LCOV - code coverage report
Current view: top level - include/reporters - StochasticReporter.h (source / functions) Hit Total Coverage
Test: idaholab/moose stochastic_tools: #31706 (f8ed4a) with base bb0a08 Lines: 47 51 92.2 %
Date: 2025-11-03 17:29:08 Functions: 32 89 36.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #pragma once
      11             : 
      12             : // MOOSE includes
      13             : #include "GeneralReporter.h"
      14             : #include "StochasticToolsUtils.h"
      15             : #include "Sampler.h"
      16             : 
      17             : template <typename T>
      18             : class StochasticReporterContext : public ReporterGeneralContext<std::vector<T>>
      19             : {
      20             : public:
      21             :   StochasticReporterContext(const libMesh::ParallelObject & other,
      22             :                             const MooseObject & producer,
      23             :                             ReporterState<std::vector<T>> & state,
      24             :                             const Sampler & sampler);
      25             : 
      26             :   virtual void copyValuesBack() override;
      27             :   virtual void finalize() override;
      28           0 :   virtual std::string contextType() const override { return MooseUtils::prettyCppType(this); }
      29             :   virtual void storeInfo(nlohmann::json & json) const override;
      30             : 
      31             : protected:
      32             :   const Sampler & _sampler;
      33             :   bool _has_gathered;
      34             :   bool _has_allgathered;
      35             : };
      36             : 
      37             : template <typename T>
      38       11247 : StochasticReporterContext<T>::StochasticReporterContext(const libMesh::ParallelObject & other,
      39             :                                                         const MooseObject & producer,
      40             :                                                         ReporterState<std::vector<T>> & state,
      41             :                                                         const Sampler & sampler)
      42             :   : ReporterGeneralContext<std::vector<T>>(other, producer, state),
      43       11247 :     _sampler(sampler),
      44       11247 :     _has_gathered(false),
      45       11247 :     _has_allgathered(false)
      46             : {
      47       11247 :   this->_state.value().resize(_sampler.getNumberOfLocalRows());
      48       11247 : }
      49             : 
      50             : template <typename T>
      51             : void
      52       29215 : StochasticReporterContext<T>::copyValuesBack()
      53             : {
      54       29215 :   this->_state.copyValuesBack();
      55       29215 :   if (_has_allgathered || (_has_gathered && this->processor_id() == 0))
      56             :   {
      57        7411 :     auto & val = this->_state.value();
      58        7411 :     val.erase(val.begin(), val.begin() + _sampler.getLocalRowBegin());
      59        7411 :     val.erase(val.begin() + _sampler.getLocalRowEnd(), val.end());
      60             :   }
      61       29215 :   _has_gathered = false;
      62       29215 :   _has_allgathered = false;
      63       29215 : }
      64             : 
      65             : template <typename T>
      66             : void
      67       29690 : StochasticReporterContext<T>::finalize()
      68             : {
      69       29690 :   bool gather_required = this->_producer_enum == REPORTER_MODE_ROOT;
      70       29690 :   bool allgather_required = this->_producer_enum == REPORTER_MODE_REPLICATED;
      71       38268 :   for (const auto & pair : this->_state.getConsumers())
      72             :   {
      73             :     const ReporterMode consumer = pair.first;
      74        8578 :     if (consumer == REPORTER_MODE_ROOT)
      75             :       gather_required = true;
      76        8578 :     else if (consumer == REPORTER_MODE_REPLICATED)
      77             :       allgather_required = true;
      78             :   }
      79             : 
      80       29690 :   if (allgather_required && !_has_allgathered)
      81           0 :     StochasticTools::stochasticAllGather(this->comm(), this->_state.value());
      82       29690 :   else if (gather_required && !_has_gathered)
      83       19229 :     StochasticTools::stochasticGather(this->comm(), 0, this->_state.value());
      84             : 
      85       29690 :   _has_gathered = gather_required || _has_gathered;
      86       29690 :   _has_allgathered = allgather_required || _has_allgathered;
      87       29690 : }
      88             : 
      89             : template <typename T>
      90             : void
      91        2242 : StochasticReporterContext<T>::storeInfo(nlohmann::json & json) const
      92             : {
      93        2242 :   ReporterGeneralContext<std::vector<T>>::storeInfo(json);
      94        2242 :   if (_has_allgathered || (_has_gathered && this->processor_id() == 0))
      95             :   {
      96         668 :     json["row_begin"] = 0;
      97        2004 :     json["row_end"] = this->_sampler.getNumberOfRows();
      98             :   }
      99             :   else
     100             :   {
     101        3148 :     json["row_begin"] = this->_sampler.getLocalRowBegin();
     102        4722 :     json["row_end"] = this->_sampler.getLocalRowEnd();
     103             :   }
     104        2242 : }
     105             : 
     106             : /**
     107             :  * This is a non-typed base class of the stochastic vector value, used to update
     108             :  * reporter values during StochasticReporter::initialize() call.
     109             :  */
     110             : class StochasticReporterValueBase
     111             : {
     112             : public:
     113       11213 :   StochasticReporterValueBase(const Sampler & sampler) : _sampler(sampler) {}
     114           0 :   virtual ~StochasticReporterValueBase() = default;
     115             : 
     116           0 :   virtual void initialize() {}
     117             : 
     118             : protected:
     119             :   const Sampler & _sampler;
     120             : };
     121             : 
     122             : template <typename T>
     123             : class StochasticReporterValue : public StochasticReporterValueBase
     124             : {
     125             : public:
     126       11213 :   StochasticReporterValue(std::vector<T> & value, const Sampler & sampler)
     127       11213 :     : StochasticReporterValueBase(sampler), _value(value)
     128             :   {
     129             :   }
     130             : 
     131       33962 :   virtual void initialize() { this->_value.resize(this->_sampler.getNumberOfLocalRows()); }
     132             : 
     133             : private:
     134             :   std::vector<T> & _value;
     135             : };
     136             : 
     137             : class StochasticReporter : public GeneralReporter
     138             : {
     139             : public:
     140             :   static InputParameters validParams();
     141             : 
     142             :   StochasticReporter(const InputParameters & parameters);
     143             :   virtual void initialize() override final;
     144        5914 :   virtual void execute() override {}
     145        9622 :   virtual void finalize() override {}
     146             : 
     147             : protected:
     148             :   virtual ReporterName declareStochasticReporterClone(const Sampler & sampler,
     149             :                                                       const ReporterData & from_data,
     150             :                                                       const ReporterName & from_reporter,
     151             :                                                       std::string prefix = "");
     152             :   template <typename T>
     153             :   std::vector<T> & declareStochasticReporter(std::string value_name, const Sampler & sampler);
     154             :   friend class SamplerReporterTransfer;
     155             : 
     156             : private:
     157             :   const unsigned int _parallel_type;
     158             :   /// Container for declared values that we may need to resize at initialize
     159             :   std::deque<std::unique_ptr<StochasticReporterValueBase>> _vectors;
     160             : };
     161             : 
     162             : template <typename T>
     163             : std::vector<T> &
     164       11213 : StochasticReporter::declareStochasticReporter(std::string value_name, const Sampler & sampler)
     165             : {
     166       11213 :   const ReporterMode mode =
     167       11213 :       this->_parallel_type == 0 ? REPORTER_MODE_DISTRIBUTED : REPORTER_MODE_ROOT;
     168             :   std::vector<T> & vector =
     169       33639 :       this->template declareValueByName<std::vector<T>, StochasticReporterContext<T>>(
     170             :           value_name, mode, sampler);
     171             : 
     172       22426 :   _vectors.push_back(std::make_unique<StochasticReporterValue<T>>(vector, sampler));
     173       11213 :   return vector;
     174             : }

Generated by: LCOV version 1.14