Skip to content

Commit 173aa5b

Browse files
committed
feat: make the polling extension opt-in
By default, the FlexConnect will conform to the Arrow Flight RPC spec. However, if an opt-in header is present, it will use the polling extension used by GoodData. This allows for things like query cancellation. Ideally, we would use the PollFlightInfo from the Arrow Flight RPC but unfortunately it is not yet available in PyArrow. JIRA: CQ-1124 risk: low
1 parent ce2fdb1 commit 173aa5b

File tree

3 files changed

+199
-50
lines changed

3 files changed

+199
-50
lines changed

gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py

Lines changed: 90 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
FlightServerMethods,
1414
ServerContext,
1515
TaskExecutionResult,
16+
TaskWaitTimeoutError,
1617
flight_server_methods,
1718
)
1819

@@ -21,13 +22,20 @@
2122
CancelInvocation,
2223
RetryInvocation,
2324
SubmitInvocation,
24-
extract_invocation_from_descriptor,
25+
extract_pollable_invocation_from_descriptor,
26+
extract_submit_invocation_from_descriptor,
2527
)
2628
from gooddata_flexconnect.function.function_registry import FlexConnectFunctionRegistry
2729
from gooddata_flexconnect.function.function_task import FlexConnectFunctionTask
2830

2931
_LOGGER = structlog.get_logger("gooddata_flexconnect.rpc")
3032

33+
POLLING_HEADER_NAME = "x-quiver-pollable"
34+
"""
35+
If this header is present on the get flight info call, the polling extension will be used.
36+
Otherwise the basic do get will be used.
37+
"""
38+
3139

3240
def _prepare_poll_error(task_id: str) -> pyarrow.flight.FlightError:
3341
return ErrorInfo.poll(
@@ -122,28 +130,69 @@ def _prepare_flight_info(
122130
total_bytes=-1,
123131
)
124132

125-
###################################################################
126-
# Implementation of Flight RPC methods
127-
###################################################################
128-
129-
def list_flights(
130-
self, context: pyarrow.flight.ServerCallContext, criteria: bytes
131-
) -> Generator[pyarrow.flight.FlightInfo, None, None]:
133+
def _get_flight_info_no_polling(
134+
self,
135+
context: pyarrow.flight.ServerCallContext,
136+
descriptor: pyarrow.flight.FlightDescriptor,
137+
) -> pyarrow.flight.FlightInfo:
138+
"""
139+
Basic DoGetInfo flow with no polling extension.
140+
This conforms to the mainline Arrow Flight RPC specification.
141+
"""
132142
structlog.contextvars.bind_contextvars(peer=context.peer())
133-
_LOGGER.info("list_flights", available_funs=self._registry.function_names)
143+
invocation = extract_submit_invocation_from_descriptor(descriptor)
134144

135-
return (self._create_fun_info(fun) for fun in self._registry.functions.values())
145+
task: Optional[FlexConnectFunctionTask] = None
136146

137-
def get_flight_info(
147+
try:
148+
task = self._prepare_task(context, invocation)
149+
self._ctx.task_executor.submit(task)
150+
151+
try:
152+
task_result = self._ctx.task_executor.wait_for_result(task.task_id, self._call_deadline)
153+
except TaskWaitTimeoutError:
154+
cancelled = self._ctx.task_executor.cancel(task.task_id)
155+
_LOGGER.warning(
156+
"flexconnect_fun_call_timeout", task_id=task.task_id, fun=task.fun_name, cancelled=cancelled
157+
)
158+
159+
raise ErrorInfo.for_reason(
160+
ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task.task_id}."
161+
).to_timeout_error()
162+
163+
# if this bombs then there must be something really wrong because the task
164+
# was clearly submitted and code was waiting for its completion. this invariant
165+
# should not happen in this particular code path. The None return value may
166+
# be applicable one day when polling is in use and a request comes to check whether
167+
# particular task id finished
168+
assert task_result is not None
169+
170+
return self._prepare_flight_info(task_id=task.task_id, task_result=task_result)
171+
except Exception:
172+
if task is not None:
173+
_LOGGER.error(
174+
"get_flight_info_failed", task_id=task.task_id, fun=task.fun_name, exc_info=True, polling=False
175+
)
176+
else:
177+
_LOGGER.error("flexconnect_fun_submit_failed", exc_info=True, polling=False)
178+
raise
179+
180+
def _get_flight_info_polling(
138181
self,
139182
context: pyarrow.flight.ServerCallContext,
140183
descriptor: pyarrow.flight.FlightDescriptor,
141184
) -> pyarrow.flight.FlightInfo:
185+
"""
186+
DoGetInfo flow with polling extension.
187+
This extends the mainline Arrow Flight RPC specification with polling capabilities using the RetryInfo
188+
encoded into the FlightTimedOutError.extra_info.
189+
Ideally, we would use the mainline PollFlightInfo, but that has yet to be implemented in the PyArrow library.
190+
"""
142191
structlog.contextvars.bind_contextvars(peer=context.peer())
192+
invocation = extract_pollable_invocation_from_descriptor(descriptor)
143193

144194
task_id: str
145195
fun_name: Optional[str] = None
146-
invocation = extract_invocation_from_descriptor(descriptor)
147196

148197
if isinstance(invocation, CancelInvocation):
149198
# cancel the given task and raise cancellation exception
@@ -166,7 +215,7 @@ def get_flight_info(
166215
task_id = task.task_id
167216
fun_name = task.fun_name
168217
except Exception:
169-
_LOGGER.error("flexconnect_fun_submit_failed", exc_info=True)
218+
_LOGGER.error("flexconnect_fun_submit_failed", exc_info=True, polling=True)
170219
raise
171220
else:
172221
# can be replaced by assert_never when we are on 3.11
@@ -188,9 +237,36 @@ def get_flight_info(
188237
# how to poll for the results
189238
raise _prepare_poll_error(task_id)
190239
except Exception:
191-
_LOGGER.error("get_flight_info_failed", task_id=task_id, fun=fun_name, exc_info=True)
240+
_LOGGER.error("get_flight_info_failed", task_id=task_id, fun=fun_name, exc_info=True, polling=True)
192241
raise
193242

243+
###################################################################
244+
# Implementation of Flight RPC methods
245+
###################################################################
246+
247+
def list_flights(
248+
self, context: pyarrow.flight.ServerCallContext, criteria: bytes
249+
) -> Generator[pyarrow.flight.FlightInfo, None, None]:
250+
structlog.contextvars.bind_contextvars(peer=context.peer())
251+
_LOGGER.info("list_flights", available_funs=self._registry.function_names)
252+
253+
return (self._create_fun_info(fun) for fun in self._registry.functions.values())
254+
255+
def get_flight_info(
256+
self,
257+
context: pyarrow.flight.ServerCallContext,
258+
descriptor: pyarrow.flight.FlightDescriptor,
259+
) -> pyarrow.flight.FlightInfo:
260+
structlog.contextvars.bind_contextvars(peer=context.peer())
261+
262+
headers = self.call_info_middleware(context).headers
263+
allow_polling = headers.get(POLLING_HEADER_NAME) is not None
264+
265+
if allow_polling:
266+
return self._get_flight_info_polling(context, descriptor)
267+
else:
268+
return self._get_flight_info_no_polling(context, descriptor)
269+
194270
def do_get(
195271
self,
196272
context: pyarrow.flight.ServerCallContext,

gooddata-flexconnect/gooddata_flexconnect/function/function_invocation.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,42 +53,51 @@ class SubmitInvocation:
5353
"""
5454

5555

56-
def extract_invocation_from_descriptor(
56+
def extract_submit_invocation_from_descriptor(descriptor: pyarrow.flight.FlightDescriptor) -> SubmitInvocation:
57+
"""
58+
Given a flight descriptor, extract the invocation information from it.
59+
Do not allow the polling-related variants.
60+
"""
61+
try:
62+
payload = orjson.loads(descriptor.command)
63+
except Exception:
64+
raise ErrorInfo.bad_argument(
65+
"Incorrect FlexConnect function invocation. The invocation payload is not a valid JSON."
66+
)
67+
68+
function_name = payload.get("functionName")
69+
if function_name is None or not len(function_name):
70+
raise ErrorInfo.bad_argument(
71+
"Incorrect FlexConnect function invocation. The invocation payload does not specify 'functionName'."
72+
)
73+
74+
parameters = payload.get("parameters") or {}
75+
columns = parameters.get("columns")
76+
77+
return SubmitInvocation(
78+
function_name=function_name, parameters=parameters, columns=columns, command=descriptor.command
79+
)
80+
81+
82+
def extract_pollable_invocation_from_descriptor(
5783
descriptor: pyarrow.flight.FlightDescriptor,
5884
) -> Union[RetryInvocation, CancelInvocation, SubmitInvocation]:
5985
"""
6086
Given a flight descriptor, extract the invocation information from it.
87+
Allow also the polling-related variants.
6188
"""
62-
6389
if descriptor.command is None or not len(descriptor.command):
6490
raise ErrorInfo.bad_argument(
6591
"Incorrect FlexConnect function invocation. Flight descriptor must contain command "
6692
"with the invocation payload."
6793
)
6894

95+
# we are in the polling-enabled realm: try parsing the retry and cancel descriptors first
6996
if descriptor.command.startswith(b"c:"):
7097
task_id = descriptor.command[2:].decode()
7198
return CancelInvocation(task_id)
7299
elif descriptor.command.startswith(b"r:"):
73100
task_id = descriptor.command[2:].decode()
74101
return RetryInvocation(task_id)
75-
else:
76-
try:
77-
payload = orjson.loads(descriptor.command)
78-
except Exception:
79-
raise ErrorInfo.bad_argument(
80-
"Incorrect FlexConnect function invocation. The invocation payload is not a valid JSON."
81-
)
82-
83-
function_name = payload.get("functionName")
84-
if function_name is None or not len(function_name):
85-
raise ErrorInfo.bad_argument(
86-
"Incorrect FlexConnect function invocation. The invocation payload does not specify 'functionName'."
87-
)
88-
89-
parameters = payload.get("parameters") or {}
90-
columns = parameters.get("columns")
91-
92-
return SubmitInvocation(
93-
function_name=function_name, parameters=parameters, columns=columns, command=descriptor.command
94-
)
102+
103+
return extract_submit_invocation_from_descriptor(descriptor)

0 commit comments

Comments
 (0)