Skip to content

Commit

Permalink
Fix sft dataset truncation (#7464)
Browse files Browse the repository at this point in the history
* Add fix

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

---------

Signed-off-by: Cheng-Ping Hsieh <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
hsiehjackson and pre-commit-ci[bot] committed Sep 20, 2023
1 parent c8894ce commit 19a3b70
Showing 1 changed file with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,10 @@ def _multiple_truncation(self, template_ids: List[List[int]], template_ids_keys:
for i, (ids, key) in enumerate(zip(template_ids, template_ids_keys)):
if key in self.truncation_fields:
truncation_length = truncation_length_list.pop()
assert len(ids) >= truncation_length, f'{key} is not long enough to truncate.'
if len(ids) < truncation_length:
logging.warning(f'{key} is not long enough to truncate.')
truncation_length = len(ids)

if self.truncation_method == 'left':
window_offset = truncation_length
elif self.truncation_method == 'right':
Expand Down Expand Up @@ -328,6 +331,7 @@ def _process_example(self, example):
if len(input_ids) > self.max_seq_length:
logging.warning(f'Input ids length {len(input_ids)} exceed max sequence length {self.max_seq_length}')
input_ids = input_ids[: self.max_seq_length]
answer_ids = input_ids[answer_start_idx:]

# store metadata in dataset, in case user may have keys required in the prediction json files
metadata = {k: v for k, v in example.items() if k not in self.prompt_template_keys}
Expand Down

0 comments on commit 19a3b70

Please sign in to comment.