Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
b5d92a22
Commit
b5d92a22
authored
May 28, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactor random seed
parent
2a419375
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
10 deletions
+14
-10
scripts/cleanba.py
scripts/cleanba.py
+7
-5
scripts/cleanba_l.py
scripts/cleanba_l.py
+7
-5
No files found.
scripts/cleanba.py
View file @
b5d92a22
...
...
@@ -184,6 +184,7 @@ class Args:
num_embeddings
:
Optional
[
int
]
=
None
freeze_id
:
Optional
[
bool
]
=
None
deck_names
:
Optional
[
List
[
str
]]
=
None
real_seed
:
Optional
[
int
]
=
None
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
,
eval
=
False
):
...
...
@@ -259,7 +260,7 @@ def rollout(
if
eval_mode
!=
'bot'
:
eval_params
=
params_queue
.
get
()
local_seed
=
args
.
seed
+
device_thread_id
*
100
local_seed
=
args
.
real_seed
+
device_thread_id
*
args
.
local_num_envs
np
.
random
.
seed
(
local_seed
)
envs
=
make_env
(
...
...
@@ -273,7 +274,7 @@ def rollout(
eval_envs
=
make_env
(
args
,
local_seed
+
10000
,
local_seed
+
10000
0
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
,
eval
=
True
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
...
...
@@ -595,11 +596,12 @@ def main():
args
.
ckpt_dir
,
save_fn
,
n_saved
=
2
)
# seeding
seed_offset
=
args
.
local_rank
*
1000
seed_offset
=
args
.
local_rank
args
.
seed
+=
seed_offset
random
.
seed
(
args
.
seed
)
args
.
real_seed
=
random
.
randint
(
0
,
1e8
)
init_key
=
jax
.
random
.
PRNGKey
(
args
.
seed
-
seed_offset
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
=
jax
.
random
.
PRNGKey
(
args
.
real_
seed
)
key
,
*
learner_keys
=
jax
.
random
.
split
(
key
,
len
(
learner_devices
)
+
1
)
learner_keys
=
jax
.
device_put_sharded
(
learner_keys
,
devices
=
learner_devices
)
actor_keys
=
jax
.
random
.
split
(
key
,
len
(
actor_devices
)
*
args
.
num_actor_threads
)
...
...
@@ -610,7 +612,7 @@ def main():
args
.
deck2
=
args
.
deck2
or
deck
# env setup
envs
=
make_env
(
args
,
args
.
seed
,
8
,
1
)
envs
=
make_env
(
args
,
0
,
2
,
1
)
obs_space
=
envs
.
observation_space
action_shape
=
envs
.
action_space
.
shape
print
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
...
...
scripts/cleanba_l.py
View file @
b5d92a22
...
...
@@ -191,6 +191,7 @@ class Args:
num_embeddings
:
Optional
[
int
]
=
None
freeze_id
:
Optional
[
bool
]
=
None
deck_names
:
Optional
[
List
[
str
]]
=
None
real_seed
:
Optional
[
int
]
=
None
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
,
eval
=
False
):
...
...
@@ -266,7 +267,7 @@ def rollout(
if
eval_mode
!=
'bot'
:
eval_params
=
params_queue
.
get
()
local_seed
=
args
.
seed
+
device_thread_id
*
100
local_seed
=
args
.
real_seed
+
device_thread_id
*
args
.
local_num_envs
np
.
random
.
seed
(
local_seed
)
envs
=
make_env
(
...
...
@@ -280,7 +281,7 @@ def rollout(
eval_envs
=
make_env
(
args
,
local_seed
+
10000
,
local_seed
+
10000
0
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
,
eval
=
True
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
...
...
@@ -619,11 +620,12 @@ def main():
args
.
ckpt_dir
,
save_fn
,
n_saved
=
2
)
# seeding
seed_offset
=
args
.
local_rank
*
1000
seed_offset
=
args
.
local_rank
args
.
seed
+=
seed_offset
random
.
seed
(
args
.
seed
)
args
.
real_seed
=
random
.
randint
(
0
,
1e8
)
init_key
=
jax
.
random
.
PRNGKey
(
args
.
seed
-
seed_offset
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
=
jax
.
random
.
PRNGKey
(
args
.
real_
seed
)
key
,
*
learner_keys
=
jax
.
random
.
split
(
key
,
len
(
learner_devices
)
+
1
)
learner_keys
=
jax
.
device_put_sharded
(
learner_keys
,
devices
=
learner_devices
)
actor_keys
=
jax
.
random
.
split
(
key
,
len
(
actor_devices
)
*
args
.
num_actor_threads
)
...
...
@@ -634,7 +636,7 @@ def main():
args
.
deck2
=
args
.
deck2
or
deck
# env setup
envs
=
make_env
(
args
,
args
.
seed
,
8
,
1
)
envs
=
make_env
(
args
,
0
,
2
,
1
)
obs_space
=
envs
.
observation_space
action_shape
=
envs
.
action_space
.
shape
print
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment