|
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 |
|
4 | 4 | import argparse |
| 5 | +import dataclasses |
| 6 | +import functools |
| 7 | +import json |
5 | 8 | import logging |
6 | 9 | import pathlib |
| 10 | +from enum import Enum |
7 | 11 | from typing import Any, Dict, Optional |
8 | 12 |
|
9 | | -from dynamo.common.utils import canonical_json_encoder, get_dynamo_version |
| 13 | +from dynamo.common import __version__ |
10 | 14 |
|
11 | 15 | from .environment import get_environment_vars |
12 | 16 | from .system_info import get_runtime_info, get_system_info |
@@ -86,7 +90,7 @@ def get_config_dump(config: Any, extra_info: Dict[str, str] = {}) -> str: |
86 | 90 | "environment": get_environment_vars(), |
87 | 91 | "config": config, |
88 | 92 | "runtime_info": get_runtime_info(), |
89 | | - "dynamo_version": get_dynamo_version(), |
| 93 | + "dynamo_version": __version__, |
90 | 94 | } |
91 | 95 |
|
92 | 96 | # Add common versions |
@@ -126,3 +130,57 @@ def add_config_dump_args(parser: argparse.ArgumentParser): |
126 | 130 | default=None, |
127 | 131 | help="Dump debug config to the specified file path. If not specified, the config will be dumped to stdout at INFO level.", |
128 | 132 | ) |
| 133 | + |
| 134 | + |
| 135 | +@functools.singledispatch |
| 136 | +def _preprocess_for_encode(obj: object) -> object: |
| 137 | + """ |
| 138 | + Single dispatch function for preprocessing objects before JSON encoding. |
| 139 | +
|
| 140 | + This function should be extended using @register_encoder decorator |
| 141 | + for backend-specific types. |
| 142 | + """ |
| 143 | + if dataclasses.is_dataclass(obj) and not isinstance(obj, type): |
| 144 | + return dataclasses.asdict(obj) |
| 145 | + logger.warning(f"Unknown type {type(obj)}, using __dict__ or str(obj)") |
| 146 | + if hasattr(obj, "__dict__"): |
| 147 | + return obj.__dict__ |
| 148 | + return str(obj) |
| 149 | + |
| 150 | + |
| 151 | +def register_encoder(type_class): |
| 152 | + """ |
| 153 | + Decorator to register custom encoders for specific types. |
| 154 | +
|
| 155 | + Usage: |
| 156 | + @register_encoder(MyClass) |
| 157 | + def encode_my_class(obj: MyClass): |
| 158 | + return {"field": obj.field} |
| 159 | + """ |
| 160 | + logger.info(f"Registering encoder for {type_class}") |
| 161 | + return _preprocess_for_encode.register(type_class) |
| 162 | + |
| 163 | + |
| 164 | +@register_encoder(set) |
| 165 | +def _preprocess_for_encode_set( |
| 166 | + obj: set, |
| 167 | +) -> list: # pyright: ignore[reportUnusedFunction] |
| 168 | + return sorted(list(obj)) |
| 169 | + |
| 170 | + |
| 171 | +@register_encoder(Enum) |
| 172 | +def _preprocess_for_encode_enum( |
| 173 | + obj: Enum, |
| 174 | +) -> str: # pyright: ignore[reportUnusedFunction] |
| 175 | + return str(obj) |
| 176 | + |
| 177 | + |
| 178 | +# Create a canonical JSON encoder with consistent formatting |
| 179 | +canonical_json_encoder = json.JSONEncoder( |
| 180 | + ensure_ascii=False, |
| 181 | + separators=(",", ":"), |
| 182 | + allow_nan=False, |
| 183 | + sort_keys=True, |
| 184 | + indent=None, |
| 185 | + default=_preprocess_for_encode, |
| 186 | +) |
0 commit comments