Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
4c0edbf8
Commit
4c0edbf8
authored
Jun 06, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add RND
parent
d43e5903
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1574 additions
and
62 deletions
+1574
-62
scripts/cleanba.py
scripts/cleanba.py
+35
-41
scripts/cleanba_rnd.py
scripts/cleanba_rnd.py
+1280
-0
ygoai/rl/jax/__init__.py
ygoai/rl/jax/__init__.py
+126
-6
ygoai/rl/jax/agent.py
ygoai/rl/jax/agent.py
+84
-15
ygoai/rl/jax/utils.py
ygoai/rl/jax/utils.py
+49
-0
No files found.
scripts/cleanba.py
View file @
4c0edbf8
This diff is collapsed.
Click to expand it.
scripts/cleanba_rnd.py
0 → 100644
View file @
4c0edbf8
This diff is collapsed.
Click to expand it.
ygoai/rl/jax/__init__.py
View file @
4c0edbf8
...
@@ -107,15 +107,15 @@ def get_from_action(values, action):
...
@@ -107,15 +107,15 @@ def get_from_action(values, action):
return
jnp
.
sum
(
distrax
.
multiply_no_nan
(
values
,
value_one_hot
),
axis
=-
1
)
return
jnp
.
sum
(
distrax
.
multiply_no_nan
(
values
,
value_one_hot
),
axis
=-
1
)
def
mean_legal
(
values
,
axis
=
None
):
def
mean_legal
(
values
,
axis
=
None
,
keepdims
=
False
):
# TODO: use real action mask
# TODO: use real action mask
no_nan_mask
=
values
>
-
1e12
no_nan_mask
=
values
>
-
1e12
no_nan
=
jnp
.
where
(
no_nan_mask
,
values
,
0
)
no_nan
=
jnp
.
where
(
no_nan_mask
,
values
,
0
)
count
=
jnp
.
sum
(
no_nan_mask
,
axis
=
axis
)
count
=
jnp
.
sum
(
no_nan_mask
,
axis
=
axis
,
keepdims
=
keepdims
)
return
jnp
.
sum
(
no_nan
,
axis
=
axis
)
/
jnp
.
maximum
(
count
,
1
)
return
jnp
.
sum
(
no_nan
,
axis
=
axis
,
keepdims
=
keepdims
)
/
jnp
.
maximum
(
count
,
1
)
def
neurd_loss
(
actions
,
logits
,
new_logits
,
advantages
,
logits_threshold
):
def
neurd_loss
_2
(
actions
,
logits
,
new_logits
,
advantages
,
logits_threshold
):
# Neural Replicator Dynamics
# Neural Replicator Dynamics
# Differences from the original implementation:
# Differences from the original implementation:
# - all actions vs. sampled actions
# - all actions vs. sampled actions
...
@@ -136,6 +136,27 @@ def neurd_loss(actions, logits, new_logits, advantages, logits_threshold):
...
@@ -136,6 +136,27 @@ def neurd_loss(actions, logits, new_logits, advantages, logits_threshold):
return
pg_loss
return
pg_loss
def
neurd_loss
(
new_logits
,
advantages
,
logits_threshold
=
2.0
,
adv_threshold
=
1000.0
):
advs
=
jax
.
lax
.
stop_gradient
(
advantages
)
legal_mask
=
new_logits
>
-
1e12
legal_logits
=
jnp
.
where
(
legal_mask
,
new_logits
,
0
)
count
=
jnp
.
sum
(
legal_mask
,
axis
=-
1
,
keepdims
=
True
)
new_logits_
=
new_logits
-
jnp
.
sum
(
legal_logits
,
axis
=-
1
,
keepdims
=
True
)
/
jnp
.
maximum
(
count
,
1
)
can_increase
=
new_logits_
<
logits_threshold
can_decrease
=
new_logits_
>
-
logits_threshold
c
=
jnp
.
where
(
advs
>=
0
,
can_increase
,
can_decrease
)
.
astype
(
jnp
.
float32
)
c
=
jax
.
lax
.
stop_gradient
(
c
)
advs
=
jnp
.
clip
(
advs
,
-
adv_threshold
,
adv_threshold
)
# TODO: renormalize with player
pg_loss
=
-
c
*
new_logits_
*
advs
pg_loss
=
jnp
.
where
(
legal_mask
,
pg_loss
,
0
)
pg_loss
=
jnp
.
sum
(
pg_loss
,
axis
=-
1
)
return
pg_loss
def
ach_loss
(
actions
,
logits
,
new_logits
,
advantages
,
logits_threshold
,
clip_coef
,
dual_clip_coef
=
None
):
def
ach_loss
(
actions
,
logits
,
new_logits
,
advantages
,
logits_threshold
,
clip_coef
,
dual_clip_coef
=
None
):
# Actor-Critic Hedge loss from Actor-Critic Policy Optimization in a Large-Scale Imperfect-Information Game
# Actor-Critic Hedge loss from Actor-Critic Policy Optimization in a Large-Scale Imperfect-Information Game
# notice entropy term is required but not included here
# notice entropy term is required but not included here
...
@@ -158,7 +179,83 @@ def ach_loss(actions, logits, new_logits, advantages, logits_threshold, clip_coe
...
@@ -158,7 +179,83 @@ def ach_loss(actions, logits, new_logits, advantages, logits_threshold, clip_coe
return
pg_loss
return
pg_loss
def
vtrace_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
def
vtrace_rnad_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
v1
,
v2
,
next_values1
,
next_values2
,
reward1
,
reward2
,
xi1
,
xi2
,
reward_u1
,
reward_u2
=
carry
ratio
,
cur_values
,
r_t
,
eta_reg_entropy
,
probs
,
a_t
,
eta_log_policy
,
next_done
,
main
=
inp
v1
=
jnp
.
where
(
next_done
,
0
,
v1
)
v2
=
jnp
.
where
(
next_done
,
0
,
v2
)
next_values1
=
jnp
.
where
(
next_done
,
0
,
next_values1
)
next_values2
=
jnp
.
where
(
next_done
,
0
,
next_values2
)
reward1
=
jnp
.
where
(
next_done
,
0
,
reward1
)
reward2
=
jnp
.
where
(
next_done
,
0
,
reward2
)
xi1
=
jnp
.
where
(
next_done
,
1
,
xi1
)
xi2
=
jnp
.
where
(
next_done
,
1
,
xi2
)
reward_u1
=
jnp
.
where
(
next_done
,
0
,
reward_u1
)
reward_u2
=
jnp
.
where
(
next_done
,
0
,
reward_u2
)
discount
=
gamma
*
(
1.0
-
next_done
)
next_v
=
jnp
.
where
(
main
,
v1
,
v2
)
next_values
=
jnp
.
where
(
main
,
next_values1
,
next_values2
)
reward
=
jnp
.
where
(
main
,
reward1
,
reward2
)
xi
=
jnp
.
where
(
main
,
xi1
,
xi2
)
reward_u
=
jnp
.
where
(
main
,
reward_u1
,
reward_u2
)
reward_u
=
r_t
+
discount
*
reward_u
+
eta_reg_entropy
discounted_reward
=
r_t
+
discount
*
reward
rho_t
=
jnp
.
clip
(
ratio
*
xi
,
rho_min
,
rho_max
)
c_t
=
jnp
.
clip
(
ratio
*
xi
,
c_min
,
c_max
)
sig_v
=
rho_t
*
(
reward_u
+
discount
*
next_values
-
cur_values
)
v
=
cur_values
+
sig_v
+
c_t
*
discount
*
(
next_v
-
next_values
)
q_t
=
cur_values
[:,
None
]
+
eta_log_policy
n_actions
=
eta_log_policy
.
shape
[
-
1
]
q_t2
=
discounted_reward
+
discount
*
xi
*
next_v
-
cur_values
q_t
=
q_t
+
q_t2
[:,
None
]
*
distrax
.
multiply_no_nan
(
1.0
/
jnp
.
maximum
(
probs
,
1e-3
),
jax
.
nn
.
one_hot
(
a_t
,
n_actions
))
v1
=
jnp
.
where
(
main
,
v
,
discount
*
v1
)
v2
=
jnp
.
where
(
main
,
discount
*
v2
,
v
)
next_values1
=
jnp
.
where
(
main
,
cur_values
,
discount
*
next_values1
)
next_values2
=
jnp
.
where
(
main
,
discount
*
next_values2
,
cur_values
)
reward1
=
jnp
.
where
(
main
,
0
,
ratio
*
(
discount
*
reward1
-
r_t
)
-
eta_reg_entropy
)
reward2
=
jnp
.
where
(
main
,
ratio
*
(
discount
*
reward2
-
r_t
)
-
eta_reg_entropy
,
0
)
xi1
=
jnp
.
where
(
main
,
1
,
ratio
*
xi1
)
xi2
=
jnp
.
where
(
main
,
ratio
*
xi2
,
1
)
reward_u1
=
jnp
.
where
(
main
,
0
,
discount
*
reward_u1
-
r_t
-
eta_reg_entropy
)
reward_u2
=
jnp
.
where
(
main
,
discount
*
reward_u2
-
r_t
-
eta_reg_entropy
,
0
)
carry
=
v1
,
v2
,
next_values1
,
next_values2
,
reward1
,
reward2
,
xi1
,
xi2
,
reward_u1
,
reward_u2
return
carry
,
(
v
,
q_t
)
def
vtrace_rnad
(
next_value
,
ratios
,
logits
,
new_logits
,
actions
,
log_policy_reg
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
rho_min
=
0.001
,
rho_max
=
1.0
,
c_min
=
0.001
,
c_max
=
1.0
,
eta
=
0.2
,
):
probs
=
jax
.
nn
.
softmax
(
logits
)
new_probs
=
jax
.
nn
.
softmax
(
new_logits
)
eta_reg_entropy
=
-
eta
*
jnp
.
sum
(
new_probs
*
log_policy_reg
,
axis
=-
1
)
eta_log_policy
=
-
eta
*
log_policy_reg
next_value1
=
next_value
next_value2
=
-
next_value1
v1
=
next_value1
v2
=
next_value2
reward1
=
reward2
=
reward_u1
=
reward_u2
=
jnp
.
zeros_like
(
next_value
)
xi1
=
xi2
=
jnp
.
ones_like
(
next_value
)
carry
=
v1
,
v2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
xi1
,
xi2
,
reward_u1
,
reward_u2
_
,
(
targets
,
q_estimate
)
=
jax
.
lax
.
scan
(
partial
(
vtrace_rnad_loop
,
gamma
=
gamma
,
rho_min
=
rho_min
,
rho_max
=
rho_max
,
c_min
=
c_min
,
c_max
=
c_max
),
carry
,
(
ratios
,
values
,
rewards
,
eta_reg_entropy
,
probs
,
actions
,
eta_log_policy
,
next_dones
,
mains
),
reverse
=
True
)
targets
=
jax
.
lax
.
stop_gradient
(
targets
)
return
targets
,
q_estimate
def
vtrace_2p0s_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
v1
,
v2
,
next_values1
,
next_values2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
v1
,
v2
,
next_values1
,
next_values2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
last_return1
,
last_return2
,
next_q1
,
next_q2
=
carry
last_return1
,
last_return2
,
next_q1
,
next_q2
=
carry
ratio
,
cur_values
,
next_done
,
r_t
,
main
=
inp
ratio
,
cur_values
,
next_done
,
r_t
,
main
=
inp
...
@@ -229,7 +326,7 @@ def vtrace_2p0s(
...
@@ -229,7 +326,7 @@ def vtrace_2p0s(
return1
,
return2
,
next_q1
,
next_q2
return1
,
return2
,
next_q1
,
next_q2
_
,
(
targets
,
q_estimate
,
return_t
)
=
jax
.
lax
.
scan
(
_
,
(
targets
,
q_estimate
,
return_t
)
=
jax
.
lax
.
scan
(
partial
(
vtrace_loop
,
gamma
=
gamma
,
rho_min
=
rho_min
,
rho_max
=
rho_max
,
c_min
=
c_min
,
c_max
=
c_max
),
partial
(
vtrace_
2p0s_
loop
,
gamma
=
gamma
,
rho_min
=
rho_min
,
rho_max
=
rho_max
,
c_min
=
c_min
,
c_max
=
c_max
),
carry
,
(
ratios
,
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
carry
,
(
ratios
,
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
)
)
advantages
=
q_estimate
-
values
advantages
=
q_estimate
-
values
...
@@ -314,6 +411,29 @@ def truncated_gae_2p0s(
...
@@ -314,6 +411,29 @@ def truncated_gae_2p0s(
return
targets
,
advantages
return
targets
,
advantages
def
truncated_gae_loop
(
carry
,
inp
,
gamma
,
gae_lambda
):
lastgaelam
,
next_value
=
carry
cur_value
,
next_done
,
reward
=
inp
nextnonterminal
=
1.0
-
next_done
delta
=
reward
+
gamma
*
next_value
*
nextnonterminal
-
cur_value
lastgaelam
=
delta
+
gamma
*
gae_lambda
*
nextnonterminal
*
lastgaelam
carry
=
lastgaelam
,
cur_value
return
carry
,
lastgaelam
def
truncated_gae
(
next_value
,
values
,
rewards
,
next_dones
,
gamma
,
gae_lambda
):
lastgaelam
=
jnp
.
zeros_like
(
next_value
)
carry
=
lastgaelam
,
next_value
_
,
advantages
=
jax
.
lax
.
scan
(
partial
(
truncated_gae_loop
,
gamma
=
gamma
,
gae_lambda
=
gae_lambda
),
carry
,
(
values
,
next_dones
,
rewards
),
reverse
=
True
)
targets
=
values
+
advantages
targets
=
jax
.
lax
.
stop_gradient
(
targets
)
return
targets
,
advantages
def
simple_policy_loss
(
ratios
,
logits
,
new_logits
,
advantages
,
kld_max
,
eps
=
1e-12
):
def
simple_policy_loss
(
ratios
,
logits
,
new_logits
,
advantages
,
kld_max
,
eps
=
1e-12
):
advs
=
jax
.
lax
.
stop_gradient
(
advantages
)
advs
=
jax
.
lax
.
stop_gradient
(
advantages
)
probs
=
jax
.
nn
.
softmax
(
logits
)
probs
=
jax
.
nn
.
softmax
(
logits
)
...
...
ygoai/rl/jax/agent.py
View file @
4c0edbf8
...
@@ -506,40 +506,39 @@ def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True
...
@@ -506,40 +506,39 @@ def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True
@
dataclass
@
dataclass
class
Model
Args
:
class
Encoder
Args
:
num_layers
:
int
=
2
num_layers
:
int
=
2
"""the number of layers for the agent"""
"""the number of layers for the agent"""
num_channels
:
int
=
128
num_channels
:
int
=
128
"""the number of channels for the agent"""
"""the number of channels for the agent"""
rnn_channels
:
int
=
512
"""the number of channels for the RNN in the agent"""
use_history
:
bool
=
True
use_history
:
bool
=
True
"""whether to use history actions as input for agent"""
"""whether to use history actions as input for agent"""
card_mask
:
bool
=
False
card_mask
:
bool
=
False
"""whether to mask the padding card as ignored in the transformer"""
"""whether to mask the padding card as ignored in the transformer"""
rnn_type
:
Optional
[
Literal
[
'lstm'
,
'gru'
,
'rwkv'
,
'none'
]]
=
"lstm"
"""the type of RNN to use, None for no RNN"""
film
:
bool
=
False
"""whether to use FiLM for the actor"""
noam
:
bool
=
False
noam
:
bool
=
False
"""whether to use Noam architecture for the transformer layer"""
"""whether to use Noam architecture for the transformer layer"""
rwkv_head_size
:
int
=
32
"""the head size for the RWKV"""
action_feats
:
bool
=
True
action_feats
:
bool
=
True
"""whether to use action features for the global state"""
"""whether to use action features for the global state"""
version
:
int
=
0
version
:
int
=
0
"""the version of the environment and the agent"""
"""the version of the environment and the agent"""
@
dataclass
class
ModelArgs
(
EncoderArgs
):
rnn_channels
:
int
=
512
"""the number of channels for the RNN in the agent"""
rnn_type
:
Optional
[
Literal
[
'lstm'
,
'gru'
,
'rwkv'
,
'none'
]]
=
"lstm"
"""the type of RNN to use, None for no RNN"""
film
:
bool
=
False
"""whether to use FiLM for the actor"""
rwkv_head_size
:
int
=
32
"""the head size for the RWKV"""
class
RNNAgent
(
nn
.
Module
):
class
RNNAgent
(
nn
.
Module
):
num_layers
:
int
=
2
num_layers
:
int
=
2
num_channels
:
int
=
128
num_channels
:
int
=
128
rnn_channels
:
int
=
512
rnn_channels
:
int
=
512
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
switch
:
bool
=
True
freeze_id
:
bool
=
False
use_history
:
bool
=
True
use_history
:
bool
=
True
card_mask
:
bool
=
False
card_mask
:
bool
=
False
rnn_type
:
str
=
'lstm'
rnn_type
:
str
=
'lstm'
...
@@ -549,6 +548,13 @@ class RNNAgent(nn.Module):
...
@@ -549,6 +548,13 @@ class RNNAgent(nn.Module):
action_feats
:
bool
=
True
action_feats
:
bool
=
True
version
:
int
=
0
version
:
int
=
0
switch
:
bool
=
True
freeze_id
:
bool
=
False
int_head
:
bool
=
False
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
x
,
rstate
,
done
=
None
,
switch_or_main
=
None
):
def
__call__
(
self
,
x
,
rstate
,
done
=
None
,
switch_or_main
=
None
):
c
=
self
.
num_channels
c
=
self
.
num_channels
...
@@ -618,6 +624,11 @@ class RNNAgent(nn.Module):
...
@@ -618,6 +624,11 @@ class RNNAgent(nn.Module):
logits
=
actor
(
f_state_r
,
f_actions
,
mask
)
logits
=
actor
(
f_state_r
,
f_actions
,
mask
)
value
=
critic
(
f_state_r
)
value
=
critic
(
f_state_r
)
if
self
.
int_head
:
critic_int
=
Critic
(
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
value_int
=
critic_int
(
f_state_r
)
value
=
(
value
,
value_int
)
return
rstate
,
logits
,
value
,
valid
return
rstate
,
logits
,
value
,
valid
def
init_rnn_state
(
self
,
batch_size
):
def
init_rnn_state
(
self
,
batch_size
):
...
@@ -636,4 +647,62 @@ class RNNAgent(nn.Module):
...
@@ -636,4 +647,62 @@ class RNNAgent(nn.Module):
np
.
zeros
((
batch_size
,
num_heads
*
head_size
*
head_size
)),
np
.
zeros
((
batch_size
,
num_heads
*
head_size
*
head_size
)),
)
)
else
:
else
:
return
None
return
None
\ No newline at end of file
default_rnd_args
=
EncoderArgs
(
num_layers
=
1
,
num_channels
=
128
,
use_history
=
True
,
card_mask
=
False
,
noam
=
True
,
action_feats
=
True
,
version
=
2
,
)
class
RNDModel
(
nn
.
Module
):
is_predictor
:
bool
=
False
num_layers
:
int
=
1
num_channels
:
int
=
128
use_history
:
bool
=
True
card_mask
:
bool
=
False
noam
:
bool
=
True
action_feats
:
bool
=
True
version
:
int
=
2
out_channels
:
Optional
[
int
]
=
None
freeze_id
:
bool
=
True
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
):
c
=
self
.
num_channels
oc
=
self
.
out_channels
or
c
*
2
encoder
=
Encoder
(
channels
=
c
,
out_channels
=
oc
,
num_layers
=
self
.
num_layers
,
embedding_shape
=
self
.
embedding_shape
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
freeze_id
=
self
.
freeze_id
,
use_history
=
self
.
use_history
,
card_mask
=
self
.
card_mask
,
noam
=
self
.
noam
,
action_feats
=
self
.
action_feats
,
version
=
self
.
version
,
)
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
c
=
f_state
.
shape
[
-
1
]
if
self
.
is_predictor
:
predictor
=
MLP
([
oc
,
oc
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
f_state
=
predictor
(
f_state
)
else
:
f_state
=
nn
.
Dense
(
oc
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
np
.
sqrt
(
2
)))(
f_state
)
return
f_state
\ No newline at end of file
ygoai/rl/jax/utils.py
View file @
4c0edbf8
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
flax
import
struct
from
ygoai.rl.env
import
RecordEpisodeStatistics
from
ygoai.rl.env
import
RecordEpisodeStatistics
...
@@ -24,3 +26,50 @@ def categorical_sample(logits, key):
...
@@ -24,3 +26,50 @@ def categorical_sample(logits, key):
u
=
jax
.
random
.
uniform
(
subkey
,
shape
=
logits
.
shape
)
u
=
jax
.
random
.
uniform
(
subkey
,
shape
=
logits
.
shape
)
action
=
jnp
.
argmax
(
logits
-
jnp
.
log
(
-
jnp
.
log
(
u
)),
axis
=-
1
)
action
=
jnp
.
argmax
(
logits
-
jnp
.
log
(
-
jnp
.
log
(
u
)),
axis
=-
1
)
return
action
,
key
return
action
,
key
class
RunningMeanStd
(
struct
.
PyTreeNode
):
"""Tracks the mean, variance and count of values."""
mean
:
jnp
.
ndarray
=
struct
.
field
(
pytree_node
=
True
)
var
:
jnp
.
ndarray
=
struct
.
field
(
pytree_node
=
True
)
count
:
jnp
.
ndarray
=
struct
.
field
(
pytree_node
=
True
)
@
classmethod
def
create
(
cls
,
shape
=
()):
return
cls
(
mean
=
jnp
.
zeros
(
shape
,
"float64"
),
var
=
jnp
.
ones
(
shape
,
"float64"
),
count
=
jnp
.
full
(
shape
,
1e-4
,
"float64"
),
)
def
update
(
self
,
x
):
"""Updates the mean, var and count from a batch of samples."""
batch_mean
=
jnp
.
mean
(
x
,
axis
=
0
)
batch_var
=
jnp
.
var
(
x
,
axis
=
0
)
batch_count
=
x
.
shape
[
0
]
return
self
.
update_from_moments
(
batch_mean
,
batch_var
,
batch_count
)
def
update_from_moments
(
self
,
batch_mean
,
batch_var
,
batch_count
):
"""Updates from batch mean, variance and count moments."""
mean
,
var
,
count
=
update_mean_var_count_from_moments
(
self
.
mean
,
self
.
var
,
self
.
count
,
batch_mean
,
batch_var
,
batch_count
)
return
self
.
replace
(
mean
=
mean
,
var
=
var
,
count
=
count
)
def
update_mean_var_count_from_moments
(
mean
,
var
,
count
,
batch_mean
,
batch_var
,
batch_count
):
"""Updates the mean, var and count using the previous mean, var, count and batch values."""
delta
=
batch_mean
-
mean
tot_count
=
count
+
batch_count
new_mean
=
mean
+
delta
*
batch_count
/
tot_count
m_a
=
var
*
count
m_b
=
batch_var
*
batch_count
M2
=
m_a
+
m_b
+
jnp
.
square
(
delta
)
*
count
*
batch_count
/
tot_count
new_var
=
M2
/
tot_count
new_count
=
tot_count
return
new_mean
,
new_var
,
new_count
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment