Commit 350c29a8 authored by sbl1996@126.com's avatar sbl1996@126.com

Update eval scripts

parent bbd36d86
......@@ -30,9 +30,9 @@ class Args:
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
"""the deck name for the first player, for example, `Hero`"""
deck2: Optional[str] = None
"""the deck file for the second player"""
"""the deck name for the second player, for example, `CyberDragon`"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
lang: str = "english"
......@@ -44,6 +44,8 @@ class Args:
num_embeddings: Optional[int] = None
"""the number of embeddings of the agent"""
verbose: bool = False
"""whether to print debug information"""
record: bool = False
"""whether to record the game as YGOPro replays"""
......@@ -51,8 +53,6 @@ class Args:
"""the number of episodes to run"""
num_envs: int = 64
"""the number of parallel game environments"""
verbose: bool = False
"""whether to print debug information"""
num_layers: int = 2
"""the number of layers for the agent"""
......@@ -61,15 +61,14 @@ class Args:
rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent"""
checkpoint1: str = "checkpoints/agent.pt"
"""the checkpoint to load for the first agent, `pt` or `flax_model` file"""
"""the checkpoint to load for the first agent, must be a `flax_model` file"""
checkpoint2: str = "checkpoints/agent.pt"
"""the checkpoint to load for the second agent, `pt` or `flax_model` file"""
"""the checkpoint to load for the second agent, must be a `flax_model` file"""
# Jax specific
xla_device: Optional[str] = None
"""the XLA device to use, defaults to `None`"""
"""the XLA device to use, `cpu` for forcing running on CPU"""
env_threads: Optional[int] = 16
env_threads: Optional[int] = None
"""the number of threads to use for envpool, defaults to `num_envs`"""
......@@ -96,8 +95,9 @@ if __name__ == "__main__":
args = tyro.cli(Args)
if args.record:
assert args.num_envs == 1, "Recording only works with a single environment"
assert args.verbose, "Recording only works with verbose mode"
args.num_envs = 1
args.verbose = True
print("Set num_envs=1 and verbose=True for recording")
if not os.path.exists("replay"):
os.makedirs("replay")
......@@ -254,7 +254,7 @@ if __name__ == "__main__":
win_rates.append(win)
win_reasons.append(1 if win_reason == 1 else 0)
if args.verbose:
print(f"Episode {len(episode_lengths)}: length={episode_length}, reward={main_reward}, win={win}, win_reason={win_reason}\n")
sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={main_reward}, win={win}, win_reason={win_reason}\n")
else:
pbar.set_postfix(len=np.mean(episode_lengths), reward=np.mean(episode_rewards), win_rate=np.mean(win_rates))
pbar.update(1)
......
......@@ -24,9 +24,9 @@ class Args:
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
"""the deck name for the first player, for example, `Hero`"""
deck2: Optional[str] = None
"""the deck file for the second player"""
"""the deck name for the second player, for example, `CyberDragon`"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
lang: str = "english"
......@@ -42,6 +42,8 @@ class Args:
"""the player to play as, -1 means random, 0 is the first player, 1 is the second player"""
play: bool = False
"""whether to play the game"""
verbose: bool = False
"""whether to print debug information"""
record: bool = False
"""whether to record the game as YGOPro replays"""
......@@ -49,8 +51,6 @@ class Args:
"""the number of episodes to run"""
num_envs: int = 64
"""the number of parallel game environments"""
verbose: bool = False
"""whether to print debug information"""
bot_type: Literal["random", "greedy"] = "greedy"
"""the type of bot to use"""
......@@ -63,14 +63,14 @@ class Args:
"""the number of channels for the agent"""
rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent"""
checkpoint: Optional[str] = None
"""the checkpoint to load, `pt` or `flax_model` file"""
"""the checkpoint to load, must be a `flax_model` file"""
# Jax specific
xla_device: Optional[str] = None
"""the XLA device to use, defaults to `None`"""
"""the XLA device to use, `cpu` for forcing running on CPU"""
env_threads: Optional[int] = 16
env_threads: Optional[int] = None
"""the number of threads to use for envpool, defaults to `num_envs`"""
......@@ -92,14 +92,11 @@ def init_rnn_state(num_envs, rnn_channels):
if __name__ == "__main__":
args = tyro.cli(Args)
if args.play:
if args.play or args.record:
args.num_envs = 1
args.verbose = True
if args.record:
assert args.num_envs == 1, "Recording only works with a single environment"
assert args.verbose, "Recording only works with verbose mode"
if not os.path.exists("replay"):
print("Set num_envs=1 and verbose=True for recording or playing the game")
if args.record and not os.path.exists("replay"):
os.makedirs("replay")
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
......
......@@ -9,7 +9,7 @@ DESCRIPTION = "A Yu-Gi-Oh! AI."
URL = 'https://github.com/sbl1996/ygo-agent'
EMAIL = 'sbl1996@gmail.com'
AUTHOR = 'Hastur'
REQUIRES_PYTHON = '>=3.8.0'
REQUIRES_PYTHON = '>=3.10.0'
VERSION = None
REQUIRED = [
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment