Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ Here are a few settings you may be interested in:
- `--attn_softmax_bf16` to run attention softmax layer in bfloat16 precision provided that the model (such as Llama) supports it
- `--trim_logits` to calculate logits only for the last token in the first time step provided that the model (such as Llama) supports it
- `--fp8` Enable Quantization to fp8
- `--kv_cache_fp8` Deprecated - Store kv-cache in float8 when kv-cache is used. should not be used with HQT(The Quantization Toolkit)

For example, you can reproduce the results presented in [this blog post](https://huggingface.co/blog/habana-gaudi-2-bloom) with the following command:
```bash
Expand Down
9 changes: 0 additions & 9 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,6 @@ def setup_parser(parser):
help="Preprocess on cpu, and some other optimizations. Useful to prevent recompilations when using dynamic prompts (simulate_dyn_prompt)",
)

parser.add_argument(
"--kv_cache_fp8",
action="store_true",
help="Store kv-cache in float8 when kv-cache is used. Can't use this argument together with QUANT_CONFIG env var",
)
parser.add_argument("--fp8", action="store_true", help="Enable Quantization to fp8")
parser.add_argument(
"--use_flash_attention",
Expand Down Expand Up @@ -259,10 +254,6 @@ def setup_parser(parser):
args.limit_hpu_graphs = False

args.quant_config = os.getenv("QUANT_CONFIG", "")
if args.quant_config and args.kv_cache_fp8:
# can't use both quant_config and kv_cache_fp8, since quant_config may trigger kv cache quantization
# with habana quantization toolkit
raise parser.error("Can't use QUANT_CONFIG env var with kv_cache_fp8 argument")
return args


Expand Down
7 changes: 1 addition & 6 deletions examples/text-generation/run_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,7 @@ def _model_call(self, inps):
if self.options.static_shapes:
bucket_length = self.find_bucket(seq_length)
if self.options.use_cache and self.options.reuse_cache:
self.model.allocate_kv_cache(
bs,
bucket_length + 1,
bucket_length,
False,
)
self.model.allocate_kv_cache(bs, bucket_length + 1, bucket_length)
padding_length = bucket_length - seq_length
inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id)
logits = self.model(inps.to(self._device), **self.model_inputs)["logits"].cpu()
Expand Down
35 changes: 16 additions & 19 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ def setup_distributed(args):
args.global_rank = int(os.getenv("RANK", "0"))


def setup_quantization(args, model):
import habana_frameworks.torch.core as htcore
from habana_frameworks.torch.hpu import hpu

print("Initializing inference with quantization")
if not args.quant_config:
hpu.enable_quantization()
htcore.hpu_initialize(model)
return model
def setup_const_serialization(const_serialization_path):
import uuid

const_serialization_path = os.path.join(const_serialization_path + uuid.uuid4().hex)
os.makedirs(const_serialization_path)
from habana_frameworks.torch.hpu import enable_const_section_serialization

print("Serializing const params to {}".format(const_serialization_path))
enable_const_section_serialization(const_serialization_path, False, True)


def setup_env(args):
Expand Down Expand Up @@ -346,7 +346,6 @@ def setup_generation_config(args, model, tokenizer):
generation_config.reduce_recompile = args.reduce_recompile
if generation_config.reduce_recompile:
assert generation_config.bucket_size > 0
generation_config.kv_cache_fp8 = args.kv_cache_fp8
generation_config.use_flash_attention = args.use_flash_attention
return generation_config

Expand Down Expand Up @@ -383,16 +382,14 @@ def initialize_model(args, logger):
generation_config = setup_generation_config(args, model, tokenizer)

if args.const_serialization_path:
import uuid

args.const_serialization_path = os.path.join(args.const_serialization_path + uuid.uuid4().hex)
os.makedirs(args.const_serialization_path)
from habana_frameworks.torch.hpu import enable_const_section_serialization

print("Serializing const params to {}".format(args.const_serialization_path))
enable_const_section_serialization(args.const_serialization_path, True)
setup_const_serialization(args.const_serialization_path)
if args.fp8:
model = setup_quantization(args, model)
import habana_frameworks.torch.core as htcore

print("Initializing inference mode")
const_marking = os.getenv("ENABLE_CONST_MARKING", "True")
if const_marking == "True":
htcore.hpu_initialize(model)
init_end = time.perf_counter()
logger.info(f"Args: {args}")
logger.info(f"device: {args.device}, n_hpu: {args.world_size}, bf16: {model_dtype == torch.bfloat16}")
Expand Down
3 changes: 0 additions & 3 deletions optimum/habana/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class GaudiGenerationConfig(GenerationConfig):
Only active if `static_shapes` is used. Can't be used with `reuse_cache`.
bucket_internal (`bool`, *optional*):
Split kv sequence into buckets in decode phase. It improves throughput when max_new_tokens is large.
Comment thread
HolyFalafel marked this conversation as resolved.
kv_cache_fp8 (`bool`, *optional*):
Store kv-cache in float8 when kv-cache is used
use_flash_attention (`bool`, *optional*):
Whether to use flash attention optimization.
flash_attention_recompute (`bool`, *optional*):
Expand All @@ -48,7 +46,6 @@ def __init__(self, **kwargs):
self.bucket_size = kwargs.get("bucket_size", -1)
self.bucket_internal = kwargs.get("bucket_internal", None)
self.reduce_recompile = kwargs.get("reduce_recompile", None)
self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None)
self.use_flash_attention = kwargs.get("use_flash_attention", None)
self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None)
self.use_fused_rope = kwargs.get("use_fused_rope", None)
5 changes: 1 addition & 4 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,10 +730,7 @@ def generate(
bs, _ = input_ids.shape
if not is_greedy_or_beam_and_bucket:
unwrap_deepspeed_model(self).allocate_kv_cache(
bs * generation_config.num_beams,
calculated_max_length,
token_idx,
generation_config.kv_cache_fp8,
bs * generation_config.num_beams, calculated_max_length, token_idx
)
model_kwargs["kv_cache_len"] = calculated_max_length

Expand Down
79 changes: 37 additions & 42 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,6 @@
FusedSDPA = None


def update(prev, cur, dim, idx, inp_seq_len):
orig_cur = cur
if prev.dtype == torch.float8_e4m3fn:
from habana_frameworks.torch.hpex.kernels.Fp8Ops import cast_to_fp8_v2

cur = cast_to_fp8_v2(cur, None, False, False, prev.dtype)[0]
if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
# Initialize
prev[:, :, :inp_seq_len, :].copy_(cur)
return orig_cur
assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}"
if idx is not None:
prev.index_copy_(dim, idx - 1, cur)
prev_cast = prev.to(orig_cur.dtype)
return prev_cast
else:
return torch.cat((prev, cur), dim=dim)


def gaudi_llama_rmsnorm_forward(self, hidden_states):
"""
Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Expand Down Expand Up @@ -171,25 +152,39 @@ def __init__(self):
self.cache = None
self.inp_seq_len = -1

def allocate(self, inp_seq_len, kv_cache_fp8, dtype, device, shape):
def allocate(self, inp_seq_len, dtype, device, shape):
if self.cache is None or self.cache.shape != shape:
self.inp_seq_len = inp_seq_len
if kv_cache_fp8:
dtype = torch.float8_e4m3fn
self.cache = torch.zeros(shape, dtype=dtype, device=device)
else:
assert (
self.inp_seq_len == inp_seq_len
), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
self.cache.fill_(0)

def update(self, prev, cur, dim, idx, inp_seq_len):
orig_cur = cur
if prev.shape == cur.shape:
prev.copy_(cur)
return orig_cur
if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
# Initialize
prev[:, :, :inp_seq_len, :].copy_(cur)
return orig_cur
assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}"
if idx is not None:
prev.index_copy_(dim, idx - 1, cur)
return prev
else:
return torch.cat((prev, cur), dim=dim)

def get_shape(self):
if self.cache is None:
return None
return self.cache.shape

def forward(self, cur, dim, idx):
return update(self.cache, cur, dim, idx, self.inp_seq_len)
return self.update(self.cache, cur, dim, idx, self.inp_seq_len)


class GaudiLlamaRotaryEmbedding(torch.nn.Module):
Expand Down Expand Up @@ -273,12 +268,12 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.inp_seq_len = -1
self.norm_factor = 1.0 / math.sqrt(self.head_dim)

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8):
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim)
device = self.k_proj.weight.device
dtype = self.config.torch_dtype
self.k_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape)
self.v_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape)
self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape)
self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape)

def update_sincos_cache(self, seq_len):
# Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings
Expand Down Expand Up @@ -373,27 +368,28 @@ def pre_attn_forward(
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None or reuse_cache:
if use_cache:
# reuse k, v, self_attention
if reuse_cache:
key_states = self.k_cache(key_states, 2, token_idx)
value_states = self.v_cache(value_states, 2, token_idx)
past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
else:
key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len)
if past_key_value is None:
past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device)
past_value = torch.zeros(
key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device
)
past_key_value = (past_key, past_value)
key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len)

if cache_idx is not None and q_len == 1:
key_states = key_states[:, :, :cache_idx, :]
value_states = value_states[:, :, :cache_idx, :]
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, :cache_idx]
kv_seq_len = key_states.shape[-2]

if use_cache:
if reuse_cache:
past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
else:
past_key_value = (key_states.contiguous(), value_states.contiguous())
else:
past_key_value = None

Expand Down Expand Up @@ -475,8 +471,8 @@ def __init__(self, config: LlamaConfig, layer_idx: int):
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8):
self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8)
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)

def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return self.self_attn.reorder_kv_cache(beam_idx)
Expand Down Expand Up @@ -631,9 +627,9 @@ def __init__(self, config: LlamaConfig):
# Initialize weights and apply final processing
self.post_init()

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8):
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
for layer in self.layers:
layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8)
layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)

def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers)
Expand Down Expand Up @@ -822,9 +818,8 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM):
- add new args reuse_cache
"""

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8):
self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8)
self.kv_cache_len = max_seq_len
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)

def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return self.model.reorder_kv_cache(beam_idx)
Expand Down