LCOV - code coverage report
Current view: top level - src/libtorch/transfers - LibtorchNeuralNetControlTransfer.C (source / functions) Hit Total Coverage
Test: idaholab/moose stochastic_tools: f45d79 Lines: 25 27 92.6 %
Date: 2025-07-25 05:00:46 Functions: 3 3 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #ifdef MOOSE_LIBTORCH_ENABLED
      11             : 
      12             : #include "LibtorchNeuralNetControlTransfer.h"
      13             : #include "LibtorchNeuralNetControl.h"
      14             : 
      15             : registerMooseObject("StochasticToolsApp", LibtorchNeuralNetControlTransfer);
      16             : 
      17             : InputParameters
      18          32 : LibtorchNeuralNetControlTransfer::validParams()
      19             : {
      20          32 :   InputParameters params = MultiAppTransfer::validParams();
      21          32 :   params += SurrogateModelInterface::validParams();
      22             : 
      23          32 :   params.addClassDescription("Copies a neural network from a trainer object on the main app to a "
      24             :                              "LibtorchNeuralNetControl object on the subapp.");
      25             : 
      26          32 :   params.suppressParameter<MultiAppName>("from_multi_app");
      27          32 :   params.suppressParameter<MultiAppName>("multi_app");
      28          32 :   params.suppressParameter<MultiMooseEnum>("direction");
      29             : 
      30          64 :   params.addRequiredParam<UserObjectName>("trainer_name",
      31             :                                           "Trainer object that contains the neural networks."
      32             :                                           " for different samples.");
      33          64 :   params.addRequiredParam<std::string>("control_name", "Controller object name.");
      34          32 :   return params;
      35           0 : }
      36             : 
      37          16 : LibtorchNeuralNetControlTransfer::LibtorchNeuralNetControlTransfer(
      38          16 :     const InputParameters & parameters)
      39             :   : MultiAppTransfer(parameters),
      40             :     SurrogateModelInterface(this),
      41          16 :     _control_name(getParam<std::string>("control_name")),
      42          32 :     _trainer(getSurrogateTrainerByName<LibtorchDRLControlTrainer>(
      43          16 :         getParam<UserObjectName>("trainer_name")))
      44             : {
      45          16 : }
      46             : 
      47             : void
      48          16 : LibtorchNeuralNetControlTransfer::execute()
      49             : {
      50             :   // Get the control neural net from the trainer
      51          16 :   const Moose::LibtorchArtificialNeuralNet & trainer_nn = _trainer.controlNeuralNet();
      52             : 
      53             :   // Get the control object from the other app
      54          16 :   FEProblemBase & app_problem = _multi_app->appProblemBase(0);
      55             :   auto & control_warehouse = app_problem.getControlWarehouse();
      56          16 :   std::shared_ptr<Control> control_ptr = control_warehouse.getActiveObject(_control_name);
      57             :   LibtorchNeuralNetControl * control_object =
      58          16 :       dynamic_cast<LibtorchNeuralNetControl *>(control_ptr.get());
      59             : 
      60          16 :   if (!control_object)
      61           0 :     paramError("control_name", "The given gontrol is not a LibtorchNeuralNetrControl!");
      62             : 
      63             :   // Copy and the neural net and execute it to get the initial values
      64          16 :   control_object->loadControlNeuralNet(trainer_nn);
      65          16 :   control_object->execute();
      66          16 : }
      67             : #endif

Generated by: LCOV version 1.14