Skip to content

Commit

Permalink
refactor: Added stablebaselines3 trainer class
Browse files Browse the repository at this point in the history
  • Loading branch information
RickFqt committed Nov 14, 2024
1 parent b67bd85 commit 8a70ca9
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 57 deletions.
89 changes: 32 additions & 57 deletions experiments/solves/solve_collectables_sb3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

from absl import app
from gymnasium import spaces
from pysc2.env import sc2_env
from stable_baselines3 import PPO
Expand All @@ -12,70 +13,44 @@
from urnai.sc2.environments.sc2environment import SC2Env
from urnai.sc2.rewards.collectables import CollectablesReward
from urnai.sc2.states.collectables import CollectablesMethod, CollectablesState
from urnai.trainers.stablebaselines3_trainer import SB3Trainer

players = [sc2_env.Agent(sc2_env.Race.terran)]
env = SC2Env(map_name='CollectMineralShards', visualize=False,
step_mul=16, players=players)
state = CollectablesState(method=CollectablesMethod.STATE_NON_SPATIAL)
urnai_action_space = CollectablesActionSpace()
reward = CollectablesReward()

# Define action and observation space
action_space = spaces.Discrete(n=4, start=0)
observation_space = spaces.Box(low=0, high=255, shape=(2, ), dtype=float)
def declare_trainer():
players = [sc2_env.Agent(sc2_env.Race.terran)]
env = SC2Env(map_name='CollectMineralShards', visualize=False,
step_mul=16, players=players)
state = CollectablesState(method=CollectablesMethod.STATE_NON_SPATIAL)
urnai_action_space = CollectablesActionSpace()
reward = CollectablesReward()

# Create the custom environment
custom_env = CustomEnv(env, state, urnai_action_space, reward, observation_space,
action_space)
# Define action and observation space
action_space = spaces.Discrete(n=4, start=0)
observation_space = spaces.Box(low=0, high=255, shape=(2, ), dtype=float)

# Create the custom environment
custom_env = CustomEnv(env, state, urnai_action_space, reward, observation_space,
action_space)

# models_dir = "saves/models/DQN"
models_dir = "saves/models/PPO"
logdir = "saves/logs"
# models_dir = "saves/models/DQN"
models_dir = "saves/models/PPO"
logdir = "saves/logs"

if not os.path.exists(models_dir):
os.makedirs(models_dir)
model=PPO("MlpPolicy", custom_env, verbose=1, tensorboard_log=logdir)

if not os.path.exists(logdir):
os.makedirs(logdir)
trainer = SB3Trainer(custom_env, models_dir, logdir, model)

# If training from scratch, uncomment 1
# If loading a model, uncomment 2
return trainer

## 1 - Train and Save model
def main(unused_argv):
try:
trainer = declare_trainer()
trainer.train_model(timesteps=10000, reset_num_timesteps=False,
tb_log_name="PPO", repeat_times=30)
# trainer.load_model(f"{trainer.models_dir}/290000")
trainer.test_model(total_steps=10000, deterministic=True)
except KeyboardInterrupt:
print("Training interrupted by user")

# model=DQN("MlpPolicy",custom_env,buffer_size=100000,verbose=1,tensorboard_log=logdir)
model=PPO("MlpPolicy", custom_env, verbose=1, tensorboard_log=logdir)

TIMESTEPS = 10000
for i in range(1,30):
model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name="PPO")
model.save(f"{models_dir}/{TIMESTEPS*i}")

## 1 - End

## 2 - Load model
# model = PPO.load(f"{models_dir}/290000.zip", env = custom_env)
## 2 - End

vec_env = model.get_env()
obs = vec_env.reset()

# Test model
total_episodes = 0
total_reward = 0

total_steps = 10000

for _ in range(total_steps):
action, _state = model.predict(obs, deterministic=True)
obs, rewards, done, info = vec_env.step(action)

total_reward += rewards
if done:
total_episodes += 1
print(f"Episode: {total_episodes}, Total Reward: {total_reward}")
obs = vec_env.reset() # Reset the environment
total_reward = 0 # Reset reward for the new episode

env.close()
if __name__ == '__main__':
app.run(main)
48 changes: 48 additions & 0 deletions urnai/trainers/stablebaselines3_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os

from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.type_aliases import MaybeCallback


class SB3Trainer:
def __init__(self, custom_env, models_dir, logdir, model : BaseAlgorithm):
self.custom_env = custom_env
self.models_dir = models_dir
self.model = model

if not os.path.exists(models_dir):
os.makedirs(models_dir)

if not os.path.exists(logdir):
os.makedirs(logdir)

def load_model(self, model_path):
self.model = self.model.load(model_path, env = self.custom_env)

def train_model(self, timesteps: int = 10000, callback: MaybeCallback = None,
log_interval: int = 100, tb_log_name: str = "run",
reset_num_timesteps: bool = True, progress_bar: bool = False,
repeat_times: int = 1):
for repeat_time in range(repeat_times):
self.model.learn(total_timesteps = timesteps, callback = callback,
log_interval = log_interval, tb_log_name = tb_log_name,
reset_num_timesteps = reset_num_timesteps,
progress_bar = progress_bar)
self.model.save(f"{self.models_dir}/{timesteps*(repeat_time + 1)}")

def test_model(self, total_steps: int = 10000, deterministic: bool = True):
vec_env = self.model.get_env()
obs = vec_env.reset()

total_episodes = 0
total_reward = 0

for _ in range(total_steps):
action, _state = self.model.predict(obs, deterministic=deterministic)
obs, rewards, done, info = vec_env.step(action)

total_reward += rewards
if done:
total_episodes += 1
total_reward = 0 # Reset reward for the new episode

0 comments on commit 8a70ca9

Please sign in to comment.