Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
974fe861
Commit
974fe861
authored
Jun 10, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Rename
parent
dab0733b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
16 deletions
+15
-16
scripts/cleanba.py
scripts/cleanba.py
+7
-8
ygoai/rl/jax/__init__.py
ygoai/rl/jax/__init__.py
+7
-7
ygoai/rl/jax/switch.py
ygoai/rl/jax/switch.py
+1
-1
No files found.
scripts/cleanba.py
View file @
974fe861
...
...
@@ -28,9 +28,9 @@ from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from
ygoai.rl.jax.agent
import
RNNAgent
,
ModelArgs
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax.switch
import
truncated_gae_sep
as
gae_sep_switch
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
\
ach_loss
,
policy_gradient_loss
,
vtrace
,
vtrace_2p0s
,
truncated_gae
from
ygoai.rl.jax.switch
import
truncated_gae_2p0s
as
gae_2p0s_switch
ach_loss
,
policy_gradient_loss
,
vtrace
,
vtrace_sep
,
truncated_gae
,
truncated_gae_sep
os
.
environ
[
"XLA_FLAGS"
]
=
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
...
...
@@ -745,21 +745,20 @@ def main():
if
args
.
switch
:
if
args
.
value
==
"vtrace"
or
args
.
sep_value
:
raise
NotImplementedError
target_values
,
advantages
=
gae_
2p0s
_switch
(
target_values
,
advantages
=
gae_
sep
_switch
(
next_value
,
new_values_
,
rewards
,
next_dones
,
switch_or_mains
,
args
.
gamma
,
args
.
gae_lambda
,
args
.
upgo
)
else
:
# TODO: TD(lambda) for multi-step
ratios_
=
reshape_time_series
(
ratios
)
if
args
.
value
==
"gae"
:
if
args
.
sep_value
:
raise
NotImplementedError
target_values
,
advantages
=
truncated_gae
(
adv_fn
=
truncated_gae_sep
if
args
.
sep_value
else
truncated_gae
target_values
,
advantages
=
adv_fn
(
next_value
,
new_values_
,
rewards
,
next_dones
,
switch_or_mains
,
args
.
gamma
,
args
.
gae_lambda
,
args
.
upgo
)
else
:
vtrace_fn
=
vtrace_2p0s
if
args
.
sep_value
else
vtrace
target_values
,
advantages
=
vtrace
_fn
(
adv_fn
=
vtrace_sep
if
args
.
sep_value
else
vtrace
target_values
,
advantages
=
adv
_fn
(
next_value
,
ratios_
,
new_values_
,
rewards
,
next_dones
,
switch_or_mains
,
args
.
gamma
,
args
.
rho_clip_min
,
args
.
rho_clip_max
,
args
.
c_clip_min
,
args
.
c_clip_max
)
...
...
ygoai/rl/jax/__init__.py
View file @
974fe861
...
...
@@ -244,7 +244,7 @@ def vtrace(
return
targets
,
advantages
def
vtrace_
2p0s
_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
def
vtrace_
sep
_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
v1
,
v2
,
next_values1
,
next_values2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
last_return1
,
last_return2
,
next_q1
,
next_q2
=
carry
ratio
,
cur_values
,
next_done
,
r_t
,
main
=
inp
...
...
@@ -301,7 +301,7 @@ def vtrace_2p0s_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
return
carry
,
(
v
,
q_t
,
return_t
)
def
vtrace_
2p0s
(
def
vtrace_
sep
(
next_value
,
ratios
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
rho_min
=
0.001
,
rho_max
=
1.0
,
c_min
=
0.001
,
c_max
=
1.0
,
upgo
=
False
,
):
...
...
@@ -315,7 +315,7 @@ def vtrace_2p0s(
return1
,
return2
,
next_q1
,
next_q2
_
,
(
targets
,
q_estimate
,
return_t
)
=
jax
.
lax
.
scan
(
partial
(
vtrace_
2p0s
_loop
,
gamma
=
gamma
,
rho_min
=
rho_min
,
rho_max
=
rho_max
,
c_min
=
c_min
,
c_max
=
c_max
),
partial
(
vtrace_
sep
_loop
,
gamma
=
gamma
,
rho_min
=
rho_min
,
rho_max
=
rho_max
,
c_min
=
c_min
,
c_max
=
c_max
),
carry
,
(
ratios
,
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
)
advantages
=
q_estimate
-
values
...
...
@@ -325,7 +325,7 @@ def vtrace_2p0s(
return
targets
,
advantages
def
truncated_gae_
upgo
_loop
(
carry
,
inp
,
gamma
,
gae_lambda
):
def
truncated_gae_
sep
_loop
(
carry
,
inp
,
gamma
,
gae_lambda
):
lastgaelam1
,
lastgaelam2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
\
done_used1
,
done_used2
,
last_return1
,
last_return2
,
next_q1
,
next_q2
=
carry
cur_value
,
next_done
,
reward
,
main
=
inp
...
...
@@ -375,7 +375,7 @@ def truncated_gae_upgo_loop(carry, inp, gamma, gae_lambda):
return
carry
,
(
advantages
,
returns
)
def
truncated_gae_
2p0s
(
def
truncated_gae_
sep
(
next_value
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
gae_lambda
,
upgo
,
):
next_value1
=
next_value
...
...
@@ -390,12 +390,12 @@ def truncated_gae_2p0s(
done_used1
,
done_used2
,
last_return1
,
last_return2
,
next_q1
,
next_q2
_
,
(
advantages
,
returns
)
=
jax
.
lax
.
scan
(
partial
(
truncated_gae_
upgo
_loop
,
gamma
=
gamma
,
gae_lambda
=
gae_lambda
),
partial
(
truncated_gae_
sep
_loop
,
gamma
=
gamma
,
gae_lambda
=
gae_lambda
),
carry
,
(
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
)
targets
=
values
+
advantages
if
upgo
:
advantages
+=
returns
-
values
targets
=
values
+
advantages
targets
=
jax
.
lax
.
stop_gradient
(
targets
)
return
targets
,
advantages
...
...
ygoai/rl/jax/switch.py
View file @
974fe861
...
...
@@ -2,7 +2,7 @@ import jax
import
jax.numpy
as
jnp
def
truncated_gae_
2p0s
(
def
truncated_gae_
sep
(
next_value
,
values
,
rewards
,
next_dones
,
switch
,
gamma
,
gae_lambda
,
upgo
):
def
body_fn
(
carry
,
inp
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment