Skip to content

Commit 5403dc8

Browse files
authored
Merge pull request #47 from Tiramisu-Compiler/isl-ast
fix: fix upper bound coming from isl ast
2 parents 91e20ca + b706455 commit 5403dc8

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

.github/workflows/python-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
name: Python package
55

6-
on: [push, pull_request]
6+
on: [pull_request]
77

88
jobs:
99
format:

tiralib/tiramisu/tiramisu_tree.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ def from_isl_ast_string_list(cls, isl_ast_string_list: list[str]) -> "TiramisuTr
183183
else:
184184
upper_bound = loop_condition
185185
try:
186-
upper_bound = int(upper_bound)
186+
# isl has <= so we add 1 to the upper bound
187+
upper_bound = int(upper_bound) + 1
187188
except ValueError:
188189
# Upper bound is not an integer so we keep it string
189190
pass
@@ -428,3 +429,48 @@ def __repr__(self) -> str:
428429
@property
429430
def depth(self) -> int:
430431
return max([iterator.level for iterator in self.iterators.values()]) + 1
432+
433+
def get_isl_ast_string(self) -> str:
434+
representation = ""
435+
for root in self.roots:
436+
representation += self._get_isl_ast_string_of_node(root)
437+
return representation
438+
439+
def _get_isl_ast_string_of_node(self, node_id: IteratorIdentifier) -> str:
440+
representation = ""
441+
iterator = self.iterators[node_id]
442+
upper_bound_str = (
443+
f"{iterator.name} <= {iterator.upper_bound - 1}"
444+
if isinstance(iterator.upper_bound, int)
445+
else iterator.upper_bound
446+
)
447+
representation += (
448+
f"{iterator.name}|iterator|{iterator.lower_bound}|{upper_bound_str}|1\n"
449+
)
450+
comps_and_iterators = [
451+
(comp, "comp") for comp in self.iterators[node_id].computations_list
452+
]
453+
comps_and_iterators += [
454+
(iterator, "iterator")
455+
for iterator in self.iterators[node_id].child_iterators
456+
]
457+
458+
# sort them by computations_absolute_order
459+
comps_and_iterators = sorted(
460+
comps_and_iterators,
461+
key=lambda item: (
462+
self.computations_absolute_order[item[0]]
463+
if item[1] == "comp"
464+
else self.computations_absolute_order[item[0][0]]
465+
),
466+
)
467+
468+
for comp_or_iterator, type in comps_and_iterators:
469+
if type == "comp":
470+
representation += (
471+
f"{iterator.level + 1}|computation|{comp_or_iterator}\n"
472+
)
473+
474+
else:
475+
representation += self._get_isl_ast_string_of_node(comp_or_iterator)
476+
return representation

0 commit comments

Comments
 (0)