|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +""" Qwen2MoE model configuration""" |
| 16 | + |
| 17 | +from paddlenlp.transformers.configuration_utils import PretrainedConfig |
| 18 | + |
| 19 | +__all__ = [ |
| 20 | + "Qwen2MoeConfig", |
| 21 | +] |
| 22 | + |
| 23 | + |
| 24 | +class Qwen2MoeConfig(PretrainedConfig): |
| 25 | + r""" |
| 26 | + This is the configuration class to store the configuration of a [`Qwen2MoeModel`]. It is used to instantiate a |
| 27 | + Qwen2MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration |
| 28 | + with the defaults will yield a similar configuration to that of |
| 29 | + Qwen1.5-MoE-A2.7B" [Qwen/Qwen1.5-MoE-A2.7B"](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B"). |
| 30 | +
|
| 31 | + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the |
| 32 | + documentation from [`PretrainedConfig`] for more information. |
| 33 | +
|
| 34 | +
|
| 35 | + Args: |
| 36 | + vocab_size (`int`, *optional*, defaults to 151936): |
| 37 | + Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the |
| 38 | + `inputs_ids` passed when calling [`Qwen2MoeModel`] |
| 39 | + hidden_size (`int`, *optional*, defaults to 2048): |
| 40 | + Dimension of the hidden representations. |
| 41 | + intermediate_size (`int`, *optional*, defaults to 5632): |
| 42 | + Dimension of the MLP representations. |
| 43 | + num_hidden_layers (`int`, *optional*, defaults to 24): |
| 44 | + Number of hidden layers in the Transformer encoder. |
| 45 | + num_attention_heads (`int`, *optional*, defaults to 16): |
| 46 | + Number of attention heads for each attention layer in the Transformer encoder. |
| 47 | + num_key_value_heads (`int`, *optional*, defaults to 16): |
| 48 | + This is the number of key_value heads that should be used to implement Grouped Query Attention. If |
| 49 | + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if |
| 50 | + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When |
| 51 | + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed |
| 52 | + by meanpooling all the original heads within that group. For more details checkout [this |
| 53 | + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. |
| 54 | + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): |
| 55 | + The non-linear activation function (function or string) in the decoder. |
| 56 | + max_position_embeddings (`int`, *optional*, defaults to 32768): |
| 57 | + The maximum sequence length that this model might ever be used with. |
| 58 | + initializer_range (`float`, *optional*, defaults to 0.02): |
| 59 | + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. |
| 60 | + rms_norm_eps (`float`, *optional*, defaults to 1e-06): |
| 61 | + The epsilon used by the rms normalization layers. |
| 62 | + use_cache (`bool`, *optional*, defaults to `True`): |
| 63 | + Whether or not the model should return the last key/values attentions (not used by all models). Only |
| 64 | + relevant if `config.is_decoder=True`. |
| 65 | + tie_word_embeddings (`bool`, *optional*, defaults to `False`): |
| 66 | + Whether the model's input and output word embeddings should be tied. |
| 67 | + rope_theta (`float`, *optional*, defaults to 10000.0): |
| 68 | + The base period of the RoPE embeddings. |
| 69 | + use_sliding_window (`bool`, *optional*, defaults to `False`): |
| 70 | + Whether to use sliding window attention. |
| 71 | + sliding_window (`int`, *optional*, defaults to 4096): |
| 72 | + Sliding window attention (SWA) window size. If not specified, will default to `4096`. |
| 73 | + max_window_layers (`int`, *optional*, defaults to 28): |
| 74 | + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. |
| 75 | + attention_dropout (`float`, *optional*, defaults to 0.0): |
| 76 | + The dropout ratio for the attention probabilities. |
| 77 | + decoder_sparse_step (`int`, *optional*, defaults to 1): |
| 78 | + The frequency of the MoE layer. |
| 79 | + moe_intermediate_size (`int`, *optional*, defaults to 1408): |
| 80 | + Intermediate size of the routed expert. |
| 81 | + shared_expert_intermediate_size (`int`, *optional*, defaults to 5632): |
| 82 | + Intermediate size of the shared expert. |
| 83 | + num_experts_per_tok (`int`, *optional*, defaults to 4): |
| 84 | + Number of selected experts. |
| 85 | + num_experts (`int`, *optional*, defaults to 60): |
| 86 | + Number of routed experts. |
| 87 | + norm_topk_prob (`bool`, *optional*, defaults to `False`): |
| 88 | + Whether to normalize the topk probabilities. |
| 89 | + output_router_logits (`bool`, *optional*, defaults to `False`): |
| 90 | + Whether or not the router logits should be returned by the model. Enabeling this will also |
| 91 | + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. |
| 92 | + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): |
| 93 | + The aux loss factor for the total loss. |
| 94 | +
|
| 95 | + ```python |
| 96 | + >>> from paddlenlp.transformers import Qwen2MoeModel, Qwen2MoeConfig |
| 97 | +
|
| 98 | + >>> # Initializing a Qwen2MoE style configuration |
| 99 | + >>> configuration = Qwen2MoeConfig() |
| 100 | +
|
| 101 | + >>> # Initializing a model from the Qwen1.5-MoE-A2.7B" style configuration |
| 102 | + >>> model = Qwen2MoeModel(configuration) |
| 103 | +
|
| 104 | + >>> # Accessing the model configuration |
| 105 | + >>> configuration = model.config |
| 106 | + ```""" |
| 107 | + |
| 108 | + model_type = "qwen2_moe" |
| 109 | + keys_to_ignore_at_inference = ["past_key_values"] |
| 110 | + |
| 111 | + def __init__( |
| 112 | + self, |
| 113 | + vocab_size=151936, |
| 114 | + hidden_size=2048, |
| 115 | + intermediate_size=5632, |
| 116 | + num_hidden_layers=24, |
| 117 | + num_attention_heads=16, |
| 118 | + num_key_value_heads=16, |
| 119 | + hidden_act="silu", |
| 120 | + max_position_embeddings=32768, |
| 121 | + seq_length=2048, |
| 122 | + initializer_range=0.02, |
| 123 | + rms_norm_eps=1e-6, |
| 124 | + use_cache=True, |
| 125 | + use_recompute=False, |
| 126 | + recompute_granularity="full", |
| 127 | + no_recompute_layers=None, |
| 128 | + use_flash_attention=False, |
| 129 | + attention_dropout=0.0, |
| 130 | + use_fused_rope=False, |
| 131 | + rope_theta=10000.0, |
| 132 | + tensor_parallel_output=True, |
| 133 | + sequence_parallel=False, |
| 134 | + fuse_sequence_parallel_allreduce=False, |
| 135 | + pad_token_id=0, |
| 136 | + bos_token_id=1, |
| 137 | + eos_token_id=2, |
| 138 | + tie_word_embeddings=False, |
| 139 | + use_sliding_window=False, |
| 140 | + sliding_window=4096, |
| 141 | + max_window_layers=28, |
| 142 | + decoder_sparse_step=1, |
| 143 | + moe_intermediate_size=1408, |
| 144 | + shared_expert_intermediate_size=5632, |
| 145 | + num_experts_per_tok=4, |
| 146 | + num_experts=60, |
| 147 | + norm_topk_prob=False, |
| 148 | + output_router_logits=False, |
| 149 | + router_aux_loss_coef=0.001, |
| 150 | + **kwargs, |
| 151 | + ): |
| 152 | + self.vocab_size = vocab_size |
| 153 | + self.max_position_embeddings = max_position_embeddings |
| 154 | + self.seq_length = seq_length |
| 155 | + self.hidden_size = hidden_size |
| 156 | + self.intermediate_size = intermediate_size |
| 157 | + self.num_hidden_layers = num_hidden_layers |
| 158 | + self.num_attention_heads = num_attention_heads |
| 159 | + self.use_sliding_window = use_sliding_window |
| 160 | + self.sliding_window = sliding_window |
| 161 | + self.max_window_layers = max_window_layers |
| 162 | + |
| 163 | + self.num_key_value_heads = num_key_value_heads |
| 164 | + self.hidden_act = hidden_act |
| 165 | + |
| 166 | + self.initializer_range = initializer_range |
| 167 | + self.rms_norm_eps = rms_norm_eps |
| 168 | + |
| 169 | + self.use_cache = use_cache |
| 170 | + self.use_recompute = use_recompute |
| 171 | + self.recompute_granularity = recompute_granularity |
| 172 | + self.no_recompute_layers = no_recompute_layers |
| 173 | + self.use_flash_attention = use_flash_attention |
| 174 | + self.tensor_parallel_output = tensor_parallel_output |
| 175 | + self.sequence_parallel = sequence_parallel |
| 176 | + self.fuse_sequence_parallel_allreduce = fuse_sequence_parallel_allreduce |
| 177 | + |
| 178 | + self.pad_token_id = pad_token_id |
| 179 | + self.bos_token_id = bos_token_id |
| 180 | + self.eos_token_id = eos_token_id |
| 181 | + |
| 182 | + self.use_fused_rope = use_fused_rope |
| 183 | + self.rope_theta = rope_theta |
| 184 | + self.attention_dropout = attention_dropout |
| 185 | + |
| 186 | + # MoE arguments |
| 187 | + self.decoder_sparse_step = decoder_sparse_step |
| 188 | + self.moe_intermediate_size = moe_intermediate_size |
| 189 | + self.shared_expert_intermediate_size = shared_expert_intermediate_size |
| 190 | + self.num_experts_per_tok = num_experts_per_tok |
| 191 | + self.num_experts = num_experts |
| 192 | + self.norm_topk_prob = norm_topk_prob |
| 193 | + self.output_router_logits = output_router_logits |
| 194 | + self.router_aux_loss_coef = router_aux_loss_coef |
| 195 | + |
| 196 | + super().__init__( |
| 197 | + pad_token_id=pad_token_id, |
| 198 | + bos_token_id=bos_token_id, |
| 199 | + eos_token_id=eos_token_id, |
| 200 | + tie_word_embeddings=tie_word_embeddings, |
| 201 | + tensor_parallel_output=tensor_parallel_output, |
| 202 | + **kwargs, |
| 203 | + ) |
0 commit comments