diff --git a/src/EntityFramework.Testing.Moq/MockDbSetExtenstions.cs b/src/EntityFramework.Testing.Moq/MockDbSetExtenstions.cs index 1fed4a9..d222b40 100644 --- a/src/EntityFramework.Testing.Moq/MockDbSetExtenstions.cs +++ b/src/EntityFramework.Testing.Moq/MockDbSetExtenstions.cs @@ -1,4 +1,5 @@ using Moq; +using System; using System.Collections.Generic; using System.Data.Entity.Infrastructure; using System.Linq; @@ -14,12 +15,6 @@ public static MockDbSet SetupSeedData( { set.AddData(data); - // Need to re-setup LINQ if the data changes - if(set.IsLinqSetup) - { - set.SetupLinq(); - } - return set; } @@ -32,22 +27,45 @@ public static MockDbSet SetupLinq(this MockDbSet set) // Enable direct async enumeration of set set.As>() .Setup(m => m.GetAsyncEnumerator()) - .Returns(new TestDbAsyncEnumerator(set.Queryable.GetEnumerator())); + .Returns(() => new TestDbAsyncEnumerator(set.Queryable.GetEnumerator())); // Enable LINQ queries with async enumeration set.As>() .Setup(m => m.Provider) - .Returns(new TestDbAsyncQueryProvider(set.Queryable.Provider)); + .Returns(() => new TestDbAsyncQueryProvider(set.Queryable.Provider)); // Wire up LINQ provider to fall back to in memory LINQ provider of the data - set.As>().Setup(m => m.Expression).Returns(set.Queryable.Expression); - set.As>().Setup(m => m.ElementType).Returns(set.Queryable.ElementType); - set.As>().Setup(m => m.GetEnumerator()).Returns(set.Queryable.GetEnumerator()); + set.As>().Setup(m => m.Expression).Returns(() => set.Queryable.Expression); + set.As>().Setup(m => m.ElementType).Returns(() => set.Queryable.ElementType); + set.As>().Setup(m => m.GetEnumerator()).Returns(() => set.Queryable.GetEnumerator()); // Enable Include directly on the DbSet (Include extension method on IQueryable is a no-op when it's not a DbSet/DbQuery) // Include(string) and Include(Func s.Include(It.IsAny())).Returns(set.Object); return set; } + + public static MockDbSet SetupAddAndRemove(this MockDbSet set) + where TEntity : class + { + set.Setup(s => s.Add(It.IsAny())) + .Returns((TEntity t) => t) + .Callback((TEntity t) => set.AddData(t)); + + set.Setup(s => s.Remove(It.IsAny())) + .Returns((TEntity t) => t) + .Callback((TEntity t) => set.RemoveData(t)); + + return set; + } + + public static MockDbSet SetupFind(this MockDbSet set, Func finder) + where TEntity : class + { + set.Setup(s => s.Find(It.IsAny())) + .Returns((object[] keyValues) => set.Data.SingleOrDefault(e => finder(keyValues, e))); + + return set; + } } } diff --git a/src/EntityFramework.Testing.Moq/MockDbSet`.cs b/src/EntityFramework.Testing.Moq/MockDbSet`.cs index fd78dcc..a57847d 100644 --- a/src/EntityFramework.Testing.Moq/MockDbSet`.cs +++ b/src/EntityFramework.Testing.Moq/MockDbSet`.cs @@ -29,9 +29,19 @@ internal IQueryable Queryable get { return _queryable; } } + internal void AddData(TEntity data) + { + _data.Add(data); + } + internal void AddData(IEnumerable data) { _data.AddRange(data); } + + internal void RemoveData(TEntity data) + { + _data.Remove(data); + } } } diff --git a/test/EntityFramework.Testing.Moq.Tests/EntityFramework.Testing.Moq.Tests.csproj b/test/EntityFramework.Testing.Moq.Tests/EntityFramework.Testing.Moq.Tests.csproj index 49d6fff..e944f36 100644 --- a/test/EntityFramework.Testing.Moq.Tests/EntityFramework.Testing.Moq.Tests.csproj +++ b/test/EntityFramework.Testing.Moq.Tests/EntityFramework.Testing.Moq.Tests.csproj @@ -60,6 +60,7 @@ + diff --git a/test/EntityFramework.Testing.Moq.Tests/FakeDbSetDataOperationsTests.cs b/test/EntityFramework.Testing.Moq.Tests/FakeDbSetDataOperationsTests.cs new file mode 100644 index 0000000..4acf896 --- /dev/null +++ b/test/EntityFramework.Testing.Moq.Tests/FakeDbSetDataOperationsTests.cs @@ -0,0 +1,100 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace EntityFramework.Testing.Moq.Tests +{ + [TestClass] + public class FakeDbSetDataOperationsTests + { + [TestMethod] + public void Basic_add() + { + var set = new MockDbSet() + .SetupAddAndRemove(); + + var blog = new Blog(); + var result = set.Object.Add(blog); + + Assert.AreSame(blog, result); + Assert.AreEqual(1, set.Data.Count()); + Assert.IsTrue(set.Data.Contains(blog)); + } + + [TestMethod] + public void Basic_remove() + { + var blog1 = new Blog(); + var blog2 = new Blog(); + var data = new List { blog1, blog2 }; + var set = new MockDbSet() + .SetupSeedData(data) + .SetupAddAndRemove(); + + var result = set.Object.Remove(blog1); + + Assert.AreSame(blog1, result); + Assert.AreEqual(1, set.Data.Count()); + Assert.IsFalse(set.Data.Contains(blog1)); + Assert.IsTrue(set.Data.Contains(blog2)); + } + + [TestMethod] + public void Add_remove_work_with_enumeration() + { + var blog1 = new Blog(); + var blog2 = new Blog(); + var blog3 = new Blog(); + var data = new List { blog1, blog2 }; + var set = new MockDbSet() + .SetupSeedData(data) + .SetupLinq() + .SetupAddAndRemove(); + + set.Object.Remove(blog2); + set.Object.Add(blog3); + + var result = set.Object.ToList(); + + Assert.AreEqual(2, result.Count); + Assert.IsTrue(result.Contains(blog3)); + Assert.IsTrue(result.Contains(blog1)); + } + + [TestMethod] + public void Basic_find() + { + var blog = new Blog { BlogId = 1 }; + var data = new List { blog, new Blog { BlogId = 2 } }; + var set = new MockDbSet() + .SetupSeedData(data) + .SetupFind((keyValues, entity) => entity.BlogId == (int)keyValues.Single()); + + var result = set.Object.Find(1); + + Assert.AreSame(blog, result); + } + + [TestMethod] + public void Find_returs_null_for_no_match() + { + var data = new List{ new Blog { BlogId = 1 }, new Blog { BlogId = 2 } }; + var set = new MockDbSet() + .SetupSeedData(data) + .SetupFind((keyValues, entity) => entity.BlogId == (int)keyValues.Single()); + + var result = set.Object.Find(99); + + Assert.IsNull(result); + } + + public class Blog + { + public int BlogId { get; set; } + public string Url { get; set; } + } + } +} diff --git a/test/EntityFramework.Testing.Moq.Tests/FakeDbSetDataTests.cs b/test/EntityFramework.Testing.Moq.Tests/FakeDbSetDataTests.cs index 101e96d..0238998 100644 --- a/test/EntityFramework.Testing.Moq.Tests/FakeDbSetDataTests.cs +++ b/test/EntityFramework.Testing.Moq.Tests/FakeDbSetDataTests.cs @@ -10,6 +10,7 @@ namespace EntityFramework.Testing.Moq.Tests [TestClass] public class FakeDbSetDataTests { + [TestMethod] public void Data_is_addded_to_set() { var data = new List { new Blog(), new Blog() };