Skip to content

Commit 57b579d

Browse files
committed
feat: add BindArguments and helper functions for CallToolRequest
1 parent 27e3a49 commit 57b579d

File tree

1 file changed

+200
-0
lines changed

1 file changed

+200
-0
lines changed

mcp/tools.go

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"encoding/json"
55
"errors"
66
"fmt"
7+
"strconv"
78
)
89

910
var errToolSchemaConflict = errors.New("provide either InputSchema or RawInputSchema, not both")
@@ -73,6 +74,205 @@ func (r CallToolRequest) GetRawArguments() any {
7374
return r.Params.Arguments
7475
}
7576

77+
// BindArguments unmarshals the Arguments into the provided struct
78+
// This is useful for working with strongly-typed arguments
79+
func (r CallToolRequest) BindArguments(target any) error {
80+
data, err := json.Marshal(r.Params.Arguments)
81+
if err != nil {
82+
return fmt.Errorf("failed to marshal arguments: %w", err)
83+
}
84+
85+
return json.Unmarshal(data, target)
86+
}
87+
88+
// GetString returns a string argument by key, or the default value if not found
89+
func (r CallToolRequest) GetString(key string, defaultValue string) string {
90+
args := r.GetArguments()
91+
if val, ok := args[key]; ok {
92+
if str, ok := val.(string); ok {
93+
return str
94+
}
95+
}
96+
return defaultValue
97+
}
98+
99+
// RequireString returns a string argument by key, or an error if not found or not a string
100+
func (r CallToolRequest) RequireString(key string) (string, error) {
101+
args := r.GetArguments()
102+
if val, ok := args[key]; ok {
103+
if str, ok := val.(string); ok {
104+
return str, nil
105+
}
106+
return "", fmt.Errorf("argument %q is not a string", key)
107+
}
108+
return "", fmt.Errorf("required argument %q not found", key)
109+
}
110+
111+
// GetInt returns an int argument by key, or the default value if not found
112+
func (r CallToolRequest) GetInt(key string, defaultValue int) int {
113+
args := r.GetArguments()
114+
if val, ok := args[key]; ok {
115+
switch v := val.(type) {
116+
case int:
117+
return v
118+
case float64:
119+
return int(v)
120+
case string:
121+
if i, err := strconv.Atoi(v); err == nil {
122+
return i
123+
}
124+
}
125+
}
126+
return defaultValue
127+
}
128+
129+
// RequireInt returns an int argument by key, or an error if not found or not convertible to int
130+
func (r CallToolRequest) RequireInt(key string) (int, error) {
131+
args := r.GetArguments()
132+
if val, ok := args[key]; ok {
133+
switch v := val.(type) {
134+
case int:
135+
return v, nil
136+
case float64:
137+
return int(v), nil
138+
case string:
139+
if i, err := strconv.Atoi(v); err == nil {
140+
return i, nil
141+
}
142+
return 0, fmt.Errorf("argument %q cannot be converted to int", key)
143+
default:
144+
return 0, fmt.Errorf("argument %q is not an int", key)
145+
}
146+
}
147+
return 0, fmt.Errorf("required argument %q not found", key)
148+
}
149+
150+
// GetFloat returns a float64 argument by key, or the default value if not found
151+
func (r CallToolRequest) GetFloat(key string, defaultValue float64) float64 {
152+
args := r.GetArguments()
153+
if val, ok := args[key]; ok {
154+
switch v := val.(type) {
155+
case float64:
156+
return v
157+
case int:
158+
return float64(v)
159+
case string:
160+
if f, err := strconv.ParseFloat(v, 64); err == nil {
161+
return f
162+
}
163+
}
164+
}
165+
return defaultValue
166+
}
167+
168+
// RequireFloat returns a float64 argument by key, or an error if not found or not convertible to float64
169+
func (r CallToolRequest) RequireFloat(key string) (float64, error) {
170+
args := r.GetArguments()
171+
if val, ok := args[key]; ok {
172+
switch v := val.(type) {
173+
case float64:
174+
return v, nil
175+
case int:
176+
return float64(v), nil
177+
case string:
178+
if f, err := strconv.ParseFloat(v, 64); err == nil {
179+
return f, nil
180+
}
181+
return 0, fmt.Errorf("argument %q cannot be converted to float64", key)
182+
default:
183+
return 0, fmt.Errorf("argument %q is not a float64", key)
184+
}
185+
}
186+
return 0, fmt.Errorf("required argument %q not found", key)
187+
}
188+
189+
// GetBool returns a bool argument by key, or the default value if not found
190+
func (r CallToolRequest) GetBool(key string, defaultValue bool) bool {
191+
args := r.GetArguments()
192+
if val, ok := args[key]; ok {
193+
switch v := val.(type) {
194+
case bool:
195+
return v
196+
case string:
197+
if b, err := strconv.ParseBool(v); err == nil {
198+
return b
199+
}
200+
case int:
201+
return v != 0
202+
case float64:
203+
return v != 0
204+
}
205+
}
206+
return defaultValue
207+
}
208+
209+
// RequireBool returns a bool argument by key, or an error if not found or not convertible to bool
210+
func (r CallToolRequest) RequireBool(key string) (bool, error) {
211+
args := r.GetArguments()
212+
if val, ok := args[key]; ok {
213+
switch v := val.(type) {
214+
case bool:
215+
return v, nil
216+
case string:
217+
if b, err := strconv.ParseBool(v); err == nil {
218+
return b, nil
219+
}
220+
return false, fmt.Errorf("argument %q cannot be converted to bool", key)
221+
case int:
222+
return v != 0, nil
223+
case float64:
224+
return v != 0, nil
225+
default:
226+
return false, fmt.Errorf("argument %q is not a bool", key)
227+
}
228+
}
229+
return false, fmt.Errorf("required argument %q not found", key)
230+
}
231+
232+
// GetStringSlice returns a string slice argument by key, or the default value if not found
233+
func (r CallToolRequest) GetStringSlice(key string, defaultValue []string) []string {
234+
args := r.GetArguments()
235+
if val, ok := args[key]; ok {
236+
switch v := val.(type) {
237+
case []string:
238+
return v
239+
case []any:
240+
result := make([]string, 0, len(v))
241+
for _, item := range v {
242+
if str, ok := item.(string); ok {
243+
result = append(result, str)
244+
}
245+
}
246+
return result
247+
}
248+
}
249+
return defaultValue
250+
}
251+
252+
// RequireStringSlice returns a string slice argument by key, or an error if not found or not convertible to string slice
253+
func (r CallToolRequest) RequireStringSlice(key string) ([]string, error) {
254+
args := r.GetArguments()
255+
if val, ok := args[key]; ok {
256+
switch v := val.(type) {
257+
case []string:
258+
return v, nil
259+
case []any:
260+
result := make([]string, 0, len(v))
261+
for i, item := range v {
262+
if str, ok := item.(string); ok {
263+
result = append(result, str)
264+
} else {
265+
return nil, fmt.Errorf("item %d in argument %q is not a string", i, key)
266+
}
267+
}
268+
return result, nil
269+
default:
270+
return nil, fmt.Errorf("argument %q is not a string slice", key)
271+
}
272+
}
273+
return nil, fmt.Errorf("required argument %q not found", key)
274+
}
275+
76276
// ToolListChangedNotification is an optional notification from the server to
77277
// the client, informing it that the list of tools it offers has changed. This may
78278
// be issued by servers without any previous subscription from the client.

0 commit comments

Comments
 (0)