Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
2a419375
Commit
2a419375
authored
May 27, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add league training
parent
cd2974ce
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1047 additions
and
24 deletions
+1047
-24
scripts/cleanba.py
scripts/cleanba.py
+19
-24
scripts/cleanba_l.py
scripts/cleanba_l.py
+1028
-0
No files found.
scripts/cleanba.py
View file @
2a419375
...
@@ -43,8 +43,8 @@ class Args:
...
@@ -43,8 +43,8 @@ class Args:
"""seed of the experiment"""
"""seed of the experiment"""
log_frequency
:
int
=
10
log_frequency
:
int
=
10
"""the logging frequency of the model performance (in terms of `updates`)"""
"""the logging frequency of the model performance (in terms of `updates`)"""
time_log_freq
:
int
=
100
0
time_log_freq
:
int
=
0
"""the logging frequency of the deck time statistics"""
"""the logging frequency of the deck time statistics
, 0 to disable
"""
save_interval
:
int
=
400
save_interval
:
int
=
400
"""the frequency of saving the model (in terms of `updates`)"""
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint
:
Optional
[
str
]
=
None
checkpoint
:
Optional
[
str
]
=
None
...
@@ -360,7 +360,7 @@ def rollout(
...
@@ -360,7 +360,7 @@ def rollout(
if
args
.
concurrency
:
if
args
.
concurrency
:
if
update
!=
2
:
if
update
!=
2
:
params
=
params_queue
.
get
()
params
=
params_queue
.
get
()
params
[
"params"
][
"Encoder_0"
][
'Embed_0'
][
"embedding"
]
.
block_until_ready
()
#
params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
actor_policy_version
+=
1
actor_policy_version
+=
1
else
:
else
:
params
=
params_queue
.
get
()
params
=
params_queue
.
get
()
...
@@ -416,20 +416,21 @@ def rollout(
...
@@ -416,20 +416,21 @@ def rollout(
t
.
next_dones
[
idx
]
=
True
t
.
next_dones
[
idx
]
=
True
t
.
rewards
[
idx
]
=
-
next_reward
[
idx
]
t
.
rewards
[
idx
]
=
-
next_reward
[
idx
]
break
break
for
i
in
range
(
2
):
if
args
.
time_log_freq
:
deck_time
=
info
[
'step_time'
][
idx
][
i
]
for
i
in
range
(
2
):
deck_name
=
deck_names
[
info
[
'deck'
][
idx
][
i
]]
deck_time
=
info
[
'step_time'
][
idx
][
i
]
deck_name
=
deck_names
[
info
[
'deck'
][
idx
][
i
]]
time_count
=
deck_time_count
[
deck_name
]
avg_time
=
deck_avg_times
[
deck_name
]
time_count
=
deck_time_count
[
deck_name
]
avg_time
=
avg_time
*
(
time_count
/
(
time_count
+
1
))
+
deck_time
/
(
time_count
+
1
)
avg_time
=
deck_avg_times
[
deck_name
]
max_time
=
max
(
deck_time
,
deck_max_times
[
deck_name
])
avg_time
=
avg_time
*
(
time_count
/
(
time_count
+
1
))
+
deck_time
/
(
time_count
+
1
)
deck_avg_times
[
deck_name
]
=
avg_time
max_time
=
max
(
deck_time
,
deck_max_times
[
deck_name
])
deck_max_times
[
deck_name
]
=
max_time
deck_avg_times
[
deck_name
]
=
avg_time
deck_time_count
[
deck_name
]
+=
1
deck_max_times
[
deck_name
]
=
max_time
if
deck_time_count
[
deck_name
]
%
args
.
time_log_freq
==
0
:
deck_time_count
[
deck_name
]
+=
1
print
(
f
"Deck {deck_name}, avg: {avg_time * 1000:.2f}, max: {max_time * 1000:.2f}"
)
if
deck_time_count
[
deck_name
]
%
args
.
time_log_freq
==
0
:
print
(
f
"Deck {deck_name}, avg: {avg_time * 1000:.2f}, max: {max_time * 1000:.2f}"
)
episode_reward
=
info
[
'r'
][
idx
]
*
(
1
if
cur_main
else
-
1
)
episode_reward
=
info
[
'r'
][
idx
]
*
(
1
if
cur_main
else
-
1
)
win
=
1
if
episode_reward
>
0
else
0
win
=
1
if
episode_reward
>
0
else
0
...
@@ -474,14 +475,12 @@ def rollout(
...
@@ -474,14 +475,12 @@ def rollout(
else
:
else
:
eval_stats
=
None
eval_stats
=
None
learn_opponent
=
False
payload
=
(
payload
=
(
global_step
,
global_step
,
update
,
update
,
sharded_storage
,
sharded_storage
,
*
sharded_data
,
*
sharded_data
,
np
.
mean
(
params_queue_get_time
),
np
.
mean
(
params_queue_get_time
),
learn_opponent
,
eval_stats
,
eval_stats
,
)
)
rollout_queue
.
put
(
payload
)
rollout_queue
.
put
(
payload
)
...
@@ -758,7 +757,6 @@ def main():
...
@@ -758,7 +757,6 @@ def main():
sharded_next_inputs
:
List
,
sharded_next_inputs
:
List
,
sharded_next_main
:
List
,
sharded_next_main
:
List
,
key
:
jax
.
random
.
PRNGKey
,
key
:
jax
.
random
.
PRNGKey
,
learn_opponent
:
bool
=
False
,
):
):
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
# TODO: rstate will be out-date after the first update, maybe consider R2D2
# TODO: rstate will be out-date after the first update, maybe consider R2D2
...
@@ -862,7 +860,6 @@ def main():
...
@@ -862,7 +860,6 @@ def main():
single_device_update
,
single_device_update
,
axis_name
=
"local_devices"
,
axis_name
=
"local_devices"
,
devices
=
global_learner_decices
,
devices
=
global_learner_decices
,
static_broadcasted_argnums
=
(
7
,),
)
)
params_queues
=
[]
params_queues
=
[]
...
@@ -906,7 +903,6 @@ def main():
...
@@ -906,7 +903,6 @@ def main():
update
,
update
,
*
sharded_data
,
*
sharded_data
,
avg_params_queue_get_time
,
avg_params_queue_get_time
,
learn_opponent
,
eval_stats
,
eval_stats
,
)
=
rollout_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
get
()
)
=
rollout_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
get
()
sharded_data_list
.
append
(
sharded_data
)
sharded_data_list
.
append
(
sharded_data
)
...
@@ -929,13 +925,12 @@ def main():
...
@@ -929,13 +925,12 @@ def main():
agent_state
,
agent_state
,
*
list
(
zip
(
*
sharded_data_list
)),
*
list
(
zip
(
*
sharded_data_list
)),
learner_keys
,
learner_keys
,
learn_opponent
,
)
)
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
agent_state
.
params
)
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
agent_state
.
params
)
params_queue_put_time
=
0
params_queue_put_time
=
0
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
device_params
=
jax
.
device_put
(
unreplicated_params
,
local_devices
[
d_id
])
device_params
=
jax
.
device_put
(
unreplicated_params
,
local_devices
[
d_id
])
#
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
device_params
[
"params"
][
"Encoder_0"
][
'Embed_0'
][
"embedding"
]
.
block_until_ready
()
params_queue_put_start
=
time
.
time
()
params_queue_put_start
=
time
.
time
()
for
thread_id
in
range
(
args
.
num_actor_threads
):
for
thread_id
in
range
(
args
.
num_actor_threads
):
params_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
put
(
device_params
)
params_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
put
(
device_params
)
...
...
scripts/cleanba_l.py
0 → 100644
View file @
2a419375
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment