diff --git a/test/test_env.py b/test/test_env.py index 415c973b6fb..bdb0283059c 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -112,6 +112,7 @@ from torchrl.envs import ( CatFrames, CatTensors, + ChessEnv, DoubleToFloat, EnvBase, EnvCreator, @@ -3380,6 +3381,111 @@ def test_partial_rest(self, batched): assert s["next", "string"] == ["6", "6"] +# fen strings for board positions generated with: +# https://lichess.org/editor +@pytest.mark.parametrize("stateful", [False, True]) +class TestChessEnv: + def test_env(self, stateful): + env = ChessEnv(stateful=stateful) + check_env_specs(env) + + def test_rollout(self, stateful): + env = ChessEnv(stateful=stateful) + env.rollout(5000) + + def test_reset_white_to_move(self, stateful): + env = ChessEnv(stateful=stateful) + fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1" + td = env.reset(TensorDict({"fen": fen})) + assert td["fen"] == fen + assert td["turn"] == env.lib.WHITE + assert not td["done"] + + def test_reset_black_to_move(self, stateful): + env = ChessEnv(stateful=stateful) + fen = "5k2/4r3/8/8/8/1Q6/2K5/8 b - - 0 1" + td = env.reset(TensorDict({"fen": fen})) + assert td["fen"] == fen + assert td["turn"] == env.lib.BLACK + assert not td["done"] + + def test_reset_done_error(self, stateful): + env = ChessEnv(stateful=stateful) + fen = "1R3k2/2R5/8/8/8/8/2K5/8 b - - 0 1" + with pytest.raises(ValueError) as e_info: + env.reset(TensorDict({"fen": fen})) + + assert "Cannot reset to a fen that is a gameover state" in str(e_info) + + @pytest.mark.parametrize("reset_without_fen", [False, True]) + @pytest.mark.parametrize( + "endstate", ["white win", "black win", "stalemate", "50 move", "insufficient"] + ) + def test_reward(self, stateful, reset_without_fen, endstate): + if stateful and reset_without_fen: + pytest.skip("reset_without_fen is only used for stateless env") + + env = ChessEnv(stateful=stateful) + + if endstate == "white win": + fen = "5k2/2R5/8/8/8/1R6/2K5/8 w - - 0 1" + expected_turn = env.lib.WHITE + move = "Rb8#" + expected_reward = 1 + expected_done = True + + elif endstate == "black win": + fen = "5k2/6r1/8/8/8/8/7r/1K6 b - - 0 1" + expected_turn = env.lib.BLACK + move = "Rg1#" + expected_reward = -1 + expected_done = True + + elif endstate == "stalemate": + fen = "5k2/6r1/8/8/8/8/7r/K7 b - - 0 1" + expected_turn = env.lib.BLACK + move = "Rb7" + expected_reward = 0 + expected_done = True + + elif endstate == "insufficient": + fen = "5k2/8/8/8/3r4/2K5/8/8 w - - 0 1" + expected_turn = env.lib.WHITE + move = "Kxd4" + expected_reward = 0 + expected_done = True + + elif endstate == "50 move": + fen = "5k2/8/1R6/8/6r1/2K5/8/8 b - - 99 123" + expected_turn = env.lib.BLACK + move = "Kf7" + expected_reward = 0 + expected_done = True + + elif endstate == "not_done": + fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" + expected_turn = env.lib.WHITE + move = "e4" + expected_reward = 0 + expected_done = False + + else: + raise RuntimeError(f"endstate not supported: {endstate}") + + if reset_without_fen: + td = TensorDict({"fen": fen}) + else: + td = env.reset(TensorDict({"fen": fen})) + assert td["turn"] == expected_turn + + moves = env.get_legal_moves(None if stateful else td) + td["action"] = moves.index(move) + td = env.step(td)["next"] + assert td["done"] == expected_done + assert td["reward"] == expected_reward + assert td["turn"] == (not expected_turn) + + class TestCustomEnvs: def test_tictactoe_env(self): torch.manual_seed(0) diff --git a/torchrl/envs/custom/chess.py b/torchrl/envs/custom/chess.py index f97f05b4d96..4dc5dbe5321 100644 --- a/torchrl/envs/custom/chess.py +++ b/torchrl/envs/custom/chess.py @@ -127,10 +127,13 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None): self._set_action_space(tensordict) return super().rand_action(tensordict) + def _is_done(self, board): + return board.is_game_over() | board.is_fifty_moves() + def _reset(self, tensordict=None): fen = None if tensordict is not None: - fen = self._get_fen(tensordict) + fen = self._get_fen(tensordict).data dest = tensordict.empty() else: dest = TensorDict() @@ -139,7 +142,11 @@ def _reset(self, tensordict=None): self.board.reset() fen = self.board.fen() else: - self.board.set_fen(fen.data) + self.board.set_fen(fen) + if self._is_done(self.board): + raise ValueError( + "Cannot reset to a fen that is a gameover state." f" fen: {fen}" + ) hashing = hash(fen) @@ -162,6 +169,38 @@ def _get_fen(cls, tensordict): fen = cls._hash_table.get(hashing.item()) return fen + def get_legal_moves(self, tensordict=None, uci=False): + """List the legal moves in a position. + + To choose one of the actions, the "action" key can be set to the index + of the move in this list. + + Args: + tensordict (TensorDict, optional): Tensordict containing the fen + string of a position. Required if not stateful. If stateful, + this argument is ignored and the current state of the env is + used instead. + + uci (bool, optional): If ``False``, moves are given in SAN format. + If ``True``, moves are given in UCI format. Default is + ``False``. + + """ + board = self.board + if not self.stateful: + if tensordict is None: + raise ValueError( + "tensordict must be given since this env is not stateful" + ) + fen = self._get_fen(tensordict).data + board.set_fen(fen) + moves = board.legal_moves + + if uci: + return [board.uci(move) for move in moves] + else: + return [board.san(move) for move in moves] + def _step(self, tensordict): # action action = tensordict.get("action") @@ -169,9 +208,8 @@ def _step(self, tensordict): if not self.stateful: fen = self._get_fen(tensordict).data board.set_fen(fen) - action = str(list(board.legal_moves)[action]) - # assert chess.Move.from_uci(action) in board.legal_moves - board.push_san(action) + action = list(board.legal_moves)[action] + board.push(action) self._set_action_space() # Collect data @@ -181,10 +219,15 @@ def _step(self, tensordict): dest.set("fen", fen) dest.set("hashing", hashing) - done = board.is_checkmate() turn = torch.tensor(board.turn) - reward = torch.tensor([done]).int() * (turn.int() * 2 - 1) - done = done | board.is_stalemate() | board.is_game_over() + if board.is_checkmate(): + # turn flips after every move, even if the game is over + winner = not turn + reward_val = 1 if winner == self.lib.WHITE else -1 + else: + reward_val = 0 + reward = torch.tensor([reward_val], dtype=torch.int32) + done = self._is_done(board) dest.set("reward", reward) dest.set("turn", turn) dest.set("done", [done])