LCOV - code coverage report
Current view: top level - include/libtorch/controls - LibtorchNeuralNetControl.h (source / functions) Hit Total Coverage
Test: idaholab/moose framework: 8601ad Lines: 2 2 100.0 %
Date: 2025-07-18 13:27:08 Functions: 2 2 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #ifdef LIBTORCH_ENABLED
      11             : 
      12             : #pragma once
      13             : 
      14             : #include "LibtorchArtificialNeuralNet.h"
      15             : #include "Control.h"
      16             : 
      17             : /**
      18             :  * A time-dependent, neural network-based control of multiple input parameters.
      19             :  * The control strategy depends on the training of the neural net, which is
      20             :  * typically done in a trainer object in the main app. Alternatively, the user can read
      21             :  * neural networks using two formats:
      22             :  * 1. Torchscript format (from python)
      23             :  * 2. Regular data format containing the parameter values of the neural net
      24             :  */
      25             : class LibtorchNeuralNetControl : public Control
      26             : {
      27             : public:
      28             :   static InputParameters validParams();
      29             : 
      30             :   /// Construct using input parameters
      31             :   LibtorchNeuralNetControl(const InputParameters & parameters);
      32             : 
      33             :   /// Execute neural network to determine the controllable parameter values
      34             :   virtual void execute() override;
      35             : 
      36             :   /**
      37             :    * Get the (signal_index)-th signal of the control neural net
      38             :    * @param signal_index The index of the queried control signal
      39             :    * @return The (signal_index)-th constol signal
      40             :    */
      41             :   Real getSignal(const unsigned int signal_index) const;
      42             : 
      43             :   /// Get the number of controls this object is computing
      44           3 :   unsigned int numberOfControlSignals() const { return _control_names.size(); }
      45             : 
      46             :   /**
      47             :    * Function responsible for loading the neural network for the controller. This function is used
      48             :    * when copying the neural network from a main app which trains it.
      49             :    * @param input_nn Reference to a neural network which will be copied into this object
      50             :    */
      51             :   void loadControlNeuralNet(const Moose::LibtorchArtificialNeuralNet & input_nn);
      52             : 
      53             :   /// Return a reference to the stored neural network
      54             :   const Moose::LibtorchNeuralNetBase & controlNeuralNet() const;
      55             : 
      56             :   /// Return true if the object already has a neural netwok
      57           1 :   bool hasControlNeuralNet() const { return (_nn != NULL); };
      58             : 
      59             : protected:
      60             :   /**
      61             :    * Function responsible for checking for potential user errors in the input file
      62             :    * @param param_name The name of the main parameter
      63             :    * @param conditional_param Vector parameter names that depend on the main parameter
      64             :    * @param should_be_defined If the conditional parameters should be defined when the main
      65             :    * parameter is defined
      66             :    */
      67             :   void conditionalParameterError(const std::string & param_name,
      68             :                                  const std::vector<std::string> & conditional_param,
      69             :                                  bool should_be_defined = true);
      70             : 
      71             :   /// Function that updates the values of the current response
      72             :   void updateCurrentResponse();
      73             : 
      74             :   /// Function that prepares the input tensor for the controller neural network
      75             :   torch::Tensor prepareInputTensor();
      76             : 
      77             :   /// The values of the current observed postprocessor values
      78             :   std::vector<Real> _current_response;
      79             :   /// This variable is populated if the controller needs acess to older values of the
      80             :   /// observed postprocessor values
      81             :   std::vector<std::vector<Real>> & _old_responses;
      82             : 
      83             :   /// The names of the controllable parameters
      84             :   const std::vector<std::string> & _control_names;
      85             :   /// The control signals from the last evaluation of the controller
      86             :   std::vector<Real> _current_control_signals;
      87             : 
      88             :   /// Names of the postprocessors which contain the observations of the system
      89             :   const std::vector<PostprocessorName> & _response_names;
      90             : 
      91             :   /// Links to the current response postprocessor values. This is necessary so that we can check
      92             :   /// if the postprocessors exist.
      93             :   std::vector<const Real *> _response_values;
      94             : 
      95             :   /// Number of timesteps to use as input data from the reporters (this influences how many past
      96             :   /// results are used, e.g. the size of _old_responses)
      97             :   const unsigned int _input_timesteps;
      98             : 
      99             :   /// Shifting constants for the responses
     100             :   const std::vector<Real> _response_shift_factors;
     101             :   /// Scaling constants (multipliers) for the responses
     102             :   const std::vector<Real> _response_scaling_factors;
     103             :   /// Multipliers for the actions
     104             :   const std::vector<Real> _action_scaling_factors;
     105             : 
     106             :   /// Pointer to the neural net object which is supposed to be used to control
     107             :   /// the parameter values. The controller owns this object, but it can be read
     108             :   /// from file or copied by a transfer.
     109             :   std::shared_ptr<Moose::LibtorchNeuralNetBase> _nn;
     110             : };
     111             : 
     112             : #endif

Generated by: LCOV version 1.14