1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from dataclasses import dataclass , field
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch_xla .core .xla_model as xm
7
+
8
+ from vllm .v1 .sample .metadata import SamplingMetadata
9
+
10
+
11
+ @dataclass
12
+ class TPUSupportedSamplingMetadata :
13
+ # This class exposes a more xla-friendly interface than SamplingMetadata
14
+ # on TPU, in particular all arguments should be traceable and no optionals
15
+ # are allowed, to avoid graph recompilation on Nones.
16
+ temperature : torch .Tensor
17
+
18
+ min_p : torch .Tensor
19
+ # Still too slow on forward_native!
20
+ top_k : torch .Tensor = None
21
+ top_p : torch .Tensor = None
22
+
23
+ # XLA-unfriendly control flow in Sampler
24
+ all_greedy : bool = False
25
+ all_random : bool = False
26
+ # Greedy sampling flag for compiling single xla graph.
27
+ do_argmax : torch .Tensor = None
28
+
29
+ # speculation not supported
30
+ spec_token_ids = None
31
+
32
+ # Generator not supported by xla
33
+ generators : dict [int ,
34
+ torch .Generator ] = field (default_factory = lambda : dict ())
35
+
36
+ # unsupported, you need to return an extra tensor of static size BxV
37
+ max_num_logprobs = None
38
+
39
+ # TODO No penalties for now
40
+ no_penalties : bool = True
41
+ prompt_token_ids = None
42
+ frequency_penalties = None
43
+ presence_penalties = None
44
+ repetition_penalties = None
45
+ # should use tensor
46
+ output_token_ids : list [list [int ]] = field (default_factory = lambda : list ())
47
+
48
+ min_tokens = None # impl is not vectorized
49
+
50
+ logit_bias : list [Optional [dict [int , float ]]] = field (
51
+ default_factory = lambda : list ())
52
+
53
+ allowed_token_ids_mask = None
54
+ bad_words_token_ids = None
55
+ indices_do_sample : torch .Tensor = None
56
+
57
+ def __post_init__ (self ):
58
+ temp = self .temperature
59
+ if self .indices_do_sample is None :
60
+ self .indices_do_sample = torch .zeros (temp .shape [0 ],
61
+ device = temp .device ,
62
+ dtype = torch .int32 )
63
+ if self .do_argmax is None :
64
+ self .do_argmax = torch .tensor (0 ,
65
+ dtype = torch .bool ,
66
+ device = temp .device )
67
+
68
+ @classmethod
69
+ def from_sampling_metadata (
70
+ cls , metadata : SamplingMetadata ,
71
+ padded_do_sample_indices : torch .Tensor , num_do_sample : int ,
72
+ device : torch .device ) -> "TPUSupportedSamplingMetadata" :
73
+ """
74
+ Create an XLA-frienly SamplingMetadata structure. Do so by first
75
+ instantiating an object with fixed-sized tensors and then writing the
76
+ values in input `metadata`. Do that only for non-None values so that
77
+ recompilation is not triggered for optional values (None/torch.Tensor).
78
+
79
+ In order to handle different sizes for the params that range from 1 up
80
+ to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
81
+ Same thing for `padded_do_sample_indices`, which contains the indices
82
+ to be fed to the Sampler, padded to the closest pre-compiled shape.
83
+
84
+ Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
85
+ do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
86
+ """
87
+ metadata = cls ._validate_sampling_metadata (metadata )
88
+ # NOTE we have to initialize default tensor-based params first and
89
+ # skip None values altogether to produce the same xla graph.
90
+ num_samples = len (padded_do_sample_indices )
91
+ do_argmax = torch .tensor (metadata .all_greedy ,
92
+ dtype = torch .bool ,
93
+ device = device )
94
+ new_metadata = cls .get_default_sampling_params (num_samples , device ,
95
+ indices_do_sample = \
96
+ padded_do_sample_indices ,
97
+ do_argmax = do_argmax
98
+ )
99
+ supported_params = \
100
+ TPUSupportedSamplingMetadata ._get_default_params_values ()
101
+ # Copy input non-None values into `new_metadata` fixed-sized tensors.
102
+ for p_name in supported_params :
103
+ old_val = getattr (metadata , p_name )
104
+ new_val = getattr (new_metadata , p_name )
105
+ if isinstance (old_val , torch .Tensor ):
106
+ new_val [:num_do_sample ] = old_val
107
+ setattr (new_metadata , p_name , new_val )
108
+
109
+ xm .mark_step ()
110
+ xm .wait_device_ops ()
111
+ return new_metadata
112
+
113
+ @classmethod
114
+ def get_default_sampling_params (
115
+ cls ,
116
+ num_samples : int ,
117
+ device : torch .device ,
118
+ indices_do_sample = None ,
119
+ do_argmax = None ) -> "TPUSupportedSamplingMetadata" :
120
+ # As sampling happens on a single traced graph, options
121
+ # are "disabled" by having them evaluate to an Identity op.
122
+ # Note that initialization is dependent on num_samples.
123
+ sampling_metadata_disable_value = \
124
+ TPUSupportedSamplingMetadata ._get_default_params_values ()
125
+ init_kwargs = dict ()
126
+ for p_name , (default_val ,
127
+ dtype ) in sampling_metadata_disable_value .items ():
128
+ default_tensor = torch .full ((num_samples , ),
129
+ default_val ,
130
+ dtype = dtype ,
131
+ device = device )
132
+ init_kwargs [p_name ] = default_tensor
133
+
134
+ return cls (** init_kwargs ,
135
+ indices_do_sample = indices_do_sample ,
136
+ do_argmax = do_argmax )
137
+
138
+ @staticmethod
139
+ def _validate_sampling_metadata (
140
+ sampling_metadata : SamplingMetadata ) -> SamplingMetadata :
141
+ if sampling_metadata .all_greedy :
142
+ # Set to None since #13587. Make sure default isn't overruled.
143
+ assert sampling_metadata .temperature is None
144
+ return sampling_metadata
145
+
146
+ @staticmethod
147
+ def _get_default_params_values ():
148
+ return dict (
149
+ # Since #13587 greedy sampling requires branching off which leads
150
+ # to separate graphs. We set temp to noop and handle argmax here.
151
+ temperature = (1.0 , torch .float32 ),
152
+ min_p = (0.0 , torch .float32 ),
153
+ # strictly disabled for now
154
+ # top_k=(-1, torch.int32),
155
+ # top_p=(0.0, torch.float32),
156
+ # frequency_penalties=(0.0, torch.float32),
157
+ # presence_penalties=(0.0, torch.float32),
158
+ # repetition_penalties=(0.0, torch.float32),
159
+ )
0 commit comments