Skip to content

Commit 10c7f88

Browse files
dynamicheartwill-jl944
authored andcommitted
[XPU] add lora optimization (PaddlePaddle#8527)
* [XPU] add lora optimization * fix * refine
1 parent c39c08e commit 10c7f88

File tree

1 file changed

+50
-8
lines changed

1 file changed

+50
-8
lines changed

paddlenlp/peft/lora/lora_model.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,57 @@ class RowSequenceParallelLinear:
6161
from ...utils.log import logger
6262
from ...utils.tools import get_env_device
6363
from .lora_config import LoRAConfig
64-
from .lora_layers import (
65-
ColumnParallelLoRALinear,
66-
ColumnSequenceParallelLoRALinear,
67-
LoRAConv2D,
68-
LoRALinear,
69-
RowParallelLoRALinear,
70-
RowSequenceParallelLoRALinear,
71-
)
7264

65+
66+
def get_lora_layers():
67+
try:
68+
if get_env_device() == "xpu":
69+
# If paddle_xpu is not installed, just use PaddleNLP's native lora layers
70+
from paddle_xpu.layers.nn.lora_layers import (
71+
XPUColumnParallelLoRALinear as ColumnParallelLoRALinear,
72+
)
73+
from paddle_xpu.layers.nn.lora_layers import (
74+
XPUColumnSequenceParallelLoRALinear as ColumnSequenceParallelLoRALinear,
75+
)
76+
from paddle_xpu.layers.nn.lora_layers import XPULoRALinear as LoRALinear
77+
from paddle_xpu.layers.nn.lora_layers import (
78+
XPURowParallelLoRALinear as RowParallelLoRALinear,
79+
)
80+
from paddle_xpu.layers.nn.lora_layers import (
81+
XPURowSequenceParallelLoRALinear as RowSequenceParallelLoRALinear,
82+
)
83+
84+
from .lora_layers import LoRAConv2D
85+
86+
else:
87+
raise ImportError # Force to use the fallback if not XPU
88+
except ImportError:
89+
from .lora_layers import (
90+
ColumnParallelLoRALinear,
91+
ColumnSequenceParallelLoRALinear,
92+
LoRAConv2D,
93+
LoRALinear,
94+
RowParallelLoRALinear,
95+
RowSequenceParallelLoRALinear,
96+
)
97+
98+
return {
99+
"ColumnParallelLoRALinear": ColumnParallelLoRALinear,
100+
"ColumnSequenceParallelLoRALinear": ColumnSequenceParallelLoRALinear,
101+
"LoRAConv2D": LoRAConv2D,
102+
"LoRALinear": LoRALinear,
103+
"RowParallelLoRALinear": RowParallelLoRALinear,
104+
"RowSequenceParallelLoRALinear": RowSequenceParallelLoRALinear,
105+
}
106+
107+
108+
lora_layers = get_lora_layers()
109+
ColumnParallelLoRALinear = lora_layers["ColumnParallelLoRALinear"]
110+
ColumnSequenceParallelLoRALinear = lora_layers["ColumnSequenceParallelLoRALinear"]
111+
LoRAConv2D = lora_layers["LoRAConv2D"]
112+
LoRALinear = lora_layers["LoRALinear"]
113+
RowParallelLoRALinear = lora_layers["RowParallelLoRALinear"]
114+
RowSequenceParallelLoRALinear = lora_layers["RowSequenceParallelLoRALinear"]
73115
AVALIABLE_LAYERS = [
74116
ColumnParallelLoRALinear,
75117
ColumnSequenceParallelLoRALinear,

0 commit comments

Comments
 (0)