|
16 | 16 |
|
17 | 17 | Usage: |
18 | 18 | python tools/count_preset_params.py |
| 19 | +python tools/count_preset_params.py --model BertBackbone |
| 20 | +python tools/count_preset_params.py --preset bert_base_multi |
19 | 21 | """ |
20 | 22 |
|
21 | 23 | import inspect |
22 | 24 |
|
| 25 | +from absl import app |
| 26 | +from absl import flags |
23 | 27 | from keras.utils.layer_utils import count_params |
24 | 28 | from tensorflow import keras |
25 | 29 |
|
26 | 30 | import keras_nlp |
27 | 31 |
|
28 | | -for name, symbol in keras_nlp.models.__dict__.items(): |
29 | | - if inspect.isclass(symbol) and issubclass(symbol, keras.Model): |
30 | | - for preset in symbol.presets: |
31 | | - model = symbol.from_preset(preset) |
32 | | - params = count_params(model.weights) |
33 | | - print(f"{name} {preset} {params}") |
| 32 | +FLAGS = flags.FLAGS |
| 33 | +flags.DEFINE_string("model", None, "The name of a model, e.g. BertBackbone.") |
| 34 | +flags.DEFINE_string( |
| 35 | + "preset", None, "The name of a preset, e.g. bert_base_multi." |
| 36 | +) |
| 37 | + |
| 38 | + |
| 39 | +def main(_): |
| 40 | + for name, symbol in keras_nlp.models.__dict__.items(): |
| 41 | + if inspect.isclass(symbol) and issubclass(symbol, keras.Model): |
| 42 | + if FLAGS.model and name != FLAGS.model: |
| 43 | + continue |
| 44 | + if not hasattr(symbol, "from_preset"): |
| 45 | + continue |
| 46 | + for preset in symbol.presets: |
| 47 | + if FLAGS.preset and preset != FLAGS.preset: |
| 48 | + continue |
| 49 | + model = symbol.from_preset(preset) |
| 50 | + params = count_params(model.weights) |
| 51 | + print(f"{name} {preset} {params}") |
| 52 | + |
| 53 | + |
| 54 | +if __name__ == "__main__": |
| 55 | + app.run(main) |
0 commit comments