Commit c02fbd19 authored by sbl1996@126.com's avatar sbl1996@126.com

Add option to use history

parent 0e9969c5
......@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.agent2 import LSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import clipped_surrogate_pg_loss, vtrace_2p0s, mse_loss, entropy_loss, simple_policy_loss, ach_loss, policy_gradient_loss
......@@ -76,6 +76,10 @@ class Args:
"""the number of history actions to use"""
greedy_reward: bool = False
"""whether to use greedy reward (faster kill higher reward)"""
use_history: bool = True
"""whether to use history actions as input for agent"""
eval_use_history: bool = True
"""whether to use history actions as input for eval agent"""
total_timesteps: int = 50000000000
"""total timesteps of the experiments"""
......@@ -210,8 +214,8 @@ class Transition(NamedTuple):
next_dones: list
def create_agent(args, multi_step=False):
return PPOLSTMAgent(
def create_agent(args, multi_step=False, eval=False):
return LSTMAgent(
channels=args.num_channels,
num_layers=args.num_layers,
embedding_shape=args.num_embeddings,
......@@ -221,6 +225,7 @@ def create_agent(args, multi_step=False):
switch=args.switch,
multi_step=multi_step,
freeze_id=args.freeze_id,
use_history=args.use_history if not eval else args.eval_use_history,
)
......@@ -272,10 +277,10 @@ def rollout(
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
@jax.jit
@partial(jax.jit, static_argnums=(2,))
def get_logits(
params: flax.core.FrozenDict, inputs):
rstate, logits = create_agent(args).apply(params, inputs)[:2]
params: flax.core.FrozenDict, inputs, eval=False):
rstate, logits = create_agent(args, eval=eval).apply(params, inputs)[:2]
return rstate, logits
@jax.jit
......@@ -287,7 +292,7 @@ def rollout(
@jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, logits1 = get_logits(params1, (rstate1, obs))
next_rstate2, logits2 = get_logits(params2, (rstate2, obs))
next_rstate2, logits2 = get_logits(params2, (rstate2, obs), True)
logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
......
......@@ -151,6 +151,7 @@ class Encoder(nn.Module):
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
freeze_id: bool = False
use_history: bool = True
@nn.compact
def __call__(self, x):
......@@ -266,6 +267,8 @@ class Encoder(nn.Module):
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
# State
if not self.use_history:
f_g_h_actions = jnp.zeros_like(f_g_h_actions)
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
f_state = MLP((c * 2, c), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state)
......@@ -306,35 +309,7 @@ class Critic(nn.Module):
return x
class PPOAgent(nn.Module):
channels: int = 128
num_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.channels
encoder = Encoder(
channels=c,
num_layers=self.num_layers,
embedding_shape=self.embedding_shape,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
f_actions, f_state, mask, valid = encoder(x)
logits = actor(f_state, f_actions, mask)
value = critic(f_state)
return logits, value, valid
class PPOLSTMAgent(nn.Module):
class LSTMAgent(nn.Module):
channels: int = 128
num_layers: int = 2
lstm_channels: int = 512
......@@ -344,6 +319,7 @@ class PPOLSTMAgent(nn.Module):
multi_step: bool = False
switch: bool = True
freeze_id: bool = False
use_history: bool = True
@nn.compact
def __call__(self, inputs):
......@@ -363,6 +339,7 @@ class PPOLSTMAgent(nn.Module):
dtype=self.dtype,
param_dtype=self.param_dtype,
freeze_id=self.freeze_id,
use_history=self.use_history,
)
f_actions, f_state, mask, valid = encoder(x)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment