Commit b725f2dc authored by sbl1996@126.com's avatar sbl1996@126.com

Prepare for release

parent 662b300f
......@@ -20,7 +20,6 @@ import optax
import distrax
import tyro
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.utils import RecordEpisodeStatistics, EnvPreprocess
......@@ -55,7 +54,7 @@ class Args:
debug: bool = False
"""whether to run the script in debug mode"""
tb_dir: str = "runs"
tb_dir: Optional[str] = "runs"
"""the directory to save the tensorboard logs"""
tb_offset: int = 0
"""the step offset of the tensorboard logs"""
......@@ -696,8 +695,9 @@ def main():
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
tb_log_dir = f"{args.tb_dir}/{run_name}"
if args.local_rank == 0 and not args.debug:
if args.local_rank == 0 and not args.debug and args.tb_dir is not None:
from tensorboardX import SummaryWriter
tb_log_dir = f"{args.tb_dir}/{run_name}"
writer = SummaryWriter(tb_log_dir)
writer.add_text(
"hyperparameters",
......
......@@ -210,17 +210,6 @@ if __name__ == "__main__":
for idx, d in enumerate(dones):
if not d:
continue
# for i in range(2):
# deck_time = infos['step_time'][idx][i]
# deck_name = deck_names[infos['deck'][idx][i]]
# time_count = deck_time_count[deck_name]
# avg_time = deck_times[deck_name]
# avg_time = avg_time * (time_count / (time_count + 1)) + deck_time / (time_count + 1)
# deck_times[deck_name] = avg_time
# deck_time_count[deck_name] += 1
# if deck_time_count[deck_name] % 100 == 0:
# print(f"Deck {deck_name}: {avg_time:.4f}")
win_reason = infos['win_reason'][idx]
episode_length = infos['l'][idx]
......@@ -235,7 +224,7 @@ if __name__ == "__main__":
if len(episode_lengths) >= args.num_episodes:
break
print(f"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}")
print(f"len={np.mean(episode_lengths):.4f}, reward={np.mean(episode_rewards):.4f}, win_rate={np.mean(win_rates):.4f}, win_reason={np.mean(win_reasons):.4f}")
if not args.play:
total_time = time.time() - start
total_steps = (step - start_step) * num_envs
......
......@@ -15,6 +15,7 @@ VERSION = None
REQUIRED = [
"tyro",
"pandas",
"tensorboardX",
]
here = os.path.dirname(os.path.abspath(__file__))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment