Commit 745c67f9 authored by Biluo Shen's avatar Biluo Shen

Fix multi select error

parent 15e1e4e7
...@@ -140,8 +140,8 @@ if __name__ == "__main__": ...@@ -140,8 +140,8 @@ if __name__ == "__main__":
code_list = f.readlines() code_list = f.readlines()
embedding_shape = len(code_list) embedding_shape = len(code_list)
L = args.num_layers L = args.num_layers
agent1 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device) agent1 = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
agent2 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device) agent2 = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
for agent, ckpt in zip([agent1, agent2], [args.checkpoint1, args.checkpoint2]): for agent, ckpt in zip([agent1, agent2], [args.checkpoint1, args.checkpoint2]):
state_dict = torch.load(ckpt, map_location=device) state_dict = torch.load(ckpt, map_location=device)
......
...@@ -153,7 +153,6 @@ if __name__ == "__main__": ...@@ -153,7 +153,6 @@ if __name__ == "__main__":
embedding_shape = len(code_list) embedding_shape = len(code_list)
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device) agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
# agent = agent.eval()
if args.checkpoint: if args.checkpoint:
state_dict = torch.load(args.checkpoint, map_location=device) state_dict = torch.load(args.checkpoint, map_location=device)
if not args.compile: if not args.compile:
......
...@@ -275,6 +275,8 @@ def main(): ...@@ -275,6 +275,8 @@ def main():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False) traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
train_step = torch.compile(train_step, mode=args.compile) train_step = torch.compile(train_step, mode=args.compile)
else:
traced_model = agent
# ALGO Logic: Storage setup # ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device) obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device)
......
...@@ -1392,7 +1392,6 @@ protected: ...@@ -1392,7 +1392,6 @@ protected:
ankerl::unordered_dense::map<std::string, int> ms_spec2idx_; ankerl::unordered_dense::map<std::string, int> ms_spec2idx_;
std::vector<int> ms_r_idxs_; std::vector<int> ms_r_idxs_;
// discard hand cards // discard hand cards
bool discard_hand_ = false; bool discard_hand_ = false;
...@@ -1470,6 +1469,7 @@ public: ...@@ -1470,6 +1469,7 @@ public:
} }
turn_count_ = 0; turn_count_ = 0;
ms_idx_ = -1;
history_actions_0_.Zero(); history_actions_0_.Zero();
history_actions_1_.Zero(); history_actions_1_.Zero();
...@@ -1710,7 +1710,6 @@ public: ...@@ -1710,7 +1710,6 @@ public:
for (int i = 0; i < ms_r_idxs_.size(); ++i) { for (int i = 0; i < ms_r_idxs_.size(); ++i) {
resp_buf_[i + 1] = ms_r_idxs_[i]; resp_buf_[i + 1] = ms_r_idxs_[i];
} }
// fmt::println("{}, {}", ms_r_idxs_.size(), ms_r_idxs_);
YGO_SetResponseb(pduel_, resp_buf_); YGO_SetResponseb(pduel_, resp_buf_);
} else { } else {
ms_idx_++; ms_idx_++;
...@@ -1750,6 +1749,14 @@ public: ...@@ -1750,6 +1749,14 @@ public:
fmt::println("turn: {}, phase: {}, tplayer: {}", turn_count_, phase_to_string(current_phase_), tp_); fmt::println("turn: {}, phase: {}, tplayer: {}", turn_count_, phase_to_string(current_phase_), tp_);
} }
void show_buffer() const {
fmt::println("msg: {}, dp: {}, dl: {}", msg_to_string(msg_), dp_, dl_);
for (int i = 0; i < dl_; ++i) {
fmt::print("{:02x} ", data_[i]);
}
fmt::print("\n");
}
void show_deck(PlayerId player) const { void show_deck(PlayerId player) const {
fmt::print("Player {}'s deck:\n", player); fmt::print("Player {}'s deck:\n", player);
show_deck(player == 0 ? main_deck0_ : main_deck1_, "Main"); show_deck(player == 0 ? main_deck0_ : main_deck1_, "Main");
...@@ -1997,11 +2004,16 @@ private: ...@@ -1997,11 +2004,16 @@ private:
if (it == spec2index.end()) { if (it == spec2index.end()) {
// TODO: find the root cause // TODO: find the root cause
// print spec2index // print spec2index
fmt::println("Spec2index:"); show_deck(0);
show_deck(1);
show_buffer();
show_turn();
fmt::println("MS: idx: {}, mode: {}, min: {}, max: {}, must: {}, specs: {}, combs: {}", ms_idx_, ms_mode_, ms_min_, ms_max_, ms_must_, ms_specs_, ms_combs_);
fmt::println("Spec: {}, Spec2index:", spec);
for (auto &[k, v] : spec2index) { for (auto &[k, v] : spec2index) {
fmt::println("{}: {}", k, v); fmt::println("{}: {}", k, v);
} }
// throw std::runtime_error("Spec not found: " + spec); throw std::runtime_error("Spec not found: " + spec);
idx = 1; idx = 1;
} else { } else {
idx = it->second; idx = it->second;
...@@ -4533,11 +4545,7 @@ private: ...@@ -4533,11 +4545,7 @@ private:
} else { } else {
show_deck(0); show_deck(0);
show_deck(1); show_deck(1);
// print byte by byte show_buffer();
for (int i = 0; i < dp_; ++i) {
fmt::print("{:02x} ", data_[i]);
}
fmt::print("\n");
throw std::runtime_error( throw std::runtime_error(
fmt::format("Unknown message {}, length {}, dp {}", fmt::format("Unknown message {}, length {}, dp {}",
msg_to_string(msg_), dl_, dp_)); msg_to_string(msg_), dl_, dp_));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment