@@ -24,27 +24,156 @@ use std::fs::File;
2424use std:: io:: BufReader ;
2525use std:: path:: { Path , PathBuf } ;
2626use std:: ptr:: NonNull ;
27+ use std:: sync:: Arc ;
2728
2829use arrow:: array:: ArrayData ;
2930use arrow:: datatypes:: { Schema , SchemaRef } ;
3031use arrow:: ipc:: { reader:: StreamReader , writer:: StreamWriter } ;
3132use arrow:: record_batch:: RecordBatch ;
32- use tokio:: sync:: mpsc:: Sender ;
33-
34- use datafusion_common:: { exec_datafusion_err, HashSet , Result } ;
35-
36- fn read_spill ( sender : Sender < Result < RecordBatch > > , path : & Path ) -> Result < ( ) > {
37- let file = BufReader :: new ( File :: open ( path) ?) ;
38- // SAFETY: DataFusion's spill writer strictly follows Arrow IPC specifications
39- // with validated schemas and buffers. Skip redundant validation during read
40- // to speedup read operation. This is safe for DataFusion as input guaranteed to be correct when written.
41- let reader = unsafe { StreamReader :: try_new ( file, None ) ?. with_skip_validation ( true ) } ;
42- for batch in reader {
43- sender
44- . blocking_send ( batch. map_err ( Into :: into) )
45- . map_err ( |e| exec_datafusion_err ! ( "{e}" ) ) ?;
33+
34+ use datafusion_common:: { exec_datafusion_err, DataFusionError , HashSet , Result } ;
35+ use datafusion_common_runtime:: SpawnedTask ;
36+ use datafusion_execution:: disk_manager:: RefCountedTempFile ;
37+ use datafusion_execution:: RecordBatchStream ;
38+ use futures:: { FutureExt as _, Stream } ;
39+
40+ /// Stream that reads spill files from disk where each batch is read in a spawned blocking task
41+ /// It will read one batch at a time and will not do any buffering, to buffer data use [`crate::common::spawn_buffered`]
42+ struct SpillReaderStream {
43+ schema : SchemaRef ,
44+ state : SpillReaderStreamState ,
45+ }
46+
47+ /// When we poll for the next batch, we will get back both the batch and the reader,
48+ /// so we can call `next` again.
49+ type NextRecordBatchResult = Result < ( StreamReader < BufReader < File > > , Option < RecordBatch > ) > ;
50+
51+ enum SpillReaderStreamState {
52+ /// Initial state: the stream was not initialized yet
53+ /// and the file was not opened
54+ Uninitialized ( RefCountedTempFile ) ,
55+
56+ /// A read is in progress in a spawned blocking task for which we hold the handle.
57+ ReadInProgress ( SpawnedTask < NextRecordBatchResult > ) ,
58+
59+ /// A read has finished and we wait for being polled again in order to start reading the next batch.
60+ Waiting ( StreamReader < BufReader < File > > ) ,
61+
62+ /// The stream has finished, successfully or not.
63+ Done ,
64+ }
65+
66+ impl SpillReaderStream {
67+ fn new ( schema : SchemaRef , spill_file : RefCountedTempFile ) -> Self {
68+ Self {
69+ schema,
70+ state : SpillReaderStreamState :: Uninitialized ( spill_file) ,
71+ }
72+ }
73+
74+ fn poll_next_inner (
75+ & mut self ,
76+ cx : & mut std:: task:: Context < ' _ > ,
77+ ) -> std:: task:: Poll < Option < Result < RecordBatch > > > {
78+ match & mut self . state {
79+ SpillReaderStreamState :: Uninitialized ( _) => {
80+ // Temporarily replace with `Done` to be able to pass the file to the task.
81+ let SpillReaderStreamState :: Uninitialized ( spill_file) =
82+ std:: mem:: replace ( & mut self . state , SpillReaderStreamState :: Done )
83+ else {
84+ unreachable ! ( )
85+ } ;
86+
87+ let task = SpawnedTask :: spawn_blocking ( move || {
88+ let file = BufReader :: new ( File :: open ( spill_file. path ( ) ) ?) ;
89+ // SAFETY: DataFusion's spill writer strictly follows Arrow IPC specifications
90+ // with validated schemas and buffers. Skip redundant validation during read
91+ // to speedup read operation. This is safe for DataFusion as input guaranteed to be correct when written.
92+ let mut reader = unsafe {
93+ StreamReader :: try_new ( file, None ) ?. with_skip_validation ( true )
94+ } ;
95+
96+ let next_batch = reader. next ( ) . transpose ( ) ?;
97+
98+ Ok ( ( reader, next_batch) )
99+ } ) ;
100+
101+ self . state = SpillReaderStreamState :: ReadInProgress ( task) ;
102+
103+ // Poll again immediately so the inner task is polled and the waker is
104+ // registered.
105+ self . poll_next_inner ( cx)
106+ }
107+
108+ SpillReaderStreamState :: ReadInProgress ( task) => {
109+ let result = futures:: ready!( task. poll_unpin( cx) )
110+ . unwrap_or_else ( |err| Err ( DataFusionError :: External ( Box :: new ( err) ) ) ) ;
111+
112+ match result {
113+ Ok ( ( reader, batch) ) => {
114+ match batch {
115+ Some ( batch) => {
116+ self . state = SpillReaderStreamState :: Waiting ( reader) ;
117+
118+ std:: task:: Poll :: Ready ( Some ( Ok ( batch) ) )
119+ }
120+ None => {
121+ // Stream is done
122+ self . state = SpillReaderStreamState :: Done ;
123+
124+ std:: task:: Poll :: Ready ( None )
125+ }
126+ }
127+ }
128+ Err ( err) => {
129+ self . state = SpillReaderStreamState :: Done ;
130+
131+ std:: task:: Poll :: Ready ( Some ( Err ( err) ) )
132+ }
133+ }
134+ }
135+
136+ SpillReaderStreamState :: Waiting ( _) => {
137+ // Temporarily replace with `Done` to be able to pass the file to the task.
138+ let SpillReaderStreamState :: Waiting ( mut reader) =
139+ std:: mem:: replace ( & mut self . state , SpillReaderStreamState :: Done )
140+ else {
141+ unreachable ! ( )
142+ } ;
143+
144+ let task = SpawnedTask :: spawn_blocking ( move || {
145+ let next_batch = reader. next ( ) . transpose ( ) ?;
146+
147+ Ok ( ( reader, next_batch) )
148+ } ) ;
149+
150+ self . state = SpillReaderStreamState :: ReadInProgress ( task) ;
151+
152+ // Poll again immediately so the inner task is polled and the waker is
153+ // registered.
154+ self . poll_next_inner ( cx)
155+ }
156+
157+ SpillReaderStreamState :: Done => std:: task:: Poll :: Ready ( None ) ,
158+ }
159+ }
160+ }
161+
162+ impl Stream for SpillReaderStream {
163+ type Item = Result < RecordBatch > ;
164+
165+ fn poll_next (
166+ self : std:: pin:: Pin < & mut Self > ,
167+ cx : & mut std:: task:: Context < ' _ > ,
168+ ) -> std:: task:: Poll < Option < Self :: Item > > {
169+ self . get_mut ( ) . poll_next_inner ( cx)
170+ }
171+ }
172+
173+ impl RecordBatchStream for SpillReaderStream {
174+ fn schema ( & self ) -> SchemaRef {
175+ Arc :: clone ( & self . schema )
46176 }
47- Ok ( ( ) )
48177}
49178
50179/// Spill the `RecordBatch` to disk as smaller batches
@@ -205,6 +334,7 @@ mod tests {
205334 use arrow:: record_batch:: RecordBatch ;
206335 use datafusion_common:: Result ;
207336 use datafusion_execution:: runtime_env:: RuntimeEnv ;
337+ use futures:: StreamExt as _;
208338
209339 use std:: sync:: Arc ;
210340
@@ -604,4 +734,42 @@ mod tests {
604734
605735 Ok ( ( ) )
606736 }
737+
738+ #[ test]
739+ fn test_reading_more_spills_than_tokio_blocking_threads ( ) -> Result < ( ) > {
740+ tokio:: runtime:: Builder :: new_current_thread ( )
741+ . enable_all ( )
742+ . max_blocking_threads ( 1 )
743+ . build ( )
744+ . unwrap ( )
745+ . block_on ( async {
746+ let batch = build_table_i32 (
747+ ( "a2" , & vec ! [ 0 , 1 , 2 ] ) ,
748+ ( "b2" , & vec ! [ 3 , 4 , 5 ] ) ,
749+ ( "c2" , & vec ! [ 4 , 5 , 6 ] ) ,
750+ ) ;
751+
752+ let schema = batch. schema ( ) ;
753+
754+ // Construct SpillManager
755+ let env = Arc :: new ( RuntimeEnv :: default ( ) ) ;
756+ let metrics = SpillMetrics :: new ( & ExecutionPlanMetricsSet :: new ( ) , 0 ) ;
757+ let spill_manager = SpillManager :: new ( env, metrics, Arc :: clone ( & schema) ) ;
758+ let batches: [ _ ; 10 ] = std:: array:: from_fn ( |_| batch. clone ( ) ) ;
759+
760+ let spill_file_1 = spill_manager
761+ . spill_record_batch_and_finish ( & batches, "Test1" ) ?
762+ . unwrap ( ) ;
763+ let spill_file_2 = spill_manager
764+ . spill_record_batch_and_finish ( & batches, "Test2" ) ?
765+ . unwrap ( ) ;
766+
767+ let mut stream_1 = spill_manager. read_spill_as_stream ( spill_file_1) ?;
768+ let mut stream_2 = spill_manager. read_spill_as_stream ( spill_file_2) ?;
769+ stream_1. next ( ) . await ;
770+ stream_2. next ( ) . await ;
771+
772+ Ok ( ( ) )
773+ } )
774+ }
607775}
0 commit comments