Commit 745c67f9 authored by Biluo Shen's avatar Biluo Shen

Fix multi select error

parent 15e1e4e7
......@@ -140,8 +140,8 @@ if __name__ == "__main__":
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent1 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent2 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent1 = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
agent2 = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
for agent, ckpt in zip([agent1, agent2], [args.checkpoint1, args.checkpoint2]):
state_dict = torch.load(ckpt, map_location=device)
......
......@@ -153,7 +153,6 @@ if __name__ == "__main__":
embedding_shape = len(code_list)
L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
# agent = agent.eval()
if args.checkpoint:
state_dict = torch.load(args.checkpoint, map_location=device)
if not args.compile:
......
......@@ -275,6 +275,8 @@ def main():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
train_step = torch.compile(train_step, mode=args.compile)
else:
traced_model = agent
# ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device)
......
......@@ -1392,7 +1392,6 @@ protected:
ankerl::unordered_dense::map<std::string, int> ms_spec2idx_;
std::vector<int> ms_r_idxs_;
// discard hand cards
bool discard_hand_ = false;
......@@ -1470,6 +1469,7 @@ public:
}
turn_count_ = 0;
ms_idx_ = -1;
history_actions_0_.Zero();
history_actions_1_.Zero();
......@@ -1710,7 +1710,6 @@ public:
for (int i = 0; i < ms_r_idxs_.size(); ++i) {
resp_buf_[i + 1] = ms_r_idxs_[i];
}
// fmt::println("{}, {}", ms_r_idxs_.size(), ms_r_idxs_);
YGO_SetResponseb(pduel_, resp_buf_);
} else {
ms_idx_++;
......@@ -1750,6 +1749,14 @@ public:
fmt::println("turn: {}, phase: {}, tplayer: {}", turn_count_, phase_to_string(current_phase_), tp_);
}
void show_buffer() const {
fmt::println("msg: {}, dp: {}, dl: {}", msg_to_string(msg_), dp_, dl_);
for (int i = 0; i < dl_; ++i) {
fmt::print("{:02x} ", data_[i]);
}
fmt::print("\n");
}
void show_deck(PlayerId player) const {
fmt::print("Player {}'s deck:\n", player);
show_deck(player == 0 ? main_deck0_ : main_deck1_, "Main");
......@@ -1997,11 +2004,16 @@ private:
if (it == spec2index.end()) {
// TODO: find the root cause
// print spec2index
fmt::println("Spec2index:");
show_deck(0);
show_deck(1);
show_buffer();
show_turn();
fmt::println("MS: idx: {}, mode: {}, min: {}, max: {}, must: {}, specs: {}, combs: {}", ms_idx_, ms_mode_, ms_min_, ms_max_, ms_must_, ms_specs_, ms_combs_);
fmt::println("Spec: {}, Spec2index:", spec);
for (auto &[k, v] : spec2index) {
fmt::println("{}: {}", k, v);
}
// throw std::runtime_error("Spec not found: " + spec);
throw std::runtime_error("Spec not found: " + spec);
idx = 1;
} else {
idx = it->second;
......@@ -4533,11 +4545,7 @@ private:
} else {
show_deck(0);
show_deck(1);
// print byte by byte
for (int i = 0; i < dp_; ++i) {
fmt::print("{:02x} ", data_[i]);
}
fmt::print("\n");
show_buffer();
throw std::runtime_error(
fmt::format("Unknown message {}, length {}, dp {}",
msg_to_string(msg_), dl_, dp_));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment