-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow batch eval/inference flexibility #80
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThe changes introduce a new approach to handling batch configurations in the Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Config
participant Dataset
participant Dataloader
participant Model
User->>Config: Provide batch_config
Config->>Dataset: Load dataset
Dataset-->>Config: Return dataset or None
Config->>Dataloader: Create dataloader
Dataloader-->>Config: Return dataloader or None
Config->>Model: Run training/evaluation
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
Outside diff range, codebase verification and nitpick comments (1)
dreem/io/config.py (1)
Line range hint
200-257
: Approve handling of empty datasets and suggest improvement.Returning
None
for an empty dataset and logging a warning is a good practice. However, consider improving the clarity of the warning message to specify which dataset is empty.The code changes are approved.
Consider improving the warning message for clarity:
- logger.warn(f"Length of {mode} dataset is {len(dataset)}! Returning None") + logger.warn(f"Length of {mode} dataset is 0! Returning None as the dataset is empty.")Tools
Ruff
261-261: Undefined name
SleapDataset
(F821)
261-261: Undefined name
MicroscopyDataset
(F821)
261-261: Undefined name
CellTrackingDataset
(F821)
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (5)
- dreem/inference/eval.py (1 hunks)
- dreem/inference/track.py (2 hunks)
- dreem/io/config.py (6 hunks)
- dreem/models/embedding.py (1 hunks)
- dreem/models/gtr_runner.py (2 hunks)
Additional context used
Ruff
dreem/inference/eval.py
30-30: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
56-56: f-string without any placeholders
Remove extraneous
f
prefix(F541)
dreem/inference/track.py
100-100: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
dreem/io/config.py
200-200: Undefined name
SleapDataset
(F821)
200-200: Undefined name
MicroscopyDataset
(F821)
200-200: Undefined name
CellTrackingDataset
(F821)
261-261: Undefined name
SleapDataset
(F821)
261-261: Undefined name
MicroscopyDataset
(F821)
261-261: Undefined name
CellTrackingDataset
(F821)
Additional comments not posted (10)
dreem/inference/eval.py (3)
38-44
: Good handling of hyperparameters from CSV.Reading hyperparameters from a CSV file and removing an unnamed column is a good practice for data cleanup. The logging statements provide visibility into the configuration process, which is beneficial for debugging and verification.
The changes are approved.
42-44
: Effective use of logging.The logging statements added provide useful information about the hyperparameters being set, which enhances transparency and aids in debugging.
The changes are approved.
46-47
: Proper initialization of default hyperparameters.Initializing
hparams
as an empty dictionary when "batch_config" is not present is a good fallback mechanism, ensuring that the function remains flexible and robust.The changes are approved.
dreem/io/config.py (2)
177-179
: Approve sorting of file paths.Sorting file paths for labels and videos ensures consistency in the order of files processed, which is beneficial for reproducibility.
The code changes are approved.
Line range hint
261-313
: Approve robust handling of invalid datasets in dataloader.Handling
None
or empty datasets by returningNone
and logging warnings in theget_dataloader
method is a good practice to prevent errors during data loading and maintain transparency.The code changes are approved.
dreem/inference/track.py (4)
116-117
: Initialize empty dictionary for hyperparameters.This initialization is straightforward and correct.
The code changes are approved.
120-120
: Review model and tracker configuration updates.The updates to the model and tracker configurations are crucial for the correct functioning of the inference process.
The code changes are approved.
Also applies to: 122-122, 124-124, 127-127
137-144
: Enhance logging for tracking and results saving.The enhanced logging provides better visibility into the operations being performed, which is beneficial for debugging and monitoring.
The code changes are approved.
118-118
: Verify checkpoint path retrieval.Ensure that the checkpoint path is correctly retrieved and handled, especially in scenarios where it might not be specified.
Verification successful
Checkpoint path retrieval is correctly handled.
The codebase includes checks for
None
or empty checkpoint paths, ensuring robustness in scenarios where the path might not be specified. The checkpoint path is accessed and used appropriately across different modules, and test cases confirm its correct handling. No issues were found.
Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify that the checkpoint path is always available in the configuration. # Test: Search for the checkpoint path handling in the configuration loading. rg --type python -A 5 $'ckpt_path'Length of output: 64
Script:
#!/bin/bash # Description: Verify that the checkpoint path is always available in the configuration. # Test: Search for the checkpoint path handling in the configuration loading. rg --type py -A 5 $'ckpt_path'Length of output: 4118
dreem/models/gtr_runner.py (1)
304-304
: Ensure consistent video name handling.Retaining the entire filename could impact how video names are stored or referenced later in the code. Ensure consistency across the application.
model.tracker_cfg = eval_cfg.cfg.tracker | ||
model.tracker = Tracker(**model.tracker_cfg) | ||
|
||
logger.info(f"Using the following tracker:") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove extraneous f-string prefix.
The f-string used in the logging statement does not contain any placeholders, making the f
prefix unnecessary.
Remove the extraneous f
prefix to clean up the code:
- logger.info(f"Using the following tracker:")
+ logger.info("Using the following tracker:")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
logger.info(f"Using the following tracker:") | |
logger.info("Using the following tracker:") |
Tools
Ruff
56-56: f-string without any placeholders
Remove extraneous
f
prefix(F541)
# update with parameters for batch train job | ||
if "batch_config" in cfg.keys(): | ||
try: | ||
index = int(os.environ["POD_INDEX"]) | ||
# For testing without deploying a job on runai | ||
except KeyError: | ||
index = input("Pod Index Not found! Please choose a pod index: ") | ||
|
||
logger.info(f"Pod Index: {index}") | ||
|
||
checkpoints = pd.read_csv(cfg.checkpoints) | ||
checkpoint = checkpoints.iloc[index] | ||
except KeyError as e: | ||
index = int( | ||
input(f"{e}. Assuming single run!\nPlease input task index to run:") | ||
) | ||
|
||
hparams_df = pd.read_csv(cfg.batch_config) | ||
hparams = hparams_df.iloc[index].to_dict() | ||
_ = hparams.pop("Unnamed: 0", None) | ||
|
||
if eval_cfg.set_hparams(hparams): | ||
logger.info("Updated the following hparams to the following values") | ||
logger.info(hparams) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor dictionary key check and improve error handling.
The handling of the POD_INDEX
environment variable and user input is robust, enhancing the user experience by providing clear error messages. However, the check for "batch_config"
can be simplified by removing .keys()
for a more Pythonic approach.
Apply this diff to refactor the dictionary key check:
- if "batch_config" in cfg.keys():
+ if "batch_config" in cfg:
The changes are approved.
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# update with parameters for batch train job | |
if "batch_config" in cfg.keys(): | |
try: | |
index = int(os.environ["POD_INDEX"]) | |
# For testing without deploying a job on runai | |
except KeyError: | |
index = input("Pod Index Not found! Please choose a pod index: ") | |
logger.info(f"Pod Index: {index}") | |
checkpoints = pd.read_csv(cfg.checkpoints) | |
checkpoint = checkpoints.iloc[index] | |
except KeyError as e: | |
index = int( | |
input(f"{e}. Assuming single run!\nPlease input task index to run:") | |
) | |
hparams_df = pd.read_csv(cfg.batch_config) | |
hparams = hparams_df.iloc[index].to_dict() | |
_ = hparams.pop("Unnamed: 0", None) | |
if eval_cfg.set_hparams(hparams): | |
logger.info("Updated the following hparams to the following values") | |
logger.info(hparams) | |
# update with parameters for batch train job | |
if "batch_config" in cfg: | |
try: | |
index = int(os.environ["POD_INDEX"]) | |
except KeyError as e: | |
index = int( | |
input(f"{e}. Assuming single run!\nPlease input task index to run:") | |
) | |
hparams_df = pd.read_csv(cfg.batch_config) | |
hparams = hparams_df.iloc[index].to_dict() | |
_ = hparams.pop("Unnamed: 0", None) | |
if eval_cfg.set_hparams(hparams): | |
logger.info("Updated the following hparams to the following values") | |
logger.info(hparams) |
Tools
Ruff
30-30: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
@@ -323,6 +323,7 @@ def _learned_temp_embedding(self, times: torch.Tensor) -> torch.Tensor: | |||
""" | |||
temp_lookup = self.lookup | |||
N = times.shape[0] | |||
times = times / times.max() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add safeguard against division by zero in normalization step.
The normalization step times = times / times.max()
could potentially lead to a division by zero error if times.max()
is zero. It's crucial to add a small epsilon value to avoid this issue.
Consider modifying the normalization step to include a small epsilon value:
- times = times / times.max()
+ times = times / (times.max() + 1e-6)
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
times = times / times.max() | |
times = times / (times.max() + 1e-6) |
except KeyError as e: | ||
index = int( | ||
input(f"{e}. Assuming single run!\nPlease input task index to run:") | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle missing environment variable more robustly.
The handling of a missing POD_INDEX
environment variable could be improved by providing a default value or a more robust error handling mechanism, rather than prompting the user.
Consider setting a default index or handling the error in a way that does not require user interaction, which might not be feasible in batch processes.
hparams_df = pd.read_csv(cfg.batch_config) | ||
hparams = hparams_df.iloc[index].to_dict() | ||
_ = hparams.pop("Unnamed: 0", None) | ||
|
||
if pred_cfg.set_hparams(hparams): | ||
logger.info("Updated the following hparams to the following values") | ||
logger.info(hparams) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate and log hyperparameter updates.
The code reads hyperparameters from a CSV and updates the configuration. Ensure that the hyperparameters are validated before applying them to avoid runtime errors.
Add validation for the hyperparameters read from the CSV to ensure they meet expected formats and constraints.
# update with parameters for batch train job | ||
if "batch_config" in cfg.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplify dictionary key check.
The check for "batch_config"
in the configuration dictionary can be simplified by removing .keys()
.
Apply this diff to simplify the check:
-if "batch_config" in cfg.keys():
+if "batch_config" in cfg:
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# update with parameters for batch train job | |
if "batch_config" in cfg.keys(): | |
# update with parameters for batch train job | |
if "batch_config" in cfg: |
Tools
Ruff
100-100: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
[ | ||
instance.gt_track_id.item() | ||
for instance in frame.instances | ||
], | ||
save={"crop": True, "features": True, "embeddings": True}, | ||
) | ||
else: | ||
_ = frame.to_h5( | ||
clip_group, frame.get_gt_track_ids().cpu().numpy() | ||
clip_group, | ||
[ | ||
instance.gt_track_id.item() | ||
for instance in frame.instances | ||
], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review changes to track ID handling.
The new implementation uses a list comprehension to extract gt_track_id
. Ensure that this change does not affect the expected data structure in downstream processes.
Verify that the new list structure of track IDs is compatible with all downstream processes that consume this data.
Here we switch the batch logic to take a csv of hyperparams in order to allow more flexibility in batch jobs rather than just running different models
Summary by CodeRabbit
New Features
Bug Fixes
Documentation