Commit 43ca871e authored by sbl1996@126.com's avatar sbl1996@126.com

Refactor switch

parent 93bc3723
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import random import random
from typing import Optional, Literal from typing import Optional, Literal
from dataclasses import dataclass from dataclasses import dataclass
from tqdm import tqdm
import ygoenv import ygoenv
import numpy as np import numpy as np
...@@ -220,6 +221,9 @@ if __name__ == "__main__": ...@@ -220,6 +221,9 @@ if __name__ == "__main__":
]) ])
rstate1 = rstate2 = init_rnn_state(num_envs, args.rnn_channels) rstate1 = rstate2 = init_rnn_state(num_envs, args.rnn_channels)
if not args.verbose:
pbar = tqdm(total=args.num_episodes)
model_time = env_time = 0 model_time = env_time = 0
while True: while True:
if start_step == 0 and len(episode_lengths) > int(args.num_episodes * 0.1): if start_step == 0 and len(episode_lengths) > int(args.num_episodes * 0.1):
...@@ -255,7 +259,11 @@ if __name__ == "__main__": ...@@ -255,7 +259,11 @@ if __name__ == "__main__":
episode_rewards.append(episode_reward) episode_rewards.append(episode_reward)
win_rates.append(win) win_rates.append(win)
win_reasons.append(1 if win_reason == 1 else 0) win_reasons.append(1 if win_reason == 1 else 0)
sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n") if args.verbose:
print(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n")
else:
pbar.set_postfix(len=np.mean(episode_lengths), reward=np.mean(episode_rewards), win_rate=np.mean(win_rates))
pbar.update(1)
# Only when num_envs=1, we switch the player here # Only when num_envs=1, we switch the player here
if args.verbose: if args.verbose:
...@@ -264,6 +272,8 @@ if __name__ == "__main__": ...@@ -264,6 +272,8 @@ if __name__ == "__main__":
if len(episode_lengths) >= args.num_episodes: if len(episode_lengths) >= args.num_episodes:
break break
if not args.verbose:
pbar.close()
print(f"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}") print(f"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}")
total_time = time.time() - start total_time = time.time() - start
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -150,6 +150,7 @@ class Encoder(nn.Module): ...@@ -150,6 +150,7 @@ class Encoder(nn.Module):
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
freeze_id: bool = False
@nn.compact @nn.compact
def __call__(self, x): def __call__(self, x):
...@@ -168,6 +169,8 @@ class Encoder(nn.Module): ...@@ -168,6 +169,8 @@ class Encoder(nn.Module):
fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype) fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype)
id_embed = embed(n_embed, embed_dim) id_embed = embed(n_embed, embed_dim)
if self.freeze_id:
id_embed = lambda x: jax.lax.stop_gradient(id_embed(x))
action_encoder = ActionEncoder( action_encoder = ActionEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
...@@ -337,6 +340,7 @@ class PPOLSTMAgent(nn.Module): ...@@ -337,6 +340,7 @@ class PPOLSTMAgent(nn.Module):
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
multi_step: bool = False multi_step: bool = False
switch: bool = True switch: bool = True
freeze_id: bool = False
@nn.compact @nn.compact
def __call__(self, inputs): def __call__(self, inputs):
...@@ -355,6 +359,7 @@ class PPOLSTMAgent(nn.Module): ...@@ -355,6 +359,7 @@ class PPOLSTMAgent(nn.Module):
embedding_shape=self.embedding_shape, embedding_shape=self.embedding_shape,
dtype=self.dtype, dtype=self.dtype,
param_dtype=self.param_dtype, param_dtype=self.param_dtype,
freeze_id=self.freeze_id,
) )
f_actions, f_state, mask, valid = encoder(x) f_actions, f_state, mask, valid = encoder(x)
......
import jax
import jax.numpy as jnp
def truncated_gae_2p0s(
next_value, values, rewards, next_dones, switch, gamma, gae_lambda, upgo
):
def body_fn(carry, inp):
boot_value, boot_done, next_value, lastgaelam, next_q, last_return = carry
next_done, cur_value, reward, switch = inp
next_done = jnp.where(switch, boot_done, next_done)
next_value = jnp.where(switch, -boot_value, next_value)
lastgaelam = jnp.where(switch, 0, lastgaelam)
next_q = jnp.where(switch, -boot_value * gamma, next_q)
last_return = jnp.where(switch, -boot_value, last_return)
discount = gamma * (1.0 - next_done)
last_return = reward + discount * jnp.where(
next_q >= next_value, last_return, next_value)
next_q = reward + discount * next_value
delta = next_q - cur_value
lastgaelam = delta + gae_lambda * discount * lastgaelam
carry = boot_value, boot_done, cur_value, lastgaelam, next_q, last_return
return carry, (lastgaelam, last_return)
next_done = next_dones[-1]
lastgaelam = jnp.zeros_like(next_value)
next_q = last_return = next_value
carry = next_value, next_done, next_value, lastgaelam, next_q, last_return
_, (advantages, returns) = jax.lax.scan(
body_fn, carry, (next_dones, values, rewards, switch), reverse=True
)
if upgo:
advantages += returns - values
targets = values + advantages
targets = jax.lax.stop_gradient(targets)
return targets, advantages
...@@ -58,13 +58,3 @@ def masked_normalize(x, valid, eps=1e-8): ...@@ -58,13 +58,3 @@ def masked_normalize(x, valid, eps=1e-8):
def to_tensor(x, device, dtype=None): def to_tensor(x, device, dtype=None):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x) return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
def load_embeddings(embedding_file, code_list_file):
with open(embedding_file, "rb") as f:
embeddings = pickle.load(f)
with open(code_list_file, "r") as f:
code_list = f.readlines()
code_list = [int(code.strip()) for code in code_list]
assert len(embeddings) == len(code_list), f"len(embeddings)={len(embeddings)}, len(code_list)={len(code_list)}"
embeddings = np.array([embeddings[code] for code in code_list], dtype=np.float32)
return embeddings
import pickle
import numpy as np
from pathlib import Path from pathlib import Path
...@@ -43,4 +45,15 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False): ...@@ -43,4 +45,15 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False):
elif 'EDOPro' in env_id: elif 'EDOPro' in env_id:
from ygoenv.edopro import init_module from ygoenv.edopro import init_module
init_module(str(db_path), code_list_file, decks) init_module(str(db_path), code_list_file, decks)
return deck_name return deck_name
\ No newline at end of file
def load_embeddings(embedding_file, code_list_file):
with open(embedding_file, "rb") as f:
embeddings = pickle.load(f)
with open(code_list_file, "r") as f:
code_list = f.readlines()
code_list = [int(code.strip()) for code in code_list]
assert len(embeddings) == len(code_list), f"len(embeddings)={len(embeddings)}, len(code_list)={len(code_list)}"
embeddings = np.array([embeddings[code] for code in code_list], dtype=np.float32)
return embeddings
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment