Add Support for Passing Pretokenized Datasets to TRL#166
Add Support for Passing Pretokenized Datasets to TRL#166alex-jw-brooks wants to merge 8 commits intofoundation-model-stack:mainfrom
Conversation
| ) | ||
|
|
||
| ### Utils for custom masking / manipulating input / output strs, etc | ||
| def combine_sequence(input_element: str, output_element: str): |
There was a problem hiding this comment.
with my upcoming change to accept a template in API - which will be like verbalizer field. In future we ll need to apply template while combining sequence - which will be a minor addition, that way we can accept "input/output" + custom template. Need not worry about that now though
There was a problem hiding this comment.
Cool sounds good - that is partially why I wrote things this way, even though the input/output are hardcoded in the thing calling this, so we can just pass everything through here as needed
There was a problem hiding this comment.
are the templates going to be jinja style like this? https://huggingface.co/docs/transformers/en/chat_templating
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
de61112 to
5520d92
Compare
| data_kwargs = {} | ||
| if isinstance(data_collator, DataCollatorForSeq2Seq): | ||
| # HACK: This function is never called, but is needed to sidestep TRL's internal validation. | ||
| data_kwargs["formatting_func"] = lambda x: x |
There was a problem hiding this comment.
I was implementing a kind of silly workaround for TRL's validation logic with wrapping the tokenizer to make tokenize() a noop when I realized that there is actually already logic for handling pretokenized data based on the columns of the passed dataset (ref). Currently, all we have to do to handle the tokenized data is pass a dummy formatting function, because _prepare_dataset will inspect the dataset and return immediately without ever calling the value passed here.
There is additionally an extra argument, dataset_kwargs, which is a dict that can pass skip_prepare_dataset, which is bool-valued. I think the correct behavior is to just check that value in TRL and skip the validation for the packing=True case, which is a small change that does not affect its API or change the supported data formats.
I have opened an issue / pull request in TRL to this effect: huggingface/trl#1673
If this change is merged, I will open another PR to remove this hack
1c22429 to
7fc6478
Compare
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
7fc6478 to
2b38589
Compare
|
@alex-jw-brooks Are we using huggingface/trl#1520 ? |
| return input_element + output_element | ||
|
|
||
|
|
||
| def preprocess_and_tokenize( |
There was a problem hiding this comment.
I think here we ll need to accept template as argument and use template to combine sequence
| from tuning.config import configs | ||
|
|
||
|
|
||
| def get_data_trainer_kwargs( |
There was a problem hiding this comment.
I am just wondering if this wrapper function is hiding too much from main code. Wondering if it would make sense to move this back to main code, esp once the hack is removed . as its basically 2 steps ->
get collator, get_formatted_dataset -> internally data formatting might happen in different ways and that can be combined to 1 function.
I am still thinking about it, but on initial thoughts I feel it might be good to know high level steps in train()
|
Closing this PR, as it was refactored and merged in #260 |
Description of the change
Related issue number
How to verify the PR
Was the PR tested