Skip to content

Commit e9973c4

Browse files
committed
Pipeline: proper implements 'errors=ignore option
1 parent ff558a3 commit e9973c4

File tree

2 files changed

+96
-7
lines changed

2 files changed

+96
-7
lines changed

extraasync/pipeline.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ async def _sync_to_async_iterable() -> t.AsyncIterable[T]:
119119
return _sync_to_async_iterable()
120120

121121

122-
PipelineErrors = t.Literal["strict", "ignore", "lazy_raise"]
122+
PipelineErrors = t.Literal["eager", "ignore", "group"]
123123

124124

125125
class RateLimiter:
@@ -244,7 +244,7 @@ def _collect_result(self, task, next_):
244244
return next_((task.order_tag, task.result()))
245245
# Hardcoded: exception - ignore.
246246
logger.error("Exception in pipelined stage: %s", exc)
247-
self.parent.output.put_nowait((EXC_MARKER, exc))
247+
self.parent.output.put_nowait((EXC_MARKER, (task.order_tag, exc)))
248248

249249
def _create_task(self, value: tuple[int, t.Any]):
250250

@@ -312,7 +312,7 @@ def __init__(
312312
max_concurrency: t.Optional[int] = None,
313313
rate_limit: None | RateLimiter | Real = None,
314314
rate_limit_unit: TIME_UNIT = "second",
315-
on_error: PipelineErrors = "strict",
315+
on_error: PipelineErrors = "eager",
316316
preserve_order: bool = False,
317317
max_simultaneous_records: t.Optional[int] = None,
318318
sink: None | SupportsRShift | MutableSequence | MutableSet = None,
@@ -326,7 +326,7 @@ def __init__(
326326
(i.e. if there are 2 stages, and max_concurrency is set to 4, we may have
327327
up to 8 concurrent tasks running at once in the pipeline, but each stage is
328328
limited to 4)
329-
- on_error: WHat to do if any stage raises an exeception - defaults to re-raise the
329+
- on_error: WHat to do if any stage raises an exception - defaults to re-raise the
330330
exception and stop the whole pipeline
331331
- rate_limit: An overall rate-limitting parameter which can be used to throtle all stages.
332332
If anyone stage should have a limit different from the limit to the whole pipeline,
@@ -390,7 +390,7 @@ def chain_data(self, data_source):
390390
# TBD
391391

392392
async def __aiter__(self):
393-
"""Each iteration retrieves the next final result, after passing it trhough all teh stages
393+
"""Each iteration retrieves the next final result, after passing it trhough all the stages
394394
395395
NB: calling this a single time will trigger the Pipeline background execution, and
396396
more than one item can be (or will be) fectched from source in a single iteration,
@@ -429,9 +429,11 @@ async def __aiter__(self):
429429
if order_marker is EXC_MARKER:
430430
if self.on_error == "ignore":
431431
await asyncio.sleep(0)
432+
if self.preserve_order:
433+
self.ordered_results.push((result_data[0], EXC_MARKER))
432434
continue
433435
elif self.on_error == "strict":
434-
raise result_data
436+
raise result_data[1]
435437
elif self.on_error == "lazy":
436438
raise NotImplementedError("Lazy error raising in pipeline")
437439
if not self.preserve_order:
@@ -440,7 +442,8 @@ async def __aiter__(self):
440442
self.ordered_results.push((order_marker, result_data))
441443
if self.ordered_results.peek() == last_yielded_index + 1:
442444
last_yielded_index, result_data = self.ordered_results.pop()
443-
yield result_data
445+
if result_data is not EXC_MARKER:
446+
yield result_data
444447

445448
await asyncio.sleep(0)
446449

tests/test_pipeline.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,92 @@ async def map_function(n):
431431
assert set(results) == set(range(0, 2 * task_amount, 2))
432432

433433

434+
@pytest.mark.parametrize(
435+
["task_ammount", "failing_steps", "two_stages", "preserve_order"],
436+
[
437+
(
438+
10,
439+
{
440+
5,
441+
},
442+
False,
443+
True,
444+
),
445+
(
446+
10,
447+
{
448+
5,
449+
},
450+
True,
451+
True,
452+
),
453+
(
454+
10,
455+
{
456+
5,
457+
},
458+
True,
459+
False,
460+
),
461+
(
462+
10,
463+
{
464+
5,
465+
},
466+
False,
467+
False,
468+
),
469+
(10, {1, 3, 6, 9}, False, True),
470+
],
471+
)
472+
@pytest.mark.asyncio
473+
async def test_pipeline_runs_to_completion_when_ignoring_exceptions(
474+
task_ammount, failing_steps, two_stages, preserve_order
475+
):
476+
477+
async def producer(n, interval=0):
478+
for i in range(n):
479+
yield i
480+
await asyncio.sleep(interval)
481+
482+
async def map_function(n):
483+
if n in failing_steps:
484+
raise ValueError()
485+
await asyncio.sleep(0)
486+
return n * 2
487+
488+
async def second_stage(n):
489+
await asyncio.sleep(0)
490+
return n
491+
492+
use_second_stage = (second_stage,) if two_stages else ()
493+
494+
results = []
495+
496+
try:
497+
async with asyncio.timeout(0.1):
498+
async for result in Pipeline(
499+
producer(10),
500+
map_function,
501+
*use_second_stage,
502+
preserve_order=preserve_order,
503+
on_error="ignore",
504+
):
505+
results.append(result)
506+
except TimeoutError:
507+
508+
assert False, f"Timed out waiting pipeline to complete. Results: {results}"
509+
510+
expected = [
511+
item * 2 for item in range(0, task_ammount) if item not in failing_steps
512+
]
513+
if not preserve_order:
514+
results = set(results)
515+
expected = set(expected)
516+
517+
assert results == expected
518+
519+
434520
@pytest.mark.skip
435521
@pytest.mark.asyncio
436522
async def test_pipeline_max_simultaneous_record_limit(): ...

0 commit comments

Comments
 (0)