Fdet / app.py
siyah1's picture
Update app.py
01a1b3c verified
import os
import cv2
import numpy as np
import onnxruntime as ort
from ultralytics import YOLO
from huggingface_hub import snapshot_download
from sklearn.metrics.pairwise import cosine_similarity
from PIL import Image
import gradio as gr
import pickle
# ------------------------------
# MODEL REPOSITORIES
# ------------------------------
YOLO_REPO = "arnabdhar/YOLOv8-Face-Detection" # Face detection model
ARCFACE_ONNX_REPO = "garavv/arcface-onnx" # ArcFace ONNX model
# ------------------------------
# PKL FILE PATH
# ------------------------------
KNOWN_FACES_PKL = "known_faces.pkl"
# ------------------------------
# DOWNLOAD MODELS FROM HUGGING FACE
# ------------------------------
yolo_dir = snapshot_download(YOLO_REPO)
arcface_dir = snapshot_download(ARCFACE_ONNX_REPO)
def find_model(folder, ext):
for root, _, files in os.walk(folder):
for f in files:
if f.endswith(ext):
return os.path.join(root, f)
return None
yolo_model_file = find_model(yolo_dir, ".pt")
arcface_file = find_model(arcface_dir, ".onnx")
# ------------------------------
# LOAD MODELS
# ------------------------------
yolo = YOLO(yolo_model_file)
arcface_sess = ort.InferenceSession(arcface_file, providers=["CPUExecutionProvider"])
arcface_input = arcface_sess.get_inputs()[0].name
# ------------------------------
# HELPER FUNCTIONS
# ------------------------------
def get_embedding(face):
"""
Preprocess the face and compute embedding using ArcFace ONNX model.
Fixed to match NHWC format (1, 112, 112, 3).
"""
# Ensure face is not empty
if face.size == 0 or face.shape[0] < 10 or face.shape[1] < 10:
return None
img = cv2.resize(face, (112, 112))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Normalize to [-1, 1] range for better feature extraction
img = (img.astype(np.float32) - 127.5) / 128.0
img = np.expand_dims(img, axis=0) # (1, 112, 112, 3)
emb = arcface_sess.run(None, {arcface_input: img})[0][0]
# L2 normalization
emb = emb / (np.linalg.norm(emb) + 1e-8)
return emb
# ------------------------------
# LOAD/SAVE KNOWN FACES FROM/TO PKL
# ------------------------------
def load_known_faces():
"""Load known faces from pickle file"""
if os.path.exists(KNOWN_FACES_PKL):
try:
with open(KNOWN_FACES_PKL, 'rb') as f:
return pickle.load(f)
except Exception as e:
print(f"Error loading known faces: {e}")
return {}
else:
# Create empty pickle file if it doesn't exist
print(f"Creating new known faces database: {KNOWN_FACES_PKL}")
empty_dict = {}
try:
with open(KNOWN_FACES_PKL, 'wb') as f:
pickle.dump(empty_dict, f)
except Exception as e:
print(f"Error creating pickle file: {e}")
return empty_dict
def save_known_faces(known_faces):
"""Save known faces to pickle file"""
try:
# Ensure directory exists
os.makedirs(os.path.dirname(KNOWN_FACES_PKL) if os.path.dirname(KNOWN_FACES_PKL) else '.', exist_ok=True)
with open(KNOWN_FACES_PKL, 'wb') as f:
pickle.dump(known_faces, f)
print(f"Successfully saved {len(known_faces)} faces to {KNOWN_FACES_PKL}")
return True
except Exception as e:
print(f"Error saving known faces: {e}")
import traceback
traceback.print_exc()
return False
# ------------------------------
# LOAD KNOWN FACES ON STARTUP
# ------------------------------
known_faces = load_known_faces()
print(f"Loaded {len(known_faces)} known faces from storage.")
# ------------------------------
# REGISTER FACE FUNCTION (with PKL storage)
# ------------------------------
def register_face(name, image):
"""Register a face with a name and save to PKL file"""
import traceback
if not name or name.strip() == "":
return "❌ Please enter a name before uploading an image."
if image is None:
return "❌ Please upload a valid face image."
try:
# Convert PIL image to OpenCV
img = np.array(image)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
else:
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# Detect face first using YOLO to ensure good quality
results = yolo.predict(source=img, conf=0.25, verbose=False)
face_found = False
for r in results:
boxes = r.boxes
if len(boxes) > 0:
# Use the first detected face
box = boxes[0]
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
crop = img[y1:y2, x1:x2]
# Compute embedding from cropped face
emb = get_embedding(crop)
if emb is not None and not np.isnan(emb).any():
face_found = True
break
if not face_found:
return f"❌ No clear face detected for {name}. Please upload a clearer image with a visible face."
# Add to known faces dictionary
known_faces[name] = emb
# Save to PKL file
if save_known_faces(known_faces):
return f"βœ… Registered face for **{name}** and saved to storage. Total known faces: {len(known_faces)}"
else:
return f"⚠️ Registered face for **{name}** but failed to save to storage."
except Exception as e:
tb = traceback.format_exc()
print("---- ERROR DURING REGISTER FACE ----")
print(tb)
return f"⚠️ Internal error: {str(e)}\n\n{tb}"
# ------------------------------
# DELETE FACE FUNCTION
# ------------------------------
def delete_face(name):
"""Delete a registered face from memory and PKL file"""
if not name or name.strip() == "":
return "❌ Please enter a name to delete."
if name not in known_faces:
return f"❌ No face registered with name '{name}'."
try:
del known_faces[name]
if save_known_faces(known_faces):
return f"βœ… Deleted face for **{name}**. Remaining faces: {len(known_faces)}"
else:
return f"⚠️ Deleted from memory but failed to update storage."
except Exception as e:
return f"❌ Error deleting face: {str(e)}"
# ------------------------------
# LIST ALL REGISTERED FACES
# ------------------------------
def list_faces():
"""List all registered faces"""
if not known_faces:
return "No faces registered yet."
face_list = "\n".join([f"β€’ {name}" for name in known_faces.keys()])
return f"**Registered Faces ({len(known_faces)} total):**\n\n{face_list}"
# ------------------------------
# DETECT + RECOGNIZE FUNCTION
# ------------------------------
def detect_and_recognize(image):
import traceback
if image is None:
return "❌ Please upload an image.", None
if not known_faces:
return "⚠️ No known faces registered yet!", None
try:
img = np.array(image)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
else:
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
results = yolo.predict(source=img, conf=0.35, verbose=False)
names_found = []
for r in results:
boxes = r.boxes
for box in boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
crop = img[y1:y2, x1:x2]
if crop.size == 0:
continue
emb = get_embedding(crop)
best_name, best_score = "Unknown", 0
for name, ref_emb in known_faces.items():
score = cosine_similarity([emb], [ref_emb])[0][0]
if score > best_score:
best_name, best_score = name, score
# Lower threshold for better matching (0.35 instead of 0.45)
if best_score < 0.35:
best_name = "Unknown"
cv2.rectangle(img, (x1, y1), (x2, y2), (0,255,0), 2)
cv2.putText(img, f"{best_name} ({best_score:.2f})", (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
names_found.append(f"{best_name} ({best_score:.2f})")
result_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
if names_found:
return f"βœ… Detected faces: {', '.join(names_found)}", result_img
else:
return "⚠️ No faces detected in the image.", result_img
except Exception as e:
tb = traceback.format_exc()
print("---- ERROR DURING DETECT & RECOGNIZE ----")
print(tb)
return f"⚠️ Internal error: {str(e)}\n\n{tb}", None
# ------------------------------
# GRADIO INTERFACE
# ------------------------------
with gr.Blocks() as demo:
gr.Markdown("# πŸ‘€ Face Detection & Recognition (YOLO + ArcFace)")
gr.Markdown("Register faces to build your database, then detect and recognize them in new images!")
with gr.Tab("πŸ“ Register Face"):
gr.Markdown("### Add a new face to the database")
name_input = gr.Textbox(label="Person Name", placeholder="Enter person's name...")
face_input = gr.Image(label="Upload Face Image", type="pil")
register_btn = gr.Button("Register Face", variant="primary")
register_output = gr.Markdown(label="Status")
register_btn.click(register_face, [name_input, face_input], register_output)
with gr.Tab("πŸ” Detect & Recognize"):
gr.Markdown("### Detect and recognize faces in an image")
img_input = gr.Image(label="Upload Test Image", type="pil")
detect_btn = gr.Button("Detect Faces", variant="primary")
text_output = gr.Markdown(label="Results")
img_output = gr.Image(label="Output Image")
detect_btn.click(detect_and_recognize, img_input, [text_output, img_output])
with gr.Tab("πŸ“‹ Manage Faces"):
gr.Markdown("### View and manage registered faces")
with gr.Row():
with gr.Column():
list_btn = gr.Button("Show All Registered Faces")
list_output = gr.Markdown(label="Registered Faces")
list_btn.click(list_faces, outputs=list_output)
with gr.Column():
delete_name = gr.Textbox(label="Name to Delete", placeholder="Enter name to remove...")
delete_btn = gr.Button("Delete Face", variant="stop")
delete_output = gr.Markdown(label="Status")
delete_btn.click(delete_face, delete_name, delete_output)
demo.launch()