Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoPE, Fourier RoPE, Fourier learnable embeddings #96

Open
wants to merge 81 commits into
base: main
Choose a base branch
from

Conversation

shaikh58
Copy link
Contributor

@shaikh58 shaikh58 commented Oct 30, 2024

Based off Mustafa-rope (that PR can be closed and not merged). Implements RoPE and Fourier RoPE

Summary by CodeRabbit

  • New Features

    • Enhanced .gitignore to ignore specific XML and YAML files.
    • Updated SleapDataset for improved video data management and error handling.
    • Modified inference scripts to support batch configuration and improved hyperparameter handling.
    • Added new methods and improved tracking functionality in the Tracker class.
    • Introduced new embedding options in the Embedding class for enhanced flexibility.
  • Bug Fixes

    • Improved error handling in various inference scripts and dataset management.
  • Documentation

    • Added comments and clarifications across multiple files to improve code readability and maintainability.
  • Tests

    • Expanded test cases for embedding and model functionality to ensure robustness against new features.

shaikh58 and others added 30 commits July 31, 2024 15:23
- changes to embedding class
- add apply() function to Embedding class
- remove references to embedding from encoderlayer fwd pass
- add support for stack/avg/concatenate
- move embedding processing out of transformer and into encoder
- get centroid from x,y for spatial embedding
- complete stack agg method
- add docstrings
- concatenation method with mlp
- complete pre-processing for input to EncoderLayer
- fix shape issues in rope/additive_embedding/forward modules in embedding.py
- bounding box embedding only for method "average" - modify emb_funcs routing
- temporarily remove support for adding embeddings into instance objects - need to make compatible with x,y,t embeddings
- remove config yamls from updates - current versions serve as templates
- runs through to end of encoder forward pass
- implement embeddings for decoder + refactor
- add 1x1 conv to final attn head to deal with stacked embeddings (3x tokens) and create channels for each dim
- bug fix in rope rotation matrix product with input data
- 1x1 conv for stack embedding
- stack into 3 channels for x,y,t
- add unit tests for rope
- Update existing tests to use new args/return params related to tfmr
- Modify test to remove return_embedding=True support - need to address this
- create rope isntance once rather than each fwd pass
- construct embedding lookup array each fwd pass based on num instances passed in to embedding
- scale only pos embs * 100 rather than also temp embs
- times array for embedding for encoder queries inside decoder was of query size rather than ref size
shaikh58 and others added 19 commits September 27, 2024 09:48
- times array for embedding for encoder queries inside decoder was of query size rather than ref size
- Fix RoPE to apply to q,k not v; pass in config to layers, update forward passes to enable this
- Add support for learned Fourier coeffs for RoPE
- complete fourier rope; only change vs rope is cache creation as param not buffer
- disabled encoder output embedding at decoder input (not standard practice)
- Apply embeddings to q,k,v instead of just q,k (but still add to orig queries rather than embedded queries)
…ce in each forward pass of transformer

- add if/else branch to encoder/decoder layer to apply embeddings only to qk not v in case of rope mode
Copy link
Contributor

coderabbitai bot commented Oct 30, 2024

Walkthrough

This pull request introduces various updates across multiple files in the project. Key changes include enhancements to the .gitignore file to exclude specific YAML and XML files, modifications to the SleapDataset class for improved video data management, and updates to several inference and training scripts to support batch configurations and new embedding methods. Additionally, several configuration files have been added or modified to streamline the training process and enhance logging. Overall, these changes aim to improve the flexibility, clarity, and functionality of the codebase.

Changes

File Change Summary
.gitignore Added patterns: *.xml, dreem/training/configs/base.yaml, dreem/training/configs/override.yaml to ignore specific files.
dreem/datasets/sleap_dataset.py Renamed self.videos to self.vid_readers; updated get_instances method for error handling; added multiple new variables for enhanced functionality.
dreem/inference/eval.py Updated run function to check for batch_config instead of checkpoints; modified hyperparameter handling to read from CSV.
dreem/inference/post_processing.py Updated filter_max_center_dist function to include new parameters h and w; added comments for clarity.
dreem/inference/track.py Modified run function to handle batch_config; improved hyperparameter reading and error handling.
dreem/inference/tracker.py Enhanced Tracker class with improved tracking logic and added methods for sliding inference and global tracking.
dreem/io/config.py Updated Config class methods to handle None return types and improved error logging.
dreem/io/instance.py Modified add_embedding method to handle different types of embeddings; updated pose setter logic.
dreem/models/__init__.py Added import for FourierPositionalEmbeddings.
dreem/models/attention_head.py Updated ATTWeightHead to accept embedding_agg_method in constructor; modified initialization logic.
dreem/models/embedding.py Enhanced Embedding class with new parameters and modes; added new methods for handling different embedding types.
dreem/models/gtr_runner.py Simplified video name extraction and ground truth track ID handling in on_test_epoch_end.
dreem/models/mlp.py Added comments for clarity in __init__ and forward methods.
dreem/models/transformer.py Significant updates to positional encoding, decoder output, and embedding handling; introduced new helper functions.
dreem/training/configs/base.yaml Added embedding_agg_method; modified dataset configurations and training parameters.
dreem/training/configs/override.yaml New configuration file defining model, dataset, and training parameters.
dreem/training/train.py Enhanced dataset handling and training configuration; added batch viewing option.
tests/test_inference.py Updated test_post_processing to include new parameters for filter_max_center_dist.
tests/test_models.py Expanded tests for Embedding class and added new tests for RoPE embedding functionality.
tests/test_training.py Commented out anomaly detection in test_config_gtr_runner.
dreem/datasets/base_dataset.py Added comments for clarity in create_chunks method.
dreem/training/losses.py Enhanced documentation for tensor shapes in AssoLoss class methods.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Config
    participant Dataset
    participant Model
    participant Tracker

    User->>Config: Load configuration
    Config->>Dataset: Initialize datasets
    Dataset->>Model: Prepare data for training
    Model->>Tracker: Start tracking process
    Tracker->>Model: Process frames
    Model->>User: Return results
Loading

🐇 "In the code, a change we see,
With videos tracked so easily!
YAMLs ignored, configs refined,
A smoother path for all aligned.
With hops and jumps, we now can play,
In the world of code, hip-hip-hooray!" 🐇


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 35

🧹 Outside diff range and nitpick comments (20)
scripts/run_eval.py (1)

4-6: Remove debugging file paths from comments.

These absolute file paths appear to be debugging leftovers and should be removed to maintain clean code.

-# /Users/mustafashaikh/dreem/dreem/training
-# /Users/main/Documents/GitHub/dreem/dreem/training
dreem/models/__init__.py (1)

Line range hint 1-11: Consider defining __all__ to explicitly specify the public API.

To properly expose the public API and address the unused import warnings, consider adding an __all__ definition:

"""Model architectures and layers."""

from .embedding import Embedding, FourierPositionalEmbeddings

# from .mlp import MLP
# from .attention_head import ATTWeightHead

from .transformer import Transformer
from .visual_encoder import VisualEncoder

from .global_tracking_transformer import GlobalTrackingTransformer
from .gtr_runner import GTRRunner

+__all__ = [
+    'Embedding',
+    'FourierPositionalEmbeddings',
+    'Transformer',
+    'VisualEncoder',
+    'GlobalTrackingTransformer',
+    'GTRRunner',
+]
🧰 Tools
🪛 Ruff

3-3: .embedding.Embedding imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


3-3: .embedding.FourierPositionalEmbeddings imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)

scripts/run_trainer.py (1)

1-20: Consider adding script documentation and type hints.

The script would benefit from docstrings and type hints to improve maintainability.

Add a module docstring explaining the purpose and usage of the script:

"""
Training script for the Dreem framework.

This script loads base and parameter configurations from YAML files and executes
the training process. It supports configuration overrides through a separate
parameter file.

Usage:
    python run_trainer.py --base-config path/to/base.yaml --params-config path/to/params.yaml
"""
🧰 Tools
🪛 Ruff

3-3: os imported but unused

Remove unused import: os

(F401)

dreem/models/mlp.py (2)

37-39: Improve comment visibility and clarity.

The comment explaining list concatenations would be more visible and useful if moved above the list comprehension. Consider expanding it to clarify the layer shape compatibility concept.

+            # Using list concatenations ([input_dim] + h and h + [output_dim])
+            # to ensure proper layer shape transitions:
+            # input_dim -> hidden_dims -> output_dim
             [
                 torch.nn.Linear(n, k)
-                # list concatenations to ensure layer shape compability
                 for n, k in zip([input_dim] + h, h + [output_dim])
             ]

Device management in MLP needs optimization for RoPE integration

The verification reveals that the current MLP implementation moves layers to the input device during each forward pass, which is inefficient. This is particularly important in the RoPE context where the MLP is used in:

  • Query/Key projections in attention heads
  • Embedding layer transformations
  • Transformer layer computations

Key issues:

  • layer.to(x.device) in MLP's forward method causes unnecessary device transfers on every forward pass
  • Other components like RoPE embeddings handle device management during initialization or first use
🔗 Analysis chain

Line range hint 1-63: Verify MLP integration with RoPE implementation.

While the MLP implementation is generic, let's verify its usage within the RoPE context to ensure optimal performance.


Let me gather more context about the RoPE implementation and its interaction with MLP.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for MLP usage in relation to RoPE implementation
echo "Searching for MLP usage in transformer/embedding files..."
rg -l "class.*RoPE" | xargs rg "MLP"

echo "Checking for potential performance bottlenecks in model construction..."
ast-grep --pattern 'MLP(
  $$$
)'

Length of output: 2123


Script:

#!/bin/bash
# Check RoPE implementation and its interaction with MLP
echo "Checking RoPE implementation..."
rg -A 10 "class.*RoPE"

echo "\nChecking MLP usage in attention mechanism..."
rg -A 5 "def forward.*RoPE"

echo "\nChecking device handling in RoPE and related components..."
rg "to\(.*device\)" dreem/models/

Length of output: 2335

dreem/training/configs/base.yaml (1)

76-76: Document supported dataset modes

The addition of mode parameter to datasets is a good practice, but it needs documentation.

Please add comments explaining:

  1. Available dataset modes
  2. How each mode affects dataset behavior
  3. Any mode-specific preprocessing or augmentation

Also applies to: 85-85, 94-94

dreem/training/configs/override.yaml (2)

1-142: Fix YAML formatting inconsistencies.

The YAML file has several formatting issues that should be addressed:

  1. Inconsistent indentation (especially in nested structures)
  2. Trailing spaces
  3. Missing newline at end of file
  4. Inconsistent spacing after commas and colons

Run yamllint locally to fix these issues or apply the following changes:

  • Use 2 spaces for indentation consistently
  • Remove trailing spaces
  • Add newline at end of file
  • Ensure consistent spacing after commas and colons
🧰 Tools
🪛 yamllint

[error] 3-3: trailing spaces

(trailing-spaces)


[warning] 4-4: wrong indentation: expected 4 but found 6

(indentation)


[warning] 20-20: wrong indentation: expected 6 but found 8

(indentation)


[warning] 24-24: wrong indentation: expected 6 but found 8

(indentation)


[warning] 29-29: too many spaces after colon

(colons)


[error] 43-43: trailing spaces

(trailing-spaces)


[warning] 46-46: wrong indentation: expected 6 but found 8

(indentation)


[warning] 58-58: too few spaces after comma

(commas)


[warning] 66-66: too few spaces after comma

(commas)


[error] 76-76: trailing spaces

(trailing-spaces)


[warning] 78-78: wrong indentation: expected 4 but found 6

(indentation)


[error] 78-78: trailing spaces

(trailing-spaces)


[error] 99-99: trailing spaces

(trailing-spaces)


[error] 112-112: trailing spaces

(trailing-spaces)


[error] 130-130: trailing spaces

(trailing-spaces)


[warning] 132-132: wrong indentation: expected 2 but found 4

(indentation)


[warning] 133-133: wrong indentation: expected 6 but found 8

(indentation)


[warning] 137-137: wrong indentation: expected 6 but found 8

(indentation)


[error] 142-142: no new line character at the end of file

(new-line-at-end-of-file)


139-142: Remove or document commented configuration.

The view_batch configuration is commented out without explanation. Either:

  1. Remove it if it's not needed
  2. Document why it's kept as reference
🧰 Tools
🪛 yamllint

[error] 142-142: no new line character at the end of file

(new-line-at-end-of-file)

dreem/training/train.py (4)

Line range hint 57-73: Add error handling for dataset initialization.

The code assumes all datasets will be successfully loaded, but according to related changes, get_dataset() may return None. Consider adding validation to handle cases where datasets are unavailable.

     train_dataset = train_cfg.get_dataset(mode="train")
+    if train_dataset is None:
+        raise ValueError("Training dataset could not be loaded")
     train_dataloader = train_cfg.get_dataloader(train_dataset, mode="train")

     val_dataset = train_cfg.get_dataset(mode="val")
+    if val_dataset is None:
+        logger.warning("Validation dataset not available")
     val_dataloader = train_cfg.get_dataloader(val_dataset, mode="val")

     test_dataset = train_cfg.get_dataset(mode="test")
+    if test_dataset is None:
+        logger.warning("Test dataset not available")
     test_dataloader = train_cfg.get_dataloader(test_dataset, mode="test")

Line range hint 74-82: Consider memory optimization for batch viewing.

The batch viewing feature loads a batch into memory even when view_batch.no_train is True. Consider moving the batch loading inside the if block to avoid unnecessary memory usage.

     dataset = TrackingDataset(
         train_dl=train_dataloader, val_dl=val_dataloader, test_dl=test_dataloader
     )

     if cfg.view_batch.enable:
+        if cfg.view_batch.no_train:
+            instances = next(iter(train_dataset))
+            view_training_batch(instances, num_frames=cfg.view_batch.num_frames)
+            return
         instances = next(iter(train_dataset))
         view_training_batch(instances, num_frames=cfg.view_batch.num_frames)
-
-        if cfg.view_batch.no_train:
-            return

Line range hint 83-83: Address TODO comment regarding torch.compile().

The TODO comment suggests exploring torch.compile() usage. This could potentially improve model performance.

Would you like me to help implement torch.compile() support or create an issue to track this enhancement?


RoPE implementation exists but configuration needs attention

The codebase has a complete RoPE implementation in dreem/models/embedding.py with both standard and Fourier variants, but the current configuration in params.yaml uses mode: learned for positional embeddings. The training infrastructure properly handles RoPE through the embedding system, but:

  • The model configuration needs to be updated to use mode: rope for positional embeddings
  • When using RoPE, the embedding aggregation method must be either 'stack' or 'concatenate', not 'average'
  • The Fourier variant requires additional parameters (fourier_n_components) to be configured
🔗 Analysis chain

Line range hint 56-87: Verify RoPE configuration handling.

Given that this PR implements RoPE and Fourier RoPE, let's verify that the training configuration properly handles the related parameters.


Let me try another verification approach to check for RoPE-related configurations and implementations in the codebase.


Let me check the actual RoPE implementation and configuration details to make a final assessment.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for RoPE-related configuration handling

# Test 1: Search for RoPE configuration in config files
echo "Checking for RoPE configuration..."
rg -l "rope" configs/

# Test 2: Search for RoPE-related model parameters
echo "Checking for RoPE implementation..."
ast-grep --pattern 'class $_ {
  $$$
  rope($$$) {
    $$$
  }
  $$$
}'

Length of output: 304


Script:

#!/bin/bash
# Search for RoPE-related configurations and implementations across the codebase
echo "Searching for RoPE-related files and code..."
rg -l "rope|RoPE" 

# Look for configuration-related files
echo -e "\nSearching for configuration files..."
fd -e yaml -e yml -e json

# Check for model configuration handling
echo -e "\nChecking model configuration handling..."
rg -A 5 "get_gtr_runner|get_model" 

Length of output: 2868


Script:

#!/bin/bash
# Check RoPE implementation in transformer and embedding files
echo "Checking transformer.py for RoPE implementation..."
rg -A 10 "rope|RoPE" dreem/models/transformer.py

echo -e "\nChecking embedding.py for RoPE implementation..."
rg -A 10 "rope|RoPE" dreem/models/embedding.py

# Check configuration parameters
echo -e "\nChecking configuration parameters..."
cat dreem/training/configs/params.yaml

Length of output: 12157

dreem/io/config.py (2)

43-44: Document the implications of disabling structure protection.

Setting struct=False allows for dynamic updates but could lead to unintended modifications. Consider adding a comment explaining why structure protection is disabled and any precautions that should be taken.


202-202: LGTM! Improved error handling for empty datasets.

Good addition of explicit handling for empty datasets. Consider making the warning message more specific about why the dataset might be empty (e.g., no matching files, filtering criteria too strict, etc.).

-            logger.warn(f"Length of {mode} dataset is {len(dataset)}! Returning None")
+            logger.warn(f"Length of {mode} dataset is {len(dataset)}! This might indicate no matching files or strict filtering criteria. Returning None")

Also applies to: 256-259

🧰 Tools
🪛 Ruff

202-202: Undefined name SleapDataset

(F821)


202-202: Undefined name MicroscopyDataset

(F821)


202-202: Undefined name CellTrackingDataset

(F821)

dreem/io/instance.py (1)

568-572: Enhance documentation for embedding aggregation methods.

The comment briefly mentions different embedding aggregation methods ("average", "stack", "concatenate") but their requirements and behaviors aren't fully documented.

Consider adding a docstring section explaining the supported embedding aggregation methods:

def add_embedding(self, emb_type: str, embedding: torch.Tensor | dict) -> None:
    """Save embedding to instance embedding dictionary.

    Args:
        emb_type: Key/embedding type to be saved to dictionary
        embedding: The actual embedding data, which can be:
            - A torch.Tensor for "average" aggregation method
            - A dict for "stack" or "concatenate" aggregation methods

    Note:
        The embedding format depends on the aggregation method:
        - "average": Expects a tensor that will be expanded to rank 2
        - "stack"/"concatenate": Expects a dict containing the embedding components
    """
🧰 Tools
🪛 Ruff

569-569: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

dreem/inference/eval.py (1)

Line range hint 63-63: Remove unnecessary f prefix in string

The log message does not contain any placeholders, so the f prefix is unnecessary. Replace logger.info(f"Computing the following metrics:") with logger.info("Computing the following metrics:").

Apply this diff to fix the issue:

-logger.info(f"Computing the following metrics:")
+logger.info("Computing the following metrics:")
🧰 Tools
🪛 Ruff

30-30: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


56-56: f-string without any placeholders

Remove extraneous f prefix

(F541)

dreem/models/attention_head.py (1)

22-25: Simplify the assignment of self.embedding_agg_method

You can simplify the code by using the dict.get() method.

Apply this diff to simplify the code:

-        if "embedding_agg_method" in kwargs:
-            self.embedding_agg_method = kwargs["embedding_agg_method"]
-        else:
-            self.embedding_agg_method = None
+        self.embedding_agg_method = kwargs.get("embedding_agg_method", None)
🧰 Tools
🪛 Ruff

22-25: Use self.embedding_agg_method = kwargs.get("embedding_agg_method", None) instead of an if block

Replace with self.embedding_agg_method = kwargs.get("embedding_agg_method", None)

(SIM401)

dreem/inference/track.py (2)

118-118: Verify that the checkpoint path exists before loading

Ensuring the checkpoint file exists can prevent runtime errors and provide clearer feedback if the file is missing.

Consider adding a check for the checkpoint file:

 checkpoint = pred_cfg.cfg.ckpt_path
+if not os.path.exists(checkpoint):
+    raise FileNotFoundError(f"Checkpoint file not found at {checkpoint}")

Line range hint 131-132: Use os.makedirs(outdir, exist_ok=True) efficiently

The check "outdir" in pred_cfg.cfg might not be necessary if pred_cfg.cfg.get('outdir', './results') is used to provide a default value.

Simplify the outdir assignment:

-outdir = pred_cfg.cfg.outdir if "outdir" in pred_cfg.cfg else "./results"
+outdir = pred_cfg.cfg.get('outdir', './results')
 os.makedirs(outdir, exist_ok=True)
🧰 Tools
🪛 Ruff

100-100: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


128-128: f-string without any placeholders

Remove extraneous f prefix

(F541)

dreem/inference/tracker.py (1)

209-211: Rename frames_to_track for better clarity.

The variable frames_to_track may not clearly convey its purpose. Consider renaming it to tracking_window or current_window_frames to improve readability and clarity.

Apply this diff to rename the variable:

- frames_to_track = tracked_frames + [frame_to_track]  # better var name?
+ tracking_window = tracked_frames + [frame_to_track]
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 11a33c8 and 711028a.

⛔ Files ignored due to path filters (1)
  • dreem/training/configs/test_batch_train.csv is excluded by !**/*.csv
📒 Files selected for processing (23)
  • .gitignore (1 hunks)
  • dreem/datasets/sleap_dataset.py (3 hunks)
  • dreem/inference/eval.py (1 hunks)
  • dreem/inference/post_processing.py (3 hunks)
  • dreem/inference/track.py (2 hunks)
  • dreem/inference/tracker.py (14 hunks)
  • dreem/io/config.py (7 hunks)
  • dreem/io/instance.py (1 hunks)
  • dreem/models/__init__.py (1 hunks)
  • dreem/models/attention_head.py (2 hunks)
  • dreem/models/embedding.py (16 hunks)
  • dreem/models/gtr_runner.py (2 hunks)
  • dreem/models/mlp.py (2 hunks)
  • dreem/models/transformer.py (14 hunks)
  • dreem/training/configs/base.yaml (6 hunks)
  • dreem/training/configs/override.yaml (1 hunks)
  • dreem/training/train.py (2 hunks)
  • scripts/run_eval.py (1 hunks)
  • scripts/run_tracker.py (1 hunks)
  • scripts/run_trainer.py (1 hunks)
  • tests/test_inference.py (2 hunks)
  • tests/test_models.py (12 hunks)
  • tests/test_training.py (1 hunks)
✅ Files skipped from review due to trivial changes (2)
  • .gitignore
  • tests/test_training.py
🧰 Additional context used
🪛 Ruff
dreem/datasets/sleap_dataset.py

163-163: Local variable e is assigned to but never used

Remove assignment to unused variable e

(F841)

dreem/inference/eval.py

30-30: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


56-56: f-string without any placeholders

Remove extraneous f prefix

(F541)

dreem/inference/track.py

100-100: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

dreem/inference/tracker.py

196-196: Loop control variable i not used within loop body

Rename unused i to _i

(B007)


200-200: Loop control variable i not used within loop body

Rename unused i to _i

(B007)

dreem/io/config.py

202-202: Undefined name SleapDataset

(F821)


202-202: Undefined name MicroscopyDataset

(F821)


202-202: Undefined name CellTrackingDataset

(F821)


263-263: Undefined name SleapDataset

(F821)


263-263: Undefined name MicroscopyDataset

(F821)


263-263: Undefined name CellTrackingDataset

(F821)

dreem/io/instance.py

569-569: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

dreem/models/__init__.py

3-3: .embedding.Embedding imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


3-3: .embedding.FourierPositionalEmbeddings imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)

dreem/models/attention_head.py

22-25: Use self.embedding_agg_method = kwargs.get("embedding_agg_method", None) instead of an if block

Replace with self.embedding_agg_method = kwargs.get("embedding_agg_method", None)

(SIM401)

dreem/models/embedding.py

166-166: f-string without any placeholders

Remove extraneous f prefix

(F541)

dreem/models/transformer.py

85-87: Use embedding_meta.get("embedding_agg_method", "average") instead of an if block

Replace with embedding_meta.get("embedding_agg_method", "average")

(SIM401)


90-92: Use embedding_meta.get("use_fourier", False) instead of an if block

Replace with embedding_meta.get("use_fourier", False)

(SIM401)

scripts/run_eval.py

1-1: dreem.training.train imported but unused

Remove unused import: dreem.training.train

(F401)

scripts/run_tracker.py

3-3: os imported but unused

Remove unused import: os

(F401)

scripts/run_trainer.py

3-3: os imported but unused

Remove unused import: os

(F401)

tests/test_models.py

18-18: dreem.models.transformer.apply_embeddings imported but unused

Remove unused import: dreem.models.transformer.apply_embeddings

(F401)


168-168: pytest.raises(Exception) should be considered evil

(B017)


493-493: Local variable query_pos_emb is assigned to but never used

Remove assignment to unused variable query_pos_emb

(F841)

🪛 yamllint
dreem/training/configs/base.yaml

[warning] 19-19: wrong indentation: expected 6 but found 8

(indentation)


[warning] 22-22: wrong indentation: expected 6 but found 8

(indentation)

dreem/training/configs/override.yaml

[error] 3-3: trailing spaces

(trailing-spaces)


[warning] 4-4: wrong indentation: expected 4 but found 6

(indentation)


[warning] 20-20: wrong indentation: expected 6 but found 8

(indentation)


[warning] 24-24: wrong indentation: expected 6 but found 8

(indentation)


[warning] 29-29: too many spaces after colon

(colons)


[error] 43-43: trailing spaces

(trailing-spaces)


[warning] 46-46: wrong indentation: expected 6 but found 8

(indentation)


[warning] 58-58: too few spaces after comma

(commas)


[warning] 66-66: too few spaces after comma

(commas)


[error] 76-76: trailing spaces

(trailing-spaces)


[warning] 78-78: wrong indentation: expected 4 but found 6

(indentation)


[error] 78-78: trailing spaces

(trailing-spaces)


[error] 99-99: trailing spaces

(trailing-spaces)


[error] 112-112: trailing spaces

(trailing-spaces)


[error] 130-130: trailing spaces

(trailing-spaces)


[warning] 132-132: wrong indentation: expected 2 but found 4

(indentation)


[warning] 133-133: wrong indentation: expected 6 but found 8

(indentation)


[warning] 137-137: wrong indentation: expected 6 but found 8

(indentation)


[error] 142-142: no new line character at the end of file

(new-line-at-end-of-file)

🔇 Additional comments (26)
scripts/run_eval.py (1)

1-12: Verify the configuration file structure.

Let's verify the structure of the inference configuration file to ensure it matches the expected format.

✅ Verification successful

Configuration structure matches the expected format

The inference configuration file in tests/configs/inference.yaml follows the correct structure as documented in dreem/inference/configs/README.md. It contains all the required sections:

  • Valid tracker configuration with correct parameters
  • Properly structured dataset.test_dataset with required fields
  • Correct dataloader.test_dataloader settings

The test configuration appears to be a simplified version of the reference config, suitable for testing purposes.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check the structure of the inference configuration file

# Test 1: Verify the config file exists
if [ -f "tests/configs/inference.yaml" ]; then
    echo "Found inference.yaml"
    cat "tests/configs/inference.yaml"
else
    echo "Warning: inference.yaml not found at expected location"
fi

# Test 2: Search for other inference config examples in the codebase
echo "Searching for other inference config examples:"
rg -g "*.yaml" -g "*.yml" -l "inference|eval" "tests/" "configs/"

Length of output: 804


Script:

#!/bin/bash
# Let's verify if the configuration structure matches the expected format by checking:
# 1. Usage of this config in the codebase
# 2. Other similar config files that might serve as reference

# Check how the config is used in the codebase
ast-grep --pattern 'eval.run($_)'

# Look for similar config files in the entire repository
fd -e yaml -e yml

# Check for any documentation about config structure
fd README -t f

Length of output: 459


Script:

#!/bin/bash
# Let's check the reference config files and documentation

# Check the reference inference config
cat dreem/inference/configs/inference.yaml

# Check the config documentation
cat dreem/inference/configs/README.md

# Check the implementation of eval.run to understand expected config structure
ast-grep --pattern 'def run($$$)'

Length of output: 7574

🧰 Tools
🪛 Ruff

1-1: dreem.training.train imported but unused

Remove unused import: dreem.training.train

(F401)

dreem/models/__init__.py (1)

3-3: LGTM! Import aligns with new Fourier positional embeddings feature.

The addition of FourierPositionalEmbeddings import is consistent with the PR's objective of implementing RoPE and Fourier RoPE functionality.

🧰 Tools
🪛 Ruff

3-3: .embedding.Embedding imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


3-3: .embedding.FourierPositionalEmbeddings imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)

scripts/run_tracker.py (1)

10-12: ⚠️ Potential issue

Add error handling for configuration loading and execution

The script lacks error handling for common issues like missing config files or invalid configurations.

Consider adding proper error handling:

-cfg = OmegaConf.load(config)
-track.run(cfg)
+try:
+    if not os.path.exists(config):
+        raise FileNotFoundError(f"Config file not found: {config}")
+    
+    cfg = OmegaConf.load(config)
+    track.run(cfg)
+except FileNotFoundError as e:
+    print(f"Error: {e}")
+    sys.exit(1)
+except Exception as e:
+    print(f"Error during tracking: {e}")
+    sys.exit(1)

Also, consider adding input validation for the config structure before running the tracker.

Let's verify the expected config structure:

✅ Verification successful

Let's examine the actual config files to understand their structure and validate the need for error handling.


Let's check if there's any existing error handling in the track module and how the config is validated.


Let's check the config validation implementation and test files to see if there's existing error handling we should be aware of.


Based on the code analysis, I can now provide the final response:

Error handling is already implemented in the Config class, but script-level handling is still needed

The Config class has built-in error handling for:

  • Config merging and validation through OmegaConf
  • Parameter updates with try-except blocks
  • File path resolution and dataset validation
  • Proper logging of errors

However, the script should still handle file existence and basic config loading errors:

try:
    if not os.path.exists(config):
        raise FileNotFoundError(f"Config file not found: {config}")
    
    cfg = OmegaConf.load(config)
    track.run(cfg)
except FileNotFoundError as e:
    print(f"Error: {e}")
    sys.exit(1)
except Exception as e:
    print(f"Error during tracking: {e}")
    sys.exit(1)
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check the structure of existing inference YAML configs
# to understand the expected format

# Search for YAML files that might contain inference configurations
fd -e yaml -e yml | xargs rg -l 'inference|track' 

Length of output: 201


Script:

#!/bin/bash
# Check contents of inference config files to understand structure
echo "=== dreem/inference/configs/inference.yaml ==="
cat dreem/inference/configs/inference.yaml

echo -e "\n=== tests/configs/inference.yaml ==="
cat tests/configs/inference.yaml

# Check if track.run has any built-in error handling
ast-grep --pattern 'def run($$$)'

Length of output: 1408


Script:

#!/bin/bash
# Look for track.run implementation and any validation
rg -A 5 "def run" 

# Check for any config validation or error handling
rg -l "OmegaConf|ConfigurationError|ValidationError"

Length of output: 1279


Script:

#!/bin/bash
# Check config validation implementation
echo "=== dreem/io/config.py ==="
cat dreem/io/config.py

# Check test files for error handling examples
echo -e "\n=== tests/test_config.py ==="
cat tests/test_config.py

Length of output: 19202

dreem/training/configs/base.yaml (2)

137-148: Verify single-epoch training configuration

The training configuration has been set to run for exactly one epoch, which seems unusually short for training deep learning models, especially with RoPE embeddings.

Is this intentional or is it meant for development/testing purposes only? Consider:

  1. Adding a comment explaining why single-epoch training is used
  2. Creating separate configurations for development and production training
  3. Documenting recommended epoch settings for production use

19-23: Verify RoPE implementation and enhance documentation

The configuration now supports RoPE for both positional and temporal embeddings, which aligns with the PR objectives. However, we should ensure proper implementation and documentation.

Consider adding docstrings or comments explaining:

  1. The differences between "fixed", "learned", and "rope" modes
  2. The impact of normalization on RoPE embeddings
  3. The behavior of each embedding_agg_method ("stack", "average", "concatenate")
✅ Verification successful

Let me gather more specific information about the RoPE implementation.


RoPE implementation is properly implemented with correct constraints

The codebase shows a complete and well-implemented RoPE (Rotary Positional Encoding) system:

  1. Proper implementation of both standard RoPE (RotaryPositionalEmbeddings) and Fourier RoPE (FourierRotaryPositionalEmbeddings) with correct mathematical operations
  2. Appropriate handling of embedding modes with validation:
    • RoPE mode correctly restricts aggregation methods to 'stack' or 'concatenate' only
    • Proper application of RoPE to queries and keys but not values in transformer layers
  3. Documentation exists explaining the implementation details and references to original papers

The configuration options in base.yaml accurately reflect the implemented functionality, and the constraints on embedding aggregation methods are properly enforced in the code.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Verify RoPE implementation in the codebase
echo "Checking for RoPE implementation in embedding modules..."
rg -l "rope" --type py

# Look for embedding mode handling
ast-grep --pattern 'mode = $_' 

Length of output: 389


Script:

#!/bin/bash
# Check RoPE implementation in embedding.py
echo "=== Checking embedding.py for RoPE implementation ==="
rg -A 10 "class.*Embedding" dreem/models/embedding.py

# Check for mode handling in embedding.py
echo "=== Checking mode handling in embedding.py ==="
rg -B 2 -A 5 "mode.*=.*" dreem/models/embedding.py

# Check transformer.py for RoPE usage
echo "=== Checking transformer.py for RoPE implementation ==="
rg -B 2 -A 5 "rope" dreem/models/transformer.py

# Check for embedding aggregation methods
echo "=== Checking embedding aggregation methods ==="
rg -B 2 -A 5 "stack|average|concatenate" dreem/models/embedding.py

Length of output: 10430

🧰 Tools
🪛 yamllint

[warning] 19-19: wrong indentation: expected 6 but found 8

(indentation)


[warning] 22-22: wrong indentation: expected 6 but found 8

(indentation)

dreem/training/configs/override.yaml (1)

18-25: ⚠️ Potential issue

Missing RoPE configuration in embedding settings.

The PR objectives indicate implementation of RoPE (Rotary Positional Encoding) and Fourier RoPE, but the embedding configuration only shows fixed positional encoding. Consider adding RoPE-specific configuration parameters here.

Example configuration structure for RoPE:

embedding_meta:
  pos:
    mode: "rope"  # instead of "fixed"
    dim: 128      # rotary embedding dimension
    base: 10000   # base for frequency computation
    normalize: true
🧰 Tools
🪛 yamllint

[warning] 20-20: wrong indentation: expected 6 but found 8

(indentation)


[warning] 24-24: wrong indentation: expected 6 but found 8

(indentation)

dreem/training/train.py (2)

56-56: LGTM! Good code organization.

The added blank line improves readability by separating model initialization from dataset setup.


87-87: LGTM! Good code organization.

The added blank line improves readability by separating accelerator setup from device count determination.

dreem/inference/post_processing.py (2)

161-162: Review distance normalization approach.

The normalization method has been changed from using box sizes to using mean across dimensions. This change, combined with the commented-out code and TODO about pixel scaling, suggests uncertainty about the best approach.

Questions to consider:

  1. Why was the box size normalization removed?
  2. How does the mean normalization affect the distance threshold max_center_dist?
  3. Should we maintain both normalization methods as configurable options?

Let's analyze the impact of this change:

#!/bin/bash
# Search for tests related to max_center_dist to understand expected behavior
rg -A 10 "max_center_dist.*test" --type python

# Look for configuration files that might specify distance thresholds
fd -e yaml -e json | xargs rg "max_center_dist"

154-154: 🛠️ Refactor suggestion

Verify tracking behavior with window-based nonk_boxes.

The TODO comment indicates that nonk_boxes should only include boxes from the previous frame, but currently uses boxes from the entire window. This could affect tracking accuracy and performance.

Let's verify how nonk_boxes is being used across the codebase:

Consider implementing a frame selection mechanism:

+        # Extract boxes only from the previous frame
+        prev_frame_mask = torch.zeros_like(nonk_boxes[:, 0, 0], dtype=torch.bool)
+        prev_frame_mask[-1] = True  # Select only the last frame
+        nonk_boxes = nonk_boxes[prev_frame_mask]
         dist = ((k_ct[:, None, :, :] - nonk_ct[None, :, :, :]) ** 2).sum(
             dim=-1
         )  # n_k x Np

Also applies to: 158-160

tests/test_inference.py (1)

218-219: LGTM! Test updates match the new function signature.

The addition of h and w parameters to filter_max_center_dist calls correctly reflects the updated function signature. The test maintains its original validation logic while accommodating the new dimensional parameters.

Let's verify the function signature change:

Also applies to: 231-232

dreem/models/gtr_runner.py (2)

304-304: LGTM: More robust video name extraction.

The change to use path splitting instead of extension splitting is more robust as it handles paths with multiple dots correctly.


313-325: Verify error handling for gt_track_id access.

The code assumes that every instance has a valid gt_track_id tensor. Consider adding error handling to gracefully handle cases where gt_track_id might be None or not a tensor.

✅ Verification successful

Based on the shell script results, I can now generate the final response as there is sufficient evidence from the codebase.

The gt_track_id access is properly handled in the codebase

The Instance class in dreem/io/instance.py has proper handling for gt_track_id:

  • It has a default value of -1 when not set
  • The setter method handles None case by setting an empty tensor
  • The has_gt_track_id() method checks for empty tensor before access
  • The getter returns a tensor which is safe to call .item() on

The code in question is safe because:

  1. All instances are created with valid gt_track_id tensors through proper initialization
  2. The frame.instances list only contains valid Instance objects
  3. The .item() access is protected by the Instance class implementation
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for any error handling around gt_track_id access in the codebase

# Test 1: Look for any None checks or error handling around gt_track_id
rg -A 3 "gt_track_id" 

# Test 2: Look for any instance creation or modification of gt_track_id
ast-grep --pattern 'instance.$$$gt_track_id$$$'

Length of output: 21974

dreem/datasets/sleap_dataset.py (2)

109-109: LGTM: Improved video reader management

The change from a list to a dictionary for storing video readers is a good improvement. It provides efficient lookup by video name and prevents duplicate loading of the same video.


Line range hint 1-370: Verify relationship with RoPE implementation

While the changes improve video handling, it's unclear how this file relates to the PR's objective of implementing RoPE (Rotary Positional Encoding) and Fourier RoPE.

Let's check for RoPE-related code in the codebase:

🧰 Tools
🪛 Ruff

163-163: Local variable e is assigned to but never used

Remove assignment to unused variable e

(F841)

dreem/io/config.py (2)

179-181: LGTM! Sorting ensures deterministic file ordering.

The addition of sorting to glob results is a good practice that ensures consistent and reproducible behavior across different systems.


420-420: Document and verify checkpoint naming change.

The change from 'best' to 'final' in checkpoint naming could impact existing scripts or workflows. Please:

  1. Document this change in the code comments
  2. Verify that dependent code is updated accordingly
✅ Verification successful

Let me try a different search strategy to verify the checkpoint naming impact.


Based on the search results, I can now generate the final response:

Checkpoint naming change appears safe to proceed

The change from 'best' to 'final' in the checkpoint naming pattern appears safe as:

  • No hardcoded references to 'best' checkpoint names were found in the codebase
  • All checkpoint loading operations use path variables rather than hardcoded patterns
  • The checkpoint loading mechanism (load_from_checkpoint) is consistently used across files

However, consider adding a comment above the line to document this naming convention:

# Format: {epoch}-final-{metric} (e.g., "10-final-accuracy")
checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-final-{{{metric}}}"
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for references to checkpoint filenames containing 'best'
echo "Searching for potential impacts of checkpoint naming change..."
rg -l 'best.*checkpoint' --type py

# Search for checkpoint loading code
echo "Searching for checkpoint loading patterns..."
ast-grep --pattern 'load_from_checkpoint($_)'

Length of output: 318


Script:

#!/bin/bash
# Search for any references to checkpoint names or patterns
echo "Searching for checkpoint name references..."
rg -A 2 'CHECKPOINT_NAME' --type py

# Search for checkpoint loading or saving patterns
echo "Searching for checkpoint operations..."
rg -A 2 '(save|load).*checkpoint' --type py

# Search for documentation or comments about checkpoints
echo "Searching for checkpoint documentation..."
rg -B 1 'checkpoint' '*.md' '*.rst' '*.txt'

Length of output: 2353

tests/test_models.py (2)

38-40: LGTM: Correct addition of embedding_agg_method parameter

The ATTWeightHead is correctly instantiated with the new embedding_agg_method parameter set to "average".


210-250: LGTM: Proper implementation of test_rope_embedding

The test_rope_embedding function appropriately tests the RoPE embedding functionality, ensuring correct output dimensions and validating that the embeddings are distinct.

dreem/inference/tracker.py (4)

354-355: Include height and width when scaling boxes.

Ensure that h and w correctly represent the image height and width for scaling bounding boxes. This is important for accurate post-processing in filtering based on center distance.


Line range hint 479-489: Clarify threshold logic in track assignment.

The condition in the if-statement determines whether to assign an instance to an existing track or create a new one. Ensure that the logic correctly reflects the intended behavior, especially regarding the threshold comparison.

Double-check the threshold comparison to confirm that it aligns with the intended logic. If the intention is to create a new track when the association score is lower than the threshold, the condition may need to be adjusted.


288-291: ⚠️ Potential issue

Handle potential empty tensors in concatenation.

When concatenating features in reid_features, ensure that all tensors are non-empty to avoid runtime errors. If any frame does not have instances, frame.get_features() may return an empty tensor.

Consider adding a check before concatenation:

  reid_features = torch.cat([frame.get_features() for frame in frames], dim=0)[
      None
  ]  # (1, total_instances, D=512)
+ if reid_features.size(1) == 0:
+     logger.warning("No instance features available for tracking.")
+     return query_frame

Likely invalid or redundant comment.


Line range hint 261-307: Ensure proper handling of empty frame instances.

When concatenating features for frames, ensure that frames without instances are handled correctly to prevent runtime errors due to empty tensors.

Run the following script to check for frames without instances:

dreem/models/embedding.py (3)

131-140: Addition of RoPE mode in embeddings looks good

The implementation of the 'rope' mode embeddings and the instantiation of the appropriate positional embeddings classes enhance the flexibility of the Embedding class. The conditional logic correctly handles the creation of rope_instance based on the use_fourier flag.


174-227: New embedding transformation methods are well-implemented

The methods _transform, _apply_rope, and _apply_additive_embeddings effectively handle the embedding transformations based on the selected aggregation method and embedding mode. The code is clear and logically structured.


228-248: Ensure all calls to Embedding.forward are updated

The forward method signature has been modified to include an additional parameter x and now returns both xout and emb. Please verify that all calls to Embedding.forward in the codebase are updated to match the new signature and handle the returned tuple appropriately.

Run the following script to identify usages of Embedding.forward that may need updating:

Comment on lines 1 to 2
from dreem.training import train
from omegaconf import OmegaConf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix import statements and remove unused imports.

The script has import-related issues:

  1. The train import is unused and should be removed
  2. The eval module is used but not imported

Apply this diff to fix the imports:

-from dreem.training import train
+from dreem.inference import eval
from omegaconf import OmegaConf
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from dreem.training import train
from omegaconf import OmegaConf
from dreem.inference import eval
from omegaconf import OmegaConf
🧰 Tools
🪛 Ruff

1-1: dreem.training.train imported but unused

Remove unused import: dreem.training.train

(F401)

Comment on lines 8 to 12
inference_config = "tests/configs/inference.yaml"

cfg = OmegaConf.load(inference_config)

eval.run(cfg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve configuration handling and add error handling.

The script needs improvements in configuration handling:

  1. Hardcoded configuration path should be configurable
  2. Missing error handling for configuration loading
  3. No validation of loaded configuration

Consider refactoring like this:

-inference_config = "tests/configs/inference.yaml"
-
-cfg = OmegaConf.load(inference_config)
-
-eval.run(cfg)
+import argparse
+import sys
+from pathlib import Path
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Run model evaluation')
+    parser.add_argument('--config', type=str, required=True,
+                       help='Path to inference configuration file')
+    return parser.parse_args()
+
+def main():
+    args = parse_args()
+    config_path = Path(args.config)
+    
+    if not config_path.exists():
+        print(f"Error: Configuration file not found: {config_path}")
+        sys.exit(1)
+        
+    try:
+        cfg = OmegaConf.load(config_path)
+        eval.run(cfg)
+    except Exception as e:
+        print(f"Error during evaluation: {str(e)}")
+        sys.exit(1)
+
+if __name__ == "__main__":
+    main()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
inference_config = "tests/configs/inference.yaml"
cfg = OmegaConf.load(inference_config)
eval.run(cfg)
import argparse
import sys
from pathlib import Path
def parse_args():
parser = argparse.ArgumentParser(description='Run model evaluation')
parser.add_argument('--config', type=str, required=True,
help='Path to inference configuration file')
return parser.parse_args()
def main():
args = parse_args()
config_path = Path(args.config)
if not config_path.exists():
print(f"Error: Configuration file not found: {config_path}")
sys.exit(1)
try:
cfg = OmegaConf.load(config_path)
eval.run(cfg)
except Exception as e:
print(f"Error during evaluation: {str(e)}")
sys.exit(1)
if __name__ == "__main__":
main()

Comment on lines 1 to 3
from dreem.inference import track
from omegaconf import OmegaConf
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Clean up imports

The os module is imported but not used in the code.

Apply this diff to remove the unused import:

from dreem.inference import track
from omegaconf import OmegaConf
-import os
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from dreem.inference import track
from omegaconf import OmegaConf
import os
from dreem.inference import track
from omegaconf import OmegaConf
🧰 Tools
🪛 Ruff

3-3: os imported but unused

Remove unused import: os

(F401)

Comment on lines 5 to 8
# /Users/mustafashaikh/dreem/dreem/training
# /Users/main/Documents/GitHub/dreem/dreem/training
# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training")
config = "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/inference.yaml"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Remove development paths and use configurable config path

The code contains hardcoded development paths in comments and a hardcoded absolute path to the config file. This makes the script less portable and harder to use in different environments.

Consider:

  1. Remove the commented-out development paths
  2. Make the config path configurable via command-line arguments using argparse:
from dreem.inference import track
from omegaconf import OmegaConf
+import argparse

-# /Users/mustafashaikh/dreem/dreem/training
-# /Users/main/Documents/GitHub/dreem/dreem/training
-# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training")
-config = "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/inference.yaml"

+def parse_args():
+    parser = argparse.ArgumentParser(description='Run tracker with YAML config')
+    parser.add_argument('--config', required=True, help='Path to inference YAML config file')
+    return parser.parse_args()
+
+args = parse_args()
+config = args.config
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# /Users/mustafashaikh/dreem/dreem/training
# /Users/main/Documents/GitHub/dreem/dreem/training
# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training")
config = "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/inference.yaml"
from dreem.inference import track
from omegaconf import OmegaConf
import argparse
def parse_args():
parser = argparse.ArgumentParser(description='Run tracker with YAML config')
parser.add_argument('--config', required=True, help='Path to inference YAML config file')
return parser.parse_args()
args = parse_args()
config = args.config

Comment on lines 5 to 7
# /Users/mustafashaikh/dreem/dreem/training
# /Users/main/Documents/GitHub/dreem/dreem/training
# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove personal directory paths from comments.

These comments contain personal directory paths that should not be in the codebase.

Remove these commented paths as they:

  1. Expose personal information
  2. Don't provide value to other developers
  3. May pose security risks

def forward(self, seq_positions: torch.Tensor) -> torch.Tensor:
if mode.lower() == "rope" and self.embedding_agg_method == "average":
raise ValueError(
f"Cannot use aggregation method 'average' for rope embedding; must use 'stack' or 'concatenate'"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unnecessary f prefix in the error message

Line 166 contains an f-string without any placeholders. The f prefix is unnecessary and can be removed to clean up the code.

Apply this diff to fix the issue:

-f"Cannot use aggregation method 'average' for rope embedding; must use 'stack' or 'concatenate'"
+"Cannot use aggregation method 'average' for rope embedding; must use 'stack' or 'concatenate'"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
f"Cannot use aggregation method 'average' for rope embedding; must use 'stack' or 'concatenate'"
"Cannot use aggregation method 'average' for rope embedding; must use 'stack' or 'concatenate'"
🧰 Tools
🪛 Ruff

166-166: f-string without any placeholders

Remove extraneous f prefix

(F541)

Comment on lines +728 to +733
proj = nn.Linear(queries_cat.shape[-1], queries.shape[-1]).to(queries_cat.device)
norm = nn.LayerNorm(queries.shape[-1]).to(queries_cat.device)

queries = proj(queries_cat)
queries = norm(queries)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid instantiating nn modules inside functions to ensure parameters are registered

Defining nn.Linear and nn.LayerNorm inside the apply_fourier_embeddings function will result in new instances being created every time the function is called. Their parameters won't be registered with the model, so they won't be updated during training, leading to unexpected behavior.

Consider moving the instantiation of proj and norm outside the function, possibly as attributes of the class. Would you like assistance in refactoring this code to ensure proper parameter registration?

Comment on lines +850 to +856
mlp = MLP(
input_dim=queries_t.shape[-1] * 3, # t,x,y
hidden_dim=queries_t.shape[-1] * 6, # not applied when num_layers=1
output_dim=queries_t.shape[-1],
num_layers=1,
dropout=0.0,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid creating nn modules inside functions to ensure parameters are registered

Instantiating the MLP module inside the collate_queries function means that its parameters won't be registered for optimization, and it will be recreated on every call. This prevents the model from learning the correct weights during training.

Consider defining the MLP instance outside the function, such as in the __init__ method of a class, and passing it as a parameter. Would you like assistance in refactoring this code to properly register the MLP parameters?

encoder_queries, pos_emb=ref_emb
) # (total_instances, batch_size, embed_dim)
# apply fourier embeddings
if "use_fourier" in self.embedding_meta and self.embedding_meta["use_fourier"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Simplify condition by using self.use_fourier

Since self.use_fourier already holds the boolean value, you can simplify the conditional check by using it directly.

Apply this diff to simplify the condition:

- if "use_fourier" in self.embedding_meta and self.embedding_meta["use_fourier"]:
+ if self.use_fourier:

Committable suggestion was skipped due to low confidence.

Comment on lines +85 to +98
embedding_meta["embedding_agg_method"]
if "embedding_agg_method" in embedding_meta
else "average"
)
self.use_fourier = (
embedding_meta["use_fourier"]
if "use_fourier" in embedding_meta
else False
)
self.fourier_n_components = (
embedding_meta["fourier_n_components"]
if "fourier_n_components" in embedding_meta
else self.d_model // 2
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Simplify dictionary key access using get() method

Instead of using if statements to check for the existence of keys in the embedding_meta dictionary, you can use the get() method to make the code cleaner and more readable.

Apply this diff to simplify the code:

- self.embedding_agg_method = (
-     embedding_meta["embedding_agg_method"]
-     if "embedding_agg_method" in embedding_meta
-     else "average"
- )
+ self.embedding_agg_method = embedding_meta.get("embedding_agg_method", "average")

- self.use_fourier = (
-     embedding_meta["use_fourier"]
-     if "use_fourier" in embedding_meta
-     else False
- )
+ self.use_fourier = embedding_meta.get("use_fourier", False)

- self.fourier_n_components = (
-     embedding_meta["fourier_n_components"]
-     if "fourier_n_components" in embedding_meta
-     else self.d_model // 2
- )
+ self.fourier_n_components = embedding_meta.get("fourier_n_components", self.d_model // 2)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
embedding_meta["embedding_agg_method"]
if "embedding_agg_method" in embedding_meta
else "average"
)
self.use_fourier = (
embedding_meta["use_fourier"]
if "use_fourier" in embedding_meta
else False
)
self.fourier_n_components = (
embedding_meta["fourier_n_components"]
if "fourier_n_components" in embedding_meta
else self.d_model // 2
)
self.embedding_agg_method = embedding_meta.get("embedding_agg_method", "average")
self.use_fourier = embedding_meta.get("use_fourier", False)
self.fourier_n_components = embedding_meta.get("fourier_n_components", self.d_model // 2)
🧰 Tools
🪛 Ruff

85-87: Use embedding_meta.get("embedding_agg_method", "average") instead of an if block

Replace with embedding_meta.get("embedding_agg_method", "average")

(SIM401)


90-92: Use embedding_meta.get("use_fourier", False) instead of an if block

Replace with embedding_meta.get("use_fourier", False)

(SIM401)

@shaikh58 shaikh58 changed the title Mustafa fourier pos emb rope RoPE, Fourier RoPE, Fourier learnable embeddings Oct 30, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (2)
dreem/datasets/sleap_dataset.py (2)

127-127: Add input validation in get_indices

The method should validate that the input index is within bounds of the available chunks.

 def get_indices(self, idx: int) -> tuple:
+    if not 0 <= idx < len(self.label_idx):
+        raise IndexError(f"Index {idx} is out of bounds for dataset with {len(self.label_idx)} chunks")
     return self.label_idx[idx], self.chunked_frame_idx[idx]

Line range hint 150-170: Fix unused variables

There are two unused variables that should be addressed:

  1. Rename the unused loop variable i to _
  2. Remove the unused exception variable e
-    for i, frame_ind in enumerate(frame_idx):
+    for _, frame_ind in enumerate(frame_idx):
         (
             instances,
             gt_track_ids,
             poses,
             shown_poses,
             point_scores,
             instance_score,
         ) = ([], [], [], [], [], [])

         frame_ind = int(frame_ind)
         lf = video[frame_ind]

         try:
             img = lf.image
-        except FileNotFoundError as e:
+        except FileNotFoundError:
             if video_name not in self.vid_readers:
                 self.vid_readers[video_name] = sio.load_video(video_name)
🧰 Tools
🪛 Ruff

166-166: Local variable e is assigned to but never used

Remove assignment to unused variable e

(F841)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 711028a and cf9453e.

📒 Files selected for processing (2)
  • dreem/datasets/base_dataset.py (1 hunks)
  • dreem/datasets/sleap_dataset.py (5 hunks)
✅ Files skipped from review due to trivial changes (1)
  • dreem/datasets/base_dataset.py
🧰 Additional context used
🪛 Ruff
dreem/datasets/sleap_dataset.py

150-150: Loop control variable i not used within loop body

Rename unused i to _i

(B007)


166-166: Local variable e is assigned to but never used

Remove assignment to unused variable e

(F841)

🔇 Additional comments (2)
dreem/datasets/sleap_dataset.py (2)

109-109: LGTM! Good improvement in video reader management

Converting vid_readers from a list to a dictionary improves efficiency by providing O(1) lookup and preventing duplicate video reader instances.


373-373: Previous review comment about video reader cleanup is still valid

Please refer to the previous review comment about improving the video reader cleanup implementation.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (7)
dreem/inference/eval.py (1)

63-64: Remove unnecessary f-strings and improve logging format.

The f-string prefix is unnecessary when there are no placeholders. Also, consider using a more structured format for logging metrics.

-logger.info(f"Computing the following metrics:")
-logger.info(model.metrics['test'])
+logger.info("Computing the following metrics:")
+logger.info("Test metrics: %s", model.metrics['test'])
🧰 Tools
🪛 Ruff

63-63: f-string without any placeholders

Remove extraneous f prefix

(F541)

dreem/models/attention_head.py (3)

22-25: Simplify kwargs handling using dict.get()

Replace the if-else block with a single line using dict.get() for cleaner code:

-        if "embedding_agg_method" in kwargs:
-            self.embedding_agg_method = kwargs["embedding_agg_method"]
-        else:
-            self.embedding_agg_method = None
+        self.embedding_agg_method = kwargs.get("embedding_agg_method", None)
🧰 Tools
🪛 Ruff

22-25: Use self.embedding_agg_method = kwargs.get("embedding_agg_method", None) instead of an if block

Replace with self.embedding_agg_method = kwargs.get("embedding_agg_method", None)

(SIM401)


19-19: Improve parameter documentation

The embedding_agg_method parameter documentation should specify the valid options ("stack" or None) and their implications.

-            embedding_agg_method: how the embeddings are aggregated; average/stack/concatenate
+            embedding_agg_method: Method for aggregating embeddings. When set to "stack",
+                uses separate attention heads for x, y, t dimensions with 1D convolution.
+                When None, uses standard MLP projections.

82-84: Consider parallelizing attention computations

The attention computations for t, x, y dimensions are performed sequentially. Consider using torch.nn.ModuleList and parallel computation for better performance.

+        # Store attention heads in ModuleList for parallel computation
+        self.attention_heads = torch.nn.ModuleList([
+            self.attn_t, self.attn_x, self.attn_y
+        ])
+        # Compute attention in parallel
+        attention_outputs = [
+            head(query=query[:, i, :], key=key_orig, value=key_orig)[0]
+            for i, head in enumerate(self.attention_heads)
+        ]
+        collated = torch.stack(attention_outputs, dim=0).permute(1, 0, 2)
dreem/training/losses.py (3)

84-84: Enhance shape documentation clarity.

While the shape documentation is helpful, the comment for asso_pred could be more precise. Consider clarifying what "total_instances across all frames" means, perhaps by using mathematical notation or providing an example.

-  # asso_pred is shape (total_instances across all frames, total_instances across all frames)
+  # asso_pred is shape (M, M) where M = sum(n_t), and n_t is the number of instances in each frame

Also applies to: 86-86


222-223: Improve shape documentation precision.

The shape documentation is helpful but could be more precise:

  1. Use mathematical notation for clarity
  2. Explain why +1 is needed for background class
  3. Document how track IDs map to class labels
-            # asso_preds_with_bg is shape (total_instances across all frames, n_t + 1) where +1 is for background class
-            # asso_gt_t is shape (total_instances across all frames); each entry is the track id of the gt instance i.e. the class label
+            # asso_preds_with_bg: shape (M, n_t + 1) where:
+            #   - M = sum(n_t) is total instances across all frames
+            #   - n_t is instances in current frame
+            #   - +1 dimension handles background class for unmatched predictions
+            # asso_gt_t: shape (M,) containing track IDs as class labels
+            #   - track ID ∈ [0, n_t-1] for matched instances
+            #   - track ID = n_t for background class

Line range hint 8-9: Consider implementing TODO suggestions for better code maintainability.

The TODOs suggest using named tensors and cleaning up names/comments. Named tensors could help prevent shape-related bugs and improve code clarity, especially given the complex tensor operations in this file.

Would you like help implementing named tensors or suggestions for name improvements?

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between cf9453e and 0b70f6e.

📒 Files selected for processing (3)
  • dreem/inference/eval.py (1 hunks)
  • dreem/models/attention_head.py (2 hunks)
  • dreem/training/losses.py (4 hunks)
🧰 Additional context used
🪛 Ruff
dreem/inference/eval.py

30-30: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


56-56: f-string without any placeholders

Remove extraneous f prefix

(F541)


63-63: f-string without any placeholders

Remove extraneous f prefix

(F541)

dreem/models/attention_head.py

22-25: Use self.embedding_agg_method = kwargs.get("embedding_agg_method", None) instead of an if block

Replace with self.embedding_agg_method = kwargs.get("embedding_agg_method", None)

(SIM401)

🔇 Additional comments (9)
dreem/inference/eval.py (5)

30-30: Simplify dictionary membership test.

The existing review comment about using if "batch_config" in cfg: instead of if "batch_config" in cfg.keys(): is still valid.

🧰 Tools
🪛 Ruff

30-30: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


34-36: Add input validation for task index.

The existing review comment about adding error handling for non-integer user input is still valid.


56-58: Improve logging consistency.

The existing review comments about:

  1. Removing unnecessary f prefix in logger.info(f"Using the following tracker:")
  2. Replacing print(model.tracker) with logger.info(model.tracker)
    are still valid.
🧰 Tools
🪛 Ruff

56-56: f-string without any placeholders

Remove extraneous f prefix

(F541)


Line range hint 22-22: Consider maintaining specific type hints.

The return type annotation has been changed from dict[int, sio.Labels] to dict. While this makes the function more flexible, it reduces type safety. If possible, please specify the expected key and value types to help with static type checking and code clarity.

🧰 Tools
🪛 Ruff

30-30: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


56-56: f-string without any placeholders

Remove extraneous f prefix

(F541)


63-63: f-string without any placeholders

Remove extraneous f prefix

(F541)


42-44: Verify RoPE implementation integration.

The PR objectives mention implementing RoPE and Fourier RoPE, but this file only shows batch configuration changes. Please verify that the hyperparameters being set here properly configure the RoPE implementation.

✅ Verification successful

Let me try a different search strategy to verify the RoPE implementation.

This broader search should help us:

  1. Find any RoPE-related code even if not in class definitions
  2. Locate positional encoding implementations that might be related
  3. Understand what hyperparameters are being configured in the codebase

Based on the search results, I can now provide a final response as there's clear evidence about the RoPE implementation and its integration with hyperparameters.

RoPE implementation is properly integrated with hyperparameter configuration

The codebase shows a complete RoPE implementation with proper hyperparameter integration:

  • RoPE is implemented in dreem/models/embedding.py with both standard and Fourier variants
  • The embedding configuration is properly handled through hyperparameters:
    • Mode can be set to "rope" in embedding configuration
    • Supports both standard RoPE and Fourier RoPE variants
    • Validates embedding aggregation methods ("stack" or "concatenate" required for RoPE)
  • The transformer model in dreem/models/transformer.py correctly applies RoPE:
    • Only applies to query and key vectors, not values
    • Properly integrates with both positional and temporal embeddings

The hyperparameter setting in eval.py will correctly configure RoPE when specified in the config.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for RoPE-related configurations and implementations
echo "Searching for RoPE configurations..."
rg -i "rope" --type yaml
echo "Searching for RoPE implementations..."
ast-grep --pattern 'class $_(RoPE|RotaryPositionalEncoding) { $$$ }'

Length of output: 301


Script:

#!/bin/bash
# Search for RoPE-related code and configurations more broadly
echo "Searching for RoPE in Python files..."
rg -i "rope" --type py -A 5

echo -e "\nSearching for positional encoding implementations..."
rg -i "position.*encod" --type py -A 5

echo -e "\nSearching for hyperparameter configurations..."
rg "set_hparams" --type py -A 5

Length of output: 31626

dreem/models/attention_head.py (2)

90-94: ⚠️ Potential issue

** Inconsistent use of projections**

The previous review comment about inconsistent use of projections is still valid. The code uses k = self.k_proj(key_stacked) but leaves q = self.q_proj(query) commented out, which could lead to inconsistencies in attention weight computation.

Either:

  1. Uncomment and use both projections:
 k = self.k_proj(key_stacked).transpose(1, 0)
-# q = self.q_proj(query).transpose(1, 0)
+q = self.q_proj(query).transpose(1, 0)
  1. Or document why only key projection is needed:
-# TODO: uncomment this if not using modified attention heads for t,x,y
+# Note: Only key projection is used because query projection is handled by
+# separate attention heads (attn_x, attn_y, attn_t)

65-72: ⚠️ Potential issue

Verify tensor reshaping logic

The reshaping operations assume the input tensors have specific dimensions divisible by 3. This could cause runtime errors if the assumptions don't hold.

Consider adding input validation:

+        if num_window_instances % 3 != 0 or num_query_instances % 3 != 0:
+            raise ValueError("Number of instances must be divisible by 3 when using stack mode")
dreem/training/losses.py (2)

Line range hint 44-63: Review commented-out zero loss conditions.

The commented-out block contains important logic for handling edge cases (empty videos or single-instance videos). Consider:

  1. Why was this code commented out?
  2. Are these edge cases now handled differently?
  3. Should this be restored with improved conditions?
#!/bin/bash
# Search for other zero loss handling in the codebase
rg "return.*zero.*loss" --type py

125-125: Verify fixed detection assumption.

The comment states that pred_box and target_box are the same for fixed detections, resulting in a diagonal IoU matrix. Let's verify this assumption across the codebase.

✅ Verification successful

Based on the search results, I can now generate the final response:

Fixed detection assumption is verified correct

The code in dreem/training/losses.py explicitly shows that target_box is directly assigned the value of pred_box:

pred_box = get_boxes(instances)
pred_box = torch.nanmean(pred_box, axis=1)
target_box, target_time = pred_box, pred_time

This confirms that for fixed detections, the prediction and target boxes are indeed identical, which would result in a diagonal IoU matrix when computing box_iou(pred_box, target_box).

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for places where pred_box and target_box are set to understand if they can ever differ
rg -A 5 "pred_box|target_box" --type py

Length of output: 4497

Comment on lines +38 to +40
hparams_df = pd.read_csv(cfg.batch_config)
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add CSV file validation and error handling.

The code assumes the CSV file exists and contains the expected columns. Consider adding validation:

  1. Check if the CSV file exists
  2. Verify required columns are present
  3. Validate the index is within bounds of the DataFrame
+    if not os.path.exists(cfg.batch_config):
+        raise FileNotFoundError(f"Batch config file {cfg.batch_config} not found")
     hparams_df = pd.read_csv(cfg.batch_config)
+    if len(hparams_df) <= index:
+        raise IndexError(f"Task index {index} exceeds number of configurations {len(hparams_df)}")
     hparams = hparams_df.iloc[index].to_dict()
     _ = hparams.pop("Unnamed: 0", None)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
hparams_df = pd.read_csv(cfg.batch_config)
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)
if not os.path.exists(cfg.batch_config):
raise FileNotFoundError(f"Batch config file {cfg.batch_config} not found")
hparams_df = pd.read_csv(cfg.batch_config)
if len(hparams_df) <= index:
raise IndexError(f"Task index {index} exceeds number of configurations {len(hparams_df)}")
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants