Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ILQL details #156

Merged
merged 19 commits into from
Jan 11, 2023
Merged

Update ILQL details #156

merged 19 commits into from
Jan 11, 2023

Conversation

maxreciprocate
Copy link
Collaborator

@maxreciprocate maxreciprocate commented Jan 2, 2023

This pr

  • adds tensor statistics to ILQL training
  • changes offline training api by removing split_token and instead delegating preprocessing per example
  • confines optional crossentropy loss to logits from completions

https://api.wandb.ai/report/sorry/t4w326pt
https://api.wandb.ai/report/sorry/93xza5j0

The difference in randomwalks is due to removing <bos> from samples and specifying starting nodes directly instead

@jon-tow jon-tow added this to the v0.4.0 milestone Jan 2, 2023
@maxreciprocate maxreciprocate marked this pull request as ready for review January 5, 2023 16:14
@@ -36,7 +36,6 @@ def metric_fn(samples: List[str]) -> Dict[str, List[float]]:
imdb = load_dataset("imdb", split="train+test")

trlx.train(
"gpt2",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we changing this on the example?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this change. Previously, if your ILQL config set the model_path option it would be overridden by this "gpt2" arg. It's led to unexpected behavior on my end whereby editing the example ILQL config's model path doesn't actually update the model.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's for consistency with other examples

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

import torch

from trlx.orchestrator import Orchestrator, register_orchestrator
from trlx.pipeline.offline_pipeline import ILQLRolloutStorage


def tokenize_dialogue(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we adding dialogue related functionality to the core api?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having this func in trlx makes sense, it makes the completion split based preproc for ilql a bit clearer although we should probably pair it with a multi-turn example to show it working

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps including it is fine but I think it should be optional (and not the default)

@LouisCastricato
Copy link
Contributor

Where exactly does this fix the memory issue?

@maxreciprocate
Copy link
Collaborator Author

maxreciprocate commented Jan 9, 2023

Where exactly does this fix the memory issue?

Implicitly it fixes the issue by removing split_token which was easy to abuse and turn the whole sample into an output over which q/v losses have to be computed. Also cross-entropy is now computed only on a few output logits (as per Charlie's comment) instead of the whole sample. From results of both I can now train on @AlekseyKorshuk's data referenced in the issue with the same gpt2-xl no layers frozen, mixed precision, batch size 8, context length 1024, on a single A100 without OOMs.

https://wandb.ai/sorry/public/runs/17v1c8st

import yaml
import trlx
from trlx.data.configs import TRLConfig
from datasets import load_dataset

default_config = yaml.safe_load(open("configs/ilql_dalio.yml"))

def main(hparams={}):
    config = TRLConfig.update(default_config, hparams)

    val_split = 16
    dataset = load_dataset("ChaiML/dalio_scored_responses_v1")["train"]
    valid_ds = dataset[:val_split]
    train_ds = dataset[val_split:]

    eval_prompts = valid_ds["input_text"]

    dataset = (
        list(zip(
            train_ds["input_text"],
            train_ds["output_text"],
        )),
        train_ds["score"]
    )

    trlx.train(
        dataset=dataset,
        config=config,
        eval_prompts=eval_prompts,
    )


if __name__ == "__main__":
    main()

@@ -36,7 +36,6 @@ def metric_fn(samples: List[str]) -> Dict[str, List[float]]:
imdb = load_dataset("imdb", split="train+test")

trlx.train(
"gpt2",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

@@ -48,18 +51,20 @@ def heads(self, hidden_size: int, vocab_size: int):

def loss(self, outputs, labels: ILQLBatch):
logits, (qs, target_qs, vs) = outputs
terminal_mask = labels.dones[:, :-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remind why we need this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the same as attention_mask except only accounting states_ixs (subset of tokens from which continuations were sampled), plus the last token for which is also masked. It lets V(terminal) = 0 in V[:, 1:] * terminal_mask[:, 1:] and also masks padding in other losses

for phrase in sample:
if isoutput:
actions_ixs.append(
torch.arange(length - 1, length + len(phrase) - 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do states and actions append exactly the same thing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be refactored, the difference between the two is few lines below

all_input_ids.append(torch.tensor(sum(sample, [])))
isoutput = False
actions_ixs, states_ixs = [], []
for phrase in sample:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So a sample can contain multiple "phrases" which we can reward?

Is there a reason we can't just separate each phrase into it's own datapoint with the "output" at the end(with all the prior dialog as context)? Do you see some advantage to treating things this way?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, both are supported and opt-in

dialogue: Union[str, List[str]], tokenizer, max_length=2048, truncation_side="left"
) -> List[int]:
"""
Tokenize sample with the interleaved form of (question_1, answer_1, question_2, answer_2...)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we intend for answer_2 to depend on question_1, answer_1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ILQL was designed to do this

break

# in case of odd number of phrases (possibly due to truncation)
# since the first phrase always has to be a prompt, force it to be <bos>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the last phrase always be a prompt? Why must first always be prompt?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Counting from the left the last one should be action, otherwise there would be nothing to prompt and it would be redundant since the trajectory (more specifically actions) might as well be graded without it. At least datasets we are interested in, to my knowledge, are of this form

out[0].pop(0)
out.insert(0, [tokenizer.bos_token_id])

elif truncation_side == "right":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would we want this? If we truncate on the right the agent is missing the most recent dialog

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function accepts both dialogues and monologues and it's in anticipation of

@Dahoas
Copy link
Collaborator

Dahoas commented Jan 9, 2023

Left some comments. In general I'm kinda skeptical about the need to support multiple "actions" (responses) in one trajectory. I don't see why we can't handle this just by splitting a dialog with 10 interactions into 10 samples, each of only has one action (at the very end of the dialog). Perhaps including it is fine but I think it should be optional

(Sorry if I misunderstood something)

@maxreciprocate
Copy link
Collaborator Author

No Alex, this is a great question. The difference would be when you don't have intermediate rewards for each cumulative interaction (q0, a0, r0), (q0, a0, q1, a1, r1) etc, but instead a single rating for the whole trajectory. Therefore it sense makes to credit assign the return to those preceding interactions as well and not necessarily only to the latest one, given that all interactions were in fact judged as a whole. And it's already totally optional depending how you pre-split your data yourself either way could work, it's to be experimentally determined which one is better

@Dahoas Dahoas merged commit 0cb8438 into main Jan 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants