Commit 6af7c4b4 authored by sbl1996@126.com's avatar sbl1996@126.com

Tensorboard log only on local_rank 0

parent 45885f81
import os import os
import shutil
import queue import queue
import random import random
import threading import threading
...@@ -527,13 +528,18 @@ if __name__ == "__main__": ...@@ -527,13 +528,18 @@ if __name__ == "__main__":
timestamp = int(time.time()) timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}" run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
tb_log_dir = f"{args.tb_dir}/{run_name}" tb_log_dir = f"{args.tb_dir}/{run_name}"
if args.local_rank == 0:
writer = SummaryWriter(tb_log_dir) writer = SummaryWriter(tb_log_dir)
writer.add_text( writer.add_text(
"hyperparameters", "hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
) )
else:
writer = dummy_writer
def save_fn(obj, path): def save_fn(obj, path):
with open(path, "wb") as f: with open(path, "wb") as f:
...@@ -782,8 +788,6 @@ if __name__ == "__main__": ...@@ -782,8 +788,6 @@ if __name__ == "__main__":
params_queues = [] params_queues = []
rollout_queues = [] rollout_queues = []
eval_queue = queue.Queue() eval_queue = queue.Queue()
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params) unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids): for d_idx, d_id in enumerate(args.actor_device_ids):
...@@ -875,8 +879,11 @@ if __name__ == "__main__": ...@@ -875,8 +879,11 @@ if __name__ == "__main__":
ckpt_name = f"{timestamp}_{M_steps}M.flax_model" ckpt_name = f"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger.save(unreplicated_params, ckpt_name) ckpt_maneger.save(unreplicated_params, ckpt_name)
if args.gcs_bucket is not None: if args.gcs_bucket is not None:
lastest_path = ckpt_maneger.get_latest()
copy_path = lastest_path.with_name("latest" + lastest_path.suffix)
shutil.copyfile(lastest_path, copy_path)
zip_file_path = "latest.zip" zip_file_path = "latest.zip"
zip_files(zip_file_path, [ckpt_maneger.get_latest(), tb_log_dir]) zip_files(zip_file_path, [str(copy_path), tb_log_dir])
sync_to_gcs(args.gcs_bucket, zip_file_path) sync_to_gcs(args.gcs_bucket, zip_file_path)
if learner_policy_version >= args.num_updates: if learner_policy_version >= args.num_updates:
......
import os import os
import shutil
from pathlib import Path from pathlib import Path
import zipfile import zipfile
...@@ -37,17 +36,13 @@ class ModelCheckpoint(object): ...@@ -37,17 +36,13 @@ class ModelCheckpoint(object):
self._saved.append(path) self._saved.append(path)
print(f"Saved model to {path}") print(f"Saved model to {path}")
# Copy the lastest checkpoint as latest
lastest_path = path.with_name("latest" + path.suffix)
shutil.copyfile(path, lastest_path)
if len(self._saved) > self._n_saved: if len(self._saved) > self._n_saved:
path = self._saved.pop(0) path = self._saved.pop(0)
os.remove(path) os.remove(path)
def get_latest(self): def get_latest(self):
path = self._saved[-1] path = self._saved[-1]
return str(path.with_name("latest" + path.suffix)) return path
def sync_to_gcs(bucket, source, dest=None): def sync_to_gcs(bucket, source, dest=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment