Skip to content

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented May 16, 2022

This PR introduces TFData2VecVisionForSemanticSegmentation which takes the TFData2VecVisionMainLayer and appends the necessary layers for performing semantic segmentation along with loss computation (first one in this line?).

Notes

  • Thanks to @Rocketknight1 who implemented the adaptive average pooling layer.
  • Currently the model saving the tests (2 tests) are failing as soon as TFData2VecVisionForSemanticSegmentation class is introduced to tests/models/test_modeling_tf_data2vec_vision.py. Without that class, the test runs as expected. I would appreciate any help.
  • As per discussed over Slack, this class should never have been subclassed from nn.ModuleList. It is currently leading a few idiosyncracies on the TF side (mainly related to naming of the layers). Once that is sorted out we can again revisit this TFData2VecVisionForSemanticSegmentation class and make the amends if needed. Happy to take the charge then.
  • I ran the tests locally with the following command: RUN_SLOW=1 python -m pytest tests/models/data2vec/test_modeling_tf_data2vec_vision.py.

Here's the trace of the errors from running tests:

            model = model_class(config)
            model(self._prepare_for_class(inputs_dict, model_class))  # Model must be called before saving.
            # Let's load it from the disk to be sure we can use pretrained weights
            with tempfile.TemporaryDirectory() as tmpdirname:
>               model.save_pretrained(tmpdirname, saved_model=False)

tests/test_modeling_tf_common.py:693: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/transformers/modeling_tf_utils.py:1513: in save_pretrained
    self.save_weights(output_model_file)
../../.local/bin/.virtualenvs/hf/lib/python3.8/site-packages/keras/utils/traceback_utils.py:67: in error_handler
    raise e.with_traceback(filtered_tb) from None
../../.local/bin/.virtualenvs/hf/lib/python3.8/site-packages/h5py/_hl/group.py:149: in create_dataset
    dsid = dataset.make_new_dset(group, shape, dtype, data, name, **kwds)
../../.local/bin/.virtualenvs/hf/lib/python3.8/site-packages/h5py/_hl/dataset.py:142: in make_new_dset
    dset_id = h5d.create(parent.id, name, tid, sid, dcpl=dcpl)
h5py/_objects.pyx:54: in h5py._objects.with_phil.wrapper
    ???
h5py/_objects.pyx:55: in h5py._objects.with_phil.wrapper
    ???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

>   ???
E   ValueError: Unable to create dataset (name already exists)

h5py/h5d.pyx:87: ValueError

...


           outputs = model(self._prepare_for_class(inputs_dict, model_class))
    
            with tempfile.TemporaryDirectory() as tmpdirname:
>               model.save_pretrained(tmpdirname, saved_model=False)

tests/test_modeling_tf_common.py:175: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/transformers/modeling_tf_utils.py:1513: in save_pretrained
    self.save_weights(output_model_file)
../../.local/bin/.virtualenvs/hf/lib/python3.8/site-packages/keras/utils/traceback_utils.py:67: in error_handler
    raise e.with_traceback(filtered_tb) from None
../../.local/bin/.virtualenvs/hf/lib/python3.8/site-packages/h5py/_hl/group.py:149: in create_dataset
    dsid = dataset.make_new_dset(group, shape, dtype, data, name, **kwds)
../../.local/bin/.virtualenvs/hf/lib/python3.8/site-packages/h5py/_hl/dataset.py:142: in make_new_dset
    dset_id = h5d.create(parent.id, name, tid, sid, dcpl=dcpl)
h5py/_objects.pyx:54: in h5py._objects.with_phil.wrapper
    ???
h5py/_objects.pyx:55: in h5py._objects.with_phil.wrapper
    ???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

>   ???
E   ValueError: Unable to create dataset (name already exists)

h5py/h5d.pyx:87: ValueError
-------------------------------

Additionally, here's a little code for testing the segmentation class:

from PIL import Image
import tensorflow as tf 

from src.transformers.models.data2vec.modeling_tf_data2vec_vision import (
    TFData2VecVisionForSemanticSegmentation
)
from transformers import BeitFeatureExtractor


def prepare_img():
    image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
    return image


feature_extractor = BeitFeatureExtractor.from_pretrained(
    "facebook/data2vec-vision-base-ft1k"
)
model = TFData2VecVisionForSemanticSegmentation.from_pretrained(
    "facebook/data2vec-vision-base",
)


image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="tf")
batch_size, num_channels, height, width = inputs["pixel_values"].shape
inputs["labels"] = tf.zeros((batch_size, height, width))
outputs = model(**inputs)

print(outputs.logits.shape)
print(outputs.loss.shape)

@Rocketknight1 @sgugger

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 16, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your PR! Note that on the pyramid pooling class, even if we change the PyTorch class to not subclass ModuleList anymore, it will still need to keep the same weight names, otherwise compatibility with any checkpoint on the Hub will be broken.

@sayakpaul
Copy link
Member Author

Thanks a lot for your PR! Note that on the pyramid pooling class, even if we change the PyTorch class to not subclass ModuleList anymore, it will still need to keep the same weight names, otherwise compatibility with any checkpoint on the Hub will be broken.

Absolutely.

@sayakpaul
Copy link
Member Author

@Rocketknight1 a gentle ping 👀

@Rocketknight1
Copy link
Member

Ah, I'm sorry! Will review it by tomorrow.

@Rocketknight1
Copy link
Member

Hi, I just took a look over this! I suspect the issue with the tests is that there's something like a layer name collision when saving. In h5 files, weights are saved as 'datasets' , so this error is telling us that the weights are not uniquely named - the same 'dataset' name is being written to twice during saving, which means two layers share the same name.

