Skip to content

Commit f58c4ef

Browse files
committed
feat: add support for polling in FlexConnect server
Add a polling mechanism to the FlexConnect functions so that long-running tasks can be polled for and canceled. The TaskExecutor now supports returning a timestamp of when a particular task was submitted. This is to keep track of the call deadline breaches. JIRA: CQ-1124 risk: low
1 parent 4462856 commit f58c4ef

File tree

7 files changed

+226
-31
lines changed

7 files changed

+226
-31
lines changed

gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py

Lines changed: 78 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# (C) 2024 GoodData Corporation
2+
import time
23
from collections.abc import Generator
34
from typing import Optional
45

@@ -12,7 +13,6 @@
1213
FlightServerMethods,
1314
ServerContext,
1415
TaskExecutionResult,
15-
TaskWaitTimeoutError,
1616
flight_server_methods,
1717
)
1818

@@ -23,11 +23,26 @@
2323
_LOGGER = structlog.get_logger("gooddata_flexconnect.rpc")
2424

2525

26+
def _prepare_poll_error(task_id: str) -> pyarrow.flight.FlightError:
27+
return ErrorInfo.poll(
28+
flight_info=None,
29+
cancel_descriptor=pyarrow.flight.FlightDescriptor.for_command(f"c:{task_id}".encode()),
30+
retry_descriptor=pyarrow.flight.FlightDescriptor.for_command(f"r:{task_id}".encode()),
31+
)
32+
33+
2634
class _FlexConnectServerMethods(FlightServerMethods):
27-
def __init__(self, ctx: ServerContext, registry: FlexConnectFunctionRegistry, call_deadline_ms: float) -> None:
35+
def __init__(
36+
self,
37+
ctx: ServerContext,
38+
registry: FlexConnectFunctionRegistry,
39+
call_deadline_ms: float,
40+
poll_interval_ms: float,
41+
) -> None:
2842
self._ctx = ctx
2943
self._registry = registry
3044
self._call_deadline = call_deadline_ms / 1000
45+
self._poll_interval = poll_interval_ms / 1000
3146

3247
@staticmethod
3348
def _create_descriptor(fun_name: str, metadata: Optional[dict]) -> pyarrow.flight.FlightDescriptor:
@@ -140,39 +155,53 @@ def get_flight_info(
140155
descriptor: pyarrow.flight.FlightDescriptor,
141156
) -> pyarrow.flight.FlightInfo:
142157
structlog.contextvars.bind_contextvars(peer=context.peer())
143-
task: Optional[FlexConnectFunctionTask] = None
144158

145-
try:
146-
task = self._prepare_task(context, descriptor)
147-
self._ctx.task_executor.submit(task)
159+
# first, check if the descriptor is a cancel descriptor
160+
if descriptor.command is None or not len(descriptor.command):
161+
raise ErrorInfo.bad_argument(
162+
"Incorrect FlexConnect function invocation. Flight descriptor must contain command "
163+
"with the invocation payload."
164+
)
148165

149-
try:
150-
# XXX: this should be enhanced to implement polling
151-
task_result = self._ctx.task_executor.wait_for_result(task.task_id, self._call_deadline)
152-
except TaskWaitTimeoutError:
153-
cancelled = self._ctx.task_executor.cancel(task.task_id)
154-
_LOGGER.warning(
155-
"flexconnect_fun_call_timeout", task_id=task.task_id, fun=task.fun_name, cancelled=cancelled
156-
)
166+
task_id: str
167+
fun_name: Optional[str] = None
157168

169+
if descriptor.command.startswith(b"c:"):
170+
# cancel descriptor: just cancel the given task and raise cancellation exception
171+
task_id = descriptor.command[2:].decode()
172+
self._ctx.task_executor.cancel(task_id)
173+
raise ErrorInfo.for_reason(
174+
ErrorCode.COMMAND_CANCELLED, "FlexConnect function invocation was cancelled."
175+
).to_cancelled_error()
176+
elif descriptor.command.startswith(b"r:"):
177+
# retry descriptor: extract the task_id, do not submit it again and do one polling iteration
178+
task_id = descriptor.command[2:].decode()
179+
# for retries, we also need to check the call deadline for the whole call duration
180+
task_timestamp = self._ctx.task_executor.get_task_submitted_timestamp(task_id)
181+
if task_timestamp is not None and time.perf_counter() - task_timestamp > self._call_deadline:
182+
self._ctx.task_executor.cancel(task_id)
158183
raise ErrorInfo.for_reason(
159-
ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task.task_id}."
184+
ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task_id}."
160185
).to_timeout_error()
186+
else:
187+
# basic first-time submit: submit the task and do one polling iteration.
188+
# do not check call deadline to give it a chance to wait for the result at least once
189+
try:
190+
task = self._prepare_task(context, descriptor)
191+
self._ctx.task_executor.submit(task)
192+
task_id = task.task_id
193+
fun_name = task.fun_name
194+
except Exception:
195+
_LOGGER.error("flexconnect_fun_submit_failed", exc_info=True)
196+
raise
161197

162-
# if this bombs then there must be something really wrong because the task
163-
# was clearly submitted and code was waiting for its completion. this invariant
164-
# should not happen in this particular code path. The None return value may
165-
# be applicable one day when polling is in use and a request comes to check whether
166-
# particular task id finished
167-
assert task_result is not None
168-
198+
try:
199+
task_result = self._ctx.task_executor.wait_for_result(task_id, timeout=self._poll_interval)
169200
return self._prepare_flight_info(task_result)
201+
except TimeoutError:
202+
raise _prepare_poll_error(task_id)
170203
except Exception:
171-
if task is not None:
172-
_LOGGER.error("get_flight_info_failed", task_id=task.task_id, fun=task.fun_name, exc_info=True)
173-
else:
174-
_LOGGER.error("flexconnect_fun_submit_failed", exc_info=True)
175-
204+
_LOGGER.error("get_flight_info_failed", task_id=task_id, fun=fun_name, exc_info=True)
176205
raise
177206

178207
def do_get(
@@ -201,7 +230,9 @@ def do_get(
201230
_FLEX_CONNECT_CONFIG_SECTION = "flexconnect"
202231
_FLEX_CONNECT_FUNCTION_LIST = "functions"
203232
_FLEX_CONNECT_CALL_DEADLINE_MS = "call_deadline_ms"
233+
_FLEX_CONNECT_POLLING_INTERVAL_MS = "polling_interval_ms"
204234
_DEFAULT_FLEX_CONNECT_CALL_DEADLINE_MS = 180_000
235+
_DEFAULT_FLEX_CONNECT_POLLING_INTERVAL_MS = 2000
205236

206237

207238
def _read_call_deadline_ms(ctx: ServerContext) -> int:
@@ -223,6 +254,24 @@ def _read_call_deadline_ms(ctx: ServerContext) -> int:
223254
)
224255

225256

257+
def _read_polling_interval_ms(ctx: ServerContext) -> int:
258+
polling_interval = ctx.settings.get(f"{_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_POLLING_INTERVAL_MS}")
259+
if polling_interval is None:
260+
return _DEFAULT_FLEX_CONNECT_POLLING_INTERVAL_MS
261+
262+
try:
263+
polling_interval = int(polling_interval)
264+
if polling_interval <= 0:
265+
raise ValueError()
266+
return polling_interval
267+
except ValueError:
268+
raise ValueError(
269+
f"Value of {_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_POLLING_INTERVAL_MS} must "
270+
f"be a positive number - duration, in milliseconds, that FlexConnect function "
271+
f"waits for the result during one polling iteration."
272+
)
273+
274+
226275
@flight_server_methods
227276
def create_flexconnect_flight_methods(ctx: ServerContext) -> FlightServerMethods:
228277
"""
@@ -236,8 +285,9 @@ def create_flexconnect_flight_methods(ctx: ServerContext) -> FlightServerMethods
236285
"""
237286
modules = list(ctx.settings.get(f"{_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_FUNCTION_LIST}") or [])
238287
call_deadline_ms = _read_call_deadline_ms(ctx)
288+
polling_interval_ms = _read_polling_interval_ms(ctx)
239289

240290
_LOGGER.info("flexconnect_init", modules=modules)
241291
registry = FlexConnectFunctionRegistry().load(ctx, modules)
242292

243-
return _FlexConnectServerMethods(ctx, registry, call_deadline_ms)
293+
return _FlexConnectServerMethods(ctx, registry, call_deadline_ms, polling_interval_ms)

gooddata-flexconnect/tests/server/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def flexconnect_server(
8080
funs = f"[{funs}]"
8181

8282
os.environ["GOODDATA_FLIGHT_FLEXCONNECT__FUNCTIONS"] = funs
83-
os.environ["GOODDATA_FLIGHT_FLEXCONNECT__CALL_DEADLINE_MS"] = "500"
83+
os.environ["GOODDATA_FLIGHT_FLEXCONNECT__CALL_DEADLINE_MS"] = "1000"
84+
os.environ["GOODDATA_FLIGHT_FLEXCONNECT__POLLING_INTERVAL_MS"] = "500"
8485

8586
with server(create_flexconnect_flight_methods, tls, mtls) as s:
8687
yield s

gooddata-flexconnect/tests/server/funs/fun3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def call(
2727
) -> ArrowData:
2828
# sleep is intentionally setup to be longer than the deadline for
2929
# the function invocation (see conftest.py // flexconnect_server fixture)
30-
time.sleep(1)
30+
time.sleep(1.5)
3131

3232
return pyarrow.table(
3333
data={
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# (C) 2024 GoodData Corporation
2+
import time
3+
from typing import Optional
4+
5+
import pyarrow
6+
from gooddata_flexconnect.function.function import FlexConnectFunction
7+
from gooddata_flight_server import ArrowData
8+
9+
_DATA: Optional[pyarrow.Table] = None
10+
11+
12+
class _PollableFun(FlexConnectFunction):
13+
Name = "PollableFun"
14+
Schema = pyarrow.schema(
15+
fields=[
16+
pyarrow.field("col1", pyarrow.int64()),
17+
pyarrow.field("col2", pyarrow.string()),
18+
pyarrow.field("col3", pyarrow.bool_()),
19+
]
20+
)
21+
22+
def call(
23+
self,
24+
parameters: dict,
25+
columns: tuple[str, ...],
26+
headers: dict[str, list[str]],
27+
) -> ArrowData:
28+
# sleep is intentionally setup to be longer than one polling interval
29+
# (see conftest.py // flexconnect_server fixture)
30+
time.sleep(0.7)
31+
32+
return pyarrow.table(
33+
data={
34+
"col1": [1, 2, 3],
35+
"col2": ["a", "b", "c"],
36+
"col3": [True, False, True],
37+
},
38+
schema=self.Schema,
39+
)

gooddata-flexconnect/tests/server/test_flexconnect_server.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# (C) 2024 GoodData Corporation
2+
23
import orjson
34
import pyarrow.flight
45
import pytest
5-
from gooddata_flight_server import ErrorCode
6+
from gooddata_flight_server import ErrorCode, ErrorInfo, RetryInfo
67

78
from tests.assert_error_info import assert_error_code
89
from tests.server.conftest import flexconnect_server
@@ -89,6 +90,43 @@ def test_basic_function_tls(tls_ca_cert):
8990
assert data.column_names == ["col1", "col2", "col3"]
9091

9192

93+
def test_function_with_polling():
94+
"""
95+
Flight RPC implementation that invokes FlexConnect can return a polling info.
96+
97+
This way, the client can poll for results that take longer to complete.
98+
"""
99+
with flexconnect_server(["tests.server.funs.fun4"]) as s:
100+
c = pyarrow.flight.FlightClient(s.location)
101+
descriptor = pyarrow.flight.FlightDescriptor.for_command(
102+
orjson.dumps(
103+
{
104+
"functionName": "PollableFun",
105+
"parameters": {"test1": 1, "test2": 2, "test3": 3},
106+
}
107+
)
108+
)
109+
110+
# the function is set to sleep a bit longer than the polling interval,
111+
# so the first iteration returns retry info in the exception
112+
with pytest.raises(pyarrow.flight.FlightTimedOutError) as e:
113+
c.get_flight_info(descriptor)
114+
115+
assert e.value is not None
116+
assert_error_code(ErrorCode.POLL, e.value)
117+
118+
error_info = ErrorInfo.from_bytes(e.value.extra_info)
119+
retry_info = RetryInfo.from_bytes(error_info.body)
120+
121+
# use the retry info to poll again for the result,
122+
# now it should be ready and returned normally
123+
info = c.get_flight_info(retry_info.retry_descriptor)
124+
data: pyarrow.Table = c.do_get(info.endpoints[0].ticket).read_all()
125+
126+
assert len(data) == 3
127+
assert data.column_names == ["col1", "col2", "col3"]
128+
129+
92130
def test_function_with_call_deadline():
93131
"""
94132
Flight RPC implementation that invokes FlexConnect can be setup with
@@ -115,4 +153,54 @@ def test_function_with_call_deadline():
115153
with pytest.raises(pyarrow.flight.FlightTimedOutError) as e:
116154
c.get_flight_info(descriptor)
117155

156+
assert e.value is not None
157+
assert_error_code(ErrorCode.POLL, e.value)
158+
159+
error_info = ErrorInfo.from_bytes(e.value.extra_info)
160+
retry_info = RetryInfo.from_bytes(error_info.body)
161+
162+
# poll twice to reach the call deadline
163+
with pytest.raises(pyarrow.flight.FlightTimedOutError) as e:
164+
c.get_flight_info(retry_info.retry_descriptor)
165+
166+
assert e.value is not None
167+
assert_error_code(ErrorCode.POLL, e.value)
168+
169+
error_info = ErrorInfo.from_bytes(e.value.extra_info)
170+
retry_info = RetryInfo.from_bytes(error_info.body)
171+
172+
with pytest.raises(pyarrow.flight.FlightTimedOutError) as e:
173+
c.get_flight_info(retry_info.retry_descriptor)
174+
175+
# and then ensure the timeout error is returned
118176
assert_error_code(ErrorCode.TIMEOUT, e.value)
177+
178+
179+
def test_function_with_cancelation():
180+
"""
181+
Run a long-running function and cancel it after one poll iteration.
182+
"""
183+
with flexconnect_server(["tests.server.funs.fun3"]) as s:
184+
c = pyarrow.flight.FlightClient(s.location)
185+
descriptor = pyarrow.flight.FlightDescriptor.for_command(
186+
orjson.dumps(
187+
{
188+
"functionName": "LongRunningFun",
189+
"parameters": {"test1": 1, "test2": 2, "test3": 3},
190+
}
191+
)
192+
)
193+
194+
with pytest.raises(pyarrow.flight.FlightTimedOutError) as e:
195+
c.get_flight_info(descriptor)
196+
197+
assert e.value is not None
198+
assert_error_code(ErrorCode.POLL, e.value)
199+
200+
error_info = ErrorInfo.from_bytes(e.value.extra_info)
201+
retry_info = RetryInfo.from_bytes(error_info.body)
202+
203+
with pytest.raises(pyarrow.flight.FlightCancelledError) as e:
204+
c.get_flight_info(retry_info.cancel_descriptor)
205+
206+
assert e.value is not None

gooddata-flight-server/gooddata_flight_server/tasks/task_executor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ def submit(
5454
"""
5555
raise NotImplementedError
5656

57+
@abc.abstractmethod
58+
def get_task_submitted_timestamp(self, task_id: str) -> Optional[float]:
59+
"""
60+
Returns the timestamp of when the task with the given id was submitted.
61+
:param task_id: task id to get the timestamp for
62+
:return: Timestamp in seconds since epoch of when the task was submitted or None if there is no such task
63+
"""
64+
raise NotImplementedError
65+
5766
@abc.abstractmethod
5867
def wait_for_result(self, task_id: str, timeout: Optional[float] = None) -> Optional[TaskExecutionResult]:
5968
"""

gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,14 @@ def submit(
565565
execution.start()
566566
self._metrics.queue_size.set(self._queue_size)
567567

568+
def get_task_submitted_timestamp(self, task_id: str) -> Optional[float]:
569+
with self._task_lock:
570+
execution = self._executions.get(task_id)
571+
572+
if execution is not None:
573+
return execution.stats.created
574+
return None
575+
568576
def wait_for_result(self, task_id: str, timeout: Optional[float] = None) -> Optional[TaskExecutionResult]:
569577
with self._task_lock:
570578
execution = self._executions.get(task_id)

0 commit comments

Comments
 (0)