Skip to content

Commit f8e805f

Browse files
ksagiyampbrubeck
andcommitted
simplify Indexed(ComponentTensor(Indexed(ListTensor(...), ...), ...), ...) on construction
Co-authored-by: Pablo Brubeck <[email protected]>
1 parent 4adcd7b commit f8e805f

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

test/test_simplify.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
acos,
1515
as_tensor,
1616
as_ufl,
17+
as_vector,
1718
asin,
1819
atan,
1920
cos,
@@ -234,7 +235,7 @@ def test_untangle_indexed_component_tensor(self):
234235
A = as_tensor(Indexed(C, MultiIndex(kk)), jj)
235236
assert A is not C
236237

237-
ii = kk
238+
ii = indices(len(A.ufl_shape))
238239
expr = Indexed(A, MultiIndex(ii))
239240
assert isinstance(expr, Indexed)
240241
B, ll = expr.ufl_operands
@@ -266,3 +267,24 @@ def test_simplify_indexed(self):
266267
# ComponentTensor + ListTensor
267268
c = ComponentTensor(Indexed(ll, MultiIndex((i, j))), MultiIndex((j, i)))
268269
assert Indexed(c, MultiIndex((FixedIndex(1), FixedIndex(2)))) == l2[1]
270+
271+
272+
def test_simplify_indexed_componenttensor_indexed_listtensor():
273+
list_item_0 = as_vector([10.0, 11.0, 12.0])
274+
list_item_1 = as_vector([20.0, 21.0, 22.0])
275+
i = Index()
276+
j = Index()
277+
value = Indexed(
278+
ComponentTensor(
279+
Indexed(
280+
ListTensor(
281+
Indexed(list_item_0, MultiIndex((j,))),
282+
Indexed(list_item_1, MultiIndex((j,))),
283+
),
284+
MultiIndex((i,)),
285+
),
286+
MultiIndex((i, j)),
287+
),
288+
MultiIndex((FixedIndex(1), FixedIndex(2))),
289+
)
290+
assert value == list_item_1[2]

ufl/tensors.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,14 +235,24 @@ def _simplify_indexed(self, multiindex):
235235
"""Return a simplified Expr used in the constructor of Indexed(self, multiindex)."""
236236
# Untangle as_tensor(C[kk], jj)[ii] -> C[ll]
237237
B, jj = self.ufl_operands
238+
if len(multiindex) != len(jj):
239+
raise ValueError(f"len(multiindex) ({len(multiindex)}) != len(jj) ({len(jj)})")
240+
rep = dict(zip(jj, multiindex))
241+
# Avoid recursion and just attempt to simplify some common patterns
242+
# as the result of this method is not cached.
243+
if isinstance(B, Indexed):
244+
C, kk = B.ufl_operands
245+
if isinstance(C, ListTensor) and len(kk) == 1 and isinstance(rep[kk[0]], FixedIndex):
246+
(k,) = kk
247+
B = C.ufl_operands[int(rep[k])]
248+
jj = MultiIndex(tuple(j for j in jj if j != k))
249+
multiindex = MultiIndex(tuple(rep[j] for j in jj))
250+
rep = dict(zip(jj, multiindex))
238251
if isinstance(B, Indexed):
239252
C, kk = B.ufl_operands
240253
if all(j in kk for j in jj):
241-
ii = tuple(multiindex)
242-
rep = dict(zip(jj, ii))
243254
Cind = tuple(rep.get(k, k) for k in kk)
244255
return Indexed(C, MultiIndex(Cind))
245-
246256
return Operator._simplify_indexed(self, multiindex)
247257

248258
def indices(self):

0 commit comments

Comments
 (0)