Commit b7d52f29 authored by sbl1996@126.com's avatar sbl1996@126.com

Add option for no_rnn

parent a1e6193c
......@@ -80,6 +80,8 @@ class Args:
"""whether to use history actions as input for agent"""
eval_use_history: bool = True
"""whether to use history actions as input for eval agent"""
use_rnn: bool = True
"""whether to use RNN for the agent"""
total_timesteps: int = 50000000000
"""total timesteps of the experiments"""
......@@ -231,6 +233,7 @@ def create_agent(args, multi_step=False, eval=False):
multi_step=multi_step,
freeze_id=args.freeze_id,
use_history=args.use_history if not eval else args.eval_use_history,
no_rnn=(not args.use_rnn) if not eval else False
)
......@@ -318,8 +321,8 @@ def rollout(
rstate = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, logits = get_logits(params, (rstate, next_obs))
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate1 = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate, rstate1)
rstate2 = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x2, x1), rstate, rstate2)
rstate1, rstate2 = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
action, key = categorical_sample(logits, key)
......
......@@ -320,6 +320,7 @@ class LSTMAgent(nn.Module):
switch: bool = True
freeze_id: bool = False
use_history: bool = True
no_rnn: bool = False
@nn.compact
def __call__(self, inputs):
......@@ -366,18 +367,21 @@ class LSTMAgent(nn.Module):
scan = nn.scan(
body_fn, variable_broadcast='params',
split_rngs={'params': False})
f_state, done, switch_or_main = jax.tree.map(
f_state_r, done, switch_or_main = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main))
rstate, f_state = scan(lstm_layer, (rstate1, rstate2), f_state, done, switch_or_main)
f_state = f_state.reshape((-1, f_state.shape[-1]))
rstate, f_state_r = scan(lstm_layer, (rstate1, rstate2), f_state_r, done, switch_or_main)
f_state_r = f_state_r.reshape((-1, f_state_r.shape[-1]))
else:
rstate, f_state = lstm_layer(rstate, f_state)
rstate, f_state_r = lstm_layer(rstate, f_state)
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
logits = actor(f_state, f_actions, mask)
value = critic(f_state)
if self.no_rnn:
f_state_r = jnp.concatenate([f_state for i in range(self.lstm_channels // c)], axis=-1)
logits = actor(f_state_r, f_actions, mask)
value = critic(f_state_r)
return rstate, logits, value, valid
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment