diff --git a/python/ray/rllib/models/visionnet.py b/python/ray/rllib/models/visionnet.py index 4105af7dd367..1d856e42cec4 100644 --- a/python/ray/rllib/models/visionnet.py +++ b/python/ray/rllib/models/visionnet.py @@ -16,7 +16,7 @@ def _build_layers_v2(self, input_dict, num_outputs, options): inputs = input_dict["obs"] filters = options.get("conv_filters") if not filters: - filters = get_filter_config(options) + filters = get_filter_config(inputs) activation = get_activation_fn(options.get("conv_activation")) @@ -47,7 +47,7 @@ def _build_layers_v2(self, input_dict, num_outputs, options): return flatten(fc2), flatten(fc1) -def get_filter_config(options): +def get_filter_config(inputs): filters_84x84 = [ [16, [8, 8], 4], [32, [4, 4], 2], @@ -58,12 +58,15 @@ def get_filter_config(options): [32, [4, 4], 2], [256, [11, 11], 1], ] - dim = options.get("dim") - if dim == 84: + shape = inputs.shape.as_list()[1:] + if len(shape) == 3 and shape[:2] == [84, 84]: return filters_84x84 - elif dim == 42: + elif len(shape) == 3 and shape[:2] == [42, 42]: return filters_42x42 else: raise ValueError( - "No default configuration for image size={}".format(dim) + - ", you must specify `conv_filters` manually as a model option.") + "No default configuration for obs input {}".format(inputs) + + ", you must specify `conv_filters` manually as a model option. " + "Default configurations are only available for inputs of size " + "[?, 42, 42, K] and [?, 84, 84, K]. You may alternatively want " + "to use a custom model or preprocessor.") diff --git a/python/ray/rllib/test/test_catalog.py b/python/ray/rllib/test/test_catalog.py index 852a02fc4d1e..efa1aba0e2f0 100644 --- a/python/ray/rllib/test/test_catalog.py +++ b/python/ray/rllib/test/test_catalog.py @@ -72,13 +72,13 @@ def testDefaultModels(self): with tf.variable_scope("test1"): p1 = ModelCatalog.get_model({ - "obs": np.zeros((10, 3), dtype=np.float32) + "obs": tf.zeros((10, 3), dtype=tf.float32) }, Box(0, 1, shape=(3, ), dtype=np.float32), 5, {}) self.assertEqual(type(p1), FullyConnectedNetwork) with tf.variable_scope("test2"): p2 = ModelCatalog.get_model({ - "obs": np.zeros((10, 84, 84, 3), dtype=np.float32) + "obs": tf.zeros((10, 84, 84, 3), dtype=tf.float32) }, Box(0, 1, shape=(84, 84, 3), dtype=np.float32), 5, {}) self.assertEqual(type(p2), VisionNetwork)