Skip to content

Commit 9e7e546

Browse files
authored
Move input_ids to hpu and remove disposal of adapter_meta (#3237)
Signed-off-by: Wang, Yi A <[email protected]>
1 parent e325287 commit 9e7e546

File tree

4 files changed

+137
-108
lines changed

4 files changed

+137
-108
lines changed

backends/gaudi/server/text_generation_server/layers/attention/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def clamp(self, max):
9090
def _async_h2d_tensor_copy(source, device="hpu"):
9191
if source is None:
9292
return None
93+
if source.device.type == "hpu":
94+
return source
9395
assert source.device.type == "cpu", "Source tensor is not present in host memory!"
9496
target = torch.empty(source.shape, dtype=source.dtype, device=device)
9597
target.copy_(source, non_blocking=True)

backends/gaudi/server/text_generation_server/models/flash_causal_lm.py

Lines changed: 128 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -634,21 +634,25 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
634634
# Index into tensors
635635
input_ids = self.input_ids[indices]
636636
position_ids = self.position_ids[indices]
637-
adapter_indices = self.adapter_meta.adapter_indices[indices]
638637
input_lengths_tensor = self.input_lengths_tensor[indices]
639638
cache_lengths_tensor = self.cache_lengths_tensor[indices]
640639

641640
# Move to GPU now that we have the whole tensor
642641
slot_indices = slot_indices.to(device)
643-
644-
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
645-
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
646-
adapter_meta = AdapterBatchMetadata(
647-
adapter_indices=adapter_indices,
648-
adapter_set=adapter_set,
649-
adapter_segments=adapter_segments,
650-
segment_indices=adapter_segment_indices,
651-
)
642+
if self.adapter_meta is not None:
643+
adapter_indices = self.adapter_meta.adapter_indices[indices]
644+
adapter_segments, adapter_segment_indices = find_segments(
645+
adapter_indices
646+
)
647+
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
648+
adapter_meta = AdapterBatchMetadata(
649+
adapter_indices=adapter_indices,
650+
adapter_set=adapter_set,
651+
adapter_segments=adapter_segments,
652+
segment_indices=adapter_segment_indices,
653+
)
654+
else:
655+
adapter_meta = None
652656
htorch.core.mark_step()
653657
return type(self)(
654658
batch_id=self.batch_id,
@@ -710,6 +714,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
710714
max_length = 0
711715
max_input_length = 0
712716
max_current_length = 0
717+
ADAPTER_TO_INDEX = get_adapter_to_index()
713718
for b in batches:
714719
total_batch_size += len(b)
715720
max_blocks = max(max_blocks, b.max_blocks)
@@ -763,14 +768,15 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
763768
cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
764769
total_batch_size
765770
)
766-
total_indices_size = sum(
767-
b.adapter_meta.adapter_indices.shape[0] for b in batches
768-
)
769-
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
770-
total_indices_size
771-
)
772-
adapter_segment_builder = SegmentConcatBuilder()
773-
adapter_set = set()
771+
if ADAPTER_TO_INDEX:
772+
total_indices_size = sum(
773+
b.adapter_meta.adapter_indices.shape[0] for b in batches
774+
)
775+
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
776+
total_indices_size
777+
)
778+
adapter_segment_builder = SegmentConcatBuilder()
779+
adapter_set = set()
774780

775781
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
776782
total_batch_size
@@ -821,9 +827,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
821827
start_index = cumulative_batch_size
822828
end_index = cumulative_batch_size + valid_bsize
823829

824-
index = torch.tensor(
825-
list(range(start_index, end_index)), device=batch.input_ids.device
826-
)
830+
index = torch.tensor(list(range(start_index, end_index)), device="cpu")
827831
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
828832
all_input_ids_tensor[
829833
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
@@ -847,7 +851,9 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
847851
)
848852

849853
if not prefilling:
850-
input_ids.index_copy_(0, index, batch.input_ids[:valid_bsize])
854+
input_ids.index_copy_(
855+
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
856+
)
851857
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
852858
slot_indices.index_copy_(
853859
0, index, batch.slot_indices + cumulative_slots
@@ -858,20 +864,21 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
858864
cache_lengths_tensor.index_copy_(
859865
0, index, batch.cache_lengths_tensor[:valid_bsize]
860866
)
861-
adapter_start_index = cumulative_adapter_indices_size
862-
adapter_end_index = (
863-
cumulative_adapter_indices_size
864-
+ batch.adapter_meta.adapter_indices.shape[0]
865-
)
866-
adapter_indices[adapter_start_index:adapter_end_index] = (
867-
batch.adapter_meta.adapter_indices
868-
)
869-
cumulative_adapter_indices_size = adapter_end_index
870-
adapter_set.update(batch.adapter_meta.adapter_set)
871-
adapter_segment_builder.concat(
872-
batch.adapter_meta.adapter_segments,
873-
batch.adapter_meta.segment_indices,
874-
)
867+
if ADAPTER_TO_INDEX:
868+
adapter_start_index = cumulative_adapter_indices_size
869+
adapter_end_index = (
870+
cumulative_adapter_indices_size
871+
+ batch.adapter_meta.adapter_indices.shape[0]
872+
)
873+
adapter_indices[adapter_start_index:adapter_end_index] = (
874+
batch.adapter_meta.adapter_indices
875+
)
876+
cumulative_adapter_indices_size = adapter_end_index
877+
adapter_set.update(batch.adapter_meta.adapter_set)
878+
adapter_segment_builder.concat(
879+
batch.adapter_meta.adapter_segments,
880+
batch.adapter_meta.segment_indices,
881+
)
875882
else:
876883
if isinstance(batch.input_ids, torch.Tensor):
877884
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
@@ -914,7 +921,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
914921
else:
915922
speculative_ids = None
916923

917-
if adapter_segment_builder is not None:
924+
if ADAPTER_TO_INDEX and adapter_segment_builder is not None:
918925
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
919926
adapter_meta = AdapterBatchMetadata(
920927
adapter_indices=adapter_indices,
@@ -961,7 +968,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
961968
num_blocks=num_blocks,
962969
max_blocks=max_blocks,
963970
speculative_ids=speculative_ids,
964-
adapter_meta=adapter_meta,
971+
adapter_meta=adapter_meta if ADAPTER_TO_INDEX else None,
965972
hpu_attn_meta=None,
966973
next_token_logits=None,
967974
speculative_logits=None,
@@ -1037,6 +1044,7 @@ def prepare_for_prefill(self, max_padded_input_len, max_padded_bs):
10371044
# need extra pad to match warmup seq
10381045
extra_pad = max_padded_input_len - self.max_input_length
10391046
extra_pad_bs = max_padded_bs - len(self)
1047+
device = self.all_input_ids_tensor.device
10401048
if isinstance(self.input_ids, list) and len(self) > 1:
10411049
input_ids_padded_length = []
10421050
input_ids = []
@@ -1047,12 +1055,12 @@ def prepare_for_prefill(self, max_padded_input_len, max_padded_bs):
10471055
input_ids.append(input_id)
10481056
input_ids_padded_length.append(padded)
10491057
input_ids = np.concatenate(input_ids, dtype=np.int64)
1050-
self.input_ids = torch.tensor(input_ids, dtype=torch.int64)
1058+
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
10511059
elif isinstance(self.input_ids, list):
10521060
input_ids = self.input_ids[0]
10531061
input_ids_padded_length.append(extra_pad)
10541062
input_ids = [0] * extra_pad + input_ids
1055-
self.input_ids = torch.tensor(input_ids, dtype=torch.int64)
1063+
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
10561064
else:
10571065
self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
10581066
input_ids_padded_length.extend([extra_pad] * len(self))
@@ -1245,7 +1253,9 @@ def prepare_for_prefill(self, max_padded_input_len, max_padded_bs):
12451253
self.slot_indices = slot_indices
12461254

12471255
self.prefill_cu_outlens = prefill_cu_outlens
1248-
self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool)
1256+
self.prefill_cache_indices = torch.zeros_like(
1257+
self.input_ids, dtype=torch.bool, device="cpu"
1258+
)
12491259
self.prefill_cache_indices[prefill_cache_indices] = True
12501260

12511261
if all_prefill_logprobs:
@@ -1301,21 +1311,24 @@ def prepare_for_prefill(self, max_padded_input_len, max_padded_bs):
13011311
fsm_grammar_states,
13021312
)
13031313

1304-
if adapter_set:
1305-
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)
1306-
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
1307-
else:
1308-
adapter_indices = torch.zeros_like(self.input_ids)
1309-
adapter_segments = [0, len(adapter_indices)]
1310-
adapter_segment_indices = [len(adapter_indices) - 1]
1311-
1312-
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
1313-
self.adapter_meta = AdapterBatchMetadata(
1314-
adapter_indices=adapter_indices,
1315-
adapter_set=adapter_set,
1316-
adapter_segments=adapter_segments,
1317-
segment_indices=adapter_segment_indices,
1318-
)
1314+
if ADAPTER_TO_INDEX:
1315+
if adapter_set:
1316+
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)
1317+
adapter_segments, adapter_segment_indices = find_segments(
1318+
adapter_indices
1319+
)
1320+
else:
1321+
adapter_indices = torch.zeros_like(self.input_ids)
1322+
adapter_segments = [0, len(adapter_indices)]
1323+
adapter_segment_indices = [len(adapter_indices) - 1]
1324+
1325+
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
1326+
self.adapter_meta = AdapterBatchMetadata(
1327+
adapter_indices=adapter_indices,
1328+
adapter_set=adapter_set,
1329+
adapter_segments=adapter_segments,
1330+
segment_indices=adapter_segment_indices,
1331+
)
13191332

13201333
def __len__(self):
13211334
return len(self.requests)
@@ -1941,11 +1954,11 @@ def forward(
19411954
# This makes sure the max_s for the decode pass is correct.
19421955
max_s = min(self.max_past(), max_s)
19431956
if batch.prefill_cache_indices is not None:
1944-
slots_pad = torch.zeros_like(input_ids)
1957+
slots_pad = torch.zeros_like(input_ids, device=slots.device)
19451958
slots_pad[batch.prefill_cache_indices] = slots
19461959
slots = slots_pad
19471960
else:
1948-
slots_pad = torch.zeros_like(input_ids)
1961+
slots_pad = torch.zeros_like(input_ids, device=slots.device)
19491962
slots_pad[: slots.shape[0]] = slots
19501963
slots = slots_pad
19511964
seqlen = Seqlen(
@@ -1965,7 +1978,7 @@ def forward(
19651978
)
19661979

19671980
logits, speculative_logits = self.model.forward(
1968-
input_ids=_async_h2d_tensor_copy(input_ids),
1981+
input_ids=input_ids,
19691982
position_ids=_async_h2d_tensor_copy(position_ids),
19701983
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
19711984
kv_cache=kv_cache,
@@ -2059,15 +2072,16 @@ def generate_token(
20592072
batch.position_ids = batch.position_ids[indices]
20602073

20612074
batch.slot_indices = batch.slot_indices[indices[: len(batch)]]
2062-
batch.adapter_meta.adapter_indices = (
2063-
batch.adapter_meta.adapter_indices[indices]
2064-
)
2075+
if batch.adapter_meta is not None:
2076+
batch.adapter_meta.adapter_indices = (
2077+
batch.adapter_meta.adapter_indices[indices]
2078+
)
20652079
# For each member of the batch
20662080
# Cumulative length
2067-
accepted_ids = accepted_ids.cpu()
2068-
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
2069-
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
2081+
20702082
if batch.speculative_logits is not None:
2083+
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
2084+
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
20712085
for i in range(len(batch)):
20722086
batch.all_input_ids_tensor[
20732087
i,
@@ -2076,6 +2090,20 @@ def generate_token(
20762090
+ batch.input_lengths[i]
20772091
+ accepted_ids[i],
20782092
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
2093+
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
2094+
accepted_ids = accepted_ids.cpu()
2095+
if batch.position_ids.dim() == 2:
2096+
# Qwen2_vl case:
2097+
batch.position_ids += accepted_ids.unsqueeze(-1)
2098+
else:
2099+
batch.position_ids += accepted_ids
2100+
batch.cache_lengths_tensor += (
2101+
batch.input_lengths_tensor + accepted_ids - 1
2102+
)
2103+
batch.input_lengths_tensor = torch.ones_like(
2104+
batch.input_lengths_tensor
2105+
)
2106+
batch.slot_indices += accepted_ids[: len(batch)]
20792107
else:
20802108
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
20812109
index = index.to(batch.all_input_ids_tensor.device)
@@ -2088,22 +2116,18 @@ def generate_token(
20882116
batch.all_input_ids_tensor.index_put_(
20892117
(batch_idx, index.long()), next_input_ids
20902118
)
2091-
next_input_ids = next_input_ids.cpu()
2092-
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
2119+
batch.input_ids = next_input_ids
2120+
batch.position_ids += 1
2121+
batch.cache_lengths_tensor += batch.input_lengths_tensor
2122+
batch.input_lengths_tensor = torch.ones_like(
2123+
batch.input_lengths_tensor
2124+
)
2125+
batch.slot_indices += 1
2126+
20932127
batch.speculative_ids = speculative_ids
2094-
if batch.position_ids.dim() == 2:
2095-
# Qwen2_vl case:
2096-
batch.position_ids += accepted_ids.unsqueeze(-1)
2097-
else:
2098-
batch.position_ids += accepted_ids
2099-
batch.cache_lengths_tensor += (
2100-
batch.input_lengths_tensor + accepted_ids - 1
2101-
)
2102-
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
2103-
batch.slot_indices += accepted_ids[: len(batch)]
21042128

21052129
# Does a HPU <-> CPU sync internally
2106-
if prefill:
2130+
if prefill and batch.adapter_meta is not None:
21072131
# adjust segment lengths to account for all request lengths being 1 during decoding
21082132
adapter_segments, _ = find_segments(
21092133
batch.adapter_meta.adapter_indices
@@ -2194,30 +2218,33 @@ def generate_token(
21942218
prefill_logprobs = batch.prefill_next_token_indices is not None
21952219
# Update adapter indices for speculative tokens (if present)
21962220
adapter_meta = batch.adapter_meta
2197-
if batch.speculative_ids is not None:
2198-
B, speculative_length = batch.speculative_ids.shape
2199-
new_length = speculative_length + 1
2200-
adapter_indices = (
2201-
adapter_meta.adapter_indices.unsqueeze(-1)
2202-
.expand(B, new_length)
2203-
.reshape(-1)
2204-
)
2205-
adapter_segments = adapter_meta.adapter_segments * new_length
2206-
adapter_meta = AdapterBatchMetadata(
2207-
adapter_indices=adapter_indices,
2208-
adapter_set=adapter_meta.adapter_set,
2209-
adapter_segments=adapter_segments,
2210-
segment_indices=adapter_meta.segment_indices,
2211-
)
2221+
if adapter_meta is not None:
2222+
if batch.speculative_ids is not None:
2223+
B, speculative_length = batch.speculative_ids.shape
2224+
new_length = speculative_length + 1
2225+
adapter_indices = (
2226+
adapter_meta.adapter_indices.unsqueeze(-1)
2227+
.expand(B, new_length)
2228+
.reshape(-1)
2229+
)
2230+
adapter_segments = adapter_meta.adapter_segments * new_length
2231+
adapter_meta = AdapterBatchMetadata(
2232+
adapter_indices=adapter_indices,
2233+
adapter_set=adapter_meta.adapter_set,
2234+
adapter_segments=adapter_segments,
2235+
segment_indices=adapter_meta.segment_indices,
2236+
)
22122237

2213-
# Assign pointers to adapter weights
2214-
# TODO(travis): don't update this if indices haven't changed
2215-
adapter_data = AdapterBatchData.from_meta(
2216-
adapter_meta,
2217-
self.layer_to_adapter_weights,
2218-
prefill,
2219-
batch.prefill_head_indices,
2220-
)
2238+
# Assign pointers to adapter weights
2239+
# TODO(travis): don't update this if indices haven't changed
2240+
adapter_data = AdapterBatchData.from_meta(
2241+
adapter_meta,
2242+
self.layer_to_adapter_weights,
2243+
prefill,
2244+
batch.prefill_head_indices,
2245+
)
2246+
else:
2247+
adapter_data = None
22212248

22222249
out, speculative_logits = self.forward(batch, adapter_data)
22232250

0 commit comments

Comments
 (0)