diff --git a/plugin/host2plugin/host2plugin_test.go b/plugin/host2plugin/host2plugin_test.go index b918052..662999b 100644 --- a/plugin/host2plugin/host2plugin_test.go +++ b/plugin/host2plugin/host2plugin_test.go @@ -41,6 +41,7 @@ type mockRuleSetImpl struct { configSchema func() *hclext.BodySchema applyGlobalConfig func(*tflint.Config) error applyConfig func(*hclext.BodyContent) error + newRunner func(tflint.Runner) (tflint.Runner, error) check func(tflint.Runner) error } @@ -79,6 +80,13 @@ func (r *mockRuleSet) ApplyConfig(content *hclext.BodyContent) error { return nil } +func (r *mockRuleSet) NewRunner(runner tflint.Runner) (tflint.Runner, error) { + if r.impl.newRunner != nil { + return r.impl.newRunner(runner) + } + return runner, nil +} + func (r *mockRuleSet) Check(runner tflint.Runner) error { if r.impl.check != nil { return r.impl.check(runner) @@ -575,6 +583,14 @@ func (s *mockServer) GetFiles(tflint.ModuleCtxType) map[string][]byte { return map[string][]byte{} } +type mockCustomRunner struct { + tflint.Runner +} + +func (s *mockCustomRunner) Hello() string { + return "Hello from custom runner!" +} + func TestCheck(t *testing.T) { // default error check helper neverHappend := func(err error) bool { return err != nil } @@ -589,10 +605,11 @@ func TestCheck(t *testing.T) { } tests := []struct { - Name string - Arg func() plugin2host.Server - ServerImpl func(tflint.Runner) error - ErrCheck func(error) bool + Name string + Arg func() plugin2host.Server + ServerImpl func(tflint.Runner) error + NewRunnerImpl func(tflint.Runner) (tflint.Runner, error) + ErrCheck func(error) bool }{ { Name: "bidirectional", @@ -667,11 +684,26 @@ resource "aws_instance" "foo" { return err == nil || err.Error() != "unexpected error" }, }, + { + Name: "inject new runner", + Arg: func() plugin2host.Server { + return &mockServer{} + }, + NewRunnerImpl: func(runner tflint.Runner) (tflint.Runner, error) { + return &mockCustomRunner{runner}, nil + }, + ServerImpl: func(runner tflint.Runner) error { + return errors.New(runner.(*mockCustomRunner).Hello()) + }, + ErrCheck: func(err error) bool { + return err == nil || err.Error() != "Hello from custom runner!" + }, + }, } for _, test := range tests { t.Run(test.Name, func(t *testing.T) { - client := startTestGRPCPluginServer(t, newMockRuleSet("test_ruleset", "0.1.0", mockRuleSetImpl{check: test.ServerImpl})) + client := startTestGRPCPluginServer(t, newMockRuleSet("test_ruleset", "0.1.0", mockRuleSetImpl{check: test.ServerImpl, newRunner: test.NewRunnerImpl})) err := client.Check(test.Arg()) if test.ErrCheck(err) { diff --git a/plugin/host2plugin/server.go b/plugin/host2plugin/server.go index 6eee98c..77b4fa4 100644 --- a/plugin/host2plugin/server.go +++ b/plugin/host2plugin/server.go @@ -116,8 +116,11 @@ func (s *GRPCServer) Check(ctx context.Context, req *proto.Check_Request) (*prot } defer conn.Close() - err = s.impl.Check(&plugin2host.GRPCClient{Client: proto.NewRunnerClient(conn)}) - + runner, err := s.impl.NewRunner(&plugin2host.GRPCClient{Client: proto.NewRunnerClient(conn)}) + if err != nil { + return nil, toproto.Error(codes.FailedPrecondition, err) + } + err = s.impl.Check(runner) if err != nil { return nil, toproto.Error(codes.Aborted, err) } diff --git a/tflint/ruleset.go b/tflint/ruleset.go index c0e2d25..abeea6e 100644 --- a/tflint/ruleset.go +++ b/tflint/ruleset.go @@ -106,11 +106,6 @@ func (r *BuiltinRuleSet) NewRunner(runner Runner) (Runner, error) { // Check runs inspection for each rule by applying Runner. func (r *BuiltinRuleSet) Check(runner Runner) error { - runner, err := r.NewRunner(runner) - if err != nil { - return err - } - for _, rule := range r.EnabledRules { if err := rule.Check(runner); err != nil { return fmt.Errorf("Failed to check `%s` rule: %s", rule.Name(), err)