Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
b8929b9c
Commit
b8929b9c
authored
May 13, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
refactor agent2
parent
8b35ba28
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
32 deletions
+43
-32
scripts/cleanba.py
scripts/cleanba.py
+40
-29
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+3
-3
No files found.
scripts/cleanba.py
View file @
b8929b9c
...
@@ -52,6 +52,10 @@ class Args:
...
@@ -52,6 +52,10 @@ class Args:
tb_dir
:
str
=
"runs"
tb_dir
:
str
=
"runs"
"""the directory to save the tensorboard logs"""
"""the directory to save the tensorboard logs"""
tb_offset
:
int
=
0
"""the step offset of the tensorboard logs"""
run_name
:
Optional
[
str
]
=
None
"""the name of the tensorboard run"""
ckpt_dir
:
str
=
"checkpoints"
ckpt_dir
:
str
=
"checkpoints"
"""the directory to save the model checkpoints"""
"""the directory to save the model checkpoints"""
gcs_bucket
:
Optional
[
str
]
=
None
gcs_bucket
:
Optional
[
str
]
=
None
...
@@ -223,7 +227,7 @@ class Transition(NamedTuple):
...
@@ -223,7 +227,7 @@ class Transition(NamedTuple):
next_dones
:
list
next_dones
:
list
def
create_agent
(
args
,
multi_step
=
False
,
eval
=
False
):
def
create_agent
(
args
,
eval
=
False
):
return
RNNAgent
(
return
RNNAgent
(
channels
=
args
.
num_channels
,
channels
=
args
.
num_channels
,
num_layers
=
args
.
num_layers
,
num_layers
=
args
.
num_layers
,
...
@@ -232,7 +236,6 @@ def create_agent(args, multi_step=False, eval=False):
...
@@ -232,7 +236,6 @@ def create_agent(args, multi_step=False, eval=False):
param_dtype
=
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
rnn_channels
=
args
.
rnn_channels
,
rnn_channels
=
args
.
rnn_channels
,
switch
=
args
.
switch
,
switch
=
args
.
switch
,
multi_step
=
multi_step
,
freeze_id
=
args
.
freeze_id
,
freeze_id
=
args
.
freeze_id
,
use_history
=
args
.
use_history
if
not
eval
else
args
.
eval_use_history
,
use_history
=
args
.
use_history
if
not
eval
else
args
.
eval_use_history
,
rnn_type
=
args
.
rnn_type
if
not
eval
else
args
.
eval_rnn_type
,
rnn_type
=
args
.
rnn_type
if
not
eval
else
args
.
eval_rnn_type
,
...
@@ -480,23 +483,26 @@ def rollout(
...
@@ -480,23 +483,26 @@ def rollout(
avg_episodic_length
=
np
.
mean
(
envs
.
returned_episode_lengths
)
avg_episodic_length
=
np
.
mean
(
envs
.
returned_episode_lengths
)
SPS
=
int
((
global_step
-
warmup_step
)
/
(
time
.
time
()
-
start_time
-
other_time
))
SPS
=
int
((
global_step
-
warmup_step
)
/
(
time
.
time
()
-
start_time
-
other_time
))
SPS_update
=
int
(
args
.
batch_size
/
(
time
.
time
()
-
update_time_start
))
SPS_update
=
int
(
args
.
batch_size
/
(
time
.
time
()
-
update_time_start
))
tb_global_step
=
args
.
tb_offset
+
global_step
if
device_thread_id
==
0
:
if
device_thread_id
==
0
:
print
(
print
(
f
"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
f
"global_step={
tb_
global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
)
)
time_now
=
datetime
.
now
(
timezone
(
timedelta
(
hours
=
8
)))
.
strftime
(
"
%
H:
%
M:
%
S"
)
time_now
=
datetime
.
now
(
timezone
(
timedelta
(
hours
=
8
)))
.
strftime
(
"
%
H:
%
M:
%
S"
)
print
(
print
(
f
"{time_now} SPS: {SPS}, update: {SPS_update}, "
f
"{time_now} SPS: {SPS}, update: {SPS_update}, "
f
"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
f
"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
)
)
writer
.
add_scalar
(
"stats/rollout_time"
,
np
.
mean
(
rollout_time
),
global_step
)
writer
.
add_scalar
(
"stats/rollout_time"
,
np
.
mean
(
rollout_time
),
tb_
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_return"
,
avg_episodic_return
,
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_return"
,
avg_episodic_return
,
tb_
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_length"
,
avg_episodic_length
,
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_length"
,
avg_episodic_length
,
tb_
global_step
)
writer
.
add_scalar
(
"stats/params_queue_get_time"
,
np
.
mean
(
params_queue_get_time
),
global_step
)
writer
.
add_scalar
(
"stats/params_queue_get_time"
,
np
.
mean
(
params_queue_get_time
),
tb_
global_step
)
writer
.
add_scalar
(
"stats/inference_time"
,
inference_time
,
global_step
)
writer
.
add_scalar
(
"stats/inference_time"
,
inference_time
,
tb_
global_step
)
writer
.
add_scalar
(
"stats/env_time"
,
env_time
,
global_step
)
writer
.
add_scalar
(
"stats/env_time"
,
env_time
,
tb_
global_step
)
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
tb_
global_step
)
writer
.
add_scalar
(
"charts/SPS_update"
,
SPS_update
,
global_step
)
writer
.
add_scalar
(
"charts/SPS_update"
,
SPS_update
,
tb_
global_step
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
@@ -554,8 +560,12 @@ if __name__ == "__main__":
...
@@ -554,8 +560,12 @@ if __name__ == "__main__":
args
.
learner_devices
=
[
str
(
item
)
for
item
in
learner_devices
]
args
.
learner_devices
=
[
str
(
item
)
for
item
in
learner_devices
]
pprint
(
args
)
pprint
(
args
)
timestamp
=
int
(
time
.
time
())
if
args
.
run_name
is
None
:
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
timestamp
=
int
(
time
.
time
())
run_name
=
f
"{args.exp_name}__{args.seed}__{timestamp}"
else
:
run_name
=
args
.
run_name
timestamp
=
int
(
run_name
.
split
(
"__"
)[
-
1
])
dummy_writer
=
SimpleNamespace
()
dummy_writer
=
SimpleNamespace
()
dummy_writer
.
add_scalar
=
lambda
x
,
y
,
z
:
None
dummy_writer
.
add_scalar
=
lambda
x
,
y
,
z
:
None
...
@@ -668,7 +678,7 @@ if __name__ == "__main__":
...
@@ -668,7 +678,7 @@ if __name__ == "__main__":
inputs
=
(
rstate1
,
rstate2
,
obs
,
dones
,
switch_or_mains
)
inputs
=
(
rstate1
,
rstate2
,
obs
,
dones
,
switch_or_mains
)
_rstate
,
new_logits
,
new_values
,
_valid
=
create_agent
(
_rstate
,
new_logits
,
new_values
,
_valid
=
create_agent
(
args
,
multi_step
=
True
)
.
apply
(
params
,
inputs
)
args
)
.
apply
(
params
,
inputs
)
new_values
=
new_values
.
squeeze
(
-
1
)
new_values
=
new_values
.
squeeze
(
-
1
)
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
...
@@ -897,13 +907,14 @@ if __name__ == "__main__":
...
@@ -897,13 +907,14 @@ if __name__ == "__main__":
if
eval_stats
is
not
None
:
if
eval_stats
is
not
None
:
eval_stat_list
.
append
(
eval_stats
)
eval_stat_list
.
append
(
eval_stats
)
tb_global_step
=
args
.
tb_offset
+
global_step
if
update
%
args
.
eval_interval
==
0
:
if
update
%
args
.
eval_interval
==
0
:
eval_stats
=
np
.
mean
(
eval_stat_list
,
axis
=
0
)
eval_stats
=
np
.
mean
(
eval_stat_list
,
axis
=
0
)
eval_stats
=
jax
.
device_put
(
eval_stats
,
local_devices
[
0
])
eval_stats
=
jax
.
device_put
(
eval_stats
,
local_devices
[
0
])
eval_stats
=
np
.
array
(
all_reduce_value
(
eval_stats
[
None
])[
0
])
eval_stats
=
np
.
array
(
all_reduce_value
(
eval_stats
[
None
])[
0
])
eval_time
,
eval_return
,
eval_win_rate
=
eval_stats
eval_time
,
eval_return
,
eval_win_rate
=
eval_stats
writer
.
add_scalar
(
f
"charts/eval_return"
,
eval_return
,
global_step
)
writer
.
add_scalar
(
f
"charts/eval_return"
,
eval_return
,
tb_
global_step
)
writer
.
add_scalar
(
f
"charts/eval_win_rate"
,
eval_win_rate
,
global_step
)
writer
.
add_scalar
(
f
"charts/eval_win_rate"
,
eval_win_rate
,
tb_
global_step
)
print
(
f
"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}"
)
print
(
f
"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}"
)
rollout_queue_get_time
.
append
(
time
.
time
()
-
rollout_queue_get_time_start
)
rollout_queue_get_time
.
append
(
time
.
time
()
-
rollout_queue_get_time_start
)
...
@@ -927,31 +938,31 @@ if __name__ == "__main__":
...
@@ -927,31 +938,31 @@ if __name__ == "__main__":
# record rewards for plotting purposes
# record rewards for plotting purposes
if
learner_policy_version
%
args
.
log_frequency
==
0
:
if
learner_policy_version
%
args
.
log_frequency
==
0
:
writer
.
add_scalar
(
"stats/rollout_queue_get_time"
,
np
.
mean
(
rollout_queue_get_time
),
global_step
)
writer
.
add_scalar
(
"stats/rollout_queue_get_time"
,
np
.
mean
(
rollout_queue_get_time
),
tb_
global_step
)
writer
.
add_scalar
(
writer
.
add_scalar
(
"stats/rollout_params_queue_get_time_diff"
,
"stats/rollout_params_queue_get_time_diff"
,
np
.
mean
(
rollout_queue_get_time
)
-
avg_params_queue_get_time
,
np
.
mean
(
rollout_queue_get_time
)
-
avg_params_queue_get_time
,
global_step
,
tb_
global_step
,
)
)
writer
.
add_scalar
(
"stats/training_time"
,
time
.
time
()
-
training_time_start
,
global_step
)
writer
.
add_scalar
(
"stats/training_time"
,
time
.
time
()
-
training_time_start
,
tb_
global_step
)
writer
.
add_scalar
(
"stats/rollout_queue_size"
,
rollout_queues
[
-
1
]
.
qsize
(),
global_step
)
writer
.
add_scalar
(
"stats/rollout_queue_size"
,
rollout_queues
[
-
1
]
.
qsize
(),
tb_
global_step
)
writer
.
add_scalar
(
"stats/params_queue_size"
,
params_queues
[
-
1
]
.
qsize
(),
global_step
)
writer
.
add_scalar
(
"stats/params_queue_size"
,
params_queues
[
-
1
]
.
qsize
(),
tb_
global_step
)
print
(
print
(
f
"{global_step} actor_update={update}, "
f
"{
tb_
global_step} actor_update={update}, "
f
"train_time={time.time() - training_time_start:.2f}, "
f
"train_time={time.time() - training_time_start:.2f}, "
f
"data_time={rollout_queue_get_time[-1]:.2f}"
f
"data_time={rollout_queue_get_time[-1]:.2f}"
)
)
writer
.
add_scalar
(
writer
.
add_scalar
(
"charts/learning_rate"
,
agent_state
.
opt_state
[
3
][
2
][
1
]
.
hyperparams
[
"learning_rate"
][
-
1
]
.
item
(),
global_step
"charts/learning_rate"
,
agent_state
.
opt_state
[
3
][
2
][
1
]
.
hyperparams
[
"learning_rate"
][
-
1
]
.
item
(),
tb_
global_step
)
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
[
-
1
]
.
item
(),
tb_
global_step
)
writer
.
add_scalar
(
"losses/policy_loss"
,
pg_loss
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/policy_loss"
,
pg_loss
[
-
1
]
.
item
(),
tb_
global_step
)
writer
.
add_scalar
(
"losses/entropy"
,
entropy_loss
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/entropy"
,
entropy_loss
[
-
1
]
.
item
(),
tb_
global_step
)
writer
.
add_scalar
(
"losses/approx_kl"
,
approx_kl
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/approx_kl"
,
approx_kl
[
-
1
]
.
item
(),
tb_
global_step
)
writer
.
add_scalar
(
"losses/loss"
,
loss
,
global_step
)
writer
.
add_scalar
(
"losses/loss"
,
loss
,
tb_
global_step
)
if
args
.
local_rank
==
0
and
learner_policy_version
%
args
.
save_interval
==
0
and
not
args
.
debug
:
if
args
.
local_rank
==
0
and
learner_policy_version
%
args
.
save_interval
==
0
and
not
args
.
debug
:
M_steps
=
args
.
batch_size
*
learner_policy_version
//
2
**
20
M_steps
=
tb_global_step
//
2
**
20
ckpt_name
=
f
"{timestamp}_{M_steps}M.flax_model"
ckpt_name
=
f
"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger
.
save
(
unreplicated_params
,
ckpt_name
)
ckpt_maneger
.
save
(
unreplicated_params
,
ckpt_name
)
if
args
.
gcs_bucket
is
not
None
:
if
args
.
gcs_bucket
is
not
None
:
...
...
ygoai/rl/jax/agent2.py
View file @
b8929b9c
...
@@ -339,7 +339,6 @@ class RNNAgent(nn.Module):
...
@@ -339,7 +339,6 @@ class RNNAgent(nn.Module):
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
jnp
.
dtype
=
jnp
.
float32
dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
multi_step
:
bool
=
False
switch
:
bool
=
True
switch
:
bool
=
True
freeze_id
:
bool
=
False
freeze_id
:
bool
=
False
use_history
:
bool
=
True
use_history
:
bool
=
True
...
@@ -347,7 +346,8 @@ class RNNAgent(nn.Module):
...
@@ -347,7 +346,8 @@ class RNNAgent(nn.Module):
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
inputs
):
def
__call__
(
self
,
inputs
):
if
self
.
multi_step
:
multi_step
=
len
(
inputs
)
!=
2
if
multi_step
:
# (num_steps * batch_size, ...)
# (num_steps * batch_size, ...)
*
rstate
,
x
,
done
,
switch_or_main
=
inputs
*
rstate
,
x
,
done
,
switch_or_main
=
inputs
else
:
else
:
...
@@ -380,7 +380,7 @@ class RNNAgent(nn.Module):
...
@@ -380,7 +380,7 @@ class RNNAgent(nn.Module):
elif
self
.
rnn_type
==
'none'
:
elif
self
.
rnn_type
==
'none'
:
f_state_r
=
jnp
.
concatenate
([
f_state
for
i
in
range
(
self
.
rnn_channels
//
c
)],
axis
=-
1
)
f_state_r
=
jnp
.
concatenate
([
f_state
for
i
in
range
(
self
.
rnn_channels
//
c
)],
axis
=-
1
)
else
:
else
:
if
self
.
multi_step
:
if
multi_step
:
rstate1
,
rstate2
=
rstate
rstate1
,
rstate2
=
rstate
batch_size
=
jax
.
tree
.
leaves
(
rstate1
)[
0
]
.
shape
[
0
]
batch_size
=
jax
.
tree
.
leaves
(
rstate1
)[
0
]
.
shape
[
0
]
num_steps
=
done
.
shape
[
0
]
//
batch_size
num_steps
=
done
.
shape
[
0
]
//
batch_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment