Commit 878f469c authored by sbl1996@126.com's avatar sbl1996@126.com

Update pybind to support numpy 2.0

parent 79854eae
......@@ -127,7 +127,7 @@ python -u battle.py --deck ../assets/deck --xla_device cpu --checkpoint1 checkpo
```
## Training (Deprecated, to be updated)
## Training
Training an agent requires a lot of computational resources, typically 8x4090 GPUs and 128-core CPU for a few days. We don't recommend training the agent on your local machine. Reducing the number of decks for training may reduce the computational resources required.
......@@ -136,9 +136,12 @@ Training an agent requires a lot of computational resources, typically 8x4090 GP
We can train the agent with a single GPU using the following command:
```bash
python -u ppo.py --deck ../assets/deck --seed 1 --embedding_file embed.pkl \
--minibatch-size 128 --learning-rate 1e-4 --update-epochs 2 --save_interval 100 \
--compile reduce-overhead --env_threads 16 --num_envs 64 --eval_episodes 32
cd scripts
python -u cleanba.py --actor-device-ids 0 --learner-device-ids 0 \
--local-num_envs 16 --num-minibatches 8 --learning-rate 1e-4 \
--update-epochs 1 --vloss_clip 1.0 --sep_value --value gae \
--save_interval 100 --seed 0 --m1.film --m1.noam --m1.version 2 \
--local_eval_episodes 32 --eval_interval 50
```
#### Deck
......@@ -151,43 +154,16 @@ To handle the diverse and complex card effects, we have converted the card infor
We provide one in the [releases](https://github.com/sbl1996/ygo-agent/releases/tag/v0.1), which named `embed{n}.pkl` where `n` is the number of cards in `code_list.txt`.
You can choose to not use the embeddings by skip the `--embedding_file` option. If you do it, remember to set `--num_embeddings` to `999` in the `eval.py` script.
#### Compile
We use `torch.compile` to speed up the overall training process. It is very important and can reduce the overall time by 2x or more. If the compilation fails, you may update the PyTorch version to the latest one.
You can choose to not use the embeddings by skip the `--embedding_file` option.
#### Seed
The `seed` option is used to set the random seed for reproducibility. However, many optimizations used in the training are not deterministic, so the results may still vary.
For debugging, you can set `--compile None --torch-deterministic` with the same seed to get a deterministic result.
The `seed` option is used to set the random seed for reproducibility. The training and and evaluation will be exactly the same under the same seed.
#### Hyperparameters
More PPO hyperparameters can be found in the `ppo.py` script. Tuning them may improve the performance but requires more computational resources.
More hyperparameters can be found in the `cleanba.py` script. Tuning them may improve the performance but requires more computational resources.
### Distributed Training
The `ppo.py` script supports single-node and multi-node distributed training with `torchrun`. Start distributed training like this:
```bash
# single node
OMP_NUM_THREADS=4 torchrun --standalone --nnodes=1 --nproc-per-node=8 ppo.py \
# multi node on nodes 0
OMP_NUM_THREADS=4 torchrun --nnodes=2 --nproc-per-node=8 --node-rank=0 \
--rdzv-id=12941 --master-addr=$MASTER_ADDR --master-port=$MASTER_PORT ppo.py \
# multi node on nodes 1
OMP_NUM_THREADS=4 torchrun --nnodes=2 --nproc-per-node=8 --node-rank=1 \
--rdzv-id=12941 --master-addr=$MASTER_ADDR --master-port=$MASTER_PORT ppo.py \
# script options
--deck ../assets/deck --seed 1 --embedding_file embed.pkl \
--minibatch-size 2048 --learning-rate 5e-4 --update-epochs 2 --save_interval 100 \
--compile reduce-overhead --env_threads 128 --num_envs 1024 --eval_episodes 128
```
The script options are mostly the same as the single GPU training. We only scale the batch size and the number of environments to the number of available CPUs and GPUs. The learning rate is then scaled according to the batch size.
TODO
## Plan
......@@ -218,4 +194,4 @@ This work is supported with Cloud TPUs from Google's [TPU Research Cloud (TRC)](
- [ygopro-core](https://github.com/Fluorohydride/ygopro-core)
- [envpool](https://github.com/sail-sg/envpool)
- [yugioh-ai](https://github.com/melvinzhang/yugioh-ai)
- [yugioh-game](https://github.com/tspivey/yugioh-game)
- [yugioh-game](https://github.com/tspivey/yugioh-game)
\ No newline at end of file
......@@ -17,7 +17,7 @@ import jax.numpy as jnp
import flax
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.utils import RecordEpisodeStatistics, EnvPreprocess
from ygoai.rl.jax.agent import RNNAgent, ModelArgs
......@@ -46,6 +46,8 @@ class Args:
"""the number of history actions to use for the environment1"""
n_history_actions2: Optional[int] = None
"""the number of history actions to use for the environment2, defaults to `n_history_actions1`"""
oppo_info: bool = False
"""whether to use opponent information"""
num_embeddings: Optional[int] = None
"""the number of embeddings of the agent"""
accurate: bool = True
......@@ -160,8 +162,11 @@ if __name__ == "__main__":
n_history_actions=args.n_history_actions1,
deck1=args.deck1,
deck2=args.deck2,
oppo_info=args.oppo_info,
**env_option,
)
envs1 = EnvPreprocess(envs1, skip_mask=not args.oppo_info)
if cross_env:
envs2 = ygoenv.make(
task_id=env_id2,
......
......@@ -63,8 +63,6 @@ class Args:
"""the name of the tensorboard run"""
ckpt_dir: str = "checkpoints"
"""the directory to save the model checkpoints"""
gcs_bucket: Optional[str] = None
"""the GCS bucket to save the model checkpoints"""
# Algorithm specific arguments
env_id: str = "YGOPro-v1"
......@@ -218,7 +216,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
greedy_reward=args.greedy_reward if not eval else True,
play_mode=mode,
timeout=args.timeout,
oppo_info=args.m2.oppo_info if eval else args.m1.oppo_info,
oppo_info=False,
)
envs.num_envs = num_envs
return envs
......@@ -262,13 +260,6 @@ def get_variables(agent_state):
return variables
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
np.zeros((num_envs, rnn_channels)),
)
def reshape_minibatch(
x, multi_step, num_minibatches, num_steps, segment_length=None, key=None):
# if segment_length is None,
......@@ -357,7 +348,7 @@ def rollout(
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
)
envs = EnvPreprocess(envs, skip_mask=not args.m1.oppo_info)
envs = EnvPreprocess(envs, skip_mask=True)
envs = RecordEpisodeStatistics(envs)
eval_envs = make_env(
......@@ -378,17 +369,19 @@ def rollout(
avg_win_rates = deque(maxlen=1000)
agent = create_agent(args)
apply_fn = agent.apply
eval_agent = create_agent(args, eval=eval_mode != 'bot')
eval_apply_fn = eval_agent.apply
@jax.jit
def get_action(params, obs, rstate):
rstate, logits = eval_agent.apply(params, obs, rstate)[:2]
rstate, logits = eval_apply_fn(params, obs, rstate)[:2]
return rstate, logits.argmax(axis=1)
@jax.jit
def get_action_battle(params1, params2, obs, rstate1, rstate2, main, done):
next_rstate1, logits1 = agent.apply(params1, obs, rstate1)[:2]
next_rstate2, logits2 = eval_agent.apply(params2, obs, rstate2)[:2]
next_rstate1, logits1 = apply_fn(params1, obs, rstate1)[:2]
next_rstate2, logits2 = eval_apply_fn(params2, obs, rstate2)[:2]
logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
......@@ -401,7 +394,7 @@ def rollout(
@jax.jit
def sample_action(
params, next_obs, rstate1, rstate2, main, done, key):
(rstate1, rstate2), logits, value = agent.apply(
(rstate1, rstate2), logits, value = apply_fn(
params, next_obs, (rstate1, rstate2), done, main)[:3]
value = jnp.squeeze(value, axis=-1)
action, key = categorical_sample(logits, key)
......@@ -608,6 +601,7 @@ def rollout(
if update % args.log_frequency == 0:
avg_episodic_return = np.mean(avg_ep_returns)
avg_episodic_length = np.mean(envs.returned_episode_lengths)
max_episode_length = np.max(envs.returned_episode_lengths)
SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time))
SPS_update = int(args.batch_size / (time.time() - update_time_start))
......@@ -625,6 +619,7 @@ def rollout(
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), tb_global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, tb_global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, tb_global_step)
writer.add_scalar("charts/max_episode_length", max_episode_length, tb_global_step)
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), tb_global_step)
writer.add_scalar("stats/inference_time", inference_time, tb_global_step)
writer.add_scalar("stats/env_time", env_time, tb_global_step)
......@@ -780,9 +775,7 @@ def main():
tx = optax.apply_if_finite(tx, max_consecutive_errors=10)
if 'batch_stats' not in variables:
# variables = flax.core.unfreeze(variables)
variables['batch_stats'] = {}
# variables = flax.core.freeze(variables)
agent_state = TrainState.create(
apply_fn=None,
params=variables['params'],
......@@ -1046,7 +1039,7 @@ def main():
(loss, pg_loss, v_loss, ent_loss, approx_kl) = jax.lax.scan(
update_minibatch_t, (carry, init_rstate), minibatch_t)
return carry, (loss, pg_loss, v_loss, ent_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, ent_loss, approx_kl) = jax.lax.scan(
update_minibatch,
agent_state,
......
......@@ -3,7 +3,7 @@ add_rules("mode.debug", "mode.release")
add_repositories("my-repo repo")
add_requires(
"ygopro-core", "edopro-core", "pybind11 2.10.*", "fmt 10.2.*", "glog 0.6.0",
"ygopro-core", "edopro-core", "pybind11 2.13.*", "fmt 10.2.*", "glog 0.6.0",
"sqlite3 3.43.0+200", "concurrentqueue 1.0.4", "unordered_dense 4.4.*",
"sqlitecpp 3.2.1")
......
......@@ -2077,6 +2077,7 @@ public:
if (c[0] == idx) {
c.erase(c.begin());
if (c.empty()) {
// TODO: maybe finish too early
_callback_multi_select_2_finish();
return;
} else {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment