Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow batch eval/inference flexibility #80

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

aaprasad
Copy link
Contributor

@aaprasad aaprasad commented Aug 30, 2024

Here we switch the batch logic to take a csv of hyperparams in order to allow more flexibility in batch jobs rather than just running different models

Summary by CodeRabbit

  • New Features

    • Enhanced support for batch training configurations, improving user interaction and error handling.
    • Added logging for hyperparameter updates, providing better visibility during configuration processes.
    • Implemented normalization of input data in embedding learning, improving robustness.
  • Bug Fixes

    • Improved error handling for empty datasets and dataloaders, preventing issues from propagating.
  • Documentation

    • Updated function signatures to reflect new return types for dataset and dataloader methods.

Copy link
Contributor

coderabbitai bot commented Aug 30, 2024

Walkthrough

The changes introduce a new approach to handling batch configurations in the run function across multiple files, replacing previous checkpoint handling. Error handling is improved, especially regarding missing environment variables and empty datasets. Additionally, normalization is added to the embedding process, and adjustments are made to how video names and ground truth track IDs are processed. Overall, the updates enhance configuration management and data handling.

Changes

Files Change Summary
dreem/inference/eval.py, dreem/inference/track.py Introduced batch configuration handling in the run function, improved error messaging, and logging.
dreem/io/config.py Enhanced dataset and dataloader methods to return None for empty datasets, with added logging.
dreem/models/embedding.py Added normalization for the times tensor in the _learned_temp_embedding method.
dreem/models/gtr_runner.py Modified how video names and ground truth track IDs are derived and processed in the on_test_epoch_end method.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Config
    participant Dataset
    participant Dataloader
    participant Model

    User->>Config: Provide batch_config
    Config->>Dataset: Load dataset
    Dataset-->>Config: Return dataset or None
    Config->>Dataloader: Create dataloader
    Dataloader-->>Config: Return dataloader or None
    Config->>Model: Run training/evaluation
Loading

🐇 "In the meadow, hops a little bunny,
With changes bright, oh so sunny!
Batch configs now take the lead,
No more empty datasets to heed.
With logs and norms, we dance and play,
Hooray for updates, hip-hip-hooray!" 🐇


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

Share
Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai generate interesting stats about this repository and render them as a table.
    • @coderabbitai show all the console.log statements in this repository.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Outside diff range, codebase verification and nitpick comments (1)
dreem/io/config.py (1)

Line range hint 200-257: Approve handling of empty datasets and suggest improvement.

Returning None for an empty dataset and logging a warning is a good practice. However, consider improving the clarity of the warning message to specify which dataset is empty.

The code changes are approved.

Consider improving the warning message for clarity:

- logger.warn(f"Length of {mode} dataset is {len(dataset)}! Returning None")
+ logger.warn(f"Length of {mode} dataset is 0! Returning None as the dataset is empty.")
Tools
Ruff

261-261: Undefined name SleapDataset

(F821)


261-261: Undefined name MicroscopyDataset

(F821)


261-261: Undefined name CellTrackingDataset

(F821)

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 2af0dd5 and 9eddead.

Files selected for processing (5)
  • dreem/inference/eval.py (1 hunks)
  • dreem/inference/track.py (2 hunks)
  • dreem/io/config.py (6 hunks)
  • dreem/models/embedding.py (1 hunks)
  • dreem/models/gtr_runner.py (2 hunks)
Additional context used
Ruff
dreem/inference/eval.py

30-30: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


56-56: f-string without any placeholders

Remove extraneous f prefix

(F541)

dreem/inference/track.py

100-100: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

dreem/io/config.py

200-200: Undefined name SleapDataset

(F821)


200-200: Undefined name MicroscopyDataset

(F821)


200-200: Undefined name CellTrackingDataset

(F821)


261-261: Undefined name SleapDataset

(F821)


261-261: Undefined name MicroscopyDataset

(F821)


261-261: Undefined name CellTrackingDataset

(F821)

Additional comments not posted (10)
dreem/inference/eval.py (3)

38-44: Good handling of hyperparameters from CSV.

Reading hyperparameters from a CSV file and removing an unnamed column is a good practice for data cleanup. The logging statements provide visibility into the configuration process, which is beneficial for debugging and verification.

The changes are approved.


42-44: Effective use of logging.

The logging statements added provide useful information about the hyperparameters being set, which enhances transparency and aids in debugging.

The changes are approved.


46-47: Proper initialization of default hyperparameters.

Initializing hparams as an empty dictionary when "batch_config" is not present is a good fallback mechanism, ensuring that the function remains flexible and robust.

The changes are approved.

dreem/io/config.py (2)

177-179: Approve sorting of file paths.

Sorting file paths for labels and videos ensures consistency in the order of files processed, which is beneficial for reproducibility.

The code changes are approved.


Line range hint 261-313: Approve robust handling of invalid datasets in dataloader.

Handling None or empty datasets by returning None and logging warnings in the get_dataloader method is a good practice to prevent errors during data loading and maintain transparency.

The code changes are approved.

dreem/inference/track.py (4)

116-117: Initialize empty dictionary for hyperparameters.

This initialization is straightforward and correct.

The code changes are approved.


120-120: Review model and tracker configuration updates.

The updates to the model and tracker configurations are crucial for the correct functioning of the inference process.

The code changes are approved.

Also applies to: 122-122, 124-124, 127-127


137-144: Enhance logging for tracking and results saving.

The enhanced logging provides better visibility into the operations being performed, which is beneficial for debugging and monitoring.

The code changes are approved.


118-118: Verify checkpoint path retrieval.

Ensure that the checkpoint path is correctly retrieved and handled, especially in scenarios where it might not be specified.

Verification successful

Checkpoint path retrieval is correctly handled.

The codebase includes checks for None or empty checkpoint paths, ensuring robustness in scenarios where the path might not be specified. The checkpoint path is accessed and used appropriately across different modules, and test cases confirm its correct handling. No issues were found.

Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that the checkpoint path is always available in the configuration.

# Test: Search for the checkpoint path handling in the configuration loading.
rg --type python -A 5 $'ckpt_path'

Length of output: 64


Script:

#!/bin/bash
# Description: Verify that the checkpoint path is always available in the configuration.

# Test: Search for the checkpoint path handling in the configuration loading.
rg --type py -A 5 $'ckpt_path'

Length of output: 4118

dreem/models/gtr_runner.py (1)

304-304: Ensure consistent video name handling.

Retaining the entire filename could impact how video names are stored or referenced later in the code. Ensure consistency across the application.

model.tracker_cfg = eval_cfg.cfg.tracker
model.tracker = Tracker(**model.tracker_cfg)

logger.info(f"Using the following tracker:")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove extraneous f-string prefix.

The f-string used in the logging statement does not contain any placeholders, making the f prefix unnecessary.

Remove the extraneous f prefix to clean up the code:

- logger.info(f"Using the following tracker:")
+ logger.info("Using the following tracker:")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
logger.info(f"Using the following tracker:")
logger.info("Using the following tracker:")
Tools
Ruff

56-56: f-string without any placeholders

Remove extraneous f prefix

(F541)

Comment on lines +29 to +44
# update with parameters for batch train job
if "batch_config" in cfg.keys():
try:
index = int(os.environ["POD_INDEX"])
# For testing without deploying a job on runai
except KeyError:
index = input("Pod Index Not found! Please choose a pod index: ")

logger.info(f"Pod Index: {index}")

checkpoints = pd.read_csv(cfg.checkpoints)
checkpoint = checkpoints.iloc[index]
except KeyError as e:
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)

hparams_df = pd.read_csv(cfg.batch_config)
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)

if eval_cfg.set_hparams(hparams):
logger.info("Updated the following hparams to the following values")
logger.info(hparams)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor dictionary key check and improve error handling.

The handling of the POD_INDEX environment variable and user input is robust, enhancing the user experience by providing clear error messages. However, the check for "batch_config" can be simplified by removing .keys() for a more Pythonic approach.

Apply this diff to refactor the dictionary key check:

- if "batch_config" in cfg.keys():
+ if "batch_config" in cfg:

The changes are approved.

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# update with parameters for batch train job
if "batch_config" in cfg.keys():
try:
index = int(os.environ["POD_INDEX"])
# For testing without deploying a job on runai
except KeyError:
index = input("Pod Index Not found! Please choose a pod index: ")
logger.info(f"Pod Index: {index}")
checkpoints = pd.read_csv(cfg.checkpoints)
checkpoint = checkpoints.iloc[index]
except KeyError as e:
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)
hparams_df = pd.read_csv(cfg.batch_config)
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)
if eval_cfg.set_hparams(hparams):
logger.info("Updated the following hparams to the following values")
logger.info(hparams)
# update with parameters for batch train job
if "batch_config" in cfg:
try:
index = int(os.environ["POD_INDEX"])
except KeyError as e:
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)
hparams_df = pd.read_csv(cfg.batch_config)
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)
if eval_cfg.set_hparams(hparams):
logger.info("Updated the following hparams to the following values")
logger.info(hparams)
Tools
Ruff

30-30: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

@@ -323,6 +323,7 @@ def _learned_temp_embedding(self, times: torch.Tensor) -> torch.Tensor:
"""
temp_lookup = self.lookup
N = times.shape[0]
times = times / times.max()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add safeguard against division by zero in normalization step.

The normalization step times = times / times.max() could potentially lead to a division by zero error if times.max() is zero. It's crucial to add a small epsilon value to avoid this issue.

Consider modifying the normalization step to include a small epsilon value:

- times = times / times.max()
+ times = times / (times.max() + 1e-6)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
times = times / times.max()
times = times / (times.max() + 1e-6)

Comment on lines +103 to +106
except KeyError as e:
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle missing environment variable more robustly.

The handling of a missing POD_INDEX environment variable could be improved by providing a default value or a more robust error handling mechanism, rather than prompting the user.

Consider setting a default index or handling the error in a way that does not require user interaction, which might not be feasible in batch processes.

Comment on lines +108 to +114
hparams_df = pd.read_csv(cfg.batch_config)
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)

if pred_cfg.set_hparams(hparams):
logger.info("Updated the following hparams to the following values")
logger.info(hparams)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validate and log hyperparameter updates.

The code reads hyperparameters from a CSV and updates the configuration. Ensure that the hyperparameters are validated before applying them to avoid runtime errors.

Add validation for the hyperparameters read from the CSV to ensure they meet expected formats and constraints.

Comment on lines +99 to +100
# update with parameters for batch train job
if "batch_config" in cfg.keys():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify dictionary key check.

The check for "batch_config" in the configuration dictionary can be simplified by removing .keys().

Apply this diff to simplify the check:

-if "batch_config" in cfg.keys():
+if "batch_config" in cfg:
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# update with parameters for batch train job
if "batch_config" in cfg.keys():
# update with parameters for batch train job
if "batch_config" in cfg:
Tools
Ruff

100-100: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

Comment on lines +313 to +325
[
instance.gt_track_id.item()
for instance in frame.instances
],
save={"crop": True, "features": True, "embeddings": True},
)
else:
_ = frame.to_h5(
clip_group, frame.get_gt_track_ids().cpu().numpy()
clip_group,
[
instance.gt_track_id.item()
for instance in frame.instances
],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review changes to track ID handling.

The new implementation uses a list comprehension to extract gt_track_id. Ensure that this change does not affect the expected data structure in downstream processes.

Verify that the new list structure of track IDs is compatible with all downstream processes that consume this data.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant