Skip to content

Commit 8c1d258

Browse files
committed
Ensure entities from query are tracked as correct type
Issue #4817 The problem was that query was passing in the base type to the state manager since the actual type varies by row. It would likely be possible to flow the actual type through from the materializer but this would require significant changes and would not be used in a lot of cases. Instead the state manager now looks up the entity type if it is not the base type, with some optimizations to that lookup made in metadata.
1 parent 39b2aa3 commit 8c1d258

File tree

14 files changed

+139
-52
lines changed

14 files changed

+139
-52
lines changed

src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/IStateManager.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public interface IStateManager
1515
{
1616
InternalEntityEntry GetOrCreateEntry([NotNull] object entity);
1717

18-
InternalEntityEntry StartTrackingFromQuery([NotNull] IEntityType entityType, [NotNull] object entity, ValueBuffer valueBuffer);
18+
InternalEntityEntry StartTrackingFromQuery([NotNull] IEntityType baseEntityType, [NotNull] object entity, ValueBuffer valueBuffer);
1919

2020
void BeginTrackingQuery();
2121

src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/StateManager.cs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public virtual InternalEntityEntry GetOrCreateEntry(object entity)
8484
public virtual void BeginTrackingQuery() => SingleQueryMode = SingleQueryMode == null;
8585

8686
public virtual InternalEntityEntry StartTrackingFromQuery(
87-
IEntityType entityType,
87+
IEntityType baseEntityType,
8888
object entity,
8989
ValueBuffer valueBuffer)
9090
{
@@ -94,11 +94,17 @@ public virtual InternalEntityEntry StartTrackingFromQuery(
9494
return existingEntry;
9595
}
9696

97-
var newEntry = _factory.Create(this, entityType, entity, valueBuffer);
97+
var clrType = entity.GetType();
98+
99+
var newEntry = _factory.Create(this,
100+
baseEntityType.ClrType == clrType
101+
? baseEntityType
102+
: _model.FindEntityType(clrType),
103+
entity, valueBuffer);
98104

99105
_subscriber.SnapshotAndSubscribe(newEntry);
100106

101-
foreach (var key in entityType.GetKeys())
107+
foreach (var key in baseEntityType.GetKeys())
102108
{
103109
GetOrCreateIdentityMap(key).AddOrUpdate(newEntry);
104110
}

src/Microsoft.EntityFrameworkCore/Extensions/EntityTypeExtensions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ public static IEnumerable<IEntityType> GetDerivedTypes([NotNull] this IEntityTyp
2828
// ReSharper disable once LoopCanBeConvertedToQuery
2929
foreach (var derivedType in entityType.Model.GetEntityTypes())
3030
{
31-
if ((derivedType.BaseType != null)
32-
&& (derivedType != entityType)
31+
if (derivedType.BaseType != null
32+
&& derivedType != entityType
3333
&& entityType.IsAssignableFrom(derivedType))
3434
{
3535
yield return derivedType;

src/Microsoft.EntityFrameworkCore/Extensions/ModelExtensions.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using JetBrains.Annotations;
66
using Microsoft.EntityFrameworkCore.Internal;
77
using Microsoft.EntityFrameworkCore.Metadata;
8+
using Microsoft.EntityFrameworkCore.Metadata.Internal;
89
using Microsoft.EntityFrameworkCore.Utilities;
910

1011
namespace Microsoft.EntityFrameworkCore
@@ -24,7 +25,11 @@ public static IEntityType FindEntityType([NotNull] this IModel model, [NotNull]
2425
{
2526
Check.NotNull(type, nameof(type));
2627

27-
return model.FindEntityType(type.DisplayName());
28+
var canFindEntityType = model as ICanFindEntityType;
29+
30+
return canFindEntityType != null
31+
? canFindEntityType.FindEntityType(type)
32+
: model.FindEntityType(type.DisplayName());
2833
}
2934
}
3035
}

src/Microsoft.EntityFrameworkCore/Extensions/MutableModelExtensions.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using JetBrains.Annotations;
66
using Microsoft.EntityFrameworkCore.Internal;
77
using Microsoft.EntityFrameworkCore.Metadata;
8+
using Microsoft.EntityFrameworkCore.Metadata.Internal;
89
using Microsoft.EntityFrameworkCore.Utilities;
910

1011
namespace Microsoft.EntityFrameworkCore
@@ -44,8 +45,14 @@ public static IMutableEntityType AddEntityType([NotNull] this IMutableModel mode
4445
Check.NotNull(model, nameof(model));
4546
Check.NotNull(type, nameof(type));
4647

47-
var entityType = model.AddEntityType(type.DisplayName());
48+
var canFindEntityType = model as ICanFindEntityType;
49+
50+
var entityType = canFindEntityType != null
51+
? canFindEntityType.AddEntityType(type.DisplayName(), type) :
52+
model.AddEntityType(type.DisplayName());
53+
4854
entityType.ClrType = type;
55+
4956
return entityType;
5057
}
5158

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
using System;
5+
using JetBrains.Annotations;
6+
7+
namespace Microsoft.EntityFrameworkCore.Metadata.Internal
8+
{
9+
public interface ICanFindEntityType
10+
{
11+
IEntityType FindEntityType([NotNull] Type type);
12+
IMutableEntityType AddEntityType([NotNull] string name, [CanBeNull] Type type);
13+
}
14+
}

src/Microsoft.EntityFrameworkCore/Metadata/Internal/InternalModelBuilder.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public virtual InternalEntityTypeBuilder Entity([NotNull] string name, Configura
3131
{
3232
Metadata.Unignore(name);
3333

34-
entityType = Metadata.AddEntityType(name, configurationSource);
34+
entityType = Metadata.AddEntityType(name, null, configurationSource);
3535
}
3636
else
3737
{

src/Microsoft.EntityFrameworkCore/Metadata/Internal/Model.cs

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@
1313

1414
namespace Microsoft.EntityFrameworkCore.Metadata.Internal
1515
{
16-
public class Model : ConventionalAnnotatable, IMutableModel
16+
public class Model : ConventionalAnnotatable, IMutableModel, ICanFindEntityType
1717
{
18-
private readonly SortedDictionary<string, EntityType> _entityTypes = new SortedDictionary<string, EntityType>();
18+
private readonly SortedDictionary<string, EntityType> _entityTypes
19+
= new SortedDictionary<string, EntityType>();
20+
21+
private readonly IDictionary<Type, EntityType> _clrTypeMap
22+
= new Dictionary<Type, EntityType>();
1923

2024
private readonly Dictionary<string, ConfigurationSource> _ignoredEntityTypeNames
2125
= new Dictionary<string, ConfigurationSource>();
@@ -38,47 +42,53 @@ public Model([NotNull] ConventionSet conventions)
3842
public virtual IEnumerable<EntityType> GetEntityTypes() => _entityTypes.Values;
3943

4044
public virtual EntityType AddEntityType(
41-
[NotNull] string name, ConfigurationSource configurationSource = ConfigurationSource.Explicit)
45+
[NotNull] string name, [CanBeNull] Type type = null, ConfigurationSource configurationSource = ConfigurationSource.Explicit)
4246
{
4347
Check.NotEmpty(name, nameof(name));
4448

45-
var entityType = AddEntityTypeWithoutConventions(name, configurationSource);
46-
47-
return ConventionDispatcher.OnEntityTypeAdded(entityType.Builder)?.Metadata;
48-
}
49-
50-
public virtual EntityType AddEntityType(
51-
[NotNull] Type type, ConfigurationSource configurationSource = ConfigurationSource.Explicit)
52-
{
53-
Check.NotNull(type, nameof(type));
54-
55-
var entityType = AddEntityTypeWithoutConventions(type.DisplayName(), configurationSource);
56-
entityType.ClrType = type;
57-
58-
return ConventionDispatcher.OnEntityTypeAdded(entityType.Builder)?.Metadata;
59-
}
60-
61-
private EntityType AddEntityTypeWithoutConventions(string name, ConfigurationSource configurationSource)
62-
{
6349
var entityType = new EntityType(name, this, configurationSource);
50+
6451
var previousLength = _entityTypes.Count;
6552
_entityTypes[name] = entityType;
66-
6753
if (previousLength == _entityTypes.Count)
6854
{
6955
throw new InvalidOperationException(CoreStrings.DuplicateEntityType(entityType.Name));
7056
}
71-
return entityType;
57+
58+
if (type != null)
59+
{
60+
entityType.ClrType = type;
61+
_clrTypeMap[type] = entityType;
62+
}
63+
64+
return ConventionDispatcher.OnEntityTypeAdded(entityType.Builder)?.Metadata;
7265
}
7366

67+
IMutableEntityType ICanFindEntityType.AddEntityType(string name, Type type)
68+
=> AddEntityType(name, type);
69+
70+
public virtual EntityType AddEntityType(
71+
[NotNull] Type type, ConfigurationSource configurationSource = ConfigurationSource.Explicit)
72+
=> AddEntityType(type.DisplayName(), type, configurationSource);
73+
7474
public virtual EntityType GetOrAddEntityType([NotNull] Type type)
7575
=> FindEntityType(type) ?? AddEntityType(type);
7676

7777
public virtual EntityType GetOrAddEntityType([NotNull] string name)
7878
=> FindEntityType(name) ?? AddEntityType(name);
7979

80-
public virtual EntityType FindEntityType([NotNull] Type type)
81-
=> (EntityType)((IMutableModel)this).FindEntityType(type);
80+
public virtual EntityType FindEntityType([NotNull] Type type)
81+
=> (EntityType)((ICanFindEntityType)this).FindEntityType(type);
82+
83+
IEntityType ICanFindEntityType.FindEntityType(Type type)
84+
{
85+
Check.NotNull(type, nameof(type));
86+
87+
EntityType entityType;
88+
return _clrTypeMap.TryGetValue(type, out entityType)
89+
? entityType
90+
: FindEntityType(type.DisplayName());
91+
}
8292

8393
public virtual EntityType FindEntityType([NotNull] string name)
8494
{
@@ -129,6 +139,11 @@ private EntityType RemoveEntityType([NotNull] EntityType entityType)
129139
derivedEntityType.DisplayName()));
130140
}
131141

142+
if (entityType.ClrType != null)
143+
{
144+
_clrTypeMap.Remove(entityType.ClrType);
145+
}
146+
132147
var removed = _entityTypes.Remove(entityType.Name);
133148
Debug.Assert(removed);
134149
entityType.Builder = null;

src/Microsoft.EntityFrameworkCore/Microsoft.EntityFrameworkCore.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@
320320
<Compile Include="Metadata\Internal\IEntityMaterializer.cs" />
321321
<Compile Include="Metadata\Internal\IEntityMaterializerSource.cs" />
322322
<Compile Include="Metadata\Internal\IFieldMatcher.cs" />
323+
<Compile Include="Metadata\Internal\ICanFindEntityType.cs" />
323324
<Compile Include="Metadata\Internal\IIdentityMapFactorySource.cs" />
324325
<Compile Include="Metadata\Internal\IMemberMapper.cs" />
325326
<Compile Include="Metadata\Internal\INavigationAccessors.cs" />

test/Microsoft.EntityFrameworkCore.FunctionalTests/InheritanceRelationshipsQueryTestBase.cs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

44
using System;
5+
using System.Collections.Generic;
56
using System.Linq;
67
using Microsoft.EntityFrameworkCore.FunctionalTests.TestModels.InheritanceRelationships;
78
using Microsoft.EntityFrameworkCore.Metadata.Internal;
@@ -13,6 +14,38 @@ public abstract class InheritanceRelationshipsQueryTestBase<TTestStore, TFixture
1314
where TTestStore : TestStore
1415
where TFixture : InheritanceRelationshipsQueryFixtureBase<TTestStore>, new()
1516
{
17+
[Fact]
18+
public virtual void Changes_in_derived_related_entities_are_detected()
19+
{
20+
using (var context = CreateContext())
21+
{
22+
context.ChangeTracker.QueryTrackingBehavior = QueryTrackingBehavior.TrackAll;
23+
24+
var derivedEntity = context.BaseEntities.Include(e => e.BaseCollectionOnBase)
25+
.Single(e => e.Name == "Derived1(4)") as DerivedInheritanceRelationshipEntity;
26+
27+
var firstRelatedEntity = derivedEntity.BaseCollectionOnBase.Cast<DerivedCollectionOnBase>().First();
28+
29+
var originalValue = firstRelatedEntity.DerivedProperty;
30+
Assert.NotEqual(0, originalValue);
31+
32+
var entry = context.ChangeTracker.Entries<DerivedCollectionOnBase>()
33+
.Single(e => e.Entity == firstRelatedEntity);
34+
35+
Assert.IsType<DerivedCollectionOnBase>(entry.Entity);
36+
37+
Assert.Equal(
38+
"Microsoft.EntityFrameworkCore.FunctionalTests.TestModels.InheritanceRelationships.DerivedCollectionOnBase",
39+
entry.Metadata.Name);
40+
41+
firstRelatedEntity.DerivedProperty = originalValue + 1;
42+
context.ChangeTracker.DetectChanges();
43+
44+
Assert.Equal(EntityState.Modified, entry.State);
45+
Assert.Equal(originalValue, entry.Property(e => e.DerivedProperty).OriginalValue);
46+
Assert.Equal(originalValue + 1, entry.Property(e => e.DerivedProperty).CurrentValue);
47+
}
48+
}
1649

1750
[Fact]
1851
public virtual void Entity_can_make_separate_relationships_with_base_type_and_derived_type_both()
@@ -194,6 +227,7 @@ public virtual void Include_collection_with_inheritance1()
194227
var result = query.ToList();
195228

196229
Assert.Equal(6, result.Count);
230+
Assert.Equal(3, result.SelectMany(e => e.BaseCollectionOnBase.OfType<DerivedCollectionOnBase>()).Count(e => e.DerivedProperty != 0));
197231
}
198232
}
199233

@@ -231,6 +265,7 @@ public virtual void Include_collection_with_inheritance_with_filter1()
231265
var result = query.ToList();
232266

233267
Assert.Equal(6, result.Count);
268+
Assert.Equal(3, result.SelectMany(e => e.BaseCollectionOnBase.OfType<DerivedCollectionOnBase>()).Count(e => e.DerivedProperty != 0));
234269
}
235270
}
236271

@@ -474,6 +509,7 @@ public virtual void Include_collection_with_inheritance_on_derived1()
474509
var result = query.ToList();
475510

476511
Assert.Equal(3, result.Count);
512+
Assert.Equal(2, result.SelectMany(e => e.BaseCollectionOnBase.OfType<DerivedCollectionOnBase>()).Count(e => e.DerivedProperty != 0));
477513
}
478514
}
479515

@@ -660,6 +696,7 @@ public virtual void Nested_include_with_inheritance_collection_reference1()
660696
var result = query.ToList();
661697

662698
Assert.Equal(6, result.Count);
699+
Assert.Equal(3, result.SelectMany(e => e.BaseCollectionOnBase.OfType<DerivedCollectionOnBase>()).Count(e => e.DerivedProperty != 0));
663700
}
664701
}
665702

@@ -723,6 +760,7 @@ public virtual void Nested_include_with_inheritance_collection_collection1()
723760
var result = query.ToList();
724761

725762
Assert.Equal(6, result.Count);
763+
Assert.Equal(3, result.SelectMany(e => e.BaseCollectionOnBase.OfType<DerivedCollectionOnBase>()).Count(e => e.DerivedProperty != 0));
726764
}
727765
}
728766

0 commit comments

Comments
 (0)