forked from tseip/fourinarow
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgame_tree_node.h
409 lines (362 loc) · 10.9 KB
/
game_tree_node.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
#ifndef GAME_TREE_NODE_H_INCLUDED
#define GAME_TREE_NODE_H_INCLUDED
#include <memory>
#include <queue>
#include <sstream>
#include "player.h"
/**
* Represents a single node in the game tree.
*
* @tparam Board The board representation used by the game.
*/
template <class Board>
class Node : public std::enable_shared_from_this<Node<Board>> {
public:
/**
* Allows for BFS iteration over the game tree.
*/
template <class It>
struct Iterator {
using iterator_category = std::forward_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = It;
using pointer = std::shared_ptr<value_type>;
using reference = value_type &;
/**
* Constructor.
*
* @param ptr The root node of the tree to iterate over.
*/
Iterator(pointer ptr) : nodes() {
if (ptr) nodes.push(ptr);
}
/**
* @return The node the iterator points to.
*/
reference operator*() const { return *nodes.front(); }
/**
* @return The node the iterator points to.
*/
pointer operator->() { return nodes.front(); }
/**
* Prefix increment.
*
* @return The next node in the iterator.
*/
Iterator &operator++() {
for (auto &node : (*this)->get_children()) {
nodes.push(node);
}
nodes.pop();
return *this;
}
/**
* Prefix increment.
*
* @return The node the iterator points to.
*/
Iterator operator++(int) {
Iterator tmp = *this;
++(*this);
return tmp;
}
/**
* Iterator comparison functions.
*
* @param a The first iterator to compare.
* @param b The second iterator to compare.
*
* @{
*/
friend bool operator==(const Iterator &a, const Iterator &b) {
return a.nodes == b.nodes;
};
friend bool operator!=(const Iterator &a, const Iterator &b) {
return a.nodes != b.nodes;
};
/**
* @}
*/
private:
/**
* The queue of nodes yet to be iterated over.
*/
std::queue<pointer> nodes;
};
/**
* @return An iterator that will iterate over the game tree under this node.
*/
Iterator<Node> begin() { return Iterator<Node>(this->shared_from_this()); }
/**
* @return An empty iterator.
*/
Iterator<Node> end() { return Iterator<Node>(nullptr); }
/**
* @return An iterator that will iterate over the game tree under this node.
*/
Iterator<const Node> begin() const {
return Iterator<const Node>(this->shared_from_this());
}
/**
* @return An empty iterator.
*/
Iterator<const Node> end() const { return Iterator<const Node>(nullptr); }
/**
* @return The node's heuristic value, determined by the search algorithm
* used.
* @note Since the base class has no search algorithm associated with it,
* we order arbitrarily by board position.
*/
virtual double get_value() const { return move.board_position; }
/**
* @return True if this node is solved, i.e. if the outcome of the game is
* known from this position.
*/
virtual bool determined() const = 0;
bool operator<(const Node &other) { return get_value() < other.get_value(); }
/**
* @return The board state represented by this node.
*/
const Board get_board() const { return board; }
/**
* @return The move prior to the board state represented by this node.
*/
const typename Board::MoveT get_move() const { return move; }
/**
* @return The children of this node.
*/
std::vector<std::shared_ptr<Node>> &get_children() { return children; }
/**
* @return The children of this node.
*/
const std::vector<std::shared_ptr<Node>> &get_children() const {
return children;
}
/**
* @return The depth of this node in the tree.
*/
std::size_t get_depth() const { return depth; }
/**
* @return The parent of this node in the tree.
*/
std::shared_ptr<Node> get_parent() const { return parent.lock(); }
/**
* @return A string representing the state of this node.
*/
virtual std::string to_string() const {
std::stringstream stream;
stream << "Position: " << move.board_position
<< ", Player: " << static_cast<size_t>(move.player)
<< ", Depth: " << depth;
return stream.str();
}
/**
* @param How many layers of children should also be printed.
*
* @return A string representing the state of this node.
*/
std::string to_string(std::size_t max_depth) const {
std::stringstream stream;
for (auto &node : *this) {
if (node.depth - depth >= max_depth) break;
stream << node.to_string() << std::endl;
}
return stream.str();
}
protected:
/**
* The children of this node.
*/
std::vector<std::shared_ptr<Node>> children;
/**
* The parent of this node.
*/
const std::weak_ptr<Node> parent;
/**
* The depth of this node in the game tree.
*/
const std::size_t depth;
/**
* The board state of this node with the move that this node represents
* included.
*/
const Board board;
/**
* The move that led to the current board state.
*/
const typename Board::MoveT move;
/**
* Private node constructor for nodes without meaningful move history (i.e.,
* with no parents).
*
* @param board The board state prior to the move being represented by this
* node is made.
*/
Node(const Board &board)
: children(), parent(), depth(1U), board(board), move() {}
/**
* Private node constructor for nodes with meaningful move history (i.e., with
* parents).
*
* @note Checking for nullity of parent is presumed to have been done outside
* of this function.
*
* @param parent The parent of this node.
* @param move The move represented by this node.
*/
Node(const std::shared_ptr<Node> parent, const typename Board::MoveT &move)
: children(),
parent(parent),
depth(1U + parent->depth),
board(parent->board + move),
move(move) {}
/**
* Helper function for updating the val, opt, and pess fields while respecting
* each player's outcome sign preference (Player1 prefers positive values,
* Player2 prefers negative values).
*
* @tparam Field The type of the field to be updated.
* @param child_field The field to compare against.
* @param field The field being compared and updated if the child_field is
* preferential for the player of this node's move.
*
* @return True if the field has been updated, false otherwise.
*/
template <class Field>
bool update_field_against_child(const Field &child_field,
Field &field) const {
if (board.active_player() == Player::Player1) {
if (child_field > field) {
field = child_field;
return true;
}
} else {
if (child_field < field) {
field = child_field;
return true;
}
}
return false;
}
/**
* Sums the depth of all of the leaf nodes in the game tree beneath us.
*
* @return The sum of the depth of all of the leaf nodes in the tree beneath
* us, including us.
*/
std::size_t get_sum_depth() const {
if (children.empty()) {
return depth;
}
std::size_t n = 0;
for (const auto &child : children) {
n += child->get_sum_depth();
}
return n;
}
public:
/**
* Accepts a list of moves that can be played from this position represented
* by this node and adds them to the game tree.
*
* @param moves A list of moves that can be played from this position.
*/
virtual void expand(const std::vector<typename Board::MoveT> &moves) = 0;
/**
* @return The number of moves between us and our recursively best known
* child.
*/
virtual std::size_t get_depth_of_pv() const = 0;
/**
* Select the next move to be searched from among our children recursively.
*
* @note We delegate this to a virtual function so that we don't have to
* define a const and non-const version of select in all subclasses.
*
* @return The best move from amongst our children and ourselves.
*/
std::shared_ptr<const Node> select() const { return virtual_select(); }
/**
* Select the next move to be searched from among our children recursively.
*
* @return The next move from amongst our children and ourselves that ought to
* be expanded.
*/
std::shared_ptr<Node> select() {
return std::const_pointer_cast<Node>(
std::const_pointer_cast<const Node>(this->shared_from_this())
->select());
}
/**
* Find the number of leaf nodes beneath us, including us. A leaf node is
* defined as a node that has no best known children.
*
* @return The number of leaf nodes beneath us, including us.
*/
std::size_t get_num_leaves() const {
if (children.empty()) {
return 1;
}
std::size_t n = 0;
for (const auto &child : children) {
n += child->get_num_leaves();
}
return n;
}
/**
* @return The number of all nodes in the tree.
*/
std::size_t get_node_count() const {
std::size_t node_count = 1;
for (const auto &child : children) {
node_count += child->get_node_count();
}
return node_count;
}
/**
* @return The average branching factor of the tree.
*/
double get_average_branching_factor() const {
if (children.empty()) return 0.0;
return static_cast<double>(get_node_count() - 1) / get_num_internal_nodes();
}
/**
* Find the number of internal nodes beneath us. Only nodes with
* children are internal nodes.
*
* @return The number of internal nodes beneath us, including us.
*/
std::size_t get_num_internal_nodes() const {
if (children.empty()) {
return 0;
}
std::size_t n = 1;
for (const auto &child : children) {
n += child->get_num_internal_nodes();
}
return n;
}
/**
* @return The average depth of all of the leaf nodes in the tree beneath us,
* including us.
*/
double get_mean_depth() const {
return (static_cast<double>(get_sum_depth()) / get_num_leaves());
}
/**
* @return The best known move from the current position for the current
* player.
*/
virtual typename Board::MoveT get_best_move() const = 0;
protected:
/**
* Select the next move to be searched from among our children recursively.
*
* @note We split this out so that we don't have to define a const and
* non-const version of select in all subclasses.
*
* @return The best move from amongst our children and ourselves.
*/
virtual std::shared_ptr<const Node> virtual_select() const = 0;
};
#endif // GAME_TREE_NODE_H_INCLUDED