Skip to content

Commit 744971e

Browse files
committed
Extract ReceiverStreamBuilder
1 parent 1bfe740 commit 744971e

File tree

1 file changed

+118
-100
lines changed

1 file changed

+118
-100
lines changed

datafusion/physical-plan/src/stream.rs

Lines changed: 118 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,113 @@ use tokio::task::JoinSet;
3838
use super::metrics::BaselineMetrics;
3939
use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
4040

41+
/// Creates a stream from a collection of producing tasks, routing panics to the stream
42+
pub(crate) struct ReceiverStreamBuilder<O> {
43+
tx: Sender<Result<O>>,
44+
rx: Receiver<Result<O>>,
45+
join_set: JoinSet<Result<()>>,
46+
}
47+
48+
impl<O: Send + 'static> ReceiverStreamBuilder<O> {
49+
/// create new channels with the specified buffer size
50+
pub fn new(capacity: usize) -> Self {
51+
let (tx, rx) = tokio::sync::mpsc::channel(capacity);
52+
53+
Self {
54+
tx,
55+
rx,
56+
join_set: JoinSet::new(),
57+
}
58+
}
59+
60+
/// Get a handle for sending [`O`] to the output
61+
pub fn tx(&self) -> Sender<Result<O>> {
62+
self.tx.clone()
63+
}
64+
65+
/// Spawn task that will be aborted if this builder (or the stream
66+
/// built from it) are dropped
67+
pub fn spawn<F>(&mut self, task: F)
68+
where
69+
F: Future<Output = Result<()>>,
70+
F: Send + 'static,
71+
{
72+
self.join_set.spawn(task);
73+
}
74+
75+
/// Spawn a blocking task that will be aborted if this builder (or the stream
76+
/// built from it) are dropped
77+
///
78+
/// this is often used to spawn tasks that write to the sender
79+
/// retrieved from `Self::tx`
80+
pub fn spawn_blocking<F>(&mut self, f: F)
81+
where
82+
F: FnOnce() -> Result<()>,
83+
F: Send + 'static,
84+
{
85+
self.join_set.spawn_blocking(f);
86+
}
87+
88+
/// Create a stream of all [`O`] written to `tx`
89+
pub fn build(self) -> BoxStream<'static, Result<O>> {
90+
let Self {
91+
tx,
92+
rx,
93+
mut join_set,
94+
} = self;
95+
96+
// don't need tx
97+
drop(tx);
98+
99+
// future that checks the result of the join set, and propagates panic if seen
100+
let check = async move {
101+
while let Some(result) = join_set.join_next().await {
102+
match result {
103+
Ok(task_result) => {
104+
match task_result {
105+
// nothing to report
106+
Ok(_) => continue,
107+
// This means a blocking task error
108+
Err(e) => {
109+
return Some(exec_err!("Spawned Task error: {e}"));
110+
}
111+
}
112+
}
113+
// This means a tokio task error, likely a panic
114+
Err(e) => {
115+
if e.is_panic() {
116+
// resume on the main thread
117+
std::panic::resume_unwind(e.into_panic());
118+
} else {
119+
// This should only occur if the task is
120+
// cancelled, which would only occur if
121+
// the JoinSet were aborted, which in turn
122+
// would imply that the receiver has been
123+
// dropped and this code is not running
124+
return Some(internal_err!("Non Panic Task error: {e}"));
125+
}
126+
}
127+
}
128+
}
129+
None
130+
};
131+
132+
let check_stream = futures::stream::once(check)
133+
// unwrap Option / only return the error
134+
.filter_map(|item| async move { item });
135+
136+
// Convert the receiver into a stream
137+
let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
138+
let next_item = rx.recv().await;
139+
next_item.map(|next_item| (next_item, rx))
140+
});
141+
142+
// Merge the streams together so whichever is ready first
143+
// produces the batch
144+
futures::stream::select(rx_stream, check_stream).boxed()
145+
}
146+
}
147+
41148
/// Builder for [`RecordBatchReceiverStream`] that propagates errors
42149
/// and panic's correctly.
43150
///
@@ -47,28 +154,22 @@ use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
47154
///
48155
/// This also handles propagating panic`s and canceling the tasks.
49156
pub struct RecordBatchReceiverStreamBuilder {
50-
tx: Sender<Result<RecordBatch>>,
51-
rx: Receiver<Result<RecordBatch>>,
52157
schema: SchemaRef,
53-
join_set: JoinSet<Result<()>>,
158+
inner: ReceiverStreamBuilder<RecordBatch>,
54159
}
55160

56161
impl RecordBatchReceiverStreamBuilder {
57162
/// create new channels with the specified buffer size
58163
pub fn new(schema: SchemaRef, capacity: usize) -> Self {
59-
let (tx, rx) = tokio::sync::mpsc::channel(capacity);
60-
61164
Self {
62-
tx,
63-
rx,
64165
schema,
65-
join_set: JoinSet::new(),
166+
inner: ReceiverStreamBuilder::new(capacity),
66167
}
67168
}
68169

69-
/// Get a handle for sending [`RecordBatch`]es to the output
170+
/// Get a handle for sending [`RecordBatch`] to the output
70171
pub fn tx(&self) -> Sender<Result<RecordBatch>> {
71-
self.tx.clone()
172+
self.inner.tx()
72173
}
73174

74175
/// Spawn task that will be aborted if this builder (or the stream
@@ -81,7 +182,7 @@ impl RecordBatchReceiverStreamBuilder {
81182
F: Future<Output = Result<()>>,
82183
F: Send + 'static,
83184
{
84-
self.join_set.spawn(task);
185+
self.inner.spawn(task)
85186
}
86187

87188
/// Spawn a blocking task that will be aborted if this builder (or the stream
@@ -94,7 +195,7 @@ impl RecordBatchReceiverStreamBuilder {
94195
F: FnOnce() -> Result<()>,
95196
F: Send + 'static,
96197
{
97-
self.join_set.spawn_blocking(f);
198+
self.inner.spawn_blocking(f)
98199
}
99200

100201
/// runs the input_partition of the `input` ExecutionPlan on the
@@ -110,7 +211,7 @@ impl RecordBatchReceiverStreamBuilder {
110211
) {
111212
let output = self.tx();
112213

113-
self.spawn(async move {
214+
self.inner.spawn(async move {
114215
let mut stream = match input.execute(partition, context) {
115216
Err(e) => {
116217
// If send fails, the plan being torn down, there
@@ -155,80 +256,14 @@ impl RecordBatchReceiverStreamBuilder {
155256
});
156257
}
157258

158-
/// Create a stream of all `RecordBatch`es written to `tx`
259+
/// Create a stream of all [`RecordBatch`] written to `tx`
159260
pub fn build(self) -> SendableRecordBatchStream {
160-
let Self {
161-
tx,
162-
rx,
163-
schema,
164-
mut join_set,
165-
} = self;
166-
167-
// don't need tx
168-
drop(tx);
169-
170-
// future that checks the result of the join set, and propagates panic if seen
171-
let check = async move {
172-
while let Some(result) = join_set.join_next().await {
173-
match result {
174-
Ok(task_result) => {
175-
match task_result {
176-
// nothing to report
177-
Ok(_) => continue,
178-
// This means a blocking task error
179-
Err(e) => {
180-
return Some(exec_err!("Spawned Task error: {e}"));
181-
}
182-
}
183-
}
184-
// This means a tokio task error, likely a panic
185-
Err(e) => {
186-
if e.is_panic() {
187-
// resume on the main thread
188-
std::panic::resume_unwind(e.into_panic());
189-
} else {
190-
// This should only occur if the task is
191-
// cancelled, which would only occur if
192-
// the JoinSet were aborted, which in turn
193-
// would imply that the receiver has been
194-
// dropped and this code is not running
195-
return Some(internal_err!("Non Panic Task error: {e}"));
196-
}
197-
}
198-
}
199-
}
200-
None
201-
};
202-
203-
let check_stream = futures::stream::once(check)
204-
// unwrap Option / only return the error
205-
.filter_map(|item| async move { item });
206-
207-
// Convert the receiver into a stream
208-
let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
209-
let next_item = rx.recv().await;
210-
next_item.map(|next_item| (next_item, rx))
211-
});
212-
213-
// Merge the streams together so whichever is ready first
214-
// produces the batch
215-
let inner = futures::stream::select(rx_stream, check_stream).boxed();
216-
217-
Box::pin(RecordBatchReceiverStream { schema, inner })
261+
Box::pin(RecordBatchStreamAdapter::new(self.schema, self.inner.build()))
218262
}
219263
}
220264

221-
/// A [`SendableRecordBatchStream`] that combines [`RecordBatch`]es from multiple inputs,
222-
/// on new tokio Tasks, increasing the potential parallelism.
223-
///
224-
/// This structure also handles propagating panics and cancelling the
225-
/// underlying tasks correctly.
226-
///
227-
/// Use [`Self::builder`] to construct one.
228-
pub struct RecordBatchReceiverStream {
229-
schema: SchemaRef,
230-
inner: BoxStream<'static, Result<RecordBatch>>,
231-
}
265+
#[doc(hidden)]
266+
pub struct RecordBatchReceiverStream {}
232267

233268
impl RecordBatchReceiverStream {
234269
/// Create a builder with an internal buffer of capacity batches.
@@ -240,23 +275,6 @@ impl RecordBatchReceiverStream {
240275
}
241276
}
242277

243-
impl Stream for RecordBatchReceiverStream {
244-
type Item = Result<RecordBatch>;
245-
246-
fn poll_next(
247-
mut self: Pin<&mut Self>,
248-
cx: &mut Context<'_>,
249-
) -> Poll<Option<Self::Item>> {
250-
self.inner.poll_next_unpin(cx)
251-
}
252-
}
253-
254-
impl RecordBatchStream for RecordBatchReceiverStream {
255-
fn schema(&self) -> SchemaRef {
256-
self.schema.clone()
257-
}
258-
}
259-
260278
pin_project! {
261279
/// Combines a [`Stream`] with a [`SchemaRef`] implementing
262280
/// [`RecordBatchStream`] for the combination

0 commit comments

Comments
 (0)