1313# limitations under the License.
1414
1515import math
16- import os
1716from typing import List , Optional
1817
1918import paddle
2524 RowParallelLinear ,
2625)
2726
28- from .lora_quick_layers import quick_lora
27+ try :
28+ from paddle .distributed .fleet .utils .sequence_parallel_utils import (
29+ AllGatherOp ,
30+ ColumnSequenceParallelLinear ,
31+ ReduceScatterOp ,
32+ RowSequenceParallelLinear ,
33+ mark_as_sequence_parallel_parameter ,
34+ )
35+ except :
36+ pass
37+
38+ from paddlenlp .transformers .mc2_parallel_linear import (
39+ MC2ColumnParallelCoreLinear ,
40+ MC2ColumnSeqParallelCoreLinear ,
41+ MC2RowParallelCoreLinear ,
42+ MC2RowSeqParallelCoreLinear ,
43+ )
2944
30- if "npu" in paddle .device .get_all_custom_device_type ():
31- from .mc2_lora_npu import MC2LoRaColumnParallelLinear , MC2LoRaRowParallelLinear
32- else :
33- MC2LoRaRowParallelLinear = None
34- MC2LoRaColumnParallelLinear = None
45+ from .lora_quick_layers import quick_lora
3546
3647
3748class LoRALinear (nn .Linear ):
@@ -266,16 +277,16 @@ def forward(self, x: paddle.Tensor):
266277 )
267278 else :
268279 # x @ W : [bz, in_f / ws] ===> [bz, out_f]
269- if "npu" in paddle .device .get_all_custom_device_type () and int (os .getenv ("MC2" , "0" )):
270- output = MC2LoRaRowParallelLinear .apply (input_mp , self .weight , self .model_parallel_group )
271- else :
280+ if MC2RowParallelCoreLinear is None :
272281 result_mp = F .linear (x = input_mp , weight = self .weight , name = self .name )
273282 output = mp_ops ._mp_allreduce (
274283 result_mp ,
275284 group = self .model_parallel_group ,
276285 use_calc_stream = True ,
277286 use_model_parallel = True ,
278287 )
288+ else :
289+ output = MC2RowParallelCoreLinear .apply (input_mp , self .weight , self .model_parallel_group )
279290
280291 if not self .merged :
281292 # x @ A: [bz, in_f/ ws] ===> [bz, r]
@@ -298,6 +309,120 @@ def extra_repr(self):
298309 return f"in_features={ self .weight .shape [0 ]} , out_features={ self .weight .shape [1 ]} , rank={ self .r } { name } "
299310
300311
312+ class RowSequenceParallelLoRALinear (RowSequenceParallelLinear ):
313+ def __init__ (
314+ self ,
315+ in_features : int ,
316+ out_features : int ,
317+ r : int = 0 ,
318+ lora_alpha : int = 1 ,
319+ lora_dropout : float = 0.0 ,
320+ rslora : bool = False ,
321+ lora_plus_scale : float = 1.0 ,
322+ merge_weights : bool = True ,
323+ use_quick_lora : bool = False ,
324+ ** kwargs
325+ ):
326+ RowSequenceParallelLinear .__init__ (self , in_features , out_features , ** kwargs )
327+ if not isinstance (r , int ) or r <= 0 :
328+ raise ValueError ("Lora rank r should be a positive integer" )
329+ self .r = r
330+ self .lora_alpha = lora_alpha
331+ # Optional dropout
332+ if lora_dropout > 0.0 :
333+ self .lora_dropout = nn .Dropout (p = lora_dropout )
334+ else :
335+ self .lora_dropout = lambda x : x
336+ # Mark the weight as unmerged
337+ self .merged = False
338+ self .merge_weights = merge_weights
339+
340+ # compatible
341+ self .name = self ._name
342+
343+ # Actual trainable parameters
344+ self .lora_A = self .create_parameter (
345+ shape = [self .input_size_per_partition , r ],
346+ dtype = self ._dtype ,
347+ is_bias = False ,
348+ attr = paddle .ParamAttr (
349+ initializer = nn .initializer .KaimingUniform (negative_slope = math .sqrt (5 ), nonlinearity = "leaky_relu" )
350+ ),
351+ )
352+ self .lora_B = self .create_parameter (
353+ shape = [r , self .out_features ],
354+ dtype = self ._dtype ,
355+ is_bias = False ,
356+ attr = paddle .ParamAttr (
357+ initializer = paddle .nn .initializer .Constant (value = 0.0 ),
358+ learning_rate = lora_plus_scale ,
359+ ),
360+ )
361+
362+ self .lora_A .is_distributed = True
363+ self .lora_A .split_axis = 0
364+ self .lora_B .is_distributed = False
365+ mark_as_sequence_parallel_parameter (self .lora_B )
366+ if not rslora :
367+ self .scaling = self .lora_alpha / self .r
368+ else :
369+ self .scaling = self .lora_alpha / math .sqrt (self .r )
370+
371+ # Freezing the pre-trained weight matrix
372+ self .weight .stop_gradient = True
373+ self ._use_quick_lora = use_quick_lora and lora_dropout == 0.0
374+
375+ @property
376+ def use_quick_lora (self ):
377+ # TODO(@gexiao): support qlora
378+ return False # self._use_quick_lora and self.training and not self.merged
379+
380+ def train (self ):
381+ super ().train ()
382+ if self .merge_weights and self .merged :
383+ # Make sure that the weights are not merged
384+ new_weight = self .weight - self .lora_A @ self .lora_B * self .scaling
385+ self .weight .set_value (new_weight )
386+ self .merged = False
387+
388+ def eval (self ):
389+ super ().eval ()
390+ if self .merge_weights and not self .merged :
391+ # Merge the weights and mark it
392+ new_weight = self .weight + self .lora_A @ self .lora_B * self .scaling
393+ self .weight .set_value (new_weight )
394+ self .merged = True
395+
396+ def forward (self , x : paddle .Tensor ):
397+ if not self .input_is_parallel :
398+ input_mp = mp_ops ._c_split (x , group = self .model_parallel_group )
399+ else :
400+ input_mp = x
401+
402+ if MC2RowSeqParallelCoreLinear is None :
403+ output_parallel = self .linear (input_mp , self .weight , name = self ._name )
404+ output_ = ReduceScatterOp .apply (output_parallel )
405+ result_mp = output_ + self .bias if self .bias is not None else output_
406+ else :
407+ output_ = MC2RowSeqParallelCoreLinear .apply (input_mp , self .weight , self .model_parallel_group )
408+ result_mp = output_ + self .bias if self .bias is not None else output_
409+
410+ if not self .merged :
411+ input_mp = self .lora_dropout (input_mp )
412+ if MC2RowSeqParallelCoreLinear is None :
413+ input_mp = input_mp @ self .lora_A
414+ input_mp = ReduceScatterOp .apply (input_mp )
415+ else :
416+ input_mp = MC2RowSeqParallelCoreLinear .apply (input_mp , self .lora_A , self .model_parallel_group )
417+ delta_mp = (input_mp @ self .lora_B ) * self .scaling
418+ result_mp += delta_mp
419+ return result_mp
420+
421+ def extra_repr (self ):
422+ name = f", name={ self .name } " if self .name else ""
423+ return f"in_features={ self .weight .shape [0 ]} , out_features={ self .weight .shape [1 ]} , rank={ self .r } { name } "
424+
425+
301426class ColumnParallelLoRALinear (ColumnParallelLinear ):
302427 def __init__ (
303428 self ,
@@ -400,21 +525,21 @@ def forward(self, input: paddle.Tensor):
400525 world_size = self .world_size ,
401526 )
402527 else :
403- if "npu" in paddle .device .get_all_custom_device_type () and int (os .getenv ("MC2" , "0" )):
404- res_mp = MC2LoRaColumnParallelLinear .apply (input , self .weight , self .model_parallel_group )
405- result_mp = res_mp + self .bias
406- else :
528+ if MC2ColumnParallelCoreLinear is None :
407529 input_mp = mp_ops ._c_identity (input , group = self .model_parallel_group )
408530 result_mp = F .linear (x = input_mp , weight = self .weight , bias = self .bias , name = self .name )
531+ else :
532+ res_mp = MC2ColumnParallelCoreLinear .apply (input , self .weight , self .model_parallel_group )
533+ result_mp = res_mp + self .bias
409534
410535 if not self .merged :
411536 input_a = self .lora_dropout (input ) @ self .lora_A
412- if "npu" in paddle .device .get_all_custom_device_type () and int (os .getenv ("MC2" , "0" )):
413- tmp = MC2LoRaColumnParallelLinear .apply (input_a , self .lora_B , self .model_parallel_group )
414- delta_mp = tmp * self .scaling
415- else :
537+ if MC2ColumnParallelCoreLinear is None :
416538 input_a_mp = mp_ops ._c_identity (input_a , group = self .model_parallel_group )
417539 delta_mp = (input_a_mp @ self .lora_B ) * self .scaling
540+ else :
541+ tmp = MC2ColumnParallelCoreLinear .apply (input_a , self .lora_B , self .model_parallel_group )
542+ delta_mp = tmp * self .scaling
418543 result_mp += delta_mp
419544
420545 if self .gather_output and self .is_mp :
@@ -428,6 +553,123 @@ def extra_repr(self):
428553 return f"in_features={ self .weight .shape [0 ]} , out_features={ self .weight .shape [1 ]} , rank={ self .r } { name } "
429554
430555
556+ class ColumnSequenceParallelLoRALinear (ColumnSequenceParallelLinear ):
557+ def __init__ (
558+ self ,
559+ in_features : int ,
560+ out_features : int ,
561+ r : int = 0 ,
562+ lora_alpha : int = 1 ,
563+ lora_dropout : float = 0.0 ,
564+ rslora : bool = False ,
565+ lora_plus_scale : float = 1.0 ,
566+ merge_weights : bool = True ,
567+ lora_A_weight_attr : Optional [paddle .ParamAttr ] = None ,
568+ use_quick_lora : bool = False ,
569+ ** kwargs
570+ ):
571+ ColumnSequenceParallelLinear .__init__ (self , in_features , out_features , ** kwargs )
572+ if not isinstance (r , int ) or r <= 0 :
573+ raise ValueError ("Lora rank r should be a positive integer" )
574+ self .r = r
575+ self .lora_alpha = lora_alpha
576+ # Optional dropout
577+ if lora_dropout > 0.0 :
578+ self .lora_dropout = nn .Dropout (p = lora_dropout )
579+ else :
580+ self .lora_dropout = lambda x : x
581+ # Mark the weight as unmerged
582+ self .merged = False
583+ self .merge_weights = merge_weights
584+
585+ # compatible
586+ self .name = self ._name
587+
588+ # Actual trainable parameters
589+ self .lora_A = self .create_parameter (
590+ shape = [in_features , r ],
591+ dtype = self ._dtype ,
592+ is_bias = False ,
593+ attr = lora_A_weight_attr ,
594+ )
595+ self .lora_A .is_distributed = False
596+ mark_as_sequence_parallel_parameter (self .lora_A )
597+
598+ self .lora_B = self .create_parameter (
599+ shape = [r , self .output_size_per_partition ],
600+ dtype = self ._dtype ,
601+ is_bias = False ,
602+ attr = paddle .ParamAttr (
603+ initializer = paddle .nn .initializer .Constant (value = 0.0 ),
604+ learning_rate = lora_plus_scale ,
605+ ),
606+ )
607+
608+ self .lora_B .is_distributed = True
609+ self .lora_B .split_axis = 1
610+ if not rslora :
611+ self .scaling = self .lora_alpha / self .r
612+ else :
613+ self .scaling = self .lora_alpha / math .sqrt (self .r )
614+
615+ # Freezing the pre-trained weight matrix
616+ self .weight .stop_gradient = True
617+ self ._use_quick_lora = use_quick_lora and lora_dropout == 0.0
618+
619+ @property
620+ def use_quick_lora (self ):
621+ # TODO(@gexiao): support qlora
622+ return False # self._use_quick_lora and self.training and not self.merged
623+
624+ def train (self ):
625+ super ().train ()
626+ if self .merge_weights and self .merged :
627+ # Make sure that the weights are not merged
628+ new_weight = self .weight - self .lora_A @ self .lora_B * self .scaling
629+ self .weight .set_value (new_weight )
630+ self .merged = False
631+
632+ def eval (self ):
633+ super ().eval ()
634+ if self .merge_weights and not self .merged :
635+ # Merge the weights and mark it
636+ new_weight = self .weight + self .lora_A @ self .lora_B * self .scaling
637+ self .weight .set_value (new_weight )
638+ self .merged = True
639+
640+ def forward (self , x : paddle .Tensor ):
641+ if MC2ColumnSeqParallelCoreLinear is None :
642+ if self .is_mp :
643+ input_parallel = AllGatherOp .apply (x )
644+ else :
645+ input_parallel = x
646+ result_mp = self .linear (input_parallel , self .weight , self .bias , name = self ._name )
647+ else :
648+ result_mp = MC2ColumnSeqParallelCoreLinear .apply (x , self .weight , self .model_parallel_group )
649+ if self .bias is not None :
650+ result_mp += self .bias
651+
652+ if not self .merged :
653+ input_a = self .lora_dropout (x ) @ self .lora_A
654+ if MC2ColumnSeqParallelCoreLinear is None :
655+ input_a = AllGatherOp .apply (input_a )
656+ delta_mp = (input_a @ self .lora_B ) * self .scaling
657+ else :
658+ input_a = MC2ColumnSeqParallelCoreLinear .apply (input_a , self .lora_B , self .model_parallel_group )
659+ delta_mp = input_a * self .scaling
660+ result_mp += delta_mp
661+
662+ if self .gather_output and self .is_mp :
663+ result = mp_ops ._c_concat (result_mp , group = self .model_parallel_group )
664+ else :
665+ result = result_mp
666+ return result
667+
668+ def extra_repr (self ):
669+ name = f", name={ self .name } " if self .name else ""
670+ return f"in_features={ self .weight .shape [0 ]} , out_features={ self .weight .shape [1 ]} , rank={ self .r } { name } "
671+
672+
431673class LoRAMergedLinear (nn .Linear ):
432674 # LoRA implemented in a dense layer with merged linear weights for q, k, v
433675 def __init__ (
0 commit comments