Skip to content

Commit ce083f0

Browse files
[intel_hpu] initial commit for intel_hpu support (#9273)
* add intel hpu device * add wa for intel hpu --------- Co-authored-by: Zhai Feiyue <[email protected]>
1 parent 26b73c2 commit ce083f0

File tree

7 files changed

+300
-10
lines changed

7 files changed

+300
-10
lines changed

llm/intel_hpu/llama/README.md

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
## 🚣‍♂️ 使用 PaddleNLP 在 Intel HPU 下跑通 llama2-7b 模型 🚣
2+
PaddleNLP 在 Intel® Gaudi®2D([了解 Gaudi](https://docs.habana.ai/en/latest/index.html))上对 llama2-7B 模型进行了深度适配和优化,下面给出详细安装步骤。
3+
4+
## 🚀 快速开始 🚀
5+
6+
### (0)在开始之前,您需要有一台 Intel Gaudi 机器,对此机器的系统要求如下:
7+
8+
| 芯片类型 | 卡型号 | 驱动版本 |
9+
| --- | --- | --- |
10+
| Gaudi | 225D | 1.17.0 |
11+
12+
13+
### (1)环境准备:(这将花费您5~15min 时间)
14+
1. 拉取镜像
15+
```
16+
# 注意此镜像仅为开发环境,镜像中不包含预编译的飞桨安装包
17+
docker pull vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest
18+
```
19+
2. 参考如下命令启动容器
20+
```
21+
docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest
22+
```
23+
3. 安装 paddle
24+
```
25+
# paddlepaddle『飞桨』深度学习框架,提供运算基础能力
26+
pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html
27+
```
28+
4. 安装 paddleCustomDevice
29+
```
30+
# paddleCustomDevice是paddlepaddle『飞桨』深度学习框架的自定义硬件接入实现,提供Intel HPU的算子实现。
31+
git clone --recursive https://github.com/PaddlePaddle/PaddleCustomDevice
32+
cd PaddleCustomDevice
33+
git submodule sync
34+
git submodule update --remote --init --recursive
35+
cd backends/intel_hpu/
36+
mkdir build && cd build
37+
cmake ..
38+
make -j8
39+
pip install dist/paddle_intel_hpu-0.0.1-cp310-cp310-linux_x86_64.whl
40+
```
41+
5. 克隆 PaddleNLP 仓库代码,并安装依赖
42+
```
43+
# PaddleNLP是基于paddlepaddle『飞桨』的自然语言处理和大语言模型(LLM)开发库,存放了基于『飞桨』框架实现的各种大模型,llama2-7B模型也包含其中。为了便于您更好地使用PaddleNLP,您需要clone整个仓库。
44+
git clone https://github.com/PaddlePaddle/PaddleNLP.git
45+
cd PaddleNLP
46+
python -m pip install -r requirements.txt
47+
python -m pip install -e .
48+
```
49+
50+
### (2)推理:(这将花费您10~15min 时间)
51+
1. 单卡推理
52+
53+
执行如下命令进行推理:
54+
```bash
55+
python inference_hpu.py
56+
```
57+
58+
成功运行后,可以查看到推理结果的生成,样例如下:
59+
```
60+
[2024-10-25 02:42:42,220] [ INFO] - We are using <class 'paddlenlp.transformers.llama.tokenizer.LlamaTokenizer'> to load 'meta-llama/Llama-2-7b-chat'.
61+
[2024-10-25 02:42:42,427] [ INFO] - We are using <class 'paddlenlp.transformers.llama.modeling.LlamaForCausalLM'> to load 'meta-llama/Llama-2-7b-chat'.
62+
[2024-10-25 02:42:42,427] [ INFO] - Loading configuration file /root/.paddlenlp/models/meta-llama/Llama-2-7b-chat/config.json
63+
[2024-10-25 02:42:42,428] [ INFO] - Loading weights file from cache at /root/.paddlenlp/models/meta-llama/Llama-2-7b-chat/model_state.pdparams
64+
[2024-10-25 02:43:32,871] [ INFO] - Loaded weights file from disk, setting weights to model.
65+
[2024-10-25 02:44:15,226] [ INFO] - All model checkpoint weights were used when initializing LlamaForCausalLM.
66+
67+
[2024-10-25 02:44:15,226] [ INFO] - All the weights of LlamaForCausalLM were initialized from the model checkpoint at meta-llama/Llama-2-7b-chat.
68+
If your task is similar to the task the model of the checkpoint was trained on, you can already use LlamaForCausalLM for predictions without further training.
69+
[2024-10-25 02:44:15,229] [ INFO] - Loading configuration file /root/.paddlenlp/models/meta-llama/Llama-2-7b-chat/generation_config.json
70+
71+
['myself. I am a 35 year old woman from the United States. I am a writer and artist, and I have been living in Japan for the past 5 years. I am originally from the Midwest, but I have lived in several different places around the world, including California, New York, and now Japan.\nI am passionate about many things, including art, writing, music, and travel. I love to explore new places and cultures, and I am always looking for new inspiration for my art and writing. I am also a big fan of Japanese culture, and I try to learn as much']
72+
```
73+
2. 多卡推理
74+
75+
执行如下命令进行推理:
76+
```bash
77+
bash test_llama_2x.sh
78+
```
79+
成功运行后,可以查看到推理结果的生成,样例如下:
80+
```bash
81+
[2024-10-29 11:24:39,468] [ INFO] - We are using <class 'paddlenlp.transformers.llama.tokenizer.LlamaTokenizer'> to load 'meta-llama/Llama-2-7b-chat'.
82+
[2024-10-29 11:24:40,705] [ INFO] distributed_strategy.py:214 - distributed strategy initialized
83+
I1029 11:24:40.706755 14711 tcp_utils.cc:181] The server starts to listen on IP_ANY:59129
84+
I1029 11:24:40.706897 14711 tcp_utils.cc:130] Successfully connected to 127.0.0.1:59129
85+
[2024-10-29 11:24:42,740] [ INFO] topology.py:357 - Total 2 pipe comm group(s) create successfully!
86+
[2024-10-29 11:24:52,064] [ INFO] topology.py:357 - Total 2 data comm group(s) create successfully!
87+
[2024-10-29 11:24:52,064] [ INFO] topology.py:357 - Total 1 model comm group(s) create successfully!
88+
[2024-10-29 11:24:52,065] [ INFO] topology.py:357 - Total 2 sharding comm group(s) create successfully!
89+
[2024-10-29 11:24:52,065] [ INFO] topology.py:279 - HybridParallelInfo: rank_id: 0, mp_degree: 2, sharding_degree: 1, pp_degree: 1, dp_degree: 1, sep_degree: 1, mp_group: [0, 1], sharding_group: [0], pp_group: [0], dp_group: [0], sep:group: None, check/clip group: [0, 1]
90+
[2024-10-29 11:24:52,067] [ INFO] - We are using <class 'paddlenlp.transformers.llama.modeling.LlamaForCausalLM'> to load 'meta-llama/Llama-2-7b-chat'.
91+
[2024-10-29 11:24:52,067] [ INFO] - Loading configuration file /root/.paddlenlp/models/meta-llama/Llama-2-7b-chat/config.json
92+
[2024-10-29 11:24:52,068] [ INFO] - Loading weights file from cache at /root/.paddlenlp/models/meta-llama/Llama-2-7b-chat/model_state.pdparams
93+
[2024-10-29 11:25:43,202] [ INFO] - Starting to convert orignal state_dict to tensor parallel state_dict.
94+
[2024-10-29 11:25:45,125] [ INFO] - Loaded weights file from disk, setting weights to model.
95+
[2024-10-29 11:26:04,008] [ INFO] - All model checkpoint weights were used when initializing LlamaForCausalLM.
96+
[2024-10-29 11:26:04,008] [ INFO] - All the weights of LlamaForCausalLM were initialized from the model checkpoint at meta-llama/Llama-2-7b-chat.
97+
If your task is similar to the task the model of the checkpoint was trained on, you can already use LlamaForCausalLM for predictions without further training.
98+
[2024-10-29 11:26:04,010] [ INFO] - Loading configuration file /root/.paddlenlp/models/meta-llama/Llama-2-7b-chat/generation_config.json
99+
100+
['myself\nHello everyone my name is [Your Name], and I am a new member of this community']
101+
I1029 11:26:16.184163 14767 tcp_store.cc:293] receive shutdown event and so quit from MasterDaemon run loop
102+
LAUNCH INFO 2024-10-29 11:26:17,186 Pod completed
103+
LAUNCH INFO 2024-10-29 11:26:17,186 Exit code 0
104+
```
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
17+
from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer
18+
19+
# import os
20+
# os.environ['ENABLE_EXPERIMENTAL_FLAGS'] = '1'
21+
# os.environ['VISUALIZATION_MODE'] = '0'
22+
# os.environ['GRAPH_VISUALIZATION'] = '1'
23+
# os.environ["HABANA_LOGS"] = "logs"
24+
# os.environ["LOG_LEVEL_ALL"] = "0"
25+
# os.environ['GLOG_v'] = '10'
26+
27+
28+
paddle.set_device("intel_hpu")
29+
paddle.set_default_dtype("bfloat16")
30+
31+
model = "meta-llama/Llama-2-7b-chat"
32+
tokenizer = AutoTokenizer.from_pretrained(model)
33+
model = AutoModelForCausalLM.from_pretrained(model, dtype="bfloat16")
34+
35+
input_features = tokenizer("please introduce llm", return_tensors="pd")
36+
37+
with paddle.amp.auto_cast(dtype="bfloat16", custom_white_list={"elementwise_add", "rms_norm"}):
38+
outputs = model.generate(**input_features, max_length=128)
39+
40+
print(tokenizer.batch_decode(outputs[0]))

llm/intel_hpu/llama/test_llama.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
from paddle.distributed import fleet
17+
18+
from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer
19+
20+
paddle.set_device("intel_hpu")
21+
paddle.set_default_dtype("bfloat16")
22+
23+
model = "meta-llama/Llama-2-7b-chat"
24+
tokenizer = AutoTokenizer.from_pretrained(model)
25+
strategy = fleet.DistributedStrategy()
26+
strategy.hybrid_configs = {
27+
"dp_degree": 1,
28+
"mp_degree": 2,
29+
"pp_degree": 1,
30+
"sharding_degree": 1,
31+
}
32+
fleet.init(is_collective=True, strategy=strategy)
33+
hcg = fleet.get_hybrid_communicate_group()
34+
tensor_parallel_rank = hcg.get_model_parallel_rank()
35+
36+
model = AutoModelForCausalLM.from_pretrained(
37+
model,
38+
tensor_parallel_degree=2,
39+
tensor_parallel_rank=tensor_parallel_rank,
40+
dtype="bfloat16",
41+
)
42+
input_features = tokenizer("please introduce llm", return_tensors="pd")
43+
44+
45+
with paddle.amp.auto_cast(dtype="bfloat16", custom_white_list={"elementwise_add", "rms_norm"}):
46+
outputs = model.generate(**input_features, max_length=20)
47+
48+
print(tokenizer.batch_decode(outputs[0]))
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
set -ex
16+
17+
# export LOG_LEVEL_ALL=0
18+
export HABANA_LOGS=./logs
19+
20+
# export HCCL_COMM_ID=127.0.0.1:5555
21+
# export INTEL_HPU_VISIBLE_DEVICES=0,1 # 3,4
22+
export INTEL_HPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
23+
export PADDLE_DISTRI_BACKEND=xccl
24+
export PADDLE_XCCL_BACKEND=intel_hpu
25+
# PYTHONPATH=../../:$PYTHONPATH \
26+
export FLAGS_intel_hpu_runtime_debug=0
27+
28+
# export HABANA_PROFILE=1
29+
# export HABANA_PROFILE_WRITE_HLTV_WITH_HOST=1
30+
31+
echo $INTEL_HPU_VISIBLE_DEVICES
32+
33+
# export GRAPH_VISUALIZATION=1
34+
# export ENABLE_EXPERIMENTAL_FLAGS=1
35+
# export VISUALIZATION_MODE=0
36+
37+
#GLOG_v=10
38+
python -m paddle.distributed.launch --devices "3,5" test_llama.py 2>&1 | tee test_llama_2x.log
39+
40+

paddlenlp/transformers/llama/fusion_ops.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,32 @@ def fusion_rope(
6464
rotary_emb,
6565
context_parallel_degree=-1,
6666
):
67-
if get_env_device() != "gcu":
67+
if get_env_device() not in ["gcu", "intel_hpu"]:
6868
assert past_key_value is None, "fuse rotary not support cache kv for now"
6969
batch_size, seq_length, num_heads, head_dim = query_states.shape
7070
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
7171
if context_parallel_degree > 1:
7272
assert get_env_device() == "gpu", "context parallel only support cuda device for now"
7373
kv_seq_len *= context_parallel_degree
74-
if get_env_device() != "gcu":
74+
if get_env_device() not in ["gcu", "intel_hpu"]:
7575
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
7676
if get_env_device() == "npu":
7777
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
7878
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
79+
elif get_env_device() == "intel_hpu":
80+
if past_key_value is not None:
81+
kv_seq_len += past_key_value[0].shape[-3]
82+
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
83+
cos = cos.squeeze().unsqueeze(0).unsqueeze(0)
84+
sin = sin.squeeze().unsqueeze(0).unsqueeze(0)
85+
query_states, _, _ = paddle.incubate.nn.functional.fused_rotary_position_embedding(
86+
paddle.transpose(query_states, [0, 2, 1, 3]), None, None, sin=sin, cos=cos, position_ids=position_ids
87+
)
88+
key_states, _, _ = paddle.incubate.nn.functional.fused_rotary_position_embedding(
89+
paddle.transpose(key_states, [0, 2, 1, 3]), None, None, sin=sin, cos=cos, position_ids=position_ids
90+
)
91+
query_states = paddle.transpose(query_states, [0, 2, 1, 3])
92+
key_states = paddle.transpose(key_states, [0, 2, 1, 3])
7993
elif get_env_device() == "gcu":
8094
cos_sin = rotary_emb.get_fused_cos_sin(value_states, seq_len=kv_seq_len)
8195
query_states, key_states = core.eager._run_custom_op(
@@ -132,6 +146,10 @@ def fusion_rms_norm(hidden_states, weight, variance_epsilon, use_fast_ln=False):
132146
return core.eager._run_custom_op("rms_norm_mlu", hidden_states, weight, variance_epsilon)[0]
133147
elif get_env_device() == "gcu":
134148
return core.eager._run_custom_op("rms_norm_gcu", hidden_states, weight, variance_epsilon)[0]
149+
elif get_env_device() == "intel_hpu":
150+
return paddle.incubate.nn.functional.fused_rms_norm(
151+
hidden_states, weight, None, variance_epsilon, hidden_states.dim() - 1
152+
)[0]
135153
elif get_env_device() == "xpu":
136154
try:
137155
import paddle_xpu_nn # noqa: F821
@@ -205,6 +223,24 @@ def fusion_flash_attention(
205223
attention_mask is None,
206224
True,
207225
)[0]
226+
elif get_env_device() == "intel_hpu":
227+
if config.context_parallel_degree > 1:
228+
raise ValueError("Context parallel is not implemented for intel_hpu")
229+
scaling_factor = query_states.shape[3] ** -0.5
230+
attention_mask = attention_mask.astype(query_states.dtype)
231+
attn_output = paddle.incubate.nn.functional.fused_dot_product_attention(
232+
query_states,
233+
key_states,
234+
value_states,
235+
attention_mask,
236+
scaling_factor,
237+
0.0,
238+
False,
239+
attention_mask is None,
240+
None,
241+
False,
242+
)
243+
attn_output = paddle.transpose(attn_output, [0, 2, 1, 3])
208244
else:
209245
if config.context_parallel_degree > 1:
210246
attn_output = RingFlashAttention.apply(

0 commit comments

Comments
 (0)