Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CSHARP-4935: Linq3Implementation unable to handle interface as root #1419

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public DiscriminatedInterfaceSerializer()
/// <exception cref="System.ArgumentException">interfaceType</exception>
/// <exception cref="System.ArgumentNullException">interfaceType</exception>
public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorConvention)
: this(discriminatorConvention, CreateInterfaceSerializer())
: this(discriminatorConvention, CreateInterfaceSerializer(), objectSerializer: null)
{
}

Expand All @@ -86,6 +86,19 @@ public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorCo
/// <exception cref="System.ArgumentException">interfaceType</exception>
/// <exception cref="System.ArgumentNullException">interfaceType</exception>
public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorConvention, IBsonSerializer<TInterface> interfaceSerializer)
: this(discriminatorConvention, interfaceSerializer, objectSerializer: null)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="DiscriminatedInterfaceSerializer{TInterface}" /> class.
/// </summary>
/// <param name="discriminatorConvention">The discriminator convention.</param>
/// <param name="interfaceSerializer">The interface serializer (necessary to support LINQ queries).</param>
/// <param name="objectSerializer">The serializer that is used to serialize any objects.</param>
/// <exception cref="System.ArgumentException">interfaceType</exception>
/// <exception cref="System.ArgumentNullException">interfaceType</exception>
public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorConvention, IBsonSerializer<TInterface> interfaceSerializer, IBsonSerializer<object> objectSerializer)
{
var interfaceTypeInfo = typeof(TInterface).GetTypeInfo();
if (!interfaceTypeInfo.IsInterface)
Expand All @@ -96,7 +109,7 @@ public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorCo

_interfaceType = typeof(TInterface);
_discriminatorConvention = discriminatorConvention ?? BsonSerializer.LookupDiscriminatorConvention(typeof(TInterface));
_objectSerializer = BsonSerializer.LookupSerializer<object>();
_objectSerializer = objectSerializer ?? new ObjectSerializer(allowedTypes: type => typeof(TInterface).IsAssignableFrom(type));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you see any backward compatibility issues with this change?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be a problem if someone registered their own object serializer for some reasons. Can we implement method like a WithAllowedTypes on objectSerializer? So we lookup the same way as it was before - and then try to use this new method, similar to this:

var objectSerializer = BsonSerializer.LookupSerializer<object>();
if (objectSerializer is ObjectSerializer builtInSerializer)
{
    objectSerializer  = builtInSerializer.WithAllowedTypes(allowedTypes: type => typeof(TInterface).IsAssignableFrom(type));
}

if (_objectSerializer is ObjectSerializer standardObjectSerializer)
{
_objectSerializer = standardObjectSerializer.WithDiscriminatorConvention(_discriminatorConvention);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ private static bool IsConvertToBaseType(Type sourceType, Type targetType)

private static bool IsConvertToDerivedType(Type sourceType, Type targetType)
{
return targetType.IsSubclassOf(sourceType);
return sourceType.IsAssignableFrom(targetType); // targetType either derives from sourceType or implements sourceType interface
}

private static bool IsConvertToNullableType(Type targetType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ private static bool IsConvertToBaseType(Type fieldType, Type targetType)

private static bool IsConvertToDerivedType(Type fieldType, Type targetType)
{
return targetType.IsSubclassOf(fieldType);
return fieldType.IsAssignableFrom(targetType); // targetType either derives from fieldType or implements fieldType interface
}

private static bool IsConvertToNullable(Type fieldType, Type targetType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,47 @@ public void Project_using_convert_nullable_enum_to_nullable_underlying_type_work
result.EnumAsNullableInt.Should().Be(2);
}

[Fact]
public void Should_translate_from_base_interface_to_derived_class_on_method_call()
{
var collection = GetInterfaceCollection();
var queryable = collection.AsQueryable()
.Select(p => new DerivedClass
{
Id = p.Id,
A = ((DerivedClass)p).A.ToUpper()
});

var stages = Translate(collection, queryable);
AssertStages(
stages,
"{ '$project' : { _id : '$_id', A : { '$toUpper' : '$A' } } }");

var result = queryable.Single();
result.Id.Should().Be(1);
result.A.Should().Be("ABC");
}

[Fact]
public void Should_translate_from_base_interface_to_derived_class_on_projection()
{
var collection = GetInterfaceCollection();
var queryable = collection.AsQueryable()
.Select(p => new DerivedClass()
{
Id = p.Id,
A = ((DerivedClass)p).A
});

var stages = Translate(collection, queryable);
AssertStages(
stages,
"{ '$project' : { _id : '$_id', A : '$A' } }");

var result = queryable.Single();
result.Id.Should().Be(1);
result.A.Should().Be("abc");
}

private IMongoCollection<BaseClass> GetCollection()
{
Expand All @@ -208,7 +248,31 @@ private IMongoCollection<BaseClass> GetCollection()
return collection;
}

private class BaseClass
private IMongoCollection<IBaseInterface> GetInterfaceCollection()
{
var collection = GetCollection<IBaseInterface>("test");
CreateCollection(collection, new DerivedClass()
{
Id = 1,
A = "abc",
Enum = Enum.Two,
NullableEnum = Enum.Two,
EnumAsInt = 2,
EnumAsNullableInt = 2
});
return collection;
}

private interface IBaseInterface
{
public int Id { get; set; }
public Enum Enum { get; set; }
public Enum? NullableEnum { get; set; }
public int EnumAsInt { get; set; }
public int? EnumAsNullableInt { get; set; }
}

private class BaseClass : IBaseInterface
{
public int Id { get; set; }
public Enum Enum { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ public void Filter_using_convert_nullable_enum_to_underlying_type_should_work()
result.Id.Should().Be(2);
}

[Fact]
public void Filter_using_field_from_implementing_type_should_work()
{
var collection = GetInterfaceCollection();

var filter = Builders<IData>.Filter.Eq(x => ((Data)x).AdditionalValue, "value");

var result = collection.Find(filter).Single();
result.Id.Should().Be(2);
}

private IMongoCollection<Data> GetCollection()
{
var collection = GetCollection<Data>("test");
Expand All @@ -95,13 +106,33 @@ private IMongoCollection<Data> GetCollection()
return collection;
}

private class Data
private IMongoCollection<IData> GetInterfaceCollection()
{
var collection = GetCollection<IData>("test");
CreateCollection(
collection,
new Data { Id = 1, Enum = Enum.One, NullableEnum = Enum.One, EnumAsInt = 1, EnumAsNullableInt = 1 },
new Data { Id = 2, Enum = Enum.Two, NullableEnum = Enum.Two, EnumAsInt = 2, EnumAsNullableInt = 2, AdditionalValue = "value"});
return collection;
}

private interface IData
{
public int Id { get; set; }
public Enum Enum { get; set; }
public Enum? NullableEnum { get; set; }
public int EnumAsInt { get; set; }
public int? EnumAsNullableInt { get; set; }
}

private class Data : IData
{
public int Id { get; set; }
public Enum Enum { get; set; }
public Enum? NullableEnum { get; set; }
public int EnumAsInt { get; set; }
public int? EnumAsNullableInt { get; set; }
public string AdditionalValue { get; set; }
}

private enum Enum
Expand Down