Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #pragma once 11 : 12 : // MOOSE includes 13 : #include "GeneralReporter.h" 14 : #include "StochasticToolsUtils.h" 15 : #include "Sampler.h" 16 : 17 : template <typename T> 18 : class StochasticReporterContext : public ReporterGeneralContext<std::vector<T>> 19 : { 20 : public: 21 : StochasticReporterContext(const libMesh::ParallelObject & other, 22 : const MooseObject & producer, 23 : ReporterState<std::vector<T>> & state, 24 : const Sampler & sampler); 25 : 26 : virtual void copyValuesBack() override; 27 : virtual void finalize() override; 28 0 : virtual std::string contextType() const override { return MooseUtils::prettyCppType(this); } 29 : virtual void storeInfo(nlohmann::json & json) const override; 30 : 31 : protected: 32 : const Sampler & _sampler; 33 : bool _has_gathered; 34 : bool _has_allgathered; 35 : }; 36 : 37 : template <typename T> 38 10422 : StochasticReporterContext<T>::StochasticReporterContext(const libMesh::ParallelObject & other, 39 : const MooseObject & producer, 40 : ReporterState<std::vector<T>> & state, 41 : const Sampler & sampler) 42 : : ReporterGeneralContext<std::vector<T>>(other, producer, state), 43 10422 : _sampler(sampler), 44 10422 : _has_gathered(false), 45 10422 : _has_allgathered(false) 46 : { 47 10422 : this->_state.value().resize(_sampler.getNumberOfLocalRows()); 48 10422 : } 49 : 50 : template <typename T> 51 : void 52 24689 : StochasticReporterContext<T>::copyValuesBack() 53 : { 54 24689 : this->_state.copyValuesBack(); 55 24689 : if (_has_allgathered || (_has_gathered && this->processor_id() == 0)) 56 : { 57 5027 : auto & val = this->_state.value(); 58 5027 : val.erase(val.begin(), val.begin() + _sampler.getLocalRowBegin()); 59 5027 : val.erase(val.begin() + _sampler.getLocalRowEnd(), val.end()); 60 : } 61 24689 : _has_gathered = false; 62 24689 : _has_allgathered = false; 63 24689 : } 64 : 65 : template <typename T> 66 : void 67 25096 : StochasticReporterContext<T>::finalize() 68 : { 69 25096 : bool gather_required = this->_producer_enum == REPORTER_MODE_ROOT; 70 25096 : bool allgather_required = this->_producer_enum == REPORTER_MODE_REPLICATED; 71 33674 : for (const auto & pair : this->_state.getConsumers()) 72 : { 73 : const ReporterMode consumer = pair.first; 74 8578 : if (consumer == REPORTER_MODE_ROOT) 75 : gather_required = true; 76 8578 : else if (consumer == REPORTER_MODE_REPLICATED) 77 : allgather_required = true; 78 : } 79 : 80 25096 : if (allgather_required && !_has_allgathered) 81 0 : StochasticTools::stochasticAllGather(this->comm(), this->_state.value()); 82 25096 : else if (gather_required && !_has_gathered) 83 14635 : StochasticTools::stochasticGather(this->comm(), 0, this->_state.value()); 84 : 85 25096 : _has_gathered = gather_required || _has_gathered; 86 25096 : _has_allgathered = allgather_required || _has_allgathered; 87 25096 : } 88 : 89 : template <typename T> 90 : void 91 2182 : StochasticReporterContext<T>::storeInfo(nlohmann::json & json) const 92 : { 93 2182 : ReporterGeneralContext<std::vector<T>>::storeInfo(json); 94 2182 : if (_has_allgathered || (_has_gathered && this->processor_id() == 0)) 95 : { 96 608 : json["row_begin"] = 0; 97 1824 : json["row_end"] = this->_sampler.getNumberOfRows(); 98 : } 99 : else 100 : { 101 3148 : json["row_begin"] = this->_sampler.getLocalRowBegin(); 102 4722 : json["row_end"] = this->_sampler.getLocalRowEnd(); 103 : } 104 2182 : } 105 : 106 : class StochasticReporter : public GeneralReporter 107 : { 108 : public: 109 : static InputParameters validParams(); 110 : 111 : StochasticReporter(const InputParameters & parameters); 112 9083 : virtual void initialize() override {} 113 5914 : virtual void execute() override {} 114 9083 : virtual void finalize() override {} 115 : 116 : protected: 117 : virtual ReporterName declareStochasticReporterClone(const Sampler & sampler, 118 : const ReporterData & from_data, 119 : const ReporterName & from_reporter, 120 : std::string prefix = ""); 121 : template <typename T> 122 : std::vector<T> & declareStochasticReporter(std::string value_name, const Sampler & sampler); 123 : friend class SamplerReporterTransfer; 124 : 125 : private: 126 : const unsigned int _parallel_type; 127 : }; 128 : 129 : template <typename T> 130 : std::vector<T> & 131 10388 : StochasticReporter::declareStochasticReporter(std::string value_name, const Sampler & sampler) 132 : { 133 10388 : const ReporterMode mode = 134 10388 : this->_parallel_type == 0 ? REPORTER_MODE_DISTRIBUTED : REPORTER_MODE_ROOT; 135 41552 : return this->template declareValueByName<std::vector<T>, StochasticReporterContext<T>>( 136 10388 : value_name, mode, sampler); 137 : }