Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
5cd9807d
Commit
5cd9807d
authored
Jun 16, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix gae with upgo
parent
20da4bcc
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
31 deletions
+26
-31
ygoai/rl/jax/__init__.py
ygoai/rl/jax/__init__.py
+26
-31
No files found.
ygoai/rl/jax/__init__.py
View file @
5cd9807d
...
@@ -262,7 +262,7 @@ def vtrace(
...
@@ -262,7 +262,7 @@ def vtrace(
class
VtraceSepCarry
(
NamedTuple
):
class
VtraceSepCarry
(
NamedTuple
):
v
:
jnp
.
ndarray
v
:
jnp
.
ndarray
next_value
s
:
jnp
.
ndarray
next_value
:
jnp
.
ndarray
reward
:
jnp
.
ndarray
reward
:
jnp
.
ndarray
xi
:
jnp
.
ndarray
xi
:
jnp
.
ndarray
last_return
:
jnp
.
ndarray
last_return
:
jnp
.
ndarray
...
@@ -270,22 +270,18 @@ class VtraceSepCarry(NamedTuple):
...
@@ -270,22 +270,18 @@ class VtraceSepCarry(NamedTuple):
def
vtrace_sep_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
def
vtrace_sep_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
(
v1
,
next_value
s
1
,
reward1
,
xi1
,
last_return1
,
next_q1
),
\
(
v1
,
next_value1
,
reward1
,
xi1
,
last_return1
,
next_q1
),
\
(
v2
,
next_value
s
2
,
reward2
,
xi2
,
last_return2
,
next_q2
)
=
carry
(
v2
,
next_value2
,
reward2
,
xi2
,
last_return2
,
next_q2
)
=
carry
ratio
,
cur_value
s
,
next_done
,
r_t
,
main
=
inp
ratio
,
cur_value
,
next_done
,
r_t
,
main
=
inp
v1
=
jnp
.
where
(
next_done
,
0
,
v1
)
v2
=
jnp
.
where
(
next_done
,
0
,
v2
)
v1
,
v2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
xi1
,
xi2
=
jax
.
tree
.
map
(
next_values1
=
jnp
.
where
(
next_done
,
0
,
next_values1
)
lambda
x
:
jnp
.
where
(
next_done
,
0
,
x
),
next_values2
=
jnp
.
where
(
next_done
,
0
,
next_values2
)
(
v1
,
v2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
xi1
,
xi2
))
reward1
=
jnp
.
where
(
next_done
,
0
,
reward1
)
reward2
=
jnp
.
where
(
next_done
,
0
,
reward2
)
xi1
=
jnp
.
where
(
next_done
,
1
,
xi1
)
xi2
=
jnp
.
where
(
next_done
,
1
,
xi2
)
discount
=
gamma
*
(
1.0
-
next_done
)
discount
=
gamma
*
(
1.0
-
next_done
)
v
=
jnp
.
where
(
main
,
v1
,
v2
)
v
=
jnp
.
where
(
main
,
v1
,
v2
)
next_value
s
=
jnp
.
where
(
main
,
next_values1
,
next_values
2
)
next_value
=
jnp
.
where
(
main
,
next_value1
,
next_value
2
)
reward
=
jnp
.
where
(
main
,
reward1
,
reward2
)
reward
=
jnp
.
where
(
main
,
reward1
,
reward2
)
xi
=
jnp
.
where
(
main
,
xi1
,
xi2
)
xi
=
jnp
.
where
(
main
,
xi1
,
xi2
)
...
@@ -293,20 +289,20 @@ def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
...
@@ -293,20 +289,20 @@ def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
rho_t
=
jnp
.
clip
(
ratio
*
xi
,
rho_min
,
rho_max
)
rho_t
=
jnp
.
clip
(
ratio
*
xi
,
rho_min
,
rho_max
)
c_t
=
jnp
.
clip
(
ratio
*
xi
,
c_min
,
c_max
)
c_t
=
jnp
.
clip
(
ratio
*
xi
,
c_min
,
c_max
)
sig_v
=
rho_t
*
(
r_t
+
ratio
*
reward
+
discount
*
next_value
s
-
cur_values
)
sig_v
=
rho_t
*
(
r_t
+
ratio
*
reward
+
discount
*
next_value
-
cur_value
)
v
=
cur_value
s
+
sig_v
+
c_t
*
discount
*
(
v
-
next_values
)
v
=
cur_value
+
sig_v
+
c_t
*
discount
*
(
v
-
next_value
)
# UPGO advantage (not corrected by importance sampling, unlike V-trace)
# UPGO advantage (not corrected by importance sampling, unlike V-trace)
return_t
=
jnp
.
where
(
main
,
last_return1
,
last_return2
)
return_t
=
jnp
.
where
(
main
,
last_return1
,
last_return2
)
next_q
=
jnp
.
where
(
main
,
next_q1
,
next_q2
)
next_q
=
jnp
.
where
(
main
,
next_q1
,
next_q2
)
factor
=
jnp
.
where
(
main
,
jnp
.
ones_like
(
r_t
),
-
jnp
.
ones_like
(
r_t
))
factor
=
jnp
.
where
(
main
,
jnp
.
ones_like
(
r_t
),
-
jnp
.
ones_like
(
r_t
))
return_t
=
r_t
+
discount
*
jnp
.
where
(
return_t
=
r_t
+
discount
*
jnp
.
where
(
next_q
>=
next_value
s
,
return_t
,
next_values
)
next_q
>=
next_value
,
return_t
,
next_value
)
last_return1
=
jnp
.
where
(
last_return1
=
jnp
.
where
(
next_done
,
r_t
*
factor
,
jnp
.
where
(
main
,
return_t
,
last_return1
))
next_done
,
r_t
*
factor
,
jnp
.
where
(
main
,
return_t
,
last_return1
))
last_return2
=
jnp
.
where
(
last_return2
=
jnp
.
where
(
next_done
,
r_t
*
-
factor
,
jnp
.
where
(
main
,
last_return2
,
return_t
))
next_done
,
r_t
*
-
factor
,
jnp
.
where
(
main
,
last_return2
,
return_t
))
next_q
=
r_t
+
discount
*
next_value
s
next_q
=
r_t
+
discount
*
next_value
next_q1
=
jnp
.
where
(
next_q1
=
jnp
.
where
(
next_done
,
r_t
*
factor
,
jnp
.
where
(
main
,
next_q
,
next_q1
))
next_done
,
r_t
*
factor
,
jnp
.
where
(
main
,
next_q
,
next_q1
))
next_q2
=
jnp
.
where
(
next_q2
=
jnp
.
where
(
...
@@ -314,15 +310,15 @@ def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
...
@@ -314,15 +310,15 @@ def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v1
=
jnp
.
where
(
main
,
v
,
v1
)
v1
=
jnp
.
where
(
main
,
v
,
v1
)
v2
=
jnp
.
where
(
main
,
v2
,
v
)
v2
=
jnp
.
where
(
main
,
v2
,
v
)
next_value
s1
=
jnp
.
where
(
main
,
cur_values
,
next_values
1
)
next_value
1
=
jnp
.
where
(
main
,
cur_value
,
next_value
1
)
next_value
s2
=
jnp
.
where
(
main
,
next_values2
,
cur_values
)
next_value
2
=
jnp
.
where
(
main
,
next_value2
,
cur_value
)
reward1
=
jnp
.
where
(
main
,
0
,
-
r_t
+
ratio
*
reward1
)
reward1
=
jnp
.
where
(
main
,
0
,
-
r_t
+
ratio
*
reward1
)
reward2
=
jnp
.
where
(
main
,
-
r_t
+
ratio
*
reward2
,
0
)
reward2
=
jnp
.
where
(
main
,
-
r_t
+
ratio
*
reward2
,
0
)
xi1
=
jnp
.
where
(
main
,
1
,
ratio
*
xi1
)
xi1
=
jnp
.
where
(
main
,
1
,
ratio
*
xi1
)
xi2
=
jnp
.
where
(
main
,
ratio
*
xi2
,
1
)
xi2
=
jnp
.
where
(
main
,
ratio
*
xi2
,
1
)
carry1
=
VtraceSepCarry
(
v1
,
next_value
s
1
,
reward1
,
xi1
,
last_return1
,
next_q1
)
carry1
=
VtraceSepCarry
(
v1
,
next_value1
,
reward1
,
xi1
,
last_return1
,
next_q1
)
carry2
=
VtraceSepCarry
(
v2
,
next_value
s
2
,
reward2
,
xi2
,
last_return2
,
next_q2
)
carry2
=
VtraceSepCarry
(
v2
,
next_value2
,
reward2
,
xi2
,
last_return2
,
next_q2
)
return
(
carry1
,
carry2
),
(
v
,
q_t
,
return_t
)
return
(
carry1
,
carry2
),
(
v
,
q_t
,
return_t
)
...
@@ -338,7 +334,7 @@ def vtrace_sep(
...
@@ -338,7 +334,7 @@ def vtrace_sep(
next_value1
=
next_value
next_value1
=
next_value
carry1
=
VtraceSepCarry
(
carry1
=
VtraceSepCarry
(
v
=
next_value1
,
v
=
next_value1
,
next_value
s
=
next_value1
,
next_value
=
next_value1
,
reward
=
jnp
.
zeros_like
(
next_value
),
reward
=
jnp
.
zeros_like
(
next_value
),
xi
=
jnp
.
ones_like
(
next_value
),
xi
=
jnp
.
ones_like
(
next_value
),
last_return
=
next_value1
,
last_return
=
next_value1
,
...
@@ -347,7 +343,7 @@ def vtrace_sep(
...
@@ -347,7 +343,7 @@ def vtrace_sep(
next_value2
=
-
next_value1
next_value2
=
-
next_value1
carry2
=
VtraceSepCarry
(
carry2
=
VtraceSepCarry
(
v
=
next_value2
,
v
=
next_value2
,
next_value
s
=
next_value2
,
next_value
=
next_value2
,
reward
=
jnp
.
zeros_like
(
next_value
),
reward
=
jnp
.
zeros_like
(
next_value
),
xi
=
jnp
.
ones_like
(
next_value
),
xi
=
jnp
.
ones_like
(
next_value
),
last_return
=
next_value2
,
last_return
=
next_value2
,
...
@@ -397,23 +393,22 @@ def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
...
@@ -397,23 +393,22 @@ def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
done_used2
=
jnp
.
where
(
done_used2
=
jnp
.
where
(
next_done
,
main2
,
jnp
.
where
(
main2
&
~
done_used2
,
True
,
done_used2
))
next_done
,
main2
,
jnp
.
where
(
main2
&
~
done_used2
,
True
,
done_used2
))
# UPGO advantage
last_return1
=
jnp
.
where
(
real_done1
,
0
,
last_return1
)
last_return1
=
jnp
.
where
(
real_done1
,
0
,
last_return1
)
last_return2
=
jnp
.
where
(
real_done2
,
0
,
last_return2
)
last_return2
=
jnp
.
where
(
real_done2
,
0
,
last_return2
)
last_return1_
=
reward1
+
gamma
*
jnp
.
where
(
last_return1_
=
reward1
+
gamma
*
jnp
.
where
(
next_q1
>=
next_value1
,
last_return1
,
next_value1
)
next_q1
>=
next_value1
,
last_return1
,
next_value1
)
last_return2_
=
reward2
+
gamma
*
jnp
.
where
(
last_return2_
=
reward2
+
gamma
*
jnp
.
where
(
next_q2
>=
next_value2
,
last_return2
,
next_value2
)
next_q2
>=
next_value2
,
last_return2
,
next_value2
)
return_t
=
jnp
.
where
(
main1
,
last_return1_
,
last_return2_
)
last_return1
=
jnp
.
where
(
main1
,
last_return1_
,
last_return1
)
last_return2
=
jnp
.
where
(
main2
,
last_return2_
,
last_return2
)
next_q1_
=
reward1
+
gamma
*
next_value1
next_q1_
=
reward1
+
gamma
*
next_value1
next_q2_
=
reward2
+
gamma
*
next_value2
next_q2_
=
reward2
+
gamma
*
next_value2
next_q1
=
jnp
.
where
(
main1
,
next_q1_
,
next_q1
)
next_q1
=
jnp
.
where
(
main1
,
next_q1_
,
next_q1
)
next_q2
=
jnp
.
where
(
main2
,
next_q2_
,
next_q1
)
next_q2
=
jnp
.
where
(
main2
,
next_q2_
,
next_q2
)
last_return1
=
jnp
.
where
(
main1
,
last_return1_
,
last_return1
)
last_return2
=
jnp
.
where
(
main2
,
last_return2_
,
last_return2
)
returns
=
jnp
.
where
(
main1
,
last_return1_
,
last_return2_
)
delta1
=
next_q1_
-
cur_value
delta1
=
reward1
+
gamma
*
next_value1
-
cur_value
delta2
=
next_q2_
-
cur_value
delta2
=
reward2
+
gamma
*
next_value2
-
cur_value
lastgaelam1_
=
delta1
+
gamma
*
gae_lambda
*
lastgaelam1
lastgaelam1_
=
delta1
+
gamma
*
gae_lambda
*
lastgaelam1
lastgaelam2_
=
delta2
+
gamma
*
gae_lambda
*
lastgaelam2
lastgaelam2_
=
delta2
+
gamma
*
gae_lambda
*
lastgaelam2
advantages
=
jnp
.
where
(
main1
,
lastgaelam1_
,
lastgaelam2_
)
advantages
=
jnp
.
where
(
main1
,
lastgaelam1_
,
lastgaelam2_
)
...
@@ -424,7 +419,7 @@ def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
...
@@ -424,7 +419,7 @@ def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
carry1
=
GAESepCarry
(
lastgaelam1
,
next_value1
,
reward1
,
done_used1
,
last_return1
,
next_q1
)
carry1
=
GAESepCarry
(
lastgaelam1
,
next_value1
,
reward1
,
done_used1
,
last_return1
,
next_q1
)
carry2
=
GAESepCarry
(
lastgaelam2
,
next_value2
,
reward2
,
done_used2
,
last_return2
,
next_q2
)
carry2
=
GAESepCarry
(
lastgaelam2
,
next_value2
,
reward2
,
done_used2
,
last_return2
,
next_q2
)
return
(
carry1
,
carry2
),
(
advantages
,
return
s
)
return
(
carry1
,
carry2
),
(
advantages
,
return
_t
)
def
truncated_gae_sep
(
def
truncated_gae_sep
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment