Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions open_instruct/dpo_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,9 @@ def load_model():
quantization_config=bnb_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
use_flash_attention_2=True if args.use_flash_attn else False,
attn_implementation="flash_attention_2"
if args.use_flash_attn
else "eager",
)
else:
model = AutoModelForCausalLM.from_pretrained(
Expand All @@ -642,7 +644,9 @@ def load_model():
config=config,
trust_remote_code=args.trust_remote_code,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_flash_attention_2=True if args.use_flash_attn else False,
attn_implementation="flash_attention_2"
if args.use_flash_attn
else "eager",
)
else:
logger.info("Training new model from scratch")
Expand Down
45 changes: 39 additions & 6 deletions open_instruct/dpo_tune_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,35 @@ def __post_init__(self):
raise ValueError("Need either a dataset name, dataset mixer, or a training file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["json", "jsonl"], "`train_file` should be a json or a jsonl file."
if os.path.isdir(self.train_file):
# just assume they
self.train_file = [
os.path.join(self.train_file, x)
for x in os.listdir(self.train_file)
]
self.train_file_type = [
x.split(".")[-1] for x in self.train_file
]
self.train_file_type = [
x for x in self.train_file_type
if x in ["json", "jsonl", "parquet"]
]
self.train_file_type = list(set(self.train_file_type)) # unique
# assume the directory cannot mix types
self.train_file_type = (
None if len(self.train_file_type) == 0 else
self.train_file_type[0]
)
else:
self.train_file_type = self.train_file.split(".")[-1]

# some slight renames
if self.train_file_type == 'jsonl':
self.train_file_type = "json"

assert self.train_file_type in ["json", "parquet"], (
"`train_file` should be a json(l) or parquet file."
)
if (
(self.dataset_name is not None and (self.dataset_mixer is not None or self.dataset_mixer_list is not None))
or (self.dataset_name is not None and self.train_file is not None)
Expand Down Expand Up @@ -456,6 +483,7 @@ def get_cache_ref_logprobs(
resume_step: int,
epoch_range: range,
forward_fn: Callable,
use_lora: bool = False,
):
epoch_cached_reference_chosen_logps = []
epoch_cached_reference_rejected_logps = []
Expand All @@ -468,7 +496,7 @@ def get_cache_ref_logprobs(
cached_reference_rejected_logps = []
with torch.no_grad():
for step, batch in tqdm(enumerate(active_dataloader), disable=not accelerator.is_local_main_process):
if args.use_lora:
if use_lora:
with accelerator.unwrap_model(model).disable_adapter():
reference_chosen_logps, reference_rejected_logps, _ = forward_fn(
model, batch, average_log_prob=average_log_prob
Expand Down Expand Up @@ -575,7 +603,7 @@ def main(args: FlatArguments):
if args.train_file is not None:
data_files["train"] = args.train_file
raw_datasets = load_dataset(
"json",
args.train_file_type,
data_files=data_files,
**dataset_args,
)
Expand Down Expand Up @@ -646,7 +674,9 @@ def load_model():
quantization_config=bnb_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
use_flash_attention_2=True if args.use_flash_attn else False,
attn_implementation="flash_attention_2"
if args.use_flash_attn
else "eager",
)
else:
model = AutoModelForCausalLM.from_pretrained(
Expand All @@ -656,7 +686,9 @@ def load_model():
config=config,
trust_remote_code=args.trust_remote_code,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_flash_attention_2=True if args.use_flash_attn else False,
attn_implementation="flash_attention_2"
if args.use_flash_attn
else "eager",
)
else:
logger.info("Training new model from scratch")
Expand Down Expand Up @@ -960,6 +992,7 @@ def load_model():
resume_step,
range(starting_epoch, args.num_train_epochs),
forward_fn,
args.use_lora,
)
print("=============after cache logprobs")
print_gpu_stats(init_gpu_memory)
Expand Down
8 changes: 6 additions & 2 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,13 @@ def __post_init__(self):
)
else:
self.train_file_type = self.train_file.split(".")[-1]

# some slight renames
if self.train_file_type == 'jsonl':
self.train_file_type = "json"

assert self.train_file_type in ["json", "jsonl", "parquet"], (
"`train_file` should be a json or a jsonl or parquet file."
assert self.train_file_type in ["json", "parquet"], (
"`train_file` should be a json(l) or parquet file."
)
if (
(
Expand Down
181 changes: 181 additions & 0 deletions scripts/test_dataloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import torch
import os

from unittest.mock import patch, Mock
from accelerate import Accelerator


from transformers import AutoModelForCausalLM
from accelerate import init_empty_weights
from transformers.modeling_outputs import CausalLMOutputWithPast

from typing import Callable, List
from types import MethodType

# builds the from_pretrained to pass through AutoModelForCausalLM
# - loads the model on the meta device

def built_from_pretrained(
store: List,
extra_keys: List = ['labels'],
):

def forward(self, input_ids, *args, **kwargs):
# - hook to capture the input_ids
store.append(
{
'input_ids': input_ids,
**{k:kwargs.get(k) for k in extra_keys},
}
)

# - returns dummy outputs / loss
return CausalLMOutputWithPast(
loss=torch.tensor(0.),
logits=torch.zeros(input_ids.shape + (self.config.vocab_size,))
)

# mock the from_pretrained function
old_func = AutoModelForCausalLM.from_pretrained
def _from_pretrained(model_name_or_path, *args, **kwargs):
with init_empty_weights():
model = old_func(
model_name_or_path, *args, **kwargs,
)
model.forward = MethodType(forward, model)
return model

return _from_pretrained

def accelerate_prepare(self, model, optimizer, dataloader, scheduler):
self.state = Mock() # sneak it in

# skip through the set_epoch if the train loop assumes a distributed
# dataloader

def set_epoch(self, *args):
pass

dataloader.set_epoch = MethodType(set_epoch, dataloader)
return model, optimizer, dataloader, scheduler

# - patches

patch_accelerate = patch.multiple(
Accelerator,
wait_for_everyone=Mock(),
prepare=accelerate_prepare,
num_processes=1,
device=torch.device('cuda'),
backward=Mock(),
sync_gradients=False,
)


def test_tuning_script(
script: Callable,
args: object,
patches: List[object],
store: List,
write_data_to_directory: str = None,
):
from contextlib import ExitStack

# clear the store
store.clear()
with ExitStack() as stack:
for patch in patches:
stack.enter_context(patch)
script(args)

if write_data_to_directory is None:
return store

os.makedirs(write_data_to_directory)
for i, data in enumerate(store):
torch.save(data, os.path.join(write_data_to_directory, f'batch_{i}.pt'))


def test_finetune(
model_name_or_path: str,
dataset_name: str = None,
train_file: str = None,
max_train_steps: int = 2,
write_data_to_directory: str = None,
):
from open_instruct.finetune import main, FlatArguments

args = FlatArguments(
model_name_or_path=model_name_or_path,
dataset_name=dataset_name,
train_file=train_file,
push_to_hub=False,
try_launch_beaker_eval_jobs=False,
output_dir=None,
max_train_steps=max_train_steps,
)

# - initialize store in transformers patcher
STORE = []
patch_transformers = patch.multiple(
AutoModelForCausalLM,
from_pretrained=built_from_pretrained(
STORE
),
)

return test_tuning_script(
main, args,
[
patch_transformers,
patch_accelerate,
],
STORE,
write_data_to_directory,
)

def test_dpo_tune(
model_name_or_path: str,
dataset_name: str = None,
train_file: str = None,
max_train_steps: int = 2,
write_data_to_directory: str = None,
):
from open_instruct.dpo_tune_cache import main, FlatArguments

args = FlatArguments(
model_name_or_path=model_name_or_path,
dataset_name=dataset_name,
train_file=train_file,
push_to_hub=False,
try_launch_beaker_eval_jobs=False,
try_auto_save_to_beaker=False,
output_dir=None,
max_train_steps=max_train_steps,
)

# - initialize store in transformers patcher
STORE = []
patch_transformers = patch.multiple(
AutoModelForCausalLM,
from_pretrained=built_from_pretrained(
STORE,
extra_keys=[],
),
)

return test_tuning_script(
main, args,
[
patch_transformers,
patch_accelerate,
],
STORE,
write_data_to_directory,
)


if __name__ == "__main__":
import fire
data = fire.Fire(test_finetune)