LCOV - code coverage report
Current view: top level - src/systems - diff_system.C (source / functions) Hit Total Coverage
Test: libMesh/libmesh: #4229 (6a9aeb) with base 727f46 Lines: 112 165 67.9 %
Date: 2025-08-19 19:27:09 Functions: 15 24 62.5 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : // The libMesh Finite Element Library.
       2             : // Copyright (C) 2002-2025 Benjamin S. Kirk, John W. Peterson, Roy H. Stogner
       3             : 
       4             : // This library is free software; you can redistribute it and/or
       5             : // modify it under the terms of the GNU Lesser General Public
       6             : // License as published by the Free Software Foundation; either
       7             : // version 2.1 of the License, or (at your option) any later version.
       8             : 
       9             : // This library is distributed in the hope that it will be useful,
      10             : // but WITHOUT ANY WARRANTY; without even the implied warranty of
      11             : // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
      12             : // Lesser General Public License for more details.
      13             : 
      14             : // You should have received a copy of the GNU Lesser General Public
      15             : // License along with this library; if not, write to the Free Software
      16             : // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
      17             : 
      18             : 
      19             : // libMesh includes
      20             : #include "libmesh/diff_solver.h"
      21             : #include "libmesh/diff_system.h"
      22             : #include "libmesh/time_solver.h"
      23             : #include "libmesh/unsteady_solver.h"
      24             : #include "libmesh/dirichlet_boundaries.h"
      25             : #include "libmesh/dof_map.h"
      26             : #include "libmesh/zero_function.h"
      27             : 
      28             : // C++ includes
      29             : #include <utility> // std::swap
      30             : 
      31             : namespace libMesh
      32             : {
      33             : 
      34             : 
      35             : 
      36        5044 : DifferentiableSystem::DifferentiableSystem(EquationSystems & es,
      37             :                                            const std::string & name_in,
      38           0 :                                            const unsigned int number_in) :
      39             :   Parent      (es, name_in, number_in),
      40             :   time_solver (),
      41        4744 :   deltat(1.),
      42        4744 :   postprocess_sides(false),
      43        4744 :   print_solution_norms(false),
      44        4744 :   print_solutions(false),
      45        4744 :   print_residual_norms(false),
      46        4744 :   print_residuals(false),
      47        4744 :   print_jacobian_norms(false),
      48        4744 :   print_jacobians(false),
      49        4744 :   print_element_solutions(false),
      50        4744 :   print_element_residuals(false),
      51        4744 :   print_element_jacobians(false),
      52        4744 :   _constrain_in_solver(true),
      53             :   _diff_physics(),
      54        5044 :   _diff_qoi()
      55             : {
      56        5044 : }
      57             : 
      58             : 
      59             : 
      60        9488 : DifferentiableSystem::~DifferentiableSystem () = default;
      61             : 
      62             : 
      63             : 
      64           0 : void DifferentiableSystem::clear ()
      65             : {
      66             :   // If we had no attached Physics object, clear our own Physics data
      67           0 :   if (this->_diff_physics.empty())
      68           0 :     this->clear_physics();
      69             : 
      70           0 :   this->_diff_physics = {}; // No stack::clear
      71           0 :   this->_diff_qoi = {};
      72             : 
      73             :   // If we had no attached QoI object, clear our own QoI data
      74           0 :   if (this->_diff_qoi.empty())
      75           0 :     this->clear_qoi();
      76             : 
      77           0 :   use_fixed_solution = false;
      78           0 : }
      79             : 
      80             : 
      81             : 
      82        5236 : void DifferentiableSystem::reinit ()
      83             : {
      84        5236 :   Parent::reinit();
      85             : 
      86         176 :   libmesh_assert(time_solver.get());
      87         176 :   libmesh_assert_equal_to (&(time_solver->system()), this);
      88             : 
      89        5236 :   time_solver->reinit();
      90        5236 : }
      91             : 
      92             : 
      93             : 
      94        5044 : void DifferentiableSystem::init_data ()
      95             : {
      96             :   // If it isn't a separate initialized-upon-attachment object, do any
      97             :   // initialization our physics needs.
      98        5044 :   if (this->_diff_physics.empty())
      99        5044 :     this->init_physics(*this);
     100             : 
     101             :   // Do any initialization our solvers need
     102         150 :   libmesh_assert(time_solver.get());
     103         150 :   libmesh_assert_equal_to (&(time_solver->system()), this);
     104             : 
     105             :   // Now check for second order variables and add their velocities to the System.
     106        5044 :   if (!time_solver->is_steady())
     107             :     {
     108             :       const UnsteadySolver & unsteady_solver =
     109          64 :         cast_ref<const UnsteadySolver &>(*(time_solver.get()));
     110             : 
     111        2259 :       if (unsteady_solver.time_order() == 1)
     112        1833 :         this->add_second_order_dot_vars();
     113             :     }
     114             : 
     115        5044 :   time_solver->init();
     116             : 
     117             :   // Next initialize ImplicitSystem data
     118        5044 :   Parent::init_data();
     119             : 
     120        5044 :   time_solver->init_data();
     121        5044 : }
     122             : 
     123           0 : std::unique_ptr<DiffContext> DifferentiableSystem::build_context ()
     124             : {
     125           0 :   auto context = std::make_unique<DiffContext>(*this);
     126           0 :   context->set_deltat_pointer( &this->deltat );
     127           0 :   return context;
     128           0 : }
     129             : 
     130             : 
     131           0 : void DifferentiableSystem::assemble ()
     132             : {
     133           0 :   this->assembly(true, true);
     134           0 : }
     135             : 
     136             : 
     137             : 
     138       25037 : void DifferentiableSystem::solve ()
     139             : {
     140             :   // Get the time solver object associated with the system, and tell it that
     141             :   // we are not solving the adjoint problem
     142         746 :   this->get_time_solver().set_is_adjoint(false);
     143             : 
     144         746 :   libmesh_assert_equal_to (&(time_solver->system()), this);
     145       25037 :   time_solver->solve();
     146       25037 : }
     147             : 
     148             : 
     149             : 
     150       11157 : std::pair<unsigned int, Real> DifferentiableSystem::adjoint_solve (const QoISet & qoi_indices)
     151             : {
     152             :   // Get the time solver object associated with the system, and tell it that
     153             :   // we are solving the adjoint problem
     154         354 :   this->get_time_solver().set_is_adjoint(true);
     155             : 
     156       10803 :   return time_solver->adjoint_solve(qoi_indices);
     157             : 
     158             :   //return this->ImplicitSystem::adjoint_solve(qoi_indices);
     159             : }
     160             : 
     161             : 
     162             : 
     163      156509 : LinearSolver<Number> * DifferentiableSystem::get_linear_solver() const
     164             : {
     165        4526 :   libmesh_assert(time_solver.get());
     166        4526 :   libmesh_assert_equal_to (&(time_solver->system()), this);
     167      156509 :   return this->time_solver->linear_solver().get();
     168             : }
     169             : 
     170             : 
     171             : 
     172       14654 : std::pair<unsigned int, Real> DifferentiableSystem::get_linear_solve_parameters() const
     173             : {
     174         464 :   libmesh_assert(time_solver.get());
     175         464 :   libmesh_assert_equal_to (&(time_solver->system()), this);
     176       14654 :   return std::make_pair(this->time_solver->diff_solver()->max_linear_iterations,
     177       15118 :                         this->time_solver->diff_solver()->relative_residual_tolerance);
     178             : }
     179             : 
     180             : 
     181             : 
     182        1833 : void DifferentiableSystem::add_second_order_dot_vars()
     183             : {
     184        1833 :   const std::set<unsigned int> & second_order_vars = this->get_second_order_vars();
     185        1833 :   if (!second_order_vars.empty())
     186             :     {
     187        1278 :       for (const auto & var_id : second_order_vars)
     188             :         {
     189         852 :           const Variable & var = this->variable(var_id);
     190        1728 :           std::string new_var_name = std::string("dot_")+var.name();
     191             : 
     192             :           unsigned int v_var_idx;
     193             : 
     194         852 :           if (var.active_subdomains().empty())
     195         876 :             v_var_idx = this->add_variable( new_var_name, var.type() );
     196             :           else
     197           0 :             v_var_idx = this->add_variable( new_var_name, var.type(), &var.active_subdomains() );
     198             : 
     199         852 :           _second_order_dot_vars.insert(std::pair<unsigned int, unsigned int>(var_id, v_var_idx));
     200             : 
     201             :           // The new velocities are time evolving variables of first order
     202         852 :           this->time_evolving( v_var_idx, 1 );
     203             : 
     204             : #ifdef LIBMESH_ENABLE_DIRICHLET
     205             :           // And if there are any boundary conditions set on the second order
     206             :           // variable, we also need to set it on its velocity variable.
     207         852 :           this->add_dot_var_dirichlet_bcs(var_id, v_var_idx);
     208             : #endif
     209             :         }
     210             :     }
     211        1833 : }
     212             : 
     213             : #ifdef LIBMESH_ENABLE_DIRICHLET
     214         852 : void DifferentiableSystem::add_dot_var_dirichlet_bcs( unsigned int var_idx,
     215             :                                                       unsigned int dot_var_idx )
     216             : {
     217             :   // We're assuming that there could be a lot more variables than
     218             :   // boundary conditions, so we search each of the boundary conditions
     219             :   // for this variable rather than looping over boundary conditions
     220             :   // in a separate loop and searching through all the variables.
     221             :   const DirichletBoundaries * all_dbcs =
     222          24 :     this->get_dof_map().get_dirichlet_boundaries();
     223             : 
     224         852 :   if (all_dbcs)
     225             :     {
     226             :       // We need to cache the DBCs to be added so that we add them
     227             :       // after looping over the existing DBCs. Otherwise, we're polluting
     228             :       // the thing we're looping over.
     229          72 :       std::vector<DirichletBoundary> new_dbcs;
     230             : 
     231        2840 :       for (const auto & dbc : *all_dbcs)
     232             :         {
     233          56 :           libmesh_assert(dbc);
     234             : 
     235             :           // Look for second order variable in the current
     236             :           // DirichletBoundary object
     237             :           std::vector<unsigned int>::const_iterator dbc_var_it =
     238        1988 :             std::find( dbc->variables.begin(), dbc->variables.end(), var_idx );
     239             : 
     240             :           // If we found it, then we also need to add it's corresponding
     241             :           // "dot" variable to a DirichletBoundary
     242         112 :           std::vector<unsigned int> vars_to_add;
     243        1988 :           if (dbc_var_it != dbc->variables.end())
     244         994 :             vars_to_add.push_back(dot_var_idx);
     245             : 
     246        1988 :           if (!vars_to_add.empty())
     247             :             {
     248             :               // We need to check if the boundary condition is time-dependent.
     249             :               // Currently, we cannot automatically differentiate w.r.t. time
     250             :               // so if the user supplies a time-dependent Dirichlet BC, then
     251             :               // we can't automatically support the Dirichlet BC for the
     252             :               // "velocity" boundary condition, so we error. Otherwise,
     253             :               // the "velocity boundary condition will just be zero.
     254          28 :               bool is_time_evolving_bc = false;
     255         994 :               if (dbc->f)
     256          56 :                 is_time_evolving_bc = dbc->f->is_time_dependent();
     257           0 :               else if (dbc->f_fem)
     258             :                 // We it's a FEMFunctionBase object, it will be implicitly
     259             :                 // time-dependent since it is assumed to depend on the solution.
     260           0 :                 is_time_evolving_bc = true;
     261             :               else
     262           0 :                 libmesh_error_msg("Could not find valid boundary function!");
     263             : 
     264         994 :               libmesh_error_msg_if(is_time_evolving_bc, "Cannot currently support time-dependent Dirichlet BC for dot variables!");
     265          28 :               libmesh_error_msg_if(!dbc->f, "Expected valid DirichletBoundary function");
     266             : 
     267         994 :               new_dbcs.emplace_back(dbc->b, vars_to_add, ZeroFunction<Number>());
     268             :             }
     269             :         }
     270             : 
     271             :       // Let the DofMap make its own deep copy of the DirichletBC objects
     272        1846 :       for (const auto & dbc : new_dbcs)
     273         994 :         this->get_dof_map().add_dirichlet_boundary(dbc);
     274             : 
     275         804 :     } // if (all_dbcs)
     276         852 : }
     277             : #endif // LIBMESH_ENABLE_DIRICHLET
     278             : 
     279         141 : void DifferentiableSystem::attach_qoi( DifferentiableQoI * qoi_in )
     280             : {
     281         141 :   this->_diff_qoi = {};
     282         278 :   this->_diff_qoi.push(qoi_in->clone());
     283             : 
     284           4 :   auto & dq = this->_diff_qoi.top();
     285             :   // User needs to resize qoi system qoi accordingly
     286             : #ifdef LIBMESH_ENABLE_DEPRECATED
     287             :   // Call the old API for backwards compatibility
     288         141 :   dq->init_qoi( this->qoi );
     289             : 
     290             :   // Then the new API for forwards compatibility
     291         141 :   dq->init_qoi_count( *this );
     292             : #else
     293             : #ifndef NDEBUG
     294             :   // Make sure the user has updated their QoI subclass - call the old
     295             :   // API and make sure it does nothing
     296             :   std::vector<Number> deprecated_vector;
     297             :   dq->init_qoi( deprecated_vector );
     298             :   libmesh_assert(deprecated_vector.empty());
     299             : #endif
     300             : 
     301             :   // Then the new API
     302             :   dq->init_qoi_count( *this );
     303             : #endif
     304         141 : }
     305             : 
     306    14624736 : unsigned int DifferentiableSystem::get_second_order_dot_var( unsigned int var ) const
     307             : {
     308             :   // For SteadySolver or SecondOrderUnsteadySolvers, we just give back var
     309    13334640 :   unsigned int dot_var = var;
     310             : 
     311    14624736 :   if (!time_solver->is_steady())
     312             :     {
     313             :       const UnsteadySolver & unsteady_solver =
     314     1280184 :         cast_ref<const UnsteadySolver &>(*(time_solver.get()));
     315             : 
     316    14505792 :       if (unsteady_solver.time_order() == 1)
     317     9354480 :         dot_var = this->_second_order_dot_vars.find(var)->second;
     318             :     }
     319             : 
     320    14624736 :   return dot_var;
     321             : }
     322             : 
     323           0 : bool DifferentiableSystem::have_first_order_scalar_vars() const
     324             : {
     325           0 :   bool have_first_order_scalar_vars = false;
     326             : 
     327           0 :   if (this->have_first_order_vars())
     328           0 :     for (const auto & var : this->get_first_order_vars())
     329           0 :       if (this->variable(var).type().family == SCALAR)
     330           0 :         have_first_order_scalar_vars = true;
     331             : 
     332           0 :   return have_first_order_scalar_vars;
     333             : }
     334             : 
     335           0 : bool DifferentiableSystem::have_second_order_scalar_vars() const
     336             : {
     337           0 :   bool have_second_order_scalar_vars = false;
     338             : 
     339           0 :   if (this->have_second_order_vars())
     340           0 :     for (const auto & var : this->get_second_order_vars())
     341           0 :       if (this->variable(var).type().family == SCALAR)
     342           0 :         have_second_order_scalar_vars = true;
     343             : 
     344           0 :   return have_second_order_scalar_vars;
     345             : }
     346             : 
     347             : 
     348             : 
     349             : #ifdef LIBMESH_ENABLE_DEPRECATED
     350           0 : void DifferentiableSystem::swap_physics ( DifferentiablePhysics * & swap_physics )
     351             : {
     352             :   // This isn't safe if users aren't very careful about memory
     353             :   // management and they don't (or aren't able to due to an exception)
     354             :   // swap back.
     355             :   libmesh_deprecated();
     356             : 
     357             :   // A mess of code for backwards compatibility
     358           0 :   if (this->_diff_physics.empty())
     359             :     {
     360             :       // Swap-something-else-for-self
     361           0 :       std::unique_ptr<DifferentiablePhysics> scary_hack(swap_physics);
     362           0 :       this->_diff_physics.push(std::move(scary_hack));
     363           0 :       swap_physics = this;
     364           0 :     }
     365           0 :   else if (swap_physics == this)
     366             :     {
     367             :       // The user must be cleaning up after a previous
     368             :       // swap-something-else-for-self
     369           0 :       libmesh_assert(!this->_diff_physics.empty());
     370             : 
     371             :       // So we don't want to delete what got swapped in, but we do
     372             :       // want to put it back into their pointer
     373           0 :       DifferentiablePhysics * old_p = this->_diff_physics.top().release();
     374           0 :       this->_diff_physics.pop();
     375           0 :       swap_physics = old_p;
     376             : 
     377             :       // And if the user is doing anything more sophisticated than
     378             :       // that then the user is sophisticated enough to upgrade to
     379             :       // push/pop.
     380           0 :       libmesh_assert(this->_diff_physics.empty());
     381             :     }
     382             :   else
     383             :     {
     384             :       // Swapping one external physics for another
     385           0 :       DifferentiablePhysics * old_p = this->_diff_physics.top().release();
     386           0 :       std::swap(old_p, swap_physics);
     387           0 :       this->_diff_physics.top().reset(old_p);
     388             :     }
     389             : 
     390             :   // If the physics has been swapped, we will reassemble
     391             :   // the matrix from scratch before doing an adjoint solve
     392             :   // rather than just transposing
     393           0 :   this->disable_cache();
     394           0 : }
     395             : #endif // LIBMESH_ENABLE_DEPRECATED
     396             : 
     397             : 
     398             : 
     399       70560 : void DifferentiableSystem::push_physics ( DifferentiablePhysics & new_physics )
     400             : {
     401      137088 :   this->_diff_physics.push(new_physics.clone_physics());
     402             : 
     403             :   // If the physics has been changed, we will reassemble
     404             :   // the matrix from scratch before doing an adjoint solve
     405             :   // rather than just transposing
     406       70560 :   this->disable_cache();
     407       70560 : }
     408             : 
     409             : 
     410             : 
     411       70560 : void DifferentiableSystem::pop_physics ()
     412             : {
     413        2016 :   libmesh_assert(!this->_diff_physics.empty());
     414             : 
     415        2016 :   this->_diff_physics.pop();
     416             : 
     417             :   // If the physics has been changed, we will reassemble
     418             :   // the matrix from scratch before doing an adjoint solve
     419             :   // rather than just transposing
     420       70560 :   this->disable_cache();
     421       70560 : }
     422             : 
     423             : 
     424        1573 : void DifferentiableSystem::set_constrain_in_solver(bool enable)
     425             : {
     426        1573 :   _constrain_in_solver = enable;
     427        1573 :   this->time_solver->diff_solver()->set_exact_constraint_enforcement(enable);
     428        1573 : }
     429             : 
     430             : 
     431             : } // namespace libMesh

Generated by: LCOV version 1.14