Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/csharp/Microsoft.Spark.E2ETest/Resources/people.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{"name":"Michael", "ids":[1], "info":{"city":"Burdwan", "state":"Paschimbanga"}}
{"name":"Andy", "age":30, "ids":[3,5], "info":{"city":"Los Angeles", "state":"California"}}
{"name":"Justin", "age":19, "ids":[2,4], "info":{"city":"Seattle"}}
{"name":"Michael", "ids":[1], "info1":{"city":"Burdwan"}, "info2":{"state":"Paschimbanga"}, "info3":{"company":{"job":"Developer"}}}"
{"name":"Andy", "age":30, "ids":[3,5], "info1":{"city":"Los Angeles"}, "info2":{"state":"California"}, "info3":{"company":{"job":"Developer"}}}
{"name":"Justin", "age":19, "ids":[2,4], "info1":{"city":"Seattle"}, "info2":{"state":"Washington"}, "info3":{"company":{"job":"Developer"}}}
110 changes: 90 additions & 20 deletions src/csharp/Microsoft.Spark.E2ETest/UdfTests/UdfComplexTypesTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,62 @@ public void TestUdfWithReturnAsMapType()
[Fact]
public void TestUdfWithRowType()
{
Func<Column, Column> udf = Udf<Row, string>(
(row) =>
{
string city = row.GetAs<string>("city");
string state = row.GetAs<string>("state");
return $"{city},{state}";
});
// Single Row
{
Func<Column, Column> udf = Udf<Row, string>(
(row) =>
{
return row.GetAs<string>("city");
});

Row[] rows = _df.Select(udf(_df["info"])).Collect().ToArray();
Assert.Equal(3, rows.Length);
Row[] rows = _df.Select(udf(_df["info1"])).Collect().ToArray();
Assert.Equal(3, rows.Length);

var expected = new[] { "Burdwan,Paschimbanga", "Los Angeles,California", "Seattle," };
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
Assert.Equal(expected, actual);
var expected = new[] { "Burdwan", "Los Angeles", "Seattle" };
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
Assert.Equal(expected, actual);
}

// Multiple Rows
{
Func<Column, Column, Column, Column> udf = Udf<Row, Row, string, string>(
(row1, row2, str) =>
{
string city = row1.GetAs<string>("city");
string state = row2.GetAs<string>("state");
return $"{str}:{city},{state}";
});

Row[] rows = _df
.Select(udf(_df["info1"], _df["info2"], _df["name"]))
.Collect()
.ToArray();
Assert.Equal(3, rows.Length);

var expected = new[] {
"Michael:Burdwan,Paschimbanga",
"Andy:Los Angeles,California",
"Justin:Seattle,Washington" };
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
Assert.Equal(expected, actual);
}

// Nested Row
{
Func<Column, Column> udf = Udf<Row, string>(
(row) =>
{
Row outerCol = row.GetAs<Row>("company");
return outerCol.GetAs<string>("job");
});

Row[] rows = _df.Select(udf(_df["info3"])).Collect().ToArray();
Assert.Equal(3, rows.Length);

var expected = new[] { "Developer", "Developer", "Developer" };
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
Assert.Equal(expected, actual);
}
}

/// <summary>
Expand All @@ -168,14 +210,40 @@ public void TestUdfWithReturnAsRowType()
Func<Column, Column> udf = Udf<string>(
str => new GenericRow(new object[] { 1, "abc" }), schema);

Row[] rows = _df.Select(udf(_df["name"])).Collect().ToArray();
Row[] rows = _df.Select(udf(_df["name"]).As("col")).Collect().ToArray();
Assert.Equal(3, rows.Length);
foreach (Row row in rows)
{
Assert.Equal(1, row.Size());
Row outerCol = row.GetAs<Row>("col");
Assert.Equal(2, outerCol.Size());
Assert.Equal(1, outerCol.GetAs<int>("col1"));
Assert.Equal("abc", outerCol.GetAs<string>("col2"));
}
}

// Generic row is a part of top-level column.
{
var schema = new StructType(new[]
{
new StructField("col1", new IntegerType())
});
Func<Column, Column> udf = Udf<string>(
str => new GenericRow(new object[] { 111 }), schema);

Column nameCol = _df["name"];
Row[] rows = _df.Select(udf(nameCol).As("col"), nameCol).Collect().ToArray();
Assert.Equal(3, rows.Length);

foreach (Row row in rows)
{
Assert.Equal(2, row.Size());
Assert.Equal(1, row.GetAs<int>("col1"));
Assert.Equal("abc", row.GetAs<string>("col2"));
Row col1 = row.GetAs<Row>("col");
Assert.Equal(1, col1.Size());
Assert.Equal(111, col1.GetAs<int>("col1"));

string col2 = row.GetAs<string>("name");
Assert.NotEmpty(col2);
}
}

Expand Down Expand Up @@ -211,21 +279,23 @@ public void TestUdfWithReturnAsRowType()
}),
schema);

Row[] rows = _df.Select(udf(_df["name"])).Collect().ToArray();
Row[] rows = _df.Select(udf(_df["name"]).As("col")).Collect().ToArray();
Assert.Equal(3, rows.Length);

foreach (Row row in rows)
{
Assert.Equal(3, row.Size());
Assert.Equal(1, row.GetAs<int>("col1"));
Assert.Equal(1, row.Size());
Row outerCol = row.GetAs<Row>("col");
Assert.Equal(3, outerCol.Size());
Assert.Equal(1, outerCol.GetAs<int>("col1"));
Assert.Equal(
new Row(new object[] { 1 }, subSchema1),
row.GetAs<Row>("col2"));
outerCol.GetAs<Row>("col2"));
Assert.Equal(
new Row(
new object[] { "abc", new Row(new object[] { 10 }, subSchema1) },
subSchema2),
row.GetAs<Row>("col3"));
outerCol.GetAs<Row>("col3"));
}
}
}
Expand Down
11 changes: 10 additions & 1 deletion src/csharp/Microsoft.Spark.Worker/Command/SqlCommandExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,17 @@ protected override CommandExecutorStat ExecuteCore(

for (int i = 0; i < inputRows.Length; ++i)
{
object row = inputRows[i];
// The following can happen if an UDF takes Row object(s).
// The JVM Spark side sends a Row object that wraps all the columns used
// in the UDF, thus, it is normalized below (the extra layer is removed).
if (row is RowConstructor rowConstructor)
{
row = rowConstructor.GetRow().Values;
}

// Split id is not used for SQL UDFs, so 0 is passed.
outputRows.Add(commandRunner.Run(0, inputRows[i]));
outputRows.Add(commandRunner.Run(0, row));
}

// The initial (estimated) buffer size for pickling rows is set to the size of
Expand Down
20 changes: 1 addition & 19 deletions src/csharp/Microsoft.Spark/Sql/RowCollector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.Spark.Interop.Ipc;
Expand Down Expand Up @@ -34,24 +33,7 @@ public IEnumerable<Row> Collect(ISocketWrapper socket)

foreach (object unpickled in unpickledObjects)
{
// Unpickled object can be either a RowConstructor object (not materialized),
// or a Row object (materialized). Refer to RowConstruct.construct() to see how
// Row objects are unpickled.
switch (unpickled)
{
case RowConstructor rc:
yield return rc.GetRow();
break;

case object[] objs when objs.Length == 1 && (objs[0] is Row row):
yield return row;
break;

default:
throw new NotSupportedException(
string.Format("Unpickle type {0} is not supported",
unpickled.GetType()));
}
yield return (unpickled as RowConstructor).GetRow();
}
}
}
Expand Down
9 changes: 0 additions & 9 deletions src/csharp/Microsoft.Spark/Sql/RowConstructor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,6 @@ public object construct(object[] args)
s_schemaCache = new Dictionary<string, StructType>();
}

// When a row is ready to be materialized, then construct() is called
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually a breaking change, but @suhsteve is updating the worker version in his PR: https://github.com/dotnet/spark/pull/387/files

// on the RowConstructor which represents the row.
if ((args.Length == 1) && (args[0] is RowConstructor rowConstructor))
{
// Construct the Row and return args containing the Row.
args[0] = rowConstructor.GetRow();
return args;
}

// Return a new RowConstructor where the args either represent the
// schema or the row data. The parent becomes important when calling
// GetRow() on the RowConstructor containing the row data.
Expand Down