Skip to content

Commit 299c1fd

Browse files
committed
Add flags
1 parent 57d292c commit 299c1fd

File tree

1 file changed

+28
-6
lines changed

1 file changed

+28
-6
lines changed

tools/count_preset_params.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,40 @@
1616
1717
Usage:
1818
python tools/count_preset_params.py
19+
python tools/count_preset_params.py --model BertBackbone
20+
python tools/count_preset_params.py --preset bert_base_multi
1921
"""
2022

2123
import inspect
2224

25+
from absl import app
26+
from absl import flags
2327
from keras.utils.layer_utils import count_params
2428
from tensorflow import keras
2529

2630
import keras_nlp
2731

28-
for name, symbol in keras_nlp.models.__dict__.items():
29-
if inspect.isclass(symbol) and issubclass(symbol, keras.Model):
30-
for preset in symbol.presets:
31-
model = symbol.from_preset(preset)
32-
params = count_params(model.weights)
33-
print(f"{name} {preset} {params}")
32+
FLAGS = flags.FLAGS
33+
flags.DEFINE_string("model", None, "The name of a model, e.g. BertBackbone.")
34+
flags.DEFINE_string(
35+
"preset", None, "The name of a preset, e.g. bert_base_multi."
36+
)
37+
38+
39+
def main(_):
40+
for name, symbol in keras_nlp.models.__dict__.items():
41+
if inspect.isclass(symbol) and issubclass(symbol, keras.Model):
42+
if FLAGS.model and name != FLAGS.model:
43+
continue
44+
if not hasattr(symbol, "from_preset"):
45+
continue
46+
for preset in symbol.presets:
47+
if FLAGS.preset and preset != FLAGS.preset:
48+
continue
49+
model = symbol.from_preset(preset)
50+
params = count_params(model.weights)
51+
print(f"{name} {preset} {params}")
52+
53+
54+
if __name__ == "__main__":
55+
app.run(main)

0 commit comments

Comments
 (0)