Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/v1/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
logger = init_logger(__name__)

PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
GREEDY_TEMPERATURE: tl.constexpr = -1
GREEDY_TEMPERATURE: tl.constexpr = 0
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 128
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/sample/tpu/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class TPUSupportedSamplingMetadata:
top_p: torch.Tensor = None

all_greedy: bool = True
all_random: bool = False

# Whether logprobs are to be gathered in this batch of request. To balance
# out compile time and runtime, a fixed `max_number_logprobs` value is used
Expand Down Expand Up @@ -110,6 +111,7 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
xla_device
),
all_greedy=input_batch.all_greedy,
all_random=input_batch.all_random,
# TODO enable more and avoid returning None values
top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(xla_device),
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(xla_device),
Expand Down
8 changes: 7 additions & 1 deletion vllm/v1/sample/tpu/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ def apply_temperature(
self,
logits: torch.Tensor,
temp: torch.Tensor,
all_random: bool = False,
) -> torch.Tensor:
# Avoid division by zero for greedy sampling (temperature ~ 0.0).
if not all_random:
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
return logits.div_(temp.unsqueeze(dim=1))

def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
Expand All @@ -56,7 +60,9 @@ def sample(
assert sampling_metadata.temperature is not None

# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
logits = self.apply_temperature(
logits, sampling_metadata.temperature, sampling_metadata.all_random
)

# Apply min_p.
if sampling_metadata.min_p is not None:
Expand Down
12 changes: 10 additions & 2 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
Expand Down Expand Up @@ -1140,8 +1141,15 @@ def compute_probs_and_sample_next_token(
next_token_ids = logits.argmax(dim=-1)
return next_token_ids, probs

is_greedy = sampling_metadata.temperature == -1
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
assert sampling_metadata.temperature is not None

# Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
# consistent with sampler.py's _SAMPLING_EPS threshold
temperature = sampling_metadata.temperature
# Avoid division by zero if there are greedy requests.
if not sampling_metadata.all_random:
is_greedy = temperature < _SAMPLING_EPS
temperature = torch.where(is_greedy, 1.0, temperature)
logits.div_(temperature.view(-1, 1))
probs = logits.softmax(dim=-1, dtype=torch.float32)

Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/worker/tpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def add_request(
sampling_params = request.sampling_params
assert sampling_params is not None, "pooling requests not supported yet"
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
# Should avoid division by zero later when apply_temperature.
self.temperature_cpu[req_index] = 0.0
Comment thread
Pradyun92 marked this conversation as resolved.
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
Expand Down