Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor visual encoder and features #101

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

shaikh58
Copy link
Contributor

@shaikh58 shaikh58 commented Nov 27, 2024

  • Add support for different visual feature encoders. Currently we only support ResNet based encoder. This PR adds support for a compact, descriptive feature encoder based on the inertia tensor, mean intensity, and optionally, Hu moments
  • Adds an encoder registry which picks appropriate encoder based on user input in configs. User can also configure whether or not to use Hu moments
  • Adds sample config showing how users can specify the encoder
  • Adds feature to disable image normalization if descriptive features are used. Normalizing leads to empty crops in low intensity regions, which prevents calculation of visual features
  • Allows users to define their own visual encoder, and register it
  • Introduces learnable Fourier embeddings that are concatenated with the low dimensional visual feature vector and projected to user-defined model dimension. Applied to encoder and decoder input

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced a new parameter for image normalization in the dataset class, enhancing flexibility during dataset initialization.
    • Added a new DescriptorVisualEncoder class for feature extraction based on image descriptors.
    • Implemented FourierPositionalEmbeddings for improved positional encoding in the transformer model.
  • Enhancements

    • Updated dataset retrieval process to adapt image normalization based on encoder type.
    • Streamlined visual encoder initialization within the transformer class.
    • Enhanced encoder configuration structure in YAML files for improved clarity and extensibility.
  • Documentation

    • Updated configuration files to reflect new encoder settings and added comments for better understanding.

…ct appropriate encoder

- registry allows users to define a custom visual encoder, add configs and register into the registry
- global_tracking_Transformer calls the router method rather than VisualEncoder directly
- update config file structure to allow different config sets for different encoder types
- maintain backward compatibility in case old configs are used - defaults to existing encoder
move data to and from gpu when computing skimage features
…visual features to larger d_model

- transformer concatenates learnable fourier params onto feature vector and projects
- remove Hu embeddings from feature descriptor - for numerical stability
- use embed_dim = d_model rather than taking it from the incoming features shapes
- apply fourier embeddings and project to d_model; fix minor bugs in dims and order of operations
- apply fourier embeddings to decoder input as well as encoder
… regions - empty crops prevent calculation of visual features
… used (for training stability)

- remove lbp encoder from visual encoder and config structure
Copy link
Contributor

coderabbitai bot commented Nov 27, 2024

Walkthrough

The pull request introduces several modifications across multiple files. Key changes include the addition of a normalize_image parameter in the SleapDataset class to control image normalization, updates to the dataset retrieval process in the Config class, and the introduction of a new DescriptorVisualEncoder class for feature extraction. Additionally, a FourierPositionalEmbeddings class is added to enhance positional encoding in transformers. Various import statements are updated to support these new functionalities, streamlining the overall architecture and improving flexibility in image processing and model configuration.

Changes

File Change Summary
dreem/datasets/sleap_dataset.py Added normalize_image parameter to __init__ method; updated get_instances method to conditionally apply image normalization. Updated docstring for __init__ to include normalize_image.
dreem/io/config.py Modified get_dataset method to set normalize_image to False if encoder_type is "descriptor".
dreem/models/__init__.py Updated import statements to include FourierPositionalEmbeddings, DescriptorVisualEncoder, create_visual_encoder, and register_encoder.
dreem/models/embedding.py Introduced FourierPositionalEmbeddings class with methods for initialization and forward pass; added helper function _pos_embed_fourier1d_init.
dreem/models/global_tracking_transformer.py Replaced VisualEncoder import with create_visual_encoder; simplified visual encoder instantiation; updated extract_features method to ensure compatibility with device context.
dreem/models/transformer.py Added encoder_cfg parameter to __init__; updated forward method to use Fourier embeddings; introduced apply_fourier_embeddings method.
dreem/models/visual_encoder.py Added DescriptorVisualEncoder class with methods for feature extraction; introduced ENCODER_REGISTRY and register_encoder function; updated create_visual_encoder function.
dreem/training/configs/base.yaml Added comments outlining a new structure for encoder_cfg in the model configuration, suggesting future organization based on encoder types.
tests/configs/base.yaml Updated encoder_cfg section with a new structure including encoder_type, ndim, and nested settings for resnet and descriptor.

Possibly related PRs

  • Mustafa multiple crop size #100: Modifies the SleapDataset class in sleap_dataset.py, updating the constructor to include a new parameter data_dirs and altering the crop_size parameter, which is relevant to the normalize_image changes.

🐇 In the code we hop and play,
New features bloom, bright as day.
Normalize images, oh so fine,
With descriptors, our models shine.
Fourier waves dance in the air,
In this update, joy we share! 🎉


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Outside diff range and nitpick comments (14)
dreem/models/__init__.py (1)

9-14: Excellent architectural design for extensibility

The combination of factory and registry patterns creates a flexible plugin architecture that:

  • Enables users to register custom encoders
  • Provides a clean interface for encoder instantiation
  • Facilitates future additions of new encoder types without modifying existing code
🧰 Tools
🪛 Ruff (0.8.0)

10-10: .visual_encoder.VisualEncoder imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


11-11: .visual_encoder.DescriptorVisualEncoder imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


12-12: .visual_encoder.create_visual_encoder imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


13-13: .visual_encoder.register_encoder imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)

dreem/models/global_tracking_transformer.py (3)

75-75: Consider reviewing the architecture design

Passing encoder_cfg to the Transformer initialization might indicate tight coupling between the Transformer and encoder configuration. Consider if this violates separation of concerns and whether the encoder configuration should be handled entirely by the visual encoder component.


Line range hint 8-8: Address TODO comment about parameter handling

The TODO comment suggests uncertainty about parameter handling with configs. This should be resolved before merging.

Would you like me to help design a more structured approach to parameter handling?


Line range hint 10-11: Consider splitting responsibilities

The GlobalTrackingTransformer class handles both visual encoding and transformer operations. Consider splitting these responsibilities into separate components for better maintainability and testing.

dreem/models/embedding.py (3)

361-363: Enhance function documentation

The docstring should include parameter descriptions and return type information.

Consider updating the docstring to:

-    """Create a tensor of shape (1,n) of fourier frequency coefficients"""
+    """Create a tensor of shape (1,n) of Fourier frequency coefficients.
+    
+    Args:
+        cutoff: The maximum frequency cutoff value
+        n: Number of frequency components
+        
+    Returns:
+        torch.Tensor: A tensor of shape (1,n) containing logarithmically spaced frequencies
+    """

366-376: Improve class documentation and add input validation

The class docstring should better describe its purpose and parameters. Additionally, consider adding input validation.

 class FourierPositionalEmbeddings(torch.nn.Module):
+    """Learnable Fourier positional embeddings for transformer models.
+    
+    This class implements positional encodings using learnable Fourier frequency
+    coefficients. The embeddings are computed using both sine and cosine
+    transformations of the input positions.
+    
+    Args:
+        n_components (int): Number of frequency components for each dimension
+        d_model (int): The model dimension (must be divisible by 2*n_components)
+    """
     def __init__(
         self,
         n_components: int,
         d_model: int,
     ):
-        """Positional encoding with given cutoff and number of frequencies for each dimension.
-        number of dimension is inferred from the length of cutoffs and n_pos.
-        """
         super().__init__()
+        if d_model % (2 * n_components) != 0:
+            raise ValueError(
+                f"d_model ({d_model}) must be divisible by 2*n_components ({2*n_components})"
+            )
         self.d_model = d_model
         self.n_components = n_components

383-404: Optimize computation and improve type hints

The forward method could benefit from type hints and clearer variable names.

-    def forward(self, seq_positions: torch.Tensor):
+    def forward(self, seq_positions: torch.Tensor) -> torch.Tensor:
         """Compute learnable fourier coefficients for each spatial/temporal position.
         Args:
-            seq_positions: tensor of shape (num_queries,)
+            seq_positions: Tensor of shape (num_queries,) containing position indices
         Returns:
-            tensor of shape (num_queries, embed_dim)
+            torch.Tensor: Positional embeddings of shape (1, num_queries, d_model)
         """
         freq = self.freq.to(seq_positions.device)
+        # Reshape positions once for both sin and cos computations
+        positions = seq_positions.unsqueeze(-1).unsqueeze(0) * freq * 0.5 * math.pi
         embed = torch.cat(
             (
-                torch.sin(
-                    0.5 * math.pi * seq_positions.unsqueeze(-1).unsqueeze(0) * freq
-                ),
-                torch.cos(
-                    0.5 * math.pi * seq_positions.unsqueeze(-1).unsqueeze(0) * freq
-                ),
+                torch.sin(positions),
+                torch.cos(positions),
             ),
             axis=-1,
         ) / math.sqrt(len(freq))
dreem/io/config.py (1)

271-275: Consider decoupling encoder configuration from dataset preprocessing

The current implementation creates a direct dependency between the encoder type and dataset preprocessing. This coupling might make it harder to:

  • Add new encoder types without modifying dataset logic
  • Mix different preprocessing strategies independently
  • Test preprocessing configurations in isolation

Consider these alternatives:

  1. Move preprocessing configuration to the encoder config itself:
encoder_cfg = {
    "type": "descriptor",
    "preprocessing": {
        "normalize_image": False
    }
}
  1. Create a separate preprocessing configuration section:
preprocessing_cfg = {
    "descriptor": {
        "normalize_image": False
    },
    "resnet": {
        "normalize_image": True
    }
}

This would make the system more modular and easier to extend with new encoder types.

dreem/models/visual_encoder.py (2)

9-9: Remove unused import local_binary_pattern

The function local_binary_pattern from skimage.feature is imported but not used in the code. Removing it will clean up the imports.

Apply this diff to remove the unused import:

- from skimage.feature import local_binary_pattern
🧰 Tools
🪛 Ruff (0.8.0)

9-9: skimage.feature.local_binary_pattern imported but unused

Remove unused import: skimage.feature.local_binary_pattern

(F401)


187-215: Optimize forward method by vectorizing computations

The forward method processes images in a loop, which can be inefficient for large batches. Consider vectorizing the computations to leverage PyTorch's batch processing capabilities and improve performance.

dreem/models/transformer.py (4)

45-45: Ensure encoder_cfg parameter is documented and type-hinted correctly

The new parameter encoder_cfg is added to the Transformer class. Please ensure that it's fully documented in the method's docstring, including its purpose and expected structure.


97-100: Make n_components in FourierPositionalEmbeddings configurable

Currently, n_components is hard-coded to 8. Consider making n_components configurable through encoder_cfg or constructor parameters to provide flexibility for different model configurations.


203-214: Refactor condition to reduce code duplication

The condition to apply Fourier embeddings is repeated in both the encoder and decoder sections. Refactor this condition into a helper method to improve code maintainability and readability.

Apply this diff to create a helper method:

+    def should_apply_fourier_embeddings(self) -> bool:
+        return (
+            "use_fourier" in self.embedding_meta and self.embedding_meta["use_fourier"]
+        ) or (
+            self.encoder_cfg is not None
+            and "encoder_type" in self.encoder_cfg
+            and self.encoder_cfg["encoder_type"] == "descriptor"
+        )

...

             # Apply Fourier embeddings if conditions are met
-            if (
-                "use_fourier" in self.embedding_meta and self.embedding_meta["use_fourier"]
-            ) or (
-                self.encoder_cfg is not None
-                and "encoder_type" in self.encoder_cfg
-                and self.encoder_cfg["encoder_type"] == "descriptor"
-            ):
+            if self.should_apply_fourier_embeddings():
                 encoder_queries = self.apply_fourier_embeddings(
                     encoder_queries, ref_times
                 )

256-267: Refactor condition to reduce code duplication in decoder

Similar to the encoder section, the condition to apply Fourier embeddings is repeated. Use the should_apply_fourier_embeddings helper method to improve code maintainability.

Apply this diff:

             # Apply Fourier embeddings if conditions are met
-            if (
-                "use_fourier" in self.embedding_meta and self.embedding_meta["use_fourier"]
-            ) or (
-                self.encoder_cfg is not None
-                and "encoder_type" in self.encoder_cfg
-                and self.encoder_cfg["encoder_type"] == "descriptor"
-            ):
+            if self.should_apply_fourier_embeddings():
                 query_features = self.apply_fourier_embeddings(
                     query_features, query_times
                 )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 9cef2a2 and dc5a6f0.

📒 Files selected for processing (8)
  • dreem/datasets/sleap_dataset.py (4 hunks)
  • dreem/io/config.py (1 hunks)
  • dreem/models/__init__.py (1 hunks)
  • dreem/models/embedding.py (1 hunks)
  • dreem/models/global_tracking_transformer.py (4 hunks)
  • dreem/models/transformer.py (8 hunks)
  • dreem/models/visual_encoder.py (2 hunks)
  • dreem/training/configs/base.yaml (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.0)
dreem/models/__init__.py

3-3: .embedding.Embedding imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


3-3: .embedding.FourierPositionalEmbeddings imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


8-8: .transformer.Transformer imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


10-10: .visual_encoder.VisualEncoder imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


11-11: .visual_encoder.DescriptorVisualEncoder imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


12-12: .visual_encoder.create_visual_encoder imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


13-13: .visual_encoder.register_encoder imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)

dreem/models/visual_encoder.py

9-9: skimage.feature.local_binary_pattern imported but unused

Remove unused import: skimage.feature.local_binary_pattern

(F401)

🪛 yamllint (1.35.1)
dreem/training/configs/base.yaml

[error] 8-8: trailing spaces

(trailing-spaces)

🔇 Additional comments (12)
dreem/models/__init__.py (2)

3-3: LGTM: Appropriate addition of FourierPositionalEmbeddings to public API

The addition of FourierPositionalEmbeddings to the public API aligns with the PR objective of introducing learnable Fourier embeddings for positional encoding.

🧰 Tools
🪛 Ruff (0.8.0)

3-3: .embedding.Embedding imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


3-3: .embedding.FourierPositionalEmbeddings imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


9-14: LGTM: Well-structured expansion of visual encoder capabilities

The additions to the public API are well-organized and comprehensively support the new visual encoder framework:

  • Base VisualEncoder class
  • New DescriptorVisualEncoder implementation
  • Factory function create_visual_encoder
  • Registration system via register_encoder

The multi-line import format enhances readability.

🧰 Tools
🪛 Ruff (0.8.0)

10-10: .visual_encoder.VisualEncoder imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


11-11: .visual_encoder.DescriptorVisualEncoder imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


12-12: .visual_encoder.create_visual_encoder imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


13-13: .visual_encoder.register_encoder imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)

dreem/models/global_tracking_transformer.py (3)

4-4: LGTM: Good use of factory pattern for encoder creation

The switch to using a factory function aligns well with the PR's objective of supporting multiple visual encoder types and improving configuration flexibility.


59-59: Verify error handling for invalid encoder configurations

The simplified encoder creation looks good, but we should ensure proper error handling for invalid configurations.

✅ Verification successful

Error handling for invalid encoder configurations is properly implemented

The implementation includes comprehensive error handling:

  • Validates encoder class type during registration
  • Handles unknown encoder types with descriptive error messages
  • Validates input channel mismatches in the VisualEncoder
  • Provides clear error messages with available encoder types
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for error handling in the create_visual_encoder implementation
ast-grep --pattern 'def create_visual_encoder($$$) {
  $$$
  raise $$$
  $$$
}'

# Check for any existing error handling tests
rg -l "test.*create_visual_encoder" tests/

Length of output: 126


Script:

#!/bin/bash
# Let's try a different approach to find the create_visual_encoder implementation
rg -A 10 "def create_visual_encoder" 

# Also search for visual encoder related error handling
rg "raise \w+Error.*encoder"

# Look for visual encoder related tests with a broader pattern
rg -l "visual.*encoder.*test|test.*visual.*encoder" tests/

Length of output: 1107


Script:

#!/bin/bash
# Let's check the ENCODER_REGISTRY implementation and usage
rg -B 2 -A 2 "ENCODER_REGISTRY"

# Check the register_encoder implementation
rg -B 2 -A 5 "def register_encoder"

# Look for any additional error handling in visual_encoder.py
cat dreem/models/visual_encoder.py

Length of output: 10970


126-126: Verify device handling for edge cases

The explicit device management is good for multi-GPU compatibility. However, we should verify handling of edge cases.

✅ Verification successful

Let me check the device-related tests and other device management code to better understand the edge cases.


Let me check the context around the device management in the global tracking transformer to better understand how the features and instances are handled.


Device handling is properly implemented

The device management is correctly implemented with proper safeguards:

  • Features are moved to the same device as instances after visual encoding
  • The code handles empty and single-instance cases
  • Device consistency is maintained throughout the processing pipeline
  • The commented-out test fixture in tests/fixtures/torch.py suggests device-related testing was considered but separated to avoid interference
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for other device management patterns in the codebase
rg "\.to\(device=" 

# Look for any existing device-related tests
rg -l "test.*device" tests/

Length of output: 289


Script:

#!/bin/bash
# Check device-related test content
cat tests/test_inference.py
cat tests/fixtures/torch.py

# Check the context of device management in boxes.py
rg -B 3 -A 3 "\.to\(device=" dreem/inference/boxes.py

Length of output: 9439


Script:

#!/bin/bash
# Check the context around device management in global_tracking_transformer.py
rg -B 10 -A 10 "\.to\(device=" dreem/models/global_tracking_transformer.py

# Look for instance device handling
ast-grep --pattern 'instances_to_compute'

Length of output: 1497

dreem/datasets/sleap_dataset.py (2)

38-38: LGTM: Well-structured parameter addition

The new normalize_image parameter is properly integrated with clear documentation and appropriate default value for backward compatibility.

Also applies to: 75-75, 103-103


224-226: Verify the impact of normalization changes on model predictions

Since this change affects image preprocessing, it's important to verify:

  1. The usage of normalize_image=False doesn't negatively impact model performance
  2. The handling of low-intensity regions is improved as intended

Let's check how this parameter is used across the codebase:

✅ Verification successful

Image normalization change is safe and intentionally handled

The verification reveals that the normalization change is well-controlled:

  1. The normalize_image parameter is explicitly set to False in config.py when using the descriptor encoder type, showing intentional handling of different preprocessing requirements.
  2. The normalization is properly documented in the SleapDataset class with a clear parameter description.
  3. The change only affects integer-type images, which are first converted to float32 before optional normalization.
  4. No other parts of the codebase are impacted by this change as the normalization is encapsulated within the dataset class.
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check usage of normalize_image parameter and related image processing

# Check for normalize_image parameter usage
rg -l "normalize_image" 

# Look for image preprocessing patterns
ast-grep --pattern 'img = img / 255'
ast-grep --pattern 'img.astype(np.float32)'

# Check for potential configuration files
fd -e yaml -e json . | xargs rg -l "normalize_image"

Length of output: 365


Script:

#!/bin/bash
# Let's check the context of normalize_image usage in config.py and sleap_dataset.py
rg "normalize_image" -B 3 -A 3 dreem/io/config.py dreem/datasets/sleap_dataset.py

# Check if there are any tests related to image normalization
fd -e py . | xargs rg -l "test.*normalize.*image"

# Look for any model configuration files that might reference image normalization
fd "config" -e yaml -e json . | xargs cat

Length of output: 2500

dreem/models/embedding.py (1)

366-416: Verify integration with transformer architecture

The implementation looks solid, but let's verify its integration with the transformer architecture as mentioned in the PR objectives.

✅ Verification successful

Integration with transformer architecture verified successfully

The implementation is correctly integrated with the transformer architecture as intended:

  • Found in apply_fourier_embeddings function:
    • Fourier embeddings are concatenated with input queries: torch.cat([queries, embs], dim=-1)
    • The concatenated features are properly projected to the model dimension using a linear layer
    • LayerNorm is applied after projection, ensuring proper normalization
  • The transformer properly initializes FourierPositionalEmbeddings with correct parameters in its constructor
  • The integration follows the PR objectives of concatenating with feature vectors and projecting to the model dimension
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the integration of FourierPositionalEmbeddings with transformer architecture

# Check for transformer usage of FourierPositionalEmbeddings
echo "Checking transformer integration..."
rg -l "FourierPositionalEmbeddings" --type py

# Check for visual feature concatenation
echo "Checking feature concatenation..."
ast-grep --pattern 'torch.cat([$$$ FourierPositionalEmbeddings $$$])'

Length of output: 354


Script:

#!/bin/bash
# Let's check how FourierPositionalEmbeddings is used in transformer.py
echo "Checking transformer.py implementation..."
rg -A 10 -B 10 "FourierPositionalEmbeddings" dreem/models/transformer.py

# Let's also check for any concatenation operations in transformer.py
echo "Checking concatenation operations in transformer.py..."
rg "torch.cat" dreem/models/transformer.py

