-
Notifications
You must be signed in to change notification settings - Fork 309
Light-weight benchmarking script #664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
chenmoneygithub
merged 4 commits into
keras-team:master
from
NusretOzates:sentinement_analysis_benchmark
Jan 19, 2023
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
2975e0c
Implemented sentiment analysis benchmark for Classifiers using IMDB r…
NusretOzates 63fdf3d
Script usage code updated and unnecessary spaces removed. Val and tes…
NusretOzates d6a2b7b
Made "model" flag required using absl and made mixed precision policy…
NusretOzates 22a0338
Fix flag parsing issue and line length over 80 chars
mattdangerw File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| # Copyright 2023 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import inspect | ||
| import time | ||
|
|
||
| import tensorflow as tf | ||
| import tensorflow_datasets as tfds | ||
| from absl import app | ||
| from absl import flags | ||
| from tensorflow import keras | ||
|
|
||
| import keras_nlp | ||
|
|
||
| FLAGS = flags.FLAGS | ||
| flags.DEFINE_string( | ||
| "model", | ||
| None, | ||
| "The name of the classifier such as BertClassifier.", | ||
| ) | ||
| flags.DEFINE_string( | ||
| "preset", | ||
| None, | ||
| "The name of a preset, e.g. bert_base_multi.", | ||
| ) | ||
|
|
||
| flags.DEFINE_string( | ||
| "mixed_precision_policy", | ||
| "mixed_float16", | ||
| "The global precision policy to use. E.g. 'mixed_float16' or 'float32'.", | ||
| ) | ||
|
|
||
| flags.DEFINE_float("learning_rate", 5e-5, "The learning rate.") | ||
| flags.DEFINE_integer("num_epochs", 1, "The number of epochs.") | ||
| flags.DEFINE_integer("batch_size", 16, "The batch size.") | ||
|
|
||
| tfds.disable_progress_bar() | ||
|
|
||
| BUFFER_SIZE = 10000 | ||
|
|
||
|
|
||
| def create_imdb_dataset(): | ||
| dataset, info = tfds.load( | ||
| "imdb_reviews", as_supervised=True, with_info=True | ||
| ) | ||
| train_dataset, test_dataset = dataset["train"], dataset["test"] | ||
|
|
||
| train_dataset = ( | ||
| train_dataset.shuffle(BUFFER_SIZE) | ||
| .batch(FLAGS.batch_size) | ||
| .prefetch(tf.data.AUTOTUNE) | ||
| ) | ||
|
|
||
| # We split the test data evenly into validation and test sets. | ||
| test_dataset_size = info.splits["test"].num_examples // 2 | ||
|
|
||
| val_dataset = ( | ||
| test_dataset.take(test_dataset_size) | ||
| .batch(FLAGS.batch_size) | ||
| .prefetch(tf.data.AUTOTUNE) | ||
| ) | ||
| test_dataset = ( | ||
| test_dataset.skip(test_dataset_size) | ||
| .batch(FLAGS.batch_size) | ||
| .prefetch(tf.data.AUTOTUNE) | ||
| ) | ||
|
|
||
| return train_dataset, val_dataset, test_dataset | ||
|
|
||
|
|
||
| def create_model(): | ||
| for name, symbol in keras_nlp.models.__dict__.items(): | ||
| if inspect.isclass(symbol) and issubclass(symbol, keras.Model): | ||
| if FLAGS.model and name != FLAGS.model: | ||
| continue | ||
| if not hasattr(symbol, "from_preset"): | ||
| continue | ||
| for preset in symbol.presets: | ||
| if FLAGS.preset and preset != FLAGS.preset: | ||
| continue | ||
| model = symbol.from_preset(preset) | ||
| print(f"Using model {name} with preset {preset}") | ||
| return model | ||
|
|
||
| raise ValueError(f"Model {FLAGS.model} or preset {FLAGS.preset} not found.") | ||
|
|
||
|
|
||
| def train_model( | ||
| model: keras.Model, | ||
| train_dataset: tf.data.Dataset, | ||
| validation_dataset: tf.data.Dataset, | ||
| ): | ||
| model.compile( | ||
| loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
| optimizer=keras.optimizers.Adam(5e-5), | ||
| metrics=keras.metrics.SparseCategoricalAccuracy(), | ||
| jit_compile=True, | ||
| ) | ||
|
|
||
| model.fit( | ||
| train_dataset, | ||
| epochs=FLAGS.num_epochs, | ||
| validation_data=validation_dataset, | ||
| verbose=2, | ||
| ) | ||
|
|
||
| return model | ||
|
|
||
|
|
||
| def evaluate_model(model: keras.Model, test_dataset: tf.data.Dataset): | ||
| loss, accuracy = model.evaluate(test_dataset) | ||
| print(f"Test loss: {loss}") | ||
| print(f"Test accuracy: {accuracy}") | ||
|
|
||
|
|
||
| def main(_): | ||
| keras.mixed_precision.set_global_policy(FLAGS.mixed_precision_policy) | ||
|
|
||
| # Start time | ||
| start_time = time.time() | ||
|
|
||
| train_dataset, validation_dataset, test_dataset = create_imdb_dataset() | ||
| model = create_model() | ||
| model = train_model(model, train_dataset, validation_dataset) | ||
| evaluate_model(model, test_dataset) | ||
|
|
||
| # End time | ||
| end_time = time.time() | ||
| print(f"Total wall time: {end_time - start_time}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| flags.mark_flag_as_required("model") | ||
| app.run(main) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.