Skip to content

Commit

Permalink
Boundary Quadrature element
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 9, 2025
1 parent 6147622 commit 7f2dd72
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 28 deletions.
2 changes: 1 addition & 1 deletion FIAT/polynomial_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def tabulate_new(self, pts):
def tabulate(self, pts, jet_order=0):
"""Returns the values of the polynomial set."""
base_vals = self.expansion_set._tabulate(self.embedded_degree, pts, order=jet_order)
result = {alpha: numpy.dot(self.coeffs, base_vals[alpha]) for alpha in base_vals}
result = {alpha: numpy.tensordot(self.coeffs, base_vals[alpha], (-1, 0)) for alpha in base_vals}
return result

def get_expansion_set(self):
Expand Down
7 changes: 4 additions & 3 deletions finat/element_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,14 @@ def convert(element, **kwargs):
@convert.register(finat.ufl.FiniteElement)
def convert_finiteelement(element, **kwargs):
cell = as_fiat_cell(element.cell)
if element.family() == "Quadrature":
if element.family() in {"Quadrature", "Boundary Quadrature"}:
degree = element.degree()
scheme = element.quadrature_scheme()
scheme = element.quadrature_scheme() or "default"
if degree is None or scheme is None:
raise ValueError("Quadrature scheme and degree must be specified!")

return finat.make_quadrature_element(cell, degree, scheme), set()
codim = 1 if element.family() == "Boundary Quadrature" else 0
return finat.make_quadrature_element(cell, degree, scheme, codim), set()
lmbda = supported_elements[element.family()]
if element.family() == "Real" and element.cell.cellname() in {"quadrilateral", "hexahedron"}:
lmbda = None
Expand Down
7 changes: 3 additions & 4 deletions finat/fiat_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
'''
space_dimension = self._element.space_dimension()
value_size = np.prod(self._element.value_shape(), dtype=int)
fiat_result = self._element.tabulate(order, ps.points, entity)
fiat_result = self._element.tabulate(order, ps.points.reshape(-1, ps.points.shape[-1]), entity)
result = {}
# In almost all cases, we have
# self.space_dimension() == self._element.space_dimension()
Expand All @@ -116,9 +116,8 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
continue

derivative = sum(alpha)
table_roll = fiat_table.reshape(
space_dimension, value_size, len(ps.points)
).transpose(1, 2, 0)
table = fiat_table.reshape(space_dimension, value_size, *ps.points.shape[:-1])
table_roll = np.moveaxis(table, 0, -1)

exprs = []
for table in table_roll:
Expand Down
7 changes: 3 additions & 4 deletions finat/point_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def points(self):
@property
def dimension(self):
"""Point dimension."""
_, dim = self.points.shape
return dim
return self.points.shape[-1]

@abstractproperty
def indices(self):
Expand Down Expand Up @@ -130,7 +129,7 @@ def __init__(self, points):
:arg points: A vector of N points of shape (N, D) where D is the
dimension of each point."""
points = numpy.asarray(points)
assert len(points.shape) == 2
assert len(points.shape) > 1
self.points = points

@cached_property
Expand All @@ -139,7 +138,7 @@ def points(self):

@cached_property
def indices(self):
return (gem.Index(extent=len(self.points)),)
return tuple(gem.Index(extent=e) for e in self.points.shape[:-1])

@cached_property
def expression(self):
Expand Down
67 changes: 51 additions & 16 deletions finat/quadrature_element.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from finat.point_set import UnknownPointSet
from finat.point_set import UnknownPointSet, PointSet
from functools import reduce

import numpy
Expand All @@ -13,7 +13,7 @@
from finat.quadrature import make_quadrature, AbstractQuadratureRule


def make_quadrature_element(fiat_ref_cell, degree, scheme="default"):
def make_quadrature_element(fiat_ref_cell, degree, scheme="default", codim=0):
"""Construct a :class:`QuadratureElement` from a given a reference
element, degree and scheme.
Expand All @@ -23,9 +23,16 @@ def make_quadrature_element(fiat_ref_cell, degree, scheme="default"):
integrate exactly.
:param scheme: The quadrature scheme to use - e.g. "default",
"canonical" or "KMV".
:param codim: The codimension of the quadrature scheme.
:returns: The appropriate :class:`QuadratureElement`
"""
rule = make_quadrature(fiat_ref_cell, degree, scheme=scheme)
if codim:
sd = fiat_ref_cell.get_spatial_dimension()
rule_ref_cell = fiat_ref_cell.construct_subcomplex(sd - codim)
else:
rule_ref_cell = fiat_ref_cell

rule = make_quadrature(rule_ref_cell, degree, scheme=scheme)
return QuadratureElement(fiat_ref_cell, rule)


Expand All @@ -42,8 +49,6 @@ def __init__(self, fiat_ref_cell, rule):
self.cell = fiat_ref_cell
if not isinstance(rule, AbstractQuadratureRule):
raise TypeError("rule is not an AbstractQuadratureRule")
if fiat_ref_cell.get_spatial_dimension() != rule.point_set.dimension:
raise ValueError("Cell dimension does not match rule's point set dimension")
self._rule = rule

@cached_property
Expand All @@ -64,10 +69,16 @@ def formdegree(self):

@cached_property
def _entity_dofs(self):
# Inspired by ffc/quadratureelement.py
top = self.cell.get_topology()
entity_dofs = {dim: {entity: [] for entity in entities}
for dim, entities in self.cell.get_topology().items()}
entity_dofs[self.cell.get_dimension()] = {0: list(range(self.space_dimension()))}
for dim, entities in top.items()}
ps = self._rule.point_set
dim = ps.dimension
num_pts = len(ps.points)
cur = 0
for entity in sorted(top[dim]):
entity_dofs[dim][entity] = list(range(cur, cur + num_pts))
cur += num_pts
return entity_dofs

def entity_dofs(self):
Expand All @@ -76,9 +87,22 @@ def entity_dofs(self):
def space_dimension(self):
return numpy.prod(self.index_shape, dtype=int)

@cached_property
def _point_set(self):
ps = self._rule.point_set
sd = self.cell.get_spatial_dimension()
dim = ps.dimension
if dim != sd:
# Tile the quadrature rule on each subentity
entity_ids = self.entity_dofs()
pts = [self.cell.get_entity_transform(dim, entity)(ps.points)
for entity in entity_ids[dim]]
ps = PointSet(numpy.stack(pts, axis=0))
return ps

@property
def index_shape(self):
ps = self._rule.point_set
ps = self._point_set
return tuple(index.extent for index in ps.indices)

@property
Expand All @@ -87,7 +111,7 @@ def value_shape(self):

@cached_property
def fiat_equivalent(self):
ps = self._rule.point_set
ps = self._point_set
if isinstance(ps, UnknownPointSet):
raise ValueError("A quadrature element with rule with runtime points has no fiat equivalent!")
weights = getattr(self._rule, 'weights', None)
Expand All @@ -107,8 +131,13 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
:param ps: the point set object.
:param entity: the cell entity on which to tabulate.
'''
if entity is not None and entity != (self.cell.get_dimension(), 0):
raise ValueError('QuadratureElement does not "tabulate" on subentities.')
rule_dim = self._rule.point_set.dimension
if entity is None:
entity = (rule_dim, 0)
entity_dim, entity_id = entity
if entity_dim != rule_dim:
raise ValueError(f"Cannot tabulate QuadratureElement of dimension {rule_dim}"
f" on subentities of dimension {entity_dim}.")

if order:
raise ValueError("Derivatives are not defined on a QuadratureElement.")
Expand All @@ -119,17 +148,23 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
# Return an outer product of identity matrices
multiindex = self.get_indices()
product = reduce(gem.Product, [gem.Delta(q, r)
for q, r in zip(ps.indices, multiindex)])
for q, r in zip(ps.indices, multiindex[-len(ps.indices):])])

dim = self.cell.get_spatial_dimension()
return {(0,) * dim: gem.ComponentTensor(product, multiindex)}
sd = self.cell.get_spatial_dimension()
if sd != ps.dimension:
data = numpy.zeros(self.index_shape[:-1], dtype=object)
data[...] = gem.Zero()
data[entity_id] = gem.Literal(1)
product = gem.Product(product, gem.Indexed(gem.ListTensor(data), multiindex[:1]))

return {(0,) * sd: gem.ComponentTensor(product, multiindex)}

def point_evaluation(self, order, refcoords, entity=None):
raise NotImplementedError("QuadratureElement cannot do point evaluation!")

@property
def dual_basis(self):
ps = self._rule.point_set
ps = self._point_set
multiindex = self.get_indices()
# Evaluation matrix is just an outer product of identity
# matrices, evaluation points are just the quadrature points.
Expand Down

0 comments on commit 7f2dd72

Please sign in to comment.