Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -443,25 +443,7 @@ private ShapedQueryExpression CreateShapedQueryExpression(SelectExpression selec
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateAverage(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
if (selectExpression.IsDistinct
|| selectExpression.Limit != null
|| selectExpression.Offset != null)
{
return null;
}

if (selector != null)
{
source = TranslateSelect(source, selector);
}

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
projection = _sqlExpressionFactory.Function("AVG", new[] { projection }, resultType, _typeMappingSource.FindMapping(resultType));

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
}
=> TranslateAggregate(source, selector, resultType, "AVG");

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down Expand Up @@ -841,26 +823,7 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
if (selectExpression.IsDistinct
|| selectExpression.Limit != null
|| selectExpression.Offset != null)
{
return null;
}

if (selector != null)
{
source = TranslateSelect(source, selector);
}

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());

projection = _sqlExpressionFactory.Function("MAX", new[] { projection }, resultType, projection.TypeMapping);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
}
=> TranslateAggregate(source, selector, resultType, "MAX");

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand All @@ -869,26 +832,7 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
if (selectExpression.IsDistinct
|| selectExpression.Limit != null
|| selectExpression.Offset != null)
{
return null;
}

if (selector != null)
{
source = TranslateSelect(source, selector);
}

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());

projection = _sqlExpressionFactory.Function("MIN", new[] { projection }, resultType, projection.TypeMapping);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
}
=> TranslateAggregate(source, selector, resultType, "MIN");

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down Expand Up @@ -1241,7 +1185,7 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s

projection = _sqlExpressionFactory.Function("SUM", new[] { projection }, serverOutputType, projection.TypeMapping);

return AggregateResultShaper(source, projection, throwOnNullResult: false, resultType);
return AggregateResultShaper(source, projection, resultType);
}

/// <summary>
Expand Down Expand Up @@ -1515,6 +1459,35 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s

#endregion Queryable collection support

private ShapedQueryExpression? TranslateAggregate(ShapedQueryExpression source, LambdaExpression? selector, Type resultType, string functionName)
{
var selectExpression = (SelectExpression)source.QueryExpression;
if (selectExpression.IsDistinct
|| selectExpression.Limit != null
|| selectExpression.Offset != null)
{
return null;
}

if (selector != null)
{
source = TranslateSelect(source, selector);
}

if (!_subquery && resultType.IsNullableType())
{
// For nullable types, we want to return null from Max, Min, and Average, rather than throwing. See Issue #35094.
// Note that relational databases typically return null, which propagates. Cosmos will instead return no elements,
// and hence for Cosmos only we need to change no elements into null.
source = source.UpdateResultCardinality(ResultCardinality.SingleOrDefault);
}

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
projection = _sqlExpressionFactory.Function(functionName, [projection], resultType, _typeMappingSource.FindMapping(resultType));

return AggregateResultShaper(source, projection, resultType);
}

private bool TryApplyPredicate(ShapedQueryExpression source, LambdaExpression predicate)
{
var select = (SelectExpression)source.QueryExpression;
Expand Down Expand Up @@ -1695,7 +1668,6 @@ private Expression RemapLambdaBody(ShapedQueryExpression shapedQueryExpression,
private static ShapedQueryExpression AggregateResultShaper(
ShapedQueryExpression source,
Expression projection,
bool throwOnNullResult,
Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
Expand All @@ -1706,29 +1678,7 @@ private static ShapedQueryExpression AggregateResultShaper(
var nullableResultType = resultType.MakeNullable();
Expression shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType);

if (throwOnNullResult)
{
var resultVariable = Expression.Variable(nullableResultType, "result");
var returnValueForNull = resultType.IsNullableType()
? (Expression)Expression.Constant(null, resultType)
: Expression.Throw(
Expression.New(
typeof(InvalidOperationException).GetConstructors()
.Single(ci => ci.GetParameters().Length == 1),
Expression.Constant(CoreStrings.SequenceContainsNoElements)),
resultType);

shaper = Expression.Block(
new[] { resultVariable },
Expression.Assign(resultVariable, shaper),
Expression.Condition(
Expression.Equal(resultVariable, Expression.Default(nullableResultType)),
returnValueForNull,
resultType != resultVariable.Type
? Expression.Convert(resultVariable, resultType)
: resultVariable));
}
else if (resultType != shaper.Type)
if (resultType != shaper.Type)
{
shaper = Expression.Convert(shaper, resultType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent
ShapedQueryExpression source,
LambdaExpression? selector,
Type resultType)
=> TranslateAggregateWithSelector(source, selector, QueryableMethods.GetAverageWithoutSelector, throwWhenEmpty: true, resultType);
=> TranslateAggregateWithSelector(source, selector, QueryableMethods.GetAverageWithoutSelector, resultType);

/// <inheritdoc />
protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression source, Type resultType)
Expand Down Expand Up @@ -971,7 +971,7 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK
}

return TranslateAggregateWithSelector(
source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), resultType);
}

/// <inheritdoc />
Expand All @@ -990,7 +990,7 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK
}

return TranslateAggregateWithSelector(
source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), resultType);
}

/// <inheritdoc />
Expand Down Expand Up @@ -1241,7 +1241,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateSum(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
=> TranslateAggregateWithSelector(source, selector, QueryableMethods.GetSumWithoutSelector, throwWhenEmpty: false, resultType);
=> TranslateAggregateWithSelector(source, selector, QueryableMethods.GetSumWithoutSelector, resultType);

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateTake(ShapedQueryExpression source, Expression count)
Expand Down Expand Up @@ -1966,7 +1966,6 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
ShapedQueryExpression source,
LambdaExpression? selectorLambda,
Func<Type, MethodInfo> methodGenerator,
bool throwWhenEmpty,
Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
Expand Down Expand Up @@ -2012,48 +2011,13 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
new Dictionary<ProjectionMember, Expression> { { new ProjectionMember(), translation } });

selectExpression.ClearOrdering();
Expression shaper;

if (throwWhenEmpty)
{
// Avg/Max/Min case.
// We always read nullable value
// If resultType is nullable then we always return null. Only non-null result shows throwing behavior.
// otherwise, if projection.Type is nullable then server result is passed through DefaultIfEmpty, hence we return default
// otherwise, server would return null only if it is empty, and we throw
var nullableResultType = resultType.MakeNullable();
shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType);
var resultVariable = Expression.Variable(nullableResultType, "result");
var returnValueForNull = resultType.IsNullableType()
? (Expression)Expression.Default(resultType)
: translation.Type.IsNullableType()
? Expression.Default(resultType)
: Expression.Throw(
Expression.New(
typeof(InvalidOperationException).GetConstructors()
.Single(ci => ci.GetParameters().Length == 1),
Expression.Constant(CoreStrings.SequenceContainsNoElements)),
resultType);

shaper = Expression.Block(
new[] { resultVariable },
Expression.Assign(resultVariable, shaper),
Expression.Condition(
Expression.Equal(resultVariable, Expression.Default(nullableResultType)),
returnValueForNull,
resultType != resultVariable.Type
? Expression.Convert(resultVariable, resultType)
: resultVariable));
}
else
{
// Sum case. Projection is always non-null. We read nullable value.
shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), translation.Type.MakeNullable());

if (resultType != shaper.Type)
{
shaper = Expression.Convert(shaper, resultType);
}
// Sum case. Projection is always non-null. We read nullable value.
Expression shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), translation.Type.MakeNullable());

if (resultType != shaper.Type)
{
shaper = Expression.Convert(shaper, resultType);
}

return source.UpdateShaperExpression(shaper);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.ComponentModel.DataAnnotations.Schema;

namespace Microsoft.EntityFrameworkCore.Query;

#nullable disable
Expand Down Expand Up @@ -50,6 +52,115 @@ public enum MemberType

#endregion 34911

#region 35094

// TODO: Move these tests to a better location. They require nullable properties with nulls in the database.

[ConditionalFact]
public virtual async Task Min_over_value_type_containing_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().MinAsync(p => p.NullableVal));
}

[ConditionalFact]
public virtual async Task Min_over_value_type_containing_all_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableVal == null).MinAsync(p => p.NullableVal));
}

[ConditionalFact]
public virtual async Task Min_over_reference_type_containing_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().MinAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Min_over_reference_type_containing_all_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableRef == null).MinAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Min_over_reference_type_containing_no_data()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.Id < 0).MinAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Max_over_value_type_containing_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Equal(3.14, await context.Set<Context35094.Product>().MaxAsync(p => p.NullableVal));
}

[ConditionalFact]
public virtual async Task Max_over_value_type_containing_all_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableVal == null).MaxAsync(p => p.NullableVal));
}

[ConditionalFact]
public virtual async Task Max_over_reference_type_containing_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Equal("Value", await context.Set<Context35094.Product>().MaxAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Max_over_reference_type_containing_all_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableRef == null).MaxAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Max_over_reference_type_containing_no_data()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.Id < 0).MaxAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Average_over_value_type_containing_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().AverageAsync(p => p.NullableVal));
}

[ConditionalFact]
public virtual async Task Average_over_value_type_containing_all_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableVal == null).AverageAsync(p => p.NullableVal));
}

protected class Context35094(DbContextOptions options) : DbContext(options)
{
public DbSet<Product> Products { get; set; }

protected override void OnModelCreating(ModelBuilder modelBuilder)
=> modelBuilder.Entity<Product>().HasData(
new Product { Id = 1, NullableRef = "Value", NullableVal = 3.14 },
new Product { Id = 2, NullableVal = 3.14 },
new Product { Id = 3, NullableRef = "Value" });

public class Product
{
[DatabaseGenerated(DatabaseGeneratedOption.None)]
public int Id { get; set; }
public double? NullableVal { get; set; }
public string NullableRef { get; set; }
}
}

#endregion 35094

protected override string StoreName
=> "AdHocMiscellaneousQueryTests";

Expand Down
Loading