@@ -634,21 +634,25 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
634
634
# Index into tensors
635
635
input_ids = self .input_ids [indices ]
636
636
position_ids = self .position_ids [indices ]
637
- adapter_indices = self .adapter_meta .adapter_indices [indices ]
638
637
input_lengths_tensor = self .input_lengths_tensor [indices ]
639
638
cache_lengths_tensor = self .cache_lengths_tensor [indices ]
640
639
641
640
# Move to GPU now that we have the whole tensor
642
641
slot_indices = slot_indices .to (device )
643
-
644
- adapter_segments , adapter_segment_indices = find_segments (adapter_indices )
645
- adapter_segments = torch .tensor (adapter_segments , dtype = torch .int32 )
646
- adapter_meta = AdapterBatchMetadata (
647
- adapter_indices = adapter_indices ,
648
- adapter_set = adapter_set ,
649
- adapter_segments = adapter_segments ,
650
- segment_indices = adapter_segment_indices ,
651
- )
642
+ if self .adapter_meta is not None :
643
+ adapter_indices = self .adapter_meta .adapter_indices [indices ]
644
+ adapter_segments , adapter_segment_indices = find_segments (
645
+ adapter_indices
646
+ )
647
+ adapter_segments = torch .tensor (adapter_segments , dtype = torch .int32 )
648
+ adapter_meta = AdapterBatchMetadata (
649
+ adapter_indices = adapter_indices ,
650
+ adapter_set = adapter_set ,
651
+ adapter_segments = adapter_segments ,
652
+ segment_indices = adapter_segment_indices ,
653
+ )
654
+ else :
655
+ adapter_meta = None
652
656
htorch .core .mark_step ()
653
657
return type (self )(
654
658
batch_id = self .batch_id ,
@@ -710,6 +714,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
710
714
max_length = 0
711
715
max_input_length = 0
712
716
max_current_length = 0
717
+ ADAPTER_TO_INDEX = get_adapter_to_index ()
713
718
for b in batches :
714
719
total_batch_size += len (b )
715
720
max_blocks = max (max_blocks , b .max_blocks )
@@ -763,14 +768,15 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
763
768
cache_lengths_tensor = batches [0 ].cache_lengths_tensor .new_empty (
764
769
total_batch_size
765
770
)
766
- total_indices_size = sum (
767
- b .adapter_meta .adapter_indices .shape [0 ] for b in batches
768
- )
769
- adapter_indices = batches [0 ].adapter_meta .adapter_indices .new_empty (
770
- total_indices_size
771
- )
772
- adapter_segment_builder = SegmentConcatBuilder ()
773
- adapter_set = set ()
771
+ if ADAPTER_TO_INDEX :
772
+ total_indices_size = sum (
773
+ b .adapter_meta .adapter_indices .shape [0 ] for b in batches
774
+ )
775
+ adapter_indices = batches [0 ].adapter_meta .adapter_indices .new_empty (
776
+ total_indices_size
777
+ )
778
+ adapter_segment_builder = SegmentConcatBuilder ()
779
+ adapter_set = set ()
774
780
775
781
prompt_lengths_tensor = batches [0 ].prompt_lengths_tensor .new_empty (
776
782
total_batch_size
@@ -821,9 +827,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
821
827
start_index = cumulative_batch_size
822
828
end_index = cumulative_batch_size + valid_bsize
823
829
824
- index = torch .tensor (
825
- list (range (start_index , end_index )), device = batch .input_ids .device
826
- )
830
+ index = torch .tensor (list (range (start_index , end_index )), device = "cpu" )
827
831
top_n_tokens_tensor .index_copy_ (0 , index , batch .top_n_tokens_tensor )
828
832
all_input_ids_tensor [
829
833
start_index :end_index , : batch .all_input_ids_tensor .shape [1 ]
@@ -847,7 +851,9 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
847
851
)
848
852
849
853
if not prefilling :
850
- input_ids .index_copy_ (0 , index , batch .input_ids [:valid_bsize ])
854
+ input_ids .index_copy_ (
855
+ 0 , index .to (input_ids .device ), batch .input_ids [:valid_bsize ]
856
+ )
851
857
position_ids .index_copy_ (0 , index , batch .position_ids [:valid_bsize ])
852
858
slot_indices .index_copy_ (
853
859
0 , index , batch .slot_indices + cumulative_slots
@@ -858,20 +864,21 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
858
864
cache_lengths_tensor .index_copy_ (
859
865
0 , index , batch .cache_lengths_tensor [:valid_bsize ]
860
866
)
861
- adapter_start_index = cumulative_adapter_indices_size
862
- adapter_end_index = (
863
- cumulative_adapter_indices_size
864
- + batch .adapter_meta .adapter_indices .shape [0 ]
865
- )
866
- adapter_indices [adapter_start_index :adapter_end_index ] = (
867
- batch .adapter_meta .adapter_indices
868
- )
869
- cumulative_adapter_indices_size = adapter_end_index
870
- adapter_set .update (batch .adapter_meta .adapter_set )
871
- adapter_segment_builder .concat (
872
- batch .adapter_meta .adapter_segments ,
873
- batch .adapter_meta .segment_indices ,
874
- )
867
+ if ADAPTER_TO_INDEX :
868
+ adapter_start_index = cumulative_adapter_indices_size
869
+ adapter_end_index = (
870
+ cumulative_adapter_indices_size
871
+ + batch .adapter_meta .adapter_indices .shape [0 ]
872
+ )
873
+ adapter_indices [adapter_start_index :adapter_end_index ] = (
874
+ batch .adapter_meta .adapter_indices
875
+ )
876
+ cumulative_adapter_indices_size = adapter_end_index
877
+ adapter_set .update (batch .adapter_meta .adapter_set )
878
+ adapter_segment_builder .concat (
879
+ batch .adapter_meta .adapter_segments ,
880
+ batch .adapter_meta .segment_indices ,
881
+ )
875
882
else :
876
883
if isinstance (batch .input_ids , torch .Tensor ):
877
884
batch .input_ids = batch .input_ids .view (- 1 , 1 ).tolist ()
@@ -914,7 +921,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
914
921
else :
915
922
speculative_ids = None
916
923
917
- if adapter_segment_builder is not None :
924
+ if ADAPTER_TO_INDEX and adapter_segment_builder is not None :
918
925
adapter_segments , adapter_segment_indices = adapter_segment_builder .build ()
919
926
adapter_meta = AdapterBatchMetadata (
920
927
adapter_indices = adapter_indices ,
@@ -961,7 +968,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
961
968
num_blocks = num_blocks ,
962
969
max_blocks = max_blocks ,
963
970
speculative_ids = speculative_ids ,
964
- adapter_meta = adapter_meta ,
971
+ adapter_meta = adapter_meta if ADAPTER_TO_INDEX else None ,
965
972
hpu_attn_meta = None ,
966
973
next_token_logits = None ,
967
974
speculative_logits = None ,
@@ -1037,6 +1044,7 @@ def prepare_for_prefill(self, max_padded_input_len, max_padded_bs):
1037
1044
# need extra pad to match warmup seq
1038
1045
extra_pad = max_padded_input_len - self .max_input_length
1039
1046
extra_pad_bs = max_padded_bs - len (self )
1047
+ device = self .all_input_ids_tensor .device
1040
1048
if isinstance (self .input_ids , list ) and len (self ) > 1 :
1041
1049
input_ids_padded_length = []
1042
1050
input_ids = []
@@ -1047,12 +1055,12 @@ def prepare_for_prefill(self, max_padded_input_len, max_padded_bs):
1047
1055
input_ids .append (input_id )
1048
1056
input_ids_padded_length .append (padded )
1049
1057
input_ids = np .concatenate (input_ids , dtype = np .int64 )
1050
- self .input_ids = torch .tensor (input_ids , dtype = torch .int64 )
1058
+ self .input_ids = torch .tensor (input_ids , dtype = torch .int64 , device = device )
1051
1059
elif isinstance (self .input_ids , list ):
1052
1060
input_ids = self .input_ids [0 ]
1053
1061
input_ids_padded_length .append (extra_pad )
1054
1062
input_ids = [0 ] * extra_pad + input_ids
1055
- self .input_ids = torch .tensor (input_ids , dtype = torch .int64 )
1063
+ self .input_ids = torch .tensor (input_ids , dtype = torch .int64 , device = device )
1056
1064
else :
1057
1065
self .input_ids = F .pad (self .input_ids , (extra_pad , 0 ), value = 0 )
1058
1066
input_ids_padded_length .extend ([extra_pad ] * len (self ))
@@ -1245,7 +1253,9 @@ def prepare_for_prefill(self, max_padded_input_len, max_padded_bs):
1245
1253
self .slot_indices = slot_indices
1246
1254
1247
1255
self .prefill_cu_outlens = prefill_cu_outlens
1248
- self .prefill_cache_indices = torch .zeros_like (self .input_ids , dtype = torch .bool )
1256
+ self .prefill_cache_indices = torch .zeros_like (
1257
+ self .input_ids , dtype = torch .bool , device = "cpu"
1258
+ )
1249
1259
self .prefill_cache_indices [prefill_cache_indices ] = True
1250
1260
1251
1261
if all_prefill_logprobs :
@@ -1301,21 +1311,24 @@ def prepare_for_prefill(self, max_padded_input_len, max_padded_bs):
1301
1311
fsm_grammar_states ,
1302
1312
)
1303
1313
1304
- if adapter_set :
1305
- adapter_indices = torch .cat (adapter_indices_list ).to (dtype = torch .int64 )
1306
- adapter_segments , adapter_segment_indices = find_segments (adapter_indices )
1307
- else :
1308
- adapter_indices = torch .zeros_like (self .input_ids )
1309
- adapter_segments = [0 , len (adapter_indices )]
1310
- adapter_segment_indices = [len (adapter_indices ) - 1 ]
1311
-
1312
- adapter_segments = torch .tensor (adapter_segments , dtype = torch .int32 )
1313
- self .adapter_meta = AdapterBatchMetadata (
1314
- adapter_indices = adapter_indices ,
1315
- adapter_set = adapter_set ,
1316
- adapter_segments = adapter_segments ,
1317
- segment_indices = adapter_segment_indices ,
1318
- )
1314
+ if ADAPTER_TO_INDEX :
1315
+ if adapter_set :
1316
+ adapter_indices = torch .cat (adapter_indices_list ).to (dtype = torch .int64 )
1317
+ adapter_segments , adapter_segment_indices = find_segments (
1318
+ adapter_indices
1319
+ )
1320
+ else :
1321
+ adapter_indices = torch .zeros_like (self .input_ids )
1322
+ adapter_segments = [0 , len (adapter_indices )]
1323
+ adapter_segment_indices = [len (adapter_indices ) - 1 ]
1324
+
1325
+ adapter_segments = torch .tensor (adapter_segments , dtype = torch .int32 )
1326
+ self .adapter_meta = AdapterBatchMetadata (
1327
+ adapter_indices = adapter_indices ,
1328
+ adapter_set = adapter_set ,
1329
+ adapter_segments = adapter_segments ,
1330
+ segment_indices = adapter_segment_indices ,
1331
+ )
1319
1332
1320
1333
def __len__ (self ):
1321
1334
return len (self .requests )
@@ -1941,11 +1954,11 @@ def forward(
1941
1954
# This makes sure the max_s for the decode pass is correct.
1942
1955
max_s = min (self .max_past (), max_s )
1943
1956
if batch .prefill_cache_indices is not None :
1944
- slots_pad = torch .zeros_like (input_ids )
1957
+ slots_pad = torch .zeros_like (input_ids , device = slots . device )
1945
1958
slots_pad [batch .prefill_cache_indices ] = slots
1946
1959
slots = slots_pad
1947
1960
else :
1948
- slots_pad = torch .zeros_like (input_ids )
1961
+ slots_pad = torch .zeros_like (input_ids , device = slots . device )
1949
1962
slots_pad [: slots .shape [0 ]] = slots
1950
1963
slots = slots_pad
1951
1964
seqlen = Seqlen (
@@ -1965,7 +1978,7 @@ def forward(
1965
1978
)
1966
1979
1967
1980
logits , speculative_logits = self .model .forward (
1968
- input_ids = _async_h2d_tensor_copy ( input_ids ) ,
1981
+ input_ids = input_ids ,
1969
1982
position_ids = _async_h2d_tensor_copy (position_ids ),
1970
1983
cu_seqlen_prefill = _async_h2d_tensor_copy (cu_seqlen_prefill ),
1971
1984
kv_cache = kv_cache ,
@@ -2059,15 +2072,16 @@ def generate_token(
2059
2072
batch .position_ids = batch .position_ids [indices ]
2060
2073
2061
2074
batch .slot_indices = batch .slot_indices [indices [: len (batch )]]
2062
- batch .adapter_meta .adapter_indices = (
2063
- batch .adapter_meta .adapter_indices [indices ]
2064
- )
2075
+ if batch .adapter_meta is not None :
2076
+ batch .adapter_meta .adapter_indices = (
2077
+ batch .adapter_meta .adapter_indices [indices ]
2078
+ )
2065
2079
# For each member of the batch
2066
2080
# Cumulative length
2067
- accepted_ids = accepted_ids .cpu ()
2068
- cu_accepted_ids = accepted_ids .new_zeros (accepted_ids .shape [0 ] + 1 )
2069
- torch .cumsum (accepted_ids , dim = 0 , out = cu_accepted_ids [1 :])
2081
+
2070
2082
if batch .speculative_logits is not None :
2083
+ cu_accepted_ids = accepted_ids .new_zeros (accepted_ids .shape [0 ] + 1 )
2084
+ torch .cumsum (accepted_ids , dim = 0 , out = cu_accepted_ids [1 :])
2071
2085
for i in range (len (batch )):
2072
2086
batch .all_input_ids_tensor [
2073
2087
i ,
@@ -2076,6 +2090,20 @@ def generate_token(
2076
2090
+ batch .input_lengths [i ]
2077
2091
+ accepted_ids [i ],
2078
2092
] = next_input_ids [cu_accepted_ids [i ] : cu_accepted_ids [i + 1 ]]
2093
+ batch .input_ids = next_input_ids [cu_accepted_ids [1 :] - 1 ]
2094
+ accepted_ids = accepted_ids .cpu ()
2095
+ if batch .position_ids .dim () == 2 :
2096
+ # Qwen2_vl case:
2097
+ batch .position_ids += accepted_ids .unsqueeze (- 1 )
2098
+ else :
2099
+ batch .position_ids += accepted_ids
2100
+ batch .cache_lengths_tensor += (
2101
+ batch .input_lengths_tensor + accepted_ids - 1
2102
+ )
2103
+ batch .input_lengths_tensor = torch .ones_like (
2104
+ batch .input_lengths_tensor
2105
+ )
2106
+ batch .slot_indices += accepted_ids [: len (batch )]
2079
2107
else :
2080
2108
index = batch .cache_lengths_tensor + batch .input_lengths_tensor
2081
2109
index = index .to (batch .all_input_ids_tensor .device )
@@ -2088,22 +2116,18 @@ def generate_token(
2088
2116
batch .all_input_ids_tensor .index_put_ (
2089
2117
(batch_idx , index .long ()), next_input_ids
2090
2118
)
2091
- next_input_ids = next_input_ids .cpu ()
2092
- batch .input_ids = next_input_ids [cu_accepted_ids [1 :] - 1 ]
2119
+ batch .input_ids = next_input_ids
2120
+ batch .position_ids += 1
2121
+ batch .cache_lengths_tensor += batch .input_lengths_tensor
2122
+ batch .input_lengths_tensor = torch .ones_like (
2123
+ batch .input_lengths_tensor
2124
+ )
2125
+ batch .slot_indices += 1
2126
+
2093
2127
batch .speculative_ids = speculative_ids
2094
- if batch .position_ids .dim () == 2 :
2095
- # Qwen2_vl case:
2096
- batch .position_ids += accepted_ids .unsqueeze (- 1 )
2097
- else :
2098
- batch .position_ids += accepted_ids
2099
- batch .cache_lengths_tensor += (
2100
- batch .input_lengths_tensor + accepted_ids - 1
2101
- )
2102
- batch .input_lengths_tensor = torch .ones_like (batch .input_lengths_tensor )
2103
- batch .slot_indices += accepted_ids [: len (batch )]
2104
2128
2105
2129
# Does a HPU <-> CPU sync internally
2106
- if prefill :
2130
+ if prefill and batch . adapter_meta is not None :
2107
2131
# adjust segment lengths to account for all request lengths being 1 during decoding
2108
2132
adapter_segments , _ = find_segments (
2109
2133
batch .adapter_meta .adapter_indices
@@ -2194,30 +2218,33 @@ def generate_token(
2194
2218
prefill_logprobs = batch .prefill_next_token_indices is not None
2195
2219
# Update adapter indices for speculative tokens (if present)
2196
2220
adapter_meta = batch .adapter_meta
2197
- if batch .speculative_ids is not None :
2198
- B , speculative_length = batch .speculative_ids .shape
2199
- new_length = speculative_length + 1
2200
- adapter_indices = (
2201
- adapter_meta .adapter_indices .unsqueeze (- 1 )
2202
- .expand (B , new_length )
2203
- .reshape (- 1 )
2204
- )
2205
- adapter_segments = adapter_meta .adapter_segments * new_length
2206
- adapter_meta = AdapterBatchMetadata (
2207
- adapter_indices = adapter_indices ,
2208
- adapter_set = adapter_meta .adapter_set ,
2209
- adapter_segments = adapter_segments ,
2210
- segment_indices = adapter_meta .segment_indices ,
2211
- )
2221
+ if adapter_meta is not None :
2222
+ if batch .speculative_ids is not None :
2223
+ B , speculative_length = batch .speculative_ids .shape
2224
+ new_length = speculative_length + 1
2225
+ adapter_indices = (
2226
+ adapter_meta .adapter_indices .unsqueeze (- 1 )
2227
+ .expand (B , new_length )
2228
+ .reshape (- 1 )
2229
+ )
2230
+ adapter_segments = adapter_meta .adapter_segments * new_length
2231
+ adapter_meta = AdapterBatchMetadata (
2232
+ adapter_indices = adapter_indices ,
2233
+ adapter_set = adapter_meta .adapter_set ,
2234
+ adapter_segments = adapter_segments ,
2235
+ segment_indices = adapter_meta .segment_indices ,
2236
+ )
2212
2237
2213
- # Assign pointers to adapter weights
2214
- # TODO(travis): don't update this if indices haven't changed
2215
- adapter_data = AdapterBatchData .from_meta (
2216
- adapter_meta ,
2217
- self .layer_to_adapter_weights ,
2218
- prefill ,
2219
- batch .prefill_head_indices ,
2220
- )
2238
+ # Assign pointers to adapter weights
2239
+ # TODO(travis): don't update this if indices haven't changed
2240
+ adapter_data = AdapterBatchData .from_meta (
2241
+ adapter_meta ,
2242
+ self .layer_to_adapter_weights ,
2243
+ prefill ,
2244
+ batch .prefill_head_indices ,
2245
+ )
2246
+ else :
2247
+ adapter_data = None
2221
2248
2222
2249
out , speculative_logits = self .forward (batch , adapter_data )
2223
2250
0 commit comments