Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
3c7a7080
Commit
3c7a7080
authored
Feb 20, 2024
by
biluo.shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add PPO
parent
ad9c3c34
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
543 additions
and
16 deletions
+543
-16
scripts/dmc.py
scripts/dmc.py
+7
-7
scripts/ppo.py
scripts/ppo.py
+418
-0
ygoai/rl/agent.py
ygoai/rl/agent.py
+75
-9
ygoai/rl/dist.py
ygoai/rl/dist.py
+43
-0
No files found.
scripts/dmc.py
View file @
3c7a7080
...
...
@@ -60,17 +60,17 @@ class Args:
total_timesteps
:
int
=
100000000
"""total timesteps of the experiments"""
learning_rate
:
float
=
2.
5e-4
learning_rate
:
float
=
5e-4
"""the learning rate of the optimizer"""
num_envs
:
int
=
64
"""the number of parallel game environments"""
num_steps
:
int
=
1
00
num_steps
:
int
=
2
00
"""the number of steps per env per iteration"""
buffer_size
:
int
=
20000
0
buffer_size
:
int
=
20000
"""the replay memory buffer size"""
gamma
:
float
=
0.99
"""the discount factor gamma"""
minibatch_size
:
int
=
256
minibatch_size
:
int
=
1024
"""the mini-batch size"""
eps
:
float
=
0.05
"""the epsilon for exploration"""
...
...
@@ -264,13 +264,13 @@ if __name__ == "__main__":
# ALGO LOGIC: training.
_start
=
time
.
time
()
b_inds
=
rb
.
get_data_indices
()
if
len
(
b_inds
)
<
args
.
minibatch_size
:
if
not
rb
.
full
:
continue
b_inds
=
rb
.
get_data_indices
()
np
.
random
.
shuffle
(
b_inds
)
b_obs
,
b_actions
,
b_returns
=
rb
.
_get_samples
(
b_inds
)
sample_time
+=
time
.
time
()
-
_start
for
start
in
range
(
0
,
len
(
b_
ind
s
),
args
.
minibatch_size
):
for
start
in
range
(
0
,
len
(
b_
return
s
),
args
.
minibatch_size
):
_start
=
time
.
time
()
end
=
start
+
args
.
minibatch_size
mb_obs
=
{
...
...
scripts/ppo.py
0 → 100644
View file @
3c7a7080
import
os
import
random
import
time
from
dataclasses
import
dataclass
from
typing
import
Literal
,
Optional
import
ygoenv
import
numpy
as
np
import
optree
import
tyro
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torch.distributed
as
dist
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.agent
import
PPOAgent
as
Agent
from
ygoai.rl.dist
import
reduce_gradidents
,
mp_start
,
setup
from
ygoai.rl.buffer
import
create_obs
@
dataclass
class
Args
:
exp_name
:
str
=
os
.
path
.
basename
(
__file__
)[:
-
len
(
".py"
)]
"""the name of this experiment"""
seed
:
int
=
1
"""seed of the experiment"""
torch_deterministic
:
bool
=
True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda
:
bool
=
True
"""if toggled, cuda will be enabled by default"""
# Algorithm specific arguments
env_id
:
str
=
"YGOPro-v0"
"""the id of the environment"""
deck
:
str
=
"../assets/deck/OldSchool.ydk"
"""the deck file to use"""
deck1
:
Optional
[
str
]
=
None
"""the deck file for the first player"""
deck2
:
Optional
[
str
]
=
None
"""the deck file for the second player"""
code_list_file
:
str
=
"code_list.txt"
"""the code list file for card embeddings"""
embedding_file
:
str
=
"embeddings_en.npy"
"""the embedding file for card embeddings"""
max_options
:
int
=
24
"""the maximum number of options"""
n_history_actions
:
int
=
8
"""the number of history actions to use"""
play_mode
:
str
=
"self"
"""the play mode, can be combination of 'self', 'bot', 'random', like 'self+bot'"""
num_layers
:
int
=
2
"""the number of layers for the agent"""
num_channels
:
int
=
128
"""the number of channels for the agent"""
total_timesteps
:
int
=
100000000
"""total timesteps of the experiments"""
learning_rate
:
float
=
2.5e-4
"""the learning rate of the optimizer"""
num_envs
:
int
=
8
"""the number of parallel game environments"""
num_steps
:
int
=
128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr
:
bool
=
True
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
0.99
"""the discount factor gamma"""
gae_lambda
:
float
=
0.95
"""the lambda for the general advantage estimation"""
minibatch_size
:
int
=
256
"""the mini-batch size"""
update_epochs
:
int
=
4
"""the K epochs to update the policy"""
norm_adv
:
bool
=
True
"""Toggles advantages normalization"""
clip_coef
:
float
=
0.1
"""the surrogate clipping coefficient"""
clip_vloss
:
bool
=
True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
ent_coef
:
float
=
0.01
"""coefficient of the entropy"""
vf_coef
:
float
=
0.5
"""coefficient of the value function"""
max_grad_norm
:
float
=
0.5
"""the maximum norm for the gradient clipping"""
target_kl
:
Optional
[
float
]
=
None
"""the target KL divergence threshold"""
backend
:
Literal
[
"gloo"
,
"nccl"
,
"mpi"
]
=
"nccl"
"""the backend for distributed training"""
compile
:
bool
=
True
"""whether to use torch.compile to compile the model and functions"""
torch_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for envpool, defaults to `num_envs`"""
tb_dir
:
str
=
"./runs"
"""tensorboard log directory"""
port
:
int
=
12355
"""the port to use for distributed training"""
# to be filled in runtime
local_batch_size
:
int
=
0
"""the local batch size in the local rank (computed in runtime)"""
local_minibatch_size
:
int
=
0
"""the local mini-batch size in the local rank (computed in runtime)"""
local_num_envs
:
int
=
0
"""the number of parallel game environments (in the local rank, computed in runtime)"""
batch_size
:
int
=
0
"""the batch size (computed in runtime)"""
num_iterations
:
int
=
0
"""the number of iterations (computed in runtime)"""
world_size
:
int
=
0
"""the number of processes (computed in runtime)"""
def
run
(
local_rank
,
world_size
):
args
=
tyro
.
cli
(
Args
)
args
.
world_size
=
world_size
args
.
local_num_envs
=
args
.
num_envs
//
args
.
world_size
args
.
local_batch_size
=
int
(
args
.
local_num_envs
*
args
.
num_steps
)
args
.
local_minibatch_size
=
int
(
args
.
minibatch_size
//
args
.
world_size
)
args
.
batch_size
=
int
(
args
.
num_envs
*
args
.
num_steps
)
args
.
num_iterations
=
args
.
total_timesteps
//
args
.
batch_size
args
.
env_threads
=
args
.
env_threads
or
args
.
num_envs
args
.
torch_threads
=
args
.
torch_threads
or
(
int
(
os
.
getenv
(
"OMP_NUM_THREADS"
,
"2"
))
*
args
.
world_size
)
local_torch_threads
=
args
.
torch_threads
//
args
.
world_size
local_env_threads
=
args
.
env_threads
//
args
.
world_size
torch
.
set_num_threads
(
local_torch_threads
)
torch
.
set_float32_matmul_precision
(
'high'
)
if
args
.
world_size
>
1
:
setup
(
args
.
backend
,
local_rank
,
args
.
world_size
,
args
.
port
)
timestamp
=
int
(
time
.
time
())
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer
=
None
if
local_rank
==
0
:
from
torch.utils.tensorboard
import
SummaryWriter
writer
=
SummaryWriter
(
os
.
path
.
join
(
args
.
tb_dir
,
run_name
))
writer
.
add_text
(
"hyperparameters"
,
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
)
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args
.
seed
+=
local_rank
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
-
local_rank
)
if
args
.
torch_deterministic
:
torch
.
backends
.
cudnn
.
deterministic
=
True
else
:
torch
.
backends
.
cudnn
.
benchmark
=
True
device
=
torch
.
device
(
f
"cuda:{local_rank}"
if
torch
.
cuda
.
is_available
()
and
args
.
cuda
else
"cpu"
)
deck
=
init_ygopro
(
"english"
,
args
.
deck
,
args
.
code_list_file
)
args
.
deck1
=
args
.
deck1
or
deck
args
.
deck2
=
args
.
deck2
or
deck
# env setup
envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
num_envs
=
args
.
local_num_envs
,
num_threads
=
local_env_threads
,
seed
=
args
.
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
args
.
play_mode
,
)
envs
.
num_envs
=
args
.
local_num_envs
obs_space
=
envs
.
observation_space
action_shape
=
envs
.
action_space
.
shape
if
local_rank
==
0
:
print
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
envs
=
RecordEpisodeStatistics
(
envs
)
embeddings
=
np
.
load
(
args
.
embedding_file
)
L
=
args
.
num_layers
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
1
,
embeddings
.
shape
)
.
to
(
device
)
agent
.
load_embeddings
(
embeddings
)
if
args
.
compile
:
agent
.
get_action_and_value
=
torch
.
compile
(
agent
.
get_action_and_value
,
mode
=
'reduce-overhead'
)
optimizer
=
optim
.
Adam
(
agent
.
parameters
(),
lr
=
args
.
learning_rate
,
eps
=
1e-5
)
def
masked_mean
(
x
,
valid
):
x
=
x
.
masked_fill
(
~
valid
,
0
)
return
x
.
sum
()
/
valid
.
float
()
.
sum
()
def
train_step
(
agent
,
mb_obs
,
mb_actions
,
mb_logprobs
,
mb_advantages
,
mb_returns
,
mb_values
):
_
,
newlogprob
,
entropy
,
newvalue
,
valid
=
agent
.
get_action_and_value
(
mb_obs
,
mb_actions
.
long
())
logratio
=
newlogprob
-
mb_logprobs
ratio
=
logratio
.
exp
()
with
torch
.
no_grad
():
# calculate approx_kl http://joschu.net/blog/kl-approx.html
old_approx_kl
=
(
-
logratio
)
.
mean
()
approx_kl
=
((
ratio
-
1
)
-
logratio
)
.
mean
()
clipfrac
=
((
ratio
-
1.0
)
.
abs
()
>
args
.
clip_coef
)
.
float
()
.
mean
()
if
args
.
norm_adv
:
mb_advantages
=
(
mb_advantages
-
mb_advantages
.
mean
())
/
(
mb_advantages
.
std
()
+
1e-8
)
# Policy loss
pg_loss1
=
-
mb_advantages
*
ratio
pg_loss2
=
-
mb_advantages
*
torch
.
clamp
(
ratio
,
1
-
args
.
clip_coef
,
1
+
args
.
clip_coef
)
pg_loss
=
torch
.
max
(
pg_loss1
,
pg_loss2
)
pg_loss
=
masked_mean
(
pg_loss
,
valid
)
# Value loss
newvalue
=
newvalue
.
view
(
-
1
)
if
args
.
clip_vloss
:
v_loss_unclipped
=
(
newvalue
-
mb_returns
)
**
2
v_clipped
=
mb_values
+
torch
.
clamp
(
newvalue
-
mb_values
,
-
args
.
clip_coef
,
args
.
clip_coef
,
)
v_loss_clipped
=
(
v_clipped
-
mb_returns
)
**
2
v_loss_max
=
torch
.
max
(
v_loss_unclipped
,
v_loss_clipped
)
v_loss
=
0.5
*
v_loss_max
else
:
v_loss
=
0.5
*
((
newvalue
-
mb_returns
)
**
2
)
v_loss
=
masked_mean
(
v_loss
,
valid
)
entropy_loss
=
masked_mean
(
entropy
,
valid
)
loss
=
pg_loss
-
args
.
ent_coef
*
entropy_loss
+
v_loss
*
args
.
vf_coef
optimizer
.
zero_grad
()
loss
.
backward
()
reduce_gradidents
(
agent
,
args
.
world_size
)
return
old_approx_kl
,
approx_kl
,
clipfrac
,
pg_loss
,
v_loss
,
entropy_loss
if
args
.
compile
:
train_step
=
torch
.
compile
(
train_step
,
mode
=
'reduce-overhead'
)
def
to_tensor
(
x
,
dtype
=
torch
.
float32
):
return
optree
.
tree_map
(
lambda
x
:
torch
.
from_numpy
(
x
)
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
True
),
x
)
# ALGO Logic: Storage setup
obs
=
create_obs
(
obs_space
,
(
args
.
num_steps
,
args
.
local_num_envs
),
device
)
actions
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
)
+
action_shape
)
.
to
(
device
)
logprobs
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
))
.
to
(
device
)
rewards
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
))
.
to
(
device
)
dones
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
))
.
to
(
device
)
values
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
))
.
to
(
device
)
avg_ep_returns
=
[]
avg_win_rates
=
[]
# TRY NOT TO MODIFY: start the game
global_step
=
0
warmup_steps
=
0
start_time
=
time
.
time
()
next_obs
=
to_tensor
(
envs
.
reset
()[
0
],
dtype
=
torch
.
uint8
)
next_done
=
torch
.
zeros
(
args
.
local_num_envs
,
device
=
device
)
for
iteration
in
range
(
1
,
args
.
num_iterations
+
1
):
# Annealing the rate if instructed to do so.
if
args
.
anneal_lr
:
frac
=
1.0
-
(
iteration
-
1.0
)
/
args
.
num_iterations
lrnow
=
frac
*
args
.
learning_rate
optimizer
.
param_groups
[
0
][
"lr"
]
=
lrnow
model_time
=
0
env_time
=
0
collect_start
=
time
.
time
()
for
step
in
range
(
0
,
args
.
num_steps
):
global_step
+=
args
.
num_envs
for
key
in
obs
:
obs
[
key
][
step
]
=
next_obs
[
key
]
dones
[
step
]
=
next_done
_start
=
time
.
time
()
with
torch
.
no_grad
():
action
,
logprob
,
_
,
value
,
valid
=
agent
.
get_action_and_value
(
next_obs
)
values
[
step
]
=
value
.
flatten
()
actions
[
step
]
=
action
logprobs
[
step
]
=
logprob
action
=
action
.
cpu
()
.
numpy
()
model_time
+=
time
.
time
()
-
_start
_start
=
time
.
time
()
next_obs
,
reward
,
next_done_
,
info
=
envs
.
step
(
action
)
env_time
+=
time
.
time
()
-
_start
rewards
[
step
]
=
to_tensor
(
reward
)
next_obs
,
next_done
=
to_tensor
(
next_obs
,
torch
.
uint8
),
to_tensor
(
next_done_
)
if
not
writer
:
continue
for
idx
,
d
in
enumerate
(
next_done_
):
if
d
:
episode_length
=
info
[
'l'
][
idx
]
episode_reward
=
info
[
'r'
][
idx
]
writer
.
add_scalar
(
"charts/episodic_return"
,
info
[
"r"
][
idx
],
global_step
)
writer
.
add_scalar
(
"charts/episodic_length"
,
info
[
"l"
][
idx
],
global_step
)
avg_ep_returns
.
append
(
episode_reward
)
winner
=
0
if
episode_reward
>
0
else
1
avg_win_rates
.
append
(
1
-
winner
)
print
(
f
"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}"
)
if
len
(
avg_win_rates
)
>
100
:
writer
.
add_scalar
(
"charts/avg_win_rate"
,
np
.
mean
(
avg_win_rates
),
global_step
)
writer
.
add_scalar
(
"charts/avg_ep_return"
,
np
.
mean
(
avg_ep_returns
),
global_step
)
avg_win_rates
=
[]
avg_ep_returns
=
[]
collect_time
=
time
.
time
()
-
collect_start
# bootstrap value if not done
with
torch
.
no_grad
():
next_value
=
agent
.
get_value
(
next_obs
)
.
reshape
(
1
,
-
1
)
advantages
=
torch
.
zeros_like
(
rewards
)
.
to
(
device
)
lastgaelam
=
0
for
t
in
reversed
(
range
(
args
.
num_steps
)):
if
t
==
args
.
num_steps
-
1
:
nextnonterminal
=
1.0
-
next_done
nextvalues
=
next_value
else
:
nextnonterminal
=
1.0
-
dones
[
t
+
1
]
nextvalues
=
values
[
t
+
1
]
delta
=
rewards
[
t
]
+
args
.
gamma
*
nextvalues
*
nextnonterminal
-
values
[
t
]
advantages
[
t
]
=
lastgaelam
=
delta
+
args
.
gamma
*
args
.
gae_lambda
*
nextnonterminal
*
lastgaelam
returns
=
advantages
+
values
_start
=
time
.
time
()
# flatten the batch
b_obs
=
{
k
:
v
.
reshape
((
-
1
,)
+
v
.
shape
[
2
:])
for
k
,
v
in
obs
.
items
()
}
b_logprobs
=
logprobs
.
reshape
(
-
1
)
b_actions
=
actions
.
reshape
((
-
1
,)
+
action_shape
)
b_advantages
=
advantages
.
reshape
(
-
1
)
b_returns
=
returns
.
reshape
(
-
1
)
b_values
=
values
.
reshape
(
-
1
)
# Optimizing the policy and value network
b_inds
=
np
.
arange
(
args
.
local_batch_size
)
clipfracs
=
[]
for
epoch
in
range
(
args
.
update_epochs
):
np
.
random
.
shuffle
(
b_inds
)
for
start
in
range
(
0
,
args
.
local_batch_size
,
args
.
local_minibatch_size
):
end
=
start
+
args
.
local_minibatch_size
mb_inds
=
b_inds
[
start
:
end
]
mb_obs
=
{
k
:
v
[
mb_inds
]
for
k
,
v
in
b_obs
.
items
()
}
old_approx_kl
,
approx_kl
,
clipfrac
,
pg_loss
,
v_loss
,
entropy_loss
=
\
train_step
(
agent
,
mb_obs
,
b_actions
[
mb_inds
],
b_logprobs
[
mb_inds
],
b_advantages
[
mb_inds
],
b_returns
[
mb_inds
],
b_values
[
mb_inds
])
nn
.
utils
.
clip_grad_norm_
(
agent
.
parameters
(),
args
.
max_grad_norm
)
optimizer
.
step
()
clipfracs
.
append
(
clipfrac
.
item
())
if
args
.
target_kl
is
not
None
and
approx_kl
>
args
.
target_kl
:
break
train_time
=
time
.
time
()
-
_start
if
local_rank
==
0
:
print
(
f
"train_time={train_time:.4f}, collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}"
)
y_pred
,
y_true
=
b_values
.
cpu
()
.
numpy
(),
b_returns
.
cpu
()
.
numpy
()
var_y
=
np
.
var
(
y_true
)
explained_var
=
np
.
nan
if
var_y
==
0
else
1
-
np
.
var
(
y_true
-
y_pred
)
/
var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if
local_rank
==
0
:
writer
.
add_scalar
(
"charts/learning_rate"
,
optimizer
.
param_groups
[
0
][
"lr"
],
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/policy_loss"
,
pg_loss
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/entropy"
,
entropy_loss
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/old_approx_kl"
,
old_approx_kl
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/approx_kl"
,
approx_kl
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/clipfrac"
,
np
.
mean
(
clipfracs
),
global_step
)
writer
.
add_scalar
(
"losses/explained_variance"
,
explained_var
,
global_step
)
SPS
=
int
((
global_step
-
warmup_steps
)
/
(
time
.
time
()
-
start_time
))
# Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters
=
10
if
iteration
==
SPS_warmup_iters
:
start_time
=
time
.
time
()
warmup_steps
=
global_step
if
iteration
>
SPS_warmup_iters
:
print
(
"SPS:"
,
SPS
)
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
if
args
.
world_size
>
1
:
dist
.
destroy_process_group
()
envs
.
close
()
if
local_rank
==
0
:
writer
.
close
()
if
__name__
==
"__main__"
:
mp_start
(
run
)
ygoai/rl/agent.py
View file @
3c7a7080
import
torch
import
torch.nn
as
nn
from
torch.distributions
import
Categorical
def
bytes_to_bin
(
x
,
points
,
intervals
):
...
...
@@ -18,11 +19,11 @@ def make_bin_params(x_max=32000, n_bins=32, sig_bins=24):
return
points
,
intervals
class
Agent
(
nn
.
Module
):
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
channels
=
128
,
num_card_layers
=
2
,
num_action_layers
=
2
,
num_history_action_layers
=
2
,
embedding_shape
=
None
,
bias
=
False
,
affine
=
True
):
super
(
Agent
,
self
)
.
__init__
()
super
(
Encoder
,
self
)
.
__init__
()
self
.
num_history_action_layers
=
num_history_action_layers
c
=
channels
...
...
@@ -129,11 +130,6 @@ class Agent(nn.Module):
])
self
.
action_norm
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
False
)
self
.
value_head
=
nn
.
Sequential
(
nn
.
Linear
(
c
,
c
//
4
),
nn
.
ReLU
(),
nn
.
Linear
(
c
//
4
,
1
),
)
self
.
init_embeddings
()
...
...
@@ -147,7 +143,6 @@ class Agent(nn.Module):
nn
.
init
.
uniform_
(
m
.
weight
,
-
scale
,
scale
)
elif
"fc_emb"
in
n
:
nn
.
init
.
uniform_
(
m
.
weight
,
-
scale
,
scale
)
def
load_embeddings
(
self
,
embeddings
,
freeze
=
True
):
weight
=
self
.
id_embed
.
weight
...
...
@@ -309,7 +304,78 @@ class Agent(nn.Module):
f_actions
=
layer
(
f_actions
,
f_h_actions
)
f_actions
=
self
.
action_norm
(
f_actions
)
return
f_actions
,
mask
,
valid
class
PPOAgent
(
nn
.
Module
):
def
__init__
(
self
,
channels
=
128
,
num_card_layers
=
2
,
num_action_layers
=
2
,
num_history_action_layers
=
2
,
embedding_shape
=
None
,
bias
=
False
,
affine
=
True
):
super
(
PPOAgent
,
self
)
.
__init__
()
self
.
encoder
=
Encoder
(
channels
,
num_card_layers
,
num_action_layers
,
num_history_action_layers
,
embedding_shape
,
bias
,
affine
)
c
=
channels
self
.
actor
=
nn
.
Sequential
(
nn
.
Linear
(
c
,
c
//
4
),
nn
.
ReLU
(),
nn
.
Linear
(
c
//
4
,
1
),
)
self
.
critic
=
nn
.
Sequential
(
nn
.
Linear
(
c
,
c
//
4
),
nn
.
ReLU
(),
nn
.
Linear
(
c
//
4
,
1
),
)
def
load_embeddings
(
self
,
embeddings
,
freeze
=
True
):
self
.
encoder
.
load_embeddings
(
embeddings
,
freeze
)
def
get_value
(
self
,
x
):
f_actions
,
mask
,
valid
=
self
.
encoder
(
x
)
c_mask
=
1
-
mask
.
unsqueeze
(
-
1
)
.
float
()
f
=
(
f_actions
*
c_mask
)
.
sum
(
dim
=
1
)
/
c_mask
.
sum
(
dim
=
1
)
return
self
.
critic
(
f
)
def
get_action_and_value
(
self
,
x
,
action
=
None
):
f_actions
,
mask
,
valid
=
self
.
encoder
(
x
)
c_mask
=
1
-
mask
.
unsqueeze
(
-
1
)
.
float
()
f
=
(
f_actions
*
c_mask
)
.
sum
(
dim
=
1
)
/
c_mask
.
sum
(
dim
=
1
)
logits
=
self
.
actor
(
f_actions
)[
...
,
0
]
logits
=
logits
.
masked_fill
(
mask
,
float
(
"-inf"
))
probs
=
Categorical
(
logits
=
logits
)
if
action
is
None
:
action
=
probs
.
sample
()
return
action
,
probs
.
log_prob
(
action
),
probs
.
entropy
(),
self
.
critic
(
f
),
valid
class
DMCAgent
(
nn
.
Module
):
def
__init__
(
self
,
channels
=
128
,
num_card_layers
=
2
,
num_action_layers
=
2
,
num_history_action_layers
=
2
,
embedding_shape
=
None
,
bias
=
False
,
affine
=
True
):
super
(
DMCAgent
,
self
)
.
__init__
()
self
.
encoder
=
Encoder
(
channels
,
num_card_layers
,
num_action_layers
,
num_history_action_layers
,
embedding_shape
,
bias
,
affine
)
c
=
channels
self
.
value_head
=
nn
.
Sequential
(
nn
.
Linear
(
c
,
c
//
4
),
nn
.
ReLU
(),
nn
.
Linear
(
c
//
4
,
1
),
)
def
load_embeddings
(
self
,
embeddings
,
freeze
=
True
):
self
.
encoder
.
load_embeddings
(
embeddings
,
freeze
)
def
forward
(
self
,
x
):
f_actions
,
mask
,
valid
=
self
.
encoder
(
f_actions
)
values
=
self
.
value_head
(
f_actions
)[
...
,
0
]
# values = torch.tanh(values)
values
=
torch
.
where
(
mask
,
torch
.
full_like
(
values
,
-
10
),
values
)
return
values
,
valid
\ No newline at end of file
return
values
,
valid
\ No newline at end of file
ygoai/rl/dist.py
0 → 100644
View file @
3c7a7080
import
os
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
def
reduce_gradidents
(
model
,
world_size
):
if
world_size
==
1
:
return
all_grads_list
=
[]
for
param
in
model
.
parameters
():
if
param
.
grad
is
not
None
:
all_grads_list
.
append
(
param
.
grad
.
view
(
-
1
))
all_grads
=
torch
.
cat
(
all_grads_list
)
dist
.
all_reduce
(
all_grads
,
op
=
dist
.
ReduceOp
.
SUM
)
offset
=
0
for
param
in
model
.
parameters
():
if
param
.
grad
is
not
None
:
param
.
grad
.
data
.
copy_
(
all_grads
[
offset
:
offset
+
param
.
numel
()]
.
view_as
(
param
.
grad
.
data
)
/
world_size
)
offset
+=
param
.
numel
()
def
setup
(
backend
,
rank
,
world_size
,
port
):
os
.
environ
[
'MASTER_ADDR'
]
=
'127.0.0.1'
os
.
environ
[
'MASTER_PORT'
]
=
str
(
port
)
dist
.
init_process_group
(
backend
,
rank
=
rank
,
world_size
=
world_size
)
def
mp_start
(
run
):
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
if
world_size
==
1
:
run
(
local_rank
=
0
,
world_size
=
world_size
)
else
:
children
=
[]
for
i
in
range
(
world_size
):
subproc
=
mp
.
Process
(
target
=
run
,
args
=
(
i
,
world_size
))
children
.
append
(
subproc
)
subproc
.
start
()
for
i
in
range
(
world_size
):
children
[
i
]
.
join
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment