Skip to content

Commit 90ab316

Browse files
authored
chore: fix linter issues
chore: fix linter issues
2 parents a88eeb3 + 7b4363c commit 90ab316

File tree

4 files changed

+138
-112
lines changed

4 files changed

+138
-112
lines changed

internal/app/app.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ type ListTablesOptions struct {
2121

2222
// ExecuteQueryOptions represents options for executing queries.
2323
type ExecuteQueryOptions struct {
24-
Query string `json:"query"`
25-
Args []interface{} `json:"args,omitempty"`
26-
Limit int `json:"limit,omitempty"`
24+
Query string `json:"query"`
25+
Args []any `json:"args,omitempty"`
26+
Limit int `json:"limit,omitempty"`
2727
}
2828

2929
// App represents the main application structure.
@@ -281,7 +281,7 @@ func (a *App) GetCurrentDatabase() (string, error) {
281281
}
282282

283283
// ExplainQuery returns the execution plan for a query.
284-
func (a *App) ExplainQuery(query string, args ...interface{}) (*QueryResult, error) {
284+
func (a *App) ExplainQuery(query string, args ...any) (*QueryResult, error) {
285285
if err := a.ensureConnection(); err != nil {
286286
return nil, err
287287
}

internal/app/client.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -358,16 +358,16 @@ func validateQuery(query string) error {
358358
}
359359

360360
// processRows processes query result rows and handles type conversion.
361-
func processRows(rows *sql.Rows) ([][]interface{}, error) {
361+
func processRows(rows *sql.Rows) ([][]any, error) {
362362
columns, err := rows.Columns()
363363
if err != nil {
364364
return nil, fmt.Errorf("failed to get columns: %w", err)
365365
}
366366

367-
var result [][]interface{}
367+
var result [][]any
368368
for rows.Next() {
369-
values := make([]interface{}, len(columns))
370-
valuePtrs := make([]interface{}, len(columns))
369+
values := make([]any, len(columns))
370+
valuePtrs := make([]any, len(columns))
371371
for i := range values {
372372
valuePtrs[i] = &values[i]
373373
}
@@ -389,7 +389,7 @@ func processRows(rows *sql.Rows) ([][]interface{}, error) {
389389
}
390390

391391
// ExecuteQuery executes a SELECT query and returns the results.
392-
func (c *PostgreSQLClientImpl) ExecuteQuery(query string, args ...interface{}) (*QueryResult, error) {
392+
func (c *PostgreSQLClientImpl) ExecuteQuery(query string, args ...any) (*QueryResult, error) {
393393
if c.db == nil {
394394
return nil, ErrNoDatabaseConnection
395395
}
@@ -425,7 +425,7 @@ func (c *PostgreSQLClientImpl) ExecuteQuery(query string, args ...interface{}) (
425425
}
426426

427427
// ExplainQuery returns the execution plan for a query.
428-
func (c *PostgreSQLClientImpl) ExplainQuery(query string, args ...interface{}) (*QueryResult, error) {
428+
func (c *PostgreSQLClientImpl) ExplainQuery(query string, args ...any) (*QueryResult, error) {
429429
if c.db == nil {
430430
return nil, ErrNoDatabaseConnection
431431
}

internal/app/interfaces.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ var (
1515
ErrQueryRequired = errors.New("query is required")
1616
ErrInvalidQuery = errors.New("only SELECT and WITH queries are allowed")
1717
ErrNoConnectionString = errors.New(
18-
"no database connection string provided. Either call connect_database tool or set POSTGRES_URL/DATABASE_URL environment variable",
18+
"no database connection string provided. " +
19+
"Either call connect_database tool or set POSTGRES_URL/DATABASE_URL environment variable",
1920
)
2021
ErrNoDatabaseConnection = errors.New("no database connection")
2122
ErrTableNotFound = errors.New("table does not exist")
@@ -69,9 +70,9 @@ type IndexInfo struct {
6970

7071
// QueryResult represents the result of a query execution.
7172
type QueryResult struct {
72-
Columns []string `json:"columns"`
73-
Rows [][]interface{} `json:"rows"`
74-
RowCount int `json:"row_count"`
73+
Columns []string `json:"columns"`
74+
Rows [][]any `json:"rows"`
75+
RowCount int `json:"row_count"`
7576
}
7677

7778
// ConnectionManager handles database connection operations.
@@ -99,8 +100,8 @@ type TableExplorer interface {
99100

100101
// QueryExecutor handles query operations.
101102
type QueryExecutor interface {
102-
ExecuteQuery(query string, args ...interface{}) (*QueryResult, error)
103-
ExplainQuery(query string, args ...interface{}) (*QueryResult, error)
103+
ExecuteQuery(query string, args ...any) (*QueryResult, error)
104+
ExplainQuery(query string, args ...any) (*QueryResult, error)
104105
}
105106

106107
// PostgreSQLClient interface combines all database operations.

main.go

Lines changed: 121 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ var version = "dev"
2222
// Error variables for static errors.
2323
var (
2424
ErrInvalidConnectionParameters = errors.New("invalid connection parameters")
25+
ErrHostRequired = errors.New("host is required")
26+
ErrUserRequired = errors.New("user is required")
27+
ErrDatabaseRequired = errors.New("database is required")
2528
)
2629

2730
// ConnectionParams represents individual database connection parameters.
@@ -39,13 +42,13 @@ type ConnectionParams struct {
3942
func buildConnectionString(params ConnectionParams) (string, error) {
4043
// Validate required parameters
4144
if params.Host == "" {
42-
return "", errors.New("host is required")
45+
return "", ErrHostRequired
4346
}
4447
if params.User == "" {
45-
return "", errors.New("user is required")
48+
return "", ErrUserRequired
4649
}
4750
if params.Database == "" {
48-
return "", errors.New("database is required")
51+
return "", ErrDatabaseRequired
4952
}
5053

5154
// Set defaults
@@ -59,26 +62,126 @@ func buildConnectionString(params ConnectionParams) (string, error) {
5962
sslMode = "prefer" // PostgreSQL default SSL mode
6063
}
6164

62-
// Build connection string
65+
// Build connection string using net.JoinHostPort pattern
66+
hostPort := fmt.Sprintf("%s:%d", params.Host, port)
6367
connStr := fmt.Sprintf(
64-
"postgres://%s:%s@%s:%d/%s?sslmode=%s",
68+
"postgres://%s:%s@%s/%s?sslmode=%s",
6569
params.User,
6670
params.Password,
67-
params.Host,
68-
port,
71+
hostPort,
6972
params.Database,
7073
sslMode,
7174
)
7275

7376
return connStr, nil
7477
}
7578

79+
// extractConnectionParams extracts connection parameters from args.
80+
func extractConnectionParams(args map[string]any) ConnectionParams {
81+
params := ConnectionParams{
82+
Host: "localhost", // Default
83+
}
84+
85+
if host, ok := args["host"].(string); ok && host != "" {
86+
params.Host = host
87+
}
88+
89+
if portFloat, ok := args["port"].(float64); ok {
90+
params.Port = int(portFloat)
91+
}
92+
93+
if user, ok := args["user"].(string); ok {
94+
params.User = user
95+
}
96+
97+
if password, ok := args["password"].(string); ok {
98+
params.Password = password
99+
}
100+
101+
if database, ok := args["database"].(string); ok {
102+
params.Database = database
103+
}
104+
105+
if sslmode, ok := args["sslmode"].(string); ok {
106+
params.SSLMode = sslmode
107+
}
108+
109+
return params
110+
}
111+
112+
// getConnectionString determines the connection string from args.
113+
func getConnectionString(
114+
args map[string]any,
115+
debugLogger *slog.Logger,
116+
) (string, error) {
117+
// Check if full connection URL is provided
118+
if connURL, ok := args["connection_url"].(string); ok && connURL != "" {
119+
debugLogger.Debug("Using provided connection URL")
120+
return connURL, nil
121+
}
122+
123+
// Build connection string from individual parameters
124+
params := extractConnectionParams(args)
125+
connectionString, err := buildConnectionString(params)
126+
if err != nil {
127+
debugLogger.Error("Failed to build connection string", "error", err)
128+
return "", fmt.Errorf("invalid connection parameters: %w", err)
129+
}
130+
131+
debugLogger.Debug("Built connection string from parameters",
132+
"host", params.Host, "port", params.Port, "database", params.Database)
133+
return connectionString, nil
134+
}
135+
136+
// handleConnectDatabaseRequest handles the connect_database tool request.
137+
func handleConnectDatabaseRequest(
138+
args map[string]any,
139+
appInstance *app.App,
140+
debugLogger *slog.Logger,
141+
) (*mcp.CallToolResult, error) {
142+
debugLogger.Debug("Received connect_database tool request", "args", args)
143+
144+
connectionString, err := getConnectionString(args, debugLogger)
145+
if err != nil {
146+
return mcp.NewToolResultError(err.Error()), nil
147+
}
148+
149+
// Attempt to connect
150+
if err := appInstance.Connect(connectionString); err != nil {
151+
debugLogger.Error("Failed to connect to database", "error", err)
152+
return mcp.NewToolResultError(fmt.Sprintf("Failed to connect to database: %v", err)), nil
153+
}
154+
155+
// Get current database name to confirm connection
156+
dbName, err := appInstance.GetCurrentDatabase()
157+
if err != nil {
158+
debugLogger.Warn("Connected but failed to get database name", "error", err)
159+
dbName = "unknown"
160+
}
161+
162+
debugLogger.Info("Successfully connected to database", "database", dbName)
163+
164+
response := map[string]any{
165+
"status": "connected",
166+
"database": dbName,
167+
"message": "Successfully connected to database: " + dbName,
168+
}
169+
170+
jsonData, err := json.Marshal(response)
171+
if err != nil {
172+
debugLogger.Error("Failed to marshal connection response", "error", err)
173+
return mcp.NewToolResultError("Failed to format connection response"), nil
174+
}
175+
176+
return mcp.NewToolResultText(string(jsonData)), nil
177+
}
178+
76179
// setupConnectDatabaseTool creates and registers the connect_database tool.
77180
func setupConnectDatabaseTool(s *server.MCPServer, appInstance *app.App, debugLogger *slog.Logger) {
78181
connectDBTool := mcp.NewTool("connect_database",
79-
mcp.WithDescription("Connect to a PostgreSQL database using connection parameters or connection URL"),
182+
mcp.WithDescription("Connect to a PostgreSQL database using connection parameters or URL"),
80183
mcp.WithString("connection_url",
81-
mcp.Description("Full PostgreSQL connection URL (postgres://user:password@host:port/dbname?sslmode=mode). If provided, individual parameters are ignored."),
184+
mcp.Description("Full PostgreSQL connection URL. If provided, individual parameters are ignored."),
82185
),
83186
mcp.WithString("host",
84187
mcp.Description("Database host (default: localhost)"),
@@ -101,85 +204,7 @@ func setupConnectDatabaseTool(s *server.MCPServer, appInstance *app.App, debugLo
101204
)
102205

103206
s.AddTool(connectDBTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
104-
args := request.GetArguments()
105-
debugLogger.Debug("Received connect_database tool request", "args", args)
106-
107-
var connectionString string
108-
109-
// Check if full connection URL is provided
110-
if connURL, ok := args["connection_url"].(string); ok && connURL != "" {
111-
connectionString = connURL
112-
debugLogger.Debug("Using provided connection URL")
113-
} else {
114-
// Build connection string from individual parameters
115-
params := ConnectionParams{}
116-
117-
if host, ok := args["host"].(string); ok && host != "" {
118-
params.Host = host
119-
} else {
120-
params.Host = "localhost" // Default
121-
}
122-
123-
if portFloat, ok := args["port"].(float64); ok {
124-
params.Port = int(portFloat)
125-
}
126-
// Port will default to 5432 in buildConnectionString if 0
127-
128-
if user, ok := args["user"].(string); ok {
129-
params.User = user
130-
}
131-
132-
if password, ok := args["password"].(string); ok {
133-
params.Password = password
134-
}
135-
136-
if database, ok := args["database"].(string); ok {
137-
params.Database = database
138-
}
139-
140-
if sslmode, ok := args["sslmode"].(string); ok {
141-
params.SSLMode = sslmode
142-
}
143-
144-
// Validate and build connection string
145-
var err error
146-
connectionString, err = buildConnectionString(params)
147-
if err != nil {
148-
debugLogger.Error("Failed to build connection string", "error", err)
149-
return mcp.NewToolResultError(fmt.Sprintf("Invalid connection parameters: %v", err)), nil
150-
}
151-
152-
debugLogger.Debug("Built connection string from parameters", "host", params.Host, "port", params.Port, "database", params.Database)
153-
}
154-
155-
// Attempt to connect
156-
if err := appInstance.Connect(connectionString); err != nil {
157-
debugLogger.Error("Failed to connect to database", "error", err)
158-
return mcp.NewToolResultError(fmt.Sprintf("Failed to connect to database: %v", err)), nil
159-
}
160-
161-
// Get current database name to confirm connection
162-
dbName, err := appInstance.GetCurrentDatabase()
163-
if err != nil {
164-
debugLogger.Warn("Connected but failed to get database name", "error", err)
165-
dbName = "unknown"
166-
}
167-
168-
debugLogger.Info("Successfully connected to database", "database", dbName)
169-
170-
response := map[string]interface{}{
171-
"status": "connected",
172-
"database": dbName,
173-
"message": fmt.Sprintf("Successfully connected to database: %s", dbName),
174-
}
175-
176-
jsonData, err := json.Marshal(response)
177-
if err != nil {
178-
debugLogger.Error("Failed to marshal connection response", "error", err)
179-
return mcp.NewToolResultError("Failed to format connection response"), nil
180-
}
181-
182-
return mcp.NewToolResultText(string(jsonData)), nil
207+
return handleConnectDatabaseRequest(request.GetArguments(), appInstance, debugLogger)
183208
})
184209
}
185210

@@ -289,7 +314,7 @@ func setupListTablesTool(s *server.MCPServer, appInstance *app.App, debugLogger
289314

290315
// handleTableSchemaToolRequest handles tool requests that require table and optional schema parameters.
291316
func handleTableSchemaToolRequest(
292-
args map[string]interface{},
317+
args map[string]any,
293318
debugLogger *slog.Logger,
294319
toolName string,
295320
) (string, string, error) {
@@ -311,7 +336,7 @@ func handleTableSchemaToolRequest(
311336
}
312337

313338
// marshalToJSON converts data to JSON and handles errors.
314-
func marshalToJSON(data interface{}, debugLogger *slog.Logger, errorMsg string) ([]byte, error) {
339+
func marshalToJSON(data any, debugLogger *slog.Logger, errorMsg string) ([]byte, error) {
315340
jsonData, err := json.Marshal(data)
316341
if err != nil {
317342
debugLogger.Error("Failed to marshal data to JSON", "error", err, "context", errorMsg)
@@ -325,8 +350,8 @@ type TableToolConfig struct {
325350
Name string
326351
Description string
327352
TableDesc string
328-
Operation func(appInstance *app.App, schema, table string) (interface{}, error)
329-
SuccessMsg func(result interface{}, schema, table string) (string, []any)
353+
Operation func(appInstance *app.App, schema, table string) (any, error)
354+
SuccessMsg func(result any, schema, table string) (string, []any)
330355
ErrorMsg string
331356
}
332357

@@ -375,10 +400,10 @@ func setupDescribeTableTool(s *server.MCPServer, appInstance *app.App, debugLogg
375400
Name: "describe_table",
376401
Description: "Get detailed information about a table's structure (columns, types, constraints)",
377402
TableDesc: "Table name to describe",
378-
Operation: func(appInstance *app.App, schema, table string) (interface{}, error) {
403+
Operation: func(appInstance *app.App, schema, table string) (any, error) {
379404
return appInstance.DescribeTable(schema, table)
380405
},
381-
SuccessMsg: func(result interface{}, schema, table string) (string, []any) {
406+
SuccessMsg: func(result any, schema, table string) (string, []any) {
382407
columns, ok := result.([]*app.ColumnInfo)
383408
if !ok {
384409
return "Error processing result", []any{"error", "type assertion failed"}
@@ -449,10 +474,10 @@ func setupListIndexesTool(s *server.MCPServer, appInstance *app.App, debugLogger
449474
Name: "list_indexes",
450475
Description: "List indexes for a specific table",
451476
TableDesc: "Table name to list indexes for",
452-
Operation: func(appInstance *app.App, schema, table string) (interface{}, error) {
477+
Operation: func(appInstance *app.App, schema, table string) (any, error) {
453478
return appInstance.ListIndexes(schema, table)
454479
},
455-
SuccessMsg: func(result interface{}, schema, table string) (string, []any) {
480+
SuccessMsg: func(result any, schema, table string) (string, []any) {
456481
indexes, ok := result.([]*app.IndexInfo)
457482
if !ok {
458483
return "Error processing result", []any{"error", "type assertion failed"}

0 commit comments

Comments
 (0)