Commit ff67e8c4 authored by Biluo Shen's avatar Biluo Shen

Add partial GAE

parent 745c67f9
...@@ -4,7 +4,6 @@ import time ...@@ -4,7 +4,6 @@ import time
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal, Optional from typing import Literal, Optional
import pickle
import ygoenv import ygoenv
import numpy as np import numpy as np
...@@ -99,6 +98,8 @@ class Args: ...@@ -99,6 +98,8 @@ class Args:
"""the target KL divergence threshold""" """the target KL divergence threshold"""
learn_opponent: bool = True learn_opponent: bool = True
"""if toggled, the samples from the opponent will be used to train the agent""" """if toggled, the samples from the opponent will be used to train the agent"""
collect_length: int = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl" backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training""" """the backend for distributed training"""
...@@ -156,6 +157,9 @@ def main(): ...@@ -156,6 +157,9 @@ def main():
args.num_iterations = args.total_timesteps // args.batch_size args.num_iterations = args.total_timesteps // args.batch_size
args.env_threads = args.env_threads or args.num_envs args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size) args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_torch_threads = args.torch_threads // args.world_size local_torch_threads = args.torch_threads // args.world_size
local_env_threads = args.env_threads // args.world_size local_env_threads = args.env_threads // args.world_size
...@@ -279,13 +283,13 @@ def main(): ...@@ -279,13 +283,13 @@ def main():
traced_model = agent traced_model = agent
# ALGO Logic: Storage setup # ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device) obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device)
actions = torch.zeros((args.num_steps, args.local_num_envs) + action_shape).to(device) actions = torch.zeros((args.collect_length, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.num_steps, args.local_num_envs)).to(device) logprobs = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.local_num_envs)).to(device) rewards = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device) dones = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.num_steps, args.local_num_envs)).to(device) values = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
learns = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device) learns = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000) avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000) avg_win_rates = deque(maxlen=1000)
...@@ -305,6 +309,7 @@ def main(): ...@@ -305,6 +309,7 @@ def main():
np.random.shuffle(ai_player1_) np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype) ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
next_value1 = next_value2 = 0 next_value1 = next_value2 = 0
step = 0
for iteration in range(1, args.num_iterations + 1): for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so. # Annealing the rate if instructed to do so.
...@@ -316,7 +321,7 @@ def main(): ...@@ -316,7 +321,7 @@ def main():
model_time = 0 model_time = 0
env_time = 0 env_time = 0
collect_start = time.time() collect_start = time.time()
for step in range(0, args.num_steps): while step < args.collect_length:
global_step += args.num_envs global_step += args.num_envs
for key in obs: for key in obs:
...@@ -350,6 +355,7 @@ def main(): ...@@ -350,6 +355,7 @@ def main():
env_time += time.time() - _start env_time += time.time() - _start
rewards[step] = to_tensor(reward, device) rewards[step] = to_tensor(reward, device)
next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool) next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool)
step += 1
if not writer: if not writer:
continue continue
...@@ -378,29 +384,44 @@ def main(): ...@@ -378,29 +384,44 @@ def main():
if local_rank == 0: if local_rank == 0:
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}") fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
step = args.collect_length - args.num_steps
_start = time.time() _start = time.time()
# bootstrap value if not done # bootstrap value if not done
with torch.no_grad(): with torch.no_grad():
value = traced_model(next_obs)[1].reshape(-1) value = traced_model(next_obs)[1].reshape(-1)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1) nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2) nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
if step > 0 and iteration != 1:
# recalculate the values for the first few steps
v_steps = args.local_minibatch_size * 4 // args.local_num_envs
for v_start in range(0, step, v_steps):
v_end = min(v_start + v_steps, step)
v_obs = {
k: v[v_start:v_end].flatten(0, 1) for k, v in obs.items()
}
with torch.no_grad():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value = traced_model(v_obs)[1].reshape(v_end - v_start, -1)
values[v_start:v_end] = value
advantages = bootstrap_value_selfplay( advantages = bootstrap_value_selfplay(
values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, args.gamma, args.gae_lambda) values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, args.gamma, args.gae_lambda)
returns = advantages + values
bootstrap_time = time.time() - _start bootstrap_time = time.time() - _start
_start = time.time() _start = time.time()
# flatten the batch # flatten the batch
b_obs = { b_obs = {
k: v.reshape((-1,) + v.shape[2:]) k: v[:args.num_steps].reshape((-1,) + v.shape[2:])
for k, v in obs.items() for k, v in obs.items()
} }
b_logprobs = logprobs.reshape(-1) b_actions = actions[:args.num_steps].reshape((-1,) + action_shape)
b_actions = actions.reshape((-1,) + action_shape) b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages.reshape(-1) b_advantages = advantages[:args.num_steps].reshape(-1)
b_returns = returns.reshape(-1) b_values = values[:args.num_steps].reshape(-1)
b_values = values.reshape(-1) b_learns = learns[:args.num_steps].reshape(-1)
b_learns = learns.reshape(-1) b_returns = b_advantages + b_values
# Optimizing the policy and value network # Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size) b_inds = np.arange(args.local_batch_size)
...@@ -424,7 +445,14 @@ def main(): ...@@ -424,7 +445,14 @@ def main():
if args.target_kl is not None and approx_kl > args.target_kl: if args.target_kl is not None and approx_kl > args.target_kl:
break break
if step > 0:
# TODO: use cyclic buffer to avoid copying
for v in obs.values():
v[:step] = v[args.num_steps:].clone()
for v in [actions, logprobs, rewards, dones, values, learns]:
v[:step] = v[args.num_steps:].clone()
train_time = time.time() - _start train_time = time.time() - _start
if local_rank == 0: if local_rank == 0:
......
This diff is collapsed.
...@@ -112,8 +112,8 @@ inline bool sum_to2(const std::vector<std::vector<int>> &w, ...@@ -112,8 +112,8 @@ inline bool sum_to2(const std::vector<std::vector<int>> &w,
} }
inline std::vector<std::vector<int>> inline std::vector<std::vector<int>>
combinations_with_weight2(const std::vector<std::vector<int>> &weights, combinations_with_weight2(
int r) { const std::vector<std::vector<int>> &weights, int r) {
int n = weights.size(); int n = weights.size();
std::vector<std::vector<int>> results; std::vector<std::vector<int>> results;
...@@ -1771,9 +1771,9 @@ public: ...@@ -1771,9 +1771,9 @@ public:
uint8_t msg_id = uint8_t(ha(i, 2)); uint8_t msg_id = uint8_t(ha(i, 2));
int msg = _msgs[msg_id - 1]; int msg = _msgs[msg_id - 1];
fmt::print("msg: {},", msg_to_string(msg)); fmt::print("msg: {},", msg_to_string(msg));
auto v1 = static_cast<CardId>(ha(i, 0)); uint8_t v1 = ha(i, 0);
auto v2 = static_cast<CardId>(ha(i, 1)); uint8_t v2 = ha(i, 1);
CardId card_id = (v1 << 8) + v2; CardId card_id = (static_cast<CardId>(v1) << 8) + static_cast<CardId>(v2);
fmt::print(" {};", card_id); fmt::print(" {};", card_id);
for (int j = 3; j < ha.Shape()[1]; j++) { for (int j = 3; j < ha.Shape()[1]; j++) {
fmt::print(" {}", uint8_t(ha(i, j))); fmt::print(" {}", uint8_t(ha(i, j)));
...@@ -2326,15 +2326,13 @@ private: ...@@ -2326,15 +2326,13 @@ private:
for (int i = 0; i < n_options; ++i) { for (int i = 0; i < n_options; ++i) {
uint8_t spec_index1 = state["obs:actions_"_](i, 0); uint8_t spec_index1 = state["obs:actions_"_](i, 0);
uint8_t spec_index2 = state["obs:actions_"_](i, 1); uint8_t spec_index2 = state["obs:actions_"_](i, 1);
uint16_t spec_index = (spec_index1 << 8) + spec_index2; uint16_t spec_index = (static_cast<uint16_t>(spec_index1) << 8) + static_cast<uint16_t>(spec_index2);
if (spec_index == 0) { if (spec_index == 0) {
h_card_ids[i] = 0; h_card_ids[i] = 0;
} else { } else {
uint16_t card_id1 = uint8_t card_id1 = state["obs:cards_"_](spec_index - 1, 0);
static_cast<uint16_t>(state["obs:cards_"_](spec_index - 1, 0)); uint8_t card_id2 = state["obs:cards_"_](spec_index - 1, 1);
uint16_t card_id2 = h_card_ids[i] = (static_cast<uint16_t>(card_id1) << 8) + static_cast<uint16_t>(card_id2);
static_cast<uint16_t>(state["obs:cards_"_](spec_index - 1, 1));
h_card_ids[i] = (card_id1 << 8) + card_id2;
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment