Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

denoise implemented as text_preprocessor or token_preprocessor #173

Closed
aiXander opened this issue Apr 15, 2020 · 6 comments
Closed

denoise implemented as text_preprocessor or token_preprocessor #173

aiXander opened this issue Apr 15, 2020 · 6 comments

Comments

@aiXander
Copy link

aiXander commented Apr 15, 2020

I'm trying to adapt the Colab notebook to train a model from scratch using a custom dataset + training objectives.

After shaping my data as a TSV TextLineDataset, I'm having issues implementing a denoising task.

from t5.data.preprocessors import *
def custom_preprocessor(dataset):
  print(dataset)
  dataset = denoise(dataset, vocab, 
                    noise_density = 0.15, 
                    noise_mask_fn = iid_noise_mask, 
                    inputs_fn = noise_token_to_sentinel, 
                    targets_fn=None)
  print(dataset)
  return dataset

t5.data.TaskRegistry.remove("denoising_objective")
t5.data.TaskRegistry.add(
    "denoising_objective",
    dataset_fn=dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text:
    text_preprocessor=[custom_preprocessor],
    # Supply a function which preprocesses tokens:
    #token_preprocessor=[custom_preprocessor],
    sentencepiece_model_path=vocab_model_path,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_examples
)

denoising_task = t5.data.TaskRegistry.get("denoising_objective")
ds = denoising_task.get_dataset(split="train", sequence_length={"inputs": 512, "targets": 512})

When running this I get the following output:

<DatasetV1Adapter shapes: {targets: ()}, types: {targets: tf.string}>
<DatasetV1Adapter shapes: {inputs: (1,), targets: ()}, types: {inputs: tf.string, targets: tf.string}>
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-20-9d72a0006afb> in <module>()
     26 
     27 denoising_task = t5.data.TaskRegistry.get("denoising_objective")
---> 28 ds = denoising_task.get_dataset(split="train", sequence_length={"inputs": 20, "targets": 20})

2 frames
/usr/local/lib/python3.6/dist-packages/t5/data/utils.py in _validate_dataset(self, dataset, expected_output_type, expected_output_rank, error_label, ensure_no_eos)
    500             "{label}: Got {actual}, expected {expected}".format(
    501                 feat=feat, label=error_label, actual=len(shapes[feat]),
--> 502                 expected=expected_output_rank))
    503 
    504     def _ensure_no_eos(feat, v):

ValueError: Task dataset has incorrect rank for feature 'inputs' after text preprocessing: Got 1, expected 0

However, when I provide the custom_preprocessor as a token_preprocessor everything works fine (but then I get downstream issues because there's no longer a 'targets_plaintext' field for evaluation...)

I feel like I'm missing a simple step here as the denoise function seems to return just a single token under the 'inputs' field rather than a string sequence, but I can't figure out where my mistake is...

Looking at the implementation of denoise (and the corresponding mask_fn), this doesn't make sense to me atm.. Any help would be greatly appreciated!

def denoise(dataset,
            vocabulary,
            noise_density=gin.REQUIRED,
            noise_mask_fn=gin.REQUIRED,
            inputs_fn=gin.REQUIRED,
            targets_fn=None,
            **unused_kwargs):
  ...
  def my_fn(features):
    tokens = features['targets']
    noise_mask = noise_mask_fn(tf.size(tokens), noise_density)
    inputs = inputs_fn(tokens, noise_mask, vocabulary)
    if targets_fn:
      targets = targets_fn(tokens, noise_mask, vocabulary)
    else:
      targets = tokens
    return {'inputs': inputs, 'targets': targets}
  return dataset.map(my_fn, num_parallel_calls=num_parallel_calls())
@craffel
Copy link
Contributor

craffel commented Apr 15, 2020

Hey Xavier, denoise is a token preprocessor so it should definitely be provided via the token_preprocessor arg. When you load the dataset (via denoising_task.get_dataset), inputs should be sequences of tokens. Is this not what you're seeing? If you post a full colab/gist of what you are trying to do with denoise as a token preprocessor I can take a look.

Regarding inputs_plaintext, the unsupervised text preprocessors (denoise) assume that your dataset does not have inputs and targets in its raw form (otherwise it would be a supervised dataset), so there is no notion of inputs until the token preprocessor comes around and creates them. You'd need to make a different preprocessing pipeline if you wanted inputs_plaintext/targets_plaintext with an unsupervised objective.

@aiXander
Copy link
Author

aiXander commented Apr 16, 2020

First of all, thanks for helping out so quickly, greatly appreciated!

Ok, so I had actually managed to train the model by providing the denoise preprocessor as a token_processor. My problem is that after training, I'd like to evaluate the model on the eval metrics (currently just accuracy) but it fails due to not having a 'inputs_plaintext' field:

# Use a larger batch size for evaluation, which requires less memory.
model.batch_size = train_batch_size * 4
model.eval(
    mixture_or_task_name="unsupervised_representation_learning",
    checkpoint_steps="all"
)

yields:

INFO:tensorflow:Using config: {'_model_dir': 'gs://t5_nlp_for_proteins/models/small', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.20.104.10:8470"
    }
  }
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.20.104.10:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.20.104.10:8470', '_evaluation_master': 'grpc://10.20.104.10:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=100, num_shards=None, num_cores_per_replica=1, per_host_input_for_training=4, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None, eval_training_input_configuration=2, experimental_host_call_every_n_steps=1), '_cluster': <tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver.TPUClusterResolver object at 0x7f83fc39f6a0>}
INFO:tensorflow:_TPUContext: eval_on_tpu True
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-54-84513320c975> in <module>()
      2 model.eval(
      3     mixture_or_task_name="unsupervised_representation_learning",
----> 4     checkpoint_steps="all"
      5 )

2 frames
/usr/local/lib/python3.6/dist-packages/mesh_tensorflow/transformer/utils.py in <listcomp>(.0)
   1261                 tf.compat.as_text(ex["targets_plaintext"]),
   1262                 example=ex, is_target=True)
-> 1263             for ex in examples
   1264         ]
   1265         targets_filename = os.path.join(

KeyError: 'targets_plaintext'

Do I need to provide a postprocessor to combine the token sequence back into a string?

my Task definition currently looks like this:

def text_preprocessor(dataset):
  
  def _to_inputs_and_targets(input_dict):
    seq_str = input_dict['targets']
    return {"inputs": seq_str, "targets": seq_str}

  return dataset.map(_to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)

def token_noise(dataset, vocabulary, **unused_kwargs):
  return denoise(dataset, vocabulary, 
                 noise_density = 0.15, 
                 noise_mask_fn = iid_noise_mask, 
                 inputs_fn = noise_token_to_sentinel, 
                 targets_fn=None)

t5.data.TaskRegistry.remove("unsupervised_denoising_task")
t5.data.TaskRegistry.add(
    "unsupervised_denoising_task",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=swissprot_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[text_preprocessor],
    # Supply a function which preprocesses tokens.
    token_preprocessor=[token_noise],
    # Supply the char-based vocab model.
    sentencepiece_model_path=vocab_model_path,
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy],
    # Not required, but helps for mixing and auto-caching.
    num_input_examples=num_examples
)

t5.data.MixtureRegistry.remove("unsupervised_representation_learning")
t5.data.MixtureRegistry.add(
    "unsupervised_representation_learning",
    ["unsupervised_denoising_task"],
     default_rate=1.0
)

@craffel
Copy link
Contributor

craffel commented Apr 16, 2020

Hmm, the fact that your text preprocessor returns inputs and targets should mean that you are getting plaintext for both inputs and targets. Can you share a full colab? If you have a dataset you don't want to share, feel free to swap in one of the TFDS datasets instead.

@aiXander
Copy link
Author

aiXander commented Apr 16, 2020

Hi Colin,

I want to experiment with applying T5 to amino acid sequences, my current Notebook is here.
(I made my GCS bucket publically readable; so you should be able to run it!)

There's also a few other issues I'm trying to fix like how to load a pretrained checkpoint file into the model object without using the train/finetune functions?

Any help is greatly appreciated!

Best regards,
Xander

@matchlesswei
Copy link

I have the same KeyError: 'targets_plaintext' when pre-training the unsupervised task and evaluate it.
I add the TaskRegistry as follows:

t5.data.TaskRegistry.remove('review_context_free')
t5.data.TaskRegistry.add(
    "review_context_free",
    dataset_fn = review_dataset_fn,
    splits=["train","validation"],
    text_preprocessor = [review_preprocessor],
    token_preprocessor=t5.data.preprocessors.unsupervised,
    sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,
    postprocess_fn=t5.data.postprocessors.lower_text,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_review_examples
)

When I tried to view the examples from this task using:

review_task = t5.data.TaskRegistry.get("review_context_free")
ds = review_task.get_dataset(split="validation", sequence_length={"inputs": 128, "targets": 32})
print("A few preprocessed validation examples...")
for ex in tfds.as_numpy(ds.take(1)):
  print(ex)

The output looks like this:

A few preprocessed validation examples...
{'inputs': array([ 31, 17, 1672, 396, 231, 5, 2, 115, 52, 3, 87, 3155, 4188, 6994, 31, 17, 129, 396, 13423, 81, 578, 2841, 32099, 81, 48, 733, 3, 18, 3, 15429, 2448, 141, 3, 14339, 13127, 5, 16, 32098, 6254, 4413, 3, 88, 47, 882, 964, 24, 1664, 3, 88, 47, 182, 24357, 81, 8, 3, 5490, 463, 13, 8, 999, 3, 32097, 19, 614, 12, 253, 231, 13, 48, 3071, 406, 20, 32096, 139, 8, 1819, 296, 21, 199, 233, 2, 115, 32095, 3, 87, 3155, 155, 31, 7, 1245, 12, 32094, 9, 113, 1819, 6, 68, 25, 32093, 16, 22118, 12, 8, 296, 1187, 8, 1861, 5, 3806, 1132, 32092, 3, 10, 48, 1974, 47, 81, 38, 1477, 11, 3, 31733, 32091, 18, 75, 12513, 15, 1]), 'targets': array([32099, 21029, 32098, 3, 89, 32097, 18, 34, 32096, 40, 3745, 32095, 52, 32094, 36, 3, 32093, 225, 470, 129, 396, 32092, 8181, 32091, 38, 3, 9, 4541, 3, 26459, 32090, 1])}

which is exactly without 'inputs_plaintext' and 'targets_plaintext' .

My text_processor is:

def review_preprocessor(ds):
  def normalize_text(text):
    text = tf.strings.lower(text)
    return text
  
  def to_inputs_and_targets(ex):
    return{
        "inputs": tf.strings.join(["generate mask reviw : ", normalize_text(ex["review"])]),
        "targets": tf.strings.join(["generate review mask : ", normalize_text(ex["review"])])
    }
  
  return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)

What is the correct way to set the text_processor and do the unsupervised train?

@adarob
Copy link
Collaborator

adarob commented May 20, 2020

The plaintext features aren't compatible with unsupervised pretraining. Simply remove the metrics and postprocessor and it should train.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants