Skip to content

Commit a9c32a1

Browse files
authored
[tnx] fix optimum token selection and sampling (#2233)
1 parent e7dd68e commit a9c32a1

File tree

4 files changed

+233
-10
lines changed

4 files changed

+233
-10
lines changed

engines/python/setup/djl_python/transformers_neuronx_scheduler/optimum_modeling.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from transformers import PretrainedConfig
2626
from transformers_neuronx import bucket
2727
from transformers_neuronx.constants import LAYOUT_BSH, LAYOUT_HSB
28-
from optimum.neuron.generation import TokenSelector
28+
from djl_python.transformers_neuronx_scheduler.optimum_token_selector import OptimumTokenSelector
2929
from optimum.neuron.utils.version_utils import check_compiler_compatibility, get_neuronxcc_version
3030
from optimum.modeling_base import OptimizedModel
3131
from transformers.generation import StoppingCriteriaList
@@ -238,11 +238,12 @@ def generate(
238238
self._validate_model_kwargs(model_kwargs)
239239

240240
# Instantiate a TokenSelector for the specified configuration
241-
selector = TokenSelector.create(input_ids,
242-
generation_config,
243-
self,
244-
self.max_length,
245-
stopping_criteria=stopping_criteria)
241+
selector = OptimumTokenSelector.create(
242+
input_ids,
243+
generation_config,
244+
self,
245+
self.max_length,
246+
stopping_criteria=stopping_criteria)
246247

247248
# Verify that the inputs are compatible with the model static input dimensions
248249
batch_size, sequence_length = input_ids.shape
@@ -280,7 +281,7 @@ def generate(
280281
def generate_tokens(
281282
self,
282283
input_ids: torch.LongTensor,
283-
selector: TokenSelector,
284+
selector: OptimumTokenSelector,
284285
batch_size: int,
285286
attention_mask: Optional[torch.Tensor] = None,
286287
**model_kwargs,
@@ -291,7 +292,7 @@ def generate_tokens(
291292
Args:
292293
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
293294
The sequence used as a prompt for the generation.
294-
selector (`TokenSelector`):
295+
selector (`OptimumTokenSelector`):
295296
The object implementing the generation logic based on transformers processors and stopping criterias.
296297
batch_size (`int`):
297298
The actual input batch size. Used to avoid generating tokens for padded inputs.

engines/python/setup/djl_python/transformers_neuronx_scheduler/optimum_neuron_scheduler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from dataclasses import dataclass
2525

2626
from djl_python.transformers_neuronx_scheduler.slot import Slot
27-
from djl_python.rolling_batch.rolling_batch import filter_unused_generation_params
2827
from djl_python.request import Request
2928
from djl_python.transformers_neuronx_scheduler.token_selector import TokenSelector
3029
from djl_python.transformers_neuronx_scheduler.speculation import (
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
6+
# except in compliance with the License. A copy of the License is located at
7+
#
8+
# http://aws.amazon.com/apache2.0/
9+
#
10+
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
11+
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
12+
# the specific language governing permissions and limitations under the License.
13+
# The below code is heavily inspired from Optimum Neuron under the following link:
14+
# https://github.com/huggingface/optimum-neuron/blob/main/optimum/neuron/generation/token_selector.py
15+
16+
import copy
17+
import logging
18+
from typing import TYPE_CHECKING, List, Optional
19+
20+
import torch
21+
from transformers.generation import (
22+
GenerationConfig,
23+
GenerationMixin,
24+
LogitsProcessorList,
25+
StoppingCriteriaList,
26+
)
27+
from transformers.generation.utils import GenerationMode
28+
29+
from optimum.neuron.generation import FusedLogitsWarper
30+
31+
if TYPE_CHECKING:
32+
from transformers import PreTrainedTokenizer
33+
34+
logger = logging.getLogger(__name__)
35+
36+
37+
# TODO: This is a temporary solution to avoid Optimum's dependency on transformers<4.42.
38+
class OptimumTokenSelector:
39+
"""Implements the token selection logic corresponding to a generation configuration.
40+
41+
This class combines and uses the logits processors and stopping criterias implemented in
42+
the transformers library.
43+
44+
The algorithm to select these objects is heavily inspired by the transformers `GenerationMixin.generate()`
45+
method, but the actual token selection methods are specific.
46+
47+
The reason why this class does not inherit from `GenerationMixin` is because it does not
48+
include the code to produce the tokens logits.
49+
Separating the production of the tokens logits from the tokens selection allows this class
50+
to be used with different generation paradigms, either synchronously using a single `TokenSelector` in
51+
`GenerationMixin.generate()` or asynchronously using multiple `TokenSelector` inside an inference endpoint.
52+
53+
The constructor of this class should not be called directly: instances should be obtained by
54+
calling `TokenSelector.create()`.
55+
"""
56+
57+
def __init__(
58+
self,
59+
mode: GenerationMode,
60+
logits_processor: LogitsProcessorList,
61+
stopping_criteria: StoppingCriteriaList,
62+
eos_token_ids: List[int],
63+
pad_token_id: int,
64+
logits_warper: Optional[LogitsProcessorList] = None,
65+
seed: Optional[int] = 0,
66+
):
67+
self.mode = mode
68+
self.logits_processor = logits_processor
69+
self.stopping_criteria = stopping_criteria
70+
self.eos_token_ids = eos_token_ids
71+
self.pad_token_id = pad_token_id
72+
self.logits_warper = logits_warper
73+
self.generator = torch.Generator()
74+
self.generator.manual_seed(seed)
75+
76+
@classmethod
77+
def create(
78+
cls,
79+
input_ids: torch.Tensor,
80+
generation_config: GenerationConfig,
81+
model: GenerationMixin,
82+
max_seq_length: int,
83+
stopping_criteria: Optional[StoppingCriteriaList] = None,
84+
tokenizer: Optional["PreTrainedTokenizer"] = None,
85+
seed: Optional[int] = 0,
86+
) -> "OptimumTokenSelector":
87+
r"""Creates the `TokenSelector` for a specific generation configuration.
88+
89+
Args:
90+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
91+
The sequence used as a prompt for the generation.
92+
generation_config (`~transformers.generation.GenerationConfig`, *optional*):
93+
The generation configuration to parametrize the token selection.
94+
model (`~transformers.generation.GenerationMixin`):
95+
The model provides the internal helpers allowing to select the logits processors and stopping criterias.
96+
max_seq_length (`int`):
97+
The maximum number of input + generated tokens for this model. It depends on the model compilation parameters.
98+
stopping_criteria (`Optional[transformers.generation.StoppingCriteriaList], defaults to `None`):
99+
Custom stopping criteria that complement the default stopping criteria built from arguments and a
100+
generation config
101+
tokenizer (`Optional[transformers.PreTrainedTokenizer]`, default to `None`):
102+
A tokenizer used when stop strings are passed to generate.
103+
seed(`Optional[int]`):
104+
The optional seed for sampling. Defaults to zero.
105+
Return:
106+
`torch.LongTensor`: A `torch.LongTensor` containing the selected tokens.
107+
"""
108+
generation_config.validate()
109+
generation_config = copy.deepcopy(generation_config)
110+
111+
unsupported_generation_flags = [
112+
"output_attentions",
113+
"output_hidden_states",
114+
"output_scores",
115+
"return_dict_in_generate",
116+
]
117+
for flag in unsupported_generation_flags:
118+
if getattr(generation_config, flag, False):
119+
raise ValueError("{flag} is not supported for generation.")
120+
121+
if generation_config.max_new_tokens is not None:
122+
logger.warning(
123+
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
124+
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
125+
"Please refer to the documentation for more information. "
126+
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
127+
)
128+
generation_config.max_length = generation_config.max_new_tokens + input_ids.shape[
129+
-1]
130+
131+
min_length = generation_config.min_length
132+
if min_length > max_seq_length:
133+
raise ValueError(
134+
f"The minimum generation length ({min_length}) exceeds the model maximum sequence length ({max_seq_length})"
135+
)
136+
max_length = generation_config.max_length
137+
if max_length > max_seq_length:
138+
logger.warning(
139+
f"Adjusting the maximum generation length ({max_length}) to the model maximum sequence length ({max_seq_length})"
140+
)
141+
generation_config.max_length = max_seq_length
142+
143+
# This is not supposed to happen for any of the models we support
144+
eos_token_id = generation_config.eos_token_id
145+
assert eos_token_id is not None
146+
# The generation requires special tokens
147+
eos_token_ids = eos_token_id if isinstance(eos_token_id,
148+
list) else [eos_token_id]
149+
generation_config._eos_token_tensor = torch.tensor(
150+
eos_token_ids, device=input_ids.device)
151+
if generation_config.pad_token_id is None:
152+
logger.warning(
153+
f"Setting `pad_token_id` to `eos_token_id`:{eos_token_ids[0]} for open-ended generation."
154+
)
155+
generation_config.pad_token_id = eos_token_ids[0]
156+
157+
# Instantiate transformers library processors and criterias
158+
logits_processor = model._get_logits_processor(
159+
generation_config,
160+
input_ids_seq_length=input_ids.shape[-1],
161+
encoder_input_ids=input_ids,
162+
prefix_allowed_tokens_fn=None,
163+
logits_processor=LogitsProcessorList(),
164+
)
165+
if stopping_criteria is None:
166+
stopping_criteria = StoppingCriteriaList()
167+
stopping_criteria = model._get_stopping_criteria(
168+
generation_config,
169+
stopping_criteria=stopping_criteria,
170+
tokenizer=tokenizer)
171+
172+
generation_mode = generation_config.get_generation_mode()
173+
if generation_mode not in [
174+
GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE
175+
]:
176+
raise ValueError("Unsupported generation mode")
177+
178+
logits_warper = None
179+
if generation_mode == GenerationMode.SAMPLE:
180+
logits_warper = FusedLogitsWarper.from_config(generation_config)
181+
182+
return cls(
183+
mode=generation_mode,
184+
logits_processor=logits_processor,
185+
stopping_criteria=stopping_criteria,
186+
logits_warper=logits_warper,
187+
eos_token_ids=eos_token_ids,
188+
pad_token_id=generation_config.pad_token_id,
189+
seed=seed,
190+
)
191+
192+
def select(self, input_ids: torch.LongTensor,
193+
logits: torch.Tensor) -> torch.LongTensor:
194+
"""Select the next tokens from the candidate logits.
195+
196+
Args:
197+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
198+
The sequence used as a prompt for the generation (not used in all generation modes).
199+
logits (`torch.Tensor` of shape `(batch_size, sequence_length)`):
200+
The logits corresponding to the generated tokens.
201+
202+
Return:
203+
`torch.LongTensor`: A `torch.LongTensor` containing the selected tokens.
204+
"""
205+
scores = self.logits_processor(input_ids, logits)
206+
if self.mode == GenerationMode.SAMPLE:
207+
return self._sample(scores)
208+
else:
209+
return torch.argmax(scores, dim=-1)
210+
211+
def _sample(self, scores: torch.Tensor) -> torch.LongTensor:
212+
# Get [batch_size, kept] scores and indices instead of [batch_size, vocab_size] scores
213+
scores, next_token_indices = self.logits_warper(scores)
214+
215+
# sample
216+
probs = torch.nn.functional.softmax(scores, dim=-1)
217+
next_tokens = torch.multinomial(probs,
218+
num_samples=1,
219+
generator=self.generator)
220+
# Convert the filtered tokens to actual vocabulary tokens
221+
next_tokens = torch.gather(next_token_indices, 1, next_tokens)
222+
return next_tokens.squeeze(1)

engines/python/setup/djl_python/transformers_neuronx_scheduler/token_selector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ def create(
169169

170170
logits_warper = None
171171
if generation_mode == GenerationMode.SAMPLE:
172-
logits_warper = model._get_logits_warper(generation_config)
172+
logits_warper = model._get_logits_warper(generation_config,
173+
device=model.device)
173174
if len(logits_warper) == 0:
174175
generation_mode = GenerationMode.GREEDY_SEARCH
175176

0 commit comments

Comments
 (0)