Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
671ed3c6
Commit
671ed3c6
authored
Apr 11, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Replace tree_map with tree.map
parent
9670ed68
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
8 additions
and
8 deletions
+8
-8
scripts/battle.py
scripts/battle.py
+1
-1
scripts/eval.py
scripts/eval.py
+1
-1
scripts/jax/battle.py
scripts/jax/battle.py
+2
-2
scripts/jax/impala.py
scripts/jax/impala.py
+4
-4
No files found.
scripts/battle.py
View file @
671ed3c6
...
@@ -207,7 +207,7 @@ if __name__ == "__main__":
...
@@ -207,7 +207,7 @@ if __name__ == "__main__":
agent
=
create_agent
(
args
)
agent
=
create_agent
(
args
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
sample_obs
=
jax
.
tree
_
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
params
=
agent
.
init
(
agent_key
,
sample_obs
)
params
=
agent
.
init
(
agent_key
,
sample_obs
)
print
(
jax
.
tree
.
leaves
(
params
)[
0
]
.
devices
())
print
(
jax
.
tree
.
leaves
(
params
)[
0
]
.
devices
())
with
open
(
args
.
checkpoint1
,
"rb"
)
as
f
:
with
open
(
args
.
checkpoint1
,
"rb"
)
as
f
:
...
...
scripts/eval.py
View file @
671ed3c6
...
@@ -224,7 +224,7 @@ if __name__ == "__main__":
...
@@ -224,7 +224,7 @@ if __name__ == "__main__":
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
sample_obs
=
jax
.
tree
_
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
params
=
agent
.
init
(
agent_key
,
sample_obs
)
params
=
agent
.
init
(
agent_key
,
sample_obs
)
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
...
...
scripts/jax/battle.py
View file @
671ed3c6
...
@@ -153,7 +153,7 @@ if __name__ == "__main__":
...
@@ -153,7 +153,7 @@ if __name__ == "__main__":
agent
=
create_agent
(
args
)
agent
=
create_agent
(
args
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
sample_obs
=
jax
.
tree
_
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
params
=
jax
.
jit
(
agent
.
init
)(
agent_key
,
(
rstate
,
sample_obs
))
params
=
jax
.
jit
(
agent
.
init
)(
agent_key
,
(
rstate
,
sample_obs
))
...
@@ -171,7 +171,7 @@ if __name__ == "__main__":
...
@@ -171,7 +171,7 @@ if __name__ == "__main__":
agent
=
create_agent
(
args
)
agent
=
create_agent
(
args
)
next_rstate
,
logits
=
agent
.
apply
(
params
,
(
rstate
,
obs
))[:
2
]
next_rstate
,
logits
=
agent
.
apply
(
params
,
(
rstate
,
obs
))[:
2
]
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
next_rstate
=
jax
.
tree
_
map
(
next_rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
next_rstate
)
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
next_rstate
)
return
next_rstate
,
probs
return
next_rstate
,
probs
...
...
scripts/jax/impala.py
View file @
671ed3c6
...
@@ -237,7 +237,7 @@ def rollout(
...
@@ -237,7 +237,7 @@ def rollout(
next_obs
,
next_obs
,
key
:
jax
.
random
.
PRNGKey
,
key
:
jax
.
random
.
PRNGKey
,
):
):
next_obs
=
jax
.
tree
_
map
(
lambda
x
:
jnp
.
array
(
x
),
next_obs
)
next_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
(
x
),
next_obs
)
logits
=
apply_fn
(
params
,
next_obs
)[
0
]
logits
=
apply_fn
(
params
,
next_obs
)[
0
]
# sample action: Gumbel-softmax trick
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
...
@@ -263,7 +263,7 @@ def rollout(
...
@@ -263,7 +263,7 @@ def rollout(
@
jax
.
jit
@
jax
.
jit
def
prepare_data
(
storage
:
List
[
Transition
])
->
Transition
:
def
prepare_data
(
storage
:
List
[
Transition
])
->
Transition
:
return
jax
.
tree
_
map
(
lambda
*
xs
:
jnp
.
split
(
jnp
.
stack
(
xs
),
len
(
learner_devices
),
axis
=
1
),
*
storage
)
return
jax
.
tree
.
map
(
lambda
*
xs
:
jnp
.
split
(
jnp
.
stack
(
xs
),
len
(
learner_devices
),
axis
=
1
),
*
storage
)
for
update
in
range
(
1
,
args
.
num_updates
+
2
):
for
update
in
range
(
1
,
args
.
num_updates
+
2
):
if
update
==
10
:
if
update
==
10
:
...
@@ -469,7 +469,7 @@ if __name__ == "__main__":
...
@@ -469,7 +469,7 @@ if __name__ == "__main__":
obs_space
=
envs
.
observation_space
obs_space
=
envs
.
observation_space
action_shape
=
envs
.
action_space
.
shape
action_shape
=
envs
.
action_space
.
shape
print
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
print
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
sample_obs
=
jax
.
tree
_
map
(
lambda
x
:
jnp
.
array
([
np
.
zeros
((
args
.
local_num_envs
,)
+
x
.
shape
[
1
:])]),
obs_space
.
sample
())
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
np
.
zeros
((
args
.
local_num_envs
,)
+
x
.
shape
[
1
:])]),
obs_space
.
sample
())
envs
.
close
()
envs
.
close
()
del
envs
del
envs
...
@@ -579,7 +579,7 @@ if __name__ == "__main__":
...
@@ -579,7 +579,7 @@ if __name__ == "__main__":
sharded_storages
:
List
[
Transition
],
sharded_storages
:
List
[
Transition
],
key
:
jax
.
random
.
PRNGKey
,
key
:
jax
.
random
.
PRNGKey
,
):
):
storage
=
jax
.
tree
_
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
impala_loss_grad_fn
=
jax
.
value_and_grad
(
impala_loss
,
has_aux
=
True
)
impala_loss_grad_fn
=
jax
.
value_and_grad
(
impala_loss
,
has_aux
=
True
)
def
update_minibatch
(
agent_state
,
minibatch
):
def
update_minibatch
(
agent_state
,
minibatch
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment