Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ public void TestUdfWithReturnAsMapType()
[Fact]
public void TestUdfWithRowType()
{
// Single Row
// Single Row.
{
Func<Column, Column> udf = Udf<Row, string>(
(row) => row.GetAs<string>("city"));
Expand All @@ -149,7 +149,7 @@ public void TestUdfWithRowType()
Assert.Equal(expected, actual);
}

// Multiple Rows
// Multiple Rows.
{
Func<Column, Column, Column, Column> udf = Udf<Row, Row, string, string>(
(row1, row2, str) =>
Expand All @@ -173,7 +173,7 @@ public void TestUdfWithRowType()
Assert.Equal(expected, actual);
}

// Nested Row
// Nested Rows.
{
Func<Column, Column> udf = Udf<Row, string>(
(row) =>
Expand All @@ -197,7 +197,7 @@ public void TestUdfWithRowType()
[Fact]
public void TestUdfWithReturnAsRowType()
{
// Single GenericRow
// Test UDF that returns a Row object.
{
var schema = new StructType(new[]
{
Expand All @@ -219,7 +219,7 @@ public void TestUdfWithReturnAsRowType()
}
}

// Generic row is a part of top-level column.
// GenericRow is a part of top-level column.
{
var schema = new StructType(new[]
{
Expand All @@ -244,7 +244,7 @@ public void TestUdfWithReturnAsRowType()
}
}

// Nested GenericRow
// Test UDF that returns a nested Row object.
{
var subSchema1 = new StructType(new[]
{
Expand Down Expand Up @@ -295,6 +295,27 @@ public void TestUdfWithReturnAsRowType()
outerCol.GetAs<Row>("col3"));
}
}

// Chained UDFs.
{
var schema = new StructType(new[]
{
new StructField("col1", new IntegerType()),
new StructField("col2", new StringType())
});
Func<Column, Column> udf1 = Udf<string>(
str => new GenericRow(new object[] { 1, "abc" }), schema);

Func<Column, Column> udf2 = Udf<Row, string>(
row => row.GetAs<string>(1));

Row[] rows = _df.Select(udf2(udf1(_df["name"]))).Collect().ToArray();
Assert.Equal(3, rows.Length);

var expected = new[] { "abc", "abc", "abc" };
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
Assert.Equal(expected, actual);
}
}
}
}
2 changes: 1 addition & 1 deletion src/csharp/Microsoft.Spark.UnitTest/Sql/RowTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ public void RowCollectorTest()
private Pickler CreatePickler()
{
new StructTypePickler().Register();
new RowPickler().Register();
new TestUtils.RowPickler().Register();
return new Pickler();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@

namespace Microsoft.Spark.Sql
{
/// <summary>
/// Custom pickler for Row objects.
/// </summary>
internal class RowPickler : IObjectPickler
{
public void pickle(object o, Stream outs, Pickler currentPickler)
{
currentPickler.save(((Row)o).Values);
}
}

/// <summary>
/// Custom pickler for GenericRow objects.
/// </summary>
Expand Down
22 changes: 11 additions & 11 deletions src/csharp/Microsoft.Spark/Sql/Functions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3804,7 +3804,7 @@ public static Func<Column, Column, Column, Column, Column, Column, Column, Colum
/// <returns>
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column> Udf(Func<GenericRow> udf, StructType returnType)
public static Func<Column> Udf(Func<Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply0;
}
Expand All @@ -3816,7 +3816,7 @@ public static Func<Column> Udf(Func<GenericRow> udf, StructType returnType)
/// <returns>
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column> Udf<T>(Func<T, GenericRow> udf, StructType returnType)
public static Func<Column, Column> Udf<T>(Func<T, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply1;
}
Expand All @@ -3830,7 +3830,7 @@ public static Func<Column, Column> Udf<T>(Func<T, GenericRow> udf, StructType re
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column> Udf<T1, T2>(
Func<T1, T2, GenericRow> udf, StructType returnType)
Func<T1, T2, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply2;
}
Expand All @@ -3845,7 +3845,7 @@ public static Func<Column, Column, Column> Udf<T1, T2>(
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column> Udf<T1, T2, T3>(
Func<T1, T2, T3, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply3;
}
Expand All @@ -3861,7 +3861,7 @@ public static Func<Column, Column, Column, Column> Udf<T1, T2, T3>(
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4>(
Func<T1, T2, T3, T4, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply4;
}
Expand All @@ -3878,7 +3878,7 @@ public static Func<Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4>(
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5>(
Func<T1, T2, T3, T4, T5, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply5;
}
Expand All @@ -3896,7 +3896,7 @@ public static Func<Column, Column, Column, Column, Column, Column> Udf<T1, T2, T
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6>(
Func<T1, T2, T3, T4, T5, T6, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, T6, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply6;
}
Expand All @@ -3915,7 +3915,7 @@ public static Func<Column, Column, Column, Column, Column, Column, Column> Udf<T
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7>(
Func<T1, T2, T3, T4, T5, T6, T7, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, T6, T7, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply7;
}
Expand All @@ -3935,7 +3935,7 @@ public static Func<Column, Column, Column, Column, Column, Column, Column, Colum
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7, T8>(
Func<T1, T2, T3, T4, T5, T6, T7, T8, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, T6, T7, T8, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply8;
}
Expand All @@ -3954,7 +3954,7 @@ public static Func<Column, Column, Column, Column, Column, Column, Column, Colum
/// <param name="returnType">Schema associated with this row</param>
/// <returns>A delegate that when invoked will return a <see cref="Column"/> for the result of the UDF.</returns>
public static Func<Column, Column, Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7, T8, T9>(
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply9;
}
Expand All @@ -3976,7 +3976,7 @@ public static Func<Column, Column, Column, Column, Column, Column, Column, Colum
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10>(
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply10;
}
Expand Down
2 changes: 1 addition & 1 deletion src/csharp/Microsoft.Spark/Sql/GenericRow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,6 @@ public override bool Equals(object obj) =>
/// Returns the hash code of the current object.
/// </summary>
/// <returns>The hash code of the current object</returns>
public override int GetHashCode() => base.GetHashCode();
public override int GetHashCode() => base.GetHashCode();
}
}
20 changes: 18 additions & 2 deletions src/csharp/Microsoft.Spark/Sql/Row.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Spark.Sql.Types;

namespace Microsoft.Spark.Sql
Expand Down Expand Up @@ -36,6 +34,24 @@ internal Row(object[] values, StructType schema)
Convert();
}

/// <summary>
/// Constructor for the schema-less Row class used for chained UDFs.
/// </summary>
/// <param name="genericRow">GenericRow to convert from</param>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no conversion happening here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This constructor is for converting GenericRow to schema-less Row, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, it's just setting the given value to its member (not "convert"ing it)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, make sense.

internal Row(GenericRow genericRow)
{
_genericRow = genericRow;
}

/// <summary>
/// Returns schema-less Row which can happen within chained UDFs (same behavior as PySpark).
/// </summary>
/// <returns>schema-less Row</returns>
public static implicit operator Row(GenericRow genericRow)
{
return new Row(genericRow);
}

/// <summary>
/// Schema associated with this row.
/// </summary>
Expand Down
3 changes: 2 additions & 1 deletion src/csharp/Microsoft.Spark/Utils/PythonSerDe.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ static PythonSerDe()
Unpickler.registerConstructor(
"pyspark.sql.types", "_create_row_inbound_converter", s_rowConstructor);

// Register custom pickler for GenericRow objects.
// Register custom picklers.
Pickler.registerCustomPickler(typeof(Row), new RowPickler());
Pickler.registerCustomPickler(typeof(GenericRow), new GenericRowPickler());
}

Expand Down