diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index b4f7c19a8d3d..d605c4b65e9d 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -5,7 +5,18 @@ import torch import torch_xla.core.xla_model as xm -from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu_input_batch import InputBatch + +DEFAULT_SAMPLING_PARAMS = dict( + temperature=-1.0, + min_p=0.0, + # strictly disabled for now + # top_k=-1, + # top_p=0.0, + # frequency_penalties=0.0, + # presence_penalties=0.0, + # repetition_penalties=0.0, +) @dataclass @@ -20,14 +31,8 @@ class TPUSupportedSamplingMetadata: top_k: torch.Tensor = None top_p: torch.Tensor = None - # XLA-unfriendly control flow in Sampler - all_greedy: bool = False - all_random: bool = False # Greedy sampling flag for compiling single xla graph. - do_argmax: torch.Tensor = None - - # speculation not supported - spec_token_ids = None + all_greedy: torch.Tensor = None # Generator not supported by xla generators: dict[int, @@ -54,106 +59,68 @@ class TPUSupportedSamplingMetadata: bad_words_token_ids = None indices_do_sample: torch.Tensor = None - def __post_init__(self): - temp = self.temperature - if self.indices_do_sample is None: - self.indices_do_sample = torch.zeros(temp.shape[0], - device=temp.device, - dtype=torch.int32) - if self.do_argmax is None: - self.do_argmax = torch.tensor(0, - dtype=torch.bool, - device=temp.device) - @classmethod - def from_sampling_metadata( - cls, metadata: SamplingMetadata, - padded_do_sample_indices: torch.Tensor, num_do_sample: int, - device: torch.device) -> "TPUSupportedSamplingMetadata": + def from_input_batch( + cls, input_batch: InputBatch, + indices_do_sample: torch.Tensor) -> "TPUSupportedSamplingMetadata": """ - Create an XLA-frienly SamplingMetadata structure. Do so by first - instantiating an object with fixed-sized tensors and then writing the - values in input `metadata`. Do that only for non-None values so that - recompilation is not triggered for optional values (None/torch.Tensor). - - In order to handle different sizes for the params that range from 1 up - to `max_num_seqs`, pad tensors to the closest pre-compiled shape. - Same thing for `padded_do_sample_indices`, which contains the indices - to be fed to the Sampler, padded to the closest pre-compiled shape. - - Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0] - do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0] + Copy sampling tensors slices from `input_batch` to on device tensors. + + `InputBatch._make_sampling_metadata` causes recompilation on XLA as it + slices dynamic shapes on device tensors. This impl moves the dynamic + ops to CPU and produces tensors of fixed `padded_num_reqs` size. It + also reuses the on-device persistent tensors managed in `input_batch` + to reduce waste. + + `indices_do_sample` contains the indices to be fed to the Sampler, + normally one per request, here padded to the closest pre-compiled shape + We expect sampling params tensors to be padded to the same fixed shape. + + Eg. 3 requests, tensors padded to 4 + temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0] + sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0] """ - metadata = cls._validate_sampling_metadata(metadata) - # NOTE we have to initialize default tensor-based params first and - # skip None values altogether to produce the same xla graph. - num_samples = len(padded_do_sample_indices) - do_argmax = torch.tensor(metadata.all_greedy, - dtype=torch.bool, - device=device) - new_metadata = cls.get_default_sampling_params(num_samples, device, - indices_do_sample=\ - padded_do_sample_indices, - do_argmax=do_argmax - ) - supported_params = \ - TPUSupportedSamplingMetadata._get_default_params_values() - # Copy input non-None values into `new_metadata` fixed-sized tensors. - for p_name in supported_params: - old_val = getattr(metadata, p_name) - new_val = getattr(new_metadata, p_name) - if isinstance(old_val, torch.Tensor): - new_val[:num_do_sample] = old_val - setattr(new_metadata, p_name, new_val) + num_reqs = input_batch.num_reqs + padded_num_reqs = len(indices_do_sample) + + def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor, + fill_val) -> torch.Tensor: + # Copy slice from CPU to corresponding TPU pre-allocated tensor. + # Pad value is the default one. + cpu_tensor[num_reqs:padded_num_reqs] = fill_val + tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs] + + # NOTE NickLucche The sync CPU-TPU graph we produce here must be + # consistent. We can't have flags to skip copies or we'll end up + # recompiling. + copy_slice(input_batch.temperature_cpu_tensor, input_batch.temperature, + DEFAULT_SAMPLING_PARAMS["temperature"]) + # TODO Temporarily disabled until sampling options are enabled + # copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p) + # copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k) + copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p, + DEFAULT_SAMPLING_PARAMS["min_p"]) + + # copy_slice(input_batch.frequency_penalties_cpu_tensor, + # input_batch.frequency_penalties) + # copy_slice(input_batch.presence_penalties_cpu_tensor, + # input_batch.presence_penalties) + # copy_slice(input_batch.repetition_penalties_cpu_tensor, + # input_batch.repetition_penalties) xm.mark_step() xm.wait_device_ops() - return new_metadata - @classmethod - def get_default_sampling_params( - cls, - num_samples: int, - device: torch.device, - indices_do_sample=None, - do_argmax=None) -> "TPUSupportedSamplingMetadata": - # As sampling happens on a single traced graph, options - # are "disabled" by having them evaluate to an Identity op. - # Note that initialization is dependent on num_samples. - sampling_metadata_disable_value = \ - TPUSupportedSamplingMetadata._get_default_params_values() - init_kwargs = dict() - for p_name, (default_val, - dtype) in sampling_metadata_disable_value.items(): - default_tensor = torch.full((num_samples, ), - default_val, - dtype=dtype, - device=device) - init_kwargs[p_name] = default_tensor - - return cls(**init_kwargs, - indices_do_sample=indices_do_sample, - do_argmax=do_argmax) - - @staticmethod - def _validate_sampling_metadata( - sampling_metadata: SamplingMetadata) -> SamplingMetadata: - if sampling_metadata.all_greedy: - # Set to None since #13587. Make sure default isn't overruled. - assert sampling_metadata.temperature is None - return sampling_metadata - - @staticmethod - def _get_default_params_values(): - return dict( - # Since #13587 greedy sampling requires branching off which leads - # to separate graphs. We set temp to noop and handle argmax here. - temperature=(1.0, torch.float32), - min_p=(0.0, torch.float32), - # strictly disabled for now - # top_k=(-1, torch.int32), - # top_p=(0.0, torch.float32), - # frequency_penalties=(0.0, torch.float32), - # presence_penalties=(0.0, torch.float32), - # repetition_penalties=(0.0, torch.float32), - ) \ No newline at end of file + # Slice persistent device tensors to a fixed pre-compiled padded shape. + return cls( + temperature=input_batch.temperature[:padded_num_reqs], + # Scalar tensor for xla-friendly tracing. + all_greedy=torch.tensor(input_batch.all_greedy, + dtype=torch.bool, + device=input_batch.device), + # TODO enable more and avoid returning None values + top_p=None, # input_batch.top_p[:padded_num_reqs], + top_k=None, # input_batch.top_k[:padded_num_reqs], + min_p=input_batch.min_p[:padded_num_reqs], + generators=input_batch.generators, + indices_do_sample=indices_do_sample) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f4502f6b4237..0e9473a33453 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -279,9 +279,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: req_data.num_computed_tokens) self.input_batch.block_table.append_row(req_data.new_block_ids, req_index) - # Check if the batch has changed. If not, we can skip copying the - # sampling metadata from CPU to GPU. - batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0 # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. @@ -300,9 +297,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: if removed_req_indices: self.input_batch.condense(removed_req_indices) - # TODO This slices tensors to copy to device, triggering recompilation. - if batch_changed: - self.input_batch.refresh_sampling_metadata() return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 def get_model(self) -> nn.Module: @@ -597,14 +591,12 @@ def execute_model( # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids inputs_embeds = None - sampling_metadata = self.input_batch.sampling_metadata num_reqs = self.input_batch.num_reqs - # NOTE (NickLucche) here we sync with TPU: if there's any shape - # mismatch in pre-processing, it will trigger a small recompilation - # of the code thus far. Forward graph remains untouched. + # NOTE (NickLucche) here we sync with TPU: sampling params tensors + # are copied to device in chunks of pre-compiled padded shape to + # avoid recompilations. tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ - from_sampling_metadata(sampling_metadata, logits_indices, - num_reqs, self.device) + from_input_batch(self.input_batch, logits_indices) # Run the decoder with set_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( @@ -797,21 +789,19 @@ def capture_model(self) -> None: device=device, dtype=torch.bfloat16) while True: - # Default metadata is an all_greedy setup. But since the - # `do_argmax` flag is a tensor, we still compile the full graph - meta = self.input_batch.sampling_metadata indices = torch.zeros( num_reqs_to_sample, dtype=torch.int32, device=device, ) + xm.mark_step() sampling_meta = TPUSupportedSamplingMetadata.\ - from_sampling_metadata(meta, indices, - num_reqs_to_sample, device) + from_input_batch(self.input_batch, indices) logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, num_reqs_to_sample) - self.model.sample_from_hidden(dummy_hidden, sampling_meta) - xm.mark_step() + out = self.model.sample_from_hidden(dummy_hidden, + sampling_meta) + out = out.cpu() if num_reqs_to_sample >= self.max_num_reqs: break num_reqs_to_sample *= 2 @@ -910,6 +900,7 @@ def forward( return hidden_states + # @torch.compile(backend="openxla", fullgraph=True, dynamic=False) def sample_from_hidden( self, hidden_states: torch.Tensor, @@ -923,10 +914,9 @@ def sample_from_hidden( sample_hidden_states = \ hidden_states[sampling_metadata.indices_do_sample] logits = self.compute_logits(sample_hidden_states) - # Greedy sampling can't be run without branching the graph on Sampler. - # Therefore do_argmax/all_greedy is checked here in a xla-friendly way. - # NOTE do_argmax is a scalar, this is just an optimized if/else. - out_tokens = torch.where(sampling_metadata.do_argmax, + # Optimized greedy sampling branch, tracing both paths in a single pass + # NOTE all_greedy is a scalar, this is just an optimized if/else. + out_tokens = torch.where(sampling_metadata.all_greedy, torch.argmax(logits, dim=-1, keepdim=True), self.sample(logits, sampling_metadata)\ .sampled_token_ids)