AlbertCAC commited on
Commit
1a02e83
·
1 Parent(s): 4e5d78f
Files changed (2) hide show
  1. app.py +30 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel, Field
3
+ from typing import Optional
4
+ from cold.classifier import ToxicTextClassifier
5
+
6
+ app = FastAPI()
7
+
8
+ model = ToxicTextClassifier()
9
+
10
+ class PredictionInput(BaseModel):
11
+ text: str = Field(..., title="Text to classify", description="The text to classify for malicious content")
12
+ context: Optional[str] = Field(None, title="Context for classification", description="Optional context to provide additional information for classification")
13
+
14
+ @app.post("/predict")
15
+ def predict(input: PredictionInput):
16
+ try:
17
+ if not input.text:
18
+ raise HTTPException(status_code=400, detail="Text input is required")
19
+ elif len(input.text) > 512:
20
+ raise HTTPException(status_code=400, detail="Text input exceeds maximum length of 512 characters")
21
+ if input.context and len(input.context) > 512:
22
+ raise HTTPException(status_code=400, detail="Context input exceeds maximum length of 512 characters")
23
+ if not input.context:
24
+ result = model.predict(input.text, device="cpu")
25
+ return {"text": input.text, "prediction": result[0]["prediction"], "probabilities": result[0]["probabilities"]}
26
+ else:
27
+ result = model.predict([[input.text,input.context]], device="cpu")
28
+ return {"text": input.text, "context": input.context, "prediction": result[0]["prediction"], "probabilities": result[0]["probabilities"]}
29
+ except Exception as e:
30
+ raise HTTPException(status_code=500, detail=str(e))
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ pydantic
3
+ uvicorn
4
+ torch
5
+ transformers
6
+ scikit-learn
7
+ pandas
8
+ tqdm