@@ -44,6 +44,8 @@ def parse_input_with_formatter(inputs: Input, **kwargs) -> ParsedInput:
4444 request_id_counter = get_req_id_counter (kwargs )
4545 for i , input_item in enumerate (batch ):
4646 try :
47+ kwargs ["is_rolling_batch" ] = is_rolling_batch_enabled (
48+ kwargs .get ("configs" ).rolling_batch )
4749 request_id = request_id_counter .next_id (
4850 ) if request_id_counter else i
4951 # TODO: Decide whether it is a text input based on content-type
@@ -70,7 +72,7 @@ def parse_input_with_formatter(inputs: Input, **kwargs) -> ParsedInput:
7072
7173def get_req_id_counter (kwargs ):
7274 req_id_counter = None
73- if is_rolling_batch_enabled ( kwargs .get ("configs" ). rolling_batch ):
75+ if kwargs .get ("is_rolling_batch" ):
7476 req_id_counter = kwargs .get ("rolling_batch" ).req_id_counter
7577 return req_id_counter
7678
@@ -89,26 +91,29 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input,
8991 invoke_type = input_item .get_property ("X-Amzn-SageMaker-Forwarded-Api" )
9092 tokenizer = kwargs .get ("tokenizer" )
9193 if is_chat_completions_request (input_map ):
92- _inputs , _param = parse_chat_completions_request (
94+ inputs , param = parse_chat_completions_request (
9395 input_map , kwargs .get ("is_rolling_batch" ), tokenizer )
9496 elif is_3p_request (invoke_type ):
95- _inputs , _param = parse_3p_request (input_map ,
96- kwargs .get ("is_rolling_batch" ),
97- tokenizer , invoke_type )
97+ inputs , param = parse_3p_request (input_map ,
98+ kwargs .get ("is_rolling_batch" ),
99+ tokenizer , invoke_type )
98100 else :
99- _inputs = input_map .pop ("inputs" , input_map )
100- _param = input_map .pop ("parameters" , {})
101-
102- request_input .input_text = _inputs
103- request_input .parameters = _param
104- # assign input_ids
105- if kwargs .get ("tokenizer" ):
101+ inputs = input_map .pop ("inputs" , input_map )
102+ param = input_map .pop ("parameters" , {})
103+
104+ request_input .input_text = inputs
105+ request_input .parameters = param
106+ # assigns input_ids
107+ # TODO: for dynamic batching, or HF pipeline, tokenizer is applied differently.
108+ if kwargs .get ("tokenizer" ) and kwargs .get ("is_rolling_batch" ):
106109 request_input .input_ids = tokenizer .encode (request_input .input_text )
107110
111+ # TODO: Instead of modifying user parameters, maintain this in server_parameters.
112+ # Added here for backward compatibility
108113 # re-organize the parameters
109- if is_rolling_batch_enabled ( kwargs .get ("configs" ). rolling_batch ):
114+ if kwargs .get ("is_rolling_batch" ):
110115 if "stream" in input_map :
111- request_input .parameters [ " stream" ] = input_map .pop ("stream" )
116+ request_input .stream = input_map .pop ("stream" )
112117 if "cached_prompt" in input_map :
113118 request_input .parameters ["cached_prompt" ] = input_map .pop (
114119 "cached_prompt" )
@@ -124,18 +129,20 @@ def add_server_maintained_params(request_input: TextInput, input_item: Input,
124129 if input_item .contains_key ("seed" ):
125130 request_input .server_parameters ["seed" ] = input_item .get_as_string (
126131 key = "seed" )
132+
133+ # setting the output formatter
134+ output_formatter = request_input .server_parameters .pop ("output_formatter" , None )
127135 if not "output_formatter" in request_input .server_parameters :
128- request_input .server_parameters ["output_formatter" ] = kwargs .get (
129- "configs" ).output_formatter
136+ output_formatter = kwargs .get ("configs" ).output_formatter
130137
131- request_input .output_formatter = request_input .server_parameters .get (
132- "output_formatter" )
138+ request_input .output_formatter = output_formatter
133139
134140 if request_input .output_formatter == "json" or request_input .output_formatter == "sse" :
135- request_input .tgi_compat = kwargs .get ("configs" ).tgi_compat
141+ request_input .tgi_compat = kwargs .get ("configs" ).tgi_compa
136142
137143 # duplicating parameters for client side batching
138- if isinstance (request_input .input_text , list ):
144+ if isinstance (request_input .input_text , list ) and len (
145+ request_input .input_text ) > 1 :
139146 parameters = []
140147 for _ in range (len (request_input .input_text )):
141148 parameters .append (request_input .server_parameters .copy ())
@@ -147,22 +154,28 @@ def parse_adapters(request_input: TextInput, input_item: Input,
147154 adapter_registry = kwargs .get ("adapter_registry" )
148155 # if adapter registry exists and not empty, then we assume, peft is supported for the incoming
149156 if adapter_registry :
157+ input_len = len (request_input .input_text ) if isinstance (
158+ request_input .input_text , list ) else 1
150159 adapters_per_item = _fetch_adapters_from_input (input_map , input_item )
151160 if adapters_per_item :
152161 _validate_adapters (adapters_per_item ,
153162 kwargs .get ("adapter_registry" ))
154163 else :
155164 # inference with just base model.
156- adapters_per_item = ["" ] * len ( request_input . input_text )
165+ adapters_per_item = ["" ] * input_len
157166
158- if len ( request_input . input_text ) != len (adapters_per_item ):
167+ if input_len != len (adapters_per_item ):
159168 raise ValueError (
160169 f"Number of adapters is not equal to the number of inputs" )
161170 # lookup the adapter registry to get the adapter details of the registered adapter.
162- request_input . adapters = [
171+ adapters_data = [
163172 kwargs .get ("adapter_registry" ).get (adapter , None )
164- for adapter in adapter_registry
173+ for adapter in adapters_per_item
165174 ]
175+ if len (adapters_data ) == 1 :
176+ adapters_data = adapters_data [0 ]
177+
178+ request_input .adapters = adapters_data
166179
167180
168181def _fetch_adapters_from_input (input_map : dict , input_item : Input ):
0 commit comments