Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
ff67e8c4
Commit
ff67e8c4
authored
Mar 13, 2024
by
Biluo Shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add partial GAE
parent
745c67f9
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
55 additions
and
518 deletions
+55
-518
scripts/ppo.py
scripts/ppo.py
+46
-18
scripts/ppo2.py
scripts/ppo2.py
+0
-489
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+9
-11
No files found.
scripts/ppo.py
View file @
ff67e8c4
...
@@ -4,7 +4,6 @@ import time
...
@@ -4,7 +4,6 @@ import time
from
collections
import
deque
from
collections
import
deque
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Literal
,
Optional
from
typing
import
Literal
,
Optional
import
pickle
import
ygoenv
import
ygoenv
import
numpy
as
np
import
numpy
as
np
...
@@ -99,6 +98,8 @@ class Args:
...
@@ -99,6 +98,8 @@ class Args:
"""the target KL divergence threshold"""
"""the target KL divergence threshold"""
learn_opponent
:
bool
=
True
learn_opponent
:
bool
=
True
"""if toggled, the samples from the opponent will be used to train the agent"""
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length
:
int
=
None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend
:
Literal
[
"gloo"
,
"nccl"
,
"mpi"
]
=
"nccl"
backend
:
Literal
[
"gloo"
,
"nccl"
,
"mpi"
]
=
"nccl"
"""the backend for distributed training"""
"""the backend for distributed training"""
...
@@ -156,6 +157,9 @@ def main():
...
@@ -156,6 +157,9 @@ def main():
args
.
num_iterations
=
args
.
total_timesteps
//
args
.
batch_size
args
.
num_iterations
=
args
.
total_timesteps
//
args
.
batch_size
args
.
env_threads
=
args
.
env_threads
or
args
.
num_envs
args
.
env_threads
=
args
.
env_threads
or
args
.
num_envs
args
.
torch_threads
=
args
.
torch_threads
or
(
int
(
os
.
getenv
(
"OMP_NUM_THREADS"
,
"2"
))
*
args
.
world_size
)
args
.
torch_threads
=
args
.
torch_threads
or
(
int
(
os
.
getenv
(
"OMP_NUM_THREADS"
,
"2"
))
*
args
.
world_size
)
args
.
collect_length
=
args
.
collect_length
or
args
.
num_steps
assert
args
.
collect_length
>=
args
.
num_steps
,
"collect_length must be greater than or equal to num_steps"
local_torch_threads
=
args
.
torch_threads
//
args
.
world_size
local_torch_threads
=
args
.
torch_threads
//
args
.
world_size
local_env_threads
=
args
.
env_threads
//
args
.
world_size
local_env_threads
=
args
.
env_threads
//
args
.
world_size
...
@@ -279,13 +283,13 @@ def main():
...
@@ -279,13 +283,13 @@ def main():
traced_model
=
agent
traced_model
=
agent
# ALGO Logic: Storage setup
# ALGO Logic: Storage setup
obs
=
create_obs
(
obs_space
,
(
args
.
num_steps
,
args
.
local_num_envs
),
device
)
obs
=
create_obs
(
obs_space
,
(
args
.
collect_length
,
args
.
local_num_envs
),
device
)
actions
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
)
+
action_shape
)
.
to
(
device
)
actions
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
)
+
action_shape
)
.
to
(
device
)
logprobs
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
))
.
to
(
device
)
logprobs
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
rewards
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
))
.
to
(
device
)
rewards
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
dones
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
dones
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
values
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
))
.
to
(
device
)
values
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
learns
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
learns
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
...
@@ -305,6 +309,7 @@ def main():
...
@@ -305,6 +309,7 @@ def main():
np
.
random
.
shuffle
(
ai_player1_
)
np
.
random
.
shuffle
(
ai_player1_
)
ai_player1
=
to_tensor
(
ai_player1_
,
device
,
dtype
=
next_to_play
.
dtype
)
ai_player1
=
to_tensor
(
ai_player1_
,
device
,
dtype
=
next_to_play
.
dtype
)
next_value1
=
next_value2
=
0
next_value1
=
next_value2
=
0
step
=
0
for
iteration
in
range
(
1
,
args
.
num_iterations
+
1
):
for
iteration
in
range
(
1
,
args
.
num_iterations
+
1
):
# Annealing the rate if instructed to do so.
# Annealing the rate if instructed to do so.
...
@@ -316,7 +321,7 @@ def main():
...
@@ -316,7 +321,7 @@ def main():
model_time
=
0
model_time
=
0
env_time
=
0
env_time
=
0
collect_start
=
time
.
time
()
collect_start
=
time
.
time
()
for
step
in
range
(
0
,
args
.
num_steps
)
:
while
step
<
args
.
collect_length
:
global_step
+=
args
.
num_envs
global_step
+=
args
.
num_envs
for
key
in
obs
:
for
key
in
obs
:
...
@@ -350,6 +355,7 @@ def main():
...
@@ -350,6 +355,7 @@ def main():
env_time
+=
time
.
time
()
-
_start
env_time
+=
time
.
time
()
-
_start
rewards
[
step
]
=
to_tensor
(
reward
,
device
)
rewards
[
step
]
=
to_tensor
(
reward
,
device
)
next_obs
,
next_done
=
to_tensor
(
next_obs
,
device
,
torch
.
uint8
),
to_tensor
(
next_done_
,
device
,
torch
.
bool
)
next_obs
,
next_done
=
to_tensor
(
next_obs
,
device
,
torch
.
uint8
),
to_tensor
(
next_done_
,
device
,
torch
.
bool
)
step
+=
1
if
not
writer
:
if
not
writer
:
continue
continue
...
@@ -378,29 +384,44 @@ def main():
...
@@ -378,29 +384,44 @@ def main():
if
local_rank
==
0
:
if
local_rank
==
0
:
fprint
(
f
"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}"
)
fprint
(
f
"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}"
)
step
=
args
.
collect_length
-
args
.
num_steps
_start
=
time
.
time
()
_start
=
time
.
time
()
# bootstrap value if not done
# bootstrap value if not done
with
torch
.
no_grad
():
with
torch
.
no_grad
():
value
=
traced_model
(
next_obs
)[
1
]
.
reshape
(
-
1
)
value
=
traced_model
(
next_obs
)[
1
]
.
reshape
(
-
1
)
nextvalues1
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
next_value1
)
nextvalues1
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
next_value1
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
ai_player1
,
value
,
next_value2
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
ai_player1
,
value
,
next_value2
)
if
step
>
0
and
iteration
!=
1
:
# recalculate the values for the first few steps
v_steps
=
args
.
local_minibatch_size
*
4
//
args
.
local_num_envs
for
v_start
in
range
(
0
,
step
,
v_steps
):
v_end
=
min
(
v_start
+
v_steps
,
step
)
v_obs
=
{
k
:
v
[
v_start
:
v_end
]
.
flatten
(
0
,
1
)
for
k
,
v
in
obs
.
items
()
}
with
torch
.
no_grad
():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value
=
traced_model
(
v_obs
)[
1
]
.
reshape
(
v_end
-
v_start
,
-
1
)
values
[
v_start
:
v_end
]
=
value
advantages
=
bootstrap_value_selfplay
(
advantages
=
bootstrap_value_selfplay
(
values
,
rewards
,
dones
,
learns
,
nextvalues1
,
nextvalues2
,
next_done
,
args
.
gamma
,
args
.
gae_lambda
)
values
,
rewards
,
dones
,
learns
,
nextvalues1
,
nextvalues2
,
next_done
,
args
.
gamma
,
args
.
gae_lambda
)
returns
=
advantages
+
values
bootstrap_time
=
time
.
time
()
-
_start
bootstrap_time
=
time
.
time
()
-
_start
_start
=
time
.
time
()
_start
=
time
.
time
()
# flatten the batch
# flatten the batch
b_obs
=
{
b_obs
=
{
k
:
v
.
reshape
((
-
1
,)
+
v
.
shape
[
2
:])
k
:
v
[:
args
.
num_steps
]
.
reshape
((
-
1
,)
+
v
.
shape
[
2
:])
for
k
,
v
in
obs
.
items
()
for
k
,
v
in
obs
.
items
()
}
}
b_
logprobs
=
logprobs
.
reshape
(
-
1
)
b_
actions
=
actions
[:
args
.
num_steps
]
.
reshape
((
-
1
,)
+
action_shape
)
b_
actions
=
actions
.
reshape
((
-
1
,)
+
action_shape
)
b_
logprobs
=
logprobs
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_advantages
=
advantages
.
reshape
(
-
1
)
b_advantages
=
advantages
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_
returns
=
returns
.
reshape
(
-
1
)
b_
values
=
values
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_
values
=
values
.
reshape
(
-
1
)
b_
learns
=
learns
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_
learns
=
learns
.
reshape
(
-
1
)
b_
returns
=
b_advantages
+
b_values
# Optimizing the policy and value network
# Optimizing the policy and value network
b_inds
=
np
.
arange
(
args
.
local_batch_size
)
b_inds
=
np
.
arange
(
args
.
local_batch_size
)
...
@@ -424,7 +445,14 @@ def main():
...
@@ -424,7 +445,14 @@ def main():
if
args
.
target_kl
is
not
None
and
approx_kl
>
args
.
target_kl
:
if
args
.
target_kl
is
not
None
and
approx_kl
>
args
.
target_kl
:
break
break
if
step
>
0
:
# TODO: use cyclic buffer to avoid copying
for
v
in
obs
.
values
():
v
[:
step
]
=
v
[
args
.
num_steps
:]
.
clone
()
for
v
in
[
actions
,
logprobs
,
rewards
,
dones
,
values
,
learns
]:
v
[:
step
]
=
v
[
args
.
num_steps
:]
.
clone
()
train_time
=
time
.
time
()
-
_start
train_time
=
time
.
time
()
-
_start
if
local_rank
==
0
:
if
local_rank
==
0
:
...
...
scripts/ppo2.py
deleted
100644 → 0
View file @
745c67f9
This diff is collapsed.
Click to expand it.
ygoenv/ygoenv/ygopro/ygopro.h
View file @
ff67e8c4
...
@@ -112,8 +112,8 @@ inline bool sum_to2(const std::vector<std::vector<int>> &w,
...
@@ -112,8 +112,8 @@ inline bool sum_to2(const std::vector<std::vector<int>> &w,
}
}
inline
std
::
vector
<
std
::
vector
<
int
>>
inline
std
::
vector
<
std
::
vector
<
int
>>
combinations_with_weight2
(
const
std
::
vector
<
std
::
vector
<
int
>>
&
weights
,
combinations_with_weight2
(
int
r
)
{
const
std
::
vector
<
std
::
vector
<
int
>>
&
weights
,
int
r
)
{
int
n
=
weights
.
size
();
int
n
=
weights
.
size
();
std
::
vector
<
std
::
vector
<
int
>>
results
;
std
::
vector
<
std
::
vector
<
int
>>
results
;
...
@@ -1771,9 +1771,9 @@ public:
...
@@ -1771,9 +1771,9 @@ public:
uint8_t
msg_id
=
uint8_t
(
ha
(
i
,
2
));
uint8_t
msg_id
=
uint8_t
(
ha
(
i
,
2
));
int
msg
=
_msgs
[
msg_id
-
1
];
int
msg
=
_msgs
[
msg_id
-
1
];
fmt
::
print
(
"msg: {},"
,
msg_to_string
(
msg
));
fmt
::
print
(
"msg: {},"
,
msg_to_string
(
msg
));
auto
v1
=
static_cast
<
CardId
>
(
ha
(
i
,
0
)
);
uint8_t
v1
=
ha
(
i
,
0
);
auto
v2
=
static_cast
<
CardId
>
(
ha
(
i
,
1
)
);
uint8_t
v2
=
ha
(
i
,
1
);
CardId
card_id
=
(
v1
<<
8
)
+
v2
;
CardId
card_id
=
(
static_cast
<
CardId
>
(
v1
)
<<
8
)
+
static_cast
<
CardId
>
(
v2
)
;
fmt
::
print
(
" {};"
,
card_id
);
fmt
::
print
(
" {};"
,
card_id
);
for
(
int
j
=
3
;
j
<
ha
.
Shape
()[
1
];
j
++
)
{
for
(
int
j
=
3
;
j
<
ha
.
Shape
()[
1
];
j
++
)
{
fmt
::
print
(
" {}"
,
uint8_t
(
ha
(
i
,
j
)));
fmt
::
print
(
" {}"
,
uint8_t
(
ha
(
i
,
j
)));
...
@@ -2326,15 +2326,13 @@ private:
...
@@ -2326,15 +2326,13 @@ private:
for
(
int
i
=
0
;
i
<
n_options
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n_options
;
++
i
)
{
uint8_t
spec_index1
=
state
[
"obs:actions_"
_
](
i
,
0
);
uint8_t
spec_index1
=
state
[
"obs:actions_"
_
](
i
,
0
);
uint8_t
spec_index2
=
state
[
"obs:actions_"
_
](
i
,
1
);
uint8_t
spec_index2
=
state
[
"obs:actions_"
_
](
i
,
1
);
uint16_t
spec_index
=
(
s
pec_index1
<<
8
)
+
spec_index2
;
uint16_t
spec_index
=
(
s
tatic_cast
<
uint16_t
>
(
spec_index1
)
<<
8
)
+
static_cast
<
uint16_t
>
(
spec_index2
)
;
if
(
spec_index
==
0
)
{
if
(
spec_index
==
0
)
{
h_card_ids
[
i
]
=
0
;
h_card_ids
[
i
]
=
0
;
}
else
{
}
else
{
uint16_t
card_id1
=
uint8_t
card_id1
=
state
[
"obs:cards_"
_
](
spec_index
-
1
,
0
);
static_cast
<
uint16_t
>
(
state
[
"obs:cards_"
_
](
spec_index
-
1
,
0
));
uint8_t
card_id2
=
state
[
"obs:cards_"
_
](
spec_index
-
1
,
1
);
uint16_t
card_id2
=
h_card_ids
[
i
]
=
(
static_cast
<
uint16_t
>
(
card_id1
)
<<
8
)
+
static_cast
<
uint16_t
>
(
card_id2
);
static_cast
<
uint16_t
>
(
state
[
"obs:cards_"
_
](
spec_index
-
1
,
1
));
h_card_ids
[
i
]
=
(
card_id1
<<
8
)
+
card_id2
;
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment