Skip to content

Neuron backend may select wrong batch size for cached model #3299

@jimburtoft

Description

@jimburtoft

System Info

Container:
https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference/445854433?tag=sha-9f38d93-neuron

Running on HF Neuron DLAMI.

Information

  • Docker
  • The CLI directly

Tasks

  • An officially supported command
  • My own modifications

Reproduction

Running on the HF Neuron DLAMI.

docker-compose.yaml:

services:
  tgi-1:
    image: ghcr.io/huggingface/text-generation-inference:latest-neuron
    ports:
      - "8080:8080"
    environment:
      - PORT=8080
      - MODEL_ID=ibm-granite/granite-3.1-8b-instruct
      #- HF_AUTO_CAST_TYPE='bf16'
      - HF_NUM_CORES=2
      - MAX_BATCH_SIZE=4
      #- MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS}
      - MAX_TOTAL_TOKENS=4096
      #- MAX_CONCURRENT_REQUESTS=512
      #- HF_TOKEN=${HF_TOKEN} #only needed for gated models
    devices:
      - "/dev/neuron0"

(aws_neuronx_venv_pytorch_2_5) ubuntu@ip-172-31-10-149:~$ docker compose -f docker-compose.yaml up

[+] Running 1/1
 ✔ Container ubuntu-tgi-1-1  Created                                                                                  0.0s 
