FrAnKu34t23 commited on
Commit
0f0c715
·
verified ·
1 Parent(s): b94cd58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -77
app.py CHANGED
@@ -1,92 +1,210 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- # --- 1. Import Existing Baselines ---
4
- # Wrapped in try-except so the app doesn't crash if files are temporarily missing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  try:
6
  from baseline.baseline_convnext import predict_convnext
7
  except ImportError:
8
- def predict_convnext(image): return {"Error": "ConvNeXt module missing"}
9
-
10
  try:
11
  from baseline.baseline_infer import predict_baseline
12
  except ImportError:
13
- def predict_baseline(image): return {"Error": "Baseline module missing"}
14
-
15
- # --- 2. Import NEW SPA Approach ---
16
- # This imports the function from: new_approach/spa_ensemble.py
17
  try:
18
  from new_approach.spa_ensemble import predict_spa
19
  except ImportError:
20
- def predict_spa(image): return {"Error": "SPA module missing. Check 'new_approach' folder."}
21
-
22
 
23
- # --- Placeholder models (for future extensions) ---
24
- def predict_placeholder_2(image):
25
- if image is None:
26
- return "Please upload an image."
27
- return "Model 4 is not available yet. Please check back later."
28
 
29
- # --- Main Prediction Logic ---
30
  def predict(model_choice, image):
31
- if image is None: return None
32
 
33
- if model_choice == "Herbarium Species Classifier":
34
- # Friend's ConvNeXt mix-stream CNN baseline
35
- return predict_convnext(image)
36
-
37
  elif model_choice == "Baseline (DINOv2 + LogReg)":
38
- # Original baseline
39
- return predict_baseline(image)
40
-
41
  elif model_choice == "SPA Ensemble (New Approach)":
42
- # YOUR NEW CODE: DINOv2 + BioCLIP + Handcrafted + SPA
43
- return predict_spa(image)
44
-
45
  elif model_choice == "Future Model 2 (Placeholder)":
46
- return predict_placeholder_2(image)
47
-
48
  else:
49
- return "Invalid model selected."
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # --- Gradio Interface ---
52
  with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
53
  with gr.Column(elem_id="app-wrapper"):
54
- # Header
55
  gr.Markdown(
56
  """
57
  <div id="app-header">
58
  <h1>🌿 Plant Species Classification</h1>
59
- <h3>AML Group Project – PsychicFireSong</h3>
60
- </div>
61
- """,
62
- elem_id="app-header",
63
- )
64
-
65
- # Badges row
66
- gr.Markdown(
67
- """
68
- <div id="badge-row">
69
- <span class="badge">Herbarium + Field images</span>
70
- <span class="badge">ConvNeXtV2</span>
71
- <span class="badge">SPA Ensemble</span>
72
  </div>
73
- """,
74
- elem_id="badge-row",
75
  )
76
 
77
- # Main card
78
  with gr.Row(elem_id="main-card"):
79
- # Left side: model + image
80
- with gr.Column(scale=1, elem_id="left-panel"):
81
  model_selector = gr.Dropdown(
82
  label="Select model",
83
  choices=[
84
- "Herbarium Species Classifier",
85
  "Baseline (DINOv2 + LogReg)",
86
  "SPA Ensemble (New Approach)",
87
  "Future Model 2 (Placeholder)",
88
  ],
89
- value="SPA Ensemble (New Approach)", # Default to your new model
90
  )
91
 
