Commit 3a79013f authored by sbl1996@126.com's avatar sbl1996@126.com

Add tflite support for inf server

parent a727022b
...@@ -11,6 +11,7 @@ dev: assets script py_install ygoenv_so ...@@ -11,6 +11,7 @@ dev: assets script py_install ygoenv_so
py_install: py_install:
pip install -e ygoenv pip install -e ygoenv
pip install -e ygoinf
pip install -e . pip install -e .
ygoenv_so: ygoenv/ygoenv/ygopro/ygopro_ygoenv.so ygoenv_so: ygoenv/ygoenv/ygopro/ygopro_ygoenv.so
...@@ -46,6 +47,4 @@ assets/locale/zh/strings.conf: assets/locale/zh ...@@ -46,6 +47,4 @@ assets/locale/zh/strings.conf: assets/locale/zh
clean: clean:
rm -rf scripts/script rm -rf scripts/script
rm -rf assets/locale/en assets/locale/zh rm -rf assets/locale/en assets/locale/zh
pip uninstall -y ygoenv \ No newline at end of file
pip uninstall -y .
\ No newline at end of file
...@@ -16,9 +16,6 @@ REQUIRED = [ ...@@ -16,9 +16,6 @@ REQUIRED = [
"tyro", "tyro",
"pandas", "pandas",
"tensorboardX", "tensorboardX",
"fastapi",
"uvicorn[standard]",
"pydantic_settings",
"tqdm", "tqdm",
] ]
......
...@@ -242,7 +242,7 @@ class Encoder(nn.Module): ...@@ -242,7 +242,7 @@ class Encoder(nn.Module):
x_global = x['global_'] x_global = x['global_']
x_actions = x['actions_'] x_actions = x['actions_']
x_h_actions = x['h_actions_'] x_h_actions = x['h_actions_']
mask = x['mask_'] mask = x.get('mask_', None)
batch_size = x_global.shape[0] batch_size = x_global.shape[0]
valid = x_global[:, -1] == 0 valid = x_global[:, -1] == 0
...@@ -296,7 +296,8 @@ class Encoder(nn.Module): ...@@ -296,7 +296,8 @@ class Encoder(nn.Module):
# History actions # History actions
x_h_actions = x_h_actions.astype(jnp.int32) x_h_actions = x_h_actions.astype(jnp.int32)
h_mask = x_h_actions[:, :, 3] == 0 # msg == 0 h_mask = x_h_actions[:, :, 3] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False) h_mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=h_mask.dtype), h_mask[:, 1:]], axis=1)
# h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., 1:3]) x_h_id = decode_id(x_h_actions[..., 1:3])
x_h_id = id_embed(x_h_id) x_h_id = id_embed(x_h_id)
...@@ -355,7 +356,8 @@ class Encoder(nn.Module): ...@@ -355,7 +356,8 @@ class Encoder(nn.Module):
f_actions = x_a_feats + f_actions f_actions = x_a_feats + f_actions
a_mask = x_actions[:, :, 3] == 0 a_mask = x_actions[:, :, 3] == 0
a_mask = a_mask.at[:, 0].set(False) a_mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=a_mask.dtype), a_mask[:, 1:]], axis=1)
# a_mask = a_mask.at[:, 0].set(False)
g_feats = [f_g_card, f_global] g_feats = [f_g_card, f_global]
if self.use_history: if self.use_history:
...@@ -698,17 +700,17 @@ class RNNAgent(nn.Module): ...@@ -698,17 +700,17 @@ class RNNAgent(nn.Module):
def init_rnn_state(self, batch_size): def init_rnn_state(self, batch_size):
if self.rnn_type in ['lstm', 'none']: if self.rnn_type in ['lstm', 'none']:
return ( return (
np.zeros((batch_size, self.rnn_channels)), np.zeros((batch_size, self.rnn_channels), dtype=np.float32),
np.zeros((batch_size, self.rnn_channels)), np.zeros((batch_size, self.rnn_channels), dtype=np.float32),
) )
elif self.rnn_type == 'gru': elif self.rnn_type == 'gru':
return np.zeros((batch_size, self.rnn_channels)) return np.zeros((batch_size, self.rnn_channels), dtype=np.float32)
elif self.rnn_type == 'rwkv': elif self.rnn_type == 'rwkv':
head_size = self.rwkv_head_size head_size = self.rwkv_head_size
num_heads = self.rnn_channels // self.rwkv_head_size num_heads = self.rnn_channels // self.rwkv_head_size
return ( return (
np.zeros((batch_size, num_heads*head_size)), np.zeros((batch_size, num_heads*head_size), dtype=np.float32),
np.zeros((batch_size, num_heads*head_size*head_size)), np.zeros((batch_size, num_heads*head_size*head_size), dtype=np.float32),
) )
else: else:
return None return None
......
from setuptools import setup, find_packages
__version__ = "0.0.1"
INSTALL_REQUIRES = [
"numpy",
"optree",
"fastapi",
"uvicorn[standard]",
"pydantic_settings",
"tflite-runtime",
]
setup(
name="ygoinf",
version=__version__,
packages=find_packages(include='ygoinf*'),
long_description="",
install_requires=INSTALL_REQUIRES,
python_requires=">=3.10",
)
\ No newline at end of file
...@@ -40,10 +40,13 @@ def combinations_with_weight2(weights, r): ...@@ -40,10 +40,13 @@ def combinations_with_weight2(weights, r):
N_CARD_FEATURES = 41 N_CARD_FEATURES = 41
MAX_CARDS = 80 MAX_CARDS = 80
MAX_ACTIONS = 24 MAX_ACTIONS = 24
N_GLOBAL_FEATURES = 23
N_ACTION_FEATURES = 12 N_ACTION_FEATURES = 12
N_GLOBAL_FEATURES = 23
N_HISTORY_ACTIONS = 32 N_HISTORY_ACTIONS = 32
H_ACTIONS_SHAPE = (N_HISTORY_ACTIONS, N_ACTION_FEATURES + 2) H_ACTIONS_FEATS = 14
N_RNN_CHANNELS = 512
H_ACTIONS_SHAPE = (N_HISTORY_ACTIONS, H_ACTIONS_FEATS)
DESCRIPTION_LIMIT = 10000 DESCRIPTION_LIMIT = 10000
CARD_EFFECT_OFFSET = 10010 CARD_EFFECT_OFFSET = 10010
...@@ -58,9 +61,13 @@ def sample_input(): ...@@ -58,9 +61,13 @@ def sample_input():
"global_": global_, "global_": global_,
"actions_": legal_actions, "actions_": legal_actions,
"h_actions_": history_actions, "h_actions_": history_actions,
"mask_": None,
} }
def init_rstate():
return (
np.zeros((1, N_RNN_CHANNELS), dtype=np.float32),
np.zeros((1, N_RNN_CHANNELS), dtype=np.float32),
)
system_strings = [ system_strings = [
1050, 1051, 1052, 1054, 1055, 1056, 1057, 1058, 1059, 1060, 1050, 1051, 1052, 1054, 1055, 1056, 1057, 1058, 1059, 1060,
...@@ -1047,8 +1054,8 @@ class HistoryActions: ...@@ -1047,8 +1054,8 @@ class HistoryActions:
class PredictState: class PredictState:
def __init__(self, init_rstate): def __init__(self):
self.rstate = init_rstate self.rstate = init_rstate()
self.index = 0 self.index = 0
self.history_actions = HistoryActions() self.history_actions = HistoryActions()
...@@ -1097,7 +1104,6 @@ def predict(model_fn, input: Input, prev_action_idx, state: PredictState): ...@@ -1097,7 +1104,6 @@ def predict(model_fn, input: Input, prev_action_idx, state: PredictState):
"global_": global_, "global_": global_,
"actions_": actions, "actions_": actions,
"h_actions_": h_actions, "h_actions_": h_actions,
"mask_": None,
} }
if n_actions == 1: if n_actions == 1:
probs = [1.0] probs = [1.0]
...@@ -1123,3 +1129,24 @@ def predict(model_fn, input: Input, prev_action_idx, state: PredictState): ...@@ -1123,3 +1129,24 @@ def predict(model_fn, input: Input, prev_action_idx, state: PredictState):
state.record(input, actions, probs) state.record(input, actions, probs)
state.index += 1 state.index += 1
return predict_results return predict_results
class Predictor:
def __init__(self, loaded, predict_fn):
self.loaded = loaded
self.predict_fn = predict_fn
def predict(self, rstate, sample_obs):
return self.predict_fn(self.loaded, rstate, sample_obs)
@staticmethod
def load(checkpoint):
sample_obs = sample_input()
rstate = init_rstate()
if checkpoint.endswith(".flax_model"):
from .jax_inf import load_model, predict_fn
elif checkpoint.endswith(".tflite"):
from .tflite_inf import load_model, predict_fn
predictor = Predictor(load_model(checkpoint, rstate, sample_obs), predict_fn)
predictor.predict(rstate, sample_obs)
return predictor
import numpy as np
import jax
import jax.numpy as jnp
import flax
from ygoai.rl.jax.agent import RNNAgent
def create_agent():
return RNNAgent(
num_layers=2,
rnn_channels=512,
use_history=True,
rnn_type='lstm',
num_channels=128,
film=True,
noam=True,
version=2,
)
@jax.jit
def get_probs_and_value(params, rstate, obs):
agent = create_agent()
next_rstate, logits, value = agent.apply(params, obs, rstate)[:3]
probs = jax.nn.softmax(logits, axis=-1)
return next_rstate, probs, value
def predict_fn(params, rstate, obs):
obs = jax.tree.map(lambda x: jnp.array([x]), obs)
rstate, probs, value = get_probs_and_value(params, rstate, obs)
return rstate, np.array(probs)[0].tolist(), float(np.array(value)[0])
def load_model(checkpoint, rstate, sample_obs):
agent = create_agent()
key = jax.random.PRNGKey(0)
key, agent_key = jax.random.split(key, 2)
sample_obs_ = jax.tree.map(lambda x: jnp.array([x]), sample_obs)
params = jax.jit(agent.init)(agent_key, sample_obs_, rstate)
with open(checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
params = jax.device_put(params)
return params
...@@ -8,13 +8,9 @@ from contextlib import asynccontextmanager ...@@ -8,13 +8,9 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI, Path from fastapi import FastAPI, Path
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import Field
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
import numpy as np
import jax
import jax.numpy as jnp
import flax
from ygoai.rl.jax.agent import RNNAgent
from .models import ( from .models import (
DuelCreateResponse, DuelCreateResponse,
...@@ -22,45 +18,21 @@ from .models import ( ...@@ -22,45 +18,21 @@ from .models import (
DuelPredictResponse, DuelPredictResponse,
DuelPredictErrorResponse, DuelPredictErrorResponse,
) )
from .features import predict, sample_input, init_code_list, PredictState from .features import predict, init_code_list, PredictState, Predictor
class Settings(BaseSettings): class Settings(BaseSettings):
code_list: str = "code_list.txt" code_list: str = "code_list.txt"
checkpoint: str = "latest.flax_model" checkpoint: str = "latest.flax_model"
enable_cors: bool = Field(default=True, description="Enable CORS")
settings = Settings() settings = Settings()
def create_agent():
return RNNAgent(
num_layers=2,
rnn_channels=512,
use_history=True,
rnn_type='lstm',
num_channels=128,
film=True,
noam=True,
version=2,
)
@jax.jit
def get_probs_and_value(params, rstate, obs):
agent = create_agent()
next_rstate, logits, value = agent.apply(params, obs, rstate)[:3]
probs = jax.nn.softmax(logits, axis=-1)
return next_rstate, probs, value
def predict_fn(params, rstate, obs):
obs = jax.tree.map(lambda x: jnp.array([x]), obs)
rstate, probs, value = get_probs_and_value(params, rstate, obs)
return rstate, np.array(probs)[0].tolist(), float(np.array(value)[0])
all_models = {} all_models = {}
duel_states: Dict[str, PredictState] = {} duel_states: Dict[str, PredictState] = {}
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
from jax.experimental.compilation_cache import compilation_cache as cc from jax.experimental.compilation_cache import compilation_cache as cc
...@@ -68,26 +40,9 @@ async def lifespan(app: FastAPI): ...@@ -68,26 +40,9 @@ async def lifespan(app: FastAPI):
init_code_list(settings.code_list) init_code_list(settings.code_list)
agent = create_agent()
key = jax.random.PRNGKey(0)
key, agent_key = jax.random.split(key, 2)
sample_obs = sample_input()
sample_obs_ = jax.tree.map(lambda x: jnp.array([x]), sample_obs)
rstate = agent.init_rnn_state(1)
params = jax.jit(agent.init)(agent_key, sample_obs_, rstate)
checkpoint = settings.checkpoint checkpoint = settings.checkpoint
with open(checkpoint, "rb") as f: predictor = Predictor.load(checkpoint)
params = flax.serialization.from_bytes(params, f.read()) all_models["default"] = predictor
params = jax.device_put(params)
all_models["param"] = params
all_models["agent"] = agent
predict_fn(params, rstate, sample_obs)
print(f"loaded checkpoint from {checkpoint}") print(f"loaded checkpoint from {checkpoint}")
state = new_state() state = new_state()
...@@ -103,16 +58,17 @@ app = FastAPI( ...@@ -103,16 +58,17 @@ app = FastAPI(
lifespan=lifespan, lifespan=lifespan,
) )
app.add_middleware( if settings.enable_cors:
CORSMiddleware, app.add_middleware(
allow_origins=["*"], CORSMiddleware,
allow_credentials=True, allow_origins=["*"],
allow_methods=["*"], allow_credentials=True,
allow_headers=["*"], allow_methods=["*"],
) allow_headers=["*"],
)
def new_state(): def new_state():
return PredictState(all_models["agent"].init_rnn_state(1)) return PredictState()
@app.post('/v0/duels', response_model=DuelCreateResponse) @app.post('/v0/duels', response_model=DuelCreateResponse)
async def create_duel() -> DuelCreateResponse: async def create_duel() -> DuelCreateResponse:
...@@ -153,10 +109,10 @@ async def duel_predict( ...@@ -153,10 +109,10 @@ async def duel_predict(
error=f"index mismatch: expected {duel_state.index}, got {index}" error=f"index mismatch: expected {duel_state.index}, got {index}"
) )
params = all_models["param"] predictor = all_models["default"]
model_fn = predictor.predict
_start = time.time() _start = time.time()
model_fn = lambda r, x: predict_fn(params, r, x)
try: try:
predict_results = predict(model_fn, body.input, body.prev_action_idx, duel_state) predict_results = predict(model_fn, body.input, body.prev_action_idx, duel_state)
except (KeyError, NotImplementedError) as e: except (KeyError, NotImplementedError) as e:
......
import numpy as np
import optree
import tflite_runtime.interpreter as tf_lite
def tflite_predict(interpreter, rstate, obs):
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
inputs = rstate, obs
for i, x in enumerate(optree.tree_leaves(inputs)):
interpreter.set_tensor(input_details[i]["index"], x)
interpreter.invoke()
results = [
interpreter.get_tensor(o["index"]) for o in output_details]
rstate1, rstate2, probs, value = results
rstate = (rstate1, rstate2)
return rstate, probs, value
def predict_fn(interpreter, rstate, obs):
obs = optree.tree_map(lambda x: np.array([x]), obs)
rstate, probs, value = tflite_predict(interpreter, rstate, obs)
prob = probs[0].tolist()
value = float(value[0])
return rstate, prob, value
def load_model(checkpoint, *args):
with open(checkpoint, "rb") as f:
tflite_model = f.read()
interpreter = tf_lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
return interpreter
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment