Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Olmo tiny scripts #628

Merged
merged 46 commits into from
Jun 28, 2024
Merged

Olmo tiny scripts #628

merged 46 commits into from
Jun 28, 2024

Conversation

ananyahjha93
Copy link
Contributor

No description provided.

@ananyahjha93 ananyahjha93 requested a review from dirkgr June 18, 2024 22:37
@ananyahjha93 ananyahjha93 requested a review from AkshitaB June 18, 2024 22:39
Copy link
Contributor

@AkshitaB AkshitaB left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed my queries offline with @ananyahjha93

  • How were model shapes decided? Based on Pythia and then number of parameters.
  • How about LR? Also ballpark from Pythia.

Other things to note:

  • Global batch size may also require some ablation

@@ -248,7 +248,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
)
cfg.save_interval_unsharded = cfg.save_interval

if cfg.save_num_unsharded_checkpoints_to_keep < 1:
if cfg.save_num_unsharded_checkpoints_to_keep == 0:
log.warning(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if save_num_checkpoints_to_keep is also 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it then assumes that you did not want to keep checkpoints at all!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-1 assumes you want to save all checkpoints and so I made it ==0 instead of < 1.

Copy link
Member

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume the configs between the sizes are all the same, so I didn't look at all of them.

@@ -9,17 +9,15 @@ wandb:
model:
d_model: 1024
n_heads: 16
n_layers: 16
n_layers: 24
mlp_ratio: 8
weight_tying: false
alibi: false
rope: true
flash_attention: true # not available on AMD
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is now available on AMD

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the comment

- label: commonsense_qa
type: downstream

- label: social_iqa
type: downstream

- label: basic_arithmetic
type: downstream

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's wrong with these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, basic_arithmetic should be in, others don't provide any signal based on my experience

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah this was commented out saying

# Doesn't work from cache.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should work with cache v4

stop_at: 100_000
global_train_batch_size: 2048
device_train_microbatch_size: 8
max_duration: 2ep
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means you'll run into this bug: #584
It might not matter. The problem is only that the second epoch will be shuffled the same way the first one is shuffled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a stop_at 400k steps!

@@ -9,17 +9,15 @@ wandb:
model:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No DDP section in this file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is!

Comment on lines 63 to 64
grad_clip_warmup_steps: null
grad_clip_warmup_factor: 5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't have these settings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

took these from @AkshitaB 's llamaish1-normal-weka.yaml.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed them for now!

paths:
######### NON WEB DATA #########
# ~> GUTENBERG BOOKS (5.256 GT)
- s3://ai2-llm/preprocessed/olmo-mix/v1_6-decontaminated/books/gpt-neox-olmo-dolma-v1_5/part-0-00000.npy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you read from weka instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was planning to run on pluto, now I can see free nodes on jupiter, making the change!

# Unsharded checkpoints (for ddp)
save_interval_unsharded: 5000
save_num_unsharded_checkpoints_to_keep: 3
save_num_unsharded_checkpoints_to_keep: -1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does -1 do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-1 is for keeping all checkpoints, but I'll double check

Copy link
Member

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved with a small comment about the long warmup.

units: tokens
t_warmup: 4194304000
t_max: 3e12
t_warmup: 5000
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For normal init, this is a lot of warmup? Not a big deal, but unusual?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

smaller models, higher LR, did not take a chance! never bad doing a longer warmup!

Comment on lines +77 to +78
max_duration: 1ep
stop_at: 406_934
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need both max_duration and stop_at?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, from what I have observed and Dave mentioned the training goes past max_duration if stop_at is not set

Comment on lines +152 to +155
# Doesn't work from cache.
# - label: basic_arithmetic
# type: downstream

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even with cache v4?

--run_name=$TASK_NAME \
--wandb.name=$TASK_NAME \
--wandb.group=$TASK_NAME \
--wandb.project=tiny_olmo \
--wandb.project=olmo-tiny \
--max_grad_norm=2.0 \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to do this clipping value for all small models?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, let me fix this, so the model with clipping value 2.0 does not show any downstream improvement!

@ananyahjha93 ananyahjha93 requested a review from epwalsh June 24, 2024 23:25
olmo/train.py Show resolved Hide resolved
olmo/train.py Outdated
Comment on lines 1118 to 1119
num_fwd_flops=self.model.num_fwd_flops, # this is per sequence
num_bck_flops=self.model.num_bck_flops, # this is per sequence
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"this is per sequence" ... it's per-token now, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

@ananyahjha93 ananyahjha93 merged commit a1f118a into main Jun 28, 2024
12 checks passed
@ananyahjha93 ananyahjha93 deleted the olmo-tiny branch June 28, 2024 16:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants