Commit d2369b3e authored by biluo.shen's avatar biluo.shen

Add mcts

parent 722dd65a
......@@ -6,7 +6,22 @@ YGO Agent is a project to create a Yu-Gi-Oh! AI using deep learning (LLMs, RL).
`ygoenv` is a high performance game environment for Yu-Gi-Oh! It is initially inspired by [yugioh-ai](https://github.com/melvinzhang/yugioh-ai]) and [yugioh-game](https://github.com/tspivey/yugioh-game), and now implemented on top of [envpool](https://github.com/sail-sg/envpool).
## ygoai
`ygoai` is a set of AI agents for playing Yu-Gi-Oh! It aims to achieve superhuman performance like AlphaGo. Currently, only a RL-based agent is implemented.
`ygoai` is a set of AI agents for playing Yu-Gi-Oh! It aims to achieve superhuman performance like AlphaGo and AlphaZero, with or without human knowledge. Currently, we focus on using reinforcement learning to train the agents.
## TODO
### Documentation
- Add documentations of building and running
### Training
- Eval with old models during training
- MTCS-based training
### Inference
- MCTS-based planning
- Support of play in YGOPro
## Usage
TODO
......
......@@ -15,7 +15,7 @@ Not supported
- `min` > 5 throws an error
- `max` > 5 is truncated to 5
### Unsupported
### related cards
- Fairy Tail - Snow (min=max=7)
- Pot of Prosperity (min=max=6)
......
This diff is collapsed.
from __future__ import annotations
import numpy
__all__ = ['MinMaxStatsList', 'Roots', 'SearchResults', 'batch_backpropagate', 'batch_expand', 'batch_traverse', 'init_module']
class MinMaxStatsList:
def __init__(self, arg0: int) -> None:
...
def set_delta(self, arg0: float) -> None:
...
class Roots:
def __init__(self, arg0: int) -> None:
...
def get_distributions(self) -> list[list[int]]:
...
def get_values(self) -> list[float]:
...
def prepare(self, arg0: numpy.ndarray[numpy.float32], arg1: numpy.ndarray[numpy.float32], arg2: numpy.ndarray[numpy.int32], arg3: numpy.ndarray[numpy.int32], arg4: float, arg5: float) -> None:
...
@property
def num(self) -> int:
...
class SearchResults:
def __init__(self, arg0: int) -> None:
...
def get_search_len(self) -> list[int]:
...
def batch_backpropagate(arg0: float, arg1: numpy.ndarray[numpy.float32], arg2: MinMaxStatsList, arg3: SearchResults) -> None:
...
def batch_expand(arg0: int, arg1: numpy.ndarray[bool], arg2: numpy.ndarray[numpy.float32], arg3: numpy.ndarray[numpy.float32], arg4: numpy.ndarray[numpy.int32], arg5: numpy.ndarray[numpy.int32], arg6: SearchResults) -> None:
...
def batch_traverse(arg0: Roots, arg1: int, arg2: float, arg3: float, arg4: MinMaxStatsList, arg5: SearchResults) -> tuple:
...
def init_module(seed: int) -> None:
"""
(asd
,)
"""
This diff is collapsed.
#ifndef AZ_CNODE_H
#define AZ_CNODE_H
#include <math.h>
#include <vector>
#include <stack>
#include <stdlib.h>
#include <time.h>
#include <cmath>
#include <random>
#include <sys/timeb.h>
#include <ctime>
#include <map>
#include "mcts/core/minimax.h"
#include "mcts/core/array.h"
const int DEBUG_MODE = 0;
namespace tree {
void init_module(int seed);
using Action = int;
class Node {
public:
int visit_count, state_index, batch_index, best_action;
float reward, prior, value_sum;
std::map<Action, Node> children;
Node();
Node(float prior);
~Node();
void expand(
int state_index, int batch_index, float reward, const Array &logits, const Array &legal_actions);
void add_exploration_noise(float exploration_fraction, float dirichlet_alpha);
float compute_mean_q(int isRoot, float parent_q, float discount_factor);
int expanded() const;
float value() const;
std::vector<int> get_trajectory();
std::vector<int> get_children_distribution();
Node* get_child(int action);
};
class Roots{
public:
int root_num;
std::vector<Node> roots;
Roots();
Roots(int root_num);
~Roots();
void prepare(
const Array &rewards, const Array &logits,
const Array &all_legal_actions, const Array &n_legal_actions,
float exploration_fraction, float dirichlet_alpha);
void clear();
std::vector<std::vector<int> > get_trajectories();
std::vector<std::vector<int> > get_distributions();
std::vector<float> get_values();
};
class SearchResults{
public:
int num;
std::vector<int> state_index_in_search_path, state_index_in_batch, last_actions, search_lens;
std::vector<Node*> nodes;
std::vector<std::vector<Node*> > search_paths;
SearchResults();
SearchResults(int num);
~SearchResults();
};
void update_tree_q(Node* root, MinMaxStats &min_max_stats, float discount_factor);
void backpropagate(std::vector<Node*> &search_path, MinMaxStats &min_max_stats, float value, float discount_factor);
void batch_expand(
int state_index, const Array &game_over, const Array &rewards, const Array &logits /* 2D array */,
const Array &all_legal_actions, const Array &n_legal_actions, SearchResults &results);
void batch_backpropagate(float discount_factor, const Array &values, MinMaxStatsList &min_max_stats_lst, SearchResults &results);
int select_child(Node* root, const MinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q);
float ucb_score(const Node &child, const MinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor);
void batch_traverse(Roots &roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList &min_max_stats_lst, SearchResults &results);
}
#endif // AZ_CNODE_H
\ No newline at end of file
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <pybind11/cast.h>
#include "mcts/core/common.h"
#include "mcts/core/minimax.h"
#include "mcts/core/array.h"
#include "mcts/alphazero/cnode.h"
namespace py = pybind11;
PYBIND11_MODULE(alphazero_mcts, m) {
using namespace pybind11::literals;
py::class_<MinMaxStatsList>(m, "MinMaxStatsList")
.def(py::init<int>())
.def("set_delta", &MinMaxStatsList::set_delta);
py::class_<tree::SearchResults>(m, "SearchResults")
.def(py::init<int>())
.def("get_search_len", [](tree::SearchResults &results) {
return results.search_lens;
});
py::class_<tree::Roots>(m, "Roots")
.def(py::init<int>())
.def_readonly("num", &tree::Roots::root_num)
.def("prepare", [](
tree::Roots &roots, const py::array_t<float> &rewards,
const py::array_t<float> &logits, const py::array_t<int> &all_legal_actions,
const py::array_t<int> &n_legal_actions, float exploration_fraction,
float dirichlet_alpha) {
Array rewards_ = NumpyToArray(rewards);
Array logits_ = NumpyToArray(logits);
Array all_legal_actions_ = NumpyToArray(all_legal_actions);
Array n_legal_actions_ = NumpyToArray(n_legal_actions);
roots.prepare(rewards_, logits_, all_legal_actions_, n_legal_actions_, exploration_fraction, dirichlet_alpha);
})
.def("get_distributions", &tree::Roots::get_distributions)
.def("get_values", &tree::Roots::get_values);
m.def("batch_expand", [](
int state_index, const py::array_t<bool> &game_over, const py::array_t<float> &rewards, const py::array_t<float> &logits,
const py::array_t<int> &all_legal_actions, const py::array_t<int> &n_legal_actions, tree::SearchResults &results) {
Array game_over_ = NumpyToArray(game_over);
Array rewards_ = NumpyToArray(rewards);
Array logits_ = NumpyToArray(logits);
Array all_legal_actions_ = NumpyToArray(all_legal_actions);
Array n_legal_actions_ = NumpyToArray(n_legal_actions);
tree::batch_expand(state_index, game_over_, rewards_, logits_, all_legal_actions_, n_legal_actions_, results);
});
m.def("batch_backpropagate", [](
float discount_factor, const py::array_t<float> &values,
MinMaxStatsList &min_max_stats_lst, tree::SearchResults &results) {
Array values_ = NumpyToArray(values);
tree::batch_backpropagate(discount_factor, values_, min_max_stats_lst, results);
});
m.def("batch_traverse", [](
tree::Roots &roots, int pb_c_base, float pb_c_init, float discount_factor,
MinMaxStatsList &min_max_stats_lst, tree::SearchResults &results) {
tree::batch_traverse(roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results);
return py::make_tuple(results.state_index_in_search_path, results.state_index_in_batch, results.last_actions);
});
m.def("init_module", &tree::init_module, "", "seed"_a);
}
\ No newline at end of file
#ifndef MCTS_CORE_ARRAY_H_
#define MCTS_CORE_ARRAY_H_
#include <cstddef>
#include <cstring>
#include <memory>
#include <utility>
#include <vector>
#include "mcts/core/spec.h"
class Array {
public:
std::size_t size;
std::size_t ndim;
std::size_t element_size;
protected:
std::vector<std::size_t> shape_;
std::shared_ptr<char> ptr_;
template<class Shape, class Deleter>
Array(char *ptr, Shape &&shape, std::size_t element_size,// NOLINT
Deleter &&deleter)
: size(Prod(shape.data(), shape.size())),
ndim(shape.size()),
element_size(element_size),
shape_(std::forward<Shape>(shape)),
ptr_(ptr, std::forward<Deleter>(deleter)) {}
template<class Shape>
Array(std::shared_ptr<char> ptr, Shape &&shape, std::size_t element_size)
: size(Prod(shape.data(), shape.size())),
ndim(shape.size()),
element_size(element_size),
shape_(std::forward<Shape>(shape)),
ptr_(std::move(ptr)) {}
public:
Array() = default;
/**
* Constructor an `Array` of shape defined by `spec`, with `data` as pointer
* to its raw memory. With an empty deleter, which means Array does not own
* the memory.
*/
template<class Deleter>
Array(const ShapeSpec &spec, char *data, Deleter &&deleter)// NOLINT
: Array(data, spec.Shape(), spec.element_size,
std::forward<Deleter>(deleter)) {}
Array(const ShapeSpec &spec, char *data)
: Array(data, spec.Shape(), spec.element_size, [](char * /*unused*/) {}) {}
/**
* Constructor an `Array` of shape defined by `spec`. This constructor
* allocates and owns the memory.
*/
explicit Array(const ShapeSpec &spec)
: Array(spec, nullptr, [](char * /*unused*/) {}) {
ptr_.reset(new char[size * element_size](),
[](const char *p) { delete[] p; });
}
/**
* Take multidimensional index into the Array.
*/
template<typename... Index>
inline Array operator()(Index... index) const {
constexpr std::size_t num_index = sizeof...(Index);
std::size_t offset = 0;
std::size_t i = 0;
for (((offset = offset * shape_[i++] + index), ...); i < ndim; ++i) {
offset *= shape_[i];
}
return Array(
ptr_.get() + offset * element_size,
std::vector<std::size_t>(shape_.begin() + num_index, shape_.end()),
element_size, [](char * /*unused*/) {});
}
/**
* Index operator of array, takes the index along the first axis.
*/
inline Array operator[](int index) const { return this->operator()(index); }
/**
* Take a slice at the first axis of the Array.
*/
[[nodiscard]] Array Slice(std::size_t start, std::size_t end) const {
std::vector<std::size_t> new_shape(shape_);
new_shape[0] = end - start;
std::size_t offset = 0;
if (shape_[0] > 0) {
offset = start * size / shape_[0];
}
return {ptr_.get() + offset * element_size, std::move(new_shape),
element_size, [](char *p) {}};
}
/**
* Copy the content of another Array to this Array.
*/
void Assign(const Array &value) const {
std::memcpy(ptr_.get(), value.ptr_.get(), size * element_size);
}
/**
* Return a clone of this array.
*/
Array Clone() const {
std::vector<int> shape;
for (int i = 0; i < ndim; i++) {
shape.push_back(shape_[i]);
}
auto spec = ShapeSpec(element_size, shape);
Array ret(spec);
ret.Assign(*this);
return ret;
}
/**
* Assign to this Array a scalar value. This Array needs to have a scalar
* shape.
*/
template<typename T>
void operator=(const T &value) const {
*reinterpret_cast<T *>(ptr_.get()) = value;
}
/**
* Fills this array with a scalar value of type T.
*/
template<typename T>
void Fill(const T &value) const {
auto *data = reinterpret_cast<T *>(ptr_.get());
std::fill(data, data + size, value);
}
/**
* Copy the memory starting at `raw.first`, to `raw.first + raw.second` to the
* memory of this Array.
*/
template<typename T>
void Assign(const T *buff, std::size_t sz) const {
std::memcpy(ptr_.get(), buff, sz * sizeof(T));
}
template<typename T>
void Assign(const T *buff, std::size_t sz, ptrdiff_t offset) const {
offset = offset * (element_size / sizeof(char));
std::memcpy(ptr_.get() + offset, buff, sz * sizeof(T));
}
/**
* Cast the Array to a scalar value of type `T`. This Array needs to have a
* scalar shape.
*/
template<typename T>
operator const T &() const {// NOLINT
return *reinterpret_cast<T *>(ptr_.get());
}
/**
* Cast the Array to a scalar value of type `T`. This Array needs to have a
* scalar shape.
*/
template<typename T>
operator T &() {// NOLINT
return *reinterpret_cast<T *>(ptr_.get());
}
/**
* Size of axis `dim`.
*/
[[nodiscard]] inline std::size_t Shape(std::size_t dim) const {
return shape_[dim];
}
/**
* Shape
*/
[[nodiscard]] inline const std::vector<std::size_t> &Shape() const {
return shape_;
}
/**
* Pointer to the raw memory.
*/
[[nodiscard]] inline void *Data() const { return ptr_.get(); }
/**
* Truncate the Array. Return a new Array that shares the same memory
* location but with a truncated shape.
*/
[[nodiscard]] Array Truncate(std::size_t end) const {
auto new_shape = std::vector<std::size_t>(shape_);
new_shape[0] = end;
Array ret(ptr_, std::move(new_shape), element_size);
return ret;
}
void Zero() const { std::memset(ptr_.get(), 0, size * element_size); }
[[nodiscard]] std::shared_ptr<char> SharedPtr() const { return ptr_; }
};
template<typename Dtype>
class TArray : public Array {
public:
explicit TArray(const Spec<Dtype> &spec) : Array(spec) {}
};
#endif // MCTS_CORE_ARRAY_H_
\ No newline at end of file
#ifndef MCTS_CORE_COMMON_H_
#define MCTS_CORE_COMMON_H_
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include "mcts/core/array.h"
namespace py = pybind11;
template<typename dtype>
py::array_t<dtype> ArrayToNumpy(const Array &a) {
auto *ptr = new std::shared_ptr<char>(a.SharedPtr());
auto capsule = py::capsule(ptr, [](void *ptr) {
delete reinterpret_cast<std::shared_ptr<char> *>(ptr);
});
return py::array(a.Shape(), reinterpret_cast<dtype *>(a.Data()), capsule);
}
template<typename dtype>
Array NumpyToArray(const py::array_t<dtype> &arr) {
using ArrayT = py::array_t<dtype, py::array::c_style | py::array::forcecast>;
ArrayT arr_t(arr);
ShapeSpec spec(arr_t.itemsize(),
std::vector<int>(arr_t.shape(), arr_t.shape() + arr_t.ndim()));
return {spec, reinterpret_cast<char *>(arr_t.mutable_data())};
}
template <typename Sequence>
inline py::array_t<typename Sequence::value_type> as_pyarray(Sequence &&seq) {
using T = typename Sequence::value_type;
std::unique_ptr<Sequence> seq_ptr = std::make_unique<Sequence>(std::forward<Sequence>(seq));
return py::array_t<T>({seq_ptr->size()}, {sizeof(T)}, seq_ptr->data());
}
#endif // MCTS_CORE_COMMON_H_
\ No newline at end of file
#ifndef MCTS_CORE_MINIMAX_H_
#define MCTS_CORE_MINIMAX_H_
#include <iostream>
#include <vector>
const float FLOAT_MAX = 1000000.0;
const float FLOAT_MIN = -FLOAT_MAX;
class MinMaxStats {
public:
float maximum, minimum, value_delta_max;
MinMaxStats() {
this->maximum = FLOAT_MIN;
this->minimum = FLOAT_MAX;
this->value_delta_max = 0.;
}
~MinMaxStats() {}
void set_delta(float value_delta_max) {
this->value_delta_max = value_delta_max;
}
void update(float value) {
if(value > this->maximum){
this->maximum = value;
}
if(value < this->minimum){
this->minimum = value;
}
}
void clear() {
this->maximum = FLOAT_MIN;
this->minimum = FLOAT_MAX;
}
float normalize(float value) const {
float norm_value = value;
float delta = this->maximum - this->minimum;
if(delta > 0){
if(delta < this->value_delta_max){
norm_value = (norm_value - this->minimum) / this->value_delta_max;
}
else{
norm_value = (norm_value - this->minimum) / delta;
}
}
return norm_value;
}
};
class MinMaxStatsList {
public:
int num;
std::vector<MinMaxStats> stats_lst;
MinMaxStatsList() {
this->num = 0;
}
MinMaxStatsList(int num) {
this->num = num;
for(int i = 0; i < num; ++i){
this->stats_lst.push_back(MinMaxStats());
}
}
~MinMaxStatsList() {}
void set_delta(float value_delta_max) {
for(int i = 0; i < this->num; ++i){
this->stats_lst[i].set_delta(value_delta_max);
}
}
};
#endif // MCTS_CORE_MINIMAX_H_
\ No newline at end of file
#ifndef MCTS_CORE_SPEC_H_
#define MCTS_CORE_SPEC_H_
#include <cstddef>
#include <functional>
#include <limits>
#include <memory>
#include <numeric>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
static std::size_t Prod(const std::size_t *shape, std::size_t ndim) {
return std::accumulate(shape, shape + ndim, static_cast<std::size_t>(1),
std::multiplies<>());
}
class ShapeSpec {
public:
int element_size;
std::vector<int> shape;
ShapeSpec() = default;
ShapeSpec(int element_size, std::vector<int> shape_vec)
: element_size(element_size), shape(std::move(shape_vec)) {}
[[nodiscard]] ShapeSpec Batch(int batch_size) const {
std::vector<int> new_shape = {batch_size};
new_shape.insert(new_shape.end(), shape.begin(), shape.end());
return {element_size, std::move(new_shape)};
}
[[nodiscard]] std::vector<std::size_t> Shape() const {
auto s = std::vector<std::size_t>(shape.size());
for (std::size_t i = 0; i < shape.size(); ++i) {
s[i] = shape[i];
}
return s;
}
};
template<typename D>
class Spec : public ShapeSpec {
public:
using dtype = D;// NOLINT
std::tuple<dtype, dtype> bounds = {std::numeric_limits<dtype>::min(),
std::numeric_limits<dtype>::max()};
std::tuple<std::vector<dtype>, std::vector<dtype>> elementwise_bounds;
explicit Spec(std::vector<int> &&shape)
: ShapeSpec(sizeof(dtype), std::move(shape)) {}
explicit Spec(const std::vector<int> &shape)
: ShapeSpec(sizeof(dtype), shape) {}
/* init with constant bounds */
Spec(std::vector<int> &&shape, std::tuple<dtype, dtype> &&bounds)
: ShapeSpec(sizeof(dtype), std::move(shape)), bounds(std::move(bounds)) {}
Spec(const std::vector<int> &shape, const std::tuple<dtype, dtype> &bounds)
: ShapeSpec(sizeof(dtype), shape), bounds(bounds) {}
/* init with elementwise bounds */
Spec(std::vector<int> &&shape,
std::tuple<std::vector<dtype>, std::vector<dtype>> &&elementwise_bounds)
: ShapeSpec(sizeof(dtype), std::move(shape)),
elementwise_bounds(std::move(elementwise_bounds)) {}
Spec(const std::vector<int> &shape,
const std::tuple<std::vector<dtype>, std::vector<dtype>> &
elementwise_bounds)
: ShapeSpec(sizeof(dtype), shape),
elementwise_bounds(elementwise_bounds) {}
[[nodiscard]] Spec Batch(int batch_size) const {
std::vector<int> new_shape = {batch_size};
new_shape.insert(new_shape.end(), shape.begin(), shape.end());
return Spec(std::move(new_shape));
}
};
template<typename dtype>
class TArray;
template<typename dtype>
using Container = std::unique_ptr<TArray<dtype>>;
template<typename D>
class Spec<Container<D>> : public ShapeSpec {
public:
using dtype = Container<D>;// NOLINT
Spec<D> inner_spec;
explicit Spec(const std::vector<int> &shape, const Spec<D> &inner_spec)
: ShapeSpec(sizeof(Container<D>), shape), inner_spec(inner_spec) {}
explicit Spec(std::vector<int> &&shape, Spec<D> &&inner_spec)
: ShapeSpec(sizeof(Container<D>), std::move(shape)),
inner_spec(std::move(inner_spec)) {}
};
#endif // MCTS_CORE_SPEC_H_
\ No newline at end of file
from typing import Tuple, Sequence
import numpy as np
class State:
def __init__(self, batch_shape: Tuple[int, ...], store):
assert isinstance(store, np.ndarray)
self.store = store
self.batch_shape = batch_shape
self.ndim = len(batch_shape)
def get_state_keys(self):
return self.store
@classmethod
def from_item(cls, item):
return cls((1,), np.array([item], dtype=np.int32))
def reshape(self, batch_shape: Tuple[int, ...]):
self.batch_shape = batch_shape
self.ndim = len(batch_shape)
return self
def item(self):
assert self.ndim == 1 and self.batch_shape[0] == 1
return self.store[0]
@classmethod
def from_state_list(cls, state_list, batch_shape=None):
if isinstance(state_list[0], State):
batch_shape_ = (len(state_list),)
elif isinstance(state_list[0], Sequence):
batch_shape_ = (len(state_list), len(state_list[0]))
assert isinstance(state_list[0][0], State)
else:
raise ValueError("Invalid dim of states")
if batch_shape is None:
batch_shape = batch_shape_
else:
assert len(batch_shape) == 2 and len(batch_shape_) == 1
if len(batch_shape) == 2:
states = [s for ss in state_list for s in ss]
else:
states = state_list
state_keys = np.concatenate([s.store for s in states], dtype=np.int32, axis=0)
return State(batch_shape, state_keys)
def _get_by_index(self, batch_shape, indices):
state_keys = self.store[indices]
return State(batch_shape, state_keys)
def __getitem__(self, item):
return self.get(item)
def get(self, i):
if self.ndim == 2:
assert isinstance(i, tuple)
i = i[0] * self.batch_shape[1] + i[1]
i = np.array([i], dtype=np.int32)
return self._get_by_index((1,), i)
def __len__(self) -> int:
return len(self.store)
def __repr__(self) -> str:
return f'State(batch_shape={self.batch_shape}, ndim={self.ndim})'
def __str__(self) -> str:
return f'State(batch_shape={self.batch_shape}, ndim={self.ndim})'
from setuptools import setup, find_packages
__version__ = "0.0.1"
INSTALL_REQUIRES = [
"setuptools",
"wheel",
"pybind11-stubgen",
"numpy",
]
setup(
name="mcts",
version=__version__,
packages=find_packages(include='mcts*'),
long_description="",
install_requires=INSTALL_REQUIRES,
python_requires=">=3.7",
include_package_data=True,
)
\ No newline at end of file
......@@ -39,7 +39,7 @@ class Args:
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck/OldSchool.ydk"
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
......@@ -47,19 +47,21 @@ class Args:
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = "embeddings_en.npy"
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 16
"""the number of history actions to use"""
play_mode: str = "bot"
"""the play mode, can be combination of 'self', 'bot', 'random', like 'self+bot'"""
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
checkpoint: Optional[str] = None
"""the checkpoint to load the model from"""
total_timesteps: int = 1000000000
"""total timesteps of the experiments"""
......@@ -236,8 +238,15 @@ def run(local_rank, world_size):
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
if args.embedding_file:
if args.checkpoint:
agent.load_state_dict(torch.load(args.checkpoint, map_location=device))
fprint(f"Loaded checkpoint from {args.checkpoint}")
elif args.embedding_file:
agent.load_embeddings(embeddings)
fprint(f"Loaded embeddings from {args.embedding_file}")
if args.embedding_file:
agent.freeze_embeddings()
optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
......@@ -431,12 +440,10 @@ def run(local_rank, world_size):
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
# TODO: optimize this
done_used1 = torch.zeros_like(next_done, dtype=torch.bool)
done_used2 = torch.zeros_like(next_done, dtype=torch.bool)
reward1 = 0
reward2 = 0
lastgaelam1 = 0
lastgaelam2 = 0
done_used1 = torch.ones_like(next_done, dtype=torch.bool)
done_used2 = torch.ones_like(next_done, dtype=torch.bool)
reward1 = reward2 = 0
lastgaelam1 = lastgaelam2 = 0
for t in reversed(range(args.num_steps)):
# if learns[t]:
# if dones[t+1]:
......@@ -586,6 +593,7 @@ def run(local_rank, world_size):
writer.add_scalar("charts/SPS", SPS, global_step)
if iteration % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
episode_lengths = []
episode_rewards = []
......@@ -627,6 +635,8 @@ def run(local_rank, world_size):
eval_time = time.time() - _start
fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}, eval_ep_len={eval_ep_len:.1f}, eval_win_rate={eval_win_rate:.4f}")
# Eval with old model
if args.world_size > 1:
dist.destroy_process_group()
envs.close()
......
......@@ -8,19 +8,6 @@ add_requires(
"sqlitecpp 3.2.1")
-- target("dummy_ygopro")
-- add_rules("python.library")
-- add_files("ygoenv/ygoenv/dummy/*.cpp")
-- add_packages("pybind11", "fmt", "glog", "concurrentqueue")
-- set_languages("c++17")
-- add_includedirs("ygoenv")
-- after_build(function (target)
-- local install_target = "$(projectdir)/ygoenv/ygoenv/dummy"
-- os.cp(target:targetfile(), install_target)
-- print("Copy target to " .. install_target)
-- end)
target("ygopro_ygoenv")
add_rules("python.library")
add_files("ygoenv/ygoenv/ygopro/*.cpp")
......@@ -37,3 +24,22 @@ target("ygopro_ygoenv")
os.cp(target:targetfile(), install_target)
print("Copy target to " .. install_target)
end)
target("alphazero_mcts")
add_rules("python.library")
add_files("mcts/mcts/alphazero/*.cpp")
add_packages("pybind11")
set_languages("c++17")
if is_mode("release") then
set_policy("build.optimization.lto", true)
add_cxxflags("-march=native")
end
add_includedirs("mcts")
after_build(function (target)
local install_target = "$(projectdir)/mcts/mcts/alphazero"
os.cp(target:targetfile(), install_target)
print("Copy target to " .. install_target)
os.run("pybind11-stubgen mcts.alphazero.alphazero_mcts -o %s", "$(projectdir)/mcts")
end)
......@@ -150,14 +150,15 @@ class Encoder(nn.Module):
elif "fc_emb" in n:
nn.init.uniform_(m.weight, -scale, scale)
def load_embeddings(self, embeddings, freeze=True):
def load_embeddings(self, embeddings):
weight = self.id_embed.weight
embeddings = torch.from_numpy(embeddings).to(dtype=weight.dtype, device=weight.device)
unknown_embed = embeddings.mean(dim=0, keepdim=True)
embeddings = torch.cat([unknown_embed, embeddings], dim=0)
weight.data.copy_(embeddings)
if freeze:
weight.requires_grad = False
def freeze_embeddings(self):
self.id_embed.weight.requires_grad = False
def num_transform(self, x):
return self.num_fc(bytes_to_bin(x, self.bin_points, self.bin_intervals))
......@@ -409,8 +410,11 @@ class PPOAgent(nn.Module):
nn.Linear(c // 2, 1),
)
def load_embeddings(self, embeddings, freeze=True):
self.encoder.load_embeddings(embeddings, freeze)
def load_embeddings(self, embeddings):
self.encoder.load_embeddings(embeddings)
def freeze_embeddings(self):
self.encoder.freeze_embeddings()
def get_logit(self, x):
f_actions, f_state, mask, valid = self.encoder(x)
......
......@@ -1827,14 +1827,17 @@ private:
} else {
auto it = spec2index.find(spec);
if (it == spec2index.end()) {
// TODO: find the root cause
// print spec2index
fmt::println("Spec2index:");
for (auto &[k, v] : spec2index) {
fmt::println("{}: {}", k, v);
}
throw std::runtime_error("Spec not found: " + spec);
// throw std::runtime_error("Spec not found: " + spec);
idx = 1;
} else {
idx = it->second;
}
idx = it->second;
}
feat(i, 2 * j) = static_cast<uint8_t>(idx >> 8);
feat(i, 2 * j + 1) = static_cast<uint8_t>(idx & 0xff);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment