Commit 83777a2d authored by sbl1996@126.com's avatar sbl1996@126.com

Add num_threads

parent 878f6a9e
......@@ -1144,13 +1144,13 @@ class Predictor:
return self.predict_fn(self.loaded, rstate, sample_obs)
@staticmethod
def load(checkpoint):
def load(checkpoint, num_threads):
sample_obs = sample_input()
rstate = init_rstate()
if checkpoint.endswith(".flax_model"):
from .jax_inf import load_model, predict_fn
elif checkpoint.endswith(".tflite"):
from .tflite_inf import load_model, predict_fn
predictor = Predictor(load_model(checkpoint, rstate, sample_obs), predict_fn)
predictor = Predictor(load_model(checkpoint, rstate, sample_obs, num_threads=num_threads), predict_fn)
predictor.predict(rstate, sample_obs)
return predictor
......@@ -35,7 +35,7 @@ def predict_fn(params, rstate, obs):
rstate, probs, value = get_probs_and_value(params, rstate, obs)
return rstate, np.array(probs)[0].tolist(), float(np.array(value)[0])
def load_model(checkpoint, rstate, sample_obs):
def load_model(checkpoint, rstate, sample_obs, **kwargs):
agent = create_agent()
key = jax.random.PRNGKey(0)
key, agent_key = jax.random.split(key, 2)
......
......@@ -29,6 +29,7 @@ class Settings(BaseSettings):
enable_cors: bool = Field(default=True, description="Enable CORS")
state_expire: int = Field(default=3600, description="Duel state expire time in seconds")
test_duel_id: str = Field(default="9654823a-23fd-4850-bb-6fec241740b0", description="Test duel id")
ygo_num_threads: int = Field(default=1, description="Number of threads to use for YGO prediction")
settings = Settings()
......@@ -55,7 +56,7 @@ async def lifespan(app: FastAPI):
init_code_list(settings.code_list)
checkpoint = settings.checkpoint
predictor = Predictor.load(checkpoint)
predictor = Predictor.load(checkpoint, settings.ygo_num_threads)
all_models["default"] = predictor
print(f"loaded checkpoint from {checkpoint}")
......
......@@ -23,9 +23,10 @@ def predict_fn(interpreter, rstate, obs):
value = float(value[0])
return rstate, prob, value
def load_model(checkpoint, *args):
def load_model(checkpoint, *args, **kwargs):
with open(checkpoint, "rb") as f:
tflite_model = f.read()
interpreter = tf_lite.Interpreter(model_content=tflite_model)
interpreter = tf_lite.Interpreter(
model_content=tflite_model, num_threads=kwargs.get("num_threads", 1))
interpreter.allocate_tensors()
return interpreter
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment