Skip to content

Commit a58de00

Browse files
authored
feat: add inputs.json file for dataset traceability (#310)
1 parent 1b14d24 commit a58de00

File tree

14 files changed

+474
-16
lines changed

14 files changed

+474
-16
lines changed

aiperf/common/config/config_defaults.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ class TurnDelayDefaults:
114114
class OutputDefaults:
115115
ARTIFACT_DIRECTORY = Path("./artifacts")
116116
PROFILE_EXPORT_FILE = Path("profile_export.json")
117+
LOG_FOLDER = Path("logs")
118+
LOG_FILE = Path("aiperf.log")
119+
INPUTS_JSON_FILE = Path("inputs.json")
120+
PROFILE_EXPORT_AIPERF_CSV_FILE = Path("profile_export_aiperf.csv")
121+
PROFILE_EXPORT_AIPERF_JSON_FILE = Path("profile_export_aiperf.json")
117122

118123

119124
@dataclass(frozen=True)

aiperf/common/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
MILLIS_PER_SECOND = 1000
99
BYTES_PER_MIB = 1024 * 1024
1010

11+
1112
GRACEFUL_SHUTDOWN_TIMEOUT = 5.0
1213
"""Default timeout for shutting down services in seconds."""
1314

aiperf/common/logging.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from aiperf.common.aiperf_logger import _DEBUG, _TRACE, AIPerfLogger
1313
from aiperf.common.config import ServiceConfig, ServiceDefaults, UserConfig
14+
from aiperf.common.config.config_defaults import OutputDefaults
1415
from aiperf.common.enums import ServiceType
1516
from aiperf.common.enums.ui_enums import AIPerfUIType
1617
from aiperf.common.factories import ServiceFactory
@@ -112,7 +113,7 @@ def setup_child_process_logging(
112113

113114
if user_config and user_config.output.artifact_directory:
114115
file_handler = create_file_handler(
115-
user_config.output.artifact_directory / "logs", level
116+
user_config.output.artifact_directory / OutputDefaults.LOG_FOLDER, level
116117
)
117118
root_logger.addHandler(file_handler)
118119

@@ -138,9 +139,9 @@ def setup_rich_logging(user_config: UserConfig, service_config: ServiceConfig) -
138139

139140
# Enable file logging for services
140141
# TODO: Use config to determine if file logging is enabled and the folder path.
141-
log_folder = user_config.output.artifact_directory / "logs"
142+
log_folder = user_config.output.artifact_directory / OutputDefaults.LOG_FOLDER
142143
log_folder.mkdir(parents=True, exist_ok=True)
143-
file_handler = logging.FileHandler(log_folder / "aiperf.log")
144+
file_handler = logging.FileHandler(log_folder / OutputDefaults.LOG_FILE)
144145
file_handler.setLevel(level)
145146
file_handler.formatter = logging.Formatter(
146147
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
@@ -158,7 +159,7 @@ def create_file_handler(
158159
"""Configure a file handler for logging."""
159160

160161
log_folder.mkdir(parents=True, exist_ok=True)
161-
log_file_path = log_folder / "aiperf.log"
162+
log_file_path = log_folder / OutputDefaults.LOG_FILE
162163

163164
file_handler = logging.FileHandler(log_file_path, encoding="utf-8")
164165
file_handler.setLevel(level)

aiperf/common/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
Audio,
2222
Conversation,
2323
Image,
24+
InputsFile,
2425
Media,
26+
SessionPayloads,
2527
Text,
2628
Turn,
2729
)
@@ -86,6 +88,7 @@
8688
"IOCounters",
8789
"Image",
8890
"InferenceServerResponse",
91+
"InputsFile",
8992
"Media",
9093
"MetricResult",
9194
"ParsedResponse",
@@ -102,6 +105,7 @@
102105
"SSEField",
103106
"SSEMessage",
104107
"ServiceRunInfo",
108+
"SessionPayloads",
105109
"StatsProtocol",
106110
"Text",
107111
"TextResponse",

aiperf/common/models/dataset_models.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from typing import ClassVar
4+
from typing import Any, ClassVar
55

66
from pydantic import Field
77

@@ -81,3 +81,25 @@ class Conversation(AIPerfBaseModel):
8181
default=[], description="List of turns in the conversation."
8282
)
8383
session_id: str = Field(default="", description="Session ID of the conversation.")
84+
85+
86+
class SessionPayloads(AIPerfBaseModel):
87+
"""A single session, with its session ID and a list of formatted payloads (one per turn)."""
88+
89+
session_id: str | None = Field(
90+
default=None, description="Session ID of the conversation."
91+
)
92+
payloads: list[dict[str, Any]] = Field(
93+
default=[],
94+
description="List of formatted payloads in the session (one per turn). These have been formatted for the model and endpoint.",
95+
)
96+
97+
98+
class InputsFile(AIPerfBaseModel):
99+
"""A list of all dataset sessions. Each session contains a list of formatted payloads (one per turn).
100+
This is similar to the format used by GenAI-Perf for the inputs.json file.
101+
"""
102+
103+
data: list[SessionPayloads] = Field(
104+
default=[], description="List of all dataset sessions."
105+
)

aiperf/controller/system_controller.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from aiperf.common.base_service import BaseService
1212
from aiperf.common.config import ServiceConfig, UserConfig
13+
from aiperf.common.config.config_defaults import OutputDefaults
1314
from aiperf.common.config.dev_config import print_developer_mode_warning
1415
from aiperf.common.constants import (
1516
AIPERF_DEV_MODE,
@@ -513,7 +514,11 @@ async def _print_post_benchmark_info_and_metrics(self) -> None:
513514

514515
def _print_log_file_info(self, console: Console) -> None:
515516
"""Print the log file info."""
516-
log_file = self.user_config.output.artifact_directory / "logs" / "aiperf.log"
517+
log_file = (
518+
self.user_config.output.artifact_directory
519+
/ OutputDefaults.LOG_FOLDER
520+
/ OutputDefaults.LOG_FILE
521+
)
517522
console.print(
518523
f"[bold green]Log File:[/bold green] [cyan]{log_file.resolve()}[/cyan]"
519524
)

aiperf/dataset/dataset_manager.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@
44
import random
55
import time
66

7+
import aiofiles
8+
9+
from aiperf.clients.model_endpoint_info import ModelEndpointInfo
710
from aiperf.common.aiperf_logger import AIPerfLogger
811
from aiperf.common.base_component_service import BaseComponentService
912
from aiperf.common.config import ServiceConfig, UserConfig
13+
from aiperf.common.config.config_defaults import OutputDefaults
1014
from aiperf.common.decorators import implements_protocol
1115
from aiperf.common.enums import (
1216
CommAddress,
@@ -16,7 +20,11 @@
1620
ServiceType,
1721
)
1822
from aiperf.common.enums.dataset_enums import CustomDatasetType
19-
from aiperf.common.factories import ComposerFactory, ServiceFactory
23+
from aiperf.common.factories import (
24+
ComposerFactory,
25+
RequestConverterFactory,
26+
ServiceFactory,
27+
)
2028
from aiperf.common.hooks import on_command, on_request
2129
from aiperf.common.messages import (
2230
ConversationRequestMessage,
@@ -29,8 +37,9 @@
2937
ProfileConfigureCommand,
3038
)
3139
from aiperf.common.mixins import ReplyClientMixin
32-
from aiperf.common.models import Conversation
33-
from aiperf.common.protocols import ServiceProtocol
40+
from aiperf.common.models import Conversation, InputsFile
41+
from aiperf.common.models.dataset_models import SessionPayloads
42+
from aiperf.common.protocols import RequestConverterProtocol, ServiceProtocol
3443
from aiperf.common.tokenizer import Tokenizer
3544
from aiperf.dataset.loader import ShareGPTLoader
3645

@@ -87,6 +96,7 @@ async def _profile_configure_command(
8796
self.info(lambda: f"Configuring dataset for {self.service_id}")
8897
begin = time.perf_counter()
8998
await self._configure_dataset()
99+
await self._generate_inputs_json_file()
90100
duration = time.perf_counter() - begin
91101
self.info(lambda: f"Dataset configured in {duration:.2f} seconds")
92102

@@ -104,6 +114,57 @@ async def _configure_tokenizer(self) -> None:
104114
revision=self.user_config.tokenizer.revision,
105115
)
106116

117+
async def _generate_input_payloads(
118+
self,
119+
model_endpoint: ModelEndpointInfo,
120+
request_converter: RequestConverterProtocol,
121+
) -> InputsFile:
122+
"""Generate input payloads from the dataset for use in the inputs.json file."""
123+
inputs = InputsFile()
124+
for conversation in self.dataset.values():
125+
payloads = await asyncio.gather(
126+
*[
127+
request_converter.format_payload(model_endpoint, turn)
128+
for turn in conversation.turns
129+
]
130+
)
131+
inputs.data.append(
132+
SessionPayloads(session_id=conversation.session_id, payloads=payloads)
133+
)
134+
return inputs
135+
136+
async def _generate_inputs_json_file(self) -> None:
137+
"""Generate inputs.json file in the artifact directory."""
138+
file_path = (
139+
self.user_config.output.artifact_directory / OutputDefaults.INPUTS_JSON_FILE
140+
)
141+
self.info(f"Generating inputs.json file at {file_path.resolve()}")
142+
143+
try:
144+
start_time = time.perf_counter()
145+
file_path.parent.mkdir(parents=True, exist_ok=True)
146+
147+
model_endpoint = ModelEndpointInfo.from_user_config(self.user_config)
148+
request_converter = RequestConverterFactory.create_instance(
149+
model_endpoint.endpoint.type,
150+
)
151+
152+
inputs = await self._generate_input_payloads(
153+
model_endpoint, request_converter
154+
)
155+
156+
async with aiofiles.open(file_path, "w") as f:
157+
await f.write(inputs.model_dump_json(indent=2, exclude_unset=True))
158+
159+
duration = time.perf_counter() - start_time
160+
self.info(f"inputs.json file generated in {duration:.2f} seconds")
161+
162+
except Exception as e:
163+
# Log as warning, but continue to run the benchmark
164+
self.warning(
165+
f"Error generating inputs.json file at {file_path.resolve()}: {e}"
166+
)
167+
107168
async def _configure_dataset(self) -> None:
108169
if self.user_config is None:
109170
raise self._service_error("User config is required for dataset manager")

aiperf/exporters/csv_exporter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import aiofiles
1111

12+
from aiperf.common.config.config_defaults import OutputDefaults
1213
from aiperf.common.decorators import implements_protocol
1314
from aiperf.common.enums import DataExporterType
1415
from aiperf.common.enums.metric_enums import MetricFlags
@@ -40,7 +41,9 @@ def __init__(self, exporter_config: ExporterConfig, **kwargs) -> None:
4041
self._results = exporter_config.results
4142
self._output_directory = exporter_config.user_config.output.artifact_directory
4243
self._metric_registry = MetricRegistry
43-
self._file_path = self._output_directory / "profile_export_aiperf.csv"
44+
self._file_path = (
45+
self._output_directory / OutputDefaults.PROFILE_EXPORT_AIPERF_CSV_FILE
46+
)
4447
self._percentile_keys = _percentile_keys_from(STAT_KEYS)
4548

4649
def get_export_info(self) -> FileExportInfo:

aiperf/exporters/json_exporter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pydantic import BaseModel
88

99
from aiperf.common.config import UserConfig
10+
from aiperf.common.config.config_defaults import OutputDefaults
1011
from aiperf.common.constants import NANOS_PER_SECOND
1112
from aiperf.common.decorators import implements_protocol
1213
from aiperf.common.enums import DataExporterType, MetricFlags
@@ -45,7 +46,9 @@ def __init__(self, exporter_config: ExporterConfig, **kwargs) -> None:
4546
self._output_directory = exporter_config.user_config.output.artifact_directory
4647
self._input_config = exporter_config.user_config
4748
self._metric_registry = MetricRegistry
48-
self._file_path = self._output_directory / "profile_export_aiperf.json"
49+
self._file_path = (
50+
self._output_directory / OutputDefaults.PROFILE_EXPORT_AIPERF_JSON_FILE
51+
)
4952

5053
def get_export_info(self) -> FileExportInfo:
5154
return FileExportInfo(

tests/data_exporters/test_csv_exporter.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99

1010
from aiperf.common.config import EndpointConfig, ServiceConfig, UserConfig
11+
from aiperf.common.config.config_defaults import OutputDefaults
1112
from aiperf.common.enums import EndpointType
1213
from aiperf.common.models import MetricResult
1314
from aiperf.exporters.csv_exporter import CsvExporter
@@ -146,7 +147,7 @@ async def test_csv_exporter_writes_two_sections_and_values(
146147
exporter = CsvExporter(cfg)
147148
await exporter.export()
148149

149-
expected = outdir / "profile_export_aiperf.csv"
150+
expected = outdir / OutputDefaults.PROFILE_EXPORT_AIPERF_CSV_FILE
150151
assert expected.exists()
151152

152153
text = _read(expected)
@@ -195,7 +196,7 @@ async def test_csv_exporter_empty_records_creates_empty_file(
195196
exporter = CsvExporter(cfg)
196197
await exporter.export()
197198

198-
expected = outdir / "profile_export_aiperf.csv"
199+
expected = outdir / OutputDefaults.PROFILE_EXPORT_AIPERF_CSV_FILE
199200
assert expected.exists()
200201
content = _read(expected)
201202
assert content.strip() == ""
@@ -233,7 +234,7 @@ async def test_csv_exporter_deterministic_sort_order(
233234
exporter = CsvExporter(cfg)
234235
await exporter.export()
235236

236-
text = _read(outdir / "profile_export_aiperf.csv")
237+
text = _read(outdir / OutputDefaults.PROFILE_EXPORT_AIPERF_CSV_FILE)
237238

238239
# Request section should list aaa_latency then zzz_latency in order
239240
# Pull only the request rows region (before the blank line separator).
@@ -288,7 +289,7 @@ async def test_csv_exporter_unit_aware_number_formatting(
288289
exporter = CsvExporter(cfg)
289290
await exporter.export()
290291

291-
text = _read(outdir / "profile_export_aiperf.csv")
292+
text = _read(outdir / OutputDefaults.PROFILE_EXPORT_AIPERF_CSV_FILE)
292293

293294
# counts: integer
294295
assert re.search(r"Input Sequence Length \(tokens\),\s*4096\b", text)

0 commit comments

Comments
 (0)