Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge iodic-diva-branch into main #15

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions dvc.lock
Original file line number Diff line number Diff line change
Expand Up @@ -8,53 +8,53 @@ stages:
deps:
- path: http://www.manythings.org/anki/fra-eng.zip
hash: md5
checksum: '29366693926503416364759562686349725439'
size: 7757635
checksum: '58835075968442113841504768009775441090'
size: 7833145
outs:
- path: fra.txt
hash: md5
md5: f16099673fd64e9fda1e17927ad02248
size: 34292144
md5: efac451a5f83015366bfe6ac117b9ba4
size: 34591986
train:
cmd: python train.py
deps:
- path: fra.txt
hash: md5
md5: f16099673fd64e9fda1e17927ad02248
size: 34292144
md5: efac451a5f83015366bfe6ac117b9ba4
size: 34591986
- path: train.py
hash: md5
md5: 9505d81175e09e8cc80a8905c95abed7
size: 7136
md5: 51d5640636dc54ef6171b81ec77b599f
size: 7224
params:
params.yaml:
data_path: fra.txt
model:
batch_size: 512
latent_dim: 8
duration: 00:00:30:00
max_epochs: 2
max_epochs: 20
optim:
lr: 0.01
lr: 0.004
num_samples: 10000
seed: 423
outs:
- path: model
hash: md5
md5: b4e402de57bdfa7871c9ea74562482f5.dir
size: 18561
md5: deae7576b332077b74aaf9470428840d.dir
size: 18930
nfiles: 1
- path: results/artifacts
hash: md5
md5: b4e402de57bdfa7871c9ea74562482f5.dir
size: 18561
md5: deae7576b332077b74aaf9470428840d.dir
size: 18930
nfiles: 1
- path: results/metrics.json
hash: md5
md5: e525dbdc2466b77f199fe5fe24afc549
size: 256
md5: 1e13f4dfbfa45966ad62b7df95e79e47
size: 366
- path: results/plots
hash: md5
md5: 1dd4ea44497391604fc8c379789b52b4.dir
size: 240
nfiles: 5
md5: ac5e4f9a863532eccbdf9e2ab81cad46.dir
size: 1255
nfiles: 7
2 changes: 1 addition & 1 deletion dvc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ plots:
x: step
artifacts:
best:
path: results/artifacts/epoch=0-step=2.ckpt
path: results/artifacts/epoch=0-step=16.ckpt
type: model
4 changes: 2 additions & 2 deletions params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ model:
batch_size: 512
latent_dim: 8
duration: 00:00:30:00
max_epochs: 2
max_epochs: 20
optim:
lr: 0.01
lr: 0.004
data_path: fra.txt
num_samples: 10000
seed: 423
16 changes: 10 additions & 6 deletions results/metrics.json
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
{
"val": {
"loss": 1.7237510681152344,
"acc": 0.014084506779909134
"loss": 0.6839841604232788,
"acc": 0.01470289845019579
},
"epoch": 1,
"step": 3,
"epoch": 9,
"step": 159,
"train": {
"epoch": {
"loss": 2.709216356277466,
"acc": 0.013341710902750492
"loss": 0.6881032586097717,
"acc": 0.01480439305305481
},
"step": {
"loss": 0.7074033617973328,
"acc": 0.013838487677276134
}
}
}
2 changes: 1 addition & 1 deletion results/params.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
latent_dim: 8
optim_params:
lr: 0.01
lr: 0.004
27 changes: 23 additions & 4 deletions results/plots/metrics/epoch.tsv
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@
step epoch
1 0
1 0
3 1
3 1
15 0
15 0
31 1
31 1
47 2
47 2
49 3
63 3
63 3
79 4
79 4
95 5
95 5
99 6
111 6
111 6
127 7
127 7
143 8
143 8
149 9
159 9
159 9
12 changes: 10 additions & 2 deletions results/plots/metrics/train/epoch/acc.tsv
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
step acc
1 0.0067295110784471035
3 0.013341710902750492
15 0.013551988638937473
31 0.014475204981863499
47 0.014578690752387047
63 0.012611094862222672
79 0.012640479020774364
95 0.0126886535435915
111 0.012571862898766994
127 0.012932837940752506
143 0.013741872273385525
159 0.01480439305305481
12 changes: 10 additions & 2 deletions results/plots/metrics/train/epoch/loss.tsv
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
step loss
1 4.339257717132568
3 2.709216356277466
15 2.651099920272827
31 1.1034443378448486
47 0.9179204106330872
63 0.8584996461868286
79 0.7974901795387268
95 0.7613820433616638
111 0.7376429438591003
127 0.7168747782707214
143 0.7019720673561096
159 0.6881032586097717
4 changes: 4 additions & 0 deletions results/plots/metrics/train/step/acc.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
step acc
49 0.012820512987673283
99 0.01176421158015728
149 0.013838487677276134
4 changes: 4 additions & 0 deletions results/plots/metrics/train/step/loss.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
step loss
49 0.8957815766334534
99 0.7495381236076355
149 0.7074033617973328
12 changes: 10 additions & 2 deletions results/plots/metrics/val/acc.tsv
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
step acc
1 0.014448276720941067
3 0.014084506779909134
15 0.01582844741642475
31 0.012163701467216015
47 0.012383817695081234
63 0.012163614854216576
79 0.013182752765715122
95 0.01217393297702074
111 0.012239444069564342
127 0.012904092669487
143 0.014017394743859768
159 0.01470289845019579
12 changes: 10 additions & 2 deletions results/plots/metrics/val/loss.tsv
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
step loss
1 3.0123467445373535
3 1.7237510681152344
15 1.277770757675171
31 0.9838346242904663
47 0.8614449501037598
63 0.8240187168121338
79 0.7715871334075928
95 0.7506844997406006
111 0.7279419302940369
127 0.7120298147201538
143 0.696197509765625
159 0.6839841604232788
12 changes: 7 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import numpy as np
import torch
import pytorch_lightning as pl
import lightning
from lightning.pytorch import callbacks
import torchmetrics
from ruamel.yaml import YAML
from dvclive import Live
# from lightning.pytorch.loggers import DVCLiveLogger
from dvclive.lightning import DVCLiveLogger

yaml = YAML(typ="safe")
Expand Down Expand Up @@ -76,7 +78,7 @@


# Define the model
class LSTMSeqToSeq(pl.LightningModule):
class LSTMSeqToSeq(lightning.LightningModule):
def __init__(self, latent_dim, optim_params):
super().__init__()
# Log parameters (saves them to self.hparams)
Expand Down Expand Up @@ -163,14 +165,14 @@ def __getitem__(self, idx):

exp = Live("results", save_dvc_exp=True)
live = DVCLiveLogger(report=None, experiment=exp, log_model=True)
checkpoint = pl.callbacks.ModelCheckpoint(
checkpoint = callbacks.ModelCheckpoint(
dirpath="model",
monitor="val_acc",
mode="max",
save_weights_only=True, every_n_epochs=1)
timer = pl.callbacks.Timer(duration=params["model"]["duration"])
timer = callbacks.Timer(duration=params["model"]["duration"])

trainer = pl.Trainer(max_epochs=params["model"]["max_epochs"], logger=[live],
trainer = lightning.Trainer(max_epochs=params["model"]["max_epochs"], logger=[live],
callbacks=[timer, checkpoint])
trainer.fit(model=arch, train_dataloaders=train_loader,
val_dataloaders=val_loader)