@@ -120,10 +120,15 @@ def load(self,
120
120
handle = self .cache [(runtime_shape , graph_index , self .compiler .name )]
121
121
compiled_graph = self .compiler .load (handle , graph , example_inputs ,
122
122
graph_index , runtime_shape )
123
- logger .debug (
124
- "Directly load the %s-th graph for shape %s from %s via "
125
- "handle %s" , graph_index , str (runtime_shape ), self .compiler .name ,
126
- handle )
123
+ if runtime_shape is None :
124
+ logger .debug (
125
+ "Directly load the %s-th graph for dynamic shape from %s via "
126
+ "handle %s" , graph_index , self .compiler .name , handle )
127
+ else :
128
+ logger .debug (
129
+ "Directly load the %s-th graph for shape %s from %s via "
130
+ "handle %s" , graph_index , str (runtime_shape ),
131
+ self .compiler .name , handle )
127
132
return compiled_graph
128
133
129
134
def compile (self ,
@@ -152,9 +157,15 @@ def compile(self,
152
157
# there can be multiple graphs due to piecewise compilation.
153
158
now = time .time ()
154
159
elapsed = now - compilation_start_time
155
- logger .info (
156
- "Directly load the compiled graph(s) for shape %s "
157
- "from the cache, took %.3f s" , str (runtime_shape ), elapsed )
160
+ if runtime_shape is None :
161
+ logger .info (
162
+ "Directly load the compiled graph(s) for dynamic shape "
163
+ "from the cache, took %.3f s" , elapsed )
164
+ else :
165
+ logger .info (
166
+ "Directly load the compiled graph(s) for shape %s "
167
+ "from the cache, took %.3f s" , str (runtime_shape ),
168
+ elapsed )
158
169
return compiled_graph
159
170
160
171
# no compiler cached the graph, or the cache is disabled,
@@ -178,19 +189,29 @@ def compile(self,
178
189
self .is_cache_updated = True
179
190
if graph_index == 0 :
180
191
# adds some info logging for the first graph
181
- logger .info ("Cache the graph of shape %s for later use" ,
182
- str (runtime_shape ))
183
- logger .debug (
184
- "store the %s-th graph for shape %s from %s via handle %s" ,
185
- graph_index , str (runtime_shape ), self .compiler .name , handle )
192
+ if runtime_shape is None :
193
+ logger .info (
194
+ "Cache the graph for dynamic shape for later use" )
195
+ else :
196
+ logger .info ("Cache the graph of shape %s for later use" ,
197
+ str (runtime_shape ))
198
+ if runtime_shape is None :
199
+ logger .debug (
200
+ "Store the %s-th graph for dynamic shape from %s via "
201
+ "handle %s" , graph_index , self .compiler .name , handle )
202
+ else :
203
+ logger .debug (
204
+ "Store the %s-th graph for shape %s from %s via handle %s" ,
205
+ graph_index , str (runtime_shape ), self .compiler .name ,
206
+ handle )
186
207
187
208
# after compiling the last graph, record the end time
188
209
if graph_index == num_graphs - 1 :
189
210
now = time .time ()
190
211
elapsed = now - compilation_start_time
191
212
compilation_config .compilation_time += elapsed
192
213
if runtime_shape is None :
193
- logger .info ("Compiling a graph for general shape takes %.2f s" ,
214
+ logger .info ("Compiling a graph for dynamic shape takes %.2f s" ,
194
215
elapsed )
195
216
else :
196
217
logger .info ("Compiling a graph for shape %s takes %.2f s" ,
@@ -308,7 +329,7 @@ def call_module(self, target: torch.fx.node.Target,
308
329
i for i , x in enumerate (args ) if isinstance (x , torch .SymInt )
309
330
]
310
331
global compilation_start_time
311
- compiled_graph_for_general_shape = self .vllm_backend .\
332
+ compiled_graph_for_dynamic_shape = self .vllm_backend .\
312
333
compiler_manager .compile (
313
334
submod ,
314
335
args ,
@@ -323,7 +344,7 @@ def call_module(self, target: torch.fx.node.Target,
323
344
self .module .__dict__ [target ] = piecewise_backend (
324
345
submod , self .vllm_config , self .graph_pool , index ,
325
346
len (self .compile_submod_names ), sym_shape_indices ,
326
- compiled_graph_for_general_shape , self .vllm_backend )
347
+ compiled_graph_for_dynamic_shape , self .vllm_backend )
327
348
328
349
compilation_counter .num_piecewise_capturable_graphs_seen += 1
329
350
0 commit comments