Skip to content

Commit 2c7d7b7

Browse files
committed
Implement norm head for Baichuan2
1 parent 68f178a commit 2c7d7b7

File tree

3 files changed

+64
-40
lines changed

3 files changed

+64
-40
lines changed

flash_attn/models/baichuan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ def key_mapping_attn(key):
116116
def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Config:
117117
# HACK: the config doesn't have say whether it's rotary or alibi.
118118
# So we have to infer from the hidden size (7B -> rotary, 13B -> alibi).
119+
# HACK: the config doesn't have say whether it uses norm head.
120+
# So we have to infer from the vocab size
121+
# (v1, vocab size 64k, no norm head; v2, vocab size 128k, norm head).
119122
use_rotary = baichuan_config.hidden_size < 5000
120123
return GPT2Config(
121124
vocab_size=baichuan_config.vocab_size,
@@ -141,6 +144,7 @@ def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Con
141144
use_alibi=not use_rotary,
142145
use_flash_attn=not use_rotary, # Alibi code path requires flash_attn
143146
tie_word_embeddings=False,
147+
norm_head=baichuan_config.vocab_size > 70000,
144148
qkv_proj_bias=False,
145149
out_proj_bias=False,
146150
mlp_fc1_bias=False,

flash_attn/models/gpt.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@
3232
ParallelMLP,
3333
)
3434
from flash_attn.ops.activations import sqrelu_fwd
35-
from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params
35+
from flash_attn.utils.distributed import (
36+
all_gather,
37+
all_gather_raw,
38+
get_dim_for_local_rank,
39+
sync_shared_params,
40+
)
3641
from flash_attn.utils.generation import GenerationMixin
3742
from flash_attn.utils.pretrained import state_dict_from_pretrained
3843

@@ -355,9 +360,8 @@ def from_pretrained(
355360
state_dict = remap_state_dict_hf_gpt2(state_dict, config)
356361
elif model_name.startswith("facebook/opt"):
357362
state_dict = remap_state_dict_hf_opt(state_dict, config)
358-
elif (
359-
model_name.startswith("EleutherAI/gpt-j-")
360-
or model_name.startswith("togethercomputer/GPT-JT-")
363+
elif model_name.startswith("EleutherAI/gpt-j-") or model_name.startswith(
364+
"togethercomputer/GPT-JT-"
361365
):
362366
state_dict = remap_state_dict_hf_gptj(state_dict, config)
363367
elif (
@@ -621,6 +625,7 @@ def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=No
621625
sequence_parallel=getattr(config, "sequence_parallel", True),
622626
**factory_kwargs,
623627
)
628+
self.norm_head = getattr(config, "norm_head", False)
624629
# Initialize weights and apply final processing
625630
self.apply(
626631
partial(
@@ -662,7 +667,13 @@ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_
662667
hidden_states = hidden_states[:, -num_last_tokens:]
663668
if self.project_out is not None:
664669
hidden_states = self.project_out(hidden_states)
665-
lm_logits = self.lm_head(hidden_states)
670+
if not self.norm_head:
671+
lm_logits = self.lm_head(hidden_states)
672+
else:
673+
lm_head_weight = F.normalize(self.lm_head.weight)
674+
if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel:
675+
hidden_states = all_gather(hidden_states, self.lm_head.process_group)
676+
lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias)
666677
# During inference, we want the full logit for sampling
667678
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
668679
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)

tests/models/test_baichuan.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,15 @@
2323
from flash_attn.utils.generation import update_graph_cache
2424

2525

26-
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
26+
@pytest.mark.parametrize(
27+
"model_name",
28+
[
29+
"baichuan-inc/Baichuan-7B",
30+
"baichuan-inc/Baichuan-13B-Base",
31+
"baichuan-inc/Baichuan2-7B-Base",
32+
"baichuan-inc/Baichuan2-13B-Base",
33+
],
34+
)
2735
def test_baichuan_state_dict(model_name):
2836
config = baichuan_config_to_gpt2_config(
2937
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
@@ -39,7 +47,15 @@ def test_baichuan_state_dict(model_name):
3947
assert state_dict[k].shape == pretrained_state_dict[k].shape
4048

4149

42-
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
50+
@pytest.mark.parametrize(
51+
"model_name",
52+
[
53+
"baichuan-inc/Baichuan-7B",
54+
"baichuan-inc/Baichuan-13B-Base",
55+
"baichuan-inc/Baichuan2-7B-Base",
56+
"baichuan-inc/Baichuan2-13B-Base",
57+
],
58+
)
4359
def test_baichuan_optimized(model_name):
4460
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
4561
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
@@ -66,9 +82,7 @@ def test_baichuan_optimized(model_name):
6682
torch.manual_seed(0)
6783
batch_size = 2
6884
max_seqlen = 256
69-
seqlens = torch.randint(
70-
max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device
71-
)
85+
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
7286
input_ids = torch.randint(
7387
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
7488
)
@@ -89,7 +103,10 @@ def test_baichuan_optimized(model_name):
89103
del model_ref
90104

91105
model_hf = AutoModelForCausalLM.from_pretrained(
92-
model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True,
106+
model_name,
107+
torch_dtype=dtype,
108+
device_map={"": device},
109+
trust_remote_code=True,
93110
)
94111
model_hf.eval()
95112
with torch.no_grad():
@@ -101,9 +118,7 @@ def test_baichuan_optimized(model_name):
101118
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
102119
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
103120
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
104-
assert (out - out_ref).abs().max().item() < 3 * (
105-
out_hf - out_ref
106-
).abs().max().item()
121+
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
107122

108123
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
109124
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
@@ -116,7 +131,15 @@ def test_baichuan_optimized(model_name):
116131

117132
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel_forward"
118133
@pytest.mark.parametrize("world_size", [2])
119-
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
134+
@pytest.mark.parametrize(
135+
"model_name",
136+
[
137+
"baichuan-inc/Baichuan-7B",
138+
"baichuan-inc/Baichuan-13B-Base",
139+
"baichuan-inc/Baichuan2-7B-Base",
140+
"baichuan-inc/Baichuan2-13B-Base",
141+
],
142+
)
120143
def test_baichuan_parallel_forward(model_name, world_size):
121144
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
122145
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
@@ -146,20 +169,14 @@ def test_baichuan_parallel_forward(model_name, world_size):
146169
state_dict_from_pretrained(model_name), config
147170
)
148171

149-
model = GPTLMHeadModel(
150-
config, process_group=process_group, device=device, dtype=dtype
151-
)
152-
model.load_state_dict(
153-
shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)
154-
)
172+
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
173+
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
155174
model.eval()
156175

157176
torch.manual_seed(0)
158177
batch_size = 2
159178
max_seqlen = 256
160-
seqlens = torch.randint(
161-
max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device
162-
)
179+
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
163180
input_ids = torch.randint(
164181
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
165182
)
@@ -198,9 +215,7 @@ def test_baichuan_parallel_forward(model_name, world_size):
198215
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
199216
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
200217
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
201-
assert (out - out_ref).abs().max().item() < 2 * (
202-
out_hf - out_ref
203-
).abs().max().item()
218+
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
204219

205220
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
206221
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
@@ -211,7 +226,9 @@ def test_baichuan_parallel_forward(model_name, world_size):
211226
).abs().max().item()
212227

213228

214-
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
229+
@pytest.mark.parametrize(
230+
"model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"]
231+
)
215232
def test_baichuan_generation(model_name):
216233
dtype = torch.float16
217234
device = "cuda"
@@ -258,9 +275,7 @@ def test_baichuan_generation(model_name):
258275
)
259276
model_ref.eval()
260277
with torch.no_grad():
261-
logits_ref = (
262-
model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
263-
)
278+
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
264279
del model_ref
265280

266281
pretrained_state_dict = remap_state_dict_hf_baichuan(
@@ -370,12 +385,8 @@ def test_baichuan_parallel_generation(model_name, world_size):
370385
state_dict_from_pretrained(model_name), config
371386
)
372387

373-
model = GPTLMHeadModel(
374-
config, process_group=process_group, device=device, dtype=dtype
375-
)
376-
model.load_state_dict(
377-
shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)
378-
)
388+
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
389+
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
379390
model.eval()
380391

381392
print("Without CUDA graph")
@@ -425,9 +436,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
425436
output_scores=True,
426437
)
427438
torch.cuda.synchronize()
428-
print(
429-
f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms"
430-
)
439+
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
431440
del model_hf
432441

433442
model_ref = AutoModelForCausalLM.from_pretrained(

0 commit comments

Comments
 (0)