Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions finetuner-workflow/finetuner/ds_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true,
"cpu_offload": true,
"stage3_gather_fp16_weights_on_model_save": true
"offload_optimizer": {
"device": "cpu"
},
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
Expand Down
191 changes: 95 additions & 96 deletions finetuner-workflow/finetuner/finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,10 @@
from torch.utils.data import Dataset, random_split, RandomSampler
import argparse
import pathlib
from typing import Callable, Tuple, Optional, List
from typing import Tuple, List
import socket
from contextlib import closing
from contextlib import closing, contextmanager

try:
from tensorizer.tensorizer import load_model, get_tokenizer, no_init
except ModuleNotFoundError:
pass
import validators
import deepspeed

Expand Down Expand Up @@ -91,18 +87,24 @@ def find_free_port():
default=0.9,
)
parser.add_argument(
"--eot", type=str, help="EOT token to use", default="<|endoftext|>"
"--eot",
type=str,
help="EOT token to use",
default="", # default is model-dependent
)
parser.add_argument(
"--pad", type=str, help="Pad token to use", default="<|padding|>"
"--pad",
type=str,
help="Pad token to use",
default="", # default is model-dependent
)
parser.add_argument(
"--bs", type=int, help="Batch size (-1 == autosize)", default=-1
)
parser.add_argument(
"--bs_divisor",
type=float,
help="Batch size divisor for " "automatically " "determining batch size",
help="Batch size divisor for automatically determining batch size",
default=1.0,
)
parser.add_argument(
Expand Down Expand Up @@ -208,12 +210,11 @@ def find_free_port():

# Properly type-cast the param (str to bool)
FALSE = [
"False",
"false",
"f",
"0",
]
args.no_resume = args.no_resume not in FALSE
args.no_resume = args.no_resume.lower() not in FALSE

# Discover if we have any checkpoints to resume from.
if not args.no_resume:
Expand All @@ -230,19 +231,19 @@ def find_free_port():

# Set up `wandb` reporting if we have an API key, and resume reporting
# if we are resuming a checkpoint.
report_to = None
wandb_key = os.getenv("WANDB_API_KEY", "").lstrip().rstrip()
report_to = "none"
wandb_key = os.getenv("WANDB_API_KEY", "").strip()
if not wandb_key:
print("WANDB_API_KEY: No WANDB_API_KEY found, not reporting to wandb.")
os.environ["WANDB_DISABLED"] = "True"

import wandb

if wandb_key:
wandbApi = wandb.Api(overrides={"project": args.project_id})
report_to = "wandb"

if lastCheckpoint is not None:
wandbApi = wandb.Api(overrides={"project": args.project_id})
for run in wandbApi.runs(path=args.project_id):
print("PRIOR RUN:", run, run.name, run.id, run.state)
if run.state in ["crashed", "failed"] and run.name == args.run_name:
Expand All @@ -263,18 +264,42 @@ def find_free_port():
# Set up our tokenizer.
tokenizer: PreTrainedTokenizer
try:
# If a special token (args.eot or args.pad) is explicitly provided,
# then use it; otherwise use the model's defaults if they exist;
# otherwise use hardcoded defaults.

# The resulting padding token ID must match the one the dataset tokenizer
# used, or the existing padding tokens in the dataset
# will not be properly masked during training.

tokens_to_add = {}
if args.eot:
tokens_to_add["eos_token"] = args.eot
if args.pad:
tokens_to_add["pad_token"] = args.pad

tokenizer = AutoTokenizer.from_pretrained(
args.model,
eos_token=args.eot,
pad_token=args.pad,
**tokens_to_add,
cache_dir=args.cache,
)

tokens_to_add.clear()
if "eos_token" not in tokenizer.special_tokens_map:
tokens_to_add["eos_token"] = "<|endoftext|>"
if "pad_token" not in tokenizer.special_tokens_map:
tokens_to_add["pad_token"] = "<|endoftext|>"
if tokens_to_add:
tokenizer.add_special_tokens(tokens_to_add)
except Exception as e:
print(e)
sys.exit(1)


def no_init(loading_code: Callable[[], PreTrainedModel]) -> PreTrainedModel:
@contextmanager
def no_init():
# `no_init_weights` doesn't suppress initialization of some layers by default
# See https://github.com/huggingface/transformers/issues/18505
def dummy(self):
return

Expand All @@ -284,12 +309,12 @@ def dummy(self):
original[mod] = mod.reset_parameters
mod.reset_parameters = dummy

with no_init_weights():
result = loading_code()
for mod in modules:
mod.reset_parameters = original[mod]

return result
try:
with no_init_weights():
yield
finally:
for mod in modules:
mod.reset_parameters = original[mod]


def estimate_batch_size(divisor: float = 1.0) -> int:
Expand Down Expand Up @@ -329,34 +354,32 @@ def get_gpu_ram() -> str:
cudadev = torch.cuda.current_device()
nvml_device = pynvml.nvmlDeviceGetHandleByIndex(cudadev)
gpu_info = pynvml.nvmlDeviceGetMemoryInfo(nvml_device)
gpu_total = int(gpu_info.total / 1e6)
gpu_free = int(gpu_info.free / 1e6)
gpu_used = int(gpu_info.used / 1e6)
gpu_total = gpu_info.total >> 20
gpu_free = gpu_info.free >> 20
gpu_used = gpu_info.used >> 20
gpu_str = (
f"GPU: (U: {gpu_used:,}mb F: {gpu_free:,}mb "
f"T: {gpu_total:,}mb) "
f"GPU: (U: {gpu_used:,}MiB F: {gpu_free:,}MiB T: {gpu_total:,}MiB) "
)
torch_reserved_gpu = int(torch.cuda.memory.memory_reserved() / 1e6)
torch_reserved_max = int(torch.cuda.memory.max_memory_reserved() / 1e6)
torch_used_gpu = int(torch.cuda.memory_allocated() / 1e6)
torch_max_used_gpu = int(torch.cuda.max_memory_allocated() / 1e6)
torch_reserved_gpu = torch.cuda.memory.memory_reserved() >> 20
torch_reserved_max = torch.cuda.memory.max_memory_reserved() >> 20
torch_used_gpu = torch.cuda.memory_allocated() >> 20
torch_max_used_gpu = torch.cuda.max_memory_allocated() >> 20
torch_str = (
f"TORCH: (R: {torch_reserved_gpu:,}mb/"
f"{torch_reserved_max:,}mb, "
f"A: {torch_used_gpu:,}mb/{torch_max_used_gpu:,}mb)"
f"TORCH: (R: {torch_reserved_gpu:,}MiB/"
f"{torch_reserved_max:,}MiB, "
f"A: {torch_used_gpu:,}MiB/{torch_max_used_gpu:,}MiB)"
)
except AssertionError:
pass
cpu_maxrss = int(
resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1e3
+ resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss / 1e3
)
cpu_maxrss = (
resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
+ resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss
) >> 10
cpu_vmem = psutil.virtual_memory()
cpu_free = int(cpu_vmem.free / 1e6)
cpu_free = cpu_vmem.free >> 20
return (
f"CPU: (maxrss: {cpu_maxrss:,}mb F: {cpu_free:,}mb) "
f"{gpu_str}"
f"{torch_str}"
f"CPU: (maxrss: {cpu_maxrss:,}MiB F: {cpu_free:,}MiB) "
f"{gpu_str}{torch_str}"
)


Expand All @@ -368,36 +391,23 @@ class ModifiedTrainer(Trainer):

def compute_loss(self, model, inputs, return_outputs=False):
if "labels" in inputs:
inputs["labels"][inputs["labels"] == tokenizer.pad_token_id] = -100

if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None

outputs = model(**inputs)

if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
inputs["labels"].masked_fill_(
inputs["labels"] == tokenizer.pad_token_id, -100
)

if labels is not None:
loss = self.label_smoother(outputs, labels)
else:
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
results = super().compute_loss(model, inputs, return_outputs)
loss = results[0] if return_outputs else results

# Hack -- enable `requires_grad` on `loss`
loss.requires_grad_(True)

# Hack -- output an (useful) GPU ram update to flush tqdm.
if not hasattr(self, "report_idx"):
self.report_idx = 1
else:
self.report_idx += 1
if self.report_idx % 10 == 0:
# Hack -- output a (useful) GPU ram update to flush tqdm.
self.report_idx = getattr(self, "report_idx", 0) + 1
if self.report_idx % (2 * self.args.gradient_accumulation_steps) == 0:
print(f"\nLOSS: {loss:.3f} {get_gpu_ram()}", file=sys.stderr)
sys.stderr.flush()

return (loss, outputs) if return_outputs else loss
return results


class ModelSampler(TrainerCallback):
Expand Down Expand Up @@ -439,7 +449,6 @@ def __init__(
def on_step_end(
self, args, state, control, model: PreTrainedModel = None, **kwargs
):

if not model:
return
if state.global_step % self.report_every == 0 or state.global_step == 1:
Expand Down Expand Up @@ -511,11 +520,11 @@ def __init__(self, path: str, context_length: int = 2048):
self.length = int(file_stat.st_size / 2 / context_length)
self.formatstr = "%sH" % context_length
self.context_length = context_length
length_mb = os.stat(path).st_size / 1024.0 / 1024.0
length_mb = os.stat(path).st_size / (1 << 20)
num_tokens = self.length * context_length
print(f"DATASET: {path}")
print(
f"DATASET SIZE: {length_mb:,.2f}mb, {num_tokens:,} tokens, "
f"DATASET SIZE: {length_mb:,.2f}MiB, {num_tokens:,} tokens, "
f"{self.length:,} contexts"
)

Expand All @@ -529,7 +538,7 @@ def load(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
self.formatstr, self.file.read(self.context_length * 2)
)
)
mask = torch.zeros(self.context_length)
mask = input_ids != tokenizer.pad_token_id
return input_ids, mask

def seek(self, idx):
Expand All @@ -539,7 +548,7 @@ def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
return self.load(idx)


# Inform the user of host, and various versions -- useful for debugging isseus.
# Inform the user of host, and various versions -- useful for debugging issues.
print("RUN_NAME:", args.run_name)
print("HOST:", socket.gethostname())
print("CUDA:", torch.version.cuda)
Expand Down Expand Up @@ -573,28 +582,23 @@ def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:

# Determine if we train in fp32 or fp16 mode.
print("FORCE FP16:", args.fp16)
fp16_arg = {}
if args.fp16:
fp16_arg = {"fp16": True}
fp16_arg = {"fp16": True} if args.fp16 else {}

# Load our model that we're training. This may fetch via HTTP if not cached
# already.
model: PreTrainedModel


try:
model = AutoModelForCausalLM.from_pretrained(
args.model, # Can be a HuggingFace ID or directory.
cache_dir=args.cache,
use_cache=False,
) # Gradient checkpointing needs this off.
if lastCheckpoint is None:
if args.fp16:
model = no_init(lambda: model.half().to(device))
else:
model = no_init(lambda: model.to(device))
else:
model = no_init(lambda: model)
with no_init():
model = AutoModelForCausalLM.from_pretrained(
args.model, # Can be a HuggingFace ID or directory.
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
cache_dir=args.cache,
use_cache=False,
) # Gradient checkpointing needs this off.
if lastCheckpoint is None:
model = model.to(device)
sys.stderr.flush()
sys.stdout.flush()
except Exception as e:
Expand Down Expand Up @@ -623,7 +627,7 @@ def evaluate(
input_tokens: Tensor = (
torch.LongTensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
)
attention_mask: Tensor = torch.ones_like(input_tokens).to(device)
attention_mask: Tensor = input_tokens != tokenizer.pad_token_id
max_length = input_tokens.shape[1] + generate_tokens
generated_tokens = eval_model.generate(
input_tokens,
Expand All @@ -640,10 +644,8 @@ def evaluate(
bad_words_ids=[[eval_tokenizer.eos_token_id]],
)

for sample_idx in range(len(generated_tokens)):
output_text = eval_tokenizer.decode(
generated_tokens[sample_idx], skip_special_tokens=False
)
for token in generated_tokens:
output_text = eval_tokenizer.decode(token, skip_special_tokens=False)
output_texts.append(output_text)

return output_texts
Expand All @@ -669,10 +671,7 @@ def evaluate(
ds_args = {}
if device != "cpu":
ds_config = json.load(open(args.ds_config))
if (
"zero_optimization" in ds_config
and ds_config["zero_optimization"].get("stage", None) != args.zero_stage
):
if "zero_optimization" in ds_config:
ds_config["zero_optimization"]["stage"] = args.zero_stage
ds_args["deepspeed"] = ds_config
else:
Expand All @@ -681,7 +680,7 @@ def evaluate(
os.environ["LOCAL_RANK"] = "-1"

# The latest deepspeed logging is pretty obnoxious, so we disable it.
deepspeed.utils.logger.setLevel(logging.ERROR)
deepspeed.utils.logger.setLevel(logging.WARNING)

# Change our current directory due to some packages assumptions.
os.makedirs(args.output_path, exist_ok=True)
Expand Down