Commit 330ee6af authored by sbl1996@126.com's avatar sbl1996@126.com

Add truncated LSTM

parent cd59a6e9
This diff is collapsed.
This diff is collapsed.
...@@ -3,7 +3,7 @@ import time ...@@ -3,7 +3,7 @@ import time
import os import os
import random import random
from typing import Optional, Literal from typing import Optional, Literal
from dataclasses import dataclass from dataclasses import dataclass, field, asdict
import ygoenv import ygoenv
import numpy as np import numpy as np
...@@ -12,6 +12,7 @@ import tyro ...@@ -12,6 +12,7 @@ import tyro
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.jax.agent import RNNAgent, ModelArgs
@dataclass @dataclass
...@@ -57,14 +58,8 @@ class Args: ...@@ -57,14 +58,8 @@ class Args:
strategy: Literal["random", "greedy"] = "greedy" strategy: Literal["random", "greedy"] = "greedy"
"""the strategy to use if agent is not used""" """the strategy to use if agent is not used"""
num_layers: int = 2 m: ModelArgs = field(default_factory=lambda: ModelArgs())
"""the number of layers for the agent""" """the model arguments for the agent1"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent"""
rnn_type: Optional[str] = "lstm"
"""the type of RNN to use for agent, None for no RNN"""
checkpoint: Optional[str] = None checkpoint: Optional[str] = None
"""the checkpoint to load, must be a `flax_model` file""" """the checkpoint to load, must be a `flax_model` file"""
...@@ -78,11 +73,8 @@ class Args: ...@@ -78,11 +73,8 @@ class Args:
def create_agent(args): def create_agent(args):
return RNNAgent( return RNNAgent(
channels=args.num_channels, **asdict(args.m),
num_layers=args.num_layers,
rnn_channels=args.rnn_channels,
embedding_shape=args.num_embeddings, embedding_shape=args.num_embeddings,
rnn_type=args.rnn_type,
) )
...@@ -97,12 +89,14 @@ if __name__ == "__main__": ...@@ -97,12 +89,14 @@ if __name__ == "__main__":
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs) args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
deck = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file) deck, deck_names = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file, return_deck_names=True)
args.deck1 = args.deck1 or deck args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck args.deck2 = args.deck2 or deck
seed = args.seed seed = args.seed + 100000
random.seed(seed)
seed = random.randint(0, 1e8)
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
...@@ -135,22 +129,21 @@ if __name__ == "__main__": ...@@ -135,22 +129,21 @@ if __name__ == "__main__":
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import flax import flax
from ygoai.rl.jax.agent import RNNAgent
from jax.experimental.compilation_cache import compilation_cache as cc from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax")) cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
agent = create_agent(args) agent = create_agent(args)
key = jax.random.PRNGKey(args.seed) key = jax.random.PRNGKey(seed)
key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample()) sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
rstate = agent.init_rnn_state(1) rstate = agent.init_rnn_state(1)
params = jax.jit(agent.init)(agent_key, sample_obs, rstate) params = jax.jit(agent.init)(key, sample_obs, rstate)
with open(args.checkpoint, "rb") as f: with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read()) params = flax.serialization.from_bytes(params, f.read())
params = jax.device_put(params) params = jax.device_put(params)
rstate = agent.init_rnn_state(num_envs)
@jax.jit @jax.jit
def get_probs_and_value(params, rstate, obs, done): def get_probs_and_value(params, rstate, obs, done):
...@@ -180,6 +173,10 @@ if __name__ == "__main__": ...@@ -180,6 +173,10 @@ if __name__ == "__main__":
start = time.time() start = time.time()
start_step = step start_step = step
deck_names = sorted(deck_names)
deck_times = {name: 0 for name in deck_names}
deck_time_count = {name: 0 for name in deck_names}
model_time = env_time = 0 model_time = env_time = 0
while True: while True:
if start_step == 0 and len(episode_lengths) > int(args.num_episodes * 0.1): if start_step == 0 and len(episode_lengths) > int(args.num_episodes * 0.1):
...@@ -211,7 +208,20 @@ if __name__ == "__main__": ...@@ -211,7 +208,20 @@ if __name__ == "__main__":
step += 1 step += 1
for idx, d in enumerate(dones): for idx, d in enumerate(dones):
if d: if not d:
continue
for i in range(2):
deck_time = infos['step_time'][idx][i]
deck_name = deck_names[infos['deck'][idx][i]]
time_count = deck_time_count[deck_name]
avg_time = deck_times[deck_name]
avg_time = avg_time * (time_count / (time_count + 1)) + deck_time / (time_count + 1)
deck_times[deck_name] = avg_time
deck_time_count[deck_name] += 1
if deck_time_count[deck_name] % 100 == 0:
print(f"Deck {deck_name}: {avg_time:.4f}")
win_reason = infos['win_reason'][idx] win_reason = infos['win_reason'][idx]
episode_length = infos['l'][idx] episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] episode_reward = infos['r'][idx]
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment