(moose/test/src/libtorch/vectorpostprocessors/TorchScriptModuleTest.C)
// This file is part of the MOOSE framework
// https://mooseframework.inl.gov
//
// All rights reserved, see COPYRIGHT for full restrictions
// https://github.com/idaholab/moose/blob/master/COPYRIGHT
//
// Licensed under LGPL 2.1, please see LICENSE for details
// https://www.gnu.org/licenses/lgpl-2.1.html
#ifdef MOOSE_LIBTORCH_ENABLED
#include <torch/torch.h>
#include "TorchScriptModule.h"
#include "TorchScriptModuleTest.h"
registerMooseObject("MooseTestApp", TorchScriptModuleTest);
InputParameters
TorchScriptModuleTest::validParams()
{
  InputParameters params = GeneralVectorPostprocessor::validParams();
  params.addParam<std::string>(
      "filename", "my_net.pt", "The name of the file where the torch script is saved.");
  params.addClassDescription("Evaluates a neural network saved from python.");
  return params;
}
TorchScriptModuleTest::TorchScriptModuleTest(const InputParameters & params)
  : GeneralVectorPostprocessor(params),
    _x_values(declareVector("x_values")),
    _y_values(declareVector("y_values")),
    _z_values(declareVector("z_values")),
    _nn_values(declareVector("nn_values"))
{
  // We read the neural net from a file which was prepared in python
  auto nn = std::make_shared<Moose::TorchScriptModule>();
  nn->loadNeuralNetwork(getParam<std::string>("filename"));
  // First, we evaluate the neural network at (0.0, 0.0, 0.0) and add it to the VPPs
  torch::Tensor input = at::zeros({1, 3}, at::kDouble);
  _x_values.push_back(0.0);
  _y_values.push_back(0.0);
  _z_values.push_back(0.0);
  torch::Tensor prediction = nn->forward(input);
  _nn_values.push_back(prediction.item<double>());
  // Now we evaluate the neural net at two other positions and add the values to the VPPs
  for (unsigned int i = 1; i < 3; ++i)
  {
    torch::Tensor input = torch::add(at::zeros({1, 3}, at::kDouble), i * 0.5);
    torch::Tensor prediction = nn->forward(input);
    _x_values.push_back(i * 0.5);
    _y_values.push_back(i * 0.5);
    _z_values.push_back(i * 0.5);
    _nn_values.push_back(prediction.item<double>());
  }
}
#endif