Commit 389fcd9c authored by sbl1996@126.com's avatar sbl1996@126.com

fixing set_obs_cards seg fault

parent 46b4b5ae
SCRIPTS_REPO := "https://github.com/Fluorohydride/ygopro-scripts.git"
SCRIPTS_REPO := "https://github.com/mycard/ygopro-scripts.git"
SCRIPTS_DIR := "../ygopro-scripts"
DATABASE_REPO := "https://github.com/mycard/ygopro-database/raw/master/locales"
LOCALES := en zh
......
......@@ -11,7 +11,7 @@ YGO Agent is a project to create a Yu-Gi-Oh! AI using deep learning (LLMs, RL).
## Building
### prerequisites
### Prerequisites
- gcc 10+ or clang 11+
- [xmake](https://xmake.io/#/getting_started)
- PyTorch 2.0 or later with cuda support
......
......@@ -32,7 +32,7 @@
35059553
24224830
24224830
84211599
97268402
73628505
40155014
40155014
......@@ -58,3 +58,4 @@
02857636
!side
27204312
92907249
......@@ -28,7 +28,7 @@
51697825
51697825
51697825
84211599
98645731
28126717
55521751
24224830
......
......@@ -19,7 +19,7 @@
80433039
17827173
24508238
84211599
98645731
75500286
98645731
49238328
......
......@@ -23,7 +23,7 @@
97268402
2511
2511
84211599
97268402
24224830
24224830
33407125
......
......@@ -24,7 +24,7 @@
1984618
1984618
35261759
84211599
49238328
49238328
84797028
84797028
......
......@@ -24,7 +24,7 @@
24508238
2295440
51405049
84211599
27204311
89023486
24224830
24224830
......@@ -58,3 +58,4 @@
94259633
!side
52340445
27204312
......@@ -840,3 +840,4 @@
35809262
92731385
74018812
92907249
......@@ -5,7 +5,6 @@ from collections import deque
from dataclasses import dataclass
from typing import Literal, Optional
import ygoenv
import numpy as np
import tyro
......@@ -22,7 +21,7 @@ from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs
from ygoai.rl.ppo import bootstrap_value_self
from ygoai.rl.ppo import bootstrap_value_selfplay
from ygoai.rl.eval import evaluate
......@@ -79,11 +78,6 @@ class Args:
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
update_win_rate: float = 0.55
"""the required win rate to update the agent"""
update_return: float = 0.1
"""the required return to update the agent"""
minibatch_size: int = 256
"""the mini-batch size"""
update_epochs: int = 2
......@@ -265,28 +259,21 @@ def main():
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
agent_t = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
agent_t.eval()
agent_t.load_state_dict(agent.state_dict())
def predict_step(agent: Agent, agent_t: Agent, next_obs, learn):
def predict_step(agent: Agent, next_obs):
with torch.no_grad():
with autocast(enabled=args.fp16_eval):
logits, value, valid = agent(next_obs)
logits_t, value_t, valid = agent_t(next_obs)
logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
return logits, value
from ygoai.rl.ppo import train_step
if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
agent = torch.compile(agent, mode=args.compile)
example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
train_step = torch.compile(train_step, mode=args.compile)
# ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device)
......@@ -298,7 +285,6 @@ def main():
learns = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
version = 0
# TRY NOT TO MODIFY: start the game
global_step = 0
......@@ -315,7 +301,7 @@ def main():
])
np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
next_value = 0
next_value1 = next_value2 = 0
for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so.
......@@ -324,8 +310,6 @@ def main():
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow
agent.eval()
model_time = 0
env_time = 0
collect_start = time.time()
......@@ -339,7 +323,7 @@ def main():
learns[step] = learn
_start = time.time()
logits, value = predict_step(agent, traced_model_t, next_obs, learn)
logits, value = predict_step(traced_model, next_obs)
value = value.flatten()
probs = Categorical(logits=logits)
action = probs.sample()
......@@ -352,7 +336,8 @@ def main():
model_time += time.time() - _start
next_nonterminal = 1 - next_done.float()
next_value = torch.where(learn, value, next_value) * next_nonterminal
next_value1 = torch.where(learn, value, next_value1) * next_nonterminal
next_value2 = torch.where(learn, next_value2, value) * next_nonterminal
_start = time.time()
to_play = next_to_play_
......@@ -393,16 +378,14 @@ def main():
_start = time.time()
# bootstrap value if not done
with torch.no_grad():
value = agent(next_obs)[1].reshape(-1)
value_t = traced_model_t(next_obs)[1].reshape(-1)
value = torch.where(next_to_play == ai_player1, value, value_t)
nextvalues = torch.where(next_to_play == ai_player1, value, next_value)
advantages = bootstrap_value_self(
values, rewards, dones, learns, nextvalues, next_done, args.gamma, args.gae_lambda)
value = traced_model(next_obs)[1].reshape(-1)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
advantages = bootstrap_value_selfplay(
values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, args.gamma, args.gae_lambda)
returns = advantages + values
bootstrap_time = time.time() - _start
agent.train()
_start = time.time()
# flatten the batch
b_obs = {
......@@ -475,31 +458,11 @@ def main():
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step)
if rank == 0:
should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return
should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device)
else:
should_update = torch.zeros((), dtype=torch.int64, device=device)
if args.world_size > 1:
dist.all_reduce(should_update, op=dist.ReduceOp.SUM)
should_update = should_update.item() > 0
if should_update:
agent_t.load_state_dict(agent.state_dict())
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
version += 1
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear()
avg_ep_returns.clear()
if iteration % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
agent.eval()
eval_return = evaluate(
eval_envs, agent, local_eval_episodes, device, args.fp16_eval)
eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)[0]
eval_stats = torch.tensor(eval_return, dtype=torch.float32, device=device)
# sync the statistics
......
import math
import torch
import torch.nn as nn
......@@ -17,6 +19,29 @@ def make_bin_params(x_max=32000, n_bins=32, sig_bins=24):
intervals = torch.cat([points[0:1], points[1:] - points[:-1]], dim=0)
return points, intervals
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.0, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Arguments:
x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
"""
x = x + self.pe[:x.size(0)]
return self.dropout(x)
class Encoder(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
......@@ -122,7 +147,6 @@ class Encoder(nn.Module):
nn.Linear(c, c),
)
self.h_id_fc_emb = linear(1024, c)
self.h_id_norm = nn.LayerNorm(c, elementwise_affine=False)
self.h_a_feat_norm = nn.LayerNorm(c, elementwise_affine=False)
......@@ -134,6 +158,7 @@ class Encoder(nn.Module):
for i in range(num_action_layers)
])
self.action_history_pe = PositionalEncoding(c, dropout=0.0)
self.action_history_net = nn.ModuleList([
nn.TransformerDecoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
......@@ -322,6 +347,7 @@ class Encoder(nn.Module):
x_h_a_feats = self.encode_action_(x_h_actions[:, :, mo:])
x_h_a_feats = torch.cat(x_h_a_feats, dim=-1)
f_h_actions = self.h_id_norm(x_h_id) + self.h_a_feat_norm(x_h_a_feats)
f_h_actions = self.action_history_pe(f_h_actions)
for layer in self.action_history_net:
f_actions = layer(f_actions, f_h_actions)
......
......@@ -108,7 +108,6 @@ def bootstrap_value_self(values, rewards, dones, learns, nextvalues, next_done,
def bootstrap_value_selfplay(values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, gamma, gae_lambda):
# TODO: drop epsilon steps for estimated nextvalues
num_steps = rewards.size(0)
advantages = torch.zeros_like(rewards)
# TODO: optimize this
......
......@@ -365,8 +365,7 @@ inline std::string ls_to_spec(uint8_t loc, uint8_t seq, uint8_t pos) {
return spec;
}
inline std::string ls_to_spec(uint8_t loc, uint8_t seq, uint8_t pos,
bool opponent) {
inline std::string ls_to_spec(uint8_t loc, uint8_t seq, uint8_t pos, bool opponent) {
std::string spec = ls_to_spec(loc, seq, pos);
if (opponent) {
spec.insert(0, 1, 'o');
......@@ -1471,6 +1470,7 @@ public:
int init_lp = 8000;
int startcount = 5;
int drawcount = 1;
for (PlayerId i = 0; i < 2; i++) {
if (players_[i] != nullptr) {
delete players_[i];
......@@ -1641,10 +1641,12 @@ public:
void Step(const Action &action) override {
// clock_t start = clock();
// fmt::println("Step");
int idx = action["action"_];
callback_(idx);
update_history_actions(to_play_, idx);
// update_history_actions(to_play_, idx);
// fmt::println("update_history_actions");
PlayerId player = to_play_;
......@@ -2153,6 +2155,10 @@ private:
ReplayWriteInt8(1);
fwrite(buf, 1, 1, fp_);
break;
case MSG_SELECT_COUNTER:
ReplayWriteInt8(2);
fwrite(buf, 2, 1, fp_);
break;
case MSG_SELECT_PLACE:
case MSG_SELECT_DISFIELD:
ReplayWriteInt8(3);
......@@ -2183,10 +2189,13 @@ private:
state["obs:global_"_][22] = uint8_t(1);
return;
}
// fmt::println("writestate");
auto [spec2index, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_);
// fmt::println("_set_obs_cards");
_set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
// _set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
// fmt::println("_set_obs_global");
// we can't shuffle because idx must be stable in callback
if (n_options > max_options()) {
......@@ -2198,10 +2207,12 @@ private:
// fmt::println("{} {}", key, val);
// }
_set_obs_actions(state["obs:actions_"_], spec2index, msg_, options_);
// _set_obs_actions(state["obs:actions_"_], spec2index, msg_, options_);
// fmt::println("_set_obs_actions");
n_options = options_.size();
state["info:num_options"_] = n_options;
return;
// update h_card_ids from state
auto &h_card_ids = to_play_ == 0 ? h_card_ids_0_ : h_card_ids_1_;
......@@ -2222,6 +2233,7 @@ private:
}
h_card_ids[i] = card_ids;
}
// fmt::println("update h_card_ids");
// write history actions
......@@ -2235,6 +2247,7 @@ private:
n_action_feats * n1);
state["obs:h_actions_"_][n1].Assign((uint8_t *)history_actions.Data(),
n_action_feats * ha_p);
// fmt::println("write history actions");
}
void show_decision(int idx) {
......@@ -2307,8 +2320,8 @@ private:
if ((play_mode_ == kSelfPlay) || (to_play_ == ai_player_)) {
if (options_.size() == 1) {
callback_(0);
update_h_card_ids(to_play_, 0);
update_history_actions(to_play_, 0);
// update_h_card_ids(to_play_, 0);
// update_history_actions(to_play_, 0);
if (verbose_) {
show_decision(0);
}
......@@ -2488,6 +2501,7 @@ private:
}
cards.push_back(c);
}
fmt::println("qdp: {}, bl: {}, n: {}", qdp_, bl, cards.size());
return cards;
}
......@@ -2513,8 +2527,7 @@ private:
return cards;
}
std::vector<IdleCardSpec> read_cardlist_spec(bool extra = false,
bool extra8 = false) {
std::vector<IdleCardSpec> read_cardlist_spec(PlayerId player, bool extra = false, bool extra8 = false) {
std::vector<IdleCardSpec> card_specs;
auto count = read_u8();
card_specs.reserve(count);
......@@ -2531,7 +2544,7 @@ private:
data = read_u32();
}
}
card_specs.push_back({code, ls_to_spec(loc, seq, 0), data});
card_specs.push_back({code, ls_to_spec(loc, seq, 0, player != controller), data});
}
return card_specs;
}
......@@ -2988,6 +3001,7 @@ private:
if (verbose_) {
cards.push_back(get_card(c, loc, seq));
}
// TODO: check if this is correct
revealed_.push_back(ls_to_spec(loc, seq, 0, c == player));
}
if (!verbose_) {
......@@ -3406,8 +3420,8 @@ private:
throw std::runtime_error("Retry");
} else if (msg_ == MSG_SELECT_BATTLECMD) {
auto player = read_u8();
auto activatable = read_cardlist_spec(true);
auto attackable = read_cardlist_spec(true, true);
auto activatable = read_cardlist_spec(player, true);
auto attackable = read_cardlist_spec(player, true, true);
bool to_m2 = read_u8();
bool to_ep = read_u8();
......@@ -4122,12 +4136,12 @@ private:
};
} else if (msg_ == MSG_SELECT_IDLECMD) {
int32_t player = read_u8();
auto summonable_ = read_cardlist_spec();
auto spsummon_ = read_cardlist_spec();
auto repos_ = read_cardlist_spec();
auto idle_mset_ = read_cardlist_spec();
auto idle_set_ = read_cardlist_spec();
auto idle_activate_ = read_cardlist_spec(true);
auto summonable_ = read_cardlist_spec(player);
auto spsummon_ = read_cardlist_spec(player);
auto repos_ = read_cardlist_spec(player);
auto idle_mset_ = read_cardlist_spec(player);
auto idle_set_ = read_cardlist_spec(player);
auto idle_activate_ = read_cardlist_spec(player, true);
bool to_bp_ = read_u8();
bool to_ep_ = read_u8();
read_u8(); // can_shuffle
......@@ -4332,6 +4346,35 @@ private:
resp_buf_[2] = seq;
YGO_SetResponseb(pduel_, resp_buf_);
};
} else if (msg_ == MSG_SELECT_COUNTER) {
auto player = read_u8();
auto counter_type = read_u16();
auto counter_count = read_u16();
int count = read_u8();
if (count != 1) {
throw std::runtime_error("Select counter count " +
std::to_string(count) + " not implemented");
}
auto pl = players_[player];
if (verbose_) {
pl->notify(fmt::format("Type new {} for {} card(s), separated by spaces.", "UNKNOWN_COUNTER", count));
}
for (int i = 0; i < count; ++i) {
auto code = read_u32();
auto controller = read_u8();
auto loc = read_u8();
auto seq = read_u8();
auto counter = read_u16();
if (verbose_) {
pl->notify(c_get_card(code).name_ + ": " + std::to_string(counter));
}
// auto spec = ls_to_spec(loc, seq, 0, controller != player);
// options_.push_back(spec);
}
// TODO: implement action
uint16_t resp = counter_count & 0xffff;
memcpy(resp_buf_, &resp, 2);
YGO_SetResponseb(pduel_, resp_buf_);
} else if (msg_ == MSG_ANNOUNCE_NUMBER) {
auto player = read_u8();
int count = read_u8();
......@@ -4448,6 +4491,11 @@ private:
} else {
show_deck(0);
show_deck(1);
// print byte by byte
for (int i = 0; i < dp_; ++i) {
fmt::print("{:02x} ", data_[i]);
}
fmt::print("\n");
throw std::runtime_error(
fmt::format("Unknown message {}, length {}, dp {}",
msg_to_string(msg_), dl_, dp_));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment