Commit c02fbd19 authored by sbl1996@126.com's avatar sbl1996@126.com

Add option to use history

parent 0e9969c5
...@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter ...@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent2 import PPOLSTMAgent from ygoai.rl.jax.agent2 import LSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import clipped_surrogate_pg_loss, vtrace_2p0s, mse_loss, entropy_loss, simple_policy_loss, ach_loss, policy_gradient_loss from ygoai.rl.jax import clipped_surrogate_pg_loss, vtrace_2p0s, mse_loss, entropy_loss, simple_policy_loss, ach_loss, policy_gradient_loss
...@@ -76,6 +76,10 @@ class Args: ...@@ -76,6 +76,10 @@ class Args:
"""the number of history actions to use""" """the number of history actions to use"""
greedy_reward: bool = False greedy_reward: bool = False
"""whether to use greedy reward (faster kill higher reward)""" """whether to use greedy reward (faster kill higher reward)"""
use_history: bool = True
"""whether to use history actions as input for agent"""
eval_use_history: bool = True
"""whether to use history actions as input for eval agent"""
total_timesteps: int = 50000000000 total_timesteps: int = 50000000000
"""total timesteps of the experiments""" """total timesteps of the experiments"""
...@@ -210,8 +214,8 @@ class Transition(NamedTuple): ...@@ -210,8 +214,8 @@ class Transition(NamedTuple):
next_dones: list next_dones: list
def create_agent(args, multi_step=False): def create_agent(args, multi_step=False, eval=False):
return PPOLSTMAgent( return LSTMAgent(
channels=args.num_channels, channels=args.num_channels,
num_layers=args.num_layers, num_layers=args.num_layers,
embedding_shape=args.num_embeddings, embedding_shape=args.num_embeddings,
...@@ -221,6 +225,7 @@ def create_agent(args, multi_step=False): ...@@ -221,6 +225,7 @@ def create_agent(args, multi_step=False):
switch=args.switch, switch=args.switch,
multi_step=multi_step, multi_step=multi_step,
freeze_id=args.freeze_id, freeze_id=args.freeze_id,
use_history=args.use_history if not eval else args.eval_use_history,
) )
...@@ -272,10 +277,10 @@ def rollout( ...@@ -272,10 +277,10 @@ def rollout(
avg_ep_returns = deque(maxlen=1000) avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000) avg_win_rates = deque(maxlen=1000)
@jax.jit @partial(jax.jit, static_argnums=(2,))
def get_logits( def get_logits(
params: flax.core.FrozenDict, inputs): params: flax.core.FrozenDict, inputs, eval=False):
rstate, logits = create_agent(args).apply(params, inputs)[:2] rstate, logits = create_agent(args, eval=eval).apply(params, inputs)[:2]
return rstate, logits return rstate, logits
@jax.jit @jax.jit
...@@ -287,7 +292,7 @@ def rollout( ...@@ -287,7 +292,7 @@ def rollout(
@jax.jit @jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done): def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, logits1 = get_logits(params1, (rstate1, obs)) next_rstate1, logits1 = get_logits(params1, (rstate1, obs))
next_rstate2, logits2 = get_logits(params2, (rstate2, obs)) next_rstate2, logits2 = get_logits(params2, (rstate2, obs), True)
logits = jnp.where(main[:, None], logits1, logits2) logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map( rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1) lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
......
...@@ -151,6 +151,7 @@ class Encoder(nn.Module): ...@@ -151,6 +151,7 @@ class Encoder(nn.Module):
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
freeze_id: bool = False freeze_id: bool = False
use_history: bool = True
@nn.compact @nn.compact
def __call__(self, x): def __call__(self, x):
...@@ -266,6 +267,8 @@ class Encoder(nn.Module): ...@@ -266,6 +267,8 @@ class Encoder(nn.Module):
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True) f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
# State # State
if not self.use_history:
f_g_h_actions = jnp.zeros_like(f_g_h_actions)
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1) f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
f_state = MLP((c * 2, c), dtype=self.dtype, param_dtype=self.param_dtype)(f_state) f_state = MLP((c * 2, c), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state) f_state = layer_norm(dtype=self.dtype)(f_state)
...@@ -306,35 +309,7 @@ class Critic(nn.Module): ...@@ -306,35 +309,7 @@ class Critic(nn.Module):
return x return x
class PPOAgent(nn.Module): class LSTMAgent(nn.Module):
channels: int = 128
num_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.channels
encoder = Encoder(
channels=c,
num_layers=self.num_layers,
embedding_shape=self.embedding_shape,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
f_actions, f_state, mask, valid = encoder(x)
logits = actor(f_state, f_actions, mask)
value = critic(f_state)
return logits, value, valid
class PPOLSTMAgent(nn.Module):
channels: int = 128 channels: int = 128
num_layers: int = 2 num_layers: int = 2
lstm_channels: int = 512 lstm_channels: int = 512
...@@ -344,6 +319,7 @@ class PPOLSTMAgent(nn.Module): ...@@ -344,6 +319,7 @@ class PPOLSTMAgent(nn.Module):
multi_step: bool = False multi_step: bool = False
switch: bool = True switch: bool = True
freeze_id: bool = False freeze_id: bool = False
use_history: bool = True
@nn.compact @nn.compact
def __call__(self, inputs): def __call__(self, inputs):
...@@ -363,6 +339,7 @@ class PPOLSTMAgent(nn.Module): ...@@ -363,6 +339,7 @@ class PPOLSTMAgent(nn.Module):
dtype=self.dtype, dtype=self.dtype,
param_dtype=self.param_dtype, param_dtype=self.param_dtype,
freeze_id=self.freeze_id, freeze_id=self.freeze_id,
use_history=self.use_history,
) )
f_actions, f_state, mask, valid = encoder(x) f_actions, f_state, mask, valid = encoder(x)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment