Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
45506246
You need to sign in or sign up before continuing.
Commit
45506246
authored
Apr 23, 2024
by
Biluo Shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add NeuRD and ACH loss
parent
efe74c51
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
1 deletion
+64
-1
scripts/ppo.py
scripts/ppo.py
+6
-1
ygoai/rl/jax/__init__.py
ygoai/rl/jax/__init__.py
+58
-0
No files found.
scripts/ppo.py
View file @
45506246
...
...
@@ -26,7 +26,7 @@ from ygoai.utils import init_ygopro, load_embeddings
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
ach_loss
from
ygoai.rl.jax.switch
import
truncated_gae_2p0s
...
...
@@ -98,6 +98,8 @@ class Args:
"""the dual surrogate clipping coefficient, typically 3.0"""
spo_kld_max
:
Optional
[
float
]
=
None
"""the maximum KLD for the SPO policy, typically 0.02"""
logits_threshold
:
Optional
[
float
]
=
None
"""the logits threshold for NeuRD and ACH, typically 2.0-6.0"""
ent_coef
:
float
=
0.01
"""coefficient of the entropy"""
vf_coef
:
float
=
0.5
...
...
@@ -635,6 +637,9 @@ if __name__ == "__main__":
if
args
.
spo_kld_max
is
not
None
:
pg_loss
=
simple_policy_loss
(
ratios
,
logits
,
new_logits
,
advantages
,
args
.
spo_kld_max
)
elif
args
.
logits_threshold
is
not
None
:
pg_loss
=
ach_loss
(
actions
,
logits
,
new_logits
,
advantages
,
args
.
logits_threshold
,
args
.
clip_coef
,
args
.
dual_clip_coef
)
else
:
pg_loss
=
clipped_surrogate_pg_loss
(
ratios
,
advantages
,
args
.
clip_coef
,
args
.
dual_clip_coef
)
...
...
ygoai/rl/jax/__init__.py
View file @
45506246
...
...
@@ -100,6 +100,64 @@ def clipped_surrogate_pg_loss(ratios, advantages, clip_coef, dual_clip_coef=None
return
pg_loss
def
get_from_action
(
values
,
action
):
num_categories
=
values
.
shape
[
-
1
]
value_one_hot
=
jax
.
nn
.
one_hot
(
action
,
num_categories
,
dtype
=
values
.
dtype
)
return
jnp
.
sum
(
distrax
.
multiply_no_nan
(
values
,
value_one_hot
),
axis
=-
1
)
def
mean_legal
(
values
,
axis
=
None
):
# TODO: use real action mask
no_nan_mask
=
values
>
-
1e12
no_nan
=
jnp
.
where
(
no_nan_mask
,
values
,
0
)
count
=
jnp
.
sum
(
no_nan_mask
,
axis
=
axis
)
return
jnp
.
sum
(
no_nan
,
axis
=
axis
)
/
jnp
.
maximum
(
count
,
1
)
def
neurd_loss
(
actions
,
logits
,
new_logits
,
advantages
,
logits_threshold
):
# Neural Replicator Dynamics
# Differences from the original implementation:
# - all actions vs. sampled actions
# - original computes advantages with q values
# - original does not use importance sampling ratios
advs
=
jax
.
lax
.
stop_gradient
(
advantages
)
probs_a
=
get_from_action
(
jax
.
nn
.
softmax
(
logits
),
actions
)
probs_a
=
jnp
.
maximum
(
probs_a
,
0.001
)
new_logits_a
=
get_from_action
(
new_logits
,
actions
)
new_logits_a_
=
new_logits_a
-
mean_legal
(
new_logits
,
axis
=-
1
)
can_decrease_1
=
new_logits_a_
<
logits_threshold
can_decrease_2
=
new_logits_a_
>
-
logits_threshold
c
=
jnp
.
where
(
advs
>=
0
,
can_decrease_1
,
can_decrease_2
)
.
astype
(
jnp
.
float32
)
c
=
jax
.
lax
.
stop_gradient
(
c
)
pg_loss
=
-
c
*
new_logits_a_
/
probs_a
*
advs
return
pg_loss
def
ach_loss
(
actions
,
logits
,
new_logits
,
advantages
,
logits_threshold
,
clip_coef
,
dual_clip_coef
=
None
):
# Actor-Critic Hedge loss from Actor-Critic Policy Optimization in a Large-Scale Imperfect-Information Game
# notice entropy term is required but not included here
advs
=
jax
.
lax
.
stop_gradient
(
advantages
)
probs_a
=
get_from_action
(
jax
.
nn
.
softmax
(
logits
),
actions
)
probs_a
=
jnp
.
maximum
(
probs_a
,
0.001
)
new_logits_a
=
get_from_action
(
new_logits
,
actions
)
new_logits_a_
=
new_logits_a
-
mean_legal
(
new_logits
,
axis
=-
1
)
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
can_decrease_1
=
(
ratios
<
1
+
clip_coef
)
*
(
new_logits_a_
<
logits_threshold
)
can_decrease_2
=
(
ratios
>
1
-
clip_coef
)
*
(
new_logits_a_
>
-
logits_threshold
)
if
dual_clip_coef
is
not
None
:
can_decrease_2
=
can_decrease_2
*
(
ratios
<
dual_clip_coef
)
c
=
jnp
.
where
(
advs
>=
0
,
can_decrease_1
,
can_decrease_2
)
.
astype
(
jnp
.
float32
)
c
=
jax
.
lax
.
stop_gradient
(
c
)
pg_loss
=
-
c
*
new_logits_a_
/
probs_a
*
advs
return
pg_loss
def
vtrace_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
v1
,
v2
,
next_values1
,
next_values2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
last_return1
,
last_return2
,
next_q1
,
next_q2
=
carry
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment