Skip to content

Commit 9b947a7

Browse files
committed
support autoTP with weight quantization in DS inference path
Signed-off-by: Feng Tian <[email protected]>
1 parent b3edd2f commit 9b947a7

File tree

4 files changed

+200
-2
lines changed

4 files changed

+200
-2
lines changed

deepspeed/inference/quantization/layers.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,67 @@ def __init__(self, config: Dict, pre_quant_layer: nn.Linear) -> None:
5353
device=pre_quant_layer.weight.device,
5454
dtype=pre_quant_layer.weight.dtype)
5555
self.config = config
56+
self.quantizer = Quantizer(config=config)
57+
self.bias = pre_quant_layer.bias
58+
self.weight = get_quantized_weight_wrapper(self, pre_quant_layer.weight,
59+
get_quantize_weight_fn(self.quantizer, pre_quant_layer.weight))
60+
61+
self.weight.dequantizer = DeQuantizer(config, pre_quant_layer.weight.dtype)
62+
63+
def forward(self, input: Tensor) -> Tensor:
64+
quantized_weight, quant_scale, quant_min = self.weight.deconcat(self.weight)
65+
temp_dequantized_weight = self.weight.dequantizer.dequantize(quantized_weight.view(torch.uint8), quant_scale,
66+
quant_min)
67+
68+
# !!! Do not use torch.functional.linear(input, temp_dequantized_weight, self.bias) here as in zero3 torch.functional.linear is
69+
# replaced by LinearFunctionForZeroStage3. Which assume weight is non-temporary.
70+
# If weight is temp buffer there will be memory leak.
71+
return torch._C._nn.linear(input, temp_dequantized_weight, self.bias)
72+
73+
74+
class QuantizedLinearAllreduce(nn.Linear):
75+
76+
def __init__(self, config: Dict, pre_quant_layer: nn.Linear) -> None:
77+
super(QuantizedLinearAllreduce, self).__init__(in_features=pre_quant_layer.weight.shape[1],
78+
out_features=pre_quant_layer.weight.shape[0],
79+
bias=pre_quant_layer.bias is not None,
80+
device=pre_quant_layer.weight.device,
81+
dtype=pre_quant_layer.weight.dtype)
82+
self.config = config
83+
self.mp_group = pre_quant_layer.mp_group if hasattr(pre_quant_layer, 'mp_group') else None
84+
self.quantizer = Quantizer(config=config, mp_group=self.mp_group)
85+
self.bias = pre_quant_layer.bias
86+
self.weight = get_quantized_weight_wrapper(self, pre_quant_layer.weight,
87+
get_quantize_weight_fn(self.quantizer, pre_quant_layer.weight))
88+
89+
self.weight.dequantizer = DeQuantizer(config, pre_quant_layer.weight.dtype)
90+
91+
def forward(self, input: Tensor) -> Tensor:
92+
quantized_weight, quant_scale, quant_min = self.weight.deconcat(self.weight)
93+
temp_dequantized_weight = self.weight.dequantizer.dequantize(quantized_weight.view(torch.uint8), quant_scale,
94+
quant_min)
5695

