Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
1b81fc0d
Commit
1b81fc0d
authored
Jun 29, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix orbax
parent
d9334955
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
3 deletions
+4
-3
scripts/cleanba_nnx.py
scripts/cleanba_nnx.py
+4
-3
No files found.
scripts/cleanba_nnx.py
View file @
1b81fc0d
...
...
@@ -1154,10 +1154,10 @@ def main():
*
list
(
zip
(
*
sharded_data_list
)),
learner_keys
,
)
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
get_state
(
agent_state
)
)
new_state
=
get_state
(
agent_state
)
params_queue_put_time
=
0
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
device_params
=
jax
.
device_put
(
unreplicated_params
,
local_devices
[
d_id
])
device_params
=
jax
.
device_put
(
flax
.
jax_utils
.
unreplicate
(
new_state
)
,
local_devices
[
d_id
])
device_params
[
"encoder"
][
'id_embed'
][
"embedding"
]
.
value
.
block_until_ready
()
params_queue_put_start
=
time
.
time
()
for
thread_id
in
range
(
args
.
num_actor_threads
):
...
...
@@ -1197,7 +1197,8 @@ def main():
if
learner_policy_version
%
args
.
save_interval
==
0
and
not
args
.
debug
:
M_steps
=
tb_global_step
//
2
**
20
ckpt_name
=
f
"{timestamp}_{M_steps}M"
ckpt_maneger
.
save
(
unreplicated_params
,
ckpt_name
)
new_state
=
jax
.
tree
.
map
(
orbax
.
utils
.
fully_replicated_host_local_array_to_global_array
,
new_state
)
ckpt_maneger
.
save
(
new_state
,
ckpt_name
)
if
learner_policy_version
>=
args
.
num_updates
:
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment