Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
81f63996
Commit
81f63996
authored
May 01, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix random seed
parent
3e538bc7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
18 deletions
+32
-18
scripts/impala.py
scripts/impala.py
+16
-9
scripts/ppo.py
scripts/ppo.py
+16
-9
No files found.
scripts/impala.py
View file @
81f63996
...
...
@@ -229,9 +229,12 @@ def rollout(
if
eval_mode
!=
'bot'
:
eval_params
=
params_queue
.
get
()
local_seed
=
args
.
seed
+
device_thread_id
np
.
random
.
seed
(
local_seed
)
envs
=
make_env
(
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_i
d
,
local_see
d
,
args
.
local_num_envs
,
args
.
local_env_threads
,
thread_affinity_offset
=
device_thread_id
*
args
.
local_env_threads
,
...
...
@@ -240,7 +243,7 @@ def rollout(
eval_envs
=
make_env
(
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_i
d
,
local_see
d
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
,
eval
=
True
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
...
...
@@ -542,11 +545,14 @@ if __name__ == "__main__":
args
.
ckpt_dir
,
save_fn
,
n_saved
=
3
)
# seeding
seed_offset
=
args
.
local_rank
*
10000
args
.
seed
+=
seed_offset
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
init_key
=
jax
.
random
.
PRNGKey
(
args
.
seed
-
seed_offset
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
learner_keys
=
jax
.
device_put_replicated
(
key
,
learner_devices
)
key
,
*
learner_keys
=
jax
.
random
.
split
(
key
,
len
(
learner_devices
)
+
1
)
learner_keys
=
jax
.
device_put_sharded
(
learner_keys
,
devices
=
learner_devices
)
actor_keys
=
jax
.
random
.
split
(
key
,
len
(
actor_devices
)
*
args
.
num_actor_threads
)
deck
=
init_ygopro
(
args
.
env_id
,
"english"
,
args
.
deck
,
args
.
code_list_file
)
args
.
deck1
=
args
.
deck1
or
deck
...
...
@@ -569,7 +575,7 @@ if __name__ == "__main__":
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
agent
=
create_agent
(
args
)
params
=
agent
.
init
(
agen
t_key
,
(
rstate
,
sample_obs
))
params
=
agent
.
init
(
ini
t_key
,
(
rstate
,
sample_obs
))
if
embeddings
is
not
None
:
unknown_embed
=
embeddings
.
mean
(
axis
=
0
)
embeddings
=
np
.
concatenate
([
unknown_embed
[
None
,
:],
embeddings
],
axis
=
0
)
...
...
@@ -775,18 +781,19 @@ if __name__ == "__main__":
rollout_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
if
eval_params
:
params_queues
[
-
1
]
.
put
(
jax
.
device_put
(
eval_params
,
local_devices
[
d_id
]))
jax
.
device_put
(
eval_params
,
local_devices
[
d_id
]))
actor_thread_id
=
d_idx
*
args
.
num_actor_threads
+
thread_id
threading
.
Thread
(
target
=
rollout
,
args
=
(
jax
.
device_put
(
key
,
local_devices
[
d_id
]),
jax
.
device_put
(
actor_keys
[
actor_thread_id
]
,
local_devices
[
d_id
]),
args
,
rollout_queues
[
-
1
],
params_queues
[
-
1
],
eval_queue
,
writer
if
d_idx
==
0
and
thread_id
==
0
else
dummy_writer
,
learner_devices
,
d_idx
*
args
.
num_actor_threads
+
thread_id
,
actor_
thread_id
,
),
)
.
start
()
params_queues
[
-
1
]
.
put
(
device_params
)
...
...
scripts/ppo.py
View file @
81f63996
...
...
@@ -229,9 +229,12 @@ def rollout(
if
eval_mode
!=
'bot'
:
eval_params
=
params_queue
.
get
()
local_seed
=
args
.
seed
+
device_thread_id
np
.
random
.
seed
(
local_seed
)
envs
=
make_env
(
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_i
d
,
local_see
d
,
args
.
local_num_envs
,
args
.
local_env_threads
,
thread_affinity_offset
=
device_thread_id
*
args
.
local_env_threads
,
...
...
@@ -240,7 +243,7 @@ def rollout(
eval_envs
=
make_env
(
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_i
d
,
local_see
d
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
,
eval
=
True
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
...
...
@@ -552,11 +555,14 @@ if __name__ == "__main__":
args
.
ckpt_dir
,
save_fn
,
n_saved
=
3
)
# seeding
seed_offset
=
args
.
local_rank
*
10000
args
.
seed
+=
seed_offset
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
init_key
=
jax
.
random
.
PRNGKey
(
args
.
seed
-
seed_offset
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
learner_keys
=
jax
.
device_put_replicated
(
key
,
learner_devices
)
key
,
*
learner_keys
=
jax
.
random
.
split
(
key
,
len
(
learner_devices
)
+
1
)
learner_keys
=
jax
.
device_put_sharded
(
learner_keys
,
devices
=
learner_devices
)
actor_keys
=
jax
.
random
.
split
(
key
,
len
(
actor_devices
)
*
args
.
num_actor_threads
)
deck
=
init_ygopro
(
args
.
env_id
,
"english"
,
args
.
deck
,
args
.
code_list_file
)
args
.
deck1
=
args
.
deck1
or
deck
...
...
@@ -579,7 +585,7 @@ if __name__ == "__main__":
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
agent
=
create_agent
(
args
)
params
=
agent
.
init
(
agen
t_key
,
(
rstate
,
sample_obs
))
params
=
agent
.
init
(
ini
t_key
,
(
rstate
,
sample_obs
))
if
embeddings
is
not
None
:
unknown_embed
=
embeddings
.
mean
(
axis
=
0
)
embeddings
=
np
.
concatenate
([
unknown_embed
[
None
,
:],
embeddings
],
axis
=
0
)
...
...
@@ -801,18 +807,19 @@ if __name__ == "__main__":
rollout_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
if
eval_params
:
params_queues
[
-
1
]
.
put
(
jax
.
device_put
(
eval_params
,
local_devices
[
d_id
]))
jax
.
device_put
(
eval_params
,
local_devices
[
d_id
]))
actor_thread_id
=
d_idx
*
args
.
num_actor_threads
+
thread_id
threading
.
Thread
(
target
=
rollout
,
args
=
(
jax
.
device_put
(
key
,
local_devices
[
d_id
]),
jax
.
device_put
(
actor_keys
[
actor_thread_id
]
,
local_devices
[
d_id
]),
args
,
rollout_queues
[
-
1
],
params_queues
[
-
1
],
eval_queue
,
writer
if
d_idx
==
0
and
thread_id
==
0
else
dummy_writer
,
learner_devices
,
d_idx
*
args
.
num_actor_threads
+
thread_id
,
actor_
thread_id
,
),
)
.
start
()
params_queues
[
-
1
]
.
put
(
device_params
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment