Commit a93e7a88 authored by biluo.shen's avatar biluo.shen

Fix eval

parent 29e1a24b
...@@ -41,7 +41,7 @@ class Args: ...@@ -41,7 +41,7 @@ class Args:
"""the language to use""" """the language to use"""
max_options: int = 24 max_options: int = 24
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 8 n_history_actions: int = 16
"""the number of history actions to use""" """the number of history actions to use"""
player: int = -1 player: int = -1
...@@ -71,7 +71,7 @@ class Args: ...@@ -71,7 +71,7 @@ class Args:
"""the number of channels for the agent""" """the number of channels for the agent"""
checkpoint: str = "checkpoints/agent.pt" checkpoint: str = "checkpoints/agent.pt"
"""the checkpoint to load""" """the checkpoint to load"""
embedding_file: str = "embeddings_en.npy" embedding_file: Optional[str] = "embeddings_en.npy"
"""the embedding file for card embeddings""" """the embedding file for card embeddings"""
compile: bool = False compile: bool = False
...@@ -130,9 +130,13 @@ if __name__ == "__main__": ...@@ -130,9 +130,13 @@ if __name__ == "__main__":
envs = RecordEpisodeStatistics(envs) envs = RecordEpisodeStatistics(envs)
if args.agent: if args.agent:
embeddings = np.load(args.embedding_file) if args.embedding_file:
embeddings = np.load(args.embedding_file)
embedding_shape = embeddings.shape
else:
embedding_shape = None
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embeddings.shape).to(device) agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent = agent.eval() agent = agent.eval()
state_dict = torch.load(args.checkpoint, map_location=device) state_dict = torch.load(args.checkpoint, map_location=device)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment