Commit 16fca7eb authored by biluo.shen's avatar biluo.shen

Add edopro

parent d2369b3e
......@@ -531,8 +531,11 @@ namespace tree
float max_score = FLOAT_MIN;
const float epsilon = 0.000001;
std::vector<Action> max_index_lst;
int action = 0;
for (const auto &[a, child] : root->children)
{
action = a;
float temp_score = ucb_score(child, min_max_stats, mean_q, root->visit_count, pb_c_base, pb_c_init, discount_factor);
if (max_score < temp_score)
{
......@@ -547,7 +550,6 @@ namespace tree
}
}
int action = 0;
if (max_index_lst.size() > 0)
{
std::uniform_int_distribution<int> dist(0, max_index_lst.size() - 1);
......
package("edopro-core")
set_homepage("https://github.com/edo9300/ygopro-core")
set_urls("https://github.com/edo9300/ygopro-core.git")
add_deps("lua")
on_install("linux", function (package)
io.writefile("xmake.lua", [[
add_rules("mode.debug", "mode.release")
add_requires("lua")
target("edopro-core")
set_kind("static")
set_languages("c++17")
add_files("*.cpp")
add_headerfiles("*.h")
add_headerfiles("RNG/*.hpp")
add_packages("lua")
]])
local check_and_insert = function(file, line, insert)
local lines = table.to_array(io.lines(file))
if lines[line] ~= insert then
table.insert(lines, line, insert)
io.writefile(file, table.concat(lines, "\n"))
end
end
local configs = {}
if package:config("shared") then
configs.kind = "shared"
end
import("package.tools.xmake").install(package)
os.cp("*.h", package:installdir("include", "edopro-core"))
os.cp("RNG", package:installdir("include", "edopro-core"))
end)
package_end()
\ No newline at end of file
......@@ -63,7 +63,7 @@ class Args:
checkpoint: Optional[str] = None
"""the checkpoint to load the model from"""
total_timesteps: int = 1000000000
total_timesteps: int = 2000000000
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
......
......@@ -3,7 +3,7 @@ add_rules("mode.debug", "mode.release")
add_repositories("my-repo repo")
add_requires(
"ygopro-core", "pybind11 2.10.*", "fmt 10.2.*", "glog 0.6.0",
"ygopro-core", "edopro-core", "pybind11 2.10.*", "fmt 10.2.*", "glog 0.6.0",
"sqlite3 3.43.0+200", "concurrentqueue 1.0.4", "unordered_dense 4.4.*",
"sqlitecpp 3.2.1")
......@@ -26,6 +26,24 @@ target("ygopro_ygoenv")
end)
target("edopro_ygoenv")
add_rules("python.library")
add_files("ygoenv/ygoenv/edopro/*.cpp")
add_packages("pybind11", "fmt", "glog", "concurrentqueue", "sqlitecpp", "unordered_dense", "edopro-core")
set_languages("c++17")
if is_mode("release") then
set_policy("build.optimization.lto", true)
add_cxxflags("-march=native")
end
add_includedirs("ygoenv")
after_build(function (target)
local install_target = "$(projectdir)/ygoenv/ygoenv/edopro"
os.cp(target:targetfile(), install_target)
print("Copy target to " .. install_target)
end)
target("alphazero_mcts")
add_rules("python.library")
add_files("mcts/mcts/alphazero/*.cpp")
......
from ygoenv.python.api import py_env
from .edopro_ygoenv import (
_EDOProEnvPool,
_EDOProEnvSpec,
init_module,
)
(
EDOProEnvSpec,
EDOProDMEnvPool,
EDOProGymEnvPool,
EDOProGymnasiumEnvPool,
) = py_env(_EDOProEnvSpec, _EDOProEnvPool)
__all__ = [
"EDOProEnvSpec",
"EDOProDMEnvPool",
"EDOProGymEnvPool",
"EDOProGymnasiumEnvPool",
]
#include "ygoenv/edopro/edopro.h"
#include "ygoenv/core/py_envpool.h"
using EDOProEnvSpec = PyEnvSpec<edopro::EDOProEnvSpec>;
using EDOProEnvPool = PyEnvPool<edopro::EDOProEnvPool>;
PYBIND11_MODULE(edopro_ygoenv, m) {
REGISTER(m, EDOProEnvSpec, EDOProEnvPool)
m.def("init_module", &edopro::init_module);
}
This source diff could not be displayed because it is too large. You can view the blob instead.
from ygoenv.registration import register
register(
task_id="EDOPro-v0",
import_path="ygoenv.edopro",
spec_cls="EDOProEnvSpec",
dm_cls="EDOProDMEnvPool",
gym_cls="EDOProGymEnvPool",
gymnasium_cls="EDOProGymnasiumEnvPool",
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment