Commit ec9f3e0c authored by sbl1996@126.com's avatar sbl1996@126.com

support different rnn

parent b7d52f29
...@@ -2,9 +2,10 @@ import sys ...@@ -2,9 +2,10 @@ import sys
import time import time
import os import os
import random import random
from typing import Optional, Literal from typing import Optional
from dataclasses import dataclass from dataclasses import dataclass
from tqdm import tqdm from tqdm import tqdm
from functools import partial
import ygoenv import ygoenv
import numpy as np import numpy as np
...@@ -17,7 +18,7 @@ import flax ...@@ -17,7 +18,7 @@ import flax
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.jax.agent2 import PPOLSTMAgent from ygoai.rl.jax.agent2 import RNNAgent
@dataclass @dataclass
...@@ -43,6 +44,10 @@ class Args: ...@@ -43,6 +44,10 @@ class Args:
"""the number of history actions to use""" """the number of history actions to use"""
num_embeddings: Optional[int] = None num_embeddings: Optional[int] = None
"""the number of embeddings of the agent""" """the number of embeddings of the agent"""
use_history1: bool = True
"""whether to use history actions as input for agent1"""
use_history2: bool = True
"""whether to use history actions as input for agent2"""
verbose: bool = False verbose: bool = False
"""whether to print debug information""" """whether to print debug information"""
...@@ -60,6 +65,10 @@ class Args: ...@@ -60,6 +65,10 @@ class Args:
"""the number of channels for the agent""" """the number of channels for the agent"""
rnn_channels: Optional[int] = 512 rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent""" """the number of rnn channels for the agent"""
rnn_type1: Optional[str] = "lstm"
"""the type of RNN to use for agent1, None for no RNN"""
rnn_type2: Optional[str] = "lstm"
"""the type of RNN to use for agent2, None for no RNN"""
checkpoint1: str = "checkpoints/agent.pt" checkpoint1: str = "checkpoints/agent.pt"
"""the checkpoint to load for the first agent, must be a `flax_model` file""" """the checkpoint to load for the first agent, must be a `flax_model` file"""
checkpoint2: str = "checkpoints/agent.pt" checkpoint2: str = "checkpoints/agent.pt"
...@@ -72,19 +81,25 @@ class Args: ...@@ -72,19 +81,25 @@ class Args:
"""the number of threads to use for envpool, defaults to `num_envs`""" """the number of threads to use for envpool, defaults to `num_envs`"""
def create_agent(args): def create_agent1(args):
return PPOLSTMAgent( return RNNAgent(
channels=args.num_channels, channels=args.num_channels,
num_layers=args.num_layers, num_layers=args.num_layers,
lstm_channels=args.rnn_channels, rnn_channels=args.rnn_channels,
embedding_shape=args.num_embeddings, embedding_shape=args.num_embeddings,
use_history=args.use_history1,
rnn_type=args.rnn_type1,
) )
def init_rnn_state(num_envs, rnn_channels): def create_agent2(args):
return ( return RNNAgent(
np.zeros((num_envs, rnn_channels)), channels=args.num_channels,
np.zeros((num_envs, rnn_channels)), num_layers=args.num_layers,
rnn_channels=args.rnn_channels,
embedding_shape=args.num_embeddings,
use_history=args.use_history2,
rnn_type=args.rnn_type2,
) )
...@@ -137,28 +152,34 @@ if __name__ == "__main__": ...@@ -137,28 +152,34 @@ if __name__ == "__main__":
envs.num_envs = num_envs envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs) envs = RecordEpisodeStatistics(envs)
agent = create_agent(args)
key = jax.random.PRNGKey(args.seed) key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2) key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample()) sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
rstate = init_rnn_state(1, args.rnn_channels) agent1 = create_agent1(args)
params = jax.jit(agent.init)(agent_key, (rstate, sample_obs)) rstate = agent1.init_rnn_state(1)
params1 = jax.jit(agent1.init)(agent_key, (rstate, sample_obs))
with open(args.checkpoint1, "rb") as f: with open(args.checkpoint1, "rb") as f:
params1 = flax.serialization.from_bytes(params, f.read()) params1 = flax.serialization.from_bytes(params1, f.read())
if args.checkpoint1 == args.checkpoint2: if args.checkpoint1 == args.checkpoint2:
params2 = params1 params2 = params1
else: else:
agent2 = create_agent2(args)
rstate = agent2.init_rnn_state(1)
params2 = jax.jit(agent2.init)(agent_key, (rstate, sample_obs))
with open(args.checkpoint2, "rb") as f: with open(args.checkpoint2, "rb") as f:
params2 = flax.serialization.from_bytes(params, f.read()) params2 = flax.serialization.from_bytes(params2, f.read())
params1 = jax.device_put(params1) params1 = jax.device_put(params1)
params2 = jax.device_put(params2) params2 = jax.device_put(params2)
@jax.jit @partial(jax.jit, static_argnums=(4,))
def get_probs(params, rstate, obs, done=None): def get_probs(params, rstate, obs, done=None, model_id=1):
agent = create_agent(args) if model_id == 1:
agent = create_agent1(args)
else:
agent = create_agent2(args)
next_rstate, logits = agent.apply(params, (rstate, obs))[:2] next_rstate, logits = agent.apply(params, (rstate, obs))[:2]
probs = jax.nn.softmax(logits, axis=-1) probs = jax.nn.softmax(logits, axis=-1)
if done is not None: if done is not None:
...@@ -168,8 +189,8 @@ if __name__ == "__main__": ...@@ -168,8 +189,8 @@ if __name__ == "__main__":
if args.num_envs != 1: if args.num_envs != 1:
@jax.jit @jax.jit
def get_probs2(params1, params2, rstate1, rstate2, obs, main, done): def get_probs2(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, probs1 = get_probs(params1, rstate1, obs) next_rstate1, probs1 = get_probs(params1, rstate1, obs, None, 1)
next_rstate2, probs2 = get_probs(params2, rstate2, obs) next_rstate2, probs2 = get_probs(params2, rstate2, obs, None, 2)
probs = jnp.where(main[:, None], probs1, probs2) probs = jnp.where(main[:, None], probs1, probs2)
rstate1 = jax.tree.map( rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1) lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
...@@ -185,9 +206,9 @@ if __name__ == "__main__": ...@@ -185,9 +206,9 @@ if __name__ == "__main__":
else: else:
def predict_fn(rstate1, rstate2, obs, main, done): def predict_fn(rstate1, rstate2, obs, main, done):
if main[0]: if main[0]:
rstate1, probs = get_probs(params1, rstate1, obs, done) rstate1, probs = get_probs(params1, rstate1, obs, done, 1)
else: else:
rstate2, probs = get_probs(params2, rstate2, obs, done) rstate2, probs = get_probs(params2, rstate2, obs, done, 2)
return rstate1, rstate2, np.array(probs) return rstate1, rstate2, np.array(probs)
obs, infos = envs.reset() obs, infos = envs.reset()
...@@ -209,7 +230,8 @@ if __name__ == "__main__": ...@@ -209,7 +230,8 @@ if __name__ == "__main__":
np.zeros(num_envs // 2, dtype=np.int64), np.zeros(num_envs // 2, dtype=np.int64),
np.ones(num_envs - num_envs // 2, dtype=np.int64) np.ones(num_envs - num_envs // 2, dtype=np.int64)
]) ])
rstate1 = rstate2 = init_rnn_state(num_envs, args.rnn_channels) rstate1 = agent1.init_rnn_state(num_envs)
rstate2 = agent2.init_rnn_state(num_envs)
if not args.verbose: if not args.verbose:
pbar = tqdm(total=args.num_episodes) pbar = tqdm(total=args.num_episodes)
......
...@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter ...@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent2 import LSTMAgent from ygoai.rl.jax.agent2 import RNNAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import clipped_surrogate_pg_loss, vtrace_2p0s, mse_loss, entropy_loss, simple_policy_loss, ach_loss, policy_gradient_loss from ygoai.rl.jax import clipped_surrogate_pg_loss, vtrace_2p0s, mse_loss, entropy_loss, simple_policy_loss, ach_loss, policy_gradient_loss
...@@ -80,8 +80,6 @@ class Args: ...@@ -80,8 +80,6 @@ class Args:
"""whether to use history actions as input for agent""" """whether to use history actions as input for agent"""
eval_use_history: bool = True eval_use_history: bool = True
"""whether to use history actions as input for eval agent""" """whether to use history actions as input for eval agent"""
use_rnn: bool = True
"""whether to use RNN for the agent"""
total_timesteps: int = 50000000000 total_timesteps: int = 50000000000
"""total timesteps of the experiments""" """total timesteps of the experiments"""
...@@ -150,6 +148,10 @@ class Args: ...@@ -150,6 +148,10 @@ class Args:
"""the number of channels for the agent""" """the number of channels for the agent"""
rnn_channels: int = 512 rnn_channels: int = 512
"""the number of channels for the RNN in the agent""" """the number of channels for the RNN in the agent"""
rnn_type: Optional[str] = "lstm"
"""the type of RNN to use, None for no RNN"""
eval_rnn_type: Optional[str] = "lstm"
"""the type of RNN to use for evaluation, None for no RNN"""
actor_device_ids: List[int] = field(default_factory=lambda: [0, 1]) actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
"""the device ids that actor workers will use""" """the device ids that actor workers will use"""
...@@ -222,18 +224,18 @@ class Transition(NamedTuple): ...@@ -222,18 +224,18 @@ class Transition(NamedTuple):
def create_agent(args, multi_step=False, eval=False): def create_agent(args, multi_step=False, eval=False):
return LSTMAgent( return RNNAgent(
channels=args.num_channels, channels=args.num_channels,
num_layers=args.num_layers, num_layers=args.num_layers,
embedding_shape=args.num_embeddings, embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32, dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32, param_dtype=jnp.float32,
lstm_channels=args.rnn_channels, rnn_channels=args.rnn_channels,
switch=args.switch, switch=args.switch,
multi_step=multi_step, multi_step=multi_step,
freeze_id=args.freeze_id, freeze_id=args.freeze_id,
use_history=args.use_history if not eval else args.eval_use_history, use_history=args.use_history if not eval else args.eval_use_history,
no_rnn=(not args.use_rnn) if not eval else False rnn_type=args.rnn_type if not eval else args.eval_rnn_type,
) )
...@@ -284,23 +286,20 @@ def rollout( ...@@ -284,23 +286,20 @@ def rollout(
other_time = 0 other_time = 0
avg_ep_returns = deque(maxlen=1000) avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000) avg_win_rates = deque(maxlen=1000)
@partial(jax.jit, static_argnums=(2,)) agent = create_agent(args)
def get_logits( eval_agent = create_agent(args, eval=True)
params: flax.core.FrozenDict, inputs, eval=False):
rstate, logits = create_agent(args, eval=eval).apply(params, inputs)[:2]
return rstate, logits
@jax.jit @jax.jit
def get_action( def get_action(
params: flax.core.FrozenDict, inputs): params: flax.core.FrozenDict, inputs):
rstate, logits = get_logits(params, inputs) rstate, logits = eval_agent.apply(params, inputs)[:2]
return rstate, logits.argmax(axis=1) return rstate, logits.argmax(axis=1)
@jax.jit @jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done): def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, logits1 = get_logits(params1, (rstate1, obs)) next_rstate1, logits1 = agent.apply(params1, (rstate1, obs))[:2]
next_rstate2, logits2 = get_logits(params2, (rstate2, obs), True) next_rstate2, logits2 = eval_agent.apply(params2, (rstate2, obs))[:2]
logits = jnp.where(main[:, None], logits1, logits2) logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map( rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1) lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
...@@ -320,7 +319,7 @@ def rollout( ...@@ -320,7 +319,7 @@ def rollout(
rstate = jax.tree.map( rstate = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2) lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, logits = get_logits(params, (rstate, next_obs)) rstate, logits = agent.apply(params, (rstate, next_obs))[:2]
rstate1 = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate, rstate1) rstate1 = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate, rstate1)
rstate2 = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x2, x1), rstate, rstate2) rstate2 = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x2, x1), rstate, rstate2)
rstate1, rstate2 = jax.tree.map( rstate1, rstate2 = jax.tree.map(
...@@ -335,10 +334,11 @@ def rollout( ...@@ -335,10 +334,11 @@ def rollout(
next_obs, info = envs.reset() next_obs, info = envs.reset()
next_to_play = info["to_play"] next_to_play = info["to_play"]
next_done = np.zeros(args.local_num_envs, dtype=np.bool_) next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
next_rstate1 = next_rstate2 = init_rnn_state( next_rstate1 = next_rstate2 = agent.init_rnn_state(args.local_num_envs)
args.local_num_envs, args.rnn_channels)
eval_rstate = init_rnn_state( eval_rstate1 = agent.init_rnn_state(args.local_eval_episodes)
args.local_eval_episodes, args.rnn_channels) eval_rstate2 = eval_agent.init_rnn_state(args.local_eval_episodes)
main_player = np.concatenate([ main_player = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64), np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64) np.ones(args.local_num_envs // 2, dtype=np.int64)
...@@ -452,11 +452,11 @@ def rollout( ...@@ -452,11 +452,11 @@ def rollout(
if eval_mode == 'bot': if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x) predict_fn = lambda x: get_action(params, x)
eval_return, eval_ep_len, eval_win_rate = evaluate( eval_return, eval_ep_len, eval_win_rate = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate) eval_envs, args.local_eval_episodes, predict_fn, eval_rstate2)
else: else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x) predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_return, eval_ep_len, eval_win_rate = battle( eval_return, eval_ep_len, eval_win_rate = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate) eval_envs, args.local_eval_episodes, predict_fn, eval_rstate1, eval_rstate2)
eval_time = time.time() - _start eval_time = time.time() - _start
other_time += eval_time other_time += eval_time
eval_stats = np.array([eval_time, eval_return, eval_win_rate], dtype=np.float32) eval_stats = np.array([eval_time, eval_return, eval_win_rate], dtype=np.float32)
...@@ -606,8 +606,9 @@ if __name__ == "__main__": ...@@ -606,8 +606,9 @@ if __name__ == "__main__":
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac return args.learning_rate * frac
rstate = init_rnn_state(1, args.rnn_channels) # rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args) agent = create_agent(args)
rstate = agent.init_rnn_state(1)
params = agent.init(init_key, (rstate, sample_obs)) params = agent.init(init_key, (rstate, sample_obs))
if embeddings is not None: if embeddings is not None:
unknown_embed = embeddings.mean(axis=0) unknown_embed = embeddings.mean(axis=0)
...@@ -641,20 +642,15 @@ if __name__ == "__main__": ...@@ -641,20 +642,15 @@ if __name__ == "__main__":
# print(agent.tabulate(agent_key, sample_obs)) # print(agent.tabulate(agent_key, sample_obs))
if args.eval_checkpoint: if args.eval_checkpoint:
eval_agent = create_agent(args, eval=True)
eval_rstate = eval_agent.init_rnn_state(1)
eval_params = eval_agent.init(init_key, (eval_rstate, sample_obs))
with open(args.eval_checkpoint, "rb") as f: with open(args.eval_checkpoint, "rb") as f:
eval_params = flax.serialization.from_bytes(params, f.read()) eval_params = flax.serialization.from_bytes(eval_params, f.read())
print(f"loaded eval checkpoint from {args.eval_checkpoint}") print(f"loaded eval checkpoint from {args.eval_checkpoint}")
else: else:
eval_params = None eval_params = None
@jax.jit
def get_logits_and_value(
params: flax.core.FrozenDict, inputs,
):
rstate, logits, value, valid = create_agent(
args, multi_step=True).apply(params, inputs)
return logits, value.squeeze(-1)
def loss_fn( def loss_fn(
params, rstate1, rstate2, obs, dones, next_dones, params, rstate1, rstate2, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, mask, next_value): switch_or_mains, actions, logits, rewards, mask, next_value):
...@@ -671,7 +667,9 @@ if __name__ == "__main__": ...@@ -671,7 +667,9 @@ if __name__ == "__main__":
dones = dones | next_dones dones = dones | next_dones
inputs = (rstate1, rstate2, obs, dones, switch_or_mains) inputs = (rstate1, rstate2, obs, dones, switch_or_mains)
new_logits, new_values = get_logits_and_value(params, inputs) _rstate, new_logits, new_values, _valid = create_agent(
args, multi_step=True).apply(params, inputs)
new_values = new_values.squeeze(-1)
ratios = distrax.importance_sampling_ratios(distrax.Categorical( ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions) new_logits), distrax.Categorical(logits), actions)
......
...@@ -63,6 +63,8 @@ class Args: ...@@ -63,6 +63,8 @@ class Args:
"""the number of channels for the agent""" """the number of channels for the agent"""
rnn_channels: Optional[int] = 512 rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent""" """the number of rnn channels for the agent"""
rnn_type: Optional[str] = "lstm"
"""the type of RNN to use for agent, None for no RNN"""
checkpoint: Optional[str] = None checkpoint: Optional[str] = None
"""the checkpoint to load, must be a `flax_model` file""" """the checkpoint to load, must be a `flax_model` file"""
...@@ -75,18 +77,12 @@ class Args: ...@@ -75,18 +77,12 @@ class Args:
def create_agent(args): def create_agent(args):
return PPOLSTMAgent( return RNNAgent(
channels=args.num_channels, channels=args.num_channels,
num_layers=args.num_layers, num_layers=args.num_layers,
lstm_channels=args.rnn_channels, rnn_channels=args.rnn_channels,
embedding_shape=args.num_embeddings, embedding_shape=args.num_embeddings,
) rnn_type=args.rnn_type,
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
np.zeros((num_envs, rnn_channels)),
) )
...@@ -139,7 +135,7 @@ if __name__ == "__main__": ...@@ -139,7 +135,7 @@ if __name__ == "__main__":
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import flax import flax
from ygoai.rl.jax.agent2 import PPOLSTMAgent from ygoai.rl.jax.agent2 import RNNAgent
from jax.experimental.compilation_cache import compilation_cache as cc from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax")) cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
...@@ -148,7 +144,7 @@ if __name__ == "__main__": ...@@ -148,7 +144,7 @@ if __name__ == "__main__":
key, agent_key = jax.random.split(key, 2) key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample()) sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
rstate = init_rnn_state(1, args.rnn_channels) rstate = agent.init_rnn_state(1)
params = jax.jit(agent.init)(agent_key, (rstate, sample_obs)) params = jax.jit(agent.init)(agent_key, (rstate, sample_obs))
with open(args.checkpoint, "rb") as f: with open(args.checkpoint, "rb") as f:
...@@ -158,7 +154,7 @@ if __name__ == "__main__": ...@@ -158,7 +154,7 @@ if __name__ == "__main__":
@jax.jit @jax.jit
def get_probs_and_value(params, rstate, obs, done): def get_probs_and_value(params, rstate, obs, done):
agent = create_agent(args) agent = agent
next_rstate, logits, value = agent.apply(params, (rstate, obs))[:3] next_rstate, logits, value = agent.apply(params, (rstate, obs))[:3]
probs = jax.nn.softmax(logits, axis=-1) probs = jax.nn.softmax(logits, axis=-1)
next_rstate = jax.tree.map( next_rstate = jax.tree.map(
...@@ -173,6 +169,7 @@ if __name__ == "__main__": ...@@ -173,6 +169,7 @@ if __name__ == "__main__":
obs, infos = envs.reset() obs, infos = envs.reset()
print(obs)
next_to_play = infos['to_play'] next_to_play = infos['to_play']
dones = np.zeros(num_envs, dtype=np.bool_) dones = np.zeros(num_envs, dtype=np.bool_)
......
from typing import Tuple, Union, Optional, Sequence from typing import Tuple, Union, Optional, Sequence
from functools import partial from functools import partial
import numpy as np
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import flax.linen as nn import flax.linen as nn
...@@ -272,8 +273,6 @@ class Encoder(nn.Module): ...@@ -272,8 +273,6 @@ class Encoder(nn.Module):
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1) f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
f_state = MLP((c * 2, c), dtype=self.dtype, param_dtype=self.param_dtype)(f_state) f_state = MLP((c * 2, c), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state) f_state = layer_norm(dtype=self.dtype)(f_state)
# TODO: LSTM
return f_actions, f_state, a_mask, valid return f_actions, f_state, a_mask, valid
...@@ -309,10 +308,34 @@ class Critic(nn.Module): ...@@ -309,10 +308,34 @@ class Critic(nn.Module):
return x return x
class LSTMAgent(nn.Module): def rnn_forward_2p(rnn_layer, rstate1, rstate2, f_state, done, switch_or_main, switch=True):
if switch:
def body_fn(cell, carry, x, done, switch):
rstate, init_rstate2 = carry
rstate, y = cell(rstate, x)
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
rstate = jax.tree.map(lambda x, y: jnp.where(switch[:, None], x, y), init_rstate2, rstate)
return (rstate, init_rstate2), y
else:
def body_fn(cell, carry, x, done, main):
rstate1, rstate2 = carry
rstate = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, y = cell(rstate, x)
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate1, rstate2 = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
return (rstate1, rstate2), y
scan = nn.scan(
body_fn, variable_broadcast='params',
split_rngs={'params': False})
rstate, f_state = scan(rnn_layer, (rstate1, rstate2), f_state, done, switch_or_main)
return rstate, f_state
class RNNAgent(nn.Module):
channels: int = 128 channels: int = 128
num_layers: int = 2 num_layers: int = 2
lstm_channels: int = 512 rnn_channels: int = 512
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
...@@ -320,15 +343,13 @@ class LSTMAgent(nn.Module): ...@@ -320,15 +343,13 @@ class LSTMAgent(nn.Module):
switch: bool = True switch: bool = True
freeze_id: bool = False freeze_id: bool = False
use_history: bool = True use_history: bool = True
no_rnn: bool = False rnn_type: str = 'lstm'
@nn.compact @nn.compact
def __call__(self, inputs): def __call__(self, inputs):
if self.multi_step: if self.multi_step:
# (num_steps * batch_size, ...) # (num_steps * batch_size, ...)
rstate1, rstate2, x, done, switch_or_main = inputs *rstate, x, done, switch_or_main = inputs
batch_size = rstate1[0].shape[0]
num_steps = done.shape[0] // batch_size
else: else:
rstate, x = inputs rstate, x = inputs
...@@ -345,43 +366,48 @@ class LSTMAgent(nn.Module): ...@@ -345,43 +366,48 @@ class LSTMAgent(nn.Module):
f_actions, f_state, mask, valid = encoder(x) f_actions, f_state, mask, valid = encoder(x)
lstm_layer = nn.OptimizedLSTMCell( if self.rnn_type in ['lstm', 'none']:
self.lstm_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0)) rnn_layer = nn.OptimizedLSTMCell(
if self.multi_step: self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
if self.switch: elif self.rnn_type == 'gru':
def body_fn(cell, carry, x, done, switch): rnn_layer = nn.GRUCell(
rstate, init_rstate2 = carry self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
rstate, y = cell(rstate, x) elif self.rnn_type is None:
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate) rnn_layer = None
rstate = jax.tree.map(lambda x, y: jnp.where(switch[:, None], x, y), init_rstate2, rstate)
return (rstate, init_rstate2), y if rnn_layer is None:
else: f_state_r = f_state
def body_fn(cell, carry, x, done, main): elif self.rnn_type == 'none':
rstate1, rstate2 = carry f_state_r = jnp.concatenate([f_state for i in range(self.rnn_channels // c)], axis=-1)
rstate = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, y = cell(rstate, x)
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate1, rstate2 = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
return (rstate1, rstate2), y
scan = nn.scan(
body_fn, variable_broadcast='params',
split_rngs={'params': False})
f_state_r, done, switch_or_main = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main))
rstate, f_state_r = scan(lstm_layer, (rstate1, rstate2), f_state_r, done, switch_or_main)
f_state_r = f_state_r.reshape((-1, f_state_r.shape[-1]))
else: else:
rstate, f_state_r = lstm_layer(rstate, f_state) if self.multi_step:
rstate1, rstate2 = rstate
batch_size = jax.tree.leaves(rstate1)[0].shape[0]
num_steps = done.shape[0] // batch_size
f_state_r, done, switch_or_main = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main))
rstate, f_state_r = rnn_forward_2p(
rnn_layer, rstate1, rstate2, f_state_r, done, switch_or_main, self.switch)
f_state_r = f_state_r.reshape((-1, f_state_r.shape[-1]))
else:
rstate, f_state_r = rnn_layer(rstate, f_state)
actor = Actor( actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
critic = Critic( critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype) channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
if self.no_rnn:
f_state_r = jnp.concatenate([f_state for i in range(self.lstm_channels // c)], axis=-1)
logits = actor(f_state_r, f_actions, mask) logits = actor(f_state_r, f_actions, mask)
value = critic(f_state_r) value = critic(f_state_r)
return rstate, logits, value, valid return rstate, logits, value, valid
def init_rnn_state(self, batch_size):
if self.rnn_type in ['lstm', 'none']:
return (
np.zeros((batch_size, self.rnn_channels)),
np.zeros((batch_size, self.rnn_channels)),
)
elif self.rnn_type == 'gru':
return np.zeros((batch_size, self.rnn_channels))
else:
return None
\ No newline at end of file
...@@ -36,7 +36,7 @@ def evaluate(envs, num_episodes, predict_fn, rnn_state=None): ...@@ -36,7 +36,7 @@ def evaluate(envs, num_episodes, predict_fn, rnn_state=None):
return eval_return, eval_ep_len, eval_win_rate return eval_return, eval_ep_len, eval_win_rate
def battle(envs, num_episodes, predict_fn, init_rnn_state=None): def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None):
num_envs = envs.num_envs num_envs = envs.num_envs
episode_rewards = [] episode_rewards = []
episode_lengths = [] episode_lengths = []
...@@ -50,7 +50,6 @@ def battle(envs, num_episodes, predict_fn, init_rnn_state=None): ...@@ -50,7 +50,6 @@ def battle(envs, num_episodes, predict_fn, init_rnn_state=None):
np.zeros(num_envs // 2, dtype=np.int64), np.zeros(num_envs // 2, dtype=np.int64),
np.ones(num_envs - num_envs // 2, dtype=np.int64) np.ones(num_envs - num_envs // 2, dtype=np.int64)
]) ])
rstate1 = rstate2 = init_rnn_state
while True: while True:
main = next_to_play == main_player main = next_to_play == main_player
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment