Commit ec9f3e0c authored by sbl1996@126.com's avatar sbl1996@126.com

support different rnn

parent b7d52f29
......@@ -2,9 +2,10 @@ import sys
import time
import os
import random
from typing import Optional, Literal
from typing import Optional
from dataclasses import dataclass
from tqdm import tqdm
from functools import partial
import ygoenv
import numpy as np
......@@ -17,7 +18,7 @@ import flax
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.agent2 import RNNAgent
@dataclass
......@@ -43,6 +44,10 @@ class Args:
"""the number of history actions to use"""
num_embeddings: Optional[int] = None
"""the number of embeddings of the agent"""
use_history1: bool = True
"""whether to use history actions as input for agent1"""
use_history2: bool = True
"""whether to use history actions as input for agent2"""
verbose: bool = False
"""whether to print debug information"""
......@@ -60,6 +65,10 @@ class Args:
"""the number of channels for the agent"""
rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent"""
rnn_type1: Optional[str] = "lstm"
"""the type of RNN to use for agent1, None for no RNN"""
rnn_type2: Optional[str] = "lstm"
"""the type of RNN to use for agent2, None for no RNN"""
checkpoint1: str = "checkpoints/agent.pt"
"""the checkpoint to load for the first agent, must be a `flax_model` file"""
checkpoint2: str = "checkpoints/agent.pt"
......@@ -72,19 +81,25 @@ class Args:
"""the number of threads to use for envpool, defaults to `num_envs`"""
def create_agent(args):
return PPOLSTMAgent(
def create_agent1(args):
return RNNAgent(
channels=args.num_channels,
num_layers=args.num_layers,
lstm_channels=args.rnn_channels,
rnn_channels=args.rnn_channels,
embedding_shape=args.num_embeddings,
use_history=args.use_history1,
rnn_type=args.rnn_type1,
)
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
np.zeros((num_envs, rnn_channels)),
def create_agent2(args):
return RNNAgent(
channels=args.num_channels,
num_layers=args.num_layers,
rnn_channels=args.rnn_channels,
embedding_shape=args.num_embeddings,
use_history=args.use_history2,
rnn_type=args.rnn_type2,
)
......@@ -137,28 +152,34 @@ if __name__ == "__main__":
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
agent = create_agent(args)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
rstate = init_rnn_state(1, args.rnn_channels)
params = jax.jit(agent.init)(agent_key, (rstate, sample_obs))
agent1 = create_agent1(args)
rstate = agent1.init_rnn_state(1)
params1 = jax.jit(agent1.init)(agent_key, (rstate, sample_obs))
with open(args.checkpoint1, "rb") as f:
params1 = flax.serialization.from_bytes(params, f.read())
params1 = flax.serialization.from_bytes(params1, f.read())
if args.checkpoint1 == args.checkpoint2:
params2 = params1
else:
agent2 = create_agent2(args)
rstate = agent2.init_rnn_state(1)
params2 = jax.jit(agent2.init)(agent_key, (rstate, sample_obs))
with open(args.checkpoint2, "rb") as f:
params2 = flax.serialization.from_bytes(params, f.read())
params2 = flax.serialization.from_bytes(params2, f.read())
params1 = jax.device_put(params1)
params2 = jax.device_put(params2)
@jax.jit
def get_probs(params, rstate, obs, done=None):
agent = create_agent(args)
@partial(jax.jit, static_argnums=(4,))
def get_probs(params, rstate, obs, done=None, model_id=1):
if model_id == 1:
agent = create_agent1(args)
else:
agent = create_agent2(args)
next_rstate, logits = agent.apply(params, (rstate, obs))[:2]
probs = jax.nn.softmax(logits, axis=-1)
if done is not None:
......@@ -168,8 +189,8 @@ if __name__ == "__main__":
if args.num_envs != 1:
@jax.jit
def get_probs2(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, probs1 = get_probs(params1, rstate1, obs)
next_rstate2, probs2 = get_probs(params2, rstate2, obs)
next_rstate1, probs1 = get_probs(params1, rstate1, obs, None, 1)
next_rstate2, probs2 = get_probs(params2, rstate2, obs, None, 2)
probs = jnp.where(main[:, None], probs1, probs2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
......@@ -185,9 +206,9 @@ if __name__ == "__main__":
else:
def predict_fn(rstate1, rstate2, obs, main, done):
if main[0]:
rstate1, probs = get_probs(params1, rstate1, obs, done)
rstate1, probs = get_probs(params1, rstate1, obs, done, 1)
else:
rstate2, probs = get_probs(params2, rstate2, obs, done)
rstate2, probs = get_probs(params2, rstate2, obs, done, 2)
return rstate1, rstate2, np.array(probs)
obs, infos = envs.reset()
......@@ -209,7 +230,8 @@ if __name__ == "__main__":
np.zeros(num_envs // 2, dtype=np.int64),
np.ones(num_envs - num_envs // 2, dtype=np.int64)
])
rstate1 = rstate2 = init_rnn_state(num_envs, args.rnn_channels)
rstate1 = agent1.init_rnn_state(num_envs)
rstate2 = agent2.init_rnn_state(num_envs)
if not args.verbose:
pbar = tqdm(total=args.num_episodes)
......
......@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent2 import LSTMAgent
from ygoai.rl.jax.agent2 import RNNAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import clipped_surrogate_pg_loss, vtrace_2p0s, mse_loss, entropy_loss, simple_policy_loss, ach_loss, policy_gradient_loss
......@@ -80,8 +80,6 @@ class Args:
"""whether to use history actions as input for agent"""
eval_use_history: bool = True
"""whether to use history actions as input for eval agent"""
use_rnn: bool = True
"""whether to use RNN for the agent"""
total_timesteps: int = 50000000000
"""total timesteps of the experiments"""
......@@ -150,6 +148,10 @@ class Args:
"""the number of channels for the agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
rnn_type: Optional[str] = "lstm"
"""the type of RNN to use, None for no RNN"""
eval_rnn_type: Optional[str] = "lstm"
"""the type of RNN to use for evaluation, None for no RNN"""
actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
"""the device ids that actor workers will use"""
......@@ -222,18 +224,18 @@ class Transition(NamedTuple):
def create_agent(args, multi_step=False, eval=False):
return LSTMAgent(
return RNNAgent(
channels=args.num_channels,
num_layers=args.num_layers,
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
lstm_channels=args.rnn_channels,
rnn_channels=args.rnn_channels,
switch=args.switch,
multi_step=multi_step,
freeze_id=args.freeze_id,
use_history=args.use_history if not eval else args.eval_use_history,
no_rnn=(not args.use_rnn) if not eval else False
rnn_type=args.rnn_type if not eval else args.eval_rnn_type,
)
......@@ -285,22 +287,19 @@ def rollout(
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
@partial(jax.jit, static_argnums=(2,))
def get_logits(
params: flax.core.FrozenDict, inputs, eval=False):
rstate, logits = create_agent(args, eval=eval).apply(params, inputs)[:2]
return rstate, logits
agent = create_agent(args)
eval_agent = create_agent(args, eval=True)
@jax.jit
def get_action(
params: flax.core.FrozenDict, inputs):
rstate, logits = get_logits(params, inputs)
rstate, logits = eval_agent.apply(params, inputs)[:2]
return rstate, logits.argmax(axis=1)
@jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, logits1 = get_logits(params1, (rstate1, obs))
next_rstate2, logits2 = get_logits(params2, (rstate2, obs), True)
next_rstate1, logits1 = agent.apply(params1, (rstate1, obs))[:2]
next_rstate2, logits2 = eval_agent.apply(params2, (rstate2, obs))[:2]
logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
......@@ -320,7 +319,7 @@ def rollout(
rstate = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, logits = get_logits(params, (rstate, next_obs))
rstate, logits = agent.apply(params, (rstate, next_obs))[:2]
rstate1 = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate, rstate1)
rstate2 = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x2, x1), rstate, rstate2)
rstate1, rstate2 = jax.tree.map(
......@@ -335,10 +334,11 @@ def rollout(
next_obs, info = envs.reset()
next_to_play = info["to_play"]
next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
next_rstate1 = next_rstate2 = init_rnn_state(
args.local_num_envs, args.rnn_channels)
eval_rstate = init_rnn_state(
args.local_eval_episodes, args.rnn_channels)
next_rstate1 = next_rstate2 = agent.init_rnn_state(args.local_num_envs)
eval_rstate1 = agent.init_rnn_state(args.local_eval_episodes)
eval_rstate2 = eval_agent.init_rnn_state(args.local_eval_episodes)
main_player = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
......@@ -452,11 +452,11 @@ def rollout(
if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x)
eval_return, eval_ep_len, eval_win_rate = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate2)
else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_return, eval_ep_len, eval_win_rate = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate1, eval_rstate2)
eval_time = time.time() - _start
other_time += eval_time
eval_stats = np.array([eval_time, eval_return, eval_win_rate], dtype=np.float32)
......@@ -606,8 +606,9 @@ if __name__ == "__main__":
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac
rstate = init_rnn_state(1, args.rnn_channels)
# rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args)
rstate = agent.init_rnn_state(1)
params = agent.init(init_key, (rstate, sample_obs))
if embeddings is not None:
unknown_embed = embeddings.mean(axis=0)
......@@ -641,20 +642,15 @@ if __name__ == "__main__":
# print(agent.tabulate(agent_key, sample_obs))
if args.eval_checkpoint:
eval_agent = create_agent(args, eval=True)
eval_rstate = eval_agent.init_rnn_state(1)
eval_params = eval_agent.init(init_key, (eval_rstate, sample_obs))
with open(args.eval_checkpoint, "rb") as f:
eval_params = flax.serialization.from_bytes(params, f.read())
eval_params = flax.serialization.from_bytes(eval_params, f.read())
print(f"loaded eval checkpoint from {args.eval_checkpoint}")
else:
eval_params = None
@jax.jit
def get_logits_and_value(
params: flax.core.FrozenDict, inputs,
):
rstate, logits, value, valid = create_agent(
args, multi_step=True).apply(params, inputs)
return logits, value.squeeze(-1)
def loss_fn(
params, rstate1, rstate2, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, mask, next_value):
......@@ -671,7 +667,9 @@ if __name__ == "__main__":
dones = dones | next_dones
inputs = (rstate1, rstate2, obs, dones, switch_or_mains)
new_logits, new_values = get_logits_and_value(params, inputs)
_rstate, new_logits, new_values, _valid = create_agent(
args, multi_step=True).apply(params, inputs)
new_values = new_values.squeeze(-1)
ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions)
......
......@@ -63,6 +63,8 @@ class Args:
"""the number of channels for the agent"""
rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent"""
rnn_type: Optional[str] = "lstm"
"""the type of RNN to use for agent, None for no RNN"""
checkpoint: Optional[str] = None
"""the checkpoint to load, must be a `flax_model` file"""
......@@ -75,18 +77,12 @@ class Args:
def create_agent(args):
return PPOLSTMAgent(
return RNNAgent(
channels=args.num_channels,
num_layers=args.num_layers,
lstm_channels=args.rnn_channels,
rnn_channels=args.rnn_channels,
embedding_shape=args.num_embeddings,
)
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
np.zeros((num_envs, rnn_channels)),
rnn_type=args.rnn_type,
)
......@@ -139,7 +135,7 @@ if __name__ == "__main__":
import jax
import jax.numpy as jnp
import flax
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.agent2 import RNNAgent
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
......@@ -148,7 +144,7 @@ if __name__ == "__main__":
key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
rstate = init_rnn_state(1, args.rnn_channels)
rstate = agent.init_rnn_state(1)
params = jax.jit(agent.init)(agent_key, (rstate, sample_obs))
with open(args.checkpoint, "rb") as f:
......@@ -158,7 +154,7 @@ if __name__ == "__main__":
@jax.jit
def get_probs_and_value(params, rstate, obs, done):
agent = create_agent(args)
agent = agent
next_rstate, logits, value = agent.apply(params, (rstate, obs))[:3]
probs = jax.nn.softmax(logits, axis=-1)
next_rstate = jax.tree.map(
......@@ -173,6 +169,7 @@ if __name__ == "__main__":
obs, infos = envs.reset()
print(obs)
next_to_play = infos['to_play']
dones = np.zeros(num_envs, dtype=np.bool_)
......
from typing import Tuple, Union, Optional, Sequence
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
......@@ -272,8 +273,6 @@ class Encoder(nn.Module):
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
f_state = MLP((c * 2, c), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state)
# TODO: LSTM
return f_actions, f_state, a_mask, valid
......@@ -309,10 +308,34 @@ class Critic(nn.Module):
return x
class LSTMAgent(nn.Module):
def rnn_forward_2p(rnn_layer, rstate1, rstate2, f_state, done, switch_or_main, switch=True):
if switch:
def body_fn(cell, carry, x, done, switch):
rstate, init_rstate2 = carry
rstate, y = cell(rstate, x)
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
rstate = jax.tree.map(lambda x, y: jnp.where(switch[:, None], x, y), init_rstate2, rstate)
return (rstate, init_rstate2), y
else:
def body_fn(cell, carry, x, done, main):
rstate1, rstate2 = carry
rstate = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, y = cell(rstate, x)
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate1, rstate2 = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
return (rstate1, rstate2), y
scan = nn.scan(
body_fn, variable_broadcast='params',
split_rngs={'params': False})
rstate, f_state = scan(rnn_layer, (rstate1, rstate2), f_state, done, switch_or_main)
return rstate, f_state
class RNNAgent(nn.Module):
channels: int = 128
num_layers: int = 2
lstm_channels: int = 512
rnn_channels: int = 512
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
......@@ -320,15 +343,13 @@ class LSTMAgent(nn.Module):
switch: bool = True
freeze_id: bool = False
use_history: bool = True
no_rnn: bool = False
rnn_type: str = 'lstm'
@nn.compact
def __call__(self, inputs):
if self.multi_step:
# (num_steps * batch_size, ...)
rstate1, rstate2, x, done, switch_or_main = inputs
batch_size = rstate1[0].shape[0]
num_steps = done.shape[0] // batch_size
*rstate, x, done, switch_or_main = inputs
else:
rstate, x = inputs
......@@ -345,43 +366,48 @@ class LSTMAgent(nn.Module):
f_actions, f_state, mask, valid = encoder(x)
lstm_layer = nn.OptimizedLSTMCell(
self.lstm_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
if self.multi_step:
if self.switch:
def body_fn(cell, carry, x, done, switch):
rstate, init_rstate2 = carry
rstate, y = cell(rstate, x)
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
rstate = jax.tree.map(lambda x, y: jnp.where(switch[:, None], x, y), init_rstate2, rstate)
return (rstate, init_rstate2), y
if self.rnn_type in ['lstm', 'none']:
rnn_layer = nn.OptimizedLSTMCell(
self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
elif self.rnn_type == 'gru':
rnn_layer = nn.GRUCell(
self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
elif self.rnn_type is None:
rnn_layer = None
if rnn_layer is None:
f_state_r = f_state
elif self.rnn_type == 'none':
f_state_r = jnp.concatenate([f_state for i in range(self.rnn_channels // c)], axis=-1)
else:
def body_fn(cell, carry, x, done, main):
rstate1, rstate2 = carry
rstate = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, y = cell(rstate, x)
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate1, rstate2 = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
return (rstate1, rstate2), y
scan = nn.scan(
body_fn, variable_broadcast='params',
split_rngs={'params': False})
if self.multi_step:
rstate1, rstate2 = rstate
batch_size = jax.tree.leaves(rstate1)[0].shape[0]
num_steps = done.shape[0] // batch_size
f_state_r, done, switch_or_main = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main))
rstate, f_state_r = scan(lstm_layer, (rstate1, rstate2), f_state_r, done, switch_or_main)
rstate, f_state_r = rnn_forward_2p(
rnn_layer, rstate1, rstate2, f_state_r, done, switch_or_main, self.switch)
f_state_r = f_state_r.reshape((-1, f_state_r.shape[-1]))
else:
rstate, f_state_r = lstm_layer(rstate, f_state)
rstate, f_state_r = rnn_layer(rstate, f_state)
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
if self.no_rnn:
f_state_r = jnp.concatenate([f_state for i in range(self.lstm_channels // c)], axis=-1)
logits = actor(f_state_r, f_actions, mask)
value = critic(f_state_r)
return rstate, logits, value, valid
def init_rnn_state(self, batch_size):
if self.rnn_type in ['lstm', 'none']:
return (
np.zeros((batch_size, self.rnn_channels)),
np.zeros((batch_size, self.rnn_channels)),
)
elif self.rnn_type == 'gru':
return np.zeros((batch_size, self.rnn_channels))
else:
return None
\ No newline at end of file
......@@ -36,7 +36,7 @@ def evaluate(envs, num_episodes, predict_fn, rnn_state=None):
return eval_return, eval_ep_len, eval_win_rate
def battle(envs, num_episodes, predict_fn, init_rnn_state=None):
def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None):
num_envs = envs.num_envs
episode_rewards = []
episode_lengths = []
......@@ -50,7 +50,6 @@ def battle(envs, num_episodes, predict_fn, init_rnn_state=None):
np.zeros(num_envs // 2, dtype=np.int64),
np.ones(num_envs - num_envs // 2, dtype=np.int64)
])
rstate1 = rstate2 = init_rnn_state
while True:
main = next_to_play == main_player
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment