Commit 330ee6af authored by sbl1996@126.com's avatar sbl1996@126.com

Add truncated LSTM

parent cd59a6e9
This diff is collapsed.
This diff is collapsed.
......@@ -3,7 +3,7 @@ import time
import os
import random
from typing import Optional, Literal
from dataclasses import dataclass
from dataclasses import dataclass, field, asdict
import ygoenv
import numpy as np
......@@ -12,6 +12,7 @@ import tyro
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.jax.agent import RNNAgent, ModelArgs
@dataclass
......@@ -57,14 +58,8 @@ class Args:
strategy: Literal["random", "greedy"] = "greedy"
"""the strategy to use if agent is not used"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent"""
rnn_type: Optional[str] = "lstm"
"""the type of RNN to use for agent, None for no RNN"""
m: ModelArgs = field(default_factory=lambda: ModelArgs())
"""the model arguments for the agent1"""
checkpoint: Optional[str] = None
"""the checkpoint to load, must be a `flax_model` file"""
......@@ -78,11 +73,8 @@ class Args:
def create_agent(args):
return RNNAgent(
channels=args.num_channels,
num_layers=args.num_layers,
rnn_channels=args.rnn_channels,
**asdict(args.m),
embedding_shape=args.num_embeddings,
rnn_type=args.rnn_type,
)
......@@ -97,12 +89,14 @@ if __name__ == "__main__":
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
deck = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file)
deck, deck_names = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file, return_deck_names=True)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
seed = args.seed
seed = args.seed + 100000
random.seed(seed)
seed = random.randint(0, 1e8)
random.seed(seed)
np.random.seed(seed)
......@@ -135,22 +129,21 @@ if __name__ == "__main__":
import jax
import jax.numpy as jnp
import flax
from ygoai.rl.jax.agent import RNNAgent
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
agent = create_agent(args)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
key = jax.random.PRNGKey(seed)
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
rstate = agent.init_rnn_state(1)
params = jax.jit(agent.init)(agent_key, sample_obs, rstate)
params = jax.jit(agent.init)(key, sample_obs, rstate)
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
params = jax.device_put(params)
rstate = agent.init_rnn_state(num_envs)
@jax.jit
def get_probs_and_value(params, rstate, obs, done):
......@@ -180,6 +173,10 @@ if __name__ == "__main__":
start = time.time()
start_step = step
deck_names = sorted(deck_names)
deck_times = {name: 0 for name in deck_names}
deck_time_count = {name: 0 for name in deck_names}
model_time = env_time = 0
while True:
if start_step == 0 and len(episode_lengths) > int(args.num_episodes * 0.1):
......@@ -211,7 +208,20 @@ if __name__ == "__main__":
step += 1
for idx, d in enumerate(dones):
if d:
if not d:
continue
for i in range(2):
deck_time = infos['step_time'][idx][i]
deck_name = deck_names[infos['deck'][idx][i]]
time_count = deck_time_count[deck_name]
avg_time = deck_times[deck_name]
avg_time = avg_time * (time_count / (time_count + 1)) + deck_time / (time_count + 1)
deck_times[deck_name] = avg_time
deck_time_count[deck_name] += 1
if deck_time_count[deck_name] % 100 == 0:
print(f"Deck {deck_name}: {avg_time:.4f}")
win_reason = infos['win_reason'][idx]
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx]
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment