Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CSHARP-4935 allow Linq Translation conversion from interface to deriv… #1250

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public DiscriminatedInterfaceSerializer()
/// <exception cref="System.ArgumentException">interfaceType</exception>
/// <exception cref="System.ArgumentNullException">interfaceType</exception>
public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorConvention)
: this(discriminatorConvention, CreateInterfaceSerializer())
: this(discriminatorConvention, CreateInterfaceSerializer(), null)
{
}

Expand All @@ -86,6 +86,19 @@ public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorCo
/// <exception cref="System.ArgumentException">interfaceType</exception>
/// <exception cref="System.ArgumentNullException">interfaceType</exception>
public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorConvention, IBsonSerializer<TInterface> interfaceSerializer)
: this(discriminatorConvention, interfaceSerializer, null)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="DiscriminatedInterfaceSerializer{TInterface}" /> class.
/// </summary>
/// <param name="discriminatorConvention">The discriminator convention.</param>
/// <param name="interfaceSerializer">The interface serializer (necessary to support LINQ queries).</param>
/// <param name="objectSerializer">The serializer that is used to serialize any objects.</param>
/// <exception cref="System.ArgumentException">interfaceType</exception>
/// <exception cref="System.ArgumentNullException">interfaceType</exception>
public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorConvention, IBsonSerializer<TInterface> interfaceSerializer, IBsonSerializer<object> objectSerializer)
{
var interfaceTypeInfo = typeof(TInterface).GetTypeInfo();
if (!interfaceTypeInfo.IsInterface)
Expand All @@ -96,7 +109,7 @@ public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorCo

_interfaceType = typeof(TInterface);
_discriminatorConvention = discriminatorConvention ?? BsonSerializer.LookupDiscriminatorConvention(typeof(TInterface));
_objectSerializer = BsonSerializer.LookupSerializer<object>();
_objectSerializer = objectSerializer ?? new ObjectSerializer(type => typeof(TInterface).IsAssignableFrom(type));
if (_objectSerializer is ObjectSerializer standardObjectSerializer)
{
_objectSerializer = standardObjectSerializer.WithDiscriminatorConvention(_discriminatorConvention);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ private static bool IsConvertToBaseType(Type sourceType, Type targetType)

private static bool IsConvertToDerivedType(Type sourceType, Type targetType)
{
return targetType.IsSubclassOf(sourceType);
return targetType.IsSubclassOf(sourceType) || sourceType.IsAssignableFrom(targetType);
}

private static bool IsConvertUnderlyingTypeToEnum(UnaryExpression expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ private static bool IsConvertToBaseType(Type fieldType, Type targetType)

private static bool IsConvertToDerivedType(Type fieldType, Type targetType)
{
return targetType.IsSubclassOf(fieldType);
return targetType.IsSubclassOf(fieldType) || fieldType.IsAssignableFrom(targetType);
}

private static bool IsConvertToNullable(Type fieldType, Type targetType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,47 @@ public void Project_using_convert_nullable_enum_to_nullable_underlying_type_work
result.EnumAsNullableInt.Should().Be(2);
}

[Fact]
public void Should_translate_from_base_interface_to_derived_class_on_method_call()
{
var collection = GetInterfaceCollection();
var queryable = collection.AsQueryable()
.Select(p => new DerivedClass
{
Id = p.Id,
A = ((DerivedClass)p).A.ToUpper()
});

var stages = Translate(collection, queryable);
AssertStages(
stages,
"{ '$project' : { _id : '$_id', A : { '$toUpper' : '$A' } } }");

var result = queryable.Single();
result.Id.Should().Be(1);
result.A.Should().Be("ABC");
}

[Fact]
public void Should_translate_from_base_interface_to_derived_class_on_projection()
{
var collection = GetInterfaceCollection();
var queryable = collection.AsQueryable()
.Select(p => new DerivedClass()
{
Id = p.Id,
A = ((DerivedClass)p).A
});

var stages = Translate(collection, queryable);
AssertStages(
stages,
"{ '$project' : { _id : '$_id', A : '$A' } }");

var result = queryable.Single();
result.Id.Should().Be(1);
result.A.Should().Be("abc");
}

private IMongoCollection<BaseClass> GetCollection()
{
Expand All @@ -208,7 +248,31 @@ private IMongoCollection<BaseClass> GetCollection()
return collection;
}

private class BaseClass
private IMongoCollection<IBaseInterface> GetInterfaceCollection()
{
var collection = GetCollection<IBaseInterface>("test");
CreateCollection(collection, new DerivedClass()
{
Id = 1,
A = "abc",
Enum = Enum.Two,
NullableEnum = Enum.Two,
EnumAsInt = 2,
EnumAsNullableInt = 2
});
return collection;
}

private interface IBaseInterface
{
public int Id { get; set; }
public Enum Enum { get; set; }
public Enum? NullableEnum { get; set; }
public int EnumAsInt { get; set; }
public int? EnumAsNullableInt { get; set; }
}

private class BaseClass : IBaseInterface
{
public int Id { get; set; }
public Enum Enum { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ public void Filter_using_convert_nullable_enum_to_underlying_type_should_work()
result.Id.Should().Be(2);
}

[Fact]
public void Filter_using_field_from_underlying_type_should_work()
{
var collection = GetInterfaceCollection();

var filter = Builders<IData>.Filter.Eq(x => ((Data)x).AdditionalValue, "value");

var result = collection.Find(filter).Single();
result.Id.Should().Be(2);
}

private IMongoCollection<Data> GetCollection()
{
var collection = GetCollection<Data>("test");
Expand All @@ -95,13 +106,33 @@ private IMongoCollection<Data> GetCollection()
return collection;
}

private class Data
private IMongoCollection<IData> GetInterfaceCollection()
{
var collection = GetCollection<IData>("test");
CreateCollection(
collection,
new Data { Id = 1, Enum = Enum.One, NullableEnum = Enum.One, EnumAsInt = 1, EnumAsNullableInt = 1 },
new Data { Id = 2, Enum = Enum.Two, NullableEnum = Enum.Two, EnumAsInt = 2, EnumAsNullableInt = 2, AdditionalValue = "value"});
return collection;
}

private interface IData
{
public int Id { get; set; }
public Enum Enum { get; set; }
public Enum? NullableEnum { get; set; }
public int EnumAsInt { get; set; }
public int? EnumAsNullableInt { get; set; }
}

private class Data : IData
{
public int Id { get; set; }
public Enum Enum { get; set; }
public Enum? NullableEnum { get; set; }
public int EnumAsInt { get; set; }
public int? EnumAsNullableInt { get; set; }
public string AdditionalValue { get; set; }
}

private enum Enum
Expand Down