Commit c90562de authored by sbl1996@126.com's avatar sbl1996@126.com

Env fault tolerant training via timeout

parent b7e43382
......@@ -153,7 +153,7 @@ if __name__ == "__main__":
play_mode='self',
async_reset=False,
verbose=args.verbose,
record=args.record,
record=args.record,
)
envs1 = ygoenv.make(
task_id=env_id1,
......@@ -311,6 +311,7 @@ if __name__ == "__main__":
for idx, d in enumerate(dones1):
if not d or (args.accurate and collected[idx]):
continue
# c1 = collected[idx]
collected[idx] = True
win_reason = infos1['win_reason'][idx]
pl = 1 if main[idx] else -1
......@@ -323,7 +324,8 @@ if __name__ == "__main__":
win_players.append(win_player)
win_agent = 1 if main_reward > 0 else 2
win_agents.append(win_agent)
# print(f"{len(episode_lengths)}: {episode_length}, {main_reward}")
# if not c1:
# print(f"{len(episode_lengths)}: {episode_length}, {main_reward}")
episode_lengths.append(episode_length)
episode_rewards.append(main_reward)
win_rates.append(win)
......
......@@ -49,6 +49,8 @@ class Args:
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
timeout: int = 600
"""the timeout of the environment step"""
debug: bool = False
"""whether to run the script in debug mode"""
......@@ -208,6 +210,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
async_reset=False,
greedy_reward=args.greedy_reward if not eval else True,
play_mode=mode,
timeout=args.timeout,
)
envs.num_envs = num_envs
return envs
......
......@@ -210,17 +210,17 @@ if __name__ == "__main__":
for idx, d in enumerate(dones):
if not d:
continue
for i in range(2):
deck_time = infos['step_time'][idx][i]
deck_name = deck_names[infos['deck'][idx][i]]
time_count = deck_time_count[deck_name]
avg_time = deck_times[deck_name]
avg_time = avg_time * (time_count / (time_count + 1)) + deck_time / (time_count + 1)
deck_times[deck_name] = avg_time
deck_time_count[deck_name] += 1
if deck_time_count[deck_name] % 100 == 0:
print(f"Deck {deck_name}: {avg_time:.4f}")
# for i in range(2):
# deck_time = infos['step_time'][idx][i]
# deck_name = deck_names[infos['deck'][idx][i]]
# time_count = deck_time_count[deck_name]
# avg_time = deck_times[deck_name]
# avg_time = avg_time * (time_count / (time_count + 1)) + deck_time / (time_count + 1)
# deck_times[deck_name] = avg_time
# deck_time_count[deck_name] += 1
# if deck_time_count[deck_name] % 100 == 0:
# print(f"Deck {deck_name}: {avg_time:.4f}")
win_reason = infos['win_reason'][idx]
episode_length = infos['l'][idx]
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment