Skip to content

Commit 5364d5e

Browse files
authored
[DCU] high performance LLM train and inference for DCU (#8580)
* [DCU] high performance LLM train and inference for DCU
1 parent a2b8a78 commit 5364d5e

12 files changed

+249
-7
lines changed

csrc/generation/get_padding_offset_v2.cu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
#include "paddle/extension.h"
216

317
__global__ void RemovePaddingV2(int64_t *output_data,

csrc/generation/helper.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,44 @@
1515
#pragma once
1616

1717
#include "paddle/extension.h"
18+
#ifdef PADDLE_WITH_HIP
19+
#include <hip/hip_runtime.h>
20+
#include <hip/hip_fp16.h>
21+
#include <hip/hip_bfloat16.h>
22+
#include <hipcub/hipcub.hpp>
23+
#include <hiprand.h>
24+
#include <hiprand_kernel.h>
25+
namespace cub = hipcub;
26+
#else
1827
#include <cub/cub.cuh>
1928
#include <curand_kernel.h>
29+
#endif
2030

2131
constexpr int kBlockSize = 256;
2232
constexpr int kNumWaves = 16;
2333

34+
#ifdef PADDLE_WITH_HIP
35+
inline hipError_t GetNumBlocks(int64_t n, int* num_blocks) {
36+
int dev;
37+
{
38+
hipError_t err = hipGetDevice(&dev);
39+
if (err != hipSuccess) { return err; }
40+
}
41+
int sm_count;
42+
{
43+
hipError_t err = hipDeviceGetAttribute(&sm_count, hipDeviceAttributeMultiprocessorCount, dev);
44+
if (err != hipSuccess) { return err; }
45+
}
46+
int tpm;
47+
{
48+
hipError_t err = hipDeviceGetAttribute(&tpm, hipDeviceAttributeMaxThreadsPerMultiProcessor, dev);
49+
if (err != hipSuccess) { return err; }
50+
}
51+
*num_blocks = std::max<int>(1, std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
52+
sm_count * tpm / kBlockSize * kNumWaves));
53+
return hipSuccess;
54+
}
55+
#else
2456
inline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) {
2557
int dev;
2658
{
@@ -41,6 +73,7 @@ inline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) {
4173
sm_count * tpm / kBlockSize * kNumWaves));
4274
return cudaSuccess;
4375
}
76+
#endif
4477

