Skip to content

Commit

Permalink
add InjectCall/EnableCall to avoid test code pollute main code (#83)
Browse files Browse the repository at this point in the history
* change

* makefile

* move eval inside to avoid line number change
  • Loading branch information
D3Hunter authored May 27, 2024
1 parent fd0796e commit 9b3b6e3
Show file tree
Hide file tree
Showing 13 changed files with 345 additions and 48 deletions.
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,18 @@ check-static: tools/bin/gometalinter

gotest:
@ echo "----------- go test ---------------"
$(GOTEST) -covermode=atomic -coverprofile=coverage.txt -coverpkg=./... -v ./...
$(GOTEST) -covermode=atomic -coverprofile=coverage.txt -coverpkg=./... -v $(go list ./... | grep -v examples)

tools/bin/gometalinter:
cd tools; \
curl -L https://git.io/vp6lP | sh

test-examples:
@ echo "----------- go test examples ---------------"
$(GO) run failpoint-ctl/main.go enable ./examples
$(GOTEST) -covermode=atomic -coverprofile=coverage.txt -coverpkg=./... -v ./examples/...
$(GO) run failpoint-ctl/main.go disable ./examples

test-examples-toolexec: build
@ echo "----------- go test examples using toolexec ---------------"
GOCACHE=/tmp/failpoint-cache $(GOTEST) -covermode=atomic -coverprofile=coverage.txt -coverpkg=./... -toolexec="$(PWD)/bin/failpoint-toolexec" -v ./examples/...
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ An implementation of [failpoints][failpoint] for Golang. Fail points are used to
GO_FAILPOINTS="main/testPanic=return(true)" ./your-program
```
Note: `GO_FAILPOINTS` does not work with `InjectCall` type of marker.
6. If you use `go run` to run the test, don't forget to add the generated `binding__failpoint_binding__.go` in your command, like:
```bash
Expand Down Expand Up @@ -137,6 +139,7 @@ An implementation of [failpoints][failpoint] for Golang. Fail points are used to
- `func Inject(fpname string, fpblock func(val Value)) {}`
- `func InjectContext(fpname string, ctx context.Context, fpblock func(val Value)) {}`
- `func InjectCall(fpname string, args ...any) {}`
- `func Break(label ...string) {}`
- `func Goto(label string) {}`
- `func Continue(label ...string) {}`
Expand All @@ -148,6 +151,8 @@ An implementation of [failpoints][failpoint] for Golang. Fail points are used to
failpoint can be enabled by export environment variables with the following patten, which is quite similar to [freebsd failpoint SYSCTL VARIABLES](https://www.freebsd.org/cgi/man.cgi?query=fail)
Note: `InjectCall` cannot be enabled by environment variables.
```regexp
[<percent>%][<count>*]<type>[(args...)][-><more terms>]
```
Expand Down Expand Up @@ -240,6 +245,8 @@ active in parallel tests or other cases. For example,
}
```
- You can use `failpoint.InjectCall` to inject a function call, this type of marker can only be enabled using `failpoint.EnableCall` and it must be called in the same process as the `InjectCall` call site. Using this marker, you can avoid failpoint code pollute you source code. See [examples](./examples/injectcall/inject_call.go).
- You can control a failpoint by failpoint.WithHook
```go
Expand Down
36 changes: 36 additions & 0 deletions code/expr_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type exprRewriter func(rewriter *Rewriter, call *ast.CallExpr) (rewritten bool,
var exprRewriters = map[string]exprRewriter{
"Inject": (*Rewriter).rewriteInject,
"InjectContext": (*Rewriter).rewriteInjectContext,
"InjectCall": (*Rewriter).rewriteInjectCall,
"Break": (*Rewriter).rewriteBreak,
"Continue": (*Rewriter).rewriteContinue,
"Label": (*Rewriter).rewriteLabel,
Expand Down Expand Up @@ -220,6 +221,41 @@ func (r *Rewriter) rewriteInjectContext(call *ast.CallExpr) (bool, ast.Stmt, err
return true, stmt, nil
}

func (r *Rewriter) rewriteInjectCall(call *ast.CallExpr) (bool, ast.Stmt, error) {
if len(call.Args) < 1 {
return false, nil, fmt.Errorf("failpoint.InjectCall: expect at least 1 arguments but got %v in %s", len(call.Args), r.pos(call.Pos()))
}
// First argument need not to be a string literal, any string type stuff is ok.
// Type safe is convinced by compiler.
fpname, ok := call.Args[0].(ast.Expr)
if !ok {
return false, nil, fmt.Errorf("failpoint.InjectCall: first argument expect a valid expression in %s", r.pos(call.Pos()))
}

fpnameExtendCall := &ast.CallExpr{
Fun: ast.NewIdent(ExtendPkgName),
Args: []ast.Expr{fpname},
}

// failpoint.InjectCall("name", a, b, c)
// |
// v
// failpoint.Call(_curpkg_("name"), a, b, c)
fnArgs := make([]ast.Expr, 0, len(call.Args))
fnArgs = append(fnArgs, fpnameExtendCall)
fnArgs = append(fnArgs, call.Args[1:]...)
fnCall := &ast.ExprStmt{
X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{NamePos: call.Pos(), Name: r.failpointName},
Sel: ast.NewIdent(callFunction),
},
Args: fnArgs,
},
}
return true, fnCall, nil
}

func (r *Rewriter) rewriteBreak(call *ast.CallExpr) (bool, ast.Stmt, error) {
if count := len(call.Args); count > 1 {
return false, nil, fmt.Errorf("failpoint.Break expect 1 or 0 arguments, but got %v in %s", count, r.pos(call.Pos()))
Expand Down
1 change: 1 addition & 0 deletions code/rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ const (
packagePath = "github.com/pingcap/failpoint"
packageName = "failpoint"
evalFunction = "Eval"
callFunction = "Call"
evalCtxFunction = "EvalContext"
ExtendPkgName = "_curpkg_"
// It is an indicator to indicate the label is converted from `failpoint.Label("...")`
Expand Down
102 changes: 92 additions & 10 deletions code/rewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package code_test

import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
Expand All @@ -26,12 +27,15 @@ import (
"github.com/pingcap/failpoint/code"
)

type rewriteCase struct {
filepath string
errormsg string
original string
expected string
}

func TestRewrite(t *testing.T) {
var cases = []struct {
filepath string
original string
expected string
}{
var cases = []rewriteCase{
{
filepath: "func-args-test.go",
original: `
Expand Down Expand Up @@ -2477,11 +2481,7 @@ func unittest() {
}

func TestRewriteBad(t *testing.T) {
var cases = []struct {
filepath string
errormsg string
original string
}{
var cases = []rewriteCase{

{
filepath: "bad-basic-test.go",
Expand Down Expand Up @@ -3627,3 +3627,85 @@ label:
require.Equalf(t, cs.original, string(content), "%v", cs.filepath)
}
}

func TestRewriteInjectCall(t *testing.T) {
cases := []rewriteCase{

{
filepath: "test.go",
original: `
package rewriter_test
import (
"fmt"
"github.com/pingcap/failpoint"
)
func main() {
var (
a int
b string
c []float64
)
a, b, c = 1, "hello", []float64{1.0, 2.0}
failpoint.InjectCall("test", a, b, c)
}
`,
expected: `
package rewriter_test
import (
"fmt"
"github.com/pingcap/failpoint"
)
func main() {
var (
a int
b string
c []float64
)
a, b, c = 1, "hello", []float64{1.0, 2.0}
if _, _err_ := failpoint.Eval(_curpkg_("test")); _err_ == nil {
failpoint.Call(_curpkg_("test"), a, b, c)
}
}
`,
},
}
tempDir := t.TempDir()
for i, cs := range cases {
t.Run(fmt.Sprintf("case-%d", i), func(t *testing.T) {
caseDir := filepath.Join(tempDir, fmt.Sprintf("case-%d", i))
require.NoError(t, os.Mkdir(caseDir, 0755))
caseFileName := filepath.Join(caseDir, cs.filepath)
require.NoError(t, os.WriteFile(caseFileName, []byte(cs.original), 0644))

rewriter := code.NewRewriter(caseDir)
err := rewriter.Rewrite()
if cs.errormsg != "" {
require.Error(t, err)
require.Regexp(t, cs.errormsg, err.Error(), "%v", cs.filepath)

content, err := os.ReadFile(caseFileName)
require.NoError(t, err)
require.Equalf(t, cs.original, string(content), "%v", cs.filepath)
} else {
require.NoError(t, err)

content, err := os.ReadFile(caseFileName)
require.NoError(t, err)
require.Equalf(t, strings.TrimSpace(cs.expected), strings.TrimSpace(string(content)), "%v", cs.filepath)

restorer := code.NewRestorer(caseDir)
err = restorer.Restore()
require.NoError(t, err)
content, err = os.ReadFile(caseFileName)
require.NoError(t, err)
require.Equal(t, cs.original, string(content))
}
})
}
}
35 changes: 35 additions & 0 deletions examples/injectcall/inject_call.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package injectcall

import (
"context"
"fmt"

"github.com/pingcap/failpoint"
)

func foo(ctx context.Context, count int) int {
for i := 0; i < count; i++ {
fmt.Println(i)
failpoint.InjectCall("test", ctx, i, count)
select {
case <-ctx.Done():
return i
default:
}
}
return count
}
49 changes: 49 additions & 0 deletions examples/injectcall/inject_call_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package injectcall

import (
"context"
"testing"

"github.com/pingcap/failpoint"
"github.com/stretchr/testify/require"
)

func TestFoo(t *testing.T) {
ctx := context.WithValue(context.Background(), "key", "ctx-value")
ctx, cancel := context.WithCancel(ctx)
var (
capturedCtxVal string
capturedArgCount int
)
require.NoError(t, failpoint.EnableCall("github.com/pingcap/failpoint/examples/injectcall/test",
func(ctx context.Context, i, count int) {
if i == 5 {
cancel()
capturedCtxVal = ctx.Value("key").(string)
capturedArgCount = count
}
},
))
t.Cleanup(func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/failpoint/examples/injectcall/test"))
})

loopCount := foo(ctx, 123)
require.EqualValues(t, "ctx-value", capturedCtxVal)
require.EqualValues(t, 5, loopCount)
require.EqualValues(t, 123, capturedArgCount)
}
38 changes: 38 additions & 0 deletions failpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package failpoint

import (
"context"
"fmt"
"reflect"
"sync"
)

Expand All @@ -41,6 +43,8 @@ type (
mu sync.RWMutex
t *terms
waitChan chan struct{}
// fn is the function to be called for InjectCall type failpoint.
fn *reflect.Value
}
)

Expand Down Expand Up @@ -81,6 +85,24 @@ func (fp *Failpoint) EnableWith(inTerms string, action func() error) error {
return nil
}

// EnableCall enables a failpoint which is a InjectCall type failpoint.
func (fp *Failpoint) EnableCall(fn any) error {
value := reflect.ValueOf(fn)
if value.Kind() != reflect.Func {
return fmt.Errorf("failpoint: not a function")
}
t, err := newTerms("return(true)", fp)
if err != nil {
return err
}
fp.mu.Lock()
fp.t = t
fp.waitChan = make(chan struct{})
fp.fn = &value
fp.mu.Unlock()
return nil
}

// Disable stops a failpoint
func (fp *Failpoint) Disable() {
select {
Expand Down Expand Up @@ -110,3 +132,19 @@ func (fp *Failpoint) Eval() (Value, error) {
}
return v, nil
}

// Call calls the function passed by EnableCall with args supplied in InjectCall.
func (fp *Failpoint) Call(args ...any) {
fp.mu.RLock()
fn := fp.fn
fp.mu.RUnlock()

if fn == nil {
return
}
argVals := make([]reflect.Value, 0, len(args))
for _, a := range args {
argVals = append(argVals, reflect.ValueOf(a))
}
fn.Call(argVals)
}
Loading

0 comments on commit 9b3b6e3

Please sign in to comment.