diff --git a/official/projects/rngdet/tasks/rngdet_test.py b/official/projects/rngdet/tasks/rngdet_test.py index ccab141ded..c5b7f5fbd5 100644 --- a/official/projects/rngdet/tasks/rngdet_test.py +++ b/official/projects/rngdet/tasks/rngdet_test.py @@ -19,14 +19,16 @@ import tensorflow_datasets as tfds import os -from official.projects.rngdet import optimization +from official.modeling import optimization from official.projects.rngdet.configs import rngdet as rngdet_cfg from official.projects.rngdet.tasks import rngdet from official.vision.configs import backbones from official.vision.configs import decoders - +import sys +import pdb; _NUM_EXAMPLES = 10 +CITYSCALE_INPUT_PATH_BASE = 'FIX_ME' def _gen_fn(): h = 128 @@ -42,7 +44,6 @@ def _gen_fn(): 'gt_masks': np.ones(shape=(h, w, num_query), dtype=np.uint8), } - def _as_dataset(self, *args, **kwargs): del args del kwargs @@ -52,15 +53,10 @@ def _as_dataset(self, *args, **kwargs): output_shapes=self.info.features.shape, ) -#CITYSCALE_INPUT_PATH_BASE = '/home/ghpark/cityscale' -CITYSCALE_INPUT_PATH_BASE = '/home/ghpark.epiclab/03_rngdet/models/official/projects/rngdet' - class RngdetTest(tf.test.TestCase): def test_train_step(self): config = rngdet_cfg.RngdetTask( - init_checkpoint='gs://ghpark-ckpts/rngdet/test_02', - init_checkpoint_modules='all', model=rngdet_cfg.Rngdet( input_size=[128, 128, 3], roi_size=128, @@ -78,7 +74,6 @@ def test_train_step(self): decoder=decoders.Decoder( type='fpn', fpn=decoders.FPN()) - ), train_data=rngdet_cfg.DataConfig( input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train-noise*'), @@ -87,85 +82,73 @@ def test_train_step(self): global_batch_size=3, shuffle_buffer_size=1000, )) - with tfds.testing.mock_data(as_dataset_fn=_as_dataset): - task = rngdet.RNGDetTask(config) - model = task.build_model() - dummy_images = tf.keras.Input([128, 128, 3]) - dummy_history = tf.keras.Input([128, 128, 1]) - _ = model(dummy_images, dummy_history, training=False) - #task.initialize(model) - #ckpt_dir_or_file = 'gs://ghpark-ckpts/rngdet/test_00' - ckpt_dir_or_file = '/home/ghpark.epiclab/03_rngdet/ckpt/test_02' - ckpt = tf.train.Checkpoint( - backbone=model.backbone, - backbone_history=model.backbone_history, - transformer=model.transformer, - segment_fpn=model._segment_fpn, - keypoint_fpn=model._keypoint_fpn, - query_embeddings=model._query_embeddings, - segment_head=model._segment_head, - keypoint_head=model._keypoint_head, - class_embed=model._class_embed, - bbox_embed=model._bbox_embed, - input_proj=model.input_proj) - status = ckpt.restore(tf.train.latest_checkpoint(ckpt_dir_or_file)) - status.expect_partial().assert_existing_objects_matched() - print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@") - print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@") - print("LOAD CHECKPOINT DONE") - print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@") - print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@") - - - dataset = task.build_inputs(config.train_data) - iterator = iter(dataset) + with tfds.testing.mock_data(as_dataset_fn=_as_dataset, num_examples=5): + task = rngdet.RNGDetTask(config) + model = task.build_model() opt_cfg = optimization.OptimizationConfig({ 'optimizer': { - 'type': 'detr_adamw', - 'detr_adamw': { - 'weight_decay_rate': 1e-4, + 'type': 'adamw_experimental', + 'adamw_experimental': { + 'epsilon': 1.0e-08, + 'weight_decay': 1e-4, 'global_clipnorm': 0.1, } }, 'learning_rate': { - 'type': 'stepwise', - 'stepwise': { - 'boundaries': [120000], - 'values': [0.0001, 1.0e-05] + 'type': 'polynomial', + 'polynomial': { + 'initial_learning_rate': 0.0001, + 'end_learning_rate': 0.000001, + 'offset': 0, + 'power': 1.0, + 'decay_steps': 50 * 10, } }, + 'warmup': { + 'type': 'linear', + 'linear': { + 'warmup_steps': 2 * 10, + 'warmup_learning_rate': 0, + }, + }, }) optimizer = rngdet.RNGDetTask.create_optimizer(opt_cfg) + dataset = task.build_inputs(config.train_data) + iterator = iter(dataset) + task.train_step(next(iterator), model, optimizer) + dummy_images = tf.keras.Input([128, 128, 3]) + dummy_history = tf.keras.Input([128, 128, 1]) + _ = model(dummy_images, dummy_history, training=False) - """def test_validation_step(self): - config = rngdet_cfg.DetrTask( - model=rngdet_cfg.Detr( - input_size=[1333, 1333, 3], - num_encoder_layers=1, - num_decoder_layers=1, - num_classes=81, + def test_validation_step(self): + config = rngdet_cfg.RngdetTask( + model=rngdet_cfg.Rngdet( + input_size=[128, 128, 3], + roi_size=128, + num_encoder_layers=6, + num_decoder_layers=6, + num_queries=10, + hidden_size=256, + num_classes=2, + min_level=2, + max_level=5, + backbone_endpoint_name='5', backbone=backbones.Backbone( type='resnet', - resnet=backbones.ResNet(model_id=10, bn_trainable=False)) - ), - validation_data=coco.COCODataConfig( - tfds_name='coco/2017', - tfds_split='validation', - is_training=False, - global_batch_size=2, + resnet=backbones.ResNet(model_id=50, bn_trainable=False)), + decoder=decoders.Decoder( + type='fpn', + fpn=decoders.FPN()) + )) with tfds.testing.mock_data(as_dataset_fn=_as_dataset): - task = detection.DetectionTask(config) + task = rngdet.RNGDetTask(config) model = task.build_model() - metrics = task.build_metrics(training=False) - dataset = task.build_inputs(config.validation_data) - iterator = iter(dataset) - logs = task.validation_step(next(iterator), model, metrics) - state = task.aggregate_logs(step_outputs=logs) - task.reduce_aggregated_logs(state)""" - + dummy_images = tf.keras.Input([128, 128, 3]) + dummy_history = tf.keras.Input([128, 128, 1]) + _ = model(dummy_images, dummy_history, training=False) if __name__ == '__main__': tf.test.main()