Commit 1d35fed3 authored by sbl1996@126.com's avatar sbl1996@126.com

Refactor RNN inputs

parent b8929b9c
...@@ -158,7 +158,7 @@ if __name__ == "__main__": ...@@ -158,7 +158,7 @@ if __name__ == "__main__":
agent1 = create_agent1(args) agent1 = create_agent1(args)
rstate = agent1.init_rnn_state(1) rstate = agent1.init_rnn_state(1)
params1 = jax.jit(agent1.init)(agent_key, (rstate, sample_obs)) params1 = jax.jit(agent1.init)(agent_key, sample_obs, rstate)
with open(args.checkpoint1, "rb") as f: with open(args.checkpoint1, "rb") as f:
params1 = flax.serialization.from_bytes(params1, f.read()) params1 = flax.serialization.from_bytes(params1, f.read())
...@@ -167,7 +167,7 @@ if __name__ == "__main__": ...@@ -167,7 +167,7 @@ if __name__ == "__main__":
else: else:
agent2 = create_agent2(args) agent2 = create_agent2(args)
rstate = agent2.init_rnn_state(1) rstate = agent2.init_rnn_state(1)
params2 = jax.jit(agent2.init)(agent_key, (rstate, sample_obs)) params2 = jax.jit(agent2.init)(agent_key, sample_obs, rstate)
with open(args.checkpoint2, "rb") as f: with open(args.checkpoint2, "rb") as f:
params2 = flax.serialization.from_bytes(params2, f.read()) params2 = flax.serialization.from_bytes(params2, f.read())
...@@ -180,7 +180,7 @@ if __name__ == "__main__": ...@@ -180,7 +180,7 @@ if __name__ == "__main__":
agent = create_agent1(args) agent = create_agent1(args)
else: else:
agent = create_agent2(args) agent = create_agent2(args)
next_rstate, logits = agent.apply(params, (rstate, obs))[:2] next_rstate, logits = agent.apply(params, obs, rstate)[:2]
probs = jax.nn.softmax(logits, axis=-1) probs = jax.nn.softmax(logits, axis=-1)
if done is not None: if done is not None:
next_rstate = jnp.where(done[:, None], 0, next_rstate) next_rstate = jnp.where(done[:, None], 0, next_rstate)
......
...@@ -294,15 +294,14 @@ def rollout( ...@@ -294,15 +294,14 @@ def rollout(
eval_agent = create_agent(args, eval=True) eval_agent = create_agent(args, eval=True)
@jax.jit @jax.jit
def get_action( def get_action(params, obs, rstate):
params: flax.core.FrozenDict, inputs): rstate, logits = eval_agent.apply(params, obs, rstate)[:2]
rstate, logits = eval_agent.apply(params, inputs)[:2]
return rstate, logits.argmax(axis=1) return rstate, logits.argmax(axis=1)
@jax.jit @jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done): def get_action_battle(params1, params2, obs, rstate1, rstate2, main, done):
next_rstate1, logits1 = agent.apply(params1, (rstate1, obs))[:2] next_rstate1, logits1 = agent.apply(params1, obs, rstate1)[:2]
next_rstate2, logits2 = eval_agent.apply(params2, (rstate2, obs))[:2] next_rstate2, logits2 = eval_agent.apply(params2, obs, rstate2)[:2]
logits = jnp.where(main[:, None], logits1, logits2) logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map( rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1) lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
...@@ -314,19 +313,13 @@ def rollout( ...@@ -314,19 +313,13 @@ def rollout(
@jax.jit @jax.jit
def sample_action( def sample_action(
params: flax.core.FrozenDict, params, next_obs, rstate1, rstate2, main, done, key):
next_obs, rstate1, rstate2, main, done, key):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs) next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
done = jnp.array(done) done = jnp.array(done)
main = jnp.array(main) main = jnp.array(main)
rstate = jax.tree.map( inputs = next_obs, (rstate1, rstate2), done, main
lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2) (rstate1, rstate2), logits = agent.apply(params, *inputs)[:2]
rstate, logits = agent.apply(params, (rstate, next_obs))[:2]
rstate1 = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate, rstate1)
rstate2 = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x2, x1), rstate, rstate2)
rstate1, rstate2 = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
action, key = categorical_sample(logits, key) action, key = categorical_sample(logits, key)
return next_obs, done, main, rstate1, rstate2, action, logits, key return next_obs, done, main, rstate1, rstate2, action, logits, key
...@@ -448,12 +441,12 @@ def rollout( ...@@ -448,12 +441,12 @@ def rollout(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2) lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded( sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices), np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_main)) (init_rstate1, init_rstate2, (next_obs, next_rstate), next_main))
if args.eval_interval and update % args.eval_interval == 0: if args.eval_interval and update % args.eval_interval == 0:
_start = time.time() _start = time.time()
if eval_mode == 'bot': if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x) predict_fn = lambda *x: get_action(params, *x)
eval_return, eval_ep_len, eval_win_rate = evaluate( eval_return, eval_ep_len, eval_win_rate = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate2) eval_envs, args.local_eval_episodes, predict_fn, eval_rstate2)
else: else:
...@@ -619,7 +612,7 @@ if __name__ == "__main__": ...@@ -619,7 +612,7 @@ if __name__ == "__main__":
# rstate = init_rnn_state(1, args.rnn_channels) # rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args) agent = create_agent(args)
rstate = agent.init_rnn_state(1) rstate = agent.init_rnn_state(1)
params = agent.init(init_key, (rstate, sample_obs)) params = agent.init(init_key, sample_obs, rstate)
if embeddings is not None: if embeddings is not None:
unknown_embed = embeddings.mean(axis=0) unknown_embed = embeddings.mean(axis=0)
embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0) embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0)
...@@ -654,7 +647,7 @@ if __name__ == "__main__": ...@@ -654,7 +647,7 @@ if __name__ == "__main__":
if args.eval_checkpoint: if args.eval_checkpoint:
eval_agent = create_agent(args, eval=True) eval_agent = create_agent(args, eval=True)
eval_rstate = eval_agent.init_rnn_state(1) eval_rstate = eval_agent.init_rnn_state(1)
eval_params = eval_agent.init(init_key, (eval_rstate, sample_obs)) eval_params = eval_agent.init(init_key, sample_obs, eval_rstate)
with open(args.eval_checkpoint, "rb") as f: with open(args.eval_checkpoint, "rb") as f:
eval_params = flax.serialization.from_bytes(eval_params, f.read()) eval_params = flax.serialization.from_bytes(eval_params, f.read())
print(f"loaded eval checkpoint from {args.eval_checkpoint}") print(f"loaded eval checkpoint from {args.eval_checkpoint}")
...@@ -676,9 +669,8 @@ if __name__ == "__main__": ...@@ -676,9 +669,8 @@ if __name__ == "__main__":
if args.switch: if args.switch:
dones = dones | next_dones dones = dones | next_dones
inputs = (rstate1, rstate2, obs, dones, switch_or_mains) inputs = obs, (rstate1, rstate2), dones, switch_or_mains
_rstate, new_logits, new_values, _valid = create_agent( new_logits, new_values = create_agent(args).apply(params, *inputs)[1:3]
args).apply(params, inputs)
new_values = new_values.squeeze(-1) new_values = new_values.squeeze(-1)
ratios = distrax.importance_sampling_ratios(distrax.Categorical( ratios = distrax.importance_sampling_ratios(distrax.Categorical(
...@@ -780,7 +772,7 @@ if __name__ == "__main__": ...@@ -780,7 +772,7 @@ if __name__ == "__main__":
key, subkey = jax.random.split(key) key, subkey = jax.random.split(key)
next_value = create_agent(args).apply( next_value = create_agent(args).apply(
agent_state.params, next_inputs)[2].squeeze(-1) agent_state.params, *next_inputs)[2].squeeze(-1)
if args.switch: if args.switch:
next_value = jnp.where(next_main, -next_value, next_value) next_value = jnp.where(next_main, -next_value, next_value)
else: else:
......
...@@ -145,7 +145,7 @@ if __name__ == "__main__": ...@@ -145,7 +145,7 @@ if __name__ == "__main__":
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample()) sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
rstate = agent.init_rnn_state(1) rstate = agent.init_rnn_state(1)
params = jax.jit(agent.init)(agent_key, (rstate, sample_obs)) params = jax.jit(agent.init)(agent_key, sample_obs, rstate)
with open(args.checkpoint, "rb") as f: with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read()) params = flax.serialization.from_bytes(params, f.read())
...@@ -154,8 +154,7 @@ if __name__ == "__main__": ...@@ -154,8 +154,7 @@ if __name__ == "__main__":
@jax.jit @jax.jit
def get_probs_and_value(params, rstate, obs, done): def get_probs_and_value(params, rstate, obs, done):
agent = agent next_rstate, logits, value = agent.apply(params, obs, rstate)[:3]
next_rstate, logits, value = agent.apply(params, (rstate, obs))[:3]
probs = jax.nn.softmax(logits, axis=-1) probs = jax.nn.softmax(logits, axis=-1)
next_rstate = jax.tree.map( next_rstate = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), next_rstate) lambda x: jnp.where(done[:, None], 0, x), next_rstate)
......
...@@ -308,7 +308,21 @@ class Critic(nn.Module): ...@@ -308,7 +308,21 @@ class Critic(nn.Module):
return x return x
def rnn_forward_2p(rnn_layer, rstate1, rstate2, f_state, done, switch_or_main, switch=True): def rnn_step_by_main(rnn_layer, rstate, f_state, done, main):
if main is not None:
rstate1, rstate2 = rstate
rstate = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, f_state = rnn_layer(rstate, f_state)
if main is not None:
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate = rstate1, rstate2
if done is not None:
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
return rstate, f_state
def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True):
if switch: if switch:
def body_fn(cell, carry, x, done, switch): def body_fn(cell, carry, x, done, switch):
rstate, init_rstate2 = carry rstate, init_rstate2 = carry
...@@ -318,20 +332,15 @@ def rnn_forward_2p(rnn_layer, rstate1, rstate2, f_state, done, switch_or_main, s ...@@ -318,20 +332,15 @@ def rnn_forward_2p(rnn_layer, rstate1, rstate2, f_state, done, switch_or_main, s
return (rstate, init_rstate2), y return (rstate, init_rstate2), y
else: else:
def body_fn(cell, carry, x, done, main): def body_fn(cell, carry, x, done, main):
rstate1, rstate2 = carry return rnn_step_by_main(cell, carry, x, done, main)
rstate = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, y = cell(rstate, x)
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate1, rstate2 = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
return (rstate1, rstate2), y
scan = nn.scan( scan = nn.scan(
body_fn, variable_broadcast='params', body_fn, variable_broadcast='params',
split_rngs={'params': False}) split_rngs={'params': False})
rstate, f_state = scan(rnn_layer, (rstate1, rstate2), f_state, done, switch_or_main) rstate, f_state = scan(rnn_layer, rstate, f_state, done, switch_or_main)
return rstate, f_state return rstate, f_state
class RNNAgent(nn.Module): class RNNAgent(nn.Module):
channels: int = 128 channels: int = 128
num_layers: int = 2 num_layers: int = 2
...@@ -345,14 +354,7 @@ class RNNAgent(nn.Module): ...@@ -345,14 +354,7 @@ class RNNAgent(nn.Module):
rnn_type: str = 'lstm' rnn_type: str = 'lstm'
@nn.compact @nn.compact
def __call__(self, inputs): def __call__(self, x, rstate, done=None, switch_or_main=None):
multi_step = len(inputs) != 2
if multi_step:
# (num_steps * batch_size, ...)
*rstate, x, done, switch_or_main = inputs
else:
rstate, x = inputs
c = self.channels c = self.channels
encoder = Encoder( encoder = Encoder(
channels=c, channels=c,
...@@ -380,17 +382,24 @@ class RNNAgent(nn.Module): ...@@ -380,17 +382,24 @@ class RNNAgent(nn.Module):
elif self.rnn_type == 'none': elif self.rnn_type == 'none':
f_state_r = jnp.concatenate([f_state for i in range(self.rnn_channels // c)], axis=-1) f_state_r = jnp.concatenate([f_state for i in range(self.rnn_channels // c)], axis=-1)
else: else:
batch_size = jax.tree.leaves(rstate)[0].shape[0]
num_steps = f_state.shape[0] // batch_size
multi_step = num_steps > 1
if done is not None:
assert switch_or_main is not None
else:
assert not multi_step
if multi_step: if multi_step:
rstate1, rstate2 = rstate
batch_size = jax.tree.leaves(rstate1)[0].shape[0]
num_steps = done.shape[0] // batch_size
f_state_r, done, switch_or_main = jax.tree.map( f_state_r, done, switch_or_main = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main)) lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main))
rstate, f_state_r = rnn_forward_2p( rstate, f_state_r = rnn_forward_2p(
rnn_layer, rstate1, rstate2, f_state_r, done, switch_or_main, self.switch) rnn_layer, rstate, f_state_r, done, switch_or_main, self.switch)
f_state_r = f_state_r.reshape((-1, f_state_r.shape[-1])) f_state_r = f_state_r.reshape((-1, f_state_r.shape[-1]))
else: else:
rstate, f_state_r = rnn_layer(rstate, f_state) rstate, f_state_r = rnn_step_by_main(
rnn_layer, rstate, f_state, done, switch_or_main)
actor = Actor( actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
......
...@@ -11,7 +11,7 @@ def evaluate(envs, num_episodes, predict_fn, rnn_state=None): ...@@ -11,7 +11,7 @@ def evaluate(envs, num_episodes, predict_fn, rnn_state=None):
if rnn_state is None: if rnn_state is None:
actions = predict_fn(obs) actions = predict_fn(obs)
else: else:
rnn_state, actions = predict_fn((rnn_state, obs)) rnn_state, actions = predict_fn(obs, rnn_state)
actions = np.array(actions) actions = np.array(actions)
obs, rewards, dones, info = envs.step(actions) obs, rewards, dones, info = envs.step(actions)
...@@ -53,7 +53,7 @@ def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None): ...@@ -53,7 +53,7 @@ def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None):
while True: while True:
main = next_to_play == main_player main = next_to_play == main_player
rstate1, rstate2, actions = predict_fn(rstate1, rstate2, obs, main, dones) rstate1, rstate2, actions = predict_fn(obs, rstate1, rstate2, main, dones)
actions = np.array(actions) actions = np.array(actions)
obs, rewards, dones, infos = envs.step(actions) obs, rewards, dones, infos = envs.step(actions)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment