311 lines
10 KiB
Python
311 lines
10 KiB
Python
# main.py (tuned for faster Phi-3.1-mini / Phi-2-mini style models on CPU)
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel
|
|
import os
|
|
import sys
|
|
import threading
|
|
import json
|
|
import time
|
|
import multiprocessing
|
|
from functools import lru_cache
|
|
|
|
# llama_cpp import
|
|
from llama_cpp import Llama
|
|
|
|
# -------------------------
|
|
# Configuration (env overrides allowed)
|
|
# -------------------------
|
|
DEFAULT_MODEL = os.environ.get("LLM_MODEL_PATH", "models/Phi-3.1-mini-4k-instruct-Q4_K_M.gguf")
|
|
|
|
# sensible defaults for your 13GB / 9-core VM
|
|
CPU_CORES = multiprocessing.cpu_count()
|
|
DEFAULT_THREADS = max(1, CPU_CORES - 1)
|
|
|
|
MODEL_PATH = DEFAULT_MODEL
|
|
N_THREADS = int(os.environ.get("LLM_N_THREADS", DEFAULT_THREADS))
|
|
N_CTX = int(os.environ.get("LLM_N_CTX", 1024)) # lowered from 1536/4096
|
|
N_BATCH = int(os.environ.get("LLM_N_BATCH", 512)) # larger batch for CPU speed
|
|
MAX_TOKENS = int(os.environ.get("LLM_MAX_TOKENS", 150))
|
|
TEMPERATURE = float(os.environ.get("LLM_TEMPERATURE", 0.2))
|
|
TOP_P = float(os.environ.get("LLM_TOP_P", 0.9))
|
|
CONCURRENCY = int(os.environ.get("LLM_CONCURRENCY", 1))
|
|
|
|
# memory mapping options
|
|
USE_MMAP = os.environ.get("LLM_USE_MMAP", "1") == "1"
|
|
USE_MLOCK = os.environ.get("LLM_USE_MLOCK", "0") == "1" # usually false on non-root
|
|
|
|
# optional in-memory cache to speed repeated identical requests
|
|
ENABLE_CACHE = os.environ.get("LLM_ENABLE_CACHE", "0") == "1"
|
|
CACHE_MAXSIZE = int(os.environ.get("LLM_CACHE_MAXSIZE", 256))
|
|
|
|
# -------------------------
|
|
# FastAPI
|
|
# -------------------------
|
|
app = FastAPI(root_path="/api")
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# -------------------------
|
|
# Model load (with safe guards)
|
|
# -------------------------
|
|
print(f"Loading model from: {MODEL_PATH}", file=sys.stderr)
|
|
if not os.path.exists(MODEL_PATH):
|
|
raise FileNotFoundError(f"Model file not found at: {MODEL_PATH}. Place the gguf file there or set LLM_MODEL_PATH env var.")
|
|
|
|
try:
|
|
llm = Llama(
|
|
model_path=MODEL_PATH,
|
|
n_ctx=N_CTX,
|
|
n_threads=N_THREADS,
|
|
n_batch=N_BATCH,
|
|
use_mmap=USE_MMAP,
|
|
use_mlock=USE_MLOCK,
|
|
)
|
|
print("Model loaded successfully.", file=sys.stderr)
|
|
except Exception as e:
|
|
# print traceback to stderr and raise so process doesn't silently continue
|
|
import traceback
|
|
print("ERROR loading model:", file=sys.stderr)
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
# concurrency control (1 by default to avoid OOM)
|
|
llm_sem = threading.Semaphore(CONCURRENCY)
|
|
|
|
# -------------------------
|
|
# Input schema (unchanged)
|
|
# -------------------------
|
|
class HealthInput(BaseModel):
|
|
weight: float
|
|
height: float
|
|
age: int
|
|
spo2: int
|
|
heart_rate: int
|
|
activity: str
|
|
|
|
# -------------------------
|
|
# Utility: extract JSON object from model text
|
|
# -------------------------
|
|
def extract_json_object(text: str) -> str:
|
|
start = text.find("{")
|
|
if start == -1:
|
|
return "{}"
|
|
stack = 0
|
|
for i in range(start, len(text)):
|
|
if text[i] == "{":
|
|
stack += 1
|
|
elif text[i] == "}":
|
|
stack -= 1
|
|
if stack == 0:
|
|
return text[start:i+1]
|
|
return text[start:]
|
|
|
|
# -------------------------
|
|
# Normalization + typo fixes (keeps output schema intact)
|
|
# -------------------------
|
|
def normalize_keys_and_fix_typos(result: dict) -> dict:
|
|
canonical = {}
|
|
for k, v in result.items():
|
|
lk = k.strip().lower()
|
|
if lk.startswith("sum") or lk.startswith("summ") or lk == "summery":
|
|
canonical["summary"] = v
|
|
elif lk.startswith("bmi") and "category" in lk:
|
|
canonical["bmi_category"] = v
|
|
elif lk == "bmi":
|
|
canonical["bmi"] = v
|
|
elif lk.startswith("spo2") or lk.startswith("sp02"):
|
|
canonical["spo2_status"] = v
|
|
elif "heart" in lk and ("rate" in lk or "hr" in lk):
|
|
canonical["heart_rate_status"] = v
|
|
elif "risk" in lk:
|
|
canonical["health_risk"] = v
|
|
elif lk.startswith("recom") or "recomend" in lk:
|
|
canonical["recommendation"] = v
|
|
elif lk.startswith("expl") or "expla" in lk:
|
|
canonical["explanation"] = v
|
|
else:
|
|
canonical[k] = v
|
|
|
|
defaults = {
|
|
"summary": "Unable to produce summary.",
|
|
"bmi": None,
|
|
"bmi_category": "normal",
|
|
"spo2_status": "normal",
|
|
"heart_rate_status": "normal",
|
|
"health_risk": "low",
|
|
"recommendation": "Maintain healthy lifestyle: balanced diet, regular exercise, and follow-up with healthcare if needed.",
|
|
"explanation": "No detailed explanation provided."
|
|
}
|
|
|
|
for k, dv in defaults.items():
|
|
if k not in canonical or canonical[k] is None or canonical[k] == "":
|
|
canonical[k] = dv
|
|
|
|
def pick_most_likely(val, allowed):
|
|
if not isinstance(val, str):
|
|
return val
|
|
if "|" in val:
|
|
parts = [p.strip().lower() for p in val.split("|")]
|
|
for p in parts:
|
|
if p in allowed:
|
|
return p
|
|
return parts[0]
|
|
lv = val.strip().lower()
|
|
for a in allowed:
|
|
if a in lv:
|
|
return a
|
|
return val
|
|
|
|
canonical["bmi_category"] = pick_most_likely(canonical["bmi_category"], ["underweight", "normal", "overweight"])
|
|
canonical["spo2_status"] = pick_most_likely(canonical["spo2_status"], ["normal", "low", "dangerous"])
|
|
canonical["heart_rate_status"] = pick_most_likely(canonical["heart_rate_status"], ["normal", "low", "high"])
|
|
canonical["health_risk"] = pick_most_likely(canonical["health_risk"], ["low", "moderate", "high"])
|
|
|
|
return canonical
|
|
|
|
# -------------------------
|
|
# Prompt builder (short and strict)
|
|
# -------------------------
|
|
SYSTEM_PROMPT = (
|
|
"You are a concise medical assistant. Output ONLY a single JSON object with keys: "
|
|
"summary, bmi, bmi_category, spo2_status, heart_rate_status, health_risk, recommendation, explanation."
|
|
)
|
|
|
|
def build_user_prompt(data: HealthInput, bmi: float) -> str:
|
|
return (
|
|
f"Health data:\n"
|
|
f"weight: {data.weight}\n"
|
|
f"height: {data.height}\n"
|
|
f"age: {data.age}\n"
|
|
f"spo2: {data.spo2}\n"
|
|
f"heart_rate: {data.heart_rate}\n"
|
|
f"activity: {data.activity}\n"
|
|
f"bmi: {bmi}\n\n"
|
|
"Return only a single JSON object using the exact keys."
|
|
)
|
|
|
|
# -------------------------
|
|
# Optional simple LRU cache for identical inputs (disabled by default)
|
|
# -------------------------
|
|
if ENABLE_CACHE:
|
|
@lru_cache(maxsize=CACHE_MAXSIZE)
|
|
def cached_infer_cache_key(payload_str: str):
|
|
# payload_str is JSON string of inputs
|
|
# to keep function signature compatible with lru_cache, we accept string key
|
|
payload = json.loads(payload_str)
|
|
# we call the real infer function below by delegating to a non-cached function
|
|
return _infer_from_parsed(payload)
|
|
|
|
# -------------------------
|
|
# Core inference (non-cached)
|
|
# -------------------------
|
|
def _infer_from_parsed(payload: dict) -> dict:
|
|
# map payload to HealthInput-like structure
|
|
try:
|
|
# compute BMI
|
|
height_m = payload.get("height", 0) / 100.0 if payload.get("height") else None
|
|
bmi_value = round(payload["weight"] / (height_m ** 2), 1) if height_m and payload.get("weight") else None
|
|
except Exception:
|
|
bmi_value = None
|
|
|
|
# build prompt
|
|
class Tmp:
|
|
pass
|
|
tmp = Tmp()
|
|
tmp.weight = payload.get("weight")
|
|
tmp.height = payload.get("height")
|
|
tmp.age = payload.get("age")
|
|
tmp.spo2 = payload.get("spo2")
|
|
tmp.heart_rate = payload.get("heart_rate")
|
|
tmp.activity = payload.get("activity")
|
|
|
|
user_prompt = build_user_prompt(tmp, bmi_value)
|
|
|
|
# run LLM with semaphore
|
|
with llm_sem:
|
|
start_t = time.time()
|
|
kwargs = dict(
|
|
messages=[
|
|
{"role": "system", "content": SYSTEM_PROMPT},
|
|
{"role": "user", "content": user_prompt}
|
|
],
|
|
temperature=TEMPERATURE,
|
|
max_tokens=MAX_TOKENS,
|
|
top_p=TOP_P,
|
|
)
|
|
|
|
# prefer response_format if supported — it's faster/cleaner
|
|
try:
|
|
kwargs_with_rf = kwargs.copy()
|
|
kwargs_with_rf["response_format"] = {"type": "json_object"}
|
|
response = llm.create_chat_completion(**kwargs_with_rf)
|
|
except TypeError:
|
|
response = llm.create_chat_completion(**kwargs)
|
|
|
|
duration = time.time() - start_t
|
|
print(f"[inference] elapsed={duration:.2f}s tokens_limit={MAX_TOKENS} threads={N_THREADS} batch={N_BATCH}", file=sys.stderr)
|
|
|
|
# extract generated text safely
|
|
raw_out = ""
|
|
try:
|
|
raw_out = response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
|
except Exception:
|
|
raw_out = str(response)
|
|
|
|
json_text = extract_json_object(raw_out)
|
|
parsed = {}
|
|
try:
|
|
parsed = json.loads(json_text)
|
|
except Exception:
|
|
parsed = {}
|
|
|
|
normalized = normalize_keys_and_fix_typos(parsed)
|
|
normalized["bmi"] = bmi_value
|
|
return normalized
|
|
|
|
# -------------------------
|
|
# Endpoint: /analyze (keeps existing JSON schema)
|
|
# -------------------------
|
|
@app.post("/analyze")
|
|
async def analyze(request: Request):
|
|
# parse request body
|
|
try:
|
|
payload = await request.json()
|
|
except Exception:
|
|
return JSONResponse(status_code=400, content={"error": "invalid JSON body"})
|
|
|
|
# validate minimal fields (fast check)
|
|
required = ["weight", "height", "age", "spo2", "heart_rate", "activity"]
|
|
for k in required:
|
|
if k not in payload:
|
|
return JSONResponse(status_code=400, content={"error": "missing_field", "field": k})
|
|
|
|
try:
|
|
if ENABLE_CACHE:
|
|
key = json.dumps(payload, sort_keys=True)
|
|
result = cached_infer_cache_key(key)
|
|
else:
|
|
result = _infer_from_parsed(payload)
|
|
except Exception as e:
|
|
# log and return JSON error so frontend doesn't get HTML
|
|
import traceback
|
|
tb = traceback.format_exc()
|
|
print("Inference error:", file=sys.stderr)
|
|
print(tb, file=sys.stderr)
|
|
return JSONResponse(status_code=500, content={"error": "model_inference_failed", "details": str(e)})
|
|
|
|
return JSONResponse(status_code=200, content=result)
|
|
|
|
# -------------------------
|
|
# Root health check
|
|
# -------------------------
|
|
@app.get("/")
|
|
def root():
|
|
return {"status": "ok", "model": os.path.basename(MODEL_PATH), "n_ctx": N_CTX, "n_threads": N_THREADS} |