Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
80707a8c
Commit
80707a8c
authored
Mar 16, 2024
by
Biluo Shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add OSFP
parent
11261948
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
168 additions
and
154 deletions
+168
-154
.gitignore
.gitignore
+1
-0
scripts/ppo.py
scripts/ppo.py
+40
-56
scripts/ppo_osfp.py
scripts/ppo_osfp.py
+127
-98
No files found.
.gitignore
View file @
80707a8c
*.pt
*.ptj
*.pkl
# Xmake cache
...
...
scripts/ppo.py
View file @
80707a8c
...
...
@@ -3,7 +3,7 @@ import random
import
time
from
collections
import
deque
from
dataclasses
import
dataclass
from
typing
import
Literal
,
Optional
from
typing
import
Optional
import
ygoenv
...
...
@@ -52,10 +52,8 @@ class Args:
"""the embedding file for card embeddings"""
max_options
:
int
=
24
"""the maximum number of options"""
n_history_actions
:
int
=
16
n_history_actions
:
int
=
32
"""the number of history actions to use"""
play_mode
:
str
=
"bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers
:
int
=
2
"""the number of layers for the agent"""
...
...
@@ -74,9 +72,9 @@ class Args:
"""the number of steps to run in each environment per policy rollout"""
anneal_lr
:
bool
=
True
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
0.997
gamma
:
float
=
1.0
"""the discount factor gamma"""
gae_lambda
:
float
=
0.9
5
gae_lambda
:
float
=
0.9
8
"""the lambda for the general advantage estimation"""
minibatch_size
:
int
=
256
...
...
@@ -85,7 +83,7 @@ class Args:
"""the K epochs to update the policy"""
norm_adv
:
bool
=
True
"""Toggles advantages normalization"""
clip_coef
:
float
=
0.
1
clip_coef
:
float
=
0.
2
"""the surrogate clipping coefficient"""
clip_vloss
:
bool
=
True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
...
...
@@ -93,17 +91,13 @@ class Args:
"""coefficient of the entropy"""
vf_coef
:
float
=
0.5
"""coefficient of the value function"""
max_grad_norm
:
float
=
0.5
max_grad_norm
:
float
=
1.0
"""the maximum norm for the gradient clipping"""
target_kl
:
Optional
[
float
]
=
None
"""the target KL divergence threshold"""
learn_opponent
:
bool
=
True
learn_opponent
:
bool
=
False
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length
:
Optional
[
int
]
=
None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend
:
Literal
[
"gloo"
,
"nccl"
,
"mpi"
]
=
"nccl"
"""the backend for distributed training"""
compile
:
Optional
[
str
]
=
None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads
:
Optional
[
int
]
=
None
...
...
@@ -125,7 +119,7 @@ class Args:
"""the probability of logging"""
eval_episodes
:
int
=
128
"""the number of episodes to evaluate the model"""
eval_interval
:
int
=
1
0
eval_interval
:
int
=
5
0
"""the number of iterations to evaluate the model"""
# to be filled in runtime
...
...
@@ -143,6 +137,23 @@ class Args:
"""the number of processes (computed in runtime)"""
def
make_env
(
args
,
num_envs
,
num_threads
,
mode
=
'self'
):
envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
num_envs
=
num_envs
,
num_threads
=
num_threads
,
seed
=
args
.
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
'self'
,
)
envs
.
num_envs
=
num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
return
envs
def
main
():
rank
=
int
(
os
.
environ
.
get
(
"RANK"
,
0
))
local_rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
...
...
@@ -169,7 +180,7 @@ def main():
torch
.
set_float32_matmul_precision
(
'high'
)
if
args
.
world_size
>
1
:
torchrun_setup
(
args
.
backend
,
local_rank
)
torchrun_setup
(
'nccl'
,
local_rank
)
timestamp
=
int
(
time
.
time
())
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
...
...
@@ -204,43 +215,17 @@ def main():
args
.
deck2
=
args
.
deck2
or
deck
# env setup
envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
num_envs
=
args
.
local_num_envs
,
num_threads
=
local_env_threads
,
seed
=
args
.
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
'self'
,
)
envs
.
num_envs
=
args
.
local_num_envs
obs_space
=
envs
.
observation_space
action_shape
=
envs
.
action_space
.
shape
envs
=
make_env
(
args
,
args
.
local_num_envs
,
local_env_threads
)
obs_space
=
envs
.
env
.
observation_space
action_shape
=
envs
.
env
.
action_space
.
shape
if
local_rank
==
0
:
fprint
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
envs_per_thread
=
args
.
local_num_envs
//
local_env_threads
local_eval_episodes
=
args
.
eval_episodes
//
args
.
world_size
local_eval_num_envs
=
local_eval_episodes
eval_envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
num_envs
=
local_eval_num_envs
,
num_threads
=
max
(
1
,
local_eval_num_envs
//
envs_per_thread
),
seed
=
args
.
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
args
.
play_mode
,
)
eval_envs
.
num_envs
=
local_eval_num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
local_eval_num_threads
=
max
(
1
,
local_eval_num_envs
//
envs_per_thread
)
eval_envs
=
make_env
(
args
,
local_eval_num_envs
,
local_eval_num_threads
,
mode
=
'bot'
)
if
args
.
embedding_file
:
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
...
...
@@ -312,10 +297,10 @@ def main():
next_value1
=
next_value2
=
0
step
=
0
for
iteration
in
range
(
1
,
args
.
num_iterations
+
1
):
for
iteration
in
range
(
args
.
num_iterations
):
# Annealing the rate if instructed to do so.
if
args
.
anneal_lr
:
frac
=
1.0
-
(
iteration
-
1.0
)
/
args
.
num_iterations
frac
=
1.0
-
iteration
/
args
.
num_iterations
lrnow
=
frac
*
args
.
learning_rate
optimizer
.
param_groups
[
0
][
"lr"
]
=
lrnow
...
...
@@ -372,7 +357,7 @@ def main():
if
random
.
random
()
<
args
.
log_p
:
n
=
100
if
random
.
random
()
<
10
/
n
or
iteration
<=
2
:
if
random
.
random
()
<
10
/
n
or
iteration
<=
1
:
writer
.
add_scalar
(
"charts/episodic_return"
,
info
[
"r"
][
idx
],
global_step
)
writer
.
add_scalar
(
"charts/episodic_length"
,
info
[
"l"
][
idx
],
global_step
)
fprint
(
f
"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}"
)
...
...
@@ -394,7 +379,7 @@ def main():
nextvalues1
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
next_value1
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
ai_player1
,
value
,
next_value2
)
if
step
>
0
and
iteration
!=
1
:
if
step
>
0
and
iteration
!=
0
:
# recalculate the values for the first few steps
v_steps
=
args
.
local_minibatch_size
*
4
//
args
.
local_num_envs
for
v_start
in
range
(
0
,
step
,
v_steps
):
...
...
@@ -421,8 +406,11 @@ def main():
b_logprobs
=
logprobs
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_advantages
=
advantages
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_values
=
values
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_learns
=
torch
.
ones_like
(
b_values
,
dtype
=
torch
.
bool
)
if
args
.
learn_opponent
else
learns
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_returns
=
b_advantages
+
b_values
if
args
.
learn_opponent
:
b_learns
=
torch
.
ones_like
(
b_values
,
dtype
=
torch
.
bool
)
else
:
b_learns
=
learns
[:
args
.
num_steps
]
.
reshape
(
-
1
)
# Optimizing the policy and value network
b_inds
=
np
.
arange
(
args
.
local_batch_size
)
...
...
@@ -444,9 +432,6 @@ def main():
scaler
.
update
()
clipfracs
.
append
(
clipfrac
.
item
())
if
args
.
target_kl
is
not
None
and
approx_kl
>
args
.
target_kl
:
break
if
step
>
0
:
# TODO: use cyclic buffer to avoid copying
for
v
in
obs
.
values
():
...
...
@@ -463,7 +448,6 @@ def main():
var_y
=
np
.
var
(
y_true
)
explained_var
=
np
.
nan
if
var_y
==
0
else
1
-
np
.
var
(
y_true
-
y_pred
)
/
var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if
rank
==
0
:
if
iteration
%
args
.
save_interval
==
0
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent.pt"
))
...
...
@@ -490,7 +474,7 @@ def main():
if
rank
==
0
:
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
if
iteration
%
args
.
eval_interval
==
0
:
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
# Eval with rule-based policy
_start
=
time
.
time
()
eval_return
=
evaluate
(
...
...
scripts/ppo_osfp.py
View file @
80707a8c
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment