Commit 89c209f2 authored by Biluo Shen's avatar Biluo Shen

Add ModelCheckpoint

parent 45506246
...@@ -23,6 +23,7 @@ from rich.pretty import pprint ...@@ -23,6 +23,7 @@ from rich.pretty import pprint
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.ckpt import ModelCheckpoint
from ygoai.rl.jax.agent2 import PPOLSTMAgent from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle from ygoai.rl.jax.eval import evaluate, battle
...@@ -45,6 +46,10 @@ class Args: ...@@ -45,6 +46,10 @@ class Args:
"""the frequency of saving the model (in terms of `updates`)""" """the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None checkpoint: Optional[str] = None
"""the path to the model checkpoint to load""" """the path to the model checkpoint to load"""
checkpoint_dir: str = "checkpoints"
"""the directory to save the model checkpoints"""
gcs_bucket: Optional[str] = None
"""the GCS bucket to save the model checkpoints"""
# Algorithm specific arguments # Algorithm specific arguments
env_id: str = "YGOPro-v0" env_id: str = "YGOPro-v0"
...@@ -525,6 +530,14 @@ if __name__ == "__main__": ...@@ -525,6 +530,14 @@ if __name__ == "__main__":
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
) )
def save_fn(obj, path):
with open(path, "wb") as f:
f.write(flax.serialization.to_bytes(obj))
ckpt_maneger = ModelCheckpoint(
args.checkpoint_dir, save_fn, n_saved=3, gcs_bucket=args.gcs_bucket)
# seeding # seeding
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
...@@ -854,15 +867,9 @@ if __name__ == "__main__": ...@@ -854,15 +867,9 @@ if __name__ == "__main__":
writer.add_scalar("losses/loss", loss, global_step) writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0: if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
ckpt_dir = f"checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
M_steps = args.batch_size * learner_policy_version // (2**20) M_steps = args.batch_size * learner_policy_version // (2**20)
model_path = os.path.join(ckpt_dir, f"{timestamp}_{M_steps}M.flax_model") ckpt_name = f"{timestamp}_{M_steps}M.flax_model"
with open(model_path, "wb") as f: ckpt_maneger.save(unreplicated_params, ckpt_name)
f.write(
flax.serialization.to_bytes(unreplicated_params)
)
print(f"model saved to {model_path}")
if learner_policy_version >= args.num_updates: if learner_policy_version >= args.num_updates:
break break
......
import os
from pathlib import Path
class ModelCheckpoint(object):
""" ModelCheckpoint handler can be used to periodically save objects to disk.
Args:
dirname (str):
Directory path where objects will be saved.
save_fn (callable):
Function that will be called to save the object. It should have the signature `save_fn(obj, path)`.
n_saved (int, optional):
Number of objects that should be kept on disk. Older files will be removed.
gcs_bucket (str, optional):
If provided, will sync the saved model to the specified GCS bucket.
"""
def __init__(self, dirname, save_fn, n_saved=1, gcs_bucket=None):
self._dirname = Path(dirname).expanduser()
self._n_saved = n_saved
self._save_fn = save_fn
if gcs_bucket.startswith("gs://"):
gcs_bucket = gcs_bucket[5:]
self._gcs_bucket = gcs_bucket
self._saved = []
def _check_dir(self):
self._dirname.mkdir(parents=True, exist_ok=True)
# Ensure that dirname exists
if not self._dirname.exists():
raise ValueError(
"Directory path '{}' is not found".format(self._dirname))
def save(self, obj, name, sync_gcs=True):
self._check_dir()
path = self._dirname / name
self._save_fn(obj, str(path))
self._saved.append(path)
print(f"Saved model to {path}")
if self._gcs_bucket is not None and sync_gcs:
fname = "latest" + path.suffix
gcs_url = Path(self._gcs_bucket) / fname
gcs_url = f"gs://{gcs_url}"
os.system(f"gsutil cp {path} {gcs_url} >> gcs_sync.log 2>&1 &")
print("Sync to GCS: ", gcs_url)
if len(self._saved) > self._n_saved:
path = self._saved.pop(0)
os.remove(path)
...@@ -32,8 +32,8 @@ def truncated_gae_2p0s( ...@@ -32,8 +32,8 @@ def truncated_gae_2p0s(
_, (advantages, returns) = jax.lax.scan( _, (advantages, returns) = jax.lax.scan(
body_fn, carry, (next_dones, values, rewards, switch), reverse=True body_fn, carry, (next_dones, values, rewards, switch), reverse=True
) )
targets = values + advantages
if upgo: if upgo:
advantages += returns - values advantages += returns - values
targets = values + advantages
targets = jax.lax.stop_gradient(targets) targets = jax.lax.stop_gradient(targets)
return targets, advantages return targets, advantages
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment