@@ -61,15 +61,57 @@ class RowSequenceParallelLinear:
6161from ...utils .log import logger
6262from ...utils .tools import get_env_device
6363from .lora_config import LoRAConfig
64- from .lora_layers import (
65- ColumnParallelLoRALinear ,
66- ColumnSequenceParallelLoRALinear ,
67- LoRAConv2D ,
68- LoRALinear ,
69- RowParallelLoRALinear ,
70- RowSequenceParallelLoRALinear ,
71- )
7264
65+
66+ def get_lora_layers ():
67+ try :
68+ if get_env_device () == "xpu" :
69+ # If paddle_xpu is not installed, just use PaddleNLP's native lora layers
70+ from paddle_xpu .layers .nn .lora_layers import (
71+ XPUColumnParallelLoRALinear as ColumnParallelLoRALinear ,
72+ )
73+ from paddle_xpu .layers .nn .lora_layers import (
74+ XPUColumnSequenceParallelLoRALinear as ColumnSequenceParallelLoRALinear ,
75+ )
76+ from paddle_xpu .layers .nn .lora_layers import XPULoRALinear as LoRALinear
77+ from paddle_xpu .layers .nn .lora_layers import (
78+ XPURowParallelLoRALinear as RowParallelLoRALinear ,
79+ )
80+ from paddle_xpu .layers .nn .lora_layers import (
81+ XPURowSequenceParallelLoRALinear as RowSequenceParallelLoRALinear ,
82+ )
83+
84+ from .lora_layers import LoRAConv2D
85+
86+ else :
87+ raise ImportError # Force to use the fallback if not XPU
88+ except ImportError :
89+ from .lora_layers import (
90+ ColumnParallelLoRALinear ,
91+ ColumnSequenceParallelLoRALinear ,
92+ LoRAConv2D ,
93+ LoRALinear ,
94+ RowParallelLoRALinear ,
95+ RowSequenceParallelLoRALinear ,
96+ )
97+
98+ return {
99+ "ColumnParallelLoRALinear" : ColumnParallelLoRALinear ,
100+ "ColumnSequenceParallelLoRALinear" : ColumnSequenceParallelLoRALinear ,
101+ "LoRAConv2D" : LoRAConv2D ,
102+ "LoRALinear" : LoRALinear ,
103+ "RowParallelLoRALinear" : RowParallelLoRALinear ,
104+ "RowSequenceParallelLoRALinear" : RowSequenceParallelLoRALinear ,
105+ }
106+
107+
108+ lora_layers = get_lora_layers ()
109+ ColumnParallelLoRALinear = lora_layers ["ColumnParallelLoRALinear" ]
110+ ColumnSequenceParallelLoRALinear = lora_layers ["ColumnSequenceParallelLoRALinear" ]
111+ LoRAConv2D = lora_layers ["LoRAConv2D" ]
112+ LoRALinear = lora_layers ["LoRALinear" ]
113+ RowParallelLoRALinear = lora_layers ["RowParallelLoRALinear" ]
114+ RowSequenceParallelLoRALinear = lora_layers ["RowSequenceParallelLoRALinear" ]
73115AVALIABLE_LAYERS = [
74116 ColumnParallelLoRALinear ,
75117 ColumnSequenceParallelLoRALinear ,
0 commit comments