Skip to content

Commit

Permalink
[0.6.3.1] [feat]: add cfg.restore_model_meta
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjim0816 committed Jun 14, 2024
1 parent e93cfa2 commit 3a47d54
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 15 deletions.
4 changes: 2 additions & 2 deletions joyrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
Email: [email protected]
Date: 2023-01-01 16:20:49
LastEditor: JiangJi
LastEditTime: 2024-06-14 09:33:39
LastEditTime: 2024-06-14 09:41:26
Discription:
'''
from joyrl import algos, framework, envs
from joyrl.run import run

__version__ = "0.6.3"
__version__ = "0.6.3.1"

__all__ = [
"algos",
Expand Down
3 changes: 2 additions & 1 deletion joyrl/framework/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-02 15:30:09
LastEditor: JiangJi
LastEditTime: 2024-06-11 13:31:40
LastEditTime: 2024-06-14 09:39:48
Discription:
'''
class DefaultConfig:
Expand Down Expand Up @@ -52,6 +52,7 @@ def __init__(self) -> None:
self.model_save_fre = 500 # model save frequency per update step
# load model settings
self.load_checkpoint = False # if load checkpoint
self.restore_model_meta = True # if restore model meta
self.load_path = "Train_single_CartPole-v1_DQN_20230515-211721" # path to load model
self.load_model_step = 'best' # load model at which step
# stats recorder settings
Expand Down
4 changes: 2 additions & 2 deletions joyrl/framework/interactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2024-02-25 15:46:04
LastEditor: JiangJi
LastEditTime: 2024-06-14 09:34:08
LastEditTime: 2024-06-14 09:40:51
Discription:
'''
import copy
Expand Down Expand Up @@ -49,7 +49,7 @@ def _init_n_sample_steps(self):
self.n_sample_steps = float('inf')

def _load_model_meta(self):
if self.cfg.load_checkpoint:
if self.cfg.load_checkpoint and self.cfg.restore_model_meta:
model_meta = self.cfg.model_meta.get(self.name, {})
self.policy.load_model_meta(model_meta)

Expand Down
4 changes: 2 additions & 2 deletions joyrl/framework/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-02 15:02:30
LastEditor: JiangJi
LastEditTime: 2024-06-14 09:34:20
LastEditTime: 2024-06-14 09:40:42
Discription:
'''
import numpy as np
Expand Down Expand Up @@ -38,7 +38,7 @@ def _init_update_steps(self):
self.n_update_steps = float('inf')

def _load_model_meta(self):
if self.cfg.load_checkpoint:
if self.cfg.load_checkpoint and self.cfg.restore_model_meta:
model_meta = self.cfg.model_meta.get(self.name, {})
self.policy.load_model_meta(model_meta)

Expand Down
2 changes: 1 addition & 1 deletion joyrl/framework/message.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum, unique
from typing import Optional
from typing import Optional, Any
from dataclasses import dataclass

@unique
Expand Down
11 changes: 6 additions & 5 deletions joyrl/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-22 13:16:59
LastEditor: JiangJi
LastEditTime: 2024-06-13 22:15:22
LastEditTime: 2024-06-14 09:41:35
Discription:
'''
import os,copy
Expand All @@ -15,7 +15,7 @@
from pathlib import Path
from joyrl.framework.config import GeneralConfig, MergedConfig, DefaultConfig
from joyrl.framework.trainer import Trainer
from joyrl.framework.utils import merge_class_attrs, all_seed, create_module,exec_method
from joyrl.framework.utils import merge_class_attrs, all_seed, load_model_meta

class Launcher(object):
def __init__(self, **kwargs):
Expand Down Expand Up @@ -162,13 +162,14 @@ def policy_config(self):
policy = policy_mod.Policy(self.cfg)
self.cfg.start_model_step = 0
if self.cfg.load_checkpoint:
policy.load_model(f"tasks/{self.cfg.load_path}/models/{self.cfg.load_model_step}")
self.cfg.model_meta = load_model_meta(f"tasks/{self.cfg.load_path}/models")
policy.load_model(model_path = f"tasks/{self.cfg.load_path}/models/{self.cfg.load_model_step}")
policy.save_model(f"{self.cfg.model_dir}/{self.cfg.load_model_step}")
if isinstance(self.cfg.load_model_step, int):
self.cfg.start_model_step = self.cfg.load_model_step
if str(self.cfg.load_model_step).startswith('best'):
if str(self.cfg.load_model_step).startswith('best') and self.cfg.restore_model_meta:
try:
self.cfg.start_model_step = int(self.cfg.load_model_step.split('_')[-1])
self.cfg.start_model_step = self.cfg.model_meta['OnlineTester']['best_model_step']
except:
self.cfg.start_model_step = 0
data_handler = data_handler_mod.DataHandler(self.cfg)
Expand Down
4 changes: 2 additions & 2 deletions offline_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-22 13:16:59
LastEditor: JiangJi
LastEditTime: 2024-06-14 09:27:40
LastEditTime: 2024-06-14 09:40:18
Discription:
'''
import os,copy
Expand Down Expand Up @@ -167,7 +167,7 @@ def policy_config(self):
policy.save_model(f"{self.cfg.model_dir}/{self.cfg.load_model_step}")
if isinstance(self.cfg.load_model_step, int):
self.cfg.start_model_step = self.cfg.load_model_step
if str(self.cfg.load_model_step).startswith('best'):
if str(self.cfg.load_model_step).startswith('best') and self.cfg.restore_model_meta:
try:
self.cfg.start_model_step = self.cfg.model_meta['OnlineTester']['best_model_step']
except:
Expand Down

0 comments on commit 3a47d54

Please sign in to comment.