|
| 1 | +import random |
| 2 | + |
| 3 | +import pytest |
| 4 | +import torch |
| 5 | +import transformers |
| 6 | +from tqdm import tqdm |
| 7 | + |
| 8 | +from neural_compressor.common.utils import logger |
| 9 | +from neural_compressor.torch.algorithms.weight_only.gptq import move_input_to_device |
| 10 | +from neural_compressor.torch.quantization import GPTQConfig, get_default_rtn_config, quantize |
| 11 | + |
| 12 | + |
| 13 | +class GPTQDataloaderPreprocessor: |
| 14 | + def __init__(self, dataloader_original, use_max_length=False, max_seq_length=2048, nsamples=128): |
| 15 | + self.dataloader_original = dataloader_original |
| 16 | + self.use_max_length = use_max_length |
| 17 | + self.max_seq_length = max_seq_length |
| 18 | + self.nsamples = nsamples |
| 19 | + self.dataloader = [] |
| 20 | + self.is_ready = False |
| 21 | + |
| 22 | + def get_prepared_dataloader(self): |
| 23 | + if not self.is_ready: |
| 24 | + self.prepare_dataloader() |
| 25 | + return self.dataloader |
| 26 | + |
| 27 | + def prepare_dataloader(self): |
| 28 | + if self.use_max_length: |
| 29 | + # (Recommend) only take sequence whose length exceeds self.max_seq_length, |
| 30 | + # which preserves calibration's tokens are all valid |
| 31 | + # This is GPTQ official dataloader implementation |
| 32 | + self.obtain_first_n_samples_fulllength() |
| 33 | + else: |
| 34 | + # general selection, no padding, not GPTQ original implementation. |
| 35 | + self.obtain_first_n_samples() |
| 36 | + self.is_ready = True |
| 37 | + |
| 38 | + def obtain_first_n_samples(self, seed=0): |
| 39 | + """Get first nsample data as the real calibration dataset.""" |
| 40 | + self.dataloader.clear() |
| 41 | + random.seed(seed) |
| 42 | + for batch in self.dataloader_original: |
| 43 | + # process data, depends on its data type. |
| 44 | + if len(self.dataloader) == self.nsamples: |
| 45 | + logger.info(f"Successfully collect {self.nsamples} calibration samples.") |
| 46 | + break |
| 47 | + # list, tuple |
| 48 | + if isinstance(batch, list) or isinstance(batch, tuple): |
| 49 | + if batch[0].shape[-1] > self.max_seq_length: |
| 50 | + i = random.randint(0, batch[0].shape[-1] - self.max_seq_length - 1) |
| 51 | + j = i + self.max_seq_length |
| 52 | + batch_final = [] |
| 53 | + for item in batch: |
| 54 | + if isinstance(item, torch.Tensor) and item.shape.__len__() == 2: |
| 55 | + batch_final.append(item[:, i:j]) |
| 56 | + else: |
| 57 | + batch_final.append(item) |
| 58 | + else: |
| 59 | + batch_final = batch[:] |
| 60 | + # dict |
| 61 | + elif isinstance(batch, dict): |
| 62 | + try: |
| 63 | + length = batch["input_ids"].shape[-1] |
| 64 | + except: |
| 65 | + logger.warning("Please make sure your dict'like data contains key of 'input_ids'.") |
| 66 | + continue |
| 67 | + batch_final = {} |
| 68 | + if length > self.max_seq_length: |
| 69 | + i = random.randint(0, length - self.max_seq_length - 1) |
| 70 | + j = i + self.max_seq_length |
| 71 | + # may have to slice every sequence related data |
| 72 | + for key in batch.keys(): |
| 73 | + if isinstance(batch[key], torch.Tensor): |
| 74 | + batch_final[key] = batch[key][:, i:j] # slice on sequence length dim |
| 75 | + else: |
| 76 | + batch_final[key] = batch[key] |
| 77 | + else: |
| 78 | + batch_final = batch |
| 79 | + # tensor |
| 80 | + else: |
| 81 | + if batch.shape[-1] > self.max_seq_length: |
| 82 | + i = random.randint(0, batch.shape[-1] - self.max_seq_length - 1) |
| 83 | + j = i + self.max_seq_length |
| 84 | + batch_final = batch[:, i:j] |
| 85 | + else: |
| 86 | + batch_final = batch |
| 87 | + self.dataloader.append(batch_final) |
| 88 | + |
| 89 | + if len(self.dataloader) < self.nsamples: |
| 90 | + logger.warning(f"Try to use {self.nsamples} data, but entire dataset size is {len(self.dataloader)}.") |
| 91 | + |
| 92 | + def obtain_first_n_samples_fulllength(self, seed=0): |
| 93 | + self.dataloader.clear() |
| 94 | + random.seed(seed) |
| 95 | + unified_length = self.max_seq_length |
| 96 | + for batch in self.dataloader_original: |
| 97 | + if len(self.dataloader) == self.nsamples: |
| 98 | + logger.info(f"Successfully collect {self.nsamples} calibration samples.") |
| 99 | + break |
| 100 | + # list & tuple, gpt-j-6b mlperf, etc. |
| 101 | + if isinstance(batch, list) or isinstance(batch, tuple): |
| 102 | + if batch[0].shape[-1] == unified_length: |
| 103 | + batch_final = batch[:] |
| 104 | + elif batch[0].shape[-1] > unified_length: |
| 105 | + i = random.randint(0, batch[0].shape[-1] - unified_length - 1) |
| 106 | + j = i + unified_length |
| 107 | + batch_final = [] |
| 108 | + for item in batch: |
| 109 | + if isinstance(item, torch.Tensor) and item.shape.__len__() == 2: |
| 110 | + batch_final.append(item[:, i:j]) |
| 111 | + else: |
| 112 | + batch_final.append(item) |
| 113 | + else: |
| 114 | + # not match max length, not include in target dataset |
| 115 | + continue |
| 116 | + # dict |
| 117 | + elif isinstance(batch, dict): |
| 118 | + try: |
| 119 | + length = batch["input_ids"].shape[-1] |
| 120 | + except: |
| 121 | + logger.warning("Please make sure your dict'like data contains key of 'input_ids'.") |
| 122 | + continue |
| 123 | + batch_final = {} |
| 124 | + if length == self.max_seq_length: |
| 125 | + batch_final = batch |
| 126 | + elif length > self.max_seq_length: |
| 127 | + i = random.randint(0, length - self.max_seq_length - 1) |
| 128 | + j = i + self.max_seq_length |
| 129 | + # may have to slice every sequence related data |
| 130 | + for key in batch.keys(): |
| 131 | + if isinstance(batch[key], torch.Tensor): |
| 132 | + batch_final[key] = batch[key][:, i:j] # slice on sequence length dim with same position |
| 133 | + else: |
| 134 | + batch_final[key] = batch[key] |
| 135 | + else: |
| 136 | + # not match max length, not include in target dataset |
| 137 | + continue |
| 138 | + # tensor |
| 139 | + else: |
| 140 | + if batch.shape[-1] == unified_length: |
| 141 | + batch_final = batch |
| 142 | + elif batch.shape[-1] > unified_length: |
| 143 | + i = random.randint(0, batch.shape[-1] - unified_length - 1) |
| 144 | + j = i + unified_length |
| 145 | + batch_final = batch[:, i:j] |
| 146 | + else: |
| 147 | + # not match max length, not include in target dataset |
| 148 | + continue |
| 149 | + self.dataloader.append(batch_final) |
| 150 | + if len(self.dataloader) < self.nsamples: # pragma: no cover |
| 151 | + logger.warning( |
| 152 | + f"Trying to allocate {self.nsamples} data with fixed length {unified_length}, \ |
| 153 | + but only {len(self.dataloader)} samples are found. Please use smaller 'self.max_seq_length' value." |
| 154 | + ) |
| 155 | + |
| 156 | + |
| 157 | +class TestGPTQ: |
| 158 | + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU") |
| 159 | + def test_GPTQ_fixed_length_quant(self): |
| 160 | + class GPTQLLMDataLoader: |
| 161 | + def __init__(self): |
| 162 | + self.batch_size = 1 |
| 163 | + |
| 164 | + def __iter__(self): |
| 165 | + for i in range(10): |
| 166 | + yield torch.ones([1, 512], dtype=torch.long) |
| 167 | + |
| 168 | + class GPTQLLMDataLoaderList: |
| 169 | + def __init__(self): |
| 170 | + self.batch_size = 1 |
| 171 | + |
| 172 | + def __iter__(self): |
| 173 | + for i in range(10): |
| 174 | + yield (torch.ones([1, 512], dtype=torch.long), torch.ones([1, 512], dtype=torch.long)) |
| 175 | + |
| 176 | + class GPTQLLMDataLoaderDict: |
| 177 | + def __init__(self): |
| 178 | + self.batch_size = 1 |
| 179 | + |
| 180 | + def __iter__(self): |
| 181 | + for i in range(10): |
| 182 | + yield { |
| 183 | + "input_ids": torch.ones([1, 512], dtype=torch.long), |
| 184 | + "attention_mask": torch.ones([1, 512], dtype=torch.long), |
| 185 | + } |
| 186 | + |
| 187 | + dataloader_list = GPTQLLMDataLoaderList() |
| 188 | + dataloader_dict = GPTQLLMDataLoaderDict() |
| 189 | + |
| 190 | + quant_config = GPTQConfig() |
| 191 | + quant_config.set_local("lm_head", GPTQConfig(dtype="fp32")) |
| 192 | + |
| 193 | + gptq_use_max_length = False |
| 194 | + gptq_max_seq_length = 2048 |
| 195 | + dataloaderPreprocessor = GPTQDataloaderPreprocessor( |
| 196 | + dataloader_original=dataloader_list, |
| 197 | + use_max_length=gptq_use_max_length, |
| 198 | + max_seq_length=gptq_max_seq_length, |
| 199 | + ) |
| 200 | + dataloader_for_calibration = dataloaderPreprocessor.get_prepared_dataloader() |
| 201 | + |
| 202 | + def run_fn_for_gptq(model, dataloader_for_calibration, *args): |
| 203 | + for batch in tqdm(dataloader_for_calibration): |
| 204 | + batch = move_input_to_device(batch, device=model.device) |
| 205 | + try: |
| 206 | + if isinstance(batch, tuple) or isinstance(batch, list): |
| 207 | + model(batch[0]) |
| 208 | + elif isinstance(batch, dict): |
| 209 | + model(**batch) |
| 210 | + else: |
| 211 | + model(batch) |
| 212 | + except ValueError: |
| 213 | + pass |
| 214 | + return |
| 215 | + |
| 216 | + user_model = transformers.AutoModelForCausalLM.from_pretrained( |
| 217 | + "hf-internal-testing/tiny-random-GPTJForCausalLM", |
| 218 | + ) |
| 219 | + |
| 220 | + user_model = quantize( |
| 221 | + model=user_model, quant_config=quant_config, run_fn=run_fn_for_gptq, run_args=dataloader_for_calibration |
| 222 | + ) |
| 223 | + model_device = str(user_model.device) |
| 224 | + assert "cuda" in model_device, f"Model device is {model_device}" |
| 225 | + |
| 226 | + |
| 227 | +class TestRTNQuant: |
| 228 | + |
| 229 | + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU") |
| 230 | + def test_rtn(self): |
| 231 | + self.tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained( |
| 232 | + "hf-internal-testing/tiny-random-GPTJForCausalLM", |
| 233 | + ) |
| 234 | + self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long) |
| 235 | + model = self.tiny_gptj |
| 236 | + # record label for comparison |
| 237 | + self.label = model(self.example_inputs.to(model.device))[0] |
| 238 | + # test_default_config |
| 239 | + quant_config = get_default_rtn_config() |
| 240 | + q_model = quantize(model, quant_config) |
| 241 | + assert "cuda" in str(q_model.device), f"Expect qmodel device is cuda, got {q_model.device}" |
0 commit comments