11# (C) 2024 GoodData Corporation
2+ import time
23from collections .abc import Generator
34from typing import Optional
45
1213 FlightServerMethods ,
1314 ServerContext ,
1415 TaskExecutionResult ,
15- TaskWaitTimeoutError ,
1616 flight_server_methods ,
1717)
1818
2323_LOGGER = structlog .get_logger ("gooddata_flexconnect.rpc" )
2424
2525
26+ def _prepare_poll_error (task_id : str ) -> pyarrow .flight .FlightError :
27+ return ErrorInfo .poll (
28+ flight_info = None ,
29+ cancel_descriptor = pyarrow .flight .FlightDescriptor .for_command (f"c:{ task_id } " .encode ()),
30+ retry_descriptor = pyarrow .flight .FlightDescriptor .for_command (f"r:{ task_id } " .encode ()),
31+ )
32+
33+
2634class _FlexConnectServerMethods (FlightServerMethods ):
27- def __init__ (self , ctx : ServerContext , registry : FlexConnectFunctionRegistry , call_deadline_ms : float ) -> None :
35+ def __init__ (
36+ self ,
37+ ctx : ServerContext ,
38+ registry : FlexConnectFunctionRegistry ,
39+ call_deadline_ms : float ,
40+ poll_interval_ms : float ,
41+ ) -> None :
2842 self ._ctx = ctx
2943 self ._registry = registry
3044 self ._call_deadline = call_deadline_ms / 1000
45+ self ._poll_interval = poll_interval_ms / 1000
3146
3247 @staticmethod
3348 def _create_descriptor (fun_name : str , metadata : Optional [dict ]) -> pyarrow .flight .FlightDescriptor :
@@ -140,39 +155,53 @@ def get_flight_info(
140155 descriptor : pyarrow .flight .FlightDescriptor ,
141156 ) -> pyarrow .flight .FlightInfo :
142157 structlog .contextvars .bind_contextvars (peer = context .peer ())
143- task : Optional [FlexConnectFunctionTask ] = None
144158
145- try :
146- task = self ._prepare_task (context , descriptor )
147- self ._ctx .task_executor .submit (task )
159+ # first, check if the descriptor is a cancel descriptor
160+ if descriptor .command is None or not len (descriptor .command ):
161+ raise ErrorInfo .bad_argument (
162+ "Incorrect FlexConnect function invocation. Flight descriptor must contain command "
163+ "with the invocation payload."
164+ )
148165
149- try :
150- # XXX: this should be enhanced to implement polling
151- task_result = self ._ctx .task_executor .wait_for_result (task .task_id , self ._call_deadline )
152- except TaskWaitTimeoutError :
153- cancelled = self ._ctx .task_executor .cancel (task .task_id )
154- _LOGGER .warning (
155- "flexconnect_fun_call_timeout" , task_id = task .task_id , fun = task .fun_name , cancelled = cancelled
156- )
166+ task_id : str
167+ fun_name : Optional [str ] = None
157168
169+ if descriptor .command .startswith (b"c:" ):
170+ # cancel descriptor: just cancel the given task and raise cancellation exception
171+ task_id = descriptor .command [2 :].decode ()
172+ self ._ctx .task_executor .cancel (task_id )
173+ raise ErrorInfo .for_reason (
174+ ErrorCode .COMMAND_CANCELLED , "FlexConnect function invocation was cancelled."
175+ ).to_cancelled_error ()
176+ elif descriptor .command .startswith (b"r:" ):
177+ # retry descriptor: extract the task_id, do not submit it again and do one polling iteration
178+ task_id = descriptor .command [2 :].decode ()
179+ # for retries, we also need to check the call deadline for the whole call duration
180+ task_timestamp = self ._ctx .task_executor .get_task_submitted_timestamp (task_id )
181+ if task_timestamp is not None and time .perf_counter () - task_timestamp > self ._call_deadline :
182+ self ._ctx .task_executor .cancel (task_id )
158183 raise ErrorInfo .for_reason (
159- ErrorCode .TIMEOUT , f"GetFlightInfo timed out while waiting for task { task . task_id } ."
184+ ErrorCode .TIMEOUT , f"GetFlightInfo timed out while waiting for task { task_id } ."
160185 ).to_timeout_error ()
186+ else :
187+ # basic first-time submit: submit the task and do one polling iteration.
188+ # do not check call deadline to give it a chance to wait for the result at least once
189+ try :
190+ task = self ._prepare_task (context , descriptor )
191+ self ._ctx .task_executor .submit (task )
192+ task_id = task .task_id
193+ fun_name = task .fun_name
194+ except Exception :
195+ _LOGGER .error ("flexconnect_fun_submit_failed" , exc_info = True )
196+ raise
161197
162- # if this bombs then there must be something really wrong because the task
163- # was clearly submitted and code was waiting for its completion. this invariant
164- # should not happen in this particular code path. The None return value may
165- # be applicable one day when polling is in use and a request comes to check whether
166- # particular task id finished
167- assert task_result is not None
168-
198+ try :
199+ task_result = self ._ctx .task_executor .wait_for_result (task_id , timeout = self ._poll_interval )
169200 return self ._prepare_flight_info (task_result )
201+ except TimeoutError :
202+ raise _prepare_poll_error (task_id )
170203 except Exception :
171- if task is not None :
172- _LOGGER .error ("get_flight_info_failed" , task_id = task .task_id , fun = task .fun_name , exc_info = True )
173- else :
174- _LOGGER .error ("flexconnect_fun_submit_failed" , exc_info = True )
175-
204+ _LOGGER .error ("get_flight_info_failed" , task_id = task_id , fun = fun_name , exc_info = True )
176205 raise
177206
178207 def do_get (
@@ -201,7 +230,9 @@ def do_get(
201230_FLEX_CONNECT_CONFIG_SECTION = "flexconnect"
202231_FLEX_CONNECT_FUNCTION_LIST = "functions"
203232_FLEX_CONNECT_CALL_DEADLINE_MS = "call_deadline_ms"
233+ _FLEX_CONNECT_POLLING_INTERVAL_MS = "polling_interval_ms"
204234_DEFAULT_FLEX_CONNECT_CALL_DEADLINE_MS = 180_000
235+ _DEFAULT_FLEX_CONNECT_POLLING_INTERVAL_MS = 2000
205236
206237
207238def _read_call_deadline_ms (ctx : ServerContext ) -> int :
@@ -223,6 +254,24 @@ def _read_call_deadline_ms(ctx: ServerContext) -> int:
223254 )
224255
225256
257+ def _read_polling_interval_ms (ctx : ServerContext ) -> int :
258+ polling_interval = ctx .settings .get (f"{ _FLEX_CONNECT_CONFIG_SECTION } .{ _FLEX_CONNECT_POLLING_INTERVAL_MS } " )
259+ if polling_interval is None :
260+ return _DEFAULT_FLEX_CONNECT_POLLING_INTERVAL_MS
261+
262+ try :
263+ polling_interval = int (polling_interval )
264+ if polling_interval <= 0 :
265+ raise ValueError ()
266+ return polling_interval
267+ except ValueError :
268+ raise ValueError (
269+ f"Value of { _FLEX_CONNECT_CONFIG_SECTION } .{ _FLEX_CONNECT_POLLING_INTERVAL_MS } must "
270+ f"be a positive number - duration, in milliseconds, that FlexConnect function "
271+ f"waits for the result during one polling iteration."
272+ )
273+
274+
226275@flight_server_methods
227276def create_flexconnect_flight_methods (ctx : ServerContext ) -> FlightServerMethods :
228277 """
@@ -236,8 +285,9 @@ def create_flexconnect_flight_methods(ctx: ServerContext) -> FlightServerMethods
236285 """
237286 modules = list (ctx .settings .get (f"{ _FLEX_CONNECT_CONFIG_SECTION } .{ _FLEX_CONNECT_FUNCTION_LIST } " ) or [])
238287 call_deadline_ms = _read_call_deadline_ms (ctx )
288+ polling_interval_ms = _read_polling_interval_ms (ctx )
239289
240290 _LOGGER .info ("flexconnect_init" , modules = modules )
241291 registry = FlexConnectFunctionRegistry ().load (ctx , modules )
242292
243- return _FlexConnectServerMethods (ctx , registry , call_deadline_ms )
293+ return _FlexConnectServerMethods (ctx , registry , call_deadline_ms , polling_interval_ms )
0 commit comments