From 3a47d54498f755bfaa7a49449c3dff1c3b6071b0 Mon Sep 17 00:00:00 2001 From: johnjim0816 Date: Fri, 14 Jun 2024 09:42:16 +0800 Subject: [PATCH] [0.6.3.1] [feat]: add cfg.restore_model_meta --- joyrl/__init__.py | 4 ++-- joyrl/framework/config.py | 3 ++- joyrl/framework/interactor.py | 4 ++-- joyrl/framework/learner.py | 4 ++-- joyrl/framework/message.py | 2 +- joyrl/run.py | 11 ++++++----- offline_run.py | 4 ++-- 7 files changed, 17 insertions(+), 15 deletions(-) diff --git a/joyrl/__init__.py b/joyrl/__init__.py index 412f097..d93e9b5 100644 --- a/joyrl/__init__.py +++ b/joyrl/__init__.py @@ -5,13 +5,13 @@ Email: johnjim0816@gmail.com Date: 2023-01-01 16:20:49 LastEditor: JiangJi -LastEditTime: 2024-06-14 09:33:39 +LastEditTime: 2024-06-14 09:41:26 Discription: ''' from joyrl import algos, framework, envs from joyrl.run import run -__version__ = "0.6.3" +__version__ = "0.6.3.1" __all__ = [ "algos", diff --git a/joyrl/framework/config.py b/joyrl/framework/config.py index 3d55a33..de94a65 100644 --- a/joyrl/framework/config.py +++ b/joyrl/framework/config.py @@ -5,7 +5,7 @@ Email: johnjim0816@gmail.com Date: 2023-12-02 15:30:09 LastEditor: JiangJi -LastEditTime: 2024-06-11 13:31:40 +LastEditTime: 2024-06-14 09:39:48 Discription: ''' class DefaultConfig: @@ -52,6 +52,7 @@ def __init__(self) -> None: self.model_save_fre = 500 # model save frequency per update step # load model settings self.load_checkpoint = False # if load checkpoint + self.restore_model_meta = True # if restore model meta self.load_path = "Train_single_CartPole-v1_DQN_20230515-211721" # path to load model self.load_model_step = 'best' # load model at which step # stats recorder settings diff --git a/joyrl/framework/interactor.py b/joyrl/framework/interactor.py index 21bf253..fcc0fe9 100644 --- a/joyrl/framework/interactor.py +++ b/joyrl/framework/interactor.py @@ -5,7 +5,7 @@ Email: johnjim0816@gmail.com Date: 2024-02-25 15:46:04 LastEditor: JiangJi -LastEditTime: 2024-06-14 09:34:08 +LastEditTime: 2024-06-14 09:40:51 Discription: ''' import copy @@ -49,7 +49,7 @@ def _init_n_sample_steps(self): self.n_sample_steps = float('inf') def _load_model_meta(self): - if self.cfg.load_checkpoint: + if self.cfg.load_checkpoint and self.cfg.restore_model_meta: model_meta = self.cfg.model_meta.get(self.name, {}) self.policy.load_model_meta(model_meta) diff --git a/joyrl/framework/learner.py b/joyrl/framework/learner.py index f40aba7..386ca39 100644 --- a/joyrl/framework/learner.py +++ b/joyrl/framework/learner.py @@ -5,7 +5,7 @@ Email: johnjim0816@gmail.com Date: 2023-12-02 15:02:30 LastEditor: JiangJi -LastEditTime: 2024-06-14 09:34:20 +LastEditTime: 2024-06-14 09:40:42 Discription: ''' import numpy as np @@ -38,7 +38,7 @@ def _init_update_steps(self): self.n_update_steps = float('inf') def _load_model_meta(self): - if self.cfg.load_checkpoint: + if self.cfg.load_checkpoint and self.cfg.restore_model_meta: model_meta = self.cfg.model_meta.get(self.name, {}) self.policy.load_model_meta(model_meta) diff --git a/joyrl/framework/message.py b/joyrl/framework/message.py index f4d3c0b..986beee 100644 --- a/joyrl/framework/message.py +++ b/joyrl/framework/message.py @@ -1,5 +1,5 @@ from enum import Enum, unique -from typing import Optional +from typing import Optional, Any from dataclasses import dataclass @unique diff --git a/joyrl/run.py b/joyrl/run.py index 82bb8a6..6e412eb 100644 --- a/joyrl/run.py +++ b/joyrl/run.py @@ -5,7 +5,7 @@ Email: johnjim0816@gmail.com Date: 2023-12-22 13:16:59 LastEditor: JiangJi -LastEditTime: 2024-06-13 22:15:22 +LastEditTime: 2024-06-14 09:41:35 Discription: ''' import os,copy @@ -15,7 +15,7 @@ from pathlib import Path from joyrl.framework.config import GeneralConfig, MergedConfig, DefaultConfig from joyrl.framework.trainer import Trainer -from joyrl.framework.utils import merge_class_attrs, all_seed, create_module,exec_method +from joyrl.framework.utils import merge_class_attrs, all_seed, load_model_meta class Launcher(object): def __init__(self, **kwargs): @@ -162,13 +162,14 @@ def policy_config(self): policy = policy_mod.Policy(self.cfg) self.cfg.start_model_step = 0 if self.cfg.load_checkpoint: - policy.load_model(f"tasks/{self.cfg.load_path}/models/{self.cfg.load_model_step}") + self.cfg.model_meta = load_model_meta(f"tasks/{self.cfg.load_path}/models") + policy.load_model(model_path = f"tasks/{self.cfg.load_path}/models/{self.cfg.load_model_step}") policy.save_model(f"{self.cfg.model_dir}/{self.cfg.load_model_step}") if isinstance(self.cfg.load_model_step, int): self.cfg.start_model_step = self.cfg.load_model_step - if str(self.cfg.load_model_step).startswith('best'): + if str(self.cfg.load_model_step).startswith('best') and self.cfg.restore_model_meta: try: - self.cfg.start_model_step = int(self.cfg.load_model_step.split('_')[-1]) + self.cfg.start_model_step = self.cfg.model_meta['OnlineTester']['best_model_step'] except: self.cfg.start_model_step = 0 data_handler = data_handler_mod.DataHandler(self.cfg) diff --git a/offline_run.py b/offline_run.py index ba13b7b..2821a78 100644 --- a/offline_run.py +++ b/offline_run.py @@ -5,7 +5,7 @@ Email: johnjim0816@gmail.com Date: 2023-12-22 13:16:59 LastEditor: JiangJi -LastEditTime: 2024-06-14 09:27:40 +LastEditTime: 2024-06-14 09:40:18 Discription: ''' import os,copy @@ -167,7 +167,7 @@ def policy_config(self): policy.save_model(f"{self.cfg.model_dir}/{self.cfg.load_model_step}") if isinstance(self.cfg.load_model_step, int): self.cfg.start_model_step = self.cfg.load_model_step - if str(self.cfg.load_model_step).startswith('best'): + if str(self.cfg.load_model_step).startswith('best') and self.cfg.restore_model_meta: try: self.cfg.start_model_step = self.cfg.model_meta['OnlineTester']['best_model_step'] except: