Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
e7a19464
Commit
e7a19464
authored
Feb 21, 2024
by
biluo.shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Improve for compile
parent
eba0e134
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
61 additions
and
21 deletions
+61
-21
scripts/eval.py
scripts/eval.py
+8
-3
scripts/ppo.py
scripts/ppo.py
+38
-13
ygoai/rl/agent.py
ygoai/rl/agent.py
+14
-5
ygoai/rl/dist.py
ygoai/rl/dist.py
+1
-0
No files found.
scripts/eval.py
View file @
e7a19464
...
@@ -14,7 +14,7 @@ import tyro
...
@@ -14,7 +14,7 @@ import tyro
from
ygoai.utils
import
init_ygopro
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.agent
import
Agent
from
ygoai.rl.agent
import
PPOAgent
as
Agent
from
ygoai.rl.buffer
import
create_obs
from
ygoai.rl.buffer
import
create_obs
...
@@ -171,8 +171,13 @@ if __name__ == "__main__":
...
@@ -171,8 +171,13 @@ if __name__ == "__main__":
_start
=
time
.
time
()
_start
=
time
.
time
()
obs
=
optree
.
tree_map
(
lambda
x
:
torch
.
from_numpy
(
x
)
.
to
(
device
=
device
),
obs
)
obs
=
optree
.
tree_map
(
lambda
x
:
torch
.
from_numpy
(
x
)
.
to
(
device
=
device
),
obs
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
values
=
agent
(
obs
)[
0
]
logits
,
values
=
agent
(
obs
)
actions
=
torch
.
argmax
(
values
,
dim
=
1
)
.
cpu
()
.
numpy
()
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs
=
probs
.
cpu
()
.
numpy
()
if
args
.
play
:
print
(
probs
[
probs
!=
0
]
.
tolist
())
print
(
values
)
actions
=
probs
.
argmax
(
axis
=
1
)
model_time
+=
time
.
time
()
-
_start
model_time
+=
time
.
time
()
-
_start
else
:
else
:
if
args
.
strategy
==
"random"
:
if
args
.
strategy
==
"random"
:
...
...
scripts/ppo.py
View file @
e7a19464
...
@@ -13,7 +13,9 @@ import tyro
...
@@ -13,7 +13,9 @@ import tyro
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torch.optim
as
optim
from
torch.distributions
import
Categorical
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.cuda.amp
import
GradScaler
,
autocast
from
ygoai.utils
import
init_ygopro
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.utils
import
RecordEpisodeStatistics
...
@@ -44,7 +46,7 @@ class Args:
...
@@ -44,7 +46,7 @@ class Args:
"""the deck file for the second player"""
"""the deck file for the second player"""
code_list_file
:
str
=
"code_list.txt"
code_list_file
:
str
=
"code_list.txt"
"""the code list file for card embeddings"""
"""the code list file for card embeddings"""
embedding_file
:
str
=
"embeddings_en.npy"
embedding_file
:
Optional
[
str
]
=
"embeddings_en.npy"
"""the embedding file for card embeddings"""
"""the embedding file for card embeddings"""
max_options
:
int
=
24
max_options
:
int
=
24
"""the maximum number of options"""
"""the maximum number of options"""
...
@@ -101,6 +103,10 @@ class Args:
...
@@ -101,6 +103,10 @@ class Args:
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads
:
Optional
[
int
]
=
None
env_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for envpool, defaults to `num_envs`"""
"""the number of threads to use for envpool, defaults to `num_envs`"""
fp16_train
:
bool
=
False
"""if toggled, training will be done in fp16 precision"""
fp16_eval
:
bool
=
False
"""if toggled, evaluation will be done in fp16 precision"""
tb_dir
:
str
=
"./runs"
tb_dir
:
str
=
"./runs"
"""tensorboard log directory"""
"""tensorboard log directory"""
...
@@ -199,20 +205,28 @@ def run(local_rank, world_size):
...
@@ -199,20 +205,28 @@ def run(local_rank, world_size):
envs
=
RecordEpisodeStatistics
(
envs
)
envs
=
RecordEpisodeStatistics
(
envs
)
if
args
.
embedding_file
:
embeddings
=
np
.
load
(
args
.
embedding_file
)
embeddings
=
np
.
load
(
args
.
embedding_file
)
embedding_shape
=
embeddings
.
shape
else
:
embedding_shape
=
None
L
=
args
.
num_layers
L
=
args
.
num_layers
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
1
,
embeddings
.
shape
)
.
to
(
device
)
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
1
,
embedding_shape
)
.
to
(
device
)
if
args
.
embedding_file
:
agent
.
load_embeddings
(
embeddings
)
agent
.
load_embeddings
(
embeddings
)
if
args
.
compile
:
#
if args.compile:
agent
.
get_action_and_value
=
torch
.
compile
(
agent
.
get_action_and_value
,
mode
=
args
.
compile_mode
)
#
agent.get_action_and_value = torch.compile(agent.get_action_and_value, mode=args.compile_mode)
optimizer
=
optim
.
Adam
(
agent
.
parameters
(),
lr
=
args
.
learning_rate
,
eps
=
1e-5
)
optimizer
=
optim
.
Adam
(
agent
.
parameters
(),
lr
=
args
.
learning_rate
,
eps
=
1e-5
)
scaler
=
GradScaler
(
enabled
=
args
.
fp16_train
)
def
masked_mean
(
x
,
valid
):
def
masked_mean
(
x
,
valid
):
x
=
x
.
masked_fill
(
~
valid
,
0
)
x
=
x
.
masked_fill
(
~
valid
,
0
)
return
x
.
sum
()
/
valid
.
float
()
.
sum
()
return
x
.
sum
()
/
valid
.
float
()
.
sum
()
def
train_step
(
agent
,
mb_obs
,
mb_actions
,
mb_logprobs
,
mb_advantages
,
mb_returns
,
mb_values
):
def
train_step
(
agent
,
scaler
,
mb_obs
,
mb_actions
,
mb_logprobs
,
mb_advantages
,
mb_returns
,
mb_values
):
with
autocast
(
enabled
=
args
.
fp16_train
):
_
,
newlogprob
,
entropy
,
newvalue
,
valid
=
agent
.
get_action_and_value
(
mb_obs
,
mb_actions
.
long
())
_
,
newlogprob
,
entropy
,
newvalue
,
valid
=
agent
.
get_action_and_value
(
mb_obs
,
mb_actions
.
long
())
logratio
=
newlogprob
-
mb_logprobs
logratio
=
newlogprob
-
mb_logprobs
ratio
=
logratio
.
exp
()
ratio
=
logratio
.
exp
()
...
@@ -251,12 +265,20 @@ def run(local_rank, world_size):
...
@@ -251,12 +265,20 @@ def run(local_rank, world_size):
entropy_loss
=
masked_mean
(
entropy
,
valid
)
entropy_loss
=
masked_mean
(
entropy
,
valid
)
loss
=
pg_loss
-
args
.
ent_coef
*
entropy_loss
+
v_loss
*
args
.
vf_coef
loss
=
pg_loss
-
args
.
ent_coef
*
entropy_loss
+
v_loss
*
args
.
vf_coef
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
loss
.
backward
()
scaler
.
scale
(
loss
)
.
backward
()
scaler
.
unscale_
(
optimizer
)
reduce_gradidents
(
agent
,
args
.
world_size
)
reduce_gradidents
(
agent
,
args
.
world_size
)
return
old_approx_kl
,
approx_kl
,
clipfrac
,
pg_loss
,
v_loss
,
entropy_loss
return
old_approx_kl
,
approx_kl
,
clipfrac
,
pg_loss
,
v_loss
,
entropy_loss
def
predict_step
(
agent
,
next_obs
):
with
torch
.
no_grad
():
with
autocast
(
enabled
=
args
.
fp16_eval
):
logits
,
values
=
agent
(
next_obs
)
return
logits
,
values
if
args
.
compile
:
if
args
.
compile
:
train_step
=
torch
.
compile
(
train_step
,
mode
=
args
.
compile_mode
)
train_step
=
torch
.
compile
(
train_step
,
mode
=
args
.
compile_mode
)
predict_step
=
torch
.
compile
(
predict_step
,
mode
=
args
.
compile_mode
)
def
to_tensor
(
x
,
dtype
=
torch
.
float32
):
def
to_tensor
(
x
,
dtype
=
torch
.
float32
):
return
optree
.
tree_map
(
lambda
x
:
torch
.
from_numpy
(
x
)
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
True
),
x
)
return
optree
.
tree_map
(
lambda
x
:
torch
.
from_numpy
(
x
)
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
True
),
x
)
...
@@ -296,8 +318,10 @@ def run(local_rank, world_size):
...
@@ -296,8 +318,10 @@ def run(local_rank, world_size):
dones
[
step
]
=
next_done
dones
[
step
]
=
next_done
_start
=
time
.
time
()
_start
=
time
.
time
()
with
torch
.
no_grad
():
logits
,
value
=
predict_step
(
agent
,
next_obs
)
action
,
logprob
,
_
,
value
,
valid
=
agent
.
get_action_and_value
(
next_obs
)
probs
=
Categorical
(
logits
=
logits
)
action
=
probs
.
sample
()
logprob
=
probs
.
log_prob
(
action
)
values
[
step
]
=
value
.
flatten
()
values
[
step
]
=
value
.
flatten
()
actions
[
step
]
=
action
actions
[
step
]
=
action
...
@@ -374,10 +398,11 @@ def run(local_rank, world_size):
...
@@ -374,10 +398,11 @@ def run(local_rank, world_size):
k
:
v
[
mb_inds
]
for
k
,
v
in
b_obs
.
items
()
k
:
v
[
mb_inds
]
for
k
,
v
in
b_obs
.
items
()
}
}
old_approx_kl
,
approx_kl
,
clipfrac
,
pg_loss
,
v_loss
,
entropy_loss
=
\
old_approx_kl
,
approx_kl
,
clipfrac
,
pg_loss
,
v_loss
,
entropy_loss
=
\
train_step
(
agent
,
mb_obs
,
b_actions
[
mb_inds
],
b_logprobs
[
mb_inds
],
b_advantages
[
mb_inds
],
train_step
(
agent
,
scaler
,
mb_obs
,
b_actions
[
mb_inds
],
b_logprobs
[
mb_inds
],
b_advantages
[
mb_inds
],
b_returns
[
mb_inds
],
b_values
[
mb_inds
])
b_returns
[
mb_inds
],
b_values
[
mb_inds
])
nn
.
utils
.
clip_grad_norm_
(
agent
.
parameters
(),
args
.
max_grad_norm
)
nn
.
utils
.
clip_grad_norm_
(
agent
.
parameters
(),
args
.
max_grad_norm
)
optimizer
.
step
()
scaler
.
step
(
optimizer
)
scaler
.
update
()
clipfracs
.
append
(
clipfrac
.
item
())
clipfracs
.
append
(
clipfrac
.
item
())
if
args
.
target_kl
is
not
None
and
approx_kl
>
args
.
target_kl
:
if
args
.
target_kl
is
not
None
and
approx_kl
>
args
.
target_kl
:
...
...
ygoai/rl/agent.py
View file @
e7a19464
...
@@ -45,7 +45,7 @@ class Encoder(nn.Module):
...
@@ -45,7 +45,7 @@ class Encoder(nn.Module):
self
.
bin_intervals
=
nn
.
Parameter
(
bin_intervals
,
requires_grad
=
False
)
self
.
bin_intervals
=
nn
.
Parameter
(
bin_intervals
,
requires_grad
=
False
)
if
embedding_shape
is
None
:
if
embedding_shape
is
None
:
n_embed
,
embed_dim
=
1
5
0
,
1024
n_embed
,
embed_dim
=
1
00
0
,
1024
else
:
else
:
n_embed
,
embed_dim
=
embedding_shape
n_embed
,
embed_dim
=
embedding_shape
n_embed
=
1
+
n_embed
# 1 (index 0) for unknown
n_embed
=
1
+
n_embed
# 1 (index 0) for unknown
...
@@ -339,19 +339,28 @@ class PPOAgent(nn.Module):
...
@@ -339,19 +339,28 @@ class PPOAgent(nn.Module):
f
=
(
f_actions
*
c_mask
)
.
sum
(
dim
=
1
)
/
c_mask
.
sum
(
dim
=
1
)
f
=
(
f_actions
*
c_mask
)
.
sum
(
dim
=
1
)
/
c_mask
.
sum
(
dim
=
1
)
return
self
.
critic
(
f
)
return
self
.
critic
(
f
)
def
get_action_and_value
(
self
,
x
,
action
=
None
):
def
get_action_and_value
(
self
,
x
,
action
):
f_actions
,
mask
,
valid
=
self
.
encoder
(
x
)
f_actions
,
mask
,
valid
=
self
.
encoder
(
x
)
c_mask
=
1
-
mask
.
unsqueeze
(
-
1
)
.
float
()
c_mask
=
1
-
mask
.
unsqueeze
(
-
1
)
.
float
()
f
=
(
f_actions
*
c_mask
)
.
sum
(
dim
=
1
)
/
c_mask
.
sum
(
dim
=
1
)
f
=
(
f_actions
*
c_mask
)
.
sum
(
dim
=
1
)
/
c_mask
.
sum
(
dim
=
1
)
logits
=
self
.
actor
(
f_actions
)[
...
,
0
]
logits
=
self
.
actor
(
f_actions
)[
...
,
0
]
logits
=
logits
.
float
()
logits
=
logits
.
masked_fill
(
mask
,
float
(
"-inf"
))
logits
=
logits
.
masked_fill
(
mask
,
float
(
"-inf"
))
probs
=
Categorical
(
logits
=
logits
)
probs
=
Categorical
(
logits
=
logits
)
if
action
is
None
:
action
=
probs
.
sample
()
return
action
,
probs
.
log_prob
(
action
),
probs
.
entropy
(),
self
.
critic
(
f
),
valid
return
action
,
probs
.
log_prob
(
action
),
probs
.
entropy
(),
self
.
critic
(
f
),
valid
def
forward
(
self
,
x
):
f_actions
,
mask
,
valid
=
self
.
encoder
(
x
)
c_mask
=
1
-
mask
.
unsqueeze
(
-
1
)
.
float
()
f
=
(
f_actions
*
c_mask
)
.
sum
(
dim
=
1
)
/
c_mask
.
sum
(
dim
=
1
)
logits
=
self
.
actor
(
f_actions
)[
...
,
0
]
logits
=
logits
.
float
()
logits
=
logits
.
masked_fill
(
mask
,
float
(
"-inf"
))
return
logits
,
self
.
critic
(
f
)
class
DMCAgent
(
nn
.
Module
):
class
DMCAgent
(
nn
.
Module
):
...
...
ygoai/rl/dist.py
View file @
e7a19464
...
@@ -39,6 +39,7 @@ def mp_start(run):
...
@@ -39,6 +39,7 @@ def mp_start(run):
if
world_size
==
1
:
if
world_size
==
1
:
run
(
local_rank
=
0
,
world_size
=
world_size
)
run
(
local_rank
=
0
,
world_size
=
world_size
)
else
:
else
:
mp
.
set_start_method
(
'spawn'
)
children
=
[]
children
=
[]
for
i
in
range
(
world_size
):
for
i
in
range
(
world_size
):
subproc
=
mp
.
Process
(
target
=
run
,
args
=
(
i
,
world_size
))
subproc
=
mp
.
Process
(
target
=
run
,
args
=
(
i
,
world_size
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment