diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..796b93b --- /dev/null +++ b/Dockerfile @@ -0,0 +1,29 @@ +FROM python:3.10-slim + +# Set the working directory in the container +WORKDIR /app + +# Copy requirements first to leverage Docker cache +COPY requirements.txt . + +# Install any needed packages specified in requirements.txt +# Use --no-cache-dir to reduce image size +RUN pip install --no-cache-dir -r requirements.txt + +# Optional: Clean up build dependencies to reduce image size +# RUN apt-get purge -y --auto-remove build-essential + +# Copy the rest of the application code into the container at /app +COPY . . + +# Make port 8000 available to the world outside this container +# (MCP servers typically run on port 8000 by default) +EXPOSE 8000 + +# Define environment variables (these will be overridden by docker run -e flags) +ENV DATABRICKS_HOST="" +ENV DATABRICKS_TOKEN="" +ENV DATABRICKS_HTTP_PATH="" + +# Run main.py when the container launches +CMD ["python", "main.py"] \ No newline at end of file diff --git a/README.md b/README.md index 85b8eae..2e1e4f6 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,46 @@ You can test the MCP server using the inspector by running npx @modelcontextprotocol/inspector python3 main.py ``` +## Configuring with Docker (for MCP Clients like Cursor) + +If you are integrating this server with an MCP client (like Cursor), you might configure it using Docker. The client will typically manage running the Docker container based on a configuration file (e.g., `mcp.json`). + +A pre-built image is available on Docker Hub and can be pulled using: +```bash +docker pull jordineil/databricks-mcp-server +``` + +The configuration passes environment variables directly to the Docker container. Here's an example structure, replacing placeholders with your actual credentials and using the public image name: + +```json +{ + "mcpServers": { + "databricks-docker": { + "command": "docker", + "args": [ + "run", + "--rm", + "-i", + "-e", + "DATABRICKS_HOST=", + "-e", + "DATABRICKS_TOKEN=", + "-e", + "DATABRICKS_HTTP_PATH=", + "jordineil/databricks-mcp-server" + ] + } + // ... other servers ... + } +} +``` + +- Replace `` with your Databricks host (e.g., `dbc-xyz.cloud.databricks.com`). +- Replace `` with your personal access token. +- Replace `` with the HTTP path for your SQL warehouse. + +This method avoids storing secrets directly in a `.env` file within the project, as the MCP client injects them at runtime. + ## Available MCP Tools The following MCP tools are available: diff --git a/main.py b/main.py index 8f256a8..b8052b7 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,12 @@ import os from typing import Dict from dotenv import load_dotenv -from databricks.sql import connect +from databricks.sql import connect, exc as db_exc # Import Databricks SQL exceptions from databricks.sql.client import Connection from mcp.server.fastmcp import FastMCP import requests +from requests import exceptions as req_exc # Import requests exceptions +import logging # Load environment variables load_dotenv() @@ -17,23 +19,62 @@ # Set up the MCP server mcp = FastMCP("Databricks API Explorer") +# Global variable to hold the Databricks SQL connection +_db_connection: Connection | None = None -# Helper function to get a Databricks SQL connection +# Helper function to get a reusable Databricks SQL connection def get_databricks_connection() -> Connection: - """Create and return a Databricks SQL connection""" + """Create and return a reusable Databricks SQL connection.""" + global _db_connection + + # Check if connection exists and is open + if _db_connection is not None: + try: + # A simple way to check if the connection is still valid + # This might depend on the driver's implementation; adjust if needed + cursor = _db_connection.cursor() + cursor.execute("SELECT 1") + cursor.close() + logging.info("Reusing existing Databricks SQL connection.") + return _db_connection + except db_exc.Error as db_error: # Catch specific DBAPI errors for liveness check + logging.warning(f"Existing connection seems invalid (DB Error: {db_error}), creating a new one.") + try: + if _db_connection: # Ensure _db_connection is not None before closing + _db_connection.close() + except db_exc.Error: # Ignore DB errors during close if connection was already broken + pass + except Exception as close_exc: # Catch other potential close errors + logging.error(f"Unexpected error closing potentially broken DB connection: {close_exc}") + _db_connection = None # Ensure we create a new one + except Exception as e: # Catch other unexpected errors during liveness check + logging.warning(f"Unexpected error checking connection liveness ({e}), creating a new one.") + _db_connection = None + if not all([DATABRICKS_HOST, DATABRICKS_TOKEN, DATABRICKS_HTTP_PATH]): + # This case is critical configuration error, raising is appropriate raise ValueError("Missing required Databricks connection details in .env file") - return connect( - server_hostname=DATABRICKS_HOST, - http_path=DATABRICKS_HTTP_PATH, - access_token=DATABRICKS_TOKEN - ) + try: + logging.info("Creating new Databricks SQL connection.") + _db_connection = connect( + server_hostname=DATABRICKS_HOST, + http_path=DATABRICKS_HTTP_PATH, + access_token=DATABRICKS_TOKEN + ) + return _db_connection + except db_exc.Error as db_connect_error: # Catch DB connection errors + logging.error(f"Failed to connect to Databricks SQL Warehouse: {db_connect_error}") + raise # Re-raise after logging, as connection is essential + except Exception as connect_exc: + logging.error(f"Unexpected error creating Databricks SQL connection: {connect_exc}") + raise # Re-raise unexpected connection errors # Helper function for Databricks REST API requests def databricks_api_request(endpoint: str, method: str = "GET", data: Dict = None) -> Dict: - """Make a request to the Databricks REST API""" + """Make a request to the Databricks REST API, handling common errors.""" if not all([DATABRICKS_HOST, DATABRICKS_TOKEN]): + # This case is critical configuration error, raising is appropriate raise ValueError("Missing required Databricks API credentials in .env file") headers = { @@ -43,21 +84,43 @@ def databricks_api_request(endpoint: str, method: str = "GET", data: Dict = None url = f"https://{DATABRICKS_HOST}/api/2.0/{endpoint}" - if method.upper() == "GET": - response = requests.get(url, headers=headers) - elif method.upper() == "POST": - response = requests.post(url, headers=headers, json=data) - else: - raise ValueError(f"Unsupported HTTP method: {method}") + try: + if method.upper() == "GET": + response = requests.get(url, headers=headers) + elif method.upper() == "POST": + response = requests.post(url, headers=headers, json=data) + else: + # Internal programming error + raise ValueError(f"Unsupported HTTP method: {method}") + + response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx) + return response.json() - response.raise_for_status() - return response.json() + except req_exc.HTTPError as http_err: + logging.error(f"HTTP error occurred: {http_err} - Response: {http_err.response.text}") + # Re-raise as the caller (MCP tool) should handle it + raise + except req_exc.ConnectionError as conn_err: + logging.error(f"Connection error occurred: {conn_err}") + raise + except req_exc.Timeout as timeout_err: + logging.error(f"Request timed out: {timeout_err}") + raise + except req_exc.RequestException as req_err: + # Catch any other requests-related errors + logging.error(f"An error occurred during the API request: {req_err}") + raise + except ValueError as val_err: # Catch the ValueError from unsupported method + logging.error(f"API Request internal error: {val_err}") + raise + # No fallback 'except Exception' here - let unexpected errors propagate if necessary + # or add one if you want to catch absolutely everything originating *within* this function @mcp.resource("schema://tables") def get_schema() -> str: """Provide the list of tables in the Databricks SQL warehouse as a resource""" - conn = get_databricks_connection() try: + conn = get_databricks_connection() # Use the shared connection cursor = conn.cursor() tables = cursor.tables().fetchall() @@ -65,46 +128,55 @@ def get_schema() -> str: for table in tables: table_info.append(f"Database: {table.TABLE_CAT}, Schema: {table.TABLE_SCHEM}, Table: {table.TABLE_NAME}") + cursor.close() return "\n".join(table_info) - except Exception as e: - return f"Error retrieving tables: {str(e)}" - finally: - if 'conn' in locals(): - conn.close() + except ValueError as e: # Catch connection config errors + return f"Configuration Error: {str(e)}" + except db_exc.Error as db_error: # Catch errors during DB operations + logging.error(f"Databricks SQL Error in get_schema: {db_error}") + return f"Database Error: {str(db_error)}" + except Exception as e: # Fallback for unexpected errors + logging.error(f"Unexpected Error in get_schema: {e}", exc_info=True) + return f"An unexpected error occurred: {str(e)}" @mcp.tool() def run_sql_query(sql: str) -> str: """Execute SQL queries on Databricks SQL warehouse""" - conn = get_databricks_connection() - try: + conn = get_databricks_connection() # Use the shared connection cursor = conn.cursor() result = cursor.execute(sql) if result.description: - # Get column names columns = [col[0] for col in result.description] - - # Format the result as a table rows = result.fetchall() + cursor.close() + if not rows: return "Query executed successfully. No results returned." - # Format as markdown table - table = "| " + " | ".join(columns) + " |\n" - table += "| " + " | ".join(["---" for _ in columns]) + " |\n" - - for row in rows: - table += "| " + " | ".join([str(cell) for cell in row]) + " |\n" - - return table + # Format as markdown table (potential errors here unlikely but possible) + try: + table = "| " + " | ".join(columns) + " |\n" + table += "| " + " | ".join(["---" for _ in columns]) + " |\n" + for row in rows: + table += "| " + " | ".join([str(cell) for cell in row]) + " |\n" + return table + except Exception as format_exc: + logging.error(f"Error formatting SQL results: {format_exc}", exc_info=True) + return f"Error formatting results: {str(format_exc)}" else: + cursor.close() return "Query executed successfully. No results returned." - except Exception as e: - return f"Error executing query: {str(e)}" - finally: - if 'conn' in locals(): - conn.close() + except ValueError as e: # Catch connection config errors + return f"Configuration Error: {str(e)}" + except db_exc.Error as db_error: # Catch errors during DB operations (connection, syntax, permissions) + logging.error(f"Databricks SQL Error in run_sql_query: {db_error}") + # Provide a slightly more user-friendly message for common SQL issues + return f"SQL Execution Error: {str(db_error)}" + except Exception as e: # Fallback for unexpected errors + logging.error(f"Unexpected Error in run_sql_query: {e}", exc_info=True) + return f"An unexpected error occurred: {str(e)}" @mcp.tool() def list_jobs() -> str: @@ -112,102 +184,199 @@ def list_jobs() -> str: try: response = databricks_api_request("jobs/list") - if not response.get("jobs"): - return "No jobs found." - - jobs = response.get("jobs", []) - + # Error handling for unexpected API response structure + if not isinstance(response, dict): + logging.error(f"Unexpected API response type for list_jobs: {type(response)}") + return "Error: Received unexpected API response format." + + jobs = response.get("jobs") # Use .get() for safer access + if jobs is None: + # Distinguish between no jobs and error retrieving jobs + if "error_code" in response: + logging.error(f"API error in list_jobs response: {response}") + return f"API Error: {response.get('message', 'Unknown error')}" + else: + return "No jobs found." + if not isinstance(jobs, list): + logging.error(f"Unexpected 'jobs' format in list_jobs response: {type(jobs)}") + return "Error: Received unexpected API response format for jobs list." + # Format as markdown table - table = "| Job ID | Job Name | Created By |\n" - table += "| ------ | -------- | ---------- |\n" - + table_rows = ["| Job ID | Job Name | Created By |", "| ------ | -------- | ---------- |"] for job in jobs: + if not isinstance(job, dict): + logging.warning(f"Skipping invalid job item in list_jobs: {job}") + continue # Skip malformed job entries job_id = job.get("job_id", "N/A") - job_name = job.get("settings", {}).get("name", "N/A") - created_by = job.get("created_by", "N/A") + job_name = job.get("settings", {}).get("name", "N/A") # Nested get + creator = job.get("creator_user_name", "N/A") # Corrected field name based on API docs? Check this. - table += f"| {job_id} | {job_name} | {created_by} |\n" + table_rows.append(f"| {job_id} | {job_name} | {creator} |") - return table - except Exception as e: - return f"Error listing jobs: {str(e)}" + return "\n".join(table_rows) + # Catch specific exceptions raised by databricks_api_request + except ValueError as e: # Catch config or internal errors + return f"Configuration/Internal Error: {str(e)}" + except req_exc.HTTPError as http_err: + return f"API Request Failed: {http_err.response.status_code} {http_err.response.reason}" + except req_exc.RequestException as req_err: + return f"API Connection/Request Error: {str(req_err)}" + except (KeyError, TypeError, AttributeError) as data_err: + # Catch errors processing the response data + logging.error(f"Error processing API response data in list_jobs: {data_err}", exc_info=True) + return f"Error processing API response: {str(data_err)}" + except Exception as e: # Fallback + logging.error(f"Unexpected Error in list_jobs: {e}", exc_info=True) + return f"An unexpected error occurred: {str(e)}" @mcp.tool() def get_job_status(job_id: int) -> str: """Get the status of a specific Databricks job""" try: - response = databricks_api_request("jobs/runs/list", data={"job_id": job_id}) + # Add basic input validation + if not isinstance(job_id, int) or job_id <= 0: + return "Error: Invalid Job ID provided." + + response = databricks_api_request("jobs/runs/list", data={"job_id": job_id}) # Can raise exceptions - if not response.get("runs"): + # Error handling for unexpected API response structure + if not isinstance(response, dict): + logging.error(f"Unexpected API response type for get_job_status: {type(response)}") + return "Error: Received unexpected API response format." + + runs = response.get("runs") + if runs is None: + if "error_code" in response: + logging.error(f"API error in get_job_status response: {response}") + return f"API Error: {response.get('message', 'Unknown error')}" + else: + return f"No runs found for job ID {job_id}." + if not isinstance(runs, list): + logging.error(f"Unexpected 'runs' format in get_job_status response: {type(runs)}") + return "Error: Received unexpected API response format for runs list." + + if not runs: # Check if list is empty after verifying it's a list return f"No runs found for job ID {job_id}." - - runs = response.get("runs", []) - + # Format as markdown table - table = "| Run ID | State | Start Time | End Time | Duration |\n" - table += "| ------ | ----- | ---------- | -------- | -------- |\n" + import datetime # Import moved inside for scope + table_rows = ["| Run ID | State | Start Time | End Time | Duration |", "| ------ | ----- | ---------- | -------- | -------- |"] for run in runs: - run_id = run.get("run_id", "N/A") - state = run.get("state", {}).get("result_state", "N/A") - - # Convert timestamps to readable format if they exist - start_time = run.get("start_time", 0) - end_time = run.get("end_time", 0) + if not isinstance(run, dict): + logging.warning(f"Skipping invalid run item in get_job_status: {run}") + continue - if start_time and end_time: - duration = f"{(end_time - start_time) / 1000:.2f}s" - else: - duration = "N/A" + run_id = run.get("run_id", "N/A") + # Safer access to nested state + state_info = run.get("state", {}) + state = state_info.get("life_cycle_state", "N/A") # life_cycle_state is often more useful + result_state = state_info.get("result_state") + if result_state: + state += f" ({result_state})" - # Format timestamps - import datetime + start_time = run.get("start_time") + end_time = run.get("end_time") + duration_ms = run.get("execution_duration") # Use execution_duration if available + + duration = "N/A" + if duration_ms is not None and duration_ms > 0: + duration = f"{duration_ms / 1000:.2f}s" + elif start_time and end_time: + # Fallback calculation if execution_duration missing + duration = f"{(end_time - start_time) / 1000:.2f}s" + start_time_str = datetime.datetime.fromtimestamp(start_time / 1000).strftime('%Y-%m-%d %H:%M:%S') if start_time else "N/A" end_time_str = datetime.datetime.fromtimestamp(end_time / 1000).strftime('%Y-%m-%d %H:%M:%S') if end_time else "N/A" - table += f"| {run_id} | {state} | {start_time_str} | {end_time_str} | {duration} |\n" + table_rows.append(f"| {run_id} | {state} | {start_time_str} | {end_time_str} | {duration} |") - return table - except Exception as e: - return f"Error getting job status: {str(e)}" + return "\n".join(table_rows) + # Catch specific exceptions raised by databricks_api_request + except ValueError as e: # Catch config or internal errors (incl. invalid job_id) + return f"Configuration/Input Error: {str(e)}" + except req_exc.HTTPError as http_err: + return f"API Request Failed: {http_err.response.status_code} {http_err.response.reason}" + except req_exc.RequestException as req_err: + return f"API Connection/Request Error: {str(req_err)}" + except (KeyError, TypeError, AttributeError) as data_err: + # Catch errors processing the response data + logging.error(f"Error processing API response data in get_job_status: {data_err}", exc_info=True) + return f"Error processing API response: {str(data_err)}" + except Exception as e: # Fallback + logging.error(f"Unexpected Error in get_job_status: {e}", exc_info=True) + return f"An unexpected error occurred: {str(e)}" @mcp.tool() def get_job_details(job_id: int) -> str: """Get detailed information about a specific Databricks job""" try: - response = databricks_api_request(f"jobs/get?job_id={job_id}", method="GET") - - # Format the job details - job_name = response.get("settings", {}).get("name", "N/A") - created_time = response.get("created_time", 0) + # Add basic input validation + if not isinstance(job_id, int) or job_id <= 0: + return "Error: Invalid Job ID provided." + + response = databricks_api_request(f"jobs/get?job_id={job_id}", method="GET") # Can raise exceptions + + # Error handling for unexpected API response structure + if not isinstance(response, dict): + logging.error(f"Unexpected API response type for get_job_details: {type(response)}") + return "Error: Received unexpected API response format." + + # Safer access using .get() + settings = response.get("settings", {}) + job_name = settings.get("name", "N/A") + created_time = response.get("created_time") + creator_user_name = response.get('creator_user_name', 'N/A') + tasks = settings.get("tasks", []) # Convert timestamp to readable format - import datetime + import datetime # Import moved inside for scope created_time_str = datetime.datetime.fromtimestamp(created_time / 1000).strftime('%Y-%m-%d %H:%M:%S') if created_time else "N/A" - # Get job tasks - tasks = response.get("settings", {}).get("tasks", []) + result_lines = [f"## Job Details: {job_name}\n"] + result_lines.append(f"- **Job ID:** {job_id}") + result_lines.append(f"- **Created:** {created_time_str}") + result_lines.append(f"- **Creator:** {creator_user_name}\n") - result = f"## Job Details: {job_name}\n\n" - result += f"- **Job ID:** {job_id}\n" - result += f"- **Created:** {created_time_str}\n" - result += f"- **Creator:** {response.get('creator_user_name', 'N/A')}\n\n" - - if tasks: - result += "### Tasks:\n\n" - result += "| Task Key | Task Type | Description |\n" - result += "| -------- | --------- | ----------- |\n" + if isinstance(tasks, list) and tasks: # Check tasks is a non-empty list + result_lines.append("### Tasks:\n") + result_lines.append("| Task Key | Task Type | Description |") + result_lines.append("| -------- | --------- | ----------- |") for task in tasks: + if not isinstance(task, dict): + logging.warning(f"Skipping invalid task item in get_job_details: {task}") + continue + task_key = task.get("task_key", "N/A") - task_type = next(iter([k for k in task.keys() if k.endswith("_task")]), "N/A") + # Simplified task type extraction (assuming one _task key) + task_type = next((k.replace('_task', '') for k in task if k.endswith('_task')), "N/A") description = task.get("description", "N/A") - result += f"| {task_key} | {task_type} | {description} |\n" + result_lines.append(f"| {task_key} | {task_type} | {description} |") + else: + result_lines.append("No tasks defined for this job.") - return result - except Exception as e: - return f"Error getting job details: {str(e)}" + return "\n".join(result_lines) + # Catch specific exceptions raised by databricks_api_request + except ValueError as e: # Catch config or internal errors (incl. invalid job_id) + return f"Configuration/Input Error: {str(e)}" + except req_exc.HTTPError as http_err: + # Specifically check for 404 for job not found + if http_err.response.status_code == 404: + return f"Error: Job ID {job_id} not found." + return f"API Request Failed: {http_err.response.status_code} {http_err.response.reason}" + except req_exc.RequestException as req_err: + return f"API Connection/Request Error: {str(req_err)}" + except (KeyError, TypeError, AttributeError) as data_err: + # Catch errors processing the response data + logging.error(f"Error processing API response data in get_job_details: {data_err}", exc_info=True) + return f"Error processing API response: {str(data_err)}" + except Exception as e: # Fallback + logging.error(f"Unexpected Error in get_job_details: {e}", exc_info=True) + return f"An unexpected error occurred: {str(e)}" if __name__ == "__main__": + # Setup basic logging + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') mcp.run() \ No newline at end of file