Skip to content

Commit cef772f

Browse files
committed
fix sp import
1 parent d1c43cd commit cef772f

File tree

4 files changed

+32
-18
lines changed

4 files changed

+32
-18
lines changed

model_zoo/gpt-3/ppfleetx/models/language_model/gpt/dygraph/hybrid_model.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,16 @@
4848
MinLengthLogitsProcessor,
4949
RepetitionPenaltyLogitsProcessor,
5050
)
51-
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
52-
ColumnSequenceParallelLinear,
53-
GatherOp,
54-
RowSequenceParallelLinear,
55-
ScatterOp,
56-
mark_as_sequence_parallel_parameter,
57-
)
51+
try:
52+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
53+
ColumnSequenceParallelLinear,
54+
GatherOp,
55+
RowSequenceParallelLinear,
56+
ScatterOp,
57+
mark_as_sequence_parallel_parameter,
58+
)
59+
except:
60+
pass
5861

5962
from paddlenlp.transformers.segment_parallel_utils import ReshardLayer
6063

model_zoo/gpt-3/ppfleetx/models/language_model/language_module.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@
2424
from ppfleetx.core.module.basic_module import BasicModule
2525
from ppfleetx.data.tokenizers import GPTTokenizer
2626
from ppfleetx.distributed.apis import env
27-
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
28-
register_sequence_parallel_allreduce_hooks,
29-
)
27+
try:
28+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
29+
register_sequence_parallel_allreduce_hooks,
30+
)
31+
except:
32+
pass
3033
from ppfleetx.utils.log import logger
3134

3235
# TODO(haohongxiang): to solve the problem of cross-reference

paddlenlp/transformers/gpt/modeling_auto.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,14 @@
3030
from paddle.distributed import fleet
3131
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
3232
from paddle.distributed.fleet.utils import recompute
33-
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
34-
ScatterOp,
35-
mark_as_sequence_parallel_parameter,
36-
)
33+
34+
try:
35+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
36+
ScatterOp,
37+
mark_as_sequence_parallel_parameter,
38+
)
39+
except:
40+
pass
3741

3842
from ...utils.converter import StateDictNameMapping
3943
from .. import PretrainedModel, register_base_model

paddlenlp/transformers/mc2_seqence_parallel_linear.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,14 @@
2323

2424
from paddle import distributed as dist
2525
from paddle.autograd import PyLayer
26-
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
27-
ColumnSequenceParallelLinear,
28-
RowSequenceParallelLinear,
29-
)
26+
27+
try:
28+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
29+
ColumnSequenceParallelLinear,
30+
RowSequenceParallelLinear,
31+
)
32+
except:
33+
pass
3034

3135
__all_gather_recomputation__ = False
3236
if int(os.getenv("MC2_Recompute", 0)):

0 commit comments

Comments
 (0)