Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions firedrake/cython/dmcommon.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2203,11 +2203,11 @@ def _get_expanded_dm_dg_coords(dm: PETSc.DM, ndofs: np.ndarray):

def _get_periodicity(dm: PETSc.DM) -> tuple[tuple[bool, bool], ...]:
"""Return mesh periodicity information.

This function returns a 2-tuple of bools per dimension where the first entry indicates
whether the mesh is periodic in that dimension, and the second indicates whether the
mesh is single-cell periodic in that dimension.

"""
cdef:
const PetscReal *maxCell, *L
Expand Down Expand Up @@ -4325,3 +4325,41 @@ def get_dm_cell_types(PETSc.DM dm):
return tuple(
polytope_type_enum for polytope_type_enum, found in enumerate(found_all) if found
)


def create_label_intersection(PETSc.DM dm, label_name, label_values):
"""Return the intersection of the closure of a subdomains of a DMPlex.

Parameters
----------
dm : PETSc.DM
The DMPlex.
label_name : str
The name of the label
label_values : Sequence[int]
The values of the subdomain label to intersect

Returns
-------
tuple
Comment thread
pbrubeck marked this conversation as resolved.
Outdated
A PETSc.IS with the points in the intersection.

"""
cdef:
PETSc.DMLabel label
PETSc.PetscIS is1, is2
PetscInt val = label_values[0]

label = dm.getLabel(label_name)
CHKERR(DMPlexLabelComplete(dm.dm, label.dmlabel))
CHKERR(DMLabelGetStratumIS(<DMLabel>label.dmlabel, val, &is1))

for i in range(1, len(label_values)):
iout = PETSc.IS()
val = label_values[i]
CHKERR(DMLabelGetStratumIS(<DMLabel>label.dmlabel, val, &is2))
CHKERR(ISIntersect(is1, is2, &(<PETSc.IS?>iout).iset))
CHKERR(ISDestroy(&is1))
CHKERR(ISDestroy(&is2))
is1 = (<PETSc.IS?>iout).iset
return iout
1 change: 1 addition & 0 deletions firedrake/cython/petschdr.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ cdef extern from "petscis.h" nogil:
PetscErrorCode ISLocalToGlobalMappingGetBlockIndices(PETSc.PetscLGMap, const PetscInt**)
PetscErrorCode ISLocalToGlobalMappingRestoreBlockIndices(PETSc.PetscLGMap, const PetscInt**)
PetscErrorCode ISDestroy(PETSc.PetscIS*)
PetscErrorCode ISIntersect(PETSc.PetscIS, PETSc.PetscIS, PETSc.PetscIS*)

cdef extern from "petscsf.h" nogil:
struct PetscSFNode_:
Expand Down
46 changes: 41 additions & 5 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pyop2.mpi import (
MPI, COMM_WORLD, temp_internal_comm
)
from functools import cached_property
from functools import cached_property, reduce

Check failure on line 28 in firedrake/mesh.py

View workflow job for this annotation

GitHub Actions / test / Lint codebase

F401

firedrake/mesh.py:28:1: F401 'functools.reduce' imported but unused
from pyop2.utils import as_tuple
import petsctools
from petsctools import OptionsManager, get_external_packages
Expand Down Expand Up @@ -4806,10 +4806,12 @@
subdim : int | None
Topological dimension of the submesh.
Defaults to ``mesh.topological_dimension``.
subdomain_id : int | None
subdomain_id : int | Sequence | None
Subdomain ID representing the submesh.
If `None` the submesh will cover the entire domain.
This is useful to obtain a codim-1 submesh over all facets or
If multiple subdomain IDs are provided, their union is taken.
If nested lists of subdomain IDs are provided, their intersection is taken.
Comment thread
pbrubeck marked this conversation as resolved.
If `None` the submesh will cover the entire domain,
this is useful to obtain a codim-1 submesh over all facets or
a submesh over a different communicator.
label_name : str | None
Name of the label to search ``subdomain_id`` in.
Expand Down Expand Up @@ -4861,6 +4863,8 @@
raise NotImplementedError("Can not create a submesh of a ``VertexOnlyMesh``")
if subdim is None:
subdim = mesh.topological_dimension
if subdomain_id == "on_boundary":
subdim = subdim - 1
plex = mesh.topology_dm
dim = plex.getDimension()
if subdim not in {dim, dim - 1}:
Expand All @@ -4876,15 +4880,47 @@
label_name = dmcommon.CELL_SETS_LABEL
elif subdim == dim - 1:
label_name = dmcommon.FACE_SETS_LABEL

# Parse non-integer subdomain_id
if isinstance(subdomain_id, str):
if subdomain_id == "on_boundary":
subdomain_id = tuple(mesh.exterior_facets.unique_markers)
else:
raise ValueError(f"Submesh construction got invalid subdomain_id {subdomain_id}.")

