Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add latest support for flash-attention from hf #727

Merged
merged 1 commit into from
Mar 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ pip install -e .

gpu_state="$(nvidia-smi --query-gpu=name --format=csv,noheader)"
if [[ "${gpu_state}" == *"A100"* || "${gpu_state}" == *"A40"* || "${gpu_state}" == *"A6000"* ]]; then
pip install flash-attn==2.0.2
pip install flash-attn>=2.0.2
fi
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ gradio
accelerate>=0.27.2
einops>=0.6.1
scikit-learn==1.2.2
flash-attn
7 changes: 6 additions & 1 deletion scripts/run_finetune_with_lisa.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ gradient_checkpointing=True
use_flash_attention=0
gradient_accumulation_steps=1
block_size=256
per_device_train_batch_size=1

# Enable model parallelism for multiple gpus, modify this if you prefer
# customized deepspeed zero-redundancy optimization settings
Expand Down Expand Up @@ -72,6 +73,10 @@ while [[ $# -ge 1 ]]; do
block_size="$2"
shift
;;
--per_device_train_batch_size|--batch_size)
per_device_train_batch_size="$2"
shift
;;
*)
echo "error: unknown option \"${key}\"" 1>&2
exit 1
Expand All @@ -94,7 +99,7 @@ deepspeed ${deepspeed_args} \
--learning_rate 2e-5 \
--disable_group_texts 1 \
--block_size ${block_size} \
--per_device_train_batch_size 1 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--deepspeed ${ds_config_file} \
--fp16 \
--run_name finetune \
Expand Down
60 changes: 9 additions & 51 deletions src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,56 +218,7 @@ def __init__(
replace_llama_with_condense,
)
replace_llama_with_condense(model_args.rope_pi_ratio, model_args.rope_ntk_ratio)

# Whether use flash attention
if model_args.use_flash_attention:
supported_gpu_device = None
for gpu in GPU_SUPPORT_FLASH_ATTENTION:
if gpu in torch.cuda.get_device_name():
supported_gpu_device = gpu
if not any(model_supported in config.architectures
for model_supported in MODELS_SUPPORT_FLASH_ATTENTION):
logger.warning(
f"Model \"{config.architectures}\" does not support"
" flash attention, use normal attention layer instead"
)
elif supported_gpu_device is None:
logger.warning(
f"Your decice \"{torch.cuda.get_device_name()}\""
" does not support flash attention, it will"
" automatically use normal attention layer"
)
else:

supported_models = GPU_SUPPORT_FLASH_ATTENTION[supported_gpu_device]

config.use_cache = False
if "LlamaForCausalLM" in config.architectures and "LlamaForCausalLM" in supported_models:
from lmflow.utils.flash_attention.llama_flash_attention import (
replace_llama_attn_with_flash_attn,
)
replace_llama_attn_with_flash_attn()
elif "GPTNeoForCausalLM" in config.architectures and "GPTNeoForCausalLM" in supported_models:
from lmflow.utils.flash_attention.gpt_neo_flash_attention import (
replace_gpt_neo_attn_with_flash_attn,
)
replace_gpt_neo_attn_with_flash_attn()
elif "GPT2ForCausalLM" in config.architectures and "GPT2ForCausalLM" in supported_models:
from lmflow.utils.flash_attention.gpt2_flash_attention import (
replace_gpt2_attn_with_flash_attn,
)
replace_gpt2_attn_with_flash_attn()
elif "BloomForCausalLM" in config.architectures and "BloomForCausalLM" in supported_models:
from lmflow.utils.flash_attention.bloom_flash_attention import (
replace_bloom_attn_with_flash_attn
)
replace_bloom_attn_with_flash_attn()
else:
raise ValueError(
f"Model \"{config.architectures}\" with GPU {supported_gpu_device} does not support"
" flash attention, use normal attention layer instead"
)


if tune_strategy == 'normal':
if model_args.model_name_or_path:
compute_dtype = torch_dtype
Expand Down Expand Up @@ -298,6 +249,7 @@ def __init__(
use_auth_token=True if model_args.use_auth_token else None,
torch_dtype=torch_dtype,
trust_remote_code = model_args.trust_remote_code,
attn_implementation="flash_attention_2" if model_args.use_flash_attention else None,
)
#for deepspeed zero3, we don't need to specify device_map
except:
Expand All @@ -311,6 +263,7 @@ def __init__(
use_auth_token=True if model_args.use_auth_token else None,
torch_dtype=torch_dtype,
trust_remote_code = model_args.trust_remote_code,
attn_implementation="flash_attention_2" if model_args.use_flash_attention else None,
)
if model_args.use_qlora:
model.gradient_checkpointing_enable()
Expand Down Expand Up @@ -359,7 +312,8 @@ def __init__(
offload_folder="offload",
offload_state_dict=True,
torch_dtype=torch_dtype,
load_in_8bit = model_args.use_int8
load_in_8bit = model_args.use_int8,
attn_implementation="flash_attention_2" if model_args.use_flash_attention else None,
)
if peft_model_id is not None:
self.backend_model = PeftModel.from_pretrained(
Expand Down Expand Up @@ -388,6 +342,7 @@ def __init__(
offload_folder="offload",
offload_state_dict=True,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2" if model_args.use_flash_attention else None,
)
except:
logger.warning(
Expand All @@ -399,6 +354,7 @@ def __init__(
model_args.model_name_or_path,
config=config,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2" if model_args.use_flash_attention else None,
)
else:
if peft_model_id is not None:
Expand All @@ -410,6 +366,7 @@ def __init__(
model_args.model_name_or_path,
config=config,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2" if model_args.use_flash_attention else None,
)

self.backend_model_full = self.backend_model
Expand Down Expand Up @@ -811,6 +768,7 @@ def get_peft_without_qlora(self):
torch_dtype=torch_dtype,
device_map=device_map,
trust_remote_code = self.model_args.trust_remote_code,
attn_implementation="flash_attention_2" if model_args.use_flash_attention else None,
)

self.backend_model = PeftModel.from_pretrained(self.backend_model_full, tmpdirname)
Expand Down