Commit 43ca871e authored by sbl1996@126.com's avatar sbl1996@126.com

Refactor switch

parent 93bc3723
......@@ -4,6 +4,7 @@ import os
import random
from typing import Optional, Literal
from dataclasses import dataclass
from tqdm import tqdm
import ygoenv
import numpy as np
......@@ -220,6 +221,9 @@ if __name__ == "__main__":
])
rstate1 = rstate2 = init_rnn_state(num_envs, args.rnn_channels)
if not args.verbose:
pbar = tqdm(total=args.num_episodes)
model_time = env_time = 0
while True:
if start_step == 0 and len(episode_lengths) > int(args.num_episodes * 0.1):
......@@ -255,7 +259,11 @@ if __name__ == "__main__":
episode_rewards.append(episode_reward)
win_rates.append(win)
win_reasons.append(1 if win_reason == 1 else 0)
sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n")
if args.verbose:
print(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n")
else:
pbar.set_postfix(len=np.mean(episode_lengths), reward=np.mean(episode_rewards), win_rate=np.mean(win_rates))
pbar.update(1)
# Only when num_envs=1, we switch the player here
if args.verbose:
......@@ -264,6 +272,8 @@ if __name__ == "__main__":
if len(episode_lengths) >= args.num_episodes:
break
if not args.verbose:
pbar.close()
print(f"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}")
total_time = time.time() - start
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -150,6 +150,7 @@ class Encoder(nn.Module):
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
freeze_id: bool = False
@nn.compact
def __call__(self, x):
......@@ -168,6 +169,8 @@ class Encoder(nn.Module):
fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype)
id_embed = embed(n_embed, embed_dim)
if self.freeze_id:
id_embed = lambda x: jax.lax.stop_gradient(id_embed(x))
action_encoder = ActionEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
......@@ -337,6 +340,7 @@ class PPOLSTMAgent(nn.Module):
param_dtype: jnp.dtype = jnp.float32
multi_step: bool = False
switch: bool = True
freeze_id: bool = False
@nn.compact
def __call__(self, inputs):
......@@ -355,6 +359,7 @@ class PPOLSTMAgent(nn.Module):
embedding_shape=self.embedding_shape,
dtype=self.dtype,
param_dtype=self.param_dtype,
freeze_id=self.freeze_id,
)
f_actions, f_state, mask, valid = encoder(x)
......
import jax
import jax.numpy as jnp
def truncated_gae_2p0s(
next_value, values, rewards, next_dones, switch, gamma, gae_lambda, upgo
):
def body_fn(carry, inp):
boot_value, boot_done, next_value, lastgaelam, next_q, last_return = carry
next_done, cur_value, reward, switch = inp
next_done = jnp.where(switch, boot_done, next_done)
next_value = jnp.where(switch, -boot_value, next_value)
lastgaelam = jnp.where(switch, 0, lastgaelam)
next_q = jnp.where(switch, -boot_value * gamma, next_q)
last_return = jnp.where(switch, -boot_value, last_return)
discount = gamma * (1.0 - next_done)
last_return = reward + discount * jnp.where(
next_q >= next_value, last_return, next_value)
next_q = reward + discount * next_value
delta = next_q - cur_value
lastgaelam = delta + gae_lambda * discount * lastgaelam
carry = boot_value, boot_done, cur_value, lastgaelam, next_q, last_return
return carry, (lastgaelam, last_return)
next_done = next_dones[-1]
lastgaelam = jnp.zeros_like(next_value)
next_q = last_return = next_value
carry = next_value, next_done, next_value, lastgaelam, next_q, last_return
_, (advantages, returns) = jax.lax.scan(
body_fn, carry, (next_dones, values, rewards, switch), reverse=True
)
if upgo:
advantages += returns - values
targets = values + advantages
targets = jax.lax.stop_gradient(targets)
return targets, advantages
......@@ -58,13 +58,3 @@ def masked_normalize(x, valid, eps=1e-8):
def to_tensor(x, device, dtype=None):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
def load_embeddings(embedding_file, code_list_file):
with open(embedding_file, "rb") as f:
embeddings = pickle.load(f)
with open(code_list_file, "r") as f:
code_list = f.readlines()
code_list = [int(code.strip()) for code in code_list]
assert len(embeddings) == len(code_list), f"len(embeddings)={len(embeddings)}, len(code_list)={len(code_list)}"
embeddings = np.array([embeddings[code] for code in code_list], dtype=np.float32)
return embeddings
import pickle
import numpy as np
from pathlib import Path
......@@ -43,4 +45,15 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False):
elif 'EDOPro' in env_id:
from ygoenv.edopro import init_module
init_module(str(db_path), code_list_file, decks)
return deck_name
\ No newline at end of file
return deck_name
def load_embeddings(embedding_file, code_list_file):
with open(embedding_file, "rb") as f:
embeddings = pickle.load(f)
with open(code_list_file, "r") as f:
code_list = f.readlines()
code_list = [int(code.strip()) for code in code_list]
assert len(embeddings) == len(code_list), f"len(embeddings)={len(embeddings)}, len(code_list)={len(code_list)}"
embeddings = np.array([embeddings[code] for code in code_list], dtype=np.float32)
return embeddings
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment