File size: 5,385 Bytes
6286b6e
346d26e
bd6af11
346d26e
c7544fd
1681229
774db5c
1681229
c7544fd
1681229
 
 
 
 
f9fcb47
f2889d6
346d26e
c7544fd
1681229
f2889d6
1681229
 
f2889d6
 
1681229
 
 
346d26e
1681229
f2889d6
 
1681229
c7544fd
f2889d6
7587404
1681229
f2889d6
 
7587404
 
1681229
7587404
1681229
 
346d26e
1681229
 
 
 
 
 
 
 
 
 
 
 
 
346d26e
 
1681229
 
 
7587404
c7544fd
1681229
ec01ae8
1681229
ec01ae8
1681229
 
 
 
 
 
ec01ae8
1681229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec01ae8
1681229
 
 
 
ec01ae8
 
 
1681229
 
 
 
ec01ae8
1681229
ec01ae8
1681229
 
 
 
 
 
 
 
 
ec01ae8
 
1681229
 
 
 
 
 
ec01ae8
1681229
 
 
 
 
ec01ae8
1681229
ec01ae8
 
 
 
 
 
 
1681229
 
 
 
 
 
ec01ae8
 
346d26e
bd6af11
1681229
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
import os
import torch.nn.functional as F

# --- Configuración del modelo ---
MODEL_ID = "prithivMLmods/Alphabet-Sign-Language-Detection"
TITLE = "🤖 Clasificador de Alfabeto Americano de Señas (ASL)"

# --- Cargar modelo y procesador ---
processor, model = None, None
print(f"⏳ Intentando cargar el modelo '{MODEL_ID}'...")

try:
    processor = AutoImageProcessor.from_pretrained(MODEL_ID)
    model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
    print(f"✅ Modelo cargado correctamente.")
except Exception as e:
    print("❌ Error al cargar el modelo:")
    print(e)


# --- Función de predicción (Top-3 ordenado verticalmente) ---
def predict_sign(image: Image.Image):
    """Realiza la predicción y devuelve el Top-3 en formato Markdown."""
    if processor is None or model is None:
        return "❌ El modelo no está disponible o falló al cargar. Revisa los logs."

    try:
        # 1. Preprocesamiento
        image = image.convert("RGB")
        inputs = processor(images=image, return_tensors="pt")

        # 2. Predicción
        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]

        # 3. Obtener Top-3
        topk = torch.topk(probs, k=3)
        
        # 4. Formatear resultado
        resultados_md = "### 🔠 Predicciones más probables\n"
        
        # Predicción Principal
        main_label = model.config.id2label.get(topk.indices.tolist()[0], "Desconocido")
        main_score = topk.values.tolist()[0] * 100
        
        # Estilo para la predicción principal
        result = f"## 🏆 Predicción Principal: **{main_label}**\n"
        result += f"### Confianza: **{main_score:.2f}%**\n\n"
        result += "---"
        
        # Top 3 (incluyendo la principal)
        result += "\n\n### Otras Predicciones (Top 3)\n"
        
        for idx, score in zip(topk.indices.tolist(), topk.values.tolist()):
            label = model.config.id2label.get(idx, "Desconocido")
            result += f"🖐 **{label}** — {score * 100:.2f}% de confianza\n"
            
        return result

    except Exception as e:
        return f"⚠️ Error durante la predicción: {str(e)}"

# --- Función para obtener ejemplos ---
def get_examples():
    """Busca ejemplos locales para la interfaz."""
    examples_dir = "examples"  # <--- Carpeta "examples"
    examples = []
    # Nota: Si solo usas 'senia.jpeg', cámbialo a buscar todos los archivos.
    # El código original de Gradio solo incluía una.
    # Usemos una versión que recoja todos los archivos en la carpeta:
    
    if os.path.exists(examples_dir):
        # Recorre todos los archivos en la carpeta
        for img_file in os.listdir(examples_dir):
            if img_file.endswith(('.jpg', '.jpeg', '.png')):
                # Agrega la ruta completa de la imagen como un elemento de la lista
                examples.append([os.path.join(examples_dir, img_file)])
                
    return examples if examples else None

# --- Interfaz de Gradio (Usando gr.Blocks para el nuevo diseño) ---
with gr.Blocks(title=TITLE, theme=gr.themes.Soft()) as interface:
    
    # Título principal y descripción
    gr.Markdown(f"""
    # {TITLE}
    Esta aplicación utiliza el modelo de Hugging Face **`{MODEL_ID}`** para clasificar 
    imágenes de señas del **Alfabeto Americano (ASL)**.
    """)
    
    # Consejos para mejores resultados
    gr.Markdown("""
    **📸 Consejos para mejores resultados:**
    - Usa imágenes con la mano clara y bien iluminada.
    - Asegúrate de que la seña sea reconocible según el ASL.
    - El modelo predice las 3 opciones más probables.
    """)

    with gr.Row():
        
        # Columna de entrada de imagen
        with gr.Column(scale=1):
            gr.Markdown("## 📤 Sube una Imagen")
            input_image = gr.Image(
                label="📸 Sube la imagen de la seña",
                type="pil",
                height=350
            )
            predict_btn = gr.Button("🔍 Clasificar Seña", variant="primary", size="lg")
            
        # Columna de salida del resultado
        with gr.Column(scale=1):
            gr.Markdown("## 📊 Resultado de la Predicción")
            output_text = gr.Markdown(
                label="Clasificación",
            )

    # Conectar la función de predicción al botón
    predict_btn.click(
        fn=predict_sign,
        inputs=input_image,
        outputs=output_text,
    )

    # Información adicional y ejemplos
    gr.Markdown("---")
    
    # Agregar ejemplos si existen
    examples = get_examples()
    if examples:
        gr.Markdown("### 📸 Ejemplos para probar:")
        gr.Examples(
            examples=examples,
            inputs=input_image,
            outputs=output_text,
            fn=predict_sign,
            cache_examples=True
        )
    
    # Pie de página o créditos
    gr.Markdown(f"""
    <p style='text-align: center; font-size: 0.8em; color: #555;'>
    Modelo basado en el trabajo de <a href='https://huggingface.co/{MODEL_ID}' target='_blank'>prithivMLmods en Hugging Face</a>.
    </p>
    """)


if __name__ == "__main__":
    interface.launch()