@sayakpaul
Copy link
Member Author

Yes, I suspected something similar but couldn't figure out where the duplicate is coming from. Do you have any suggestions?

@Rocketknight1

@Rocketknight1
Copy link
Member

I suspect the issue is most likely related to the implementation of AdaptiveAvgPool I wrote - the practice of precomputing a constant sparse matrix like that is non-standard, and TF might be trying to save that Tensor somehow. Can you try replacing it with a 'dummy' layer that has the same output shape and seeing if the error goes away? If so, I can work on a different implementation for the layer - I have some ideas that I think will improve performance a lot, and they might also resolve the problem too.

@sayakpaul
Copy link
Member Author

Can you try replacing it with a 'dummy' layer that has the same output shape and seeing if the error goes away?

Sure. I will do it and get back.

@sayakpaul
Copy link
Member Author

sayakpaul commented May 20, 2022

@Rocketknight1
Copy link
Member

@sayakpaul I used post-mortem debugging to isolate this - just add this to TFData2VecVisionModelTest:

  def test_save_load(self):
      try:
          super().test_save_load()
      except:
          import pdb
          pdb.post_mortem()

Then run the tests with pytest --capture=no. This will break into a debugger at the point of failure, and you can step up to the calling frame with (u)p.

From there, I can tell that the offending array has name kernel:0 with shape (1, 1, 32, 32), though I couldn't figure out exactly where it was. Is there a 1x1 conv2D in your code that maps 32 filters to 32 filters?

@sayakpaul
Copy link
Member Author

sayakpaul commented May 20, 2022

From there, I can tell that the offending array has name kernel:0 with shape (1, 1, 32, 32), though I couldn't figure out exactly where it was. Is there a 1x1 conv2D in your code that maps 32 filters to 32 filters?

There are multiple 1x1 convs, yes.

@sayakpaul
Copy link
Member Author

sayakpaul commented May 20, 2022

Then run the tests with pytest --capture=no. This will break into a debugger at the point of failure, and you can step up to the calling frame with (u)p.

Could you elaborate a bit more here? I have added the pdb snippet into the model tester code. Then I ran RUN_SLOW=1 python -m pytest --capture=no tests/models/data2vec/test_modeling_tf_data2vec_vision.py. I do get the pdb prompt and I get to -> super().test_save_load() as the oldest frame.

@Rocketknight1

@Rocketknight1
Copy link
Member

@sayakpaul I stepped up to the frame of dsid = dataset.make_new_dset(group, shape, dtype, data, name, **kwds). This let me inspect the variable name and the group, but I didn't understand h5py well enough to figure out the exact weight causing the issue.

@sayakpaul
Copy link
Member Author

@Rocketknight1 I looked into the layers with kernel_size=1 and tried to fix their names to use something that's suffixed with identifiers. You can find the commit here.

It still didn't resolve the issue. The only potential suspect I could find is the following. There are two layers namely classifier in TFData2VecVisionForSemanticSegmentation that are added via TFData2VecVisionUperHead and TFData2VecVisionFCNHead respectively.

Thoughts?

@sayakpaul
Copy link
Member Author

sayakpaul commented May 30, 2022

Update:

With @Rocketknight1's help, I was able to resolve the current test failure (commit here). But I have run into two more failures which I am currently discussing with @Rocketknight1. He's on vacation. Once he gets back, hopefully, will be able to report back with updates.

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With tests passing now, I'm happy to approve this!


This model was contributed by [edugp](https://huggingface.co/edugp) and [patrickvonplaten](https://huggingface.co/patrickvonplaten).
[sayakpaul](https://github.com/sayakpaul) contributed Data2Vec for vision in TensorFlow.
[sayakpaul](https://github.com/sayakpaul) and [Rocketknight1](https://github.com/Rocketknight1) contributed Data2Vec for vision in TensorFlow.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You did almost all of it!

@Rocketknight1 Rocketknight1 merged commit 9d99489 into huggingface:main Jun 8, 2022
@sayakpaul sayakpaul deleted the tf-data2vec-vision-seg branch June 8, 2022 13:04
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* feat: initial implementation of data2vec segmentation model in TF.

* chore: minor corrections to make the segmenter work.

* chore: removed unncessary files.

* chore: add tests and other modifications.

* fix: loss computation for segmentation.

* chore: remove unused variable.

* chore: formatting.

* added a dummy adaptive pooling layer.

* removed unnecessary file.

* potentially add identifiers to layer names.

* fix: layer naming.

* chore: removed unnecessary print.

* Skipping unneeded test

* chore: add logging to debug tolerance.

* fix: segmentation tests for tfdata2vecvision

* chore: make style.

* fix: layer names, assertion to be resolved.

* Bumping test tolerance a bit

* chore: bump the tol in PT test.

Co-authored-by: matt <[email protected]>
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Jun 16, 2022
* feat: initial implementation of data2vec segmentation model in TF.

* chore: minor corrections to make the segmenter work.

* chore: removed unncessary files.

* chore: add tests and other modifications.

* fix: loss computation for segmentation.

* chore: remove unused variable.

* chore: formatting.

* added a dummy adaptive pooling layer.

* removed unnecessary file.

* potentially add identifiers to layer names.

* fix: layer naming.

* chore: removed unnecessary print.

* Skipping unneeded test

* chore: add logging to debug tolerance.

* fix: segmentation tests for tfdata2vecvision

* chore: make style.

* fix: layer names, assertion to be resolved.

* Bumping test tolerance a bit

* chore: bump the tol in PT test.

Co-authored-by: matt <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants