Skip to content

Commit a53477c

Browse files
authored
[XPU] add lora optimization (#8527)
* [XPU] add lora optimization * fix * refine
1 parent 2723138 commit a53477c

File tree

1 file changed

+62
-11
lines changed

1 file changed

+62
-11
lines changed

paddlenlp/peft/lora/lora_model.py

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,27 +44,78 @@
4444
from ...utils.distributed import distributed_gather
4545
from ...utils.env import LORA_WEIGHTS_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME
4646
from ...utils.log import logger
47+
from ...utils.tools import get_env_device
4748
from .lora_config import LoRAConfig
4849

4950
try:
5051
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
5152
ColumnSequenceParallelLinear,
5253
RowSequenceParallelLinear,
5354
)
54-
55-
from .lora_layers import (
56-
ColumnParallelLoRALinear,
57-
ColumnParallelLoRAMergedLinear,
58-
ColumnSequenceParallelLoRALinear,
59-
LoRAConv2D,
60-
LoRALinear,
61-
LoRAMergedLinear,
62-
RowParallelLoRALinear,
63-
RowSequenceParallelLoRALinear,
64-
)
6555
except:
6656
pass
6757

58+
59+
def get_lora_layers():
60+
try:
61+
if get_env_device() == "xpu":
62+
# If paddle_xpu is not installed, just use PaddleNLP's native lora layers
63+
from paddle_xpu.layers.nn.lora_layers import (
64+
XPUColumnParallelLoRALinear as ColumnParallelLoRALinear,
65+
)
66+
from paddle_xpu.layers.nn.lora_layers import (
67+
XPUColumnSequenceParallelLoRALinear as ColumnSequenceParallelLoRALinear,
68+
)
69+
from paddle_xpu.layers.nn.lora_layers import XPULoRALinear as LoRALinear
70+
from paddle_xpu.layers.nn.lora_layers import (
71+
XPURowParallelLoRALinear as RowParallelLoRALinear,
72+
)
73+
from paddle_xpu.layers.nn.lora_layers import (
74+
XPURowSequenceParallelLoRALinear as RowSequenceParallelLoRALinear,
75+
)
76+
77+
from .lora_layers import (
78+
ColumnParallelLoRAMergedLinear,
79+
LoRAConv2D,
80+
LoRAMergedLinear,
81+
)
82+
83+
else:
84+
raise ImportError # Force to use the fallback if not XPU
85+
except ImportError:
86+
from .lora_layers import (
87+
ColumnParallelLoRALinear,
88+
ColumnParallelLoRAMergedLinear,
89+
ColumnSequenceParallelLoRALinear,
90+
LoRAConv2D,
91+
LoRALinear,
92+
LoRAMergedLinear,
93+
RowParallelLoRALinear,
94+
RowSequenceParallelLoRALinear,
95+
)
96+
97+
return {
98+
"ColumnParallelLoRALinear": ColumnParallelLoRALinear,
99+
"ColumnParallelLoRAMergedLinear": ColumnParallelLoRAMergedLinear,
100+
"ColumnSequenceParallelLoRALinear": ColumnSequenceParallelLoRALinear,
101+
"LoRAConv2D": LoRAConv2D,
102+
"LoRALinear": LoRALinear,
103+
"LoRAMergedLinear": LoRAMergedLinear,
104+
"RowParallelLoRALinear": RowParallelLoRALinear,
105+
"RowSequenceParallelLoRALinear": RowSequenceParallelLoRALinear,
106+
}
107+
108+
109+
lora_layers = get_lora_layers()
110+
ColumnParallelLoRALinear = lora_layers["ColumnParallelLoRALinear"]
111+
ColumnParallelLoRAMergedLinear = lora_layers["ColumnParallelLoRAMergedLinear"]
112+
ColumnSequenceParallelLoRALinear = lora_layers["ColumnSequenceParallelLoRALinear"]
113+
LoRAConv2D = lora_layers["LoRAConv2D"]
114+
LoRALinear = lora_layers["LoRALinear"]
115+
LoRAMergedLinear = lora_layers["LoRAMergedLinear"]
116+
RowParallelLoRALinear = lora_layers["RowParallelLoRALinear"]
117+
RowSequenceParallelLoRALinear = lora_layers["RowSequenceParallelLoRALinear"]
118+
68119
try:
69120
from ...quantization.quantization_linear import (
70121
ColumnParallelQuantizationLinear,

0 commit comments

Comments
 (0)