Skip to content

Commit

Permalink
remove reference to Keras in test files
Browse files Browse the repository at this point in the history
  • Loading branch information
franckma31 committed Nov 7, 2024
1 parent e83741e commit f5acef3
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 12 deletions.
2 changes: 1 addition & 1 deletion deel/torchlip/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@


def _is_supported_1lip_layer(layer):
"""Return True if the Keras layer is 1-Lipschitz. Note that in some cases, the layer
"""Return True if the layer is 1-Lipschitz. Note that in some cases, the layer
is 1-Lipschitz for specific set of parameters.
"""
supported_1lip_layers = (
Expand Down
4 changes: 2 additions & 2 deletions tests/test_compute_layer_sv.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def train_compute_and_verifySV(
input_shape = uft.to_framework_channel(input_shape)

# tf.random.uft.set_seed(1234)
# create the keras model, defin opt, and compile it
# create the model, defin opt, and compile it
model = uft.generate_k_lip_model(layer_type, layer_params, input_shape, k_lip_model)

optimizer = uft.get_instance_framework(
Expand All @@ -166,7 +166,7 @@ def train_compute_and_verifySV(
logdir = os.path.join("logs", uft.LIP_LAYERS, "%s" % layer_type.__name__)
os.makedirs(logdir, exist_ok=True)

callback_list = [] # [hp.KerasCallback(logdir, hparams)]
callback_list = []
if "callbacks" in kwargs and (kwargs["callbacks"] is not None):
callback_list = callback_list + kwargs["callbacks"]
# train model
Expand Down
2 changes: 1 addition & 1 deletion tests/test_condense.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def test_model(layer_type, layer_params, k_coef_lip, input_shape):

model = get_model(layer_type, layer_params, input_shape, k_coef_lip)

# create the keras model, defin opt, and compile it
# create the model, defin opt, and compile it
optimizer = uft.get_instance_framework(
uft.Adam, inst_params={"lr": 0.001, "model": model}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_initializer(layer_type, layer_params, input_shape, orthogonal_test):
# clear session to avoid side effects from previous train
uft.init_session() # K.clear_session()
input_shape = uft.to_framework_channel(input_shape)
# create the keras model, defin opt, and compile it
# create the model, defin opt, and compile it
model = uft.generate_k_lip_model(layer_type, layer_params, input_shape)
uft.initialize_kernel(model, 0, layer_params["kernel_initializer"])

Expand Down
4 changes: 2 additions & 2 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def train_k_lip_model(
uft.init_session() # K.clear_session()
np.random.seed(42)
input_shape = uft.to_framework_channel(input_shape)
# create the keras model, defin opt, and compile it
# create the model, defin opt, and compile it
model = uft.generate_k_lip_model(layer_type, layer_params, input_shape, k_lip_model)

optimizer = uft.get_instance_framework(
Expand All @@ -195,7 +195,7 @@ def train_k_lip_model(

callback_list = (
[]
) # [callbacks.TensorBoard(logdir), hp.KerasCallback(logdir, hparams)]
)
if kwargs["callbacks"] is not None:
callback_list = callback_list + kwargs["callbacks"]
# train model
Expand Down
8 changes: 3 additions & 5 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def assert_model_outputs(input_shape, model1, model2):
np.testing.assert_allclose(uft.to_numpy(y2), uft.to_numpy(y1), atol=1e-5)


def test_keras_Sequential():
"""Assert vanilla conversion of a tf.keras.Sequential model"""
def test_Sequential():
"""Assert vanilla conversion of a Sequential model"""
input_shape = uft.to_framework_channel((3, 20, 20))
model = uft.generate_k_lip_model(
tSequential,
Expand Down Expand Up @@ -203,14 +203,12 @@ def test_deel_lip_Sequential():
reason="tModel not available",
)
def test_Model():
"""Assert vanilla conversion of a tf.keras.Model model"""
"""Assert vanilla conversion of a Model model"""
input_shape = uft.to_framework_channel((3, 8, 8))
dict_tensors = get_functional_tensors(input_shape)
model = uft.get_functional_model(
tModel, dict_tensors, functional_input_output_tensors
)
# inputs, outputs = functional_input_output_tensors()
# model = tf.keras.Model(inputs, outputs)
if uft.vanilla_require_a_copy():
dict_tensors2 = get_functional_tensors(input_shape)
model2 = uft.get_functional_model(
Expand Down

0 comments on commit f5acef3

Please sign in to comment.