from FIAT.hdivcurl import Hdiv, Hcurl
from FIAT.reference_element import LINE
import gem
from gem.utils import cached_property
from finat.finiteelementbase import FiniteElementBase
from finat.tensor_product import TensorProductElement
[docs]class WrapperElementBase(FiniteElementBase):
"""Common base class for H(div) and H(curl) element wrappers."""
def __init__(self, wrappee, transform):
super(WrapperElementBase, self).__init__()
self.wrappee = wrappee
"""An appropriate tensor product FInAT element whose basis
functions are mapped to produce an H(div) or H(curl)
conforming element."""
self.transform = transform
"""A transformation applied on the scalar/vector values of the
wrapped element to produce an H(div) or H(curl) conforming
element."""
@property
def cell(self):
return self.wrappee.cell
@property
def complex(self):
return self.wrappee.complex
@property
def degree(self):
return self.wrappee.degree
[docs] def entity_dofs(self):
return self.wrappee.entity_dofs()
@property
def entity_permutations(self):
return self.wrappee.entity_permutations
[docs] def entity_closure_dofs(self):
return self.wrappee.entity_closure_dofs()
[docs] def entity_support_dofs(self):
return self.wrappee.entity_support_dofs()
[docs] def space_dimension(self):
return self.wrappee.space_dimension()
@property
def index_shape(self):
return self.wrappee.index_shape
@property
def value_shape(self):
return (self.cell.get_spatial_dimension(),)
def _transform_evaluation(self, core_eval):
beta = self.get_indices()
zeta = self.get_value_indices()
def promote(table):
v = gem.partial_indexed(table, beta)
u = gem.ListTensor(self.transform(v))
return gem.ComponentTensor(gem.Indexed(u, zeta), beta + zeta)
return {alpha: promote(table)
for alpha, table in core_eval.items()}
[docs] def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
core_eval = self.wrappee.basis_evaluation(order, ps, entity)
return self._transform_evaluation(core_eval)
[docs] def point_evaluation(self, order, refcoords, entity=None):
core_eval = self.wrappee.point_evaluation(order, refcoords, entity)
return self._transform_evaluation(core_eval)
@property
def dual_basis(self):
Q, x = self.wrappee.dual_basis
beta = self.get_indices()
zeta = self.get_value_indices()
# Index out the basis indices from wrapee's Q, to get
# something of wrappee.value_shape, then promote to new shape
# with the same transform as done for basis evaluation
Q = gem.ListTensor(self.transform(gem.partial_indexed(Q, beta)))
# Finally wrap up Q in shape again (now with some extra
# value_shape indices)
return gem.ComponentTensor(Q[zeta], beta + zeta), x
[docs]class HDivElement(WrapperElementBase):
"""H(div) wrapper element for tensor product elements."""
def __init__(self, wrappee):
assert isinstance(wrappee, TensorProductElement)
if any(fe.formdegree is None for fe in wrappee.factors):
raise ValueError("Form degree of subelement is None, cannot H(div)!")
formdegree = sum(fe.formdegree for fe in wrappee.factors)
if formdegree != wrappee.cell.get_spatial_dimension() - 1:
raise ValueError("H(div) requires (n-1)-form element!")
transform = select_hdiv_transformer(wrappee)
super(HDivElement, self).__init__(wrappee, transform)
@property
def formdegree(self):
return self.cell.get_spatial_dimension() - 1
[docs] @cached_property
def fiat_equivalent(self):
return Hdiv(self.wrappee.fiat_equivalent)
@property
def mapping(self):
return "contravariant piola"
[docs]class HCurlElement(WrapperElementBase):
"""H(curl) wrapper element for tensor product elements."""
def __init__(self, wrappee):
assert isinstance(wrappee, TensorProductElement)
if any(fe.formdegree is None for fe in wrappee.factors):
raise ValueError("Form degree of subelement is None, cannot H(curl)!")
formdegree = sum(fe.formdegree for fe in wrappee.factors)
if formdegree != 1:
raise ValueError("H(curl) requires 1-form element!")
transform = select_hcurl_transformer(wrappee)
super(HCurlElement, self).__init__(wrappee, transform)
@property
def formdegree(self):
return 1
[docs] @cached_property
def fiat_equivalent(self):
return Hcurl(self.wrappee.fiat_equivalent)
@property
def mapping(self):
return "covariant piola"