Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #ifdef LIBTORCH_ENABLED 11 : 12 : #pragma once 13 : 14 : #include <torch/torch.h> 15 : #include "MooseError.h" 16 : 17 : namespace Moose 18 : { 19 : 20 : /** 21 : * This base class is meant to gather the functions and members common in 22 : * every neural network based on Libtorch. 23 : */ 24 : class LibtorchNeuralNetBase 25 : { 26 : public: 27 : // Virtual destructor 28 60 : virtual ~LibtorchNeuralNetBase() {} 29 : 30 : // Evaluation function of the libtorch modules. Since there are considerable 31 : // differences between self-built modules and modules read using a torch-script 32 : // format, this serves as a common denominator between the two. 33 : virtual torch::Tensor forward(const torch::Tensor & x) = 0; 34 : }; 35 : 36 : } 37 : 38 : #endif