1313 FlightServerMethods ,
1414 ServerContext ,
1515 TaskExecutionResult ,
16+ TaskWaitTimeoutError ,
1617 flight_server_methods ,
1718)
1819
2122 CancelInvocation ,
2223 RetryInvocation ,
2324 SubmitInvocation ,
24- extract_invocation_from_descriptor ,
25+ extract_pollable_invocation_from_descriptor ,
26+ extract_submit_invocation_from_descriptor ,
2527)
2628from gooddata_flexconnect .function .function_registry import FlexConnectFunctionRegistry
2729from gooddata_flexconnect .function .function_task import FlexConnectFunctionTask
2830
2931_LOGGER = structlog .get_logger ("gooddata_flexconnect.rpc" )
3032
33+ POLLING_HEADER_NAME = "x-quiver-pollable"
34+ """
35+ If this header is present on the get flight info call, the polling extension will be used.
36+ Otherwise the basic do get will be used.
37+ """
38+
3139
3240def _prepare_poll_error (task_id : str ) -> pyarrow .flight .FlightError :
3341 return ErrorInfo .poll (
@@ -122,28 +130,69 @@ def _prepare_flight_info(
122130 total_bytes = - 1 ,
123131 )
124132
125- ###################################################################
126- # Implementation of Flight RPC methods
127- ###################################################################
128-
129- def list_flights (
130- self , context : pyarrow .flight .ServerCallContext , criteria : bytes
131- ) -> Generator [pyarrow .flight .FlightInfo , None , None ]:
133+ def _get_flight_info_no_polling (
134+ self ,
135+ context : pyarrow .flight .ServerCallContext ,
136+ descriptor : pyarrow .flight .FlightDescriptor ,
137+ ) -> pyarrow .flight .FlightInfo :
138+ """
139+ Basic DoGetInfo flow with no polling extension.
140+ This conforms to the mainline Arrow Flight RPC specification.
141+ """
132142 structlog .contextvars .bind_contextvars (peer = context .peer ())
133- _LOGGER . info ( "list_flights" , available_funs = self . _registry . function_names )
143+ invocation = extract_submit_invocation_from_descriptor ( descriptor )
134144
135- return ( self . _create_fun_info ( fun ) for fun in self . _registry . functions . values ())
145+ task : Optional [ FlexConnectFunctionTask ] = None
136146
137- def get_flight_info (
147+ try :
148+ task = self ._prepare_task (context , invocation )
149+ self ._ctx .task_executor .submit (task )
150+
151+ try :
152+ task_result = self ._ctx .task_executor .wait_for_result (task .task_id , self ._call_deadline )
153+ except TaskWaitTimeoutError :
154+ cancelled = self ._ctx .task_executor .cancel (task .task_id )
155+ _LOGGER .warning (
156+ "flexconnect_fun_call_timeout" , task_id = task .task_id , fun = task .fun_name , cancelled = cancelled
157+ )
158+
159+ raise ErrorInfo .for_reason (
160+ ErrorCode .TIMEOUT , f"GetFlightInfo timed out while waiting for task { task .task_id } ."
161+ ).to_timeout_error ()
162+
163+ # if this bombs then there must be something really wrong because the task
164+ # was clearly submitted and code was waiting for its completion. this invariant
165+ # should not happen in this particular code path. The None return value may
166+ # be applicable one day when polling is in use and a request comes to check whether
167+ # particular task id finished
168+ assert task_result is not None
169+
170+ return self ._prepare_flight_info (task_id = task .task_id , task_result = task_result )
171+ except Exception :
172+ if task is not None :
173+ _LOGGER .error (
174+ "get_flight_info_failed" , task_id = task .task_id , fun = task .fun_name , exc_info = True , polling = False
175+ )
176+ else :
177+ _LOGGER .error ("flexconnect_fun_submit_failed" , exc_info = True , polling = False )
178+ raise
179+
180+ def _get_flight_info_polling (
138181 self ,
139182 context : pyarrow .flight .ServerCallContext ,
140183 descriptor : pyarrow .flight .FlightDescriptor ,
141184 ) -> pyarrow .flight .FlightInfo :
185+ """
186+ DoGetInfo flow with polling extension.
187+ This extends the mainline Arrow Flight RPC specification with polling capabilities using the RetryInfo
188+ encoded into the FlightTimedOutError.extra_info.
189+ Ideally, we would use the mainline PollFlightInfo, but that has yet to be implemented in the PyArrow library.
190+ """
142191 structlog .contextvars .bind_contextvars (peer = context .peer ())
192+ invocation = extract_pollable_invocation_from_descriptor (descriptor )
143193
144194 task_id : str
145195 fun_name : Optional [str ] = None
146- invocation = extract_invocation_from_descriptor (descriptor )
147196
148197 if isinstance (invocation , CancelInvocation ):
149198 # cancel the given task and raise cancellation exception
@@ -166,7 +215,7 @@ def get_flight_info(
166215 task_id = task .task_id
167216 fun_name = task .fun_name
168217 except Exception :
169- _LOGGER .error ("flexconnect_fun_submit_failed" , exc_info = True )
218+ _LOGGER .error ("flexconnect_fun_submit_failed" , exc_info = True , polling = True )
170219 raise
171220 else :
172221 # can be replaced by assert_never when we are on 3.11
@@ -188,9 +237,36 @@ def get_flight_info(
188237 # how to poll for the results
189238 raise _prepare_poll_error (task_id )
190239 except Exception :
191- _LOGGER .error ("get_flight_info_failed" , task_id = task_id , fun = fun_name , exc_info = True )
240+ _LOGGER .error ("get_flight_info_failed" , task_id = task_id , fun = fun_name , exc_info = True , polling = True )
192241 raise
193242
243+ ###################################################################
244+ # Implementation of Flight RPC methods
245+ ###################################################################
246+
247+ def list_flights (
248+ self , context : pyarrow .flight .ServerCallContext , criteria : bytes
249+ ) -> Generator [pyarrow .flight .FlightInfo , None , None ]:
250+ structlog .contextvars .bind_contextvars (peer = context .peer ())
251+ _LOGGER .info ("list_flights" , available_funs = self ._registry .function_names )
252+
253+ return (self ._create_fun_info (fun ) for fun in self ._registry .functions .values ())
254+
255+ def get_flight_info (
256+ self ,
257+ context : pyarrow .flight .ServerCallContext ,
258+ descriptor : pyarrow .flight .FlightDescriptor ,
259+ ) -> pyarrow .flight .FlightInfo :
260+ structlog .contextvars .bind_contextvars (peer = context .peer ())
261+
262+ headers = self .call_info_middleware (context ).headers
263+ allow_polling = headers .get (POLLING_HEADER_NAME ) is not None
264+
265+ if allow_polling :
266+ return self ._get_flight_info_polling (context , descriptor )
267+ else :
268+ return self ._get_flight_info_no_polling (context , descriptor )
269+
194270 def do_get (
195271 self ,
196272 context : pyarrow .flight .ServerCallContext ,
0 commit comments