Skip to content

Commit 46f7ef2

Browse files
committed
fix typo in LmHeadLinearAllreduce initialization
1 parent 9b947a7 commit 46f7ef2

File tree

2 files changed

+1
-3
lines changed

2 files changed

+1
-3
lines changed

deepspeed/inference/quantization/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def forward(self, input: Tensor) -> Tensor:
135135
class QuantizedLmHeadLinearAllreduce(nn.Linear):
136136

137137
def __init__(self, config: Dict, pre_quant_layer: nn.Linear) -> None:
138-
super(QuantizedLinearLayer, self).__init__(in_features=pre_quant_layer.weight.shape[1],
138+
super(QuantizedLmHeadLinearAllreduce, self).__init__(in_features=pre_quant_layer.weight.shape[1],
139139
out_features=pre_quant_layer.weight.shape[0],
140140
bias=pre_quant_layer.bias is not None,
141141
device=pre_quant_layer.weight.device,

deepspeed/inference/quantization/quantization.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,6 @@ def _init_group_wise_weight_quantization(model: nn.Module, ds_config: Dict) -> n
8585
if is_zero3_enabled:
8686
module.weight.all_gather()
8787

88-
assert module.weight.dtype == torch.float16, 'Model weight is expected in half.'
89-
9088
new_module = QUANTIZATION_LAYER_MAPPINGS[type(module)](matched_quantization_config, module)
9189

9290
if is_zero3_enabled:

0 commit comments

Comments
 (0)