Commit d2369b3e authored by biluo.shen's avatar biluo.shen

Add mcts

parent 722dd65a
...@@ -6,7 +6,22 @@ YGO Agent is a project to create a Yu-Gi-Oh! AI using deep learning (LLMs, RL). ...@@ -6,7 +6,22 @@ YGO Agent is a project to create a Yu-Gi-Oh! AI using deep learning (LLMs, RL).
`ygoenv` is a high performance game environment for Yu-Gi-Oh! It is initially inspired by [yugioh-ai](https://github.com/melvinzhang/yugioh-ai]) and [yugioh-game](https://github.com/tspivey/yugioh-game), and now implemented on top of [envpool](https://github.com/sail-sg/envpool). `ygoenv` is a high performance game environment for Yu-Gi-Oh! It is initially inspired by [yugioh-ai](https://github.com/melvinzhang/yugioh-ai]) and [yugioh-game](https://github.com/tspivey/yugioh-game), and now implemented on top of [envpool](https://github.com/sail-sg/envpool).
## ygoai ## ygoai
`ygoai` is a set of AI agents for playing Yu-Gi-Oh! It aims to achieve superhuman performance like AlphaGo. Currently, only a RL-based agent is implemented. `ygoai` is a set of AI agents for playing Yu-Gi-Oh! It aims to achieve superhuman performance like AlphaGo and AlphaZero, with or without human knowledge. Currently, we focus on using reinforcement learning to train the agents.
## TODO
### Documentation
- Add documentations of building and running
### Training
- Eval with old models during training
- MTCS-based training
### Inference
- MCTS-based planning
- Support of play in YGOPro
## Usage ## Usage
TODO TODO
......
...@@ -15,7 +15,7 @@ Not supported ...@@ -15,7 +15,7 @@ Not supported
- `min` > 5 throws an error - `min` > 5 throws an error
- `max` > 5 is truncated to 5 - `max` > 5 is truncated to 5
### Unsupported ### related cards
- Fairy Tail - Snow (min=max=7) - Fairy Tail - Snow (min=max=7)
- Pot of Prosperity (min=max=6) - Pot of Prosperity (min=max=6)
......
from typing import List, Tuple, Union
import numpy as np
from . import alphazero_mcts as tree
# ==============================================================
# AlphaZero
# ==============================================================
def select_action(
visit_counts: np.ndarray,
temperature: float = 1.0,
deterministic: bool = True
):
"""
Select action from visit counts of the root node.
Parameters
----------
visit_counts: np.ndarray, shape (n_legal_actions,)
The visit counts of the root node.
temperature: float, default: 1.0
The temperature used to adjust the sampling distribution.
deterministic: bool, default: True
Whether to enable deterministic mode in action selection. True means to
select the argmax result, False indicates to sample action from the distribution.
Returns
-------
action_pos: np.int64
The selected action position (index).
"""
if deterministic:
action_pos = np.argmax(visit_counts)
else:
if temperature != 1:
visit_counts = visit_counts ** (1 / temperature)
action_probs = visit_counts / np.sum(visit_counts)
action_pos = np.random.choice(len(visit_counts), p=action_probs)
return action_pos
class AlphaZeroMCTSCtree(object):
"""
MCTSCtree for AlphaZero. The core ``batch_traverse``, ``batch_expand`` and ``batch_backpropagate`` function is implemented in C++.
Interfaces
----------
__init__, tree_search
"""
def __init__(
self,
env,
predict_fn=None,
root_dirichlet_alpha: float = 0.3,
root_exploration_fraction: float = 0.25,
pb_c_init: float = 1.25,
pb_c_base: float = 19652,
discount_factor=0.99,
value_delta_max=0.01,
seed: int = 0,
):
"""
Initialize the MCTSCtree for AlphaZero.
Parameters
----------
env:
The game model.
predict_fn: Callable[Obs, [np.ndarray, np.ndarray]]
The function used to predict the policy and value.
root_dirichlet_alpha: float, optional, default: 0.25
The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree.
root_exploration_fraction: float, default: 0.25
The noise weight at the root node of the search tree.
pb_c_init: float, default: 1.25
The initialization constant used in the PUCT formula for balancing exploration and exploitation during tree search.
pb_c_base: int, default: 19652
The base constant used in the PUCT formula for balancing exploration and exploitation during tree search.
discount_factor: float, default: 0.99
The discount factor used to calculate the return.
value_delta_max: float, default: 0.01
The maximum change in value allowed during the backup step of the search tree update.
seed: int, default: 0
The random seed.
"""
self._env = env
self._predict_fn = predict_fn
self._root_dirichlet_alpha = root_dirichlet_alpha
self._root_exploration_fraction = root_exploration_fraction
self._pb_c_init = pb_c_init
self._pb_c_base = pb_c_base
self._discount_factor = discount_factor
self._value_delta_max = value_delta_max
tree.init_module(seed)
def set_predict_fn(self, predict_fn):
self._predict_fn = predict_fn
def tree_search(
self,
init_states,
num_simulations: int,
temperature: float = 1.0,
root_exploration_fraction: float = None,
sample: bool = False,
return_roots: bool = False
) -> Union[
Tuple[np.ndarray, np.ndarray, np.ndarray],
Tuple[np.ndarray, np.ndarray, np.ndarray, tree.Roots]]:
"""
Perform MCTS for a batch of root nodes in parallel using the cpp ctree.
Parameters
----------
init_states : State, shape (parallel_episodes,)
The states of the roots.
num_simulations : int
The number of simulations to run for each root.
temperature : float, default: 1.0
The temperature used to adjust the sampling distribution.
sample : bool, default: False
Whether to sample action for acting. If False, select the argmax result.
return_roots : bool, default: False
Whether to return the roots.
Returns
-------
probs : Tuple[np.ndarray], shape (parallel_episodes, action_dim1), (parallel_episodes, action_dim2), ...
The target action probabilities of the roots for learning.
values : np.ndarray, shape (parallel_episodes,)
The target Q values of the roots.
action : np.ndarray, shape (parallel_episodes, n_actions)
The selected action of the roots for acting.
roots : Roots, optional
The roots after search. Only returned if return_roots is True.
"""
assert self._predict_fn is not None, "The predict function is not set."
if root_exploration_fraction is None:
root_exploration_fraction = self._root_exploration_fraction
batch_size = len(init_states) # parallel_episodes
obs, all_legal_actions, n_legal_actions = self._env.observation(init_states)
legal_actions_list = []
offset = 0
for i in range(batch_size):
legal_actions_list.append(all_legal_actions[offset: offset + n_legal_actions[i]].tolist())
offset += n_legal_actions[i]
logits, pred_values = self._predict_fn(obs)
game_over, rewards = self._env.terminal(init_states)
init_legal_actions_list = legal_actions_list
roots = tree.Roots(batch_size)
roots.prepare(rewards, logits, all_legal_actions, n_legal_actions, root_exploration_fraction, self._root_dirichlet_alpha)
# the data storage of states: storing the state of all the nodes in the search.
# shape: (num_simulations, batch_size)
state_batch_in_search_path = [init_states]
# minimax value storage
min_max_stats_lst = tree.MinMaxStatsList(batch_size)
min_max_stats_lst.set_delta(self._value_delta_max)
for i_sim in range(1, num_simulations + 1):
# In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most.
# prepare a result wrapper to transport results between python and c++ parts
results = tree.SearchResults(batch_size)
# state_index_in_search_path: the first index of leaf node states in state_batch_in_search_path, i.e. is state_index in one the search.
# state_index_in_batch: the second index of leaf node states in state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``.
# e.g. the state of the leaf node in (x, y) is state_batch_in_search_path[x, y], where x is state_index, y is batch_index.
# The index of value prefix hidden state of the leaf node are in the same manner.
"""
MCTS stage 1: Selection
Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l.
"""
state_index_in_search_path, state_index_in_batch, last_actions = tree.batch_traverse(
roots, self._pb_c_base, self._pb_c_init, self._discount_factor, min_max_stats_lst, results)
# obtain the state for leaf node
states = []
for ix, iy in zip(state_index_in_search_path, state_index_in_batch):
states.append(state_batch_in_search_path[ix][iy])
states = self._env.from_state_list(states)
last_actions = np.array(last_actions)
"""
MCTS stage 2: Expansion
At the final time-step l of the simulation, the state and reward/value_prefix are computed by the dynamics function.
Then we calculate the policy_logits and value for the leaf node (state) by the prediction function. (aka. evaluation)
"""
states, obs, all_legal_actions, n_legal_actions, rewards, dones = self._env.step(states, last_actions)
logits, pred_values = self._predict_fn(obs)
values = pred_values.reshape(-1)
values = np.where(dones, rewards, values)
state_batch_in_search_path.append(states)
tree.batch_expand(
i_sim, dones, rewards, logits, all_legal_actions, n_legal_actions, results)
"""
MCTS stage 3: Backup
At the end of the simulation, the statistics along the trajectory are updated.
"""
# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# ``reward`` predicted by the model, then perform backpropagation along the search path to update the
# statistics.
tree.batch_backpropagate(
self._discount_factor, values, min_max_stats_lst, results)
all_probs, all_values, all_actions = self._predict(
roots, init_legal_actions_list, temperature, sample)
if return_roots:
return all_probs, all_values, all_actions, roots
return all_probs, all_values, all_actions
def _predict(
self,
roots: tree.Roots,
legal_actions_list: List[List[int]],
temperature: float = 1.0,
sample: bool = True
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Get the target action probabilities, values for learning and actions
for acting from the roots after search.
Parameters
----------
roots : Roots
The roots after search.
legal_actions_list : List[List[int]], shape (n_legals_1,), (n_legals_2,), ..., (n_legals_n,)
The list of legal actions for each state.
temperature : float, default: 1.0
The temperature used to adjust the sampling distribution.
sample : bool, default: False
Whether to sample action for acting. If False, select the argmax result.
Returns
-------
probs : Tuple[np.ndarray], shape (parallel_episodes, action_dim1), (parallel_episodes, action_dim2), ...
The target action probabilities of the roots for learning.
values : np.ndarray, shape (parallel_episodes,)
The target Q values of the roots.
action : np.ndarray, shape (parallel_episodes, n_actions)
The selected action of the roots for acting.
"""
action_dim = self._env.action_space.n
batch_size = roots.num
# list: (batch_size, n_legal_actions)
roots_visit_counts = roots.get_distributions()
roots_values = roots.get_values() # list: (batch_size,)
all_probs = np.zeros((batch_size, action_dim), dtype=np.float32)
all_actions = np.zeros(batch_size, dtype=np.int32)
all_values = np.zeros(batch_size, dtype=np.float32)
for i in range(batch_size):
visit_counts, value = roots_visit_counts[i], roots_values[i]
visit_counts = np.array(visit_counts)
action_index_in_legal_action_set = select_action(
visit_counts, temperature=temperature, deterministic=not sample)
# NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the
# entire action set.
legal_actions = legal_actions_list[i]
action = legal_actions[action_index_in_legal_action_set]
probs = visit_counts / sum(visit_counts)
all_probs[i, legal_actions] = probs
all_actions[i] = action
all_values[i] = value
return all_probs, all_values, all_actions
from __future__ import annotations
import numpy
__all__ = ['MinMaxStatsList', 'Roots', 'SearchResults', 'batch_backpropagate', 'batch_expand', 'batch_traverse', 'init_module']
class MinMaxStatsList:
def __init__(self, arg0: int) -> None:
...
def set_delta(self, arg0: float) -> None:
...
class Roots:
def __init__(self, arg0: int) -> None:
...
def get_distributions(self) -> list[list[int]]:
...
def get_values(self) -> list[float]:
...
def prepare(self, arg0: numpy.ndarray[numpy.float32], arg1: numpy.ndarray[numpy.float32], arg2: numpy.ndarray[numpy.int32], arg3: numpy.ndarray[numpy.int32], arg4: float, arg5: float) -> None:
...
@property
def num(self) -> int:
...
class SearchResults:
def __init__(self, arg0: int) -> None:
...
def get_search_len(self) -> list[int]:
...
def batch_backpropagate(arg0: float, arg1: numpy.ndarray[numpy.float32], arg2: MinMaxStatsList, arg3: SearchResults) -> None:
...
def batch_expand(arg0: int, arg1: numpy.ndarray[bool], arg2: numpy.ndarray[numpy.float32], arg3: numpy.ndarray[numpy.float32], arg4: numpy.ndarray[numpy.int32], arg5: numpy.ndarray[numpy.int32], arg6: SearchResults) -> None:
...
def batch_traverse(arg0: Roots, arg1: int, arg2: float, arg3: float, arg4: MinMaxStatsList, arg5: SearchResults) -> tuple:
...
def init_module(seed: int) -> None:
"""
(asd
,)
"""
#include <algorithm>
#include <map>
#include <cassert>
#include "mcts/alphazero/cnode.h"
namespace tree
{
std::mt19937 rng_ = std::mt19937(time(NULL));
void init_module(int seed) {
rng_ = std::mt19937(seed);
}
template <class RealType>
std::vector<RealType> random_dirichlet(RealType alpha, int n) {
std::gamma_distribution<RealType> gamma(alpha, 1);
std::vector<RealType> x(n);
RealType sum = 0.0;
for (int i = 0; i < n; i++){
x[i] = gamma(rng_);
sum += x[i];
}
for (int i = 0; i < n; i++) {
x[i] = x[i] / sum;
}
return x;
}
SearchResults::SearchResults()
{
/*
Overview:
Initialization of SearchResults, the default result number is set to 0.
*/
this->num = 0;
}
SearchResults::SearchResults(int num)
{
/*
Overview:
Initialization of SearchResults with result number.
*/
this->num = num;
for (int i = 0; i < num; ++i)
{
this->search_paths.push_back(std::vector<Node *>());
}
}
SearchResults::~SearchResults() {}
//*********************************************************
Node::Node()
{
/*
Overview:
Initialization of Node.
*/
this->prior = 0;
this->visit_count = 0;
this->value_sum = 0;
this->best_action = -1;
this->reward = 0.0;
}
Node::Node(float prior)
{
/*
Overview:
Initialization of Node with prior value and legal actions.
Arguments:
- prior: the prior value of this node.
- legal_actions: a vector of legal actions of this node.
*/
this->prior = prior;
this->visit_count = 0;
this->value_sum = 0;
this->best_action = -1;
this->reward = 0.0;
this->state_index = -1;
this->batch_index = -1;
}
Node::~Node() {}
void Node::expand(
int state_index, int batch_index, float reward, const Array &logits, const Array &legal_actions)
{
/*
Overview:
Expand the child nodes of the current node.
Arguments:
- state_index: The index of state of the leaf node in the search path of the current node.
- batch_index: The index of state of the leaf node in the search path of the current node.
- reward: the reward of the current node.
- logits: the logit of the child nodes.
*/
this->state_index = state_index;
this->batch_index = batch_index;
this->reward = reward;
float temp_policy;
float policy_sum = 0.0;
int n_actions = logits.Shape(0);
int n_legal_actions = legal_actions.Shape(0);
// Softmax over logits of legal actions
float policy[n_actions];
float policy_max = FLOAT_MIN;
for (int i = 0; i < n_legal_actions; ++i)
{
int a = legal_actions[i];
float logit = logits[a];
if (policy_max < logit)
{
policy_max = logit;
}
}
for (int i = 0; i < n_legal_actions; ++i)
{
int a = legal_actions[i];
float logit = logits[a];
temp_policy = exp(logit - policy_max);
policy_sum += temp_policy;
policy[a] = temp_policy;
}
float prior;
for (int i = 0; i < n_legal_actions; ++i)
{
int a = legal_actions[i];
prior = policy[a] / policy_sum;
this->children[a] = Node(prior);
}
}
void Node::add_exploration_noise(float exploration_fraction, float dirichlet_alpha)
{
/*
Overview:
Add a noise to the prior of the child nodes.
Arguments:
- exploration_fraction: the fraction to add noise.
- noises: the vector of noises added to each child node.
*/
std::vector<float> noises = random_dirichlet(dirichlet_alpha, this->children.size());
float noise, prior;
int i = 0;
for (auto &[a, child] : this->children)
{
noise = noises[i++];
prior = child.prior;
child.prior = prior * (1 - exploration_fraction) + noise * exploration_fraction;
}
}
float Node::compute_mean_q(int isRoot, float parent_q, float discount_factor)
{
/*
Overview:
Compute the mean q value of the current node.
Arguments:
- isRoot: whether the current node is a root node.
- parent_q: the q value of the parent node.
- discount_factor: the discount_factor of reward.
*/
float total_unsigned_q = 0.0;
int total_visits = 0;
for (const auto &[a, child] : this->children)
{
if (child.visit_count > 0)
{
float qsa = child.reward + discount_factor * child.value();
total_unsigned_q += qsa;
total_visits += 1;
}
}
float mean_q = 0.0;
if (isRoot && total_visits > 0)
{
mean_q = total_unsigned_q / total_visits;
}
else
{
mean_q = (parent_q + total_unsigned_q) / (total_visits + 1);
}
return mean_q;
}
int Node::expanded() const
{
/*
Overview:
Return whether the current node is expanded.
*/
return this->children.size() > 0;
}
float Node::value() const
{
/*
Overview:
Return the real value of the current tree.
*/
float true_value = 0.0;
if (this->visit_count == 0)
{
return true_value;
}
else
{
true_value = this->value_sum / this->visit_count;
return true_value;
}
}
std::vector<int> Node::get_trajectory()
{
/*
Overview:
Find the current best trajectory starts from the current node.
Outputs:
- traj: a vector of node index, which is the current best trajectory from this node.
*/
std::vector<int> traj;
Node *node = this;
int best_action = node->best_action;
while (best_action >= 0)
{
traj.push_back(best_action);
node = node->get_child(best_action);
best_action = node->best_action;
}
return traj;
}
std::vector<int> Node::get_children_distribution()
{
/*
Overview:
Get the distribution of child nodes in the format of visit_count.
Outputs:
- distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
*/
std::vector<int> distribution;
if (this->expanded())
{
for (const auto &[a, child] : this->children)
{
distribution.push_back(child.visit_count);
}
}
return distribution;
}
Node *Node::get_child(int action)
{
/*
Overview:
Get the child node corresponding to the input action.
Arguments:
- action: the action to get child.
*/
auto it = this->children.find(action);
if (it != this->children.end())
{
// The action exists in the map, return a pointer to the corresponding Node.
return &(it->second);
}
else
{
throw std::out_of_range("Action not found in children");
}
}
//*********************************************************
Roots::Roots()
{
/*
Overview:
The initialization of Roots.
*/
this->root_num = 0;
}
Roots::Roots(int root_num)
{
/*
Overview:
The initialization of Roots with root num and legal action lists.
Arguments:
- root_num: the number of the current root.
*/
this->root_num = root_num;
for (int i = 0; i < root_num; ++i)
{
// root node has no prior
this->roots.push_back(Node(0));
}
}
Roots::~Roots() {}
void Roots::prepare(
const Array &rewards, const Array &logits,
const Array &all_legal_actions, const Array &n_legal_actions,
float exploration_fraction, float dirichlet_alpha)
{
/*
Overview:
Expand the roots and add noises.
Arguments:
- rewards: the vector of rewards of each root.
- logits: the vector of policy logits of each root.
- legal_actions_list: the vector of legal actions of each root.
- exploration_fraction: the fraction to add noise, 0 means no noise.
- dirichlet_alpha: the dirichlet alpha.
Note:
Do not include terminal states because they have no legal actions and cannot be expanded.
*/
int batch_size = this->root_num;
int offset = 0;
int n_actions = logits.Shape(1);
for (int i = 0; i < batch_size; ++i)
{
int n_legal_action = n_legal_actions[i];
const Array &legal_actions = all_legal_actions.Slice(offset, offset + n_legal_action);
this->roots[i].expand(
0, i, rewards[i], logits[i], legal_actions);
if (exploration_fraction > 0) {
this->roots[i].add_exploration_noise(exploration_fraction, dirichlet_alpha);
}
this->roots[i].visit_count += 1;
offset += n_legal_action;
}
}
void Roots::clear()
{
/*
Overview:
Clear the roots vector.
*/
this->roots.clear();
}
std::vector<std::vector<int> > Roots::get_trajectories()
{
/*
Overview:
Find the current best trajectory starts from each root.
Outputs:
- traj: a vector of node index, which is the current best trajectory from each root.
*/
std::vector<std::vector<int> > trajs;
trajs.reserve(this->root_num);
for (int i = 0; i < this->root_num; ++i)
{
trajs.push_back(this->roots[i].get_trajectory());
}
return trajs;
}
std::vector<std::vector<int> > Roots::get_distributions()
{
/*
Overview:
Get the children distribution of each root.
Outputs:
- distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
*/
std::vector<std::vector<int> > distributions;
distributions.reserve(this->root_num);
for (int i = 0; i < this->root_num; ++i)
{
distributions.push_back(this->roots[i].get_children_distribution());
}
return distributions;
}
std::vector<float> Roots::get_values()
{
/*
Overview:
Return the real value of each root.
*/
std::vector<float> values;
for (int i = 0; i < this->root_num; ++i)
{
values.push_back(roots[i].value());
}
return values;
}
//*********************************************************
//
void update_tree_q(Node *root, MinMaxStats &min_max_stats, float discount_factor)
{
/*
Overview:
Update the q value of the root and its child nodes.
Arguments:
- root: the root that update q value from.
- min_max_stats: a tool used to min-max normalize the q value.
- discount_factor: the discount factor of reward.
*/
std::stack<Node *> node_stack;
node_stack.push(root);
while (node_stack.size() > 0)
{
Node *node = node_stack.top();
node_stack.pop();
if (node != root)
{
float true_reward = node->reward;
float qsa;
qsa = true_reward + discount_factor * node->value();
min_max_stats.update(qsa);
}
for (auto it = node->children.begin(); it != node->children.end(); ++it) {
Node *child = &(it->second);
if (child->expanded()) {
node_stack.push(child);
}
}
}
}
void backpropagate(std::vector<Node *> &search_path, MinMaxStats &min_max_stats, float value, float discount_factor)
{
/*
Overview:
Update the value sum and visit count of nodes along the search path.
Arguments:
- search_path: a vector of nodes on the search path.
- min_max_stats: a tool used to min-max normalize the q value.
- value: the value to propagate along the search path.
- discount_factor: the discount factor of reward.
*/
float bootstrap_value = value;
int path_len = search_path.size();
for (int i = path_len - 1; i >= 0; --i)
{
Node *node = search_path[i];
node->value_sum += bootstrap_value;
node->visit_count += 1;
float true_reward = node->reward;
min_max_stats.update(true_reward + discount_factor * node->value());
bootstrap_value = true_reward + discount_factor * bootstrap_value;
}
}
void batch_expand(
int state_index, const Array &game_over, const Array &rewards, const Array &logits /* 2D array */,
const Array &all_legal_actions, const Array &n_legal_actions, SearchResults &results)
{
int batch_size = results.num;
int offset = 0;
int n_actions = logits.Shape(1);
for (int i = 0; i < batch_size; ++i)
{
Node *node = results.nodes[i];
int n_legal_action = n_legal_actions[i];
if (game_over[i]) {
node->state_index = state_index;
node->batch_index = i;
node->reward = rewards[i];
}
else {
const Array &legal_actions = all_legal_actions.Slice(offset, offset + n_legal_action);
node->expand(
state_index, i, rewards[i], logits[i], legal_actions);
}
offset += n_legal_action;
}
}
void batch_backpropagate(float discount_factor, const Array &values, MinMaxStatsList &min_max_stats_lst, SearchResults &results)
{
/*
Overview:
Expand the nodes along the search path and update the infos.
Arguments:
- state_index: The index of state of the leaf node in the search path.
- values: the values to propagate along the search path.
- logits: the policy logits of nodes along the search path.
- min_max_stats: a tool used to min-max normalize the q value.
- results: the search results.
*/
for (int i = 0; i < results.num; ++i)
{
backpropagate(results.search_paths[i], min_max_stats_lst.stats_lst[i], values[i], discount_factor);
}
}
int select_child(Node *root, const MinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q)
{
/*
Overview:
Select the child node of the roots according to ucb scores.
Arguments:
- root: the roots to select the child node.
- min_max_stats: a tool used to min-max normalize the score.
- pb_c_base: constants c2 in muzero.
- pb_c_init: constants c1 in muzero.
- disount_factor: the discount factor of reward.
- mean_q: the mean q value of the parent node.
Outputs:
- action: the action to select.
*/
float max_score = FLOAT_MIN;
const float epsilon = 0.000001;
std::vector<Action> max_index_lst;
for (const auto &[a, child] : root->children)
{
float temp_score = ucb_score(child, min_max_stats, mean_q, root->visit_count, pb_c_base, pb_c_init, discount_factor);
if (max_score < temp_score)
{
max_score = temp_score;
max_index_lst.clear();
max_index_lst.push_back(a);
}
else if (temp_score >= max_score - epsilon)
{
max_index_lst.push_back(a);
}
}
int action = 0;
if (max_index_lst.size() > 0)
{
std::uniform_int_distribution<int> dist(0, max_index_lst.size() - 1);
int rand_index = dist(rng_);
action = max_index_lst[rand_index];
}
return action;
}
float ucb_score(const Node &child, const MinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor)
{
/*
Overview:
Compute the ucb score of the child.
Arguments:
- child: the child node to compute ucb score.
- min_max_stats: a tool used to min-max normalize the score.
- mean_q: the mean q value of the parent node.
- total_children_visit_counts: the total visit counts of the child nodes of the parent node.
- pb_c_base: constants c2 in muzero.
- pb_c_init: constants c1 in muzero.
- disount_factor: the discount factor of reward.
Outputs:
- ucb_value: the ucb score of the child.
*/
float pb_c = 0.0, prior_score = 0.0, value_score = 0.0;
pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init;
pb_c *= (sqrt(total_children_visit_counts) / (child.visit_count + 1));
prior_score = pb_c * child.prior;
if (child.visit_count == 0) {
value_score = parent_mean_q;
}
else {
value_score = child.reward + discount_factor * child.value();
}
value_score = min_max_stats.normalize(value_score);
if (value_score < 0)
value_score = 0;
if (value_score > 1)
value_score = 1;
float ucb_value = prior_score + value_score;
return ucb_value;
}
void batch_traverse(Roots &roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList &min_max_stats_lst, SearchResults &results)
{
/*
Overview:
Search node path from the roots.
Arguments:
- roots: the roots that search from.
- pb_c_base: constants c2 in muzero.
- pb_c_init: constants c1 in muzero.
- disount_factor: the discount factor of reward.
- min_max_stats: a tool used to min-max normalize the score.
- results: the search results.
*/
int last_action = -1;
float parent_q = 0.0;
results.search_lens = std::vector<int>();
for (int i = 0; i < results.num; ++i)
{
Node *node = &(roots.roots[i]);
int is_root = 1;
int search_len = 0;
results.search_paths[i].push_back(node);
while (node->expanded())
{
float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor);
is_root = 0;
parent_q = mean_q;
int action = select_child(node, min_max_stats_lst.stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q);
node->best_action = action;
// next
node = node->get_child(action);
last_action = action;
results.search_paths[i].push_back(node);
search_len += 1;
}
Node *parent = results.search_paths[i][results.search_paths[i].size() - 2];
results.state_index_in_search_path.push_back(parent->state_index);
results.state_index_in_batch.push_back(parent->batch_index);
results.last_actions.push_back(last_action);
results.search_lens.push_back(search_len);
results.nodes.push_back(node);
}
}
}
\ No newline at end of file
#ifndef AZ_CNODE_H
#define AZ_CNODE_H
#include <math.h>
#include <vector>
#include <stack>
#include <stdlib.h>
#include <time.h>
#include <cmath>
#include <random>
#include <sys/timeb.h>
#include <ctime>
#include <map>
#include "mcts/core/minimax.h"
#include "mcts/core/array.h"
const int DEBUG_MODE = 0;
namespace tree {
void init_module(int seed);
using Action = int;
class Node {
public:
int visit_count, state_index, batch_index, best_action;
float reward, prior, value_sum;
std::map<Action, Node> children;
Node();
Node(float prior);
~Node();
void expand(
int state_index, int batch_index, float reward, const Array &logits, const Array &legal_actions);
void add_exploration_noise(float exploration_fraction, float dirichlet_alpha);
float compute_mean_q(int isRoot, float parent_q, float discount_factor);
int expanded() const;
float value() const;
std::vector<int> get_trajectory();
std::vector<int> get_children_distribution();
Node* get_child(int action);
};
class Roots{
public:
int root_num;
std::vector<Node> roots;
Roots();
Roots(int root_num);
~Roots();
void prepare(
const Array &rewards, const Array &logits,
const Array &all_legal_actions, const Array &n_legal_actions,
float exploration_fraction, float dirichlet_alpha);
void clear();
std::vector<std::vector<int> > get_trajectories();
std::vector<std::vector<int> > get_distributions();
std::vector<float> get_values();
};
class SearchResults{
public:
int num;
std::vector<int> state_index_in_search_path, state_index_in_batch, last_actions, search_lens;
std::vector<Node*> nodes;
std::vector<std::vector<Node*> > search_paths;
SearchResults();
SearchResults(int num);
~SearchResults();
};
void update_tree_q(Node* root, MinMaxStats &min_max_stats, float discount_factor);
void backpropagate(std::vector<Node*> &search_path, MinMaxStats &min_max_stats, float value, float discount_factor);
void batch_expand(
int state_index, const Array &game_over, const Array &rewards, const Array &logits /* 2D array */,
const Array &all_legal_actions, const Array &n_legal_actions, SearchResults &results);
void batch_backpropagate(float discount_factor, const Array &values, MinMaxStatsList &min_max_stats_lst, SearchResults &results);
int select_child(Node* root, const MinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q);
float ucb_score(const Node &child, const MinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor);
void batch_traverse(Roots &roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList &min_max_stats_lst, SearchResults &results);
}
#endif // AZ_CNODE_H
\ No newline at end of file
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <pybind11/cast.h>
#include "mcts/core/common.h"
#include "mcts/core/minimax.h"
#include "mcts/core/array.h"
#include "mcts/alphazero/cnode.h"
namespace py = pybind11;
PYBIND11_MODULE(alphazero_mcts, m) {
using namespace pybind11::literals;
py::class_<MinMaxStatsList>(m, "MinMaxStatsList")
.def(py::init<int>())
.def("set_delta", &MinMaxStatsList::set_delta);
py::class_<tree::SearchResults>(m, "SearchResults")
.def(py::init<int>())
.def("get_search_len", [](tree::SearchResults &results) {
return results.search_lens;
});
py::class_<tree::Roots>(m, "Roots")
.def(py::init<int>())
.def_readonly("num", &tree::Roots::root_num)
.def("prepare", [](
tree::Roots &roots, const py::array_t<float> &rewards,
const py::array_t<float> &logits, const py::array_t<int> &all_legal_actions,
const py::array_t<int> &n_legal_actions, float exploration_fraction,
float dirichlet_alpha) {
Array rewards_ = NumpyToArray(rewards);
Array logits_ = NumpyToArray(logits);
Array all_legal_actions_ = NumpyToArray(all_legal_actions);
Array n_legal_actions_ = NumpyToArray(n_legal_actions);
roots.prepare(rewards_, logits_, all_legal_actions_, n_legal_actions_, exploration_fraction, dirichlet_alpha);
})
.def("get_distributions", &tree::Roots::get_distributions)
.def("get_values", &tree::Roots::get_values);
m.def("batch_expand", [](
int state_index, const py::array_t<bool> &game_over, const py::array_t<float> &rewards, const py::array_t<float> &logits,
const py::array_t<int> &all_legal_actions, const py::array_t<int> &n_legal_actions, tree::SearchResults &results) {
Array game_over_ = NumpyToArray(game_over);
Array rewards_ = NumpyToArray(rewards);
Array logits_ = NumpyToArray(logits);
Array all_legal_actions_ = NumpyToArray(all_legal_actions);
Array n_legal_actions_ = NumpyToArray(n_legal_actions);
tree::batch_expand(state_index, game_over_, rewards_, logits_, all_legal_actions_, n_legal_actions_, results);
});
m.def("batch_backpropagate", [](
float discount_factor, const py::array_t<float> &values,
MinMaxStatsList &min_max_stats_lst, tree::SearchResults &results) {
Array values_ = NumpyToArray(values);
tree::batch_backpropagate(discount_factor, values_, min_max_stats_lst, results);
});
m.def("batch_traverse", [](
tree::Roots &roots, int pb_c_base, float pb_c_init, float discount_factor,
MinMaxStatsList &min_max_stats_lst, tree::SearchResults &results) {
tree::batch_traverse(roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results);
return py::make_tuple(results.state_index_in_search_path, results.state_index_in_batch, results.last_actions);
});
m.def("init_module", &tree::init_module, "", "seed"_a);
}
\ No newline at end of file
#ifndef MCTS_CORE_ARRAY_H_
#define MCTS_CORE_ARRAY_H_
#include <cstddef>
#include <cstring>
#include <memory>
#include <utility>
#include <vector>
#include "mcts/core/spec.h"
class Array {
public:
std::size_t size;
std::size_t ndim;
std::size_t element_size;
protected:
std::vector<std::size_t> shape_;
std::shared_ptr<char> ptr_;
template<class Shape, class Deleter>
Array(char *ptr, Shape &&shape, std::size_t element_size,// NOLINT
Deleter &&deleter)
: size(Prod(shape.data(), shape.size())),
ndim(shape.size()),
element_size(element_size),
shape_(std::forward<Shape>(shape)),
ptr_(ptr, std::forward<Deleter>(deleter)) {}
template<class Shape>
Array(std::shared_ptr<char> ptr, Shape &&shape, std::size_t element_size)
: size(Prod(shape.data(), shape.size())),
ndim(shape.size()),
element_size(element_size),
shape_(std::forward<Shape>(shape)),
ptr_(std::move(ptr)) {}
public:
Array() = default;
/**
* Constructor an `Array` of shape defined by `spec`, with `data` as pointer
* to its raw memory. With an empty deleter, which means Array does not own
* the memory.
*/
template<class Deleter>
Array(const ShapeSpec &spec, char *data, Deleter &&deleter)// NOLINT
: Array(data, spec.Shape(), spec.element_size,
std::forward<Deleter>(deleter)) {}
Array(const ShapeSpec &spec, char *data)
: Array(data, spec.Shape(), spec.element_size, [](char * /*unused*/) {}) {}
/**
* Constructor an `Array` of shape defined by `spec`. This constructor
* allocates and owns the memory.
*/
explicit Array(const ShapeSpec &spec)
: Array(spec, nullptr, [](char * /*unused*/) {}) {
ptr_.reset(new char[size * element_size](),
[](const char *p) { delete[] p; });
}
/**
* Take multidimensional index into the Array.
*/
template<typename... Index>
inline Array operator()(Index... index) const {
constexpr std::size_t num_index = sizeof...(Index);
std::size_t offset = 0;
std::size_t i = 0;
for (((offset = offset * shape_[i++] + index), ...); i < ndim; ++i) {
offset *= shape_[i];
}
return Array(
ptr_.get() + offset * element_size,
std::vector<std::size_t>(shape_.begin() + num_index, shape_.end()),
element_size, [](char * /*unused*/) {});
}
/**
* Index operator of array, takes the index along the first axis.
*/
inline Array operator[](int index) const { return this->operator()(index); }
/**
* Take a slice at the first axis of the Array.
*/
[[nodiscard]] Array Slice(std::size_t start, std::size_t end) const {
std::vector<std::size_t> new_shape(shape_);
new_shape[0] = end - start;
std::size_t offset = 0;
if (shape_[0] > 0) {
offset = start * size / shape_[0];
}
return {ptr_.get() + offset * element_size, std::move(new_shape),
element_size, [](char *p) {}};
}
/**
* Copy the content of another Array to this Array.
*/
void Assign(const Array &value) const {
std::memcpy(ptr_.get(), value.ptr_.get(), size * element_size);
}
/**
* Return a clone of this array.
*/
Array Clone() const {
std::vector<int> shape;
for (int i = 0; i < ndim; i++) {
shape.push_back(shape_[i]);
}
auto spec = ShapeSpec(element_size, shape);
Array ret(spec);
ret.Assign(*this);
return ret;
}
/**
* Assign to this Array a scalar value. This Array needs to have a scalar
* shape.
*/
template<typename T>
void operator=(const T &value) const {
*reinterpret_cast<T *>(ptr_.get()) = value;
}
/**
* Fills this array with a scalar value of type T.
*/
template<typename T>
void Fill(const T &value) const {
auto *data = reinterpret_cast<T *>(ptr_.get());
std::fill(data, data + size, value);
}
/**
* Copy the memory starting at `raw.first`, to `raw.first + raw.second` to the
* memory of this Array.
*/
template<typename T>
void Assign(const T *buff, std::size_t sz) const {
std::memcpy(ptr_.get(), buff, sz * sizeof(T));
}
template<typename T>
void Assign(const T *buff, std::size_t sz, ptrdiff_t offset) const {
offset = offset * (element_size / sizeof(char));
std::memcpy(ptr_.get() + offset, buff, sz * sizeof(T));
}
/**
* Cast the Array to a scalar value of type `T`. This Array needs to have a
* scalar shape.
*/
template<typename T>
operator const T &() const {// NOLINT
return *reinterpret_cast<T *>(ptr_.get());
}
/**
* Cast the Array to a scalar value of type `T`. This Array needs to have a
* scalar shape.
*/
template<typename T>
operator T &() {// NOLINT
return *reinterpret_cast<T *>(ptr_.get());
}
/**
* Size of axis `dim`.
*/
[[nodiscard]] inline std::size_t Shape(std::size_t dim) const {
return shape_[dim];
}
/**
* Shape
*/
[[nodiscard]] inline const std::vector<std::size_t> &Shape() const {
return shape_;
}
/**
* Pointer to the raw memory.
*/
[[nodiscard]] inline void *Data() const { return ptr_.get(); }
/**
* Truncate the Array. Return a new Array that shares the same memory
* location but with a truncated shape.
*/
[[nodiscard]] Array Truncate(std::size_t end) const {
auto new_shape = std::vector<std::size_t>(shape_);
new_shape[0] = end;
Array ret(ptr_, std::move(new_shape), element_size);
return ret;
}
void Zero() const { std::memset(ptr_.get(), 0, size * element_size); }
[[nodiscard]] std::shared_ptr<char> SharedPtr() const { return ptr_; }
};
template<typename Dtype>
class TArray : public Array {
public:
explicit TArray(const Spec<Dtype> &spec) : Array(spec) {}
};
#endif // MCTS_CORE_ARRAY_H_
\ No newline at end of file
#ifndef MCTS_CORE_COMMON_H_
#define MCTS_CORE_COMMON_H_
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include "mcts/core/array.h"
namespace py = pybind11;
template<typename dtype>
py::array_t<dtype> ArrayToNumpy(const Array &a) {
auto *ptr = new std::shared_ptr<char>(a.SharedPtr());
auto capsule = py::capsule(ptr, [](void *ptr) {
delete reinterpret_cast<std::shared_ptr<char> *>(ptr);
});
return py::array(a.Shape(), reinterpret_cast<dtype *>(a.Data()), capsule);
}
template<typename dtype>
Array NumpyToArray(const py::array_t<dtype> &arr) {
using ArrayT = py::array_t<dtype, py::array::c_style | py::array::forcecast>;
ArrayT arr_t(arr);
ShapeSpec spec(arr_t.itemsize(),
std::vector<int>(arr_t.shape(), arr_t.shape() + arr_t.ndim()));
return {spec, reinterpret_cast<char *>(arr_t.mutable_data())};
}
template <typename Sequence>
inline py::array_t<typename Sequence::value_type> as_pyarray(Sequence &&seq) {
using T = typename Sequence::value_type;
std::unique_ptr<Sequence> seq_ptr = std::make_unique<Sequence>(std::forward<Sequence>(seq));
return py::array_t<T>({seq_ptr->size()}, {sizeof(T)}, seq_ptr->data());
}
#endif // MCTS_CORE_COMMON_H_
\ No newline at end of file
#ifndef MCTS_CORE_MINIMAX_H_
#define MCTS_CORE_MINIMAX_H_
#include <iostream>
#include <vector>
const float FLOAT_MAX = 1000000.0;
const float FLOAT_MIN = -FLOAT_MAX;
class MinMaxStats {
public:
float maximum, minimum, value_delta_max;
MinMaxStats() {
this->maximum = FLOAT_MIN;
this->minimum = FLOAT_MAX;
this->value_delta_max = 0.;
}
~MinMaxStats() {}
void set_delta(float value_delta_max) {
this->value_delta_max = value_delta_max;
}
void update(float value) {
if(value > this->maximum){
this->maximum = value;
}
if(value < this->minimum){
this->minimum = value;
}
}
void clear() {
this->maximum = FLOAT_MIN;
this->minimum = FLOAT_MAX;
}
float normalize(float value) const {
float norm_value = value;
float delta = this->maximum - this->minimum;
if(delta > 0){
if(delta < this->value_delta_max){
norm_value = (norm_value - this->minimum) / this->value_delta_max;
}
else{
norm_value = (norm_value - this->minimum) / delta;
}
}
return norm_value;
}
};
class MinMaxStatsList {
public:
int num;
std::vector<MinMaxStats> stats_lst;
MinMaxStatsList() {
this->num = 0;
}
MinMaxStatsList(int num) {
this->num = num;
for(int i = 0; i < num; ++i){
this->stats_lst.push_back(MinMaxStats());
}
}
~MinMaxStatsList() {}
void set_delta(float value_delta_max) {
for(int i = 0; i < this->num; ++i){
this->stats_lst[i].set_delta(value_delta_max);
}
}
};
#endif // MCTS_CORE_MINIMAX_H_
\ No newline at end of file
#ifndef MCTS_CORE_SPEC_H_
#define MCTS_CORE_SPEC_H_
#include <cstddef>
#include <functional>
#include <limits>
#include <memory>
#include <numeric>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
static std::size_t Prod(const std::size_t *shape, std::size_t ndim) {
return std::accumulate(shape, shape + ndim, static_cast<std::size_t>(1),
std::multiplies<>());
}
class ShapeSpec {
public:
int element_size;
std::vector<int> shape;
ShapeSpec() = default;
ShapeSpec(int element_size, std::vector<int> shape_vec)
: element_size(element_size), shape(std::move(shape_vec)) {}
[[nodiscard]] ShapeSpec Batch(int batch_size) const {
std::vector<int> new_shape = {batch_size};
new_shape.insert(new_shape.end(), shape.begin(), shape.end());
return {element_size, std::move(new_shape)};
}
[[nodiscard]] std::vector<std::size_t> Shape() const {
auto s = std::vector<std::size_t>(shape.size());
for (std::size_t i = 0; i < shape.size(); ++i) {
s[i] = shape[i];
}
return s;
}
};
template<typename D>
class Spec : public ShapeSpec {
public:
using dtype = D;// NOLINT
std::tuple<dtype, dtype> bounds = {std::numeric_limits<dtype>::min(),
std::numeric_limits<dtype>::max()};
std::tuple<std::vector<dtype>, std::vector<dtype>> elementwise_bounds;
explicit Spec(std::vector<int> &&shape)
: ShapeSpec(sizeof(dtype), std::move(shape)) {}
explicit Spec(const std::vector<int> &shape)
: ShapeSpec(sizeof(dtype), shape) {}
/* init with constant bounds */
Spec(std::vector<int> &&shape, std::tuple<dtype, dtype> &&bounds)
: ShapeSpec(sizeof(dtype), std::move(shape)), bounds(std::move(bounds)) {}
Spec(const std::vector<int> &shape, const std::tuple<dtype, dtype> &bounds)
: ShapeSpec(sizeof(dtype), shape), bounds(bounds) {}
/* init with elementwise bounds */
Spec(std::vector<int> &&shape,
std::tuple<std::vector<dtype>, std::vector<dtype>> &&elementwise_bounds)
: ShapeSpec(sizeof(dtype), std::move(shape)),
elementwise_bounds(std::move(elementwise_bounds)) {}
Spec(const std::vector<int> &shape,
const std::tuple<std::vector<dtype>, std::vector<dtype>> &
elementwise_bounds)
: ShapeSpec(sizeof(dtype), shape),
elementwise_bounds(elementwise_bounds) {}
[[nodiscard]] Spec Batch(int batch_size) const {
std::vector<int> new_shape = {batch_size};
new_shape.insert(new_shape.end(), shape.begin(), shape.end());
return Spec(std::move(new_shape));
}
};
template<typename dtype>
class TArray;
template<typename dtype>
using Container = std::unique_ptr<TArray<dtype>>;
template<typename D>
class Spec<Container<D>> : public ShapeSpec {
public:
using dtype = Container<D>;// NOLINT
Spec<D> inner_spec;
explicit Spec(const std::vector<int> &shape, const Spec<D> &inner_spec)
: ShapeSpec(sizeof(Container<D>), shape), inner_spec(inner_spec) {}
explicit Spec(std::vector<int> &&shape, Spec<D> &&inner_spec)
: ShapeSpec(sizeof(Container<D>), std::move(shape)),
inner_spec(std::move(inner_spec)) {}
};
#endif // MCTS_CORE_SPEC_H_
\ No newline at end of file
from typing import Tuple, Sequence
import numpy as np
class State:
def __init__(self, batch_shape: Tuple[int, ...], store):
assert isinstance(store, np.ndarray)
self.store = store
self.batch_shape = batch_shape
self.ndim = len(batch_shape)
def get_state_keys(self):
return self.store
@classmethod
def from_item(cls, item):
return cls((1,), np.array([item], dtype=np.int32))
def reshape(self, batch_shape: Tuple[int, ...]):
self.batch_shape = batch_shape
self.ndim = len(batch_shape)
return self
def item(self):
assert self.ndim == 1 and self.batch_shape[0] == 1
return self.store[0]
@classmethod
def from_state_list(cls, state_list, batch_shape=None):
if isinstance(state_list[0], State):
batch_shape_ = (len(state_list),)
elif isinstance(state_list[0], Sequence):
batch_shape_ = (len(state_list), len(state_list[0]))
assert isinstance(state_list[0][0], State)
else:
raise ValueError("Invalid dim of states")
if batch_shape is None:
batch_shape = batch_shape_
else:
assert len(batch_shape) == 2 and len(batch_shape_) == 1
if len(batch_shape) == 2:
states = [s for ss in state_list for s in ss]
else:
states = state_list
state_keys = np.concatenate([s.store for s in states], dtype=np.int32, axis=0)
return State(batch_shape, state_keys)
def _get_by_index(self, batch_shape, indices):
state_keys = self.store[indices]
return State(batch_shape, state_keys)
def __getitem__(self, item):
return self.get(item)
def get(self, i):
if self.ndim == 2:
assert isinstance(i, tuple)
i = i[0] * self.batch_shape[1] + i[1]
i = np.array([i], dtype=np.int32)
return self._get_by_index((1,), i)
def __len__(self) -> int:
return len(self.store)
def __repr__(self) -> str:
return f'State(batch_shape={self.batch_shape}, ndim={self.ndim})'
def __str__(self) -> str:
return f'State(batch_shape={self.batch_shape}, ndim={self.ndim})'
from setuptools import setup, find_packages
__version__ = "0.0.1"
INSTALL_REQUIRES = [
"setuptools",
"wheel",
"pybind11-stubgen",
"numpy",
]
setup(
name="mcts",
version=__version__,
packages=find_packages(include='mcts*'),
long_description="",
install_requires=INSTALL_REQUIRES,
python_requires=">=3.7",
include_package_data=True,
)
\ No newline at end of file
...@@ -39,7 +39,7 @@ class Args: ...@@ -39,7 +39,7 @@ class Args:
# Algorithm specific arguments # Algorithm specific arguments
env_id: str = "YGOPro-v0" env_id: str = "YGOPro-v0"
"""the id of the environment""" """the id of the environment"""
deck: str = "../assets/deck/OldSchool.ydk" deck: str = "../assets/deck"
"""the deck file to use""" """the deck file to use"""
deck1: Optional[str] = None deck1: Optional[str] = None
"""the deck file for the first player""" """the deck file for the first player"""
...@@ -47,19 +47,21 @@ class Args: ...@@ -47,19 +47,21 @@ class Args:
"""the deck file for the second player""" """the deck file for the second player"""
code_list_file: str = "code_list.txt" code_list_file: str = "code_list.txt"
"""the code list file for card embeddings""" """the code list file for card embeddings"""
embedding_file: Optional[str] = "embeddings_en.npy" embedding_file: Optional[str] = None
"""the embedding file for card embeddings""" """the embedding file for card embeddings"""
max_options: int = 24 max_options: int = 24
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 16 n_history_actions: int = 16
"""the number of history actions to use""" """the number of history actions to use"""
play_mode: str = "bot" play_mode: str = "bot"
"""the play mode, can be combination of 'self', 'bot', 'random', like 'self+bot'""" """the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2 num_layers: int = 2
"""the number of layers for the agent""" """the number of layers for the agent"""
num_channels: int = 128 num_channels: int = 128
"""the number of channels for the agent""" """the number of channels for the agent"""
checkpoint: Optional[str] = None
"""the checkpoint to load the model from"""
total_timesteps: int = 1000000000 total_timesteps: int = 1000000000
"""total timesteps of the experiments""" """total timesteps of the experiments"""
...@@ -236,8 +238,15 @@ def run(local_rank, world_size): ...@@ -236,8 +238,15 @@ def run(local_rank, world_size):
embedding_shape = None embedding_shape = None
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device) agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
if args.embedding_file:
if args.checkpoint:
agent.load_state_dict(torch.load(args.checkpoint, map_location=device))
fprint(f"Loaded checkpoint from {args.checkpoint}")
elif args.embedding_file:
agent.load_embeddings(embeddings) agent.load_embeddings(embeddings)
fprint(f"Loaded embeddings from {args.embedding_file}")
if args.embedding_file:
agent.freeze_embeddings()
optim_params = list(agent.parameters()) optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5) optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
...@@ -431,12 +440,10 @@ def run(local_rank, world_size): ...@@ -431,12 +440,10 @@ def run(local_rank, world_size):
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1) nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2) nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
# TODO: optimize this # TODO: optimize this
done_used1 = torch.zeros_like(next_done, dtype=torch.bool) done_used1 = torch.ones_like(next_done, dtype=torch.bool)
done_used2 = torch.zeros_like(next_done, dtype=torch.bool) done_used2 = torch.ones_like(next_done, dtype=torch.bool)
reward1 = 0 reward1 = reward2 = 0
reward2 = 0 lastgaelam1 = lastgaelam2 = 0
lastgaelam1 = 0
lastgaelam2 = 0
for t in reversed(range(args.num_steps)): for t in reversed(range(args.num_steps)):
# if learns[t]: # if learns[t]:
# if dones[t+1]: # if dones[t+1]:
...@@ -586,6 +593,7 @@ def run(local_rank, world_size): ...@@ -586,6 +593,7 @@ def run(local_rank, world_size):
writer.add_scalar("charts/SPS", SPS, global_step) writer.add_scalar("charts/SPS", SPS, global_step)
if iteration % args.eval_interval == 0: if iteration % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time() _start = time.time()
episode_lengths = [] episode_lengths = []
episode_rewards = [] episode_rewards = []
...@@ -627,6 +635,8 @@ def run(local_rank, world_size): ...@@ -627,6 +635,8 @@ def run(local_rank, world_size):
eval_time = time.time() - _start eval_time = time.time() - _start
fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}, eval_ep_len={eval_ep_len:.1f}, eval_win_rate={eval_win_rate:.4f}") fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}, eval_ep_len={eval_ep_len:.1f}, eval_win_rate={eval_win_rate:.4f}")
# Eval with old model
if args.world_size > 1: if args.world_size > 1:
dist.destroy_process_group() dist.destroy_process_group()
envs.close() envs.close()
......
...@@ -8,19 +8,6 @@ add_requires( ...@@ -8,19 +8,6 @@ add_requires(
"sqlitecpp 3.2.1") "sqlitecpp 3.2.1")
-- target("dummy_ygopro")
-- add_rules("python.library")
-- add_files("ygoenv/ygoenv/dummy/*.cpp")
-- add_packages("pybind11", "fmt", "glog", "concurrentqueue")
-- set_languages("c++17")
-- add_includedirs("ygoenv")
-- after_build(function (target)
-- local install_target = "$(projectdir)/ygoenv/ygoenv/dummy"
-- os.cp(target:targetfile(), install_target)
-- print("Copy target to " .. install_target)
-- end)
target("ygopro_ygoenv") target("ygopro_ygoenv")
add_rules("python.library") add_rules("python.library")
add_files("ygoenv/ygoenv/ygopro/*.cpp") add_files("ygoenv/ygoenv/ygopro/*.cpp")
...@@ -37,3 +24,22 @@ target("ygopro_ygoenv") ...@@ -37,3 +24,22 @@ target("ygopro_ygoenv")
os.cp(target:targetfile(), install_target) os.cp(target:targetfile(), install_target)
print("Copy target to " .. install_target) print("Copy target to " .. install_target)
end) end)
target("alphazero_mcts")
add_rules("python.library")
add_files("mcts/mcts/alphazero/*.cpp")
add_packages("pybind11")
set_languages("c++17")
if is_mode("release") then
set_policy("build.optimization.lto", true)
add_cxxflags("-march=native")
end
add_includedirs("mcts")
after_build(function (target)
local install_target = "$(projectdir)/mcts/mcts/alphazero"
os.cp(target:targetfile(), install_target)
print("Copy target to " .. install_target)
os.run("pybind11-stubgen mcts.alphazero.alphazero_mcts -o %s", "$(projectdir)/mcts")
end)
...@@ -150,14 +150,15 @@ class Encoder(nn.Module): ...@@ -150,14 +150,15 @@ class Encoder(nn.Module):
elif "fc_emb" in n: elif "fc_emb" in n:
nn.init.uniform_(m.weight, -scale, scale) nn.init.uniform_(m.weight, -scale, scale)
def load_embeddings(self, embeddings, freeze=True): def load_embeddings(self, embeddings):
weight = self.id_embed.weight weight = self.id_embed.weight
embeddings = torch.from_numpy(embeddings).to(dtype=weight.dtype, device=weight.device) embeddings = torch.from_numpy(embeddings).to(dtype=weight.dtype, device=weight.device)
unknown_embed = embeddings.mean(dim=0, keepdim=True) unknown_embed = embeddings.mean(dim=0, keepdim=True)
embeddings = torch.cat([unknown_embed, embeddings], dim=0) embeddings = torch.cat([unknown_embed, embeddings], dim=0)
weight.data.copy_(embeddings) weight.data.copy_(embeddings)
if freeze:
weight.requires_grad = False def freeze_embeddings(self):
self.id_embed.weight.requires_grad = False
def num_transform(self, x): def num_transform(self, x):
return self.num_fc(bytes_to_bin(x, self.bin_points, self.bin_intervals)) return self.num_fc(bytes_to_bin(x, self.bin_points, self.bin_intervals))
...@@ -409,8 +410,11 @@ class PPOAgent(nn.Module): ...@@ -409,8 +410,11 @@ class PPOAgent(nn.Module):
nn.Linear(c // 2, 1), nn.Linear(c // 2, 1),
) )
def load_embeddings(self, embeddings, freeze=True): def load_embeddings(self, embeddings):
self.encoder.load_embeddings(embeddings, freeze) self.encoder.load_embeddings(embeddings)
def freeze_embeddings(self):
self.encoder.freeze_embeddings()
def get_logit(self, x): def get_logit(self, x):
f_actions, f_state, mask, valid = self.encoder(x) f_actions, f_state, mask, valid = self.encoder(x)
......
...@@ -1827,15 +1827,18 @@ private: ...@@ -1827,15 +1827,18 @@ private:
} else { } else {
auto it = spec2index.find(spec); auto it = spec2index.find(spec);
if (it == spec2index.end()) { if (it == spec2index.end()) {
// TODO: find the root cause
// print spec2index // print spec2index
fmt::println("Spec2index:"); fmt::println("Spec2index:");
for (auto &[k, v] : spec2index) { for (auto &[k, v] : spec2index) {
fmt::println("{}: {}", k, v); fmt::println("{}: {}", k, v);
} }
throw std::runtime_error("Spec not found: " + spec); // throw std::runtime_error("Spec not found: " + spec);
} idx = 1;
} else {
idx = it->second; idx = it->second;
} }
}
feat(i, 2 * j) = static_cast<uint8_t>(idx >> 8); feat(i, 2 * j) = static_cast<uint8_t>(idx >> 8);
feat(i, 2 * j + 1) = static_cast<uint8_t>(idx & 0xff); feat(i, 2 * j + 1) = static_cast<uint8_t>(idx & 0xff);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment