Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
816ddfa
enabling runtime optimization
grzegorz-k-karch Apr 28, 2026
7aa5fe7
Merge branch 'main' into gkarch/runtime_opt
grzegorz-k-karch Apr 28, 2026
3041dc2
done ruff formatting and docstrings
grzegorz-k-karch Apr 28, 2026
a363750
distributed timeout is configurable
grzegorz-k-karch May 4, 2026
8739fa0
Merge branch 'main' into gkarch/runtime_opt
grzegorz-k-karch May 4, 2026
53a2caf
added example config for attn pruning and runtime constraint
grzegorz-k-karch May 4, 2026
dfb905c
renamed configs
grzegorz-k-karch May 5, 2026
e165171
working on readme
grzegorz-k-karch May 6, 2026
d47b69c
working on refactoring
grzegorz-k-karch May 6, 2026
12ed46b
working on fix
grzegorz-k-karch May 17, 2026
ab925b9
runtime accuracy improved
grzegorz-k-karch May 18, 2026
58f17e4
using vllm api instead of subprocess
grzegorz-k-karch May 18, 2026
e868303
working on review feedback
grzegorz-k-karch May 19, 2026
f7be643
removed unused batch_size; cleaned up config loading
grzegorz-k-karch May 19, 2026
8423676
Merge branch 'main' into gkarch/runtime_opt
grzegorz-k-karch May 19, 2026
49235d1
cleanup based on pre-commit
grzegorz-k-karch May 19, 2026
781d44d
added docstrings
grzegorz-k-karch May 19, 2026
a1901c7
updated readme
grzegorz-k-karch May 19, 2026
0b75502
further changes based on review
grzegorz-k-karch May 19, 2026
7e2f995
further changes based on review
grzegorz-k-karch May 19, 2026
2ca5306
removed synth_dataset_num_requests
grzegorz-k-karch May 19, 2026
ca21748
removed duplicate model saving
grzegorz-k-karch May 19, 2026
26ceb36
added test
grzegorz-k-karch May 20, 2026
4c5b133
suppressing bandit warnings B404 and B603; precedence found in repo
grzegorz-k-karch May 20, 2026
398808a
removed gpu utilization param
grzegorz-k-karch May 21, 2026
e468f62
wip
grzegorz-k-karch May 21, 2026
34dbe52
removed redundant configs; guards for vllm results
grzegorz-k-karch May 21, 2026
24fa2d5
following annotation suggestion
grzegorz-k-karch May 21, 2026
4b824f1
updated readme
grzegorz-k-karch May 22, 2026
ae25ec7
moved stats utils from nas to puzzletron
grzegorz-k-karch May 22, 2026
c14cad0
Merge branch 'main' into gkarch/runtime_opt
grzegorz-k-karch May 27, 2026
f34d3a3
responding to reviews
grzegorz-k-karch May 27, 2026
3332149
reenabled some vars
grzegorz-k-karch May 27, 2026
88e16d7
added support for batch_sizes
grzegorz-k-karch May 28, 2026
3f69e55
further fixes
grzegorz-k-karch May 28, 2026
7e48dbd
Merge branch 'main' into gkarch/runtime_opt
grzegorz-k-karch May 30, 2026
36f4685
using 5s latency target in the example
grzegorz-k-karch May 31, 2026
b1b810f
added vllm adapter
grzegorz-k-karch Jun 8, 2026
354dd8d
Merge branch 'main' into gkarch/runtime_opt
grzegorz-k-karch Jun 8, 2026
f49fbc9
disabled vllm tests that depends on anymodel
grzegorz-k-karch Jun 8, 2026
cebc4cd
Merge branch 'main' into gkarch/runtime_opt
kevalmorabia97 Jun 8, 2026
d6e1c6b
Fix CI failures
kevalmorabia97 Jun 8, 2026
105c736
Merge branch 'main' into gkarch/runtime_opt
kevalmorabia97 Jun 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@


# Mock imports for autodoc
autodoc_mock_imports = ["mpi4py", "tensorrt_llm", "triton"]
autodoc_mock_imports = ["mpi4py", "tensorrt_llm", "triton", "vllm"]

autosummary_generate = True
autosummary_imported_members = False
Expand Down
40 changes: 40 additions & 0 deletions examples/puzzletron/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,46 @@ See [Megatron-Bridge distillation](../megatron_bridge/README.md#distillation) fo

For distillation results on Puzzletron-compressed models, see [examples/pruning/puzzletron/](../pruning/puzzletron/README.md).

## Runtime-Based Latency Optimization
Comment thread
grzegorz-k-karch marked this conversation as resolved.

You can enable **runtime stats** to measure actual inference latency via vLLM, which unlocks latency-based MIP constraints.

A ready-to-run example config is included at [`configs/llama-3_1-8B_pruneffn_runtime/`](./configs/llama-3_1-8B_pruneffn_runtime/llama-3_1-8B_pruneffn_runtime.yaml). The following key fields enable and control execution of the runtime statistics in the `llama-3_1-8B_pruneffn_runtime.yaml` config file:

```yaml
calc_subblock_stats:
runtime_stats:
enabled: true
num_warmup_iters: 2
num_iters: 10
```

The runtime constraint is specified in the `human_constraints` section of the config `Llama-3_1-8B.yaml`:

```yaml
human_constraints:
target_latency_seconds: 21
```

Run the pipeline against this config the same way as the memory-constrained variant:

```bash
torchrun --nproc_per_node 2 examples/puzzletron/main.py \
--config examples/puzzletron/configs/llama-3_1-8B_pruneffn_runtime/llama-3_1-8B_pruneffn_runtime.yaml 2>&1 | tee ./log.txt | grep "Puzzletron Progress"
```

The MIP solver will now search for a heterogeneous architecture whose measured end-to-end latency is at or below `target_latency_seconds`, instead of optimizing for a memory budget.

Because vLLM startup adds substantial overhead during stats collection, extend the distributed process group timeout accordingly (already included in the example config):

```yaml
nccl_timeout_minutes: 90 # default is 10 if omitted
```

This field is supported in any Puzzletron YAML config and overrides the default 10-minute distributed timeout.

Due to non-linear extension of the runtime stats of single subblocks to the total runtime of the model, the `target_latency_seconds` value should be set to a value that is slightly lower than the desired latency. For example, in our experiments, the `target_latency_seconds` value of 5 resulted in a final model latency of 5.4 seconds.

## Advanced Usage

Modify `llama-3_1-8B_pruneffn_memory.yaml` file for advanced compression scenarios.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ scoring:
teacher_dir: ${to_path:${teacher_dir}}
output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation

eval_samples: 128
eval_samples: 8
micro_batch_size: 1
seed: 42
shuffle_seed: 444
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model
block_size: 8192
bos_rate: 0.5
data_column: messages
val_dataset_name: valid
val_dataset_name: validation
shuffle_seed: 81436
seed: 42
fim_rate: 0
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
defaults:
- ../llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning@pruning
- ../llama-3_1-8B_pruneffn_memory/validate_solutions_defaults@scoring
- ../llama-3_1-8B_pruneffn_memory/validate_solutions_defaults@realize_model
- bypass:
- override hydra/hydra_logging: disabled
- _self_

puzzle_dir: ???
descriptor: llama
teacher_dir: ${puzzle_dir}/ckpts/teacher/
replacement_library_path: ${puzzle_dir}/replacement_library.json
dataset_path: ??? # ppath to Nemotron-Post-Training-Dataset-v2

skip_realize_model: false

build_replacement_library:
add_ffn_no_ops: true
add_attention_no_ops: true

calc_subblock_stats:
batch_sizes: [1, 4]
prefill_seq_len: 1024
generation_seq_len: 1024
num_active_tokens_override: # Optional override for sequence lengths
prefill_queue_size: 0
allocate_prefill_query: false
merge_with_existing_stats: false
subblock_stats_filename: "subblock_stats.json"
moe_stats_filename: "moe_stats.json"

scoring:
descriptor: ${descriptor}
solutions_to_validate:
skip_existing_solutions: true

replacement_library_path: ${replacement_library_path}
solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
teacher_dir: ${to_path:${teacher_dir}}
output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation

eval_samples: 128
micro_batch_size: 1
seed: 42
shuffle_seed: 444
dataset_path: ${dataset_path}

mip:
single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
gathered_metrics_path:
puzzle_profile:

# puzzle_profile:
objective: metrics.cosine_embedding_loss_hidden_states
bigger_is_better: false

subblock_stats_args:
- batch_size: 1
weights_dtype: torch.bfloat16

report_additional_costs:
- stats.memory_mib
- stats.num_params
- stats.num_kv_heads
- stats.has_attention
- stats.has_ffn
- stats.kv_cache_memory_mib
- stats.attention_memory_mib
- stats.ffn_memory_mib
- stats.ffn_num_params
- stats.attention_num_params

human_constraints:
target_latency_seconds: 5

mip_constraints:
metric_overrides:
max_seconds_per_solution: 60

realize_model:
descriptor: ${descriptor}
teacher_dir: ${to_path:${teacher_dir}}
tokenizer_name: ${to_path:${teacher_dir}}
replacement_library_path: ${replacement_library_path}
save_models: true
solutions_path: # Filled dynamically

# Validate params
skip_validation: false # To enable validation of the model solution set `skip_validation` as False
eval_samples: 128
micro_batch_size: 1
seed: 42
shuffle_seed: 444
dataset_path: ${dataset_path}

nccl_timeout_minutes: ${timedelta_minutes:120}

# This section redirects Hydra outputs
hydra:
run:
dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
defaults:
- Llama-3_1-8B
- _self_

# Input Hugging Face model to compress
input_hf_model_path: /workspace/hf_models/meta-llama/Llama-3.1-8B-Instruct

# Dataset path for pruning and NAS scoring
dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2

# Working directory for puzzletron outputs
puzzle_dir: /workspace/puzzle_dir

calc_subblock_stats:
runtime_stats:
enabled: true
num_warmup_iters: 2
num_iters: 10

# FFN intermediate sizes to search over (heterogeneous architecture)
pruning:
intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336
9 changes: 8 additions & 1 deletion examples/puzzletron/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def run_full_puzzletron(hydra_config_path: str):
config_path: Path to the YAML configuration file
"""
mtpz.tools.mprint("Puzzletron Progress 1/8: starting puzzletron pipeline")
dist.setup(timeout=timedelta(minutes=10))

# Register Hydra custom resolvers (needed for config resolution)
mtpz.tools.register_hydra_resolvers()
Expand All @@ -84,6 +83,14 @@ def run_full_puzzletron(hydra_config_path: str):
overrides=[],
)

# Default timeout: 10 minutes, or extended to nccl_timeout_minutes if set in config
if hasattr(hydra_cfg, "nccl_timeout_minutes"):
timeout_minutes = hydra_cfg.nccl_timeout_minutes
else:
timeout_minutes = timedelta(minutes=10)

dist.setup(timeout=timeout_minutes)
Comment thread
grzegorz-k-karch marked this conversation as resolved.

# Convert model (convert from HF to DeciLM, score pruning activations,
# prune the model and save pruned checkpoints)
input_model = mtpz.puzzletron_nas_plugin.PuzzletronModel()
Expand Down
13 changes: 7 additions & 6 deletions modelopt/torch/kernels/sparsity/attention/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,17 +200,18 @@ def attention_calibrate(
measuring how many KV tiles would be skipped at each threshold in
``threshold_trials``. No autograd — forward only.

All arguments except ``threshold_trials`` match
:func:`modelopt.torch.kernels.common.attention.attention`.

Args:
q, k, v, b_start_loc, b_seq_len, max_input_len, is_causal,
softmax_scale, b_start_loc_k, b_seq_len_k, max_input_len_k:
Same as :func:`modelopt.torch.kernels.common.attention.attention`.
threshold_trials: List of threshold values to measure sparsity for.
Each value is converted to log2-scaled space for the kernel.

Returns:
Tuple of (output, sparsity_counters):
- output: ``[total_q_tokens, num_q_heads, head_dim]``
- sparsity_counters: ``[num_thresholds, 2]`` int64 tensor where
Tuple of ``(output, sparsity_counters)``:

- ``output``: ``[total_q_tokens, num_q_heads, head_dim]``
- ``sparsity_counters``: ``[num_thresholds, 2]`` int64 tensor where
``[:, 0]`` = total tile evaluations, ``[:, 1]`` = skipped tiles.
Sparsity per threshold = ``counters[:, 1] / counters[:, 0]``.
"""
Expand Down
6 changes: 3 additions & 3 deletions modelopt/torch/puzzletron/mip/run_puzzle.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class Type(enum.Enum):
_ALLOWED_HUMAN_CONSTRAINTS = {
"target_memory",
"target_throughput",
"target_latency",
"target_latency_seconds",
"target_time_to_first_token",
"num_params",
"stats.has_attention",
Expand Down Expand Up @@ -175,8 +175,8 @@ def to_mip_constraints(self, subblock_stats_args) -> dict[str, Any]:
throughput_constraints.append(
batch_size * generation_seq_len / self.constraints["target_throughput"]
)
if "target_latency" in self.constraints:
throughput_constraints.append(self.constraints["target_latency"])
if "target_latency_seconds" in self.constraints:
throughput_constraints.append(self.constraints["target_latency_seconds"])
if throughput_constraints:
mip_constraints["stats.runtime_ms"] = 1000 * min(throughput_constraints)

Expand Down
1 change: 0 additions & 1 deletion modelopt/torch/puzzletron/subblock_stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,4 @@

"""Subblock statistics collection for Puzzletron."""

from .calc_subblock_params_and_memory import *
from .calc_subblock_stats import *
Loading
Loading