Commit c90562de authored by sbl1996@126.com's avatar sbl1996@126.com

Env fault tolerant training via timeout

parent b7e43382
...@@ -153,7 +153,7 @@ if __name__ == "__main__": ...@@ -153,7 +153,7 @@ if __name__ == "__main__":
play_mode='self', play_mode='self',
async_reset=False, async_reset=False,
verbose=args.verbose, verbose=args.verbose,
record=args.record, record=args.record,
) )
envs1 = ygoenv.make( envs1 = ygoenv.make(
task_id=env_id1, task_id=env_id1,
...@@ -311,6 +311,7 @@ if __name__ == "__main__": ...@@ -311,6 +311,7 @@ if __name__ == "__main__":
for idx, d in enumerate(dones1): for idx, d in enumerate(dones1):
if not d or (args.accurate and collected[idx]): if not d or (args.accurate and collected[idx]):
continue continue
# c1 = collected[idx]
collected[idx] = True collected[idx] = True
win_reason = infos1['win_reason'][idx] win_reason = infos1['win_reason'][idx]
pl = 1 if main[idx] else -1 pl = 1 if main[idx] else -1
...@@ -323,7 +324,8 @@ if __name__ == "__main__": ...@@ -323,7 +324,8 @@ if __name__ == "__main__":
win_players.append(win_player) win_players.append(win_player)
win_agent = 1 if main_reward > 0 else 2 win_agent = 1 if main_reward > 0 else 2
win_agents.append(win_agent) win_agents.append(win_agent)
# print(f"{len(episode_lengths)}: {episode_length}, {main_reward}") # if not c1:
# print(f"{len(episode_lengths)}: {episode_length}, {main_reward}")
episode_lengths.append(episode_length) episode_lengths.append(episode_length)
episode_rewards.append(main_reward) episode_rewards.append(main_reward)
win_rates.append(win) win_rates.append(win)
......
...@@ -49,6 +49,8 @@ class Args: ...@@ -49,6 +49,8 @@ class Args:
"""the frequency of saving the model (in terms of `updates`)""" """the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None checkpoint: Optional[str] = None
"""the path to the model checkpoint to load""" """the path to the model checkpoint to load"""
timeout: int = 600
"""the timeout of the environment step"""
debug: bool = False debug: bool = False
"""whether to run the script in debug mode""" """whether to run the script in debug mode"""
...@@ -208,6 +210,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off ...@@ -208,6 +210,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
async_reset=False, async_reset=False,
greedy_reward=args.greedy_reward if not eval else True, greedy_reward=args.greedy_reward if not eval else True,
play_mode=mode, play_mode=mode,
timeout=args.timeout,
) )
envs.num_envs = num_envs envs.num_envs = num_envs
return envs return envs
......
...@@ -210,17 +210,17 @@ if __name__ == "__main__": ...@@ -210,17 +210,17 @@ if __name__ == "__main__":
for idx, d in enumerate(dones): for idx, d in enumerate(dones):
if not d: if not d:
continue continue
for i in range(2): # for i in range(2):
deck_time = infos['step_time'][idx][i] # deck_time = infos['step_time'][idx][i]
deck_name = deck_names[infos['deck'][idx][i]] # deck_name = deck_names[infos['deck'][idx][i]]
time_count = deck_time_count[deck_name] # time_count = deck_time_count[deck_name]
avg_time = deck_times[deck_name] # avg_time = deck_times[deck_name]
avg_time = avg_time * (time_count / (time_count + 1)) + deck_time / (time_count + 1) # avg_time = avg_time * (time_count / (time_count + 1)) + deck_time / (time_count + 1)
deck_times[deck_name] = avg_time # deck_times[deck_name] = avg_time
deck_time_count[deck_name] += 1 # deck_time_count[deck_name] += 1
if deck_time_count[deck_name] % 100 == 0: # if deck_time_count[deck_name] % 100 == 0:
print(f"Deck {deck_name}: {avg_time:.4f}") # print(f"Deck {deck_name}: {avg_time:.4f}")
win_reason = infos['win_reason'][idx] win_reason = infos['win_reason'][idx]
episode_length = infos['l'][idx] episode_length = infos['l'][idx]
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment