Commit 6752ba72 authored by sbl1996@126.com's avatar sbl1996@126.com

Add card_mask and film to agent

parent 1d35fed3
......@@ -3,7 +3,7 @@ import time
import os
import random
from typing import Optional
from dataclasses import dataclass
from dataclasses import dataclass, field, asdict
from tqdm import tqdm
from functools import partial
......@@ -18,7 +18,7 @@ import flax
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.jax.agent2 import RNNAgent
from ygoai.rl.jax.agent2 import RNNAgent, ModelArgs
@dataclass
......@@ -44,10 +44,6 @@ class Args:
"""the number of history actions to use"""
num_embeddings: Optional[int] = None
"""the number of embeddings of the agent"""
use_history1: bool = True
"""whether to use history actions as input for agent1"""
use_history2: bool = True
"""whether to use history actions as input for agent2"""
verbose: bool = False
"""whether to print debug information"""
......@@ -59,16 +55,11 @@ class Args:
num_envs: int = 64
"""the number of parallel game environments"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent"""
rnn_type1: Optional[str] = "lstm"
"""the type of RNN to use for agent1, None for no RNN"""
rnn_type2: Optional[str] = "lstm"
"""the type of RNN to use for agent2, None for no RNN"""
m1: ModelArgs = field(default_factory=lambda: ModelArgs())
"""the model arguments for the agent1"""
m2: ModelArgs = field(default_factory=lambda: ModelArgs())
"""the model arguments for the agent2"""
checkpoint1: str = "checkpoints/agent.pt"
"""the checkpoint to load for the first agent, must be a `flax_model` file"""
checkpoint2: str = "checkpoints/agent.pt"
......@@ -83,23 +74,15 @@ class Args:
def create_agent1(args):
return RNNAgent(
channels=args.num_channels,
num_layers=args.num_layers,
rnn_channels=args.rnn_channels,
**asdict(args.m1),
embedding_shape=args.num_embeddings,
use_history=args.use_history1,
rnn_type=args.rnn_type1,
)
def create_agent2(args):
return RNNAgent(
channels=args.num_channels,
num_layers=args.num_layers,
rnn_channels=args.rnn_channels,
**asdict(args.m2),
embedding_shape=args.num_embeddings,
use_history=args.use_history2,
rnn_type=args.rnn_type2,
)
......
......@@ -6,7 +6,7 @@ import threading
import time
from datetime import datetime, timedelta, timezone
from collections import deque
from dataclasses import dataclass, field
from dataclasses import dataclass, field, asdict
from types import SimpleNamespace
from typing import List, NamedTuple, Optional
from functools import partial
......@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent2 import RNNAgent
from ygoai.rl.jax.agent2 import RNNAgent, ModelArgs
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import clipped_surrogate_pg_loss, vtrace_2p0s, mse_loss, entropy_loss, simple_policy_loss, ach_loss, policy_gradient_loss
......@@ -80,10 +80,6 @@ class Args:
"""the number of history actions to use"""
greedy_reward: bool = False
"""whether to use greedy reward (faster kill higher reward)"""
use_history: bool = True
"""whether to use history actions as input for agent"""
eval_use_history: bool = True
"""whether to use history actions as input for eval agent"""
total_timesteps: int = 50000000000
"""total timesteps of the experiments"""
......@@ -146,16 +142,10 @@ class Args:
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
rnn_type: Optional[str] = "lstm"
"""the type of RNN to use, None for no RNN"""
eval_rnn_type: Optional[str] = "lstm"
"""the type of RNN to use for evaluation, None for no RNN"""
m1: ModelArgs = field(default_factory=lambda: ModelArgs())
"""the model arguments for the agent"""
m2: ModelArgs = field(default_factory=lambda: ModelArgs())
"""the model arguments for the eval agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
"""the device ids that actor workers will use"""
......@@ -228,17 +218,21 @@ class Transition(NamedTuple):
def create_agent(args, eval=False):
if eval:
return RNNAgent(
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
**asdict(args.m2),
)
else:
return RNNAgent(
channels=args.num_channels,
num_layers=args.num_layers,
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
rnn_channels=args.rnn_channels,
switch=args.switch,
freeze_id=args.freeze_id,
use_history=args.use_history if not eval else args.eval_use_history,
rnn_type=args.rnn_type if not eval else args.eval_rnn_type,
**asdict(args.m1),
)
......
from typing import Tuple, Union, Optional, Sequence
from dataclasses import dataclass
from typing import Tuple, Union, Optional, Sequence, Literal
from functools import partial
import numpy as np
......@@ -6,7 +7,7 @@ import jax
import jax.numpy as jnp
import flax.linen as nn
from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding
from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer
from ygoai.rl.jax.modules import MLP, make_bin_params, bytes_to_bin, decode_id
......@@ -15,6 +16,13 @@ default_fc_init1 = nn.initializers.uniform(scale=0.001)
default_fc_init2 = nn.initializers.uniform(scale=0.001)
def get_encoder_layer_cls(noam, n_heads, dtype, param_dtype):
if noam:
return LlamaEncoderLayer(n_heads, dtype=dtype, param_dtype=param_dtype, rope=False)
else:
return EncoderLayer(n_heads, dtype=dtype, param_dtype=param_dtype)
class ActionEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
......@@ -69,6 +77,8 @@ class CardEncoder(nn.Module):
x_id = layer_norm()(x_id)
x_loc = x1[:, :, 0]
c_mask = x_loc == 0
c_mask = c_mask.at[:, 0].set(False)
f_loc = layer_norm()(embed(9, c)(x_loc))
x_seq = x1[:, :, 1]
......@@ -97,7 +107,7 @@ class CardEncoder(nn.Module):
f_cards = jnp.concatenate([x_id, x_f], axis=-1)
f_cards = f_cards + f_loc + f_seq
return f_cards
return f_cards, c_mask
class GlobalEncoder(nn.Module):
......@@ -153,6 +163,8 @@ class Encoder(nn.Module):
param_dtype: jnp.dtype = jnp.float32
freeze_id: bool = False
use_history: bool = True
card_mask: bool = False
noam: bool = False
@nn.compact
def __call__(self, x):
......@@ -188,7 +200,7 @@ class Encoder(nn.Module):
x_id = jax.lax.stop_gradient(x_id)
# Cards
f_cards = CardEncoder(
f_cards, c_mask = CardEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)(x_id, x_cards[:, :, 2:])
g_card_embed = self.param(
'g_card_embed',
......@@ -196,10 +208,16 @@ class Encoder(nn.Module):
(1, c), self.param_dtype)
f_g_card = jnp.tile(g_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_g_card, f_cards], axis=1)
if self.card_mask:
c_mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=c_mask.dtype), c_mask], axis=1)
else:
c_mask = None
num_heads = max(2, c // 128)
for _ in range(self.num_layers):
f_cards = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(f_cards)
f_cards = get_encoder_layer_cls(
self.noam, num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_cards, src_key_padding_mask=c_mask)
f_cards = layer_norm(dtype=self.dtype)(f_cards)
f_g_card = f_cards[:, 0]
......@@ -294,6 +312,32 @@ class Actor(nn.Module):
return logits
class FiLMActor(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
noam: bool = False
@nn.compact
def __call__(self, f_state, f_actions, mask):
f_state = f_state.astype(self.dtype)
f_actions = f_actions.astype(self.dtype)
c = self.channels
t = nn.Dense(c * 4, dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
a_s, a_b, o_s, o_b = jnp.split(t[:, None, :], 4, axis=-1)
num_heads = max(2, c // 128)
f_actions = get_encoder_layer_cls(
self.noam, num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_actions, mask, a_s, a_b, o_s, o_b)
logits = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=nn.initializers.orthogonal(0.01))(f_actions)[:, :, 0]
big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits)
return logits
class Critic(nn.Module):
channels: Sequence[int] = (128, 128, 128)
dtype: Optional[jnp.dtype] = None
......@@ -340,10 +384,29 @@ def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True
return rstate, f_state
@dataclass
class ModelArgs:
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
use_history: bool = True
"""whether to use history actions as input for agent"""
card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer"""
rnn_type: Optional[Literal['lstm', 'gru', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN"""
film: bool = False
"""whether to use FiLM for the actor"""
noam: bool = False
"""whether to use Noam architecture for the transformer layer"""
class RNNAgent(nn.Module):
channels: int = 128
num_layers: int = 2
num_channels: int = 128
rnn_channels: int = 512
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32
......@@ -351,11 +414,14 @@ class RNNAgent(nn.Module):
switch: bool = True
freeze_id: bool = False
use_history: bool = True
card_mask: bool = False
rnn_type: str = 'lstm'
film: bool = False
noam: bool = False
@nn.compact
def __call__(self, x, rstate, done=None, switch_or_main=None):
c = self.channels
c = self.num_channels
encoder = Encoder(
channels=c,
num_layers=self.num_layers,
......@@ -364,6 +430,8 @@ class RNNAgent(nn.Module):
param_dtype=self.param_dtype,
freeze_id=self.freeze_id,
use_history=self.use_history,
card_mask=self.card_mask,
noam=self.noam,
)
f_actions, f_state, mask, valid = encoder(x)
......@@ -401,6 +469,10 @@ class RNNAgent(nn.Module):
rstate, f_state_r = rnn_step_by_main(
rnn_layer, rstate, f_state, done, switch_or_main)
if self.film:
actor = FiLMActor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam)
else:
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
critic = Critic(
......
......@@ -596,7 +596,6 @@ class GLUMlpBlock(nn.Module):
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
shard=self.shard,
) for _ in range(3)
]
......@@ -631,7 +630,10 @@ class EncoderLayer(nn.Module):
deterministic: bool = True
@nn.compact
def __call__(self, inputs, src_key_padding_mask=None):
def __call__(
self, inputs, src_key_padding_mask=None,
attn_scale=None, attn_bias=None,
output_scale=None, output_bias=None):
inputs = jnp.asarray(inputs, self.dtype)
x = nn.LayerNorm(epsilon=self.layer_norm_epsilon,
dtype=self.dtype, name="ln_1")(inputs)
......@@ -648,6 +650,11 @@ class EncoderLayer(nn.Module):
x = nn.Dropout(rate=self.resid_pdrop)(
x, deterministic=self.deterministic)
if attn_scale is not None:
x = x * attn_scale
if attn_bias is not None:
x = x + attn_bias
x = x + inputs
y = nn.LayerNorm(epsilon=self.layer_norm_epsilon,
......@@ -662,7 +669,13 @@ class EncoderLayer(nn.Module):
name="mlp")(y)
y = nn.Dropout(rate=self.resid_pdrop)(
y, deterministic=self.deterministic)
if output_scale is not None:
y = y * output_scale
if output_bias is not None:
y = y + output_bias
y = x + y
return y
......@@ -733,8 +746,9 @@ class DecoderLayer(nn.Module):
class LlamaEncoderLayer(nn.Module):
n_heads: int
intermediate_size: int
intermediate_size: Optional[int] = None
n_positions: int = 512
rope: bool = True
dtype: Any = None
param_dtype: Any = jnp.float32
attn_pdrop: float = 0.0
......@@ -745,11 +759,17 @@ class LlamaEncoderLayer(nn.Module):
deterministic: bool = True
@nn.compact
def __call__(self, inputs, src_key_padding_mask=None):
def __call__(
self, inputs, src_key_padding_mask=None,
attn_scale=None, attn_bias=None,
output_scale=None, output_bias=None):
features = inputs.shape[-1]
intermediate_size = self.intermediate_size or 2 * features
x = RMSNorm(epsilon=self.rms_norm_eps,
dtype=self.dtype, name="ln_1")(inputs)
x = MultiheadAttention(
features=x.shape[-1],
features=features,
num_heads=self.n_heads,
max_len=self.n_positions,
dtype=self.dtype,
......@@ -757,19 +777,24 @@ class LlamaEncoderLayer(nn.Module):
kernel_init=self.kernel_init,
qkv_bias=False,
out_bias=False,
rope=True,
rope=self.rope,
dropout_rate=self.attn_pdrop,
deterministic=self.deterministic,
name="attn")(x, x, x, key_padding_mask=src_key_padding_mask)
x = nn.Dropout(rate=self.resid_pdrop)(
x, deterministic=self.deterministic)
if attn_scale is not None:
x = x * attn_scale
if attn_bias is not None:
x = x + attn_bias
x = x + inputs
y = RMSNorm(epsilon=self.rms_norm_eps,
dtype=self.dtype, name="ln_2")(x)
y = GLUMlpBlock(
intermediate_size=self.intermediate_size,
intermediate_size=intermediate_size,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
......@@ -777,6 +802,12 @@ class LlamaEncoderLayer(nn.Module):
name="mlp")(y)
y = nn.Dropout(rate=self.resid_pdrop)(
y, deterministic=self.deterministic)
if output_scale is not None:
y = y * output_scale
if output_bias is not None:
y = y + output_bias
y = x + y
return y
......@@ -785,6 +816,7 @@ class LlamaDecoderLayer(nn.Module):
n_heads: int
intermediate_size: int
n_positions: int = 512
rope: bool = True
dtype: Any = None
param_dtype: Any = jnp.float32
attn_pdrop: float = 0.0
......@@ -808,7 +840,7 @@ class LlamaDecoderLayer(nn.Module):
kernel_init=self.kernel_init,
qkv_bias=False,
out_bias=False,
rope=True,
rope=self.rope,
dropout_rate=self.attn_pdrop,
deterministic=self.deterministic,
name="self_attn")(x, x, x, key_padding_mask=tgt_key_padding_mask)
......@@ -827,7 +859,7 @@ class LlamaDecoderLayer(nn.Module):
kernel_init=self.kernel_init,
qkv_bias=False,
out_bias=False,
rope=True,
rope=self.rope,
dropout_rate=self.attn_pdrop,
deterministic=self.deterministic,
name="cross_attn")(y, memory, memory, key_padding_mask=memory_key_padding_mask)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment