Commit 3a79013f authored by sbl1996@126.com's avatar sbl1996@126.com

Add tflite support for inf server

parent a727022b
......@@ -11,6 +11,7 @@ dev: assets script py_install ygoenv_so
py_install:
pip install -e ygoenv
pip install -e ygoinf
pip install -e .
ygoenv_so: ygoenv/ygoenv/ygopro/ygopro_ygoenv.so
......@@ -46,6 +47,4 @@ assets/locale/zh/strings.conf: assets/locale/zh
clean:
rm -rf scripts/script
rm -rf assets/locale/en assets/locale/zh
pip uninstall -y ygoenv
pip uninstall -y .
\ No newline at end of file
rm -rf assets/locale/en assets/locale/zh
\ No newline at end of file
......@@ -16,9 +16,6 @@ REQUIRED = [
"tyro",
"pandas",
"tensorboardX",
"fastapi",
"uvicorn[standard]",
"pydantic_settings",
"tqdm",
]
......
......@@ -242,7 +242,7 @@ class Encoder(nn.Module):
x_global = x['global_']
x_actions = x['actions_']
x_h_actions = x['h_actions_']
mask = x['mask_']
mask = x.get('mask_', None)
batch_size = x_global.shape[0]
valid = x_global[:, -1] == 0
......@@ -296,7 +296,8 @@ class Encoder(nn.Module):
# History actions
x_h_actions = x_h_actions.astype(jnp.int32)
h_mask = x_h_actions[:, :, 3] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
h_mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=h_mask.dtype), h_mask[:, 1:]], axis=1)
# h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., 1:3])
x_h_id = id_embed(x_h_id)
......@@ -355,7 +356,8 @@ class Encoder(nn.Module):
f_actions = x_a_feats + f_actions
a_mask = x_actions[:, :, 3] == 0
a_mask = a_mask.at[:, 0].set(False)
a_mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=a_mask.dtype), a_mask[:, 1:]], axis=1)
# a_mask = a_mask.at[:, 0].set(False)
g_feats = [f_g_card, f_global]
if self.use_history:
......@@ -698,17 +700,17 @@ class RNNAgent(nn.Module):
def init_rnn_state(self, batch_size):
if self.rnn_type in ['lstm', 'none']:
return (
np.zeros((batch_size, self.rnn_channels)),
np.zeros((batch_size, self.rnn_channels)),
np.zeros((batch_size, self.rnn_channels), dtype=np.float32),
np.zeros((batch_size, self.rnn_channels), dtype=np.float32),
)
elif self.rnn_type == 'gru':
return np.zeros((batch_size, self.rnn_channels))
return np.zeros((batch_size, self.rnn_channels), dtype=np.float32)
elif self.rnn_type == 'rwkv':
head_size = self.rwkv_head_size
num_heads = self.rnn_channels // self.rwkv_head_size
return (
np.zeros((batch_size, num_heads*head_size)),
np.zeros((batch_size, num_heads*head_size*head_size)),
np.zeros((batch_size, num_heads*head_size), dtype=np.float32),
np.zeros((batch_size, num_heads*head_size*head_size), dtype=np.float32),
)
else:
return None
......
from setuptools import setup, find_packages
__version__ = "0.0.1"
INSTALL_REQUIRES = [
"numpy",
"optree",
"fastapi",
"uvicorn[standard]",
"pydantic_settings",
"tflite-runtime",
]
setup(
name="ygoinf",
version=__version__,
packages=find_packages(include='ygoinf*'),
long_description="",
install_requires=INSTALL_REQUIRES,
python_requires=">=3.10",
)
\ No newline at end of file
......@@ -40,10 +40,13 @@ def combinations_with_weight2(weights, r):
N_CARD_FEATURES = 41
MAX_CARDS = 80
MAX_ACTIONS = 24
N_GLOBAL_FEATURES = 23
N_ACTION_FEATURES = 12
N_GLOBAL_FEATURES = 23
N_HISTORY_ACTIONS = 32
H_ACTIONS_SHAPE = (N_HISTORY_ACTIONS, N_ACTION_FEATURES + 2)
H_ACTIONS_FEATS = 14
N_RNN_CHANNELS = 512
H_ACTIONS_SHAPE = (N_HISTORY_ACTIONS, H_ACTIONS_FEATS)
DESCRIPTION_LIMIT = 10000
CARD_EFFECT_OFFSET = 10010
......@@ -58,9 +61,13 @@ def sample_input():
"global_": global_,
"actions_": legal_actions,
"h_actions_": history_actions,
"mask_": None,
}
def init_rstate():
return (
np.zeros((1, N_RNN_CHANNELS), dtype=np.float32),
np.zeros((1, N_RNN_CHANNELS), dtype=np.float32),
)
system_strings = [
1050, 1051, 1052, 1054, 1055, 1056, 1057, 1058, 1059, 1060,
......@@ -1047,8 +1054,8 @@ class HistoryActions:
class PredictState:
def __init__(self, init_rstate):
self.rstate = init_rstate
def __init__(self):
self.rstate = init_rstate()
self.index = 0
self.history_actions = HistoryActions()
......@@ -1097,7 +1104,6 @@ def predict(model_fn, input: Input, prev_action_idx, state: PredictState):
"global_": global_,
"actions_": actions,
"h_actions_": h_actions,
"mask_": None,
}
if n_actions == 1:
probs = [1.0]
......@@ -1123,3 +1129,24 @@ def predict(model_fn, input: Input, prev_action_idx, state: PredictState):
state.record(input, actions, probs)
state.index += 1
return predict_results
class Predictor:
def __init__(self, loaded, predict_fn):
self.loaded = loaded
self.predict_fn = predict_fn
def predict(self, rstate, sample_obs):
return self.predict_fn(self.loaded, rstate, sample_obs)
@staticmethod
def load(checkpoint):
sample_obs = sample_input()
rstate = init_rstate()
if checkpoint.endswith(".flax_model"):
from .jax_inf import load_model, predict_fn
elif checkpoint.endswith(".tflite"):
from .tflite_inf import load_model, predict_fn
predictor = Predictor(load_model(checkpoint, rstate, sample_obs), predict_fn)
predictor.predict(rstate, sample_obs)
return predictor
import numpy as np
import jax
import jax.numpy as jnp
import flax
from ygoai.rl.jax.agent import RNNAgent
def create_agent():
return RNNAgent(
num_layers=2,
rnn_channels=512,
use_history=True,
rnn_type='lstm',
num_channels=128,
film=True,
noam=True,
version=2,
)
@jax.jit
def get_probs_and_value(params, rstate, obs):
agent = create_agent()
next_rstate, logits, value = agent.apply(params, obs, rstate)[:3]
probs = jax.nn.softmax(logits, axis=-1)
return next_rstate, probs, value
def predict_fn(params, rstate, obs):
obs = jax.tree.map(lambda x: jnp.array([x]), obs)
rstate, probs, value = get_probs_and_value(params, rstate, obs)
return rstate, np.array(probs)[0].tolist(), float(np.array(value)[0])
def load_model(checkpoint, rstate, sample_obs):
agent = create_agent()
key = jax.random.PRNGKey(0)
key, agent_key = jax.random.split(key, 2)
sample_obs_ = jax.tree.map(lambda x: jnp.array([x]), sample_obs)
params = jax.jit(agent.init)(agent_key, sample_obs_, rstate)
with open(checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
params = jax.device_put(params)
return params
......@@ -8,13 +8,9 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI, Path
from fastapi.middleware.cors import CORSMiddleware
from pydantic import Field
from pydantic_settings import BaseSettings
import numpy as np
import jax
import jax.numpy as jnp
import flax
from ygoai.rl.jax.agent import RNNAgent
from .models import (
DuelCreateResponse,
......@@ -22,45 +18,21 @@ from .models import (
DuelPredictResponse,
DuelPredictErrorResponse,
)
from .features import predict, sample_input, init_code_list, PredictState
from .features import predict, init_code_list, PredictState, Predictor
class Settings(BaseSettings):
code_list: str = "code_list.txt"
checkpoint: str = "latest.flax_model"
enable_cors: bool = Field(default=True, description="Enable CORS")
settings = Settings()
def create_agent():
return RNNAgent(
num_layers=2,
rnn_channels=512,
use_history=True,
rnn_type='lstm',
num_channels=128,
film=True,
noam=True,
version=2,
)
@jax.jit
def get_probs_and_value(params, rstate, obs):
agent = create_agent()
next_rstate, logits, value = agent.apply(params, obs, rstate)[:3]
probs = jax.nn.softmax(logits, axis=-1)
return next_rstate, probs, value
def predict_fn(params, rstate, obs):
obs = jax.tree.map(lambda x: jnp.array([x]), obs)
rstate, probs, value = get_probs_and_value(params, rstate, obs)
return rstate, np.array(probs)[0].tolist(), float(np.array(value)[0])
all_models = {}
duel_states: Dict[str, PredictState] = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
from jax.experimental.compilation_cache import compilation_cache as cc
......@@ -68,26 +40,9 @@ async def lifespan(app: FastAPI):
init_code_list(settings.code_list)
agent = create_agent()
key = jax.random.PRNGKey(0)
key, agent_key = jax.random.split(key, 2)
sample_obs = sample_input()
sample_obs_ = jax.tree.map(lambda x: jnp.array([x]), sample_obs)
rstate = agent.init_rnn_state(1)
params = jax.jit(agent.init)(agent_key, sample_obs_, rstate)
checkpoint = settings.checkpoint
with open(checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
params = jax.device_put(params)
all_models["param"] = params
all_models["agent"] = agent
predict_fn(params, rstate, sample_obs)
predictor = Predictor.load(checkpoint)
all_models["default"] = predictor
print(f"loaded checkpoint from {checkpoint}")
state = new_state()
......@@ -103,16 +58,17 @@ app = FastAPI(
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if settings.enable_cors:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def new_state():
return PredictState(all_models["agent"].init_rnn_state(1))
return PredictState()
@app.post('/v0/duels', response_model=DuelCreateResponse)
async def create_duel() -> DuelCreateResponse:
......@@ -153,10 +109,10 @@ async def duel_predict(
error=f"index mismatch: expected {duel_state.index}, got {index}"
)
params = all_models["param"]
predictor = all_models["default"]
model_fn = predictor.predict
_start = time.time()
model_fn = lambda r, x: predict_fn(params, r, x)
try:
predict_results = predict(model_fn, body.input, body.prev_action_idx, duel_state)
except (KeyError, NotImplementedError) as e:
......
import numpy as np
import optree
import tflite_runtime.interpreter as tf_lite
def tflite_predict(interpreter, rstate, obs):
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
inputs = rstate, obs
for i, x in enumerate(optree.tree_leaves(inputs)):
interpreter.set_tensor(input_details[i]["index"], x)
interpreter.invoke()
results = [
interpreter.get_tensor(o["index"]) for o in output_details]
rstate1, rstate2, probs, value = results
rstate = (rstate1, rstate2)
return rstate, probs, value
def predict_fn(interpreter, rstate, obs):
obs = optree.tree_map(lambda x: np.array([x]), obs)
rstate, probs, value = tflite_predict(interpreter, rstate, obs)
prob = probs[0].tolist()
value = float(value[0])
return rstate, prob, value
def load_model(checkpoint, *args):
with open(checkpoint, "rb") as f:
tflite_model = f.read()
interpreter = tf_lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
return interpreter
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment