Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
831e92ff
Commit
831e92ff
authored
Jun 08, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add sep_value
parent
e03d45b6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
104 additions
and
78 deletions
+104
-78
scripts/cleanba.py
scripts/cleanba.py
+20
-5
ygoai/rl/jax/__init__.py
ygoai/rl/jax/__init__.py
+84
-73
No files found.
scripts/cleanba.py
View file @
831e92ff
...
...
@@ -8,7 +8,7 @@ from datetime import datetime, timedelta, timezone
from
collections
import
deque
from
dataclasses
import
dataclass
,
field
,
asdict
from
types
import
SimpleNamespace
from
typing
import
List
,
NamedTuple
,
Optional
from
typing
import
List
,
NamedTuple
,
Optional
,
Literal
from
functools
import
partial
import
ygoenv
...
...
@@ -28,7 +28,8 @@ from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from
ygoai.rl.jax.agent
import
RNNAgent
,
ModelArgs
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
vtrace_2p0s
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
ach_loss
,
policy_gradient_loss
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
\
ach_loss
,
policy_gradient_loss
,
vtrace
,
vtrace_2p0s
,
truncated_gae
from
ygoai.rl.jax.switch
import
truncated_gae_2p0s
as
gae_2p0s_switch
...
...
@@ -116,6 +117,10 @@ class Args:
upgo
:
bool
=
True
"""Toggle the use of UPGO for advantages"""
sep_value
:
bool
=
True
"""Whether separate value function computation for each player"""
value
:
Literal
[
"vtrace"
,
"gae"
]
=
"vtrace"
"""the method to learn the value function"""
gae_lambda
:
float
=
0.95
"""the lambda for the general advantage estimation"""
c_clip_min
:
float
=
0.001
...
...
@@ -738,13 +743,23 @@ def main():
# Advantages and target values
if
args
.
switch
:
if
args
.
value
==
"vtrace"
or
args
.
sep_value
:
raise
NotImplementedError
target_values
,
advantages
=
gae_2p0s_switch
(
next_value
,
new_values_
,
rewards
,
next_dones
,
switch_or_mains
,
args
.
gamma
,
args
.
gae_lambda
,
args
.
upgo
)
else
:
# TODO: TD(lambda) for multi-step
ratios_
=
reshape_time_series
(
ratios
)
target_values
,
advantages
=
vtrace_2p0s
(
if
args
.
value
==
"gae"
:
if
not
args
.
sep_value
:
raise
NotImplementedError
target_values
,
advantages
=
truncated_gae
(
next_value
,
new_values_
,
rewards
,
next_dones
,
switch_or_mains
,
args
.
gamma
,
args
.
gae_lambda
,
args
.
upgo
)
else
:
vtrace_fn
=
vtrace
if
args
.
sep_value
else
vtrace_2p0s
target_values
,
advantages
=
vtrace_fn
(
next_value
,
ratios_
,
new_values_
,
rewards
,
next_dones
,
switch_or_mains
,
args
.
gamma
,
args
.
rho_clip_min
,
args
.
rho_clip_max
,
args
.
c_clip_min
,
args
.
c_clip_max
)
...
...
ygoai/rl/jax/__init__.py
View file @
831e92ff
...
...
@@ -7,68 +7,6 @@ import chex
import
distrax
# class VTraceOutput(NamedTuple):
# q_estimate: jnp.ndarray
# errors: jnp.ndarray
# def vtrace(
# v_tm1,
# v_t,
# r_t,
# discount_t,
# rho_tm1,
# lambda_=1.0,
# c_clip_min: float = 0.001,
# c_clip_max: float = 1.007,
# rho_clip_min: float = 0.001,
# rho_clip_max: float = 1.007,
# stop_target_gradients: bool = True,
# ):
# """
# Args:
# v_tm1: values at time t-1.
# v_t: values at time t.
# r_t: reward at time t.
# discount_t: discount at time t.
# rho_tm1: importance sampling ratios at time t-1.
# lambda_: mixing parameter; a scalar or a vector for timesteps t.
# clip_rho_threshold: clip threshold for importance weights.
# stop_target_gradients: whether or not to apply stop gradient to targets.
# """
# # Clip importance sampling ratios.
# lambda_ = jnp.ones_like(discount_t) * lambda_
# c_tm1 = jnp.clip(rho_tm1, c_clip_min, c_clip_max) * lambda_
# clipped_rhos_tm1 = jnp.clip(rho_tm1, rho_clip_min, rho_clip_max)
# # Compute the temporal difference errors.
# td_errors = clipped_rhos_tm1 * (r_t + discount_t * v_t - v_tm1)
# # Work backwards computing the td-errors.
# def _body(acc, xs):
# td_error, discount, c = xs
# acc = td_error + discount * c * acc
# return acc, acc
# _, errors = jax.lax.scan(
# _body, 0.0, (td_errors, discount_t, c_tm1), reverse=True)
# # Return errors, maybe disabling gradient flow through bootstrap targets.
# errors = jax.lax.select(
# stop_target_gradients,
# jax.lax.stop_gradient(errors + v_tm1) - v_tm1,
# errors)
# targets_tm1 = errors + v_tm1
# q_bootstrap = jnp.concatenate([
# lambda_[:-1] * targets_tm1[1:] + (1 - lambda_[:-1]) * v_tm1[1:],
# v_t[-1:],
# ], axis=0)
# q_estimate = r_t + discount_t * q_bootstrap
# return VTraceOutput(q_estimate=q_estimate, errors=errors)
def
entropy_loss
(
logits
):
return
distrax
.
Softmax
(
logits
=
logits
)
.
entropy
()
...
...
@@ -255,6 +193,57 @@ def vtrace_rnad(
return
targets
,
q_estimate
def
vtrace_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
v
,
next_value
,
last_return
,
next_q
,
next_main
=
carry
ratio
,
cur_value
,
next_done
,
reward
,
main
=
inp
v
=
jnp
.
where
(
next_done
,
0
,
v
)
next_value
=
jnp
.
where
(
next_done
,
0
,
next_value
)
sign
=
jnp
.
where
(
main
==
next_main
,
1
,
-
1
)
v
=
v
*
sign
next_value
=
next_value
*
sign
discount
=
gamma
*
(
1.0
-
next_done
)
q_t
=
reward
+
discount
*
v
rho_t
=
jnp
.
clip
(
ratio
,
rho_min
,
rho_max
)
c_t
=
jnp
.
clip
(
ratio
,
c_min
,
c_max
)
sig_v
=
rho_t
*
(
reward
+
discount
*
next_value
-
cur_value
)
v
=
cur_value
+
sig_v
+
c_t
*
discount
*
(
v
-
next_value
)
# UPGO advantage (not corrected by importance sampling, unlike V-trace)
last_return
=
last_return
*
sign
next_q
=
next_q
*
sign
last_return
=
reward
+
discount
*
jnp
.
where
(
next_q
>=
next_value
,
last_return
,
next_value
)
next_q
=
reward
+
discount
*
next_value
carry
=
v
,
cur_value
,
last_return
,
next_q
,
main
return
carry
,
(
v
,
q_t
,
last_return
)
def
vtrace
(
next_value
,
ratios
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
rho_min
=
0.001
,
rho_max
=
1.0
,
c_min
=
0.001
,
c_max
=
1.0
,
upgo
=
False
,
):
v
=
last_return
=
next_q
=
next_value
next_main
=
jnp
.
ones_like
(
next_value
,
dtype
=
jnp
.
bool_
)
carry
=
v
,
next_value
,
last_return
,
next_q
,
next_main
_
,
(
targets
,
q_estimate
,
return_t
)
=
jax
.
lax
.
scan
(
partial
(
vtrace_loop
,
gamma
=
gamma
,
rho_min
=
rho_min
,
rho_max
=
rho_max
,
c_min
=
c_min
,
c_max
=
c_max
),
carry
,
(
ratios
,
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
)
advantages
=
q_estimate
-
values
if
upgo
:
advantages
+=
return_t
-
values
targets
=
jax
.
lax
.
stop_gradient
(
targets
)
return
targets
,
advantages
def
vtrace_2p0s_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
v1
,
v2
,
next_values1
,
next_values2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
last_return1
,
last_return2
,
next_q1
,
next_q2
=
carry
...
...
@@ -412,24 +401,46 @@ def truncated_gae_2p0s(
def
truncated_gae_loop
(
carry
,
inp
,
gamma
,
gae_lambda
):
lastgaelam
,
next_value
=
carry
cur_value
,
next_done
,
reward
=
inp
nextnonterminal
=
1.0
-
next_done
lastgaelam
,
next_value
,
last_return
,
next_q
,
next_main
=
carry
cur_value
,
next_done
,
reward
,
main
=
inp
lastgaelam
=
jnp
.
where
(
next_done
,
0
,
lastgaelam
)
next_value
=
jnp
.
where
(
next_done
,
0
,
next_value
)
delta
=
reward
+
gamma
*
next_value
*
nextnonterminal
-
cur_value
lastgaelam
=
delta
+
gamma
*
gae_lambda
*
nextnonterminal
*
lastgaelam
carry
=
lastgaelam
,
cur_value
return
carry
,
lastgaelam
sign
=
jnp
.
where
(
main
==
next_main
,
1
,
-
1
)
lastgaelam
=
lastgaelam
*
sign
next_value
=
next_value
*
sign
discount
=
gamma
*
(
1.0
-
next_done
)
delta
=
reward
+
discount
*
next_value
-
cur_value
lastgaelam
=
delta
+
discount
*
gae_lambda
*
lastgaelam
# UPGO advantage
last_return
=
last_return
*
sign
next_q
=
next_q
*
sign
last_return
=
reward
+
discount
*
jnp
.
where
(
next_q
>=
next_value
,
last_return
,
next_value
)
next_q
=
reward
+
discount
*
next_value
def
truncated_gae
(
next_value
,
values
,
rewards
,
next_dones
,
gamma
,
gae_lambda
):
carry
=
lastgaelam
,
cur_value
,
last_return
,
next_q
,
main
return
carry
,
(
lastgaelam
,
last_return
)
def
truncated_gae
(
next_value
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
gae_lambda
,
upgo
=
False
):
lastgaelam
=
jnp
.
zeros_like
(
next_value
)
carry
=
lastgaelam
,
next_value
_
,
advantages
=
jax
.
lax
.
scan
(
last_return
=
next_q
=
next_value
next_main
=
jnp
.
ones_like
(
next_value
,
dtype
=
jnp
.
bool_
)
carry
=
lastgaelam
,
next_value
,
last_return
,
next_q
,
next_main
_
,
(
advantages
,
return_t
)
=
jax
.
lax
.
scan
(
partial
(
truncated_gae_loop
,
gamma
=
gamma
,
gae_lambda
=
gae_lambda
),
carry
,
(
values
,
next_dones
,
rewards
),
reverse
=
True
carry
,
(
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
)
targets
=
values
+
advantages
if
upgo
:
advantages
+=
return_t
-
values
targets
=
jax
.
lax
.
stop_gradient
(
targets
)
return
targets
,
advantages
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment