Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
ff67e8c4
Commit
ff67e8c4
authored
Mar 13, 2024
by
Biluo Shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add partial GAE
parent
745c67f9
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
55 additions
and
518 deletions
+55
-518
scripts/ppo.py
scripts/ppo.py
+46
-18
scripts/ppo2.py
scripts/ppo2.py
+0
-489
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+9
-11
No files found.
scripts/ppo.py
View file @
ff67e8c4
...
...
@@ -4,7 +4,6 @@ import time
from
collections
import
deque
from
dataclasses
import
dataclass
from
typing
import
Literal
,
Optional
import
pickle
import
ygoenv
import
numpy
as
np
...
...
@@ -99,6 +98,8 @@ class Args:
"""the target KL divergence threshold"""
learn_opponent
:
bool
=
True
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length
:
int
=
None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend
:
Literal
[
"gloo"
,
"nccl"
,
"mpi"
]
=
"nccl"
"""the backend for distributed training"""
...
...
@@ -156,6 +157,9 @@ def main():
args
.
num_iterations
=
args
.
total_timesteps
//
args
.
batch_size
args
.
env_threads
=
args
.
env_threads
or
args
.
num_envs
args
.
torch_threads
=
args
.
torch_threads
or
(
int
(
os
.
getenv
(
"OMP_NUM_THREADS"
,
"2"
))
*
args
.
world_size
)
args
.
collect_length
=
args
.
collect_length
or
args
.
num_steps
assert
args
.
collect_length
>=
args
.
num_steps
,
"collect_length must be greater than or equal to num_steps"
local_torch_threads
=
args
.
torch_threads
//
args
.
world_size
local_env_threads
=
args
.
env_threads
//
args
.
world_size
...
...
@@ -279,13 +283,13 @@ def main():
traced_model
=
agent
# ALGO Logic: Storage setup
obs
=
create_obs
(
obs_space
,
(
args
.
num_steps
,
args
.
local_num_envs
),
device
)
actions
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
)
+
action_shape
)
.
to
(
device
)
logprobs
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
))
.
to
(
device
)
rewards
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
))
.
to
(
device
)
dones
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
values
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
))
.
to
(
device
)
learns
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
obs
=
create_obs
(
obs_space
,
(
args
.
collect_length
,
args
.
local_num_envs
),
device
)
actions
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
)
+
action_shape
)
.
to
(
device
)
logprobs
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
rewards
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
dones
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
values
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
learns
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
...
...
@@ -305,6 +309,7 @@ def main():
np
.
random
.
shuffle
(
ai_player1_
)
ai_player1
=
to_tensor
(
ai_player1_
,
device
,
dtype
=
next_to_play
.
dtype
)
next_value1
=
next_value2
=
0
step
=
0
for
iteration
in
range
(
1
,
args
.
num_iterations
+
1
):
# Annealing the rate if instructed to do so.
...
...
@@ -316,7 +321,7 @@ def main():
model_time
=
0
env_time
=
0
collect_start
=
time
.
time
()
for
step
in
range
(
0
,
args
.
num_steps
)
:
while
step
<
args
.
collect_length
:
global_step
+=
args
.
num_envs
for
key
in
obs
:
...
...
@@ -350,6 +355,7 @@ def main():
env_time
+=
time
.
time
()
-
_start
rewards
[
step
]
=
to_tensor
(
reward
,
device
)
next_obs
,
next_done
=
to_tensor
(
next_obs
,
device
,
torch
.
uint8
),
to_tensor
(
next_done_
,
device
,
torch
.
bool
)
step
+=
1
if
not
writer
:
continue
...
...
@@ -378,29 +384,44 @@ def main():
if
local_rank
==
0
:
fprint
(
f
"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}"
)
step
=
args
.
collect_length
-
args
.
num_steps
_start
=
time
.
time
()
# bootstrap value if not done
with
torch
.
no_grad
():
value
=
traced_model
(
next_obs
)[
1
]
.
reshape
(
-
1
)
nextvalues1
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
next_value1
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
ai_player1
,
value
,
next_value2
)
if
step
>
0
and
iteration
!=
1
:
# recalculate the values for the first few steps
v_steps
=
args
.
local_minibatch_size
*
4
//
args
.
local_num_envs
for
v_start
in
range
(
0
,
step
,
v_steps
):
v_end
=
min
(
v_start
+
v_steps
,
step
)
v_obs
=
{
k
:
v
[
v_start
:
v_end
]
.
flatten
(
0
,
1
)
for
k
,
v
in
obs
.
items
()
}
with
torch
.
no_grad
():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value
=
traced_model
(
v_obs
)[
1
]
.
reshape
(
v_end
-
v_start
,
-
1
)
values
[
v_start
:
v_end
]
=
value
advantages
=
bootstrap_value_selfplay
(
values
,
rewards
,
dones
,
learns
,
nextvalues1
,
nextvalues2
,
next_done
,
args
.
gamma
,
args
.
gae_lambda
)
returns
=
advantages
+
values
bootstrap_time
=
time
.
time
()
-
_start
_start
=
time
.
time
()
# flatten the batch
b_obs
=
{
k
:
v
.
reshape
((
-
1
,)
+
v
.
shape
[
2
:])
k
:
v
[:
args
.
num_steps
]
.
reshape
((
-
1
,)
+
v
.
shape
[
2
:])
for
k
,
v
in
obs
.
items
()
}
b_
logprobs
=
logprobs
.
reshape
(
-
1
)
b_
actions
=
actions
.
reshape
((
-
1
,)
+
action_shape
)
b_advantages
=
advantages
.
reshape
(
-
1
)
b_
returns
=
returns
.
reshape
(
-
1
)
b_
values
=
values
.
reshape
(
-
1
)
b_
learns
=
learns
.
reshape
(
-
1
)
b_
actions
=
actions
[:
args
.
num_steps
]
.
reshape
((
-
1
,)
+
action_shape
)
b_
logprobs
=
logprobs
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_advantages
=
advantages
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_
values
=
values
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_
learns
=
learns
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_
returns
=
b_advantages
+
b_values
# Optimizing the policy and value network
b_inds
=
np
.
arange
(
args
.
local_batch_size
)
...
...
@@ -424,7 +445,14 @@ def main():
if
args
.
target_kl
is
not
None
and
approx_kl
>
args
.
target_kl
:
break
if
step
>
0
:
# TODO: use cyclic buffer to avoid copying
for
v
in
obs
.
values
():
v
[:
step
]
=
v
[
args
.
num_steps
:]
.
clone
()
for
v
in
[
actions
,
logprobs
,
rewards
,
dones
,
values
,
learns
]:
v
[:
step
]
=
v
[
args
.
num_steps
:]
.
clone
()
train_time
=
time
.
time
()
-
_start
if
local_rank
==
0
:
...
...
scripts/ppo2.py
deleted
100644 → 0
View file @
745c67f9
This diff is collapsed.
Click to expand it.
ygoenv/ygoenv/ygopro/ygopro.h
View file @
ff67e8c4
...
...
@@ -112,8 +112,8 @@ inline bool sum_to2(const std::vector<std::vector<int>> &w,
}
inline std::vector<std::vector<int>>
combinations_with_weight2(
const std::vector<std::vector<int>> &weights,
int r) {
combinations_with_weight2(
const std::vector<std::vector<int>> &weights,
int r) {
int n = weights.size();
std::vector<std::vector<int>> results;
...
...
@@ -1771,9 +1771,9 @@ public:
uint8_t msg_id = uint8_t(ha(i, 2));
int msg = _msgs[msg_id - 1];
fmt::print("msg: {},", msg_to_string(msg));
auto v1 = static_cast<CardId>(ha(i, 0)
);
auto v2 = static_cast<CardId>(ha(i, 1)
);
CardId card_id = (
v1 << 8) + v2
;
uint8_t v1 = ha(i, 0
);
uint8_t v2 = ha(i, 1
);
CardId card_id = (
static_cast<CardId>(v1) << 8) + static_cast<CardId>(v2)
;
fmt::print(" {};", card_id);
for (int j = 3; j < ha.Shape()[1]; j++) {
fmt::print(" {}", uint8_t(ha(i, j)));
...
...
@@ -2326,15 +2326,13 @@ private:
for (int i = 0; i < n_options; ++i) {
uint8_t spec_index1 = state["obs:actions_"_](i, 0);
uint8_t spec_index2 = state["obs:actions_"_](i, 1);
uint16_t spec_index = (s
pec_index1 << 8) + spec_index2
;
uint16_t spec_index = (s
tatic_cast<uint16_t>(spec_index1) << 8) + static_cast<uint16_t>(spec_index2)
;
if (spec_index == 0) {
h_card_ids[i] = 0;
} else {
uint16_t card_id1 =
static_cast<uint16_t>(state["obs:cards_"_](spec_index - 1, 0));
uint16_t card_id2 =
static_cast<uint16_t>(state["obs:cards_"_](spec_index - 1, 1));
h_card_ids[i] = (card_id1 << 8) + card_id2;
uint8_t card_id1 = state["obs:cards_"_](spec_index - 1, 0);
uint8_t card_id2 = state["obs:cards_"_](spec_index - 1, 1);
h_card_ids[i] = (static_cast<uint16_t>(card_id1) << 8) + static_cast<uint16_t>(card_id2);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment