Commit 350c29a8 authored by sbl1996@126.com's avatar sbl1996@126.com

Update eval scripts

parent bbd36d86
...@@ -30,9 +30,9 @@ class Args: ...@@ -30,9 +30,9 @@ class Args:
deck: str = "../assets/deck" deck: str = "../assets/deck"
"""the deck file to use""" """the deck file to use"""
deck1: Optional[str] = None deck1: Optional[str] = None
"""the deck file for the first player""" """the deck name for the first player, for example, `Hero`"""
deck2: Optional[str] = None deck2: Optional[str] = None
"""the deck file for the second player""" """the deck name for the second player, for example, `CyberDragon`"""
code_list_file: str = "code_list.txt" code_list_file: str = "code_list.txt"
"""the code list file for card embeddings""" """the code list file for card embeddings"""
lang: str = "english" lang: str = "english"
...@@ -44,6 +44,8 @@ class Args: ...@@ -44,6 +44,8 @@ class Args:
num_embeddings: Optional[int] = None num_embeddings: Optional[int] = None
"""the number of embeddings of the agent""" """the number of embeddings of the agent"""
verbose: bool = False
"""whether to print debug information"""
record: bool = False record: bool = False
"""whether to record the game as YGOPro replays""" """whether to record the game as YGOPro replays"""
...@@ -51,8 +53,6 @@ class Args: ...@@ -51,8 +53,6 @@ class Args:
"""the number of episodes to run""" """the number of episodes to run"""
num_envs: int = 64 num_envs: int = 64
"""the number of parallel game environments""" """the number of parallel game environments"""
verbose: bool = False
"""whether to print debug information"""
num_layers: int = 2 num_layers: int = 2
"""the number of layers for the agent""" """the number of layers for the agent"""
...@@ -61,15 +61,14 @@ class Args: ...@@ -61,15 +61,14 @@ class Args:
rnn_channels: Optional[int] = 512 rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent""" """the number of rnn channels for the agent"""
checkpoint1: str = "checkpoints/agent.pt" checkpoint1: str = "checkpoints/agent.pt"
"""the checkpoint to load for the first agent, `pt` or `flax_model` file""" """the checkpoint to load for the first agent, must be a `flax_model` file"""
checkpoint2: str = "checkpoints/agent.pt" checkpoint2: str = "checkpoints/agent.pt"
"""the checkpoint to load for the second agent, `pt` or `flax_model` file""" """the checkpoint to load for the second agent, must be a `flax_model` file"""
# Jax specific
xla_device: Optional[str] = None xla_device: Optional[str] = None
"""the XLA device to use, defaults to `None`""" """the XLA device to use, `cpu` for forcing running on CPU"""
env_threads: Optional[int] = 16 env_threads: Optional[int] = None
"""the number of threads to use for envpool, defaults to `num_envs`""" """the number of threads to use for envpool, defaults to `num_envs`"""
...@@ -96,8 +95,9 @@ if __name__ == "__main__": ...@@ -96,8 +95,9 @@ if __name__ == "__main__":
args = tyro.cli(Args) args = tyro.cli(Args)
if args.record: if args.record:
assert args.num_envs == 1, "Recording only works with a single environment" args.num_envs = 1
assert args.verbose, "Recording only works with verbose mode" args.verbose = True
print("Set num_envs=1 and verbose=True for recording")
if not os.path.exists("replay"): if not os.path.exists("replay"):
os.makedirs("replay") os.makedirs("replay")
...@@ -254,7 +254,7 @@ if __name__ == "__main__": ...@@ -254,7 +254,7 @@ if __name__ == "__main__":
win_rates.append(win) win_rates.append(win)
win_reasons.append(1 if win_reason == 1 else 0) win_reasons.append(1 if win_reason == 1 else 0)
if args.verbose: if args.verbose:
print(f"Episode {len(episode_lengths)}: length={episode_length}, reward={main_reward}, win={win}, win_reason={win_reason}\n") sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={main_reward}, win={win}, win_reason={win_reason}\n")
else: else:
pbar.set_postfix(len=np.mean(episode_lengths), reward=np.mean(episode_rewards), win_rate=np.mean(win_rates)) pbar.set_postfix(len=np.mean(episode_lengths), reward=np.mean(episode_rewards), win_rate=np.mean(win_rates))
pbar.update(1) pbar.update(1)
......
...@@ -24,9 +24,9 @@ class Args: ...@@ -24,9 +24,9 @@ class Args:
deck: str = "../assets/deck" deck: str = "../assets/deck"
"""the deck file to use""" """the deck file to use"""
deck1: Optional[str] = None deck1: Optional[str] = None
"""the deck file for the first player""" """the deck name for the first player, for example, `Hero`"""
deck2: Optional[str] = None deck2: Optional[str] = None
"""the deck file for the second player""" """the deck name for the second player, for example, `CyberDragon`"""
code_list_file: str = "code_list.txt" code_list_file: str = "code_list.txt"
"""the code list file for card embeddings""" """the code list file for card embeddings"""
lang: str = "english" lang: str = "english"
...@@ -42,6 +42,8 @@ class Args: ...@@ -42,6 +42,8 @@ class Args:
"""the player to play as, -1 means random, 0 is the first player, 1 is the second player""" """the player to play as, -1 means random, 0 is the first player, 1 is the second player"""
play: bool = False play: bool = False
"""whether to play the game""" """whether to play the game"""
verbose: bool = False
"""whether to print debug information"""
record: bool = False record: bool = False
"""whether to record the game as YGOPro replays""" """whether to record the game as YGOPro replays"""
...@@ -49,8 +51,6 @@ class Args: ...@@ -49,8 +51,6 @@ class Args:
"""the number of episodes to run""" """the number of episodes to run"""
num_envs: int = 64 num_envs: int = 64
"""the number of parallel game environments""" """the number of parallel game environments"""
verbose: bool = False
"""whether to print debug information"""
bot_type: Literal["random", "greedy"] = "greedy" bot_type: Literal["random", "greedy"] = "greedy"
"""the type of bot to use""" """the type of bot to use"""
...@@ -63,14 +63,14 @@ class Args: ...@@ -63,14 +63,14 @@ class Args:
"""the number of channels for the agent""" """the number of channels for the agent"""
rnn_channels: Optional[int] = 512 rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent""" """the number of rnn channels for the agent"""
checkpoint: Optional[str] = None checkpoint: Optional[str] = None
"""the checkpoint to load, `pt` or `flax_model` file""" """the checkpoint to load, must be a `flax_model` file"""
# Jax specific
xla_device: Optional[str] = None xla_device: Optional[str] = None
"""the XLA device to use, defaults to `None`""" """the XLA device to use, `cpu` for forcing running on CPU"""
env_threads: Optional[int] = 16 env_threads: Optional[int] = None
"""the number of threads to use for envpool, defaults to `num_envs`""" """the number of threads to use for envpool, defaults to `num_envs`"""
...@@ -92,14 +92,11 @@ def init_rnn_state(num_envs, rnn_channels): ...@@ -92,14 +92,11 @@ def init_rnn_state(num_envs, rnn_channels):
if __name__ == "__main__": if __name__ == "__main__":
args = tyro.cli(Args) args = tyro.cli(Args)
if args.play: if args.play or args.record:
args.num_envs = 1 args.num_envs = 1
args.verbose = True args.verbose = True
print("Set num_envs=1 and verbose=True for recording or playing the game")
if args.record: if args.record and not os.path.exists("replay"):
assert args.num_envs == 1, "Recording only works with a single environment"
assert args.verbose, "Recording only works with verbose mode"
if not os.path.exists("replay"):
os.makedirs("replay") os.makedirs("replay")
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs) args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
......
...@@ -9,7 +9,7 @@ DESCRIPTION = "A Yu-Gi-Oh! AI." ...@@ -9,7 +9,7 @@ DESCRIPTION = "A Yu-Gi-Oh! AI."
URL = 'https://github.com/sbl1996/ygo-agent' URL = 'https://github.com/sbl1996/ygo-agent'
EMAIL = 'sbl1996@gmail.com' EMAIL = 'sbl1996@gmail.com'
AUTHOR = 'Hastur' AUTHOR = 'Hastur'
REQUIRES_PYTHON = '>=3.8.0' REQUIRES_PYTHON = '>=3.10.0'
VERSION = None VERSION = None
REQUIRED = [ REQUIRED = [
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment