-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[0.6.3.1] [feat]: add cfg.restore_model_meta
- Loading branch information
1 parent
e93cfa2
commit 3a47d54
Showing
7 changed files
with
17 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,13 +5,13 @@ | |
Email: [email protected] | ||
Date: 2023-01-01 16:20:49 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-06-14 09:33:39 | ||
LastEditTime: 2024-06-14 09:41:26 | ||
Discription: | ||
''' | ||
from joyrl import algos, framework, envs | ||
from joyrl.run import run | ||
|
||
__version__ = "0.6.3" | ||
__version__ = "0.6.3.1" | ||
|
||
__all__ = [ | ||
"algos", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-02 15:30:09 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-06-11 13:31:40 | ||
LastEditTime: 2024-06-14 09:39:48 | ||
Discription: | ||
''' | ||
class DefaultConfig: | ||
|
@@ -52,6 +52,7 @@ def __init__(self) -> None: | |
self.model_save_fre = 500 # model save frequency per update step | ||
# load model settings | ||
self.load_checkpoint = False # if load checkpoint | ||
self.restore_model_meta = True # if restore model meta | ||
self.load_path = "Train_single_CartPole-v1_DQN_20230515-211721" # path to load model | ||
self.load_model_step = 'best' # load model at which step | ||
# stats recorder settings | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2024-02-25 15:46:04 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-06-14 09:34:08 | ||
LastEditTime: 2024-06-14 09:40:51 | ||
Discription: | ||
''' | ||
import copy | ||
|
@@ -49,7 +49,7 @@ def _init_n_sample_steps(self): | |
self.n_sample_steps = float('inf') | ||
|
||
def _load_model_meta(self): | ||
if self.cfg.load_checkpoint: | ||
if self.cfg.load_checkpoint and self.cfg.restore_model_meta: | ||
model_meta = self.cfg.model_meta.get(self.name, {}) | ||
self.policy.load_model_meta(model_meta) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-02 15:02:30 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-06-14 09:34:20 | ||
LastEditTime: 2024-06-14 09:40:42 | ||
Discription: | ||
''' | ||
import numpy as np | ||
|
@@ -38,7 +38,7 @@ def _init_update_steps(self): | |
self.n_update_steps = float('inf') | ||
|
||
def _load_model_meta(self): | ||
if self.cfg.load_checkpoint: | ||
if self.cfg.load_checkpoint and self.cfg.restore_model_meta: | ||
model_meta = self.cfg.model_meta.get(self.name, {}) | ||
self.policy.load_model_meta(model_meta) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-22 13:16:59 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-06-13 22:15:22 | ||
LastEditTime: 2024-06-14 09:41:35 | ||
Discription: | ||
''' | ||
import os,copy | ||
|
@@ -15,7 +15,7 @@ | |
from pathlib import Path | ||
from joyrl.framework.config import GeneralConfig, MergedConfig, DefaultConfig | ||
from joyrl.framework.trainer import Trainer | ||
from joyrl.framework.utils import merge_class_attrs, all_seed, create_module,exec_method | ||
from joyrl.framework.utils import merge_class_attrs, all_seed, load_model_meta | ||
|
||
class Launcher(object): | ||
def __init__(self, **kwargs): | ||
|
@@ -162,13 +162,14 @@ def policy_config(self): | |
policy = policy_mod.Policy(self.cfg) | ||
self.cfg.start_model_step = 0 | ||
if self.cfg.load_checkpoint: | ||
policy.load_model(f"tasks/{self.cfg.load_path}/models/{self.cfg.load_model_step}") | ||
self.cfg.model_meta = load_model_meta(f"tasks/{self.cfg.load_path}/models") | ||
policy.load_model(model_path = f"tasks/{self.cfg.load_path}/models/{self.cfg.load_model_step}") | ||
policy.save_model(f"{self.cfg.model_dir}/{self.cfg.load_model_step}") | ||
if isinstance(self.cfg.load_model_step, int): | ||
self.cfg.start_model_step = self.cfg.load_model_step | ||
if str(self.cfg.load_model_step).startswith('best'): | ||
if str(self.cfg.load_model_step).startswith('best') and self.cfg.restore_model_meta: | ||
try: | ||
self.cfg.start_model_step = int(self.cfg.load_model_step.split('_')[-1]) | ||
self.cfg.start_model_step = self.cfg.model_meta['OnlineTester']['best_model_step'] | ||
except: | ||
self.cfg.start_model_step = 0 | ||
data_handler = data_handler_mod.DataHandler(self.cfg) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-22 13:16:59 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-06-14 09:27:40 | ||
LastEditTime: 2024-06-14 09:40:18 | ||
Discription: | ||
''' | ||
import os,copy | ||
|
@@ -167,7 +167,7 @@ def policy_config(self): | |
policy.save_model(f"{self.cfg.model_dir}/{self.cfg.load_model_step}") | ||
if isinstance(self.cfg.load_model_step, int): | ||
self.cfg.start_model_step = self.cfg.load_model_step | ||
if str(self.cfg.load_model_step).startswith('best'): | ||
if str(self.cfg.load_model_step).startswith('best') and self.cfg.restore_model_meta: | ||
try: | ||
self.cfg.start_model_step = self.cfg.model_meta['OnlineTester']['best_model_step'] | ||
except: | ||
|