Commit af60d012 authored by biluo.shen's avatar biluo.shen

Add Actor and Critic

parent 32aa61bf
......@@ -95,10 +95,8 @@ class Args:
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
compile: bool = True
"""whether to use torch.compile to compile the model and functions"""
compile_mode: Optional[str] = None
"""the mode to use for torch.compile"""
compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = None
......
......@@ -24,6 +24,7 @@ class Encoder(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True):
super(Encoder, self).__init__()
self.channels = channels
self.num_history_action_layers = num_history_action_layers
c = channels
......@@ -319,6 +320,51 @@ class Encoder(nn.Module):
return f_actions, f_state, mask, valid
class PPOCritic(nn.Module):
def __init__(self, channels):
super(PPOCritic, self).__init__()
c = channels
self.net = nn.Sequential(
nn.Linear(c * 2, c // 2),
nn.ReLU(),
nn.Linear(c // 2, 1),
)
def forward(self, f_state):
return self.net(f_state)
class PPOActor(nn.Module):
def __init__(self, channels):
super(PPOActor, self).__init__()
c = channels
self.trans = nn.TransformerEncoderLayer(
c, 4, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
self.head = nn.Sequential(
nn.Linear(c, c // 4),
nn.ReLU(),
nn.Linear(c // 4, 1),
)
def forward(self, f_actions, mask, action):
f_actions = self.trans(f_actions, src_key_padding_mask=mask)
logits = self.head(f_actions)[..., 0]
logits = logits.float()
logits = logits.masked_fill(mask, float("-inf"))
probs = Categorical(logits=logits)
return probs.log_prob(action), probs.entropy()
def predict(self, f_actions, mask):
f_actions = self.trans(f_actions, src_key_padding_mask=mask)
logits = self.head(f_actions)[..., 0]
logits = logits.float()
logits = logits.masked_fill(mask, float("-inf"))
return logits
class PPOAgent(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
......
......@@ -4,17 +4,17 @@ import torch.distributed as dist
import torch.multiprocessing as mp
def reduce_gradidents(model, world_size):
def reduce_gradidents(params, world_size):
if world_size == 1:
return
all_grads_list = []
for param in model.parameters():
for param in params:
if param.grad is not None:
all_grads_list.append(param.grad.view(-1))
all_grads = torch.cat(all_grads_list)
dist.all_reduce(all_grads, op=dist.ReduceOp.SUM)
offset = 0
for param in model.parameters():
for param in params:
if param.grad is not None:
param.grad.data.copy_(
all_grads[offset : offset + param.numel()].view_as(param.grad.data) / world_size
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment