Skip to content

Commit 19b1b4f

Browse files
authored
Merge pull request #8 from dsarno/codex/streamline-protocol-framing-in-c#-and-python
2 parents 031f5d7 + 49a3355 commit 19b1b4f

File tree

6 files changed

+99
-102
lines changed

6 files changed

+99
-102
lines changed

UnityMcpBridge/Editor/UnityMcpBridge.cs

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -431,15 +431,7 @@ private static async Task HandleClientAsync(TcpClient client)
431431
if (true)
432432
{
433433
// Enforced framed mode for this connection
434-
byte[] header = await ReadExactAsync(stream, 8, FrameIOTimeoutMs);
435-
ulong payloadLen = ReadUInt64BigEndian(header);
436-
if (payloadLen == 0UL || payloadLen > MaxFrameBytes)
437-
{
438-
throw new System.IO.IOException($"Invalid framed length: {payloadLen}");
439-
}
440-
int payloadLenInt = checked((int)payloadLen);
441-
byte[] payload = await ReadExactAsync(stream, payloadLenInt, FrameIOTimeoutMs);
442-
commandText = System.Text.Encoding.UTF8.GetString(payload);
434+
commandText = await ReadFrameAsUtf8Async(stream, FrameIOTimeoutMs);
443435
}
444436

445437
try
@@ -459,16 +451,7 @@ private static async Task HandleClientAsync(TcpClient client)
459451
/*lang=json,strict*/
460452
"{\"status\":\"success\",\"result\":{\"message\":\"pong\"}}"
461453
);
462-
if ((ulong)pingResponseBytes.Length > MaxFrameBytes)
463-
{
464-
throw new System.IO.IOException($"Frame too large: {pingResponseBytes.Length}");
465-
}
466-
{
467-
byte[] outHeader = new byte[8];
468-
WriteUInt64BigEndian(outHeader, (ulong)pingResponseBytes.Length);
469-
await stream.WriteAsync(outHeader, 0, outHeader.Length);
470-
}
471-
await stream.WriteAsync(pingResponseBytes, 0, pingResponseBytes.Length);
454+
await WriteFrameAsync(stream, pingResponseBytes);
472455
continue;
473456
}
474457

@@ -479,16 +462,7 @@ private static async Task HandleClientAsync(TcpClient client)
479462

480463
string response = await tcs.Task;
481464
byte[] responseBytes = System.Text.Encoding.UTF8.GetBytes(response);
482-
if ((ulong)responseBytes.Length > MaxFrameBytes)
483-
{
484-
throw new System.IO.IOException($"Frame too large: {responseBytes.Length}");
485-
}
486-
{
487-
byte[] outHeader = new byte[8];
488-
WriteUInt64BigEndian(outHeader, (ulong)responseBytes.Length);
489-
await stream.WriteAsync(outHeader, 0, outHeader.Length);
490-
}
491-
await stream.WriteAsync(responseBytes, 0, responseBytes.Length);
465+
await WriteFrameAsync(stream, responseBytes);
492466
}
493467
catch (Exception ex)
494468
{
@@ -499,22 +473,6 @@ private static async Task HandleClientAsync(TcpClient client)
499473
}
500474
}
501475

502-
private static async System.Threading.Tasks.Task<byte[]> ReadExactAsync(NetworkStream stream, int count)
503-
{
504-
byte[] data = new byte[count];
505-
int offset = 0;
506-
while (offset < count)
507-
{
508-
int r = await stream.ReadAsync(data, offset, count - offset);
509-
if (r == 0)
510-
{
511-
throw new System.IO.IOException("Connection closed before reading expected bytes");
512-
}
513-
offset += r;
514-
}
515-
return data;
516-
}
517-
518476
// Timeout-aware exact read helper; avoids indefinite stalls
519477
private static async System.Threading.Tasks.Task<byte[]> ReadExactAsync(NetworkStream stream, int count, int timeoutMs)
520478
{
@@ -538,6 +496,35 @@ private static async System.Threading.Tasks.Task<byte[]> ReadExactAsync(NetworkS
538496
return data;
539497
}
540498

499+
private static async System.Threading.Tasks.Task WriteFrameAsync(NetworkStream stream, byte[] payload)
500+
{
501+
if ((ulong)payload.LongLength > MaxFrameBytes)
502+
{
503+
throw new System.IO.IOException($"Frame too large: {payload.LongLength}");
504+
}
505+
byte[] header = new byte[8];
506+
WriteUInt64BigEndian(header, (ulong)payload.LongLength);
507+
await stream.WriteAsync(header, 0, header.Length);
508+
await stream.WriteAsync(payload, 0, payload.Length);
509+
}
510+
511+
private static async System.Threading.Tasks.Task<string> ReadFrameAsUtf8Async(NetworkStream stream, int timeoutMs)
512+
{
513+
byte[] header = await ReadExactAsync(stream, 8, timeoutMs);
514+
ulong payloadLen = ReadUInt64BigEndian(header);
515+
if (payloadLen == 0UL || payloadLen > MaxFrameBytes)
516+
{
517+
throw new System.IO.IOException($"Invalid framed length: {payloadLen}");
518+
}
519+
if (payloadLen > int.MaxValue)
520+
{
521+
throw new System.IO.IOException("Frame too large for buffer");
522+
}
523+
int count = (int)payloadLen;
524+
byte[] payload = await ReadExactAsync(stream, count, timeoutMs);
525+
return System.Text.Encoding.UTF8.GetString(payload);
526+
}
527+
541528
private static ulong ReadUInt64BigEndian(byte[] buffer)
542529
{
543530
if (buffer == null || buffer.Length < 8) return 0UL;

UnityMcpBridge/UnityMcpServer~/src/server.py

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from tools import register_all_tools
1010
from unity_connection import get_unity_connection, UnityConnection
1111
from pathlib import Path
12+
import os
13+
import hashlib
1214

1315
# Configure logging: strictly stderr/file only (never stdout)
1416
stderr_handler = logging.StreamHandler(stream=sys.stderr)
@@ -98,52 +100,44 @@ def asset_creation_strategy() -> str:
98100
class _:
99101
pass
100102

101-
import os
102-
import hashlib
103-
104-
def _unity_assets_root() -> str:
105-
# Heuristic: from the Unity project root (one level up from Library/ProjectSettings), 'Assets'
106-
# Here, assume server runs from repo; let clients pass absolute paths under project too.
107-
return None
103+
PROJECT_ROOT = Path(os.environ.get("UNITY_PROJECT_ROOT", Path.cwd())).resolve()
104+
ASSETS_ROOT = (PROJECT_ROOT / "Assets").resolve()
108105

109-
def _safe_path(uri: str) -> str | None:
110-
# URIs: unity://path/Assets/... or file:///absolute
106+
def _resolve_safe_path_from_uri(uri: str) -> Path | None:
107+
raw: str | None = None
111108
if uri.startswith("unity://path/"):
112-
p = uri[len("unity://path/"):]
113-
return p
114-
if uri.startswith("file://"):
115-
return uri[len("file://"):]
116-
# Minimal tolerance for plain Assets/... paths
117-
if uri.startswith("Assets/"):
118-
return uri
119-
return None
109+
raw = uri[len("unity://path/"):]
110+
elif uri.startswith("file://"):
111+
raw = uri[len("file://"):]
112+
elif uri.startswith("Assets/"):
113+
raw = uri
114+
if raw is None:
115+
return None
116+
p = (PROJECT_ROOT / raw).resolve()
117+
try:
118+
p.relative_to(PROJECT_ROOT)
119+
except ValueError:
120+
return None
121+
return p
120122

121123
@mcp.resource.list()
122124
def list_resources(ctx: Context) -> list[dict]:
123-
# Lightweight: expose only C# under Assets by default
124125
assets = []
125126
try:
126-
root = os.getcwd()
127-
for base, _, files in os.walk(os.path.join(root, "Assets")):
128-
for f in files:
129-
if f.endswith(".cs"):
130-
rel = os.path.relpath(os.path.join(base, f), root).replace("\\", "/")
131-
assets.append({
132-
"uri": f"unity://path/{rel}",
133-
"name": os.path.basename(rel)
134-
})
127+
for p in ASSETS_ROOT.rglob("*.cs"):
128+
rel = p.relative_to(PROJECT_ROOT).as_posix()
129+
assets.append({"uri": f"unity://path/{rel}", "name": p.name})
135130
except Exception:
136131
pass
137132
return assets
138133

139134
@mcp.resource.read()
140135
def read_resource(ctx: Context, uri: str) -> dict:
141-
path = _safe_path(uri)
142-
if not path or not os.path.exists(path):
136+
p = _resolve_safe_path_from_uri(uri)
137+
if not p or not p.exists():
143138
return {"mimeType": "text/plain", "text": f"Resource not found: {uri}"}
144139
try:
145-
with open(path, "r", encoding="utf-8") as f:
146-
text = f.read()
140+
text = p.read_text(encoding="utf-8")
147141
sha = hashlib.sha256(text.encode("utf-8")).hexdigest()
148142
return {"mimeType": "text/plain", "text": text, "metadata": {"sha256": sha}}
149143
except Exception as e:

UnityMcpBridge/UnityMcpServer~/src/tools/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from .manage_script_edits import register_manage_script_edits_tools
23
from .manage_script import register_manage_script_tools
34
from .manage_scene import register_manage_scene_tools
@@ -8,9 +9,11 @@
89
from .read_console import register_read_console_tools
910
from .execute_menu_item import register_execute_menu_item_tools
1011

12+
logger = logging.getLogger("unity-mcp-server")
13+
1114
def register_all_tools(mcp):
1215
"""Register all refactored tools with the MCP server."""
13-
# Note: Do not print to stdout; Claude treats stdout as MCP JSON. Use logging.
16+
logger.info("Registering Unity MCP Server refactored tools...")
1417
# Prefer the surgical edits tool so LLMs discover it first
1518
register_manage_script_edits_tools(mcp)
1619
register_manage_script_tools(mcp)
@@ -21,4 +24,4 @@ def register_all_tools(mcp):
2124
register_manage_shader_tools(mcp)
2225
register_read_console_tools(mcp)
2326
register_execute_menu_item_tools(mcp)
24-
# Do not print to stdout here either.
27+
logger.info("Unity MCP Server tool registration complete.")

UnityMcpBridge/UnityMcpServer~/src/tools/manage_script.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def create_script(
6060
"namespace": namespace,
6161
"scriptType": script_type,
6262
}
63-
if contents is not None:
63+
if contents:
6464
params["encodedContents"] = base64.b64encode(contents.encode("utf-8")).decode("utf-8")
6565
params["contentsEncoded"] = True
6666
params = {k: v for k, v in params.items() if v is not None}
@@ -107,7 +107,7 @@ def manage_script(
107107
- Edits should use apply_text_edits.
108108
109109
Args:
110-
action: Operation ('create', 'read', 'update', 'delete').
110+
action: Operation ('create', 'read', 'delete').
111111
name: Script name (no .cs extension).
112112
path: Asset path (default: "Assets/").
113113
contents: C# code for 'create'/'update'.
@@ -132,8 +132,8 @@ def manage_script(
132132
}
133133

134134
# Base64 encode the contents if they exist to avoid JSON escaping issues
135-
if contents is not None:
136-
if action in ['create', 'update']:
135+
if contents:
136+
if action == 'create':
137137
params["encodedContents"] = base64.b64encode(contents.encode('utf-8')).decode('utf-8')
138138
params["contentsEncoded"] = True
139139
else:
@@ -143,22 +143,22 @@ def manage_script(
143143

144144
response = send_command_with_retry("manage_script", params)
145145

146-
if isinstance(response, dict) and response.get("success"):
147-
if response.get("data", {}).get("contentsEncoded"):
148-
decoded_contents = base64.b64decode(response["data"]["encodedContents"]).decode('utf-8')
149-
response["data"]["contents"] = decoded_contents
150-
del response["data"]["encodedContents"]
151-
del response["data"]["contentsEncoded"]
152-
153-
return {
154-
"success": True,
155-
"message": response.get("message", "Operation successful."),
156-
"data": response.get("data"),
157-
}
158-
return response if isinstance(response, dict) else {
159-
"success": False,
160-
"message": str(response),
161-
}
146+
if isinstance(response, dict):
147+
if response.get("success"):
148+
if response.get("data", {}).get("contentsEncoded"):
149+
decoded_contents = base64.b64decode(response["data"]["encodedContents"]).decode('utf-8')
150+
response["data"]["contents"] = decoded_contents
151+
del response["data"]["encodedContents"]
152+
del response["data"]["contentsEncoded"]
153+
154+
return {
155+
"success": True,
156+
"message": response.get("message", "Operation successful."),
157+
"data": response.get("data"),
158+
}
159+
return response
160+
161+
return {"success": False, "message": str(response)}
162162

163163
except Exception as e:
164164
return {

UnityMcpBridge/UnityMcpServer~/src/tools/manage_script_edits.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,11 @@ def _apply_edits_locally(original_text: str, edits: List[Dict[str, Any]]) -> str
5151
end_line = int(edit.get("endLine", start_line))
5252
replacement = edit.get("text", "")
5353
lines = text.splitlines(keepends=True)
54-
if start_line < 1 or end_line < start_line or end_line > len(lines):
54+
max_end = len(lines) + 1
55+
if start_line < 1 or end_line < start_line or end_line > max_end:
5556
raise RuntimeError("replace_range out of bounds")
5657
a = start_line - 1
57-
b = end_line
58+
b = min(end_line, len(lines))
5859
rep = replacement
5960
if rep and not rep.endswith("\n"):
6061
rep += "\n"
@@ -88,7 +89,8 @@ def script_apply_edits(
8889
script_type: str = "MonoBehaviour",
8990
namespace: str = "",
9091
) -> Dict[str, Any]:
91-
# If the edits request structured class/method ops, route directly to Unity's 'edit' action
92+
# If the edits request structured class/method ops, route directly to Unity's 'edit' action.
93+
# These bypass local text validation/encoding since Unity performs the semantic changes.
9294
for e in edits or []:
9395
op = (e.get("op") or e.get("operation") or e.get("type") or e.get("mode") or "").strip().lower()
9496
if op in ("replace_class", "delete_class", "replace_method", "delete_method", "insert_method"):

UnityMcpBridge/UnityMcpServer~/src/unity_connection.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,23 @@ def connect(self) -> bool:
4949
self.use_framing = True
5050
logger.debug('Unity MCP handshake received: FRAMING=1 (strict)')
5151
else:
52+
try:
53+
msg = b'Unity MCP requires FRAMING=1'
54+
header = struct.pack('>Q', len(msg))
55+
self.sock.sendall(header + msg)
56+
except Exception:
57+
pass
5258
raise ConnectionError(f'Unity MCP requires FRAMING=1, got: {text!r}')
5359
finally:
5460
self.sock.settimeout(config.connection_timeout)
5561
return True
5662
except Exception as e:
5763
logger.error(f"Failed to connect to Unity: {str(e)}")
64+
try:
65+
if self.sock:
66+
self.sock.close()
67+
except Exception:
68+
pass
5869
self.sock = None
5970
return False
6071

@@ -83,7 +94,7 @@ def receive_full_response(self, sock, buffer_size=config.buffer_size) -> bytes:
8394
try:
8495
header = self._read_exact(sock, 8)
8596
payload_len = struct.unpack('>Q', header)[0]
86-
if payload_len == 0 or payload_len > (64 * 1024 * 1024):
97+
if payload_len > (64 * 1024 * 1024):
8798
raise Exception(f"Invalid framed length: {payload_len}")
8899
payload = self._read_exact(sock, payload_len)
89100
logger.info(f"Received framed response ({len(payload)} bytes)")

0 commit comments

Comments
 (0)