Skip to content

Commit 3bd7b2d

Browse files
authored
[SOT] Refactor MapVariable to align with Python (PaddlePaddle#71346)
1 parent 35caacb commit 3bd7b2d

File tree

3 files changed

+46
-29
lines changed

3 files changed

+46
-29
lines changed

python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -645,16 +645,12 @@ def create_zip(*var: VariableBase):
645645

646646

647647
# map
648-
Dispatcher.register(
649-
map,
650-
(
651-
"CallableVariable",
652-
"VariableBase",
653-
),
654-
lambda fn, var: MapVariable.from_iterator(
655-
fn, var, graph=var.graph, tracker=DummyTracker([var])
656-
),
657-
)
648+
@Dispatcher.register_decorator(map)
649+
def create_map(func: CallableVariable, *var: VariableBase):
650+
tracked_vars = [func, *var]
651+
return MapVariable.from_iterator(
652+
func, var, graph=Dispatcher.graph, tracker=DummyTracker(tracked_vars)
653+
)
658654

659655

660656
# reversed

python/paddle/jit/sot/opcode_translator/executor/variables/iter.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ..tracker import ConstTracker, DummyTracker
2525
from .base import VariableFactory
2626
from .basic import ConstantVariable
27-
from .container import ContainerVariable, TupleVariable
27+
from .container import TupleVariable
2828

2929
if TYPE_CHECKING:
3030
from collections.abc import Sequence
@@ -226,21 +226,25 @@ class MapVariable(SequenceIterVariable):
226226
MapVariable holds a SequenceIterVariable and return a Iterable Variable after map function
227227
"""
228228

229-
def __init__(self, func, val_iterator, graph, tracker):
230-
super().__init__(val_iterator, graph, tracker)
229+
def __init__(self, func, iters, graph, tracker):
230+
# iters may contain only one iter.
231+
super().__init__(iters, graph, tracker)
231232
self.func = func
232233

233234
def next(self):
234-
return self.func(self.hold.next())
235+
values = []
236+
for iter_var in self.hold:
237+
next_var = iter_var.next()
238+
values.append(next_var)
239+
return self.func(*values)
235240

236241
def to_list(self) -> list:
237-
retval = []
238-
while True:
239-
try:
240-
retval.append(self.func(self.hold.next()))
241-
except StopIteration:
242-
break
243-
return retval
242+
lists = [iter_vars.to_list() for iter_vars in self.hold]
243+
min_len = min(len(l) for l in lists)
244+
result = []
245+
for i in range(min_len):
246+
result.append(self.func(*(l[i] for l in lists)))
247+
return result
244248

245249
def has_side_effect(self) -> bool:
246250
return self.hold.has_side_effect()
@@ -256,16 +260,20 @@ def _reconstruct(self, codegen: PyCodeGen):
256260

257261
@staticmethod
258262
def from_iterator(
259-
func, value, graph: FunctionGraph | None, tracker: Tracker
263+
func,
264+
value: Sequence[VariableBase],
265+
graph: FunctionGraph | None,
266+
tracker: Tracker,
260267
):
261-
iter_variable = (
262-
value.get_iter() if isinstance(value, ContainerVariable) else value
263-
)
268+
map_targets = []
264269

265-
if isinstance(iter_variable, IterVariable):
266-
return MapVariable(func, iter_variable, graph, tracker)
267-
else:
268-
return UserDefinedIterVariable(value, graph, tracker)
270+
for variable in value:
271+
iter_variable = variable.get_iter()
272+
if not isinstance(iter_variable, SequenceIterVariable):
273+
return UserDefinedIterVariable(value, graph, tracker)
274+
map_targets.append(iter_variable)
275+
276+
return MapVariable(func, map_targets, graph, tracker)
269277

270278

271279
# what UserDefinedIterVariable holds doesn't matter, because use user defined iterator will trigger break graph

test/sot/test_builtin_map.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from test_case_base import TestCaseBase, test_with_faster_guard
2121

22+
from paddle import to_tensor
2223
from paddle.jit import sot
2324
from paddle.jit.sot.psdb import check_no_breakgraph
2425
from paddle.jit.sot.utils import strict_mode_guard
@@ -98,6 +99,12 @@ def test_map_for_loop(x: list):
9899
return res
99100

100101

102+
@check_no_breakgraph
103+
def test_map_multi_input(func, tensor_, tuple_):
104+
x, y, z = map(func, tensor_, tuple_)
105+
return x
106+
107+
101108
class TestMap(TestCaseBase):
102109
@test_with_faster_guard
103110
def test_map(self):
@@ -122,6 +129,12 @@ def test_map_with_breakgraph(self):
122129
@test_with_faster_guard
123130
def test_map_unpack(self):
124131
self.assert_results(test_map_unpack, [1, 2, 3, 4])
132+
self.assert_results(
133+
test_map_multi_input,
134+
lambda x, y: x + y,
135+
to_tensor([1, 2, 3]),
136+
(2, 4, 6),
137+
)
125138

126139
@test_with_faster_guard
127140
def test_map_for_loop(self):

0 commit comments

Comments
 (0)