Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgjs committed Feb 14, 2024
1 parent 8101139 commit 93799fc
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/autora/doc/pipelines/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def fine_tune(base_model: str, new_model_name: str, dataset: Dataset) -> None:

model = AutoModelForCausalLM.from_pretrained(
base_model,
*kwargs,
**kwargs,
)
model.config.use_cache = False
model.config.pretraining_tp = 1
Expand Down Expand Up @@ -61,8 +61,8 @@ def fine_tune(base_model: str, new_model_name: str, dataset: Dataset) -> None:
logging_steps=1, # TODO: Increase once there's more data
learning_rate=2e-4,
weight_decay=0.001,
fp16=False,
bf16=cuda_available,
fp16=cuda_available,
bf16=False,
max_grad_norm=0.3,
max_steps=-1,
warmup_ratio=0.03,
Expand Down
2 changes: 1 addition & 1 deletion src/autora/doc/runtime/predict_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,5 @@ def get_quantization_config() -> Any:
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_compute_dtype=torch.float16,
)

0 comments on commit 93799fc

Please sign in to comment.