Length of output: 2538

dreem/io/config.py (1)

271-275: Verify the impact of disabling image normalization

Let's verify the integration points of this change:

  1. Usage of normalize_image parameter in SleapDataset
  2. References to descriptor encoder type
  3. Test coverage
✅ Verification successful

Let me gather more information about the descriptor encoder and test coverage.


Let me check one more aspect to understand the configuration and test coverage better.


Based on the gathered information, I can now provide a final response:

Disabling image normalization for descriptor encoder is correctly integrated

The change is well-integrated within the codebase for the following reasons:

  • The normalize_image parameter in SleapDataset is properly handled, with a default value of True and explicit normalization logic
  • The descriptor encoder implementation (DescriptorVisualEncoder) is designed to work with raw image values as it computes Hu moments
  • The configuration system correctly supports the descriptor encoder type through the encoder_type setting
  • Test coverage exists for both the configuration system and the SleapDataset, including image normalization
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify integration points of the normalization change

# Check for normalize_image parameter usage in SleapDataset
echo "Checking normalize_image usage in SleapDataset:"
rg -A 5 "normalize_image" "dreem/datasets/sleap_dataset.py"

# Check for descriptor encoder type references
echo -e "\nChecking descriptor encoder type references:"
rg -l "descriptor" "dreem/"

# Check for relevant test files
echo -e "\nChecking for related test files:"
fd -g "*test*.py" "tests/"

Length of output: 1424


Script:

#!/bin/bash
# Check descriptor encoder implementation and related tests
echo "Checking descriptor encoder implementation:"
rg -A 10 "class.*Descriptor" dreem/models/visual_encoder.py

echo -e "\nChecking normalize_image in tests:"
rg "normalize_image" "tests/"

echo -e "\nChecking descriptor encoder tests:"
rg "descriptor" "tests/test_models.py"

Length of output: 792


Script:

#!/bin/bash
# Check configuration and test coverage
echo "Checking base configuration for descriptor encoder:"
cat dreem/training/configs/base.yaml

echo -e "\nChecking test_config.py for descriptor tests:"
cat tests/test_config.py

echo -e "\nChecking test_datasets.py for SleapDataset tests:"
cat tests/test_datasets.py

Length of output: 21977

dreem/models/visual_encoder.py (1)

194-194: Verify that images are single-channel or handle multiple channels appropriately

The code uses im = im[0].cpu().numpy(), which extracts the first channel of the image tensor. This assumes that the images are single-channel or that only the first channel is relevant.

Please verify whether all images are single-channel. If images can have multiple channels, consider adjusting the code to handle all channels or appropriately convert images to grayscale before processing.

dreem/models/transformer.py (2)

16-16: Approved: Importing necessary modules

The addition of FourierPositionalEmbeddings to the imports is appropriate for the new functionality.


78-78: Approved: Assigning encoder_cfg to an instance variable

Properly storing encoder_cfg for later use in the class.

Comment on lines +224 to +226
img = img.astype(np.float32)
if self.normalize_image:
img = img / 255
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider enhancing image type handling and normalization logic

The current implementation has several areas for improvement:

  1. The unconditional conversion to float32 could impact memory usage unnecessarily when normalization is disabled
  2. The normalization assumes 8-bit images (division by 255) which may not always be the case
  3. There's no handling for images that are already in float format

Consider this improved implementation:

-            if np.issubdtype(img.dtype, np.integer):  # convert int to float
-                img = img.astype(np.float32)
-                if self.normalize_image:
-                    img = img / 255
+            if self.normalize_image:
+                if np.issubdtype(img.dtype, np.integer):
+                    max_val = np.iinfo(img.dtype).max
+                    img = img.astype(np.float32) / max_val
+                elif np.issubdtype(img.dtype, np.floating):
+                    if img.max() > 1.0:
+                        img = img / img.max()
+            elif np.issubdtype(img.dtype, np.integer) and self.normalize_image:
+                img = img.astype(np.float32)

This implementation:

  • Only converts to float32 when necessary
  • Handles different bit depths correctly
  • Properly processes images that are already in float format

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +406 to +409
if self.d_model % self.n_components != 0:
raise ValueError(
f"d_model ({self.d_model}) must be divisible by number of Fourier components n_components ({self.n_components})"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Move dimension validation to initialization

The dimension check should be performed during initialization rather than in the forward pass to fail fast and avoid unnecessary computations.

This validation block should be moved to __init__ as shown in the earlier suggestion.

Comment on lines 271 to 275
if (
"encoder_type" in self.cfg["model"]["encoder_cfg"]
and self.cfg["model"]["encoder_cfg"]["encoder_type"] == "descriptor"
):
dataset_params["normalize_image"] = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add validation and improve robustness of configuration access

The current implementation has several areas for improvement:

  1. Deep nesting of dictionary access could lead to KeyError exceptions
  2. The rationale for disabling normalization with descriptor encoder should be documented
  3. The string literal "descriptor" should be extracted to a constant

Consider applying this refactor:

+ # Constants
+ DESCRIPTOR_ENCODER_TYPE = "descriptor"
+
+ def _get_encoder_type(self) -> str | None:
+     """Get the encoder type from model config safely."""
+     try:
+         return self.cfg["model"]["encoder_cfg"]["encoder_type"]
+     except (KeyError, TypeError):
+         return None
+
  if (
-     "encoder_type" in self.cfg["model"]["encoder_cfg"]
-     and self.cfg["model"]["encoder_cfg"]["encoder_type"] == "descriptor"
+     encoder_type := self._get_encoder_type()
+     and encoder_type == DESCRIPTOR_ENCODER_TYPE
  ):
+     # Disable normalization for descriptor encoder to prevent empty crops
+     # in low-intensity regions that could hinder feature calculation
      dataset_params["normalize_image"] = False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (
"encoder_type" in self.cfg["model"]["encoder_cfg"]
and self.cfg["model"]["encoder_cfg"]["encoder_type"] == "descriptor"
):
dataset_params["normalize_image"] = False
# Constants
DESCRIPTOR_ENCODER_TYPE = "descriptor"
def _get_encoder_type(self) -> str | None:
"""Get the encoder type from model config safely."""
try:
return self.cfg["model"]["encoder_cfg"]["encoder_type"]
except (KeyError, TypeError):
return None
if (
encoder_type := self._get_encoder_type()
and encoder_type == DESCRIPTOR_ENCODER_TYPE
):
# Disable normalization for descriptor encoder to prevent empty crops
# in low-intensity regions that could hinder feature calculation
dataset_params["normalize_image"] = False

Comment on lines 225 to 246
def create_visual_encoder(d_model: int, **encoder_cfg) -> torch.nn.Module:
"""Create a visual encoder based on the specified type."""

register_encoder("resnet", VisualEncoder)
register_encoder("descriptor", DescriptorVisualEncoder)
# register any custom encoders here

# compatibility with configs that don't specify encoder_type; default to resnet
if "encoder_type" not in encoder_cfg:
encoder_type = "resnet"
return ENCODER_REGISTRY[encoder_type](d_model=d_model, **encoder_cfg)
else:
encoder_type = encoder_cfg.pop("encoder_type")

if encoder_type in ENCODER_REGISTRY:
# choose the relevant encoder configs based on the encoder_type
configs = encoder_cfg[encoder_type]
return ENCODER_REGISTRY[encoder_type](d_model=d_model, **configs)
else:
raise ValueError(
f"Unknown encoder type: {encoder_type}. Please use one of {list(ENCODER_REGISTRY.keys())}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix potential KeyError in create_visual_encoder

In the create_visual_encoder function, accessing configs = encoder_cfg[encoder_type] may raise a KeyError if encoder_type is not a key in encoder_cfg. This can happen when encoder_cfg does not contain configurations specific to the encoder_type.

Apply this diff to ensure configs is correctly retrieved:

 def create_visual_encoder(d_model: int, **encoder_cfg) -> torch.nn.Module:
     """Create a visual encoder based on the specified type."""
 
     register_encoder("resnet", VisualEncoder)
     register_encoder("descriptor", DescriptorVisualEncoder)
     # register any custom encoders here
 
     # Compatibility with configs that don't specify encoder_type; default to resnet
     if "encoder_type" not in encoder_cfg:
         encoder_type = "resnet"
+        configs = encoder_cfg
     else:
         encoder_type = encoder_cfg.pop("encoder_type")
 
-    if encoder_type in ENCODER_REGISTRY:
-        # Choose the relevant encoder configs based on the encoder_type
-        configs = encoder_cfg[encoder_type]
+    if encoder_type in ENCODER_REGISTRY:
+        # Use the remaining encoder_cfg as configs
+        configs = encoder_cfg
         return ENCODER_REGISTRY[encoder_type](d_model=d_model, **configs)
     else:
         raise ValueError(
             f"Unknown encoder type: {encoder_type}. Please use one of {list(ENCODER_REGISTRY.keys())}"
         )

This adjustment ensures that configs are appropriately assigned, preventing a KeyError when configurations are not nested under the encoder_type key.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def create_visual_encoder(d_model: int, **encoder_cfg) -> torch.nn.Module:
"""Create a visual encoder based on the specified type."""
register_encoder("resnet", VisualEncoder)
register_encoder("descriptor", DescriptorVisualEncoder)
# register any custom encoders here
# compatibility with configs that don't specify encoder_type; default to resnet
if "encoder_type" not in encoder_cfg:
encoder_type = "resnet"
return ENCODER_REGISTRY[encoder_type](d_model=d_model, **encoder_cfg)
else:
encoder_type = encoder_cfg.pop("encoder_type")
if encoder_type in ENCODER_REGISTRY:
# choose the relevant encoder configs based on the encoder_type
configs = encoder_cfg[encoder_type]
return ENCODER_REGISTRY[encoder_type](d_model=d_model, **configs)
else:
raise ValueError(
f"Unknown encoder type: {encoder_type}. Please use one of {list(ENCODER_REGISTRY.keys())}"
)
def create_visual_encoder(d_model: int, **encoder_cfg) -> torch.nn.Module:
"""Create a visual encoder based on the specified type."""
register_encoder("resnet", VisualEncoder)
register_encoder("descriptor", DescriptorVisualEncoder)
# register any custom encoders here
# Compatibility with configs that don't specify encoder_type; default to resnet
if "encoder_type" not in encoder_cfg:
encoder_type = "resnet"
configs = encoder_cfg
else:
encoder_type = encoder_cfg.pop("encoder_type")
if encoder_type in ENCODER_REGISTRY:
# Use the remaining encoder_cfg as configs
configs = encoder_cfg
return ENCODER_REGISTRY[encoder_type](d_model=d_model, **configs)
else:
raise ValueError(
f"Unknown encoder type: {encoder_type}. Please use one of {list(ENCODER_REGISTRY.keys())}"
)

@@ -228,12 +246,25 @@ def forward(

else:
query_instances = ref_instances
query_times = ref_times
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure query_instances is assigned when None

When query_instances is None, only query_times is assigned. This may lead to query_instances being None later, causing errors. Assign query_instances = ref_instances to prevent potential issues.

Apply this diff to fix the issue:

             else:
+                query_instances = ref_instances
                 query_times = ref_times

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines 295 to 311
def apply_fourier_embeddings(
queries: torch.Tensor,
times: torch.Tensor,
d_model: int,
fourier_embeddings: FourierPositionalEmbeddings,
) -> torch.Tensor:

embs = fourier_embeddings(times).permute(1, 0, 2)
cat_queries = torch.cat([queries, embs], dim=-1)
# project to d_model
proj = nn.Linear(cat_queries.shape[-1], d_model).to(queries.device)
norm = nn.LayerNorm(d_model).to(queries.device)
cat_queries = proj(cat_queries)
cat_queries = norm(cat_queries)

return cat_queries

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Avoid defining neural network layers inside a function

Defining nn.Linear and nn.LayerNorm inside apply_fourier_embeddings leads to new instances being created on each call. This prevents the layers from learning during training, as their parameters are not registered with the model.

Move the layer definitions to the Transformer class's __init__ method and store them as instance variables.

Apply this diff:

 class Transformer(torch.nn.Module):
     def __init__(self, ...):
         ...
+        # Initialize Fourier embeddings
+        self.fourier_embeddings = FourierPositionalEmbeddings(
+            n_components=8, d_model=d_model
+        )
+
+        # Define projection and normalization layers for Fourier embeddings
+        fourier_dim = self.fourier_embeddings.d_model
+        self.fourier_proj = nn.Linear(d_model + fourier_dim, d_model)
+        self.fourier_norm = nn.LayerNorm(d_model)
         ...

+    def apply_fourier_embeddings(
+        self,
+        queries: torch.Tensor,
+        times: torch.Tensor,
+    ) -> torch.Tensor:
+        embs = self.fourier_embeddings(times).permute(1, 0, 2)
+        cat_queries = torch.cat([queries, embs], dim=-1)
+        cat_queries = self.fourier_proj(cat_queries)
+        cat_queries = self.fourier_norm(cat_queries)
+        return cat_queries

...

-    def apply_fourier_embeddings(
-        queries: torch.Tensor,
-        times: torch.Tensor,
-        d_model: int,
-        fourier_embeddings: FourierPositionalEmbeddings,
-    ) -> torch.Tensor:
-
-        embs = fourier_embeddings(times).permute(1, 0, 2)
-        cat_queries = torch.cat([queries, embs], dim=-1)
-        # Project to d_model
-        proj = nn.Linear(cat_queries.shape[-1], d_model).to(queries.device)
-        norm = nn.LayerNorm(d_model).to(queries.device)
-        cat_queries = proj(cat_queries)
-        cat_queries = norm(cat_queries)
-        return cat_queries

Update calls to apply_fourier_embeddings:

             if self.should_apply_fourier_embeddings():
-                encoder_queries = apply_fourier_embeddings(
-                    encoder_queries, ref_times, self.d_model, self.fourier_embeddings
-                )
+                encoder_queries = self.apply_fourier_embeddings(
+                    encoder_queries, ref_times
                 )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def apply_fourier_embeddings(
queries: torch.Tensor,
times: torch.Tensor,
d_model: int,
fourier_embeddings: FourierPositionalEmbeddings,
) -> torch.Tensor:
embs = fourier_embeddings(times).permute(1, 0, 2)
cat_queries = torch.cat([queries, embs], dim=-1)
# project to d_model
proj = nn.Linear(cat_queries.shape[-1], d_model).to(queries.device)
norm = nn.LayerNorm(d_model).to(queries.device)
cat_queries = proj(cat_queries)
cat_queries = norm(cat_queries)
return cat_queries
def apply_fourier_embeddings(
self,
queries: torch.Tensor,
times: torch.Tensor,
) -> torch.Tensor:
embs = self.fourier_embeddings(times).permute(1, 0, 2)
cat_queries = torch.cat([queries, embs], dim=-1)
cat_queries = self.fourier_proj(cat_queries)
cat_queries = self.fourier_norm(cat_queries)
return cat_queries

update sample configs for train and for unit tests
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Outside diff range and nitpick comments (11)
tests/configs/base.yaml (3)

3-3: Remove trailing whitespace.

There is a trailing space after encoder_cfg:.

-  encoder_cfg: 
+  encoder_cfg:
🧰 Tools
🪛 yamllint (1.35.1)

[error] 3-3: trailing spaces

(trailing-spaces)


6-6: Enhance documentation for ndim parameter.

The current comment could be more descriptive about the relationship between ndim and encoder types.

-    ndim: 5 # dim of the visual feature vector; 5 for descriptor, 12 for descriptor + hu moments
+    ndim: 5 # Dimension of the visual feature vector:
+           # - 5 for basic descriptor (inertia tensor + mean intensity)
+           # - 12 when using descriptor with Hu moments enabled

12-13: Add documentation for descriptor configuration.

The descriptor configuration would benefit from comments explaining:

  • The impact of enabling Hu moments on model performance
  • When to use Hu moments vs. basic descriptors
     descriptor:
+      # Enable Hu moments for rotation-invariant shape descriptors
+      # Recommended for tracking objects with significant rotation
       use_hu_moments: false
dreem/training/configs/base.yaml (3)

4-7: Consider moving legacy configuration documentation.

While keeping the old configuration structure as comments can help during migration, it's better to document these changes in proper migration guides or documentation to avoid confusion and reduce maintenance overhead.

🧰 Tools
🪛 yamllint (1.35.1)

[error] 4-4: trailing spaces

(trailing-spaces)


10-12: Enhance configuration documentation and validation.

The encoder configuration would benefit from:

  1. Documentation of all supported encoder_type values
  2. Clarification if both resnet and descriptor configs are required
  3. More flexible documentation for ndim that won't need updates if new descriptors are added

Consider adding a comment like this:

encoder_cfg:
  # Supported encoder_type values: "resnet", "descriptor"
  # Only the configuration matching the selected encoder_type is required
  encoder_type: "descriptor"
  # ndim: dimension of the visual feature vector
  # - descriptor without hu_moments: 5
  # - descriptor with hu_moments: 12
  ndim: 5

13-19: Consider enhancing the descriptor encoder configuration.

The descriptor configuration could benefit from:

  1. Validation constraints for use_hu_moments
  2. Additional customization parameters if applicable
  3. Clear documentation of the relationship between use_hu_moments and ndim

Consider expanding the descriptor configuration like this:

descriptor:
  use_hu_moments: false  # If true, ndim should be 12; if false, ndim should be 5
  # Add any additional parameters for customization
  # For example:
  # normalize_features: true
  # feature_weights: [1.0, 1.0, 1.0, 1.0, 1.0]
dreem/models/visual_encoder.py (2)

9-9: Remove unused import

The local_binary_pattern from skimage.feature is imported but never used in the code.

-from skimage.feature import local_binary_pattern
🧰 Tools
🪛 Ruff (0.8.0)

9-9: skimage.feature.local_binary_pattern imported but unused

Remove unused import: skimage.feature.local_binary_pattern

(F401)


167-221: Consider documenting performance characteristics

The DescriptorVisualEncoder provides a lightweight alternative to CNN-based encoding, which could be beneficial for resource-constrained environments. Consider adding documentation about:

  1. Memory usage comparison with ResNet
  2. Inference speed benchmarks
  3. Use cases where descriptor-based encoding might be preferred
  4. Limitations of the descriptor-based approach

This will help users make informed decisions about which encoder to use.

dreem/models/transformer.py (3)

98-100: Consider making Fourier embedding parameters configurable

The Fourier embeddings are initialized with hardcoded parameters. Consider making these configurable through encoder_cfg for better flexibility.

         self.fourier_embeddings = FourierPositionalEmbeddings(
-            n_components=8, d_model=d_model
+            n_components=self.encoder_cfg.get("fourier_components", 8),
+            d_model=d_model
         )

110-115: Simplify configuration access using dict.get()

The nested configuration checks can be simplified using dict.get() method.

-        if self.encoder_cfg and "encoder_type" in self.encoder_cfg:
-            self.visual_feat_dim = (
-                self.encoder_cfg["ndim"] if "ndim" in self.encoder_cfg else 5
-            )  # 5 is default for descriptor
+        if self.encoder_cfg and self.encoder_cfg.get("encoder_type"):
+            self.visual_feat_dim = self.encoder_cfg.get("ndim", 5)  # 5 is default for descriptor
             self.fourier_proj = nn.Linear(self.d_model + self.visual_feat_dim, d_model)
             self.fourier_norm = nn.LayerNorm(self.d_model)
🧰 Tools
🪛 Ruff (0.8.0)

112-112: Use self.encoder_cfg.get("ndim", 5) instead of an if block

Replace with self.encoder_cfg.get("ndim", 5)

(SIM401)


318-346: Consider moving apply_fourier_embeddings into the Transformer class

This function uses class members and is only used within the Transformer class. Consider making it a method of the Transformer class to improve encapsulation.

-def apply_fourier_embeddings(
-    queries: torch.Tensor,
-    times: torch.Tensor,
-    d_model: int,
-    fourier_embeddings: FourierPositionalEmbeddings,
-    proj: nn.Linear,
-    norm: nn.LayerNorm,
-) -> torch.Tensor:
+def _apply_fourier_embeddings(
+    self,
+    queries: torch.Tensor,
+    times: torch.Tensor,
+) -> torch.Tensor:
     """Apply fourier embeddings to queries.
     
     Args:
         queries: The input tensor of shape (n_query, batch_size, embed_dim).
         times: The times index tensor of shape (n_query,).
-        d_model: Model dimension.
-        fourier_embeddings: The Fourier positional embeddings object.
-        proj: Linear projection layer that projects concantenated feature vector to model dimension.
-        norm: The normalization layer.
     
     Returns:
         The output queries of shape (n_query, batch_size, embed_dim).
     """
-    embs = fourier_embeddings(times).permute(1, 0, 2)
+    embs = self.fourier_embeddings(times).permute(1, 0, 2)
     cat_queries = torch.cat([queries, embs], dim=-1)
     # project to d_model
-    cat_queries = proj(cat_queries)
-    cat_queries = norm(cat_queries)
+    cat_queries = self.fourier_proj(cat_queries)
+    cat_queries = self.fourier_norm(cat_queries)
 
     return cat_queries
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between dc5a6f0 and f37c4c4.

📒 Files selected for processing (7)
  • dreem/io/config.py (1 hunks)
  • dreem/models/embedding.py (1 hunks)
  • dreem/models/global_tracking_transformer.py (4 hunks)
  • dreem/models/transformer.py (9 hunks)
  • dreem/models/visual_encoder.py (2 hunks)
  • dreem/training/configs/base.yaml (1 hunks)
  • tests/configs/base.yaml (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • dreem/models/embedding.py
  • dreem/models/global_tracking_transformer.py
🧰 Additional context used
🪛 Ruff (0.8.0)
dreem/io/config.py

272-276: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)

dreem/models/transformer.py

112-112: Use self.encoder_cfg.get("ndim", 5) instead of an if block

Replace with self.encoder_cfg.get("ndim", 5)

(SIM401)

dreem/models/visual_encoder.py

9-9: skimage.feature.local_binary_pattern imported but unused

Remove unused import: skimage.feature.local_binary_pattern

(F401)

🪛 yamllint (1.35.1)
dreem/training/configs/base.yaml

[error] 4-4: trailing spaces

(trailing-spaces)


[error] 9-9: trailing spaces

(trailing-spaces)

tests/configs/base.yaml

[error] 3-3: trailing spaces

(trailing-spaces)

🔇 Additional comments (5)
tests/configs/base.yaml (1)

7-11: Consider enabling pretrained weights for better feature extraction.

The ResNet encoder is configured with pretrained: false. Consider enabling pretrained weights as they often provide better feature extraction capabilities, especially for tasks with limited training data.

dreem/training/configs/base.yaml (1)

10-19: Verify configuration handling in code.

Please ensure that:

  1. The code properly validates the encoder_type selection
  2. The relationship between use_hu_moments and ndim is enforced
  3. Configuration errors provide clear error messages
✅ Verification successful

Let me gather more information about the descriptor encoder and its validation.


Let me search for the descriptor encoder implementation and validation logic in a different way.


Configuration validation is properly implemented

The codebase shows proper handling of the configuration concerns:

  1. The encoder_type validation is implemented in create_visual_encoder() with a registry system and clear error messages for unknown types.
  2. The relationship between use_hu_moments and descriptor dimensionality is handled in the DescriptorVisualEncoder class, which correctly combines the features based on the configuration.
  3. The code includes proper error handling with descriptive messages for invalid encoder types.

Additional observations:

  • Default fallback to "resnet" encoder when type is not specified
  • Clear separation of encoder-specific configurations
  • Type validation through the registry system
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for configuration validation code
rg -A 5 "encoder_type.*=|encoder_cfg.*=" 

# Search for error handling related to configuration
rg -A 5 "raise.*Error.*encoder"

Length of output: 6104


Script:

#!/bin/bash
# Search for descriptor encoder implementation and validation
ast-grep --pattern 'class DescriptorEncoder {
  $$$
}'

# Search for configuration validation in descriptor-related code
rg -A 10 "descriptor.*config|descriptor.*cfg"

# Search for ndim validation
rg -A 5 "ndim.*=|ndim.*\+" 

Length of output: 882


Script:

#!/bin/bash
# Search for descriptor-related code
rg -A 10 "class.*Descriptor" 

# Search for validation of Hu moments
rg -A 5 "hu_moments|use_hu"

# Search for any validation in the model creation
rg -A 10 "def create_.*model|def build_.*model"

Length of output: 4150

dreem/models/visual_encoder.py (1)

230-250: Fix potential KeyError in create_visual_encoder

The issue mentioned in the past review about potential KeyError when accessing encoder_cfg[encoder_type] is still present. Please refer to the previous review comment for the suggested fix.

dreem/io/config.py (1)

271-277: 🛠️ Refactor suggestion

Improve robustness and documentation of encoder configuration

The current implementation has several areas for improvement:

  1. Deep nesting of dictionary access could lead to KeyError exceptions
  2. The rationale for disabling normalization with descriptor encoder should be documented
  3. The string literal "descriptor" should be extracted to a constant
  4. The nested if statements can be combined

Apply this refactor:

+ # Constants
+ DESCRIPTOR_ENCODER_TYPE = "descriptor"
+
+ def _get_encoder_type(self) -> str | None:
+     """Get the encoder type from model config safely."""
+     try:
+         return self.cfg["model"]["encoder_cfg"]["encoder_type"]
+     except (KeyError, TypeError):
+         return None
+
- # TODO: handle this better!
- if "model" in self.cfg:
-     if (
-         "encoder_type" in self.cfg["model"]["encoder_cfg"]
-         and self.cfg["model"]["encoder_cfg"]["encoder_type"] == "descriptor"
-     ):
-         dataset_params["normalize_image"] = False
+ # Disable normalization for descriptor encoder to prevent empty crops
+ # in low-intensity regions that could hinder feature calculation
+ if (encoder_type := self._get_encoder_type()) == DESCRIPTOR_ENCODER_TYPE:
+     dataset_params["normalize_image"] = False
🧰 Tools
🪛 Ruff (0.8.0)

272-276: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)

dreem/models/transformer.py (1)

265-265: Ensure query_instances is assigned when None

When query_instances is None, only query_times is assigned. This may lead to query_instances being None later, causing errors.

Comment on lines +3 to +13
encoder_cfg:
# based on encoder_type, the relevant settings are selected
encoder_type: "resnet"
ndim: 5 # dim of the visual feature vector; 5 for descriptor, 12 for descriptor + hu moments
resnet:
model_name: "resnet18"
in_chans: 3
backend: "timm"
pretrained: false
descriptor:
use_hu_moments: false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Fourier embeddings are controlled by use_fourier flag in embedding configuration

Based on the codebase analysis, I found that:

  1. Fourier embeddings are implemented in dreem/models/transformer.py and dreem/models/embedding.py
  2. They are activated through the use_fourier flag in embedding_meta configuration
  3. The base configuration file shows embedding_meta section but lacks the use_fourier setting

The configuration should be updated to include:

  • Add use_fourier: true under the embedding_meta section
  • Optionally, consider adding Fourier-specific settings like n_components (default is 8 from the code)
🔗 Analysis chain

Verify configuration for Fourier embeddings.

The PR objectives mention learnable Fourier embeddings, but there's no configuration section for them in this file. Please verify if:

  1. Fourier embedding settings should be added to this configuration
  2. These settings are defined in another configuration file
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for Fourier embedding configurations
rg -l "fourier.*embedding" "tests/configs/"

Length of output: 43


Script:

#!/bin/bash
# Let's broaden the search to look for any Fourier-related configurations and implementation
rg -i "fourier" --type yaml tests/configs/ -A 5

Length of output: 47


Script:

#!/bin/bash
# Search for Fourier-related code in the entire codebase
rg -i "fourier" -g "!*.pyc" -g "!*.git*" -A 5

Length of output: 8183


Script:

#!/bin/bash
# Look for encoder types in configuration files
rg "encoder_type:" tests/configs/ -A 3

Length of output: 299


Script:

#!/bin/bash
# Let's check for embedding configurations in the config files
rg "embedding" tests/configs/ -A 5

Length of output: 860


Script:

#!/bin/bash
# Check if there are other configuration files that might contain the settings
fd -e yaml -e yml . tests/configs/

Length of output: 115

🧰 Tools
🪛 yamllint (1.35.1)

[error] 3-3: trailing spaces

(trailing-spaces)

Comment on lines +193 to +220
@torch.no_grad()
def forward(self, img: torch.Tensor) -> torch.Tensor:
"""Forward pass of feature extractor to get feature vector."""
descriptors = []

for im in img:
im = im[0].cpu().numpy()

inertia_tensor = self.compute_inertia_tensor(im)
mean_intensity = im.mean()
if self.use_hu_moments:
hu_moments = self.compute_hu_moments(im)

# Flatten inertia tensor
inertia_tensor_flat = inertia_tensor.flatten()

# Combine all features into a single descriptor
descriptor = np.concatenate(
[
inertia_tensor_flat,
[mean_intensity],
hu_moments if self.use_hu_moments else [],
]
)

descriptors.append(torch.tensor(descriptor, dtype=torch.float32))

return torch.stack(descriptors)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Improve forward method implementation

Several improvements are needed in the forward method:

  1. Missing input shape validation
  2. Inefficient CPU transfer
  3. Missing feature normalization
  4. Incomplete docstring
     @torch.no_grad()
     def forward(self, img: torch.Tensor) -> torch.Tensor:
-        """Forward pass of feature extractor to get feature vector."""
+        """Forward pass of feature extractor to get feature vector.
+        
+        Args:
+            img: Input image tensor of shape (B, C, H, W). Expects grayscale images (C=1).
+            
+        Returns:
+            feats: Normalized output tensor of shape (B, d_model).
+        """
+        if img.shape[1] != 1:
+            raise ValueError(f"Expected grayscale images with 1 channel, got {img.shape[1]} channels")
+
         descriptors = []
 
         for im in img:
-            im = im[0].cpu().numpy()
+            # Keep computation on GPU if available
+            im = im[0].numpy() if not torch.cuda.is_available() else im[0].cpu().numpy()
 
             inertia_tensor = self.compute_inertia_tensor(im)
             mean_intensity = im.mean()
             if self.use_hu_moments:
                 hu_moments = self.compute_hu_moments(im)
 
-            # Flatten inertia tensor
             inertia_tensor_flat = inertia_tensor.flatten()
 
-            # Combine all features into a single descriptor
             descriptor = np.concatenate(
                 [
                     inertia_tensor_flat,
                     [mean_intensity],
                     hu_moments if self.use_hu_moments else [],
                 ]
             )
 
             descriptors.append(torch.tensor(descriptor, dtype=torch.float32))
 
-        return torch.stack(descriptors)
+        feats = torch.stack(descriptors)
+        if torch.cuda.is_available():
+            feats = feats.cuda()
+        
+        # Project to d_model dimension and normalize
+        feats = self.projection(feats)
+        feats = F.normalize(feats)
+        return feats

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +170 to +178
def __init__(self, use_hu_moments: bool = False, **kwargs):
"""Initialize Descriptor Visual Encoder.

Args:
use_hu_moments: Whether to use Hu moments.
"""
super().__init__()
self.use_hu_moments = use_hu_moments

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add missing d_model parameter for consistency

The DescriptorVisualEncoder should maintain consistency with VisualEncoder by accepting and utilizing the d_model parameter to ensure proper feature dimensionality.

-    def __init__(self, use_hu_moments: bool = False, **kwargs):
+    def __init__(self, d_model: int, use_hu_moments: bool = False, **kwargs):
         """Initialize Descriptor Visual Encoder.

         Args:
+            d_model: Output embedding dimension.
             use_hu_moments: Whether to use Hu moments.
         """
         super().__init__()
+        self.d_model = d_model
         self.use_hu_moments = use_hu_moments
+        # Add projection layer to match d_model dimension
+        feature_dim = 4 + (7 if use_hu_moments else 0)  # inertia tensor (4) + mean + optional hu moments (7)
+        self.projection = torch.nn.Linear(feature_dim, d_model)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def __init__(self, use_hu_moments: bool = False, **kwargs):
"""Initialize Descriptor Visual Encoder.
Args:
use_hu_moments: Whether to use Hu moments.
"""
super().__init__()
self.use_hu_moments = use_hu_moments
def __init__(self, d_model: int, use_hu_moments: bool = False, **kwargs):
"""Initialize Descriptor Visual Encoder.
Args:
d_model: Output embedding dimension.
use_hu_moments: Whether to use Hu moments.
"""
super().__init__()
self.d_model = d_model
self.use_hu_moments = use_hu_moments
# Add projection layer to match d_model dimension
feature_dim = 4 + (7 if use_hu_moments else 0) # inertia tensor (4) + mean + optional hu moments (7)
self.projection = torch.nn.Linear(feature_dim, d_model)

Comment on lines +212 to +229
# apply fourier embeddings if using fourier rope, OR if using descriptor (compact) visual encoder
if (
self.embedding_meta
and "use_fourier" in self.embedding_meta
and self.embedding_meta["use_fourier"]
) or (
self.encoder_cfg
and "encoder_type" in self.encoder_cfg
and self.encoder_cfg["encoder_type"] == "descriptor"
):
encoder_queries = apply_fourier_embeddings(
encoder_queries,
ref_times,
self.d_model,
self.fourier_embeddings,
self.fourier_proj,
self.fourier_norm,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Refactor duplicate Fourier embedding logic

The code for applying Fourier embeddings is duplicated. Consider extracting the condition check into a method and reusing it.

+    def should_apply_fourier_embeddings(self) -> bool:
+        """Check if Fourier embeddings should be applied."""
+        return (
+            (self.embedding_meta and self.embedding_meta.get("use_fourier")) or
+            (self.encoder_cfg and self.encoder_cfg.get("encoder_type") == "descriptor")
+        )

-        if (
-            self.embedding_meta
-            and "use_fourier" in self.embedding_meta
-            and self.embedding_meta["use_fourier"]
-        ) or (
-            self.encoder_cfg
-            and "encoder_type" in self.encoder_cfg
-            and self.encoder_cfg["encoder_type"] == "descriptor"
-        ):
+        if self.should_apply_fourier_embeddings():
             encoder_queries = apply_fourier_embeddings(
                 encoder_queries,
                 ref_times,
                 self.d_model,
                 self.fourier_embeddings,
                 self.fourier_proj,
                 self.fourier_norm,
             )

Also applies to: 272-289

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant