@@ -38,6 +38,113 @@ use tokio::task::JoinSet;
3838use super :: metrics:: BaselineMetrics ;
3939use super :: { ExecutionPlan , RecordBatchStream , SendableRecordBatchStream } ;
4040
41+ /// Creates a stream from a collection of producing tasks, routing panics to the stream
42+ pub ( crate ) struct ReceiverStreamBuilder < O > {
43+ tx : Sender < Result < O > > ,
44+ rx : Receiver < Result < O > > ,
45+ join_set : JoinSet < Result < ( ) > > ,
46+ }
47+
48+ impl < O : Send + ' static > ReceiverStreamBuilder < O > {
49+ /// create new channels with the specified buffer size
50+ pub fn new ( capacity : usize ) -> Self {
51+ let ( tx, rx) = tokio:: sync:: mpsc:: channel ( capacity) ;
52+
53+ Self {
54+ tx,
55+ rx,
56+ join_set : JoinSet :: new ( ) ,
57+ }
58+ }
59+
60+ /// Get a handle for sending [`O`] to the output
61+ pub fn tx ( & self ) -> Sender < Result < O > > {
62+ self . tx . clone ( )
63+ }
64+
65+ /// Spawn task that will be aborted if this builder (or the stream
66+ /// built from it) are dropped
67+ pub fn spawn < F > ( & mut self , task : F )
68+ where
69+ F : Future < Output = Result < ( ) > > ,
70+ F : Send + ' static ,
71+ {
72+ self . join_set . spawn ( task) ;
73+ }
74+
75+ /// Spawn a blocking task that will be aborted if this builder (or the stream
76+ /// built from it) are dropped
77+ ///
78+ /// this is often used to spawn tasks that write to the sender
79+ /// retrieved from `Self::tx`
80+ pub fn spawn_blocking < F > ( & mut self , f : F )
81+ where
82+ F : FnOnce ( ) -> Result < ( ) > ,
83+ F : Send + ' static ,
84+ {
85+ self . join_set . spawn_blocking ( f) ;
86+ }
87+
88+ /// Create a stream of all [`O`] written to `tx`
89+ pub fn build ( self ) -> BoxStream < ' static , Result < O > > {
90+ let Self {
91+ tx,
92+ rx,
93+ mut join_set,
94+ } = self ;
95+
96+ // don't need tx
97+ drop ( tx) ;
98+
99+ // future that checks the result of the join set, and propagates panic if seen
100+ let check = async move {
101+ while let Some ( result) = join_set. join_next ( ) . await {
102+ match result {
103+ Ok ( task_result) => {
104+ match task_result {
105+ // nothing to report
106+ Ok ( _) => continue ,
107+ // This means a blocking task error
108+ Err ( e) => {
109+ return Some ( exec_err ! ( "Spawned Task error: {e}" ) ) ;
110+ }
111+ }
112+ }
113+ // This means a tokio task error, likely a panic
114+ Err ( e) => {
115+ if e. is_panic ( ) {
116+ // resume on the main thread
117+ std:: panic:: resume_unwind ( e. into_panic ( ) ) ;
118+ } else {
119+ // This should only occur if the task is
120+ // cancelled, which would only occur if
121+ // the JoinSet were aborted, which in turn
122+ // would imply that the receiver has been
123+ // dropped and this code is not running
124+ return Some ( internal_err ! ( "Non Panic Task error: {e}" ) ) ;
125+ }
126+ }
127+ }
128+ }
129+ None
130+ } ;
131+
132+ let check_stream = futures:: stream:: once ( check)
133+ // unwrap Option / only return the error
134+ . filter_map ( |item| async move { item } ) ;
135+
136+ // Convert the receiver into a stream
137+ let rx_stream = futures:: stream:: unfold ( rx, |mut rx| async move {
138+ let next_item = rx. recv ( ) . await ;
139+ next_item. map ( |next_item| ( next_item, rx) )
140+ } ) ;
141+
142+ // Merge the streams together so whichever is ready first
143+ // produces the batch
144+ futures:: stream:: select ( rx_stream, check_stream) . boxed ( )
145+ }
146+ }
147+
41148/// Builder for [`RecordBatchReceiverStream`] that propagates errors
42149/// and panic's correctly.
43150///
@@ -47,28 +154,22 @@ use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
47154///
48155/// This also handles propagating panic`s and canceling the tasks.
49156pub struct RecordBatchReceiverStreamBuilder {
50- tx : Sender < Result < RecordBatch > > ,
51- rx : Receiver < Result < RecordBatch > > ,
52157 schema : SchemaRef ,
53- join_set : JoinSet < Result < ( ) > > ,
158+ inner : ReceiverStreamBuilder < RecordBatch > ,
54159}
55160
56161impl RecordBatchReceiverStreamBuilder {
57162 /// create new channels with the specified buffer size
58163 pub fn new ( schema : SchemaRef , capacity : usize ) -> Self {
59- let ( tx, rx) = tokio:: sync:: mpsc:: channel ( capacity) ;
60-
61164 Self {
62- tx,
63- rx,
64165 schema,
65- join_set : JoinSet :: new ( ) ,
166+ inner : ReceiverStreamBuilder :: new ( capacity ) ,
66167 }
67168 }
68169
69- /// Get a handle for sending [`RecordBatch`]es to the output
170+ /// Get a handle for sending [`RecordBatch`] to the output
70171 pub fn tx ( & self ) -> Sender < Result < RecordBatch > > {
71- self . tx . clone ( )
172+ self . inner . tx ( )
72173 }
73174
74175 /// Spawn task that will be aborted if this builder (or the stream
@@ -81,7 +182,7 @@ impl RecordBatchReceiverStreamBuilder {
81182 F : Future < Output = Result < ( ) > > ,
82183 F : Send + ' static ,
83184 {
84- self . join_set . spawn ( task) ;
185+ self . inner . spawn ( task)
85186 }
86187
87188 /// Spawn a blocking task that will be aborted if this builder (or the stream
@@ -94,7 +195,7 @@ impl RecordBatchReceiverStreamBuilder {
94195 F : FnOnce ( ) -> Result < ( ) > ,
95196 F : Send + ' static ,
96197 {
97- self . join_set . spawn_blocking ( f) ;
198+ self . inner . spawn_blocking ( f)
98199 }
99200
100201 /// runs the input_partition of the `input` ExecutionPlan on the
@@ -110,7 +211,7 @@ impl RecordBatchReceiverStreamBuilder {
110211 ) {
111212 let output = self . tx ( ) ;
112213
113- self . spawn ( async move {
214+ self . inner . spawn ( async move {
114215 let mut stream = match input. execute ( partition, context) {
115216 Err ( e) => {
116217 // If send fails, the plan being torn down, there
@@ -155,80 +256,14 @@ impl RecordBatchReceiverStreamBuilder {
155256 } ) ;
156257 }
157258
158- /// Create a stream of all `RecordBatch`es written to `tx`
259+ /// Create a stream of all [ `RecordBatch`] written to `tx`
159260 pub fn build ( self ) -> SendableRecordBatchStream {
160- let Self {
161- tx,
162- rx,
163- schema,
164- mut join_set,
165- } = self ;
166-
167- // don't need tx
168- drop ( tx) ;
169-
170- // future that checks the result of the join set, and propagates panic if seen
171- let check = async move {
172- while let Some ( result) = join_set. join_next ( ) . await {
173- match result {
174- Ok ( task_result) => {
175- match task_result {
176- // nothing to report
177- Ok ( _) => continue ,
178- // This means a blocking task error
179- Err ( e) => {
180- return Some ( exec_err ! ( "Spawned Task error: {e}" ) ) ;
181- }
182- }
183- }
184- // This means a tokio task error, likely a panic
185- Err ( e) => {
186- if e. is_panic ( ) {
187- // resume on the main thread
188- std:: panic:: resume_unwind ( e. into_panic ( ) ) ;
189- } else {
190- // This should only occur if the task is
191- // cancelled, which would only occur if
192- // the JoinSet were aborted, which in turn
193- // would imply that the receiver has been
194- // dropped and this code is not running
195- return Some ( internal_err ! ( "Non Panic Task error: {e}" ) ) ;
196- }
197- }
198- }
199- }
200- None
201- } ;
202-
203- let check_stream = futures:: stream:: once ( check)
204- // unwrap Option / only return the error
205- . filter_map ( |item| async move { item } ) ;
206-
207- // Convert the receiver into a stream
208- let rx_stream = futures:: stream:: unfold ( rx, |mut rx| async move {
209- let next_item = rx. recv ( ) . await ;
210- next_item. map ( |next_item| ( next_item, rx) )
211- } ) ;
212-
213- // Merge the streams together so whichever is ready first
214- // produces the batch
215- let inner = futures:: stream:: select ( rx_stream, check_stream) . boxed ( ) ;
216-
217- Box :: pin ( RecordBatchReceiverStream { schema, inner } )
261+ Box :: pin ( RecordBatchStreamAdapter :: new ( self . schema , self . inner . build ( ) ) )
218262 }
219263}
220264
221- /// A [`SendableRecordBatchStream`] that combines [`RecordBatch`]es from multiple inputs,
222- /// on new tokio Tasks, increasing the potential parallelism.
223- ///
224- /// This structure also handles propagating panics and cancelling the
225- /// underlying tasks correctly.
226- ///
227- /// Use [`Self::builder`] to construct one.
228- pub struct RecordBatchReceiverStream {
229- schema : SchemaRef ,
230- inner : BoxStream < ' static , Result < RecordBatch > > ,
231- }
265+ #[ doc( hidden) ]
266+ pub struct RecordBatchReceiverStream { }
232267
233268impl RecordBatchReceiverStream {
234269 /// Create a builder with an internal buffer of capacity batches.
@@ -240,23 +275,6 @@ impl RecordBatchReceiverStream {
240275 }
241276}
242277
243- impl Stream for RecordBatchReceiverStream {
244- type Item = Result < RecordBatch > ;
245-
246- fn poll_next (
247- mut self : Pin < & mut Self > ,
248- cx : & mut Context < ' _ > ,
249- ) -> Poll < Option < Self :: Item > > {
250- self . inner . poll_next_unpin ( cx)
251- }
252- }
253-
254- impl RecordBatchStream for RecordBatchReceiverStream {
255- fn schema ( & self ) -> SchemaRef {
256- self . schema . clone ( )
257- }
258- }
259-
260278pin_project ! {
261279 /// Combines a [`Stream`] with a [`SchemaRef`] implementing
262280 /// [`RecordBatchStream`] for the combination
0 commit comments