Commit f6ad8563 authored by sbl1996@126.com's avatar sbl1996@126.com

Add expire for duel states in server

parent 7888000e
......@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
__version__ = "0.0.1"
INSTALL_REQUIRES = [
"numpy",
"numpy==1.26.4",
"optree",
"fastapi",
"uvicorn[standard]",
......
from enum import Enum
from itertools import combinations
import time
from typing import List
......@@ -1060,7 +1061,9 @@ class PredictState:
self.history_actions = HistoryActions()
self.reset()
self._timestamp = time.time()
def reset(self):
self._probs = None
self._actions = None
......@@ -1073,13 +1076,14 @@ class PredictState:
action = self._actions[idx1]
self.history_actions.update(action, self._turn, self._phase)
self.reset()
def record(self, input: Input, actions, probs):
self._probs = probs
self._actions = actions
self._action_msg = input.action_msg
self._turn = input.global_.turn
self._phase = input.global_.phase
self._timestamp = time.time()
def revert_pad_truncate(probs, n_actions):
if len(probs) < n_actions:
......
......@@ -3,6 +3,7 @@ os.environ.setdefault("JAX_PLATFORMS", "cpu")
from typing import Union, Dict
import time
import threading
import uuid
from contextlib import asynccontextmanager
......@@ -25,13 +26,28 @@ class Settings(BaseSettings):
code_list: str = "code_list.txt"
checkpoint: str = "latest.flax_model"
enable_cors: bool = Field(default=True, description="Enable CORS")
state_expire: int = Field(default=3600, description="Duel state expire time in seconds")
test_duel_id: str = Field(default="9654823a-23fd-4850-bb-6fec241740b0", description="Test duel id")
settings = Settings()
all_models = {}
duel_states: Dict[str, PredictState] = {}
def delete_outdated_states():
while True:
current_time = time.time()
for k, v in list(duel_states.items()):
if k == settings.test_duel_id:
continue
if current_time - v._timestamp > settings.state_expire:
del duel_states[k]
time.sleep(600)
# Start the thread to delete outdated states
thread = threading.Thread(target=delete_outdated_states)
thread.daemon = True
thread.start()
@asynccontextmanager
async def lifespan(app: FastAPI):
......@@ -43,7 +59,7 @@ async def lifespan(app: FastAPI):
print(f"loaded checkpoint from {checkpoint}")
state = new_state()
test_duel_id = "9654823a-23fd-4850-bb-6fec241740b0"
test_duel_id = settings.test_duel_id
duel_states[test_duel_id] = state
yield
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment