Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pyop2.mpi import (
MPI, COMM_WORLD, temp_internal_comm
)
from functools import cached_property
from functools import cached_property, reduce
from pyop2.utils import as_tuple
import petsctools
from petsctools import OptionsManager, get_external_packages
Expand Down Expand Up @@ -4806,10 +4806,12 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig
subdim : int | None
Topological dimension of the submesh.
Defaults to ``mesh.topological_dimension``.
subdomain_id : int | None
subdomain_id : int | Sequence | None
Subdomain ID representing the submesh.
If `None` the submesh will cover the entire domain.
This is useful to obtain a codim-1 submesh over all facets or
If multiple subdomain IDs are provided, their union is taken.
If nested lists of subdomain IDs are provided, their intersection is taken.
Comment thread
pbrubeck marked this conversation as resolved.
If `None` the submesh will cover the entire domain,
this is useful to obtain a codim-1 submesh over all facets or
a submesh over a different communicator.
label_name : str | None
Name of the label to search ``subdomain_id`` in.
Expand Down Expand Up @@ -4876,15 +4878,43 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig
label_name = dmcommon.CELL_SETS_LABEL
elif subdim == dim - 1:
label_name = dmcommon.FACE_SETS_LABEL

# Parse non-integer subdomain_id
if subdomain_id == "on_boundary":
subdomain_id = tuple(mesh.exterior_facets.unique_markers)
Comment thread
pbrubeck marked this conversation as resolved.
Outdated

if isinstance(subdomain_id, Sequence):
Comment thread
pbrubeck marked this conversation as resolved.
Outdated
# Create a temporary DMLabel with the union of the labels in the list
icomm = comm or mesh.comm
iset = PETSc.IS().createGeneral([], comm=icomm)
for sub in subdomain_id:
if isinstance(sub, Sequence):
# Take the intersection of the (closure of the) labels from nested lists
ises = [plex.getStratumIS(label_name, subi) for subi in sub]
closure = [[plex.getTransitiveClosure(p)[0] for p in i.indices] for i in ises]
Comment thread
pbrubeck marked this conversation as resolved.
Outdated
indices = reduce(np.intersect1d, closure)
cur = PETSc.IS().createGeneral(indices, comm=icomm)
else:
cur = plex.getStratumIS(label_name, sub)
iset = iset.union(cur)
label_name = "temp_label"
subdomain_id = 1
plex.createLabel(label_name)
label = plex.getLabel(label_name)
label.setStratumIS(subdomain_id, iset)

subplex = dmcommon.submesh_create(plex, subdim, label_name, subdomain_id, ignore_halo, comm=comm)

if label_name == "temp_label":
plex.removeLabel(label_name)

comm = comm or mesh.comm
name = name or _generate_default_submesh_name(mesh.name)
subplex.setName(_generate_default_mesh_topology_name(name))
if subplex.getDimension() != subdim:
raise RuntimeError(f"Found subplex dim ({subplex.getDimension()}) != expected ({subdim})")
if reorder is None:
# Ideally we should set perm_is = mesh.dm_reordering[label_indices]
# Ideally we should set perm_is = mesh._dm_renumbering[label_indices]
reorder = mesh._did_reordering

submesh = Mesh(
Expand Down
93 changes: 93 additions & 0 deletions tests/firedrake/submesh/test_submesh_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import pytest
import numpy as np
from firedrake import *


def test_submesh_subdomain_id_tuple():
Comment thread
pbrubeck marked this conversation as resolved.
Outdated
mesh = UnitSquareMesh(4, 4)
x, y = SpatialCoordinate(mesh)
M = FunctionSpace(mesh, "DG", 0)
m1 = Function(M).interpolate(conditional(lt(x, 0.5), 1, 0))
m2 = Function(M).interpolate(conditional(lt(y, 0.5), 1, 0))
mesh.mark_entities(m1, 111)
mesh.mark_entities(m2, 222)

subdomain_id = [111, 222]
submesh1 = Submesh(mesh, mesh.topological_dimension, subdomain_id=subdomain_id)

m3 = Function(M).interpolate(m1 + m2 - m1 * m2)
expected = assemble(m3*dx)
assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12

mesh.mark_entities(m3, 333)
submesh2 = Submesh(mesh, mesh.topological_dimension, 333)
assert submesh2.cell_set.size == submesh1.cell_set.size
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)


def test_submesh_subdomain_id_nested_tuple():
Comment thread
pbrubeck marked this conversation as resolved.
Outdated
mesh = UnitSquareMesh(4, 4)
x, y = SpatialCoordinate(mesh)
M = FunctionSpace(mesh, "DG", 0)
m1 = Function(M).interpolate(conditional(lt(x, 0.5), 1, 0))
m2 = Function(M).interpolate(conditional(lt(y, 0.5), 1, 0))
mesh.mark_entities(m1, 111)
mesh.mark_entities(m2, 222)

subdomain_id = [(111, 222)]
submesh1 = Submesh(mesh, mesh.topological_dimension, subdomain_id=subdomain_id)

m3 = Function(M).interpolate(m1 * m2)
expected = assemble(m3*dx)
assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12

mesh.mark_entities(m3, 333)
submesh2 = Submesh(mesh, mesh.topological_dimension, 333)
assert submesh2.cell_set.size == submesh1.cell_set.size
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)


@pytest.mark.parametrize("subdomain_id", ["on_boundary", (1, 3, 6)])
def test_submesh_facet_subdomain_id_tuple(subdomain_id):
Comment thread
pbrubeck marked this conversation as resolved.
Outdated
mesh = UnitCubeMesh(2, 2, 2)
submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id)
if subdomain_id == "on_boundary":
area = assemble(1*ds(domain=mesh))
else:
area = assemble(1*ds(subdomain_id, domain=mesh))
assert abs(assemble(1*dx(domain=submesh1)) - area) < 1E-12

V = FunctionSpace(mesh, "HDiv Trace", 0)
facet_function = Function(V)
DirichletBC(V, 1, subdomain_id).apply(facet_function)
facet_value = 999
rmesh = RelabeledMesh(mesh, [facet_function], [facet_value])
submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value)
assert submesh2.cell_set.size == submesh1.cell_set.size
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)


def test_submesh_facet_subdomain_id_nested_tuple():
Comment thread
pbrubeck marked this conversation as resolved.
Outdated
mesh = UnitSquareMesh(4, 4)
x, y = SpatialCoordinate(mesh)
M = FunctionSpace(mesh, "DG", 0)
m1 = Function(M).interpolate(conditional(lt(x, 0.5), 1, 0))
m2 = Function(M).interpolate(conditional(lt(x, 0.5), 0, 1))
mesh.mark_entities(m1, 111)
mesh.mark_entities(m2, 222)

subdomain_id = [(111, 222)]
submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id, label_name="Cell Sets")
Comment thread
pbrubeck marked this conversation as resolved.
Outdated

expected = 1
assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12

x, y = SpatialCoordinate(mesh)
V = FunctionSpace(mesh, "HDiv Trace", 0)
facet_function = Function(V)
facet_function.interpolate(conditional(lt(abs(x-0.5), 1E-8), 1, 0))
facet_value = 999
rmesh = RelabeledMesh(mesh, [facet_function], [facet_value])
submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value)
assert submesh2.cell_set.size == submesh1.cell_set.size
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)
Loading