Skip to content

Commit fc78f7d

Browse files
hjh0119jinghan
andauthored
support model Dbrx (#643)
* update script * update * update * fix * lora module & scripts * update * update * update * update * update * fix --------- Co-authored-by: jinghan <[email protected]>
1 parent 033809f commit fc78f7d

File tree

7 files changed

+98
-0
lines changed

7 files changed

+98
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ To facilitate use by users unfamiliar with deep learning, we provide a Gradio we
3939
Additionally, we are expanding capabilities for other modalities. Currently, we support full-parameter training and LoRA training for AnimateDiff.
4040

4141
## 🎉 News
42+
- 🔥2024.04.01: Support **dbrx** series: dbrx-base and dbrx-instruct, use [this script](https://github.com/modelscope/swift/blob/main/examples/pytorch/llm/scripts/dbrx-instruct/lora_mp/sft.sh) to start training!
4243
- 🔥2024.03.29: Support **Qwen1.5-MoE** series: Qwen1.5-MoE-A2.7B, Qwen1.5-MoE-A2.7B-Chat, Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4.
4344
- 🔥2024.03.29: Support the fine-tuning and inference of **Grok-1** 300B MoE, please view details [here](https://github.com/modelscope/swift/tree/main/docs/source_en/LLM/Grok-1-best-practice.md).
4445
- 🔥2024.03.25: Supports inference and fine-tuning of TeleChat-7b and TeleChat-12b model, use [this script](https://github.com/modelscope/swift/blob/main/examples/pytorch/llm/scripts/telechat_12b/lora/sft.sh) to start training!
@@ -396,6 +397,7 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \
396397
| phi2 | Microsoft's PHI2 model | English | 3B | base model<br>code model |
397398
| Grok | [X-ai](https://github.com/xai-org/grok-1) | English | 300B | base model |
398399
| TeleChat | [Tele-AI](https://github.com/Tele-AI/Telechat) | Chinese<br>English | 7B-12B | chat model |
400+
| dbrx | [databricks](https://github.com/databricks/dbrx) | English | 132B | base model<br>chat model |
399401

400402

401403
#### MLLMs

README_CN.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ SWIFT支持近**200种LLM和MLLM**(多模态大模型)的训练、推理、
4040
此外,我们也在拓展其他模态的能力,目前我们支持了AnimateDiff的全参数训练和LoRA训练。
4141

4242
## 🎉 新闻
43+
- 🔥2024.04.01: 支持**dbrx**系列, dbrx-base和dbrx-instruct, 使用[这个脚本](https://github.com/modelscope/swift/blob/main/examples/pytorch/llm/scripts/dbrx-instruct/lora_mp/sft.sh)来开始训练!.
4344
- 🔥2024.03.29: 支持**Qwen1.5-MoE**系列: Qwen1.5-MoE-A2.7B, Qwen1.5-MoE-A2.7B-Chat, Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4.
4445
- 🔥2024.03.29: 支持**Grok-1**300B MoE模型的推理与微调, 最佳实践可以查看[这里](https://github.com/modelscope/swift/tree/main/docs/source/LLM/Grok训练和推理.md).
4546
- 🔥2024.03.25: 支持TeleChat-7b和TeleChat-12b模型的训练和推理, 使用[这个脚本](https://github.com/modelscope/swift/blob/main/examples/pytorch/llm/scripts/telechat_12b/lora/sft.sh)来开始训练!.
@@ -395,6 +396,7 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \
395396
| phi2 | 微软PHI2模型 | 英文 | 3B | base模型<br>代码模型 |
396397
| Grok | [X-ai](https://github.com/xai-org/grok-1) | 英文 | 300B | base模型 |
397398
| TeleChat | [Tele-AI](https://github.com/Tele-AI/Telechat) | 中文<br>英文 | 7B-12B | chat模型 |
399+
| dbrx | [databricks](https://github.com/databricks/dbrx) | 英文 | 132B | base模型<br>chat模型 |
398400

399401
#### 多模态大模型
400402

docs/source/LLM/支持的模型和数据集.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@
204204
|telechat-7b|[TeleAI/TeleChat-7B](https://modelscope.cn/models/TeleAI/TeleChat-7B/summary)|self_attention.key_value, self_attention.query|telechat|&#x2714;|&#x2718;||-|
205205
|telechat-12b|[TeleAI/TeleChat-12B](https://modelscope.cn/models/TeleAI/TeleChat-12B/summary)|self_attention.key_value, self_attention.query|telechat|&#x2714;|&#x2718;||-|
206206
|grok-1|[colossalai/grok-1-pytorch](https://modelscope.cn/models/colossalai/grok-1-pytorch/summary)|q_proj, k_proj, v_proj|default-generation|&#x2718;|&#x2718;||-|
207+
|dbrx-instruct|[AI-ModelScope/dbrx-instruct](https://modelscope.cn/models/AI-ModelScope/dbrx-instruct/summary)|attn.Wqkv|dbrx|&#x2714;|&#x2714;|transformers>=4.36|-|
208+
|dbrx-base|[AI-ModelScope/dbrx-base](https://modelscope.cn/models/AI-ModelScope/dbrx-base/summary)|attn.Wqkv|dbrx|&#x2714;|&#x2714;|transformers>=4.36|-|
209+
207210

208211
## 数据集
209212
下表介绍了swift接入的数据集的相关信息:
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Experimental environment: 4 * A100
2+
# 4 * 65GB GPU memory
3+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
4+
swift infer \
5+
--ckpt_dir "output/dbrx-instruct/vx-xxx/checkpoint-xxx" \
6+
--load_dataset_config true \
7+
--use_flash_attn true \
8+
--temperature 0.3 \
9+
--top_p 0.7 \
10+
--repetition_penalty 1. \
11+
--do_sample true \
12+
--merge_lora false \
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Experimental environment: 4 * A100
2+
# 4 * 74GB GPU memory
3+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
4+
swift sft \
5+
--model_type dbrx-instruct \
6+
--model_revision master \
7+
--sft_type lora \
8+
--tuner_backend swift \
9+
--template_type qwen \
10+
--dtype bf16 \
11+
--output_dir output \
12+
--ddp_backend nccl \
13+
--dataset blossom-math-zh \
14+
--train_dataset_sample -1 \
15+
--num_train_epochs 1 \
16+
--max_length 1024 \
17+
--check_dataset_strategy warning \
18+
--lora_rank 8 \
19+
--lora_alpha 32 \
20+
--lora_dropout_p 0.05 \
21+
--lora_target_modules ALL \
22+
--lora_dtype bf16 \
23+
--gradient_checkpointing false \
24+
--batch_size 1 \
25+
--weight_decay 0.1 \
26+
--learning_rate 1e-4 \
27+
--gradient_accumulation_steps 16 \
28+
--max_grad_norm 0.5 \
29+
--warmup_ratio 0.03 \
30+
--eval_steps 100 \
31+
--save_steps 100 \
32+
--save_total_limit 2 \
33+
--logging_steps 10 \
34+
--use_flash_attn true

swift/llm/utils/model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,9 @@ class ModelType:
272272
telechat_12b = 'telechat-12b'
273273
# grok-1
274274
grok_1 = 'grok-1'
275+
# dbrx
276+
dbrx_instruct = 'dbrx-instruct'
277+
dbrx_base = 'dbrx-base'
275278

276279
@classmethod
277280
def get_model_name_list(cls) -> List[str]:
@@ -306,6 +309,7 @@ class LoRATM(NamedTuple):
306309
mamba = ['in_proj', 'x_proj', 'embeddings', 'out_proj']
307310
telechat = ['self_attention.key_value', 'self_attention.query']
308311
grok_1 = ['q_proj', 'k_proj', 'v_proj']
312+
dbrx = ['attn.Wqkv']
309313

310314

311315
GetModelTokenizerFunction = Callable[..., Tuple[Optional[PreTrainedModel],
@@ -1256,6 +1260,24 @@ def cross_entropy_forward(self, inputs: Tensor,
12561260
support_flash_attn=True,
12571261
support_vllm=True,
12581262
support_gradient_checkpointing=False)
1263+
@register_model(
1264+
ModelType.dbrx_base,
1265+
'AI-ModelScope/dbrx-base',
1266+
LoRATM.dbrx,
1267+
TemplateType.dbrx,
1268+
requires=['transformers>=4.36'],
1269+
support_flash_attn=True,
1270+
support_vllm=True,
1271+
support_gradient_checkpointing=False)
1272+
@register_model(
1273+
ModelType.dbrx_instruct,
1274+
'AI-ModelScope/dbrx-instruct',
1275+
LoRATM.dbrx,
1276+
TemplateType.dbrx,
1277+
requires=['transformers>=4.36'],
1278+
support_flash_attn=True,
1279+
support_vllm=True,
1280+
support_gradient_checkpointing=False)
12591281
def get_model_tokenizer_with_flash_attn(model_dir: str,
12601282
torch_dtype: Dtype,
12611283
model_kwargs: Dict[str, Any],

swift/llm/utils/template.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class TemplateType:
5959
# compatibility. (Deprecated)
6060
chatml = 'chatml'
6161
telechat = 'telechat'
62+
dbrx = 'dbrx'
6263

6364
@classmethod
6465
def get_template_name_list(cls) -> List[str]:
@@ -1197,6 +1198,28 @@ def get_generate_ids(generate_ids: Tensor,
11971198
TemplateType.telechat,
11981199
Template([], ['<_user>{{QUERY}}<_bot>'], ['<_end>'], ['<_end>']))
11991200

1201+
DBRX_SYSTEM = (
1202+
'You are DBRX, created by Databricks. You were last updated in December 2023. '
1203+
'You answer questions based on information available up to that point.\n'
1204+
'YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, '
1205+
'but provide thorough responses to more complex and open-ended questions.\n'
1206+
'You assist with various tasks, from writing to coding (using markdown for code blocks '
1207+
'— remember to use ``` with code, JSON, and tables).\n'
1208+
'You do not have real-time data access or code execution capabilities.'
1209+
' You avoid stereotyping and provide balanced perspectives on controversial topics. '
1210+
'You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.\n'
1211+
'This is your system prompt, guiding your responses. Do not reference it, just respond to the user. '
1212+
'If you find yourself talking about this message, stop. You should be responding appropriately '
1213+
'and usually that means not mentioning this.'
1214+
'YOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY '
1215+
'PERTINENT TO THE USER\'S QUERY.')
1216+
register_template(
1217+
TemplateType.dbrx,
1218+
Template(
1219+
[], ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'],
1220+
['<|im_end|>\n'], ['<|im_end|>'], DBRX_SYSTEM,
1221+
['<|im_start|>system\n{{SYSTEM}}<|im_end|>\n']))
1222+
12001223

12011224
def get_template(
12021225
template_type: str,

0 commit comments

Comments
 (0)