-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
107 lines (90 loc) · 3.36 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import re
from pathlib import Path
import hydra
import numpy as np
import ray
from omegaconf import OmegaConf
from ray.rllib.agents import ppo
from ray.rllib.models import ModelCatalog
from ray.tune.logger import UnifiedLogger
from ray.tune.registry import register_env
import ig_navigation
from ig_navigation.callbacks import DummyCallback, MetricsCallback
from ig_navigation.model import ComplexInputNetwork
# Add a custom example feature extractor
ModelCatalog.register_custom_model("complex_input_network", ComplexInputNetwork)
def igibson_env_creator(env_config):
from ig_navigation.igibson_env import SearchEnv
return SearchEnv(
config_file=env_config,
mode=env_config["mode"],
action_timestep=1 / 10.0,
physics_timestep=1 / 120.0,
)
@hydra.main(config_path=ig_navigation.CONFIG_PATH, config_name="config")
def main(cfg):
ray.init()
env_config = OmegaConf.to_object(cfg)
register_env("igibson_env_creator", igibson_env_creator)
checkpoint_path = Path(cfg.experiment_save_path, cfg.experiment_name)
num_epochs = np.round(cfg.training_timesteps / cfg.n_steps).astype(int)
save_ep_freq = np.round(
num_epochs / (cfg.training_timesteps / cfg.save_freq)
).astype(int)
config = {
"env": "igibson_env_creator",
"model": OmegaConf.to_object(cfg.model),
"env_config": env_config, # config to pass to env class
"num_workers": cfg.num_envs,
"framework": "torch",
"seed": cfg.seed,
"lambda": cfg.gae_lambda,
"lr": cfg.learning_rate,
"train_batch_size": cfg.n_steps,
"rollout_fragment_length": cfg.n_steps // cfg.num_envs,
"num_sgd_iter": cfg.n_epochs,
"sgd_minibatch_size": cfg.batch_size,
"gamma": cfg.gamma,
"create_env_on_driver": False,
"num_gpus": 1,
"callbacks": MetricsCallback,
}
if cfg.eval_freq > 0:
eval_ep_freq = np.round(
num_epochs / (cfg.training_timesteps / cfg.eval_freq)
).astype(int)
config.update(
{
"evaluation_interval": eval_ep_freq, # every n episodes evaluation episode
"evaluation_duration": 20,
"evaluation_duration_unit": "episodes",
"evaluation_num_workers": 1,
"evaluation_parallel_to_training": True,
"evaluation_config": {
"callbacks": DummyCallback,
"record_env": True,
},
}
)
log_path = str(checkpoint_path.joinpath("log"))
Path(log_path).mkdir(parents=True, exist_ok=True)
trainer = ppo.PPOTrainer(
config,
logger_creator=lambda x: UnifiedLogger(x, log_path), # type: ignore
)
if Path(checkpoint_path).exists():
checkpoints = Path(checkpoint_path).rglob("checkpoint-*")
checkpoints = [
str(f) for f in checkpoints if re.search(r".*checkpoint-\d*$", str(f))
]
checkpoints = sorted(checkpoints)
if len(checkpoints) > 0:
trainer.restore(checkpoints[-1])
for i in range(num_epochs):
# Perform one iteration of training the policy with PPO
trainer.train()
if (i % save_ep_freq) == 0:
checkpoint = trainer.save(checkpoint_path)
print("checkpoint saved at", checkpoint)
if __name__ == "__main__":
main()