Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: GPU memory leak in TextPairRegressor when embed_separately is set to False #3487

Open
MattGPT-ai opened this issue Jul 3, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@MattGPT-ai
Copy link
Contributor

MattGPT-ai commented Jul 3, 2024

Describe the bug

When training a TextPairRegressor model with embed_separately=False (the default), via e.g. ModelTrainer.fine_tune, the GPU memory slowly creeps up with each batch, eventually causing an OOM even when the model and a single batch fits easily in GPU memory.

The function store_embeddings is supposed to clear any embeddings of each DataPoint. For this model, the type of data point is TextPair. It actually does seem to handle clearing text_pair.first and .second when embed_separately=True, because it runs embed for each sentence (see TextPairRegressor._get_embedding_for_data_point), and that embedding is attached to each sentence so it can be referenced via the sentence.

However, the default setting is False; in that case, to embed the pair, it concatenates the text of both sentences (adding a separator), creates a new sentence, embeds that sentence, and then returns that embedding. Since it's never attached to the DataPoint object, clear_embeddings doesn't find it when you iterate over the data points. The function identify_dynamic_embeddings also always comes up empty

To Reproduce

import flair
from flair.data import DataPairCorpus
from flair.models import TextPairRegressor

search_rel_corpus = DataPairCorpus(Path('text_pair_dataset'), train_file='train.tsv', test_file='test.tsv', dev_file='dev.tsv', label_type='relevance', in_memory=False)
text_pair_regressor = TextPairRegressor(embeddings=embeddings, label_type='relevance')

embeddings = TransformerDocumentEmbeddings(
    model='xlm-roberta-base',
    layers="-1",
    subtoken_pooling='first',
    fine_tune=True,
    use_context=True,
    is_word_embedding=True,
)

trainer = ModelTrainer(text_pair_regressor, search_rel_corpus)

trainer.fine_tune(
    "relevance_regressor",
    learning_rate=1e-5,
    epoch=0,
    max_epochs=5,
    mini_batch_size=4,
    save_optimizer_state=True,
    save_model_each_k_epochs=1,
    use_amp=True,  # aka Automatic Mixed Precision, e.g. float16
)

Expected behavior

The memory should remain relatively flat with each epoch of training if memory is cleared correctly. In other training, such as for a TextClassifier, it stays roughly the same after each mini-batch,

Logs and Stack traces

OutOfMemoryError                          Traceback (most recent call last)
Cell In[15], line 1
----> 1 final_score = trainer.fine_tune(
      2     "relevance_regressor",
      3     learning_rate=1e-5,
      4     epoch=0,
      5     max_epochs=5,
      6     mini_batch_size=4,
      7     save_optimizer_state=True,
      8     save_model_each_k_epochs=1,
      9     use_amp=True,  # aka Automatic Mixed Precision, e.g. float16
     10 )
     11 final_score

File /pyzr/active_venv/lib/python3.10/site-packages/flair/trainers/trainer.py:253, in ModelTrainer.fine_tune(self, base_path, warmup_fraction, learning_rate, decoder_learning_rate, mini_batch_size, eval_batch_size, mini_batch_chunk_size, max_epochs, optimizer, train_with_dev, train_with_test, reduce_transformer_vocab, main_evaluation_metric, monitor_test, monitor_train_sample, use_final_model_for_eval, gold_label_dictionary_for_eval, exclude_labels, sampler, shuffle, shuffle_first_epoch, embeddings_storage_mode, epoch, save_final_model, save_optimizer_state, save_model_each_k_epochs, create_file_logs, create_loss_file, write_weights, use_amp, plugins, attach_default_scheduler, **kwargs)
    250 if attach_default_scheduler:
    251     plugins.append(LinearSchedulerPlugin(warmup_fraction=warmup_fraction))
--> 253 return self.train_custom(
    254     base_path=base_path,
    255     # training parameters
    256     learning_rate=learning_rate,
    257     decoder_learning_rate=decoder_learning_rate,
    258     mini_batch_size=mini_batch_size,
    259     eval_batch_size=eval_batch_size,
    260     mini_batch_chunk_size=mini_batch_chunk_size,
    261     max_epochs=max_epochs,
    262     optimizer=optimizer,
    263     train_with_dev=train_with_dev,
    264     train_with_test=train_with_test,
    265     reduce_transformer_vocab=reduce_transformer_vocab,
    266     # evaluation and monitoring
    267     main_evaluation_metric=main_evaluation_metric,
    268     monitor_test=monitor_test,
    269     monitor_train_sample=monitor_train_sample,
    270     use_final_model_for_eval=use_final_model_for_eval,
    271     gold_label_dictionary_for_eval=gold_label_dictionary_for_eval,
    272     exclude_labels=exclude_labels,
    273     # sampling and shuffling
    274     sampler=sampler,
    275     shuffle=shuffle,
    276     shuffle_first_epoch=shuffle_first_epoch,
    277     # evaluation and monitoring
    278     embeddings_storage_mode=embeddings_storage_mode,
    279     epoch=epoch,
    280     # when and what to save
    281     save_final_model=save_final_model,
    282     save_optimizer_state=save_optimizer_state,
    283     save_model_each_k_epochs=save_model_each_k_epochs,
    284     # logging parameters
    285     create_file_logs=create_file_logs,
    286     create_loss_file=create_loss_file,
    287     write_weights=write_weights,
    288     # amp
    289     use_amp=use_amp,
    290     # plugins
    291     plugins=plugins,
    292     **kwargs,
    293 )

File /pyzr/active_venv/lib/python3.10/site-packages/flair/trainers/trainer.py:624, in ModelTrainer.train_custom(self, base_path, learning_rate, decoder_learning_rate, mini_batch_size, eval_batch_size, mini_batch_chunk_size, max_epochs, optimizer, train_with_dev, train_with_test, max_grad_norm, reduce_transformer_vocab, main_evaluation_metric, monitor_test, monitor_train_sample, use_final_model_for_eval, gold_label_dictionary_for_eval, exclude_labels, sampler, shuffle, shuffle_first_epoch, embeddings_storage_mode, epoch, save_final_model, save_optimizer_state, save_model_each_k_epochs, create_file_logs, create_loss_file, write_weights, use_amp, plugins, **kwargs)
    622     gradient_norm = None
    623 scale_before = scaler.get_scale()
--> 624 scaler.step(self.optimizer)
    625 scaler.update()
    626 scale_after = scaler.get_scale()

File /pyzr/active_venv/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py:370, in GradScaler.step(self, optimizer, *args, **kwargs)
    366     self.unscale_(optimizer)
    368 assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
--> 370 retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
    372 optimizer_state["stage"] = OptState.STEPPED
    374 return retval

File /pyzr/active_venv/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py:290, in GradScaler._maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs)
    288 retval = None
    289 if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
--> 290     retval = optimizer.step(*args, **kwargs)
    291 return retval

File /pyzr/active_venv/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:69, in LRScheduler.__init__.<locals>.with_counter.<locals>.wrapper(*args, **kwargs)
     67 instance._step_count += 1
     68 wrapped = func.__get__(instance, cls)
---> 69 return wrapped(*args, **kwargs)

File /pyzr/active_venv/lib/python3.10/site-packages/torch/optim/optimizer.py:280, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    276         else:
    277             raise RuntimeError(f"{func} must return None or a tuple of (new_args, new_kwargs),"
    278                                f"but got {result}.")
--> 280 out = func(*args, **kwargs)
    281 self._optimizer_step_code()
    283 # call optimizer step post hooks

File /pyzr/active_venv/lib/python3.10/site-packages/torch/optim/optimizer.py:33, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
     31 try:
     32     torch.set_grad_enabled(self.defaults['differentiable'])
---> 33     ret = func(self, *args, **kwargs)
     34 finally:
     35     torch.set_grad_enabled(prev_grad)

File /pyzr/active_venv/lib/python3.10/site-packages/torch/optim/adamw.py:171, in AdamW.step(self, closure)
    158     beta1, beta2 = group["betas"]
    160     self._init_group(
    161         group,
    162         params_with_grad,
   (...)
    168         state_steps,
    169     )
--> 171     adamw(
    172         params_with_grad,
    173         grads,
    174         exp_avgs,
    175         exp_avg_sqs,
    176         max_exp_avg_sqs,
    177         state_steps,
    178         amsgrad=amsgrad,
    179         beta1=beta1,
    180         beta2=beta2,
    181         lr=group["lr"],
    182         weight_decay=group["weight_decay"],
    183         eps=group["eps"],
    184         maximize=group["maximize"],
    185         foreach=group["foreach"],
    186         capturable=group["capturable"],
    187         differentiable=group["differentiable"],
    188         fused=group["fused"],
    189         grad_scale=getattr(self, "grad_scale", None),
    190         found_inf=getattr(self, "found_inf", None),
    191     )
    193 return loss

File /pyzr/active_venv/lib/python3.10/site-packages/torch/optim/adamw.py:321, in adamw(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach, capturable, differentiable, fused, grad_scale, found_inf, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize)
    318 else:
    319     func = _single_tensor_adamw
--> 321 func(
    322     params,
    323     grads,
    324     exp_avgs,
    325     exp_avg_sqs,
    326     max_exp_avg_sqs,
    327     state_steps,
    328     amsgrad=amsgrad,
    329     beta1=beta1,
    330     beta2=beta2,
    331     lr=lr,
    332     weight_decay=weight_decay,
    333     eps=eps,
    334     maximize=maximize,
    335     capturable=capturable,
    336     differentiable=differentiable,
    337     grad_scale=grad_scale,
    338     found_inf=found_inf,
    339 )

File /pyzr/active_venv/lib/python3.10/site-packages/torch/optim/adamw.py:566, in _multi_tensor_adamw(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, grad_scale, found_inf, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize, capturable, differentiable)
    564     exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
    565     torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
--> 566     denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
    568 torch._foreach_addcdiv_(device_params, device_exp_avgs, denom, step_size)

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 15.78 GiB total capacity; 14.06 GiB already allocated; 12.00 MiB free; 14.90 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Screenshots

No response

Additional Context

I printed out the GPU usage in an altered train_custom:

def print_gpu_usage(entry=None):
    allocated_memory = torch.cuda.memory_allocated(0)
    reserved_memory = torch.cuda.memory_reserved(0)
    print(f"{entry}\t{allocated_memory:<15,} / {reserved_memory:<15,}")

I saw that when training a TextClassifier, the memory usage goes back down to the value at the beginning of a batch after store_embeddings is called. In TextPairRegressor, the memory does not go down at all after store_embeddings is called.

Environment

Versions:

Flair

0.13.1

Pytorch

2.3.1+cu121

Transformers

4.31.0

GPU

True

@MattGPT-ai MattGPT-ai added the bug Something isn't working label Jul 3, 2024
alanakbik added a commit to MattGPT-ai/flair that referenced this issue Jul 8, 2024
MattGPT-ai added a commit to MattGPT-ai/flair that referenced this issue Jul 9, 2024
alanakbik added a commit that referenced this issue Jul 13, 2024
…text-pair-regressor

Fix GPU memory leak in TextPairRegressor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant