Skip to content

Commit

Permalink
Revert __getitem__
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 21, 2025
1 parent e8ff24b commit 4cff9b4
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 34 deletions.
4 changes: 2 additions & 2 deletions finat/physically_mapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,15 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
M = self.basis_transformation(coordinate_mapping)
# we expect M to be sparse with O(1) nonzeros per row
# for each row, get the column index of each nonzero entry
csr = [[j for j in range(M.shape[1]) if not isinstance(M[i, j], gem.Zero)]
csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)]
for i in range(M.shape[0])]

def matvec(table):
# basis recombination using hand-rolled sparse-dense matrix multiplication
ii = gem.indices(len(table.shape)-1)
phi = [gem.Indexed(table, (j, *ii)) for j in range(M.shape[1])]
# the sum approach is faster than calling numpy.dot or gem.IndexSum
expressions = [gem.ComponentTensor(sum(M[i, j] * phi[j] for j in js), ii)
expressions = [gem.ComponentTensor(sum(M.array[i, j] * phi[j] for j in js), ii)
for i, js in enumerate(csr)]
val = gem.ListTensor(expressions)
# val = M @ table
Expand Down
44 changes: 19 additions & 25 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def __call__(self, *args, **kwargs):

# Set free_indices if not set already
if not hasattr(obj, 'free_indices'):
obj.free_indices = unique(chain(*[c.free_indices
for c in obj.children]))
obj.free_indices = unique(chain.from_iterable(c.free_indices
for c in obj.children))
# Set dtype if not set already.
if not hasattr(obj, 'dtype'):
obj.dtype = obj.inherit_dtype_from_children(obj.children)
Expand Down Expand Up @@ -306,9 +306,6 @@ def value(self):
def shape(self):
return self.array.shape

def __getitem__(self, i):
return self.array[i]


class Variable(Terminal):
"""Symbolic variable tensor"""
Expand Down Expand Up @@ -337,7 +334,7 @@ def __new__(cls, a, b):
return a

if isinstance(a, Constant) and isinstance(b, Constant):
return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children([a, b]))
return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children((a, b)))

self = super(Sum, cls).__new__(cls)
self.children = a, b
Expand All @@ -361,7 +358,7 @@ def __new__(cls, a, b):
return a

if isinstance(a, Constant) and isinstance(b, Constant):
return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children([a, b]))
return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children((a, b)))

self = super(Product, cls).__new__(cls)
self.children = a, b
Expand All @@ -385,7 +382,7 @@ def __new__(cls, a, b):
return a

if isinstance(a, Constant) and isinstance(b, Constant):
return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children([a, b]))
return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children((a, b)))

self = super(Division, cls).__new__(cls)
self.children = a, b
Expand Down Expand Up @@ -676,6 +673,19 @@ def __new__(cls, aggregate, multiindex):
if isinstance(aggregate, Zero):
return Zero(dtype=aggregate.dtype)

# Simplify Literal and ListTensor
if isinstance(aggregate, (Constant, ListTensor)):
if all(isinstance(i, int) for i in multiindex):
# All indices fixed
sub = aggregate.array[multiindex]
return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub
elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex):
# Some indices fixed
slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex)
sub = aggregate.array[slices]
sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub)
return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int)))

# Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll)
if isinstance(aggregate, ComponentTensor):
B, = aggregate.children
Expand All @@ -689,19 +699,6 @@ def __new__(cls, aggregate, multiindex):
ll = tuple(rep.get(k, k) for k in kk)
return Indexed(C, ll)

# Simplify Literal and ListTensor
if isinstance(aggregate, (Constant, ListTensor)):
if all(isinstance(i, int) for i in multiindex):
# All indices fixed
sub = aggregate[multiindex]
return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub
elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex):
# Some indices fixed
slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex)
sub = aggregate[slices]
sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub)
return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int)))

self = super(Indexed, cls).__new__(cls)
self.children = (aggregate,)
self.multiindex = multiindex
Expand Down Expand Up @@ -945,9 +942,6 @@ def shape(self):
def __reduce__(self):
return type(self), (self.array,)

def __getitem__(self, i):
return self.array[i]

def reconstruct(self, *args):
return ListTensor(asarray(args).reshape(self.array.shape))

Expand All @@ -958,7 +952,7 @@ def is_equal(self, other):
"""Common subexpression eliminating equality predicate."""
if type(self) is not type(other):
return False
if (self.array == other.array).all():
if numpy.array_equal(self.array, other.array):
self.array = other.array
return True
return False
Expand Down
10 changes: 5 additions & 5 deletions gem/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _constant_fold_zero_listtensor(node, self):
new_children = list(map(self, node.children))
if all(isinstance(nc, Zero) for nc in new_children):
return Zero(node.shape)
elif all(nc == c for nc, c in zip(new_children, node.children)):
elif new_children == node.children:
return node
else:
return node.reconstruct(*new_children)
Expand All @@ -207,7 +207,7 @@ def constant_fold_zero(exprs):
otherwise Literal `0`s would be reintroduced.
"""
mapper = Memoizer(_constant_fold_zero)
return [mapper(e) for e in exprs]
return list(map(mapper, exprs))


def _select_expression(expressions, index):
Expand Down Expand Up @@ -252,9 +252,9 @@ def child(expression):
assert all(len(e.children) == len(expr.children) for e in expressions)
assert len(expr.children) > 0

return expr.reconstruct(*[_select_expression(nth_children, index)
for nth_children in zip(*[e.children
for e in expressions])])
return expr.reconstruct(*(_select_expression(nth_children, index)
for nth_children in zip(*(e.children
for e in expressions))))

raise NotImplementedError("No rule for factorising expressions of this kind.")

Expand Down
5 changes: 3 additions & 2 deletions test/finat/test_zany_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pytest
from gem.interpreter import evaluate
from finat.physically_mapped import PhysicallyMappedElement


def make_unisolvent_points(element, interior=False):
Expand Down Expand Up @@ -65,11 +66,11 @@ def check_zany_mapping(element, ref_to_phys, *args, **kwargs):
# Zany map the results
num_bfs = phys_element.space_dimension()
num_dofs = finat_element.space_dimension()
try:
if isinstance(finat_element, PhysicallyMappedElement):
Mgem = finat_element.basis_transformation(ref_to_phys)
M = evaluate([Mgem])[0].arr
ref_vals_zany = np.tensordot(M, ref_vals_piola, (-1, 0))
except AttributeError:
else:
M = np.eye(num_dofs, num_bfs)
ref_vals_zany = ref_vals_piola

Expand Down

0 comments on commit 4cff9b4

Please sign in to comment.