-
Notifications
You must be signed in to change notification settings - Fork 478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update ILQL details #156
Update ILQL details #156
Conversation
as it already is in the ppo_randomwalks
@@ -36,7 +36,6 @@ def metric_fn(samples: List[str]) -> Dict[str, List[float]]: | |||
imdb = load_dataset("imdb", split="train+test") | |||
|
|||
trlx.train( | |||
"gpt2", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we changing this on the example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this change. Previously, if your ILQL config set the model_path
option it would be overridden by this "gpt2"
arg. It's led to unexpected behavior on my end whereby editing the example ILQL config's model path doesn't actually update the model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's for consistency with other examples
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed
import torch | ||
|
||
from trlx.orchestrator import Orchestrator, register_orchestrator | ||
from trlx.pipeline.offline_pipeline import ILQLRolloutStorage | ||
|
||
|
||
def tokenize_dialogue( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we adding dialogue related functionality to the core api?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having this func in trlx makes sense, it makes the completion split based preproc for ilql a bit clearer although we should probably pair it with a multi-turn example to show it working
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps including it is fine but I think it should be optional (and not the default)
Where exactly does this fix the memory issue? |
Implicitly it fixes the issue by removing https://wandb.ai/sorry/public/runs/17v1c8st import yaml
import trlx
from trlx.data.configs import TRLConfig
from datasets import load_dataset
default_config = yaml.safe_load(open("configs/ilql_dalio.yml"))
def main(hparams={}):
config = TRLConfig.update(default_config, hparams)
val_split = 16
dataset = load_dataset("ChaiML/dalio_scored_responses_v1")["train"]
valid_ds = dataset[:val_split]
train_ds = dataset[val_split:]
eval_prompts = valid_ds["input_text"]
dataset = (
list(zip(
train_ds["input_text"],
train_ds["output_text"],
)),
train_ds["score"]
)
trlx.train(
dataset=dataset,
config=config,
eval_prompts=eval_prompts,
)
if __name__ == "__main__":
main() |
@@ -36,7 +36,6 @@ def metric_fn(samples: List[str]) -> Dict[str, List[float]]: | |||
imdb = load_dataset("imdb", split="train+test") | |||
|
|||
trlx.train( | |||
"gpt2", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed
@@ -48,18 +51,20 @@ def heads(self, hidden_size: int, vocab_size: int): | |||
|
|||
def loss(self, outputs, labels: ILQLBatch): | |||
logits, (qs, target_qs, vs) = outputs | |||
terminal_mask = labels.dones[:, :-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remind why we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the same as attention_mask
except only accounting states_ixs
(subset of tokens from which continuations were sampled), plus the last token for which is also masked. It lets V(terminal) = 0
in V[:, 1:] * terminal_mask[:, 1:]
and also masks padding in other losses
for phrase in sample: | ||
if isoutput: | ||
actions_ixs.append( | ||
torch.arange(length - 1, length + len(phrase) - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do states and actions append exactly the same thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can be refactored, the difference between the two is few lines below
all_input_ids.append(torch.tensor(sum(sample, []))) | ||
isoutput = False | ||
actions_ixs, states_ixs = [], [] | ||
for phrase in sample: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So a sample can contain multiple "phrases" which we can reward?
Is there a reason we can't just separate each phrase into it's own datapoint with the "output" at the end(with all the prior dialog as context)? Do you see some advantage to treating things this way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, both are supported and opt-in
dialogue: Union[str, List[str]], tokenizer, max_length=2048, truncation_side="left" | ||
) -> List[int]: | ||
""" | ||
Tokenize sample with the interleaved form of (question_1, answer_1, question_2, answer_2...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we intend for answer_2 to depend on question_1, answer_1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, ILQL was designed to do this
break | ||
|
||
# in case of odd number of phrases (possibly due to truncation) | ||
# since the first phrase always has to be a prompt, force it to be <bos> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the last phrase always be a prompt? Why must first always be prompt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Counting from the left the last one should be action, otherwise there would be nothing to prompt and it would be redundant since the trajectory (more specifically actions) might as well be graded without it. At least datasets we are interested in, to my knowledge, are of this form
out[0].pop(0) | ||
out.insert(0, [tokenizer.bos_token_id]) | ||
|
||
elif truncation_side == "right": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When would we want this? If we truncate on the right the agent is missing the most recent dialog
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function accepts both dialogues and monologues and it's in anticipation of
Left some comments. In general I'm kinda skeptical about the need to support multiple "actions" (responses) in one trajectory. I don't see why we can't handle this just by splitting a dialog with 10 interactions into 10 samples, each of only has one action (at the very end of the dialog). Perhaps including it is fine but I think it should be optional (Sorry if I misunderstood something) |
No Alex, this is a great question. The difference would be when you don't have intermediate rewards for each cumulative interaction |
This pr
split_token
and instead delegating preprocessing per examplehttps://api.wandb.ai/report/sorry/t4w326pt
https://api.wandb.ai/report/sorry/93xza5j0
The difference in randomwalks is due to removing
<bos>
from samples and specifying starting nodes directly instead