import os
from fastapi import FastAPI, Header, HTTPException
from fastapi.responses import StreamingResponse
from zipstream import ZipStream, ZIP_DEFLATED
import pymysql
from dotenv import load_dotenv
from starlette.middleware.cors import CORSMiddleware


# 加载 .env 文件中的环境变量（如有）
load_dotenv()

# 从环境变量获取配置
DB_CONFIG = {
    "host": os.getenv("DB_HOST", "localhost"),
    "port": int(os.getenv("DB_PORT", "3306")),
    "user": os.getenv("DB_USER", "root"),
    "password": os.getenv("DB_PASS", ""),
    "database": os.getenv("DB_NAME", ""),
    "charset": "utf8mb4"
}
UPLOAD_DIR = os.getenv("UPLOAD_DIR", "./Uploads")
SERVER_TOKEN = os.getenv("SERVER_TOKEN")


app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

def get_db():
    return pymysql.connect(**DB_CONFIG)

@app.get("/api/cmpt/{cmptid}.zip")
def stream_zip(cmptid: int, x_server_token: str = Header(default=None)):
    # 可选 token 校验
    if SERVER_TOKEN and x_server_token != SERVER_TOKEN:
        raise HTTPException(status_code=403, detail="Invalid or missing token")

    db = get_db()
    cursor = db.cursor(pymysql.cursors.DictCursor)

    # 获取命名格式
    cursor.execute("SELECT format FROM mc_cmpt_downloadformat WHERE cmptid = %s", (cmptid,))
    format_row = cursor.fetchone()
    if not format_row:
        raise HTTPException(status_code=404, detail="No format found")

    format_fields = format_row['format'].split('+')

    # 获取数据记录
    cursor.execute("SELECT * FROM mc_cmpt_reg WHERE cmptid = %s AND status = 0", (cmptid,))
    records = cursor.fetchall()
    if not records:
        raise HTTPException(status_code=404, detail="No records found")

    z = ZipStream(compress_type=ZIP_DEFLATED, compress_level=9)

    for row in records:
        parts = []
        for f in format_fields:
            key = f.strip('{}')
            parts.append(str(row.get(key, '')))
        ext = os.path.splitext(row['filename'])[1]
        archive_name = '+'.join(parts) + ext
        file_path = os.path.join(UPLOAD_DIR, row['ydk'])

        if os.path.exists(file_path):
            z.add_path(file_path, arcname=archive_name)

    cursor.close()
    db.close()

    return StreamingResponse(
        z,
        media_type='application/zip',
        headers={"Content-Disposition": f'attachment; filename="{cmptid}.zip"'}
    )


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, proxy_headers=True, host="0.0.0.0", port=3000)

