diff --git a/Spec.go b/Spec.go index 3e6493a..d916ac1 100644 --- a/Spec.go +++ b/Spec.go @@ -683,6 +683,12 @@ func (spec *Spec) Spec(oth *Spec) { opt.setup(oth) } oth.isSuite = isSuite + for _, hook := range spec.hooks.BeforeAll { + oth.BeforeAll(hook.Block) + } + for _, hook := range spec.hooks.Around { + oth.Around(hook.Block) + } for _, def := range spec.defs { def(oth) } @@ -730,10 +736,11 @@ type SpecSuite struct { S *Spec } -func (suite SpecSuite) Name() string { return suite.N } +func (suite SpecSuite) Name() string { return suite.N } +func (suite SpecSuite) Spec(s *Spec) { suite.S.Spec(s) } + func (suite SpecSuite) Test(t *testing.T) { suite.run(t) } func (suite SpecSuite) Benchmark(b *testing.B) { suite.run(b) } -func (suite SpecSuite) Spec(s *Spec) { suite.S.Spec(s) } func (suite SpecSuite) run(tb testing.TB) { s := NewSpec(tb) diff --git a/Suite_test.go b/Suite_test.go index 49bac93..6021ab3 100644 --- a/Suite_test.go +++ b/Suite_test.go @@ -140,3 +140,62 @@ func (c RunContractFmtStringerContract) String() string { return "Hello, world!" func (c RunContractFmtStringerContract) Spec(s *testcase.Spec) { s.Test(``, func(t *testcase.T) { t.Log("!dlrow ,olleH") }) } + +func TestSpec_AsSuite_merge(t *testing.T) { + t.Run("Before", func(t *testing.T) { + var n int + t.Run("", func(t *testing.T) { + suite := testcase.NewSpec(nil) + suite.HasSideEffect() + suite.Before(func(t *testcase.T) { n++ }) + suite.Test("", func(t *testcase.T) {}) + suite.Test("", func(t *testcase.T) {}) + suite.AsSuite("suite").Test(t) + }) + assert.Equal(t, 2, n) + }) + t.Run("BeforeAll", func(t *testing.T) { + var n int + t.Run("", func(t *testing.T) { + suite := testcase.NewSpec(nil) + suite.HasSideEffect() + suite.BeforeAll(func(tb testing.TB) { n++ }) + suite.Test("", func(t *testcase.T) {}) + suite.Test("", func(t *testcase.T) {}) + suite.AsSuite("suite").Test(t) + }) + assert.Equal(t, 1, n) + }) + t.Run("After", func(t *testing.T) { + var n int + t.Run("", func(t *testing.T) { + suite := testcase.NewSpec(nil) + suite.HasSideEffect() + suite.After(func(t *testcase.T) { n++ }) + suite.Test("", func(t *testcase.T) {}) + suite.Test("", func(t *testcase.T) {}) + suite.AsSuite("suite").Test(t) + }) + assert.Equal(t, 2, n) + }) + t.Run("Around", func(t *testing.T) { + var b, a int + t.Run("", func(t *testing.T) { + suite := testcase.NewSpec(nil) + suite.HasSideEffect() + suite.Around(func(*testcase.T) func() { + b++ + return func() { + a++ + } + }) + suite.Test("", func(t *testcase.T) {}) + suite.Test("", func(t *testcase.T) {}) + suite.AsSuite("suite").Test(t) + }) + assert.Equal(t, 2, a) + assert.Equal(t, 2, b) + }) + + // TODO: cover further +} diff --git a/T.go b/T.go index 60f4a71..bffe08e 100644 --- a/T.go +++ b/T.go @@ -123,7 +123,7 @@ func (t *T) setUp() func() { for _, c := range contexts { for _, hook := range c.hooks.BeforeAll { - hook.Block() + hook.DoOnce(t) } } diff --git a/hooks.go b/hooks.go index 9c076ce..209b0c0 100644 --- a/hooks.go +++ b/hooks.go @@ -19,8 +19,9 @@ type hook struct { } type hookOnce struct { - Block func() - Frame runtime.Frame + Block func(testing.TB) + Frame runtime.Frame + DoOnce func(testing.TB) } // Before give you the ability to run a block before each test case. @@ -78,9 +79,17 @@ func (spec *Spec) BeforeAll(blk func(tb testing.TB)) { spec.testingTB.Fatal(hookWarning) } frame, _ := caller.GetFrame() - var once sync.Once - spec.hooks.BeforeAll = append(spec.hooks.BeforeAll, hookOnce{ - Block: func() { once.Do(func() { blk(spec.testingTB) }) }, - Frame: frame, - }) + + var onCall sync.Once + var beforeAll = func(tb testing.TB) { + onCall.Do(func() { blk(tb) }) + } + + h := hookOnce{ + DoOnce: beforeAll, + Block: blk, + Frame: frame, + } + + spec.hooks.BeforeAll = append(spec.hooks.BeforeAll, h) }