Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 235 additions & 30 deletions crates/goose/src/agents/code_execution_extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,74 @@ fn create_tool_function(full_tool_name: String) -> NativeFunction {
)
}

fn is_continuation_line(line: &str) -> bool {
let trimmed = line.trim().trim_end_matches(';');
trimmed.starts_with('.')
|| trimmed
.chars()
.all(|c| matches!(c, '}' | ')' | ']') || c.is_whitespace())
}

fn wrap_code_for_result(code: &str) -> String {
let lines: Vec<&str> = code.trim().lines().collect();
let last_idx = lines
.iter()
.rposition(|l| !l.trim().is_empty() && !l.trim().starts_with("//"))
.unwrap_or(0);
let last = lines.get(last_idx).map(|s| s.trim()).unwrap_or("");

const NO_WRAP: &[&str] = &["import ", "export ", "function ", "class "];
const DECLS: &[&str] = &["const ", "let ", "var "];

if last.contains("__result__") || NO_WRAP.iter().any(|p| last.starts_with(p)) {
return code.to_string();
}

if is_continuation_line(last) {
// Find declaration only if it doesn't end with ; (i.e., continuation is part of initialization)
// If declaration ends with ;, it's a complete statement and continuation is a new expression
if let Some((_idx, decl)) = lines.iter().enumerate().rev().find(|(_, l)| {
let trimmed = l.trim();
DECLS.iter().any(|d| trimmed.starts_with(d)) && !trimmed.ends_with(';')
}) {
for d in DECLS {
if let Some(rest) = decl.trim().strip_prefix(d) {
if let Some(name) = rest.split('=').next().map(str::trim) {
return format!("{}\n__result__ = {name};", lines.join("\n"));
}
}
}
}
if let Some(start_idx) = lines.iter().position(|l| {
let t = l.trim();
!t.is_empty()
&& !t.starts_with("//")
&& !t.starts_with('.')
&& !t.ends_with(';')
&& !NO_WRAP.iter().any(|p| t.starts_with(p))
&& !DECLS.iter().any(|d| t.starts_with(d))
&& !is_continuation_line(t)
}) {
let before = lines[..start_idx].join("\n");
let expr = lines[start_idx..=last_idx].join("\n");
let expr_clean = expr.trim_end_matches(';');
return format!("{before}\n__result__ = {expr_clean};");
}
return code.to_string();
}

let before = lines[..last_idx].join("\n");
for decl in DECLS {
if let Some(rest) = last.strip_prefix(decl) {
if let Some(name) = rest.split('=').next().map(str::trim) {
return format!("{before}\n{last}\n__result__ = {name};");
}
}
}
let last_clean = last.trim_end_matches(';');
format!("{before}\n__result__ = {last_clean};")
}

fn run_js_module(
code: &str,
tools: &[ToolInfo],
Expand Down Expand Up @@ -368,33 +436,7 @@ fn run_js_module(
loader.insert(*server_name, module);
}

let wrapped = {
let lines: Vec<&str> = code.trim().lines().collect();
let last_idx = lines
.iter()
.rposition(|l| !l.trim().is_empty() && !l.trim().starts_with("//"))
.unwrap_or(0);
let last = lines.get(last_idx).map(|s| s.trim()).unwrap_or("");

const NO_WRAP: &[&str] = &["import ", "export ", "function ", "class "];
if last.contains("__result__") || NO_WRAP.iter().any(|p| last.starts_with(p)) {
code.to_string()
} else {
let before = lines[..last_idx].join("\n");
let mut result = None;
for decl in ["const ", "let ", "var "] {
if let Some(rest) = last.strip_prefix(decl) {
if let Some(name) = rest.split('=').next().map(str::trim) {
result = Some(format!("{before}\n{last}\n__result__ = {name};"));
}
break;
}
}
result.unwrap_or_else(|| {
format!("{before}\n__result__ = {};", last.trim_end_matches(';'))
})
}
};
let wrapped = wrap_code_for_result(code);

let user_module = Module::parse(Source::from_bytes(&wrapped), None, &mut ctx)
.map_err(|e| format!("Parse error: {e}"))?;
Expand All @@ -410,7 +452,13 @@ fn run_js_module(
.global_object()
.get(js_string!("__result__"), &mut ctx)
.map_err(|e| format!("Failed to get result: {e}"))?;
Ok(result.display().to_string())
let json = result
.to_json(&mut ctx)
.map_err(|e| format!("Failed to serialize result: {e}"))?;
match json {
Some(v) => Ok(serde_json::to_string_pretty(&v).unwrap_or_else(|_| v.to_string())),
None => Ok(result.display().to_string()),
}
}
PromiseState::Rejected(err) => Err(format!("Module error: {}", err.display())),
PromiseState::Pending => Err("Module evaluation did not complete".to_string()),
Expand Down Expand Up @@ -758,8 +806,8 @@ impl McpClientTrait for CodeExecutionClient {
SYNTAX:
- Import: import { tool1, tool2 } from "serverName";
- Call: toolName({ param1: value, param2: value })
- All calls are synchronous, return strings
- Last expression is the result
- All calls are synchronous; JSON returns are auto-parsed (NEVER use JSON.parse)
- Last expression value is returned
- No comments in code

TOOL_GRAPH: Always provide tool_graph to describe the execution flow for the UI.
Expand Down Expand Up @@ -942,6 +990,163 @@ mod tests {
}
}

#[tokio::test]
async fn test_execute_code_multiline_expression() {
let context = PlatformExtensionContext {
session_id: None,
extension_manager: None,
};
let client = CodeExecutionClient::new(context).unwrap();

let code = indoc! {r#"
const obj = {
a: 1,
b: 2
};
obj.a + obj.b
"#};
let mut args = JsonObject::new();
args.insert("code".to_string(), Value::String(code.to_string()));

let result = client
.call_tool("execute_code", Some(args), CancellationToken::new())
.await
.unwrap();

assert!(!result.is_error.unwrap_or(false));
if let RawContent::Text(text) = &result.content[0].raw {
assert_eq!(text.text, "Result: 3");
} else {
panic!("Expected text content");
}
}

#[tokio::test]
async fn test_execute_code_multiline_continuation_char() {
let context = PlatformExtensionContext {
session_id: None,
extension_manager: None,
};
let client = CodeExecutionClient::new(context).unwrap();

let code = indoc! {r#"
const obj = {
x: 10,
y: 20
};
"#};
let mut args = JsonObject::new();
args.insert("code".to_string(), Value::String(code.to_string()));

let result = client
.call_tool("execute_code", Some(args), CancellationToken::new())
.await
.unwrap();

assert!(
!result.is_error.unwrap_or(false),
"Code ending with '}}' should not cause parse error"
);
}

#[tokio::test]
async fn test_execute_code_multiline_array_with_assignment() {
let context = PlatformExtensionContext {
session_id: None,
extension_manager: None,
};
let client = CodeExecutionClient::new(context).unwrap();

let code = indoc! {r#"
const arr = [
1,
2,
3
];
arr
"#};
let mut args = JsonObject::new();
args.insert("code".to_string(), Value::String(code.to_string()));

let result = client
.call_tool("execute_code", Some(args), CancellationToken::new())
.await
.unwrap();

assert!(!result.is_error.unwrap_or(false));
if let RawContent::Text(text) = &result.content[0].raw {
assert!(
text.text.contains("1") && text.text.contains("2") && text.text.contains("3"),
"Multi-line array should return its value when assigned, got: {}",
text.text
);
} else {
panic!("Expected text content");
}
}

#[test_case(
"[1, 2, 3]\n .map(x => x * 2)\n .filter(x => x > 2)",
"__result__ = [1, 2, 3]",
"__result__ = .";
"method chaining wraps from expression start"
)]
#[test_case(
"const arr = [1, 2, 3];\narr\n .map(x => x * 2)",
"__result__ = arr\n .map(x => x * 2);",
"__result__ = arr;";
"declaration followed by separate method chain captures chain result"
)]
#[test_case(
"foo({\n a: 1,\n b: 2\n});",
"__result__ = foo({",
"__result__ = a:";
"multiline function call wraps from function start not interior"
)]
#[test_case(
"doSomething();\nfoo({\n a: 1\n});",
"__result__ = foo({",
"__result__ = doSomething";
"complete statement before multiline call is skipped"
)]
fn test_wrap_code_edge_cases(code: &str, should_contain: &str, should_not_contain: &str) {
let wrapped = wrap_code_for_result(code);
assert!(
wrapped.contains(should_contain),
"Expected '{should_contain}' in: {wrapped}"
);
assert!(
!wrapped.contains(should_not_contain),
"Should not contain '{should_not_contain}' in: {wrapped}"
);
}

#[test]
fn test_wrap_code_nested_object_function_call() {
let code = indoc! {r#"
import { search_flight } from "kiwitravel";
search_flight({
flyFrom: "BBI",
flyTo: "PNQ",
departureDate: "2023-12-29",
cabinClass: "M",
curr: "INR",
passengers: { adults: 1 }
});
"#};

let wrapped = wrap_code_for_result(code);

assert!(
wrapped.contains("__result__"),
"Should wrap the function call result, got:\n{wrapped}"
);
assert!(
!wrapped.contains("__result__ = });"),
"Should not wrap closing brace, got:\n{wrapped}"
);
}

#[tokio::test]
async fn test_read_module_not_found() {
let context = PlatformExtensionContext {
Expand Down