Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support value attribution condition #6934

Merged
merged 24 commits into from
Jun 29, 2023
Merged

support value attribution condition #6934

merged 24 commits into from
Jun 29, 2023

Conversation

yidong72
Copy link
Collaborator

  1. support value attribute string prediction
  2. support value attribute sft
  3. update UI to allow user enter values.

Signed-off-by: Yi Dong <[email protected]>
Copy link
Contributor

@MaximumEntropy MaximumEntropy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few minor comments.

cur_idx = header_len
tgt_len = target.shape[0]
for i, (tokenized_len, speaker, s_id) in enumerate(zip(tokenized_lens, speakers, s_ids)):
# note, sentence piece will add extra empty token in front. s_id has that extra token too
skip_name_len = len(tokenizer.text_to_ids(TURN_TOKEN + speaker + END_NAME_SIGNAL))
if (s_id[1:] == 255002).any().item():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check detokenized string value instead of a hardcode value specific to a particular tokenizer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. but it is still associated with sentence piece as this tokenizer added an extra token in front. Need to test other tokenizers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. I fixed the code a bit and tested on the huggingface tokenizer and confirm it is working.

chat: False # use the chat interface
chatbot_config:
value: False # whether to inject the value attributes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you call this steerlm_config instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the paper is in anonymity. don't want to mention the name of the method.

@@ -26,21 +26,34 @@
END_SIGNAL = "\n"
END_NAME_SIGNAL = "\n"

SCALE = 9
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid making this a global variable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes it hard to experimentat with different kinds of value models.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this. only use the value string

Copy link
Collaborator

@Zhilin123 Zhilin123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments on readability

if isinstance(label, str):
return '<extra_id_2>' + label + '\n'
else:
raise ValueError(f'Unknown label type {type(label)}')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe more informative error msg (e.g. please have label in str format instead)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

@@ -157,14 +196,21 @@ def _build_samples_mapping(self):
assert hasattr(self.tokenizer, "vocab"), "tokenizer should have vocab property, not supported"
assert '<extra_id_0>' in self.tokenizer.vocab, "<extra_id_0> not in the tokenizer vocab. not supported"
assert '<extra_id_1>' in self.tokenizer.vocab, "<extra_id_1> not in the tokenizer vocab. not supported"
# calcuilate <extra_id_2> id value
if '<extra_id_2>' in self.tokenizer.vocab:
ids_1 = self.tokenizer.text_to_ids('<extra_id_1><extra_id_2>')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a typo?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why we can't get the text_to_ids('<extra_id_2>') directly. This looks pretty hacky

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to handle the sentencepiece tokenizer which adds a special token in front. A hacky way I agree.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw this check is super slow since self.tokenizer.vocab is a list not a dictionary/set. It is actually faster to do len(self.tokenizer.text_to_ids('<extra_id_2>')) == 1.

Copy link
Collaborator

@Zhilin123 Zhilin123 Jun 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, nit but isn't it just gonna be self.extra_id_2_token_id = self.tokenizer.text_to_ids('<extra_id_2>')[-1]

}


def _mask_targets(target, tokenized_lens, speakers, header_len, s_ids, tokenizer, mask_role):
def _mask_targets(target, tokenized_lens, speakers, header_len, s_ids, tokenizer, mask_role, gtype, extra_id_2_id):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you write a description on what the code is doing using an example such as (also include the control tokens)

e.g.

(user utterance1) (bot utterance1 ) (user utterance2) (bot utterance)
XXXXXXXXXX YYYY ZZZZZZZZZZZZ AAAA → loss on YYYY and AAAA together

as well as what each arg does?

This function is pretty hard to read without it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added doc

Signed-off-by: Yi Dong <[email protected]>
Zhilin123
Zhilin123 previously approved these changes Jun 28, 2023
Copy link
Collaborator

@Zhilin123 Zhilin123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good, with some remaining minor code style issues but we can address in future PR since code freeze is today.

Signed-off-by: Yi Dong <[email protected]>
@yidong72 yidong72 merged commit a27ba52 into main Jun 29, 2023
15 checks passed
@yidong72 yidong72 deleted the value_ds branch June 29, 2023 02:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants