Commit 3c7a7080 authored by biluo.shen's avatar biluo.shen

Add PPO

parent ad9c3c34
...@@ -60,17 +60,17 @@ class Args: ...@@ -60,17 +60,17 @@ class Args:
total_timesteps: int = 100000000 total_timesteps: int = 100000000
"""total timesteps of the experiments""" """total timesteps of the experiments"""
learning_rate: float = 2.5e-4 learning_rate: float = 5e-4
"""the learning rate of the optimizer""" """the learning rate of the optimizer"""
num_envs: int = 64 num_envs: int = 64
"""the number of parallel game environments""" """the number of parallel game environments"""
num_steps: int = 100 num_steps: int = 200
"""the number of steps per env per iteration""" """the number of steps per env per iteration"""
buffer_size: int = 200000 buffer_size: int = 20000
"""the replay memory buffer size""" """the replay memory buffer size"""
gamma: float = 0.99 gamma: float = 0.99
"""the discount factor gamma""" """the discount factor gamma"""
minibatch_size: int = 256 minibatch_size: int = 1024
"""the mini-batch size""" """the mini-batch size"""
eps: float = 0.05 eps: float = 0.05
"""the epsilon for exploration""" """the epsilon for exploration"""
...@@ -264,13 +264,13 @@ if __name__ == "__main__": ...@@ -264,13 +264,13 @@ if __name__ == "__main__":
# ALGO LOGIC: training. # ALGO LOGIC: training.
_start = time.time() _start = time.time()
b_inds = rb.get_data_indices() if not rb.full:
if len(b_inds) < args.minibatch_size:
continue continue
b_inds = rb.get_data_indices()
np.random.shuffle(b_inds) np.random.shuffle(b_inds)
b_obs, b_actions, b_returns = rb._get_samples(b_inds) b_obs, b_actions, b_returns = rb._get_samples(b_inds)
sample_time += time.time() - _start sample_time += time.time() - _start
for start in range(0, len(b_inds), args.minibatch_size): for start in range(0, len(b_returns), args.minibatch_size):
_start = time.time() _start = time.time()
end = start + args.minibatch_size end = start + args.minibatch_size
mb_obs = { mb_obs = {
......
This diff is collapsed.
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributions import Categorical
def bytes_to_bin(x, points, intervals): def bytes_to_bin(x, points, intervals):
...@@ -18,11 +19,11 @@ def make_bin_params(x_max=32000, n_bins=32, sig_bins=24): ...@@ -18,11 +19,11 @@ def make_bin_params(x_max=32000, n_bins=32, sig_bins=24):
return points, intervals return points, intervals
class Agent(nn.Module): class Encoder(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2, def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True): num_history_action_layers=2, embedding_shape=None, bias=False, affine=True):
super(Agent, self).__init__() super(Encoder, self).__init__()
self.num_history_action_layers = num_history_action_layers self.num_history_action_layers = num_history_action_layers
c = channels c = channels
...@@ -129,11 +130,6 @@ class Agent(nn.Module): ...@@ -129,11 +130,6 @@ class Agent(nn.Module):
]) ])
self.action_norm = nn.LayerNorm(c, elementwise_affine=False) self.action_norm = nn.LayerNorm(c, elementwise_affine=False)
self.value_head = nn.Sequential(
nn.Linear(c, c // 4),
nn.ReLU(),
nn.Linear(c // 4, 1),
)
self.init_embeddings() self.init_embeddings()
...@@ -148,7 +144,6 @@ class Agent(nn.Module): ...@@ -148,7 +144,6 @@ class Agent(nn.Module):
elif "fc_emb" in n: elif "fc_emb" in n:
nn.init.uniform_(m.weight, -scale, scale) nn.init.uniform_(m.weight, -scale, scale)
def load_embeddings(self, embeddings, freeze=True): def load_embeddings(self, embeddings, freeze=True):
weight = self.id_embed.weight weight = self.id_embed.weight
embeddings = torch.from_numpy(embeddings).to(dtype=weight.dtype, device=weight.device) embeddings = torch.from_numpy(embeddings).to(dtype=weight.dtype, device=weight.device)
...@@ -309,7 +304,78 @@ class Agent(nn.Module): ...@@ -309,7 +304,78 @@ class Agent(nn.Module):
f_actions = layer(f_actions, f_h_actions) f_actions = layer(f_actions, f_h_actions)
f_actions = self.action_norm(f_actions) f_actions = self.action_norm(f_actions)
return f_actions, mask, valid
class PPOAgent(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True):
super(PPOAgent, self).__init__()
self.encoder = Encoder(
channels, num_card_layers, num_action_layers, num_history_action_layers, embedding_shape, bias, affine)
c = channels
self.actor = nn.Sequential(
nn.Linear(c, c // 4),
nn.ReLU(),
nn.Linear(c // 4, 1),
)
self.critic = nn.Sequential(
nn.Linear(c, c // 4),
nn.ReLU(),
nn.Linear(c // 4, 1),
)
def load_embeddings(self, embeddings, freeze=True):
self.encoder.load_embeddings(embeddings, freeze)
def get_value(self, x):
f_actions, mask, valid = self.encoder(x)
c_mask = 1 - mask.unsqueeze(-1).float()
f = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1)
return self.critic(f)
def get_action_and_value(self, x, action=None):
f_actions, mask, valid = self.encoder(x)
c_mask = 1 - mask.unsqueeze(-1).float()
f = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1)
logits = self.actor(f_actions)[..., 0]
logits = logits.masked_fill(mask, float("-inf"))
probs = Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy(), self.critic(f), valid
class DMCAgent(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True):
super(DMCAgent, self).__init__()
self.encoder = Encoder(
channels, num_card_layers, num_action_layers, num_history_action_layers, embedding_shape, bias, affine)
c = channels
self.value_head = nn.Sequential(
nn.Linear(c, c // 4),
nn.ReLU(),
nn.Linear(c // 4, 1),
)
def load_embeddings(self, embeddings, freeze=True):
self.encoder.load_embeddings(embeddings, freeze)
def forward(self, x):
f_actions, mask, valid = self.encoder(f_actions)
values = self.value_head(f_actions)[..., 0] values = self.value_head(f_actions)[..., 0]
# values = torch.tanh(values) # values = torch.tanh(values)
values = torch.where(mask, torch.full_like(values, -10), values) values = torch.where(mask, torch.full_like(values, -10), values)
return values, valid return values, valid
\ No newline at end of file
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def reduce_gradidents(model, world_size):
if world_size == 1:
return
all_grads_list = []
for param in model.parameters():
if param.grad is not None:
all_grads_list.append(param.grad.view(-1))
all_grads = torch.cat(all_grads_list)
dist.all_reduce(all_grads, op=dist.ReduceOp.SUM)
offset = 0
for param in model.parameters():
if param.grad is not None:
param.grad.data.copy_(
all_grads[offset : offset + param.numel()].view_as(param.grad.data) / world_size
)
offset += param.numel()
def setup(backend, rank, world_size, port):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(port)
dist.init_process_group(backend, rank=rank, world_size=world_size)
def mp_start(run):
world_size = int(os.getenv("WORLD_SIZE", "1"))
if world_size == 1:
run(local_rank=0, world_size=world_size)
else:
children = []
for i in range(world_size):
subproc = mp.Process(target=run, args=(i, world_size))
children.append(subproc)
subproc.start()
for i in range(world_size):
children[i].join()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment