Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
cd2974ce
Commit
cd2974ce
authored
May 25, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Change block location
parent
14bceecd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
10 deletions
+7
-10
scripts/cleanba.py
scripts/cleanba.py
+7
-10
No files found.
scripts/cleanba.py
View file @
cd2974ce
...
@@ -227,7 +227,7 @@ def create_agent(args, eval=False):
...
@@ -227,7 +227,7 @@ def create_agent(args, eval=False):
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
**
asdict
(
args
.
m2
),
**
asdict
(
args
.
m2
),
)
)
else
:
else
:
return
RNNAgent
(
return
RNNAgent
(
embedding_shape
=
args
.
num_embeddings
,
embedding_shape
=
args
.
num_embeddings
,
...
@@ -315,8 +315,8 @@ def rollout(
...
@@ -315,8 +315,8 @@ def rollout(
done
=
jnp
.
array
(
done
)
done
=
jnp
.
array
(
done
)
main
=
jnp
.
array
(
main
)
main
=
jnp
.
array
(
main
)
inputs
=
next_obs
,
(
rstate1
,
rstate2
),
done
,
main
(
rstate1
,
rstate2
),
logits
=
agent
.
apply
(
(
rstate1
,
rstate2
),
logits
=
agent
.
apply
(
params
,
*
inputs
)[:
2
]
params
,
next_obs
,
(
rstate1
,
rstate2
),
done
,
main
)[:
2
]
action
,
key
=
categorical_sample
(
logits
,
key
)
action
,
key
=
categorical_sample
(
logits
,
key
)
return
next_obs
,
done
,
main
,
rstate1
,
rstate2
,
action
,
logits
,
key
return
next_obs
,
done
,
main
,
rstate1
,
rstate2
,
action
,
logits
,
key
...
@@ -360,9 +360,7 @@ def rollout(
...
@@ -360,9 +360,7 @@ def rollout(
if
args
.
concurrency
:
if
args
.
concurrency
:
if
update
!=
2
:
if
update
!=
2
:
params
=
params_queue
.
get
()
params
=
params_queue
.
get
()
# params["params"]["Encoder_0"]['Embed_0'][
params
[
"params"
][
"Encoder_0"
][
'Embed_0'
][
"embedding"
]
.
block_until_ready
()
# "embedding"
# ].block_until_ready()
actor_policy_version
+=
1
actor_policy_version
+=
1
else
:
else
:
params
=
params_queue
.
get
()
params
=
params_queue
.
get
()
...
@@ -627,7 +625,6 @@ def main():
...
@@ -627,7 +625,6 @@ def main():
frac
=
1.0
-
(
count
//
(
args
.
num_minibatches
*
args
.
update_epochs
))
/
args
.
num_updates
frac
=
1.0
-
(
count
//
(
args
.
num_minibatches
*
args
.
update_epochs
))
/
args
.
num_updates
return
args
.
learning_rate
*
frac
return
args
.
learning_rate
*
frac
# rstate = init_rnn_state(1, args.rnn_channels)
agent
=
create_agent
(
args
)
agent
=
create_agent
(
args
)
rstate
=
agent
.
init_rnn_state
(
1
)
rstate
=
agent
.
init_rnn_state
(
1
)
params
=
agent
.
init
(
init_key
,
sample_obs
,
rstate
)
params
=
agent
.
init
(
init_key
,
sample_obs
,
rstate
)
...
@@ -687,8 +684,8 @@ def main():
...
@@ -687,8 +684,8 @@ def main():
if
args
.
switch
:
if
args
.
switch
:
dones
=
dones
|
next_dones
dones
=
dones
|
next_dones
inputs
=
obs
,
(
rstate1
,
rstate2
),
dones
,
switch_or_mains
new_logits
,
new_values
=
create_agent
(
args
)
.
apply
(
new_logits
,
new_values
=
create_agent
(
args
)
.
apply
(
params
,
*
input
s
)[
1
:
3
]
params
,
obs
,
(
rstate1
,
rstate2
),
dones
,
switch_or_main
s
)[
1
:
3
]
new_values
=
new_values
.
squeeze
(
-
1
)
new_values
=
new_values
.
squeeze
(
-
1
)
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
...
@@ -938,7 +935,7 @@ def main():
...
@@ -938,7 +935,7 @@ def main():
params_queue_put_time
=
0
params_queue_put_time
=
0
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
device_params
=
jax
.
device_put
(
unreplicated_params
,
local_devices
[
d_id
])
device_params
=
jax
.
device_put
(
unreplicated_params
,
local_devices
[
d_id
])
device_params
[
"params"
][
"Encoder_0"
][
'Embed_0'
][
"embedding"
]
.
block_until_ready
()
#
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
params_queue_put_start
=
time
.
time
()
params_queue_put_start
=
time
.
time
()
for
thread_id
in
range
(
args
.
num_actor_threads
):
for
thread_id
in
range
(
args
.
num_actor_threads
):
params_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
put
(
device_params
)
params_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
put
(
device_params
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment