Skip to content

Commit 462ebaa

Browse files
authored
[Mixtral] Add mixtral moe (#7803)
* Add mixtral moe * fix mixtral moe * add mixtral lora config * fix load balancing loss * add sft scripts * add unittest
1 parent 0d56490 commit 462ebaa

File tree

12 files changed

+2238
-3
lines changed

12 files changed

+2238
-3
lines changed

llm/data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ def get_convert_example(model):
4444

4545
if base_model_prefix == "chatglm":
4646
return convert_example_chatglm
47-
elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen"]:
47+
elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral"]:
4848
return convert_example_common
4949
else:
5050
raise ValueError(
51-
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama."
51+
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral"
5252
)
5353

5454

@@ -107,7 +107,7 @@ def tokenize_rounds_example(tokenizer, example, data_args):
107107
# 0. prepare data
108108
context_data = example.get("context", {})
109109
context_data["is_training"] = True
110-
110+
111111
example["src"] = example["src"] if isinstance(example["src"], list) else [example["src"]]
112112
example["tgt"] = example["tgt"] if isinstance(example["tgt"], list) else [example["tgt"]]
113113

llm/mixtral/lora_argument.json

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
{
2+
"model_name_or_path": "mistralai/Mixtral-8x7B-Instruct-v0.1",
3+
"dataset_name_or_path": "./data",
4+
"output_dir": "./checkpoints/mixtral_lora_ckpts",
5+
"per_device_train_batch_size": 4,
6+
"gradient_accumulation_steps": 4,
7+
"per_device_eval_batch_size": 8,
8+
"eval_accumulation_steps":16,
9+
"num_train_epochs": 3,
10+
"learning_rate": 3e-04,
11+
"warmup_steps": 30,
12+
"logging_steps": 1,
13+
"evaluation_strategy": "epoch",
14+
"save_strategy": "epoch",
15+
"src_length": 1024,
16+
"max_length": 2048,
17+
"fp16": true,
18+
"fp16_opt_level": "O2",
19+
"do_train": true,
20+
"do_eval": true,
21+
"disable_tqdm": true,
22+
"load_best_model_at_end": true,
23+
"eval_with_do_generation": false,
24+
"metric_for_best_model": "accuracy",
25+
"recompute": true,
26+
"save_total_limit": 1,
27+
"tensor_parallel_degree": 8,
28+
"pipeline_parallel_degree": 1,
29+
"lora": true,
30+
"zero_padding": false,
31+
"use_flash_attention": false
32+
}

llm/mixtral/sft_argument.json

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"model_name_or_path": "mistralai/Mixtral-8x7B-Instruct-v0.1",
3+
"dataset_name_or_path": "./data",
4+
"output_dir": "./checkpoints/mixtral_sft_ckpts",
5+
"per_device_train_batch_size": 4,
6+
"gradient_accumulation_steps": 4,
7+
"per_device_eval_batch_size": 8,
8+
"eval_accumulation_steps":16,
9+
"num_train_epochs": 3,
10+
"learning_rate": 3e-05,
11+
"warmup_steps": 30,
12+
"logging_steps": 1,
13+
"evaluation_strategy": "epoch",
14+
"save_strategy": "epoch",
15+
"src_length": 1024,
16+
"max_length": 2048,
17+
"bf16": true,
18+
"fp16_opt_level": "O2",
19+
"do_train": true,
20+
"do_eval": true,
21+
"disable_tqdm": true,
22+
"load_best_model_at_end": true,
23+
"eval_with_do_generation": false,
24+
"metric_for_best_model": "accuracy",
25+
"recompute": true,
26+
"save_total_limit": 1,
27+
"tensor_parallel_degree": 8,
28+
"sharding": "stage2",
29+
"pipeline_parallel_degree": 1
30+
}

llm/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,16 @@ def get_lora_target_modules(model):
149149
".*mlp.w2.*",
150150
".*mlp.c_proj.*",
151151
]
152+
elif model.base_model_prefix == "mixtral":
153+
target_modules = [
154+
".*q_proj.*",
155+
".*k_proj.*",
156+
".*v_proj.*",
157+
".*o_proj.*",
158+
".*w1.*",
159+
".*w2.*",
160+
".*w3.*",
161+
]
152162
else:
153163
raise ValueError(f"Unknown base_model_prefix: {model.base_model_prefix}.")
154164
return target_modules

paddlenlp/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,8 @@
281281
from .rw.configuration import *
282282
from .rw.tokenizer import *
283283
from .qwen import *
284+
from .mixtral.modeling import *
285+
from .mixtral.configuration import *
284286

285287
# For faster tokenizer
286288
from ..utils.import_utils import is_fast_tokenizer_available

paddlenlp/transformers/auto/modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
("Blip", "blip"),
128128
("Bloom", "bloom"),
129129
("QWen", "qwen"),
130+
("Mixtral", "mixtral"),
130131
]
131132
)
132133

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .configuration import MixtralConfig
16+
from .modeling import MixtralForCausalLM
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
""" Mixtral model configuration"""
16+
17+
from paddlenlp.transformers.configuration_utils import PretrainedConfig
18+
19+
__all__ = [
20+
"MixtralConfig",
21+
]
22+
23+
24+
class MixtralConfig(PretrainedConfig):
25+
r"""
26+
This is the configuration class to store the configuration of a [`~MixtralModel`]. It is used to instantiate an Mixtral
27+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
28+
defaults will yield a similar configuration to that of the Mixtral-7B-v0.1 or Mixtral-7B-Instruct-v0.1.
29+
30+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31+
documentation from [`PretrainedConfig`] for more information.
32+
33+
Args:
34+
vocab_size (`int`, *optional*, defaults to 32000):
35+
Vocabulary size of the Mixtral model. Defines the number of different tokens that can be represented by the
36+
`inputs_ids` passed when calling [`~MixtralModel`]
37+
hidden_size (`int`, *optional*, defaults to 4096):
38+
Dimension of the hidden representations.
39+
intermediate_size (`int`, *optional*, defaults to 14336):
40+
Dimension of the MLP representations.
41+
num_hidden_layers (`int`, *optional*, defaults to 32):
42+
Number of hidden layers in the Transformer encoder.
43+
num_attention_heads (`int`, *optional*, defaults to 32):
44+
Number of attention heads for each attention layer in the Transformer encoder.
45+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
46+
The non-linear activation function (function or string) in the decoder.
47+
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
48+
The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention
49+
allows sequence of up to 4096*32 tokens.
50+
initializer_range (`float`, *optional*, defaults to 0.02):
51+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
52+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
53+
The epsilon used by the rms normalization layers.
54+
use_cache (`bool`, *optional*, defaults to `True`):
55+
Whether or not the model should return the last key/values attentions (not used by all models). Only
56+
relevant if `config.is_decoder=True`.
57+
pad_token_id (`int`, *optional*):
58+
The id of the padding token.
59+
bos_token_id (`int`, *optional*, defaults to 1):
60+
The id of the "beginning-of-sequence" token.
61+
eos_token_id (`int`, *optional*, defaults to 2):
62+
The id of the "end-of-sequence" token.
63+
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
64+
Whether to tie weight embeddings.
65+
rope_theta (`float`, *optional*, defaults to 1000000.0):
66+
The base period of the RoPE embeddings.
67+
sliding_window (`int`, *optional*):
68+
Sliding window attention window size.
69+
attention_dropout (`float`, *optional*, defaults to 0.0):
70+
The dropout ratio for the attention probabilities.
71+
num_experts_per_tok (`int`, *optional*, defaults to 2):
72+
The number of experts to root per-token, can be also interpreted as the `top-p` routing
73+
parameter
74+
num_local_experts (`int`, *optional*, defaults to 8):
75+
Number of experts per Sparse MLP layer.
76+
output_router_logits (`bool`, *optional*, defaults to `False`):
77+
Whether or not the router logits should be returned by the model. Enabeling this will also
78+
allow the model to output the auxiliary loss. See [here]() for more details
79+
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
80+
The aux loss factor for the total loss.
81+
use_fused_rope(`bool`, *optional*, defaults to False):
82+
Enable rope fusion or not.
83+
num_key_value_heads (`int`, *optional*, defaults to 8):
84+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
85+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
86+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
87+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
88+
by meanpooling all the original heads within that group. For more details checkout [this
89+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
90+
`num_attention_heads`.
91+
Example:
92+
```python
93+
>>> from paddlenlp.transformer import MixtralModel, MixtralConfig
94+
95+
>>> # Initializing a Mixtral mixtral-7b style configuration
96+
>>> configuration = MixtralConfig()
97+
98+
>>> # Initializing a model from the mixtral-7b style configuration
99+
>>> model = MixtralModel(configuration)
100+
101+
>>> # Accessing the model configuration
102+
>>> configuration = model.config
103+
```"""
104+
105+
model_type = "mixtral"
106+
keys_to_ignore_at_inference = ["past_key_values"]
107+
108+
def __init__(
109+
self,
110+
vocab_size=32000,
111+
hidden_size=4096,
112+
intermediate_size=14336,
113+
max_position_embeddings=4096 * 32,
114+
seq_length=2048,
115+
num_hidden_layers=32,
116+
num_attention_heads=32,
117+
num_key_value_heads=8,
118+
hidden_act="silu",
119+
initializer_range=0.02,
120+
rms_norm_eps=1e-5,
121+
use_cache=True,
122+
use_recompute=False,
123+
recompute_granularity="full",
124+
no_recompute_layers=None,
125+
use_flash_attention=False,
126+
attention_dropout=0.0,
127+
use_fused_rope=False,
128+
rope_theta=1e6,
129+
tensor_parallel_output=True,
130+
sequence_parallel=False,
131+
fuse_sequence_parallel_allreduce=False,
132+
pad_token_id=0,
133+
bos_token_id=1,
134+
eos_token_id=2,
135+
tie_word_embeddings=False,
136+
num_experts_per_tok=2,
137+
num_local_experts=8,
138+
router_aux_loss_coef=0.001,
139+
output_router_logits=False,
140+
sliding_window=None,
141+
**kwargs,
142+
):
143+
self.vocab_size = vocab_size
144+
self.hidden_size = hidden_size
145+
self.intermediate_size = intermediate_size
146+
self.max_position_embeddings = max_position_embeddings
147+
self.seq_length = seq_length
148+
self.num_hidden_layers = num_hidden_layers
149+
self.num_attention_heads = num_attention_heads
150+
self.attention_dropout = attention_dropout
151+
152+
if num_key_value_heads is None:
153+
num_key_value_heads = num_attention_heads
154+
self.num_key_value_heads = num_key_value_heads
155+
self.hidden_act = hidden_act
156+
157+
self.initializer_range = initializer_range
158+
self.rms_norm_eps = rms_norm_eps
159+
160+
self.use_cache = use_cache
161+
self.use_recompute = use_recompute
162+
self.recompute_granularity = recompute_granularity
163+
self.no_recompute_layers = no_recompute_layers
164+
self.use_flash_attention = use_flash_attention
165+
self.tensor_parallel_output = tensor_parallel_output
166+
self.sequence_parallel = sequence_parallel
167+
self.fuse_sequence_parallel_allreduce = fuse_sequence_parallel_allreduce
168+
169+
self.pad_token_id = pad_token_id
170+
self.bos_token_id = bos_token_id
171+
self.eos_token_id = eos_token_id
172+
173+
self.use_fused_rope = use_fused_rope
174+
self.rope_theta = rope_theta
175+
176+
# ----------------- Experts -------------------- #
177+
self.num_experts_per_tok = num_experts_per_tok
178+
self.num_local_experts = num_local_experts
179+
self.router_aux_loss_coef = router_aux_loss_coef
180+
self.output_router_logits = output_router_logits
181+
182+
self.sliding_window = sliding_window
183+
184+
super().__init__(
185+
pad_token_id=pad_token_id,
186+
bos_token_id=bos_token_id,
187+
eos_token_id=eos_token_id,
188+
tie_word_embeddings=tie_word_embeddings,
189+
tensor_parallel_output=tensor_parallel_output,
190+
**kwargs,
191+
)

0 commit comments

Comments
 (0)