|
12 | 12 | "pack_single=1,pack_complete=0,pack_buffer_size=50",
|
13 | 13 | "ds_name": "RefactFIMCodeDataset"
|
14 | 14 | }
|
15 |
| -_bigcode_tokenizer_mapping = { |
16 |
| - "eot_idx": 0, |
17 |
| - "padding_idx": 4, |
18 |
| - "fim_prefix": 1, |
19 |
| - "fim_middle": 2, |
20 |
| - "fim_suffix": 3, |
21 |
| - "escape": 14 |
22 |
| -} |
23 |
| -_starcoder_base = { |
24 |
| - "lora_target_modules_mapping": { |
25 |
| - "qkv": ["attn.q_attn", "attn.c_attn"], |
26 |
| - "out": ["attn.c_proj"], |
27 |
| - "backproj": ["attn.c_proj"], |
28 |
| - "mlp": ["mlp.c_fc", "mlp.c_proj"], |
29 |
| - }, |
30 |
| - "freeze_exceptions_mapping": { |
31 |
| - "wte": ["wte", "wpe"], |
32 |
| - "lm_head": ["lm_head"], |
33 |
| - "lora": ["lora"] |
34 |
| - }, |
35 |
| - "tokenizer": _bigcode_tokenizer_mapping, |
36 |
| - "train_ds_pipeline": _fim_train_ds_pipeline, |
37 |
| - "test_ds_pipeline": _fim_test_ds_pipeline, |
38 |
| - "train_model_modifiers": [ |
39 |
| - "flash_sa.apply_flash_mha_to_starcoder_model" |
40 |
| - ], |
41 |
| - "force_enable_checkpointing": False |
42 |
| -} |
43 |
| -_starcoder2_base = { |
44 |
| - "lora_target_modules_mapping": { |
45 |
| - "qkv": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], |
46 |
| - "out": ["self_attn.o_proj"], |
47 |
| - "backproj": ["self_attn.o_proj"], |
48 |
| - "mlp": ["mlp.c_fc", "mlp.c_proj"], |
49 |
| - }, |
50 |
| - "freeze_exceptions_mapping": { |
51 |
| - "wte": ["embed_tokens"], |
52 |
| - "lm_head": ["lm_head"], |
53 |
| - "lora": ["lora"] |
54 |
| - }, |
55 |
| - "tokenizer": _bigcode_tokenizer_mapping, |
56 |
| - "train_ds_pipeline": _fim_train_ds_pipeline, |
57 |
| - "test_ds_pipeline": _fim_test_ds_pipeline, |
58 |
| - "train_model_modifiers": [ |
59 |
| - "flash_sa.apply_flash_mha_to_starcoder2_model" |
60 |
| - ], |
61 |
| - "force_enable_checkpointing": True |
62 |
| -} |
63 |
| -_deepseek_base = { |
64 |
| - "lora_target_modules_mapping": { |
65 |
| - "qkv": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], |
66 |
| - "out": ["self_attn.o_proj"], |
67 |
| - "backproj": ["self_attn.o_proj"], |
68 |
| - "mlp": ["mlp.gate_proj", "mlp.up_proj", "mlp.down_proj"], |
69 |
| - }, |
70 |
| - "freeze_exceptions_mapping": { |
71 |
| - "wte": ["embed_tokens"], |
72 |
| - "lm_head": ["lm_head"], |
73 |
| - "lora": ["lora"] |
74 |
| - }, |
75 |
| - "tokenizer": { |
76 |
| - "eot_idx": 32021, # `<|EOT|>` |
77 |
| - "padding_idx": 32018, # `<pad>` |
78 |
| - "fim_prefix": 32016, # `<|fim▁begin|>` |
79 |
| - "fim_middle": 32017, # `<|fim▁end|>` |
80 |
| - "fim_suffix": 32015, # `<|fim▁hole|>` |
81 |
| - "escape": 32013, # using `<|begin▁of▁sentence|>` token for now |
82 |
| - }, |
83 |
| - "train_ds_pipeline": { |
84 |
| - "ds_opts": f"{_fim_train_ds_pipeline['ds_opts']},spm_prob=0.0", |
85 |
| - "ds_name": _fim_train_ds_pipeline["ds_name"] |
86 |
| - }, |
87 |
| - "test_ds_pipeline": _fim_test_ds_pipeline, |
88 |
| - "train_model_modifiers": [ |
89 |
| - "flash_sa.apply_flash_mha_to_codellama_model" |
90 |
| - ], |
91 |
| - "force_enable_checkpointing": False |
92 |
| -} |
| 15 | + |
93 | 16 | _qwen_base = {
|
94 | 17 | "lora_target_modules_mapping": {
|
95 | 18 | "qkv": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
|
|
122 | 45 | }
|
123 | 46 |
|
124 | 47 | config = {
|
125 |
| - "Refact/1.6B": { |
126 |
| - "lora_target_modules_mapping": { |
127 |
| - "qkv": ["attn.q", "attn.kv"], |
128 |
| - "out": ["attn.c_proj"], |
129 |
| - "backproj": ["attn.c_proj"], |
130 |
| - "mlp": ["mlp.gate_up_proj", "mlp.c_proj"], |
131 |
| - }, |
132 |
| - "freeze_exceptions_mapping": { |
133 |
| - "wte": ["wte"], |
134 |
| - "lm_head": ["lm_head"], |
135 |
| - "lora": ["lora"] |
136 |
| - }, |
137 |
| - "tokenizer": _bigcode_tokenizer_mapping, |
138 |
| - "train_ds_pipeline": _fim_train_ds_pipeline, |
139 |
| - "test_ds_pipeline": _fim_test_ds_pipeline, |
140 |
| - "train_model_modifiers": [ |
141 |
| - "flash_sa.apply_flash_mha_to_refact_model" |
142 |
| - ], |
143 |
| - "force_enable_checkpointing": False |
144 |
| - }, |
145 |
| - |
146 |
| - "starcoder/1b/base": _starcoder_base, |
147 |
| - |
148 |
| - "starcoder/3b/base": _starcoder_base, |
149 |
| - |
150 |
| - "starcoder/7b/base": { |
151 |
| - **_starcoder_base, |
152 |
| - "force_enable_checkpointing": True |
153 |
| - }, |
154 |
| - |
155 |
| - "starcoder2/3b/base": _starcoder2_base, |
156 |
| - |
157 |
| - "starcoder2/7b/base": { |
158 |
| - **_starcoder2_base, |
159 |
| - "force_enable_checkpointing": True |
160 |
| - }, |
161 |
| - |
162 |
| - "starcoder2/15b/base": { |
163 |
| - **_starcoder2_base, |
164 |
| - "force_enable_checkpointing": True |
165 |
| - }, |
166 |
| - |
167 |
| - "codellama/7b": { |
168 |
| - "lora_target_modules_mapping": { |
169 |
| - "qkv": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], |
170 |
| - "out": ["self_attn.o_proj"], |
171 |
| - "backproj": ["self_attn.o_proj"], |
172 |
| - "mlp": ["mlp.gate_proj", "mlp.up_proj", "mlp.down_proj"], |
173 |
| - }, |
174 |
| - "freeze_exceptions_mapping": { |
175 |
| - "wte": ["embed_tokens"], |
176 |
| - "lm_head": ["lm_head"], |
177 |
| - "lora": ["lora"] |
178 |
| - }, |
179 |
| - "tokenizer": { |
180 |
| - "eot_idx": 32010, |
181 |
| - "padding_idx": 2, # there is no padding token, so instead using `eos` token as in `gpt2` |
182 |
| - "fim_prefix": 32007, |
183 |
| - "fim_middle": 32009, |
184 |
| - "fim_suffix": 32008, |
185 |
| - "escape": 0, # using <unk> token |
186 |
| - "bos_idx": 1 |
187 |
| - }, |
188 |
| - "train_ds_pipeline": { |
189 |
| - **_fim_train_ds_pipeline, |
190 |
| - "ds_name": "CodeLLamaFIMDataset" |
191 |
| - }, |
192 |
| - "test_ds_pipeline": { |
193 |
| - **_fim_test_ds_pipeline, |
194 |
| - "ds_name": "CodeLLamaFIMDataset" |
195 |
| - }, |
196 |
| - "train_model_modifiers": [ |
197 |
| - "flash_sa.apply_flash_mha_to_codellama_model" |
198 |
| - ], |
199 |
| - "force_enable_checkpointing": True |
200 |
| - }, |
201 |
| - |
202 |
| - "deepseek-coder/1.3b/base": _deepseek_base, |
203 |
| - |
204 |
| - "deepseek-coder/5.7b/mqa-base": { |
205 |
| - **_deepseek_base, |
206 |
| - "force_enable_checkpointing": True |
207 |
| - }, |
208 |
| - |
209 |
| - "deepseek-coder/6.7b/base": { |
210 |
| - **_deepseek_base, |
211 |
| - "force_enable_checkpointing": True |
212 |
| - }, |
213 |
| - |
214 | 48 | # qwen models
|
215 | 49 | "qwen2.5/coder/32b/base": {
|
216 | 50 | **_qwen_base,
|
|
0 commit comments