-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor visual encoder and features #101
base: main
Are you sure you want to change the base?
Conversation
…ct appropriate encoder - registry allows users to define a custom visual encoder, add configs and register into the registry - global_tracking_Transformer calls the router method rather than VisualEncoder directly - update config file structure to allow different config sets for different encoder types - maintain backward compatibility in case old configs are used - defaults to existing encoder
move data to and from gpu when computing skimage features
…visual features to larger d_model - transformer concatenates learnable fourier params onto feature vector and projects - remove Hu embeddings from feature descriptor - for numerical stability
- use embed_dim = d_model rather than taking it from the incoming features shapes - apply fourier embeddings and project to d_model; fix minor bugs in dims and order of operations - apply fourier embeddings to decoder input as well as encoder
… regions - empty crops prevent calculation of visual features
… used (for training stability) - remove lbp encoder from visual encoder and config structure
- add config example
WalkthroughThe pull request introduces several modifications across multiple files. Key changes include the addition of a Changes
Possibly related PRs
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Outside diff range and nitpick comments (14)
dreem/models/__init__.py (1)
9-14
: Excellent architectural design for extensibilityThe combination of factory and registry patterns creates a flexible plugin architecture that:
- Enables users to register custom encoders
- Provides a clean interface for encoder instantiation
- Facilitates future additions of new encoder types without modifying existing code
🧰 Tools
🪛 Ruff (0.8.0)
10-10:
.visual_encoder.VisualEncoder
imported but unused; consider removing, adding to__all__
, or using a redundant alias(F401)
11-11:
.visual_encoder.DescriptorVisualEncoder
imported but unused; consider removing, adding to__all__
, or using a redundant alias(F401)
12-12:
.visual_encoder.create_visual_encoder
imported but unused; consider removing, adding to__all__
, or using a redundant alias(F401)
13-13:
.visual_encoder.register_encoder
imported but unused; consider removing, adding to__all__
, or using a redundant alias(F401)
dreem/models/global_tracking_transformer.py (3)
75-75
: Consider reviewing the architecture designPassing
encoder_cfg
to the Transformer initialization might indicate tight coupling between the Transformer and encoder configuration. Consider if this violates separation of concerns and whether the encoder configuration should be handled entirely by the visual encoder component.
Line range hint
8-8
: Address TODO comment about parameter handlingThe TODO comment suggests uncertainty about parameter handling with configs. This should be resolved before merging.
Would you like me to help design a more structured approach to parameter handling?
Line range hint
10-11
: Consider splitting responsibilitiesThe
GlobalTrackingTransformer
class handles both visual encoding and transformer operations. Consider splitting these responsibilities into separate components for better maintainability and testing.dreem/models/embedding.py (3)
361-363
: Enhance function documentationThe docstring should include parameter descriptions and return type information.
Consider updating the docstring to:
- """Create a tensor of shape (1,n) of fourier frequency coefficients""" + """Create a tensor of shape (1,n) of Fourier frequency coefficients. + + Args: + cutoff: The maximum frequency cutoff value + n: Number of frequency components + + Returns: + torch.Tensor: A tensor of shape (1,n) containing logarithmically spaced frequencies + """
366-376
: Improve class documentation and add input validationThe class docstring should better describe its purpose and parameters. Additionally, consider adding input validation.
class FourierPositionalEmbeddings(torch.nn.Module): + """Learnable Fourier positional embeddings for transformer models. + + This class implements positional encodings using learnable Fourier frequency + coefficients. The embeddings are computed using both sine and cosine + transformations of the input positions. + + Args: + n_components (int): Number of frequency components for each dimension + d_model (int): The model dimension (must be divisible by 2*n_components) + """ def __init__( self, n_components: int, d_model: int, ): - """Positional encoding with given cutoff and number of frequencies for each dimension. - number of dimension is inferred from the length of cutoffs and n_pos. - """ super().__init__() + if d_model % (2 * n_components) != 0: + raise ValueError( + f"d_model ({d_model}) must be divisible by 2*n_components ({2*n_components})" + ) self.d_model = d_model self.n_components = n_components
383-404
: Optimize computation and improve type hintsThe forward method could benefit from type hints and clearer variable names.
- def forward(self, seq_positions: torch.Tensor): + def forward(self, seq_positions: torch.Tensor) -> torch.Tensor: """Compute learnable fourier coefficients for each spatial/temporal position. Args: - seq_positions: tensor of shape (num_queries,) + seq_positions: Tensor of shape (num_queries,) containing position indices Returns: - tensor of shape (num_queries, embed_dim) + torch.Tensor: Positional embeddings of shape (1, num_queries, d_model) """ freq = self.freq.to(seq_positions.device) + # Reshape positions once for both sin and cos computations + positions = seq_positions.unsqueeze(-1).unsqueeze(0) * freq * 0.5 * math.pi embed = torch.cat( ( - torch.sin( - 0.5 * math.pi * seq_positions.unsqueeze(-1).unsqueeze(0) * freq - ), - torch.cos( - 0.5 * math.pi * seq_positions.unsqueeze(-1).unsqueeze(0) * freq - ), + torch.sin(positions), + torch.cos(positions), ), axis=-1, ) / math.sqrt(len(freq))dreem/io/config.py (1)
271-275
: Consider decoupling encoder configuration from dataset preprocessingThe current implementation creates a direct dependency between the encoder type and dataset preprocessing. This coupling might make it harder to:
- Add new encoder types without modifying dataset logic
- Mix different preprocessing strategies independently
- Test preprocessing configurations in isolation
Consider these alternatives:
- Move preprocessing configuration to the encoder config itself:
encoder_cfg = { "type": "descriptor", "preprocessing": { "normalize_image": False } }
- Create a separate preprocessing configuration section:
preprocessing_cfg = { "descriptor": { "normalize_image": False }, "resnet": { "normalize_image": True } }This would make the system more modular and easier to extend with new encoder types.
dreem/models/visual_encoder.py (2)
9-9
: Remove unused importlocal_binary_pattern
The function
local_binary_pattern
fromskimage.feature
is imported but not used in the code. Removing it will clean up the imports.Apply this diff to remove the unused import:
- from skimage.feature import local_binary_pattern
🧰 Tools
🪛 Ruff (0.8.0)
9-9:
skimage.feature.local_binary_pattern
imported but unusedRemove unused import:
skimage.feature.local_binary_pattern
(F401)
187-215
: Optimizeforward
method by vectorizing computationsThe
forward
method processes images in a loop, which can be inefficient for large batches. Consider vectorizing the computations to leverage PyTorch's batch processing capabilities and improve performance.dreem/models/transformer.py (4)
45-45
: Ensureencoder_cfg
parameter is documented and type-hinted correctlyThe new parameter
encoder_cfg
is added to theTransformer
class. Please ensure that it's fully documented in the method's docstring, including its purpose and expected structure.
97-100
: Maken_components
inFourierPositionalEmbeddings
configurableCurrently,
n_components
is hard-coded to8
. Consider makingn_components
configurable throughencoder_cfg
or constructor parameters to provide flexibility for different model configurations.
203-214
: Refactor condition to reduce code duplicationThe condition to apply Fourier embeddings is repeated in both the encoder and decoder sections. Refactor this condition into a helper method to improve code maintainability and readability.
Apply this diff to create a helper method:
+ def should_apply_fourier_embeddings(self) -> bool: + return ( + "use_fourier" in self.embedding_meta and self.embedding_meta["use_fourier"] + ) or ( + self.encoder_cfg is not None + and "encoder_type" in self.encoder_cfg + and self.encoder_cfg["encoder_type"] == "descriptor" + ) ... # Apply Fourier embeddings if conditions are met - if ( - "use_fourier" in self.embedding_meta and self.embedding_meta["use_fourier"] - ) or ( - self.encoder_cfg is not None - and "encoder_type" in self.encoder_cfg - and self.encoder_cfg["encoder_type"] == "descriptor" - ): + if self.should_apply_fourier_embeddings(): encoder_queries = self.apply_fourier_embeddings( encoder_queries, ref_times )
256-267
: Refactor condition to reduce code duplication in decoderSimilar to the encoder section, the condition to apply Fourier embeddings is repeated. Use the
should_apply_fourier_embeddings
helper method to improve code maintainability.Apply this diff:
# Apply Fourier embeddings if conditions are met - if ( - "use_fourier" in self.embedding_meta and self.embedding_meta["use_fourier"] - ) or ( - self.encoder_cfg is not None - and "encoder_type" in self.encoder_cfg - and self.encoder_cfg["encoder_type"] == "descriptor" - ): + if self.should_apply_fourier_embeddings(): query_features = self.apply_fourier_embeddings( query_features, query_times )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (8)
dreem/datasets/sleap_dataset.py
(4 hunks)dreem/io/config.py
(1 hunks)dreem/models/__init__.py
(1 hunks)dreem/models/embedding.py
(1 hunks)dreem/models/global_tracking_transformer.py
(4 hunks)dreem/models/transformer.py
(8 hunks)dreem/models/visual_encoder.py
(2 hunks)dreem/training/configs/base.yaml
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.0)
dreem/models/__init__.py
3-3: .embedding.Embedding
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
3-3: .embedding.FourierPositionalEmbeddings
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
8-8: .transformer.Transformer
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
10-10: .visual_encoder.VisualEncoder
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
11-11: .visual_encoder.DescriptorVisualEncoder
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
12-12: .visual_encoder.create_visual_encoder
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
13-13: .visual_encoder.register_encoder
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
dreem/models/visual_encoder.py
9-9: skimage.feature.local_binary_pattern
imported but unused
Remove unused import: skimage.feature.local_binary_pattern
(F401)
🪛 yamllint (1.35.1)
dreem/training/configs/base.yaml
[error] 8-8: trailing spaces
(trailing-spaces)
🔇 Additional comments (12)
dreem/models/__init__.py (2)
3-3
: LGTM: Appropriate addition of FourierPositionalEmbeddings to public API
The addition of FourierPositionalEmbeddings
to the public API aligns with the PR objective of introducing learnable Fourier embeddings for positional encoding.
🧰 Tools
🪛 Ruff (0.8.0)
3-3: .embedding.Embedding
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
3-3: .embedding.FourierPositionalEmbeddings
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
9-14
: LGTM: Well-structured expansion of visual encoder capabilities
The additions to the public API are well-organized and comprehensively support the new visual encoder framework:
- Base
VisualEncoder
class - New
DescriptorVisualEncoder
implementation - Factory function
create_visual_encoder
- Registration system via
register_encoder
The multi-line import format enhances readability.
🧰 Tools
🪛 Ruff (0.8.0)
10-10: .visual_encoder.VisualEncoder
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
11-11: .visual_encoder.DescriptorVisualEncoder
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
12-12: .visual_encoder.create_visual_encoder
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
13-13: .visual_encoder.register_encoder
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
dreem/models/global_tracking_transformer.py (3)
4-4
: LGTM: Good use of factory pattern for encoder creation
The switch to using a factory function aligns well with the PR's objective of supporting multiple visual encoder types and improving configuration flexibility.
59-59
: Verify error handling for invalid encoder configurations
The simplified encoder creation looks good, but we should ensure proper error handling for invalid configurations.
✅ Verification successful
Error handling for invalid encoder configurations is properly implemented
The implementation includes comprehensive error handling:
- Validates encoder class type during registration
- Handles unknown encoder types with descriptive error messages
- Validates input channel mismatches in the VisualEncoder
- Provides clear error messages with available encoder types
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for error handling in the create_visual_encoder implementation
ast-grep --pattern 'def create_visual_encoder($$$) {
$$$
raise $$$
$$$
}'
# Check for any existing error handling tests
rg -l "test.*create_visual_encoder" tests/
Length of output: 126
Script:
#!/bin/bash
# Let's try a different approach to find the create_visual_encoder implementation
rg -A 10 "def create_visual_encoder"
# Also search for visual encoder related error handling
rg "raise \w+Error.*encoder"
# Look for visual encoder related tests with a broader pattern
rg -l "visual.*encoder.*test|test.*visual.*encoder" tests/
Length of output: 1107
Script:
#!/bin/bash
# Let's check the ENCODER_REGISTRY implementation and usage
rg -B 2 -A 2 "ENCODER_REGISTRY"
# Check the register_encoder implementation
rg -B 2 -A 5 "def register_encoder"
# Look for any additional error handling in visual_encoder.py
cat dreem/models/visual_encoder.py
Length of output: 10970
126-126
: Verify device handling for edge cases
The explicit device management is good for multi-GPU compatibility. However, we should verify handling of edge cases.
✅ Verification successful
Let me check the device-related tests and other device management code to better understand the edge cases.
Let me check the context around the device management in the global tracking transformer to better understand how the features and instances are handled.
Device handling is properly implemented
The device management is correctly implemented with proper safeguards:
- Features are moved to the same device as instances after visual encoding
- The code handles empty and single-instance cases
- Device consistency is maintained throughout the processing pipeline
- The commented-out test fixture in
tests/fixtures/torch.py
suggests device-related testing was considered but separated to avoid interference
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for other device management patterns in the codebase
rg "\.to\(device="
# Look for any existing device-related tests
rg -l "test.*device" tests/
Length of output: 289
Script:
#!/bin/bash
# Check device-related test content
cat tests/test_inference.py
cat tests/fixtures/torch.py
# Check the context of device management in boxes.py
rg -B 3 -A 3 "\.to\(device=" dreem/inference/boxes.py
Length of output: 9439
Script:
#!/bin/bash
# Check the context around device management in global_tracking_transformer.py
rg -B 10 -A 10 "\.to\(device=" dreem/models/global_tracking_transformer.py
# Look for instance device handling
ast-grep --pattern 'instances_to_compute'
Length of output: 1497
dreem/datasets/sleap_dataset.py (2)
38-38
: LGTM: Well-structured parameter addition
The new normalize_image
parameter is properly integrated with clear documentation and appropriate default value for backward compatibility.
Also applies to: 75-75, 103-103
224-226
: Verify the impact of normalization changes on model predictions
Since this change affects image preprocessing, it's important to verify:
- The usage of
normalize_image=False
doesn't negatively impact model performance - The handling of low-intensity regions is improved as intended
Let's check how this parameter is used across the codebase:
✅ Verification successful
Image normalization change is safe and intentionally handled
The verification reveals that the normalization change is well-controlled:
- The
normalize_image
parameter is explicitly set toFalse
inconfig.py
when using the descriptor encoder type, showing intentional handling of different preprocessing requirements. - The normalization is properly documented in the
SleapDataset
class with a clear parameter description. - The change only affects integer-type images, which are first converted to float32 before optional normalization.
- No other parts of the codebase are impacted by this change as the normalization is encapsulated within the dataset class.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check usage of normalize_image parameter and related image processing
# Check for normalize_image parameter usage
rg -l "normalize_image"
# Look for image preprocessing patterns
ast-grep --pattern 'img = img / 255'
ast-grep --pattern 'img.astype(np.float32)'
# Check for potential configuration files
fd -e yaml -e json . | xargs rg -l "normalize_image"
Length of output: 365
Script:
#!/bin/bash
# Let's check the context of normalize_image usage in config.py and sleap_dataset.py
rg "normalize_image" -B 3 -A 3 dreem/io/config.py dreem/datasets/sleap_dataset.py
# Check if there are any tests related to image normalization
fd -e py . | xargs rg -l "test.*normalize.*image"
# Look for any model configuration files that might reference image normalization
fd "config" -e yaml -e json . | xargs cat
Length of output: 2500
dreem/models/embedding.py (1)
366-416
: Verify integration with transformer architecture
The implementation looks solid, but let's verify its integration with the transformer architecture as mentioned in the PR objectives.
✅ Verification successful
Integration with transformer architecture verified successfully
The implementation is correctly integrated with the transformer architecture as intended:
- Found in
apply_fourier_embeddings
function:- Fourier embeddings are concatenated with input queries:
torch.cat([queries, embs], dim=-1)
- The concatenated features are properly projected to the model dimension using a linear layer
- LayerNorm is applied after projection, ensuring proper normalization
- Fourier embeddings are concatenated with input queries:
- The transformer properly initializes
FourierPositionalEmbeddings
with correct parameters in its constructor - The integration follows the PR objectives of concatenating with feature vectors and projecting to the model dimension
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify the integration of FourierPositionalEmbeddings with transformer architecture
# Check for transformer usage of FourierPositionalEmbeddings
echo "Checking transformer integration..."
rg -l "FourierPositionalEmbeddings" --type py
# Check for visual feature concatenation
echo "Checking feature concatenation..."
ast-grep --pattern 'torch.cat([$$$ FourierPositionalEmbeddings $$$])'
Length of output: 354
Script:
#!/bin/bash
# Let's check how FourierPositionalEmbeddings is used in transformer.py
echo "Checking transformer.py implementation..."
rg -A 10 -B 10 "FourierPositionalEmbeddings" dreem/models/transformer.py
# Let's also check for any concatenation operations in transformer.py
echo "Checking concatenation operations in transformer.py..."
rg "torch.cat" dreem/models/transformer.py
Length of output: 2538
dreem/io/config.py (1)
271-275
: Verify the impact of disabling image normalization
Let's verify the integration points of this change:
- Usage of normalize_image parameter in SleapDataset
- References to descriptor encoder type
- Test coverage
✅ Verification successful
Let me gather more information about the descriptor encoder and test coverage.
Let me check one more aspect to understand the configuration and test coverage better.
Based on the gathered information, I can now provide a final response:
Disabling image normalization for descriptor encoder is correctly integrated
The change is well-integrated within the codebase for the following reasons:
- The
normalize_image
parameter inSleapDataset
is properly handled, with a default value ofTrue
and explicit normalization logic - The descriptor encoder implementation (
DescriptorVisualEncoder
) is designed to work with raw image values as it computes Hu moments - The configuration system correctly supports the descriptor encoder type through the
encoder_type
setting - Test coverage exists for both the configuration system and the
SleapDataset
, including image normalization
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify integration points of the normalization change
# Check for normalize_image parameter usage in SleapDataset
echo "Checking normalize_image usage in SleapDataset:"
rg -A 5 "normalize_image" "dreem/datasets/sleap_dataset.py"
# Check for descriptor encoder type references
echo -e "\nChecking descriptor encoder type references:"
rg -l "descriptor" "dreem/"
# Check for relevant test files
echo -e "\nChecking for related test files:"
fd -g "*test*.py" "tests/"
Length of output: 1424
Script:
#!/bin/bash
# Check descriptor encoder implementation and related tests
echo "Checking descriptor encoder implementation:"
rg -A 10 "class.*Descriptor" dreem/models/visual_encoder.py
echo -e "\nChecking normalize_image in tests:"
rg "normalize_image" "tests/"
echo -e "\nChecking descriptor encoder tests:"
rg "descriptor" "tests/test_models.py"
Length of output: 792
Script:
#!/bin/bash
# Check configuration and test coverage
echo "Checking base configuration for descriptor encoder:"
cat dreem/training/configs/base.yaml
echo -e "\nChecking test_config.py for descriptor tests:"
cat tests/test_config.py
echo -e "\nChecking test_datasets.py for SleapDataset tests:"
cat tests/test_datasets.py
Length of output: 21977
dreem/models/visual_encoder.py (1)
194-194
: Verify that images are single-channel or handle multiple channels appropriately
The code uses im = im[0].cpu().numpy()
, which extracts the first channel of the image tensor. This assumes that the images are single-channel or that only the first channel is relevant.
Please verify whether all images are single-channel. If images can have multiple channels, consider adjusting the code to handle all channels or appropriately convert images to grayscale before processing.
dreem/models/transformer.py (2)
16-16
: Approved: Importing necessary modules
The addition of FourierPositionalEmbeddings
to the imports is appropriate for the new functionality.
78-78
: Approved: Assigning encoder_cfg
to an instance variable
Properly storing encoder_cfg
for later use in the class.
img = img.astype(np.float32) | ||
if self.normalize_image: | ||
img = img / 255 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider enhancing image type handling and normalization logic
The current implementation has several areas for improvement:
- The unconditional conversion to float32 could impact memory usage unnecessarily when normalization is disabled
- The normalization assumes 8-bit images (division by 255) which may not always be the case
- There's no handling for images that are already in float format
Consider this improved implementation:
- if np.issubdtype(img.dtype, np.integer): # convert int to float
- img = img.astype(np.float32)
- if self.normalize_image:
- img = img / 255
+ if self.normalize_image:
+ if np.issubdtype(img.dtype, np.integer):
+ max_val = np.iinfo(img.dtype).max
+ img = img.astype(np.float32) / max_val
+ elif np.issubdtype(img.dtype, np.floating):
+ if img.max() > 1.0:
+ img = img / img.max()
+ elif np.issubdtype(img.dtype, np.integer) and self.normalize_image:
+ img = img.astype(np.float32)
This implementation:
- Only converts to float32 when necessary
- Handles different bit depths correctly
- Properly processes images that are already in float format
Committable suggestion skipped: line range outside the PR's diff.
if self.d_model % self.n_components != 0: | ||
raise ValueError( | ||
f"d_model ({self.d_model}) must be divisible by number of Fourier components n_components ({self.n_components})" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Move dimension validation to initialization
The dimension check should be performed during initialization rather than in the forward pass to fail fast and avoid unnecessary computations.
This validation block should be moved to __init__
as shown in the earlier suggestion.
dreem/io/config.py
Outdated
if ( | ||
"encoder_type" in self.cfg["model"]["encoder_cfg"] | ||
and self.cfg["model"]["encoder_cfg"]["encoder_type"] == "descriptor" | ||
): | ||
dataset_params["normalize_image"] = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add validation and improve robustness of configuration access
The current implementation has several areas for improvement:
- Deep nesting of dictionary access could lead to KeyError exceptions
- The rationale for disabling normalization with descriptor encoder should be documented
- The string literal "descriptor" should be extracted to a constant
Consider applying this refactor:
+ # Constants
+ DESCRIPTOR_ENCODER_TYPE = "descriptor"
+
+ def _get_encoder_type(self) -> str | None:
+ """Get the encoder type from model config safely."""
+ try:
+ return self.cfg["model"]["encoder_cfg"]["encoder_type"]
+ except (KeyError, TypeError):
+ return None
+
if (
- "encoder_type" in self.cfg["model"]["encoder_cfg"]
- and self.cfg["model"]["encoder_cfg"]["encoder_type"] == "descriptor"
+ encoder_type := self._get_encoder_type()
+ and encoder_type == DESCRIPTOR_ENCODER_TYPE
):
+ # Disable normalization for descriptor encoder to prevent empty crops
+ # in low-intensity regions that could hinder feature calculation
dataset_params["normalize_image"] = False
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if ( | |
"encoder_type" in self.cfg["model"]["encoder_cfg"] | |
and self.cfg["model"]["encoder_cfg"]["encoder_type"] == "descriptor" | |
): | |
dataset_params["normalize_image"] = False | |
# Constants | |
DESCRIPTOR_ENCODER_TYPE = "descriptor" | |
def _get_encoder_type(self) -> str | None: | |
"""Get the encoder type from model config safely.""" | |
try: | |
return self.cfg["model"]["encoder_cfg"]["encoder_type"] | |
except (KeyError, TypeError): | |
return None | |
if ( | |
encoder_type := self._get_encoder_type() | |
and encoder_type == DESCRIPTOR_ENCODER_TYPE | |
): | |
# Disable normalization for descriptor encoder to prevent empty crops | |
# in low-intensity regions that could hinder feature calculation | |
dataset_params["normalize_image"] = False |
dreem/models/visual_encoder.py
Outdated
def create_visual_encoder(d_model: int, **encoder_cfg) -> torch.nn.Module: | ||
"""Create a visual encoder based on the specified type.""" | ||
|
||
register_encoder("resnet", VisualEncoder) | ||
register_encoder("descriptor", DescriptorVisualEncoder) | ||
# register any custom encoders here | ||
|
||
# compatibility with configs that don't specify encoder_type; default to resnet | ||
if "encoder_type" not in encoder_cfg: | ||
encoder_type = "resnet" | ||
return ENCODER_REGISTRY[encoder_type](d_model=d_model, **encoder_cfg) | ||
else: | ||
encoder_type = encoder_cfg.pop("encoder_type") | ||
|
||
if encoder_type in ENCODER_REGISTRY: | ||
# choose the relevant encoder configs based on the encoder_type | ||
configs = encoder_cfg[encoder_type] | ||
return ENCODER_REGISTRY[encoder_type](d_model=d_model, **configs) | ||
else: | ||
raise ValueError( | ||
f"Unknown encoder type: {encoder_type}. Please use one of {list(ENCODER_REGISTRY.keys())}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix potential KeyError
in create_visual_encoder
In the create_visual_encoder
function, accessing configs = encoder_cfg[encoder_type]
may raise a KeyError
if encoder_type
is not a key in encoder_cfg
. This can happen when encoder_cfg
does not contain configurations specific to the encoder_type
.
Apply this diff to ensure configs
is correctly retrieved:
def create_visual_encoder(d_model: int, **encoder_cfg) -> torch.nn.Module:
"""Create a visual encoder based on the specified type."""
register_encoder("resnet", VisualEncoder)
register_encoder("descriptor", DescriptorVisualEncoder)
# register any custom encoders here
# Compatibility with configs that don't specify encoder_type; default to resnet
if "encoder_type" not in encoder_cfg:
encoder_type = "resnet"
+ configs = encoder_cfg
else:
encoder_type = encoder_cfg.pop("encoder_type")
- if encoder_type in ENCODER_REGISTRY:
- # Choose the relevant encoder configs based on the encoder_type
- configs = encoder_cfg[encoder_type]
+ if encoder_type in ENCODER_REGISTRY:
+ # Use the remaining encoder_cfg as configs
+ configs = encoder_cfg
return ENCODER_REGISTRY[encoder_type](d_model=d_model, **configs)
else:
raise ValueError(
f"Unknown encoder type: {encoder_type}. Please use one of {list(ENCODER_REGISTRY.keys())}"
)
This adjustment ensures that configs
are appropriately assigned, preventing a KeyError
when configurations are not nested under the encoder_type
key.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def create_visual_encoder(d_model: int, **encoder_cfg) -> torch.nn.Module: | |
"""Create a visual encoder based on the specified type.""" | |
register_encoder("resnet", VisualEncoder) | |
register_encoder("descriptor", DescriptorVisualEncoder) | |
# register any custom encoders here | |
# compatibility with configs that don't specify encoder_type; default to resnet | |
if "encoder_type" not in encoder_cfg: | |
encoder_type = "resnet" | |
return ENCODER_REGISTRY[encoder_type](d_model=d_model, **encoder_cfg) | |
else: | |
encoder_type = encoder_cfg.pop("encoder_type") | |
if encoder_type in ENCODER_REGISTRY: | |
# choose the relevant encoder configs based on the encoder_type | |
configs = encoder_cfg[encoder_type] | |
return ENCODER_REGISTRY[encoder_type](d_model=d_model, **configs) | |
else: | |
raise ValueError( | |
f"Unknown encoder type: {encoder_type}. Please use one of {list(ENCODER_REGISTRY.keys())}" | |
) | |
def create_visual_encoder(d_model: int, **encoder_cfg) -> torch.nn.Module: | |
"""Create a visual encoder based on the specified type.""" | |
register_encoder("resnet", VisualEncoder) | |
register_encoder("descriptor", DescriptorVisualEncoder) | |
# register any custom encoders here | |
# Compatibility with configs that don't specify encoder_type; default to resnet | |
if "encoder_type" not in encoder_cfg: | |
encoder_type = "resnet" | |
configs = encoder_cfg | |
else: | |
encoder_type = encoder_cfg.pop("encoder_type") | |
if encoder_type in ENCODER_REGISTRY: | |
# Use the remaining encoder_cfg as configs | |
configs = encoder_cfg | |
return ENCODER_REGISTRY[encoder_type](d_model=d_model, **configs) | |
else: | |
raise ValueError( | |
f"Unknown encoder type: {encoder_type}. Please use one of {list(ENCODER_REGISTRY.keys())}" | |
) |
@@ -228,12 +246,25 @@ def forward( | |||
|
|||
else: | |||
query_instances = ref_instances | |||
query_times = ref_times |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure query_instances
is assigned when None
When query_instances
is None
, only query_times
is assigned. This may lead to query_instances
being None
later, causing errors. Assign query_instances = ref_instances
to prevent potential issues.
Apply this diff to fix the issue:
else:
+ query_instances = ref_instances
query_times = ref_times
Committable suggestion skipped: line range outside the PR's diff.
dreem/models/transformer.py
Outdated
def apply_fourier_embeddings( | ||
queries: torch.Tensor, | ||
times: torch.Tensor, | ||
d_model: int, | ||
fourier_embeddings: FourierPositionalEmbeddings, | ||
) -> torch.Tensor: | ||
|
||
embs = fourier_embeddings(times).permute(1, 0, 2) | ||
cat_queries = torch.cat([queries, embs], dim=-1) | ||
# project to d_model | ||
proj = nn.Linear(cat_queries.shape[-1], d_model).to(queries.device) | ||
norm = nn.LayerNorm(d_model).to(queries.device) | ||
cat_queries = proj(cat_queries) | ||
cat_queries = norm(cat_queries) | ||
|
||
return cat_queries | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid defining neural network layers inside a function
Defining nn.Linear
and nn.LayerNorm
inside apply_fourier_embeddings
leads to new instances being created on each call. This prevents the layers from learning during training, as their parameters are not registered with the model.
Move the layer definitions to the Transformer
class's __init__
method and store them as instance variables.
Apply this diff:
class Transformer(torch.nn.Module):
def __init__(self, ...):
...
+ # Initialize Fourier embeddings
+ self.fourier_embeddings = FourierPositionalEmbeddings(
+ n_components=8, d_model=d_model
+ )
+
+ # Define projection and normalization layers for Fourier embeddings
+ fourier_dim = self.fourier_embeddings.d_model
+ self.fourier_proj = nn.Linear(d_model + fourier_dim, d_model)
+ self.fourier_norm = nn.LayerNorm(d_model)
...
+ def apply_fourier_embeddings(
+ self,
+ queries: torch.Tensor,
+ times: torch.Tensor,
+ ) -> torch.Tensor:
+ embs = self.fourier_embeddings(times).permute(1, 0, 2)
+ cat_queries = torch.cat([queries, embs], dim=-1)
+ cat_queries = self.fourier_proj(cat_queries)
+ cat_queries = self.fourier_norm(cat_queries)
+ return cat_queries
...
- def apply_fourier_embeddings(
- queries: torch.Tensor,
- times: torch.Tensor,
- d_model: int,
- fourier_embeddings: FourierPositionalEmbeddings,
- ) -> torch.Tensor:
-
- embs = fourier_embeddings(times).permute(1, 0, 2)
- cat_queries = torch.cat([queries, embs], dim=-1)
- # Project to d_model
- proj = nn.Linear(cat_queries.shape[-1], d_model).to(queries.device)
- norm = nn.LayerNorm(d_model).to(queries.device)
- cat_queries = proj(cat_queries)
- cat_queries = norm(cat_queries)
- return cat_queries
Update calls to apply_fourier_embeddings
:
if self.should_apply_fourier_embeddings():
- encoder_queries = apply_fourier_embeddings(
- encoder_queries, ref_times, self.d_model, self.fourier_embeddings
- )
+ encoder_queries = self.apply_fourier_embeddings(
+ encoder_queries, ref_times
)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def apply_fourier_embeddings( | |
queries: torch.Tensor, | |
times: torch.Tensor, | |
d_model: int, | |
fourier_embeddings: FourierPositionalEmbeddings, | |
) -> torch.Tensor: | |
embs = fourier_embeddings(times).permute(1, 0, 2) | |
cat_queries = torch.cat([queries, embs], dim=-1) | |
# project to d_model | |
proj = nn.Linear(cat_queries.shape[-1], d_model).to(queries.device) | |
norm = nn.LayerNorm(d_model).to(queries.device) | |
cat_queries = proj(cat_queries) | |
cat_queries = norm(cat_queries) | |
return cat_queries | |
def apply_fourier_embeddings( | |
self, | |
queries: torch.Tensor, | |
times: torch.Tensor, | |
) -> torch.Tensor: | |
embs = self.fourier_embeddings(times).permute(1, 0, 2) | |
cat_queries = torch.cat([queries, embs], dim=-1) | |
cat_queries = self.fourier_proj(cat_queries) | |
cat_queries = self.fourier_norm(cat_queries) | |
return cat_queries |
update sample configs for train and for unit tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Outside diff range and nitpick comments (11)
tests/configs/base.yaml (3)
3-3
: Remove trailing whitespace.There is a trailing space after
encoder_cfg:
.- encoder_cfg: + encoder_cfg:🧰 Tools
🪛 yamllint (1.35.1)
[error] 3-3: trailing spaces
(trailing-spaces)
6-6
: Enhance documentation for ndim parameter.The current comment could be more descriptive about the relationship between ndim and encoder types.
- ndim: 5 # dim of the visual feature vector; 5 for descriptor, 12 for descriptor + hu moments + ndim: 5 # Dimension of the visual feature vector: + # - 5 for basic descriptor (inertia tensor + mean intensity) + # - 12 when using descriptor with Hu moments enabled
12-13
: Add documentation for descriptor configuration.The descriptor configuration would benefit from comments explaining:
- The impact of enabling Hu moments on model performance
- When to use Hu moments vs. basic descriptors
descriptor: + # Enable Hu moments for rotation-invariant shape descriptors + # Recommended for tracking objects with significant rotation use_hu_moments: falsedreem/training/configs/base.yaml (3)
4-7
: Consider moving legacy configuration documentation.While keeping the old configuration structure as comments can help during migration, it's better to document these changes in proper migration guides or documentation to avoid confusion and reduce maintenance overhead.
🧰 Tools
🪛 yamllint (1.35.1)
[error] 4-4: trailing spaces
(trailing-spaces)
10-12
: Enhance configuration documentation and validation.The encoder configuration would benefit from:
- Documentation of all supported
encoder_type
values- Clarification if both
resnet
anddescriptor
configs are required- More flexible documentation for
ndim
that won't need updates if new descriptors are addedConsider adding a comment like this:
encoder_cfg: # Supported encoder_type values: "resnet", "descriptor" # Only the configuration matching the selected encoder_type is required encoder_type: "descriptor" # ndim: dimension of the visual feature vector # - descriptor without hu_moments: 5 # - descriptor with hu_moments: 12 ndim: 5
13-19
: Consider enhancing the descriptor encoder configuration.The descriptor configuration could benefit from:
- Validation constraints for
use_hu_moments
- Additional customization parameters if applicable
- Clear documentation of the relationship between
use_hu_moments
andndim
Consider expanding the descriptor configuration like this:
descriptor: use_hu_moments: false # If true, ndim should be 12; if false, ndim should be 5 # Add any additional parameters for customization # For example: # normalize_features: true # feature_weights: [1.0, 1.0, 1.0, 1.0, 1.0]dreem/models/visual_encoder.py (2)
9-9
: Remove unused importThe
local_binary_pattern
fromskimage.feature
is imported but never used in the code.-from skimage.feature import local_binary_pattern
🧰 Tools
🪛 Ruff (0.8.0)
9-9:
skimage.feature.local_binary_pattern
imported but unusedRemove unused import:
skimage.feature.local_binary_pattern
(F401)
167-221
: Consider documenting performance characteristicsThe
DescriptorVisualEncoder
provides a lightweight alternative to CNN-based encoding, which could be beneficial for resource-constrained environments. Consider adding documentation about:
- Memory usage comparison with ResNet
- Inference speed benchmarks
- Use cases where descriptor-based encoding might be preferred
- Limitations of the descriptor-based approach
This will help users make informed decisions about which encoder to use.
dreem/models/transformer.py (3)
98-100
: Consider making Fourier embedding parameters configurableThe Fourier embeddings are initialized with hardcoded parameters. Consider making these configurable through
encoder_cfg
for better flexibility.self.fourier_embeddings = FourierPositionalEmbeddings( - n_components=8, d_model=d_model + n_components=self.encoder_cfg.get("fourier_components", 8), + d_model=d_model )
110-115
: Simplify configuration access using dict.get()The nested configuration checks can be simplified using dict.get() method.
- if self.encoder_cfg and "encoder_type" in self.encoder_cfg: - self.visual_feat_dim = ( - self.encoder_cfg["ndim"] if "ndim" in self.encoder_cfg else 5 - ) # 5 is default for descriptor + if self.encoder_cfg and self.encoder_cfg.get("encoder_type"): + self.visual_feat_dim = self.encoder_cfg.get("ndim", 5) # 5 is default for descriptor self.fourier_proj = nn.Linear(self.d_model + self.visual_feat_dim, d_model) self.fourier_norm = nn.LayerNorm(self.d_model)🧰 Tools
🪛 Ruff (0.8.0)
112-112: Use
self.encoder_cfg.get("ndim", 5)
instead of anif
blockReplace with
self.encoder_cfg.get("ndim", 5)
(SIM401)
318-346
: Consider moving apply_fourier_embeddings into the Transformer classThis function uses class members and is only used within the Transformer class. Consider making it a method of the Transformer class to improve encapsulation.
-def apply_fourier_embeddings( - queries: torch.Tensor, - times: torch.Tensor, - d_model: int, - fourier_embeddings: FourierPositionalEmbeddings, - proj: nn.Linear, - norm: nn.LayerNorm, -) -> torch.Tensor: +def _apply_fourier_embeddings( + self, + queries: torch.Tensor, + times: torch.Tensor, +) -> torch.Tensor: """Apply fourier embeddings to queries. Args: queries: The input tensor of shape (n_query, batch_size, embed_dim). times: The times index tensor of shape (n_query,). - d_model: Model dimension. - fourier_embeddings: The Fourier positional embeddings object. - proj: Linear projection layer that projects concantenated feature vector to model dimension. - norm: The normalization layer. Returns: The output queries of shape (n_query, batch_size, embed_dim). """ - embs = fourier_embeddings(times).permute(1, 0, 2) + embs = self.fourier_embeddings(times).permute(1, 0, 2) cat_queries = torch.cat([queries, embs], dim=-1) # project to d_model - cat_queries = proj(cat_queries) - cat_queries = norm(cat_queries) + cat_queries = self.fourier_proj(cat_queries) + cat_queries = self.fourier_norm(cat_queries) return cat_queries
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (7)
dreem/io/config.py
(1 hunks)dreem/models/embedding.py
(1 hunks)dreem/models/global_tracking_transformer.py
(4 hunks)dreem/models/transformer.py
(9 hunks)dreem/models/visual_encoder.py
(2 hunks)dreem/training/configs/base.yaml
(1 hunks)tests/configs/base.yaml
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- dreem/models/embedding.py
- dreem/models/global_tracking_transformer.py
🧰 Additional context used
🪛 Ruff (0.8.0)
dreem/io/config.py
272-276: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
dreem/models/transformer.py
112-112: Use self.encoder_cfg.get("ndim", 5)
instead of an if
block
Replace with self.encoder_cfg.get("ndim", 5)
(SIM401)
dreem/models/visual_encoder.py
9-9: skimage.feature.local_binary_pattern
imported but unused
Remove unused import: skimage.feature.local_binary_pattern
(F401)
🪛 yamllint (1.35.1)
dreem/training/configs/base.yaml
[error] 4-4: trailing spaces
(trailing-spaces)
[error] 9-9: trailing spaces
(trailing-spaces)
tests/configs/base.yaml
[error] 3-3: trailing spaces
(trailing-spaces)
🔇 Additional comments (5)
tests/configs/base.yaml (1)
7-11
: Consider enabling pretrained weights for better feature extraction.
The ResNet encoder is configured with pretrained: false
. Consider enabling pretrained weights as they often provide better feature extraction capabilities, especially for tasks with limited training data.
dreem/training/configs/base.yaml (1)
10-19
: Verify configuration handling in code.
Please ensure that:
- The code properly validates the
encoder_type
selection - The relationship between
use_hu_moments
andndim
is enforced - Configuration errors provide clear error messages
✅ Verification successful
Let me gather more information about the descriptor encoder and its validation.
Let me search for the descriptor encoder implementation and validation logic in a different way.
Configuration validation is properly implemented
The codebase shows proper handling of the configuration concerns:
- The
encoder_type
validation is implemented increate_visual_encoder()
with a registry system and clear error messages for unknown types. - The relationship between
use_hu_moments
and descriptor dimensionality is handled in theDescriptorVisualEncoder
class, which correctly combines the features based on the configuration. - The code includes proper error handling with descriptive messages for invalid encoder types.
Additional observations:
- Default fallback to "resnet" encoder when type is not specified
- Clear separation of encoder-specific configurations
- Type validation through the registry system
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for configuration validation code
rg -A 5 "encoder_type.*=|encoder_cfg.*="
# Search for error handling related to configuration
rg -A 5 "raise.*Error.*encoder"
Length of output: 6104
Script:
#!/bin/bash
# Search for descriptor encoder implementation and validation
ast-grep --pattern 'class DescriptorEncoder {
$$$
}'
# Search for configuration validation in descriptor-related code
rg -A 10 "descriptor.*config|descriptor.*cfg"
# Search for ndim validation
rg -A 5 "ndim.*=|ndim.*\+"
Length of output: 882
Script:
#!/bin/bash
# Search for descriptor-related code
rg -A 10 "class.*Descriptor"
# Search for validation of Hu moments
rg -A 5 "hu_moments|use_hu"
# Search for any validation in the model creation
rg -A 10 "def create_.*model|def build_.*model"
Length of output: 4150
dreem/models/visual_encoder.py (1)
230-250
: Fix potential KeyError
in create_visual_encoder
The issue mentioned in the past review about potential KeyError when accessing encoder_cfg[encoder_type]
is still present. Please refer to the previous review comment for the suggested fix.
dreem/io/config.py (1)
271-277
: 🛠️ Refactor suggestion
Improve robustness and documentation of encoder configuration
The current implementation has several areas for improvement:
- Deep nesting of dictionary access could lead to KeyError exceptions
- The rationale for disabling normalization with descriptor encoder should be documented
- The string literal "descriptor" should be extracted to a constant
- The nested if statements can be combined
Apply this refactor:
+ # Constants
+ DESCRIPTOR_ENCODER_TYPE = "descriptor"
+
+ def _get_encoder_type(self) -> str | None:
+ """Get the encoder type from model config safely."""
+ try:
+ return self.cfg["model"]["encoder_cfg"]["encoder_type"]
+ except (KeyError, TypeError):
+ return None
+
- # TODO: handle this better!
- if "model" in self.cfg:
- if (
- "encoder_type" in self.cfg["model"]["encoder_cfg"]
- and self.cfg["model"]["encoder_cfg"]["encoder_type"] == "descriptor"
- ):
- dataset_params["normalize_image"] = False
+ # Disable normalization for descriptor encoder to prevent empty crops
+ # in low-intensity regions that could hinder feature calculation
+ if (encoder_type := self._get_encoder_type()) == DESCRIPTOR_ENCODER_TYPE:
+ dataset_params["normalize_image"] = False
🧰 Tools
🪛 Ruff (0.8.0)
272-276: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
dreem/models/transformer.py (1)
265-265
: Ensure query_instances is assigned when None
When query_instances
is None
, only query_times
is assigned. This may lead to query_instances
being None
later, causing errors.
encoder_cfg: | ||
# based on encoder_type, the relevant settings are selected | ||
encoder_type: "resnet" | ||
ndim: 5 # dim of the visual feature vector; 5 for descriptor, 12 for descriptor + hu moments | ||
resnet: | ||
model_name: "resnet18" | ||
in_chans: 3 | ||
backend: "timm" | ||
pretrained: false | ||
descriptor: | ||
use_hu_moments: false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Fourier embeddings are controlled by use_fourier
flag in embedding configuration
Based on the codebase analysis, I found that:
- Fourier embeddings are implemented in
dreem/models/transformer.py
anddreem/models/embedding.py
- They are activated through the
use_fourier
flag inembedding_meta
configuration - The base configuration file shows
embedding_meta
section but lacks theuse_fourier
setting
The configuration should be updated to include:
- Add
use_fourier: true
under theembedding_meta
section - Optionally, consider adding Fourier-specific settings like
n_components
(default is 8 from the code)
🔗 Analysis chain
Verify configuration for Fourier embeddings.
The PR objectives mention learnable Fourier embeddings, but there's no configuration section for them in this file. Please verify if:
- Fourier embedding settings should be added to this configuration
- These settings are defined in another configuration file
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for Fourier embedding configurations
rg -l "fourier.*embedding" "tests/configs/"
Length of output: 43
Script:
#!/bin/bash
# Let's broaden the search to look for any Fourier-related configurations and implementation
rg -i "fourier" --type yaml tests/configs/ -A 5
Length of output: 47
Script:
#!/bin/bash
# Search for Fourier-related code in the entire codebase
rg -i "fourier" -g "!*.pyc" -g "!*.git*" -A 5
Length of output: 8183
Script:
#!/bin/bash
# Look for encoder types in configuration files
rg "encoder_type:" tests/configs/ -A 3
Length of output: 299
Script:
#!/bin/bash
# Let's check for embedding configurations in the config files
rg "embedding" tests/configs/ -A 5
Length of output: 860
Script:
#!/bin/bash
# Check if there are other configuration files that might contain the settings
fd -e yaml -e yml . tests/configs/
Length of output: 115
🧰 Tools
🪛 yamllint (1.35.1)
[error] 3-3: trailing spaces
(trailing-spaces)
@torch.no_grad() | ||
def forward(self, img: torch.Tensor) -> torch.Tensor: | ||
"""Forward pass of feature extractor to get feature vector.""" | ||
descriptors = [] | ||
|
||
for im in img: | ||
im = im[0].cpu().numpy() | ||
|
||
inertia_tensor = self.compute_inertia_tensor(im) | ||
mean_intensity = im.mean() | ||
if self.use_hu_moments: | ||
hu_moments = self.compute_hu_moments(im) | ||
|
||
# Flatten inertia tensor | ||
inertia_tensor_flat = inertia_tensor.flatten() | ||
|
||
# Combine all features into a single descriptor | ||
descriptor = np.concatenate( | ||
[ | ||
inertia_tensor_flat, | ||
[mean_intensity], | ||
hu_moments if self.use_hu_moments else [], | ||
] | ||
) | ||
|
||
descriptors.append(torch.tensor(descriptor, dtype=torch.float32)) | ||
|
||
return torch.stack(descriptors) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improve forward method implementation
Several improvements are needed in the forward method:
- Missing input shape validation
- Inefficient CPU transfer
- Missing feature normalization
- Incomplete docstring
@torch.no_grad()
def forward(self, img: torch.Tensor) -> torch.Tensor:
- """Forward pass of feature extractor to get feature vector."""
+ """Forward pass of feature extractor to get feature vector.
+
+ Args:
+ img: Input image tensor of shape (B, C, H, W). Expects grayscale images (C=1).
+
+ Returns:
+ feats: Normalized output tensor of shape (B, d_model).
+ """
+ if img.shape[1] != 1:
+ raise ValueError(f"Expected grayscale images with 1 channel, got {img.shape[1]} channels")
+
descriptors = []
for im in img:
- im = im[0].cpu().numpy()
+ # Keep computation on GPU if available
+ im = im[0].numpy() if not torch.cuda.is_available() else im[0].cpu().numpy()
inertia_tensor = self.compute_inertia_tensor(im)
mean_intensity = im.mean()
if self.use_hu_moments:
hu_moments = self.compute_hu_moments(im)
- # Flatten inertia tensor
inertia_tensor_flat = inertia_tensor.flatten()
- # Combine all features into a single descriptor
descriptor = np.concatenate(
[
inertia_tensor_flat,
[mean_intensity],
hu_moments if self.use_hu_moments else [],
]
)
descriptors.append(torch.tensor(descriptor, dtype=torch.float32))
- return torch.stack(descriptors)
+ feats = torch.stack(descriptors)
+ if torch.cuda.is_available():
+ feats = feats.cuda()
+
+ # Project to d_model dimension and normalize
+ feats = self.projection(feats)
+ feats = F.normalize(feats)
+ return feats
Committable suggestion skipped: line range outside the PR's diff.
def __init__(self, use_hu_moments: bool = False, **kwargs): | ||
"""Initialize Descriptor Visual Encoder. | ||
|
||
Args: | ||
use_hu_moments: Whether to use Hu moments. | ||
""" | ||
super().__init__() | ||
self.use_hu_moments = use_hu_moments | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add missing d_model parameter for consistency
The DescriptorVisualEncoder
should maintain consistency with VisualEncoder
by accepting and utilizing the d_model
parameter to ensure proper feature dimensionality.
- def __init__(self, use_hu_moments: bool = False, **kwargs):
+ def __init__(self, d_model: int, use_hu_moments: bool = False, **kwargs):
"""Initialize Descriptor Visual Encoder.
Args:
+ d_model: Output embedding dimension.
use_hu_moments: Whether to use Hu moments.
"""
super().__init__()
+ self.d_model = d_model
self.use_hu_moments = use_hu_moments
+ # Add projection layer to match d_model dimension
+ feature_dim = 4 + (7 if use_hu_moments else 0) # inertia tensor (4) + mean + optional hu moments (7)
+ self.projection = torch.nn.Linear(feature_dim, d_model)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def __init__(self, use_hu_moments: bool = False, **kwargs): | |
"""Initialize Descriptor Visual Encoder. | |
Args: | |
use_hu_moments: Whether to use Hu moments. | |
""" | |
super().__init__() | |
self.use_hu_moments = use_hu_moments | |
def __init__(self, d_model: int, use_hu_moments: bool = False, **kwargs): | |
"""Initialize Descriptor Visual Encoder. | |
Args: | |
d_model: Output embedding dimension. | |
use_hu_moments: Whether to use Hu moments. | |
""" | |
super().__init__() | |
self.d_model = d_model | |
self.use_hu_moments = use_hu_moments | |
# Add projection layer to match d_model dimension | |
feature_dim = 4 + (7 if use_hu_moments else 0) # inertia tensor (4) + mean + optional hu moments (7) | |
self.projection = torch.nn.Linear(feature_dim, d_model) |
# apply fourier embeddings if using fourier rope, OR if using descriptor (compact) visual encoder | ||
if ( | ||
self.embedding_meta | ||
and "use_fourier" in self.embedding_meta | ||
and self.embedding_meta["use_fourier"] | ||
) or ( | ||
self.encoder_cfg | ||
and "encoder_type" in self.encoder_cfg | ||
and self.encoder_cfg["encoder_type"] == "descriptor" | ||
): | ||
encoder_queries = apply_fourier_embeddings( | ||
encoder_queries, | ||
ref_times, | ||
self.d_model, | ||
self.fourier_embeddings, | ||
self.fourier_proj, | ||
self.fourier_norm, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Refactor duplicate Fourier embedding logic
The code for applying Fourier embeddings is duplicated. Consider extracting the condition check into a method and reusing it.
+ def should_apply_fourier_embeddings(self) -> bool:
+ """Check if Fourier embeddings should be applied."""
+ return (
+ (self.embedding_meta and self.embedding_meta.get("use_fourier")) or
+ (self.encoder_cfg and self.encoder_cfg.get("encoder_type") == "descriptor")
+ )
- if (
- self.embedding_meta
- and "use_fourier" in self.embedding_meta
- and self.embedding_meta["use_fourier"]
- ) or (
- self.encoder_cfg
- and "encoder_type" in self.encoder_cfg
- and self.encoder_cfg["encoder_type"] == "descriptor"
- ):
+ if self.should_apply_fourier_embeddings():
encoder_queries = apply_fourier_embeddings(
encoder_queries,
ref_times,
self.d_model,
self.fourier_embeddings,
self.fourier_proj,
self.fourier_norm,
)
Also applies to: 272-289
Summary by CodeRabbit
Release Notes
New Features
DescriptorVisualEncoder
class for feature extraction based on image descriptors.FourierPositionalEmbeddings
for improved positional encoding in the transformer model.Enhancements
Documentation