diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bee0047..ff8b599 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -140,7 +140,7 @@ When adding new tools, follow this pattern: @mcp.tool() def your_new_tool( app_id: str, - server: Optional[str] = None, + server_spec: ServerSpec, # other parameters ) -> YourReturnType: """ @@ -148,13 +148,13 @@ def your_new_tool( Args: app_id: The Spark application ID - server: Optional server name to use + server_spec: ServerSpec Returns: Description of return value """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) # Your implementation here return client.your_method(app_id) @@ -163,7 +163,7 @@ def your_new_tool( **Don't forget to add tests:** ```python -@patch("tools.get_client_or_default") +@patch("tools.get_client") def test_your_new_tool(self, mock_get_client): """Test your new tool functionality""" # Setup mocks @@ -172,7 +172,7 @@ def test_your_new_tool(self, mock_get_client): mock_get_client.return_value = mock_client # Call the tool - result = your_new_tool("spark-app-123") + result = your_new_tool("spark-app-123", self.DEFAULT_SERVER_SPEC) # Verify results self.assertEqual(result, expected_result) diff --git a/README.md b/README.md index 3f36863..cecb7ff 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # MCP Server for Apache Spark History Server + [![CI](https://github.com/DeepDiagnostix-AI/mcp-apache-spark-history-server/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/DeepDiagnostix-AI/mcp-apache-spark-history-server/actions) [![Python 3.12+](https://img.shields.io/badge/python-3.12+-blue.svg)](https://www.python.org/downloads/) @@ -112,6 +113,10 @@ servers: auth: # optional username: "user" password: "pass" + +# Enable dynamic EMR clusters mode to specify EMR clusters directly in tool calls +dynamic_emr_clusters_mode: true + mcp: transports: - streamable-http # streamable-http or stdio. @@ -243,13 +248,35 @@ servers: url: "http://staging-spark-history:18080" ``` +With static configuration: + 💁 User Query: "Can you get application using production server?" 🤖 AI Tool Request: ```json { "app_id": "", - "server": "production" + "server_spec": { + "static_server_spec": { + "server_name": "production" + } + } +} +``` + +With dynamic EMR configuration: + +💁 User Query: "Can you get application using cluster j-I4VIWMNGOIP7" + +🤖 AI Tool Request: +```json +{ + "app_id": "", + "server_spec": { + "dynamic_emr_server_spec": { + "emr_cluster_id": "j-I4VIWMNGOIP7" + } + } } ``` 🤖 AI Tool Response: @@ -289,6 +316,7 @@ SHS_SERVERS_*_AUTH_TOKEN - Token for a specific server SHS_SERVERS_*_VERIFY_SSL - Whether to verify SSL for a specific server (true/false) SHS_SERVERS_*_TIMEOUT - HTTP request timeout in seconds for a specific server (default: 30) SHS_SERVERS_*_EMR_CLUSTER_ARN - EMR cluster ARN for a specific server +SHS_DYNAMIC_EMR_CLUSTERS_MODE - Enable dynamic EMR clusters mode (default: false) ``` ## 🤖 AI Agent Integration diff --git a/config.yaml b/config.yaml index fcccfb7..13bd924 100644 --- a/config.yaml +++ b/config.yaml @@ -36,6 +36,8 @@ servers: # emr_persistent_ui: # emr_cluster_arn: "" +# dynamic_emr_clusters_mode: true # To be able to add cluster id or name to the prompt + mcp: transports: - streamable-http # streamable-http or stdio. you can only specify one right now. @@ -55,3 +57,4 @@ mcp: # SHS_SERVERS_*_AUTH_TOKEN - Token for a specific server # SHS_SERVERS_*_VERIFY_SSL - Whether to verify SSL for a specific server (true/false) # SHS_SERVERS_*_EMR_CLUSTER_ARN - EMR cluster ARN for a specific server +# SHS_DYNAMIC_EMR_CLUSTERS_MODE - Enable dynamic EMR clusters mode (default: false) diff --git a/deploy/kubernetes/helm/README.md b/deploy/kubernetes/helm/README.md index bba2d48..6021a0e 100644 --- a/deploy/kubernetes/helm/README.md +++ b/deploy/kubernetes/helm/README.md @@ -81,6 +81,12 @@ config: url: "http://dev-spark-history:18080" ``` +#### 1a. Dynamic EMR Clusters Configuration +```yaml +config: + dynamic_emr_clusters_mode: true +``` + #### 2. Authentication Setup ```yaml auth: diff --git a/examples/aws/emr/README.md b/examples/aws/emr/README.md index 1dc042f..b1b326f 100644 --- a/examples/aws/emr/README.md +++ b/examples/aws/emr/README.md @@ -4,7 +4,12 @@ [![Watch the demo video](https://img.shields.io/badge/YouTube-Watch%20Demo-red?style=for-the-badge&logo=youtube)](https://www.youtube.com/watch?v=FaduuvMdGxI) -If you are an existing Amazon EMR user looking to analyze your Spark Applications, then you can follow the steps below to start using the Spark History Server MCP in 5 simple steps. +If you are an existing Amazon EMR user looking to analyze your Spark Applications, you have **two options**: + +1. **🔧 Static Configuration** - Pre-configure EMR clusters in `config.yaml` +2. **⚡ Dynamic Configuration** - Specify EMR clusters directly in tool calls + +The dynamic approach is particularly useful when analyzing Spark applications across multiple EMR clusters without pre-configuration. ## Step 1: Setup project on your laptop @@ -24,16 +29,38 @@ task install # Install dependencies Amazon EMR-EC2 users can use a service-managed [Persistent UI](https://docs.aws.amazon.com/emr/latest/ManagementGuide/app-history-spark-UI.html) which automatically creates the Spark History Server for Spark applications on a given EMR Cluster. You can directly go to Step 3 and configure the MCP server with an EMR Cluster Id to analyze the Spark applications on that cluster. -## Step 3: Configure the MCP Server to use the EMR Persistent UI +## Step 3: Configure the MCP Server + +You have **two configuration options**: + +### Option A: 🔧 Static Configuration - Identify the Amazon EMR Cluster Id for which you want the MCP server to analyze the Spark applications - Edit SHS MCP Config: [config.yaml](../../../config.yaml) to add the EMR Cluster Id ```yaml -emr_persistent_ui: - emr_cluster_arn: "" +servers: + emr_persistent_ui: + emr_cluster_arn: "" + +dynamic_emr_clusters_mode: false # Disable dynamic mode +``` + +### Option B: ⚡ Dynamic Configuration + +Enable dynamic EMR clusters mode in [config.yaml](../../../config.yaml): + +```yaml +dynamic_emr_clusters_mode: true # Enable dynamic mode + +# No need to pre-configure servers - specify clusters in tool calls ``` +With dynamic mode, you can specify EMR clusters directly in AI queries: +- **By ARN**: `"arn:aws:emr:us-east-1:123456789012:cluster/j-1234567890ABC"` +- **By Cluster ID**: `"j-1234567890ABC"` +- **By Cluster Name**: `"my-production-cluster"` (active clusters only) + **Note**: The MCP Server manages the creation of the Persistent UI and its authentication using tokens with Persistent UI. You do not need to open the Persistent UI URL in a Web Browser. Please ensure the user running the MCP has access to create and view the Persistent UI for that cluster by following the [EMR Documentation](https://docs.aws.amazon.com/emr/latest/ManagementGuide/app-history-spark-UI.html#app-history-spark-UI-permissions). ## Step 4: Start the MCP Server diff --git a/src/spark_history_mcp/api/__init__.py b/src/spark_history_mcp/api/__init__.py index 5613a31..60d4a88 100644 --- a/src/spark_history_mcp/api/__init__.py +++ b/src/spark_history_mcp/api/__init__.py @@ -1 +1 @@ -"""API clients for interacting with Spark History Server.""" +"""API clients.""" diff --git a/src/spark_history_mcp/api/client_factory.py b/src/spark_history_mcp/api/client_factory.py new file mode 100644 index 0000000..6ec4f85 --- /dev/null +++ b/src/spark_history_mcp/api/client_factory.py @@ -0,0 +1,62 @@ +from typing import Optional + +from spark_history_mcp.api.emr_persistent_ui_client import EMRPersistentUIClient +from spark_history_mcp.api.spark_client import SparkRestClient +from spark_history_mcp.config.config import ServerConfig + + +def create_spark_client_from_config(server_config: ServerConfig) -> SparkRestClient: + """ + Create a SparkRestClient from a ServerConfig. + + This function handles both regular Spark History Servers and EMR Persistent UI configurations. + + Args: + server_config: The server configuration + + Returns: + SparkRestClient instance properly configured + """ + # Check if this is an EMR server configuration + if server_config.emr_cluster_arn: + return create_spark_emr_client(server_config.emr_cluster_arn, server_config) + else: + # Regular Spark REST client + return SparkRestClient(server_config) + + +def create_spark_emr_client( + emr_cluster_arn: str, server_config: Optional[ServerConfig] = None +) -> SparkRestClient: + """ + Create a SparkRestClient from EMR cluster arn and optional ServerConfig. + + This function handles EMR Persistent UI applications. + + Args: + emr_cluster_arn: The EMR cluster ARN + server_config: The server configuration + + Returns: + SparkRestClient instance properly configured + """ + if server_config is None: + server_config = ServerConfig() + server_config.emr_cluster_arn = emr_cluster_arn + emr_client = EMRPersistentUIClient(server_config) + + # Initialize EMR client (create persistent UI, get presigned URL, setup session) + base_url, session = emr_client.initialize() + + # Create a modified server config with the base URL + if server_config is None: + server_config = ServerConfig() + else: + server_config = server_config.model_copy() + server_config.url = base_url + + # Create SparkRestClient with the session + spark_client = SparkRestClient(server_config) + spark_client.session = session # Use the authenticated session + + return spark_client diff --git a/src/spark_history_mcp/api/emr_client.py b/src/spark_history_mcp/api/emr_client.py new file mode 100644 index 0000000..970a6b7 --- /dev/null +++ b/src/spark_history_mcp/api/emr_client.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +""" +EMR Client + +This module provides functionality to interact with AWS EMR clusters, +specifically to get cluster ARNs by cluster ID or cluster name. +""" + +import logging +from typing import Dict, List, Optional + +import boto3 +from botocore.exceptions import ClientError, NoCredentialsError + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +class EMRClusterNotFoundError(Exception): + """Raised when an EMR cluster cannot be found.""" + + pass + + +class EMRClient: + """Client for interacting with AWS EMR clusters.""" + + def __init__(self, region_name: Optional[str] = None): + """ + Initialize the EMR client. + + Args: + region_name: AWS region name. If not provided, uses default region from AWS config. + """ + try: + self.emr_client = boto3.client("emr", region_name=region_name) + self.region_name = region_name or self.emr_client.meta.region_name + logger.info(f"Initialized EMR client for region: {self.region_name}") + except NoCredentialsError: + logger.error("AWS credentials not found. Please configure AWS credentials.") + raise + except Exception as e: + logger.error(f"Failed to initialize EMR client: {str(e)}") + raise + + def get_cluster_arn_by_id(self, cluster_id: str) -> str: + """ + Get cluster ARN by cluster ID. + + Args: + cluster_id: EMR cluster ID (e.g., 'j-1234567890ABC') + + Returns: + Cluster ARN + + Raises: + EMRClusterNotFoundError: If cluster is not found + ClientError: If AWS API call fails + """ + try: + logger.info(f"Getting cluster details for cluster ID: {cluster_id}") + response = self.emr_client.describe_cluster(ClusterId=cluster_id) + cluster_arn = response["Cluster"]["ClusterArn"] + logger.info(f"Found cluster ARN: {cluster_arn}") + return cluster_arn + except ClientError as e: + error_code = e.response["Error"]["Code"] + if error_code == "InvalidRequestException": + raise EMRClusterNotFoundError( + f"Cluster with ID '{cluster_id}' not found" + ) from e + else: + logger.error(f"AWS API error getting cluster by ID: {str(e)}") + raise + except Exception as e: + logger.error(f"Unexpected error getting cluster by ID: {str(e)}") + raise + + def get_active_cluster_arn_by_name(self, cluster_name: str) -> str: + """ + Get cluster ARN by cluster name. This only searches for RUNNING or WAITING clusters. + + Args: + cluster_name: EMR cluster name + + Returns: + Cluster ARN + + Raises: + EMRClusterNotFoundError: If cluster is not found or multiple clusters with same name exist + ClientError: If AWS API call fails + """ + try: + logger.info(f"Searching for cluster with name: {cluster_name}") + + # List clusters to find matching name + matching_clusters = self._find_active_clusters_by_name(cluster_name) + + if not matching_clusters: + raise EMRClusterNotFoundError( + f"No cluster found with name '{cluster_name}'" + ) + + if len(matching_clusters) > 1: + cluster_ids = [cluster["Id"] for cluster in matching_clusters] + raise EMRClusterNotFoundError( + f"Multiple clusters found with name '{cluster_name}': {cluster_ids}. " + "Please use cluster ID instead." + ) + + cluster = matching_clusters[0] + cluster_id = cluster["Id"] + + # Get full cluster details to retrieve ARN + return self.get_cluster_arn_by_id(cluster_id) + + except EMRClusterNotFoundError: + raise + except Exception as e: + logger.error(f"Unexpected error getting cluster by name: {str(e)}") + raise + + def _find_active_clusters_by_name(self, cluster_name: str) -> List[Dict]: + """ + Find clusters by name using list_clusters API. This only searches for RUNNING or WAITING clusters. + + Args: + cluster_name: Name of the cluster to search for + + Returns: + List of cluster summaries matching the name + """ + matching_clusters = [] + marker = None + + try: + while True: + # List clusters with pagination + list_params = {} + if marker: + list_params["Marker"] = marker + list_params["ClusterStates"] = ["RUNNING", "WAITING"] + + response = self.emr_client.list_clusters(**list_params) + clusters = response.get("Clusters", []) + + # Filter clusters by name + for cluster in clusters: + if cluster.get("Name") == cluster_name: + matching_clusters.append(cluster) + + # Check if there are more pages + marker = response.get("Marker") + if not marker: + break + + except ClientError as e: + logger.error(f"AWS API error listing clusters: {str(e)}") + raise + + return matching_clusters + + def get_cluster_arn(self, cluster_identifier: str) -> str: + """ + Get cluster ARN by cluster identifier (ID or name). + + This method automatically detects whether the identifier is a cluster ID or name: + - Cluster IDs follow the pattern 'j-' followed by alphanumeric characters + - Everything else is treated as a cluster name + + Args: + cluster_identifier: Either cluster ID (e.g., 'j-1234567890ABC') or cluster name + + Returns: + Cluster ARN + + Raises: + EMRClusterNotFoundError: If cluster is not found + ClientError: If AWS API call fails + """ + # Check if it's a cluster ID (starts with 'j-') + if cluster_identifier.startswith("j-"): + logger.info(f"Treating '{cluster_identifier}' as cluster ID") + return self.get_cluster_arn_by_id(cluster_identifier) + else: + logger.info(f"Treating '{cluster_identifier}' as cluster name") + return self.get_active_cluster_arn_by_name(cluster_identifier) + + def get_cluster_details(self, cluster_identifier: str) -> Dict: + """ + Get full cluster details by cluster identifier (ID or name). + + Args: + cluster_identifier: Either cluster ID or cluster name + + Returns: + Cluster details dictionary + + Raises: + EMRClusterNotFoundError: If cluster is not found + ClientError: If AWS API call fails + """ + # First get the cluster ARN to ensure we have the correct cluster + cluster_arn = self.get_cluster_arn(cluster_identifier) + + # Extract cluster ID from ARN if we started with a name + if not cluster_identifier.startswith("j-"): + # ARN format: arn:aws:elasticmapreduce:region:account:cluster/cluster-id + cluster_id = cluster_arn.split("/")[-1] + else: + cluster_id = cluster_identifier + + try: + response = self.emr_client.describe_cluster(ClusterId=cluster_id) + return response["Cluster"] + except ClientError as e: + logger.error(f"AWS API error getting cluster details: {str(e)}") + raise diff --git a/src/spark_history_mcp/config/config.py b/src/spark_history_mcp/config/config.py index 639f9e6..c4c266f 100644 --- a/src/spark_history_mcp/config/config.py +++ b/src/spark_history_mcp/config/config.py @@ -2,7 +2,7 @@ from typing import Dict, List, Literal, Optional import yaml -from pydantic import Field +from pydantic import Field, model_validator from pydantic_settings import ( BaseSettings, PydanticBaseSettingsSource, @@ -45,9 +45,8 @@ class McpConfig(BaseSettings): class Config(BaseSettings): """Configuration for the Spark client.""" - servers: Dict[str, ServerConfig] = { - "local": ServerConfig(url="http://localhost:18080", default=True), - } + dynamic_emr_clusters_mode: bool = False + servers: Optional[Dict[str, ServerConfig]] = None mcp: Optional[McpConfig] = McpConfig(transports=["streamable-http"]) model_config = SettingsConfigDict( env_prefix="SHS_", @@ -56,6 +55,22 @@ class Config(BaseSettings): env_file_encoding="utf-8", ) + @model_validator(mode="after") + def validate_and_set_default_values(self) -> "Config": + """Set default servers and validate that dynamic_emr_clusters_mode and servers are mutually exclusive.""" + if self.servers is None and not self.dynamic_emr_clusters_mode: + self.servers = { + "local": ServerConfig(url="http://localhost:18080", default=True), + } + + # Validate mutual exclusivity + if self.dynamic_emr_clusters_mode and self.servers: + raise ValueError( + "dynamic_emr_clusters_mode cannot be True when servers is not empty. " + "These modes are mutually exclusive." + ) + return self + @classmethod def from_file(cls, file_path: str) -> "Config": """Load configuration from a YAML file.""" diff --git a/src/spark_history_mcp/core/app.py b/src/spark_history_mcp/core/app.py index 954436a..e6483d3 100644 --- a/src/spark_history_mcp/core/app.py +++ b/src/spark_history_mcp/core/app.py @@ -8,17 +8,25 @@ from mcp.server.fastmcp import FastMCP -from spark_history_mcp.api.emr_persistent_ui_client import EMRPersistentUIClient +from spark_history_mcp.api.client_factory import create_spark_client_from_config +from spark_history_mcp.api.emr_client import EMRClient from spark_history_mcp.api.spark_client import SparkRestClient from spark_history_mcp.config.config import Config @dataclass -class AppContext: +class StaticClients: clients: dict[str, SparkRestClient] default_client: Optional[SparkRestClient] = None +@dataclass +class AppContext: + dynamic_emr_clusters_mode: bool = False + emr_client: Optional[EMRClient] = None + static_clients: Optional[StaticClients] = None + + class DateTimeEncoder(json.JSONEncoder): """Custom JSON encoder that handles datetime objects.""" @@ -32,35 +40,22 @@ def default(self, obj): async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: config = Config.from_file("config.yaml") + if config.dynamic_emr_clusters_mode: + yield AppContext(dynamic_emr_clusters_mode=True, emr_client=EMRClient()) + return + clients: dict[str, SparkRestClient] = {} default_client = None for name, server_config in config.servers.items(): - # Check if this is an EMR server configuration - if server_config.emr_cluster_arn: - # Create EMR client - emr_client = EMRPersistentUIClient(server_config) - - # Initialize EMR client (create persistent UI, get presigned URL, setup session) - base_url, session = emr_client.initialize() - - # Create a modified server config with the base URL - emr_server_config = server_config.model_copy() - emr_server_config.url = base_url - - # Create SparkRestClient with the session - spark_client = SparkRestClient(emr_server_config) - spark_client.session = session # Use the authenticated session - - clients[name] = spark_client - else: - # Regular Spark REST client - clients[name] = SparkRestClient(server_config) + clients[name] = create_spark_client_from_config(server_config) if server_config.default: default_client = clients[name] - yield AppContext(clients=clients, default_client=default_client) + yield AppContext( + static_clients=StaticClients(clients=clients, default_client=default_client) + ) def run(config: Config): diff --git a/src/spark_history_mcp/models/server_spec.py b/src/spark_history_mcp/models/server_spec.py new file mode 100644 index 0000000..e29f965 --- /dev/null +++ b/src/spark_history_mcp/models/server_spec.py @@ -0,0 +1,39 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class StaticServerSpec(BaseModel): + """Model for specifying static Spark server configuration in the tool call.""" + + default_client: bool = Field( + False, description="Use the default client from config.yaml if True. " + ) + server_name: Optional[str] = Field( + None, description="Name of a pre-configured server from config.yaml" + ) + + +class DynamicEMRServerSpec(BaseModel): + """Model for specifying dynamic EMR server in the tool call.""" + + emr_cluster_arn: Optional[str] = Field(None, description="ARN of the EMR cluster. ") + emr_cluster_id: Optional[str] = Field( + None, description="ID of the EMR cluster. Starts with 'j-'" + ) + emr_cluster_name: Optional[str] = Field( + None, + description="Name of the *active* EMR cluster. Terminated clusters are not supported.", + ) + + +class ServerSpec(BaseModel): + """Model for specifying which Spark server to use in the tool call.""" + + static_server_spec: Optional[StaticServerSpec] = Field( + None, + description="spec to be used with static Spark servers defined in config.yaml.", + ) + dynamic_emr_server_spec: Optional[DynamicEMRServerSpec] = Field( + None, description="spec to be used in dynamic EMR server mode. " + ) diff --git a/src/spark_history_mcp/tools/tools.py b/src/spark_history_mcp/tools/tools.py index 3c83aa5..ee0d342 100644 --- a/src/spark_history_mcp/tools/tools.py +++ b/src/spark_history_mcp/tools/tools.py @@ -1,11 +1,16 @@ import heapq -from typing import Any, Dict, List, Optional +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, TypeVar +from spark_history_mcp.api.client_factory import create_spark_emr_client +from spark_history_mcp.api.emr_client import EMRClient +from spark_history_mcp.api.spark_client import SparkRestClient from spark_history_mcp.core.app import mcp from spark_history_mcp.models.mcp_types import ( JobSummary, SqlQuerySummary, ) +from spark_history_mcp.models.server_spec import ServerSpec from spark_history_mcp.models.spark_types import ( ApplicationInfo, ExecutionData, @@ -17,39 +22,112 @@ TaskMetricDistributions, ) +emr_cluster_id_to_arn_cache: Dict[str, str] = {} +arn_to_spark_emr_client_cache: Dict[str, SparkRestClient] = {} +session_emr_cluster_name_to_arn_cache: Dict[int, Dict[str, str]] = defaultdict(dict) -def get_client_or_default(ctx, server_name: Optional[str] = None): + +K = TypeVar("K") # Key type +V = TypeVar("V") # Value type + + +def _get_cacheable(key: K, cache: Dict[K, V], get_value_func: Callable[[K], V]) -> V: + if key in cache: + return cache[key] + + value = get_value_func(key) + cache[key] = value + return value + + +def _get_emr_client(ctx) -> EMRClient: + emr_client: EMRClient = ctx.request_context.lifespan_context.emr_client + if not emr_client: + raise ValueError("EMR client is not initialized in dynamic mode") + return emr_client + + +def get_client(ctx, server_spec: ServerSpec) -> SparkRestClient: """ - Get a client by server name or the default client if no name is provided. + Get a client by ServerSpec. Args: ctx: The MCP context - server_name: Optional server name + server_spec: ServerSpec Returns: SparkRestClient: The requested client or default client Raises: - ValueError: If no client is found + ValueError: If no client is found or configuration is invalid """ - clients = ctx.request_context.lifespan_context.clients - default_client = ctx.request_context.lifespan_context.default_client - if server_name: - client = clients.get(server_name) - if client: + if server_spec.static_server_spec: + if ctx.request_context.lifespan_context.dynamic_emr_clusters_mode: + raise ValueError( + "MCP is running in dynamic EMR mode, but static server spec was provided." + ) + + spec = server_spec.static_server_spec + + if spec.default_client: + default_client = ( + ctx.request_context.lifespan_context.static_clients.default_client + ) + if not default_client: + raise ValueError("No default client configured") + return default_client + + if spec.server_name: + clients = ctx.request_context.lifespan_context.static_clients.clients + client = clients.get(spec.server_name) + if not client: + raise ValueError(f"No server configured with name: {spec.server_name}") return client - if default_client: - return default_client + if server_spec.dynamic_emr_server_spec: + if not ctx.request_context.lifespan_context.dynamic_emr_clusters_mode: + raise ValueError( + "MCP is not running in dynamic EMR mode, but dynamic server spec was provided." + ) - raise ValueError( - "No Spark client found. Please specify a valid server name or set a default server." - ) + spec = server_spec.dynamic_emr_server_spec + + if spec.emr_cluster_arn: + arn = spec.emr_cluster_arn + + elif spec.emr_cluster_id: + arn = _get_cacheable( + spec.emr_cluster_id, + emr_cluster_id_to_arn_cache, + lambda cluster_id: _get_emr_client(ctx).get_cluster_arn_by_id( + cluster_id + ), + ) + + elif spec.emr_cluster_name: + arn = _get_cacheable( + spec.emr_cluster_name, + session_emr_cluster_name_to_arn_cache[id(ctx.session)], + lambda cluster_name: _get_emr_client( + ctx + ).get_active_cluster_arn_by_name(cluster_name), + ) + + else: + raise ValueError("Invalid server_spec") + + return _get_cacheable( + arn, + arn_to_spark_emr_client_cache, + lambda _arn: create_spark_emr_client(_arn), + ) + + raise ValueError("Invalid server_spec") @mcp.tool() -def get_application(app_id: str, server: Optional[str] = None) -> ApplicationInfo: +def get_application(app_id: str, server_spec: ServerSpec) -> ApplicationInfo: """ Get detailed information about a specific Spark application. @@ -58,34 +136,34 @@ def get_application(app_id: str, server: Optional[str] = None) -> ApplicationInf Args: app_id: The Spark application ID - server: Optional server name to use (uses default if not specified) + server_spec: ServerSpec for specifying the server Returns: ApplicationInfo object containing application details """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) return client.get_application(app_id) @mcp.tool() def list_jobs( - app_id: str, server: Optional[str] = None, status: Optional[list[str]] = None + app_id: str, server_spec: ServerSpec, status: Optional[list[str]] = None ) -> list: """ Get a list of all jobs for a Spark application. Args: app_id: The Spark application ID - server: Optional server name to use (uses default if not specified) + server_spec: ServerSpec for specifying the server status: Optional list of job status values to filter by Returns: List of JobData objects for the application """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) # Convert string status values to JobExecutionStatus enum if provided job_statuses = None @@ -98,7 +176,7 @@ def list_jobs( @mcp.tool() def list_slowest_jobs( app_id: str, - server: Optional[str] = None, + server_spec: ServerSpec, include_running: bool = False, n: int = 5, ) -> List[JobData]: @@ -109,7 +187,7 @@ def list_slowest_jobs( Args: app_id: The Spark application ID - server: Optional server name to use (uses default if not specified) + server_spec: ServerSpec for specifying the server include_running: Whether to include running jobs in the search n: Number of slowest jobs to return (default: 5) @@ -117,7 +195,7 @@ def list_slowest_jobs( List of JobData objects for the slowest jobs, or empty list if no jobs found """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) # Get all jobs jobs = client.list_jobs(app_id=app_id) @@ -143,7 +221,7 @@ def get_job_duration(job): @mcp.tool() def list_stages( app_id: str, - server: Optional[str] = None, + server_spec: ServerSpec, status: Optional[list[str]] = None, with_summaries: bool = False, ) -> list: @@ -155,7 +233,7 @@ def list_stages( Args: app_id: The Spark application ID - server: Optional server name to use (uses default if not specified) + server_spec: ServerSpec for specifying the server status: Optional list of stage status values to filter by with_summaries: Whether to include summary metrics in the response @@ -163,7 +241,7 @@ def list_stages( List of StageData objects for the application """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) # Convert string status values to StageStatus enum if provided stage_statuses = None @@ -180,7 +258,7 @@ def list_stages( @mcp.tool() def list_slowest_stages( app_id: str, - server: Optional[str] = None, + server_spec: ServerSpec, include_running: bool = False, n: int = 5, ) -> List[StageData]: @@ -191,7 +269,7 @@ def list_slowest_stages( Args: app_id: The Spark application ID - server: Optional server name to use (uses default if not specified) + server_spec: ServerSpec for specifying the server include_running: Whether to include running stages in the search n: Number of slowest stages to return (default: 5) @@ -199,7 +277,7 @@ def list_slowest_stages( List of StageData objects for the slowest stages, or empty list if no stages found """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) stages = client.list_stages(app_id=app_id) @@ -224,8 +302,8 @@ def get_stage_duration(stage: StageData): def get_stage( app_id: str, stage_id: int, + server_spec: ServerSpec, attempt_id: Optional[int] = None, - server: Optional[str] = None, with_summaries: bool = False, ) -> StageData: """ @@ -234,15 +312,15 @@ def get_stage( Args: app_id: The Spark application ID stage_id: The stage ID + server_spec: ServerSpec for specifying the server attempt_id: Optional stage attempt ID (if not provided, returns the latest attempt) - server: Optional server name to use (uses default if not specified) with_summaries: Whether to include summary metrics Returns: StageData object containing stage information """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) if attempt_id is not None: # Get specific attempt @@ -287,7 +365,7 @@ def get_stage( @mcp.tool() -def get_environment(app_id: str, server: Optional[str] = None): +def get_environment(app_id: str, server_spec: ServerSpec): """ Get the comprehensive Spark runtime configuration for a Spark application. @@ -296,20 +374,20 @@ def get_environment(app_id: str, server: Optional[str] = None): Args: app_id: The Spark application ID - server: Optional server name to use (uses default if not specified) + server_spec: ServerSpec for specifying the server Returns: ApplicationEnvironmentInfo object containing environment details """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) return client.get_environment(app_id=app_id) @mcp.tool() def list_executors( - app_id: str, server: Optional[str] = None, include_inactive: bool = False + app_id: str, server_spec: ServerSpec, include_inactive: bool = False ): """ Get executor information for a Spark application. @@ -319,14 +397,14 @@ def list_executors( Args: app_id: The Spark application ID - server: Optional server name to use (uses default if not specified) + server_spec: ServerSpec for specifying the server include_inactive: Whether to include inactive executors (default: False) Returns: List of ExecutorSummary objects containing executor information """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) if include_inactive: return client.list_all_executors(app_id=app_id) @@ -335,7 +413,7 @@ def list_executors( @mcp.tool() -def get_executor(app_id: str, executor_id: str, server: Optional[str] = None): +def get_executor(app_id: str, executor_id: str, server_spec: ServerSpec): """ Get information about a specific executor. @@ -345,13 +423,13 @@ def get_executor(app_id: str, executor_id: str, server: Optional[str] = None): Args: app_id: The Spark application ID executor_id: The executor ID - server: Optional server name to use (uses default if not specified) + server_spec: ServerSpec for specifying the server Returns: ExecutorSummary object containing executor details or None if not found """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) # Get all executors and find the one with matching ID executors = client.list_all_executors(app_id=app_id) @@ -364,7 +442,7 @@ def get_executor(app_id: str, executor_id: str, server: Optional[str] = None): @mcp.tool() -def get_executor_summary(app_id: str, server: Optional[str] = None): +def get_executor_summary(app_id: str, server_spec: ServerSpec): """ Aggregates metrics across all executors for a Spark application. @@ -373,13 +451,13 @@ def get_executor_summary(app_id: str, server: Optional[str] = None): Args: app_id: The Spark application ID - server: Optional server name to use (uses default if not specified) + server_spec: ServerSpec for specifying the server Returns: Dictionary containing aggregated executor metrics """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) executors = client.list_all_executors(app_id=app_id) @@ -417,7 +495,7 @@ def get_executor_summary(app_id: str, server: Optional[str] = None): @mcp.tool() def compare_job_environments( - app_id1: str, app_id2: str, server: Optional[str] = None + app_id1: str, app_id2: str, server_spec: ServerSpec ) -> Dict[str, Any]: """ Compare Spark environment configurations between two jobs. @@ -428,13 +506,13 @@ def compare_job_environments( Args: app_id1: First Spark application ID app_id2: Second Spark application ID - server: Optional server name to use (uses default if not specified) + server_spec: ServerSpec for specifying the server Returns: Dictionary containing configuration differences and similarities """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) env1 = client.get_environment(app_id=app_id1) env2 = client.get_environment(app_id=app_id2) @@ -504,7 +582,7 @@ def props_to_dict(props): @mcp.tool() def compare_job_performance( - app_id1: str, app_id2: str, server: Optional[str] = None + app_id1: str, app_id2: str, server_spec: ServerSpec ) -> Dict[str, Any]: """ Compare performance metrics between two Spark jobs. @@ -515,21 +593,21 @@ def compare_job_performance( Args: app_id1: First Spark application ID app_id2: Second Spark application ID - server: Optional server name to use (uses default if not specified) + server_spec: ServerSpec for specifying the server Returns: Dictionary containing detailed performance comparison """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) # Get application info app1 = client.get_application(app_id1) app2 = client.get_application(app_id2) # Get executor summaries - exec_summary1 = get_executor_summary(app_id1, server) - exec_summary2 = get_executor_summary(app_id2, server) + exec_summary1 = get_executor_summary(app_id1, server_spec) + exec_summary2 = get_executor_summary(app_id2, server_spec) # Get job data jobs1 = client.list_jobs(app_id=app_id1) @@ -618,9 +696,9 @@ def calc_job_stats(jobs): def compare_sql_execution_plans( app_id1: str, app_id2: str, + server_spec: ServerSpec, execution_id1: Optional[int] = None, execution_id2: Optional[int] = None, - server: Optional[str] = None, ) -> Dict[str, Any]: """ Compare SQL execution plans between two Spark jobs. @@ -631,15 +709,15 @@ def compare_sql_execution_plans( Args: app_id1: First Spark application ID app_id2: Second Spark application ID + server_spec: ServerSpec for specifying the server execution_id1: Optional specific execution ID for first app (uses longest if not specified) execution_id2: Optional specific execution ID for second app (uses longest if not specified) - server: Optional server name to use (uses default if not specified) Returns: Dictionary containing SQL execution plan comparison """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) # Get SQL executions for both applications sql_execs1 = client.get_sql_list( @@ -738,8 +816,8 @@ def analyze_nodes(execution): def get_stage_task_summary( app_id: str, stage_id: int, + server_spec: ServerSpec, attempt_id: int = 0, - server: Optional[str] = None, quantiles: str = "0.05,0.25,0.5,0.75,0.95", ) -> TaskMetricDistributions: """ @@ -751,15 +829,15 @@ def get_stage_task_summary( Args: app_id: The Spark application ID stage_id: The stage ID + server_spec: ServerSpec for specifying the server attempt_id: The stage attempt ID (default: 0) - server: Optional server name to use (uses default if not specified) quantiles: Comma-separated list of quantiles to use for summary metrics Returns: TaskMetricDistributions object containing metric distributions """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) return client.get_stage_task_summary( app_id=app_id, stage_id=stage_id, attempt_id=attempt_id, quantiles=quantiles @@ -794,7 +872,7 @@ def truncate_plan_description(plan_desc: str, max_length: int) -> str: @mcp.tool() def list_slowest_sql_queries( app_id: str, - server: Optional[str] = None, + server_spec: ServerSpec, attempt_id: Optional[str] = None, top_n: int = 1, page_size: int = 100, @@ -807,7 +885,7 @@ def list_slowest_sql_queries( Args: app_id: The Spark application ID - server: Optional server name to use (uses default if not specified) + server_spec: ServerSpec for specifying the server attempt_id: Optional attempt ID top_n: Number of slowest queries to return (default: 1) page_size: Number of executions to fetch per page (default: 100) @@ -819,7 +897,7 @@ def list_slowest_sql_queries( List of SqlQuerySummary objects for the slowest queries """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) all_executions: List[ExecutionData] = [] offset = 0 @@ -889,7 +967,7 @@ def list_slowest_sql_queries( @mcp.tool() def get_job_bottlenecks( - app_id: str, server: Optional[str] = None, top_n: int = 5 + app_id: str, server_spec: ServerSpec, top_n: int = 5 ) -> Dict[str, Any]: """ Identify performance bottlenecks in a Spark job. @@ -899,23 +977,23 @@ def get_job_bottlenecks( Args: app_id: The Spark application ID - server: Optional server name to use (uses default if not specified) + server_spec: ServerSpec for specifying the server top_n: Number of top bottlenecks to return Returns: Dictionary containing identified bottlenecks and recommendations """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) # Get slowest stages - slowest_stages = list_slowest_stages(app_id, server, False, top_n) + slowest_stages = list_slowest_stages(app_id, server_spec, False, top_n) # Get slowest jobs - slowest_jobs = list_slowest_jobs(app_id, server, False, top_n) + slowest_jobs = list_slowest_jobs(app_id, server_spec, False, top_n) # Get executor summary - exec_summary = get_executor_summary(app_id, server) + exec_summary = get_executor_summary(app_id, server_spec) all_stages = client.list_stages(app_id=app_id) @@ -1030,9 +1108,7 @@ def get_job_bottlenecks( @mcp.tool() -def get_resource_usage_timeline( - app_id: str, server: Optional[str] = None -) -> Dict[str, Any]: +def get_resource_usage_timeline(app_id: str, server_spec: ServerSpec) -> Dict[str, Any]: """ Get resource usage timeline for a Spark application. @@ -1041,13 +1117,13 @@ def get_resource_usage_timeline( Args: app_id: The Spark application ID - server: Optional server name to use (uses default if not specified) + server_spec: ServerSpec for specifying the server Returns: Dictionary containing timeline of resource usage """ ctx = mcp.get_context() - client = get_client_or_default(ctx, server) + client = get_client(ctx, server_spec) # Get application info app = client.get_application(app_id) diff --git a/tests/emr/test_emr_integration.py b/tests/emr/test_emr_integration.py index 6fb1157..044c45b 100644 --- a/tests/emr/test_emr_integration.py +++ b/tests/emr/test_emr_integration.py @@ -62,7 +62,7 @@ def test_spark_client_with_emr_session(self, mock_initialize): mock_session.get.assert_called_once() self.assertEqual(apps, []) - @patch("spark_history_mcp.core.app.EMRPersistentUIClient") + @patch("spark_history_mcp.api.client_factory.EMRPersistentUIClient") @patch("spark_history_mcp.core.app.Config.from_file") def test_app_lifespan_with_emr_config( self, mock_config_from_file, mock_emr_client_class @@ -92,6 +92,7 @@ def test_app_lifespan_with_emr_config( # Set up the mock config mock_config = MagicMock() + mock_config.dynamic_emr_clusters_mode = False mock_config.servers = { "emr": ServerConfig( emr_cluster_arn=self.emr_cluster_arn, default=True, verify_ssl=True @@ -109,8 +110,11 @@ async def test_lifespan(): mock_emr_client.initialize.assert_called_once() # Verify context has clients - self.assertIn("emr", context.clients) - self.assertEqual(context.default_client, context.clients["emr"]) + self.assertIn("emr", context.static_clients.clients) + self.assertEqual( + context.static_clients.default_client, + context.static_clients.clients["emr"], + ) # Run the async test try: diff --git a/tests/unit/config.py b/tests/unit/config.py index 3495755..cfa137e 100644 --- a/tests/unit/config.py +++ b/tests/unit/config.py @@ -180,3 +180,71 @@ def test_model_serialization(self): # Test with explicit exclude server_dict = server.model_dump(exclude={"auth"}) self.assertNotIn("auth", server_dict) + + def test_dynamic_emr_clusters_mode_default(self): + """Test that dynamic_emr_clusters_mode defaults to False.""" + minimal_config = {"servers": {"minimal": {"url": "http://minimal:18080"}}} + + with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file: + yaml.dump(minimal_config, temp_file) + temp_file_path = temp_file.name + + try: + config = Config.from_file(temp_file_path) + self.assertFalse(config.dynamic_emr_clusters_mode) + finally: + os.unlink(temp_file_path) + + def test_dynamic_emr_clusters_mode_enabled(self): + """Test that dynamic_emr_clusters_mode can be enabled.""" + config_data = {"dynamic_emr_clusters_mode": True} + + with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file: + yaml.dump(config_data, temp_file) + temp_file_path = temp_file.name + + try: + config = Config.from_file(temp_file_path) + self.assertTrue(config.dynamic_emr_clusters_mode) + self.assertIsNone(config.servers) + finally: + os.unlink(temp_file_path) + + def test_dynamic_emr_clusters_mode_mutually_exclusive_with_servers(self): + """Test that dynamic_emr_clusters_mode and servers are mutually exclusive.""" + config_data = { + "dynamic_emr_clusters_mode": True, + "servers": {"test": {"url": "http://test:18080"}}, + } + + with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file: + yaml.dump(config_data, temp_file) + temp_file_path = temp_file.name + + try: + with self.assertRaises(ValueError) as context: + Config.from_file(temp_file_path) + self.assertIn( + "dynamic_emr_clusters_mode cannot be True when servers is not empty", + str(context.exception), + ) + finally: + os.unlink(temp_file_path) + + def test_default_servers_when_dynamic_mode_disabled(self): + """Test that default servers are set when dynamic_emr_clusters_mode is False.""" + minimal_config = {} # Empty config + + with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file: + yaml.dump(minimal_config, temp_file) + temp_file_path = temp_file.name + + try: + config = Config.from_file(temp_file_path) + self.assertFalse(config.dynamic_emr_clusters_mode) + self.assertIsNotNone(config.servers) + self.assertIn("local", config.servers) + self.assertEqual(config.servers["local"].url, "http://localhost:18080") + self.assertTrue(config.servers["local"].default) + finally: + os.unlink(temp_file_path) diff --git a/tests/unit/test_client_factory.py b/tests/unit/test_client_factory.py new file mode 100644 index 0000000..1de0c2c --- /dev/null +++ b/tests/unit/test_client_factory.py @@ -0,0 +1,140 @@ +import unittest +from unittest.mock import MagicMock, patch + +from spark_history_mcp.api.client_factory import ( + create_spark_client_from_config, + create_spark_emr_client, +) +from spark_history_mcp.api.spark_client import SparkRestClient +from spark_history_mcp.config.config import ServerConfig + + +class TestAPIFactory(unittest.TestCase): + """Test cases for the API factory functions.""" + + def test_create_spark_client_from_config_regular(self): + """Test creating a regular SparkRestClient from ServerConfig.""" + server_config = ServerConfig(url="http://localhost:18080") + + client = create_spark_client_from_config(server_config) + + self.assertIsInstance(client, SparkRestClient) + # Note: SparkRestClient stores the config internally but doesn't expose it as server_config + + @patch("spark_history_mcp.api.client_factory.create_spark_emr_client") + def test_create_spark_client_from_config_emr(self, mock_create_emr_client): + """Test creating an EMR SparkRestClient from ServerConfig.""" + mock_emr_client = MagicMock(spec=SparkRestClient) + mock_create_emr_client.return_value = mock_emr_client + + server_config = ServerConfig( + url="http://localhost:18080", + emr_cluster_arn="arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC", + ) + + client = create_spark_client_from_config(server_config) + + self.assertEqual(client, mock_emr_client) + mock_create_emr_client.assert_called_once_with( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC", + server_config, + ) + + @patch("spark_history_mcp.api.client_factory.EMRPersistentUIClient") + @patch("spark_history_mcp.api.client_factory.SparkRestClient") + def test_create_spark_emr_client_success( + self, mock_spark_client_class, mock_emr_client_class + ): + """Test successful creation of EMR SparkRestClient.""" + # Mock EMR client + mock_emr_client = MagicMock() + mock_emr_client.initialize.return_value = ("http://emr-base-url", MagicMock()) + mock_emr_client_class.return_value = mock_emr_client + + # Mock SparkRestClient + mock_spark_client = MagicMock() + mock_spark_client_class.return_value = mock_spark_client + + emr_cluster_arn = ( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC" + ) + server_config = ServerConfig(url="http://original-url:18080") + + result = create_spark_emr_client(emr_cluster_arn, server_config) + + # Verify EMR client was created and initialized + mock_emr_client_class.assert_called_once() + call_args = mock_emr_client_class.call_args[0][0] + self.assertIsInstance(call_args, ServerConfig) + self.assertEqual(call_args.emr_cluster_arn, emr_cluster_arn) + self.assertEqual(call_args.url, "http://original-url:18080") + mock_emr_client.initialize.assert_called_once() + + # Verify SparkRestClient was created with modified config + self.assertEqual(result, mock_spark_client) + # Check that the server config URL was modified + call_args = mock_spark_client_class.call_args[0][0] + self.assertEqual(call_args.url, "http://emr-base-url") + + @patch("spark_history_mcp.api.client_factory.EMRPersistentUIClient") + @patch("spark_history_mcp.api.client_factory.SparkRestClient") + def test_create_spark_emr_client_no_server_config( + self, mock_spark_client_class, mock_emr_client_class + ): + """Test creating EMR SparkRestClient without server config.""" + # Mock EMR client + mock_emr_client = MagicMock() + mock_emr_client.initialize.return_value = ("http://emr-base-url", MagicMock()) + mock_emr_client_class.return_value = mock_emr_client + + # Mock SparkRestClient + mock_spark_client = MagicMock() + mock_spark_client_class.return_value = mock_spark_client + + emr_cluster_arn = ( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC" + ) + + result = create_spark_emr_client(emr_cluster_arn, None) + + # Verify EMR client was created and initialized + mock_emr_client_class.assert_called_once() + call_args = mock_emr_client_class.call_args[0][0] + self.assertIsInstance(call_args, ServerConfig) + self.assertEqual(call_args.emr_cluster_arn, emr_cluster_arn) + mock_emr_client.initialize.assert_called_once() + + # Verify SparkRestClient was created with new config + self.assertEqual(result, mock_spark_client) + # Check that a default server config was created + call_args = mock_spark_client_class.call_args[0][0] + self.assertEqual(call_args.url, "http://emr-base-url") + + @patch("spark_history_mcp.api.client_factory.EMRPersistentUIClient") + @patch("spark_history_mcp.api.client_factory.SparkRestClient") + def test_create_spark_emr_client_session_assignment( + self, mock_spark_client_class, mock_emr_client_class + ): + """Test that the authenticated session is properly assigned.""" + # Mock EMR client + mock_emr_client = MagicMock() + mock_session = MagicMock() + mock_emr_client.initialize.return_value = ("http://emr-base-url", mock_session) + mock_emr_client_class.return_value = mock_emr_client + + # Mock SparkRestClient + mock_spark_client = MagicMock() + mock_spark_client_class.return_value = mock_spark_client + + emr_cluster_arn = ( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC" + ) + + result = create_spark_emr_client(emr_cluster_arn) + + # Verify the session was assigned to the SparkRestClient + self.assertEqual(result.session, mock_session) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_emr_client.py b/tests/unit/test_emr_client.py new file mode 100644 index 0000000..5e09572 --- /dev/null +++ b/tests/unit/test_emr_client.py @@ -0,0 +1,262 @@ +import unittest +from unittest.mock import MagicMock, patch + +from botocore.exceptions import ClientError, NoCredentialsError + +from spark_history_mcp.api.emr_client import EMRClient, EMRClusterNotFoundError + + +class TestEMRClient(unittest.TestCase): + """Test cases for the EMRClient class.""" + + def setUp(self): + """Set up test fixtures.""" + self.mock_emr_client = MagicMock() + self.mock_emr_client.meta.region_name = "us-east-1" + + @patch("boto3.client") + def test_init_success(self, mock_boto_client): + """Test successful EMRClient initialization.""" + mock_boto_client.return_value = self.mock_emr_client + + client = EMRClient() + + mock_boto_client.assert_called_once_with("emr", region_name=None) + self.assertEqual(client.region_name, "us-east-1") + + @patch("boto3.client") + def test_init_with_region(self, mock_boto_client): + """Test EMRClient initialization with specific region.""" + mock_boto_client.return_value = self.mock_emr_client + + EMRClient(region_name="us-west-2") + + mock_boto_client.assert_called_once_with("emr", region_name="us-west-2") + + @patch("boto3.client") + def test_init_no_credentials(self, mock_boto_client): + """Test EMRClient initialization fails with no credentials.""" + mock_boto_client.side_effect = NoCredentialsError() + + with self.assertRaises(NoCredentialsError): + EMRClient() + + @patch("boto3.client") + def test_get_cluster_arn_by_id_success(self, mock_boto_client): + """Test successful cluster ARN retrieval by ID.""" + mock_boto_client.return_value = self.mock_emr_client + + cluster_arn = ( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC" + ) + self.mock_emr_client.describe_cluster.return_value = { + "Cluster": {"ClusterArn": cluster_arn} + } + + client = EMRClient() + result = client.get_cluster_arn_by_id("j-1234567890ABC") + + self.assertEqual(result, cluster_arn) + self.mock_emr_client.describe_cluster.assert_called_once_with( + ClusterId="j-1234567890ABC" + ) + + @patch("boto3.client") + def test_get_cluster_arn_by_id_not_found(self, mock_boto_client): + """Test cluster ARN retrieval by ID when cluster doesn't exist.""" + mock_boto_client.return_value = self.mock_emr_client + + error_response = {"Error": {"Code": "InvalidRequestException"}} + self.mock_emr_client.describe_cluster.side_effect = ClientError( + error_response, "DescribeCluster" + ) + + client = EMRClient() + + with self.assertRaises(EMRClusterNotFoundError) as context: + client.get_cluster_arn_by_id("j-nonexistent") + + self.assertIn("not found", str(context.exception)) + + @patch("boto3.client") + def test_get_active_cluster_arn_by_name_success(self, mock_boto_client): + """Test successful cluster ARN retrieval by name.""" + mock_boto_client.return_value = self.mock_emr_client + + cluster_arn = ( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC" + ) + + # Mock list_clusters response + self.mock_emr_client.list_clusters.return_value = { + "Clusters": [ + { + "Id": "j-1234567890ABC", + "Name": "test-cluster", + "Status": {"State": "RUNNING"}, + } + ] + } + + # Mock describe_cluster response + self.mock_emr_client.describe_cluster.return_value = { + "Cluster": {"ClusterArn": cluster_arn} + } + + client = EMRClient() + result = client.get_active_cluster_arn_by_name("test-cluster") + + self.assertEqual(result, cluster_arn) + + @patch("boto3.client") + def test_get_active_cluster_arn_by_name_not_found(self, mock_boto_client): + """Test cluster ARN retrieval by name when cluster doesn't exist.""" + mock_boto_client.return_value = self.mock_emr_client + + # Mock empty list_clusters response + self.mock_emr_client.list_clusters.return_value = {"Clusters": []} + + client = EMRClient() + + with self.assertRaises(EMRClusterNotFoundError) as context: + client.get_active_cluster_arn_by_name("nonexistent-cluster") + + self.assertIn("No cluster found", str(context.exception)) + + @patch("boto3.client") + def test_get_active_cluster_arn_by_name_multiple_found(self, mock_boto_client): + """Test cluster ARN retrieval by name when multiple clusters exist.""" + mock_boto_client.return_value = self.mock_emr_client + + # Mock list_clusters response with multiple clusters + self.mock_emr_client.list_clusters.return_value = { + "Clusters": [ + { + "Id": "j-1234567890ABC", + "Name": "test-cluster", + "Status": {"State": "RUNNING"}, + }, + { + "Id": "j-0987654321DEF", + "Name": "test-cluster", + "Status": {"State": "WAITING"}, + }, + ] + } + + client = EMRClient() + + with self.assertRaises(EMRClusterNotFoundError) as context: + client.get_active_cluster_arn_by_name("test-cluster") + + self.assertIn("Multiple clusters found", str(context.exception)) + + @patch("boto3.client") + def test_get_cluster_arn_with_cluster_id(self, mock_boto_client): + """Test get_cluster_arn with cluster ID.""" + mock_boto_client.return_value = self.mock_emr_client + + cluster_arn = ( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC" + ) + self.mock_emr_client.describe_cluster.return_value = { + "Cluster": {"ClusterArn": cluster_arn} + } + + client = EMRClient() + result = client.get_cluster_arn("j-1234567890ABC") + + self.assertEqual(result, cluster_arn) + + @patch("boto3.client") + def test_get_cluster_arn_with_cluster_name(self, mock_boto_client): + """Test get_cluster_arn with cluster name.""" + mock_boto_client.return_value = self.mock_emr_client + + cluster_arn = ( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC" + ) + + # Mock list_clusters response + self.mock_emr_client.list_clusters.return_value = { + "Clusters": [ + { + "Id": "j-1234567890ABC", + "Name": "test-cluster", + "Status": {"State": "RUNNING"}, + } + ] + } + + # Mock describe_cluster response + self.mock_emr_client.describe_cluster.return_value = { + "Cluster": {"ClusterArn": cluster_arn} + } + + client = EMRClient() + result = client.get_cluster_arn("test-cluster") + + self.assertEqual(result, cluster_arn) + + @patch("boto3.client") + def test_get_cluster_details_success(self, mock_boto_client): + """Test successful cluster details retrieval.""" + mock_boto_client.return_value = self.mock_emr_client + + cluster_arn = ( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC" + ) + cluster_details = { + "Id": "j-1234567890ABC", + "Name": "test-cluster", + "ClusterArn": cluster_arn, + "Status": {"State": "RUNNING"}, + } + + self.mock_emr_client.describe_cluster.return_value = { + "Cluster": cluster_details + } + + client = EMRClient() + result = client.get_cluster_details("j-1234567890ABC") + + self.assertEqual(result, cluster_details) + + @patch("boto3.client") + def test_find_active_clusters_by_name_pagination(self, mock_boto_client): + """Test _find_active_clusters_by_name with pagination.""" + mock_boto_client.return_value = self.mock_emr_client + + # Mock paginated responses + self.mock_emr_client.list_clusters.side_effect = [ + { + "Clusters": [ + { + "Id": "j-1234567890ABC", + "Name": "other-cluster", + "Status": {"State": "RUNNING"}, + } + ], + "Marker": "next-page-token", + }, + { + "Clusters": [ + { + "Id": "j-0987654321DEF", + "Name": "test-cluster", + "Status": {"State": "RUNNING"}, + } + ] + }, + ] + + client = EMRClient() + result = client._find_active_clusters_by_name("test-cluster") + + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["Id"], "j-0987654321DEF") + self.assertEqual(self.mock_emr_client.list_clusters.call_count, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_tools.py b/tests/unit/test_tools.py index 89e0107..7a78501 100644 --- a/tests/unit/test_tools.py +++ b/tests/unit/test_tools.py @@ -3,6 +3,11 @@ from unittest.mock import MagicMock, patch from spark_history_mcp.api.spark_client import SparkRestClient +from spark_history_mcp.models.server_spec import ( + DynamicEMRServerSpec, + ServerSpec, + StaticServerSpec, +) from spark_history_mcp.models.spark_types import ( ApplicationInfo, ExecutionData, @@ -12,7 +17,7 @@ ) from spark_history_mcp.tools.tools import ( get_application, - get_client_or_default, + get_client, get_stage, get_stage_task_summary, list_jobs, @@ -23,74 +28,451 @@ ) +class MockRequestContext: + """Simple mock request context that can store attributes properly""" + + def __init__(self, lifespan_context): + self.lifespan_context = lifespan_context + + class TestTools(unittest.TestCase): + # Common server spec for tests that use default client + DEFAULT_SERVER_SPEC = ServerSpec( + static_server_spec=StaticServerSpec(default_client=True) + ) + def setUp(self): # Create mock context self.mock_ctx = MagicMock() self.mock_lifespan_context = MagicMock() - self.mock_ctx.request_context.lifespan_context = self.mock_lifespan_context + + # Create a request context that can properly store attributes + self.mock_request_context = MockRequestContext(self.mock_lifespan_context) + self.mock_ctx.request_context = self.mock_request_context # Create mock clients self.mock_client1 = MagicMock(spec=SparkRestClient) self.mock_client2 = MagicMock(spec=SparkRestClient) - # Set up clients dictionary - self.mock_lifespan_context.clients = { + # Set up static clients structure + self.mock_static_clients = MagicMock() + self.mock_static_clients.clients = { "server1": self.mock_client1, "server2": self.mock_client2, } + self.mock_static_clients.default_client = self.mock_client1 - def test_get_client_with_name(self): - """Test getting a client by name""" - self.mock_lifespan_context.default_client = self.mock_client1 + # Set up lifespan context for static mode + self.mock_lifespan_context.dynamic_emr_clusters_mode = False + self.mock_lifespan_context.static_clients = self.mock_static_clients + self.mock_lifespan_context.emr_client = None + + def test_get_client_with_static_server_name(self): + """Test getting a client by server name with static spec""" + server_spec = ServerSpec( + static_server_spec=StaticServerSpec(server_name="server2") + ) # Get client by name - client = get_client_or_default(self.mock_ctx, "server2") + client = get_client(self.mock_ctx, server_spec) # Should return the requested client self.assertEqual(client, self.mock_client2) - def test_get_default_client(self): - """Test getting the default client when no name is provided""" - self.mock_lifespan_context.default_client = self.mock_client1 + def test_get_default_client_static(self): + """Test getting the default client with static spec""" + server_spec = ServerSpec( + static_server_spec=StaticServerSpec(default_client=True) + ) - # Get client without specifying name - client = get_client_or_default(self.mock_ctx) + # Get default client + client = get_client(self.mock_ctx, server_spec) # Should return the default client self.assertEqual(client, self.mock_client1) - def test_get_client_not_found_with_default(self): - """Test behavior when requested client is not found but default exists""" - self.mock_lifespan_context.default_client = self.mock_client1 + def test_get_client_not_found_static(self): + """Test error when requested static client is not found""" + server_spec = ServerSpec( + static_server_spec=StaticServerSpec(server_name="non_existent_server") + ) - # Get non-existent client - client = get_client_or_default(self.mock_ctx, "non_existent_server") + # Try to get non-existent client + with self.assertRaises(ValueError) as context: + get_client(self.mock_ctx, server_spec) - # Should fall back to default client - self.assertEqual(client, self.mock_client1) + self.assertIn("No server configured with name", str(context.exception)) - def test_no_client_found(self): - """Test error when no client is found and no default exists""" - self.mock_lifespan_context.default_client = None + def test_no_default_client_static(self): + """Test error when no default client exists in static mode""" + self.mock_static_clients.default_client = None + server_spec = ServerSpec( + static_server_spec=StaticServerSpec(default_client=True) + ) - # Try to get non-existent client with no default + # Try to get default client when none exists with self.assertRaises(ValueError) as context: - get_client_or_default(self.mock_ctx, "non_existent_server") + get_client(self.mock_ctx, server_spec) - self.assertIn("No Spark client found", str(context.exception)) + self.assertIn("No default client configured", str(context.exception)) - def test_no_default_client(self): - """Test error when no name is provided and no default exists""" - self.mock_lifespan_context.default_client = None + def test_dynamic_emr_mode_with_static_spec_error(self): + """Test error when using static spec in dynamic EMR mode""" + self.mock_lifespan_context.dynamic_emr_clusters_mode = True + server_spec = ServerSpec( + static_server_spec=StaticServerSpec(default_client=True) + ) + + with self.assertRaises(ValueError) as context: + get_client(self.mock_ctx, server_spec) + + self.assertIn( + "MCP is running in dynamic EMR mode, but static server spec was provided", + str(context.exception), + ) + + def test_static_mode_with_dynamic_spec_error(self): + """Test error when using dynamic spec in static mode""" + server_spec = ServerSpec( + dynamic_emr_server_spec=DynamicEMRServerSpec(emr_cluster_id="j-123456") + ) + + with self.assertRaises(ValueError) as context: + get_client(self.mock_ctx, server_spec) + + self.assertIn( + "MCP is not running in dynamic EMR mode, but dynamic server spec was provided", + str(context.exception), + ) + + @patch("spark_history_mcp.tools.tools.create_spark_emr_client") + def test_dynamic_emr_client_by_cluster_arn(self, mock_create_client): + """Test getting EMR client by cluster ARN in dynamic mode""" + # Set up dynamic EMR mode + self.mock_lifespan_context.dynamic_emr_clusters_mode = True + self.mock_lifespan_context.static_clients = None + + # Mock the create_spark_emr_client function + mock_emr_client = MagicMock(spec=SparkRestClient) + mock_create_client.return_value = mock_emr_client + + server_spec = ServerSpec( + dynamic_emr_server_spec=DynamicEMRServerSpec( + emr_cluster_arn="arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC" + ) + ) + + client = get_client(self.mock_ctx, server_spec) + + self.assertEqual(client, mock_emr_client) + mock_create_client.assert_called_once_with( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC" + ) + + @patch("spark_history_mcp.tools.tools.create_spark_emr_client") + def test_dynamic_emr_client_by_cluster_id(self, mock_create_client): + """Test getting EMR client by cluster ID in dynamic mode""" + # Set up dynamic EMR mode with mock EMR client + self.mock_lifespan_context.dynamic_emr_clusters_mode = True + self.mock_lifespan_context.static_clients = None + + # Clear caches to ensure clean test + from spark_history_mcp.tools.tools import ( + arn_to_spark_emr_client_cache, + emr_cluster_id_to_arn_cache, + ) + + arn_to_spark_emr_client_cache.clear() + emr_cluster_id_to_arn_cache.clear() + + mock_emr_client = MagicMock() + mock_emr_client.get_cluster_arn_by_id.return_value = ( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC" + ) + self.mock_lifespan_context.emr_client = mock_emr_client + + # Mock the create_spark_emr_client function + mock_spark_client = MagicMock(spec=SparkRestClient) + mock_create_client.return_value = mock_spark_client + + server_spec = ServerSpec( + dynamic_emr_server_spec=DynamicEMRServerSpec( + emr_cluster_id="j-1234567890ABC" + ) + ) + + client = get_client(self.mock_ctx, server_spec) + + self.assertEqual(client, mock_spark_client) + mock_emr_client.get_cluster_arn_by_id.assert_called_once_with("j-1234567890ABC") + mock_create_client.assert_called_once_with( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC" + ) + + @patch("spark_history_mcp.tools.tools.create_spark_emr_client") + def test_dynamic_emr_client_by_cluster_name(self, mock_create_client): + """Test getting EMR client by cluster name in dynamic mode""" + # Set up dynamic EMR mode with mock EMR client + self.mock_lifespan_context.dynamic_emr_clusters_mode = True + self.mock_lifespan_context.static_clients = None + + # Clear caches to ensure clean test + from spark_history_mcp.tools.tools import arn_to_spark_emr_client_cache + + arn_to_spark_emr_client_cache.clear() + + mock_emr_client = MagicMock() + mock_emr_client.get_active_cluster_arn_by_name.return_value = ( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC" + ) + self.mock_lifespan_context.emr_client = mock_emr_client + + # Mock the create_spark_emr_client function + mock_spark_client = MagicMock(spec=SparkRestClient) + mock_create_client.return_value = mock_spark_client + + server_spec = ServerSpec( + dynamic_emr_server_spec=DynamicEMRServerSpec( + emr_cluster_name="test-cluster" + ) + ) + + client = get_client(self.mock_ctx, server_spec) + + self.assertEqual(client, mock_spark_client) + mock_emr_client.get_active_cluster_arn_by_name.assert_called_once_with( + "test-cluster" + ) + mock_create_client.assert_called_once_with( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC" + ) + + @patch("spark_history_mcp.tools.tools.create_spark_emr_client") + def test_dynamic_emr_cluster_name_caching(self, mock_create_client): + """Test that cluster name to ARN mapping is cached (request-scoped)""" + # Set up dynamic EMR mode with mock EMR client + self.mock_lifespan_context.dynamic_emr_clusters_mode = True + self.mock_lifespan_context.static_clients = None + + # Clear caches to ensure clean test + from spark_history_mcp.tools.tools import arn_to_spark_emr_client_cache + + arn_to_spark_emr_client_cache.clear() + # Note: cluster name cache is now request-scoped and automatically fresh per test + + mock_emr_client = MagicMock() + mock_emr_client.get_active_cluster_arn_by_name.return_value = ( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC" + ) + self.mock_lifespan_context.emr_client = mock_emr_client + + # Mock the create_spark_emr_client function + mock_spark_client = MagicMock(spec=SparkRestClient) + mock_create_client.return_value = mock_spark_client + + server_spec = ServerSpec( + dynamic_emr_server_spec=DynamicEMRServerSpec( + emr_cluster_name="test-cluster" + ) + ) + + # First call should trigger EMR client call + client1 = get_client(self.mock_ctx, server_spec) + # Second call should use cached ARN + client2 = get_client(self.mock_ctx, server_spec) + + # Both calls should return the same client + self.assertEqual(client1, client2) + # EMR client method should only be called once (caching works) + mock_emr_client.get_active_cluster_arn_by_name.assert_called_once_with( + "test-cluster" + ) + # Spark client should only be created once (ARN caching leads to client caching) + mock_create_client.assert_called_once_with( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC" + ) + + @patch("spark_history_mcp.tools.tools.create_spark_emr_client") + def test_dynamic_emr_cluster_name_session_isolation(self, mock_create_client): + """Test that cluster name caching is isolated between different sessions""" + # Set up dynamic EMR mode with mock EMR client + self.mock_lifespan_context.dynamic_emr_clusters_mode = True + self.mock_lifespan_context.static_clients = None + + # Clear caches to ensure clean test + from spark_history_mcp.tools.tools import ( + arn_to_spark_emr_client_cache, + session_emr_cluster_name_to_arn_cache, + ) + + arn_to_spark_emr_client_cache.clear() + session_emr_cluster_name_to_arn_cache.clear() + + mock_emr_client = MagicMock() + # Configure the EMR client to return different ARNs for the same cluster name + # to demonstrate that caching is working independently per session + mock_emr_client.get_active_cluster_arn_by_name.side_effect = [ + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-session1cluster", + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-session2cluster", + ] + self.mock_lifespan_context.emr_client = mock_emr_client + + # Create two different sessions + session1 = MagicMock() + session2 = MagicMock() + + # Create two different contexts with different sessions + ctx1 = MagicMock() + ctx1.session = session1 + ctx1.request_context = MockRequestContext(self.mock_lifespan_context) + + ctx2 = MagicMock() + ctx2.session = session2 + ctx2.request_context = MockRequestContext(self.mock_lifespan_context) + + # Mock the create_spark_emr_client function + mock_spark_client1 = MagicMock(spec=SparkRestClient) + mock_spark_client2 = MagicMock(spec=SparkRestClient) + mock_create_client.side_effect = [mock_spark_client1, mock_spark_client2] + + server_spec = ServerSpec( + dynamic_emr_server_spec=DynamicEMRServerSpec( + emr_cluster_name="shared-cluster-name" + ) + ) + + # Make calls from both sessions with the same cluster name + client1 = get_client(ctx1, server_spec) + client2 = get_client(ctx2, server_spec) + + # Both calls should trigger EMR client calls (no cross-session caching) + self.assertEqual(mock_emr_client.get_active_cluster_arn_by_name.call_count, 2) + + # Verify each session got its own client + self.assertEqual(client1, mock_spark_client1) + self.assertEqual(client2, mock_spark_client2) + + # Verify different ARNs were used for Spark client creation + expected_calls = [ + unittest.mock.call( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-session1cluster" + ), + unittest.mock.call( + "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-session2cluster" + ), + ] + mock_create_client.assert_has_calls(expected_calls) + + # Now test that caching works within each session + # Reset the mock to track additional calls + mock_emr_client.reset_mock() + + # Make another call from session1 - should use cache + client1_cached = get_client(ctx1, server_spec) + # Make another call from session2 - should use cache + client2_cached = get_client(ctx2, server_spec) + + # No additional EMR client calls should be made (caching is working) + mock_emr_client.get_active_cluster_arn_by_name.assert_not_called() + + # Should return the same clients as before (from ARN cache) + self.assertEqual(client1_cached, mock_spark_client1) + self.assertEqual(client2_cached, mock_spark_client2) + + @patch("spark_history_mcp.tools.tools.create_spark_emr_client") + def test_dynamic_emr_client_caching(self, mock_create_client): + """Test that EMR clients are cached by ARN""" + # Set up dynamic EMR mode + self.mock_lifespan_context.dynamic_emr_clusters_mode = True + self.mock_lifespan_context.static_clients = None + + arn = "arn:aws:elasticmapreduce:us-east-1:123456789012:cluster/j-1234567890ABC" + + # Clear the cache to ensure clean test + from spark_history_mcp.tools.tools import arn_to_spark_emr_client_cache + + arn_to_spark_emr_client_cache.clear() + + # Mock the create_spark_emr_client function + mock_spark_client = MagicMock(spec=SparkRestClient) + mock_create_client.return_value = mock_spark_client + + server_spec = ServerSpec( + dynamic_emr_server_spec=DynamicEMRServerSpec(emr_cluster_arn=arn) + ) + + # First call + client1 = get_client(self.mock_ctx, server_spec) + # Second call + client2 = get_client(self.mock_ctx, server_spec) + + # Should return the same cached client + self.assertEqual(client1, client2) + # Should only create the client once + mock_create_client.assert_called_once_with(arn) + + def test_dynamic_emr_no_emr_client_error(self): + """Test error when EMR client is not initialized in dynamic mode""" + # Set up dynamic EMR mode without EMR client + self.mock_lifespan_context.dynamic_emr_clusters_mode = True + self.mock_lifespan_context.static_clients = None + self.mock_lifespan_context.emr_client = None + + # Clear caches to ensure the ID lookup is attempted + from spark_history_mcp.tools.tools import ( + emr_cluster_id_to_arn_cache, + ) + + emr_cluster_id_to_arn_cache.clear() + + server_spec = ServerSpec( + dynamic_emr_server_spec=DynamicEMRServerSpec( + emr_cluster_id="j-1234567890ABC" + ) + ) - # Try to get default client when none exists with self.assertRaises(ValueError) as context: - get_client_or_default(self.mock_ctx) + get_client(self.mock_ctx, server_spec) + + self.assertIn( + "EMR client is not initialized in dynamic mode", str(context.exception) + ) + + def test_dynamic_emr_no_emr_client_error_cluster_name(self): + """Test error when EMR client is not initialized in dynamic mode with cluster name""" + # Set up dynamic EMR mode without EMR client + self.mock_lifespan_context.dynamic_emr_clusters_mode = True + self.mock_lifespan_context.static_clients = None + self.mock_lifespan_context.emr_client = None + + server_spec = ServerSpec( + dynamic_emr_server_spec=DynamicEMRServerSpec( + emr_cluster_name="test-cluster" + ) + ) - self.assertIn("No Spark client found", str(context.exception)) + with self.assertRaises(ValueError) as context: + get_client(self.mock_ctx, server_spec) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + self.assertIn( + "EMR client is not initialized in dynamic mode", str(context.exception) + ) + + def test_dynamic_emr_invalid_server_spec(self): + """Test error when dynamic server spec is invalid""" + # Set up dynamic EMR mode + self.mock_lifespan_context.dynamic_emr_clusters_mode = True + self.mock_lifespan_context.static_clients = None + + # Empty dynamic server spec + server_spec = ServerSpec(dynamic_emr_server_spec=DynamicEMRServerSpec()) + + with self.assertRaises(ValueError) as context: + get_client(self.mock_ctx, server_spec) + + self.assertIn("Invalid server_spec", str(context.exception)) + + @patch("spark_history_mcp.tools.tools.get_client") def test_get_slowest_jobs_empty(self, mock_get_client): """Test list_slowest_jobs when no jobs are found""" # Setup mock client @@ -99,13 +481,13 @@ def test_get_slowest_jobs_empty(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function - result = list_slowest_jobs("app-123", n=3) + result = list_slowest_jobs("app-123", self.DEFAULT_SERVER_SPEC, n=3) # Verify results self.assertEqual(result, []) mock_client.list_jobs.assert_called_once_with(app_id="app-123") - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_slowest_jobs_exclude_running(self, mock_get_client): """Test list_slowest_jobs excluding running jobs""" # Setup mock client and jobs @@ -136,7 +518,7 @@ def test_get_slowest_jobs_exclude_running(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function with include_running=False (default) - result = list_slowest_jobs("app-123", n=2) + result = list_slowest_jobs("app-123", self.DEFAULT_SERVER_SPEC, n=2) # Verify results - should return job3 and job2 (in that order) self.assertEqual(len(result), 2) @@ -146,7 +528,7 @@ def test_get_slowest_jobs_exclude_running(self, mock_get_client): # Running job (job1) should be excluded self.assertNotIn(job1, result) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_slowest_jobs_include_running(self, mock_get_client): """Test list_slowest_jobs including running jobs""" # Setup mock client and jobs @@ -174,7 +556,9 @@ def test_get_slowest_jobs_include_running(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function with include_running=True - result = list_slowest_jobs("app-123", include_running=True, n=2) + result = list_slowest_jobs( + "app-123", self.DEFAULT_SERVER_SPEC, include_running=True, n=2 + ) # Verify results - should include the running job self.assertEqual(len(result), 2) @@ -183,7 +567,7 @@ def test_get_slowest_jobs_include_running(self, mock_get_client): self.assertEqual(result[0], job3) self.assertEqual(result[1], job2) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_slowest_jobs_limit_results(self, mock_get_client): """Test list_slowest_jobs limits results to n""" # Setup mock client and jobs @@ -203,12 +587,12 @@ def test_get_slowest_jobs_limit_results(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function with n=3 - result = list_slowest_jobs("app-123", n=3) + result = list_slowest_jobs("app-123", self.DEFAULT_SERVER_SPEC, n=3) # Verify results - should return only 3 jobs self.assertEqual(len(result), 3) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_stage_with_attempt_id(self, mock_get_client): """Test get_stage with a specific attempt ID""" # Setup mock client @@ -221,7 +605,7 @@ def test_get_stage_with_attempt_id(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function with attempt_id - result = get_stage("app-123", stage_id=1, attempt_id=0) + result = get_stage("app-123", 1, self.DEFAULT_SERVER_SPEC, attempt_id=0) # Verify results self.assertEqual(result, mock_stage) @@ -233,7 +617,7 @@ def test_get_stage_with_attempt_id(self, mock_get_client): with_summaries=False, ) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_stage_without_attempt_id_single_stage(self, mock_get_client): """Test get_stage without attempt ID when a single stage is returned""" # Setup mock client @@ -246,7 +630,10 @@ def test_get_stage_without_attempt_id_single_stage(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function without attempt_id - result = get_stage("app-123", stage_id=1) + server_spec = ServerSpec( + static_server_spec=StaticServerSpec(default_client=True) + ) + result = get_stage("app-123", 1, server_spec) # Verify results self.assertEqual(result, mock_stage) @@ -257,7 +644,7 @@ def test_get_stage_without_attempt_id_single_stage(self, mock_get_client): with_summaries=False, ) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_stage_without_attempt_id_multiple_stages(self, mock_get_client): """Test get_stage without attempt ID when multiple stages are returned""" # Setup mock client @@ -276,7 +663,10 @@ def test_get_stage_without_attempt_id_multiple_stages(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function without attempt_id - result = get_stage("app-123", stage_id=1) + server_spec = ServerSpec( + static_server_spec=StaticServerSpec(default_client=True) + ) + result = get_stage("app-123", 1, server_spec) # Verify results - should return the stage with highest attempt_id self.assertEqual(result, mock_stage2) @@ -287,7 +677,7 @@ def test_get_stage_without_attempt_id_multiple_stages(self, mock_get_client): with_summaries=False, ) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_stage_with_summaries_missing_metrics(self, mock_get_client): """Test get_stage with summaries when metrics distributions are missing""" # Setup mock client @@ -305,7 +695,9 @@ def test_get_stage_with_summaries_missing_metrics(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function with with_summaries=True - result = get_stage("app-123", stage_id=1, attempt_id=0, with_summaries=True) + result = get_stage( + "app-123", 1, self.DEFAULT_SERVER_SPEC, attempt_id=0, with_summaries=True + ) # Verify results self.assertEqual(result, mock_stage) @@ -325,7 +717,7 @@ def test_get_stage_with_summaries_missing_metrics(self, mock_get_client): attempt_id=0, ) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_stage_no_stages_found(self, mock_get_client): """Test get_stage when no stages are found""" # Setup mock client @@ -334,12 +726,12 @@ def test_get_stage_no_stages_found(self, mock_get_client): mock_get_client.return_value = mock_client with self.assertRaises(ValueError) as context: - get_stage("app-123", stage_id=1) + get_stage("app-123", 1, self.DEFAULT_SERVER_SPEC) self.assertIn("No stage found with ID 1", str(context.exception)) # Tests for get_application tool - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_application_success(self, mock_get_client): """Test successful application retrieval""" # Setup mock client @@ -351,14 +743,16 @@ def test_get_application_success(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function - result = get_application("spark-app-123") + result = get_application("spark-app-123", self.DEFAULT_SERVER_SPEC) # Verify results self.assertEqual(result, mock_app) mock_client.get_application.assert_called_once_with("spark-app-123") - mock_get_client.assert_called_once_with(unittest.mock.ANY, None) + mock_get_client.assert_called_once_with( + unittest.mock.ANY, self.DEFAULT_SERVER_SPEC + ) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_application_with_server(self, mock_get_client): """Test application retrieval with specific server""" # Setup mock client @@ -368,12 +762,15 @@ def test_get_application_with_server(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function with server - get_application("spark-app-123", server="production") + server_spec = ServerSpec( + static_server_spec=StaticServerSpec(server_name="production") + ) + get_application("spark-app-123", server_spec) # Verify server parameter is passed - mock_get_client.assert_called_once_with(unittest.mock.ANY, "production") + mock_get_client.assert_called_once_with(unittest.mock.ANY, server_spec) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_application_not_found(self, mock_get_client): """Test application retrieval when app doesn't exist""" # Setup mock client to raise exception @@ -383,12 +780,12 @@ def test_get_application_not_found(self, mock_get_client): # Verify exception is propagated with self.assertRaises(Exception) as context: - get_application("non-existent-app") + get_application("non-existent-app", self.DEFAULT_SERVER_SPEC) self.assertIn("Application not found", str(context.exception)) # Tests for list_jobs tool - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_list_jobs_no_filter(self, mock_get_client): """Test job retrieval without status filter""" # Setup mock client @@ -398,7 +795,10 @@ def test_list_jobs_no_filter(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function - result = list_jobs("spark-app-123") + server_spec = ServerSpec( + static_server_spec=StaticServerSpec(default_client=True) + ) + result = list_jobs("spark-app-123", server_spec) # Verify results self.assertEqual(result, mock_jobs) @@ -406,7 +806,7 @@ def test_list_jobs_no_filter(self, mock_get_client): app_id="spark-app-123", status=None ) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_list_jobs_with_status_filter(self, mock_get_client): """Test job retrieval with status filter""" # Setup mock client @@ -417,13 +817,15 @@ def test_list_jobs_with_status_filter(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function with status filter - result = list_jobs("spark-app-123", status=["SUCCEEDED"]) + result = list_jobs( + "spark-app-123", self.DEFAULT_SERVER_SPEC, status=["SUCCEEDED"] + ) # Verify results self.assertEqual(len(result), 1) self.assertEqual(result[0].status, "SUCCEEDED") - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_list_jobs_empty_result(self, mock_get_client): """Test job retrieval with empty result""" # Setup mock client @@ -432,12 +834,15 @@ def test_list_jobs_empty_result(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function - result = list_jobs("spark-app-123") + server_spec = ServerSpec( + static_server_spec=StaticServerSpec(default_client=True) + ) + result = list_jobs("spark-app-123", server_spec) # Verify results self.assertEqual(result, []) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_list_jobs_status_filtering(self, mock_get_client): """Test job status filtering logic""" # Setup mock client @@ -456,14 +861,16 @@ def test_list_jobs_status_filtering(self, mock_get_client): mock_get_client.return_value = mock_client # Test filtering for SUCCEEDED jobs - result = list_jobs("spark-app-123", status=["SUCCEEDED"]) + result = list_jobs( + "spark-app-123", self.DEFAULT_SERVER_SPEC, status=["SUCCEEDED"] + ) # Should only return SUCCEEDED job self.assertEqual(len(result), 1) self.assertEqual(result[0].status, "SUCCEEDED") # Tests for list_stages tool - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_stages_no_filter(self, mock_get_client): """Test stage retrieval without filters""" # Setup mock client @@ -473,7 +880,10 @@ def test_get_stages_no_filter(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function - result = list_stages("spark-app-123") + server_spec = ServerSpec( + static_server_spec=StaticServerSpec(default_client=True) + ) + result = list_stages("spark-app-123", server_spec) # Verify results self.assertEqual(result, mock_stages) @@ -481,7 +891,7 @@ def test_get_stages_no_filter(self, mock_get_client): app_id="spark-app-123", status=None, with_summaries=False ) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_stages_with_status_filter(self, mock_get_client): """Test stage retrieval with status filter""" # Setup mock client @@ -500,13 +910,15 @@ def test_get_stages_with_status_filter(self, mock_get_client): mock_get_client.return_value = mock_client # Call with status filter - result = list_stages("spark-app-123", status=["COMPLETE"]) + result = list_stages( + "spark-app-123", self.DEFAULT_SERVER_SPEC, status=["COMPLETE"] + ) # Should only return COMPLETE stage self.assertEqual(len(result), 1) self.assertEqual(result[0].status, "COMPLETE") - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_stages_with_summaries(self, mock_get_client): """Test stage retrieval with summaries enabled""" # Setup mock client @@ -516,14 +928,14 @@ def test_get_stages_with_summaries(self, mock_get_client): mock_get_client.return_value = mock_client # Call with summaries enabled - list_stages("spark-app-123", with_summaries=True) + list_stages("spark-app-123", self.DEFAULT_SERVER_SPEC, with_summaries=True) # Verify summaries parameter is passed mock_client.list_stages.assert_called_once_with( app_id="spark-app-123", status=None, with_summaries=True ) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_stages_empty_result(self, mock_get_client): """Test stage retrieval with empty result""" # Setup mock client @@ -532,13 +944,16 @@ def test_get_stages_empty_result(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function - result = list_stages("spark-app-123") + server_spec = ServerSpec( + static_server_spec=StaticServerSpec(default_client=True) + ) + result = list_stages("spark-app-123", server_spec) # Verify results self.assertEqual(result, []) # Tests for get_stage_task_summary tool - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_stage_task_summary_success(self, mock_get_client): """Test successful stage task summary retrieval""" # Setup mock client @@ -548,7 +963,7 @@ def test_get_stage_task_summary_success(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function - result = get_stage_task_summary("spark-app-123", 1, 0) + result = get_stage_task_summary("spark-app-123", 1, self.DEFAULT_SERVER_SPEC, 0) # Verify results self.assertEqual(result, mock_summary) @@ -559,7 +974,7 @@ def test_get_stage_task_summary_success(self, mock_get_client): quantiles="0.05,0.25,0.5,0.75,0.95", ) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_stage_task_summary_with_quantiles(self, mock_get_client): """Test stage task summary with custom quantiles""" # Setup mock client @@ -569,14 +984,16 @@ def test_get_stage_task_summary_with_quantiles(self, mock_get_client): mock_get_client.return_value = mock_client # Call with custom quantiles - get_stage_task_summary("spark-app-123", 1, 0, quantiles="0.25,0.5,0.75") + get_stage_task_summary( + "spark-app-123", 1, self.DEFAULT_SERVER_SPEC, 0, quantiles="0.25,0.5,0.75" + ) # Verify quantiles parameter is passed mock_client.get_stage_task_summary.assert_called_once_with( app_id="spark-app-123", stage_id=1, attempt_id=0, quantiles="0.25,0.5,0.75" ) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_stage_task_summary_not_found(self, mock_get_client): """Test stage task summary when stage doesn't exist""" # Setup mock client to raise exception @@ -586,12 +1003,12 @@ def test_get_stage_task_summary_not_found(self, mock_get_client): # Verify exception is propagated with self.assertRaises(Exception) as context: - get_stage_task_summary("spark-app-123", 999, 0) + get_stage_task_summary("spark-app-123", 999, self.DEFAULT_SERVER_SPEC, 0) self.assertIn("Stage not found", str(context.exception)) # Tests for list_slowest_stages tool - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_list_slowest_stages_execution_time_vs_total_time(self, mock_get_client): """Test that list_slowest_stages prioritizes execution time over total stage duration""" mock_client = MagicMock() @@ -628,7 +1045,9 @@ def test_list_slowest_stages_execution_time_vs_total_time(self, mock_get_client) mock_get_client.return_value = mock_client # Call the function - result = list_slowest_stages("app-123", n=2) + result = list_slowest_stages( + "app-123", server_spec=self.DEFAULT_SERVER_SPEC, n=2 + ) # Verify results - Stage B should be first (longer execution time: 7 min vs 5 min) # even though Stage A has longer total duration (10 min vs 8 min) @@ -636,7 +1055,7 @@ def test_list_slowest_stages_execution_time_vs_total_time(self, mock_get_client) self.assertEqual(result[0], stage_b) # Stage B first (7 min execution) self.assertEqual(result[1], stage_a) # Stage A second (5 min execution) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_list_slowest_stages_exclude_running(self, mock_get_client): """Test that list_slowest_stages excludes running stages by default""" mock_client = MagicMock() @@ -665,14 +1084,16 @@ def test_list_slowest_stages_exclude_running(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function with include_running=False (default) - result = list_slowest_stages("app-123", include_running=False, n=2) + result = list_slowest_stages( + "app-123", server_spec=self.DEFAULT_SERVER_SPEC, include_running=False, n=2 + ) # Should only return the completed stage self.assertEqual(len(result), 1) self.assertEqual(result[0], completed_stage) self.assertNotIn(running_stage, result) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_list_slowest_stages_include_running(self, mock_get_client): """Test that list_slowest_stages includes running stages when requested""" mock_client = MagicMock() @@ -701,7 +1122,9 @@ def test_list_slowest_stages_include_running(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function with include_running=True - result = list_slowest_stages("app-123", include_running=True, n=2) + result = list_slowest_stages( + "app-123", server_spec=self.DEFAULT_SERVER_SPEC, include_running=True, n=2 + ) # Should include both stages, but running stage will have duration 0 # so completed stage should be first @@ -709,7 +1132,7 @@ def test_list_slowest_stages_include_running(self, mock_get_client): self.assertEqual(result[0], completed_stage) # Has actual duration self.assertEqual(result[1], running_stage) # Duration 0 - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_list_slowest_stages_missing_timestamps(self, mock_get_client): """Test list_slowest_stages handles stages with missing timestamps""" # Setup mock client @@ -755,13 +1178,15 @@ def test_list_slowest_stages_missing_timestamps(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function - result = list_slowest_stages("app-123", n=3) + result = list_slowest_stages( + "app-123", server_spec=self.DEFAULT_SERVER_SPEC, n=3 + ) # Should return valid stage first, others should have duration 0 self.assertEqual(len(result), 3) self.assertEqual(result[0], valid_stage) # Only one with valid duration - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_list_slowest_stages_empty_result(self, mock_get_client): """Test list_slowest_stages with no stages""" # Setup mock client @@ -770,12 +1195,14 @@ def test_list_slowest_stages_empty_result(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function - result = list_slowest_stages("app-123", n=5) + result = list_slowest_stages( + "app-123", server_spec=self.DEFAULT_SERVER_SPEC, n=5 + ) # Should return empty list self.assertEqual(result, []) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_list_slowest_stages_limit_results(self, mock_get_client): """Test list_slowest_stages limits results to n""" # Setup mock client @@ -799,7 +1226,9 @@ def test_list_slowest_stages_limit_results(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function with n=3 - result = list_slowest_stages("app-123", n=3) + result = list_slowest_stages( + "app-123", server_spec=self.DEFAULT_SERVER_SPEC, n=3 + ) # Should return only 3 stages (the ones with longest execution times) self.assertEqual(len(result), 3) @@ -809,7 +1238,7 @@ def test_list_slowest_stages_limit_results(self, mock_get_client): self.assertEqual(result[2].stage_id, 2) # 3 minutes # Tests for list_slowest_sql_queries tool - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_slowest_sql_queries_success(self, mock_get_client): """Test successful SQL query retrieval and sorting""" # Setup mock client @@ -853,14 +1282,16 @@ def test_get_slowest_sql_queries_success(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function - result = list_slowest_sql_queries("spark-app-123", top_n=2) + result = list_slowest_sql_queries( + "spark-app-123", self.DEFAULT_SERVER_SPEC, top_n=2 + ) # Verify results are sorted by duration (descending) self.assertEqual(len(result), 2) self.assertEqual(result[0].duration, 10000) # Slowest first self.assertEqual(result[1].duration, 5000) # Second slowest - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_slowest_sql_queries_exclude_running(self, mock_get_client): """Test SQL query retrieval excluding running queries""" # Setup mock client @@ -893,13 +1324,13 @@ def test_get_slowest_sql_queries_exclude_running(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function (include_running=False by default) - result = list_slowest_sql_queries("spark-app-123") + result = list_slowest_sql_queries("spark-app-123", self.DEFAULT_SERVER_SPEC) # Should exclude running query self.assertEqual(len(result), 1) self.assertEqual(result[0].status, "COMPLETED") - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_slowest_sql_queries_include_running(self, mock_get_client): """Test SQL query retrieval including running queries""" # Setup mock client @@ -933,13 +1364,13 @@ def test_get_slowest_sql_queries_include_running(self, mock_get_client): # Call the function with include_running=True and top_n=2 result = list_slowest_sql_queries( - "spark-app-123", include_running=True, top_n=2 + "spark-app-123", self.DEFAULT_SERVER_SPEC, include_running=True, top_n=2 ) # Should include both queries self.assertEqual(len(result), 2) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_slowest_sql_queries_empty_result(self, mock_get_client): """Test SQL query retrieval with empty result""" # Setup mock client @@ -948,12 +1379,12 @@ def test_get_slowest_sql_queries_empty_result(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function - result = list_slowest_sql_queries("spark-app-123") + result = list_slowest_sql_queries("spark-app-123", self.DEFAULT_SERVER_SPEC) # Verify results self.assertEqual(result, []) - @patch("spark_history_mcp.tools.tools.get_client_or_default") + @patch("spark_history_mcp.tools.tools.get_client") def test_get_slowest_sql_queries_limit(self, mock_get_client): """Test SQL query retrieval with limit""" # Setup mock client @@ -978,7 +1409,9 @@ def test_get_slowest_sql_queries_limit(self, mock_get_client): mock_get_client.return_value = mock_client # Call the function with top_n=3 - result = list_slowest_sql_queries("spark-app-123", top_n=3) + result = list_slowest_sql_queries( + "spark-app-123", self.DEFAULT_SERVER_SPEC, top_n=3 + ) # Verify results - should return only 3 queries self.assertEqual(len(result), 3)