-
Notifications
You must be signed in to change notification settings - Fork 309
Add Base Task Class #671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add Base Task Class #671
Changes from 5 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
78a25cc
Add Base Task Class
abheesht17 0adf63d
Merge branch 'keras-team:master' into base-task
abheesht17 d2e9de5
Minor fix
abheesht17 6db9d6d
Fixes
abheesht17 6fb860f
Minor edit
abheesht17 8724f81
Address comments
abheesht17 d843b90
Merge branch 'master' into base-task
abheesht17 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,6 @@ | |
| """BERT classification model.""" | ||
|
|
||
| import copy | ||
| import os | ||
|
|
||
| from tensorflow import keras | ||
|
|
||
|
|
@@ -23,15 +22,15 @@ | |
| from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor | ||
| from keras_nlp.models.bert.bert_presets import backbone_presets | ||
| from keras_nlp.models.bert.bert_presets import classifier_presets | ||
| from keras_nlp.utils.pipeline_model import PipelineModel | ||
| from keras_nlp.models.task import Task | ||
| from keras_nlp.utils.python_utils import classproperty | ||
| from keras_nlp.utils.python_utils import format_docstring | ||
|
|
||
| PRESET_NAMES = ", ".join(list(backbone_presets) + list(classifier_presets)) | ||
|
|
||
|
|
||
| @keras.utils.register_keras_serializable(package="keras_nlp") | ||
| class BertClassifier(PipelineModel): | ||
| class BertClassifier(Task): | ||
| """An end-to-end BERT model for classification tasks | ||
|
|
||
| This model attaches a classification head to a `keras_nlp.model.BertBackbone` | ||
|
|
@@ -124,164 +123,47 @@ def __init__( | |
| self._backbone = backbone | ||
| self._preprocessor = preprocessor | ||
| self.num_classes = num_classes | ||
|
|
||
| def preprocess_samples(self, x, y=None, sample_weight=None): | ||
| return self.preprocessor(x, y=y, sample_weight=sample_weight) | ||
|
|
||
| @property | ||
| def backbone(self): | ||
| """A `keras_nlp.models.BertBackbone` instance providing the encoder | ||
| submodel. | ||
| """ | ||
| return self._backbone | ||
|
|
||
| @property | ||
| def preprocessor(self): | ||
| """A `keras_nlp.models.BertPreprocessor` for preprocessing inputs.""" | ||
| return self._preprocessor | ||
| self.dropout = dropout | ||
|
|
||
| def get_config(self): | ||
| return { | ||
| "backbone": keras.layers.serialize(self.backbone), | ||
| "preprocessor": keras.layers.serialize(self.preprocessor), | ||
| "num_classes": self.num_classes, | ||
| "name": self.name, | ||
| "trainable": self.trainable, | ||
| } | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "num_classes": self.num_classes, | ||
| "dropout": self.dropout, | ||
| "name": self.name, | ||
| "trainable": self.trainable, | ||
| } | ||
| ) | ||
| return config | ||
|
|
||
| @classmethod | ||
| def from_config(cls, config): | ||
| if "backbone" in config and isinstance(config["backbone"], dict): | ||
| config["backbone"] = keras.layers.deserialize(config["backbone"]) | ||
| if "preprocessor" in config and isinstance( | ||
| config["preprocessor"], dict | ||
| ): | ||
| config["preprocessor"] = keras.layers.deserialize( | ||
| config["preprocessor"] | ||
| ) | ||
| return cls(**config) | ||
| @classproperty | ||
| def backbone_cls(cls): | ||
| return BertBackbone | ||
|
|
||
| @classproperty | ||
| def preprocessor_cls(cls): | ||
| return BertPreprocessor | ||
|
|
||
| @classproperty | ||
| def presets(cls): | ||
| return copy.deepcopy({**backbone_presets, **classifier_presets}) | ||
|
|
||
| @classmethod | ||
| @format_docstring(names=PRESET_NAMES) | ||
| def from_preset( | ||
| cls, | ||
| preset, | ||
| load_weights=True, | ||
| **kwargs, | ||
| ): | ||
| """Create a classification model from a preset architecture and weights. | ||
|
|
||
| By default, this method will automatically create a `preprocessor` | ||
| layer to preprocess raw inputs during `fit()`, `predict()`, and | ||
| `evaluate()`. If you would like to disable this behavior, pass | ||
| `preprocessor=None`. | ||
|
|
||
| Args: | ||
| preset: string. Must be one of {{names}}. | ||
| load_weights: Whether to load pre-trained weights into model. | ||
| Defaults to `True`. | ||
|
|
||
| Examples: | ||
|
|
||
| Raw string inputs. | ||
| ```python | ||
| # Create a dataset with raw string features in an `(x, y)` format. | ||
| features = ["The quick brown fox jumped.", "I forgot my homework."] | ||
| labels = [0, 3] | ||
|
|
||
| # Create a BertClassifier and fit your data. | ||
| classifier = keras_nlp.models.BertClassifier.from_preset( | ||
| "bert_base_en_uncased", | ||
| num_classes=4, | ||
| ) | ||
| classifier.compile( | ||
| loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
| return super().from_preset( | ||
| preset=preset, load_weights=load_weights, **kwargs | ||
| ) | ||
| classifier.fit(x=features, y=labels, batch_size=2) | ||
| ``` | ||
|
|
||
| Raw string inputs with customized preprocessing. | ||
jbischof marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ```python | ||
| # Create a dataset with raw string features in an `(x, y)` format. | ||
| features = ["The quick brown fox jumped.", "I forgot my homework."] | ||
| labels = [0, 3] | ||
|
|
||
| # Use a shorter sequence length. | ||
| preprocessor = keras_nlp.models.BertPreprocessor.from_preset( | ||
| "bert_base_en_uncased", | ||
| sequence_length=128, | ||
| ) | ||
|
|
||
| # Create a BertClassifier and fit your data. | ||
| classifier = keras_nlp.models.BertClassifier.from_preset( | ||
| "bert_base_en_uncased", | ||
| num_classes=4, | ||
| preprocessor=preprocessor, | ||
| ) | ||
| classifier.compile( | ||
| loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
| ) | ||
| classifier.fit(x=features, y=labels, batch_size=2) | ||
| ``` | ||
|
|
||
| Preprocessed inputs. | ||
| ```python | ||
| # Create a dataset with preprocessed features in an `(x, y)` format. | ||
| preprocessed_features = { | ||
| "token_ids": tf.ones(shape=(2, 12), dtype=tf.int64), | ||
| "segment_ids": tf.constant( | ||
| [[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12) | ||
| ), | ||
| "padding_mask": tf.constant( | ||
| [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12) | ||
| ), | ||
| } | ||
| labels = [0, 3] | ||
|
|
||
| # Create a BERT classifier and fit your data. | ||
| classifier = keras_nlp.models.BertClassifier.from_preset( | ||
| "bert_base_en_uncased", | ||
| num_classes=4, | ||
| preprocessor=None, | ||
| ) | ||
| classifier.compile( | ||
| loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
| ) | ||
| classifier.fit(x=preprocessed_features, y=labels, batch_size=2) | ||
| ``` | ||
| """ | ||
| if preset not in cls.presets: | ||
| raise ValueError( | ||
| "`preset` must be one of " | ||
| f"""{", ".join(cls.presets)}. Received: {preset}.""" | ||
| ) | ||
|
|
||
| if "preprocessor" not in kwargs: | ||
| kwargs["preprocessor"] = BertPreprocessor.from_preset(preset) | ||
|
|
||
| # Check if preset is backbone-only model | ||
| if preset in BertBackbone.presets: | ||
| backbone = BertBackbone.from_preset(preset, load_weights) | ||
| return cls(backbone, **kwargs) | ||
|
|
||
| # Otherwise must be one of class presets | ||
| metadata = cls.presets[preset] | ||
| config = metadata["config"] | ||
| model = cls.from_config({**config, **kwargs}) | ||
|
|
||
| if not load_weights: | ||
| return model | ||
|
|
||
| weights = keras.utils.get_file( | ||
| "model.h5", | ||
| metadata["weights_url"], | ||
| cache_subdir=os.path.join("models", preset), | ||
| file_hash=metadata["weights_hash"], | ||
| ) | ||
|
|
||
| model.load_weights(weights) | ||
| return model | ||
| BertClassifier.from_preset.__func__.__doc__ = Task.from_preset.__doc__ | ||
|
||
| format_docstring( | ||
| model_task_name=BertClassifier.__name__, | ||
| example_preset_name="bert_base_en_uncased", | ||
| preset_names=PRESET_NAMES, | ||
| )(BertClassifier.from_preset.__func__) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could
"name"and"trainable"be added to theTaskget_config? I guess it would be awkward to look for them inkwargsto pass tosuper().get_config()?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make this change. I was going to do the same for backbones.
Having name and trainable handle in a base class is actually the norm, so it will improve readability here.