diff --git a/experiments/solves/solve_collectables_sb3.py b/experiments/solves/solve_collectables_sb3.py index 26e2206..f2cc3b9 100644 --- a/experiments/solves/solve_collectables_sb3.py +++ b/experiments/solves/solve_collectables_sb3.py @@ -3,6 +3,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) +from absl import app from gymnasium import spaces from pysc2.env import sc2_env from stable_baselines3 import PPO @@ -12,70 +13,44 @@ from urnai.sc2.environments.sc2environment import SC2Env from urnai.sc2.rewards.collectables import CollectablesReward from urnai.sc2.states.collectables import CollectablesMethod, CollectablesState +from urnai.trainers.stablebaselines3_trainer import SB3Trainer -players = [sc2_env.Agent(sc2_env.Race.terran)] -env = SC2Env(map_name='CollectMineralShards', visualize=False, - step_mul=16, players=players) -state = CollectablesState(method=CollectablesMethod.STATE_NON_SPATIAL) -urnai_action_space = CollectablesActionSpace() -reward = CollectablesReward() -# Define action and observation space -action_space = spaces.Discrete(n=4, start=0) -observation_space = spaces.Box(low=0, high=255, shape=(2, ), dtype=float) +def declare_trainer(): + players = [sc2_env.Agent(sc2_env.Race.terran)] + env = SC2Env(map_name='CollectMineralShards', visualize=False, + step_mul=16, players=players) + state = CollectablesState(method=CollectablesMethod.STATE_NON_SPATIAL) + urnai_action_space = CollectablesActionSpace() + reward = CollectablesReward() -# Create the custom environment -custom_env = CustomEnv(env, state, urnai_action_space, reward, observation_space, - action_space) + # Define action and observation space + action_space = spaces.Discrete(n=4, start=0) + observation_space = spaces.Box(low=0, high=255, shape=(2, ), dtype=float) + # Create the custom environment + custom_env = CustomEnv(env, state, urnai_action_space, reward, observation_space, + action_space) -# models_dir = "saves/models/DQN" -models_dir = "saves/models/PPO" -logdir = "saves/logs" + # models_dir = "saves/models/DQN" + models_dir = "saves/models/PPO" + logdir = "saves/logs" -if not os.path.exists(models_dir): - os.makedirs(models_dir) + model=PPO("MlpPolicy", custom_env, verbose=1, tensorboard_log=logdir) -if not os.path.exists(logdir): - os.makedirs(logdir) + trainer = SB3Trainer(custom_env, models_dir, logdir, model) -# If training from scratch, uncomment 1 -# If loading a model, uncomment 2 + return trainer -## 1 - Train and Save model +def main(unused_argv): + try: + trainer = declare_trainer() + trainer.train_model(timesteps=10000, reset_num_timesteps=False, + tb_log_name="PPO", repeat_times=30) + # trainer.load_model(f"{trainer.models_dir}/290000") + trainer.test_model(total_steps=10000, deterministic=True) + except KeyboardInterrupt: + print("Training interrupted by user") -# model=DQN("MlpPolicy",custom_env,buffer_size=100000,verbose=1,tensorboard_log=logdir) -model=PPO("MlpPolicy", custom_env, verbose=1, tensorboard_log=logdir) - -TIMESTEPS = 10000 -for i in range(1,30): - model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name="PPO") - model.save(f"{models_dir}/{TIMESTEPS*i}") - -## 1 - End - -## 2 - Load model -# model = PPO.load(f"{models_dir}/290000.zip", env = custom_env) -## 2 - End - -vec_env = model.get_env() -obs = vec_env.reset() - -# Test model -total_episodes = 0 -total_reward = 0 - -total_steps = 10000 - -for _ in range(total_steps): - action, _state = model.predict(obs, deterministic=True) - obs, rewards, done, info = vec_env.step(action) - - total_reward += rewards - if done: - total_episodes += 1 - print(f"Episode: {total_episodes}, Total Reward: {total_reward}") - obs = vec_env.reset() # Reset the environment - total_reward = 0 # Reset reward for the new episode - -env.close() \ No newline at end of file +if __name__ == '__main__': + app.run(main) \ No newline at end of file diff --git a/urnai/trainers/stablebaselines3_trainer.py b/urnai/trainers/stablebaselines3_trainer.py new file mode 100644 index 0000000..fbb7c57 --- /dev/null +++ b/urnai/trainers/stablebaselines3_trainer.py @@ -0,0 +1,48 @@ +import os + +from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.type_aliases import MaybeCallback + + +class SB3Trainer: + def __init__(self, custom_env, models_dir, logdir, model : BaseAlgorithm): + self.custom_env = custom_env + self.models_dir = models_dir + self.model = model + + if not os.path.exists(models_dir): + os.makedirs(models_dir) + + if not os.path.exists(logdir): + os.makedirs(logdir) + + def load_model(self, model_path): + self.model = self.model.load(model_path, env = self.custom_env) + + def train_model(self, timesteps: int = 10000, callback: MaybeCallback = None, + log_interval: int = 100, tb_log_name: str = "run", + reset_num_timesteps: bool = True, progress_bar: bool = False, + repeat_times: int = 1): + for repeat_time in range(repeat_times): + self.model.learn(total_timesteps = timesteps, callback = callback, + log_interval = log_interval, tb_log_name = tb_log_name, + reset_num_timesteps = reset_num_timesteps, + progress_bar = progress_bar) + self.model.save(f"{self.models_dir}/{timesteps*(repeat_time + 1)}") + + def test_model(self, total_steps: int = 10000, deterministic: bool = True): + vec_env = self.model.get_env() + obs = vec_env.reset() + + total_episodes = 0 + total_reward = 0 + + for _ in range(total_steps): + action, _state = self.model.predict(obs, deterministic=deterministic) + obs, rewards, done, info = vec_env.step(action) + + total_reward += rewards + if done: + total_episodes += 1 + total_reward = 0 # Reset reward for the new episode + \ No newline at end of file