Skip to content

Commit f08b44a

Browse files
authored
Upgrade to new vllm extension ops for Gaudi backend (fix issue in exponential bucketing) (#3239)
Signed-off-by: Wang, Yi A <[email protected]>
1 parent 674c514 commit f08b44a

File tree

3 files changed

+10
-13
lines changed

3 files changed

+10
-13
lines changed

Dockerfile_gaudi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ RUN cd server && \
9898
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
9999
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
100100
pip install . --no-cache-dir
101-
RUN pip install git+https://github.com/HabanaAI/vllm-hpu-extension.git@a060794
101+
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
102102

103103
# Install benchmarker
104104
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark

backends/gaudi/server/text_generation_server/layers/attention/hpu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from habana_frameworks.torch.hpex.kernels import FusedSDPA
88
from vllm_hpu_extension.utils import ModuleFusedSDPA
99
import os
10+
from text_generation_server.models.globals import BLOCK_SIZE
1011

1112
SUPPORTS_WINDOWING = False
1213

@@ -126,6 +127,7 @@ def paged_attention(
126127
block_mapping=hpu_attention_meta.block_mapping,
127128
block_bias=hpu_attention_meta.attn_bias,
128129
block_groups=hpu_attention_meta.block_groups,
130+
block_size=BLOCK_SIZE,
129131
scale=softmax_scale,
130132
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
131133
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
@@ -160,6 +162,7 @@ def paged_attention_mla(
160162
block_mapping=hpu_attention_meta.block_mapping,
161163
block_bias=hpu_attention_meta.attn_bias,
162164
block_groups=hpu_attention_meta.block_groups,
165+
block_size=BLOCK_SIZE,
163166
scale=softmax_scale,
164167
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
165168
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),

backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from text_generation_server.models.globals import BLOCK_SIZE
77
from text_generation_server.utils.weights import Weights
8-
from vllm_hpu_extension import cache_ops
98

109

1110
@dataclass
@@ -55,12 +54,12 @@ def __init__(
5554

5655
self.kv_cache = (
5756
torch.zeros(
58-
(num_blocks, BLOCK_SIZE, num_heads, head_size),
57+
(num_blocks * BLOCK_SIZE, num_heads, head_size),
5958
dtype=dtype,
6059
device=device,
6160
),
6261
torch.zeros(
63-
(num_blocks, BLOCK_SIZE, num_heads, head_size),
62+
(num_blocks * BLOCK_SIZE, num_heads, head_size),
6463
dtype=dtype,
6564
device=device,
6665
),
@@ -129,7 +128,7 @@ def __init__(
129128
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
130129

131130
self.kv_cache = torch.zeros(
132-
(num_blocks, BLOCK_SIZE, 1, head_size),
131+
(num_blocks * BLOCK_SIZE, 1, head_size),
133132
dtype=dtype,
134133
device=device,
135134
)
@@ -161,14 +160,11 @@ def store(
161160
):
162161
"""Store the key and value at the given slots."""
163162
## TODO FP8 kv cache support
164-
165-
block_idx = slots // BLOCK_SIZE
166-
block_offset = slots % BLOCK_SIZE
167163
if self.kv_cache.dtype == torch.float8_e4m3fn:
168164
key = torch.ops.hpu.cast_to_fp8_v2(
169165
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
170166
)[0]
171-
cache_ops.insert_or_update_cache(key, self.kv_cache, block_idx, block_offset)
167+
self.kv_cache.index_copy_(0, slots, key)
172168

173169

174170
def paged_reshape_and_cache(
@@ -180,17 +176,15 @@ def paged_reshape_and_cache(
180176
k_scale: torch.Tensor,
181177
v_scale: torch.Tensor,
182178
):
183-
block_idx = slots // BLOCK_SIZE
184-
block_offset = slots % BLOCK_SIZE
185179
if key_cache.dtype == torch.float8_e4m3fn:
186180
key = torch.ops.hpu.cast_to_fp8_v2(
187181
key, k_scale, False, False, torch.float8_e4m3fn
188182
)[0]
189183
value = torch.ops.hpu.cast_to_fp8_v2(
190184
value, v_scale, False, False, torch.float8_e4m3fn
191185
)[0]
192-
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
193-
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
186+
key_cache.index_copy_(0, slots, key)
187+
value_cache.index_copy_(0, slots, value)
194188

195189

196190
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:

0 commit comments

Comments
 (0)