Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
3a79013f
Commit
3a79013f
authored
Jul 15, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add tflite support for inf server
parent
a727022b
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
158 additions
and
81 deletions
+158
-81
Makefile
Makefile
+2
-3
setup.py
setup.py
+0
-3
ygoai/rl/jax/agent.py
ygoai/rl/jax/agent.py
+10
-8
ygoinf/setup.py
ygoinf/setup.py
+21
-0
ygoinf/ygoinf/__init__.py
ygoinf/ygoinf/__init__.py
+0
-0
ygoinf/ygoinf/features.py
ygoinf/ygoinf/features.py
+33
-6
ygoinf/ygoinf/jax_inf.py
ygoinf/ygoinf/jax_inf.py
+44
-0
ygoinf/ygoinf/models.py
ygoinf/ygoinf/models.py
+0
-0
ygoinf/ygoinf/server.py
ygoinf/ygoinf/server.py
+17
-61
ygoinf/ygoinf/tflite_inf.py
ygoinf/ygoinf/tflite_inf.py
+31
-0
No files found.
Makefile
View file @
3a79013f
...
...
@@ -11,6 +11,7 @@ dev: assets script py_install ygoenv_so
py_install
:
pip
install
-e
ygoenv
pip
install
-e
ygoinf
pip
install
-e
.
ygoenv_so
:
ygoenv/ygoenv/ygopro/ygopro_ygoenv.so
...
...
@@ -46,6 +47,4 @@ assets/locale/zh/strings.conf: assets/locale/zh
clean
:
rm
-rf
scripts/script
rm
-rf
assets/locale/en assets/locale/zh
pip uninstall
-y
ygoenv
pip uninstall
-y
.
\ No newline at end of file
rm
-rf
assets/locale/en assets/locale/zh
\ No newline at end of file
setup.py
View file @
3a79013f
...
...
@@ -16,9 +16,6 @@ REQUIRED = [
"tyro"
,
"pandas"
,
"tensorboardX"
,
"fastapi"
,
"uvicorn[standard]"
,
"pydantic_settings"
,
"tqdm"
,
]
...
...
ygoai/rl/jax/agent.py
View file @
3a79013f
...
...
@@ -242,7 +242,7 @@ class Encoder(nn.Module):
x_global
=
x
[
'global_'
]
x_actions
=
x
[
'actions_'
]
x_h_actions
=
x
[
'h_actions_'
]
mask
=
x
[
'mask_'
]
mask
=
x
.
get
(
'mask_'
,
None
)
batch_size
=
x_global
.
shape
[
0
]
valid
=
x_global
[:,
-
1
]
==
0
...
...
@@ -296,7 +296,8 @@ class Encoder(nn.Module):
# History actions
x_h_actions
=
x_h_actions
.
astype
(
jnp
.
int32
)
h_mask
=
x_h_actions
[:,
:,
3
]
==
0
# msg == 0
h_mask
=
h_mask
.
at
[:,
0
]
.
set
(
False
)
h_mask
=
jnp
.
concatenate
([
jnp
.
zeros
((
batch_size
,
1
),
dtype
=
h_mask
.
dtype
),
h_mask
[:,
1
:]],
axis
=
1
)
# h_mask = h_mask.at[:, 0].set(False)
x_h_id
=
decode_id
(
x_h_actions
[
...
,
1
:
3
])
x_h_id
=
id_embed
(
x_h_id
)
...
...
@@ -355,7 +356,8 @@ class Encoder(nn.Module):
f_actions
=
x_a_feats
+
f_actions
a_mask
=
x_actions
[:,
:,
3
]
==
0
a_mask
=
a_mask
.
at
[:,
0
]
.
set
(
False
)
a_mask
=
jnp
.
concatenate
([
jnp
.
zeros
((
batch_size
,
1
),
dtype
=
a_mask
.
dtype
),
a_mask
[:,
1
:]],
axis
=
1
)
# a_mask = a_mask.at[:, 0].set(False)
g_feats
=
[
f_g_card
,
f_global
]
if
self
.
use_history
:
...
...
@@ -698,17 +700,17 @@ class RNNAgent(nn.Module):
def
init_rnn_state
(
self
,
batch_size
):
if
self
.
rnn_type
in
[
'lstm'
,
'none'
]:
return
(
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)),
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)),
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)
,
dtype
=
np
.
float32
),
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)
,
dtype
=
np
.
float32
),
)
elif
self
.
rnn_type
==
'gru'
:
return
np
.
zeros
((
batch_size
,
self
.
rnn_channels
))
return
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)
,
dtype
=
np
.
float32
)
elif
self
.
rnn_type
==
'rwkv'
:
head_size
=
self
.
rwkv_head_size
num_heads
=
self
.
rnn_channels
//
self
.
rwkv_head_size
return
(
np
.
zeros
((
batch_size
,
num_heads
*
head_size
)),
np
.
zeros
((
batch_size
,
num_heads
*
head_size
*
head_size
)),
np
.
zeros
((
batch_size
,
num_heads
*
head_size
)
,
dtype
=
np
.
float32
),
np
.
zeros
((
batch_size
,
num_heads
*
head_size
*
head_size
)
,
dtype
=
np
.
float32
),
)
else
:
return
None
...
...
ygoinf/setup.py
0 → 100644
View file @
3a79013f
from
setuptools
import
setup
,
find_packages
__version__
=
"0.0.1"
INSTALL_REQUIRES
=
[
"numpy"
,
"optree"
,
"fastapi"
,
"uvicorn[standard]"
,
"pydantic_settings"
,
"tflite-runtime"
,
]
setup
(
name
=
"ygoinf"
,
version
=
__version__
,
packages
=
find_packages
(
include
=
'ygoinf*'
),
long_description
=
""
,
install_requires
=
INSTALL_REQUIRES
,
python_requires
=
">=3.10"
,
)
\ No newline at end of file
ygo
env/main
.py
→
ygo
inf/ygoinf/__init__
.py
View file @
3a79013f
File moved
ygo
ai/server
/features.py
→
ygo
inf/ygoinf
/features.py
View file @
3a79013f
...
...
@@ -40,10 +40,13 @@ def combinations_with_weight2(weights, r):
N_CARD_FEATURES
=
41
MAX_CARDS
=
80
MAX_ACTIONS
=
24
N_GLOBAL_FEATURES
=
23
N_ACTION_FEATURES
=
12
N_GLOBAL_FEATURES
=
23
N_HISTORY_ACTIONS
=
32
H_ACTIONS_SHAPE
=
(
N_HISTORY_ACTIONS
,
N_ACTION_FEATURES
+
2
)
H_ACTIONS_FEATS
=
14
N_RNN_CHANNELS
=
512
H_ACTIONS_SHAPE
=
(
N_HISTORY_ACTIONS
,
H_ACTIONS_FEATS
)
DESCRIPTION_LIMIT
=
10000
CARD_EFFECT_OFFSET
=
10010
...
...
@@ -58,9 +61,13 @@ def sample_input():
"global_"
:
global_
,
"actions_"
:
legal_actions
,
"h_actions_"
:
history_actions
,
"mask_"
:
None
,
}
def
init_rstate
():
return
(
np
.
zeros
((
1
,
N_RNN_CHANNELS
),
dtype
=
np
.
float32
),
np
.
zeros
((
1
,
N_RNN_CHANNELS
),
dtype
=
np
.
float32
),
)
system_strings
=
[
1050
,
1051
,
1052
,
1054
,
1055
,
1056
,
1057
,
1058
,
1059
,
1060
,
...
...
@@ -1047,8 +1054,8 @@ class HistoryActions:
class
PredictState
:
def
__init__
(
self
,
init_rstate
):
self
.
rstate
=
init_rstate
def
__init__
(
self
):
self
.
rstate
=
init_rstate
()
self
.
index
=
0
self
.
history_actions
=
HistoryActions
()
...
...
@@ -1097,7 +1104,6 @@ def predict(model_fn, input: Input, prev_action_idx, state: PredictState):
"global_"
:
global_
,
"actions_"
:
actions
,
"h_actions_"
:
h_actions
,
"mask_"
:
None
,
}
if
n_actions
==
1
:
probs
=
[
1.0
]
...
...
@@ -1123,3 +1129,24 @@ def predict(model_fn, input: Input, prev_action_idx, state: PredictState):
state
.
record
(
input
,
actions
,
probs
)
state
.
index
+=
1
return
predict_results
class
Predictor
:
def
__init__
(
self
,
loaded
,
predict_fn
):
self
.
loaded
=
loaded
self
.
predict_fn
=
predict_fn
def
predict
(
self
,
rstate
,
sample_obs
):
return
self
.
predict_fn
(
self
.
loaded
,
rstate
,
sample_obs
)
@
staticmethod
def
load
(
checkpoint
):
sample_obs
=
sample_input
()
rstate
=
init_rstate
()
if
checkpoint
.
endswith
(
".flax_model"
):
from
.jax_inf
import
load_model
,
predict_fn
elif
checkpoint
.
endswith
(
".tflite"
):
from
.tflite_inf
import
load_model
,
predict_fn
predictor
=
Predictor
(
load_model
(
checkpoint
,
rstate
,
sample_obs
),
predict_fn
)
predictor
.
predict
(
rstate
,
sample_obs
)
return
predictor
ygoinf/ygoinf/jax_inf.py
0 → 100644
View file @
3a79013f
import
numpy
as
np
import
jax
import
jax.numpy
as
jnp
import
flax
from
ygoai.rl.jax.agent
import
RNNAgent
def
create_agent
():
return
RNNAgent
(
num_layers
=
2
,
rnn_channels
=
512
,
use_history
=
True
,
rnn_type
=
'lstm'
,
num_channels
=
128
,
film
=
True
,
noam
=
True
,
version
=
2
,
)
@
jax
.
jit
def
get_probs_and_value
(
params
,
rstate
,
obs
):
agent
=
create_agent
()
next_rstate
,
logits
,
value
=
agent
.
apply
(
params
,
obs
,
rstate
)[:
3
]
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
return
next_rstate
,
probs
,
value
def
predict_fn
(
params
,
rstate
,
obs
):
obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs
)
rstate
,
probs
,
value
=
get_probs_and_value
(
params
,
rstate
,
obs
)
return
rstate
,
np
.
array
(
probs
)[
0
]
.
tolist
(),
float
(
np
.
array
(
value
)[
0
])
def
load_model
(
checkpoint
,
rstate
,
sample_obs
):
agent
=
create_agent
()
key
=
jax
.
random
.
PRNGKey
(
0
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
sample_obs_
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
sample_obs
)
params
=
jax
.
jit
(
agent
.
init
)(
agent_key
,
sample_obs_
,
rstate
)
with
open
(
checkpoint
,
"rb"
)
as
f
:
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params
=
jax
.
device_put
(
params
)
return
params
ygo
ai/server
/models.py
→
ygo
inf/ygoinf
/models.py
View file @
3a79013f
File moved
ygo
ai/server/main
.py
→
ygo
inf/ygoinf/server
.py
View file @
3a79013f
...
...
@@ -8,13 +8,9 @@ from contextlib import asynccontextmanager
from
fastapi
import
FastAPI
,
Path
from
fastapi.middleware.cors
import
CORSMiddleware
from
pydantic
import
Field
from
pydantic_settings
import
BaseSettings
import
numpy
as
np
import
jax
import
jax.numpy
as
jnp
import
flax
from
ygoai.rl.jax.agent
import
RNNAgent
from
.models
import
(
DuelCreateResponse
,
...
...
@@ -22,45 +18,21 @@ from .models import (
DuelPredictResponse
,
DuelPredictErrorResponse
,
)
from
.features
import
predict
,
sample_input
,
init_code_list
,
PredictState
from
.features
import
predict
,
init_code_list
,
PredictState
,
Predictor
class
Settings
(
BaseSettings
):
code_list
:
str
=
"code_list.txt"
checkpoint
:
str
=
"latest.flax_model"
enable_cors
:
bool
=
Field
(
default
=
True
,
description
=
"Enable CORS"
)
settings
=
Settings
()
def
create_agent
():
return
RNNAgent
(
num_layers
=
2
,
rnn_channels
=
512
,
use_history
=
True
,
rnn_type
=
'lstm'
,
num_channels
=
128
,
film
=
True
,
noam
=
True
,
version
=
2
,
)
@
jax
.
jit
def
get_probs_and_value
(
params
,
rstate
,
obs
):
agent
=
create_agent
()
next_rstate
,
logits
,
value
=
agent
.
apply
(
params
,
obs
,
rstate
)[:
3
]
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
return
next_rstate
,
probs
,
value
def
predict_fn
(
params
,
rstate
,
obs
):
obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs
)
rstate
,
probs
,
value
=
get_probs_and_value
(
params
,
rstate
,
obs
)
return
rstate
,
np
.
array
(
probs
)[
0
]
.
tolist
(),
float
(
np
.
array
(
value
)[
0
])
all_models
=
{}
duel_states
:
Dict
[
str
,
PredictState
]
=
{}
@
asynccontextmanager
async
def
lifespan
(
app
:
FastAPI
):
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
...
...
@@ -68,26 +40,9 @@ async def lifespan(app: FastAPI):
init_code_list
(
settings
.
code_list
)
agent
=
create_agent
()
key
=
jax
.
random
.
PRNGKey
(
0
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
sample_obs
=
sample_input
()
sample_obs_
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
sample_obs
)
rstate
=
agent
.
init_rnn_state
(
1
)
params
=
jax
.
jit
(
agent
.
init
)(
agent_key
,
sample_obs_
,
rstate
)
checkpoint
=
settings
.
checkpoint
with
open
(
checkpoint
,
"rb"
)
as
f
:
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params
=
jax
.
device_put
(
params
)
all_models
[
"param"
]
=
params
all_models
[
"agent"
]
=
agent
predict_fn
(
params
,
rstate
,
sample_obs
)
predictor
=
Predictor
.
load
(
checkpoint
)
all_models
[
"default"
]
=
predictor
print
(
f
"loaded checkpoint from {checkpoint}"
)
state
=
new_state
()
...
...
@@ -103,16 +58,17 @@ app = FastAPI(
lifespan
=
lifespan
,
)
app
.
add_middleware
(
CORSMiddleware
,
allow_origins
=
[
"*"
],
allow_credentials
=
True
,
allow_methods
=
[
"*"
],
allow_headers
=
[
"*"
],
)
if
settings
.
enable_cors
:
app
.
add_middleware
(
CORSMiddleware
,
allow_origins
=
[
"*"
],
allow_credentials
=
True
,
allow_methods
=
[
"*"
],
allow_headers
=
[
"*"
],
)
def
new_state
():
return
PredictState
(
all_models
[
"agent"
]
.
init_rnn_state
(
1
)
)
return
PredictState
()
@
app
.
post
(
'/v0/duels'
,
response_model
=
DuelCreateResponse
)
async
def
create_duel
()
->
DuelCreateResponse
:
...
...
@@ -153,10 +109,10 @@ async def duel_predict(
error
=
f
"index mismatch: expected {duel_state.index}, got {index}"
)
params
=
all_models
[
"param"
]
predictor
=
all_models
[
"default"
]
model_fn
=
predictor
.
predict
_start
=
time
.
time
()
model_fn
=
lambda
r
,
x
:
predict_fn
(
params
,
r
,
x
)
try
:
predict_results
=
predict
(
model_fn
,
body
.
input
,
body
.
prev_action_idx
,
duel_state
)
except
(
KeyError
,
NotImplementedError
)
as
e
:
...
...
ygoinf/ygoinf/tflite_inf.py
0 → 100644
View file @
3a79013f
import
numpy
as
np
import
optree
import
tflite_runtime.interpreter
as
tf_lite
def
tflite_predict
(
interpreter
,
rstate
,
obs
):
input_details
=
interpreter
.
get_input_details
()
output_details
=
interpreter
.
get_output_details
()
inputs
=
rstate
,
obs
for
i
,
x
in
enumerate
(
optree
.
tree_leaves
(
inputs
)):
interpreter
.
set_tensor
(
input_details
[
i
][
"index"
],
x
)
interpreter
.
invoke
()
results
=
[
interpreter
.
get_tensor
(
o
[
"index"
])
for
o
in
output_details
]
rstate1
,
rstate2
,
probs
,
value
=
results
rstate
=
(
rstate1
,
rstate2
)
return
rstate
,
probs
,
value
def
predict_fn
(
interpreter
,
rstate
,
obs
):
obs
=
optree
.
tree_map
(
lambda
x
:
np
.
array
([
x
]),
obs
)
rstate
,
probs
,
value
=
tflite_predict
(
interpreter
,
rstate
,
obs
)
prob
=
probs
[
0
]
.
tolist
()
value
=
float
(
value
[
0
])
return
rstate
,
prob
,
value
def
load_model
(
checkpoint
,
*
args
):
with
open
(
checkpoint
,
"rb"
)
as
f
:
tflite_model
=
f
.
read
()
interpreter
=
tf_lite
.
Interpreter
(
model_content
=
tflite_model
)
interpreter
.
allocate_tensors
()
return
interpreter
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment