from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field from typing import Optional from cold.classifier import ToxicTextClassifier import torch app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) model = ToxicTextClassifier() model.load_state_dict(torch.load("output/lited_best.pth",map_location="cpu")) class PredictionInput(BaseModel): text: str = Field(..., title="Text to classify", description="The text to classify for malicious content") context: Optional[str] = Field(None, title="Context for classification", description="Optional context to provide additional information for classification") @app.post("/predict") def predict(input: PredictionInput): try: if not input.text: raise HTTPException(status_code=400, detail="Text input is required") elif len(input.text) > 512: raise HTTPException(status_code=400, detail="Text input exceeds maximum length of 512 characters") if input.context and len(input.context) > 512: raise HTTPException(status_code=400, detail="Context input exceeds maximum length of 512 characters") if not input.context: result = model.predict(input.text, device="cpu") print(result) return {"text": input.text, "prediction": result[0]["prediction"], "probabilities": result[0]["probabilities"]} else: result = model.predict([[input.text,input.context]], device="cpu") return {"text": input.text, "context": input.context, "prediction": result[0]["prediction"], "probabilities": result[0]["probabilities"]} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) app.mount("/", StaticFiles(directory="out", html=True), name="static")