# main.py (tuned for faster Phi-3.1-mini / Phi-2-mini style models on CPU) from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel import os import sys import threading import json import time import multiprocessing from functools import lru_cache # llama_cpp import from llama_cpp import Llama # ------------------------- # Configuration (env overrides allowed) # ------------------------- DEFAULT_MODEL = os.environ.get("LLM_MODEL_PATH", "models/Phi-3.1-mini-4k-instruct-Q4_K_M.gguf") # sensible defaults for your 13GB / 9-core VM CPU_CORES = multiprocessing.cpu_count() DEFAULT_THREADS = max(1, CPU_CORES - 1) MODEL_PATH = DEFAULT_MODEL N_THREADS = int(os.environ.get("LLM_N_THREADS", DEFAULT_THREADS)) N_CTX = int(os.environ.get("LLM_N_CTX", 1024)) # lowered from 1536/4096 N_BATCH = int(os.environ.get("LLM_N_BATCH", 512)) # larger batch for CPU speed MAX_TOKENS = int(os.environ.get("LLM_MAX_TOKENS", 150)) TEMPERATURE = float(os.environ.get("LLM_TEMPERATURE", 0.2)) TOP_P = float(os.environ.get("LLM_TOP_P", 0.9)) CONCURRENCY = int(os.environ.get("LLM_CONCURRENCY", 1)) # memory mapping options USE_MMAP = os.environ.get("LLM_USE_MMAP", "1") == "1" USE_MLOCK = os.environ.get("LLM_USE_MLOCK", "0") == "1" # usually false on non-root # optional in-memory cache to speed repeated identical requests ENABLE_CACHE = os.environ.get("LLM_ENABLE_CACHE", "0") == "1" CACHE_MAXSIZE = int(os.environ.get("LLM_CACHE_MAXSIZE", 256)) # ------------------------- # FastAPI # ------------------------- app = FastAPI(root_path="/api") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ------------------------- # Model load (with safe guards) # ------------------------- print(f"Loading model from: {MODEL_PATH}", file=sys.stderr) if not os.path.exists(MODEL_PATH): raise FileNotFoundError(f"Model file not found at: {MODEL_PATH}. Place the gguf file there or set LLM_MODEL_PATH env var.") try: llm = Llama( model_path=MODEL_PATH, n_ctx=N_CTX, n_threads=N_THREADS, n_batch=N_BATCH, use_mmap=USE_MMAP, use_mlock=USE_MLOCK, ) print("Model loaded successfully.", file=sys.stderr) except Exception as e: # print traceback to stderr and raise so process doesn't silently continue import traceback print("ERROR loading model:", file=sys.stderr) traceback.print_exc() raise # concurrency control (1 by default to avoid OOM) llm_sem = threading.Semaphore(CONCURRENCY) # ------------------------- # Input schema (unchanged) # ------------------------- class HealthInput(BaseModel): weight: float height: float age: int spo2: int heart_rate: int activity: str # ------------------------- # Utility: extract JSON object from model text # ------------------------- def extract_json_object(text: str) -> str: start = text.find("{") if start == -1: return "{}" stack = 0 for i in range(start, len(text)): if text[i] == "{": stack += 1 elif text[i] == "}": stack -= 1 if stack == 0: return text[start:i+1] return text[start:] # ------------------------- # Normalization + typo fixes (keeps output schema intact) # ------------------------- def normalize_keys_and_fix_typos(result: dict) -> dict: canonical = {} for k, v in result.items(): lk = k.strip().lower() if lk.startswith("sum") or lk.startswith("summ") or lk == "summery": canonical["summary"] = v elif lk.startswith("bmi") and "category" in lk: canonical["bmi_category"] = v elif lk == "bmi": canonical["bmi"] = v elif lk.startswith("spo2") or lk.startswith("sp02"): canonical["spo2_status"] = v elif "heart" in lk and ("rate" in lk or "hr" in lk): canonical["heart_rate_status"] = v elif "risk" in lk: canonical["health_risk"] = v elif lk.startswith("recom") or "recomend" in lk: canonical["recommendation"] = v elif lk.startswith("expl") or "expla" in lk: canonical["explanation"] = v else: canonical[k] = v defaults = { "summary": "Unable to produce summary.", "bmi": None, "bmi_category": "normal", "spo2_status": "normal", "heart_rate_status": "normal", "health_risk": "low", "recommendation": "Maintain healthy lifestyle: balanced diet, regular exercise, and follow-up with healthcare if needed.", "explanation": "No detailed explanation provided." } for k, dv in defaults.items(): if k not in canonical or canonical[k] is None or canonical[k] == "": canonical[k] = dv def pick_most_likely(val, allowed): if not isinstance(val, str): return val if "|" in val: parts = [p.strip().lower() for p in val.split("|")] for p in parts: if p in allowed: return p return parts[0] lv = val.strip().lower() for a in allowed: if a in lv: return a return val canonical["bmi_category"] = pick_most_likely(canonical["bmi_category"], ["underweight", "normal", "overweight"]) canonical["spo2_status"] = pick_most_likely(canonical["spo2_status"], ["normal", "low", "dangerous"]) canonical["heart_rate_status"] = pick_most_likely(canonical["heart_rate_status"], ["normal", "low", "high"]) canonical["health_risk"] = pick_most_likely(canonical["health_risk"], ["low", "moderate", "high"]) return canonical # ------------------------- # Prompt builder (short and strict) # ------------------------- SYSTEM_PROMPT = ( "You are a concise medical assistant. Output ONLY a single JSON object with keys: " "summary, bmi, bmi_category, spo2_status, heart_rate_status, health_risk, recommendation, explanation." ) def build_user_prompt(data: HealthInput, bmi: float) -> str: return ( f"Health data:\n" f"weight: {data.weight}\n" f"height: {data.height}\n" f"age: {data.age}\n" f"spo2: {data.spo2}\n" f"heart_rate: {data.heart_rate}\n" f"activity: {data.activity}\n" f"bmi: {bmi}\n\n" "Return only a single JSON object using the exact keys." ) # ------------------------- # Optional simple LRU cache for identical inputs (disabled by default) # ------------------------- if ENABLE_CACHE: @lru_cache(maxsize=CACHE_MAXSIZE) def cached_infer_cache_key(payload_str: str): # payload_str is JSON string of inputs # to keep function signature compatible with lru_cache, we accept string key payload = json.loads(payload_str) # we call the real infer function below by delegating to a non-cached function return _infer_from_parsed(payload) # ------------------------- # Core inference (non-cached) # ------------------------- def _infer_from_parsed(payload: dict) -> dict: # map payload to HealthInput-like structure try: # compute BMI height_m = payload.get("height", 0) / 100.0 if payload.get("height") else None bmi_value = round(payload["weight"] / (height_m ** 2), 1) if height_m and payload.get("weight") else None except Exception: bmi_value = None # build prompt class Tmp: pass tmp = Tmp() tmp.weight = payload.get("weight") tmp.height = payload.get("height") tmp.age = payload.get("age") tmp.spo2 = payload.get("spo2") tmp.heart_rate = payload.get("heart_rate") tmp.activity = payload.get("activity") user_prompt = build_user_prompt(tmp, bmi_value) # run LLM with semaphore with llm_sem: start_t = time.time() kwargs = dict( messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt} ], temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P, ) # prefer response_format if supported — it's faster/cleaner try: kwargs_with_rf = kwargs.copy() kwargs_with_rf["response_format"] = {"type": "json_object"} response = llm.create_chat_completion(**kwargs_with_rf) except TypeError: response = llm.create_chat_completion(**kwargs) duration = time.time() - start_t print(f"[inference] elapsed={duration:.2f}s tokens_limit={MAX_TOKENS} threads={N_THREADS} batch={N_BATCH}", file=sys.stderr) # extract generated text safely raw_out = "" try: raw_out = response.get("choices", [{}])[0].get("message", {}).get("content", "") except Exception: raw_out = str(response) json_text = extract_json_object(raw_out) parsed = {} try: parsed = json.loads(json_text) except Exception: parsed = {} normalized = normalize_keys_and_fix_typos(parsed) normalized["bmi"] = bmi_value return normalized # ------------------------- # Endpoint: /analyze (keeps existing JSON schema) # ------------------------- @app.post("/analyze") async def analyze(request: Request): # parse request body try: payload = await request.json() except Exception: return JSONResponse(status_code=400, content={"error": "invalid JSON body"}) # validate minimal fields (fast check) required = ["weight", "height", "age", "spo2", "heart_rate", "activity"] for k in required: if k not in payload: return JSONResponse(status_code=400, content={"error": "missing_field", "field": k}) try: if ENABLE_CACHE: key = json.dumps(payload, sort_keys=True) result = cached_infer_cache_key(key) else: result = _infer_from_parsed(payload) except Exception as e: # log and return JSON error so frontend doesn't get HTML import traceback tb = traceback.format_exc() print("Inference error:", file=sys.stderr) print(tb, file=sys.stderr) return JSONResponse(status_code=500, content={"error": "model_inference_failed", "details": str(e)}) return JSONResponse(status_code=200, content=result) # ------------------------- # Root health check # ------------------------- @app.get("/") def root(): return {"status": "ok", "model": os.path.basename(MODEL_PATH), "n_ctx": N_CTX, "n_threads": N_THREADS}