92
  gr.Markdown(
@@ -96,43 +214,28 @@ with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
96
  <b>Baseline</b> – Simple DINOv2 + LogReg.<br>
97
  <b>SPA Ensemble</b> – <i>(New)</i> DINOv2 + BioCLIP-2 + Handcrafted features.
98
  </div>
99
- """,
100
- elem_id="model-help",
101
- )
102
-
103
- image_input = gr.Image(
104
- type="pil",
105
- label="Upload plant image",
106
  )
107
 
 
108
  submit_button = gr.Button("Classify 🌱", variant="primary")
109
 
110
- # Right side: predictions
111
- with gr.Column(scale=1, elem_id="right-panel"):
112
- output_label = gr.Label(
113
- label="Top 5 predictions",
114
- num_top_classes=5,
 
 
115
  )
116
 
117
  submit_button.click(
118
  fn=predict,
119
  inputs=[model_selector, image_input],
120
- outputs=output_label,
121
- )
122
-
123
- # Optional examples
124
- gr.Examples(
125
- examples=[],
126
- inputs=image_input,
127
- outputs=output_label,
128
- fn=lambda img: predict("SPA Ensemble (New Approach)", img),
129
- cache_examples=False,
130
  )
131
 
132
- gr.Markdown(
133
- "Built for the AML course – compare CNN vs. DINOv2 feature-extractor baselines.",
134
- elem_id="footer",
135
- )
136
 
137
  if __name__ == "__main__":
138
  demo.launch()
 
1
  import gradio as gr
2
+ import os
3
+ import re
4
+ import pickle
5
+ import torch
6
+ from torchvision import transforms
7
+ from huggingface_hub import list_repo_files, hf_hub_download
8
 
9
+ # --- CONFIGURATION ---
10
+
11
+ # 1. Dataset Config (Where the images are stored)
12
+ # This is used to generate the URLs for the displayed images
13
+ DATASET_ID = "FrAnKu34t23/Herbarium_Field"
14
+ DATASET_URL_BASE = f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/train/herbarium/"
15
+
16
+ # 2. Model Repo Config (Where the herbarium_index.pkl is stored)
17
+ # This is used to download the Visual Search Index
18
+ MODEL_REPO_ID = "FrAnKu34t23/ensemble_models_plant"
19
+ INDEX_FILENAME = "herbarium_index.pkl"
20
+
21
+ # Global Variables
22
+ REFERENCE_IMAGE_MAP = {} # Fallback
23
+ VECTOR_INDEX = None # Smart Search Index
24
+ FEATURE_EXTRACTOR = None # DINOv2 model for retrieval
25
+ TRANSFORM = None # Image transforms
26
+
27
+ # --- SETUP: Load Resources ---
28
+ def load_resources():
29
+ global VECTOR_INDEX, FEATURE_EXTRACTOR, TRANSFORM, REFERENCE_IMAGE_MAP
30
+
31
+ print("🚀 App starting... Initializing resources.")
32
+
33
+ # 1. Download and Load Visual Search Index from Model Hub
34
+ try:
35
+ print(f"⬇️ Downloading {INDEX_FILENAME} from {MODEL_REPO_ID}...")
36
+
37
+ # This downloads the file to a local cache and returns the path
38
+ index_path = hf_hub_download(
39
+ repo_id=MODEL_REPO_ID,
40
+ filename=INDEX_FILENAME,
41
+ repo_type="model"
42
+ )
43
+
44
+ print(f"✅ Downloaded to {index_path}. Loading pickle...")
45
+ with open(index_path, "rb") as f:
46
+ VECTOR_INDEX = pickle.load(f)
47
+
48
+ # Load DINOv2 (Retrieval Engine)
49
+ print("⬇️ Loading DINOv2 (Retrieval Engine)...")
50
+ FEATURE_EXTRACTOR = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
51
+ FEATURE_EXTRACTOR.eval()
52
+
53
+ TRANSFORM = transforms.Compose([
54
+ transforms.Resize((224, 224)),
55
+ transforms.ToTensor(),
56
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
57
+ ])
58
+ print("🚀 Smart Search Ready!")
59
+
60
+ except Exception as e:
61
+ print(f"⚠️ Smart Search initialization failed: {e}")
62
+ print("ℹ️ Ensure 'herbarium_index.pkl' is uploaded to 'FrAnKu34t23/ensemble_models_plant'")
63
+ VECTOR_INDEX = None
64
+
65
+ # 2. Build Fallback Map
66
+ # We do this as a backup in case the specific class isn't in the index
67
+ build_fallback_map()
68
+
69
+ def build_fallback_map():
70
+ global REFERENCE_IMAGE_MAP
71
+ try:
72
+ print(f"🔄 Scanning dataset {DATASET_ID} for fallback map...")
73
+ # Note: If dataset is private, add token=os.environ.get("HF_TOKEN") inside list_repo_files
74
+ all_files = list_repo_files(repo_id=DATASET_ID, repo_type="dataset")
75
+
76
+ # Look for images in: train/herbarium/{class_id}/{filename}
77
+ image_files = [f for f in all_files if f.startswith("train/herbarium/") and f.lower().endswith(('.jpg', '.png'))]
78
+
79
+ for file_path in image_files:
80
+ parts = file_path.split("/")
81
+ # Expected parts: ['train', 'herbarium', 'CLASS_ID', 'FILENAME']
82
+ if len(parts) >= 4:
83
+ class_id = parts[2]
84
+ filename = parts[3]
85
+ if class_id not in REFERENCE_IMAGE_MAP:
86
+ REFERENCE_IMAGE_MAP[class_id] = filename
87
+ print(f"✅ Fallback map built for {len(REFERENCE_IMAGE_MAP)} classes.")
88
+ except Exception as e:
89
+ print(f"⚠️ Error scanning dataset: {e}")
90
+
91
+ # Load resources on startup
92
+ load_resources()
93
+
94
+ # --- Logic: Visual Similarity Search ---
95
+ def find_most_similar_herbarium_sheet(class_prediction, input_pil_image):
96
+ """
97
+ Finds the image in the predicted class folder that looks most similar to the input.
98
+ """
99
+ if not class_prediction: return None
100
+
101
+ # Extract Class ID
102
+ match = re.search(r'\((\d+)\)', class_prediction)
103
+ if not match: return None
104
+ class_id = match.group(1)
105
+
106
+ # Strategy A: Visual Similarity (Vectors)
107
+ if VECTOR_INDEX and FEATURE_EXTRACTOR and input_pil_image and class_id in VECTOR_INDEX:
108
+ try:
109
+ # Create embedding for input image
110
+ img_tensor = TRANSFORM(input_pil_image).unsqueeze(0)
111
+ with torch.no_grad():
112
+ input_vec = FEATURE_EXTRACTOR(img_tensor)
113
+ input_vec = torch.nn.functional.normalize(input_vec, p=2, dim=1)
114
+
115
+ # Compare against pre-calculated vectors in the index
116
+ candidates = VECTOR_INDEX[class_id]
117
+ best_score = -1.0
118
+ best_filename = None
119
+
120
+ for item in candidates:
121
+ # Cosine similarity
122
+ score = torch.mm(input_vec, item["vector"].T).item()
123
+ if score > best_score:
124
+ best_score = score
125
+ best_filename = item["filename"]
126
+
127
+ if best_filename:
128
+ return f"{DATASET_URL_BASE}{class_id}/{best_filename}"
129
+
130
+ except Exception as e:
131
+ print(f"⚠️ Search failed: {e}")
132
+
133
+ # Strategy B: First Available Image (Fallback)
134
+ filename = REFERENCE_IMAGE_MAP.get(class_id)
135
+ if filename:
136
+ return f"{DATASET_URL_BASE}{class_id}/{filename}"
137
+ return None
138
+
139
+ # --- Import User Models ---
140
+ # Safely import your existing model files
141
  try:
142
  from baseline.baseline_convnext import predict_convnext
143
  except ImportError:
144
+ def predict_convnext(image): return {"Error: ConvNeXt missing": 0.0}
 
145
  try:
146
  from baseline.baseline_infer import predict_baseline
147
  except ImportError:
148
+ def predict_baseline(image): return {"Error: Baseline missing": 0.0}
 
 
 
149
  try:
150
  from new_approach.spa_ensemble import predict_spa
151
  except ImportError:
152
+ def predict_spa(image): return {"Error: SPA missing": 0.0}
 
153
 
154
+ def predict_placeholder_2(image): return {"Model 4 Not Available": 0.0}
 
 
 
 
155
 
156
+ # --- Main App Logic ---
157
  def predict(model_choice, image):
158
+ if image is None: return None, None
159
 
160
+ # STEP 1: CLASSIFICATION
161
+ predictions = {}
162
+ if model_choice == "Herbarium Species Classifier (ConvNeXT)":
163
+ predictions = predict_convnext(image)
164
  elif model_choice == "Baseline (DINOv2 + LogReg)":
165
+ predictions = predict_baseline(image)
 
 
166
  elif model_choice == "SPA Ensemble (New Approach)":
167
+ predictions = predict_spa(image)
 
 
168
  elif model_choice == "Future Model 2 (Placeholder)":
169
+ predictions = predict_placeholder_2(image)
 
170
  else:
171
+ predictions = {"Invalid model": 0.0}
172
+
173
+ # STEP 2: RETRIEVAL
174
+ reference_image_url = None
175
+ if isinstance(predictions, dict) and predictions:
176
+ try:
177
+ top_class = max(predictions, key=predictions.get)
178
+ if "Error" not in top_class and "Please" not in top_class:
179
+ reference_image_url = find_most_similar_herbarium_sheet(top_class, image)
180
+ except Exception as e:
181
+ print(f"Error in retrieval: {e}")
182
+
183
+ return predictions, reference_image_url
184
 
185
  # --- Gradio Interface ---
186
  with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
187
  with gr.Column(elem_id="app-wrapper"):
 
188
  gr.Markdown(
189
  """
190
  <div id="app-header">
191
  <h1>🌿 Plant Species Classification</h1>
192
+ <h3>AML Group Project – Group 8</h3>
 
 
 
 
 
 
 
 
 
 
 
 
193
  </div>
194
+ """, elem_id="app-header"
 
195
  )
196
 
 
197
  with gr.Row(elem_id="main-card"):
198
+ with gr.Column(scale=1):
 
199
  model_selector = gr.Dropdown(
200
  label="Select model",
201
  choices=[
202
+ "Herbarium Species Classifier (ConvNeXT)",
203
  "Baseline (DINOv2 + LogReg)",
204
  "SPA Ensemble (New Approach)",
205
  "Future Model 2 (Placeholder)",
206
  ],
207
+ value="SPA Ensemble (New Approach)",
208
  )
209
 
210
  gr.Markdown(
 
214
  <b>Baseline</b> – Simple DINOv2 + LogReg.<br>
215
  <b>SPA Ensemble</b> – <i>(New)</i> DINOv2 + BioCLIP-2 + Handcrafted features.
216
  </div>
217
+ """, elem_id="model-help"
 
 
 
 
 
 
218
  )
219
 
220
+ image_input = gr.Image(type="pil", label="Upload plant image")
221
  submit_button = gr.Button("Classify 🌱", variant="primary")
222
 
223
+ with gr.Column(scale=1):
224
+ output_label = gr.Label(label="Top 5 predictions", num_top_classes=5)
225
+ herbarium_output = gr.Image(
226
+ label="Matched Herbarium Specimen (Visual Reference)",
227
+ show_label=True,
228
+ interactive=False,
229
+ height=300
230
  )
231
 
232
  submit_button.click(
233
  fn=predict,
234
  inputs=[model_selector, image_input],
235
+ outputs=[output_label, herbarium_output],
 
 
 
 
 
 
 
 
 
236
  )
237
 
238
+ gr.Markdown("Built for the AML course – Group 8", elem_id="footer")
 
 
 
239
 
240
  if __name__ == "__main__":
241
  demo.launch()