example-project/qwen/main.py
2026-03-30 18:13:09 +07:00

224 lines
8.0 KiB
Python

# main.py
# FastAPI backend di-tune untuk Qwen3.5-2B.Q4_K_M.gguf
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from fastapi.responses import JSONResponse
import os, sys, threading, traceback, json, time
from llama_cpp import Llama
# ---------- Config (sesuaikan jika perlu) ----------
MODEL_DEFAULT = "models/Qwen3.5-2B.Q4_K_M.gguf"
MODEL_PATH = os.environ.get("LLM_MODEL_PATH", MODEL_DEFAULT)
# VM: 13GB RAM, 9 core -> safe default
N_THREADS = int(os.environ.get("LLM_N_THREADS", 9)) # gunakan semua core jika ingin
N_CTX = int(os.environ.get("LLM_N_CTX", 1536)) # 2048 lebih aman daripada 4096
N_BATCH = int(os.environ.get("LLM_N_BATCH", 256)) # batch lebih besar mempercepat
MAX_TOKENS = int(os.environ.get("LLM_MAX_TOKENS", 450))
TEMPERATURE = float(os.environ.get("LLM_TEMPERATURE", 0.3))
TOP_P = float(os.environ.get("LLM_TOP_P", 0.9))
CONCURRENCY = int(os.environ.get("LLM_CONCURRENCY", 1))
USE_MMAP = os.environ.get("LLM_USE_MMAP", "1") == "1"
USE_MLOCK = os.environ.get("LLM_USE_MLOCK", "0") == "1" # mlock bisa gagal tanpa privileges
# ---------- FastAPI ----------
app = FastAPI(root_path="/api")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
# concurrency control
llm_sem = threading.Semaphore(CONCURRENCY)
# ---------- HealthInput (tetap sama) ----------
class HealthInput(BaseModel):
weight: float
height: float
age: int
spo2: int
heart_rate: int
activity: str
# ---------- Try to load model (with informative error) ----------
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Loading model from: {MODEL_PATH}", file=sys.stderr)
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file not found: {MODEL_PATH}. Place your .gguf file there or set LLM_MODEL_PATH env var.")
try:
llm = Llama(
model_path=MODEL_PATH,
n_ctx=N_CTX,
n_threads=N_THREADS,
n_batch=N_BATCH,
use_mmap=USE_MMAP,
use_mlock=USE_MLOCK
)
print("Model loaded successfully.", file=sys.stderr)
except Exception as e:
# jika model gagal load, print stack dan raise supaya operator tahu.
print("ERROR loading model:", file=sys.stderr)
traceback.print_exc()
raise
# ---------- Helpers: JSON extraction & normalization ----------
def extract_json_object(text: str) -> str:
start = text.find("{")
if start == -1:
return "{}"
stack = 0
for i in range(start, len(text)):
if text[i] == "{":
stack += 1
elif text[i] == "}":
stack -= 1
if stack == 0:
return text[start:i+1]
return text[start:]
def sanitize_and_normalize(parsed: dict, bmi_value) -> dict:
# map common typos to canonical keys
mapping = {}
for k, v in parsed.items():
lk = k.strip().lower()
if "summ" in lk or "summary" in lk:
mapping["summary"] = v
elif lk == "bmi":
mapping["bmi"] = v
elif "bmi" in lk and "cat" in lk:
mapping["bmi_category"] = v
elif "spo2" in lk or "sp02" in lk:
mapping["spo2_status"] = v
elif "heart" in lk and ("rate" in lk or "hr" in lk):
mapping["heart_rate_status"] = v
elif "risk" in lk:
mapping["health_risk"] = v
elif "recom" in lk or "recomend" in lk:
mapping["recommendation"] = v
elif "expl" in lk or "expla" in lk:
mapping["explanation"] = v
else:
mapping[k] = v
# defaults
defaults = {
"summary": "Unable to produce summary.",
"bmi": bmi_value,
"bmi_category": "normal",
"spo2_status": "normal",
"heart_rate_status": "normal",
"health_risk": "low",
"recommendation": "Maintain balanced diet, regular exercise, and consult a professional if concerned.",
"explanation": "No detailed explanation provided."
}
for dk, dv in defaults.items():
if dk not in mapping or mapping[dk] in (None, ""):
mapping[dk] = dv
# clean templates like "normal|low|dangerous"
def pick(val, allowed):
if not isinstance(val, str): return val
if "|" in val:
parts = [p.strip().lower() for p in val.split("|")]
for p in parts:
if p in allowed:
return p
return parts[0]
lv = val.strip().lower()
for a in allowed:
if a in lv:
return a
return val
mapping["bmi_category"] = pick(mapping["bmi_category"], ["underweight", "normal", "overweight"])
mapping["spo2_status"] = pick(mapping["spo2_status"], ["normal", "low", "dangerous"])
mapping["heart_rate_status"] = pick(mapping["heart_rate_status"], ["normal", "low", "high"])
mapping["health_risk"] = pick(mapping["health_risk"], ["low", "moderate", "high"])
mapping["bmi"] = bmi_value
return mapping
# ---------- Prompt ----------
SYSTEM_PROMPT = (
"You are a conservative medical assistant. Output ONLY a single JSON object with keys: "
"summary, bmi, bmi_category, spo2_status, heart_rate_status, health_risk, recommendation, explanation."
)
def build_user_prompt(data: HealthInput, bmi):
return (
f"Health data:\n"
f"weight: {data.weight}\n"
f"height: {data.height}\n"
f"age: {data.age}\n"
f"spo2: {data.spo2}\n"
f"heart_rate: {data.heart_rate}\n"
f"activity: {data.activity}\n"
f"bmi: {bmi}\n\n"
"Return only a single JSON object with the exact keys."
)
# ---------- Endpoints ----------
@app.get("/")
def root():
return {"status": "ok", "model": os.path.basename(MODEL_PATH), "n_ctx": N_CTX, "n_threads": N_THREADS}
@app.post("/analyze")
async def analyze(request: Request):
# parse request body safely
try:
payload = await request.json()
except Exception:
return JSONResponse(status_code=400, content={"error": "invalid JSON body"})
# validate minimal fields (frontend uses same schema)
try:
data = HealthInput(**payload)
except Exception as e:
return JSONResponse(status_code=400, content={"error": "invalid payload", "details": str(e)})
# compute bmi
try:
height_m = data.height / 100.0
bmi_value = round(data.weight / (height_m ** 2), 1)
except Exception:
bmi_value = None
user_prompt = build_user_prompt(data, bmi_value)
# call model (with semaphore)
with llm_sem:
try:
kwargs = dict(
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt}
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
top_p=TOP_P,
)
# try to include response_format when supported
try:
kwargs_rf = kwargs.copy()
kwargs_rf["response_format"] = {"type": "json_object"}
response = llm.create_chat_completion(**kwargs_rf)
except TypeError:
response = llm.create_chat_completion(**kwargs)
except Exception as e:
# log exception and return JSON error (avoid HTML)
tb = traceback.format_exc()
print("Model inference error:", file=sys.stderr)
print(tb, file=sys.stderr)
return JSONResponse(status_code=500, content={"error": "model_inference_failed", "details": str(e)})
# extract text
raw_text = ""
try:
raw_text = response.get("choices", [{}])[0].get("message", {}).get("content", "")
except Exception:
raw_text = str(response)
# attempt parse
json_text = extract_json_object(raw_text)
parsed = {}
try:
parsed = json.loads(json_text)
except Exception:
# fallback empty
parsed = {}
out = sanitize_and_normalize(parsed, bmi_value)
return JSONResponse(status_code=200, content=out)