Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
9d8d4386
Commit
9d8d4386
authored
Apr 13, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix bug: shuffle rstate in channels
parent
671ed3c6
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
835 additions
and
12 deletions
+835
-12
scripts/jax/ppo_lstm2.py
scripts/jax/ppo_lstm2.py
+816
-0
ygoai/rl/jax/__init__.py
ygoai/rl/jax/__init__.py
+8
-11
ygoai/rl/jax/utils.py
ygoai/rl/jax/utils.py
+11
-1
No files found.
scripts/jax/ppo_lstm2.py
0 → 100644
View file @
9d8d4386
This diff is collapsed.
Click to expand it.
ygoai/rl/jax/__init__.py
View file @
9d8d4386
...
...
@@ -95,10 +95,9 @@ def clipped_surrogate_pg_loss(prob_ratios_t, adv_t, mask, epsilon, use_stop_grad
return
-
jnp
.
mean
(
clipped_objective
*
mask
)
@
partial
(
jax
.
jit
,
static_argnums
=
(
6
,
7
))
@
partial
(
jax
.
jit
,
static_argnums
=
(
5
,
6
))
def
compute_gae_2p0s
(
next_value
,
next_done
,
values
,
rewards
,
dones
,
switch
,
gamma
,
gae_lambda
,
next_value
,
values
,
rewards
,
next_dones
,
switch
,
gamma
,
gae_lambda
,
):
def
body_fn
(
carry
,
inp
):
boot_value
,
boot_done
,
next_value
,
lastgaelam
=
carry
...
...
@@ -113,21 +112,20 @@ def compute_gae_2p0s(
lastgaelam
=
delta
+
gae_lambda
*
gamma_
*
lastgaelam
return
(
boot_value
,
boot_done
,
cur_value
,
lastgaelam
),
lastgaelam
dones
=
jnp
.
concatenate
([
dones
,
next_done
[
None
,
:]],
axis
=
0
)
next_done
=
next_dones
[
-
1
]
lastgaelam
=
jnp
.
zeros_like
(
next_value
)
carry
=
next_value
,
next_done
,
next_value
,
lastgaelam
_
,
advantages
=
jax
.
lax
.
scan
(
body_fn
,
carry
,
(
dones
[
1
:]
,
values
,
rewards
,
switch
),
reverse
=
True
body_fn
,
carry
,
(
next_dones
,
values
,
rewards
,
switch
),
reverse
=
True
)
target_values
=
advantages
+
values
return
advantages
,
target_values
@
partial
(
jax
.
jit
,
static_argnums
=
(
6
,
7
))
@
partial
(
jax
.
jit
,
static_argnums
=
(
5
,
6
))
def
compute_gae_upgo_2p0s
(
next_value
,
next_done
,
values
,
rewards
,
dones
,
switch
,
next_value
,
values
,
rewards
,
next_
dones
,
switch
,
gamma
,
gae_lambda
,
):
def
body_fn
(
carry
,
inp
):
...
...
@@ -150,13 +148,12 @@ def compute_gae_upgo_2p0s(
carry
=
boot_value
,
boot_done
,
cur_value
,
next_q
,
last_return
,
lastgaelam
return
carry
,
(
lastgaelam
,
last_return
)
dones
=
jnp
.
concatenate
([
dones
,
next_done
[
None
,
:]],
axis
=
0
)
next_done
=
next_dones
[
-
1
]
lastgaelam
=
jnp
.
zeros_like
(
next_value
)
carry
=
next_value
,
next_done
,
next_value
,
next_value
,
next_value
,
lastgaelam
_
,
(
advantages
,
returns
)
=
jax
.
lax
.
scan
(
body_fn
,
carry
,
(
dones
[
1
:]
,
values
,
rewards
,
switch
),
reverse
=
True
body_fn
,
carry
,
(
next_dones
,
values
,
rewards
,
switch
),
reverse
=
True
)
return
returns
-
values
,
advantages
+
values
...
...
ygoai/rl/jax/utils.py
View file @
9d8d4386
import
jax
import
jax.numpy
as
jnp
from
ygoai.rl.env
import
RecordEpisodeStatistics
...
...
@@ -13,4 +14,13 @@ def masked_normalize(x, valid, epsilon=1e-8):
n
=
valid
.
sum
()
mean
=
x
.
sum
()
/
n
variance
=
jnp
.
square
(
x
-
mean
)
.
sum
()
/
n
return
(
x
-
mean
)
/
jnp
.
sqrt
(
variance
+
epsilon
)
\ No newline at end of file
return
(
x
-
mean
)
/
jnp
.
sqrt
(
variance
+
epsilon
)
def
categorical_sample
(
logits
,
key
):
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key
,
subkey
=
jax
.
random
.
split
(
key
)
u
=
jax
.
random
.
uniform
(
subkey
,
shape
=
logits
.
shape
)
action
=
jnp
.
argmax
(
logits
-
jnp
.
log
(
-
jnp
.
log
(
u
)),
axis
=-
1
)
return
action
,
key
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment