11from enum import IntEnum , auto
22from typing import Optional
3- from typing import List
3+ from typing import List , Set
44from transformers import PretrainedConfig
55
66from scratchpad .utils import get_config , get_context_length
@@ -70,6 +70,9 @@ def __init__(
7070 self .hf_config .architectures , is_embedding
7171 )
7272 self .is_multimodal = is_multimodal_model (self .hf_config .architectures )
73+ self .is_multimodal_gen = False
74+ self .is_image_gen = False
75+ self .is_audio_model = False
7376 self .is_encoder_decoder = is_encoder_decoder_model (self .hf_config .architectures )
7477 if context_length is not None :
7578 self .context_len = context_length
@@ -82,38 +85,26 @@ def __init__(
8285 "head_dim" ,
8386 self .hf_text_config .hidden_size // self .hf_text_config .num_attention_heads ,
8487 )
85-
86- # FIXME: temporary special judge for deepseek v2 MLA architecture
87- if "DeepseekV2ForCausalLM" in self .hf_config .architectures :
88- self .head_dim = 256
89- self .attention_arch = AttentionArch .MLA
90- self .kv_lora_rank = self .hf_config .kv_lora_rank
91- self .qk_rope_head_dim = self .hf_config .qk_rope_head_dim
92- elif "MiniCPM3ForCausalLM" in self .hf_config .architectures :
93- self .head_dim = 128
94- self .attention_arch = AttentionArch .MLA
95- self .kv_lora_rank = self .hf_config .kv_lora_rank
96- self .qk_rope_head_dim = self .hf_config .qk_rope_head_dim
97- else :
98- self .attention_arch = AttentionArch .MHA
99-
88+ self .attention_arch = AttentionArch .MHA
10089 self .num_attention_heads = self .hf_text_config .num_attention_heads
10190 self .num_key_value_heads = getattr (
10291 self .hf_text_config , "num_key_value_heads" , None
10392 )
104-
105- # for Dbrx and MPT models
106- if self .hf_config .model_type in ["dbrx" , "mpt" ]:
107- self .num_key_value_heads = getattr (
108- self .hf_config .attn_config , "kv_n_heads" , None
109- )
110-
11193 if self .num_key_value_heads is None :
11294 self .num_key_value_heads = self .num_attention_heads
11395 self .hidden_size = self .hf_text_config .hidden_size
11496 self .num_hidden_layers = self .hf_text_config .num_hidden_layers
11597 self .vocab_size = self .hf_text_config .vocab_size
11698 self .is_encoder_decoder = self .hf_config .model_type in ["mllama" ]
99+ self .hf_eos_token_id = self .get_hf_eos_token_id ()
100+ self .image_token_id = getattr (self .hf_config , "image_token_id" , None )
101+
102+ def get_hf_eos_token_id (self ) -> Optional [Set [int ]]:
103+ eos_ids = getattr (self .hf_config , "eos_token_id" , None )
104+ if eos_ids :
105+ # it can be either int or list of int
106+ eos_ids = {eos_ids } if isinstance (eos_ids , int ) else set (eos_ids )
107+ return eos_ids
117108
118109 # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
119110 def get_total_num_kv_heads (self ) -> int :
0 commit comments