if isinstance(subdomain_id, Sequence):
Comment thread
pbrubeck marked this conversation as resolved.
Outdated
# Create a temporary DMLabel with the union of the labels in the list
icomm = comm or mesh.comm
iset = PETSc.IS().createGeneral([], comm=icomm)
for sub in subdomain_id:
try:
sub, = sub
except ValueError:
Comment thread
pbrubeck marked this conversation as resolved.
Outdated
pass
if isinstance(sub, Sequence):
# Take the intersection of the (closure of the) labels from nested lists
cur = dmcommon.create_label_intersection(plex, label_name, sub)
else:
cur = plex.getStratumIS(label_name, sub)
iset = iset.union(cur)
label_name = "temp_label"
subdomain_id = 1
plex.createLabel(label_name)
label = plex.getLabel(label_name)
label.setStratumIS(subdomain_id, iset)

subplex = dmcommon.submesh_create(plex, subdim, label_name, subdomain_id, ignore_halo, comm=comm)

if label_name == "temp_label":
plex.removeLabel(label_name)

comm = comm or mesh.comm
name = name or _generate_default_submesh_name(mesh.name)
subplex.setName(_generate_default_mesh_topology_name(name))
if subplex.getDimension() != subdim:
raise RuntimeError(f"Found subplex dim ({subplex.getDimension()}) != expected ({subdim})")
if reorder is None:
# Ideally we should set perm_is = mesh.dm_reordering[label_indices]
# Ideally we should set perm_is = mesh._dm_renumbering[label_indices]
reorder = mesh._did_reordering

submesh = Mesh(
Expand Down
93 changes: 93 additions & 0 deletions tests/firedrake/submesh/test_submesh_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import pytest
import numpy as np
from firedrake import *


def test_submesh_subdomain_id_union():
mesh = UnitSquareMesh(4, 4)
x, y = SpatialCoordinate(mesh)
M = FunctionSpace(mesh, "DG", 0)
m1 = Function(M).interpolate(conditional(lt(x, 0.5), 1, 0))
m2 = Function(M).interpolate(conditional(lt(y, 0.5), 1, 0))
mesh.mark_entities(m1, 111)
mesh.mark_entities(m2, 222)

subdomain_id = [111, 222]
submesh1 = Submesh(mesh, mesh.topological_dimension, subdomain_id=subdomain_id)

m3 = Function(M).interpolate(m1 + m2 - m1 * m2)
expected = assemble(m3*dx)
assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12

mesh.mark_entities(m3, 333)
submesh2 = Submesh(mesh, mesh.topological_dimension, 333)
assert submesh2.cell_set.size == submesh1.cell_set.size
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)


def test_submesh_subdomain_id_intersection():
mesh = UnitSquareMesh(4, 4)
x, y = SpatialCoordinate(mesh)
M = FunctionSpace(mesh, "DG", 0)
m1 = Function(M).interpolate(conditional(lt(x, 0.5), 1, 0))
m2 = Function(M).interpolate(conditional(lt(y, 0.5), 1, 0))
mesh.mark_entities(m1, 111)
mesh.mark_entities(m2, 222)

subdomain_id = [(111, 222)]
submesh1 = Submesh(mesh, mesh.topological_dimension, subdomain_id=subdomain_id)

m3 = Function(M).interpolate(m1 * m2)
expected = assemble(m3*dx)
assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12

mesh.mark_entities(m3, 333)
submesh2 = Submesh(mesh, mesh.topological_dimension, 333)
assert submesh2.cell_set.size == submesh1.cell_set.size
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)


@pytest.mark.parametrize("subdomain_id", ["on_boundary", (1, 3, 6)])
def test_submesh_facet_subdomain_id_union(subdomain_id):
mesh = UnitCubeMesh(2, 2, 2)
submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id)
if subdomain_id == "on_boundary":
area = assemble(1*ds(domain=mesh))
else:
area = assemble(1*ds(subdomain_id, domain=mesh))
assert abs(assemble(1*dx(domain=submesh1)) - area) < 1E-12

V = FunctionSpace(mesh, "HDiv Trace", 0)
facet_function = Function(V)
DirichletBC(V, 1, subdomain_id).apply(facet_function)
facet_value = 999
rmesh = RelabeledMesh(mesh, [facet_function], [facet_value])
submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value)
assert submesh2.cell_set.size == submesh1.cell_set.size
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)


def test_submesh_facet_subdomain_id_intersection():
mesh = UnitSquareMesh(4, 4)
x, y = SpatialCoordinate(mesh)
M = FunctionSpace(mesh, "DG", 0)
m1 = Function(M).interpolate(conditional(lt(x, 0.5), 1, 0))
m2 = Function(M).interpolate(conditional(lt(x, 0.5), 0, 1))
mesh.mark_entities(m1, 111)
mesh.mark_entities(m2, 222)

subdomain_id = [(111, 222)]
submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id, label_name="Cell Sets")
Comment thread
pbrubeck marked this conversation as resolved.
Outdated

expected = 1
assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12

x, y = SpatialCoordinate(mesh)
V = FunctionSpace(mesh, "HDiv Trace", 0)
facet_function = Function(V)
facet_function.interpolate(conditional(lt(abs(x-0.5), 1E-8), 1, 0))
facet_value = 999
rmesh = RelabeledMesh(mesh, [facet_function], [facet_value])
submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value)
assert submesh2.cell_set.size == submesh1.cell_set.size
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)
Loading