Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/scaffolding/run_best_of_n_with_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def main():
prompts = [query]

results = llm.generate(prompts)
print(results[0].output.output_str)
print(results[0].outputs[0].text)
llm.shutdown(shutdown_workers=True)
print(f'main shut down done')

Expand Down
5 changes: 2 additions & 3 deletions examples/scaffolding/run_majority_vote_aime24.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,8 @@ def main():
result = results[i]
test_case = test_dataset[i]
ref_answer = int(test_case["answer"])
result.result()
output = result.output
extracted_answer = extract_answer_from_boxed(output.output_str)
output = result.outputs[0]
extracted_answer = extract_answer_from_boxed(output.text)
try:
# print(f"[QUESTION]:\n{prompt}\n\n[OUTPUT]\n\n{output.output_str}\n\n")
answer = int(extracted_answer)
Expand Down
1 change: 0 additions & 1 deletion tensorrt_llm/scaffolding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

__all__ = [
"ScaffoldingLlm",
"ScaffoldingOutput",
"ParallelProcess",
"Controller",
"NativeGenerationController",
Expand Down
13 changes: 7 additions & 6 deletions tensorrt_llm/scaffolding/controller.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
from abc import ABC
from enum import Enum
from typing import Any, List, Mapping
from typing import Any, List, Mapping, Tuple

import torch
from torch.nn import functional as F
Expand Down Expand Up @@ -231,13 +231,14 @@ def process(self,
generation_kwargs_list)

candidates = [tasks[0].output_str for tasks in tasks_list]
result = self.majority_vote(candidates, **majority_vote_kwargs)
majority_index, majority_answer = self.majority_vote(
candidates, **majority_vote_kwargs)

assert isinstance(result, str), "majority_vote failed"
assert isinstance(majority_answer, str), "majority_vote failed"
# The task returned by majority vote does not have output_tokens and logits.
tasks[0].output_str = result
tasks[0].result = tasks_list[majority_index][0].result

def majority_vote(self, candidates: List[str], **kwargs) -> str:
def majority_vote(self, candidates: List[str], **kwargs) -> Tuple[int, str]:
return get_digit_majority_vote_result(candidates)


Expand Down Expand Up @@ -292,7 +293,7 @@ def process(self,

best_task, best_idx = self.select_best(generation_tasks, reward_values,
**select_best_kwargs)
task.output_str = best_task.output_str
task.result = best_task.result

def select_best(self, tasks: List[Task], reward_values, **kwargs) -> Task:
max_index = torch.argmax(torch.tensor(reward_values)).item()
Expand Down
34 changes: 18 additions & 16 deletions tensorrt_llm/scaffolding/math_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
from collections import Counter
from typing import List


Expand Down Expand Up @@ -59,28 +58,31 @@ def get_majority_result(
result_extractor=lambda x: x,
result_validator=lambda x: True,
):
valid_answers_and_results = [(result, result_extractor(result))
for result in results
if result_validator(result) is True
and result_extractor(result) is not None]
if len(valid_answers_and_results) == 0:
extract_answers = [result_extractor(result) for result in results]
valid_answers = [
result for result in extract_answers
if result is not None and result_validator(result) is True
]
if len(valid_answers) == 0:
return None, None

majority_result = Counter(valid_answers_and_results).most_common(1)[0][0]
# return result and extracted result
return majority_result[0], majority_result[1]
answer_counts = {}
for answer in valid_answers:
answer_counts[answer] = answer_counts.get(answer, 0) + 1
majority_answer = max(answer_counts, key=answer_counts.get)
majority_index = next(
filter(lambda x: x[1] == majority_answer,
enumerate(extract_answers)))[0]
return majority_index, majority_answer


def get_digit_majority_vote_result(results: List[str]) -> str:

def is_digit(result: str):
extracted_answer = extract_answer_from_boxed(result)
if extracted_answer is None:
return False
return extracted_answer.isdigit()
return result.isdigit()

vote_result = get_majority_result(
index, extract_answer = get_majority_result(
results,
result_extractor=extract_answer_from_boxed,
result_validator=is_digit)[0]
return vote_result if vote_result else results[0]
result_validator=is_digit)
return (index, extract_answer) if extract_answer else (0, None)
10 changes: 1 addition & 9 deletions tensorrt_llm/scaffolding/result.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
import asyncio
from dataclasses import dataclass
from typing import Mapping, Optional

from tensorrt_llm.executor.result import GenerationResult


@dataclass(slots=True)
class ScaffoldingOutput:

def __init__(self):
self.output_str = None


class ScaffoldingResult:

def __init__(self, streaming_event: Optional[asyncio.Event] = None):
super().__init__()
self.aqueue = asyncio.Queue()
self.cur_output = None
self.cur_output: GenerationResult = None
self._done = False
self.task_collections = None
self.streaming_event = streaming_event
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/scaffolding/scaffolding_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def _handle_task_list(self,
]
await asyncio.gather(*async_tasks)
for task in tasks:
if task.streaming:
if getattr(task, 'streaming', False):
await request.result.set_output_async(task.result)
self.streaming_event.clear()
await self.streaming_event.wait()
Expand Down
27 changes: 13 additions & 14 deletions tensorrt_llm/scaffolding/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ class GenerationTask(Task):
worker_tag: Union[str, "Controller.WorkerTag"] = None

# result field
_outputs: Optional[List[dict]] = None

# link to TRTLLM's GenerationResult, for async update in streaming mode
_result: Optional[GenerationResult] = None

Expand All @@ -74,35 +72,36 @@ def result(self) -> GenerationResult:
@result.setter
def result(self, result: GenerationResult) -> None:
self._result = result
self._outputs = result.outputs

@property
def outputs(self) -> Optional[List[dict]]:
return self._result.outputs if self._result else None

@property
def output_tokens(self) -> List[int]:
return self._outputs[
0].token_ids if self.result and self._outputs else None
return self._result.outputs[0].token_ids if self._result else None

@property
def output_str(self) -> Optional[str]:
return self._outputs[0].text if self.result and self._outputs else None
return self._result.outputs[0].text if self._result else None

@output_str.setter
def output_str(self, output) -> Optional[str]:
assert self.result and self._outputs
self._outputs[0].text = output
assert self.result
self._result.outputs[0].text = output

@property
def cumulative_logprob(self) -> Optional[float]:
return self._outputs[
0].cumulative_logprob if self.result and self._outputs else None
return self._result.outputs[
0].cumulative_logprob if self._result else None

@property
def logprobs(self) -> Optional[List[float]]:
return self._outputs[
0].logprobs if self.result and self._outputs else None
return self._result.outputs[0].logprobs if self._result else None

@property
def context_logits(self) -> Optional[torch.Tensor]:
return self.result.context_logits if self.result else None
return self._result.context_logits if self._result else None

@staticmethod
def create_from_prompt(prompt: str) -> "GenerationTask":
Expand All @@ -113,7 +112,7 @@ def create_from_prompt(prompt: str) -> "GenerationTask":
return task

def create_scaffolding_output(self) -> GenerationResult:
return self.result
return self._result


@dataclass
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,6 @@ examples/test_multimodal.py::test_llm_multimodal_general[Qwen2-VL-7B-Instruct-pp
examples/test_multimodal.py::test_llm_fp8_multimodal_general[fp8-fp8-cnn_dailymail-Qwen2-VL-7B-Instruct-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False] SKIP (https://nvbugs/5385987)
examples/test_multimodal.py::test_llm_multimodal_general[Phi-4-multimodal-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5385992)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] SKIP (https://nvbugs/5377914)
test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B] SKIP (https://nvbugs/5387375)
examples/test_multimodal.py::test_llm_multimodal_general[kosmos-2-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5387422)
examples/test_multimodal.py::test_llm_multimodal_general[fuyu-8b-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5387424)
test_e2e.py::test_ptp_quickstart SKIP (https://nvbugs/5387762)
Expand Down
6 changes: 3 additions & 3 deletions tests/unittest/scaffolding/test_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class DummyWorker(Worker):

async def dummy_generation_handler(self, task: GenerationTask):
task.output_str = OUTPUT_STR
task.result = OUTPUT_STR
return TaskStatus.SUCCESS

task_handlers = {GenerationTask: dummy_generation_handler}
Expand All @@ -29,7 +29,7 @@ def before_yield(self, tasks: List[Task]):
pass

def after_yield(self, tasks: List[Task]):
self.output_len = len(tasks[0].output_str)
self.output_len = len(tasks[0].result)


def test_scaffolding_benchmark():
Expand All @@ -56,6 +56,6 @@ def test_scaffolding_benchmark():

assert len(results) == requests_num
assert len(requests_execution_time) == requests_num
assert results[0].output.output_str == OUTPUT_STR
assert results[0].cur_output == OUTPUT_STR
assert results[0].task_collections[
"bench_dummy_collection"].output_len == len(OUTPUT_STR)
8 changes: 0 additions & 8 deletions tests/unittest/scaffolding/test_parallel_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from enum import Enum
from typing import List

import pytest

from tensorrt_llm.scaffolding import (Controller, ParallelProcess,
ScaffoldingLlm, Task, TaskStatus, Worker)

Expand All @@ -21,8 +19,6 @@ def create_from_prompt(prompt: str) -> "DummyTask":
task = DummyTask(2)
return task

# TODO: Fix when ScaffoldingOutput is replaced with GenerationResult
# def create_scaffolding_output(self) -> "ScaffoldingOutput":
def create_scaffolding_output(self):
self.verify()
return None
Expand All @@ -34,8 +30,6 @@ def verify(self):

class DummyControllerBase(Controller):

# TODO: Fix when ScaffoldingOutput is replaced with GenerationResult
# def generate(self, prompt: str, **kwargs) -> ScaffoldingOutput:
def generate(self, prompt: str, **kwargs):
task = DummyTask.create_from_prompt(prompt)
yield from self.process([task], **kwargs)
Expand Down Expand Up @@ -125,7 +119,6 @@ def parallel_process_helper_run_and_verify(controllers):
llm.shutdown()


@pytest.skip(reason="ScaffoldingOutput removed in PR #5345, needs refactoring")
def test_parallel_process_helper():
NUM_CONTROLLERS = 3
controllers = []
Expand All @@ -137,7 +130,6 @@ def test_parallel_process_helper():
parallel_process_helper_run_and_verify(controllers)


@pytest.skip(reason="ScaffoldingOutput removed in PR #5345, needs refactoring")
def test_parallel_process_helper_with_two_level():
NUM_CONTROLLERS_LEVEL_1 = 2
NUM_CONTROLLERS_LEVEL_2 = 2
Expand Down
7 changes: 0 additions & 7 deletions tests/unittest/scaffolding/test_task_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from enum import Enum
from typing import List

import pytest

from tensorrt_llm.scaffolding import (Controller, ParallelProcess,
ScaffoldingLlm, Task, TaskCollection,
TaskStatus, Worker, with_task_collection)
Expand All @@ -20,8 +18,6 @@ def create_from_prompt(prompt: str) -> "DummyTask":
task = DummyTask()
return task

# TODO: Fix when ScaffoldingOutput is replaced with GenerationResult
# def create_scaffolding_output(self) -> "ScaffoldingOutput":
def create_scaffolding_output(self):
return None

Expand Down Expand Up @@ -55,8 +51,6 @@ def __init__(self, expected_task_count: int):
super().__init__()
self.expected_task_count = expected_task_count

# TODO: Fix when ScaffoldingOutput is replaced with GenerationResult
# def generate(self, prompt: str, **kwargs) -> ScaffoldingOutput:
def generate(self, prompt: str, **kwargs):
task = DummyTask.create_from_prompt(prompt)
yield from self.process([task], **kwargs)
Expand Down Expand Up @@ -127,7 +121,6 @@ def run(controller, expected_task_count):
llm.shutdown()


@pytest.skip(reason="ScaffoldingOutput removed in PR #5345, needs refactoring")
def test_dummy_task_collection():
controller = DummyController(1)
run(controller, 1)
Expand Down