Manav Sarkar commited on
Commit
363d6df
·
1 Parent(s): 442d120

age model using mivolo

Browse files
Files changed (38) hide show
  1. .gitignore +2 -0
  2. AttributesHolder.py +47 -0
  3. Dockerfile +3 -1
  4. __pycache__/AttributesHolder.cpython-311.pyc +0 -0
  5. __pycache__/app.cpython-311.pyc +0 -0
  6. app.py +50 -4
  7. deepface/__pycache__/DeepFace.cpython-311.pyc +0 -0
  8. mivolo/__init__.py +0 -0
  9. mivolo/__pycache__/__init__.cpython-311.pyc +0 -0
  10. mivolo/__pycache__/predictor.cpython-311.pyc +0 -0
  11. mivolo/__pycache__/structures.cpython-311.pyc +0 -0
  12. mivolo/data/__init__.py +0 -0
  13. mivolo/data/__pycache__/__init__.cpython-311.pyc +0 -0
  14. mivolo/data/__pycache__/data_reader.cpython-311.pyc +0 -0
  15. mivolo/data/__pycache__/misc.cpython-311.pyc +0 -0
  16. mivolo/data/data_reader.py +125 -0
  17. mivolo/data/dataset/__init__.py +66 -0
  18. mivolo/data/dataset/age_gender_dataset.py +194 -0
  19. mivolo/data/dataset/age_gender_loader.py +169 -0
  20. mivolo/data/dataset/classification_dataset.py +47 -0
  21. mivolo/data/dataset/reader_age_gender.py +492 -0
  22. mivolo/data/misc.py +246 -0
  23. mivolo/model/__init__.py +0 -0
  24. mivolo/model/__pycache__/__init__.cpython-311.pyc +0 -0
  25. mivolo/model/__pycache__/create_timm_model.cpython-311.pyc +0 -0
  26. mivolo/model/__pycache__/cross_bottleneck_attn.cpython-311.pyc +0 -0
  27. mivolo/model/__pycache__/mi_volo.cpython-311.pyc +0 -0
  28. mivolo/model/__pycache__/mivolo_model.cpython-311.pyc +0 -0
  29. mivolo/model/__pycache__/yolo_detector.cpython-311.pyc +0 -0
  30. mivolo/model/create_timm_model.py +107 -0
  31. mivolo/model/cross_bottleneck_attn.py +116 -0
  32. mivolo/model/mi_volo.py +243 -0
  33. mivolo/model/mivolo_model.py +404 -0
  34. mivolo/model/yolo_detector.py +46 -0
  35. mivolo/predictor.py +68 -0
  36. mivolo/structures.py +472 -0
  37. mivolo/version.py +1 -0
  38. requirements.txt +6 -1
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ /weights
2
+ /models
AttributesHolder.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class _AttributeHolder(object):
2
+ """Abstract base class that provides __repr__.
3
+
4
+ The __repr__ method returns a string in the format::
5
+ ClassName(attr=name, attr=name, ...)
6
+ The attributes are determined either by a class-level attribute,
7
+ '_kwarg_names', or by inspecting the instance __dict__.
8
+ """
9
+
10
+ def __repr__(self):
11
+ type_name = type(self).__name__
12
+ arg_strings = []
13
+ star_args = {}
14
+ for arg in self._get_args():
15
+ arg_strings.append(repr(arg))
16
+ for name, value in self._get_kwargs():
17
+ if name.isidentifier():
18
+ arg_strings.append('%s=%r' % (name, value))
19
+ else:
20
+ star_args[name] = value
21
+ if star_args:
22
+ arg_strings.append('**%s' % repr(star_args))
23
+ return '%s(%s)' % (type_name, ', '.join(arg_strings))
24
+
25
+ def _get_kwargs(self):
26
+ return list(self.__dict__.items())
27
+
28
+ def _get_args(self):
29
+ return []
30
+ class Namespace(_AttributeHolder):
31
+ """Simple object for storing attributes.
32
+
33
+ Implements equality by attribute names and values, and provides a simple
34
+ string representation.
35
+ """
36
+
37
+ def __init__(self, **kwargs):
38
+ for name in kwargs:
39
+ setattr(self, name, kwargs[name])
40
+
41
+ def __eq__(self, other):
42
+ if not isinstance(other, Namespace):
43
+ return NotImplemented
44
+ return vars(self) == vars(other)
45
+
46
+ def __contains__(self, key):
47
+ return key in self.__dict__
Dockerfile CHANGED
@@ -12,6 +12,7 @@ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
12
  # Create a folder named 'weights'
13
  RUN mkdir /weights
14
  RUN mkdir /code/weights
 
15
 
16
  RUN wget -O /code/weights/age_model_weights.h5 https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5
17
  RUN wget -O /code/weights/facial_expression_model_weights.h5 https://github.com/serengil/deepface_models/releases/download/v1.0/facial_expression_model_weights.h5
@@ -23,7 +24,8 @@ RUN bunzip2 /code/weights/shape_predictor_5_face_landmarks.dat.bz2
23
 
24
  RUN wget -O /code/weights/ghostfacenet_v1.h5 https://github.com/HamadYA/GhostFaceNets/releases/download/v1.2/GhostFaceNet_W1.3_S1_ArcFace.h5
25
 
26
-
 
27
 
28
  COPY . .
29
 
 
12
  # Create a folder named 'weights'
13
  RUN mkdir /weights
14
  RUN mkdir /code/weights
15
+ RUN mkdir /code/models
16
 
17
  RUN wget -O /code/weights/age_model_weights.h5 https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5
18
  RUN wget -O /code/weights/facial_expression_model_weights.h5 https://github.com/serengil/deepface_models/releases/download/v1.0/facial_expression_model_weights.h5
 
24
 
25
  RUN wget -O /code/weights/ghostfacenet_v1.h5 https://github.com/HamadYA/GhostFaceNets/releases/download/v1.2/GhostFaceNet_W1.3_S1_ArcFace.h5
26
 
27
+ RUN wget -O /code/models/mivolo_imbd.pth.tar https://firebasestorage.googleapis.com/v0/b/vidcorder-4bf3e.appspot.com/o/mivolo_imbd.pth.tar?alt=media&token=b8e9954a-b9fa-4908-84b9-3a7bd90ccd3d
28
+ RUN wget -O /code/models/yolov8x_person_face.pt https://firebasestorage.googleapis.com/v0/b/vidcorder-4bf3e.appspot.com/o/yolov8x_person_face.pt?alt=media&token=64d72b75-21fe-4d70-a461-6009dae2f65c
29
 
30
  COPY . .
31
 
__pycache__/AttributesHolder.cpython-311.pyc ADDED
Binary file (3.23 kB). View file
 
__pycache__/app.cpython-311.pyc ADDED
Binary file (4.94 kB). View file
 
app.py CHANGED
@@ -3,10 +3,35 @@ import deepface.DeepFace as DeepFace
3
  from fastapi.responses import JSONResponse
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel
6
-
 
7
  import base64
8
  from PIL import Image
9
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  app = FastAPI()
11
 
12
  class Base64Data(BaseModel):
@@ -31,11 +56,17 @@ def index():
31
  async def create_upload_file(contents: Base64Data):
32
  try:
33
  # Read the file
 
 
 
 
34
  contents = contents.base64_data
35
- res = DeepFace.analyze(img_path=contents, actions=['age', 'gender'])
36
  df = DeepFace.find(img_path = contents, db_path = "dataset/",model_name ='GhostFaceNet', threshold=0.9)
37
  print(df[0].head())
38
- return JSONResponse(content={"res": res, "celeb":df[0].head()['identity'][0] }, status_code=200)
 
 
 
39
  except Exception as e:
40
  return JSONResponse(content={"message": "Error processing the file.", "error": str(e)}, status_code=500)
41
 
@@ -55,4 +86,19 @@ async def file_to_base64(file):
55
 
56
  except Exception as e:
57
  print(str(e))
58
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from fastapi.responses import JSONResponse
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel
6
+ import numpy as np
7
+ from mivolo.predictor import Predictor
8
  import base64
9
  from PIL import Image
10
  import numpy as np
11
+ from AttributesHolder import Namespace
12
+ import cv2
13
+
14
+ config = {
15
+ "checkpoint": 'models/mivolo_imbd.pth.tar',
16
+ "detector_weights": 'models/yolov8x_person_face.pt',
17
+ "device": 'cpu',
18
+ "draw": False,
19
+ "with_persons": True,
20
+ "disable_faces": False,
21
+ "output": 'output'
22
+ }
23
+
24
+ namespace = Namespace()
25
+ setattr(namespace, 'checkpoint', 'models/mivolo_imbd.pth.tar')
26
+ setattr(namespace, 'detector_weights', 'models/yolov8x_person_face.pt')
27
+ setattr(namespace, 'device', 'cpu')
28
+ setattr(namespace, 'draw', False)
29
+ setattr(namespace, 'with_persons', True)
30
+ setattr(namespace, 'disable_faces', False)
31
+ setattr(namespace, 'output', 'output')
32
+
33
+ predictor = Predictor(config=namespace)
34
+
35
  app = FastAPI()
36
 
37
  class Base64Data(BaseModel):
 
56
  async def create_upload_file(contents: Base64Data):
57
  try:
58
  # Read the file
59
+ loaded_image = base64_to_cv2(contents.base64_data)
60
+ detected_objects, out_im = predictor.recognize(loaded_image)
61
+ age = detected_objects.ages[0]
62
+ gender = detected_objects.genders[0]
63
  contents = contents.base64_data
 
64
  df = DeepFace.find(img_path = contents, db_path = "dataset/",model_name ='GhostFaceNet', threshold=0.9)
65
  print(df[0].head())
66
+ return JSONResponse(content={"celeb":df[0].head()['identity'][0], "res":{
67
+ "age": age,
68
+ "gender": gender
69
+ } }, status_code=200)
70
  except Exception as e:
71
  return JSONResponse(content={"message": "Error processing the file.", "error": str(e)}, status_code=500)
72
 
 
86
 
87
  except Exception as e:
88
  print(str(e))
89
+ return None
90
+
91
+
92
+
93
+ def base64_to_cv2(base64_string):
94
+ base64_string = base64_string.split(",")[1]
95
+ # Decode the base64 string into bytes
96
+ decoded_bytes = base64.b64decode(base64_string)
97
+
98
+ # Convert bytes to numpy array
99
+ np_array = np.frombuffer(decoded_bytes, np.uint8)
100
+
101
+ # Decode the numpy array into an image
102
+ image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
103
+
104
+ return image
deepface/__pycache__/DeepFace.cpython-311.pyc CHANGED
Binary files a/deepface/__pycache__/DeepFace.cpython-311.pyc and b/deepface/__pycache__/DeepFace.cpython-311.pyc differ
 
mivolo/__init__.py ADDED
File without changes
mivolo/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (146 Bytes). View file
 
mivolo/__pycache__/predictor.cpython-311.pyc ADDED
Binary file (4.07 kB). View file
 
mivolo/__pycache__/structures.cpython-311.pyc ADDED
Binary file (29.3 kB). View file
 
mivolo/data/__init__.py ADDED
File without changes
mivolo/data/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (151 Bytes). View file
 
mivolo/data/__pycache__/data_reader.cpython-311.pyc ADDED
Binary file (8.81 kB). View file
 
mivolo/data/__pycache__/misc.cpython-311.pyc ADDED
Binary file (13.6 kB). View file
 
mivolo/data/data_reader.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass, field
4
+ from enum import Enum
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ import pandas as pd
8
+
9
+ IMAGES_EXT: Tuple = (".jpeg", ".jpg", ".png", ".webp", ".bmp", ".gif")
10
+ VIDEO_EXT: Tuple = (".mp4", ".avi", ".mov", ".mkv", ".webm")
11
+
12
+
13
+ @dataclass
14
+ class PictureInfo:
15
+ image_path: str
16
+ age: Optional[str] # age or age range(start;end format) or "-1"
17
+ gender: Optional[str] # "M" of "F" or "-1"
18
+ bbox: List[int] = field(default_factory=lambda: [-1, -1, -1, -1]) # face bbox: xyxy
19
+ person_bbox: List[int] = field(default_factory=lambda: [-1, -1, -1, -1]) # person bbox: xyxy
20
+
21
+ @property
22
+ def has_person_bbox(self) -> bool:
23
+ return any(coord != -1 for coord in self.person_bbox)
24
+
25
+ @property
26
+ def has_face_bbox(self) -> bool:
27
+ return any(coord != -1 for coord in self.bbox)
28
+
29
+ def has_gt(self, only_age: bool = False) -> bool:
30
+ if only_age:
31
+ return self.age != "-1"
32
+ else:
33
+ return not (self.age == "-1" and self.gender == "-1")
34
+
35
+ def clear_person_bbox(self):
36
+ self.person_bbox = [-1, -1, -1, -1]
37
+
38
+ def clear_face_bbox(self):
39
+ self.bbox = [-1, -1, -1, -1]
40
+
41
+
42
+ class AnnotType(Enum):
43
+ ORIGINAL = "original"
44
+ PERSONS = "persons"
45
+ NONE = "none"
46
+
47
+ @classmethod
48
+ def _missing_(cls, value):
49
+ print(f"WARN: Unknown annotation type {value}.")
50
+ return AnnotType.NONE
51
+
52
+
53
+ def get_all_files(path: str, extensions: Tuple = IMAGES_EXT):
54
+ files_all = []
55
+ for root, subFolders, files in os.walk(path):
56
+ for name in files:
57
+ # linux tricks with .directory that still is file
58
+ if "directory" not in name and sum([ext.lower() in name.lower() for ext in extensions]) > 0:
59
+ files_all.append(os.path.join(root, name))
60
+ return files_all
61
+
62
+
63
+ class InputType(Enum):
64
+ Image = 0
65
+ Video = 1
66
+ VideoStream = 2
67
+
68
+
69
+ def get_input_type(input_path: str) -> InputType:
70
+ if os.path.isdir(input_path):
71
+ print("Input is a folder, only images will be processed")
72
+ return InputType.Image
73
+ elif os.path.isfile(input_path):
74
+ if input_path.endswith(VIDEO_EXT):
75
+ return InputType.Video
76
+ if input_path.endswith(IMAGES_EXT):
77
+ return InputType.Image
78
+ else:
79
+ raise ValueError(
80
+ f"Unknown or unsupported input file format {input_path}, \
81
+ supported video formats: {VIDEO_EXT}, \
82
+ supported image formats: {IMAGES_EXT}"
83
+ )
84
+ elif input_path.startswith("http") and not input_path.endswith(IMAGES_EXT):
85
+ return InputType.VideoStream
86
+ else:
87
+ raise ValueError(f"Unknown input {input_path}")
88
+
89
+
90
+ def read_csv_annotation_file(annotation_file: str, images_dir: str, ignore_without_gt=False):
91
+ bboxes_per_image: Dict[str, List[PictureInfo]] = defaultdict(list)
92
+
93
+ df = pd.read_csv(annotation_file, sep=",")
94
+
95
+ annot_type = AnnotType("persons") if "person_x0" in df.columns else AnnotType("original")
96
+ print(f"Reading {annotation_file} (type: {annot_type})...")
97
+
98
+ missing_images = 0
99
+ for index, row in df.iterrows():
100
+ img_path = os.path.join(images_dir, row["img_name"])
101
+ if not os.path.exists(img_path):
102
+ missing_images += 1
103
+ continue
104
+
105
+ face_x1, face_y1, face_x2, face_y2 = row["face_x0"], row["face_y0"], row["face_x1"], row["face_y1"]
106
+ age, gender = str(row["age"]), str(row["gender"])
107
+
108
+ if ignore_without_gt and (age == "-1" or gender == "-1"):
109
+ continue
110
+
111
+ if annot_type == AnnotType.PERSONS:
112
+ p_x1, p_y1, p_x2, p_y2 = row["person_x0"], row["person_y0"], row["person_x1"], row["person_y1"]
113
+ person_bbox = list(map(int, [p_x1, p_y1, p_x2, p_y2]))
114
+ else:
115
+ person_bbox = [-1, -1, -1, -1]
116
+
117
+ bbox = list(map(int, [face_x1, face_y1, face_x2, face_y2]))
118
+ pic_info = PictureInfo(img_path, age, gender, bbox, person_bbox)
119
+ assert isinstance(pic_info.person_bbox, list)
120
+
121
+ bboxes_per_image[img_path].append(pic_info)
122
+
123
+ if missing_images > 0:
124
+ print(f"WARNING: Missing images: {missing_images}/{len(df)}")
125
+ return bboxes_per_image, annot_type
mivolo/data/dataset/__init__.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from mivolo.model.mi_volo import MiVOLO
5
+
6
+ from .age_gender_dataset import AgeGenderDataset
7
+ from .age_gender_loader import create_loader
8
+ from .classification_dataset import AdienceDataset, FairFaceDataset
9
+
10
+ DATASET_CLASS_MAP = {
11
+ "utk": AgeGenderDataset,
12
+ "lagenda": AgeGenderDataset,
13
+ "imdb": AgeGenderDataset,
14
+ "agedb": AgeGenderDataset,
15
+ "cacd": AgeGenderDataset,
16
+ "adience": AdienceDataset,
17
+ "fairface": FairFaceDataset,
18
+ }
19
+
20
+
21
+ def build(
22
+ name: str,
23
+ images_path: str,
24
+ annotations_path: str,
25
+ split: str,
26
+ mivolo_model: MiVOLO,
27
+ workers: int,
28
+ batch_size: int,
29
+ ) -> Tuple[torch.utils.data.Dataset, torch.utils.data.DataLoader]:
30
+
31
+ dataset_class = DATASET_CLASS_MAP[name]
32
+
33
+ dataset: torch.utils.data.Dataset = dataset_class(
34
+ images_path=images_path,
35
+ annotations_path=annotations_path,
36
+ name=name,
37
+ split=split,
38
+ target_size=mivolo_model.input_size,
39
+ max_age=mivolo_model.meta.max_age,
40
+ min_age=mivolo_model.meta.min_age,
41
+ model_with_persons=mivolo_model.meta.with_persons_model,
42
+ use_persons=mivolo_model.meta.use_persons,
43
+ disable_faces=mivolo_model.meta.disable_faces,
44
+ only_age=mivolo_model.meta.only_age,
45
+ )
46
+
47
+ data_config = mivolo_model.data_config
48
+
49
+ in_chans = 3 if not mivolo_model.meta.with_persons_model else 6
50
+ input_size = (in_chans, mivolo_model.input_size, mivolo_model.input_size)
51
+
52
+ dataset_loader: torch.utils.data.DataLoader = create_loader(
53
+ dataset,
54
+ input_size=input_size,
55
+ batch_size=batch_size,
56
+ mean=data_config["mean"],
57
+ std=data_config["std"],
58
+ num_workers=workers,
59
+ crop_pct=data_config["crop_pct"],
60
+ crop_mode=data_config["crop_mode"],
61
+ pin_memory=False,
62
+ device=mivolo_model.device,
63
+ target_type=dataset.target_dtype,
64
+ )
65
+
66
+ return dataset, dataset_loader
mivolo/data/dataset/age_gender_dataset.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, List, Optional, Set
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from mivolo.data.dataset.reader_age_gender import ReaderAgeGender
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+
11
+ _logger = logging.getLogger("AgeGenderDataset")
12
+
13
+
14
+ class AgeGenderDataset(torch.utils.data.Dataset):
15
+ def __init__(
16
+ self,
17
+ images_path,
18
+ annotations_path,
19
+ name=None,
20
+ split="train",
21
+ load_bytes=False,
22
+ img_mode="RGB",
23
+ transform=None,
24
+ is_training=False,
25
+ seed=1234,
26
+ target_size=224,
27
+ min_age=None,
28
+ max_age=None,
29
+ model_with_persons=False,
30
+ use_persons=False,
31
+ disable_faces=False,
32
+ only_age=False,
33
+ ):
34
+ reader = ReaderAgeGender(
35
+ images_path,
36
+ annotations_path,
37
+ split=split,
38
+ seed=seed,
39
+ target_size=target_size,
40
+ with_persons=use_persons,
41
+ disable_faces=disable_faces,
42
+ only_age=only_age,
43
+ )
44
+
45
+ self.name = name
46
+ self.model_with_persons = model_with_persons
47
+ self.reader = reader
48
+ self.load_bytes = load_bytes
49
+ self.img_mode = img_mode
50
+ self.transform = transform
51
+ self._consecutive_errors = 0
52
+ self.is_training = is_training
53
+ self.random_flip = 0.0
54
+
55
+ # Setting up classes.
56
+ # If min and max classes are passed - use them to have the same preprocessing for validation
57
+ self.max_age: float = None
58
+ self.min_age: float = None
59
+ self.avg_age: float = None
60
+ self.set_ages_min_max(min_age, max_age)
61
+
62
+ self.genders = ["M", "F"]
63
+ self.num_classes_gender = len(self.genders)
64
+
65
+ self.age_classes: Optional[List[str]] = self.set_age_classes()
66
+
67
+ self.num_classes_age = 1 if self.age_classes is None else len(self.age_classes)
68
+ self.num_classes: int = self.num_classes_age + self.num_classes_gender
69
+ self.target_dtype = torch.float32
70
+
71
+ def set_age_classes(self) -> Optional[List[str]]:
72
+ return None # for regression dataset
73
+
74
+ def set_ages_min_max(self, min_age: Optional[float], max_age: Optional[float]):
75
+
76
+ assert all(age is None for age in [min_age, max_age]) or all(
77
+ age is not None for age in [min_age, max_age]
78
+ ), "Both min and max age must be passed or none of them"
79
+
80
+ if max_age is not None and min_age is not None:
81
+ _logger.info(f"Received predefined min_age {min_age} and max_age {max_age}")
82
+ self.max_age = max_age
83
+ self.min_age = min_age
84
+ else:
85
+ # collect statistics from loaded dataset
86
+ all_ages_set: Set[int] = set()
87
+ for img_path, image_samples in self.reader._ann.items():
88
+ for image_sample_info in image_samples:
89
+ if image_sample_info.age == "-1":
90
+ continue
91
+ age = round(float(image_sample_info.age))
92
+ all_ages_set.add(age)
93
+
94
+ self.max_age = max(all_ages_set)
95
+ self.min_age = min(all_ages_set)
96
+
97
+ self.avg_age = (self.max_age + self.min_age) / 2.0
98
+
99
+ def _norm_age(self, age):
100
+ return (age - self.avg_age) / (self.max_age - self.min_age)
101
+
102
+ def parse_gender(self, _gender: str) -> float:
103
+ if _gender != "-1":
104
+ gender = float(0 if _gender == "M" or _gender == "0" else 1)
105
+ else:
106
+ gender = -1
107
+ return gender
108
+
109
+ def parse_target(self, _age: str, gender: str) -> List[Any]:
110
+ if _age != "-1":
111
+ age = round(float(_age))
112
+ age = self._norm_age(float(age))
113
+ else:
114
+ age = -1
115
+
116
+ target: List[float] = [age, self.parse_gender(gender)]
117
+ return target
118
+
119
+ @property
120
+ def transform(self):
121
+ return self._transform
122
+
123
+ @transform.setter
124
+ def transform(self, transform):
125
+ # Disable pretrained monkey-patched transforms
126
+ if not transform:
127
+ return
128
+
129
+ _trans = []
130
+ for trans in transform.transforms:
131
+ if "Resize" in str(trans):
132
+ continue
133
+ if "Crop" in str(trans):
134
+ continue
135
+ _trans.append(trans)
136
+ self._transform = transforms.Compose(_trans)
137
+
138
+ def apply_tranforms(self, image: Optional[np.ndarray]) -> np.ndarray:
139
+ if image is None:
140
+ return None
141
+
142
+ if self.transform is None:
143
+ return image
144
+
145
+ image = convert_to_pil(image, self.img_mode)
146
+ for trans in self.transform.transforms:
147
+ image = trans(image)
148
+ return image
149
+
150
+ def __getitem__(self, index):
151
+ # get preprocessed face and person crops (np.ndarray)
152
+ # resize + pad, for person crops: cut off other bboxes
153
+ images, target = self.reader[index]
154
+
155
+ target = self.parse_target(*target)
156
+
157
+ if self.model_with_persons:
158
+ face_image, person_image = images
159
+ person_image: np.ndarray = self.apply_tranforms(person_image)
160
+ else:
161
+ face_image = images[0]
162
+ person_image = None
163
+
164
+ face_image: np.ndarray = self.apply_tranforms(face_image)
165
+
166
+ if person_image is not None:
167
+ img = np.concatenate([face_image, person_image], axis=0)
168
+ else:
169
+ img = face_image
170
+
171
+ return img, target
172
+
173
+ def __len__(self):
174
+ return len(self.reader)
175
+
176
+ def filename(self, index, basename=False, absolute=False):
177
+ return self.reader.filename(index, basename, absolute)
178
+
179
+ def filenames(self, basename=False, absolute=False):
180
+ return self.reader.filenames(basename, absolute)
181
+
182
+
183
+ def convert_to_pil(cv_im: Optional[np.ndarray], img_mode: str = "RGB") -> "Image":
184
+ if cv_im is None:
185
+ return None
186
+
187
+ if img_mode == "RGB":
188
+ cv_im = cv2.cvtColor(cv_im, cv2.COLOR_BGR2RGB)
189
+ else:
190
+ raise Exception("Incorrect image mode has been passed!")
191
+
192
+ cv_im = np.ascontiguousarray(cv_im)
193
+ pil_image = Image.fromarray(cv_im)
194
+ return pil_image
mivolo/data/dataset/age_gender_loader.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from timm https://github.com/huggingface/pytorch-image-models
3
+
4
+ Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
+ """
6
+
7
+ import logging
8
+ from contextlib import suppress
9
+ from functools import partial
10
+ from itertools import repeat
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.utils.data
15
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16
+ from timm.data.dataset import IterableImageDataset
17
+ from timm.data.loader import PrefetchLoader, _worker_init
18
+ from timm.data.transforms_factory import create_transform
19
+
20
+ _logger = logging.getLogger(__name__)
21
+
22
+
23
+ def fast_collate(batch, target_dtype=torch.uint8):
24
+ """A fast collation function optimized for uint8 images (np array or torch) and target_dtype targets (labels)"""
25
+ assert isinstance(batch[0], tuple)
26
+ batch_size = len(batch)
27
+ if isinstance(batch[0][0], np.ndarray):
28
+ targets = torch.tensor([b[1] for b in batch], dtype=target_dtype)
29
+ assert len(targets) == batch_size
30
+ tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
31
+ for i in range(batch_size):
32
+ tensor[i] += torch.from_numpy(batch[i][0])
33
+ return tensor, targets
34
+ else:
35
+ raise ValueError(f"Incorrect batch type: {type(batch[0][0])}")
36
+
37
+
38
+ def adapt_to_chs(x, n):
39
+ if not isinstance(x, (tuple, list)):
40
+ x = tuple(repeat(x, n))
41
+ elif len(x) != n:
42
+ # doubled channels
43
+ if len(x) * 2 == n:
44
+ x = np.concatenate((x, x))
45
+ _logger.warning(f"Pretrained mean/std different shape than model (doubled channes), using concat: {x}.")
46
+ else:
47
+ x_mean = np.mean(x).item()
48
+ x = (x_mean,) * n
49
+ _logger.warning(f"Pretrained mean/std different shape than model, using avg value {x}.")
50
+ else:
51
+ assert len(x) == n, "normalization stats must match image channels"
52
+ return x
53
+
54
+
55
+ class PrefetchLoaderForMultiInput(PrefetchLoader):
56
+ def __init__(
57
+ self,
58
+ loader,
59
+ mean=IMAGENET_DEFAULT_MEAN,
60
+ std=IMAGENET_DEFAULT_STD,
61
+ channels=3,
62
+ device=torch.device("cuda"),
63
+ img_dtype=torch.float32,
64
+ ):
65
+
66
+ mean = adapt_to_chs(mean, channels)
67
+ std = adapt_to_chs(std, channels)
68
+ normalization_shape = (1, channels, 1, 1)
69
+
70
+ self.loader = loader
71
+ self.device = device
72
+ self.img_dtype = img_dtype
73
+ self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
74
+ self.std = torch.tensor([x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
75
+
76
+ self.is_cuda = torch.cuda.is_available() and device.type == "cuda"
77
+
78
+ def __iter__(self):
79
+ first = True
80
+ if self.is_cuda:
81
+ stream = torch.cuda.Stream()
82
+ stream_context = partial(torch.cuda.stream, stream=stream)
83
+ else:
84
+ stream = None
85
+ stream_context = suppress
86
+
87
+ for next_input, next_target in self.loader:
88
+
89
+ with stream_context():
90
+ next_input = next_input.to(device=self.device, non_blocking=True)
91
+ next_target = next_target.to(device=self.device, non_blocking=True)
92
+ next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
93
+
94
+ if not first:
95
+ yield input, target # noqa: F823, F821
96
+ else:
97
+ first = False
98
+
99
+ if stream is not None:
100
+ torch.cuda.current_stream().wait_stream(stream)
101
+
102
+ input = next_input
103
+ target = next_target
104
+
105
+ yield input, target
106
+
107
+
108
+ def create_loader(
109
+ dataset,
110
+ input_size,
111
+ batch_size,
112
+ mean=IMAGENET_DEFAULT_MEAN,
113
+ std=IMAGENET_DEFAULT_STD,
114
+ num_workers=1,
115
+ crop_pct=None,
116
+ crop_mode=None,
117
+ pin_memory=False,
118
+ img_dtype=torch.float32,
119
+ device=torch.device("cuda"),
120
+ persistent_workers=True,
121
+ worker_seeding="all",
122
+ target_type=torch.int64,
123
+ ):
124
+
125
+ transform = create_transform(
126
+ input_size,
127
+ is_training=False,
128
+ use_prefetcher=True,
129
+ mean=mean,
130
+ std=std,
131
+ crop_pct=crop_pct,
132
+ crop_mode=crop_mode,
133
+ )
134
+ dataset.transform = transform
135
+
136
+ if isinstance(dataset, IterableImageDataset):
137
+ # give Iterable datasets early knowledge of num_workers so that sample estimates
138
+ # are correct before worker processes are launched
139
+ dataset.set_loader_cfg(num_workers=num_workers)
140
+ raise ValueError("Incorrect dataset type: IterableImageDataset")
141
+
142
+ loader_class = torch.utils.data.DataLoader
143
+ loader_args = dict(
144
+ batch_size=batch_size,
145
+ shuffle=False,
146
+ num_workers=num_workers,
147
+ sampler=None,
148
+ collate_fn=lambda batch: fast_collate(batch, target_dtype=target_type),
149
+ pin_memory=pin_memory,
150
+ drop_last=False,
151
+ worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
152
+ persistent_workers=persistent_workers,
153
+ )
154
+ try:
155
+ loader = loader_class(dataset, **loader_args)
156
+ except TypeError:
157
+ loader_args.pop("persistent_workers") # only in Pytorch 1.7+
158
+ loader = loader_class(dataset, **loader_args)
159
+
160
+ loader = PrefetchLoaderForMultiInput(
161
+ loader,
162
+ mean=mean,
163
+ std=std,
164
+ channels=input_size[0],
165
+ device=device,
166
+ img_dtype=img_dtype,
167
+ )
168
+
169
+ return loader
mivolo/data/dataset/classification_dataset.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Optional
2
+
3
+ import torch
4
+
5
+ from .age_gender_dataset import AgeGenderDataset
6
+
7
+
8
+ class ClassificationDataset(AgeGenderDataset):
9
+ def __init__(self, *args, **kwargs):
10
+ super().__init__(*args, **kwargs)
11
+
12
+ self.target_dtype = torch.int32
13
+
14
+ def set_age_classes(self) -> Optional[List[str]]:
15
+ raise NotImplementedError
16
+
17
+ def parse_target(self, age: str, gender: str) -> List[Any]:
18
+ assert self.age_classes is not None
19
+ if age != "-1":
20
+ assert age in self.age_classes, f"Unknown category in {self.name} dataset: {age}"
21
+ age_ind = self.age_classes.index(age)
22
+ else:
23
+ age_ind = -1
24
+
25
+ target: List[int] = [age_ind, int(self.parse_gender(gender))]
26
+ return target
27
+
28
+
29
+ class FairFaceDataset(ClassificationDataset):
30
+ def set_age_classes(self) -> Optional[List[str]]:
31
+ age_classes = ["0;2", "3;9", "10;19", "20;29", "30;39", "40;49", "50;59", "60;69", "70;120"]
32
+ # a[i-1] <= v < a[i] => age_classes[i-1]
33
+ self._intervals = torch.tensor([0, 3, 10, 20, 30, 40, 50, 60, 70])
34
+ return age_classes
35
+
36
+
37
+ class AdienceDataset(ClassificationDataset):
38
+ def __init__(self, *args, **kwargs):
39
+ super().__init__(*args, **kwargs)
40
+
41
+ self.target_dtype = torch.int32
42
+
43
+ def set_age_classes(self) -> Optional[List[str]]:
44
+ age_classes = ["0;2", "4;6", "8;12", "15;20", "25;32", "38;43", "48;53", "60;100"]
45
+ # a[i-1] <= v < a[i] => age_classes[i-1]
46
+ self._intervals = torch.tensor([0, 4, 7, 14, 24, 36, 46, 57])
47
+ return age_classes
mivolo/data/dataset/reader_age_gender.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from functools import partial
4
+ from multiprocessing.pool import ThreadPool
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ import cv2
8
+ import numpy as np
9
+ from mivolo.data.data_reader import AnnotType, PictureInfo, get_all_files, read_csv_annotation_file
10
+ from mivolo.data.misc import IOU, class_letterbox
11
+ from timm.data.readers.reader import Reader
12
+ from tqdm import tqdm
13
+
14
+ CROP_ROUND_TOL = 0.3
15
+ MIN_PERSON_SIZE = 100
16
+ MIN_PERSON_CROP_AFTERCUT_RATIO = 0.4
17
+
18
+ _logger = logging.getLogger("ReaderAgeGender")
19
+
20
+
21
+ class ReaderAgeGender(Reader):
22
+ """
23
+ Reader for almost original imdb-wiki cleaned dataset.
24
+ Two changes:
25
+ 1. Your annotation must be in ./annotation subdir of dataset root
26
+ 2. Images must be in images subdir
27
+
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ images_path,
33
+ annotations_path,
34
+ split="validation",
35
+ target_size=224,
36
+ min_size=5,
37
+ seed=1234,
38
+ with_persons=False,
39
+ min_person_size=MIN_PERSON_SIZE,
40
+ disable_faces=False,
41
+ only_age=False,
42
+ min_person_aftercut_ratio=MIN_PERSON_CROP_AFTERCUT_RATIO,
43
+ crop_round_tol=CROP_ROUND_TOL,
44
+ ):
45
+ super().__init__()
46
+
47
+ self.with_persons = with_persons
48
+ self.disable_faces = disable_faces
49
+ self.only_age = only_age
50
+
51
+ # can be only black for now, even though it's not very good with further normalization
52
+ self.crop_out_color = (0, 0, 0)
53
+
54
+ self.empty_crop = np.ones((target_size, target_size, 3)) * self.crop_out_color
55
+ self.empty_crop = self.empty_crop.astype(np.uint8)
56
+
57
+ self.min_person_size = min_person_size
58
+ self.min_person_aftercut_ratio = min_person_aftercut_ratio
59
+ self.crop_round_tol = crop_round_tol
60
+
61
+ splits = split.split(",")
62
+ self.splits = [split.strip() for split in splits if len(split.strip())]
63
+ assert len(self.splits), "Incorrect split arg"
64
+
65
+ self.min_size = min_size
66
+ self.seed = seed
67
+ self.target_size = target_size
68
+
69
+ # Reading annotations. Can be multiple files if annotations_path dir
70
+ self._ann: Dict[str, List[PictureInfo]] = {} # list of samples for each image
71
+ self._associated_objects: Dict[str, Dict[int, List[List[int]]]] = {}
72
+ self._faces_list: List[Tuple[str, int]] = [] # samples from this list will be loaded in __getitem__
73
+
74
+ self._read_annotations(images_path, annotations_path)
75
+ _logger.info(f"Dataset length: {len(self._faces_list)} crops")
76
+
77
+ def __getitem__(self, index):
78
+ return self._read_img_and_label(index)
79
+
80
+ def __len__(self):
81
+ return len(self._faces_list)
82
+
83
+ def _filename(self, index, basename=False, absolute=False):
84
+ img_p = self._faces_list[index][0]
85
+ return os.path.basename(img_p) if basename else img_p
86
+
87
+ def _read_annotations(self, images_path, csvs_path):
88
+ self._ann = {}
89
+ self._faces_list = []
90
+ self._associated_objects = {}
91
+
92
+ csvs = get_all_files(csvs_path, [".csv"])
93
+ csvs = [c for c in csvs if any(split_name in os.path.basename(c) for split_name in self.splits)]
94
+
95
+ # load annotations per image
96
+ for csv in csvs:
97
+ db, ann_type = read_csv_annotation_file(csv, images_path)
98
+ if self.with_persons and ann_type != AnnotType.PERSONS:
99
+ raise ValueError(
100
+ f"Annotation type in file {csv} contains no persons, "
101
+ f"but annotations with persons are requested."
102
+ )
103
+ self._ann.update(db)
104
+
105
+ if len(self._ann) == 0:
106
+ raise ValueError("Annotations are empty!")
107
+
108
+ self._ann, self._associated_objects = self.prepare_annotations()
109
+ images_list = list(self._ann.keys())
110
+
111
+ for img_path in images_list:
112
+ for index, image_sample_info in enumerate(self._ann[img_path]):
113
+ assert image_sample_info.has_gt(
114
+ self.only_age
115
+ ), "Annotations must be checked with self.prepare_annotations() func"
116
+ self._faces_list.append((img_path, index))
117
+
118
+ def _read_img_and_label(self, index):
119
+ if not isinstance(index, int):
120
+ raise TypeError("ReaderAgeGender expected index to be integer")
121
+
122
+ img_p, face_index = self._faces_list[index]
123
+ ann: PictureInfo = self._ann[img_p][face_index]
124
+ img = cv2.imread(img_p)
125
+
126
+ face_empty = True
127
+ if ann.has_face_bbox and not (self.with_persons and self.disable_faces):
128
+ face_crop, face_empty = self._get_crop(ann.bbox, img)
129
+
130
+ if not self.with_persons and face_empty:
131
+ # model without persons
132
+ raise ValueError("Annotations must be checked with self.prepare_annotations() func")
133
+
134
+ if face_empty:
135
+ face_crop = self.empty_crop
136
+
137
+ person_empty = True
138
+ if self.with_persons or self.disable_faces:
139
+ if ann.has_person_bbox:
140
+ # cut off all associated objects from person crop
141
+ objects = self._associated_objects[img_p][face_index]
142
+ person_crop, person_empty = self._get_crop(
143
+ ann.person_bbox,
144
+ img,
145
+ crop_out_color=self.crop_out_color,
146
+ asced_objects=objects,
147
+ )
148
+
149
+ if face_empty and person_empty:
150
+ raise ValueError("Annotations must be checked with self.prepare_annotations() func")
151
+
152
+ if person_empty:
153
+ person_crop = self.empty_crop
154
+
155
+ return (face_crop, person_crop), [ann.age, ann.gender]
156
+
157
+ def _get_crop(
158
+ self,
159
+ bbox,
160
+ img,
161
+ asced_objects=None,
162
+ crop_out_color=(0, 0, 0),
163
+ ) -> Tuple[np.ndarray, bool]:
164
+
165
+ empty_bbox = False
166
+
167
+ xmin, ymin, xmax, ymax = bbox
168
+ assert not (
169
+ ymax - ymin < self.min_size or xmax - xmin < self.min_size
170
+ ), "Annotations must be checked with self.prepare_annotations() func"
171
+
172
+ crop = img[ymin:ymax, xmin:xmax]
173
+
174
+ if asced_objects:
175
+ # cut off other objects for person crop
176
+ crop, empty_bbox = _cropout_asced_objs(
177
+ asced_objects,
178
+ bbox,
179
+ crop.copy(),
180
+ crop_out_color=crop_out_color,
181
+ min_person_size=self.min_person_size,
182
+ crop_round_tol=self.crop_round_tol,
183
+ min_person_aftercut_ratio=self.min_person_aftercut_ratio,
184
+ )
185
+ if empty_bbox:
186
+ crop = self.empty_crop
187
+
188
+ crop = class_letterbox(crop, new_shape=(self.target_size, self.target_size), color=crop_out_color)
189
+ return crop, empty_bbox
190
+
191
+ def prepare_annotations(self):
192
+
193
+ good_anns: Dict[str, List[PictureInfo]] = {}
194
+ all_associated_objects: Dict[str, Dict[int, List[List[int]]]] = {}
195
+
196
+ if not self.with_persons:
197
+ # remove all persons
198
+ for img_path, bboxes in self._ann.items():
199
+ for sample in bboxes:
200
+ sample.clear_person_bbox()
201
+
202
+ # check dataset and collect associated_objects
203
+ verify_images_func = partial(
204
+ verify_images,
205
+ min_size=self.min_size,
206
+ min_person_size=self.min_person_size,
207
+ with_persons=self.with_persons,
208
+ disable_faces=self.disable_faces,
209
+ crop_round_tol=self.crop_round_tol,
210
+ min_person_aftercut_ratio=self.min_person_aftercut_ratio,
211
+ only_age=self.only_age,
212
+ )
213
+ num_threads = min(8, os.cpu_count())
214
+
215
+ all_msgs = []
216
+ broken = 0
217
+ skipped = 0
218
+ all_skipped_crops = 0
219
+ desc = "Check annotations..."
220
+ with ThreadPool(num_threads) as pool:
221
+ pbar = tqdm(
222
+ pool.imap_unordered(verify_images_func, list(self._ann.items())),
223
+ desc=desc,
224
+ total=len(self._ann),
225
+ )
226
+
227
+ for (img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops) in pbar:
228
+ broken += 1 if is_corrupted else 0
229
+ all_msgs.extend(msgs)
230
+ all_skipped_crops += skipped_crops
231
+ skipped += 1 if is_empty_annotations else 0
232
+ if img_info is not None:
233
+ img_path, img_samples = img_info
234
+ good_anns[img_path] = img_samples
235
+ all_associated_objects.update({img_path: associated_objects})
236
+
237
+ pbar.desc = (
238
+ f"{desc} {skipped} images skipped ({all_skipped_crops} crops are incorrect); "
239
+ f"{broken} images corrupted"
240
+ )
241
+
242
+ pbar.close()
243
+
244
+ for msg in all_msgs:
245
+ print(msg)
246
+ print(f"\nLeft images: {len(good_anns)}")
247
+
248
+ return good_anns, all_associated_objects
249
+
250
+
251
+ def verify_images(
252
+ img_info,
253
+ min_size: int,
254
+ min_person_size: int,
255
+ with_persons: bool,
256
+ disable_faces: bool,
257
+ crop_round_tol: float,
258
+ min_person_aftercut_ratio: float,
259
+ only_age: bool,
260
+ ):
261
+ # If crop is too small, if image can not be read or if image does not exist
262
+ # then filter out this sample
263
+
264
+ disable_faces = disable_faces and with_persons
265
+ kwargs = dict(
266
+ min_person_size=min_person_size,
267
+ disable_faces=disable_faces,
268
+ with_persons=with_persons,
269
+ crop_round_tol=crop_round_tol,
270
+ min_person_aftercut_ratio=min_person_aftercut_ratio,
271
+ only_age=only_age,
272
+ )
273
+
274
+ def bbox_correct(bbox, min_size, im_h, im_w) -> Tuple[bool, List[int]]:
275
+ ymin, ymax, xmin, xmax = _correct_bbox(bbox, im_h, im_w)
276
+ crop_h, crop_w = ymax - ymin, xmax - xmin
277
+ if crop_h < min_size or crop_w < min_size:
278
+ return False, [-1, -1, -1, -1]
279
+ bbox = [xmin, ymin, xmax, ymax]
280
+ return True, bbox
281
+
282
+ msgs = []
283
+ skipped_crops = 0
284
+ is_corrupted = False
285
+ is_empty_annotations = False
286
+
287
+ img_path: str = img_info[0]
288
+ img_samples: List[PictureInfo] = img_info[1]
289
+ try:
290
+ im_cv = cv2.imread(img_path)
291
+ im_h, im_w = im_cv.shape[:2]
292
+ except Exception:
293
+ msgs.append(f"Can not load image {img_path}")
294
+ is_corrupted = True
295
+ return None, {}, msgs, is_corrupted, is_empty_annotations, skipped_crops
296
+
297
+ out_samples: List[PictureInfo] = []
298
+ for sample in img_samples:
299
+ # correct face bbox
300
+ if sample.has_face_bbox:
301
+ is_correct, sample.bbox = bbox_correct(sample.bbox, min_size, im_h, im_w)
302
+ if not is_correct and sample.has_gt(only_age):
303
+ msgs.append("Small face. Passing..")
304
+ skipped_crops += 1
305
+
306
+ # correct person bbox
307
+ if sample.has_person_bbox:
308
+ is_correct, sample.person_bbox = bbox_correct(
309
+ sample.person_bbox, max(min_person_size, min_size), im_h, im_w
310
+ )
311
+ if not is_correct and sample.has_gt(only_age):
312
+ msgs.append(f"Small person {img_path}. Passing..")
313
+ skipped_crops += 1
314
+
315
+ if sample.has_face_bbox or sample.has_person_bbox:
316
+ out_samples.append(sample)
317
+ elif sample.has_gt(only_age):
318
+ msgs.append("Sample has no face and no body. Passing..")
319
+ skipped_crops += 1
320
+
321
+ # sort that samples with undefined age and gender be the last
322
+ out_samples = sorted(out_samples, key=lambda sample: 1 if not sample.has_gt(only_age) else 0)
323
+
324
+ # for each person find other faces and persons bboxes, intersected with it
325
+ associated_objects: Dict[int, List[List[int]]] = find_associated_objects(out_samples, only_age=only_age)
326
+
327
+ out_samples, associated_objects, skipped_crops = filter_bad_samples(
328
+ out_samples, associated_objects, im_cv, msgs, skipped_crops, **kwargs
329
+ )
330
+
331
+ out_img_info: Optional[Tuple[str, List]] = (img_path, out_samples)
332
+ if len(out_samples) == 0:
333
+ out_img_info = None
334
+ is_empty_annotations = True
335
+
336
+ return out_img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops
337
+
338
+
339
+ def filter_bad_samples(
340
+ out_samples: List[PictureInfo],
341
+ associated_objects: dict,
342
+ im_cv: np.ndarray,
343
+ msgs: List[str],
344
+ skipped_crops: int,
345
+ **kwargs,
346
+ ):
347
+ with_persons, disable_faces, min_person_size, crop_round_tol, min_person_aftercut_ratio, only_age = (
348
+ kwargs["with_persons"],
349
+ kwargs["disable_faces"],
350
+ kwargs["min_person_size"],
351
+ kwargs["crop_round_tol"],
352
+ kwargs["min_person_aftercut_ratio"],
353
+ kwargs["only_age"],
354
+ )
355
+
356
+ # left only samples with annotations
357
+ inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_gt(only_age)]
358
+ out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
359
+
360
+ if kwargs["disable_faces"]:
361
+ # clear all faces
362
+ for ind, sample in enumerate(out_samples):
363
+ sample.clear_face_bbox()
364
+
365
+ # left only samples with person_bbox
366
+ inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_person_bbox]
367
+ out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
368
+
369
+ if with_persons or disable_faces:
370
+ # check that preprocessing func
371
+ # _cropout_asced_objs() return not empty person_image for each out sample
372
+
373
+ inds = []
374
+ for ind, sample in enumerate(out_samples):
375
+ person_empty = True
376
+ if sample.has_person_bbox:
377
+ xmin, ymin, xmax, ymax = sample.person_bbox
378
+ crop = im_cv[ymin:ymax, xmin:xmax]
379
+ # cut off all associated objects from person crop
380
+ _, person_empty = _cropout_asced_objs(
381
+ associated_objects[ind],
382
+ sample.person_bbox,
383
+ crop.copy(),
384
+ min_person_size=min_person_size,
385
+ crop_round_tol=crop_round_tol,
386
+ min_person_aftercut_ratio=min_person_aftercut_ratio,
387
+ )
388
+
389
+ if person_empty and not sample.has_face_bbox:
390
+ msgs.append("Small person after preprocessing. Passing..")
391
+ skipped_crops += 1
392
+ else:
393
+ inds.append(ind)
394
+ out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
395
+
396
+ assert len(associated_objects) == len(out_samples)
397
+ return out_samples, associated_objects, skipped_crops
398
+
399
+
400
+ def _filter_by_ind(out_samples, associated_objects, inds):
401
+ _associated_objects = {}
402
+ _out_samples = []
403
+ for ind, sample in enumerate(out_samples):
404
+ if ind in inds:
405
+ _associated_objects[len(_out_samples)] = associated_objects[ind]
406
+ _out_samples.append(sample)
407
+
408
+ return _out_samples, _associated_objects
409
+
410
+
411
+ def find_associated_objects(
412
+ image_samples: List[PictureInfo], iou_thresh=0.0001, only_age=False
413
+ ) -> Dict[int, List[List[int]]]:
414
+ """
415
+ For each person (which has gt age and gt gender) find other faces and persons bboxes, intersected with it
416
+ """
417
+ associated_objects: Dict[int, List[List[int]]] = {}
418
+
419
+ for iindex, image_sample_info in enumerate(image_samples):
420
+ # add own face
421
+ associated_objects[iindex] = [image_sample_info.bbox] if image_sample_info.has_face_bbox else []
422
+
423
+ if not image_sample_info.has_person_bbox or not image_sample_info.has_gt(only_age):
424
+ # if sample has not gt => not be used
425
+ continue
426
+
427
+ iperson_box = image_sample_info.person_bbox
428
+ for jindex, other_image_sample in enumerate(image_samples):
429
+ if iindex == jindex:
430
+ continue
431
+ if other_image_sample.has_face_bbox:
432
+ jface_bbox = other_image_sample.bbox
433
+ iou = _get_iou(jface_bbox, iperson_box)
434
+ if iou >= iou_thresh:
435
+ associated_objects[iindex].append(jface_bbox)
436
+ if other_image_sample.has_person_bbox:
437
+ jperson_bbox = other_image_sample.person_bbox
438
+ iou = _get_iou(jperson_bbox, iperson_box)
439
+ if iou >= iou_thresh:
440
+ associated_objects[iindex].append(jperson_bbox)
441
+
442
+ return associated_objects
443
+
444
+
445
+ def _cropout_asced_objs(
446
+ asced_objects,
447
+ person_bbox,
448
+ crop,
449
+ min_person_size,
450
+ crop_round_tol,
451
+ min_person_aftercut_ratio,
452
+ crop_out_color=(0, 0, 0),
453
+ ):
454
+ empty = False
455
+ xmin, ymin, xmax, ymax = person_bbox
456
+
457
+ for a_obj in asced_objects:
458
+ aobj_xmin, aobj_ymin, aobj_xmax, aobj_ymax = a_obj
459
+
460
+ aobj_ymin = int(max(aobj_ymin - ymin, 0))
461
+ aobj_xmin = int(max(aobj_xmin - xmin, 0))
462
+ aobj_ymax = int(min(aobj_ymax - ymin, ymax - ymin))
463
+ aobj_xmax = int(min(aobj_xmax - xmin, xmax - xmin))
464
+
465
+ crop[aobj_ymin:aobj_ymax, aobj_xmin:aobj_xmax] = crop_out_color
466
+
467
+ # calc useful non-black area
468
+ remain_ratio = np.count_nonzero(crop) / (crop.shape[0] * crop.shape[1] * crop.shape[2])
469
+ if (crop.shape[0] < min_person_size or crop.shape[1] < min_person_size) or remain_ratio < min_person_aftercut_ratio:
470
+ crop = None
471
+ empty = True
472
+
473
+ return crop, empty
474
+
475
+
476
+ def _correct_bbox(bbox, h, w):
477
+ xmin, ymin, xmax, ymax = bbox
478
+ ymin = min(max(ymin, 0), h)
479
+ ymax = min(max(ymax, 0), h)
480
+ xmin = min(max(xmin, 0), w)
481
+ xmax = min(max(xmax, 0), w)
482
+ return ymin, ymax, xmin, xmax
483
+
484
+
485
+ def _get_iou(bbox1, bbox2):
486
+ xmin1, ymin1, xmax1, ymax1 = bbox1
487
+ xmin2, ymin2, xmax2, ymax2 = bbox2
488
+ iou = IOU(
489
+ [ymin1, xmin1, ymax1, xmax1],
490
+ [ymin2, xmin2, ymax2, xmax2],
491
+ )
492
+ return iou
mivolo/data/misc.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import ast
3
+ import re
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ import torchvision.transforms.functional as F
10
+ from scipy.optimize import linear_sum_assignment
11
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
12
+
13
+ CROP_ROUND_RATE = 0.1
14
+ MIN_PERSON_CROP_NONZERO = 0.5
15
+
16
+
17
+ def aggregate_votes_winsorized(ages, max_age_dist=6):
18
+ # Replace any annotation that is more than a max_age_dist away from the median
19
+ # with the median + max_age_dist if higher or max_age_dist - max_age_dist if below
20
+ median = np.median(ages)
21
+ ages = np.clip(ages, median - max_age_dist, median + max_age_dist)
22
+ return np.mean(ages)
23
+
24
+
25
+ def natural_key(string_):
26
+ """See http://www.codinghorror.com/blog/archives/001018.html"""
27
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
28
+
29
+
30
+ def add_bool_arg(parser, name, default=False, help=""):
31
+ dest_name = name.replace("-", "_")
32
+ group = parser.add_mutually_exclusive_group(required=False)
33
+ group.add_argument("--" + name, dest=dest_name, action="store_true", help=help)
34
+ group.add_argument("--no-" + name, dest=dest_name, action="store_false", help=help)
35
+ parser.set_defaults(**{dest_name: default})
36
+
37
+
38
+ def cumulative_score(pred_ages, gt_ages, L, tol=1e-6):
39
+ n = pred_ages.shape[0]
40
+ num_correct = torch.sum(torch.abs(pred_ages - gt_ages) <= L + tol)
41
+ cs_score = num_correct / n
42
+ return cs_score
43
+
44
+
45
+ def cumulative_error(pred_ages, gt_ages, L, tol=1e-6):
46
+ n = pred_ages.shape[0]
47
+ num_correct = torch.sum(torch.abs(pred_ages - gt_ages) >= L + tol)
48
+ cs_score = num_correct / n
49
+ return cs_score
50
+
51
+
52
+ class ParseKwargs(argparse.Action):
53
+ def __call__(self, parser, namespace, values, option_string=None):
54
+ kw = {}
55
+ for value in values:
56
+ key, value = value.split("=")
57
+ try:
58
+ kw[key] = ast.literal_eval(value)
59
+ except ValueError:
60
+ kw[key] = str(value) # fallback to string (avoid need to escape on command line)
61
+ setattr(namespace, self.dest, kw)
62
+
63
+
64
+ def box_iou(box1, box2, over_second=False):
65
+ """
66
+ Return intersection-over-union (Jaccard index) of boxes.
67
+ If over_second == True, return mean(intersection-over-union, (inter / area2))
68
+
69
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
70
+
71
+ Arguments:
72
+ box1 (Tensor[N, 4])
73
+ box2 (Tensor[M, 4])
74
+ Returns:
75
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
76
+ IoU values for every element in boxes1 and boxes2
77
+ """
78
+
79
+ def box_area(box):
80
+ # box = 4xn
81
+ return (box[2] - box[0]) * (box[3] - box[1])
82
+
83
+ area1 = box_area(box1.T)
84
+ area2 = box_area(box2.T)
85
+
86
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
87
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
88
+
89
+ iou = inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
90
+ if over_second:
91
+ return (inter / area2 + iou) / 2 # mean(inter / area2, iou)
92
+ else:
93
+ return iou
94
+
95
+
96
+ def split_batch(bs: int, dev: int) -> Tuple[int, int]:
97
+ full_bs = (bs // dev) * dev
98
+ part_bs = bs - full_bs
99
+ return full_bs, part_bs
100
+
101
+
102
+ def assign_faces(
103
+ persons_bboxes: List[torch.tensor], faces_bboxes: List[torch.tensor], iou_thresh: float = 0.0001
104
+ ) -> Tuple[List[Optional[int]], List[int]]:
105
+ """
106
+ Assign person to each face if it is possible.
107
+ Return:
108
+ - assigned_faces List[Optional[int]]: mapping of face_ind to person_ind
109
+ ( assigned_faces[face_ind] = person_ind ). person_ind can be None
110
+ - unassigned_persons_inds List[int]: persons indexes without any assigned face
111
+ """
112
+
113
+ assigned_faces: List[Optional[int]] = [None for _ in range(len(faces_bboxes))]
114
+ unassigned_persons_inds: List[int] = [p_ind for p_ind in range(len(persons_bboxes))]
115
+
116
+ if len(persons_bboxes) == 0 or len(faces_bboxes) == 0:
117
+ return assigned_faces, unassigned_persons_inds
118
+
119
+ cost_matrix = box_iou(torch.stack(persons_bboxes), torch.stack(faces_bboxes), over_second=True).cpu().numpy()
120
+ persons_indexes, face_indexes = [], []
121
+
122
+ if len(cost_matrix) > 0:
123
+ persons_indexes, face_indexes = linear_sum_assignment(cost_matrix, maximize=True)
124
+
125
+ matched_persons = set()
126
+ for person_idx, face_idx in zip(persons_indexes, face_indexes):
127
+ ciou = cost_matrix[person_idx][face_idx]
128
+ if ciou > iou_thresh:
129
+ if person_idx in matched_persons:
130
+ # Person can not be assigned twice, in reality this should not happen
131
+ continue
132
+ assigned_faces[face_idx] = person_idx
133
+ matched_persons.add(person_idx)
134
+
135
+ unassigned_persons_inds = [p_ind for p_ind in range(len(persons_bboxes)) if p_ind not in matched_persons]
136
+
137
+ return assigned_faces, unassigned_persons_inds
138
+
139
+
140
+ def class_letterbox(im, new_shape=(640, 640), color=(0, 0, 0), scaleup=True):
141
+ # Resize and pad image while meeting stride-multiple constraints
142
+ shape = im.shape[:2] # current shape [height, width]
143
+ if isinstance(new_shape, int):
144
+ new_shape = (new_shape, new_shape)
145
+
146
+ if im.shape[0] == new_shape[0] and im.shape[1] == new_shape[1]:
147
+ return im
148
+
149
+ # Scale ratio (new / old)
150
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
151
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
152
+ r = min(r, 1.0)
153
+
154
+ # Compute padding
155
+ # ratio = r, r # width, height ratios
156
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
157
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
158
+
159
+ dw /= 2 # divide padding into 2 sides
160
+ dh /= 2
161
+
162
+ if shape[::-1] != new_unpad: # resize
163
+ im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
164
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
165
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
166
+ im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
167
+ return im
168
+
169
+
170
+ def prepare_classification_images(
171
+ img_list: List[Optional[np.ndarray]],
172
+ target_size: int = 224,
173
+ mean=IMAGENET_DEFAULT_MEAN,
174
+ std=IMAGENET_DEFAULT_STD,
175
+ device=None,
176
+ ) -> torch.tensor:
177
+
178
+ prepared_images: List[torch.tensor] = []
179
+
180
+ for img in img_list:
181
+ if img is None:
182
+ img = torch.zeros((3, target_size, target_size), dtype=torch.float32)
183
+ img = F.normalize(img, mean=mean, std=std)
184
+ img = img.unsqueeze(0)
185
+ prepared_images.append(img)
186
+ continue
187
+ img = class_letterbox(img, new_shape=(target_size, target_size))
188
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
189
+
190
+ img = img / 255.0
191
+ img = (img - mean) / std
192
+ img = img.astype(dtype=np.float32)
193
+
194
+ img = img.transpose((2, 0, 1))
195
+ img = np.ascontiguousarray(img)
196
+ img = torch.from_numpy(img)
197
+ img = img.unsqueeze(0)
198
+
199
+ prepared_images.append(img)
200
+
201
+ if len(prepared_images) == 0:
202
+ return None
203
+
204
+ prepared_input = torch.concat(prepared_images)
205
+
206
+ if device:
207
+ prepared_input = prepared_input.to(device)
208
+
209
+ return prepared_input
210
+
211
+
212
+ def IOU(bb1: Union[tuple, list], bb2: Union[tuple, list], norm_second_bbox: bool = False) -> float:
213
+ # expects [ymin, xmin, ymax, xmax], doesnt matter absolute or relative
214
+ assert bb1[1] < bb1[3]
215
+ assert bb1[0] < bb1[2]
216
+ assert bb2[1] < bb2[3]
217
+ assert bb2[0] < bb2[2]
218
+
219
+ # determine the coordinates of the intersection rectangle
220
+ x_left = max(bb1[1], bb2[1])
221
+ y_top = max(bb1[0], bb2[0])
222
+ x_right = min(bb1[3], bb2[3])
223
+ y_bottom = min(bb1[2], bb2[2])
224
+
225
+ if x_right < x_left or y_bottom < y_top:
226
+ return 0.0
227
+
228
+ # The intersection of two axis-aligned bounding boxes is always an
229
+ # axis-aligned bounding box
230
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
231
+ # compute the area of both AABBs
232
+ bb1_area = (bb1[3] - bb1[1]) * (bb1[2] - bb1[0])
233
+ bb2_area = (bb2[3] - bb2[1]) * (bb2[2] - bb2[0])
234
+ if not norm_second_bbox:
235
+ # compute the intersection over union by taking the intersection
236
+ # area and dividing it by the sum of prediction + ground-truth
237
+ # areas - the interesection area
238
+ iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
239
+ else:
240
+ # for cases when we search if second bbox is inside first one
241
+ iou = intersection_area / float(bb2_area)
242
+
243
+ assert iou >= 0.0
244
+ assert iou <= 1.01
245
+
246
+ return iou
mivolo/model/__init__.py ADDED
File without changes
mivolo/model/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (152 Bytes). View file
 
mivolo/model/__pycache__/create_timm_model.cpython-311.pyc ADDED
Binary file (4.94 kB). View file
 
mivolo/model/__pycache__/cross_bottleneck_attn.cpython-311.pyc ADDED
Binary file (6.96 kB). View file
 
mivolo/model/__pycache__/mi_volo.cpython-311.pyc ADDED
Binary file (13.7 kB). View file
 
mivolo/model/__pycache__/mivolo_model.cpython-311.pyc ADDED
Binary file (16.1 kB). View file
 
mivolo/model/__pycache__/yolo_detector.cpython-311.pyc ADDED
Binary file (2.69 kB). View file
 
mivolo/model/create_timm_model.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from timm https://github.com/huggingface/pytorch-image-models
3
+
4
+ Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
+ """
6
+
7
+ import os
8
+ from typing import Any, Dict, Optional, Union
9
+
10
+ import timm
11
+
12
+ # register new models
13
+ from mivolo.model.mivolo_model import * # noqa: F403, F401
14
+ from timm.layers import set_layer_config
15
+ from timm.models._factory import parse_model_name
16
+ from timm.models._helpers import load_state_dict, remap_checkpoint
17
+ from timm.models._hub import load_model_config_from_hf
18
+ from timm.models._pretrained import PretrainedCfg, split_model_name_tag
19
+ from timm.models._registry import is_model, model_entrypoint
20
+
21
+
22
+ def load_checkpoint(
23
+ model, checkpoint_path, use_ema=True, strict=True, remap=False, filter_keys=None, state_dict_map=None
24
+ ):
25
+ if os.path.splitext(checkpoint_path)[-1].lower() in (".npz", ".npy"):
26
+ # numpy checkpoint, try to load via model specific load_pretrained fn
27
+ if hasattr(model, "load_pretrained"):
28
+ timm.models._model_builder.load_pretrained(checkpoint_path)
29
+ else:
30
+ raise NotImplementedError("Model cannot load numpy checkpoint")
31
+ return
32
+ state_dict = load_state_dict(checkpoint_path, use_ema)
33
+ if remap:
34
+ state_dict = remap_checkpoint(model, state_dict)
35
+ if filter_keys:
36
+ for sd_key in list(state_dict.keys()):
37
+ for filter_key in filter_keys:
38
+ if filter_key in sd_key:
39
+ if sd_key in state_dict:
40
+ del state_dict[sd_key]
41
+
42
+ rep = []
43
+ if state_dict_map is not None:
44
+ # 'patch_embed.conv1.' : 'patch_embed.conv.'
45
+ for state_k in list(state_dict.keys()):
46
+ for target_k, target_v in state_dict_map.items():
47
+ if target_v in state_k:
48
+ target_name = state_k.replace(target_v, target_k)
49
+ state_dict[target_name] = state_dict[state_k]
50
+ rep.append(state_k)
51
+ for r in rep:
52
+ if r in state_dict:
53
+ del state_dict[r]
54
+
55
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict if filter_keys is None else False)
56
+ return incompatible_keys
57
+
58
+
59
+ def create_model(
60
+ model_name: str,
61
+ pretrained: bool = False,
62
+ pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
63
+ pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
64
+ checkpoint_path: str = "",
65
+ scriptable: Optional[bool] = None,
66
+ exportable: Optional[bool] = None,
67
+ no_jit: Optional[bool] = None,
68
+ filter_keys=None,
69
+ state_dict_map=None,
70
+ **kwargs,
71
+ ):
72
+ """Create a model
73
+ Lookup model's entrypoint function and pass relevant args to create a new model.
74
+ """
75
+ # Parameters that aren't supported by all models or are intended to only override model defaults if set
76
+ # should default to None in command line args/cfg. Remove them if they are present and not set so that
77
+ # non-supporting models don't break and default args remain in effect.
78
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
79
+
80
+ model_source, model_name = parse_model_name(model_name)
81
+ if model_source == "hf-hub":
82
+ assert not pretrained_cfg, "pretrained_cfg should not be set when sourcing model from Hugging Face Hub."
83
+ # For model names specified in the form `hf-hub:path/architecture_name@revision`,
84
+ # load model weights + pretrained_cfg from Hugging Face hub.
85
+ pretrained_cfg, model_name = load_model_config_from_hf(model_name)
86
+ else:
87
+ model_name, pretrained_tag = split_model_name_tag(model_name)
88
+ if not pretrained_cfg:
89
+ # a valid pretrained_cfg argument takes priority over tag in model name
90
+ pretrained_cfg = pretrained_tag
91
+
92
+ if not is_model(model_name):
93
+ raise RuntimeError("Unknown model (%s)" % model_name)
94
+
95
+ create_fn = model_entrypoint(model_name)
96
+ with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
97
+ model = create_fn(
98
+ pretrained=pretrained,
99
+ pretrained_cfg=pretrained_cfg,
100
+ pretrained_cfg_overlay=pretrained_cfg_overlay,
101
+ **kwargs,
102
+ )
103
+
104
+ if checkpoint_path:
105
+ load_checkpoint(model, checkpoint_path, filter_keys=filter_keys, state_dict_map=state_dict_map)
106
+
107
+ return model
mivolo/model/cross_bottleneck_attn.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code based on timm https://github.com/huggingface/pytorch-image-models
3
+
4
+ Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from timm.layers.bottleneck_attn import PosEmbedRel
10
+ from timm.layers.helpers import make_divisible
11
+ from timm.layers.mlp import Mlp
12
+ from timm.layers.trace_utils import _assert
13
+ from timm.layers.weight_init import trunc_normal_
14
+
15
+
16
+ class CrossBottleneckAttn(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim,
20
+ dim_out=None,
21
+ feat_size=None,
22
+ stride=1,
23
+ num_heads=4,
24
+ dim_head=None,
25
+ qk_ratio=1.0,
26
+ qkv_bias=False,
27
+ scale_pos_embed=False,
28
+ ):
29
+ super().__init__()
30
+ assert feat_size is not None, "A concrete feature size matching expected input (H, W) is required"
31
+ dim_out = dim_out or dim
32
+ assert dim_out % num_heads == 0
33
+
34
+ self.num_heads = num_heads
35
+ self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
36
+ self.dim_head_v = dim_out // self.num_heads
37
+ self.dim_out_qk = num_heads * self.dim_head_qk
38
+ self.dim_out_v = num_heads * self.dim_head_v
39
+ self.scale = self.dim_head_qk**-0.5
40
+ self.scale_pos_embed = scale_pos_embed
41
+
42
+ self.qkv_f = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
43
+ self.qkv_p = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
44
+
45
+ # NOTE I'm only supporting relative pos embedding for now
46
+ self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale)
47
+
48
+ self.norm = nn.LayerNorm([self.dim_out_v * 2, *feat_size])
49
+ mlp_ratio = 4
50
+ self.mlp = Mlp(
51
+ in_features=self.dim_out_v * 2,
52
+ hidden_features=int(dim * mlp_ratio),
53
+ act_layer=nn.GELU,
54
+ out_features=dim_out,
55
+ drop=0,
56
+ use_conv=True,
57
+ )
58
+
59
+ self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
60
+ self.reset_parameters()
61
+
62
+ def reset_parameters(self):
63
+ trunc_normal_(self.qkv_f.weight, std=self.qkv_f.weight.shape[1] ** -0.5) # fan-in
64
+ trunc_normal_(self.qkv_p.weight, std=self.qkv_p.weight.shape[1] ** -0.5) # fan-in
65
+ trunc_normal_(self.pos_embed.height_rel, std=self.scale)
66
+ trunc_normal_(self.pos_embed.width_rel, std=self.scale)
67
+
68
+ def get_qkv(self, x, qvk_conv):
69
+ B, C, H, W = x.shape
70
+
71
+ x = qvk_conv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
72
+
73
+ q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1)
74
+
75
+ q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2)
76
+ k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
77
+ v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
78
+
79
+ return q, k, v
80
+
81
+ def apply_attn(self, q, k, v, B, H, W, dropout=None):
82
+ if self.scale_pos_embed:
83
+ attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W
84
+ else:
85
+ attn = (q @ k) * self.scale + self.pos_embed(q)
86
+ attn = attn.softmax(dim=-1)
87
+ if dropout:
88
+ attn = dropout(attn)
89
+
90
+ out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
91
+ return out
92
+
93
+ def forward(self, x):
94
+ B, C, H, W = x.shape
95
+
96
+ dim = int(C / 2)
97
+ x1 = x[:, :dim, :, :]
98
+ x2 = x[:, dim:, :, :]
99
+
100
+ _assert(H == self.pos_embed.height, "")
101
+ _assert(W == self.pos_embed.width, "")
102
+
103
+ q_f, k_f, v_f = self.get_qkv(x1, self.qkv_f)
104
+ q_p, k_p, v_p = self.get_qkv(x2, self.qkv_p)
105
+
106
+ # person to face
107
+ out_f = self.apply_attn(q_f, k_p, v_p, B, H, W)
108
+ # face to person
109
+ out_p = self.apply_attn(q_p, k_f, v_f, B, H, W)
110
+
111
+ x_pf = torch.cat((out_f, out_p), dim=1) # B, dim_out * 2, H, W
112
+ x_pf = self.norm(x_pf)
113
+ x_pf = self.mlp(x_pf) # B, dim_out, H, W
114
+
115
+ out = self.pool(x_pf)
116
+ return out
mivolo/model/mi_volo.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from mivolo.data.misc import prepare_classification_images
7
+ from mivolo.model.create_timm_model import create_model
8
+ from mivolo.structures import PersonAndFaceCrops, PersonAndFaceResult
9
+ from timm.data import resolve_data_config
10
+
11
+ _logger = logging.getLogger("MiVOLO")
12
+ has_compile = hasattr(torch, "compile")
13
+
14
+
15
+ class Meta:
16
+ def __init__(self):
17
+ self.min_age = None
18
+ self.max_age = None
19
+ self.avg_age = None
20
+ self.num_classes = None
21
+
22
+ self.in_chans = 3
23
+ self.with_persons_model = False
24
+ self.disable_faces = False
25
+ self.use_persons = True
26
+ self.only_age = False
27
+
28
+ self.num_classes_gender = 2
29
+ self.input_size = 224
30
+
31
+ def load_from_ckpt(self, ckpt_path: str, disable_faces: bool = False, use_persons: bool = True) -> "Meta":
32
+
33
+ state = torch.load(ckpt_path, map_location="cpu")
34
+
35
+ self.min_age = state["min_age"]
36
+ self.max_age = state["max_age"]
37
+ self.avg_age = state["avg_age"]
38
+ self.only_age = state["no_gender"]
39
+
40
+ only_age = state["no_gender"]
41
+
42
+ self.disable_faces = disable_faces
43
+ if "with_persons_model" in state:
44
+ self.with_persons_model = state["with_persons_model"]
45
+ else:
46
+ self.with_persons_model = True if "patch_embed.conv1.0.weight" in state["state_dict"] else False
47
+
48
+ self.num_classes = 1 if only_age else 3
49
+ self.in_chans = 3 if not self.with_persons_model else 6
50
+ self.use_persons = use_persons and self.with_persons_model
51
+
52
+ if not self.with_persons_model and self.disable_faces:
53
+ raise ValueError("You can not use disable-faces for faces-only model")
54
+ if self.with_persons_model and self.disable_faces and not self.use_persons:
55
+ raise ValueError(
56
+ "You can not disable faces and persons together. "
57
+ "Set --with-persons if you want to run with --disable-faces"
58
+ )
59
+ self.input_size = state["state_dict"]["pos_embed"].shape[1] * 16
60
+ return self
61
+
62
+ def __str__(self):
63
+ attrs = vars(self)
64
+ attrs.update({"use_person_crops": self.use_person_crops, "use_face_crops": self.use_face_crops})
65
+ return ", ".join("%s: %s" % item for item in attrs.items())
66
+
67
+ @property
68
+ def use_person_crops(self) -> bool:
69
+ return self.with_persons_model and self.use_persons
70
+
71
+ @property
72
+ def use_face_crops(self) -> bool:
73
+ return not self.disable_faces or not self.with_persons_model
74
+
75
+
76
+ class MiVOLO:
77
+ def __init__(
78
+ self,
79
+ ckpt_path: str,
80
+ device: str = "cuda",
81
+ half: bool = True,
82
+ disable_faces: bool = False,
83
+ use_persons: bool = True,
84
+ verbose: bool = False,
85
+ torchcompile: Optional[str] = None,
86
+ ):
87
+ self.verbose = verbose
88
+ self.device = torch.device(device)
89
+ self.half = half and self.device.type != "cpu"
90
+
91
+ self.meta: Meta = Meta().load_from_ckpt(ckpt_path, disable_faces, use_persons)
92
+ if self.verbose:
93
+ _logger.info(f"Model meta:\n{str(self.meta)}")
94
+
95
+ model_name = f"mivolo_d1_{self.meta.input_size}"
96
+ self.model = create_model(
97
+ model_name=model_name,
98
+ num_classes=self.meta.num_classes,
99
+ in_chans=self.meta.in_chans,
100
+ pretrained=False,
101
+ checkpoint_path=ckpt_path,
102
+ filter_keys=["fds."],
103
+ )
104
+ self.param_count = sum([m.numel() for m in self.model.parameters()])
105
+ _logger.info(f"Model {model_name} created, param count: {self.param_count}")
106
+
107
+ self.data_config = resolve_data_config(
108
+ model=self.model,
109
+ verbose=verbose,
110
+ use_test_size=True,
111
+ )
112
+
113
+ self.data_config["crop_pct"] = 1.0
114
+ c, h, w = self.data_config["input_size"]
115
+ assert h == w, "Incorrect data_config"
116
+ self.input_size = w
117
+
118
+ self.model = self.model.to(self.device)
119
+
120
+ if torchcompile:
121
+ assert has_compile, "A version of torch w/ torch.compile() is required for --compile, possibly a nightly."
122
+ torch._dynamo.reset()
123
+ self.model = torch.compile(self.model, backend=torchcompile)
124
+
125
+ self.model.eval()
126
+ if self.half:
127
+ self.model = self.model.half()
128
+
129
+ def warmup(self, batch_size: int, steps=10):
130
+ if self.meta.with_persons_model:
131
+ input_size = (6, self.input_size, self.input_size)
132
+ else:
133
+ input_size = self.data_config["input_size"]
134
+
135
+ input = torch.randn((batch_size,) + tuple(input_size)).to(self.device)
136
+
137
+ for _ in range(steps):
138
+ out = self.inference(input) # noqa: F841
139
+
140
+ if torch.cuda.is_available():
141
+ torch.cuda.synchronize()
142
+
143
+ def inference(self, model_input: torch.tensor) -> torch.tensor:
144
+
145
+ with torch.no_grad():
146
+ if self.half:
147
+ model_input = model_input.half()
148
+ output = self.model(model_input)
149
+ return output
150
+
151
+ def predict(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
152
+ if (
153
+ (detected_bboxes.n_objects == 0)
154
+ or (not self.meta.use_persons and detected_bboxes.n_faces == 0)
155
+ or (self.meta.disable_faces and detected_bboxes.n_persons == 0)
156
+ ):
157
+ # nothing to process
158
+ return
159
+
160
+ faces_input, person_input, faces_inds, bodies_inds = self.prepare_crops(image, detected_bboxes)
161
+
162
+ if faces_input is None and person_input is None:
163
+ # nothing to process
164
+ return
165
+
166
+ if self.meta.with_persons_model:
167
+ model_input = torch.cat((faces_input, person_input), dim=1)
168
+ else:
169
+ model_input = faces_input
170
+ output = self.inference(model_input)
171
+
172
+ # write gender and age results into detected_bboxes
173
+ self.fill_in_results(output, detected_bboxes, faces_inds, bodies_inds)
174
+
175
+ def fill_in_results(self, output, detected_bboxes, faces_inds, bodies_inds):
176
+ if self.meta.only_age:
177
+ age_output = output
178
+ gender_probs, gender_indx = None, None
179
+ else:
180
+ age_output = output[:, 2]
181
+ gender_output = output[:, :2].softmax(-1)
182
+ gender_probs, gender_indx = gender_output.topk(1)
183
+
184
+ assert output.shape[0] == len(faces_inds) == len(bodies_inds)
185
+
186
+ # per face
187
+ for index in range(output.shape[0]):
188
+ face_ind = faces_inds[index]
189
+ body_ind = bodies_inds[index]
190
+
191
+ # get_age
192
+ age = age_output[index].item()
193
+ age = age * (self.meta.max_age - self.meta.min_age) + self.meta.avg_age
194
+ age = round(age, 2)
195
+
196
+ detected_bboxes.set_age(face_ind, age)
197
+ detected_bboxes.set_age(body_ind, age)
198
+
199
+ _logger.info(f"\tage: {age}")
200
+
201
+ if gender_probs is not None:
202
+ gender = "male" if gender_indx[index].item() == 0 else "female"
203
+ gender_score = gender_probs[index].item()
204
+
205
+ _logger.info(f"\tgender: {gender} [{int(gender_score * 100)}%]")
206
+
207
+ detected_bboxes.set_gender(face_ind, gender, gender_score)
208
+ detected_bboxes.set_gender(body_ind, gender, gender_score)
209
+
210
+ def prepare_crops(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
211
+
212
+ if self.meta.use_person_crops and self.meta.use_face_crops:
213
+ detected_bboxes.associate_faces_with_persons()
214
+
215
+ crops: PersonAndFaceCrops = detected_bboxes.collect_crops(image)
216
+ (bodies_inds, bodies_crops), (faces_inds, faces_crops) = crops.get_faces_with_bodies(
217
+ self.meta.use_person_crops, self.meta.use_face_crops
218
+ )
219
+
220
+ if not self.meta.use_face_crops:
221
+ assert all(f is None for f in faces_crops)
222
+
223
+ faces_input = prepare_classification_images(
224
+ faces_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device
225
+ )
226
+
227
+ if not self.meta.use_person_crops:
228
+ assert all(p is None for p in bodies_crops)
229
+
230
+ person_input = prepare_classification_images(
231
+ bodies_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device
232
+ )
233
+
234
+ _logger.info(
235
+ f"faces_input: {faces_input.shape if faces_input is not None else None}, "
236
+ f"person_input: {person_input.shape if person_input is not None else None}"
237
+ )
238
+
239
+ return faces_input, person_input, faces_inds, bodies_inds
240
+
241
+
242
+ if __name__ == "__main__":
243
+ model = MiVOLO("../pretrained/checkpoint-377.pth.tar", half=True, device="cuda:0")
mivolo/model/mivolo_model.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from timm https://github.com/huggingface/pytorch-image-models
3
+
4
+ Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from mivolo.model.cross_bottleneck_attn import CrossBottleneckAttn
10
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11
+ from timm.layers import trunc_normal_
12
+ from timm.models._builder import build_model_with_cfg
13
+ from timm.models._registry import register_model
14
+ from timm.models.volo import VOLO
15
+
16
+ __all__ = ["MiVOLOModel"] # model_registry will add each entrypoint fn to this
17
+
18
+
19
+ def _cfg(url="", **kwargs):
20
+ return {
21
+ "url": url,
22
+ "num_classes": 1000,
23
+ "input_size": (3, 224, 224),
24
+ "pool_size": None,
25
+ "crop_pct": 0.96,
26
+ "interpolation": "bicubic",
27
+ "fixed_input_size": True,
28
+ "mean": IMAGENET_DEFAULT_MEAN,
29
+ "std": IMAGENET_DEFAULT_STD,
30
+ "first_conv": None,
31
+ "classifier": ("head", "aux_head"),
32
+ **kwargs,
33
+ }
34
+
35
+
36
+ default_cfgs = {
37
+ "mivolo_d1_224": _cfg(
38
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar", crop_pct=0.96
39
+ ),
40
+ "mivolo_d1_384": _cfg(
41
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar",
42
+ crop_pct=1.0,
43
+ input_size=(3, 384, 384),
44
+ ),
45
+ "mivolo_d2_224": _cfg(
46
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar", crop_pct=0.96
47
+ ),
48
+ "mivolo_d2_384": _cfg(
49
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar",
50
+ crop_pct=1.0,
51
+ input_size=(3, 384, 384),
52
+ ),
53
+ "mivolo_d3_224": _cfg(
54
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar", crop_pct=0.96
55
+ ),
56
+ "mivolo_d3_448": _cfg(
57
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar",
58
+ crop_pct=1.0,
59
+ input_size=(3, 448, 448),
60
+ ),
61
+ "mivolo_d4_224": _cfg(
62
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar", crop_pct=0.96
63
+ ),
64
+ "mivolo_d4_448": _cfg(
65
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar",
66
+ crop_pct=1.15,
67
+ input_size=(3, 448, 448),
68
+ ),
69
+ "mivolo_d5_224": _cfg(
70
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar", crop_pct=0.96
71
+ ),
72
+ "mivolo_d5_448": _cfg(
73
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar",
74
+ crop_pct=1.15,
75
+ input_size=(3, 448, 448),
76
+ ),
77
+ "mivolo_d5_512": _cfg(
78
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar",
79
+ crop_pct=1.15,
80
+ input_size=(3, 512, 512),
81
+ ),
82
+ }
83
+
84
+
85
+ def get_output_size(input_shape, conv_layer):
86
+ padding = conv_layer.padding
87
+ dilation = conv_layer.dilation
88
+ kernel_size = conv_layer.kernel_size
89
+ stride = conv_layer.stride
90
+
91
+ output_size = [
92
+ ((input_shape[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) // stride[i]) + 1 for i in range(2)
93
+ ]
94
+ return output_size
95
+
96
+
97
+ def get_output_size_module(input_size, stem):
98
+ output_size = input_size
99
+
100
+ for module in stem:
101
+ if isinstance(module, nn.Conv2d):
102
+ output_size = [
103
+ (
104
+ (output_size[i] + 2 * module.padding[i] - module.dilation[i] * (module.kernel_size[i] - 1) - 1)
105
+ // module.stride[i]
106
+ )
107
+ + 1
108
+ for i in range(2)
109
+ ]
110
+
111
+ return output_size
112
+
113
+
114
+ class PatchEmbed(nn.Module):
115
+ """Image to Patch Embedding."""
116
+
117
+ def __init__(
118
+ self, img_size=224, stem_conv=False, stem_stride=1, patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384
119
+ ):
120
+ super().__init__()
121
+ assert patch_size in [4, 8, 16]
122
+ assert in_chans in [3, 6]
123
+ self.with_persons_model = in_chans == 6
124
+ self.use_cross_attn = True
125
+
126
+ if stem_conv:
127
+ if not self.with_persons_model:
128
+ self.conv = self.create_stem(stem_stride, in_chans, hidden_dim)
129
+ else:
130
+ self.conv = True # just to match interface
131
+ # split
132
+ self.conv1 = self.create_stem(stem_stride, 3, hidden_dim)
133
+ self.conv2 = self.create_stem(stem_stride, 3, hidden_dim)
134
+ else:
135
+ self.conv = None
136
+
137
+ if self.with_persons_model:
138
+
139
+ self.proj1 = nn.Conv2d(
140
+ hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
141
+ )
142
+ self.proj2 = nn.Conv2d(
143
+ hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
144
+ )
145
+
146
+ stem_out_shape = get_output_size_module((img_size, img_size), self.conv1)
147
+ self.proj_output_size = get_output_size(stem_out_shape, self.proj1)
148
+
149
+ self.map = CrossBottleneckAttn(embed_dim, dim_out=embed_dim, num_heads=1, feat_size=self.proj_output_size)
150
+
151
+ else:
152
+ self.proj = nn.Conv2d(
153
+ hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
154
+ )
155
+
156
+ self.patch_dim = img_size // patch_size
157
+ self.num_patches = self.patch_dim**2
158
+
159
+ def create_stem(self, stem_stride, in_chans, hidden_dim):
160
+ return nn.Sequential(
161
+ nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False), # 112x112
162
+ nn.BatchNorm2d(hidden_dim),
163
+ nn.ReLU(inplace=True),
164
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
165
+ nn.BatchNorm2d(hidden_dim),
166
+ nn.ReLU(inplace=True),
167
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
168
+ nn.BatchNorm2d(hidden_dim),
169
+ nn.ReLU(inplace=True),
170
+ )
171
+
172
+ def forward(self, x):
173
+ if self.conv is not None:
174
+ if self.with_persons_model:
175
+ x1 = x[:, :3]
176
+ x2 = x[:, 3:]
177
+
178
+ x1 = self.conv1(x1)
179
+ x1 = self.proj1(x1)
180
+
181
+ x2 = self.conv2(x2)
182
+ x2 = self.proj2(x2)
183
+
184
+ x = torch.cat([x1, x2], dim=1)
185
+ x = self.map(x)
186
+ else:
187
+ x = self.conv(x)
188
+ x = self.proj(x) # B, C, H, W
189
+
190
+ return x
191
+
192
+
193
+ class MiVOLOModel(VOLO):
194
+ """
195
+ Vision Outlooker, the main class of our model
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ layers,
201
+ img_size=224,
202
+ in_chans=3,
203
+ num_classes=1000,
204
+ global_pool="token",
205
+ patch_size=8,
206
+ stem_hidden_dim=64,
207
+ embed_dims=None,
208
+ num_heads=None,
209
+ downsamples=(True, False, False, False),
210
+ outlook_attention=(True, False, False, False),
211
+ mlp_ratio=3.0,
212
+ qkv_bias=False,
213
+ drop_rate=0.0,
214
+ attn_drop_rate=0.0,
215
+ drop_path_rate=0.0,
216
+ norm_layer=nn.LayerNorm,
217
+ post_layers=("ca", "ca"),
218
+ use_aux_head=True,
219
+ use_mix_token=False,
220
+ pooling_scale=2,
221
+ ):
222
+ super().__init__(
223
+ layers,
224
+ img_size,
225
+ in_chans,
226
+ num_classes,
227
+ global_pool,
228
+ patch_size,
229
+ stem_hidden_dim,
230
+ embed_dims,
231
+ num_heads,
232
+ downsamples,
233
+ outlook_attention,
234
+ mlp_ratio,
235
+ qkv_bias,
236
+ drop_rate,
237
+ attn_drop_rate,
238
+ drop_path_rate,
239
+ norm_layer,
240
+ post_layers,
241
+ use_aux_head,
242
+ use_mix_token,
243
+ pooling_scale,
244
+ )
245
+
246
+ im_size = img_size[0] if isinstance(img_size, tuple) else img_size
247
+ self.patch_embed = PatchEmbed(
248
+ img_size=im_size,
249
+ stem_conv=True,
250
+ stem_stride=2,
251
+ patch_size=patch_size,
252
+ in_chans=in_chans,
253
+ hidden_dim=stem_hidden_dim,
254
+ embed_dim=embed_dims[0],
255
+ )
256
+
257
+ trunc_normal_(self.pos_embed, std=0.02)
258
+ self.apply(self._init_weights)
259
+
260
+ def forward_features(self, x):
261
+ x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
262
+
263
+ # step2: tokens learning in the two stages
264
+ x = self.forward_tokens(x)
265
+
266
+ # step3: post network, apply class attention or not
267
+ if self.post_network is not None:
268
+ x = self.forward_cls(x)
269
+ x = self.norm(x)
270
+ return x
271
+
272
+ def forward_head(self, x, pre_logits: bool = False, targets=None, epoch=None):
273
+ if self.global_pool == "avg":
274
+ out = x.mean(dim=1)
275
+ elif self.global_pool == "token":
276
+ out = x[:, 0]
277
+ else:
278
+ out = x
279
+ if pre_logits:
280
+ return out
281
+
282
+ features = out
283
+ fds_enabled = hasattr(self, "_fds_forward")
284
+ if fds_enabled:
285
+ features = self._fds_forward(features, targets, epoch)
286
+
287
+ out = self.head(features)
288
+ if self.aux_head is not None:
289
+ # generate classes in all feature tokens, see token labeling
290
+ aux = self.aux_head(x[:, 1:])
291
+ out = out + 0.5 * aux.max(1)[0]
292
+
293
+ return (out, features) if (fds_enabled and self.training) else out
294
+
295
+ def forward(self, x, targets=None, epoch=None):
296
+ """simplified forward (without mix token training)"""
297
+ x = self.forward_features(x)
298
+ x = self.forward_head(x, targets=targets, epoch=epoch)
299
+ return x
300
+
301
+
302
+ def _create_mivolo(variant, pretrained=False, **kwargs):
303
+ if kwargs.get("features_only", None):
304
+ raise RuntimeError("features_only not implemented for Vision Transformer models.")
305
+ return build_model_with_cfg(MiVOLOModel, variant, pretrained, **kwargs)
306
+
307
+
308
+ @register_model
309
+ def mivolo_d1_224(pretrained=False, **kwargs):
310
+ model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
311
+ model = _create_mivolo("mivolo_d1_224", pretrained=pretrained, **model_args)
312
+ return model
313
+
314
+
315
+ @register_model
316
+ def mivolo_d1_384(pretrained=False, **kwargs):
317
+ model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
318
+ model = _create_mivolo("mivolo_d1_384", pretrained=pretrained, **model_args)
319
+ return model
320
+
321
+
322
+ @register_model
323
+ def mivolo_d2_224(pretrained=False, **kwargs):
324
+ model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
325
+ model = _create_mivolo("mivolo_d2_224", pretrained=pretrained, **model_args)
326
+ return model
327
+
328
+
329
+ @register_model
330
+ def mivolo_d2_384(pretrained=False, **kwargs):
331
+ model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
332
+ model = _create_mivolo("mivolo_d2_384", pretrained=pretrained, **model_args)
333
+ return model
334
+
335
+
336
+ @register_model
337
+ def mivolo_d3_224(pretrained=False, **kwargs):
338
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
339
+ model = _create_mivolo("mivolo_d3_224", pretrained=pretrained, **model_args)
340
+ return model
341
+
342
+
343
+ @register_model
344
+ def mivolo_d3_448(pretrained=False, **kwargs):
345
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
346
+ model = _create_mivolo("mivolo_d3_448", pretrained=pretrained, **model_args)
347
+ return model
348
+
349
+
350
+ @register_model
351
+ def mivolo_d4_224(pretrained=False, **kwargs):
352
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
353
+ model = _create_mivolo("mivolo_d4_224", pretrained=pretrained, **model_args)
354
+ return model
355
+
356
+
357
+ @register_model
358
+ def mivolo_d4_448(pretrained=False, **kwargs):
359
+ """VOLO-D4 model, Params: 193M"""
360
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
361
+ model = _create_mivolo("mivolo_d4_448", pretrained=pretrained, **model_args)
362
+ return model
363
+
364
+
365
+ @register_model
366
+ def mivolo_d5_224(pretrained=False, **kwargs):
367
+ model_args = dict(
368
+ layers=(12, 12, 20, 4),
369
+ embed_dims=(384, 768, 768, 768),
370
+ num_heads=(12, 16, 16, 16),
371
+ mlp_ratio=4,
372
+ stem_hidden_dim=128,
373
+ **kwargs
374
+ )
375
+ model = _create_mivolo("mivolo_d5_224", pretrained=pretrained, **model_args)
376
+ return model
377
+
378
+
379
+ @register_model
380
+ def mivolo_d5_448(pretrained=False, **kwargs):
381
+ model_args = dict(
382
+ layers=(12, 12, 20, 4),
383
+ embed_dims=(384, 768, 768, 768),
384
+ num_heads=(12, 16, 16, 16),
385
+ mlp_ratio=4,
386
+ stem_hidden_dim=128,
387
+ **kwargs
388
+ )
389
+ model = _create_mivolo("mivolo_d5_448", pretrained=pretrained, **model_args)
390
+ return model
391
+
392
+
393
+ @register_model
394
+ def mivolo_d5_512(pretrained=False, **kwargs):
395
+ model_args = dict(
396
+ layers=(12, 12, 20, 4),
397
+ embed_dims=(384, 768, 768, 768),
398
+ num_heads=(12, 16, 16, 16),
399
+ mlp_ratio=4,
400
+ stem_hidden_dim=128,
401
+ **kwargs
402
+ )
403
+ model = _create_mivolo("mivolo_d5_512", pretrained=pretrained, **model_args)
404
+ return model
mivolo/model/yolo_detector.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, Union
3
+
4
+ import numpy as np
5
+ import PIL
6
+ import torch
7
+ from mivolo.structures import PersonAndFaceResult
8
+ from ultralytics import YOLO
9
+ from ultralytics.engine.results import Results
10
+
11
+ # because of ultralytics bug it is important to unset CUBLAS_WORKSPACE_CONFIG after the module importing
12
+ os.unsetenv("CUBLAS_WORKSPACE_CONFIG")
13
+
14
+
15
+ class Detector:
16
+ def __init__(
17
+ self,
18
+ weights: str,
19
+ device: str = "cuda",
20
+ half: bool = True,
21
+ verbose: bool = False,
22
+ conf_thresh: float = 0.4,
23
+ iou_thresh: float = 0.7,
24
+ ):
25
+ self.yolo = YOLO(weights)
26
+ self.yolo.fuse()
27
+
28
+ self.device = torch.device(device)
29
+ self.half = half and self.device.type != "cpu"
30
+
31
+ if self.half:
32
+ self.yolo.model = self.yolo.model.half()
33
+
34
+ self.detector_names: Dict[int, str] = self.yolo.model.names
35
+
36
+ # init yolo.predictor
37
+ self.detector_kwargs = {"conf": conf_thresh, "iou": iou_thresh, "half": self.half, "verbose": verbose}
38
+ # self.yolo.predict(**self.detector_kwargs)
39
+
40
+ def predict(self, image: Union[np.ndarray, str, "PIL.Image"]) -> PersonAndFaceResult:
41
+ results: Results = self.yolo.predict(image, **self.detector_kwargs)[0]
42
+ return PersonAndFaceResult(results)
43
+
44
+ def track(self, image: Union[np.ndarray, str, "PIL.Image"]) -> PersonAndFaceResult:
45
+ results: Results = self.yolo.track(image, persist=True, **self.detector_kwargs)[0]
46
+ return PersonAndFaceResult(results)
mivolo/predictor.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import Dict, Generator, List, Optional, Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import tqdm
7
+ from mivolo.model.mi_volo import MiVOLO
8
+ from mivolo.model.yolo_detector import Detector
9
+ from mivolo.structures import AGE_GENDER_TYPE, PersonAndFaceResult
10
+
11
+
12
+ class Predictor:
13
+ def __init__(self, config, verbose: bool = False):
14
+ self.detector = Detector(config.detector_weights, config.device, verbose=verbose)
15
+ self.age_gender_model = MiVOLO(
16
+ config.checkpoint,
17
+ config.device,
18
+ half=True,
19
+ use_persons=config.with_persons,
20
+ disable_faces=config.disable_faces,
21
+ verbose=verbose,
22
+ )
23
+ self.draw = config.draw
24
+
25
+ def recognize(self, image: np.ndarray) -> Tuple[PersonAndFaceResult, Optional[np.ndarray]]:
26
+ detected_objects: PersonAndFaceResult = self.detector.predict(image)
27
+ self.age_gender_model.predict(image, detected_objects)
28
+
29
+ out_im = None
30
+ if self.draw:
31
+ # plot results on image
32
+ out_im = detected_objects.plot()
33
+
34
+ return detected_objects, out_im
35
+
36
+ def recognize_video(self, source: str) -> Generator:
37
+ video_capture = cv2.VideoCapture(source)
38
+ if not video_capture.isOpened():
39
+ raise ValueError(f"Failed to open video source {source}")
40
+
41
+ detected_objects_history: Dict[int, List[AGE_GENDER_TYPE]] = defaultdict(list)
42
+
43
+ total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
44
+ for _ in tqdm.tqdm(range(total_frames)):
45
+ ret, frame = video_capture.read()
46
+ if not ret:
47
+ break
48
+
49
+ detected_objects: PersonAndFaceResult = self.detector.track(frame)
50
+ self.age_gender_model.predict(frame, detected_objects)
51
+
52
+ current_frame_objs = detected_objects.get_results_for_tracking()
53
+ cur_persons: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[0]
54
+ cur_faces: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[1]
55
+
56
+ # add tr_persons and tr_faces to history
57
+ for guid, data in cur_persons.items():
58
+ # not useful for tracking :)
59
+ if None not in data:
60
+ detected_objects_history[guid].append(data)
61
+ for guid, data in cur_faces.items():
62
+ if None not in data:
63
+ detected_objects_history[guid].append(data)
64
+
65
+ detected_objects.set_tracked_age_gender(detected_objects_history)
66
+ if self.draw:
67
+ frame = detected_objects.plot()
68
+ yield detected_objects_history, frame
mivolo/structures.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from copy import deepcopy
4
+ from typing import Dict, List, Optional, Tuple
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from mivolo.data.misc import aggregate_votes_winsorized, assign_faces, box_iou
10
+ from ultralytics.engine.results import Results
11
+ from ultralytics.utils.plotting import Annotator, colors
12
+
13
+ # because of ultralytics bug it is important to unset CUBLAS_WORKSPACE_CONFIG after the module importing
14
+ os.unsetenv("CUBLAS_WORKSPACE_CONFIG")
15
+
16
+ AGE_GENDER_TYPE = Tuple[float, str]
17
+
18
+
19
+ class PersonAndFaceCrops:
20
+ def __init__(self):
21
+ # int: index of person along results
22
+ self.crops_persons: Dict[int, np.ndarray] = {}
23
+
24
+ # int: index of face along results
25
+ self.crops_faces: Dict[int, np.ndarray] = {}
26
+
27
+ # int: index of face along results
28
+ self.crops_faces_wo_body: Dict[int, np.ndarray] = {}
29
+
30
+ # int: index of person along results
31
+ self.crops_persons_wo_face: Dict[int, np.ndarray] = {}
32
+
33
+ def _add_to_output(
34
+ self, crops: Dict[int, np.ndarray], out_crops: List[np.ndarray], out_crop_inds: List[Optional[int]]
35
+ ):
36
+ inds_to_add = list(crops.keys())
37
+ crops_to_add = list(crops.values())
38
+ out_crops.extend(crops_to_add)
39
+ out_crop_inds.extend(inds_to_add)
40
+
41
+ def _get_all_faces(
42
+ self, use_persons: bool, use_faces: bool
43
+ ) -> Tuple[List[Optional[int]], List[Optional[np.ndarray]]]:
44
+ """
45
+ Returns
46
+ if use_persons and use_faces
47
+ faces: faces_with_bodies + faces_without_bodies + [None] * len(crops_persons_wo_face)
48
+ if use_persons and not use_faces
49
+ faces: [None] * n_persons
50
+ if not use_persons and use_faces:
51
+ faces: faces_with_bodies + faces_without_bodies
52
+ """
53
+
54
+ def add_none_to_output(faces_inds, faces_crops, num):
55
+ faces_inds.extend([None for _ in range(num)])
56
+ faces_crops.extend([None for _ in range(num)])
57
+
58
+ faces_inds: List[Optional[int]] = []
59
+ faces_crops: List[Optional[np.ndarray]] = []
60
+
61
+ if not use_faces:
62
+ add_none_to_output(faces_inds, faces_crops, len(self.crops_persons) + len(self.crops_persons_wo_face))
63
+ return faces_inds, faces_crops
64
+
65
+ self._add_to_output(self.crops_faces, faces_crops, faces_inds)
66
+ self._add_to_output(self.crops_faces_wo_body, faces_crops, faces_inds)
67
+
68
+ if use_persons:
69
+ add_none_to_output(faces_inds, faces_crops, len(self.crops_persons_wo_face))
70
+
71
+ return faces_inds, faces_crops
72
+
73
+ def _get_all_bodies(
74
+ self, use_persons: bool, use_faces: bool
75
+ ) -> Tuple[List[Optional[int]], List[Optional[np.ndarray]]]:
76
+ """
77
+ Returns
78
+ if use_persons and use_faces
79
+ persons: bodies_with_faces + [None] * len(faces_without_bodies) + bodies_without_faces
80
+ if use_persons and not use_faces
81
+ persons: bodies_with_faces + bodies_without_faces
82
+ if not use_persons and use_faces
83
+ persons: [None] * n_faces
84
+ """
85
+
86
+ def add_none_to_output(bodies_inds, bodies_crops, num):
87
+ bodies_inds.extend([None for _ in range(num)])
88
+ bodies_crops.extend([None for _ in range(num)])
89
+
90
+ bodies_inds: List[Optional[int]] = []
91
+ bodies_crops: List[Optional[np.ndarray]] = []
92
+
93
+ if not use_persons:
94
+ add_none_to_output(bodies_inds, bodies_crops, len(self.crops_faces) + len(self.crops_faces_wo_body))
95
+ return bodies_inds, bodies_crops
96
+
97
+ self._add_to_output(self.crops_persons, bodies_crops, bodies_inds)
98
+ if use_faces:
99
+ add_none_to_output(bodies_inds, bodies_crops, len(self.crops_faces_wo_body))
100
+
101
+ self._add_to_output(self.crops_persons_wo_face, bodies_crops, bodies_inds)
102
+
103
+ return bodies_inds, bodies_crops
104
+
105
+ def get_faces_with_bodies(self, use_persons: bool, use_faces: bool):
106
+ """
107
+ Return
108
+ faces: faces_with_bodies, faces_without_bodies, [None] * len(crops_persons_wo_face)
109
+ persons: bodies_with_faces, [None] * len(faces_without_bodies), bodies_without_faces
110
+ """
111
+
112
+ bodies_inds, bodies_crops = self._get_all_bodies(use_persons, use_faces)
113
+ faces_inds, faces_crops = self._get_all_faces(use_persons, use_faces)
114
+
115
+ return (bodies_inds, bodies_crops), (faces_inds, faces_crops)
116
+
117
+ def save(self, out_dir="output"):
118
+ ind = 0
119
+ os.makedirs(out_dir, exist_ok=True)
120
+ for crops in [self.crops_persons, self.crops_faces, self.crops_faces_wo_body, self.crops_persons_wo_face]:
121
+ for crop in crops.values():
122
+ if crop is None:
123
+ continue
124
+ out_name = os.path.join(out_dir, f"{ind}_crop.jpg")
125
+ cv2.imwrite(out_name, crop)
126
+ ind += 1
127
+
128
+
129
+ class PersonAndFaceResult:
130
+ def __init__(self, results: Results):
131
+
132
+ self.yolo_results = results
133
+ names = set(results.names.values())
134
+ assert "person" in names and "face" in names
135
+
136
+ # initially no faces and persons are associated to each other
137
+ self.face_to_person_map: Dict[int, Optional[int]] = {ind: None for ind in self.get_bboxes_inds("face")}
138
+ self.unassigned_persons_inds: List[int] = self.get_bboxes_inds("person")
139
+ n_objects = len(self.yolo_results.boxes)
140
+ self.ages: List[Optional[float]] = [None for _ in range(n_objects)]
141
+ self.genders: List[Optional[str]] = [None for _ in range(n_objects)]
142
+ self.gender_scores: List[Optional[float]] = [None for _ in range(n_objects)]
143
+
144
+ @property
145
+ def n_objects(self) -> int:
146
+ return len(self.yolo_results.boxes)
147
+
148
+ @property
149
+ def n_faces(self) -> int:
150
+ return len(self.get_bboxes_inds("face"))
151
+
152
+ @property
153
+ def n_persons(self) -> int:
154
+ return len(self.get_bboxes_inds("person"))
155
+
156
+ def get_bboxes_inds(self, category: str) -> List[int]:
157
+ bboxes: List[int] = []
158
+ for ind, det in enumerate(self.yolo_results.boxes):
159
+ name = self.yolo_results.names[int(det.cls)]
160
+ if name == category:
161
+ bboxes.append(ind)
162
+
163
+ return bboxes
164
+
165
+ def get_distance_to_center(self, bbox_ind: int) -> float:
166
+ """
167
+ Calculate euclidian distance between bbox center and image center.
168
+ """
169
+ im_h, im_w = self.yolo_results[bbox_ind].orig_shape
170
+ x1, y1, x2, y2 = self.get_bbox_by_ind(bbox_ind).cpu().numpy()
171
+ center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
172
+ dist = math.dist([center_x, center_y], [im_w / 2, im_h / 2])
173
+ return dist
174
+
175
+ def plot(
176
+ self,
177
+ conf=False,
178
+ line_width=None,
179
+ font_size=None,
180
+ font="Arial.ttf",
181
+ pil=False,
182
+ img=None,
183
+ labels=True,
184
+ boxes=True,
185
+ probs=True,
186
+ ages=True,
187
+ genders=True,
188
+ gender_probs=False,
189
+ ):
190
+ """
191
+ Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
192
+ Args:
193
+ conf (bool): Whether to plot the detection confidence score.
194
+ line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
195
+ font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
196
+ font (str): The font to use for the text.
197
+ pil (bool): Whether to return the image as a PIL Image.
198
+ img (numpy.ndarray): Plot to another image. if not, plot to original image.
199
+ labels (bool): Whether to plot the label of bounding boxes.
200
+ boxes (bool): Whether to plot the bounding boxes.
201
+ probs (bool): Whether to plot classification probability
202
+ ages (bool): Whether to plot the age of bounding boxes.
203
+ genders (bool): Whether to plot the genders of bounding boxes.
204
+ gender_probs (bool): Whether to plot gender classification probability
205
+ Returns:
206
+ (numpy.ndarray): A numpy array of the annotated image.
207
+ """
208
+
209
+ # return self.yolo_results.plot()
210
+ colors_by_ind = {}
211
+ for face_ind, person_ind in self.face_to_person_map.items():
212
+ if person_ind is not None:
213
+ colors_by_ind[face_ind] = face_ind + 2
214
+ colors_by_ind[person_ind] = face_ind + 2
215
+ else:
216
+ colors_by_ind[face_ind] = 0
217
+ for person_ind in self.unassigned_persons_inds:
218
+ colors_by_ind[person_ind] = 1
219
+
220
+ names = self.yolo_results.names
221
+ annotator = Annotator(
222
+ deepcopy(self.yolo_results.orig_img if img is None else img),
223
+ line_width,
224
+ font_size,
225
+ font,
226
+ pil,
227
+ example=names,
228
+ )
229
+ pred_boxes, show_boxes = self.yolo_results.boxes, boxes
230
+ pred_probs, show_probs = self.yolo_results.probs, probs
231
+
232
+ if pred_boxes and show_boxes:
233
+ for bb_ind, (d, age, gender, gender_score) in enumerate(
234
+ zip(pred_boxes, self.ages, self.genders, self.gender_scores)
235
+ ):
236
+ c, conf, guid = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
237
+ name = ("" if guid is None else f"id:{guid} ") + names[c]
238
+ label = (f"{name} {conf:.2f}" if conf else name) if labels else None
239
+ if ages and age is not None:
240
+ label += f" {age:.1f}"
241
+ if genders and gender is not None:
242
+ label += f" {'F' if gender == 'female' else 'M'}"
243
+ if gender_probs and gender_score is not None:
244
+ label += f" ({gender_score:.1f})"
245
+ annotator.box_label(d.xyxy.squeeze(), label, color=colors(colors_by_ind[bb_ind], True))
246
+
247
+ if pred_probs is not None and show_probs:
248
+ text = f"{', '.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)}, "
249
+ annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
250
+
251
+ return annotator.result()
252
+
253
+ def set_tracked_age_gender(self, tracked_objects: Dict[int, List[AGE_GENDER_TYPE]]):
254
+ """
255
+ Update age and gender for objects based on history from tracked_objects.
256
+ Args:
257
+ tracked_objects (dict[int, list[AGE_GENDER_TYPE]]): info about tracked objects by guid
258
+ """
259
+
260
+ for face_ind, person_ind in self.face_to_person_map.items():
261
+ pguid = self._get_id_by_ind(person_ind)
262
+ fguid = self._get_id_by_ind(face_ind)
263
+
264
+ if fguid == -1 and pguid == -1:
265
+ # YOLO might not assign ids for some objects in some cases:
266
+ # https://github.com/ultralytics/ultralytics/issues/3830
267
+ continue
268
+ age, gender = self._gather_tracking_result(tracked_objects, fguid, pguid)
269
+ if age is None or gender is None:
270
+ continue
271
+ self.set_age(face_ind, age)
272
+ self.set_gender(face_ind, gender, 1.0)
273
+ if pguid != -1:
274
+ self.set_gender(person_ind, gender, 1.0)
275
+ self.set_age(person_ind, age)
276
+
277
+ for person_ind in self.unassigned_persons_inds:
278
+ pid = self._get_id_by_ind(person_ind)
279
+ if pid == -1:
280
+ continue
281
+ age, gender = self._gather_tracking_result(tracked_objects, -1, pid)
282
+ if age is None or gender is None:
283
+ continue
284
+ self.set_gender(person_ind, gender, 1.0)
285
+ self.set_age(person_ind, age)
286
+
287
+ def _get_id_by_ind(self, ind: Optional[int] = None) -> int:
288
+ if ind is None:
289
+ return -1
290
+ obj_id = self.yolo_results.boxes[ind].id
291
+ if obj_id is None:
292
+ return -1
293
+ return obj_id.item()
294
+
295
+ def get_bbox_by_ind(self, ind: int, im_h: int = None, im_w: int = None) -> torch.tensor:
296
+ bb = self.yolo_results.boxes[ind].xyxy.squeeze().type(torch.int32)
297
+ if im_h is not None and im_w is not None:
298
+ bb[0] = torch.clamp(bb[0], min=0, max=im_w - 1)
299
+ bb[1] = torch.clamp(bb[1], min=0, max=im_h - 1)
300
+ bb[2] = torch.clamp(bb[2], min=0, max=im_w - 1)
301
+ bb[3] = torch.clamp(bb[3], min=0, max=im_h - 1)
302
+ return bb
303
+
304
+ def set_age(self, ind: Optional[int], age: float):
305
+ if ind is not None:
306
+ self.ages[ind] = age
307
+
308
+ def set_gender(self, ind: Optional[int], gender: str, gender_score: float):
309
+ if ind is not None:
310
+ self.genders[ind] = gender
311
+ self.gender_scores[ind] = gender_score
312
+
313
+ @staticmethod
314
+ def _gather_tracking_result(
315
+ tracked_objects: Dict[int, List[AGE_GENDER_TYPE]],
316
+ fguid: int = -1,
317
+ pguid: int = -1,
318
+ minimum_sample_size: int = 10,
319
+ ) -> AGE_GENDER_TYPE:
320
+
321
+ assert fguid != -1 or pguid != -1, "Incorrect tracking behaviour"
322
+
323
+ face_ages = [r[0] for r in tracked_objects[fguid] if r[0] is not None] if fguid in tracked_objects else []
324
+ face_genders = [r[1] for r in tracked_objects[fguid] if r[1] is not None] if fguid in tracked_objects else []
325
+ person_ages = [r[0] for r in tracked_objects[pguid] if r[0] is not None] if pguid in tracked_objects else []
326
+ person_genders = [r[1] for r in tracked_objects[pguid] if r[1] is not None] if pguid in tracked_objects else []
327
+
328
+ if not face_ages and not person_ages: # both empty
329
+ return None, None
330
+
331
+ # You can play here with different aggregation strategies
332
+ # Face ages - predictions based on face or face + person, depends on history of object
333
+ # Person ages - predictions based on person or face + person, depends on history of object
334
+
335
+ if len(person_ages + face_ages) >= minimum_sample_size:
336
+ age = aggregate_votes_winsorized(person_ages + face_ages)
337
+ else:
338
+ face_age = np.mean(face_ages) if face_ages else None
339
+ person_age = np.mean(person_ages) if person_ages else None
340
+ if face_age is None:
341
+ face_age = person_age
342
+ if person_age is None:
343
+ person_age = face_age
344
+ age = (face_age + person_age) / 2.0
345
+
346
+ genders = face_genders + person_genders
347
+ assert len(genders) > 0
348
+ # take mode of genders
349
+ gender = max(set(genders), key=genders.count)
350
+
351
+ return age, gender
352
+
353
+ def get_results_for_tracking(self) -> Tuple[Dict[int, AGE_GENDER_TYPE], Dict[int, AGE_GENDER_TYPE]]:
354
+ """
355
+ Get objects from current frame
356
+ """
357
+ persons: Dict[int, AGE_GENDER_TYPE] = {}
358
+ faces: Dict[int, AGE_GENDER_TYPE] = {}
359
+
360
+ names = self.yolo_results.names
361
+ pred_boxes = self.yolo_results.boxes
362
+ for _, (det, age, gender, _) in enumerate(zip(pred_boxes, self.ages, self.genders, self.gender_scores)):
363
+ if det.id is None:
364
+ continue
365
+ cat_id, _, guid = int(det.cls), float(det.conf), int(det.id.item())
366
+ name = names[cat_id]
367
+ if name == "person":
368
+ persons[guid] = (age, gender)
369
+ elif name == "face":
370
+ faces[guid] = (age, gender)
371
+
372
+ return persons, faces
373
+
374
+ def associate_faces_with_persons(self):
375
+ face_bboxes_inds: List[int] = self.get_bboxes_inds("face")
376
+ person_bboxes_inds: List[int] = self.get_bboxes_inds("person")
377
+
378
+ face_bboxes: List[torch.tensor] = [self.get_bbox_by_ind(ind) for ind in face_bboxes_inds]
379
+ person_bboxes: List[torch.tensor] = [self.get_bbox_by_ind(ind) for ind in person_bboxes_inds]
380
+
381
+ self.face_to_person_map = {ind: None for ind in face_bboxes_inds}
382
+ assigned_faces, unassigned_persons_inds = assign_faces(person_bboxes, face_bboxes)
383
+
384
+ for face_ind, person_ind in enumerate(assigned_faces):
385
+ face_ind = face_bboxes_inds[face_ind]
386
+ person_ind = person_bboxes_inds[person_ind] if person_ind is not None else None
387
+ self.face_to_person_map[face_ind] = person_ind
388
+
389
+ self.unassigned_persons_inds = [person_bboxes_inds[person_ind] for person_ind in unassigned_persons_inds]
390
+
391
+ def crop_object(
392
+ self, full_image: np.ndarray, ind: int, cut_other_classes: Optional[List[str]] = None
393
+ ) -> Optional[np.ndarray]:
394
+
395
+ IOU_THRESH = 0.000001
396
+ MIN_PERSON_CROP_AFTERCUT_RATIO = 0.4
397
+ CROP_ROUND_RATE = 0.3
398
+ MIN_PERSON_SIZE = 50
399
+
400
+ obj_bbox = self.get_bbox_by_ind(ind, *full_image.shape[:2])
401
+ x1, y1, x2, y2 = obj_bbox
402
+ cur_cat = self.yolo_results.names[int(self.yolo_results.boxes[ind].cls)]
403
+ # get crop of face or person
404
+ obj_image = full_image[y1:y2, x1:x2].copy()
405
+ crop_h, crop_w = obj_image.shape[:2]
406
+
407
+ if cur_cat == "person" and (crop_h < MIN_PERSON_SIZE or crop_w < MIN_PERSON_SIZE):
408
+ return None
409
+
410
+ if not cut_other_classes:
411
+ return obj_image
412
+
413
+ # calc iou between obj_bbox and other bboxes
414
+ other_bboxes: List[torch.tensor] = [
415
+ self.get_bbox_by_ind(other_ind, *full_image.shape[:2]) for other_ind in range(len(self.yolo_results.boxes))
416
+ ]
417
+
418
+ iou_matrix = box_iou(torch.stack([obj_bbox]), torch.stack(other_bboxes)).cpu().numpy()[0]
419
+
420
+ # cut out other objects in case of intersection
421
+ for other_ind, (det, iou) in enumerate(zip(self.yolo_results.boxes, iou_matrix)):
422
+ other_cat = self.yolo_results.names[int(det.cls)]
423
+ if ind == other_ind or iou < IOU_THRESH or other_cat not in cut_other_classes:
424
+ continue
425
+ o_x1, o_y1, o_x2, o_y2 = det.xyxy.squeeze().type(torch.int32)
426
+
427
+ # remap current_person_bbox to reference_person_bbox coordinates
428
+ o_x1 = max(o_x1 - x1, 0)
429
+ o_y1 = max(o_y1 - y1, 0)
430
+ o_x2 = min(o_x2 - x1, crop_w)
431
+ o_y2 = min(o_y2 - y1, crop_h)
432
+
433
+ if other_cat != "face":
434
+ if (o_y1 / crop_h) < CROP_ROUND_RATE:
435
+ o_y1 = 0
436
+ if ((crop_h - o_y2) / crop_h) < CROP_ROUND_RATE:
437
+ o_y2 = crop_h
438
+ if (o_x1 / crop_w) < CROP_ROUND_RATE:
439
+ o_x1 = 0
440
+ if ((crop_w - o_x2) / crop_w) < CROP_ROUND_RATE:
441
+ o_x2 = crop_w
442
+
443
+ obj_image[o_y1:o_y2, o_x1:o_x2] = 0
444
+
445
+ remain_ratio = np.count_nonzero(obj_image) / (obj_image.shape[0] * obj_image.shape[1] * obj_image.shape[2])
446
+ if remain_ratio < MIN_PERSON_CROP_AFTERCUT_RATIO:
447
+ return None
448
+
449
+ return obj_image
450
+
451
+ def collect_crops(self, image) -> PersonAndFaceCrops:
452
+
453
+ crops_data = PersonAndFaceCrops()
454
+ for face_ind, person_ind in self.face_to_person_map.items():
455
+ face_image = self.crop_object(image, face_ind, cut_other_classes=[])
456
+
457
+ if person_ind is None:
458
+ crops_data.crops_faces_wo_body[face_ind] = face_image
459
+ continue
460
+
461
+ person_image = self.crop_object(image, person_ind, cut_other_classes=["face", "person"])
462
+
463
+ crops_data.crops_faces[face_ind] = face_image
464
+ crops_data.crops_persons[person_ind] = person_image
465
+
466
+ for person_ind in self.unassigned_persons_inds:
467
+ person_image = self.crop_object(image, person_ind, cut_other_classes=["face", "person"])
468
+ crops_data.crops_persons_wo_face[person_ind] = person_image
469
+
470
+ # uncomment to save preprocessed crops
471
+ # crops_data.save()
472
+ return crops_data
mivolo/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.6.0dev"
requirements.txt CHANGED
@@ -14,4 +14,9 @@ gunicorn>=20.1.0
14
  uvicorn>=0.29.0
15
  fastapi>=0.110.0
16
  tf-keras>=2.16.0
17
- python-multipart>=0.0.9
 
 
 
 
 
 
14
  uvicorn>=0.29.0
15
  fastapi>=0.110.0
16
  tf-keras>=2.16.0
17
+ python-multipart>=0.0.9
18
+ huggingface_hub
19
+ ultralytics==8.1.0
20
+ timm==0.8.13.dev0
21
+ yt_dlp
22
+ lapx>=0.5.2