Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #pragma once 11 : 12 : // MOOSE includes 13 : #include "GeneralReporter.h" 14 : #include "StochasticToolsUtils.h" 15 : #include "Sampler.h" 16 : 17 : template <typename T> 18 : class StochasticReporterContext : public ReporterGeneralContext<std::vector<T>> 19 : { 20 : public: 21 : StochasticReporterContext(const libMesh::ParallelObject & other, 22 : const MooseObject & producer, 23 : ReporterState<std::vector<T>> & state, 24 : const Sampler & sampler); 25 : 26 : virtual void copyValuesBack() override; 27 : virtual void finalize() override; 28 0 : virtual std::string contextType() const override { return MooseUtils::prettyCppType(this); } 29 : virtual void storeInfo(nlohmann::json & json) const override; 30 : 31 : protected: 32 : const Sampler & _sampler; 33 : bool _has_gathered; 34 : bool _has_allgathered; 35 : }; 36 : 37 : template <typename T> 38 11247 : StochasticReporterContext<T>::StochasticReporterContext(const libMesh::ParallelObject & other, 39 : const MooseObject & producer, 40 : ReporterState<std::vector<T>> & state, 41 : const Sampler & sampler) 42 : : ReporterGeneralContext<std::vector<T>>(other, producer, state), 43 11247 : _sampler(sampler), 44 11247 : _has_gathered(false), 45 11247 : _has_allgathered(false) 46 : { 47 11247 : this->_state.value().resize(_sampler.getNumberOfLocalRows()); 48 11247 : } 49 : 50 : template <typename T> 51 : void 52 29215 : StochasticReporterContext<T>::copyValuesBack() 53 : { 54 29215 : this->_state.copyValuesBack(); 55 29215 : if (_has_allgathered || (_has_gathered && this->processor_id() == 0)) 56 : { 57 7411 : auto & val = this->_state.value(); 58 7411 : val.erase(val.begin(), val.begin() + _sampler.getLocalRowBegin()); 59 7411 : val.erase(val.begin() + _sampler.getLocalRowEnd(), val.end()); 60 : } 61 29215 : _has_gathered = false; 62 29215 : _has_allgathered = false; 63 29215 : } 64 : 65 : template <typename T> 66 : void 67 29690 : StochasticReporterContext<T>::finalize() 68 : { 69 29690 : bool gather_required = this->_producer_enum == REPORTER_MODE_ROOT; 70 29690 : bool allgather_required = this->_producer_enum == REPORTER_MODE_REPLICATED; 71 38268 : for (const auto & pair : this->_state.getConsumers()) 72 : { 73 : const ReporterMode consumer = pair.first; 74 8578 : if (consumer == REPORTER_MODE_ROOT) 75 : gather_required = true; 76 8578 : else if (consumer == REPORTER_MODE_REPLICATED) 77 : allgather_required = true; 78 : } 79 : 80 29690 : if (allgather_required && !_has_allgathered) 81 0 : StochasticTools::stochasticAllGather(this->comm(), this->_state.value()); 82 29690 : else if (gather_required && !_has_gathered) 83 19229 : StochasticTools::stochasticGather(this->comm(), 0, this->_state.value()); 84 : 85 29690 : _has_gathered = gather_required || _has_gathered; 86 29690 : _has_allgathered = allgather_required || _has_allgathered; 87 29690 : } 88 : 89 : template <typename T> 90 : void 91 2242 : StochasticReporterContext<T>::storeInfo(nlohmann::json & json) const 92 : { 93 2242 : ReporterGeneralContext<std::vector<T>>::storeInfo(json); 94 2242 : if (_has_allgathered || (_has_gathered && this->processor_id() == 0)) 95 : { 96 668 : json["row_begin"] = 0; 97 2004 : json["row_end"] = this->_sampler.getNumberOfRows(); 98 : } 99 : else 100 : { 101 3148 : json["row_begin"] = this->_sampler.getLocalRowBegin(); 102 4722 : json["row_end"] = this->_sampler.getLocalRowEnd(); 103 : } 104 2242 : } 105 : 106 : /** 107 : * This is a non-typed base class of the stochastic vector value, used to update 108 : * reporter values during StochasticReporter::initialize() call. 109 : */ 110 : class StochasticReporterValueBase 111 : { 112 : public: 113 11213 : StochasticReporterValueBase(const Sampler & sampler) : _sampler(sampler) {} 114 0 : virtual ~StochasticReporterValueBase() = default; 115 : 116 0 : virtual void initialize() {} 117 : 118 : protected: 119 : const Sampler & _sampler; 120 : }; 121 : 122 : template <typename T> 123 : class StochasticReporterValue : public StochasticReporterValueBase 124 : { 125 : public: 126 11213 : StochasticReporterValue(std::vector<T> & value, const Sampler & sampler) 127 11213 : : StochasticReporterValueBase(sampler), _value(value) 128 : { 129 : } 130 : 131 33962 : virtual void initialize() { this->_value.resize(this->_sampler.getNumberOfLocalRows()); } 132 : 133 : private: 134 : std::vector<T> & _value; 135 : }; 136 : 137 : class StochasticReporter : public GeneralReporter 138 : { 139 : public: 140 : static InputParameters validParams(); 141 : 142 : StochasticReporter(const InputParameters & parameters); 143 : virtual void initialize() override final; 144 5914 : virtual void execute() override {} 145 9622 : virtual void finalize() override {} 146 : 147 : protected: 148 : virtual ReporterName declareStochasticReporterClone(const Sampler & sampler, 149 : const ReporterData & from_data, 150 : const ReporterName & from_reporter, 151 : std::string prefix = ""); 152 : template <typename T> 153 : std::vector<T> & declareStochasticReporter(std::string value_name, const Sampler & sampler); 154 : friend class SamplerReporterTransfer; 155 : 156 : private: 157 : const unsigned int _parallel_type; 158 : /// Container for declared values that we may need to resize at initialize 159 : std::deque<std::unique_ptr<StochasticReporterValueBase>> _vectors; 160 : }; 161 : 162 : template <typename T> 163 : std::vector<T> & 164 11213 : StochasticReporter::declareStochasticReporter(std::string value_name, const Sampler & sampler) 165 : { 166 11213 : const ReporterMode mode = 167 11213 : this->_parallel_type == 0 ? REPORTER_MODE_DISTRIBUTED : REPORTER_MODE_ROOT; 168 : std::vector<T> & vector = 169 33639 : this->template declareValueByName<std::vector<T>, StochasticReporterContext<T>>( 170 : value_name, mode, sampler); 171 : 172 22426 : _vectors.push_back(std::make_unique<StochasticReporterValue<T>>(vector, sampler)); 173 11213 : return vector; 174 : }