@@ -420,6 +420,7 @@ async def add_request_async(
420
420
lora_request : Optional [LoRARequest ] = None ,
421
421
trace_headers : Optional [Mapping [str , str ]] = None ,
422
422
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
423
+ priority : int = 0 ,
423
424
) -> None :
424
425
...
425
426
@@ -433,6 +434,7 @@ async def add_request_async(
433
434
lora_request : Optional [LoRARequest ] = None ,
434
435
trace_headers : Optional [Mapping [str , str ]] = None ,
435
436
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
437
+ priority : int = 0 ,
436
438
) -> None :
437
439
...
438
440
@@ -449,6 +451,7 @@ async def add_request_async(
449
451
lora_request : Optional [LoRARequest ] = None ,
450
452
trace_headers : Optional [Mapping [str , str ]] = None ,
451
453
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
454
+ priority : int = 0 ,
452
455
* ,
453
456
inputs : Optional [PromptType ] = None , # DEPRECATED
454
457
) -> None :
@@ -460,6 +463,9 @@ async def add_request_async(
460
463
if lora_request is not None and not self .lora_config :
461
464
raise ValueError (f"Got lora_request { lora_request } but LoRA is "
462
465
"not enabled!" )
466
+ if priority != 0 and not self .scheduler_config .policy == "priority" :
467
+ raise ValueError (f"Got priority { priority } but "
468
+ "Priority scheduling is not enabled." )
463
469
if arrival_time is None :
464
470
arrival_time = time .time ()
465
471
@@ -479,6 +485,7 @@ async def add_request_async(
479
485
lora_request = lora_request ,
480
486
prompt_adapter_request = prompt_adapter_request ,
481
487
trace_headers = trace_headers ,
488
+ priority = priority ,
482
489
)
483
490
484
491
async def check_health_async (self ) -> None :
@@ -829,6 +836,7 @@ def add_request(
829
836
lora_request : Optional [LoRARequest ] = None ,
830
837
trace_headers : Optional [Mapping [str , str ]] = None ,
831
838
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
839
+ priority : int = 0 ,
832
840
) -> Coroutine [None , None , AsyncGenerator [Union [
833
841
RequestOutput , EmbeddingRequestOutput ], None ]]:
834
842
...
@@ -843,6 +851,7 @@ def add_request(
843
851
lora_request : Optional [LoRARequest ] = None ,
844
852
trace_headers : Optional [Mapping [str , str ]] = None ,
845
853
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
854
+ priority : int = 0 ,
846
855
) -> Coroutine [None , None , AsyncGenerator [Union [
847
856
RequestOutput , EmbeddingRequestOutput ], None ]]:
848
857
...
@@ -860,6 +869,7 @@ async def add_request(
860
869
lora_request : Optional [LoRARequest ] = None ,
861
870
trace_headers : Optional [Mapping [str , str ]] = None ,
862
871
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
872
+ priority : int = 0 ,
863
873
* ,
864
874
inputs : Optional [PromptType ] = None , # DEPRECATED
865
875
) -> AsyncGenerator [Union [RequestOutput , EmbeddingRequestOutput ], None ]:
@@ -877,6 +887,11 @@ async def add_request(
877
887
"error that caused the background loop to stop "
878
888
"(AsyncEngineDeadError)." )
879
889
890
+ if (priority != 0
891
+ and not self .engine .scheduler_config .policy == "priority" ):
892
+ raise ValueError (f"Got priority { priority } but "
893
+ "Priority scheduling is not enabled." )
894
+
880
895
stream = self ._request_tracker .add_request (
881
896
request_id ,
882
897
verbose = self .log_requests ,
@@ -885,7 +900,9 @@ async def add_request(
885
900
arrival_time = arrival_time or time .time (),
886
901
lora_request = lora_request ,
887
902
trace_headers = trace_headers ,
888
- prompt_adapter_request = prompt_adapter_request )
903
+ prompt_adapter_request = prompt_adapter_request ,
904
+ priority = priority ,
905
+ )
889
906
890
907
return stream .generator ()
891
908
@@ -896,7 +913,8 @@ async def generate(
896
913
request_id : str ,
897
914
lora_request : Optional [LoRARequest ] = None ,
898
915
trace_headers : Optional [Mapping [str , str ]] = None ,
899
- prompt_adapter_request : Optional [PromptAdapterRequest ] = None
916
+ prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
917
+ priority : int = 0 ,
900
918
) -> AsyncGenerator [RequestOutput , None ]:
901
919
"""Generate outputs for a request.
902
920
@@ -913,6 +931,8 @@ async def generate(
913
931
trace_headers: OpenTelemetry trace headers.
914
932
prompt_adapter_request: Prompt Adapter request to use
915
933
for generation, if any.
934
+ priority: The priority of the request.
935
+ Only applicable with priority scheduling.
916
936
917
937
Yields:
918
938
The output `RequestOutput` objects from the LLMEngine
@@ -968,6 +988,7 @@ async def generate(
968
988
lora_request = lora_request ,
969
989
trace_headers = trace_headers ,
970
990
prompt_adapter_request = prompt_adapter_request ,
991
+ priority = priority ,
971
992
):
972
993
yield LLMEngine .validate_output (output , RequestOutput )
973
994
0 commit comments