Skip to content

Commit

Permalink
Save copy of labeled data csv to output dir
Browse files Browse the repository at this point in the history
  • Loading branch information
ksikka committed Nov 18, 2024
1 parent 873fa32 commit 5f572e2
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 17 deletions.
53 changes: 36 additions & 17 deletions lightning_pose/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import os
import random
import shutil
import sys
import warnings
from pathlib import Path

import lightning.pytorch as pl
import numpy as np
Expand Down Expand Up @@ -76,6 +78,33 @@ def train(cfg: DictConfig) -> None:
# model
model = get_model(cfg=cfg, data_module=data_module, loss_factories=loss_factories)

# ----------------------------------------------------------------------------------
# Save configuration in output directory
# ----------------------------------------------------------------------------------
# Done before training; files will exist even if script dies prematurely.
hydra_output_directory = os.getcwd()
print(f"Hydra output directory: {hydra_output_directory}")

# save config file
dest_config_file = Path(hydra_output_directory) / "config.yaml"
OmegaConf.save(config=cfg, f=dest_config_file, resolve=True)

# save labeled data file(s)
if isinstance(cfg.data.csv_file, str):
# single view
csv_files = [cfg.data.csv_file]
else:
# multi view
assert isinstance(cfg.data.csv_file, ListConfig)
csv_files = cfg.data.csv_file
for csv_file in csv_files:
src_csv_file = Path(csv_file)
if not src_csv_file.is_absolute():
src_csv_file = Path(data_dir) / src_csv_file

dest_csv_file = Path(hydra_output_directory) / src_csv_file.name
shutil.copyfile(src_csv_file, dest_csv_file)

# ----------------------------------------------------------------------------------
# Set up and run training
# ----------------------------------------------------------------------------------
Expand Down Expand Up @@ -138,18 +167,12 @@ def train(cfg: DictConfig) -> None:
# ----------------------------------------------------------------------------------
# Post-training analysis
# ----------------------------------------------------------------------------------
hydra_output_directory = os.getcwd()
print(f"Hydra output directory: {hydra_output_directory}")
# get best ckpt
best_ckpt = os.path.abspath(trainer.checkpoint_callback.best_model_path)
print(f"Best checkpoint: {os.path.basename(best_ckpt)}")
# check if best_ckpt is a file
if not os.path.isfile(best_ckpt):
raise FileNotFoundError("Cannot find checkpoint. Have you trained for too few epochs?")
# save config file
cfg_file_local = os.path.join(hydra_output_directory, "config.yaml")
with open(cfg_file_local, "w") as fp:
OmegaConf.save(config=cfg, f=fp.name)

# make unaugmented data_loader if necessary
if cfg.training.imgaug != "default":
Expand Down Expand Up @@ -181,17 +204,13 @@ def train(cfg: DictConfig) -> None:
preds_file=preds_file,
)
# compute and save various metrics
try:
# take care of multiview case, where multiple csv files have been saved
preds_files = [
os.path.join(hydra_output_directory, path) for path in
os.listdir(hydra_output_directory) if path.endswith(".csv")
]
if len(preds_files) > 1:
preds_file = preds_files
compute_metrics(cfg=cfg, preds_file=preds_file, data_module=data_module_pred)
except Exception as e:
print(f"Error computing metrics\n{e}")
# for multiview, predict_dataset outputs one pred file per view.
multiview_pred_files = [
str(p) for p in Path(hydra_output_directory).glob("predictions_*.csv")
]
if len(multiview_pred_files) > 0:
preds_file = multiview_pred_files
compute_metrics(cfg=cfg, preds_file=preds_file, data_module=data_module_pred)

# ----------------------------------------------------------------------------------
# predict folder of videos
Expand Down
3 changes: 3 additions & 0 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_train_singleview(cfg, tmp_path):

# ensure labeled data was properly processed
assert (tmp_path / "config.yaml").is_file()
assert (tmp_path / "CollectedData.csv").is_file()
assert (tmp_path / "predictions.csv").is_file()
assert (tmp_path / "predictions_pca_multiview_error.csv").is_file()
assert (tmp_path / "predictions_pca_singleview_error.csv").is_file()
Expand All @@ -80,6 +81,8 @@ def test_train_multiview(cfg_multiview, tmp_path):
train(cfg)

assert (tmp_path / "config.yaml").is_file()
assert (tmp_path / "top.csv").is_file()
assert (tmp_path / "bot.csv").is_file()

for view in ["top", "bot"]:
# ensure labeled data was properly processed
Expand Down

0 comments on commit 5f572e2

Please sign in to comment.