Commit 8293ff41 authored by biluo.shen's avatar biluo.shen

Support replay

parent 1925293b
...@@ -50,6 +50,8 @@ class Args: ...@@ -50,6 +50,8 @@ class Args:
"""whether to play the game""" """whether to play the game"""
selfplay: bool = False selfplay: bool = False
"""whether to use selfplay""" """whether to use selfplay"""
record: bool = False
"""whether to record the game as YGOPro replays"""
num_episodes: int = 1024 num_episodes: int = 1024
"""the number of episodes to run""" """the number of episodes to run"""
...@@ -87,6 +89,10 @@ if __name__ == "__main__": ...@@ -87,6 +89,10 @@ if __name__ == "__main__":
if args.play: if args.play:
args.num_envs = 1 args.num_envs = 1
args.verbose = True args.verbose = True
if args.record:
assert args.num_envs == 1, "Recording only works with a single environment"
assert args.verbose, "Recording only works with verbose mode"
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs) args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4")) args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
...@@ -125,6 +131,7 @@ if __name__ == "__main__": ...@@ -125,6 +131,7 @@ if __name__ == "__main__":
n_history_actions=args.n_history_actions, n_history_actions=args.n_history_actions,
play_mode='human' if args.play else ('self' if args.selfplay else ('bot' if args.bot_type == "greedy" else "random")), play_mode='human' if args.play else ('self' if args.selfplay else ('bot' if args.bot_type == "greedy" else "random")),
verbose=args.verbose, verbose=args.verbose,
record=args.record,
) )
envs.num_envs = num_envs envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs) envs = RecordEpisodeStatistics(envs)
......
...@@ -112,8 +112,10 @@ class Args: ...@@ -112,8 +112,10 @@ class Args:
"""tensorboard log directory""" """tensorboard log directory"""
ckpt_dir: str = "./checkpoints" ckpt_dir: str = "./checkpoints"
"""checkpoint directory""" """checkpoint directory"""
save_interval: int = 100 save_interval: int = 1000
"""the number of iterations to save the model""" """the number of iterations to save the model"""
log_p: float = 0.1
"""the probability of logging"""
port: int = 12355 port: int = 12355
"""the port to use for distributed training""" """the port to use for distributed training"""
...@@ -339,7 +341,7 @@ def run(local_rank, world_size): ...@@ -339,7 +341,7 @@ def run(local_rank, world_size):
continue continue
for idx, d in enumerate(next_done_): for idx, d in enumerate(next_done_):
if d: if d and random.random() < args.log_p:
episode_length = info['l'][idx] episode_length = info['l'][idx]
episode_reward = info['r'][idx] episode_reward = info['r'][idx]
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
...@@ -420,7 +422,7 @@ def run(local_rank, world_size): ...@@ -420,7 +422,7 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: record rewards for plotting purposes # TRY NOT TO MODIFY: record rewards for plotting purposes
if local_rank == 0: if local_rank == 0:
if iteration % args.save_interval == 0: if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"{iteration}.pth")) torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pth"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step) writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
......
...@@ -1359,7 +1359,7 @@ protected: ...@@ -1359,7 +1359,7 @@ protected:
bool record_ = false; bool record_ = false;
// uint8_t *replay_data_; // uint8_t *replay_data_;
// uint8_t *rdata_; // uint8_t *rdata_;
FILE* fp_; FILE* fp_ = nullptr;
bool is_recording = false; bool is_recording = false;
public: public:
...@@ -1454,7 +1454,7 @@ public: ...@@ -1454,7 +1454,7 @@ public:
ha_p_0_ = 0; ha_p_0_ = 0;
ha_p_1_ = 0; ha_p_1_ = 0;
unsigned long duel_seed = dist_int_(gen_); auto duel_seed = dist_int_(gen_);
std::unique_lock<std::shared_timed_mutex> ulock(duel_mtx); std::unique_lock<std::shared_timed_mutex> ulock(duel_mtx);
pduel_ = OCG_CreateDuel(duel_seed); pduel_ = OCG_CreateDuel(duel_seed);
...@@ -1580,12 +1580,12 @@ public: ...@@ -1580,12 +1580,12 @@ public:
if (done_) { if (done_) {
float base_reward = 1.0; float base_reward = 1.0;
int win_turn = turn_count_ - winner_; int win_turn = turn_count_ - winner_;
if (win_turn <= 5) { if (win_turn <= 1) {
base_reward = 2.0; base_reward = 4.0;
} else if (win_turn <= 3) { } else if (win_turn <= 3) {
base_reward = 3.0; base_reward = 3.0;
} else if (win_turn <= 1) { } else if (win_turn <= 5) {
base_reward = 4.0; base_reward = 2.0;
} }
if (play_mode_ == kSelfPlay) { if (play_mode_ == kSelfPlay) {
// to_play_ is the previous player // to_play_ is the previous player
...@@ -1599,6 +1599,14 @@ public: ...@@ -1599,6 +1599,14 @@ public:
} else if (win_reason_ == 0x02) { } else if (win_reason_ == 0x02) {
reason = -1; reason = -1;
} }
if (record_) {
if (!is_recording || fp_ == nullptr) {
throw std::runtime_error("Recording is not started");
}
fclose(fp_);
is_recording = false;
}
} }
WriteState(reward, win_reason_); WriteState(reward, win_reason_);
...@@ -1942,11 +1950,12 @@ private: ...@@ -1942,11 +1950,12 @@ private:
void str_to_uint16(const char* src, uint16_t* dest) { void str_to_uint16(const char* src, uint16_t* dest) {
for (int i = 0; i < strlen(src); i += 2) { for (int i = 0; i < strlen(src); i += 1) {
dest[i / 2] = src[i] | (src[i + 1] << 8); dest[i] = src[i];
} }
// Add null terminator // Add null terminator
dest[(strlen(src) + 1) / 2] = '\0'; dest[strlen(src) + 1] = '\0';
} }
void ReplayWriteInt8(int8_t value) { void ReplayWriteInt8(int8_t value) {
...@@ -1958,7 +1967,7 @@ private: ...@@ -1958,7 +1967,7 @@ private:
} }
// ygopro-core API // ygopro-core API
intptr_t OCG_CreateDuel(uint_fast32_t seed) { intptr_t OCG_CreateDuel(uint32_t seed) {
if (record_) { if (record_) {
ReplayHeader rh; ReplayHeader rh;
rh.id = 0x31707279; rh.id = 0x31707279;
...@@ -1969,18 +1978,21 @@ private: ...@@ -1969,18 +1978,21 @@ private:
fwrite(&rh, sizeof(rh), 1, fp_); fwrite(&rh, sizeof(rh), 1, fp_);
fflush(fp_); fflush(fp_);
} }
return create_duel(seed); std::mt19937 rnd(seed);
return create_duel(rnd());
} }
void OCG_SetPlayerInfo(intptr_t pduel, int32 playerid, int32 lp, int32 startcount, int32 drawcount) { void OCG_SetPlayerInfo(intptr_t pduel, int32 playerid, int32 lp, int32 startcount, int32 drawcount) {
if (record_ && playerid == 0) { if (record_ && playerid == 0) {
{ {
uint16_t name[20]; uint16_t name[20];
memset(name, 0, 40);
str_to_uint16("Alice", name); str_to_uint16("Alice", name);
fwrite(name, 40, 1, fp_); fwrite(name, 40, 1, fp_);
} }
{ {
uint16_t name[20]; uint16_t name[20];
memset(name, 0, 40);
str_to_uint16("Bob", name); str_to_uint16("Bob", name);
fwrite(name, 40, 1, fp_); fwrite(name, 40, 1, fp_);
} }
...@@ -2030,7 +2042,7 @@ private: ...@@ -2030,7 +2042,7 @@ private:
void OCG_SetResponsei(intptr_t pduel, int32 value) { void OCG_SetResponsei(intptr_t pduel, int32 value) {
if (record_) { if (record_) {
ReplayWriteInt32(4); ReplayWriteInt8(4);
ReplayWriteInt32(value); ReplayWriteInt32(value);
} }
set_responsei(pduel, value); set_responsei(pduel, value);
...@@ -2038,8 +2050,21 @@ private: ...@@ -2038,8 +2050,21 @@ private:
void OCG_SetResponseb(intptr_t pduel, byte* buf) { void OCG_SetResponseb(intptr_t pduel, byte* buf) {
if (record_) { if (record_) {
ReplayWriteInt8(buf[0]); switch (msg_) {
fwrite(buf + 1, buf[0], 1, fp_); case MSG_SORT_CARD:
ReplayWriteInt8(1);
fwrite(buf, 1, 1, fp_);
break;
case MSG_SELECT_PLACE:
case MSG_SELECT_DISFIELD:
ReplayWriteInt8(3);
fwrite(buf, 3, 1, fp_);
break;
default:
ReplayWriteInt8(buf[0] + 1);
fwrite(buf, buf[0] + 1, 1, fp_);
break;
}
} }
set_responseb(pduel, buf); set_responseb(pduel, buf);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment