Skip to content

Comments

[Bugfix][Logprobs] Fix logprobs op to support more backend#21591

Merged
vllm-bot merged 1 commit intovllm-project:mainfrom
MengqingCao:compile
Jul 25, 2025
Merged

[Bugfix][Logprobs] Fix logprobs op to support more backend#21591
vllm-bot merged 1 commit intovllm-project:mainfrom
MengqingCao:compile

Conversation

@MengqingCao
Copy link
Contributor

@MengqingCao MengqingCao commented Jul 25, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This pr fixes logprobs op to support more backend, currently it only support inductor backend, which may break some hardwares.

Closes: #21592

Test Plan

Test with vllm-ascend with the following scripts:

from vllm import LLM, SamplingParams

def main():
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    # Create a sampling params object.
    sampling_params = SamplingParams(max_tokens=100, prompt_logprobs=1, temperature=0.0)
    # Create an LLM.
    llm = LLM(model="Qwen/Qwen2.5-0.5B-Instruct")

    # Generate texts from the prompts.
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == "__main__":
    main()

Test Result

before this pr:

Details
Traceback (most recent call last):
    outputs = self._run_engine(use_tqdm=use_tqdm)
  File "/home/xxx/code/vllm-cpu/vllm/vllm/entrypoints/llm.py", line 1701, in _run_engine
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/xxx/code/vllm-cpu/vllm/vllm/v1/engine/core.py", line 638, in run_engine_core
    raise e
  File "/home/xxx/code/vllm-cpu/vllm/vllm/v1/engine/core.py", line 627, in run_engine_core
    engine_core.run_busy_loop()
  File "/home/xxx/code/vllm-cpu/vllm/vllm/v1/engine/core.py", line 654, in run_busy_loop
    self._process_engine_step()
  File "/home/xxx/code/vllm-cpu/vllm/vllm/v1/engine/core.py", line 679, in _process_engine_step
    outputs, model_executed = self.step_fn()
  File "/home/xxx/code/vllm-cpu/vllm/vllm/v1/engine/core.py", line 268, in step
    model_output = self.execute_model_with_error_logging(
  File "/home/xxx/code/vllm-cpu/vllm/vllm/v1/engine/core.py", line 254, in execute_model_with_error_logging
    raise err
  File "/home/xxx/code/vllm-cpu/vllm/vllm/v1/engine/core.py", line 245, in execute_model_with_error_logging
    return model_fn(scheduler_output)
  File "/home/xxx/code/vllm-cpu/vllm/vllm/v1/executor/abstract.py", line 87, in execute_model
    output = self.collective_rpc("execute_model",
  File "/home/xxx/code/vllm-cpu/vllm/vllm/executor/uniproc_executor.py", line 58, in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
  File "/home/xxx/code/vllm-cpu/vllm/vllm/utils/__init__.py", line 2986, in run_method
    return func(*args, **kwargs)
  File "/home/xxx/code/vllm-ascend/vllm_ascend/worker/worker_v1.py", line 190, in execute_model
    output = self.model_runner.execute_model(scheduler_output,
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/xxx/code/vllm-ascend/vllm_ascend/worker/model_runner_v1.py", line 1523, in execute_model
    prompt_logprobs_dict = self._get_prompt_logprobs_dict(
  File "/home/xxx/code/vllm-ascend/vllm_ascend/worker/model_runner_v1.py", line 2387, in _get_prompt_logprobs_dict
    token_ids, logprobs, ranks = self.sampler.gather_logprobs(
  File "/home/xxx/code/vllm-cpu/vllm/vllm/v1/sample/sampler.py", line 191, in gather_logprobs
    token_ranks = batched_count_greater_than(logprobs, token_logprobs)
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1064, in __call__
    result = self._inner_convert(
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
    return fn(*args, **kwargs)
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
    tracer.run()
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
    super().run()
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2987, in RETURN_VALUE
    self._return(inst)
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2972, in _return
    self.output.compile_subgraph(
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
    return self._call_user_compiler(gm)
  File "/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: TypeError: 'NoneType' object is not callable
  target: aten.sum.dim_IntList
  args[0]: TensorBox(StorageBox(
    Pointwise(
      'npu',
      torch.bool,
      def inner_fn(index):
          i0, i1 = index
          tmp0 = ops.load(arg2_1, i1 + i0 * s1)
          tmp1 = ops.load(arg3_1, i0)
          tmp2 = tmp0 >= tmp1
          return tmp2
      ,
      ranges=[s0, s1],
      origin_node=ge,
      origins=OrderedSet([ge])
    )
  ))
  args[1]: [-1]

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

    step_outputs = self.llm_engine.step()
  File "/home/xxx/code/vllm-cpu/vllm/vllm/v1/engine/llm_engine.py", line 237, in step
    outputs = self.engine_core.get_output()
  File "/home/xxx/code/vllm-cpu/vllm/vllm/v1/engine/core_client.py", line 582, in get_output
    raise self._format_exception(outputs) from None

After this pr:

Details
INFO 07-25 07:14:45 [__init__.py:38] Available plugins for group vllm.platform_plugins:
INFO 07-25 07:14:45 [__init__.py:40] - ascend -> vllm_ascend:register
INFO 07-25 07:14:45 [__init__.py:43] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load.
INFO 07-25 07:14:45 [__init__.py:226] Platform plugin ascend is activated
WARNING 07-25 07:14:48 [_custom_ops.py:20] Failed to import from vllm._C with ImportError('libnuma.so.1: cannot open shared object file: No such file or directory')
INFO 07-25 07:14:51 [importing.py:63] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 07-25 07:14:52 [registry.py:435] Model architecture DeepSeekMTPModel is already registered, and will be overwritten by the new model class vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP.
WARNING 07-25 07:14:52 [registry.py:435] Model architecture Qwen2VLForConditionalGeneration is already registered, and will be overwritten by the new model class vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration.
WARNING 07-25 07:14:52 [registry.py:435] Model architecture Qwen2_5_VLForConditionalGeneration is already registered, and will be overwritten by the new model class vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration.
WARNING 07-25 07:14:52 [registry.py:435] Model architecture DeepseekV2ForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM.
WARNING 07-25 07:14:52 [registry.py:435] Model architecture DeepseekV3ForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend.models.deepseek_v3:CustomDeepseekV3ForCausalLM.
WARNING 07-25 07:14:52 [registry.py:435] Model architecture Qwen3MoeForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM.
WARNING 07-25 07:14:52 [registry.py:435] Model architecture Qwen3ForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend.models.qwen3:CustomQwen3ForCausalLM.
2025-07-25 07:14:52,570 - modelscope - INFO - Not logged-in, you can login for uploadingor accessing controlled entities.
Downloading Model from https://www.modelscope.cn to directory: /home/xxx/cache/modelscope/models/Qwen/Qwen2.5-0.5B-Instruct
2025-07-25 07:14:53,085 - modelscope - INFO - Target directory already exists, skipping creation.
Downloading Model from https://www.modelscope.cn to directory: /home/xxx/cache/modelscope/models/Qwen/Qwen2.5-0.5B-Instruct
2025-07-25 07:14:54,035 - modelscope - INFO - Target directory already exists, skipping creation.
INFO 07-25 07:15:11 [config.py:1605] Using max model len 32768
INFO 07-25 07:15:11 [config.py:2416] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 07-25 07:15:11 [platform.py:157] PIECEWISE compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode
INFO 07-25 07:15:11 [utils.py:311] Calculated maximum supported batch sizes for ACL graph: 76
INFO 07-25 07:15:11 [utils.py:337] No adjustment needed for ACL graph batch sizes: Qwen2ForCausalLM model (layers: 24) with 67 sizes
Downloading Model from https://www.modelscope.cn to directory: /home/xxx/cache/modelscope/models/Qwen/Qwen2.5-0.5B-Instruct
2025-07-25 07:15:12,255 - modelscope - INFO - Target directory already exists, skipping creation.
Downloading Model from https://www.modelscope.cn to directory: /home/xxx/cache/modelscope/models/Qwen/Qwen2.5-0.5B-Instruct
2025-07-25 07:15:13,782 - modelscope - INFO - Target directory already exists, skipping creation.
INFO 07-25 07:15:24 [__init__.py:38] Available plugins for group vllm.platform_plugins:
INFO 07-25 07:15:24 [__init__.py:40] - ascend -> vllm_ascend:register
INFO 07-25 07:15:24 [__init__.py:43] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load.
INFO 07-25 07:15:24 [__init__.py:226] Platform plugin ascend is activated
WARNING 07-25 07:15:27 [_custom_ops.py:20] Failed to import from vllm._C with ImportError('libnuma.so.1: cannot open shared object file: No such file or directory')
INFO 07-25 07:15:27 [core.py:574] Waiting for init message from front-end.
INFO 07-25 07:15:29 [importing.py:63] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 07-25 07:15:30 [registry.py:435] Model architecture DeepSeekMTPModel is already registered, and will be overwritten by the new model class vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP.
WARNING 07-25 07:15:30 [registry.py:435] Model architecture Qwen2VLForConditionalGeneration is already registered, and will be overwritten by the new model class vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration.
WARNING 07-25 07:15:30 [registry.py:435] Model architecture Qwen2_5_VLForConditionalGeneration is already registered, and will be overwritten by the new model class vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration.
WARNING 07-25 07:15:30 [registry.py:435] Model architecture DeepseekV2ForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM.
WARNING 07-25 07:15:30 [registry.py:435] Model architecture DeepseekV3ForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend.models.deepseek_v3:CustomDeepseekV3ForCausalLM.
WARNING 07-25 07:15:30 [registry.py:435] Model architecture Qwen3MoeForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM.
WARNING 07-25 07:15:30 [registry.py:435] Model architecture Qwen3ForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend.models.qwen3:CustomQwen3ForCausalLM.
INFO 07-25 07:15:30 [core.py:72] Initializing a V1 LLM engine (v0.9.2.dev301+g3c545c0c3) with config: model='Qwen/Qwen2.5-0.5B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-0.5B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=npu, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=Qwen/Qwen2.5-0.5B-Instruct, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":["all"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output","vllm.mamba_mixer2","vllm.unified_ascend_attention_with_output"],"use_inductor":false,"compile_sizes":[],"inductor_compile_config":{},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":512,"local_cache_dir":null}
INFO 07-25 07:15:33 [parallel_state.py:1102] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
INFO 07-25 07:15:33 [model_runner_v1.py:1800] Starting to load model Qwen/Qwen2.5-0.5B-Instruct...
2025-07-25 07:15:35,213 - modelscope - INFO - Not logged-in, you can login for uploadingor accessing controlled entities.
Downloading Model from https://www.modelscope.cn to directory: /home/xxx/cache/modelscope/models/Qwen/Qwen2.5-0.5B-Instruct
2025-07-25 07:15:35,778 - modelscope - INFO - Target directory already exists, skipping creation.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  3.69it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  3.69it/s]

INFO 07-25 07:15:36 [default_loader.py:262] Loading weights took 0.37 seconds
INFO 07-25 07:15:37 [model_runner_v1.py:1833] Loading model weights took 0.9278 GB
INFO 07-25 07:15:44 [backends.py:530] Using cache directory: /home/xxx/.cache/vllm/torch_compile_cache/36b867fecb/rank_0_0/backbone for vLLM's torch.compile
INFO 07-25 07:15:44 [backends.py:541] Dynamo bytecode transform time: 6.48 s
INFO 07-25 07:15:47 [backends.py:215] Compiling a graph for dynamic shape takes 2.24 s
INFO 07-25 07:15:54 [monitor.py:34] torch.compile takes 8.72 s in total
INFO 07-25 07:15:55 [worker_v1.py:175] Available memory: 56856685772, total memory: 65464696832
INFO 07-25 07:15:55 [kv_cache_utils.py:833] GPU KV cache size: 4,626,944 tokens
INFO 07-25 07:15:55 [kv_cache_utils.py:837] Maximum concurrency for 32,768 tokens per request: 141.20x
INFO 07-25 07:16:42 [model_runner_v1.py:2129] Graph capturing finished in 48 secs, took 0.17 GiB
INFO 07-25 07:16:42 [core.py:194] init engine (profile, create kv cache, warmup model) took 65.29 seconds
Downloading Model from https://www.modelscope.cn to directory: /home/xxx/cache/modelscope/models/Qwen/Qwen2.5-0.5B-Instruct
2025-07-25 07:16:43,684 - modelscope - INFO - Target directory already exists, skipping creation.
INFO 07-25 07:16:44 [platform.py:157] PIECEWISE compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode
INFO 07-25 07:16:44 [utils.py:311] Calculated maximum supported batch sizes for ACL graph: 76
INFO 07-25 07:16:44 [utils.py:337] No adjustment needed for ACL graph batch sizes: Qwen2ForCausalLM model (layers: 24) with 67 sizes
Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 260.74it/s]
Processed prompts: 100%|████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.64it/s, est. speed input: 14.53 toks/s, output: 264.15 toks/s]
Prompt: 'Hello, my name is', Generated text: ' Alex and I am a 16 year old male. I have been diagnosed with a rare genetic disorder called X-linked recessive. I have been told that I will not be able to have children. I have been told that I will not be able to have children because of the X-linked recessive disorder. I have been told that I will not be able to have children because of the X-linked recessive disorder. I have been told that I will not be able to have children because of'
Prompt: 'The president of the United States is', Generated text: ' a very important person. He is the leader of the country. He is the president of the United States. He is the leader of the country. He is the leader of the country. He is the leader of the country. He is the leader of the country. He is the leader of the country. He is the leader of the country. He is the leader of the country. He is the leader of the country. He is the leader of the country. He is the leader of the'
Prompt: 'The capital of France is', Generated text: ' Paris. It is the largest city in Europe and the second largest city in the world. It is located in the south of France, on the banks of the Seine River. It is situated on the Île de la Cité, which is a small island in the center of the city. The city is surrounded by the Seine River and the Mediterranean Sea. The city is also surrounded by the Pyrenees mountains. The city is home to many famous landmarks, including the Eiffel'
Prompt: 'The future of AI is', Generated text: ' in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of'

(Optional) Documentation Update

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request fixes the torch.compile backend to support more backends. Consider using lazy compilation to allow backend selection based on runtime parameters.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Evaluating current_platform.simple_compile_backend at module import time makes the backend choice static for the lifetime of the process. Consider using lazy compilation to allow backend selection based on runtime parameters.1

def _batched_count_greater_than_impl(x: torch.Tensor,
                                     values: torch.Tensor) -> torch.Tensor:
    """Implementation of batched_count_greater_than."""
    return (x > values[..., None]).count_nonzero(dim=-1)

_cached_compiled_fn = None

def batched_count_greater_than(x: torch.Tensor,
                               values: torch.Tensor) -> torch.Tensor:
    """
    For each row in `x`, counts the number of elements that are greater than
    the corresponding value in `values`.

    Args:
        x: A 2D tensor of shape (num_rows, num_elements).
        values: A 1D tensor of shape (num_rows,).
    """
    global _cached_compiled_fn
    if _cached_compiled_fn is None:
        from vllm.platforms import current_platform
        _cached_compiled_fn = torch.compile(
            dynamic=True,
            backend=current_platform.simple_compile_backend
        )(_batched_count_greater_than_impl)
    
    return _cached_compiled_fn(x, values)

Style Guide References

Footnotes

  1. Use lazy compilation to allow backend selection based on runtime parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using a static compile backend is enough for a specific platform.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@Yikun
Copy link
Member

Yikun commented Jul 25, 2025

@houseroad Would you mind taking a look? Many thanks.

@Yikun
Copy link
Member

Yikun commented Jul 25, 2025

@MengqingCao are doing a e2e test, we will paste e2e results after tests complete.

(later) Updated on PR description, I also do a e2e test on vllm-project/vllm-ascend#1927 it's works as expected.

Signed-off-by: MengqingCao <cmq0113@163.com>
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) July 25, 2025 10:26
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 25, 2025
@vllm-bot vllm-bot merged commit f3a683b into vllm-project:main Jul 25, 2025
61 of 66 checks passed
liuyumoye pushed a commit to liuyumoye/vllm that referenced this pull request Jul 31, 2025
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
…ect#21591)

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: x22x22 <wadeking@qq.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…ect#21591)

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…ect#21591)

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…ect#21591)

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Failed to execute_model with logprobs on v0.10.0rc2

4 participants