4578
template<typename T>
4679
__device__ T max_func(const T a, const T b) {
@@ -74,7 +107,11 @@ class PDTraits<paddle::DataType::FLOAT16> {
74107
template <>
75108
class PDTraits<paddle::DataType::BFLOAT16> {
76109
public:
110+
#ifdef PADDLE_WITH_HIP
111+
typedef hip_bfloat16 DataType;
112+
#else
77113
typedef __nv_bfloat16 DataType;
114+
#endif
78115
typedef paddle::bfloat16 data_t;
79116
};
80117

csrc/generation/quant_int8.cu

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@
2222
#include<sys/mman.h>
2323
#include<stdio.h>
2424
#include<algorithm>
25+
#ifdef PADDLE_WITH_HIP
26+
#include <hip/hip_fp16.h>
27+
#include <hip/hip_bfloat16.h>
28+
#else
2529
#include<cuda_fp16.h>
2630
#include<cuda_bf16.h>
31+
#endif
2732

2833

2934
constexpr int DequantKernelVecSize = 4;
@@ -52,11 +57,17 @@ __forceinline__ __device__ half add_mul<half>(half a, half b, half c) {
5257
return __hmul(__hadd(a, b), c);
5358
}
5459

60+
#ifdef PADDLE_WITH_HIP
61+
template<>
62+
__forceinline__ __device__ hip_bfloat16 add_mul<hip_bfloat16>(hip_bfloat16 a, hip_bfloat16 b, hip_bfloat16 c) {
63+
return (a + b) * c;
64+
}
65+
#else
5566
template<>
5667
__forceinline__ __device__ __nv_bfloat16 add_mul<__nv_bfloat16>(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
5768
return __hmul(__hadd(a, b), c);
5869
}
59-
70+
#endif
6071

6172

6273
template <typename data_t>
@@ -173,8 +184,13 @@ std::vector<paddle::Tensor> LaunchQuantInt8(const paddle::Tensor& input,
173184
auto output=paddle::full(input_shape, -1, paddle::DataType::INT8, input.place());
174185
int m = input_shape[0];
175186
int n = input_shape[1];
187+
#ifdef PADDLE_WITH_HIP
188+
dim3 grid(((n >> 2) + 63) / 64, (m + 7) / 8);
189+
dim3 block(64, 8);
190+
#else
176191
dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32);
177192
dim3 block(32, 32);
193+
#endif
178194
auto stream = input.stream();
179195
if (shift && smooth) {
180196
QuantKernel<DataType_><<<grid, block, 0, stream>>>(reinterpret_cast<const DataType_*>(input.data<data_t>()),

csrc/generation/rebuild_padding.cu

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@ void InvokeRebuildPadding(T *output_data,
5858
const int *padding_offset,
5959
const int token_num,
6060
const int dim_embed,
61-
cudaStream_t stream) {
61+
#ifdef PADDLE_WITH_HIP
62+
hipStream_t stream
63+
#else
64+
cudaStream_t stream
65+
#endif
66+
) {
6267
// src: [token_num, dim_embed]
6368
// dst: [batch_size * max_seq_len, dim_embed]
6469
RebuildPaddingKernel<<<token_num, 256, 0, stream>>>(

csrc/generation/rebuild_padding_v2.cu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
#include "helper.h"
216

317
template <typename T, int VecSize>

csrc/generation/set_value_by_flags_v2.cu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
#include "paddle/extension.h"
216

317
__global__ void set_value_by_flag_and_id_v2(const bool *stop_flags,

csrc/generation/step.cu

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
#include "helper.h"
216

317
// #define DEBUG_STEP
@@ -255,7 +269,11 @@ void StepPaddle(const paddle::Tensor& stop_flags,
255269
max_decoder_block_num
256270
);
257271
#ifdef DEBUG_STEP
272+
#ifdef PADDLE_WITH_HIP
273+
hipDeviceSynchronize();
274+
#else
258275
cudaDeviceSynchronize();
276+
#endif
259277
#endif
260278
auto cpu_recover_lens = recover_lens.copy_to(paddle::CPUPlace(), false);
261279
const int grid_size = cpu_recover_lens.data<int>()[0];
@@ -287,7 +305,11 @@ void StepPaddle(const paddle::Tensor& stop_flags,
287305
first_token_id
288306
);
289307
#ifdef DEBUG_STEP
308+
#ifdef PADDLE_WITH_HIP
309+
hipDeviceSynchronize();
310+
#else
290311
cudaDeviceSynchronize();
312+
#endif
291313
#endif
292314
}
293315
}

csrc/generation/stop_generation_multi_ends_v2.cu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
#include "paddle/extension.h"
216
#include<stdlib.h>
317
#include<string.h>

csrc/generation/token_penalty_multi_scores_v2.cu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
#include "helper.h"
216

317

csrc/generation/transpose_removing_padding.cu

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,12 @@ void InvokeTransposeRemovePadding(const T* input_data,
6565
const int head_dim,
6666
const int token_num,
6767
const int* padding_offset,
68-
cudaStream_t cu_stream) {
68+
#ifdef PADDLE_WITH_HIP
69+
hipStream_t cu_stream
70+
#else
71+
cudaStream_t cu_stream
72+
#endif
73+
) {
6974
// [batch_size, num_head, max_len_this_time, head_dim] -> [token_num, num_head,
7075
// head_dim]
7176
constexpr int VEC_16B = 16;

0 commit comments

Comments
 (0)