LCOV - code coverage report
Current view: top level - src/tensor_solver - TensorSolver.C (source / functions) Hit Total Coverage
Test: idaholab/swift: #92 (25e020) with base b3cd84 Lines: 24 29 82.8 %
Date: 2025-09-10 17:10:32 Functions: 5 6 83.3 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /**********************************************************************/
       2             : /*                    DO NOT MODIFY THIS HEADER                       */
       3             : /*             Swift, a Fourier spectral solver for MOOSE             */
       4             : /*                                                                    */
       5             : /*            Copyright 2024 Battelle Energy Alliance, LLC            */
       6             : /*                        ALL RIGHTS RESERVED                         */
       7             : /**********************************************************************/
       8             : 
       9             : #include "TensorSolver.h"
      10             : #include "TensorProblem.h"
      11             : #include "SwiftTypes.h"
      12             : 
      13             : InputParameters
      14         222 : TensorSolver::validParams()
      15             : {
      16         222 :   InputParameters params = TensorOperatorBase::validParams();
      17         222 :   params.registerBase("TensorSolver");
      18         444 :   params.addParam<TensorComputeName>(
      19             :       "root_compute",
      20             :       "Primary compute object that updates the buffers. This is usually a "
      21             :       "ComputeGroup object. A ComputeGroup encompassing all computes will be generated "
      22             :       "automatically if the user does not provide this parameter.");
      23         222 :   params.addClassDescription("TensorSolver object.");
      24             : 
      25         444 :   params.addParam<std::vector<TensorOutputBufferName>>(
      26             :       "forward_buffer",
      27             :       {},
      28             :       "These buffers are updated with the corresponding buffers from forward_buffer_old. No "
      29             :       "integration is performed. Buffer forwarding is used only to resolve cyclic dependencies.");
      30         444 :   params.addParam<std::vector<TensorInputBufferName>>(
      31             :       "forward_buffer_new", {}, "New values to update `forward_buffer` with.");
      32             : 
      33         222 :   return params;
      34           0 : }
      35             : 
      36         110 : TensorSolver::TensorSolver(const InputParameters & parameters)
      37         110 :   : TensorOperatorBase(parameters), _dt(_tensor_problem.dt()), _dt_old(_tensor_problem.dtOld())
      38             : {
      39             :   const auto & forward_buffer_names = getParam<TensorOutputBufferName, TensorOutputBufferName>(
      40         220 :       "forward_buffer", "forward_buffer_new");
      41         110 :   for (const auto & [forward_buffer, forward_buffer_new] : forward_buffer_names)
      42           0 :     _forwarded_buffers.emplace_back(getOutputBufferByName(forward_buffer),
      43             :                                     getInputBufferByName(forward_buffer_new));
      44         110 : }
      45             : 
      46             : const std::vector<torch::Tensor> &
      47           0 : TensorSolver::getBufferOld(const std::string & param, unsigned int max_states)
      48             : {
      49           0 :   return getBufferOldByName(getParam<TensorInputBufferName>(param), max_states);
      50             : }
      51             : 
      52             : const std::vector<torch::Tensor> &
      53         110 : TensorSolver::getBufferOldByName(const TensorInputBufferName & buffer_name, unsigned int max_states)
      54             : {
      55         110 :   return _tensor_problem.getBufferOld(buffer_name, max_states);
      56             : }
      57             : 
      58             : void
      59         110 : TensorSolver::updateDependencies()
      60             : {
      61             :   // the compute that's being solved for (usually a ComputeGroup)
      62         110 :   const auto & root_name = getParam<TensorComputeName>("root_compute");
      63         410 :   for (const auto & cmp : _tensor_problem.getComputes())
      64         410 :     if (cmp->name() == root_name)
      65             :     {
      66             :       _compute = cmp;
      67         110 :       _compute->updateDependencies();
      68         110 :       return;
      69             :     }
      70             : 
      71           0 :   paramError("root_compute", "Compute object not found.");
      72             : }
      73             : 
      74             : void
      75       82784 : TensorSolver::forwardBuffers()
      76             : {
      77       82784 :   for (const auto & [forward_buffer, forward_buffer_new] : _forwarded_buffers)
      78             :     forward_buffer = forward_buffer_new;
      79       82784 : }

Generated by: LCOV version 1.14