Commit 7f80e9a8 authored by Biluo Shen's avatar Biluo Shen

GCS save checkpoints and tensorboard logs

parent 718bbe9b
......@@ -23,7 +23,7 @@ from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.ckpt import ModelCheckpoint
from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle
......@@ -46,7 +46,10 @@ class Args:
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
checkpoint_dir: str = "checkpoints"
tb_dir: str = "runs"
"""the directory to save the tensorboard logs"""
ckpt_dir: str = "checkpoints"
"""the directory to save the model checkpoints"""
gcs_bucket: Optional[str] = None
"""the GCS bucket to save the model checkpoints"""
......@@ -524,7 +527,8 @@ if __name__ == "__main__":
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = SummaryWriter(f"runs/{run_name}")
tb_log_dir = f"{args.tb_dir}/{run_name}"
writer = SummaryWriter(tb_log_dir)
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
......@@ -536,7 +540,7 @@ if __name__ == "__main__":
f.write(flax.serialization.to_bytes(obj))
ckpt_maneger = ModelCheckpoint(
args.checkpoint_dir, save_fn, n_saved=3, gcs_bucket=args.gcs_bucket)
args.ckpt_dir, save_fn, n_saved=3)
# seeding
random.seed(args.seed)
......@@ -613,7 +617,7 @@ if __name__ == "__main__":
args, multi_step=True).apply(params, inputs)
return logits, value.squeeze(-1)
def ppo_loss(
def loss_fn(
params, rstate1, rstate2, obs, dones, next_dones,
switch, actions, logits, rewards, mask, next_value):
# (num_steps * local_num_envs // n_mb))
......@@ -701,7 +705,7 @@ if __name__ == "__main__":
switch = T[:, None] == (switch_steps[None, :] - 1)
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
loss_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
def update_epoch(carry, _):
agent_state, key = carry
......@@ -733,7 +737,7 @@ if __name__ == "__main__":
shuffled_mask = jnp.ones_like(shuffled_storage.mains)
def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = loss_grad_fn(
agent_state.params, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
......@@ -870,6 +874,10 @@ if __name__ == "__main__":
M_steps = args.batch_size * learner_policy_version // (2**20)
ckpt_name = f"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger.save(unreplicated_params, ckpt_name)
if args.gcs_bucket is not None:
zip_file_path = "latest.zip"
zip_files(zip_file_path, [ckpt_maneger.get_latest(), tb_log_dir])
sync_to_gcs(args.gcs_bucket, zip_file_path)
if learner_policy_version >= args.num_updates:
break
......
import os
import shutil
from pathlib import Path
import zipfile
class ModelCheckpoint(object):
......@@ -12,17 +14,12 @@ class ModelCheckpoint(object):
Function that will be called to save the object. It should have the signature `save_fn(obj, path)`.
n_saved (int, optional):
Number of objects that should be kept on disk. Older files will be removed.
gcs_bucket (str, optional):
If provided, will sync the saved model to the specified GCS bucket.
"""
def __init__(self, dirname, save_fn, n_saved=1, gcs_bucket=None):
def __init__(self, dirname, save_fn, n_saved=1):
self._dirname = Path(dirname).expanduser()
self._n_saved = n_saved
self._save_fn = save_fn
if gcs_bucket.startswith("gs://"):
gcs_bucket = gcs_bucket[5:]
self._gcs_bucket = gcs_bucket
self._saved = []
def _check_dir(self):
......@@ -33,20 +30,55 @@ class ModelCheckpoint(object):
raise ValueError(
"Directory path '{}' is not found".format(self._dirname))
def save(self, obj, name, sync_gcs=True):
def save(self, obj, name):
self._check_dir()
path = self._dirname / name
self._save_fn(obj, str(path))
self._saved.append(path)
print(f"Saved model to {path}")
if self._gcs_bucket is not None and sync_gcs:
fname = "latest" + path.suffix
gcs_url = Path(self._gcs_bucket) / fname
gcs_url = f"gs://{gcs_url}"
os.system(f"gsutil cp {path} {gcs_url} >> gcs_sync.log 2>&1 &")
print("Sync to GCS: ", gcs_url)
# Copy the lastest checkpoint as latest
lastest_path = path.with_name("latest" + path.suffix)
shutil.copyfile(path, lastest_path)
if len(self._saved) > self._n_saved:
path = self._saved.pop(0)
os.remove(path)
def get_latest(self):
path = self._saved[-1]
return str(path.with_name("latest" + path.suffix))
def sync_to_gcs(bucket, source, dest=None):
if bucket.startswith("gs://"):
bucket = bucket[5:]
if dest is None:
dest = Path(source).name
gcs_url = Path(bucket) / dest
gcs_url = f"gs://{gcs_url}"
os.system(f"gsutil cp {source} {gcs_url} > /dev/null 2>&1 &")
print(f"Sync to GCS: {gcs_url}")
def zip_files(zip_file_path, files_to_zip):
"""
Creates a zip file at the specified path, containing the files and directories
specified in files_to_zip.
Args:
zip_file_path (str): The path to the zip file to be created.
files_to_zip (list): A list of paths to files and directories to be zipped.
"""
with zipfile.ZipFile(zip_file_path, mode='w') as zip_file:
for file_path in files_to_zip:
# Check if the path is a file or a directory
if os.path.isfile(file_path):
# If it's a file, add it to the zip file
zip_file.write(file_path)
elif os.path.isdir(file_path):
# If it's a directory, add all its files and subdirectories to the zip file
for root, dirs, files in os.walk(file_path):
for file in files:
file_path = os.path.join(root, file)
zip_file.write(file_path)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment