Commit 1487b081 authored by Biluo Shen's avatar Biluo Shen

(WIP) add cleanba_ppo

parent 80707a8c
...@@ -88,4 +88,7 @@ ...@@ -88,4 +88,7 @@
## History Actions ## History Actions
- 0,1: card id, uint16 -> 2 uint8 - 0,1: card id, uint16 -> 2 uint8
- others same as legal actions - 2-12 same as legal actions
- 13: player, discrete, 0: me, 1: oppo
- 14: turn, discrete, trunc to 3
...@@ -41,7 +41,7 @@ class Args: ...@@ -41,7 +41,7 @@ class Args:
"""the language to use""" """the language to use"""
max_options: int = 24 max_options: int = 24
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 16 n_history_actions: int = 32
"""the number of history actions to use""" """the number of history actions to use"""
num_embeddings: Optional[int] = None num_embeddings: Optional[int] = None
"""the number of embeddings of the agent""" """the number of embeddings of the agent"""
......
This diff is collapsed.
...@@ -69,7 +69,7 @@ class Args: ...@@ -69,7 +69,7 @@ class Args:
"""the number of parallel game environments""" """the number of parallel game environments"""
num_steps: int = 128 num_steps: int = 128
"""the number of steps to run in each environment per policy rollout""" """the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks""" """Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0 gamma: float = 1.0
"""the discount factor gamma""" """the discount factor gamma"""
...@@ -329,21 +329,17 @@ def main(): ...@@ -329,21 +329,17 @@ def main():
global_step = 0 global_step = 0
warmup_steps = 0 warmup_steps = 0
start_time = time.time() start_time = time.time()
next_obs, info = envs.reset()
next_obs = to_tensor(next_obs, device, dtype=torch.uint8)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_, device)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool) next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
ai_player1_ = np.concatenate([ ai_player1_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64), np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64) np.ones(args.local_num_envs // 2, dtype=np.int64)
]) ])
np.random.shuffle(ai_player1_) np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype) ai_player1 = to_tensor(ai_player1_, device)
next_value1 = next_value2 = 0 next_value1 = next_value2 = 0
step = 0 step = 0
ts = []
lp_count = 0 lp_count = 0
ts = sample_target(history)
for iteration in range(args.num_iterations): for iteration in range(args.num_iterations):
# Annealing the rate if instructed to do so. # Annealing the rate if instructed to do so.
...@@ -351,6 +347,15 @@ def main(): ...@@ -351,6 +347,15 @@ def main():
frac = 1.0 - (iteration % args.iter_per_lp) / args.iter_per_lp frac = 1.0 - (iteration % args.iter_per_lp) / args.iter_per_lp
lrnow = frac * args.learning_rate lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow optimizer.param_groups[0]["lr"] = lrnow
if iteration % args.iter_per_lp == 0:
next_obs, info = envs.reset()
next_obs = to_tensor(next_obs, device, dtype=torch.uint8)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_, device)
next_value1 = next_value2 = 0
step = 0
ts = []
if len(ts) == 0: if len(ts) == 0:
ts = sample_target(history) ts = sample_target(history)
...@@ -538,7 +543,7 @@ def main(): ...@@ -538,7 +543,7 @@ def main():
if (iteration + 1) % args.iter_per_lp == 0: if (iteration + 1) % args.iter_per_lp == 0:
lp_count += 1 lp_count += 1
win_rates = sync_var(avg_win_rates, dtype=torch.float32, reduce='mean') win_rates = sync_var(avg_win_rates, dtype=torch.float32, reduce='mean')
if np.all(win_rates > args.update_win_rate) or lp_count >= args.max_lp: if len(history) == 0 or np.all(win_rates > args.update_win_rate) or lp_count >= args.max_lp:
agent_t.load_state_dict(agent.state_dict()) agent_t.load_state_dict(agent.state_dict())
with torch.no_grad(): with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False) traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
......
...@@ -343,7 +343,8 @@ class Encoder(nn.Module): ...@@ -343,7 +343,8 @@ class Encoder(nn.Module):
mask = x_actions[:, :, 2] == 0 # msg == 0 mask = x_actions[:, :, 2] == 0 # msg == 0
valid = x['global_'][:, -1] == 0 valid = x['global_'][:, -1] == 0
mask[:, 0] &= valid mask[:, 0] = False
# mask[:, 0] &= valid
for layer in self.action_card_net: for layer in self.action_card_net:
f_actions = layer( f_actions = layer(
f_actions, f_cards[:, 1:], tgt_key_padding_mask=mask, memory_key_padding_mask=c_mask) f_actions, f_cards[:, 1:], tgt_key_padding_mask=mask, memory_key_padding_mask=c_mask)
......
...@@ -54,6 +54,20 @@ def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_adv ...@@ -54,6 +54,20 @@ def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_adv
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def bootstrap_value(values, rewards, dones, nextvalues, next_done, gamma, gae_lambda):
num_steps = rewards.size(0)
advantages = torch.zeros_like(rewards)
lastgaelam = 0
for t in reversed(range(num_steps)):
if t == num_steps - 1:
nextnonterminal = 1.0 - next_done
nextvalues = nextvalues
else:
nextnonterminal = 1.0 - dones[t + 1]
nextvalues = values[t + 1]
delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
def bootstrap_value_self(values, rewards, dones, learns, nextvalues, next_done, gamma, gae_lambda): def bootstrap_value_self(values, rewards, dones, learns, nextvalues, next_done, gamma, gae_lambda):
num_steps = rewards.size(0) num_steps = rewards.size(0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment