Commit 1d35fed3 authored by sbl1996@126.com's avatar sbl1996@126.com

Refactor RNN inputs

parent b8929b9c
......@@ -158,7 +158,7 @@ if __name__ == "__main__":
agent1 = create_agent1(args)
rstate = agent1.init_rnn_state(1)
params1 = jax.jit(agent1.init)(agent_key, (rstate, sample_obs))
params1 = jax.jit(agent1.init)(agent_key, sample_obs, rstate)
with open(args.checkpoint1, "rb") as f:
params1 = flax.serialization.from_bytes(params1, f.read())
......@@ -167,7 +167,7 @@ if __name__ == "__main__":
else:
agent2 = create_agent2(args)
rstate = agent2.init_rnn_state(1)
params2 = jax.jit(agent2.init)(agent_key, (rstate, sample_obs))
params2 = jax.jit(agent2.init)(agent_key, sample_obs, rstate)
with open(args.checkpoint2, "rb") as f:
params2 = flax.serialization.from_bytes(params2, f.read())
......@@ -180,7 +180,7 @@ if __name__ == "__main__":
agent = create_agent1(args)
else:
agent = create_agent2(args)
next_rstate, logits = agent.apply(params, (rstate, obs))[:2]
next_rstate, logits = agent.apply(params, obs, rstate)[:2]
probs = jax.nn.softmax(logits, axis=-1)
if done is not None:
next_rstate = jnp.where(done[:, None], 0, next_rstate)
......
......@@ -294,15 +294,14 @@ def rollout(
eval_agent = create_agent(args, eval=True)
@jax.jit
def get_action(
params: flax.core.FrozenDict, inputs):
rstate, logits = eval_agent.apply(params, inputs)[:2]
def get_action(params, obs, rstate):
rstate, logits = eval_agent.apply(params, obs, rstate)[:2]
return rstate, logits.argmax(axis=1)
@jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, logits1 = agent.apply(params1, (rstate1, obs))[:2]
next_rstate2, logits2 = eval_agent.apply(params2, (rstate2, obs))[:2]
def get_action_battle(params1, params2, obs, rstate1, rstate2, main, done):
next_rstate1, logits1 = agent.apply(params1, obs, rstate1)[:2]
next_rstate2, logits2 = eval_agent.apply(params2, obs, rstate2)[:2]
logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
......@@ -314,19 +313,13 @@ def rollout(
@jax.jit
def sample_action(
params: flax.core.FrozenDict,
next_obs, rstate1, rstate2, main, done, key):
params, next_obs, rstate1, rstate2, main, done, key):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
done = jnp.array(done)
main = jnp.array(main)
rstate = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, logits = agent.apply(params, (rstate, next_obs))[:2]
rstate1 = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate, rstate1)
rstate2 = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x2, x1), rstate, rstate2)
rstate1, rstate2 = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
inputs = next_obs, (rstate1, rstate2), done, main
(rstate1, rstate2), logits = agent.apply(params, *inputs)[:2]
action, key = categorical_sample(logits, key)
return next_obs, done, main, rstate1, rstate2, action, logits, key
......@@ -448,12 +441,12 @@ def rollout(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_main))
(init_rstate1, init_rstate2, (next_obs, next_rstate), next_main))
if args.eval_interval and update % args.eval_interval == 0:
_start = time.time()
if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x)
predict_fn = lambda *x: get_action(params, *x)
eval_return, eval_ep_len, eval_win_rate = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate2)
else:
......@@ -619,7 +612,7 @@ if __name__ == "__main__":
# rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args)
rstate = agent.init_rnn_state(1)
params = agent.init(init_key, (rstate, sample_obs))
params = agent.init(init_key, sample_obs, rstate)
if embeddings is not None:
unknown_embed = embeddings.mean(axis=0)
embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0)
......@@ -654,7 +647,7 @@ if __name__ == "__main__":
if args.eval_checkpoint:
eval_agent = create_agent(args, eval=True)
eval_rstate = eval_agent.init_rnn_state(1)
eval_params = eval_agent.init(init_key, (eval_rstate, sample_obs))
eval_params = eval_agent.init(init_key, sample_obs, eval_rstate)
with open(args.eval_checkpoint, "rb") as f:
eval_params = flax.serialization.from_bytes(eval_params, f.read())
print(f"loaded eval checkpoint from {args.eval_checkpoint}")
......@@ -676,9 +669,8 @@ if __name__ == "__main__":
if args.switch:
dones = dones | next_dones
inputs = (rstate1, rstate2, obs, dones, switch_or_mains)
_rstate, new_logits, new_values, _valid = create_agent(
args).apply(params, inputs)
inputs = obs, (rstate1, rstate2), dones, switch_or_mains
new_logits, new_values = create_agent(args).apply(params, *inputs)[1:3]
new_values = new_values.squeeze(-1)
ratios = distrax.importance_sampling_ratios(distrax.Categorical(
......@@ -780,7 +772,7 @@ if __name__ == "__main__":
key, subkey = jax.random.split(key)
next_value = create_agent(args).apply(
agent_state.params, next_inputs)[2].squeeze(-1)
agent_state.params, *next_inputs)[2].squeeze(-1)
if args.switch:
next_value = jnp.where(next_main, -next_value, next_value)
else:
......
......@@ -145,7 +145,7 @@ if __name__ == "__main__":
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
rstate = agent.init_rnn_state(1)
params = jax.jit(agent.init)(agent_key, (rstate, sample_obs))
params = jax.jit(agent.init)(agent_key, sample_obs, rstate)
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
......@@ -154,8 +154,7 @@ if __name__ == "__main__":
@jax.jit
def get_probs_and_value(params, rstate, obs, done):
agent = agent
next_rstate, logits, value = agent.apply(params, (rstate, obs))[:3]
next_rstate, logits, value = agent.apply(params, obs, rstate)[:3]
probs = jax.nn.softmax(logits, axis=-1)
next_rstate = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), next_rstate)
......
......@@ -308,7 +308,21 @@ class Critic(nn.Module):
return x
def rnn_forward_2p(rnn_layer, rstate1, rstate2, f_state, done, switch_or_main, switch=True):
def rnn_step_by_main(rnn_layer, rstate, f_state, done, main):
if main is not None:
rstate1, rstate2 = rstate
rstate = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, f_state = rnn_layer(rstate, f_state)
if main is not None:
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate = rstate1, rstate2
if done is not None:
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
return rstate, f_state
def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True):
if switch:
def body_fn(cell, carry, x, done, switch):
rstate, init_rstate2 = carry
......@@ -318,20 +332,15 @@ def rnn_forward_2p(rnn_layer, rstate1, rstate2, f_state, done, switch_or_main, s
return (rstate, init_rstate2), y
else:
def body_fn(cell, carry, x, done, main):
rstate1, rstate2 = carry
rstate = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, y = cell(rstate, x)
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate1, rstate2 = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
return (rstate1, rstate2), y
return rnn_step_by_main(cell, carry, x, done, main)
scan = nn.scan(
body_fn, variable_broadcast='params',
split_rngs={'params': False})
rstate, f_state = scan(rnn_layer, (rstate1, rstate2), f_state, done, switch_or_main)
rstate, f_state = scan(rnn_layer, rstate, f_state, done, switch_or_main)
return rstate, f_state
class RNNAgent(nn.Module):
channels: int = 128
num_layers: int = 2
......@@ -345,14 +354,7 @@ class RNNAgent(nn.Module):
rnn_type: str = 'lstm'
@nn.compact
def __call__(self, inputs):
multi_step = len(inputs) != 2
if multi_step:
# (num_steps * batch_size, ...)
*rstate, x, done, switch_or_main = inputs
else:
rstate, x = inputs
def __call__(self, x, rstate, done=None, switch_or_main=None):
c = self.channels
encoder = Encoder(
channels=c,
......@@ -380,17 +382,24 @@ class RNNAgent(nn.Module):
elif self.rnn_type == 'none':
f_state_r = jnp.concatenate([f_state for i in range(self.rnn_channels // c)], axis=-1)
else:
batch_size = jax.tree.leaves(rstate)[0].shape[0]
num_steps = f_state.shape[0] // batch_size
multi_step = num_steps > 1
if done is not None:
assert switch_or_main is not None
else:
assert not multi_step
if multi_step:
rstate1, rstate2 = rstate
batch_size = jax.tree.leaves(rstate1)[0].shape[0]
num_steps = done.shape[0] // batch_size
f_state_r, done, switch_or_main = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main))
rstate, f_state_r = rnn_forward_2p(
rnn_layer, rstate1, rstate2, f_state_r, done, switch_or_main, self.switch)
rnn_layer, rstate, f_state_r, done, switch_or_main, self.switch)
f_state_r = f_state_r.reshape((-1, f_state_r.shape[-1]))
else:
rstate, f_state_r = rnn_layer(rstate, f_state)
rstate, f_state_r = rnn_step_by_main(
rnn_layer, rstate, f_state, done, switch_or_main)
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
......
......@@ -11,7 +11,7 @@ def evaluate(envs, num_episodes, predict_fn, rnn_state=None):
if rnn_state is None:
actions = predict_fn(obs)
else:
rnn_state, actions = predict_fn((rnn_state, obs))
rnn_state, actions = predict_fn(obs, rnn_state)
actions = np.array(actions)
obs, rewards, dones, info = envs.step(actions)
......@@ -53,7 +53,7 @@ def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None):
while True:
main = next_to_play == main_player
rstate1, rstate2, actions = predict_fn(rstate1, rstate2, obs, main, dones)
rstate1, rstate2, actions = predict_fn(obs, rstate1, rstate2, main, dones)
actions = np.array(actions)
obs, rewards, dones, infos = envs.step(actions)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment