Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions src/csharp/Microsoft.Spark.E2ETest/UdfTests/UdfComplexTypesTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ public void TestUdfWithReturnAsMapType()
[Fact]
public void TestUdfWithRowType()
{
// Single Row
// Single Row.
{
Func<Column, Column> udf = Udf<Row, string>(
(row) => row.GetAs<string>("city"));
Expand All @@ -149,7 +149,7 @@ public void TestUdfWithRowType()
Assert.Equal(expected, actual);
}

// Multiple Rows
// Multiple Rows.
{
Func<Column, Column, Column, Column> udf = Udf<Row, Row, string, string>(
(row1, row2, str) =>
Expand All @@ -173,7 +173,7 @@ public void TestUdfWithRowType()
Assert.Equal(expected, actual);
}

// Nested Row
// Nested Rows.
{
Func<Column, Column> udf = Udf<Row, string>(
(row) =>
Expand All @@ -197,15 +197,15 @@ public void TestUdfWithRowType()
[Fact]
public void TestUdfWithReturnAsRowType()
{
// Single GenericRow
// Single Row.
{
var schema = new StructType(new[]
{
new StructField("col1", new IntegerType()),
new StructField("col2", new StringType())
});
Func<Column, Column> udf = Udf<string>(
str => new GenericRow(new object[] { 1, "abc" }), schema);
str => new Row(new object[] { 1, "abc" }), schema);

Row[] rows = _df.Select(udf(_df["name"]).As("col")).Collect().ToArray();
Assert.Equal(3, rows.Length);
Expand All @@ -219,14 +219,14 @@ public void TestUdfWithReturnAsRowType()
}
}

// Generic row is a part of top-level column.
// Row is a part of top-level column.
{
var schema = new StructType(new[]
{
new StructField("col1", new IntegerType())
});
Func<Column, Column> udf = Udf<string>(
str => new GenericRow(new object[] { 111 }), schema);
str => new Row(new object[] { 111 }), schema);

Column nameCol = _df["name"];
Row[] rows = _df.Select(udf(nameCol).As("col"), nameCol).Collect().ToArray();
Expand All @@ -244,7 +244,7 @@ public void TestUdfWithReturnAsRowType()
}
}

// Nested GenericRow
// Nested Rows.
{
var subSchema1 = new StructType(new[]
{
Expand All @@ -263,15 +263,15 @@ public void TestUdfWithReturnAsRowType()
});

Func<Column, Column> udf = Udf<string>(
str => new GenericRow(
str => new Row(
new object[]
{
1,
new GenericRow(new object[] { 1 }),
new GenericRow(new object[]
new Row(new object[] { 1 }),
new Row(new object[]
{
"abc",
new GenericRow(new object[] { 10 })
new Row(new object[] { 10 })
})
}),
schema);
Expand All @@ -295,6 +295,27 @@ public void TestUdfWithReturnAsRowType()
outerCol.GetAs<Row>("col3"));
}
}

// Chained UDFs with Row type.
{
var schema = new StructType(new[]
{
new StructField("col1", new IntegerType()),
new StructField("col2", new StringType())
});
Func<Column, Column> udf1 = Udf<string>(
str => new Row(new object[] { 1, "abc" }), schema);

Func<Column, Column> udf2 = Udf<Row, string>(
row => row.GetAs<string>(1));

Row[] rows = _df.Select(udf2(udf1(_df["name"]))).Collect().ToArray();
Assert.Equal(3, rows.Length);

var expected = new[] { "abc", "abc", "abc" };
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
Assert.Equal(expected, actual);
}
}
}
}
2 changes: 1 addition & 1 deletion src/csharp/Microsoft.Spark.UnitTest/Sql/RowTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ public void RowCollectorTest()
private Pickler CreatePickler()
{
new StructTypePickler().Register();
new RowPickler().Register();
new Microsoft.Spark.UnitTest.TestUtils.RowPickler().Register();
return new Pickler();
}

Expand Down
22 changes: 11 additions & 11 deletions src/csharp/Microsoft.Spark/Sql/Functions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3804,7 +3804,7 @@ public static Func<Column, Column, Column, Column, Column, Column, Column, Colum
/// <returns>
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column> Udf(Func<GenericRow> udf, StructType returnType)
public static Func<Column> Udf(Func<Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply0;
}
Expand All @@ -3816,7 +3816,7 @@ public static Func<Column> Udf(Func<GenericRow> udf, StructType returnType)
/// <returns>
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column> Udf<T>(Func<T, GenericRow> udf, StructType returnType)
public static Func<Column, Column> Udf<T>(Func<T, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply1;
}
Expand All @@ -3830,7 +3830,7 @@ public static Func<Column, Column> Udf<T>(Func<T, GenericRow> udf, StructType re
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column> Udf<T1, T2>(
Func<T1, T2, GenericRow> udf, StructType returnType)
Func<T1, T2, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply2;
}
Expand All @@ -3845,7 +3845,7 @@ public static Func<Column, Column, Column> Udf<T1, T2>(
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column> Udf<T1, T2, T3>(
Func<T1, T2, T3, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply3;
}
Expand All @@ -3861,7 +3861,7 @@ public static Func<Column, Column, Column, Column> Udf<T1, T2, T3>(
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4>(
Func<T1, T2, T3, T4, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply4;
}
Expand All @@ -3878,7 +3878,7 @@ public static Func<Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4>(
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5>(
Func<T1, T2, T3, T4, T5, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply5;
}
Expand All @@ -3896,7 +3896,7 @@ public static Func<Column, Column, Column, Column, Column, Column> Udf<T1, T2, T
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6>(
Func<T1, T2, T3, T4, T5, T6, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, T6, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply6;
}
Expand All @@ -3915,7 +3915,7 @@ public static Func<Column, Column, Column, Column, Column, Column, Column> Udf<T
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7>(
Func<T1, T2, T3, T4, T5, T6, T7, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, T6, T7, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply7;
}
Expand All @@ -3935,7 +3935,7 @@ public static Func<Column, Column, Column, Column, Column, Column, Column, Colum
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7, T8>(
Func<T1, T2, T3, T4, T5, T6, T7, T8, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, T6, T7, T8, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply8;
}
Expand All @@ -3954,7 +3954,7 @@ public static Func<Column, Column, Column, Column, Column, Column, Column, Colum
/// <param name="returnType">Schema associated with this row</param>
/// <returns>A delegate that when invoked will return a <see cref="Column"/> for the result of the UDF.</returns>
public static Func<Column, Column, Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7, T8, T9>(
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply9;
}
Expand All @@ -3976,7 +3976,7 @@ public static Func<Column, Column, Column, Column, Column, Column, Column, Colum
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10>(
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply10;
}
Expand Down
9 changes: 9 additions & 0 deletions src/csharp/Microsoft.Spark/Sql/Row.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ internal Row(object[] values, StructType schema)
Convert();
}

/// <summary>
/// Constructor for the Row class.
/// </summary>
/// <param name="values">Column values for a row</param>
internal Row(object[] values)
{
_genericRow = new GenericRow(values);
}

/// <summary>
/// Schema associated with this row.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ namespace Microsoft.Spark.Sql
/// <summary>
/// Custom pickler for GenericRow objects.
/// </summary>
internal class GenericRowPickler : IObjectPickler
internal class RowPickler : IObjectPickler
{
public void pickle(object o, Stream outs, Pickler currentPickler)
{
currentPickler.save(((GenericRow)o).Values);
currentPickler.save(((Row)o).Values);
}
}
}
4 changes: 2 additions & 2 deletions src/csharp/Microsoft.Spark/Utils/PythonSerDe.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ static PythonSerDe()
Unpickler.registerConstructor(
"pyspark.sql.types", "_create_row_inbound_converter", s_rowConstructor);

// Register custom pickler for GenericRow objects.
Pickler.registerCustomPickler(typeof(GenericRow), new GenericRowPickler());
// Register custom pickler for Row objects.
Pickler.registerCustomPickler(typeof(Row), new RowPickler());
}

/// <summary>
Expand Down