Attaching to tgi-1-1
tgi-1-1  | WARNING:root:MASTER_ADDR environment variable is not set, defaulting to localhost
tgi-1-1  | WARNING:root:Found libneuronpjrt.so. Setting PJRT_DEVICE=NEURON.
tgi-1-1  | /usr/local/lib/python3.10/dist-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
tgi-1-1  |   from neuronx_distributed.modules.moe.blockwise import (
tgi-1-1  | /usr/local/lib/python3.10/dist-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
tgi-1-1  |   from neuronx_distributed.modules.moe.blockwise import (
tgi-1-1  | /usr/local/lib/python3.10/dist-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
tgi-1-1  |   from neuronx_distributed.modules.moe.blockwise import (
tgi-1-1  | 2025-07-18T20:33:54.299502Z  INFO text_generation_launcher: Args {
tgi-1-1  |     model_id: "ibm-granite/granite-3.1-8b-instruct",
tgi-1-1  |     revision: None,
tgi-1-1  |     validation_workers: 2,
tgi-1-1  |     sharded: None,
tgi-1-1  |     num_shard: None,
tgi-1-1  |     quantize: None,
tgi-1-1  |     speculate: None,
tgi-1-1  |     dtype: None,
tgi-1-1  |     kv_cache_dtype: None,
tgi-1-1  |     trust_remote_code: false,
tgi-1-1  |     max_concurrent_requests: 128,
tgi-1-1  |     max_best_of: 2,
tgi-1-1  |     max_stop_sequences: 4,
tgi-1-1  |     max_top_n_tokens: 5,
tgi-1-1  |     max_input_tokens: Some(
tgi-1-1  |         2048,
tgi-1-1  |     ),
tgi-1-1  |     max_input_length: None,
tgi-1-1  |     max_total_tokens: Some(
tgi-1-1  |         4096,
tgi-1-1  |     ),
tgi-1-1  |     waiting_served_ratio: 0.3,
tgi-1-1  |     max_batch_prefill_tokens: Some(
tgi-1-1  |         16384,
tgi-1-1  |     ),
tgi-1-1  |     max_batch_total_tokens: None,
tgi-1-1  |     max_waiting_tokens: 20,
tgi-1-1  |     max_batch_size: Some(
tgi-1-1  |         8,
tgi-1-1  |     ),
tgi-1-1  |     cuda_graphs: None,
tgi-1-1  |     hostname: "07f2479259fe",
tgi-1-1  |     port: 8080,
tgi-1-1  |     prometheus_port: 9000,
tgi-1-1  |     shard_uds_path: "/tmp/text-generation-server",
tgi-1-1  |     master_addr: "localhost",
tgi-1-1  |     master_port: 29500,
tgi-1-1  |     huggingface_hub_cache: Some(
tgi-1-1  |         "/tmp",
tgi-1-1  |     ),
tgi-1-1  |     weights_cache_override: None,
tgi-1-1  |     disable_custom_kernels: false,
tgi-1-1  |     cuda_memory_fraction: 1.0,
tgi-1-1  |     rope_scaling: None,
tgi-1-1  |     rope_factor: None,
tgi-1-1  |     json_output: false,
tgi-1-1  |     otlp_endpoint: None,
tgi-1-1  |     otlp_service_name: "text-generation-inference.router",
tgi-1-1  |     cors_allow_origin: [],
tgi-1-1  |     api_key: None,
tgi-1-1  |     watermark_gamma: None,
tgi-1-1  |     watermark_delta: None,
tgi-1-1  |     ngrok: false,
tgi-1-1  |     ngrok_authtoken: None,
tgi-1-1  |     ngrok_edge: None,
tgi-1-1  |     tokenizer_config_path: None,
tgi-1-1  |     disable_grammar_support: false,
tgi-1-1  |     env: false,
tgi-1-1  |     max_client_batch_size: 4,
tgi-1-1  |     lora_adapters: None,
tgi-1-1  |     usage_stats: On,
tgi-1-1  |     payload_limit: 2000000,
tgi-1-1  |     enable_prefill_logprobs: false,
tgi-1-1  |     graceful_termination_timeout: 90,
tgi-1-1  | }
tgi-1-1  | 2025-07-18T20:33:55.392501Z  WARN text_generation_launcher::gpu: Cannot determine GPU compute capability: AssertionError: Torch not compiled with CUDA enabled
tgi-1-1  | 2025-07-18T20:33:55.392531Z  INFO text_generation_launcher: Using attention flashinfer - Prefix caching true
tgi-1-1  | 2025-07-18T20:33:55.392537Z  INFO text_generation_launcher: Using default cuda graphs [1, 2, 4, 8, 16, 32]
tgi-1-1  | 2025-07-18T20:33:55.392638Z  INFO download: text_generation_launcher: Starting check and download process for ibm-granite/granite-3.1-8b-instruct
tgi-1-1  | 2025-07-18T20:33:55.483205Z  WARN text_generation_launcher: 'extension' argument is not supported and will be ignored.
tgi-1-1  | 2025-07-18T20:33:55.483246Z  WARN text_generation_launcher: 'merge_lora' argument is not supported and will be ignored.
tgi-1-1  | 2025-07-18T20:34:02.614961Z  WARN text_generation_launcher: ibm-granite/granite-3.1-8b-instruct is not a neuron model: it will be exported using cached artifacts.
tgi-1-1  | 2025-07-18T20:34:02.615033Z  INFO text_generation_launcher: Cache disk [/tmp]: total = 496.03 G, free = 396.94 G
tgi-1-1  | 2025-07-18T20:34:02.647749Z  INFO text_generation_launcher: Model weights fetched in 0.03 s.
tgi-1-1  | 2025-07-18T20:34:02.647828Z  INFO text_generation_launcher: Cache disk [/tmp]: total = 496.03 G, free = 396.94 G
tgi-1-1  | 2025-07-18T20:34:03.403412Z  INFO download: text_generation_launcher: Successfully downloaded weights for ibm-granite/granite-3.1-8b-instruct
tgi-1-1  | 2025-07-18T20:34:03.403744Z  INFO shard-manager: text_generation_launcher: Starting shard rank=0
tgi-1-1  | 2025-07-18T20:34:07.269504Z  INFO text_generation_launcher: Exporting model to neuron with config: {'batch_size': 8, 'sequence_length': 4096, 'num_cores': 2, 'auto_cast_type': 'bf16'}.
tgi-1-1  | 2025-07-18T20:34:13.415084Z  INFO shard-manager: text_generation_launcher: Waiting for shard to be ready... rank=0
tgi-1-1  | 2025-07-18T20:34:23.422447Z  INFO shard-manager: text_generation_launcher: Waiting for shard to be ready... rank=0
tgi-1-1  | 2025-07-18T20:34:33.430380Z  INFO shard-manager: text_generation_launcher: Waiting for shard to be ready... rank=0
tgi-1-1  | 2025-07-18T20:34:43.442764Z  INFO shard-manager: text_generation_launcher: Waiting for shard to be ready... rank=0
tgi-1-1  | 2025-07-18T20:34:51.219431Z  INFO text_generation_launcher: Model successfully loaded in 43.95 s.
tgi-1-1  | 2025-07-18T20:34:51.343924Z  INFO text_generation_launcher: Server started at unix:///tmp/text-generation-server-0
tgi-1-1  | 2025-07-18T20:34:51.349221Z  INFO shard-manager: text_generation_launcher: Shard ready in 47.94151876s rank=0
tgi-1-1  | 2025-07-18T20:34:51.442177Z  INFO text_generation_launcher: Starting Webserver
tgi-1-1  | 2025-07-18T20:34:51.453274Z  INFO text_generation_router_v2: backends/v2/src/lib.rs:88: Warming up model

Expected behavior

In the docker-compose, I specify:
- MAX_BATCH_SIZE=4
For model: ibm-granite/granite-3.1-8b-instruct
That model is cached with num_cores=2 for batch size 4 and 8:
https://huggingface.co/aws-neuron/optimum-neuron-cache/blob/main/inference-cache-config/granite.json#L43

However, in the log above, it shows it is loading batch size 8
tgi-1-1 | max_batch_size: Some(
tgi-1-1 | 8,
tgi-1-1 | ),

tgi-1-1 | 2025-07-18T20:34:07.269504Z INFO text_generation_launcher: Exporting model to neuron with config: {'batch_size': 8, 'sequence_length': 4096, 'num_cores': 2, 'auto_cast_type': 'bf16'}.

I think I see the problem and will submit a PR for it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions