Skip to content

Commit 11fe163

Browse files
Merge pull request #500 from cangtianhuang/develop
Fix `API Tracer` about exceeding length
2 parents 8a1d4d9 + 13e5613 commit 11fe163

File tree

3 files changed

+94
-37
lines changed

3 files changed

+94
-37
lines changed

tools/api_tracer/config_serializer.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import time
33
from collections import defaultdict, deque
44
from threading import Event, Thread
5-
from types import EllipsisType
65
from typing import Any, Dict, List, TextIO
76

87
import yaml
@@ -27,29 +26,33 @@ def __init__(
2726
self.merge_output = merge_output
2827

2928
self.file_handlers: Dict[int, Dict[str, TextIO]] = {}
30-
3129
self.buffer_limit = 20000
3230
self.buffers: Dict[int, List[Dict]] = defaultdict(list)
3331

32+
self.max_args_count = 100
33+
self.max_item_count = 100
34+
self.max_line_length = 1024
35+
self.max_nest_depth = 5
36+
3437
# asyncio
3538
self.log_queue = deque()
3639
self._stop_event = Event()
3740
self.writer_thread = Thread(target=self._writer_loop)
3841
self.total_calls_processed = 0
3942

4043
self._serialize_handlers = {
41-
type(None): lambda x: x,
42-
bool: lambda x: x,
43-
int: lambda x: x,
44-
float: lambda x: x,
45-
str: lambda x: x,
44+
type(None): lambda x, depth: x,
45+
bool: lambda x, depth: x,
46+
int: lambda x, depth: x,
47+
float: lambda x, depth: x,
48+
str: lambda x, depth: x,
4649
list: self._serialize_list,
4750
tuple: self._serialize_tuple,
4851
set: self._serialize_set,
4952
dict: self._serialize_dict,
5053
type: self._serialize_type,
5154
slice: self._serialize_slice,
52-
EllipsisType: self._serialize_ellipsis,
55+
type(Ellipsis): self._serialize_ellipsis,
5356
}
5457

5558
def open(self):
@@ -157,60 +160,86 @@ def dump_call(
157160
):
158161
"""记录一次API调用"""
159162
try:
163+
total_args = len(args) + len(kwargs)
164+
if total_args > self.max_args_count:
165+
if len(args) < self.max_args_count:
166+
kwargs = dict(
167+
list(kwargs.items())[: self.max_args_count - len(args) - 1]
168+
)
169+
kwargs["__truncated__"] = "<Truncated: max args exceeded>"
170+
else:
171+
args = tuple(
172+
list(args)[: self.max_args_count - 1]
173+
+ ["<Truncated: max args exceeded>"]
174+
)
175+
kwargs = {}
160176
call_record = {
161177
"level": level,
162178
"api": api_name,
163-
"args": [self._serialize_item(arg) for arg in args],
179+
"args": [self._serialize_item(arg, depth=0) for arg in args],
164180
"kwargs": {
165-
key: self._serialize_item(value) for key, value in kwargs.items()
181+
key: self._serialize_item(value, depth=0)
182+
for key, value in kwargs.items()
166183
},
167-
# "output_summary": self._serialize_item(output)
184+
# "output_summary": self._serialize_item(output, depth=0)
168185
}
169186
self.log_queue.append(call_record)
170187
except Exception as e:
171188
print(f"[ConfigSerializer] Error serializing call for '{api_name}': {e}")
172189

173-
def _serialize_list(self, item: list) -> Dict:
190+
def _serialize_list(self, item: list, depth: int) -> Dict:
191+
if len(item) > self.max_item_count:
192+
item = item[: self.max_item_count - 1] + ["<Truncated: max item count>"]
174193
return {
175194
"type": "list",
176-
"value": [self._serialize_item(sub_item) for sub_item in item],
195+
"value": [self._serialize_item(sub_item, depth) for sub_item in item],
177196
}
178197

179-
def _serialize_tuple(self, item: tuple) -> Dict:
198+
def _serialize_tuple(self, item: tuple, depth: int) -> Dict:
199+
if len(item) > self.max_item_count:
200+
item = item[: self.max_item_count - 1] + ("<Truncated: max item count>",)
180201
return {
181202
"type": "tuple",
182-
"value": [self._serialize_item(sub_item) for sub_item in item],
203+
"value": [self._serialize_item(sub_item, depth) for sub_item in item],
183204
}
184205

185-
def _serialize_set(self, item: set) -> Dict:
206+
def _serialize_set(self, item: set, depth: int) -> Dict:
207+
if len(item) > self.max_item_count:
208+
item = set(list(item)[: self.max_item_count - 1])
186209
return {
187210
"type": "set",
188-
"value": [self._serialize_item(sub_item) for sub_item in item],
211+
"value": [self._serialize_item(sub_item, depth) for sub_item in item],
189212
}
190213

191-
def _serialize_dict(self, item: dict) -> Dict:
214+
def _serialize_dict(self, item: dict, depth: int) -> Dict:
215+
if len(item) > self.max_item_count:
216+
item = dict(list(item.keys())[: self.max_item_count - 1])
217+
item["__truncated__"] = "<Truncated: max item count>"
192218
return {
193219
"type": "dict",
194-
"value": {str(k): self._serialize_item(v) for k, v in item.items()},
220+
"value": {str(k): self._serialize_item(v, depth) for k, v in item.items()},
195221
}
196222

197-
def _serialize_type(self, item: type) -> Dict:
223+
def _serialize_type(self, item: type, depth: int) -> Dict:
198224
return {"type": "type", "value": f"{item.__module__}.{item.__name__}"}
199225

200-
def _serialize_slice(self, item: slice) -> Dict:
226+
def _serialize_slice(self, item: slice, depth: int) -> Dict:
201227
return {
202228
"type": "slice",
203229
"value": {"start": item.start, "stop": item.stop, "step": item.step},
204230
}
205231

206-
def _serialize_ellipsis(self, item: Any) -> Dict:
232+
def _serialize_ellipsis(self, item: Any, depth: int) -> Dict:
207233
return {"type": "ellipsis", "value": "..."}
208234

209-
def _serialize_item(self, item: Any) -> Any:
235+
def _serialize_item(self, item: Any, depth=0) -> Any:
210236
"""递归序列化对象"""
237+
if depth > self.max_nest_depth:
238+
return "<Truncated: max depth exceeded>"
239+
211240
handler = self._serialize_handlers.get(type(item))
212241
if handler:
213-
return handler(item)
242+
return handler(item, depth=depth + 1)
214243

215244
special_serialization = self.dialect.serialize_special_type(item)
216245
if special_serialization is not None:
@@ -228,6 +257,8 @@ def format_arg(arg: Any) -> str:
228257
if arg is None or isinstance(arg, (bool, int, float)):
229258
return str(arg)
230259
if isinstance(arg, str):
260+
if len(arg) > 100:
261+
return f'"{arg[:97]}..."'
231262
return f'"{arg}"'
232263

233264
if isinstance(arg, dict) and "type" in arg:
@@ -252,7 +283,11 @@ def format_arg(arg: Any) -> str:
252283

253284
args_str = ", ".join(format_arg(arg) for arg in args)
254285
kwargs_str = ", ".join(f"{k}={format_arg(v)}" for k, v in kwargs.items())
255-
return f"{api_name}({args_str + (', ' + kwargs_str if kwargs_str else '')})"
286+
result = f"{api_name}({args_str + (', ' + kwargs_str if kwargs_str else '')})"
287+
288+
if len(result) > self.max_line_length:
289+
result = result[: self.max_line_length - 4] + "...)"
290+
return result
256291

257292
def get_apis_and_configs(self):
258293
if self.merge_output:

tools/api_tracer/test_infer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from transformers import AutoModelForCausalLM, AutoTokenizer
1515

1616
MODELS = [
17-
# "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
1817
# "Qwen/Qwen2-0.5B",
1918
# "Qwen/Qwen3-0.6B",
2019
# "Qwen/Qwen3-30B-A3B",
@@ -27,8 +26,11 @@
2726

2827
def run_inference_test(model_name: str):
2928
print(f"🚀 Running inference test for: {model_name}")
30-
output_path = f"tools/api_tracer/trace_output_test_infer/{model_name}"
31-
tracer = APITracer("torch", output_path=output_path, levels=[0, 1])
29+
true_model_name = "/".join(model_name.rsplit("/", 2)[-2:])
30+
output_path = f"tools/api_tracer/trace_output_test_infer/{true_model_name}"
31+
tracer = APITracer(
32+
"torch", output_path=output_path, levels=[0, 1], merge_output=True
33+
)
3234

3335
try:
3436
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
@@ -59,10 +61,10 @@ def run_inference_test(model_name: str):
5961
print("\n--- Generated Response ---")
6062
print(response)
6163
print("--------------------------\n")
62-
print(f"✅ Test for {model_name} finished.")
64+
print(f"✅ Test for {true_model_name} finished.")
6365
except Exception as e:
6466
traceback.print_exc()
65-
print(f"❌ An error occurred during inference for {model_name}: {e}")
67+
print(f"❌ An error occurred during inference for {true_model_name}: {e}")
6668

6769

6870
def main():

tools/api_tracer/test_train.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
import time
32
import traceback
43

54
os.environ["HF_HOME"] = "tools/api_tracer/.huggingface"
@@ -17,9 +16,9 @@
1716
from tools.api_tracer import APITracer
1817

1918
MODELS = [
20-
# "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
2119
# "Qwen/Qwen2-0.5B",
2220
# "Qwen/Qwen3-0.6B",
21+
# "Qwen/Qwen3-30B-A3B",
2322
# "Qwen/Qwen2.5-VL-3B-Instruct",
2423
# "deepseek-ai/DeepSeek-V2-Lite",
2524
# "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
@@ -29,20 +28,41 @@
2928

3029
def run_training_test(model_name: str):
3130
print(f"🚀 Running training test for: {model_name})")
32-
output_path = f"tools/api_tracer/trace_output_test_train/{model_name}"
33-
tracer = APITracer("torch", output_path=output_path, levels=[0, 1])
31+
true_model_name = "/".join(model_name.rsplit("/", 2)[-2:])
32+
output_path = f"tools/api_tracer/trace_output_test_train/{true_model_name}"
33+
tracer = APITracer(
34+
"torch", output_path=output_path, levels=[0, 1], merge_output=True
35+
)
3436

3537
try:
3638
model = AutoModelForCausalLM.from_pretrained(
3739
model_name,
3840
torch_dtype=torch.bfloat16,
3941
device_map="auto",
4042
trust_remote_code=True,
43+
use_cache=False,
4144
)
4245
tokenizer = AutoTokenizer.from_pretrained(model_name)
4346
if tokenizer.pad_token is None:
4447
tokenizer.pad_token = tokenizer.eos_token
4548

49+
if "Llama" in true_model_name:
50+
llama_chat_template = (
51+
"{% for message in messages %}"
52+
"{% if message['role'] == 'system' %}"
53+
"{{'<|start_header_id|>system<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>'}}"
54+
"{% elif message['role'] == 'user' %}"
55+
"{{'<|start_header_id|>user<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>'}}"
56+
"{% elif message['role'] == 'assistant' %}"
57+
"{{'<|start_header_id|>assistant<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>'}}"
58+
"{% endif %}"
59+
"{% endfor %}"
60+
"{% if add_generation_prompt %}"
61+
"{{'<|start_header_id|>assistant<|end_header_id|>\n\n'}}"
62+
"{% endif %}"
63+
)
64+
tokenizer.chat_template = llama_chat_template
65+
4666
print(f"Model Class: {model.__class__}")
4767
print(f"Tokenizer Class: {tokenizer.__class__}")
4868

@@ -82,7 +102,7 @@ def preprocess_function(examples):
82102
save_strategy="no",
83103
bf16=True,
84104
report_to="none",
85-
max_steps=5,
105+
max_steps=1,
86106
gradient_checkpointing=True,
87107
)
88108

@@ -98,10 +118,10 @@ def preprocess_function(examples):
98118
with tracer:
99119
trainer.train()
100120

101-
print(f"✅ Test for {model_name} finished.")
121+
print(f"✅ Test for {true_model_name} finished.")
102122
except Exception as e:
103123
traceback.print_exc()
104-
print(f"❌ An error occurred during training for {model_name}: {e}")
124+
print(f"❌ An error occurred during training for {true_model_name}: {e}")
105125

106126

107127
def run_training_test_vision(model_name: str):

0 commit comments

Comments
 (0)