-
Notifications
You must be signed in to change notification settings - Fork 761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
denoise implemented as text_preprocessor or token_preprocessor #173
Comments
Hey Xavier, Regarding |
First of all, thanks for helping out so quickly, greatly appreciated! Ok, so I had actually managed to train the model by providing the denoise preprocessor as a token_processor. My problem is that after training, I'd like to evaluate the model on the eval metrics (currently just accuracy) but it fails due to not having a 'inputs_plaintext' field: # Use a larger batch size for evaluation, which requires less memory.
model.batch_size = train_batch_size * 4
model.eval(
mixture_or_task_name="unsupervised_representation_learning",
checkpoint_steps="all"
) yields:
Do I need to provide a postprocessor to combine the token sequence back into a string? my Task definition currently looks like this: def text_preprocessor(dataset):
def _to_inputs_and_targets(input_dict):
seq_str = input_dict['targets']
return {"inputs": seq_str, "targets": seq_str}
return dataset.map(_to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def token_noise(dataset, vocabulary, **unused_kwargs):
return denoise(dataset, vocabulary,
noise_density = 0.15,
noise_mask_fn = iid_noise_mask,
inputs_fn = noise_token_to_sentinel,
targets_fn=None)
t5.data.TaskRegistry.remove("unsupervised_denoising_task")
t5.data.TaskRegistry.add(
"unsupervised_denoising_task",
# Supply a function which returns a tf.data.Dataset.
dataset_fn=swissprot_dataset_fn,
splits=["train", "validation"],
# Supply a function which preprocesses text from the tf.data.Dataset.
text_preprocessor=[text_preprocessor],
# Supply a function which preprocesses tokens.
token_preprocessor=[token_noise],
# Supply the char-based vocab model.
sentencepiece_model_path=vocab_model_path,
# We'll use accuracy as our evaluation metric.
metric_fns=[t5.evaluation.metrics.accuracy],
# Not required, but helps for mixing and auto-caching.
num_input_examples=num_examples
)
t5.data.MixtureRegistry.remove("unsupervised_representation_learning")
t5.data.MixtureRegistry.add(
"unsupervised_representation_learning",
["unsupervised_denoising_task"],
default_rate=1.0
) |
Hmm, the fact that your text preprocessor returns inputs and targets should mean that you are getting plaintext for both inputs and targets. Can you share a full colab? If you have a dataset you don't want to share, feel free to swap in one of the TFDS datasets instead. |
Hi Colin, I want to experiment with applying T5 to amino acid sequences, my current Notebook is here. There's also a few other issues I'm trying to fix like how to load a pretrained checkpoint file into the model object without using the train/finetune functions? Any help is greatly appreciated! Best regards, |
I have the same KeyError: 'targets_plaintext' when pre-training the unsupervised task and evaluate it.
When I tried to view the examples from this task using:
The output looks like this:
which is exactly without My text_processor is:
What is the correct way to set the text_processor and do the unsupervised train? |
The plaintext features aren't compatible with unsupervised pretraining. Simply remove the metrics and postprocessor and it should train. |
I'm trying to adapt the Colab notebook to train a model from scratch using a custom dataset + training objectives.
After shaping my data as a TSV TextLineDataset, I'm having issues implementing a denoising task.
When running this I get the following output:
However, when I provide the custom_preprocessor as a token_preprocessor everything works fine (but then I get downstream issues because there's no longer a 'targets_plaintext' field for evaluation...)
I feel like I'm missing a simple step here as the denoise function seems to return just a single token under the 'inputs' field rather than a string sequence, but I can't figure out where my mistake is...
Looking at the implementation of denoise (and the corresponding mask_fn), this doesn't make sense to me atm.. Any help would be greatly appreciated!
The text was updated successfully, but these errors were encountered: