File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -399,7 +399,8 @@ def forward(self, x: paddle.Tensor):
399399 else :
400400 input_mp = x
401401
402- if MC2RowSeqParallelCoreLinear is None :
402+ # TODO(@gexiao): temporary workaround for deterministic calculation
403+ if True or MC2RowSeqParallelCoreLinear is None :
403404 output_parallel = self .linear (input_mp , self .weight , name = self ._name )
404405 output_ = ReduceScatterOp .apply (output_parallel )
405406 result_mp = output_ + self .bias if self .bias is not None else output_
@@ -651,7 +652,8 @@ def forward(self, x: paddle.Tensor):
651652
652653 if not self .merged :
653654 input_a = self .lora_dropout (x ) @ self .lora_A
654- if MC2ColumnSeqParallelCoreLinear is None :
655+ # TODO(@gexiao): temporary workaround for deterministic calculation
656+ if True or MC2ColumnSeqParallelCoreLinear is None :
655657 input_a = AllGatherOp .apply (input_a )
656658 delta_mp = (input_a @ self .lora_B ) * self .scaling
657659 else :
You can’t perform that action at this time.
0 commit comments