Skip to content

Commit 9568bc7

Browse files
authored
[SOT] Add functools.lru_cache support (PaddlePaddle#71298)
1 parent 535a222 commit 9568bc7

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

python/paddle/jit/sot/opcode_translator/executor/variables/callable.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import functools
1718
import inspect
1819
import itertools
1920
import operator
@@ -794,6 +795,27 @@ def main_info(self) -> dict[str, Any]:
794795
}
795796

796797

798+
class FunctoolsLruCacheWrapperVariable(FunctionVariable):
799+
def __init__(
800+
self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker
801+
):
802+
super().__init__(fn, graph, tracker)
803+
self.value = fn
804+
805+
def call_function(self, /, *args, **kwargs):
806+
wrapped_fn = self.value.__wrapped__
807+
wrapped_fn = VariableFactory.from_value(
808+
wrapped_fn, self.graph, GetAttrTracker(self, "__wrapped__")
809+
)
810+
return wrapped_fn(*args, **kwargs)
811+
812+
@VariableFactory.register_from_value()
813+
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
814+
if isinstance(value, functools._lru_cache_wrapper):
815+
return FunctoolsLruCacheWrapperVariable(value, graph, tracker)
816+
return None
817+
818+
797819
class UserDefinedGeneratorFunctionVariable(FunctionVariable):
798820
"""
799821
UserDefinedGeneratorFunctionVariable is a subclass of FunctionVariable used to wrap a user-defined generator.

test/sot/test_functools.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import functools
16+
import unittest
17+
18+
from test_case_base import TestCaseBase
19+
20+
import paddle
21+
from paddle.jit.sot.psdb import check_no_breakgraph
22+
23+
24+
@functools.lru_cache
25+
def fn_with_cache(x):
26+
return x + 1
27+
28+
29+
@check_no_breakgraph
30+
def fn_with_lru_cache(x: paddle.Tensor):
31+
a1 = fn_with_cache(1)
32+
b1 = fn_with_cache(x)
33+
b2 = fn_with_cache(x)
34+
a2 = fn_with_cache(1)
35+
c1 = fn_with_cache(2)
36+
c2 = fn_with_cache(2)
37+
return a1, a2, b1, b2, c1, c2
38+
39+
40+
class TestFunctools(TestCaseBase):
41+
def test_lru_cache(self):
42+
x = paddle.rand([2, 3, 4])
43+
self.assert_results(fn_with_lru_cache, x)
44+
45+
46+
if __name__ == "__main__":
47+
unittest.main()

0 commit comments

Comments
 (0)