Skip to content

Commit 7c1aa91

Browse files
precommit
1 parent 57459fd commit 7c1aa91

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

tools/ckpts/convert_hf_llama_to_neox.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,7 @@ def convert_model(hf_state_dict, hf_config, tp_ranks):
6262
# The GQA code simply expects concatenated q,k,v weights for each tp partition
6363
conv_state_dicts[i][
6464
f"sequential.{layer_num+2}.attention.query_key_value.weight"
65-
] = (
66-
torch.cat([q_chunk, k_chunk, v_chunk], dim=0)
67-
.clone()
68-
.detach()
69-
)
65+
] = (torch.cat([q_chunk, k_chunk, v_chunk], dim=0).clone().detach())
7066
print(
7167
f"model.layers.{layer_num}.self_attn.(q/k/v)_proj.weight",
7268
hf_state_dict[f"model.layers.{layer_num}.self_attn.q_proj.weight"].shape,

tools/ckpts/convert_neox_to_hf.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -371,12 +371,16 @@ def reshard_and_split_qkv(
371371
)
372372
# We should now have shape [TP_SIZE, (hidden_size + 2 * kv_hidden_size) / TP_SIZE, hidden_size].
373373
# At this point, for each TP rank, q, k, and v are concatenated
374-
374+
375375
# Next, we split tp_harded_qkv into q, k, v along dim 1
376-
hidden_size_per_attention_head = hf_config.hidden_size // hf_config.num_attention_heads
377-
kv_hidden_size = int(hidden_size_per_attention_head * hf_config.num_key_value_heads)
376+
hidden_size_per_attention_head = (
377+
hf_config.hidden_size // hf_config.num_attention_heads
378+
)
379+
kv_hidden_size = int(
380+
hidden_size_per_attention_head * hf_config.num_key_value_heads
381+
)
378382
tensor_parallel_size = len(loaded_tp_ranks)
379-
383+
380384
q, k, v = torch.split(
381385
tp_sharded_qkv,
382386
[
@@ -385,13 +389,17 @@ def reshard_and_split_qkv(
385389
kv_hidden_size // tensor_parallel_size,
386390
],
387391
dim=1,
388-
) # New shapes:
392+
) # New shapes:
389393
# q-->[TP_SIZE, hidden_size/TP_SIZE, hidden_size]
390394
# k-->[TP_SIZE, kv_hidden_size/TP_SIZE, hidden_size]
391395
# v-->[TP_SIZE, kv_hidden_size/TP_SIZE, hidden_size]
392396

393397
# Finally, we flatten the first two dimensions merging the TP partitions
394-
q, k, v = q.reshape(-1, q.shape[2]), k.reshape(-1, k.shape[2]), v.reshape(-1, k.shape[2])
398+
q, k, v = (
399+
q.reshape(-1, q.shape[2]),
400+
k.reshape(-1, k.shape[2]),
401+
v.reshape(-1, k.shape[2]),
402+
)
395403

396404
# return these
397405
state_dict = {}

0 commit comments

Comments
 (0)