From 9067013cec625fec294ddd2856322043e5349261 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Tue, 12 Dec 2023 13:39:41 -0500 Subject: [PATCH] dvc: commit experiment 5c79eb5d512b5775838831a58d11815e63653a8b57c818f24949f96bb0a06837 --- dvc.lock | 38 +++++++++++----------- dvc.yaml | 2 +- params.yaml | 4 +-- results/metrics.json | 16 +++++---- results/params.yaml | 2 +- results/plots/metrics/epoch.tsv | 27 ++++++++++++--- results/plots/metrics/train/epoch/acc.tsv | 12 +++++-- results/plots/metrics/train/epoch/loss.tsv | 12 +++++-- results/plots/metrics/train/step/acc.tsv | 4 +++ results/plots/metrics/train/step/loss.tsv | 4 +++ results/plots/metrics/val/acc.tsv | 12 +++++-- results/plots/metrics/val/loss.tsv | 12 +++++-- train.py | 12 ++++--- 13 files changed, 111 insertions(+), 46 deletions(-) create mode 100644 results/plots/metrics/train/step/acc.tsv create mode 100644 results/plots/metrics/train/step/loss.tsv diff --git a/dvc.lock b/dvc.lock index fc26f5728..497f11bb0 100644 --- a/dvc.lock +++ b/dvc.lock @@ -8,24 +8,24 @@ stages: deps: - path: http://www.manythings.org/anki/fra-eng.zip hash: md5 - checksum: '29366693926503416364759562686349725439' - size: 7757635 + checksum: '58835075968442113841504768009775441090' + size: 7833145 outs: - path: fra.txt hash: md5 - md5: f16099673fd64e9fda1e17927ad02248 - size: 34292144 + md5: efac451a5f83015366bfe6ac117b9ba4 + size: 34591986 train: cmd: python train.py deps: - path: fra.txt hash: md5 - md5: f16099673fd64e9fda1e17927ad02248 - size: 34292144 + md5: efac451a5f83015366bfe6ac117b9ba4 + size: 34591986 - path: train.py hash: md5 - md5: 9505d81175e09e8cc80a8905c95abed7 - size: 7136 + md5: 51d5640636dc54ef6171b81ec77b599f + size: 7224 params: params.yaml: data_path: fra.txt @@ -33,28 +33,28 @@ stages: batch_size: 512 latent_dim: 8 duration: 00:00:30:00 - max_epochs: 2 + max_epochs: 20 optim: - lr: 0.01 + lr: 0.004 num_samples: 10000 seed: 423 outs: - path: model hash: md5 - md5: b4e402de57bdfa7871c9ea74562482f5.dir - size: 18561 + md5: deae7576b332077b74aaf9470428840d.dir + size: 18930 nfiles: 1 - path: results/artifacts hash: md5 - md5: b4e402de57bdfa7871c9ea74562482f5.dir - size: 18561 + md5: deae7576b332077b74aaf9470428840d.dir + size: 18930 nfiles: 1 - path: results/metrics.json hash: md5 - md5: e525dbdc2466b77f199fe5fe24afc549 - size: 256 + md5: 1e13f4dfbfa45966ad62b7df95e79e47 + size: 366 - path: results/plots hash: md5 - md5: 1dd4ea44497391604fc8c379789b52b4.dir - size: 240 - nfiles: 5 + md5: ac5e4f9a863532eccbdf9e2ab81cad46.dir + size: 1255 + nfiles: 7 diff --git a/dvc.yaml b/dvc.yaml index eafd9dc22..3865486f6 100644 --- a/dvc.yaml +++ b/dvc.yaml @@ -31,5 +31,5 @@ plots: x: step artifacts: best: - path: results/artifacts/epoch=0-step=2.ckpt + path: results/artifacts/epoch=0-step=16.ckpt type: model diff --git a/params.yaml b/params.yaml index 0dea3bef1..3b73d2685 100644 --- a/params.yaml +++ b/params.yaml @@ -2,9 +2,9 @@ model: batch_size: 512 latent_dim: 8 duration: 00:00:30:00 - max_epochs: 2 + max_epochs: 20 optim: - lr: 0.01 + lr: 0.004 data_path: fra.txt num_samples: 10000 seed: 423 diff --git a/results/metrics.json b/results/metrics.json index 2d84b4694..dfccf8d6a 100644 --- a/results/metrics.json +++ b/results/metrics.json @@ -1,14 +1,18 @@ { "val": { - "loss": 1.7237510681152344, - "acc": 0.014084506779909134 + "loss": 0.6839841604232788, + "acc": 0.01470289845019579 }, - "epoch": 1, - "step": 3, + "epoch": 9, + "step": 159, "train": { "epoch": { - "loss": 2.709216356277466, - "acc": 0.013341710902750492 + "loss": 0.6881032586097717, + "acc": 0.01480439305305481 + }, + "step": { + "loss": 0.7074033617973328, + "acc": 0.013838487677276134 } } } diff --git a/results/params.yaml b/results/params.yaml index b9a798c19..28eb9a19e 100644 --- a/results/params.yaml +++ b/results/params.yaml @@ -1,3 +1,3 @@ latent_dim: 8 optim_params: - lr: 0.01 + lr: 0.004 diff --git a/results/plots/metrics/epoch.tsv b/results/plots/metrics/epoch.tsv index 9ced0a6d2..56f5f1a62 100644 --- a/results/plots/metrics/epoch.tsv +++ b/results/plots/metrics/epoch.tsv @@ -1,5 +1,24 @@ step epoch -1 0 -1 0 -3 1 -3 1 +15 0 +15 0 +31 1 +31 1 +47 2 +47 2 +49 3 +63 3 +63 3 +79 4 +79 4 +95 5 +95 5 +99 6 +111 6 +111 6 +127 7 +127 7 +143 8 +143 8 +149 9 +159 9 +159 9 diff --git a/results/plots/metrics/train/epoch/acc.tsv b/results/plots/metrics/train/epoch/acc.tsv index 2a2f3f7d0..d87ac92a4 100644 --- a/results/plots/metrics/train/epoch/acc.tsv +++ b/results/plots/metrics/train/epoch/acc.tsv @@ -1,3 +1,11 @@ step acc -1 0.0067295110784471035 -3 0.013341710902750492 +15 0.013551988638937473 +31 0.014475204981863499 +47 0.014578690752387047 +63 0.012611094862222672 +79 0.012640479020774364 +95 0.0126886535435915 +111 0.012571862898766994 +127 0.012932837940752506 +143 0.013741872273385525 +159 0.01480439305305481 diff --git a/results/plots/metrics/train/epoch/loss.tsv b/results/plots/metrics/train/epoch/loss.tsv index caa189d7b..b00608237 100644 --- a/results/plots/metrics/train/epoch/loss.tsv +++ b/results/plots/metrics/train/epoch/loss.tsv @@ -1,3 +1,11 @@ step loss -1 4.339257717132568 -3 2.709216356277466 +15 2.651099920272827 +31 1.1034443378448486 +47 0.9179204106330872 +63 0.8584996461868286 +79 0.7974901795387268 +95 0.7613820433616638 +111 0.7376429438591003 +127 0.7168747782707214 +143 0.7019720673561096 +159 0.6881032586097717 diff --git a/results/plots/metrics/train/step/acc.tsv b/results/plots/metrics/train/step/acc.tsv new file mode 100644 index 000000000..55e0b6ee8 --- /dev/null +++ b/results/plots/metrics/train/step/acc.tsv @@ -0,0 +1,4 @@ +step acc +49 0.012820512987673283 +99 0.01176421158015728 +149 0.013838487677276134 diff --git a/results/plots/metrics/train/step/loss.tsv b/results/plots/metrics/train/step/loss.tsv new file mode 100644 index 000000000..71229a09f --- /dev/null +++ b/results/plots/metrics/train/step/loss.tsv @@ -0,0 +1,4 @@ +step loss +49 0.8957815766334534 +99 0.7495381236076355 +149 0.7074033617973328 diff --git a/results/plots/metrics/val/acc.tsv b/results/plots/metrics/val/acc.tsv index 1df924139..7a2921d52 100644 --- a/results/plots/metrics/val/acc.tsv +++ b/results/plots/metrics/val/acc.tsv @@ -1,3 +1,11 @@ step acc -1 0.014448276720941067 -3 0.014084506779909134 +15 0.01582844741642475 +31 0.012163701467216015 +47 0.012383817695081234 +63 0.012163614854216576 +79 0.013182752765715122 +95 0.01217393297702074 +111 0.012239444069564342 +127 0.012904092669487 +143 0.014017394743859768 +159 0.01470289845019579 diff --git a/results/plots/metrics/val/loss.tsv b/results/plots/metrics/val/loss.tsv index e124247ee..3a32ae43d 100644 --- a/results/plots/metrics/val/loss.tsv +++ b/results/plots/metrics/val/loss.tsv @@ -1,3 +1,11 @@ step loss -1 3.0123467445373535 -3 1.7237510681152344 +15 1.277770757675171 +31 0.9838346242904663 +47 0.8614449501037598 +63 0.8240187168121338 +79 0.7715871334075928 +95 0.7506844997406006 +111 0.7279419302940369 +127 0.7120298147201538 +143 0.696197509765625 +159 0.6839841604232788 diff --git a/train.py b/train.py index 2f55782af..e679a21a4 100644 --- a/train.py +++ b/train.py @@ -1,10 +1,12 @@ import os import numpy as np import torch -import pytorch_lightning as pl +import lightning +from lightning.pytorch import callbacks import torchmetrics from ruamel.yaml import YAML from dvclive import Live +# from lightning.pytorch.loggers import DVCLiveLogger from dvclive.lightning import DVCLiveLogger yaml = YAML(typ="safe") @@ -76,7 +78,7 @@ # Define the model -class LSTMSeqToSeq(pl.LightningModule): +class LSTMSeqToSeq(lightning.LightningModule): def __init__(self, latent_dim, optim_params): super().__init__() # Log parameters (saves them to self.hparams) @@ -163,14 +165,14 @@ def __getitem__(self, idx): exp = Live("results", save_dvc_exp=True) live = DVCLiveLogger(report=None, experiment=exp, log_model=True) -checkpoint = pl.callbacks.ModelCheckpoint( +checkpoint = callbacks.ModelCheckpoint( dirpath="model", monitor="val_acc", mode="max", save_weights_only=True, every_n_epochs=1) -timer = pl.callbacks.Timer(duration=params["model"]["duration"]) +timer = callbacks.Timer(duration=params["model"]["duration"]) -trainer = pl.Trainer(max_epochs=params["model"]["max_epochs"], logger=[live], +trainer = lightning.Trainer(max_epochs=params["model"]["max_epochs"], logger=[live], callbacks=[timer, checkpoint]) trainer.fit(model=arch, train_dataloaders=train_loader, val_dataloaders=val_loader)