Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
3dfee5f5
Commit
3dfee5f5
authored
Jun 07, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix int_reward compute
parent
e6dc7744
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
32 deletions
+42
-32
scripts/cleanba_rnd.py
scripts/cleanba_rnd.py
+29
-12
ygoai/rl/jax/utils.py
ygoai/rl/jax/utils.py
+13
-20
No files found.
scripts/cleanba_rnd.py
View file @
3dfee5f5
...
@@ -162,7 +162,7 @@ class Args:
...
@@ -162,7 +162,7 @@ class Args:
"""proportion of exp used for predictor update"""
"""proportion of exp used for predictor update"""
rnd_episodic
:
bool
=
False
rnd_episodic
:
bool
=
False
"""whether to use episodic intrinsic reward for RND"""
"""whether to use episodic intrinsic reward for RND"""
rnd_norm
:
Literal
[
"default"
,
"min_max"
]
=
"default"
rnd_norm
:
Literal
[
"default"
,
"min_max"
,
"min_max2"
]
=
"default"
"""the normalization method for RND intrinsic reward"""
"""the normalization method for RND intrinsic reward"""
int_coef
:
float
=
0.5
int_coef
:
float
=
0.5
"""coefficient of intrinsic reward, 0.0 to disable RND"""
"""coefficient of intrinsic reward, 0.0 to disable RND"""
...
@@ -393,6 +393,15 @@ def rollout(
...
@@ -393,6 +393,15 @@ def rollout(
rstate1
,
rstate2
=
jax
.
tree
.
map
(
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
return
rstate1
,
rstate2
,
logits
.
argmax
(
axis
=
1
)
return
rstate1
,
rstate2
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
compute_int_rew
(
params_rt
,
params_rp
,
obs
):
target_feats
=
rnd_target
.
apply
(
params_rt
,
obs
)
predict_feats
=
rnd_predictor
.
apply
(
params_rp
,
obs
)
int_rewards
=
jnp
.
sum
((
target_feats
-
predict_feats
)
**
2
,
axis
=-
1
)
/
2
if
args
.
rnd_norm
==
'min_max'
:
int_rewards
=
(
int_rewards
-
int_rewards
.
min
())
/
(
int_rewards
.
max
()
-
int_rewards
.
min
()
+
1e-8
)
return
target_feats
,
int_rewards
@
jax
.
jit
@
jax
.
jit
def
sample_action
(
def
sample_action
(
...
@@ -403,12 +412,8 @@ def rollout(
...
@@ -403,12 +412,8 @@ def rollout(
action
,
key
=
categorical_sample
(
logits
,
key
)
action
,
key
=
categorical_sample
(
logits
,
key
)
if
args
.
enable_rnd
:
if
args
.
enable_rnd
:
target_feats
=
rnd_target
.
apply
(
params_rt
,
next_obs
)
target_feats
,
int_rewards
=
compute_int_rew
(
params_rt
,
params_rp
,
next_obs
)
predict_feats
=
rnd_predictor
.
apply
(
params_rp
,
next_obs
)
if
args
.
rnd_norm
==
'default'
:
int_rewards
=
jnp
.
sum
((
target_feats
-
predict_feats
)
**
2
,
axis
=-
1
)
/
2
if
args
.
rnd_norm
==
'min_max'
:
int_rewards
=
(
int_rewards
-
int_rewards
.
min
())
/
(
int_rewards
.
max
()
-
int_rewards
.
min
()
+
1e-8
)
else
:
rewems
=
rewems
*
args
.
int_gamma
+
int_rewards
rewems
=
rewems
*
args
.
int_gamma
+
int_rewards
else
:
else
:
target_feats
=
int_rewards
=
None
target_feats
=
int_rewards
=
None
...
@@ -442,7 +447,7 @@ def rollout(
...
@@ -442,7 +447,7 @@ def rollout(
np
.
random
.
shuffle
(
main_player
)
np
.
random
.
shuffle
(
main_player
)
storage
=
[]
storage
=
[]
reward_rms
=
jax
.
device_put
(
RunningMeanStd
.
create
()
)
reward_rms
=
RunningMeanStd
(
)
rewems
=
jnp
.
zeros
(
args
.
local_num_envs
,
dtype
=
jnp
.
float32
,
device
=
actor_device
)
rewems
=
jnp
.
zeros
(
args
.
local_num_envs
,
dtype
=
jnp
.
float32
,
device
=
actor_device
)
@
jax
.
jit
@
jax
.
jit
...
@@ -550,16 +555,28 @@ def rollout(
...
@@ -550,16 +555,28 @@ def rollout(
rollout_time
.
append
(
time
.
time
()
-
rollout_time_start
)
rollout_time
.
append
(
time
.
time
()
-
rollout_time_start
)
if
args
.
enable_rnd
:
if
args
.
enable_rnd
:
next_int_reward
=
compute_int_rew
(
params_rt
,
params_rp
,
next_obs
)[
1
]
all_int_rewards
=
all_int_rewards
[
1
:]
+
[
next_int_reward
]
# TODO: update every step
# TODO: update every step
all_int_rewards
=
jnp
.
stack
(
all_int_rewards
)
all_int_rewards
=
jnp
.
stack
(
all_int_rewards
)
if
args
.
rnd_norm
==
'default'
:
if
args
.
rnd_norm
==
'default'
:
reward_rms
=
reward_rms
.
update
(
jnp
.
array
(
all_dis_int_rewards
)
.
flatten
())
all_dis_int_rewards
=
jnp
.
concatenate
(
all_dis_int_rewards
)
all_int_rewards
=
all_int_rewards
/
jnp
.
sqrt
(
reward_rms
.
var
)
mean
,
std
=
jax
.
device_get
((
all_dis_int_rewards
.
mean
(),
all_dis_int_rewards
.
std
()))
count
=
len
(
all_dis_int_rewards
)
reward_rms
.
update_from_moments
(
mean
,
std
**
2
,
count
)
all_int_rewards
=
all_int_rewards
/
np
.
sqrt
(
reward_rms
.
var
)
elif
args
.
rnd_norm
==
'min_max2'
:
max_int_rewards
=
jnp
.
max
(
all_int_rewards
)
min_int_rewards
=
jnp
.
min
(
all_int_rewards
)
all_int_rewards
=
(
all_int_rewards
-
min_int_rewards
)
/
(
max_int_rewards
-
min_int_rewards
)
mean_int_rewards
=
jnp
.
mean
(
all_int_rewards
)
max_int_rewards
=
jnp
.
max
(
all_int_rewards
)
for
k
in
range
(
args
.
num_steps
):
for
k
in
range
(
args
.
num_steps
):
int_rewards
=
all_int_rewards
[
k
]
int_rewards
=
all_int_rewards
[
k
]
storage
[
k
]
=
storage
[
k
]
.
_replace
(
int_rewards
=
int_rewards
)
storage
[
k
]
=
storage
[
k
]
.
_replace
(
int_rewards
=
int_rewards
)
mean_int_rewards
=
jnp
.
mean
(
all_int_rewards
)
max_int_rewards
=
jnp
.
max
(
all_int_rewards
)
partitioned_storage
=
prepare_data
(
storage
)
partitioned_storage
=
prepare_data
(
storage
)
...
...
ygoai/rl/jax/utils.py
View file @
3dfee5f5
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
flax
import
struct
import
numpy
as
np
from
ygoai.rl.env
import
RecordEpisodeStatistics
from
ygoai.rl.env
import
RecordEpisodeStatistics
...
@@ -28,35 +28,28 @@ def categorical_sample(logits, key):
...
@@ -28,35 +28,28 @@ def categorical_sample(logits, key):
return
action
,
key
return
action
,
key
class
RunningMeanStd
(
struct
.
PyTreeNode
)
:
class
RunningMeanStd
:
"""Tracks the mean, variance and count of values."""
"""Tracks the mean, variance and count of values."""
mean
:
jnp
.
ndarray
=
struct
.
field
(
pytree_node
=
True
)
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
var
:
jnp
.
ndarray
=
struct
.
field
(
pytree_node
=
True
)
def
__init__
(
self
,
epsilon
=
1e-4
,
shape
=
()):
count
:
jnp
.
ndarray
=
struct
.
field
(
pytree_node
=
True
)
"""Tracks the mean, variance and count of values."""
self
.
mean
=
np
.
zeros
(
shape
,
"float64"
)
@
classmethod
self
.
var
=
np
.
ones
(
shape
,
"float64"
)
def
create
(
cls
,
shape
=
()):
self
.
count
=
epsilon
# TODO: use numpy and float64
return
cls
(
mean
=
jnp
.
zeros
(
shape
,
"float32"
),
var
=
jnp
.
ones
(
shape
,
"float32"
),
count
=
jnp
.
full
(
shape
,
1e-4
,
"float32"
),
)
def
update
(
self
,
x
):
def
update
(
self
,
x
):
"""Updates the mean, var and count from a batch of samples."""
"""Updates the mean, var and count from a batch of samples."""
batch_mean
=
j
np
.
mean
(
x
,
axis
=
0
)
batch_mean
=
np
.
mean
(
x
,
axis
=
0
)
batch_var
=
j
np
.
var
(
x
,
axis
=
0
)
batch_var
=
np
.
var
(
x
,
axis
=
0
)
batch_count
=
x
.
shape
[
0
]
batch_count
=
x
.
shape
[
0
]
return
self
.
update_from_moments
(
batch_mean
,
batch_var
,
batch_count
)
self
.
update_from_moments
(
batch_mean
,
batch_var
,
batch_count
)
def
update_from_moments
(
self
,
batch_mean
,
batch_var
,
batch_count
):
def
update_from_moments
(
self
,
batch_mean
,
batch_var
,
batch_count
):
"""Updates from batch mean, variance and count moments."""
"""Updates from batch mean, variance and count moments."""
mean
,
var
,
count
=
update_mean_var_count_from_moments
(
self
.
mean
,
self
.
var
,
self
.
count
=
update_mean_var_count_from_moments
(
self
.
mean
,
self
.
var
,
self
.
count
,
batch_mean
,
batch_var
,
batch_count
self
.
mean
,
self
.
var
,
self
.
count
,
batch_mean
,
batch_var
,
batch_count
)
)
return
self
.
replace
(
mean
=
mean
,
var
=
var
,
count
=
count
)
def
update_mean_var_count_from_moments
(
def
update_mean_var_count_from_moments
(
...
@@ -69,7 +62,7 @@ def update_mean_var_count_from_moments(
...
@@ -69,7 +62,7 @@ def update_mean_var_count_from_moments(
new_mean
=
mean
+
delta
*
batch_count
/
tot_count
new_mean
=
mean
+
delta
*
batch_count
/
tot_count
m_a
=
var
*
count
m_a
=
var
*
count
m_b
=
batch_var
*
batch_count
m_b
=
batch_var
*
batch_count
M2
=
m_a
+
m_b
+
j
np
.
square
(
delta
)
*
count
*
batch_count
/
tot_count
M2
=
m_a
+
m_b
+
np
.
square
(
delta
)
*
count
*
batch_count
/
tot_count
new_var
=
M2
/
tot_count
new_var
=
M2
/
tot_count
new_count
=
tot_count
new_count
=
tot_count
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment