1+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ import unittest
16+
17+ import numpy as np
18+ import paddle
19+ import paddle .distributed .fleet as fleet
20+ from paddle .distributed .fleet .meta_parallel .pipeline_parallel import PipelineParallel
21+
22+ from paddlenlp .transformers import AutoConfig , AutoModelForCausalLM , AutoModelForCausalLMPipe ,AutoTokenizer
23+
24+
25+ class TestLlama (unittest .TestCase ):
26+ def test_sequence_model (self ):
27+ world_size = paddle .distributed .get_world_size ()
28+ pp_degree = world_size
29+ tp_degree = 1
30+
31+ if world_size > 2 :
32+ pp_degree = 2
33+ assert world_size % pp_degree == 0
34+ tp_degree = world_size // pp_degree
35+
36+ strategy = fleet .DistributedStrategy ()
37+ strategy .hybrid_configs = {
38+ "dp_degree" : 1 ,
39+ "mp_degree" : tp_degree ,
40+ "pp_degree" : pp_degree ,
41+ "sharding_degree" : 1 ,
42+ }
43+ #strategy.pipeline_configs = {"enable_partial_send_recv": False if pp_degree > 1 else True}
44+ fleet .init (is_collective = True , strategy = strategy )
45+ hcg = fleet .get_hybrid_communicate_group ()
46+ mp_group = hcg .get_model_parallel_group ()
47+ tensor_parallel_rank = mp_group .rank
48+
49+ if pp_degree > 1 :
50+ model_class = AutoModelForCausalLMPipe
51+ else :
52+ model_class = AutoModelForCausalLM
53+
54+ model_name_or_path = "meta-llama/Llama-2-7b"
55+
56+ seq_len = 2048
57+ batch_size = 2
58+
59+ tokenizer = AutoTokenizer .from_pretrained (model_name_or_path )
60+ config = AutoConfig .from_pretrained (model_name_or_path )
61+ config .seq_length = seq_len
62+ config .num_key_value_heads = 8 # gqa
63+ config .max_position_embeddings = max (config .max_position_embeddings , seq_len )
64+ config .vocab_size = max (config .vocab_size , ((tokenizer .vocab_size - 1 ) // 128 + 1 ) * 128 )
65+ config .use_flash_attention = True
66+ config .use_fused_rope = True
67+ config .use_fused_rms_norm = True
68+ config .fuse_attention_qkv = True
69+ config .recompute_granularity = "full"
70+ config .virtual_pp_degree = 1
71+ config .use_recompute = False
72+
73+ config .tensor_parallel_degree = tp_degree
74+ config .tensor_parallel_rank = tensor_parallel_rank
75+ config .tensor_parallel_output = False
76+ config .sequence_parallel = False
77+
78+ config .fuse_sequence_parallel_allreduce = False
79+
80+ # hidden_size = 4096
81+ model = model_class .from_config (
82+ config ,
83+ dtype = "float16" ,
84+ )
85+
86+ model .eval ()
87+
88+ input_ids = paddle .arange (100 , 100 + batch_size * seq_len , dtype = "int64" ).reshape ([batch_size , seq_len ])
89+ labels = paddle .arange (101 , 101 + batch_size * seq_len , dtype = "int64" ).reshape ([batch_size , seq_len ])
90+
91+ attention_mask = None
92+ if pp_degree > 1 :
93+ pp_model = PipelineParallel (layers = model , hcg = hcg , strategy = strategy )
94+ pp_model .accumulate_steps = batch_size # for micro_batch_size * acc_steps == batch_size
95+ ret = pp_model .eval_batch (data = [input_ids , labels ], compute_loss = True )
96+ else :
97+ ret = model (input_ids = input_ids , labels = labels , attention_mask = attention_mask )
98+ ret = ret [0 ]
99+
100+ print (f"ret mp{ tp_degree } pp{ pp_degree } " , ret .item ())
101+ ret_mp_pp = ret .item ()
102+
103+
104+
105+
106+
107+ if __name__ == "__main__" :
108+ TestLlama ().test_sequence_model ()
0 commit comments