Commit 632f551d authored by sbl1996@126.com's avatar sbl1996@126.com

Shuffle even epochs=1

parent bcf43ac9
...@@ -1003,9 +1003,8 @@ def main(): ...@@ -1003,9 +1003,8 @@ def main():
key, subkey = jax.random.split(key) key, subkey = jax.random.split(key)
def convert_data(x: jnp.ndarray, multi_step=True): def convert_data(x: jnp.ndarray, multi_step=True):
key = subkey if args.update_epochs > 1 else None
return reshape_minibatch( return reshape_minibatch(
x, multi_step, args.num_minibatches, num_steps, args.segment_length, key=key) x, multi_step, args.num_minibatches, num_steps, args.segment_length, key=subkey)
b_init_rstate, b_next_data = \ b_init_rstate, b_next_data = \
jax.tree.map(partial(convert_data, multi_step=False), jax.tree.map(partial(convert_data, multi_step=False),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment