Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #pragma once 11 : 12 : // MOOSE includes 13 : #include "GeneralReporter.h" 14 : #include "StochasticToolsUtils.h" 15 : #include "Sampler.h" 16 : 17 : template <typename T> 18 : class StochasticReporterContext : public ReporterGeneralContext<std::vector<T>> 19 : { 20 : public: 21 : StochasticReporterContext(const libMesh::ParallelObject & other, 22 : const MooseObject & producer, 23 : ReporterState<std::vector<T>> & state, 24 : const Sampler & sampler); 25 : 26 : virtual void copyValuesBack() override; 27 : virtual void finalize() override; 28 0 : virtual std::string contextType() const override { return MooseUtils::prettyCppType(this); } 29 : virtual void storeInfo(nlohmann::json & json) const override; 30 : 31 : protected: 32 : const Sampler & _sampler; 33 : bool _has_gathered; 34 : bool _has_allgathered; 35 : }; 36 : 37 : template <typename T> 38 9810 : StochasticReporterContext<T>::StochasticReporterContext(const libMesh::ParallelObject & other, 39 : const MooseObject & producer, 40 : ReporterState<std::vector<T>> & state, 41 : const Sampler & sampler) 42 : : ReporterGeneralContext<std::vector<T>>(other, producer, state), 43 9810 : _sampler(sampler), 44 9810 : _has_gathered(false), 45 9810 : _has_allgathered(false) 46 : { 47 9810 : this->_state.value().resize(_sampler.getNumberOfLocalRows()); 48 9810 : } 49 : 50 : template <typename T> 51 : void 52 23012 : StochasticReporterContext<T>::copyValuesBack() 53 : { 54 23012 : this->_state.copyValuesBack(); 55 23012 : if (_has_allgathered || (_has_gathered && this->processor_id() == 0)) 56 : { 57 4570 : auto & val = this->_state.value(); 58 4570 : val.erase(val.begin(), val.begin() + _sampler.getLocalRowBegin()); 59 4570 : val.erase(val.begin() + _sampler.getLocalRowEnd(), val.end()); 60 : } 61 23012 : _has_gathered = false; 62 23012 : _has_allgathered = false; 63 23012 : } 64 : 65 : template <typename T> 66 : void 67 23388 : StochasticReporterContext<T>::finalize() 68 : { 69 23388 : bool gather_required = this->_producer_enum == REPORTER_MODE_ROOT; 70 23388 : bool allgather_required = this->_producer_enum == REPORTER_MODE_REPLICATED; 71 31344 : for (const auto & pair : this->_state.getConsumers()) 72 : { 73 : const ReporterMode consumer = pair.first; 74 7956 : if (consumer == REPORTER_MODE_ROOT) 75 : gather_required = true; 76 7956 : else if (consumer == REPORTER_MODE_REPLICATED) 77 : allgather_required = true; 78 : } 79 : 80 23388 : if (allgather_required && !_has_allgathered) 81 0 : StochasticTools::stochasticAllGather(this->comm(), this->_state.value()); 82 23388 : else if (gather_required && !_has_gathered) 83 13740 : StochasticTools::stochasticGather(this->comm(), 0, this->_state.value()); 84 : 85 23388 : _has_gathered = gather_required || _has_gathered; 86 23388 : _has_allgathered = allgather_required || _has_allgathered; 87 23388 : } 88 : 89 : template <typename T> 90 : void 91 1992 : StochasticReporterContext<T>::storeInfo(nlohmann::json & json) const 92 : { 93 1992 : ReporterGeneralContext<std::vector<T>>::storeInfo(json); 94 1992 : if (_has_allgathered || (_has_gathered && this->processor_id() == 0)) 95 : { 96 560 : json["row_begin"] = 0; 97 1680 : json["row_end"] = this->_sampler.getNumberOfRows(); 98 : } 99 : else 100 : { 101 2864 : json["row_begin"] = this->_sampler.getLocalRowBegin(); 102 4296 : json["row_end"] = this->_sampler.getLocalRowEnd(); 103 : } 104 1992 : } 105 : 106 : class StochasticReporter : public GeneralReporter 107 : { 108 : public: 109 : static InputParameters validParams(); 110 : 111 : StochasticReporter(const InputParameters & parameters); 112 8424 : virtual void initialize() override {} 113 5462 : virtual void execute() override {} 114 8424 : virtual void finalize() override {} 115 : 116 : protected: 117 : virtual ReporterName declareStochasticReporterClone(const Sampler & sampler, 118 : const ReporterData & from_data, 119 : const ReporterName & from_reporter, 120 : std::string prefix = ""); 121 : template <typename T> 122 : std::vector<T> & declareStochasticReporter(std::string value_name, const Sampler & sampler); 123 : friend class SamplerReporterTransfer; 124 : 125 : private: 126 : const unsigned int _parallel_type; 127 : }; 128 : 129 : template <typename T> 130 : std::vector<T> & 131 9778 : StochasticReporter::declareStochasticReporter(std::string value_name, const Sampler & sampler) 132 : { 133 9778 : const ReporterMode mode = 134 9778 : this->_parallel_type == 0 ? REPORTER_MODE_DISTRIBUTED : REPORTER_MODE_ROOT; 135 39112 : return this->template declareValueByName<std::vector<T>, StochasticReporterContext<T>>( 136 9778 : value_name, mode, sampler); 137 : }