Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
7c8b11c3
Commit
7c8b11c3
authored
Jun 14, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add collect_steps
parent
e1ff8f92
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
281 additions
and
142 deletions
+281
-142
scripts/cleanba.py
scripts/cleanba.py
+150
-92
ygoai/rl/jax/__init__.py
ygoai/rl/jax/__init__.py
+131
-50
No files found.
scripts/cleanba.py
View file @
7c8b11c3
This diff is collapsed.
Click to expand it.
ygoai/rl/jax/__init__.py
View file @
7c8b11c3
from
functools
import
partial
from
typing
import
NamedTuple
import
jax
import
jax.numpy
as
jnp
...
...
@@ -193,6 +194,14 @@ def vtrace_rnad(
return
targets
,
q_estimate
class
VtraceCarry
(
NamedTuple
):
v
:
jnp
.
ndarray
next_value
:
jnp
.
ndarray
last_return
:
jnp
.
ndarray
next_q
:
jnp
.
ndarray
next_main
:
jnp
.
ndarray
def
vtrace_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
v
,
next_value
,
last_return
,
next_q
,
next_main
=
carry
ratio
,
cur_value
,
next_done
,
reward
,
main
=
inp
...
...
@@ -221,22 +230,29 @@ def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
next_q
=
reward
+
discount
*
next_value
carry
=
v
,
cur_value
,
last_return
,
next_q
,
main
carry
=
VtraceCarry
(
v
,
next_value
,
last_return
,
next_q
,
main
)
return
carry
,
(
v
,
q_t
,
last_return
)
def
vtrace
(
next_value
,
ratios
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
rho_min
=
0.001
,
rho_max
=
1.0
,
c_min
=
0.001
,
c_max
=
1.0
,
upgo
=
False
,
next_v
,
ratios
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
rho_min
=
0.001
,
rho_max
=
1.0
,
c_min
=
0.001
,
c_max
=
1.0
,
upgo
=
False
,
return_carry
=
False
):
if
isinstance
(
next_v
,
(
tuple
,
list
)):
carry
=
next_v
else
:
next_value
=
next_v
v
=
last_return
=
next_q
=
next_value
next_main
=
jnp
.
ones_like
(
next_value
,
dtype
=
jnp
.
bool_
)
carry
=
v
,
next_value
,
last_return
,
next_q
,
next_main
carry
=
VtraceCarry
(
v
,
next_value
,
last_return
,
next_q
,
next_main
)
_
,
(
targets
,
q_estimate
,
return_t
)
=
jax
.
lax
.
scan
(
carry
,
(
targets
,
q_estimate
,
return_t
)
=
jax
.
lax
.
scan
(
partial
(
vtrace_loop
,
gamma
=
gamma
,
rho_min
=
rho_min
,
rho_max
=
rho_max
,
c_min
=
c_min
,
c_max
=
c_max
),
carry
,
(
ratios
,
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
)
if
return_carry
:
return
carry
advantages
=
q_estimate
-
values
if
upgo
:
advantages
+=
return_t
-
values
...
...
@@ -244,9 +260,18 @@ def vtrace(
return
targets
,
advantages
class
VtraceSepCarry
(
NamedTuple
):
v
:
jnp
.
ndarray
next_values
:
jnp
.
ndarray
reward
:
jnp
.
ndarray
xi
:
jnp
.
ndarray
last_return
:
jnp
.
ndarray
next_q
:
jnp
.
ndarray
def
vtrace_sep_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
v1
,
v2
,
next_values1
,
next_values2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
last_return1
,
last_return2
,
next_q1
,
next_q2
=
carry
(
v1
,
next_values1
,
reward1
,
xi1
,
last_return1
,
next_q1
)
,
\
(
v2
,
next_values2
,
reward2
,
xi2
,
last_return2
,
next_q2
)
=
carry
ratio
,
cur_values
,
next_done
,
r_t
,
main
=
inp
v1
=
jnp
.
where
(
next_done
,
0
,
v1
)
...
...
@@ -296,28 +321,46 @@ def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
xi1
=
jnp
.
where
(
main
,
1
,
ratio
*
xi1
)
xi2
=
jnp
.
where
(
main
,
ratio
*
xi2
,
1
)
carry
=
v1
,
v2
,
next_values1
,
next_values2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
last_return1
,
last_return2
,
next_q1
,
next_q2
return
carry
,
(
v
,
q_t
,
return_t
)
carry
1
=
VtraceSepCarry
(
v1
,
next_values1
,
reward1
,
xi1
,
last_return1
,
next_q1
)
carry2
=
VtraceSepCarry
(
v2
,
next_values2
,
reward2
,
xi2
,
last_return2
,
next_q2
)
return
(
carry1
,
carry2
)
,
(
v
,
q_t
,
return_t
)
def
vtrace_sep
(
next_value
,
ratios
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
rho_min
=
0.001
,
rho_max
=
1.0
,
c_min
=
0.001
,
c_max
=
1.0
,
upgo
=
False
,
next_v
,
ratios
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
rho_min
=
0.001
,
rho_max
=
1.0
,
c_min
=
0.001
,
c_max
=
1.0
,
upgo
=
False
,
return_carry
=
False
):
if
isinstance
(
next_v
,
(
tuple
,
list
)):
carry
=
next_v
else
:
next_value
=
next_v
next_value1
=
next_value
carry1
=
VtraceSepCarry
(
v
=
next_value1
,
next_values
=
next_value1
,
reward
=
jnp
.
zeros_like
(
next_value
),
xi
=
jnp
.
ones_like
(
next_value
),
last_return
=
next_value1
,
next_q
=
next_value1
,
)
next_value2
=
-
next_value1
v1
=
return1
=
next_q1
=
next_value1
v2
=
return2
=
next_q2
=
next_value2
reward1
=
reward2
=
jnp
.
zeros_like
(
next_value
)
xi1
=
xi2
=
jnp
.
ones_like
(
next_value
)
carry
=
v1
,
v2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
return1
,
return2
,
next_q1
,
next_q2
carry2
=
VtraceSepCarry
(
v
=
next_value2
,
next_values
=
next_value2
,
reward
=
jnp
.
zeros_like
(
next_value
),
xi
=
jnp
.
ones_like
(
next_value
),
last_return
=
next_value2
,
next_q
=
next_value2
,
)
carry
=
carry1
,
carry2
_
,
(
targets
,
q_estimate
,
return_t
)
=
jax
.
lax
.
scan
(
carry
,
(
targets
,
q_estimate
,
return_t
)
=
jax
.
lax
.
scan
(
partial
(
vtrace_sep_loop
,
gamma
=
gamma
,
rho_min
=
rho_min
,
rho_max
=
rho_max
,
c_min
=
c_min
,
c_max
=
c_max
),
carry
,
(
ratios
,
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
)
if
return_carry
:
return
carry
advantages
=
q_estimate
-
values
if
upgo
:
advantages
+=
return_t
-
values
...
...
@@ -325,9 +368,18 @@ def vtrace_sep(
return
targets
,
advantages
class
GAESepCarry
(
NamedTuple
):
lastgaelam
:
jnp
.
ndarray
next_value
:
jnp
.
ndarray
reward
:
jnp
.
ndarray
done_used
:
jnp
.
ndarray
last_return
:
jnp
.
ndarray
next_q
:
jnp
.
ndarray
def
truncated_gae_sep_loop
(
carry
,
inp
,
gamma
,
gae_lambda
):
lastgaelam1
,
lastgaelam2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
\
done_used1
,
done_used2
,
last_return1
,
last_return2
,
next_q1
,
next_q2
=
carry
(
lastgaelam1
,
next_value1
,
reward1
,
done_used1
,
last_return1
,
next_q1
)
,
\
(
lastgaelam2
,
next_value2
,
reward2
,
done_used2
,
last_return2
,
next_q2
)
=
carry
cur_value
,
next_done
,
reward
,
main
=
inp
main1
=
main
main2
=
~
main
...
...
@@ -370,29 +422,40 @@ def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
lastgaelam1
=
jnp
.
where
(
main1
,
lastgaelam1_
,
lastgaelam1
)
lastgaelam2
=
jnp
.
where
(
main2
,
lastgaelam2_
,
lastgaelam2
)
carry
=
lastgaelam1
,
lastgaelam2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
\
done_used1
,
done_used2
,
last_return1
,
last_return2
,
next_q1
,
next_q2
return
carry
,
(
advantages
,
returns
)
carry
1
=
GAESepCarry
(
lastgaelam1
,
next_value1
,
reward1
,
done_used1
,
last_return1
,
next_q1
)
carry2
=
GAESepCarry
(
lastgaelam2
,
next_value2
,
reward2
,
done_used2
,
last_return2
,
next_q2
)
return
(
carry1
,
carry2
)
,
(
advantages
,
returns
)
def
truncated_gae_sep
(
next_value
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
gae_lambda
,
upgo
,
):
next_value1
=
next_value
next_value2
=
-
next_value1
last_return1
=
next_q1
=
next_value1
last_return2
=
next_q2
=
next_value2
done_used1
=
jnp
.
ones_like
(
next_dones
[
-
1
])
done_used2
=
jnp
.
ones_like
(
next_dones
[
-
1
])
reward1
=
reward2
=
jnp
.
zeros_like
(
next_value
)
lastgaelam1
=
lastgaelam2
=
jnp
.
zeros_like
(
next_value
)
carry
=
lastgaelam1
,
lastgaelam2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
\
done_used1
,
done_used2
,
last_return1
,
last_return2
,
next_q1
,
next_q2
_
,
(
advantages
,
returns
)
=
jax
.
lax
.
scan
(
next_v
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
gae_lambda
,
upgo
,
return_carry
=
False
):
if
isinstance
(
next_v
,
(
tuple
,
list
)):
carry
=
next_v
else
:
next_value
=
next_v
carry1
=
GAESepCarry
(
lastgaelam
=
jnp
.
zeros_like
(
next_value
),
next_value
=
next_value
,
reward
=
jnp
.
zeros_like
(
next_value
),
done_used
=
jnp
.
ones_like
(
next_dones
[
-
1
]),
last_return
=
next_value
,
next_q
=
next_value
,
)
carry2
=
GAESepCarry
(
lastgaelam
=
jnp
.
zeros_like
(
next_value
),
next_value
=-
next_value
,
reward
=
jnp
.
zeros_like
(
next_value
),
done_used
=
jnp
.
ones_like
(
next_dones
[
-
1
]),
last_return
=-
next_value
,
next_q
=-
next_value
,
)
carry
=
carry1
,
carry2
carry
,
(
advantages
,
returns
)
=
jax
.
lax
.
scan
(
partial
(
truncated_gae_sep_loop
,
gamma
=
gamma
,
gae_lambda
=
gae_lambda
),
carry
,
(
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
)
if
return_carry
:
return
carry
targets
=
values
+
advantages
if
upgo
:
advantages
+=
returns
-
values
...
...
@@ -400,6 +463,14 @@ def truncated_gae_sep(
return
targets
,
advantages
class
GAECarry
(
NamedTuple
):
lastgaelam
:
jnp
.
ndarray
next_value
:
jnp
.
ndarray
last_return
:
jnp
.
ndarray
next_q
:
jnp
.
ndarray
next_main
:
jnp
.
ndarray
def
truncated_gae_loop
(
carry
,
inp
,
gamma
,
gae_lambda
):
lastgaelam
,
next_value
,
last_return
,
next_q
,
next_main
=
carry
cur_value
,
next_done
,
reward
,
main
=
inp
...
...
@@ -424,20 +495,30 @@ def truncated_gae_loop(carry, inp, gamma, gae_lambda):
next_q
=
reward
+
discount
*
next_value
carry
=
lastgaelam
,
cur_value
,
last_return
,
next_q
,
main
carry
=
GAECarry
(
lastgaelam
,
cur_value
,
last_return
,
next_q
,
main
)
return
carry
,
(
lastgaelam
,
last_return
)
def
truncated_gae
(
next_value
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
gae_lambda
,
upgo
=
False
):
lastgaelam
=
jnp
.
zeros_like
(
next_value
)
last_return
=
next_q
=
next_value
next_main
=
jnp
.
ones_like
(
next_value
,
dtype
=
jnp
.
bool_
)
carry
=
lastgaelam
,
next_value
,
last_return
,
next_q
,
next_main
_
,
(
advantages
,
return_t
)
=
jax
.
lax
.
scan
(
next_v
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
gae_lambda
,
upgo
=
False
,
return_carry
=
False
):
if
isinstance
(
next_v
,
(
tuple
,
list
)):
carry
=
next_v
else
:
next_value
=
next_v
carry
=
GAECarry
(
lastgaelam
=
jnp
.
zeros_like
(
next_value
),
next_value
=
next_value
,
last_return
=
next_value
,
next_q
=
next_value
,
next_main
=
jnp
.
ones_like
(
next_value
,
dtype
=
jnp
.
bool_
),
)
carry
,
(
advantages
,
return_t
)
=
jax
.
lax
.
scan
(
partial
(
truncated_gae_loop
,
gamma
=
gamma
,
gae_lambda
=
gae_lambda
),
carry
,
(
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
)
if
return_carry
:
return
carry
targets
=
values
+
advantages
if
upgo
:
advantages
+=
return_t
-
values
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment