Lenguaje_senias / app.py
RaquelTP25's picture
Update app.py
1681229 verified
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
import os
import torch.nn.functional as F
# --- Configuración del modelo ---
MODEL_ID = "prithivMLmods/Alphabet-Sign-Language-Detection"
TITLE = "🤖 Clasificador de Alfabeto Americano de Señas (ASL)"
# --- Cargar modelo y procesador ---
processor, model = None, None
print(f"⏳ Intentando cargar el modelo '{MODEL_ID}'...")
try:
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
print(f"✅ Modelo cargado correctamente.")
except Exception as e:
print("❌ Error al cargar el modelo:")
print(e)
# --- Función de predicción (Top-3 ordenado verticalmente) ---
def predict_sign(image: Image.Image):
"""Realiza la predicción y devuelve el Top-3 en formato Markdown."""
if processor is None or model is None:
return "❌ El modelo no está disponible o falló al cargar. Revisa los logs."
try:
# 1. Preprocesamiento
image = image.convert("RGB")
inputs = processor(images=image, return_tensors="pt")
# 2. Predicción
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
# 3. Obtener Top-3
topk = torch.topk(probs, k=3)
# 4. Formatear resultado
resultados_md = "### 🔠 Predicciones más probables\n"
# Predicción Principal
main_label = model.config.id2label.get(topk.indices.tolist()[0], "Desconocido")
main_score = topk.values.tolist()[0] * 100
# Estilo para la predicción principal
result = f"## 🏆 Predicción Principal: **{main_label}**\n"
result += f"### Confianza: **{main_score:.2f}%**\n\n"
result += "---"
# Top 3 (incluyendo la principal)
result += "\n\n### Otras Predicciones (Top 3)\n"
for idx, score in zip(topk.indices.tolist(), topk.values.tolist()):
label = model.config.id2label.get(idx, "Desconocido")
result += f"🖐 **{label}** — {score * 100:.2f}% de confianza\n"
return result
except Exception as e:
return f"⚠️ Error durante la predicción: {str(e)}"
# --- Función para obtener ejemplos ---
def get_examples():
"""Busca ejemplos locales para la interfaz."""
examples_dir = "examples" # <--- Carpeta "examples"
examples = []
# Nota: Si solo usas 'senia.jpeg', cámbialo a buscar todos los archivos.
# El código original de Gradio solo incluía una.
# Usemos una versión que recoja todos los archivos en la carpeta:
if os.path.exists(examples_dir):
# Recorre todos los archivos en la carpeta
for img_file in os.listdir(examples_dir):
if img_file.endswith(('.jpg', '.jpeg', '.png')):
# Agrega la ruta completa de la imagen como un elemento de la lista
examples.append([os.path.join(examples_dir, img_file)])
return examples if examples else None
# --- Interfaz de Gradio (Usando gr.Blocks para el nuevo diseño) ---
with gr.Blocks(title=TITLE, theme=gr.themes.Soft()) as interface:
# Título principal y descripción
gr.Markdown(f"""
# {TITLE}
Esta aplicación utiliza el modelo de Hugging Face **`{MODEL_ID}`** para clasificar
imágenes de señas del **Alfabeto Americano (ASL)**.
""")
# Consejos para mejores resultados
gr.Markdown("""
**📸 Consejos para mejores resultados:**
- Usa imágenes con la mano clara y bien iluminada.
- Asegúrate de que la seña sea reconocible según el ASL.
- El modelo predice las 3 opciones más probables.
""")
with gr.Row():
# Columna de entrada de imagen
with gr.Column(scale=1):
gr.Markdown("## 📤 Sube una Imagen")
input_image = gr.Image(
label="📸 Sube la imagen de la seña",
type="pil",
height=350
)
predict_btn = gr.Button("🔍 Clasificar Seña", variant="primary", size="lg")
# Columna de salida del resultado
with gr.Column(scale=1):
gr.Markdown("## 📊 Resultado de la Predicción")
output_text = gr.Markdown(
label="Clasificación",
)
# Conectar la función de predicción al botón
predict_btn.click(
fn=predict_sign,
inputs=input_image,
outputs=output_text,
)
# Información adicional y ejemplos
gr.Markdown("---")
# Agregar ejemplos si existen
examples = get_examples()
if examples:
gr.Markdown("### 📸 Ejemplos para probar:")
gr.Examples(
examples=examples,
inputs=input_image,
outputs=output_text,
fn=predict_sign,
cache_examples=True
)
# Pie de página o créditos
gr.Markdown(f"""
<p style='text-align: center; font-size: 0.8em; color: #555;'>
Modelo basado en el trabajo de <a href='https://huggingface.co/{MODEL_ID}' target='_blank'>prithivMLmods en Hugging Face</a>.
</p>
""")
if __name__ == "__main__":
interface.launch()