Commit 4a0590bd authored by sbl1996@126.com's avatar sbl1996@126.com

Rename entropy_loss to ent_loss

parent 04e61b91
......@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent2 import RNNAgent, ModelArgs
from ygoai.rl.jax.agent import RNNAgent, ModelArgs
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import clipped_surrogate_pg_loss, vtrace_2p0s, mse_loss, entropy_loss, simple_policy_loss, ach_loss, policy_gradient_loss
......@@ -285,7 +285,7 @@ def rollout(
avg_win_rates = deque(maxlen=1000)
agent = create_agent(args)
eval_agent = create_agent(args, eval=True)
eval_agent = create_agent(args, eval=eval_mode != 'bot')
@jax.jit
def get_action(params, obs, rstate):
......@@ -492,7 +492,7 @@ def rollout(
writer.add_scalar("charts/SPS_update", SPS_update, tb_global_step)
if __name__ == "__main__":
def main():
args = tyro.cli(Args)
args.local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
......@@ -796,13 +796,13 @@ if __name__ == "__main__":
shuffled_mask = jnp.ones_like(shuffled_storage.mains)
def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = loss_grad_fn(
(loss, (pg_loss, v_loss, ent_loss, approx_kl)), grads = loss_grad_fn(
agent_state.params, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl)
return agent_state, (loss, pg_loss, v_loss, ent_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
agent_state, (loss, pg_loss, v_loss, ent_loss, approx_kl) = jax.lax.scan(
update_minibatch,
agent_state,
(
......@@ -819,17 +819,17 @@ if __name__ == "__main__":
shuffled_next_value,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
return (agent_state, key), (loss, pg_loss, v_loss, ent_loss, approx_kl)
(agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
(agent_state, key), (loss, pg_loss, v_loss, ent_loss, approx_kl) = jax.lax.scan(
update_epoch, (agent_state, key), (), length=args.update_epochs
)
loss = jax.lax.pmean(loss, axis_name="local_devices").mean()
pg_loss = jax.lax.pmean(pg_loss, axis_name="local_devices").mean()
v_loss = jax.lax.pmean(v_loss, axis_name="local_devices").mean()
entropy_loss = jax.lax.pmean(entropy_loss, axis_name="local_devices").mean()
ent_loss = jax.lax.pmean(ent_loss, axis_name="local_devices").mean()
approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
return agent_state, loss, pg_loss, v_loss, ent_loss, approx_kl, key
all_reduce_value = jax.pmap(
lambda x: jax.lax.pmean(x, axis_name="main_devices"),
......@@ -872,7 +872,6 @@ if __name__ == "__main__":
params_queues[-1].put(device_params)
rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
learner_policy_version = 0
while True:
learner_policy_version += 1
......@@ -905,7 +904,7 @@ if __name__ == "__main__":
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
(agent_state, loss, pg_loss, v_loss, ent_loss, approx_kl, learner_keys) = multi_device_update(
agent_state,
*list(zip(*sharded_data_list)),
learner_keys,
......@@ -943,7 +942,7 @@ if __name__ == "__main__":
)
writer.add_scalar("losses/value_loss", v_loss[-1].item(), tb_global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), tb_global_step)
writer.add_scalar("losses/entropy", entropy_loss[-1].item(), tb_global_step)
writer.add_scalar("losses/entropy", ent_loss[-1].item(), tb_global_step)
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), tb_global_step)
writer.add_scalar("losses/loss", loss, tb_global_step)
......@@ -966,3 +965,7 @@ if __name__ == "__main__":
jax.distributed.shutdown()
writer.close()
if __name__ == "__main__":
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment