Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
e1ff8f92
Commit
e1ff8f92
authored
Jun 13, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add more rnn options and batch norm
parent
974fe861
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
373 additions
and
83 deletions
+373
-83
scripts/cleanba.py
scripts/cleanba.py
+99
-73
ygoai/rl/jax/agent.py
ygoai/rl/jax/agent.py
+54
-8
ygoai/rl/jax/modules.py
ygoai/rl/jax/modules.py
+144
-2
ygoai/rl/jax/utils.py
ygoai/rl/jax/utils.py
+76
-0
No files found.
scripts/cleanba.py
View file @
e1ff8f92
...
@@ -19,14 +19,13 @@ import numpy as np
...
@@ -19,14 +19,13 @@ import numpy as np
import
optax
import
optax
import
distrax
import
distrax
import
tyro
import
tyro
from
flax.training.train_state
import
TrainState
from
rich.pretty
import
pprint
from
rich.pretty
import
pprint
from
tensorboardX
import
SummaryWriter
from
tensorboardX
import
SummaryWriter
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.ckpt
import
ModelCheckpoint
,
sync_to_gcs
,
zip_files
from
ygoai.rl.ckpt
import
ModelCheckpoint
,
sync_to_gcs
,
zip_files
from
ygoai.rl.jax.agent
import
RNNAgent
,
ModelArgs
from
ygoai.rl.jax.agent
import
RNNAgent
,
ModelArgs
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
,
TrainState
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax.switch
import
truncated_gae_sep
as
gae_sep_switch
from
ygoai.rl.jax.switch
import
truncated_gae_sep
as
gae_sep_switch
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
\
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
\
...
@@ -251,6 +250,14 @@ def create_agent(args, eval=False):
...
@@ -251,6 +250,14 @@ def create_agent(args, eval=False):
)
)
def
get_variables
(
agent_state
):
batch_stats
=
getattr
(
agent_state
,
"batch_stats"
,
None
)
variables
=
{
'params'
:
agent_state
.
params
}
if
batch_stats
is
not
None
:
variables
[
'batch_stats'
]
=
batch_stats
return
variables
def
init_rnn_state
(
num_envs
,
rnn_channels
):
def
init_rnn_state
(
num_envs
,
rnn_channels
):
return
(
return
(
np
.
zeros
((
num_envs
,
rnn_channels
)),
np
.
zeros
((
num_envs
,
rnn_channels
)),
...
@@ -502,11 +509,9 @@ def rollout(
...
@@ -502,11 +509,9 @@ def rollout(
sharded_storage
.
append
(
x
)
sharded_storage
.
append
(
x
)
sharded_storage
=
Transition
(
*
sharded_storage
)
sharded_storage
=
Transition
(
*
sharded_storage
)
next_main
=
main_player
==
next_to_play
next_main
=
main_player
==
next_to_play
next_rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
next_rstate1
,
next_rstate2
)
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
(
init_rstate1
,
init_rstate2
,
(
next_obs
,
next_rstate
)
,
next_main
))
(
init_rstate1
,
init_rstate2
,
next_obs
,
next_main
))
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
_start
=
time
.
time
()
_start
=
time
.
time
()
...
@@ -683,13 +688,17 @@ def main():
...
@@ -683,13 +688,17 @@ def main():
agent
=
create_agent
(
args
)
agent
=
create_agent
(
args
)
rstate
=
agent
.
init_rnn_state
(
1
)
rstate
=
agent
.
init_rnn_state
(
1
)
params
=
agent
.
init
(
init_key
,
sample_obs
,
rstate
)
variables
=
agent
.
init
(
init_key
,
sample_obs
,
rstate
)
variables
=
flax
.
core
.
unfreeze
(
variables
)
if
embeddings
is
not
None
:
if
embeddings
is
not
None
:
unknown_embed
=
embeddings
.
mean
(
axis
=
0
)
unknown_embed
=
embeddings
.
mean
(
axis
=
0
)
embeddings
=
np
.
concatenate
([
unknown_embed
[
None
,
:],
embeddings
],
axis
=
0
)
embeddings
=
np
.
concatenate
([
unknown_embed
[
None
,
:],
embeddings
],
axis
=
0
)
params
=
flax
.
core
.
unfreeze
(
params
)
variables
[
'params'
][
'Encoder_0'
][
'Embed_0'
][
'embedding'
]
=
jax
.
device_put
(
embeddings
)
params
[
'params'
][
'Encoder_0'
][
'Embed_0'
][
'embedding'
]
=
jax
.
device_put
(
embeddings
)
# variables = flax.core.freeze(variables)
params
=
flax
.
core
.
freeze
(
params
)
if
args
.
checkpoint
:
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
variables
=
flax
.
serialization
.
from_bytes
(
variables
,
f
.
read
())
print
(
f
"loaded checkpoint from {args.checkpoint}"
)
tx
=
optax
.
MultiSteps
(
tx
=
optax
.
MultiSteps
(
optax
.
chain
(
optax
.
chain
(
...
@@ -701,29 +710,29 @@ def main():
...
@@ -701,29 +710,29 @@ def main():
every_k_schedule
=
1
,
every_k_schedule
=
1
,
)
)
tx
=
optax
.
apply_if_finite
(
tx
,
max_consecutive_errors
=
10
)
tx
=
optax
.
apply_if_finite
(
tx
,
max_consecutive_errors
=
10
)
if
'batch_stats'
not
in
variables
:
# variables = flax.core.unfreeze(variables)
variables
[
'batch_stats'
]
=
{}
# variables = flax.core.freeze(variables)
agent_state
=
TrainState
.
create
(
agent_state
=
TrainState
.
create
(
apply_fn
=
None
,
apply_fn
=
None
,
params
=
params
,
params
=
variables
[
'params'
]
,
tx
=
tx
,
tx
=
tx
,
batch_stats
=
variables
[
'batch_stats'
],
)
)
if
args
.
checkpoint
:
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
agent_state
=
agent_state
.
replace
(
params
=
params
)
print
(
f
"loaded checkpoint from {args.checkpoint}"
)
agent_state
=
flax
.
jax_utils
.
replicate
(
agent_state
,
devices
=
learner_devices
)
agent_state
=
flax
.
jax_utils
.
replicate
(
agent_state
,
devices
=
learner_devices
)
# print(agent.tabulate(agent_key, sample_obs))
# print(agent.tabulate(agent_key, sample_obs))
if
args
.
eval_checkpoint
:
if
args
.
eval_checkpoint
:
eval_agent
=
create_agent
(
args
,
eval
=
True
)
eval_agent
=
create_agent
(
args
,
eval
=
True
)
eval_rstate
=
eval_agent
.
init_rnn_state
(
1
)
eval_rstate
=
eval_agent
.
init_rnn_state
(
1
)
eval_
param
s
=
eval_agent
.
init
(
init_key
,
sample_obs
,
eval_rstate
)
eval_
variable
s
=
eval_agent
.
init
(
init_key
,
sample_obs
,
eval_rstate
)
with
open
(
args
.
eval_checkpoint
,
"rb"
)
as
f
:
with
open
(
args
.
eval_checkpoint
,
"rb"
)
as
f
:
eval_
params
=
flax
.
serialization
.
from_bytes
(
eval_param
s
,
f
.
read
())
eval_
variables
=
flax
.
serialization
.
from_bytes
(
eval_variable
s
,
f
.
read
())
print
(
f
"loaded eval checkpoint from {args.eval_checkpoint}"
)
print
(
f
"loaded eval checkpoint from {args.eval_checkpoint}"
)
else
:
else
:
eval_
param
s
=
None
eval_
variable
s
=
None
def
advantage_fn
(
def
advantage_fn
(
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
...
@@ -811,17 +820,29 @@ def main():
...
@@ -811,17 +820,29 @@ def main():
loss
=
pg_loss
-
args
.
ent_coef
*
ent_loss
+
v_loss
*
args
.
vf_coef
loss
=
pg_loss
-
args
.
ent_coef
*
ent_loss
+
v_loss
*
args
.
vf_coef
return
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
return
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
def
apply_fn
(
param
s
,
obs
,
rstate1
,
rstate2
,
dones
,
next_dones
,
switch_or_mains
):
def
apply_fn
(
variable
s
,
obs
,
rstate1
,
rstate2
,
dones
,
next_dones
,
switch_or_mains
):
if
args
.
switch
:
if
args
.
switch
:
dones
=
dones
|
next_dones
dones
=
dones
|
next_dones
(
rstate1
,
rstate2
),
new_logits
,
new_values
=
agent
.
apply
(
((
rstate1
,
rstate2
),
new_logits
,
new_values
,
_
),
state_updates
=
agent
.
apply
(
params
,
obs
,
(
rstate1
,
rstate2
),
dones
,
switch_or_mains
)[:
3
]
variables
,
obs
,
(
rstate1
,
rstate2
),
dones
,
switch_or_mains
,
train
=
True
,
mutable
=
[
"batch_stats"
])
new_values
=
jax
.
tree
.
map
(
lambda
x
:
x
.
squeeze
(
-
1
),
new_values
)
new_values
=
jax
.
tree
.
map
(
lambda
x
:
x
.
squeeze
(
-
1
),
new_values
)
return
(
rstate1
,
rstate2
),
new_logits
,
new_values
return
((
rstate1
,
rstate2
),
new_logits
,
new_values
),
state_updates
def
compute_next_value
(
variables
,
rstate1
,
rstate2
,
next_obs
,
next_main
):
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
next_value
=
agent
.
apply
(
variables
,
next_obs
,
rstate
)[
2
]
next_value
=
jax
.
tree
.
map
(
lambda
x
:
x
.
squeeze
(
-
1
),
next_value
)
next_value
=
jax
.
lax
.
stop_gradient
(
next_value
)
sign
=
-
1
if
args
.
switch
else
1
next_value
=
jnp
.
where
(
next_main
,
sign
*
next_value
,
-
sign
*
next_value
)
return
next_value
def
compute_advantage
(
def
compute_advantage
(
param
s
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
variable
s
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
,
next_
value
):
switch_or_mains
,
actions
,
logits
,
rewards
,
next_
obs
,
next_main
):
segment_length
=
dones
.
shape
[
0
]
segment_length
=
dones
.
shape
[
0
]
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
=
\
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
=
\
...
@@ -829,8 +850,11 @@ def main():
...
@@ -829,8 +850,11 @@ def main():
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)
+
x
.
shape
[
2
:]),
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)
+
x
.
shape
[
2
:]),
(
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
))
(
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
))
new_logits
,
new_values
=
apply_fn
(
((
rstate1
,
rstate2
),
new_logits
,
new_values
),
state_updates
=
apply_fn
(
params
,
obs
,
rstate1
,
rstate2
,
dones
,
next_dones
,
switch_or_mains
)[
1
:
3
]
variables
,
obs
,
rstate1
,
rstate2
,
dones
,
next_dones
,
switch_or_mains
)
next_value
=
compute_next_value
(
variables
,
rstate1
,
rstate2
,
next_obs
,
next_main
)
target_values
,
advantages
=
advantage_fn
(
target_values
,
advantages
=
advantage_fn
(
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
...
@@ -842,10 +866,11 @@ def main():
...
@@ -842,10 +866,11 @@ def main():
return
target_values
,
advantages
return
target_values
,
advantages
def
compute_loss
(
def
compute_loss
(
params
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
params
,
batch_stats
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
target_values
,
advantages
,
mask
):
switch_or_mains
,
actions
,
logits
,
target_values
,
advantages
,
mask
):
(
rstate1
,
rstate2
),
new_logits
,
new_values
=
apply_fn
(
variables
=
{
'params'
:
params
,
'batch_stats'
:
batch_stats
}
params
,
obs
,
rstate1
,
rstate2
,
dones
,
next_dones
,
switch_or_mains
)
((
rstate1
,
rstate2
),
new_logits
,
new_values
),
state_updates
=
apply_fn
(
variables
,
obs
,
rstate1
,
rstate2
,
dones
,
next_dones
,
switch_or_mains
)
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
=
loss_fn
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
=
loss_fn
(
new_logits
,
new_values
,
actions
,
logits
,
target_values
,
advantages
,
new_logits
,
new_values
,
actions
,
logits
,
target_values
,
advantages
,
...
@@ -854,14 +879,19 @@ def main():
...
@@ -854,14 +879,19 @@ def main():
loss
=
jnp
.
where
(
jnp
.
isnan
(
loss
)
|
jnp
.
isinf
(
loss
),
0.0
,
loss
)
loss
=
jnp
.
where
(
jnp
.
isnan
(
loss
)
|
jnp
.
isinf
(
loss
),
0.0
,
loss
)
approx_kl
,
rstate1
,
rstate2
=
jax
.
tree
.
map
(
approx_kl
,
rstate1
,
rstate2
=
jax
.
tree
.
map
(
jax
.
lax
.
stop_gradient
,
(
approx_kl
,
rstate1
,
rstate2
))
jax
.
lax
.
stop_gradient
,
(
approx_kl
,
rstate1
,
rstate2
))
return
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
,
rstate1
,
rstate2
)
return
loss
,
(
state_updates
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
,
rstate1
,
rstate2
)
def
compute_advantage_loss
(
def
compute_advantage_loss
(
params
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
params
,
batch_stats
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
,
next_value
,
mask
):
switch_or_mains
,
actions
,
logits
,
rewards
,
mask
,
next_obs
,
next_main
):
num_envs
=
jax
.
tree
.
leaves
(
next_value
)[
0
]
.
shape
[
0
]
num_envs
=
jax
.
tree
.
leaves
(
next_main
)[
0
]
.
shape
[
0
]
new_logits
,
new_values
=
apply_fn
(
variables
=
{
'params'
:
params
,
'batch_stats'
:
batch_stats
}
params
,
obs
,
rstate1
,
rstate2
,
dones
,
next_dones
,
switch_or_mains
)[
1
:
3
]
((
rstate1
,
rstate2
),
new_logits
,
new_values
),
state_updates
=
apply_fn
(
variables
,
obs
,
rstate1
,
rstate2
,
dones
,
next_dones
,
switch_or_mains
)
variables
=
{
'params'
:
params
,
'batch_stats'
:
state_updates
[
'batch_stats'
]}
next_value
=
compute_next_value
(
variables
,
rstate1
,
rstate2
,
next_obs
,
next_main
)
target_values
,
advantages
=
advantage_fn
(
target_values
,
advantages
=
advantage_fn
(
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
...
@@ -873,22 +903,21 @@ def main():
...
@@ -873,22 +903,21 @@ def main():
loss
=
jnp
.
where
(
jnp
.
isnan
(
loss
)
|
jnp
.
isinf
(
loss
),
0.0
,
loss
)
loss
=
jnp
.
where
(
jnp
.
isnan
(
loss
)
|
jnp
.
isinf
(
loss
),
0.0
,
loss
)
approx_kl
=
jax
.
lax
.
stop_gradient
(
approx_kl
)
approx_kl
=
jax
.
lax
.
stop_gradient
(
approx_kl
)
return
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
return
loss
,
(
state_updates
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
def
single_device_update
(
def
single_device_update
(
agent_state
:
TrainState
,
agent_state
:
TrainState
,
sharded_storages
:
List
,
sharded_storages
:
List
,
sharded_init_rstate1
:
List
,
sharded_init_rstate1
:
List
,
sharded_init_rstate2
:
List
,
sharded_init_rstate2
:
List
,
sharded_next_
input
s
:
List
,
sharded_next_
ob
s
:
List
,
sharded_next_main
:
List
,
sharded_next_main
:
List
,
key
:
jax
.
random
.
PRNGKey
,
key
:
jax
.
random
.
PRNGKey
,
):
):
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
# TODO: rstate will be out-date after the first update, maybe consider R2D2
next_obs
,
init_rstate1
,
init_rstate2
=
[
next_inputs
,
init_rstate1
,
init_rstate2
=
[
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
for
x
in
[
sharded_next_
input
s
,
sharded_init_rstate1
,
sharded_init_rstate2
]
for
x
in
[
sharded_next_
ob
s
,
sharded_init_rstate1
,
sharded_init_rstate2
]
]
]
next_main
=
jnp
.
concatenate
(
sharded_next_main
)
next_main
=
jnp
.
concatenate
(
sharded_next_main
)
...
@@ -913,49 +942,45 @@ def main():
...
@@ -913,49 +942,45 @@ def main():
agent_state
,
key
=
carry
agent_state
,
key
=
carry
key
,
subkey
=
jax
.
random
.
split
(
key
)
key
,
subkey
=
jax
.
random
.
split
(
key
)
next_value
=
agent
.
apply
(
agent_state
.
params
,
*
next_inputs
)[
2
]
next_value
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
squeeze
(
x
,
axis
=-
1
),
next_value
)
sign
=
-
1
if
args
.
switch
else
1
next_value
=
jnp
.
where
(
next_main
,
sign
*
next_value
,
-
sign
*
next_value
)
def
convert_data
(
x
:
jnp
.
ndarray
,
multi_step
=
True
):
def
convert_data
(
x
:
jnp
.
ndarray
,
multi_step
=
True
):
key
=
subkey
if
args
.
update_epochs
>
1
else
None
key
=
subkey
if
args
.
update_epochs
>
1
else
None
return
reshape_minibatch
(
return
reshape_minibatch
(
x
,
multi_step
,
args
.
num_minibatches
,
num_steps
,
args
.
segment_length
,
key
=
key
)
x
,
multi_step
,
args
.
num_minibatches
,
num_steps
,
args
.
segment_length
,
key
=
key
)
shuffled_init_rstate1
,
shuffled_init_rstate2
=
jax
.
tree
.
map
(
b_init_rstate1
,
b_init_rstate2
,
b_next_obs
,
b_next_main
=
\
partial
(
convert_data
,
multi_step
=
False
),
(
init_rstate1
,
init_rstate2
))
jax
.
tree
.
map
(
partial
(
convert_data
,
multi_step
=
False
),
shuffled_storage
=
jax
.
tree
.
map
(
convert_data
,
storage
)
(
init_rstate1
,
init_rstate2
,
next_obs
,
next_main
))
b_storage
=
jax
.
tree
.
map
(
convert_data
,
storage
)
if
args
.
switch
:
if
args
.
switch
:
switch_or_mains
=
convert_data
(
switch
)
switch_or_mains
=
convert_data
(
switch
)
else
:
else
:
switch_or_mains
=
shuffled_storage
.
mains
switch_or_mains
=
b_storage
.
mains
shuffled_mask
=
~
shuffled_storage
.
dones
b_mask
=
~
b_storage
.
dones
shuffled_next_value
=
jax
.
tree
.
map
(
b_rewards
=
b_storage
.
rewards
partial
(
convert_data
,
multi_step
=
False
),
next_value
)
shuffled_rewards
=
shuffled_storage
.
rewards
if
args
.
segment_length
is
None
:
if
args
.
segment_length
is
None
:
def
update_minibatch
(
agent_state
,
minibatch
):
def
update_minibatch
(
agent_state
,
minibatch
):
(
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)),
grads
=
loss_grad_fn
(
(
loss
,
(
state_updates
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)),
grads
=
\
agent_state
.
param
s
,
*
minibatch
)
loss_grad_fn
(
agent_state
.
params
,
agent_state
.
batch_stat
s
,
*
minibatch
)
grads
=
jax
.
lax
.
pmean
(
grads
,
axis_name
=
"local_devices"
)
grads
=
jax
.
lax
.
pmean
(
grads
,
axis_name
=
"local_devices"
)
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
agent_state
=
agent_state
.
replace
(
batch_stats
=
state_updates
[
'batch_stats'
])
return
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
return
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
else
:
else
:
def
update_minibatch
(
carry
,
minibatch
):
def
update_minibatch
(
carry
,
minibatch
):
def
update_minibatch_t
(
carry
,
minibatch_t
):
def
update_minibatch_t
(
carry
,
minibatch_t
):
agent_state
,
rstate1
,
rstate2
=
carry
agent_state
,
rstate1
,
rstate2
=
carry
minibatch_t
=
rstate1
,
rstate2
,
*
minibatch_t
minibatch_t
=
rstate1
,
rstate2
,
*
minibatch_t
(
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
,
rstate1
,
rstate2
)),
\
(
loss
,
(
state_updates
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
,
rstate1
,
rstate2
)),
\
grads
=
loss_grad_fn
(
agent_state
.
params
,
*
minibatch_t
)
grads
=
loss_grad_fn
(
agent_state
.
params
,
agent_state
.
batch_stats
,
*
minibatch_t
)
grads
=
jax
.
lax
.
pmean
(
grads
,
axis_name
=
"local_devices"
)
grads
=
jax
.
lax
.
pmean
(
grads
,
axis_name
=
"local_devices"
)
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
agent_state
=
agent_state
.
replace
(
batch_stats
=
state_updates
[
'batch_stats'
])
return
(
agent_state
,
rstate1
,
rstate2
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
return
(
agent_state
,
rstate1
,
rstate2
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
rstate1
,
rstate2
,
*
minibatch_t
,
mask
=
minibatch
rstate1
,
rstate2
,
*
minibatch_t
,
mask
=
minibatch
target_values
,
advantages
=
compute_advantage
(
target_values
,
advantages
=
compute_advantage
(
carry
.
params
,
rstate1
,
rstate2
,
*
minibatch_t
)
get_variables
(
carry
)
,
rstate1
,
rstate2
,
*
minibatch_t
)
minibatch_t
=
*
minibatch_t
[:
-
2
],
target_values
,
advantages
,
mask
minibatch_t
=
*
minibatch_t
[:
-
2
],
target_values
,
advantages
,
mask
(
carry
,
_rstate1
,
_rstate2
),
\
(
carry
,
_rstate1
,
_rstate2
),
\
...
@@ -967,17 +992,18 @@ def main():
...
@@ -967,17 +992,18 @@ def main():
update_minibatch
,
update_minibatch
,
agent_state
,
agent_state
,
(
(
shuffled
_init_rstate1
,
b
_init_rstate1
,
shuffled
_init_rstate2
,
b
_init_rstate2
,
shuffled
_storage
.
obs
,
b
_storage
.
obs
,
shuffled
_storage
.
dones
,
b
_storage
.
dones
,
shuffled
_storage
.
next_dones
,
b
_storage
.
next_dones
,
switch_or_mains
,
switch_or_mains
,
shuffled_storage
.
actions
,
b_storage
.
actions
,
shuffled_storage
.
logits
,
b_storage
.
logits
,
shuffled_rewards
,
b_rewards
,
shuffled_next_value
,
b_mask
,
shuffled_mask
,
b_next_obs
,
b_next_main
,
),
),
)
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
...
@@ -1007,16 +1033,16 @@ def main():
...
@@ -1007,16 +1033,16 @@ def main():
params_queues
=
[]
params_queues
=
[]
rollout_queues
=
[]
rollout_queues
=
[]
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
agent_state
.
params
)
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
get_variables
(
agent_state
)
)
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
actor_device
=
local_devices
[
d_id
]
actor_device
=
local_devices
[
d_id
]
device_params
=
jax
.
device_put
(
unreplicated_params
,
actor_device
)
device_params
=
jax
.
device_put
(
unreplicated_params
,
actor_device
)
for
thread_id
in
range
(
args
.
num_actor_threads
):
for
thread_id
in
range
(
args
.
num_actor_threads
):
params_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
params_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
rollout_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
rollout_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
if
eval_
param
s
:
if
eval_
variable
s
:
params_queues
[
-
1
]
.
put
(
params_queues
[
-
1
]
.
put
(
jax
.
device_put
(
eval_
param
s
,
actor_device
))
jax
.
device_put
(
eval_
variable
s
,
actor_device
))
actor_thread_id
=
d_idx
*
args
.
num_actor_threads
+
thread_id
actor_thread_id
=
d_idx
*
args
.
num_actor_threads
+
thread_id
threading
.
Thread
(
threading
.
Thread
(
target
=
rollout
,
target
=
rollout
,
...
@@ -1070,7 +1096,7 @@ def main():
...
@@ -1070,7 +1096,7 @@ def main():
*
list
(
zip
(
*
sharded_data_list
)),
*
list
(
zip
(
*
sharded_data_list
)),
learner_keys
,
learner_keys
,
)
)
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
agent_state
.
params
)
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
get_variables
(
agent_state
)
)
params_queue_put_time
=
0
params_queue_put_time
=
0
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
device_params
=
jax
.
device_put
(
unreplicated_params
,
local_devices
[
d_id
])
device_params
=
jax
.
device_put
(
unreplicated_params
,
local_devices
[
d_id
])
...
...
ygoai/rl/jax/agent.py
View file @
e1ff8f92
...
@@ -8,7 +8,7 @@ import jax.numpy as jnp
...
@@ -8,7 +8,7 @@ import jax.numpy as jnp
import
flax.linen
as
nn
import
flax.linen
as
nn
from
ygoai.rl.jax.transformer
import
EncoderLayer
,
PositionalEncoding
,
LlamaEncoderLayer
from
ygoai.rl.jax.transformer
import
EncoderLayer
,
PositionalEncoding
,
LlamaEncoderLayer
from
ygoai.rl.jax.modules
import
MLP
,
GLUMlp
,
RMSN
orm
,
make_bin_params
,
bytes_to_bin
,
decode_id
from
ygoai.rl.jax.modules
import
MLP
,
GLUMlp
,
BatchRen
orm
,
make_bin_params
,
bytes_to_bin
,
decode_id
from
ygoai.rl.jax.rwkv
import
Rwkv6SelfAttention
from
ygoai.rl.jax.rwkv
import
Rwkv6SelfAttention
...
@@ -487,7 +487,7 @@ class Critic(nn.Module):
...
@@ -487,7 +487,7 @@ class Critic(nn.Module):
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
f_state
):
def
__call__
(
self
,
f_state
,
train
):
f_state
=
f_state
.
astype
(
self
.
dtype
)
f_state
=
f_state
.
astype
(
self
.
dtype
)
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
x
=
mlp
(
self
.
channels
,
last_lin
=
False
)(
f_state
)
x
=
mlp
(
self
.
channels
,
last_lin
=
False
)(
f_state
)
...
@@ -495,6 +495,33 @@ class Critic(nn.Module):
...
@@ -495,6 +495,33 @@ class Critic(nn.Module):
return
x
return
x
class
CrossCritic
(
nn
.
Module
):
channels
:
Sequence
[
int
]
=
(
128
,
128
,
128
)
# dropout_rate: Optional[float] = None
batch_norm_momentum
:
float
=
0.99
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
f_state
,
train
):
x
=
f_state
.
astype
(
self
.
dtype
)
linear
=
partial
(
nn
.
Dense
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
use_bias
=
False
)
BN
=
partial
(
BatchRenorm
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
momentum
=
self
.
batch_norm_momentum
,
axis_name
=
"local_devices"
,
use_running_average
=
not
train
)
x
=
BN
()(
x
)
for
c
in
self
.
channels
:
x
=
linear
(
c
)(
x
)
# if self.use_layer_norm:
# x = nn.LayerNorm()(x)
x
=
nn
.
relu
()(
x
)
# x = nn.leaky_relu(x, negative_slope=0.1)
x
=
BN
()(
x
)
x
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
x
)
return
x
class
GlobalCritic
(
nn
.
Module
):
class
GlobalCritic
(
nn
.
Module
):
channels
:
Sequence
[
int
]
=
(
128
,
128
)
channels
:
Sequence
[
int
]
=
(
128
,
128
)
dtype
:
Optional
[
jnp
.
dtype
]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
...
@@ -580,6 +607,14 @@ class ModelArgs(EncoderArgs):
...
@@ -580,6 +607,14 @@ class ModelArgs(EncoderArgs):
"""whether to use FiLM for the actor"""
"""whether to use FiLM for the actor"""
oppo_info
:
bool
=
False
oppo_info
:
bool
=
False
"""whether to use opponent's information"""
"""whether to use opponent's information"""
rnn_shortcut
:
bool
=
False
"""whether to use shortcut for the RNN"""
batch_norm
:
bool
=
False
"""whether to use batch normalization for the critic"""
critic_width
:
int
=
128
"""the width of the critic"""
critic_depth
:
int
=
3
"""the depth of the critic"""
rwkv_head_size
:
int
=
32
rwkv_head_size
:
int
=
32
"""the head size for the RWKV"""
"""the head size for the RWKV"""
...
@@ -596,6 +631,10 @@ class RNNAgent(nn.Module):
...
@@ -596,6 +631,10 @@ class RNNAgent(nn.Module):
rwkv_head_size
:
int
=
32
rwkv_head_size
:
int
=
32
action_feats
:
bool
=
True
action_feats
:
bool
=
True
oppo_info
:
bool
=
False
oppo_info
:
bool
=
False
rnn_shortcut
:
bool
=
False
batch_norm
:
bool
=
False
critic_width
:
int
=
128
critic_depth
:
int
=
3
version
:
int
=
0
version
:
int
=
0
switch
:
bool
=
True
switch
:
bool
=
True
...
@@ -606,7 +645,7 @@ class RNNAgent(nn.Module):
...
@@ -606,7 +645,7 @@ class RNNAgent(nn.Module):
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
x
,
rstate
,
done
=
None
,
switch_or_main
=
None
):
def
__call__
(
self
,
x
,
rstate
,
done
=
None
,
switch_or_main
=
None
,
train
=
False
):
batch_size
=
jax
.
tree
.
leaves
(
rstate
)[
0
]
.
shape
[
0
]
batch_size
=
jax
.
tree
.
leaves
(
rstate
)[
0
]
.
shape
[
0
]
c
=
self
.
num_channels
c
=
self
.
num_channels
...
@@ -669,6 +708,10 @@ class RNNAgent(nn.Module):
...
@@ -669,6 +708,10 @@ class RNNAgent(nn.Module):
rstate
,
f_state_r
=
rnn_step_by_main
(
rstate
,
f_state_r
=
rnn_step_by_main
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
)
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
)
if
self
.
rnn_shortcut
:
# f_state_r = ReZero(channel_wise=True)(f_state_r)
f_state_r
=
jnp
.
concatenate
([
f_state
,
f_state_r
],
axis
=-
1
)
if
self
.
film
:
if
self
.
film
:
actor
=
FiLMActor
(
actor
=
FiLMActor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
noam
=
self
.
noam
)
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
noam
=
self
.
noam
)
...
@@ -694,13 +737,16 @@ class RNNAgent(nn.Module):
...
@@ -694,13 +737,16 @@ class RNNAgent(nn.Module):
lambda
x1
,
x2
:
jnp
.
where
(
main
,
x2
,
x1
),
rstate1
,
rstate2
)
lambda
x1
,
x2
:
jnp
.
where
(
main
,
x2
,
x1
),
rstate1
,
rstate2
)
value
=
critic
(
rstate1_t
,
rstate2_t
,
f_g
)
value
=
critic
(
rstate1_t
,
rstate2_t
,
f_g
)
else
:
else
:
critic
=
Critic
(
CriticCls
=
CrossCritic
if
self
.
batch_norm
else
Critic
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
cs
=
[
self
.
critic_width
]
*
self
.
critic_depth
value
=
critic
(
f_state_r
)
critic
=
CriticCls
(
channels
=
cs
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
value
=
critic
(
f_state_r
,
train
)
if
self
.
int_head
:
if
self
.
int_head
:
cs
=
[
self
.
critic_width
]
*
self
.
critic_depth
critic_int
=
Critic
(
critic_int
=
Critic
(
channels
=
[
c
,
c
,
c
]
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
channels
=
cs
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
value_int
=
critic_int
(
f_state_r
)
value_int
=
critic_int
(
f_state_r
)
value
=
(
value
,
value_int
)
value
=
(
value
,
value_int
)
return
rstate
,
logits
,
value
,
valid
return
rstate
,
logits
,
value
,
valid
...
...
ygoai/rl/jax/modules.py
View file @
e1ff8f92
from
typing
import
Tuple
,
Union
,
Optional
from
typing
import
Tuple
,
Union
,
Optional
,
Any
import
functools
import
functools
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
flax.linen
as
nn
import
flax.linen
as
nn
from
flax.linen.normalization
import
_compute_stats
,
_normalize
,
_canonicalize_axes
def
decode_id
(
x
):
def
decode_id
(
x
):
...
@@ -110,3 +111,144 @@ class RMSNorm(nn.Module):
...
@@ -110,3 +111,144 @@ class RMSNorm(nn.Module):
)
)
x
=
x
*
scale
x
=
x
*
scale
return
jnp
.
asarray
(
x
,
self
.
dtype
)
return
jnp
.
asarray
(
x
,
self
.
dtype
)
class
ReZero
(
nn
.
Module
):
channel_wise
:
bool
=
False
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
):
shape
=
(
x
.
shape
[
-
1
],)
if
self
.
channel_wise
else
()
scale
=
self
.
param
(
"scale"
,
nn
.
initializers
.
zeros
,
shape
,
self
.
param_dtype
)
return
x
*
scale
class
BatchRenorm
(
nn
.
Module
):
"""BatchRenorm Module, implemented based on the Batch Renormalization paper (https://arxiv.org/abs/1702.03275).
and adapted from Flax's BatchNorm implementation:
https://github.com/google/flax/blob/ce8a3c74d8d1f4a7d8f14b9fb84b2cc76d7f8dbf/flax/linen/normalization.py#L228
Attributes:
use_running_average: if True, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.
axis: the feature or non-batch axis of the input.
momentum: decay rate for the exponential moving average of the batch
statistics.
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
use_bias: if True, bias (beta) is added.
use_scale: if True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
bias_init: initializer for bias, by default, zero.
scale_init: initializer for scale, by default, one.
axis_name: the axis name used to combine batch statistics from multiple
devices. See `jax.pmap` for a description of axis names (default: None).
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For
example, `[[0, 1], [2, 3]]` would independently batch-normalize over the
examples on the first two and last two devices. See `jax.lax.psum` for
more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
"""
use_running_average
:
Optional
[
bool
]
=
None
axis
:
int
=
-
1
momentum
:
float
=
0.999
epsilon
:
float
=
0.001
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
use_bias
:
bool
=
True
use_scale
:
bool
=
True
bias_init
:
nn
.
initializers
.
Initializer
=
nn
.
initializers
.
zeros
scale_init
:
nn
.
initializers
.
Initializer
=
nn
.
initializers
.
ones
axis_name
:
Optional
[
str
]
=
None
axis_index_groups
:
Any
=
None
use_fast_variance
:
bool
=
True
@
nn
.
compact
def
__call__
(
self
,
x
,
use_running_average
:
Optional
[
bool
]
=
None
):
"""
Args:
x: the input to be normalized.
use_running_average: if true, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.
Returns:
Normalized inputs (the same shape as inputs).
"""
use_running_average
=
nn
.
merge_param
(
'use_running_average'
,
self
.
use_running_average
,
use_running_average
)
feature_axes
=
_canonicalize_axes
(
x
.
ndim
,
self
.
axis
)
reduction_axes
=
tuple
(
i
for
i
in
range
(
x
.
ndim
)
if
i
not
in
feature_axes
)
feature_shape
=
[
x
.
shape
[
ax
]
for
ax
in
feature_axes
]
ra_mean
=
self
.
variable
(
'batch_stats'
,
'mean'
,
lambda
s
:
jnp
.
zeros
(
s
,
jnp
.
float32
),
feature_shape
)
ra_var
=
self
.
variable
(
'batch_stats'
,
'var'
,
lambda
s
:
jnp
.
ones
(
s
,
jnp
.
float32
),
feature_shape
)
r_max
=
self
.
variable
(
'batch_stats'
,
'r_max'
,
lambda
s
:
s
,
3
)
d_max
=
self
.
variable
(
'batch_stats'
,
'd_max'
,
lambda
s
:
s
,
5
)
steps
=
self
.
variable
(
'batch_stats'
,
'steps'
,
lambda
s
:
s
,
0
)
if
use_running_average
:
mean
,
var
=
ra_mean
.
value
,
ra_var
.
value
custom_mean
=
mean
custom_var
=
var
else
:
mean
,
var
=
_compute_stats
(
x
,
reduction_axes
,
dtype
=
self
.
dtype
,
axis_name
=
self
.
axis_name
if
not
self
.
is_initializing
()
else
None
,
axis_index_groups
=
self
.
axis_index_groups
,
use_fast_variance
=
self
.
use_fast_variance
,
)
custom_mean
=
mean
custom_var
=
var
if
not
self
.
is_initializing
():
# The code below is implemented following the Batch Renormalization paper
r
=
1
d
=
0
std
=
jnp
.
sqrt
(
var
+
self
.
epsilon
)
ra_std
=
jnp
.
sqrt
(
ra_var
.
value
+
self
.
epsilon
)
r
=
jax
.
lax
.
stop_gradient
(
std
/
ra_std
)
r
=
jnp
.
clip
(
r
,
1
/
r_max
.
value
,
r_max
.
value
)
d
=
jax
.
lax
.
stop_gradient
((
mean
-
ra_mean
.
value
)
/
ra_std
)
d
=
jnp
.
clip
(
d
,
-
d_max
.
value
,
d_max
.
value
)
tmp_var
=
var
/
(
r
**
2
)
tmp_mean
=
mean
-
d
*
jnp
.
sqrt
(
custom_var
)
/
r
# Warm up batch renorm for 100_000 steps to build up proper running statistics
warmed_up
=
jnp
.
greater_equal
(
steps
.
value
,
100_000
)
.
astype
(
jnp
.
float32
)
custom_var
=
warmed_up
*
tmp_var
+
(
1.
-
warmed_up
)
*
custom_var
custom_mean
=
warmed_up
*
tmp_mean
+
(
1.
-
warmed_up
)
*
custom_mean
ra_mean
.
value
=
(
self
.
momentum
*
ra_mean
.
value
+
(
1
-
self
.
momentum
)
*
mean
)
ra_var
.
value
=
self
.
momentum
*
ra_var
.
value
+
(
1
-
self
.
momentum
)
*
var
steps
.
value
+=
1
return
_normalize
(
self
,
x
,
custom_mean
,
custom_var
,
reduction_axes
,
feature_axes
,
self
.
dtype
,
self
.
param_dtype
,
self
.
epsilon
,
self
.
use_bias
,
self
.
use_scale
,
self
.
bias_init
,
self
.
scale_init
,
)
ygoai/rl/jax/utils.py
View file @
e1ff8f92
from
typing
import
Any
,
Callable
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
flax
import
core
,
struct
from
flax.linen.fp8_ops
import
OVERWRITE_WITH_GRADIENT
import
optax
import
numpy
as
np
import
numpy
as
np
from
ygoai.rl.env
import
RecordEpisodeStatistics
from
ygoai.rl.env
import
RecordEpisodeStatistics
...
@@ -67,3 +74,72 @@ def update_mean_var_count_from_moments(
...
@@ -67,3 +74,72 @@ def update_mean_var_count_from_moments(
new_count
=
tot_count
new_count
=
tot_count
return
new_mean
,
new_var
,
new_count
return
new_mean
,
new_var
,
new_count
class
TrainState
(
struct
.
PyTreeNode
):
step
:
int
apply_fn
:
Callable
=
struct
.
field
(
pytree_node
=
False
)
params
:
core
.
FrozenDict
[
str
,
Any
]
=
struct
.
field
(
pytree_node
=
True
)
tx
:
optax
.
GradientTransformation
=
struct
.
field
(
pytree_node
=
False
)
opt_state
:
optax
.
OptState
=
struct
.
field
(
pytree_node
=
True
)
batch_stats
:
core
.
FrozenDict
[
str
,
Any
]
=
struct
.
field
(
pytree_node
=
True
)
def
apply_gradients
(
self
,
*
,
grads
,
**
kwargs
):
"""Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value.
Note that internally this function calls ``.tx.update()`` followed by a call
to ``optax.apply_updates()`` to update ``params`` and ``opt_state``.
Args:
grads: Gradients that have the same pytree structure as ``.params``.
**kwargs: Additional dataclass attributes that should be ``.replace()``-ed.
Returns:
An updated instance of ``self`` with ``step`` incremented by one, ``params``
and ``opt_state`` updated by applying ``grads``, and additional attributes
replaced as specified by ``kwargs``.
"""
if
OVERWRITE_WITH_GRADIENT
in
grads
:
grads_with_opt
=
grads
[
'params'
]
params_with_opt
=
self
.
params
[
'params'
]
else
:
grads_with_opt
=
grads
params_with_opt
=
self
.
params
updates
,
new_opt_state
=
self
.
tx
.
update
(
grads_with_opt
,
self
.
opt_state
,
params_with_opt
)
new_params_with_opt
=
optax
.
apply_updates
(
params_with_opt
,
updates
)
# As implied by the OWG name, the gradients are used directly to update the
# parameters.
if
OVERWRITE_WITH_GRADIENT
in
grads
:
new_params
=
{
'params'
:
new_params_with_opt
,
OVERWRITE_WITH_GRADIENT
:
grads
[
OVERWRITE_WITH_GRADIENT
],
}
else
:
new_params
=
new_params_with_opt
return
self
.
replace
(
step
=
self
.
step
+
1
,
params
=
new_params
,
opt_state
=
new_opt_state
,
**
kwargs
,
)
@
classmethod
def
create
(
cls
,
*
,
apply_fn
,
params
,
tx
,
**
kwargs
):
"""Creates a new instance with ``step=0`` and initialized ``opt_state``."""
# We exclude OWG params when present because they do not need opt states.
params_with_opt
=
(
params
[
'params'
]
if
OVERWRITE_WITH_GRADIENT
in
params
else
params
)
opt_state
=
tx
.
init
(
params_with_opt
)
return
cls
(
step
=
0
,
apply_fn
=
apply_fn
,
params
=
params
,
tx
=
tx
,
opt_state
=
opt_state
,
**
kwargs
,
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment