Skip to content

Commit fcabbb2

Browse files
committed
Building blocks for Ray DataFusionDatasource
1 parent 6badcbd commit fcabbb2

File tree

3 files changed

+165
-4
lines changed

3 files changed

+165
-4
lines changed

python/datafusion/dataframe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,9 @@ def count(self) -> int:
708708
"""
709709
return self.df.count()
710710

711+
def distributed_plan(self, num_shards: int):
712+
return self.df.distributed_plan(num_shards)
713+
711714
@deprecated("Use :py:func:`unnest_columns` instead.")
712715
def unnest_column(self, column: str, preserve_nulls: bool = True) -> DataFrame:
713716
"""See :py:func:`unnest_columns`."""

src/dataframe.rs

Lines changed: 159 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,27 @@ use arrow::util::display::{ArrayFormatter, FormatOptions};
2727
use datafusion::arrow::datatypes::Schema;
2828
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
2929
use datafusion::arrow::util::pretty;
30-
use datafusion::common::UnnestOptions;
31-
use datafusion::config::{CsvOptions, TableParquetOptions};
30+
use datafusion::common::stats::Precision;
31+
use datafusion::common::{DFSchema, UnnestOptions};
32+
use datafusion::config::{ConfigOptions, CsvOptions, TableParquetOptions};
3233
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
33-
use datafusion::execution::SendableRecordBatchStream;
34+
use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec};
35+
use datafusion::datasource::physical_plan::parquet::ParquetExecBuilder;
36+
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
37+
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
3438
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
39+
use datafusion::physical_plan::{displayable, execute_stream, ExecutionPlan};
3540
use datafusion::prelude::*;
41+
use datafusion_expr::registry::MemoryFunctionRegistry;
42+
use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec};
43+
use datafusion_proto::protobuf::PhysicalPlanNode;
44+
use deltalake::delta_datafusion::DeltaPhysicalCodec;
45+
use prost::Message;
3646
use pyo3::exceptions::{PyTypeError, PyValueError};
3747
use pyo3::prelude::*;
3848
use pyo3::pybacked::PyBackedStr;
3949
use pyo3::types::{PyCapsule, PyTuple, PyTupleMethods};
4050
use tokio::task::JoinHandle;
41-
4251
use crate::errors::py_datafusion_err;
4352
use crate::expr::sort_expr::to_sort_expressions;
4453
use crate::physical_plan::PyExecutionPlan;
@@ -49,6 +58,7 @@ use crate::{
4958
errors::DataFusionError,
5059
expr::{sort_expr::PySortExpr, PyExpr},
5160
};
61+
use crate::common::df_schema::PyDFSchema;
5262

5363
/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
5464
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
@@ -650,6 +660,151 @@ impl PyDataFrame {
650660
fn count(&self, py: Python) -> PyResult<usize> {
651661
Ok(wait_for_future(py, self.df.as_ref().clone().count())?)
652662
}
663+
664+
fn distributed_plan(&self, num_shards: usize, py: Python<'_>) -> PyResult<DistributedPlan> {
665+
let distributed_plan = wait_for_future(py, split_physical_plan(&self.df, num_shards))
666+
.map_err(py_datafusion_err)?;
667+
Ok(distributed_plan)
668+
}
669+
670+
}
671+
672+
#[pyclass(get_all)]
673+
#[derive(Debug, Clone)]
674+
pub struct Statistics {
675+
num_bytes: Option<usize>,
676+
num_rows: Option<usize>,
677+
}
678+
679+
impl Statistics {
680+
fn new(plan: &dyn ExecutionPlan) -> Self {
681+
fn extract(prec: Precision<usize>) -> Option<usize> {
682+
match prec {
683+
Precision::Exact(n) | Precision::Inexact(n) => Some(n),
684+
Precision::Absent => None,
685+
}
686+
}
687+
if let Ok(stats) = plan.statistics() {
688+
let num_bytes = extract(stats.total_byte_size);
689+
let num_rows = extract(stats.num_rows);
690+
Statistics { num_bytes, num_rows}
691+
} else {
692+
Statistics { num_bytes: None, num_rows: None }
693+
}
694+
}
695+
}
696+
697+
#[pyclass(get_all)]
698+
#[derive(Debug, Clone)]
699+
pub struct Shard {
700+
stats: Statistics,
701+
serialized_plan: Vec<u8>,
702+
}
703+
704+
impl Shard {
705+
pub fn try_new(plan: &Arc<dyn ExecutionPlan>) -> Result<Self, DataFusionError> {
706+
let stats = Statistics::new(plan.as_ref());
707+
let serialized_plan = PhysicalPlanNode::try_from_physical_plan(plan.clone(), Self::codec())?
708+
.encode_to_vec();
709+
Ok(Self { stats, serialized_plan })
710+
}
711+
712+
fn codec() -> &'static dyn PhysicalExtensionCodec {
713+
static CODEC: DeltaPhysicalCodec = DeltaPhysicalCodec {};
714+
&CODEC
715+
}
716+
}
717+
718+
#[pyclass(get_all)]
719+
#[derive(Debug, Clone)]
720+
pub struct DistributedPlan {
721+
shards: Vec<Shard>,
722+
schema: PyDFSchema,
723+
stats: Statistics,
724+
}
725+
726+
async fn split_physical_plan(df: &DataFrame, num_shards: usize) -> Result<DistributedPlan, DataFusionError> {
727+
fn split(plan: &Arc<dyn ExecutionPlan>, num_shards: usize) -> Vec<Arc<dyn ExecutionPlan>> {
728+
if let Some(parquet) = plan.as_any().downcast_ref::<ParquetExec>() {
729+
let parquet = if let Ok(Some(repartitioned)) = parquet.repartitioned(num_shards, &ConfigOptions::default()) {
730+
repartitioned.as_any().downcast_ref::<ParquetExec>()
731+
.expect("repartitioned parquet is no longer parquet")
732+
.clone()
733+
} else { // repartition failed
734+
parquet.clone()
735+
};
736+
let config = parquet.base_config();
737+
config
738+
.file_groups
739+
.iter()
740+
.map(|shard| {
741+
FileScanConfig {
742+
object_store_url: config.object_store_url.clone(),
743+
file_schema: config.file_schema.clone(),
744+
file_groups: shard.iter().map(|file| vec![file.to_owned()]).collect(), // one partition per file
745+
statistics: config.statistics.clone(),
746+
projection: config.projection.clone(),
747+
projection_deep: config.projection_deep.clone(),
748+
limit: config.limit,
749+
table_partition_cols: config.table_partition_cols.clone(),
750+
output_ordering: config.output_ordering.clone(),
751+
}
752+
})
753+
.map(|config| {
754+
let mut builder = ParquetExecBuilder::new(config)
755+
.with_table_parquet_options(parquet.table_parquet_options().clone());
756+
if let Some(predicate) = parquet.predicate() {
757+
builder = builder.with_predicate(predicate.clone());
758+
}
759+
builder.build_arc()
760+
})
761+
.map(|shard| shard as Arc<dyn ExecutionPlan>)
762+
.collect()
763+
} else if plan.children().len() == 0 { // TODO: split leaf nodes other than parquet?
764+
vec![plan.clone()]
765+
} else if plan.children().len() == 1 {
766+
plan.children().into_iter()
767+
.flat_map(|child| {
768+
split(child, num_shards)
769+
.into_iter()
770+
.map(|shard| plan.clone().with_new_children(vec![shard]))
771+
})
772+
.collect::<Result<Vec<_>, _>>()
773+
.expect("Unable to split plan")
774+
} else {
775+
panic!(
776+
"Only leaf or single-child plans are supported, found {}",
777+
displayable(plan.as_ref()).one_line()
778+
)
779+
}
780+
}
781+
let plan = df.clone().create_physical_plan().await?;
782+
let shards = split(&plan, num_shards)
783+
.iter()
784+
.map(Shard::try_new)
785+
.collect::<Result<Vec<_>, _>>()?;
786+
let schema = DFSchema::try_from(plan.schema().as_ref().to_owned())?.into();
787+
let stats = Statistics::new(plan.as_ref());
788+
Ok(DistributedPlan { shards, schema, stats })
789+
}
790+
791+
#[pyfunction]
792+
pub fn shard_stream(serialized_shard_plan: &[u8], py: Python) -> PyResult<PyRecordBatchStream> {
793+
deltalake::ensure_initialized();
794+
let registry = MemoryFunctionRegistry::default();
795+
let runtime = RuntimeEnvBuilder::new().build()?;
796+
let codec = DeltaPhysicalCodec {};
797+
let node = PhysicalPlanNode::decode(serialized_shard_plan)
798+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))
799+
.map_err(py_datafusion_err)?;
800+
let plan = node.try_into_physical_plan(&registry, &runtime, &codec)?;
801+
let stream_with_runtime = get_tokio_runtime().0.spawn(async move {
802+
execute_stream(plan, Arc::new(TaskContext::default()))
803+
});
804+
wait_for_future(py, stream_with_runtime)
805+
.map_err(py_datafusion_err)?
806+
.map(PyRecordBatchStream::new)
807+
.map_err(py_datafusion_err)
653808
}
654809

655810
/// Print DataFrame

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
115115
#[cfg(feature = "substrait")]
116116
setup_substrait_module(py, &m)?;
117117

118+
m.add_class::<dataframe::Shard>()?;
119+
m.add_class::<dataframe::DistributedPlan>()?;
120+
m.add_wrapped(wrap_pyfunction!(dataframe::shard_stream))?;
118121
Ok(())
119122
}
120123

0 commit comments

Comments
 (0)