Commit 308011ee authored by sbl1996@126.com's avatar sbl1996@126.com

Add docs for running

parent f89ac8fe
......@@ -8,29 +8,21 @@ YGO Agent is a project to create a Yu-Gi-Oh! AI using deep learning (LLMs, RL).
## ygoai
`ygoai` is a set of AI agents for playing Yu-Gi-Oh! It aims to achieve superhuman performance like AlphaGo and AlphaZero, with or without human knowledge. Currently, we focus on using reinforcement learning to train the agents.
## TODO
### Documentation
- Add documentations of building and running
### Training
- Eval with old models during training
- MTCS-based training
## Usage
### Inference
- MCTS-based planning
- Support of play in YGOPro
### Obtain a trained agent
We provide some trained agents in the [releases](https://github.com/sbl1996/ygo-agent/releases/tag/v0.1). Check these `ptj` TorchScript files and download them to your local machine. The following usage assumes you have it.
## Usage
Notice that the provided `ptj` can only run on GPU, but not CPU. Actually, the agent can run in real-time on CPU, we will provide a CPU version in the future.
### Serialize agent
### Play against the agent
After training, we can serialize the trained agent model to a file for later use without keeping source code of the model. The serialized model file will end with `.ptj` (PyTorch JIT) extension.
We can use `eval.py` to play against the trained agent with a MUD-like interface in the terminal.
```bash
python -u eval.py --agent --checkpoint checkpoints/1234_1000M.pt --num_embeddings 999 --convert --optimize
python -u eval.py --agent --deck ../assets/deck --lang chinese --checkpoint checkpoints/1234_1000M.ptj --play
```
### Battle between two agents
......@@ -41,8 +33,31 @@ We can use `battle.py` to let two agents play against each other and find out wh
python -u battle.py --deck ../assets/deck --checkpoint1 checkpoints/1234_1000M.ptj --checkpoint2 checkpoints/9876_100M.ptj --num-episodes=256 --num_envs=32 --seed 0
```
### Running
TODO
You can set `--num_envs=1 --verbose --record` to generate `.yrp` replay files.
### Serialize agent
After training, we can serialize the trained agent model to a file for later use without keeping source code of the model. The serialized model file will end with `.ptj` (PyTorch JIT) extension.
```bash
python -u eval.py --agent --checkpoint checkpoints/1234_1000M.pt --num_embeddings 999 --convert --optimize
```
## TODO
### Documentation
- Add documentations of building and running
### Training
- Evaluation with old models during training
- LSTM for memory
- League training following AlphaStar and ROA-Star
### Inference
- MCTS-based planning
- Support of play in YGOPro
## Related Projects
......
......@@ -142,50 +142,53 @@ if __name__ == "__main__":
envs = RecordEpisodeStatistics(envs)
if args.agent:
# count lines of code_list
embedding_shape = args.num_embeddings
if embedding_shape is None:
with open(args.code_list_file, "r") as f:
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
# agent = agent.eval()
if args.checkpoint:
state_dict = torch.load(args.checkpoint, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
if args.compile:
if args.convert:
# Don't support dynamic shapes and very slow inference
raise NotImplementedError
# obs = create_obs(envs.observation_space, (num_envs,), device=device)
# dynamic_shapes = {"x": {}}
# # batch_dim = torch.export.Dim("batch", min=1, max=64)
# batch_dim = None
# for k, v in obs.items():
# dynamic_shapes["x"][k] = {0: batch_dim}
# program = torch.export.export(
# agent, (obs,),
# dynamic_shapes=dynamic_shapes,
# )
# torch.export.save(program, args.checkpoint + "2")
# exit(0)
agent = torch.compile(agent, mode='reduce-overhead')
elif args.optimize:
obs = create_obs(envs.observation_space, (num_envs,), device=device)
def optimize_for_inference(agent):
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
return torch.jit.optimize_for_inference(traced_model)
agent = optimize_for_inference(agent)
if args.convert:
torch.jit.save(agent, args.checkpoint + "j")
print(f"Optimized model saved to {args.checkpoint}j")
exit(0)
if args.checkpoint.endswith(".ptj"):
agent = torch.jit.load(args.checkpoint)
else:
# count lines of code_list
embedding_shape = args.num_embeddings
if embedding_shape is None:
with open(args.code_list_file, "r") as f:
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
# agent = agent.eval()
if args.checkpoint:
state_dict = torch.load(args.checkpoint, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
if args.compile:
if args.convert:
# Don't support dynamic shapes and very slow inference
raise NotImplementedError
# obs = create_obs(envs.observation_space, (num_envs,), device=device)
# dynamic_shapes = {"x": {}}
# # batch_dim = torch.export.Dim("batch", min=1, max=64)
# batch_dim = None
# for k, v in obs.items():
# dynamic_shapes["x"][k] = {0: batch_dim}
# program = torch.export.export(
# agent, (obs,),
# dynamic_shapes=dynamic_shapes,
# )
# torch.export.save(program, args.checkpoint + "2")
# exit(0)
agent = torch.compile(agent, mode='reduce-overhead')
elif args.optimize:
obs = create_obs(envs.observation_space, (num_envs,), device=device)
def optimize_for_inference(agent):
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
return torch.jit.optimize_for_inference(traced_model)
agent = optimize_for_inference(agent)
if args.convert:
torch.jit.save(agent, args.checkpoint + "j")
print(f"Optimized model saved to {args.checkpoint}j")
exit(0)
obs, infos = envs.reset()
next_to_play = infos['to_play']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment