Commit 662b300f authored by sbl1996@126.com's avatar sbl1996@126.com

Update doc and defaults for release

parent 03416f14
# YGO Agent
YGO Agent is a project to create a Yu-Gi-Oh! AI using deep learning (LLMs, RL). It consists of a game environment and a set of AI agents.
YGO Agent is a project aimed at mastering the popular trading card game Yu-Gi-Oh! through deep learning. Based on a high-performance game environment (ygoenv), this project leverages reinforcement learning and large language models to develop advanced AI agents (ygoai) that aim to match or surpass human expert play. YGO Agent provides researchers and players with a platform for exploring AI in complex, strategic game environments.
[Discord](https://discord.gg/EqWYj4G4Ys)
## News
## News🔥
- July 2, 2024: We have a discord channel for discussion now! We are also working with [neos-ts](https://github.com/DarkNeos/neos-ts) to implement human-AI battle.
- April 18, 2024: We have fully switched to JAX for training and evaluation. Check the evaluation sections for more details and try the new JAX-trained agents.
- April 14, 2024: LSTM has been implemented and well tested. See `scripts/jax/ppo.py` for more details.
- April 7, 2024: We have switched to JAX for training and evalution due to the better performance and flexibility. The scripts are in the `scripts/jax` directory. The documentation is in progress. PyTorch scripts are still available in the `scripts` directory, but they are not maintained.
- 2024.7.2 - We have a discord channel for discussion now! We are also working with [neos-ts](https://github.com/DarkNeos/neos-ts) to implement human-AI battle.
- 2024.4.18 - LSTM has been implemented and well tested.
- 2024.4.7 - We have switched to JAX for training and evaluation due to the better performance and flexibility.
## Table of Contents
- [Subprojects](#subprojects)
- [ygoenv](#ygoenv)
- [ygoai](#ygoai)
- [Building](#building)
- [Common Issues](#common-issues)
- [Installation](#installation)
- [Building from source](#building-from-source)
- [Troubleshooting](#troubleshooting)
- [Evaluation](#evaluation)
- [Obtain a trained agent](#obtain-a-trained-agent)
- [Play against the agent](#play-against-the-agent)
- [Battle between two agents](#battle-between-two-agents)
- [Training (Deprecated, to be updated)](#training-deprecated-to-be-updated)
- [Training](#training)
- [Single GPU Training](#single-gpu-training)
- [Distributed Training](#distributed-training)
- [Plan](#plan)
- [Roadmap](#roadmap)
- [Environment](#environment)
- [Training](#training-1)
- [Inference](#inference)
......@@ -40,93 +37,102 @@ YGO Agent is a project to create a Yu-Gi-Oh! AI using deep learning (LLMs, RL).
## Subprojects
### ygoenv
`ygoenv` is a high performance game environment for Yu-Gi-Oh! It is initially inspired by [yugioh-ai](https://github.com/melvinzhang/yugioh-ai]) and [yugioh-game](https://github.com/tspivey/yugioh-game), and now implemented on top of [envpool](https://github.com/sail-sg/envpool).
`ygoenv` is a high performance game environment for Yu-Gi-Oh!, implemented on top of [envpool](https://github.com/sail-sg/envpool) and [ygopro-core](https://github.com/Fluorohydride/ygopro-core). It provides standard gym interface for reinforcement learning.
### ygoai
`ygoai` is a set of AI agents for playing Yu-Gi-Oh! It aims to achieve superhuman performance like AlphaGo and AlphaZero, with or without human knowledge. Currently, we focus on using reinforcement learning to train the agents.
## Building
The following building instructions are only tested on Ubuntu (WSL2) and may not work on other platforms.
To build the project, you need to install the following prerequisites first:
## Installation
Pre-built binaries are available for Ubuntu 22.04 or newer. If you're using them, follow the installation instructions below. Otherwise, please build from source following [Building from source](#building-from-source).
1. Install JAX and other dependencies:
```bash
# Install JAX (CPU version)
pip install -U "jax<=0.4.28"
# Or with CUDA support
pip install -U "jax[cuda12]<=0.4.28"
# Install other dependencies
pip install flax distrax chex
```
2. Clone the repository and install pre-built binary (Ubuntu 22.04 or newer):
```bash
git clone https://github.com/sbl1996/ygo-agent.git
cd ygo-agent
# Choose the appropriate version for your Python (cp310, cp311, or cp312)
wget -nv https://github.com/sbl1996/ygo-agent/releases/download/v0.1/ygopro_ygoenv_cp310.so
mv ygopro_ygoenv_cp310.so ygoenv/ygoenv/ygopro/ygopro_ygoenv.so
make
```
3. Verify the installation:
```bash
cd scripts
python -u eval.py --env-id "YGOPro-v1" --deck ../assets/deck/ --num_episodes 32 --strategy random --lang chinese --num_envs 16
```
If you see episode logs and the output contains this line, the environment is working correctly. For more usage examples, see the [Evaluation](#evaluation) section.
```
len=76.5758, reward=-0.1751, win_rate=0.3939, win_reason=0.9697
```
### Building from source
If you can't use the pre-built binary or prefer to build from source, follow these instructions. Note: These instructions are tested on Ubuntu 22.04 and may not work on other platforms.
#### Additional Prerequisites
- gcc 10+ or clang 11+
- CMake 3.12+
- [xmake](https://xmake.io/#/getting_started)
- jax 0.4.25+, flax 0.8.2+, distrax 0.1.5+ (CUDA is optional)
After that, you can build with the following commands:
#### Build Instructions
```bash
git clone https://github.com/sbl1996/ygo-agent.git
cd ygo-agent
git checkout stable # switch to the stable branch
xmake f -y
make
```
After building, you can run the following command to test the environment. If you see episode logs, it means the environment is working. Try more usage in the next section!
```bash
cd scripts
python -u eval.py --env-id "YGOPro-v1" --deck ../assets/deck/ --num_episodes 32 --strategy random --lang chinese --num_envs 16
make dev
```
### Common Issues
### Troubleshooting
#### Package version not found by xmake
Delete `repositories`, `cache`, `packages` directories in the `~/.xmake` directory and run `xmake f -y` again.
Delete `repositories`, `cache`, `packages` directories in the `~/.xmake` directory and run `xmake f -y -c` again.
#### Install packages failed with xmake
Sometimes you may fail to install the required libraries by xmake automatically (e.g., `glog` and `gflags`). You can install them manually (e.g., `apt install`) and put them in the search path (`$LD_LIBRARY_PATH` or others), then xmake will find them.
If xmake fails to install required libraries automatically (e.g., `glog` and `gflags`), install them manually (e.g., `apt install`) and add them to the search path (`$LD_LIBRARY_PATH` or others).
#### GLIBC and GLIBCXX version conflict
Mostly, it is because your `libstdc++` from `$CONDA_PREFIX` is older than the system one, while xmake compiles libraries with the system one and you run programs with the `$CONDA_PREFIX` one. If so, you can delete the old `libstdc++` from `$CONDA_PREFIX` (backup it first) and make a soft link to the system one.
#### Other issues
Open a new terminal and try again. If you still encounter issues, you can join the [Discord channel](https://discord.gg/EqWYj4G4Ys) for help.
Open a new terminal and try again. If issues persist, join our [Discord channel](https://discord.gg/EqWYj4G4Ys) for help.
## Evaluation
### Obtain a trained agent
We provide trained agents in the [releases](https://github.com/sbl1996/ygo-agent/releases/tag/v0.1). Check these Flax checkpoint files named with `{commit_hash}_{exp_id}_{step}.flax_model` and download (the lastest) one to your local machine. The following usage assumes you have it.
If you are not in the `stable` branch or encounter any other running issues, you can try to switch to the `commit_hash` commit before using the agent. You may need to rebuild the project after switching:
```bash
xmake f -c
xmake b -r ygopro_ygoenv
```
We provide trained agents in the [releases](https://github.com/sbl1996/ygo-agent/releases/tag/v0.1). Check these Flax checkpoint files named with `{exp_id}_{step}.flax_model` and download (the lastest) one to your local machine. The following usage assumes you have it.
### Play against the agent
We can use `eval.py` to play against the trained agent with a MUD-like interface in the terminal. We add `--xla_device cpu` to run the agent on the CPU.
```bash
python -u eval.py --deck ../assets/deck --lang chinese --xla_device cpu --checkpoint checkpoints/350c29a_7565_6700M.flax_model --play
```
We can enter `quit` to exit the game. Run `python eval.py --help` for more options, for example, `--player 0` to make the agent play as the first player, `--deck1 TenyiSword` to force the first player to use the TenyiSword deck.
We can play against the agent with any YGOPro clients now. TODO.
### Battle between two agents
We can use `battle.py` to let two agents play against each other and find out which one is better.
We can use `battle.py` to let two agents play against each other and find out which one is better. Adding `--xla_device cpu` forces JAX to run on CPU.
```bash
python -u battle.py --deck ../assets/deck --checkpoint1 checkpoints/350c29a_7565_6700M.flax_model --checkpoint2 checkpoints/350c29a_1166_6200M.flax_model --num-episodes 32 --num_envs 8 --seed 0
python -u battle.py --xla_device cpu --checkpoint1 checkpoints/0546_16500M.flax_model --checkpoint2 checkpoints/0546_11300M.flax_model --num-episodes 32 --seed 0
```
We can set `--record` to generate `.yrp` replay files to the `replay` directory. The `yrp` files can be replayed in YGOPro compatible clients (YGOPro, YGOPro2, KoishiPro, MDPro). Change `--seed` to generate different games.
```bash
python -u battle.py --deck ../assets/deck --xla_device cpu --checkpoint1 checkpoints/350c29a_7565_6700M.flax_model --checkpoint2 checkpoints/350c29a_1166_6200M.flax_model --num-episodes 16 --record --seed 0
python -u battle.py --xla_device cpu --checkpoint1 checkpoints/0546_16500M.flax_model --checkpoint2 checkpoints/0546_11300M.flax_model --num-episodes 16 --seed 1 --record
```
## Training
Training an agent requires a lot of computational resources, typically 8x4090 GPUs and 128-core CPU for a few days. We don't recommend training the agent on your local machine. Reducing the number of decks for training may reduce the computational resources required.
......@@ -138,10 +144,8 @@ We can train the agent with a single GPU using the following command:
```bash
cd scripts
python -u cleanba.py --actor-device-ids 0 --learner-device-ids 0 \
--local-num_envs 16 --num-minibatches 8 --learning-rate 1e-4 \
--update-epochs 1 --vloss_clip 1.0 --sep_value --value gae \
--save_interval 100 --seed 0 --m1.film --m1.noam --m1.version 2 \
--local_eval_episodes 32 --eval_interval 50
--local-num_envs 16 --num-minibatches 8 --learning-rate 1e-4 --vloss_clip 1.0 \
--save_interval 100 --local_eval_episodes 32 --eval_interval 50 --seed 0
```
#### Deck
......@@ -165,7 +169,7 @@ More hyperparameters can be found in the `cleanba.py` script. Tuning them may im
### Distributed Training
TODO
## Plan
## Roadmap
### Environment
- Generation of yrpX replay files
......@@ -178,8 +182,8 @@ TODO
- Centralized critic with full observation
### Inference
- Export as SavedModel
- MCTS-based planning
- Support of play in YGOPro
### Documentation
- JAX training
......@@ -193,5 +197,6 @@ This work is supported with Cloud TPUs from Google's [TPU Research Cloud (TRC)](
## Related Projects
- [ygopro-core](https://github.com/Fluorohydride/ygopro-core)
- [envpool](https://github.com/sail-sg/envpool)
- [neos-ts](https://github.com/DarkNeos/neos-ts)
- [yugioh-ai](https://github.com/melvinzhang/yugioh-ai)
- [yugioh-game](https://github.com/tspivey/yugioh-game)
\ No newline at end of file
......@@ -131,7 +131,7 @@ if __name__ == "__main__":
seed = args.seed + 100000
random.seed(seed)
seed = random.randint(0, 1e8)
seed = random.randint(0, int(1e8))
random.seed(seed)
np.random.seed(seed)
......@@ -165,6 +165,7 @@ if __name__ == "__main__":
oppo_info=args.oppo_info,
**env_option,
)
envs1.num_envs = num_envs
envs1 = EnvPreprocess(envs1, skip_mask=not args.oppo_info)
if cross_env:
......@@ -175,11 +176,11 @@ if __name__ == "__main__":
deck2=deck2,
**env_option,
)
envs2.num_envs = num_envs
key = jax.random.PRNGKey(seed)
obs_space1 = envs1.observation_space
envs1.num_envs = num_envs
envs1 = RecordEpisodeStatistics(envs1)
sample_obs1 = jax.tree.map(lambda x: jnp.array([x]), obs_space1.sample())
agent1 = create_agent1(args)
......@@ -190,7 +191,6 @@ if __name__ == "__main__":
if cross_env:
obs_space2 = envs2.observation_space
envs2.num_envs = num_envs
envs2 = RecordEpisodeStatistics(envs2)
sample_obs2 = jax.tree.map(lambda x: jnp.array([x]), obs_space2.sample())
else:
......
......@@ -106,7 +106,7 @@ class Args:
"""the discount factor gamma"""
num_minibatches: int = 64
"""the number of mini-batches"""
update_epochs: int = 2
update_epochs: int = 1
"""the K epochs to update the policy"""
switch: bool = False
"""Toggle the use of switch mechanism"""
......@@ -119,7 +119,7 @@ class Args:
"""Toggle the use of UPGO for advantages"""
sep_value: bool = True
"""Whether separate value function computation for each player"""
value: Literal["vtrace", "gae"] = "vtrace"
value: Literal["vtrace", "gae"] = "gae"
"""the method to learn the value function"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
......@@ -715,14 +715,14 @@ def main():
# seeding
random.seed(args.seed)
seed = random.randint(0, 1e8)
seed = random.randint(0, int(1e8))
seed_offset = args.local_rank
seed += seed_offset
init_key = jax.random.PRNGKey(seed - seed_offset)
random.seed(seed)
args.real_seed = random.randint(0, 1e8)
args.real_seed = random.randint(0, int(1e8))
key = jax.random.PRNGKey(args.real_seed)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
......
......@@ -716,14 +716,14 @@ def main():
# seeding
random.seed(args.seed)
seed = random.randint(0, 1e8)
seed = random.randint(0, int(1e8))
seed_offset = args.local_rank
seed += seed_offset
init_key = jax.random.PRNGKey(seed - seed_offset)
random.seed(seed)
args.real_seed = random.randint(0, 1e8)
args.real_seed = random.randint(0, int(1e8))
key = jax.random.PRNGKey(args.real_seed)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
......
......@@ -743,14 +743,14 @@ def main():
# seeding
random.seed(args.seed)
seed = random.randint(0, 1e8)
seed = random.randint(0, int(1e8))
seed_offset = args.local_rank
seed += seed_offset
init_key = jax.random.PRNGKey(seed - seed_offset)
random.seed(seed)
args.real_seed = random.randint(0, 1e8)
args.real_seed = random.randint(0, int(1e8))
key = jax.random.PRNGKey(args.real_seed)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
......
......@@ -96,7 +96,7 @@ if __name__ == "__main__":
seed = args.seed + 100000
random.seed(seed)
seed = random.randint(0, 1e8)
seed = random.randint(0, int(1e8))
random.seed(seed)
np.random.seed(seed)
......
......@@ -83,11 +83,10 @@ class CardEncoder(nn.Module):
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
oppo_info: bool = False
version: int = 0
version: int = 2
@nn.compact
def __call__(self, x_id, x, mask):
assert self.version > 0
c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True, dtype=self.dtype)
......@@ -105,13 +104,6 @@ class CardEncoder(nn.Module):
x_loc = x1[:, :, 0]
x_seq = x1[:, :, 1]
if self.version == 0:
x_id = mlp(
(c, c // 4), kernel_init=default_fc_init2)(x_id)
x_id = layer_norm()(x_id)
f_loc = layer_norm()(embed(9, c)(x_loc))
f_seq = layer_norm()(embed(76, c)(x_seq))
c_mask = x_loc == 0
c_mask = c_mask.at[:, 0].set(False)
......@@ -130,44 +122,34 @@ class CardEncoder(nn.Module):
x_def = fc_embed(c // 16, kernel_init=default_fc_init1)(x_def)
x_type = fc_embed(c // 16 * 2, kernel_init=default_fc_init2)(x2[:, :, 4:])
if self.version == 0:
x_f = jnp.concatenate([
x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated,
x_atk, x_def, x_type], axis=-1)
x_f = layer_norm()(x_f)
f_cards = jnp.concatenate([x_id, x_f], axis=-1)
f_cards = f_cards + f_loc + f_seq
f_cards_g = None
x_id = mlp((c,), kernel_init=default_fc_init2)(x_id)
x_id = jax.nn.swish(x_id)
f_loc = embed(9, c // 16 * 2)(x_loc)
f_seq = embed(76, c // 16 * 2)(x_seq)
feats_g = [
x_id, f_loc, f_seq, x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated, x_atk, x_def, x_type]
if mask is not None:
assert len(feats_g) == mask.shape[-1]
feats = [
jnp.where(mask[..., i:i+1] == 1, f, f[..., -1:, :])
for i, f in enumerate(feats_g)
]
else:
x_id = mlp((c,), kernel_init=default_fc_init2)(x_id)
x_id = jax.nn.swish(x_id)
f_loc = embed(9, c // 16 * 2)(x_loc)
f_seq = embed(76, c // 16 * 2)(x_seq)
feats_g = [
x_id, f_loc, f_seq, x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated, x_atk, x_def, x_type]
if mask is not None:
assert len(feats_g) == mask.shape[-1]
feats = [
jnp.where(mask[..., i:i+1] == 1, f, f[..., -1:, :])
for i, f in enumerate(feats_g)
]
else:
feats = feats_g
x_cards = jnp.concatenate(feats[1:], axis=-1)
x_cards = mlp((c,), kernel_init=default_fc_init2)(x_cards)
x_cards = x_cards * feats[0]
f_cards = layer_norm()(x_cards)
# f_cards = f_cards.astype(self.dtype)
if self.oppo_info:
x_cards_g = jnp.concatenate(feats_g[1:], axis=-1)
x_cards_g = mlp((c,), kernel_init=default_fc_init2)(x_cards_g)
x_cards_g = x_cards_g * feats_g[0]
f_cards_g = layer_norm()(x_cards_g)
# f_cards_g = f_cards_g.astype(self.dtype)
else:
f_cards_g = None
feats = feats_g
x_cards = jnp.concatenate(feats[1:], axis=-1)
x_cards = mlp((c,), kernel_init=default_fc_init2)(x_cards)
x_cards = x_cards * feats[0]
f_cards = layer_norm()(x_cards)
# f_cards = f_cards.astype(self.dtype)
if self.oppo_info:
x_cards_g = jnp.concatenate(feats_g[1:], axis=-1)
x_cards_g = mlp((c,), kernel_init=default_fc_init2)(x_cards_g)
x_cards_g = x_cards_g * feats_g[0]
f_cards_g = layer_norm()(x_cards_g)
# f_cards_g = f_cards_g.astype(self.dtype)
else:
f_cards_g = None
return f_cards_g, f_cards, c_mask
......@@ -175,7 +157,7 @@ class GlobalEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
version: int = 0
version: int = 2
@nn.compact
def __call__(self, x):
......@@ -230,7 +212,7 @@ class Encoder(nn.Module):
noam: bool = False
action_feats: bool = True
oppo_info: bool = False
version: int = 0
version: int = 2
@nn.compact
def __call__(self, x):
......@@ -252,7 +234,7 @@ class Encoder(nn.Module):
card_encoder = CardEncoder(
channels=c, dtype=self.dtype, param_dtype=self.param_dtype,
version=self.version, oppo_info=self.oppo_info)
ActionEncoderCls = ActionEncoder if self.version == 0 else ActionEncoderV1
ActionEncoderCls = ActionEncoderV1
action_encoder = ActionEncoderCls(
channels=c, dtype=self.dtype, param_dtype=self.param_dtype)
......@@ -313,60 +295,33 @@ class Encoder(nn.Module):
# History actions
x_h_actions = x_h_actions.astype(jnp.int32)
if self.version == 0:
h_mask = x_h_actions[:, :, 2] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., :2])
x_h_id = id_embed(x_h_id)
if self.freeze_id:
x_h_id = jax.lax.stop_gradient(x_h_id)
x_h_id = MLP(
(c, c), dtype=self.dtype, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(x_h_id)
x_h_a_feats1 = action_encoder(x_h_actions[:, :, 2:13])
x_h_a_player = embed(2, c // 2)(x_h_actions[:, :, 13])
x_h_a_turn = embed(20, c // 2)(x_h_actions[:, :, 14])
x_h_a_feats = jnp.concatenate([
*x_h_a_feats1, x_h_a_player, x_h_a_turn], axis=-1)
f_h_actions = layer_norm()(x_h_id) + layer_norm()(fc_layer(c)(x_h_a_feats))
f_h_actions = PositionalEncoding()(f_h_actions)
for _ in range(self.num_layers):
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_h_actions, src_key_padding_mask=h_mask)
f_g_h_actions = layer_norm()(f_h_actions[:, 0])
h_mask = x_h_actions[:, :, 3] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., 1:3])
x_h_id = id_embed(x_h_id)
if self.freeze_id:
x_h_id = jax.lax.stop_gradient(x_h_id)
x_h_id = fc_layer(c)(x_h_id)
x_h_a_feats = action_encoder(x_h_actions[:, :, 3:12])
x_h_a_turn = embed(20, c // 2)(x_h_actions[:, :, 12])
x_h_a_phase = embed(12, c // 2)(x_h_actions[:, :, 13])
x_h_a_feats.extend([x_h_id, x_h_a_turn, x_h_a_phase])
x_h_a_feats = jnp.concatenate(x_h_a_feats, axis=-1)
x_h_a_feats = layer_norm()(x_h_a_feats)
x_h_a_feats = fc_layer(c)(x_h_a_feats)
if self.noam:
f_h_actions = LlamaEncoderLayer(
num_heads, dtype=self.dtype, param_dtype=self.param_dtype,
rope=True, n_positions=64)(x_h_a_feats, src_key_padding_mask=h_mask)
else:
h_mask = x_h_actions[:, :, 3] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., 1:3])
x_h_id = id_embed(x_h_id)
if self.freeze_id:
x_h_id = jax.lax.stop_gradient(x_h_id)
x_h_id = fc_layer(c)(x_h_id)
x_h_a_feats = action_encoder(x_h_actions[:, :, 3:12])
x_h_a_turn = embed(20, c // 2)(x_h_actions[:, :, 12])
x_h_a_phase = embed(12, c // 2)(x_h_actions[:, :, 13])
x_h_a_feats.extend([x_h_id, x_h_a_turn, x_h_a_phase])
x_h_a_feats = jnp.concatenate(x_h_a_feats, axis=-1)
x_h_a_feats = layer_norm()(x_h_a_feats)
x_h_a_feats = fc_layer(c)(x_h_a_feats)
if self.noam:
f_h_actions = LlamaEncoderLayer(
num_heads, dtype=self.dtype, param_dtype=self.param_dtype,
rope=True, n_positions=64)(x_h_a_feats, src_key_padding_mask=h_mask)
else:
x_h_a_feats = PositionalEncoding()(x_h_a_feats)
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
x_h_a_feats, src_key_padding_mask=h_mask)
f_g_h_actions = layer_norm()(f_h_actions[:, 0])
x_h_a_feats = PositionalEncoding()(x_h_a_feats)
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
x_h_a_feats, src_key_padding_mask=h_mask)
f_g_h_actions = layer_norm()(f_h_actions[:, 0])
# Actions
......@@ -379,63 +334,42 @@ class Encoder(nn.Module):
f_na_card = jnp.tile(na_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_na_card, f_cards[:, 1:]], axis=1)
if self.version == 0:
spec_index = decode_id(x_actions[..., :2])
B = jnp.arange(batch_size)
f_a_cards = f_cards[B[:, None], spec_index]
f_a_cards = fc_layer(c)(f_a_cards)
x_a_feats = jnp.concatenate(action_encoder(x_actions[..., 2:]), axis=-1)
x_a_feats = fc_layer(c)(x_a_feats)
f_actions = jnp.concatenate([f_a_cards, x_a_feats], axis=-1)
f_actions = fc_layer(c)(nn.leaky_relu(f_actions, negative_slope=0.1))
f_actions = layer_norm(dtype=self.dtype)(f_actions)
a_mask = x_actions[:, :, 2] == 0
a_mask = a_mask.at[:, 0].set(False)
spec_index = x_actions[..., 0]
B = jnp.arange(batch_size)
f_a_cards = f_cards[B[:, None], spec_index]
x_a_id = decode_id(x_actions[..., 1:3])
x_a_id = id_embed(x_a_id)
if self.freeze_id:
x_a_id = jax.lax.stop_gradient(x_a_id)
x_a_id = fc_layer(c)(x_a_id)
x_a_feats = action_encoder(x_actions[..., 3:])
x_a_feats.append(x_a_id)
x_a_feats = jnp.concatenate(x_a_feats, axis=-1)
x_a_feats = layer_norm()(x_a_feats)
x_a_feats = fc_layer(c)(x_a_feats)
f_a_cards = fc_layer(c)(f_a_cards)
f_actions = jax.nn.silu(f_a_cards) * x_a_feats
f_actions = fc_layer(c)(f_actions)
f_actions = x_a_feats + f_actions
a_mask = x_actions[:, :, 3] == 0
a_mask = a_mask.at[:, 0].set(False)
g_feats = [f_g_card, f_global]
if self.use_history:
g_feats.append(f_g_h_actions)
if self.action_feats:
f_actions_g = fc_layer(c)(f_actions)
a_mask_ = (1 - a_mask.astype(f_actions.dtype))
f_g_actions = (f_actions * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = (f_actions_g * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
if not self.use_history:
f_g_h_actions = jnp.zeros_like(f_g_h_actions)
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
else:
spec_index = x_actions[..., 0]
B = jnp.arange(batch_size)
f_a_cards = f_cards[B[:, None], spec_index]
x_a_id = decode_id(x_actions[..., 1:3])
x_a_id = id_embed(x_a_id)
if self.freeze_id:
x_a_id = jax.lax.stop_gradient(x_a_id)
x_a_id = fc_layer(c)(x_a_id)
x_a_feats = action_encoder(x_actions[..., 3:])
x_a_feats.append(x_a_id)
x_a_feats = jnp.concatenate(x_a_feats, axis=-1)
x_a_feats = layer_norm()(x_a_feats)
x_a_feats = fc_layer(c)(x_a_feats)
f_a_cards = fc_layer(c)(f_a_cards)
f_actions = jax.nn.silu(f_a_cards) * x_a_feats
f_actions = fc_layer(c)(f_actions)
f_actions = x_a_feats + f_actions
a_mask = x_actions[:, :, 3] == 0
a_mask = a_mask.at[:, 0].set(False)
g_feats = [f_g_card, f_global]
if self.use_history:
g_feats.append(f_g_h_actions)
if self.action_feats:
f_actions_g = fc_layer(c)(f_actions)
a_mask_ = (1 - a_mask.astype(f_actions.dtype))
f_g_actions = (f_actions_g * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
g_feats.append(f_g_actions)
f_state = jnp.concatenate(g_feats, axis=-1)
g_feats.append(f_g_actions)
f_state = jnp.concatenate(g_feats, axis=-1)
oc = self.out_channels or c
if self.version == 2:
f_state = GLUMlp(
......@@ -573,7 +507,7 @@ def rnn_step_by_main(rnn_layer, rstate, f_state, done, main, return_state=False)
return rstate, f_state
def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True, return_state=False):
def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=False, return_state=False):
if switch:
def body_fn(cell, carry, x, done, switch):
rstate, init_rstate2 = carry
......@@ -601,11 +535,11 @@ class EncoderArgs:
"""whether to use history actions as input for agent"""
card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer"""
noam: bool = False
noam: bool = True
"""whether to use Noam architecture for the transformer layer"""
action_feats: bool = True
"""whether to use action features for the global state"""
version: int = 0
version: int = 2
"""the version of the environment and the agent"""
......@@ -615,7 +549,7 @@ class ModelArgs(EncoderArgs):
"""the number of channels for the RNN in the agent"""
rnn_type: Optional[Literal['lstm', 'gru', 'rwkv', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN"""
film: bool = False
film: bool = True
"""whether to use FiLM for the actor"""
oppo_info: bool = False
"""whether to use opponent's information"""
......@@ -638,8 +572,8 @@ class RNNAgent(nn.Module):
use_history: bool = True
card_mask: bool = False
rnn_type: str = 'lstm'
film: bool = False
noam: bool = False
film: bool = True
noam: bool = True
rwkv_head_size: int = 32
action_feats: bool = True
oppo_info: bool = False
......@@ -647,10 +581,10 @@ class RNNAgent(nn.Module):
batch_norm: bool = False
critic_width: int = 128
critic_depth: int = 3
version: int = 0
version: int = 2
q_head: bool = False
switch: bool = True
switch: bool = False
freeze_id: bool = False
int_head: bool = False
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
......
......@@ -646,11 +646,11 @@ class EncoderArgs:
"""whether to use history actions as input for agent"""
card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer"""
noam: bool = False
noam: bool = True
"""whether to use Noam architecture for the transformer layer"""
action_feats: bool = True
"""whether to use action features for the global state"""
version: int = 0
version: int = 2
"""the version of the environment and the agent"""
......@@ -660,7 +660,7 @@ class ModelArgs(EncoderArgs):
"""the number of channels for the RNN in the agent"""
rnn_type: Optional[Literal['lstm', 'gru', 'rwkv', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN"""
film: bool = False
film: bool = True
"""whether to use FiLM for the actor"""
rnn_shortcut: bool = False
"""whether to use shortcut for the RNN"""
......@@ -684,15 +684,15 @@ class RNNAgent(nnx.Module):
use_history: bool = True,
card_mask: bool = False,
rnn_type: str = 'lstm',
film: bool = False,
noam: bool = False,
film: bool = True,
noam: bool = True,
rwkv_head_size: int = 32,
action_feats: bool = True,
rnn_shortcut: bool = False,
batch_norm: bool = False,
critic_width: int = 128,
critic_depth: int = 3,
version: int = 0,
version: int = 2,
q_head: bool = False,
switch: bool = True,
freeze_id: bool = False,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment