Source code for finat.tensor_product

from functools import reduce
from itertools import chain, product
from operator import methodcaller

import numpy

import FIAT
from FIAT.polynomial_set import mis
from FIAT.reference_element import TensorProductCell
from FIAT.orientation_utils import make_entity_permutations_tensorproduct

import gem
from gem.utils import cached_property

from finat.finiteelementbase import FiniteElementBase
from finat.point_set import PointSingleton, PointSet, TensorPointSet


[docs]class TensorProductElement(FiniteElementBase): def __init__(self, factors): super(TensorProductElement, self).__init__() self.factors = tuple(factors) shapes = [fe.value_shape for fe in self.factors if fe.value_shape != ()] if len(shapes) == 0: self._value_shape = () elif len(shapes) == 1: self._value_shape = shapes[0] else: raise NotImplementedError("Only one nonscalar factor permitted!")
[docs] @cached_property def cell(self): return TensorProductCell(*[fe.cell for fe in self.factors])
[docs] @cached_property def complex(self): return TensorProductCell(*[fe.complex for fe in self.factors])
@property def degree(self): return tuple(fe.degree for fe in self.factors)
[docs] @cached_property def formdegree(self): if any(fe.formdegree is None for fe in self.factors): return None else: return sum(fe.formdegree for fe in self.factors)
@cached_property def _entity_dofs(self): return productise(self.factors, methodcaller("entity_dofs")) @cached_property def _entity_support_dofs(self): return productise(self.factors, methodcaller("entity_support_dofs"))
[docs] def entity_dofs(self): return self._entity_dofs
[docs] @cached_property def entity_permutations(self): return compose_permutations(self.factors)
[docs] def space_dimension(self): return numpy.prod([fe.space_dimension() for fe in self.factors])
@property def index_shape(self): return tuple(chain(*[fe.index_shape for fe in self.factors])) @property def value_shape(self): return self._value_shape
[docs] @cached_property def fiat_equivalent(self): # FIAT TensorProductElement support only 2 factors A, B = self.factors return FIAT.TensorProductElement(A.fiat_equivalent, B.fiat_equivalent)
def _factor_entity(self, entity): # Default entity if entity is None: entity = (self.cell.get_dimension(), 0) entity_dim, entity_id = entity # Factor entity assert isinstance(entity_dim, tuple) assert len(entity_dim) == len(self.factors) shape = tuple(len(c.get_topology()[d]) for c, d in zip(self.cell.cells, entity_dim)) entities = list(zip(entity_dim, numpy.unravel_index(entity_id, shape))) return entities def _merge_evaluations(self, factor_results): # Spatial dimension dimension = self.cell.get_spatial_dimension() # Derivative order order = max(map(sum, chain(*factor_results))) # A list of slices that are used to select dimensions # corresponding to each subelement. dim_slices = TensorProductCell._split_slices([c.get_spatial_dimension() for c in self.cell.cells]) # A list of multiindices, one multiindex per subelement, each # multiindex describing the shape of basis functions of the # subelement. alphas = [fe.get_indices() for fe in self.factors] # A list of multiindices, one multiindex per subelement, each # multiindex describing the value shape of the subelement. zetas = [fe.get_value_indices() for fe in self.factors] result = {} for derivative in range(order + 1): for Delta in mis(dimension, derivative): # Split the multiindex for the subelements deltas = [Delta[s] for s in dim_slices] # GEM scalars (can have free indices) for collecting # the contributions from the subelements. scalars = [] for fr, delta, alpha, zeta in zip(factor_results, deltas, alphas, zetas): # Turn basis shape to free indices, select the # right derivative entry, and collect the result. scalars.append(gem.Indexed(fr[delta], alpha + zeta)) # Multiply the values from the subelements and wrap up # non-point indices into shape. result[Delta] = gem.ComponentTensor( reduce(gem.Product, scalars), tuple(chain(*(alphas + zetas))) ) return result
[docs] def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): entities = self._factor_entity(entity) entity_dim, _ = zip(*entities) ps_factors = factor_point_set(self.cell, entity_dim, ps) factor_results = [fe.basis_evaluation(order, ps_, e) for fe, ps_, e in zip(self.factors, ps_factors, entities)] return self._merge_evaluations(factor_results)
[docs] def point_evaluation(self, order, point, entity=None): entities = self._factor_entity(entity) entity_dim, _ = zip(*entities) # Split point expression assert len(self.cell.cells) == len(entity_dim) point_dims = [cell.construct_subelement(dim).get_spatial_dimension() for cell, dim in zip(self.cell.cells, entity_dim)] assert isinstance(point, gem.Node) and point.shape == (sum(point_dims),) slices = TensorProductCell._split_slices(point_dims) point_factors = [] for s in slices: point_factors.append(gem.ListTensor( [gem.Indexed(point, (i,)) for i in range(s.start, s.stop)] )) # Subelement results factor_results = [fe.point_evaluation(order, p_, e) for fe, p_, e in zip(self.factors, point_factors, entities)] return self._merge_evaluations(factor_results)
@property def dual_basis(self): # Outer product the dual bases of the factors qs, pss = zip(*(factor.dual_basis for factor in self.factors)) ps = TensorPointSet(pss) # Naming as _merge_evaluations above alphas = [factor.get_indices() for factor in self.factors] zetas = [factor.get_value_indices() for factor in self.factors] # Index the factors by so that we can reshape into index-shape # followed by value-shape qis = [q[alpha + zeta] for q, alpha, zeta in zip(qs, alphas, zetas)] Q = gem.ComponentTensor( reduce(gem.Product, qis), tuple(chain(*(alphas + zetas))) ) return Q, ps
[docs] @cached_property def mapping(self): mappings = [fe.mapping for fe in self.factors if fe.mapping != "affine"] if len(mappings) == 0: return "affine" elif len(mappings) == 1: return mappings[0] else: return None
[docs]def productise(factors, method): '''Tensor product the dict mapping topological entities to dofs across factors. :arg factors: element factors. :arg method: instance method to call on each factor to get dofs.''' shape = tuple(fe.space_dimension() for fe in factors) dofs = {} for dim in product(*[fe.cell.get_topology().keys() for fe in factors]): dim_dofs = [] topds = [method(fe)[d] for fe, d in zip(factors, dim)] for tuple_ei in product(*[sorted(topd) for topd in topds]): tuple_vs = list(product(*[topd[ei] for topd, ei in zip(topds, tuple_ei)])) if tuple_vs: vs = list(numpy.ravel_multi_index(numpy.transpose(tuple_vs), shape)) dim_dofs.append((tuple_ei, vs)) else: dim_dofs.append((tuple_ei, [])) # flatten entity numbers dofs[dim] = dict(enumerate(v for k, v in sorted(dim_dofs))) return dofs
[docs]def compose_permutations(factors): """For the :class:`TensorProductElement` object composed of factors, construct, for each dimension tuple, for each entity, and for each possible entity orientation combination, the DoF permutation list. :arg factors: element factors. :returns: entity_permutation dict of the :class:`TensorProductElement` object composed of factors. For tensor-product elements, one needs to consider two kinds of orientations: extrinsic orientations and intrinsic ("material") orientations. Example: UFCQuadrilateral := UFCInterval x UFCInterval eo (extrinsic orientation): swap axes (X -> y, Y-> x) io (intrinsic orientation): reflect component intervals o (total orientation) : (2 ** dim) * eo + io eo\\io 0 1 2 3 1---3 0---2 3---1 2---0 0 | | | | | | | | 0---2 1---3 2---0 3---1 2---3 3---2 0---1 1---0 1 | | | | | | | | 0---1 1---0 2---3 3---2 .. code-block:: python3 import FIAT import finat cell = FIAT.ufc_cell("interval") elem = finat.DiscontinuousLagrange(cell, 1) elem = finat.TensorProductElement([elem, elem]) print(elem.entity_permutations) prints: {(0, 0): {0: {(0, 0, 0): []}, 1: {(0, 0, 0): []}, 2: {(0, 0, 0): []}, 3: {(0, 0, 0): []}}, (0, 1): {0: {(0, 0, 0): [], (0, 0, 1): []}, 1: {(0, 0, 0): [], (0, 0, 1): []}}, (1, 0): {0: {(0, 0, 0): [], (0, 1, 0): []}, 1: {(0, 0, 0): [], (0, 1, 0): []}}, (1, 1): {0: {(0, 0, 0): [0, 1, 2, 3], (0, 0, 1): [1, 0, 3, 2], (0, 1, 0): [2, 3, 0, 1], (0, 1, 1): [3, 2, 1, 0], (1, 0, 0): [0, 2, 1, 3], (1, 0, 1): [2, 0, 3, 1], (1, 1, 0): [1, 3, 0, 2], (1, 1, 1): [3, 1, 2, 0]}}} """ permutations = {} cells = [fe.cell for fe in factors] for dim in product(*[cell.get_topology().keys() for cell in cells]): dim_permutations = [] e_o_p_maps = [fe.entity_permutations[d] for fe, d in zip(factors, dim)] for e_tuple in product(*[sorted(e_o_p_map) for e_o_p_map in e_o_p_maps]): o_p_maps = [e_o_p_map[e] for e_o_p_map, e in zip(e_o_p_maps, e_tuple)] o_tuple_perm_map = make_entity_permutations_tensorproduct(cells, dim, o_p_maps) dim_permutations.append((e_tuple, o_tuple_perm_map)) permutations[dim] = dict(enumerate(v for k, v in sorted(dim_permutations))) return permutations
[docs]def factor_point_set(product_cell, product_dim, point_set): """Factors a point set for the product element into a point sets for each subelement. :arg product_cell: a TensorProductCell :arg product_dim: entity dimension for the product cell :arg point_set: point set for the product element """ assert len(product_cell.cells) == len(product_dim) point_dims = [cell.construct_subelement(dim).get_spatial_dimension() for cell, dim in zip(product_cell.cells, product_dim)] if isinstance(point_set, TensorPointSet) and \ len(product_cell.cells) == len(point_set.factors): # Just give the factors asserting matching dimensions. assert len(point_set.factors) == len(point_dims) assert all(ps.dimension == dim for ps, dim in zip(point_set.factors, point_dims)) return point_set.factors # Split the point coordinates along the point dimensions # required by the subelements. assert point_set.dimension == sum(point_dims) slices = TensorProductCell._split_slices(point_dims) if isinstance(point_set, PointSingleton): return [PointSingleton(point_set.point[s]) for s in slices] elif isinstance(point_set, (PointSet, TensorPointSet)): # Use the same point index for the new point sets. result = [] for s in slices: ps = PointSet(point_set.points[:, s]) ps.indices = point_set.indices result.append(ps) return result raise NotImplementedError("How to tabulate TensorProductElement on %s?" % (type(point_set).__name__,))