96+
# !!! Do not use torch.functional.linear(input, temp_dequantized_weight, self.bias) here as in zero3 torch.functional.linear is
97+
# replaced by LinearFunctionForZeroStage3. Which assume weight is non-temporary.
98+
# If weight is temp buffer there will be memory leak.
99+
output = torch._C._nn.linear(input, temp_dequantized_weight)
100+
if self.mp_group is not None:
101+
from deepspeed import comm as dist
102+
dist.inference_all_reduce(output, group=self.mp_group)
103+
if self.bias is not None:
104+
output += self.bias
105+
return output
106+
107+
108+
class QuantizedLinearLayer(nn.Linear):
109+
110+
def __init__(self, config: Dict, pre_quant_layer: nn.Linear) -> None:
111+
super(QuantizedLinearLayer, self).__init__(in_features=pre_quant_layer.weight.shape[1],
112+
out_features=pre_quant_layer.weight.shape[0],
113+
bias=pre_quant_layer.bias is not None,
114+
device=pre_quant_layer.weight.device,
115+
dtype=pre_quant_layer.weight.dtype)
116+
self.config = config
57117
self.quantizer = Quantizer(config=config)
58118
self.bias = pre_quant_layer.bias
59119
self.weight = get_quantized_weight_wrapper(self, pre_quant_layer.weight,
@@ -72,6 +132,46 @@ def forward(self, input: Tensor) -> Tensor:
72132
return torch._C._nn.linear(input, temp_dequantized_weight, self.bias)
73133

74134

135+
class QuantizedLmHeadLinearAllreduce(nn.Linear):
136+
137+
def __init__(self, config: Dict, pre_quant_layer: nn.Linear) -> None:
138+
super(QuantizedLinearLayer, self).__init__(in_features=pre_quant_layer.weight.shape[1],
139+
out_features=pre_quant_layer.weight.shape[0],
140+
bias=pre_quant_layer.bias is not None,
141+
device=pre_quant_layer.weight.device,
142+
dtype=pre_quant_layer.weight.dtype)
143+
self.config = config
144+
self.quantizer = Quantizer(config=config)
145+
self.bias = pre_quant_layer.bias
146+
self.rank = pre_quant_layer.rank
147+
self.world_size = pre_quant_layer.world_size
148+
self.weight = get_quantized_weight_wrapper(self, pre_quant_layer.weight,
149+
get_quantize_weight_fn(self.quantizer, pre_quant_layer.weight))
150+
151+
self.weight.dequantizer = DeQuantizer(config, pre_quant_layer.weight.dtype)
152+
153+
def forward(self, input: Tensor) -> Tensor:
154+
quantized_weight, quant_scale, quant_min = self.weight.deconcat(self.weight)
155+
temp_dequantized_weight = self.weight.dequantizer.dequantize(quantized_weight.view(torch.uint8), quant_scale,
156+
quant_min)
157+
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
158+
input_shard_size = get_shard_size(input.shape[-1], self.world_size)
159+
input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size)[0:self.rank])
160+
161+
# !!! Do not use torch.functional.linear(input, temp_dequantized_weight, self.bias) here as in zero3 torch.functional.linear is
162+
# replaced by LinearFunctionForZeroStage3. Which assume weight is non-temporary.
163+
# If weight is temp buffer there will be memory leak.
164+
output = torch._C._nn.linear(input[:, :, input_shard_offset:input_shard_offset + input_shard_size],
165+
temp_dequantized_weight.transpose(-1, -2))
166+
167+
if self.mp_group is not None:
168+
from deepspeed import comm as dist
169+
dist.inference_all_reduce(output, group=self.mp_group)
170+
if self.bias is not None:
171+
output += self.bias
172+
return output
173+
174+
75175
class QuantizedEmbedding(nn.Embedding):
76176

77177
def __init__(self, config: Dict, pre_quant_layer: nn.Embedding) -> None:
@@ -108,7 +208,12 @@ def forward(self, input: Tensor) -> Tensor:
108208
self.scale_grad_by_freq, self.sparse)
109209

110210

211+
from ...module_inject import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
212+
111213
QUANTIZATION_LAYER_MAPPINGS = {
112214
nn.Linear: QuantizedLinear,
113215
nn.Embedding: QuantizedEmbedding,
216+
LinearAllreduce: QuantizedLinearAllreduce,
217+
LinearLayer: QuantizedLinearLayer,
218+
LmHeadLinearAllreduce: QuantizedLmHeadLinearAllreduce
114219
}

deepspeed/inference/quantization/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ def tensor_round(tensor: Tensor) -> Tensor:
4242

4343
class Quantizer:
4444

45-
def __init__(self, config: Dict) -> None:
45+
def __init__(self, config: Dict, mp_group=None) -> None:
4646
self.config = config
47+
self.mp_group = mp_group
4748
assert self.config['num_bits'] == 4 or self.config[
4849
'num_bits'] == 8, 'Only INT4 and INT8 quantization is supported.'
4950
assert self.config['symmetric'] == False, 'Only asymmetric quantization is supported at this moment.'

