@@ -53,6 +53,7 @@ use crate::{
5353 scalar:: ScalarValue ,
5454} ;
5555use arrow:: array:: new_null_array;
56+ use arrow:: compute:: can_cast_types;
5657use arrow:: record_batch:: RecordBatchOptions ;
5758use datafusion_common:: tree_node:: { TreeNode , VisitRecursion } ;
5859use datafusion_physical_expr:: expressions:: Column ;
@@ -450,6 +451,77 @@ impl SchemaAdapter {
450451 & options,
451452 ) ?)
452453 }
454+
455+ /// Creates a `SchemaMapping` that can be used to cast or map the columns from the file schema to the table schema.
456+ ///
457+ /// If the provided `file_schema` contains columns of a different type to the expected
458+ /// `table_schema`, the method will attempt to cast the array data from the file schema
459+ /// to the table schema where possible.
460+ #[ allow( dead_code) ]
461+ pub fn map_schema ( & self , file_schema : & Schema ) -> Result < SchemaMapping > {
462+ let mut field_mappings = Vec :: new ( ) ;
463+
464+ for ( idx, field) in self . table_schema . fields ( ) . iter ( ) . enumerate ( ) {
465+ match file_schema. field_with_name ( field. name ( ) ) {
466+ Ok ( file_field) => {
467+ if can_cast_types ( file_field. data_type ( ) , field. data_type ( ) ) {
468+ field_mappings. push ( ( idx, field. data_type ( ) . clone ( ) ) )
469+ } else {
470+ return Err ( DataFusionError :: Plan ( format ! (
471+ "Cannot cast file schema field {} of type {:?} to table schema field of type {:?}" ,
472+ field. name( ) ,
473+ file_field. data_type( ) ,
474+ field. data_type( )
475+ ) ) ) ;
476+ }
477+ }
478+ Err ( _) => {
479+ return Err ( DataFusionError :: Plan ( format ! (
480+ "File schema does not contain expected field {}" ,
481+ field. name( )
482+ ) ) ) ;
483+ }
484+ }
485+ }
486+ Ok ( SchemaMapping {
487+ table_schema : self . table_schema . clone ( ) ,
488+ field_mappings,
489+ } )
490+ }
491+ }
492+
493+ /// The SchemaMapping struct holds a mapping from the file schema to the table schema
494+ /// and any necessary type conversions that need to be applied.
495+ #[ derive( Debug ) ]
496+ pub struct SchemaMapping {
497+ #[ allow( dead_code) ]
498+ table_schema : SchemaRef ,
499+ #[ allow( dead_code) ]
500+ field_mappings : Vec < ( usize , DataType ) > ,
501+ }
502+
503+ impl SchemaMapping {
504+ /// Adapts a `RecordBatch` to match the `table_schema` using the stored mapping and conversions.
505+ #[ allow( dead_code) ]
506+ fn map_batch ( & self , batch : RecordBatch ) -> Result < RecordBatch > {
507+ let mut mapped_cols = Vec :: with_capacity ( self . field_mappings . len ( ) ) ;
508+
509+ for ( idx, data_type) in & self . field_mappings {
510+ let array = batch. column ( * idx) ;
511+ let casted_array = arrow:: compute:: cast ( array, data_type) ?;
512+ mapped_cols. push ( casted_array) ;
513+ }
514+
515+ // Necessary to handle empty batches
516+ let options = RecordBatchOptions :: new ( ) . with_row_count ( Some ( batch. num_rows ( ) ) ) ;
517+
518+ let record_batch = RecordBatch :: try_new_with_options (
519+ self . table_schema . clone ( ) ,
520+ mapped_cols,
521+ & options,
522+ ) ?;
523+ Ok ( record_batch)
524+ }
453525}
454526
455527/// A helper that projects partition columns into the file record batches.
@@ -805,6 +877,9 @@ fn get_projected_output_ordering(
805877
806878#[ cfg( test) ]
807879mod tests {
880+ use arrow_array:: cast:: AsArray ;
881+ use arrow_array:: types:: { Float64Type , UInt32Type } ;
882+ use arrow_array:: { Float32Array , StringArray , UInt64Array } ;
808883 use chrono:: Utc ;
809884
810885 use crate :: {
@@ -1124,6 +1199,122 @@ mod tests {
11241199 assert ! ( mapped. is_err( ) ) ;
11251200 }
11261201
1202+ #[ test]
1203+ fn schema_adapter_map_schema ( ) {
1204+ let table_schema = Arc :: new ( Schema :: new ( vec ! [
1205+ Field :: new( "c1" , DataType :: Utf8 , true ) ,
1206+ Field :: new( "c2" , DataType :: UInt64 , true ) ,
1207+ Field :: new( "c3" , DataType :: Float64 , true ) ,
1208+ ] ) ) ;
1209+
1210+ let adapter = SchemaAdapter :: new ( table_schema. clone ( ) ) ;
1211+
1212+ // file schema matches table schema
1213+ let file_schema = Schema :: new ( vec ! [
1214+ Field :: new( "c1" , DataType :: Utf8 , true ) ,
1215+ Field :: new( "c2" , DataType :: UInt64 , true ) ,
1216+ Field :: new( "c3" , DataType :: Float64 , true ) ,
1217+ ] ) ;
1218+
1219+ let mapping = adapter. map_schema ( & file_schema) . unwrap ( ) ;
1220+
1221+ assert_eq ! (
1222+ mapping. field_mappings,
1223+ vec![
1224+ ( 0 , DataType :: Utf8 ) ,
1225+ ( 1 , DataType :: UInt64 ) ,
1226+ ( 2 , DataType :: Float64 ) ,
1227+ ]
1228+ ) ;
1229+ assert_eq ! ( mapping. table_schema, table_schema) ;
1230+
1231+ // file schema has columns of a different but castable type
1232+ let file_schema = Schema :: new ( vec ! [
1233+ Field :: new( "c1" , DataType :: Utf8 , true ) ,
1234+ Field :: new( "c2" , DataType :: Int32 , true ) , // can be casted to UInt64
1235+ Field :: new( "c3" , DataType :: Float32 , true ) , // can be casted to Float64
1236+ ] ) ;
1237+
1238+ let mapping = adapter. map_schema ( & file_schema) . unwrap ( ) ;
1239+
1240+ assert_eq ! (
1241+ mapping. field_mappings,
1242+ vec![
1243+ ( 0 , DataType :: Utf8 ) ,
1244+ ( 1 , DataType :: UInt64 ) ,
1245+ ( 2 , DataType :: Float64 ) ,
1246+ ]
1247+ ) ;
1248+ assert_eq ! ( mapping. table_schema, table_schema) ;
1249+
1250+ // file schema lacks necessary columns
1251+ let file_schema = Schema :: new ( vec ! [
1252+ Field :: new( "c1" , DataType :: Utf8 , true ) ,
1253+ Field :: new( "c2" , DataType :: Int32 , true ) ,
1254+ ] ) ;
1255+
1256+ let err = adapter. map_schema ( & file_schema) . unwrap_err ( ) ;
1257+
1258+ assert ! ( err
1259+ . to_string( )
1260+ . contains( "File schema does not contain expected field" ) ) ;
1261+
1262+ // file schema has columns of a different and non-castable type
1263+ let file_schema = Schema :: new ( vec ! [
1264+ Field :: new( "c1" , DataType :: Utf8 , true ) ,
1265+ Field :: new( "c2" , DataType :: Int32 , true ) ,
1266+ Field :: new( "c3" , DataType :: Date64 , true ) , // cannot be casted to Float64
1267+ ] ) ;
1268+ let err = adapter. map_schema ( & file_schema) . unwrap_err ( ) ;
1269+
1270+ assert ! ( err. to_string( ) . contains( "Cannot cast file schema field" ) ) ;
1271+ }
1272+
1273+ #[ test]
1274+ fn schema_mapping_map_batch ( ) {
1275+ let table_schema = Arc :: new ( Schema :: new ( vec ! [
1276+ Field :: new( "c1" , DataType :: Utf8 , true ) ,
1277+ Field :: new( "c2" , DataType :: UInt32 , true ) ,
1278+ Field :: new( "c3" , DataType :: Float64 , true ) ,
1279+ ] ) ) ;
1280+
1281+ let adapter = SchemaAdapter :: new ( table_schema. clone ( ) ) ;
1282+
1283+ let file_schema = Schema :: new ( vec ! [
1284+ Field :: new( "c1" , DataType :: Utf8 , true ) ,
1285+ Field :: new( "c2" , DataType :: UInt64 , true ) ,
1286+ Field :: new( "c3" , DataType :: Float32 , true ) ,
1287+ ] ) ;
1288+
1289+ let mapping = adapter. map_schema ( & file_schema) . expect ( "map schema failed" ) ;
1290+
1291+ let c1 = StringArray :: from ( vec ! [ "hello" , "world" ] ) ;
1292+ let c2 = UInt64Array :: from ( vec ! [ 9_u64 , 5_u64 ] ) ;
1293+ let c3 = Float32Array :: from ( vec ! [ 2.0_f32 , 7.0_f32 ] ) ;
1294+ let batch = RecordBatch :: try_new (
1295+ Arc :: new ( file_schema) ,
1296+ vec ! [ Arc :: new( c1) , Arc :: new( c2) , Arc :: new( c3) ] ,
1297+ )
1298+ . unwrap ( ) ;
1299+
1300+ let mapped_batch = mapping. map_batch ( batch) . unwrap ( ) ;
1301+
1302+ assert_eq ! ( mapped_batch. schema( ) , table_schema) ;
1303+ assert_eq ! ( mapped_batch. num_columns( ) , 3 ) ;
1304+ assert_eq ! ( mapped_batch. num_rows( ) , 2 ) ;
1305+
1306+ let c1 = mapped_batch. column ( 0 ) . as_string :: < i32 > ( ) ;
1307+ let c2 = mapped_batch. column ( 1 ) . as_primitive :: < UInt32Type > ( ) ;
1308+ let c3 = mapped_batch. column ( 2 ) . as_primitive :: < Float64Type > ( ) ;
1309+
1310+ assert_eq ! ( c1. value( 0 ) , "hello" ) ;
1311+ assert_eq ! ( c1. value( 1 ) , "world" ) ;
1312+ assert_eq ! ( c2. value( 0 ) , 9_u32 ) ;
1313+ assert_eq ! ( c2. value( 1 ) , 5_u32 ) ;
1314+ assert_eq ! ( c3. value( 0 ) , 2.0_f64 ) ;
1315+ assert_eq ! ( c3. value( 1 ) , 7.0_f64 ) ;
1316+ }
1317+
11271318 // sets default for configs that play no role in projections
11281319 fn config_for_projection (
11291320 file_schema : SchemaRef ,
0 commit comments