Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## Unreleased

FEATURES

* **Toolsets Flag**: Added `--toolsets` flag to selectively enable tool groups. Three toolset groups are available: `registry` (public Terraform Registry), `registry-private` (private TFE/TFC registry), and `terraform` (TFE/TFC operations). Default is `registry` only.

## 0.3.3 (Nov 21, 2025)

IMPROVEMENTS
Expand Down
16 changes: 11 additions & 5 deletions cmd/terraform-mcp-server/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/hashicorp/terraform-mcp-server/pkg/client"
"github.com/hashicorp/terraform-mcp-server/pkg/resources"
"github.com/hashicorp/terraform-mcp-server/pkg/tools"
"github.com/hashicorp/terraform-mcp-server/pkg/toolsets"
"github.com/hashicorp/terraform-mcp-server/version"
"github.com/mark3labs/mcp-go/server"
log "github.com/sirupsen/logrus"
Expand All @@ -37,7 +38,7 @@ var (
Use: "stdio",
Short: "Start stdio server",
Long: `Start a server that communicates via standard input/output streams using JSON-RPC messages.`,
Run: func(_ *cobra.Command, _ []string) {
Run: func(cmd *cobra.Command, _ []string) {
logFile, err := rootCmd.PersistentFlags().GetString("log-file")
if err != nil {
stdlog.Fatal("Failed to get log file:", err)
Expand All @@ -47,7 +48,9 @@ var (
stdlog.Fatal("Failed to initialize logger:", err)
}

if err := runStdioServer(logger); err != nil {
enabledToolsets := getToolsetsFromCmd(cmd.Root(), logger)

if err := runStdioServer(logger, enabledToolsets); err != nil {
stdlog.Fatal("failed to run stdio server:", err)
}
},
Expand Down Expand Up @@ -81,7 +84,9 @@ var (
stdlog.Fatal("Failed to get endpoint path:", err)
}

if err := runHTTPServer(logger, host, port, endpointPath); err != nil {
enabledToolsets := getToolsetsFromCmd(cmd.Root(), logger)

if err := runHTTPServer(logger, host, port, endpointPath, enabledToolsets); err != nil {
stdlog.Fatal("failed to run streamableHTTP server:", err)
}
},
Expand All @@ -104,6 +109,7 @@ func init() {
cobra.OnInitialize(initConfig)
rootCmd.SetVersionTemplate("{{.Short}}\n{{.Version}}\n")
rootCmd.PersistentFlags().String("log-file", "", "Path to log file")
rootCmd.PersistentFlags().String("toolsets", "default", toolsets.GenerateToolsetsHelp())

// Add StreamableHTTP command flags (avoid 'h' shorthand conflict with help)
streamableHTTPCmd.Flags().String("transport-host", "127.0.0.1", "Host to bind to")
Expand Down Expand Up @@ -142,8 +148,8 @@ func initLogger(outPath string) (*log.Logger, error) {
}

// registerToolsAndResources registers tools and resources with the MCP server
func registerToolsAndResources(hcServer *server.MCPServer, logger *log.Logger) {
tools.RegisterTools(hcServer, logger)
func registerToolsAndResources(hcServer *server.MCPServer, logger *log.Logger, enabledToolsets []string) {
tools.RegisterTools(hcServer, logger, enabledToolsets)
resources.RegisterResources(hcServer, logger)
resources.RegisterResourceTemplates(hcServer, logger)
}
Expand Down
51 changes: 42 additions & 9 deletions cmd/terraform-mcp-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"syscall"

"github.com/hashicorp/terraform-mcp-server/pkg/client"
"github.com/hashicorp/terraform-mcp-server/pkg/toolsets"
"github.com/hashicorp/terraform-mcp-server/version"

"github.com/mark3labs/mcp-go/server"
Expand All @@ -24,27 +25,27 @@ import (
//go:embed instructions.md
var instructions string

func runHTTPServer(logger *log.Logger, host string, port string, endpointPath string) error {
func runHTTPServer(logger *log.Logger, host string, port string, endpointPath string, enabledToolsets []string) error {
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()

hcServer := NewServer(version.Version, logger)
registerToolsAndResources(hcServer, logger)
hcServer := NewServer(version.Version, logger, enabledToolsets)
registerToolsAndResources(hcServer, logger, enabledToolsets)

return streamableHTTPServerInit(ctx, hcServer, logger, host, port, endpointPath)
}

func runStdioServer(logger *log.Logger) error {
func runStdioServer(logger *log.Logger, enabledToolsets []string) error {
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()

hcServer := NewServer(version.Version, logger)
registerToolsAndResources(hcServer, logger)
hcServer := NewServer(version.Version, logger, enabledToolsets)
registerToolsAndResources(hcServer, logger, enabledToolsets)

return serverInit(ctx, hcServer, logger)
}

func NewServer(version string, logger *log.Logger, opts ...server.ServerOption) *server.MCPServer {
func NewServer(version string, logger *log.Logger, enabledToolsets []string, opts ...server.ServerOption) *server.MCPServer {
// Create rate limiting middleware with environment-based configuration
rateLimitConfig := client.LoadRateLimitConfigFromEnv()
rateLimitMiddleware := client.NewRateLimitMiddleware(rateLimitConfig, logger)
Expand Down Expand Up @@ -80,6 +81,33 @@ func NewServer(version string, logger *log.Logger, opts ...server.ServerOption)
return s
}

// parseToolsets parses and validates the toolsets flag value
func parseToolsets(toolsetsFlag string, logger *log.Logger) []string {
rawToolsets := strings.Split(toolsetsFlag, ",")

cleaned, invalid := toolsets.CleanToolsets(rawToolsets)
if len(invalid) > 0 {
logger.Warnf("Invalid toolsets ignored: %v", invalid)
}

expanded := toolsets.ExpandDefaultToolset(cleaned)

logger.Infof("Enabled toolsets: %v", expanded)
return expanded
}

func getToolsetsFromCmd(cmd *cobra.Command, logger *log.Logger) []string {
toolsetsFlag, err := cmd.Flags().GetString("toolsets")
if err != nil {
toolsetsFlag, err = cmd.Root().PersistentFlags().GetString("toolsets")
if err != nil {
logger.Warnf("Failed to get toolsets flag, using default: %v", err)
toolsetsFlag = "default"
}
}
return parseToolsets(toolsetsFlag, logger)
}

// runDefaultCommand handles the default behavior when no subcommand is provided
func runDefaultCommand(cmd *cobra.Command, _ []string) {
// Default to stdio mode when no subcommand is provided
Expand All @@ -92,7 +120,10 @@ func runDefaultCommand(cmd *cobra.Command, _ []string) {
stdlog.Fatal("Failed to initialize logger:", err)
}

if err := runStdioServer(logger); err != nil {
// Get toolsets from the command that was passed in
enabledToolsets := getToolsetsFromCmd(cmd, logger)

if err := runStdioServer(logger, enabledToolsets); err != nil {
stdlog.Fatal("failed to run stdio server:", err)
}
}
Expand All @@ -110,7 +141,9 @@ func main() {
stdlog.Fatal("Failed to initialize logger:", err)
}

if err := runHTTPServer(logger, host, port, endpointPath); err != nil {
enabledToolsets := getToolsetsFromCmd(rootCmd, logger)

if err := runHTTPServer(logger, host, port, endpointPath, enabledToolsets); err != nil {
stdlog.Fatal("failed to run StreamableHTTP server:", err)
}
return
Expand Down
Loading
Loading