diff --git a/tools/count_preset_params.py b/tools/count_preset_params.py new file mode 100644 index 0000000000..3a676ca17b --- /dev/null +++ b/tools/count_preset_params.py @@ -0,0 +1,55 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Small utility script to count parameters in our preset checkpoints. + +Usage: +python tools/count_preset_params.py +python tools/count_preset_params.py --model BertBackbone +python tools/count_preset_params.py --preset bert_base_multi +""" + +import inspect + +from absl import app +from absl import flags +from keras.utils.layer_utils import count_params +from tensorflow import keras + +import keras_nlp + +FLAGS = flags.FLAGS +flags.DEFINE_string("model", None, "The name of a model, e.g. BertBackbone.") +flags.DEFINE_string( + "preset", None, "The name of a preset, e.g. bert_base_multi." +) + + +def main(_): + for name, symbol in keras_nlp.models.__dict__.items(): + if inspect.isclass(symbol) and issubclass(symbol, keras.Model): + if FLAGS.model and name != FLAGS.model: + continue + if not hasattr(symbol, "from_preset"): + continue + for preset in symbol.presets: + if FLAGS.preset and preset != FLAGS.preset: + continue + model = symbol.from_preset(preset) + params = count_params(model.weights) + print(f"{name} {preset} {params}") + + +if __name__ == "__main__": + app.run(main)