Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
93bc3723
Commit
93bc3723
authored
Apr 15, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add option for greedy_reward and correct upgo
parent
0ecf0a00
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
149 additions
and
76 deletions
+149
-76
scripts/jax/ppo.py
scripts/jax/ppo.py
+9
-4
ygoai/rl/jax/__init__.py
ygoai/rl/jax/__init__.py
+88
-36
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+25
-14
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+27
-22
No files found.
scripts/jax/ppo.py
View file @
93bc3723
...
...
@@ -26,7 +26,7 @@ from ygoai.utils import init_ygopro
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
compute_gae_
upgo_2p0s
,
compute_gae_2p0s
from
ygoai.rl.jax
import
compute_gae_
2p0s
,
upgo_advantage
os
.
environ
[
"XLA_FLAGS"
]
=
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
...
...
@@ -62,6 +62,8 @@ class Args:
"""the maximum number of options"""
n_history_actions
:
int
=
32
"""the number of history actions to use"""
greedy_reward
:
bool
=
True
"""whether to use greedy reward (faster kill higher reward)"""
total_timesteps
:
int
=
5000000000
"""total timesteps of the experiments"""
...
...
@@ -117,7 +119,7 @@ class Args:
"""whether to use `jax.distirbuted`"""
concurrency
:
bool
=
True
"""whether to run the actor and learner concurrently"""
bfloat16
:
bool
=
Tru
e
bfloat16
:
bool
=
Fals
e
"""whether to use bfloat16 for the agent"""
thread_affinity
:
bool
=
False
"""whether to use thread affinity for the environment"""
...
...
@@ -161,6 +163,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
async_reset
=
False
,
greedy_reward
=
args
.
greedy_reward
,
play_mode
=
mode
,
)
envs
.
num_envs
=
num_envs
...
...
@@ -596,10 +599,12 @@ if __name__ == "__main__":
(
jax
.
lax
.
stop_gradient
(
new_values
),
rewards
,
next_dones
,
switch
),
)
compute_gae_fn
=
compute_gae_upgo_2p0s
if
args
.
upgo
else
compute_gae_2p0s
advantages
,
target_values
=
compute_gae_fn
(
advantages
,
target_values
=
compute_gae_2p0s
(
next_value
,
values
,
rewards
,
next_dones
,
switch
,
args
.
gamma
,
args
.
gae_lambda
)
if
args
.
upgo
:
advantages
=
advantages
+
upgo_advantage
(
next_value
,
values
,
rewards
,
next_dones
,
switch
,
args
.
gamma
)
advantages
,
target_values
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)),
(
advantages
,
target_values
))
...
...
ygoai/rl/jax/__init__.py
View file @
93bc3723
...
...
@@ -67,27 +67,6 @@ def vtrace(
return
VTraceOutput
(
q_estimate
=
q_estimate
,
errors
=
errors
)
def
upgo_return
(
r_t
,
v_t
,
discount_t
,
stop_target_gradients
:
bool
=
True
):
def
_body
(
acc
,
xs
):
r
,
v
,
q
,
discount
=
xs
acc
=
r
+
discount
*
jnp
.
where
(
q
>=
v
,
acc
,
v
)
return
acc
,
acc
# TODO: following alphastar, estimate q_t with one-step target
# It might be better to use network to estimate q_t
q_t
=
r_t
[
1
:]
+
discount_t
[
1
:]
*
v_t
[
1
:]
# q[:-1]
_
,
returns
=
jax
.
lax
.
scan
(
_body
,
q_t
[
-
1
],
(
r_t
[:
-
1
],
v_t
[:
-
1
],
q_t
,
discount_t
[:
-
1
]),
reverse
=
True
)
# Following rlax.vtrace_td_error_and_advantage, part of gradient is reserved
# Experiments show that where to stop gradient has no impact on the performance
returns
=
jax
.
lax
.
select
(
stop_target_gradients
,
jax
.
lax
.
stop_gradient
(
returns
),
returns
)
returns
=
jnp
.
concatenate
([
returns
,
q_t
[
-
1
:]],
axis
=
0
)
return
returns
def
clipped_surrogate_pg_loss
(
prob_ratios_t
,
adv_t
,
mask
,
epsilon
,
use_stop_gradient
=
True
):
adv_t
=
jax
.
lax
.
select
(
use_stop_gradient
,
jax
.
lax
.
stop_gradient
(
adv_t
),
adv_t
)
clipped_ratios_t
=
jnp
.
clip
(
prob_ratios_t
,
1.
-
epsilon
,
1.
+
epsilon
)
...
...
@@ -123,39 +102,112 @@ def compute_gae_2p0s(
return
advantages
,
target_values
@
partial
(
jax
.
jit
,
static_argnums
=
(
5
,
6
))
def
compute_gae_upgo_2p0s
(
next_value
,
values
,
rewards
,
next_dones
,
switch
,
gamma
,
gae_lambda
,
):
@
partial
(
jax
.
jit
,
static_argnums
=
(
5
,))
def
upgo_advantage
(
next_value
,
values
,
rewards
,
next_dones
,
switch
,
gamma
):
def
body_fn
(
carry
,
inp
):
boot_value
,
boot_done
,
next_value
,
next_q
,
last_return
,
lastgaelam
=
carry
boot_value
,
boot_done
,
next_value
,
next_q
,
last_return
=
carry
next_done
,
cur_value
,
reward
,
switch
=
inp
next_done
=
jnp
.
where
(
switch
,
boot_done
,
next_done
)
next_value
=
jnp
.
where
(
switch
,
-
boot_value
,
next_value
)
next_q
=
jnp
.
where
(
switch
,
-
boot_value
*
gamma
,
next_q
)
last_return
=
jnp
.
where
(
switch
,
-
boot_value
,
last_return
)
lastgaelam
=
jnp
.
where
(
switch
,
0
,
lastgaelam
)
gamma_
=
gamma
*
(
1.0
-
next_done
)
last_return
=
reward
+
gamma_
*
jnp
.
where
(
next_q
>=
next_value
,
last_return
,
next_value
)
next_q
=
reward
+
gamma_
*
next_value
delta
=
next_q
-
cur_value
lastgaelam
=
delta
+
gae_lambda
*
gamma_
*
lastgaelam
carry
=
boot_value
,
boot_done
,
cur_value
,
next_q
,
last_return
,
lastgaelam
return
carry
,
(
lastgaelam
,
last_return
)
carry
=
boot_value
,
boot_done
,
cur_value
,
next_q
,
last_return
return
carry
,
last_return
next_done
=
next_dones
[
-
1
]
lastgaelam
=
jnp
.
zeros_like
(
next_value
)
carry
=
next_value
,
next_done
,
next_value
,
next_value
,
next_value
,
lastgaelam
carry
=
next_value
,
next_done
,
next_value
,
next_value
,
next_value
_
,
(
advantages
,
returns
)
=
jax
.
lax
.
scan
(
_
,
returns
=
jax
.
lax
.
scan
(
body_fn
,
carry
,
(
next_dones
,
values
,
rewards
,
switch
),
reverse
=
True
)
return
returns
-
values
,
advantages
+
values
return
returns
-
values
# def compute_gae_once(carry, inp, gamma, gae_lambda):
# v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2 = carry
# rho, cur_values, log_ratio, next_done, r_t, corr_r_t, main = inp
# v = jnp.where(main, v1, v2)
# next_values = jnp.where(main, next_values1, next_values2)
# reward = jnp.where(main, reward1, reward2)
# xi = jnp.where(main, xi1, xi2)
# p_t = c_t = jnp.minimum(1.0, rho * xi)
# sig_v = p_t * (r_t + reward * rho + next_values - cur_values)
# reg_r = jnp.log(p / p_reg)
# q = r_t + rho * (reward + v)
# q = -eta * + cur_values
# v = cur_values + sig_v + c_t * (v - next_values)
# v1 = jnp.where(main, v, v1)
# v2 = jnp.where(main, v2, v)
# next_values1 = jnp.where(main, cur_values, next_values1)
# next_values2 = jnp.where(main, next_values2, cur_values)
# reward1 = jnp.where(main, 0, r_t + rho * reward1)
# reward2 = jnp.where(main, r_t + rho * reward2, 0)
# xi1 = jnp.where(main, 1, rho * xi1)
# xi2 = jnp.where(main, rho * xi2, 1)
# learn1 = learn
# learn2 = ~learn
# factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward))
# reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1))
# reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2))
# real_done1 = next_done | ~done_used1
# nextvalues1 = jnp.where(real_done1, 0, nextvalues1)
# lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
# real_done2 = next_done | ~done_used2
# nextvalues2 = jnp.where(real_done2, 0, nextvalues2)
# lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
# done_used1 = jnp.where(
# next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1))
# done_used2 = jnp.where(
# next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2))
# delta1 = reward1 + gamma * nextvalues1 - curvalues
# delta2 = reward2 + gamma * nextvalues2 - curvalues
# lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
# lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
# advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_)
# nextvalues1 = jnp.where(learn1, curvalues, nextvalues1)
# nextvalues2 = jnp.where(learn2, curvalues, nextvalues2)
# lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1)
# lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2)
# carry = nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
# return carry, advantages
# @partial(jax.jit, static_argnums=(6, 7))
# def vtrace_rnad(
# next_value, next_done, values, rewards, dones, learns,
# gamma, gae_lambda,
# ):
# next_value1 = next_value
# next_value2 = -next_value1
# done_used1 = jnp.ones_like(next_done)
# done_used2 = jnp.ones_like(next_done)
# reward1 = jnp.zeros_like(next_value)
# reward2 = jnp.zeros_like(next_value)
# lastgaelam1 = jnp.zeros_like(next_value)
# lastgaelam2 = jnp.zeros_like(next_value)
# carry = next_value1, next_value2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
# dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
# _, advantages = jax.lax.scan(
# partial(compute_gae_once, gamma=gamma, gae_lambda=gae_lambda),
# carry, (dones[1:], values, rewards, learns), reverse=True
# )
# target_values = advantages + values
# return advantages, target_values
def
compute_gae_once
(
carry
,
inp
,
gamma
,
gae_lambda
):
...
...
ygoai/rl/jax/agent2.py
View file @
93bc3723
...
...
@@ -336,16 +336,17 @@ class PPOLSTMAgent(nn.Module):
dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
multi_step
:
bool
=
False
switch
:
bool
=
True
@
nn
.
compact
def
__call__
(
self
,
inputs
):
if
self
.
multi_step
:
# (num_steps * batch_size, ...)
carry1
,
carry2
,
x
,
done
,
switch
=
inputs
batch_size
=
carry
1
[
0
]
.
shape
[
0
]
rstate1
,
rstate2
,
x
,
done
,
switch_or_main
=
inputs
batch_size
=
rstate
1
[
0
]
.
shape
[
0
]
num_steps
=
done
.
shape
[
0
]
//
batch_size
else
:
carry
,
x
=
inputs
rstate
,
x
=
inputs
c
=
self
.
channels
encoder
=
Encoder
(
...
...
@@ -361,21 +362,31 @@ class PPOLSTMAgent(nn.Module):
lstm_layer
=
nn
.
OptimizedLSTMCell
(
self
.
lstm_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
if
self
.
multi_step
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
switch
):
carry
,
init_carry
=
carry
carry
,
y
=
cell
(
carry
,
x
)
carry
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
carry
)
carry
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
switch
[:,
None
],
x
,
y
),
init_carry
,
carry
)
return
(
carry
,
init_carry
),
y
if
self
.
switch
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
switch
):
rstate
,
init_rstate2
=
carry
rstate
,
y
=
cell
(
rstate
,
x
)
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
rstate
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
switch
[:,
None
],
x
,
y
),
init_rstate2
,
rstate
)
return
(
rstate
,
init_rstate2
),
y
else
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
main
):
rstate1
,
rstate2
=
carry
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
y
=
cell
(
rstate
,
x
)
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
return
(
rstate1
,
rstate2
),
y
scan
=
nn
.
scan
(
body_fn
,
variable_broadcast
=
'params'
,
split_rngs
=
{
'params'
:
False
})
f_state
,
done
,
switch
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
batch_size
)
+
x
.
shape
[
1
:]),
(
f_state
,
done
,
switch
))
carry
,
f_state
=
scan
(
lstm_layer
,
(
carry1
,
carry2
),
f_state
,
done
,
switch
)
f_state
,
done
,
switch
_or_main
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
batch_size
)
+
x
.
shape
[
1
:]),
(
f_state
,
done
,
switch
_or_main
))
rstate
,
f_state
=
scan
(
lstm_layer
,
(
rstate1
,
rstate2
),
f_state
,
done
,
switch_or_main
)
f_state
=
f_state
.
reshape
((
-
1
,
f_state
.
shape
[
-
1
]))
else
:
carry
,
f_state
=
lstm_layer
(
carry
,
f_state
)
rstate
,
f_state
=
lstm_layer
(
rstate
,
f_state
)
actor
=
Actor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
...
...
@@ -384,4 +395,4 @@ class PPOLSTMAgent(nn.Module):
logits
=
actor
(
f_state
,
f_actions
,
mask
)
value
=
critic
(
f_state
)
return
carry
,
logits
,
value
,
valid
return
rstate
,
logits
,
value
,
valid
ygoenv/ygoenv/ygopro/ygopro.h
View file @
93bc3723
...
...
@@ -1252,7 +1252,7 @@ public:
"play_mode"
_
.
Bind
(
std
::
string
(
"bot"
)),
"verbose"
_
.
Bind
(
false
),
"max_options"
_
.
Bind
(
16
),
"max_cards"
_
.
Bind
(
80
),
"n_history_actions"
_
.
Bind
(
16
),
"record"
_
.
Bind
(
false
),
"async_reset"
_
.
Bind
(
true
));
"record"
_
.
Bind
(
false
),
"async_reset"
_
.
Bind
(
true
)
,
"greedy_reward"
_
.
Bind
(
true
)
);
}
template
<
typename
Config
>
static
decltype
(
auto
)
StateSpec
(
const
Config
&
conf
)
{
...
...
@@ -1353,6 +1353,7 @@ protected:
PlayerId
winner_
;
uint8_t
win_reason_
;
bool
greedy_reward_
;
int
lp_
[
2
];
...
...
@@ -1438,7 +1439,7 @@ public:
play_modes_
(
parse_play_modes
(
spec
.
config
[
"play_mode"
_
])),
verbose_
(
spec
.
config
[
"verbose"
_
]),
record_
(
spec
.
config
[
"record"
_
]),
n_history_actions_
(
spec
.
config
[
"n_history_actions"
_
]),
pool_
(
BS
::
thread_pool
(
1
)),
async_reset_
(
spec
.
config
[
"async_reset"
_
])
{
async_reset_
(
spec
.
config
[
"async_reset"
_
])
,
greedy_reward_
(
spec
.
config
[
"greedy_reward"
_
])
{
if
(
record_
)
{
if
(
!
verbose_
)
{
throw
std
::
runtime_error
(
"record mode must be used with verbose mode and num_envs=1"
);
...
...
@@ -1879,29 +1880,33 @@ public:
int
reason
=
0
;
if
(
done_
)
{
float
base_reward
;
if
(
winner_
==
0
)
{
if
(
turn_count_
<=
1
)
{
// FTK
base_reward
=
16.0
;
}
else
if
(
turn_count_
<=
3
)
{
base_reward
=
8.0
;
}
else
if
(
turn_count_
<=
5
)
{
base_reward
=
4.0
;
}
else
if
(
turn_count_
<=
7
)
{
base_reward
=
2.0
;
if
(
greedy_reward_
)
{
if
(
winner_
==
0
)
{
if
(
turn_count_
<=
1
)
{
// FTK
base_reward
=
16.0
;
}
else
if
(
turn_count_
<=
3
)
{
base_reward
=
8.0
;
}
else
if
(
turn_count_
<=
5
)
{
base_reward
=
4.0
;
}
else
if
(
turn_count_
<=
7
)
{
base_reward
=
2.0
;
}
else
{
base_reward
=
0.5
+
1.0
/
(
turn_count_
-
7
);
}
}
else
{
base_reward
=
0.5
+
1.0
/
(
turn_count_
-
7
);
if
(
turn_count_
<=
1
)
{
base_reward
=
8.0
;
}
else
if
(
turn_count_
<=
3
)
{
base_reward
=
4.0
;
}
else
if
(
turn_count_
<=
5
)
{
base_reward
=
2.0
;
}
else
{
base_reward
=
0.5
+
1.0
/
(
turn_count_
-
5
);
}
}
}
else
{
if
(
turn_count_
<=
1
)
{
base_reward
=
8.0
;
}
else
if
(
turn_count_
<=
3
)
{
base_reward
=
4.0
;
}
else
if
(
turn_count_
<=
5
)
{
base_reward
=
2.0
;
}
else
{
base_reward
=
0.5
+
1.0
/
(
turn_count_
-
5
);
}
base_reward
=
1.0
;
}
if
(
play_mode_
==
kSelfPlay
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment