Commit 366e5f3b authored by sbl1996@126.com's avatar sbl1996@126.com

Add ppo_sp

parent 598465e8
...@@ -18,6 +18,20 @@ ...@@ -18,6 +18,20 @@
## Global ## Global
- lp: 2, max 65535 to 2 bytes - lp: 2, max 65535 to 2 bytes
- oppo_lp: 2, max 65535 to 2 bytes - oppo_lp: 2, max 65535 to 2 bytes
- n_my_decks: 1, int
- n_my_extras:
- n_my_hands:
- n_my_graves:
- n_my_removes:
- n_my_monsters:
- n_my_spell_traps:
- n_op_decks:
- n_op_extras:
- n_op_hands:
- n_op_graves:
- n_op_removes:
- n_op_monsters:
- n_op_spell_traps:
- turn: 1, int, trunc to 8 - turn: 1, int, trunc to 8
- phase: 1, int, one-hot (10) - phase: 1, int, one-hot (10)
- is_first: 1, int, 0: False, 1: True - is_first: 1, int, 0: False, 1: True
......
...@@ -43,6 +43,8 @@ class Args: ...@@ -43,6 +43,8 @@ class Args:
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 16 n_history_actions: int = 16
"""the number of history actions to use""" """the number of history actions to use"""
num_embeddings: Optional[int] = None
"""the number of embeddings of the agent"""
player: int = -1 player: int = -1
"""the player to play as, -1 means random, 0 is the first player, 1 is the second player""" """the player to play as, -1 means random, 0 is the first player, 1 is the second player"""
...@@ -138,9 +140,11 @@ if __name__ == "__main__": ...@@ -138,9 +140,11 @@ if __name__ == "__main__":
if args.agent: if args.agent:
# count lines of code_list # count lines of code_list
with open(args.code_list_file, "r") as f: embedding_shape = args.num_embeddings
code_list = f.readlines() if embedding_shape is None:
embedding_shape = len(code_list) with open(args.code_list_file, "r") as f:
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device) agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent = agent.eval() agent = agent.eval()
......
...@@ -375,7 +375,7 @@ def run(local_rank, world_size): ...@@ -375,7 +375,7 @@ def run(local_rank, world_size):
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values returns = advantages + values
_start = time.time() _start = time.time()
# flatten the batch # flatten the batch
b_obs = { b_obs = {
......
This diff is collapsed.
...@@ -105,6 +105,7 @@ class Encoder(nn.Module): ...@@ -105,6 +105,7 @@ class Encoder(nn.Module):
self.a_option_embed = nn.Embedding(6, c // divisor // 2) self.a_option_embed = nn.Embedding(6, c // divisor // 2)
self.a_number_embed = nn.Embedding(13, c // divisor // 2) self.a_number_embed = nn.Embedding(13, c // divisor // 2)
self.a_place_embed = nn.Embedding(31, c // divisor // 2) self.a_place_embed = nn.Embedding(31, c // divisor // 2)
# TODO: maybe same embedding as attribute_embed
self.a_attrib_embed = nn.Embedding(10, c // divisor // 2) self.a_attrib_embed = nn.Embedding(10, c // divisor // 2)
self.a_feat_norm = nn.LayerNorm(c, elementwise_affine=affine) self.a_feat_norm = nn.LayerNorm(c, elementwise_affine=affine)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment