Commit 385bd1cb authored by biluo.shen's avatar biluo.shen

Add PPO selfplay

parent af60d012
...@@ -147,21 +147,21 @@ if __name__ == "__main__": ...@@ -147,21 +147,21 @@ if __name__ == "__main__":
embedding_shape = len(code_list) embedding_shape = len(code_list)
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device) agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent = agent.eval() # agent = agent.eval()
if args.checkpoint: if args.checkpoint:
state_dict = torch.load(args.checkpoint, map_location=device) state_dict = torch.load(args.checkpoint, map_location=device)
else: else:
state_dict = None state_dict = None
if args.compile: if args.compile:
agent = torch.compile(agent, mode='reduce-overhead')
if state_dict: if state_dict:
agent.load_state_dict(state_dict) print(agent.load_state_dict(state_dict))
agent = torch.compile(agent, mode='reduce-overhead')
else: else:
prefix = "_orig_mod." prefix = "_orig_mod."
if state_dict: if state_dict:
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()} state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
agent.load_state_dict(state_dict) print(agent.load_state_dict(state_dict))
if args.optimize: if args.optimize:
obs = create_obs(envs.observation_space, (num_envs,), device=device) obs = create_obs(envs.observation_space, (num_envs,), device=device)
...@@ -170,6 +170,7 @@ if __name__ == "__main__": ...@@ -170,6 +170,7 @@ if __name__ == "__main__":
agent = torch.jit.optimize_for_inference(traced_model) agent = torch.jit.optimize_for_inference(traced_model)
obs, infos = envs.reset() obs, infos = envs.reset()
next_to_play = infos['to_play']
episode_rewards = [] episode_rewards = []
episode_lengths = [] episode_lengths = []
...@@ -191,7 +192,7 @@ if __name__ == "__main__": ...@@ -191,7 +192,7 @@ if __name__ == "__main__":
_start = time.time() _start = time.time()
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs) obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
with torch.no_grad(): with torch.no_grad():
logits, values = agent(obs) logits, values, _valid = agent(obs)
probs = torch.softmax(logits, dim=-1) probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy() probs = probs.cpu().numpy()
if args.play: if args.play:
...@@ -212,9 +213,11 @@ if __name__ == "__main__": ...@@ -212,9 +213,11 @@ if __name__ == "__main__":
# print(k, v.tolist()) # print(k, v.tolist())
# print(infos) # print(infos)
# print(actions[0]) # print(actions[0])
to_play = next_to_play
_start = time.time() _start = time.time()
obs, rewards, dones, infos = envs.step(actions) obs, rewards, dones, infos = envs.step(actions)
next_to_play = infos['to_play']
env_time += time.time() - _start env_time += time.time() - _start
step += 1 step += 1
...@@ -225,7 +228,7 @@ if __name__ == "__main__": ...@@ -225,7 +228,7 @@ if __name__ == "__main__":
episode_length = infos['l'][idx] episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] episode_reward = infos['r'][idx]
if args.selfplay: if args.selfplay:
pl = 1 if infos['to_play'][idx] == 0 else -1 pl = 1 if to_play[idx] == 0 else -1
winner = 0 if episode_reward * pl > 0 else 1 winner = 0 if episode_reward * pl > 0 else 1
win = 1 - winner win = 1 - winner
else: else:
......
This diff is collapsed.
...@@ -320,66 +320,88 @@ class Encoder(nn.Module): ...@@ -320,66 +320,88 @@ class Encoder(nn.Module):
return f_actions, f_state, mask, valid return f_actions, f_state, mask, valid
class PPOCritic(nn.Module): # class PPOCritic(nn.Module):
def __init__(self, channels): # def __init__(self, channels):
super(PPOCritic, self).__init__() # super(PPOCritic, self).__init__()
c = channels # c = channels
self.net = nn.Sequential( # self.net = nn.Sequential(
nn.Linear(c * 2, c // 2), # nn.Linear(c * 2, c // 2),
nn.ReLU(), # nn.ReLU(),
nn.Linear(c // 2, 1), # nn.Linear(c // 2, 1),
) # )
# def forward(self, f_state):
# return self.net(f_state)
# class PPOActor(nn.Module):
# def __init__(self, channels):
# super(PPOActor, self).__init__()
# c = channels
# self.trans = nn.TransformerEncoderLayer(
# c, 4, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
# self.head = nn.Sequential(
# nn.Linear(c, c // 4),
# nn.ReLU(),
# nn.Linear(c // 4, 1),
# )
def forward(self, f_state): # def forward(self, f_actions, mask, action):
return self.net(f_state) # f_actions = self.trans(f_actions, src_key_padding_mask=mask)
# logits = self.head(f_actions)[..., 0]
# logits = logits.float()
# logits = logits.masked_fill(mask, float("-inf"))
# probs = Categorical(logits=logits)
# return probs.log_prob(action), probs.entropy()
class PPOActor(nn.Module): # def predict(self, f_actions, mask):
# f_actions = self.trans(f_actions, src_key_padding_mask=mask)
# logits = self.head(f_actions)[..., 0]
# logits = logits.float()
# logits = logits.masked_fill(mask, float("-inf"))
# return logits
def __init__(self, channels):
super(PPOActor, self).__init__() class Actor(nn.Module):
def __init__(self, channels, use_transformer=False):
super(Actor, self).__init__()
c = channels c = channels
self.trans = nn.TransformerEncoderLayer( self.use_transformer = use_transformer
c, 4, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False) if use_transformer:
self.transformer = nn.TransformerEncoderLayer(
c, 4, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
self.head = nn.Sequential( self.head = nn.Sequential(
nn.Linear(c, c // 4), nn.Linear(c, c // 4),
nn.ReLU(), nn.ReLU(),
nn.Linear(c // 4, 1), nn.Linear(c // 4, 1),
) )
def forward(self, f_actions, mask, action): def forward(self, f_actions, mask):
f_actions = self.trans(f_actions, src_key_padding_mask=mask) if self.use_transformer:
logits = self.head(f_actions)[..., 0] f_actions = self.transformer(f_actions, src_key_padding_mask=mask)
logits = logits.float()
logits = logits.masked_fill(mask, float("-inf"))
probs = Categorical(logits=logits)
return probs.log_prob(action), probs.entropy()
def predict(self, f_actions, mask):
f_actions = self.trans(f_actions, src_key_padding_mask=mask)
logits = self.head(f_actions)[..., 0] logits = self.head(f_actions)[..., 0]
logits = logits.float() logits = logits.float()
logits = logits.masked_fill(mask, float("-inf")) logits = logits.masked_fill(mask, float("-inf"))
return logits return logits
class PPOAgent(nn.Module): class PPOAgent(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2, def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True): num_history_action_layers=2, embedding_shape=None, bias=False,
affine=True, a_trans=True):
super(PPOAgent, self).__init__() super(PPOAgent, self).__init__()
self.encoder = Encoder( self.encoder = Encoder(
channels, num_card_layers, num_action_layers, num_history_action_layers, embedding_shape, bias, affine) channels, num_card_layers, num_action_layers, num_history_action_layers, embedding_shape, bias, affine)
c = channels c = channels
self.actor = nn.Sequential( self.actor = Actor(c, a_trans)
nn.Linear(c, c // 4),
nn.ReLU(),
nn.Linear(c // 4, 1),
)
self.critic = nn.Sequential( self.critic = nn.Sequential(
nn.Linear(c * 2, c // 2), nn.Linear(c * 2, c // 2),
...@@ -390,24 +412,15 @@ class PPOAgent(nn.Module): ...@@ -390,24 +412,15 @@ class PPOAgent(nn.Module):
def load_embeddings(self, embeddings, freeze=True): def load_embeddings(self, embeddings, freeze=True):
self.encoder.load_embeddings(embeddings, freeze) self.encoder.load_embeddings(embeddings, freeze)
def get_value(self, x): def get_logit(self, x):
f_actions, f_state, mask, valid = self.encoder(x) f_actions, f_state, mask, valid = self.encoder(x)
return self.critic(f_state) return self.actor(f_actions, mask)
def get_action_and_value(self, x, action): def get_value(self, x):
f_actions, f_state, mask, valid = self.encoder(x) f_actions, f_state, mask, valid = self.encoder(x)
return self.critic(f_state)
logits = self.actor(f_actions)[..., 0]
logits = logits.float()
logits = logits.masked_fill(mask, float("-inf"))
probs = Categorical(logits=logits)
return action, probs.log_prob(action), probs.entropy(), self.critic(f_state), valid
def forward(self, x): def forward(self, x):
f_actions, f_state, mask, valid = self.encoder(x) f_actions, f_state, mask, valid = self.encoder(x)
logits = self.actor(f_actions, mask)
logits = self.actor(f_actions)[..., 0] return logits, self.critic(f_state), valid
logits = logits.float()
logits = logits.masked_fill(mask, float("-inf"))
return logits, self.critic(f_state)
...@@ -2935,7 +2935,6 @@ private: ...@@ -2935,7 +2935,6 @@ private:
return; return;
} }
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto size = read_u8(); auto size = read_u8();
std::vector<Card> cards; std::vector<Card> cards;
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
...@@ -3315,7 +3314,6 @@ private: ...@@ -3315,7 +3314,6 @@ private:
throw std::runtime_error("Retry"); throw std::runtime_error("Retry");
} else if (msg_ == MSG_SELECT_BATTLECMD) { } else if (msg_ == MSG_SELECT_BATTLECMD) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto activatable = read_cardlist_spec(true); auto activatable = read_cardlist_spec(true);
auto attackable = read_cardlist_spec(true, true); auto attackable = read_cardlist_spec(true, true);
bool to_m2 = read_u8(); bool to_m2 = read_u8();
...@@ -3366,6 +3364,7 @@ private: ...@@ -3366,6 +3364,7 @@ private:
} }
int n_activatables = activatable.size(); int n_activatables = activatable.size();
int n_attackables = attackable.size(); int n_attackables = attackable.size();
to_play_ = player;
callback_ = [this, n_activatables, n_attackables, to_ep, to_m2](int idx) { callback_ = [this, n_activatables, n_attackables, to_ep, to_m2](int idx) {
if (idx < n_activatables) { if (idx < n_activatables) {
OCG_SetResponsei(pduel_, idx << 16); OCG_SetResponsei(pduel_, idx << 16);
...@@ -3382,7 +3381,6 @@ private: ...@@ -3382,7 +3381,6 @@ private:
}; };
} else if (msg_ == MSG_SELECT_UNSELECT_CARD) { } else if (msg_ == MSG_SELECT_UNSELECT_CARD) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
bool finishable = read_u8(); bool finishable = read_u8();
bool cancelable = read_u8(); bool cancelable = read_u8();
auto min = read_u8(); auto min = read_u8();
...@@ -3435,6 +3433,7 @@ private: ...@@ -3435,6 +3433,7 @@ private:
// cancelable and finishable not needed // cancelable and finishable not needed
to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
if (options_[idx] == "f") { if (options_[idx] == "f") {
OCG_SetResponsei(pduel_, -1); OCG_SetResponsei(pduel_, -1);
...@@ -3447,7 +3446,6 @@ private: ...@@ -3447,7 +3446,6 @@ private:
} else if (msg_ == MSG_SELECT_CARD) { } else if (msg_ == MSG_SELECT_CARD) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
bool cancelable = read_u8(); bool cancelable = read_u8();
auto min = read_u8(); auto min = read_u8();
auto max = read_u8(); auto max = read_u8();
...@@ -3535,6 +3533,7 @@ private: ...@@ -3535,6 +3533,7 @@ private:
} }
} }
to_play_ = player;
callback_ = [this, combs](int idx) { callback_ = [this, combs](int idx) {
const auto &comb = combs[idx]; const auto &comb = combs[idx];
resp_buf_[0] = comb.size(); resp_buf_[0] = comb.size();
...@@ -3545,7 +3544,6 @@ private: ...@@ -3545,7 +3544,6 @@ private:
}; };
} else if (msg_ == MSG_SELECT_TRIBUTE) { } else if (msg_ == MSG_SELECT_TRIBUTE) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
bool cancelable = read_u8(); bool cancelable = read_u8();
auto min = read_u8(); auto min = read_u8();
auto max = read_u8(); auto max = read_u8();
...@@ -3621,6 +3619,7 @@ private: ...@@ -3621,6 +3619,7 @@ private:
options_.push_back(option); options_.push_back(option);
} }
to_play_ = player;
callback_ = [this, combs](int idx) { callback_ = [this, combs](int idx) {
const auto &comb = combs[idx]; const auto &comb = combs[idx];
resp_buf_[0] = comb.size(); resp_buf_[0] = comb.size();
...@@ -3632,7 +3631,6 @@ private: ...@@ -3632,7 +3631,6 @@ private:
} else if (msg_ == MSG_SELECT_SUM) { } else if (msg_ == MSG_SELECT_SUM) {
auto mode = read_u8(); auto mode = read_u8();
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto val = read_u32(); auto val = read_u32();
auto min = read_u8(); auto min = read_u8();
auto max = read_u8(); auto max = read_u8();
...@@ -3761,6 +3759,7 @@ private: ...@@ -3761,6 +3759,7 @@ private:
options_.push_back(option); options_.push_back(option);
} }
to_play_ = player;
callback_ = [this, combs, must_select_size](int idx) { callback_ = [this, combs, must_select_size](int idx) {
const auto &comb = combs[idx]; const auto &comb = combs[idx];
resp_buf_[0] = must_select_size + comb.size(); resp_buf_[0] = must_select_size + comb.size();
...@@ -3775,7 +3774,6 @@ private: ...@@ -3775,7 +3774,6 @@ private:
} else if (msg_ == MSG_SELECT_CHAIN) { } else if (msg_ == MSG_SELECT_CHAIN) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto size = read_u8(); auto size = read_u8();
auto spe_count = read_u8(); auto spe_count = read_u8();
bool forced = read_u8(); bool forced = read_u8();
...@@ -3872,6 +3870,7 @@ private: ...@@ -3872,6 +3870,7 @@ private:
if (!forced) { if (!forced) {
options_.push_back("c"); options_.push_back("c");
} }
to_play_ = player;
callback_ = [this, forced](int idx) { callback_ = [this, forced](int idx) {
const auto &option = options_[idx]; const auto &option = options_[idx];
if ((option == "c") && (!forced)) { if ((option == "c") && (!forced)) {
...@@ -3882,7 +3881,6 @@ private: ...@@ -3882,7 +3881,6 @@ private:
}; };
} else if (msg_ == MSG_SELECT_YESNO) { } else if (msg_ == MSG_SELECT_YESNO) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
if (verbose_) { if (verbose_) {
auto desc = read_u32(); auto desc = read_u32();
...@@ -3907,6 +3905,7 @@ private: ...@@ -3907,6 +3905,7 @@ private:
dp_ += 4; dp_ += 4;
} }
options_ = {"y", "n"}; options_ = {"y", "n"};
to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
if (idx == 0) { if (idx == 0) {
OCG_SetResponsei(pduel_, 1); OCG_SetResponsei(pduel_, 1);
...@@ -3918,7 +3917,6 @@ private: ...@@ -3918,7 +3917,6 @@ private:
}; };
} else if (msg_ == MSG_SELECT_EFFECTYN) { } else if (msg_ == MSG_SELECT_EFFECTYN) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
std::string spec; std::string spec;
if (verbose_) { if (verbose_) {
...@@ -3981,6 +3979,7 @@ private: ...@@ -3981,6 +3979,7 @@ private:
spec = ls_to_spec(loc, seq, pos, c != player); spec = ls_to_spec(loc, seq, pos, c != player);
} }
options_ = {"y " + spec, "n " + spec}; options_ = {"y " + spec, "n " + spec};
to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
if (idx == 0) { if (idx == 0) {
OCG_SetResponsei(pduel_, 1); OCG_SetResponsei(pduel_, 1);
...@@ -3992,7 +3991,6 @@ private: ...@@ -3992,7 +3991,6 @@ private:
}; };
} else if (msg_ == MSG_SELECT_OPTION) { } else if (msg_ == MSG_SELECT_OPTION) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto size = read_u8(); auto size = read_u8();
if (verbose_) { if (verbose_) {
auto pl = players_[player]; auto pl = players_[player];
...@@ -4016,6 +4014,7 @@ private: ...@@ -4016,6 +4014,7 @@ private:
options_.push_back(std::to_string(i + 1)); options_.push_back(std::to_string(i + 1));
} }
} }
to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
if (verbose_) { if (verbose_) {
players_[to_play_]->notify("You selected option " + options_[idx] + players_[to_play_]->notify("You selected option " + options_[idx] +
...@@ -4029,7 +4028,6 @@ private: ...@@ -4029,7 +4028,6 @@ private:
}; };
} else if (msg_ == MSG_SELECT_IDLECMD) { } else if (msg_ == MSG_SELECT_IDLECMD) {
int32_t player = read_u8(); int32_t player = read_u8();
to_play_ = player;
auto summonable_ = read_cardlist_spec(); auto summonable_ = read_cardlist_spec();
auto spsummon_ = read_cardlist_spec(); auto spsummon_ = read_cardlist_spec();
auto repos_ = read_cardlist_spec(); auto repos_ = read_cardlist_spec();
...@@ -4134,6 +4132,7 @@ private: ...@@ -4134,6 +4132,7 @@ private:
} }
} }
to_play_ = player;
callback_ = [this, spsummon_offset, repos_offset, mset_offset, set_offset, callback_ = [this, spsummon_offset, repos_offset, mset_offset, set_offset,
activate_offset](int idx) { activate_offset](int idx) {
const auto &option = options_[idx]; const auto &option = options_[idx];
...@@ -4169,7 +4168,6 @@ private: ...@@ -4169,7 +4168,6 @@ private:
}; };
} else if (msg_ == MSG_SELECT_PLACE) { } else if (msg_ == MSG_SELECT_PLACE) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto count = read_u8(); auto count = read_u8();
if (count == 0) { if (count == 0) {
count = 1; count = 1;
...@@ -4189,6 +4187,7 @@ private: ...@@ -4189,6 +4187,7 @@ private:
" places for card, from " + specs_str + "."); " places for card, from " + specs_str + ".");
} }
} }
to_play_ = player;
callback_ = [this, player](int idx) { callback_ = [this, player](int idx) {
int y = player + 1; int y = player + 1;
std::string spec = options_[idx]; std::string spec = options_[idx];
...@@ -4205,7 +4204,6 @@ private: ...@@ -4205,7 +4204,6 @@ private:
}; };
} else if (msg_ == MSG_SELECT_DISFIELD) { } else if (msg_ == MSG_SELECT_DISFIELD) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto count = read_u8(); auto count = read_u8();
if (count == 0) { if (count == 0) {
count = 1; count = 1;
...@@ -4225,6 +4223,7 @@ private: ...@@ -4225,6 +4223,7 @@ private:
std::to_string(count) + " not implemented"); std::to_string(count) + " not implemented");
} }
} }
to_play_ = player;
callback_ = [this, player](int idx) { callback_ = [this, player](int idx) {
int y = player + 1; int y = player + 1;
std::string spec = options_[idx]; std::string spec = options_[idx];
...@@ -4241,7 +4240,6 @@ private: ...@@ -4241,7 +4240,6 @@ private:
}; };
} else if (msg_ == MSG_ANNOUNCE_NUMBER) { } else if (msg_ == MSG_ANNOUNCE_NUMBER) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto count = read_u8(); auto count = read_u8();
std::vector<int> numbers; std::vector<int> numbers;
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
...@@ -4265,12 +4263,12 @@ private: ...@@ -4265,12 +4263,12 @@ private:
str += "]"; str += "]";
pl->notify(str); pl->notify(str);
} }
to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
OCG_SetResponsei(pduel_, idx); OCG_SetResponsei(pduel_, idx);
}; };
} else if (msg_ == MSG_ANNOUNCE_ATTRIB) { } else if (msg_ == MSG_ANNOUNCE_ATTRIB) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto count = read_u8(); auto count = read_u8();
auto flag = read_u32(); auto flag = read_u32();
...@@ -4310,6 +4308,7 @@ private: ...@@ -4310,6 +4308,7 @@ private:
options_.push_back(option); options_.push_back(option);
} }
to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
const auto &option = options_[idx]; const auto &option = options_[idx];
uint32_t resp = 0; uint32_t resp = 0;
...@@ -4323,7 +4322,6 @@ private: ...@@ -4323,7 +4322,6 @@ private:
} else if (msg_ == MSG_SELECT_POSITION) { } else if (msg_ == MSG_SELECT_POSITION) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto code = read_u32(); auto code = read_u32();
auto valid_pos = read_u8(); auto valid_pos = read_u8();
...@@ -4348,6 +4346,7 @@ private: ...@@ -4348,6 +4346,7 @@ private:
i++; i++;
} }
to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
uint8_t pos = options_[idx][0] - '1'; uint8_t pos = options_[idx][0] - '1';
OCG_SetResponsei(pduel_, 1 << pos); OCG_SetResponsei(pduel_, 1 << pos);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment