Skip to content

Commit

Permalink
Update rngdet_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mjyun01 authored Feb 7, 2024
1 parent 2322237 commit a07dd62
Showing 1 changed file with 53 additions and 70 deletions.
123 changes: 53 additions & 70 deletions official/projects/rngdet/tasks/rngdet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@
import tensorflow_datasets as tfds
import os

from official.projects.rngdet import optimization
from official.modeling import optimization
from official.projects.rngdet.configs import rngdet as rngdet_cfg
from official.projects.rngdet.tasks import rngdet
from official.vision.configs import backbones
from official.vision.configs import decoders

import sys
import pdb;
_NUM_EXAMPLES = 10

CITYSCALE_INPUT_PATH_BASE = 'FIX_ME'

def _gen_fn():
h = 128
Expand All @@ -42,7 +44,6 @@ def _gen_fn():
'gt_masks': np.ones(shape=(h, w, num_query), dtype=np.uint8),
}


def _as_dataset(self, *args, **kwargs):
del args
del kwargs
Expand All @@ -52,15 +53,10 @@ def _as_dataset(self, *args, **kwargs):
output_shapes=self.info.features.shape,
)

#CITYSCALE_INPUT_PATH_BASE = '/home/ghpark/cityscale'
CITYSCALE_INPUT_PATH_BASE = '/home/ghpark.epiclab/03_rngdet/models/official/projects/rngdet'

class RngdetTest(tf.test.TestCase):

def test_train_step(self):
config = rngdet_cfg.RngdetTask(
init_checkpoint='gs://ghpark-ckpts/rngdet/test_02',
init_checkpoint_modules='all',
model=rngdet_cfg.Rngdet(
input_size=[128, 128, 3],
roi_size=128,
Expand All @@ -78,7 +74,6 @@ def test_train_step(self):
decoder=decoders.Decoder(
type='fpn',
fpn=decoders.FPN())

),
train_data=rngdet_cfg.DataConfig(
input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train-noise*'),
Expand All @@ -87,85 +82,73 @@ def test_train_step(self):
global_batch_size=3,
shuffle_buffer_size=1000,
))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = rngdet.RNGDetTask(config)
model = task.build_model()
dummy_images = tf.keras.Input([128, 128, 3])
dummy_history = tf.keras.Input([128, 128, 1])
_ = model(dummy_images, dummy_history, training=False)
#task.initialize(model)
#ckpt_dir_or_file = 'gs://ghpark-ckpts/rngdet/test_00'
ckpt_dir_or_file = '/home/ghpark.epiclab/03_rngdet/ckpt/test_02'
ckpt = tf.train.Checkpoint(
backbone=model.backbone,
backbone_history=model.backbone_history,
transformer=model.transformer,
segment_fpn=model._segment_fpn,
keypoint_fpn=model._keypoint_fpn,
query_embeddings=model._query_embeddings,
segment_head=model._segment_head,
keypoint_head=model._keypoint_head,
class_embed=model._class_embed,
bbox_embed=model._bbox_embed,
input_proj=model.input_proj)
status = ckpt.restore(tf.train.latest_checkpoint(ckpt_dir_or_file))
status.expect_partial().assert_existing_objects_matched()
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
print("LOAD CHECKPOINT DONE")
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")


dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
with tfds.testing.mock_data(as_dataset_fn=_as_dataset, num_examples=5):
task = rngdet.RNGDetTask(config)
model = task.build_model()
opt_cfg = optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'type': 'adamw_experimental',
'adamw_experimental': {
'epsilon': 1.0e-08,
'weight_decay': 1e-4,
'global_clipnorm': 0.1,
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [120000],
'values': [0.0001, 1.0e-05]
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.0001,
'end_learning_rate': 0.000001,
'offset': 0,
'power': 1.0,
'decay_steps': 50 * 10,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 2 * 10,
'warmup_learning_rate': 0,
},
},
})
optimizer = rngdet.RNGDetTask.create_optimizer(opt_cfg)
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)

task.train_step(next(iterator), model, optimizer)
dummy_images = tf.keras.Input([128, 128, 3])
dummy_history = tf.keras.Input([128, 128, 1])
_ = model(dummy_images, dummy_history, training=False)

"""def test_validation_step(self):
config = rngdet_cfg.DetrTask(
model=rngdet_cfg.Detr(
input_size=[1333, 1333, 3],
num_encoder_layers=1,
num_decoder_layers=1,
num_classes=81,
def test_validation_step(self):
config = rngdet_cfg.RngdetTask(
model=rngdet_cfg.Rngdet(
input_size=[128, 128, 3],
roi_size=128,
num_encoder_layers=6,
num_decoder_layers=6,
num_queries=10,
hidden_size=256,
num_classes=2,
min_level=2,
max_level=5,
backbone_endpoint_name='5',
backbone=backbones.Backbone(
type='resnet',
resnet=backbones.ResNet(model_id=10, bn_trainable=False))
),
validation_data=coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=False,
global_batch_size=2,
resnet=backbones.ResNet(model_id=50, bn_trainable=False)),
decoder=decoders.Decoder(
type='fpn',
fpn=decoders.FPN())

))

with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DetectionTask(config)
task = rngdet.RNGDetTask(config)
model = task.build_model()
metrics = task.build_metrics(training=False)
dataset = task.build_inputs(config.validation_data)
iterator = iter(dataset)
logs = task.validation_step(next(iterator), model, metrics)
state = task.aggregate_logs(step_outputs=logs)
task.reduce_aggregated_logs(state)"""

dummy_images = tf.keras.Input([128, 128, 3])
dummy_history = tf.keras.Input([128, 128, 1])
_ = model(dummy_images, dummy_history, training=False)

if __name__ == '__main__':
tf.test.main()

0 comments on commit a07dd62

Please sign in to comment.