Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
45506246
Commit
45506246
authored
Apr 23, 2024
by
Biluo Shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add NeuRD and ACH loss
parent
efe74c51
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
1 deletion
+64
-1
scripts/ppo.py
scripts/ppo.py
+6
-1
ygoai/rl/jax/__init__.py
ygoai/rl/jax/__init__.py
+58
-0
No files found.
scripts/ppo.py
View file @
45506246
...
...
@@ -26,7 +26,7 @@ from ygoai.utils import init_ygopro, load_embeddings
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
ach_loss
from
ygoai.rl.jax.switch
import
truncated_gae_2p0s
...
...
@@ -98,6 +98,8 @@ class Args:
"""the dual surrogate clipping coefficient, typically 3.0"""
spo_kld_max
:
Optional
[
float
]
=
None
"""the maximum KLD for the SPO policy, typically 0.02"""
logits_threshold
:
Optional
[
float
]
=
None
"""the logits threshold for NeuRD and ACH, typically 2.0-6.0"""
ent_coef
:
float
=
0.01
"""coefficient of the entropy"""
vf_coef
:
float
=
0.5
...
...
@@ -635,6 +637,9 @@ if __name__ == "__main__":
if
args
.
spo_kld_max
is
not
None
:
pg_loss
=
simple_policy_loss
(
ratios
,
logits
,
new_logits
,
advantages
,
args
.
spo_kld_max
)
elif
args
.
logits_threshold
is
not
None
:
pg_loss
=
ach_loss
(
actions
,
logits
,
new_logits
,
advantages
,
args
.
logits_threshold
,
args
.
clip_coef
,
args
.
dual_clip_coef
)
else
:
pg_loss
=
clipped_surrogate_pg_loss
(
ratios
,
advantages
,
args
.
clip_coef
,
args
.
dual_clip_coef
)
...
...
ygoai/rl/jax/__init__.py
View file @
45506246
...
...
@@ -100,6 +100,64 @@ def clipped_surrogate_pg_loss(ratios, advantages, clip_coef, dual_clip_coef=None
return
pg_loss
def
get_from_action
(
values
,
action
):
num_categories
=
values
.
shape
[
-
1
]
value_one_hot
=
jax
.
nn
.
one_hot
(
action
,
num_categories
,
dtype
=
values
.
dtype
)
return
jnp
.
sum
(
distrax
.
multiply_no_nan
(
values
,
value_one_hot
),
axis
=-
1
)
def
mean_legal
(
values
,
axis
=
None
):
# TODO: use real action mask
no_nan_mask
=
values
>
-
1e12
no_nan
=
jnp
.
where
(
no_nan_mask
,
values
,
0
)
count
=
jnp
.
sum
(
no_nan_mask
,
axis
=
axis
)
return
jnp
.
sum
(
no_nan
,
axis
=
axis
)
/
jnp
.
maximum
(
count
,
1
)
def
neurd_loss
(
actions
,
logits
,
new_logits
,
advantages
,
logits_threshold
):
# Neural Replicator Dynamics
# Differences from the original implementation:
# - all actions vs. sampled actions
# - original computes advantages with q values
# - original does not use importance sampling ratios
advs
=
jax
.
lax
.
stop_gradient
(
advantages
)
probs_a
=
get_from_action
(
jax
.
nn
.
softmax
(
logits
),
actions
)
probs_a
=
jnp
.
maximum
(
probs_a
,
0.001
)
new_logits_a
=
get_from_action
(
new_logits
,
actions
)
new_logits_a_
=
new_logits_a
-
mean_legal
(
new_logits
,
axis
=-
1
)
can_decrease_1
=
new_logits_a_
<
logits_threshold
can_decrease_2
=
new_logits_a_
>
-
logits_threshold
c
=
jnp
.
where
(
advs
>=
0
,
can_decrease_1
,
can_decrease_2
)
.
astype
(
jnp
.
float32
)
c
=
jax
.
lax
.
stop_gradient
(
c
)
pg_loss
=
-
c
*
new_logits_a_
/
probs_a
*
advs
return
pg_loss
def
ach_loss
(
actions
,
logits
,
new_logits
,
advantages
,
logits_threshold
,
clip_coef
,
dual_clip_coef
=
None
):
# Actor-Critic Hedge loss from Actor-Critic Policy Optimization in a Large-Scale Imperfect-Information Game
# notice entropy term is required but not included here
advs
=
jax
.
lax
.
stop_gradient
(
advantages
)
probs_a
=
get_from_action
(
jax
.
nn
.
softmax
(
logits
),
actions
)
probs_a
=
jnp
.
maximum
(
probs_a
,
0.001
)
new_logits_a
=
get_from_action
(
new_logits
,
actions
)
new_logits_a_
=
new_logits_a
-
mean_legal
(
new_logits
,
axis
=-
1
)
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
can_decrease_1
=
(
ratios
<
1
+
clip_coef
)
*
(
new_logits_a_
<
logits_threshold
)
can_decrease_2
=
(
ratios
>
1
-
clip_coef
)
*
(
new_logits_a_
>
-
logits_threshold
)
if
dual_clip_coef
is
not
None
:
can_decrease_2
=
can_decrease_2
*
(
ratios
<
dual_clip_coef
)
c
=
jnp
.
where
(
advs
>=
0
,
can_decrease_1
,
can_decrease_2
)
.
astype
(
jnp
.
float32
)
c
=
jax
.
lax
.
stop_gradient
(
c
)
pg_loss
=
-
c
*
new_logits_a_
/
probs_a
*
advs
return
pg_loss
def
vtrace_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
v1
,
v2
,
next_values1
,
next_values2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
last_return1
,
last_return2
,
next_q1
,
next_q2
=
carry
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment