Skip to content

Commit 124cc14

Browse files
committed
feat: Add RFC 7230 compliant header validation and sanitization
1 parent d211763 commit 124cc14

File tree

3 files changed

+672
-31
lines changed

3 files changed

+672
-31
lines changed
Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Internal utilities for MCP tools.
16+
17+
This module contains internal validation and sanitization utilities
18+
that are not part of the public API and follow RFC 7230 properly.
19+
20+
**Security Notes:**
21+
22+
- Header validation implements RFC 7230 §3.2 for proper HTTP header format
23+
- Only truly dangerous control characters are removed from header values
24+
- Legitimate multi-line headers with proper folding are preserved
25+
- Binary data handling is separate from text data for security
26+
- All functions log security-relevant warnings when appropriate
27+
28+
**RFC 7230 Compliance:**
29+
30+
- Header names: only letters, digits, and hyphens allowed
31+
- Header values: control characters (0x00-0x1F, 0x7F) are dangerous
32+
- Header folding: CRLF sequences are preserved for legitimate use cases
33+
- Binary data: handled separately with explicit allow_binary flag
34+
35+
**Attack Prevention:**
36+
37+
- HTTP header injection attacks via control character filtering
38+
- Response splitting attacks through CRLF handling
39+
- Log injection attacks via character sanitization
40+
- Type confusion attacks through strict validation
41+
"""
42+
43+
import logging
44+
import re
45+
from typing import Any
46+
47+
logger = logging.getLogger("google_adk." + __name__)
48+
49+
# RFC 7230 compliant header patterns
50+
# Control characters and special characters not allowed in header names
51+
_HEADER_NAME_FORBIDDEN = r'\x00-\x1F\x7F()<>@,;:\\"/[\]?={} \t'
52+
53+
# Header whitespace characters (RFC 7230 §3.2.4)
54+
_HEADER_WHITESPACE = "\r\n"
55+
56+
# RFC 7230 compliant header name pattern (allows letters, digits, hyphens)
57+
_HEADER_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9-]+$")
58+
59+
# Truly dangerous characters that should never appear in header values
60+
# These are characters that can break HTTP parsing or cause injection
61+
_DANGEROUS_CHARS = {
62+
"\x00",
63+
"\x01",
64+
"\x02",
65+
"\x03",
66+
"\x04",
67+
"\x05",
68+
"\x06",
69+
"\x07",
70+
"\x08",
71+
"\x0b",
72+
"\x0c",
73+
"\x0e",
74+
"\x0f",
75+
"\x10",
76+
"\x11",
77+
"\x12",
78+
"\x13",
79+
"\x14",
80+
"\x15",
81+
"\x16",
82+
"\x17",
83+
"\x18",
84+
"\x19",
85+
"\x1a",
86+
"\x1b",
87+
"\x1c",
88+
"\x1d",
89+
"\x1e",
90+
"\x1f",
91+
"\x7f",
92+
}
93+
94+
95+
def _is_printable_ascii(char: str) -> bool:
96+
"""Check if character is printable ASCII."""
97+
try:
98+
return 0x20 <= ord(char) <= 0x7E
99+
except ValueError:
100+
return False
101+
102+
103+
def _is_control_char(char: str) -> bool:
104+
"""Check if character is a control character."""
105+
return char in _DANGEROUS_CHARS
106+
107+
108+
def _is_whitespace(char: str) -> bool:
109+
"""Check if character is whitespace."""
110+
return char in _HEADER_WHITESPACE
111+
112+
113+
def _get_forbidden_char_desc(char: str) -> str:
114+
"""Get description of forbidden character."""
115+
if char == "\r":
116+
return "carriage return"
117+
elif char == "\n":
118+
return "line feed"
119+
elif char == "\t":
120+
return "horizontal tab"
121+
elif _is_printable_ascii(char):
122+
return f"non-printable ASCII: {repr(char)}"
123+
else:
124+
return f"control character: {repr(char)}"
125+
126+
127+
def _validate_header_name(header_name: str) -> None:
128+
"""Validates that a header name conforms to RFC 7230.
129+
Only allows printable ASCII, no control chars, spaces, or separators.
130+
Rejects header names containing invalid characters.
131+
"""
132+
if not header_name:
133+
raise ValueError("Header name cannot be empty.")
134+
135+
if not _HEADER_NAME_PATTERN.match(header_name):
136+
raise ValueError(
137+
f'Header name "{header_name}" contains invalid characters. '
138+
"Header names must conform to RFC 7230 and cannot contain "
139+
'control characters, spaces, or separators like ():<>@,;:\\"/[]?={}.'
140+
)
141+
142+
143+
def _sanitize_header_value(value: str) -> str:
144+
"""Sanitizes a header value to prevent injection attacks.
145+
146+
This function removes ONLY truly dangerous characters that could cause
147+
header injection attacks, while remaining RFC 7230 compliant.
148+
149+
Args:
150+
value: The header value to sanitize.
151+
152+
Returns:
153+
The sanitized header value with dangerous characters removed.
154+
"""
155+
if not isinstance(value, str):
156+
value = str(value)
157+
158+
# Remove only characters that are truly dangerous for HTTP headers
159+
# These are control characters that can break parsing or enable injection
160+
# We DON'T remove all \r\n sequences as that would break legitimate multi-line headers
161+
# and violate RFC 7230 §3.2.4 which allows header folding
162+
sanitized_chars = []
163+
for char in value:
164+
if char not in _DANGEROUS_CHARS:
165+
sanitized_chars.append(char)
166+
else:
167+
logger.warning(
168+
f"Removed dangerous character {repr(char)} from header value "
169+
"for security reasons"
170+
)
171+
172+
return "".join(sanitized_chars)
173+
174+
175+
def _validate_header_value(value: Any, allow_binary: bool = False) -> None:
176+
"""Validates header values with RFC 7230 compliance and proper binary handling.
177+
178+
Args:
179+
value: The header value to validate.
180+
allow_binary: Whether to allow binary data (bytes) in header values.
181+
182+
Raises:
183+
ValueError: If value contains dangerous characters.
184+
"""
185+
if value is None:
186+
return
187+
188+
if isinstance(value, bytes):
189+
if not allow_binary:
190+
raise ValueError("Binary data not allowed in HTTP header values")
191+
# For binary data, check for dangerous bytes
192+
for byte_val in value:
193+
if byte_val < 128: # ASCII range
194+
char = chr(byte_val)
195+
if char in _DANGEROUS_CHARS:
196+
raise ValueError(
197+
f"Binary data contains dangerous byte: {repr(char)} "
198+
f"({_get_forbidden_char_desc(char)})"
199+
)
200+
return
201+
202+
# For strings, check for dangerous characters that could enable injection
203+
if isinstance(value, str):
204+
for char in value:
205+
if char in _DANGEROUS_CHARS:
206+
raise ValueError(
207+
f"Header value contains dangerous character: {repr(char)} "
208+
f"({_get_forbidden_char_desc(char)})"
209+
)
210+
return
211+
212+
# For other types, convert to string and validate
213+
str_value = str(value)
214+
for char in str_value:
215+
if char in _DANGEROUS_CHARS:
216+
raise ValueError(
217+
"Header value (converted to string) contains dangerous character: "
218+
f"{repr(char)} ({_get_forbidden_char_desc(char)})"
219+
)
220+
221+
222+
def sanitize_header_value(value: Any) -> str:
223+
"""Sanitizes a header value to prevent injection attacks.
224+
225+
This is a wrapper that converts non-string values to strings and then
226+
applies core sanitization logic.
227+
228+
Args:
229+
value: The header value to sanitize (any type).
230+
231+
Returns:
232+
The sanitized header value as a string.
233+
"""
234+
if not isinstance(value, str):
235+
value = str(value)
236+
237+
return _sanitize_header_value(value)
238+
239+
240+
def validate_header_value(
241+
state_key: str, value: Any, strict: bool = False
242+
) -> None:
243+
"""Validates that a state value is suitable for use in a header.
244+
245+
Args:
246+
state_key: The key being validated.
247+
value: The value to validate.
248+
strict: If True, raises ValueError for non-primitive types.
249+
250+
Raises:
251+
ValueError: If strict=True and value is not a primitive type.
252+
"""
253+
if not isinstance(value, (str, int, float, bool)):
254+
msg = (
255+
f'Value for state key "{state_key}" is of type '
256+
f"{type(value).__name__}, which may not serialize correctly into a "
257+
"header. Consider pre-serializing complex values or using "
258+
"state_header_format."
259+
)
260+
if strict:
261+
raise ValueError(msg)
262+
else:
263+
logger.warning(msg)
264+
265+
# Always validate for dangerous characters regardless of strict mode
266+
_validate_header_value(value)
267+
268+
269+
def create_session_state_header_provider(
270+
state_key: str,
271+
header_name: str = "Authorization",
272+
header_format: str = "Bearer {value}",
273+
default_value: str = None,
274+
strict: bool = False,
275+
):
276+
"""Creates a header provider that extracts values from session state.
277+
278+
This utility function generates a header_provider callable that can be used
279+
with McpToolset to automatically extract values from session state and
280+
format them as HTTP headers for MCP server connections.
281+
282+
.. warning::
283+
**Security Best Practice**: For sensitive, short-lived tokens like JWTs,
284+
use ``request_state`` instead of ``session.state`` to avoid persisting
285+
sensitive data to the database. Pass tokens via
286+
``RunAgentRequest.request_state``, which will override ``session.state``
287+
for the duration of the request without being persisted.
288+
289+
Args:
290+
state_key: The key to look up in session.state (or request_state).
291+
header_name: The HTTP header name to set (default: 'Authorization').
292+
header_format: Format string for the header value. Use {value} as a
293+
placeholder for the state value (default: 'Bearer {value}').
294+
default_value: Default value if state_key is not found in session state.
295+
If None, the header is omitted when the key is missing.
296+
strict: If True, raises ValueError when non-primitive types are
297+
encountered. If False (default), logs a warning instead.
298+
299+
Returns:
300+
A callable that takes a ReadonlyContext and returns a dictionary of
301+
headers to be used for the MCP session.
302+
"""
303+
# Validate header name upfront
304+
_validate_header_name(header_name)
305+
306+
def provider(ctx) -> dict[str, str]:
307+
value = ctx.state.get(state_key, default_value)
308+
# Skip header if value is None or empty string
309+
if value is None or value == "":
310+
return {}
311+
312+
validate_header_value(state_key, value, strict=strict)
313+
formatted_value = header_format.format(value=value)
314+
sanitized_value = sanitize_header_value(formatted_value)
315+
316+
return {header_name: sanitized_value}
317+
318+
return provider

0 commit comments

Comments
 (0)