|
44 | 44 | from ...utils.distributed import distributed_gather |
45 | 45 | from ...utils.env import LORA_WEIGHTS_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME |
46 | 46 | from ...utils.log import logger |
| 47 | +from ...utils.tools import get_env_device |
47 | 48 | from .lora_config import LoRAConfig |
48 | 49 |
|
49 | 50 | try: |
50 | 51 | from paddle.distributed.fleet.utils.sequence_parallel_utils import ( |
51 | 52 | ColumnSequenceParallelLinear, |
52 | 53 | RowSequenceParallelLinear, |
53 | 54 | ) |
54 | | - |
55 | | - from .lora_layers import ( |
56 | | - ColumnParallelLoRALinear, |
57 | | - ColumnParallelLoRAMergedLinear, |
58 | | - ColumnSequenceParallelLoRALinear, |
59 | | - LoRAConv2D, |
60 | | - LoRALinear, |
61 | | - LoRAMergedLinear, |
62 | | - RowParallelLoRALinear, |
63 | | - RowSequenceParallelLoRALinear, |
64 | | - ) |
65 | 55 | except: |
66 | 56 | pass |
67 | 57 |
|
| 58 | + |
| 59 | +def get_lora_layers(): |
| 60 | + try: |
| 61 | + if get_env_device() == "xpu": |
| 62 | + # If paddle_xpu is not installed, just use PaddleNLP's native lora layers |
| 63 | + from paddle_xpu.layers.nn.lora_layers import ( |
| 64 | + XPUColumnParallelLoRALinear as ColumnParallelLoRALinear, |
| 65 | + ) |
| 66 | + from paddle_xpu.layers.nn.lora_layers import ( |
| 67 | + XPUColumnSequenceParallelLoRALinear as ColumnSequenceParallelLoRALinear, |
| 68 | + ) |
| 69 | + from paddle_xpu.layers.nn.lora_layers import XPULoRALinear as LoRALinear |
| 70 | + from paddle_xpu.layers.nn.lora_layers import ( |
| 71 | + XPURowParallelLoRALinear as RowParallelLoRALinear, |
| 72 | + ) |
| 73 | + from paddle_xpu.layers.nn.lora_layers import ( |
| 74 | + XPURowSequenceParallelLoRALinear as RowSequenceParallelLoRALinear, |
| 75 | + ) |
| 76 | + |
| 77 | + from .lora_layers import ( |
| 78 | + ColumnParallelLoRAMergedLinear, |
| 79 | + LoRAConv2D, |
| 80 | + LoRAMergedLinear, |
| 81 | + ) |
| 82 | + |
| 83 | + else: |
| 84 | + raise ImportError # Force to use the fallback if not XPU |
| 85 | + except ImportError: |
| 86 | + from .lora_layers import ( |
| 87 | + ColumnParallelLoRALinear, |
| 88 | + ColumnParallelLoRAMergedLinear, |
| 89 | + ColumnSequenceParallelLoRALinear, |
| 90 | + LoRAConv2D, |
| 91 | + LoRALinear, |
| 92 | + LoRAMergedLinear, |
| 93 | + RowParallelLoRALinear, |
| 94 | + RowSequenceParallelLoRALinear, |
| 95 | + ) |
| 96 | + |
| 97 | + return { |
| 98 | + "ColumnParallelLoRALinear": ColumnParallelLoRALinear, |
| 99 | + "ColumnParallelLoRAMergedLinear": ColumnParallelLoRAMergedLinear, |
| 100 | + "ColumnSequenceParallelLoRALinear": ColumnSequenceParallelLoRALinear, |
| 101 | + "LoRAConv2D": LoRAConv2D, |
| 102 | + "LoRALinear": LoRALinear, |
| 103 | + "LoRAMergedLinear": LoRAMergedLinear, |
| 104 | + "RowParallelLoRALinear": RowParallelLoRALinear, |
| 105 | + "RowSequenceParallelLoRALinear": RowSequenceParallelLoRALinear, |
| 106 | + } |
| 107 | + |
| 108 | + |
| 109 | +lora_layers = get_lora_layers() |
| 110 | +ColumnParallelLoRALinear = lora_layers["ColumnParallelLoRALinear"] |
| 111 | +ColumnParallelLoRAMergedLinear = lora_layers["ColumnParallelLoRAMergedLinear"] |
| 112 | +ColumnSequenceParallelLoRALinear = lora_layers["ColumnSequenceParallelLoRALinear"] |
| 113 | +LoRAConv2D = lora_layers["LoRAConv2D"] |
| 114 | +LoRALinear = lora_layers["LoRALinear"] |
| 115 | +LoRAMergedLinear = lora_layers["LoRAMergedLinear"] |
| 116 | +RowParallelLoRALinear = lora_layers["RowParallelLoRALinear"] |
| 117 | +RowSequenceParallelLoRALinear = lora_layers["RowSequenceParallelLoRALinear"] |
| 118 | + |
68 | 119 | try: |
69 | 120 | from ...quantization.quantization_linear import ( |
70 | 121 | ColumnParallelQuantizationLinear, |
|
0 commit comments