Commit 308011ee authored by sbl1996@126.com's avatar sbl1996@126.com

Add docs for running

parent f89ac8fe
...@@ -8,29 +8,21 @@ YGO Agent is a project to create a Yu-Gi-Oh! AI using deep learning (LLMs, RL). ...@@ -8,29 +8,21 @@ YGO Agent is a project to create a Yu-Gi-Oh! AI using deep learning (LLMs, RL).
## ygoai ## ygoai
`ygoai` is a set of AI agents for playing Yu-Gi-Oh! It aims to achieve superhuman performance like AlphaGo and AlphaZero, with or without human knowledge. Currently, we focus on using reinforcement learning to train the agents. `ygoai` is a set of AI agents for playing Yu-Gi-Oh! It aims to achieve superhuman performance like AlphaGo and AlphaZero, with or without human knowledge. Currently, we focus on using reinforcement learning to train the agents.
## TODO
### Documentation
- Add documentations of building and running
### Training ## Usage
- Eval with old models during training
- MTCS-based training
### Inference ### Obtain a trained agent
- MCTS-based planning
- Support of play in YGOPro
We provide some trained agents in the [releases](https://github.com/sbl1996/ygo-agent/releases/tag/v0.1). Check these `ptj` TorchScript files and download them to your local machine. The following usage assumes you have it.
## Usage Notice that the provided `ptj` can only run on GPU, but not CPU. Actually, the agent can run in real-time on CPU, we will provide a CPU version in the future.
### Serialize agent ### Play against the agent
After training, we can serialize the trained agent model to a file for later use without keeping source code of the model. The serialized model file will end with `.ptj` (PyTorch JIT) extension. We can use `eval.py` to play against the trained agent with a MUD-like interface in the terminal.
```bash ```bash
python -u eval.py --agent --checkpoint checkpoints/1234_1000M.pt --num_embeddings 999 --convert --optimize python -u eval.py --agent --deck ../assets/deck --lang chinese --checkpoint checkpoints/1234_1000M.ptj --play
``` ```
### Battle between two agents ### Battle between two agents
...@@ -41,8 +33,31 @@ We can use `battle.py` to let two agents play against each other and find out wh ...@@ -41,8 +33,31 @@ We can use `battle.py` to let two agents play against each other and find out wh
python -u battle.py --deck ../assets/deck --checkpoint1 checkpoints/1234_1000M.ptj --checkpoint2 checkpoints/9876_100M.ptj --num-episodes=256 --num_envs=32 --seed 0 python -u battle.py --deck ../assets/deck --checkpoint1 checkpoints/1234_1000M.ptj --checkpoint2 checkpoints/9876_100M.ptj --num-episodes=256 --num_envs=32 --seed 0
``` ```
### Running You can set `--num_envs=1 --verbose --record` to generate `.yrp` replay files.
TODO
### Serialize agent
After training, we can serialize the trained agent model to a file for later use without keeping source code of the model. The serialized model file will end with `.ptj` (PyTorch JIT) extension.
```bash
python -u eval.py --agent --checkpoint checkpoints/1234_1000M.pt --num_embeddings 999 --convert --optimize
```
## TODO
### Documentation
- Add documentations of building and running
### Training
- Evaluation with old models during training
- LSTM for memory
- League training following AlphaStar and ROA-Star
### Inference
- MCTS-based planning
- Support of play in YGOPro
## Related Projects ## Related Projects
......
...@@ -142,50 +142,53 @@ if __name__ == "__main__": ...@@ -142,50 +142,53 @@ if __name__ == "__main__":
envs = RecordEpisodeStatistics(envs) envs = RecordEpisodeStatistics(envs)
if args.agent: if args.agent:
# count lines of code_list if args.checkpoint.endswith(".ptj"):
embedding_shape = args.num_embeddings agent = torch.jit.load(args.checkpoint)
if embedding_shape is None: else:
with open(args.code_list_file, "r") as f: # count lines of code_list
code_list = f.readlines() embedding_shape = args.num_embeddings
embedding_shape = len(code_list) if embedding_shape is None:
L = args.num_layers with open(args.code_list_file, "r") as f:
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device) code_list = f.readlines()
# agent = agent.eval() embedding_shape = len(code_list)
if args.checkpoint: L = args.num_layers
state_dict = torch.load(args.checkpoint, map_location=device) agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
if not args.compile: # agent = agent.eval()
prefix = "_orig_mod." if args.checkpoint:
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()} state_dict = torch.load(args.checkpoint, map_location=device)
print(agent.load_state_dict(state_dict)) if not args.compile:
prefix = "_orig_mod."
if args.compile: state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
if args.convert: print(agent.load_state_dict(state_dict))
# Don't support dynamic shapes and very slow inference
raise NotImplementedError if args.compile:
# obs = create_obs(envs.observation_space, (num_envs,), device=device) if args.convert:
# dynamic_shapes = {"x": {}} # Don't support dynamic shapes and very slow inference
# # batch_dim = torch.export.Dim("batch", min=1, max=64) raise NotImplementedError
# batch_dim = None # obs = create_obs(envs.observation_space, (num_envs,), device=device)
# for k, v in obs.items(): # dynamic_shapes = {"x": {}}
# dynamic_shapes["x"][k] = {0: batch_dim} # # batch_dim = torch.export.Dim("batch", min=1, max=64)
# program = torch.export.export( # batch_dim = None
# agent, (obs,), # for k, v in obs.items():
# dynamic_shapes=dynamic_shapes, # dynamic_shapes["x"][k] = {0: batch_dim}
# ) # program = torch.export.export(
# torch.export.save(program, args.checkpoint + "2") # agent, (obs,),
# exit(0) # dynamic_shapes=dynamic_shapes,
agent = torch.compile(agent, mode='reduce-overhead') # )
elif args.optimize: # torch.export.save(program, args.checkpoint + "2")
obs = create_obs(envs.observation_space, (num_envs,), device=device) # exit(0)
def optimize_for_inference(agent): agent = torch.compile(agent, mode='reduce-overhead')
with torch.no_grad(): elif args.optimize:
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False) obs = create_obs(envs.observation_space, (num_envs,), device=device)
return torch.jit.optimize_for_inference(traced_model) def optimize_for_inference(agent):
agent = optimize_for_inference(agent) with torch.no_grad():
if args.convert: traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
torch.jit.save(agent, args.checkpoint + "j") return torch.jit.optimize_for_inference(traced_model)
print(f"Optimized model saved to {args.checkpoint}j") agent = optimize_for_inference(agent)
exit(0) if args.convert:
torch.jit.save(agent, args.checkpoint + "j")
print(f"Optimized model saved to {args.checkpoint}j")
exit(0)
obs, infos = envs.reset() obs, infos = envs.reset()
next_to_play = infos['to_play'] next_to_play = infos['to_play']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment