Skip to content

Commit 9e892a2

Browse files
committed
reduce list tool context by default
Signed-off-by: Manabu McCloskey <[email protected]>
1 parent 5959b6f commit 9e892a2

File tree

3 files changed

+285
-40
lines changed

3 files changed

+285
-40
lines changed
Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,92 @@
1-
from typing import Optional, Sequence
1+
from datetime import datetime
2+
from typing import List, Optional
23

34
from pydantic import BaseModel, ConfigDict, Field
45

6+
from spark_history_mcp.models.spark_types import JobData, StageData
7+
58

69
class JobSummary(BaseModel):
7-
"""Summary of job execution counts for a SQL query."""
10+
job_id: Optional[int] = Field(None, alias="jobId")
11+
name: str
12+
description: Optional[str] = None
13+
status: str
14+
submission_time: Optional[datetime] = Field(None, alias="submissionTime")
15+
completion_time: Optional[datetime] = Field(None, alias="completionTime")
16+
duration_seconds: Optional[float] = None
17+
succeeded_stage_ids: List[int] = Field(
18+
default_factory=list, alias="succeededStageIds"
19+
)
20+
failed_stage_ids: List[int] = Field(default_factory=list, alias="failedStageIds")
21+
active_stage_ids: List[int] = Field(default_factory=list, alias="activeStageIds")
22+
pending_stage_ids: List[int] = Field(default_factory=list, alias="pendingStageIds")
23+
skipped_stage_ids: List[int] = Field(default_factory=list, alias="skippedStageIds")
824

9-
success_job_ids: Sequence[int] = Field(..., alias="successJobsIds")
10-
failed_job_ids: Sequence[int] = Field(..., alias="failedJobsIds")
11-
running_job_ids: Sequence[int] = Field(..., alias="runningJobsIds")
25+
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
1226

13-
model_config = ConfigDict(populate_by_name=True)
27+
@classmethod
28+
def parse_datetime(cls, value):
29+
if value is None:
30+
return None
31+
if isinstance(value, (int, float)):
32+
return datetime.fromtimestamp(value / 1000)
33+
if isinstance(value, str) and value.endswith("GMT"):
34+
try:
35+
dt_str = value.replace("GMT", "+0000")
36+
return datetime.strptime(dt_str, "%Y-%m-%dT%H:%M:%S.%f%z")
37+
except ValueError:
38+
pass
39+
return value
40+
41+
@classmethod
42+
def from_job_data(
43+
cls, job_data: JobData, stages: List[StageData] = None
44+
) -> "JobSummary":
45+
"""Create a JobSummary from full JobData and optional stage data."""
46+
duration = None
47+
if job_data.completion_time and job_data.submission_time:
48+
duration = (
49+
job_data.completion_time - job_data.submission_time
50+
).total_seconds()
51+
52+
# Initialize stage ID lists
53+
succeeded_stage_ids = []
54+
failed_stage_ids = []
55+
active_stage_ids = []
56+
pending_stage_ids = []
57+
skipped_stage_ids = []
58+
59+
# Group stage IDs by status if stage data is provided
60+
if stages and job_data.stage_ids:
61+
stage_status_map = {stage.stage_id: stage.status for stage in stages}
62+
63+
for stage_id in job_data.stage_ids:
64+
stage_status = stage_status_map.get(stage_id, "UNKNOWN")
65+
if stage_status == "COMPLETE":
66+
succeeded_stage_ids.append(stage_id)
67+
elif stage_status == "FAILED":
68+
failed_stage_ids.append(stage_id)
69+
elif stage_status == "ACTIVE":
70+
active_stage_ids.append(stage_id)
71+
elif stage_status == "PENDING":
72+
pending_stage_ids.append(stage_id)
73+
elif stage_status == "SKIPPED":
74+
skipped_stage_ids.append(stage_id)
75+
76+
return cls(
77+
job_id=job_data.job_id,
78+
name=job_data.name,
79+
description=job_data.description,
80+
status=job_data.status,
81+
submission_time=job_data.submission_time,
82+
completion_time=job_data.completion_time,
83+
duration_seconds=duration,
84+
succeeded_stage_ids=succeeded_stage_ids,
85+
failed_stage_ids=failed_stage_ids,
86+
active_stage_ids=active_stage_ids,
87+
pending_stage_ids=pending_stage_ids,
88+
skipped_stage_ids=skipped_stage_ids,
89+
)
1490

1591

1692
class SqlQuerySummary(BaseModel):
@@ -22,6 +98,8 @@ class SqlQuerySummary(BaseModel):
2298
status: str
2399
submission_time: Optional[str] = Field(None, alias="submissionTime")
24100
plan_description: str = Field(..., alias="planDescription")
25-
job_summary: JobSummary = Field(..., alias="jobSummary")
101+
success_job_ids: List[int] = Field(..., alias="successJobIds")
102+
failed_job_ids: List[int] = Field(..., alias="failedJobIds")
103+
running_job_ids: List[int] = Field(..., alias="runningJobIds")
26104

27105
model_config = ConfigDict(populate_by_name=True)

src/spark_history_mcp/tools/tools.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -115,18 +115,21 @@ def get_application(app_id: str, server: Optional[str] = None) -> ApplicationInf
115115

116116
@mcp.tool()
117117
def list_jobs(
118-
app_id: str, server: Optional[str] = None, status: Optional[list[str]] = None
119-
) -> list:
118+
app_id: str,
119+
server: Optional[str] = None,
120+
status: Optional[list[str]] = None,
121+
limit: int = 50,
122+
) -> List[JobSummary]:
120123
"""
121-
Get a list of all jobs for a Spark application.
122-
124+
Get a list of jobs for a Spark application.
123125
Args:
124126
app_id: The Spark application ID
125127
server: Optional server name to use (uses default if not specified)
126-
status: Optional list of job status values to filter by
128+
status: Optional list of job status values to filter by (running|succeeded|failed|unknown)
129+
limit: Maximum number of jobs to return (default: 50)
127130
128131
Returns:
129-
List of JobData objects for the application
132+
List of JobSummary objects for the application
130133
"""
131134
ctx = mcp.get_context()
132135
client = get_client_or_default(ctx, server)
@@ -136,7 +139,43 @@ def list_jobs(
136139
if status:
137140
job_statuses = [JobExecutionStatus.from_string(s) for s in status]
138141

139-
return client.list_jobs(app_id=app_id, status=job_statuses)
142+
jobs = client.list_jobs(app_id=app_id, status=job_statuses)
143+
144+
stages = client.list_stages(app_id=app_id, details=False)
145+
146+
job_summaries = [JobSummary.from_job_data(job, stages) for job in jobs]
147+
148+
if limit > 0:
149+
job_summaries = job_summaries[:limit]
150+
151+
return job_summaries
152+
153+
154+
@mcp.tool()
155+
def get_job(
156+
app_id: str,
157+
job_id: int,
158+
server: Optional[str] = None,
159+
) -> JobSummary:
160+
"""
161+
Get information about a specific job.
162+
163+
Args:
164+
app_id: The Spark application ID
165+
job_id: The job ID
166+
server: Optional server name to use (uses default if not specified)
167+
168+
Returns:
169+
JobSummary object containing job information with stage IDs grouped by status
170+
"""
171+
ctx = mcp.get_context()
172+
client = get_client_or_default(ctx, server)
173+
174+
job_data = client.get_job(app_id, job_id)
175+
176+
stages = client.list_stages(app_id=app_id, details=False)
177+
178+
return JobSummary.from_job_data(job_data, stages)
140179

141180

142181
@mcp.tool()
@@ -190,6 +229,7 @@ def list_stages(
190229
server: Optional[str] = None,
191230
status: Optional[list[str]] = None,
192231
with_summaries: bool = False,
232+
limit: int = 20,
193233
) -> list:
194234
"""
195235
Get a list of all stages for a Spark application.
@@ -202,6 +242,7 @@ def list_stages(
202242
server: Optional server name to use (uses default if not specified)
203243
status: Optional list of stage status values to filter by
204244
with_summaries: Whether to include summary metrics in the response
245+
limit: Maximum number of stages to return (default: 20)
205246
206247
Returns:
207248
List of StageData objects for the application
@@ -214,12 +255,17 @@ def list_stages(
214255
if status:
215256
stage_statuses = [StageStatus.from_string(s) for s in status]
216257

217-
return client.list_stages(
258+
stages = client.list_stages(
218259
app_id=app_id,
219260
status=stage_statuses,
220261
with_summaries=with_summaries,
221262
)
222263

264+
if limit > 0:
265+
stages = stages[:limit]
266+
267+
return stages
268+
223269

224270
@mcp.tool()
225271
def list_slowest_stages(
@@ -939,12 +985,6 @@ def list_slowest_sql_queries(
939985
# Create simplified results without additional API calls. Raw object is too verbose.
940986
simplified_results = []
941987
for execution in slowest_executions:
942-
job_summary = JobSummary(
943-
success_job_ids=execution.success_job_ids,
944-
failed_job_ids=execution.failed_job_ids,
945-
running_job_ids=execution.running_job_ids,
946-
)
947-
948988
# Handle plan description based on include_plan_description flag
949989
plan_description = ""
950990
if include_plan_description and execution.plan_description:
@@ -961,7 +1001,9 @@ def list_slowest_sql_queries(
9611001
if execution.submission_time
9621002
else None,
9631003
plan_description=plan_description,
964-
job_summary=job_summary,
1004+
success_job_ids=execution.success_job_ids,
1005+
failed_job_ids=execution.failed_job_ids,
1006+
running_job_ids=execution.running_job_ids,
9651007
)
9661008

9671009
simplified_results.append(query_summary)

0 commit comments

Comments
 (0)