Commit 6af7c4b4 authored by sbl1996@126.com's avatar sbl1996@126.com

Tensorboard log only on local_rank 0

parent 45885f81
import os
import shutil
import queue
import random
import threading
......@@ -527,13 +528,18 @@ if __name__ == "__main__":
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
tb_log_dir = f"{args.tb_dir}/{run_name}"
writer = SummaryWriter(tb_log_dir)
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
tb_log_dir = f"{args.tb_dir}/{run_name}"
if args.local_rank == 0:
writer = SummaryWriter(tb_log_dir)
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
else:
writer = dummy_writer
def save_fn(obj, path):
with open(path, "wb") as f:
......@@ -782,8 +788,6 @@ if __name__ == "__main__":
params_queues = []
rollout_queues = []
eval_queue = queue.Queue()
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
......@@ -875,8 +879,11 @@ if __name__ == "__main__":
ckpt_name = f"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger.save(unreplicated_params, ckpt_name)
if args.gcs_bucket is not None:
lastest_path = ckpt_maneger.get_latest()
copy_path = lastest_path.with_name("latest" + lastest_path.suffix)
shutil.copyfile(lastest_path, copy_path)
zip_file_path = "latest.zip"
zip_files(zip_file_path, [ckpt_maneger.get_latest(), tb_log_dir])
zip_files(zip_file_path, [str(copy_path), tb_log_dir])
sync_to_gcs(args.gcs_bucket, zip_file_path)
if learner_policy_version >= args.num_updates:
......
import os
import shutil
from pathlib import Path
import zipfile
......@@ -37,17 +36,13 @@ class ModelCheckpoint(object):
self._saved.append(path)
print(f"Saved model to {path}")
# Copy the lastest checkpoint as latest
lastest_path = path.with_name("latest" + path.suffix)
shutil.copyfile(path, lastest_path)
if len(self._saved) > self._n_saved:
path = self._saved.pop(0)
os.remove(path)
def get_latest(self):
path = self._saved[-1]
return str(path.with_name("latest" + path.suffix))
return path
def sync_to_gcs(bucket, source, dest=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment