@@ -27,18 +27,27 @@ use arrow::util::display::{ArrayFormatter, FormatOptions};
2727use datafusion:: arrow:: datatypes:: Schema ;
2828use datafusion:: arrow:: pyarrow:: { PyArrowType , ToPyArrow } ;
2929use datafusion:: arrow:: util:: pretty;
30- use datafusion:: common:: UnnestOptions ;
31- use datafusion:: config:: { CsvOptions , TableParquetOptions } ;
30+ use datafusion:: common:: stats:: Precision ;
31+ use datafusion:: common:: { DFSchema , UnnestOptions } ;
32+ use datafusion:: config:: { ConfigOptions , CsvOptions , TableParquetOptions } ;
3233use datafusion:: dataframe:: { DataFrame , DataFrameWriteOptions } ;
33- use datafusion:: execution:: SendableRecordBatchStream ;
34+ use datafusion:: datasource:: physical_plan:: { FileScanConfig , ParquetExec } ;
35+ use datafusion:: datasource:: physical_plan:: parquet:: ParquetExecBuilder ;
36+ use datafusion:: execution:: runtime_env:: RuntimeEnvBuilder ;
37+ use datafusion:: execution:: { SendableRecordBatchStream , TaskContext } ;
3438use datafusion:: parquet:: basic:: { BrotliLevel , Compression , GzipLevel , ZstdLevel } ;
39+ use datafusion:: physical_plan:: { displayable, execute_stream, ExecutionPlan } ;
3540use datafusion:: prelude:: * ;
41+ use datafusion_expr:: registry:: MemoryFunctionRegistry ;
42+ use datafusion_proto:: physical_plan:: { AsExecutionPlan , PhysicalExtensionCodec } ;
43+ use datafusion_proto:: protobuf:: PhysicalPlanNode ;
44+ use deltalake:: delta_datafusion:: DeltaPhysicalCodec ;
45+ use prost:: Message ;
3646use pyo3:: exceptions:: { PyTypeError , PyValueError } ;
3747use pyo3:: prelude:: * ;
3848use pyo3:: pybacked:: PyBackedStr ;
3949use pyo3:: types:: { PyCapsule , PyTuple , PyTupleMethods } ;
4050use tokio:: task:: JoinHandle ;
41-
4251use crate :: errors:: py_datafusion_err;
4352use crate :: expr:: sort_expr:: to_sort_expressions;
4453use crate :: physical_plan:: PyExecutionPlan ;
@@ -49,6 +58,7 @@ use crate::{
4958 errors:: DataFusionError ,
5059 expr:: { sort_expr:: PySortExpr , PyExpr } ,
5160} ;
61+ use crate :: common:: df_schema:: PyDFSchema ;
5262
5363/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
5464/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
@@ -650,6 +660,151 @@ impl PyDataFrame {
650660 fn count ( & self , py : Python ) -> PyResult < usize > {
651661 Ok ( wait_for_future ( py, self . df . as_ref ( ) . clone ( ) . count ( ) ) ?)
652662 }
663+
664+ fn distributed_plan ( & self , num_shards : usize , py : Python < ' _ > ) -> PyResult < DistributedPlan > {
665+ let distributed_plan = wait_for_future ( py, split_physical_plan ( & self . df , num_shards) )
666+ . map_err ( py_datafusion_err) ?;
667+ Ok ( distributed_plan)
668+ }
669+
670+ }
671+
672+ #[ pyclass( get_all) ]
673+ #[ derive( Debug , Clone ) ]
674+ pub struct Statistics {
675+ num_bytes : Option < usize > ,
676+ num_rows : Option < usize > ,
677+ }
678+
679+ impl Statistics {
680+ fn new ( plan : & dyn ExecutionPlan ) -> Self {
681+ fn extract ( prec : Precision < usize > ) -> Option < usize > {
682+ match prec {
683+ Precision :: Exact ( n) | Precision :: Inexact ( n) => Some ( n) ,
684+ Precision :: Absent => None ,
685+ }
686+ }
687+ if let Ok ( stats) = plan. statistics ( ) {
688+ let num_bytes = extract ( stats. total_byte_size ) ;
689+ let num_rows = extract ( stats. num_rows ) ;
690+ Statistics { num_bytes, num_rows}
691+ } else {
692+ Statistics { num_bytes : None , num_rows : None }
693+ }
694+ }
695+ }
696+
697+ #[ pyclass( get_all) ]
698+ #[ derive( Debug , Clone ) ]
699+ pub struct Shard {
700+ stats : Statistics ,
701+ serialized_plan : Vec < u8 > ,
702+ }
703+
704+ impl Shard {
705+ pub fn try_new ( plan : & Arc < dyn ExecutionPlan > ) -> Result < Self , DataFusionError > {
706+ let stats = Statistics :: new ( plan. as_ref ( ) ) ;
707+ let serialized_plan = PhysicalPlanNode :: try_from_physical_plan ( plan. clone ( ) , Self :: codec ( ) ) ?
708+ . encode_to_vec ( ) ;
709+ Ok ( Self { stats, serialized_plan } )
710+ }
711+
712+ fn codec ( ) -> & ' static dyn PhysicalExtensionCodec {
713+ static CODEC : DeltaPhysicalCodec = DeltaPhysicalCodec { } ;
714+ & CODEC
715+ }
716+ }
717+
718+ #[ pyclass( get_all) ]
719+ #[ derive( Debug , Clone ) ]
720+ pub struct DistributedPlan {
721+ shards : Vec < Shard > ,
722+ schema : PyDFSchema ,
723+ stats : Statistics ,
724+ }
725+
726+ async fn split_physical_plan ( df : & DataFrame , num_shards : usize ) -> Result < DistributedPlan , DataFusionError > {
727+ fn split ( plan : & Arc < dyn ExecutionPlan > , num_shards : usize ) -> Vec < Arc < dyn ExecutionPlan > > {
728+ if let Some ( parquet) = plan. as_any ( ) . downcast_ref :: < ParquetExec > ( ) {
729+ let parquet = if let Ok ( Some ( repartitioned) ) = parquet. repartitioned ( num_shards, & ConfigOptions :: default ( ) ) {
730+ repartitioned. as_any ( ) . downcast_ref :: < ParquetExec > ( )
731+ . expect ( "repartitioned parquet is no longer parquet" )
732+ . clone ( )
733+ } else { // repartition failed
734+ parquet. clone ( )
735+ } ;
736+ let config = parquet. base_config ( ) ;
737+ config
738+ . file_groups
739+ . iter ( )
740+ . map ( |shard| {
741+ FileScanConfig {
742+ object_store_url : config. object_store_url . clone ( ) ,
743+ file_schema : config. file_schema . clone ( ) ,
744+ file_groups : shard. iter ( ) . map ( |file| vec ! [ file. to_owned( ) ] ) . collect ( ) , // one partition per file
745+ statistics : config. statistics . clone ( ) ,
746+ projection : config. projection . clone ( ) ,
747+ projection_deep : config. projection_deep . clone ( ) ,
748+ limit : config. limit ,
749+ table_partition_cols : config. table_partition_cols . clone ( ) ,
750+ output_ordering : config. output_ordering . clone ( ) ,
751+ }
752+ } )
753+ . map ( |config| {
754+ let mut builder = ParquetExecBuilder :: new ( config)
755+ . with_table_parquet_options ( parquet. table_parquet_options ( ) . clone ( ) ) ;
756+ if let Some ( predicate) = parquet. predicate ( ) {
757+ builder = builder. with_predicate ( predicate. clone ( ) ) ;
758+ }
759+ builder. build_arc ( )
760+ } )
761+ . map ( |shard| shard as Arc < dyn ExecutionPlan > )
762+ . collect ( )
763+ } else if plan. children ( ) . len ( ) == 0 { // TODO: split leaf nodes other than parquet?
764+ vec ! [ plan. clone( ) ]
765+ } else if plan. children ( ) . len ( ) == 1 {
766+ plan. children ( ) . into_iter ( )
767+ . flat_map ( |child| {
768+ split ( child, num_shards)
769+ . into_iter ( )
770+ . map ( |shard| plan. clone ( ) . with_new_children ( vec ! [ shard] ) )
771+ } )
772+ . collect :: < Result < Vec < _ > , _ > > ( )
773+ . expect ( "Unable to split plan" )
774+ } else {
775+ panic ! (
776+ "Only leaf or single-child plans are supported, found {}" ,
777+ displayable( plan. as_ref( ) ) . one_line( )
778+ )
779+ }
780+ }
781+ let plan = df. clone ( ) . create_physical_plan ( ) . await ?;
782+ let shards = split ( & plan, num_shards)
783+ . iter ( )
784+ . map ( Shard :: try_new)
785+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
786+ let schema = DFSchema :: try_from ( plan. schema ( ) . as_ref ( ) . to_owned ( ) ) ?. into ( ) ;
787+ let stats = Statistics :: new ( plan. as_ref ( ) ) ;
788+ Ok ( DistributedPlan { shards, schema, stats } )
789+ }
790+
791+ #[ pyfunction]
792+ pub fn shard_stream ( serialized_shard_plan : & [ u8 ] , py : Python ) -> PyResult < PyRecordBatchStream > {
793+ deltalake:: ensure_initialized ( ) ;
794+ let registry = MemoryFunctionRegistry :: default ( ) ;
795+ let runtime = RuntimeEnvBuilder :: new ( ) . build ( ) ?;
796+ let codec = DeltaPhysicalCodec { } ;
797+ let node = PhysicalPlanNode :: decode ( serialized_shard_plan)
798+ . map_err ( |e| datafusion:: error:: DataFusionError :: External ( Box :: new ( e) ) )
799+ . map_err ( py_datafusion_err) ?;
800+ let plan = node. try_into_physical_plan ( & registry, & runtime, & codec) ?;
801+ let stream_with_runtime = get_tokio_runtime ( ) . 0 . spawn ( async move {
802+ execute_stream ( plan, Arc :: new ( TaskContext :: default ( ) ) )
803+ } ) ;
804+ wait_for_future ( py, stream_with_runtime)
805+ . map_err ( py_datafusion_err) ?
806+ . map ( PyRecordBatchStream :: new)
807+ . map_err ( py_datafusion_err)
653808}
654809
655810/// Print DataFrame
0 commit comments