5
5
6
6
from text_generation_server .models .globals import BLOCK_SIZE
7
7
from text_generation_server .utils .weights import Weights
8
- from vllm_hpu_extension import cache_ops
9
8
10
9
11
10
@dataclass
@@ -55,12 +54,12 @@ def __init__(
55
54
56
55
self .kv_cache = (
57
56
torch .zeros (
58
- (num_blocks , BLOCK_SIZE , num_heads , head_size ),
57
+ (num_blocks * BLOCK_SIZE , num_heads , head_size ),
59
58
dtype = dtype ,
60
59
device = device ,
61
60
),
62
61
torch .zeros (
63
- (num_blocks , BLOCK_SIZE , num_heads , head_size ),
62
+ (num_blocks * BLOCK_SIZE , num_heads , head_size ),
64
63
dtype = dtype ,
65
64
device = device ,
66
65
),
@@ -129,7 +128,7 @@ def __init__(
129
128
raise ValueError ("torch.float8_e5m2 is not supported in hpu. " )
130
129
131
130
self .kv_cache = torch .zeros (
132
- (num_blocks , BLOCK_SIZE , 1 , head_size ),
131
+ (num_blocks * BLOCK_SIZE , 1 , head_size ),
133
132
dtype = dtype ,
134
133
device = device ,
135
134
)
@@ -161,14 +160,11 @@ def store(
161
160
):
162
161
"""Store the key and value at the given slots."""
163
162
## TODO FP8 kv cache support
164
-
165
- block_idx = slots // BLOCK_SIZE
166
- block_offset = slots % BLOCK_SIZE
167
163
if self .kv_cache .dtype == torch .float8_e4m3fn :
168
164
key = torch .ops .hpu .cast_to_fp8_v2 (
169
165
key , kv_scales .key_scale , False , False , torch .float8_e4m3fn
170
166
)[0 ]
171
- cache_ops . insert_or_update_cache ( key , self .kv_cache , block_idx , block_offset )
167
+ self .kv_cache . index_copy_ ( 0 , slots , key )
172
168
173
169
174
170
def paged_reshape_and_cache (
@@ -180,17 +176,15 @@ def paged_reshape_and_cache(
180
176
k_scale : torch .Tensor ,
181
177
v_scale : torch .Tensor ,
182
178
):
183
- block_idx = slots // BLOCK_SIZE
184
- block_offset = slots % BLOCK_SIZE
185
179
if key_cache .dtype == torch .float8_e4m3fn :
186
180
key = torch .ops .hpu .cast_to_fp8_v2 (
187
181
key , k_scale , False , False , torch .float8_e4m3fn
188
182
)[0 ]
189
183
value = torch .ops .hpu .cast_to_fp8_v2 (
190
184
value , v_scale , False , False , torch .float8_e4m3fn
191
185
)[0 ]
192
- cache_ops . insert_or_update_cache ( key , key_cache , block_idx , block_offset )
193
- cache_ops . insert_or_update_cache ( value , value_cache , block_idx , block_offset )
186
+ key_cache . index_copy_ ( 0 , slots , key )
187
+ value_cache . index_copy_ ( 0 , slots , value )
194
188
195
189
196
190
def get_kv_scales (weights : Weights , prefix : str ) -> KVScales :
0 commit comments