deepspeed/module_inject/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66
from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection
77
from .module_quantize import quantize_transformer_layer
88
from .replace_policy import HFBertLayerPolicy
9-
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize
9+
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, EmbeddingLayer, Normalize
1010
from .policy import DSPolicy

tests/unit/inference/test_inference.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,98 @@ def test(
512512
assert assert_fn(bs_output, ds_output)
513513

514514

515+
@pytest.mark.seq_inference
516+
@pytest.mark.parametrize("model_w_task", [("tiiuae/falcon-7b", "text-generation")], ids=["falcon"])
517+
class TestAutoTP(DistributedTest):
518+
world_size = 1
519+
520+
def test(
521+
self,
522+
model_w_task,
523+
query,
524+
inf_kwargs,
525+
assert_fn,
526+
):
527+
# TODO: enable this test for H100 tests
528+
pytest.skip("Not enough GPU memory for this on V100 runners")
529+
model, task = model_w_task
530+
dtype = torch.bfloat16
531+
local_rank = int(os.getenv("LOCAL_RANK", "0"))
532+
533+
# We have to load these large models on CPU with pipeline because not
534+
# enough GPU memory
535+
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
536+
pipe = pipeline(task,
537+
model=model,
538+
tokenizer=tokenizer,
539+
torch_dtype=dtype,
540+
trust_remote_code=True,
541+
device=torch.device("cpu"),
542+
framework="pt")
543+
#bs_output = pipe(query, **inf_kwargs)
544+
545+
pipe.model = deepspeed.init_inference(pipe.model, mp_size=self.world_size, replace_with_kernel_inject=False)
546+
# Switch device to GPU so that input tensors are not on CPU
547+
pipe.device = torch.device(get_accelerator().device_name(local_rank))
548+
ds_output = pipe(query, **inf_kwargs)
549+
550+
#print(local_rank, "baseline", bs_output)
551+
print(local_rank, "deepspeed", ds_output)
552+
#assert assert_fn(bs_output, ds_output)
553+
554+
555+
@pytest.mark.seq_inference
556+
@pytest.mark.parametrize("model_w_task", [("tiiuae/falcon-7b", "text-generation")], ids=["falcon"])
557+
class TestAutoTPwithWeightQuant(DistributedTest):
558+
world_size = 2
559+
560+
def test(
561+
self,
562+
model_w_task,
563+
query,
564+
inf_kwargs,
565+
assert_fn,
566+
):
567+
# TODO: enable this test for H100 tests
568+
pytest.skip("Not enough GPU memory for this on V100 runners")
569+
model, task = model_w_task
570+
dtype = torch.bfloat16
571+
local_rank = int(os.getenv("LOCAL_RANK", "0"))
572+
573+
# We have to load these large models on CPU with pipeline because not
574+
# enough GPU memory
575+
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
576+
pipe = pipeline(task,
577+
model=model,
578+
tokenizer=tokenizer,
579+
torch_dtype=dtype,
580+
trust_remote_code=True,
581+
device=torch.device("cpu"),
582+
framework="pt")
583+
584+
pipe.model = deepspeed.init_inference(pipe.model, mp_size=self.world_size, replace_with_kernel_inject=False)
585+
ds_config = {
586+
"weight_quantization": {
587+
"post_init_quant": {
588+
'*': {
589+
'num_bits': 4,
590+
'group_size': 32,
591+
'group_dim': 1,
592+
'symmetric': False
593+
},
594+
}
595+
}
596+
}
597+
from deepspeed.inference.quantization.quantization import _init_group_wise_weight_quantization
598+
pipe.model = _init_group_wise_weight_quantization(pipe.model, ds_config)
599+
pipe.device = torch.device(get_accelerator().device_name(local_rank))
600+
ds_output = pipe(query, **inf_kwargs)
601+
602+
#print(local_rank, "baseline", bs_output)
603+
print(local_rank, "deepspeed", ds_output)
604+
#assert assert_fn(bs_output, ds_output)
605+
606+
515607
@pytest.mark.seq_inference
516608
@pytest.mark.parametrize(
517609
"model_w_task, injection_policy",

0 commit comments

Comments
 (0)