import os import cv2 import numpy as np import onnxruntime as ort from ultralytics import YOLO from huggingface_hub import snapshot_download from sklearn.metrics.pairwise import cosine_similarity from PIL import Image import gradio as gr import pickle # ------------------------------ # MODEL REPOSITORIES # ------------------------------ YOLO_REPO = "arnabdhar/YOLOv8-Face-Detection" # Face detection model ARCFACE_ONNX_REPO = "garavv/arcface-onnx" # ArcFace ONNX model # ------------------------------ # PKL FILE PATH # ------------------------------ KNOWN_FACES_PKL = "known_faces.pkl" # ------------------------------ # DOWNLOAD MODELS FROM HUGGING FACE # ------------------------------ yolo_dir = snapshot_download(YOLO_REPO) arcface_dir = snapshot_download(ARCFACE_ONNX_REPO) def find_model(folder, ext): for root, _, files in os.walk(folder): for f in files: if f.endswith(ext): return os.path.join(root, f) return None yolo_model_file = find_model(yolo_dir, ".pt") arcface_file = find_model(arcface_dir, ".onnx") # ------------------------------ # LOAD MODELS # ------------------------------ yolo = YOLO(yolo_model_file) arcface_sess = ort.InferenceSession(arcface_file, providers=["CPUExecutionProvider"]) arcface_input = arcface_sess.get_inputs()[0].name # ------------------------------ # HELPER FUNCTIONS # ------------------------------ def get_embedding(face): """ Preprocess the face and compute embedding using ArcFace ONNX model. Fixed to match NHWC format (1, 112, 112, 3). """ # Ensure face is not empty if face.size == 0 or face.shape[0] < 10 or face.shape[1] < 10: return None img = cv2.resize(face, (112, 112)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Normalize to [-1, 1] range for better feature extraction img = (img.astype(np.float32) - 127.5) / 128.0 img = np.expand_dims(img, axis=0) # (1, 112, 112, 3) emb = arcface_sess.run(None, {arcface_input: img})[0][0] # L2 normalization emb = emb / (np.linalg.norm(emb) + 1e-8) return emb # ------------------------------ # LOAD/SAVE KNOWN FACES FROM/TO PKL # ------------------------------ def load_known_faces(): """Load known faces from pickle file""" if os.path.exists(KNOWN_FACES_PKL): try: with open(KNOWN_FACES_PKL, 'rb') as f: return pickle.load(f) except Exception as e: print(f"Error loading known faces: {e}") return {} else: # Create empty pickle file if it doesn't exist print(f"Creating new known faces database: {KNOWN_FACES_PKL}") empty_dict = {} try: with open(KNOWN_FACES_PKL, 'wb') as f: pickle.dump(empty_dict, f) except Exception as e: print(f"Error creating pickle file: {e}") return empty_dict def save_known_faces(known_faces): """Save known faces to pickle file""" try: # Ensure directory exists os.makedirs(os.path.dirname(KNOWN_FACES_PKL) if os.path.dirname(KNOWN_FACES_PKL) else '.', exist_ok=True) with open(KNOWN_FACES_PKL, 'wb') as f: pickle.dump(known_faces, f) print(f"Successfully saved {len(known_faces)} faces to {KNOWN_FACES_PKL}") return True except Exception as e: print(f"Error saving known faces: {e}") import traceback traceback.print_exc() return False # ------------------------------ # LOAD KNOWN FACES ON STARTUP # ------------------------------ known_faces = load_known_faces() print(f"Loaded {len(known_faces)} known faces from storage.") # ------------------------------ # REGISTER FACE FUNCTION (with PKL storage) # ------------------------------ def register_face(name, image): """Register a face with a name and save to PKL file""" import traceback if not name or name.strip() == "": return "❌ Please enter a name before uploading an image." if image is None: return "❌ Please upload a valid face image." try: # Convert PIL image to OpenCV img = np.array(image) if img.ndim == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) else: img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Detect face first using YOLO to ensure good quality results = yolo.predict(source=img, conf=0.25, verbose=False) face_found = False for r in results: boxes = r.boxes if len(boxes) > 0: # Use the first detected face box = boxes[0] x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) crop = img[y1:y2, x1:x2] # Compute embedding from cropped face emb = get_embedding(crop) if emb is not None and not np.isnan(emb).any(): face_found = True break if not face_found: return f"❌ No clear face detected for {name}. Please upload a clearer image with a visible face." # Add to known faces dictionary known_faces[name] = emb # Save to PKL file if save_known_faces(known_faces): return f"✅ Registered face for **{name}** and saved to storage. Total known faces: {len(known_faces)}" else: return f"⚠️ Registered face for **{name}** but failed to save to storage." except Exception as e: tb = traceback.format_exc() print("---- ERROR DURING REGISTER FACE ----") print(tb) return f"⚠️ Internal error: {str(e)}\n\n{tb}" # ------------------------------ # DELETE FACE FUNCTION # ------------------------------ def delete_face(name): """Delete a registered face from memory and PKL file""" if not name or name.strip() == "": return "❌ Please enter a name to delete." if name not in known_faces: return f"❌ No face registered with name '{name}'." try: del known_faces[name] if save_known_faces(known_faces): return f"✅ Deleted face for **{name}**. Remaining faces: {len(known_faces)}" else: return f"⚠️ Deleted from memory but failed to update storage." except Exception as e: return f"❌ Error deleting face: {str(e)}" # ------------------------------ # LIST ALL REGISTERED FACES # ------------------------------ def list_faces(): """List all registered faces""" if not known_faces: return "No faces registered yet." face_list = "\n".join([f"• {name}" for name in known_faces.keys()]) return f"**Registered Faces ({len(known_faces)} total):**\n\n{face_list}" # ------------------------------ # DETECT + RECOGNIZE FUNCTION # ------------------------------ def detect_and_recognize(image): import traceback if image is None: return "❌ Please upload an image.", None if not known_faces: return "⚠️ No known faces registered yet!", None try: img = np.array(image) if img.ndim == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) else: img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) results = yolo.predict(source=img, conf=0.35, verbose=False) names_found = [] for r in results: boxes = r.boxes for box in boxes: x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) crop = img[y1:y2, x1:x2] if crop.size == 0: continue emb = get_embedding(crop) best_name, best_score = "Unknown", 0 for name, ref_emb in known_faces.items(): score = cosine_similarity([emb], [ref_emb])[0][0] if score > best_score: best_name, best_score = name, score # Lower threshold for better matching (0.35 instead of 0.45) if best_score < 0.35: best_name = "Unknown" cv2.rectangle(img, (x1, y1), (x2, y2), (0,255,0), 2) cv2.putText(img, f"{best_name} ({best_score:.2f})", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2) names_found.append(f"{best_name} ({best_score:.2f})") result_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) if names_found: return f"✅ Detected faces: {', '.join(names_found)}", result_img else: return "⚠️ No faces detected in the image.", result_img except Exception as e: tb = traceback.format_exc() print("---- ERROR DURING DETECT & RECOGNIZE ----") print(tb) return f"⚠️ Internal error: {str(e)}\n\n{tb}", None # ------------------------------ # GRADIO INTERFACE # ------------------------------ with gr.Blocks() as demo: gr.Markdown("# 👤 Face Detection & Recognition (YOLO + ArcFace)") gr.Markdown("Register faces to build your database, then detect and recognize them in new images!") with gr.Tab("📝 Register Face"): gr.Markdown("### Add a new face to the database") name_input = gr.Textbox(label="Person Name", placeholder="Enter person's name...") face_input = gr.Image(label="Upload Face Image", type="pil") register_btn = gr.Button("Register Face", variant="primary") register_output = gr.Markdown(label="Status") register_btn.click(register_face, [name_input, face_input], register_output) with gr.Tab("🔍 Detect & Recognize"): gr.Markdown("### Detect and recognize faces in an image") img_input = gr.Image(label="Upload Test Image", type="pil") detect_btn = gr.Button("Detect Faces", variant="primary") text_output = gr.Markdown(label="Results") img_output = gr.Image(label="Output Image") detect_btn.click(detect_and_recognize, img_input, [text_output, img_output]) with gr.Tab("📋 Manage Faces"): gr.Markdown("### View and manage registered faces") with gr.Row(): with gr.Column(): list_btn = gr.Button("Show All Registered Faces") list_output = gr.Markdown(label="Registered Faces") list_btn.click(list_faces, outputs=list_output) with gr.Column(): delete_name = gr.Textbox(label="Name to Delete", placeholder="Enter name to remove...") delete_btn = gr.Button("Delete Face", variant="stop") delete_output = gr.Markdown(label="Status") delete_btn.click(delete_face, delete_name, delete_output) demo.launch()