-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RoPE, Fourier RoPE, Fourier learnable embeddings #96
base: main
Are you sure you want to change the base?
Conversation
- changes to embedding class - add apply() function to Embedding class - remove references to embedding from encoderlayer fwd pass
- add support for stack/avg/concatenate - move embedding processing out of transformer and into encoder
- get centroid from x,y for spatial embedding - complete stack agg method - add docstrings
- concatenation method with mlp - complete pre-processing for input to EncoderLayer - fix shape issues in rope/additive_embedding/forward modules in embedding.py
- bounding box embedding only for method "average" - modify emb_funcs routing - temporarily remove support for adding embeddings into instance objects - need to make compatible with x,y,t embeddings - remove config yamls from updates - current versions serve as templates - runs through to end of encoder forward pass
- implement embeddings for decoder + refactor - add 1x1 conv to final attn head to deal with stacked embeddings (3x tokens) and create channels for each dim - bug fix in rope rotation matrix product with input data
- 1x1 conv for stack embedding - stack into 3 channels for x,y,t
- add unit tests for rope - Update existing tests to use new args/return params related to tfmr - Modify test to remove return_embedding=True support - need to address this
- create rope isntance once rather than each fwd pass - construct embedding lookup array each fwd pass based on num instances passed in to embedding - scale only pos embs * 100 rather than also temp embs
- times array for embedding for encoder queries inside decoder was of query size rather than ref size
- times array for embedding for encoder queries inside decoder was of query size rather than ref size
- added tracker debugging script
- Fix RoPE to apply to q,k not v; pass in config to layers, update forward passes to enable this - Add support for learned Fourier coeffs for RoPE
- complete fourier rope; only change vs rope is cache creation as param not buffer - disabled encoder output embedding at decoder input (not standard practice)
…el rather than n_queries
- Apply embeddings to q,k,v instead of just q,k (but still add to orig queries rather than embedded queries)
…ce in each forward pass of transformer - add if/else branch to encoder/decoder layer to apply embeddings only to qk not v in case of rope mode
WalkthroughThis pull request introduces various updates across multiple files in the project. Key changes include enhancements to the Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Config
participant Dataset
participant Model
participant Tracker
User->>Config: Load configuration
Config->>Dataset: Initialize datasets
Dataset->>Model: Prepare data for training
Model->>Tracker: Start tracking process
Tracker->>Model: Process frames
Model->>User: Return results
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 35
🧹 Outside diff range and nitpick comments (20)
scripts/run_eval.py (1)
4-6
: Remove debugging file paths from comments.These absolute file paths appear to be debugging leftovers and should be removed to maintain clean code.
-# /Users/mustafashaikh/dreem/dreem/training -# /Users/main/Documents/GitHub/dreem/dreem/trainingdreem/models/__init__.py (1)
Line range hint
1-11
: Consider defining__all__
to explicitly specify the public API.To properly expose the public API and address the unused import warnings, consider adding an
__all__
definition:"""Model architectures and layers.""" from .embedding import Embedding, FourierPositionalEmbeddings # from .mlp import MLP # from .attention_head import ATTWeightHead from .transformer import Transformer from .visual_encoder import VisualEncoder from .global_tracking_transformer import GlobalTrackingTransformer from .gtr_runner import GTRRunner +__all__ = [ + 'Embedding', + 'FourierPositionalEmbeddings', + 'Transformer', + 'VisualEncoder', + 'GlobalTrackingTransformer', + 'GTRRunner', +]🧰 Tools
🪛 Ruff
3-3:
.embedding.Embedding
imported but unused; consider removing, adding to__all__
, or using a redundant alias(F401)
3-3:
.embedding.FourierPositionalEmbeddings
imported but unused; consider removing, adding to__all__
, or using a redundant alias(F401)
scripts/run_trainer.py (1)
1-20
: Consider adding script documentation and type hints.The script would benefit from docstrings and type hints to improve maintainability.
Add a module docstring explaining the purpose and usage of the script:
""" Training script for the Dreem framework. This script loads base and parameter configurations from YAML files and executes the training process. It supports configuration overrides through a separate parameter file. Usage: python run_trainer.py --base-config path/to/base.yaml --params-config path/to/params.yaml """🧰 Tools
🪛 Ruff
3-3:
os
imported but unusedRemove unused import:
os
(F401)
dreem/models/mlp.py (2)
37-39
: Improve comment visibility and clarity.The comment explaining list concatenations would be more visible and useful if moved above the list comprehension. Consider expanding it to clarify the layer shape compatibility concept.
+ # Using list concatenations ([input_dim] + h and h + [output_dim]) + # to ensure proper layer shape transitions: + # input_dim -> hidden_dims -> output_dim [ torch.nn.Linear(n, k) - # list concatenations to ensure layer shape compability for n, k in zip([input_dim] + h, h + [output_dim]) ]
Device management in MLP needs optimization for RoPE integration
The verification reveals that the current MLP implementation moves layers to the input device during each forward pass, which is inefficient. This is particularly important in the RoPE context where the MLP is used in:
- Query/Key projections in attention heads
- Embedding layer transformations
- Transformer layer computations
Key issues:
layer.to(x.device)
in MLP's forward method causes unnecessary device transfers on every forward pass- Other components like RoPE embeddings handle device management during initialization or first use
🔗 Analysis chain
Line range hint
1-63
: Verify MLP integration with RoPE implementation.While the MLP implementation is generic, let's verify its usage within the RoPE context to ensure optimal performance.
Let me gather more context about the RoPE implementation and its interaction with MLP.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Search for MLP usage in relation to RoPE implementation echo "Searching for MLP usage in transformer/embedding files..." rg -l "class.*RoPE" | xargs rg "MLP" echo "Checking for potential performance bottlenecks in model construction..." ast-grep --pattern 'MLP( $$$ )'Length of output: 2123
Script:
#!/bin/bash # Check RoPE implementation and its interaction with MLP echo "Checking RoPE implementation..." rg -A 10 "class.*RoPE" echo "\nChecking MLP usage in attention mechanism..." rg -A 5 "def forward.*RoPE" echo "\nChecking device handling in RoPE and related components..." rg "to\(.*device\)" dreem/models/Length of output: 2335
dreem/training/configs/base.yaml (1)
76-76
: Document supported dataset modesThe addition of
mode
parameter to datasets is a good practice, but it needs documentation.Please add comments explaining:
- Available dataset modes
- How each mode affects dataset behavior
- Any mode-specific preprocessing or augmentation
Also applies to: 85-85, 94-94
dreem/training/configs/override.yaml (2)
1-142
: Fix YAML formatting inconsistencies.The YAML file has several formatting issues that should be addressed:
- Inconsistent indentation (especially in nested structures)
- Trailing spaces
- Missing newline at end of file
- Inconsistent spacing after commas and colons
Run
yamllint
locally to fix these issues or apply the following changes:
- Use 2 spaces for indentation consistently
- Remove trailing spaces
- Add newline at end of file
- Ensure consistent spacing after commas and colons
🧰 Tools
🪛 yamllint
[error] 3-3: trailing spaces
(trailing-spaces)
[warning] 4-4: wrong indentation: expected 4 but found 6
(indentation)
[warning] 20-20: wrong indentation: expected 6 but found 8
(indentation)
[warning] 24-24: wrong indentation: expected 6 but found 8
(indentation)
[warning] 29-29: too many spaces after colon
(colons)
[error] 43-43: trailing spaces
(trailing-spaces)
[warning] 46-46: wrong indentation: expected 6 but found 8
(indentation)
[warning] 58-58: too few spaces after comma
(commas)
[warning] 66-66: too few spaces after comma
(commas)
[error] 76-76: trailing spaces
(trailing-spaces)
[warning] 78-78: wrong indentation: expected 4 but found 6
(indentation)
[error] 78-78: trailing spaces
(trailing-spaces)
[error] 99-99: trailing spaces
(trailing-spaces)
[error] 112-112: trailing spaces
(trailing-spaces)
[error] 130-130: trailing spaces
(trailing-spaces)
[warning] 132-132: wrong indentation: expected 2 but found 4
(indentation)
[warning] 133-133: wrong indentation: expected 6 but found 8
(indentation)
[warning] 137-137: wrong indentation: expected 6 but found 8
(indentation)
[error] 142-142: no new line character at the end of file
(new-line-at-end-of-file)
139-142
: Remove or document commented configuration.The
view_batch
configuration is commented out without explanation. Either:
- Remove it if it's not needed
- Document why it's kept as reference
🧰 Tools
🪛 yamllint
[error] 142-142: no new line character at the end of file
(new-line-at-end-of-file)
dreem/training/train.py (4)
Line range hint
57-73
: Add error handling for dataset initialization.The code assumes all datasets will be successfully loaded, but according to related changes,
get_dataset()
may returnNone
. Consider adding validation to handle cases where datasets are unavailable.train_dataset = train_cfg.get_dataset(mode="train") + if train_dataset is None: + raise ValueError("Training dataset could not be loaded") train_dataloader = train_cfg.get_dataloader(train_dataset, mode="train") val_dataset = train_cfg.get_dataset(mode="val") + if val_dataset is None: + logger.warning("Validation dataset not available") val_dataloader = train_cfg.get_dataloader(val_dataset, mode="val") test_dataset = train_cfg.get_dataset(mode="test") + if test_dataset is None: + logger.warning("Test dataset not available") test_dataloader = train_cfg.get_dataloader(test_dataset, mode="test")
Line range hint
74-82
: Consider memory optimization for batch viewing.The batch viewing feature loads a batch into memory even when
view_batch.no_train
is True. Consider moving the batch loading inside theif
block to avoid unnecessary memory usage.dataset = TrackingDataset( train_dl=train_dataloader, val_dl=val_dataloader, test_dl=test_dataloader ) if cfg.view_batch.enable: + if cfg.view_batch.no_train: + instances = next(iter(train_dataset)) + view_training_batch(instances, num_frames=cfg.view_batch.num_frames) + return instances = next(iter(train_dataset)) view_training_batch(instances, num_frames=cfg.view_batch.num_frames) - - if cfg.view_batch.no_train: - return
Line range hint
83-83
: Address TODO comment regarding torch.compile().The TODO comment suggests exploring torch.compile() usage. This could potentially improve model performance.
Would you like me to help implement torch.compile() support or create an issue to track this enhancement?
RoPE implementation exists but configuration needs attention
The codebase has a complete RoPE implementation in
dreem/models/embedding.py
with both standard and Fourier variants, but the current configuration inparams.yaml
usesmode: learned
for positional embeddings. The training infrastructure properly handles RoPE through the embedding system, but:
- The model configuration needs to be updated to use
mode: rope
for positional embeddings- When using RoPE, the embedding aggregation method must be either 'stack' or 'concatenate', not 'average'
- The Fourier variant requires additional parameters (
fourier_n_components
) to be configured🔗 Analysis chain
Line range hint
56-87
: Verify RoPE configuration handling.Given that this PR implements RoPE and Fourier RoPE, let's verify that the training configuration properly handles the related parameters.
Let me try another verification approach to check for RoPE-related configurations and implementations in the codebase.
Let me check the actual RoPE implementation and configuration details to make a final assessment.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check for RoPE-related configuration handling # Test 1: Search for RoPE configuration in config files echo "Checking for RoPE configuration..." rg -l "rope" configs/ # Test 2: Search for RoPE-related model parameters echo "Checking for RoPE implementation..." ast-grep --pattern 'class $_ { $$$ rope($$$) { $$$ } $$$ }'Length of output: 304
Script:
#!/bin/bash # Search for RoPE-related configurations and implementations across the codebase echo "Searching for RoPE-related files and code..." rg -l "rope|RoPE" # Look for configuration-related files echo -e "\nSearching for configuration files..." fd -e yaml -e yml -e json # Check for model configuration handling echo -e "\nChecking model configuration handling..." rg -A 5 "get_gtr_runner|get_model"Length of output: 2868
Script:
#!/bin/bash # Check RoPE implementation in transformer and embedding files echo "Checking transformer.py for RoPE implementation..." rg -A 10 "rope|RoPE" dreem/models/transformer.py echo -e "\nChecking embedding.py for RoPE implementation..." rg -A 10 "rope|RoPE" dreem/models/embedding.py # Check configuration parameters echo -e "\nChecking configuration parameters..." cat dreem/training/configs/params.yamlLength of output: 12157
dreem/io/config.py (2)
43-44
: Document the implications of disabling structure protection.Setting
struct=False
allows for dynamic updates but could lead to unintended modifications. Consider adding a comment explaining why structure protection is disabled and any precautions that should be taken.
202-202
: LGTM! Improved error handling for empty datasets.Good addition of explicit handling for empty datasets. Consider making the warning message more specific about why the dataset might be empty (e.g., no matching files, filtering criteria too strict, etc.).
- logger.warn(f"Length of {mode} dataset is {len(dataset)}! Returning None") + logger.warn(f"Length of {mode} dataset is {len(dataset)}! This might indicate no matching files or strict filtering criteria. Returning None")Also applies to: 256-259
🧰 Tools
🪛 Ruff
202-202: Undefined name
SleapDataset
(F821)
202-202: Undefined name
MicroscopyDataset
(F821)
202-202: Undefined name
CellTrackingDataset
(F821)
dreem/io/instance.py (1)
568-572
: Enhance documentation for embedding aggregation methods.The comment briefly mentions different embedding aggregation methods ("average", "stack", "concatenate") but their requirements and behaviors aren't fully documented.
Consider adding a docstring section explaining the supported embedding aggregation methods:
def add_embedding(self, emb_type: str, embedding: torch.Tensor | dict) -> None: """Save embedding to instance embedding dictionary. Args: emb_type: Key/embedding type to be saved to dictionary embedding: The actual embedding data, which can be: - A torch.Tensor for "average" aggregation method - A dict for "stack" or "concatenate" aggregation methods Note: The embedding format depends on the aggregation method: - "average": Expects a tensor that will be expanded to rank 2 - "stack"/"concatenate": Expects a dict containing the embedding components """🧰 Tools
🪛 Ruff
569-569: Use
is
andis not
for type comparisons, orisinstance()
for isinstance checks(E721)
dreem/inference/eval.py (1)
Line range hint
63-63
: Remove unnecessaryf
prefix in stringThe log message does not contain any placeholders, so the
f
prefix is unnecessary. Replacelogger.info(f"Computing the following metrics:")
withlogger.info("Computing the following metrics:")
.Apply this diff to fix the issue:
-logger.info(f"Computing the following metrics:") +logger.info("Computing the following metrics:")🧰 Tools
🪛 Ruff
30-30: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
56-56: f-string without any placeholders
Remove extraneous
f
prefix(F541)
dreem/models/attention_head.py (1)
22-25
: Simplify the assignment ofself.embedding_agg_method
You can simplify the code by using the
dict.get()
method.Apply this diff to simplify the code:
- if "embedding_agg_method" in kwargs: - self.embedding_agg_method = kwargs["embedding_agg_method"] - else: - self.embedding_agg_method = None + self.embedding_agg_method = kwargs.get("embedding_agg_method", None)🧰 Tools
🪛 Ruff
22-25: Use
self.embedding_agg_method = kwargs.get("embedding_agg_method", None)
instead of anif
blockReplace with
self.embedding_agg_method = kwargs.get("embedding_agg_method", None)
(SIM401)
dreem/inference/track.py (2)
118-118
: Verify that the checkpoint path exists before loadingEnsuring the checkpoint file exists can prevent runtime errors and provide clearer feedback if the file is missing.
Consider adding a check for the checkpoint file:
checkpoint = pred_cfg.cfg.ckpt_path +if not os.path.exists(checkpoint): + raise FileNotFoundError(f"Checkpoint file not found at {checkpoint}")
Line range hint
131-132
: Useos.makedirs(outdir, exist_ok=True)
efficientlyThe check
"outdir" in pred_cfg.cfg
might not be necessary ifpred_cfg.cfg.get('outdir', './results')
is used to provide a default value.Simplify the outdir assignment:
-outdir = pred_cfg.cfg.outdir if "outdir" in pred_cfg.cfg else "./results" +outdir = pred_cfg.cfg.get('outdir', './results') os.makedirs(outdir, exist_ok=True)🧰 Tools
🪛 Ruff
100-100: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
128-128: f-string without any placeholders
Remove extraneous
f
prefix(F541)
dreem/inference/tracker.py (1)
209-211
: Renameframes_to_track
for better clarity.The variable
frames_to_track
may not clearly convey its purpose. Consider renaming it totracking_window
orcurrent_window_frames
to improve readability and clarity.Apply this diff to rename the variable:
- frames_to_track = tracked_frames + [frame_to_track] # better var name? + tracking_window = tracked_frames + [frame_to_track]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
⛔ Files ignored due to path filters (1)
dreem/training/configs/test_batch_train.csv
is excluded by!**/*.csv
📒 Files selected for processing (23)
.gitignore
(1 hunks)dreem/datasets/sleap_dataset.py
(3 hunks)dreem/inference/eval.py
(1 hunks)dreem/inference/post_processing.py
(3 hunks)dreem/inference/track.py
(2 hunks)dreem/inference/tracker.py
(14 hunks)dreem/io/config.py
(7 hunks)dreem/io/instance.py
(1 hunks)dreem/models/__init__.py
(1 hunks)dreem/models/attention_head.py
(2 hunks)dreem/models/embedding.py
(16 hunks)dreem/models/gtr_runner.py
(2 hunks)dreem/models/mlp.py
(2 hunks)dreem/models/transformer.py
(14 hunks)dreem/training/configs/base.yaml
(6 hunks)dreem/training/configs/override.yaml
(1 hunks)dreem/training/train.py
(2 hunks)scripts/run_eval.py
(1 hunks)scripts/run_tracker.py
(1 hunks)scripts/run_trainer.py
(1 hunks)tests/test_inference.py
(2 hunks)tests/test_models.py
(12 hunks)tests/test_training.py
(1 hunks)
✅ Files skipped from review due to trivial changes (2)
- .gitignore
- tests/test_training.py
🧰 Additional context used
🪛 Ruff
dreem/datasets/sleap_dataset.py
163-163: Local variable e
is assigned to but never used
Remove assignment to unused variable e
(F841)
dreem/inference/eval.py
30-30: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
56-56: f-string without any placeholders
Remove extraneous f
prefix
(F541)
dreem/inference/track.py
100-100: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
dreem/inference/tracker.py
196-196: Loop control variable i
not used within loop body
Rename unused i
to _i
(B007)
200-200: Loop control variable i
not used within loop body
Rename unused i
to _i
(B007)
dreem/io/config.py
202-202: Undefined name SleapDataset
(F821)
202-202: Undefined name MicroscopyDataset
(F821)
202-202: Undefined name CellTrackingDataset
(F821)
263-263: Undefined name SleapDataset
(F821)
263-263: Undefined name MicroscopyDataset
(F821)
263-263: Undefined name CellTrackingDataset
(F821)
dreem/io/instance.py
569-569: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
dreem/models/__init__.py
3-3: .embedding.Embedding
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
3-3: .embedding.FourierPositionalEmbeddings
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
dreem/models/attention_head.py
22-25: Use self.embedding_agg_method = kwargs.get("embedding_agg_method", None)
instead of an if
block
Replace with self.embedding_agg_method = kwargs.get("embedding_agg_method", None)
(SIM401)
dreem/models/embedding.py
166-166: f-string without any placeholders
Remove extraneous f
prefix
(F541)
dreem/models/transformer.py
85-87: Use embedding_meta.get("embedding_agg_method", "average")
instead of an if
block
Replace with embedding_meta.get("embedding_agg_method", "average")
(SIM401)
90-92: Use embedding_meta.get("use_fourier", False)
instead of an if
block
Replace with embedding_meta.get("use_fourier", False)
(SIM401)
scripts/run_eval.py
1-1: dreem.training.train
imported but unused
Remove unused import: dreem.training.train
(F401)
scripts/run_tracker.py
3-3: os
imported but unused
Remove unused import: os
(F401)
scripts/run_trainer.py
3-3: os
imported but unused
Remove unused import: os
(F401)
tests/test_models.py
18-18: dreem.models.transformer.apply_embeddings
imported but unused
Remove unused import: dreem.models.transformer.apply_embeddings
(F401)
168-168: pytest.raises(Exception)
should be considered evil
(B017)
493-493: Local variable query_pos_emb
is assigned to but never used
Remove assignment to unused variable query_pos_emb
(F841)
🪛 yamllint
dreem/training/configs/base.yaml
[warning] 19-19: wrong indentation: expected 6 but found 8
(indentation)
[warning] 22-22: wrong indentation: expected 6 but found 8
(indentation)
dreem/training/configs/override.yaml
[error] 3-3: trailing spaces
(trailing-spaces)
[warning] 4-4: wrong indentation: expected 4 but found 6
(indentation)
[warning] 20-20: wrong indentation: expected 6 but found 8
(indentation)
[warning] 24-24: wrong indentation: expected 6 but found 8
(indentation)
[warning] 29-29: too many spaces after colon
(colons)
[error] 43-43: trailing spaces
(trailing-spaces)
[warning] 46-46: wrong indentation: expected 6 but found 8
(indentation)
[warning] 58-58: too few spaces after comma
(commas)
[warning] 66-66: too few spaces after comma
(commas)
[error] 76-76: trailing spaces
(trailing-spaces)
[warning] 78-78: wrong indentation: expected 4 but found 6
(indentation)
[error] 78-78: trailing spaces
(trailing-spaces)
[error] 99-99: trailing spaces
(trailing-spaces)
[error] 112-112: trailing spaces
(trailing-spaces)
[error] 130-130: trailing spaces
(trailing-spaces)
[warning] 132-132: wrong indentation: expected 2 but found 4
(indentation)
[warning] 133-133: wrong indentation: expected 6 but found 8
(indentation)
[warning] 137-137: wrong indentation: expected 6 but found 8
(indentation)
[error] 142-142: no new line character at the end of file
(new-line-at-end-of-file)
🔇 Additional comments (26)
scripts/run_eval.py (1)
1-12
: Verify the configuration file structure.
Let's verify the structure of the inference configuration file to ensure it matches the expected format.
✅ Verification successful
Configuration structure matches the expected format
The inference configuration file in tests/configs/inference.yaml
follows the correct structure as documented in dreem/inference/configs/README.md
. It contains all the required sections:
- Valid
tracker
configuration with correct parameters - Properly structured
dataset.test_dataset
with required fields - Correct
dataloader.test_dataloader
settings
The test configuration appears to be a simplified version of the reference config, suitable for testing purposes.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check the structure of the inference configuration file
# Test 1: Verify the config file exists
if [ -f "tests/configs/inference.yaml" ]; then
echo "Found inference.yaml"
cat "tests/configs/inference.yaml"
else
echo "Warning: inference.yaml not found at expected location"
fi
# Test 2: Search for other inference config examples in the codebase
echo "Searching for other inference config examples:"
rg -g "*.yaml" -g "*.yml" -l "inference|eval" "tests/" "configs/"
Length of output: 804
Script:
#!/bin/bash
# Let's verify if the configuration structure matches the expected format by checking:
# 1. Usage of this config in the codebase
# 2. Other similar config files that might serve as reference
# Check how the config is used in the codebase
ast-grep --pattern 'eval.run($_)'
# Look for similar config files in the entire repository
fd -e yaml -e yml
# Check for any documentation about config structure
fd README -t f
Length of output: 459
Script:
#!/bin/bash
# Let's check the reference config files and documentation
# Check the reference inference config
cat dreem/inference/configs/inference.yaml
# Check the config documentation
cat dreem/inference/configs/README.md
# Check the implementation of eval.run to understand expected config structure
ast-grep --pattern 'def run($$$)'
Length of output: 7574
🧰 Tools
🪛 Ruff
1-1: dreem.training.train
imported but unused
Remove unused import: dreem.training.train
(F401)
dreem/models/__init__.py (1)
3-3
: LGTM! Import aligns with new Fourier positional embeddings feature.
The addition of FourierPositionalEmbeddings
import is consistent with the PR's objective of implementing RoPE and Fourier RoPE functionality.
🧰 Tools
🪛 Ruff
3-3: .embedding.Embedding
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
3-3: .embedding.FourierPositionalEmbeddings
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
scripts/run_tracker.py (1)
10-12
:
Add error handling for configuration loading and execution
The script lacks error handling for common issues like missing config files or invalid configurations.
Consider adding proper error handling:
-cfg = OmegaConf.load(config)
-track.run(cfg)
+try:
+ if not os.path.exists(config):
+ raise FileNotFoundError(f"Config file not found: {config}")
+
+ cfg = OmegaConf.load(config)
+ track.run(cfg)
+except FileNotFoundError as e:
+ print(f"Error: {e}")
+ sys.exit(1)
+except Exception as e:
+ print(f"Error during tracking: {e}")
+ sys.exit(1)
Also, consider adding input validation for the config structure before running the tracker.
Let's verify the expected config structure:
✅ Verification successful
Let's examine the actual config files to understand their structure and validate the need for error handling.
Let's check if there's any existing error handling in the track module and how the config is validated.
Let's check the config validation implementation and test files to see if there's existing error handling we should be aware of.
Based on the code analysis, I can now provide the final response:
Error handling is already implemented in the Config class, but script-level handling is still needed
The Config class has built-in error handling for:
- Config merging and validation through OmegaConf
- Parameter updates with try-except blocks
- File path resolution and dataset validation
- Proper logging of errors
However, the script should still handle file existence and basic config loading errors:
try:
if not os.path.exists(config):
raise FileNotFoundError(f"Config file not found: {config}")
cfg = OmegaConf.load(config)
track.run(cfg)
except FileNotFoundError as e:
print(f"Error: {e}")
sys.exit(1)
except Exception as e:
print(f"Error during tracking: {e}")
sys.exit(1)
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check the structure of existing inference YAML configs
# to understand the expected format
# Search for YAML files that might contain inference configurations
fd -e yaml -e yml | xargs rg -l 'inference|track'
Length of output: 201
Script:
#!/bin/bash
# Check contents of inference config files to understand structure
echo "=== dreem/inference/configs/inference.yaml ==="
cat dreem/inference/configs/inference.yaml
echo -e "\n=== tests/configs/inference.yaml ==="
cat tests/configs/inference.yaml
# Check if track.run has any built-in error handling
ast-grep --pattern 'def run($$$)'
Length of output: 1408
Script:
#!/bin/bash
# Look for track.run implementation and any validation
rg -A 5 "def run"
# Check for any config validation or error handling
rg -l "OmegaConf|ConfigurationError|ValidationError"
Length of output: 1279
Script:
#!/bin/bash
# Check config validation implementation
echo "=== dreem/io/config.py ==="
cat dreem/io/config.py
# Check test files for error handling examples
echo -e "\n=== tests/test_config.py ==="
cat tests/test_config.py
Length of output: 19202
dreem/training/configs/base.yaml (2)
137-148
: Verify single-epoch training configuration
The training configuration has been set to run for exactly one epoch, which seems unusually short for training deep learning models, especially with RoPE embeddings.
Is this intentional or is it meant for development/testing purposes only? Consider:
- Adding a comment explaining why single-epoch training is used
- Creating separate configurations for development and production training
- Documenting recommended epoch settings for production use
19-23
: Verify RoPE implementation and enhance documentation
The configuration now supports RoPE for both positional and temporal embeddings, which aligns with the PR objectives. However, we should ensure proper implementation and documentation.
Consider adding docstrings or comments explaining:
- The differences between "fixed", "learned", and "rope" modes
- The impact of normalization on RoPE embeddings
- The behavior of each embedding_agg_method ("stack", "average", "concatenate")
✅ Verification successful
Let me gather more specific information about the RoPE implementation.
RoPE implementation is properly implemented with correct constraints
The codebase shows a complete and well-implemented RoPE (Rotary Positional Encoding) system:
- Proper implementation of both standard RoPE (
RotaryPositionalEmbeddings
) and Fourier RoPE (FourierRotaryPositionalEmbeddings
) with correct mathematical operations - Appropriate handling of embedding modes with validation:
- RoPE mode correctly restricts aggregation methods to 'stack' or 'concatenate' only
- Proper application of RoPE to queries and keys but not values in transformer layers
- Documentation exists explaining the implementation details and references to original papers
The configuration options in base.yaml accurately reflect the implemented functionality, and the constraints on embedding aggregation methods are properly enforced in the code.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Verify RoPE implementation in the codebase
echo "Checking for RoPE implementation in embedding modules..."
rg -l "rope" --type py
# Look for embedding mode handling
ast-grep --pattern 'mode = $_'
Length of output: 389
Script:
#!/bin/bash
# Check RoPE implementation in embedding.py
echo "=== Checking embedding.py for RoPE implementation ==="
rg -A 10 "class.*Embedding" dreem/models/embedding.py
# Check for mode handling in embedding.py
echo "=== Checking mode handling in embedding.py ==="
rg -B 2 -A 5 "mode.*=.*" dreem/models/embedding.py
# Check transformer.py for RoPE usage
echo "=== Checking transformer.py for RoPE implementation ==="
rg -B 2 -A 5 "rope" dreem/models/transformer.py
# Check for embedding aggregation methods
echo "=== Checking embedding aggregation methods ==="
rg -B 2 -A 5 "stack|average|concatenate" dreem/models/embedding.py
Length of output: 10430
🧰 Tools
🪛 yamllint
[warning] 19-19: wrong indentation: expected 6 but found 8
(indentation)
[warning] 22-22: wrong indentation: expected 6 but found 8
(indentation)
dreem/training/configs/override.yaml (1)
18-25
:
Missing RoPE configuration in embedding settings.
The PR objectives indicate implementation of RoPE (Rotary Positional Encoding) and Fourier RoPE, but the embedding configuration only shows fixed positional encoding. Consider adding RoPE-specific configuration parameters here.
Example configuration structure for RoPE:
embedding_meta:
pos:
mode: "rope" # instead of "fixed"
dim: 128 # rotary embedding dimension
base: 10000 # base for frequency computation
normalize: true
🧰 Tools
🪛 yamllint
[warning] 20-20: wrong indentation: expected 6 but found 8
(indentation)
[warning] 24-24: wrong indentation: expected 6 but found 8
(indentation)
dreem/training/train.py (2)
56-56
: LGTM! Good code organization.
The added blank line improves readability by separating model initialization from dataset setup.
87-87
: LGTM! Good code organization.
The added blank line improves readability by separating accelerator setup from device count determination.
dreem/inference/post_processing.py (2)
161-162
: Review distance normalization approach.
The normalization method has been changed from using box sizes to using mean across dimensions. This change, combined with the commented-out code and TODO about pixel scaling, suggests uncertainty about the best approach.
Questions to consider:
- Why was the box size normalization removed?
- How does the mean normalization affect the distance threshold
max_center_dist
? - Should we maintain both normalization methods as configurable options?
Let's analyze the impact of this change:
#!/bin/bash
# Search for tests related to max_center_dist to understand expected behavior
rg -A 10 "max_center_dist.*test" --type python
# Look for configuration files that might specify distance thresholds
fd -e yaml -e json | xargs rg "max_center_dist"
154-154
: 🛠️ Refactor suggestion
Verify tracking behavior with window-based nonk_boxes.
The TODO comment indicates that nonk_boxes
should only include boxes from the previous frame, but currently uses boxes from the entire window. This could affect tracking accuracy and performance.
Let's verify how nonk_boxes
is being used across the codebase:
Consider implementing a frame selection mechanism:
+ # Extract boxes only from the previous frame
+ prev_frame_mask = torch.zeros_like(nonk_boxes[:, 0, 0], dtype=torch.bool)
+ prev_frame_mask[-1] = True # Select only the last frame
+ nonk_boxes = nonk_boxes[prev_frame_mask]
dist = ((k_ct[:, None, :, :] - nonk_ct[None, :, :, :]) ** 2).sum(
dim=-1
) # n_k x Np
Also applies to: 158-160
tests/test_inference.py (1)
218-219
: LGTM! Test updates match the new function signature.
The addition of h
and w
parameters to filter_max_center_dist
calls correctly reflects the updated function signature. The test maintains its original validation logic while accommodating the new dimensional parameters.
Let's verify the function signature change:
Also applies to: 231-232
dreem/models/gtr_runner.py (2)
304-304
: LGTM: More robust video name extraction.
The change to use path splitting instead of extension splitting is more robust as it handles paths with multiple dots correctly.
313-325
: Verify error handling for gt_track_id access.
The code assumes that every instance has a valid gt_track_id
tensor. Consider adding error handling to gracefully handle cases where gt_track_id
might be None or not a tensor.
✅ Verification successful
Based on the shell script results, I can now generate the final response as there is sufficient evidence from the codebase.
The gt_track_id access is properly handled in the codebase
The Instance class in dreem/io/instance.py
has proper handling for gt_track_id:
- It has a default value of -1 when not set
- The setter method handles None case by setting an empty tensor
- The
has_gt_track_id()
method checks for empty tensor before access - The getter returns a tensor which is safe to call .item() on
The code in question is safe because:
- All instances are created with valid gt_track_id tensors through proper initialization
- The frame.instances list only contains valid Instance objects
- The .item() access is protected by the Instance class implementation
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for any error handling around gt_track_id access in the codebase
# Test 1: Look for any None checks or error handling around gt_track_id
rg -A 3 "gt_track_id"
# Test 2: Look for any instance creation or modification of gt_track_id
ast-grep --pattern 'instance.$$$gt_track_id$$$'
Length of output: 21974
dreem/datasets/sleap_dataset.py (2)
109-109
: LGTM: Improved video reader management
The change from a list to a dictionary for storing video readers is a good improvement. It provides efficient lookup by video name and prevents duplicate loading of the same video.
Line range hint 1-370
: Verify relationship with RoPE implementation
While the changes improve video handling, it's unclear how this file relates to the PR's objective of implementing RoPE (Rotary Positional Encoding) and Fourier RoPE.
Let's check for RoPE-related code in the codebase:
🧰 Tools
🪛 Ruff
163-163: Local variable e
is assigned to but never used
Remove assignment to unused variable e
(F841)
dreem/io/config.py (2)
179-181
: LGTM! Sorting ensures deterministic file ordering.
The addition of sorting to glob results is a good practice that ensures consistent and reproducible behavior across different systems.
420-420
: Document and verify checkpoint naming change.
The change from 'best' to 'final' in checkpoint naming could impact existing scripts or workflows. Please:
- Document this change in the code comments
- Verify that dependent code is updated accordingly
✅ Verification successful
Let me try a different search strategy to verify the checkpoint naming impact.
Based on the search results, I can now generate the final response:
Checkpoint naming change appears safe to proceed
The change from 'best' to 'final' in the checkpoint naming pattern appears safe as:
- No hardcoded references to 'best' checkpoint names were found in the codebase
- All checkpoint loading operations use path variables rather than hardcoded patterns
- The checkpoint loading mechanism (
load_from_checkpoint
) is consistently used across files
However, consider adding a comment above the line to document this naming convention:
# Format: {epoch}-final-{metric} (e.g., "10-final-accuracy")
checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-final-{{{metric}}}"
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for references to checkpoint filenames containing 'best'
echo "Searching for potential impacts of checkpoint naming change..."
rg -l 'best.*checkpoint' --type py
# Search for checkpoint loading code
echo "Searching for checkpoint loading patterns..."
ast-grep --pattern 'load_from_checkpoint($_)'
Length of output: 318
Script:
#!/bin/bash
# Search for any references to checkpoint names or patterns
echo "Searching for checkpoint name references..."
rg -A 2 'CHECKPOINT_NAME' --type py
# Search for checkpoint loading or saving patterns
echo "Searching for checkpoint operations..."
rg -A 2 '(save|load).*checkpoint' --type py
# Search for documentation or comments about checkpoints
echo "Searching for checkpoint documentation..."
rg -B 1 'checkpoint' '*.md' '*.rst' '*.txt'
Length of output: 2353
tests/test_models.py (2)
38-40
: LGTM: Correct addition of embedding_agg_method
parameter
The ATTWeightHead
is correctly instantiated with the new embedding_agg_method
parameter set to "average"
.
210-250
: LGTM: Proper implementation of test_rope_embedding
The test_rope_embedding
function appropriately tests the RoPE embedding functionality, ensuring correct output dimensions and validating that the embeddings are distinct.
dreem/inference/tracker.py (4)
354-355
: Include height and width when scaling boxes.
Ensure that h
and w
correctly represent the image height and width for scaling bounding boxes. This is important for accurate post-processing in filtering based on center distance.
Line range hint 479-489
: Clarify threshold logic in track assignment.
The condition in the if-statement determines whether to assign an instance to an existing track or create a new one. Ensure that the logic correctly reflects the intended behavior, especially regarding the threshold comparison.
Double-check the threshold comparison to confirm that it aligns with the intended logic. If the intention is to create a new track when the association score is lower than the threshold, the condition may need to be adjusted.
288-291
:
Handle potential empty tensors in concatenation.
When concatenating features in reid_features
, ensure that all tensors are non-empty to avoid runtime errors. If any frame does not have instances, frame.get_features()
may return an empty tensor.
Consider adding a check before concatenation:
reid_features = torch.cat([frame.get_features() for frame in frames], dim=0)[
None
] # (1, total_instances, D=512)
+ if reid_features.size(1) == 0:
+ logger.warning("No instance features available for tracking.")
+ return query_frame
Likely invalid or redundant comment.
Line range hint 261-307
: Ensure proper handling of empty frame instances.
When concatenating features for frames, ensure that frames without instances are handled correctly to prevent runtime errors due to empty tensors.
Run the following script to check for frames without instances:
dreem/models/embedding.py (3)
131-140
: Addition of RoPE mode in embeddings looks good
The implementation of the 'rope'
mode embeddings and the instantiation of the appropriate positional embeddings classes enhance the flexibility of the Embedding
class. The conditional logic correctly handles the creation of rope_instance
based on the use_fourier
flag.
174-227
: New embedding transformation methods are well-implemented
The methods _transform
, _apply_rope
, and _apply_additive_embeddings
effectively handle the embedding transformations based on the selected aggregation method and embedding mode. The code is clear and logically structured.
228-248
: Ensure all calls to Embedding.forward
are updated
The forward
method signature has been modified to include an additional parameter x
and now returns both xout
and emb
. Please verify that all calls to Embedding.forward
in the codebase are updated to match the new signature and handle the returned tuple appropriately.
Run the following script to identify usages of Embedding.forward
that may need updating:
scripts/run_eval.py
Outdated
from dreem.training import train | ||
from omegaconf import OmegaConf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix import statements and remove unused imports.
The script has import-related issues:
- The
train
import is unused and should be removed - The
eval
module is used but not imported
Apply this diff to fix the imports:
-from dreem.training import train
+from dreem.inference import eval
from omegaconf import OmegaConf
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
from dreem.training import train | |
from omegaconf import OmegaConf | |
from dreem.inference import eval | |
from omegaconf import OmegaConf |
🧰 Tools
🪛 Ruff
1-1: dreem.training.train
imported but unused
Remove unused import: dreem.training.train
(F401)
scripts/run_eval.py
Outdated
inference_config = "tests/configs/inference.yaml" | ||
|
||
cfg = OmegaConf.load(inference_config) | ||
|
||
eval.run(cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Improve configuration handling and add error handling.
The script needs improvements in configuration handling:
- Hardcoded configuration path should be configurable
- Missing error handling for configuration loading
- No validation of loaded configuration
Consider refactoring like this:
-inference_config = "tests/configs/inference.yaml"
-
-cfg = OmegaConf.load(inference_config)
-
-eval.run(cfg)
+import argparse
+import sys
+from pathlib import Path
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Run model evaluation')
+ parser.add_argument('--config', type=str, required=True,
+ help='Path to inference configuration file')
+ return parser.parse_args()
+
+def main():
+ args = parse_args()
+ config_path = Path(args.config)
+
+ if not config_path.exists():
+ print(f"Error: Configuration file not found: {config_path}")
+ sys.exit(1)
+
+ try:
+ cfg = OmegaConf.load(config_path)
+ eval.run(cfg)
+ except Exception as e:
+ print(f"Error during evaluation: {str(e)}")
+ sys.exit(1)
+
+if __name__ == "__main__":
+ main()
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
inference_config = "tests/configs/inference.yaml" | |
cfg = OmegaConf.load(inference_config) | |
eval.run(cfg) | |
import argparse | |
import sys | |
from pathlib import Path | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Run model evaluation') | |
parser.add_argument('--config', type=str, required=True, | |
help='Path to inference configuration file') | |
return parser.parse_args() | |
def main(): | |
args = parse_args() | |
config_path = Path(args.config) | |
if not config_path.exists(): | |
print(f"Error: Configuration file not found: {config_path}") | |
sys.exit(1) | |
try: | |
cfg = OmegaConf.load(config_path) | |
eval.run(cfg) | |
except Exception as e: | |
print(f"Error during evaluation: {str(e)}") | |
sys.exit(1) | |
if __name__ == "__main__": | |
main() |
scripts/run_tracker.py
Outdated
from dreem.inference import track | ||
from omegaconf import OmegaConf | ||
import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean up imports
The os
module is imported but not used in the code.
Apply this diff to remove the unused import:
from dreem.inference import track
from omegaconf import OmegaConf
-import os
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
from dreem.inference import track | |
from omegaconf import OmegaConf | |
import os | |
from dreem.inference import track | |
from omegaconf import OmegaConf |
🧰 Tools
🪛 Ruff
3-3: os
imported but unused
Remove unused import: os
(F401)
scripts/run_tracker.py
Outdated
# /Users/mustafashaikh/dreem/dreem/training | ||
# /Users/main/Documents/GitHub/dreem/dreem/training | ||
# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training") | ||
config = "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/inference.yaml" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Remove development paths and use configurable config path
The code contains hardcoded development paths in comments and a hardcoded absolute path to the config file. This makes the script less portable and harder to use in different environments.
Consider:
- Remove the commented-out development paths
- Make the config path configurable via command-line arguments using
argparse
:
from dreem.inference import track
from omegaconf import OmegaConf
+import argparse
-# /Users/mustafashaikh/dreem/dreem/training
-# /Users/main/Documents/GitHub/dreem/dreem/training
-# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training")
-config = "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/inference.yaml"
+def parse_args():
+ parser = argparse.ArgumentParser(description='Run tracker with YAML config')
+ parser.add_argument('--config', required=True, help='Path to inference YAML config file')
+ return parser.parse_args()
+
+args = parse_args()
+config = args.config
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# /Users/mustafashaikh/dreem/dreem/training | |
# /Users/main/Documents/GitHub/dreem/dreem/training | |
# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training") | |
config = "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/inference.yaml" | |
from dreem.inference import track | |
from omegaconf import OmegaConf | |
import argparse | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Run tracker with YAML config') | |
parser.add_argument('--config', required=True, help='Path to inference YAML config file') | |
return parser.parse_args() | |
args = parse_args() | |
config = args.config |
scripts/run_trainer.py
Outdated
# /Users/mustafashaikh/dreem/dreem/training | ||
# /Users/main/Documents/GitHub/dreem/dreem/training | ||
# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove personal directory paths from comments.
These comments contain personal directory paths that should not be in the codebase.
Remove these commented paths as they:
- Expose personal information
- Don't provide value to other developers
- May pose security risks
def forward(self, seq_positions: torch.Tensor) -> torch.Tensor: | ||
if mode.lower() == "rope" and self.embedding_agg_method == "average": | ||
raise ValueError( | ||
f"Cannot use aggregation method 'average' for rope embedding; must use 'stack' or 'concatenate'" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unnecessary f
prefix in the error message
Line 166 contains an f-string without any placeholders. The f
prefix is unnecessary and can be removed to clean up the code.
Apply this diff to fix the issue:
-f"Cannot use aggregation method 'average' for rope embedding; must use 'stack' or 'concatenate'"
+"Cannot use aggregation method 'average' for rope embedding; must use 'stack' or 'concatenate'"
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
f"Cannot use aggregation method 'average' for rope embedding; must use 'stack' or 'concatenate'" | |
"Cannot use aggregation method 'average' for rope embedding; must use 'stack' or 'concatenate'" |
🧰 Tools
🪛 Ruff
166-166: f-string without any placeholders
Remove extraneous f
prefix
(F541)
proj = nn.Linear(queries_cat.shape[-1], queries.shape[-1]).to(queries_cat.device) | ||
norm = nn.LayerNorm(queries.shape[-1]).to(queries_cat.device) | ||
|
||
queries = proj(queries_cat) | ||
queries = norm(queries) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid instantiating nn
modules inside functions to ensure parameters are registered
Defining nn.Linear
and nn.LayerNorm
inside the apply_fourier_embeddings
function will result in new instances being created every time the function is called. Their parameters won't be registered with the model, so they won't be updated during training, leading to unexpected behavior.
Consider moving the instantiation of proj
and norm
outside the function, possibly as attributes of the class. Would you like assistance in refactoring this code to ensure proper parameter registration?
mlp = MLP( | ||
input_dim=queries_t.shape[-1] * 3, # t,x,y | ||
hidden_dim=queries_t.shape[-1] * 6, # not applied when num_layers=1 | ||
output_dim=queries_t.shape[-1], | ||
num_layers=1, | ||
dropout=0.0, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid creating nn
modules inside functions to ensure parameters are registered
Instantiating the MLP
module inside the collate_queries
function means that its parameters won't be registered for optimization, and it will be recreated on every call. This prevents the model from learning the correct weights during training.
Consider defining the MLP
instance outside the function, such as in the __init__
method of a class, and passing it as a parameter. Would you like assistance in refactoring this code to properly register the MLP
parameters?
encoder_queries, pos_emb=ref_emb | ||
) # (total_instances, batch_size, embed_dim) | ||
# apply fourier embeddings | ||
if "use_fourier" in self.embedding_meta and self.embedding_meta["use_fourier"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Simplify condition by using self.use_fourier
Since self.use_fourier
already holds the boolean value, you can simplify the conditional check by using it directly.
Apply this diff to simplify the condition:
- if "use_fourier" in self.embedding_meta and self.embedding_meta["use_fourier"]:
+ if self.use_fourier:
Committable suggestion was skipped due to low confidence.
embedding_meta["embedding_agg_method"] | ||
if "embedding_agg_method" in embedding_meta | ||
else "average" | ||
) | ||
self.use_fourier = ( | ||
embedding_meta["use_fourier"] | ||
if "use_fourier" in embedding_meta | ||
else False | ||
) | ||
self.fourier_n_components = ( | ||
embedding_meta["fourier_n_components"] | ||
if "fourier_n_components" in embedding_meta | ||
else self.d_model // 2 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Simplify dictionary key access using get()
method
Instead of using if
statements to check for the existence of keys in the embedding_meta
dictionary, you can use the get()
method to make the code cleaner and more readable.
Apply this diff to simplify the code:
- self.embedding_agg_method = (
- embedding_meta["embedding_agg_method"]
- if "embedding_agg_method" in embedding_meta
- else "average"
- )
+ self.embedding_agg_method = embedding_meta.get("embedding_agg_method", "average")
- self.use_fourier = (
- embedding_meta["use_fourier"]
- if "use_fourier" in embedding_meta
- else False
- )
+ self.use_fourier = embedding_meta.get("use_fourier", False)
- self.fourier_n_components = (
- embedding_meta["fourier_n_components"]
- if "fourier_n_components" in embedding_meta
- else self.d_model // 2
- )
+ self.fourier_n_components = embedding_meta.get("fourier_n_components", self.d_model // 2)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
embedding_meta["embedding_agg_method"] | |
if "embedding_agg_method" in embedding_meta | |
else "average" | |
) | |
self.use_fourier = ( | |
embedding_meta["use_fourier"] | |
if "use_fourier" in embedding_meta | |
else False | |
) | |
self.fourier_n_components = ( | |
embedding_meta["fourier_n_components"] | |
if "fourier_n_components" in embedding_meta | |
else self.d_model // 2 | |
) | |
self.embedding_agg_method = embedding_meta.get("embedding_agg_method", "average") | |
self.use_fourier = embedding_meta.get("use_fourier", False) | |
self.fourier_n_components = embedding_meta.get("fourier_n_components", self.d_model // 2) |
🧰 Tools
🪛 Ruff
85-87: Use embedding_meta.get("embedding_agg_method", "average")
instead of an if
block
Replace with embedding_meta.get("embedding_agg_method", "average")
(SIM401)
90-92: Use embedding_meta.get("use_fourier", False)
instead of an if
block
Replace with embedding_meta.get("use_fourier", False)
(SIM401)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (2)
dreem/datasets/sleap_dataset.py (2)
127-127
: Add input validation in get_indicesThe method should validate that the input index is within bounds of the available chunks.
def get_indices(self, idx: int) -> tuple: + if not 0 <= idx < len(self.label_idx): + raise IndexError(f"Index {idx} is out of bounds for dataset with {len(self.label_idx)} chunks") return self.label_idx[idx], self.chunked_frame_idx[idx]
Line range hint
150-170
: Fix unused variablesThere are two unused variables that should be addressed:
- Rename the unused loop variable
i
to_
- Remove the unused exception variable
e
- for i, frame_ind in enumerate(frame_idx): + for _, frame_ind in enumerate(frame_idx): ( instances, gt_track_ids, poses, shown_poses, point_scores, instance_score, ) = ([], [], [], [], [], []) frame_ind = int(frame_ind) lf = video[frame_ind] try: img = lf.image - except FileNotFoundError as e: + except FileNotFoundError: if video_name not in self.vid_readers: self.vid_readers[video_name] = sio.load_video(video_name)🧰 Tools
🪛 Ruff
166-166: Local variable
e
is assigned to but never usedRemove assignment to unused variable
e
(F841)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
dreem/datasets/base_dataset.py
(1 hunks)dreem/datasets/sleap_dataset.py
(5 hunks)
✅ Files skipped from review due to trivial changes (1)
- dreem/datasets/base_dataset.py
🧰 Additional context used
🪛 Ruff
dreem/datasets/sleap_dataset.py
150-150: Loop control variable i
not used within loop body
Rename unused i
to _i
(B007)
166-166: Local variable e
is assigned to but never used
Remove assignment to unused variable e
(F841)
🔇 Additional comments (2)
dreem/datasets/sleap_dataset.py (2)
109-109
: LGTM! Good improvement in video reader management
Converting vid_readers
from a list to a dictionary improves efficiency by providing O(1) lookup and preventing duplicate video reader instances.
373-373
: Previous review comment about video reader cleanup is still valid
Please refer to the previous review comment about improving the video reader cleanup implementation.
…for extra attention block
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (7)
dreem/inference/eval.py (1)
63-64
: Remove unnecessary f-strings and improve logging format.The f-string prefix is unnecessary when there are no placeholders. Also, consider using a more structured format for logging metrics.
-logger.info(f"Computing the following metrics:") -logger.info(model.metrics['test']) +logger.info("Computing the following metrics:") +logger.info("Test metrics: %s", model.metrics['test'])🧰 Tools
🪛 Ruff
63-63: f-string without any placeholders
Remove extraneous
f
prefix(F541)
dreem/models/attention_head.py (3)
22-25
: Simplify kwargs handling using dict.get()Replace the if-else block with a single line using dict.get() for cleaner code:
- if "embedding_agg_method" in kwargs: - self.embedding_agg_method = kwargs["embedding_agg_method"] - else: - self.embedding_agg_method = None + self.embedding_agg_method = kwargs.get("embedding_agg_method", None)🧰 Tools
🪛 Ruff
22-25: Use
self.embedding_agg_method = kwargs.get("embedding_agg_method", None)
instead of anif
blockReplace with
self.embedding_agg_method = kwargs.get("embedding_agg_method", None)
(SIM401)
19-19
: Improve parameter documentationThe embedding_agg_method parameter documentation should specify the valid options ("stack" or None) and their implications.
- embedding_agg_method: how the embeddings are aggregated; average/stack/concatenate + embedding_agg_method: Method for aggregating embeddings. When set to "stack", + uses separate attention heads for x, y, t dimensions with 1D convolution. + When None, uses standard MLP projections.
82-84
: Consider parallelizing attention computationsThe attention computations for t, x, y dimensions are performed sequentially. Consider using
torch.nn.ModuleList
and parallel computation for better performance.+ # Store attention heads in ModuleList for parallel computation + self.attention_heads = torch.nn.ModuleList([ + self.attn_t, self.attn_x, self.attn_y + ]) + # Compute attention in parallel + attention_outputs = [ + head(query=query[:, i, :], key=key_orig, value=key_orig)[0] + for i, head in enumerate(self.attention_heads) + ] + collated = torch.stack(attention_outputs, dim=0).permute(1, 0, 2)dreem/training/losses.py (3)
84-84
: Enhance shape documentation clarity.While the shape documentation is helpful, the comment for
asso_pred
could be more precise. Consider clarifying what "total_instances across all frames" means, perhaps by using mathematical notation or providing an example.- # asso_pred is shape (total_instances across all frames, total_instances across all frames) + # asso_pred is shape (M, M) where M = sum(n_t), and n_t is the number of instances in each frameAlso applies to: 86-86
222-223
: Improve shape documentation precision.The shape documentation is helpful but could be more precise:
- Use mathematical notation for clarity
- Explain why +1 is needed for background class
- Document how track IDs map to class labels
- # asso_preds_with_bg is shape (total_instances across all frames, n_t + 1) where +1 is for background class - # asso_gt_t is shape (total_instances across all frames); each entry is the track id of the gt instance i.e. the class label + # asso_preds_with_bg: shape (M, n_t + 1) where: + # - M = sum(n_t) is total instances across all frames + # - n_t is instances in current frame + # - +1 dimension handles background class for unmatched predictions + # asso_gt_t: shape (M,) containing track IDs as class labels + # - track ID ∈ [0, n_t-1] for matched instances + # - track ID = n_t for background class
Line range hint
8-9
: Consider implementing TODO suggestions for better code maintainability.The TODOs suggest using named tensors and cleaning up names/comments. Named tensors could help prevent shape-related bugs and improve code clarity, especially given the complex tensor operations in this file.
Would you like help implementing named tensors or suggestions for name improvements?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
dreem/inference/eval.py
(1 hunks)dreem/models/attention_head.py
(2 hunks)dreem/training/losses.py
(4 hunks)
🧰 Additional context used
🪛 Ruff
dreem/inference/eval.py
30-30: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
56-56: f-string without any placeholders
Remove extraneous f
prefix
(F541)
63-63: f-string without any placeholders
Remove extraneous f
prefix
(F541)
dreem/models/attention_head.py
22-25: Use self.embedding_agg_method = kwargs.get("embedding_agg_method", None)
instead of an if
block
Replace with self.embedding_agg_method = kwargs.get("embedding_agg_method", None)
(SIM401)
🔇 Additional comments (9)
dreem/inference/eval.py (5)
30-30
: Simplify dictionary membership test.
The existing review comment about using if "batch_config" in cfg:
instead of if "batch_config" in cfg.keys():
is still valid.
🧰 Tools
🪛 Ruff
30-30: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
34-36
: Add input validation for task index.
The existing review comment about adding error handling for non-integer user input is still valid.
56-58
: Improve logging consistency.
The existing review comments about:
- Removing unnecessary
f
prefix inlogger.info(f"Using the following tracker:")
- Replacing
print(model.tracker)
withlogger.info(model.tracker)
are still valid.
🧰 Tools
🪛 Ruff
56-56: f-string without any placeholders
Remove extraneous f
prefix
(F541)
Line range hint 22-22
: Consider maintaining specific type hints.
The return type annotation has been changed from dict[int, sio.Labels]
to dict
. While this makes the function more flexible, it reduces type safety. If possible, please specify the expected key and value types to help with static type checking and code clarity.
🧰 Tools
🪛 Ruff
30-30: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
56-56: f-string without any placeholders
Remove extraneous f
prefix
(F541)
63-63: f-string without any placeholders
Remove extraneous f
prefix
(F541)
42-44
: Verify RoPE implementation integration.
The PR objectives mention implementing RoPE and Fourier RoPE, but this file only shows batch configuration changes. Please verify that the hyperparameters being set here properly configure the RoPE implementation.
✅ Verification successful
Let me try a different search strategy to verify the RoPE implementation.
This broader search should help us:
- Find any RoPE-related code even if not in class definitions
- Locate positional encoding implementations that might be related
- Understand what hyperparameters are being configured in the codebase
Based on the search results, I can now provide a final response as there's clear evidence about the RoPE implementation and its integration with hyperparameters.
RoPE implementation is properly integrated with hyperparameter configuration
The codebase shows a complete RoPE implementation with proper hyperparameter integration:
- RoPE is implemented in
dreem/models/embedding.py
with both standard and Fourier variants - The embedding configuration is properly handled through hyperparameters:
- Mode can be set to "rope" in embedding configuration
- Supports both standard RoPE and Fourier RoPE variants
- Validates embedding aggregation methods ("stack" or "concatenate" required for RoPE)
- The transformer model in
dreem/models/transformer.py
correctly applies RoPE:- Only applies to query and key vectors, not values
- Properly integrates with both positional and temporal embeddings
The hyperparameter setting in eval.py
will correctly configure RoPE when specified in the config.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for RoPE-related configurations and implementations
echo "Searching for RoPE configurations..."
rg -i "rope" --type yaml
echo "Searching for RoPE implementations..."
ast-grep --pattern 'class $_(RoPE|RotaryPositionalEncoding) { $$$ }'
Length of output: 301
Script:
#!/bin/bash
# Search for RoPE-related code and configurations more broadly
echo "Searching for RoPE in Python files..."
rg -i "rope" --type py -A 5
echo -e "\nSearching for positional encoding implementations..."
rg -i "position.*encod" --type py -A 5
echo -e "\nSearching for hyperparameter configurations..."
rg "set_hparams" --type py -A 5
Length of output: 31626
dreem/models/attention_head.py (2)
90-94
:
** Inconsistent use of projections**
The previous review comment about inconsistent use of projections is still valid. The code uses k = self.k_proj(key_stacked)
but leaves q = self.q_proj(query)
commented out, which could lead to inconsistencies in attention weight computation.
Either:
- Uncomment and use both projections:
k = self.k_proj(key_stacked).transpose(1, 0)
-# q = self.q_proj(query).transpose(1, 0)
+q = self.q_proj(query).transpose(1, 0)
- Or document why only key projection is needed:
-# TODO: uncomment this if not using modified attention heads for t,x,y
+# Note: Only key projection is used because query projection is handled by
+# separate attention heads (attn_x, attn_y, attn_t)
65-72
:
Verify tensor reshaping logic
The reshaping operations assume the input tensors have specific dimensions divisible by 3. This could cause runtime errors if the assumptions don't hold.
Consider adding input validation:
+ if num_window_instances % 3 != 0 or num_query_instances % 3 != 0:
+ raise ValueError("Number of instances must be divisible by 3 when using stack mode")
dreem/training/losses.py (2)
Line range hint 44-63
: Review commented-out zero loss conditions.
The commented-out block contains important logic for handling edge cases (empty videos or single-instance videos). Consider:
- Why was this code commented out?
- Are these edge cases now handled differently?
- Should this be restored with improved conditions?
#!/bin/bash
# Search for other zero loss handling in the codebase
rg "return.*zero.*loss" --type py
125-125
: Verify fixed detection assumption.
The comment states that pred_box
and target_box
are the same for fixed detections, resulting in a diagonal IoU matrix. Let's verify this assumption across the codebase.
✅ Verification successful
Based on the search results, I can now generate the final response:
Fixed detection assumption is verified correct
The code in dreem/training/losses.py
explicitly shows that target_box
is directly assigned the value of pred_box
:
pred_box = get_boxes(instances)
pred_box = torch.nanmean(pred_box, axis=1)
target_box, target_time = pred_box, pred_time
This confirms that for fixed detections, the prediction and target boxes are indeed identical, which would result in a diagonal IoU matrix when computing box_iou(pred_box, target_box)
.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for places where pred_box and target_box are set to understand if they can ever differ
rg -A 5 "pred_box|target_box" --type py
Length of output: 4497
hparams_df = pd.read_csv(cfg.batch_config) | ||
hparams = hparams_df.iloc[index].to_dict() | ||
_ = hparams.pop("Unnamed: 0", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add CSV file validation and error handling.
The code assumes the CSV file exists and contains the expected columns. Consider adding validation:
- Check if the CSV file exists
- Verify required columns are present
- Validate the index is within bounds of the DataFrame
+ if not os.path.exists(cfg.batch_config):
+ raise FileNotFoundError(f"Batch config file {cfg.batch_config} not found")
hparams_df = pd.read_csv(cfg.batch_config)
+ if len(hparams_df) <= index:
+ raise IndexError(f"Task index {index} exceeds number of configurations {len(hparams_df)}")
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
hparams_df = pd.read_csv(cfg.batch_config) | |
hparams = hparams_df.iloc[index].to_dict() | |
_ = hparams.pop("Unnamed: 0", None) | |
if not os.path.exists(cfg.batch_config): | |
raise FileNotFoundError(f"Batch config file {cfg.batch_config} not found") | |
hparams_df = pd.read_csv(cfg.batch_config) | |
if len(hparams_df) <= index: | |
raise IndexError(f"Task index {index} exceeds number of configurations {len(hparams_df)}") | |
hparams = hparams_df.iloc[index].to_dict() | |
_ = hparams.pop("Unnamed: 0", None) |
Based off Mustafa-rope (that PR can be closed and not merged). Implements RoPE and Fourier RoPE
Summary by CodeRabbit
New Features
.gitignore
to ignore specific XML and YAML files.SleapDataset
for improved video data management and error handling.Tracker
class.Embedding
class for enhanced flexibility.Bug Fixes
Documentation
Tests