2323from flash_attn .utils .generation import update_graph_cache
2424
2525
26- @pytest .mark .parametrize ("model_name" , ["baichuan-inc/Baichuan-7B" , "baichuan-inc/Baichuan-13B-Base" ])
26+ @pytest .mark .parametrize (
27+ "model_name" ,
28+ [
29+ "baichuan-inc/Baichuan-7B" ,
30+ "baichuan-inc/Baichuan-13B-Base" ,
31+ "baichuan-inc/Baichuan2-7B-Base" ,
32+ "baichuan-inc/Baichuan2-13B-Base" ,
33+ ],
34+ )
2735def test_baichuan_state_dict (model_name ):
2836 config = baichuan_config_to_gpt2_config (
2937 AutoConfig .from_pretrained (model_name , trust_remote_code = True )
@@ -39,7 +47,15 @@ def test_baichuan_state_dict(model_name):
3947 assert state_dict [k ].shape == pretrained_state_dict [k ].shape
4048
4149
42- @pytest .mark .parametrize ("model_name" , ["baichuan-inc/Baichuan-7B" , "baichuan-inc/Baichuan-13B-Base" ])
50+ @pytest .mark .parametrize (
51+ "model_name" ,
52+ [
53+ "baichuan-inc/Baichuan-7B" ,
54+ "baichuan-inc/Baichuan-13B-Base" ,
55+ "baichuan-inc/Baichuan2-7B-Base" ,
56+ "baichuan-inc/Baichuan2-13B-Base" ,
57+ ],
58+ )
4359def test_baichuan_optimized (model_name ):
4460 """Check that our implementation of Baichuan (with all optimizations enabled) matches the
4561 HF implementation: the output of our forward pass in fp16 should be around the same as the HF
@@ -66,9 +82,7 @@ def test_baichuan_optimized(model_name):
6682 torch .manual_seed (0 )
6783 batch_size = 2
6884 max_seqlen = 256
69- seqlens = torch .randint (
70- max_seqlen // 2 , max_seqlen + 1 , (batch_size ,), device = device
71- )
85+ seqlens = torch .randint (max_seqlen // 2 , max_seqlen + 1 , (batch_size ,), device = device )
7286 input_ids = torch .randint (
7387 0 , config .vocab_size , (batch_size , max_seqlen ), dtype = torch .long , device = device
7488 )
@@ -89,7 +103,10 @@ def test_baichuan_optimized(model_name):
89103 del model_ref
90104
91105 model_hf = AutoModelForCausalLM .from_pretrained (
92- model_name , torch_dtype = dtype , device_map = {"" : device }, trust_remote_code = True ,
106+ model_name ,
107+ torch_dtype = dtype ,
108+ device_map = {"" : device },
109+ trust_remote_code = True ,
93110 )
94111 model_hf .eval ()
95112 with torch .no_grad ():
@@ -101,9 +118,7 @@ def test_baichuan_optimized(model_name):
101118 print (f"Output mean diff: { (out - out_ref ).abs ().mean ().item ()} " )
102119 print (f"HF fp16 max diff: { (out_hf - out_ref ).abs ().max ().item ()} " )
103120 print (f"HF fp16 mean diff: { (out_hf - out_ref ).abs ().mean ().item ()} " )
104- assert (out - out_ref ).abs ().max ().item () < 3 * (
105- out_hf - out_ref
106- ).abs ().max ().item ()
121+ assert (out - out_ref ).abs ().max ().item () < 3 * (out_hf - out_ref ).abs ().max ().item ()
107122
108123 print (f"Logits max diff: { (logits - logits_ref ).abs ().max ().item ()} " )
109124 print (f"Logits mean diff: { (logits - logits_ref ).abs ().mean ().item ()} " )
@@ -116,7 +131,15 @@ def test_baichuan_optimized(model_name):
116131
117132# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel_forward"
118133@pytest .mark .parametrize ("world_size" , [2 ])
119- @pytest .mark .parametrize ("model_name" , ["baichuan-inc/Baichuan-7B" , "baichuan-inc/Baichuan-13B-Base" ])
134+ @pytest .mark .parametrize (
135+ "model_name" ,
136+ [
137+ "baichuan-inc/Baichuan-7B" ,
138+ "baichuan-inc/Baichuan-13B-Base" ,
139+ "baichuan-inc/Baichuan2-7B-Base" ,
140+ "baichuan-inc/Baichuan2-13B-Base" ,
141+ ],
142+ )
120143def test_baichuan_parallel_forward (model_name , world_size ):
121144 """Check that our implementation of Baichuan (with all optimizations enabled) matches the
122145 HF implementation: the output of our forward pass in fp16 should be around the same as the HF
@@ -146,20 +169,14 @@ def test_baichuan_parallel_forward(model_name, world_size):
146169 state_dict_from_pretrained (model_name ), config
147170 )
148171
149- model = GPTLMHeadModel (
150- config , process_group = process_group , device = device , dtype = dtype
151- )
152- model .load_state_dict (
153- shard_state_dict_tp (pretrained_state_dict , config , world_size , rank )
154- )
172+ model = GPTLMHeadModel (config , process_group = process_group , device = device , dtype = dtype )
173+ model .load_state_dict (shard_state_dict_tp (pretrained_state_dict , config , world_size , rank ))
155174 model .eval ()
156175
157176 torch .manual_seed (0 )
158177 batch_size = 2
159178 max_seqlen = 256
160- seqlens = torch .randint (
161- max_seqlen // 2 , max_seqlen + 1 , (batch_size ,), device = device
162- )
179+ seqlens = torch .randint (max_seqlen // 2 , max_seqlen + 1 , (batch_size ,), device = device )
163180 input_ids = torch .randint (
164181 0 , config .vocab_size , (batch_size , max_seqlen ), dtype = torch .long , device = device
165182 )
@@ -198,9 +215,7 @@ def test_baichuan_parallel_forward(model_name, world_size):
198215 print (f"Output mean diff: { (out - out_ref ).abs ().mean ().item ()} " )
199216 print (f"HF fp16 max diff: { (out_hf - out_ref ).abs ().max ().item ()} " )
200217 print (f"HF fp16 mean diff: { (out_hf - out_ref ).abs ().mean ().item ()} " )
201- assert (out - out_ref ).abs ().max ().item () < 2 * (
202- out_hf - out_ref
203- ).abs ().max ().item ()
218+ assert (out - out_ref ).abs ().max ().item () < 2 * (out_hf - out_ref ).abs ().max ().item ()
204219
205220 print (f"Logits max diff: { (logits - logits_ref ).abs ().max ().item ()} " )
206221 print (f"Logits mean diff: { (logits - logits_ref ).abs ().mean ().item ()} " )
@@ -211,7 +226,9 @@ def test_baichuan_parallel_forward(model_name, world_size):
211226 ).abs ().max ().item ()
212227
213228
214- @pytest .mark .parametrize ("model_name" , ["baichuan-inc/Baichuan-7B" , "baichuan-inc/Baichuan-13B-Base" ])
229+ @pytest .mark .parametrize (
230+ "model_name" , ["baichuan-inc/Baichuan-7B" , "baichuan-inc/Baichuan-13B-Base" ]
231+ )
215232def test_baichuan_generation (model_name ):
216233 dtype = torch .float16
217234 device = "cuda"
@@ -258,9 +275,7 @@ def test_baichuan_generation(model_name):
258275 )
259276 model_ref .eval ()
260277 with torch .no_grad ():
261- logits_ref = (
262- model_ref (out_hf .sequences ).logits [:, (seqlen - 1 ) : - 1 ].to (device = device )
263- )
278+ logits_ref = model_ref (out_hf .sequences ).logits [:, (seqlen - 1 ) : - 1 ].to (device = device )
264279 del model_ref
265280
266281 pretrained_state_dict = remap_state_dict_hf_baichuan (
@@ -370,12 +385,8 @@ def test_baichuan_parallel_generation(model_name, world_size):
370385 state_dict_from_pretrained (model_name ), config
371386 )
372387
373- model = GPTLMHeadModel (
374- config , process_group = process_group , device = device , dtype = dtype
375- )
376- model .load_state_dict (
377- shard_state_dict_tp (pretrained_state_dict , config , world_size , rank )
378- )
388+ model = GPTLMHeadModel (config , process_group = process_group , device = device , dtype = dtype )
389+ model .load_state_dict (shard_state_dict_tp (pretrained_state_dict , config , world_size , rank ))
379390 model .eval ()
380391
381392 print ("Without CUDA graph" )
@@ -425,9 +436,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
425436 output_scores = True ,
426437 )
427438 torch .cuda .synchronize ()
428- print (
429- f"Prompt processing + decoding time: { (time .time () - start ) * 1000 :.0f} ms"
430- )
439+ print (f"Prompt processing + decoding time: { (time .time () - start ) * 1000 :.0f} ms" )
431440 del model_hf
432441
433442 model_ref = AutoModelForCausalLM .from_pretrained (
0 commit comments