@@ -53,7 +53,67 @@ def __init__(self, config: Dict, pre_quant_layer: nn.Linear) -> None:
5353 device = pre_quant_layer .weight .device ,
5454 dtype = pre_quant_layer .weight .dtype )
5555 self .config = config
56+ self .quantizer = Quantizer (config = config )
57+ self .bias = pre_quant_layer .bias
58+ self .weight = get_quantized_weight_wrapper (self , pre_quant_layer .weight ,
59+ get_quantize_weight_fn (self .quantizer , pre_quant_layer .weight ))
60+
61+ self .weight .dequantizer = DeQuantizer (config , pre_quant_layer .weight .dtype )
62+
63+ def forward (self , input : Tensor ) -> Tensor :
64+ quantized_weight , quant_scale , quant_min = self .weight .deconcat (self .weight )
65+ temp_dequantized_weight = self .weight .dequantizer .dequantize (quantized_weight .view (torch .uint8 ), quant_scale ,
66+ quant_min )
67+
68+ # !!! Do not use torch.functional.linear(input, temp_dequantized_weight, self.bias) here as in zero3 torch.functional.linear is
69+ # replaced by LinearFunctionForZeroStage3. Which assume weight is non-temporary.
70+ # If weight is temp buffer there will be memory leak.
71+ return torch ._C ._nn .linear (input , temp_dequantized_weight , self .bias )
72+
73+
74+ class QuantizedLinearAllreduce (nn .Linear ):
75+
76+ def __init__ (self , config : Dict , pre_quant_layer : nn .Linear ) -> None :
77+ super (QuantizedLinearAllreduce , self ).__init__ (in_features = pre_quant_layer .weight .shape [1 ],
78+ out_features = pre_quant_layer .weight .shape [0 ],
79+ bias = pre_quant_layer .bias is not None ,
80+ device = pre_quant_layer .weight .device ,
81+ dtype = pre_quant_layer .weight .dtype )
82+ self .config = config
83+ self .mp_group = pre_quant_layer .mp_group if hasattr (pre_quant_layer , 'mp_group' ) else None
84+ self .quantizer = Quantizer (config = config , mp_group = self .mp_group )
85+ self .bias = pre_quant_layer .bias
86+ self .weight = get_quantized_weight_wrapper (self , pre_quant_layer .weight ,
87+ get_quantize_weight_fn (self .quantizer , pre_quant_layer .weight ))
88+
89+ self .weight .dequantizer = DeQuantizer (config , pre_quant_layer .weight .dtype )
90+
91+ def forward (self , input : Tensor ) -> Tensor :
92+ quantized_weight , quant_scale , quant_min = self .weight .deconcat (self .weight )
93+ temp_dequantized_weight = self .weight .dequantizer .dequantize (quantized_weight .view (torch .uint8 ), quant_scale ,
94+ quant_min )
5695
96+ # !!! Do not use torch.functional.linear(input, temp_dequantized_weight, self.bias) here as in zero3 torch.functional.linear is
97+ # replaced by LinearFunctionForZeroStage3. Which assume weight is non-temporary.
98+ # If weight is temp buffer there will be memory leak.
99+ output = torch ._C ._nn .linear (input , temp_dequantized_weight )
100+ if self .mp_group is not None :
101+ from deepspeed import comm as dist
102+ dist .inference_all_reduce (output , group = self .mp_group )
103+ if self .bias is not None :
104+ output += self .bias
105+ return output
106+
107+
108+ class QuantizedLinearLayer (nn .Linear ):
109+
110+ def __init__ (self , config : Dict , pre_quant_layer : nn .Linear ) -> None :
111+ super (QuantizedLinearLayer , self ).__init__ (in_features = pre_quant_layer .weight .shape [1 ],
112+ out_features = pre_quant_layer .weight .shape [0 ],
113+ bias = pre_quant_layer .bias is not None ,
114+ device = pre_quant_layer .weight .device ,
115+ dtype = pre_quant_layer .weight .dtype )
116+ self .config = config
57117 self .quantizer = Quantizer (config = config )
58118 self .bias = pre_quant_layer .bias
59119 self .weight = get_quantized_weight_wrapper (self , pre_quant_layer .weight ,
@@ -72,6 +132,46 @@ def forward(self, input: Tensor) -> Tensor:
72132 return torch ._C ._nn .linear (input , temp_dequantized_weight , self .bias )
73133
74134
135+ class QuantizedLmHeadLinearAllreduce (nn .Linear ):
136+
137+ def __init__ (self , config : Dict , pre_quant_layer : nn .Linear ) -> None :
138+ super (QuantizedLinearLayer , self ).__init__ (in_features = pre_quant_layer .weight .shape [1 ],
139+ out_features = pre_quant_layer .weight .shape [0 ],
140+ bias = pre_quant_layer .bias is not None ,
141+ device = pre_quant_layer .weight .device ,
142+ dtype = pre_quant_layer .weight .dtype )
143+ self .config = config
144+ self .quantizer = Quantizer (config = config )
145+ self .bias = pre_quant_layer .bias
146+ self .rank = pre_quant_layer .rank
147+ self .world_size = pre_quant_layer .world_size
148+ self .weight = get_quantized_weight_wrapper (self , pre_quant_layer .weight ,
149+ get_quantize_weight_fn (self .quantizer , pre_quant_layer .weight ))
150+
151+ self .weight .dequantizer = DeQuantizer (config , pre_quant_layer .weight .dtype )
152+
153+ def forward (self , input : Tensor ) -> Tensor :
154+ quantized_weight , quant_scale , quant_min = self .weight .deconcat (self .weight )
155+ temp_dequantized_weight = self .weight .dequantizer .dequantize (quantized_weight .view (torch .uint8 ), quant_scale ,
156+ quant_min )
157+ from deepspeed .module_inject .tp_shard import get_shard_size , get_shard_size_list
158+ input_shard_size = get_shard_size (input .shape [- 1 ], self .world_size )
159+ input_shard_offset = sum (get_shard_size_list (input .shape [- 1 ], self .world_size )[0 :self .rank ])
160+
161+ # !!! Do not use torch.functional.linear(input, temp_dequantized_weight, self.bias) here as in zero3 torch.functional.linear is
162+ # replaced by LinearFunctionForZeroStage3. Which assume weight is non-temporary.
163+ # If weight is temp buffer there will be memory leak.
164+ output = torch ._C ._nn .linear (input [:, :, input_shard_offset :input_shard_offset + input_shard_size ],
165+ temp_dequantized_weight .transpose (- 1 , - 2 ))
166+
167+ if self .mp_group is not None :
168+ from deepspeed import comm as dist
169+ dist .inference_all_reduce (output , group = self .mp_group )
170+ if self .bias is not None :
171+ output += self .bias
172+ return output
173+
174+
75175class QuantizedEmbedding (nn .Embedding ):
76176
77177 def __init__ (self , config : Dict , pre_quant_layer : nn .Embedding ) -> None :
@@ -108,7 +208,12 @@ def forward(self, input: Tensor) -> Tensor:
108208 self .scale_grad_by_freq , self .sparse )
109209
110210
211+ from ...module_inject import LinearAllreduce , LinearLayer , LmHeadLinearAllreduce
212+
111213QUANTIZATION_LAYER_MAPPINGS = {
112214 nn .Linear : QuantizedLinear ,
113215 nn .Embedding : QuantizedEmbedding ,
216+ LinearAllreduce : QuantizedLinearAllreduce ,
217+ LinearLayer : QuantizedLinearLayer ,
218+ LmHeadLinearAllreduce : QuantizedLmHeadLinearAllreduce
114219}
0 commit comments