Skip to content

Commit 475a2ef

Browse files
authored
Merge pull request #1 from ROCm/greg/fp8_tests
Greg/fp8 tests
2 parents 644b165 + 926e2b8 commit 475a2ef

File tree

11 files changed

+117
-35
lines changed

11 files changed

+117
-35
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ _build/
181181
# hip files generated by PyTorch
182182
*.hip
183183
*_hip*
184+
hip_compat.h
184185

185186
# Benchmark dataset
186187
*.json

csrc/attention/attention_dtypes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
#include "dtype_float16.cuh"
55
#include "dtype_float32.cuh"
66
#include "dtype_bfloat16.cuh"
7-
#include "dtype_fp8_e5m2.cuh"
7+
#include "dtype_fp8.cuh"

csrc/attention/dtype_fp8_e5m2.cuh renamed to csrc/attention/dtype_fp8.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#endif
99

1010
namespace vllm {
11-
#ifdef ENABLE_FP8_E5M2
11+
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
1212
// fp8 vector types for quantization of kv cache
1313

1414
template<>

csrc/cache.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,6 @@ void gather_cached_kv(
3131
torch::Tensor& slot_mapping);
3232

3333
// Just for unittest
34-
void convert_fp8_e5m2(
34+
void convert_fp8(
3535
torch::Tensor& src_cache,
3636
torch::Tensor& dst_cache);

csrc/cache_kernels.cu

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
#include "cuda_compat.h"
66
#include "dispatch_utils.h"
7-
#ifdef ENABLE_FP8_E5M2
7+
#if defined(ENABLE_FP8_E5M2)
88
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
9+
#else if defined(ENABLE_FP8_E4M3)
10+
#include "quantization/fp8/amd_detail/quant_utils.cuh"
911
#endif
1012

1113
#include <algorithm>
@@ -196,9 +198,12 @@ __global__ void reshape_and_cache_kernel(
196198
scalar_t tgt_key = key[src_key_idx];
197199
scalar_t tgt_value = value[src_value_idx];
198200
if constexpr (is_fp8_e5m2_kv_cache) {
199-
#ifdef ENABLE_FP8_E5M2
201+
#if defined(ENABLE_FP8_E5M2)
200202
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
201203
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
204+
#elif defined(ENABLE_FP8_E4M3)
205+
key_cache[tgt_key_idx] = fp8_e4m3::vec_conversion<uint8_t, scalar_t>(tgt_key);
206+
value_cache[tgt_value_idx] = fp8_e4m3::vec_conversion<uint8_t, scalar_t>(tgt_value);
202207
#else
203208
assert(false);
204209
#endif
@@ -431,15 +436,17 @@ void gather_cached_kv(
431436
namespace vllm {
432437

433438
template<typename Tout, typename Tin>
434-
__global__ void convert_fp8_e5m2_kernel(
439+
__global__ void convert_fp8_kernel(
435440
const Tin* __restrict__ src_cache,
436441
Tout* __restrict__ dst_cache,
437442
const int64_t block_stride) {
438443
const int64_t block_idx = blockIdx.x;
439444
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
440445
int64_t idx = block_idx * block_stride + i;
441-
#ifdef ENABLE_FP8_E5M2
446+
#if defined(ENABLE_FP8_E5M2)
442447
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
448+
#elif defined(ENABLE_FP8_E4M3)
449+
dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]);
443450
#else
444451
assert(false);
445452
#endif
@@ -448,16 +455,29 @@ __global__ void convert_fp8_e5m2_kernel(
448455

449456
} // namespace vllm
450457

451-
#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
452-
vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
453-
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
454-
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
458+
#define CALL_CONVERT_FP8(Tout, Tin) \
459+
vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
460+
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
461+
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
455462
block_stride);
456463

457-
void convert_fp8_e5m2(
464+
void convert_fp8(
458465
torch::Tensor& src_cache,
459466
torch::Tensor& dst_cache)
460467
{
468+
torch::Device src_device = src_cache.device();
469+
torch::Device dst_device = dst_cache.device();
470+
if (src_device.is_cuda() && dst_device.is_cuda()) {
471+
TORCH_CHECK(
472+
src_device.index() == dst_device.index(),
473+
"src and dst must be on the same GPU");
474+
}
475+
at::cuda::OptionalCUDAGuard device_guard;
476+
if (src_device.is_cuda()) {
477+
device_guard.set_device(src_device);
478+
} else if (dst_device.is_cuda()) {
479+
device_guard.set_device(dst_device);
480+
}
461481
int64_t num_blocks = src_cache.size(0);
462482
int64_t block_stride = src_cache.stride(0);
463483

@@ -466,16 +486,16 @@ void convert_fp8_e5m2(
466486
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
467487

468488
if (src_cache.dtype() == at::ScalarType::Float) {
469-
CALL_CONVERT_FP8_E5M2(uint8_t, float);
489+
CALL_CONVERT_FP8(uint8_t, float);
470490
} else if (src_cache.dtype() == at::ScalarType::Half) {
471-
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
491+
CALL_CONVERT_FP8(uint8_t, uint16_t);
472492
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
473-
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
493+
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16);
474494
} else if (dst_cache.dtype() == at::ScalarType::Float) {
475-
CALL_CONVERT_FP8_E5M2(float, uint8_t);
495+
CALL_CONVERT_FP8(float, uint8_t);
476496
} else if (dst_cache.dtype() == at::ScalarType::Half) {
477-
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
497+
CALL_CONVERT_FP8(uint16_t, uint8_t);
478498
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
479-
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
499+
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t);
480500
}
481501
}

csrc/pybind.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
8080
&gather_cached_kv,
8181
"Gather key and value from the cache into contiguous QKV tensors");
8282
cache_ops.def(
83-
"convert_fp8_e5m2",
84-
&convert_fp8_e5m2,
83+
"convert_fp8",
84+
&convert_fp8,
8585
"Convert the key and value cache to fp8_e5m2 data type");
8686

8787
// Cuda utils

csrc/quantization/fp8/amd_detail/quant_utils.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
namespace vllm
1212
{
13-
13+
namespace fp8_e4m3 {
1414
template <typename Tout, typename Tin>
1515
__inline__ __device__ Tout vec_conversion(const Tin& x)
1616
{
@@ -290,4 +290,5 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_&
290290
b.w = __float22bfloat162_rn(a.w);
291291
return b;
292292
}
293+
}
293294
} // namespace vllm

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
# Supported NVIDIA GPU architectures.
2121
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
22-
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
22+
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx942", "gfx1030", "gfx1100"}
2323
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)
2424

2525

@@ -296,6 +296,7 @@ def get_torch_arch_list() -> Set[str]:
296296
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
297297
f"amdgpu_arch_found: {arch}")
298298
NVCC_FLAGS += [f"--offload-arch={arch}"]
299+
NVCC_FLAGS += ["-DENABLE_FP8_E4M3"]
299300

300301
elif _is_neuron():
301302
neuronxcc_version = get_neuronxcc_version()

tests/kernels/test_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,14 +230,14 @@ def test_paged_attention(
230230
dequantized_key_cache = torch.empty(size=key_cache_shape,
231231
dtype=dtype,
232232
device=gpu_id)
233-
cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache)
233+
cache_ops.convert_fp8(key_cache, dequantized_key_cache)
234234
key_cache = dequantized_key_cache
235235

236236
value_cache_shape = value_cache.shape
237237
dequantized_value_cache = torch.empty(size=value_cache_shape,
238238
dtype=dtype,
239239
device=gpu_id)
240-
cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache)
240+
cache_ops.convert_fp8(value_cache, dequantized_value_cache)
241241
value_cache = dequantized_value_cache
242242

243243
ref_output = torch.empty_like(query)

tests/kernels/test_cache.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def test_copy_blocks(
9999
@pytest.mark.parametrize("dtype", DTYPES)
100100
@pytest.mark.parametrize("seed", SEEDS)
101101
@pytest.mark.parametrize("device", DEVICES)
102+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
102103
@torch.inference_mode()
103104
def test_reshape_and_cache(
104105
kv_cache_factory,
@@ -110,6 +111,7 @@ def test_reshape_and_cache(
110111
dtype: torch.dtype,
111112
seed: int,
112113
device: int,
114+
kv_cache_dtype: str,
113115
) -> None:
114116
random.seed(seed)
115117
torch.random.manual_seed(seed)
@@ -130,17 +132,29 @@ def test_reshape_and_cache(
130132

131133
# Create the KV caches.
132134
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
133-
num_heads, head_size, dtype,
134-
None, seed, gpu_id)
135+
num_heads, head_size, kv_cache_dtype,
136+
dtype, seed, gpu_id)
135137
key_cache, value_cache = key_caches[0], value_caches[0]
136138

137139
# Clone the KV caches.
138-
cloned_key_cache = key_cache.clone()
139-
cloned_value_cache = value_cache.clone()
140+
if kv_cache_dtype == "fp8_e5m2":
141+
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
142+
cache_ops.convert_fp8(key_cache, cloned_key_cache)
143+
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
144+
cache_ops.convert_fp8(value_cache, cloned_value_cache)
145+
else:
146+
cloned_key_cache = key_cache.clone()
147+
cloned_value_cache = value_cache.clone()
140148

141149
# Call the reshape_and_cache kernel.
142150
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
143-
slot_mapping, "auto")
151+
slot_mapping, kv_cache_dtype)
152+
153+
if kv_cache_dtype == "fp8_e5m2":
154+
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
155+
cache_ops.convert_fp8(key_cache, result_key_cache)
156+
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
157+
cache_ops.convert_fp8(value_cache, result_value_cache)
144158

145159
# Run the reference implementation.
146160
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
@@ -153,9 +167,13 @@ def test_reshape_and_cache(
153167
block_offset = block_offsets[i]
154168
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
155169
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
156-
157-
assert torch.allclose(key_cache, cloned_key_cache)
158-
assert torch.allclose(value_cache, cloned_value_cache)
170+
171+
if kv_cache_dtype == "fp8_e5m2":
172+
assert torch.allclose(result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1)
173+
assert torch.allclose(result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1)
174+
else:
175+
assert torch.allclose(key_cache, cloned_key_cache)
176+
assert torch.allclose(value_cache, cloned_value_cache)
159177

160178

161179
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@@ -167,6 +185,7 @@ def test_reshape_and_cache(
167185
@pytest.mark.parametrize("dtype", DTYPES)
168186
@pytest.mark.parametrize("seed", SEEDS)
169187
@pytest.mark.parametrize("device", DEVICES)
188+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
170189
@torch.inference_mode()
171190
def test_swap_blocks(
172191
kv_cache_factory,
@@ -179,7 +198,10 @@ def test_swap_blocks(
179198
dtype: torch.dtype,
180199
seed: int,
181200
device: int,
201+
kv_cache_dtype: str,
182202
) -> None:
203+
if kv_cache_dtype == "fp8_e5m2" and "cpu" in direction:
204+
return
183205
random.seed(seed)
184206
torch.random.manual_seed(seed)
185207
torch.cuda.manual_seed(seed)
@@ -200,12 +222,12 @@ def test_swap_blocks(
200222

201223
# Create the KV caches on the first device.
202224
src_key_caches, src_value_caches = kv_cache_factory(
203-
num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed,
225+
num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype, seed,
204226
src_device)
205227

206228
# Create the KV caches on the second device.
207229
dist_key_caches, dist_value_caches = kv_cache_factory(
208-
num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed,
230+
num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype, seed,
209231
dst_device)
210232

211233
src_key_caches_clone = src_key_caches[0].clone()
@@ -221,3 +243,40 @@ def test_swap_blocks(
221243
dist_key_caches[0][dst].cpu())
222244
assert torch.allclose(src_value_caches_clone[src].cpu(),
223245
dist_value_caches[0][dst].cpu())
246+
247+
248+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
249+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
250+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
251+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
252+
@pytest.mark.parametrize("dtype", DTYPES)
253+
@pytest.mark.parametrize("seed", SEEDS)
254+
@pytest.mark.parametrize("device", DEVICES)
255+
@torch.inference_mode()
256+
def test_fp8_conversion(
257+
num_heads: int,
258+
head_size: int,
259+
block_size: int,
260+
num_blocks: int,
261+
dtype: torch.dtype,
262+
seed: int,
263+
device: int,
264+
) -> None:
265+
random.seed(seed)
266+
torch.random.manual_seed(seed)
267+
torch.cuda.manual_seed(seed)
268+
gpu_id = f"cuda:{device}"
269+
270+
low = -240.0
271+
high = 240.0
272+
shape = (num_blocks, num_heads, head_size, block_size)
273+
cache = torch.empty(shape, dtype=dtype, device=gpu_id)
274+
cache.uniform_(low, high)
275+
276+
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
277+
cache_ops.convert_fp8(cache, cache_fp8)
278+
279+
converted_cache = torch.empty_like(cache)
280+
cache_ops.convert_fp8(cache_fp8, converted_cache)
281+
282+
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)

0 commit comments

Comments
 (0)