Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
43ca871e
Commit
43ca871e
authored
Apr 18, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactor switch
parent
93bc3723
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
502 additions
and
481 deletions
+502
-481
scripts/jax/battle.py
scripts/jax/battle.py
+11
-1
scripts/jax/impala.py
scripts/jax/impala.py
+125
-101
scripts/jax/ppo.py
scripts/jax/ppo.py
+72
-65
ygoai/rl/jax/__init__.py
ygoai/rl/jax/__init__.py
+236
-303
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+5
-0
ygoai/rl/jax/switch.py
ygoai/rl/jax/switch.py
+39
-0
ygoai/rl/utils.py
ygoai/rl/utils.py
+0
-10
ygoai/utils.py
ygoai/utils.py
+14
-1
No files found.
scripts/jax/battle.py
View file @
43ca871e
...
@@ -4,6 +4,7 @@ import os
...
@@ -4,6 +4,7 @@ import os
import
random
import
random
from
typing
import
Optional
,
Literal
from
typing
import
Optional
,
Literal
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
tqdm
import
tqdm
import
ygoenv
import
ygoenv
import
numpy
as
np
import
numpy
as
np
...
@@ -220,6 +221,9 @@ if __name__ == "__main__":
...
@@ -220,6 +221,9 @@ if __name__ == "__main__":
])
])
rstate1
=
rstate2
=
init_rnn_state
(
num_envs
,
args
.
rnn_channels
)
rstate1
=
rstate2
=
init_rnn_state
(
num_envs
,
args
.
rnn_channels
)
if
not
args
.
verbose
:
pbar
=
tqdm
(
total
=
args
.
num_episodes
)
model_time
=
env_time
=
0
model_time
=
env_time
=
0
while
True
:
while
True
:
if
start_step
==
0
and
len
(
episode_lengths
)
>
int
(
args
.
num_episodes
*
0.1
):
if
start_step
==
0
and
len
(
episode_lengths
)
>
int
(
args
.
num_episodes
*
0.1
):
...
@@ -255,7 +259,11 @@ if __name__ == "__main__":
...
@@ -255,7 +259,11 @@ if __name__ == "__main__":
episode_rewards
.
append
(
episode_reward
)
episode_rewards
.
append
(
episode_reward
)
win_rates
.
append
(
win
)
win_rates
.
append
(
win
)
win_reasons
.
append
(
1
if
win_reason
==
1
else
0
)
win_reasons
.
append
(
1
if
win_reason
==
1
else
0
)
sys
.
stderr
.
write
(
f
"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}
\n
"
)
if
args
.
verbose
:
print
(
f
"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}
\n
"
)
else
:
pbar
.
set_postfix
(
len
=
np
.
mean
(
episode_lengths
),
reward
=
np
.
mean
(
episode_rewards
),
win_rate
=
np
.
mean
(
win_rates
))
pbar
.
update
(
1
)
# Only when num_envs=1, we switch the player here
# Only when num_envs=1, we switch the player here
if
args
.
verbose
:
if
args
.
verbose
:
...
@@ -264,6 +272,8 @@ if __name__ == "__main__":
...
@@ -264,6 +272,8 @@ if __name__ == "__main__":
if
len
(
episode_lengths
)
>=
args
.
num_episodes
:
if
len
(
episode_lengths
)
>=
args
.
num_episodes
:
break
break
if
not
args
.
verbose
:
pbar
.
close
()
print
(
f
"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}"
)
print
(
f
"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}"
)
total_time
=
time
.
time
()
-
start
total_time
=
time
.
time
()
-
start
...
...
scripts/jax/impala.py
View file @
43ca871e
...
@@ -16,18 +16,17 @@ import jax
...
@@ -16,18 +16,17 @@ import jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
numpy
as
np
import
numpy
as
np
import
optax
import
optax
import
rlax
import
distrax
import
distrax
import
tyro
import
tyro
from
flax.training.train_state
import
TrainState
from
flax.training.train_state
import
TrainState
from
rich.pretty
import
pprint
from
rich.pretty
import
pprint
from
tensorboardX
import
SummaryWriter
from
tensorboardX
import
SummaryWriter
from
ygoai.utils
import
init_ygopro
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
categorical_sample
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
upgo_return
,
vtrace
,
clipped_surrogate_pg
_loss
from
ygoai.rl.jax
import
vtrace_2p0s
,
clipped_surrogate_pg_loss
,
policy_gradient_loss
,
mse_loss
,
entropy
_loss
os
.
environ
[
"XLA_FLAGS"
]
=
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
os
.
environ
[
"XLA_FLAGS"
]
=
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
...
@@ -63,10 +62,12 @@ class Args:
...
@@ -63,10 +62,12 @@ class Args:
"""the maximum number of options"""
"""the maximum number of options"""
n_history_actions
:
int
=
32
n_history_actions
:
int
=
32
"""the number of history actions to use"""
"""the number of history actions to use"""
greedy_reward
:
bool
=
True
"""whether to use greedy reward (faster kill higher reward)"""
total_timesteps
:
int
=
5000000000
total_timesteps
:
int
=
5000000000
"""total timesteps of the experiments"""
"""total timesteps of the experiments"""
learning_rate
:
float
=
1e-
4
learning_rate
:
float
=
1e-
3
"""the learning rate of the optimizer"""
"""the learning rate of the optimizer"""
local_num_envs
:
int
=
128
local_num_envs
:
int
=
128
"""the number of parallel game environments"""
"""the number of parallel game environments"""
...
@@ -74,15 +75,15 @@ class Args:
...
@@ -74,15 +75,15 @@ class Args:
"""the number of threads to use for environment"""
"""the number of threads to use for environment"""
num_actor_threads
:
int
=
2
num_actor_threads
:
int
=
2
"""the number of actor threads to use"""
"""the number of actor threads to use"""
num_steps
:
int
=
32
num_steps
:
int
=
128
"""the number of steps to run in each environment per policy rollout"""
"""the number of steps to run in each environment per policy rollout"""
collect_length
:
Optional
[
int
]
=
None
"""the number of steps to compute the advantages"""
anneal_lr
:
bool
=
False
anneal_lr
:
bool
=
False
"""Toggle learning rate annealing for policy and value networks"""
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
1.0
gamma
:
float
=
1.0
"""the discount factor gamma"""
"""the discount factor gamma"""
num_minibatches
:
int
=
4
upgo
:
bool
=
False
"""Toggle the use of UPGO for advantages"""
num_minibatches
:
int
=
8
"""the number of mini-batches"""
"""the number of mini-batches"""
update_epochs
:
int
=
2
update_epochs
:
int
=
2
"""the K epochs to update the policy"""
"""the K epochs to update the policy"""
...
@@ -94,12 +95,12 @@ class Args:
...
@@ -94,12 +95,12 @@ class Args:
"""the minimum value of the importance sampling clipping"""
"""the minimum value of the importance sampling clipping"""
rho_clip_max
:
float
=
1.007
rho_clip_max
:
float
=
1.007
"""the maximum value of the importance sampling clipping"""
"""the maximum value of the importance sampling clipping"""
upgo
:
bool
=
False
"""whether to use UPGO for policy update"""
ppo_clip
:
bool
=
True
ppo_clip
:
bool
=
True
"""whether to use the PPO clipping to replace V-Trace surrogate clipping"""
"""whether to use the PPO clipping to replace V-Trace surrogate clipping"""
clip_coef
:
float
=
0.25
clip_coef
:
float
=
0.25
"""the PPO surrogate clipping coefficient"""
"""the PPO surrogate clipping coefficient"""
dual_clip_coef
:
Optional
[
float
]
=
None
"""the dual surrogate clipping coefficient"""
ent_coef
:
float
=
0.01
ent_coef
:
float
=
0.01
"""coefficient of the entropy"""
"""coefficient of the entropy"""
vf_coef
:
float
=
0.5
vf_coef
:
float
=
0.5
...
@@ -122,11 +123,13 @@ class Args:
...
@@ -122,11 +123,13 @@ class Args:
"""whether to use `jax.distirbuted`"""
"""whether to use `jax.distirbuted`"""
concurrency
:
bool
=
True
concurrency
:
bool
=
True
"""whether to run the actor and learner concurrently"""
"""whether to run the actor and learner concurrently"""
bfloat16
:
bool
=
Tru
e
bfloat16
:
bool
=
Fals
e
"""whether to use bfloat16 for the agent"""
"""whether to use bfloat16 for the agent"""
thread_affinity
:
bool
=
False
thread_affinity
:
bool
=
False
"""whether to use thread affinity for the environment"""
"""whether to use thread affinity for the environment"""
eval_checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes
:
int
=
32
local_eval_episodes
:
int
=
32
"""the number of episodes to evaluate the model"""
"""the number of episodes to evaluate the model"""
eval_interval
:
int
=
50
eval_interval
:
int
=
50
...
@@ -145,6 +148,7 @@ class Args:
...
@@ -145,6 +148,7 @@ class Args:
actor_devices
:
Optional
[
List
[
str
]]
=
None
actor_devices
:
Optional
[
List
[
str
]]
=
None
learner_devices
:
Optional
[
List
[
str
]]
=
None
learner_devices
:
Optional
[
List
[
str
]]
=
None
num_embeddings
:
Optional
[
int
]
=
None
num_embeddings
:
Optional
[
int
]
=
None
freeze_id
:
bool
=
False
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
):
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
):
...
@@ -164,6 +168,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
...
@@ -164,6 +168,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
max_options
=
args
.
max_options
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
n_history_actions
=
args
.
n_history_actions
,
async_reset
=
False
,
async_reset
=
False
,
greedy_reward
=
args
.
greedy_reward
if
mode
==
'self'
else
True
,
play_mode
=
mode
,
play_mode
=
mode
,
)
)
envs
.
num_envs
=
num_envs
envs
.
num_envs
=
num_envs
...
@@ -177,7 +182,6 @@ class Transition(NamedTuple):
...
@@ -177,7 +182,6 @@ class Transition(NamedTuple):
logits
:
list
logits
:
list
rewards
:
list
rewards
:
list
mains
:
list
mains
:
list
next_dones
:
list
def
create_agent
(
args
,
multi_step
=
False
):
def
create_agent
(
args
,
multi_step
=
False
):
...
@@ -189,6 +193,7 @@ def create_agent(args, multi_step=False):
...
@@ -189,6 +193,7 @@ def create_agent(args, multi_step=False):
param_dtype
=
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
lstm_channels
=
args
.
rnn_channels
,
lstm_channels
=
args
.
rnn_channels
,
multi_step
=
multi_step
,
multi_step
=
multi_step
,
freeze_id
=
args
.
freeze_id
,
)
)
...
@@ -209,6 +214,10 @@ def rollout(
...
@@ -209,6 +214,10 @@ def rollout(
learner_devices
,
learner_devices
,
device_thread_id
,
device_thread_id
,
):
):
eval_mode
=
'self'
if
args
.
eval_checkpoint
else
'bot'
if
eval_mode
!=
'bot'
:
eval_params
=
params_queue
.
get
()
envs
=
make_env
(
envs
=
make_env
(
args
,
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
...
@@ -222,7 +231,7 @@ def rollout(
...
@@ -222,7 +231,7 @@ def rollout(
args
,
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
'bot'
)
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
...
@@ -249,6 +258,17 @@ def rollout(
...
@@ -249,6 +258,17 @@ def rollout(
rstate
,
logits
=
get_logits
(
params
,
inputs
,
done
)
rstate
,
logits
=
get_logits
(
params
,
inputs
,
done
)
return
rstate
,
logits
.
argmax
(
axis
=
1
)
return
rstate
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
get_action_battle
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
logits1
=
get_logits
(
params1
,
(
rstate1
,
obs
),
done
)
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
),
done
)
logits
=
jnp
.
where
(
main
[:,
None
],
logits1
,
logits2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
next_rstate2
,
rstate2
)
return
rstate1
,
rstate2
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
@
jax
.
jit
def
sample_action
(
def
sample_action
(
params
:
flax
.
core
.
FrozenDict
,
params
:
flax
.
core
.
FrozenDict
,
...
@@ -281,7 +301,6 @@ def rollout(
...
@@ -281,7 +301,6 @@ def rollout(
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
])
np
.
random
.
shuffle
(
main_player
)
np
.
random
.
shuffle
(
main_player
)
start_step
=
0
storage
=
[]
storage
=
[]
@
jax
.
jit
@
jax
.
jit
...
@@ -312,7 +331,7 @@ def rollout(
...
@@ -312,7 +331,7 @@ def rollout(
rollout_time_start
=
time
.
time
()
rollout_time_start
=
time
.
time
()
init_rstate1
,
init_rstate2
=
jax
.
tree
.
map
(
init_rstate1
,
init_rstate2
=
jax
.
tree
.
map
(
lambda
x
:
x
.
copy
(),
(
next_rstate1
,
next_rstate2
))
lambda
x
:
x
.
copy
(),
(
next_rstate1
,
next_rstate2
))
for
_
in
range
(
start_step
,
args
.
collect_length
):
for
_
in
range
(
args
.
num_steps
):
global_step
+=
args
.
local_num_envs
*
n_actors
*
args
.
world_size
global_step
+=
args
.
local_num_envs
*
n_actors
*
args
.
world_size
cached_next_obs
=
next_obs
cached_next_obs
=
next_obs
...
@@ -340,7 +359,6 @@ def rollout(
...
@@ -340,7 +359,6 @@ def rollout(
actions
=
action
,
actions
=
action
,
logits
=
logits
,
logits
=
logits
,
rewards
=
next_reward
,
rewards
=
next_reward
,
next_dones
=
next_done
,
)
)
)
)
...
@@ -348,15 +366,6 @@ def rollout(
...
@@ -348,15 +366,6 @@ def rollout(
if
not
d
:
if
not
d
:
continue
continue
cur_main
=
main
[
idx
]
cur_main
=
main
[
idx
]
for
j
in
reversed
(
range
(
len
(
storage
)
-
1
)):
t
=
storage
[
j
]
if
t
.
next_dones
[
idx
]:
# For OTK where player may not switch
break
if
t
.
mains
[
idx
]
!=
cur_main
:
t
.
next_dones
[
idx
]
=
True
t
.
rewards
[
idx
]
=
-
next_reward
[
idx
]
break
episode_reward
=
info
[
'r'
][
idx
]
*
(
1
if
cur_main
else
-
1
)
episode_reward
=
info
[
'r'
][
idx
]
*
(
1
if
cur_main
else
-
1
)
win
=
1
if
episode_reward
>
0
else
0
win
=
1
if
episode_reward
>
0
else
0
avg_ep_returns
.
append
(
episode_reward
)
avg_ep_returns
.
append
(
episode_reward
)
...
@@ -364,10 +373,8 @@ def rollout(
...
@@ -364,10 +373,8 @@ def rollout(
rollout_time
.
append
(
time
.
time
()
-
rollout_time_start
)
rollout_time
.
append
(
time
.
time
()
-
rollout_time_start
)
start_step
=
args
.
collect_length
-
args
.
num_steps
partitioned_storage
=
prepare_data
(
storage
)
partitioned_storage
=
prepare_data
(
storage
)
storage
=
storage
[
args
.
num_steps
:
]
storage
=
[
]
sharded_storage
=
[]
sharded_storage
=
[]
for
x
in
partitioned_storage
:
for
x
in
partitioned_storage
:
if
isinstance
(
x
,
dict
):
if
isinstance
(
x
,
dict
):
...
@@ -384,7 +391,7 @@ def rollout(
...
@@ -384,7 +391,7 @@ def rollout(
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
next_rstate1
,
next_rstate2
)
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
next_rstate1
,
next_rstate2
)
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
(
init_rstate1
,
init_rstate2
,
(
next_rstate
,
next_obs
),
next_main
))
(
init_rstate1
,
init_rstate2
,
(
next_rstate
,
next_obs
),
next_
done
,
next_
main
))
learn_opponent
=
False
learn_opponent
=
False
payload
=
(
payload
=
(
global_step
,
global_step
,
...
@@ -403,10 +410,13 @@ def rollout(
...
@@ -403,10 +410,13 @@ def rollout(
SPS_update
=
int
(
args
.
batch_size
/
(
time
.
time
()
-
update_time_start
))
SPS_update
=
int
(
args
.
batch_size
/
(
time
.
time
()
-
update_time_start
))
if
device_thread_id
==
0
:
if
device_thread_id
==
0
:
print
(
print
(
f
"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}
, rollout_time={rollout_time[-1]:.2f}
"
f
"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
)
)
time_now
=
datetime
.
now
(
timezone
(
timedelta
(
hours
=
8
)))
.
strftime
(
"
%
H:
%
M:
%
S"
)
time_now
=
datetime
.
now
(
timezone
(
timedelta
(
hours
=
8
)))
.
strftime
(
"
%
H:
%
M:
%
S"
)
print
(
f
"{time_now} SPS: {SPS}, update: {SPS_update}"
)
print
(
f
"{time_now} SPS: {SPS}, update: {SPS_update}, "
f
"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
)
writer
.
add_scalar
(
"stats/rollout_time"
,
np
.
mean
(
rollout_time
),
global_step
)
writer
.
add_scalar
(
"stats/rollout_time"
,
np
.
mean
(
rollout_time
),
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_return"
,
avg_episodic_return
,
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_return"
,
avg_episodic_return
,
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_length"
,
avg_episodic_length
,
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_length"
,
avg_episodic_length
,
global_step
)
...
@@ -419,19 +429,28 @@ def rollout(
...
@@ -419,19 +429,28 @@ def rollout(
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
# Eval with rule-based policy
# Eval with rule-based policy
_start
=
time
.
time
()
_start
=
time
.
time
()
eval_return
=
evaluate
(
eval_envs
,
get_action
,
params
,
eval_rstate
)[
0
]
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
eval_stat
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)[
0
]
metric_name
=
"eval_return"
else
:
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
eval_stat
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)[
2
]
metric_name
=
"eval_win_rate"
if
device_thread_id
!=
0
:
if
device_thread_id
!=
0
:
eval_queue
.
put
(
eval_
return
)
eval_queue
.
put
(
eval_
stat
)
else
:
else
:
eval_stats
=
[]
eval_stats
=
[]
eval_stats
.
append
(
eval_
return
)
eval_stats
.
append
(
eval_
stat
)
for
_
in
range
(
1
,
n_actors
):
for
_
in
range
(
1
,
n_actors
):
eval_stats
.
append
(
eval_queue
.
get
())
eval_stats
.
append
(
eval_queue
.
get
())
eval_stats
=
np
.
mean
(
eval_stats
)
eval_stats
=
np
.
mean
(
eval_stats
)
writer
.
add_scalar
(
"charts/eval_return
"
,
eval_stats
,
global_step
)
writer
.
add_scalar
(
f
"charts/{metric_name}
"
,
eval_stats
,
global_step
)
if
device_thread_id
==
0
:
if
device_thread_id
==
0
:
eval_time
=
time
.
time
()
-
_start
eval_time
=
time
.
time
()
-
_start
print
(
f
"eval_time={eval_time:.4f},
eval_ep_return
={eval_stats:.4f}"
)
print
(
f
"eval_time={eval_time:.4f},
{metric_name}
={eval_stats:.4f}"
)
other_time
+=
eval_time
other_time
+=
eval_time
...
@@ -461,8 +480,15 @@ if __name__ == "__main__":
...
@@ -461,8 +480,15 @@ if __name__ == "__main__":
args
.
minibatch_size
=
args
.
local_minibatch_size
*
args
.
world_size
args
.
minibatch_size
=
args
.
local_minibatch_size
*
args
.
world_size
args
.
num_updates
=
args
.
total_timesteps
//
(
args
.
local_batch_size
*
args
.
world_size
)
args
.
num_updates
=
args
.
total_timesteps
//
(
args
.
local_batch_size
*
args
.
world_size
)
args
.
local_env_threads
=
args
.
local_env_threads
or
args
.
local_num_envs
args
.
local_env_threads
=
args
.
local_env_threads
or
args
.
local_num_envs
args
.
collect_length
=
args
.
collect_length
or
args
.
num_steps
assert
args
.
collect_length
>=
args
.
num_steps
,
"collect_length must be greater than or equal to num_steps"
if
args
.
embedding_file
:
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
embedding_shape
=
embeddings
.
shape
args
.
num_embeddings
=
embedding_shape
args
.
freeze_id
=
True
if
args
.
freeze_id
is
None
else
args
.
freeze_id
else
:
embeddings
=
None
embedding_shape
=
None
local_devices
=
jax
.
local_devices
()
local_devices
=
jax
.
local_devices
()
global_devices
=
jax
.
devices
()
global_devices
=
jax
.
devices
()
...
@@ -517,6 +543,13 @@ if __name__ == "__main__":
...
@@ -517,6 +543,13 @@ if __name__ == "__main__":
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
agent
=
create_agent
(
args
)
agent
=
create_agent
(
args
)
params
=
agent
.
init
(
agent_key
,
(
rstate
,
sample_obs
))
params
=
agent
.
init
(
agent_key
,
(
rstate
,
sample_obs
))
if
embeddings
is
not
None
:
unknown_embed
=
embeddings
.
mean
(
axis
=
0
)
embeddings
=
np
.
concatenate
([
unknown_embed
[
None
,
:],
embeddings
],
axis
=
0
)
params
=
flax
.
core
.
unfreeze
(
params
)
params
[
'params'
][
'Encoder_0'
][
'Embed_0'
][
'embedding'
]
=
jax
.
device_put
(
embeddings
)
params
=
flax
.
core
.
freeze
(
params
)
tx
=
optax
.
MultiSteps
(
tx
=
optax
.
MultiSteps
(
optax
.
chain
(
optax
.
chain
(
optax
.
clip_by_global_norm
(
args
.
max_grad_norm
),
optax
.
clip_by_global_norm
(
args
.
max_grad_norm
),
...
@@ -541,6 +574,13 @@ if __name__ == "__main__":
...
@@ -541,6 +574,13 @@ if __name__ == "__main__":
agent_state
=
flax
.
jax_utils
.
replicate
(
agent_state
,
devices
=
learner_devices
)
agent_state
=
flax
.
jax_utils
.
replicate
(
agent_state
,
devices
=
learner_devices
)
# print(agent.tabulate(agent_key, sample_obs))
# print(agent.tabulate(agent_key, sample_obs))
if
args
.
eval_checkpoint
:
with
open
(
args
.
eval_checkpoint
,
"rb"
)
as
f
:
eval_params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
print
(
f
"loaded eval checkpoint from {args.eval_checkpoint}"
)
else
:
eval_params
=
None
@
jax
.
jit
@
jax
.
jit
def
get_logits_and_value
(
def
get_logits_and_value
(
params
:
flax
.
core
.
FrozenDict
,
inputs
,
params
:
flax
.
core
.
FrozenDict
,
inputs
,
...
@@ -550,67 +590,54 @@ if __name__ == "__main__":
...
@@ -550,67 +590,54 @@ if __name__ == "__main__":
return
logits
,
value
.
squeeze
(
-
1
)
return
logits
,
value
.
squeeze
(
-
1
)
def
ppo_loss
(
def
ppo_loss
(
params
,
rstate1
,
rstate2
,
obs
,
dones
,
next_done
s
,
params
,
rstate1
,
rstate2
,
obs
,
dones
,
main
s
,
switch
,
actions
,
logits
,
rewards
,
mask
,
next_valu
e
):
actions
,
logits
,
rewards
,
mask
,
next_value
,
next_don
e
):
# (num_steps * local_num_envs // n_mb))
# (num_steps * local_num_envs // n_mb))
num_envs
=
next_value
.
shape
[
0
]
num_envs
=
next_value
.
shape
[
0
]
num_steps
=
dones
.
shape
[
0
]
//
num_envs
num_steps
=
dones
.
shape
[
0
]
//
num_envs
mask
=
mask
&
(
~
dones
)
mask
=
mask
*
(
1.0
-
dones
)
n_valids
=
jnp
.
sum
(
mask
)
n_valids
=
jnp
.
sum
(
mask
)
real_dones
=
dones
|
next_dones
inputs
=
(
rstate1
,
rstate2
,
obs
,
real_dones
,
switch
)
inputs
=
(
rstate1
,
rstate2
,
obs
,
dones
,
mains
)
new_logits
,
new_values
=
get_logits_and_value
(
params
,
inputs
)
new_logits
,
new_values
=
get_logits_and_value
(
params
,
inputs
)
new_logits
,
v_tm1
,
logits
,
actions
,
rewards
,
next_dones
,
switch
,
mask
=
jax
.
tree
.
map
(
new_logits
,
new_values
,
logits
,
actions
,
rewards
,
dones
,
mains
,
mask
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)
+
x
.
shape
[
1
:]),
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)
+
x
.
shape
[
1
:]),
(
new_logits
,
new_values
,
logits
,
actions
,
rewards
,
next_dones
,
switch
,
mask
),
(
new_logits
,
new_values
,
logits
,
actions
,
rewards
,
dones
,
mains
,
mask
),
)
)
next_dones
=
jnp
.
concatenate
([
dones
[
1
:],
next_done
[
None
,
:]],
axis
=
0
)
v_t
=
jnp
.
concatenate
([
v_tm1
[
1
:],
next_value
[
None
,
:]],
axis
=
0
)
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
discounts
=
(
1.0
-
next_dones
)
*
args
.
gamma
ratio
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
logratio
=
jnp
.
log
(
ratio
)
approx_kl
=
(((
ratio
-
1
)
-
logratio
)
*
mask
)
.
sum
()
/
n_valids
# TODO: TD(lambda) for multi-step
target_values
,
advantages
=
vtrace_2p0s
(
# TODO: use switch to calculate the correct value
next_value
,
ratios
,
new_values
,
rewards
,
next_dones
,
mains
,
args
.
gamma
,
vtrace_fn
=
partial
(
args
.
rho_clip_min
,
args
.
rho_clip_max
,
args
.
c_clip_min
,
args
.
c_clip_max
)
vtrace
,
c_clip_min
=
args
.
c_clip_min
,
c_clip_max
=
args
.
c_clip_max
,
rho_clip_min
=
args
.
rho_clip_min
,
rho_clip_max
=
args
.
rho_clip_max
)
logratio
=
jnp
.
log
(
ratios
)
vtrace_returns
=
jax
.
vmap
(
approx_kl
=
(((
ratios
-
1
)
-
logratio
)
*
mask
)
.
sum
()
/
n_valids
vtrace_fn
,
in_axes
=
1
,
out_axes
=
1
)(
v_tm1
,
v_t
,
rewards
,
discounts
,
ratio
)
if
args
.
upgo
:
advs
=
jax
.
vmap
(
upgo_return
,
in_axes
=
1
,
out_axes
=
1
)(
rewards
,
v_t
,
discounts
)
-
v_tm1
else
:
advs
=
vtrace_returns
.
q_estimate
-
v_tm1
if
args
.
ppo_clip
:
if
args
.
ppo_clip
:
pg_loss
=
jax
.
vmap
(
pg_loss
=
clipped_surrogate_pg_loss
(
partial
(
clipped_surrogate_pg_loss
,
epsilon
=
args
.
clip_coef
),
in_axes
=
1
)(
ratios
,
advantages
,
args
.
clip_coef
,
args
.
dual_clip_coef
)
ratio
,
advs
,
mask
)
*
num_steps
pg_loss
=
jnp
.
sum
(
pg_loss
)
else
:
else
:
pg_advs
=
jnp
.
minimum
(
args
.
rho_clip_max
,
ratio
)
*
advs
pg_advs
=
jnp
.
clip
(
ratios
,
args
.
rho_clip_min
,
args
.
rho_clip_max
)
*
advantages
pg_loss
=
jax
.
vmap
(
pg_loss
=
policy_gradient_loss
(
new_logits
,
actions
,
pg_advs
)
rlax
.
policy_gradient_loss
,
in_axes
=
1
)(
pg_loss
=
jnp
.
sum
(
pg_loss
*
mask
)
new_logits
,
actions
,
pg_advs
,
mask
)
*
num_steps
pg_loss
=
jnp
.
sum
(
pg_loss
)
v_loss
=
0.5
*
(
vtrace_returns
.
errors
**
2
)
v_loss
=
mse_loss
(
new_values
,
target_values
)
v_loss
=
jnp
.
sum
(
v_loss
*
mask
)
v_loss
=
jnp
.
sum
(
v_loss
*
mask
)
ent
ropy_loss
=
distrax
.
Softmax
(
new_logits
)
.
entropy
(
)
ent
_loss
=
entropy_loss
(
new_logits
)
ent
ropy_loss
=
jnp
.
sum
(
entropy
_loss
*
mask
)
ent
_loss
=
jnp
.
sum
(
ent
_loss
*
mask
)
pg_loss
=
pg_loss
/
n_valids
pg_loss
=
pg_loss
/
n_valids
v_loss
=
v_loss
/
n_valids
v_loss
=
v_loss
/
n_valids
ent
ropy_loss
=
entropy
_loss
/
n_valids
ent
_loss
=
ent
_loss
/
n_valids
loss
=
pg_loss
-
args
.
ent_coef
*
ent
ropy
_loss
+
v_loss
*
args
.
vf_coef
loss
=
pg_loss
-
args
.
ent_coef
*
ent_loss
+
v_loss
*
args
.
vf_coef
return
loss
,
(
pg_loss
,
v_loss
,
ent
ropy
_loss
,
jax
.
lax
.
stop_gradient
(
approx_kl
))
return
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
jax
.
lax
.
stop_gradient
(
approx_kl
))
def
single_device_update
(
def
single_device_update
(
agent_state
:
TrainState
,
agent_state
:
TrainState
,
...
@@ -618,6 +645,7 @@ if __name__ == "__main__":
...
@@ -618,6 +645,7 @@ if __name__ == "__main__":
sharded_init_rstate1
:
List
,
sharded_init_rstate1
:
List
,
sharded_init_rstate2
:
List
,
sharded_init_rstate2
:
List
,
sharded_next_inputs
:
List
,
sharded_next_inputs
:
List
,
sharded_next_done
:
List
,
sharded_next_main
:
List
,
sharded_next_main
:
List
,
key
:
jax
.
random
.
PRNGKey
,
key
:
jax
.
random
.
PRNGKey
,
learn_opponent
:
bool
=
False
,
learn_opponent
:
bool
=
False
,
...
@@ -627,20 +655,13 @@ if __name__ == "__main__":
...
@@ -627,20 +655,13 @@ if __name__ == "__main__":
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
for
x
in
[
sharded_next_inputs
,
sharded_init_rstate1
,
sharded_init_rstate2
]
for
x
in
[
sharded_next_inputs
,
sharded_init_rstate1
,
sharded_init_rstate2
]
]
]
next_main
,
=
[
next_main
,
next_done
=
[
jnp
.
concatenate
(
x
)
for
x
in
[
sharded_next_main
]
jnp
.
concatenate
(
x
)
for
x
in
[
sharded_next_main
,
sharded_next_done
]
]
]
# reorder storage of individual players
# reorder storage of individual players
# main first, opponent second
# main first, opponent second
num_steps
,
num_envs
=
storage
.
rewards
.
shape
num_steps
,
num_envs
=
storage
.
rewards
.
shape
T
=
jnp
.
arange
(
num_steps
,
dtype
=
jnp
.
int32
)
B
=
jnp
.
arange
(
num_envs
,
dtype
=
jnp
.
int32
)
mains
=
storage
.
mains
.
astype
(
jnp
.
int32
)
indices
=
jnp
.
argsort
(
T
[:,
None
]
-
mains
*
num_steps
,
axis
=
0
)
switch_steps
=
jnp
.
sum
(
mains
,
axis
=
0
)
switch
=
T
[:,
None
]
==
(
switch_steps
[
None
,
:]
-
1
)
storage
=
jax
.
tree
.
map
(
lambda
x
:
x
[
indices
,
B
[
None
,
:]],
storage
)
ppo_loss_grad_fn
=
jax
.
value_and_grad
(
ppo_loss
,
has_aux
=
True
)
ppo_loss_grad_fn
=
jax
.
value_and_grad
(
ppo_loss
,
has_aux
=
True
)
...
@@ -650,9 +671,7 @@ if __name__ == "__main__":
...
@@ -650,9 +671,7 @@ if __name__ == "__main__":
next_value
=
create_agent
(
args
)
.
apply
(
next_value
=
create_agent
(
args
)
.
apply
(
agent_state
.
params
,
next_inputs
)[
2
]
.
squeeze
(
-
1
)
agent_state
.
params
,
next_inputs
)[
2
]
.
squeeze
(
-
1
)
# TODO: check if this is correct
next_value
=
jnp
.
where
(
next_main
,
next_value
,
-
next_value
)
sign
=
jnp
.
where
(
switch_steps
<=
num_steps
,
1.0
,
-
1.0
)
next_value
=
jnp
.
where
(
next_main
,
-
sign
*
next_value
,
sign
*
next_value
)
def
convert_data
(
x
:
jnp
.
ndarray
,
num_steps
):
def
convert_data
(
x
:
jnp
.
ndarray
,
num_steps
):
if
args
.
update_epochs
>
1
:
if
args
.
update_epochs
>
1
:
...
@@ -666,10 +685,11 @@ if __name__ == "__main__":
...
@@ -666,10 +685,11 @@ if __name__ == "__main__":
x
=
jnp
.
reshape
(
x
,
(
N
,
-
1
)
+
x
.
shape
[
1
:])
x
=
jnp
.
reshape
(
x
,
(
N
,
-
1
)
+
x
.
shape
[
1
:])
return
x
return
x
shuffled_init_rstate1
,
shuffled_init_rstate2
,
shuffled_next_value
=
jax
.
tree
.
map
(
shuffled_init_rstate1
,
shuffled_init_rstate2
,
\
partial
(
convert_data
,
num_steps
=
1
),
(
init_rstate1
,
init_rstate2
,
next_value
))
shuffled_next_value
,
shuffled_next_done
=
jax
.
tree
.
map
(
shuffled_storage
,
shuffled_switch
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
1
),
(
init_rstate1
,
init_rstate2
,
next_value
,
next_done
))
partial
(
convert_data
,
num_steps
=
num_steps
),
(
storage
,
switch
))
shuffled_storage
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
num_steps
),
storage
)
shuffled_mask
=
jnp
.
ones_like
(
shuffled_storage
.
mains
)
shuffled_mask
=
jnp
.
ones_like
(
shuffled_storage
.
mains
)
def
update_minibatch
(
agent_state
,
minibatch
):
def
update_minibatch
(
agent_state
,
minibatch
):
...
@@ -687,13 +707,13 @@ if __name__ == "__main__":
...
@@ -687,13 +707,13 @@ if __name__ == "__main__":
shuffled_init_rstate2
,
shuffled_init_rstate2
,
shuffled_storage
.
obs
,
shuffled_storage
.
obs
,
shuffled_storage
.
dones
,
shuffled_storage
.
dones
,
shuffled_storage
.
next_dones
,
shuffled_storage
.
mains
,
shuffled_switch
,
shuffled_storage
.
actions
,
shuffled_storage
.
actions
,
shuffled_storage
.
logits
,
shuffled_storage
.
logits
,
shuffled_storage
.
rewards
,
shuffled_storage
.
rewards
,
shuffled_mask
,
shuffled_mask
,
shuffled_next_value
,
shuffled_next_value
,
shuffled_next_done
,
),
),
)
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
...
@@ -712,7 +732,7 @@ if __name__ == "__main__":
...
@@ -712,7 +732,7 @@ if __name__ == "__main__":
single_device_update
,
single_device_update
,
axis_name
=
"local_devices"
,
axis_name
=
"local_devices"
,
devices
=
global_learner_decices
,
devices
=
global_learner_decices
,
static_broadcasted_argnums
=
(
7
,),
static_broadcasted_argnums
=
(
8
,),
)
)
params_queues
=
[]
params_queues
=
[]
...
@@ -727,7 +747,9 @@ if __name__ == "__main__":
...
@@ -727,7 +747,9 @@ if __name__ == "__main__":
for
thread_id
in
range
(
args
.
num_actor_threads
):
for
thread_id
in
range
(
args
.
num_actor_threads
):
params_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
params_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
rollout_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
rollout_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
params_queues
[
-
1
]
.
put
(
device_params
)
if
eval_params
:
params_queues
[
-
1
]
.
put
(
jax
.
device_put
(
eval_params
,
local_devices
[
d_id
]))
threading
.
Thread
(
threading
.
Thread
(
target
=
rollout
,
target
=
rollout
,
args
=
(
args
=
(
...
@@ -741,6 +763,7 @@ if __name__ == "__main__":
...
@@ -741,6 +763,7 @@ if __name__ == "__main__":
d_idx
*
args
.
num_actor_threads
+
thread_id
,
d_idx
*
args
.
num_actor_threads
+
thread_id
,
),
),
)
.
start
()
)
.
start
()
params_queues
[
-
1
]
.
put
(
device_params
)
rollout_queue_get_time
=
deque
(
maxlen
=
10
)
rollout_queue_get_time
=
deque
(
maxlen
=
10
)
data_transfer_time
=
deque
(
maxlen
=
10
)
data_transfer_time
=
deque
(
maxlen
=
10
)
...
@@ -790,8 +813,9 @@ if __name__ == "__main__":
...
@@ -790,8 +813,9 @@ if __name__ == "__main__":
writer
.
add_scalar
(
"stats/rollout_queue_size"
,
rollout_queues
[
-
1
]
.
qsize
(),
global_step
)
writer
.
add_scalar
(
"stats/rollout_queue_size"
,
rollout_queues
[
-
1
]
.
qsize
(),
global_step
)
writer
.
add_scalar
(
"stats/params_queue_size"
,
params_queues
[
-
1
]
.
qsize
(),
global_step
)
writer
.
add_scalar
(
"stats/params_queue_size"
,
params_queues
[
-
1
]
.
qsize
(),
global_step
)
print
(
print
(
global_step
,
f
"{global_step} actor_update={update}, "
f
"actor_update={update}, train_time={time.time() - training_time_start:.2f}"
,
f
"train_time={time.time() - training_time_start:.2f}, "
f
"data_time={rollout_queue_get_time[-1]:.2f}"
)
)
writer
.
add_scalar
(
writer
.
add_scalar
(
"charts/learning_rate"
,
agent_state
.
opt_state
[
2
][
1
]
.
hyperparams
[
"learning_rate"
][
-
1
]
.
item
(),
global_step
"charts/learning_rate"
,
agent_state
.
opt_state
[
2
][
1
]
.
hyperparams
[
"learning_rate"
][
-
1
]
.
item
(),
global_step
...
...
scripts/jax/ppo.py
View file @
43ca871e
...
@@ -22,11 +22,12 @@ from flax.training.train_state import TrainState
...
@@ -22,11 +22,12 @@ from flax.training.train_state import TrainState
from
rich.pretty
import
pprint
from
rich.pretty
import
pprint
from
tensorboardX
import
SummaryWriter
from
tensorboardX
import
SummaryWriter
from
ygoai.utils
import
init_ygopro
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
compute_gae_2p0s
,
upgo_advantage
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
from
ygoai.rl.jax.switch
import
truncated_gae_2p0s
os
.
environ
[
"XLA_FLAGS"
]
=
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
os
.
environ
[
"XLA_FLAGS"
]
=
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
...
@@ -77,8 +78,6 @@ class Args:
...
@@ -77,8 +78,6 @@ class Args:
"""the number of actor threads to use"""
"""the number of actor threads to use"""
num_steps
:
int
=
128
num_steps
:
int
=
128
"""the number of steps to run in each environment per policy rollout"""
"""the number of steps to run in each environment per policy rollout"""
collect_length
:
Optional
[
int
]
=
None
"""the number of steps to compute the advantages"""
anneal_lr
:
bool
=
False
anneal_lr
:
bool
=
False
"""Toggle learning rate annealing for policy and value networks"""
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
1.0
gamma
:
float
=
1.0
...
@@ -95,8 +94,10 @@ class Args:
...
@@ -95,8 +94,10 @@ class Args:
"""Toggles advantages normalization"""
"""Toggles advantages normalization"""
clip_coef
:
float
=
0.25
clip_coef
:
float
=
0.25
"""the surrogate clipping coefficient"""
"""the surrogate clipping coefficient"""
dual_clip_coef
:
Optional
[
float
]
=
None
"""the dual surrogate clipping coefficient, typically 3.0"""
spo_kld_max
:
Optional
[
float
]
=
None
spo_kld_max
:
Optional
[
float
]
=
None
"""the maximum KLD for the SPO policy"""
"""the maximum KLD for the SPO policy
, typically 0.02
"""
ent_coef
:
float
=
0.01
ent_coef
:
float
=
0.01
"""coefficient of the entropy"""
"""coefficient of the entropy"""
vf_coef
:
float
=
0.5
vf_coef
:
float
=
0.5
...
@@ -144,9 +145,10 @@ class Args:
...
@@ -144,9 +145,10 @@ class Args:
actor_devices
:
Optional
[
List
[
str
]]
=
None
actor_devices
:
Optional
[
List
[
str
]]
=
None
learner_devices
:
Optional
[
List
[
str
]]
=
None
learner_devices
:
Optional
[
List
[
str
]]
=
None
num_embeddings
:
Optional
[
int
]
=
None
num_embeddings
:
Optional
[
int
]
=
None
freeze_id
:
bool
=
False
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
):
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
,
eval
=
False
):
if
not
args
.
thread_affinity
:
if
not
args
.
thread_affinity
:
thread_affinity_offset
=
-
1
thread_affinity_offset
=
-
1
if
thread_affinity_offset
>=
0
:
if
thread_affinity_offset
>=
0
:
...
@@ -163,7 +165,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
...
@@ -163,7 +165,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
max_options
=
args
.
max_options
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
n_history_actions
=
args
.
n_history_actions
,
async_reset
=
False
,
async_reset
=
False
,
greedy_reward
=
args
.
greedy_reward
,
greedy_reward
=
args
.
greedy_reward
if
not
eval
else
True
,
play_mode
=
mode
,
play_mode
=
mode
,
)
)
envs
.
num_envs
=
num_envs
envs
.
num_envs
=
num_envs
...
@@ -189,6 +191,7 @@ def create_agent(args, multi_step=False):
...
@@ -189,6 +191,7 @@ def create_agent(args, multi_step=False):
param_dtype
=
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
lstm_channels
=
args
.
rnn_channels
,
lstm_channels
=
args
.
rnn_channels
,
multi_step
=
multi_step
,
multi_step
=
multi_step
,
freeze_id
=
args
.
freeze_id
,
)
)
...
@@ -226,7 +229,7 @@ def rollout(
...
@@ -226,7 +229,7 @@ def rollout(
args
,
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
)
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
,
eval
=
True
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
...
@@ -296,7 +299,6 @@ def rollout(
...
@@ -296,7 +299,6 @@ def rollout(
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
])
np
.
random
.
shuffle
(
main_player
)
np
.
random
.
shuffle
(
main_player
)
start_step
=
0
storage
=
[]
storage
=
[]
@
jax
.
jit
@
jax
.
jit
...
@@ -327,7 +329,7 @@ def rollout(
...
@@ -327,7 +329,7 @@ def rollout(
rollout_time_start
=
time
.
time
()
rollout_time_start
=
time
.
time
()
init_rstate1
,
init_rstate2
=
jax
.
tree
.
map
(
init_rstate1
,
init_rstate2
=
jax
.
tree
.
map
(
lambda
x
:
x
.
copy
(),
(
next_rstate1
,
next_rstate2
))
lambda
x
:
x
.
copy
(),
(
next_rstate1
,
next_rstate2
))
for
_
in
range
(
start_step
,
args
.
collect_length
):
for
_
in
range
(
args
.
num_steps
):
global_step
+=
args
.
local_num_envs
*
n_actors
*
args
.
world_size
global_step
+=
args
.
local_num_envs
*
n_actors
*
args
.
world_size
cached_next_obs
=
next_obs
cached_next_obs
=
next_obs
...
@@ -379,10 +381,8 @@ def rollout(
...
@@ -379,10 +381,8 @@ def rollout(
rollout_time
.
append
(
time
.
time
()
-
rollout_time_start
)
rollout_time
.
append
(
time
.
time
()
-
rollout_time_start
)
start_step
=
args
.
collect_length
-
args
.
num_steps
partitioned_storage
=
prepare_data
(
storage
)
partitioned_storage
=
prepare_data
(
storage
)
storage
=
storage
[
args
.
num_steps
:
]
storage
=
[
]
sharded_storage
=
[]
sharded_storage
=
[]
for
x
in
partitioned_storage
:
for
x
in
partitioned_storage
:
if
isinstance
(
x
,
dict
):
if
isinstance
(
x
,
dict
):
...
@@ -418,10 +418,13 @@ def rollout(
...
@@ -418,10 +418,13 @@ def rollout(
SPS_update
=
int
(
args
.
batch_size
/
(
time
.
time
()
-
update_time_start
))
SPS_update
=
int
(
args
.
batch_size
/
(
time
.
time
()
-
update_time_start
))
if
device_thread_id
==
0
:
if
device_thread_id
==
0
:
print
(
print
(
f
"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}
, rollout_time={rollout_time[-1]:.2f}
"
f
"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
)
)
time_now
=
datetime
.
now
(
timezone
(
timedelta
(
hours
=
8
)))
.
strftime
(
"
%
H:
%
M:
%
S"
)
time_now
=
datetime
.
now
(
timezone
(
timedelta
(
hours
=
8
)))
.
strftime
(
"
%
H:
%
M:
%
S"
)
print
(
f
"{time_now} SPS: {SPS}, update: {SPS_update}"
)
print
(
f
"{time_now} SPS: {SPS}, update: {SPS_update}, "
f
"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
)
writer
.
add_scalar
(
"stats/rollout_time"
,
np
.
mean
(
rollout_time
),
global_step
)
writer
.
add_scalar
(
"stats/rollout_time"
,
np
.
mean
(
rollout_time
),
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_return"
,
avg_episodic_return
,
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_return"
,
avg_episodic_return
,
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_length"
,
avg_episodic_length
,
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_length"
,
avg_episodic_length
,
global_step
)
...
@@ -436,14 +439,13 @@ def rollout(
...
@@ -436,14 +439,13 @@ def rollout(
_start
=
time
.
time
()
_start
=
time
.
time
()
if
eval_mode
==
'bot'
:
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
eval_stat
=
evaluate
(
eval_return
,
eval_ep_len
,
eval_win_rate
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)[
0
]
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
metric_name
=
"eval_return"
else
:
else
:
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
eval_
stat
=
battle
(
eval_
return
,
eval_ep_len
,
eval_win_rate
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
[
2
]
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
metric_name
=
"eval_win_rate"
eval_stat
=
np
.
array
([
eval_return
,
eval_win_rate
])
if
device_thread_id
!=
0
:
if
device_thread_id
!=
0
:
eval_queue
.
put
(
eval_stat
)
eval_queue
.
put
(
eval_stat
)
else
:
else
:
...
@@ -451,12 +453,14 @@ def rollout(
...
@@ -451,12 +453,14 @@ def rollout(
eval_stats
.
append
(
eval_stat
)
eval_stats
.
append
(
eval_stat
)
for
_
in
range
(
1
,
n_actors
):
for
_
in
range
(
1
,
n_actors
):
eval_stats
.
append
(
eval_queue
.
get
())
eval_stats
.
append
(
eval_queue
.
get
())
eval_stats
=
np
.
mean
(
eval_stats
)
eval_stats
=
np
.
stack
(
eval_stats
)
writer
.
add_scalar
(
f
"charts/{metric_name}"
,
eval_stats
,
global_step
)
eval_return
,
eval_win_rate
=
np
.
mean
(
eval_stats
,
axis
=
0
)
writer
.
add_scalar
(
f
"charts/eval_return"
,
eval_return
,
global_step
)
writer
.
add_scalar
(
f
"charts/eval_win_rate"
,
eval_win_rate
,
global_step
)
if
device_thread_id
==
0
:
if
device_thread_id
==
0
:
eval_time
=
time
.
time
()
-
_start
eval_time
=
time
.
time
()
-
_start
print
(
f
"eval_time={eval_time:.4f}, {metric_name}={eval_stats:.4f}"
)
other_time
+=
eval_time
other_time
+=
eval_time
print
(
f
"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
@@ -485,8 +489,15 @@ if __name__ == "__main__":
...
@@ -485,8 +489,15 @@ if __name__ == "__main__":
args
.
minibatch_size
=
args
.
local_minibatch_size
*
args
.
world_size
args
.
minibatch_size
=
args
.
local_minibatch_size
*
args
.
world_size
args
.
num_updates
=
args
.
total_timesteps
//
(
args
.
local_batch_size
*
args
.
world_size
)
args
.
num_updates
=
args
.
total_timesteps
//
(
args
.
local_batch_size
*
args
.
world_size
)
args
.
local_env_threads
=
args
.
local_env_threads
or
args
.
local_num_envs
args
.
local_env_threads
=
args
.
local_env_threads
or
args
.
local_num_envs
args
.
collect_length
=
args
.
collect_length
or
args
.
num_steps
assert
args
.
collect_length
>=
args
.
num_steps
,
"collect_length must be greater than or equal to num_steps"
if
args
.
embedding_file
:
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
embedding_shape
=
embeddings
.
shape
args
.
num_embeddings
=
embedding_shape
args
.
freeze_id
=
True
if
args
.
freeze_id
is
None
else
args
.
freeze_id
else
:
embeddings
=
None
embedding_shape
=
None
local_devices
=
jax
.
local_devices
()
local_devices
=
jax
.
local_devices
()
global_devices
=
jax
.
devices
()
global_devices
=
jax
.
devices
()
...
@@ -541,6 +552,13 @@ if __name__ == "__main__":
...
@@ -541,6 +552,13 @@ if __name__ == "__main__":
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
agent
=
create_agent
(
args
)
agent
=
create_agent
(
args
)
params
=
agent
.
init
(
agent_key
,
(
rstate
,
sample_obs
))
params
=
agent
.
init
(
agent_key
,
(
rstate
,
sample_obs
))
if
embeddings
is
not
None
:
unknown_embed
=
embeddings
.
mean
(
axis
=
0
)
embeddings
=
np
.
concatenate
([
unknown_embed
[
None
,
:],
embeddings
],
axis
=
0
)
params
=
flax
.
core
.
unfreeze
(
params
)
params
[
'params'
][
'Encoder_0'
][
'Embed_0'
][
'embedding'
]
=
jax
.
device_put
(
embeddings
)
params
=
flax
.
core
.
freeze
(
params
)
tx
=
optax
.
MultiSteps
(
tx
=
optax
.
MultiSteps
(
optax
.
chain
(
optax
.
chain
(
optax
.
clip_by_global_norm
(
args
.
max_grad_norm
),
optax
.
clip_by_global_norm
(
args
.
max_grad_norm
),
...
@@ -587,66 +605,53 @@ if __name__ == "__main__":
...
@@ -587,66 +605,53 @@ if __name__ == "__main__":
num_envs
=
next_value
.
shape
[
0
]
num_envs
=
next_value
.
shape
[
0
]
num_steps
=
dones
.
shape
[
0
]
//
num_envs
num_steps
=
dones
.
shape
[
0
]
//
num_envs
mask
=
mask
&
(
~
dones
)
mask
=
mask
*
(
1.0
-
dones
)
n_valids
=
jnp
.
sum
(
mask
)
n_valids
=
jnp
.
sum
(
mask
)
real_dones
=
dones
|
next_dones
real_dones
=
dones
|
next_dones
inputs
=
(
rstate1
,
rstate2
,
obs
,
real_dones
,
switch
)
inputs
=
(
rstate1
,
rstate2
,
obs
,
real_dones
,
switch
)
new_logits
,
new_values
=
get_logits_and_value
(
params
,
inputs
)
new_logits
,
new_values
=
get_logits_and_value
(
params
,
inputs
)
values
,
rewards
,
next_dones
,
switch
=
jax
.
tree
.
map
(
new_values_
,
rewards
,
next_dones
,
switch
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)),
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)
+
x
.
shape
[
1
:]
),
(
jax
.
lax
.
stop_gradient
(
new_values
)
,
rewards
,
next_dones
,
switch
),
(
new_values
,
rewards
,
next_dones
,
switch
),
)
)
advantages
,
target_values
=
compute_gae_2p0s
(
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
next_value
,
values
,
rewards
,
next_dones
,
switch
,
args
.
gamma
,
args
.
gae_lambda
)
if
args
.
upgo
:
advantages
=
advantages
+
upgo_advantage
(
next_value
,
values
,
rewards
,
next_dones
,
switch
,
args
.
gamma
)
advantages
,
target_values
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)),
(
advantages
,
target_values
))
ratio
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
logratio
=
jnp
.
log
(
ratio
)
approx_kl
=
(((
ratio
-
1
)
-
logratio
)
*
mask
)
.
sum
()
/
n_valids
target_values
,
advantages
=
truncated_gae_2p0s
(
next_value
,
new_values_
,
rewards
,
next_dones
,
switch
,
args
.
gamma
,
args
.
gae_lambda
,
args
.
upgo
)
target_values
,
advantages
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)),
(
target_values
,
advantages
))
logratio
=
jnp
.
log
(
ratios
)
approx_kl
=
(((
ratios
-
1
)
-
logratio
)
*
mask
)
.
sum
()
/
n_valids
if
args
.
norm_adv
:
if
args
.
norm_adv
:
advantages
=
masked_normalize
(
advantages
,
mask
,
eps
=
1e-8
)
advantages
=
masked_normalize
(
advantages
,
mask
,
eps
=
1e-8
)
# Policy loss
# Policy loss
if
args
.
spo_kld_max
is
not
None
:
if
args
.
spo_kld_max
is
not
None
:
probs
=
jax
.
nn
.
softmax
(
logits
)
pg_loss
=
simple_policy_loss
(
new_probs
=
jax
.
nn
.
softmax
(
new_logits
)
ratios
,
logits
,
new_logits
,
advantages
,
args
.
spo_kld_max
)
eps
=
1e-8
kld
=
jnp
.
sum
(
probs
*
jnp
.
log
((
probs
+
eps
)
/
(
new_probs
+
eps
)),
axis
=-
1
)
kld_clip
=
jnp
.
clip
(
kld
,
0
,
args
.
spo_kld_max
)
d_ratio
=
kld_clip
/
(
kld
+
eps
)
d_ratio
=
jnp
.
where
(
kld
<
1e-6
,
1.0
,
d_ratio
)
sign_a
=
jnp
.
sign
(
advantages
)
result
=
(
d_ratio
+
sign_a
-
1
)
*
sign_a
pg_loss
=
-
advantages
*
ratio
*
result
else
:
else
:
pg_loss1
=
-
advantages
*
ratio
pg_loss
=
clipped_surrogate_pg_loss
(
pg_loss2
=
-
advantages
*
jnp
.
clip
(
ratio
,
1
-
args
.
clip_coef
,
1
+
args
.
clip_coef
)
ratios
,
advantages
,
args
.
clip_coef
,
args
.
dual_clip_coef
)
pg_loss
=
jnp
.
maximum
(
pg_loss1
,
pg_loss2
)
pg_loss
=
jnp
.
sum
(
pg_loss
*
mask
)
pg_loss
=
jnp
.
sum
(
pg_loss
*
mask
)
v_loss
=
0.5
*
((
new_values
-
target_values
)
**
2
)
v_loss
=
mse_loss
(
new_values
,
target_values
)
v_loss
=
jnp
.
sum
(
v_loss
*
mask
)
v_loss
=
jnp
.
sum
(
v_loss
*
mask
)
ent
ropy_loss
=
distrax
.
Softmax
(
new_logits
)
.
entropy
(
)
ent
_loss
=
entropy_loss
(
new_logits
)
ent
ropy_loss
=
jnp
.
sum
(
entropy
_loss
*
mask
)
ent
_loss
=
jnp
.
sum
(
ent
_loss
*
mask
)
pg_loss
=
pg_loss
/
n_valids
pg_loss
=
pg_loss
/
n_valids
v_loss
=
v_loss
/
n_valids
v_loss
=
v_loss
/
n_valids
ent
ropy_loss
=
entropy
_loss
/
n_valids
ent
_loss
=
ent
_loss
/
n_valids
loss
=
pg_loss
-
args
.
ent_coef
*
ent
ropy
_loss
+
v_loss
*
args
.
vf_coef
loss
=
pg_loss
-
args
.
ent_coef
*
ent_loss
+
v_loss
*
args
.
vf_coef
return
loss
,
(
pg_loss
,
v_loss
,
ent
ropy
_loss
,
jax
.
lax
.
stop_gradient
(
approx_kl
))
return
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
jax
.
lax
.
stop_gradient
(
approx_kl
))
def
single_device_update
(
def
single_device_update
(
agent_state
:
TrainState
,
agent_state
:
TrainState
,
...
@@ -702,7 +707,8 @@ if __name__ == "__main__":
...
@@ -702,7 +707,8 @@ if __name__ == "__main__":
x
=
jnp
.
reshape
(
x
,
(
N
,
-
1
)
+
x
.
shape
[
1
:])
x
=
jnp
.
reshape
(
x
,
(
N
,
-
1
)
+
x
.
shape
[
1
:])
return
x
return
x
shuffled_init_rstate1
,
shuffled_init_rstate2
,
shuffled_next_value
=
jax
.
tree
.
map
(
shuffled_init_rstate1
,
shuffled_init_rstate2
,
\
shuffled_next_value
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
1
),
(
init_rstate1
,
init_rstate2
,
next_value
))
partial
(
convert_data
,
num_steps
=
1
),
(
init_rstate1
,
init_rstate2
,
next_value
))
shuffled_storage
,
shuffled_switch
=
jax
.
tree
.
map
(
shuffled_storage
,
shuffled_switch
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
num_steps
),
(
storage
,
switch
))
partial
(
convert_data
,
num_steps
=
num_steps
),
(
storage
,
switch
))
...
@@ -829,8 +835,9 @@ if __name__ == "__main__":
...
@@ -829,8 +835,9 @@ if __name__ == "__main__":
writer
.
add_scalar
(
"stats/rollout_queue_size"
,
rollout_queues
[
-
1
]
.
qsize
(),
global_step
)
writer
.
add_scalar
(
"stats/rollout_queue_size"
,
rollout_queues
[
-
1
]
.
qsize
(),
global_step
)
writer
.
add_scalar
(
"stats/params_queue_size"
,
params_queues
[
-
1
]
.
qsize
(),
global_step
)
writer
.
add_scalar
(
"stats/params_queue_size"
,
params_queues
[
-
1
]
.
qsize
(),
global_step
)
print
(
print
(
global_step
,
f
"{global_step} actor_update={update}, "
f
"actor_update={update}, train_time={time.time() - training_time_start:.2f}"
,
f
"train_time={time.time() - training_time_start:.2f}, "
f
"data_time={rollout_queue_get_time[-1]:.2f}"
)
)
writer
.
add_scalar
(
writer
.
add_scalar
(
"charts/learning_rate"
,
agent_state
.
opt_state
[
2
][
1
]
.
hyperparams
[
"learning_rate"
][
-
1
]
.
item
(),
global_step
"charts/learning_rate"
,
agent_state
.
opt_state
[
2
][
1
]
.
hyperparams
[
"learning_rate"
][
-
1
]
.
item
(),
global_step
...
...
ygoai/rl/jax/__init__.py
View file @
43ca871e
...
@@ -2,340 +2,273 @@ from functools import partial
...
@@ -2,340 +2,273 @@ from functools import partial
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
typing
import
NamedTuple
class
VTraceOutput
(
NamedTuple
):
q_estimate
:
jnp
.
ndarray
errors
:
jnp
.
ndarray
def
vtrace
(
v_tm1
,
v_t
,
r_t
,
discount_t
,
rho_tm1
,
lambda_
=
1.0
,
c_clip_min
:
float
=
0.001
,
c_clip_max
:
float
=
1.007
,
rho_clip_min
:
float
=
0.001
,
rho_clip_max
:
float
=
1.007
,
stop_target_gradients
:
bool
=
True
,
):
"""
Args:
v_tm1: values at time t-1.
v_t: values at time t.
r_t: reward at time t.
discount_t: discount at time t.
rho_tm1: importance sampling ratios at time t-1.
lambda_: mixing parameter; a scalar or a vector for timesteps t.
clip_rho_threshold: clip threshold for importance weights.
stop_target_gradients: whether or not to apply stop gradient to targets.
"""
# Clip importance sampling ratios.
lambda_
=
jnp
.
ones_like
(
discount_t
)
*
lambda_
c_tm1
=
jnp
.
clip
(
rho_tm1
,
c_clip_min
,
c_clip_max
)
*
lambda_
clipped_rhos_tm1
=
jnp
.
clip
(
rho_tm1
,
rho_clip_min
,
rho_clip_max
)
# Compute the temporal difference errors.
td_errors
=
clipped_rhos_tm1
*
(
r_t
+
discount_t
*
v_t
-
v_tm1
)
# Work backwards computing the td-errors.
def
_body
(
acc
,
xs
):
td_error
,
discount
,
c
=
xs
acc
=
td_error
+
discount
*
c
*
acc
return
acc
,
acc
_
,
errors
=
jax
.
lax
.
scan
(
_body
,
0.0
,
(
td_errors
,
discount_t
,
c_tm1
),
reverse
=
True
)
# Return errors, maybe disabling gradient flow through bootstrap targets.
errors
=
jax
.
lax
.
select
(
stop_target_gradients
,
jax
.
lax
.
stop_gradient
(
errors
+
v_tm1
)
-
v_tm1
,
errors
)
targets_tm1
=
errors
+
v_tm1
q_bootstrap
=
jnp
.
concatenate
([
lambda_
[:
-
1
]
*
targets_tm1
[
1
:]
+
(
1
-
lambda_
[:
-
1
])
*
v_tm1
[
1
:],
v_t
[
-
1
:],
],
axis
=
0
)
q_estimate
=
r_t
+
discount_t
*
q_bootstrap
return
VTraceOutput
(
q_estimate
=
q_estimate
,
errors
=
errors
)
def
clipped_surrogate_pg_loss
(
prob_ratios_t
,
adv_t
,
mask
,
epsilon
,
use_stop_gradient
=
True
):
adv_t
=
jax
.
lax
.
select
(
use_stop_gradient
,
jax
.
lax
.
stop_gradient
(
adv_t
),
adv_t
)
clipped_ratios_t
=
jnp
.
clip
(
prob_ratios_t
,
1.
-
epsilon
,
1.
+
epsilon
)
clipped_objective
=
jnp
.
fmin
(
prob_ratios_t
*
adv_t
,
clipped_ratios_t
*
adv_t
)
return
-
jnp
.
mean
(
clipped_objective
*
mask
)
@
partial
(
jax
.
jit
,
static_argnums
=
(
5
,
6
))
def
compute_gae_2p0s
(
next_value
,
values
,
rewards
,
next_dones
,
switch
,
gamma
,
gae_lambda
,
):
def
body_fn
(
carry
,
inp
):
boot_value
,
boot_done
,
next_value
,
lastgaelam
=
carry
next_done
,
cur_value
,
reward
,
switch
=
inp
next_done
=
jnp
.
where
(
switch
,
boot_done
,
next_done
)
import
chex
next_value
=
jnp
.
where
(
switch
,
-
boot_value
,
next_value
)
import
distrax
lastgaelam
=
jnp
.
where
(
switch
,
0
,
lastgaelam
)
gamma_
=
gamma
*
(
1.0
-
next_done
)
delta
=
reward
+
gamma_
*
next_value
-
cur_value
lastgaelam
=
delta
+
gae_lambda
*
gamma_
*
lastgaelam
return
(
boot_value
,
boot_done
,
cur_value
,
lastgaelam
),
lastgaelam
next_done
=
next_dones
[
-
1
]
# class VTraceOutput(NamedTuple):
lastgaelam
=
jnp
.
zeros_like
(
next_value
)
# q_estimate: jnp.ndarray
carry
=
next_value
,
next_done
,
next_value
,
lastgaelam
# errors: jnp.ndarray
_
,
advantages
=
jax
.
lax
.
scan
(
body_fn
,
carry
,
(
next_dones
,
values
,
rewards
,
switch
),
reverse
=
True
)
target_values
=
advantages
+
values
return
advantages
,
target_values
@
partial
(
jax
.
jit
,
static_argnums
=
(
5
,))
def
upgo_advantage
(
next_value
,
values
,
rewards
,
next_dones
,
switch
,
gamma
):
def
body_fn
(
carry
,
inp
):
boot_value
,
boot_done
,
next_value
,
next_q
,
last_return
=
carry
next_done
,
cur_value
,
reward
,
switch
=
inp
next_done
=
jnp
.
where
(
switch
,
boot_done
,
next_done
)
next_value
=
jnp
.
where
(
switch
,
-
boot_value
,
next_value
)
next_q
=
jnp
.
where
(
switch
,
-
boot_value
*
gamma
,
next_q
)
last_return
=
jnp
.
where
(
switch
,
-
boot_value
,
last_return
)
gamma_
=
gamma
*
(
1.0
-
next_done
)
last_return
=
reward
+
gamma_
*
jnp
.
where
(
next_q
>=
next_value
,
last_return
,
next_value
)
next_q
=
reward
+
gamma_
*
next_value
carry
=
boot_value
,
boot_done
,
cur_value
,
next_q
,
last_return
return
carry
,
last_return
next_done
=
next_dones
[
-
1
]
carry
=
next_value
,
next_done
,
next_value
,
next_value
,
next_value
_
,
returns
=
jax
.
lax
.
scan
(
body_fn
,
carry
,
(
next_dones
,
values
,
rewards
,
switch
),
reverse
=
True
)
return
returns
-
values
# def compute_gae_once(carry, inp, gamma, gae_lambda):
# v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2 = carry
# rho, cur_values, log_ratio, next_done, r_t, corr_r_t, main = inp
# v = jnp.where(main, v1, v2)
# next_values = jnp.where(main, next_values1, next_values2)
# reward = jnp.where(main, reward1, reward2)
# xi = jnp.where(main, xi1, xi2)
# p_t = c_t = jnp.minimum(1.0, rho * xi)
# sig_v = p_t * (r_t + reward * rho + next_values - cur_values)
# reg_r = jnp.log(p / p_reg)
# q = r_t + rho * (reward + v)
# q = -eta * + cur_values
# v = cur_values + sig_v + c_t * (v - next_values)
# v1 = jnp.where(main, v, v1)
# v2 = jnp.where(main, v2, v)
# next_values1 = jnp.where(main, cur_values, next_values1)
# next_values2 = jnp.where(main, next_values2, cur_values)
# reward1 = jnp.where(main, 0, r_t + rho * reward1)
# reward2 = jnp.where(main, r_t + rho * reward2, 0)
# xi1 = jnp.where(main, 1, rho * xi1)
# xi2 = jnp.where(main, rho * xi2, 1)
# learn1 = learn
# learn2 = ~learn
# factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward))
# reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1))
# reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2))
# real_done1 = next_done | ~done_used1
# nextvalues1 = jnp.where(real_done1, 0, nextvalues1)
# lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
# real_done2 = next_done | ~done_used2
# nextvalues2 = jnp.where(real_done2, 0, nextvalues2)
# lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
# done_used1 = jnp.where(
# next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1))
# done_used2 = jnp.where(
# next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2))
# delta1 = reward1 + gamma * nextvalues1 - curvalues
# delta2 = reward2 + gamma * nextvalues2 - curvalues
# lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
# lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
# advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_)
# nextvalues1 = jnp.where(learn1, curvalues, nextvalues1)
# nextvalues2 = jnp.where(learn2, curvalues, nextvalues2)
# lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1)
# lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2)
# carry = nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
# return carry, advantages
# @partial(jax.jit, static_argnums=(6, 7))
# def vtrace_rnad(
# next_value, next_done, values, rewards, dones, learns,
# gamma, gae_lambda,
# ):
# next_value1 = next_value
# next_value2 = -next_value1
# done_used1 = jnp.ones_like(next_done)
# done_used2 = jnp.ones_like(next_done)
# reward1 = jnp.zeros_like(next_value)
# reward2 = jnp.zeros_like(next_value)
# lastgaelam1 = jnp.zeros_like(next_value)
# lastgaelam2 = jnp.zeros_like(next_value)
# carry = next_value1, next_value2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
# dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
# _, advantages = jax.lax.scan(
# partial(compute_gae_once, gamma=gamma, gae_lambda=gae_lambda),
# carry, (dones[1:], values, rewards, learns), reverse=True
# )
# target_values = advantages + values
# return advantages, target_values
def
compute_gae_once
(
carry
,
inp
,
gamma
,
gae_lambda
):
nextvalues1
,
nextvalues2
,
done_used1
,
done_used2
,
reward1
,
reward2
,
lastgaelam1
,
lastgaelam2
=
carry
next_done
,
curvalues
,
reward
,
learn
=
inp
learn1
=
learn
learn2
=
~
learn
factor
=
jnp
.
where
(
learn1
,
jnp
.
ones_like
(
reward
),
-
jnp
.
ones_like
(
reward
))
reward1
=
jnp
.
where
(
next_done
,
reward
*
factor
,
jnp
.
where
(
learn1
&
done_used1
,
0
,
reward1
))
reward2
=
jnp
.
where
(
next_done
,
reward
*
-
factor
,
jnp
.
where
(
learn2
&
done_used2
,
0
,
reward2
))
real_done1
=
next_done
|
~
done_used1
nextvalues1
=
jnp
.
where
(
real_done1
,
0
,
nextvalues1
)
lastgaelam1
=
jnp
.
where
(
real_done1
,
0
,
lastgaelam1
)
real_done2
=
next_done
|
~
done_used2
nextvalues2
=
jnp
.
where
(
real_done2
,
0
,
nextvalues2
)
lastgaelam2
=
jnp
.
where
(
real_done2
,
0
,
lastgaelam2
)
done_used1
=
jnp
.
where
(
next_done
,
learn1
,
jnp
.
where
(
learn1
&
~
done_used1
,
True
,
done_used1
))
done_used2
=
jnp
.
where
(
next_done
,
learn2
,
jnp
.
where
(
learn2
&
~
done_used2
,
True
,
done_used2
))
delta1
=
reward1
+
gamma
*
nextvalues1
-
curvalues
# def vtrace(
delta2
=
reward2
+
gamma
*
nextvalues2
-
curvalues
# v_tm1,
lastgaelam1_
=
delta1
+
gamma
*
gae_lambda
*
lastgaelam1
# v_t,
lastgaelam2_
=
delta2
+
gamma
*
gae_lambda
*
lastgaelam2
# r_t,
advantages
=
jnp
.
where
(
learn1
,
lastgaelam1_
,
lastgaelam2_
)
# discount_t,
nextvalues1
=
jnp
.
where
(
learn1
,
curvalues
,
nextvalues1
)
# rho_tm1,
nextvalues2
=
jnp
.
where
(
learn2
,
curvalues
,
nextvalues2
)
# lambda_=1.0,
lastgaelam1
=
jnp
.
where
(
learn1
,
lastgaelam1_
,
lastgaelam1
)
# c_clip_min: float = 0.001,
lastgaelam2
=
jnp
.
where
(
learn2
,
lastgaelam2_
,
lastgaelam2
)
# c_clip_max: float = 1.007,
carry
=
nextvalues1
,
nextvalues2
,
done_used1
,
done_used2
,
reward1
,
reward2
,
lastgaelam1
,
lastgaelam2
# rho_clip_min: float = 0.001,
return
carry
,
advantages
# rho_clip_max: float = 1.007,
# stop_target_gradients: bool = True,
# ):
@
partial
(
jax
.
jit
,
static_argnums
=
(
7
,
8
))
# """
def
compute_gae
(
next_value
,
next_done
,
next_learn
,
# Args:
values
,
rewards
,
dones
,
learns
,
# v_tm1: values at time t-1.
gamma
,
gae_lambda
,
# v_t: values at time t.
# r_t: reward at time t.
# discount_t: discount at time t.
# rho_tm1: importance sampling ratios at time t-1.
# lambda_: mixing parameter; a scalar or a vector for timesteps t.
# clip_rho_threshold: clip threshold for importance weights.
# stop_target_gradients: whether or not to apply stop gradient to targets.
# """
# # Clip importance sampling ratios.
# lambda_ = jnp.ones_like(discount_t) * lambda_
# c_tm1 = jnp.clip(rho_tm1, c_clip_min, c_clip_max) * lambda_
# clipped_rhos_tm1 = jnp.clip(rho_tm1, rho_clip_min, rho_clip_max)
# # Compute the temporal difference errors.
# td_errors = clipped_rhos_tm1 * (r_t + discount_t * v_t - v_tm1)
# # Work backwards computing the td-errors.
# def _body(acc, xs):
# td_error, discount, c = xs
# acc = td_error + discount * c * acc
# return acc, acc
# _, errors = jax.lax.scan(
# _body, 0.0, (td_errors, discount_t, c_tm1), reverse=True)
# # Return errors, maybe disabling gradient flow through bootstrap targets.
# errors = jax.lax.select(
# stop_target_gradients,
# jax.lax.stop_gradient(errors + v_tm1) - v_tm1,
# errors)
# targets_tm1 = errors + v_tm1
# q_bootstrap = jnp.concatenate([
# lambda_[:-1] * targets_tm1[1:] + (1 - lambda_[:-1]) * v_tm1[1:],
# v_t[-1:],
# ], axis=0)
# q_estimate = r_t + discount_t * q_bootstrap
# return VTraceOutput(q_estimate=q_estimate, errors=errors)
def
entropy_loss
(
logits
):
return
distrax
.
Softmax
(
logits
=
logits
)
.
entropy
()
def
mse_loss
(
y_true
,
y_pred
):
return
0.5
*
((
y_true
-
y_pred
)
**
2
)
def
policy_gradient_loss
(
logits
,
actions
,
advantages
):
chex
.
assert_type
([
logits
,
actions
,
advantages
],
[
float
,
int
,
float
])
advs
=
jax
.
lax
.
stop_gradient
(
advantages
)
log_probs
=
distrax
.
Softmax
(
logits
=
logits
)
.
log_prob
(
actions
)
pg_loss
=
-
log_probs
*
advs
return
pg_loss
def
clipped_surrogate_pg_loss
(
ratios
,
advantages
,
clip_coef
,
dual_clip_coef
=
None
):
# dual clip from JueWu (Mastering Complex Control in MOBA Games with Deep Reinforcement Learning)
advs
=
jax
.
lax
.
stop_gradient
(
advantages
)
clipped_ratios
=
jnp
.
clip
(
ratios
,
1
-
clip_coef
,
1
+
clip_coef
)
clipped_obj
=
jnp
.
fmin
(
ratios
*
advs
,
clipped_ratios
*
advs
)
if
dual_clip_coef
is
not
None
:
clipped_obj
=
jnp
.
where
(
advs
>=
0
,
clipped_obj
,
jnp
.
fmax
(
clipped_obj
,
dual_clip_coef
*
advs
)
)
pg_loss
=
-
clipped_obj
return
pg_loss
def
vtrace_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
v1
,
v2
,
next_values1
,
next_values2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
last_return1
,
last_return2
,
next_q1
,
next_q2
=
carry
ratio
,
cur_values
,
next_done
,
r_t
,
main
=
inp
v1
=
jnp
.
where
(
next_done
,
0
,
v1
)
v2
=
jnp
.
where
(
next_done
,
0
,
v2
)
next_values1
=
jnp
.
where
(
next_done
,
0
,
next_values1
)
next_values2
=
jnp
.
where
(
next_done
,
0
,
next_values2
)
reward1
=
jnp
.
where
(
next_done
,
0
,
reward1
)
reward2
=
jnp
.
where
(
next_done
,
0
,
reward2
)
xi1
=
jnp
.
where
(
next_done
,
1
,
xi1
)
xi2
=
jnp
.
where
(
next_done
,
1
,
xi2
)
discount
=
gamma
*
(
1.0
-
next_done
)
v
=
jnp
.
where
(
main
,
v1
,
v2
)
next_values
=
jnp
.
where
(
main
,
next_values1
,
next_values2
)
reward
=
jnp
.
where
(
main
,
reward1
,
reward2
)
xi
=
jnp
.
where
(
main
,
xi1
,
xi2
)
q_t
=
r_t
+
ratio
*
reward
+
discount
*
v
rho_t
=
jnp
.
clip
(
ratio
*
xi
,
rho_min
,
rho_max
)
c_t
=
jnp
.
clip
(
ratio
*
xi
,
c_min
,
c_max
)
sig_v
=
rho_t
*
(
r_t
+
ratio
*
reward
+
discount
*
next_values
-
cur_values
)
v
=
cur_values
+
sig_v
+
c_t
*
discount
*
(
v
-
next_values
)
# UPGO advantage (not corrected by importance sampling, unlike V-trace)
return_t
=
jnp
.
where
(
main
,
last_return1
,
last_return2
)
next_q
=
jnp
.
where
(
main
,
next_q1
,
next_q2
)
factor
=
jnp
.
where
(
main
,
jnp
.
ones_like
(
r_t
),
-
jnp
.
ones_like
(
r_t
))
return_t
=
r_t
+
discount
*
jnp
.
where
(
next_q
>=
next_values
,
return_t
,
next_values
)
last_return1
=
jnp
.
where
(
next_done
,
r_t
*
factor
,
jnp
.
where
(
main
,
return_t
,
last_return1
))
last_return2
=
jnp
.
where
(
next_done
,
r_t
*
-
factor
,
jnp
.
where
(
main
,
last_return2
,
return_t
))
next_q
=
r_t
+
discount
*
next_values
next_q1
=
jnp
.
where
(
next_done
,
r_t
*
factor
,
jnp
.
where
(
main
,
next_q
,
next_q1
))
next_q2
=
jnp
.
where
(
next_done
,
r_t
*
-
factor
,
jnp
.
where
(
main
,
next_q2
,
next_q
))
v1
=
jnp
.
where
(
main
,
v
,
v1
)
v2
=
jnp
.
where
(
main
,
v2
,
v
)
next_values1
=
jnp
.
where
(
main
,
cur_values
,
next_values1
)
next_values2
=
jnp
.
where
(
main
,
next_values2
,
cur_values
)
reward1
=
jnp
.
where
(
main
,
0
,
-
r_t
+
ratio
*
reward1
)
reward2
=
jnp
.
where
(
main
,
-
r_t
+
ratio
*
reward2
,
0
)
xi1
=
jnp
.
where
(
main
,
1
,
ratio
*
xi1
)
xi2
=
jnp
.
where
(
main
,
ratio
*
xi2
,
1
)
carry
=
v1
,
v2
,
next_values1
,
next_values2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
last_return1
,
last_return2
,
next_q1
,
next_q2
return
carry
,
(
v
,
q_t
,
return_t
)
def
vtrace_2p0s
(
next_value
,
ratios
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
rho_min
=
0.001
,
rho_max
=
1.0
,
c_min
=
0.001
,
c_max
=
1.0
,
upgo
=
False
,
):
):
next_value1
=
jnp
.
where
(
next_learn
,
next_value
,
-
next_value
)
next_value1
=
next_value
next_value2
=
-
next_value1
next_value2
=
-
next_value1
done_used1
=
jnp
.
ones_like
(
next_done
)
v1
=
return1
=
next_q1
=
next_value1
done_used2
=
jnp
.
ones_like
(
next_done
)
v2
=
return2
=
next_q2
=
next_value2
reward1
=
jnp
.
zeros_like
(
next_value
)
reward1
=
reward2
=
jnp
.
zeros_like
(
next_value
)
reward2
=
jnp
.
zeros_like
(
next_value
)
xi1
=
xi2
=
jnp
.
ones_like
(
next_value
)
lastgaelam1
=
jnp
.
zeros_like
(
next_value
)
carry
=
v1
,
v2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
lastgaelam2
=
jnp
.
zeros_like
(
next_value
)
return1
,
return2
,
next_q1
,
next_q2
carry
=
next_value1
,
next_value2
,
done_used1
,
done_used2
,
reward1
,
reward2
,
lastgaelam1
,
lastgaelam2
_
,
(
targets
,
q_estimate
,
return_t
)
=
jax
.
lax
.
scan
(
dones
=
jnp
.
concatenate
([
dones
,
next_done
[
None
,
:]],
axis
=
0
)
partial
(
vtrace_loop
,
gamma
=
gamma
,
rho_min
=
rho_min
,
rho_max
=
rho_max
,
c_min
=
c_min
,
c_max
=
c_max
),
_
,
advantages
=
jax
.
lax
.
scan
(
carry
,
(
ratios
,
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
partial
(
compute_gae_once
,
gamma
=
gamma
,
gae_lambda
=
gae_lambda
),
carry
,
(
dones
[
1
:],
values
,
rewards
,
learns
),
reverse
=
True
)
)
target_values
=
advantages
+
values
advantages
=
q_estimate
-
values
return
advantages
,
target_values
if
upgo
:
advantages
+=
return_t
-
values
targets
=
jax
.
lax
.
stop_gradient
(
targets
)
def
compute_gae_once_upgo
(
carry
,
inp
,
gamma
,
gae_lambda
):
return
targets
,
advantages
next_value1
,
next_value2
,
next_q1
,
next_q2
,
last_return1
,
last_return2
,
\
done_used1
,
done_used2
,
reward1
,
reward2
,
lastgaelam1
,
lastgaelam2
=
carry
next_done
,
curvalues
,
reward
,
learn
=
inp
def
truncated_gae_upgo_loop
(
carry
,
inp
,
gamma
,
gae_lambda
):
learn1
=
learn
lastgaelam1
,
lastgaelam2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
\
learn2
=
~
learn
done_used1
,
done_used2
,
last_return1
,
last_return2
,
next_q1
,
next_q2
=
carry
factor
=
jnp
.
where
(
learn1
,
jnp
.
ones_like
(
reward
),
-
jnp
.
ones_like
(
reward
))
cur_value
,
next_done
,
reward
,
main
=
inp
reward1
=
jnp
.
where
(
next_done
,
reward
*
factor
,
jnp
.
where
(
learn1
&
done_used1
,
0
,
reward1
))
main1
=
main
reward2
=
jnp
.
where
(
next_done
,
reward
*
-
factor
,
jnp
.
where
(
learn2
&
done_used2
,
0
,
reward2
))
main2
=
~
main
factor
=
jnp
.
where
(
main1
,
jnp
.
ones_like
(
reward
),
-
jnp
.
ones_like
(
reward
))
reward1
=
jnp
.
where
(
next_done
,
reward
*
factor
,
jnp
.
where
(
main1
&
done_used1
,
0
,
reward1
))
reward2
=
jnp
.
where
(
next_done
,
reward
*
-
factor
,
jnp
.
where
(
main2
&
done_used2
,
0
,
reward2
))
real_done1
=
next_done
|
~
done_used1
real_done1
=
next_done
|
~
done_used1
next_value1
=
jnp
.
where
(
real_done1
,
0
,
next_value1
)
next_value1
=
jnp
.
where
(
real_done1
,
0
,
next_value1
)
last_return1
=
jnp
.
where
(
real_done1
,
0
,
last_return1
)
lastgaelam1
=
jnp
.
where
(
real_done1
,
0
,
lastgaelam1
)
lastgaelam1
=
jnp
.
where
(
real_done1
,
0
,
lastgaelam1
)
real_done2
=
next_done
|
~
done_used2
real_done2
=
next_done
|
~
done_used2
next_value2
=
jnp
.
where
(
real_done2
,
0
,
next_value2
)
next_value2
=
jnp
.
where
(
real_done2
,
0
,
next_value2
)
last_return2
=
jnp
.
where
(
real_done2
,
0
,
last_return2
)
lastgaelam2
=
jnp
.
where
(
real_done2
,
0
,
lastgaelam2
)
lastgaelam2
=
jnp
.
where
(
real_done2
,
0
,
lastgaelam2
)
done_used1
=
jnp
.
where
(
done_used1
=
jnp
.
where
(
next_done
,
learn1
,
jnp
.
where
(
lear
n1
&
~
done_used1
,
True
,
done_used1
))
next_done
,
main1
,
jnp
.
where
(
mai
n1
&
~
done_used1
,
True
,
done_used1
))
done_used2
=
jnp
.
where
(
done_used2
=
jnp
.
where
(
next_done
,
learn2
,
jnp
.
where
(
lear
n2
&
~
done_used2
,
True
,
done_used2
))
next_done
,
main2
,
jnp
.
where
(
mai
n2
&
~
done_used2
,
True
,
done_used2
))
# UPGO advantage
last_return1
=
jnp
.
where
(
real_done1
,
0
,
last_return1
)
last_return2
=
jnp
.
where
(
real_done2
,
0
,
last_return2
)
last_return1_
=
reward1
+
gamma
*
jnp
.
where
(
last_return1_
=
reward1
+
gamma
*
jnp
.
where
(
next_q1
>=
next_value1
,
last_return1
,
next_value1
)
next_q1
>=
next_value1
,
last_return1
,
next_value1
)
last_return2_
=
reward2
+
gamma
*
jnp
.
where
(
last_return2_
=
reward2
+
gamma
*
jnp
.
where
(
next_q2
>=
next_value2
,
last_return2
,
next_value2
)
next_q2
>=
next_value2
,
last_return2
,
next_value2
)
next_q1_
=
reward1
+
gamma
*
next_value1
next_q1_
=
reward1
+
gamma
*
next_value1
next_q2_
=
reward2
+
gamma
*
next_value2
next_q2_
=
reward2
+
gamma
*
next_value2
delta1
=
next_q1_
-
curvalues
next_q1
=
jnp
.
where
(
main1
,
next_q1_
,
next_q1
)
delta2
=
next_q2_
-
curvalues
next_q2
=
jnp
.
where
(
main2
,
next_q2_
,
next_q1
)
last_return1
=
jnp
.
where
(
main1
,
last_return1_
,
last_return1
)
last_return2
=
jnp
.
where
(
main2
,
last_return2_
,
last_return2
)
returns
=
jnp
.
where
(
main1
,
last_return1_
,
last_return2_
)
delta1
=
next_q1_
-
cur_value
delta2
=
next_q2_
-
cur_value
lastgaelam1_
=
delta1
+
gamma
*
gae_lambda
*
lastgaelam1
lastgaelam1_
=
delta1
+
gamma
*
gae_lambda
*
lastgaelam1
lastgaelam2_
=
delta2
+
gamma
*
gae_lambda
*
lastgaelam2
lastgaelam2_
=
delta2
+
gamma
*
gae_lambda
*
lastgaelam2
returns
=
jnp
.
where
(
learn1
,
last_return1_
,
last_return2_
)
advantages
=
jnp
.
where
(
main1
,
lastgaelam1_
,
lastgaelam2_
)
advantages
=
jnp
.
where
(
learn1
,
lastgaelam1_
,
lastgaelam2_
)
next_value1
=
jnp
.
where
(
main1
,
cur_value
,
next_value1
)
next_value1
=
jnp
.
where
(
learn1
,
curvalues
,
next_value1
)
next_value2
=
jnp
.
where
(
main2
,
cur_value
,
next_value2
)
next_value2
=
jnp
.
where
(
learn2
,
curvalues
,
next_value2
)
lastgaelam1
=
jnp
.
where
(
main1
,
lastgaelam1_
,
lastgaelam1
)
lastgaelam1
=
jnp
.
where
(
learn1
,
lastgaelam1_
,
lastgaelam1
)
lastgaelam2
=
jnp
.
where
(
main2
,
lastgaelam2_
,
lastgaelam2
)
lastgaelam2
=
jnp
.
where
(
learn2
,
lastgaelam2_
,
lastgaelam2
)
next_q1
=
jnp
.
where
(
learn1
,
next_q1_
,
next_q1
)
carry
=
lastgaelam1
,
lastgaelam2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
\
next_q2
=
jnp
.
where
(
learn2
,
next_q2_
,
next_q1
)
done_used1
,
done_used2
,
last_return1
,
last_return2
,
next_q1
,
next_q2
last_return1
=
jnp
.
where
(
learn1
,
last_return1_
,
last_return1
)
last_return2
=
jnp
.
where
(
learn2
,
last_return2_
,
last_return2
)
carry
=
next_value1
,
next_value2
,
next_q1
,
next_q2
,
last_return1
,
last_return2
,
\
done_used1
,
done_used2
,
reward1
,
reward2
,
lastgaelam1
,
lastgaelam2
return
carry
,
(
advantages
,
returns
)
return
carry
,
(
advantages
,
returns
)
@
partial
(
jax
.
jit
,
static_argnums
=
(
7
,
8
))
def
truncated_gae_2p0s
(
def
compute_gae_upgo
(
next_value
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
gae_lambda
,
upgo
,
next_value
,
next_done
,
next_learn
,
values
,
rewards
,
dones
,
learns
,
gamma
,
gae_lambda
,
):
):
next_value1
=
jnp
.
where
(
next_learn
,
next_value
,
-
next_value
)
next_value1
=
next_value
next_value2
=
-
next_value1
next_value2
=
-
next_value1
last_return1
=
next_q1
=
next_value1
last_return1
=
next_q1
=
next_value1
last_return2
=
next_q2
=
next_value2
last_return2
=
next_q2
=
next_value2
done_used1
=
jnp
.
ones_like
(
next_done
)
done_used1
=
jnp
.
ones_like
(
next_dones
[
-
1
])
done_used2
=
jnp
.
ones_like
(
next_done
)
done_used2
=
jnp
.
ones_like
(
next_dones
[
-
1
])
reward1
=
jnp
.
zeros_like
(
next_value
)
reward1
=
reward2
=
jnp
.
zeros_like
(
next_value
)
reward2
=
jnp
.
zeros_like
(
next_value
)
lastgaelam1
=
lastgaelam2
=
jnp
.
zeros_like
(
next_value
)
lastgaelam1
=
jnp
.
zeros_like
(
next_value
)
carry
=
lastgaelam1
,
lastgaelam2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
\
lastgaelam2
=
jnp
.
zeros_like
(
next_value
)
done_used1
,
done_used2
,
last_return1
,
last_return2
,
next_q1
,
next_q2
carry
=
next_value1
,
next_value2
,
next_q1
,
next_q2
,
last_return1
,
last_return2
,
\
done_used1
,
done_used2
,
reward1
,
reward2
,
lastgaelam1
,
lastgaelam2
dones
=
jnp
.
concatenate
([
dones
,
next_done
[
None
,
:]],
axis
=
0
)
_
,
(
advantages
,
returns
)
=
jax
.
lax
.
scan
(
_
,
(
advantages
,
returns
)
=
jax
.
lax
.
scan
(
partial
(
compute_gae_once_upgo
,
gamma
=
gamma
,
gae_lambda
=
gae_lambda
),
partial
(
truncated_gae_upgo_loop
,
gamma
=
gamma
,
gae_lambda
=
gae_lambda
),
carry
,
(
dones
[
1
:],
values
,
rewards
,
lear
ns
),
reverse
=
True
carry
,
(
values
,
next_dones
,
rewards
,
mai
ns
),
reverse
=
True
)
)
return
returns
-
values
,
advantages
+
values
if
upgo
:
advantages
+=
returns
-
values
targets
=
values
+
advantages
targets
=
jax
.
lax
.
stop_gradient
(
targets
)
return
targets
,
advantages
def
simple_policy_loss
(
ratios
,
logits
,
new_logits
,
advantages
,
kld_max
,
eps
=
1e-12
):
advs
=
jax
.
lax
.
stop_gradient
(
advantages
)
probs
=
jax
.
nn
.
softmax
(
logits
)
new_probs
=
jax
.
nn
.
softmax
(
new_logits
)
kld
=
jnp
.
sum
(
probs
*
jnp
.
log
((
probs
+
eps
)
/
(
new_probs
+
eps
)),
axis
=-
1
)
kld_clip
=
jnp
.
clip
(
kld
,
0
,
kld_max
)
d_ratio
=
kld_clip
/
(
kld
+
eps
)
# e == 1 and t == 1
d_ratio
=
jnp
.
where
(
kld
<
1e-6
,
1.0
,
d_ratio
)
sign_a
=
jnp
.
sign
(
advs
)
result
=
(
d_ratio
+
sign_a
-
1
)
*
sign_a
pg_loss
=
-
advs
*
ratios
*
result
return
pg_loss
\ No newline at end of file
ygoai/rl/jax/agent2.py
View file @
43ca871e
...
@@ -150,6 +150,7 @@ class Encoder(nn.Module):
...
@@ -150,6 +150,7 @@ class Encoder(nn.Module):
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
freeze_id
:
bool
=
False
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
x
):
def
__call__
(
self
,
x
):
...
@@ -168,6 +169,8 @@ class Encoder(nn.Module):
...
@@ -168,6 +169,8 @@ class Encoder(nn.Module):
fc_layer
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
param_dtype
=
self
.
param_dtype
)
fc_layer
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
param_dtype
=
self
.
param_dtype
)
id_embed
=
embed
(
n_embed
,
embed_dim
)
id_embed
=
embed
(
n_embed
,
embed_dim
)
if
self
.
freeze_id
:
id_embed
=
lambda
x
:
jax
.
lax
.
stop_gradient
(
id_embed
(
x
))
action_encoder
=
ActionEncoder
(
action_encoder
=
ActionEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
...
@@ -337,6 +340,7 @@ class PPOLSTMAgent(nn.Module):
...
@@ -337,6 +340,7 @@ class PPOLSTMAgent(nn.Module):
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
multi_step
:
bool
=
False
multi_step
:
bool
=
False
switch
:
bool
=
True
switch
:
bool
=
True
freeze_id
:
bool
=
False
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
inputs
):
def
__call__
(
self
,
inputs
):
...
@@ -355,6 +359,7 @@ class PPOLSTMAgent(nn.Module):
...
@@ -355,6 +359,7 @@ class PPOLSTMAgent(nn.Module):
embedding_shape
=
self
.
embedding_shape
,
embedding_shape
=
self
.
embedding_shape
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
param_dtype
=
self
.
param_dtype
,
freeze_id
=
self
.
freeze_id
,
)
)
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
...
...
ygoai/rl/jax/switch.py
0 → 100644
View file @
43ca871e
import
jax
import
jax.numpy
as
jnp
def
truncated_gae_2p0s
(
next_value
,
values
,
rewards
,
next_dones
,
switch
,
gamma
,
gae_lambda
,
upgo
):
def
body_fn
(
carry
,
inp
):
boot_value
,
boot_done
,
next_value
,
lastgaelam
,
next_q
,
last_return
=
carry
next_done
,
cur_value
,
reward
,
switch
=
inp
next_done
=
jnp
.
where
(
switch
,
boot_done
,
next_done
)
next_value
=
jnp
.
where
(
switch
,
-
boot_value
,
next_value
)
lastgaelam
=
jnp
.
where
(
switch
,
0
,
lastgaelam
)
next_q
=
jnp
.
where
(
switch
,
-
boot_value
*
gamma
,
next_q
)
last_return
=
jnp
.
where
(
switch
,
-
boot_value
,
last_return
)
discount
=
gamma
*
(
1.0
-
next_done
)
last_return
=
reward
+
discount
*
jnp
.
where
(
next_q
>=
next_value
,
last_return
,
next_value
)
next_q
=
reward
+
discount
*
next_value
delta
=
next_q
-
cur_value
lastgaelam
=
delta
+
gae_lambda
*
discount
*
lastgaelam
carry
=
boot_value
,
boot_done
,
cur_value
,
lastgaelam
,
next_q
,
last_return
return
carry
,
(
lastgaelam
,
last_return
)
next_done
=
next_dones
[
-
1
]
lastgaelam
=
jnp
.
zeros_like
(
next_value
)
next_q
=
last_return
=
next_value
carry
=
next_value
,
next_done
,
next_value
,
lastgaelam
,
next_q
,
last_return
_
,
(
advantages
,
returns
)
=
jax
.
lax
.
scan
(
body_fn
,
carry
,
(
next_dones
,
values
,
rewards
,
switch
),
reverse
=
True
)
if
upgo
:
advantages
+=
returns
-
values
targets
=
values
+
advantages
targets
=
jax
.
lax
.
stop_gradient
(
targets
)
return
targets
,
advantages
ygoai/rl/utils.py
View file @
43ca871e
...
@@ -58,13 +58,3 @@ def masked_normalize(x, valid, eps=1e-8):
...
@@ -58,13 +58,3 @@ def masked_normalize(x, valid, eps=1e-8):
def
to_tensor
(
x
,
device
,
dtype
=
None
):
def
to_tensor
(
x
,
device
,
dtype
=
None
):
return
optree
.
tree_map
(
lambda
x
:
torch
.
from_numpy
(
x
)
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
True
),
x
)
return
optree
.
tree_map
(
lambda
x
:
torch
.
from_numpy
(
x
)
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
True
),
x
)
def
load_embeddings
(
embedding_file
,
code_list_file
):
with
open
(
embedding_file
,
"rb"
)
as
f
:
embeddings
=
pickle
.
load
(
f
)
with
open
(
code_list_file
,
"r"
)
as
f
:
code_list
=
f
.
readlines
()
code_list
=
[
int
(
code
.
strip
())
for
code
in
code_list
]
assert
len
(
embeddings
)
==
len
(
code_list
),
f
"len(embeddings)={len(embeddings)}, len(code_list)={len(code_list)}"
embeddings
=
np
.
array
([
embeddings
[
code
]
for
code
in
code_list
],
dtype
=
np
.
float32
)
return
embeddings
ygoai/utils.py
View file @
43ca871e
import
pickle
import
numpy
as
np
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -43,4 +45,15 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False):
...
@@ -43,4 +45,15 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False):
elif
'EDOPro'
in
env_id
:
elif
'EDOPro'
in
env_id
:
from
ygoenv.edopro
import
init_module
from
ygoenv.edopro
import
init_module
init_module
(
str
(
db_path
),
code_list_file
,
decks
)
init_module
(
str
(
db_path
),
code_list_file
,
decks
)
return
deck_name
return
deck_name
\ No newline at end of file
def
load_embeddings
(
embedding_file
,
code_list_file
):
with
open
(
embedding_file
,
"rb"
)
as
f
:
embeddings
=
pickle
.
load
(
f
)
with
open
(
code_list_file
,
"r"
)
as
f
:
code_list
=
f
.
readlines
()
code_list
=
[
int
(
code
.
strip
())
for
code
in
code_list
]
assert
len
(
embeddings
)
==
len
(
code_list
),
f
"len(embeddings)={len(embeddings)}, len(code_list)={len(code_list)}"
embeddings
=
np
.
array
([
embeddings
[
code
]
for
code
in
code_list
],
dtype
=
np
.
float32
)
return
embeddings
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment