Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #ifdef MOOSE_LIBTORCH_ENABLED 11 : 12 : #include "LibtorchNeuralNetControlTransfer.h" 13 : #include "LibtorchNeuralNetControl.h" 14 : 15 : registerMooseObject("StochasticToolsApp", LibtorchNeuralNetControlTransfer); 16 : 17 : InputParameters 18 32 : LibtorchNeuralNetControlTransfer::validParams() 19 : { 20 32 : InputParameters params = MultiAppTransfer::validParams(); 21 32 : params += SurrogateModelInterface::validParams(); 22 : 23 32 : params.addClassDescription("Copies a neural network from a trainer object on the main app to a " 24 : "LibtorchNeuralNetControl object on the subapp."); 25 : 26 32 : params.suppressParameter<MultiAppName>("from_multi_app"); 27 32 : params.suppressParameter<MultiAppName>("multi_app"); 28 32 : params.suppressParameter<MultiMooseEnum>("direction"); 29 : 30 64 : params.addRequiredParam<UserObjectName>("trainer_name", 31 : "Trainer object that contains the neural networks." 32 : " for different samples."); 33 64 : params.addRequiredParam<std::string>("control_name", "Controller object name."); 34 32 : return params; 35 0 : } 36 : 37 16 : LibtorchNeuralNetControlTransfer::LibtorchNeuralNetControlTransfer( 38 16 : const InputParameters & parameters) 39 : : MultiAppTransfer(parameters), 40 : SurrogateModelInterface(this), 41 16 : _control_name(getParam<std::string>("control_name")), 42 32 : _trainer(getSurrogateTrainerByName<LibtorchDRLControlTrainer>( 43 16 : getParam<UserObjectName>("trainer_name"))) 44 : { 45 16 : } 46 : 47 : void 48 16 : LibtorchNeuralNetControlTransfer::execute() 49 : { 50 : // Get the control neural net from the trainer 51 16 : const Moose::LibtorchArtificialNeuralNet & trainer_nn = _trainer.controlNeuralNet(); 52 : 53 : // Get the control object from the other app 54 16 : FEProblemBase & app_problem = _multi_app->appProblemBase(0); 55 : auto & control_warehouse = app_problem.getControlWarehouse(); 56 16 : std::shared_ptr<Control> control_ptr = control_warehouse.getActiveObject(_control_name); 57 : LibtorchNeuralNetControl * control_object = 58 16 : dynamic_cast<LibtorchNeuralNetControl *>(control_ptr.get()); 59 : 60 16 : if (!control_object) 61 0 : paramError("control_name", "The given gontrol is not a LibtorchNeuralNetrControl!"); 62 : 63 : // Copy and the neural net and execute it to get the initial values 64 16 : control_object->loadControlNeuralNet(trainer_nn); 65 16 : control_object->execute(); 66 16 : } 67 : #endif