Line data Source code
1 : // The libMesh Finite Element Library.
2 : // Copyright (C) 2002-2025 Benjamin S. Kirk, John W. Peterson, Roy H. Stogner
3 :
4 : // This library is free software; you can redistribute it and/or
5 : // modify it under the terms of the GNU Lesser General Public
6 : // License as published by the Free Software Foundation; either
7 : // version 2.1 of the License, or (at your option) any later version.
8 :
9 : // This library is distributed in the hope that it will be useful,
10 : // but WITHOUT ANY WARRANTY; without even the implied warranty of
11 : // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
12 : // Lesser General Public License for more details.
13 :
14 : // You should have received a copy of the GNU Lesser General Public
15 : // License along with this library; if not, write to the Free Software
16 : // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17 :
18 :
19 : // libMesh includes
20 : #include "libmesh/diff_solver.h"
21 : #include "libmesh/diff_system.h"
22 : #include "libmesh/time_solver.h"
23 : #include "libmesh/unsteady_solver.h"
24 : #include "libmesh/dirichlet_boundaries.h"
25 : #include "libmesh/dof_map.h"
26 : #include "libmesh/zero_function.h"
27 :
28 : // C++ includes
29 : #include <utility> // std::swap
30 :
31 : namespace libMesh
32 : {
33 :
34 :
35 :
36 5044 : DifferentiableSystem::DifferentiableSystem(EquationSystems & es,
37 : const std::string & name_in,
38 0 : const unsigned int number_in) :
39 : Parent (es, name_in, number_in),
40 : time_solver (),
41 4744 : deltat(1.),
42 4744 : postprocess_sides(false),
43 4744 : print_solution_norms(false),
44 4744 : print_solutions(false),
45 4744 : print_residual_norms(false),
46 4744 : print_residuals(false),
47 4744 : print_jacobian_norms(false),
48 4744 : print_jacobians(false),
49 4744 : print_element_solutions(false),
50 4744 : print_element_residuals(false),
51 4744 : print_element_jacobians(false),
52 4744 : _constrain_in_solver(true),
53 : _diff_physics(),
54 5044 : _diff_qoi()
55 : {
56 5044 : }
57 :
58 :
59 :
60 9488 : DifferentiableSystem::~DifferentiableSystem () = default;
61 :
62 :
63 :
64 0 : void DifferentiableSystem::clear ()
65 : {
66 : // If we had no attached Physics object, clear our own Physics data
67 0 : if (this->_diff_physics.empty())
68 0 : this->clear_physics();
69 :
70 0 : this->_diff_physics = {}; // No stack::clear
71 0 : this->_diff_qoi = {};
72 :
73 : // If we had no attached QoI object, clear our own QoI data
74 0 : if (this->_diff_qoi.empty())
75 0 : this->clear_qoi();
76 :
77 0 : use_fixed_solution = false;
78 0 : }
79 :
80 :
81 :
82 5236 : void DifferentiableSystem::reinit ()
83 : {
84 5236 : Parent::reinit();
85 :
86 176 : libmesh_assert(time_solver.get());
87 176 : libmesh_assert_equal_to (&(time_solver->system()), this);
88 :
89 5236 : time_solver->reinit();
90 5236 : }
91 :
92 :
93 :
94 5044 : void DifferentiableSystem::init_data ()
95 : {
96 : // If it isn't a separate initialized-upon-attachment object, do any
97 : // initialization our physics needs.
98 5044 : if (this->_diff_physics.empty())
99 5044 : this->init_physics(*this);
100 :
101 : // Do any initialization our solvers need
102 150 : libmesh_assert(time_solver.get());
103 150 : libmesh_assert_equal_to (&(time_solver->system()), this);
104 :
105 : // Now check for second order variables and add their velocities to the System.
106 5044 : if (!time_solver->is_steady())
107 : {
108 : const UnsteadySolver & unsteady_solver =
109 64 : cast_ref<const UnsteadySolver &>(*(time_solver.get()));
110 :
111 2259 : if (unsteady_solver.time_order() == 1)
112 1833 : this->add_second_order_dot_vars();
113 : }
114 :
115 5044 : time_solver->init();
116 :
117 : // Next initialize ImplicitSystem data
118 5044 : Parent::init_data();
119 :
120 5044 : time_solver->init_data();
121 5044 : }
122 :
123 0 : std::unique_ptr<DiffContext> DifferentiableSystem::build_context ()
124 : {
125 0 : auto context = std::make_unique<DiffContext>(*this);
126 0 : context->set_deltat_pointer( &this->deltat );
127 0 : return context;
128 0 : }
129 :
130 :
131 0 : void DifferentiableSystem::assemble ()
132 : {
133 0 : this->assembly(true, true);
134 0 : }
135 :
136 :
137 :
138 25037 : void DifferentiableSystem::solve ()
139 : {
140 : // Get the time solver object associated with the system, and tell it that
141 : // we are not solving the adjoint problem
142 746 : this->get_time_solver().set_is_adjoint(false);
143 :
144 746 : libmesh_assert_equal_to (&(time_solver->system()), this);
145 25037 : time_solver->solve();
146 25037 : }
147 :
148 :
149 :
150 11157 : std::pair<unsigned int, Real> DifferentiableSystem::adjoint_solve (const QoISet & qoi_indices)
151 : {
152 : // Get the time solver object associated with the system, and tell it that
153 : // we are solving the adjoint problem
154 354 : this->get_time_solver().set_is_adjoint(true);
155 :
156 10803 : return time_solver->adjoint_solve(qoi_indices);
157 :
158 : //return this->ImplicitSystem::adjoint_solve(qoi_indices);
159 : }
160 :
161 :
162 :
163 156509 : LinearSolver<Number> * DifferentiableSystem::get_linear_solver() const
164 : {
165 4526 : libmesh_assert(time_solver.get());
166 4526 : libmesh_assert_equal_to (&(time_solver->system()), this);
167 156509 : return this->time_solver->linear_solver().get();
168 : }
169 :
170 :
171 :
172 14654 : std::pair<unsigned int, Real> DifferentiableSystem::get_linear_solve_parameters() const
173 : {
174 464 : libmesh_assert(time_solver.get());
175 464 : libmesh_assert_equal_to (&(time_solver->system()), this);
176 14654 : return std::make_pair(this->time_solver->diff_solver()->max_linear_iterations,
177 15118 : this->time_solver->diff_solver()->relative_residual_tolerance);
178 : }
179 :
180 :
181 :
182 1833 : void DifferentiableSystem::add_second_order_dot_vars()
183 : {
184 1833 : const std::set<unsigned int> & second_order_vars = this->get_second_order_vars();
185 1833 : if (!second_order_vars.empty())
186 : {
187 1278 : for (const auto & var_id : second_order_vars)
188 : {
189 852 : const Variable & var = this->variable(var_id);
190 1728 : std::string new_var_name = std::string("dot_")+var.name();
191 :
192 : unsigned int v_var_idx;
193 :
194 852 : if (var.active_subdomains().empty())
195 876 : v_var_idx = this->add_variable( new_var_name, var.type() );
196 : else
197 0 : v_var_idx = this->add_variable( new_var_name, var.type(), &var.active_subdomains() );
198 :
199 852 : _second_order_dot_vars.insert(std::pair<unsigned int, unsigned int>(var_id, v_var_idx));
200 :
201 : // The new velocities are time evolving variables of first order
202 852 : this->time_evolving( v_var_idx, 1 );
203 :
204 : #ifdef LIBMESH_ENABLE_DIRICHLET
205 : // And if there are any boundary conditions set on the second order
206 : // variable, we also need to set it on its velocity variable.
207 852 : this->add_dot_var_dirichlet_bcs(var_id, v_var_idx);
208 : #endif
209 : }
210 : }
211 1833 : }
212 :
213 : #ifdef LIBMESH_ENABLE_DIRICHLET
214 852 : void DifferentiableSystem::add_dot_var_dirichlet_bcs( unsigned int var_idx,
215 : unsigned int dot_var_idx )
216 : {
217 : // We're assuming that there could be a lot more variables than
218 : // boundary conditions, so we search each of the boundary conditions
219 : // for this variable rather than looping over boundary conditions
220 : // in a separate loop and searching through all the variables.
221 : const DirichletBoundaries * all_dbcs =
222 24 : this->get_dof_map().get_dirichlet_boundaries();
223 :
224 852 : if (all_dbcs)
225 : {
226 : // We need to cache the DBCs to be added so that we add them
227 : // after looping over the existing DBCs. Otherwise, we're polluting
228 : // the thing we're looping over.
229 72 : std::vector<DirichletBoundary> new_dbcs;
230 :
231 2840 : for (const auto & dbc : *all_dbcs)
232 : {
233 56 : libmesh_assert(dbc);
234 :
235 : // Look for second order variable in the current
236 : // DirichletBoundary object
237 : std::vector<unsigned int>::const_iterator dbc_var_it =
238 1988 : std::find( dbc->variables.begin(), dbc->variables.end(), var_idx );
239 :
240 : // If we found it, then we also need to add it's corresponding
241 : // "dot" variable to a DirichletBoundary
242 112 : std::vector<unsigned int> vars_to_add;
243 1988 : if (dbc_var_it != dbc->variables.end())
244 994 : vars_to_add.push_back(dot_var_idx);
245 :
246 1988 : if (!vars_to_add.empty())
247 : {
248 : // We need to check if the boundary condition is time-dependent.
249 : // Currently, we cannot automatically differentiate w.r.t. time
250 : // so if the user supplies a time-dependent Dirichlet BC, then
251 : // we can't automatically support the Dirichlet BC for the
252 : // "velocity" boundary condition, so we error. Otherwise,
253 : // the "velocity boundary condition will just be zero.
254 28 : bool is_time_evolving_bc = false;
255 994 : if (dbc->f)
256 56 : is_time_evolving_bc = dbc->f->is_time_dependent();
257 0 : else if (dbc->f_fem)
258 : // We it's a FEMFunctionBase object, it will be implicitly
259 : // time-dependent since it is assumed to depend on the solution.
260 0 : is_time_evolving_bc = true;
261 : else
262 0 : libmesh_error_msg("Could not find valid boundary function!");
263 :
264 994 : libmesh_error_msg_if(is_time_evolving_bc, "Cannot currently support time-dependent Dirichlet BC for dot variables!");
265 28 : libmesh_error_msg_if(!dbc->f, "Expected valid DirichletBoundary function");
266 :
267 994 : new_dbcs.emplace_back(dbc->b, vars_to_add, ZeroFunction<Number>());
268 : }
269 : }
270 :
271 : // Let the DofMap make its own deep copy of the DirichletBC objects
272 1846 : for (const auto & dbc : new_dbcs)
273 994 : this->get_dof_map().add_dirichlet_boundary(dbc);
274 :
275 804 : } // if (all_dbcs)
276 852 : }
277 : #endif // LIBMESH_ENABLE_DIRICHLET
278 :
279 141 : void DifferentiableSystem::attach_qoi( DifferentiableQoI * qoi_in )
280 : {
281 141 : this->_diff_qoi = {};
282 278 : this->_diff_qoi.push(qoi_in->clone());
283 :
284 4 : auto & dq = this->_diff_qoi.top();
285 : // User needs to resize qoi system qoi accordingly
286 : #ifdef LIBMESH_ENABLE_DEPRECATED
287 : // Call the old API for backwards compatibility
288 141 : dq->init_qoi( this->qoi );
289 :
290 : // Then the new API for forwards compatibility
291 141 : dq->init_qoi_count( *this );
292 : #else
293 : #ifndef NDEBUG
294 : // Make sure the user has updated their QoI subclass - call the old
295 : // API and make sure it does nothing
296 : std::vector<Number> deprecated_vector;
297 : dq->init_qoi( deprecated_vector );
298 : libmesh_assert(deprecated_vector.empty());
299 : #endif
300 :
301 : // Then the new API
302 : dq->init_qoi_count( *this );
303 : #endif
304 141 : }
305 :
306 14624736 : unsigned int DifferentiableSystem::get_second_order_dot_var( unsigned int var ) const
307 : {
308 : // For SteadySolver or SecondOrderUnsteadySolvers, we just give back var
309 13334640 : unsigned int dot_var = var;
310 :
311 14624736 : if (!time_solver->is_steady())
312 : {
313 : const UnsteadySolver & unsteady_solver =
314 1280184 : cast_ref<const UnsteadySolver &>(*(time_solver.get()));
315 :
316 14505792 : if (unsteady_solver.time_order() == 1)
317 9354480 : dot_var = this->_second_order_dot_vars.find(var)->second;
318 : }
319 :
320 14624736 : return dot_var;
321 : }
322 :
323 0 : bool DifferentiableSystem::have_first_order_scalar_vars() const
324 : {
325 0 : bool have_first_order_scalar_vars = false;
326 :
327 0 : if (this->have_first_order_vars())
328 0 : for (const auto & var : this->get_first_order_vars())
329 0 : if (this->variable(var).type().family == SCALAR)
330 0 : have_first_order_scalar_vars = true;
331 :
332 0 : return have_first_order_scalar_vars;
333 : }
334 :
335 0 : bool DifferentiableSystem::have_second_order_scalar_vars() const
336 : {
337 0 : bool have_second_order_scalar_vars = false;
338 :
339 0 : if (this->have_second_order_vars())
340 0 : for (const auto & var : this->get_second_order_vars())
341 0 : if (this->variable(var).type().family == SCALAR)
342 0 : have_second_order_scalar_vars = true;
343 :
344 0 : return have_second_order_scalar_vars;
345 : }
346 :
347 :
348 :
349 : #ifdef LIBMESH_ENABLE_DEPRECATED
350 0 : void DifferentiableSystem::swap_physics ( DifferentiablePhysics * & swap_physics )
351 : {
352 : // This isn't safe if users aren't very careful about memory
353 : // management and they don't (or aren't able to due to an exception)
354 : // swap back.
355 : libmesh_deprecated();
356 :
357 : // A mess of code for backwards compatibility
358 0 : if (this->_diff_physics.empty())
359 : {
360 : // Swap-something-else-for-self
361 0 : std::unique_ptr<DifferentiablePhysics> scary_hack(swap_physics);
362 0 : this->_diff_physics.push(std::move(scary_hack));
363 0 : swap_physics = this;
364 0 : }
365 0 : else if (swap_physics == this)
366 : {
367 : // The user must be cleaning up after a previous
368 : // swap-something-else-for-self
369 0 : libmesh_assert(!this->_diff_physics.empty());
370 :
371 : // So we don't want to delete what got swapped in, but we do
372 : // want to put it back into their pointer
373 0 : DifferentiablePhysics * old_p = this->_diff_physics.top().release();
374 0 : this->_diff_physics.pop();
375 0 : swap_physics = old_p;
376 :
377 : // And if the user is doing anything more sophisticated than
378 : // that then the user is sophisticated enough to upgrade to
379 : // push/pop.
380 0 : libmesh_assert(this->_diff_physics.empty());
381 : }
382 : else
383 : {
384 : // Swapping one external physics for another
385 0 : DifferentiablePhysics * old_p = this->_diff_physics.top().release();
386 0 : std::swap(old_p, swap_physics);
387 0 : this->_diff_physics.top().reset(old_p);
388 : }
389 :
390 : // If the physics has been swapped, we will reassemble
391 : // the matrix from scratch before doing an adjoint solve
392 : // rather than just transposing
393 0 : this->disable_cache();
394 0 : }
395 : #endif // LIBMESH_ENABLE_DEPRECATED
396 :
397 :
398 :
399 70560 : void DifferentiableSystem::push_physics ( DifferentiablePhysics & new_physics )
400 : {
401 137088 : this->_diff_physics.push(new_physics.clone_physics());
402 :
403 : // If the physics has been changed, we will reassemble
404 : // the matrix from scratch before doing an adjoint solve
405 : // rather than just transposing
406 70560 : this->disable_cache();
407 70560 : }
408 :
409 :
410 :
411 70560 : void DifferentiableSystem::pop_physics ()
412 : {
413 2016 : libmesh_assert(!this->_diff_physics.empty());
414 :
415 2016 : this->_diff_physics.pop();
416 :
417 : // If the physics has been changed, we will reassemble
418 : // the matrix from scratch before doing an adjoint solve
419 : // rather than just transposing
420 70560 : this->disable_cache();
421 70560 : }
422 :
423 :
424 1573 : void DifferentiableSystem::set_constrain_in_solver(bool enable)
425 : {
426 1573 : _constrain_in_solver = enable;
427 1573 : this->time_solver->diff_solver()->set_exact_constraint_enforcement(enable);
428 1573 : }
429 :
430 :
431 : } // namespace libMesh
|