-
Notifications
You must be signed in to change notification settings - Fork 82
/
Copy pathrun.py
58 lines (43 loc) · 1.69 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
# Copyright (c) 2021 Max Bain
# This file has been modified by Graphcore
import fire
import wandb
from configs.parse_config import ConfigParser
from modeling.trainer import TrainerIPU
config = None
def run():
trainer = TrainerIPU(config=config)
if config["validation_only"]:
trainer.validate()
else:
trainer.train()
def parse_config(
config_name: str, validation_only: bool = False, compile_only: bool = False, timestamp_ckpt: bool = True, **kwargs
):
"""
Argument parser
Additional args must be of the form key=value. Key is of the form
A.B.C with A.B.C corresponding to a set of keys to be used as input
to the nested dict loaded through CONFIG_NAME json file.
See configs/*.json for example configs.
To override values in the loaded dict, provide (nested) keys in the form A.B.C=value.
For Example to override arch.type=FrozenInTime in webvid2m-8ipu-1f.json,
provide the argument --arch.type=new_value.
To remove values, set the key to None, e.g. --arch.type=None.
Keys that do not exist in the loaded dict will automatically be added.
"""
global config
kwargs["validation_only"] = validation_only
kwargs["compile_only"] = compile_only
config = ConfigParser(config_name, timestamp=timestamp_ckpt, **kwargs)
if __name__ == "__main__":
fire.Fire(parse_config)
if config._config["trainer"].get("wandb", False):
wandb.init(
project=config["trainer"].get("project_name", "torch-frozen-in-time"),
name=config["trainer"].get("run_name", None),
config=config._config,
dir="/tmp",
)
run()