diff --git a/instrumentation-genai/opentelemetry-instrumentation-google-genai/CHANGELOG.md b/instrumentation-genai/opentelemetry-instrumentation-google-genai/CHANGELOG.md index 01ef457d9a..3e331c68b1 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-google-genai/CHANGELOG.md +++ b/instrumentation-genai/opentelemetry-instrumentation-google-genai/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +- Add more request configuration options to the span attributes ([#3374](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3374)) - Restructure tests to keep in line with repository conventions ([#3344](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3344)) - Fix [bug](https://github.com/open-telemetry/opentelemetry-python-contrib/issues/3416) where diff --git a/instrumentation-genai/opentelemetry-instrumentation-google-genai/TODOS.md b/instrumentation-genai/opentelemetry-instrumentation-google-genai/TODOS.md index 16a8299e2a..15e4226479 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-google-genai/TODOS.md +++ b/instrumentation-genai/opentelemetry-instrumentation-google-genai/TODOS.md @@ -4,7 +4,6 @@ Here are some TODO items required to achieve stability for this package: - - Add more span-level attributes for request configuration - Add more span-level attributes for response information - Verify and correct formatting of events: - Including the 'role' field for message events diff --git a/instrumentation-genai/opentelemetry-instrumentation-google-genai/pyproject.toml b/instrumentation-genai/opentelemetry-instrumentation-google-genai/pyproject.toml index 9bb9e0e279..de36cab532 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-google-genai/pyproject.toml +++ b/instrumentation-genai/opentelemetry-instrumentation-google-genai/pyproject.toml @@ -37,9 +37,9 @@ classifiers = [ "Programming Language :: Python :: 3.12" ] dependencies = [ - "opentelemetry-api >=1.30.0, <2", - "opentelemetry-instrumentation >=0.51b0, <2", - "opentelemetry-semantic-conventions >=0.51b0, <2" + "opentelemetry-api >=1.31.1, <2", + "opentelemetry-instrumentation >=0.52b1, <2", + "opentelemetry-semantic-conventions >=0.52b1, <2" ] [project.optional-dependencies] diff --git a/instrumentation-genai/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/allowlist_util.py b/instrumentation-genai/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/allowlist_util.py new file mode 100644 index 0000000000..562c91ea34 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/allowlist_util.py @@ -0,0 +1,97 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from typing import Iterable, Optional, Set + +ALLOWED = True +DENIED = False + + +def _parse_env_list(s: str) -> Set[str]: + result = set() + for entry in s.split(","): + stripped_entry = entry.strip() + if not stripped_entry: + continue + result.add(stripped_entry) + return result + + +class _CompoundMatcher: + def __init__(self, entries: Set[str]): + self._match_all = "*" in entries + self._entries = entries + self._regex_matcher = None + regex_entries = [] + for entry in entries: + if "*" not in entry: + continue + if entry == "*": + continue + entry = entry.replace("[", "\\[") + entry = entry.replace("]", "\\]") + entry = entry.replace(".", "\\.") + entry = entry.replace("*", ".*") + regex_entries.append(f"({entry})") + if regex_entries: + joined_regex = "|".join(regex_entries) + regex_str = f"^({joined_regex})$" + self._regex_matcher = re.compile(regex_str) + + @property + def match_all(self): + return self._match_all + + def matches(self, x): + if self._match_all: + return True + if x in self._entries: + return True + if (self._regex_matcher is not None) and ( + self._regex_matcher.fullmatch(x) + ): + return True + return False + + +class AllowList: + def __init__( + self, + includes: Optional[Iterable[str]] = None, + excludes: Optional[Iterable[str]] = None, + ): + self._includes = _CompoundMatcher(set(includes or [])) + self._excludes = _CompoundMatcher(set(excludes or [])) + assert (not self._includes.match_all) or ( + not self._excludes.match_all + ), "Can't have '*' in both includes and excludes." + + def allowed(self, x: str): + if self._excludes.match_all: + return self._includes.matches(x) + if self._includes.match_all: + return not self._excludes.matches(x) + return self._includes.matches(x) and not self._excludes.matches(x) + + @staticmethod + def from_env( + includes_env_var: str, excludes_env_var: Optional[str] = None + ): + includes = _parse_env_list(os.getenv(includes_env_var) or "") + excludes = set() + if excludes_env_var: + excludes = _parse_env_list(os.getenv(excludes_env_var) or "") + return AllowList(includes=includes, excludes=excludes) diff --git a/instrumentation-genai/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/custom_semconv.py b/instrumentation-genai/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/custom_semconv.py new file mode 100644 index 0000000000..fcdf6b1c39 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/custom_semconv.py @@ -0,0 +1,18 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Semantic Convention still being defined in: +# https://github.com/open-telemetry/semantic-conventions/pull/2125 +GCP_GENAI_OPERATION_CONFIG = "gcp.gen_ai.operation.config" diff --git a/instrumentation-genai/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/dict_util.py b/instrumentation-genai/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/dict_util.py new file mode 100644 index 0000000000..6f39474edf --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/dict_util.py @@ -0,0 +1,301 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +from typing import ( + Any, + Dict, + Optional, + Protocol, + Sequence, + Set, + Tuple, + Union, +) + +Primitive = Union[bool, str, int, float] +BoolList = list[bool] +StringList = list[str] +IntList = list[int] +FloatList = list[float] +HomogenousPrimitiveList = Union[BoolList, StringList, IntList, FloatList] +FlattenedValue = Union[Primitive, HomogenousPrimitiveList] +FlattenedDict = Dict[str, FlattenedValue] + + +class FlattenFunc(Protocol): + def __call__( + self, + key: str, + value: Any, + exclude_keys: Set[str], + rename_keys: Dict[str, str], + flatten_functions: Dict[str, "FlattenFunc"], + **kwargs: Any, + ) -> Any: + return None + + +_logger = logging.getLogger(__name__) + + +def _concat_key(prefix: Optional[str], suffix: str): + if not prefix: + return suffix + return f"{prefix}.{suffix}" + + +def _is_primitive(v): + for t in [str, bool, int, float]: + if isinstance(v, t): + return True + return False + + +def _is_homogenous_primitive_list(v): + if not isinstance(v, list): + return False + if len(v) == 0: + return True + if not _is_primitive(v[0]): + return False + first_entry_value_type = type(v[0]) + for entry in v[1:]: + if not isinstance(entry, first_entry_value_type): + return False + return True + + +def _get_flatten_func( + flatten_functions: Dict[str, FlattenFunc], key_names: set[str] +) -> Optional[FlattenFunc]: + for key in key_names: + flatten_func = flatten_functions.get(key) + if flatten_func is not None: + return flatten_func + return None + + +def _flatten_with_flatten_func( + key: str, + value: Any, + exclude_keys: Set[str], + rename_keys: Dict[str, str], + flatten_functions: Dict[str, FlattenFunc], + key_names: Set[str], +) -> Tuple[bool, Any]: + flatten_func = _get_flatten_func(flatten_functions, key_names) + if flatten_func is None: + return False, value + func_output = flatten_func( + key, + value, + exclude_keys=exclude_keys, + rename_keys=rename_keys, + flatten_functions=flatten_functions, + ) + if func_output is None: + return True, {} + if _is_primitive(func_output) or _is_homogenous_primitive_list( + func_output + ): + return True, {key: func_output} + return False, func_output + + +def _flatten_compound_value_using_json( + key: str, + value: Any, + exclude_keys: Set[str], + rename_keys: Dict[str, str], + flatten_functions: Dict[str, FlattenFunc], + _from_json=False, +) -> FlattenedDict: + if _from_json: + _logger.debug( + "Cannot flatten value with key %s; value: %s", key, value + ) + return {} + try: + json_string = json.dumps(value) + except TypeError: + _logger.debug( + "Cannot flatten value with key %s; value: %s. Not JSON serializable.", + key, + value, + ) + return {} + json_value = json.loads(json_string) + return _flatten_value( + key, + json_value, + exclude_keys=exclude_keys, + rename_keys=rename_keys, + flatten_functions=flatten_functions, + # Ensure that we don't recurse indefinitely if "json.loads()" somehow returns + # a complex, compound object that does not get handled by the "primitive", "list", + # or "dict" cases. Prevents falling back on the JSON serialization fallback path. + _from_json=True, + ) + + +def _flatten_compound_value( + key: str, + value: Any, + exclude_keys: Set[str], + rename_keys: Dict[str, str], + flatten_functions: Dict[str, FlattenFunc], + key_names: Set[str], + _from_json=False, +) -> FlattenedDict: + fully_flattened_with_flatten_func, value = _flatten_with_flatten_func( + key=key, + value=value, + exclude_keys=exclude_keys, + rename_keys=rename_keys, + flatten_functions=flatten_functions, + key_names=key_names, + ) + if fully_flattened_with_flatten_func: + return value + if isinstance(value, dict): + return _flatten_dict( + value, + key_prefix=key, + exclude_keys=exclude_keys, + rename_keys=rename_keys, + flatten_functions=flatten_functions, + ) + if isinstance(value, list): + if _is_homogenous_primitive_list(value): + return {key: value} + return _flatten_list( + value, + key_prefix=key, + exclude_keys=exclude_keys, + rename_keys=rename_keys, + flatten_functions=flatten_functions, + ) + if hasattr(value, "model_dump"): + return _flatten_dict( + value.model_dump(), + key_prefix=key, + exclude_keys=exclude_keys, + rename_keys=rename_keys, + flatten_functions=flatten_functions, + ) + return _flatten_compound_value_using_json( + key, + value, + exclude_keys=exclude_keys, + rename_keys=rename_keys, + flatten_functions=flatten_functions, + _from_json=_from_json, + ) + + +def _flatten_value( + key: str, + value: Any, + exclude_keys: Set[str], + rename_keys: Dict[str, str], + flatten_functions: Dict[str, FlattenFunc], + _from_json=False, +) -> FlattenedDict: + if value is None: + return {} + key_names = set([key]) + renamed_key = rename_keys.get(key) + if renamed_key is not None: + key_names.add(renamed_key) + key = renamed_key + if key_names & exclude_keys: + return {} + if _is_primitive(value): + return {key: value} + return _flatten_compound_value( + key=key, + value=value, + exclude_keys=exclude_keys, + rename_keys=rename_keys, + flatten_functions=flatten_functions, + key_names=key_names, + _from_json=_from_json, + ) + + +def _flatten_dict( + d: Dict[str, Any], + key_prefix: str, + exclude_keys: Set[str], + rename_keys: Dict[str, str], + flatten_functions: Dict[str, FlattenFunc], +) -> FlattenedDict: + result = {} + for key, value in d.items(): + if key in exclude_keys: + continue + full_key = _concat_key(key_prefix, key) + flattened = _flatten_value( + full_key, + value, + exclude_keys=exclude_keys, + rename_keys=rename_keys, + flatten_functions=flatten_functions, + ) + result.update(flattened) + return result + + +def _flatten_list( + lst: list[Any], + key_prefix: str, + exclude_keys: Set[str], + rename_keys: Dict[str, str], + flatten_functions: Dict[str, FlattenFunc], +) -> FlattenedDict: + result = {} + result[_concat_key(key_prefix, "length")] = len(lst) + for index, value in enumerate(lst): + full_key = f"{key_prefix}[{index}]" + flattened = _flatten_value( + full_key, + value, + exclude_keys=exclude_keys, + rename_keys=rename_keys, + flatten_functions=flatten_functions, + ) + result.update(flattened) + return result + + +def flatten_dict( + d: Dict[str, Any], + key_prefix: Optional[str] = None, + exclude_keys: Optional[Sequence[str]] = None, + rename_keys: Optional[Dict[str, str]] = None, + flatten_functions: Optional[Dict[str, FlattenFunc]] = None, +): + key_prefix = key_prefix or "" + exclude_keys = set(exclude_keys or []) + rename_keys = rename_keys or {} + flatten_functions = flatten_functions or {} + return _flatten_dict( + d, + key_prefix=key_prefix, + exclude_keys=exclude_keys, + rename_keys=rename_keys, + flatten_functions=flatten_functions, + ) diff --git a/instrumentation-genai/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/generate_content.py b/instrumentation-genai/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/generate_content.py index dcdcf6d9f7..a029c992df 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/generate_content.py +++ b/instrumentation-genai/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/generate_content.py @@ -39,6 +39,9 @@ ) from opentelemetry.semconv.attributes import error_attributes +from .allowlist_util import AllowList +from .custom_semconv import GCP_GENAI_OPERATION_CONFIG +from .dict_util import flatten_dict from .flags import is_content_recording_enabled from .otel_wrapper import OTelWrapper @@ -129,21 +132,65 @@ def _determine_genai_system(models_object: Union[Models, AsyncModels]): return _get_gemini_system_name() -def _get_config_property( - config: Optional[GenerateContentConfigOrDict], path: str -) -> Any: +def _to_dict(value: object): + if isinstance(value, dict): + return value + if hasattr(value, "model_dump"): + return value.model_dump() + return json.loads(json.dumps(value)) + + +def _add_request_options_to_span( + span, config: Optional[GenerateContentConfigOrDict], allow_list: AllowList +): if config is None: - return None - path_segments = path.split(".") - current_context: Any = config - for path_segment in path_segments: - if current_context is None: - return None - if isinstance(current_context, dict): - current_context = current_context.get(path_segment) - else: - current_context = getattr(current_context, path_segment) - return current_context + return + span_context = span.get_span_context() + if not span_context.trace_flags.sampled: + # Avoid potentially costly traversal of config + # options if the span will be dropped, anyway. + return + # Automatically derive attributes from the contents of the + # config object. This ensures that all relevant parameters + # are captured in the telemetry data (except for those + # that are excluded via "exclude_keys"). Dynamic attributes (those + # starting with "gcp.gen_ai." instead of simply "gen_ai.request.") + # are filtered with the "allow_list" before inclusion in the span. + attributes = flatten_dict( + _to_dict(config), + # A custom prefix is used, because the names/structure of the + # configuration is likely to be specific to Google Gen AI SDK. + key_prefix=GCP_GENAI_OPERATION_CONFIG, + exclude_keys=[ + # System instruction can be overly long for a span attribute. + # Additionally, it is recorded as an event (log), instead. + "gcp.gen_ai.operation.config.system_instruction", + ], + # Although a custom prefix is used by default, some of the attributes + # are captured in common, standard, Semantic Conventions. For the + # well-known properties whose values align with Semantic Conventions, + # we ensure that the key name matches the standard SemConv name. + rename_keys={ + # TODO: add more entries here as more semantic conventions are + # generalized to cover more of the available config options. + "gcp.gen_ai.operation.config.temperature": gen_ai_attributes.GEN_AI_REQUEST_TEMPERATURE, + "gcp.gen_ai.operation.config.top_k": gen_ai_attributes.GEN_AI_REQUEST_TOP_K, + "gcp.gen_ai.operation.config.top_p": gen_ai_attributes.GEN_AI_REQUEST_TOP_P, + "gcp.gen_ai.operation.config.candidate_count": gen_ai_attributes.GEN_AI_REQUEST_CHOICE_COUNT, + "gcp.gen_ai.operation.config.max_output_tokens": gen_ai_attributes.GEN_AI_REQUEST_MAX_TOKENS, + "gcp.gen_ai.operation.config.stop_sequences": gen_ai_attributes.GEN_AI_REQUEST_STOP_SEQUENCES, + "gcp.gen_ai.operation.config.frequency_penalty": gen_ai_attributes.GEN_AI_REQUEST_FREQUENCY_PENALTY, + "gcp.gen_ai.operation.config.presence_penalty": gen_ai_attributes.GEN_AI_REQUEST_PRESENCE_PENALTY, + "gcp.gen_ai.operation.config.seed": gen_ai_attributes.GEN_AI_REQUEST_SEED, + }, + ) + for key, value in attributes.items(): + if key.startswith( + GCP_GENAI_OPERATION_CONFIG + ) and not allow_list.allowed(key): + # The allowlist is used to control inclusion of the dynamic keys. + continue + span.set_attribute(key, value) def _get_response_property(response: GenerateContentResponse, path: str): @@ -159,50 +206,13 @@ def _get_response_property(response: GenerateContentResponse, path: str): return current_context -def _get_temperature(config: Optional[GenerateContentConfigOrDict]): - return _get_config_property(config, "temperature") - - -def _get_top_k(config: Optional[GenerateContentConfigOrDict]): - return _get_config_property(config, "top_k") - - -def _get_top_p(config: Optional[GenerateContentConfigOrDict]): - return _get_config_property(config, "top_p") - - -# A map from define attributes to the function that can obtain -# the relevant information from the request object. -# -# TODO: expand this to cover a larger set of the available -# span attributes from GenAI semantic conventions. -# -# TODO: define semantic conventions for attributes that -# are relevant for the Google GenAI SDK which are not -# currently covered by the existing semantic conventions. -# -# See also: TODOS.md -_SPAN_ATTRIBUTE_TO_CONFIG_EXTRACTOR = { - gen_ai_attributes.GEN_AI_REQUEST_TEMPERATURE: _get_temperature, - gen_ai_attributes.GEN_AI_REQUEST_TOP_K: _get_top_k, - gen_ai_attributes.GEN_AI_REQUEST_TOP_P: _get_top_p, -} - - -def _to_dict(value: object): - if isinstance(value, dict): - return value - if hasattr(value, "model_dump"): - return value.model_dump() - return json.loads(json.dumps(value)) - - class _GenerateContentInstrumentationHelper: def __init__( self, models_object: Union[Models, AsyncModels], otel_wrapper: OTelWrapper, model: str, + generate_content_config_key_allowlist: Optional[AllowList] = None, ): self._start_time = time.time_ns() self._otel_wrapper = otel_wrapper @@ -215,6 +225,9 @@ def __init__( self._content_recording_enabled = is_content_recording_enabled() self._response_index = 0 self._candidate_index = 0 + self._generate_content_config_key_allowlist = ( + generate_content_config_key_allowlist or AllowList() + ) def start_span_as_current_span( self, model_name, function_name, end_on_exit=True @@ -237,13 +250,9 @@ def process_request( config: Optional[GenerateContentConfigOrDict], ): span = trace.get_current_span() - for ( - attribute_key, - extractor, - ) in _SPAN_ATTRIBUTE_TO_CONFIG_EXTRACTOR.items(): - attribute_value = extractor(config) - if attribute_value is not None: - span.set_attribute(attribute_key, attribute_value) + _add_request_options_to_span( + span, config, self._generate_content_config_key_allowlist + ) self._maybe_log_system_instruction(config=config) self._maybe_log_user_prompt(contents) @@ -330,7 +339,12 @@ def _maybe_update_error_type(self, response: GenerateContentResponse): def _maybe_log_system_instruction( self, config: Optional[GenerateContentConfigOrDict] = None ): - system_instruction = _get_config_property(config, "system_instruction") + system_instruction = None + if config is not None: + if isinstance(config, dict): + system_instruction = config.get("system_instruction") + else: + system_instruction = config.system_instruction if not system_instruction: return attributes = { @@ -512,7 +526,9 @@ def _record_duration_metric(self): def _create_instrumented_generate_content( - snapshot: _MethodsSnapshot, otel_wrapper: OTelWrapper + snapshot: _MethodsSnapshot, + otel_wrapper: OTelWrapper, + generate_content_config_key_allowlist: Optional[AllowList] = None, ): wrapped_func = snapshot.generate_content @@ -526,7 +542,10 @@ def instrumented_generate_content( **kwargs: Any, ) -> GenerateContentResponse: helper = _GenerateContentInstrumentationHelper( - self, otel_wrapper, model + self, + otel_wrapper, + model, + generate_content_config_key_allowlist=generate_content_config_key_allowlist, ) with helper.start_span_as_current_span( model, "google.genai.Models.generate_content" @@ -552,7 +571,9 @@ def instrumented_generate_content( def _create_instrumented_generate_content_stream( - snapshot: _MethodsSnapshot, otel_wrapper: OTelWrapper + snapshot: _MethodsSnapshot, + otel_wrapper: OTelWrapper, + generate_content_config_key_allowlist: Optional[AllowList] = None, ): wrapped_func = snapshot.generate_content_stream @@ -566,7 +587,10 @@ def instrumented_generate_content_stream( **kwargs: Any, ) -> Iterator[GenerateContentResponse]: helper = _GenerateContentInstrumentationHelper( - self, otel_wrapper, model + self, + otel_wrapper, + model, + generate_content_config_key_allowlist=generate_content_config_key_allowlist, ) with helper.start_span_as_current_span( model, "google.genai.Models.generate_content_stream" @@ -592,7 +616,9 @@ def instrumented_generate_content_stream( def _create_instrumented_async_generate_content( - snapshot: _MethodsSnapshot, otel_wrapper: OTelWrapper + snapshot: _MethodsSnapshot, + otel_wrapper: OTelWrapper, + generate_content_config_key_allowlist: Optional[AllowList] = None, ): wrapped_func = snapshot.async_generate_content @@ -606,7 +632,10 @@ async def instrumented_generate_content( **kwargs: Any, ) -> GenerateContentResponse: helper = _GenerateContentInstrumentationHelper( - self, otel_wrapper, model + self, + otel_wrapper, + model, + generate_content_config_key_allowlist=generate_content_config_key_allowlist, ) with helper.start_span_as_current_span( model, "google.genai.AsyncModels.generate_content" @@ -633,7 +662,9 @@ async def instrumented_generate_content( # Disabling type checking because this is not yet implemented and tested fully. def _create_instrumented_async_generate_content_stream( # type: ignore - snapshot: _MethodsSnapshot, otel_wrapper: OTelWrapper + snapshot: _MethodsSnapshot, + otel_wrapper: OTelWrapper, + generate_content_config_key_allowlist: Optional[AllowList] = None, ): wrapped_func = snapshot.async_generate_content_stream @@ -647,7 +678,10 @@ async def instrumented_generate_content_stream( **kwargs: Any, ) -> Awaitable[AsyncIterator[GenerateContentResponse]]: # type: ignore helper = _GenerateContentInstrumentationHelper( - self, otel_wrapper, model + self, + otel_wrapper, + model, + generate_content_config_key_allowlist=generate_content_config_key_allowlist, ) with helper.start_span_as_current_span( model, @@ -691,20 +725,29 @@ def uninstrument_generate_content(snapshot: object): snapshot.restore() -def instrument_generate_content(otel_wrapper: OTelWrapper) -> object: +def instrument_generate_content( + otel_wrapper: OTelWrapper, + generate_content_config_key_allowlist: Optional[AllowList] = None, +) -> object: snapshot = _MethodsSnapshot() Models.generate_content = _create_instrumented_generate_content( - snapshot, otel_wrapper + snapshot, + otel_wrapper, + generate_content_config_key_allowlist=generate_content_config_key_allowlist, ) - Models.generate_content_stream = ( - _create_instrumented_generate_content_stream(snapshot, otel_wrapper) + Models.generate_content_stream = _create_instrumented_generate_content_stream( + snapshot, + otel_wrapper, + generate_content_config_key_allowlist=generate_content_config_key_allowlist, ) AsyncModels.generate_content = _create_instrumented_async_generate_content( - snapshot, otel_wrapper + snapshot, + otel_wrapper, + generate_content_config_key_allowlist=generate_content_config_key_allowlist, ) - AsyncModels.generate_content_stream = ( - _create_instrumented_async_generate_content_stream( - snapshot, otel_wrapper - ) + AsyncModels.generate_content_stream = _create_instrumented_async_generate_content_stream( + snapshot, + otel_wrapper, + generate_content_config_key_allowlist=generate_content_config_key_allowlist, ) return snapshot diff --git a/instrumentation-genai/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/instrumentor.py b/instrumentation-genai/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/instrumentor.py index ef57f5891c..8a3f792651 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/instrumentor.py +++ b/instrumentation-genai/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/instrumentor.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Collection +from typing import Any, Collection, Optional from opentelemetry._events import get_event_logger_provider from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.metrics import get_meter_provider from opentelemetry.trace import get_tracer_provider +from .allowlist_util import AllowList from .generate_content import ( instrument_generate_content, uninstrument_generate_content, @@ -27,8 +28,17 @@ class GoogleGenAiSdkInstrumentor(BaseInstrumentor): - def __init__(self): + def __init__( + self, generate_content_config_key_allowlist: Optional[AllowList] = None + ): self._generate_content_snapshot = None + self._generate_content_config_key_allowlist = ( + generate_content_config_key_allowlist + or AllowList.from_env( + "OTEL_GOOGLE_GENAI_GENERATE_CONTENT_CONFIG_INCLUDES", + excludes_env_var="OTEL_GOOGLE_GENAI_GENERATE_CONTENT_CONFIG_EXCLUDES", + ) + ) # Inherited, abstract function from 'BaseInstrumentor'. Even though 'self' is # not used in the definition, a method is required per the API contract. @@ -49,7 +59,8 @@ def _instrument(self, **kwargs: Any): meter_provider=meter_provider, ) self._generate_content_snapshot = instrument_generate_content( - otel_wrapper + otel_wrapper, + generate_content_config_key_allowlist=self._generate_content_config_key_allowlist, ) def _uninstrument(self, **kwargs: Any): diff --git a/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/common/base.py b/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/common/base.py index 1624b47868..2bb686e057 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/common/base.py +++ b/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/common/base.py @@ -33,11 +33,17 @@ def setUp(self): self._client = None self._uses_vertex = False self._credentials = FakeCredentials() + self._instrumentor_args = {} def _lazy_init(self): - self._instrumentation_context = InstrumentationContext() + self._instrumentation_context = InstrumentationContext( + **self._instrumentor_args + ) self._instrumentation_context.install() + def set_instrumentor_constructor_kwarg(self, key, value): + self._instrumentor_args[key] = value + @property def client(self): if self._client is None: diff --git a/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/common/instrumentation_context.py b/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/common/instrumentation_context.py index 6bd6ddd7aa..83ebc6fe91 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/common/instrumentation_context.py +++ b/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/common/instrumentation_context.py @@ -18,8 +18,8 @@ class InstrumentationContext: - def __init__(self): - self._instrumentor = GoogleGenAiSdkInstrumentor() + def __init__(self, **kwargs): + self._instrumentor = GoogleGenAiSdkInstrumentor(**kwargs) def install(self): self._instrumentor.instrument() diff --git a/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/generate_content/nonstreaming_base.py b/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/generate_content/nonstreaming_base.py index 9bd5df8157..39f1dfe927 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/generate_content/nonstreaming_base.py +++ b/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/generate_content/nonstreaming_base.py @@ -35,6 +35,15 @@ def generate_content(self, *args, **kwargs): def expected_function_name(self): raise NotImplementedError("Must implement 'expected_function_name'.") + def _generate_and_get_span(self, config): + self.generate_content( + model="gemini-2.0-flash", + contents="Some input prompt", + config=config, + ) + self.otel.assert_has_span_named("generate_content gemini-2.0-flash") + return self.otel.get_span_named("generate_content gemini-2.0-flash") + def test_instrumentation_does_not_break_core_functionality(self): self.configure_valid_response(text="Yep, it works!") response = self.generate_content( diff --git a/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/generate_content/test_config_span_attributes.py b/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/generate_content/test_config_span_attributes.py new file mode 100644 index 0000000000..acb6c41d0f --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/generate_content/test_config_span_attributes.py @@ -0,0 +1,162 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest import mock + +from google.genai.types import GenerateContentConfig + +from opentelemetry.instrumentation.google_genai.allowlist_util import AllowList + +from .base import TestCase + + +class ConfigSpanAttributesTestCase(TestCase): + def setUp(self): + super().setUp() + self.configure_valid_response(text="Some response") + + def generate_content(self, *args, **kwargs): + return self.client.models.generate_content(*args, **kwargs) + + def generate_and_get_span(self, config): + self.client.models.generate_content( + model="gemini-2.0-flash", + contents="Some input prompt", + config=config, + ) + self.otel.assert_has_span_named("generate_content gemini-2.0-flash") + return self.otel.get_span_named("generate_content gemini-2.0-flash") + + def test_option_reflected_to_span_attribute_choice_count_config_dict(self): + span = self.generate_and_get_span(config={"candidate_count": 2}) + self.assertEqual(span.attributes["gen_ai.request.choice.count"], 2) + + def test_option_reflected_to_span_attribute_choice_count_config_obj(self): + span = self.generate_and_get_span( + config=GenerateContentConfig(candidate_count=2) + ) + self.assertEqual(span.attributes["gen_ai.request.choice.count"], 2) + + def test_option_reflected_to_span_attribute_seed_config_dict(self): + span = self.generate_and_get_span(config={"seed": 12345}) + self.assertEqual(span.attributes["gen_ai.request.seed"], 12345) + + def test_option_reflected_to_span_attribute_seed_config_obj(self): + span = self.generate_and_get_span( + config=GenerateContentConfig(seed=12345) + ) + self.assertEqual(span.attributes["gen_ai.request.seed"], 12345) + + def test_option_reflected_to_span_attribute_frequency_penalty(self): + span = self.generate_and_get_span(config={"frequency_penalty": 1.0}) + self.assertEqual( + span.attributes["gen_ai.request.frequency_penalty"], 1.0 + ) + + def test_option_reflected_to_span_attribute_max_tokens(self): + span = self.generate_and_get_span( + config=GenerateContentConfig(max_output_tokens=5000) + ) + self.assertEqual(span.attributes["gen_ai.request.max_tokens"], 5000) + + def test_option_reflected_to_span_attribute_presence_penalty(self): + span = self.generate_and_get_span( + config=GenerateContentConfig(presence_penalty=0.5) + ) + self.assertEqual( + span.attributes["gen_ai.request.presence_penalty"], 0.5 + ) + + def test_option_reflected_to_span_attribute_stop_sequences(self): + span = self.generate_and_get_span( + config={"stop_sequences": ["foo", "bar"]} + ) + stop_sequences = span.attributes["gen_ai.request.stop_sequences"] + self.assertEqual(len(stop_sequences), 2) + self.assertEqual(stop_sequences[0], "foo") + self.assertEqual(stop_sequences[1], "bar") + + def test_option_reflected_to_span_attribute_top_k(self): + span = self.generate_and_get_span( + config=GenerateContentConfig(top_k=20) + ) + self.assertEqual(span.attributes["gen_ai.request.top_k"], 20) + + def test_option_reflected_to_span_attribute_top_p(self): + span = self.generate_and_get_span(config={"top_p": 10}) + self.assertEqual(span.attributes["gen_ai.request.top_p"], 10) + + @mock.patch.dict( + os.environ, {"OTEL_GOOGLE_GENAI_GENERATE_CONTENT_CONFIG_INCLUDES": "*"} + ) + def test_option_not_reflected_to_span_attribute_system_instruction(self): + span = self.generate_and_get_span( + config={"system_instruction": "Yadda yadda yadda"} + ) + self.assertNotIn( + "gcp.gen_ai.operation.config.system_instruction", span.attributes + ) + self.assertNotIn("gen_ai.request.system_instruction", span.attributes) + for key in span.attributes: + value = span.attributes[key] + if isinstance(value, str): + self.assertNotIn("Yadda yadda yadda", value) + + @mock.patch.dict( + os.environ, {"OTEL_GOOGLE_GENAI_GENERATE_CONTENT_CONFIG_INCLUDES": "*"} + ) + def test_option_reflected_to_span_attribute_automatic_func_calling(self): + span = self.generate_and_get_span( + config={ + "automatic_function_calling": { + "ignore_call_history": True, + } + } + ) + self.assertTrue( + span.attributes[ + "gcp.gen_ai.operation.config.automatic_function_calling.ignore_call_history" + ] + ) + + def test_dynamic_config_options_not_included_without_allow_list(self): + span = self.generate_and_get_span( + config={ + "automatic_function_calling": { + "ignore_call_history": True, + } + } + ) + self.assertNotIn( + "gcp.gen_ai.operation.config.automatic_function_calling.ignore_call_history", + span.attributes, + ) + + def test_can_supply_allow_list_via_instrumentor_constructor(self): + self.set_instrumentor_constructor_kwarg( + "generate_content_config_key_allowlist", AllowList(includes=["*"]) + ) + span = self.generate_and_get_span( + config={ + "automatic_function_calling": { + "ignore_call_history": True, + } + } + ) + self.assertTrue( + span.attributes[ + "gcp.gen_ai.operation.config.automatic_function_calling.ignore_call_history" + ] + ) diff --git a/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/requirements.oldest.txt b/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/requirements.oldest.txt index f04e668799..50fc45f39f 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/requirements.oldest.txt +++ b/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/requirements.oldest.txt @@ -21,10 +21,10 @@ pytest-vcr==1.0.2 google-auth==2.15.0 google-genai==1.0.0 -opentelemetry-api==1.30.0 -opentelemetry-sdk==1.30.0 -opentelemetry-semantic-conventions==0.51b0 -opentelemetry-instrumentation==0.51b0 +opentelemetry-api==1.31.1 +opentelemetry-sdk==1.31.1 +opentelemetry-semantic-conventions==0.52b1 +opentelemetry-instrumentation==0.52b1 # Install locally from the folder. This path is relative to the # root directory, given invocation from "tox" at root level. diff --git a/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/utils/test_allowlist_util.py b/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/utils/test_allowlist_util.py new file mode 100644 index 0000000000..f65a90f6ed --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/utils/test_allowlist_util.py @@ -0,0 +1,162 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest import mock + +from opentelemetry.instrumentation.google_genai.allowlist_util import AllowList + + +def test_empty_allowlist_allows_nothing(): + allow_list = AllowList() + assert not allow_list.allowed("") + assert not allow_list.allowed("foo") + assert not allow_list.allowed("bar") + assert not allow_list.allowed("baz") + assert not allow_list.allowed("anything at all") + + +def test_simple_include_allow_list(): + allow_list = AllowList(includes=["abc", "xyz"]) + assert allow_list.allowed("abc") + assert not allow_list.allowed("abc.xyz") + assert allow_list.allowed("xyz") + assert not allow_list.allowed("blah") + assert not allow_list.allowed("other value not in includes") + + +def test_allow_list_with_prefix_matching(): + allow_list = AllowList(includes=["abc.*", "xyz"]) + assert not allow_list.allowed("abc") + assert allow_list.allowed("abc.foo") + assert allow_list.allowed("abc.bar") + assert allow_list.allowed("xyz") + assert not allow_list.allowed("blah") + assert not allow_list.allowed("other value not in includes") + + +def test_allow_list_with_array_wildcard_matching(): + allow_list = AllowList(includes=["abc[*].foo", "xyz[*].*"]) + assert not allow_list.allowed("abc") + assert allow_list.allowed("abc[0].foo") + assert not allow_list.allowed("abc[0].bar") + assert allow_list.allowed("abc[1].foo") + assert allow_list.allowed("xyz[0].blah") + assert allow_list.allowed("xyz[1].yadayada") + assert not allow_list.allowed("blah") + assert not allow_list.allowed("other value not in includes") + + +def test_includes_and_excludes(): + allow_list = AllowList(includes=["abc", "xyz"], excludes=["xyz"]) + assert allow_list.allowed("abc") + assert not allow_list.allowed("xyz") + assert not allow_list.allowed("blah") + assert not allow_list.allowed("other value not in includes") + + +def test_includes_and_excludes_with_wildcards(): + allow_list = AllowList( + includes=["abc", "xyz", "xyz.*"], excludes=["xyz.foo", "xyz.foo.*"] + ) + assert allow_list.allowed("abc") + assert allow_list.allowed("xyz") + assert not allow_list.allowed("xyz.foo") + assert not allow_list.allowed("xyz.foo.bar") + assert not allow_list.allowed("xyz.foo.baz") + assert allow_list.allowed("xyz.not_foo") + assert allow_list.allowed("xyz.blah") + assert not allow_list.allowed("blah") + assert not allow_list.allowed("other value not in includes") + + +def test_default_include_with_excludes(): + allow_list = AllowList(includes=["*"], excludes=["foo", "bar"]) + assert not allow_list.allowed("foo") + assert not allow_list.allowed("bar") + assert allow_list.allowed("abc") + assert allow_list.allowed("xyz") + assert allow_list.allowed("blah") + assert allow_list.allowed("other value not in includes") + + +def test_default_exclude_with_includes(): + allow_list = AllowList(includes=["foo", "bar"], excludes=["*"]) + assert allow_list.allowed("foo") + assert allow_list.allowed("bar") + assert not allow_list.allowed("abc") + assert not allow_list.allowed("xyz") + assert not allow_list.allowed("blah") + assert not allow_list.allowed("other value not in includes") + + +@mock.patch.dict(os.environ, {"TEST_ALLOW_LIST_INCLUDE_KEYS": "abc,xyz"}) +def test_can_load_from_env_with_just_include_list(): + allow_list = AllowList.from_env("TEST_ALLOW_LIST_INCLUDE_KEYS") + assert allow_list.allowed("abc") + assert allow_list.allowed("xyz") + assert not allow_list.allowed("blah") + assert not allow_list.allowed("other value not in includes") + + +@mock.patch.dict( + os.environ, {"TEST_ALLOW_LIST_INCLUDE_KEYS": " abc , , xyz ,"} +) +def test_can_handle_spaces_and_empty_entries(): + allow_list = AllowList.from_env("TEST_ALLOW_LIST_INCLUDE_KEYS") + assert allow_list.allowed("abc") + assert allow_list.allowed("xyz") + assert not allow_list.allowed("") + assert not allow_list.allowed(",") + assert not allow_list.allowed("blah") + assert not allow_list.allowed("other value not in includes") + + +@mock.patch.dict( + os.environ, + { + "TEST_ALLOW_LIST_INCLUDE_KEYS": "abc,xyz", + "TEST_ALLOW_LIST_EXCLUDE_KEYS": "xyz, foo, bar", + }, +) +def test_can_load_from_env_with_includes_and_excludes(): + allow_list = AllowList.from_env( + "TEST_ALLOW_LIST_INCLUDE_KEYS", + excludes_env_var="TEST_ALLOW_LIST_EXCLUDE_KEYS", + ) + assert allow_list.allowed("abc") + assert not allow_list.allowed("xyz") + assert not allow_list.allowed("foo") + assert not allow_list.allowed("bar") + assert not allow_list.allowed("not in the list") + + +@mock.patch.dict( + os.environ, + { + "TEST_ALLOW_LIST_INCLUDE_KEYS": "*", + "TEST_ALLOW_LIST_EXCLUDE_KEYS": "xyz, foo, bar", + }, +) +def test_supports_wildcards_in_loading_from_env(): + allow_list = AllowList.from_env( + "TEST_ALLOW_LIST_INCLUDE_KEYS", + excludes_env_var="TEST_ALLOW_LIST_EXCLUDE_KEYS", + ) + assert allow_list.allowed("abc") + assert not allow_list.allowed("xyz") + assert not allow_list.allowed("foo") + assert not allow_list.allowed("bar") + assert allow_list.allowed("blah") + assert allow_list.allowed("not in the list") diff --git a/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/utils/test_dict_util.py b/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/utils/test_dict_util.py new file mode 100644 index 0000000000..ef2e641360 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-google-genai/tests/utils/test_dict_util.py @@ -0,0 +1,357 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pydantic import BaseModel + +from opentelemetry.instrumentation.google_genai import dict_util + + +class PydanticModel(BaseModel): + """Used to verify handling of pydantic models in the flattener.""" + + str_value: str = "" + int_value: int = 0 + + +class ModelDumpableNotPydantic: + """Used to verify general handling of 'model_dump'.""" + + def __init__(self, dump_output): + self._dump_output = dump_output + + def model_dump(self): + return self._dump_output + + +class NotJsonSerializable: + def __init__(self): + pass + + +def test_flatten_empty_dict(): + input_dict = {} + output_dict = dict_util.flatten_dict(input_dict) + assert output_dict is not None + assert isinstance(output_dict, dict) + assert not output_dict + + +def test_flatten_simple_dict(): + input_dict = { + "int_key": 1, + "string_key": "somevalue", + "float_key": 3.14, + "bool_key": True, + } + assert dict_util.flatten_dict(input_dict) == input_dict + + +def test_flatten_nested_dict(): + input_dict = { + "int_key": 1, + "string_key": "somevalue", + "float_key": 3.14, + "bool_key": True, + "object_key": { + "nested": { + "foo": 1, + "bar": "baz", + }, + "qux": 54321, + }, + } + assert dict_util.flatten_dict(input_dict) == { + "int_key": 1, + "string_key": "somevalue", + "float_key": 3.14, + "bool_key": True, + "object_key.nested.foo": 1, + "object_key.nested.bar": "baz", + "object_key.qux": 54321, + } + + +def test_flatten_with_key_exclusion(): + input_dict = { + "int_key": 1, + "string_key": "somevalue", + "float_key": 3.14, + "bool_key": True, + } + output = dict_util.flatten_dict(input_dict, exclude_keys=["int_key"]) + assert "int_key" not in output + assert output == { + "string_key": "somevalue", + "float_key": 3.14, + "bool_key": True, + } + + +def test_flatten_with_renaming(): + input_dict = { + "int_key": 1, + "string_key": "somevalue", + "float_key": 3.14, + "bool_key": True, + } + output = dict_util.flatten_dict( + input_dict, rename_keys={"float_key": "math_key"} + ) + assert "float_key" not in output + assert "math_key" in output + assert output == { + "int_key": 1, + "string_key": "somevalue", + "math_key": 3.14, + "bool_key": True, + } + + +def test_flatten_with_prefixing(): + input_dict = { + "int_key": 1, + "string_key": "somevalue", + "float_key": 3.14, + "bool_key": True, + } + output = dict_util.flatten_dict(input_dict, key_prefix="someprefix") + assert output == { + "someprefix.int_key": 1, + "someprefix.string_key": "somevalue", + "someprefix.float_key": 3.14, + "someprefix.bool_key": True, + } + + +def test_flatten_with_custom_flatten_func(): + def summarize_int_list(key, value, **kwargs): + total = 0 + for item in value: + total += item + avg = total / len(value) + return f"{len(value)} items (total: {total}, average: {avg})" + + flatten_functions = {"some.deeply.nested.key": summarize_int_list} + input_dict = { + "some": { + "deeply": { + "nested": { + "key": [1, 2, 3, 4, 5, 6, 7, 8, 9], + }, + }, + }, + "other": [1, 2, 3, 4, 5, 6, 7, 8, 9], + } + output = dict_util.flatten_dict( + input_dict, flatten_functions=flatten_functions + ) + assert output == { + "some.deeply.nested.key": "9 items (total: 45, average: 5.0)", + "other": [1, 2, 3, 4, 5, 6, 7, 8, 9], + } + + +def test_flatten_with_pydantic_model_value(): + input_dict = { + "foo": PydanticModel(str_value="bar", int_value=123), + } + + output = dict_util.flatten_dict(input_dict) + assert output == { + "foo.str_value": "bar", + "foo.int_value": 123, + } + + +def test_flatten_with_model_dumpable_value(): + input_dict = { + "foo": ModelDumpableNotPydantic( + { + "str_value": "bar", + "int_value": 123, + } + ), + } + + output = dict_util.flatten_dict(input_dict) + assert output == { + "foo.str_value": "bar", + "foo.int_value": 123, + } + + +def test_flatten_with_mixed_structures(): + input_dict = { + "foo": ModelDumpableNotPydantic( + { + "pydantic": PydanticModel(str_value="bar", int_value=123), + } + ), + } + + output = dict_util.flatten_dict(input_dict) + assert output == { + "foo.pydantic.str_value": "bar", + "foo.pydantic.int_value": 123, + } + + +def test_converts_tuple_with_json_fallback(): + input_dict = { + "foo": ("abc", 123), + } + output = dict_util.flatten_dict(input_dict) + assert output == { + "foo.length": 2, + "foo[0]": "abc", + "foo[1]": 123, + } + + +def test_json_conversion_handles_unicode(): + input_dict = { + "foo": ("❤️", 123), + } + output = dict_util.flatten_dict(input_dict) + assert output == { + "foo.length": 2, + "foo[0]": "❤️", + "foo[1]": 123, + } + + +def test_flatten_with_complex_object_not_json_serializable(): + result = dict_util.flatten_dict( + { + "cannot_serialize_directly": NotJsonSerializable(), + } + ) + assert result is not None + assert isinstance(result, dict) + assert len(result) == 0 + + +def test_flatten_good_with_non_serializable_complex_object(): + result = dict_util.flatten_dict( + { + "foo": { + "bar": "blah", + "baz": 5, + }, + "cannot_serialize_directly": NotJsonSerializable(), + } + ) + assert result == { + "foo.bar": "blah", + "foo.baz": 5, + } + + +def test_flatten_with_complex_object_not_json_serializable_and_custom_flatten_func(): + def flatten_not_json_serializable(key, value, **kwargs): + assert isinstance(value, NotJsonSerializable) + return "blah" + + output = dict_util.flatten_dict( + { + "cannot_serialize_directly": NotJsonSerializable(), + }, + flatten_functions={ + "cannot_serialize_directly": flatten_not_json_serializable, + }, + ) + assert output == { + "cannot_serialize_directly": "blah", + } + + +def test_flatten_simple_homogenous_primitive_string_list(): + input_dict = {"list_value": ["abc", "def"]} + assert dict_util.flatten_dict(input_dict) == input_dict + + +def test_flatten_simple_homogenous_primitive_int_list(): + input_dict = {"list_value": [123, 456]} + assert dict_util.flatten_dict(input_dict) == input_dict + + +def test_flatten_simple_homogenous_primitive_bool_list(): + input_dict = {"list_value": [True, False]} + assert dict_util.flatten_dict(input_dict) == input_dict + + +def test_flatten_simple_heterogenous_primitive_list(): + input_dict = {"list_value": ["abc", 123]} + assert dict_util.flatten_dict(input_dict) == { + "list_value.length": 2, + "list_value[0]": "abc", + "list_value[1]": 123, + } + + +def test_flatten_list_of_compound_types(): + input_dict = { + "list_value": [ + {"a": 1, "b": 2}, + {"x": 100, "y": 123, "z": 321}, + "blah", + [ + "abc", + 123, + ], + ] + } + assert dict_util.flatten_dict(input_dict) == { + "list_value.length": 4, + "list_value[0].a": 1, + "list_value[0].b": 2, + "list_value[1].x": 100, + "list_value[1].y": 123, + "list_value[1].z": 321, + "list_value[2]": "blah", + "list_value[3].length": 2, + "list_value[3][0]": "abc", + "list_value[3][1]": 123, + } + + +def test_handles_simple_output_from_flatten_func(): + def f(*args, **kwargs): + return "baz" + + input_dict = { + "foo": PydanticModel(), + } + + output = dict_util.flatten_dict(input_dict, flatten_functions={"foo": f}) + + assert output == { + "foo": "baz", + } + + +def test_handles_compound_output_from_flatten_func(): + def f(*args, **kwargs): + return {"baz": 123, "qux": 456} + + input_dict = { + "foo": PydanticModel(), + } + + output = dict_util.flatten_dict(input_dict, flatten_functions={"foo": f}) + + assert output == { + "foo.baz": 123, + "foo.qux": 456, + }