Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
1487b081
Commit
1487b081
authored
Mar 22, 2024
by
Biluo Shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
(WIP) add cleanba_ppo
parent
80707a8c
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
607 additions
and
11 deletions
+607
-11
docs/feature_engineering.md
docs/feature_engineering.md
+4
-1
scripts/battle.py
scripts/battle.py
+1
-1
scripts/ppo_c.py
scripts/ppo_c.py
+573
-0
scripts/ppo_osfp.py
scripts/ppo_osfp.py
+13
-8
ygoai/rl/agent.py
ygoai/rl/agent.py
+2
-1
ygoai/rl/ppo.py
ygoai/rl/ppo.py
+14
-0
No files found.
docs/feature_engineering.md
View file @
1487b081
...
@@ -88,4 +88,7 @@
...
@@ -88,4 +88,7 @@
## History Actions
## History Actions
-
0,1: card id, uint16 -> 2 uint8
-
0,1: card id, uint16 -> 2 uint8
-
others same as legal actions
-
2-12 same as legal actions
-
13: player, discrete, 0: me, 1: oppo
-
14: turn, discrete, trunc to 3
scripts/battle.py
View file @
1487b081
...
@@ -41,7 +41,7 @@ class Args:
...
@@ -41,7 +41,7 @@ class Args:
"""the language to use"""
"""the language to use"""
max_options
:
int
=
24
max_options
:
int
=
24
"""the maximum number of options"""
"""the maximum number of options"""
n_history_actions
:
int
=
16
n_history_actions
:
int
=
32
"""the number of history actions to use"""
"""the number of history actions to use"""
num_embeddings
:
Optional
[
int
]
=
None
num_embeddings
:
Optional
[
int
]
=
None
"""the number of embeddings of the agent"""
"""the number of embeddings of the agent"""
...
...
scripts/ppo_c.py
0 → 100644
View file @
1487b081
This diff is collapsed.
Click to expand it.
scripts/ppo_osfp.py
View file @
1487b081
...
@@ -69,7 +69,7 @@ class Args:
...
@@ -69,7 +69,7 @@ class Args:
"""the number of parallel game environments"""
"""the number of parallel game environments"""
num_steps
:
int
=
128
num_steps
:
int
=
128
"""the number of steps to run in each environment per policy rollout"""
"""the number of steps to run in each environment per policy rollout"""
anneal_lr
:
bool
=
Tru
e
anneal_lr
:
bool
=
Fals
e
"""Toggle learning rate annealing for policy and value networks"""
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
1.0
gamma
:
float
=
1.0
"""the discount factor gamma"""
"""the discount factor gamma"""
...
@@ -329,21 +329,17 @@ def main():
...
@@ -329,21 +329,17 @@ def main():
global_step
=
0
global_step
=
0
warmup_steps
=
0
warmup_steps
=
0
start_time
=
time
.
time
()
start_time
=
time
.
time
()
next_obs
,
info
=
envs
.
reset
()
next_obs
=
to_tensor
(
next_obs
,
device
,
dtype
=
torch
.
uint8
)
next_to_play_
=
info
[
"to_play"
]
next_to_play
=
to_tensor
(
next_to_play_
,
device
)
next_done
=
torch
.
zeros
(
args
.
local_num_envs
,
device
=
device
,
dtype
=
torch
.
bool
)
next_done
=
torch
.
zeros
(
args
.
local_num_envs
,
device
=
device
,
dtype
=
torch
.
bool
)
ai_player1_
=
np
.
concatenate
([
ai_player1_
=
np
.
concatenate
([
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
])
np
.
random
.
shuffle
(
ai_player1_
)
np
.
random
.
shuffle
(
ai_player1_
)
ai_player1
=
to_tensor
(
ai_player1_
,
device
,
dtype
=
next_to_play
.
dtype
)
ai_player1
=
to_tensor
(
ai_player1_
,
device
)
next_value1
=
next_value2
=
0
next_value1
=
next_value2
=
0
step
=
0
step
=
0
ts
=
[]
lp_count
=
0
lp_count
=
0
ts
=
sample_target
(
history
)
for
iteration
in
range
(
args
.
num_iterations
):
for
iteration
in
range
(
args
.
num_iterations
):
# Annealing the rate if instructed to do so.
# Annealing the rate if instructed to do so.
...
@@ -351,6 +347,15 @@ def main():
...
@@ -351,6 +347,15 @@ def main():
frac
=
1.0
-
(
iteration
%
args
.
iter_per_lp
)
/
args
.
iter_per_lp
frac
=
1.0
-
(
iteration
%
args
.
iter_per_lp
)
/
args
.
iter_per_lp
lrnow
=
frac
*
args
.
learning_rate
lrnow
=
frac
*
args
.
learning_rate
optimizer
.
param_groups
[
0
][
"lr"
]
=
lrnow
optimizer
.
param_groups
[
0
][
"lr"
]
=
lrnow
if
iteration
%
args
.
iter_per_lp
==
0
:
next_obs
,
info
=
envs
.
reset
()
next_obs
=
to_tensor
(
next_obs
,
device
,
dtype
=
torch
.
uint8
)
next_to_play_
=
info
[
"to_play"
]
next_to_play
=
to_tensor
(
next_to_play_
,
device
)
next_value1
=
next_value2
=
0
step
=
0
ts
=
[]
if
len
(
ts
)
==
0
:
if
len
(
ts
)
==
0
:
ts
=
sample_target
(
history
)
ts
=
sample_target
(
history
)
...
@@ -538,7 +543,7 @@ def main():
...
@@ -538,7 +543,7 @@ def main():
if
(
iteration
+
1
)
%
args
.
iter_per_lp
==
0
:
if
(
iteration
+
1
)
%
args
.
iter_per_lp
==
0
:
lp_count
+=
1
lp_count
+=
1
win_rates
=
sync_var
(
avg_win_rates
,
dtype
=
torch
.
float32
,
reduce
=
'mean'
)
win_rates
=
sync_var
(
avg_win_rates
,
dtype
=
torch
.
float32
,
reduce
=
'mean'
)
if
np
.
all
(
win_rates
>
args
.
update_win_rate
)
or
lp_count
>=
args
.
max_lp
:
if
len
(
history
)
==
0
or
np
.
all
(
win_rates
>
args
.
update_win_rate
)
or
lp_count
>=
args
.
max_lp
:
agent_t
.
load_state_dict
(
agent
.
state_dict
())
agent_t
.
load_state_dict
(
agent
.
state_dict
())
with
torch
.
no_grad
():
with
torch
.
no_grad
():
traced_model_t
=
torch
.
jit
.
trace
(
agent_t
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
traced_model_t
=
torch
.
jit
.
trace
(
agent_t
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
...
...
ygoai/rl/agent.py
View file @
1487b081
...
@@ -343,7 +343,8 @@ class Encoder(nn.Module):
...
@@ -343,7 +343,8 @@ class Encoder(nn.Module):
mask
=
x_actions
[:,
:,
2
]
==
0
# msg == 0
mask
=
x_actions
[:,
:,
2
]
==
0
# msg == 0
valid
=
x
[
'global_'
][:,
-
1
]
==
0
valid
=
x
[
'global_'
][:,
-
1
]
==
0
mask
[:,
0
]
&=
valid
mask
[:,
0
]
=
False
# mask[:, 0] &= valid
for
layer
in
self
.
action_card_net
:
for
layer
in
self
.
action_card_net
:
f_actions
=
layer
(
f_actions
=
layer
(
f_actions
,
f_cards
[:,
1
:],
tgt_key_padding_mask
=
mask
,
memory_key_padding_mask
=
c_mask
)
f_actions
,
f_cards
[:,
1
:],
tgt_key_padding_mask
=
mask
,
memory_key_padding_mask
=
c_mask
)
...
...
ygoai/rl/ppo.py
View file @
1487b081
...
@@ -54,6 +54,20 @@ def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_adv
...
@@ -54,6 +54,20 @@ def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_adv
return
old_approx_kl
,
approx_kl
,
clipfrac
,
pg_loss
,
v_loss
,
entropy_loss
return
old_approx_kl
,
approx_kl
,
clipfrac
,
pg_loss
,
v_loss
,
entropy_loss
def
bootstrap_value
(
values
,
rewards
,
dones
,
nextvalues
,
next_done
,
gamma
,
gae_lambda
):
num_steps
=
rewards
.
size
(
0
)
advantages
=
torch
.
zeros_like
(
rewards
)
lastgaelam
=
0
for
t
in
reversed
(
range
(
num_steps
)):
if
t
==
num_steps
-
1
:
nextnonterminal
=
1.0
-
next_done
nextvalues
=
nextvalues
else
:
nextnonterminal
=
1.0
-
dones
[
t
+
1
]
nextvalues
=
values
[
t
+
1
]
delta
=
rewards
[
t
]
+
gamma
*
nextvalues
*
nextnonterminal
-
values
[
t
]
advantages
[
t
]
=
lastgaelam
=
delta
+
gamma
*
gae_lambda
*
nextnonterminal
*
lastgaelam
def
bootstrap_value_self
(
values
,
rewards
,
dones
,
learns
,
nextvalues
,
next_done
,
gamma
,
gae_lambda
):
def
bootstrap_value_self
(
values
,
rewards
,
dones
,
learns
,
nextvalues
,
next_done
,
gamma
,
gae_lambda
):
num_steps
=
rewards
.
size
(
0
)
num_steps
=
rewards
.
size
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment