Commit 03ffe718 authored by biluo.shen's avatar biluo.shen

Improve critic

parent e7a19464
...@@ -50,7 +50,7 @@ class Args: ...@@ -50,7 +50,7 @@ class Args:
"""the embedding file for card embeddings""" """the embedding file for card embeddings"""
max_options: int = 24 max_options: int = 24
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 8 n_history_actions: int = 16
"""the number of history actions to use""" """the number of history actions to use"""
play_mode: str = "self" play_mode: str = "self"
"""the play mode, can be combination of 'self', 'bot', 'random', like 'self+bot'""" """the play mode, can be combination of 'self', 'bot', 'random', like 'self+bot'"""
...@@ -60,7 +60,7 @@ class Args: ...@@ -60,7 +60,7 @@ class Args:
num_channels: int = 128 num_channels: int = 128
"""the number of channels for the agent""" """the number of channels for the agent"""
total_timesteps: int = 100000000 total_timesteps: int = 1000000000
"""total timesteps of the experiments""" """total timesteps of the experiments"""
learning_rate: float = 2.5e-4 learning_rate: float = 2.5e-4
"""the learning rate of the optimizer""" """the learning rate of the optimizer"""
...@@ -76,7 +76,7 @@ class Args: ...@@ -76,7 +76,7 @@ class Args:
"""the lambda for the general advantage estimation""" """the lambda for the general advantage estimation"""
minibatch_size: int = 256 minibatch_size: int = 256
"""the mini-batch size""" """the mini-batch size"""
update_epochs: int = 4 update_epochs: int = 2
"""the K epochs to update the policy""" """the K epochs to update the policy"""
norm_adv: bool = True norm_adv: bool = True
"""Toggles advantages normalization""" """Toggles advantages normalization"""
...@@ -219,7 +219,7 @@ def run(local_rank, world_size): ...@@ -219,7 +219,7 @@ def run(local_rank, world_size):
# agent.get_action_and_value = torch.compile(agent.get_action_and_value, mode=args.compile_mode) # agent.get_action_and_value = torch.compile(agent.get_action_and_value, mode=args.compile_mode)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train) scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
def masked_mean(x, valid): def masked_mean(x, valid):
x = x.masked_fill(~valid, 0) x = x.masked_fill(~valid, 0)
......
...@@ -304,7 +304,12 @@ class Encoder(nn.Module): ...@@ -304,7 +304,12 @@ class Encoder(nn.Module):
f_actions = layer(f_actions, f_h_actions) f_actions = layer(f_actions, f_h_actions)
f_actions = self.action_norm(f_actions) f_actions = self.action_norm(f_actions)
return f_actions, mask, valid
f_s_cards_global = f_cards.mean(dim=1)
c_mask = 1 - mask.unsqueeze(-1).float()
f_s_actions_ha = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1)
f_state = torch.cat([f_s_cards_global, f_s_actions_ha], dim=-1)
return f_actions, f_state, mask, valid
class PPOAgent(nn.Module): class PPOAgent(nn.Module):
...@@ -324,67 +329,32 @@ class PPOAgent(nn.Module): ...@@ -324,67 +329,32 @@ class PPOAgent(nn.Module):
) )
self.critic = nn.Sequential( self.critic = nn.Sequential(
nn.Linear(c, c // 4), nn.Linear(c * 2, c // 2),
nn.ReLU(), nn.ReLU(),
nn.Linear(c // 4, 1), nn.Linear(c // 2, 1),
) )
def load_embeddings(self, embeddings, freeze=True): def load_embeddings(self, embeddings, freeze=True):
self.encoder.load_embeddings(embeddings, freeze) self.encoder.load_embeddings(embeddings, freeze)
def get_value(self, x): def get_value(self, x):
f_actions, mask, valid = self.encoder(x) f_actions, f_state, mask, valid = self.encoder(x)
c_mask = 1 - mask.unsqueeze(-1).float() return self.critic(f_state)
f = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1)
return self.critic(f)
def get_action_and_value(self, x, action): def get_action_and_value(self, x, action):
f_actions, mask, valid = self.encoder(x) f_actions, f_state, mask, valid = self.encoder(x)
c_mask = 1 - mask.unsqueeze(-1).float()
f = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1)
logits = self.actor(f_actions)[..., 0] logits = self.actor(f_actions)[..., 0]
logits = logits.float() logits = logits.float()
logits = logits.masked_fill(mask, float("-inf")) logits = logits.masked_fill(mask, float("-inf"))
probs = Categorical(logits=logits) probs = Categorical(logits=logits)
return action, probs.log_prob(action), probs.entropy(), self.critic(f), valid return action, probs.log_prob(action), probs.entropy(), self.critic(f_state), valid
def forward(self, x): def forward(self, x):
f_actions, mask, valid = self.encoder(x) f_actions, f_state, mask, valid = self.encoder(x)
c_mask = 1 - mask.unsqueeze(-1).float()
f = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1)
logits = self.actor(f_actions)[..., 0] logits = self.actor(f_actions)[..., 0]
logits = logits.float() logits = logits.float()
logits = logits.masked_fill(mask, float("-inf")) logits = logits.masked_fill(mask, float("-inf"))
return logits, self.critic(f) return logits, self.critic(f_state)
class DMCAgent(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True):
super(DMCAgent, self).__init__()
self.encoder = Encoder(
channels, num_card_layers, num_action_layers, num_history_action_layers, embedding_shape, bias, affine)
c = channels
self.value_head = nn.Sequential(
nn.Linear(c, c // 4),
nn.ReLU(),
nn.Linear(c // 4, 1),
)
def load_embeddings(self, embeddings, freeze=True):
self.encoder.load_embeddings(embeddings, freeze)
def forward(self, x):
f_actions, mask, valid = self.encoder(f_actions)
values = self.value_head(f_actions)[..., 0]
# values = torch.tanh(values)
values = torch.where(mask, torch.full_like(values, -10), values)
return values, valid
\ No newline at end of file
...@@ -1516,9 +1516,9 @@ public: ...@@ -1516,9 +1516,9 @@ public:
if (win_turn <= 5) { if (win_turn <= 5) {
base_reward = 2.0; base_reward = 2.0;
} else if (win_turn <= 3) { } else if (win_turn <= 3) {
base_reward = 4.0; base_reward = 3.0;
} else if (win_turn <= 1) { } else if (win_turn <= 1) {
base_reward = 8.0; base_reward = 4.0;
} }
if (play_mode_ == kSelfPlay) { if (play_mode_ == kSelfPlay) {
// to_play_ is the previous player // to_play_ is the previous player
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment