LCOV - code coverage report
Current view: top level - include/reporters - StochasticReporter.h (source / functions) Hit Total Coverage
Test: idaholab/moose stochastic_tools: f45d79 Lines: 43 45 95.6 %
Date: 2025-07-25 05:00:46 Functions: 28 75 37.3 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #pragma once
      11             : 
      12             : // MOOSE includes
      13             : #include "GeneralReporter.h"
      14             : #include "StochasticToolsUtils.h"
      15             : #include "Sampler.h"
      16             : 
      17             : template <typename T>
      18             : class StochasticReporterContext : public ReporterGeneralContext<std::vector<T>>
      19             : {
      20             : public:
      21             :   StochasticReporterContext(const libMesh::ParallelObject & other,
      22             :                             const MooseObject & producer,
      23             :                             ReporterState<std::vector<T>> & state,
      24             :                             const Sampler & sampler);
      25             : 
      26             :   virtual void copyValuesBack() override;
      27             :   virtual void finalize() override;
      28           0 :   virtual std::string contextType() const override { return MooseUtils::prettyCppType(this); }
      29             :   virtual void storeInfo(nlohmann::json & json) const override;
      30             : 
      31             : protected:
      32             :   const Sampler & _sampler;
      33             :   bool _has_gathered;
      34             :   bool _has_allgathered;
      35             : };
      36             : 
      37             : template <typename T>
      38        9810 : StochasticReporterContext<T>::StochasticReporterContext(const libMesh::ParallelObject & other,
      39             :                                                         const MooseObject & producer,
      40             :                                                         ReporterState<std::vector<T>> & state,
      41             :                                                         const Sampler & sampler)
      42             :   : ReporterGeneralContext<std::vector<T>>(other, producer, state),
      43        9810 :     _sampler(sampler),
      44        9810 :     _has_gathered(false),
      45        9810 :     _has_allgathered(false)
      46             : {
      47        9810 :   this->_state.value().resize(_sampler.getNumberOfLocalRows());
      48        9810 : }
      49             : 
      50             : template <typename T>
      51             : void
      52       23012 : StochasticReporterContext<T>::copyValuesBack()
      53             : {
      54       23012 :   this->_state.copyValuesBack();
      55       23012 :   if (_has_allgathered || (_has_gathered && this->processor_id() == 0))
      56             :   {
      57        4570 :     auto & val = this->_state.value();
      58        4570 :     val.erase(val.begin(), val.begin() + _sampler.getLocalRowBegin());
      59        4570 :     val.erase(val.begin() + _sampler.getLocalRowEnd(), val.end());
      60             :   }
      61       23012 :   _has_gathered = false;
      62       23012 :   _has_allgathered = false;
      63       23012 : }
      64             : 
      65             : template <typename T>
      66             : void
      67       23388 : StochasticReporterContext<T>::finalize()
      68             : {
      69       23388 :   bool gather_required = this->_producer_enum == REPORTER_MODE_ROOT;
      70       23388 :   bool allgather_required = this->_producer_enum == REPORTER_MODE_REPLICATED;
      71       31344 :   for (const auto & pair : this->_state.getConsumers())
      72             :   {
      73             :     const ReporterMode consumer = pair.first;
      74        7956 :     if (consumer == REPORTER_MODE_ROOT)
      75             :       gather_required = true;
      76        7956 :     else if (consumer == REPORTER_MODE_REPLICATED)
      77             :       allgather_required = true;
      78             :   }
      79             : 
      80       23388 :   if (allgather_required && !_has_allgathered)
      81           0 :     StochasticTools::stochasticAllGather(this->comm(), this->_state.value());
      82       23388 :   else if (gather_required && !_has_gathered)
      83       13740 :     StochasticTools::stochasticGather(this->comm(), 0, this->_state.value());
      84             : 
      85       23388 :   _has_gathered = gather_required || _has_gathered;
      86       23388 :   _has_allgathered = allgather_required || _has_allgathered;
      87       23388 : }
      88             : 
      89             : template <typename T>
      90             : void
      91        1992 : StochasticReporterContext<T>::storeInfo(nlohmann::json & json) const
      92             : {
      93        1992 :   ReporterGeneralContext<std::vector<T>>::storeInfo(json);
      94        1992 :   if (_has_allgathered || (_has_gathered && this->processor_id() == 0))
      95             :   {
      96         560 :     json["row_begin"] = 0;
      97        1680 :     json["row_end"] = this->_sampler.getNumberOfRows();
      98             :   }
      99             :   else
     100             :   {
     101        2864 :     json["row_begin"] = this->_sampler.getLocalRowBegin();
     102        4296 :     json["row_end"] = this->_sampler.getLocalRowEnd();
     103             :   }
     104        1992 : }
     105             : 
     106             : class StochasticReporter : public GeneralReporter
     107             : {
     108             : public:
     109             :   static InputParameters validParams();
     110             : 
     111             :   StochasticReporter(const InputParameters & parameters);
     112        8424 :   virtual void initialize() override {}
     113        5462 :   virtual void execute() override {}
     114        8424 :   virtual void finalize() override {}
     115             : 
     116             : protected:
     117             :   virtual ReporterName declareStochasticReporterClone(const Sampler & sampler,
     118             :                                                       const ReporterData & from_data,
     119             :                                                       const ReporterName & from_reporter,
     120             :                                                       std::string prefix = "");
     121             :   template <typename T>
     122             :   std::vector<T> & declareStochasticReporter(std::string value_name, const Sampler & sampler);
     123             :   friend class SamplerReporterTransfer;
     124             : 
     125             : private:
     126             :   const unsigned int _parallel_type;
     127             : };
     128             : 
     129             : template <typename T>
     130             : std::vector<T> &
     131        9778 : StochasticReporter::declareStochasticReporter(std::string value_name, const Sampler & sampler)
     132             : {
     133        9778 :   const ReporterMode mode =
     134        9778 :       this->_parallel_type == 0 ? REPORTER_MODE_DISTRIBUTED : REPORTER_MODE_ROOT;
     135       39112 :   return this->template declareValueByName<std::vector<T>, StochasticReporterContext<T>>(
     136        9778 :       value_name, mode, sampler);
     137             : }

Generated by: LCOV version 1.14