Commit 8293ff41 authored by biluo.shen's avatar biluo.shen

Support replay

parent 1925293b
......@@ -50,6 +50,8 @@ class Args:
"""whether to play the game"""
selfplay: bool = False
"""whether to use selfplay"""
record: bool = False
"""whether to record the game as YGOPro replays"""
num_episodes: int = 1024
"""the number of episodes to run"""
......@@ -87,6 +89,10 @@ if __name__ == "__main__":
if args.play:
args.num_envs = 1
args.verbose = True
if args.record:
assert args.num_envs == 1, "Recording only works with a single environment"
assert args.verbose, "Recording only works with verbose mode"
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
......@@ -125,6 +131,7 @@ if __name__ == "__main__":
n_history_actions=args.n_history_actions,
play_mode='human' if args.play else ('self' if args.selfplay else ('bot' if args.bot_type == "greedy" else "random")),
verbose=args.verbose,
record=args.record,
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
......
......@@ -112,8 +112,10 @@ class Args:
"""tensorboard log directory"""
ckpt_dir: str = "./checkpoints"
"""checkpoint directory"""
save_interval: int = 100
save_interval: int = 1000
"""the number of iterations to save the model"""
log_p: float = 0.1
"""the probability of logging"""
port: int = 12355
"""the port to use for distributed training"""
......@@ -339,7 +341,7 @@ def run(local_rank, world_size):
continue
for idx, d in enumerate(next_done_):
if d:
if d and random.random() < args.log_p:
episode_length = info['l'][idx]
episode_reward = info['r'][idx]
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
......@@ -420,7 +422,7 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: record rewards for plotting purposes
if local_rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"{iteration}.pth"))
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pth"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
......
......@@ -1359,7 +1359,7 @@ protected:
bool record_ = false;
// uint8_t *replay_data_;
// uint8_t *rdata_;
FILE* fp_;
FILE* fp_ = nullptr;
bool is_recording = false;
public:
......@@ -1454,7 +1454,7 @@ public:
ha_p_0_ = 0;
ha_p_1_ = 0;
unsigned long duel_seed = dist_int_(gen_);
auto duel_seed = dist_int_(gen_);
std::unique_lock<std::shared_timed_mutex> ulock(duel_mtx);
pduel_ = OCG_CreateDuel(duel_seed);
......@@ -1580,12 +1580,12 @@ public:
if (done_) {
float base_reward = 1.0;
int win_turn = turn_count_ - winner_;
if (win_turn <= 5) {
base_reward = 2.0;
if (win_turn <= 1) {
base_reward = 4.0;
} else if (win_turn <= 3) {
base_reward = 3.0;
} else if (win_turn <= 1) {
base_reward = 4.0;
} else if (win_turn <= 5) {
base_reward = 2.0;
}
if (play_mode_ == kSelfPlay) {
// to_play_ is the previous player
......@@ -1599,6 +1599,14 @@ public:
} else if (win_reason_ == 0x02) {
reason = -1;
}
if (record_) {
if (!is_recording || fp_ == nullptr) {
throw std::runtime_error("Recording is not started");
}
fclose(fp_);
is_recording = false;
}
}
WriteState(reward, win_reason_);
......@@ -1942,11 +1950,12 @@ private:
void str_to_uint16(const char* src, uint16_t* dest) {
for (int i = 0; i < strlen(src); i += 2) {
dest[i / 2] = src[i] | (src[i + 1] << 8);
for (int i = 0; i < strlen(src); i += 1) {
dest[i] = src[i];
}
// Add null terminator
dest[(strlen(src) + 1) / 2] = '\0';
dest[strlen(src) + 1] = '\0';
}
void ReplayWriteInt8(int8_t value) {
......@@ -1958,7 +1967,7 @@ private:
}
// ygopro-core API
intptr_t OCG_CreateDuel(uint_fast32_t seed) {
intptr_t OCG_CreateDuel(uint32_t seed) {
if (record_) {
ReplayHeader rh;
rh.id = 0x31707279;
......@@ -1969,18 +1978,21 @@ private:
fwrite(&rh, sizeof(rh), 1, fp_);
fflush(fp_);
}
return create_duel(seed);
std::mt19937 rnd(seed);
return create_duel(rnd());
}
void OCG_SetPlayerInfo(intptr_t pduel, int32 playerid, int32 lp, int32 startcount, int32 drawcount) {
if (record_ && playerid == 0) {
{
uint16_t name[20];
memset(name, 0, 40);
str_to_uint16("Alice", name);
fwrite(name, 40, 1, fp_);
}
{
uint16_t name[20];
memset(name, 0, 40);
str_to_uint16("Bob", name);
fwrite(name, 40, 1, fp_);
}
......@@ -2030,7 +2042,7 @@ private:
void OCG_SetResponsei(intptr_t pduel, int32 value) {
if (record_) {
ReplayWriteInt32(4);
ReplayWriteInt8(4);
ReplayWriteInt32(value);
}
set_responsei(pduel, value);
......@@ -2038,8 +2050,21 @@ private:
void OCG_SetResponseb(intptr_t pduel, byte* buf) {
if (record_) {
ReplayWriteInt8(buf[0]);
fwrite(buf + 1, buf[0], 1, fp_);
switch (msg_) {
case MSG_SORT_CARD:
ReplayWriteInt8(1);
fwrite(buf, 1, 1, fp_);
break;
case MSG_SELECT_PLACE:
case MSG_SELECT_DISFIELD:
ReplayWriteInt8(3);
fwrite(buf, 3, 1, fp_);
break;
default:
ReplayWriteInt8(buf[0] + 1);
fwrite(buf, buf[0] + 1, 1, fp_);
break;
}
}
set_responseb(pduel, buf);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment