diff --git a/src/dolfinx_adjoint/__init__.py b/src/dolfinx_adjoint/__init__.py index 00e96e7..7a0a688 100644 --- a/src/dolfinx_adjoint/__init__.py +++ b/src/dolfinx_adjoint/__init__.py @@ -7,7 +7,7 @@ from .assembly import assemble_scalar, error_norm from .solvers import LinearProblem, NonlinearProblem -from .types import Constant, Function +from .types import Constant, Function, dirichletbc from .types.function import assign meta = metadata("dolfinx_adjoint") @@ -24,6 +24,7 @@ __all__ = [ "Constant", "Function", + "dirichletbc", "LinearProblem", "NonlinearProblem", "assemble_scalar", diff --git a/src/dolfinx_adjoint/blocks/assembly.py b/src/dolfinx_adjoint/blocks/assembly.py index b288732..dc50824 100644 --- a/src/dolfinx_adjoint/blocks/assembly.py +++ b/src/dolfinx_adjoint/blocks/assembly.py @@ -129,6 +129,7 @@ def compute_action_adjoint( form_compiler_options=self._form_compiler_options, entity_maps=self._entity_maps, ) + if space is None: # If space is not supplied infer it from the form assert len(dform.arguments()) == 1 @@ -143,6 +144,9 @@ def compute_action_adjoint( # assemble_compiled_form(compiled_adjoint, self._cached_vectors[id(space)]) assemble_compiled_form(compiled_adjoint, vector) # return a vector scaled by the scalar `adj_input` + # Safegaurd against None seeds from PyAdjoint + if adj_input is None: + adj_input = 1.0 vector.array[:] *= vector.x.array.dtype.type(adj_input) vector.scatter_forward() diff --git a/src/dolfinx_adjoint/blocks/dirichletbc.py b/src/dolfinx_adjoint/blocks/dirichletbc.py new file mode 100644 index 0000000..0fd8b48 --- /dev/null +++ b/src/dolfinx_adjoint/blocks/dirichletbc.py @@ -0,0 +1,24 @@ +import dolfinx +import numpy as np +import numpy.typing as npt +from pyadjoint.block import Block + + +class DirichletBCBlock(Block): + def __init__( + self, + value: dolfinx.fem.Function | dolfinx.fem.Constant, + dofs: npt.NDArray[np.int32], + V: dolfinx.fem.FunctionSpace | None = None, + ad_block_tag=None, + ): + super().__init__(ad_block_tag=ad_block_tag) + self.dofs = dofs + self.V = V + self.add_dependency(value) + + def prepare_recompute_component(self, inputs, relevant_outputs): + return inputs[0] if inputs else None + + def recompute_component(self, inputs, block_variable, idx, prepared): + return block_variable.saved_output diff --git a/src/dolfinx_adjoint/blocks/function_assigner.py b/src/dolfinx_adjoint/blocks/function_assigner.py index 7407de3..fca04b5 100644 --- a/src/dolfinx_adjoint/blocks/function_assigner.py +++ b/src/dolfinx_adjoint/blocks/function_assigner.py @@ -178,15 +178,18 @@ def prepare_recompute_component(self, inputs, relevant_outputs): def recompute_component(self, inputs, block_variable, idx, prepared): if self.expr is None: prepared = inputs[0] - output = dolfinx.fem.Function( - block_variable.output.function_space, name="f{block_variable.output.name}_AssignBlockRecompute" - ) + + # We should return the exact object instance to maintain C++ memory bindings + # (especially for DirichletBCs), updating it in-place. + output = block_variable.saved_output + try: if output.function_space == prepared.function_space: output.x.array[:] = prepared.x.array[:] except AttributeError: # Handling float value output.x.array[:] = prepared + return output def __str__(self): diff --git a/src/dolfinx_adjoint/blocks/solvers.py b/src/dolfinx_adjoint/blocks/solvers.py index b5db2e3..0776ec9 100644 --- a/src/dolfinx_adjoint/blocks/solvers.py +++ b/src/dolfinx_adjoint/blocks/solvers.py @@ -64,6 +64,7 @@ def __init__( self.add_dependency(c, no_duplicates=True) for c in self._rhs.coefficients(): # type: ignore self.add_dependency(c, no_duplicates=True) + except AttributeError: raise NotImplementedError("Blocked systems not implemented yet.") self._compiled_lhs = dolfinx.fem.form( @@ -86,6 +87,13 @@ def __init__( self._petsc_options = petsc_options if petsc_options is not None else {} self._petsc_options_prefix = petsc_options_prefix self._bcs = bcs if bcs is not None else [] + + # Add dependencies from the boundary conditions + if self._bcs is not None: + for bc in self._bcs: + if hasattr(bc, "block_variable"): + self.add_dependency(bc, no_duplicates=True) + # Solver for recomputing the linear problem self._forward_solver = dolfinx.fem.petsc.LinearProblem( a=self._lhs, @@ -162,16 +170,6 @@ def prepare_recompute_component(self, inputs, relevant_outputs): else: initial_guess = [dolfinx.fem.Function(u.function_space, name=u.name + "_initial_guess") for u in self._u] - # Replace values in the DirichletBC if it is dependent on a control - # NOTE: Currently assume that BCS are control independent. - bcs = self._bcs - # for block_variable in self.get_dependencies(): - # c = block_variable.output - # c_rep = block_variable.saved_output - - # if isinstance(c, dolfinx.fem.DirichletBC): - # bcs.append(c_rep) - # Replace form coefficients with checkpointed values. # Loop through the dependencies of the lhs and rhs, check if they are in the respective form lhs = self._replace_coefficients_in_form(self._lhs) @@ -206,7 +204,7 @@ def prepare_recompute_component(self, inputs, relevant_outputs): self._forward_solver._a = compiled_lhs self._forward_solver._L = compiled_rhs self._forward_solver._P = compiled_preconditioner - self._forward_solver.bcs = bcs + self._forward_solver.bcs = self._bcs self._forward_solver._u = initial_guess def recompute_component( @@ -434,6 +432,7 @@ def evaluate_adj_component( entity_maps=self._entity_maps, ) vec = _create_vector(compiled_sensitivity, sensitivity.arguments()[0].ufl_function_space()) + vec.array[:] = 0.0 assemble_compiled_form(compiled_sensitivity, tensor=vec) return vec @@ -584,6 +583,7 @@ def evaluate_hessian_component( entity_maps=self._entity_maps, ) hessian_output = _create_vector(compiled_hessian, hessian_form.arguments()[0].ufl_function_space()) + hessian_output.array[:] = 0.0 assemble_compiled_form(compiled_hessian, hessian_output) hessian_output.array[:] *= -1.0 return hessian_output @@ -999,6 +999,7 @@ def evaluate_adj_component( entity_maps=self._entity_maps, ) vec = _create_vector(compiled_sensitivity, sensitivity.arguments()[0].ufl_function_space()) + vec.array[:] = 0.0 assemble_compiled_form(compiled_sensitivity, tensor=vec) return vec @@ -1055,7 +1056,6 @@ def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_ entity_maps=self._entity_maps, ) - # Solve adjoint problem self._adjoint_solver._a = dFdu_adj self._adjoint_solver._b = b.petsc_vec self._adjoint_solver._u = self._second_adjoint_solutions @@ -1147,6 +1147,7 @@ def evaluate_hessian_component( entity_maps=self._entity_maps, ) hessian_output = _create_vector(compiled_hessian, hessian_form.arguments()[0].ufl_function_space()) + hessian_output.array[:] = 0.0 assemble_compiled_form(compiled_hessian, hessian_output) hessian_output.array[:] *= -1.0 return hessian_output diff --git a/src/dolfinx_adjoint/types/__init__.py b/src/dolfinx_adjoint/types/__init__.py index c1a57a0..a781b83 100644 --- a/src/dolfinx_adjoint/types/__init__.py +++ b/src/dolfinx_adjoint/types/__init__.py @@ -1,3 +1,4 @@ -__all__ = ["Function", "Constant"] +__all__ = ["Function", "Constant", "dirichletbc"] +from .dirichletbc import dirichletbc from .function import Constant, Function diff --git a/src/dolfinx_adjoint/types/dirichletbc.py b/src/dolfinx_adjoint/types/dirichletbc.py new file mode 100644 index 0000000..547bf15 --- /dev/null +++ b/src/dolfinx_adjoint/types/dirichletbc.py @@ -0,0 +1,63 @@ +import dolfinx +import numpy as np +import numpy.typing as npt +import pyadjoint +from pyadjoint.overloaded_type import FloatingType + +from ..blocks.dirichletbc import DirichletBCBlock +from .function import Function + + +class DirichletBC(dolfinx.fem.DirichletBC, FloatingType): + """A class overloading `dolfinx.fem.DirichletBC` to support it being used as a control variable + in the adjoint framework. + + Args: + g: The value of the Dirichlet BC. + dofs: An array of degree-of-freedom indices in `V` where the BC should be applied. + **kwargs: Additional keyword arguments to pass to the `pyadjoint.overloaded_type.FloatingType` constructor. + + """ + + def __init__(self, g: Function, dofs: npt.NDArray[np.int32], **kwargs): + dtype = g.dtype + if np.issubdtype(dtype, np.float32): + bctype = dolfinx.cpp.fem.DirichletBC_float32 + elif np.issubdtype(dtype, np.float64): + bctype = dolfinx.cpp.fem.DirichletBC_float64 + elif np.issubdtype(dtype, np.complex64): + bctype = dolfinx.cpp.fem.DirichletBC_complex64 + elif np.issubdtype(dtype, np.complex128): + bctype = dolfinx.cpp.fem.DirichletBC_complex128 + else: + raise NotImplementedError(f"Type {dtype} not supported.") + + super().__init__(bctype(g._cpp_object, dofs)) + + annotate = kwargs.pop("annotate", True) + annotate = annotate and pyadjoint.annotate_tape() + + FloatingType.__init__( + self, + g, + dtype=dtype, + block_class=kwargs.pop("block_class", DirichletBCBlock), + _ad_floating_active=False, + _ad_args=kwargs.pop("_ad_args", (g, dofs)), + annotate=annotate, + **kwargs, + ) + + if annotate: + self._ad_annotate_block() + + def _ad_create_checkpoint(self): + return self + + def _ad_restore_at_checkpoint(self, checkpoint): + return self + + +def dirichletbc(value: Function, dofs: npt.NDArray[np.int32], **kwargs) -> DirichletBC: + """Overloaded DirichletBC constructor that creates an adjoint-aware DirichletBC""" + return DirichletBC(value, dofs, **kwargs) diff --git a/tests/test_dirichlet_bc.py b/tests/test_dirichlet_bc.py new file mode 100644 index 0000000..571553f --- /dev/null +++ b/tests/test_dirichlet_bc.py @@ -0,0 +1,137 @@ +from mpi4py import MPI + +import dolfinx +import numpy as np +import pyadjoint +import ufl +from pyadjoint.overloaded_type import Weakref + +from dolfinx_adjoint import Function, LinearProblem, assemble_scalar, assign, dirichletbc +from dolfinx_adjoint.blocks.dirichletbc import DirichletBCBlock + + +def test_dirichletbc_recording(): + """Test that creating an overloaded dirichletbc correctly registers a block and dependency on the tape.""" + pyadjoint.get_working_tape().clear_tape() + mesh = dolfinx.mesh.create_unit_interval(MPI.COMM_WORLD, 10) + V = dolfinx.fem.functionspace(mesh, ("Lagrange", 1)) + + c = Function(V, name="boundary_value") + c.interpolate(lambda x: x[0]) + + dofs = dolfinx.fem.locate_dofs_geometrical(V, lambda x: np.isclose(x[0], 0.0)) + bc = dirichletbc(c, dofs) + + tape = pyadjoint.get_working_tape() + blocks = tape.get_blocks() + + # The tape should have 1 block: DirichletBCBlock + assert len(blocks) == 1 + assert isinstance(blocks[0], DirichletBCBlock) + + # The block should have exactly 1 dependency (the function 'c') + assert len(blocks[0].get_dependencies()) == 1 + assert blocks[0].get_dependencies()[0].output is c + + # The returned BC object should now possess the injected block_variable + assert hasattr(bc, "block_variable") + + +def test_dirichletbc_no_annotate(): + """Test that setting annotate=False bypasses tape recording entirely.""" + + pyadjoint.get_working_tape().clear_tape() + mesh = dolfinx.mesh.create_unit_interval(MPI.COMM_WORLD, 10) + V = dolfinx.fem.functionspace(mesh, ("Lagrange", 1)) + + c = Function(V, name="boundary_value") + c.interpolate(lambda x: x[0]) + + dofs = dolfinx.fem.locate_dofs_geometrical(V, lambda x: np.isclose(x[0], 0.0)) + + # Run with annotation off + bc = dirichletbc(c, dofs, annotate=False) + + tape = pyadjoint.get_working_tape() + + assert len(tape.get_blocks()) == 0 + # FIX: Check the underlying weak reference rather than invoking the property + assert getattr(bc, "_block_variable", Weakref())() is None + + +def test_dirichletbc_recompute(): + """Test the PyAdjoint internal recompute logic specifically for the DirichletBCBlock.""" + pyadjoint.get_working_tape().clear_tape() + mesh = dolfinx.mesh.create_unit_interval(MPI.COMM_WORLD, 10) + V = dolfinx.fem.functionspace(mesh, ("Lagrange", 1)) + + c = Function(V, name="boundary_value") + c.interpolate(lambda x: np.full_like(x[0], 5.0)) + + dofs = dolfinx.fem.locate_dofs_geometrical(V, lambda x: np.isclose(x[0], 0.0)) + bc = dirichletbc(c, dofs) + + tape = pyadjoint.get_working_tape() + block = tape.get_blocks()[0] + + # Simulate an optimizer changing the function value + c.interpolate(lambda x: np.full_like(x[0], 15.0)) + + # Replay the PyAdjoint mechanics manually + prepared = block.prepare_recompute_component([c], None) + new_bc = block.recompute_component([c], bc.block_variable, 0, prepared) + + # Assert that the re-instantiated C++ object captured the updated control value + assert isinstance(new_bc, dolfinx.fem.bcs.DirichletBC) + assert np.isclose(new_bc.g.x.array[0], 15.0) + + +def test_time_dependent_bc_replay(): + pyadjoint.get_working_tape().clear_tape() + + mesh = dolfinx.mesh.create_unit_square(MPI.COMM_WORLD, 8, 8) + V = dolfinx.fem.functionspace(mesh, ("Lagrange", 1)) + + dt = 0.1 + num_steps = 3 + + m = Function(V, name="control") + m.interpolate(lambda x: np.sin(x[0] * np.pi)) + + u = ufl.TrialFunction(V) + v = ufl.TestFunction(V) + + uh = Function(V, name="state") + assign(0.0, uh) + + u_prev = Function(V, name="state_prev") + assign(0.0, u_prev) + + F = (u - u_prev) / dt * v * ufl.dx + ufl.inner(ufl.grad(u), ufl.grad(v)) * ufl.dx - m * v * ufl.dx + a, L = ufl.system(F) + + bc_func = Function(V, name="bc_func") + mesh.topology.create_connectivity(mesh.topology.dim - 1, mesh.topology.dim) + boundary_facets = dolfinx.mesh.exterior_facet_indices(mesh.topology) + boundary_dofs = dolfinx.fem.locate_dofs_topological(V, mesh.topology.dim - 1, boundary_facets) + + # Use native dolfinx here! PyAdjoint traces the bc_func inside it. + bc = dirichletbc(bc_func, boundary_dofs) + + problem = LinearProblem(a, L, bcs=[bc], u=uh) + + J = 0.0 + + for i in range(num_steps): + assign(float(i + 1), bc_func) + problem.solve() + J += assemble_scalar(0.5 * ufl.inner(uh, uh) * ufl.dx) + assign(uh, u_prev) + + J_forward = float(J) + + control = pyadjoint.Control(m) + Jhat = pyadjoint.ReducedFunctional(J, control) + J_replay = Jhat(m) + + assert np.isclose(J_replay, J_forward, atol=1e-10, rtol=1e-10) diff --git a/tests/test_hessian.py b/tests/test_hessian.py new file mode 100644 index 0000000..a8a05c7 --- /dev/null +++ b/tests/test_hessian.py @@ -0,0 +1,188 @@ +from mpi4py import MPI + +import dolfinx +import numpy as np +import pyadjoint +import ufl + +import dolfinx_adjoint + + +def test_constant_hessian(): + pyadjoint.get_working_tape().clear_tape() + + domain = dolfinx.mesh.create_unit_square(MPI.COMM_WORLD, 10, 10) + V = dolfinx.fem.functionspace(domain, ("Lagrange", 1)) + + # ========================================== + # 1. SETUP THE FORWARD PROBLEM + # PDE: -div(grad(u)) + m * u = f + # Where 'm' is our scalar control parameter + # ========================================== + u = ufl.TrialFunction(V) + v = ufl.TestFunction(V) + f = dolfinx_adjoint.Constant(domain, 1.0) + + # The true parameter value + m_val = 2.0 + m_control = dolfinx_adjoint.Constant(domain, m_val) + + # Weak form + a = ufl.inner(ufl.grad(u), ufl.grad(v)) * ufl.dx + m_control * ufl.inner(u, v) * ufl.dx + L = ufl.inner(f, v) * ufl.dx + + # Zero Dirichlet Boundary Conditions + domain.topology.create_connectivity(domain.topology.dim - 1, domain.topology.dim) + boundary_facets = dolfinx.mesh.exterior_facet_indices(domain.topology) + boundary_dofs = dolfinx.fem.locate_dofs_topological(V, domain.topology.dim - 1, boundary_facets) + uD = dolfinx_adjoint.Function(V) + uD.x.array[:] = 0.0 + bc = dolfinx_adjoint.dirichletbc(uD, boundary_dofs) + + # Solve and tape the PDE + u_sol = dolfinx_adjoint.Function(V, name="State") + problem = dolfinx_adjoint.LinearProblem(a, L, bcs=[bc], u=u_sol) + problem.solve() + + # ========================================== + # 2. SETUP DATA MISFIT + # J_data = 1/(2*var) * \int (u - u_obs)^2 dx + # ========================================== + u_obs = dolfinx_adjoint.Function(V) + u_obs.x.array[:] = 0.0 # Dummy observation + + J_form = 0.5 * ufl.inner(u_sol - u_obs, u_sol - u_obs) * ufl.dx + J_data = dolfinx_adjoint.assemble_scalar(J_form) + + # ========================================== + # 3. EXTRACT THE EXACT TRUE HESSIAN + # ========================================== + control = pyadjoint.Control(m_control) + Jhat = pyadjoint.ReducedFunctional(J_data, control) + + # To get the dense 1x1 Hessian matrix, we compute the Hessian Action + # in the standard basis direction (which for a scalar is simply 1.0) + direction = dolfinx_adjoint.Constant(domain, 1.0) + hessian_action = Jhat.hessian(direction) + + # Cast the action result to a standard Python float + H_misfit = hessian_action.x.array[0] + + # Ensure PyAdjoint actually computed a non-zero curvature! + assert H_misfit > 0.0, "Hessian computation failed or is zero!" + + +def test_constant_hessian_assemble_only(): + """ + Test 1: Does AssembleBlock support Hessians for Constants? + J(m) = 0.5 * (m - 5)**2 * vol + d2J/dm2 = 1.0 * vol + """ + pyadjoint.get_working_tape().clear_tape() + domain = dolfinx.mesh.create_unit_square(MPI.COMM_WORLD, 3, 3) + + m = dolfinx_adjoint.Constant(domain, 3.0) + + # J = 0.5 * (m - 5)^2 * dx + J_form = 0.5 * (m - 5.0) ** 2 * ufl.dx(domain) + J = dolfinx_adjoint.assemble_scalar(J_form) + + control = pyadjoint.Control(m) + Jhat = pyadjoint.ReducedFunctional(J, control) + + # Direction m_t = 1.0 + direction = dolfinx_adjoint.Constant(domain, 1.0) + hessian_action = Jhat.hessian(direction) + + # Expected Hessian is simply the volume of the domain (1.0 for a unit square) + H_val = hessian_action.x.array[0] + + assert H_val > 0.0, f"AssembleBlock Hessian failed! Value is {H_val}" + assert np.isclose(H_val, 1.0), f"Expected 1.0, got {H_val}" + + +def test_constant_hessian_linear_source(): + """ + Test 2: Does LinearProblemBlock support TLM and SOA for linear parameters? + PDE: -div(grad(u)) = m + J(u) = 0.5 * u**2 * dx + Here, d2F/dudm = 0, so the Hessian is purely the pullback of the objective Hessian. + """ + pyadjoint.get_working_tape().clear_tape() + domain = dolfinx.mesh.create_unit_square(MPI.COMM_WORLD, 3, 3) + V = dolfinx.fem.functionspace(domain, ("Lagrange", 1)) + + m = dolfinx_adjoint.Constant(domain, 2.0) + + u = ufl.TrialFunction(V) + v = ufl.TestFunction(V) + + # m is just a source term (linear dependence) + a = ufl.inner(ufl.grad(u), ufl.grad(v)) * ufl.dx + L = m * v * ufl.dx + + domain.topology.create_connectivity(domain.topology.dim - 1, domain.topology.dim) + boundary_facets = dolfinx.mesh.exterior_facet_indices(domain.topology) + boundary_dofs = dolfinx.fem.locate_dofs_topological(V, domain.topology.dim - 1, boundary_facets) + u_bc = dolfinx_adjoint.Function(V) + u_bc.x.array[:] = 0.0 + bc = dolfinx_adjoint.dirichletbc(u_bc, boundary_dofs) + + u_sol = dolfinx_adjoint.Function(V) + problem = dolfinx_adjoint.LinearProblem(a, L, bcs=[bc], u=u_sol) + problem.solve() + + J_form = 0.5 * ufl.inner(u_sol, u_sol) * ufl.dx + J = dolfinx_adjoint.assemble_scalar(J_form) + + control = pyadjoint.Control(m) + Jhat = pyadjoint.ReducedFunctional(J, control) + + direction = dolfinx_adjoint.Constant(domain, 1.0) + hessian_action = Jhat.hessian(direction) + H_val = hessian_action.x.array[0] + + assert H_val > 0.0, f"Linear source Hessian failed! Value is {H_val}" + + +def test_constant_hessian_linear_operator(): + """ + Test 3: Does LinearProblemBlock support cross-derivatives (d2F/dudm)? + PDE: -div(grad(u)) + m * u = f + """ + pyadjoint.get_working_tape().clear_tape() + domain = dolfinx.mesh.create_unit_square(MPI.COMM_WORLD, 3, 3) + V = dolfinx.fem.functionspace(domain, ("Lagrange", 1)) + + m = dolfinx_adjoint.Constant(domain, 2.0) + f = dolfinx_adjoint.Constant(domain, 1.0) + + u = ufl.TrialFunction(V) + v = ufl.TestFunction(V) + + # m multiplies u (non-linear dependence on the parameter) + a = ufl.inner(ufl.grad(u), ufl.grad(v)) * ufl.dx + m * ufl.inner(u, v) * ufl.dx + L = f * v * ufl.dx + + domain.topology.create_connectivity(domain.topology.dim - 1, domain.topology.dim) + boundary_facets = dolfinx.mesh.exterior_facet_indices(domain.topology) + boundary_dofs = dolfinx.fem.locate_dofs_topological(V, domain.topology.dim - 1, boundary_facets) + u_bc = dolfinx_adjoint.Function(V) + u_bc.x.array[:] = 0.0 + bc = dolfinx_adjoint.dirichletbc(u_bc, boundary_dofs) + + u_sol = dolfinx_adjoint.Function(V) + problem = dolfinx_adjoint.LinearProblem(a, L, bcs=[bc], u=u_sol) + problem.solve() + + J_form = 0.5 * ufl.inner(u_sol, u_sol) * ufl.dx + J = dolfinx_adjoint.assemble_scalar(J_form) + + control = pyadjoint.Control(m) + Jhat = pyadjoint.ReducedFunctional(J, control) + + direction = dolfinx_adjoint.Constant(domain, 1.0) + hessian_action = Jhat.hessian(direction) + H_val = hessian_action.x.array[0] + + assert H_val > 0.0, f"Operator Hessian failed! Value is {H_val}"