224 lines
8.0 KiB
Python
224 lines
8.0 KiB
Python
# main.py
|
|
# FastAPI backend di-tune untuk Qwen3.5-2B.Q4_K_M.gguf
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
from fastapi.responses import JSONResponse
|
|
import os, sys, threading, traceback, json, time
|
|
from llama_cpp import Llama
|
|
|
|
# ---------- Config (sesuaikan jika perlu) ----------
|
|
MODEL_DEFAULT = "models/Qwen3.5-2B.Q4_K_M.gguf"
|
|
MODEL_PATH = os.environ.get("LLM_MODEL_PATH", MODEL_DEFAULT)
|
|
|
|
# VM: 13GB RAM, 9 core -> safe default
|
|
N_THREADS = int(os.environ.get("LLM_N_THREADS", 9)) # gunakan semua core jika ingin
|
|
N_CTX = int(os.environ.get("LLM_N_CTX", 1536)) # 2048 lebih aman daripada 4096
|
|
N_BATCH = int(os.environ.get("LLM_N_BATCH", 256)) # batch lebih besar mempercepat
|
|
MAX_TOKENS = int(os.environ.get("LLM_MAX_TOKENS", 450))
|
|
TEMPERATURE = float(os.environ.get("LLM_TEMPERATURE", 0.3))
|
|
TOP_P = float(os.environ.get("LLM_TOP_P", 0.9))
|
|
CONCURRENCY = int(os.environ.get("LLM_CONCURRENCY", 1))
|
|
USE_MMAP = os.environ.get("LLM_USE_MMAP", "1") == "1"
|
|
USE_MLOCK = os.environ.get("LLM_USE_MLOCK", "0") == "1" # mlock bisa gagal tanpa privileges
|
|
|
|
# ---------- FastAPI ----------
|
|
app = FastAPI(root_path="/api")
|
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
|
|
|
|
# concurrency control
|
|
llm_sem = threading.Semaphore(CONCURRENCY)
|
|
|
|
# ---------- HealthInput (tetap sama) ----------
|
|
class HealthInput(BaseModel):
|
|
weight: float
|
|
height: float
|
|
age: int
|
|
spo2: int
|
|
heart_rate: int
|
|
activity: str
|
|
|
|
# ---------- Try to load model (with informative error) ----------
|
|
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Loading model from: {MODEL_PATH}", file=sys.stderr)
|
|
if not os.path.exists(MODEL_PATH):
|
|
raise FileNotFoundError(f"Model file not found: {MODEL_PATH}. Place your .gguf file there or set LLM_MODEL_PATH env var.")
|
|
|
|
try:
|
|
llm = Llama(
|
|
model_path=MODEL_PATH,
|
|
n_ctx=N_CTX,
|
|
n_threads=N_THREADS,
|
|
n_batch=N_BATCH,
|
|
use_mmap=USE_MMAP,
|
|
use_mlock=USE_MLOCK
|
|
)
|
|
print("Model loaded successfully.", file=sys.stderr)
|
|
except Exception as e:
|
|
# jika model gagal load, print stack dan raise supaya operator tahu.
|
|
print("ERROR loading model:", file=sys.stderr)
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
# ---------- Helpers: JSON extraction & normalization ----------
|
|
def extract_json_object(text: str) -> str:
|
|
start = text.find("{")
|
|
if start == -1:
|
|
return "{}"
|
|
stack = 0
|
|
for i in range(start, len(text)):
|
|
if text[i] == "{":
|
|
stack += 1
|
|
elif text[i] == "}":
|
|
stack -= 1
|
|
if stack == 0:
|
|
return text[start:i+1]
|
|
return text[start:]
|
|
|
|
def sanitize_and_normalize(parsed: dict, bmi_value) -> dict:
|
|
# map common typos to canonical keys
|
|
mapping = {}
|
|
for k, v in parsed.items():
|
|
lk = k.strip().lower()
|
|
if "summ" in lk or "summary" in lk:
|
|
mapping["summary"] = v
|
|
elif lk == "bmi":
|
|
mapping["bmi"] = v
|
|
elif "bmi" in lk and "cat" in lk:
|
|
mapping["bmi_category"] = v
|
|
elif "spo2" in lk or "sp02" in lk:
|
|
mapping["spo2_status"] = v
|
|
elif "heart" in lk and ("rate" in lk or "hr" in lk):
|
|
mapping["heart_rate_status"] = v
|
|
elif "risk" in lk:
|
|
mapping["health_risk"] = v
|
|
elif "recom" in lk or "recomend" in lk:
|
|
mapping["recommendation"] = v
|
|
elif "expl" in lk or "expla" in lk:
|
|
mapping["explanation"] = v
|
|
else:
|
|
mapping[k] = v
|
|
# defaults
|
|
defaults = {
|
|
"summary": "Unable to produce summary.",
|
|
"bmi": bmi_value,
|
|
"bmi_category": "normal",
|
|
"spo2_status": "normal",
|
|
"heart_rate_status": "normal",
|
|
"health_risk": "low",
|
|
"recommendation": "Maintain balanced diet, regular exercise, and consult a professional if concerned.",
|
|
"explanation": "No detailed explanation provided."
|
|
}
|
|
for dk, dv in defaults.items():
|
|
if dk not in mapping or mapping[dk] in (None, ""):
|
|
mapping[dk] = dv
|
|
# clean templates like "normal|low|dangerous"
|
|
def pick(val, allowed):
|
|
if not isinstance(val, str): return val
|
|
if "|" in val:
|
|
parts = [p.strip().lower() for p in val.split("|")]
|
|
for p in parts:
|
|
if p in allowed:
|
|
return p
|
|
return parts[0]
|
|
lv = val.strip().lower()
|
|
for a in allowed:
|
|
if a in lv:
|
|
return a
|
|
return val
|
|
mapping["bmi_category"] = pick(mapping["bmi_category"], ["underweight", "normal", "overweight"])
|
|
mapping["spo2_status"] = pick(mapping["spo2_status"], ["normal", "low", "dangerous"])
|
|
mapping["heart_rate_status"] = pick(mapping["heart_rate_status"], ["normal", "low", "high"])
|
|
mapping["health_risk"] = pick(mapping["health_risk"], ["low", "moderate", "high"])
|
|
mapping["bmi"] = bmi_value
|
|
return mapping
|
|
|
|
# ---------- Prompt ----------
|
|
SYSTEM_PROMPT = (
|
|
"You are a conservative medical assistant. Output ONLY a single JSON object with keys: "
|
|
"summary, bmi, bmi_category, spo2_status, heart_rate_status, health_risk, recommendation, explanation."
|
|
)
|
|
def build_user_prompt(data: HealthInput, bmi):
|
|
return (
|
|
f"Health data:\n"
|
|
f"weight: {data.weight}\n"
|
|
f"height: {data.height}\n"
|
|
f"age: {data.age}\n"
|
|
f"spo2: {data.spo2}\n"
|
|
f"heart_rate: {data.heart_rate}\n"
|
|
f"activity: {data.activity}\n"
|
|
f"bmi: {bmi}\n\n"
|
|
"Return only a single JSON object with the exact keys."
|
|
)
|
|
|
|
# ---------- Endpoints ----------
|
|
@app.get("/")
|
|
def root():
|
|
return {"status": "ok", "model": os.path.basename(MODEL_PATH), "n_ctx": N_CTX, "n_threads": N_THREADS}
|
|
|
|
@app.post("/analyze")
|
|
async def analyze(request: Request):
|
|
# parse request body safely
|
|
try:
|
|
payload = await request.json()
|
|
except Exception:
|
|
return JSONResponse(status_code=400, content={"error": "invalid JSON body"})
|
|
|
|
# validate minimal fields (frontend uses same schema)
|
|
try:
|
|
data = HealthInput(**payload)
|
|
except Exception as e:
|
|
return JSONResponse(status_code=400, content={"error": "invalid payload", "details": str(e)})
|
|
|
|
# compute bmi
|
|
try:
|
|
height_m = data.height / 100.0
|
|
bmi_value = round(data.weight / (height_m ** 2), 1)
|
|
except Exception:
|
|
bmi_value = None
|
|
|
|
user_prompt = build_user_prompt(data, bmi_value)
|
|
|
|
# call model (with semaphore)
|
|
with llm_sem:
|
|
try:
|
|
kwargs = dict(
|
|
messages=[
|
|
{"role": "system", "content": SYSTEM_PROMPT},
|
|
{"role": "user", "content": user_prompt}
|
|
],
|
|
temperature=TEMPERATURE,
|
|
max_tokens=MAX_TOKENS,
|
|
top_p=TOP_P,
|
|
)
|
|
# try to include response_format when supported
|
|
try:
|
|
kwargs_rf = kwargs.copy()
|
|
kwargs_rf["response_format"] = {"type": "json_object"}
|
|
response = llm.create_chat_completion(**kwargs_rf)
|
|
except TypeError:
|
|
response = llm.create_chat_completion(**kwargs)
|
|
except Exception as e:
|
|
# log exception and return JSON error (avoid HTML)
|
|
tb = traceback.format_exc()
|
|
print("Model inference error:", file=sys.stderr)
|
|
print(tb, file=sys.stderr)
|
|
return JSONResponse(status_code=500, content={"error": "model_inference_failed", "details": str(e)})
|
|
|
|
# extract text
|
|
raw_text = ""
|
|
try:
|
|
raw_text = response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
|
except Exception:
|
|
raw_text = str(response)
|
|
|
|
# attempt parse
|
|
json_text = extract_json_object(raw_text)
|
|
parsed = {}
|
|
try:
|
|
parsed = json.loads(json_text)
|
|
except Exception:
|
|
# fallback empty
|
|
parsed = {}
|
|
|
|
out = sanitize_and_normalize(parsed, bmi_value)
|
|
return JSONResponse(status_code=200, content=out) |