|
|
import os |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import onnxruntime as ort |
|
|
from ultralytics import YOLO |
|
|
from huggingface_hub import snapshot_download |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
import pickle |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
YOLO_REPO = "arnabdhar/YOLOv8-Face-Detection" |
|
|
ARCFACE_ONNX_REPO = "garavv/arcface-onnx" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
KNOWN_FACES_PKL = "known_faces.pkl" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
yolo_dir = snapshot_download(YOLO_REPO) |
|
|
arcface_dir = snapshot_download(ARCFACE_ONNX_REPO) |
|
|
|
|
|
def find_model(folder, ext): |
|
|
for root, _, files in os.walk(folder): |
|
|
for f in files: |
|
|
if f.endswith(ext): |
|
|
return os.path.join(root, f) |
|
|
return None |
|
|
|
|
|
yolo_model_file = find_model(yolo_dir, ".pt") |
|
|
arcface_file = find_model(arcface_dir, ".onnx") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
yolo = YOLO(yolo_model_file) |
|
|
arcface_sess = ort.InferenceSession(arcface_file, providers=["CPUExecutionProvider"]) |
|
|
arcface_input = arcface_sess.get_inputs()[0].name |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_embedding(face): |
|
|
""" |
|
|
Preprocess the face and compute embedding using ArcFace ONNX model. |
|
|
Fixed to match NHWC format (1, 112, 112, 3). |
|
|
""" |
|
|
|
|
|
if face.size == 0 or face.shape[0] < 10 or face.shape[1] < 10: |
|
|
return None |
|
|
|
|
|
img = cv2.resize(face, (112, 112)) |
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
img = (img.astype(np.float32) - 127.5) / 128.0 |
|
|
img = np.expand_dims(img, axis=0) |
|
|
|
|
|
emb = arcface_sess.run(None, {arcface_input: img})[0][0] |
|
|
|
|
|
|
|
|
emb = emb / (np.linalg.norm(emb) + 1e-8) |
|
|
return emb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_known_faces(): |
|
|
"""Load known faces from pickle file""" |
|
|
if os.path.exists(KNOWN_FACES_PKL): |
|
|
try: |
|
|
with open(KNOWN_FACES_PKL, 'rb') as f: |
|
|
return pickle.load(f) |
|
|
except Exception as e: |
|
|
print(f"Error loading known faces: {e}") |
|
|
return {} |
|
|
else: |
|
|
|
|
|
print(f"Creating new known faces database: {KNOWN_FACES_PKL}") |
|
|
empty_dict = {} |
|
|
try: |
|
|
with open(KNOWN_FACES_PKL, 'wb') as f: |
|
|
pickle.dump(empty_dict, f) |
|
|
except Exception as e: |
|
|
print(f"Error creating pickle file: {e}") |
|
|
return empty_dict |
|
|
|
|
|
def save_known_faces(known_faces): |
|
|
"""Save known faces to pickle file""" |
|
|
try: |
|
|
|
|
|
os.makedirs(os.path.dirname(KNOWN_FACES_PKL) if os.path.dirname(KNOWN_FACES_PKL) else '.', exist_ok=True) |
|
|
|
|
|
with open(KNOWN_FACES_PKL, 'wb') as f: |
|
|
pickle.dump(known_faces, f) |
|
|
print(f"Successfully saved {len(known_faces)} faces to {KNOWN_FACES_PKL}") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Error saving known faces: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
known_faces = load_known_faces() |
|
|
print(f"Loaded {len(known_faces)} known faces from storage.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def register_face(name, image): |
|
|
"""Register a face with a name and save to PKL file""" |
|
|
import traceback |
|
|
|
|
|
if not name or name.strip() == "": |
|
|
return "β Please enter a name before uploading an image." |
|
|
if image is None: |
|
|
return "β Please upload a valid face image." |
|
|
|
|
|
try: |
|
|
|
|
|
img = np.array(image) |
|
|
if img.ndim == 2: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) |
|
|
else: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
results = yolo.predict(source=img, conf=0.25, verbose=False) |
|
|
|
|
|
face_found = False |
|
|
for r in results: |
|
|
boxes = r.boxes |
|
|
if len(boxes) > 0: |
|
|
|
|
|
box = boxes[0] |
|
|
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) |
|
|
crop = img[y1:y2, x1:x2] |
|
|
|
|
|
|
|
|
emb = get_embedding(crop) |
|
|
if emb is not None and not np.isnan(emb).any(): |
|
|
face_found = True |
|
|
break |
|
|
|
|
|
if not face_found: |
|
|
return f"β No clear face detected for {name}. Please upload a clearer image with a visible face." |
|
|
|
|
|
|
|
|
known_faces[name] = emb |
|
|
|
|
|
|
|
|
if save_known_faces(known_faces): |
|
|
return f"β
Registered face for **{name}** and saved to storage. Total known faces: {len(known_faces)}" |
|
|
else: |
|
|
return f"β οΈ Registered face for **{name}** but failed to save to storage." |
|
|
|
|
|
except Exception as e: |
|
|
tb = traceback.format_exc() |
|
|
print("---- ERROR DURING REGISTER FACE ----") |
|
|
print(tb) |
|
|
return f"β οΈ Internal error: {str(e)}\n\n{tb}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def delete_face(name): |
|
|
"""Delete a registered face from memory and PKL file""" |
|
|
if not name or name.strip() == "": |
|
|
return "β Please enter a name to delete." |
|
|
|
|
|
if name not in known_faces: |
|
|
return f"β No face registered with name '{name}'." |
|
|
|
|
|
try: |
|
|
del known_faces[name] |
|
|
if save_known_faces(known_faces): |
|
|
return f"β
Deleted face for **{name}**. Remaining faces: {len(known_faces)}" |
|
|
else: |
|
|
return f"β οΈ Deleted from memory but failed to update storage." |
|
|
except Exception as e: |
|
|
return f"β Error deleting face: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def list_faces(): |
|
|
"""List all registered faces""" |
|
|
if not known_faces: |
|
|
return "No faces registered yet." |
|
|
|
|
|
face_list = "\n".join([f"β’ {name}" for name in known_faces.keys()]) |
|
|
return f"**Registered Faces ({len(known_faces)} total):**\n\n{face_list}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect_and_recognize(image): |
|
|
import traceback |
|
|
if image is None: |
|
|
return "β Please upload an image.", None |
|
|
if not known_faces: |
|
|
return "β οΈ No known faces registered yet!", None |
|
|
|
|
|
try: |
|
|
img = np.array(image) |
|
|
if img.ndim == 2: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) |
|
|
else: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
results = yolo.predict(source=img, conf=0.35, verbose=False) |
|
|
names_found = [] |
|
|
|
|
|
for r in results: |
|
|
boxes = r.boxes |
|
|
for box in boxes: |
|
|
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) |
|
|
crop = img[y1:y2, x1:x2] |
|
|
if crop.size == 0: |
|
|
continue |
|
|
emb = get_embedding(crop) |
|
|
best_name, best_score = "Unknown", 0 |
|
|
for name, ref_emb in known_faces.items(): |
|
|
score = cosine_similarity([emb], [ref_emb])[0][0] |
|
|
if score > best_score: |
|
|
best_name, best_score = name, score |
|
|
|
|
|
|
|
|
if best_score < 0.35: |
|
|
best_name = "Unknown" |
|
|
cv2.rectangle(img, (x1, y1), (x2, y2), (0,255,0), 2) |
|
|
cv2.putText(img, f"{best_name} ({best_score:.2f})", (x1, y1-10), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2) |
|
|
names_found.append(f"{best_name} ({best_score:.2f})") |
|
|
|
|
|
result_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) |
|
|
|
|
|
if names_found: |
|
|
return f"β
Detected faces: {', '.join(names_found)}", result_img |
|
|
else: |
|
|
return "β οΈ No faces detected in the image.", result_img |
|
|
|
|
|
except Exception as e: |
|
|
tb = traceback.format_exc() |
|
|
print("---- ERROR DURING DETECT & RECOGNIZE ----") |
|
|
print(tb) |
|
|
return f"β οΈ Internal error: {str(e)}\n\n{tb}", None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# π€ Face Detection & Recognition (YOLO + ArcFace)") |
|
|
gr.Markdown("Register faces to build your database, then detect and recognize them in new images!") |
|
|
|
|
|
with gr.Tab("π Register Face"): |
|
|
gr.Markdown("### Add a new face to the database") |
|
|
name_input = gr.Textbox(label="Person Name", placeholder="Enter person's name...") |
|
|
face_input = gr.Image(label="Upload Face Image", type="pil") |
|
|
register_btn = gr.Button("Register Face", variant="primary") |
|
|
register_output = gr.Markdown(label="Status") |
|
|
register_btn.click(register_face, [name_input, face_input], register_output) |
|
|
|
|
|
with gr.Tab("π Detect & Recognize"): |
|
|
gr.Markdown("### Detect and recognize faces in an image") |
|
|
img_input = gr.Image(label="Upload Test Image", type="pil") |
|
|
detect_btn = gr.Button("Detect Faces", variant="primary") |
|
|
text_output = gr.Markdown(label="Results") |
|
|
img_output = gr.Image(label="Output Image") |
|
|
detect_btn.click(detect_and_recognize, img_input, [text_output, img_output]) |
|
|
|
|
|
with gr.Tab("π Manage Faces"): |
|
|
gr.Markdown("### View and manage registered faces") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
list_btn = gr.Button("Show All Registered Faces") |
|
|
list_output = gr.Markdown(label="Registered Faces") |
|
|
list_btn.click(list_faces, outputs=list_output) |
|
|
|
|
|
with gr.Column(): |
|
|
delete_name = gr.Textbox(label="Name to Delete", placeholder="Enter name to remove...") |
|
|
delete_btn = gr.Button("Delete Face", variant="stop") |
|
|
delete_output = gr.Markdown(label="Status") |
|
|
delete_btn.click(delete_face, delete_name, delete_output) |
|
|
|
|
|
demo.launch() |