Skip to content

Commit 5c9c8d3

Browse files
authored
[model_zoo/gpt-3] Fix bugs from PR-61236 which cleared paddle.jit.dy2static.utils_helper (#7989)
* fix bugs * add try import to support develop and release
1 parent 4eb6f0a commit 5c9c8d3

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

model_zoo/gpt-3/ppfleetx/models/language_model/gpt/auto/auto_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@
4949
except:
5050
flash_attention = None
5151

52+
try:
53+
from paddle.jit.api import set_dynamic_shape
54+
except:
55+
from paddle.jit.dy2static.utils_helper import set_dynamic_shape
56+
5257
def shard_op_for_sequence_parallel_linear(tgt, mesh):
5358
# FIXME Hack to shard op for module (linear)
5459
# we only shard the second to the last op (matmul) leave the last op (elementwise_add) un-touched
@@ -1206,7 +1211,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f
12061211

12071212
attn_mask = model_kwargs["attention_mask"]
12081213
# make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
1209-
paddle.jit.dy2static.utils_helper.set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
1214+
set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
12101215
model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
12111216
max_length = paddle.to_tensor(max_length)
12121217
while cur_len < max_length:

model_zoo/gpt-3/ppfleetx/models/language_model/gpt/dygraph/hybrid_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,17 @@
6262
from paddle.nn.functional.flash_attention import flash_attention
6363
except:
6464
flash_attention = None
65+
6566
try:
6667
from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd
6768
except:
6869
FusedDropoutAdd = None
6970

71+
try:
72+
from paddle.jit.api import set_dynamic_shape
73+
except:
74+
from paddle.jit.dy2static.utils_helper import set_dynamic_shape
75+
7076
def get_attr(layer, name):
7177
if getattr(layer, name, None) is not None:
7278
return getattr(layer, name, None)
@@ -1501,7 +1507,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f
15011507

15021508
attn_mask = model_kwargs["attention_mask"]
15031509
# make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
1504-
paddle.jit.dy2static.utils_helper.set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
1510+
set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
15051511
model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
15061512
while cur_len < max_length:
15071513
# Note(GuoxiaWang): Remove outputs = _forward_(**model_kwargs)

model_zoo/gpt-3/ppfleetx/models/language_model/gpt/dygraph/single_model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343
except:
4444
flash_attention = None
4545

46+
try:
47+
from paddle.jit.api import set_dynamic_shape
48+
except:
49+
from paddle.jit.dy2static.utils_helper import set_dynamic_shape
4650

4751
def get_attr(layer, name):
4852
if getattr(layer, name, None) is not None:
@@ -1077,7 +1081,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f
10771081

10781082
attn_mask = model_kwargs["attention_mask"]
10791083
# make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
1080-
paddle.jit.dy2static.utils_helper.set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
1084+
set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
10811085
model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
10821086
if hasattr(paddle.framework, "_no_check_dy2st_diff"):
10831087
# TODO(wanghuancoder): _no_check_dy2st_diff is used to turn off the checking of behavior

0 commit comments

Comments
 (0)