Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 12 additions & 53 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,17 @@ def __init__(self, config: LlamaConfig):
self.quant_type
)

self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size)
if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0:
self.embed_tokens = fleet.meta_parallel.VocabParallelEmbedding(
self.vocab_size,
self.hidden_size,
weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()),
)
else:
self.embed_tokens = nn.Embedding(
self.vocab_size,
self.hidden_size,
)

# get ring_id
ring_id = -1
Expand Down Expand Up @@ -1246,58 +1256,6 @@ def __init__(self, config):
self.llama = LlamaInferenceModel(config)
self.lm_head = LlamaLMHead(config)

@classmethod
def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True):

from paddlenlp.transformers.conversion_utils import split_or_merge_func

fn = split_or_merge_func(
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
)

def get_tensor_parallel_split_mappings(num_layers):
final_actions = {}

base_actions = {
"lm_head.weight": partial(fn, is_column=True),
# Row Linear
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
"layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
}

# Column Linear
if config.fuse_attention_qkv:
base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True)
else:
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
# if we have enough num_key_value_heads to split, then split it.
if config.num_key_value_heads % config.tensor_parallel_degree == 0:
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)

if config.fuse_attention_ffn:
base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial(
fn, is_column=True, is_naive_2fuse=True
)
else:
base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True)

for key, action in base_actions.items():
if "layers.0." in key:
for i in range(num_layers):
final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
final_actions[key] = action

return final_actions

mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)

return mappings

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
return infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs)
Expand Down Expand Up @@ -1477,6 +1435,7 @@ def get_tensor_parallel_split_mappings(num_layers):
base_actions = {
"lm_head.weight": partial(fn, is_column=True),
# Row Linear
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
"layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
}
Expand Down