Skip to content

Commit 696f305

Browse files
committed
try except sp
1 parent 6b5099a commit 696f305

File tree

5 files changed

+52
-34
lines changed

5 files changed

+52
-34
lines changed

paddlenlp/transformers/__init__.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,20 @@
2929
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
3030
from .image_processing_utils import ImageProcessingMixin
3131
from .attention_utils import create_bigbird_rand_mask_idx_list
32-
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
33-
GatherOp,
34-
ScatterOp,
35-
AllGatherOp,
36-
ReduceScatterOp,
37-
ColumnSequenceParallelLinear,
38-
RowSequenceParallelLinear,
39-
mark_as_sequence_parallel_parameter,
40-
register_sequence_parallel_allreduce_hooks,
41-
)
32+
33+
try:
34+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
35+
GatherOp,
36+
ScatterOp,
37+
AllGatherOp,
38+
ReduceScatterOp,
39+
ColumnSequenceParallelLinear,
40+
RowSequenceParallelLinear,
41+
mark_as_sequence_parallel_parameter,
42+
register_sequence_parallel_allreduce_hooks,
43+
)
44+
except:
45+
pass
4246
from .export import export_model
4347

4448
# isort: split

paddlenlp/transformers/gpt/modeling.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,17 @@
2929
from paddle.distributed import fleet
3030
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
3131
from paddle.distributed.fleet.utils import recompute
32-
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
33-
ColumnSequenceParallelLinear,
34-
GatherOp,
35-
RowSequenceParallelLinear,
36-
ScatterOp,
37-
mark_as_sequence_parallel_parameter,
38-
)
32+
33+
try:
34+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
35+
ColumnSequenceParallelLinear,
36+
GatherOp,
37+
RowSequenceParallelLinear,
38+
ScatterOp,
39+
mark_as_sequence_parallel_parameter,
40+
)
41+
except:
42+
pass
3943
from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
4044

4145
from ...utils.converter import StateDictNameMapping

paddlenlp/transformers/gpt/modeling_pp.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,13 @@
2020
SharedLayerDesc,
2121
)
2222
from paddle.distributed.fleet.utils import recompute
23-
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
24-
mark_as_sequence_parallel_parameter,
25-
)
23+
24+
try:
25+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
26+
mark_as_sequence_parallel_parameter,
27+
)
28+
except:
29+
pass
2630

2731
from paddlenlp.transformers.model_utils import PipelinePretrainedModel
2832

paddlenlp/transformers/llama/modeling.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,16 @@ def swiglu(x, y=None):
4444
return F.silu(x) * y
4545

4646

47-
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
48-
ColumnSequenceParallelLinear,
49-
GatherOp,
50-
RowSequenceParallelLinear,
51-
ScatterOp,
52-
mark_as_sequence_parallel_parameter,
53-
)
47+
try:
48+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
49+
ColumnSequenceParallelLinear,
50+
GatherOp,
51+
RowSequenceParallelLinear,
52+
ScatterOp,
53+
mark_as_sequence_parallel_parameter,
54+
)
55+
except:
56+
pass
5457
from paddle.utils import try_import
5558

5659
from paddlenlp.transformers.conversion_utils import (

paddlenlp/transformers/mixtral/modeling.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,16 @@
3333
except ImportError:
3434
fused_rotary_position_embedding = None
3535

36-
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
37-
ColumnSequenceParallelLinear,
38-
GatherOp,
39-
RowSequenceParallelLinear,
40-
ScatterOp,
41-
mark_as_sequence_parallel_parameter,
42-
)
36+
try:
37+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
38+
ColumnSequenceParallelLinear,
39+
GatherOp,
40+
RowSequenceParallelLinear,
41+
ScatterOp,
42+
mark_as_sequence_parallel_parameter,
43+
)
44+
except:
45+
pass
4346

4447
from paddlenlp.transformers.conversion_utils import (
4548
StateDictNameMapping,

0 commit comments

Comments
 (0)