Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
81d80f7f
Commit
81d80f7f
authored
Apr 30, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix done for rstate
parent
3bf0bc91
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
66 additions
and
42 deletions
+66
-42
scripts/battle.py
scripts/battle.py
+7
-5
scripts/impala.py
scripts/impala.py
+31
-20
scripts/ppo.py
scripts/ppo.py
+16
-12
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+7
-4
ygoai/utils.py
ygoai/utils.py
+5
-1
No files found.
scripts/battle.py
View file @
81d80f7f
...
...
@@ -157,24 +157,26 @@ if __name__ == "__main__":
params2
=
jax
.
device_put
(
params2
)
@
jax
.
jit
def
get_probs
(
params
,
rstate
,
obs
,
done
):
def
get_probs
(
params
,
rstate
,
obs
,
done
=
None
):
agent
=
create_agent
(
args
)
next_rstate
,
logits
=
agent
.
apply
(
params
,
(
rstate
,
obs
))[:
2
]
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
next_rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
)
,
next_rstate
)
if
done
is
not
None
:
next_rstate
=
jnp
.
where
(
done
[:,
None
],
0
,
next_rstate
)
return
next_rstate
,
probs
if
args
.
num_envs
!=
1
:
@
jax
.
jit
def
get_probs2
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
probs1
=
get_probs
(
params1
,
rstate1
,
obs
,
done
)
next_rstate2
,
probs2
=
get_probs
(
params2
,
rstate2
,
obs
,
done
)
next_rstate1
,
probs1
=
get_probs
(
params1
,
rstate1
,
obs
)
next_rstate2
,
probs2
=
get_probs
(
params2
,
rstate2
,
obs
)
probs
=
jnp
.
where
(
main
[:,
None
],
probs1
,
probs2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
next_rstate2
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
return
rstate1
,
rstate2
,
probs
def
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
done
):
...
...
scripts/impala.py
View file @
81d80f7f
import os
import shutil
import queue
import random
import threading
...
...
@@ -45,6 +46,8 @@ class Args:
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
debug: bool = False
"""whether to run the script in debug mode"""
tb_dir: str = "runs"
"""the directory to save the tensorboard logs"""
...
...
@@ -156,7 +159,7 @@ class Args:
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
freeze_id
:
bool
=
Fals
e
freeze_id:
Optional[bool] = Non
e
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1, eval=False):
...
...
@@ -253,28 +256,27 @@ def rollout(
@jax.jit
def get_logits(
params
:
flax
.
core
.
FrozenDict
,
inputs
,
done
):
params: flax.core.FrozenDict, inputs):
rstate, logits = create_agent(args).apply(params, inputs)[:2]
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
return rstate, logits
@jax.jit
def get_action(
params: flax.core.FrozenDict, inputs):
batch_size
=
jax
.
tree
.
leaves
(
inputs
)[
0
]
.
shape
[
0
]
done
=
jnp
.
zeros
(
batch_size
,
dtype
=
jnp
.
bool_
)
rstate
,
logits
=
get_logits
(
params
,
inputs
,
done
)
rstate, logits = get_logits(params, inputs)
return rstate, logits.argmax(axis=1)
@jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1
,
logits1
=
get_logits
(
params1
,
(
rstate1
,
obs
)
,
done
)
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
)
,
done
)
next_rstate1, logits1 = get_logits(params1, (rstate1, obs))
next_rstate2, logits2 = get_logits(params2, (rstate2, obs))
logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
rstate2 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
rstate1, rstate2 = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
return rstate1, rstate2, logits.argmax(axis=1)
@jax.jit
...
...
@@ -284,12 +286,14 @@ def rollout(
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
done = jnp.array(done)
main = jnp.array(main)
rstate = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate
,
logits
=
get_logits
(
params
,
(
rstate
,
next_obs
)
,
done
)
rstate, logits = get_logits(params, (rstate, next_obs))
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate1, rstate2 = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
action, key = categorical_sample(logits, key)
return next_obs, done, main, rstate1, rstate2, action, logits, key
...
...
@@ -517,13 +521,18 @@ if __name__ == "__main__":
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
tb_log_dir
=
f
"{args.tb_dir}/{run_name}"
writer
=
SummaryWriter
(
tb_log_dir
)
writer
.
add_text
(
"hyperparameters"
,
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
)
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
tb_log_dir = f"{args.tb_dir}/{run_name}"
if args.local_rank == 0 and not args.debug:
writer = SummaryWriter(tb_log_dir)
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
else:
writer = dummy_writer
def save_fn(obj, path):
with open(path, "wb") as f:
...
...
@@ -669,6 +678,7 @@ if __name__ == "__main__":
learn_opponent: bool = False,
):
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
# TODO: rstate will be out-date after the first update, maybe consider R2D2
next_inputs, init_rstate1, init_rstate2 = [
jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
...
...
@@ -756,8 +766,6 @@ if __name__ == "__main__":
params_queues = []
rollout_queues = []
eval_queue = queue.Queue()
dummy_writer
=
SimpleNamespace
()
dummy_writer
.
add_scalar
=
lambda
x
,
y
,
z
:
None
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
...
...
@@ -844,13 +852,16 @@ if __name__ == "__main__":
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step)
writer.add_scalar("losses/loss", loss, global_step)
if
args
.
local_rank
==
0
and
learner_policy_version
%
args
.
save_interval
==
0
:
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0
and not args.debug
:
M_steps = args.batch_size * learner_policy_version // (2**20)
ckpt_name = f"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger.save(unreplicated_params, ckpt_name)
if args.gcs_bucket is not None:
lastest_path = ckpt_maneger.get_latest()
copy_path = lastest_path.with_name("latest" + lastest_path.suffix)
shutil.copyfile(lastest_path, copy_path)
zip_file_path = "latest.zip"
zip_files
(
zip_file_path
,
[
ckpt_maneger
.
get_latest
(
),
tb_log_dir
])
zip_files(zip_file_path, [
str(copy_path
), tb_log_dir])
sync_to_gcs(args.gcs_bucket, zip_file_path)
if learner_policy_version >= args.num_updates:
...
...
scripts/ppo.py
View file @
81d80f7f
...
...
@@ -47,6 +47,8 @@ class Args:
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to load"""
debug
:
bool
=
False
"""whether to run the script in debug mode"""
tb_dir
:
str
=
"runs"
"""the directory to save the tensorboard logs"""
...
...
@@ -156,7 +158,7 @@ class Args:
actor_devices
:
Optional
[
List
[
str
]]
=
None
learner_devices
:
Optional
[
List
[
str
]]
=
None
num_embeddings
:
Optional
[
int
]
=
None
freeze_id
:
bool
=
Fals
e
freeze_id
:
Optional
[
bool
]
=
Non
e
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
,
eval
=
False
):
...
...
@@ -254,28 +256,27 @@ def rollout(
@
jax
.
jit
def
get_logits
(
params
:
flax
.
core
.
FrozenDict
,
inputs
,
done
):
params
:
flax
.
core
.
FrozenDict
,
inputs
):
rstate
,
logits
=
create_agent
(
args
)
.
apply
(
params
,
inputs
)[:
2
]
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
return
rstate
,
logits
@
jax
.
jit
def
get_action
(
params
:
flax
.
core
.
FrozenDict
,
inputs
):
batch_size
=
jax
.
tree
.
leaves
(
inputs
)[
0
]
.
shape
[
0
]
done
=
jnp
.
zeros
(
batch_size
,
dtype
=
jnp
.
bool_
)
rstate
,
logits
=
get_logits
(
params
,
inputs
,
done
)
rstate
,
logits
=
get_logits
(
params
,
inputs
)
return
rstate
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
get_action_battle
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
logits1
=
get_logits
(
params1
,
(
rstate1
,
obs
)
,
done
)
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
)
,
done
)
next_rstate1
,
logits1
=
get_logits
(
params1
,
(
rstate1
,
obs
))
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
))
logits
=
jnp
.
where
(
main
[:,
None
],
logits1
,
logits2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
next_rstate2
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
return
rstate1
,
rstate2
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
...
...
@@ -285,12 +286,14 @@ def rollout(
next_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
(
x
),
next_obs
)
done
=
jnp
.
array
(
done
)
main
=
jnp
.
array
(
main
)
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
logits
=
get_logits
(
params
,
(
rstate
,
next_obs
)
,
done
)
rstate
,
logits
=
get_logits
(
params
,
(
rstate
,
next_obs
))
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
action
,
key
=
categorical_sample
(
logits
,
key
)
return
next_obs
,
done
,
main
,
rstate1
,
rstate2
,
action
,
logits
,
key
...
...
@@ -532,7 +535,7 @@ if __name__ == "__main__":
dummy_writer
.
add_scalar
=
lambda
x
,
y
,
z
:
None
tb_log_dir
=
f
"{args.tb_dir}/{run_name}"
if
args
.
local_rank
==
0
:
if
args
.
local_rank
==
0
and
not
args
.
debug
:
writer
=
SummaryWriter
(
tb_log_dir
)
writer
.
add_text
(
"hyperparameters"
,
...
...
@@ -692,6 +695,7 @@ if __name__ == "__main__":
learn_opponent
:
bool
=
False
,
):
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
# TODO: rstate will be out-date after the first update, maybe consider R2D2
next_inputs
,
init_rstate1
,
init_rstate2
=
[
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
for
x
in
[
sharded_next_inputs
,
sharded_init_rstate1
,
sharded_init_rstate2
]
...
...
@@ -874,7 +878,7 @@ if __name__ == "__main__":
writer
.
add_scalar
(
"losses/approx_kl"
,
approx_kl
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/loss"
,
loss
,
global_step
)
if
args
.
local_rank
==
0
and
learner_policy_version
%
args
.
save_interval
==
0
:
if
args
.
local_rank
==
0
and
learner_policy_version
%
args
.
save_interval
==
0
and
not
args
.
debug
:
M_steps
=
args
.
batch_size
*
learner_policy_version
//
(
2
**
20
)
ckpt_name
=
f
"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger
.
save
(
unreplicated_params
,
ckpt_name
)
...
...
ygoai/rl/jax/agent2.py
View file @
81d80f7f
...
...
@@ -169,8 +169,6 @@ class Encoder(nn.Module):
fc_layer
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
param_dtype
=
self
.
param_dtype
)
id_embed
=
embed
(
n_embed
,
embed_dim
)
if
self
.
freeze_id
:
id_embed
=
lambda
x
:
jax
.
lax
.
stop_gradient
(
id_embed
(
x
))
action_encoder
=
ActionEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
...
...
@@ -184,6 +182,8 @@ class Encoder(nn.Module):
x_id
=
decode_id
(
x_cards
[:,
:,
:
2
]
.
astype
(
jnp
.
int32
))
x_id
=
id_embed
(
x_id
)
if
self
.
freeze_id
:
x_id
=
jax
.
lax
.
stop_gradient
(
x_id
)
# Cards
f_cards
=
CardEncoder
(
...
...
@@ -215,9 +215,12 @@ class Encoder(nn.Module):
h_mask
=
h_mask
.
at
[:,
0
]
.
set
(
False
)
x_h_id
=
decode_id
(
x_h_actions
[
...
,
:
2
])
x_h_id
=
id_embed
(
x_h_id
)
if
self
.
freeze_id
:
x_h_id
=
jax
.
lax
.
stop_gradient
(
x_h_id
)
x_h_id
=
MLP
(
(
c
,
c
),
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
default_fc_init2
)(
id_embed
(
x_h_id
)
)
kernel_init
=
default_fc_init2
)(
x_h_id
)
x_h_a_feats1
=
action_encoder
(
x_h_actions
[:,
:,
2
:
13
])
...
...
@@ -379,9 +382,9 @@ class PPOLSTMAgent(nn.Module):
rstate1
,
rstate2
=
carry
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
y
=
cell
(
rstate
,
x
)
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
return
(
rstate1
,
rstate2
),
y
scan
=
nn
.
scan
(
body_fn
,
variable_broadcast
=
'params'
,
...
...
ygoai/utils.py
View file @
81d80f7f
...
...
@@ -48,7 +48,7 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False):
return
deck_name
def
load_embeddings
(
embedding_file
,
code_list_file
):
def
load_embeddings
(
embedding_file
,
code_list_file
,
pad_to
=
999
):
with
open
(
embedding_file
,
"rb"
)
as
f
:
embeddings
=
pickle
.
load
(
f
)
with
open
(
code_list_file
,
"r"
)
as
f
:
...
...
@@ -56,4 +56,8 @@ def load_embeddings(embedding_file, code_list_file):
code_list
=
[
int
(
code
.
strip
())
for
code
in
code_list
]
assert
len
(
embeddings
)
==
len
(
code_list
),
f
"len(embeddings)={len(embeddings)}, len(code_list)={len(code_list)}"
embeddings
=
np
.
array
([
embeddings
[
code
]
for
code
in
code_list
],
dtype
=
np
.
float32
)
if
pad_to
is
not
None
:
assert
pad_to
>=
len
(
embeddings
),
f
"pad_to={pad_to} < len(embeddings)={len(embeddings)}"
pad
=
np
.
zeros
((
pad_to
-
len
(
embeddings
),
embeddings
.
shape
[
1
]),
dtype
=
np
.
float32
)
embeddings
=
np
.
concatenate
([
embeddings
,
pad
],
axis
=
0
)
return
embeddings
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment