Commit ff67e8c4 authored by Biluo Shen's avatar Biluo Shen

Add partial GAE

parent 745c67f9
......@@ -4,7 +4,6 @@ import time
from collections import deque
from dataclasses import dataclass
from typing import Literal, Optional
import pickle
import ygoenv
import numpy as np
......@@ -99,6 +98,8 @@ class Args:
"""the target KL divergence threshold"""
learn_opponent: bool = True
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length: int = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
......@@ -156,6 +157,9 @@ def main():
args.num_iterations = args.total_timesteps // args.batch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_torch_threads = args.torch_threads // args.world_size
local_env_threads = args.env_threads // args.world_size
......@@ -279,13 +283,13 @@ def main():
traced_model = agent
# ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device)
actions = torch.zeros((args.num_steps, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
learns = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device)
actions = torch.zeros((args.collect_length, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
rewards = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
dones = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
learns = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
......@@ -305,6 +309,7 @@ def main():
np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
next_value1 = next_value2 = 0
step = 0
for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so.
......@@ -316,7 +321,7 @@ def main():
model_time = 0
env_time = 0
collect_start = time.time()
for step in range(0, args.num_steps):
while step < args.collect_length:
global_step += args.num_envs
for key in obs:
......@@ -350,6 +355,7 @@ def main():
env_time += time.time() - _start
rewards[step] = to_tensor(reward, device)
next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool)
step += 1
if not writer:
continue
......@@ -378,29 +384,44 @@ def main():
if local_rank == 0:
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
step = args.collect_length - args.num_steps
_start = time.time()
# bootstrap value if not done
with torch.no_grad():
value = traced_model(next_obs)[1].reshape(-1)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
if step > 0 and iteration != 1:
# recalculate the values for the first few steps
v_steps = args.local_minibatch_size * 4 // args.local_num_envs
for v_start in range(0, step, v_steps):
v_end = min(v_start + v_steps, step)
v_obs = {
k: v[v_start:v_end].flatten(0, 1) for k, v in obs.items()
}
with torch.no_grad():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value = traced_model(v_obs)[1].reshape(v_end - v_start, -1)
values[v_start:v_end] = value
advantages = bootstrap_value_selfplay(
values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, args.gamma, args.gae_lambda)
returns = advantages + values
bootstrap_time = time.time() - _start
_start = time.time()
# flatten the batch
b_obs = {
k: v.reshape((-1,) + v.shape[2:])
k: v[:args.num_steps].reshape((-1,) + v.shape[2:])
for k, v in obs.items()
}
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + action_shape)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
b_learns = learns.reshape(-1)
b_actions = actions[:args.num_steps].reshape((-1,) + action_shape)
b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1)
b_learns = learns[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values
# Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size)
......@@ -424,7 +445,14 @@ def main():
if args.target_kl is not None and approx_kl > args.target_kl:
break
if step > 0:
# TODO: use cyclic buffer to avoid copying
for v in obs.values():
v[:step] = v[args.num_steps:].clone()
for v in [actions, logprobs, rewards, dones, values, learns]:
v[:step] = v[args.num_steps:].clone()
train_time = time.time() - _start
if local_rank == 0:
......
This diff is collapsed.
......@@ -112,8 +112,8 @@ inline bool sum_to2(const std::vector<std::vector<int>> &w,
}
inline std::vector<std::vector<int>>
combinations_with_weight2(const std::vector<std::vector<int>> &weights,
int r) {
combinations_with_weight2(
const std::vector<std::vector<int>> &weights, int r) {
int n = weights.size();
std::vector<std::vector<int>> results;
......@@ -1771,9 +1771,9 @@ public:
uint8_t msg_id = uint8_t(ha(i, 2));
int msg = _msgs[msg_id - 1];
fmt::print("msg: {},", msg_to_string(msg));
auto v1 = static_cast<CardId>(ha(i, 0));
auto v2 = static_cast<CardId>(ha(i, 1));
CardId card_id = (v1 << 8) + v2;
uint8_t v1 = ha(i, 0);
uint8_t v2 = ha(i, 1);
CardId card_id = (static_cast<CardId>(v1) << 8) + static_cast<CardId>(v2);
fmt::print(" {};", card_id);
for (int j = 3; j < ha.Shape()[1]; j++) {
fmt::print(" {}", uint8_t(ha(i, j)));
......@@ -2326,15 +2326,13 @@ private:
for (int i = 0; i < n_options; ++i) {
uint8_t spec_index1 = state["obs:actions_"_](i, 0);
uint8_t spec_index2 = state["obs:actions_"_](i, 1);
uint16_t spec_index = (spec_index1 << 8) + spec_index2;
uint16_t spec_index = (static_cast<uint16_t>(spec_index1) << 8) + static_cast<uint16_t>(spec_index2);
if (spec_index == 0) {
h_card_ids[i] = 0;
} else {
uint16_t card_id1 =
static_cast<uint16_t>(state["obs:cards_"_](spec_index - 1, 0));
uint16_t card_id2 =
static_cast<uint16_t>(state["obs:cards_"_](spec_index - 1, 1));
h_card_ids[i] = (card_id1 << 8) + card_id2;
uint8_t card_id1 = state["obs:cards_"_](spec_index - 1, 0);
uint8_t card_id2 = state["obs:cards_"_](spec_index - 1, 1);
h_card_ids[i] = (static_cast<uint16_t>(card_id1) << 8) + static_cast<uint16_t>(card_id2);
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment