Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
8cebfebf
Commit
8cebfebf
authored
May 31, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Recompute advantages every minibatch
parent
dd06205b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
46 deletions
+23
-46
scripts/cleanba.py
scripts/cleanba.py
+23
-46
No files found.
scripts/cleanba.py
View file @
8cebfebf
...
...
@@ -259,7 +259,7 @@ def reshape_minibatch(
# (n_mb, num_envs // n_mb, ...)
# else,
# n_mb_t = num_steps // segment_length
# n_mb_e = num_minibatches // n
um_minibatches1
# n_mb_e = num_minibatches // n
_mb_t
# if multi_step, from (num_steps, num_envs, ...)) to
# (n_mb_e, n_mb_t, segment_length * (num_envs // n_mb_e), ...)
# else, from (num_envs, ...) to
...
...
@@ -727,8 +727,8 @@ def main():
eval_params
=
None
def
advantage_fn
(
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
,
next_value
):
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
,
next_value
):
num_envs
=
next_value
.
shape
[
0
]
num_steps
=
next_dones
.
shape
[
0
]
//
num_envs
...
...
@@ -815,12 +815,23 @@ def main():
def
compute_advantage
(
params
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
,
next_value
):
segment_length
=
dones
.
shape
[
0
]
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
=
\
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)
+
x
.
shape
[
2
:]),
(
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
))
new_logits
,
new_values
=
apply_fn
(
params
,
obs
,
rstate1
,
rstate2
,
dones
,
next_dones
,
switch_or_mains
)[
1
:
3
]
target_values
,
advantages
=
advantage_fn
(
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
,
next_value
)
target_values
,
advantages
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
segment_length
,
-
1
)
+
x
.
shape
[
2
:]),
(
target_values
,
advantages
))
return
target_values
,
advantages
def
compute_loss
(
...
...
@@ -888,40 +899,6 @@ def main():
else
:
loss_grad_fn
=
jax
.
value_and_grad
(
compute_loss
,
has_aux
=
True
)
def
compute_advantage_t
(
next_value
):
N
=
args
.
num_minibatches
//
4
def
convert_data1
(
x
:
jnp
.
ndarray
,
multi_step
=
True
):
return
reshape_minibatch
(
x
,
multi_step
,
N
,
num_steps
)
b_init_rstate1
,
b_init_rstate2
,
b_next_value
=
jax
.
tree
.
map
(
partial
(
convert_data1
,
multi_step
=
False
),
(
init_rstate1
,
init_rstate2
,
next_value
))
b_storage
=
jax
.
tree
.
map
(
convert_data1
,
storage
)
if
args
.
switch
:
b_switch_or_mains
=
convert_data1
(
switch
)
else
:
b_switch_or_mains
=
b_storage
.
mains
target_values
,
advantages
=
jax
.
lax
.
scan
(
lambda
x
,
y
:
(
x
,
compute_advantage
(
x
,
*
y
)),
agent_state
.
params
,
(
b_init_rstate1
,
b_init_rstate2
,
b_storage
.
obs
,
b_storage
.
dones
,
b_storage
.
next_dones
,
b_switch_or_mains
,
b_storage
.
actions
,
b_storage
.
logits
,
b_storage
.
rewards
,
b_next_value
,
))[
1
]
target_values
,
advantages
=
jax
.
tree
.
map
(
partial
(
reshape_batch
,
num_minibatches
=
N
,
num_steps
=
num_steps
),
(
target_values
,
advantages
))
return
target_values
,
advantages
def
update_epoch
(
carry
,
_
):
agent_state
,
key
=
carry
key
,
subkey
=
jax
.
random
.
split
(
key
)
...
...
@@ -938,7 +915,6 @@ def main():
return
reshape_minibatch
(
x
,
multi_step
,
args
.
num_minibatches
,
num_steps
,
args
.
segment_length
,
key
=
key
)
shuffled_init_rstate1
,
shuffled_init_rstate2
=
jax
.
tree
.
map
(
partial
(
convert_data
,
multi_step
=
False
),
(
init_rstate1
,
init_rstate2
))
shuffled_storage
=
jax
.
tree
.
map
(
convert_data
,
storage
)
...
...
@@ -947,10 +923,9 @@ def main():
else
:
switch_or_mains
=
shuffled_storage
.
mains
shuffled_mask
=
~
shuffled_storage
.
dones
shuffled_next_value
=
convert_data
(
next_value
,
multi_step
=
False
)
if
args
.
segment_length
is
None
:
shuffled_next_value
=
convert_data
(
next_value
,
multi_step
=
False
)
others
=
shuffled_storage
.
rewards
,
shuffled_next_value
,
shuffled_mask
def
update_minibatch
(
agent_state
,
minibatch
):
(
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)),
grads
=
loss_grad_fn
(
agent_state
.
params
,
*
minibatch
)
...
...
@@ -958,10 +933,6 @@ def main():
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
return
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
else
:
target_values
,
advantages
=
compute_advantage_t
(
next_value
)
shuffled_target_values
,
shuffled_advantages
=
jax
.
tree
.
map
(
convert_data
,
(
target_values
,
advantages
))
others
=
shuffled_target_values
,
shuffled_advantages
,
shuffled_mask
def
update_minibatch
(
agent_state
,
minibatch
):
def
update_minibatch_t
(
carry
,
minibatch_t
):
agent_state
,
rstate1
,
rstate2
=
carry
...
...
@@ -972,7 +943,11 @@ def main():
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
return
(
agent_state
,
rstate1
,
rstate2
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
rstate1
,
rstate2
,
*
minibatch_t
=
minibatch
rstate1
,
rstate2
,
*
minibatch_t
,
mask
=
minibatch
target_values
,
advantages
=
compute_advantage
(
agent_state
.
params
,
rstate1
,
rstate2
,
*
minibatch_t
)
minibatch_t
=
*
minibatch_t
[:
-
2
],
target_values
,
advantages
,
mask
(
agent_state
,
_rstate1
,
_rstate2
),
\
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_minibatch_t
,
(
agent_state
,
rstate1
,
rstate2
),
minibatch_t
)
...
...
@@ -990,7 +965,9 @@ def main():
switch_or_mains
,
shuffled_storage
.
actions
,
shuffled_storage
.
logits
,
*
others
,
shuffled_storage
.
rewards
,
shuffled_next_value
,
shuffled_mask
),
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment