4
4
import random
5
5
import time
6
6
7
+ import aiofiles
8
+
9
+ from aiperf .clients .model_endpoint_info import ModelEndpointInfo
7
10
from aiperf .common .aiperf_logger import AIPerfLogger
8
11
from aiperf .common .base_component_service import BaseComponentService
9
12
from aiperf .common .config import ServiceConfig , UserConfig
13
+ from aiperf .common .config .config_defaults import OutputDefaults
10
14
from aiperf .common .decorators import implements_protocol
11
15
from aiperf .common .enums import (
12
16
CommAddress ,
16
20
ServiceType ,
17
21
)
18
22
from aiperf .common .enums .dataset_enums import CustomDatasetType
19
- from aiperf .common .factories import ComposerFactory , ServiceFactory
23
+ from aiperf .common .factories import (
24
+ ComposerFactory ,
25
+ RequestConverterFactory ,
26
+ ServiceFactory ,
27
+ )
20
28
from aiperf .common .hooks import on_command , on_request
21
29
from aiperf .common .messages import (
22
30
ConversationRequestMessage ,
29
37
ProfileConfigureCommand ,
30
38
)
31
39
from aiperf .common .mixins import ReplyClientMixin
32
- from aiperf .common .models import Conversation
33
- from aiperf .common .protocols import ServiceProtocol
40
+ from aiperf .common .models import Conversation , InputsFile
41
+ from aiperf .common .models .dataset_models import SessionPayloads
42
+ from aiperf .common .protocols import RequestConverterProtocol , ServiceProtocol
34
43
from aiperf .common .tokenizer import Tokenizer
35
44
from aiperf .dataset .loader import ShareGPTLoader
36
45
@@ -87,6 +96,7 @@ async def _profile_configure_command(
87
96
self .info (lambda : f"Configuring dataset for { self .service_id } " )
88
97
begin = time .perf_counter ()
89
98
await self ._configure_dataset ()
99
+ await self ._generate_inputs_json_file ()
90
100
duration = time .perf_counter () - begin
91
101
self .info (lambda : f"Dataset configured in { duration :.2f} seconds" )
92
102
@@ -104,6 +114,57 @@ async def _configure_tokenizer(self) -> None:
104
114
revision = self .user_config .tokenizer .revision ,
105
115
)
106
116
117
+ async def _generate_input_payloads (
118
+ self ,
119
+ model_endpoint : ModelEndpointInfo ,
120
+ request_converter : RequestConverterProtocol ,
121
+ ) -> InputsFile :
122
+ """Generate input payloads from the dataset for use in the inputs.json file."""
123
+ inputs = InputsFile ()
124
+ for conversation in self .dataset .values ():
125
+ payloads = await asyncio .gather (
126
+ * [
127
+ request_converter .format_payload (model_endpoint , turn )
128
+ for turn in conversation .turns
129
+ ]
130
+ )
131
+ inputs .data .append (
132
+ SessionPayloads (session_id = conversation .session_id , payloads = payloads )
133
+ )
134
+ return inputs
135
+
136
+ async def _generate_inputs_json_file (self ) -> None :
137
+ """Generate inputs.json file in the artifact directory."""
138
+ file_path = (
139
+ self .user_config .output .artifact_directory / OutputDefaults .INPUTS_JSON_FILE
140
+ )
141
+ self .info (f"Generating inputs.json file at { file_path .resolve ()} " )
142
+
143
+ try :
144
+ start_time = time .perf_counter ()
145
+ file_path .parent .mkdir (parents = True , exist_ok = True )
146
+
147
+ model_endpoint = ModelEndpointInfo .from_user_config (self .user_config )
148
+ request_converter = RequestConverterFactory .create_instance (
149
+ model_endpoint .endpoint .type ,
150
+ )
151
+
152
+ inputs = await self ._generate_input_payloads (
153
+ model_endpoint , request_converter
154
+ )
155
+
156
+ async with aiofiles .open (file_path , "w" ) as f :
157
+ await f .write (inputs .model_dump_json (indent = 2 , exclude_unset = True ))
158
+
159
+ duration = time .perf_counter () - start_time
160
+ self .info (f"inputs.json file generated in { duration :.2f} seconds" )
161
+
162
+ except Exception as e :
163
+ # Log as warning, but continue to run the benchmark
164
+ self .warning (
165
+ f"Error generating inputs.json file at { file_path .resolve ()} : { e } "
166
+ )
167
+
107
168
async def _configure_dataset (self ) -> None :
108
169
if self .user_config is None :
109
170
raise self ._service_error ("User config is required for dataset manager" )
0 commit comments