Skip to content

Commit 845d749

Browse files
committed
Return null when the type is nullable for Cosmos Max/Min/Average
Fixes #35094 This was a regression resulting from the major Cosmos query refactoring that happened in EF9. In EF8, the functions Min, Max, and Average would return null if the return type was nullable or was cast to a nullable when the collection is empty. In EF9, this started throwing, which is correct for non-nullable types, but a regression for nullable types.
1 parent a4a350a commit 845d749

File tree

2 files changed

+81
-145
lines changed

2 files changed

+81
-145
lines changed

src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs

Lines changed: 33 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,25 @@ private ShapedQueryExpression CreateShapedQueryExpression(SelectExpression selec
444444
/// </summary>
445445
protected override ShapedQueryExpression? TranslateAverage(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
446446
{
447-
var selectExpression = (SelectExpression)source.QueryExpression;
447+
var updatedSource = TranslateAggregateCommon(source, selector, resultType, out var selectExpression);
448+
if (updatedSource == null)
449+
{
450+
return null;
451+
}
452+
453+
var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
454+
projection = _sqlExpressionFactory.Function("AVG", new[] { projection }, resultType, _typeMappingSource.FindMapping(resultType));
455+
456+
return AggregateResultShaper(updatedSource, projection, resultType);
457+
}
458+
459+
private ShapedQueryExpression? TranslateAggregateCommon(
460+
ShapedQueryExpression source,
461+
LambdaExpression? selector,
462+
Type resultType,
463+
out SelectExpression selectExpression)
464+
{
465+
selectExpression = (SelectExpression)source.QueryExpression;
448466
if (selectExpression.IsDistinct
449467
|| selectExpression.Limit != null
450468
|| selectExpression.Offset != null)
@@ -457,10 +475,13 @@ private ShapedQueryExpression CreateShapedQueryExpression(SelectExpression selec
457475
source = TranslateSelect(source, selector);
458476
}
459477

460-
var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
461-
projection = _sqlExpressionFactory.Function("AVG", new[] { projection }, resultType, _typeMappingSource.FindMapping(resultType));
478+
if (resultType.IsNullableType())
479+
{
480+
// For nullable types, we want to return null from Max, Min, and Average, rather than throwing. See Issue #35094.
481+
source = source.UpdateResultCardinality(ResultCardinality.SingleOrDefault);
482+
}
462483

463-
return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
484+
return source;
464485
}
465486

466487
/// <summary>
@@ -842,24 +863,17 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou
842863
/// </summary>
843864
protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
844865
{
845-
var selectExpression = (SelectExpression)source.QueryExpression;
846-
if (selectExpression.IsDistinct
847-
|| selectExpression.Limit != null
848-
|| selectExpression.Offset != null)
866+
var updatedSource = TranslateAggregateCommon(source, selector, resultType, out var selectExpression);
867+
if (updatedSource == null)
849868
{
850869
return null;
851870
}
852871

853-
if (selector != null)
854-
{
855-
source = TranslateSelect(source, selector);
856-
}
857-
858872
var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
859873

860874
projection = _sqlExpressionFactory.Function("MAX", new[] { projection }, resultType, projection.TypeMapping);
861875

862-
return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
876+
return AggregateResultShaper(updatedSource, projection, resultType);
863877
}
864878

865879
/// <summary>
@@ -870,24 +884,17 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou
870884
/// </summary>
871885
protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
872886
{
873-
var selectExpression = (SelectExpression)source.QueryExpression;
874-
if (selectExpression.IsDistinct
875-
|| selectExpression.Limit != null
876-
|| selectExpression.Offset != null)
887+
var updatedSource = TranslateAggregateCommon(source, selector, resultType, out var selectExpression);
888+
if (updatedSource == null)
877889
{
878890
return null;
879891
}
880892

881-
if (selector != null)
882-
{
883-
source = TranslateSelect(source, selector);
884-
}
885-
886893
var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
887894

888895
projection = _sqlExpressionFactory.Function("MIN", new[] { projection }, resultType, projection.TypeMapping);
889896

890-
return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
897+
return AggregateResultShaper(updatedSource, projection, resultType);
891898
}
892899

893900
/// <summary>
@@ -1241,7 +1248,7 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s
12411248

12421249
projection = _sqlExpressionFactory.Function("SUM", new[] { projection }, serverOutputType, projection.TypeMapping);
12431250

1244-
return AggregateResultShaper(source, projection, throwOnNullResult: false, resultType);
1251+
return AggregateResultShaper(source, projection, resultType);
12451252
}
12461253

12471254
/// <summary>
@@ -1695,7 +1702,6 @@ private Expression RemapLambdaBody(ShapedQueryExpression shapedQueryExpression,
16951702
private static ShapedQueryExpression AggregateResultShaper(
16961703
ShapedQueryExpression source,
16971704
Expression projection,
1698-
bool throwOnNullResult,
16991705
Type resultType)
17001706
{
17011707
var selectExpression = (SelectExpression)source.QueryExpression;
@@ -1706,29 +1712,7 @@ private static ShapedQueryExpression AggregateResultShaper(
17061712
var nullableResultType = resultType.MakeNullable();
17071713
Expression shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType);
17081714

1709-
if (throwOnNullResult)
1710-
{
1711-
var resultVariable = Expression.Variable(nullableResultType, "result");
1712-
var returnValueForNull = resultType.IsNullableType()
1713-
? (Expression)Expression.Constant(null, resultType)
1714-
: Expression.Throw(
1715-
Expression.New(
1716-
typeof(InvalidOperationException).GetConstructors()
1717-
.Single(ci => ci.GetParameters().Length == 1),
1718-
Expression.Constant(CoreStrings.SequenceContainsNoElements)),
1719-
resultType);
1720-
1721-
shaper = Expression.Block(
1722-
new[] { resultVariable },
1723-
Expression.Assign(resultVariable, shaper),
1724-
Expression.Condition(
1725-
Expression.Equal(resultVariable, Expression.Default(nullableResultType)),
1726-
returnValueForNull,
1727-
resultType != resultVariable.Type
1728-
? Expression.Convert(resultVariable, resultType)
1729-
: resultVariable));
1730-
}
1731-
else if (resultType != shaper.Type)
1715+
if (resultType != shaper.Type)
17321716
{
17331717
shaper = Expression.Convert(shaper, resultType);
17341718
}

test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs

Lines changed: 48 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -555,49 +555,33 @@ FROM root c
555555
}
556556
}
557557

558-
public override async Task Average_no_data_nullable(bool async)
559-
{
560-
// Sync always throws before getting to exception being tested.
561-
if (async)
562-
{
563-
await Fixture.NoSyncTest(
564-
async, async a =>
565-
{
566-
Assert.Equal(
567-
CoreStrings.SequenceContainsNoElements,
568-
(await Assert.ThrowsAsync<InvalidOperationException>(() => base.Average_no_data_nullable(a))).Message);
558+
public override Task Average_no_data_nullable(bool async)
559+
=> Fixture.NoSyncTest(
560+
async, async a =>
561+
{
562+
await base.Average_no_data_nullable(a);
569563

570-
AssertSql(
571-
"""
564+
AssertSql(
565+
"""
572566
SELECT VALUE AVG(c["SupplierID"])
573567
FROM root c
574568
WHERE ((c["$type"] = "Product") AND (c["SupplierID"] = -1))
575569
""");
576-
});
577-
}
578-
}
570+
});
579571

580-
public override async Task Average_no_data_cast_to_nullable(bool async)
581-
{
582-
// Sync always throws before getting to exception being tested.
583-
if (async)
584-
{
585-
await Fixture.NoSyncTest(
586-
async, async a =>
587-
{
588-
Assert.Equal(
589-
CoreStrings.SequenceContainsNoElements,
590-
(await Assert.ThrowsAsync<InvalidOperationException>(() => base.Average_no_data_cast_to_nullable(a))).Message);
572+
public override Task Average_no_data_cast_to_nullable(bool async)
573+
=> Fixture.NoSyncTest(
574+
async, async a =>
575+
{
576+
await base.Average_no_data_cast_to_nullable(a);
591577

592-
AssertSql(
593-
"""
578+
AssertSql(
579+
"""
594580
SELECT VALUE AVG(c["OrderID"])
595581
FROM root c
596582
WHERE ((c["$type"] = "Order") AND (c["OrderID"] = -1))
597583
""");
598-
});
599-
}
600-
}
584+
});
601585

602586
public override async Task Min_no_data(bool async)
603587
{
@@ -647,49 +631,33 @@ public override async Task Max_no_data_subquery(bool async)
647631
AssertSql();
648632
}
649633

650-
public override async Task Max_no_data_nullable(bool async)
651-
{
652-
// Sync always throws before getting to exception being tested.
653-
if (async)
654-
{
655-
await Fixture.NoSyncTest(
656-
async, async a =>
657-
{
658-
Assert.Equal(
659-
CoreStrings.SequenceContainsNoElements,
660-
(await Assert.ThrowsAsync<InvalidOperationException>(() => base.Max_no_data_nullable(a))).Message);
634+
public override Task Max_no_data_nullable(bool async)
635+
=> Fixture.NoSyncTest(
636+
async, async a =>
637+
{
638+
await base.Max_no_data_nullable(a);
661639

662-
AssertSql(
663-
"""
640+
AssertSql(
641+
"""
664642
SELECT VALUE MAX(c["SupplierID"])
665643
FROM root c
666644
WHERE ((c["$type"] = "Product") AND (c["SupplierID"] = -1))
667645
""");
668-
});
669-
}
670-
}
646+
});
671647

672-
public override async Task Max_no_data_cast_to_nullable(bool async)
673-
{
674-
// Sync always throws before getting to exception being tested.
675-
if (async)
676-
{
677-
await Fixture.NoSyncTest(
678-
async, async a =>
679-
{
680-
Assert.Equal(
681-
CoreStrings.SequenceContainsNoElements,
682-
(await Assert.ThrowsAsync<InvalidOperationException>(() => base.Max_no_data_cast_to_nullable(a))).Message);
648+
public override Task Max_no_data_cast_to_nullable(bool async)
649+
=> Fixture.NoSyncTest(
650+
async, async a =>
651+
{
652+
await base.Max_no_data_cast_to_nullable(a);
683653

684-
AssertSql(
685-
"""
654+
AssertSql(
655+
"""
686656
SELECT VALUE MAX(c["OrderID"])
687657
FROM root c
688658
WHERE ((c["$type"] = "Order") AND (c["OrderID"] = -1))
689659
""");
690-
});
691-
}
692-
}
660+
});
693661

694662
public override async Task Min_no_data_subquery(bool async)
695663
{
@@ -868,49 +836,33 @@ FROM root c
868836
""");
869837
});
870838

871-
public override async Task Min_no_data_nullable(bool async)
872-
{
873-
// Sync always throws before getting to exception being tested.
874-
if (async)
875-
{
876-
await Fixture.NoSyncTest(
877-
async, async a =>
878-
{
879-
Assert.Equal(
880-
CoreStrings.SequenceContainsNoElements,
881-
(await Assert.ThrowsAsync<InvalidOperationException>(() => base.Min_no_data_nullable(a))).Message);
839+
public override Task Min_no_data_nullable(bool async)
840+
=> Fixture.NoSyncTest(
841+
async, async a =>
842+
{
843+
await base.Min_no_data_nullable(a);
882844

883-
AssertSql(
884-
"""
845+
AssertSql(
846+
"""
885847
SELECT VALUE MIN(c["SupplierID"])
886848
FROM root c
887849
WHERE ((c["$type"] = "Product") AND (c["SupplierID"] = -1))
888850
""");
889-
});
890-
}
891-
}
851+
});
892852

893-
public override async Task Min_no_data_cast_to_nullable(bool async)
894-
{
895-
// Sync always throws before getting to exception being tested.
896-
if (async)
897-
{
898-
await Fixture.NoSyncTest(
899-
async, async a =>
900-
{
901-
Assert.Equal(
902-
CoreStrings.SequenceContainsNoElements,
903-
(await Assert.ThrowsAsync<InvalidOperationException>(() => base.Min_no_data_cast_to_nullable(a))).Message);
853+
public override Task Min_no_data_cast_to_nullable(bool async)
854+
=> Fixture.NoSyncTest(
855+
async, async a =>
856+
{
857+
await base.Min_no_data_cast_to_nullable(a);
904858

905-
AssertSql(
906-
"""
859+
AssertSql(
860+
"""
907861
SELECT VALUE MIN(c["OrderID"])
908862
FROM root c
909863
WHERE ((c["$type"] = "Order") AND (c["OrderID"] = -1))
910864
""");
911-
});
912-
}
913-
}
865+
});
914866

915867
public override Task Min_with_coalesce(bool async)
916868
=> Fixture.NoSyncTest(

0 commit comments

Comments
 (0)