Skip to content

Commit

Permalink
FromSql returns cloned DbSet to allow additional annotation extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
mikary committed Mar 13, 2015
1 parent af958b4 commit e9ca28e
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 63 deletions.
24 changes: 9 additions & 15 deletions src/EntityFramework.Core/DbSet`.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ namespace Microsoft.Data.Entity
/// </para>
/// </summary>
/// <typeparam name="TEntity"> The type of entity being operated on by this set. </typeparam>
public abstract class DbSet<TEntity>
: IOrderedQueryable<TEntity>, IAsyncEnumerableAccessor<TEntity>, IAccessor<IServiceProvider>, IAccessor<EntityQueryable<TEntity>>
public abstract class DbSet<TEntity> : IOrderedQueryable<TEntity>, IAsyncEnumerableAccessor<TEntity>, IAccessor<IServiceProvider>
where TEntity : class
{
/// <summary>
Expand Down Expand Up @@ -263,6 +262,14 @@ IQueryProvider IQueryable.Provider
get { throw new NotImplementedException(); }
}

/// <summary>
/// Adds an <see cref="IAnnotation"/> to the current query.
/// </summary>
public virtual DbSet<TEntity> AddAnnotation([NotNull] string annotationName, [NotNull] string value)
{
throw new NotImplementedException();
}

/// <summary>
/// <para>
/// Gets the <see cref="DbContext" /> instance.
Expand All @@ -286,18 +293,5 @@ IServiceProvider IAccessor<IServiceProvider>.Service
{
get { throw new NotImplementedException(); }
}

/// <summary>
/// <para>
/// Gets the <see cref="EntityQueryable{TResult}" /> instance.
/// </para>
/// <para>
/// This property is intended for use by extension methods to clone and add query annotations.
/// </para>
/// </summary>
EntityQueryable<TEntity> IAccessor<EntityQueryable<TEntity>>.Service
{
get { throw new NotImplementedException(); }
}
}
}
25 changes: 21 additions & 4 deletions src/EntityFramework.Core/Internal/InternalDbSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@

namespace Microsoft.Data.Entity.Internal
{
public class InternalDbSet<TEntity>
: DbSet<TEntity>, IOrderedQueryable<TEntity>, IAsyncEnumerableAccessor<TEntity>, IAccessor<IServiceProvider>, IAccessor<EntityQueryable<TEntity>>
public class InternalDbSet<TEntity> : DbSet<TEntity>, IOrderedQueryable<TEntity>, IAsyncEnumerableAccessor<TEntity>, IAccessor<IServiceProvider>
where TEntity : class
{
private readonly DbContext _context;
Expand All @@ -36,6 +35,15 @@ public InternalDbSet([NotNull] DbContext context)
((IAccessor<IServiceProvider>)_context).Service.GetRequiredService<EntityQueryProvider>()));
}

private InternalDbSet([NotNull] DbContext context, [NotNull] EntityQueryable<TEntity> entityQueryable)
{
Check.NotNull(context, nameof(context));
Check.NotNull(entityQueryable, nameof(entityQueryable));

_context = context;
_entityQueryable = new LazyRef<EntityQueryable<TEntity>>(() => entityQueryable);
}

public override EntityEntry<TEntity> Add(TEntity entity)
{
Check.NotNull(entity, nameof(entity));
Expand Down Expand Up @@ -132,10 +140,19 @@ public override void UpdateRange(IEnumerable<TEntity> entities)

IQueryProvider IQueryable.Provider => _entityQueryable.Value.Provider;

public override DbSet<TEntity> AddAnnotation([NotNull] string annotationName, [NotNull] string value)
{
Check.NotEmpty(annotationName, nameof(annotationName));
Check.NotEmpty(value, nameof(value));

var entityQueryable = _entityQueryable.Value.Clone();
entityQueryable.AddAnnotation(annotationName, value);

return new InternalDbSet<TEntity>(_context, entityQueryable);
}

public override DbContext Context => _context;

IServiceProvider IAccessor<IServiceProvider>.Service => ((IAccessor<IServiceProvider>)_context).Service;

EntityQueryable<TEntity> IAccessor<EntityQueryable<TEntity>>.Service => _entityQueryable.Value;
}
}
54 changes: 18 additions & 36 deletions src/EntityFramework.Core/Query/EntityQueryable`.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@ namespace Microsoft.Data.Entity.Query
public class EntityQueryable<TResult>
: QueryableBase<TResult>, IAsyncEnumerable<TResult>, IAnnotatable
{
private EntityQueryProvider _provider;

private readonly LazyRef<Annotatable> _annotatable
= new LazyRef<Annotatable>(
() => new Annotatable());
private readonly EntityQueryProvider _provider;
private readonly Annotatable _annotatable;

public EntityQueryable([NotNull] EntityQueryProvider provider)
: base(Check.NotNull(provider, nameof(provider)))
{
_provider = provider;
_annotatable = new Annotatable();
}

public EntityQueryable([NotNull] EntityQueryProvider provider, [NotNull] Expression expression)
Expand All @@ -32,57 +30,41 @@ public EntityQueryable([NotNull] EntityQueryProvider provider, [NotNull] Express
Check.NotNull(expression, nameof(expression)))
{
_provider = provider;
_annotatable = new Annotatable();
}

public virtual EntityQueryable<TResult> Clone()
{
return new EntityQueryable<TResult>(_provider);
var clone = new EntityQueryable<TResult>(_provider);

foreach (var annotation in _annotatable.Annotations)
{
clone.AddAnnotation(annotation.Name, annotation.Value);
}

return clone;
}

IAsyncEnumerator<TResult> IAsyncEnumerable<TResult>.GetEnumerator()
{
return ((IAsyncQueryProvider)Provider).ExecuteAsync<TResult>(Expression).GetEnumerator();
}

public virtual Annotation AddAnnotation([NotNull] string annotationName, [NotNull] string value)
{
Check.NotEmpty(annotationName, nameof(annotationName));
Check.NotNull(value, nameof(value));
public virtual Annotation AddAnnotation([NotNull] string annotationName, [NotNull] string value) => _annotatable.AddAnnotation(annotationName, value);

return _annotatable.Value.AddAnnotation(annotationName, value);
}
public virtual string this[[NotNull]string annotationName] => _annotatable[annotationName];

public virtual string this[[NotNull]string annotationName]
{
get
{
Check.NotNull(annotationName, annotationName);
return _annotatable.Value[annotationName];
}
}
public virtual IEnumerable<IAnnotation> Annotations => _annotatable.Annotations;

public virtual IEnumerable<IAnnotation> Annotations
{
get
{
return _annotatable.Value.Annotations;
}
}

public virtual Annotation GetAnnotation([NotNull]string annotationName)
{
Check.NotNull(annotationName, nameof(annotationName));

return _annotatable.Value.GetAnnotation(annotationName);
}
public virtual Annotation GetAnnotation([NotNull]string annotationName) => _annotatable.GetAnnotation(annotationName);

public override string ToString()
{
return _annotatable.Value.Annotations.Count() == 0
return _annotatable.Annotations.Count() == 0
? base.ToString()
: string.Format("{0} ({1})",
base.ToString(),
string.Join(", ", _annotatable.Value.Annotations.Select(annotation =>
string.Join(", ", _annotatable.Annotations.Select(annotation =>
string.Format("{0} = {1}", annotation.Name, annotation.Value))));
}
}
Expand Down
10 changes: 2 additions & 8 deletions src/EntityFramework.Relational/RelationalDbSetExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

using System.Linq;
using JetBrains.Annotations;
using Microsoft.Data.Entity.Infrastructure;
using Microsoft.Data.Entity.Query;
using Microsoft.Data.Entity.Utilities;

// ReSharper disable once CheckNamespace
Expand All @@ -13,17 +11,13 @@ namespace Microsoft.Data.Entity
{
public static class RelationalDbSetExtensions
{
public static IQueryable<TEntity> FromSql<TEntity>([NotNull]this DbSet<TEntity> dbSet, [NotNull]string query)
public static DbSet<TEntity> FromSql<TEntity>([NotNull]this DbSet<TEntity> dbSet, [NotNull]string query)
where TEntity : class
{
Check.NotNull(dbSet, nameof(dbSet));
Check.NotNull(query, nameof(query));

var queryable = ((IAccessor<EntityQueryable<TEntity>>)dbSet).Service.Clone();

queryable.AddAnnotation("Sql", query);

return queryable;
return dbSet.AddAnnotation("Sql", query);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
using Microsoft.Data.Entity.Tests;
using Xunit;

using CoreStrings = Microsoft.Data.Entity.Internal.Strings;

namespace Microsoft.Data.Entity.Relational.FunctionalTests
{
public abstract class FromSqlQueryTestBase<TFixture> : IClassFixture<TFixture>
Expand Down Expand Up @@ -95,6 +97,18 @@ public virtual void From_sql_annotations_do_not_modify_successive_calls()
}
}

[Fact]
public virtual void Multiple_calls_to_from_sql_throw()
{
using (var context = CreateContext())
{
Assert.Equal(
CoreStrings.DuplicateAnnotation("Sql"),
Assert.Throws<InvalidOperationException>(
() => context.Customers.FromSql("X").FromSql("X")).Message);
}
}

protected NorthwindContext CreateContext()
{
return Fixture.CreateContext();
Expand Down

0 comments on commit e9ca28e

Please sign in to comment.