Line data Source code
1 : /**********************************************************************/ 2 : /* DO NOT MODIFY THIS HEADER */ 3 : /* Swift, a Fourier spectral solver for MOOSE */ 4 : /* */ 5 : /* Copyright 2024 Battelle Energy Alliance, LLC */ 6 : /* ALL RIGHTS RESERVED */ 7 : /**********************************************************************/ 8 : 9 : #include "TensorHistogram.h" 10 : #include "SwiftUtils.h" 11 : 12 : #include <ATen/ATen.h> 13 : #include <ATen/native/Histogram.h> 14 : 15 : registerMooseObject("SwiftApp", TensorHistogram); 16 : 17 : InputParameters 18 8 : TensorHistogram::validParams() 19 : { 20 8 : InputParameters params = TensorVectorPostprocessor::validParams(); 21 8 : params.addClassDescription("Compute a histogram of the given tensor."); 22 16 : params.addRequiredParam<Real>("min", "Lower bound of the histogram."); 23 16 : params.addRequiredParam<Real>("max", "Upper bound of the histogram."); 24 16 : params.addRequiredRangeCheckedParam<std::size_t>("bins", "bins>0", "Number of histogram bins."); 25 8 : return params; 26 0 : } 27 : 28 4 : TensorHistogram::TensorHistogram(const InputParameters & parameters) 29 : : TensorVectorPostprocessor(parameters), 30 4 : _min(getParam<Real>("min")), 31 8 : _max(getParam<Real>("max")), 32 8 : _bins(getParam<std::size_t>("bins")), 33 8 : _bin_edges(torch::linspace(_min, _max, _bins + 1, MooseTensor::floatTensorOptions())), 34 4 : _bin_vec(declareVector("bin")), 35 8 : _count_vec(declareVector("count")) 36 : { 37 : // error check (if this is not fulfilled the histogram will be empty) 38 4 : if (_min > _max) 39 0 : paramError("min", "max must be greater than min"); 40 : 41 : // fill the bin vector 42 4 : _bin_vec.resize(_bins); 43 4 : _count_vec.resize(_bins); 44 4 : const auto step = (_max - _min) / _bins; 45 84 : for (const auto i : make_range(_bins)) 46 80 : _bin_vec[i] = _min + step / 2.0 + step * i; 47 4 : } 48 : 49 : void 50 4 : TensorHistogram::execute() 51 : { 52 : // Reshape the data to fit the expected input format for histogramdd 53 4 : const auto data = _u.reshape({-1, 1}); 54 : 55 : // Use the histogramdd function 56 : torch::Tensor hist; 57 : try 58 : { 59 : // histogramdd does not have a cuda implementation in torch 2.1 60 : // we try anyways to run on the current compute device, in case 61 : // the implemntation exists for different devices or future torch versions. 62 12 : const auto pair = at::native::histogramdd(data, {_bin_edges}); 63 4 : hist = std::get<0>(pair).cpu(); 64 : } 65 2 : catch (const std::exception &) 66 : { 67 6 : const auto pair = at::native::histogramdd(data.cpu(), {_bin_edges.cpu()}); 68 : hist = std::get<0>(pair); 69 2 : } 70 : 71 : // put into VPP vector 72 4 : if (hist.dtype() == torch::kFloat32) 73 0 : for (const auto i : make_range(int64_t(_bins))) 74 0 : _count_vec[i] = hist.index({i}).item<float>(); 75 4 : else if (hist.dtype() == torch::kFloat64) 76 84 : for (const auto i : make_range(int64_t(_bins))) 77 240 : _count_vec[i] = hist.index({i}).item<double>(); 78 : else 79 0 : mooseError("Unsupported tensor dtype() in TensorHistogram."); 80 86 : }