@@ -371,12 +371,16 @@ def reshard_and_split_qkv(
371
371
)
372
372
# We should now have shape [TP_SIZE, (hidden_size + 2 * kv_hidden_size) / TP_SIZE, hidden_size].
373
373
# At this point, for each TP rank, q, k, and v are concatenated
374
-
374
+
375
375
# Next, we split tp_harded_qkv into q, k, v along dim 1
376
- hidden_size_per_attention_head = hf_config .hidden_size // hf_config .num_attention_heads
377
- kv_hidden_size = int (hidden_size_per_attention_head * hf_config .num_key_value_heads )
376
+ hidden_size_per_attention_head = (
377
+ hf_config .hidden_size // hf_config .num_attention_heads
378
+ )
379
+ kv_hidden_size = int (
380
+ hidden_size_per_attention_head * hf_config .num_key_value_heads
381
+ )
378
382
tensor_parallel_size = len (loaded_tp_ranks )
379
-
383
+
380
384
q , k , v = torch .split (
381
385
tp_sharded_qkv ,
382
386
[
@@ -385,13 +389,17 @@ def reshard_and_split_qkv(
385
389
kv_hidden_size // tensor_parallel_size ,
386
390
],
387
391
dim = 1 ,
388
- ) # New shapes:
392
+ ) # New shapes:
389
393
# q-->[TP_SIZE, hidden_size/TP_SIZE, hidden_size]
390
394
# k-->[TP_SIZE, kv_hidden_size/TP_SIZE, hidden_size]
391
395
# v-->[TP_SIZE, kv_hidden_size/TP_SIZE, hidden_size]
392
396
393
397
# Finally, we flatten the first two dimensions merging the TP partitions
394
- q , k , v = q .reshape (- 1 , q .shape [2 ]), k .reshape (- 1 , k .shape [2 ]), v .reshape (- 1 , k .shape [2 ])
398
+ q , k , v = (
399
+ q .reshape (- 1 , q .shape [2 ]),
400
+ k .reshape (- 1 , k .shape [2 ]),
401
+ v .reshape (- 1 , k .shape [2 ]),
402
+ )
395
403
396
404
# return these
397
405
state_dict = {}
0 commit comments