Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
1d35fed3
Commit
1d35fed3
authored
May 14, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactor RNN inputs
parent
b8929b9c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
53 additions
and
53 deletions
+53
-53
scripts/battle.py
scripts/battle.py
+3
-3
scripts/cleanba.py
scripts/cleanba.py
+15
-23
scripts/eval.py
scripts/eval.py
+2
-3
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+31
-22
ygoai/rl/jax/eval.py
ygoai/rl/jax/eval.py
+2
-2
No files found.
scripts/battle.py
View file @
1d35fed3
...
...
@@ -158,7 +158,7 @@ if __name__ == "__main__":
agent1
=
create_agent1
(
args
)
rstate
=
agent1
.
init_rnn_state
(
1
)
params1
=
jax
.
jit
(
agent1
.
init
)(
agent_key
,
(
rstate
,
sample_obs
)
)
params1
=
jax
.
jit
(
agent1
.
init
)(
agent_key
,
sample_obs
,
rstate
)
with
open
(
args
.
checkpoint1
,
"rb"
)
as
f
:
params1
=
flax
.
serialization
.
from_bytes
(
params1
,
f
.
read
())
...
...
@@ -167,7 +167,7 @@ if __name__ == "__main__":
else
:
agent2
=
create_agent2
(
args
)
rstate
=
agent2
.
init_rnn_state
(
1
)
params2
=
jax
.
jit
(
agent2
.
init
)(
agent_key
,
(
rstate
,
sample_obs
)
)
params2
=
jax
.
jit
(
agent2
.
init
)(
agent_key
,
sample_obs
,
rstate
)
with
open
(
args
.
checkpoint2
,
"rb"
)
as
f
:
params2
=
flax
.
serialization
.
from_bytes
(
params2
,
f
.
read
())
...
...
@@ -180,7 +180,7 @@ if __name__ == "__main__":
agent
=
create_agent1
(
args
)
else
:
agent
=
create_agent2
(
args
)
next_rstate
,
logits
=
agent
.
apply
(
params
,
(
rstate
,
obs
)
)[:
2
]
next_rstate
,
logits
=
agent
.
apply
(
params
,
obs
,
rstate
)[:
2
]
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
if
done
is
not
None
:
next_rstate
=
jnp
.
where
(
done
[:,
None
],
0
,
next_rstate
)
...
...
scripts/cleanba.py
View file @
1d35fed3
...
...
@@ -294,15 +294,14 @@ def rollout(
eval_agent
=
create_agent
(
args
,
eval
=
True
)
@
jax
.
jit
def
get_action
(
params
:
flax
.
core
.
FrozenDict
,
inputs
):
rstate
,
logits
=
eval_agent
.
apply
(
params
,
inputs
)[:
2
]
def
get_action
(
params
,
obs
,
rstate
):
rstate
,
logits
=
eval_agent
.
apply
(
params
,
obs
,
rstate
)[:
2
]
return
rstate
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
get_action_battle
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
logits1
=
agent
.
apply
(
params1
,
(
rstate1
,
obs
)
)[:
2
]
next_rstate2
,
logits2
=
eval_agent
.
apply
(
params2
,
(
rstate2
,
obs
)
)[:
2
]
def
get_action_battle
(
params1
,
params2
,
obs
,
rstate1
,
rstate2
,
main
,
done
):
next_rstate1
,
logits1
=
agent
.
apply
(
params1
,
obs
,
rstate1
)[:
2
]
next_rstate2
,
logits2
=
eval_agent
.
apply
(
params2
,
obs
,
rstate2
)[:
2
]
logits
=
jnp
.
where
(
main
[:,
None
],
logits1
,
logits2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
...
...
@@ -314,19 +313,13 @@ def rollout(
@
jax
.
jit
def
sample_action
(
params
:
flax
.
core
.
FrozenDict
,
next_obs
,
rstate1
,
rstate2
,
main
,
done
,
key
):
params
,
next_obs
,
rstate1
,
rstate2
,
main
,
done
,
key
):
next_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
(
x
),
next_obs
)
done
=
jnp
.
array
(
done
)
main
=
jnp
.
array
(
main
)
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
logits
=
agent
.
apply
(
params
,
(
rstate
,
next_obs
))[:
2
]
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
rstate
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
inputs
=
next_obs
,
(
rstate1
,
rstate2
),
done
,
main
(
rstate1
,
rstate2
),
logits
=
agent
.
apply
(
params
,
*
inputs
)[:
2
]
action
,
key
=
categorical_sample
(
logits
,
key
)
return
next_obs
,
done
,
main
,
rstate1
,
rstate2
,
action
,
logits
,
key
...
...
@@ -448,12 +441,12 @@ def rollout(
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
next_rstate1
,
next_rstate2
)
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
(
init_rstate1
,
init_rstate2
,
(
next_
rstate
,
next_obs
),
next_main
))
(
init_rstate1
,
init_rstate2
,
(
next_
obs
,
next_rstate
),
next_main
))
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
_start
=
time
.
time
()
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
predict_fn
=
lambda
*
x
:
get_action
(
params
,
*
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate2
)
else
:
...
...
@@ -619,7 +612,7 @@ if __name__ == "__main__":
# rstate = init_rnn_state(1, args.rnn_channels)
agent
=
create_agent
(
args
)
rstate
=
agent
.
init_rnn_state
(
1
)
params
=
agent
.
init
(
init_key
,
(
rstate
,
sample_obs
)
)
params
=
agent
.
init
(
init_key
,
sample_obs
,
rstate
)
if
embeddings
is
not
None
:
unknown_embed
=
embeddings
.
mean
(
axis
=
0
)
embeddings
=
np
.
concatenate
([
unknown_embed
[
None
,
:],
embeddings
],
axis
=
0
)
...
...
@@ -654,7 +647,7 @@ if __name__ == "__main__":
if
args
.
eval_checkpoint
:
eval_agent
=
create_agent
(
args
,
eval
=
True
)
eval_rstate
=
eval_agent
.
init_rnn_state
(
1
)
eval_params
=
eval_agent
.
init
(
init_key
,
(
eval_rstate
,
sample_obs
)
)
eval_params
=
eval_agent
.
init
(
init_key
,
sample_obs
,
eval_rstate
)
with
open
(
args
.
eval_checkpoint
,
"rb"
)
as
f
:
eval_params
=
flax
.
serialization
.
from_bytes
(
eval_params
,
f
.
read
())
print
(
f
"loaded eval checkpoint from {args.eval_checkpoint}"
)
...
...
@@ -676,9 +669,8 @@ if __name__ == "__main__":
if
args
.
switch
:
dones
=
dones
|
next_dones
inputs
=
(
rstate1
,
rstate2
,
obs
,
dones
,
switch_or_mains
)
_rstate
,
new_logits
,
new_values
,
_valid
=
create_agent
(
args
)
.
apply
(
params
,
inputs
)
inputs
=
obs
,
(
rstate1
,
rstate2
),
dones
,
switch_or_mains
new_logits
,
new_values
=
create_agent
(
args
)
.
apply
(
params
,
*
inputs
)[
1
:
3
]
new_values
=
new_values
.
squeeze
(
-
1
)
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
...
...
@@ -780,7 +772,7 @@ if __name__ == "__main__":
key
,
subkey
=
jax
.
random
.
split
(
key
)
next_value
=
create_agent
(
args
)
.
apply
(
agent_state
.
params
,
next_inputs
)[
2
]
.
squeeze
(
-
1
)
agent_state
.
params
,
*
next_inputs
)[
2
]
.
squeeze
(
-
1
)
if
args
.
switch
:
next_value
=
jnp
.
where
(
next_main
,
-
next_value
,
next_value
)
else
:
...
...
scripts/eval.py
View file @
1d35fed3
...
...
@@ -145,7 +145,7 @@ if __name__ == "__main__":
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
rstate
=
agent
.
init_rnn_state
(
1
)
params
=
jax
.
jit
(
agent
.
init
)(
agent_key
,
(
rstate
,
sample_obs
)
)
params
=
jax
.
jit
(
agent
.
init
)(
agent_key
,
sample_obs
,
rstate
)
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
...
...
@@ -154,8 +154,7 @@ if __name__ == "__main__":
@
jax
.
jit
def
get_probs_and_value
(
params
,
rstate
,
obs
,
done
):
agent
=
agent
next_rstate
,
logits
,
value
=
agent
.
apply
(
params
,
(
rstate
,
obs
))[:
3
]
next_rstate
,
logits
,
value
=
agent
.
apply
(
params
,
obs
,
rstate
)[:
3
]
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
next_rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
next_rstate
)
...
...
ygoai/rl/jax/agent2.py
View file @
1d35fed3
...
...
@@ -308,7 +308,21 @@ class Critic(nn.Module):
return
x
def
rnn_forward_2p
(
rnn_layer
,
rstate1
,
rstate2
,
f_state
,
done
,
switch_or_main
,
switch
=
True
):
def
rnn_step_by_main
(
rnn_layer
,
rstate
,
f_state
,
done
,
main
):
if
main
is
not
None
:
rstate1
,
rstate2
=
rstate
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
f_state
=
rnn_layer
(
rstate
,
f_state
)
if
main
is
not
None
:
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate
=
rstate1
,
rstate2
if
done
is
not
None
:
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
return
rstate
,
f_state
def
rnn_forward_2p
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
,
switch
=
True
):
if
switch
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
switch
):
rstate
,
init_rstate2
=
carry
...
...
@@ -318,20 +332,15 @@ def rnn_forward_2p(rnn_layer, rstate1, rstate2, f_state, done, switch_or_main, s
return
(
rstate
,
init_rstate2
),
y
else
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
main
):
rstate1
,
rstate2
=
carry
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
y
=
cell
(
rstate
,
x
)
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
return
(
rstate1
,
rstate2
),
y
return
rnn_step_by_main
(
cell
,
carry
,
x
,
done
,
main
)
scan
=
nn
.
scan
(
body_fn
,
variable_broadcast
=
'params'
,
split_rngs
=
{
'params'
:
False
})
rstate
,
f_state
=
scan
(
rnn_layer
,
(
rstate1
,
rstate2
)
,
f_state
,
done
,
switch_or_main
)
rstate
,
f_state
=
scan
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
)
return
rstate
,
f_state
class
RNNAgent
(
nn
.
Module
):
channels
:
int
=
128
num_layers
:
int
=
2
...
...
@@ -345,14 +354,7 @@ class RNNAgent(nn.Module):
rnn_type
:
str
=
'lstm'
@
nn
.
compact
def
__call__
(
self
,
inputs
):
multi_step
=
len
(
inputs
)
!=
2
if
multi_step
:
# (num_steps * batch_size, ...)
*
rstate
,
x
,
done
,
switch_or_main
=
inputs
else
:
rstate
,
x
=
inputs
def
__call__
(
self
,
x
,
rstate
,
done
=
None
,
switch_or_main
=
None
):
c
=
self
.
channels
encoder
=
Encoder
(
channels
=
c
,
...
...
@@ -380,17 +382,24 @@ class RNNAgent(nn.Module):
elif
self
.
rnn_type
==
'none'
:
f_state_r
=
jnp
.
concatenate
([
f_state
for
i
in
range
(
self
.
rnn_channels
//
c
)],
axis
=-
1
)
else
:
batch_size
=
jax
.
tree
.
leaves
(
rstate
)[
0
]
.
shape
[
0
]
num_steps
=
f_state
.
shape
[
0
]
//
batch_size
multi_step
=
num_steps
>
1
if
done
is
not
None
:
assert
switch_or_main
is
not
None
else
:
assert
not
multi_step
if
multi_step
:
rstate1
,
rstate2
=
rstate
batch_size
=
jax
.
tree
.
leaves
(
rstate1
)[
0
]
.
shape
[
0
]
num_steps
=
done
.
shape
[
0
]
//
batch_size
f_state_r
,
done
,
switch_or_main
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
batch_size
)
+
x
.
shape
[
1
:]),
(
f_state
,
done
,
switch_or_main
))
rstate
,
f_state_r
=
rnn_forward_2p
(
rnn_layer
,
rstate
1
,
rstate2
,
f_state_r
,
done
,
switch_or_main
,
self
.
switch
)
rnn_layer
,
rstate
,
f_state_r
,
done
,
switch_or_main
,
self
.
switch
)
f_state_r
=
f_state_r
.
reshape
((
-
1
,
f_state_r
.
shape
[
-
1
]))
else
:
rstate
,
f_state_r
=
rnn_layer
(
rstate
,
f_state
)
rstate
,
f_state_r
=
rnn_step_by_main
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
)
actor
=
Actor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
...
...
ygoai/rl/jax/eval.py
View file @
1d35fed3
...
...
@@ -11,7 +11,7 @@ def evaluate(envs, num_episodes, predict_fn, rnn_state=None):
if
rnn_state
is
None
:
actions
=
predict_fn
(
obs
)
else
:
rnn_state
,
actions
=
predict_fn
(
(
rnn_state
,
obs
)
)
rnn_state
,
actions
=
predict_fn
(
obs
,
rnn_state
)
actions
=
np
.
array
(
actions
)
obs
,
rewards
,
dones
,
info
=
envs
.
step
(
actions
)
...
...
@@ -53,7 +53,7 @@ def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None):
while
True
:
main
=
next_to_play
==
main_player
rstate1
,
rstate2
,
actions
=
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
dones
)
rstate1
,
rstate2
,
actions
=
predict_fn
(
obs
,
rstate1
,
rstate2
,
main
,
dones
)
actions
=
np
.
array
(
actions
)
obs
,
rewards
,
dones
,
infos
=
envs
.
step
(
actions
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment