Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,14 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

# Vscode
.vscode/

# bash
evaluation/modified_evaluate.py
evaluation/a100_*.sh
evaluation/4090_*.sh
evaluation/logs/*
evaluation/results/*
evaluation/ruler/data/*
10 changes: 9 additions & 1 deletion evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

import torch
from datasets import load_dataset
import os
from fire import Fire
from infinite_bench.calculate_metrics import calculate_metrics as infinite_bench_scorer
from kvpress.ada_attn import replace_var_flash_attn
from loogle.calculate_metrics import calculate_metrics as loogle_scorer
from ruler.calculate_metrics import calculate_metrics as ruler_scorer
from tqdm import tqdm
Expand All @@ -23,6 +25,8 @@
RandomPress,
SnapKVPress,
StreamingLLMPress,
AdaSnapKVPress,
AdaScorerPress
)

logger = logging.getLogger(__name__)
Expand All @@ -48,6 +52,7 @@
"random": RandomPress(),
"snapkv": SnapKVPress(),
"streaming_llm": StreamingLLMPress(),
"ada_snapkv": AdaSnapKVPress()
}


Expand Down Expand Up @@ -124,10 +129,13 @@ def evaluate(
# Initialize pipeline with the correct attention implementation
if isinstance(press, ObservedAttentionPress):
model_kwargs = {"attn_implementation": "eager"}
# Support AdaKV
elif isinstance(press, AdaScorerPress):
replace_var_flash_attn(model=model)
model_kwargs = {"attn_implementation": "flash_attention_2"}
else:
try:
import flash_attn # noqa: F401

model_kwargs = {"attn_implementation": "flash_attention_2"}
except ImportError:
model_kwargs = {}
Expand Down
2 changes: 1 addition & 1 deletion evaluation/evaluate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ dataset="ruler"
data_dir="4096"
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
compression_ratios=(0.1 0.25 0.5)
press_names=("expected_attention" "knorm" "streaming_llm" "snapkv")
press_names=("expected_attention" "knorm" "streaming_llm" "snapkv" "ada_snapkv")

# Check if the number of press names is less than or equal to the number of available GPUs
num_gpus=$(nvidia-smi --list-gpus | wc -l)
Expand Down
7 changes: 6 additions & 1 deletion kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.ada_scorer_press import AdaScorerPress
from kvpress.presses.ada_snapkv_press import AdaSnapKVPress


__all__ = [
"BasePress",
"ComposedPress",
"ScorerPress",
"AdaScorerPress",
"ExpectedAttentionPress",
"KnormPress",
"ObservedAttentionPress",
Expand All @@ -29,6 +34,6 @@
"TOVAPress",
"KVPressTextGenerationPipeline",
"PerLayerCompressionPress",
"AdaSnapKVPress",
]

from kvpress.presses.tova_press import TOVAPress
Loading