Skip to content

Commit 0e6e66c

Browse files
authored
Correct Cosmos ReadItem logic for abstract base types (#36884)
Fixes #36882
1 parent cc12a9a commit 0e6e66c

7 files changed

+151
-12
lines changed

src/EFCore.Cosmos/Query/Internal/CosmosReadItemAndPartitionKeysExtractor.cs

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public class CosmosReadItemAndPartitionKeysExtractor : ExpressionVisitor
2121
private IEntityType _entityType = null!;
2222
private string _rootAlias = null!;
2323
private bool _isPredicateCompatibleWithReadItem;
24-
private bool _discriminatorHandled;
24+
private bool _nonRootDiscriminatorInJsonId;
2525
private string? _discriminatorJsonPropertyName;
2626
private Dictionary<IProperty, Expression?> _jsonIdPropertyValues = null!;
2727
private Dictionary<IProperty, (Expression? ValueExpression, Expression? OriginalExpression)> _partitionKeyPropertyValues = null!;
@@ -69,8 +69,7 @@ public virtual Expression ExtractPartitionKeysAndId(
6969
_isPredicateCompatibleWithReadItem = false;
7070
}
7171

72-
_discriminatorHandled = jsonIdDefinition?.IncludesDiscriminator != true
73-
|| jsonIdDefinition.DiscriminatorIsRootType;
72+
_nonRootDiscriminatorInJsonId = jsonIdDefinition is { IncludesDiscriminator: true, DiscriminatorIsRootType: false };
7473

7574
_jsonIdPropertyValues = jsonIdProperties.ToDictionary(p => p, _ => (Expression?)null);
7675

@@ -190,11 +189,14 @@ protected override Expression VisitExtension(Expression node)
190189
// This is the case where there is more than one possible discriminator value to look for, so we can only
191190
// ignore it if the discriminator is either not included in the JSON id (which means it must be unique without it)
192191
// or the root discriminator type must be included in the JSON id, since this value is always known.
193-
if (_discriminatorHandled
192+
if (!_nonRootDiscriminatorInJsonId
194193
&& scalarAccessExpression.PropertyName == _discriminatorJsonPropertyName)
195194
{
196195
var comparer = _entityType.FindDiscriminatorProperty()!.GetValueComparer();
197-
var discriminatorValues = _entityType.GetDerivedTypesInclusive().Select(e => e.GetDiscriminatorValue()).ToList();
196+
var discriminatorValues = _entityType.GetDerivedTypesInclusive()
197+
.Where(t => !t.IsAbstract())
198+
.Select(e => e.GetDiscriminatorValue())
199+
.ToList();
198200
if (discriminatorValues.Count == sqlExpressions.Length)
199201
{
200202
foreach (var sqlExpression in sqlExpressions)
@@ -268,18 +270,32 @@ void ProcessPropertyComparison(string propertyName, SqlExpression propertyValue,
268270
// property, a partition key property, or certain cases involving the discriminator property.
269271
var isCompatibleComparisonForReadItem = false;
270272

273+
// Handle comparison of the discriminator property
271274
if (propertyName == _discriminatorJsonPropertyName
272275
&& propertyValue is SqlConstantExpression { Value: { } specifiedDiscriminatorValue }
273-
&& _entityType.FindDiscriminatorProperty() is { } discriminatorProperty
274-
&& _entityType.GetDiscriminatorValue() is { } entityDiscriminatorValue
275-
&& discriminatorProperty.GetValueComparer().Equals(specifiedDiscriminatorValue, entityDiscriminatorValue))
276+
&& _entityType.FindDiscriminatorProperty() is { } discriminatorProperty)
276277
{
277-
// This is the case where there is a single leaf node with a discriminator value. We always know this value,
278-
// so the query never needs to drop out of ReadItem because of it.
279-
isCompatibleComparisonForReadItem = true;
278+
if (_entityType.GetDiscriminatorValue() is { } entityDiscriminatorValue
279+
&& discriminatorProperty.GetValueComparer().Equals(specifiedDiscriminatorValue, entityDiscriminatorValue))
280+
{
281+
isCompatibleComparisonForReadItem = true;
282+
}
283+
// If there are abstract base types involved (e.g. we're querying on an abstract base class), the discriminator
284+
// may not correspond (in the above condition) but we can still use ReadItem - as long as
285+
// there's only a single non-abstract leaf type and also we don't have a (non-root) discriminator in the JSON ID.
286+
else if (!_nonRootDiscriminatorInJsonId
287+
&& _entityType.GetDerivedTypesInclusive()
288+
.Where(t => !t.IsAbstract())
289+
.Select(e => e.GetDiscriminatorValue())
290+
.ToList() is [var entityDiscriminatorValue2]
291+
&& discriminatorProperty.GetValueComparer().Equals(specifiedDiscriminatorValue, entityDiscriminatorValue2))
292+
{
293+
isCompatibleComparisonForReadItem = true;
294+
}
280295
}
281296
else
282297
{
298+
// Handle comparison of the JSON ID properties
283299
foreach (var property in _jsonIdPropertyValues.Keys)
284300
{
285301
if (propertyName == property.GetJsonPropertyName())
@@ -296,6 +312,7 @@ void ProcessPropertyComparison(string propertyName, SqlExpression propertyValue,
296312
}
297313
}
298314

315+
// Handle comparison of the partition key properties
299316
foreach (var property in _partitionKeyPropertyValues.Keys)
300317
{
301318
// We found a comparison for a partition key property.

test/EFCore.Cosmos.FunctionalTests/Query/ReadItemPartitionKeyQueryDiscriminatorInIdTest.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,19 @@ FROM root c
557557
""");
558558
}
559559

560+
public override async Task ReadItem_for_abstract_base_type_with_shared_container()
561+
{
562+
// Not ReadItem because discriminator value in the JSON id is unknown
563+
await base.ReadItem_for_abstract_base_type_with_shared_container();
564+
565+
AssertSql(
566+
"""
567+
SELECT VALUE c
568+
FROM root c
569+
WHERE ((c["$type"] = "SharedContainerEntity3Child") AND (c["Id"] = 6))
570+
""");
571+
}
572+
560573
public override async Task ReadItem_for_child_type_with_shared_container()
561574
{
562575
await base.ReadItem_for_child_type_with_shared_container();

test/EFCore.Cosmos.FunctionalTests/Query/ReadItemPartitionKeyQueryFixtureBase.cs

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System.Linq;
5+
46
namespace Microsoft.EntityFrameworkCore.Query;
57

68
public class ReadItemPartitionKeyQueryFixtureBase : SharedStoreFixtureBase<DbContext>, IQueryFixtureBase
@@ -63,6 +65,13 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
6365

6466
modelBuilder.Entity<SharedContainerEntity2Child>();
6567

68+
modelBuilder.Entity<SharedContainerEntity3>()
69+
.ToContainer("SharedContainer")
70+
.HasPartitionKey(e => e.PartitionKey)
71+
.HasKey(e => new { e.Id, e.PartitionKey });
72+
73+
modelBuilder.Entity<SharedContainerEntity3Child>();
74+
6675
modelBuilder.Entity<FancyDiscriminatorEntity>()
6776
.ToContainer("Cat35224")
6877
.HasPartitionKey(e => e.Id)
@@ -91,6 +100,7 @@ protected override Task SeedAsync(DbContext context)
91100
context.AddRange(data.SharedContainerEntities1);
92101
context.AddRange(data.SharedContainerEntities2);
93102
context.AddRange(data.SharedContainerEntities2Children);
103+
context.AddRange(data.SharedContainerEntities3Children);
94104
context.AddRange(data.Cat35224Entities);
95105

96106
return context.SaveChangesAsync();
@@ -108,7 +118,8 @@ public virtual ISetSource GetExpectedData()
108118
{ typeof(OnlySinglePartitionKeyEntity), e => ((OnlySinglePartitionKeyEntity?)e)?.Payload },
109119
{ typeof(NoPartitionKeyEntity), e => ((NoPartitionKeyEntity?)e)?.Id },
110120
{ typeof(SharedContainerEntity1), e => ((SharedContainerEntity1?)e)?.Id },
111-
{ typeof(SharedContainerEntity2), e => ((SharedContainerEntity2?)e)?.Id }
121+
{ typeof(SharedContainerEntity2), e => ((SharedContainerEntity2?)e)?.Id },
122+
{ typeof(SharedContainerEntity3), e => ((SharedContainerEntity3?)e)?.Id }
112123
}.ToDictionary(e => e.Key, e => (object)e.Value);
113124

114125
public IReadOnlyDictionary<Type, object> EntityAsserters { get; } = new Dictionary<Type, Action<object?, object?>>
@@ -243,6 +254,39 @@ public virtual ISetSource GetExpectedData()
243254
}
244255
}
245256
},
257+
{
258+
typeof(SharedContainerEntity3), (e, a) =>
259+
{
260+
Assert.Equal(e == null, a == null);
261+
262+
if (a != null)
263+
{
264+
var ee = (SharedContainerEntity3)e!;
265+
var aa = (SharedContainerEntity3)a;
266+
267+
Assert.Equal(ee.Id, aa.Id);
268+
Assert.Equal(ee.PartitionKey, aa.PartitionKey);
269+
Assert.Equal(ee.Payload3, aa.Payload3);
270+
}
271+
}
272+
},
273+
{
274+
typeof(SharedContainerEntity3Child), (e, a) =>
275+
{
276+
Assert.Equal(e == null, a == null);
277+
278+
if (a != null)
279+
{
280+
var ee = (SharedContainerEntity3Child)e!;
281+
var aa = (SharedContainerEntity3Child)a;
282+
283+
Assert.Equal(ee.Id, aa.Id);
284+
Assert.Equal(ee.PartitionKey, aa.PartitionKey);
285+
Assert.Equal(ee.Payload3, aa.Payload3);
286+
Assert.Equal(ee.Child1Payload, aa.Child1Payload);
287+
}
288+
}
289+
},
246290
{
247291
typeof(FancyDiscriminatorEntity), (e, a) =>
248292
{
@@ -274,6 +318,7 @@ public class PartitionKeyData : ISetSource
274318
public List<SharedContainerEntity1> SharedContainerEntities1 { get; } = CreateSharedContainerEntities1();
275319
public List<SharedContainerEntity2> SharedContainerEntities2 { get; } = CreateSharedContainerEntities2();
276320
public List<SharedContainerEntity2Child> SharedContainerEntities2Children { get; } = CreateSharedContainerEntities2Children();
321+
public List<SharedContainerEntity3Child> SharedContainerEntities3Children { get; } = CreateSharedContainerEntities3Children1();
277322

278323
public List<FancyDiscriminatorEntity> Cat35224Entities { get; } = CreateCat35224Entities();
279324

@@ -320,6 +365,11 @@ public virtual IQueryable<TEntity> Set<TEntity>()
320365
return (IQueryable<TEntity>)SharedContainerEntities2Children.AsQueryable();
321366
}
322367

368+
if (typeof(TEntity) == typeof(SharedContainerEntity3) || typeof(TEntity) == typeof(SharedContainerEntity3Child))
369+
{
370+
return (IQueryable<TEntity>)SharedContainerEntities3Children.AsQueryable();
371+
}
372+
323373
if (typeof(TEntity) == typeof(FancyDiscriminatorEntity))
324374
{
325375
return (IQueryable<TEntity>)Cat35224Entities.AsQueryable();
@@ -518,6 +568,26 @@ private static List<SharedContainerEntity2Child> CreateSharedContainerEntities2C
518568
}
519569
];
520570

571+
private static List<SharedContainerEntity3Child> CreateSharedContainerEntities3Children1()
572+
=>
573+
[
574+
new SharedContainerEntity3Child
575+
{
576+
Id = 6,
577+
PartitionKey = "PK1",
578+
Payload3 = "Payload8",
579+
Child1Payload = "Child1"
580+
},
581+
582+
new SharedContainerEntity3Child
583+
{
584+
Id = 6,
585+
PartitionKey = "PK2",
586+
Payload3 = "Payload9",
587+
Child1Payload = "Child2"
588+
}
589+
];
590+
521591
private static List<FancyDiscriminatorEntity> CreateCat35224Entities()
522592
=>
523593
[
@@ -599,3 +669,15 @@ public class SharedContainerEntity2Child : SharedContainerEntity2
599669
{
600670
public required string ChildPayload { get; set; }
601671
}
672+
673+
public abstract class SharedContainerEntity3
674+
{
675+
public int Id { get; set; }
676+
public required string PartitionKey { get; set; }
677+
public required string Payload3 { get; set; }
678+
}
679+
680+
public class SharedContainerEntity3Child : SharedContainerEntity3
681+
{
682+
public required string Child1Payload { get; set; }
683+
}

test/EFCore.Cosmos.FunctionalTests/Query/ReadItemPartitionKeyQueryNoDiscriminatorInIdTest.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,13 @@ public override async Task ReadItem_for_base_type_with_shared_container()
442442
AssertSql("""ReadItem(["PK2"], 4)""");
443443
}
444444

445+
public override async Task ReadItem_for_abstract_base_type_with_shared_container()
446+
{
447+
await base.ReadItem_for_abstract_base_type_with_shared_container();
448+
449+
AssertSql("""ReadItem(["PK2"], 6)""");
450+
}
451+
445452
public override async Task ReadItem_for_child_type_with_shared_container()
446453
{
447454
await base.ReadItem_for_child_type_with_shared_container();

test/EFCore.Cosmos.FunctionalTests/Query/ReadItemPartitionKeyQueryRootDiscriminatorInIdTest.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,13 @@ public override async Task ReadItem_for_base_type_with_shared_container()
435435
AssertSql("""ReadItem(["PK2"], SharedContainerEntity2|4)""");
436436
}
437437

438+
public override async Task ReadItem_for_abstract_base_type_with_shared_container()
439+
{
440+
await base.ReadItem_for_abstract_base_type_with_shared_container();
441+
442+
AssertSql("""ReadItem(["PK2"], SharedContainerEntity3|6)""");
443+
}
444+
438445
public override async Task ReadItem_for_child_type_with_shared_container()
439446
{
440447
await base.ReadItem_for_child_type_with_shared_container();

test/EFCore.Cosmos.FunctionalTests/Query/ReadItemPartitionKeyQueryTest.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,13 @@ public override async Task ReadItem_for_base_type_with_shared_container()
435435
AssertSql("""ReadItem(["PK2"], 4)""");
436436
}
437437

438+
public override async Task ReadItem_for_abstract_base_type_with_shared_container()
439+
{
440+
await base.ReadItem_for_abstract_base_type_with_shared_container();
441+
442+
AssertSql("""ReadItem(["PK2"], 6)""");
443+
}
444+
438445
public override async Task ReadItem_for_child_type_with_shared_container()
439446
{
440447
await base.ReadItem_for_child_type_with_shared_container();

test/EFCore.Cosmos.FunctionalTests/Query/ReadItemPartitionKeyQueryTestBase.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,12 @@ public virtual Task ReadItem_for_base_type_with_shared_container()
363363
async: true,
364364
ss => ss.Set<SharedContainerEntity2>().Where(e => e.Id == 4 && e.PartitionKey == "PK2"));
365365

366+
[ConditionalFact]
367+
public virtual Task ReadItem_for_abstract_base_type_with_shared_container()
368+
=> AssertQuery(
369+
async: true,
370+
ss => ss.Set<SharedContainerEntity3>().Where(e => e.Id == 6 && e.PartitionKey == "PK2"));
371+
366372
[ConditionalFact]
367373
public virtual Task ReadItem_for_child_type_with_shared_container()
368374
=> AssertQuery(

0 commit comments

Comments
 (0)