first commit
This commit is contained in:
commit
c73a2a067b
311
main.py
Normal file
311
main.py
Normal file
|
|
@ -0,0 +1,311 @@
|
||||||
|
# main.py (tuned for faster Phi-3.1-mini / Phi-2-mini style models on CPU)
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import multiprocessing
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
# llama_cpp import
|
||||||
|
from llama_cpp import Llama
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Configuration (env overrides allowed)
|
||||||
|
# -------------------------
|
||||||
|
DEFAULT_MODEL = os.environ.get("LLM_MODEL_PATH", "models/Phi-3.1-mini-4k-instruct-Q4_K_M.gguf")
|
||||||
|
|
||||||
|
# sensible defaults for your 13GB / 9-core VM
|
||||||
|
CPU_CORES = multiprocessing.cpu_count()
|
||||||
|
DEFAULT_THREADS = max(1, CPU_CORES - 1)
|
||||||
|
|
||||||
|
MODEL_PATH = DEFAULT_MODEL
|
||||||
|
N_THREADS = int(os.environ.get("LLM_N_THREADS", DEFAULT_THREADS))
|
||||||
|
N_CTX = int(os.environ.get("LLM_N_CTX", 1024)) # lowered from 1536/4096
|
||||||
|
N_BATCH = int(os.environ.get("LLM_N_BATCH", 512)) # larger batch for CPU speed
|
||||||
|
MAX_TOKENS = int(os.environ.get("LLM_MAX_TOKENS", 150))
|
||||||
|
TEMPERATURE = float(os.environ.get("LLM_TEMPERATURE", 0.2))
|
||||||
|
TOP_P = float(os.environ.get("LLM_TOP_P", 0.9))
|
||||||
|
CONCURRENCY = int(os.environ.get("LLM_CONCURRENCY", 1))
|
||||||
|
|
||||||
|
# memory mapping options
|
||||||
|
USE_MMAP = os.environ.get("LLM_USE_MMAP", "1") == "1"
|
||||||
|
USE_MLOCK = os.environ.get("LLM_USE_MLOCK", "0") == "1" # usually false on non-root
|
||||||
|
|
||||||
|
# optional in-memory cache to speed repeated identical requests
|
||||||
|
ENABLE_CACHE = os.environ.get("LLM_ENABLE_CACHE", "0") == "1"
|
||||||
|
CACHE_MAXSIZE = int(os.environ.get("LLM_CACHE_MAXSIZE", 256))
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# FastAPI
|
||||||
|
# -------------------------
|
||||||
|
app = FastAPI(root_path="/api")
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Model load (with safe guards)
|
||||||
|
# -------------------------
|
||||||
|
print(f"Loading model from: {MODEL_PATH}", file=sys.stderr)
|
||||||
|
if not os.path.exists(MODEL_PATH):
|
||||||
|
raise FileNotFoundError(f"Model file not found at: {MODEL_PATH}. Place the gguf file there or set LLM_MODEL_PATH env var.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = Llama(
|
||||||
|
model_path=MODEL_PATH,
|
||||||
|
n_ctx=N_CTX,
|
||||||
|
n_threads=N_THREADS,
|
||||||
|
n_batch=N_BATCH,
|
||||||
|
use_mmap=USE_MMAP,
|
||||||
|
use_mlock=USE_MLOCK,
|
||||||
|
)
|
||||||
|
print("Model loaded successfully.", file=sys.stderr)
|
||||||
|
except Exception as e:
|
||||||
|
# print traceback to stderr and raise so process doesn't silently continue
|
||||||
|
import traceback
|
||||||
|
print("ERROR loading model:", file=sys.stderr)
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
# concurrency control (1 by default to avoid OOM)
|
||||||
|
llm_sem = threading.Semaphore(CONCURRENCY)
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Input schema (unchanged)
|
||||||
|
# -------------------------
|
||||||
|
class HealthInput(BaseModel):
|
||||||
|
weight: float
|
||||||
|
height: float
|
||||||
|
age: int
|
||||||
|
spo2: int
|
||||||
|
heart_rate: int
|
||||||
|
activity: str
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Utility: extract JSON object from model text
|
||||||
|
# -------------------------
|
||||||
|
def extract_json_object(text: str) -> str:
|
||||||
|
start = text.find("{")
|
||||||
|
if start == -1:
|
||||||
|
return "{}"
|
||||||
|
stack = 0
|
||||||
|
for i in range(start, len(text)):
|
||||||
|
if text[i] == "{":
|
||||||
|
stack += 1
|
||||||
|
elif text[i] == "}":
|
||||||
|
stack -= 1
|
||||||
|
if stack == 0:
|
||||||
|
return text[start:i+1]
|
||||||
|
return text[start:]
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Normalization + typo fixes (keeps output schema intact)
|
||||||
|
# -------------------------
|
||||||
|
def normalize_keys_and_fix_typos(result: dict) -> dict:
|
||||||
|
canonical = {}
|
||||||
|
for k, v in result.items():
|
||||||
|
lk = k.strip().lower()
|
||||||
|
if lk.startswith("sum") or lk.startswith("summ") or lk == "summery":
|
||||||
|
canonical["summary"] = v
|
||||||
|
elif lk.startswith("bmi") and "category" in lk:
|
||||||
|
canonical["bmi_category"] = v
|
||||||
|
elif lk == "bmi":
|
||||||
|
canonical["bmi"] = v
|
||||||
|
elif lk.startswith("spo2") or lk.startswith("sp02"):
|
||||||
|
canonical["spo2_status"] = v
|
||||||
|
elif "heart" in lk and ("rate" in lk or "hr" in lk):
|
||||||
|
canonical["heart_rate_status"] = v
|
||||||
|
elif "risk" in lk:
|
||||||
|
canonical["health_risk"] = v
|
||||||
|
elif lk.startswith("recom") or "recomend" in lk:
|
||||||
|
canonical["recommendation"] = v
|
||||||
|
elif lk.startswith("expl") or "expla" in lk:
|
||||||
|
canonical["explanation"] = v
|
||||||
|
else:
|
||||||
|
canonical[k] = v
|
||||||
|
|
||||||
|
defaults = {
|
||||||
|
"summary": "Unable to produce summary.",
|
||||||
|
"bmi": None,
|
||||||
|
"bmi_category": "normal",
|
||||||
|
"spo2_status": "normal",
|
||||||
|
"heart_rate_status": "normal",
|
||||||
|
"health_risk": "low",
|
||||||
|
"recommendation": "Maintain healthy lifestyle: balanced diet, regular exercise, and follow-up with healthcare if needed.",
|
||||||
|
"explanation": "No detailed explanation provided."
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, dv in defaults.items():
|
||||||
|
if k not in canonical or canonical[k] is None or canonical[k] == "":
|
||||||
|
canonical[k] = dv
|
||||||
|
|
||||||
|
def pick_most_likely(val, allowed):
|
||||||
|
if not isinstance(val, str):
|
||||||
|
return val
|
||||||
|
if "|" in val:
|
||||||
|
parts = [p.strip().lower() for p in val.split("|")]
|
||||||
|
for p in parts:
|
||||||
|
if p in allowed:
|
||||||
|
return p
|
||||||
|
return parts[0]
|
||||||
|
lv = val.strip().lower()
|
||||||
|
for a in allowed:
|
||||||
|
if a in lv:
|
||||||
|
return a
|
||||||
|
return val
|
||||||
|
|
||||||
|
canonical["bmi_category"] = pick_most_likely(canonical["bmi_category"], ["underweight", "normal", "overweight"])
|
||||||
|
canonical["spo2_status"] = pick_most_likely(canonical["spo2_status"], ["normal", "low", "dangerous"])
|
||||||
|
canonical["heart_rate_status"] = pick_most_likely(canonical["heart_rate_status"], ["normal", "low", "high"])
|
||||||
|
canonical["health_risk"] = pick_most_likely(canonical["health_risk"], ["low", "moderate", "high"])
|
||||||
|
|
||||||
|
return canonical
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Prompt builder (short and strict)
|
||||||
|
# -------------------------
|
||||||
|
SYSTEM_PROMPT = (
|
||||||
|
"You are a concise medical assistant. Output ONLY a single JSON object with keys: "
|
||||||
|
"summary, bmi, bmi_category, spo2_status, heart_rate_status, health_risk, recommendation, explanation."
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_user_prompt(data: HealthInput, bmi: float) -> str:
|
||||||
|
return (
|
||||||
|
f"Health data:\n"
|
||||||
|
f"weight: {data.weight}\n"
|
||||||
|
f"height: {data.height}\n"
|
||||||
|
f"age: {data.age}\n"
|
||||||
|
f"spo2: {data.spo2}\n"
|
||||||
|
f"heart_rate: {data.heart_rate}\n"
|
||||||
|
f"activity: {data.activity}\n"
|
||||||
|
f"bmi: {bmi}\n\n"
|
||||||
|
"Return only a single JSON object using the exact keys."
|
||||||
|
)
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Optional simple LRU cache for identical inputs (disabled by default)
|
||||||
|
# -------------------------
|
||||||
|
if ENABLE_CACHE:
|
||||||
|
@lru_cache(maxsize=CACHE_MAXSIZE)
|
||||||
|
def cached_infer_cache_key(payload_str: str):
|
||||||
|
# payload_str is JSON string of inputs
|
||||||
|
# to keep function signature compatible with lru_cache, we accept string key
|
||||||
|
payload = json.loads(payload_str)
|
||||||
|
# we call the real infer function below by delegating to a non-cached function
|
||||||
|
return _infer_from_parsed(payload)
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Core inference (non-cached)
|
||||||
|
# -------------------------
|
||||||
|
def _infer_from_parsed(payload: dict) -> dict:
|
||||||
|
# map payload to HealthInput-like structure
|
||||||
|
try:
|
||||||
|
# compute BMI
|
||||||
|
height_m = payload.get("height", 0) / 100.0 if payload.get("height") else None
|
||||||
|
bmi_value = round(payload["weight"] / (height_m ** 2), 1) if height_m and payload.get("weight") else None
|
||||||
|
except Exception:
|
||||||
|
bmi_value = None
|
||||||
|
|
||||||
|
# build prompt
|
||||||
|
class Tmp:
|
||||||
|
pass
|
||||||
|
tmp = Tmp()
|
||||||
|
tmp.weight = payload.get("weight")
|
||||||
|
tmp.height = payload.get("height")
|
||||||
|
tmp.age = payload.get("age")
|
||||||
|
tmp.spo2 = payload.get("spo2")
|
||||||
|
tmp.heart_rate = payload.get("heart_rate")
|
||||||
|
tmp.activity = payload.get("activity")
|
||||||
|
|
||||||
|
user_prompt = build_user_prompt(tmp, bmi_value)
|
||||||
|
|
||||||
|
# run LLM with semaphore
|
||||||
|
with llm_sem:
|
||||||
|
start_t = time.time()
|
||||||
|
kwargs = dict(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": user_prompt}
|
||||||
|
],
|
||||||
|
temperature=TEMPERATURE,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
top_p=TOP_P,
|
||||||
|
)
|
||||||
|
|
||||||
|
# prefer response_format if supported — it's faster/cleaner
|
||||||
|
try:
|
||||||
|
kwargs_with_rf = kwargs.copy()
|
||||||
|
kwargs_with_rf["response_format"] = {"type": "json_object"}
|
||||||
|
response = llm.create_chat_completion(**kwargs_with_rf)
|
||||||
|
except TypeError:
|
||||||
|
response = llm.create_chat_completion(**kwargs)
|
||||||
|
|
||||||
|
duration = time.time() - start_t
|
||||||
|
print(f"[inference] elapsed={duration:.2f}s tokens_limit={MAX_TOKENS} threads={N_THREADS} batch={N_BATCH}", file=sys.stderr)
|
||||||
|
|
||||||
|
# extract generated text safely
|
||||||
|
raw_out = ""
|
||||||
|
try:
|
||||||
|
raw_out = response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||||
|
except Exception:
|
||||||
|
raw_out = str(response)
|
||||||
|
|
||||||
|
json_text = extract_json_object(raw_out)
|
||||||
|
parsed = {}
|
||||||
|
try:
|
||||||
|
parsed = json.loads(json_text)
|
||||||
|
except Exception:
|
||||||
|
parsed = {}
|
||||||
|
|
||||||
|
normalized = normalize_keys_and_fix_typos(parsed)
|
||||||
|
normalized["bmi"] = bmi_value
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Endpoint: /analyze (keeps existing JSON schema)
|
||||||
|
# -------------------------
|
||||||
|
@app.post("/analyze")
|
||||||
|
async def analyze(request: Request):
|
||||||
|
# parse request body
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
except Exception:
|
||||||
|
return JSONResponse(status_code=400, content={"error": "invalid JSON body"})
|
||||||
|
|
||||||
|
# validate minimal fields (fast check)
|
||||||
|
required = ["weight", "height", "age", "spo2", "heart_rate", "activity"]
|
||||||
|
for k in required:
|
||||||
|
if k not in payload:
|
||||||
|
return JSONResponse(status_code=400, content={"error": "missing_field", "field": k})
|
||||||
|
|
||||||
|
try:
|
||||||
|
if ENABLE_CACHE:
|
||||||
|
key = json.dumps(payload, sort_keys=True)
|
||||||
|
result = cached_infer_cache_key(key)
|
||||||
|
else:
|
||||||
|
result = _infer_from_parsed(payload)
|
||||||
|
except Exception as e:
|
||||||
|
# log and return JSON error so frontend doesn't get HTML
|
||||||
|
import traceback
|
||||||
|
tb = traceback.format_exc()
|
||||||
|
print("Inference error:", file=sys.stderr)
|
||||||
|
print(tb, file=sys.stderr)
|
||||||
|
return JSONResponse(status_code=500, content={"error": "model_inference_failed", "details": str(e)})
|
||||||
|
|
||||||
|
return JSONResponse(status_code=200, content=result)
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Root health check
|
||||||
|
# -------------------------
|
||||||
|
@app.get("/")
|
||||||
|
def root():
|
||||||
|
return {"status": "ok", "model": os.path.basename(MODEL_PATH), "n_ctx": N_CTX, "n_threads": N_THREADS}
|
||||||
224
qwen/main.py
Normal file
224
qwen/main.py
Normal file
|
|
@ -0,0 +1,224 @@
|
||||||
|
# main.py
|
||||||
|
# FastAPI backend di-tune untuk Qwen3.5-2B.Q4_K_M.gguf
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
import os, sys, threading, traceback, json, time
|
||||||
|
from llama_cpp import Llama
|
||||||
|
|
||||||
|
# ---------- Config (sesuaikan jika perlu) ----------
|
||||||
|
MODEL_DEFAULT = "models/Qwen3.5-2B.Q4_K_M.gguf"
|
||||||
|
MODEL_PATH = os.environ.get("LLM_MODEL_PATH", MODEL_DEFAULT)
|
||||||
|
|
||||||
|
# VM: 13GB RAM, 9 core -> safe default
|
||||||
|
N_THREADS = int(os.environ.get("LLM_N_THREADS", 9)) # gunakan semua core jika ingin
|
||||||
|
N_CTX = int(os.environ.get("LLM_N_CTX", 1536)) # 2048 lebih aman daripada 4096
|
||||||
|
N_BATCH = int(os.environ.get("LLM_N_BATCH", 256)) # batch lebih besar mempercepat
|
||||||
|
MAX_TOKENS = int(os.environ.get("LLM_MAX_TOKENS", 450))
|
||||||
|
TEMPERATURE = float(os.environ.get("LLM_TEMPERATURE", 0.3))
|
||||||
|
TOP_P = float(os.environ.get("LLM_TOP_P", 0.9))
|
||||||
|
CONCURRENCY = int(os.environ.get("LLM_CONCURRENCY", 1))
|
||||||
|
USE_MMAP = os.environ.get("LLM_USE_MMAP", "1") == "1"
|
||||||
|
USE_MLOCK = os.environ.get("LLM_USE_MLOCK", "0") == "1" # mlock bisa gagal tanpa privileges
|
||||||
|
|
||||||
|
# ---------- FastAPI ----------
|
||||||
|
app = FastAPI(root_path="/api")
|
||||||
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
|
||||||
|
|
||||||
|
# concurrency control
|
||||||
|
llm_sem = threading.Semaphore(CONCURRENCY)
|
||||||
|
|
||||||
|
# ---------- HealthInput (tetap sama) ----------
|
||||||
|
class HealthInput(BaseModel):
|
||||||
|
weight: float
|
||||||
|
height: float
|
||||||
|
age: int
|
||||||
|
spo2: int
|
||||||
|
heart_rate: int
|
||||||
|
activity: str
|
||||||
|
|
||||||
|
# ---------- Try to load model (with informative error) ----------
|
||||||
|
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Loading model from: {MODEL_PATH}", file=sys.stderr)
|
||||||
|
if not os.path.exists(MODEL_PATH):
|
||||||
|
raise FileNotFoundError(f"Model file not found: {MODEL_PATH}. Place your .gguf file there or set LLM_MODEL_PATH env var.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = Llama(
|
||||||
|
model_path=MODEL_PATH,
|
||||||
|
n_ctx=N_CTX,
|
||||||
|
n_threads=N_THREADS,
|
||||||
|
n_batch=N_BATCH,
|
||||||
|
use_mmap=USE_MMAP,
|
||||||
|
use_mlock=USE_MLOCK
|
||||||
|
)
|
||||||
|
print("Model loaded successfully.", file=sys.stderr)
|
||||||
|
except Exception as e:
|
||||||
|
# jika model gagal load, print stack dan raise supaya operator tahu.
|
||||||
|
print("ERROR loading model:", file=sys.stderr)
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
# ---------- Helpers: JSON extraction & normalization ----------
|
||||||
|
def extract_json_object(text: str) -> str:
|
||||||
|
start = text.find("{")
|
||||||
|
if start == -1:
|
||||||
|
return "{}"
|
||||||
|
stack = 0
|
||||||
|
for i in range(start, len(text)):
|
||||||
|
if text[i] == "{":
|
||||||
|
stack += 1
|
||||||
|
elif text[i] == "}":
|
||||||
|
stack -= 1
|
||||||
|
if stack == 0:
|
||||||
|
return text[start:i+1]
|
||||||
|
return text[start:]
|
||||||
|
|
||||||
|
def sanitize_and_normalize(parsed: dict, bmi_value) -> dict:
|
||||||
|
# map common typos to canonical keys
|
||||||
|
mapping = {}
|
||||||
|
for k, v in parsed.items():
|
||||||
|
lk = k.strip().lower()
|
||||||
|
if "summ" in lk or "summary" in lk:
|
||||||
|
mapping["summary"] = v
|
||||||
|
elif lk == "bmi":
|
||||||
|
mapping["bmi"] = v
|
||||||
|
elif "bmi" in lk and "cat" in lk:
|
||||||
|
mapping["bmi_category"] = v
|
||||||
|
elif "spo2" in lk or "sp02" in lk:
|
||||||
|
mapping["spo2_status"] = v
|
||||||
|
elif "heart" in lk and ("rate" in lk or "hr" in lk):
|
||||||
|
mapping["heart_rate_status"] = v
|
||||||
|
elif "risk" in lk:
|
||||||
|
mapping["health_risk"] = v
|
||||||
|
elif "recom" in lk or "recomend" in lk:
|
||||||
|
mapping["recommendation"] = v
|
||||||
|
elif "expl" in lk or "expla" in lk:
|
||||||
|
mapping["explanation"] = v
|
||||||
|
else:
|
||||||
|
mapping[k] = v
|
||||||
|
# defaults
|
||||||
|
defaults = {
|
||||||
|
"summary": "Unable to produce summary.",
|
||||||
|
"bmi": bmi_value,
|
||||||
|
"bmi_category": "normal",
|
||||||
|
"spo2_status": "normal",
|
||||||
|
"heart_rate_status": "normal",
|
||||||
|
"health_risk": "low",
|
||||||
|
"recommendation": "Maintain balanced diet, regular exercise, and consult a professional if concerned.",
|
||||||
|
"explanation": "No detailed explanation provided."
|
||||||
|
}
|
||||||
|
for dk, dv in defaults.items():
|
||||||
|
if dk not in mapping or mapping[dk] in (None, ""):
|
||||||
|
mapping[dk] = dv
|
||||||
|
# clean templates like "normal|low|dangerous"
|
||||||
|
def pick(val, allowed):
|
||||||
|
if not isinstance(val, str): return val
|
||||||
|
if "|" in val:
|
||||||
|
parts = [p.strip().lower() for p in val.split("|")]
|
||||||
|
for p in parts:
|
||||||
|
if p in allowed:
|
||||||
|
return p
|
||||||
|
return parts[0]
|
||||||
|
lv = val.strip().lower()
|
||||||
|
for a in allowed:
|
||||||
|
if a in lv:
|
||||||
|
return a
|
||||||
|
return val
|
||||||
|
mapping["bmi_category"] = pick(mapping["bmi_category"], ["underweight", "normal", "overweight"])
|
||||||
|
mapping["spo2_status"] = pick(mapping["spo2_status"], ["normal", "low", "dangerous"])
|
||||||
|
mapping["heart_rate_status"] = pick(mapping["heart_rate_status"], ["normal", "low", "high"])
|
||||||
|
mapping["health_risk"] = pick(mapping["health_risk"], ["low", "moderate", "high"])
|
||||||
|
mapping["bmi"] = bmi_value
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
# ---------- Prompt ----------
|
||||||
|
SYSTEM_PROMPT = (
|
||||||
|
"You are a conservative medical assistant. Output ONLY a single JSON object with keys: "
|
||||||
|
"summary, bmi, bmi_category, spo2_status, heart_rate_status, health_risk, recommendation, explanation."
|
||||||
|
)
|
||||||
|
def build_user_prompt(data: HealthInput, bmi):
|
||||||
|
return (
|
||||||
|
f"Health data:\n"
|
||||||
|
f"weight: {data.weight}\n"
|
||||||
|
f"height: {data.height}\n"
|
||||||
|
f"age: {data.age}\n"
|
||||||
|
f"spo2: {data.spo2}\n"
|
||||||
|
f"heart_rate: {data.heart_rate}\n"
|
||||||
|
f"activity: {data.activity}\n"
|
||||||
|
f"bmi: {bmi}\n\n"
|
||||||
|
"Return only a single JSON object with the exact keys."
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------- Endpoints ----------
|
||||||
|
@app.get("/")
|
||||||
|
def root():
|
||||||
|
return {"status": "ok", "model": os.path.basename(MODEL_PATH), "n_ctx": N_CTX, "n_threads": N_THREADS}
|
||||||
|
|
||||||
|
@app.post("/analyze")
|
||||||
|
async def analyze(request: Request):
|
||||||
|
# parse request body safely
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
except Exception:
|
||||||
|
return JSONResponse(status_code=400, content={"error": "invalid JSON body"})
|
||||||
|
|
||||||
|
# validate minimal fields (frontend uses same schema)
|
||||||
|
try:
|
||||||
|
data = HealthInput(**payload)
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(status_code=400, content={"error": "invalid payload", "details": str(e)})
|
||||||
|
|
||||||
|
# compute bmi
|
||||||
|
try:
|
||||||
|
height_m = data.height / 100.0
|
||||||
|
bmi_value = round(data.weight / (height_m ** 2), 1)
|
||||||
|
except Exception:
|
||||||
|
bmi_value = None
|
||||||
|
|
||||||
|
user_prompt = build_user_prompt(data, bmi_value)
|
||||||
|
|
||||||
|
# call model (with semaphore)
|
||||||
|
with llm_sem:
|
||||||
|
try:
|
||||||
|
kwargs = dict(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": user_prompt}
|
||||||
|
],
|
||||||
|
temperature=TEMPERATURE,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
top_p=TOP_P,
|
||||||
|
)
|
||||||
|
# try to include response_format when supported
|
||||||
|
try:
|
||||||
|
kwargs_rf = kwargs.copy()
|
||||||
|
kwargs_rf["response_format"] = {"type": "json_object"}
|
||||||
|
response = llm.create_chat_completion(**kwargs_rf)
|
||||||
|
except TypeError:
|
||||||
|
response = llm.create_chat_completion(**kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
# log exception and return JSON error (avoid HTML)
|
||||||
|
tb = traceback.format_exc()
|
||||||
|
print("Model inference error:", file=sys.stderr)
|
||||||
|
print(tb, file=sys.stderr)
|
||||||
|
return JSONResponse(status_code=500, content={"error": "model_inference_failed", "details": str(e)})
|
||||||
|
|
||||||
|
# extract text
|
||||||
|
raw_text = ""
|
||||||
|
try:
|
||||||
|
raw_text = response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||||
|
except Exception:
|
||||||
|
raw_text = str(response)
|
||||||
|
|
||||||
|
# attempt parse
|
||||||
|
json_text = extract_json_object(raw_text)
|
||||||
|
parsed = {}
|
||||||
|
try:
|
||||||
|
parsed = json.loads(json_text)
|
||||||
|
except Exception:
|
||||||
|
# fallback empty
|
||||||
|
parsed = {}
|
||||||
|
|
||||||
|
out = sanitize_and_normalize(parsed, bmi_value)
|
||||||
|
return JSONResponse(status_code=200, content=out)
|
||||||
Loading…
Reference in New Issue
Block a user