Skip to content

Conversation

@rlouf
Copy link
Contributor

@rlouf rlouf commented Nov 14, 2019

Custom schedulers are currently initiated by wrapping Pytorch's LambdaLR
class and passing a method of the wrapping class to the init
function of LambdaLR. This approach is not appropriate for several
reasons:

  1. one does not need to define a class when it only defines a
    init() method;
  2. instantiating the parent class by passing a method of the child class
    creates a cyclical reference which leads to memory leaks. See issues Out of Memory (OOM) when repeatedly running large models #1742 and Schedulers cause memory accumulation across folds in cross-validation? #1134.

In this commit we replace the wrapper classes with functions that
instantiate LambdaLR with a custom learning rate function. We use a
closure to specify the parameter of the latter. We also do a bit of
renaming within the function to explicit the behaviour and removed
docstrings that were subsequently not necessary.

Custom schedulers are currently initiated by wrapping Pytorch's LambdaLR
class and passing a method of the wrapping class to the __init__
function of LambdaLR. This approach is not appropriate for several
reasons:

1. one does not need to define a class when it only defines a
__init__() method;
2. instantiating the parent class by passing a method of the child class
creates a cyclical reference which leads to memory leaks. See issues #1742 and #1134.

In this commit we replace the wrapper classes with functions that
instantiate `LambdaLR` with a custom learning rate function. We use a
closure to specify the parameter of the latter. We also do a bit of
renaming within the function to explicit the behaviour and removed
docstrings that were subsequently not necessary.
@thomwolf
Copy link
Member

Yes, great job tracking and fixing this.
Can you update all the examples as well?

@LysandreJik
Copy link
Member

Great job finding the issue!

@rlouf rlouf force-pushed the memory-leak-schedulers branch from 76c4b61 to 2276bf6 Compare November 14, 2019 19:38
@rlouf
Copy link
Contributor Author

rlouf commented Nov 14, 2019

I updated the docs and examples. I am confused because the tests fail on this function:

=================================== FAILURES ===================================
______________ TFDistilBertModelTest.test_pt_tf_model_equivalence ______________

self = <transformers.tests.modeling_tf_distilbert_test.TFDistilBertModelTest testMethod=test_pt_tf_model_equivalence>

    def test_pt_tf_model_equivalence(self):
        if not is_torch_available():
            return
    
        import torch
        import transformers
    
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
    
        for model_class in self.all_model_classes:
            pt_model_class_name = model_class.__name__[2:]  # Skip the "TF" at the beggining
            pt_model_class = getattr(transformers, pt_model_class_name)
    
            config.output_hidden_states = True
            tf_model = model_class(config)
            pt_model = pt_model_class(config)
    
            # Check we can load pt model in tf and vice-versa with model => model functions
            tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
            pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
    
            # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
            pt_model.eval()
            pt_inputs_dict = dict((name, torch.from_numpy(key.numpy()).to(torch.long))
                                  for name, key in inputs_dict.items())
            with torch.no_grad():
                pto = pt_model(**pt_inputs_dict)
            tfo = tf_model(inputs_dict)
            max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
>           self.assertLessEqual(max_diff, 2e-2)
E           AssertionError: nan not less than or equal to 0.02

which has no apparent link with my changes.

@LysandreJik
Copy link
Member

This test has been failing on and off for a week or so now; I'll look into it soon.

Copy link
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@thomwolf thomwolf merged commit df99f8c into master Nov 14, 2019
@julien-c julien-c deleted the memory-leak-schedulers branch November 16, 2019 02:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants