@@ -22,6 +22,9 @@ var version = "dev"
2222// Error variables for static errors.
2323var (
2424 ErrInvalidConnectionParameters = errors .New ("invalid connection parameters" )
25+ ErrHostRequired = errors .New ("host is required" )
26+ ErrUserRequired = errors .New ("user is required" )
27+ ErrDatabaseRequired = errors .New ("database is required" )
2528)
2629
2730// ConnectionParams represents individual database connection parameters.
@@ -39,13 +42,13 @@ type ConnectionParams struct {
3942func buildConnectionString (params ConnectionParams ) (string , error ) {
4043 // Validate required parameters
4144 if params .Host == "" {
42- return "" , errors . New ( "host is required" )
45+ return "" , ErrHostRequired
4346 }
4447 if params .User == "" {
45- return "" , errors . New ( "user is required" )
48+ return "" , ErrUserRequired
4649 }
4750 if params .Database == "" {
48- return "" , errors . New ( "database is required" )
51+ return "" , ErrDatabaseRequired
4952 }
5053
5154 // Set defaults
@@ -59,26 +62,126 @@ func buildConnectionString(params ConnectionParams) (string, error) {
5962 sslMode = "prefer" // PostgreSQL default SSL mode
6063 }
6164
62- // Build connection string
65+ // Build connection string using net.JoinHostPort pattern
66+ hostPort := fmt .Sprintf ("%s:%d" , params .Host , port )
6367 connStr := fmt .Sprintf (
64- "postgres://%s:%s@%s:%d /%s?sslmode=%s" ,
68+ "postgres://%s:%s@%s/%s?sslmode=%s" ,
6569 params .User ,
6670 params .Password ,
67- params .Host ,
68- port ,
71+ hostPort ,
6972 params .Database ,
7073 sslMode ,
7174 )
7275
7376 return connStr , nil
7477}
7578
79+ // extractConnectionParams extracts connection parameters from args.
80+ func extractConnectionParams (args map [string ]any ) ConnectionParams {
81+ params := ConnectionParams {
82+ Host : "localhost" , // Default
83+ }
84+
85+ if host , ok := args ["host" ].(string ); ok && host != "" {
86+ params .Host = host
87+ }
88+
89+ if portFloat , ok := args ["port" ].(float64 ); ok {
90+ params .Port = int (portFloat )
91+ }
92+
93+ if user , ok := args ["user" ].(string ); ok {
94+ params .User = user
95+ }
96+
97+ if password , ok := args ["password" ].(string ); ok {
98+ params .Password = password
99+ }
100+
101+ if database , ok := args ["database" ].(string ); ok {
102+ params .Database = database
103+ }
104+
105+ if sslmode , ok := args ["sslmode" ].(string ); ok {
106+ params .SSLMode = sslmode
107+ }
108+
109+ return params
110+ }
111+
112+ // getConnectionString determines the connection string from args.
113+ func getConnectionString (
114+ args map [string ]any ,
115+ debugLogger * slog.Logger ,
116+ ) (string , error ) {
117+ // Check if full connection URL is provided
118+ if connURL , ok := args ["connection_url" ].(string ); ok && connURL != "" {
119+ debugLogger .Debug ("Using provided connection URL" )
120+ return connURL , nil
121+ }
122+
123+ // Build connection string from individual parameters
124+ params := extractConnectionParams (args )
125+ connectionString , err := buildConnectionString (params )
126+ if err != nil {
127+ debugLogger .Error ("Failed to build connection string" , "error" , err )
128+ return "" , fmt .Errorf ("invalid connection parameters: %w" , err )
129+ }
130+
131+ debugLogger .Debug ("Built connection string from parameters" ,
132+ "host" , params .Host , "port" , params .Port , "database" , params .Database )
133+ return connectionString , nil
134+ }
135+
136+ // handleConnectDatabaseRequest handles the connect_database tool request.
137+ func handleConnectDatabaseRequest (
138+ args map [string ]any ,
139+ appInstance * app.App ,
140+ debugLogger * slog.Logger ,
141+ ) (* mcp.CallToolResult , error ) {
142+ debugLogger .Debug ("Received connect_database tool request" , "args" , args )
143+
144+ connectionString , err := getConnectionString (args , debugLogger )
145+ if err != nil {
146+ return mcp .NewToolResultError (err .Error ()), nil
147+ }
148+
149+ // Attempt to connect
150+ if err := appInstance .Connect (connectionString ); err != nil {
151+ debugLogger .Error ("Failed to connect to database" , "error" , err )
152+ return mcp .NewToolResultError (fmt .Sprintf ("Failed to connect to database: %v" , err )), nil
153+ }
154+
155+ // Get current database name to confirm connection
156+ dbName , err := appInstance .GetCurrentDatabase ()
157+ if err != nil {
158+ debugLogger .Warn ("Connected but failed to get database name" , "error" , err )
159+ dbName = "unknown"
160+ }
161+
162+ debugLogger .Info ("Successfully connected to database" , "database" , dbName )
163+
164+ response := map [string ]any {
165+ "status" : "connected" ,
166+ "database" : dbName ,
167+ "message" : "Successfully connected to database: " + dbName ,
168+ }
169+
170+ jsonData , err := json .Marshal (response )
171+ if err != nil {
172+ debugLogger .Error ("Failed to marshal connection response" , "error" , err )
173+ return mcp .NewToolResultError ("Failed to format connection response" ), nil
174+ }
175+
176+ return mcp .NewToolResultText (string (jsonData )), nil
177+ }
178+
76179// setupConnectDatabaseTool creates and registers the connect_database tool.
77180func setupConnectDatabaseTool (s * server.MCPServer , appInstance * app.App , debugLogger * slog.Logger ) {
78181 connectDBTool := mcp .NewTool ("connect_database" ,
79- mcp .WithDescription ("Connect to a PostgreSQL database using connection parameters or connection URL" ),
182+ mcp .WithDescription ("Connect to a PostgreSQL database using connection parameters or URL" ),
80183 mcp .WithString ("connection_url" ,
81- mcp .Description ("Full PostgreSQL connection URL (postgres://user:password@host:port/dbname?sslmode=mode) . If provided, individual parameters are ignored." ),
184+ mcp .Description ("Full PostgreSQL connection URL. If provided, individual parameters are ignored." ),
82185 ),
83186 mcp .WithString ("host" ,
84187 mcp .Description ("Database host (default: localhost)" ),
@@ -101,85 +204,7 @@ func setupConnectDatabaseTool(s *server.MCPServer, appInstance *app.App, debugLo
101204 )
102205
103206 s .AddTool (connectDBTool , func (ctx context.Context , request mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
104- args := request .GetArguments ()
105- debugLogger .Debug ("Received connect_database tool request" , "args" , args )
106-
107- var connectionString string
108-
109- // Check if full connection URL is provided
110- if connURL , ok := args ["connection_url" ].(string ); ok && connURL != "" {
111- connectionString = connURL
112- debugLogger .Debug ("Using provided connection URL" )
113- } else {
114- // Build connection string from individual parameters
115- params := ConnectionParams {}
116-
117- if host , ok := args ["host" ].(string ); ok && host != "" {
118- params .Host = host
119- } else {
120- params .Host = "localhost" // Default
121- }
122-
123- if portFloat , ok := args ["port" ].(float64 ); ok {
124- params .Port = int (portFloat )
125- }
126- // Port will default to 5432 in buildConnectionString if 0
127-
128- if user , ok := args ["user" ].(string ); ok {
129- params .User = user
130- }
131-
132- if password , ok := args ["password" ].(string ); ok {
133- params .Password = password
134- }
135-
136- if database , ok := args ["database" ].(string ); ok {
137- params .Database = database
138- }
139-
140- if sslmode , ok := args ["sslmode" ].(string ); ok {
141- params .SSLMode = sslmode
142- }
143-
144- // Validate and build connection string
145- var err error
146- connectionString , err = buildConnectionString (params )
147- if err != nil {
148- debugLogger .Error ("Failed to build connection string" , "error" , err )
149- return mcp .NewToolResultError (fmt .Sprintf ("Invalid connection parameters: %v" , err )), nil
150- }
151-
152- debugLogger .Debug ("Built connection string from parameters" , "host" , params .Host , "port" , params .Port , "database" , params .Database )
153- }
154-
155- // Attempt to connect
156- if err := appInstance .Connect (connectionString ); err != nil {
157- debugLogger .Error ("Failed to connect to database" , "error" , err )
158- return mcp .NewToolResultError (fmt .Sprintf ("Failed to connect to database: %v" , err )), nil
159- }
160-
161- // Get current database name to confirm connection
162- dbName , err := appInstance .GetCurrentDatabase ()
163- if err != nil {
164- debugLogger .Warn ("Connected but failed to get database name" , "error" , err )
165- dbName = "unknown"
166- }
167-
168- debugLogger .Info ("Successfully connected to database" , "database" , dbName )
169-
170- response := map [string ]interface {}{
171- "status" : "connected" ,
172- "database" : dbName ,
173- "message" : fmt .Sprintf ("Successfully connected to database: %s" , dbName ),
174- }
175-
176- jsonData , err := json .Marshal (response )
177- if err != nil {
178- debugLogger .Error ("Failed to marshal connection response" , "error" , err )
179- return mcp .NewToolResultError ("Failed to format connection response" ), nil
180- }
181-
182- return mcp .NewToolResultText (string (jsonData )), nil
207+ return handleConnectDatabaseRequest (request .GetArguments (), appInstance , debugLogger )
183208 })
184209}
185210
@@ -289,7 +314,7 @@ func setupListTablesTool(s *server.MCPServer, appInstance *app.App, debugLogger
289314
290315// handleTableSchemaToolRequest handles tool requests that require table and optional schema parameters.
291316func handleTableSchemaToolRequest (
292- args map [string ]interface {} ,
317+ args map [string ]any ,
293318 debugLogger * slog.Logger ,
294319 toolName string ,
295320) (string , string , error ) {
@@ -311,7 +336,7 @@ func handleTableSchemaToolRequest(
311336}
312337
313338// marshalToJSON converts data to JSON and handles errors.
314- func marshalToJSON (data interface {} , debugLogger * slog.Logger , errorMsg string ) ([]byte , error ) {
339+ func marshalToJSON (data any , debugLogger * slog.Logger , errorMsg string ) ([]byte , error ) {
315340 jsonData , err := json .Marshal (data )
316341 if err != nil {
317342 debugLogger .Error ("Failed to marshal data to JSON" , "error" , err , "context" , errorMsg )
@@ -325,8 +350,8 @@ type TableToolConfig struct {
325350 Name string
326351 Description string
327352 TableDesc string
328- Operation func (appInstance * app.App , schema , table string ) (interface {} , error )
329- SuccessMsg func (result interface {} , schema , table string ) (string , []any )
353+ Operation func (appInstance * app.App , schema , table string ) (any , error )
354+ SuccessMsg func (result any , schema , table string ) (string , []any )
330355 ErrorMsg string
331356}
332357
@@ -375,10 +400,10 @@ func setupDescribeTableTool(s *server.MCPServer, appInstance *app.App, debugLogg
375400 Name : "describe_table" ,
376401 Description : "Get detailed information about a table's structure (columns, types, constraints)" ,
377402 TableDesc : "Table name to describe" ,
378- Operation : func (appInstance * app.App , schema , table string ) (interface {} , error ) {
403+ Operation : func (appInstance * app.App , schema , table string ) (any , error ) {
379404 return appInstance .DescribeTable (schema , table )
380405 },
381- SuccessMsg : func (result interface {} , schema , table string ) (string , []any ) {
406+ SuccessMsg : func (result any , schema , table string ) (string , []any ) {
382407 columns , ok := result .([]* app.ColumnInfo )
383408 if ! ok {
384409 return "Error processing result" , []any {"error" , "type assertion failed" }
@@ -449,10 +474,10 @@ func setupListIndexesTool(s *server.MCPServer, appInstance *app.App, debugLogger
449474 Name : "list_indexes" ,
450475 Description : "List indexes for a specific table" ,
451476 TableDesc : "Table name to list indexes for" ,
452- Operation : func (appInstance * app.App , schema , table string ) (interface {} , error ) {
477+ Operation : func (appInstance * app.App , schema , table string ) (any , error ) {
453478 return appInstance .ListIndexes (schema , table )
454479 },
455- SuccessMsg : func (result interface {} , schema , table string ) (string , []any ) {
480+ SuccessMsg : func (result any , schema , table string ) (string , []any ) {
456481 indexes , ok := result .([]* app.IndexInfo )
457482 if ! ok {
458483 return "Error processing result" , []any {"error" , "type assertion failed" }
0 commit comments