Train the neural network using a given (serialized) data and options for the training process.
101 const auto t_begin = MPI_Wtime();
113 int used_rank = real_rank < num_ranks ? real_rank : 0;
115 const auto num_samples = dataset.size().value();
117 if (num_ranks * options.num_batches > num_samples)
118 mooseError(
"The number of used processors* number of requestedf batches " +
119 std::to_string(num_ranks * options.num_batches) +
120 " is greater than the number of samples used for the training!");
123 const unsigned int sample_per_batch =
computeBatchSize(num_samples, options.num_batches);
129 auto transformed_data_set = dataset.map(torch::data::transforms::Stack<>());
132 SamplerType sampler(num_samples, num_ranks, used_rank, options.allow_duplicates);
136 torch::data::make_data_loader(std::move(transformed_data_set), sampler, sample_per_proc);
142 Real initial_loss = 1.0;
143 Real epoch_loss = 0.0;
146 unsigned int epoch = 1;
147 while (epoch <= options.num_epochs && rel_loss > options.rel_loss_tol)
151 for (
auto & batch : *data_loader)
154 optimizer->zero_grad();
157 torch::Tensor prediction =
_nn.
forward(batch.data);
160 torch::Tensor loss = torch::mse_loss(prediction, batch.target);
166 if (real_rank == used_rank)
167 epoch_loss += loss.item<
double>();
174 for (
auto & param :
_nn.named_parameters())
176 if (real_rank != used_rank)
177 param.value().grad().data() = param.value().grad().data() * 0.0;
179 MPI_Allreduce(MPI_IN_PLACE,
180 param.value().grad().data_ptr(),
181 param.value().grad().numel(),
186 param.value().grad().data() = param.value().grad().data() / num_ranks;
197 epoch_loss = epoch_loss / options.num_batches / num_ranks;
200 initial_loss = epoch_loss;
202 rel_loss = epoch_loss / initial_loss;
205 if (options.print_loss)
206 if (epoch % options.print_epoch_loss == 0 || epoch == 1)
207 Moose::out <<
"Epoch: " << epoch <<
" | Loss: " << COLOR_GREEN << epoch_loss
208 << COLOR_DEFAULT <<
" | Rel. loss: " << COLOR_GREEN << rel_loss << COLOR_DEFAULT
215 auto t_end = MPI_Wtime();
217 if (options.print_loss && used_rank == 0)
218 Moose::out <<
"Neural net training time: " << COLOR_GREEN << (t_end - t_begin) << COLOR_DEFAULT
219 <<
" s" << std::endl;
void mooseError(Args &&... args)
Emit an error message with the given stringified, concatenated args and terminate the application...
static unsigned int computeLocalBatchSize(const unsigned int batch_size, const unsigned int num_ranks)
Computes the number of local samples.
const Parallel::Communicator & _communicator
processor_id_type n_processors() const
LibtorchArtificialNeuralNet & _nn
Reference to the neural network which is trained.
DIE A HORRIBLE DEATH HERE typedef LIBMESH_DEFAULT_SCALAR_TYPE Real
static unsigned int computeBatchSize(const unsigned int num_samples, const unsigned int num_batches)
Computes the number of samples used for each batch.
processor_id_type processor_id() const
auto min(const L &left, const R &right)
virtual torch::Tensor forward(const torch::Tensor &x) override
Overriding the forward substitution function for the neural network, unfortunately this cannot be con...
static std::unique_ptr< torch::optim::Optimizer > createOptimizer(const LibtorchArtificialNeuralNet &nn, const LibtorchTrainingOptions &options)
Setup the optimizer based on the provided options.