Skip to content

Commit 32d49fb

Browse files
authored
feat: support type cast in SchemaAdapter (#6404)
* feat: support type cast in SchemaAdapter * make ci happy * improve the code * make ci happy
1 parent 924059a commit 32d49fb

File tree

1 file changed

+191
-0
lines changed
  • datafusion/core/src/physical_plan/file_format

1 file changed

+191
-0
lines changed

datafusion/core/src/physical_plan/file_format/mod.rs

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ use crate::{
5353
scalar::ScalarValue,
5454
};
5555
use arrow::array::new_null_array;
56+
use arrow::compute::can_cast_types;
5657
use arrow::record_batch::RecordBatchOptions;
5758
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
5859
use datafusion_physical_expr::expressions::Column;
@@ -450,6 +451,77 @@ impl SchemaAdapter {
450451
&options,
451452
)?)
452453
}
454+
455+
/// Creates a `SchemaMapping` that can be used to cast or map the columns from the file schema to the table schema.
456+
///
457+
/// If the provided `file_schema` contains columns of a different type to the expected
458+
/// `table_schema`, the method will attempt to cast the array data from the file schema
459+
/// to the table schema where possible.
460+
#[allow(dead_code)]
461+
pub fn map_schema(&self, file_schema: &Schema) -> Result<SchemaMapping> {
462+
let mut field_mappings = Vec::new();
463+
464+
for (idx, field) in self.table_schema.fields().iter().enumerate() {
465+
match file_schema.field_with_name(field.name()) {
466+
Ok(file_field) => {
467+
if can_cast_types(file_field.data_type(), field.data_type()) {
468+
field_mappings.push((idx, field.data_type().clone()))
469+
} else {
470+
return Err(DataFusionError::Plan(format!(
471+
"Cannot cast file schema field {} of type {:?} to table schema field of type {:?}",
472+
field.name(),
473+
file_field.data_type(),
474+
field.data_type()
475+
)));
476+
}
477+
}
478+
Err(_) => {
479+
return Err(DataFusionError::Plan(format!(
480+
"File schema does not contain expected field {}",
481+
field.name()
482+
)));
483+
}
484+
}
485+
}
486+
Ok(SchemaMapping {
487+
table_schema: self.table_schema.clone(),
488+
field_mappings,
489+
})
490+
}
491+
}
492+
493+
/// The SchemaMapping struct holds a mapping from the file schema to the table schema
494+
/// and any necessary type conversions that need to be applied.
495+
#[derive(Debug)]
496+
pub struct SchemaMapping {
497+
#[allow(dead_code)]
498+
table_schema: SchemaRef,
499+
#[allow(dead_code)]
500+
field_mappings: Vec<(usize, DataType)>,
501+
}
502+
503+
impl SchemaMapping {
504+
/// Adapts a `RecordBatch` to match the `table_schema` using the stored mapping and conversions.
505+
#[allow(dead_code)]
506+
fn map_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
507+
let mut mapped_cols = Vec::with_capacity(self.field_mappings.len());
508+
509+
for (idx, data_type) in &self.field_mappings {
510+
let array = batch.column(*idx);
511+
let casted_array = arrow::compute::cast(array, data_type)?;
512+
mapped_cols.push(casted_array);
513+
}
514+
515+
// Necessary to handle empty batches
516+
let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
517+
518+
let record_batch = RecordBatch::try_new_with_options(
519+
self.table_schema.clone(),
520+
mapped_cols,
521+
&options,
522+
)?;
523+
Ok(record_batch)
524+
}
453525
}
454526

455527
/// A helper that projects partition columns into the file record batches.
@@ -805,6 +877,9 @@ fn get_projected_output_ordering(
805877

806878
#[cfg(test)]
807879
mod tests {
880+
use arrow_array::cast::AsArray;
881+
use arrow_array::types::{Float64Type, UInt32Type};
882+
use arrow_array::{Float32Array, StringArray, UInt64Array};
808883
use chrono::Utc;
809884

810885
use crate::{
@@ -1124,6 +1199,122 @@ mod tests {
11241199
assert!(mapped.is_err());
11251200
}
11261201

1202+
#[test]
1203+
fn schema_adapter_map_schema() {
1204+
let table_schema = Arc::new(Schema::new(vec![
1205+
Field::new("c1", DataType::Utf8, true),
1206+
Field::new("c2", DataType::UInt64, true),
1207+
Field::new("c3", DataType::Float64, true),
1208+
]));
1209+
1210+
let adapter = SchemaAdapter::new(table_schema.clone());
1211+
1212+
// file schema matches table schema
1213+
let file_schema = Schema::new(vec![
1214+
Field::new("c1", DataType::Utf8, true),
1215+
Field::new("c2", DataType::UInt64, true),
1216+
Field::new("c3", DataType::Float64, true),
1217+
]);
1218+
1219+
let mapping = adapter.map_schema(&file_schema).unwrap();
1220+
1221+
assert_eq!(
1222+
mapping.field_mappings,
1223+
vec![
1224+
(0, DataType::Utf8),
1225+
(1, DataType::UInt64),
1226+
(2, DataType::Float64),
1227+
]
1228+
);
1229+
assert_eq!(mapping.table_schema, table_schema);
1230+
1231+
// file schema has columns of a different but castable type
1232+
let file_schema = Schema::new(vec![
1233+
Field::new("c1", DataType::Utf8, true),
1234+
Field::new("c2", DataType::Int32, true), // can be casted to UInt64
1235+
Field::new("c3", DataType::Float32, true), // can be casted to Float64
1236+
]);
1237+
1238+
let mapping = adapter.map_schema(&file_schema).unwrap();
1239+
1240+
assert_eq!(
1241+
mapping.field_mappings,
1242+
vec![
1243+
(0, DataType::Utf8),
1244+
(1, DataType::UInt64),
1245+
(2, DataType::Float64),
1246+
]
1247+
);
1248+
assert_eq!(mapping.table_schema, table_schema);
1249+
1250+
// file schema lacks necessary columns
1251+
let file_schema = Schema::new(vec![
1252+
Field::new("c1", DataType::Utf8, true),
1253+
Field::new("c2", DataType::Int32, true),
1254+
]);
1255+
1256+
let err = adapter.map_schema(&file_schema).unwrap_err();
1257+
1258+
assert!(err
1259+
.to_string()
1260+
.contains("File schema does not contain expected field"));
1261+
1262+
// file schema has columns of a different and non-castable type
1263+
let file_schema = Schema::new(vec![
1264+
Field::new("c1", DataType::Utf8, true),
1265+
Field::new("c2", DataType::Int32, true),
1266+
Field::new("c3", DataType::Date64, true), // cannot be casted to Float64
1267+
]);
1268+
let err = adapter.map_schema(&file_schema).unwrap_err();
1269+
1270+
assert!(err.to_string().contains("Cannot cast file schema field"));
1271+
}
1272+
1273+
#[test]
1274+
fn schema_mapping_map_batch() {
1275+
let table_schema = Arc::new(Schema::new(vec![
1276+
Field::new("c1", DataType::Utf8, true),
1277+
Field::new("c2", DataType::UInt32, true),
1278+
Field::new("c3", DataType::Float64, true),
1279+
]));
1280+
1281+
let adapter = SchemaAdapter::new(table_schema.clone());
1282+
1283+
let file_schema = Schema::new(vec![
1284+
Field::new("c1", DataType::Utf8, true),
1285+
Field::new("c2", DataType::UInt64, true),
1286+
Field::new("c3", DataType::Float32, true),
1287+
]);
1288+
1289+
let mapping = adapter.map_schema(&file_schema).expect("map schema failed");
1290+
1291+
let c1 = StringArray::from(vec!["hello", "world"]);
1292+
let c2 = UInt64Array::from(vec![9_u64, 5_u64]);
1293+
let c3 = Float32Array::from(vec![2.0_f32, 7.0_f32]);
1294+
let batch = RecordBatch::try_new(
1295+
Arc::new(file_schema),
1296+
vec![Arc::new(c1), Arc::new(c2), Arc::new(c3)],
1297+
)
1298+
.unwrap();
1299+
1300+
let mapped_batch = mapping.map_batch(batch).unwrap();
1301+
1302+
assert_eq!(mapped_batch.schema(), table_schema);
1303+
assert_eq!(mapped_batch.num_columns(), 3);
1304+
assert_eq!(mapped_batch.num_rows(), 2);
1305+
1306+
let c1 = mapped_batch.column(0).as_string::<i32>();
1307+
let c2 = mapped_batch.column(1).as_primitive::<UInt32Type>();
1308+
let c3 = mapped_batch.column(2).as_primitive::<Float64Type>();
1309+
1310+
assert_eq!(c1.value(0), "hello");
1311+
assert_eq!(c1.value(1), "world");
1312+
assert_eq!(c2.value(0), 9_u32);
1313+
assert_eq!(c2.value(1), 5_u32);
1314+
assert_eq!(c3.value(0), 2.0_f64);
1315+
assert_eq!(c3.value(1), 7.0_f64);
1316+
}
1317+
11271318
// sets default for configs that play no role in projections
11281319
fn config_for_projection(
11291320
file_schema: SchemaRef,

0 commit comments

Comments
 (0)