Student0809 commited on
Commit
8c78b88
·
verified ·
1 Parent(s): 7feac49

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. docs/resources/grpo_countdown.png +3 -0
  3. docs/resources/grpo_geoqa.png +3 -0
  4. docs/resources/grpo_openr1_multimodal.png +3 -0
  5. docs/transformers/build/lib/transformers/models/depth_anything/convert_distill_any_depth_to_hf.py +246 -0
  6. docs/transformers/build/lib/transformers/models/depth_anything/modeling_depth_anything.py +469 -0
  7. docs/transformers/build/lib/transformers/models/depth_pro/configuration_depth_pro.py +205 -0
  8. docs/transformers/build/lib/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py +254 -0
  9. docs/transformers/build/lib/transformers/models/depth_pro/image_processing_depth_pro.py +392 -0
  10. docs/transformers/build/lib/transformers/models/depth_pro/image_processing_depth_pro_fast.py +189 -0
  11. docs/transformers/build/lib/transformers/models/depth_pro/modeling_depth_pro.py +1218 -0
  12. docs/transformers/build/lib/transformers/models/detr/__init__.py +31 -0
  13. docs/transformers/build/lib/transformers/models/detr/configuration_detr.py +289 -0
  14. docs/transformers/build/lib/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py +277 -0
  15. docs/transformers/build/lib/transformers/models/detr/convert_detr_to_pytorch.py +385 -0
  16. docs/transformers/build/lib/transformers/models/detr/feature_extraction_detr.py +48 -0
  17. docs/transformers/build/lib/transformers/models/detr/image_processing_detr_fast.py +1312 -0
  18. docs/transformers/build/lib/transformers/models/detr/modeling_detr.py +1815 -0
  19. docs/transformers/build/lib/transformers/models/dialogpt/__init__.py +0 -0
  20. docs/transformers/build/lib/transformers/models/dialogpt/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py +46 -0
  21. docs/transformers/build/lib/transformers/models/diffllama/__init__.py +27 -0
  22. docs/transformers/build/lib/transformers/models/diffllama/configuration_diffllama.py +199 -0
  23. docs/transformers/build/lib/transformers/models/esm/openfold_utils/rigid_utils.py +1242 -0
  24. docs/transformers/build/lib/transformers/models/falcon/configuration_falcon.py +211 -0
  25. docs/transformers/build/lib/transformers/models/falcon/convert_custom_code_checkpoint.py +74 -0
  26. docs/transformers/build/lib/transformers/models/falcon/modeling_falcon.py +1566 -0
  27. docs/transformers/build/lib/transformers/models/falcon_mamba/__init__.py +27 -0
  28. docs/transformers/build/lib/transformers/models/falcon_mamba/configuration_falcon_mamba.py +162 -0
  29. docs/transformers/build/lib/transformers/models/falcon_mamba/modeling_falcon_mamba.py +873 -0
  30. docs/transformers/build/lib/transformers/models/fastspeech2_conformer/__init__.py +28 -0
  31. docs/transformers/build/lib/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +480 -0
  32. docs/transformers/build/lib/transformers/models/fastspeech2_conformer/convert_fastspeech2_conformer_original_pytorch_checkpoint_to_pytorch.py +210 -0
  33. docs/transformers/build/lib/transformers/models/fastspeech2_conformer/convert_hifigan.py +134 -0
  34. docs/transformers/build/lib/transformers/models/fastspeech2_conformer/convert_model_with_hifigan.py +102 -0
  35. docs/transformers/build/lib/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +1697 -0
  36. docs/transformers/build/lib/transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +188 -0
  37. docs/transformers/build/lib/transformers/models/flaubert/__init__.py +29 -0
  38. docs/transformers/build/lib/transformers/models/flaubert/configuration_flaubert.py +235 -0
  39. docs/transformers/build/lib/transformers/models/flaubert/modeling_flaubert.py +1739 -0
  40. docs/transformers/build/lib/transformers/models/flaubert/modeling_tf_flaubert.py +1344 -0
  41. docs/transformers/build/lib/transformers/models/flaubert/tokenization_flaubert.py +568 -0
  42. docs/transformers/build/lib/transformers/models/flava/__init__.py +31 -0
  43. docs/transformers/build/lib/transformers/models/flava/configuration_flava.py +701 -0
  44. docs/transformers/build/lib/transformers/models/flava/convert_dalle_to_flava_codebook.py +102 -0
  45. docs/transformers/build/lib/transformers/models/flava/convert_flava_original_pytorch_to_hf.py +99 -0
  46. docs/transformers/build/lib/transformers/models/flava/feature_extraction_flava.py +38 -0
  47. docs/transformers/build/lib/transformers/models/flava/image_processing_flava.py +705 -0
  48. docs/transformers/build/lib/transformers/models/flava/image_processing_flava_fast.py +549 -0
  49. docs/transformers/build/lib/transformers/models/flava/modeling_flava.py +2127 -0
  50. docs/transformers/build/lib/transformers/models/flava/processing_flava.py +168 -0
.gitattributes CHANGED
@@ -48,3 +48,6 @@ wandb/offline-run-20250624_115955-iye05c18/run-iye05c18.wandb filter=lfs diff=lf
48
  wandb/offline-run-20250721_000454-up3efnok/run-up3efnok.wandb filter=lfs diff=lfs merge=lfs -text
49
  wandb/offline-run-20250722_003110-femxkckf/run-femxkckf.wandb filter=lfs diff=lfs merge=lfs -text
50
  seamless_interaction/assets/banner.gif filter=lfs diff=lfs merge=lfs -text
 
 
 
 
48
  wandb/offline-run-20250721_000454-up3efnok/run-up3efnok.wandb filter=lfs diff=lfs merge=lfs -text
49
  wandb/offline-run-20250722_003110-femxkckf/run-femxkckf.wandb filter=lfs diff=lfs merge=lfs -text
50
  seamless_interaction/assets/banner.gif filter=lfs diff=lfs merge=lfs -text
51
+ docs/resources/grpo_countdown.png filter=lfs diff=lfs merge=lfs -text
52
+ docs/resources/grpo_geoqa.png filter=lfs diff=lfs merge=lfs -text
53
+ docs/resources/grpo_openr1_multimodal.png filter=lfs diff=lfs merge=lfs -text
docs/resources/grpo_countdown.png ADDED

Git LFS Details

  • SHA256: 1b55fe6864e0c92549940d6989d92b3ab22be38a035cff3694525252737fc91e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.23 MB
docs/resources/grpo_geoqa.png ADDED

Git LFS Details

  • SHA256: 71246376b16f2ff288542dca2ff31532b16ef99f5e862797463d548e447e1f8d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.24 MB
docs/resources/grpo_openr1_multimodal.png ADDED

Git LFS Details

  • SHA256: 050f56792468a4c9797a90314e322c16dd916bde3be24a7ce7c7b96381e70d9e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.3 MB
docs/transformers/build/lib/transformers/models/depth_anything/convert_distill_any_depth_to_hf.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert Distill Any Depth checkpoints from the original repository. URL:
16
+ https://github.com/Westlake-AGI-Lab/Distill-Any-Depth"""
17
+
18
+ import argparse
19
+ import re
20
+ from pathlib import Path
21
+
22
+ import requests
23
+ import torch
24
+ from huggingface_hub import hf_hub_download
25
+ from PIL import Image
26
+ from safetensors.torch import load_file
27
+
28
+ from transformers import DepthAnythingConfig, DepthAnythingForDepthEstimation, Dinov2Config, DPTImageProcessor
29
+ from transformers.utils import logging
30
+
31
+
32
+ logging.set_verbosity_info()
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
37
+ r"(backbone|pretrained)\.cls_token": r"backbone.embeddings.cls_token",
38
+ r"(backbone|pretrained)\.mask_token": r"backbone.embeddings.mask_token",
39
+ r"(backbone|pretrained)\.pos_embed": r"backbone.embeddings.position_embeddings",
40
+ r"(backbone|pretrained)\.patch_embed\.proj\.(weight|bias)": r"backbone.embeddings.patch_embeddings.projection.\2",
41
+ r"(backbone|pretrained)\.norm\.(weight|bias)": r"backbone.layernorm.\2",
42
+ r"(backbone|pretrained)(\.blocks(\.\d+)?)?\.(\d+)\.attn\.proj\.(weight|bias)": r"backbone.encoder.layer.\4.attention.output.dense.\5",
43
+ r"(backbone|pretrained)(\.blocks(\.\d+)?)?\.(\d+)\.ls(1|2)\.gamma": r"backbone.encoder.layer.\4.layer_scale\5.lambda1",
44
+ r"(backbone|pretrained)(\.blocks(\.\d+)?)?\.(\d+)\.mlp\.fc(1|2)\.(weight|bias)": r"backbone.encoder.layer.\4.mlp.fc\5.\6",
45
+ r"(backbone|pretrained)(\.blocks(\.\d+)?)?\.(\d+)\.norm(1|2)\.(weight|bias)": r"backbone.encoder.layer.\4.norm\5.\6",
46
+ r"depth_head\.projects\.(\d+)\.(weight|bias)": r"neck.reassemble_stage.layers.\1.projection.\2",
47
+ r"depth_head\.resize_layers\.(?!2)(\d+)\.(weight|bias)": r"neck.reassemble_stage.layers.\1.resize.\2",
48
+ r"depth_head\.scratch\.layer(\d+)_rn\.weight": lambda m: f"neck.convs.{int(m[1]) - 1}.weight",
49
+ r"depth_head\.scratch\.output_conv(\d+)(?:\.(\d+))?\.(weight|bias)": lambda m: (
50
+ f"head.conv{int(m[1]) + (int(m[2]) // 2 if m[2] else 0)}.{m[3]}" if m[1] == "2" else f"head.conv{m[1]}.{m[3]}"
51
+ ),
52
+ r"depth_head\.scratch\.refinenet(\d+)\.out_conv\.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{3 - (int(m[1]) - 1)}.projection.{m[2]}",
53
+ r"depth_head\.scratch\.refinenet(\d+)\.resConfUnit(\d+)\.conv(\d+)\.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{3 - (int(m[1]) - 1)}.residual_layer{m[2]}.convolution{m[3]}.{m[4]}",
54
+ }
55
+
56
+
57
+ def get_dpt_config(model_name):
58
+ if "small" in model_name:
59
+ out_indices = [3, 6, 9, 12]
60
+ backbone_config = Dinov2Config.from_pretrained(
61
+ "facebook/dinov2-small", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False
62
+ )
63
+ fusion_hidden_size = 64
64
+ neck_hidden_sizes = [48, 96, 192, 384]
65
+ elif "base" in model_name:
66
+ out_indices = [3, 6, 9, 12]
67
+ backbone_config = Dinov2Config.from_pretrained(
68
+ "facebook/dinov2-base", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False
69
+ )
70
+ fusion_hidden_size = 128
71
+ neck_hidden_sizes = [96, 192, 384, 768]
72
+ elif "large" in model_name:
73
+ out_indices = [5, 12, 18, 24]
74
+ backbone_config = Dinov2Config.from_pretrained(
75
+ "facebook/dinov2-large", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False
76
+ )
77
+ fusion_hidden_size = 256
78
+ neck_hidden_sizes = [256, 512, 1024, 1024]
79
+ else:
80
+ raise NotImplementedError(f"Model not supported: {model_name}")
81
+
82
+ depth_estimation_type = "relative"
83
+ max_depth = None
84
+
85
+ config = DepthAnythingConfig(
86
+ reassemble_hidden_size=backbone_config.hidden_size,
87
+ patch_size=backbone_config.patch_size,
88
+ backbone_config=backbone_config,
89
+ fusion_hidden_size=fusion_hidden_size,
90
+ neck_hidden_sizes=neck_hidden_sizes,
91
+ depth_estimation_type=depth_estimation_type,
92
+ max_depth=max_depth,
93
+ )
94
+
95
+ return config
96
+
97
+
98
+ def convert_key_pattern(key, mapping):
99
+ for pattern, replacement in mapping.items():
100
+ match = re.fullmatch(pattern, key)
101
+ if match:
102
+ if callable(replacement):
103
+ return replacement(match)
104
+ return re.sub(pattern, replacement, key)
105
+ return None
106
+
107
+
108
+ def convert_keys(state_dict, config):
109
+ new_state_dict = {}
110
+ qkv_pattern = r"(backbone|pretrained)(\.blocks(\.\d+)?)?\.(\d+)\.attn\.qkv\.(weight|bias)"
111
+ qkv_keys = [k for k in list(state_dict.keys()) if re.match(qkv_pattern, k)]
112
+ for old_key in qkv_keys:
113
+ value = state_dict.pop(old_key)
114
+ match = re.match(qkv_pattern, old_key)
115
+ _, _, _, layer, attr = match.groups()
116
+ hidden_size = config.backbone_config.hidden_size
117
+ q = value[:hidden_size]
118
+ k = value[hidden_size : hidden_size * 2]
119
+ v = value[-hidden_size:]
120
+
121
+ for proj, tensor in zip(["query", "key", "value"], [q, k, v]):
122
+ new_key = f"backbone.encoder.layer.{layer}.attention.attention.{proj}.{attr}"
123
+ new_state_dict[new_key] = tensor
124
+
125
+ for old_key in list(state_dict.keys()):
126
+ value = state_dict.pop(old_key)
127
+ new_key = convert_key_pattern(old_key, ORIGINAL_TO_CONVERTED_KEY_MAPPING)
128
+
129
+ new_state_dict[new_key] = value
130
+
131
+ return new_state_dict
132
+
133
+
134
+ def prepare_img():
135
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
136
+ return Image.open(requests.get(url, stream=True).raw)
137
+
138
+
139
+ name_to_checkpoint = {
140
+ "distill-any-depth-small": "small/model.safetensors",
141
+ "distill-any-depth-base": "base/model.safetensors",
142
+ "distill-any-depth-large": "large/model.safetensors",
143
+ }
144
+
145
+
146
+ @torch.no_grad()
147
+ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, verify_logits):
148
+ config = get_dpt_config(model_name)
149
+
150
+ repo_id = "xingyang1/Distill-Any-Depth"
151
+ filepath = hf_hub_download(repo_id=repo_id, filename=name_to_checkpoint[model_name])
152
+ state_dict = load_file(filepath)
153
+
154
+ converted_state_dict = convert_keys(state_dict, config)
155
+
156
+ model = DepthAnythingForDepthEstimation(config)
157
+ model.load_state_dict(converted_state_dict)
158
+ model.eval()
159
+
160
+ processor = DPTImageProcessor(
161
+ do_resize=True,
162
+ size={"height": 518, "width": 518},
163
+ ensure_multiple_of=14,
164
+ keep_aspect_ratio=True,
165
+ do_rescale=True,
166
+ do_normalize=True,
167
+ image_mean=[0.485, 0.456, 0.406],
168
+ image_std=[0.229, 0.224, 0.225],
169
+ )
170
+
171
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
172
+ image = Image.open(requests.get(url, stream=True).raw)
173
+
174
+ pixel_values = processor(image, return_tensors="pt").pixel_values
175
+
176
+ with torch.no_grad():
177
+ outputs = model(pixel_values)
178
+ predicted_depth = outputs.predicted_depth
179
+
180
+ print("Shape of predicted depth:", predicted_depth.shape)
181
+ print("First values:", predicted_depth[0, :3, :3])
182
+
183
+ if verify_logits:
184
+ print("Verifying logits...")
185
+ expected_shape = torch.Size([1, 518, 686])
186
+
187
+ if model_name == "distill-any-depth-small":
188
+ expected_slice = torch.tensor(
189
+ [[2.5653, 2.5249, 2.5570], [2.4897, 2.5235, 2.5355], [2.5255, 2.5261, 2.5422]]
190
+ )
191
+ elif model_name == "distill-any-depth-base":
192
+ expected_slice = torch.tensor(
193
+ [[4.8976, 4.9075, 4.9403], [4.8872, 4.8906, 4.9448], [4.8712, 4.8898, 4.8838]]
194
+ )
195
+ elif model_name == "distill-any-depth-large":
196
+ expected_slice = torch.tensor(
197
+ [[55.1067, 51.1828, 51.6803], [51.9098, 50.7529, 51.4494], [50.1745, 50.5491, 50.8818]]
198
+ )
199
+ else:
200
+ raise ValueError("Not supported")
201
+
202
+ assert predicted_depth.shape == torch.Size(expected_shape)
203
+ assert torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-4)
204
+ print("Looks ok!")
205
+
206
+ if pytorch_dump_folder_path is not None:
207
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
208
+ print(f"Saving model and processor to {pytorch_dump_folder_path}")
209
+ model.save_pretrained(pytorch_dump_folder_path)
210
+ processor.save_pretrained(pytorch_dump_folder_path)
211
+
212
+ if push_to_hub:
213
+ print("Pushing model and processor to hub...")
214
+ model.push_to_hub(repo_id=f"{model_name.title()}-hf")
215
+ processor.push_to_hub(repo_id=f"{model_name.title()}-hf")
216
+
217
+
218
+ if __name__ == "__main__":
219
+ parser = argparse.ArgumentParser()
220
+ parser.add_argument(
221
+ "--model_name",
222
+ default="distill-any-depth-small",
223
+ type=str,
224
+ choices=name_to_checkpoint.keys(),
225
+ help="Name of the model you'd like to convert.",
226
+ )
227
+ parser.add_argument(
228
+ "--pytorch_dump_folder_path",
229
+ default=None,
230
+ type=str,
231
+ help="Path to the output PyTorch model directory.",
232
+ )
233
+ parser.add_argument(
234
+ "--push_to_hub",
235
+ action="store_true",
236
+ help="Whether to push the model to the hub after conversion.",
237
+ )
238
+ parser.add_argument(
239
+ "--verify_logits",
240
+ action="store_true",
241
+ required=False,
242
+ help="Whether to verify the logits after conversion.",
243
+ )
244
+
245
+ args = parser.parse_args()
246
+ convert_dpt_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.verify_logits)
docs/transformers/build/lib/transformers/models/depth_anything/modeling_depth_anything.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 TikTok and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Depth Anything model."""
16
+
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+
23
+ from ...file_utils import (
24
+ add_start_docstrings,
25
+ add_start_docstrings_to_model_forward,
26
+ replace_return_docstrings,
27
+ )
28
+ from ...modeling_outputs import DepthEstimatorOutput
29
+ from ...modeling_utils import PreTrainedModel
30
+ from ...utils import logging
31
+ from ...utils.backbone_utils import load_backbone
32
+ from .configuration_depth_anything import DepthAnythingConfig
33
+
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+ # General docstring
38
+ _CONFIG_FOR_DOC = "DepthAnythingConfig"
39
+
40
+
41
+ DEPTH_ANYTHING_START_DOCSTRING = r"""
42
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
43
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
44
+ behavior.
45
+
46
+ Parameters:
47
+ config ([`DepthAnythingConfig`]): Model configuration class with all the parameters of the model.
48
+ Initializing with a config file does not load the weights associated with the model, only the
49
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
50
+ """
51
+
52
+ DEPTH_ANYTHING_INPUTS_DOCSTRING = r"""
53
+ Args:
54
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
55
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`]
56
+ for details.
57
+ output_attentions (`bool`, *optional*):
58
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
59
+ tensors for more detail.
60
+ output_hidden_states (`bool`, *optional*):
61
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
62
+ more detail.
63
+ return_dict (`bool`, *optional*):
64
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
65
+ """
66
+
67
+
68
+ class DepthAnythingReassembleLayer(nn.Module):
69
+ def __init__(self, config, channels, factor):
70
+ super().__init__()
71
+ self.projection = nn.Conv2d(in_channels=config.reassemble_hidden_size, out_channels=channels, kernel_size=1)
72
+
73
+ # up/down sampling depending on factor
74
+ if factor > 1:
75
+ self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0)
76
+ elif factor == 1:
77
+ self.resize = nn.Identity()
78
+ elif factor < 1:
79
+ # so should downsample
80
+ self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1)
81
+
82
+ # Copied from transformers.models.dpt.modeling_dpt.DPTReassembleLayer.forward
83
+ def forward(self, hidden_state):
84
+ hidden_state = self.projection(hidden_state)
85
+ hidden_state = self.resize(hidden_state)
86
+
87
+ return hidden_state
88
+
89
+
90
+ class DepthAnythingReassembleStage(nn.Module):
91
+ """
92
+ This class reassembles the hidden states of the backbone into image-like feature representations at various
93
+ resolutions.
94
+
95
+ This happens in 3 stages:
96
+ 1. Take the patch embeddings and reshape them to image-like feature representations.
97
+ 2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`.
98
+ 3. Resizing the spatial dimensions (height, width).
99
+
100
+ Args:
101
+ config (`[DepthAnythingConfig]`):
102
+ Model configuration class defining the model architecture.
103
+ """
104
+
105
+ def __init__(self, config):
106
+ super().__init__()
107
+
108
+ self.config = config
109
+ self.layers = nn.ModuleList()
110
+ for channels, factor in zip(config.neck_hidden_sizes, config.reassemble_factors):
111
+ self.layers.append(DepthAnythingReassembleLayer(config, channels=channels, factor=factor))
112
+
113
+ def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]:
114
+ """
115
+ Args:
116
+ hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`):
117
+ List of hidden states from the backbone.
118
+ """
119
+ out = []
120
+
121
+ for i, hidden_state in enumerate(hidden_states):
122
+ # reshape to (batch_size, num_channels, height, width)
123
+ hidden_state = hidden_state[:, 1:]
124
+ batch_size, _, num_channels = hidden_state.shape
125
+ hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels)
126
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
127
+ hidden_state = self.layers[i](hidden_state)
128
+ out.append(hidden_state)
129
+
130
+ return out
131
+
132
+
133
+ class DepthAnythingPreActResidualLayer(nn.Module):
134
+ """
135
+ ResidualConvUnit, pre-activate residual unit.
136
+
137
+ Args:
138
+ config (`[DepthAnythingConfig]`):
139
+ Model configuration class defining the model architecture.
140
+ """
141
+
142
+ def __init__(self, config):
143
+ super().__init__()
144
+
145
+ self.activation1 = nn.ReLU()
146
+ self.convolution1 = nn.Conv2d(
147
+ config.fusion_hidden_size,
148
+ config.fusion_hidden_size,
149
+ kernel_size=3,
150
+ stride=1,
151
+ padding=1,
152
+ bias=True,
153
+ )
154
+
155
+ self.activation2 = nn.ReLU()
156
+ self.convolution2 = nn.Conv2d(
157
+ config.fusion_hidden_size,
158
+ config.fusion_hidden_size,
159
+ kernel_size=3,
160
+ stride=1,
161
+ padding=1,
162
+ bias=True,
163
+ )
164
+
165
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
166
+ residual = hidden_state
167
+ hidden_state = self.activation1(hidden_state)
168
+ hidden_state = self.convolution1(hidden_state)
169
+ hidden_state = self.activation2(hidden_state)
170
+ hidden_state = self.convolution2(hidden_state)
171
+
172
+ return hidden_state + residual
173
+
174
+
175
+ class DepthAnythingFeatureFusionLayer(nn.Module):
176
+ """Feature fusion layer, merges feature maps from different stages.
177
+
178
+ Args:
179
+ config (`[DepthAnythingConfig]`):
180
+ Model configuration class defining the model architecture.
181
+ """
182
+
183
+ def __init__(self, config):
184
+ super().__init__()
185
+
186
+ self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
187
+
188
+ self.residual_layer1 = DepthAnythingPreActResidualLayer(config)
189
+ self.residual_layer2 = DepthAnythingPreActResidualLayer(config)
190
+
191
+ def forward(self, hidden_state, residual=None, size=None):
192
+ if residual is not None:
193
+ if hidden_state.shape != residual.shape:
194
+ residual = nn.functional.interpolate(
195
+ residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False
196
+ )
197
+ hidden_state = hidden_state + self.residual_layer1(residual)
198
+
199
+ hidden_state = self.residual_layer2(hidden_state)
200
+
201
+ modifier = {"scale_factor": 2} if size is None else {"size": size}
202
+
203
+ hidden_state = nn.functional.interpolate(
204
+ hidden_state,
205
+ **modifier,
206
+ mode="bilinear",
207
+ align_corners=True,
208
+ )
209
+ hidden_state = self.projection(hidden_state)
210
+
211
+ return hidden_state
212
+
213
+
214
+ class DepthAnythingFeatureFusionStage(nn.Module):
215
+ # Copied from transformers.models.dpt.modeling_dpt.DPTFeatureFusionStage.__init__ with DPT->DepthAnything
216
+ def __init__(self, config):
217
+ super().__init__()
218
+ self.layers = nn.ModuleList()
219
+ for _ in range(len(config.neck_hidden_sizes)):
220
+ self.layers.append(DepthAnythingFeatureFusionLayer(config))
221
+
222
+ def forward(self, hidden_states, size=None):
223
+ # reversing the hidden_states, we start from the last
224
+ hidden_states = hidden_states[::-1]
225
+
226
+ fused_hidden_states = []
227
+ fused_hidden_state = None
228
+
229
+ for idx, (hidden_state, layer) in enumerate(zip(hidden_states, self.layers)):
230
+ size = hidden_states[idx + 1].shape[2:] if idx != (len(hidden_states) - 1) else None
231
+
232
+ if fused_hidden_state is None:
233
+ # first layer only uses the last hidden_state
234
+ fused_hidden_state = layer(hidden_state, size=size)
235
+ else:
236
+ fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size)
237
+
238
+ fused_hidden_states.append(fused_hidden_state)
239
+
240
+ return fused_hidden_states
241
+
242
+
243
+ # Modified from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->DepthAnything,dpt->depth_anything
244
+ # avoiding sdpa and flash_attn_2 support, it's done in the backend
245
+ class DepthAnythingPreTrainedModel(PreTrainedModel):
246
+ """
247
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
248
+ models.
249
+ """
250
+
251
+ config_class = DepthAnythingConfig
252
+ base_model_prefix = "depth_anything"
253
+ main_input_name = "pixel_values"
254
+ supports_gradient_checkpointing = True
255
+
256
+ def _init_weights(self, module):
257
+ """Initialize the weights"""
258
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
259
+ # Slightly different from the TF version which uses truncated_normal for initialization
260
+ # cf https://github.com/pytorch/pytorch/pull/5617
261
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
262
+ if module.bias is not None:
263
+ module.bias.data.zero_()
264
+ elif isinstance(module, nn.LayerNorm):
265
+ module.bias.data.zero_()
266
+ module.weight.data.fill_(1.0)
267
+
268
+
269
+ class DepthAnythingNeck(nn.Module):
270
+ """
271
+ DepthAnythingNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as
272
+ input and produces another list of tensors as output. For DepthAnything, it includes 2 stages:
273
+
274
+ * DepthAnythingReassembleStage
275
+ * DepthAnythingFeatureFusionStage.
276
+
277
+ Args:
278
+ config (dict): config dict.
279
+ """
280
+
281
+ def __init__(self, config):
282
+ super().__init__()
283
+ self.config = config
284
+
285
+ self.reassemble_stage = DepthAnythingReassembleStage(config)
286
+
287
+ self.convs = nn.ModuleList()
288
+ for channel in config.neck_hidden_sizes:
289
+ self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
290
+
291
+ # fusion
292
+ self.fusion_stage = DepthAnythingFeatureFusionStage(config)
293
+
294
+ def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]:
295
+ """
296
+ Args:
297
+ hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
298
+ List of hidden states from the backbone.
299
+ """
300
+ if not isinstance(hidden_states, (tuple, list)):
301
+ raise TypeError("hidden_states should be a tuple or list of tensors")
302
+
303
+ if len(hidden_states) != len(self.config.neck_hidden_sizes):
304
+ raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.")
305
+
306
+ # postprocess hidden states
307
+ hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
308
+
309
+ features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
310
+
311
+ # fusion blocks
312
+ output = self.fusion_stage(features)
313
+
314
+ return output
315
+
316
+
317
+ class DepthAnythingDepthEstimationHead(nn.Module):
318
+ """
319
+ Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
320
+ the predictions to the input resolution after the first convolutional layer (details can be found in the DPT paper's
321
+ supplementary material). The final activation function is either ReLU or Sigmoid, depending on the depth estimation
322
+ type (relative or metric). For metric depth estimation, the output is scaled by the maximum depth used during pretraining.
323
+ """
324
+
325
+ def __init__(self, config):
326
+ super().__init__()
327
+
328
+ self.head_in_index = config.head_in_index
329
+ self.patch_size = config.patch_size
330
+
331
+ features = config.fusion_hidden_size
332
+ self.conv1 = nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1)
333
+ self.conv2 = nn.Conv2d(features // 2, config.head_hidden_size, kernel_size=3, stride=1, padding=1)
334
+ self.activation1 = nn.ReLU()
335
+ self.conv3 = nn.Conv2d(config.head_hidden_size, 1, kernel_size=1, stride=1, padding=0)
336
+ if config.depth_estimation_type == "relative":
337
+ self.activation2 = nn.ReLU()
338
+ elif config.depth_estimation_type == "metric":
339
+ self.activation2 = nn.Sigmoid()
340
+ else:
341
+ raise ValueError(f"Unknown depth estimation type: {config.depth_estimation_type}")
342
+ self.max_depth = config.max_depth
343
+
344
+ def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) -> torch.Tensor:
345
+ hidden_states = hidden_states[self.head_in_index]
346
+
347
+ predicted_depth = self.conv1(hidden_states)
348
+ predicted_depth = nn.functional.interpolate(
349
+ predicted_depth,
350
+ (int(patch_height * self.patch_size), int(patch_width * self.patch_size)),
351
+ mode="bilinear",
352
+ align_corners=True,
353
+ )
354
+ predicted_depth = self.conv2(predicted_depth)
355
+ predicted_depth = self.activation1(predicted_depth)
356
+ predicted_depth = self.conv3(predicted_depth)
357
+ predicted_depth = self.activation2(predicted_depth) * self.max_depth
358
+ predicted_depth = predicted_depth.squeeze(dim=1) # shape (batch_size, height, width)
359
+
360
+ return predicted_depth
361
+
362
+
363
+ @add_start_docstrings(
364
+ """
365
+ Depth Anything Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2.
366
+ """,
367
+ DEPTH_ANYTHING_START_DOCSTRING,
368
+ )
369
+ class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel):
370
+ _no_split_modules = ["DPTViTEmbeddings"]
371
+
372
+ def __init__(self, config):
373
+ super().__init__(config)
374
+
375
+ self.backbone = load_backbone(config)
376
+ self.neck = DepthAnythingNeck(config)
377
+ self.head = DepthAnythingDepthEstimationHead(config)
378
+
379
+ # Initialize weights and apply final processing
380
+ self.post_init()
381
+
382
+ @add_start_docstrings_to_model_forward(DEPTH_ANYTHING_INPUTS_DOCSTRING)
383
+ @replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
384
+ def forward(
385
+ self,
386
+ pixel_values: torch.FloatTensor,
387
+ labels: Optional[torch.LongTensor] = None,
388
+ output_attentions: Optional[bool] = None,
389
+ output_hidden_states: Optional[bool] = None,
390
+ return_dict: Optional[bool] = None,
391
+ ) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]:
392
+ r"""
393
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
394
+ Ground truth depth estimation maps for computing the loss.
395
+
396
+ Returns:
397
+
398
+ Examples:
399
+ ```python
400
+ >>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation
401
+ >>> import torch
402
+ >>> import numpy as np
403
+ >>> from PIL import Image
404
+ >>> import requests
405
+
406
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
407
+ >>> image = Image.open(requests.get(url, stream=True).raw)
408
+
409
+ >>> image_processor = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-small-hf")
410
+ >>> model = AutoModelForDepthEstimation.from_pretrained("LiheYoung/depth-anything-small-hf")
411
+
412
+ >>> # prepare image for the model
413
+ >>> inputs = image_processor(images=image, return_tensors="pt")
414
+
415
+ >>> with torch.no_grad():
416
+ ... outputs = model(**inputs)
417
+
418
+ >>> # interpolate to original size
419
+ >>> post_processed_output = image_processor.post_process_depth_estimation(
420
+ ... outputs,
421
+ ... target_sizes=[(image.height, image.width)],
422
+ ... )
423
+
424
+ >>> # visualize the prediction
425
+ >>> predicted_depth = post_processed_output[0]["predicted_depth"]
426
+ >>> depth = predicted_depth * 255 / predicted_depth.max()
427
+ >>> depth = depth.detach().cpu().numpy()
428
+ >>> depth = Image.fromarray(depth.astype("uint8"))
429
+ ```"""
430
+ loss = None
431
+ if labels is not None:
432
+ raise NotImplementedError("Training is not implemented yet")
433
+
434
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
435
+ output_hidden_states = (
436
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
437
+ )
438
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
439
+
440
+ outputs = self.backbone.forward_with_filtered_kwargs(
441
+ pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
442
+ )
443
+ hidden_states = outputs.feature_maps
444
+
445
+ _, _, height, width = pixel_values.shape
446
+ patch_size = self.config.patch_size
447
+ patch_height = height // patch_size
448
+ patch_width = width // patch_size
449
+
450
+ hidden_states = self.neck(hidden_states, patch_height, patch_width)
451
+
452
+ predicted_depth = self.head(hidden_states, patch_height, patch_width)
453
+
454
+ if not return_dict:
455
+ if output_hidden_states:
456
+ output = (predicted_depth,) + outputs[1:]
457
+ else:
458
+ output = (predicted_depth,) + outputs[2:]
459
+ return ((loss,) + output) if loss is not None else output
460
+
461
+ return DepthEstimatorOutput(
462
+ loss=loss,
463
+ predicted_depth=predicted_depth,
464
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
465
+ attentions=outputs.attentions,
466
+ )
467
+
468
+
469
+ __all__ = ["DepthAnythingForDepthEstimation", "DepthAnythingPreTrainedModel"]
docs/transformers/build/lib/transformers/models/depth_pro/configuration_depth_pro.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """DepthPro model configuration"""
16
+
17
+ from copy import deepcopy
18
+
19
+ from ...configuration_utils import PretrainedConfig
20
+ from ...utils import logging
21
+ from ..auto.configuration_auto import CONFIG_MAPPING, AutoConfig
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class DepthProConfig(PretrainedConfig):
28
+ r"""
29
+ This is the configuration class to store the configuration of a [`DepthProModel`]. It is used to instantiate a
30
+ DepthPro model according to the specified arguments, defining the model architecture. Instantiating a configuration
31
+ with the defaults will yield a similar configuration to that of the DepthPro
32
+ [apple/DepthPro](https://huggingface.co/apple/DepthPro) architecture.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+ Args:
38
+ fusion_hidden_size (`int`, *optional*, defaults to 256):
39
+ The number of channels before fusion.
40
+ patch_size (`int`, *optional*, defaults to 384):
41
+ The size (resolution) of each patch. This is also the image_size for backbone model.
42
+ initializer_range (`float`, *optional*, defaults to 0.02):
43
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
44
+ intermediate_hook_ids (`List[int]`, *optional*, defaults to `[11, 5]`):
45
+ Indices of the intermediate hidden states from the patch encoder to use for fusion.
46
+ intermediate_feature_dims (`List[int]`, *optional*, defaults to `[256, 256]`):
47
+ Hidden state dimensions during upsampling for each intermediate hidden state in `intermediate_hook_ids`.
48
+ scaled_images_ratios (`List[float]`, *optional*, defaults to `[0.25, 0.5, 1]`):
49
+ Ratios of scaled images to be used by the patch encoder.
50
+ scaled_images_overlap_ratios (`List[float]`, *optional*, defaults to `[0.0, 0.5, 0.25]`):
51
+ Overlap ratios between patches for each scaled image in `scaled_images_ratios`.
52
+ scaled_images_feature_dims (`List[int]`, *optional*, defaults to `[1024, 1024, 512]`):
53
+ Hidden state dimensions during upsampling for each scaled image in `scaled_images_ratios`.
54
+ merge_padding_value (`int`, *optional*, defaults to 3):
55
+ When merging smaller patches back to the image size, overlapping sections of this size are removed.
56
+ use_batch_norm_in_fusion_residual (`bool`, *optional*, defaults to `False`):
57
+ Whether to use batch normalization in the pre-activate residual units of the fusion blocks.
58
+ use_bias_in_fusion_residual (`bool`, *optional*, defaults to `True`):
59
+ Whether to use bias in the pre-activate residual units of the fusion blocks.
60
+ use_fov_model (`bool`, *optional*, defaults to `False`):
61
+ Whether to use `DepthProFovModel` to generate the field of view.
62
+ num_fov_head_layers (`int`, *optional*, defaults to 2):
63
+ Number of convolution layers in the head of `DepthProFovModel`.
64
+ image_model_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
65
+ The configuration of the image encoder model, which is loaded using the [`AutoModel`] API.
66
+ By default, Dinov2 model is used as backbone.
67
+ patch_model_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
68
+ The configuration of the patch encoder model, which is loaded using the [`AutoModel`] API.
69
+ By default, Dinov2 model is used as backbone.
70
+ fov_model_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
71
+ The configuration of the fov encoder model, which is loaded using the [`AutoModel`] API.
72
+ By default, Dinov2 model is used as backbone.
73
+
74
+ Example:
75
+
76
+ ```python
77
+ >>> from transformers import DepthProConfig, DepthProModel
78
+
79
+ >>> # Initializing a DepthPro apple/DepthPro style configuration
80
+ >>> configuration = DepthProConfig()
81
+
82
+ >>> # Initializing a model (with random weights) from the apple/DepthPro style configuration
83
+ >>> model = DepthProModel(configuration)
84
+
85
+ >>> # Accessing the model configuration
86
+ >>> configuration = model.config
87
+ ```"""
88
+
89
+ model_type = "depth_pro"
90
+ sub_configs = {"image_model_config": AutoConfig, "patch_model_config": AutoConfig, "fov_model_config": AutoConfig}
91
+
92
+ def __init__(
93
+ self,
94
+ fusion_hidden_size=256,
95
+ patch_size=384,
96
+ initializer_range=0.02,
97
+ intermediate_hook_ids=[11, 5],
98
+ intermediate_feature_dims=[256, 256],
99
+ scaled_images_ratios=[0.25, 0.5, 1],
100
+ scaled_images_overlap_ratios=[0.0, 0.5, 0.25],
101
+ scaled_images_feature_dims=[1024, 1024, 512],
102
+ merge_padding_value=3,
103
+ use_batch_norm_in_fusion_residual=False,
104
+ use_bias_in_fusion_residual=True,
105
+ use_fov_model=False,
106
+ num_fov_head_layers=2,
107
+ image_model_config=None,
108
+ patch_model_config=None,
109
+ fov_model_config=None,
110
+ **kwargs,
111
+ ):
112
+ super().__init__(**kwargs)
113
+
114
+ # scaled_images_ratios is sorted
115
+ if scaled_images_ratios != sorted(scaled_images_ratios):
116
+ raise ValueError(
117
+ f"Values in scaled_images_ratios={scaled_images_ratios} should be sorted from low to high"
118
+ )
119
+
120
+ # scaled_images_ratios, scaled_images_overlap_ratios, scaled_images_feature_dims should be consistent
121
+ if not (len(scaled_images_ratios) == len(scaled_images_overlap_ratios) == len(scaled_images_feature_dims)):
122
+ raise ValueError(
123
+ f"len(scaled_images_ratios)={len(scaled_images_ratios)} and "
124
+ f"len(scaled_images_overlap_ratios)={len(scaled_images_overlap_ratios)} and "
125
+ f"len(scaled_images_feature_dims)={len(scaled_images_feature_dims)}, "
126
+ f"should match in config."
127
+ )
128
+
129
+ # intermediate_hook_ids, intermediate_feature_dims should be consistent
130
+ if not (len(intermediate_hook_ids) == len(intermediate_feature_dims)):
131
+ raise ValueError(
132
+ f"len(intermediate_hook_ids)={len(intermediate_hook_ids)} and "
133
+ f"len(intermediate_feature_dims)={len(intermediate_feature_dims)}, "
134
+ f"should match in config."
135
+ )
136
+
137
+ # fusion_hidden_size should be consistent with num_fov_head_layers
138
+ if fusion_hidden_size // 2**num_fov_head_layers == 0:
139
+ raise ValueError(
140
+ f"fusion_hidden_size={fusion_hidden_size} should be consistent with num_fov_head_layers={num_fov_head_layers} "
141
+ "i.e fusion_hidden_size // 2**num_fov_head_layers > 0"
142
+ )
143
+
144
+ self.fusion_hidden_size = fusion_hidden_size
145
+ self.patch_size = patch_size
146
+ self.initializer_range = initializer_range
147
+ self.use_batch_norm_in_fusion_residual = use_batch_norm_in_fusion_residual
148
+ self.use_bias_in_fusion_residual = use_bias_in_fusion_residual
149
+ self.use_fov_model = use_fov_model
150
+ self.num_fov_head_layers = num_fov_head_layers
151
+ self.intermediate_hook_ids = intermediate_hook_ids
152
+ self.intermediate_feature_dims = intermediate_feature_dims
153
+ self.scaled_images_ratios = scaled_images_ratios
154
+ self.scaled_images_overlap_ratios = scaled_images_overlap_ratios
155
+ self.scaled_images_feature_dims = scaled_images_feature_dims
156
+ self.merge_padding_value = merge_padding_value
157
+ self.image_model_config = image_model_config
158
+ self.patch_model_config = patch_model_config
159
+ self.fov_model_config = fov_model_config
160
+
161
+ for sub_config_key in self.sub_configs.keys():
162
+ sub_config = getattr(self, sub_config_key)
163
+
164
+ if sub_config is None:
165
+ sub_config = CONFIG_MAPPING["dinov2"](image_size=patch_size)
166
+ logger.info(
167
+ f"`{sub_config_key}` is `None`. Initializing `{sub_config_key}` with the `Dinov2Config` "
168
+ f"with default values except `{sub_config_key}.image_size` is set to `config.patch_size`."
169
+ )
170
+ elif isinstance(sub_config, dict):
171
+ sub_config = deepcopy(sub_config)
172
+ if "model_type" not in sub_config:
173
+ raise KeyError(
174
+ f"The `model_type` key is missing in the `{sub_config_key}` dictionary. Please provide the model type."
175
+ )
176
+ elif sub_config["model_type"] not in CONFIG_MAPPING:
177
+ raise ValueError(
178
+ f"The model type `{sub_config['model_type']}` in `{sub_config_key}` is not supported. Please provide a valid model type."
179
+ )
180
+ image_size = sub_config.get("image_size")
181
+ if image_size != patch_size:
182
+ logger.info(
183
+ f"The `image_size` in `{sub_config_key}` is set to `{image_size}`, "
184
+ f"but it does not match the required `patch_size` of `{patch_size}`. "
185
+ f"Updating `image_size` to `{patch_size}` for consistency. "
186
+ f"Ensure that `image_size` aligns with `patch_size` in the configuration."
187
+ )
188
+ sub_config.update({"image_size": patch_size})
189
+ sub_config = CONFIG_MAPPING[sub_config["model_type"]](**sub_config)
190
+ elif isinstance(sub_config, PretrainedConfig):
191
+ sub_config = sub_config
192
+ image_size = getattr(sub_config, "image_size", None)
193
+ if image_size != patch_size:
194
+ raise ValueError(
195
+ f"`config.{sub_config_key}.image_size={image_size}` should match `config.patch_size={patch_size}`."
196
+ )
197
+ else:
198
+ raise TypeError(
199
+ f"Invalid type for `sub_config`. Expected `PretrainedConfig`, `dict`, or `None`, but got {type(sub_config)}."
200
+ )
201
+
202
+ setattr(self, sub_config_key, sub_config)
203
+
204
+
205
+ __all__ = ["DepthProConfig"]
docs/transformers/build/lib/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import gc
17
+ import os
18
+
19
+ import regex as re
20
+ import torch
21
+ from huggingface_hub import hf_hub_download
22
+
23
+ from transformers import (
24
+ DepthProConfig,
25
+ DepthProForDepthEstimation,
26
+ DepthProImageProcessorFast,
27
+ )
28
+
29
+
30
+ # fmt: off
31
+ ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
32
+
33
+ # encoder
34
+ r"encoder.(patch|image)_encoder.cls_token": r"depth_pro.encoder.\1_encoder.model.embeddings.cls_token",
35
+ r"encoder.(patch|image)_encoder.pos_embed": r"depth_pro.encoder.\1_encoder.model.embeddings.position_embeddings",
36
+ r"encoder.(patch|image)_encoder.patch_embed.proj.(weight|bias)": r"depth_pro.encoder.\1_encoder.model.embeddings.patch_embeddings.projection.\2",
37
+ r"encoder.(patch|image)_encoder.blocks.(\d+).norm(\d+).(weight|bias)": r"depth_pro.encoder.\1_encoder.model.encoder.layer.\2.norm\3.\4",
38
+ r"encoder.(patch|image)_encoder.blocks.(\d+).attn.qkv.(weight|bias)": r"depth_pro.encoder.\1_encoder.model.encoder.layer.\2.attention.attention.(query|key|value).\3",
39
+ r"encoder.(patch|image)_encoder.blocks.(\d+).attn.proj.(weight|bias)": r"depth_pro.encoder.\1_encoder.model.encoder.layer.\2.attention.output.dense.\3",
40
+ r"encoder.(patch|image)_encoder.blocks.(\d+).ls(\d+).gamma": r"depth_pro.encoder.\1_encoder.model.encoder.layer.\2.layer_scale\3.lambda1",
41
+ r"encoder.(patch|image)_encoder.blocks.(\d+).mlp.fc(\d+).(weight|bias)": r"depth_pro.encoder.\1_encoder.model.encoder.layer.\2.mlp.fc\3.\4",
42
+ r"encoder.(patch|image)_encoder.norm.(weight|bias)": r"depth_pro.encoder.\1_encoder.model.layernorm.\2",
43
+ r"encoder.fuse_lowres.(weight|bias)": r"depth_pro.neck.fuse_image_with_low_res.\1",
44
+
45
+ # fov
46
+ r"fov.encoder.0.cls_token": r"fov_model.fov_encoder.model.embeddings.cls_token",
47
+ r"fov.encoder.0.pos_embed": r"fov_model.fov_encoder.model.embeddings.position_embeddings",
48
+ r"fov.encoder.0.patch_embed.proj.(weight|bias)": r"fov_model.fov_encoder.model.embeddings.patch_embeddings.projection.\1",
49
+ r"fov.encoder.0.blocks.(\d+).norm(\d+).(weight|bias)": r"fov_model.fov_encoder.model.encoder.layer.\1.norm\2.\3",
50
+ r"fov.encoder.0.blocks.(\d+).attn.qkv.(weight|bias)": r"fov_model.fov_encoder.model.encoder.layer.\1.attention.attention.(query|key|value).\2",
51
+ r"fov.encoder.0.blocks.(\d+).attn.proj.(weight|bias)": r"fov_model.fov_encoder.model.encoder.layer.\1.attention.output.dense.\2",
52
+ r"fov.encoder.0.blocks.(\d+).ls(\d+).gamma": r"fov_model.fov_encoder.model.encoder.layer.\1.layer_scale\2.lambda1",
53
+ r"fov.encoder.0.blocks.(\d+).mlp.fc(\d+).(weight|bias)": r"fov_model.fov_encoder.model.encoder.layer.\1.mlp.fc\2.\3",
54
+ r"fov.encoder.0.norm.(weight|bias)": r"fov_model.fov_encoder.model.layernorm.\1",
55
+ r"fov.downsample.0.(weight|bias)": r"fov_model.conv.\1",
56
+ r"fov.encoder.1.(weight|bias)": r"fov_model.fov_encoder.neck.\1",
57
+ r"fov.head.(\d+).(weight|bias)": r"fov_model.head.layers.\1.\2",
58
+
59
+ # head
60
+ r"head.(\d+).(weight|bias)": r"head.layers.\1.\2",
61
+
62
+ # upsamples
63
+ r"encoder.upsample_lowres.(weight|bias)": r"depth_pro.neck.feature_upsample.image_block.layers.0.\1",
64
+ r"encoder.upsample_latent(\d+).(\d+).(weight|bias)": lambda match: (
65
+ f"depth_pro.neck.feature_upsample.intermediate.{1-int(match.group(1))}.layers.{match.group(2)}.{match.group(3)}"
66
+ ),
67
+ r"encoder.upsample(\d+).(\d+).(weight|bias)": lambda match: (
68
+ f"depth_pro.neck.feature_upsample.scaled_images.{2-int(match.group(1))}.layers.{match.group(2)}.{match.group(3)}"
69
+ ),
70
+
71
+ # projections between encoder and fusion
72
+ r"decoder.convs.(\d+).weight": lambda match: (
73
+ f"depth_pro.neck.feature_projection.projections.{4-int(match.group(1))}.weight"
74
+ ),
75
+
76
+ # fusion stage
77
+ r"decoder.fusions.([1234]).resnet(\d+).residual.(\d+).(weight|bias)": lambda match: (
78
+ f"fusion_stage.intermediate.{4-int(match.group(1))}.residual_layer{match.group(2)}.convolution{(int(match.group(3))+1)//2}.{match.group(4)}"
79
+ ),
80
+ r"decoder.fusions.0.resnet(\d+).residual.(\d+).(weight|bias)": lambda match: (
81
+ f"fusion_stage.final.residual_layer{match.group(1)}.convolution{(int(match.group(2))+1)//2}.{match.group(3)}"
82
+ ),
83
+ r"decoder.fusions.([1234]).out_conv.(weight|bias)": lambda match: (
84
+ f"fusion_stage.intermediate.{4-int(match.group(1))}.projection.{match.group(2)}"
85
+ ),
86
+ r"decoder.fusions.0.out_conv.(weight|bias)": lambda match: (
87
+ f"fusion_stage.final.projection.{match.group(1)}"
88
+ ),
89
+ r"decoder.fusions.(\d+).deconv.(weight|bias)": lambda match: (
90
+ f"fusion_stage.intermediate.{4-int(match.group(1))}.deconv.{match.group(2)}"
91
+ ),
92
+ }
93
+ # fmt: on
94
+
95
+
96
+ def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
97
+ output_dict = {}
98
+ if state_dict_keys is not None:
99
+ old_text = "\n".join(state_dict_keys)
100
+ new_text = old_text
101
+ for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
102
+ if replacement is None:
103
+ new_text = re.sub(pattern, "", new_text) # an empty line
104
+ continue
105
+ new_text = re.sub(pattern, replacement, new_text)
106
+ output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
107
+ return output_dict
108
+
109
+
110
+ def get_qkv_state_dict(key, parameter):
111
+ """
112
+ new key which looks like this
113
+ xxxx.(q|k|v).xxx (m, n)
114
+
115
+ is converted to
116
+ xxxx.q.xxxx (m//3, n)
117
+ xxxx.k.xxxx (m//3, n)
118
+ xxxx.v.xxxx (m//3, n)
119
+ """
120
+ qkv_state_dict = {}
121
+ placeholder = re.search(r"(\(.*?\))", key).group(1) # finds "(query|key|value)"
122
+ replacements_keys = placeholder[1:-1].split("|") # creates ['query', 'key', 'value']
123
+ replacements_vals = torch.split(
124
+ parameter, split_size_or_sections=parameter.size(0) // len(replacements_keys), dim=0
125
+ )
126
+ for replacement_key, replacement_val in zip(replacements_keys, replacements_vals):
127
+ qkv_state_dict[key.replace(placeholder, replacement_key)] = replacement_val
128
+ return qkv_state_dict
129
+
130
+
131
+ def write_model(
132
+ hf_repo_id: str,
133
+ output_dir: str,
134
+ safe_serialization: bool = True,
135
+ ):
136
+ os.makedirs(output_dir, exist_ok=True)
137
+
138
+ # ------------------------------------------------------------
139
+ # Create and save config
140
+ # ------------------------------------------------------------
141
+
142
+ # create config
143
+ backbone_config = {
144
+ "model_type": "dinov2",
145
+ "num_hidden_layers": 24,
146
+ "patch_size": 16,
147
+ "hidden_size": 1024,
148
+ "num_attention_heads": 16,
149
+ "image_size": 384,
150
+ "use_mask_token": False,
151
+ }
152
+ config = DepthProConfig(
153
+ # original implementation uses same config for all 3 models
154
+ image_model_config=backbone_config,
155
+ patch_model_config=backbone_config,
156
+ fov_model_config=backbone_config,
157
+ use_fov_model=True,
158
+ )
159
+
160
+ # save config
161
+ config.save_pretrained(output_dir)
162
+ print("Model config saved successfully...")
163
+
164
+ # ------------------------------------------------------------
165
+ # Convert weights
166
+ # ------------------------------------------------------------
167
+
168
+ # download and load state_dict from hf repo
169
+ file_path = hf_hub_download(hf_repo_id, "depth_pro.pt")
170
+ loaded = torch.load(file_path, weights_only=True)
171
+
172
+ print("Converting model...")
173
+ all_keys = list(loaded.keys())
174
+ new_keys = convert_old_keys_to_new_keys(all_keys)
175
+
176
+ state_dict = {}
177
+ for key in all_keys:
178
+ new_key = new_keys[key]
179
+ current_parameter = loaded.pop(key)
180
+
181
+ if "qkv" in key:
182
+ qkv_state_dict = get_qkv_state_dict(new_key, current_parameter)
183
+ state_dict.update(qkv_state_dict)
184
+ else:
185
+ state_dict[new_key] = current_parameter
186
+
187
+ print("Loading the checkpoint in a DepthPro model.")
188
+ model = DepthProForDepthEstimation(config)
189
+ model.load_state_dict(state_dict, strict=True, assign=True)
190
+ print("Checkpoint loaded successfully.")
191
+
192
+ print("Saving the model.")
193
+ model.save_pretrained(output_dir, safe_serialization=safe_serialization)
194
+ del state_dict, model
195
+
196
+ # Safety check: reload the converted model
197
+ gc.collect()
198
+ print("Reloading the model to check if it's saved correctly.")
199
+ model = DepthProForDepthEstimation.from_pretrained(output_dir, device_map="auto")
200
+ print("Model reloaded successfully.")
201
+ return model
202
+
203
+
204
+ def write_image_processor(output_dir: str):
205
+ image_processor = DepthProImageProcessorFast()
206
+ image_processor.save_pretrained(output_dir)
207
+ return image_processor
208
+
209
+
210
+ def main():
211
+ parser = argparse.ArgumentParser()
212
+ parser.add_argument(
213
+ "--hf_repo_id",
214
+ default="apple/DepthPro",
215
+ help="Location of official weights from apple on HF",
216
+ )
217
+ parser.add_argument(
218
+ "--output_dir",
219
+ default="apple_DepthPro",
220
+ help="Location to write the converted model and processor",
221
+ )
222
+ parser.add_argument(
223
+ "--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
224
+ )
225
+ parser.add_argument(
226
+ "--push_to_hub",
227
+ action=argparse.BooleanOptionalAction,
228
+ help="Whether or not to push the converted model to the huggingface hub.",
229
+ )
230
+ parser.add_argument(
231
+ "--hub_repo_id",
232
+ default="apple/DepthPro-hf",
233
+ help="Huggingface hub repo to write the converted model and processor",
234
+ )
235
+ args = parser.parse_args()
236
+
237
+ model = write_model(
238
+ hf_repo_id=args.hf_repo_id,
239
+ output_dir=args.output_dir,
240
+ safe_serialization=args.safe_serialization,
241
+ )
242
+
243
+ image_processor = write_image_processor(
244
+ output_dir=args.output_dir,
245
+ )
246
+
247
+ if args.push_to_hub:
248
+ print("Pushing to hub...")
249
+ model.push_to_hub(args.hub_repo_id)
250
+ image_processor.push_to_hub(args.hub_repo_id)
251
+
252
+
253
+ if __name__ == "__main__":
254
+ main()
docs/transformers/build/lib/transformers/models/depth_pro/image_processing_depth_pro.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for DepthPro."""
16
+
17
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...utils.import_utils import requires
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ from .modeling_depth_pro import DepthProDepthEstimatorOutput
26
+
27
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
28
+ from ...image_transforms import to_channel_dimension_format
29
+ from ...image_utils import (
30
+ IMAGENET_STANDARD_MEAN,
31
+ IMAGENET_STANDARD_STD,
32
+ ChannelDimension,
33
+ ImageInput,
34
+ PILImageResampling,
35
+ infer_channel_dimension_format,
36
+ is_scaled_image,
37
+ is_torch_available,
38
+ make_list_of_images,
39
+ pil_torch_interpolation_mapping,
40
+ to_numpy_array,
41
+ valid_images,
42
+ )
43
+ from ...utils import (
44
+ TensorType,
45
+ filter_out_non_signature_kwargs,
46
+ logging,
47
+ requires_backends,
48
+ )
49
+
50
+
51
+ if is_torch_available():
52
+ import torch
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+
58
+ @requires(backends=("torchvision", "torch"))
59
+ class DepthProImageProcessor(BaseImageProcessor):
60
+ r"""
61
+ Constructs a DepthPro image processor.
62
+
63
+ Args:
64
+ do_resize (`bool`, *optional*, defaults to `True`):
65
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
66
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
67
+ size (`dict`, *optional*, defaults to `{"height": 1536, "width": 1536}`):
68
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
69
+ method.
70
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
71
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
72
+ `preprocess` method.
73
+ do_rescale (`bool`, *optional*, defaults to `True`):
74
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
75
+ parameter in the `preprocess` method.
76
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
77
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
78
+ `preprocess` method.
79
+ do_normalize (`bool`, *optional*, defaults to `True`):
80
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
81
+ method.
82
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
83
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
84
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
85
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
86
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
87
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
88
+ """
89
+
90
+ model_input_names = ["pixel_values"]
91
+
92
+ def __init__(
93
+ self,
94
+ do_resize: bool = True,
95
+ size: Optional[Dict[str, int]] = None,
96
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
97
+ do_rescale: bool = True,
98
+ rescale_factor: Union[int, float] = 1 / 255,
99
+ do_normalize: bool = True,
100
+ image_mean: Optional[Union[float, List[float]]] = None,
101
+ image_std: Optional[Union[float, List[float]]] = None,
102
+ **kwargs,
103
+ ):
104
+ super().__init__(**kwargs)
105
+ size = size if size is not None else {"height": 1536, "width": 1536}
106
+ size = get_size_dict(size)
107
+ self.do_resize = do_resize
108
+ self.do_rescale = do_rescale
109
+ self.do_normalize = do_normalize
110
+ self.size = size
111
+ self.resample = resample
112
+ self.rescale_factor = rescale_factor
113
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
114
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
115
+
116
+ def resize(
117
+ self,
118
+ image: np.ndarray,
119
+ size: Dict[str, int],
120
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
121
+ data_format: Optional[Union[str, ChannelDimension]] = None,
122
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
123
+ **kwargs,
124
+ ) -> np.ndarray:
125
+ """
126
+ Resize an image to `(size["height"], size["width"])`.
127
+
128
+ Args:
129
+ image (`np.ndarray`):
130
+ Image to resize.
131
+ size (`Dict[str, int]`):
132
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
133
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
134
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
135
+ data_format (`ChannelDimension` or `str`, *optional*):
136
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
137
+ image is used. Can be one of:
138
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
139
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
140
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
141
+ input_data_format (`ChannelDimension` or `str`, *optional*):
142
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
143
+ from the input image. Can be one of:
144
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
145
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
146
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
147
+
148
+ Returns:
149
+ `np.ndarray`: The resized images.
150
+ """
151
+ requires_backends(self, "torch")
152
+
153
+ size = get_size_dict(size)
154
+ if "height" not in size or "width" not in size:
155
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
156
+ output_size = (size["height"], size["width"])
157
+
158
+ # we use torch interpolation instead of image.resize because DepthProImageProcessor
159
+ # rescales, then normalizes, which may cause some values to become negative, before resizing the image.
160
+ # image.resize expects all values to be in range [0, 1] or [0, 255] and throws an exception otherwise,
161
+ # however pytorch interpolation works with negative values.
162
+ # relevant issue here: https://github.com/huggingface/transformers/issues/34920
163
+ # input should be (B, C, H, W)
164
+ image_tensor = torch.from_numpy(image).unsqueeze(0)
165
+ resized_image = torch.nn.functional.interpolate(
166
+ input=image_tensor,
167
+ size=output_size,
168
+ mode=pil_torch_interpolation_mapping[resample].value,
169
+ )
170
+ resized_image = resized_image.squeeze(0).numpy()
171
+ return resized_image
172
+
173
+ def _validate_input_arguments(
174
+ self,
175
+ do_resize: bool,
176
+ size: Dict[str, int],
177
+ resample: PILImageResampling,
178
+ do_rescale: bool,
179
+ rescale_factor: float,
180
+ do_normalize: bool,
181
+ image_mean: Union[float, List[float]],
182
+ image_std: Union[float, List[float]],
183
+ data_format: Union[str, ChannelDimension],
184
+ ):
185
+ if do_resize and None in (size, resample):
186
+ raise ValueError("Size and resample must be specified if do_resize is True.")
187
+
188
+ if do_rescale and rescale_factor is None:
189
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
190
+
191
+ if do_normalize and None in (image_mean, image_std):
192
+ raise ValueError("Image mean and standard deviation must be specified if do_normalize is True.")
193
+
194
+ @filter_out_non_signature_kwargs()
195
+ def preprocess(
196
+ self,
197
+ images: ImageInput,
198
+ do_resize: Optional[bool] = None,
199
+ size: Optional[Dict[str, int]] = None,
200
+ resample: Optional[PILImageResampling] = None,
201
+ do_rescale: Optional[bool] = None,
202
+ rescale_factor: Optional[float] = None,
203
+ do_normalize: Optional[bool] = None,
204
+ image_mean: Optional[Union[float, List[float]]] = None,
205
+ image_std: Optional[Union[float, List[float]]] = None,
206
+ return_tensors: Optional[Union[str, TensorType]] = None,
207
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
208
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
209
+ ):
210
+ """
211
+ Preprocess an image or batch of images.
212
+
213
+ Args:
214
+ images (`ImageInput`):
215
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
216
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
217
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
218
+ Whether to resize the image.
219
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
220
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
221
+ resizing.
222
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
223
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
224
+ an effect if `do_resize` is set to `True`.
225
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
226
+ Whether to rescale the image values between [0 - 1].
227
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
228
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
229
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
230
+ Whether to normalize the image.
231
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
232
+ Image mean to use if `do_normalize` is set to `True`.
233
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
234
+ Image standard deviation to use if `do_normalize` is set to `True`.
235
+ return_tensors (`str` or `TensorType`, *optional*):
236
+ The type of tensors to return. Can be one of:
237
+ - Unset: Return a list of `np.ndarray`.
238
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
239
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
240
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
241
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
242
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
243
+ The channel dimension format for the output image. Can be one of:
244
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
245
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
246
+ - Unset: Use the channel dimension format of the input image.
247
+ input_data_format (`ChannelDimension` or `str`, *optional*):
248
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
249
+ from the input image. Can be one of:
250
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
251
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
252
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
253
+ """
254
+ do_resize = do_resize if do_resize is not None else self.do_resize
255
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
256
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
257
+ resample = resample if resample is not None else self.resample
258
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
259
+ image_mean = image_mean if image_mean is not None else self.image_mean
260
+ image_std = image_std if image_std is not None else self.image_std
261
+
262
+ size = size if size is not None else self.size
263
+
264
+ images = make_list_of_images(images)
265
+
266
+ if not valid_images(images):
267
+ raise ValueError(
268
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
269
+ "torch.Tensor, tf.Tensor or jax.ndarray."
270
+ )
271
+ self._validate_input_arguments(
272
+ do_resize=do_resize,
273
+ size=size,
274
+ resample=resample,
275
+ do_rescale=do_rescale,
276
+ rescale_factor=rescale_factor,
277
+ do_normalize=do_normalize,
278
+ image_mean=image_mean,
279
+ image_std=image_std,
280
+ data_format=data_format,
281
+ )
282
+
283
+ # All transformations expect numpy arrays.
284
+ images = [to_numpy_array(image) for image in images]
285
+
286
+ if is_scaled_image(images[0]) and do_rescale:
287
+ logger.warning_once(
288
+ "It looks like you are trying to rescale already rescaled images. If the input"
289
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
290
+ )
291
+
292
+ if input_data_format is None:
293
+ # We assume that all images have the same channel dimension format.
294
+ input_data_format = infer_channel_dimension_format(images[0])
295
+
296
+ all_images = []
297
+ for image in images:
298
+ if do_rescale:
299
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
300
+
301
+ if do_normalize:
302
+ image = self.normalize(
303
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
304
+ )
305
+
306
+ # depth-pro rescales and normalizes the image before resizing it
307
+ # uses torch interpolation which requires ChannelDimension.FIRST
308
+ if do_resize:
309
+ image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
310
+ image = self.resize(image=image, size=size, resample=resample)
311
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=ChannelDimension.FIRST)
312
+ else:
313
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
314
+
315
+ all_images.append(image)
316
+
317
+ data = {"pixel_values": all_images}
318
+ return BatchFeature(data=data, tensor_type=return_tensors)
319
+
320
+ def post_process_depth_estimation(
321
+ self,
322
+ outputs: "DepthProDepthEstimatorOutput",
323
+ target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
324
+ ) -> Dict[str, List[TensorType]]:
325
+ """
326
+ Post-processes the raw depth predictions from the model to generate
327
+ final depth predictions which is caliberated using the field of view if provided
328
+ and resized to specified target sizes if provided.
329
+
330
+ Args:
331
+ outputs ([`DepthProDepthEstimatorOutput`]):
332
+ Raw outputs of the model.
333
+ target_sizes (`Optional[Union[TensorType, List[Tuple[int, int]], None]]`, *optional*, defaults to `None`):
334
+ Target sizes to resize the depth predictions. Can be a tensor of shape `(batch_size, 2)`
335
+ or a list of tuples `(height, width)` for each image in the batch. If `None`, no resizing
336
+ is performed.
337
+
338
+ Returns:
339
+ `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
340
+ predictions, and field of view (degrees) and focal length (pixels) if `field_of_view` is given in `outputs`.
341
+
342
+ Raises:
343
+ `ValueError`:
344
+ If the lengths of `predicted_depths`, `fovs`, or `target_sizes` are mismatched.
345
+ """
346
+ requires_backends(self, "torch")
347
+
348
+ predicted_depth = outputs.predicted_depth
349
+ fov = outputs.field_of_view
350
+
351
+ batch_size = len(predicted_depth)
352
+
353
+ if target_sizes is not None and batch_size != len(target_sizes):
354
+ raise ValueError(
355
+ "Make sure that you pass in as many fov values as the batch dimension of the predicted depth"
356
+ )
357
+
358
+ results = []
359
+ fov = [None] * batch_size if fov is None else fov
360
+ target_sizes = [None] * batch_size if target_sizes is None else target_sizes
361
+ for depth, fov_value, target_size in zip(predicted_depth, fov, target_sizes):
362
+ focal_length = None
363
+ if target_size is not None:
364
+ # scale image w.r.t fov
365
+ if fov_value is not None:
366
+ width = target_size[1]
367
+ focal_length = 0.5 * width / torch.tan(0.5 * torch.deg2rad(fov_value))
368
+ depth = depth * width / focal_length
369
+
370
+ # interpolate
371
+ depth = torch.nn.functional.interpolate(
372
+ # input should be (B, C, H, W)
373
+ input=depth.unsqueeze(0).unsqueeze(1),
374
+ size=target_size,
375
+ mode=pil_torch_interpolation_mapping[self.resample].value,
376
+ ).squeeze()
377
+
378
+ # inverse the depth
379
+ depth = 1.0 / torch.clamp(depth, min=1e-4, max=1e4)
380
+
381
+ results.append(
382
+ {
383
+ "predicted_depth": depth,
384
+ "field_of_view": fov_value,
385
+ "focal_length": focal_length,
386
+ }
387
+ )
388
+
389
+ return results
390
+
391
+
392
+ __all__ = ["DepthProImageProcessor"]
docs/transformers/build/lib/transformers/models/depth_pro/image_processing_depth_pro_fast.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Fast Image processor class for DepthPro."""
16
+
17
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
18
+
19
+ from ...image_processing_base import BatchFeature
20
+ from ...image_processing_utils_fast import (
21
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
22
+ BaseImageProcessorFast,
23
+ group_images_by_shape,
24
+ reorder_images,
25
+ )
26
+ from ...image_utils import (
27
+ IMAGENET_STANDARD_MEAN,
28
+ IMAGENET_STANDARD_STD,
29
+ PILImageResampling,
30
+ SizeDict,
31
+ )
32
+ from ...utils import (
33
+ TensorType,
34
+ add_start_docstrings,
35
+ is_torch_available,
36
+ is_torchvision_available,
37
+ is_torchvision_v2_available,
38
+ logging,
39
+ requires_backends,
40
+ )
41
+ from ...utils.import_utils import requires
42
+
43
+
44
+ if TYPE_CHECKING:
45
+ from .modeling_depth_pro import DepthProDepthEstimatorOutput
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ if is_torch_available():
51
+ import torch
52
+
53
+
54
+ if is_torchvision_available():
55
+ from ...image_utils import pil_torch_interpolation_mapping
56
+
57
+ if is_torchvision_v2_available():
58
+ from torchvision.transforms.v2 import functional as F
59
+ else:
60
+ from torchvision.transforms import functional as F
61
+
62
+
63
+ @add_start_docstrings(
64
+ "Constructs a fast DepthPro image processor.",
65
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
66
+ )
67
+ @requires(backends=("torchvision", "torch"))
68
+ class DepthProImageProcessorFast(BaseImageProcessorFast):
69
+ resample = PILImageResampling.BILINEAR
70
+ image_mean = IMAGENET_STANDARD_MEAN
71
+ image_std = IMAGENET_STANDARD_STD
72
+ size = {"height": 1536, "width": 1536}
73
+ do_resize = True
74
+ do_rescale = True
75
+ do_normalize = True
76
+
77
+ # DepthPro resizes image after rescaling and normalizing,
78
+ # which makes it different from BaseImageProcessorFast._preprocess
79
+ def _preprocess(
80
+ self,
81
+ images: List["torch.Tensor"],
82
+ do_resize: bool,
83
+ size: SizeDict,
84
+ interpolation: Optional["F.InterpolationMode"],
85
+ do_center_crop: bool,
86
+ crop_size: SizeDict,
87
+ do_rescale: bool,
88
+ rescale_factor: float,
89
+ do_normalize: bool,
90
+ image_mean: Optional[Union[float, List[float]]],
91
+ image_std: Optional[Union[float, List[float]]],
92
+ return_tensors: Optional[Union[str, TensorType]],
93
+ ) -> BatchFeature:
94
+ # Group images by size for batched scaling
95
+ grouped_images, grouped_images_index = group_images_by_shape(images)
96
+ processed_images_grouped = {}
97
+ for shape, stacked_images in grouped_images.items():
98
+ # Fused rescale and normalize
99
+ stacked_images = self.rescale_and_normalize(
100
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
101
+ )
102
+ if do_resize:
103
+ stacked_images = self.resize(
104
+ image=stacked_images,
105
+ size=size,
106
+ interpolation=interpolation,
107
+ antialias=False,
108
+ )
109
+ processed_images_grouped[shape] = stacked_images
110
+
111
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
112
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
113
+
114
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
115
+
116
+ # Copied from transformers.models.depth_pro.image_processing_depth_pro.DepthProImageProcessor.post_process_depth_estimation
117
+ def post_process_depth_estimation(
118
+ self,
119
+ outputs: "DepthProDepthEstimatorOutput",
120
+ target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
121
+ ) -> Dict[str, List[TensorType]]:
122
+ """
123
+ Post-processes the raw depth predictions from the model to generate
124
+ final depth predictions which is caliberated using the field of view if provided
125
+ and resized to specified target sizes if provided.
126
+
127
+ Args:
128
+ outputs ([`DepthProDepthEstimatorOutput`]):
129
+ Raw outputs of the model.
130
+ target_sizes (`Optional[Union[TensorType, List[Tuple[int, int]], None]]`, *optional*, defaults to `None`):
131
+ Target sizes to resize the depth predictions. Can be a tensor of shape `(batch_size, 2)`
132
+ or a list of tuples `(height, width)` for each image in the batch. If `None`, no resizing
133
+ is performed.
134
+
135
+ Returns:
136
+ `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
137
+ predictions, and field of view (degrees) and focal length (pixels) if `field_of_view` is given in `outputs`.
138
+
139
+ Raises:
140
+ `ValueError`:
141
+ If the lengths of `predicted_depths`, `fovs`, or `target_sizes` are mismatched.
142
+ """
143
+ requires_backends(self, "torch")
144
+
145
+ predicted_depth = outputs.predicted_depth
146
+ fov = outputs.field_of_view
147
+
148
+ batch_size = len(predicted_depth)
149
+
150
+ if target_sizes is not None and batch_size != len(target_sizes):
151
+ raise ValueError(
152
+ "Make sure that you pass in as many fov values as the batch dimension of the predicted depth"
153
+ )
154
+
155
+ results = []
156
+ fov = [None] * batch_size if fov is None else fov
157
+ target_sizes = [None] * batch_size if target_sizes is None else target_sizes
158
+ for depth, fov_value, target_size in zip(predicted_depth, fov, target_sizes):
159
+ focal_length = None
160
+ if target_size is not None:
161
+ # scale image w.r.t fov
162
+ if fov_value is not None:
163
+ width = target_size[1]
164
+ focal_length = 0.5 * width / torch.tan(0.5 * torch.deg2rad(fov_value))
165
+ depth = depth * width / focal_length
166
+
167
+ # interpolate
168
+ depth = torch.nn.functional.interpolate(
169
+ # input should be (B, C, H, W)
170
+ input=depth.unsqueeze(0).unsqueeze(1),
171
+ size=target_size,
172
+ mode=pil_torch_interpolation_mapping[self.resample].value,
173
+ ).squeeze()
174
+
175
+ # inverse the depth
176
+ depth = 1.0 / torch.clamp(depth, min=1e-4, max=1e4)
177
+
178
+ results.append(
179
+ {
180
+ "predicted_depth": depth,
181
+ "field_of_view": fov_value,
182
+ "focal_length": focal_length,
183
+ }
184
+ )
185
+
186
+ return results
187
+
188
+
189
+ __all__ = ["DepthProImageProcessorFast"]
docs/transformers/build/lib/transformers/models/depth_pro/modeling_depth_pro.py ADDED
@@ -0,0 +1,1218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Apple Research Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch DepthPro model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from torch import nn
24
+
25
+ from ...modeling_outputs import BaseModelOutput
26
+ from ...modeling_utils import PreTrainedModel
27
+ from ...utils import (
28
+ ModelOutput,
29
+ add_start_docstrings,
30
+ add_start_docstrings_to_model_forward,
31
+ logging,
32
+ replace_return_docstrings,
33
+ torch_int,
34
+ )
35
+ from ..auto import AutoModel
36
+ from .configuration_depth_pro import DepthProConfig
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ @dataclass
43
+ class DepthProOutput(ModelOutput):
44
+ """
45
+ Base class for DepthPro's outputs.
46
+
47
+ Args:
48
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, n_patches_per_batch, sequence_length, hidden_size)`):
49
+ Sequence of hidden-states at the output of the last layer of the model.
50
+ features (`Union[torch.FloatTensor, List[torch.FloatTensor]]`, *optional*):
51
+ Features from encoders. Can be a single feature or a list of features.
52
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
53
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
54
+ one for the output of each layer) of shape `(batch_size, n_patches_per_batch, sequence_length, hidden_size)`.
55
+
56
+ Hidden-states of the model at the output of each layer and the optional initial embedding outputs.
57
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
58
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, n_patches_per_batch, num_heads, sequence_length,
59
+ sequence_length)`.
60
+
61
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
62
+ heads.
63
+ """
64
+
65
+ last_hidden_state: Optional[torch.FloatTensor] = None
66
+ features: Union[torch.FloatTensor, List[torch.FloatTensor]] = None
67
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
68
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
69
+
70
+
71
+ @dataclass
72
+ class DepthProDepthEstimatorOutput(ModelOutput):
73
+ """
74
+ Base class for DepthProForDepthEstimation's output.
75
+
76
+ Args:
77
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
78
+ Classification (or regression if config.num_labels==1) loss.
79
+ predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`):
80
+ Predicted depth for each pixel.
81
+ field_of_view (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned when `use_fov_model` is provided):
82
+ Field of View Scaler.
83
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
84
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
85
+ one for the output of each layer) of shape `(batch_size, n_patches_per_batch, sequence_length, hidden_size)`.
86
+
87
+ Hidden-states of the model at the output of each layer and the optional initial embedding outputs.
88
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
89
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, n_patches_per_batch, num_heads, sequence_length,
90
+ sequence_length)`.
91
+
92
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
93
+ heads.
94
+ """
95
+
96
+ loss: Optional[torch.FloatTensor] = None
97
+ predicted_depth: Optional[torch.FloatTensor] = None
98
+ field_of_view: Optional[torch.FloatTensor] = None
99
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
100
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
101
+
102
+
103
+ def split_to_patches(pixel_values: torch.Tensor, patch_size: int, overlap_ratio: float) -> torch.Tensor:
104
+ """Creates Patches from Batch."""
105
+ batch_size, num_channels, height, width = pixel_values.shape
106
+
107
+ if height == width == patch_size:
108
+ # create patches only if scaled image is not already equal to patch size
109
+ return pixel_values
110
+
111
+ stride = torch_int(patch_size * (1 - overlap_ratio))
112
+
113
+ patches = F.unfold(pixel_values, kernel_size=(patch_size, patch_size), stride=(stride, stride))
114
+ patches = patches.permute(2, 0, 1)
115
+ patches = patches.reshape(-1, num_channels, patch_size, patch_size)
116
+
117
+ return patches
118
+
119
+
120
+ def reshape_features(hidden_states: torch.Tensor) -> torch.Tensor:
121
+ """Discard class token and reshape 1D feature map to a 2D grid."""
122
+ n_samples, seq_len, hidden_size = hidden_states.shape
123
+ size = torch_int(seq_len**0.5)
124
+
125
+ hidden_states = hidden_states[:, -(size**2) :, :] # remove special tokens if there are any
126
+ hidden_states = hidden_states.reshape(n_samples, size, size, hidden_size)
127
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
128
+
129
+ return hidden_states
130
+
131
+
132
+ def merge_patches(patches: torch.Tensor, batch_size: int, padding: int) -> torch.Tensor:
133
+ """Merges smaller patches into image-like feature map."""
134
+ n_patches, hidden_size, out_size, out_size = patches.shape
135
+ n_patches_per_batch = n_patches // batch_size
136
+ sqrt_n_patches_per_batch = torch_int(n_patches_per_batch**0.5)
137
+ new_out_size = sqrt_n_patches_per_batch * out_size
138
+
139
+ if n_patches == batch_size:
140
+ # merge only if the patches were created from scaled image
141
+ # patches are not created when scaled image size is equal to patch size
142
+ return patches
143
+
144
+ if n_patches_per_batch < 4:
145
+ # for each batch, atleast 4 small patches are required to
146
+ # recreate a large square patch from merging them and later padding is applied
147
+ # 3 x (8x8) patches becomes 1 x ( 8x8 ) patch (extra patch ignored, no padding)
148
+ # 4 x (8x8) patches becomes 1 x (16x16) patch (padding later)
149
+ # 5 x (8x8) patches becomes 1 x (16x16) patch (extra patch ignored, padding later)
150
+ # 9 x (8x8) patches becomes 1 x (24x24) patch (padding later)
151
+ # thus the following code only rearranges the patches and removes extra ones
152
+ padding = 0
153
+
154
+ # make sure padding is not large enough to remove more than half of the patch
155
+ padding = min(out_size // 4, padding)
156
+
157
+ if padding == 0:
158
+ # faster when no padding is required
159
+ merged = patches.reshape(n_patches_per_batch, batch_size, hidden_size, out_size, out_size)
160
+ merged = merged.permute(1, 2, 0, 3, 4)
161
+ merged = merged[:, :, : sqrt_n_patches_per_batch**2, :, :]
162
+ merged = merged.reshape(
163
+ batch_size, hidden_size, sqrt_n_patches_per_batch, sqrt_n_patches_per_batch, out_size, out_size
164
+ )
165
+ merged = merged.permute(0, 1, 2, 4, 3, 5)
166
+ merged = merged.reshape(batch_size, hidden_size, new_out_size, new_out_size)
167
+ else:
168
+ # padding example:
169
+ # let out_size = 8, new_out_size = 32, padding = 2
170
+ # each patch is separated by "|"
171
+ # and padding is applied to the merging edges of each patch
172
+ # 00 01 02 03 04 05 06 07 | 08 09 10 11 12 13 14 15 | 16 17 18 19 20 21 22 23 | 24 25 26 27 28 29 30 31
173
+ # 00 01 02 03 04 05 -- -- | -- -- 10 11 12 13 -- -- | -- -- 18 19 20 21 -- -- | -- -- 26 27 28 29 30 31
174
+ i = 0
175
+ boxes = []
176
+ for h in range(sqrt_n_patches_per_batch):
177
+ boxes_in_row = []
178
+ for w in range(sqrt_n_patches_per_batch):
179
+ box = patches[batch_size * i : batch_size * (i + 1)]
180
+
181
+ # collect paddings
182
+ paddings = [0, 0, 0, 0]
183
+ if h != 0:
184
+ # remove pad from height if box is not at top border
185
+ paddings[0] = padding
186
+ if w != 0:
187
+ # remove pad from width if box is not at left border
188
+ paddings[2] = padding
189
+ if h != sqrt_n_patches_per_batch - 1:
190
+ # remove pad from height if box is not at bottom border
191
+ paddings[1] = padding
192
+ if w != sqrt_n_patches_per_batch - 1:
193
+ # remove pad from width if box is not at right border
194
+ paddings[3] = padding
195
+
196
+ # remove paddings
197
+ _, _, box_h, box_w = box.shape
198
+ pad_top, pad_bottom, pad_left, pad_right = paddings
199
+ box = box[:, :, pad_top : box_h - pad_bottom, pad_left : box_w - pad_right]
200
+
201
+ boxes_in_row.append(box)
202
+ i += 1
203
+ boxes_in_row = torch.cat(boxes_in_row, dim=-1)
204
+ boxes.append(boxes_in_row)
205
+ merged = torch.cat(boxes, dim=-2)
206
+
207
+ return merged
208
+
209
+
210
+ def reconstruct_feature_maps(
211
+ hidden_state: torch.Tensor, batch_size: int, padding: int, output_size: Tuple[float, float]
212
+ ) -> torch.Tensor:
213
+ """
214
+ Reconstructs feature maps from the hidden state produced by any of the encoder. Converts the hidden state of shape
215
+ `(n_patches_per_batch * batch_size, seq_len, hidden_size)` to feature maps of shape
216
+ `(batch_size, hidden_size, output_size[0], output_size[1])`.
217
+
218
+ Args:
219
+ hidden_state (torch.Tensor): Input tensor of shape `(n_patches_per_batch * batch_size, seq_len, hidden_size)`
220
+ representing the encoded patches.
221
+ batch_size (int): The number of samples in a batch.
222
+ padding (int): The amount of padding to be removed when merging patches.
223
+ output_size (Tuple[float, float]): The desired output size for the feature maps, specified as `(height, width)`.
224
+
225
+ Returns:
226
+ torch.Tensor: Reconstructed feature maps of shape `(batch_size, hidden_size, output_size[0], output_size[1])`.
227
+ """
228
+ # reshape back to image like
229
+ features = reshape_features(hidden_state)
230
+
231
+ # merge all patches in a batch to create one large patch per batch
232
+ features = merge_patches(
233
+ features,
234
+ batch_size=batch_size,
235
+ padding=padding,
236
+ )
237
+
238
+ # interpolate patches to base size
239
+ features = F.interpolate(
240
+ features,
241
+ size=output_size,
242
+ mode="bilinear",
243
+ align_corners=False,
244
+ )
245
+
246
+ return features
247
+
248
+
249
+ class DepthProPatchEncoder(nn.Module):
250
+ def __init__(self, config: DepthProConfig):
251
+ super().__init__()
252
+ self.config = config
253
+
254
+ self.intermediate_hook_ids = config.intermediate_hook_ids
255
+ self.intermediate_feature_dims = config.intermediate_feature_dims
256
+ self.scaled_images_ratios = config.scaled_images_ratios
257
+ self.scaled_images_overlap_ratios = config.scaled_images_overlap_ratios
258
+ self.scaled_images_feature_dims = config.scaled_images_feature_dims
259
+ self.merge_padding_value = config.merge_padding_value
260
+
261
+ self.n_scaled_images = len(config.scaled_images_ratios)
262
+ self.n_intermediate_hooks = len(config.intermediate_hook_ids)
263
+ self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
264
+
265
+ self.model = AutoModel.from_config(config.patch_model_config)
266
+
267
+ def forward(
268
+ self,
269
+ pixel_values: torch.Tensor,
270
+ head_mask: Optional[torch.Tensor] = None,
271
+ ) -> List[torch.Tensor]:
272
+ batch_size, num_channels, height, width = pixel_values.shape
273
+
274
+ if min(self.scaled_images_ratios) * min(height, width) < self.config.patch_size:
275
+ raise ValueError(
276
+ f"Image size {height}x{width} is too small to be scaled "
277
+ f"with scaled_images_ratios={self.scaled_images_ratios} "
278
+ f"when patch_size={self.config.patch_size}."
279
+ )
280
+
281
+ # STEP 1: create 3-level image
282
+
283
+ scaled_images = []
284
+ for ratio in self.scaled_images_ratios:
285
+ scaled_images.append(
286
+ F.interpolate(
287
+ pixel_values,
288
+ scale_factor=ratio,
289
+ mode="bilinear",
290
+ align_corners=False,
291
+ )
292
+ )
293
+
294
+ # STEP 2: create patches
295
+
296
+ for i in range(self.n_scaled_images):
297
+ scaled_images[i] = split_to_patches(
298
+ scaled_images[i],
299
+ patch_size=self.config.patch_size,
300
+ overlap_ratio=self.scaled_images_overlap_ratios[i],
301
+ )
302
+ n_patches_per_scaled_image = [len(i) for i in scaled_images]
303
+ patches = torch.cat(scaled_images[::-1], dim=0) # -1 as patch encoder expects high res patches first
304
+
305
+ # STEP 3: apply patch encoder
306
+
307
+ encodings = self.model(
308
+ # each patch is processed as a separate batch
309
+ patches,
310
+ head_mask=head_mask,
311
+ # required for intermediate features
312
+ output_hidden_states=self.n_intermediate_hooks > 0,
313
+ )
314
+
315
+ scaled_images_last_hidden_state = torch.split_with_sizes(encodings[0], n_patches_per_scaled_image[::-1])
316
+ # -1 (reverse list) as patch encoder returns high res patches first, we need low res first
317
+ scaled_images_last_hidden_state = scaled_images_last_hidden_state[::-1]
318
+
319
+ # calculate base height and width
320
+ # base height and width are the dimensions of the lowest resolution features
321
+ exponent_value = torch_int(math.log2(width / self.out_size))
322
+ base_height = height // 2**exponent_value
323
+ base_width = width // 2**exponent_value
324
+
325
+ # STEP 4: get patch features (high_res, med_res, low_res) - (3-5) in diagram
326
+
327
+ scaled_images_features = []
328
+ for i in range(self.n_scaled_images):
329
+ hidden_state = scaled_images_last_hidden_state[i]
330
+ batch_size = batch_size
331
+ padding = torch_int(self.merge_padding_value * (1 / self.scaled_images_ratios[i]))
332
+ output_height = base_height * 2**i
333
+ output_width = base_width * 2**i
334
+ features = reconstruct_feature_maps(
335
+ hidden_state,
336
+ batch_size=batch_size,
337
+ padding=padding,
338
+ output_size=(output_height, output_width),
339
+ )
340
+ scaled_images_features.append(features)
341
+
342
+ # STEP 5: get intermediate features - (1-2) in diagram
343
+
344
+ intermediate_features = []
345
+ for i in range(self.n_intermediate_hooks):
346
+ # +1 to correct index position as hidden_states contain embedding output as well
347
+ hidden_state = encodings[2][self.intermediate_hook_ids[i] + 1]
348
+ padding = torch_int(self.merge_padding_value * (1 / self.scaled_images_ratios[-1]))
349
+ output_height = base_height * 2 ** (self.n_scaled_images - 1)
350
+ output_width = base_width * 2 ** (self.n_scaled_images - 1)
351
+ features = reconstruct_feature_maps(
352
+ hidden_state,
353
+ batch_size=batch_size,
354
+ padding=padding,
355
+ output_size=(output_height, output_width),
356
+ )
357
+ intermediate_features.append(features)
358
+
359
+ # STEP 7: combine all features
360
+ features = [*scaled_images_features, *intermediate_features]
361
+
362
+ return features
363
+
364
+
365
+ class DepthProImageEncoder(nn.Module):
366
+ def __init__(self, config: DepthProConfig):
367
+ super().__init__()
368
+ self.config = config
369
+ self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
370
+
371
+ self.model = AutoModel.from_config(config.image_model_config)
372
+
373
+ def forward(
374
+ self,
375
+ pixel_values: torch.Tensor,
376
+ head_mask: Optional[torch.Tensor] = None,
377
+ output_attentions: bool = False,
378
+ output_hidden_states: bool = False,
379
+ return_dict: bool = True,
380
+ ) -> Union[tuple, DepthProOutput]:
381
+ batch_size, num_channels, height, width = pixel_values.shape
382
+
383
+ # scale the image for image_encoder
384
+ size = self.config.image_model_config.image_size
385
+ pixel_values = F.interpolate(
386
+ pixel_values,
387
+ size=(size, size),
388
+ mode="bilinear",
389
+ align_corners=False,
390
+ )
391
+ encodings = self.model(
392
+ pixel_values=pixel_values,
393
+ head_mask=head_mask,
394
+ output_attentions=output_attentions,
395
+ output_hidden_states=output_hidden_states,
396
+ )
397
+
398
+ # calculate base height and width
399
+ # base height and width are the dimensions of the lowest resolution features
400
+ exponent_value = torch_int(math.log2(width / self.out_size))
401
+ base_height = height // 2**exponent_value
402
+ base_width = width // 2**exponent_value
403
+
404
+ features = reconstruct_feature_maps(
405
+ encodings[0],
406
+ batch_size=batch_size,
407
+ padding=0,
408
+ output_size=(base_height, base_width),
409
+ )
410
+
411
+ if not return_dict:
412
+ return (encodings[0], features) + encodings[2:] # ignore last_hidden_state and poooler output
413
+
414
+ return DepthProOutput(
415
+ last_hidden_state=encodings.last_hidden_state,
416
+ features=features,
417
+ hidden_states=encodings.hidden_states,
418
+ attentions=encodings.attentions,
419
+ )
420
+
421
+
422
+ class DepthProEncoder(nn.Module):
423
+ def __init__(self, config: DepthProConfig):
424
+ super().__init__()
425
+ self.config = config
426
+ self.intermediate_hook_ids = config.intermediate_hook_ids
427
+ self.intermediate_feature_dims = config.intermediate_feature_dims
428
+ self.scaled_images_ratios = config.scaled_images_ratios
429
+ self.scaled_images_overlap_ratios = config.scaled_images_overlap_ratios
430
+ self.scaled_images_feature_dims = config.scaled_images_feature_dims
431
+ self.merge_padding_value = config.merge_padding_value
432
+
433
+ self.n_scaled_images = len(self.scaled_images_ratios)
434
+ self.n_intermediate_hooks = len(self.intermediate_hook_ids)
435
+
436
+ self.patch_encoder = DepthProPatchEncoder(config)
437
+ self.image_encoder = DepthProImageEncoder(config)
438
+
439
+ def forward(
440
+ self,
441
+ pixel_values: torch.Tensor,
442
+ head_mask: Optional[torch.Tensor] = None,
443
+ output_attentions: bool = False,
444
+ output_hidden_states: bool = False,
445
+ return_dict: bool = True,
446
+ ) -> Union[tuple, DepthProOutput]:
447
+ batch_size, num_channels, height, width = pixel_values.shape
448
+
449
+ patch_features = self.patch_encoder(
450
+ pixel_values,
451
+ head_mask=head_mask,
452
+ )
453
+ image_encodings = self.image_encoder(
454
+ pixel_values,
455
+ head_mask=head_mask,
456
+ output_attentions=output_attentions,
457
+ output_hidden_states=output_hidden_states,
458
+ return_dict=return_dict,
459
+ )
460
+ image_features = image_encodings[1] # index 1 contains features
461
+
462
+ features = [image_features, *patch_features]
463
+
464
+ if not return_dict:
465
+ return (image_encodings[0], features) + image_encodings[2:]
466
+
467
+ return DepthProOutput(
468
+ last_hidden_state=image_encodings.last_hidden_state,
469
+ features=features,
470
+ hidden_states=image_encodings.hidden_states,
471
+ attentions=image_encodings.attentions,
472
+ )
473
+
474
+
475
+ class DepthProFeatureUpsampleBlock(nn.Module):
476
+ def __init__(
477
+ self,
478
+ config: DepthProConfig,
479
+ input_dims: int,
480
+ intermediate_dims: int,
481
+ output_dims: int,
482
+ n_upsample_layers: int,
483
+ use_proj: bool = True,
484
+ bias: bool = False,
485
+ ):
486
+ super().__init__()
487
+ self.config = config
488
+ self.layers = nn.ModuleList()
489
+
490
+ # create first projection layer
491
+ if use_proj:
492
+ proj = nn.Conv2d(
493
+ in_channels=input_dims,
494
+ out_channels=intermediate_dims,
495
+ kernel_size=1,
496
+ stride=1,
497
+ padding=0,
498
+ bias=bias,
499
+ )
500
+ self.layers.append(proj)
501
+
502
+ # create following upsample layers
503
+ for i in range(n_upsample_layers):
504
+ in_channels = intermediate_dims if i == 0 else output_dims
505
+ layer = nn.ConvTranspose2d(
506
+ in_channels=in_channels,
507
+ out_channels=output_dims,
508
+ kernel_size=2,
509
+ stride=2,
510
+ padding=0,
511
+ bias=bias,
512
+ )
513
+ self.layers.append(layer)
514
+
515
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
516
+ for layer in self.layers:
517
+ features = layer(features)
518
+ return features
519
+
520
+
521
+ class DepthProFeatureUpsample(nn.Module):
522
+ def __init__(self, config: DepthProConfig):
523
+ super().__init__()
524
+ self.config = config
525
+ self.n_scaled_images = len(self.config.scaled_images_ratios)
526
+ self.n_intermediate_hooks = len(self.config.intermediate_hook_ids)
527
+
528
+ # for image_features
529
+ self.image_block = DepthProFeatureUpsampleBlock(
530
+ config=config,
531
+ input_dims=config.image_model_config.hidden_size,
532
+ intermediate_dims=config.image_model_config.hidden_size,
533
+ output_dims=config.scaled_images_feature_dims[0],
534
+ n_upsample_layers=1,
535
+ use_proj=False,
536
+ bias=True,
537
+ )
538
+
539
+ # for scaled_images_features
540
+ self.scaled_images = nn.ModuleList()
541
+ for i, feature_dims in enumerate(config.scaled_images_feature_dims):
542
+ block = DepthProFeatureUpsampleBlock(
543
+ config=config,
544
+ input_dims=config.patch_model_config.hidden_size,
545
+ intermediate_dims=feature_dims,
546
+ output_dims=feature_dims,
547
+ n_upsample_layers=1,
548
+ )
549
+ self.scaled_images.append(block)
550
+
551
+ # for intermediate_features
552
+ self.intermediate = nn.ModuleList()
553
+ for i, feature_dims in enumerate(config.intermediate_feature_dims):
554
+ intermediate_dims = config.fusion_hidden_size if i == 0 else feature_dims
555
+ block = DepthProFeatureUpsampleBlock(
556
+ config=config,
557
+ input_dims=config.patch_model_config.hidden_size,
558
+ intermediate_dims=intermediate_dims,
559
+ output_dims=feature_dims,
560
+ n_upsample_layers=2 + i,
561
+ )
562
+ self.intermediate.append(block)
563
+
564
+ def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
565
+ features[0] = self.image_block(features[0])
566
+
567
+ for i in range(self.n_scaled_images):
568
+ features[i + 1] = self.scaled_images[i](features[i + 1])
569
+
570
+ for i in range(self.n_intermediate_hooks):
571
+ features[self.n_scaled_images + i + 1] = self.intermediate[i](features[self.n_scaled_images + i + 1])
572
+
573
+ return features
574
+
575
+
576
+ class DepthProFeatureProjection(nn.Module):
577
+ def __init__(self, config: DepthProConfig):
578
+ super().__init__()
579
+ self.config = config
580
+
581
+ combined_feature_dims = config.scaled_images_feature_dims + config.intermediate_feature_dims
582
+ self.projections = nn.ModuleList()
583
+ for i, in_channels in enumerate(combined_feature_dims):
584
+ if i == len(combined_feature_dims) - 1 and in_channels == config.fusion_hidden_size:
585
+ # projection for last layer can be ignored if input and output channels already match
586
+ self.projections.append(nn.Identity())
587
+ else:
588
+ self.projections.append(
589
+ nn.Conv2d(
590
+ in_channels=in_channels,
591
+ out_channels=config.fusion_hidden_size,
592
+ kernel_size=3,
593
+ stride=1,
594
+ padding=1,
595
+ bias=False,
596
+ )
597
+ )
598
+
599
+ def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
600
+ projected_features = []
601
+ for i, projection in enumerate(self.projections):
602
+ upsampled_feature = projection(features[i])
603
+ projected_features.append(upsampled_feature)
604
+ return projected_features
605
+
606
+
607
+ class DepthProNeck(nn.Module):
608
+ def __init__(self, config: DepthProConfig):
609
+ super().__init__()
610
+ self.config = config
611
+
612
+ self.feature_upsample = DepthProFeatureUpsample(config)
613
+ self.fuse_image_with_low_res = nn.Conv2d(
614
+ in_channels=config.scaled_images_feature_dims[0] * 2,
615
+ out_channels=config.scaled_images_feature_dims[0],
616
+ kernel_size=1,
617
+ stride=1,
618
+ padding=0,
619
+ bias=True,
620
+ )
621
+ self.feature_projection = DepthProFeatureProjection(config)
622
+
623
+ def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
624
+ features = self.feature_upsample(features)
625
+ # global features = low res features + image features
626
+ global_features = torch.cat((features[1], features[0]), dim=1)
627
+ global_features = self.fuse_image_with_low_res(global_features)
628
+ features = [global_features, *features[2:]]
629
+ features = self.feature_projection(features)
630
+ return features
631
+
632
+
633
+ # General docstring
634
+ _CONFIG_FOR_DOC = "DepthProConfig"
635
+
636
+
637
+ DEPTH_PRO_START_DOCSTRING = r"""
638
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
639
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
640
+ behavior.
641
+
642
+ Parameters:
643
+ config ([`DepthProConfig`]): Model configuration class with all the parameters of the model.
644
+ Initializing with a config file does not load the weights associated with the model, only the
645
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
646
+ """
647
+
648
+ DEPTH_PRO_INPUTS_DOCSTRING = r"""
649
+ Args:
650
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
651
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`]
652
+ for details.
653
+
654
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
655
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
656
+
657
+ - 1 indicates the head is **not masked**,
658
+ - 0 indicates the head is **masked**.
659
+
660
+ output_attentions (`bool`, *optional*):
661
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
662
+ tensors for more detail.
663
+ output_hidden_states (`bool`, *optional*):
664
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
665
+ more detail.
666
+ return_dict (`bool`, *optional*):
667
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
668
+ """
669
+
670
+ DEPTH_PRO_FOR_DEPTH_ESTIMATION_START_DOCSTRING = r"""
671
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
672
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
673
+ behavior.
674
+
675
+ Parameters:
676
+ config ([`DepthProConfig`]): Model configuration class with all the parameters of the model.
677
+ Initializing with a config file does not load the weights associated with the model, only the
678
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
679
+ use_fov_model (`bool`, *optional*, defaults to `True`):
680
+ Whether to use `DepthProFovModel` to generate the field of view.
681
+ """
682
+
683
+
684
+ class DepthProPreTrainedModel(PreTrainedModel):
685
+ """
686
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
687
+ models.
688
+ """
689
+
690
+ config_class = DepthProConfig
691
+ base_model_prefix = "depth_pro"
692
+ main_input_name = "pixel_values"
693
+ supports_gradient_checkpointing = True
694
+ _supports_sdpa = True
695
+ _no_split_modules = ["DepthProPreActResidualLayer"]
696
+ _keys_to_ignore_on_load_unexpected = ["fov_model.*"]
697
+
698
+ def _init_weights(self, module):
699
+ """Initialize the weights"""
700
+ if isinstance(module, nn.Linear):
701
+ # Slightly different from the TF version which uses truncated_normal for initialization
702
+ # cf https://github.com/pytorch/pytorch/pull/5617
703
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
704
+ if module.bias is not None:
705
+ module.bias.data.zero_()
706
+ elif isinstance(module, nn.LayerNorm):
707
+ module.bias.data.zero_()
708
+ module.weight.data.fill_(1.0)
709
+ elif isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
710
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
711
+ if module.bias is not None:
712
+ module.bias.data.zero_()
713
+
714
+
715
+ @add_start_docstrings(
716
+ "The bare DepthPro Model transformer outputting raw hidden-states without any specific head on top.",
717
+ DEPTH_PRO_START_DOCSTRING,
718
+ )
719
+ class DepthProModel(DepthProPreTrainedModel):
720
+ def __init__(self, config):
721
+ super().__init__(config)
722
+ self.config = config
723
+ self.encoder = DepthProEncoder(config)
724
+ self.neck = DepthProNeck(config)
725
+ # Initialize weights and apply final processing
726
+ self.post_init()
727
+
728
+ def get_input_embeddings(self):
729
+ return self.encoder.image_encoder.model.get_input_embeddings()
730
+
731
+ @add_start_docstrings_to_model_forward(DEPTH_PRO_INPUTS_DOCSTRING)
732
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
733
+ def forward(
734
+ self,
735
+ pixel_values: torch.FloatTensor,
736
+ head_mask: Optional[torch.FloatTensor] = None,
737
+ output_attentions: Optional[bool] = None,
738
+ output_hidden_states: Optional[bool] = None,
739
+ return_dict: Optional[bool] = None,
740
+ ) -> Union[Tuple, DepthProOutput]:
741
+ r"""
742
+ Returns:
743
+
744
+ Examples:
745
+
746
+ ```python
747
+ >>> import torch
748
+ >>> from PIL import Image
749
+ >>> import requests
750
+ >>> from transformers import AutoProcessor, DepthProModel
751
+
752
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
753
+ >>> image = Image.open(requests.get(url, stream=True).raw)
754
+
755
+ >>> checkpoint = "apple/DepthPro-hf"
756
+ >>> processor = AutoProcessor.from_pretrained(checkpoint)
757
+ >>> model = DepthProModel.from_pretrained(checkpoint)
758
+
759
+ >>> # prepare image for the model
760
+ >>> inputs = processor(images=image, return_tensors="pt")
761
+
762
+ >>> with torch.no_grad():
763
+ ... output = model(**inputs)
764
+
765
+ >>> output.last_hidden_state.shape
766
+ torch.Size([1, 35, 577, 1024])
767
+ ```"""
768
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
769
+ output_hidden_states = (
770
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
771
+ )
772
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
773
+
774
+ encodings = self.encoder(
775
+ pixel_values,
776
+ head_mask=head_mask,
777
+ output_attentions=output_attentions,
778
+ output_hidden_states=output_hidden_states,
779
+ return_dict=return_dict,
780
+ )
781
+ features = encodings[1] # index 1 contains features
782
+ features = self.neck(features)
783
+
784
+ if not return_dict:
785
+ return (encodings[0], features) + encodings[2:]
786
+
787
+ return DepthProOutput(
788
+ last_hidden_state=encodings.last_hidden_state,
789
+ features=features,
790
+ hidden_states=encodings.hidden_states,
791
+ attentions=encodings.attentions,
792
+ )
793
+
794
+
795
+ # Copied from transformers.models.dpt.modeling_dpt.DPTPreActResidualLayer DPT->DepthPro
796
+ class DepthProPreActResidualLayer(nn.Module):
797
+ """
798
+ ResidualConvUnit, pre-activate residual unit.
799
+
800
+ Args:
801
+ config (`[DepthProConfig]`):
802
+ Model configuration class defining the model architecture.
803
+ """
804
+
805
+ def __init__(self, config):
806
+ super().__init__()
807
+
808
+ self.use_batch_norm = config.use_batch_norm_in_fusion_residual
809
+ use_bias_in_fusion_residual = (
810
+ config.use_bias_in_fusion_residual
811
+ if config.use_bias_in_fusion_residual is not None
812
+ else not self.use_batch_norm
813
+ )
814
+
815
+ self.activation1 = nn.ReLU()
816
+ self.convolution1 = nn.Conv2d(
817
+ config.fusion_hidden_size,
818
+ config.fusion_hidden_size,
819
+ kernel_size=3,
820
+ stride=1,
821
+ padding=1,
822
+ bias=use_bias_in_fusion_residual,
823
+ )
824
+
825
+ self.activation2 = nn.ReLU()
826
+ self.convolution2 = nn.Conv2d(
827
+ config.fusion_hidden_size,
828
+ config.fusion_hidden_size,
829
+ kernel_size=3,
830
+ stride=1,
831
+ padding=1,
832
+ bias=use_bias_in_fusion_residual,
833
+ )
834
+
835
+ if self.use_batch_norm:
836
+ self.batch_norm1 = nn.BatchNorm2d(config.fusion_hidden_size)
837
+ self.batch_norm2 = nn.BatchNorm2d(config.fusion_hidden_size)
838
+
839
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
840
+ residual = hidden_state
841
+ hidden_state = self.activation1(hidden_state)
842
+
843
+ hidden_state = self.convolution1(hidden_state)
844
+
845
+ if self.use_batch_norm:
846
+ hidden_state = self.batch_norm1(hidden_state)
847
+
848
+ hidden_state = self.activation2(hidden_state)
849
+ hidden_state = self.convolution2(hidden_state)
850
+
851
+ if self.use_batch_norm:
852
+ hidden_state = self.batch_norm2(hidden_state)
853
+
854
+ return hidden_state + residual
855
+
856
+
857
+ # Modified from transformers.models.dpt.modeling_dpt.DPTFeatureFusionLayer
858
+ # except it uses deconv and skip_add and needs no interpolation
859
+ class DepthProFeatureFusionLayer(nn.Module):
860
+ def __init__(self, config: DepthProConfig, use_deconv: bool = True):
861
+ super().__init__()
862
+ self.config = config
863
+ self.use_deconv = use_deconv
864
+
865
+ self.residual_layer1 = DepthProPreActResidualLayer(config)
866
+ self.residual_layer2 = DepthProPreActResidualLayer(config)
867
+
868
+ if self.use_deconv:
869
+ self.deconv = nn.ConvTranspose2d(
870
+ in_channels=config.fusion_hidden_size,
871
+ out_channels=config.fusion_hidden_size,
872
+ kernel_size=2,
873
+ stride=2,
874
+ padding=0,
875
+ bias=False,
876
+ )
877
+
878
+ self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
879
+
880
+ def forward(self, hidden_state: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
881
+ if residual is not None:
882
+ residual = self.residual_layer1(residual)
883
+ hidden_state = hidden_state + residual
884
+
885
+ hidden_state = self.residual_layer2(hidden_state)
886
+ if self.use_deconv:
887
+ hidden_state = self.deconv(hidden_state)
888
+ hidden_state = self.projection(hidden_state)
889
+
890
+ return hidden_state
891
+
892
+
893
+ # Modified from transformers.models.dpt.modeling_dpt.DPTFeatureFusionStage with DPT->DepthPro
894
+ # with deconv and reversed layers
895
+ class DepthProFeatureFusionStage(nn.Module):
896
+ def __init__(self, config):
897
+ super().__init__()
898
+ self.config = config
899
+
900
+ self.num_layers = len(config.intermediate_hook_ids) + len(config.scaled_images_ratios)
901
+ self.intermediate = nn.ModuleList()
902
+ for _ in range(self.num_layers - 1):
903
+ self.intermediate.append(DepthProFeatureFusionLayer(config))
904
+
905
+ # final layer doesnot require deconvolution
906
+ self.final = DepthProFeatureFusionLayer(config, use_deconv=False)
907
+
908
+ def forward(self, hidden_states: List[torch.Tensor]) -> List[torch.Tensor]:
909
+ if self.num_layers != len(hidden_states):
910
+ raise ValueError(
911
+ f"num_layers={self.num_layers} in DepthProFeatureFusionStage"
912
+ f"doesnot match len(hidden_states)={len(hidden_states)}"
913
+ )
914
+
915
+ fused_hidden_states = []
916
+ fused_hidden_state = None
917
+ for hidden_state, layer in zip(hidden_states[:-1], self.intermediate):
918
+ if fused_hidden_state is None:
919
+ # first layer only uses the last hidden_state
920
+ fused_hidden_state = layer(hidden_state)
921
+ else:
922
+ fused_hidden_state = layer(fused_hidden_state, hidden_state)
923
+ fused_hidden_states.append(fused_hidden_state)
924
+
925
+ hidden_state = hidden_states[-1]
926
+ fused_hidden_state = self.final(fused_hidden_state, hidden_state)
927
+ fused_hidden_states.append(fused_hidden_state)
928
+
929
+ return fused_hidden_states
930
+
931
+
932
+ class DepthProFovEncoder(nn.Module):
933
+ def __init__(self, config: DepthProConfig):
934
+ super().__init__()
935
+ self.config = config
936
+ self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
937
+
938
+ self.model = AutoModel.from_config(config.fov_model_config)
939
+ self.neck = nn.Linear(config.fov_model_config.hidden_size, config.fusion_hidden_size // 2)
940
+
941
+ def forward(
942
+ self,
943
+ pixel_values: torch.Tensor,
944
+ head_mask: Optional[torch.Tensor] = None,
945
+ ) -> torch.Tensor:
946
+ batch_size, num_channels, height, width = pixel_values.shape
947
+
948
+ # scale the image for fov_encoder
949
+ size = self.config.fov_model_config.image_size
950
+ pixel_values = F.interpolate(
951
+ pixel_values,
952
+ size=(size, size),
953
+ mode="bilinear",
954
+ align_corners=False,
955
+ )
956
+ encodings = self.model(
957
+ pixel_values=pixel_values,
958
+ head_mask=head_mask,
959
+ )
960
+ hidden_state = encodings[0]
961
+ hidden_state = self.neck(hidden_state)
962
+
963
+ # calculate base height and width
964
+ # base height and width are the dimensions of the lowest resolution features
965
+ exponent_value = torch_int(math.log2(width / self.out_size))
966
+ base_height = height // 2**exponent_value
967
+ base_width = width // 2**exponent_value
968
+
969
+ features = reconstruct_feature_maps(
970
+ hidden_state,
971
+ batch_size=batch_size,
972
+ padding=0,
973
+ output_size=(base_height, base_width),
974
+ )
975
+
976
+ return features
977
+
978
+
979
+ class DepthProFovHead(nn.Module):
980
+ def __init__(self, config: DepthProConfig):
981
+ super().__init__()
982
+ self.config = config
983
+ self.fusion_hidden_size = config.fusion_hidden_size
984
+ self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
985
+
986
+ # create initial head layers
987
+ self.layers = nn.ModuleList()
988
+ for i in range(config.num_fov_head_layers):
989
+ self.layers.append(
990
+ nn.Conv2d(
991
+ math.ceil(self.fusion_hidden_size / 2 ** (i + 1)),
992
+ math.ceil(self.fusion_hidden_size / 2 ** (i + 2)),
993
+ kernel_size=3,
994
+ stride=2,
995
+ padding=1,
996
+ )
997
+ )
998
+ self.layers.append(nn.ReLU(True))
999
+ # calculate expected shapes to finally generate a scalar output from final head layer
1000
+ final_in_channels = math.ceil(self.fusion_hidden_size / 2 ** (config.num_fov_head_layers + 1))
1001
+ final_kernel_size = torch_int((self.out_size - 1) / 2**config.num_fov_head_layers + 1)
1002
+ self.layers.append(
1003
+ nn.Conv2d(
1004
+ in_channels=final_in_channels, out_channels=1, kernel_size=final_kernel_size, stride=1, padding=0
1005
+ )
1006
+ )
1007
+
1008
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
1009
+ features = F.interpolate(
1010
+ features,
1011
+ size=(self.out_size, self.out_size),
1012
+ mode="bilinear",
1013
+ align_corners=False,
1014
+ )
1015
+ for layer in self.layers:
1016
+ features = layer(features)
1017
+ return features
1018
+
1019
+
1020
+ class DepthProFovModel(nn.Module):
1021
+ def __init__(self, config: DepthProConfig):
1022
+ super().__init__()
1023
+ self.config = config
1024
+ self.fusion_hidden_size = config.fusion_hidden_size
1025
+
1026
+ self.fov_encoder = DepthProFovEncoder(config)
1027
+ self.conv = nn.Conv2d(
1028
+ self.fusion_hidden_size, self.fusion_hidden_size // 2, kernel_size=3, stride=2, padding=1
1029
+ )
1030
+ self.activation = nn.ReLU(inplace=True)
1031
+ self.head = DepthProFovHead(config)
1032
+
1033
+ def forward(
1034
+ self,
1035
+ pixel_values: torch.Tensor,
1036
+ global_features: torch.Tensor,
1037
+ head_mask: Optional[torch.Tensor] = None,
1038
+ ) -> torch.Tensor:
1039
+ fov_features = self.fov_encoder(pixel_values, head_mask)
1040
+
1041
+ global_features = self.conv(global_features)
1042
+ global_features = self.activation(global_features)
1043
+
1044
+ fov_features = fov_features + global_features
1045
+ fov_output = self.head(fov_features)
1046
+ fov_output = fov_output.flatten()
1047
+
1048
+ return fov_output
1049
+
1050
+
1051
+ class DepthProDepthEstimationHead(nn.Module):
1052
+ """
1053
+ The DepthProDepthEstimationHead module serves as the output head for depth estimation tasks.
1054
+ This module comprises a sequence of convolutional and transposed convolutional layers
1055
+ that process the feature map from the fusion to produce a single-channel depth map.
1056
+ Key operations include dimensionality reduction and upsampling to match the input resolution.
1057
+ """
1058
+
1059
+ def __init__(self, config):
1060
+ super().__init__()
1061
+ self.config = config
1062
+
1063
+ features = config.fusion_hidden_size
1064
+ self.layers = nn.ModuleList(
1065
+ [
1066
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
1067
+ nn.ConvTranspose2d(
1068
+ in_channels=features // 2,
1069
+ out_channels=features // 2,
1070
+ kernel_size=2,
1071
+ stride=2,
1072
+ padding=0,
1073
+ bias=True,
1074
+ ),
1075
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
1076
+ nn.ReLU(True),
1077
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
1078
+ nn.ReLU(),
1079
+ ]
1080
+ )
1081
+
1082
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1083
+ for layer in self.layers:
1084
+ hidden_states = layer(hidden_states)
1085
+
1086
+ predicted_depth = hidden_states.squeeze(dim=1)
1087
+ return predicted_depth
1088
+
1089
+
1090
+ @add_start_docstrings(
1091
+ """
1092
+ DepthPro Model with a depth estimation head on top (consisting of 3 convolutional layers).
1093
+ """,
1094
+ DEPTH_PRO_FOR_DEPTH_ESTIMATION_START_DOCSTRING,
1095
+ )
1096
+ class DepthProForDepthEstimation(DepthProPreTrainedModel):
1097
+ def __init__(self, config, use_fov_model=None):
1098
+ super().__init__(config)
1099
+ self.config = config
1100
+ self.use_fov_model = use_fov_model if use_fov_model is not None else self.config.use_fov_model
1101
+
1102
+ # dinov2 (vit) like encoders
1103
+ self.depth_pro = DepthProModel(config)
1104
+
1105
+ # dpt (vit) like fusion stage
1106
+ self.fusion_stage = DepthProFeatureFusionStage(config)
1107
+
1108
+ # depth estimation head
1109
+ self.head = DepthProDepthEstimationHead(config)
1110
+
1111
+ # dinov2 (vit) like encoder
1112
+ self.fov_model = DepthProFovModel(config) if self.use_fov_model else None
1113
+
1114
+ # Initialize weights and apply final processing
1115
+ self.post_init()
1116
+
1117
+ @add_start_docstrings_to_model_forward(DEPTH_PRO_INPUTS_DOCSTRING)
1118
+ @replace_return_docstrings(output_type=DepthProDepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
1119
+ def forward(
1120
+ self,
1121
+ pixel_values: torch.FloatTensor,
1122
+ head_mask: Optional[torch.FloatTensor] = None,
1123
+ labels: Optional[torch.LongTensor] = None,
1124
+ output_attentions: Optional[bool] = None,
1125
+ output_hidden_states: Optional[bool] = None,
1126
+ return_dict: Optional[bool] = None,
1127
+ ) -> Union[Tuple[torch.Tensor], DepthProDepthEstimatorOutput]:
1128
+ r"""
1129
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
1130
+ Ground truth depth estimation maps for computing the loss.
1131
+
1132
+ Returns:
1133
+
1134
+ Examples:
1135
+
1136
+ ```python
1137
+ >>> from transformers import AutoImageProcessor, DepthProForDepthEstimation
1138
+ >>> import torch
1139
+ >>> from PIL import Image
1140
+ >>> import requests
1141
+
1142
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1143
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1144
+
1145
+ >>> checkpoint = "apple/DepthPro-hf"
1146
+ >>> processor = AutoImageProcessor.from_pretrained(checkpoint)
1147
+ >>> model = DepthProForDepthEstimation.from_pretrained(checkpoint)
1148
+
1149
+ >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1150
+ >>> model.to(device)
1151
+
1152
+ >>> # prepare image for the model
1153
+ >>> inputs = processor(images=image, return_tensors="pt").to(device)
1154
+
1155
+ >>> with torch.no_grad():
1156
+ ... outputs = model(**inputs)
1157
+
1158
+ >>> # interpolate to original size
1159
+ >>> post_processed_output = processor.post_process_depth_estimation(
1160
+ ... outputs, target_sizes=[(image.height, image.width)],
1161
+ ... )
1162
+
1163
+ >>> # get the field of view (fov) predictions
1164
+ >>> field_of_view = post_processed_output[0]["field_of_view"]
1165
+ >>> focal_length = post_processed_output[0]["focal_length"]
1166
+
1167
+ >>> # visualize the prediction
1168
+ >>> predicted_depth = post_processed_output[0]["predicted_depth"]
1169
+ >>> depth = predicted_depth * 255 / predicted_depth.max()
1170
+ >>> depth = depth.detach().cpu().numpy()
1171
+ >>> depth = Image.fromarray(depth.astype("uint8"))
1172
+ ```"""
1173
+ loss = None
1174
+ if labels is not None:
1175
+ raise NotImplementedError("Training is not implemented yet")
1176
+
1177
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1178
+ output_hidden_states = (
1179
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1180
+ )
1181
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1182
+
1183
+ depth_pro_outputs = self.depth_pro(
1184
+ pixel_values=pixel_values,
1185
+ head_mask=head_mask,
1186
+ output_attentions=output_attentions,
1187
+ output_hidden_states=output_hidden_states,
1188
+ return_dict=True,
1189
+ )
1190
+ features = depth_pro_outputs.features
1191
+ fused_hidden_states = self.fusion_stage(features)
1192
+ predicted_depth = self.head(fused_hidden_states[-1])
1193
+
1194
+ if self.use_fov_model:
1195
+ # frozen features from encoder are used
1196
+ features_for_fov = features[0].detach()
1197
+ fov = self.fov_model(
1198
+ pixel_values=pixel_values,
1199
+ global_features=features_for_fov,
1200
+ head_mask=head_mask,
1201
+ )
1202
+ else:
1203
+ fov = None
1204
+
1205
+ if not return_dict:
1206
+ outputs = [loss, predicted_depth, fov, depth_pro_outputs.hidden_states, depth_pro_outputs.attentions]
1207
+ return tuple(v for v in outputs if v is not None)
1208
+
1209
+ return DepthProDepthEstimatorOutput(
1210
+ loss=loss,
1211
+ predicted_depth=predicted_depth,
1212
+ field_of_view=fov,
1213
+ hidden_states=depth_pro_outputs.hidden_states,
1214
+ attentions=depth_pro_outputs.attentions,
1215
+ )
1216
+
1217
+
1218
+ __all__ = ["DepthProPreTrainedModel", "DepthProModel", "DepthProForDepthEstimation"]
docs/transformers/build/lib/transformers/models/detr/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from ...utils import _LazyModule
18
+ from ...utils.import_utils import define_import_structure
19
+
20
+
21
+ if TYPE_CHECKING:
22
+ from .configuration_detr import *
23
+ from .feature_extraction_detr import *
24
+ from .image_processing_detr import *
25
+ from .image_processing_detr_fast import *
26
+ from .modeling_detr import *
27
+ else:
28
+ import sys
29
+
30
+ _file = globals()["__file__"]
31
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/detr/configuration_detr.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Facebook AI Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """DETR model configuration"""
16
+
17
+ from collections import OrderedDict
18
+ from typing import Mapping
19
+
20
+ from packaging import version
21
+
22
+ from ...configuration_utils import PretrainedConfig
23
+ from ...onnx import OnnxConfig
24
+ from ...utils import logging
25
+ from ...utils.backbone_utils import verify_backbone_config_arguments
26
+ from ..auto import CONFIG_MAPPING
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class DetrConfig(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`DetrModel`]. It is used to instantiate a DETR
35
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
36
+ defaults will yield a similar configuration to that of the DETR
37
+ [facebook/detr-resnet-50](https://huggingface.co/facebook/detr-resnet-50) architecture.
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+ Args:
43
+ use_timm_backbone (`bool`, *optional*, defaults to `True`):
44
+ Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
45
+ API.
46
+ backbone_config (`PretrainedConfig` or `dict`, *optional*):
47
+ The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
48
+ case it will default to `ResNetConfig()`.
49
+ num_channels (`int`, *optional*, defaults to 3):
50
+ The number of input channels.
51
+ num_queries (`int`, *optional*, defaults to 100):
52
+ Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetrModel`] can
53
+ detect in a single image. For COCO, we recommend 100 queries.
54
+ d_model (`int`, *optional*, defaults to 256):
55
+ This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others.
56
+ encoder_layers (`int`, *optional*, defaults to 6):
57
+ Number of encoder layers.
58
+ decoder_layers (`int`, *optional*, defaults to 6):
59
+ Number of decoder layers.
60
+ encoder_attention_heads (`int`, *optional*, defaults to 8):
61
+ Number of attention heads for each attention layer in the Transformer encoder.
62
+ decoder_attention_heads (`int`, *optional*, defaults to 8):
63
+ Number of attention heads for each attention layer in the Transformer decoder.
64
+ decoder_ffn_dim (`int`, *optional*, defaults to 2048):
65
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
66
+ encoder_ffn_dim (`int`, *optional*, defaults to 2048):
67
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
68
+ activation_function (`str` or `function`, *optional*, defaults to `"relu"`):
69
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
70
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
71
+ dropout (`float`, *optional*, defaults to 0.1):
72
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
73
+ attention_dropout (`float`, *optional*, defaults to 0.0):
74
+ The dropout ratio for the attention probabilities.
75
+ activation_dropout (`float`, *optional*, defaults to 0.0):
76
+ The dropout ratio for activations inside the fully connected layer.
77
+ init_std (`float`, *optional*, defaults to 0.02):
78
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
79
+ init_xavier_std (`float`, *optional*, defaults to 1):
80
+ The scaling factor used for the Xavier initialization gain in the HM Attention map module.
81
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
82
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
83
+ for more details.
84
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
85
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
86
+ for more details.
87
+ auxiliary_loss (`bool`, *optional*, defaults to `False`):
88
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
89
+ position_embedding_type (`str`, *optional*, defaults to `"sine"`):
90
+ Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
91
+ backbone (`str`, *optional*, defaults to `"resnet50"`):
92
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
93
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
94
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
95
+ use_pretrained_backbone (`bool`, *optional*, `True`):
96
+ Whether to use pretrained weights for the backbone.
97
+ backbone_kwargs (`dict`, *optional*):
98
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
99
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
100
+ dilation (`bool`, *optional*, defaults to `False`):
101
+ Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
102
+ `use_timm_backbone` = `True`.
103
+ class_cost (`float`, *optional*, defaults to 1):
104
+ Relative weight of the classification error in the Hungarian matching cost.
105
+ bbox_cost (`float`, *optional*, defaults to 5):
106
+ Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost.
107
+ giou_cost (`float`, *optional*, defaults to 2):
108
+ Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost.
109
+ mask_loss_coefficient (`float`, *optional*, defaults to 1):
110
+ Relative weight of the Focal loss in the panoptic segmentation loss.
111
+ dice_loss_coefficient (`float`, *optional*, defaults to 1):
112
+ Relative weight of the DICE/F-1 loss in the panoptic segmentation loss.
113
+ bbox_loss_coefficient (`float`, *optional*, defaults to 5):
114
+ Relative weight of the L1 bounding box loss in the object detection loss.
115
+ giou_loss_coefficient (`float`, *optional*, defaults to 2):
116
+ Relative weight of the generalized IoU loss in the object detection loss.
117
+ eos_coefficient (`float`, *optional*, defaults to 0.1):
118
+ Relative classification weight of the 'no-object' class in the object detection loss.
119
+
120
+ Examples:
121
+
122
+ ```python
123
+ >>> from transformers import DetrConfig, DetrModel
124
+
125
+ >>> # Initializing a DETR facebook/detr-resnet-50 style configuration
126
+ >>> configuration = DetrConfig()
127
+
128
+ >>> # Initializing a model (with random weights) from the facebook/detr-resnet-50 style configuration
129
+ >>> model = DetrModel(configuration)
130
+
131
+ >>> # Accessing the model configuration
132
+ >>> configuration = model.config
133
+ ```"""
134
+
135
+ model_type = "detr"
136
+ keys_to_ignore_at_inference = ["past_key_values"]
137
+ attribute_map = {
138
+ "hidden_size": "d_model",
139
+ "num_attention_heads": "encoder_attention_heads",
140
+ }
141
+
142
+ def __init__(
143
+ self,
144
+ use_timm_backbone=True,
145
+ backbone_config=None,
146
+ num_channels=3,
147
+ num_queries=100,
148
+ encoder_layers=6,
149
+ encoder_ffn_dim=2048,
150
+ encoder_attention_heads=8,
151
+ decoder_layers=6,
152
+ decoder_ffn_dim=2048,
153
+ decoder_attention_heads=8,
154
+ encoder_layerdrop=0.0,
155
+ decoder_layerdrop=0.0,
156
+ is_encoder_decoder=True,
157
+ activation_function="relu",
158
+ d_model=256,
159
+ dropout=0.1,
160
+ attention_dropout=0.0,
161
+ activation_dropout=0.0,
162
+ init_std=0.02,
163
+ init_xavier_std=1.0,
164
+ auxiliary_loss=False,
165
+ position_embedding_type="sine",
166
+ backbone="resnet50",
167
+ use_pretrained_backbone=True,
168
+ backbone_kwargs=None,
169
+ dilation=False,
170
+ class_cost=1,
171
+ bbox_cost=5,
172
+ giou_cost=2,
173
+ mask_loss_coefficient=1,
174
+ dice_loss_coefficient=1,
175
+ bbox_loss_coefficient=5,
176
+ giou_loss_coefficient=2,
177
+ eos_coefficient=0.1,
178
+ **kwargs,
179
+ ):
180
+ # We default to values which were previously hard-coded in the model. This enables configurability of the config
181
+ # while keeping the default behavior the same.
182
+ if use_timm_backbone and backbone_kwargs is None:
183
+ backbone_kwargs = {}
184
+ if dilation:
185
+ backbone_kwargs["output_stride"] = 16
186
+ backbone_kwargs["out_indices"] = [1, 2, 3, 4]
187
+ backbone_kwargs["in_chans"] = num_channels
188
+ # Backwards compatibility
189
+ elif not use_timm_backbone and backbone in (None, "resnet50"):
190
+ if backbone_config is None:
191
+ logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
192
+ backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
193
+ elif isinstance(backbone_config, dict):
194
+ backbone_model_type = backbone_config.get("model_type")
195
+ config_class = CONFIG_MAPPING[backbone_model_type]
196
+ backbone_config = config_class.from_dict(backbone_config)
197
+ backbone = None
198
+ # set timm attributes to None
199
+ dilation = None
200
+
201
+ verify_backbone_config_arguments(
202
+ use_timm_backbone=use_timm_backbone,
203
+ use_pretrained_backbone=use_pretrained_backbone,
204
+ backbone=backbone,
205
+ backbone_config=backbone_config,
206
+ backbone_kwargs=backbone_kwargs,
207
+ )
208
+
209
+ self.use_timm_backbone = use_timm_backbone
210
+ self.backbone_config = backbone_config
211
+ self.num_channels = num_channels
212
+ self.num_queries = num_queries
213
+ self.d_model = d_model
214
+ self.encoder_ffn_dim = encoder_ffn_dim
215
+ self.encoder_layers = encoder_layers
216
+ self.encoder_attention_heads = encoder_attention_heads
217
+ self.decoder_ffn_dim = decoder_ffn_dim
218
+ self.decoder_layers = decoder_layers
219
+ self.decoder_attention_heads = decoder_attention_heads
220
+ self.dropout = dropout
221
+ self.attention_dropout = attention_dropout
222
+ self.activation_dropout = activation_dropout
223
+ self.activation_function = activation_function
224
+ self.init_std = init_std
225
+ self.init_xavier_std = init_xavier_std
226
+ self.encoder_layerdrop = encoder_layerdrop
227
+ self.decoder_layerdrop = decoder_layerdrop
228
+ self.num_hidden_layers = encoder_layers
229
+ self.auxiliary_loss = auxiliary_loss
230
+ self.position_embedding_type = position_embedding_type
231
+ self.backbone = backbone
232
+ self.use_pretrained_backbone = use_pretrained_backbone
233
+ self.backbone_kwargs = backbone_kwargs
234
+ self.dilation = dilation
235
+ # Hungarian matcher
236
+ self.class_cost = class_cost
237
+ self.bbox_cost = bbox_cost
238
+ self.giou_cost = giou_cost
239
+ # Loss coefficients
240
+ self.mask_loss_coefficient = mask_loss_coefficient
241
+ self.dice_loss_coefficient = dice_loss_coefficient
242
+ self.bbox_loss_coefficient = bbox_loss_coefficient
243
+ self.giou_loss_coefficient = giou_loss_coefficient
244
+ self.eos_coefficient = eos_coefficient
245
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
246
+
247
+ @property
248
+ def num_attention_heads(self) -> int:
249
+ return self.encoder_attention_heads
250
+
251
+ @property
252
+ def hidden_size(self) -> int:
253
+ return self.d_model
254
+
255
+ @classmethod
256
+ def from_backbone_config(cls, backbone_config: PretrainedConfig, **kwargs):
257
+ """Instantiate a [`DetrConfig`] (or a derived class) from a pre-trained backbone model configuration.
258
+
259
+ Args:
260
+ backbone_config ([`PretrainedConfig`]):
261
+ The backbone configuration.
262
+ Returns:
263
+ [`DetrConfig`]: An instance of a configuration object
264
+ """
265
+ return cls(backbone_config=backbone_config, **kwargs)
266
+
267
+
268
+ class DetrOnnxConfig(OnnxConfig):
269
+ torch_onnx_minimum_version = version.parse("1.11")
270
+
271
+ @property
272
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
273
+ return OrderedDict(
274
+ [
275
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
276
+ ("pixel_mask", {0: "batch"}),
277
+ ]
278
+ )
279
+
280
+ @property
281
+ def atol_for_validation(self) -> float:
282
+ return 1e-5
283
+
284
+ @property
285
+ def default_onnx_opset(self) -> int:
286
+ return 12
287
+
288
+
289
+ __all__ = ["DetrConfig", "DetrOnnxConfig"]
docs/transformers/build/lib/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert DETR checkpoints with timm backbone."""
16
+
17
+ import argparse
18
+ import json
19
+ from collections import OrderedDict
20
+ from pathlib import Path
21
+
22
+ import requests
23
+ import torch
24
+ from huggingface_hub import hf_hub_download
25
+ from PIL import Image
26
+
27
+ from transformers import DetrConfig, DetrForObjectDetection, DetrForSegmentation, DetrImageProcessor
28
+ from transformers.utils import logging
29
+
30
+
31
+ logging.set_verbosity_info()
32
+ logger = logging.get_logger(__name__)
33
+
34
+ # here we list all keys to be renamed (original name on the left, our name on the right)
35
+ rename_keys = []
36
+ for i in range(6):
37
+ # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
38
+ rename_keys.append(
39
+ (f"transformer.encoder.layers.{i}.self_attn.out_proj.weight", f"encoder.layers.{i}.self_attn.out_proj.weight")
40
+ )
41
+ rename_keys.append(
42
+ (f"transformer.encoder.layers.{i}.self_attn.out_proj.bias", f"encoder.layers.{i}.self_attn.out_proj.bias")
43
+ )
44
+ rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"encoder.layers.{i}.fc1.weight"))
45
+ rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"encoder.layers.{i}.fc1.bias"))
46
+ rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"encoder.layers.{i}.fc2.weight"))
47
+ rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"encoder.layers.{i}.fc2.bias"))
48
+ rename_keys.append(
49
+ (f"transformer.encoder.layers.{i}.norm1.weight", f"encoder.layers.{i}.self_attn_layer_norm.weight")
50
+ )
51
+ rename_keys.append((f"transformer.encoder.layers.{i}.norm1.bias", f"encoder.layers.{i}.self_attn_layer_norm.bias"))
52
+ rename_keys.append((f"transformer.encoder.layers.{i}.norm2.weight", f"encoder.layers.{i}.final_layer_norm.weight"))
53
+ rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"encoder.layers.{i}.final_layer_norm.bias"))
54
+ # decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
55
+ rename_keys.append(
56
+ (f"transformer.decoder.layers.{i}.self_attn.out_proj.weight", f"decoder.layers.{i}.self_attn.out_proj.weight")
57
+ )
58
+ rename_keys.append(
59
+ (f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"decoder.layers.{i}.self_attn.out_proj.bias")
60
+ )
61
+ rename_keys.append(
62
+ (
63
+ f"transformer.decoder.layers.{i}.multihead_attn.out_proj.weight",
64
+ f"decoder.layers.{i}.encoder_attn.out_proj.weight",
65
+ )
66
+ )
67
+ rename_keys.append(
68
+ (
69
+ f"transformer.decoder.layers.{i}.multihead_attn.out_proj.bias",
70
+ f"decoder.layers.{i}.encoder_attn.out_proj.bias",
71
+ )
72
+ )
73
+ rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"decoder.layers.{i}.fc1.weight"))
74
+ rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"decoder.layers.{i}.fc1.bias"))
75
+ rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"decoder.layers.{i}.fc2.weight"))
76
+ rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"decoder.layers.{i}.fc2.bias"))
77
+ rename_keys.append(
78
+ (f"transformer.decoder.layers.{i}.norm1.weight", f"decoder.layers.{i}.self_attn_layer_norm.weight")
79
+ )
80
+ rename_keys.append((f"transformer.decoder.layers.{i}.norm1.bias", f"decoder.layers.{i}.self_attn_layer_norm.bias"))
81
+ rename_keys.append(
82
+ (f"transformer.decoder.layers.{i}.norm2.weight", f"decoder.layers.{i}.encoder_attn_layer_norm.weight")
83
+ )
84
+ rename_keys.append(
85
+ (f"transformer.decoder.layers.{i}.norm2.bias", f"decoder.layers.{i}.encoder_attn_layer_norm.bias")
86
+ )
87
+ rename_keys.append((f"transformer.decoder.layers.{i}.norm3.weight", f"decoder.layers.{i}.final_layer_norm.weight"))
88
+ rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"decoder.layers.{i}.final_layer_norm.bias"))
89
+
90
+ # convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
91
+ rename_keys.extend(
92
+ [
93
+ ("input_proj.weight", "input_projection.weight"),
94
+ ("input_proj.bias", "input_projection.bias"),
95
+ ("query_embed.weight", "query_position_embeddings.weight"),
96
+ ("transformer.decoder.norm.weight", "decoder.layernorm.weight"),
97
+ ("transformer.decoder.norm.bias", "decoder.layernorm.bias"),
98
+ ("class_embed.weight", "class_labels_classifier.weight"),
99
+ ("class_embed.bias", "class_labels_classifier.bias"),
100
+ ("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"),
101
+ ("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"),
102
+ ("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"),
103
+ ("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"),
104
+ ("bbox_embed.layers.2.weight", "bbox_predictor.layers.2.weight"),
105
+ ("bbox_embed.layers.2.bias", "bbox_predictor.layers.2.bias"),
106
+ ]
107
+ )
108
+
109
+
110
+ def rename_key(state_dict, old, new):
111
+ val = state_dict.pop(old)
112
+ state_dict[new] = val
113
+
114
+
115
+ def rename_backbone_keys(state_dict):
116
+ new_state_dict = OrderedDict()
117
+ for key, value in state_dict.items():
118
+ if "backbone.0.body" in key:
119
+ new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model")
120
+ new_state_dict[new_key] = value
121
+ else:
122
+ new_state_dict[key] = value
123
+
124
+ return new_state_dict
125
+
126
+
127
+ def read_in_q_k_v(state_dict, is_panoptic=False):
128
+ prefix = ""
129
+ if is_panoptic:
130
+ prefix = "detr."
131
+
132
+ # first: transformer encoder
133
+ for i in range(6):
134
+ # read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
135
+ in_proj_weight = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_weight")
136
+ in_proj_bias = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_bias")
137
+ # next, add query, keys and values (in that order) to the state dict
138
+ state_dict[f"encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
139
+ state_dict[f"encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
140
+ state_dict[f"encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
141
+ state_dict[f"encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
142
+ state_dict[f"encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
143
+ state_dict[f"encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
144
+ # next: transformer decoder (which is a bit more complex because it also includes cross-attention)
145
+ for i in range(6):
146
+ # read in weights + bias of input projection layer of self-attention
147
+ in_proj_weight = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_weight")
148
+ in_proj_bias = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_bias")
149
+ # next, add query, keys and values (in that order) to the state dict
150
+ state_dict[f"decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
151
+ state_dict[f"decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
152
+ state_dict[f"decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
153
+ state_dict[f"decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
154
+ state_dict[f"decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
155
+ state_dict[f"decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
156
+ # read in weights + bias of input projection layer of cross-attention
157
+ in_proj_weight_cross_attn = state_dict.pop(
158
+ f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_weight"
159
+ )
160
+ in_proj_bias_cross_attn = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_bias")
161
+ # next, add query, keys and values (in that order) of cross-attention to the state dict
162
+ state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.weight"] = in_proj_weight_cross_attn[:256, :]
163
+ state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.bias"] = in_proj_bias_cross_attn[:256]
164
+ state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.weight"] = in_proj_weight_cross_attn[256:512, :]
165
+ state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.bias"] = in_proj_bias_cross_attn[256:512]
166
+ state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.weight"] = in_proj_weight_cross_attn[-256:, :]
167
+ state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.bias"] = in_proj_bias_cross_attn[-256:]
168
+
169
+
170
+ # We will verify our results on an image of cute cats
171
+ def prepare_img():
172
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
173
+ im = Image.open(requests.get(url, stream=True).raw)
174
+
175
+ return im
176
+
177
+
178
+ @torch.no_grad()
179
+ def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
180
+ """
181
+ Copy/paste/tweak model's weights to our DETR structure.
182
+ """
183
+
184
+ # load default config
185
+ config = DetrConfig()
186
+ # set backbone and dilation attributes
187
+ if "resnet101" in model_name:
188
+ config.backbone = "resnet101"
189
+ if "dc5" in model_name:
190
+ config.dilation = True
191
+ is_panoptic = "panoptic" in model_name
192
+ if is_panoptic:
193
+ config.num_labels = 250
194
+ else:
195
+ config.num_labels = 91
196
+ repo_id = "huggingface/label-files"
197
+ filename = "coco-detection-id2label.json"
198
+ id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
199
+ id2label = {int(k): v for k, v in id2label.items()}
200
+ config.id2label = id2label
201
+ config.label2id = {v: k for k, v in id2label.items()}
202
+
203
+ # load image processor
204
+ format = "coco_panoptic" if is_panoptic else "coco_detection"
205
+ image_processor = DetrImageProcessor(format=format)
206
+
207
+ # prepare image
208
+ img = prepare_img()
209
+ encoding = image_processor(images=img, return_tensors="pt")
210
+ pixel_values = encoding["pixel_values"]
211
+
212
+ logger.info(f"Converting model {model_name}...")
213
+
214
+ # load original model from torch hub
215
+ detr = torch.hub.load("facebookresearch/detr", model_name, pretrained=True).eval()
216
+ state_dict = detr.state_dict()
217
+ # rename keys
218
+ for src, dest in rename_keys:
219
+ if is_panoptic:
220
+ src = "detr." + src
221
+ rename_key(state_dict, src, dest)
222
+ state_dict = rename_backbone_keys(state_dict)
223
+ # query, key and value matrices need special treatment
224
+ read_in_q_k_v(state_dict, is_panoptic=is_panoptic)
225
+ # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
226
+ prefix = "detr.model." if is_panoptic else "model."
227
+ for key in state_dict.copy().keys():
228
+ if is_panoptic:
229
+ if (
230
+ key.startswith("detr")
231
+ and not key.startswith("class_labels_classifier")
232
+ and not key.startswith("bbox_predictor")
233
+ ):
234
+ val = state_dict.pop(key)
235
+ state_dict["detr.model" + key[4:]] = val
236
+ elif "class_labels_classifier" in key or "bbox_predictor" in key:
237
+ val = state_dict.pop(key)
238
+ state_dict["detr." + key] = val
239
+ elif key.startswith("bbox_attention") or key.startswith("mask_head"):
240
+ continue
241
+ else:
242
+ val = state_dict.pop(key)
243
+ state_dict[prefix + key] = val
244
+ else:
245
+ if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"):
246
+ val = state_dict.pop(key)
247
+ state_dict[prefix + key] = val
248
+ # finally, create HuggingFace model and load state dict
249
+ model = DetrForSegmentation(config) if is_panoptic else DetrForObjectDetection(config)
250
+ model.load_state_dict(state_dict)
251
+ model.eval()
252
+ # verify our conversion
253
+ original_outputs = detr(pixel_values)
254
+ outputs = model(pixel_values)
255
+ assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-4)
256
+ assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-4)
257
+ if is_panoptic:
258
+ assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4)
259
+
260
+ # Save model and image processor
261
+ logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
262
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
263
+ model.save_pretrained(pytorch_dump_folder_path)
264
+ image_processor.save_pretrained(pytorch_dump_folder_path)
265
+
266
+
267
+ if __name__ == "__main__":
268
+ parser = argparse.ArgumentParser()
269
+
270
+ parser.add_argument(
271
+ "--model_name", default="detr_resnet50", type=str, help="Name of the DETR model you'd like to convert."
272
+ )
273
+ parser.add_argument(
274
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
275
+ )
276
+ args = parser.parse_args()
277
+ convert_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path)
docs/transformers/build/lib/transformers/models/detr/convert_detr_to_pytorch.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert DETR checkpoints with native (Transformers) backbone."""
16
+
17
+ import argparse
18
+ import json
19
+ from pathlib import Path
20
+
21
+ import requests
22
+ import torch
23
+ from huggingface_hub import hf_hub_download
24
+ from PIL import Image
25
+
26
+ from transformers import DetrConfig, DetrForObjectDetection, DetrForSegmentation, DetrImageProcessor, ResNetConfig
27
+ from transformers.utils import logging
28
+
29
+
30
+ logging.set_verbosity_info()
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ def get_detr_config(model_name):
35
+ # initialize config
36
+ if "resnet-50" in model_name:
37
+ backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-50")
38
+ elif "resnet-101" in model_name:
39
+ backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-101")
40
+ else:
41
+ raise ValueError("Model name should include either resnet50 or resnet101")
42
+
43
+ config = DetrConfig(use_timm_backbone=False, backbone_config=backbone_config)
44
+
45
+ # set label attributes
46
+ is_panoptic = "panoptic" in model_name
47
+ if is_panoptic:
48
+ config.num_labels = 250
49
+ else:
50
+ config.num_labels = 91
51
+ repo_id = "huggingface/label-files"
52
+ filename = "coco-detection-id2label.json"
53
+ id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
54
+ id2label = {int(k): v for k, v in id2label.items()}
55
+ config.id2label = id2label
56
+ config.label2id = {v: k for k, v in id2label.items()}
57
+
58
+ return config, is_panoptic
59
+
60
+
61
+ def create_rename_keys(config):
62
+ # here we list all keys to be renamed (original name on the left, our name on the right)
63
+ rename_keys = []
64
+
65
+ # stem
66
+ # fmt: off
67
+ rename_keys.append(("backbone.0.body.conv1.weight", "backbone.conv_encoder.model.embedder.embedder.convolution.weight"))
68
+ rename_keys.append(("backbone.0.body.bn1.weight", "backbone.conv_encoder.model.embedder.embedder.normalization.weight"))
69
+ rename_keys.append(("backbone.0.body.bn1.bias", "backbone.conv_encoder.model.embedder.embedder.normalization.bias"))
70
+ rename_keys.append(("backbone.0.body.bn1.running_mean", "backbone.conv_encoder.model.embedder.embedder.normalization.running_mean"))
71
+ rename_keys.append(("backbone.0.body.bn1.running_var", "backbone.conv_encoder.model.embedder.embedder.normalization.running_var"))
72
+ # stages
73
+ for stage_idx in range(len(config.backbone_config.depths)):
74
+ for layer_idx in range(config.backbone_config.depths[stage_idx]):
75
+ # shortcut
76
+ if layer_idx == 0:
77
+ rename_keys.append(
78
+ (
79
+ f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.0.weight",
80
+ f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.convolution.weight",
81
+ )
82
+ )
83
+ rename_keys.append(
84
+ (
85
+ f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.weight",
86
+ f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.weight",
87
+ )
88
+ )
89
+ rename_keys.append(
90
+ (
91
+ f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.bias",
92
+ f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.bias",
93
+ )
94
+ )
95
+ rename_keys.append(
96
+ (
97
+ f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.running_mean",
98
+ f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_mean",
99
+ )
100
+ )
101
+ rename_keys.append(
102
+ (
103
+ f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.running_var",
104
+ f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_var",
105
+ )
106
+ )
107
+ # 3 convs
108
+ for i in range(3):
109
+ rename_keys.append(
110
+ (
111
+ f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.conv{i+1}.weight",
112
+ f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.convolution.weight",
113
+ )
114
+ )
115
+ rename_keys.append(
116
+ (
117
+ f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.weight",
118
+ f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.weight",
119
+ )
120
+ )
121
+ rename_keys.append(
122
+ (
123
+ f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.bias",
124
+ f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.bias",
125
+ )
126
+ )
127
+ rename_keys.append(
128
+ (
129
+ f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.running_mean",
130
+ f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_mean",
131
+ )
132
+ )
133
+ rename_keys.append(
134
+ (
135
+ f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.running_var",
136
+ f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_var",
137
+ )
138
+ )
139
+ # fmt: on
140
+
141
+ for i in range(config.encoder_layers):
142
+ # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
143
+ rename_keys.append(
144
+ (
145
+ f"transformer.encoder.layers.{i}.self_attn.out_proj.weight",
146
+ f"encoder.layers.{i}.self_attn.out_proj.weight",
147
+ )
148
+ )
149
+ rename_keys.append(
150
+ (f"transformer.encoder.layers.{i}.self_attn.out_proj.bias", f"encoder.layers.{i}.self_attn.out_proj.bias")
151
+ )
152
+ rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"encoder.layers.{i}.fc1.weight"))
153
+ rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"encoder.layers.{i}.fc1.bias"))
154
+ rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"encoder.layers.{i}.fc2.weight"))
155
+ rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"encoder.layers.{i}.fc2.bias"))
156
+ rename_keys.append(
157
+ (f"transformer.encoder.layers.{i}.norm1.weight", f"encoder.layers.{i}.self_attn_layer_norm.weight")
158
+ )
159
+ rename_keys.append(
160
+ (f"transformer.encoder.layers.{i}.norm1.bias", f"encoder.layers.{i}.self_attn_layer_norm.bias")
161
+ )
162
+ rename_keys.append(
163
+ (f"transformer.encoder.layers.{i}.norm2.weight", f"encoder.layers.{i}.final_layer_norm.weight")
164
+ )
165
+ rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"encoder.layers.{i}.final_layer_norm.bias"))
166
+ # decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
167
+ rename_keys.append(
168
+ (
169
+ f"transformer.decoder.layers.{i}.self_attn.out_proj.weight",
170
+ f"decoder.layers.{i}.self_attn.out_proj.weight",
171
+ )
172
+ )
173
+ rename_keys.append(
174
+ (f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"decoder.layers.{i}.self_attn.out_proj.bias")
175
+ )
176
+ rename_keys.append(
177
+ (
178
+ f"transformer.decoder.layers.{i}.multihead_attn.out_proj.weight",
179
+ f"decoder.layers.{i}.encoder_attn.out_proj.weight",
180
+ )
181
+ )
182
+ rename_keys.append(
183
+ (
184
+ f"transformer.decoder.layers.{i}.multihead_attn.out_proj.bias",
185
+ f"decoder.layers.{i}.encoder_attn.out_proj.bias",
186
+ )
187
+ )
188
+ rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"decoder.layers.{i}.fc1.weight"))
189
+ rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"decoder.layers.{i}.fc1.bias"))
190
+ rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"decoder.layers.{i}.fc2.weight"))
191
+ rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"decoder.layers.{i}.fc2.bias"))
192
+ rename_keys.append(
193
+ (f"transformer.decoder.layers.{i}.norm1.weight", f"decoder.layers.{i}.self_attn_layer_norm.weight")
194
+ )
195
+ rename_keys.append(
196
+ (f"transformer.decoder.layers.{i}.norm1.bias", f"decoder.layers.{i}.self_attn_layer_norm.bias")
197
+ )
198
+ rename_keys.append(
199
+ (f"transformer.decoder.layers.{i}.norm2.weight", f"decoder.layers.{i}.encoder_attn_layer_norm.weight")
200
+ )
201
+ rename_keys.append(
202
+ (f"transformer.decoder.layers.{i}.norm2.bias", f"decoder.layers.{i}.encoder_attn_layer_norm.bias")
203
+ )
204
+ rename_keys.append(
205
+ (f"transformer.decoder.layers.{i}.norm3.weight", f"decoder.layers.{i}.final_layer_norm.weight")
206
+ )
207
+ rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"decoder.layers.{i}.final_layer_norm.bias"))
208
+
209
+ # convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
210
+ rename_keys.extend(
211
+ [
212
+ ("input_proj.weight", "input_projection.weight"),
213
+ ("input_proj.bias", "input_projection.bias"),
214
+ ("query_embed.weight", "query_position_embeddings.weight"),
215
+ ("transformer.decoder.norm.weight", "decoder.layernorm.weight"),
216
+ ("transformer.decoder.norm.bias", "decoder.layernorm.bias"),
217
+ ("class_embed.weight", "class_labels_classifier.weight"),
218
+ ("class_embed.bias", "class_labels_classifier.bias"),
219
+ ("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"),
220
+ ("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"),
221
+ ("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"),
222
+ ("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"),
223
+ ("bbox_embed.layers.2.weight", "bbox_predictor.layers.2.weight"),
224
+ ("bbox_embed.layers.2.bias", "bbox_predictor.layers.2.bias"),
225
+ ]
226
+ )
227
+
228
+ return rename_keys
229
+
230
+
231
+ def rename_key(state_dict, old, new):
232
+ val = state_dict.pop(old)
233
+ state_dict[new] = val
234
+
235
+
236
+ def read_in_q_k_v(state_dict, is_panoptic=False):
237
+ prefix = ""
238
+ if is_panoptic:
239
+ prefix = "detr."
240
+
241
+ # first: transformer encoder
242
+ for i in range(6):
243
+ # read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
244
+ in_proj_weight = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_weight")
245
+ in_proj_bias = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_bias")
246
+ # next, add query, keys and values (in that order) to the state dict
247
+ state_dict[f"encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
248
+ state_dict[f"encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
249
+ state_dict[f"encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
250
+ state_dict[f"encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
251
+ state_dict[f"encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
252
+ state_dict[f"encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
253
+ # next: transformer decoder (which is a bit more complex because it also includes cross-attention)
254
+ for i in range(6):
255
+ # read in weights + bias of input projection layer of self-attention
256
+ in_proj_weight = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_weight")
257
+ in_proj_bias = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_bias")
258
+ # next, add query, keys and values (in that order) to the state dict
259
+ state_dict[f"decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
260
+ state_dict[f"decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
261
+ state_dict[f"decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
262
+ state_dict[f"decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
263
+ state_dict[f"decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
264
+ state_dict[f"decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
265
+ # read in weights + bias of input projection layer of cross-attention
266
+ in_proj_weight_cross_attn = state_dict.pop(
267
+ f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_weight"
268
+ )
269
+ in_proj_bias_cross_attn = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_bias")
270
+ # next, add query, keys and values (in that order) of cross-attention to the state dict
271
+ state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.weight"] = in_proj_weight_cross_attn[:256, :]
272
+ state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.bias"] = in_proj_bias_cross_attn[:256]
273
+ state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.weight"] = in_proj_weight_cross_attn[256:512, :]
274
+ state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.bias"] = in_proj_bias_cross_attn[256:512]
275
+ state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.weight"] = in_proj_weight_cross_attn[-256:, :]
276
+ state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.bias"] = in_proj_bias_cross_attn[-256:]
277
+
278
+
279
+ # We will verify our results on an image of cute cats
280
+ def prepare_img():
281
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
282
+ im = Image.open(requests.get(url, stream=True).raw)
283
+
284
+ return im
285
+
286
+
287
+ @torch.no_grad()
288
+ def convert_detr_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
289
+ """
290
+ Copy/paste/tweak model's weights to our DETR structure.
291
+ """
292
+
293
+ # load default config
294
+ config, is_panoptic = get_detr_config(model_name)
295
+
296
+ # load original model from torch hub
297
+ model_name_to_original_name = {
298
+ "detr-resnet-50": "detr_resnet50",
299
+ "detr-resnet-101": "detr_resnet101",
300
+ }
301
+ logger.info(f"Converting model {model_name}...")
302
+ detr = torch.hub.load("facebookresearch/detr", model_name_to_original_name[model_name], pretrained=True).eval()
303
+ state_dict = detr.state_dict()
304
+ # rename keys
305
+ for src, dest in create_rename_keys(config):
306
+ if is_panoptic:
307
+ src = "detr." + src
308
+ rename_key(state_dict, src, dest)
309
+ # query, key and value matrices need special treatment
310
+ read_in_q_k_v(state_dict, is_panoptic=is_panoptic)
311
+ # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
312
+ prefix = "detr.model." if is_panoptic else "model."
313
+ for key in state_dict.copy().keys():
314
+ if is_panoptic:
315
+ if (
316
+ key.startswith("detr")
317
+ and not key.startswith("class_labels_classifier")
318
+ and not key.startswith("bbox_predictor")
319
+ ):
320
+ val = state_dict.pop(key)
321
+ state_dict["detr.model" + key[4:]] = val
322
+ elif "class_labels_classifier" in key or "bbox_predictor" in key:
323
+ val = state_dict.pop(key)
324
+ state_dict["detr." + key] = val
325
+ elif key.startswith("bbox_attention") or key.startswith("mask_head"):
326
+ continue
327
+ else:
328
+ val = state_dict.pop(key)
329
+ state_dict[prefix + key] = val
330
+ else:
331
+ if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"):
332
+ val = state_dict.pop(key)
333
+ state_dict[prefix + key] = val
334
+
335
+ # finally, create HuggingFace model and load state dict
336
+ model = DetrForSegmentation(config) if is_panoptic else DetrForObjectDetection(config)
337
+ model.load_state_dict(state_dict)
338
+ model.eval()
339
+
340
+ # verify our conversion on an image
341
+ format = "coco_panoptic" if is_panoptic else "coco_detection"
342
+ processor = DetrImageProcessor(format=format)
343
+
344
+ encoding = processor(images=prepare_img(), return_tensors="pt")
345
+ pixel_values = encoding["pixel_values"]
346
+
347
+ original_outputs = detr(pixel_values)
348
+ outputs = model(pixel_values)
349
+
350
+ assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-3)
351
+ assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-3)
352
+ if is_panoptic:
353
+ assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4)
354
+ print("Looks ok!")
355
+
356
+ if pytorch_dump_folder_path is not None:
357
+ # Save model and image processor
358
+ logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
359
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
360
+ model.save_pretrained(pytorch_dump_folder_path)
361
+ processor.save_pretrained(pytorch_dump_folder_path)
362
+
363
+ if push_to_hub:
364
+ # Upload model and image processor to the hub
365
+ logger.info("Uploading PyTorch model and image processor to the hub...")
366
+ model.push_to_hub(f"nielsr/{model_name}")
367
+ processor.push_to_hub(f"nielsr/{model_name}")
368
+
369
+
370
+ if __name__ == "__main__":
371
+ parser = argparse.ArgumentParser()
372
+
373
+ parser.add_argument(
374
+ "--model_name",
375
+ default="detr-resnet-50",
376
+ type=str,
377
+ choices=["detr-resnet-50", "detr-resnet-101"],
378
+ help="Name of the DETR model you'd like to convert.",
379
+ )
380
+ parser.add_argument(
381
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
382
+ )
383
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to the hub or not.")
384
+ args = parser.parse_args()
385
+ convert_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
docs/transformers/build/lib/transformers/models/detr/feature_extraction_detr.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Feature extractor class for DETR."""
16
+
17
+ import warnings
18
+
19
+ from ...image_transforms import rgb_to_id as _rgb_to_id
20
+ from ...utils import logging
21
+ from ...utils.import_utils import requires
22
+ from .image_processing_detr import DetrImageProcessor
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ def rgb_to_id(x):
29
+ warnings.warn(
30
+ "rgb_to_id has moved and will not be importable from this module from v5. "
31
+ "Please import from transformers.image_transforms instead.",
32
+ FutureWarning,
33
+ )
34
+ return _rgb_to_id(x)
35
+
36
+
37
+ @requires(backends=("vision",))
38
+ class DetrFeatureExtractor(DetrImageProcessor):
39
+ def __init__(self, *args, **kwargs) -> None:
40
+ warnings.warn(
41
+ "The class DetrFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
42
+ " Please use DetrImageProcessor instead.",
43
+ FutureWarning,
44
+ )
45
+ super().__init__(*args, **kwargs)
46
+
47
+
48
+ __all__ = ["DetrFeatureExtractor"]
docs/transformers/build/lib/transformers/models/detr/image_processing_detr_fast.py ADDED
@@ -0,0 +1,1312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Fast Image processor class for DETR."""
16
+
17
+ import io
18
+ import pathlib
19
+ from collections import defaultdict
20
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
21
+
22
+ from ...image_processing_utils import BatchFeature, get_size_dict
23
+ from ...image_processing_utils_fast import (
24
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
25
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
26
+ BaseImageProcessorFast,
27
+ DefaultFastImageProcessorKwargs,
28
+ SizeDict,
29
+ get_image_size_for_max_height_width,
30
+ get_max_height_width,
31
+ safe_squeeze,
32
+ )
33
+ from ...image_transforms import (
34
+ center_to_corners_format,
35
+ corners_to_center_format,
36
+ id_to_rgb,
37
+ )
38
+ from ...image_utils import (
39
+ IMAGENET_DEFAULT_MEAN,
40
+ IMAGENET_DEFAULT_STD,
41
+ AnnotationFormat,
42
+ AnnotationType,
43
+ ChannelDimension,
44
+ ImageInput,
45
+ PILImageResampling,
46
+ get_image_size,
47
+ validate_annotations,
48
+ )
49
+ from ...processing_utils import Unpack
50
+ from ...utils import (
51
+ TensorType,
52
+ add_start_docstrings,
53
+ is_torch_available,
54
+ is_torchvision_available,
55
+ is_torchvision_v2_available,
56
+ is_vision_available,
57
+ logging,
58
+ )
59
+ from ...utils.import_utils import requires
60
+ from .image_processing_detr import (
61
+ compute_segments,
62
+ convert_segmentation_to_rle,
63
+ get_size_with_aspect_ratio,
64
+ remove_low_and_no_objects,
65
+ )
66
+
67
+
68
+ if is_torch_available():
69
+ import torch
70
+ from torch import nn
71
+
72
+ if is_vision_available():
73
+ import PIL
74
+
75
+
76
+ if is_torchvision_v2_available():
77
+ from torchvision.io import read_image
78
+ from torchvision.transforms.v2 import functional as F
79
+ elif is_torchvision_available():
80
+ from torchvision.io import read_image
81
+ from torchvision.transforms import functional as F
82
+
83
+
84
+ logger = logging.get_logger(__name__)
85
+
86
+ SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC)
87
+
88
+
89
+ # inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L33
90
+ def convert_coco_poly_to_mask(segmentations, height: int, width: int, device: torch.device) -> torch.Tensor:
91
+ """
92
+ Convert a COCO polygon annotation to a mask.
93
+
94
+ Args:
95
+ segmentations (`List[List[float]]`):
96
+ List of polygons, each polygon represented by a list of x-y coordinates.
97
+ height (`int`):
98
+ Height of the mask.
99
+ width (`int`):
100
+ Width of the mask.
101
+ """
102
+ try:
103
+ from pycocotools import mask as coco_mask
104
+ except ImportError:
105
+ raise ImportError("Pycocotools is not installed in your environment.")
106
+
107
+ masks = []
108
+ for polygons in segmentations:
109
+ rles = coco_mask.frPyObjects(polygons, height, width)
110
+ mask = coco_mask.decode(rles)
111
+ if len(mask.shape) < 3:
112
+ mask = mask[..., None]
113
+ mask = torch.as_tensor(mask, dtype=torch.uint8, device=device)
114
+ mask = torch.any(mask, axis=2)
115
+ masks.append(mask)
116
+ if masks:
117
+ masks = torch.stack(masks, axis=0)
118
+ else:
119
+ masks = torch.zeros((0, height, width), dtype=torch.uint8, device=device)
120
+
121
+ return masks
122
+
123
+
124
+ # inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L50
125
+ def prepare_coco_detection_annotation(
126
+ image,
127
+ target,
128
+ return_segmentation_masks: bool = False,
129
+ input_data_format: Optional[Union[ChannelDimension, str]] = None,
130
+ ):
131
+ """
132
+ Convert the target in COCO format into the format expected by DETR.
133
+ """
134
+ image_height, image_width = image.size()[-2:]
135
+
136
+ image_id = target["image_id"]
137
+ image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device)
138
+
139
+ # Get all COCO annotations for the given image.
140
+ annotations = target["annotations"]
141
+ classes = []
142
+ area = []
143
+ boxes = []
144
+ keypoints = []
145
+ for obj in annotations:
146
+ if "iscrowd" not in obj or obj["iscrowd"] == 0:
147
+ classes.append(obj["category_id"])
148
+ area.append(obj["area"])
149
+ boxes.append(obj["bbox"])
150
+ if "keypoints" in obj:
151
+ keypoints.append(obj["keypoints"])
152
+
153
+ classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device)
154
+ area = torch.as_tensor(area, dtype=torch.float32, device=image.device)
155
+ iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device)
156
+ # guard against no boxes via resizing
157
+ boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4)
158
+ boxes[:, 2:] += boxes[:, :2]
159
+ boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
160
+ boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
161
+
162
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
163
+
164
+ new_target = {
165
+ "image_id": image_id,
166
+ "class_labels": classes[keep],
167
+ "boxes": boxes[keep],
168
+ "area": area[keep],
169
+ "iscrowd": iscrowd[keep],
170
+ "orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device),
171
+ }
172
+
173
+ if keypoints:
174
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device)
175
+ # Apply the keep mask here to filter the relevant annotations
176
+ keypoints = keypoints[keep]
177
+ num_keypoints = keypoints.shape[0]
178
+ keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
179
+ new_target["keypoints"] = keypoints
180
+
181
+ if return_segmentation_masks:
182
+ segmentation_masks = [obj["segmentation"] for obj in annotations]
183
+ masks = convert_coco_poly_to_mask(segmentation_masks, image_height, image_width, device=image.device)
184
+ new_target["masks"] = masks[keep]
185
+
186
+ return new_target
187
+
188
+
189
+ def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
190
+ """
191
+ Compute the bounding boxes around the provided panoptic segmentation masks.
192
+
193
+ Args:
194
+ masks: masks in format `[number_masks, height, width]` where N is the number of masks
195
+
196
+ Returns:
197
+ boxes: bounding boxes in format `[number_masks, 4]` in xyxy format
198
+ """
199
+ if masks.numel() == 0:
200
+ return torch.zeros((0, 4), device=masks.device)
201
+
202
+ h, w = masks.shape[-2:]
203
+ y = torch.arange(0, h, dtype=torch.float32, device=masks.device)
204
+ x = torch.arange(0, w, dtype=torch.float32, device=masks.device)
205
+ # see https://github.com/pytorch/pytorch/issues/50276
206
+ y, x = torch.meshgrid(y, x, indexing="ij")
207
+
208
+ x_mask = masks * torch.unsqueeze(x, 0)
209
+ x_max = x_mask.view(x_mask.shape[0], -1).max(-1)[0]
210
+ x_min = (
211
+ torch.where(masks, x.unsqueeze(0), torch.tensor(1e8, device=masks.device)).view(masks.shape[0], -1).min(-1)[0]
212
+ )
213
+
214
+ y_mask = masks * torch.unsqueeze(y, 0)
215
+ y_max = y_mask.view(y_mask.shape[0], -1).max(-1)[0]
216
+ y_min = (
217
+ torch.where(masks, y.unsqueeze(0), torch.tensor(1e8, device=masks.device)).view(masks.shape[0], -1).min(-1)[0]
218
+ )
219
+
220
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
221
+
222
+
223
+ # 2 functions below adapted from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
224
+ # Copyright (c) 2018, Alexander Kirillov
225
+ # All rights reserved.
226
+ def rgb_to_id(color):
227
+ """
228
+ Converts RGB color to unique ID.
229
+ """
230
+ if isinstance(color, torch.Tensor) and len(color.shape) == 3:
231
+ if color.dtype == torch.uint8:
232
+ color = color.to(torch.int32)
233
+ return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
234
+ return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
235
+
236
+
237
+ def prepare_coco_panoptic_annotation(
238
+ image: torch.Tensor,
239
+ target: Dict,
240
+ masks_path: Union[str, pathlib.Path],
241
+ return_masks: bool = True,
242
+ input_data_format: Union[ChannelDimension, str] = None,
243
+ ) -> Dict:
244
+ """
245
+ Prepare a coco panoptic annotation for DETR.
246
+ """
247
+ image_height, image_width = get_image_size(image, channel_dim=input_data_format)
248
+ annotation_path = pathlib.Path(masks_path) / target["file_name"]
249
+
250
+ new_target = {}
251
+ new_target["image_id"] = torch.as_tensor(
252
+ [target["image_id"] if "image_id" in target else target["id"]], dtype=torch.int64, device=image.device
253
+ )
254
+ new_target["size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device)
255
+ new_target["orig_size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device)
256
+
257
+ if "segments_info" in target:
258
+ masks = read_image(annotation_path).permute(1, 2, 0).to(dtype=torch.int32, device=image.device)
259
+ masks = rgb_to_id(masks)
260
+
261
+ ids = torch.as_tensor([segment_info["id"] for segment_info in target["segments_info"]], device=image.device)
262
+ masks = masks == ids[:, None, None]
263
+ masks = masks.to(torch.bool)
264
+ if return_masks:
265
+ new_target["masks"] = masks
266
+ new_target["boxes"] = masks_to_boxes(masks)
267
+ new_target["class_labels"] = torch.as_tensor(
268
+ [segment_info["category_id"] for segment_info in target["segments_info"]],
269
+ dtype=torch.int64,
270
+ device=image.device,
271
+ )
272
+ new_target["iscrowd"] = torch.as_tensor(
273
+ [segment_info["iscrowd"] for segment_info in target["segments_info"]],
274
+ dtype=torch.int64,
275
+ device=image.device,
276
+ )
277
+ new_target["area"] = torch.as_tensor(
278
+ [segment_info["area"] for segment_info in target["segments_info"]],
279
+ dtype=torch.float32,
280
+ device=image.device,
281
+ )
282
+
283
+ return new_target
284
+
285
+
286
+ class DetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
287
+ format: Optional[Union[str, AnnotationFormat]]
288
+ do_convert_annotations: Optional[bool]
289
+ do_pad: Optional[bool]
290
+ pad_size: Optional[Dict[str, int]]
291
+ return_segmentation_masks: Optional[bool]
292
+
293
+
294
+ @add_start_docstrings(
295
+ "Constructs a fast Detr image processor.",
296
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
297
+ """
298
+ format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
299
+ Data format of the annotations. One of "coco_detection" or "coco_panoptic".
300
+ do_convert_annotations (`bool`, *optional*, defaults to `True`):
301
+ Controls whether to convert the annotations to the format expected by the DETR model. Converts the
302
+ bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
303
+ Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
304
+ do_pad (`bool`, *optional*, defaults to `True`):
305
+ Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
306
+ method. If `True`, padding will be applied to the bottom and right of the image with zeros.
307
+ If `pad_size` is provided, the image will be padded to the specified dimensions.
308
+ Otherwise, the image will be padded to the maximum height and width of the batch.
309
+ pad_size (`Dict[str, int]`, *optional*):
310
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
311
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
312
+ height and width in the batch.
313
+ return_segmentation_masks (`bool`, *optional*, defaults to `False`):
314
+ Whether to return segmentation masks.
315
+ """,
316
+ )
317
+ @requires(backends=("torchvision", "torch"))
318
+ class DetrImageProcessorFast(BaseImageProcessorFast):
319
+ resample = PILImageResampling.BILINEAR
320
+ image_mean = IMAGENET_DEFAULT_MEAN
321
+ image_std = IMAGENET_DEFAULT_STD
322
+ format = AnnotationFormat.COCO_DETECTION
323
+ do_resize = True
324
+ do_rescale = True
325
+ do_normalize = True
326
+ do_pad = True
327
+ size = {"shortest_edge": 800, "longest_edge": 1333}
328
+ default_to_square = False
329
+ model_input_names = ["pixel_values", "pixel_mask"]
330
+ valid_kwargs = DetrFastImageProcessorKwargs
331
+
332
+ def __init__(self, **kwargs: Unpack[DetrFastImageProcessorKwargs]) -> None:
333
+ if "pad_and_return_pixel_mask" in kwargs:
334
+ kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask")
335
+
336
+ size = kwargs.pop("size", None)
337
+ if "max_size" in kwargs:
338
+ logger.warning_once(
339
+ "The `max_size` parameter is deprecated and will be removed in v4.26. "
340
+ "Please specify in `size['longest_edge'] instead`.",
341
+ )
342
+ max_size = kwargs.pop("max_size")
343
+ else:
344
+ max_size = None if size is None else 1333
345
+
346
+ size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333}
347
+ self.size = get_size_dict(size, max_size=max_size, default_to_square=False)
348
+
349
+ # Backwards compatibility
350
+ do_convert_annotations = kwargs.get("do_convert_annotations", None)
351
+ do_normalize = kwargs.get("do_normalize", None)
352
+ if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None:
353
+ self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize
354
+
355
+ super().__init__(**kwargs)
356
+
357
+ @classmethod
358
+ def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
359
+ """
360
+ Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
361
+ created using from_dict and kwargs e.g. `DetrImageProcessorFast.from_pretrained(checkpoint, size=600,
362
+ max_size=800)`
363
+ """
364
+ image_processor_dict = image_processor_dict.copy()
365
+ if "max_size" in kwargs:
366
+ image_processor_dict["max_size"] = kwargs.pop("max_size")
367
+ if "pad_and_return_pixel_mask" in kwargs:
368
+ image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
369
+ return super().from_dict(image_processor_dict, **kwargs)
370
+
371
+ def prepare_annotation(
372
+ self,
373
+ image: torch.Tensor,
374
+ target: Dict,
375
+ format: Optional[AnnotationFormat] = None,
376
+ return_segmentation_masks: Optional[bool] = None,
377
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
378
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
379
+ ) -> Dict:
380
+ """
381
+ Prepare an annotation for feeding into DETR model.
382
+ """
383
+ format = format if format is not None else self.format
384
+
385
+ if format == AnnotationFormat.COCO_DETECTION:
386
+ return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
387
+ target = prepare_coco_detection_annotation(
388
+ image, target, return_segmentation_masks, input_data_format=input_data_format
389
+ )
390
+ elif format == AnnotationFormat.COCO_PANOPTIC:
391
+ return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
392
+ target = prepare_coco_panoptic_annotation(
393
+ image,
394
+ target,
395
+ masks_path=masks_path,
396
+ return_masks=return_segmentation_masks,
397
+ input_data_format=input_data_format,
398
+ )
399
+ else:
400
+ raise ValueError(f"Format {format} is not supported.")
401
+ return target
402
+
403
+ def resize(
404
+ self,
405
+ image: torch.Tensor,
406
+ size: SizeDict,
407
+ interpolation: "F.InterpolationMode" = None,
408
+ **kwargs,
409
+ ) -> torch.Tensor:
410
+ """
411
+ Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
412
+ int, smaller edge of the image will be matched to this number.
413
+
414
+ Args:
415
+ image (`torch.Tensor`):
416
+ Image to resize.
417
+ size (`SizeDict`):
418
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
419
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
420
+ Do NOT keep the aspect ratio.
421
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
422
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
423
+ less or equal to `longest_edge`.
424
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
425
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
426
+ `max_width`.
427
+ interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
428
+ Resampling filter to use if resizing the image.
429
+ """
430
+ interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
431
+ if size.shortest_edge and size.longest_edge:
432
+ # Resize the image so that the shortest edge or the longest edge is of the given size
433
+ # while maintaining the aspect ratio of the original image.
434
+ new_size = get_size_with_aspect_ratio(
435
+ image.size()[-2:],
436
+ size["shortest_edge"],
437
+ size["longest_edge"],
438
+ )
439
+ elif size.max_height and size.max_width:
440
+ new_size = get_image_size_for_max_height_width(image.size()[-2:], size["max_height"], size["max_width"])
441
+ elif size.height and size.width:
442
+ new_size = (size["height"], size["width"])
443
+ else:
444
+ raise ValueError(
445
+ "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
446
+ f" {size.keys()}."
447
+ )
448
+
449
+ image = F.resize(
450
+ image,
451
+ size=new_size,
452
+ interpolation=interpolation,
453
+ **kwargs,
454
+ )
455
+ return image
456
+
457
+ def resize_annotation(
458
+ self,
459
+ annotation: Dict[str, Any],
460
+ orig_size: Tuple[int, int],
461
+ target_size: Tuple[int, int],
462
+ threshold: float = 0.5,
463
+ interpolation: "F.InterpolationMode" = None,
464
+ ):
465
+ """
466
+ Resizes an annotation to a target size.
467
+
468
+ Args:
469
+ annotation (`Dict[str, Any]`):
470
+ The annotation dictionary.
471
+ orig_size (`Tuple[int, int]`):
472
+ The original size of the input image.
473
+ target_size (`Tuple[int, int]`):
474
+ The target size of the image, as returned by the preprocessing `resize` step.
475
+ threshold (`float`, *optional*, defaults to 0.5):
476
+ The threshold used to binarize the segmentation masks.
477
+ resample (`InterpolationMode`, defaults to `InterpolationMode.NEAREST`):
478
+ The resampling filter to use when resizing the masks.
479
+ """
480
+ interpolation = interpolation if interpolation is not None else F.InterpolationMode.NEAREST
481
+ ratio_height, ratio_width = [target / orig for target, orig in zip(target_size, orig_size)]
482
+
483
+ new_annotation = {}
484
+ new_annotation["size"] = target_size
485
+
486
+ for key, value in annotation.items():
487
+ if key == "boxes":
488
+ boxes = value
489
+ scaled_boxes = boxes * torch.as_tensor(
490
+ [ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32, device=boxes.device
491
+ )
492
+ new_annotation["boxes"] = scaled_boxes
493
+ elif key == "area":
494
+ area = value
495
+ scaled_area = area * (ratio_width * ratio_height)
496
+ new_annotation["area"] = scaled_area
497
+ elif key == "masks":
498
+ masks = value[:, None]
499
+ masks = [F.resize(mask, target_size, interpolation=interpolation) for mask in masks]
500
+ masks = torch.stack(masks).to(torch.float32)
501
+ masks = masks[:, 0] > threshold
502
+ new_annotation["masks"] = masks
503
+ elif key == "size":
504
+ new_annotation["size"] = target_size
505
+ else:
506
+ new_annotation[key] = value
507
+
508
+ return new_annotation
509
+
510
+ def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict:
511
+ image_height, image_width = image_size
512
+ norm_annotation = {}
513
+ for key, value in annotation.items():
514
+ if key == "boxes":
515
+ boxes = value
516
+ boxes = corners_to_center_format(boxes)
517
+ boxes /= torch.as_tensor(
518
+ [image_width, image_height, image_width, image_height], dtype=torch.float32, device=boxes.device
519
+ )
520
+ norm_annotation[key] = boxes
521
+ else:
522
+ norm_annotation[key] = value
523
+ return norm_annotation
524
+
525
+ def _update_annotation_for_padded_image(
526
+ self,
527
+ annotation: Dict,
528
+ input_image_size: Tuple[int, int],
529
+ output_image_size: Tuple[int, int],
530
+ padding,
531
+ update_bboxes,
532
+ ) -> Dict:
533
+ """
534
+ Update the annotation for a padded image.
535
+ """
536
+ new_annotation = {}
537
+ new_annotation["size"] = output_image_size
538
+ ratio_height, ratio_width = (input / output for output, input in zip(output_image_size, input_image_size))
539
+
540
+ for key, value in annotation.items():
541
+ if key == "masks":
542
+ masks = value
543
+ masks = F.pad(
544
+ masks,
545
+ padding,
546
+ fill=0,
547
+ )
548
+ masks = safe_squeeze(masks, 1)
549
+ new_annotation["masks"] = masks
550
+ elif key == "boxes" and update_bboxes:
551
+ boxes = value
552
+ boxes *= torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height], device=boxes.device)
553
+ new_annotation["boxes"] = boxes
554
+ elif key == "size":
555
+ new_annotation["size"] = output_image_size
556
+ else:
557
+ new_annotation[key] = value
558
+ return new_annotation
559
+
560
+ def pad(
561
+ self,
562
+ image: torch.Tensor,
563
+ padded_size: Tuple[int, int],
564
+ annotation: Optional[Dict[str, Any]] = None,
565
+ update_bboxes: bool = True,
566
+ fill: int = 0,
567
+ ):
568
+ original_size = image.size()[-2:]
569
+ padding_bottom = padded_size[0] - original_size[0]
570
+ padding_right = padded_size[1] - original_size[1]
571
+ if padding_bottom < 0 or padding_right < 0:
572
+ raise ValueError(
573
+ f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
574
+ f"original size. Got padded size: {padded_size}, original size: {original_size}."
575
+ )
576
+ if original_size != padded_size:
577
+ padding = [0, 0, padding_right, padding_bottom]
578
+ image = F.pad(image, padding, fill=fill)
579
+ if annotation is not None:
580
+ annotation = self._update_annotation_for_padded_image(
581
+ annotation, original_size, padded_size, padding, update_bboxes
582
+ )
583
+
584
+ # Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
585
+ pixel_mask = torch.zeros(padded_size, dtype=torch.int64, device=image.device)
586
+ pixel_mask[: original_size[0], : original_size[1]] = 1
587
+
588
+ return image, pixel_mask, annotation
589
+
590
+ @add_start_docstrings(
591
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
592
+ """
593
+ annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
594
+ List of annotations associated with the image or batch of images. If annotation is for object
595
+ detection, the annotations should be a dictionary with the following keys:
596
+ - "image_id" (`int`): The image id.
597
+ - "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
598
+ dictionary. An image can have no annotations, in which case the list should be empty.
599
+ If annotation is for segmentation, the annotations should be a dictionary with the following keys:
600
+ - "image_id" (`int`): The image id.
601
+ - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
602
+ An image can have no segments, in which case the list should be empty.
603
+ - "file_name" (`str`): The file name of the image.
604
+ format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
605
+ Data format of the annotations. One of "coco_detection" or "coco_panoptic".
606
+ do_convert_annotations (`bool`, *optional*, defaults to `True`):
607
+ Controls whether to convert the annotations to the format expected by the DETR model. Converts the
608
+ bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
609
+ Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
610
+ do_pad (`bool`, *optional*, defaults to `True`):
611
+ Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
612
+ method. If `True`, padding will be applied to the bottom and right of the image with zeros.
613
+ If `pad_size` is provided, the image will be padded to the specified dimensions.
614
+ Otherwise, the image will be padded to the maximum height and width of the batch.
615
+ pad_size (`Dict[str, int]`, *optional*):
616
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
617
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
618
+ height and width in the batch.
619
+ return_segmentation_masks (`bool`, *optional*, defaults to `False`):
620
+ Whether to return segmentation masks.
621
+ masks_path (`str` or `pathlib.Path`, *optional*):
622
+ Path to the directory containing the segmentation masks.
623
+ """,
624
+ )
625
+ def preprocess(
626
+ self,
627
+ images: ImageInput,
628
+ annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
629
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
630
+ **kwargs: Unpack[DetrFastImageProcessorKwargs],
631
+ ) -> BatchFeature:
632
+ if "pad_and_return_pixel_mask" in kwargs:
633
+ kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask")
634
+ logger.warning_once(
635
+ "The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, "
636
+ "use `do_pad` instead."
637
+ )
638
+
639
+ if "max_size" in kwargs:
640
+ logger.warning_once(
641
+ "The `max_size` argument is deprecated and will be removed in a future version, use"
642
+ " `size['longest_edge']` instead."
643
+ )
644
+ kwargs["size"] = kwargs.pop("max_size")
645
+
646
+ return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
647
+
648
+ def _preprocess(
649
+ self,
650
+ images: List["torch.Tensor"],
651
+ annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
652
+ return_segmentation_masks: bool,
653
+ masks_path: Optional[Union[str, pathlib.Path]],
654
+ do_resize: bool,
655
+ size: SizeDict,
656
+ interpolation: Optional["F.InterpolationMode"],
657
+ do_center_crop: bool,
658
+ crop_size: SizeDict,
659
+ do_rescale: bool,
660
+ rescale_factor: float,
661
+ do_normalize: bool,
662
+ do_convert_annotations: bool,
663
+ image_mean: Optional[Union[float, List[float]]],
664
+ image_std: Optional[Union[float, List[float]]],
665
+ do_pad: bool,
666
+ pad_size: Optional[Dict[str, int]],
667
+ format: Optional[Union[str, AnnotationFormat]],
668
+ return_tensors: Optional[Union[str, TensorType]],
669
+ ) -> BatchFeature:
670
+ """
671
+ Preprocess an image or a batch of images so that it can be used by the model.
672
+ """
673
+ if annotations is not None and isinstance(annotations, dict):
674
+ annotations = [annotations]
675
+
676
+ if annotations is not None and len(images) != len(annotations):
677
+ raise ValueError(
678
+ f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
679
+ )
680
+
681
+ format = AnnotationFormat(format)
682
+ if annotations is not None:
683
+ validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
684
+
685
+ if (
686
+ masks_path is not None
687
+ and format == AnnotationFormat.COCO_PANOPTIC
688
+ and not isinstance(masks_path, (pathlib.Path, str))
689
+ ):
690
+ raise ValueError(
691
+ "The path to the directory containing the mask PNG files should be provided as a"
692
+ f" `pathlib.Path` or string object, but is {type(masks_path)} instead."
693
+ )
694
+
695
+ data = {}
696
+
697
+ processed_images = []
698
+ processed_annotations = []
699
+ pixel_masks = [] # Initialize pixel_masks here
700
+ for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
701
+ # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
702
+ if annotations is not None:
703
+ annotation = self.prepare_annotation(
704
+ image,
705
+ annotation,
706
+ format,
707
+ return_segmentation_masks=return_segmentation_masks,
708
+ masks_path=masks_path,
709
+ input_data_format=ChannelDimension.FIRST,
710
+ )
711
+
712
+ if do_resize:
713
+ resized_image = self.resize(image, size=size, interpolation=interpolation)
714
+ if annotations is not None:
715
+ annotation = self.resize_annotation(
716
+ annotation,
717
+ orig_size=image.size()[-2:],
718
+ target_size=resized_image.size()[-2:],
719
+ )
720
+ image = resized_image
721
+ # Fused rescale and normalize
722
+ image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
723
+ if do_convert_annotations and annotations is not None:
724
+ annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
725
+
726
+ processed_images.append(image)
727
+ processed_annotations.append(annotation)
728
+ images = processed_images
729
+ annotations = processed_annotations if annotations is not None else None
730
+
731
+ if do_pad:
732
+ # depends on all resized image shapes so we need another loop
733
+ if pad_size is not None:
734
+ padded_size = (pad_size["height"], pad_size["width"])
735
+ else:
736
+ padded_size = get_max_height_width(images)
737
+
738
+ padded_images = []
739
+ padded_annotations = []
740
+ for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
741
+ # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
742
+ if padded_size == image.size()[-2:]:
743
+ padded_images.append(image)
744
+ pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device))
745
+ padded_annotations.append(annotation)
746
+ continue
747
+ image, pixel_mask, annotation = self.pad(
748
+ image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations
749
+ )
750
+ padded_images.append(image)
751
+ padded_annotations.append(annotation)
752
+ pixel_masks.append(pixel_mask)
753
+ images = padded_images
754
+ annotations = padded_annotations if annotations is not None else None
755
+ data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)})
756
+
757
+ data.update({"pixel_values": torch.stack(images, dim=0)})
758
+ encoded_inputs = BatchFeature(data, tensor_type=return_tensors)
759
+ if annotations is not None:
760
+ encoded_inputs["labels"] = [
761
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
762
+ ]
763
+ return encoded_inputs
764
+
765
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process
766
+ def post_process(self, outputs, target_sizes):
767
+ """
768
+ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
769
+ bottom_right_x, bottom_right_y) format. Only supports PyTorch.
770
+
771
+ Args:
772
+ outputs ([`DetrObjectDetectionOutput`]):
773
+ Raw outputs of the model.
774
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
775
+ Tensor containing the size (height, width) of each image of the batch. For evaluation, this must be the
776
+ original image size (before any data augmentation). For visualization, this should be the image size
777
+ after data augment, but before padding.
778
+ Returns:
779
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
780
+ in the batch as predicted by the model.
781
+ """
782
+ logger.warning_once(
783
+ "`post_process` is deprecated and will be removed in v5 of Transformers, please use"
784
+ " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
785
+ )
786
+
787
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
788
+
789
+ if len(out_logits) != len(target_sizes):
790
+ raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
791
+ if target_sizes.shape[1] != 2:
792
+ raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
793
+
794
+ prob = nn.functional.softmax(out_logits, -1)
795
+ scores, labels = prob[..., :-1].max(-1)
796
+
797
+ # convert to [x0, y0, x1, y1] format
798
+ boxes = center_to_corners_format(out_bbox)
799
+ # and from relative [0, 1] to absolute [0, height] coordinates
800
+ img_h, img_w = target_sizes.unbind(1)
801
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
802
+ boxes = boxes * scale_fct[:, None, :]
803
+
804
+ results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
805
+ return results
806
+
807
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_segmentation
808
+ def post_process_segmentation(self, outputs, target_sizes, threshold=0.9, mask_threshold=0.5):
809
+ """
810
+ Converts the output of [`DetrForSegmentation`] into image segmentation predictions. Only supports PyTorch.
811
+
812
+ Args:
813
+ outputs ([`DetrSegmentationOutput`]):
814
+ Raw outputs of the model.
815
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`):
816
+ Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction.
817
+ threshold (`float`, *optional*, defaults to 0.9):
818
+ Threshold to use to filter out queries.
819
+ mask_threshold (`float`, *optional*, defaults to 0.5):
820
+ Threshold to use when turning the predicted masks into binary values.
821
+ Returns:
822
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image
823
+ in the batch as predicted by the model.
824
+ """
825
+ logger.warning_once(
826
+ "`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
827
+ " `post_process_semantic_segmentation`.",
828
+ )
829
+ out_logits, raw_masks = outputs.logits, outputs.pred_masks
830
+ empty_label = out_logits.shape[-1] - 1
831
+ preds = []
832
+
833
+ def to_tuple(tup):
834
+ if isinstance(tup, tuple):
835
+ return tup
836
+ return tuple(tup.tolist())
837
+
838
+ for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
839
+ # we filter empty queries and detection below threshold
840
+ cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
841
+ keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
842
+ cur_scores = cur_scores[keep]
843
+ cur_labels = cur_labels[keep]
844
+ cur_masks = cur_masks[keep]
845
+ cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
846
+ cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1
847
+
848
+ predictions = {"scores": cur_scores, "labels": cur_labels, "masks": cur_masks}
849
+ preds.append(predictions)
850
+ return preds
851
+
852
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_instance
853
+ def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5):
854
+ """
855
+ Converts the output of [`DetrForSegmentation`] into actual instance segmentation predictions. Only supports
856
+ PyTorch.
857
+
858
+ Args:
859
+ results (`List[Dict]`):
860
+ Results list obtained by [`~DetrImageProcessor.post_process`], to which "masks" results will be added.
861
+ outputs ([`DetrSegmentationOutput`]):
862
+ Raw outputs of the model.
863
+ orig_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
864
+ Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
865
+ image size (before any data augmentation).
866
+ max_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
867
+ Tensor containing the maximum size (h, w) of each image of the batch. For evaluation, this must be the
868
+ original image size (before any data augmentation).
869
+ threshold (`float`, *optional*, defaults to 0.5):
870
+ Threshold to use when turning the predicted masks into binary values.
871
+ Returns:
872
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an
873
+ image in the batch as predicted by the model.
874
+ """
875
+ logger.warning_once(
876
+ "`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use"
877
+ " `post_process_instance_segmentation`.",
878
+ )
879
+
880
+ if len(orig_target_sizes) != len(max_target_sizes):
881
+ raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes")
882
+ max_h, max_w = max_target_sizes.max(0)[0].tolist()
883
+ outputs_masks = outputs.pred_masks.squeeze(2)
884
+ outputs_masks = nn.functional.interpolate(
885
+ outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False
886
+ )
887
+ outputs_masks = (outputs_masks.sigmoid() > threshold).cpu()
888
+
889
+ for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
890
+ img_h, img_w = t[0], t[1]
891
+ results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
892
+ results[i]["masks"] = nn.functional.interpolate(
893
+ results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
894
+ ).byte()
895
+
896
+ return results
897
+
898
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_panoptic
899
+ def post_process_panoptic(self, outputs, processed_sizes, target_sizes=None, is_thing_map=None, threshold=0.85):
900
+ """
901
+ Converts the output of [`DetrForSegmentation`] into actual panoptic predictions. Only supports PyTorch.
902
+
903
+ Args:
904
+ outputs ([`DetrSegmentationOutput`]):
905
+ Raw outputs of the model.
906
+ processed_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`):
907
+ Torch Tensor (or list) containing the size (h, w) of each image of the batch, i.e. the size after data
908
+ augmentation but before batching.
909
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*):
910
+ Torch Tensor (or list) corresponding to the requested final size `(height, width)` of each prediction.
911
+ If left to None, it will default to the `processed_sizes`.
912
+ is_thing_map (`torch.Tensor` of shape `(batch_size, 2)`, *optional*):
913
+ Dictionary mapping class indices to either True or False, depending on whether or not they are a thing.
914
+ If not set, defaults to the `is_thing_map` of COCO panoptic.
915
+ threshold (`float`, *optional*, defaults to 0.85):
916
+ Threshold to use to filter out queries.
917
+ Returns:
918
+ `List[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for
919
+ an image in the batch as predicted by the model.
920
+ """
921
+ logger.warning_once(
922
+ "`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use"
923
+ " `post_process_panoptic_segmentation`.",
924
+ )
925
+ if target_sizes is None:
926
+ target_sizes = processed_sizes
927
+ if len(processed_sizes) != len(target_sizes):
928
+ raise ValueError("Make sure to pass in as many processed_sizes as target_sizes")
929
+
930
+ if is_thing_map is None:
931
+ # default to is_thing_map of COCO panoptic
932
+ is_thing_map = {i: i <= 90 for i in range(201)}
933
+
934
+ out_logits, raw_masks, raw_boxes = outputs.logits, outputs.pred_masks, outputs.pred_boxes
935
+ if not len(out_logits) == len(raw_masks) == len(target_sizes):
936
+ raise ValueError(
937
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits and masks"
938
+ )
939
+ empty_label = out_logits.shape[-1] - 1
940
+ preds = []
941
+
942
+ def to_tuple(tup):
943
+ if isinstance(tup, tuple):
944
+ return tup
945
+ return tuple(tup.tolist())
946
+
947
+ for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
948
+ out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
949
+ ):
950
+ # we filter empty queries and detection below threshold
951
+ cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
952
+ keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
953
+ cur_scores = cur_scores[keep]
954
+ cur_labels = cur_labels[keep]
955
+ cur_masks = cur_masks[keep]
956
+ cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
957
+ cur_boxes = center_to_corners_format(cur_boxes[keep])
958
+
959
+ h, w = cur_masks.shape[-2:]
960
+ if len(cur_boxes) != len(cur_labels):
961
+ raise ValueError("Not as many boxes as there are classes")
962
+
963
+ # It may be that we have several predicted masks for the same stuff class.
964
+ # In the following, we track the list of masks ids for each stuff class (they are merged later on)
965
+ cur_masks = cur_masks.flatten(1)
966
+ stuff_equiv_classes = defaultdict(lambda: [])
967
+ for k, label in enumerate(cur_labels):
968
+ if not is_thing_map[label.item()]:
969
+ stuff_equiv_classes[label.item()].append(k)
970
+
971
+ def get_ids_area(masks, scores, dedup=False):
972
+ # This helper function creates the final panoptic segmentation image
973
+ # It also returns the area of the masks that appears on the image
974
+
975
+ m_id = masks.transpose(0, 1).softmax(-1)
976
+
977
+ if m_id.shape[-1] == 0:
978
+ # We didn't detect any mask :(
979
+ m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device)
980
+ else:
981
+ m_id = m_id.argmax(-1).view(h, w)
982
+
983
+ if dedup:
984
+ # Merge the masks corresponding to the same stuff class
985
+ for equiv in stuff_equiv_classes.values():
986
+ if len(equiv) > 1:
987
+ for eq_id in equiv:
988
+ m_id.masked_fill_(m_id.eq(eq_id), equiv[0])
989
+
990
+ final_h, final_w = to_tuple(target_size)
991
+
992
+ seg_img = PIL.Image.fromarray(id_to_rgb(m_id.view(h, w).cpu().numpy()))
993
+ seg_img = seg_img.resize(size=(final_w, final_h), resample=PILImageResampling.NEAREST)
994
+
995
+ np_seg_img = torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes()))
996
+ np_seg_img = np_seg_img.view(final_h, final_w, 3)
997
+ np_seg_img = np_seg_img.numpy()
998
+
999
+ m_id = torch.from_numpy(rgb_to_id(np_seg_img))
1000
+
1001
+ area = []
1002
+ for i in range(len(scores)):
1003
+ area.append(m_id.eq(i).sum().item())
1004
+ return area, seg_img
1005
+
1006
+ area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
1007
+ if cur_labels.numel() > 0:
1008
+ # We know filter empty masks as long as we find some
1009
+ while True:
1010
+ filtered_small = torch.as_tensor(
1011
+ [area[i] <= 4 for i, c in enumerate(cur_labels)], dtype=torch.bool, device=keep.device
1012
+ )
1013
+ if filtered_small.any().item():
1014
+ cur_scores = cur_scores[~filtered_small]
1015
+ cur_labels = cur_labels[~filtered_small]
1016
+ cur_masks = cur_masks[~filtered_small]
1017
+ area, seg_img = get_ids_area(cur_masks, cur_scores)
1018
+ else:
1019
+ break
1020
+
1021
+ else:
1022
+ cur_labels = torch.ones(1, dtype=torch.long, device=cur_labels.device)
1023
+
1024
+ segments_info = []
1025
+ for i, a in enumerate(area):
1026
+ cat = cur_labels[i].item()
1027
+ segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a})
1028
+ del cur_labels
1029
+
1030
+ with io.BytesIO() as out:
1031
+ seg_img.save(out, format="PNG")
1032
+ predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
1033
+ preds.append(predictions)
1034
+ return preds
1035
+
1036
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_object_detection
1037
+ def post_process_object_detection(
1038
+ self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None
1039
+ ):
1040
+ """
1041
+ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
1042
+ bottom_right_x, bottom_right_y) format. Only supports PyTorch.
1043
+
1044
+ Args:
1045
+ outputs ([`DetrObjectDetectionOutput`]):
1046
+ Raw outputs of the model.
1047
+ threshold (`float`, *optional*):
1048
+ Score threshold to keep object detection predictions.
1049
+ target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
1050
+ Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
1051
+ `(height, width)` of each image in the batch. If unset, predictions will not be resized.
1052
+ Returns:
1053
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
1054
+ in the batch as predicted by the model.
1055
+ """
1056
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
1057
+
1058
+ if target_sizes is not None:
1059
+ if len(out_logits) != len(target_sizes):
1060
+ raise ValueError(
1061
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
1062
+ )
1063
+
1064
+ prob = nn.functional.softmax(out_logits, -1)
1065
+ scores, labels = prob[..., :-1].max(-1)
1066
+
1067
+ # Convert to [x0, y0, x1, y1] format
1068
+ boxes = center_to_corners_format(out_bbox)
1069
+
1070
+ # Convert from relative [0, 1] to absolute [0, height] coordinates
1071
+ if target_sizes is not None:
1072
+ if isinstance(target_sizes, List):
1073
+ img_h = torch.Tensor([i[0] for i in target_sizes])
1074
+ img_w = torch.Tensor([i[1] for i in target_sizes])
1075
+ else:
1076
+ img_h, img_w = target_sizes.unbind(1)
1077
+
1078
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
1079
+ boxes = boxes * scale_fct[:, None, :]
1080
+
1081
+ results = []
1082
+ for s, l, b in zip(scores, labels, boxes):
1083
+ score = s[s > threshold]
1084
+ label = l[s > threshold]
1085
+ box = b[s > threshold]
1086
+ results.append({"scores": score, "labels": label, "boxes": box})
1087
+
1088
+ return results
1089
+
1090
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_semantic_segmentation
1091
+ def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple[int, int]] = None):
1092
+ """
1093
+ Converts the output of [`DetrForSegmentation`] into semantic segmentation maps. Only supports PyTorch.
1094
+
1095
+ Args:
1096
+ outputs ([`DetrForSegmentation`]):
1097
+ Raw outputs of the model.
1098
+ target_sizes (`List[Tuple[int, int]]`, *optional*):
1099
+ A list of tuples (`Tuple[int, int]`) containing the target size (height, width) of each image in the
1100
+ batch. If unset, predictions will not be resized.
1101
+ Returns:
1102
+ `List[torch.Tensor]`:
1103
+ A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
1104
+ corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
1105
+ `torch.Tensor` correspond to a semantic class id.
1106
+ """
1107
+ class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
1108
+ masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
1109
+
1110
+ # Remove the null class `[..., :-1]`
1111
+ masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
1112
+ masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
1113
+
1114
+ # Semantic segmentation logits of shape (batch_size, num_classes, height, width)
1115
+ segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
1116
+ batch_size = class_queries_logits.shape[0]
1117
+
1118
+ # Resize logits and compute semantic segmentation maps
1119
+ if target_sizes is not None:
1120
+ if batch_size != len(target_sizes):
1121
+ raise ValueError(
1122
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
1123
+ )
1124
+
1125
+ semantic_segmentation = []
1126
+ for idx in range(batch_size):
1127
+ resized_logits = nn.functional.interpolate(
1128
+ segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
1129
+ )
1130
+ semantic_map = resized_logits[0].argmax(dim=0)
1131
+ semantic_segmentation.append(semantic_map)
1132
+ else:
1133
+ semantic_segmentation = segmentation.argmax(dim=1)
1134
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
1135
+
1136
+ return semantic_segmentation
1137
+
1138
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_instance_segmentation
1139
+ def post_process_instance_segmentation(
1140
+ self,
1141
+ outputs,
1142
+ threshold: float = 0.5,
1143
+ mask_threshold: float = 0.5,
1144
+ overlap_mask_area_threshold: float = 0.8,
1145
+ target_sizes: Optional[List[Tuple[int, int]]] = None,
1146
+ return_coco_annotation: Optional[bool] = False,
1147
+ ) -> List[Dict]:
1148
+ """
1149
+ Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch.
1150
+
1151
+ Args:
1152
+ outputs ([`DetrForSegmentation`]):
1153
+ Raw outputs of the model.
1154
+ threshold (`float`, *optional*, defaults to 0.5):
1155
+ The probability score threshold to keep predicted instance masks.
1156
+ mask_threshold (`float`, *optional*, defaults to 0.5):
1157
+ Threshold to use when turning the predicted masks into binary values.
1158
+ overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
1159
+ The overlap mask area threshold to merge or discard small disconnected parts within each binary
1160
+ instance mask.
1161
+ target_sizes (`List[Tuple]`, *optional*):
1162
+ List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
1163
+ final size (height, width) of each prediction. If unset, predictions will not be resized.
1164
+ return_coco_annotation (`bool`, *optional*):
1165
+ Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE)
1166
+ format.
1167
+ Returns:
1168
+ `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
1169
+ - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
1170
+ `List[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to
1171
+ `True`. Set to `None` if no mask if found above `threshold`.
1172
+ - **segments_info** -- A dictionary that contains additional information on each segment.
1173
+ - **id** -- An integer representing the `segment_id`.
1174
+ - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
1175
+ - **score** -- Prediction score of segment with `segment_id`.
1176
+ """
1177
+ class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
1178
+ masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
1179
+
1180
+ batch_size = class_queries_logits.shape[0]
1181
+ num_labels = class_queries_logits.shape[-1] - 1
1182
+
1183
+ mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
1184
+
1185
+ # Predicted label and score of each query (batch_size, num_queries)
1186
+ pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
1187
+
1188
+ # Loop over items in batch size
1189
+ results: List[Dict[str, TensorType]] = []
1190
+
1191
+ for i in range(batch_size):
1192
+ mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
1193
+ mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
1194
+ )
1195
+
1196
+ # No mask found
1197
+ if mask_probs_item.shape[0] <= 0:
1198
+ height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
1199
+ segmentation = torch.zeros((height, width)) - 1
1200
+ results.append({"segmentation": segmentation, "segments_info": []})
1201
+ continue
1202
+
1203
+ # Get segmentation map and segment information of batch item
1204
+ target_size = target_sizes[i] if target_sizes is not None else None
1205
+ segmentation, segments = compute_segments(
1206
+ mask_probs=mask_probs_item,
1207
+ pred_scores=pred_scores_item,
1208
+ pred_labels=pred_labels_item,
1209
+ mask_threshold=mask_threshold,
1210
+ overlap_mask_area_threshold=overlap_mask_area_threshold,
1211
+ label_ids_to_fuse=[],
1212
+ target_size=target_size,
1213
+ )
1214
+
1215
+ # Return segmentation map in run-length encoding (RLE) format
1216
+ if return_coco_annotation:
1217
+ segmentation = convert_segmentation_to_rle(segmentation)
1218
+
1219
+ results.append({"segmentation": segmentation, "segments_info": segments})
1220
+ return results
1221
+
1222
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_panoptic_segmentation
1223
+ def post_process_panoptic_segmentation(
1224
+ self,
1225
+ outputs,
1226
+ threshold: float = 0.5,
1227
+ mask_threshold: float = 0.5,
1228
+ overlap_mask_area_threshold: float = 0.8,
1229
+ label_ids_to_fuse: Optional[Set[int]] = None,
1230
+ target_sizes: Optional[List[Tuple[int, int]]] = None,
1231
+ ) -> List[Dict]:
1232
+ """
1233
+ Converts the output of [`DetrForSegmentation`] into image panoptic segmentation predictions. Only supports
1234
+ PyTorch.
1235
+
1236
+ Args:
1237
+ outputs ([`DetrForSegmentation`]):
1238
+ The outputs from [`DetrForSegmentation`].
1239
+ threshold (`float`, *optional*, defaults to 0.5):
1240
+ The probability score threshold to keep predicted instance masks.
1241
+ mask_threshold (`float`, *optional*, defaults to 0.5):
1242
+ Threshold to use when turning the predicted masks into binary values.
1243
+ overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
1244
+ The overlap mask area threshold to merge or discard small disconnected parts within each binary
1245
+ instance mask.
1246
+ label_ids_to_fuse (`Set[int]`, *optional*):
1247
+ The labels in this state will have all their instances be fused together. For instance we could say
1248
+ there can only be one sky in an image, but several persons, so the label ID for sky would be in that
1249
+ set, but not the one for person.
1250
+ target_sizes (`List[Tuple]`, *optional*):
1251
+ List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
1252
+ final size (height, width) of each prediction in batch. If unset, predictions will not be resized.
1253
+ Returns:
1254
+ `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
1255
+ - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id` or
1256
+ `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized to
1257
+ the corresponding `target_sizes` entry.
1258
+ - **segments_info** -- A dictionary that contains additional information on each segment.
1259
+ - **id** -- an integer representing the `segment_id`.
1260
+ - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
1261
+ - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
1262
+ Multiple instances of the same class / label were fused and assigned a single `segment_id`.
1263
+ - **score** -- Prediction score of segment with `segment_id`.
1264
+ """
1265
+
1266
+ if label_ids_to_fuse is None:
1267
+ logger.warning_once("`label_ids_to_fuse` unset. No instance will be fused.")
1268
+ label_ids_to_fuse = set()
1269
+
1270
+ class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
1271
+ masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
1272
+
1273
+ batch_size = class_queries_logits.shape[0]
1274
+ num_labels = class_queries_logits.shape[-1] - 1
1275
+
1276
+ mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
1277
+
1278
+ # Predicted label and score of each query (batch_size, num_queries)
1279
+ pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
1280
+
1281
+ # Loop over items in batch size
1282
+ results: List[Dict[str, TensorType]] = []
1283
+
1284
+ for i in range(batch_size):
1285
+ mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
1286
+ mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
1287
+ )
1288
+
1289
+ # No mask found
1290
+ if mask_probs_item.shape[0] <= 0:
1291
+ height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
1292
+ segmentation = torch.zeros((height, width)) - 1
1293
+ results.append({"segmentation": segmentation, "segments_info": []})
1294
+ continue
1295
+
1296
+ # Get segmentation map and segment information of batch item
1297
+ target_size = target_sizes[i] if target_sizes is not None else None
1298
+ segmentation, segments = compute_segments(
1299
+ mask_probs=mask_probs_item,
1300
+ pred_scores=pred_scores_item,
1301
+ pred_labels=pred_labels_item,
1302
+ mask_threshold=mask_threshold,
1303
+ overlap_mask_area_threshold=overlap_mask_area_threshold,
1304
+ label_ids_to_fuse=label_ids_to_fuse,
1305
+ target_size=target_size,
1306
+ )
1307
+
1308
+ results.append({"segmentation": segmentation, "segments_info": segments})
1309
+ return results
1310
+
1311
+
1312
+ __all__ = ["DetrImageProcessorFast"]
docs/transformers/build/lib/transformers/models/detr/modeling_detr.py ADDED
@@ -0,0 +1,1815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Facebook AI Research The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch DETR model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ from torch import Tensor, nn
23
+
24
+ from ...activations import ACT2FN
25
+ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
26
+ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
27
+ from ...modeling_utils import PreTrainedModel
28
+ from ...utils import (
29
+ ModelOutput,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ is_timm_available,
33
+ logging,
34
+ replace_return_docstrings,
35
+ requires_backends,
36
+ )
37
+ from ...utils.backbone_utils import load_backbone
38
+ from .configuration_detr import DetrConfig
39
+
40
+
41
+ if is_timm_available():
42
+ from timm import create_model
43
+
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+ _CONFIG_FOR_DOC = "DetrConfig"
48
+ _CHECKPOINT_FOR_DOC = "facebook/detr-resnet-50"
49
+
50
+
51
+ @dataclass
52
+ class DetrDecoderOutput(BaseModelOutputWithCrossAttentions):
53
+ """
54
+ Base class for outputs of the DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
55
+ namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
56
+ gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
57
+
58
+ Args:
59
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
60
+ Sequence of hidden-states at the output of the last layer of the model.
61
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
62
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
63
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
64
+ plus the initial embedding outputs.
65
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
66
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
67
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
68
+ the self-attention heads.
69
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
70
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
71
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
72
+ used to compute the weighted average in the cross-attention heads.
73
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
74
+ Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
75
+ layernorm.
76
+ """
77
+
78
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
79
+
80
+
81
+ @dataclass
82
+ class DetrModelOutput(Seq2SeqModelOutput):
83
+ """
84
+ Base class for outputs of the DETR encoder-decoder model. This class adds one attribute to Seq2SeqModelOutput,
85
+ namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
86
+ gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
87
+
88
+ Args:
89
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
90
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
91
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
92
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
93
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each
94
+ layer plus the initial embedding outputs.
95
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
96
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
97
+ sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the
98
+ weighted average in the self-attention heads.
99
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
100
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
101
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
102
+ used to compute the weighted average in the cross-attention heads.
103
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
104
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
105
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
106
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
107
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
108
+ layer plus the initial embedding outputs.
109
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
110
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
111
+ sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the
112
+ weighted average in the self-attention heads.
113
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
114
+ Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
115
+ layernorm.
116
+ """
117
+
118
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
119
+
120
+
121
+ @dataclass
122
+ class DetrObjectDetectionOutput(ModelOutput):
123
+ """
124
+ Output type of [`DetrForObjectDetection`].
125
+
126
+ Args:
127
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
128
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
129
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
130
+ scale-invariant IoU loss.
131
+ loss_dict (`Dict`, *optional*):
132
+ A dictionary containing the individual losses. Useful for logging.
133
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
134
+ Classification logits (including no-object) for all queries.
135
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
136
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
137
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
138
+ possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
139
+ unnormalized bounding boxes.
140
+ auxiliary_outputs (`list[Dict]`, *optional*):
141
+ Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
142
+ and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
143
+ `pred_boxes`) for each decoder layer.
144
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
145
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
146
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
147
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
148
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each
149
+ layer plus the initial embedding outputs.
150
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
151
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
152
+ sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the
153
+ weighted average in the self-attention heads.
154
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
155
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
156
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
157
+ used to compute the weighted average in the cross-attention heads.
158
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
159
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
160
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
161
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
162
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
163
+ layer plus the initial embedding outputs.
164
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
165
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
166
+ sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the
167
+ weighted average in the self-attention heads.
168
+ """
169
+
170
+ loss: Optional[torch.FloatTensor] = None
171
+ loss_dict: Optional[Dict] = None
172
+ logits: Optional[torch.FloatTensor] = None
173
+ pred_boxes: Optional[torch.FloatTensor] = None
174
+ auxiliary_outputs: Optional[List[Dict]] = None
175
+ last_hidden_state: Optional[torch.FloatTensor] = None
176
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
177
+ decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
178
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
179
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
180
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
181
+ encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
182
+
183
+
184
+ @dataclass
185
+ class DetrSegmentationOutput(ModelOutput):
186
+ """
187
+ Output type of [`DetrForSegmentation`].
188
+
189
+ Args:
190
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
191
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
192
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
193
+ scale-invariant IoU loss.
194
+ loss_dict (`Dict`, *optional*):
195
+ A dictionary containing the individual losses. Useful for logging.
196
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
197
+ Classification logits (including no-object) for all queries.
198
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
199
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
200
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
201
+ possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
202
+ unnormalized bounding boxes.
203
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`):
204
+ Segmentation masks logits for all queries. See also
205
+ [`~DetrImageProcessor.post_process_semantic_segmentation`] or
206
+ [`~DetrImageProcessor.post_process_instance_segmentation`]
207
+ [`~DetrImageProcessor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic
208
+ segmentation masks respectively.
209
+ auxiliary_outputs (`list[Dict]`, *optional*):
210
+ Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
211
+ and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
212
+ `pred_boxes`) for each decoder layer.
213
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
214
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
215
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
216
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
217
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each
218
+ layer plus the initial embedding outputs.
219
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
220
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
221
+ sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the
222
+ weighted average in the self-attention heads.
223
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
224
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
225
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
226
+ used to compute the weighted average in the cross-attention heads.
227
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
228
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
229
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
230
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
231
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
232
+ layer plus the initial embedding outputs.
233
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
234
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
235
+ sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the
236
+ weighted average in the self-attention heads.
237
+ """
238
+
239
+ loss: Optional[torch.FloatTensor] = None
240
+ loss_dict: Optional[Dict] = None
241
+ logits: Optional[torch.FloatTensor] = None
242
+ pred_boxes: Optional[torch.FloatTensor] = None
243
+ pred_masks: Optional[torch.FloatTensor] = None
244
+ auxiliary_outputs: Optional[List[Dict]] = None
245
+ last_hidden_state: Optional[torch.FloatTensor] = None
246
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
247
+ decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
248
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
249
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
250
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
251
+ encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
252
+
253
+
254
+ # BELOW: utilities copied from
255
+ # https://github.com/facebookresearch/detr/blob/master/backbone.py
256
+ class DetrFrozenBatchNorm2d(nn.Module):
257
+ """
258
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
259
+
260
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
261
+ torchvision.models.resnet[18,34,50,101] produce nans.
262
+ """
263
+
264
+ def __init__(self, n):
265
+ super().__init__()
266
+ self.register_buffer("weight", torch.ones(n))
267
+ self.register_buffer("bias", torch.zeros(n))
268
+ self.register_buffer("running_mean", torch.zeros(n))
269
+ self.register_buffer("running_var", torch.ones(n))
270
+
271
+ def _load_from_state_dict(
272
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
273
+ ):
274
+ num_batches_tracked_key = prefix + "num_batches_tracked"
275
+ if num_batches_tracked_key in state_dict:
276
+ del state_dict[num_batches_tracked_key]
277
+
278
+ super()._load_from_state_dict(
279
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
280
+ )
281
+
282
+ def forward(self, x):
283
+ # move reshapes to the beginning
284
+ # to make it user-friendly
285
+ weight = self.weight.reshape(1, -1, 1, 1)
286
+ bias = self.bias.reshape(1, -1, 1, 1)
287
+ running_var = self.running_var.reshape(1, -1, 1, 1)
288
+ running_mean = self.running_mean.reshape(1, -1, 1, 1)
289
+ epsilon = 1e-5
290
+ scale = weight * (running_var + epsilon).rsqrt()
291
+ bias = bias - running_mean * scale
292
+ return x * scale + bias
293
+
294
+
295
+ def replace_batch_norm(model):
296
+ r"""
297
+ Recursively replace all `torch.nn.BatchNorm2d` with `DetrFrozenBatchNorm2d`.
298
+
299
+ Args:
300
+ model (torch.nn.Module):
301
+ input model
302
+ """
303
+ for name, module in model.named_children():
304
+ if isinstance(module, nn.BatchNorm2d):
305
+ new_module = DetrFrozenBatchNorm2d(module.num_features)
306
+
307
+ if not module.weight.device == torch.device("meta"):
308
+ new_module.weight.data.copy_(module.weight)
309
+ new_module.bias.data.copy_(module.bias)
310
+ new_module.running_mean.data.copy_(module.running_mean)
311
+ new_module.running_var.data.copy_(module.running_var)
312
+
313
+ model._modules[name] = new_module
314
+
315
+ if len(list(module.children())) > 0:
316
+ replace_batch_norm(module)
317
+
318
+
319
+ class DetrConvEncoder(nn.Module):
320
+ """
321
+ Convolutional backbone, using either the AutoBackbone API or one from the timm library.
322
+
323
+ nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.
324
+
325
+ """
326
+
327
+ def __init__(self, config):
328
+ super().__init__()
329
+
330
+ self.config = config
331
+
332
+ # For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
333
+ if config.use_timm_backbone:
334
+ # We default to values which were previously hard-coded. This enables configurability from the config
335
+ # using backbone arguments, while keeping the default behavior the same.
336
+ requires_backends(self, ["timm"])
337
+ kwargs = getattr(config, "backbone_kwargs", {})
338
+ kwargs = {} if kwargs is None else kwargs.copy()
339
+ out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
340
+ num_channels = kwargs.pop("in_chans", config.num_channels)
341
+ if config.dilation:
342
+ kwargs["output_stride"] = kwargs.get("output_stride", 16)
343
+ backbone = create_model(
344
+ config.backbone,
345
+ pretrained=config.use_pretrained_backbone,
346
+ features_only=True,
347
+ out_indices=out_indices,
348
+ in_chans=num_channels,
349
+ **kwargs,
350
+ )
351
+ else:
352
+ backbone = load_backbone(config)
353
+
354
+ # replace batch norm by frozen batch norm
355
+ with torch.no_grad():
356
+ replace_batch_norm(backbone)
357
+ self.model = backbone
358
+ self.intermediate_channel_sizes = (
359
+ self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
360
+ )
361
+
362
+ backbone_model_type = None
363
+ if config.backbone is not None:
364
+ backbone_model_type = config.backbone
365
+ elif config.backbone_config is not None:
366
+ backbone_model_type = config.backbone_config.model_type
367
+ else:
368
+ raise ValueError("Either `backbone` or `backbone_config` should be provided in the config")
369
+
370
+ if "resnet" in backbone_model_type:
371
+ for name, parameter in self.model.named_parameters():
372
+ if config.use_timm_backbone:
373
+ if "layer2" not in name and "layer3" not in name and "layer4" not in name:
374
+ parameter.requires_grad_(False)
375
+ else:
376
+ if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
377
+ parameter.requires_grad_(False)
378
+
379
+ def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
380
+ # send pixel_values through the model to get list of feature maps
381
+ features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
382
+
383
+ out = []
384
+ for feature_map in features:
385
+ # downsample pixel_mask to match shape of corresponding feature_map
386
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
387
+ out.append((feature_map, mask))
388
+ return out
389
+
390
+
391
+ class DetrConvModel(nn.Module):
392
+ """
393
+ This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
394
+ """
395
+
396
+ def __init__(self, conv_encoder, position_embedding):
397
+ super().__init__()
398
+ self.conv_encoder = conv_encoder
399
+ self.position_embedding = position_embedding
400
+
401
+ def forward(self, pixel_values, pixel_mask):
402
+ # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
403
+ out = self.conv_encoder(pixel_values, pixel_mask)
404
+ pos = []
405
+ for feature_map, mask in out:
406
+ # position encoding
407
+ pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
408
+
409
+ return out, pos
410
+
411
+
412
+ class DetrSinePositionEmbedding(nn.Module):
413
+ """
414
+ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
415
+ need paper, generalized to work on images.
416
+ """
417
+
418
+ def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None):
419
+ super().__init__()
420
+ self.embedding_dim = embedding_dim
421
+ self.temperature = temperature
422
+ self.normalize = normalize
423
+ if scale is not None and normalize is False:
424
+ raise ValueError("normalize should be True if scale is passed")
425
+ if scale is None:
426
+ scale = 2 * math.pi
427
+ self.scale = scale
428
+
429
+ def forward(self, pixel_values, pixel_mask):
430
+ if pixel_mask is None:
431
+ raise ValueError("No pixel mask provided")
432
+ y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
433
+ x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
434
+ if self.normalize:
435
+ y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
436
+ x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
437
+
438
+ dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
439
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
440
+
441
+ pos_x = x_embed[:, :, :, None] / dim_t
442
+ pos_y = y_embed[:, :, :, None] / dim_t
443
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
444
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
445
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
446
+ return pos
447
+
448
+
449
+ class DetrLearnedPositionEmbedding(nn.Module):
450
+ """
451
+ This module learns positional embeddings up to a fixed maximum size.
452
+ """
453
+
454
+ def __init__(self, embedding_dim=256):
455
+ super().__init__()
456
+ self.row_embeddings = nn.Embedding(50, embedding_dim)
457
+ self.column_embeddings = nn.Embedding(50, embedding_dim)
458
+
459
+ def forward(self, pixel_values, pixel_mask=None):
460
+ height, width = pixel_values.shape[-2:]
461
+ width_values = torch.arange(width, device=pixel_values.device)
462
+ height_values = torch.arange(height, device=pixel_values.device)
463
+ x_emb = self.column_embeddings(width_values)
464
+ y_emb = self.row_embeddings(height_values)
465
+ pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
466
+ pos = pos.permute(2, 0, 1)
467
+ pos = pos.unsqueeze(0)
468
+ pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
469
+ return pos
470
+
471
+
472
+ def build_position_encoding(config):
473
+ n_steps = config.d_model // 2
474
+ if config.position_embedding_type == "sine":
475
+ # TODO find a better way of exposing other arguments
476
+ position_embedding = DetrSinePositionEmbedding(n_steps, normalize=True)
477
+ elif config.position_embedding_type == "learned":
478
+ position_embedding = DetrLearnedPositionEmbedding(n_steps)
479
+ else:
480
+ raise ValueError(f"Not supported {config.position_embedding_type}")
481
+
482
+ return position_embedding
483
+
484
+
485
+ class DetrAttention(nn.Module):
486
+ """
487
+ Multi-headed attention from 'Attention Is All You Need' paper.
488
+
489
+ Here, we add position embeddings to the queries and keys (as explained in the DETR paper).
490
+ """
491
+
492
+ def __init__(
493
+ self,
494
+ embed_dim: int,
495
+ num_heads: int,
496
+ dropout: float = 0.0,
497
+ bias: bool = True,
498
+ ):
499
+ super().__init__()
500
+ self.embed_dim = embed_dim
501
+ self.num_heads = num_heads
502
+ self.dropout = dropout
503
+ self.head_dim = embed_dim // num_heads
504
+ if self.head_dim * num_heads != self.embed_dim:
505
+ raise ValueError(
506
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
507
+ f" {num_heads})."
508
+ )
509
+ self.scaling = self.head_dim**-0.5
510
+
511
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
512
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
513
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
514
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
515
+
516
+ def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
517
+ return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
518
+
519
+ def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor]):
520
+ return tensor if object_queries is None else tensor + object_queries
521
+
522
+ def forward(
523
+ self,
524
+ hidden_states: torch.Tensor,
525
+ attention_mask: Optional[torch.Tensor] = None,
526
+ object_queries: Optional[torch.Tensor] = None,
527
+ key_value_states: Optional[torch.Tensor] = None,
528
+ spatial_position_embeddings: Optional[torch.Tensor] = None,
529
+ output_attentions: bool = False,
530
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
531
+ """Input shape: Batch x Time x Channel"""
532
+ # if key_value_states are provided this layer is used as a cross-attention layer
533
+ # for the decoder
534
+ is_cross_attention = key_value_states is not None
535
+ batch_size, target_len, embed_dim = hidden_states.size()
536
+
537
+ # add position embeddings to the hidden states before projecting to queries and keys
538
+ if object_queries is not None:
539
+ hidden_states_original = hidden_states
540
+ hidden_states = self.with_pos_embed(hidden_states, object_queries)
541
+
542
+ # add key-value position embeddings to the key value states
543
+ if spatial_position_embeddings is not None:
544
+ key_value_states_original = key_value_states
545
+ key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
546
+
547
+ # get query proj
548
+ query_states = self.q_proj(hidden_states) * self.scaling
549
+ # get key, value proj
550
+ if is_cross_attention:
551
+ # cross_attentions
552
+ key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
553
+ value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
554
+ else:
555
+ # self_attention
556
+ key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
557
+ value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
558
+
559
+ proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
560
+ query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
561
+ key_states = key_states.view(*proj_shape)
562
+ value_states = value_states.view(*proj_shape)
563
+
564
+ source_len = key_states.size(1)
565
+
566
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
567
+
568
+ if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
569
+ raise ValueError(
570
+ f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
571
+ f" {attn_weights.size()}"
572
+ )
573
+
574
+ if attention_mask is not None:
575
+ if attention_mask.size() != (batch_size, 1, target_len, source_len):
576
+ raise ValueError(
577
+ f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
578
+ f" {attention_mask.size()}"
579
+ )
580
+ attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
581
+ attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
582
+
583
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
584
+
585
+ if output_attentions:
586
+ # this operation is a bit awkward, but it's required to
587
+ # make sure that attn_weights keeps its gradient.
588
+ # In order to do so, attn_weights have to reshaped
589
+ # twice and have to be reused in the following
590
+ attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
591
+ attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
592
+ else:
593
+ attn_weights_reshaped = None
594
+
595
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
596
+
597
+ attn_output = torch.bmm(attn_probs, value_states)
598
+
599
+ if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
600
+ raise ValueError(
601
+ f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
602
+ f" {attn_output.size()}"
603
+ )
604
+
605
+ attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
606
+ attn_output = attn_output.transpose(1, 2)
607
+ attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
608
+
609
+ attn_output = self.out_proj(attn_output)
610
+
611
+ return attn_output, attn_weights_reshaped
612
+
613
+
614
+ class DetrEncoderLayer(nn.Module):
615
+ def __init__(self, config: DetrConfig):
616
+ super().__init__()
617
+ self.embed_dim = config.d_model
618
+ self.self_attn = DetrAttention(
619
+ embed_dim=self.embed_dim,
620
+ num_heads=config.encoder_attention_heads,
621
+ dropout=config.attention_dropout,
622
+ )
623
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
624
+ self.dropout = config.dropout
625
+ self.activation_fn = ACT2FN[config.activation_function]
626
+ self.activation_dropout = config.activation_dropout
627
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
628
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
629
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
630
+
631
+ def forward(
632
+ self,
633
+ hidden_states: torch.Tensor,
634
+ attention_mask: torch.Tensor,
635
+ object_queries: Optional[torch.Tensor] = None,
636
+ output_attentions: bool = False,
637
+ ):
638
+ """
639
+ Args:
640
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
641
+ attention_mask (`torch.FloatTensor`): attention mask of size
642
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
643
+ values.
644
+ object_queries (`torch.FloatTensor`, *optional*):
645
+ Object queries (also called content embeddings), to be added to the hidden states.
646
+ output_attentions (`bool`, *optional*):
647
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
648
+ returned tensors for more detail.
649
+ """
650
+ residual = hidden_states
651
+ hidden_states, attn_weights = self.self_attn(
652
+ hidden_states=hidden_states,
653
+ attention_mask=attention_mask,
654
+ object_queries=object_queries,
655
+ output_attentions=output_attentions,
656
+ )
657
+
658
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
659
+ hidden_states = residual + hidden_states
660
+ hidden_states = self.self_attn_layer_norm(hidden_states)
661
+
662
+ residual = hidden_states
663
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
664
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
665
+
666
+ hidden_states = self.fc2(hidden_states)
667
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
668
+
669
+ hidden_states = residual + hidden_states
670
+ hidden_states = self.final_layer_norm(hidden_states)
671
+
672
+ if self.training:
673
+ if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
674
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
675
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
676
+
677
+ outputs = (hidden_states,)
678
+
679
+ if output_attentions:
680
+ outputs += (attn_weights,)
681
+
682
+ return outputs
683
+
684
+
685
+ class DetrDecoderLayer(nn.Module):
686
+ def __init__(self, config: DetrConfig):
687
+ super().__init__()
688
+ self.embed_dim = config.d_model
689
+
690
+ self.self_attn = DetrAttention(
691
+ embed_dim=self.embed_dim,
692
+ num_heads=config.decoder_attention_heads,
693
+ dropout=config.attention_dropout,
694
+ )
695
+ self.dropout = config.dropout
696
+ self.activation_fn = ACT2FN[config.activation_function]
697
+ self.activation_dropout = config.activation_dropout
698
+
699
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
700
+ self.encoder_attn = DetrAttention(
701
+ self.embed_dim,
702
+ config.decoder_attention_heads,
703
+ dropout=config.attention_dropout,
704
+ )
705
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
706
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
707
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
708
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
709
+
710
+ def forward(
711
+ self,
712
+ hidden_states: torch.Tensor,
713
+ attention_mask: Optional[torch.Tensor] = None,
714
+ object_queries: Optional[torch.Tensor] = None,
715
+ query_position_embeddings: Optional[torch.Tensor] = None,
716
+ encoder_hidden_states: Optional[torch.Tensor] = None,
717
+ encoder_attention_mask: Optional[torch.Tensor] = None,
718
+ output_attentions: Optional[bool] = False,
719
+ ):
720
+ """
721
+ Args:
722
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
723
+ attention_mask (`torch.FloatTensor`): attention mask of size
724
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
725
+ values.
726
+ object_queries (`torch.FloatTensor`, *optional*):
727
+ object_queries that are added to the hidden states
728
+ in the cross-attention layer.
729
+ query_position_embeddings (`torch.FloatTensor`, *optional*):
730
+ position embeddings that are added to the queries and keys
731
+ in the self-attention layer.
732
+ encoder_hidden_states (`torch.FloatTensor`):
733
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
734
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
735
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
736
+ values.
737
+ output_attentions (`bool`, *optional*):
738
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
739
+ returned tensors for more detail.
740
+ """
741
+ residual = hidden_states
742
+
743
+ # Self Attention
744
+ hidden_states, self_attn_weights = self.self_attn(
745
+ hidden_states=hidden_states,
746
+ object_queries=query_position_embeddings,
747
+ attention_mask=attention_mask,
748
+ output_attentions=output_attentions,
749
+ )
750
+
751
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
752
+ hidden_states = residual + hidden_states
753
+ hidden_states = self.self_attn_layer_norm(hidden_states)
754
+
755
+ # Cross-Attention Block
756
+ cross_attn_weights = None
757
+ if encoder_hidden_states is not None:
758
+ residual = hidden_states
759
+
760
+ hidden_states, cross_attn_weights = self.encoder_attn(
761
+ hidden_states=hidden_states,
762
+ object_queries=query_position_embeddings,
763
+ key_value_states=encoder_hidden_states,
764
+ attention_mask=encoder_attention_mask,
765
+ spatial_position_embeddings=object_queries,
766
+ output_attentions=output_attentions,
767
+ )
768
+
769
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
770
+ hidden_states = residual + hidden_states
771
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
772
+
773
+ # Fully Connected
774
+ residual = hidden_states
775
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
776
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
777
+ hidden_states = self.fc2(hidden_states)
778
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
779
+ hidden_states = residual + hidden_states
780
+ hidden_states = self.final_layer_norm(hidden_states)
781
+
782
+ outputs = (hidden_states,)
783
+
784
+ if output_attentions:
785
+ outputs += (self_attn_weights, cross_attn_weights)
786
+
787
+ return outputs
788
+
789
+
790
+ class DetrPreTrainedModel(PreTrainedModel):
791
+ config_class = DetrConfig
792
+ base_model_prefix = "model"
793
+ main_input_name = "pixel_values"
794
+ _no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"]
795
+
796
+ def _init_weights(self, module):
797
+ std = self.config.init_std
798
+ xavier_std = self.config.init_xavier_std
799
+
800
+ if isinstance(module, DetrMHAttentionMap):
801
+ nn.init.zeros_(module.k_linear.bias)
802
+ nn.init.zeros_(module.q_linear.bias)
803
+ nn.init.xavier_uniform_(module.k_linear.weight, gain=xavier_std)
804
+ nn.init.xavier_uniform_(module.q_linear.weight, gain=xavier_std)
805
+ elif isinstance(module, DetrLearnedPositionEmbedding):
806
+ nn.init.uniform_(module.row_embeddings.weight)
807
+ nn.init.uniform_(module.column_embeddings.weight)
808
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
809
+ # Slightly different from the TF version which uses truncated_normal for initialization
810
+ # cf https://github.com/pytorch/pytorch/pull/5617
811
+ module.weight.data.normal_(mean=0.0, std=std)
812
+ if module.bias is not None:
813
+ module.bias.data.zero_()
814
+ elif isinstance(module, nn.Embedding):
815
+ module.weight.data.normal_(mean=0.0, std=std)
816
+ if module.padding_idx is not None:
817
+ module.weight.data[module.padding_idx].zero_()
818
+
819
+
820
+ DETR_START_DOCSTRING = r"""
821
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
822
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
823
+ etc.)
824
+
825
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
826
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
827
+ and behavior.
828
+
829
+ Parameters:
830
+ config ([`DetrConfig`]):
831
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
832
+ load the weights associated with the model, only the configuration. Check out the
833
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
834
+ """
835
+
836
+ DETR_INPUTS_DOCSTRING = r"""
837
+ Args:
838
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
839
+ Pixel values. Padding will be ignored by default should you provide it.
840
+
841
+ Pixel values can be obtained using [`AutoImageProcessor`]. See [`DetrImageProcessor.__call__`] for details.
842
+
843
+ pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
844
+ Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
845
+
846
+ - 1 for pixels that are real (i.e. **not masked**),
847
+ - 0 for pixels that are padding (i.e. **masked**).
848
+
849
+ [What are attention masks?](../glossary#attention-mask)
850
+
851
+ decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
852
+ Not used by default. Can be used to mask object queries.
853
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
854
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
855
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
856
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
857
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
858
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
859
+ can choose to directly pass a flattened representation of an image.
860
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
861
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
862
+ embedded representation.
863
+ output_attentions (`bool`, *optional*):
864
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
865
+ tensors for more detail.
866
+ output_hidden_states (`bool`, *optional*):
867
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
868
+ more detail.
869
+ return_dict (`bool`, *optional*):
870
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
871
+ """
872
+
873
+
874
+ class DetrEncoder(DetrPreTrainedModel):
875
+ """
876
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
877
+ [`DetrEncoderLayer`].
878
+
879
+ The encoder updates the flattened feature map through multiple self-attention layers.
880
+
881
+ Small tweak for DETR:
882
+
883
+ - object_queries are added to the forward pass.
884
+
885
+ Args:
886
+ config: DetrConfig
887
+ """
888
+
889
+ def __init__(self, config: DetrConfig):
890
+ super().__init__(config)
891
+
892
+ self.dropout = config.dropout
893
+ self.layerdrop = config.encoder_layerdrop
894
+
895
+ self.layers = nn.ModuleList([DetrEncoderLayer(config) for _ in range(config.encoder_layers)])
896
+
897
+ # in the original DETR, no layernorm is used at the end of the encoder, as "normalize_before" is set to False by default
898
+
899
+ # Initialize weights and apply final processing
900
+ self.post_init()
901
+
902
+ def forward(
903
+ self,
904
+ inputs_embeds=None,
905
+ attention_mask=None,
906
+ object_queries=None,
907
+ output_attentions=None,
908
+ output_hidden_states=None,
909
+ return_dict=None,
910
+ ):
911
+ r"""
912
+ Args:
913
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
914
+ Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
915
+
916
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
917
+ Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
918
+
919
+ - 1 for pixel features that are real (i.e. **not masked**),
920
+ - 0 for pixel features that are padding (i.e. **masked**).
921
+
922
+ [What are attention masks?](../glossary#attention-mask)
923
+
924
+ object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
925
+ Object queries that are added to the queries in each self-attention layer.
926
+
927
+ output_attentions (`bool`, *optional*):
928
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
929
+ returned tensors for more detail.
930
+ output_hidden_states (`bool`, *optional*):
931
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
932
+ for more detail.
933
+ return_dict (`bool`, *optional*):
934
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
935
+ """
936
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
937
+ output_hidden_states = (
938
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
939
+ )
940
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
941
+
942
+ hidden_states = inputs_embeds
943
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
944
+
945
+ # expand attention_mask
946
+ if attention_mask is not None:
947
+ # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
948
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
949
+
950
+ encoder_states = () if output_hidden_states else None
951
+ all_attentions = () if output_attentions else None
952
+ for i, encoder_layer in enumerate(self.layers):
953
+ if output_hidden_states:
954
+ encoder_states = encoder_states + (hidden_states,)
955
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
956
+ to_drop = False
957
+ if self.training:
958
+ dropout_probability = torch.rand([])
959
+ if dropout_probability < self.layerdrop: # skip the layer
960
+ to_drop = True
961
+
962
+ if to_drop:
963
+ layer_outputs = (None, None)
964
+ else:
965
+ # we add object_queries as extra input to the encoder_layer
966
+ layer_outputs = encoder_layer(
967
+ hidden_states,
968
+ attention_mask,
969
+ object_queries=object_queries,
970
+ output_attentions=output_attentions,
971
+ )
972
+
973
+ hidden_states = layer_outputs[0]
974
+
975
+ if output_attentions:
976
+ all_attentions = all_attentions + (layer_outputs[1],)
977
+
978
+ if output_hidden_states:
979
+ encoder_states = encoder_states + (hidden_states,)
980
+
981
+ if not return_dict:
982
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
983
+ return BaseModelOutput(
984
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
985
+ )
986
+
987
+
988
+ class DetrDecoder(DetrPreTrainedModel):
989
+ """
990
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`].
991
+
992
+ The decoder updates the query embeddings through multiple self-attention and cross-attention layers.
993
+
994
+ Some small tweaks for DETR:
995
+
996
+ - object_queries and query_position_embeddings are added to the forward pass.
997
+ - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
998
+
999
+ Args:
1000
+ config: DetrConfig
1001
+ """
1002
+
1003
+ def __init__(self, config: DetrConfig):
1004
+ super().__init__(config)
1005
+ self.dropout = config.dropout
1006
+ self.layerdrop = config.decoder_layerdrop
1007
+
1008
+ self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)])
1009
+ # in DETR, the decoder uses layernorm after the last decoder layer output
1010
+ self.layernorm = nn.LayerNorm(config.d_model)
1011
+
1012
+ self.gradient_checkpointing = False
1013
+ # Initialize weights and apply final processing
1014
+ self.post_init()
1015
+
1016
+ def forward(
1017
+ self,
1018
+ inputs_embeds=None,
1019
+ attention_mask=None,
1020
+ encoder_hidden_states=None,
1021
+ encoder_attention_mask=None,
1022
+ object_queries=None,
1023
+ query_position_embeddings=None,
1024
+ output_attentions=None,
1025
+ output_hidden_states=None,
1026
+ return_dict=None,
1027
+ ):
1028
+ r"""
1029
+ Args:
1030
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1031
+ The query embeddings that are passed into the decoder.
1032
+
1033
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1034
+ Mask to avoid performing attention on certain queries. Mask values selected in `[0, 1]`:
1035
+
1036
+ - 1 for queries that are **not masked**,
1037
+ - 0 for queries that are **masked**.
1038
+
1039
+ [What are attention masks?](../glossary#attention-mask)
1040
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
1041
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1042
+ of the decoder.
1043
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
1044
+ Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
1045
+ in `[0, 1]`:
1046
+
1047
+ - 1 for pixels that are real (i.e. **not masked**),
1048
+ - 0 for pixels that are padding (i.e. **masked**).
1049
+
1050
+ object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1051
+ Object queries that are added to the queries and keys in each cross-attention layer.
1052
+ query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
1053
+ , *optional*): Position embeddings that are added to the values and keys in each self-attention layer.
1054
+
1055
+ output_attentions (`bool`, *optional*):
1056
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1057
+ returned tensors for more detail.
1058
+ output_hidden_states (`bool`, *optional*):
1059
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1060
+ for more detail.
1061
+ return_dict (`bool`, *optional*):
1062
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1063
+ """
1064
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1065
+ output_hidden_states = (
1066
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1067
+ )
1068
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1069
+
1070
+ if inputs_embeds is not None:
1071
+ hidden_states = inputs_embeds
1072
+ input_shape = inputs_embeds.size()[:-1]
1073
+
1074
+ combined_attention_mask = None
1075
+
1076
+ if attention_mask is not None and combined_attention_mask is not None:
1077
+ # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
1078
+ combined_attention_mask = combined_attention_mask + _prepare_4d_attention_mask(
1079
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1080
+ )
1081
+
1082
+ # expand encoder attention mask
1083
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
1084
+ # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
1085
+ encoder_attention_mask = _prepare_4d_attention_mask(
1086
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1087
+ )
1088
+
1089
+ # optional intermediate hidden states
1090
+ intermediate = () if self.config.auxiliary_loss else None
1091
+
1092
+ # decoder layers
1093
+ all_hidden_states = () if output_hidden_states else None
1094
+ all_self_attns = () if output_attentions else None
1095
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1096
+
1097
+ for idx, decoder_layer in enumerate(self.layers):
1098
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1099
+ if output_hidden_states:
1100
+ all_hidden_states += (hidden_states,)
1101
+ if self.training:
1102
+ dropout_probability = torch.rand([])
1103
+ if dropout_probability < self.layerdrop:
1104
+ continue
1105
+
1106
+ if self.gradient_checkpointing and self.training:
1107
+ layer_outputs = self._gradient_checkpointing_func(
1108
+ decoder_layer.__call__,
1109
+ hidden_states,
1110
+ combined_attention_mask,
1111
+ encoder_hidden_states,
1112
+ encoder_attention_mask,
1113
+ None,
1114
+ )
1115
+ else:
1116
+ layer_outputs = decoder_layer(
1117
+ hidden_states,
1118
+ attention_mask=combined_attention_mask,
1119
+ object_queries=object_queries,
1120
+ query_position_embeddings=query_position_embeddings,
1121
+ encoder_hidden_states=encoder_hidden_states,
1122
+ encoder_attention_mask=encoder_attention_mask,
1123
+ output_attentions=output_attentions,
1124
+ )
1125
+
1126
+ hidden_states = layer_outputs[0]
1127
+
1128
+ if self.config.auxiliary_loss:
1129
+ hidden_states = self.layernorm(hidden_states)
1130
+ intermediate += (hidden_states,)
1131
+
1132
+ if output_attentions:
1133
+ all_self_attns += (layer_outputs[1],)
1134
+
1135
+ if encoder_hidden_states is not None:
1136
+ all_cross_attentions += (layer_outputs[2],)
1137
+
1138
+ # finally, apply layernorm
1139
+ hidden_states = self.layernorm(hidden_states)
1140
+
1141
+ # add hidden states from the last decoder layer
1142
+ if output_hidden_states:
1143
+ all_hidden_states += (hidden_states,)
1144
+
1145
+ # stack intermediate decoder activations
1146
+ if self.config.auxiliary_loss:
1147
+ intermediate = torch.stack(intermediate)
1148
+
1149
+ if not return_dict:
1150
+ return tuple(
1151
+ v
1152
+ for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate]
1153
+ if v is not None
1154
+ )
1155
+ return DetrDecoderOutput(
1156
+ last_hidden_state=hidden_states,
1157
+ hidden_states=all_hidden_states,
1158
+ attentions=all_self_attns,
1159
+ cross_attentions=all_cross_attentions,
1160
+ intermediate_hidden_states=intermediate,
1161
+ )
1162
+
1163
+
1164
+ @add_start_docstrings(
1165
+ """
1166
+ The bare DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without
1167
+ any specific head on top.
1168
+ """,
1169
+ DETR_START_DOCSTRING,
1170
+ )
1171
+ class DetrModel(DetrPreTrainedModel):
1172
+ def __init__(self, config: DetrConfig):
1173
+ super().__init__(config)
1174
+
1175
+ # Create backbone + positional encoding
1176
+ backbone = DetrConvEncoder(config)
1177
+ object_queries = build_position_encoding(config)
1178
+ self.backbone = DetrConvModel(backbone, object_queries)
1179
+
1180
+ # Create projection layer
1181
+ self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
1182
+
1183
+ self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
1184
+
1185
+ self.encoder = DetrEncoder(config)
1186
+ self.decoder = DetrDecoder(config)
1187
+
1188
+ # Initialize weights and apply final processing
1189
+ self.post_init()
1190
+
1191
+ def get_encoder(self):
1192
+ return self.encoder
1193
+
1194
+ def get_decoder(self):
1195
+ return self.decoder
1196
+
1197
+ def freeze_backbone(self):
1198
+ for name, param in self.backbone.conv_encoder.model.named_parameters():
1199
+ param.requires_grad_(False)
1200
+
1201
+ def unfreeze_backbone(self):
1202
+ for name, param in self.backbone.conv_encoder.model.named_parameters():
1203
+ param.requires_grad_(True)
1204
+
1205
+ @add_start_docstrings_to_model_forward(DETR_INPUTS_DOCSTRING)
1206
+ @replace_return_docstrings(output_type=DetrModelOutput, config_class=_CONFIG_FOR_DOC)
1207
+ def forward(
1208
+ self,
1209
+ pixel_values: torch.FloatTensor,
1210
+ pixel_mask: Optional[torch.LongTensor] = None,
1211
+ decoder_attention_mask: Optional[torch.FloatTensor] = None,
1212
+ encoder_outputs: Optional[torch.FloatTensor] = None,
1213
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1214
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1215
+ output_attentions: Optional[bool] = None,
1216
+ output_hidden_states: Optional[bool] = None,
1217
+ return_dict: Optional[bool] = None,
1218
+ ) -> Union[Tuple[torch.FloatTensor], DetrModelOutput]:
1219
+ r"""
1220
+ Returns:
1221
+
1222
+ Examples:
1223
+
1224
+ ```python
1225
+ >>> from transformers import AutoImageProcessor, DetrModel
1226
+ >>> from PIL import Image
1227
+ >>> import requests
1228
+
1229
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1230
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1231
+
1232
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
1233
+ >>> model = DetrModel.from_pretrained("facebook/detr-resnet-50")
1234
+
1235
+ >>> # prepare image for the model
1236
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1237
+
1238
+ >>> # forward pass
1239
+ >>> outputs = model(**inputs)
1240
+
1241
+ >>> # the last hidden states are the final query embeddings of the Transformer decoder
1242
+ >>> # these are of shape (batch_size, num_queries, hidden_size)
1243
+ >>> last_hidden_states = outputs.last_hidden_state
1244
+ >>> list(last_hidden_states.shape)
1245
+ [1, 100, 256]
1246
+ ```"""
1247
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1248
+ output_hidden_states = (
1249
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1250
+ )
1251
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1252
+
1253
+ batch_size, num_channels, height, width = pixel_values.shape
1254
+ device = pixel_values.device
1255
+
1256
+ if pixel_mask is None:
1257
+ pixel_mask = torch.ones(((batch_size, height, width)), device=device)
1258
+
1259
+ # First, sent pixel_values + pixel_mask through Backbone to obtain the features
1260
+ # pixel_values should be of shape (batch_size, num_channels, height, width)
1261
+ # pixel_mask should be of shape (batch_size, height, width)
1262
+ features, object_queries_list = self.backbone(pixel_values, pixel_mask)
1263
+
1264
+ # get final feature map and downsampled mask
1265
+ feature_map, mask = features[-1]
1266
+
1267
+ if mask is None:
1268
+ raise ValueError("Backbone does not return downsampled pixel mask")
1269
+
1270
+ # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
1271
+ projected_feature_map = self.input_projection(feature_map)
1272
+
1273
+ # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
1274
+ # In other words, turn their shape into (batch_size, sequence_length, hidden_size)
1275
+ flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
1276
+ object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
1277
+
1278
+ flattened_mask = mask.flatten(1)
1279
+
1280
+ # Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
1281
+ # flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size)
1282
+ # flattened_mask is a Tensor of shape (batch_size, heigth*width)
1283
+ if encoder_outputs is None:
1284
+ encoder_outputs = self.encoder(
1285
+ inputs_embeds=flattened_features,
1286
+ attention_mask=flattened_mask,
1287
+ object_queries=object_queries,
1288
+ output_attentions=output_attentions,
1289
+ output_hidden_states=output_hidden_states,
1290
+ return_dict=return_dict,
1291
+ )
1292
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1293
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1294
+ encoder_outputs = BaseModelOutput(
1295
+ last_hidden_state=encoder_outputs[0],
1296
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1297
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1298
+ )
1299
+
1300
+ # Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output)
1301
+ query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
1302
+ queries = torch.zeros_like(query_position_embeddings)
1303
+
1304
+ # decoder outputs consists of (dec_features, dec_hidden, dec_attn)
1305
+ decoder_outputs = self.decoder(
1306
+ inputs_embeds=queries,
1307
+ attention_mask=None,
1308
+ object_queries=object_queries,
1309
+ query_position_embeddings=query_position_embeddings,
1310
+ encoder_hidden_states=encoder_outputs[0],
1311
+ encoder_attention_mask=flattened_mask,
1312
+ output_attentions=output_attentions,
1313
+ output_hidden_states=output_hidden_states,
1314
+ return_dict=return_dict,
1315
+ )
1316
+
1317
+ if not return_dict:
1318
+ return decoder_outputs + encoder_outputs
1319
+
1320
+ return DetrModelOutput(
1321
+ last_hidden_state=decoder_outputs.last_hidden_state,
1322
+ decoder_hidden_states=decoder_outputs.hidden_states,
1323
+ decoder_attentions=decoder_outputs.attentions,
1324
+ cross_attentions=decoder_outputs.cross_attentions,
1325
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1326
+ encoder_hidden_states=encoder_outputs.hidden_states,
1327
+ encoder_attentions=encoder_outputs.attentions,
1328
+ intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
1329
+ )
1330
+
1331
+
1332
+ # taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1333
+ class DetrMLPPredictionHead(nn.Module):
1334
+ """
1335
+ Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
1336
+ height and width of a bounding box w.r.t. an image.
1337
+
1338
+ Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1339
+
1340
+ """
1341
+
1342
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
1343
+ super().__init__()
1344
+ self.num_layers = num_layers
1345
+ h = [hidden_dim] * (num_layers - 1)
1346
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
1347
+
1348
+ def forward(self, x):
1349
+ for i, layer in enumerate(self.layers):
1350
+ x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
1351
+ return x
1352
+
1353
+
1354
+ @add_start_docstrings(
1355
+ """
1356
+ DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, for tasks
1357
+ such as COCO detection.
1358
+ """,
1359
+ DETR_START_DOCSTRING,
1360
+ )
1361
+ class DetrForObjectDetection(DetrPreTrainedModel):
1362
+ def __init__(self, config: DetrConfig):
1363
+ super().__init__(config)
1364
+
1365
+ # DETR encoder-decoder model
1366
+ self.model = DetrModel(config)
1367
+
1368
+ # Object detection heads
1369
+ self.class_labels_classifier = nn.Linear(
1370
+ config.d_model, config.num_labels + 1
1371
+ ) # We add one for the "no object" class
1372
+ self.bbox_predictor = DetrMLPPredictionHead(
1373
+ input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
1374
+ )
1375
+
1376
+ # Initialize weights and apply final processing
1377
+ self.post_init()
1378
+
1379
+ @add_start_docstrings_to_model_forward(DETR_INPUTS_DOCSTRING)
1380
+ @replace_return_docstrings(output_type=DetrObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
1381
+ def forward(
1382
+ self,
1383
+ pixel_values: torch.FloatTensor,
1384
+ pixel_mask: Optional[torch.LongTensor] = None,
1385
+ decoder_attention_mask: Optional[torch.FloatTensor] = None,
1386
+ encoder_outputs: Optional[torch.FloatTensor] = None,
1387
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1388
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1389
+ labels: Optional[List[dict]] = None,
1390
+ output_attentions: Optional[bool] = None,
1391
+ output_hidden_states: Optional[bool] = None,
1392
+ return_dict: Optional[bool] = None,
1393
+ ) -> Union[Tuple[torch.FloatTensor], DetrObjectDetectionOutput]:
1394
+ r"""
1395
+ labels (`List[Dict]` of len `(batch_size,)`, *optional*):
1396
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1397
+ following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
1398
+ respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
1399
+ in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
1400
+
1401
+ Returns:
1402
+
1403
+ Examples:
1404
+
1405
+ ```python
1406
+ >>> from transformers import AutoImageProcessor, DetrForObjectDetection
1407
+ >>> import torch
1408
+ >>> from PIL import Image
1409
+ >>> import requests
1410
+
1411
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1412
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1413
+
1414
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
1415
+ >>> model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
1416
+
1417
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1418
+ >>> outputs = model(**inputs)
1419
+
1420
+ >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
1421
+ >>> target_sizes = torch.tensor([image.size[::-1]])
1422
+ >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
1423
+ ... 0
1424
+ ... ]
1425
+
1426
+ >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
1427
+ ... box = [round(i, 2) for i in box.tolist()]
1428
+ ... print(
1429
+ ... f"Detected {model.config.id2label[label.item()]} with confidence "
1430
+ ... f"{round(score.item(), 3)} at location {box}"
1431
+ ... )
1432
+ Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98]
1433
+ Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66]
1434
+ Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76]
1435
+ Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]
1436
+ Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]
1437
+ ```"""
1438
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1439
+
1440
+ # First, sent images through DETR base model to obtain encoder + decoder outputs
1441
+ outputs = self.model(
1442
+ pixel_values,
1443
+ pixel_mask=pixel_mask,
1444
+ decoder_attention_mask=decoder_attention_mask,
1445
+ encoder_outputs=encoder_outputs,
1446
+ inputs_embeds=inputs_embeds,
1447
+ decoder_inputs_embeds=decoder_inputs_embeds,
1448
+ output_attentions=output_attentions,
1449
+ output_hidden_states=output_hidden_states,
1450
+ return_dict=return_dict,
1451
+ )
1452
+
1453
+ sequence_output = outputs[0]
1454
+
1455
+ # class logits + predicted bounding boxes
1456
+ logits = self.class_labels_classifier(sequence_output)
1457
+ pred_boxes = self.bbox_predictor(sequence_output).sigmoid()
1458
+
1459
+ loss, loss_dict, auxiliary_outputs = None, None, None
1460
+ if labels is not None:
1461
+ outputs_class, outputs_coord = None, None
1462
+ if self.config.auxiliary_loss:
1463
+ intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
1464
+ outputs_class = self.class_labels_classifier(intermediate)
1465
+ outputs_coord = self.bbox_predictor(intermediate).sigmoid()
1466
+ loss, loss_dict, auxiliary_outputs = self.loss_function(
1467
+ logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
1468
+ )
1469
+
1470
+ if not return_dict:
1471
+ if auxiliary_outputs is not None:
1472
+ output = (logits, pred_boxes) + auxiliary_outputs + outputs
1473
+ else:
1474
+ output = (logits, pred_boxes) + outputs
1475
+ return ((loss, loss_dict) + output) if loss is not None else output
1476
+
1477
+ return DetrObjectDetectionOutput(
1478
+ loss=loss,
1479
+ loss_dict=loss_dict,
1480
+ logits=logits,
1481
+ pred_boxes=pred_boxes,
1482
+ auxiliary_outputs=auxiliary_outputs,
1483
+ last_hidden_state=outputs.last_hidden_state,
1484
+ decoder_hidden_states=outputs.decoder_hidden_states,
1485
+ decoder_attentions=outputs.decoder_attentions,
1486
+ cross_attentions=outputs.cross_attentions,
1487
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1488
+ encoder_hidden_states=outputs.encoder_hidden_states,
1489
+ encoder_attentions=outputs.encoder_attentions,
1490
+ )
1491
+
1492
+
1493
+ @add_start_docstrings(
1494
+ """
1495
+ DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top, for tasks
1496
+ such as COCO panoptic.
1497
+
1498
+ """,
1499
+ DETR_START_DOCSTRING,
1500
+ )
1501
+ class DetrForSegmentation(DetrPreTrainedModel):
1502
+ def __init__(self, config: DetrConfig):
1503
+ super().__init__(config)
1504
+
1505
+ # object detection model
1506
+ self.detr = DetrForObjectDetection(config)
1507
+
1508
+ # segmentation head
1509
+ hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
1510
+ intermediate_channel_sizes = self.detr.model.backbone.conv_encoder.intermediate_channel_sizes
1511
+
1512
+ self.mask_head = DetrMaskHeadSmallConv(
1513
+ hidden_size + number_of_heads, intermediate_channel_sizes[::-1][-3:], hidden_size
1514
+ )
1515
+
1516
+ self.bbox_attention = DetrMHAttentionMap(
1517
+ hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std
1518
+ )
1519
+ # Initialize weights and apply final processing
1520
+ self.post_init()
1521
+
1522
+ @add_start_docstrings_to_model_forward(DETR_INPUTS_DOCSTRING)
1523
+ @replace_return_docstrings(output_type=DetrSegmentationOutput, config_class=_CONFIG_FOR_DOC)
1524
+ def forward(
1525
+ self,
1526
+ pixel_values: torch.FloatTensor,
1527
+ pixel_mask: Optional[torch.LongTensor] = None,
1528
+ decoder_attention_mask: Optional[torch.FloatTensor] = None,
1529
+ encoder_outputs: Optional[torch.FloatTensor] = None,
1530
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1531
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1532
+ labels: Optional[List[dict]] = None,
1533
+ output_attentions: Optional[bool] = None,
1534
+ output_hidden_states: Optional[bool] = None,
1535
+ return_dict: Optional[bool] = None,
1536
+ ) -> Union[Tuple[torch.FloatTensor], DetrSegmentationOutput]:
1537
+ r"""
1538
+ labels (`List[Dict]` of len `(batch_size,)`, *optional*):
1539
+ Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
1540
+ dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
1541
+ bounding boxes and segmentation masks of an image in the batch respectively). The class labels themselves
1542
+ should be a `torch.LongTensor` of len `(number of bounding boxes in the image,)`, the boxes a
1543
+ `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)` and the masks a
1544
+ `torch.FloatTensor` of shape `(number of bounding boxes in the image, height, width)`.
1545
+
1546
+ Returns:
1547
+
1548
+ Examples:
1549
+
1550
+ ```python
1551
+ >>> import io
1552
+ >>> import requests
1553
+ >>> from PIL import Image
1554
+ >>> import torch
1555
+ >>> import numpy
1556
+
1557
+ >>> from transformers import AutoImageProcessor, DetrForSegmentation
1558
+ >>> from transformers.image_transforms import rgb_to_id
1559
+
1560
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1561
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1562
+
1563
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic")
1564
+ >>> model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")
1565
+
1566
+ >>> # prepare image for the model
1567
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1568
+
1569
+ >>> # forward pass
1570
+ >>> outputs = model(**inputs)
1571
+
1572
+ >>> # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
1573
+ >>> # Segmentation results are returned as a list of dictionaries
1574
+ >>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(300, 500)])
1575
+
1576
+ >>> # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
1577
+ >>> panoptic_seg = result[0]["segmentation"]
1578
+ >>> # Get prediction score and segment_id to class_id mapping of each segment
1579
+ >>> panoptic_segments_info = result[0]["segments_info"]
1580
+ ```"""
1581
+
1582
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1583
+
1584
+ batch_size, num_channels, height, width = pixel_values.shape
1585
+ device = pixel_values.device
1586
+
1587
+ if pixel_mask is None:
1588
+ pixel_mask = torch.ones((batch_size, height, width), device=device)
1589
+
1590
+ # First, get list of feature maps and position embeddings
1591
+ features, object_queries_list = self.detr.model.backbone(pixel_values, pixel_mask=pixel_mask)
1592
+
1593
+ # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
1594
+ feature_map, mask = features[-1]
1595
+ batch_size, num_channels, height, width = feature_map.shape
1596
+ projected_feature_map = self.detr.model.input_projection(feature_map)
1597
+
1598
+ # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
1599
+ # In other words, turn their shape into (batch_size, sequence_length, hidden_size)
1600
+ flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
1601
+ object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
1602
+
1603
+ flattened_mask = mask.flatten(1)
1604
+
1605
+ # Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
1606
+ # flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size)
1607
+ # flattened_mask is a Tensor of shape (batch_size, heigth*width)
1608
+ if encoder_outputs is None:
1609
+ encoder_outputs = self.detr.model.encoder(
1610
+ inputs_embeds=flattened_features,
1611
+ attention_mask=flattened_mask,
1612
+ object_queries=object_queries,
1613
+ output_attentions=output_attentions,
1614
+ output_hidden_states=output_hidden_states,
1615
+ return_dict=return_dict,
1616
+ )
1617
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1618
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1619
+ encoder_outputs = BaseModelOutput(
1620
+ last_hidden_state=encoder_outputs[0],
1621
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1622
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1623
+ )
1624
+
1625
+ # Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output)
1626
+ query_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
1627
+ batch_size, 1, 1
1628
+ )
1629
+ queries = torch.zeros_like(query_position_embeddings)
1630
+
1631
+ # decoder outputs consists of (dec_features, dec_hidden, dec_attn)
1632
+ decoder_outputs = self.detr.model.decoder(
1633
+ inputs_embeds=queries,
1634
+ attention_mask=None,
1635
+ object_queries=object_queries,
1636
+ query_position_embeddings=query_position_embeddings,
1637
+ encoder_hidden_states=encoder_outputs[0],
1638
+ encoder_attention_mask=flattened_mask,
1639
+ output_attentions=output_attentions,
1640
+ output_hidden_states=output_hidden_states,
1641
+ return_dict=return_dict,
1642
+ )
1643
+
1644
+ sequence_output = decoder_outputs[0]
1645
+
1646
+ # Sixth, compute logits, pred_boxes and pred_masks
1647
+ logits = self.detr.class_labels_classifier(sequence_output)
1648
+ pred_boxes = self.detr.bbox_predictor(sequence_output).sigmoid()
1649
+
1650
+ memory = encoder_outputs[0].permute(0, 2, 1).view(batch_size, self.config.d_model, height, width)
1651
+ mask = flattened_mask.view(batch_size, height, width)
1652
+
1653
+ # FIXME h_boxes takes the last one computed, keep this in mind
1654
+ # important: we need to reverse the mask, since in the original implementation the mask works reversed
1655
+ # bbox_mask is of shape (batch_size, num_queries, number_of_attention_heads in bbox_attention, height/32, width/32)
1656
+ bbox_mask = self.bbox_attention(sequence_output, memory, mask=~mask)
1657
+
1658
+ seg_masks = self.mask_head(projected_feature_map, bbox_mask, [features[2][0], features[1][0], features[0][0]])
1659
+
1660
+ pred_masks = seg_masks.view(batch_size, self.detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
1661
+
1662
+ loss, loss_dict, auxiliary_outputs = None, None, None
1663
+ if labels is not None:
1664
+ outputs_class, outputs_coord = None, None
1665
+ if self.config.auxiliary_loss:
1666
+ intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1]
1667
+ outputs_class = self.detr.class_labels_classifier(intermediate)
1668
+ outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid()
1669
+ loss, loss_dict, auxiliary_outputs = self.loss_function(
1670
+ logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
1671
+ )
1672
+
1673
+ if not return_dict:
1674
+ if auxiliary_outputs is not None:
1675
+ output = (logits, pred_boxes, pred_masks) + auxiliary_outputs + decoder_outputs + encoder_outputs
1676
+ else:
1677
+ output = (logits, pred_boxes, pred_masks) + decoder_outputs + encoder_outputs
1678
+ return ((loss, loss_dict) + output) if loss is not None else output
1679
+
1680
+ return DetrSegmentationOutput(
1681
+ loss=loss,
1682
+ loss_dict=loss_dict,
1683
+ logits=logits,
1684
+ pred_boxes=pred_boxes,
1685
+ pred_masks=pred_masks,
1686
+ auxiliary_outputs=auxiliary_outputs,
1687
+ last_hidden_state=decoder_outputs.last_hidden_state,
1688
+ decoder_hidden_states=decoder_outputs.hidden_states,
1689
+ decoder_attentions=decoder_outputs.attentions,
1690
+ cross_attentions=decoder_outputs.cross_attentions,
1691
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1692
+ encoder_hidden_states=encoder_outputs.hidden_states,
1693
+ encoder_attentions=encoder_outputs.attentions,
1694
+ )
1695
+
1696
+
1697
+ def _expand(tensor, length: int):
1698
+ return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
1699
+
1700
+
1701
+ # taken from https://github.com/facebookresearch/detr/blob/master/models/segmentation.py
1702
+ class DetrMaskHeadSmallConv(nn.Module):
1703
+ """
1704
+ Simple convolutional head, using group norm. Upsampling is done using a FPN approach
1705
+ """
1706
+
1707
+ def __init__(self, dim, fpn_dims, context_dim):
1708
+ super().__init__()
1709
+
1710
+ if dim % 8 != 0:
1711
+ raise ValueError(
1712
+ "The hidden_size + number of attention heads must be divisible by 8 as the number of groups in"
1713
+ " GroupNorm is set to 8"
1714
+ )
1715
+
1716
+ inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
1717
+
1718
+ self.lay1 = nn.Conv2d(dim, dim, 3, padding=1)
1719
+ self.gn1 = nn.GroupNorm(8, dim)
1720
+ self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1)
1721
+ self.gn2 = nn.GroupNorm(min(8, inter_dims[1]), inter_dims[1])
1722
+ self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
1723
+ self.gn3 = nn.GroupNorm(min(8, inter_dims[2]), inter_dims[2])
1724
+ self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
1725
+ self.gn4 = nn.GroupNorm(min(8, inter_dims[3]), inter_dims[3])
1726
+ self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
1727
+ self.gn5 = nn.GroupNorm(min(8, inter_dims[4]), inter_dims[4])
1728
+ self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1)
1729
+
1730
+ self.dim = dim
1731
+
1732
+ self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
1733
+ self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
1734
+ self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
1735
+
1736
+ for m in self.modules():
1737
+ if isinstance(m, nn.Conv2d):
1738
+ nn.init.kaiming_uniform_(m.weight, a=1)
1739
+ nn.init.constant_(m.bias, 0)
1740
+
1741
+ def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]):
1742
+ # here we concatenate x, the projected feature map, of shape (batch_size, d_model, heigth/32, width/32) with
1743
+ # the bbox_mask = the attention maps of shape (batch_size, n_queries, n_heads, height/32, width/32).
1744
+ # We expand the projected feature map to match the number of heads.
1745
+ x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
1746
+
1747
+ x = self.lay1(x)
1748
+ x = self.gn1(x)
1749
+ x = nn.functional.relu(x)
1750
+ x = self.lay2(x)
1751
+ x = self.gn2(x)
1752
+ x = nn.functional.relu(x)
1753
+
1754
+ cur_fpn = self.adapter1(fpns[0])
1755
+ if cur_fpn.size(0) != x.size(0):
1756
+ cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1757
+ x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1758
+ x = self.lay3(x)
1759
+ x = self.gn3(x)
1760
+ x = nn.functional.relu(x)
1761
+
1762
+ cur_fpn = self.adapter2(fpns[1])
1763
+ if cur_fpn.size(0) != x.size(0):
1764
+ cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1765
+ x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1766
+ x = self.lay4(x)
1767
+ x = self.gn4(x)
1768
+ x = nn.functional.relu(x)
1769
+
1770
+ cur_fpn = self.adapter3(fpns[2])
1771
+ if cur_fpn.size(0) != x.size(0):
1772
+ cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1773
+ x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1774
+ x = self.lay5(x)
1775
+ x = self.gn5(x)
1776
+ x = nn.functional.relu(x)
1777
+
1778
+ x = self.out_lay(x)
1779
+ return x
1780
+
1781
+
1782
+ class DetrMHAttentionMap(nn.Module):
1783
+ """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
1784
+
1785
+ def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None):
1786
+ super().__init__()
1787
+ self.num_heads = num_heads
1788
+ self.hidden_dim = hidden_dim
1789
+ self.dropout = nn.Dropout(dropout)
1790
+
1791
+ self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
1792
+ self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
1793
+
1794
+ self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
1795
+
1796
+ def forward(self, q, k, mask: Optional[Tensor] = None):
1797
+ q = self.q_linear(q)
1798
+ k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
1799
+ queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
1800
+ keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
1801
+ weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
1802
+
1803
+ if mask is not None:
1804
+ weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
1805
+ weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
1806
+ weights = self.dropout(weights)
1807
+ return weights
1808
+
1809
+
1810
+ __all__ = [
1811
+ "DetrForObjectDetection",
1812
+ "DetrForSegmentation",
1813
+ "DetrModel",
1814
+ "DetrPreTrainedModel",
1815
+ ]
docs/transformers/build/lib/transformers/models/dialogpt/__init__.py ADDED
File without changes
docs/transformers/build/lib/transformers/models/dialogpt/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import os
17
+
18
+ import torch
19
+
20
+ from transformers.utils import WEIGHTS_NAME
21
+
22
+
23
+ DIALOGPT_MODELS = ["small", "medium", "large"]
24
+
25
+ OLD_KEY = "lm_head.decoder.weight"
26
+ NEW_KEY = "lm_head.weight"
27
+
28
+
29
+ def convert_dialogpt_checkpoint(checkpoint_path: str, pytorch_dump_folder_path: str):
30
+ d = torch.load(checkpoint_path, weights_only=True)
31
+ d[NEW_KEY] = d.pop(OLD_KEY)
32
+ os.makedirs(pytorch_dump_folder_path, exist_ok=True)
33
+ torch.save(d, os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME))
34
+
35
+
36
+ if __name__ == "__main__":
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument("--dialogpt_path", default=".", type=str)
39
+ args = parser.parse_args()
40
+ for MODEL in DIALOGPT_MODELS:
41
+ checkpoint_path = os.path.join(args.dialogpt_path, f"{MODEL}_ft.pkl")
42
+ pytorch_dump_folder_path = f"./DialoGPT-{MODEL}"
43
+ convert_dialogpt_checkpoint(
44
+ checkpoint_path,
45
+ pytorch_dump_folder_path,
46
+ )
docs/transformers/build/lib/transformers/models/diffllama/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_diffllama import *
22
+ from .modeling_diffllama import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/diffllama/configuration_diffllama.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on Llama implementations in this library and Microsoft's
5
+ # Differential Transformer implementations.
6
+
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """DiffLlama model configuration"""
19
+
20
+ from ...configuration_utils import PretrainedConfig
21
+ from ...modeling_rope_utils import rope_config_validation
22
+
23
+
24
+ class DiffLlamaConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`DiffLlamaModel`]. It is used to instantiate an DiffLlama
27
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults
28
+ will yield a similar configuration to that of the [kajuma/DiffLlama-0.3B-handcut](https://huggingface.co/kajuma/DiffLlama-0.3B-handcut).
29
+
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+
33
+
34
+ Args:
35
+ vocab_size (`int`, *optional*, defaults to 32000):
36
+ Vocabulary size of the DiffLlama model. Defines the number of different tokens that can be represented by the
37
+ `inputs_ids` passed when calling [`DiffLlamaModel`]
38
+ hidden_size (`int`, *optional*, defaults to 2048):
39
+ Dimension of the hidden representations.
40
+ intermediate_size (`int`, *optional*, defaults to 8192):
41
+ Dimension of the MLP representations.
42
+ num_hidden_layers (`int`, *optional*, defaults to 16):
43
+ Number of hidden layers in the Transformer decoder.
44
+ num_attention_heads (`int`, *optional*, defaults to 32):
45
+ Number of attention heads for each attention layer in the Transformer decoder.
46
+ num_key_value_heads (`int`, *optional*):
47
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
48
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
49
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
50
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
51
+ by meanpooling all the original heads within that group. For more details checkout [this
52
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
53
+ `num_attention_heads`.
54
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
55
+ The non-linear activation function (function or string) in the decoder.
56
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
57
+ The maximum sequence length that this model might ever be used with.
58
+ initializer_range (`float`, *optional*, defaults to 0.02):
59
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
60
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
61
+ The epsilon used by the rms normalization layers.
62
+ use_cache (`bool`, *optional*, defaults to `True`):
63
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
64
+ relevant if `config.is_decoder=True`.
65
+ pad_token_id (`int`, *optional*):
66
+ Padding token id.
67
+ bos_token_id (`int`, *optional*, defaults to 1):
68
+ Beginning of stream token id.
69
+ eos_token_id (`int`, *optional*, defaults to 2):
70
+ End of stream token id.
71
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
72
+ Whether to tie weight embeddings
73
+ rope_theta (`float`, *optional*, defaults to 10000.0):
74
+ The base period of the RoPE embeddings.
75
+ rope_scaling (`Dict`, *optional*):
76
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
77
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
78
+ accordingly.
79
+ Expected contents:
80
+ `rope_type` (`str`):
81
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
82
+ 'diffllama3'], with 'default' being the original RoPE implementation.
83
+ `factor` (`float`, *optional*):
84
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
85
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
86
+ original maximum pre-trained length.
87
+ `original_max_position_embeddings` (`int`, *optional*):
88
+ Used with 'dynamic', 'longrope' and 'diffllama3'. The original max position embeddings used during
89
+ pretraining.
90
+ `attention_factor` (`float`, *optional*):
91
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
92
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
93
+ `factor` field to infer the suggested value.
94
+ `beta_fast` (`float`, *optional*):
95
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
96
+ ramp function. If unspecified, it defaults to 32.
97
+ `beta_slow` (`float`, *optional*):
98
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
99
+ ramp function. If unspecified, it defaults to 1.
100
+ `short_factor` (`List[float]`, *optional*):
101
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
102
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
103
+ size divided by the number of attention heads divided by 2
104
+ `long_factor` (`List[float]`, *optional*):
105
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
106
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
107
+ size divided by the number of attention heads divided by 2
108
+ `low_freq_factor` (`float`, *optional*):
109
+ Only used with 'diffllama3'. Scaling factor applied to low frequency components of the RoPE
110
+ `high_freq_factor` (`float`, *optional*):
111
+ Only used with 'diffllama3'. Scaling factor applied to high frequency components of the RoPE
112
+ attention_bias (`bool`, *optional*, defaults to `False`):
113
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
114
+ attention_dropout (`float`, *optional*, defaults to 0.0):
115
+ The dropout ratio for the attention probabilities.
116
+ lambda_std_dev (`float`, *optional*, defaults to 0.1):
117
+ The standard deviation for initialization of parameter lambda in attention layer.
118
+ head_dim (`int`, *optional*):
119
+ The attention head dimension. If None, it will default to hidden_size // num_heads
120
+
121
+ ```python
122
+ >>> from transformers import DiffLlamaModel, DiffLlamaConfig
123
+
124
+ >>> # Initializing a DiffLlama diffllama-7b style configuration
125
+ >>> configuration = DiffLlamaConfig()
126
+
127
+ >>> # Initializing a model from the diffllama-7b style configuration
128
+ >>> model = DiffLlamaModel(configuration)
129
+
130
+ >>> # Accessing the model configuration
131
+ >>> configuration = model.config
132
+ ```"""
133
+
134
+ model_type = "diffllama"
135
+ keys_to_ignore_at_inference = ["past_key_values"]
136
+
137
+ def __init__(
138
+ self,
139
+ vocab_size=32000,
140
+ hidden_size=2048,
141
+ intermediate_size=8192,
142
+ num_hidden_layers=16,
143
+ num_attention_heads=32,
144
+ num_key_value_heads=None,
145
+ hidden_act="silu",
146
+ max_position_embeddings=2048,
147
+ initializer_range=0.02,
148
+ rms_norm_eps=1e-5,
149
+ use_cache=True,
150
+ pad_token_id=None,
151
+ bos_token_id=1,
152
+ eos_token_id=2,
153
+ tie_word_embeddings=False,
154
+ rope_theta=10000.0,
155
+ rope_scaling=None,
156
+ attention_bias=False,
157
+ attention_dropout=0.0,
158
+ lambda_std_dev=0.1,
159
+ head_dim=None,
160
+ **kwargs,
161
+ ):
162
+ self.vocab_size = vocab_size
163
+ self.max_position_embeddings = max_position_embeddings
164
+ self.hidden_size = hidden_size
165
+ self.intermediate_size = intermediate_size
166
+ self.num_hidden_layers = num_hidden_layers
167
+ self.num_attention_heads = num_attention_heads
168
+
169
+ # for backward compatibility
170
+ if num_key_value_heads is None:
171
+ num_key_value_heads = num_attention_heads
172
+
173
+ self.num_key_value_heads = num_key_value_heads
174
+ self.hidden_act = hidden_act
175
+ self.initializer_range = initializer_range
176
+ self.rms_norm_eps = rms_norm_eps
177
+ self.use_cache = use_cache
178
+ self.rope_theta = rope_theta
179
+ self.rope_scaling = rope_scaling
180
+ self.attention_bias = attention_bias
181
+ self.attention_dropout = attention_dropout
182
+ self.lambda_std_dev = lambda_std_dev
183
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
184
+ # Validate the correctness of rotary position embeddings parameters
185
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
186
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
187
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
188
+ rope_config_validation(self)
189
+
190
+ super().__init__(
191
+ pad_token_id=pad_token_id,
192
+ bos_token_id=bos_token_id,
193
+ eos_token_id=eos_token_id,
194
+ tie_word_embeddings=tie_word_embeddings,
195
+ **kwargs,
196
+ )
197
+
198
+
199
+ __all__ = ["DiffLlamaConfig"]
docs/transformers/build/lib/transformers/models/esm/openfold_utils/rigid_utils.py ADDED
@@ -0,0 +1,1242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ from functools import lru_cache
19
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+
25
+ def rot_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
26
+ """
27
+ Performs matrix multiplication of two rotation matrix tensors. Written out by hand to avoid AMP downcasting.
28
+
29
+ Args:
30
+ a: [*, 3, 3] left multiplicand
31
+ b: [*, 3, 3] right multiplicand
32
+ Returns:
33
+ The product ab
34
+ """
35
+
36
+ def row_mul(i: int) -> torch.Tensor:
37
+ return torch.stack(
38
+ [
39
+ a[..., i, 0] * b[..., 0, 0] + a[..., i, 1] * b[..., 1, 0] + a[..., i, 2] * b[..., 2, 0],
40
+ a[..., i, 0] * b[..., 0, 1] + a[..., i, 1] * b[..., 1, 1] + a[..., i, 2] * b[..., 2, 1],
41
+ a[..., i, 0] * b[..., 0, 2] + a[..., i, 1] * b[..., 1, 2] + a[..., i, 2] * b[..., 2, 2],
42
+ ],
43
+ dim=-1,
44
+ )
45
+
46
+ return torch.stack(
47
+ [
48
+ row_mul(0),
49
+ row_mul(1),
50
+ row_mul(2),
51
+ ],
52
+ dim=-2,
53
+ )
54
+
55
+
56
+ def rot_vec_mul(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
57
+ """
58
+ Applies a rotation to a vector. Written out by hand to avoid transfer to avoid AMP downcasting.
59
+
60
+ Args:
61
+ r: [*, 3, 3] rotation matrices
62
+ t: [*, 3] coordinate tensors
63
+ Returns:
64
+ [*, 3] rotated coordinates
65
+ """
66
+ x, y, z = torch.unbind(t, dim=-1)
67
+ return torch.stack(
68
+ [
69
+ r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
70
+ r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
71
+ r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
72
+ ],
73
+ dim=-1,
74
+ )
75
+
76
+
77
+ @lru_cache(maxsize=None)
78
+ def identity_rot_mats(
79
+ batch_dims: Tuple[int, ...],
80
+ dtype: Optional[torch.dtype] = None,
81
+ device: Optional[torch.device] = None,
82
+ requires_grad: bool = True,
83
+ ) -> torch.Tensor:
84
+ rots = torch.eye(3, dtype=dtype, device=device, requires_grad=requires_grad)
85
+ rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
86
+ rots = rots.expand(*batch_dims, -1, -1)
87
+ rots = rots.contiguous()
88
+
89
+ return rots
90
+
91
+
92
+ @lru_cache(maxsize=None)
93
+ def identity_trans(
94
+ batch_dims: Tuple[int, ...],
95
+ dtype: Optional[torch.dtype] = None,
96
+ device: Optional[torch.device] = None,
97
+ requires_grad: bool = True,
98
+ ) -> torch.Tensor:
99
+ trans = torch.zeros((*batch_dims, 3), dtype=dtype, device=device, requires_grad=requires_grad)
100
+ return trans
101
+
102
+
103
+ @lru_cache(maxsize=None)
104
+ def identity_quats(
105
+ batch_dims: Tuple[int, ...],
106
+ dtype: Optional[torch.dtype] = None,
107
+ device: Optional[torch.device] = None,
108
+ requires_grad: bool = True,
109
+ ) -> torch.Tensor:
110
+ quat = torch.zeros((*batch_dims, 4), dtype=dtype, device=device, requires_grad=requires_grad)
111
+
112
+ with torch.no_grad():
113
+ quat[..., 0] = 1
114
+
115
+ return quat
116
+
117
+
118
+ _quat_elements: List[str] = ["a", "b", "c", "d"]
119
+ _qtr_keys: List[str] = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
120
+ _qtr_ind_dict: Dict[str, int] = {key: ind for ind, key in enumerate(_qtr_keys)}
121
+
122
+
123
+ def _to_mat(pairs: List[Tuple[str, int]]) -> np.ndarray:
124
+ mat = np.zeros((4, 4))
125
+ for key, value in pairs:
126
+ ind = _qtr_ind_dict[key]
127
+ mat[ind // 4][ind % 4] = value
128
+
129
+ return mat
130
+
131
+
132
+ _QTR_MAT = np.zeros((4, 4, 3, 3))
133
+ _QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
134
+ _QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
135
+ _QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
136
+ _QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
137
+ _QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
138
+ _QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
139
+ _QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
140
+ _QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
141
+ _QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
142
+
143
+
144
+ def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
145
+ """
146
+ Converts a quaternion to a rotation matrix.
147
+
148
+ Args:
149
+ quat: [*, 4] quaternions
150
+ Returns:
151
+ [*, 3, 3] rotation matrices
152
+ """
153
+ # [*, 4, 4]
154
+ quat = quat[..., None] * quat[..., None, :]
155
+
156
+ # [4, 4, 3, 3]
157
+ mat = _get_quat("_QTR_MAT", dtype=quat.dtype, device=quat.device)
158
+
159
+ # [*, 4, 4, 3, 3]
160
+ shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
161
+ quat = quat[..., None, None] * shaped_qtr_mat
162
+
163
+ # [*, 3, 3]
164
+ return torch.sum(quat, dim=(-3, -4))
165
+
166
+
167
+ def rot_to_quat(rot: torch.Tensor) -> torch.Tensor:
168
+ if rot.shape[-2:] != (3, 3):
169
+ raise ValueError("Input rotation is incorrectly shaped")
170
+
171
+ [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = [[rot[..., i, j] for j in range(3)] for i in range(3)]
172
+
173
+ k = [
174
+ [
175
+ xx + yy + zz,
176
+ zy - yz,
177
+ xz - zx,
178
+ yx - xy,
179
+ ],
180
+ [
181
+ zy - yz,
182
+ xx - yy - zz,
183
+ xy + yx,
184
+ xz + zx,
185
+ ],
186
+ [
187
+ xz - zx,
188
+ xy + yx,
189
+ yy - xx - zz,
190
+ yz + zy,
191
+ ],
192
+ [
193
+ yx - xy,
194
+ xz + zx,
195
+ yz + zy,
196
+ zz - xx - yy,
197
+ ],
198
+ ]
199
+
200
+ _, vectors = torch.linalg.eigh((1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2))
201
+ return vectors[..., -1]
202
+
203
+
204
+ _QUAT_MULTIPLY = np.zeros((4, 4, 4))
205
+ _QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, -1]]
206
+
207
+ _QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0]]
208
+
209
+ _QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], [0, 1, 0, 0]]
210
+
211
+ _QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], [1, 0, 0, 0]]
212
+
213
+ _QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
214
+
215
+ _CACHED_QUATS: Dict[str, np.ndarray] = {
216
+ "_QTR_MAT": _QTR_MAT,
217
+ "_QUAT_MULTIPLY": _QUAT_MULTIPLY,
218
+ "_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC,
219
+ }
220
+
221
+
222
+ @lru_cache(maxsize=None)
223
+ def _get_quat(quat_key: str, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
224
+ return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device)
225
+
226
+
227
+ def quat_multiply(quat1: torch.Tensor, quat2: torch.Tensor) -> torch.Tensor:
228
+ """Multiply a quaternion by another quaternion."""
229
+ mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device)
230
+ reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
231
+ return torch.sum(reshaped_mat * quat1[..., :, None, None] * quat2[..., None, :, None], dim=(-3, -2))
232
+
233
+
234
+ def quat_multiply_by_vec(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
235
+ """Multiply a quaternion by a pure-vector quaternion."""
236
+ mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device)
237
+ reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
238
+ return torch.sum(reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], dim=(-3, -2))
239
+
240
+
241
+ def invert_rot_mat(rot_mat: torch.Tensor) -> torch.Tensor:
242
+ return rot_mat.transpose(-1, -2)
243
+
244
+
245
+ def invert_quat(quat: torch.Tensor) -> torch.Tensor:
246
+ quat_prime = quat.clone()
247
+ quat_prime[..., 1:] *= -1
248
+ inv = quat_prime / torch.sum(quat**2, dim=-1, keepdim=True)
249
+ return inv
250
+
251
+
252
+ class Rotation:
253
+ """
254
+ A 3D rotation. Depending on how the object is initialized, the rotation is represented by either a rotation matrix
255
+ or a quaternion, though both formats are made available by helper functions. To simplify gradient computation, the
256
+ underlying format of the rotation cannot be changed in-place. Like Rigid, the class is designed to mimic the
257
+ behavior of a torch Tensor, almost as if each Rotation object were a tensor of rotations, in one format or another.
258
+ """
259
+
260
+ def __init__(
261
+ self,
262
+ rot_mats: Optional[torch.Tensor] = None,
263
+ quats: Optional[torch.Tensor] = None,
264
+ normalize_quats: bool = True,
265
+ ):
266
+ """
267
+ Args:
268
+ rot_mats:
269
+ A [*, 3, 3] rotation matrix tensor. Mutually exclusive with quats
270
+ quats:
271
+ A [*, 4] quaternion. Mutually exclusive with rot_mats. If normalize_quats is not True, must be a unit
272
+ quaternion
273
+ normalize_quats:
274
+ If quats is specified, whether to normalize quats
275
+ """
276
+ if (rot_mats is None and quats is None) or (rot_mats is not None and quats is not None):
277
+ raise ValueError("Exactly one input argument must be specified")
278
+
279
+ if (rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or (quats is not None and quats.shape[-1] != 4):
280
+ raise ValueError("Incorrectly shaped rotation matrix or quaternion")
281
+
282
+ # Force full-precision
283
+ if quats is not None:
284
+ quats = quats.to(dtype=torch.float32)
285
+ if rot_mats is not None:
286
+ rot_mats = rot_mats.to(dtype=torch.float32)
287
+
288
+ if quats is not None and normalize_quats:
289
+ quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
290
+
291
+ self._rot_mats = rot_mats
292
+ self._quats = quats
293
+
294
+ @staticmethod
295
+ def identity(
296
+ shape,
297
+ dtype: Optional[torch.dtype] = None,
298
+ device: Optional[torch.device] = None,
299
+ requires_grad: bool = True,
300
+ fmt: str = "quat",
301
+ ) -> Rotation:
302
+ """
303
+ Returns an identity Rotation.
304
+
305
+ Args:
306
+ shape:
307
+ The "shape" of the resulting Rotation object. See documentation for the shape property
308
+ dtype:
309
+ The torch dtype for the rotation
310
+ device:
311
+ The torch device for the new rotation
312
+ requires_grad:
313
+ Whether the underlying tensors in the new rotation object should require gradient computation
314
+ fmt:
315
+ One of "quat" or "rot_mat". Determines the underlying format of the new object's rotation
316
+ Returns:
317
+ A new identity rotation
318
+ """
319
+ if fmt == "rot_mat":
320
+ rot_mats = identity_rot_mats(
321
+ shape,
322
+ dtype,
323
+ device,
324
+ requires_grad,
325
+ )
326
+ return Rotation(rot_mats=rot_mats, quats=None)
327
+ elif fmt == "quat":
328
+ quats = identity_quats(shape, dtype, device, requires_grad)
329
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
330
+ else:
331
+ raise ValueError(f"Invalid format: f{fmt}")
332
+
333
+ # Magic methods
334
+
335
+ def __getitem__(self, index: Any) -> Rotation:
336
+ """
337
+ Allows torch-style indexing over the virtual shape of the rotation object. See documentation for the shape
338
+ property.
339
+
340
+ Args:
341
+ index:
342
+ A torch index. E.g. (1, 3, 2), or (slice(None,))
343
+ Returns:
344
+ The indexed rotation
345
+ """
346
+ if type(index) is not tuple:
347
+ index = (index,)
348
+
349
+ if self._rot_mats is not None:
350
+ rot_mats = self._rot_mats[index + (slice(None), slice(None))]
351
+ return Rotation(rot_mats=rot_mats)
352
+ elif self._quats is not None:
353
+ quats = self._quats[index + (slice(None),)]
354
+ return Rotation(quats=quats, normalize_quats=False)
355
+ else:
356
+ raise ValueError("Both rotations are None")
357
+
358
+ def __mul__(self, right: torch.Tensor) -> Rotation:
359
+ """
360
+ Pointwise left multiplication of the rotation with a tensor. Can be used to e.g. mask the Rotation.
361
+
362
+ Args:
363
+ right:
364
+ The tensor multiplicand
365
+ Returns:
366
+ The product
367
+ """
368
+ if not (isinstance(right, torch.Tensor)):
369
+ raise TypeError("The other multiplicand must be a Tensor")
370
+
371
+ if self._rot_mats is not None:
372
+ rot_mats = self._rot_mats * right[..., None, None]
373
+ return Rotation(rot_mats=rot_mats, quats=None)
374
+ elif self._quats is not None:
375
+ quats = self._quats * right[..., None]
376
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
377
+ else:
378
+ raise ValueError("Both rotations are None")
379
+
380
+ def __rmul__(self, left: torch.Tensor) -> Rotation:
381
+ """
382
+ Reverse pointwise multiplication of the rotation with a tensor.
383
+
384
+ Args:
385
+ left:
386
+ The left multiplicand
387
+ Returns:
388
+ The product
389
+ """
390
+ return self.__mul__(left)
391
+
392
+ # Properties
393
+
394
+ @property
395
+ def shape(self) -> torch.Size:
396
+ """
397
+ Returns the virtual shape of the rotation object. This shape is defined as the batch dimensions of the
398
+ underlying rotation matrix or quaternion. If the Rotation was initialized with a [10, 3, 3] rotation matrix
399
+ tensor, for example, the resulting shape would be [10].
400
+
401
+ Returns:
402
+ The virtual shape of the rotation object
403
+ """
404
+ if self._rot_mats is not None:
405
+ return self._rot_mats.shape[:-2]
406
+ elif self._quats is not None:
407
+ return self._quats.shape[:-1]
408
+ else:
409
+ raise ValueError("Both rotations are None")
410
+
411
+ @property
412
+ def dtype(self) -> torch.dtype:
413
+ """
414
+ Returns the dtype of the underlying rotation.
415
+
416
+ Returns:
417
+ The dtype of the underlying rotation
418
+ """
419
+ if self._rot_mats is not None:
420
+ return self._rot_mats.dtype
421
+ elif self._quats is not None:
422
+ return self._quats.dtype
423
+ else:
424
+ raise ValueError("Both rotations are None")
425
+
426
+ @property
427
+ def device(self) -> torch.device:
428
+ """
429
+ The device of the underlying rotation
430
+
431
+ Returns:
432
+ The device of the underlying rotation
433
+ """
434
+ if self._rot_mats is not None:
435
+ return self._rot_mats.device
436
+ elif self._quats is not None:
437
+ return self._quats.device
438
+ else:
439
+ raise ValueError("Both rotations are None")
440
+
441
+ @property
442
+ def requires_grad(self) -> bool:
443
+ """
444
+ Returns the requires_grad property of the underlying rotation
445
+
446
+ Returns:
447
+ The requires_grad property of the underlying tensor
448
+ """
449
+ if self._rot_mats is not None:
450
+ return self._rot_mats.requires_grad
451
+ elif self._quats is not None:
452
+ return self._quats.requires_grad
453
+ else:
454
+ raise ValueError("Both rotations are None")
455
+
456
+ def get_rot_mats(self) -> torch.Tensor:
457
+ """
458
+ Returns the underlying rotation as a rotation matrix tensor.
459
+
460
+ Returns:
461
+ The rotation as a rotation matrix tensor
462
+ """
463
+ if self._rot_mats is not None:
464
+ return self._rot_mats
465
+ elif self._quats is not None:
466
+ return quat_to_rot(self._quats)
467
+ else:
468
+ raise ValueError("Both rotations are None")
469
+
470
+ def get_quats(self) -> torch.Tensor:
471
+ """
472
+ Returns the underlying rotation as a quaternion tensor.
473
+
474
+ Depending on whether the Rotation was initialized with a quaternion, this function may call torch.linalg.eigh.
475
+
476
+ Returns:
477
+ The rotation as a quaternion tensor.
478
+ """
479
+ if self._rot_mats is not None:
480
+ return rot_to_quat(self._rot_mats)
481
+ elif self._quats is not None:
482
+ return self._quats
483
+ else:
484
+ raise ValueError("Both rotations are None")
485
+
486
+ def get_cur_rot(self) -> torch.Tensor:
487
+ """
488
+ Return the underlying rotation in its current form
489
+
490
+ Returns:
491
+ The stored rotation
492
+ """
493
+ if self._rot_mats is not None:
494
+ return self._rot_mats
495
+ elif self._quats is not None:
496
+ return self._quats
497
+ else:
498
+ raise ValueError("Both rotations are None")
499
+
500
+ # Rotation functions
501
+
502
+ def compose_q_update_vec(self, q_update_vec: torch.Tensor, normalize_quats: bool = True) -> Rotation:
503
+ """
504
+ Returns a new quaternion Rotation after updating the current object's underlying rotation with a quaternion
505
+ update, formatted as a [*, 3] tensor whose final three columns represent x, y, z such that (1, x, y, z) is the
506
+ desired (not necessarily unit) quaternion update.
507
+
508
+ Args:
509
+ q_update_vec:
510
+ A [*, 3] quaternion update tensor
511
+ normalize_quats:
512
+ Whether to normalize the output quaternion
513
+ Returns:
514
+ An updated Rotation
515
+ """
516
+ quats = self.get_quats()
517
+ new_quats = quats + quat_multiply_by_vec(quats, q_update_vec)
518
+ return Rotation(
519
+ rot_mats=None,
520
+ quats=new_quats,
521
+ normalize_quats=normalize_quats,
522
+ )
523
+
524
+ def compose_r(self, r: Rotation) -> Rotation:
525
+ """
526
+ Compose the rotation matrices of the current Rotation object with those of another.
527
+
528
+ Args:
529
+ r:
530
+ An update rotation object
531
+ Returns:
532
+ An updated rotation object
533
+ """
534
+ r1 = self.get_rot_mats()
535
+ r2 = r.get_rot_mats()
536
+ new_rot_mats = rot_matmul(r1, r2)
537
+ return Rotation(rot_mats=new_rot_mats, quats=None)
538
+
539
+ def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation:
540
+ """
541
+ Compose the quaternions of the current Rotation object with those of another.
542
+
543
+ Depending on whether either Rotation was initialized with quaternions, this function may call
544
+ torch.linalg.eigh.
545
+
546
+ Args:
547
+ r:
548
+ An update rotation object
549
+ Returns:
550
+ An updated rotation object
551
+ """
552
+ q1 = self.get_quats()
553
+ q2 = r.get_quats()
554
+ new_quats = quat_multiply(q1, q2)
555
+ return Rotation(rot_mats=None, quats=new_quats, normalize_quats=normalize_quats)
556
+
557
+ def apply(self, pts: torch.Tensor) -> torch.Tensor:
558
+ """
559
+ Apply the current Rotation as a rotation matrix to a set of 3D coordinates.
560
+
561
+ Args:
562
+ pts:
563
+ A [*, 3] set of points
564
+ Returns:
565
+ [*, 3] rotated points
566
+ """
567
+ rot_mats = self.get_rot_mats()
568
+ return rot_vec_mul(rot_mats, pts)
569
+
570
+ def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
571
+ """
572
+ The inverse of the apply() method.
573
+
574
+ Args:
575
+ pts:
576
+ A [*, 3] set of points
577
+ Returns:
578
+ [*, 3] inverse-rotated points
579
+ """
580
+ rot_mats = self.get_rot_mats()
581
+ inv_rot_mats = invert_rot_mat(rot_mats)
582
+ return rot_vec_mul(inv_rot_mats, pts)
583
+
584
+ def invert(self) -> Rotation:
585
+ """
586
+ Returns the inverse of the current Rotation.
587
+
588
+ Returns:
589
+ The inverse of the current Rotation
590
+ """
591
+ if self._rot_mats is not None:
592
+ return Rotation(rot_mats=invert_rot_mat(self._rot_mats), quats=None)
593
+ elif self._quats is not None:
594
+ return Rotation(
595
+ rot_mats=None,
596
+ quats=invert_quat(self._quats),
597
+ normalize_quats=False,
598
+ )
599
+ else:
600
+ raise ValueError("Both rotations are None")
601
+
602
+ # "Tensor" stuff
603
+
604
+ def unsqueeze(self, dim: int) -> Rotation:
605
+ """
606
+ Analogous to torch.unsqueeze. The dimension is relative to the shape of the Rotation object.
607
+
608
+ Args:
609
+ dim: A positive or negative dimension index.
610
+ Returns:
611
+ The unsqueezed Rotation.
612
+ """
613
+ if dim >= len(self.shape):
614
+ raise ValueError("Invalid dimension")
615
+
616
+ if self._rot_mats is not None:
617
+ rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
618
+ return Rotation(rot_mats=rot_mats, quats=None)
619
+ elif self._quats is not None:
620
+ quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
621
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
622
+ else:
623
+ raise ValueError("Both rotations are None")
624
+
625
+ @staticmethod
626
+ def cat(rs: Sequence[Rotation], dim: int) -> Rotation:
627
+ """
628
+ Concatenates rotations along one of the batch dimensions. Analogous to torch.cat().
629
+
630
+ Note that the output of this operation is always a rotation matrix, regardless of the format of input
631
+ rotations.
632
+
633
+ Args:
634
+ rs:
635
+ A list of rotation objects
636
+ dim:
637
+ The dimension along which the rotations should be concatenated
638
+ Returns:
639
+ A concatenated Rotation object in rotation matrix format
640
+ """
641
+ rot_mats = torch.cat(
642
+ [r.get_rot_mats() for r in rs],
643
+ dim=dim if dim >= 0 else dim - 2,
644
+ )
645
+
646
+ return Rotation(rot_mats=rot_mats, quats=None)
647
+
648
+ def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rotation:
649
+ """
650
+ Apply a Tensor -> Tensor function to underlying rotation tensors, mapping over the rotation dimension(s). Can
651
+ be used e.g. to sum out a one-hot batch dimension.
652
+
653
+ Args:
654
+ fn:
655
+ A Tensor -> Tensor function to be mapped over the Rotation
656
+ Returns:
657
+ The transformed Rotation object
658
+ """
659
+ if self._rot_mats is not None:
660
+ rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
661
+ rot_mats = torch.stack(list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1)
662
+ rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
663
+ return Rotation(rot_mats=rot_mats, quats=None)
664
+ elif self._quats is not None:
665
+ quats = torch.stack(list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1)
666
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
667
+ else:
668
+ raise ValueError("Both rotations are None")
669
+
670
+ def cuda(self) -> Rotation:
671
+ """
672
+ Analogous to the cuda() method of torch Tensors
673
+
674
+ Returns:
675
+ A copy of the Rotation in CUDA memory
676
+ """
677
+ if self._rot_mats is not None:
678
+ return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
679
+ elif self._quats is not None:
680
+ return Rotation(rot_mats=None, quats=self._quats.cuda(), normalize_quats=False)
681
+ else:
682
+ raise ValueError("Both rotations are None")
683
+
684
+ def to(self, device: Optional[torch.device], dtype: Optional[torch.dtype]) -> Rotation:
685
+ """
686
+ Analogous to the to() method of torch Tensors
687
+
688
+ Args:
689
+ device:
690
+ A torch device
691
+ dtype:
692
+ A torch dtype
693
+ Returns:
694
+ A copy of the Rotation using the new device and dtype
695
+ """
696
+ if self._rot_mats is not None:
697
+ return Rotation(
698
+ rot_mats=self._rot_mats.to(device=device, dtype=dtype),
699
+ quats=None,
700
+ )
701
+ elif self._quats is not None:
702
+ return Rotation(
703
+ rot_mats=None,
704
+ quats=self._quats.to(device=device, dtype=dtype),
705
+ normalize_quats=False,
706
+ )
707
+ else:
708
+ raise ValueError("Both rotations are None")
709
+
710
+ def detach(self) -> Rotation:
711
+ """
712
+ Returns a copy of the Rotation whose underlying Tensor has been detached from its torch graph.
713
+
714
+ Returns:
715
+ A copy of the Rotation whose underlying Tensor has been detached from its torch graph
716
+ """
717
+ if self._rot_mats is not None:
718
+ return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
719
+ elif self._quats is not None:
720
+ return Rotation(
721
+ rot_mats=None,
722
+ quats=self._quats.detach(),
723
+ normalize_quats=False,
724
+ )
725
+ else:
726
+ raise ValueError("Both rotations are None")
727
+
728
+
729
+ class Rigid:
730
+ """
731
+ A class representing a rigid transformation. Little more than a wrapper around two objects: a Rotation object and a
732
+ [*, 3] translation Designed to behave approximately like a single torch tensor with the shape of the shared batch
733
+ dimensions of its component parts.
734
+ """
735
+
736
+ def __init__(self, rots: Optional[Rotation], trans: Optional[torch.Tensor]):
737
+ """
738
+ Args:
739
+ rots: A [*, 3, 3] rotation tensor
740
+ trans: A corresponding [*, 3] translation tensor
741
+ """
742
+ # (we need device, dtype, etc. from at least one input)
743
+
744
+ batch_dims, dtype, device, requires_grad = None, None, None, None
745
+ if trans is not None:
746
+ batch_dims = trans.shape[:-1]
747
+ dtype = trans.dtype
748
+ device = trans.device
749
+ requires_grad = trans.requires_grad
750
+ elif rots is not None:
751
+ batch_dims = rots.shape
752
+ dtype = rots.dtype
753
+ device = rots.device
754
+ requires_grad = rots.requires_grad
755
+ else:
756
+ raise ValueError("At least one input argument must be specified")
757
+
758
+ if rots is None:
759
+ rots = Rotation.identity(
760
+ batch_dims,
761
+ dtype,
762
+ device,
763
+ requires_grad,
764
+ )
765
+ elif trans is None:
766
+ trans = identity_trans(
767
+ batch_dims,
768
+ dtype,
769
+ device,
770
+ requires_grad,
771
+ )
772
+
773
+ assert rots is not None
774
+ assert trans is not None
775
+
776
+ if (rots.shape != trans.shape[:-1]) or (rots.device != trans.device):
777
+ raise ValueError("Rots and trans incompatible")
778
+
779
+ # Force full precision. Happens to the rotations automatically.
780
+ trans = trans.to(dtype=torch.float32)
781
+
782
+ self._rots = rots
783
+ self._trans = trans
784
+
785
+ @staticmethod
786
+ def identity(
787
+ shape: Tuple[int, ...],
788
+ dtype: Optional[torch.dtype] = None,
789
+ device: Optional[torch.device] = None,
790
+ requires_grad: bool = True,
791
+ fmt: str = "quat",
792
+ ) -> Rigid:
793
+ """
794
+ Constructs an identity transformation.
795
+
796
+ Args:
797
+ shape:
798
+ The desired shape
799
+ dtype:
800
+ The dtype of both internal tensors
801
+ device:
802
+ The device of both internal tensors
803
+ requires_grad:
804
+ Whether grad should be enabled for the internal tensors
805
+ Returns:
806
+ The identity transformation
807
+ """
808
+ return Rigid(
809
+ Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
810
+ identity_trans(shape, dtype, device, requires_grad),
811
+ )
812
+
813
+ def __getitem__(self, index: Any) -> Rigid:
814
+ """
815
+ Indexes the affine transformation with PyTorch-style indices. The index is applied to the shared dimensions of
816
+ both the rotation and the translation.
817
+
818
+ E.g.::
819
+
820
+ r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None) t = Rigid(r, torch.rand(10, 10, 3)) indexed =
821
+ t[3, 4:6] assert(indexed.shape == (2,)) assert(indexed.get_rots().shape == (2,))
822
+ assert(indexed.get_trans().shape == (2, 3))
823
+
824
+ Args:
825
+ index: A standard torch tensor index. E.g. 8, (10, None, 3),
826
+ or (3, slice(0, 1, None))
827
+ Returns:
828
+ The indexed tensor
829
+ """
830
+ if type(index) is not tuple:
831
+ index = (index,)
832
+
833
+ return Rigid(
834
+ self._rots[index],
835
+ self._trans[index + (slice(None),)],
836
+ )
837
+
838
+ def __mul__(self, right: torch.Tensor) -> Rigid:
839
+ """
840
+ Pointwise left multiplication of the transformation with a tensor. Can be used to e.g. mask the Rigid.
841
+
842
+ Args:
843
+ right:
844
+ The tensor multiplicand
845
+ Returns:
846
+ The product
847
+ """
848
+ if not (isinstance(right, torch.Tensor)):
849
+ raise TypeError("The other multiplicand must be a Tensor")
850
+
851
+ new_rots = self._rots * right
852
+ new_trans = self._trans * right[..., None]
853
+
854
+ return Rigid(new_rots, new_trans)
855
+
856
+ def __rmul__(self, left: torch.Tensor) -> Rigid:
857
+ """
858
+ Reverse pointwise multiplication of the transformation with a tensor.
859
+
860
+ Args:
861
+ left:
862
+ The left multiplicand
863
+ Returns:
864
+ The product
865
+ """
866
+ return self.__mul__(left)
867
+
868
+ @property
869
+ def shape(self) -> torch.Size:
870
+ """
871
+ Returns the shape of the shared dimensions of the rotation and the translation.
872
+
873
+ Returns:
874
+ The shape of the transformation
875
+ """
876
+ return self._trans.shape[:-1]
877
+
878
+ @property
879
+ def device(self) -> torch.device:
880
+ """
881
+ Returns the device on which the Rigid's tensors are located.
882
+
883
+ Returns:
884
+ The device on which the Rigid's tensors are located
885
+ """
886
+ return self._trans.device
887
+
888
+ def get_rots(self) -> Rotation:
889
+ """
890
+ Getter for the rotation.
891
+
892
+ Returns:
893
+ The rotation object
894
+ """
895
+ return self._rots
896
+
897
+ def get_trans(self) -> torch.Tensor:
898
+ """
899
+ Getter for the translation.
900
+
901
+ Returns:
902
+ The stored translation
903
+ """
904
+ return self._trans
905
+
906
+ def compose_q_update_vec(self, q_update_vec: torch.Tensor) -> Rigid:
907
+ """
908
+ Composes the transformation with a quaternion update vector of shape [*, 6], where the final 6 columns
909
+ represent the x, y, and z values of a quaternion of form (1, x, y, z) followed by a 3D translation.
910
+
911
+ Args:
912
+ q_vec: The quaternion update vector.
913
+ Returns:
914
+ The composed transformation.
915
+ """
916
+ q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
917
+ new_rots = self._rots.compose_q_update_vec(q_vec)
918
+
919
+ trans_update = self._rots.apply(t_vec)
920
+ new_translation = self._trans + trans_update
921
+
922
+ return Rigid(new_rots, new_translation)
923
+
924
+ def compose(self, r: Rigid) -> Rigid:
925
+ """
926
+ Composes the current rigid object with another.
927
+
928
+ Args:
929
+ r:
930
+ Another Rigid object
931
+ Returns:
932
+ The composition of the two transformations
933
+ """
934
+ new_rot = self._rots.compose_r(r._rots)
935
+ new_trans = self._rots.apply(r._trans) + self._trans
936
+ return Rigid(new_rot, new_trans)
937
+
938
+ def apply(self, pts: torch.Tensor) -> torch.Tensor:
939
+ """
940
+ Applies the transformation to a coordinate tensor.
941
+
942
+ Args:
943
+ pts: A [*, 3] coordinate tensor.
944
+ Returns:
945
+ The transformed points.
946
+ """
947
+ rotated = self._rots.apply(pts)
948
+ return rotated + self._trans
949
+
950
+ def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
951
+ """
952
+ Applies the inverse of the transformation to a coordinate tensor.
953
+
954
+ Args:
955
+ pts: A [*, 3] coordinate tensor
956
+ Returns:
957
+ The transformed points.
958
+ """
959
+ pts = pts - self._trans
960
+ return self._rots.invert_apply(pts)
961
+
962
+ def invert(self) -> Rigid:
963
+ """
964
+ Inverts the transformation.
965
+
966
+ Returns:
967
+ The inverse transformation.
968
+ """
969
+ rot_inv = self._rots.invert()
970
+ trn_inv = rot_inv.apply(self._trans)
971
+
972
+ return Rigid(rot_inv, -1 * trn_inv)
973
+
974
+ def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rigid:
975
+ """
976
+ Apply a Tensor -> Tensor function to underlying translation and rotation tensors, mapping over the
977
+ translation/rotation dimensions respectively.
978
+
979
+ Args:
980
+ fn:
981
+ A Tensor -> Tensor function to be mapped over the Rigid
982
+ Returns:
983
+ The transformed Rigid object
984
+ """
985
+ new_rots = self._rots.map_tensor_fn(fn)
986
+ new_trans = torch.stack(list(map(fn, torch.unbind(self._trans, dim=-1))), dim=-1)
987
+
988
+ return Rigid(new_rots, new_trans)
989
+
990
+ def to_tensor_4x4(self) -> torch.Tensor:
991
+ """
992
+ Converts a transformation to a homogeneous transformation tensor.
993
+
994
+ Returns:
995
+ A [*, 4, 4] homogeneous transformation tensor
996
+ """
997
+ tensor = self._trans.new_zeros((*self.shape, 4, 4))
998
+ tensor[..., :3, :3] = self._rots.get_rot_mats()
999
+ tensor[..., :3, 3] = self._trans
1000
+ tensor[..., 3, 3] = 1
1001
+ return tensor
1002
+
1003
+ @staticmethod
1004
+ def from_tensor_4x4(t: torch.Tensor) -> Rigid:
1005
+ """
1006
+ Constructs a transformation from a homogeneous transformation tensor.
1007
+
1008
+ Args:
1009
+ t: [*, 4, 4] homogeneous transformation tensor
1010
+ Returns:
1011
+ T object with shape [*]
1012
+ """
1013
+ if t.shape[-2:] != (4, 4):
1014
+ raise ValueError("Incorrectly shaped input tensor")
1015
+
1016
+ rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
1017
+ trans = t[..., :3, 3]
1018
+
1019
+ return Rigid(rots, trans)
1020
+
1021
+ def to_tensor_7(self) -> torch.Tensor:
1022
+ """
1023
+ Converts a transformation to a tensor with 7 final columns, four for the quaternion followed by three for the
1024
+ translation.
1025
+
1026
+ Returns:
1027
+ A [*, 7] tensor representation of the transformation
1028
+ """
1029
+ tensor = self._trans.new_zeros((*self.shape, 7))
1030
+ tensor[..., :4] = self._rots.get_quats()
1031
+ tensor[..., 4:] = self._trans
1032
+
1033
+ return tensor
1034
+
1035
+ @staticmethod
1036
+ def from_tensor_7(t: torch.Tensor, normalize_quats: bool = False) -> Rigid:
1037
+ if t.shape[-1] != 7:
1038
+ raise ValueError("Incorrectly shaped input tensor")
1039
+
1040
+ quats, trans = t[..., :4], t[..., 4:]
1041
+
1042
+ rots = Rotation(rot_mats=None, quats=quats, normalize_quats=normalize_quats)
1043
+
1044
+ return Rigid(rots, trans)
1045
+
1046
+ @staticmethod
1047
+ def from_3_points(
1048
+ p_neg_x_axis: torch.Tensor, origin: torch.Tensor, p_xy_plane: torch.Tensor, eps: float = 1e-8
1049
+ ) -> Rigid:
1050
+ """
1051
+ Implements algorithm 21. Constructs transformations from sets of 3 points using the Gram-Schmidt algorithm.
1052
+
1053
+ Args:
1054
+ p_neg_x_axis: [*, 3] coordinates
1055
+ origin: [*, 3] coordinates used as frame origins
1056
+ p_xy_plane: [*, 3] coordinates
1057
+ eps: Small epsilon value
1058
+ Returns:
1059
+ A transformation object of shape [*]
1060
+ """
1061
+ p_neg_x_axis_unbound = torch.unbind(p_neg_x_axis, dim=-1)
1062
+ origin_unbound = torch.unbind(origin, dim=-1)
1063
+ p_xy_plane_unbound = torch.unbind(p_xy_plane, dim=-1)
1064
+
1065
+ e0 = [c1 - c2 for c1, c2 in zip(origin_unbound, p_neg_x_axis_unbound)]
1066
+ e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane_unbound, origin_unbound)]
1067
+
1068
+ denom = torch.sqrt(sum(c * c for c in e0) + eps * torch.ones_like(e0[0]))
1069
+ e0 = [c / denom for c in e0]
1070
+ dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
1071
+ e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
1072
+ denom = torch.sqrt(sum((c * c for c in e1)) + eps * torch.ones_like(e1[0]))
1073
+ e1 = [c / denom for c in e1]
1074
+ e2 = [
1075
+ e0[1] * e1[2] - e0[2] * e1[1],
1076
+ e0[2] * e1[0] - e0[0] * e1[2],
1077
+ e0[0] * e1[1] - e0[1] * e1[0],
1078
+ ]
1079
+
1080
+ rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
1081
+ rots = rots.reshape(rots.shape[:-1] + (3, 3))
1082
+
1083
+ rot_obj = Rotation(rot_mats=rots, quats=None)
1084
+
1085
+ return Rigid(rot_obj, torch.stack(origin_unbound, dim=-1))
1086
+
1087
+ def unsqueeze(self, dim: int) -> Rigid:
1088
+ """
1089
+ Analogous to torch.unsqueeze. The dimension is relative to the shared dimensions of the rotation/translation.
1090
+
1091
+ Args:
1092
+ dim: A positive or negative dimension index.
1093
+ Returns:
1094
+ The unsqueezed transformation.
1095
+ """
1096
+ if dim >= len(self.shape):
1097
+ raise ValueError("Invalid dimension")
1098
+ rots = self._rots.unsqueeze(dim)
1099
+ trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
1100
+
1101
+ return Rigid(rots, trans)
1102
+
1103
+ @staticmethod
1104
+ def cat(ts: Sequence[Rigid], dim: int) -> Rigid:
1105
+ """
1106
+ Concatenates transformations along a new dimension.
1107
+
1108
+ Args:
1109
+ ts:
1110
+ A list of T objects
1111
+ dim:
1112
+ The dimension along which the transformations should be concatenated
1113
+ Returns:
1114
+ A concatenated transformation object
1115
+ """
1116
+ rots = Rotation.cat([t._rots for t in ts], dim)
1117
+ trans = torch.cat([t._trans for t in ts], dim=dim if dim >= 0 else dim - 1)
1118
+
1119
+ return Rigid(rots, trans)
1120
+
1121
+ def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Rigid:
1122
+ """
1123
+ Applies a Rotation -> Rotation function to the stored rotation object.
1124
+
1125
+ Args:
1126
+ fn: A function of type Rotation -> Rotation
1127
+ Returns:
1128
+ A transformation object with a transformed rotation.
1129
+ """
1130
+ return Rigid(fn(self._rots), self._trans)
1131
+
1132
+ def apply_trans_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rigid:
1133
+ """
1134
+ Applies a Tensor -> Tensor function to the stored translation.
1135
+
1136
+ Args:
1137
+ fn:
1138
+ A function of type Tensor -> Tensor to be applied to the translation
1139
+ Returns:
1140
+ A transformation object with a transformed translation.
1141
+ """
1142
+ return Rigid(self._rots, fn(self._trans))
1143
+
1144
+ def scale_translation(self, trans_scale_factor: float) -> Rigid:
1145
+ """
1146
+ Scales the translation by a constant factor.
1147
+
1148
+ Args:
1149
+ trans_scale_factor:
1150
+ The constant factor
1151
+ Returns:
1152
+ A transformation object with a scaled translation.
1153
+ """
1154
+ return self.apply_trans_fn(lambda t: t * trans_scale_factor)
1155
+
1156
+ def stop_rot_gradient(self) -> Rigid:
1157
+ """
1158
+ Detaches the underlying rotation object
1159
+
1160
+ Returns:
1161
+ A transformation object with detached rotations
1162
+ """
1163
+ return self.apply_rot_fn(lambda r: r.detach())
1164
+
1165
+ @staticmethod
1166
+ def make_transform_from_reference(
1167
+ n_xyz: torch.Tensor, ca_xyz: torch.Tensor, c_xyz: torch.Tensor, eps: float = 1e-20
1168
+ ) -> Rigid:
1169
+ """
1170
+ Returns a transformation object from reference coordinates.
1171
+
1172
+ Note that this method does not take care of symmetries. If you provide the atom positions in the non-standard
1173
+ way, the N atom will end up not at [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
1174
+ need to take care of such cases in your code.
1175
+
1176
+ Args:
1177
+ n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
1178
+ ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
1179
+ c_xyz: A [*, 3] tensor of carbon xyz coordinates.
1180
+ Returns:
1181
+ A transformation object. After applying the translation and rotation to the reference backbone, the
1182
+ coordinates will approximately equal to the input coordinates.
1183
+ """
1184
+ translation = -1 * ca_xyz
1185
+ n_xyz = n_xyz + translation
1186
+ c_xyz = c_xyz + translation
1187
+
1188
+ c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
1189
+ norm = torch.sqrt(eps + c_x**2 + c_y**2)
1190
+ sin_c1 = -c_y / norm
1191
+ cos_c1 = c_x / norm
1192
+
1193
+ c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
1194
+ c1_rots[..., 0, 0] = cos_c1
1195
+ c1_rots[..., 0, 1] = -1 * sin_c1
1196
+ c1_rots[..., 1, 0] = sin_c1
1197
+ c1_rots[..., 1, 1] = cos_c1
1198
+ c1_rots[..., 2, 2] = 1
1199
+
1200
+ norm = torch.sqrt(eps + c_x**2 + c_y**2 + c_z**2)
1201
+ sin_c2 = c_z / norm
1202
+ cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm
1203
+
1204
+ c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
1205
+ c2_rots[..., 0, 0] = cos_c2
1206
+ c2_rots[..., 0, 2] = sin_c2
1207
+ c2_rots[..., 1, 1] = 1
1208
+ c2_rots[..., 2, 0] = -1 * sin_c2
1209
+ c2_rots[..., 2, 2] = cos_c2
1210
+
1211
+ c_rots = rot_matmul(c2_rots, c1_rots)
1212
+ n_xyz = rot_vec_mul(c_rots, n_xyz)
1213
+
1214
+ _, n_y, n_z = [n_xyz[..., i] for i in range(3)]
1215
+ norm = torch.sqrt(eps + n_y**2 + n_z**2)
1216
+ sin_n = -n_z / norm
1217
+ cos_n = n_y / norm
1218
+
1219
+ n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
1220
+ n_rots[..., 0, 0] = 1
1221
+ n_rots[..., 1, 1] = cos_n
1222
+ n_rots[..., 1, 2] = -1 * sin_n
1223
+ n_rots[..., 2, 1] = sin_n
1224
+ n_rots[..., 2, 2] = cos_n
1225
+
1226
+ rots = rot_matmul(n_rots, c_rots)
1227
+
1228
+ rots = rots.transpose(-1, -2)
1229
+ translation = -1 * translation
1230
+
1231
+ rot_obj = Rotation(rot_mats=rots, quats=None)
1232
+
1233
+ return Rigid(rot_obj, translation)
1234
+
1235
+ def cuda(self) -> Rigid:
1236
+ """
1237
+ Moves the transformation object to GPU memory
1238
+
1239
+ Returns:
1240
+ A version of the transformation on GPU
1241
+ """
1242
+ return Rigid(self._rots.cuda(), self._trans.cuda())
docs/transformers/build/lib/transformers/models/falcon/configuration_falcon.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Falcon configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class FalconConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`FalconModel`]. It is used to instantiate a Falcon
27
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
28
+ defaults will yield a similar configuration to that of the
29
+ [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 65024):
37
+ Vocabulary size of the Falcon model. Defines the number of different tokens that can be represented by the
38
+ `inputs_ids` passed when calling [`FalconModel`]
39
+ hidden_size (`int`, *optional*, defaults to 4544):
40
+ Dimension of the hidden representations.
41
+ num_hidden_layers (`int`, *optional*, defaults to 32):
42
+ Number of hidden layers in the Transformer decoder.
43
+ num_attention_heads (`int`, *optional*, defaults to 71):
44
+ Number of attention heads for each attention layer in the Transformer encoder.
45
+ num_ln_in_parallel_attn (`int`, *optional*):
46
+ Set to 2 if separate layer norms are to be used for the MLP and the attention output when using parallel
47
+ attention, otherwise, 1.
48
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
49
+ The epsilon used by the layer normalization layers.
50
+ initializer_range (`float`, *optional*, defaults to 0.02):
51
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
52
+ use_cache (`bool`, *optional*, defaults to `True`):
53
+ Whether the model should return the last key/values attentions (not used by all models). Only relevant if
54
+ `config.is_decoder=True`.
55
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
56
+ The dropout probability for MLP layers.
57
+ attention_dropout (`float`, *optional*, defaults to 0.0):
58
+ The dropout probability for attention layers.
59
+ num_kv_heads (`int`, *optional*):
60
+ Number of key-value heads to use per attention layer. If unset, defaults to the same value as
61
+ `num_attention_heads`.
62
+ alibi (`bool`, *optional*, defaults to `False`):
63
+ Whether to use ALiBi positional biases during self-attention.
64
+ new_decoder_architecture (`bool`, *optional*, defaults to `False`):
65
+ Whether to use the new (Falcon-40B) decoder architecture. If `True`, the `multi_query` and `parallel_attn`
66
+ arguments are ignored, as the new decoder always uses parallel attention.
67
+ multi_query (`bool`, *optional*, defaults to `True`):
68
+ Whether to use multi-query attention in the decoder. Ignored when `new_decoder_architecture` is `True`.
69
+ parallel_attn (`bool`, *optional*, defaults to `True`):
70
+ Whether to compute attention in parallel with the feedforward layer. If False, they are consecutive
71
+ instead, as in the original Transformer architecture. Ignored when `new_decoder_architecture` is `True`.
72
+ bias (`bool`, *optional*, defaults to `False`):
73
+ Whether to use bias on Linear layers.
74
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
75
+ The maximum sequence length that this model might ever be used with, when `alibi` is `False`. Pretrained
76
+ Falcon models with RoPE support up to 2048 tokens.
77
+ rope_theta (`float`, *optional*, defaults to 10000.0):
78
+ The base period of the RoPE embeddings.
79
+ rope_scaling (`Dict`, *optional*):
80
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
81
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
82
+ accordingly.
83
+ Expected contents:
84
+ `rope_type` (`str`):
85
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
86
+ 'llama3'], with 'default' being the original RoPE implementation.
87
+ `factor` (`float`, *optional*):
88
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
89
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
90
+ original maximum pre-trained length.
91
+ `original_max_position_embeddings` (`int`, *optional*):
92
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
93
+ pretraining.
94
+ `attention_factor` (`float`, *optional*):
95
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
96
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
97
+ `factor` field to infer the suggested value.
98
+ `beta_fast` (`float`, *optional*):
99
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
100
+ ramp function. If unspecified, it defaults to 32.
101
+ `beta_slow` (`float`, *optional*):
102
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
103
+ ramp function. If unspecified, it defaults to 1.
104
+ `short_factor` (`List[float]`, *optional*):
105
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
106
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
107
+ size divided by the number of attention heads divided by 2
108
+ `long_factor` (`List[float]`, *optional*):
109
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
110
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
111
+ size divided by the number of attention heads divided by 2
112
+ `low_freq_factor` (`float`, *optional*):
113
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
114
+ `high_freq_factor` (`float`, *optional*):
115
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
116
+ bos_token_id (`int`, *optional*, defaults to 11):
117
+ The id of the "beginning-of-sequence" token.
118
+ eos_token_id (`int`, *optional*, defaults to 11):
119
+ The id of the "end-of-sequence" token.
120
+ ffn_hidden_size (`int`, *optional*):
121
+ The hidden size of the feedforward layer in the Transformer decoder.
122
+ defaults to 4x hidden dim
123
+ activation (`str`, *optional*, defaults to `"gelu"`):
124
+ The activation function used in the feedforward layer.
125
+
126
+ Example:
127
+
128
+ ```python
129
+ >>> from transformers import FalconModel, FalconConfig
130
+
131
+ >>> # Initializing a small (2-layer) Falcon configuration
132
+ >>> configuration = FalconConfig(num_hidden_layers=2)
133
+
134
+ >>> # Initializing a model from the small configuration
135
+ >>> model = FalconModel(configuration)
136
+
137
+ >>> # Accessing the model configuration
138
+ >>> configuration = model.config
139
+ ```"""
140
+
141
+ model_type = "falcon"
142
+ keys_to_ignore_at_inference = ["past_key_values"]
143
+
144
+ def __init__(
145
+ self,
146
+ vocab_size=65024,
147
+ hidden_size=4544,
148
+ num_hidden_layers=32,
149
+ num_attention_heads=71,
150
+ num_ln_in_parallel_attn=None,
151
+ layer_norm_epsilon=1e-5,
152
+ initializer_range=0.02,
153
+ use_cache=True,
154
+ hidden_dropout=0.0,
155
+ attention_dropout=0.0,
156
+ num_kv_heads=None,
157
+ alibi=False,
158
+ new_decoder_architecture=False,
159
+ multi_query=True,
160
+ parallel_attn=True,
161
+ bias=False,
162
+ max_position_embeddings=2048,
163
+ rope_theta=10000.0,
164
+ rope_scaling=None,
165
+ bos_token_id=11,
166
+ eos_token_id=11,
167
+ ffn_hidden_size=None,
168
+ activation="gelu",
169
+ **kwargs,
170
+ ):
171
+ self.vocab_size = vocab_size
172
+ # Backward compatibility with n_embed kwarg
173
+ n_embed = kwargs.pop("n_embed", None)
174
+ self.hidden_size = hidden_size if n_embed is None else n_embed
175
+ self.num_hidden_layers = num_hidden_layers
176
+ self.num_attention_heads = num_attention_heads
177
+ self.layer_norm_epsilon = layer_norm_epsilon
178
+ self.initializer_range = initializer_range
179
+ self.use_cache = use_cache
180
+ self.hidden_dropout = hidden_dropout
181
+ self.attention_dropout = attention_dropout
182
+ self.bos_token_id = bos_token_id
183
+ self.eos_token_id = eos_token_id
184
+ self.num_kv_heads = num_attention_heads if num_kv_heads is None else num_kv_heads
185
+ self.alibi = alibi
186
+ self.new_decoder_architecture = new_decoder_architecture
187
+ self.multi_query = multi_query # Ignored when new_decoder_architecture is True
188
+ self.parallel_attn = parallel_attn
189
+ self.bias = bias
190
+ self.num_ln_in_parallel_attn = num_ln_in_parallel_attn
191
+ self.max_position_embeddings = max_position_embeddings
192
+ self.rope_theta = rope_theta
193
+ self.rope_scaling = rope_scaling
194
+ self.activation = activation
195
+ if ffn_hidden_size is None:
196
+ self.ffn_hidden_size = hidden_size * 4
197
+ else:
198
+ self.ffn_hidden_size = ffn_hidden_size
199
+
200
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
201
+
202
+ @property
203
+ def head_dim(self):
204
+ return self.hidden_size // self.num_attention_heads
205
+
206
+ @property
207
+ def rotary(self):
208
+ return not self.alibi
209
+
210
+
211
+ __all__ = ["FalconConfig"]
docs/transformers/build/lib/transformers/models/falcon/convert_custom_code_checkpoint.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from argparse import ArgumentParser
3
+ from pathlib import Path
4
+
5
+
6
+ """
7
+ This script converts Falcon custom code checkpoints to modern Falcon checkpoints that use code in the Transformers
8
+ library. After conversion, performance (especially for generation) should improve and the checkpoint can be loaded
9
+ without needing trust_remote_code=True.
10
+ """
11
+
12
+ if __name__ == "__main__":
13
+ parser = ArgumentParser()
14
+ parser.add_argument(
15
+ "--checkpoint_dir",
16
+ type=Path,
17
+ required=True,
18
+ help="Directory containing a custom code checkpoint to convert to a modern Falcon checkpoint.",
19
+ )
20
+ args = parser.parse_args()
21
+
22
+ if not args.checkpoint_dir.is_dir():
23
+ raise ValueError("--checkpoint_dir argument should be a directory!")
24
+
25
+ if (
26
+ not (args.checkpoint_dir / "configuration_RW.py").is_file()
27
+ or not (args.checkpoint_dir / "modelling_RW.py").is_file()
28
+ ):
29
+ raise ValueError(
30
+ "The model directory should contain configuration_RW.py and modelling_RW.py files! Are you sure this is a custom code checkpoint?"
31
+ )
32
+ (args.checkpoint_dir / "configuration_RW.py").unlink()
33
+ (args.checkpoint_dir / "modelling_RW.py").unlink()
34
+
35
+ config = args.checkpoint_dir / "config.json"
36
+ text = config.read_text()
37
+ text = text.replace("RWForCausalLM", "FalconForCausalLM")
38
+ text = text.replace("RefinedWebModel", "falcon")
39
+ text = text.replace("RefinedWeb", "falcon")
40
+ json_config = json.loads(text)
41
+ del json_config["auto_map"]
42
+
43
+ if "n_head" in json_config:
44
+ json_config["num_attention_heads"] = json_config.pop("n_head")
45
+ if "n_layer" in json_config:
46
+ json_config["num_hidden_layers"] = json_config.pop("n_layer")
47
+ if "n_head_kv" in json_config:
48
+ json_config["num_kv_heads"] = json_config.pop("n_head_kv")
49
+ json_config["new_decoder_architecture"] = True
50
+ else:
51
+ json_config["new_decoder_architecture"] = False
52
+ bos_token_id = json_config.get("bos_token_id", 1)
53
+ eos_token_id = json_config.get("eos_token_id", 2)
54
+ config.unlink()
55
+ config.write_text(json.dumps(json_config, indent=2, sort_keys=True))
56
+
57
+ tokenizer_config = args.checkpoint_dir / "tokenizer_config.json"
58
+ if tokenizer_config.is_file():
59
+ text = tokenizer_config.read_text()
60
+ json_config = json.loads(text)
61
+ if json_config["tokenizer_class"] == "PreTrainedTokenizerFast":
62
+ json_config["model_input_names"] = ["input_ids", "attention_mask"]
63
+ tokenizer_config.unlink()
64
+ tokenizer_config.write_text(json.dumps(json_config, indent=2, sort_keys=True))
65
+
66
+ generation_config_path = args.checkpoint_dir / "generation_config.json"
67
+ generation_dict = {
68
+ "_from_model_config": True,
69
+ "bos_token_id": bos_token_id,
70
+ "eos_token_id": eos_token_id,
71
+ "transformers_version": "4.33.0.dev0",
72
+ }
73
+ generation_config_path.write_text(json.dumps(generation_dict, indent=2, sort_keys=True))
74
+ print("Done! Please double-check that the new checkpoint works as expected.")
docs/transformers/build/lib/transformers/models/falcon/modeling_falcon.py ADDED
@@ -0,0 +1,1566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Falcon model."""
16
+
17
+ import math
18
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
24
+ from torch.nn import functional as F
25
+
26
+ from ...activations import get_activation
27
+ from ...cache_utils import Cache, DynamicCache, StaticCache
28
+ from ...generation import GenerationMixin
29
+ from ...modeling_attn_mask_utils import (
30
+ AttentionMaskConverter,
31
+ )
32
+ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
33
+ from ...modeling_outputs import (
34
+ BaseModelOutputWithPastAndCrossAttentions,
35
+ CausalLMOutputWithCrossAttentions,
36
+ QuestionAnsweringModelOutput,
37
+ SequenceClassifierOutputWithPast,
38
+ TokenClassifierOutput,
39
+ )
40
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
41
+ from ...modeling_utils import PreTrainedModel
42
+ from ...utils import (
43
+ add_code_sample_docstrings,
44
+ add_start_docstrings,
45
+ add_start_docstrings_to_model_forward,
46
+ logging,
47
+ )
48
+ from .configuration_falcon import FalconConfig
49
+
50
+
51
+ if TYPE_CHECKING:
52
+ from ...configuration_utils import PretrainedConfig
53
+
54
+ if is_flash_attn_available():
55
+ from ...modeling_flash_attention_utils import _flash_attention_forward
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+
60
+ _CHECKPOINT_FOR_DOC = "Rocketknight1/falcon-rw-1b"
61
+ _CONFIG_FOR_DOC = "FalconConfig"
62
+
63
+
64
+ # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
65
+ # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
66
+ class FalconLinear(nn.Linear):
67
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
68
+ hidden_states = input @ self.weight.T
69
+ if self.bias is None:
70
+ return hidden_states
71
+ return hidden_states + self.bias
72
+
73
+
74
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
75
+ def rotate_half(x):
76
+ """Rotates half the hidden dims of the input."""
77
+ x1 = x[..., : x.shape[-1] // 2]
78
+ x2 = x[..., x.shape[-1] // 2 :]
79
+ return torch.cat((-x2, x1), dim=-1)
80
+
81
+
82
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
83
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
84
+ """Applies Rotary Position Embedding to the query and key tensors.
85
+
86
+ Args:
87
+ q (`torch.Tensor`): The query tensor.
88
+ k (`torch.Tensor`): The key tensor.
89
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
90
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
91
+ position_ids (`torch.Tensor`, *optional*):
92
+ Deprecated and unused.
93
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
94
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
95
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
96
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
97
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
98
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
99
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
100
+ Returns:
101
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
102
+ """
103
+ cos = cos.unsqueeze(unsqueeze_dim)
104
+ sin = sin.unsqueeze(unsqueeze_dim)
105
+ q_embed = (q * cos) + (rotate_half(q) * sin)
106
+ k_embed = (k * cos) + (rotate_half(k) * sin)
107
+ return q_embed, k_embed
108
+
109
+
110
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Falcon
111
+ class FalconRotaryEmbedding(nn.Module):
112
+ def __init__(self, config: FalconConfig, device=None):
113
+ super().__init__()
114
+ # BC: "rope_type" was originally "type"
115
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
116
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
117
+ else:
118
+ self.rope_type = "default"
119
+ self.max_seq_len_cached = config.max_position_embeddings
120
+ self.original_max_seq_len = config.max_position_embeddings
121
+
122
+ self.config = config
123
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
124
+
125
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
126
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
127
+ self.original_inv_freq = self.inv_freq
128
+
129
+ @torch.no_grad()
130
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
131
+ def forward(self, x, position_ids):
132
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
133
+ position_ids_expanded = position_ids[:, None, :].float()
134
+
135
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
136
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
137
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
138
+ emb = torch.cat((freqs, freqs), dim=-1)
139
+ cos = emb.cos() * self.attention_scaling
140
+ sin = emb.sin() * self.attention_scaling
141
+
142
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
143
+
144
+
145
+ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
146
+ batch_size, seq_length = attention_mask.shape
147
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
148
+ base = torch.tensor(
149
+ 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
150
+ )
151
+ powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
152
+ slopes = torch.pow(base, powers)
153
+
154
+ if closest_power_of_2 != num_heads:
155
+ extra_base = torch.tensor(
156
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
157
+ )
158
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
159
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
160
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
161
+
162
+ # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
163
+ # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
164
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
165
+ # => the query_length dimension will then be broadcasted correctly
166
+ # This is more or less identical to T5's relative position bias:
167
+ # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
168
+ arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
169
+ alibi = slopes[..., None].bfloat16() * arange_tensor
170
+ return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
171
+
172
+
173
+ # Copied from transformers.models.bloom.modeling_bloom.dropout_add
174
+ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
175
+ """
176
+ Dropout add function
177
+
178
+ Args:
179
+ x (`torch.tensor`):
180
+ input tensor
181
+ residual (`torch.tensor`):
182
+ residual tensor
183
+ prob (`float`):
184
+ dropout probability
185
+ training (`bool`):
186
+ training mode
187
+ """
188
+ out = F.dropout(x, p=prob, training=training)
189
+ out = residual + out
190
+ return out
191
+
192
+
193
+ class FalconAttention(nn.Module):
194
+ def __init__(self, config: FalconConfig, layer_idx=None):
195
+ super().__init__()
196
+
197
+ self.config = config
198
+ self.hidden_size = config.hidden_size
199
+ self.num_heads = config.num_attention_heads
200
+ self.head_dim = self.hidden_size // self.num_heads
201
+ self.split_size = self.hidden_size
202
+ self.hidden_dropout = config.hidden_dropout
203
+ self.max_position_embeddings = config.max_position_embeddings
204
+ self.rope_theta = config.rope_theta
205
+ self.is_causal = True
206
+ self._use_sdpa = config._attn_implementation == "sdpa"
207
+ self.layer_idx = layer_idx
208
+ if layer_idx is None:
209
+ logger.warning_once(
210
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
211
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
212
+ "when creating this class."
213
+ )
214
+
215
+ if self.head_dim * self.num_heads != self.hidden_size:
216
+ raise ValueError(
217
+ f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
218
+ f" {self.num_heads})."
219
+ )
220
+
221
+ # Layer-wise attention scaling
222
+ self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
223
+ self.beta = self.inv_norm_factor
224
+ if config.new_decoder_architecture:
225
+ qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
226
+ elif config.multi_query:
227
+ qkv_out_dim = self.hidden_size + 2 * self.head_dim
228
+ else:
229
+ qkv_out_dim = 3 * self.hidden_size
230
+ self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
231
+ self.new_decoder_architecture = config.new_decoder_architecture
232
+ self.multi_query = config.multi_query
233
+ self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
234
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
235
+ self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
236
+
237
+ # TODO (raushan): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
238
+ if config.rotary:
239
+ self.rotary_emb = FalconRotaryEmbedding(config=self.config)
240
+
241
+ def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
242
+ """
243
+ Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
244
+
245
+ Args:
246
+ fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim]
247
+
248
+ Returns:
249
+ query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
250
+ value: [batch_size, seq_length, num_heads, head_dim]
251
+ """
252
+ if self.new_decoder_architecture:
253
+ batch, seq_len, _ = fused_qkv.shape
254
+ qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
255
+ query = qkv[:, :, :, :-2]
256
+ key = qkv[:, :, :, [-2]]
257
+ value = qkv[:, :, :, [-1]]
258
+ key = torch.broadcast_to(key, query.shape)
259
+ value = torch.broadcast_to(value, query.shape)
260
+
261
+ query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
262
+ return query, key, value
263
+ elif not self.multi_query:
264
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
265
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
266
+ return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
267
+ else:
268
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
269
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
270
+ return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
271
+
272
+ # Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads
273
+ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
274
+ """
275
+ Merge heads together over the last dimension
276
+
277
+ Args:
278
+ x (`torch.tensor`): [batch_size * num_heads, seq_length, head_dim]
279
+
280
+ Returns:
281
+ torch.tensor: [batch_size, seq_length, num_heads * head_dim]
282
+ """
283
+ # What we want to achieve is:
284
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
285
+ batch_size_and_num_heads, seq_length, _ = x.shape
286
+ batch_size = batch_size_and_num_heads // self.num_heads
287
+
288
+ # First view to decompose the batch size
289
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
290
+ x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
291
+
292
+ # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
293
+ x = x.permute(0, 2, 1, 3)
294
+
295
+ # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
296
+ return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
297
+
298
+ def forward(
299
+ self,
300
+ hidden_states: torch.Tensor,
301
+ alibi: Optional[torch.Tensor],
302
+ attention_mask: torch.Tensor,
303
+ position_ids: Optional[torch.LongTensor] = None,
304
+ layer_past: Optional[Cache] = None,
305
+ head_mask: Optional[torch.Tensor] = None,
306
+ use_cache: bool = False,
307
+ output_attentions: bool = False,
308
+ cache_position: Optional[torch.LongTensor] = None,
309
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
310
+ ):
311
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
312
+ num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
313
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
314
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
315
+
316
+ batch_size, query_length, _, _ = query_layer.shape
317
+
318
+ query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim)
319
+ key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
320
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
321
+
322
+ if alibi is None:
323
+ cos, sin = position_embeddings
324
+ query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
325
+
326
+ if layer_past is not None:
327
+ cache_kwargs = {"cache_position": cache_position}
328
+ if alibi is None:
329
+ cache_kwargs.update({"sin": sin, "cos": cos})
330
+ key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
331
+
332
+ kv_length = key_layer.shape[-2]
333
+ if self._use_sdpa and query_layer.device.type == "cuda" and attention_mask is not None:
334
+ # For torch<=2.1.2, SDPA with memory-efficient backend is bugged with non-contiguous inputs with custom attn_mask,
335
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
336
+ query_layer = query_layer.contiguous()
337
+ key_layer = key_layer.contiguous()
338
+ value_layer = value_layer.contiguous()
339
+
340
+ if attention_mask is not None:
341
+ attention_mask = attention_mask[:, :, :, : key_layer.shape[-2]]
342
+
343
+ if alibi is None:
344
+ if self._use_sdpa and not output_attentions:
345
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
346
+ # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
347
+ # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not
348
+ # create a causal mask in case query_length == 1.
349
+ is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False
350
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
351
+ query_layer,
352
+ key_layer,
353
+ value_layer,
354
+ attn_mask=attention_mask,
355
+ dropout_p=0.0,
356
+ is_causal=is_causal,
357
+ )
358
+ attention_scores = None
359
+ else:
360
+ attention_scores = query_layer @ key_layer.transpose(-1, -2)
361
+ attention_scores /= math.sqrt(self.head_dim)
362
+
363
+ attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
364
+ # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi).
365
+ attn_output = attention_scores @ value_layer
366
+
367
+ attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
368
+ attn_output = attn_output.permute(0, 2, 1, 3)
369
+ attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
370
+
371
+ attn_output = self.dense(attn_output)
372
+
373
+ if output_attentions:
374
+ return attn_output, layer_past, attention_scores
375
+ else:
376
+ return attn_output, layer_past
377
+
378
+ else:
379
+ if self._use_sdpa and not output_attentions and head_mask is None:
380
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
381
+ # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
382
+ is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False
383
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
384
+ query_layer,
385
+ key_layer,
386
+ value_layer,
387
+ attn_mask=attention_mask,
388
+ dropout_p=self.attention_dropout.p if self.training else 0.0,
389
+ is_causal=is_causal,
390
+ )
391
+ attn_output = attn_output.transpose(1, 2)
392
+ attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
393
+
394
+ attn_output = self.dense(attn_output)
395
+ else:
396
+ matmul_result = query_layer @ key_layer.transpose(-1, -2)
397
+
398
+ # change view to [batch_size, num_heads, q_length, kv_length]
399
+ attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
400
+
401
+ # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
402
+ input_dtype = attention_scores.dtype
403
+ # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
404
+ if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
405
+ attention_scores = attention_scores.to(torch.float32)
406
+
407
+ attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
408
+ attention_logits *= self.inv_norm_factor
409
+ attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype)
410
+ # [batch_size, num_heads, q_length, kv_length]
411
+ attention_probs = self.attention_dropout(attention_probs)
412
+
413
+ if head_mask is not None:
414
+ attention_probs = attention_probs * head_mask
415
+
416
+ # change view [batch_size, num_heads, q_length, kv_length]
417
+ attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
418
+
419
+ # matmul: [batch_size * num_heads, q_length, head_dim]
420
+ attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1)
421
+
422
+ # change view [batch_size, q_length, num_heads * head_dim]
423
+ attn_output = self._merge_heads(attn_output)
424
+
425
+ attn_output = self.dense(attn_output)
426
+
427
+ if output_attentions:
428
+ return attn_output, layer_past, attention_probs
429
+ else:
430
+ return attn_output, layer_past
431
+
432
+
433
+ class FalconFlashAttention2(FalconAttention):
434
+ """
435
+ Falcon flash attention module. This module inherits from `FalconAttention` as the weights of the module stays
436
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
437
+ flash attention and deal with padding tokens in case the input contains any of them.
438
+ """
439
+
440
+ def __init__(self, *args, **kwargs):
441
+ super().__init__(*args, **kwargs)
442
+
443
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
444
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
445
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
446
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
447
+
448
+ def forward(
449
+ self,
450
+ hidden_states: torch.Tensor,
451
+ alibi: Optional[torch.Tensor],
452
+ attention_mask: torch.Tensor,
453
+ position_ids: Optional[torch.LongTensor] = None,
454
+ layer_past: Optional[Cache] = None,
455
+ head_mask: Optional[torch.Tensor] = None,
456
+ use_cache: bool = False,
457
+ output_attentions: bool = False,
458
+ cache_position: Optional[torch.LongTensor] = None,
459
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
460
+ ):
461
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
462
+ num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
463
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
464
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
465
+
466
+ batch_size, query_length, _, _ = query_layer.shape
467
+
468
+ query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim)
469
+ key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
470
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
471
+
472
+ if alibi is None:
473
+ cos, sin = position_embeddings
474
+ query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
475
+
476
+ if layer_past is not None:
477
+ cache_kwargs = {"cache_position": cache_position}
478
+ if alibi is None:
479
+ cache_kwargs.update({"sin": sin, "cos": cos})
480
+ key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
481
+
482
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
483
+ # to be able to avoid many of these transpose/reshape/view.
484
+ query_layer = query_layer.transpose(1, 2)
485
+ key_layer = key_layer.transpose(1, 2)
486
+ value_layer = value_layer.transpose(1, 2)
487
+
488
+ if alibi is not None:
489
+ raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
490
+
491
+ attn_dropout = self.config.attention_dropout if self.training else 0.0
492
+
493
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
494
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
495
+ # cast them back in float16 just to be sure everything works as expected.
496
+ input_dtype = query_layer.dtype
497
+ if input_dtype == torch.float32:
498
+ if torch.is_autocast_enabled():
499
+ target_dtype = torch.get_autocast_gpu_dtype()
500
+ # Handle the case where the model is quantized
501
+ elif hasattr(self.config, "_pre_quantization_dtype"):
502
+ target_dtype = self.config._pre_quantization_dtype
503
+ else:
504
+ target_dtype = self.query_key_value.weight.dtype
505
+
506
+ logger.warning_once(
507
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
508
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
509
+ f" {target_dtype}."
510
+ )
511
+
512
+ query_layer = query_layer.to(target_dtype)
513
+ key_layer = key_layer.to(target_dtype)
514
+ value_layer = value_layer.to(target_dtype)
515
+
516
+ attn_output = _flash_attention_forward(
517
+ query_layer,
518
+ key_layer,
519
+ value_layer,
520
+ attention_mask,
521
+ query_length,
522
+ position_ids=position_ids,
523
+ dropout=attn_dropout,
524
+ is_causal=self.is_causal,
525
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
526
+ )
527
+
528
+ attn_weights = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
529
+ attn_output = self.dense(attn_weights)
530
+
531
+ if not output_attentions:
532
+ attn_weights = None
533
+
534
+ return attn_output, layer_past, attn_weights
535
+
536
+
537
+ class FalconMLP(nn.Module):
538
+ def __init__(self, config: FalconConfig):
539
+ super().__init__()
540
+ hidden_size = config.hidden_size
541
+
542
+ self.dense_h_to_4h = FalconLinear(hidden_size, config.ffn_hidden_size, bias=config.bias)
543
+ self.act = get_activation(config.activation)
544
+ self.dense_4h_to_h = FalconLinear(config.ffn_hidden_size, hidden_size, bias=config.bias)
545
+ self.hidden_dropout = config.hidden_dropout
546
+
547
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
548
+ x = self.act(self.dense_h_to_4h(x))
549
+ x = self.dense_4h_to_h(x)
550
+ return x
551
+
552
+
553
+ FALCON_ATTENTION_CLASSES = {
554
+ "eager": FalconAttention,
555
+ "sdpa": FalconAttention, # FalconAttention originally implemented both a forward with & without SDPA
556
+ "flash_attention_2": FalconFlashAttention2,
557
+ }
558
+
559
+
560
+ class FalconDecoderLayer(nn.Module):
561
+ def __init__(self, config: FalconConfig, layer_idx=None):
562
+ super().__init__()
563
+ hidden_size = config.hidden_size
564
+ self.num_heads = config.num_attention_heads
565
+
566
+ self.self_attention = FALCON_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
567
+ self.mlp = FalconMLP(config)
568
+ self.hidden_dropout = config.hidden_dropout
569
+ self.config = config
570
+
571
+ if config.num_ln_in_parallel_attn is None and config.new_decoder_architecture:
572
+ config.num_ln_in_parallel_attn = 2
573
+
574
+ if not config.parallel_attn:
575
+ self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
576
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
577
+ else:
578
+ if config.num_ln_in_parallel_attn == 2:
579
+ # The layer norm before self-attention
580
+ self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
581
+ # The layer norm before the MLP
582
+ self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
583
+ else:
584
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
585
+
586
+ def forward(
587
+ self,
588
+ hidden_states: torch.Tensor,
589
+ alibi: Optional[torch.Tensor],
590
+ attention_mask: torch.Tensor,
591
+ position_ids: Optional[torch.LongTensor] = None,
592
+ layer_past: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None,
593
+ head_mask: Optional[torch.Tensor] = None,
594
+ use_cache: bool = False,
595
+ output_attentions: bool = False,
596
+ cache_position: Optional[torch.LongTensor] = None,
597
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
598
+ **kwargs,
599
+ ):
600
+ residual = hidden_states
601
+
602
+ if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2:
603
+ attention_layernorm_out = self.ln_attn(hidden_states)
604
+ mlp_layernorm_out = self.ln_mlp(hidden_states)
605
+ else:
606
+ attention_layernorm_out = self.input_layernorm(hidden_states)
607
+
608
+ # Self attention.
609
+ attn_outputs = self.self_attention(
610
+ attention_layernorm_out,
611
+ layer_past=layer_past,
612
+ attention_mask=attention_mask,
613
+ position_ids=position_ids,
614
+ alibi=alibi,
615
+ head_mask=head_mask,
616
+ use_cache=use_cache,
617
+ output_attentions=output_attentions,
618
+ cache_position=cache_position,
619
+ position_embeddings=position_embeddings,
620
+ )
621
+
622
+ attention_output = attn_outputs[0]
623
+
624
+ if not self.config.new_decoder_architecture:
625
+ if self.config.parallel_attn:
626
+ mlp_layernorm_out = attention_layernorm_out
627
+ else:
628
+ residual = dropout_add(
629
+ attention_output, residual, self.config.attention_dropout, training=self.training
630
+ )
631
+ mlp_layernorm_out = self.post_attention_layernorm(residual)
632
+
633
+ if (
634
+ self.config.new_decoder_architecture
635
+ and self.config.parallel_attn
636
+ and self.config.num_ln_in_parallel_attn == 1
637
+ ):
638
+ mlp_layernorm_out = attention_layernorm_out
639
+
640
+ outputs = attn_outputs[1:]
641
+
642
+ # MLP.
643
+ mlp_output = self.mlp(mlp_layernorm_out)
644
+
645
+ if self.config.new_decoder_architecture or self.config.parallel_attn:
646
+ mlp_output += attention_output
647
+
648
+ output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
649
+
650
+ if use_cache:
651
+ outputs = (output,) + outputs
652
+ else:
653
+ outputs = (output,) + outputs[1:]
654
+
655
+ return outputs # hidden_states, past_kv, attentions
656
+
657
+
658
+ FALCON_START_DOCSTRING = r"""
659
+
660
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
661
+ library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
662
+
663
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
664
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
665
+ and behavior.
666
+
667
+ Parameters:
668
+ config ([`FalconConfig`]): Model configuration class with all the parameters of the model.
669
+ Initializing with a config file does not load the weights associated with the model, only the
670
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
671
+ """
672
+
673
+ FALCON_INPUTS_DOCSTRING = r"""
674
+ Args:
675
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
676
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
677
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
678
+
679
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
680
+ `input_ids`.
681
+
682
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
683
+ [`PreTrainedTokenizer.__call__`] for details.
684
+
685
+ [What are input IDs?](../glossary#input-ids)
686
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
687
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
688
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
689
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
690
+
691
+ Two formats are allowed:
692
+ - a [`~cache_utils.Cache`] instance, see our
693
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
694
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
695
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
696
+ cache format.
697
+
698
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
699
+ legacy cache format will be returned.
700
+
701
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
702
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
703
+ of shape `(batch_size, sequence_length)`.
704
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
705
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
706
+
707
+ - 1 for tokens that are **not masked**,
708
+ - 0 for tokens that are **masked**.
709
+
710
+ [What are attention masks?](../glossary#attention-mask)
711
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
712
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
713
+ config.n_positions - 1]`.
714
+
715
+ [What are position IDs?](../glossary#position-ids)
716
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
717
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
718
+
719
+ - 1 indicates the head is **not masked**,
720
+ - 0 indicates the head is **masked**.
721
+
722
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
723
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
724
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
725
+ model's internal embedding lookup matrix.
726
+
727
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
728
+ `past_key_values`).
729
+ use_cache (`bool`, *optional*):
730
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
731
+ `past_key_values`).
732
+ output_attentions (`bool`, *optional*):
733
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
734
+ tensors for more detail.
735
+ output_hidden_states (`bool`, *optional*):
736
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
737
+ more detail.
738
+ return_dict (`bool`, *optional*):
739
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
740
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
741
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
742
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
743
+ the complete sequence length.
744
+ """
745
+
746
+
747
+ class FalconPreTrainedModel(PreTrainedModel):
748
+ """
749
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
750
+ models.
751
+ """
752
+
753
+ config_class = FalconConfig
754
+ base_model_prefix = "transformer"
755
+ supports_gradient_checkpointing = True
756
+ _no_split_modules = ["FalconDecoderLayer"]
757
+ _supports_flash_attn_2 = True
758
+ _supports_sdpa = True
759
+ _supports_cache_class = True
760
+ _supports_quantized_cache = True
761
+ _supports_static_cache = True
762
+
763
+ def __init__(self, *inputs, **kwargs):
764
+ super().__init__(*inputs, **kwargs)
765
+
766
+ def _init_weights(self, module: nn.Module):
767
+ """Initialize the weights."""
768
+ if isinstance(module, nn.Linear) or isinstance(module, FalconLinear):
769
+ # Slightly different from the TF version which uses truncated_normal for initialization
770
+ # cf https://github.com/pytorch/pytorch/pull/5617
771
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
772
+ if module.bias is not None:
773
+ module.bias.data.zero_()
774
+ elif isinstance(module, nn.Embedding):
775
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
776
+ if module.padding_idx is not None:
777
+ module.weight.data[module.padding_idx].zero_()
778
+ elif isinstance(module, LayerNorm):
779
+ module.bias.data.zero_()
780
+ module.weight.data.fill_(1.0)
781
+
782
+ # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
783
+ @classmethod
784
+ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig":
785
+ _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
786
+ if _is_bettertransformer:
787
+ return config
788
+
789
+ if not hard_check_only:
790
+ config._attn_implementation = "sdpa"
791
+ return config
792
+
793
+
794
+ @add_start_docstrings(
795
+ "The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
796
+ FALCON_START_DOCSTRING,
797
+ )
798
+ class FalconModel(FalconPreTrainedModel):
799
+ def __init__(self, config: FalconConfig):
800
+ super().__init__(config)
801
+
802
+ self.embed_dim = config.hidden_size
803
+ self.num_heads = config.num_attention_heads
804
+ self.use_alibi = config.alibi
805
+
806
+ # Embedding + LN Embedding
807
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
808
+
809
+ # Transformer blocks
810
+ self.h = nn.ModuleList([FalconDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
811
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
812
+ self._use_sdpa = config._attn_implementation == "sdpa"
813
+
814
+ # Final Layer Norm
815
+ self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
816
+
817
+ self.rotary_emb = FalconRotaryEmbedding(config=config)
818
+
819
+ self.gradient_checkpointing = False
820
+
821
+ # Initialize weights and apply final processing
822
+ self.post_init()
823
+
824
+ def get_input_embeddings(self):
825
+ return self.word_embeddings
826
+
827
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
828
+ self.word_embeddings = new_embeddings
829
+
830
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
831
+ @add_code_sample_docstrings(
832
+ checkpoint=_CHECKPOINT_FOR_DOC,
833
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
834
+ config_class=_CONFIG_FOR_DOC,
835
+ )
836
+ def forward(
837
+ self,
838
+ input_ids: Optional[torch.LongTensor] = None,
839
+ past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
840
+ attention_mask: Optional[torch.Tensor] = None,
841
+ position_ids: Optional[torch.LongTensor] = None,
842
+ head_mask: Optional[torch.LongTensor] = None,
843
+ inputs_embeds: Optional[torch.LongTensor] = None,
844
+ use_cache: Optional[bool] = None,
845
+ output_attentions: Optional[bool] = None,
846
+ output_hidden_states: Optional[bool] = None,
847
+ return_dict: Optional[bool] = None,
848
+ cache_position: Optional[torch.LongTensor] = None,
849
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
850
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
851
+ output_hidden_states = (
852
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
853
+ )
854
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
855
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
856
+
857
+ if (input_ids is None) ^ (inputs_embeds is not None):
858
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
859
+
860
+ if self.gradient_checkpointing and self.training:
861
+ if use_cache:
862
+ logger.warning_once(
863
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
864
+ )
865
+ use_cache = False
866
+
867
+ if inputs_embeds is None:
868
+ inputs_embeds = self.word_embeddings(input_ids)
869
+
870
+ # kept for BC (non `Cache` `past_key_values` inputs)
871
+ return_legacy_cache = False
872
+ if use_cache and not isinstance(past_key_values, Cache):
873
+ return_legacy_cache = True
874
+ if past_key_values is None:
875
+ past_key_values = DynamicCache()
876
+ else:
877
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
878
+ logger.warning_once(
879
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
880
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
881
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
882
+ )
883
+
884
+ # Compute alibi tensor: check build_alibi_tensor documentation
885
+ alibi = None
886
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
887
+ batch_size, seq_length, _ = inputs_embeds.shape
888
+ if self.use_alibi:
889
+ mask = (
890
+ torch.ones(
891
+ (batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long
892
+ )
893
+ if attention_mask is None
894
+ else attention_mask
895
+ )
896
+ alibi = build_alibi_tensor(mask, self.num_heads, dtype=inputs_embeds.dtype)
897
+
898
+ if cache_position is None:
899
+ cache_position = torch.arange(
900
+ past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
901
+ )
902
+
903
+ if position_ids is None:
904
+ position_ids = cache_position.unsqueeze(0)
905
+
906
+ causal_mask = self._update_causal_mask(
907
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, head_mask, alibi
908
+ )
909
+
910
+ # Prepare head mask if needed
911
+ # 1.0 in head_mask indicate we keep the head
912
+ # attention_probs has shape batch_size x num_heads x N x N
913
+ # head_mask has shape n_layer x batch x num_heads x N x N
914
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
915
+ hidden_states = inputs_embeds
916
+
917
+ # create position embeddings to be shared across the decoder layers
918
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
919
+
920
+ next_decoder_cache = None
921
+ all_self_attentions = () if output_attentions else None
922
+ all_hidden_states = () if output_hidden_states else None
923
+
924
+ for i, block in enumerate(self.h):
925
+ if output_hidden_states:
926
+ all_hidden_states = all_hidden_states + (hidden_states,)
927
+
928
+ if self.gradient_checkpointing and self.training:
929
+ outputs = self._gradient_checkpointing_func(
930
+ block.__call__,
931
+ hidden_states,
932
+ alibi,
933
+ causal_mask,
934
+ position_ids,
935
+ head_mask[i],
936
+ past_key_values,
937
+ use_cache,
938
+ output_attentions,
939
+ cache_position,
940
+ position_embeddings,
941
+ )
942
+ else:
943
+ outputs = block(
944
+ hidden_states,
945
+ layer_past=past_key_values,
946
+ attention_mask=causal_mask,
947
+ position_ids=position_ids,
948
+ head_mask=head_mask[i],
949
+ use_cache=use_cache,
950
+ output_attentions=output_attentions,
951
+ alibi=alibi,
952
+ cache_position=cache_position,
953
+ position_embeddings=position_embeddings,
954
+ )
955
+
956
+ hidden_states = outputs[0]
957
+ if use_cache is True:
958
+ next_decoder_cache = outputs[1]
959
+
960
+ if output_attentions:
961
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
962
+
963
+ # Add last hidden state
964
+ hidden_states = self.ln_f(hidden_states)
965
+
966
+ if output_hidden_states:
967
+ all_hidden_states = all_hidden_states + (hidden_states,)
968
+
969
+ next_cache = next_decoder_cache if use_cache else None
970
+ if return_legacy_cache:
971
+ next_cache = next_cache.to_legacy_cache()
972
+
973
+ if not return_dict:
974
+ return tuple(
975
+ v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
976
+ )
977
+
978
+ return BaseModelOutputWithPastAndCrossAttentions(
979
+ last_hidden_state=hidden_states,
980
+ past_key_values=next_cache,
981
+ hidden_states=all_hidden_states,
982
+ attentions=all_self_attentions,
983
+ )
984
+
985
+ def _update_causal_mask(
986
+ self,
987
+ attention_mask: torch.Tensor,
988
+ input_tensor: torch.Tensor,
989
+ cache_position: torch.Tensor,
990
+ past_key_values: Cache,
991
+ output_attentions: bool,
992
+ head_mask: torch.Tensor,
993
+ alibi: torch.Tensor,
994
+ ):
995
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
996
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
997
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
998
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
999
+
1000
+ if self.config._attn_implementation == "flash_attention_2":
1001
+ if attention_mask is not None and 0.0 in attention_mask:
1002
+ return attention_mask
1003
+ return None
1004
+
1005
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1006
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1007
+ # to infer the attention mask.
1008
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1009
+ using_static_cache = isinstance(past_key_values, StaticCache)
1010
+
1011
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1012
+ if (
1013
+ self.config._attn_implementation == "sdpa"
1014
+ and not using_static_cache
1015
+ and not output_attentions
1016
+ and head_mask is None
1017
+ and alibi is None
1018
+ ):
1019
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1020
+ attention_mask,
1021
+ inputs_embeds=input_tensor,
1022
+ past_key_values_length=past_seen_tokens,
1023
+ is_training=self.training,
1024
+ ):
1025
+ return None
1026
+
1027
+ dtype, device = input_tensor.dtype, input_tensor.device
1028
+ min_dtype = torch.finfo(dtype).min
1029
+ batch_size, sequence_length, _ = input_tensor.shape
1030
+ if using_static_cache:
1031
+ target_length = past_key_values.get_max_cache_shape()
1032
+ else:
1033
+ target_length = (
1034
+ attention_mask.shape[-1]
1035
+ if isinstance(attention_mask, torch.Tensor)
1036
+ else past_seen_tokens + sequence_length
1037
+ )
1038
+
1039
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1040
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1041
+ attention_mask,
1042
+ sequence_length=sequence_length,
1043
+ target_length=target_length,
1044
+ dtype=dtype,
1045
+ device=device,
1046
+ cache_position=cache_position,
1047
+ batch_size=input_tensor.shape[0],
1048
+ )
1049
+
1050
+ # We take care to integrate alibi bias in the causal_mask here
1051
+ if head_mask is None and alibi is not None:
1052
+ alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
1053
+ causal_mask = torch.masked_fill(
1054
+ alibi / math.sqrt(self.config.hidden_size // self.num_heads),
1055
+ causal_mask < -1,
1056
+ min_dtype,
1057
+ )
1058
+
1059
+ if (
1060
+ self.config._attn_implementation == "sdpa"
1061
+ and attention_mask is not None
1062
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
1063
+ and not output_attentions
1064
+ ):
1065
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1066
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1067
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1068
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1069
+
1070
+ return causal_mask
1071
+
1072
+ @staticmethod
1073
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
1074
+ def _prepare_4d_causal_attention_mask_with_cache_position(
1075
+ attention_mask: torch.Tensor,
1076
+ sequence_length: int,
1077
+ target_length: int,
1078
+ dtype: torch.dtype,
1079
+ device: torch.device,
1080
+ cache_position: torch.Tensor,
1081
+ batch_size: int,
1082
+ **kwargs,
1083
+ ):
1084
+ """
1085
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1086
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1087
+
1088
+ Args:
1089
+ attention_mask (`torch.Tensor`):
1090
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1091
+ `(batch_size, 1, query_length, key_value_length)`.
1092
+ sequence_length (`int`):
1093
+ The sequence length being processed.
1094
+ target_length (`int`):
1095
+ The target length: when generating with static cache, the mask should be as long as the static cache,
1096
+ to account for the 0 padding, the part of the cache that is not filled yet.
1097
+ dtype (`torch.dtype`):
1098
+ The dtype to use for the 4D attention mask.
1099
+ device (`torch.device`):
1100
+ The device to place the 4D attention mask on.
1101
+ cache_position (`torch.Tensor`):
1102
+ Indices depicting the position of the input sequence tokens in the sequence.
1103
+ batch_size (`torch.Tensor`):
1104
+ Batch size.
1105
+ """
1106
+ if attention_mask is not None and attention_mask.dim() == 4:
1107
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1108
+ causal_mask = attention_mask
1109
+ else:
1110
+ min_dtype = torch.finfo(dtype).min
1111
+ causal_mask = torch.full(
1112
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1113
+ )
1114
+ if sequence_length != 1:
1115
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1116
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1117
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1118
+ if attention_mask is not None:
1119
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1120
+ mask_length = attention_mask.shape[-1]
1121
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
1122
+ causal_mask.device
1123
+ )
1124
+ padding_mask = padding_mask == 0
1125
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1126
+ padding_mask, min_dtype
1127
+ )
1128
+
1129
+ return causal_mask
1130
+
1131
+
1132
+ @add_start_docstrings(
1133
+ "The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).",
1134
+ FALCON_START_DOCSTRING,
1135
+ )
1136
+ class FalconForCausalLM(FalconPreTrainedModel, GenerationMixin):
1137
+ _tied_weights_keys = ["lm_head.weight"]
1138
+
1139
+ def __init__(self, config: FalconConfig):
1140
+ super().__init__(config)
1141
+ self.transformer = FalconModel(config)
1142
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1143
+
1144
+ # Initialize weights and apply final processing
1145
+ self.post_init()
1146
+
1147
+ def get_output_embeddings(self):
1148
+ return self.lm_head
1149
+
1150
+ def set_output_embeddings(self, new_embeddings: torch.Tensor):
1151
+ self.lm_head = new_embeddings
1152
+
1153
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1154
+ @add_code_sample_docstrings(
1155
+ checkpoint=_CHECKPOINT_FOR_DOC,
1156
+ output_type=CausalLMOutputWithCrossAttentions,
1157
+ config_class=_CONFIG_FOR_DOC,
1158
+ )
1159
+ def forward(
1160
+ self,
1161
+ input_ids: Optional[torch.LongTensor] = None,
1162
+ past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
1163
+ attention_mask: Optional[torch.Tensor] = None,
1164
+ position_ids: Optional[torch.LongTensor] = None,
1165
+ head_mask: Optional[torch.Tensor] = None,
1166
+ inputs_embeds: Optional[torch.Tensor] = None,
1167
+ labels: Optional[torch.Tensor] = None,
1168
+ use_cache: Optional[bool] = None,
1169
+ output_attentions: Optional[bool] = None,
1170
+ output_hidden_states: Optional[bool] = None,
1171
+ return_dict: Optional[bool] = None,
1172
+ cache_position: Optional[torch.LongTensor] = None,
1173
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1174
+ **kwargs,
1175
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
1176
+ r"""
1177
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1178
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1179
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1180
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1181
+
1182
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
1183
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
1184
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1185
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1186
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
1187
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
1188
+ """
1189
+
1190
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1191
+
1192
+ transformer_outputs = self.transformer(
1193
+ input_ids,
1194
+ past_key_values=past_key_values,
1195
+ attention_mask=attention_mask,
1196
+ position_ids=position_ids,
1197
+ head_mask=head_mask,
1198
+ inputs_embeds=inputs_embeds,
1199
+ use_cache=use_cache,
1200
+ output_attentions=output_attentions,
1201
+ output_hidden_states=output_hidden_states,
1202
+ return_dict=return_dict,
1203
+ cache_position=cache_position,
1204
+ )
1205
+ hidden_states = transformer_outputs[0]
1206
+
1207
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1208
+ lm_logits = self.lm_head(hidden_states[:, slice_indices, :])
1209
+
1210
+ loss = None
1211
+ if labels is not None:
1212
+ loss = self.loss_function(
1213
+ lm_logits,
1214
+ labels,
1215
+ vocab_size=self.config.vocab_size,
1216
+ **kwargs,
1217
+ )
1218
+
1219
+ if not return_dict:
1220
+ output = (lm_logits,) + transformer_outputs[1:]
1221
+ return ((loss,) + output) if loss is not None else output
1222
+
1223
+ return CausalLMOutputWithCrossAttentions(
1224
+ loss=loss,
1225
+ logits=lm_logits,
1226
+ past_key_values=transformer_outputs.past_key_values,
1227
+ hidden_states=transformer_outputs.hidden_states,
1228
+ attentions=transformer_outputs.attentions,
1229
+ )
1230
+
1231
+ def _reorder_cache(
1232
+ self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
1233
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1234
+ """
1235
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1236
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1237
+ beam_idx at every generation step.
1238
+
1239
+ Output shares the same memory storage as `past`.
1240
+ """
1241
+
1242
+ # Get a copy of `beam_idx` on all the devices where we need those indices.
1243
+ device_to_beam_idx = {
1244
+ past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
1245
+ }
1246
+ reordered_past = tuple(
1247
+ (
1248
+ layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
1249
+ layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
1250
+ )
1251
+ for layer_past in past
1252
+ )
1253
+ return reordered_past
1254
+
1255
+
1256
+ @add_start_docstrings(
1257
+ """
1258
+ The Falcon Model transformer with a sequence classification head on top (linear layer).
1259
+
1260
+ [`FalconForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1261
+ (e.g. GPT-1) do.
1262
+
1263
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1264
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1265
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1266
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1267
+ each row of the batch).
1268
+ """,
1269
+ FALCON_START_DOCSTRING,
1270
+ )
1271
+ class FalconForSequenceClassification(FalconPreTrainedModel):
1272
+ def __init__(self, config: FalconConfig):
1273
+ super().__init__(config)
1274
+ self.num_labels = config.num_labels
1275
+ self.transformer = FalconModel(config)
1276
+ self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
1277
+
1278
+ # Initialize weights and apply final processing
1279
+ self.post_init()
1280
+
1281
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1282
+ @add_code_sample_docstrings(
1283
+ checkpoint=_CHECKPOINT_FOR_DOC,
1284
+ output_type=SequenceClassifierOutputWithPast,
1285
+ config_class=_CONFIG_FOR_DOC,
1286
+ )
1287
+ def forward(
1288
+ self,
1289
+ input_ids: Optional[torch.LongTensor] = None,
1290
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1291
+ attention_mask: Optional[torch.Tensor] = None,
1292
+ head_mask: Optional[torch.Tensor] = None,
1293
+ inputs_embeds: Optional[torch.Tensor] = None,
1294
+ labels: Optional[torch.Tensor] = None,
1295
+ use_cache: Optional[bool] = None,
1296
+ output_attentions: Optional[bool] = None,
1297
+ output_hidden_states: Optional[bool] = None,
1298
+ return_dict: Optional[bool] = None,
1299
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
1300
+ r"""
1301
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1302
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1303
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1304
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1305
+ """
1306
+
1307
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1308
+
1309
+ transformer_outputs = self.transformer(
1310
+ input_ids,
1311
+ past_key_values=past_key_values,
1312
+ attention_mask=attention_mask,
1313
+ head_mask=head_mask,
1314
+ inputs_embeds=inputs_embeds,
1315
+ use_cache=use_cache,
1316
+ output_attentions=output_attentions,
1317
+ output_hidden_states=output_hidden_states,
1318
+ return_dict=return_dict,
1319
+ )
1320
+
1321
+ hidden_states = transformer_outputs[0]
1322
+ logits = self.score(hidden_states)
1323
+
1324
+ if input_ids is not None:
1325
+ batch_size = input_ids.shape[0]
1326
+ else:
1327
+ batch_size = inputs_embeds.shape[0]
1328
+
1329
+ if self.config.pad_token_id is None and batch_size != 1:
1330
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1331
+ if self.config.pad_token_id is None:
1332
+ last_non_pad_token = -1
1333
+ elif input_ids is not None:
1334
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1335
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
1336
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
1337
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
1338
+ else:
1339
+ last_non_pad_token = -1
1340
+ logger.warning_once(
1341
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1342
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1343
+ )
1344
+
1345
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
1346
+
1347
+ loss = None
1348
+ if labels is not None:
1349
+ if self.config.problem_type is None:
1350
+ if self.num_labels == 1:
1351
+ self.config.problem_type = "regression"
1352
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1353
+ self.config.problem_type = "single_label_classification"
1354
+ else:
1355
+ self.config.problem_type = "multi_label_classification"
1356
+
1357
+ if self.config.problem_type == "regression":
1358
+ loss_fct = MSELoss()
1359
+ if self.num_labels == 1:
1360
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1361
+ else:
1362
+ loss = loss_fct(pooled_logits, labels)
1363
+ elif self.config.problem_type == "single_label_classification":
1364
+ loss_fct = CrossEntropyLoss()
1365
+ loss = loss_fct(pooled_logits, labels)
1366
+ elif self.config.problem_type == "multi_label_classification":
1367
+ loss_fct = BCEWithLogitsLoss()
1368
+ loss = loss_fct(pooled_logits, labels)
1369
+ if not return_dict:
1370
+ output = (pooled_logits,) + transformer_outputs[1:]
1371
+ return ((loss,) + output) if loss is not None else output
1372
+
1373
+ return SequenceClassifierOutputWithPast(
1374
+ loss=loss,
1375
+ logits=pooled_logits,
1376
+ past_key_values=transformer_outputs.past_key_values,
1377
+ hidden_states=transformer_outputs.hidden_states,
1378
+ attentions=transformer_outputs.attentions,
1379
+ )
1380
+
1381
+
1382
+ @add_start_docstrings(
1383
+ """
1384
+ Falcon Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1385
+ Named-Entity-Recognition (NER) tasks.
1386
+ """,
1387
+ FALCON_START_DOCSTRING,
1388
+ )
1389
+ class FalconForTokenClassification(FalconPreTrainedModel):
1390
+ def __init__(self, config: FalconConfig):
1391
+ super().__init__(config)
1392
+ self.num_labels = config.num_labels
1393
+
1394
+ self.transformer = FalconModel(config)
1395
+ if getattr(config, "classifier_dropout", None) is not None:
1396
+ classifier_dropout = config.classifier_dropout
1397
+ elif getattr(config, "hidden_dropout", None) is not None:
1398
+ classifier_dropout = config.hidden_dropout
1399
+ else:
1400
+ classifier_dropout = 0.1
1401
+ self.dropout = nn.Dropout(classifier_dropout)
1402
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1403
+
1404
+ # Initialize weights and apply final processing
1405
+ self.post_init()
1406
+
1407
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1408
+ @add_code_sample_docstrings(
1409
+ checkpoint=_CHECKPOINT_FOR_DOC,
1410
+ output_type=TokenClassifierOutput,
1411
+ config_class=_CONFIG_FOR_DOC,
1412
+ )
1413
+ def forward(
1414
+ self,
1415
+ input_ids: Optional[torch.LongTensor] = None,
1416
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1417
+ attention_mask: Optional[torch.Tensor] = None,
1418
+ head_mask: Optional[torch.Tensor] = None,
1419
+ inputs_embeds: Optional[torch.Tensor] = None,
1420
+ labels: Optional[torch.Tensor] = None,
1421
+ use_cache: Optional[bool] = None,
1422
+ output_attentions: Optional[bool] = None,
1423
+ output_hidden_states: Optional[bool] = None,
1424
+ return_dict: Optional[bool] = None,
1425
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1426
+ r"""
1427
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1428
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1429
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1430
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1431
+ """
1432
+
1433
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1434
+
1435
+ transformer_outputs = self.transformer(
1436
+ input_ids,
1437
+ past_key_values=past_key_values,
1438
+ attention_mask=attention_mask,
1439
+ head_mask=head_mask,
1440
+ inputs_embeds=inputs_embeds,
1441
+ use_cache=use_cache,
1442
+ output_attentions=output_attentions,
1443
+ output_hidden_states=output_hidden_states,
1444
+ return_dict=return_dict,
1445
+ )
1446
+
1447
+ hidden_states = transformer_outputs[0]
1448
+ hidden_states = self.dropout(hidden_states)
1449
+ logits = self.classifier(hidden_states)
1450
+
1451
+ loss = None
1452
+ if labels is not None:
1453
+ batch_size, seq_length = labels.shape
1454
+ loss_fct = CrossEntropyLoss()
1455
+ loss = loss_fct(
1456
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1457
+ )
1458
+
1459
+ if not return_dict:
1460
+ output = (logits,) + transformer_outputs[2:]
1461
+ return ((loss,) + output) if loss is not None else output
1462
+
1463
+ return TokenClassifierOutput(
1464
+ loss=loss,
1465
+ logits=logits,
1466
+ hidden_states=transformer_outputs.hidden_states,
1467
+ attentions=transformer_outputs.attentions,
1468
+ )
1469
+
1470
+
1471
+ @add_start_docstrings(
1472
+ """
1473
+ The Falcon Model transformer with a span classification head on top for extractive question-answering tasks like
1474
+ SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1475
+ """,
1476
+ FALCON_START_DOCSTRING,
1477
+ )
1478
+ class FalconForQuestionAnswering(FalconPreTrainedModel):
1479
+ def __init__(self, config):
1480
+ super().__init__(config)
1481
+ self.transformer = FalconModel(config)
1482
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1483
+
1484
+ # Initialize weights and apply final processing
1485
+ self.post_init()
1486
+
1487
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1488
+ def forward(
1489
+ self,
1490
+ input_ids: Optional[torch.LongTensor] = None,
1491
+ attention_mask: Optional[torch.FloatTensor] = None,
1492
+ head_mask: Optional[torch.FloatTensor] = None,
1493
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1494
+ start_positions: Optional[torch.LongTensor] = None,
1495
+ end_positions: Optional[torch.LongTensor] = None,
1496
+ output_attentions: Optional[bool] = None,
1497
+ output_hidden_states: Optional[bool] = None,
1498
+ return_dict: Optional[bool] = None,
1499
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1500
+ r"""
1501
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1502
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1503
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1504
+ are not taken into account for computing the loss.
1505
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1506
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1507
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1508
+ are not taken into account for computing the loss.
1509
+ """
1510
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1511
+
1512
+ outputs = self.transformer(
1513
+ input_ids,
1514
+ attention_mask=attention_mask,
1515
+ head_mask=head_mask,
1516
+ inputs_embeds=inputs_embeds,
1517
+ output_attentions=output_attentions,
1518
+ output_hidden_states=output_hidden_states,
1519
+ return_dict=return_dict,
1520
+ )
1521
+
1522
+ sequence_output = outputs[0]
1523
+
1524
+ logits = self.qa_outputs(sequence_output)
1525
+ start_logits, end_logits = logits.split(1, dim=-1)
1526
+ start_logits = start_logits.squeeze(-1).contiguous()
1527
+ end_logits = end_logits.squeeze(-1).contiguous()
1528
+
1529
+ total_loss = None
1530
+ if start_positions is not None and end_positions is not None:
1531
+ # If we are on multi-GPU, split add a dimension
1532
+ if len(start_positions.size()) > 1:
1533
+ start_positions = start_positions.squeeze(-1)
1534
+ if len(end_positions.size()) > 1:
1535
+ end_positions = end_positions.squeeze(-1)
1536
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1537
+ ignored_index = start_logits.size(1)
1538
+ start_positions = start_positions.clamp(0, ignored_index)
1539
+ end_positions = end_positions.clamp(0, ignored_index)
1540
+
1541
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1542
+ start_loss = loss_fct(start_logits, start_positions)
1543
+ end_loss = loss_fct(end_logits, end_positions)
1544
+ total_loss = (start_loss + end_loss) / 2
1545
+
1546
+ if not return_dict:
1547
+ output = (start_logits, end_logits) + outputs[2:]
1548
+ return ((total_loss,) + output) if total_loss is not None else output
1549
+
1550
+ return QuestionAnsweringModelOutput(
1551
+ loss=total_loss,
1552
+ start_logits=start_logits,
1553
+ end_logits=end_logits,
1554
+ hidden_states=outputs.hidden_states,
1555
+ attentions=outputs.attentions,
1556
+ )
1557
+
1558
+
1559
+ __all__ = [
1560
+ "FalconForCausalLM",
1561
+ "FalconModel",
1562
+ "FalconPreTrainedModel",
1563
+ "FalconForSequenceClassification",
1564
+ "FalconForTokenClassification",
1565
+ "FalconForQuestionAnswering",
1566
+ ]
docs/transformers/build/lib/transformers/models/falcon_mamba/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_falcon_mamba import *
22
+ from .modeling_falcon_mamba import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/falcon_mamba/configuration_falcon_mamba.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """FALCONMAMBA configuration"""
16
+
17
+ import math
18
+
19
+ from ...configuration_utils import PretrainedConfig
20
+ from ...utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class FalconMambaConfig(PretrainedConfig):
27
+ """
28
+ This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA
29
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
30
+ defaults will yield a similar configuration to that of the FALCON_MAMBA
31
+ [tiiuae/falcon-mamba-7b](https://huggingface.co/tiiuae/falcon-mamba-7b) architecture.
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+
37
+ Args:
38
+ vocab_size (`int`, *optional*, defaults to 50280):
39
+ Vocabulary size of the FALCON_MAMBA model. Defines the number of different tokens that can be represented by the
40
+ `inputs_ids` passed when calling [`FalconMambaModel`].
41
+ hidden_size (`int`, *optional*, defaults to 768):
42
+ Dimensionality of the embeddings and hidden states.
43
+ state_size (`int`, *optional*, defaults to 16): shape of the state space latents.
44
+ num_hidden_layers (`int`, *optional*, defaults to 32):
45
+ Number of hidden layers in the model.
46
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
47
+ The epsilon to use in the layer normalization layers.
48
+ pad_token_id (`int`, *optional*, defaults to 0):
49
+ Padding token id.
50
+ bos_token_id (`int`, *optional*, defaults to 0):
51
+ The id of the beginning of sentence token in the vocabulary.
52
+ eos_token_id (`int`, *optional*, defaults to 0):
53
+ The id of the end of sentence token in the vocabulary.
54
+ expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
55
+ conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
56
+ use_bias (`bool`, *optional*, defaults to `False`):
57
+ Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
58
+ use_conv_bias (`bool`, *optional*, defaults to `True`):
59
+ Whether or not to use bias in the convolution layer of the mixer block.
60
+ hidden_act (`str`, *optional*, defaults to `"silu"`):
61
+ The non-linear activation function (function or string) in the decoder.
62
+ initializer_range (`float`, *optional*, defaults to 0.1):
63
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
+ residual_in_fp32 (`bool`, *optional*, defaults to `True`):
65
+ Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model
66
+ time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
67
+ Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
68
+ time_step_scale (`float`, *optional*, defaults to 1.0):
69
+ Scale used used to scale `dt_proj.bias`.
70
+ time_step_min (`float`, *optional*, defaults to 0.001):
71
+ Minimum `time_step` used to bound `dt_proj.bias`.
72
+ time_step_max (`float`, *optional*, defaults to 0.1):
73
+ Maximum `time_step` used to bound `dt_proj.bias`.
74
+ time_step_init_scheme (`float`, *optional*, defaults to `"random"`):
75
+ Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`
76
+ time_step_floor (`float`, *optional*, defaults to 0.0001):
77
+ Minimum clamping value of the `dt_proj.bias` layer initialization.
78
+ rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
79
+ Whether or not to rescale `out_proj` weights when initializing.
80
+ use_cache (`bool`, *optional*, defaults to `True`):
81
+ Whether or not the cache should be used.
82
+ use_mambapy (`bool`, *optional*, defaults to `False`):
83
+ Determines the fallback strategy during training if the CUDA-based official implementation of FalconMamba is not available. If `True`, the falcon_mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
84
+ mixer_rms_eps (`float`, *optional*, defaults to 1e-06):
85
+ The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states.
86
+ Example:
87
+
88
+ ```python
89
+ >>> from transformers import FalconMambaConfig, FalconMambaModel
90
+
91
+ >>> # Initializing a FalconMamba configuration
92
+ >>> configuration = FalconMambaConfig()
93
+
94
+ >>> # Initializing a model (with random weights) from the configuration
95
+ >>> model = FalconMambaModel(configuration)
96
+
97
+ >>> # Accessing the model configuration
98
+ >>> configuration = model.config
99
+ ```"""
100
+
101
+ model_type = "falcon_mamba"
102
+
103
+ def __init__(
104
+ self,
105
+ vocab_size=50280,
106
+ hidden_size=768,
107
+ state_size=16,
108
+ num_hidden_layers=32,
109
+ layer_norm_epsilon=1e-5,
110
+ pad_token_id=0,
111
+ bos_token_id=0,
112
+ eos_token_id=0,
113
+ expand=2,
114
+ conv_kernel=4,
115
+ use_bias=False,
116
+ use_conv_bias=True,
117
+ hidden_act="silu",
118
+ initializer_range=0.1,
119
+ residual_in_fp32=True,
120
+ time_step_rank="auto",
121
+ time_step_scale=1.0,
122
+ time_step_min=0.001,
123
+ time_step_max=0.1,
124
+ time_step_init_scheme="random",
125
+ time_step_floor=1e-4,
126
+ rescale_prenorm_residual=False,
127
+ use_cache=True,
128
+ use_mambapy=False,
129
+ mixer_rms_eps=1e-6,
130
+ **kwargs,
131
+ ):
132
+ self.vocab_size = vocab_size
133
+ self.hidden_size = hidden_size
134
+ self.state_size = state_size
135
+ self.num_hidden_layers = num_hidden_layers
136
+ self.layer_norm_epsilon = layer_norm_epsilon
137
+ self.conv_kernel = conv_kernel
138
+ self.expand = expand
139
+ self.intermediate_size = int(expand * self.hidden_size)
140
+ self.bos_token_id = bos_token_id
141
+ self.eos_token_id = eos_token_id
142
+ self.pad_token_id = pad_token_id
143
+ self.use_bias = use_bias
144
+ self.use_conv_bias = use_conv_bias
145
+ self.hidden_act = hidden_act
146
+ self.initializer_range = initializer_range
147
+ self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
148
+ self.time_step_scale = time_step_scale
149
+ self.time_step_min = time_step_min
150
+ self.time_step_max = time_step_max
151
+ self.time_step_init_scheme = time_step_init_scheme
152
+ self.time_step_floor = time_step_floor
153
+ self.rescale_prenorm_residual = rescale_prenorm_residual
154
+ self.residual_in_fp32 = residual_in_fp32
155
+ self.use_cache = use_cache
156
+ self.use_mambapy = use_mambapy
157
+ self.mixer_rms_eps = mixer_rms_eps
158
+
159
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
160
+
161
+
162
+ __all__ = ["FalconMambaConfig"]
docs/transformers/build/lib/transformers/models/falcon_mamba/modeling_falcon_mamba.py ADDED
@@ -0,0 +1,873 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch FALCONMAMBA model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import CrossEntropyLoss
25
+
26
+ from ...activations import ACT2FN
27
+ from ...cache_utils import MambaCache
28
+ from ...generation import GenerationMixin
29
+ from ...modeling_utils import PreTrainedModel
30
+ from ...utils import (
31
+ ModelOutput,
32
+ add_code_sample_docstrings,
33
+ add_start_docstrings,
34
+ add_start_docstrings_to_model_forward,
35
+ logging,
36
+ )
37
+ from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available
38
+ from .configuration_falcon_mamba import FalconMambaConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ if is_mambapy_available():
44
+ from mambapy.pscan import pscan
45
+ else:
46
+ pscan = None
47
+
48
+ if is_mamba_ssm_available():
49
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
50
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
51
+
52
+ from ...kernels.falcon_mamba import mamba_inner_fn
53
+ else:
54
+ selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
55
+
56
+ if is_causal_conv1d_available():
57
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
58
+ else:
59
+ causal_conv1d_update, causal_conv1d_fn = None, None
60
+
61
+ is_fast_path_available = all(
62
+ (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
63
+ )
64
+
65
+ _CHECKPOINT_FOR_DOC = "tiiuae/falcon-mamba-7b"
66
+ _CONFIG_FOR_DOC = "FalconMambaConfig"
67
+
68
+
69
+ def rms_forward(hidden_states, variance_epsilon=1e-6):
70
+ """
71
+ Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will
72
+ leverage this in order to multiply the final result with the RMSNorm weight
73
+
74
+ Args:
75
+ hidden_states (`torch.Tensor`):
76
+ Hidden states to normalize
77
+ variance_epsilon (`float`):
78
+ The eps value to add in the square root scaling factor
79
+ """
80
+ input_dtype = hidden_states.dtype
81
+ hidden_states = hidden_states.to(torch.float32)
82
+
83
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
84
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
85
+ return hidden_states.to(input_dtype)
86
+
87
+
88
+ class FalconMambaMixer(nn.Module):
89
+ """
90
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
91
+ A, D are input independent (see FalconMamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
92
+ ∆, B, C are input-dependent (this is a key difference between FalconMamba and the linear time invariant S4,
93
+ and is why FalconMamba is called **selective** state spaces)
94
+ """
95
+
96
+ def __init__(self, config: FalconMambaConfig, layer_idx: int):
97
+ super().__init__()
98
+ self.config = config
99
+ self.hidden_size = config.hidden_size
100
+ self.ssm_state_size = config.state_size
101
+ self.conv_kernel_size = config.conv_kernel
102
+ self.intermediate_size = config.intermediate_size
103
+ self.time_step_rank = int(config.time_step_rank)
104
+ self.layer_idx = layer_idx
105
+ self.use_conv_bias = config.use_conv_bias
106
+ self.conv1d = nn.Conv1d(
107
+ in_channels=self.intermediate_size,
108
+ out_channels=self.intermediate_size,
109
+ bias=config.use_conv_bias,
110
+ kernel_size=config.conv_kernel,
111
+ groups=self.intermediate_size,
112
+ padding=config.conv_kernel - 1,
113
+ )
114
+
115
+ self.activation = config.hidden_act
116
+ self.act = ACT2FN[config.hidden_act]
117
+
118
+ self.use_mambapy = config.use_mambapy
119
+
120
+ # projection of the input hidden states
121
+ self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
122
+ # selective projection used to make dt, B and C input dependent
123
+ self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
124
+ # time step projection (discretization)
125
+ self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
126
+
127
+ # S4D real initialization. These are not discretized!
128
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
129
+ A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
130
+ A = A.expand(self.intermediate_size, -1).contiguous()
131
+
132
+ self.A_log = nn.Parameter(torch.log(A))
133
+ self.D = nn.Parameter(torch.ones(self.intermediate_size))
134
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
135
+ self.use_bias = config.use_bias
136
+
137
+ # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
138
+ self.register_buffer(
139
+ "b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False
140
+ )
141
+ self.register_buffer(
142
+ "dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False
143
+ )
144
+ self.rms_eps = config.mixer_rms_eps
145
+
146
+ if not is_fast_path_available:
147
+ if self.use_mambapy:
148
+ if is_mambapy_available():
149
+ logger.warning_once(
150
+ "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
151
+ " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and"
152
+ " https://github.com/Dao-AILab/causal-conv1d"
153
+ )
154
+ else:
155
+ raise ImportError(
156
+ "use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py."
157
+ )
158
+ else:
159
+ logger.warning_once(
160
+ "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
161
+ " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and"
162
+ " https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
163
+ )
164
+
165
+ def cuda_kernels_forward(
166
+ self,
167
+ hidden_states: torch.Tensor,
168
+ cache_params: Optional[MambaCache] = None,
169
+ cache_position: Optional[torch.LongTensor] = None,
170
+ attention_mask: Optional[torch.LongTensor] = None,
171
+ ):
172
+ # 1. Gated MLP's linear projection
173
+ projected_states = self.in_proj(hidden_states).transpose(1, 2)
174
+
175
+ if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
176
+ contextualized_states = mamba_inner_fn(
177
+ projected_states,
178
+ self.conv1d.weight,
179
+ self.conv1d.bias if self.use_conv_bias else None,
180
+ self.x_proj.weight,
181
+ self.dt_proj.weight,
182
+ self.out_proj.weight,
183
+ self.out_proj.bias.float() if self.use_bias else None,
184
+ -torch.exp(self.A_log.float()),
185
+ None, # input-dependent B
186
+ None, # input-dependent C
187
+ self.D.float(),
188
+ delta_bias=self.dt_proj.bias.float(),
189
+ delta_softplus=True,
190
+ b_rms_weight=self.b_c_rms,
191
+ c_rms_weight=self.b_c_rms,
192
+ dt_rms_weight=self.dt_rms,
193
+ b_c_dt_rms_eps=self.rms_eps,
194
+ )
195
+
196
+ else:
197
+ hidden_states, gate = projected_states.chunk(2, dim=1)
198
+
199
+ if attention_mask is not None:
200
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
201
+
202
+ # 2. Convolution sequence transformation
203
+ conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
204
+ if cache_params is not None and cache_position[0] > 0:
205
+ hidden_states = causal_conv1d_update(
206
+ hidden_states.squeeze(-1),
207
+ cache_params.conv_states[self.layer_idx],
208
+ conv_weights,
209
+ self.conv1d.bias,
210
+ self.activation,
211
+ )
212
+ hidden_states = hidden_states.unsqueeze(-1)
213
+ else:
214
+ if cache_params is not None:
215
+ conv_states = nn.functional.pad(
216
+ hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
217
+ )
218
+ cache_params.update_conv_state(self.layer_idx, conv_states, cache_position)
219
+ hidden_states = causal_conv1d_fn(
220
+ hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
221
+ )
222
+
223
+ if attention_mask is not None:
224
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
225
+
226
+ # 3. State Space Model sequence transformation
227
+ # 3.a. input varying initialization of time_step, B and C
228
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
229
+ time_step, B, C = torch.split(
230
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
231
+ )
232
+
233
+ B = rms_forward(B, variance_epsilon=self.rms_eps)
234
+ C = rms_forward(C, variance_epsilon=self.rms_eps)
235
+ time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
236
+
237
+ # In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
238
+ # at the price of a small overhead.
239
+ if hasattr(self.config, "_pre_quantization_dtype"):
240
+ discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
241
+ else:
242
+ discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
243
+
244
+ A = -torch.exp(self.A_log.float())
245
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
246
+ time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
247
+ if cache_params is not None and cache_position[0] > 0:
248
+ scan_outputs = selective_state_update(
249
+ cache_params.ssm_states[self.layer_idx],
250
+ hidden_states[..., 0],
251
+ discrete_time_step[..., 0],
252
+ A,
253
+ B[:, 0],
254
+ C[:, 0],
255
+ self.D,
256
+ gate[..., 0],
257
+ time_proj_bias,
258
+ dt_softplus=True,
259
+ ).unsqueeze(-1)
260
+ else:
261
+ scan_outputs, ssm_state = selective_scan_fn(
262
+ hidden_states,
263
+ discrete_time_step,
264
+ A,
265
+ B.transpose(1, 2),
266
+ C.transpose(1, 2),
267
+ self.D.float(),
268
+ gate,
269
+ time_proj_bias,
270
+ delta_softplus=True,
271
+ return_last_state=True,
272
+ )
273
+ if ssm_state is not None and cache_params is not None:
274
+ cache_params.update_ssm_state(self.layer_idx, ssm_state)
275
+
276
+ # 4. Final linear projection
277
+ contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
278
+ return contextualized_states
279
+
280
+ def slow_forward(
281
+ self,
282
+ input_states,
283
+ cache_params: Optional[MambaCache] = None,
284
+ cache_position: Optional[torch.LongTensor] = None,
285
+ attention_mask: Optional[torch.LongTensor] = None,
286
+ ):
287
+ batch_size, seq_len, _ = input_states.shape
288
+ dtype = input_states.dtype
289
+ # 1. Gated MLP's linear projection
290
+ projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
291
+ hidden_states, gate = projected_states.chunk(2, dim=1)
292
+
293
+ if attention_mask is not None:
294
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
295
+
296
+ # 2. Convolution sequence transformation
297
+ if cache_params is not None:
298
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
299
+ ssm_state = ssm_state.to(hidden_states.device)
300
+ # use `cache_position.shape[0]` to check whether we are in prefill
301
+ # stage, it's equivalent to check `cache_position[0] == 0`, which
302
+ # breaks dynamo fullgraph constraints
303
+ if cache_position is not None and cache_position.shape[0] == self.conv_kernel_size:
304
+ conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
305
+
306
+ cache_params.update_conv_state(self.layer_idx, conv_state, cache_position)
307
+ hidden_states = self.act(
308
+ self.conv1d(hidden_states)[..., :seq_len]
309
+ ) # [batch, intermediate_size, seq_len]
310
+ else:
311
+ conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
312
+ conv_state = conv_state.to(self.conv1d.weight.device)
313
+ hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
314
+ if self.use_conv_bias:
315
+ hidden_states += self.conv1d.bias
316
+ hidden_states = (
317
+ self.act(hidden_states).to(dtype).unsqueeze(-1)
318
+ ) # [batch, intermediate_size, 1] : decoding
319
+ else:
320
+ ssm_state = torch.zeros(
321
+ (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype
322
+ )
323
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
324
+
325
+ if attention_mask is not None:
326
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
327
+
328
+ # 3. State Space Model sequence transformation
329
+ # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
330
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
331
+ time_step, B, C = torch.split(
332
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
333
+ )
334
+
335
+ B = rms_forward(B, variance_epsilon=self.rms_eps)
336
+ C = rms_forward(C, variance_epsilon=self.rms_eps)
337
+ time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
338
+
339
+ discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
340
+ discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(
341
+ 1, 2
342
+ ) # [batch, intermediate_size, seq_len]
343
+
344
+ # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
345
+ A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
346
+ discrete_A = torch.exp(
347
+ A[None, :, None, :] * discrete_time_step[:, :, :, None]
348
+ ) # [batch, intermediate_size, seq_len, ssm_state_size]
349
+ discrete_B = (
350
+ discrete_time_step[:, :, :, None] * B[:, None, :, :].float()
351
+ ) # [batch, intermediate_size, seq_len, ssm_state_size]
352
+ deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
353
+
354
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
355
+ if self.use_mambapy and self.training and cache_params is None:
356
+ hs = pscan(
357
+ discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)
358
+ ) # [batch, seq_len, intermediate_size, ssm_state_size]
359
+ scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len]
360
+ scan_output = scan_output + hidden_states * self.D[None, :, None]
361
+ scan_output = scan_output * self.act(gate)
362
+ else:
363
+ scan_outputs = []
364
+ for i in range(seq_len):
365
+ ssm_state = (
366
+ discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
367
+ ) # [batch, intermediate_size, ssm_state]
368
+ scan_output = torch.matmul(
369
+ ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)
370
+ ) # [batch, intermediate_size, 1]
371
+ scan_outputs.append(scan_output[:, :, 0])
372
+ scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
373
+ scan_output = scan_output + (hidden_states * self.D[None, :, None])
374
+ scan_output = scan_output * self.act(gate)
375
+
376
+ if cache_params is not None:
377
+ cache_params.update_ssm_state(self.layer_idx, ssm_state)
378
+
379
+ # 4. Final linear projection
380
+ contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
381
+ return contextualized_states
382
+
383
+ # Copied from transformers.models.mamba.modeling_mamba.MambaMixer.forward
384
+ def forward(
385
+ self,
386
+ hidden_states,
387
+ cache_params: Optional[MambaCache] = None,
388
+ cache_position: Optional[torch.LongTensor] = None,
389
+ attention_mask: Optional[torch.LongTensor] = None,
390
+ ):
391
+ if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
392
+ return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
393
+ return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
394
+
395
+
396
+ # Copied from transformers.models.mamba.modeling_mamba.MambaRMSNorm with Mamba->FalconMamba
397
+ class FalconMambaRMSNorm(nn.Module):
398
+ def __init__(self, hidden_size, eps=1e-6):
399
+ """
400
+ FalconMambaRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
401
+ """
402
+ super().__init__()
403
+ self.weight = nn.Parameter(torch.ones(hidden_size))
404
+ self.variance_epsilon = eps
405
+
406
+ def extra_repr(self):
407
+ return f"{self.weight.shape[0]}, eps={self.variance_epsilon}"
408
+
409
+ # Ignore copy
410
+ def forward(self, hidden_states):
411
+ return self.weight.to(hidden_states.device) * rms_forward(
412
+ hidden_states, variance_epsilon=self.variance_epsilon
413
+ )
414
+
415
+
416
+ # Copied from transformers.models.mamba.modeling_mamba.MambaBlock with Mamba->FalconMamba,FalconMambaCache->MambaCache
417
+ class FalconMambaBlock(nn.Module):
418
+ def __init__(self, config, layer_idx):
419
+ super().__init__()
420
+ self.config = config
421
+ self.layer_idx = layer_idx
422
+ self.residual_in_fp32 = config.residual_in_fp32
423
+ self.norm = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
424
+ self.mixer = FalconMambaMixer(config, layer_idx=layer_idx)
425
+
426
+ def forward(
427
+ self,
428
+ hidden_states,
429
+ cache_params: Optional[MambaCache] = None,
430
+ cache_position: Optional[torch.LongTensor] = None,
431
+ attention_mask: Optional[torch.LongTensor] = None,
432
+ ):
433
+ residual = hidden_states
434
+ hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
435
+ if self.residual_in_fp32:
436
+ residual = residual.to(torch.float32)
437
+
438
+ hidden_states = self.mixer(
439
+ hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
440
+ )
441
+ hidden_states = residual + hidden_states
442
+ return hidden_states
443
+
444
+
445
+ # Copied from transformers.models.mamba.modeling_mamba.MambaPreTrainedModel with Mamba->FalconMamba
446
+ class FalconMambaPreTrainedModel(PreTrainedModel):
447
+ """
448
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
449
+ models.
450
+ """
451
+
452
+ config_class = FalconMambaConfig
453
+ base_model_prefix = "backbone"
454
+ _no_split_modules = ["FalconMambaBlock", "FalconMambaMixer"]
455
+ supports_gradient_checkpointing = True
456
+ _is_stateful = True
457
+
458
+ def _init_weights(self, module):
459
+ """Initialize the weights."""
460
+ if isinstance(module, FalconMambaMixer):
461
+ module.A_log._no_weight_decay = True
462
+ module.D._no_weight_decay = True
463
+
464
+ dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
465
+ if self.config.time_step_init_scheme == "constant":
466
+ nn.init.constant_(module.dt_proj.weight, dt_init_std)
467
+ elif self.config.time_step_init_scheme == "random":
468
+ nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
469
+
470
+ dt = torch.exp(
471
+ torch.rand(self.config.intermediate_size)
472
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
473
+ + math.log(self.config.time_step_min)
474
+ ).clamp(min=self.config.time_step_floor)
475
+ # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
476
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
477
+ with torch.no_grad():
478
+ module.dt_proj.bias.copy_(inv_dt)
479
+ module.dt_proj.bias._no_reinit = True
480
+
481
+ if isinstance(module, nn.Linear):
482
+ if module.bias is not None:
483
+ if not getattr(module.bias, "_no_reinit", False):
484
+ nn.init.zeros_(module.bias)
485
+ elif isinstance(module, nn.Embedding):
486
+ nn.init.normal_(module.weight, std=self.config.initializer_range)
487
+
488
+ if self.config.rescale_prenorm_residual:
489
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
490
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
491
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
492
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
493
+ #
494
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
495
+ for name, p in module.named_parameters():
496
+ if name in ["out_proj.weight"]:
497
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
498
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
499
+ # We need to reinit p since this code could be called multiple times
500
+ # Having just p *= scale would repeatedly scale it down
501
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
502
+ with torch.no_grad():
503
+ p /= math.sqrt(self.config.num_hidden_layers)
504
+
505
+
506
+ @dataclass
507
+ # Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->FALCONMAMBA,Mamba->FalconMamba,FalconMambaCache->MambaCache
508
+ class FalconMambaOutput(ModelOutput):
509
+ """
510
+ Class for the FALCONMAMBA model outputs.
511
+
512
+ Args:
513
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
514
+ Sequence of hidden-states at the output of the last layer of the model.
515
+ cache_params (`MambaCache`):
516
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
517
+ avoid providing the old `input_ids`.
518
+
519
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
520
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
521
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
522
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
523
+
524
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
525
+ """
526
+
527
+ last_hidden_state: Optional[torch.FloatTensor] = None
528
+ cache_params: Optional[MambaCache] = None
529
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
530
+
531
+
532
+ @dataclass
533
+ # Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->FalconMamba,FalconMambaCache->MambaCache
534
+ class FalconMambaCausalLMOutput(ModelOutput):
535
+ """
536
+ Base class for causal language model (or autoregressive) outputs.
537
+
538
+ Args:
539
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
540
+ Language modeling loss (for next-token prediction).
541
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
542
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
543
+ cache_params (`MambaCache`):
544
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
545
+ avoid providing the old `input_ids`.
546
+
547
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
548
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
549
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
550
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
551
+
552
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
553
+ """
554
+
555
+ loss: Optional[torch.FloatTensor] = None
556
+ logits: Optional[torch.FloatTensor] = None
557
+ cache_params: Optional[MambaCache] = None
558
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
559
+
560
+
561
+ FALCONMAMBA_START_DOCSTRING = r"""
562
+
563
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
564
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
565
+ etc.)
566
+
567
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
568
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
569
+ and behavior.
570
+
571
+ Parameters:
572
+ config ([`FalconMambaConfig`]): Model configuration class with all the parameters of the model.
573
+ Initializing with a config file does not load the weights associated with the model, only the
574
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
575
+ """
576
+
577
+ FALCONMAMBA_INPUTS_DOCSTRING = r"""
578
+ Args:
579
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
580
+ Indices of input sequence tokens in the vocabulary.
581
+
582
+ If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as
583
+ `input_ids`.
584
+
585
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
586
+ [`PreTrainedTokenizer.__call__`] for details.
587
+
588
+ [What are input IDs?](../glossary#input-ids)
589
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
590
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
591
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
592
+ model's internal embedding lookup matrix.
593
+ cache_params (`MambaCache`, *optional*):
594
+ If passed along, the model uses the previous state in all the blocks (which will give the output for the
595
+ `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
596
+ use_cache (`bool`, *optional*):
597
+ If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
598
+ output_hidden_states (`bool`, *optional*):
599
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
600
+ more detail.
601
+ return_dict (`bool`, *optional*):
602
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
603
+ """
604
+
605
+
606
+ @add_start_docstrings(
607
+ "The bare FALCONMAMBA Model transformer outputting raw hidden-states without any specific head on top.",
608
+ FALCONMAMBA_START_DOCSTRING,
609
+ )
610
+ class FalconMambaModel(FalconMambaPreTrainedModel):
611
+ def __init__(self, config):
612
+ super().__init__(config)
613
+
614
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
615
+ self.layers = nn.ModuleList(
616
+ [FalconMambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]
617
+ )
618
+
619
+ self.gradient_checkpointing = False
620
+ self.norm_f = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
621
+ # Initialize weights and apply final processing
622
+ self.post_init()
623
+
624
+ def get_input_embeddings(self):
625
+ return self.embeddings
626
+
627
+ def set_input_embeddings(self, new_embeddings):
628
+ self.embeddings = new_embeddings
629
+
630
+ @add_start_docstrings_to_model_forward(FALCONMAMBA_INPUTS_DOCSTRING)
631
+ @add_code_sample_docstrings(
632
+ checkpoint=_CHECKPOINT_FOR_DOC,
633
+ output_type=FalconMambaOutput,
634
+ config_class=_CONFIG_FOR_DOC,
635
+ )
636
+ def forward(
637
+ self,
638
+ input_ids: Optional[torch.LongTensor] = None,
639
+ inputs_embeds: Optional[torch.LongTensor] = None,
640
+ cache_params: Optional[MambaCache] = None,
641
+ use_cache: Optional[bool] = None,
642
+ output_hidden_states: Optional[bool] = None,
643
+ return_dict: Optional[bool] = None,
644
+ cache_position: Optional[torch.LongTensor] = None,
645
+ attention_mask: Optional[torch.LongTensor] = None,
646
+ ) -> Union[Tuple, FalconMambaOutput]:
647
+ output_hidden_states = (
648
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
649
+ )
650
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
651
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
652
+
653
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
654
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
655
+
656
+ if inputs_embeds is None:
657
+ inputs_embeds = self.embeddings(input_ids)
658
+
659
+ if self.gradient_checkpointing and self.training and use_cache:
660
+ use_cache = False
661
+
662
+ if use_cache:
663
+ if cache_params is None:
664
+ cache_params = MambaCache(
665
+ self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
666
+ )
667
+ cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
668
+ elif cache_position is None:
669
+ # cases when we do manual forward instead of using `model.generate` which will initiate
670
+ # `cache_position` and makes sure it is not None, throw error here instead of doing some
671
+ # hack to conjecture the current cache position
672
+ raise ValueError(
673
+ "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
674
+ "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
675
+ "be initialized for you automatically"
676
+ )
677
+ else:
678
+ cache_params = None
679
+ hidden_states = inputs_embeds
680
+ all_hidden_states = () if output_hidden_states else None
681
+ for mixer_block in self.layers:
682
+ if self.gradient_checkpointing and self.training:
683
+ hidden_states = self._gradient_checkpointing_func(
684
+ mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
685
+ )
686
+ else:
687
+ hidden_states = mixer_block(
688
+ hidden_states,
689
+ cache_params=cache_params,
690
+ cache_position=cache_position,
691
+ attention_mask=attention_mask,
692
+ )
693
+
694
+ if output_hidden_states:
695
+ all_hidden_states = all_hidden_states + (hidden_states,)
696
+
697
+ hidden_states = self.norm_f(hidden_states)
698
+
699
+ if output_hidden_states:
700
+ all_hidden_states = all_hidden_states + (hidden_states,)
701
+
702
+ if not return_dict:
703
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
704
+
705
+ return FalconMambaOutput(
706
+ last_hidden_state=hidden_states,
707
+ cache_params=cache_params if use_cache else None,
708
+ hidden_states=all_hidden_states,
709
+ )
710
+
711
+
712
+ @add_start_docstrings(
713
+ """
714
+ The FALCONMAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input
715
+ embeddings).
716
+ """,
717
+ FALCONMAMBA_START_DOCSTRING,
718
+ )
719
+ # Copied from transformers.models.mamba.modeling_mamba.MambaForCausalLM with MAMBA->FALCONMAMBA,Mamba->FalconMamba,mamba->falcon_mamba,FalconMambaCache->MambaCache
720
+ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
721
+ _tied_weights_keys = ["lm_head.weight"]
722
+
723
+ def __init__(self, config):
724
+ super().__init__(config)
725
+ self.backbone = FalconMambaModel(config)
726
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
727
+ # Initialize weights and apply final processing
728
+ self.post_init()
729
+
730
+ def get_output_embeddings(self):
731
+ return self.lm_head
732
+
733
+ def set_output_embeddings(self, new_embeddings):
734
+ self.lm_head = new_embeddings
735
+
736
+ def get_input_embeddings(self):
737
+ return self.backbone.get_input_embeddings()
738
+
739
+ def set_input_embeddings(self, new_embeddings):
740
+ return self.backbone.set_input_embeddings(new_embeddings)
741
+
742
+ def _update_model_kwargs_for_generation(
743
+ self, outputs: ModelOutput, model_kwargs: Dict[str, Any], num_new_tokens: int = 1, **kwargs
744
+ ) -> Dict[str, Any]:
745
+ model_kwargs["cache_params"] = outputs.get("cache_params", None)
746
+ if (
747
+ model_kwargs.get("use_cache", True)
748
+ and "cache_position" in model_kwargs
749
+ and model_kwargs["cache_position"] is not None
750
+ ):
751
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
752
+
753
+ if "attention_mask" in model_kwargs:
754
+ attention_mask = model_kwargs["attention_mask"]
755
+ model_kwargs["attention_mask"] = torch.cat(
756
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
757
+ )
758
+
759
+ return model_kwargs
760
+
761
+ def prepare_inputs_for_generation(
762
+ self,
763
+ input_ids,
764
+ inputs_embeds=None,
765
+ use_cache=None,
766
+ cache_params: Optional[MambaCache] = None,
767
+ cache_position: Optional[torch.LongTensor] = None,
768
+ attention_mask: Optional[torch.LongTensor] = None,
769
+ **kwargs,
770
+ ):
771
+ # Overwritten -- uses `cache_params` as opposed to `past_key_values`
772
+
773
+ if use_cache:
774
+ # `cache_position` should have been initialized in `generate`
775
+ if cache_position is None:
776
+ raise ValueError(
777
+ "`cache_position` should not be None as it should have been initialized in "
778
+ "`model.generate`, you are responsible for passing in a valid `cache_position` if "
779
+ "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
780
+ )
781
+ if cache_position[0] > 0:
782
+ input_ids = input_ids[:, -1].unsqueeze(-1)
783
+
784
+ if attention_mask is not None:
785
+ attention_mask = None
786
+
787
+ else:
788
+ # we initialize the `cache_position` to full size of `conv_states` at prefill stage
789
+ # considering padding will be applied when input length is shorter, and truncation
790
+ # will be applied when it is longer, so it will be equivalent to always have it match
791
+ # the length of `cache_params.conv_states`, which is `config.conv_kernel`
792
+ cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
793
+
794
+ if inputs_embeds is not None and cache_params is None:
795
+ model_inputs = {"inputs_embeds": inputs_embeds}
796
+ else:
797
+ model_inputs = {"input_ids": input_ids.contiguous()}
798
+
799
+ model_inputs.update(
800
+ {
801
+ "cache_params": cache_params,
802
+ "use_cache": use_cache,
803
+ "cache_position": cache_position,
804
+ "attention_mask": attention_mask,
805
+ }
806
+ )
807
+ return model_inputs
808
+
809
+ @add_start_docstrings_to_model_forward(FALCONMAMBA_INPUTS_DOCSTRING)
810
+ @add_code_sample_docstrings(
811
+ checkpoint=_CHECKPOINT_FOR_DOC,
812
+ output_type=FalconMambaCausalLMOutput,
813
+ config_class=_CONFIG_FOR_DOC,
814
+ )
815
+ def forward(
816
+ self,
817
+ input_ids: Optional[torch.LongTensor] = None,
818
+ attention_mask: Optional[torch.LongTensor] = None,
819
+ inputs_embeds: Optional[torch.FloatTensor] = None,
820
+ cache_params: Optional[MambaCache] = None,
821
+ labels: Optional[torch.LongTensor] = None,
822
+ output_hidden_states: Optional[bool] = None,
823
+ return_dict: Optional[bool] = None,
824
+ use_cache: Optional[bool] = None,
825
+ cache_position: Optional[torch.Tensor] = None,
826
+ **kwargs, # for now we need this for generation
827
+ ) -> Union[Tuple, FalconMambaCausalLMOutput]:
828
+ r"""
829
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
830
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
831
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
832
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
833
+ """
834
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
835
+
836
+ falcon_mamba_outputs = self.backbone(
837
+ input_ids,
838
+ cache_params=cache_params,
839
+ inputs_embeds=inputs_embeds,
840
+ output_hidden_states=output_hidden_states,
841
+ return_dict=return_dict,
842
+ use_cache=use_cache,
843
+ cache_position=cache_position,
844
+ attention_mask=attention_mask,
845
+ )
846
+ hidden_states = falcon_mamba_outputs[0]
847
+
848
+ logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
849
+
850
+ loss = None
851
+ if labels is not None:
852
+ # move labels to correct device to enable model parallelism
853
+ labels = labels.to(logits.device)
854
+ # Shift so that tokens < n predict n
855
+ shift_logits = logits[..., :-1, :].contiguous()
856
+ shift_labels = labels[..., 1:].contiguous()
857
+ # Flatten the tokens
858
+ loss_fct = CrossEntropyLoss()
859
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
860
+
861
+ if not return_dict:
862
+ output = (logits,) + falcon_mamba_outputs[1:]
863
+ return ((loss,) + output) if loss is not None else output
864
+
865
+ return FalconMambaCausalLMOutput(
866
+ loss=loss,
867
+ logits=logits,
868
+ cache_params=falcon_mamba_outputs.cache_params,
869
+ hidden_states=falcon_mamba_outputs.hidden_states,
870
+ )
871
+
872
+
873
+ __all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel"]
docs/transformers/build/lib/transformers/models/fastspeech2_conformer/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_fastspeech2_conformer import *
22
+ from .modeling_fastspeech2_conformer import *
23
+ from .tokenization_fastspeech2_conformer import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """FastSpeech2Conformer model configuration"""
16
+
17
+ from typing import Dict
18
+
19
+ from ...configuration_utils import PretrainedConfig
20
+ from ...utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class FastSpeech2ConformerConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`FastSpeech2ConformerModel`]. It is used to
29
+ instantiate a FastSpeech2Conformer model according to the specified arguments, defining the model architecture.
30
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the
31
+ FastSpeech2Conformer [espnet/fastspeech2_conformer](https://huggingface.co/espnet/fastspeech2_conformer)
32
+ architecture.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+ Args:
38
+ hidden_size (`int`, *optional*, defaults to 384):
39
+ The dimensionality of the hidden layers.
40
+ vocab_size (`int`, *optional*, defaults to 78):
41
+ The size of the vocabulary.
42
+ num_mel_bins (`int`, *optional*, defaults to 80):
43
+ The number of mel filters used in the filter bank.
44
+ encoder_num_attention_heads (`int`, *optional*, defaults to 2):
45
+ The number of attention heads in the encoder.
46
+ encoder_layers (`int`, *optional*, defaults to 4):
47
+ The number of layers in the encoder.
48
+ encoder_linear_units (`int`, *optional*, defaults to 1536):
49
+ The number of units in the linear layer of the encoder.
50
+ decoder_layers (`int`, *optional*, defaults to 4):
51
+ The number of layers in the decoder.
52
+ decoder_num_attention_heads (`int`, *optional*, defaults to 2):
53
+ The number of attention heads in the decoder.
54
+ decoder_linear_units (`int`, *optional*, defaults to 1536):
55
+ The number of units in the linear layer of the decoder.
56
+ speech_decoder_postnet_layers (`int`, *optional*, defaults to 5):
57
+ The number of layers in the post-net of the speech decoder.
58
+ speech_decoder_postnet_units (`int`, *optional*, defaults to 256):
59
+ The number of units in the post-net layers of the speech decoder.
60
+ speech_decoder_postnet_kernel (`int`, *optional*, defaults to 5):
61
+ The kernel size in the post-net of the speech decoder.
62
+ positionwise_conv_kernel_size (`int`, *optional*, defaults to 3):
63
+ The size of the convolution kernel used in the position-wise layer.
64
+ encoder_normalize_before (`bool`, *optional*, defaults to `False`):
65
+ Specifies whether to normalize before encoder layers.
66
+ decoder_normalize_before (`bool`, *optional*, defaults to `False`):
67
+ Specifies whether to normalize before decoder layers.
68
+ encoder_concat_after (`bool`, *optional*, defaults to `False`):
69
+ Specifies whether to concatenate after encoder layers.
70
+ decoder_concat_after (`bool`, *optional*, defaults to `False`):
71
+ Specifies whether to concatenate after decoder layers.
72
+ reduction_factor (`int`, *optional*, defaults to 1):
73
+ The factor by which the speech frame rate is reduced.
74
+ speaking_speed (`float`, *optional*, defaults to 1.0):
75
+ The speed of the speech produced.
76
+ use_macaron_style_in_conformer (`bool`, *optional*, defaults to `True`):
77
+ Specifies whether to use macaron style in the conformer.
78
+ use_cnn_in_conformer (`bool`, *optional*, defaults to `True`):
79
+ Specifies whether to use convolutional neural networks in the conformer.
80
+ encoder_kernel_size (`int`, *optional*, defaults to 7):
81
+ The kernel size used in the encoder.
82
+ decoder_kernel_size (`int`, *optional*, defaults to 31):
83
+ The kernel size used in the decoder.
84
+ duration_predictor_layers (`int`, *optional*, defaults to 2):
85
+ The number of layers in the duration predictor.
86
+ duration_predictor_channels (`int`, *optional*, defaults to 256):
87
+ The number of channels in the duration predictor.
88
+ duration_predictor_kernel_size (`int`, *optional*, defaults to 3):
89
+ The kernel size used in the duration predictor.
90
+ energy_predictor_layers (`int`, *optional*, defaults to 2):
91
+ The number of layers in the energy predictor.
92
+ energy_predictor_channels (`int`, *optional*, defaults to 256):
93
+ The number of channels in the energy predictor.
94
+ energy_predictor_kernel_size (`int`, *optional*, defaults to 3):
95
+ The kernel size used in the energy predictor.
96
+ energy_predictor_dropout (`float`, *optional*, defaults to 0.5):
97
+ The dropout rate in the energy predictor.
98
+ energy_embed_kernel_size (`int`, *optional*, defaults to 1):
99
+ The kernel size used in the energy embed layer.
100
+ energy_embed_dropout (`float`, *optional*, defaults to 0.0):
101
+ The dropout rate in the energy embed layer.
102
+ stop_gradient_from_energy_predictor (`bool`, *optional*, defaults to `False`):
103
+ Specifies whether to stop gradients from the energy predictor.
104
+ pitch_predictor_layers (`int`, *optional*, defaults to 5):
105
+ The number of layers in the pitch predictor.
106
+ pitch_predictor_channels (`int`, *optional*, defaults to 256):
107
+ The number of channels in the pitch predictor.
108
+ pitch_predictor_kernel_size (`int`, *optional*, defaults to 5):
109
+ The kernel size used in the pitch predictor.
110
+ pitch_predictor_dropout (`float`, *optional*, defaults to 0.5):
111
+ The dropout rate in the pitch predictor.
112
+ pitch_embed_kernel_size (`int`, *optional*, defaults to 1):
113
+ The kernel size used in the pitch embed layer.
114
+ pitch_embed_dropout (`float`, *optional*, defaults to 0.0):
115
+ The dropout rate in the pitch embed layer.
116
+ stop_gradient_from_pitch_predictor (`bool`, *optional*, defaults to `True`):
117
+ Specifies whether to stop gradients from the pitch predictor.
118
+ encoder_dropout_rate (`float`, *optional*, defaults to 0.2):
119
+ The dropout rate in the encoder.
120
+ encoder_positional_dropout_rate (`float`, *optional*, defaults to 0.2):
121
+ The positional dropout rate in the encoder.
122
+ encoder_attention_dropout_rate (`float`, *optional*, defaults to 0.2):
123
+ The attention dropout rate in the encoder.
124
+ decoder_dropout_rate (`float`, *optional*, defaults to 0.2):
125
+ The dropout rate in the decoder.
126
+ decoder_positional_dropout_rate (`float`, *optional*, defaults to 0.2):
127
+ The positional dropout rate in the decoder.
128
+ decoder_attention_dropout_rate (`float`, *optional*, defaults to 0.2):
129
+ The attention dropout rate in the decoder.
130
+ duration_predictor_dropout_rate (`float`, *optional*, defaults to 0.2):
131
+ The dropout rate in the duration predictor.
132
+ speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5):
133
+ The dropout rate in the speech decoder postnet.
134
+ max_source_positions (`int`, *optional*, defaults to 5000):
135
+ if `"relative"` position embeddings are used, defines the maximum source input positions.
136
+ use_masking (`bool`, *optional*, defaults to `True`):
137
+ Specifies whether to use masking in the model.
138
+ use_weighted_masking (`bool`, *optional*, defaults to `False`):
139
+ Specifies whether to use weighted masking in the model.
140
+ num_speakers (`int`, *optional*):
141
+ Number of speakers. If set to > 1, assume that the speaker ids will be provided as the input and use
142
+ speaker id embedding layer.
143
+ num_languages (`int`, *optional*):
144
+ Number of languages. If set to > 1, assume that the language ids will be provided as the input and use the
145
+ languge id embedding layer.
146
+ speaker_embed_dim (`int`, *optional*):
147
+ Speaker embedding dimension. If set to > 0, assume that speaker_embedding will be provided as the input.
148
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
149
+ Specifies whether the model is an encoder-decoder.
150
+
151
+ Example:
152
+
153
+ ```python
154
+ >>> from transformers import FastSpeech2ConformerModel, FastSpeech2ConformerConfig
155
+
156
+ >>> # Initializing a FastSpeech2Conformer style configuration
157
+ >>> configuration = FastSpeech2ConformerConfig()
158
+
159
+ >>> # Initializing a model from the FastSpeech2Conformer style configuration
160
+ >>> model = FastSpeech2ConformerModel(configuration)
161
+
162
+ >>> # Accessing the model configuration
163
+ >>> configuration = model.config
164
+ ```"""
165
+
166
+ model_type = "fastspeech2_conformer"
167
+ base_config_key = "model_config"
168
+ attribute_map = {"num_hidden_layers": "encoder_layers", "num_attention_heads": "encoder_num_attention_heads"}
169
+
170
+ def __init__(
171
+ self,
172
+ hidden_size=384,
173
+ vocab_size=78,
174
+ num_mel_bins=80,
175
+ encoder_num_attention_heads=2,
176
+ encoder_layers=4,
177
+ encoder_linear_units=1536,
178
+ decoder_layers=4,
179
+ decoder_num_attention_heads=2,
180
+ decoder_linear_units=1536,
181
+ speech_decoder_postnet_layers=5,
182
+ speech_decoder_postnet_units=256,
183
+ speech_decoder_postnet_kernel=5,
184
+ positionwise_conv_kernel_size=3,
185
+ encoder_normalize_before=False,
186
+ decoder_normalize_before=False,
187
+ encoder_concat_after=False,
188
+ decoder_concat_after=False,
189
+ reduction_factor=1,
190
+ speaking_speed=1.0,
191
+ use_macaron_style_in_conformer=True,
192
+ use_cnn_in_conformer=True,
193
+ encoder_kernel_size=7,
194
+ decoder_kernel_size=31,
195
+ duration_predictor_layers=2,
196
+ duration_predictor_channels=256,
197
+ duration_predictor_kernel_size=3,
198
+ energy_predictor_layers=2,
199
+ energy_predictor_channels=256,
200
+ energy_predictor_kernel_size=3,
201
+ energy_predictor_dropout=0.5,
202
+ energy_embed_kernel_size=1,
203
+ energy_embed_dropout=0.0,
204
+ stop_gradient_from_energy_predictor=False,
205
+ pitch_predictor_layers=5,
206
+ pitch_predictor_channels=256,
207
+ pitch_predictor_kernel_size=5,
208
+ pitch_predictor_dropout=0.5,
209
+ pitch_embed_kernel_size=1,
210
+ pitch_embed_dropout=0.0,
211
+ stop_gradient_from_pitch_predictor=True,
212
+ encoder_dropout_rate=0.2,
213
+ encoder_positional_dropout_rate=0.2,
214
+ encoder_attention_dropout_rate=0.2,
215
+ decoder_dropout_rate=0.2,
216
+ decoder_positional_dropout_rate=0.2,
217
+ decoder_attention_dropout_rate=0.2,
218
+ duration_predictor_dropout_rate=0.2,
219
+ speech_decoder_postnet_dropout=0.5,
220
+ max_source_positions=5000,
221
+ use_masking=True,
222
+ use_weighted_masking=False,
223
+ num_speakers=None,
224
+ num_languages=None,
225
+ speaker_embed_dim=None,
226
+ is_encoder_decoder=True,
227
+ **kwargs,
228
+ ):
229
+ if positionwise_conv_kernel_size % 2 == 0:
230
+ raise ValueError(
231
+ f"positionwise_conv_kernel_size must be odd, but got {positionwise_conv_kernel_size} instead."
232
+ )
233
+ if encoder_kernel_size % 2 == 0:
234
+ raise ValueError(f"encoder_kernel_size must be odd, but got {encoder_kernel_size} instead.")
235
+ if decoder_kernel_size % 2 == 0:
236
+ raise ValueError(f"decoder_kernel_size must be odd, but got {decoder_kernel_size} instead.")
237
+ if duration_predictor_kernel_size % 2 == 0:
238
+ raise ValueError(
239
+ f"duration_predictor_kernel_size must be odd, but got {duration_predictor_kernel_size} instead."
240
+ )
241
+ if energy_predictor_kernel_size % 2 == 0:
242
+ raise ValueError(
243
+ f"energy_predictor_kernel_size must be odd, but got {energy_predictor_kernel_size} instead."
244
+ )
245
+ if energy_embed_kernel_size % 2 == 0:
246
+ raise ValueError(f"energy_embed_kernel_size must be odd, but got {energy_embed_kernel_size} instead.")
247
+ if pitch_predictor_kernel_size % 2 == 0:
248
+ raise ValueError(
249
+ f"pitch_predictor_kernel_size must be odd, but got {pitch_predictor_kernel_size} instead."
250
+ )
251
+ if pitch_embed_kernel_size % 2 == 0:
252
+ raise ValueError(f"pitch_embed_kernel_size must be odd, but got {pitch_embed_kernel_size} instead.")
253
+ if hidden_size % encoder_num_attention_heads != 0:
254
+ raise ValueError("The hidden_size must be evenly divisible by encoder_num_attention_heads.")
255
+ if hidden_size % decoder_num_attention_heads != 0:
256
+ raise ValueError("The hidden_size must be evenly divisible by decoder_num_attention_heads.")
257
+ if use_masking and use_weighted_masking:
258
+ raise ValueError("Either use_masking or use_weighted_masking can be True, but not both.")
259
+
260
+ self.hidden_size = hidden_size
261
+ self.vocab_size = vocab_size
262
+ self.num_mel_bins = num_mel_bins
263
+ self.encoder_config = {
264
+ "num_attention_heads": encoder_num_attention_heads,
265
+ "layers": encoder_layers,
266
+ "kernel_size": encoder_kernel_size,
267
+ "attention_dropout_rate": encoder_attention_dropout_rate,
268
+ "dropout_rate": encoder_dropout_rate,
269
+ "positional_dropout_rate": encoder_positional_dropout_rate,
270
+ "linear_units": encoder_linear_units,
271
+ "normalize_before": encoder_normalize_before,
272
+ "concat_after": encoder_concat_after,
273
+ }
274
+ self.decoder_config = {
275
+ "num_attention_heads": decoder_num_attention_heads,
276
+ "layers": decoder_layers,
277
+ "kernel_size": decoder_kernel_size,
278
+ "attention_dropout_rate": decoder_attention_dropout_rate,
279
+ "dropout_rate": decoder_dropout_rate,
280
+ "positional_dropout_rate": decoder_positional_dropout_rate,
281
+ "linear_units": decoder_linear_units,
282
+ "normalize_before": decoder_normalize_before,
283
+ "concat_after": decoder_concat_after,
284
+ }
285
+ self.encoder_num_attention_heads = encoder_num_attention_heads
286
+ self.encoder_layers = encoder_layers
287
+ self.duration_predictor_channels = duration_predictor_channels
288
+ self.duration_predictor_kernel_size = duration_predictor_kernel_size
289
+ self.duration_predictor_layers = duration_predictor_layers
290
+ self.energy_embed_dropout = energy_embed_dropout
291
+ self.energy_embed_kernel_size = energy_embed_kernel_size
292
+ self.energy_predictor_channels = energy_predictor_channels
293
+ self.energy_predictor_dropout = energy_predictor_dropout
294
+ self.energy_predictor_kernel_size = energy_predictor_kernel_size
295
+ self.energy_predictor_layers = energy_predictor_layers
296
+ self.pitch_embed_dropout = pitch_embed_dropout
297
+ self.pitch_embed_kernel_size = pitch_embed_kernel_size
298
+ self.pitch_predictor_channels = pitch_predictor_channels
299
+ self.pitch_predictor_dropout = pitch_predictor_dropout
300
+ self.pitch_predictor_kernel_size = pitch_predictor_kernel_size
301
+ self.pitch_predictor_layers = pitch_predictor_layers
302
+ self.positionwise_conv_kernel_size = positionwise_conv_kernel_size
303
+ self.speech_decoder_postnet_units = speech_decoder_postnet_units
304
+ self.speech_decoder_postnet_dropout = speech_decoder_postnet_dropout
305
+ self.speech_decoder_postnet_kernel = speech_decoder_postnet_kernel
306
+ self.speech_decoder_postnet_layers = speech_decoder_postnet_layers
307
+ self.reduction_factor = reduction_factor
308
+ self.speaking_speed = speaking_speed
309
+ self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor
310
+ self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor
311
+ self.max_source_positions = max_source_positions
312
+ self.use_cnn_in_conformer = use_cnn_in_conformer
313
+ self.use_macaron_style_in_conformer = use_macaron_style_in_conformer
314
+ self.use_masking = use_masking
315
+ self.use_weighted_masking = use_weighted_masking
316
+ self.num_speakers = num_speakers
317
+ self.num_languages = num_languages
318
+ self.speaker_embed_dim = speaker_embed_dim
319
+ self.duration_predictor_dropout_rate = duration_predictor_dropout_rate
320
+ self.is_encoder_decoder = is_encoder_decoder
321
+
322
+ super().__init__(
323
+ is_encoder_decoder=is_encoder_decoder,
324
+ **kwargs,
325
+ )
326
+
327
+
328
+ class FastSpeech2ConformerHifiGanConfig(PretrainedConfig):
329
+ r"""
330
+ This is the configuration class to store the configuration of a [`FastSpeech2ConformerHifiGanModel`]. It is used to
331
+ instantiate a FastSpeech2Conformer HiFi-GAN vocoder model according to the specified arguments, defining the model
332
+ architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the
333
+ FastSpeech2Conformer
334
+ [espnet/fastspeech2_conformer_hifigan](https://huggingface.co/espnet/fastspeech2_conformer_hifigan) architecture.
335
+
336
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
337
+ documentation from [`PretrainedConfig`] for more information.
338
+
339
+ Args:
340
+ model_in_dim (`int`, *optional*, defaults to 80):
341
+ The number of frequency bins in the input log-mel spectrogram.
342
+ upsample_initial_channel (`int`, *optional*, defaults to 512):
343
+ The number of input channels into the upsampling network.
344
+ upsample_rates (`Tuple[int]` or `List[int]`, *optional*, defaults to `[8, 8, 2, 2]`):
345
+ A tuple of integers defining the stride of each 1D convolutional layer in the upsampling network. The
346
+ length of *upsample_rates* defines the number of convolutional layers and has to match the length of
347
+ *upsample_kernel_sizes*.
348
+ upsample_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[16, 16, 4, 4]`):
349
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the upsampling network. The
350
+ length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match the length of
351
+ *upsample_rates*.
352
+ resblock_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[3, 7, 11]`):
353
+ A tuple of integers defining the kernel sizes of the 1D convolutional layers in the multi-receptive field
354
+ fusion (MRF) module.
355
+ resblock_dilation_sizes (`Tuple[Tuple[int]]` or `List[List[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
356
+ A nested tuple of integers defining the dilation rates of the dilated 1D convolutional layers in the
357
+ multi-receptive field fusion (MRF) module.
358
+ initializer_range (`float`, *optional*, defaults to 0.01):
359
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
360
+ leaky_relu_slope (`float`, *optional*, defaults to 0.1):
361
+ The angle of the negative slope used by the leaky ReLU activation.
362
+ normalize_before (`bool`, *optional*, defaults to `True`):
363
+ Whether or not to normalize the spectrogram before vocoding using the vocoder's learned mean and variance.
364
+
365
+ Example:
366
+
367
+ ```python
368
+ >>> from transformers import FastSpeech2ConformerHifiGan, FastSpeech2ConformerHifiGanConfig
369
+
370
+ >>> # Initializing a FastSpeech2ConformerHifiGan configuration
371
+ >>> configuration = FastSpeech2ConformerHifiGanConfig()
372
+
373
+ >>> # Initializing a model (with random weights) from the configuration
374
+ >>> model = FastSpeech2ConformerHifiGan(configuration)
375
+
376
+ >>> # Accessing the model configuration
377
+ >>> configuration = model.config
378
+ ```"""
379
+
380
+ model_type = "hifigan"
381
+ base_config_key = "vocoder_config"
382
+
383
+ def __init__(
384
+ self,
385
+ model_in_dim=80,
386
+ upsample_initial_channel=512,
387
+ upsample_rates=[8, 8, 2, 2],
388
+ upsample_kernel_sizes=[16, 16, 4, 4],
389
+ resblock_kernel_sizes=[3, 7, 11],
390
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
391
+ initializer_range=0.01,
392
+ leaky_relu_slope=0.1,
393
+ normalize_before=True,
394
+ **kwargs,
395
+ ):
396
+ self.model_in_dim = model_in_dim
397
+ self.upsample_initial_channel = upsample_initial_channel
398
+ self.upsample_rates = upsample_rates
399
+ self.upsample_kernel_sizes = upsample_kernel_sizes
400
+ self.resblock_kernel_sizes = resblock_kernel_sizes
401
+ self.resblock_dilation_sizes = resblock_dilation_sizes
402
+ self.initializer_range = initializer_range
403
+ self.leaky_relu_slope = leaky_relu_slope
404
+ self.normalize_before = normalize_before
405
+ super().__init__(**kwargs)
406
+
407
+
408
+ class FastSpeech2ConformerWithHifiGanConfig(PretrainedConfig):
409
+ """
410
+ This is the configuration class to store the configuration of a [`FastSpeech2ConformerWithHifiGan`]. It is used to
411
+ instantiate a `FastSpeech2ConformerWithHifiGanModel` model according to the specified sub-models configurations,
412
+ defining the model architecture.
413
+
414
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the
415
+ FastSpeech2ConformerModel [espnet/fastspeech2_conformer](https://huggingface.co/espnet/fastspeech2_conformer) and
416
+ FastSpeech2ConformerHifiGan
417
+ [espnet/fastspeech2_conformer_hifigan](https://huggingface.co/espnet/fastspeech2_conformer_hifigan) architectures.
418
+
419
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
420
+ documentation from [`PretrainedConfig`] for more information.
421
+
422
+ Args:
423
+ model_config (`typing.Dict`, *optional*):
424
+ Configuration of the text-to-speech model.
425
+ vocoder_config (`typing.Dict`, *optional*):
426
+ Configuration of the vocoder model.
427
+ model_config ([`FastSpeech2ConformerConfig`], *optional*):
428
+ Configuration of the text-to-speech model.
429
+ vocoder_config ([`FastSpeech2ConformerHiFiGanConfig`], *optional*):
430
+ Configuration of the vocoder model.
431
+
432
+ Example:
433
+
434
+ ```python
435
+ >>> from transformers import (
436
+ ... FastSpeech2ConformerConfig,
437
+ ... FastSpeech2ConformerHifiGanConfig,
438
+ ... FastSpeech2ConformerWithHifiGanConfig,
439
+ ... FastSpeech2ConformerWithHifiGan,
440
+ ... )
441
+
442
+ >>> # Initializing FastSpeech2ConformerWithHifiGan sub-modules configurations.
443
+ >>> model_config = FastSpeech2ConformerConfig()
444
+ >>> vocoder_config = FastSpeech2ConformerHifiGanConfig()
445
+
446
+ >>> # Initializing a FastSpeech2ConformerWithHifiGan module style configuration
447
+ >>> configuration = FastSpeech2ConformerWithHifiGanConfig(model_config.to_dict(), vocoder_config.to_dict())
448
+
449
+ >>> # Initializing a model (with random weights)
450
+ >>> model = FastSpeech2ConformerWithHifiGan(configuration)
451
+
452
+ >>> # Accessing the model configuration
453
+ >>> configuration = model.config
454
+ ```
455
+ """
456
+
457
+ model_type = "fastspeech2_conformer_with_hifigan"
458
+ sub_configs = {"model_config": FastSpeech2ConformerConfig, "vocoder_config": FastSpeech2ConformerHifiGanConfig}
459
+
460
+ def __init__(
461
+ self,
462
+ model_config: Dict = None,
463
+ vocoder_config: Dict = None,
464
+ **kwargs,
465
+ ):
466
+ if model_config is None:
467
+ model_config = {}
468
+ logger.info("model_config is None. initializing the model with default values.")
469
+
470
+ if vocoder_config is None:
471
+ vocoder_config = {}
472
+ logger.info("vocoder_config is None. initializing the coarse model with default values.")
473
+
474
+ self.model_config = FastSpeech2ConformerConfig(**model_config)
475
+ self.vocoder_config = FastSpeech2ConformerHifiGanConfig(**vocoder_config)
476
+
477
+ super().__init__(**kwargs)
478
+
479
+
480
+ __all__ = ["FastSpeech2ConformerConfig", "FastSpeech2ConformerHifiGanConfig", "FastSpeech2ConformerWithHifiGanConfig"]
docs/transformers/build/lib/transformers/models/fastspeech2_conformer/convert_fastspeech2_conformer_original_pytorch_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert FastSpeech2Conformer checkpoint."""
16
+
17
+ import argparse
18
+ import json
19
+ import re
20
+ from pathlib import Path
21
+ from tempfile import TemporaryDirectory
22
+
23
+ import torch
24
+ import yaml
25
+
26
+ from transformers import (
27
+ FastSpeech2ConformerConfig,
28
+ FastSpeech2ConformerModel,
29
+ FastSpeech2ConformerTokenizer,
30
+ logging,
31
+ )
32
+
33
+
34
+ logging.set_verbosity_info()
35
+ logger = logging.get_logger("transformers.models.FastSpeech2Conformer")
36
+
37
+ CONFIG_MAPPING = {
38
+ "adim": "hidden_size",
39
+ "aheads": "num_attention_heads",
40
+ "conformer_dec_kernel_size": "decoder_kernel_size",
41
+ "conformer_enc_kernel_size": "encoder_kernel_size",
42
+ "decoder_normalize_before": "decoder_normalize_before",
43
+ "dlayers": "decoder_layers",
44
+ "dunits": "decoder_linear_units",
45
+ "duration_predictor_chans": "duration_predictor_channels",
46
+ "duration_predictor_kernel_size": "duration_predictor_kernel_size",
47
+ "duration_predictor_layers": "duration_predictor_layers",
48
+ "elayers": "encoder_layers",
49
+ "encoder_normalize_before": "encoder_normalize_before",
50
+ "energy_embed_dropout": "energy_embed_dropout",
51
+ "energy_embed_kernel_size": "energy_embed_kernel_size",
52
+ "energy_predictor_chans": "energy_predictor_channels",
53
+ "energy_predictor_dropout": "energy_predictor_dropout",
54
+ "energy_predictor_kernel_size": "energy_predictor_kernel_size",
55
+ "energy_predictor_layers": "energy_predictor_layers",
56
+ "eunits": "encoder_linear_units",
57
+ "pitch_embed_dropout": "pitch_embed_dropout",
58
+ "pitch_embed_kernel_size": "pitch_embed_kernel_size",
59
+ "pitch_predictor_chans": "pitch_predictor_channels",
60
+ "pitch_predictor_dropout": "pitch_predictor_dropout",
61
+ "pitch_predictor_kernel_size": "pitch_predictor_kernel_size",
62
+ "pitch_predictor_layers": "pitch_predictor_layers",
63
+ "positionwise_conv_kernel_size": "positionwise_conv_kernel_size",
64
+ "postnet_chans": "speech_decoder_postnet_units",
65
+ "postnet_filts": "speech_decoder_postnet_kernel",
66
+ "postnet_layers": "speech_decoder_postnet_layers",
67
+ "reduction_factor": "reduction_factor",
68
+ "stop_gradient_from_energy_predictor": "stop_gradient_from_energy_predictor",
69
+ "stop_gradient_from_pitch_predictor": "stop_gradient_from_pitch_predictor",
70
+ "transformer_dec_attn_dropout_rate": "decoder_attention_dropout_rate",
71
+ "transformer_dec_dropout_rate": "decoder_dropout_rate",
72
+ "transformer_dec_positional_dropout_rate": "decoder_positional_dropout_rate",
73
+ "transformer_enc_attn_dropout_rate": "encoder_attention_dropout_rate",
74
+ "transformer_enc_dropout_rate": "encoder_dropout_rate",
75
+ "transformer_enc_positional_dropout_rate": "encoder_positional_dropout_rate",
76
+ "use_cnn_in_conformer": "use_cnn_in_conformer",
77
+ "use_macaron_style_in_conformer": "use_macaron_style_in_conformer",
78
+ "use_masking": "use_masking",
79
+ "use_weighted_masking": "use_weighted_masking",
80
+ "idim": "input_dim",
81
+ "odim": "num_mel_bins",
82
+ "spk_embed_dim": "speaker_embed_dim",
83
+ "langs": "num_languages",
84
+ "spks": "num_speakers",
85
+ }
86
+
87
+
88
+ def remap_model_yaml_config(yaml_config_path):
89
+ with Path(yaml_config_path).open("r", encoding="utf-8") as f:
90
+ args = yaml.safe_load(f)
91
+ args = argparse.Namespace(**args)
92
+
93
+ remapped_config = {}
94
+
95
+ model_params = args.tts_conf["text2mel_params"]
96
+ # espnet_config_key -> hf_config_key, any keys not included are ignored
97
+ for espnet_config_key, hf_config_key in CONFIG_MAPPING.items():
98
+ if espnet_config_key in model_params:
99
+ remapped_config[hf_config_key] = model_params[espnet_config_key]
100
+
101
+ return remapped_config, args.g2p, args.token_list
102
+
103
+
104
+ def convert_espnet_state_dict_to_hf(state_dict):
105
+ new_state_dict = {}
106
+ for key in state_dict:
107
+ if "tts.generator.text2mel." in key:
108
+ new_key = key.replace("tts.generator.text2mel.", "")
109
+ if "postnet" in key:
110
+ new_key = new_key.replace("postnet.postnet", "speech_decoder_postnet.layers")
111
+ new_key = new_key.replace(".0.weight", ".conv.weight")
112
+ new_key = new_key.replace(".1.weight", ".batch_norm.weight")
113
+ new_key = new_key.replace(".1.bias", ".batch_norm.bias")
114
+ new_key = new_key.replace(".1.running_mean", ".batch_norm.running_mean")
115
+ new_key = new_key.replace(".1.running_var", ".batch_norm.running_var")
116
+ new_key = new_key.replace(".1.num_batches_tracked", ".batch_norm.num_batches_tracked")
117
+ if "feat_out" in key:
118
+ if "weight" in key:
119
+ new_key = "speech_decoder_postnet.feat_out.weight"
120
+ if "bias" in key:
121
+ new_key = "speech_decoder_postnet.feat_out.bias"
122
+ if "encoder.embed.0.weight" in key:
123
+ new_key = new_key.replace("0.", "")
124
+ if "w_1" in key:
125
+ new_key = new_key.replace("w_1", "conv1")
126
+ if "w_2" in key:
127
+ new_key = new_key.replace("w_2", "conv2")
128
+ if "predictor.conv" in key:
129
+ new_key = new_key.replace(".conv", ".conv_layers")
130
+ pattern = r"(\d)\.(\d)"
131
+ replacement = (
132
+ r"\1.conv" if ("2.weight" not in new_key) and ("2.bias" not in new_key) else r"\1.layer_norm"
133
+ )
134
+ new_key = re.sub(pattern, replacement, new_key)
135
+ if "pitch_embed" in key or "energy_embed" in key:
136
+ new_key = new_key.replace("0", "conv")
137
+ if "encoders" in key:
138
+ new_key = new_key.replace("encoders", "conformer_layers")
139
+ new_key = new_key.replace("norm_final", "final_layer_norm")
140
+ new_key = new_key.replace("norm_mha", "self_attn_layer_norm")
141
+ new_key = new_key.replace("norm_ff_macaron", "ff_macaron_layer_norm")
142
+ new_key = new_key.replace("norm_ff", "ff_layer_norm")
143
+ new_key = new_key.replace("norm_conv", "conv_layer_norm")
144
+ if "lid_emb" in key:
145
+ new_key = new_key.replace("lid_emb", "language_id_embedding")
146
+ if "sid_emb" in key:
147
+ new_key = new_key.replace("sid_emb", "speaker_id_embedding")
148
+
149
+ new_state_dict[new_key] = state_dict[key]
150
+
151
+ return new_state_dict
152
+
153
+
154
+ @torch.no_grad()
155
+ def convert_FastSpeech2ConformerModel_checkpoint(
156
+ checkpoint_path,
157
+ yaml_config_path,
158
+ pytorch_dump_folder_path,
159
+ repo_id=None,
160
+ ):
161
+ model_params, tokenizer_name, vocab = remap_model_yaml_config(yaml_config_path)
162
+ config = FastSpeech2ConformerConfig(**model_params)
163
+
164
+ # Prepare the model
165
+ model = FastSpeech2ConformerModel(config)
166
+
167
+ espnet_checkpoint = torch.load(checkpoint_path, weights_only=True)
168
+ hf_compatible_state_dict = convert_espnet_state_dict_to_hf(espnet_checkpoint)
169
+
170
+ model.load_state_dict(hf_compatible_state_dict)
171
+
172
+ model.save_pretrained(pytorch_dump_folder_path)
173
+
174
+ # Prepare the tokenizer
175
+ with TemporaryDirectory() as tempdir:
176
+ vocab = {token: id for id, token in enumerate(vocab)}
177
+ vocab_file = Path(tempdir) / "vocab.json"
178
+ with open(vocab_file, "w") as f:
179
+ json.dump(vocab, f)
180
+ should_strip_spaces = "no_space" in tokenizer_name
181
+ tokenizer = FastSpeech2ConformerTokenizer(str(vocab_file), should_strip_spaces=should_strip_spaces)
182
+
183
+ tokenizer.save_pretrained(pytorch_dump_folder_path)
184
+
185
+ if repo_id:
186
+ print("Pushing to the hub...")
187
+ model.push_to_hub(repo_id)
188
+ tokenizer.push_to_hub(repo_id)
189
+
190
+
191
+ if __name__ == "__main__":
192
+ parser = argparse.ArgumentParser()
193
+ parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
194
+ parser.add_argument(
195
+ "--yaml_config_path", required=True, default=None, type=str, help="Path to config.yaml of model to convert"
196
+ )
197
+ parser.add_argument(
198
+ "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
199
+ )
200
+ parser.add_argument(
201
+ "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
202
+ )
203
+
204
+ args = parser.parse_args()
205
+ convert_FastSpeech2ConformerModel_checkpoint(
206
+ args.checkpoint_path,
207
+ args.yaml_config_path,
208
+ args.pytorch_dump_folder_path,
209
+ args.push_to_hub,
210
+ )
docs/transformers/build/lib/transformers/models/fastspeech2_conformer/convert_hifigan.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert FastSpeech2Conformer HiFi-GAN checkpoint."""
16
+
17
+ import argparse
18
+ from pathlib import Path
19
+
20
+ import torch
21
+ import yaml
22
+
23
+ from transformers import FastSpeech2ConformerHifiGan, FastSpeech2ConformerHifiGanConfig, logging
24
+
25
+
26
+ logging.set_verbosity_info()
27
+ logger = logging.get_logger("transformers.models.FastSpeech2Conformer")
28
+
29
+
30
+ def load_weights(checkpoint, hf_model, config):
31
+ vocoder_key_prefix = "tts.generator.vocoder."
32
+ checkpoint = {k.replace(vocoder_key_prefix, ""): v for k, v in checkpoint.items() if vocoder_key_prefix in k}
33
+
34
+ hf_model.apply_weight_norm()
35
+
36
+ hf_model.conv_pre.weight_g.data = checkpoint["input_conv.weight_g"]
37
+ hf_model.conv_pre.weight_v.data = checkpoint["input_conv.weight_v"]
38
+ hf_model.conv_pre.bias.data = checkpoint["input_conv.bias"]
39
+
40
+ for i in range(len(config.upsample_rates)):
41
+ hf_model.upsampler[i].weight_g.data = checkpoint[f"upsamples.{i}.1.weight_g"]
42
+ hf_model.upsampler[i].weight_v.data = checkpoint[f"upsamples.{i}.1.weight_v"]
43
+ hf_model.upsampler[i].bias.data = checkpoint[f"upsamples.{i}.1.bias"]
44
+
45
+ for i in range(len(config.upsample_rates) * len(config.resblock_kernel_sizes)):
46
+ for j in range(len(config.resblock_dilation_sizes)):
47
+ hf_model.resblocks[i].convs1[j].weight_g.data = checkpoint[f"blocks.{i}.convs1.{j}.1.weight_g"]
48
+ hf_model.resblocks[i].convs1[j].weight_v.data = checkpoint[f"blocks.{i}.convs1.{j}.1.weight_v"]
49
+ hf_model.resblocks[i].convs1[j].bias.data = checkpoint[f"blocks.{i}.convs1.{j}.1.bias"]
50
+
51
+ hf_model.resblocks[i].convs2[j].weight_g.data = checkpoint[f"blocks.{i}.convs2.{j}.1.weight_g"]
52
+ hf_model.resblocks[i].convs2[j].weight_v.data = checkpoint[f"blocks.{i}.convs2.{j}.1.weight_v"]
53
+ hf_model.resblocks[i].convs2[j].bias.data = checkpoint[f"blocks.{i}.convs2.{j}.1.bias"]
54
+
55
+ hf_model.conv_post.weight_g.data = checkpoint["output_conv.1.weight_g"]
56
+ hf_model.conv_post.weight_v.data = checkpoint["output_conv.1.weight_v"]
57
+ hf_model.conv_post.bias.data = checkpoint["output_conv.1.bias"]
58
+
59
+ hf_model.remove_weight_norm()
60
+
61
+
62
+ def remap_hifigan_yaml_config(yaml_config_path):
63
+ with Path(yaml_config_path).open("r", encoding="utf-8") as f:
64
+ args = yaml.safe_load(f)
65
+ args = argparse.Namespace(**args)
66
+
67
+ vocoder_type = args.tts_conf["vocoder_type"]
68
+ if vocoder_type != "hifigan_generator":
69
+ raise TypeError(f"Vocoder config must be for `hifigan_generator`, but got {vocoder_type}")
70
+
71
+ remapped_dict = {}
72
+ vocoder_params = args.tts_conf["vocoder_params"]
73
+
74
+ # espnet_config_key -> hf_config_key
75
+ key_mappings = {
76
+ "channels": "upsample_initial_channel",
77
+ "in_channels": "model_in_dim",
78
+ "resblock_dilations": "resblock_dilation_sizes",
79
+ "resblock_kernel_sizes": "resblock_kernel_sizes",
80
+ "upsample_kernel_sizes": "upsample_kernel_sizes",
81
+ "upsample_scales": "upsample_rates",
82
+ }
83
+ for espnet_config_key, hf_config_key in key_mappings.items():
84
+ remapped_dict[hf_config_key] = vocoder_params[espnet_config_key]
85
+ remapped_dict["sampling_rate"] = args.tts_conf["sampling_rate"]
86
+ remapped_dict["normalize_before"] = False
87
+ remapped_dict["leaky_relu_slope"] = vocoder_params["nonlinear_activation_params"]["negative_slope"]
88
+
89
+ return remapped_dict
90
+
91
+
92
+ @torch.no_grad()
93
+ def convert_hifigan_checkpoint(
94
+ checkpoint_path,
95
+ pytorch_dump_folder_path,
96
+ yaml_config_path=None,
97
+ repo_id=None,
98
+ ):
99
+ if yaml_config_path is not None:
100
+ config_kwargs = remap_hifigan_yaml_config(yaml_config_path)
101
+ config = FastSpeech2ConformerHifiGanConfig(**config_kwargs)
102
+ else:
103
+ config = FastSpeech2ConformerHifiGanConfig()
104
+
105
+ model = FastSpeech2ConformerHifiGan(config)
106
+
107
+ orig_checkpoint = torch.load(checkpoint_path, weights_only=True)
108
+ load_weights(orig_checkpoint, model, config)
109
+
110
+ model.save_pretrained(pytorch_dump_folder_path)
111
+
112
+ if repo_id:
113
+ print("Pushing to the hub...")
114
+ model.push_to_hub(repo_id)
115
+
116
+
117
+ if __name__ == "__main__":
118
+ parser = argparse.ArgumentParser()
119
+ parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
120
+ parser.add_argument("--yaml_config_path", default=None, type=str, help="Path to config.yaml of model to convert")
121
+ parser.add_argument(
122
+ "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
123
+ )
124
+ parser.add_argument(
125
+ "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
126
+ )
127
+
128
+ args = parser.parse_args()
129
+ convert_hifigan_checkpoint(
130
+ args.checkpoint_path,
131
+ args.pytorch_dump_folder_path,
132
+ args.yaml_config_path,
133
+ args.push_to_hub,
134
+ )
docs/transformers/build/lib/transformers/models/fastspeech2_conformer/convert_model_with_hifigan.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert FastSpeech2Conformer checkpoint."""
16
+
17
+ import argparse
18
+
19
+ import torch
20
+
21
+ from transformers import (
22
+ FastSpeech2ConformerConfig,
23
+ FastSpeech2ConformerHifiGan,
24
+ FastSpeech2ConformerHifiGanConfig,
25
+ FastSpeech2ConformerModel,
26
+ FastSpeech2ConformerWithHifiGan,
27
+ FastSpeech2ConformerWithHifiGanConfig,
28
+ logging,
29
+ )
30
+
31
+ from .convert_fastspeech2_conformer_original_pytorch_checkpoint_to_pytorch import (
32
+ convert_espnet_state_dict_to_hf,
33
+ remap_model_yaml_config,
34
+ )
35
+ from .convert_hifigan import load_weights, remap_hifigan_yaml_config
36
+
37
+
38
+ logging.set_verbosity_info()
39
+ logger = logging.get_logger("transformers.models.FastSpeech2Conformer")
40
+
41
+
42
+ def convert_FastSpeech2ConformerWithHifiGan_checkpoint(
43
+ checkpoint_path,
44
+ yaml_config_path,
45
+ pytorch_dump_folder_path,
46
+ repo_id=None,
47
+ ):
48
+ # Prepare the model
49
+ model_params, *_ = remap_model_yaml_config(yaml_config_path)
50
+ model_config = FastSpeech2ConformerConfig(**model_params)
51
+
52
+ model = FastSpeech2ConformerModel(model_config)
53
+
54
+ espnet_checkpoint = torch.load(checkpoint_path, weights_only=True)
55
+ hf_compatible_state_dict = convert_espnet_state_dict_to_hf(espnet_checkpoint)
56
+ model.load_state_dict(hf_compatible_state_dict)
57
+
58
+ # Prepare the vocoder
59
+ config_kwargs = remap_hifigan_yaml_config(yaml_config_path)
60
+ vocoder_config = FastSpeech2ConformerHifiGanConfig(**config_kwargs)
61
+
62
+ vocoder = FastSpeech2ConformerHifiGan(vocoder_config)
63
+ load_weights(espnet_checkpoint, vocoder, vocoder_config)
64
+
65
+ # Prepare the model + vocoder
66
+ config = FastSpeech2ConformerWithHifiGanConfig.from_sub_model_configs(model_config, vocoder_config)
67
+ with_hifigan_model = FastSpeech2ConformerWithHifiGan(config)
68
+ with_hifigan_model.model = model
69
+ with_hifigan_model.vocoder = vocoder
70
+
71
+ with_hifigan_model.save_pretrained(pytorch_dump_folder_path)
72
+
73
+ if repo_id:
74
+ print("Pushing to the hub...")
75
+ with_hifigan_model.push_to_hub(repo_id)
76
+
77
+
78
+ if __name__ == "__main__":
79
+ parser = argparse.ArgumentParser()
80
+ parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
81
+ parser.add_argument(
82
+ "--yaml_config_path", required=True, default=None, type=str, help="Path to config.yaml of model to convert"
83
+ )
84
+ parser.add_argument(
85
+ "--pytorch_dump_folder_path",
86
+ required=True,
87
+ default=None,
88
+ type=str,
89
+ help="Path to the output `FastSpeech2ConformerModel` PyTorch model.",
90
+ )
91
+ parser.add_argument(
92
+ "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
93
+ )
94
+
95
+ args = parser.parse_args()
96
+
97
+ convert_FastSpeech2ConformerWithHifiGan_checkpoint(
98
+ args.checkpoint_path,
99
+ args.yaml_config_path,
100
+ args.pytorch_dump_folder_path,
101
+ args.push_to_hub,
102
+ )
docs/transformers/build/lib/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py ADDED
@@ -0,0 +1,1697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Espnet authors, IMS Toucan authors, and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch FastSpeech2Conformer model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ from torch import nn
23
+
24
+ from ...modeling_outputs import BaseModelOutput
25
+ from ...modeling_utils import PreTrainedModel
26
+ from ...utils import ModelOutput, add_start_docstrings, logging, replace_return_docstrings
27
+ from .configuration_fastspeech2_conformer import (
28
+ FastSpeech2ConformerConfig,
29
+ FastSpeech2ConformerHifiGanConfig,
30
+ FastSpeech2ConformerWithHifiGanConfig,
31
+ )
32
+
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+
37
+ @dataclass
38
+ class FastSpeech2ConformerModelOutput(ModelOutput):
39
+ """
40
+ Output type of [`FastSpeech2ConformerModel`].
41
+
42
+ Args:
43
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
44
+ Spectrogram generation loss.
45
+ spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`):
46
+ The predicted spectrogram.
47
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
48
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
49
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
50
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
51
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
52
+
53
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
54
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
55
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
56
+ sequence_length)`.
57
+
58
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
59
+ self-attention heads.
60
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
61
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
62
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
63
+
64
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
65
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
66
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
67
+ sequence_length)`.
68
+
69
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
70
+ self-attention heads.
71
+ duration_outputs (`torch.LongTensor` of shape `(batch_size, max_text_length + 1)`, *optional*):
72
+ Outputs of the duration predictor.
73
+ pitch_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
74
+ Outputs of the pitch predictor.
75
+ energy_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
76
+ Outputs of the energy predictor.
77
+
78
+ """
79
+
80
+ loss: Optional[torch.FloatTensor] = None
81
+ spectrogram: Optional[torch.FloatTensor] = None
82
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
83
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
84
+ encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
85
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
86
+ decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
87
+ duration_outputs: Optional[torch.LongTensor] = None
88
+ pitch_outputs: Optional[torch.FloatTensor] = None
89
+ energy_outputs: Optional[torch.FloatTensor] = None
90
+
91
+
92
+ @dataclass
93
+ class FastSpeech2ConformerWithHifiGanOutput(FastSpeech2ConformerModelOutput):
94
+ """
95
+ Output type of [`FastSpeech2ConformerWithHifiGan`].
96
+
97
+ Args:
98
+ waveform (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
99
+ Speech output as a result of passing the predicted mel spectrogram through the vocoder.
100
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
101
+ Spectrogram generation loss.
102
+ spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`):
103
+ The predicted spectrogram.
104
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
105
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
106
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
107
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
108
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
109
+
110
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
111
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
112
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
113
+ sequence_length)`.
114
+
115
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
116
+ self-attention heads.
117
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
118
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
119
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
120
+
121
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
122
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
123
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
124
+ sequence_length)`.
125
+
126
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
127
+ self-attention heads.
128
+ duration_outputs (`torch.LongTensor` of shape `(batch_size, max_text_length + 1)`, *optional*):
129
+ Outputs of the duration predictor.
130
+ pitch_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
131
+ Outputs of the pitch predictor.
132
+ energy_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
133
+ Outputs of the energy predictor.
134
+ """
135
+
136
+ waveform: Optional[torch.FloatTensor] = None
137
+
138
+
139
+ _CONFIG_FOR_DOC = "FastSpeech2ConformerConfig"
140
+
141
+ FASTSPEECH2_CONFORMER_START_DOCSTRING = r"""
142
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
143
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
144
+ etc.)
145
+
146
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
147
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
148
+ and behavior.
149
+
150
+ Parameters:
151
+ config ([`FastSpeech2ConformerConfig`]):
152
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
153
+ load the weights associated with the model, only the configuration. Check out the
154
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
155
+ """
156
+
157
+
158
+ HIFIGAN_START_DOCSTRING = r"""
159
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
160
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
161
+ etc.)
162
+
163
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
164
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
165
+ and behavior.
166
+
167
+ Parameters:
168
+ config ([`FastSpeech2ConformerConfig`]):
169
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
170
+ load the weights associated with the model, only the configuration. Check out the
171
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
172
+ """
173
+
174
+ FASTSPEECH2_CONFORMER_WITH_HIFIGAN_START_DOCSTRING = r"""
175
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
176
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
177
+ etc.)
178
+
179
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
180
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
181
+ and behavior.
182
+
183
+ Parameters:
184
+ config ([`FastSpeech2ConformerWithHifiGanConfig`]):
185
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
186
+ load the weights associated with the model, only the configuration. Check out the
187
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
188
+ """
189
+
190
+
191
+ def length_regulator(encoded_embeddings, duration_labels, speaking_speed=1.0):
192
+ """
193
+ Length regulator for feed-forward Transformer.
194
+
195
+ This is the length regulator module described in `FastSpeech: Fast, Robust and Controllable Text to Speech`
196
+ https://arxiv.org/pdf/1905.09263.pdf. The length regulator expands char or phoneme-level embedding features to
197
+ frame-level by repeating each feature based on the corresponding predicted durations.
198
+
199
+ Args:
200
+ encoded_embeddings (`torch.Tensor` of shape `(batch_size, max_text_length, embedding_dim)`):
201
+ Batch of sequences of char or phoneme embeddings.
202
+ duration_labels (`torch.LongTensor` of shape `(batch_size, time)`):
203
+ Batch of durations of each frame.
204
+ speaking_speed (`float`, *optional*, defaults to 1.0):
205
+ Value to control speed of speech.
206
+
207
+ Returns:
208
+ `torch.Tensor`:
209
+ Replicated input tensor based on durations (batch_size, time*, embedding_dim).
210
+ """
211
+
212
+ if speaking_speed <= 0:
213
+ raise ValueError("`speaking_speed` must be greater than 0.")
214
+ elif speaking_speed != 1.0:
215
+ duration_labels = torch.round(duration_labels.float() * speaking_speed).long()
216
+
217
+ if duration_labels.sum() == 0:
218
+ duration_labels[duration_labels.sum(dim=1).eq(0)] = 1
219
+
220
+ # Calculate the maximum length needed
221
+ max_len = torch.sum(duration_labels, dim=1).max()
222
+
223
+ # Create a padded tensor to hold the results
224
+ hidden_states = torch.zeros(
225
+ (encoded_embeddings.size(0), max_len, encoded_embeddings.size(2)),
226
+ dtype=torch.float,
227
+ device=encoded_embeddings.device,
228
+ )
229
+
230
+ # Loop through the batch and fill in the data
231
+ for i, (encoded_embedding, target_duration) in enumerate(zip(encoded_embeddings, duration_labels)):
232
+ repeated = torch.repeat_interleave(encoded_embedding, target_duration, dim=0)
233
+ hidden_states[i, : repeated.size(0)] = repeated
234
+
235
+ return hidden_states
236
+
237
+
238
+ class FastSpeech2ConformerDurationPredictor(nn.Module):
239
+ """
240
+ Duration predictor module.
241
+
242
+ This is a module of duration predictor described in the paper 'FastSpeech: Fast, Robust and Controllable Text to
243
+ Speech' https://arxiv.org/pdf/1905.09263.pdf The duration predictor predicts a duration of each frame in log domain
244
+ from the hidden embeddings of encoder.
245
+
246
+ Note:
247
+ The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`, the
248
+ outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
249
+
250
+ """
251
+
252
+ def __init__(self, config: FastSpeech2ConformerConfig):
253
+ super().__init__()
254
+
255
+ self.conv_layers = nn.ModuleList()
256
+ self.log_domain_offset = 1.0
257
+
258
+ for layer_idx in range(config.duration_predictor_layers):
259
+ num_chans = config.duration_predictor_channels
260
+ input_channels = config.hidden_size if layer_idx == 0 else num_chans
261
+ layer = FastSpeech2ConformerPredictorLayer(
262
+ input_channels,
263
+ num_chans,
264
+ config.duration_predictor_kernel_size,
265
+ config.duration_predictor_dropout_rate,
266
+ )
267
+ self.conv_layers.append(layer)
268
+ self.linear = nn.Linear(config.duration_predictor_channels, 1)
269
+
270
+ def forward(self, encoder_hidden_states):
271
+ """
272
+ Args:
273
+ hidden_states (`torch.Tensor` of shape `(batch_size, max_text_length, input_dim)`):
274
+ Batch of input sequences.
275
+ padding_masks (`torch.ByteTensor` of shape `(batch_size, max_text_length)`, *optional*):
276
+ Batch of masks indicating padded part.
277
+
278
+ Returns:
279
+ `torch.Tensor`: Batch of predicted durations in log domain `(batch_size, max_text_length)`.
280
+
281
+ """
282
+ # (batch_size, input_dim, max_text_length)
283
+ hidden_states = encoder_hidden_states.transpose(1, -1)
284
+ for layer in self.conv_layers:
285
+ hidden_states = layer(hidden_states)
286
+
287
+ # NOTE: calculate in log domain, (batch_size, max_text_length)
288
+ hidden_states = self.linear(hidden_states.transpose(1, -1)).squeeze(-1)
289
+
290
+ if not self.training:
291
+ # NOTE: calculate in linear domain
292
+ hidden_states = torch.clamp(torch.round(hidden_states.exp() - self.log_domain_offset), min=0).long()
293
+
294
+ return hidden_states
295
+
296
+
297
+ # Copied from transformers.models.speecht5.modeling_speecht5.SpeechT5BatchNormConvLayer
298
+ class FastSpeech2ConformerBatchNormConvLayer(nn.Module):
299
+ def __init__(self, config, layer_id=0):
300
+ super().__init__()
301
+
302
+ if layer_id == 0:
303
+ in_conv_dim = config.num_mel_bins
304
+ else:
305
+ in_conv_dim = config.speech_decoder_postnet_units
306
+
307
+ if layer_id == config.speech_decoder_postnet_layers - 1:
308
+ out_conv_dim = config.num_mel_bins
309
+ else:
310
+ out_conv_dim = config.speech_decoder_postnet_units
311
+
312
+ self.conv = nn.Conv1d(
313
+ in_conv_dim,
314
+ out_conv_dim,
315
+ kernel_size=config.speech_decoder_postnet_kernel,
316
+ stride=1,
317
+ padding=(config.speech_decoder_postnet_kernel - 1) // 2,
318
+ bias=False,
319
+ )
320
+ self.batch_norm = nn.BatchNorm1d(out_conv_dim)
321
+
322
+ if layer_id < config.speech_decoder_postnet_layers - 1:
323
+ self.activation = nn.Tanh()
324
+ else:
325
+ self.activation = None
326
+
327
+ self.dropout = nn.Dropout(config.speech_decoder_postnet_dropout)
328
+
329
+ def forward(self, hidden_states):
330
+ hidden_states = self.conv(hidden_states)
331
+ hidden_states = self.batch_norm(hidden_states)
332
+ if self.activation is not None:
333
+ hidden_states = self.activation(hidden_states)
334
+ hidden_states = self.dropout(hidden_states)
335
+ return hidden_states
336
+
337
+
338
+ class FastSpeech2ConformerSpeechDecoderPostnet(nn.Module):
339
+ def __init__(self, config):
340
+ super().__init__()
341
+ self.config = config
342
+ self.feat_out = nn.Linear(config.hidden_size, config.num_mel_bins * config.reduction_factor)
343
+ self.layers = nn.ModuleList(
344
+ [FastSpeech2ConformerBatchNormConvLayer(config, i) for i in range(config.speech_decoder_postnet_layers)]
345
+ )
346
+
347
+ def forward(self, hidden_states: torch.Tensor):
348
+ outputs_before_postnet = self.feat_out(hidden_states).view(hidden_states.size(0), -1, self.config.num_mel_bins)
349
+ layer_output = outputs_before_postnet.transpose(1, 2)
350
+ for layer in self.layers:
351
+ layer_output = layer(layer_output)
352
+ outputs_after_postnet = outputs_before_postnet + layer_output.transpose(1, 2)
353
+ return outputs_before_postnet, outputs_after_postnet
354
+
355
+
356
+ class FastSpeech2ConformerPredictorLayer(nn.Module):
357
+ def __init__(self, input_channels, num_chans, kernel_size, dropout_rate):
358
+ super().__init__()
359
+ self.conv = nn.Conv1d(
360
+ input_channels,
361
+ num_chans,
362
+ kernel_size,
363
+ stride=1,
364
+ padding=(kernel_size - 1) // 2,
365
+ )
366
+ self.activation = nn.ReLU()
367
+ self.layer_norm = nn.LayerNorm(num_chans)
368
+ self.dropout = nn.Dropout(dropout_rate)
369
+
370
+ def forward(self, hidden_states):
371
+ hidden_states = self.conv(hidden_states)
372
+ hidden_states = self.activation(hidden_states)
373
+
374
+ # Perform layer norm on dimension 1
375
+ hidden_states = hidden_states.transpose(1, -1)
376
+ hidden_states = self.layer_norm(hidden_states)
377
+ hidden_states = hidden_states.transpose(1, -1)
378
+
379
+ hidden_states = self.dropout(hidden_states)
380
+
381
+ return hidden_states
382
+
383
+
384
+ class FastSpeech2ConformerVariancePredictor(nn.Module):
385
+ def __init__(
386
+ self,
387
+ config: FastSpeech2ConformerConfig,
388
+ num_layers=2,
389
+ num_chans=384,
390
+ kernel_size=3,
391
+ dropout_rate=0.5,
392
+ ):
393
+ """
394
+ Initilize variance predictor module.
395
+
396
+ Args:
397
+ input_dim (`int`): Input dimension.
398
+ num_layers (`int`, *optional*, defaults to 2): Number of convolutional layers.
399
+ num_chans (`int`, *optional*, defaults to 384): Number of channels of convolutional layers.
400
+ kernel_size (`int`, *optional*, defaults to 3): Kernel size of convolutional layers.
401
+ dropout_rate (`float`, *optional*, defaults to 0.5): Dropout rate.
402
+ """
403
+ super().__init__()
404
+ self.conv_layers = nn.ModuleList()
405
+ for idx in range(num_layers):
406
+ input_channels = config.hidden_size if idx == 0 else num_chans
407
+ layer = FastSpeech2ConformerPredictorLayer(input_channels, num_chans, kernel_size, dropout_rate)
408
+ self.conv_layers.append(layer)
409
+ self.linear = nn.Linear(num_chans, 1)
410
+
411
+ def forward(self, encoder_hidden_states, padding_masks=None):
412
+ """
413
+ Calculate forward propagation.
414
+
415
+ Args:
416
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, max_text_length, input_dim)`):
417
+ Batch of input sequences.
418
+ padding_masks (`torch.ByteTensor` of shape `(batch_size, max_text_length)`, *optional*):
419
+ Batch of masks indicating padded part.
420
+
421
+ Returns:
422
+ Tensor: Batch of predicted sequences `(batch_size, max_text_length, 1)`.
423
+ """
424
+ # (batch_size, input_dim, max_text_length)
425
+ hidden_states = encoder_hidden_states.transpose(1, -1)
426
+ for layer in self.conv_layers:
427
+ hidden_states = layer(hidden_states)
428
+
429
+ hidden_states = self.linear(hidden_states.transpose(1, 2))
430
+
431
+ if padding_masks is not None:
432
+ hidden_states = hidden_states.masked_fill(padding_masks, 0.0)
433
+
434
+ return hidden_states
435
+
436
+
437
+ class FastSpeech2ConformerVarianceEmbedding(nn.Module):
438
+ def __init__(
439
+ self,
440
+ in_channels=1,
441
+ out_channels=384,
442
+ kernel_size=1,
443
+ padding=0,
444
+ dropout_rate=0.0,
445
+ ):
446
+ super().__init__()
447
+ self.conv = nn.Conv1d(
448
+ in_channels=in_channels,
449
+ out_channels=out_channels,
450
+ kernel_size=kernel_size,
451
+ padding=padding,
452
+ )
453
+ self.dropout = nn.Dropout(dropout_rate)
454
+
455
+ def forward(self, hidden_states):
456
+ hidden_states = hidden_states.transpose(1, 2)
457
+ hidden_states = self.conv(hidden_states)
458
+ hidden_states = self.dropout(hidden_states)
459
+ hidden_states = hidden_states.transpose(1, 2)
460
+ return hidden_states
461
+
462
+
463
+ class FastSpeech2ConformerAttention(nn.Module):
464
+ """
465
+ Multi-Head attention layer with relative position encoding. Details can be found in
466
+ https://github.com/espnet/espnet/pull/2816. Paper: https://arxiv.org/abs/1901.02860.
467
+ """
468
+
469
+ def __init__(self, config: FastSpeech2ConformerConfig, module_config):
470
+ """Construct an FastSpeech2ConformerAttention object."""
471
+ super().__init__()
472
+ # We assume d_v always equals dim_key
473
+ self.num_heads = module_config["num_attention_heads"]
474
+ self.hidden_size = config.hidden_size
475
+ self.dim_key = self.hidden_size // self.num_heads
476
+ self.head_dim = self.hidden_size // self.num_heads
477
+ self.linear_q = nn.Linear(self.hidden_size, self.hidden_size)
478
+ self.linear_k = nn.Linear(self.hidden_size, self.hidden_size)
479
+ self.linear_v = nn.Linear(self.hidden_size, self.hidden_size)
480
+ self.linear_out = nn.Linear(self.hidden_size, self.hidden_size)
481
+ self.dropout = nn.Dropout(p=module_config["attention_dropout_rate"])
482
+
483
+ # linear transformation for positional encoding
484
+ self.linear_pos = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
485
+ # these two learnable bias are used in matrix c and matrix d
486
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
487
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.num_heads, self.head_dim))
488
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.num_heads, self.head_dim))
489
+
490
+ def shift_relative_position_tensor(self, pos_tensor):
491
+ """
492
+ Args:
493
+ pos_tensor (torch.Tensor of shape (batch_size, head, time1, 2*time1-1)): Input tensor.
494
+ """
495
+ zero_pad = torch.zeros((*pos_tensor.size()[:3], 1), device=pos_tensor.device, dtype=pos_tensor.dtype)
496
+ pos_tensor_padded = torch.cat([zero_pad, pos_tensor], dim=-1)
497
+
498
+ pos_tensor_padded = pos_tensor_padded.view(*pos_tensor.size()[:2], pos_tensor.size(3) + 1, pos_tensor.size(2))
499
+ # only keep the positions from 0 to time2
500
+ pos_tensor = pos_tensor_padded[:, :, 1:].view_as(pos_tensor)[:, :, :, : pos_tensor.size(-1) // 2 + 1]
501
+
502
+ return pos_tensor
503
+
504
+ def forward(
505
+ self,
506
+ hidden_states: torch.Tensor,
507
+ attention_mask: Optional[torch.Tensor] = None,
508
+ pos_emb: Optional[torch.Tensor] = None,
509
+ output_attentions: Optional[torch.Tensor] = False,
510
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
511
+ """
512
+ Compute 'Scaled Dot Product Attention' with rel. positional encoding.
513
+
514
+ Args:
515
+ hidden_states (`torch.Tensor` of shape `(batch, time2, size)`): Values of the hidden states
516
+ attention_mask (`torch.Tensor` of shape `(batch, time1, time2)`): Mask tensor.
517
+ pos_emb (`torch.Tensor` of shape `(batch, 2*time1-1, size)`): Positional embedding tensor.
518
+ output_attentions (`bool`, *optional*):
519
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
520
+ returned tensors for more detail.
521
+ Returns:
522
+ `torch.Tensor`: Output tensor of shape `(batch, time1, d_model)`.
523
+ """
524
+ bsz, q_len, _ = hidden_states.size()
525
+ query_states = self.linear_q(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
526
+ key_states = self.linear_k(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
527
+ value_states = self.linear_v(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
528
+
529
+ bsz_pos = pos_emb.size(0)
530
+ pos_encoding = self.linear_pos(pos_emb).view(bsz_pos, -1, self.num_heads, self.head_dim)
531
+
532
+ # (batch_size, head, time1, dim_key)
533
+ query_with_bias_u = (query_states + self.pos_bias_u).transpose(1, 2)
534
+ # (batch_size, head, time1, dim_key)
535
+ query_with_bias_v = (query_states + self.pos_bias_v).transpose(1, 2)
536
+
537
+ # compute attention score
538
+ # first compute matrix a and matrix c
539
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
540
+ # (batch_size, head, time1, time2)
541
+ matrix_ac = torch.matmul(query_with_bias_u, key_states.permute(0, 2, 3, 1))
542
+
543
+ # compute matrix b and matrix d
544
+ # (batch_size, head, time1, 2*time1-1)
545
+ matrix_bd = torch.matmul(query_with_bias_v, pos_encoding.permute(0, 2, 3, 1))
546
+ matrix_bd = self.shift_relative_position_tensor(matrix_bd)
547
+
548
+ # (batch_size, head, time1, time2)
549
+ scores = (matrix_ac + matrix_bd) / math.sqrt(self.dim_key)
550
+
551
+ # Forward attention
552
+ if attention_mask is not None:
553
+ expected_size = (bsz, 1, q_len)
554
+ if attention_mask.size() != expected_size:
555
+ raise ValueError(f"Attention mask should be of size {expected_size}, but is {attention_mask.size()}")
556
+ attention_mask = attention_mask.unsqueeze(1).eq(0)
557
+ min_value = float(torch.finfo(scores.dtype).min)
558
+ scores = scores.masked_fill(attention_mask, min_value)
559
+ attn_weights = torch.softmax(scores, dim=-1).masked_fill(attention_mask, 0.0)
560
+ else:
561
+ attn_weights = torch.softmax(scores, dim=-1)
562
+
563
+ attn_weights = self.dropout(attn_weights)
564
+ attn_output = torch.matmul(attn_weights, value_states.transpose(1, 2))
565
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
566
+
567
+ attn_output = self.linear_out(attn_output)
568
+
569
+ if not output_attentions:
570
+ attn_weights = None
571
+
572
+ return attn_output, attn_weights
573
+
574
+
575
+ class FastSpeech2ConformerConvolutionModule(nn.Module):
576
+ def __init__(self, config: FastSpeech2ConformerConfig, module_config):
577
+ super().__init__()
578
+ # kernel_size should be an odd number for 'SAME' padding
579
+ channels = config.hidden_size
580
+ kernel_size = module_config["kernel_size"]
581
+ self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True)
582
+ self.depthwise_conv = nn.Conv1d(
583
+ channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=True
584
+ )
585
+ self.norm = nn.BatchNorm1d(channels)
586
+ self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True)
587
+
588
+ def forward(self, hidden_states):
589
+ """
590
+ Compute convolution module.
591
+
592
+ Args:
593
+ hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
594
+
595
+ Returns:
596
+ `torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
597
+
598
+ """
599
+ # exchange the temporal dimension and the feature dimension
600
+ hidden_states = hidden_states.transpose(1, 2)
601
+
602
+ # GLU mechanism, (batch_size, 2*channel, dim)
603
+ hidden_states = self.pointwise_conv1(hidden_states)
604
+ # (batch_size, channel, dim)
605
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
606
+
607
+ # 1D Depthwise Conv
608
+ hidden_states = self.depthwise_conv(hidden_states)
609
+ hidden_states = self.norm(hidden_states)
610
+
611
+ hidden_states = hidden_states * torch.sigmoid(hidden_states)
612
+
613
+ hidden_states = self.pointwise_conv2(hidden_states)
614
+
615
+ return hidden_states.transpose(1, 2)
616
+
617
+
618
+ class FastSpeech2ConformerEncoderLayer(nn.Module):
619
+ def __init__(self, config: FastSpeech2ConformerConfig, module_config):
620
+ super().__init__()
621
+
622
+ # self-attention module definition
623
+ self.self_attn = FastSpeech2ConformerAttention(config, module_config)
624
+
625
+ # feed-forward module definition
626
+ self.feed_forward = FastSpeech2ConformerMultiLayeredConv1d(config, module_config)
627
+
628
+ self.macaron_style = config.use_macaron_style_in_conformer
629
+ if self.macaron_style:
630
+ self.feed_forward_macaron = FastSpeech2ConformerMultiLayeredConv1d(config, module_config)
631
+ self.ff_macaron_layer_norm = nn.LayerNorm(config.hidden_size)
632
+ self.ff_scale = 0.5
633
+ else:
634
+ self.ff_scale = 1.0
635
+
636
+ # convolution module definition
637
+ self.use_cnn_module = config.use_cnn_in_conformer
638
+ if self.use_cnn_module:
639
+ self.conv_module = FastSpeech2ConformerConvolutionModule(config, module_config)
640
+ self.conv_layer_norm = nn.LayerNorm(config.hidden_size)
641
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size)
642
+
643
+ self.ff_layer_norm = nn.LayerNorm(config.hidden_size)
644
+
645
+ self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size)
646
+
647
+ self.dropout = nn.Dropout(module_config["dropout_rate"])
648
+ self.size = config.hidden_size
649
+ self.normalize_before = module_config["normalize_before"]
650
+ self.concat_after = module_config["concat_after"]
651
+ if self.concat_after:
652
+ self.concat_linear = nn.Linear(config.hidden_size + config.hidden_size, config.hidden_size)
653
+
654
+ def forward(
655
+ self,
656
+ hidden_states: torch.Tensor,
657
+ pos_emb: Optional[torch.Tensor] = None,
658
+ attention_mask: Optional[torch.Tensor] = None,
659
+ output_attentions: Optional[torch.Tensor] = False,
660
+ ):
661
+ """
662
+ Compute encoded features.
663
+
664
+ Args:
665
+ hidden_states (`torch.Tensor` of shape `(batch, time, size)`): Input tensor.
666
+ pos_emb (`torch.Tensor` of shape `(1, time, size)`): Positional embeddings tensor.
667
+ attention_mask (`torch.Tensor` of shape `(batch, time)`): Attention mask tensor for the input.
668
+ output_attentions (`bool`, *optional*):
669
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
670
+ returned tensors for more detail.
671
+ Returns:
672
+ `torch.Tensor`: Output tensor of shape `(batch, time, size)`.
673
+
674
+ """
675
+ # whether to use macaron style
676
+ if self.macaron_style:
677
+ residual = hidden_states
678
+ if self.normalize_before:
679
+ hidden_states = self.ff_macaron_layer_norm(hidden_states)
680
+ hidden_states = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(hidden_states))
681
+ if not self.normalize_before:
682
+ hidden_states = self.ff_macaron_layer_norm(hidden_states)
683
+
684
+ # multi-headed self-attention module
685
+ residual = hidden_states
686
+ if self.normalize_before:
687
+ hidden_states = self.self_attn_layer_norm(hidden_states)
688
+
689
+ attention_output, attention_scores = self.self_attn(
690
+ hidden_states, attention_mask=attention_mask, pos_emb=pos_emb, output_attentions=output_attentions
691
+ )
692
+
693
+ if self.concat_after:
694
+ x_concat = torch.cat((hidden_states, attention_output), dim=-1)
695
+ hidden_states = self.concat_linear(x_concat)
696
+ hidden_states = residual + hidden_states
697
+ else:
698
+ hidden_states = self.dropout(attention_output)
699
+ hidden_states = residual + hidden_states
700
+ if not self.normalize_before:
701
+ hidden_states = self.self_attn_layer_norm(hidden_states)
702
+
703
+ # convolution module
704
+ if self.use_cnn_module:
705
+ residual = hidden_states
706
+ if self.normalize_before:
707
+ hidden_states = self.conv_layer_norm(hidden_states)
708
+ hidden_states = self.conv_module(hidden_states)
709
+ hidden_states = self.dropout(hidden_states)
710
+ hidden_states = residual + hidden_states
711
+ if not self.normalize_before:
712
+ hidden_states = self.conv_layer_norm(hidden_states)
713
+
714
+ # feed forward module
715
+ residual = hidden_states
716
+ if self.normalize_before:
717
+ hidden_states = self.ff_layer_norm(hidden_states)
718
+ hidden_states = self.feed_forward(hidden_states)
719
+ hidden_states = self.dropout(hidden_states)
720
+ hidden_states = residual + self.ff_scale * hidden_states
721
+ if not self.normalize_before:
722
+ hidden_states = self.ff_layer_norm(hidden_states)
723
+
724
+ if self.conv_module is not None:
725
+ hidden_states = self.final_layer_norm(hidden_states)
726
+
727
+ outputs = (hidden_states,)
728
+
729
+ if output_attentions:
730
+ outputs += (attention_scores,)
731
+
732
+ return outputs
733
+
734
+
735
+ class FastSpeech2ConformerMultiLayeredConv1d(nn.Module):
736
+ """
737
+ Multi-layered conv1d for Transformer block.
738
+
739
+ This is a module of multi-layered conv1d designed to replace positionwise feed-forward network in Transformer
740
+ block, which is introduced in 'FastSpeech: Fast, Robust and Controllable Text to Speech'
741
+ https://arxiv.org/pdf/1905.09263.pdf
742
+ """
743
+
744
+ def __init__(self, config: FastSpeech2ConformerConfig, module_config):
745
+ """
746
+ Initialize FastSpeech2ConformerMultiLayeredConv1d module.
747
+
748
+ Args:
749
+ input_channels (`int`): Number of input channels.
750
+ hidden_channels (`int`): Number of hidden channels.
751
+ kernel_size (`int`): Kernel size of conv1d.
752
+ dropout_rate (`float`): Dropout rate.
753
+ """
754
+ super().__init__()
755
+ input_channels = config.hidden_size
756
+ hidden_channels = module_config["linear_units"]
757
+ kernel_size = config.positionwise_conv_kernel_size
758
+ self.conv1 = nn.Conv1d(input_channels, hidden_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2)
759
+ self.conv2 = nn.Conv1d(hidden_channels, input_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2)
760
+ self.dropout = nn.Dropout(module_config["dropout_rate"])
761
+
762
+ def forward(self, hidden_states):
763
+ """
764
+ Calculate forward propagation.
765
+
766
+ Args:
767
+ hidden_states (torch.Tensor): Batch of input tensors (batch_size, time, input_channels).
768
+
769
+ Returns:
770
+ torch.Tensor: Batch of output tensors (batch_size, time, hidden_channels).
771
+ """
772
+ hidden_states = hidden_states.transpose(-1, 1)
773
+ hidden_states = self.conv1(hidden_states)
774
+ hidden_states = torch.relu(hidden_states)
775
+ hidden_states = self.dropout(hidden_states)
776
+ hidden_states = self.conv2(hidden_states)
777
+ hidden_states = hidden_states.transpose(-1, 1)
778
+ return hidden_states
779
+
780
+
781
+ class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
782
+ """
783
+ Args:
784
+ Relative positional encoding module (new implementation). Details can be found in
785
+ https://github.com/espnet/espnet/pull/2816. See : Appendix Batch in https://arxiv.org/abs/1901.02860
786
+ config (`FastSpeech2ConformerConfig`):
787
+ FastSpeech2ConformerConfig instance.
788
+ module_config (`dict`):
789
+ Dictionary containing the encoder or decoder module configuration from the `FastSpeech2ConformerConfig`.
790
+ """
791
+
792
+ def __init__(self, config: FastSpeech2ConformerConfig, module_config):
793
+ """
794
+ Construct an PositionalEncoding object.
795
+ """
796
+ super().__init__()
797
+ self.embed_dim = config.hidden_size
798
+ self.input_scale = math.sqrt(self.embed_dim)
799
+ self.dropout = nn.Dropout(p=module_config["positional_dropout_rate"])
800
+ self.pos_enc = None
801
+ self.max_len = 5000
802
+ self.extend_pos_enc(torch.tensor(0.0).expand(1, self.max_len))
803
+
804
+ def extend_pos_enc(self, x):
805
+ """Reset the positional encodings."""
806
+ if self.pos_enc is not None:
807
+ # self.pos_enc contains both positive and negative parts
808
+ # the length of self.pos_enc is 2 * input_len - 1
809
+ if self.pos_enc.size(1) >= x.size(1) * 2 - 1:
810
+ if self.pos_enc.dtype != x.dtype or self.pos_enc.device != x.device:
811
+ self.pos_enc = self.pos_enc.to(dtype=x.dtype, device=x.device)
812
+ return
813
+ # Suppose `i` means to the position of query vector and `j` means the
814
+ # position of key vector. We use position relative positions when keys
815
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
816
+ pos_enc_positive = torch.zeros(x.size(1), self.embed_dim)
817
+ pos_enc_negative = torch.zeros(x.size(1), self.embed_dim)
818
+ position = torch.arange(0, x.size(1), dtype=torch.int64).float().unsqueeze(1)
819
+ div_term = torch.exp(
820
+ torch.arange(0, self.embed_dim, 2, dtype=torch.int64).float() * -(math.log(10000.0) / self.embed_dim)
821
+ )
822
+ pos_enc_positive[:, 0::2] = torch.sin(position * div_term)
823
+ pos_enc_positive[:, 1::2] = torch.cos(position * div_term)
824
+ pos_enc_negative[:, 0::2] = torch.sin(-1 * position * div_term)
825
+ pos_enc_negative[:, 1::2] = torch.cos(-1 * position * div_term)
826
+
827
+ # Reserve the order of positive indices and concat both positive and
828
+ # negative indices. This is used to support the shifting trick
829
+ # as in https://arxiv.org/abs/1901.02860
830
+ pos_enc_positive = torch.flip(pos_enc_positive, [0]).unsqueeze(0)
831
+ pos_enc_negative = pos_enc_negative[1:].unsqueeze(0)
832
+ pos_enc = torch.cat([pos_enc_positive, pos_enc_negative], dim=1)
833
+ self.pos_enc = pos_enc.to(device=x.device, dtype=x.dtype)
834
+
835
+ def forward(self, feature_representation):
836
+ """
837
+ Args:
838
+ feature_representation (`torch.Tensor` of shape (batch_size, time, `*`)):
839
+ Input tensor.
840
+
841
+ Returns:
842
+ `torch.Tensor`: Encoded tensor (batch_size, time, `*`).
843
+ """
844
+ self.extend_pos_enc(feature_representation)
845
+ hidden_states = feature_representation * self.input_scale
846
+ center_idx = self.pos_enc.size(1) // 2
847
+ pos_emb = self.pos_enc[:, center_idx - hidden_states.size(1) + 1 : center_idx + hidden_states.size(1)]
848
+ return self.dropout(hidden_states), self.dropout(pos_emb)
849
+
850
+
851
+ class FastSpeech2ConformerEncoder(nn.Module):
852
+ """
853
+ FastSpeech2ConformerEncoder encoder module.
854
+
855
+ Args:
856
+ config (`FastSpeech2ConformerConfig`):
857
+ FastSpeech2ConformerConfig instance.
858
+ module_config (`dict`):
859
+ Dictionary containing the encoder or decoder module configuration from the `FastSpeech2ConformerConfig`.
860
+ use_encoder_input_layer (`bool`, *optional*, defaults to `False`):
861
+ Input layer type.
862
+ """
863
+
864
+ def __init__(
865
+ self,
866
+ config: FastSpeech2ConformerConfig,
867
+ module_config,
868
+ use_encoder_input_layer=False,
869
+ ):
870
+ super().__init__()
871
+
872
+ self.embed = None
873
+ if use_encoder_input_layer:
874
+ self.embed = nn.Embedding(
875
+ num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, padding_idx=0
876
+ )
877
+
878
+ self.pos_enc = FastSpeech2ConformerRelPositionalEncoding(config, module_config)
879
+
880
+ self.conformer_layers = nn.ModuleList(
881
+ [FastSpeech2ConformerEncoderLayer(config, module_config) for _ in range(module_config["layers"])]
882
+ )
883
+
884
+ def forward(
885
+ self,
886
+ input_tensor: torch.LongTensor,
887
+ attention_mask: Optional[bool] = None,
888
+ output_hidden_states: Optional[bool] = None,
889
+ output_attentions: Optional[bool] = False,
890
+ return_dict: Optional[bool] = None,
891
+ ):
892
+ """
893
+ Args:
894
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
895
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
896
+ provide it.
897
+
898
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
899
+ [`PreTrainedTokenizer.__call__`] for details.
900
+
901
+ [What are input IDs?](../glossary#input-ids)
902
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
903
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
904
+
905
+ - 1 for tokens that are **not masked**,
906
+ - 0 for tokens that are **masked**.
907
+
908
+ [What are attention masks?](../glossary#attention-mask)
909
+ output_hidden_states (`bool`, *optional*):
910
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
911
+ for more detail.
912
+ output_attentions (`bool`, *optional*):
913
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
914
+ returned tensors for more detail.
915
+ return_dict (`bool`, *optional*):
916
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
917
+ Returns:
918
+ `torch.Tensor`:
919
+ Output tensor of shape `(batch, time, attention_dim)`.
920
+ """
921
+ feature_representation = input_tensor
922
+ if self.embed is not None:
923
+ feature_representation = self.embed(feature_representation)
924
+
925
+ hidden_states, pos_emb = self.pos_enc(feature_representation)
926
+
927
+ all_hidden_states = () if output_hidden_states else None
928
+ all_self_attentions = () if output_attentions else None
929
+
930
+ for conformer_layer in self.conformer_layers:
931
+ if output_hidden_states:
932
+ all_hidden_states = all_hidden_states + (hidden_states,)
933
+
934
+ layer_outputs = conformer_layer(hidden_states, pos_emb, attention_mask, output_attentions)
935
+ hidden_states = layer_outputs[0]
936
+
937
+ if output_attentions:
938
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
939
+
940
+ # Add last layer
941
+ if output_hidden_states:
942
+ all_hidden_states = all_hidden_states + (hidden_states,)
943
+
944
+ if not return_dict:
945
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
946
+ return BaseModelOutput(
947
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
948
+ )
949
+
950
+
951
+ class FastSpeech2ConformerLoss(nn.Module):
952
+ def __init__(self, config: FastSpeech2ConformerConfig):
953
+ super().__init__()
954
+
955
+ use_masking = config.use_masking
956
+ use_weighted_masking = config.use_weighted_masking
957
+
958
+ if use_masking and use_weighted_masking:
959
+ raise ValueError("Either use_masking or use_weighted_masking can be True, but not both.")
960
+
961
+ self.use_masking = use_masking
962
+ self.use_weighted_masking = use_weighted_masking
963
+
964
+ # define criterions
965
+ reduction = "none" if self.use_weighted_masking else "mean"
966
+ self.l1_criterion = nn.L1Loss(reduction=reduction)
967
+ self.mse_criterion = nn.MSELoss(reduction=reduction)
968
+ self.duration_criterion = nn.MSELoss(reduction=reduction)
969
+ self.log_domain_offset = 1.0
970
+
971
+ def forward(
972
+ self,
973
+ outputs_after_postnet,
974
+ outputs_before_postnet,
975
+ duration_outputs,
976
+ pitch_outputs,
977
+ energy_outputs,
978
+ spectrogram_labels,
979
+ duration_labels,
980
+ pitch_labels,
981
+ energy_labels,
982
+ duration_mask,
983
+ spectrogram_mask,
984
+ ):
985
+ """
986
+ Args:
987
+ outputs_after_postnet (`torch.Tensor` of shape `(batch_size, max_spectrogram_length, num_mel_bins)`):
988
+ Batch of outputs after postnet.
989
+ outputs_before_postnet (`torch.Tensor` of shape `(batch_size, max_spectrogram_length, num_mel_bins)`):
990
+ Batch of outputs before postnet.
991
+ duration_outputs (`torch.LongTensor` of shape `(batch_size, max_text_length)`):
992
+ Batch of outputs of duration predictor.
993
+ pitch_outputs (`torch.Tensor` of shape `(batch_size, max_text_length, 1)`):
994
+ Batch of outputs of pitch predictor.
995
+ energy_outputs (`torch.Tensor` of shape `(batch_size, max_text_length, 1)`):
996
+ Batch of outputs of energy predictor.
997
+ spectrogram_labels (`torch.Tensor` of shape `(batch_size, max_spectrogram_length, num_mel_bins)`):
998
+ Batch of target features.
999
+ duration_labels (`torch.LongTensor` of shape `(batch_size, max_text_length)`): Batch of durations.
1000
+ pitch_labels (`torch.Tensor` of shape `(batch_size, max_text_length, 1)`):
1001
+ Batch of target token-averaged pitch.
1002
+ energy_labels (`torch.Tensor` of shape `(batch_size, max_text_length, 1)`):
1003
+ Batch of target token-averaged energy.
1004
+ duration_mask (`torch.LongTensor`):
1005
+ Mask used to discern which values the duration loss should be calculated for.
1006
+ spectrogram_mask (`torch.LongTensor`):
1007
+ Mask used to discern which values the spectrogam loss should be calculated for.
1008
+
1009
+ Returns:
1010
+ `tuple(torch.FloatTensor)`: Tuple of tensors containing, in order, the L1 loss value, duration predictor
1011
+ loss value, pitch predictor loss value, and energy predictor loss value.
1012
+
1013
+ """
1014
+ pitch_and_energy_masks = duration_mask.unsqueeze(-1)
1015
+
1016
+ # apply mask to remove padded part
1017
+ if self.use_masking:
1018
+ outputs_before_postnet = outputs_before_postnet.masked_select(spectrogram_mask)
1019
+ if outputs_after_postnet is not None:
1020
+ outputs_after_postnet = outputs_after_postnet.masked_select(spectrogram_mask)
1021
+ spectrogram_labels = spectrogram_labels.masked_select(spectrogram_mask)
1022
+ duration_outputs = duration_outputs.masked_select(duration_mask)
1023
+ duration_labels = duration_labels.masked_select(duration_mask)
1024
+ pitch_outputs = pitch_outputs.masked_select(pitch_and_energy_masks)
1025
+ energy_outputs = energy_outputs.masked_select(pitch_and_energy_masks)
1026
+ pitch_labels = pitch_labels.masked_select(pitch_and_energy_masks)
1027
+ energy_labels = energy_labels.masked_select(pitch_and_energy_masks)
1028
+
1029
+ # calculate loss
1030
+ l1_loss = self.l1_criterion(outputs_before_postnet, spectrogram_labels)
1031
+ if outputs_after_postnet is not None:
1032
+ l1_loss = l1_loss + self.l1_criterion(outputs_after_postnet, spectrogram_labels)
1033
+ duration_labels = torch.log(duration_labels.float() + self.log_domain_offset)
1034
+ duration_loss = self.duration_criterion(duration_outputs, duration_labels)
1035
+ pitch_loss = self.mse_criterion(pitch_outputs, pitch_labels)
1036
+ energy_loss = self.mse_criterion(energy_outputs, energy_labels)
1037
+
1038
+ # make weighted mask and apply it
1039
+ if self.use_weighted_masking:
1040
+ spectrogram_mask = nn.functional.pad(
1041
+ spectrogram_mask.transpose(1, 2),
1042
+ [0, spectrogram_labels.size(1) - spectrogram_mask.size(1), 0, 0, 0, 0],
1043
+ value=False,
1044
+ ).transpose(1, 2)
1045
+
1046
+ out_weights = spectrogram_mask.float() / spectrogram_mask.sum(dim=1, keepdim=True).float()
1047
+ out_weights /= spectrogram_labels.size(0) * spectrogram_labels.size(2)
1048
+ duration_weights = duration_mask.float() / duration_mask.sum(dim=1, keepdim=True).float()
1049
+ duration_weights /= duration_labels.size(0)
1050
+
1051
+ # apply weight
1052
+ l1_loss = l1_loss.mul(out_weights).masked_select(spectrogram_mask).sum()
1053
+ duration_loss = duration_loss.mul(duration_weights).masked_select(duration_mask).sum()
1054
+ pitch_weights = duration_weights.unsqueeze(-1)
1055
+ pitch_loss = pitch_loss.mul(pitch_weights).masked_select(pitch_and_energy_masks).sum()
1056
+ energy_loss = energy_loss.mul(pitch_weights).masked_select(pitch_and_energy_masks).sum()
1057
+
1058
+ return l1_loss + duration_loss + pitch_loss + energy_loss
1059
+
1060
+
1061
+ class FastSpeech2ConformerPreTrainedModel(PreTrainedModel):
1062
+ """
1063
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1064
+ models.
1065
+ """
1066
+
1067
+ config_class = FastSpeech2ConformerConfig
1068
+ base_model_prefix = "fastspeech2_conformer"
1069
+
1070
+ main_input_name = "input_ids"
1071
+
1072
+ def _init_weights(self, module):
1073
+ """Initialize the weights"""
1074
+ if isinstance(module, (nn.LayerNorm)):
1075
+ module.bias.data.zero_()
1076
+ module.weight.data.fill_(1.0)
1077
+ elif isinstance(module, nn.Conv1d):
1078
+ nn.init.kaiming_normal_(module.weight)
1079
+ if module.bias is not None:
1080
+ key = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
1081
+ nn.init.uniform_(module.bias, a=-key, b=key)
1082
+ elif isinstance(module, nn.Embedding):
1083
+ module.weight.data.normal_()
1084
+ if module.padding_idx is not None:
1085
+ module.weight.data[module.padding_idx].zero_()
1086
+ elif isinstance(module, FastSpeech2ConformerAttention):
1087
+ nn.init.xavier_uniform_(module.pos_bias_u)
1088
+ nn.init.xavier_uniform_(module.pos_bias_v)
1089
+
1090
+ def _set_gradient_checkpointing(self, module, value=False):
1091
+ if isinstance(module, FastSpeech2ConformerEncoder):
1092
+ module.gradient_checkpointing = value
1093
+
1094
+
1095
+ @add_start_docstrings(
1096
+ """FastSpeech2Conformer Model.""",
1097
+ FASTSPEECH2_CONFORMER_START_DOCSTRING,
1098
+ )
1099
+ class FastSpeech2ConformerModel(FastSpeech2ConformerPreTrainedModel):
1100
+ """
1101
+ FastSpeech 2 module.
1102
+
1103
+ This is a module of FastSpeech 2 described in 'FastSpeech 2: Fast and High-Quality End-to-End Text to Speech'
1104
+ https://arxiv.org/abs/2006.04558. Instead of quantized pitch and energy, we use token-averaged value introduced in
1105
+ FastPitch: Parallel Text-to-speech with Pitch Prediction. The encoder and decoder are Conformers instead of regular
1106
+ Transformers.
1107
+ """
1108
+
1109
+ def __init__(self, config: FastSpeech2ConformerConfig):
1110
+ super().__init__(config)
1111
+ self.config = config
1112
+
1113
+ # store hyperparameters
1114
+ self.vocab_size = config.vocab_size
1115
+ self.num_mel_bins = config.num_mel_bins
1116
+ self.hidden_size = config.hidden_size
1117
+ self.reduction_factor = config.reduction_factor
1118
+ self.stop_gradient_from_pitch_predictor = config.stop_gradient_from_pitch_predictor
1119
+ self.stop_gradient_from_energy_predictor = config.stop_gradient_from_energy_predictor
1120
+
1121
+ self.multilingual_model = config.num_languages is not None and config.num_languages > 1
1122
+ if self.multilingual_model:
1123
+ self.language_id_embedding = torch.nn.Embedding(config.num_languages, self.hidden_size)
1124
+
1125
+ self.multispeaker_model = config.num_speakers is not None and config.num_speakers > 1
1126
+ if self.multispeaker_model:
1127
+ self.speaker_id_embedding = torch.nn.Embedding(config.num_speakers, config.hidden_size)
1128
+
1129
+ self.speaker_embed_dim = config.speaker_embed_dim
1130
+ if self.speaker_embed_dim:
1131
+ self.projection = nn.Linear(config.hidden_size + self.speaker_embed_dim, config.hidden_size)
1132
+
1133
+ self.encoder = FastSpeech2ConformerEncoder(config, config.encoder_config, use_encoder_input_layer=True)
1134
+
1135
+ self.duration_predictor = FastSpeech2ConformerDurationPredictor(config)
1136
+
1137
+ self.pitch_predictor = FastSpeech2ConformerVariancePredictor(
1138
+ config,
1139
+ num_layers=config.pitch_predictor_layers,
1140
+ num_chans=config.pitch_predictor_channels,
1141
+ kernel_size=config.pitch_predictor_kernel_size,
1142
+ dropout_rate=config.pitch_predictor_dropout,
1143
+ )
1144
+ # continuous pitch + FastPitch style avg
1145
+ self.pitch_embed = FastSpeech2ConformerVarianceEmbedding(
1146
+ out_channels=self.hidden_size,
1147
+ kernel_size=config.pitch_embed_kernel_size,
1148
+ padding=(config.pitch_embed_kernel_size - 1) // 2,
1149
+ dropout_rate=config.pitch_embed_dropout,
1150
+ )
1151
+
1152
+ self.energy_predictor = FastSpeech2ConformerVariancePredictor(
1153
+ config,
1154
+ num_layers=config.energy_predictor_layers,
1155
+ num_chans=config.energy_predictor_channels,
1156
+ kernel_size=config.energy_predictor_kernel_size,
1157
+ dropout_rate=config.energy_predictor_dropout,
1158
+ )
1159
+ # continuous energy + FastPitch style avg
1160
+ self.energy_embed = FastSpeech2ConformerVarianceEmbedding(
1161
+ out_channels=self.hidden_size,
1162
+ kernel_size=config.energy_embed_kernel_size,
1163
+ padding=(config.energy_embed_kernel_size - 1) // 2,
1164
+ dropout_rate=config.energy_embed_dropout,
1165
+ )
1166
+
1167
+ # The decoder is an encoder
1168
+ self.decoder = FastSpeech2ConformerEncoder(config, config.decoder_config, use_encoder_input_layer=False)
1169
+
1170
+ self.speech_decoder_postnet = FastSpeech2ConformerSpeechDecoderPostnet(config)
1171
+
1172
+ self.criterion = FastSpeech2ConformerLoss(config)
1173
+
1174
+ self.post_init()
1175
+
1176
+ @replace_return_docstrings(output_type=FastSpeech2ConformerModelOutput, config_class=_CONFIG_FOR_DOC)
1177
+ def forward(
1178
+ self,
1179
+ input_ids: torch.LongTensor,
1180
+ attention_mask: Optional[torch.LongTensor] = None,
1181
+ spectrogram_labels: Optional[torch.FloatTensor] = None,
1182
+ duration_labels: Optional[torch.LongTensor] = None,
1183
+ pitch_labels: Optional[torch.FloatTensor] = None,
1184
+ energy_labels: Optional[torch.FloatTensor] = None,
1185
+ speaker_ids: Optional[torch.LongTensor] = None,
1186
+ lang_ids: Optional[torch.LongTensor] = None,
1187
+ speaker_embedding: Optional[torch.FloatTensor] = None,
1188
+ return_dict: Optional[bool] = None,
1189
+ output_attentions: Optional[bool] = None,
1190
+ output_hidden_states: Optional[bool] = None,
1191
+ ) -> Union[Tuple, FastSpeech2ConformerModelOutput]:
1192
+ """
1193
+ Args:
1194
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1195
+ Input sequence of text vectors.
1196
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to `None`):
1197
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
1198
+ `[0, 1]`: 0 for tokens that are **masked**, 1 for tokens that are **not masked**.
1199
+ spectrogram_labels (`torch.FloatTensor` of shape `(batch_size, max_spectrogram_length, num_mel_bins)`, *optional*, defaults to `None`):
1200
+ Batch of padded target features.
1201
+ duration_labels (`torch.LongTensor` of shape `(batch_size, sequence_length + 1)`, *optional*, defaults to `None`):
1202
+ Batch of padded durations.
1203
+ pitch_labels (`torch.FloatTensor` of shape `(batch_size, sequence_length + 1, 1)`, *optional*, defaults to `None`):
1204
+ Batch of padded token-averaged pitch.
1205
+ energy_labels (`torch.FloatTensor` of shape `(batch_size, sequence_length + 1, 1)`, *optional*, defaults to `None`):
1206
+ Batch of padded token-averaged energy.
1207
+ speaker_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*, defaults to `None`):
1208
+ Speaker ids used to condition features of speech output by the model.
1209
+ lang_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*, defaults to `None`):
1210
+ Language ids used to condition features of speech output by the model.
1211
+ speaker_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`, *optional*, defaults to `None`):
1212
+ Embedding containing conditioning signals for the features of the speech.
1213
+ return_dict (`bool`, *optional*, defaults to `None`):
1214
+ Whether or not to return a [`FastSpeech2ConformerModelOutput`] instead of a plain tuple.
1215
+ output_attentions (`bool`, *optional*, defaults to `None`):
1216
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1217
+ returned tensors for more detail.
1218
+ output_hidden_states (`bool`, *optional*, defaults to `None`):
1219
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1220
+ for more detail.
1221
+
1222
+ Returns:
1223
+
1224
+ Example:
1225
+
1226
+ ```python
1227
+ >>> from transformers import (
1228
+ ... FastSpeech2ConformerTokenizer,
1229
+ ... FastSpeech2ConformerModel,
1230
+ ... FastSpeech2ConformerHifiGan,
1231
+ ... )
1232
+
1233
+ >>> tokenizer = FastSpeech2ConformerTokenizer.from_pretrained("espnet/fastspeech2_conformer")
1234
+ >>> inputs = tokenizer("some text to convert to speech", return_tensors="pt")
1235
+ >>> input_ids = inputs["input_ids"]
1236
+
1237
+ >>> model = FastSpeech2ConformerModel.from_pretrained("espnet/fastspeech2_conformer")
1238
+ >>> output_dict = model(input_ids, return_dict=True)
1239
+ >>> spectrogram = output_dict["spectrogram"]
1240
+
1241
+ >>> vocoder = FastSpeech2ConformerHifiGan.from_pretrained("espnet/fastspeech2_conformer_hifigan")
1242
+ >>> waveform = vocoder(spectrogram)
1243
+ >>> print(waveform.shape)
1244
+ torch.Size([1, 49664])
1245
+ ```
1246
+ """
1247
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1248
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1249
+ output_hidden_states = (
1250
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1251
+ )
1252
+
1253
+ if attention_mask is None:
1254
+ attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
1255
+
1256
+ has_missing_labels = (
1257
+ spectrogram_labels is None or duration_labels is None or pitch_labels is None or energy_labels is None
1258
+ )
1259
+ if self.training and has_missing_labels:
1260
+ raise ValueError("All labels must be provided to run in training mode.")
1261
+
1262
+ # forward encoder
1263
+ text_masks = attention_mask.unsqueeze(-2)
1264
+
1265
+ encoder_outputs = self.encoder(
1266
+ input_ids,
1267
+ text_masks,
1268
+ output_hidden_states=output_hidden_states,
1269
+ output_attentions=output_attentions,
1270
+ return_dict=return_dict,
1271
+ )
1272
+ hidden_states = encoder_outputs[0]
1273
+
1274
+ # Integrate with language id, speaker id, and speaker embedding
1275
+ if self.multispeaker_model and speaker_ids is not None:
1276
+ speaker_id_embeddings = self.speaker_id_embedding(speaker_ids.view(-1))
1277
+ hidden_states = hidden_states + speaker_id_embeddings.unsqueeze(1)
1278
+
1279
+ if self.multilingual_model and lang_ids is not None:
1280
+ language_id_embbedings = self.language_id_embedding(lang_ids.view(-1))
1281
+ hidden_states = hidden_states + language_id_embbedings.unsqueeze(1)
1282
+
1283
+ if self.speaker_embed_dim is not None and speaker_embedding is not None:
1284
+ embeddings_expanded = (
1285
+ nn.functional.normalize(speaker_embedding).unsqueeze(1).expand(-1, hidden_states.size(1), -1)
1286
+ )
1287
+ hidden_states = self.projection(torch.cat([hidden_states, embeddings_expanded], dim=-1))
1288
+
1289
+ # forward duration predictor and variance predictors
1290
+ duration_mask = ~attention_mask.bool()
1291
+
1292
+ if self.stop_gradient_from_pitch_predictor:
1293
+ pitch_predictions = self.pitch_predictor(hidden_states.detach(), duration_mask.unsqueeze(-1))
1294
+ else:
1295
+ pitch_predictions = self.pitch_predictor(hidden_states, duration_mask.unsqueeze(-1))
1296
+
1297
+ if self.stop_gradient_from_energy_predictor:
1298
+ energy_predictions = self.energy_predictor(hidden_states.detach(), duration_mask.unsqueeze(-1))
1299
+ else:
1300
+ energy_predictions = self.energy_predictor(hidden_states, duration_mask.unsqueeze(-1))
1301
+
1302
+ duration_predictions = self.duration_predictor(hidden_states)
1303
+ duration_predictions = duration_predictions.masked_fill(duration_mask, 0.0)
1304
+
1305
+ if not self.training:
1306
+ # use prediction in inference
1307
+ embedded_pitch_curve = self.pitch_embed(pitch_predictions)
1308
+ embedded_energy_curve = self.energy_embed(energy_predictions)
1309
+ hidden_states = hidden_states + embedded_energy_curve + embedded_pitch_curve
1310
+ hidden_states = length_regulator(hidden_states, duration_predictions, self.config.speaking_speed)
1311
+ else:
1312
+ # use groundtruth in training
1313
+ embedded_pitch_curve = self.pitch_embed(pitch_labels)
1314
+ embedded_energy_curve = self.energy_embed(energy_labels)
1315
+ hidden_states = hidden_states + embedded_energy_curve + embedded_pitch_curve
1316
+ hidden_states = length_regulator(hidden_states, duration_labels)
1317
+
1318
+ # forward decoder
1319
+ if not self.training:
1320
+ hidden_mask = None
1321
+ else:
1322
+ spectrogram_mask = (spectrogram_labels != -100).any(dim=-1)
1323
+ spectrogram_mask = spectrogram_mask.int()
1324
+ if self.reduction_factor > 1:
1325
+ length_dim = spectrogram_mask.shape[1] - spectrogram_mask.shape[1] % self.reduction_factor
1326
+ spectrogram_mask = spectrogram_mask[:, :, :length_dim]
1327
+ hidden_mask = spectrogram_mask.unsqueeze(-2)
1328
+
1329
+ decoder_outputs = self.decoder(
1330
+ hidden_states,
1331
+ hidden_mask,
1332
+ output_hidden_states=output_hidden_states,
1333
+ output_attentions=output_attentions,
1334
+ return_dict=return_dict,
1335
+ )
1336
+
1337
+ outputs_before_postnet, outputs_after_postnet = self.speech_decoder_postnet(decoder_outputs[0])
1338
+
1339
+ loss = None
1340
+ if self.training:
1341
+ # calculate loss
1342
+ loss_duration_mask = ~duration_mask
1343
+ loss_spectrogram_mask = spectrogram_mask.unsqueeze(-1).bool()
1344
+ loss = self.criterion(
1345
+ outputs_after_postnet=outputs_after_postnet,
1346
+ outputs_before_postnet=outputs_before_postnet,
1347
+ duration_outputs=duration_predictions,
1348
+ pitch_outputs=pitch_predictions,
1349
+ energy_outputs=energy_predictions,
1350
+ spectrogram_labels=spectrogram_labels,
1351
+ duration_labels=duration_labels,
1352
+ pitch_labels=pitch_labels,
1353
+ energy_labels=energy_labels,
1354
+ duration_mask=loss_duration_mask,
1355
+ spectrogram_mask=loss_spectrogram_mask,
1356
+ )
1357
+
1358
+ if not return_dict:
1359
+ postnet_outputs = (outputs_after_postnet,)
1360
+ audio_feature_predictions = (
1361
+ duration_predictions,
1362
+ pitch_predictions,
1363
+ energy_predictions,
1364
+ )
1365
+ outputs = postnet_outputs + encoder_outputs + decoder_outputs[1:] + audio_feature_predictions
1366
+ return ((loss,) + outputs) if loss is not None else outputs
1367
+
1368
+ return FastSpeech2ConformerModelOutput(
1369
+ loss=loss,
1370
+ spectrogram=outputs_after_postnet,
1371
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1372
+ encoder_hidden_states=encoder_outputs.hidden_states,
1373
+ encoder_attentions=encoder_outputs.attentions,
1374
+ decoder_hidden_states=decoder_outputs.hidden_states,
1375
+ decoder_attentions=decoder_outputs.attentions,
1376
+ duration_outputs=duration_predictions,
1377
+ pitch_outputs=pitch_predictions,
1378
+ energy_outputs=energy_predictions,
1379
+ )
1380
+
1381
+
1382
+ # Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock
1383
+ class HifiGanResidualBlock(nn.Module):
1384
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1):
1385
+ super().__init__()
1386
+ self.leaky_relu_slope = leaky_relu_slope
1387
+
1388
+ self.convs1 = nn.ModuleList(
1389
+ [
1390
+ nn.Conv1d(
1391
+ channels,
1392
+ channels,
1393
+ kernel_size,
1394
+ stride=1,
1395
+ dilation=dilation[i],
1396
+ padding=self.get_padding(kernel_size, dilation[i]),
1397
+ )
1398
+ for i in range(len(dilation))
1399
+ ]
1400
+ )
1401
+ self.convs2 = nn.ModuleList(
1402
+ [
1403
+ nn.Conv1d(
1404
+ channels,
1405
+ channels,
1406
+ kernel_size,
1407
+ stride=1,
1408
+ dilation=1,
1409
+ padding=self.get_padding(kernel_size, 1),
1410
+ )
1411
+ for _ in range(len(dilation))
1412
+ ]
1413
+ )
1414
+
1415
+ def get_padding(self, kernel_size, dilation=1):
1416
+ return (kernel_size * dilation - dilation) // 2
1417
+
1418
+ def apply_weight_norm(self):
1419
+ weight_norm = nn.utils.weight_norm
1420
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
1421
+ weight_norm = nn.utils.parametrizations.weight_norm
1422
+
1423
+ for layer in self.convs1:
1424
+ weight_norm(layer)
1425
+ for layer in self.convs2:
1426
+ weight_norm(layer)
1427
+
1428
+ def remove_weight_norm(self):
1429
+ for layer in self.convs1:
1430
+ nn.utils.remove_weight_norm(layer)
1431
+ for layer in self.convs2:
1432
+ nn.utils.remove_weight_norm(layer)
1433
+
1434
+ def forward(self, hidden_states):
1435
+ for conv1, conv2 in zip(self.convs1, self.convs2):
1436
+ residual = hidden_states
1437
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
1438
+ hidden_states = conv1(hidden_states)
1439
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
1440
+ hidden_states = conv2(hidden_states)
1441
+ hidden_states = hidden_states + residual
1442
+ return hidden_states
1443
+
1444
+
1445
+ @add_start_docstrings(
1446
+ """HiFi-GAN vocoder.""",
1447
+ HIFIGAN_START_DOCSTRING,
1448
+ )
1449
+ # Copied from transformers.models.speecht5.modeling_speecht5.SpeechT5HifiGan with SpeechT5->FastSpeech2Conformer
1450
+ class FastSpeech2ConformerHifiGan(PreTrainedModel):
1451
+ config_class = FastSpeech2ConformerHifiGanConfig
1452
+ main_input_name = "spectrogram"
1453
+
1454
+ def __init__(self, config: FastSpeech2ConformerHifiGanConfig):
1455
+ super().__init__(config)
1456
+ self.num_kernels = len(config.resblock_kernel_sizes)
1457
+ self.num_upsamples = len(config.upsample_rates)
1458
+ self.conv_pre = nn.Conv1d(
1459
+ config.model_in_dim,
1460
+ config.upsample_initial_channel,
1461
+ kernel_size=7,
1462
+ stride=1,
1463
+ padding=3,
1464
+ )
1465
+
1466
+ self.upsampler = nn.ModuleList()
1467
+ for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
1468
+ self.upsampler.append(
1469
+ nn.ConvTranspose1d(
1470
+ config.upsample_initial_channel // (2**i),
1471
+ config.upsample_initial_channel // (2 ** (i + 1)),
1472
+ kernel_size=kernel_size,
1473
+ stride=upsample_rate,
1474
+ padding=(kernel_size - upsample_rate) // 2,
1475
+ )
1476
+ )
1477
+
1478
+ self.resblocks = nn.ModuleList()
1479
+ for i in range(len(self.upsampler)):
1480
+ channels = config.upsample_initial_channel // (2 ** (i + 1))
1481
+ for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
1482
+ self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
1483
+
1484
+ self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3)
1485
+
1486
+ self.register_buffer("mean", torch.zeros(config.model_in_dim))
1487
+ self.register_buffer("scale", torch.ones(config.model_in_dim))
1488
+
1489
+ # Initialize weights and apply final processing
1490
+ self.post_init()
1491
+
1492
+ def _init_weights(self, module):
1493
+ """Initialize the weights."""
1494
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
1495
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1496
+ if module.bias is not None:
1497
+ module.bias.data.zero_()
1498
+
1499
+ def apply_weight_norm(self):
1500
+ weight_norm = nn.utils.weight_norm
1501
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
1502
+ weight_norm = nn.utils.parametrizations.weight_norm
1503
+
1504
+ weight_norm(self.conv_pre)
1505
+ for layer in self.upsampler:
1506
+ weight_norm(layer)
1507
+ for layer in self.resblocks:
1508
+ layer.apply_weight_norm()
1509
+ weight_norm(self.conv_post)
1510
+
1511
+ def remove_weight_norm(self):
1512
+ nn.utils.remove_weight_norm(self.conv_pre)
1513
+ for layer in self.upsampler:
1514
+ nn.utils.remove_weight_norm(layer)
1515
+ for layer in self.resblocks:
1516
+ layer.remove_weight_norm()
1517
+ nn.utils.remove_weight_norm(self.conv_post)
1518
+
1519
+ def forward(self, spectrogram: torch.FloatTensor) -> torch.FloatTensor:
1520
+ r"""
1521
+ Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch
1522
+ of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech
1523
+ waveform.
1524
+
1525
+ Args:
1526
+ spectrogram (`torch.FloatTensor`):
1527
+ Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length,
1528
+ config.model_in_dim)`, or un-batched and of shape `(sequence_length, config.model_in_dim)`.
1529
+
1530
+ Returns:
1531
+ `torch.FloatTensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of
1532
+ shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`.
1533
+ """
1534
+ if self.config.normalize_before:
1535
+ spectrogram = (spectrogram - self.mean) / self.scale
1536
+
1537
+ is_batched = spectrogram.dim() == 3
1538
+ if not is_batched:
1539
+ spectrogram = spectrogram.unsqueeze(0)
1540
+
1541
+ hidden_states = spectrogram.transpose(2, 1)
1542
+
1543
+ hidden_states = self.conv_pre(hidden_states)
1544
+ for i in range(self.num_upsamples):
1545
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
1546
+ hidden_states = self.upsampler[i](hidden_states)
1547
+
1548
+ res_state = self.resblocks[i * self.num_kernels](hidden_states)
1549
+ for j in range(1, self.num_kernels):
1550
+ res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
1551
+ hidden_states = res_state / self.num_kernels
1552
+
1553
+ hidden_states = nn.functional.leaky_relu(hidden_states)
1554
+ hidden_states = self.conv_post(hidden_states)
1555
+ hidden_states = torch.tanh(hidden_states)
1556
+
1557
+ if not is_batched:
1558
+ # remove batch dim and collapse tensor to 1-d audio waveform
1559
+ waveform = hidden_states.squeeze(0).transpose(1, 0).view(-1)
1560
+ else:
1561
+ # remove seq-len dim since this collapses to 1
1562
+ waveform = hidden_states.squeeze(1)
1563
+
1564
+ return waveform
1565
+
1566
+
1567
+ @add_start_docstrings(
1568
+ "The FastSpeech2ConformerModel with a FastSpeech2ConformerHifiGan vocoder head that performs text-to-speech (waveform).",
1569
+ FASTSPEECH2_CONFORMER_WITH_HIFIGAN_START_DOCSTRING,
1570
+ )
1571
+ class FastSpeech2ConformerWithHifiGan(PreTrainedModel):
1572
+ config_class = FastSpeech2ConformerWithHifiGanConfig
1573
+
1574
+ def __init__(self, config: FastSpeech2ConformerWithHifiGanConfig):
1575
+ super().__init__(config)
1576
+
1577
+ self.model = FastSpeech2ConformerModel(config.model_config)
1578
+ self.vocoder = FastSpeech2ConformerHifiGan(config.vocoder_config)
1579
+
1580
+ self.config = config
1581
+
1582
+ @replace_return_docstrings(
1583
+ output_type=FastSpeech2ConformerWithHifiGanOutput, config_class=FastSpeech2ConformerWithHifiGanConfig
1584
+ )
1585
+ def forward(
1586
+ self,
1587
+ input_ids: torch.LongTensor,
1588
+ attention_mask: Optional[torch.LongTensor] = None,
1589
+ spectrogram_labels: Optional[torch.FloatTensor] = None,
1590
+ duration_labels: Optional[torch.LongTensor] = None,
1591
+ pitch_labels: Optional[torch.FloatTensor] = None,
1592
+ energy_labels: Optional[torch.FloatTensor] = None,
1593
+ speaker_ids: Optional[torch.LongTensor] = None,
1594
+ lang_ids: Optional[torch.LongTensor] = None,
1595
+ speaker_embedding: Optional[torch.FloatTensor] = None,
1596
+ return_dict: Optional[bool] = None,
1597
+ output_attentions: Optional[bool] = None,
1598
+ output_hidden_states: Optional[bool] = None,
1599
+ ) -> Union[Tuple, FastSpeech2ConformerModelOutput]:
1600
+ """
1601
+ Args:
1602
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1603
+ Input sequence of text vectors.
1604
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to `None`):
1605
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
1606
+ `[0, 1]`: 0 for tokens that are **masked**, 1 for tokens that are **not masked**.
1607
+ spectrogram_labels (`torch.FloatTensor` of shape `(batch_size, max_spectrogram_length, num_mel_bins)`, *optional*, defaults to `None`):
1608
+ Batch of padded target features.
1609
+ duration_labels (`torch.LongTensor` of shape `(batch_size, sequence_length + 1)`, *optional*, defaults to `None`):
1610
+ Batch of padded durations.
1611
+ pitch_labels (`torch.FloatTensor` of shape `(batch_size, sequence_length + 1, 1)`, *optional*, defaults to `None`):
1612
+ Batch of padded token-averaged pitch.
1613
+ energy_labels (`torch.FloatTensor` of shape `(batch_size, sequence_length + 1, 1)`, *optional*, defaults to `None`):
1614
+ Batch of padded token-averaged energy.
1615
+ speaker_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*, defaults to `None`):
1616
+ Speaker ids used to condition features of speech output by the model.
1617
+ lang_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*, defaults to `None`):
1618
+ Language ids used to condition features of speech output by the model.
1619
+ speaker_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`, *optional*, defaults to `None`):
1620
+ Embedding containing conditioning signals for the features of the speech.
1621
+ return_dict (`bool`, *optional*, defaults to `None`):
1622
+ Whether or not to return a [`FastSpeech2ConformerModelOutput`] instead of a plain tuple.
1623
+ output_attentions (`bool`, *optional*, defaults to `None`):
1624
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1625
+ returned tensors for more detail.
1626
+ output_hidden_states (`bool`, *optional*, defaults to `None`):
1627
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1628
+ for more detail.
1629
+
1630
+ Returns:
1631
+
1632
+ Example:
1633
+
1634
+ ```python
1635
+ >>> from transformers import (
1636
+ ... FastSpeech2ConformerTokenizer,
1637
+ ... FastSpeech2ConformerWithHifiGan,
1638
+ ... )
1639
+
1640
+ >>> tokenizer = FastSpeech2ConformerTokenizer.from_pretrained("espnet/fastspeech2_conformer")
1641
+ >>> inputs = tokenizer("some text to convert to speech", return_tensors="pt")
1642
+ >>> input_ids = inputs["input_ids"]
1643
+
1644
+ >>> model = FastSpeech2ConformerWithHifiGan.from_pretrained("espnet/fastspeech2_conformer_with_hifigan")
1645
+ >>> output_dict = model(input_ids, return_dict=True)
1646
+ >>> waveform = output_dict["waveform"]
1647
+ >>> print(waveform.shape)
1648
+ torch.Size([1, 49664])
1649
+ ```
1650
+ """
1651
+ return_dict = return_dict if return_dict is not None else self.config.model_config.use_return_dict
1652
+ output_attentions = (
1653
+ output_attentions if output_attentions is not None else self.config.model_config.output_attentions
1654
+ )
1655
+ output_hidden_states = (
1656
+ output_hidden_states if output_hidden_states is not None else self.config.model_config.output_hidden_states
1657
+ )
1658
+
1659
+ model_outputs = self.model(
1660
+ input_ids,
1661
+ attention_mask,
1662
+ spectrogram_labels=spectrogram_labels,
1663
+ duration_labels=duration_labels,
1664
+ pitch_labels=pitch_labels,
1665
+ energy_labels=energy_labels,
1666
+ speaker_ids=speaker_ids,
1667
+ lang_ids=lang_ids,
1668
+ speaker_embedding=speaker_embedding,
1669
+ return_dict=return_dict,
1670
+ output_attentions=output_attentions,
1671
+ output_hidden_states=output_hidden_states,
1672
+ )
1673
+
1674
+ if not return_dict:
1675
+ has_missing_labels = (
1676
+ spectrogram_labels is None or duration_labels is None or pitch_labels is None or energy_labels is None
1677
+ )
1678
+ if has_missing_labels:
1679
+ spectrogram = model_outputs[0]
1680
+ else:
1681
+ spectrogram = model_outputs[1]
1682
+ else:
1683
+ spectrogram = model_outputs["spectrogram"]
1684
+ waveform = self.vocoder(spectrogram)
1685
+
1686
+ if not return_dict:
1687
+ return model_outputs + (waveform,)
1688
+
1689
+ return FastSpeech2ConformerWithHifiGanOutput(waveform=waveform, **model_outputs)
1690
+
1691
+
1692
+ __all__ = [
1693
+ "FastSpeech2ConformerWithHifiGan",
1694
+ "FastSpeech2ConformerHifiGan",
1695
+ "FastSpeech2ConformerModel",
1696
+ "FastSpeech2ConformerPreTrainedModel",
1697
+ ]
docs/transformers/build/lib/transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for FastSpeech2Conformer."""
16
+
17
+ import json
18
+ import os
19
+ from typing import Optional, Tuple
20
+
21
+ import regex
22
+
23
+ from ...tokenization_utils import PreTrainedTokenizer
24
+ from ...utils import logging, requires_backends
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"}
30
+
31
+
32
+ class FastSpeech2ConformerTokenizer(PreTrainedTokenizer):
33
+ """
34
+ Construct a FastSpeech2Conformer tokenizer.
35
+
36
+ Args:
37
+ vocab_file (`str`):
38
+ Path to the vocabulary file.
39
+ bos_token (`str`, *optional*, defaults to `"<sos/eos>"`):
40
+ The begin of sequence token. Note that for FastSpeech2, it is the same as the `eos_token`.
41
+ eos_token (`str`, *optional*, defaults to `"<sos/eos>"`):
42
+ The end of sequence token. Note that for FastSpeech2, it is the same as the `bos_token`.
43
+ pad_token (`str`, *optional*, defaults to `"<blank>"`):
44
+ The token used for padding, for example when batching sequences of different lengths.
45
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
46
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
47
+ token instead.
48
+ should_strip_spaces (`bool`, *optional*, defaults to `False`):
49
+ Whether or not to strip the spaces from the list of tokens.
50
+ """
51
+
52
+ vocab_files_names = VOCAB_FILES_NAMES
53
+ model_input_names = ["input_ids", "attention_mask"]
54
+
55
+ def __init__(
56
+ self,
57
+ vocab_file,
58
+ bos_token="<sos/eos>",
59
+ eos_token="<sos/eos>",
60
+ pad_token="<blank>",
61
+ unk_token="<unk>",
62
+ should_strip_spaces=False,
63
+ **kwargs,
64
+ ):
65
+ requires_backends(self, "g2p_en")
66
+
67
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
68
+ self.encoder = json.load(vocab_handle)
69
+
70
+ import g2p_en
71
+
72
+ self.g2p = g2p_en.G2p()
73
+
74
+ self.decoder = {v: k for k, v in self.encoder.items()}
75
+
76
+ super().__init__(
77
+ bos_token=bos_token,
78
+ eos_token=eos_token,
79
+ unk_token=unk_token,
80
+ pad_token=pad_token,
81
+ should_strip_spaces=should_strip_spaces,
82
+ **kwargs,
83
+ )
84
+
85
+ self.should_strip_spaces = should_strip_spaces
86
+
87
+ @property
88
+ def vocab_size(self):
89
+ return len(self.decoder)
90
+
91
+ def get_vocab(self):
92
+ "Returns vocab as a dict"
93
+ return dict(self.encoder, **self.added_tokens_encoder)
94
+
95
+ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
96
+ # expand symbols
97
+ text = regex.sub(";", ",", text)
98
+ text = regex.sub(":", ",", text)
99
+ text = regex.sub("-", " ", text)
100
+ text = regex.sub("&", "and", text)
101
+
102
+ # strip unnecessary symbols
103
+ text = regex.sub(r"[\(\)\[\]\<\>\"]+", "", text)
104
+
105
+ # strip whitespaces
106
+ text = regex.sub(r"\s+", " ", text)
107
+
108
+ text = text.upper()
109
+
110
+ return text, kwargs
111
+
112
+ def _tokenize(self, text):
113
+ """Returns a tokenized string."""
114
+ # phonemize
115
+ tokens = self.g2p(text)
116
+
117
+ if self.should_strip_spaces:
118
+ tokens = list(filter(lambda s: s != " ", tokens))
119
+
120
+ tokens.append(self.eos_token)
121
+
122
+ return tokens
123
+
124
+ def _convert_token_to_id(self, token):
125
+ """Converts a token (str) in an id using the vocab."""
126
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
127
+
128
+ def _convert_id_to_token(self, index):
129
+ """Converts an index (integer) in a token (str) using the vocab."""
130
+ return self.decoder.get(index, self.unk_token)
131
+
132
+ # Override since phonemes cannot be converted back to strings
133
+ def decode(self, token_ids, **kwargs):
134
+ logger.warning(
135
+ "Phonemes cannot be reliably converted to a string due to the one-many mapping, converting to tokens instead."
136
+ )
137
+ return self.convert_ids_to_tokens(token_ids)
138
+
139
+ # Override since phonemes cannot be converted back to strings
140
+ def convert_tokens_to_string(self, tokens, **kwargs):
141
+ logger.warning(
142
+ "Phonemes cannot be reliably converted to a string due to the one-many mapping, returning the tokens."
143
+ )
144
+ return tokens
145
+
146
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
147
+ """
148
+ Save the vocabulary and special tokens file to a directory.
149
+
150
+ Args:
151
+ save_directory (`str`):
152
+ The directory in which to save the vocabulary.
153
+
154
+ Returns:
155
+ `Tuple(str)`: Paths to the files saved.
156
+ """
157
+ if not os.path.isdir(save_directory):
158
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
159
+ return
160
+ vocab_file = os.path.join(
161
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
162
+ )
163
+
164
+ with open(vocab_file, "w", encoding="utf-8") as f:
165
+ f.write(json.dumps(self.get_vocab(), ensure_ascii=False))
166
+
167
+ return (vocab_file,)
168
+
169
+ def __getstate__(self):
170
+ state = self.__dict__.copy()
171
+ state["g2p"] = None
172
+ return state
173
+
174
+ def __setstate__(self, d):
175
+ self.__dict__ = d
176
+
177
+ try:
178
+ import g2p_en
179
+
180
+ self.g2p = g2p_en.G2p()
181
+ except ImportError:
182
+ raise ImportError(
183
+ "You need to install g2p-en to use FastSpeech2ConformerTokenizer. "
184
+ "See https://pypi.org/project/g2p-en/ for installation."
185
+ )
186
+
187
+
188
+ __all__ = ["FastSpeech2ConformerTokenizer"]
docs/transformers/build/lib/transformers/models/flaubert/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_flaubert import *
22
+ from .modeling_flaubert import *
23
+ from .modeling_tf_flaubert import *
24
+ from .tokenization_flaubert import *
25
+ else:
26
+ import sys
27
+
28
+ _file = globals()["__file__"]
29
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/flaubert/configuration_flaubert.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2019-present CNRS, Facebook Inc. and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Flaubert configuration"""
16
+
17
+ from collections import OrderedDict
18
+ from typing import Mapping
19
+
20
+ from ...configuration_utils import PretrainedConfig
21
+ from ...onnx import OnnxConfig
22
+ from ...utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class FlaubertConfig(PretrainedConfig):
29
+ """
30
+ This is the configuration class to store the configuration of a [`FlaubertModel`] or a [`TFFlaubertModel`]. It is
31
+ used to instantiate a FlauBERT model according to the specified arguments, defining the model architecture.
32
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the FlauBERT
33
+ [flaubert/flaubert_base_uncased](https://huggingface.co/flaubert/flaubert_base_uncased) architecture.
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+ Args:
39
+ pre_norm (`bool`, *optional*, defaults to `False`):
40
+ Whether to apply the layer normalization before or after the feed forward layer following the attention in
41
+ each layer (Vaswani et al., Tensor2Tensor for Neural Machine Translation. 2018)
42
+ layerdrop (`float`, *optional*, defaults to 0.0):
43
+ Probability to drop layers during training (Fan et al., Reducing Transformer Depth on Demand with
44
+ Structured Dropout. ICLR 2020)
45
+ vocab_size (`int`, *optional*, defaults to 30145):
46
+ Vocabulary size of the FlauBERT model. Defines the number of different tokens that can be represented by
47
+ the `inputs_ids` passed when calling [`FlaubertModel`] or [`TFFlaubertModel`].
48
+ emb_dim (`int`, *optional*, defaults to 2048):
49
+ Dimensionality of the encoder layers and the pooler layer.
50
+ n_layer (`int`, *optional*, defaults to 12):
51
+ Number of hidden layers in the Transformer encoder.
52
+ n_head (`int`, *optional*, defaults to 16):
53
+ Number of attention heads for each attention layer in the Transformer encoder.
54
+ dropout (`float`, *optional*, defaults to 0.1):
55
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
56
+ attention_dropout (`float`, *optional*, defaults to 0.1):
57
+ The dropout probability for the attention mechanism
58
+ gelu_activation (`bool`, *optional*, defaults to `True`):
59
+ Whether or not to use a *gelu* activation instead of *relu*.
60
+ sinusoidal_embeddings (`bool`, *optional*, defaults to `False`):
61
+ Whether or not to use sinusoidal positional embeddings instead of absolute positional embeddings.
62
+ causal (`bool`, *optional*, defaults to `False`):
63
+ Whether or not the model should behave in a causal manner. Causal models use a triangular attention mask in
64
+ order to only attend to the left-side context instead if a bidirectional context.
65
+ asm (`bool`, *optional*, defaults to `False`):
66
+ Whether or not to use an adaptive log softmax projection layer instead of a linear layer for the prediction
67
+ layer.
68
+ n_langs (`int`, *optional*, defaults to 1):
69
+ The number of languages the model handles. Set to 1 for monolingual models.
70
+ use_lang_emb (`bool`, *optional*, defaults to `True`)
71
+ Whether to use language embeddings. Some models use additional language embeddings, see [the multilingual
72
+ models page](http://huggingface.co/transformers/multilingual.html#xlm-language-embeddings) for information
73
+ on how to use them.
74
+ max_position_embeddings (`int`, *optional*, defaults to 512):
75
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
76
+ just in case (e.g., 512 or 1024 or 2048).
77
+ embed_init_std (`float`, *optional*, defaults to 2048^-0.5):
78
+ The standard deviation of the truncated_normal_initializer for initializing the embedding matrices.
79
+ init_std (`int`, *optional*, defaults to 50257):
80
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices except the
81
+ embedding matrices.
82
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
83
+ The epsilon used by the layer normalization layers.
84
+ bos_index (`int`, *optional*, defaults to 0):
85
+ The index of the beginning of sentence token in the vocabulary.
86
+ eos_index (`int`, *optional*, defaults to 1):
87
+ The index of the end of sentence token in the vocabulary.
88
+ pad_index (`int`, *optional*, defaults to 2):
89
+ The index of the padding token in the vocabulary.
90
+ unk_index (`int`, *optional*, defaults to 3):
91
+ The index of the unknown token in the vocabulary.
92
+ mask_index (`int`, *optional*, defaults to 5):
93
+ The index of the masking token in the vocabulary.
94
+ is_encoder(`bool`, *optional*, defaults to `True`):
95
+ Whether or not the initialized model should be a transformer encoder or decoder as seen in Vaswani et al.
96
+ summary_type (`string`, *optional*, defaults to "first"):
97
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
98
+
99
+ Has to be one of the following options:
100
+
101
+ - `"last"`: Take the last token hidden state (like XLNet).
102
+ - `"first"`: Take the first token hidden state (like BERT).
103
+ - `"mean"`: Take the mean of all tokens hidden states.
104
+ - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
105
+ - `"attn"`: Not implemented now, use multi-head attention.
106
+ summary_use_proj (`bool`, *optional*, defaults to `True`):
107
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
108
+
109
+ Whether or not to add a projection after the vector extraction.
110
+ summary_activation (`str`, *optional*):
111
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
112
+
113
+ Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
114
+ summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
115
+ Used in the sequence classification and multiple choice models.
116
+
117
+ Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
118
+ summary_first_dropout (`float`, *optional*, defaults to 0.1):
119
+ Used in the sequence classification and multiple choice models.
120
+
121
+ The dropout ratio to be used after the projection and activation.
122
+ start_n_top (`int`, *optional*, defaults to 5):
123
+ Used in the SQuAD evaluation script.
124
+ end_n_top (`int`, *optional*, defaults to 5):
125
+ Used in the SQuAD evaluation script.
126
+ mask_token_id (`int`, *optional*, defaults to 0):
127
+ Model agnostic parameter to identify masked tokens when generating text in an MLM context.
128
+ lang_id (`int`, *optional*, defaults to 1):
129
+ The ID of the language used by the model. This parameter is used when generating text in a given language.
130
+ """
131
+
132
+ model_type = "flaubert"
133
+ attribute_map = {
134
+ "hidden_size": "emb_dim",
135
+ "num_attention_heads": "n_heads",
136
+ "num_hidden_layers": "n_layers",
137
+ "n_words": "vocab_size", # For backward compatibility
138
+ }
139
+
140
+ def __init__(
141
+ self,
142
+ pre_norm=False,
143
+ layerdrop=0.0,
144
+ vocab_size=30145,
145
+ emb_dim=2048,
146
+ n_layers=12,
147
+ n_heads=16,
148
+ dropout=0.1,
149
+ attention_dropout=0.1,
150
+ gelu_activation=True,
151
+ sinusoidal_embeddings=False,
152
+ causal=False,
153
+ asm=False,
154
+ n_langs=1,
155
+ use_lang_emb=True,
156
+ max_position_embeddings=512,
157
+ embed_init_std=2048**-0.5,
158
+ layer_norm_eps=1e-12,
159
+ init_std=0.02,
160
+ bos_index=0,
161
+ eos_index=1,
162
+ pad_index=2,
163
+ unk_index=3,
164
+ mask_index=5,
165
+ is_encoder=True,
166
+ summary_type="first",
167
+ summary_use_proj=True,
168
+ summary_activation=None,
169
+ summary_proj_to_labels=True,
170
+ summary_first_dropout=0.1,
171
+ start_n_top=5,
172
+ end_n_top=5,
173
+ mask_token_id=0,
174
+ lang_id=0,
175
+ pad_token_id=2,
176
+ bos_token_id=0,
177
+ **kwargs,
178
+ ):
179
+ """Constructs FlaubertConfig."""
180
+ self.pre_norm = pre_norm
181
+ self.layerdrop = layerdrop
182
+ self.vocab_size = vocab_size
183
+ self.emb_dim = emb_dim
184
+ self.n_layers = n_layers
185
+ self.n_heads = n_heads
186
+ self.dropout = dropout
187
+ self.attention_dropout = attention_dropout
188
+ self.gelu_activation = gelu_activation
189
+ self.sinusoidal_embeddings = sinusoidal_embeddings
190
+ self.causal = causal
191
+ self.asm = asm
192
+ self.n_langs = n_langs
193
+ self.use_lang_emb = use_lang_emb
194
+ self.layer_norm_eps = layer_norm_eps
195
+ self.bos_index = bos_index
196
+ self.eos_index = eos_index
197
+ self.pad_index = pad_index
198
+ self.unk_index = unk_index
199
+ self.mask_index = mask_index
200
+ self.is_encoder = is_encoder
201
+ self.max_position_embeddings = max_position_embeddings
202
+ self.embed_init_std = embed_init_std
203
+ self.init_std = init_std
204
+ self.summary_type = summary_type
205
+ self.summary_use_proj = summary_use_proj
206
+ self.summary_activation = summary_activation
207
+ self.summary_proj_to_labels = summary_proj_to_labels
208
+ self.summary_first_dropout = summary_first_dropout
209
+ self.start_n_top = start_n_top
210
+ self.end_n_top = end_n_top
211
+ self.mask_token_id = mask_token_id
212
+ self.lang_id = lang_id
213
+
214
+ if "n_words" in kwargs:
215
+ self.n_words = kwargs["n_words"]
216
+
217
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
218
+
219
+
220
+ class FlaubertOnnxConfig(OnnxConfig):
221
+ @property
222
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
223
+ if self.task == "multiple-choice":
224
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
225
+ else:
226
+ dynamic_axis = {0: "batch", 1: "sequence"}
227
+ return OrderedDict(
228
+ [
229
+ ("input_ids", dynamic_axis),
230
+ ("attention_mask", dynamic_axis),
231
+ ]
232
+ )
233
+
234
+
235
+ __all__ = ["FlaubertConfig", "FlaubertOnnxConfig"]
docs/transformers/build/lib/transformers/models/flaubert/modeling_flaubert.py ADDED
@@ -0,0 +1,1739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2019-present CNRS, Facebook Inc. and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Flaubert model, based on XLM."""
16
+
17
+ import itertools
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Callable, Dict, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ from torch import nn
25
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
+
27
+ from ...activations import gelu, get_activation
28
+ from ...generation import GenerationMixin
29
+ from ...modeling_outputs import (
30
+ BaseModelOutput,
31
+ MaskedLMOutput,
32
+ MultipleChoiceModelOutput,
33
+ QuestionAnsweringModelOutput,
34
+ SequenceClassifierOutput,
35
+ TokenClassifierOutput,
36
+ )
37
+ from ...modeling_utils import PreTrainedModel
38
+ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
39
+ from ...utils import (
40
+ ModelOutput,
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from .configuration_flaubert import FlaubertConfig
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CHECKPOINT_FOR_DOC = "flaubert/flaubert_base_cased"
53
+ _CONFIG_FOR_DOC = "FlaubertConfig"
54
+
55
+
56
+ # Copied from transformers.models.xlm.modeling_xlm.create_sinusoidal_embeddings
57
+ def create_sinusoidal_embeddings(n_pos, dim, out):
58
+ position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
59
+ out.requires_grad = False
60
+ out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
61
+ out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
62
+ out.detach_()
63
+
64
+
65
+ # Copied from transformers.models.xlm.modeling_xlm.get_masks
66
+ def get_masks(slen, lengths, causal, padding_mask=None):
67
+ """
68
+ Generate hidden states mask, and optionally an attention mask.
69
+ """
70
+ alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
71
+ if padding_mask is not None:
72
+ mask = padding_mask
73
+ else:
74
+ assert lengths.max().item() <= slen
75
+ mask = alen < lengths[:, None]
76
+
77
+ # attention mask is the same as mask, or triangular inferior attention (causal)
78
+ bs = lengths.size(0)
79
+ if causal:
80
+ attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
81
+ else:
82
+ attn_mask = mask
83
+
84
+ # sanity check
85
+ assert mask.size() == (bs, slen)
86
+ assert causal is False or attn_mask.size() == (bs, slen, slen)
87
+
88
+ return mask, attn_mask
89
+
90
+
91
+ # Copied from transformers.models.xlm.modeling_xlm.MultiHeadAttention
92
+ class MultiHeadAttention(nn.Module):
93
+ NEW_ID = itertools.count()
94
+
95
+ def __init__(self, n_heads, dim, config):
96
+ super().__init__()
97
+ self.layer_id = next(MultiHeadAttention.NEW_ID)
98
+ self.dim = dim
99
+ self.n_heads = n_heads
100
+ self.dropout = config.attention_dropout
101
+ assert self.dim % self.n_heads == 0
102
+
103
+ self.q_lin = nn.Linear(dim, dim)
104
+ self.k_lin = nn.Linear(dim, dim)
105
+ self.v_lin = nn.Linear(dim, dim)
106
+ self.out_lin = nn.Linear(dim, dim)
107
+ self.pruned_heads = set()
108
+
109
+ def prune_heads(self, heads):
110
+ attention_head_size = self.dim // self.n_heads
111
+ if len(heads) == 0:
112
+ return
113
+ heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
114
+ # Prune linear layers
115
+ self.q_lin = prune_linear_layer(self.q_lin, index)
116
+ self.k_lin = prune_linear_layer(self.k_lin, index)
117
+ self.v_lin = prune_linear_layer(self.v_lin, index)
118
+ self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
119
+ # Update hyper params
120
+ self.n_heads = self.n_heads - len(heads)
121
+ self.dim = attention_head_size * self.n_heads
122
+ self.pruned_heads = self.pruned_heads.union(heads)
123
+
124
+ def forward(self, input, mask, kv=None, cache=None, head_mask=None, output_attentions=False):
125
+ """
126
+ Self-attention (if kv is None) or attention over source sentence (provided by kv).
127
+ """
128
+ # Input is (bs, qlen, dim)
129
+ # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
130
+ bs, qlen, dim = input.size()
131
+ if kv is None:
132
+ klen = qlen if cache is None else cache["slen"] + qlen
133
+ else:
134
+ klen = kv.size(1)
135
+ # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
136
+ n_heads = self.n_heads
137
+ dim_per_head = self.dim // n_heads
138
+ mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)
139
+
140
+ def shape(x):
141
+ """projection"""
142
+ return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
143
+
144
+ def unshape(x):
145
+ """compute context"""
146
+ return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
147
+
148
+ q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
149
+ if kv is None:
150
+ k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head)
151
+ v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head)
152
+ elif cache is None or self.layer_id not in cache:
153
+ k = v = kv
154
+ k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head)
155
+ v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head)
156
+
157
+ if cache is not None:
158
+ if self.layer_id in cache:
159
+ if kv is None:
160
+ k_, v_ = cache[self.layer_id]
161
+ k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head)
162
+ v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head)
163
+ else:
164
+ k, v = cache[self.layer_id]
165
+ cache[self.layer_id] = (k, v)
166
+
167
+ q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
168
+ scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
169
+ mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
170
+ scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)
171
+
172
+ weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
173
+ weights = nn.functional.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
174
+
175
+ # Mask heads if we want to
176
+ if head_mask is not None:
177
+ weights = weights * head_mask
178
+
179
+ context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
180
+ context = unshape(context) # (bs, qlen, dim)
181
+
182
+ outputs = (self.out_lin(context),)
183
+ if output_attentions:
184
+ outputs = outputs + (weights,)
185
+ return outputs
186
+
187
+
188
+ # Copied from transformers.models.xlm.modeling_xlm.TransformerFFN
189
+ class TransformerFFN(nn.Module):
190
+ def __init__(self, in_dim, dim_hidden, out_dim, config):
191
+ super().__init__()
192
+ self.dropout = config.dropout
193
+ self.lin1 = nn.Linear(in_dim, dim_hidden)
194
+ self.lin2 = nn.Linear(dim_hidden, out_dim)
195
+ self.act = gelu if config.gelu_activation else nn.functional.relu
196
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
197
+ self.seq_len_dim = 1
198
+
199
+ def forward(self, input):
200
+ return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
201
+
202
+ def ff_chunk(self, input):
203
+ x = self.lin1(input)
204
+ x = self.act(x)
205
+ x = self.lin2(x)
206
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
207
+ return x
208
+
209
+
210
+ FLAUBERT_START_DOCSTRING = r"""
211
+
212
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
213
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
214
+ etc.)
215
+
216
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
217
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
218
+ and behavior.
219
+
220
+ Parameters:
221
+ config ([`FlaubertConfig`]): Model configuration class with all the parameters of the model.
222
+ Initializing with a config file does not load the weights associated with the model, only the
223
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
224
+ """
225
+
226
+ FLAUBERT_INPUTS_DOCSTRING = r"""
227
+ Args:
228
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
229
+ Indices of input sequence tokens in the vocabulary.
230
+
231
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
232
+ [`PreTrainedTokenizer.__call__`] for details.
233
+
234
+ [What are input IDs?](../glossary#input-ids)
235
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
236
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
237
+
238
+ - 1 for tokens that are **not masked**,
239
+ - 0 for tokens that are **masked**.
240
+
241
+ [What are attention masks?](../glossary#attention-mask)
242
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
243
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
244
+ 1]`:
245
+
246
+ - 0 corresponds to a *sentence A* token,
247
+ - 1 corresponds to a *sentence B* token.
248
+
249
+ [What are token type IDs?](../glossary#token-type-ids)
250
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
251
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
252
+ config.max_position_embeddings - 1]`.
253
+
254
+ [What are position IDs?](../glossary#position-ids)
255
+ lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
256
+ Length of each sentence that can be used to avoid performing attention on padding token indices. You can
257
+ also use `attention_mask` for the same result (see above), kept here for compatibility. Indices selected in
258
+ `[0, ..., input_ids.size(-1)]`:
259
+ cache (`Dict[str, torch.FloatTensor]`, *optional*):
260
+ Dictionary strings to `torch.FloatTensor` that contains precomputed hidden-states (key and values in the
261
+ attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential
262
+ decoding. The dictionary object will be modified in-place during the forward pass to add newly computed
263
+ hidden-states.
264
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
265
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
266
+
267
+ - 1 indicates the head is **not masked**,
268
+ - 0 indicates the head is **masked**.
269
+
270
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
271
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
272
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
273
+ model's internal embedding lookup matrix.
274
+ output_attentions (`bool`, *optional*):
275
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
276
+ tensors for more detail.
277
+ output_hidden_states (`bool`, *optional*):
278
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
279
+ more detail.
280
+ return_dict (`bool`, *optional*):
281
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
282
+ """
283
+
284
+
285
+ @add_start_docstrings(
286
+ "The bare Flaubert Model transformer outputting raw hidden-states without any specific head on top.",
287
+ FLAUBERT_START_DOCSTRING,
288
+ )
289
+ # Copied from transformers.models.xlm.modeling_xlm.XLMPredLayer with XLM->Flaubert
290
+ class FlaubertPredLayer(nn.Module):
291
+ """
292
+ Prediction layer (cross_entropy or adaptive_softmax).
293
+ """
294
+
295
+ def __init__(self, config):
296
+ super().__init__()
297
+ self.asm = config.asm
298
+ self.n_words = config.n_words
299
+ self.pad_index = config.pad_index
300
+ dim = config.emb_dim
301
+
302
+ if config.asm is False:
303
+ self.proj = nn.Linear(dim, config.n_words, bias=True)
304
+ else:
305
+ self.proj = nn.AdaptiveLogSoftmaxWithLoss(
306
+ in_features=dim,
307
+ n_classes=config.n_words,
308
+ cutoffs=config.asm_cutoffs,
309
+ div_value=config.asm_div_value,
310
+ head_bias=True, # default is False
311
+ )
312
+
313
+ def forward(self, x, y=None):
314
+ """Compute the loss, and optionally the scores."""
315
+ outputs = ()
316
+ if self.asm is False:
317
+ scores = self.proj(x)
318
+ outputs = (scores,) + outputs
319
+ if y is not None:
320
+ loss = nn.functional.cross_entropy(scores.view(-1, self.n_words), y.view(-1), reduction="mean")
321
+ outputs = (loss,) + outputs
322
+ else:
323
+ scores = self.proj.log_prob(x)
324
+ outputs = (scores,) + outputs
325
+ if y is not None:
326
+ _, loss = self.proj(x, y)
327
+ outputs = (loss,) + outputs
328
+
329
+ return outputs
330
+
331
+
332
+ @dataclass
333
+ # Copied from transformers.models.xlm.modeling_xlm.XLMSquadHeadOutput with XLM->Flaubert
334
+ class FlaubertSquadHeadOutput(ModelOutput):
335
+ """
336
+ Base class for outputs of question answering models using a [`~modeling_utils.FlaubertSQuADHead`].
337
+
338
+ Args:
339
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
340
+ Classification loss as the sum of start token, end token (and is_impossible if provided) classification
341
+ losses.
342
+ start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
343
+ Log probabilities for the top config.start_n_top start token possibilities (beam-search).
344
+ start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
345
+ Indices for the top config.start_n_top start token possibilities (beam-search).
346
+ end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
347
+ Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
348
+ (beam-search).
349
+ end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
350
+ Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
351
+ cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
352
+ Log probabilities for the `is_impossible` label of the answers.
353
+
354
+ """
355
+
356
+ loss: Optional[torch.FloatTensor] = None
357
+ start_top_log_probs: Optional[torch.FloatTensor] = None
358
+ start_top_index: Optional[torch.LongTensor] = None
359
+ end_top_log_probs: Optional[torch.FloatTensor] = None
360
+ end_top_index: Optional[torch.LongTensor] = None
361
+ cls_logits: Optional[torch.FloatTensor] = None
362
+
363
+
364
+ # Copied from transformers.models.xlm.modeling_xlm.XLMPoolerStartLogits with XLM->Flaubert
365
+ class FlaubertPoolerStartLogits(nn.Module):
366
+ """
367
+ Compute SQuAD start logits from sequence hidden states.
368
+
369
+ Args:
370
+ config ([`FlaubertConfig`]):
371
+ The config used by the model, will be used to grab the `hidden_size` of the model.
372
+ """
373
+
374
+ def __init__(self, config: FlaubertConfig):
375
+ super().__init__()
376
+ self.dense = nn.Linear(config.hidden_size, 1)
377
+
378
+ def forward(
379
+ self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
380
+ ) -> torch.FloatTensor:
381
+ """
382
+ Args:
383
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
384
+ The final hidden states of the model.
385
+ p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
386
+ Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
387
+ should be masked.
388
+
389
+ Returns:
390
+ `torch.FloatTensor`: The start logits for SQuAD.
391
+ """
392
+ x = self.dense(hidden_states).squeeze(-1)
393
+
394
+ if p_mask is not None:
395
+ if p_mask.dtype == torch.float16:
396
+ x = x * (1 - p_mask) - 65500 * p_mask
397
+ else:
398
+ x = x * (1 - p_mask) - 1e30 * p_mask
399
+
400
+ return x
401
+
402
+
403
+ # Copied from transformers.models.xlm.modeling_xlm.XLMPoolerEndLogits with XLM->Flaubert
404
+ class FlaubertPoolerEndLogits(nn.Module):
405
+ """
406
+ Compute SQuAD end logits from sequence hidden states.
407
+
408
+ Args:
409
+ config ([`FlaubertConfig`]):
410
+ The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
411
+ to use.
412
+ """
413
+
414
+ def __init__(self, config: FlaubertConfig):
415
+ super().__init__()
416
+ self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
417
+ self.activation = nn.Tanh()
418
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
419
+ self.dense_1 = nn.Linear(config.hidden_size, 1)
420
+
421
+ def forward(
422
+ self,
423
+ hidden_states: torch.FloatTensor,
424
+ start_states: Optional[torch.FloatTensor] = None,
425
+ start_positions: Optional[torch.LongTensor] = None,
426
+ p_mask: Optional[torch.FloatTensor] = None,
427
+ ) -> torch.FloatTensor:
428
+ """
429
+ Args:
430
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
431
+ The final hidden states of the model.
432
+ start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
433
+ The hidden states of the first tokens for the labeled span.
434
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
435
+ The position of the first token for the labeled span.
436
+ p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
437
+ Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
438
+ should be masked.
439
+
440
+ <Tip>
441
+
442
+ One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
443
+ `start_states`.
444
+
445
+ </Tip>
446
+
447
+ Returns:
448
+ `torch.FloatTensor`: The end logits for SQuAD.
449
+ """
450
+ assert start_states is not None or start_positions is not None, (
451
+ "One of start_states, start_positions should be not None"
452
+ )
453
+ if start_positions is not None:
454
+ slen, hsz = hidden_states.shape[-2:]
455
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
456
+ start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
457
+ start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
458
+
459
+ x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
460
+ x = self.activation(x)
461
+ x = self.LayerNorm(x)
462
+ x = self.dense_1(x).squeeze(-1)
463
+
464
+ if p_mask is not None:
465
+ if p_mask.dtype == torch.float16:
466
+ x = x * (1 - p_mask) - 65500 * p_mask
467
+ else:
468
+ x = x * (1 - p_mask) - 1e30 * p_mask
469
+
470
+ return x
471
+
472
+
473
+ # Copied from transformers.models.xlm.modeling_xlm.XLMPoolerAnswerClass with XLM->Flaubert
474
+ class FlaubertPoolerAnswerClass(nn.Module):
475
+ """
476
+ Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
477
+
478
+ Args:
479
+ config ([`FlaubertConfig`]):
480
+ The config used by the model, will be used to grab the `hidden_size` of the model.
481
+ """
482
+
483
+ def __init__(self, config: FlaubertConfig):
484
+ super().__init__()
485
+ self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
486
+ self.activation = nn.Tanh()
487
+ self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
488
+
489
+ def forward(
490
+ self,
491
+ hidden_states: torch.FloatTensor,
492
+ start_states: Optional[torch.FloatTensor] = None,
493
+ start_positions: Optional[torch.LongTensor] = None,
494
+ cls_index: Optional[torch.LongTensor] = None,
495
+ ) -> torch.FloatTensor:
496
+ """
497
+ Args:
498
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
499
+ The final hidden states of the model.
500
+ start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
501
+ The hidden states of the first tokens for the labeled span.
502
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
503
+ The position of the first token for the labeled span.
504
+ cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
505
+ Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
506
+
507
+ <Tip>
508
+
509
+ One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
510
+ `start_states`.
511
+
512
+ </Tip>
513
+
514
+ Returns:
515
+ `torch.FloatTensor`: The SQuAD 2.0 answer class.
516
+ """
517
+ # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
518
+ hsz = hidden_states.shape[-1]
519
+ assert start_states is not None or start_positions is not None, (
520
+ "One of start_states, start_positions should be not None"
521
+ )
522
+ if start_positions is not None:
523
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
524
+ start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
525
+
526
+ if cls_index is not None:
527
+ cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
528
+ cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
529
+ else:
530
+ cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
531
+
532
+ x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
533
+ x = self.activation(x)
534
+ x = self.dense_1(x).squeeze(-1)
535
+
536
+ return x
537
+
538
+
539
+ # Copied from transformers.models.xlm.modeling_xlm.XLMSQuADHead with XLM->Flaubert
540
+ class FlaubertSQuADHead(nn.Module):
541
+ r"""
542
+ A SQuAD head inspired by XLNet.
543
+
544
+ Args:
545
+ config ([`FlaubertConfig`]):
546
+ The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
547
+ to use.
548
+ """
549
+
550
+ def __init__(self, config: FlaubertConfig):
551
+ super().__init__()
552
+ self.start_n_top = config.start_n_top
553
+ self.end_n_top = config.end_n_top
554
+
555
+ self.start_logits = FlaubertPoolerStartLogits(config)
556
+ self.end_logits = FlaubertPoolerEndLogits(config)
557
+ self.answer_class = FlaubertPoolerAnswerClass(config)
558
+
559
+ @replace_return_docstrings(output_type=FlaubertSquadHeadOutput, config_class=FlaubertConfig)
560
+ def forward(
561
+ self,
562
+ hidden_states: torch.FloatTensor,
563
+ start_positions: Optional[torch.LongTensor] = None,
564
+ end_positions: Optional[torch.LongTensor] = None,
565
+ cls_index: Optional[torch.LongTensor] = None,
566
+ is_impossible: Optional[torch.LongTensor] = None,
567
+ p_mask: Optional[torch.FloatTensor] = None,
568
+ return_dict: bool = False,
569
+ ) -> Union[FlaubertSquadHeadOutput, Tuple[torch.FloatTensor]]:
570
+ """
571
+ Args:
572
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
573
+ Final hidden states of the model on the sequence tokens.
574
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
575
+ Positions of the first token for the labeled span.
576
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
577
+ Positions of the last token for the labeled span.
578
+ cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
579
+ Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
580
+ is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
581
+ Whether the question has a possible answer in the paragraph or not.
582
+ p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
583
+ Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
584
+ should be masked.
585
+ return_dict (`bool`, *optional*, defaults to `False`):
586
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
587
+
588
+ Returns:
589
+ """
590
+ start_logits = self.start_logits(hidden_states, p_mask=p_mask)
591
+
592
+ if start_positions is not None and end_positions is not None:
593
+ # If we are on multi-GPU, let's remove the dimension added by batch splitting
594
+ for x in (start_positions, end_positions, cls_index, is_impossible):
595
+ if x is not None and x.dim() > 1:
596
+ x.squeeze_(-1)
597
+
598
+ # during training, compute the end logits based on the ground truth of the start position
599
+ end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
600
+
601
+ loss_fct = CrossEntropyLoss()
602
+ start_loss = loss_fct(start_logits, start_positions)
603
+ end_loss = loss_fct(end_logits, end_positions)
604
+ total_loss = (start_loss + end_loss) / 2
605
+
606
+ if cls_index is not None and is_impossible is not None:
607
+ # Predict answerability from the representation of CLS and START
608
+ cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
609
+ loss_fct_cls = nn.BCEWithLogitsLoss()
610
+ cls_loss = loss_fct_cls(cls_logits, is_impossible)
611
+
612
+ # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
613
+ total_loss += cls_loss * 0.5
614
+
615
+ return FlaubertSquadHeadOutput(loss=total_loss) if return_dict else (total_loss,)
616
+
617
+ else:
618
+ # during inference, compute the end logits based on beam search
619
+ bsz, slen, hsz = hidden_states.size()
620
+ start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen)
621
+
622
+ start_top_log_probs, start_top_index = torch.topk(
623
+ start_log_probs, self.start_n_top, dim=-1
624
+ ) # shape (bsz, start_n_top)
625
+ start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
626
+ start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
627
+ start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
628
+
629
+ hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
630
+ start_states
631
+ ) # shape (bsz, slen, start_n_top, hsz)
632
+ p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
633
+ end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
634
+ end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
635
+
636
+ end_top_log_probs, end_top_index = torch.topk(
637
+ end_log_probs, self.end_n_top, dim=1
638
+ ) # shape (bsz, end_n_top, start_n_top)
639
+ end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
640
+ end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
641
+
642
+ start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
643
+ cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
644
+
645
+ if not return_dict:
646
+ return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
647
+ else:
648
+ return FlaubertSquadHeadOutput(
649
+ start_top_log_probs=start_top_log_probs,
650
+ start_top_index=start_top_index,
651
+ end_top_log_probs=end_top_log_probs,
652
+ end_top_index=end_top_index,
653
+ cls_logits=cls_logits,
654
+ )
655
+
656
+
657
+ # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Flaubert
658
+ class FlaubertSequenceSummary(nn.Module):
659
+ r"""
660
+ Compute a single vector summary of a sequence hidden states.
661
+
662
+ Args:
663
+ config ([`FlaubertConfig`]):
664
+ The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
665
+ config class of your model for the default values it uses):
666
+
667
+ - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
668
+
669
+ - `"last"` -- Take the last token hidden state (like XLNet)
670
+ - `"first"` -- Take the first token hidden state (like Bert)
671
+ - `"mean"` -- Take the mean of all tokens hidden states
672
+ - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
673
+ - `"attn"` -- Not implemented now, use multi-head attention
674
+
675
+ - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
676
+ - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
677
+ (otherwise to `config.hidden_size`).
678
+ - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
679
+ another string or `None` will add no activation.
680
+ - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
681
+ - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
682
+ """
683
+
684
+ def __init__(self, config: FlaubertConfig):
685
+ super().__init__()
686
+
687
+ self.summary_type = getattr(config, "summary_type", "last")
688
+ if self.summary_type == "attn":
689
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
690
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
691
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
692
+ raise NotImplementedError
693
+
694
+ self.summary = nn.Identity()
695
+ if hasattr(config, "summary_use_proj") and config.summary_use_proj:
696
+ if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
697
+ num_classes = config.num_labels
698
+ else:
699
+ num_classes = config.hidden_size
700
+ self.summary = nn.Linear(config.hidden_size, num_classes)
701
+
702
+ activation_string = getattr(config, "summary_activation", None)
703
+ self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
704
+
705
+ self.first_dropout = nn.Identity()
706
+ if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
707
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
708
+
709
+ self.last_dropout = nn.Identity()
710
+ if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
711
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
712
+
713
+ def forward(
714
+ self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
715
+ ) -> torch.FloatTensor:
716
+ """
717
+ Compute a single vector summary of a sequence hidden states.
718
+
719
+ Args:
720
+ hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
721
+ The hidden states of the last layer.
722
+ cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
723
+ Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
724
+
725
+ Returns:
726
+ `torch.FloatTensor`: The summary of the sequence hidden states.
727
+ """
728
+ if self.summary_type == "last":
729
+ output = hidden_states[:, -1]
730
+ elif self.summary_type == "first":
731
+ output = hidden_states[:, 0]
732
+ elif self.summary_type == "mean":
733
+ output = hidden_states.mean(dim=1)
734
+ elif self.summary_type == "cls_index":
735
+ if cls_index is None:
736
+ cls_index = torch.full_like(
737
+ hidden_states[..., :1, :],
738
+ hidden_states.shape[-2] - 1,
739
+ dtype=torch.long,
740
+ )
741
+ else:
742
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
743
+ cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
744
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
745
+ output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
746
+ elif self.summary_type == "attn":
747
+ raise NotImplementedError
748
+
749
+ output = self.first_dropout(output)
750
+ output = self.summary(output)
751
+ output = self.activation(output)
752
+ output = self.last_dropout(output)
753
+
754
+ return output
755
+
756
+
757
+ # Copied from transformers.models.xlm.modeling_xlm.XLMPreTrainedModel with XLM->Flaubert
758
+ class FlaubertPreTrainedModel(PreTrainedModel):
759
+ """
760
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
761
+ models.
762
+ """
763
+
764
+ config_class = FlaubertConfig
765
+ load_tf_weights = None
766
+ base_model_prefix = "transformer"
767
+
768
+ def __init__(self, *inputs, **kwargs):
769
+ super().__init__(*inputs, **kwargs)
770
+
771
+ @property
772
+ def dummy_inputs(self):
773
+ inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
774
+ attns_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
775
+ if self.config.use_lang_emb and self.config.n_langs > 1:
776
+ langs_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
777
+ else:
778
+ langs_list = None
779
+ return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
780
+
781
+ def _init_weights(self, module):
782
+ """Initialize the weights."""
783
+ if isinstance(module, nn.Embedding):
784
+ if self.config is not None and self.config.embed_init_std is not None:
785
+ nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std)
786
+ if module.padding_idx is not None:
787
+ module.weight.data[module.padding_idx].zero_()
788
+ if isinstance(module, nn.Linear):
789
+ if self.config is not None and self.config.init_std is not None:
790
+ nn.init.normal_(module.weight, mean=0, std=self.config.init_std)
791
+ if module.bias is not None:
792
+ nn.init.constant_(module.bias, 0.0)
793
+ if isinstance(module, nn.LayerNorm):
794
+ module.bias.data.zero_()
795
+ module.weight.data.fill_(1.0)
796
+ if isinstance(module, FlaubertModel) and self.config.sinusoidal_embeddings:
797
+ create_sinusoidal_embeddings(
798
+ self.config.max_position_embeddings, self.config.emb_dim, out=module.position_embeddings.weight
799
+ )
800
+
801
+
802
+ class FlaubertModel(FlaubertPreTrainedModel):
803
+ def __init__(self, config): # , dico, is_encoder, with_output):
804
+ super().__init__(config)
805
+
806
+ # encoder / decoder, output layer
807
+ self.is_encoder = config.is_encoder
808
+ self.is_decoder = not config.is_encoder
809
+ if self.is_decoder:
810
+ raise NotImplementedError("Currently Flaubert can only be used as an encoder")
811
+ # self.with_output = with_output
812
+ self.causal = config.causal
813
+
814
+ # dictionary / languages
815
+ self.n_langs = config.n_langs
816
+ self.use_lang_emb = config.use_lang_emb
817
+ self.n_words = config.n_words
818
+ self.eos_index = config.eos_index
819
+ self.pad_index = config.pad_index
820
+ # self.dico = dico
821
+ # self.id2lang = config.id2lang
822
+ # self.lang2id = config.lang2id
823
+ # assert len(self.dico) == self.n_words
824
+ # assert len(self.id2lang) == len(self.lang2id) == self.n_langs
825
+
826
+ # model parameters
827
+ self.dim = config.emb_dim # 512 by default
828
+ self.hidden_dim = self.dim * 4 # 2048 by default
829
+ self.n_heads = config.n_heads # 8 by default
830
+ self.n_layers = config.n_layers
831
+ self.dropout = config.dropout
832
+ self.attention_dropout = config.attention_dropout
833
+ assert self.dim % self.n_heads == 0, "transformer dim must be a multiple of n_heads"
834
+
835
+ # embeddings
836
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
837
+ if config.n_langs > 1 and config.use_lang_emb:
838
+ self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
839
+ self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
840
+ self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
841
+
842
+ # transformer layers
843
+ self.attentions = nn.ModuleList()
844
+ self.layer_norm1 = nn.ModuleList()
845
+ self.ffns = nn.ModuleList()
846
+ self.layer_norm2 = nn.ModuleList()
847
+ # if self.is_decoder:
848
+ # self.layer_norm15 = nn.ModuleList()
849
+ # self.encoder_attn = nn.ModuleList()
850
+
851
+ for _ in range(self.n_layers):
852
+ self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config))
853
+ self.layer_norm1.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
854
+ # if self.is_decoder:
855
+ # self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
856
+ # self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
857
+ self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
858
+ self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
859
+
860
+ if hasattr(config, "pruned_heads"):
861
+ pruned_heads = config.pruned_heads.copy().items()
862
+ config.pruned_heads = {}
863
+ for layer, heads in pruned_heads:
864
+ if self.attentions[int(layer)].n_heads == config.n_heads:
865
+ self.prune_heads({int(layer): list(map(int, heads))})
866
+
867
+ # Initialize weights and apply final processing
868
+ self.post_init()
869
+
870
+ self.layerdrop = getattr(config, "layerdrop", 0.0)
871
+ self.pre_norm = getattr(config, "pre_norm", False)
872
+ self.register_buffer(
873
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
874
+ )
875
+
876
+ # Copied from transformers.models.xlm.modeling_xlm.XLMModel.get_input_embeddings
877
+ def get_input_embeddings(self):
878
+ return self.embeddings
879
+
880
+ # Copied from transformers.models.xlm.modeling_xlm.XLMModel.set_input_embeddings
881
+ def set_input_embeddings(self, new_embeddings):
882
+ self.embeddings = new_embeddings
883
+
884
+ # Copied from transformers.models.xlm.modeling_xlm.XLMModel._prune_heads
885
+ def _prune_heads(self, heads_to_prune):
886
+ """
887
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
888
+ class PreTrainedModel
889
+ """
890
+ for layer, heads in heads_to_prune.items():
891
+ self.attentions[layer].prune_heads(heads)
892
+
893
+ @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING)
894
+ @add_code_sample_docstrings(
895
+ checkpoint=_CHECKPOINT_FOR_DOC,
896
+ output_type=BaseModelOutput,
897
+ config_class=_CONFIG_FOR_DOC,
898
+ )
899
+ def forward(
900
+ self,
901
+ input_ids: Optional[torch.LongTensor] = None,
902
+ attention_mask: Optional[torch.FloatTensor] = None,
903
+ langs: Optional[torch.Tensor] = None,
904
+ token_type_ids: Optional[torch.LongTensor] = None,
905
+ position_ids: Optional[torch.LongTensor] = None,
906
+ lengths: Optional[torch.LongTensor] = None,
907
+ cache: Optional[Dict[str, torch.FloatTensor]] = None,
908
+ head_mask: Optional[torch.FloatTensor] = None,
909
+ inputs_embeds: Optional[torch.FloatTensor] = None,
910
+ output_attentions: Optional[bool] = None,
911
+ output_hidden_states: Optional[bool] = None,
912
+ return_dict: Optional[bool] = None,
913
+ ) -> Union[Tuple, BaseModelOutput]:
914
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
915
+ output_hidden_states = (
916
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
917
+ )
918
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
919
+
920
+ # removed: src_enc=None, src_len=None
921
+ if input_ids is not None:
922
+ bs, slen = input_ids.size()
923
+ else:
924
+ bs, slen = inputs_embeds.size()[:-1]
925
+
926
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
927
+
928
+ if lengths is None:
929
+ if input_ids is not None:
930
+ lengths = (input_ids != self.pad_index).sum(dim=1).long()
931
+ else:
932
+ lengths = torch.tensor([slen] * bs, device=device)
933
+ # mask = input_ids != self.pad_index
934
+
935
+ # check inputs
936
+ assert lengths.size(0) == bs
937
+ assert lengths.max().item() <= slen
938
+ # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
939
+ # assert (src_enc is None) == (src_len is None)
940
+ # if src_enc is not None:
941
+ # assert self.is_decoder
942
+ # assert src_enc.size(0) == bs
943
+
944
+ # generate masks
945
+ mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
946
+ # if self.is_decoder and src_enc is not None:
947
+ # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
948
+
949
+ # Setting the position-ids to the registered buffer in constructor, it helps
950
+ # when tracing the model without passing position-ids, solves
951
+ # isues similar to issue #5664
952
+ if position_ids is None:
953
+ if hasattr(self, "position_ids"):
954
+ position_ids = self.position_ids[:, :slen]
955
+ position_ids = position_ids.expand((bs, slen))
956
+ else:
957
+ position_ids = torch.arange(slen, dtype=torch.long, device=device)
958
+ position_ids = position_ids.unsqueeze(0).expand((bs, slen))
959
+ else:
960
+ assert position_ids.size() == (bs, slen) # (slen, bs)
961
+ # position_ids = position_ids.transpose(0, 1)
962
+
963
+ # langs
964
+ if langs is not None:
965
+ assert langs.size() == (bs, slen) # (slen, bs)
966
+ # langs = langs.transpose(0, 1)
967
+
968
+ # Prepare head mask if needed
969
+ head_mask = self.get_head_mask(head_mask, self.config.n_layers)
970
+
971
+ # do not recompute cached elements
972
+ if cache is not None and input_ids is not None:
973
+ _slen = slen - cache["slen"]
974
+ input_ids = input_ids[:, -_slen:]
975
+ position_ids = position_ids[:, -_slen:]
976
+ if langs is not None:
977
+ langs = langs[:, -_slen:]
978
+ mask = mask[:, -_slen:]
979
+ attn_mask = attn_mask[:, -_slen:]
980
+
981
+ # embeddings
982
+ if inputs_embeds is None:
983
+ inputs_embeds = self.embeddings(input_ids)
984
+
985
+ tensor = inputs_embeds + self.position_embeddings(position_ids).expand_as(inputs_embeds)
986
+ if langs is not None and self.use_lang_emb and self.config.n_langs > 1:
987
+ tensor = tensor + self.lang_embeddings(langs)
988
+ if token_type_ids is not None:
989
+ tensor = tensor + self.embeddings(token_type_ids)
990
+ tensor = self.layer_norm_emb(tensor)
991
+ tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training)
992
+ tensor *= mask.unsqueeze(-1).to(tensor.dtype)
993
+
994
+ # transformer layers
995
+ hidden_states = () if output_hidden_states else None
996
+ attentions = () if output_attentions else None
997
+ for i in range(self.n_layers):
998
+ # LayerDrop
999
+ if self.training:
1000
+ dropout_probability = torch.rand([])
1001
+ if dropout_probability < self.layerdrop:
1002
+ continue
1003
+
1004
+ if output_hidden_states:
1005
+ hidden_states = hidden_states + (tensor,)
1006
+
1007
+ # self attention
1008
+ if not self.pre_norm:
1009
+ attn_outputs = self.attentions[i](
1010
+ tensor,
1011
+ attn_mask,
1012
+ cache=cache,
1013
+ head_mask=head_mask[i],
1014
+ output_attentions=output_attentions,
1015
+ )
1016
+ attn = attn_outputs[0]
1017
+ if output_attentions:
1018
+ attentions = attentions + (attn_outputs[1],)
1019
+ attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
1020
+ tensor = tensor + attn
1021
+ tensor = self.layer_norm1[i](tensor)
1022
+ else:
1023
+ tensor_normalized = self.layer_norm1[i](tensor)
1024
+ attn_outputs = self.attentions[i](tensor_normalized, attn_mask, cache=cache, head_mask=head_mask[i])
1025
+ attn = attn_outputs[0]
1026
+ if output_attentions:
1027
+ attentions = attentions + (attn_outputs[1],)
1028
+ attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
1029
+ tensor = tensor + attn
1030
+
1031
+ # encoder attention (for decoder only)
1032
+ # if self.is_decoder and src_enc is not None:
1033
+ # attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
1034
+ # attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
1035
+ # tensor = tensor + attn
1036
+ # tensor = self.layer_norm15[i](tensor)
1037
+
1038
+ # FFN
1039
+ if not self.pre_norm:
1040
+ tensor = tensor + self.ffns[i](tensor)
1041
+ tensor = self.layer_norm2[i](tensor)
1042
+ else:
1043
+ tensor_normalized = self.layer_norm2[i](tensor)
1044
+ tensor = tensor + self.ffns[i](tensor_normalized)
1045
+
1046
+ tensor *= mask.unsqueeze(-1).to(tensor.dtype)
1047
+
1048
+ # Add last hidden state
1049
+ if output_hidden_states:
1050
+ hidden_states = hidden_states + (tensor,)
1051
+
1052
+ # update cache length
1053
+ if cache is not None:
1054
+ cache["slen"] += tensor.size(1)
1055
+
1056
+ # move back sequence length to dimension 0
1057
+ # tensor = tensor.transpose(0, 1)
1058
+
1059
+ if not return_dict:
1060
+ return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
1061
+
1062
+ return BaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
1063
+
1064
+
1065
+ @add_start_docstrings(
1066
+ """
1067
+ The Flaubert Model transformer with a language modeling head on top (linear layer with weights tied to the input
1068
+ embeddings).
1069
+ """,
1070
+ FLAUBERT_START_DOCSTRING,
1071
+ )
1072
+ # Copied transformers.models.xlm.modeling_xlm.XLMWithLMHeadModel with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
1073
+ class FlaubertWithLMHeadModel(FlaubertPreTrainedModel, GenerationMixin):
1074
+ _tied_weights_keys = ["pred_layer.proj.weight"]
1075
+
1076
+ def __init__(self, config):
1077
+ super().__init__(config)
1078
+ self.transformer = FlaubertModel(config)
1079
+ self.pred_layer = FlaubertPredLayer(config)
1080
+
1081
+ # Initialize weights and apply final processing
1082
+ self.post_init()
1083
+
1084
+ def get_output_embeddings(self):
1085
+ return self.pred_layer.proj
1086
+
1087
+ def set_output_embeddings(self, new_embeddings):
1088
+ self.pred_layer.proj = new_embeddings
1089
+
1090
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
1091
+ # Overwritten -- uses a language id
1092
+
1093
+ mask_token_id = self.config.mask_token_id
1094
+ lang_id = self.config.lang_id
1095
+
1096
+ effective_batch_size = input_ids.shape[0]
1097
+ mask_token = torch.full((effective_batch_size, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
1098
+ input_ids = torch.cat([input_ids, mask_token], dim=1)
1099
+ if lang_id is not None:
1100
+ langs = torch.full_like(input_ids, lang_id)
1101
+ else:
1102
+ langs = None
1103
+ return {"input_ids": input_ids, "langs": langs}
1104
+
1105
+ @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1106
+ @add_code_sample_docstrings(
1107
+ checkpoint=_CHECKPOINT_FOR_DOC,
1108
+ output_type=MaskedLMOutput,
1109
+ config_class=_CONFIG_FOR_DOC,
1110
+ mask="<special1>",
1111
+ )
1112
+ def forward(
1113
+ self,
1114
+ input_ids: Optional[torch.Tensor] = None,
1115
+ attention_mask: Optional[torch.Tensor] = None,
1116
+ langs: Optional[torch.Tensor] = None,
1117
+ token_type_ids: Optional[torch.Tensor] = None,
1118
+ position_ids: Optional[torch.Tensor] = None,
1119
+ lengths: Optional[torch.Tensor] = None,
1120
+ cache: Optional[Dict[str, torch.Tensor]] = None,
1121
+ head_mask: Optional[torch.Tensor] = None,
1122
+ inputs_embeds: Optional[torch.Tensor] = None,
1123
+ labels: Optional[torch.Tensor] = None,
1124
+ output_attentions: Optional[bool] = None,
1125
+ output_hidden_states: Optional[bool] = None,
1126
+ return_dict: Optional[bool] = None,
1127
+ ) -> Union[Tuple, MaskedLMOutput]:
1128
+ r"""
1129
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1130
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1131
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1132
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1133
+ """
1134
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1135
+
1136
+ transformer_outputs = self.transformer(
1137
+ input_ids,
1138
+ attention_mask=attention_mask,
1139
+ langs=langs,
1140
+ token_type_ids=token_type_ids,
1141
+ position_ids=position_ids,
1142
+ lengths=lengths,
1143
+ cache=cache,
1144
+ head_mask=head_mask,
1145
+ inputs_embeds=inputs_embeds,
1146
+ output_attentions=output_attentions,
1147
+ output_hidden_states=output_hidden_states,
1148
+ return_dict=return_dict,
1149
+ )
1150
+
1151
+ output = transformer_outputs[0]
1152
+ outputs = self.pred_layer(output, labels) # (loss, logits) or (logits,) depending on if labels are provided.
1153
+
1154
+ if not return_dict:
1155
+ return outputs + transformer_outputs[1:]
1156
+
1157
+ return MaskedLMOutput(
1158
+ loss=outputs[0] if labels is not None else None,
1159
+ logits=outputs[0] if labels is None else outputs[1],
1160
+ hidden_states=transformer_outputs.hidden_states,
1161
+ attentions=transformer_outputs.attentions,
1162
+ )
1163
+
1164
+
1165
+ @add_start_docstrings(
1166
+ """
1167
+ Flaubert Model with a sequence classification/regression head on top (a linear layer on top of the pooled output)
1168
+ e.g. for GLUE tasks.
1169
+ """,
1170
+ FLAUBERT_START_DOCSTRING,
1171
+ )
1172
+ # Copied from transformers.models.xlm.modeling_xlm.XLMForSequenceClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
1173
+ class FlaubertForSequenceClassification(FlaubertPreTrainedModel):
1174
+ def __init__(self, config):
1175
+ super().__init__(config)
1176
+ self.num_labels = config.num_labels
1177
+ self.config = config
1178
+
1179
+ self.transformer = FlaubertModel(config)
1180
+ self.sequence_summary = FlaubertSequenceSummary(config)
1181
+
1182
+ # Initialize weights and apply final processing
1183
+ self.post_init()
1184
+
1185
+ @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1186
+ @add_code_sample_docstrings(
1187
+ checkpoint=_CHECKPOINT_FOR_DOC,
1188
+ output_type=SequenceClassifierOutput,
1189
+ config_class=_CONFIG_FOR_DOC,
1190
+ )
1191
+ def forward(
1192
+ self,
1193
+ input_ids: Optional[torch.Tensor] = None,
1194
+ attention_mask: Optional[torch.Tensor] = None,
1195
+ langs: Optional[torch.Tensor] = None,
1196
+ token_type_ids: Optional[torch.Tensor] = None,
1197
+ position_ids: Optional[torch.Tensor] = None,
1198
+ lengths: Optional[torch.Tensor] = None,
1199
+ cache: Optional[Dict[str, torch.Tensor]] = None,
1200
+ head_mask: Optional[torch.Tensor] = None,
1201
+ inputs_embeds: Optional[torch.Tensor] = None,
1202
+ labels: Optional[torch.Tensor] = None,
1203
+ output_attentions: Optional[bool] = None,
1204
+ output_hidden_states: Optional[bool] = None,
1205
+ return_dict: Optional[bool] = None,
1206
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1207
+ r"""
1208
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1209
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1210
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1211
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1212
+ """
1213
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1214
+
1215
+ transformer_outputs = self.transformer(
1216
+ input_ids,
1217
+ attention_mask=attention_mask,
1218
+ langs=langs,
1219
+ token_type_ids=token_type_ids,
1220
+ position_ids=position_ids,
1221
+ lengths=lengths,
1222
+ cache=cache,
1223
+ head_mask=head_mask,
1224
+ inputs_embeds=inputs_embeds,
1225
+ output_attentions=output_attentions,
1226
+ output_hidden_states=output_hidden_states,
1227
+ return_dict=return_dict,
1228
+ )
1229
+
1230
+ output = transformer_outputs[0]
1231
+ logits = self.sequence_summary(output)
1232
+
1233
+ loss = None
1234
+ if labels is not None:
1235
+ if self.config.problem_type is None:
1236
+ if self.num_labels == 1:
1237
+ self.config.problem_type = "regression"
1238
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1239
+ self.config.problem_type = "single_label_classification"
1240
+ else:
1241
+ self.config.problem_type = "multi_label_classification"
1242
+
1243
+ if self.config.problem_type == "regression":
1244
+ loss_fct = MSELoss()
1245
+ if self.num_labels == 1:
1246
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1247
+ else:
1248
+ loss = loss_fct(logits, labels)
1249
+ elif self.config.problem_type == "single_label_classification":
1250
+ loss_fct = CrossEntropyLoss()
1251
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1252
+ elif self.config.problem_type == "multi_label_classification":
1253
+ loss_fct = BCEWithLogitsLoss()
1254
+ loss = loss_fct(logits, labels)
1255
+
1256
+ if not return_dict:
1257
+ output = (logits,) + transformer_outputs[1:]
1258
+ return ((loss,) + output) if loss is not None else output
1259
+
1260
+ return SequenceClassifierOutput(
1261
+ loss=loss,
1262
+ logits=logits,
1263
+ hidden_states=transformer_outputs.hidden_states,
1264
+ attentions=transformer_outputs.attentions,
1265
+ )
1266
+
1267
+
1268
+ @add_start_docstrings(
1269
+ """
1270
+ Flaubert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1271
+ Named-Entity-Recognition (NER) tasks.
1272
+ """,
1273
+ FLAUBERT_START_DOCSTRING,
1274
+ )
1275
+ # Copied from transformers.models.xlm.modeling_xlm.XLMForTokenClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
1276
+ class FlaubertForTokenClassification(FlaubertPreTrainedModel):
1277
+ def __init__(self, config):
1278
+ super().__init__(config)
1279
+ self.num_labels = config.num_labels
1280
+
1281
+ self.transformer = FlaubertModel(config)
1282
+ self.dropout = nn.Dropout(config.dropout)
1283
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1284
+
1285
+ # Initialize weights and apply final processing
1286
+ self.post_init()
1287
+
1288
+ @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1289
+ @add_code_sample_docstrings(
1290
+ checkpoint=_CHECKPOINT_FOR_DOC,
1291
+ output_type=TokenClassifierOutput,
1292
+ config_class=_CONFIG_FOR_DOC,
1293
+ )
1294
+ def forward(
1295
+ self,
1296
+ input_ids: Optional[torch.Tensor] = None,
1297
+ attention_mask: Optional[torch.Tensor] = None,
1298
+ langs: Optional[torch.Tensor] = None,
1299
+ token_type_ids: Optional[torch.Tensor] = None,
1300
+ position_ids: Optional[torch.Tensor] = None,
1301
+ lengths: Optional[torch.Tensor] = None,
1302
+ cache: Optional[Dict[str, torch.Tensor]] = None,
1303
+ head_mask: Optional[torch.Tensor] = None,
1304
+ inputs_embeds: Optional[torch.Tensor] = None,
1305
+ labels: Optional[torch.Tensor] = None,
1306
+ output_attentions: Optional[bool] = None,
1307
+ output_hidden_states: Optional[bool] = None,
1308
+ return_dict: Optional[bool] = None,
1309
+ ) -> Union[Tuple, TokenClassifierOutput]:
1310
+ r"""
1311
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1312
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1313
+ """
1314
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1315
+
1316
+ outputs = self.transformer(
1317
+ input_ids,
1318
+ attention_mask=attention_mask,
1319
+ langs=langs,
1320
+ token_type_ids=token_type_ids,
1321
+ position_ids=position_ids,
1322
+ lengths=lengths,
1323
+ cache=cache,
1324
+ head_mask=head_mask,
1325
+ inputs_embeds=inputs_embeds,
1326
+ output_attentions=output_attentions,
1327
+ output_hidden_states=output_hidden_states,
1328
+ return_dict=return_dict,
1329
+ )
1330
+
1331
+ sequence_output = outputs[0]
1332
+
1333
+ sequence_output = self.dropout(sequence_output)
1334
+ logits = self.classifier(sequence_output)
1335
+
1336
+ loss = None
1337
+ if labels is not None:
1338
+ loss_fct = CrossEntropyLoss()
1339
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1340
+
1341
+ if not return_dict:
1342
+ output = (logits,) + outputs[1:]
1343
+ return ((loss,) + output) if loss is not None else output
1344
+
1345
+ return TokenClassifierOutput(
1346
+ loss=loss,
1347
+ logits=logits,
1348
+ hidden_states=outputs.hidden_states,
1349
+ attentions=outputs.attentions,
1350
+ )
1351
+
1352
+
1353
+ @add_start_docstrings(
1354
+ """
1355
+ Flaubert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1356
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1357
+ """,
1358
+ FLAUBERT_START_DOCSTRING,
1359
+ )
1360
+ # Copied from transformers.models.xlm.modeling_xlm.XLMForQuestionAnsweringSimple with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
1361
+ class FlaubertForQuestionAnsweringSimple(FlaubertPreTrainedModel):
1362
+ def __init__(self, config):
1363
+ super().__init__(config)
1364
+
1365
+ self.transformer = FlaubertModel(config)
1366
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1367
+
1368
+ # Initialize weights and apply final processing
1369
+ self.post_init()
1370
+
1371
+ @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1372
+ @add_code_sample_docstrings(
1373
+ checkpoint=_CHECKPOINT_FOR_DOC,
1374
+ output_type=QuestionAnsweringModelOutput,
1375
+ config_class=_CONFIG_FOR_DOC,
1376
+ )
1377
+ def forward(
1378
+ self,
1379
+ input_ids: Optional[torch.Tensor] = None,
1380
+ attention_mask: Optional[torch.Tensor] = None,
1381
+ langs: Optional[torch.Tensor] = None,
1382
+ token_type_ids: Optional[torch.Tensor] = None,
1383
+ position_ids: Optional[torch.Tensor] = None,
1384
+ lengths: Optional[torch.Tensor] = None,
1385
+ cache: Optional[Dict[str, torch.Tensor]] = None,
1386
+ head_mask: Optional[torch.Tensor] = None,
1387
+ inputs_embeds: Optional[torch.Tensor] = None,
1388
+ start_positions: Optional[torch.Tensor] = None,
1389
+ end_positions: Optional[torch.Tensor] = None,
1390
+ output_attentions: Optional[bool] = None,
1391
+ output_hidden_states: Optional[bool] = None,
1392
+ return_dict: Optional[bool] = None,
1393
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1394
+ r"""
1395
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1396
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1397
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1398
+ are not taken into account for computing the loss.
1399
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1400
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1401
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1402
+ are not taken into account for computing the loss.
1403
+ """
1404
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1405
+
1406
+ transformer_outputs = self.transformer(
1407
+ input_ids,
1408
+ attention_mask=attention_mask,
1409
+ langs=langs,
1410
+ token_type_ids=token_type_ids,
1411
+ position_ids=position_ids,
1412
+ lengths=lengths,
1413
+ cache=cache,
1414
+ head_mask=head_mask,
1415
+ inputs_embeds=inputs_embeds,
1416
+ output_attentions=output_attentions,
1417
+ output_hidden_states=output_hidden_states,
1418
+ return_dict=return_dict,
1419
+ )
1420
+
1421
+ sequence_output = transformer_outputs[0]
1422
+
1423
+ logits = self.qa_outputs(sequence_output)
1424
+ start_logits, end_logits = logits.split(1, dim=-1)
1425
+ start_logits = start_logits.squeeze(-1).contiguous()
1426
+ end_logits = end_logits.squeeze(-1).contiguous()
1427
+
1428
+ total_loss = None
1429
+ if start_positions is not None and end_positions is not None:
1430
+ # If we are on multi-GPU, split add a dimension
1431
+ if len(start_positions.size()) > 1:
1432
+ start_positions = start_positions.squeeze(-1)
1433
+ if len(end_positions.size()) > 1:
1434
+ end_positions = end_positions.squeeze(-1)
1435
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1436
+ ignored_index = start_logits.size(1)
1437
+ start_positions = start_positions.clamp(0, ignored_index)
1438
+ end_positions = end_positions.clamp(0, ignored_index)
1439
+
1440
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1441
+ start_loss = loss_fct(start_logits, start_positions)
1442
+ end_loss = loss_fct(end_logits, end_positions)
1443
+ total_loss = (start_loss + end_loss) / 2
1444
+
1445
+ if not return_dict:
1446
+ output = (start_logits, end_logits) + transformer_outputs[1:]
1447
+ return ((total_loss,) + output) if total_loss is not None else output
1448
+
1449
+ return QuestionAnsweringModelOutput(
1450
+ loss=total_loss,
1451
+ start_logits=start_logits,
1452
+ end_logits=end_logits,
1453
+ hidden_states=transformer_outputs.hidden_states,
1454
+ attentions=transformer_outputs.attentions,
1455
+ )
1456
+
1457
+
1458
+ @add_start_docstrings(
1459
+ """
1460
+ Flaubert Model with a beam-search span classification head on top for extractive question-answering tasks like
1461
+ SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1462
+ """,
1463
+ FLAUBERT_START_DOCSTRING,
1464
+ )
1465
+ @dataclass
1466
+ # Copied from transformer.models.xlm.modeling_xlm.XLMForQuestionAnsweringOutput with XLM->Flaubert
1467
+ class FlaubertForQuestionAnsweringOutput(ModelOutput):
1468
+ """
1469
+ Base class for outputs of question answering models using a `SquadHead`.
1470
+
1471
+ Args:
1472
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
1473
+ Classification loss as the sum of start token, end token (and is_impossible if provided) classification
1474
+ losses.
1475
+ start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
1476
+ Log probabilities for the top config.start_n_top start token possibilities (beam-search).
1477
+ start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
1478
+ Indices for the top config.start_n_top start token possibilities (beam-search).
1479
+ end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
1480
+ Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
1481
+ (beam-search).
1482
+ end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
1483
+ Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
1484
+ cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
1485
+ Log probabilities for the `is_impossible` label of the answers.
1486
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1487
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
1488
+ shape `(batch_size, sequence_length, hidden_size)`.
1489
+
1490
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1491
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
1492
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
1493
+ sequence_length)`.
1494
+
1495
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1496
+ heads.
1497
+ """
1498
+
1499
+ loss: Optional[torch.FloatTensor] = None
1500
+ start_top_log_probs: Optional[torch.FloatTensor] = None
1501
+ start_top_index: Optional[torch.LongTensor] = None
1502
+ end_top_log_probs: Optional[torch.FloatTensor] = None
1503
+ end_top_index: Optional[torch.LongTensor] = None
1504
+ cls_logits: Optional[torch.FloatTensor] = None
1505
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1506
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
1507
+
1508
+
1509
+ # Copied from transformers.models.xlm.modeling_xlm.XLMForQuestionAnswering with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
1510
+ class FlaubertForQuestionAnswering(FlaubertPreTrainedModel):
1511
+ def __init__(self, config):
1512
+ super().__init__(config)
1513
+
1514
+ self.transformer = FlaubertModel(config)
1515
+ self.qa_outputs = FlaubertSQuADHead(config)
1516
+
1517
+ # Initialize weights and apply final processing
1518
+ self.post_init()
1519
+
1520
+ @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1521
+ @replace_return_docstrings(output_type=FlaubertForQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC)
1522
+ def forward(
1523
+ self,
1524
+ input_ids: Optional[torch.Tensor] = None,
1525
+ attention_mask: Optional[torch.Tensor] = None,
1526
+ langs: Optional[torch.Tensor] = None,
1527
+ token_type_ids: Optional[torch.Tensor] = None,
1528
+ position_ids: Optional[torch.Tensor] = None,
1529
+ lengths: Optional[torch.Tensor] = None,
1530
+ cache: Optional[Dict[str, torch.Tensor]] = None,
1531
+ head_mask: Optional[torch.Tensor] = None,
1532
+ inputs_embeds: Optional[torch.Tensor] = None,
1533
+ start_positions: Optional[torch.Tensor] = None,
1534
+ end_positions: Optional[torch.Tensor] = None,
1535
+ is_impossible: Optional[torch.Tensor] = None,
1536
+ cls_index: Optional[torch.Tensor] = None,
1537
+ p_mask: Optional[torch.Tensor] = None,
1538
+ output_attentions: Optional[bool] = None,
1539
+ output_hidden_states: Optional[bool] = None,
1540
+ return_dict: Optional[bool] = None,
1541
+ ) -> Union[Tuple, FlaubertForQuestionAnsweringOutput]:
1542
+ r"""
1543
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1544
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1545
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1546
+ are not taken into account for computing the loss.
1547
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1548
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1549
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1550
+ are not taken into account for computing the loss.
1551
+ is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1552
+ Labels whether a question has an answer or no answer (SQuAD 2.0)
1553
+ cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1554
+ Labels for position (index) of the classification token to use as input for computing plausibility of the
1555
+ answer.
1556
+ p_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1557
+ Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...). 1.0 means token should be
1558
+ masked. 0.0 mean token is not masked.
1559
+
1560
+ Returns:
1561
+
1562
+ Example:
1563
+
1564
+ ```python
1565
+ >>> from transformers import AutoTokenizer, FlaubertForQuestionAnswering
1566
+ >>> import torch
1567
+
1568
+ >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-mlm-en-2048")
1569
+ >>> model = FlaubertForQuestionAnswering.from_pretrained("FacebookAI/xlm-mlm-en-2048")
1570
+
1571
+ >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(
1572
+ ... 0
1573
+ ... ) # Batch size 1
1574
+ >>> start_positions = torch.tensor([1])
1575
+ >>> end_positions = torch.tensor([3])
1576
+
1577
+ >>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
1578
+ >>> loss = outputs.loss
1579
+ ```"""
1580
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1581
+
1582
+ transformer_outputs = self.transformer(
1583
+ input_ids,
1584
+ attention_mask=attention_mask,
1585
+ langs=langs,
1586
+ token_type_ids=token_type_ids,
1587
+ position_ids=position_ids,
1588
+ lengths=lengths,
1589
+ cache=cache,
1590
+ head_mask=head_mask,
1591
+ inputs_embeds=inputs_embeds,
1592
+ output_attentions=output_attentions,
1593
+ output_hidden_states=output_hidden_states,
1594
+ return_dict=return_dict,
1595
+ )
1596
+
1597
+ output = transformer_outputs[0]
1598
+
1599
+ outputs = self.qa_outputs(
1600
+ output,
1601
+ start_positions=start_positions,
1602
+ end_positions=end_positions,
1603
+ cls_index=cls_index,
1604
+ is_impossible=is_impossible,
1605
+ p_mask=p_mask,
1606
+ return_dict=return_dict,
1607
+ )
1608
+
1609
+ if not return_dict:
1610
+ return outputs + transformer_outputs[1:]
1611
+
1612
+ return FlaubertForQuestionAnsweringOutput(
1613
+ loss=outputs.loss,
1614
+ start_top_log_probs=outputs.start_top_log_probs,
1615
+ start_top_index=outputs.start_top_index,
1616
+ end_top_log_probs=outputs.end_top_log_probs,
1617
+ end_top_index=outputs.end_top_index,
1618
+ cls_logits=outputs.cls_logits,
1619
+ hidden_states=transformer_outputs.hidden_states,
1620
+ attentions=transformer_outputs.attentions,
1621
+ )
1622
+
1623
+
1624
+ @add_start_docstrings(
1625
+ """
1626
+ Flaubert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1627
+ softmax) e.g. for RocStories/SWAG tasks.
1628
+ """,
1629
+ FLAUBERT_START_DOCSTRING,
1630
+ )
1631
+ # Copied from transformers.models.xlm.modeling_xlm.XLMForMultipleChoice with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
1632
+ class FlaubertForMultipleChoice(FlaubertPreTrainedModel):
1633
+ def __init__(self, config, *inputs, **kwargs):
1634
+ super().__init__(config, *inputs, **kwargs)
1635
+
1636
+ self.transformer = FlaubertModel(config)
1637
+ self.sequence_summary = FlaubertSequenceSummary(config)
1638
+ self.logits_proj = nn.Linear(config.num_labels, 1)
1639
+
1640
+ # Initialize weights and apply final processing
1641
+ self.post_init()
1642
+
1643
+ @add_start_docstrings_to_model_forward(
1644
+ FLAUBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
1645
+ )
1646
+ @add_code_sample_docstrings(
1647
+ checkpoint=_CHECKPOINT_FOR_DOC,
1648
+ output_type=MultipleChoiceModelOutput,
1649
+ config_class=_CONFIG_FOR_DOC,
1650
+ )
1651
+ def forward(
1652
+ self,
1653
+ input_ids: Optional[torch.Tensor] = None,
1654
+ attention_mask: Optional[torch.Tensor] = None,
1655
+ langs: Optional[torch.Tensor] = None,
1656
+ token_type_ids: Optional[torch.Tensor] = None,
1657
+ position_ids: Optional[torch.Tensor] = None,
1658
+ lengths: Optional[torch.Tensor] = None,
1659
+ cache: Optional[Dict[str, torch.Tensor]] = None,
1660
+ head_mask: Optional[torch.Tensor] = None,
1661
+ inputs_embeds: Optional[torch.Tensor] = None,
1662
+ labels: Optional[torch.Tensor] = None,
1663
+ output_attentions: Optional[bool] = None,
1664
+ output_hidden_states: Optional[bool] = None,
1665
+ return_dict: Optional[bool] = None,
1666
+ ) -> Union[Tuple, MultipleChoiceModelOutput]:
1667
+ r"""
1668
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1669
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1670
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1671
+ `input_ids` above)
1672
+ """
1673
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1674
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1675
+
1676
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1677
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1678
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1679
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1680
+ langs = langs.view(-1, langs.size(-1)) if langs is not None else None
1681
+ inputs_embeds = (
1682
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1683
+ if inputs_embeds is not None
1684
+ else None
1685
+ )
1686
+
1687
+ if lengths is not None:
1688
+ logger.warning(
1689
+ "The `lengths` parameter cannot be used with the Flaubert multiple choice models. Please use the "
1690
+ "attention mask instead."
1691
+ )
1692
+ lengths = None
1693
+
1694
+ transformer_outputs = self.transformer(
1695
+ input_ids=input_ids,
1696
+ attention_mask=attention_mask,
1697
+ langs=langs,
1698
+ token_type_ids=token_type_ids,
1699
+ position_ids=position_ids,
1700
+ lengths=lengths,
1701
+ cache=cache,
1702
+ head_mask=head_mask,
1703
+ inputs_embeds=inputs_embeds,
1704
+ output_attentions=output_attentions,
1705
+ output_hidden_states=output_hidden_states,
1706
+ return_dict=return_dict,
1707
+ )
1708
+ output = transformer_outputs[0]
1709
+ logits = self.sequence_summary(output)
1710
+ logits = self.logits_proj(logits)
1711
+ reshaped_logits = logits.view(-1, num_choices)
1712
+
1713
+ loss = None
1714
+ if labels is not None:
1715
+ loss_fct = CrossEntropyLoss()
1716
+ loss = loss_fct(reshaped_logits, labels)
1717
+
1718
+ if not return_dict:
1719
+ output = (reshaped_logits,) + transformer_outputs[1:]
1720
+ return ((loss,) + output) if loss is not None else output
1721
+
1722
+ return MultipleChoiceModelOutput(
1723
+ loss=loss,
1724
+ logits=reshaped_logits,
1725
+ hidden_states=transformer_outputs.hidden_states,
1726
+ attentions=transformer_outputs.attentions,
1727
+ )
1728
+
1729
+
1730
+ __all__ = [
1731
+ "FlaubertForMultipleChoice",
1732
+ "FlaubertForQuestionAnswering",
1733
+ "FlaubertForQuestionAnsweringSimple",
1734
+ "FlaubertForSequenceClassification",
1735
+ "FlaubertForTokenClassification",
1736
+ "FlaubertModel",
1737
+ "FlaubertWithLMHeadModel",
1738
+ "FlaubertPreTrainedModel",
1739
+ ]
docs/transformers/build/lib/transformers/models/flaubert/modeling_tf_flaubert.py ADDED
@@ -0,0 +1,1344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ TF 2.0 Flaubert model.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import itertools
22
+ import random
23
+ import warnings
24
+ from dataclasses import dataclass
25
+ from typing import Dict, Optional, Tuple, Union
26
+
27
+ import numpy as np
28
+ import tensorflow as tf
29
+
30
+ from ...activations_tf import get_tf_activation
31
+ from ...modeling_tf_outputs import (
32
+ TFBaseModelOutput,
33
+ TFMultipleChoiceModelOutput,
34
+ TFQuestionAnsweringModelOutput,
35
+ TFSequenceClassifierOutput,
36
+ TFTokenClassifierOutput,
37
+ )
38
+ from ...modeling_tf_utils import (
39
+ TFModelInputType,
40
+ TFMultipleChoiceLoss,
41
+ TFPreTrainedModel,
42
+ TFQuestionAnsweringLoss,
43
+ TFSequenceClassificationLoss,
44
+ TFSequenceSummary,
45
+ TFSharedEmbeddings,
46
+ TFTokenClassificationLoss,
47
+ get_initializer,
48
+ keras,
49
+ keras_serializable,
50
+ unpack_inputs,
51
+ )
52
+ from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
53
+ from ...utils import (
54
+ MULTIPLE_CHOICE_DUMMY_INPUTS,
55
+ ModelOutput,
56
+ add_code_sample_docstrings,
57
+ add_start_docstrings,
58
+ add_start_docstrings_to_model_forward,
59
+ logging,
60
+ )
61
+ from .configuration_flaubert import FlaubertConfig
62
+
63
+
64
+ logger = logging.get_logger(__name__)
65
+
66
+ _CHECKPOINT_FOR_DOC = "flaubert/flaubert_base_cased"
67
+ _CONFIG_FOR_DOC = "FlaubertConfig"
68
+
69
+
70
+ FLAUBERT_START_DOCSTRING = r"""
71
+
72
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
73
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
74
+ etc.)
75
+
76
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
77
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
78
+ behavior.
79
+
80
+ <Tip>
81
+
82
+ TensorFlow models and layers in `transformers` accept two formats as input:
83
+
84
+ - having all inputs as keyword arguments (like PyTorch models), or
85
+ - having all inputs as a list, tuple or dict in the first positional argument.
86
+
87
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
88
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
89
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
90
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
91
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
92
+ positional argument:
93
+
94
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
95
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
96
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
97
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
98
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
99
+
100
+ Note that when creating models and layers with
101
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
102
+ about any of this, as you can just pass inputs like you would to any other Python function!
103
+
104
+ </Tip>
105
+
106
+ Parameters:
107
+ config ([`FlaubertConfig`]): Model configuration class with all the parameters of the model.
108
+ Initializing with a config file does not load the weights associated with the model, only the
109
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
110
+ """
111
+
112
+ FLAUBERT_INPUTS_DOCSTRING = r"""
113
+ Args:
114
+ input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`):
115
+ Indices of input sequence tokens in the vocabulary.
116
+
117
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
118
+ [`PreTrainedTokenizer.encode`] for details.
119
+
120
+ [What are input IDs?](../glossary#input-ids)
121
+ attention_mask (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
122
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
123
+
124
+ - `1` for tokens that are **not masked**,
125
+ - `0` for tokens that are **masked**.
126
+
127
+ [What are attention masks?](../glossary#attention-mask)
128
+ langs (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
129
+ A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
130
+ languages ids which can be obtained from the language names by using two conversion mappings provided in
131
+ the configuration of the model (only provided for multilingual models). More precisely, the *language name
132
+ to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
133
+ *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
134
+
135
+ See usage examples detailed in the [multilingual documentation](../multilingual).
136
+ token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
137
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
138
+ 1]`:
139
+
140
+ - `0` corresponds to a *sentence A* token,
141
+ - `1` corresponds to a *sentence B* token.
142
+
143
+ [What are token type IDs?](../glossary#token-type-ids)
144
+ position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
145
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
146
+ config.max_position_embeddings - 1]`.
147
+
148
+ [What are position IDs?](../glossary#position-ids)
149
+ lengths (`tf.Tensor` or `Numpy array` of shape `(batch_size,)`, *optional*):
150
+ Length of each sentence that can be used to avoid performing attention on padding token indices. You can
151
+ also use *attention_mask* for the same result (see above), kept here for compatibility Indices selected in
152
+ `[0, ..., input_ids.size(-1)]`:
153
+ cache (`Dict[str, tf.Tensor]`, *optional*):
154
+ Dictionary string to `tf.FloatTensor` that contains precomputed hidden states (key and values in the
155
+ attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential
156
+ decoding.
157
+
158
+ The dictionary object will be modified in-place during the forward pass to add newly computed
159
+ hidden-states.
160
+ head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
161
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
162
+
163
+ - `1` indicates the head is **not masked**,
164
+ - `0` indicates the head is **masked**.
165
+
166
+ inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
167
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
168
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
169
+ model's internal embedding lookup matrix.
170
+ output_attentions (`bool`, *optional*):
171
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
172
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
173
+ config will be used instead.
174
+ output_hidden_states (`bool`, *optional*):
175
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
176
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
177
+ used instead.
178
+ return_dict (`bool`, *optional*):
179
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
180
+ eager mode, in graph mode the value will always be set to True.
181
+ training (`bool`, *optional*, defaults to `False`):
182
+ Whether or not to use the model in training mode (some modules like dropout modules have different
183
+ behaviors between training and evaluation).
184
+ """
185
+
186
+
187
+ def get_masks(slen, lengths, causal, padding_mask=None):
188
+ """
189
+ Generate hidden states mask, and optionally an attention mask.
190
+ """
191
+ bs = shape_list(lengths)[0]
192
+ if padding_mask is not None:
193
+ mask = padding_mask
194
+ else:
195
+ # assert lengths.max().item() <= slen
196
+ alen = tf.range(slen, dtype=lengths.dtype)
197
+ mask = alen < tf.expand_dims(lengths, axis=1)
198
+
199
+ # attention mask is the same as mask, or triangular inferior attention (causal)
200
+ if causal:
201
+ attn_mask = tf.less_equal(
202
+ tf.tile(tf.reshape(alen, (1, 1, slen)), (bs, slen, 1)), tf.reshape(alen, (1, slen, 1))
203
+ )
204
+ else:
205
+ attn_mask = mask
206
+
207
+ # sanity check
208
+ # assert shape_list(mask) == [bs, slen]
209
+ tf.debugging.assert_equal(shape_list(mask), [bs, slen])
210
+ if causal:
211
+ tf.debugging.assert_equal(shape_list(attn_mask), [bs, slen, slen])
212
+
213
+ return mask, attn_mask
214
+
215
+
216
+ class TFFlaubertPreTrainedModel(TFPreTrainedModel):
217
+ """
218
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
219
+ models.
220
+ """
221
+
222
+ config_class = FlaubertConfig
223
+ base_model_prefix = "transformer"
224
+
225
+ @property
226
+ def dummy_inputs(self):
227
+ # Sometimes Flaubert has language embeddings so don't forget to build them as well if needed
228
+ inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]], dtype=tf.int32)
229
+ attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]], dtype=tf.int32)
230
+ if self.config.use_lang_emb and self.config.n_langs > 1:
231
+ return {
232
+ "input_ids": inputs_list,
233
+ "attention_mask": attns_list,
234
+ "langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]], dtype=tf.int32),
235
+ }
236
+ else:
237
+ return {"input_ids": inputs_list, "attention_mask": attns_list}
238
+
239
+
240
+ @add_start_docstrings(
241
+ "The bare Flaubert Model transformer outputting raw hidden-states without any specific head on top.",
242
+ FLAUBERT_START_DOCSTRING,
243
+ )
244
+ class TFFlaubertModel(TFFlaubertPreTrainedModel):
245
+ def __init__(self, config, *inputs, **kwargs):
246
+ super().__init__(config, *inputs, **kwargs)
247
+ self.transformer = TFFlaubertMainLayer(config, name="transformer")
248
+
249
+ @unpack_inputs
250
+ @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING)
251
+ @add_code_sample_docstrings(
252
+ checkpoint=_CHECKPOINT_FOR_DOC,
253
+ output_type=TFBaseModelOutput,
254
+ config_class=_CONFIG_FOR_DOC,
255
+ )
256
+ def call(
257
+ self,
258
+ input_ids: np.ndarray | tf.Tensor | None = None,
259
+ attention_mask: np.ndarray | tf.Tensor | None = None,
260
+ langs: np.ndarray | tf.Tensor | None = None,
261
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
262
+ position_ids: np.ndarray | tf.Tensor | None = None,
263
+ lengths: np.ndarray | tf.Tensor | None = None,
264
+ cache: Optional[Dict[str, tf.Tensor]] = None,
265
+ head_mask: np.ndarray | tf.Tensor | None = None,
266
+ inputs_embeds: tf.Tensor | None = None,
267
+ output_attentions: Optional[bool] = None,
268
+ output_hidden_states: Optional[bool] = None,
269
+ return_dict: Optional[bool] = None,
270
+ training: Optional[bool] = False,
271
+ ) -> Union[Tuple, TFBaseModelOutput]:
272
+ outputs = self.transformer(
273
+ input_ids=input_ids,
274
+ attention_mask=attention_mask,
275
+ langs=langs,
276
+ token_type_ids=token_type_ids,
277
+ position_ids=position_ids,
278
+ lengths=lengths,
279
+ cache=cache,
280
+ head_mask=head_mask,
281
+ inputs_embeds=inputs_embeds,
282
+ output_attentions=output_attentions,
283
+ output_hidden_states=output_hidden_states,
284
+ return_dict=return_dict,
285
+ training=training,
286
+ )
287
+
288
+ return outputs
289
+
290
+ def build(self, input_shape=None):
291
+ if self.built:
292
+ return
293
+ self.built = True
294
+ if getattr(self, "transformer", None) is not None:
295
+ with tf.name_scope(self.transformer.name):
296
+ self.transformer.build(None)
297
+
298
+
299
+ # Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMMultiHeadAttention with XLM->Flaubert
300
+ class TFFlaubertMultiHeadAttention(keras.layers.Layer):
301
+ NEW_ID = itertools.count()
302
+
303
+ def __init__(self, n_heads, dim, config, **kwargs):
304
+ super().__init__(**kwargs)
305
+ self.layer_id = next(TFFlaubertMultiHeadAttention.NEW_ID)
306
+ self.dim = dim
307
+ self.n_heads = n_heads
308
+ self.output_attentions = config.output_attentions
309
+ assert self.dim % self.n_heads == 0
310
+
311
+ self.q_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="q_lin")
312
+ self.k_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="k_lin")
313
+ self.v_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="v_lin")
314
+ self.out_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="out_lin")
315
+ self.dropout = keras.layers.Dropout(config.attention_dropout)
316
+ self.pruned_heads = set()
317
+ self.dim = dim
318
+
319
+ def prune_heads(self, heads):
320
+ raise NotImplementedError
321
+
322
+ def call(self, input, mask, kv, cache, head_mask, output_attentions, training=False):
323
+ """
324
+ Self-attention (if kv is None) or attention over source sentence (provided by kv).
325
+ """
326
+ # Input is (bs, qlen, dim)
327
+ # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
328
+ bs, qlen, dim = shape_list(input)
329
+
330
+ if kv is None:
331
+ klen = qlen if cache is None else cache["slen"] + qlen
332
+ else:
333
+ klen = shape_list(kv)[1]
334
+
335
+ # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
336
+ dim_per_head = self.dim // self.n_heads
337
+ mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen)
338
+
339
+ def shape(x):
340
+ """projection"""
341
+ return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3))
342
+
343
+ def unshape(x):
344
+ """compute context"""
345
+ return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
346
+
347
+ q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
348
+
349
+ if kv is None:
350
+ k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head)
351
+ v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head)
352
+ elif cache is None or self.layer_id not in cache:
353
+ k = v = kv
354
+ k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head)
355
+ v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head)
356
+
357
+ if cache is not None:
358
+ if self.layer_id in cache:
359
+ if kv is None:
360
+ k_, v_ = cache[self.layer_id]
361
+ k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head)
362
+ v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head)
363
+ else:
364
+ k, v = cache[self.layer_id]
365
+
366
+ cache[self.layer_id] = (k, v)
367
+
368
+ f_dim_per_head = tf.cast(dim_per_head, dtype=q.dtype)
369
+ q = tf.multiply(q, tf.math.rsqrt(f_dim_per_head)) # (bs, n_heads, qlen, dim_per_head)
370
+ k = tf.cast(k, dtype=q.dtype)
371
+ scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
372
+ mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
373
+ # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
374
+ mask = tf.cast(mask, dtype=scores.dtype)
375
+ scores = scores - 1e30 * (1.0 - mask)
376
+ weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
377
+ weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
378
+
379
+ # Mask heads if we want to
380
+ if head_mask is not None:
381
+ weights = weights * head_mask
382
+
383
+ context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
384
+ context = unshape(context) # (bs, qlen, dim)
385
+ outputs = (self.out_lin(context),)
386
+
387
+ if output_attentions:
388
+ outputs = outputs + (weights,)
389
+
390
+ return outputs
391
+
392
+ def build(self, input_shape=None):
393
+ if self.built:
394
+ return
395
+ self.built = True
396
+ if getattr(self, "q_lin", None) is not None:
397
+ with tf.name_scope(self.q_lin.name):
398
+ self.q_lin.build([None, None, self.dim])
399
+ if getattr(self, "k_lin", None) is not None:
400
+ with tf.name_scope(self.k_lin.name):
401
+ self.k_lin.build([None, None, self.dim])
402
+ if getattr(self, "v_lin", None) is not None:
403
+ with tf.name_scope(self.v_lin.name):
404
+ self.v_lin.build([None, None, self.dim])
405
+ if getattr(self, "out_lin", None) is not None:
406
+ with tf.name_scope(self.out_lin.name):
407
+ self.out_lin.build([None, None, self.dim])
408
+
409
+
410
+ # Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMTransformerFFN
411
+ class TFFlaubertTransformerFFN(keras.layers.Layer):
412
+ def __init__(self, in_dim, dim_hidden, out_dim, config, **kwargs):
413
+ super().__init__(**kwargs)
414
+
415
+ self.lin1 = keras.layers.Dense(dim_hidden, kernel_initializer=get_initializer(config.init_std), name="lin1")
416
+ self.lin2 = keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name="lin2")
417
+ self.act = get_tf_activation("gelu") if config.gelu_activation else get_tf_activation("relu")
418
+ self.dropout = keras.layers.Dropout(config.dropout)
419
+ self.in_dim = in_dim
420
+ self.dim_hidden = dim_hidden
421
+
422
+ def call(self, input, training=False):
423
+ x = self.lin1(input)
424
+ x = self.act(x)
425
+ x = self.lin2(x)
426
+ x = self.dropout(x, training=training)
427
+
428
+ return x
429
+
430
+ def build(self, input_shape=None):
431
+ if self.built:
432
+ return
433
+ self.built = True
434
+ if getattr(self, "lin1", None) is not None:
435
+ with tf.name_scope(self.lin1.name):
436
+ self.lin1.build([None, None, self.in_dim])
437
+ if getattr(self, "lin2", None) is not None:
438
+ with tf.name_scope(self.lin2.name):
439
+ self.lin2.build([None, None, self.dim_hidden])
440
+
441
+
442
+ @keras_serializable
443
+ class TFFlaubertMainLayer(keras.layers.Layer):
444
+ config_class = FlaubertConfig
445
+
446
+ def __init__(self, config, **kwargs):
447
+ super().__init__(**kwargs)
448
+
449
+ self.config = config
450
+ self.n_heads = config.n_heads
451
+ self.n_langs = config.n_langs
452
+ self.dim = config.emb_dim
453
+ self.hidden_dim = self.dim * 4
454
+ self.n_words = config.n_words
455
+ self.pad_index = config.pad_index
456
+ self.causal = config.causal
457
+ self.n_layers = config.n_layers
458
+ self.use_lang_emb = config.use_lang_emb
459
+ self.layerdrop = getattr(config, "layerdrop", 0.0)
460
+ self.pre_norm = getattr(config, "pre_norm", False)
461
+ self.output_attentions = config.output_attentions
462
+ self.output_hidden_states = config.output_hidden_states
463
+ self.return_dict = config.use_return_dict
464
+ self.max_position_embeddings = config.max_position_embeddings
465
+ self.embed_init_std = config.embed_init_std
466
+ self.dropout = keras.layers.Dropout(config.dropout)
467
+ self.embeddings = TFSharedEmbeddings(
468
+ self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings"
469
+ )
470
+ self.layer_norm_emb = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm_emb")
471
+ self.attentions = []
472
+ self.layer_norm1 = []
473
+ self.ffns = []
474
+ self.layer_norm2 = []
475
+
476
+ for i in range(self.n_layers):
477
+ self.attentions.append(
478
+ TFFlaubertMultiHeadAttention(self.n_heads, self.dim, config=config, name=f"attentions_._{i}")
479
+ )
480
+ self.layer_norm1.append(
481
+ keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name=f"layer_norm1_._{i}")
482
+ )
483
+ # if self.is_decoder:
484
+ # self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
485
+ # self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
486
+ self.ffns.append(
487
+ TFFlaubertTransformerFFN(self.dim, self.hidden_dim, self.dim, config=config, name=f"ffns_._{i}")
488
+ )
489
+ self.layer_norm2.append(
490
+ keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name=f"layer_norm2_._{i}")
491
+ )
492
+
493
+ def build(self, input_shape=None):
494
+ with tf.name_scope("position_embeddings"):
495
+ self.position_embeddings = self.add_weight(
496
+ name="embeddings",
497
+ shape=[self.max_position_embeddings, self.dim],
498
+ initializer=get_initializer(self.embed_init_std),
499
+ )
500
+
501
+ if self.n_langs > 1 and self.use_lang_emb:
502
+ with tf.name_scope("lang_embeddings"):
503
+ self.lang_embeddings = self.add_weight(
504
+ name="embeddings",
505
+ shape=[self.n_langs, self.dim],
506
+ initializer=get_initializer(self.embed_init_std),
507
+ )
508
+
509
+ if self.built:
510
+ return
511
+ self.built = True
512
+ if getattr(self, "embeddings", None) is not None:
513
+ with tf.name_scope(self.embeddings.name):
514
+ self.embeddings.build(None)
515
+ if getattr(self, "layer_norm_emb", None) is not None:
516
+ with tf.name_scope(self.layer_norm_emb.name):
517
+ self.layer_norm_emb.build([None, None, self.dim])
518
+ for layer in self.attentions:
519
+ with tf.name_scope(layer.name):
520
+ layer.build(None)
521
+ for layer in self.layer_norm1:
522
+ with tf.name_scope(layer.name):
523
+ layer.build([None, None, self.dim])
524
+ for layer in self.ffns:
525
+ with tf.name_scope(layer.name):
526
+ layer.build(None)
527
+ for layer in self.layer_norm2:
528
+ with tf.name_scope(layer.name):
529
+ layer.build([None, None, self.dim])
530
+
531
+ def get_input_embeddings(self):
532
+ return self.embeddings
533
+
534
+ def set_input_embeddings(self, value):
535
+ self.embeddings.weight = value
536
+ self.embeddings.vocab_size = shape_list(value)[0]
537
+
538
+ @unpack_inputs
539
+ def call(
540
+ self,
541
+ input_ids: np.ndarray | tf.Tensor | None = None,
542
+ attention_mask: np.ndarray | tf.Tensor | None = None,
543
+ langs: np.ndarray | tf.Tensor | None = None,
544
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
545
+ position_ids: np.ndarray | tf.Tensor | None = None,
546
+ lengths: np.ndarray | tf.Tensor | None = None,
547
+ cache: Optional[Dict[str, tf.Tensor]] = None,
548
+ head_mask: np.ndarray | tf.Tensor | None = None,
549
+ inputs_embeds: tf.Tensor | None = None,
550
+ output_attentions: Optional[bool] = None,
551
+ output_hidden_states: Optional[bool] = None,
552
+ return_dict: Optional[bool] = None,
553
+ training: Optional[bool] = False,
554
+ ) -> Union[Tuple, TFBaseModelOutput]:
555
+ # removed: src_enc=None, src_len=None
556
+
557
+ if input_ids is not None and inputs_embeds is not None:
558
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
559
+ elif input_ids is not None:
560
+ bs, slen = shape_list(input_ids)
561
+ elif inputs_embeds is not None:
562
+ bs, slen = shape_list(inputs_embeds)[:2]
563
+ else:
564
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
565
+
566
+ if lengths is None:
567
+ if input_ids is not None:
568
+ lengths = tf.reduce_sum(
569
+ tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=input_ids.dtype), axis=1
570
+ )
571
+ else:
572
+ lengths = tf.convert_to_tensor([slen] * bs)
573
+ # mask = input_ids != self.pad_index
574
+
575
+ # check inputs
576
+ # assert shape_list(lengths)[0] == bs
577
+ (
578
+ tf.debugging.assert_equal(shape_list(lengths)[0], bs),
579
+ f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched",
580
+ )
581
+ # assert lengths.max().item() <= slen
582
+ # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
583
+ # assert (src_enc is None) == (src_len is None)
584
+ # if src_enc is not None:
585
+ # assert self.is_decoder
586
+ # assert src_enc.size(0) == bs
587
+
588
+ # generate masks
589
+ mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
590
+ # if self.is_decoder and src_enc is not None:
591
+ # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
592
+
593
+ # position_ids
594
+ if position_ids is None:
595
+ position_ids = tf.expand_dims(tf.range(slen), axis=0)
596
+ position_ids = tf.tile(position_ids, (bs, 1))
597
+
598
+ # assert shape_list(position_ids) == [bs, slen] # (slen, bs)
599
+ (
600
+ tf.debugging.assert_equal(shape_list(position_ids), [bs, slen]),
601
+ f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched",
602
+ )
603
+ # position_ids = position_ids.transpose(0, 1)
604
+
605
+ # langs
606
+ if langs is not None:
607
+ # assert shape_list(langs) == [bs, slen] # (slen, bs)
608
+ (
609
+ tf.debugging.assert_equal(shape_list(langs), [bs, slen]),
610
+ f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched",
611
+ )
612
+ # langs = langs.transpose(0, 1)
613
+
614
+ # Prepare head mask if needed
615
+ # 1.0 in head_mask indicate we keep the head
616
+ # attention_probs has shape bsz x n_heads x N x N
617
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
618
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
619
+ if head_mask is not None:
620
+ raise NotImplementedError
621
+ else:
622
+ head_mask = [None] * self.n_layers
623
+
624
+ # do not recompute cached elements
625
+ if cache is not None and input_ids is not None:
626
+ _slen = slen - cache["slen"]
627
+ input_ids = input_ids[:, -_slen:]
628
+ position_ids = position_ids[:, -_slen:]
629
+ if langs is not None:
630
+ langs = langs[:, -_slen:]
631
+ mask = mask[:, -_slen:]
632
+ attn_mask = attn_mask[:, -_slen:]
633
+
634
+ # embeddings
635
+ if inputs_embeds is None:
636
+ check_embeddings_within_bounds(input_ids, self.embeddings.vocab_size)
637
+ inputs_embeds = self.embeddings(input_ids)
638
+
639
+ tensor = inputs_embeds + tf.gather(self.position_embeddings, position_ids)
640
+
641
+ if langs is not None and self.use_lang_emb:
642
+ tensor = tensor + tf.gather(self.lang_embeddings, langs)
643
+ if token_type_ids is not None:
644
+ tensor = tensor + self.embeddings(token_type_ids)
645
+
646
+ tensor = self.layer_norm_emb(tensor)
647
+ tensor = self.dropout(tensor, training=training)
648
+ mask = tf.cast(mask, dtype=tensor.dtype)
649
+ tensor = tensor * tf.expand_dims(mask, axis=-1)
650
+
651
+ # hidden_states and attentions cannot be None in graph mode.
652
+ hidden_states = () if output_hidden_states else None
653
+ attentions = () if output_attentions else None
654
+
655
+ # transformer layers
656
+ for i in range(self.n_layers):
657
+ # LayerDrop
658
+ dropout_probability = random.uniform(0, 1)
659
+
660
+ if training and (dropout_probability < self.layerdrop):
661
+ continue
662
+
663
+ if output_hidden_states:
664
+ hidden_states = hidden_states + (tensor,)
665
+
666
+ # self attention
667
+ if not self.pre_norm:
668
+ attn_outputs = self.attentions[i](
669
+ tensor,
670
+ attn_mask,
671
+ None,
672
+ cache,
673
+ head_mask[i],
674
+ output_attentions,
675
+ training=training,
676
+ )
677
+ attn = attn_outputs[0]
678
+
679
+ if output_attentions:
680
+ attentions = attentions + (attn_outputs[1],)
681
+
682
+ attn = self.dropout(attn, training=training)
683
+ tensor = tensor + attn
684
+ tensor = self.layer_norm1[i](tensor)
685
+ else:
686
+ tensor_normalized = self.layer_norm1[i](tensor)
687
+ attn_outputs = self.attentions[i](
688
+ tensor_normalized,
689
+ attn_mask,
690
+ None,
691
+ cache,
692
+ head_mask[i],
693
+ output_attentions,
694
+ training=training,
695
+ )
696
+ attn = attn_outputs[0]
697
+
698
+ if output_attentions:
699
+ attentions = attentions + (attn_outputs[1],)
700
+
701
+ attn = self.dropout(attn, training=training)
702
+ tensor = tensor + attn
703
+
704
+ # encoder attention (for decoder only)
705
+ # if self.is_decoder and src_enc is not None:
706
+ # attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
707
+ # attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
708
+ # tensor = tensor + attn
709
+ # tensor = self.layer_norm15[i](tensor)
710
+
711
+ # FFN
712
+ if not self.pre_norm:
713
+ tensor = tensor + self.ffns[i](tensor)
714
+ tensor = self.layer_norm2[i](tensor)
715
+ else:
716
+ tensor_normalized = self.layer_norm2[i](tensor)
717
+ tensor = tensor + self.ffns[i](tensor_normalized)
718
+
719
+ tensor = tensor * tf.expand_dims(mask, axis=-1)
720
+
721
+ # Add last hidden state
722
+ if output_hidden_states:
723
+ hidden_states = hidden_states + (tensor,)
724
+
725
+ # update cache length
726
+ if cache is not None:
727
+ cache["slen"] += tensor.size(1)
728
+
729
+ # move back sequence length to dimension 0
730
+ # tensor = tensor.transpose(0, 1)
731
+
732
+ if not return_dict:
733
+ return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
734
+
735
+ return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
736
+
737
+
738
+ # Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMPredLayer
739
+ class TFFlaubertPredLayer(keras.layers.Layer):
740
+ """
741
+ Prediction layer (cross_entropy or adaptive_softmax).
742
+ """
743
+
744
+ def __init__(self, config, input_embeddings, **kwargs):
745
+ super().__init__(**kwargs)
746
+
747
+ self.asm = config.asm
748
+ self.n_words = config.n_words
749
+ self.pad_index = config.pad_index
750
+
751
+ if config.asm is False:
752
+ self.input_embeddings = input_embeddings
753
+ else:
754
+ raise NotImplementedError
755
+ # self.proj = nn.AdaptiveLogSoftmaxWithLoss(
756
+ # in_features=dim,
757
+ # n_classes=config.n_words,
758
+ # cutoffs=config.asm_cutoffs,
759
+ # div_value=config.asm_div_value,
760
+ # head_bias=True, # default is False
761
+ # )
762
+
763
+ def build(self, input_shape):
764
+ # The output weights are the same as the input embeddings, but there is an output-only bias for each token.
765
+ self.bias = self.add_weight(shape=(self.n_words,), initializer="zeros", trainable=True, name="bias")
766
+
767
+ super().build(input_shape)
768
+
769
+ def get_output_embeddings(self):
770
+ return self.input_embeddings
771
+
772
+ def set_output_embeddings(self, value):
773
+ self.input_embeddings.weight = value
774
+ self.input_embeddings.vocab_size = shape_list(value)[0]
775
+
776
+ def get_bias(self):
777
+ return {"bias": self.bias}
778
+
779
+ def set_bias(self, value):
780
+ self.bias = value["bias"]
781
+ self.vocab_size = shape_list(value["bias"])[0]
782
+
783
+ def call(self, hidden_states):
784
+ hidden_states = self.input_embeddings(hidden_states, mode="linear")
785
+ hidden_states = hidden_states + self.bias
786
+
787
+ return hidden_states
788
+
789
+
790
+ @dataclass
791
+ class TFFlaubertWithLMHeadModelOutput(ModelOutput):
792
+ """
793
+ Base class for [`TFFlaubertWithLMHeadModel`] outputs.
794
+
795
+ Args:
796
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
797
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
798
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
799
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
800
+ `(batch_size, sequence_length, hidden_size)`.
801
+
802
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
803
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
804
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
805
+ sequence_length)`.
806
+
807
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
808
+ heads.
809
+ """
810
+
811
+ logits: Optional[tf.Tensor] = None
812
+ hidden_states: Tuple[tf.Tensor] | None = None
813
+ attentions: Tuple[tf.Tensor] | None = None
814
+
815
+
816
+ @add_start_docstrings(
817
+ """
818
+ The Flaubert Model transformer with a language modeling head on top (linear layer with weights tied to the input
819
+ embeddings).
820
+ """,
821
+ FLAUBERT_START_DOCSTRING,
822
+ )
823
+ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
824
+ def __init__(self, config, *inputs, **kwargs):
825
+ super().__init__(config, *inputs, **kwargs)
826
+ self.transformer = TFFlaubertMainLayer(config, name="transformer")
827
+ self.pred_layer = TFFlaubertPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")
828
+ # Flaubert does not have past caching features
829
+ self.supports_xla_generation = False
830
+
831
+ def get_lm_head(self):
832
+ return self.pred_layer
833
+
834
+ def get_prefix_bias_name(self):
835
+ warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
836
+ return self.name + "/" + self.pred_layer.name
837
+
838
+ def prepare_inputs_for_generation(self, inputs, **kwargs):
839
+ mask_token_id = self.config.mask_token_id
840
+ lang_id = self.config.lang_id
841
+
842
+ effective_batch_size = inputs.shape[0]
843
+ mask_token = tf.fill((effective_batch_size, 1), 1) * mask_token_id
844
+ inputs = tf.concat([inputs, mask_token], axis=1)
845
+
846
+ if lang_id is not None:
847
+ langs = tf.ones_like(inputs) * lang_id
848
+ else:
849
+ langs = None
850
+ return {"input_ids": inputs, "langs": langs}
851
+
852
+ @unpack_inputs
853
+ @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING)
854
+ @add_code_sample_docstrings(
855
+ checkpoint=_CHECKPOINT_FOR_DOC,
856
+ output_type=TFFlaubertWithLMHeadModelOutput,
857
+ config_class=_CONFIG_FOR_DOC,
858
+ )
859
+ def call(
860
+ self,
861
+ input_ids: np.ndarray | tf.Tensor | None = None,
862
+ attention_mask: np.ndarray | tf.Tensor | None = None,
863
+ langs: np.ndarray | tf.Tensor | None = None,
864
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
865
+ position_ids: np.ndarray | tf.Tensor | None = None,
866
+ lengths: np.ndarray | tf.Tensor | None = None,
867
+ cache: Optional[Dict[str, tf.Tensor]] = None,
868
+ head_mask: np.ndarray | tf.Tensor | None = None,
869
+ inputs_embeds: tf.Tensor | None = None,
870
+ output_attentions: Optional[bool] = None,
871
+ output_hidden_states: Optional[bool] = None,
872
+ return_dict: Optional[bool] = None,
873
+ training: Optional[bool] = False,
874
+ ) -> Union[Tuple, TFFlaubertWithLMHeadModelOutput]:
875
+ transformer_outputs = self.transformer(
876
+ input_ids=input_ids,
877
+ attention_mask=attention_mask,
878
+ langs=langs,
879
+ token_type_ids=token_type_ids,
880
+ position_ids=position_ids,
881
+ lengths=lengths,
882
+ cache=cache,
883
+ head_mask=head_mask,
884
+ inputs_embeds=inputs_embeds,
885
+ output_attentions=output_attentions,
886
+ output_hidden_states=output_hidden_states,
887
+ return_dict=return_dict,
888
+ training=training,
889
+ )
890
+ output = transformer_outputs[0]
891
+ outputs = self.pred_layer(output)
892
+
893
+ if not return_dict:
894
+ return (outputs,) + transformer_outputs[1:]
895
+
896
+ return TFFlaubertWithLMHeadModelOutput(
897
+ logits=outputs, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions
898
+ )
899
+
900
+ def build(self, input_shape=None):
901
+ if self.built:
902
+ return
903
+ self.built = True
904
+ if getattr(self, "transformer", None) is not None:
905
+ with tf.name_scope(self.transformer.name):
906
+ self.transformer.build(None)
907
+ if getattr(self, "pred_layer", None) is not None:
908
+ with tf.name_scope(self.pred_layer.name):
909
+ self.pred_layer.build(None)
910
+
911
+
912
+ @add_start_docstrings(
913
+ """
914
+ Flaubert Model with a sequence classification/regression head on top (a linear layer on top of the pooled output)
915
+ e.g. for GLUE tasks.
916
+ """,
917
+ FLAUBERT_START_DOCSTRING,
918
+ )
919
+ # Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForSequenceClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
920
+ class TFFlaubertForSequenceClassification(TFFlaubertPreTrainedModel, TFSequenceClassificationLoss):
921
+ def __init__(self, config, *inputs, **kwargs):
922
+ super().__init__(config, *inputs, **kwargs)
923
+ self.num_labels = config.num_labels
924
+
925
+ self.transformer = TFFlaubertMainLayer(config, name="transformer")
926
+ self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary")
927
+
928
+ @unpack_inputs
929
+ @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
930
+ @add_code_sample_docstrings(
931
+ checkpoint=_CHECKPOINT_FOR_DOC,
932
+ output_type=TFSequenceClassifierOutput,
933
+ config_class=_CONFIG_FOR_DOC,
934
+ )
935
+ def call(
936
+ self,
937
+ input_ids: TFModelInputType | None = None,
938
+ attention_mask: np.ndarray | tf.Tensor | None = None,
939
+ langs: np.ndarray | tf.Tensor | None = None,
940
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
941
+ position_ids: np.ndarray | tf.Tensor | None = None,
942
+ lengths: np.ndarray | tf.Tensor | None = None,
943
+ cache: Optional[Dict[str, tf.Tensor]] = None,
944
+ head_mask: np.ndarray | tf.Tensor | None = None,
945
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
946
+ output_attentions: Optional[bool] = None,
947
+ output_hidden_states: Optional[bool] = None,
948
+ return_dict: Optional[bool] = None,
949
+ labels: np.ndarray | tf.Tensor | None = None,
950
+ training: bool = False,
951
+ ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
952
+ r"""
953
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
954
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
955
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
956
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
957
+ """
958
+ transformer_outputs = self.transformer(
959
+ input_ids=input_ids,
960
+ attention_mask=attention_mask,
961
+ langs=langs,
962
+ token_type_ids=token_type_ids,
963
+ position_ids=position_ids,
964
+ lengths=lengths,
965
+ cache=cache,
966
+ head_mask=head_mask,
967
+ inputs_embeds=inputs_embeds,
968
+ output_attentions=output_attentions,
969
+ output_hidden_states=output_hidden_states,
970
+ return_dict=return_dict,
971
+ training=training,
972
+ )
973
+ output = transformer_outputs[0]
974
+
975
+ logits = self.sequence_summary(output)
976
+
977
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
978
+
979
+ if not return_dict:
980
+ output = (logits,) + transformer_outputs[1:]
981
+ return ((loss,) + output) if loss is not None else output
982
+
983
+ return TFSequenceClassifierOutput(
984
+ loss=loss,
985
+ logits=logits,
986
+ hidden_states=transformer_outputs.hidden_states,
987
+ attentions=transformer_outputs.attentions,
988
+ )
989
+
990
+ def build(self, input_shape=None):
991
+ if self.built:
992
+ return
993
+ self.built = True
994
+ if getattr(self, "transformer", None) is not None:
995
+ with tf.name_scope(self.transformer.name):
996
+ self.transformer.build(None)
997
+ if getattr(self, "sequence_summary", None) is not None:
998
+ with tf.name_scope(self.sequence_summary.name):
999
+ self.sequence_summary.build(None)
1000
+
1001
+
1002
+ @add_start_docstrings(
1003
+ """
1004
+ Flaubert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1005
+ layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1006
+ """,
1007
+ FLAUBERT_START_DOCSTRING,
1008
+ )
1009
+ # Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForQuestionAnsweringSimple with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
1010
+ class TFFlaubertForQuestionAnsweringSimple(TFFlaubertPreTrainedModel, TFQuestionAnsweringLoss):
1011
+ def __init__(self, config, *inputs, **kwargs):
1012
+ super().__init__(config, *inputs, **kwargs)
1013
+ self.transformer = TFFlaubertMainLayer(config, name="transformer")
1014
+ self.qa_outputs = keras.layers.Dense(
1015
+ config.num_labels, kernel_initializer=get_initializer(config.init_std), name="qa_outputs"
1016
+ )
1017
+ self.config = config
1018
+
1019
+ @unpack_inputs
1020
+ @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1021
+ @add_code_sample_docstrings(
1022
+ checkpoint=_CHECKPOINT_FOR_DOC,
1023
+ output_type=TFQuestionAnsweringModelOutput,
1024
+ config_class=_CONFIG_FOR_DOC,
1025
+ )
1026
+ def call(
1027
+ self,
1028
+ input_ids: TFModelInputType | None = None,
1029
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1030
+ langs: np.ndarray | tf.Tensor | None = None,
1031
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1032
+ position_ids: np.ndarray | tf.Tensor | None = None,
1033
+ lengths: np.ndarray | tf.Tensor | None = None,
1034
+ cache: Optional[Dict[str, tf.Tensor]] = None,
1035
+ head_mask: np.ndarray | tf.Tensor | None = None,
1036
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1037
+ output_attentions: Optional[bool] = None,
1038
+ output_hidden_states: Optional[bool] = None,
1039
+ return_dict: Optional[bool] = None,
1040
+ start_positions: np.ndarray | tf.Tensor | None = None,
1041
+ end_positions: np.ndarray | tf.Tensor | None = None,
1042
+ training: bool = False,
1043
+ ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
1044
+ r"""
1045
+ start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
1046
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1047
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1048
+ are not taken into account for computing the loss.
1049
+ end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
1050
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1051
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1052
+ are not taken into account for computing the loss.
1053
+ """
1054
+ transformer_outputs = self.transformer(
1055
+ input_ids=input_ids,
1056
+ attention_mask=attention_mask,
1057
+ langs=langs,
1058
+ token_type_ids=token_type_ids,
1059
+ position_ids=position_ids,
1060
+ lengths=lengths,
1061
+ cache=cache,
1062
+ head_mask=head_mask,
1063
+ inputs_embeds=inputs_embeds,
1064
+ output_attentions=output_attentions,
1065
+ output_hidden_states=output_hidden_states,
1066
+ return_dict=return_dict,
1067
+ training=training,
1068
+ )
1069
+ sequence_output = transformer_outputs[0]
1070
+
1071
+ logits = self.qa_outputs(sequence_output)
1072
+ start_logits, end_logits = tf.split(logits, 2, axis=-1)
1073
+ start_logits = tf.squeeze(start_logits, axis=-1)
1074
+ end_logits = tf.squeeze(end_logits, axis=-1)
1075
+
1076
+ loss = None
1077
+ if start_positions is not None and end_positions is not None:
1078
+ labels = {"start_position": start_positions}
1079
+ labels["end_position"] = end_positions
1080
+ loss = self.hf_compute_loss(labels, (start_logits, end_logits))
1081
+
1082
+ if not return_dict:
1083
+ output = (start_logits, end_logits) + transformer_outputs[1:]
1084
+ return ((loss,) + output) if loss is not None else output
1085
+
1086
+ return TFQuestionAnsweringModelOutput(
1087
+ loss=loss,
1088
+ start_logits=start_logits,
1089
+ end_logits=end_logits,
1090
+ hidden_states=transformer_outputs.hidden_states,
1091
+ attentions=transformer_outputs.attentions,
1092
+ )
1093
+
1094
+ def build(self, input_shape=None):
1095
+ if self.built:
1096
+ return
1097
+ self.built = True
1098
+ if getattr(self, "transformer", None) is not None:
1099
+ with tf.name_scope(self.transformer.name):
1100
+ self.transformer.build(None)
1101
+ if getattr(self, "qa_outputs", None) is not None:
1102
+ with tf.name_scope(self.qa_outputs.name):
1103
+ self.qa_outputs.build([None, None, self.config.hidden_size])
1104
+
1105
+
1106
+ @add_start_docstrings(
1107
+ """
1108
+ Flaubert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1109
+ Named-Entity-Recognition (NER) tasks.
1110
+ """,
1111
+ FLAUBERT_START_DOCSTRING,
1112
+ )
1113
+ # Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForTokenClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
1114
+ class TFFlaubertForTokenClassification(TFFlaubertPreTrainedModel, TFTokenClassificationLoss):
1115
+ def __init__(self, config, *inputs, **kwargs):
1116
+ super().__init__(config, *inputs, **kwargs)
1117
+ self.num_labels = config.num_labels
1118
+
1119
+ self.transformer = TFFlaubertMainLayer(config, name="transformer")
1120
+ self.dropout = keras.layers.Dropout(config.dropout)
1121
+ self.classifier = keras.layers.Dense(
1122
+ config.num_labels, kernel_initializer=get_initializer(config.init_std), name="classifier"
1123
+ )
1124
+ self.config = config
1125
+
1126
+ @unpack_inputs
1127
+ @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1128
+ @add_code_sample_docstrings(
1129
+ checkpoint=_CHECKPOINT_FOR_DOC,
1130
+ output_type=TFTokenClassifierOutput,
1131
+ config_class=_CONFIG_FOR_DOC,
1132
+ )
1133
+ def call(
1134
+ self,
1135
+ input_ids: TFModelInputType | None = None,
1136
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1137
+ langs: np.ndarray | tf.Tensor | None = None,
1138
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1139
+ position_ids: np.ndarray | tf.Tensor | None = None,
1140
+ lengths: np.ndarray | tf.Tensor | None = None,
1141
+ cache: Optional[Dict[str, tf.Tensor]] = None,
1142
+ head_mask: np.ndarray | tf.Tensor | None = None,
1143
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1144
+ output_attentions: Optional[bool] = None,
1145
+ output_hidden_states: Optional[bool] = None,
1146
+ return_dict: Optional[bool] = None,
1147
+ labels: np.ndarray | tf.Tensor | None = None,
1148
+ training: bool = False,
1149
+ ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
1150
+ r"""
1151
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1152
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1153
+ """
1154
+ transformer_outputs = self.transformer(
1155
+ input_ids=input_ids,
1156
+ attention_mask=attention_mask,
1157
+ langs=langs,
1158
+ token_type_ids=token_type_ids,
1159
+ position_ids=position_ids,
1160
+ lengths=lengths,
1161
+ cache=cache,
1162
+ head_mask=head_mask,
1163
+ inputs_embeds=inputs_embeds,
1164
+ output_attentions=output_attentions,
1165
+ output_hidden_states=output_hidden_states,
1166
+ return_dict=return_dict,
1167
+ training=training,
1168
+ )
1169
+ sequence_output = transformer_outputs[0]
1170
+
1171
+ sequence_output = self.dropout(sequence_output, training=training)
1172
+ logits = self.classifier(sequence_output)
1173
+
1174
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
1175
+
1176
+ if not return_dict:
1177
+ output = (logits,) + transformer_outputs[1:]
1178
+ return ((loss,) + output) if loss is not None else output
1179
+
1180
+ return TFTokenClassifierOutput(
1181
+ loss=loss,
1182
+ logits=logits,
1183
+ hidden_states=transformer_outputs.hidden_states,
1184
+ attentions=transformer_outputs.attentions,
1185
+ )
1186
+
1187
+ def build(self, input_shape=None):
1188
+ if self.built:
1189
+ return
1190
+ self.built = True
1191
+ if getattr(self, "transformer", None) is not None:
1192
+ with tf.name_scope(self.transformer.name):
1193
+ self.transformer.build(None)
1194
+ if getattr(self, "classifier", None) is not None:
1195
+ with tf.name_scope(self.classifier.name):
1196
+ self.classifier.build([None, None, self.config.hidden_size])
1197
+
1198
+
1199
+ @add_start_docstrings(
1200
+ """
1201
+ Flaubert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1202
+ softmax) e.g. for RocStories/SWAG tasks.
1203
+ """,
1204
+ FLAUBERT_START_DOCSTRING,
1205
+ )
1206
+ # Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForMultipleChoice with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
1207
+ class TFFlaubertForMultipleChoice(TFFlaubertPreTrainedModel, TFMultipleChoiceLoss):
1208
+ def __init__(self, config, *inputs, **kwargs):
1209
+ super().__init__(config, *inputs, **kwargs)
1210
+
1211
+ self.transformer = TFFlaubertMainLayer(config, name="transformer")
1212
+ self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary")
1213
+ self.logits_proj = keras.layers.Dense(
1214
+ 1, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
1215
+ )
1216
+ self.config = config
1217
+
1218
+ @property
1219
+ def dummy_inputs(self):
1220
+ """
1221
+ Dummy inputs to build the network.
1222
+
1223
+ Returns:
1224
+ tf.Tensor with dummy inputs
1225
+ """
1226
+ # Sometimes Flaubert has language embeddings so don't forget to build them as well if needed
1227
+ if self.config.use_lang_emb and self.config.n_langs > 1:
1228
+ return {
1229
+ "input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32),
1230
+ "langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32),
1231
+ }
1232
+ else:
1233
+ return {
1234
+ "input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32),
1235
+ }
1236
+
1237
+ @unpack_inputs
1238
+ @add_start_docstrings_to_model_forward(
1239
+ FLAUBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
1240
+ )
1241
+ @add_code_sample_docstrings(
1242
+ checkpoint=_CHECKPOINT_FOR_DOC,
1243
+ output_type=TFMultipleChoiceModelOutput,
1244
+ config_class=_CONFIG_FOR_DOC,
1245
+ )
1246
+ def call(
1247
+ self,
1248
+ input_ids: TFModelInputType | None = None,
1249
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1250
+ langs: np.ndarray | tf.Tensor | None = None,
1251
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1252
+ position_ids: np.ndarray | tf.Tensor | None = None,
1253
+ lengths: np.ndarray | tf.Tensor | None = None,
1254
+ cache: Optional[Dict[str, tf.Tensor]] = None,
1255
+ head_mask: np.ndarray | tf.Tensor | None = None,
1256
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1257
+ output_attentions: Optional[bool] = None,
1258
+ output_hidden_states: Optional[bool] = None,
1259
+ return_dict: Optional[bool] = None,
1260
+ labels: np.ndarray | tf.Tensor | None = None,
1261
+ training: bool = False,
1262
+ ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
1263
+ if input_ids is not None:
1264
+ num_choices = shape_list(input_ids)[1]
1265
+ seq_length = shape_list(input_ids)[2]
1266
+ else:
1267
+ num_choices = shape_list(inputs_embeds)[1]
1268
+ seq_length = shape_list(inputs_embeds)[2]
1269
+
1270
+ flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
1271
+ flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
1272
+ flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
1273
+ flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
1274
+ flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None
1275
+ flat_inputs_embeds = (
1276
+ tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
1277
+ if inputs_embeds is not None
1278
+ else None
1279
+ )
1280
+
1281
+ if lengths is not None:
1282
+ logger.warning(
1283
+ "The `lengths` parameter cannot be used with the Flaubert multiple choice models. Please use the "
1284
+ "attention mask instead.",
1285
+ )
1286
+ lengths = None
1287
+
1288
+ transformer_outputs = self.transformer(
1289
+ flat_input_ids,
1290
+ flat_attention_mask,
1291
+ flat_langs,
1292
+ flat_token_type_ids,
1293
+ flat_position_ids,
1294
+ lengths,
1295
+ cache,
1296
+ head_mask,
1297
+ flat_inputs_embeds,
1298
+ output_attentions,
1299
+ output_hidden_states,
1300
+ return_dict=return_dict,
1301
+ training=training,
1302
+ )
1303
+ output = transformer_outputs[0]
1304
+ logits = self.sequence_summary(output)
1305
+ logits = self.logits_proj(logits)
1306
+ reshaped_logits = tf.reshape(logits, (-1, num_choices))
1307
+
1308
+ loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
1309
+
1310
+ if not return_dict:
1311
+ output = (reshaped_logits,) + transformer_outputs[1:]
1312
+ return ((loss,) + output) if loss is not None else output
1313
+
1314
+ return TFMultipleChoiceModelOutput(
1315
+ loss=loss,
1316
+ logits=reshaped_logits,
1317
+ hidden_states=transformer_outputs.hidden_states,
1318
+ attentions=transformer_outputs.attentions,
1319
+ )
1320
+
1321
+ def build(self, input_shape=None):
1322
+ if self.built:
1323
+ return
1324
+ self.built = True
1325
+ if getattr(self, "transformer", None) is not None:
1326
+ with tf.name_scope(self.transformer.name):
1327
+ self.transformer.build(None)
1328
+ if getattr(self, "sequence_summary", None) is not None:
1329
+ with tf.name_scope(self.sequence_summary.name):
1330
+ self.sequence_summary.build(None)
1331
+ if getattr(self, "logits_proj", None) is not None:
1332
+ with tf.name_scope(self.logits_proj.name):
1333
+ self.logits_proj.build([None, None, self.config.num_labels])
1334
+
1335
+
1336
+ __all__ = [
1337
+ "TFFlaubertForMultipleChoice",
1338
+ "TFFlaubertForQuestionAnsweringSimple",
1339
+ "TFFlaubertForSequenceClassification",
1340
+ "TFFlaubertForTokenClassification",
1341
+ "TFFlaubertModel",
1342
+ "TFFlaubertPreTrainedModel",
1343
+ "TFFlaubertWithLMHeadModel",
1344
+ ]
docs/transformers/build/lib/transformers/models/flaubert/tokenization_flaubert.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2019-present CNRS, Facebook Inc. and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Flaubert."""
16
+
17
+ import json
18
+ import os
19
+ import re
20
+ import unicodedata
21
+ from typing import List, Optional, Tuple
22
+
23
+ from ...tokenization_utils import PreTrainedTokenizer
24
+ from ...utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {
30
+ "vocab_file": "vocab.json",
31
+ "merges_file": "merges.txt",
32
+ }
33
+
34
+
35
+ def convert_to_unicode(text):
36
+ """
37
+ Converts `text` to Unicode (if it's not already), assuming UTF-8 input.
38
+ """
39
+
40
+ def ensure_text(s, encoding="utf-8", errors="strict"):
41
+ if isinstance(s, bytes):
42
+ return s.decode(encoding, errors)
43
+ elif isinstance(s, str):
44
+ return s
45
+ else:
46
+ raise TypeError(f"not expecting type '{type(s)}'")
47
+
48
+ return ensure_text(text, encoding="utf-8", errors="ignore")
49
+
50
+
51
+ # Copied from transformers.models.xlm.tokenization_xlm.get_pairs
52
+ def get_pairs(word):
53
+ """
54
+ Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
55
+ strings)
56
+ """
57
+ pairs = set()
58
+ prev_char = word[0]
59
+ for char in word[1:]:
60
+ pairs.add((prev_char, char))
61
+ prev_char = char
62
+ return pairs
63
+
64
+
65
+ # Copied from transformers.models.xlm.tokenization_xlm.replace_unicode_punct
66
+ def replace_unicode_punct(text):
67
+ """
68
+ Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
69
+ """
70
+ text = text.replace(",", ",")
71
+ text = re.sub(r"。\s*", ". ", text)
72
+ text = text.replace("、", ",")
73
+ text = text.replace("”", '"')
74
+ text = text.replace("“", '"')
75
+ text = text.replace("∶", ":")
76
+ text = text.replace(":", ":")
77
+ text = text.replace("?", "?")
78
+ text = text.replace("《", '"')
79
+ text = text.replace("》", '"')
80
+ text = text.replace(")", ")")
81
+ text = text.replace("!", "!")
82
+ text = text.replace("(", "(")
83
+ text = text.replace(";", ";")
84
+ text = text.replace("1", "1")
85
+ text = text.replace("」", '"')
86
+ text = text.replace("「", '"')
87
+ text = text.replace("0", "0")
88
+ text = text.replace("3", "3")
89
+ text = text.replace("2", "2")
90
+ text = text.replace("5", "5")
91
+ text = text.replace("6", "6")
92
+ text = text.replace("9", "9")
93
+ text = text.replace("7", "7")
94
+ text = text.replace("8", "8")
95
+ text = text.replace("4", "4")
96
+ text = re.sub(r".\s*", ". ", text)
97
+ text = text.replace("~", "~")
98
+ text = text.replace("’", "'")
99
+ text = text.replace("…", "...")
100
+ text = text.replace("━", "-")
101
+ text = text.replace("〈", "<")
102
+ text = text.replace("〉", ">")
103
+ text = text.replace("【", "[")
104
+ text = text.replace("】", "]")
105
+ text = text.replace("%", "%")
106
+ return text
107
+
108
+
109
+ # Copied from transformers.models.xlm.tokenization_xlm.remove_non_printing_char
110
+ def remove_non_printing_char(text):
111
+ """
112
+ Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
113
+ """
114
+ output = []
115
+ for char in text:
116
+ cat = unicodedata.category(char)
117
+ if cat.startswith("C"):
118
+ continue
119
+ output.append(char)
120
+ return "".join(output)
121
+
122
+
123
+ class FlaubertTokenizer(PreTrainedTokenizer):
124
+ """
125
+ Construct a Flaubert tokenizer. Based on Byte-Pair Encoding. The tokenization process is the following:
126
+
127
+ - Moses preprocessing and tokenization.
128
+ - Normalizing all inputs text.
129
+ - The arguments `special_tokens` and the function `set_special_tokens`, can be used to add additional symbols (like
130
+ "__classify__") to a vocabulary.
131
+ - The argument `do_lowercase` controls lower casing (automatically set for pretrained vocabularies).
132
+
133
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
134
+ this superclass for more information regarding those methods.
135
+
136
+ Args:
137
+ vocab_file (`str`):
138
+ Vocabulary file.
139
+ merges_file (`str`):
140
+ Merges file.
141
+ do_lowercase (`bool`, *optional*, defaults to `False`):
142
+ Controls lower casing.
143
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
144
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
145
+ token instead.
146
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
147
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
148
+
149
+ <Tip>
150
+
151
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
152
+ sequence. The token used is the `cls_token`.
153
+
154
+ </Tip>
155
+
156
+ sep_token (`str`, *optional*, defaults to `"</s>"`):
157
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
158
+ sequence classification or for a text and a question for question answering. It is also used as the last
159
+ token of a sequence built with special tokens.
160
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
161
+ The token used for padding, for example when batching sequences of different lengths.
162
+ cls_token (`str`, *optional*, defaults to `"</s>"`):
163
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
164
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
165
+ mask_token (`str`, *optional*, defaults to `"<special1>"`):
166
+ The token used for masking values. This is the token used when training this model with masked language
167
+ modeling. This is the token which the model will try to predict.
168
+ additional_special_tokens (`List[str]`, *optional*, defaults to `['<special0>', '<special1>', '<special2>', '<special3>', '<special4>', '<special5>', '<special6>', '<special7>', '<special8>', '<special9>']`):
169
+ List of additional special tokens.
170
+ lang2id (`Dict[str, int]`, *optional*):
171
+ Dictionary mapping languages string identifiers to their IDs.
172
+ id2lang (`Dict[int, str]`, *optional*):
173
+ Dictionary mapping language IDs to their string identifiers.
174
+ """
175
+
176
+ vocab_files_names = VOCAB_FILES_NAMES
177
+
178
+ def __init__(
179
+ self,
180
+ vocab_file,
181
+ merges_file,
182
+ do_lowercase=False,
183
+ unk_token="<unk>",
184
+ bos_token="<s>",
185
+ sep_token="</s>",
186
+ pad_token="<pad>",
187
+ cls_token="</s>",
188
+ mask_token="<special1>",
189
+ additional_special_tokens=[
190
+ "<special0>",
191
+ "<special1>",
192
+ "<special2>",
193
+ "<special3>",
194
+ "<special4>",
195
+ "<special5>",
196
+ "<special6>",
197
+ "<special7>",
198
+ "<special8>",
199
+ "<special9>",
200
+ ],
201
+ lang2id=None,
202
+ id2lang=None,
203
+ **kwargs,
204
+ ):
205
+ do_lowercase_and_remove_accent = kwargs.pop("do_lowercase_and_remove_accent", None)
206
+ if do_lowercase_and_remove_accent is not None:
207
+ logger.warning(
208
+ "`do_lowercase_and_remove_accent` is passed as a keyword argument, but this won't do anything."
209
+ " `FlaubertTokenizer` will always set it to `False`."
210
+ )
211
+ # always `False`
212
+ self.do_lowercase_and_remove_accent = False
213
+
214
+ self.do_lowercase = do_lowercase
215
+
216
+ try:
217
+ import sacremoses
218
+ except ImportError:
219
+ raise ImportError(
220
+ "You need to install sacremoses to use FlaubertTokenizer. "
221
+ "See https://pypi.org/project/sacremoses/ for installation."
222
+ )
223
+
224
+ self.sm = sacremoses
225
+
226
+ # cache of sm.MosesPunctNormalizer instance
227
+ self.cache_moses_punct_normalizer = {}
228
+ # cache of sm.MosesTokenizer instance
229
+ self.cache_moses_tokenizer = {}
230
+ self.lang_with_custom_tokenizer = {"zh", "th", "ja"}
231
+ self.lang2id = lang2id
232
+ self.id2lang = id2lang
233
+ if lang2id is not None and id2lang is not None:
234
+ assert len(lang2id) == len(id2lang)
235
+
236
+ self.ja_word_tokenizer = None
237
+ self.zh_word_tokenizer = None
238
+
239
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
240
+ self.encoder = json.load(vocab_handle)
241
+ self.decoder = {v: k for k, v in self.encoder.items()}
242
+ with open(merges_file, encoding="utf-8") as merges_handle:
243
+ merges = merges_handle.read().split("\n")[:-1]
244
+ merges = [tuple(merge.split()[:2]) for merge in merges]
245
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
246
+ self.cache = {}
247
+
248
+ super().__init__(
249
+ do_lowercase=do_lowercase,
250
+ unk_token=unk_token,
251
+ bos_token=bos_token,
252
+ sep_token=sep_token,
253
+ pad_token=pad_token,
254
+ cls_token=cls_token,
255
+ mask_token=mask_token,
256
+ additional_special_tokens=additional_special_tokens,
257
+ lang2id=lang2id,
258
+ id2lang=id2lang,
259
+ **kwargs,
260
+ )
261
+
262
+ @property
263
+ # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.do_lower_case
264
+ def do_lower_case(self):
265
+ return self.do_lowercase_and_remove_accent
266
+
267
+ # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_punct_norm
268
+ def moses_punct_norm(self, text, lang):
269
+ if lang not in self.cache_moses_punct_normalizer:
270
+ punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang)
271
+ self.cache_moses_punct_normalizer[lang] = punct_normalizer
272
+ else:
273
+ punct_normalizer = self.cache_moses_punct_normalizer[lang]
274
+ return punct_normalizer.normalize(text)
275
+
276
+ # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_tokenize
277
+ def moses_tokenize(self, text, lang):
278
+ if lang not in self.cache_moses_tokenizer:
279
+ moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
280
+ self.cache_moses_tokenizer[lang] = moses_tokenizer
281
+ else:
282
+ moses_tokenizer = self.cache_moses_tokenizer[lang]
283
+ return moses_tokenizer.tokenize(text, return_str=False, escape=False)
284
+
285
+ # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_pipeline
286
+ def moses_pipeline(self, text, lang):
287
+ text = replace_unicode_punct(text)
288
+ text = self.moses_punct_norm(text, lang)
289
+ text = remove_non_printing_char(text)
290
+ return text
291
+
292
+ # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.ja_tokenize
293
+ def ja_tokenize(self, text):
294
+ if self.ja_word_tokenizer is None:
295
+ try:
296
+ import Mykytea
297
+
298
+ self.ja_word_tokenizer = Mykytea.Mykytea(
299
+ f"-model {os.path.expanduser('~')}/local/share/kytea/model.bin"
300
+ )
301
+ except (AttributeError, ImportError):
302
+ logger.error(
303
+ "Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper"
304
+ " (https://github.com/chezou/Mykytea-python) with the following steps"
305
+ )
306
+ logger.error("1. git clone [email protected]:neubig/kytea.git && cd kytea")
307
+ logger.error("2. autoreconf -i")
308
+ logger.error("3. ./configure --prefix=$HOME/local")
309
+ logger.error("4. make && make install")
310
+ logger.error("5. pip install kytea")
311
+ raise
312
+ return list(self.ja_word_tokenizer.getWS(text))
313
+
314
+ @property
315
+ # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.vocab_size
316
+ def vocab_size(self):
317
+ return len(self.encoder)
318
+
319
+ # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.get_vocab
320
+ def get_vocab(self):
321
+ return dict(self.encoder, **self.added_tokens_encoder)
322
+
323
+ # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.bpe
324
+ def bpe(self, token):
325
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
326
+ if token in self.cache:
327
+ return self.cache[token]
328
+ pairs = get_pairs(word)
329
+
330
+ if not pairs:
331
+ return token + "</w>"
332
+
333
+ while True:
334
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
335
+ if bigram not in self.bpe_ranks:
336
+ break
337
+ first, second = bigram
338
+ new_word = []
339
+ i = 0
340
+ while i < len(word):
341
+ try:
342
+ j = word.index(first, i)
343
+ except ValueError:
344
+ new_word.extend(word[i:])
345
+ break
346
+ else:
347
+ new_word.extend(word[i:j])
348
+ i = j
349
+
350
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
351
+ new_word.append(first + second)
352
+ i += 2
353
+ else:
354
+ new_word.append(word[i])
355
+ i += 1
356
+ new_word = tuple(new_word)
357
+ word = new_word
358
+ if len(word) == 1:
359
+ break
360
+ else:
361
+ pairs = get_pairs(word)
362
+ word = " ".join(word)
363
+ if word == "\n </w>":
364
+ word = "\n</w>"
365
+ self.cache[token] = word
366
+ return word
367
+
368
+ def preprocess_text(self, text):
369
+ text = text.replace("``", '"').replace("''", '"')
370
+ text = convert_to_unicode(text)
371
+ text = unicodedata.normalize("NFC", text)
372
+
373
+ if self.do_lowercase:
374
+ text = text.lower()
375
+
376
+ return text
377
+
378
+ def _tokenize(self, text, bypass_tokenizer=False):
379
+ """
380
+ Tokenize a string given language code using Moses.
381
+
382
+ Details of tokenization:
383
+
384
+ - [sacremoses](https://github.com/alvations/sacremoses): port of Moses
385
+ - Install with `pip install sacremoses`
386
+
387
+ Args:
388
+ - bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False)
389
+ (bool). If True, we only apply BPE.
390
+
391
+ Returns:
392
+ List of tokens.
393
+ """
394
+ lang = "fr"
395
+ if lang and self.lang2id and lang not in self.lang2id:
396
+ logger.error(
397
+ "Supplied language code not found in lang2id mapping. Please check that your language is supported by"
398
+ " the loaded pretrained model."
399
+ )
400
+
401
+ if bypass_tokenizer:
402
+ text = text.split()
403
+ else:
404
+ text = self.preprocess_text(text)
405
+ text = self.moses_pipeline(text, lang=lang)
406
+ text = self.moses_tokenize(text, lang=lang)
407
+
408
+ split_tokens = []
409
+ for token in text:
410
+ if token:
411
+ split_tokens.extend(list(self.bpe(token).split(" ")))
412
+
413
+ return split_tokens
414
+
415
+ # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer._convert_token_to_id
416
+ def _convert_token_to_id(self, token):
417
+ """Converts a token (str) in an id using the vocab."""
418
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
419
+
420
+ # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer._convert_id_to_token
421
+ def _convert_id_to_token(self, index):
422
+ """Converts an index (integer) in a token (str) using the vocab."""
423
+ return self.decoder.get(index, self.unk_token)
424
+
425
+ # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.convert_tokens_to_string
426
+ def convert_tokens_to_string(self, tokens):
427
+ """Converts a sequence of tokens (string) in a single string."""
428
+ out_string = "".join(tokens).replace("</w>", " ").strip()
429
+ return out_string
430
+
431
+ # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.build_inputs_with_special_tokens
432
+ def build_inputs_with_special_tokens(
433
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
434
+ ) -> List[int]:
435
+ """
436
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
437
+ adding special tokens. An XLM sequence has the following format:
438
+
439
+ - single sequence: `<s> X </s>`
440
+ - pair of sequences: `<s> A </s> B </s>`
441
+
442
+ Args:
443
+ token_ids_0 (`List[int]`):
444
+ List of IDs to which the special tokens will be added.
445
+ token_ids_1 (`List[int]`, *optional*):
446
+ Optional second list of IDs for sequence pairs.
447
+
448
+ Returns:
449
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
450
+
451
+ """
452
+ bos = [self.bos_token_id]
453
+ sep = [self.sep_token_id]
454
+
455
+ if token_ids_1 is None:
456
+ return bos + token_ids_0 + sep
457
+ return bos + token_ids_0 + sep + token_ids_1 + sep
458
+
459
+ # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.get_special_tokens_mask
460
+ def get_special_tokens_mask(
461
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
462
+ ) -> List[int]:
463
+ """
464
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
465
+ special tokens using the tokenizer `prepare_for_model` method.
466
+
467
+ Args:
468
+ token_ids_0 (`List[int]`):
469
+ List of IDs.
470
+ token_ids_1 (`List[int]`, *optional*):
471
+ Optional second list of IDs for sequence pairs.
472
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
473
+ Whether or not the token list is already formatted with special tokens for the model.
474
+
475
+ Returns:
476
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
477
+ """
478
+
479
+ if already_has_special_tokens:
480
+ return super().get_special_tokens_mask(
481
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
482
+ )
483
+
484
+ if token_ids_1 is not None:
485
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
486
+ return [1] + ([0] * len(token_ids_0)) + [1]
487
+
488
+ # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.create_token_type_ids_from_sequences
489
+ def create_token_type_ids_from_sequences(
490
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
491
+ ) -> List[int]:
492
+ """
493
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. An XLM sequence
494
+ pair mask has the following format:
495
+
496
+ ```
497
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
498
+ | first sequence | second sequence |
499
+ ```
500
+
501
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
502
+
503
+ Args:
504
+ token_ids_0 (`List[int]`):
505
+ List of IDs.
506
+ token_ids_1 (`List[int]`, *optional*):
507
+ Optional second list of IDs for sequence pairs.
508
+
509
+ Returns:
510
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
511
+ """
512
+ sep = [self.sep_token_id]
513
+ cls = [self.cls_token_id]
514
+ if token_ids_1 is None:
515
+ return len(cls + token_ids_0 + sep) * [0]
516
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
517
+
518
+ # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.save_vocabulary
519
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
520
+ if not os.path.isdir(save_directory):
521
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
522
+ return
523
+ vocab_file = os.path.join(
524
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
525
+ )
526
+ merge_file = os.path.join(
527
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
528
+ )
529
+
530
+ with open(vocab_file, "w", encoding="utf-8") as f:
531
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
532
+
533
+ index = 0
534
+ with open(merge_file, "w", encoding="utf-8") as writer:
535
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
536
+ if index != token_index:
537
+ logger.warning(
538
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
539
+ " Please check that the tokenizer is not corrupted!"
540
+ )
541
+ index = token_index
542
+ writer.write(" ".join(bpe_tokens) + "\n")
543
+ index += 1
544
+
545
+ return vocab_file, merge_file
546
+
547
+ # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.__getstate__
548
+ def __getstate__(self):
549
+ state = self.__dict__.copy()
550
+ state["sm"] = None
551
+ return state
552
+
553
+ # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.__setstate__
554
+ def __setstate__(self, d):
555
+ self.__dict__ = d
556
+
557
+ try:
558
+ import sacremoses
559
+ except ImportError:
560
+ raise ImportError(
561
+ "You need to install sacremoses to use XLMTokenizer. "
562
+ "See https://pypi.org/project/sacremoses/ for installation."
563
+ )
564
+
565
+ self.sm = sacremoses
566
+
567
+
568
+ __all__ = ["FlaubertTokenizer"]
docs/transformers/build/lib/transformers/models/flava/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_flava import *
22
+ from .feature_extraction_flava import *
23
+ from .image_processing_flava import *
24
+ from .image_processing_flava_fast import *
25
+ from .modeling_flava import *
26
+ from .processing_flava import *
27
+ else:
28
+ import sys
29
+
30
+ _file = globals()["__file__"]
31
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/flava/configuration_flava.py ADDED
@@ -0,0 +1,701 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """FLAVA model configurations"""
16
+
17
+ from typing import Any, Dict
18
+
19
+ from ...configuration_utils import PretrainedConfig
20
+ from ...utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class FlavaImageConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`FlavaImageModel`]. It is used to instantiate an
29
+ FLAVA model according to the specified arguments, defining the model architecture.
30
+
31
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
32
+ [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ hidden_size (`int`, *optional*, defaults to 768):
40
+ Dimensionality of the encoder layers and the pooler layer.
41
+ num_hidden_layers (`int`, *optional*, defaults to 12):
42
+ Number of hidden layers in the Transformer encoder.
43
+ num_attention_heads (`int`, *optional*, defaults to 12):
44
+ Number of attention heads for each attention layer in the Transformer encoder.
45
+ intermediate_size (`int`, *optional*, defaults to 3072):
46
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
47
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
48
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
49
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
50
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
51
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
52
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
53
+ The dropout ratio for the attention probabilities.
54
+ initializer_range (`float`, *optional*, defaults to 0.02):
55
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
56
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
57
+ The epsilon used by the layer normalization layers.
58
+ image_size (`int`, *optional*, defaults to 224):
59
+ The size (resolution) of each image.
60
+ patch_size (`int`, *optional*, defaults to 16):
61
+ The size (resolution) of each patch.
62
+ num_channels (`int`, *optional*, defaults to 3):
63
+ The number of input channels.
64
+ qkv_bias (`bool`, *optional*, defaults to `True`):
65
+ Whether to add a bias to the queries, keys and values.
66
+ mask_token (`bool`, *optional*, defaults to `True`):
67
+ Whether to use a mask token or not. Used in MIM (Masked Image Modeling) loss for FLAVA.
68
+ vocab_size (`int`, *optional*, defaults to 8192):
69
+ Vocabulary size of the [`FlavaImageCodebook`] used in conjunction with [`FlavaImageModel`] for MIM (Masked
70
+ Image Modeling) loss for FLAVA.
71
+
72
+ Example:
73
+
74
+ ```python
75
+ >>> from transformers import FlavaImageConfig, FlavaImageModel
76
+
77
+ >>> # Initializing a FlavaImageModel with style configuration
78
+ >>> configuration = FlavaImageConfig()
79
+
80
+ >>> # Initializing a FlavaImageModel model (with random weights) from the style configuration
81
+ >>> model = FlavaImageModel(configuration)
82
+
83
+ >>> # Accessing the model configuration
84
+ >>> configuration = model.config
85
+ ```"""
86
+
87
+ model_type = "flava_image_model"
88
+ base_config_key = "image_config"
89
+
90
+ def __init__(
91
+ self,
92
+ hidden_size: int = 768,
93
+ num_hidden_layers: int = 12,
94
+ num_attention_heads: int = 12,
95
+ intermediate_size: int = 3072,
96
+ hidden_act: int = "gelu",
97
+ hidden_dropout_prob: float = 0.0,
98
+ attention_probs_dropout_prob: float = 0.0,
99
+ initializer_range: float = 0.02,
100
+ layer_norm_eps: float = 1e-12,
101
+ image_size: int = 224,
102
+ patch_size: int = 16,
103
+ num_channels: int = 3,
104
+ qkv_bias: bool = True,
105
+ mask_token: bool = True,
106
+ vocab_size: int = 8192,
107
+ **kwargs,
108
+ ):
109
+ super().__init__(**kwargs)
110
+
111
+ self.hidden_size = hidden_size
112
+ self.num_hidden_layers = num_hidden_layers
113
+ self.num_attention_heads = num_attention_heads
114
+ self.intermediate_size = intermediate_size
115
+ self.hidden_act = hidden_act
116
+ self.hidden_dropout_prob = hidden_dropout_prob
117
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
118
+ self.initializer_range = initializer_range
119
+ self.layer_norm_eps = layer_norm_eps
120
+ self.image_size = image_size
121
+ self.patch_size = patch_size
122
+ self.num_channels = num_channels
123
+ self.qkv_bias = qkv_bias
124
+ self.mask_token = mask_token
125
+ self.vocab_size = vocab_size
126
+
127
+
128
+ class FlavaTextConfig(PretrainedConfig):
129
+ r"""
130
+ This is the configuration class to store the configuration of a [`FlavaTextModel`]. It is used to instantiate an
131
+ FLAVA model according to the specified arguments, defining the model architecture.
132
+
133
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
134
+ [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
135
+
136
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
137
+ documentation from [`PretrainedConfig`] for more information.
138
+
139
+
140
+ Args:
141
+ vocab_size (`int`, *optional*, defaults to 30522):
142
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
143
+ `inputs_ids` passed when calling [`FlavaTextModel`].
144
+ type_vocab_size (`int`, *optional*, defaults to 2):
145
+ The vocabulary size of the `token_type_ids` passed when calling [`FlavaTextModel`]. Note that even though
146
+ text encoder allows `token_type_ids`'s value as 2, for text-only pretraining and fine-tuning, only 1 is
147
+ used similar to RoBERTa.
148
+ max_position_embeddings (`int`, *optional*, defaults to 512):
149
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
150
+ just in case (e.g., 512 or 1024 or 2048). For VL, max_length passed to model is 77.
151
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
152
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
153
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
154
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
155
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
156
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
157
+ hidden_size (`int`, *optional*, defaults to 768):
158
+ Dimensionality of the encoder layers and the pooler layer.
159
+ num_hidden_layers (`int`, *optional*, defaults to 12):
160
+ Number of hidden layers in the Transformer encoder.
161
+ num_attention_heads (`int`, *optional*, defaults to 12):
162
+ Number of attention heads for each attention layer in the Transformer encoder.
163
+ intermediate_size (`int`, *optional*, defaults to 3072):
164
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
165
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
166
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
167
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
168
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
169
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
170
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
171
+ The dropout ratio for the attention probabilities.
172
+ initializer_range (`float`, *optional*, defaults to 0.02):
173
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
174
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
175
+ The epsilon used by the layer normalization layers.
176
+ image_size (`int`, *optional*, defaults to 224):
177
+ The size (resolution) of each image.
178
+ patch_size (`int`, *optional*, defaults to 16):
179
+ The size (resolution) of each patch.
180
+ num_channels (`int`, *optional*, defaults to 3):
181
+ The number of input channels.
182
+ qkv_bias (`bool`, *optional*, defaults to `True`):
183
+ Whether to add a bias to the queries, keys and values.
184
+
185
+ Example:
186
+
187
+ ```python
188
+ >>> from transformers import FlavaTextConfig, FlavaTextModel
189
+
190
+ >>> # Initializing a FlavaTextModel with style configuration
191
+ >>> configuration = FlavaTextConfig()
192
+
193
+ >>> # Initializing a FlavaTextModel model (with random weights) from the style configuration
194
+ >>> model = FlavaTextModel(configuration)
195
+
196
+ >>> # Accessing the model configuration
197
+ >>> configuration = model.config
198
+ ```"""
199
+
200
+ model_type = "flava_text_model"
201
+ base_config_key = "text_config"
202
+
203
+ def __init__(
204
+ self,
205
+ vocab_size: int = 30522,
206
+ type_vocab_size: int = 2,
207
+ max_position_embeddings: int = 512,
208
+ position_embedding_type: str = "absolute",
209
+ hidden_size: int = 768,
210
+ num_hidden_layers: int = 12,
211
+ num_attention_heads: int = 12,
212
+ intermediate_size: int = 3072,
213
+ hidden_act: str = "gelu",
214
+ hidden_dropout_prob: float = 0.0,
215
+ attention_probs_dropout_prob: float = 0.0,
216
+ initializer_range: float = 0.02,
217
+ layer_norm_eps: float = 1e-12,
218
+ pad_token_id: int = 0,
219
+ qkv_bias: bool = True,
220
+ **kwargs,
221
+ ):
222
+ super().__init__(**kwargs)
223
+
224
+ self.vocab_size = vocab_size
225
+ self.type_vocab_size = type_vocab_size
226
+ self.max_position_embeddings = max_position_embeddings
227
+ self.position_embedding_type = position_embedding_type
228
+ self.hidden_size = hidden_size
229
+ self.num_hidden_layers = num_hidden_layers
230
+ self.num_attention_heads = num_attention_heads
231
+ self.intermediate_size = intermediate_size
232
+ self.hidden_act = hidden_act
233
+ self.hidden_dropout_prob = hidden_dropout_prob
234
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
235
+ self.initializer_range = initializer_range
236
+ self.layer_norm_eps = layer_norm_eps
237
+ self.qkv_bias = qkv_bias
238
+ self.pad_token_id = pad_token_id
239
+
240
+
241
+ class FlavaMultimodalConfig(PretrainedConfig):
242
+ r"""
243
+ This is the configuration class to store the configuration of a [`FlavaMultimodalModel`]. It is used to instantiate
244
+ an FLAVA model according to the specified arguments, defining the model architecture.
245
+
246
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
247
+ [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
248
+
249
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
250
+ documentation from [`PretrainedConfig`] for more information.
251
+
252
+
253
+ Args:
254
+ hidden_size (`int`, *optional*, defaults to 768):
255
+ Dimensionality of the encoder layers and the pooler layer.
256
+ num_hidden_layers (`int`, *optional*, defaults to 6):
257
+ Number of hidden layers in the Transformer encoder.
258
+ num_attention_heads (`int`, *optional*, defaults to 12):
259
+ Number of attention heads for each attention layer in the Transformer encoder.
260
+ intermediate_size (`int`, *optional*, defaults to 3072):
261
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
262
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
263
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
264
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
265
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
266
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
267
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
268
+ The dropout ratio for the attention probabilities.
269
+ initializer_range (`float`, *optional*, defaults to 0.02):
270
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
271
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
272
+ The epsilon used by the layer normalization layers.
273
+ qkv_bias (`bool`, *optional*, defaults to `True`):
274
+ Whether to add a bias to the queries, keys and values.
275
+ use_cls_token (`bool`, *optional*, defaults to `True`):
276
+ Whether to use an extra CLS token for multimodal settings. Usually needed by the FLAVA model.
277
+
278
+
279
+ Example:
280
+
281
+ ```python
282
+ >>> from transformers import FlavaMultimodalConfig, FlavaMultimodalModel
283
+
284
+ >>> # Initializing a FlavaMultimodalModel with style configuration
285
+ >>> configuration = FlavaMultimodalConfig()
286
+
287
+ >>> # Initializing a FlavaMultimodalModel model (with random weights) from the style configuration
288
+ >>> model = FlavaMultimodalModel(configuration)
289
+
290
+ >>> # Accessing the model configuration
291
+ >>> configuration = model.config
292
+ ```"""
293
+
294
+ model_type = "flava_multimodal_model"
295
+ base_config_key = "multimodal_config"
296
+
297
+ def __init__(
298
+ self,
299
+ hidden_size: int = 768,
300
+ num_hidden_layers: int = 6,
301
+ num_attention_heads: int = 12,
302
+ intermediate_size: int = 3072,
303
+ hidden_act: int = "gelu",
304
+ hidden_dropout_prob: int = 0.0,
305
+ attention_probs_dropout_prob: int = 0.0,
306
+ initializer_range: float = 0.02,
307
+ layer_norm_eps: float = 1e-12,
308
+ qkv_bias: bool = True,
309
+ use_cls_token: bool = True,
310
+ **kwargs,
311
+ ):
312
+ super().__init__(**kwargs)
313
+
314
+ self.hidden_size = hidden_size
315
+ self.num_hidden_layers = num_hidden_layers
316
+ self.num_attention_heads = num_attention_heads
317
+ self.intermediate_size = intermediate_size
318
+ self.hidden_act = hidden_act
319
+ self.hidden_dropout_prob = hidden_dropout_prob
320
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
321
+ self.initializer_range = initializer_range
322
+ self.layer_norm_eps = layer_norm_eps
323
+ self.qkv_bias = qkv_bias
324
+ self.use_cls_token = use_cls_token
325
+
326
+
327
+ class FlavaImageCodebookConfig(PretrainedConfig):
328
+ model_type = "flava_image_codebook"
329
+ base_config_key = "image_codebook_config"
330
+
331
+ r"""
332
+ [`FlavaImageCodebookConfig`] is the configuration class to store the configuration of a [`FlavaImageCodebook`]. It
333
+ is used to instantiate an FLAVA model according to the specified arguments, defining the model architecture.
334
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
335
+ [facebook/flava-image-codebook](https://huggingface.co/facebook/flava-image-codebook) architecture.
336
+
337
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
338
+ documentation from [`PretrainedConfig`] for more information.
339
+
340
+ Args:
341
+ num_groups (`int`, *optional*, defaults to 4):
342
+ Number of groups to be created. This parameter as of now doesn't affect the model and is used for some
343
+ internal calculation and estimations.
344
+ input_channels (`int`, *optional*, defaults to 3):
345
+ Number of channels in the image to be passed.
346
+ num_blocks_per_group (`int`, *optional*, defaults to 2):
347
+ Number of conv-based blocks per group.
348
+ hidden_size (`int`, *optional*, defaults to 256):
349
+ Size of hidden dim for the blocks.
350
+ vocab_size (`int`, *optional*, defaults to 8192):
351
+ Size of the output vocabulary for the codebook.
352
+ freeze (`bool`, defaults to `True`):
353
+ Whether to freeze the weights of the model.
354
+ initializer_range (`float`, *optional*, defaults to 0.02):
355
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
356
+ kwargs (*optional*):
357
+ Dictionary of keyword arguments.
358
+
359
+ Example:
360
+
361
+ ```python
362
+ >>> from transformers import FlavaImageCodebookConfig, FlavaImageCodebook
363
+
364
+ >>> # Initializing a FlavaImageCodebook with style configuration
365
+ >>> configuration = FlavaImageCodebookConfig()
366
+
367
+ >>> # Initializing a FlavaImageCodebook model (with random weights) from the style configuration
368
+ >>> model = FlavaImageCodebook(configuration)
369
+ >>> # Accessing the model configuration
370
+ >>> configuration = model.config
371
+ ```
372
+ """
373
+
374
+ def __init__(
375
+ self,
376
+ num_groups: int = 4,
377
+ input_channels: int = 3,
378
+ num_blocks_per_group: int = 2,
379
+ hidden_size: int = 256,
380
+ vocab_size: int = 8192,
381
+ freeze: int = True,
382
+ initializer_range: float = 0.02,
383
+ **kwargs,
384
+ ):
385
+ super().__init__(**kwargs)
386
+ self.num_groups = num_groups
387
+ self.input_channels = input_channels
388
+ self.num_blocks_per_group = num_blocks_per_group
389
+ self.hidden_size = hidden_size
390
+ self.vocab_size = vocab_size
391
+ self.freeze = freeze
392
+ self.initializer_range = initializer_range
393
+
394
+
395
+ class FlavaConfig(PretrainedConfig):
396
+ r"""
397
+ [`FlavaConfig`] is the configuration class to store the configuration of a [`FlavaModel`]. It is used to
398
+ instantiate FLAVA model according to the specified arguments, defining the text model, image model, image codebook
399
+ and multimodal model configs. Instantiating a configuration with the defaults will yield a similar configuration to
400
+ that of the FLAVA [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
401
+
402
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
403
+ documentation from [`PretrainedConfig`] for more information.
404
+
405
+ Args:
406
+ text_config (`dict`, *optional*):
407
+ Dictionary of configuration options used to initialize [`FlavaTextConfig`].
408
+ image_config (`dict`, *optional*):
409
+ Dictionary of configuration options used to initialize [`FlavaImageConfig`].
410
+ multimodal_config (`dict`, *optional*):
411
+ Dictionary of configuration options used to initialize [`FlavaMultimodalConfig`].
412
+ hidden_size (`int`, *optional*, defaults to 768):
413
+ Dimensionality of the encoder layers and the pooler layer.
414
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
415
+ The epsilon used by the layer normalization layers.
416
+ projection_dim (`int`, *optional*, defaults to 512):
417
+ Dimensionality of text and image projection layers.
418
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
419
+ The initial value of the *logit_scale* parameter. Default is used as per the original FLAVA/CLIP
420
+ implementation.
421
+ initializer_range (`float`, *optional*, defaults to 0.02):
422
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
423
+ ce_ignore_index (`int`, *optional*, defaults to -100):
424
+ Cross entropy index to ignore.
425
+ mim_weight (`float`, *optional*, defaults to 1.0):
426
+ Weight to be assigned to MIM (Masked Image Modeling) unimodal loss
427
+ mlm_weight (`float`, *optional*, defaults to 1.0):
428
+ Weight to be assigned to MLM (Masked Language Modeling) unimodal loss
429
+ global_contrastive_weight (`float`, *optional*, defaults to 1.0):
430
+ Weight to be assigned to global contrastive cross-alignment loss.
431
+ itm_weight (`float`, *optional*, defaults to 1.0):
432
+ Weight to be assigned to image-text matching multimodal loss.
433
+ mmm_image_weight (`float`, *optional*, defaults to 1.0):
434
+ Weight to be assigned to MMM loss's image part.
435
+ mmm_text_weight (`float`, *optional*, defaults to 1.0):
436
+ Weight to be assigned to MMM loss's text part.
437
+ global_backprop_contrastive (`bool`, *optional*, defaults to `True`):
438
+ Whether to use global backpropgation through all workers in contrastive loss.
439
+ skip_unmasked_multimodal_encoder (`bool`, *optional*, defaults to `True`):
440
+ Whether to skip running unmasked multimodal encoder whose outputs are not used by FLAVA losses.
441
+ return_loss (`bool`, *optional*, defaults to `True`):
442
+ Whether to return loss or not
443
+
444
+ kwargs (*optional*):
445
+ Dictionary of keyword arguments.
446
+
447
+ Example:
448
+
449
+ ```python
450
+ >>> from transformers import FlavaConfig, FlavaModel, FlavaForPreTraining
451
+
452
+ >>> # Initializing a FlavaConfig with style configuration
453
+ >>> configuration = FlavaConfig()
454
+
455
+ >>> # Initializing a FlavaModel and FlavaForPreTraining model (with random weights) from the style configuration
456
+ >>> model = FlavaModel(configuration)
457
+ >>> model_pre = FlavaForPreTraining(configuration)
458
+
459
+ >>> # Accessing the model configuration
460
+ >>> configuration = model.config
461
+ >>> configuration_pre = model_pre.config
462
+ ```
463
+ """
464
+
465
+ model_type = "flava"
466
+ sub_configs = {
467
+ "text_config": FlavaTextConfig,
468
+ "image_config": FlavaImageConfig,
469
+ "multimodal_config": FlavaMultimodalConfig,
470
+ "image_codebook_config": FlavaImageCodebookConfig,
471
+ }
472
+
473
+ def __init__(
474
+ self,
475
+ image_config: Dict[str, Any] = None,
476
+ text_config: Dict[str, Any] = None,
477
+ multimodal_config: Dict[str, Any] = None,
478
+ image_codebook_config: Dict[str, Any] = None,
479
+ hidden_size: int = 768,
480
+ layer_norm_eps: float = 1e-12,
481
+ projection_dim: int = 768,
482
+ init_codebook: bool = True,
483
+ logit_scale_init_value: float = 2.6592,
484
+ initializer_range: float = 0.02,
485
+ ce_ignore_index: int = -100,
486
+ mim_weight: float = 1.0,
487
+ mlm_weight: float = 1.0,
488
+ global_contrastive_weight: float = 1.0,
489
+ itm_weight: float = 1.0,
490
+ mmm_image_weight: float = 1.0,
491
+ mmm_text_weight: float = 1.0,
492
+ global_backprop_contrastive: bool = True,
493
+ skip_unmasked_multimodal_encoder: bool = True,
494
+ return_loss: bool = True,
495
+ **kwargs,
496
+ ):
497
+ # If `_config_dict` exist, we use them for the backward compatibility.
498
+ # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
499
+ # of confusion!).
500
+ text_config_dict = kwargs.pop("text_config_dict", None)
501
+ image_config_dict = kwargs.pop("image_config_dict", None)
502
+ multimodal_config_dict = kwargs.pop("multimodal_config_dict", None)
503
+ image_codebook_config_dict = kwargs.pop("image_codebook_config_dict", None)
504
+
505
+ super().__init__(**kwargs)
506
+
507
+ # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
508
+ # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
509
+ # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
510
+ if text_config_dict is not None:
511
+ if text_config is None:
512
+ text_config = {}
513
+
514
+ # This is the complete result when using `text_config_dict`.
515
+ _text_config_dict = FlavaTextConfig(**text_config_dict).to_dict()
516
+
517
+ # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
518
+ for key, value in _text_config_dict.items():
519
+ if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
520
+ # If specified in `text_config_dict`
521
+ if key in text_config_dict:
522
+ message = (
523
+ f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
524
+ f'The value `text_config_dict["{key}"]` will be used instead.'
525
+ )
526
+ # If inferred from default argument values (just to be super careful)
527
+ else:
528
+ message = (
529
+ f"`text_config_dict` is provided which will be used to initialize `FlavaTextConfig`. The "
530
+ f'value `text_config["{key}"]` will be overridden.'
531
+ )
532
+ logger.info(message)
533
+
534
+ # Update all values in `text_config` with the ones in `_text_config_dict`.
535
+ text_config.update(_text_config_dict)
536
+
537
+ if image_config_dict is not None:
538
+ if image_config is None:
539
+ image_config = {}
540
+
541
+ # This is the complete result when using `image_config_dict`.
542
+ _image_config_dict = FlavaImageConfig(**image_config_dict).to_dict()
543
+ # convert keys to string instead of integer
544
+ if "id2label" in _image_config_dict:
545
+ _image_config_dict["id2label"] = {
546
+ str(key): value for key, value in _image_config_dict["id2label"].items()
547
+ }
548
+
549
+ # Give a warning if the values exist in both `_image_config_dict` and `image_config` but being different.
550
+ for key, value in _image_config_dict.items():
551
+ if key in image_config and value != image_config[key] and key not in ["transformers_version"]:
552
+ # If specified in `image_config_dict`
553
+ if key in image_config_dict:
554
+ message = (
555
+ f"`{key}` is found in both `image_config_dict` and `image_config` but with different "
556
+ f'values. The value `image_config_dict["{key}"]` will be used instead.'
557
+ )
558
+ # If inferred from default argument values (just to be super careful)
559
+ else:
560
+ message = (
561
+ f"`image_config_dict` is provided which will be used to initialize `FlavaImageConfig`. "
562
+ f'The value `image_config["{key}"]` will be overridden.'
563
+ )
564
+ logger.info(message)
565
+
566
+ # Update all values in `image_config` with the ones in `_image_config_dict`.
567
+ image_config.update(_image_config_dict)
568
+
569
+ if multimodal_config_dict is not None:
570
+ if multimodal_config is None:
571
+ multimodal_config = {}
572
+
573
+ # This is the complete result when using `multimodal_config_dict`.
574
+ _multimodal_config_dict = FlavaMultimodalConfig(**multimodal_config_dict).to_dict()
575
+
576
+ # Give a warning if the values exist in both `_multimodal_config_dict` and `multimodal_config` but being
577
+ # different.
578
+ for key, value in _multimodal_config_dict.items():
579
+ if (
580
+ key in multimodal_config
581
+ and value != multimodal_config[key]
582
+ and key not in ["transformers_version"]
583
+ ):
584
+ # If specified in `multimodal_config_dict`
585
+ if key in multimodal_config_dict:
586
+ message = (
587
+ f"`{key}` is found in both `multimodal_config_dict` and `multimodal_config` but with "
588
+ f'different values. The value `multimodal_config_dict["{key}"]` will be used instead.'
589
+ )
590
+ # If inferred from default argument values (just to be super careful)
591
+ else:
592
+ message = (
593
+ f"`multimodal_config_dict` is provided which will be used to initialize "
594
+ f'`FlavaMultimodalConfig`. The value `multimodal_config["{key}"]` will be overridden.'
595
+ )
596
+ logger.info(message)
597
+
598
+ # Update all values in `multimodal_config` with the ones in `_multimodal_config_dict`.
599
+ multimodal_config.update(_multimodal_config_dict)
600
+
601
+ if image_codebook_config_dict is not None:
602
+ if image_codebook_config is None:
603
+ image_codebook_config = {}
604
+
605
+ # This is the complete result when using `image_codebook_config_dict`.
606
+ _image_codebook_config_dict = FlavaImageCodebookConfig(**image_codebook_config_dict).to_dict()
607
+
608
+ # Give a warning if the values exist in both `_image_codebook_config_dict` and `image_codebook_config` but
609
+ # being different.
610
+ for key, value in _image_codebook_config_dict.items():
611
+ if (
612
+ key in image_codebook_config
613
+ and value != image_codebook_config[key]
614
+ and key not in ["transformers_version"]
615
+ ):
616
+ # If specified in `image_codebook_config_dict`
617
+ if key in image_codebook_config_dict:
618
+ message = (
619
+ f"`{key}` is found in both `image_codebook_config_dict` and `image_codebook_config` but "
620
+ f'with different values. The value `image_codebook_config_dict["{key}"]` will be used '
621
+ "instead."
622
+ )
623
+ # If inferred from default argument values (just to be super careful)
624
+ else:
625
+ message = (
626
+ f"`image_codebook_config_dict` is provided which will be used to initialize "
627
+ f'`FlavaImageCodebookConfig`. The value `image_codebook_config["{key}"]` will be overridden.'
628
+ )
629
+ logger.info(message)
630
+
631
+ # Update all values in `image_codebook_config` with the ones in `_image_codebook_config_dict`.
632
+ image_codebook_config.update(_image_codebook_config_dict)
633
+
634
+ if image_config is None:
635
+ image_config = {}
636
+ logger.info("`image_config` is `None`. initializing the `FlavaImageConfig` with default values.")
637
+
638
+ if text_config is None:
639
+ text_config = {}
640
+ logger.info("`text_config` is `None`. Initializing the `FlavaTextConfig` with default values.")
641
+
642
+ if multimodal_config is None:
643
+ multimodal_config = {}
644
+ logger.info("`multimodal_config` is `None`. initializing the `FlavaMultimodalConfig` with default values.")
645
+
646
+ if image_codebook_config is None:
647
+ image_codebook_config = {}
648
+ logger.info(
649
+ "`image_codebook_config` is `None`. initializing the `FlavaImageCodebookConfig` with default values."
650
+ )
651
+
652
+ self.image_config = FlavaImageConfig(**image_config)
653
+ self.text_config = FlavaTextConfig(**text_config)
654
+ self.multimodal_config = FlavaMultimodalConfig(**multimodal_config)
655
+ self.image_codebook_config = FlavaImageCodebookConfig(**image_codebook_config)
656
+ self.projection_dim = projection_dim
657
+ self.init_codebook = init_codebook
658
+
659
+ self.hidden_size = hidden_size
660
+ self.layer_norm_eps = layer_norm_eps
661
+ self.initializer_range = initializer_range
662
+ self.logit_scale_init_value = logit_scale_init_value
663
+ self.initializer_factor = 1.0
664
+ self.ce_ignore_index = ce_ignore_index
665
+ self.mim_weight = mim_weight
666
+ self.mlm_weight = mlm_weight
667
+ self.global_contrastive_weight = global_contrastive_weight
668
+ self.itm_weight = itm_weight
669
+ self.mmm_image_weight = mmm_image_weight
670
+ self.mmm_text_weight = mmm_text_weight
671
+ self.global_backprop_contrastive = global_backprop_contrastive
672
+ self.skip_unmasked_multimodal_encoder = skip_unmasked_multimodal_encoder
673
+ self.return_loss = return_loss
674
+
675
+ @classmethod
676
+ def from_configs(
677
+ cls,
678
+ image_config: FlavaImageConfig,
679
+ text_config: FlavaTextConfig,
680
+ multimodal_config: FlavaMultimodalConfig,
681
+ image_codebook_config: FlavaImageCodebookConfig,
682
+ **kwargs,
683
+ ):
684
+ r"""
685
+ Instantiate a [`FlavaConfig`] (or a derived class) from flava text model configuration, flava image model
686
+ configuration, flava multimodal model and flava codebook model configuration.
687
+
688
+ Returns:
689
+ [`FlavaConfig`]: An instance of a configuration object
690
+ """
691
+
692
+ return cls(
693
+ image_config=image_config.to_dict(),
694
+ text_config=text_config.to_dict(),
695
+ multimodal_config=multimodal_config.to_dict(),
696
+ image_codebook_config=image_codebook_config.to_dict(),
697
+ **kwargs,
698
+ )
699
+
700
+
701
+ __all__ = ["FlavaConfig", "FlavaImageCodebookConfig", "FlavaImageConfig", "FlavaMultimodalConfig", "FlavaTextConfig"]
docs/transformers/build/lib/transformers/models/flava/convert_dalle_to_flava_codebook.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import os
18
+
19
+ import torch
20
+
21
+ from transformers import FlavaImageCodebook, FlavaImageCodebookConfig
22
+
23
+
24
+ def rreplace(s, old, new, occurrence):
25
+ li = s.rsplit(old, occurrence)
26
+ return new.join(li)
27
+
28
+
29
+ def count_parameters(state_dict):
30
+ # encoder.embeddings are double copied in original FLAVA
31
+ return sum(param.float().sum() if "encoder.embeddings" not in key else 0 for key, param in state_dict.items())
32
+
33
+
34
+ def upgrade_state_dict(state_dict):
35
+ upgrade = {}
36
+
37
+ group_keys = ["group_1", "group_2", "group_3", "group_4"]
38
+ for key, value in state_dict.items():
39
+ for group_key in group_keys:
40
+ if group_key in key:
41
+ key = key.replace(f"{group_key}.", f"{group_key}.group.")
42
+
43
+ if "res_path" in key:
44
+ key = key.replace("res_path.", "res_path.path.")
45
+
46
+ if key.endswith(".w"):
47
+ key = rreplace(key, ".w", ".weight", 1)
48
+ if key.endswith(".b"):
49
+ key = rreplace(key, ".b", ".bias", 1)
50
+
51
+ upgrade[key] = value.float()
52
+
53
+ return upgrade
54
+
55
+
56
+ @torch.no_grad()
57
+ def convert_dalle_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None, save_checkpoint=True):
58
+ """
59
+ Copy/paste/tweak model's weights to transformers design.
60
+ """
61
+ from dall_e import Encoder
62
+
63
+ encoder = Encoder()
64
+ if os.path.exists(checkpoint_path):
65
+ ckpt = torch.load(checkpoint_path, weights_only=True)
66
+ else:
67
+ ckpt = torch.hub.load_state_dict_from_url(checkpoint_path)
68
+
69
+ if isinstance(ckpt, Encoder):
70
+ ckpt = ckpt.state_dict()
71
+ encoder.load_state_dict(ckpt)
72
+
73
+ if config_path is not None:
74
+ config = FlavaImageCodebookConfig.from_pretrained(config_path)
75
+ else:
76
+ config = FlavaImageCodebookConfig()
77
+
78
+ hf_model = FlavaImageCodebook(config).eval()
79
+ state_dict = encoder.state_dict()
80
+
81
+ hf_state_dict = upgrade_state_dict(state_dict)
82
+ hf_model.load_state_dict(hf_state_dict)
83
+ hf_state_dict = hf_model.state_dict()
84
+ hf_count = count_parameters(hf_state_dict)
85
+ state_dict_count = count_parameters(state_dict)
86
+
87
+ assert torch.allclose(hf_count, state_dict_count, atol=1e-3)
88
+
89
+ if save_checkpoint:
90
+ hf_model.save_pretrained(pytorch_dump_folder_path)
91
+ else:
92
+ return hf_state_dict
93
+
94
+
95
+ if __name__ == "__main__":
96
+ parser = argparse.ArgumentParser()
97
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
98
+ parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to flava checkpoint")
99
+ parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
100
+ args = parser.parse_args()
101
+
102
+ convert_dalle_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
docs/transformers/build/lib/transformers/models/flava/convert_flava_original_pytorch_to_hf.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import os
18
+
19
+ import torch
20
+
21
+ from transformers import FlavaConfig, FlavaForPreTraining
22
+ from transformers.models.flava.convert_dalle_to_flava_codebook import convert_dalle_checkpoint
23
+
24
+
25
+ def count_parameters(state_dict):
26
+ # encoder.embeddings are double copied in original FLAVA
27
+ return sum(param.float().sum() if "encoder.embeddings" not in key else 0 for key, param in state_dict.items())
28
+
29
+
30
+ def upgrade_state_dict(state_dict, codebook_state_dict):
31
+ upgrade = {}
32
+
33
+ for key, value in state_dict.items():
34
+ if "text_encoder.embeddings" in key or "image_encoder.embeddings" in key:
35
+ continue
36
+
37
+ key = key.replace("heads.cmd.mim_head.cls.predictions", "mmm_image_head")
38
+ key = key.replace("heads.cmd.mlm_head.cls.predictions", "mmm_text_head")
39
+ key = key.replace("heads.cmd.itm_head.cls", "itm_head")
40
+ key = key.replace("heads.cmd.itm_head.pooler", "itm_head.pooler")
41
+ key = key.replace("heads.cmd.clip_head.logit_scale", "flava.logit_scale")
42
+ key = key.replace("heads.fairseq_mlm.cls.predictions", "mlm_head")
43
+ key = key.replace("heads.imagenet.mim_head.cls.predictions", "mim_head")
44
+ key = key.replace("mm_text_projection", "flava.text_to_mm_projection")
45
+ key = key.replace("mm_image_projection", "flava.image_to_mm_projection")
46
+ key = key.replace("image_encoder.module", "flava.image_model")
47
+ key = key.replace("text_encoder.module", "flava.text_model")
48
+ key = key.replace("mm_encoder.module.encoder.cls_token", "flava.multimodal_model.cls_token")
49
+ key = key.replace("mm_encoder.module", "flava.multimodal_model")
50
+ key = key.replace("text_projection", "flava.text_projection")
51
+ key = key.replace("image_projection", "flava.image_projection")
52
+
53
+ upgrade[key] = value.float()
54
+
55
+ for key, value in codebook_state_dict.items():
56
+ upgrade[f"image_codebook.{key}"] = value
57
+
58
+ return upgrade
59
+
60
+
61
+ @torch.no_grad()
62
+ def convert_flava_checkpoint(checkpoint_path, codebook_path, pytorch_dump_folder_path, config_path=None):
63
+ """
64
+ Copy/paste/tweak model's weights to transformers design.
65
+ """
66
+ if config_path is not None:
67
+ config = FlavaConfig.from_pretrained(config_path)
68
+ else:
69
+ config = FlavaConfig()
70
+
71
+ hf_model = FlavaForPreTraining(config).eval()
72
+
73
+ codebook_state_dict = convert_dalle_checkpoint(codebook_path, None, save_checkpoint=False)
74
+
75
+ if os.path.exists(checkpoint_path):
76
+ state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
77
+ else:
78
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint_path, map_location="cpu")
79
+
80
+ hf_state_dict = upgrade_state_dict(state_dict, codebook_state_dict)
81
+ hf_model.load_state_dict(hf_state_dict)
82
+ hf_state_dict = hf_model.state_dict()
83
+ hf_count = count_parameters(hf_state_dict)
84
+ state_dict_count = count_parameters(state_dict) + count_parameters(codebook_state_dict)
85
+
86
+ assert torch.allclose(hf_count, state_dict_count, atol=1e-3)
87
+
88
+ hf_model.save_pretrained(pytorch_dump_folder_path)
89
+
90
+
91
+ if __name__ == "__main__":
92
+ parser = argparse.ArgumentParser()
93
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
94
+ parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to flava checkpoint")
95
+ parser.add_argument("--codebook_path", default=None, type=str, help="Path to flava codebook checkpoint")
96
+ parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
97
+ args = parser.parse_args()
98
+
99
+ convert_flava_checkpoint(args.checkpoint_path, args.codebook_path, args.pytorch_dump_folder_path, args.config_path)
docs/transformers/build/lib/transformers/models/flava/feature_extraction_flava.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Feature extractor class for FLAVA."""
16
+
17
+ import warnings
18
+
19
+ from ...utils import logging
20
+ from ...utils.import_utils import requires
21
+ from .image_processing_flava import FlavaImageProcessor
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ @requires(backends=("vision",))
28
+ class FlavaFeatureExtractor(FlavaImageProcessor):
29
+ def __init__(self, *args, **kwargs) -> None:
30
+ warnings.warn(
31
+ "The class FlavaFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
32
+ " use FlavaImageProcessor instead.",
33
+ FutureWarning,
34
+ )
35
+ super().__init__(*args, **kwargs)
36
+
37
+
38
+ __all__ = ["FlavaFeatureExtractor"]
docs/transformers/build/lib/transformers/models/flava/image_processing_flava.py ADDED
@@ -0,0 +1,705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for Flava."""
16
+
17
+ import math
18
+ import random
19
+ from functools import lru_cache
20
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+
24
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
25
+ from ...image_transforms import resize, to_channel_dimension_format
26
+ from ...image_utils import (
27
+ OPENAI_CLIP_MEAN,
28
+ OPENAI_CLIP_STD,
29
+ ChannelDimension,
30
+ ImageInput,
31
+ PILImageResampling,
32
+ infer_channel_dimension_format,
33
+ is_scaled_image,
34
+ make_list_of_images,
35
+ to_numpy_array,
36
+ valid_images,
37
+ validate_preprocess_arguments,
38
+ )
39
+ from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
40
+ from ...utils.import_utils import requires
41
+
42
+
43
+ if is_vision_available():
44
+ import PIL
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ # These values are taken from CLIP
51
+ FLAVA_IMAGE_MEAN = OPENAI_CLIP_MEAN
52
+ FLAVA_IMAGE_STD = OPENAI_CLIP_STD
53
+ FLAVA_CODEBOOK_MEAN = [0.0, 0.0, 0.0]
54
+ FLAVA_CODEBOOK_STD = [1.0, 1.0, 1.0]
55
+ LOGIT_LAPLACE_EPS: float = 0.1
56
+
57
+
58
+ # Inspired from https://github.com/microsoft/unilm/blob/master/beit/masking_generator.py
59
+ class FlavaMaskingGenerator:
60
+ def __init__(
61
+ self,
62
+ input_size: Union[int, Tuple[int, int]] = 14,
63
+ total_mask_patches: int = 75,
64
+ mask_group_max_patches: Optional[int] = None,
65
+ mask_group_min_patches: int = 16,
66
+ mask_group_min_aspect_ratio: Optional[float] = 0.3,
67
+ mask_group_max_aspect_ratio: Optional[float] = None,
68
+ ):
69
+ if not isinstance(input_size, tuple):
70
+ input_size = (input_size,) * 2
71
+ self.height, self.width = input_size
72
+
73
+ self.num_patches = self.height * self.width
74
+ self.total_mask_patches = total_mask_patches
75
+
76
+ self.mask_group_min_patches = mask_group_min_patches
77
+ self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches
78
+
79
+ mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio
80
+ self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio))
81
+
82
+ def __repr__(self):
83
+ repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
84
+ self.height,
85
+ self.width,
86
+ self.mask_group_min_patches,
87
+ self.mask_group_max_patches,
88
+ self.total_mask_patches,
89
+ self.log_aspect_ratio[0],
90
+ self.log_aspect_ratio[1],
91
+ )
92
+ return repr_str
93
+
94
+ def get_shape(self):
95
+ return self.height, self.width
96
+
97
+ def _mask(self, mask, max_mask_patches):
98
+ delta = 0
99
+ for _attempt in range(10):
100
+ target_area = random.uniform(self.mask_group_min_patches, max_mask_patches)
101
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
102
+ height = int(round(math.sqrt(target_area * aspect_ratio)))
103
+ width = int(round(math.sqrt(target_area / aspect_ratio)))
104
+ if width < self.width and height < self.height:
105
+ top = random.randint(0, self.height - height)
106
+ left = random.randint(0, self.width - width)
107
+
108
+ num_masked = mask[top : top + height, left : left + width].sum()
109
+ # Overlap
110
+ if 0 < height * width - num_masked <= max_mask_patches:
111
+ for i in range(top, top + height):
112
+ for j in range(left, left + width):
113
+ if mask[i, j] == 0:
114
+ mask[i, j] = 1
115
+ delta += 1
116
+
117
+ if delta > 0:
118
+ break
119
+ return delta
120
+
121
+ def __call__(self):
122
+ mask = np.zeros(shape=self.get_shape(), dtype=int)
123
+ mask_count = 0
124
+ while mask_count < self.total_mask_patches:
125
+ max_mask_patches = self.total_mask_patches - mask_count
126
+ max_mask_patches = min(max_mask_patches, self.mask_group_max_patches)
127
+
128
+ delta = self._mask(mask, max_mask_patches)
129
+ if delta == 0:
130
+ break
131
+ else:
132
+ mask_count += delta
133
+
134
+ return mask
135
+
136
+
137
+ @requires(backends=("vision",))
138
+ class FlavaImageProcessor(BaseImageProcessor):
139
+ r"""
140
+ Constructs a Flava image processor.
141
+
142
+ Args:
143
+ do_resize (`bool`, *optional*, defaults to `True`):
144
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
145
+ `do_resize` parameter in `preprocess`.
146
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
147
+ Size of the image after resizing. Can be overridden by the `size` parameter in `preprocess`.
148
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
149
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in
150
+ `preprocess`.
151
+ do_center_crop (`bool`, *optional*, defaults to `True`):
152
+ Whether to center crop the images. Can be overridden by the `do_center_crop` parameter in `preprocess`.
153
+ crop_size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
154
+ Size of image after the center crop `(crop_size["height"], crop_size["width"])`. Can be overridden by the
155
+ `crop_size` parameter in `preprocess`.
156
+ do_rescale (`bool`, *optional*, defaults to `True`):
157
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
158
+ parameter in `preprocess`.
159
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
160
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in
161
+ `preprocess`.
162
+ do_normalize (`bool`, *optional*, defaults to `True`):
163
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in `preprocess`.
164
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
165
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
166
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
167
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
168
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
169
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
170
+ return_image_mask (`bool`, *optional*, defaults to `False`):
171
+ Whether to return the image mask. Can be overridden by the `return_image_mask` parameter in `preprocess`.
172
+ input_size_patches (`int`, *optional*, defaults to 14):
173
+ Number of patches in the image in height and width direction. 14x14 = 196 total patches. Can be overridden
174
+ by the `input_size_patches` parameter in `preprocess`.
175
+ total_mask_patches (`int`, *optional*, defaults to 75):
176
+ Total number of patches that should be masked. Can be overridden by the `total_mask_patches` parameter in
177
+ `preprocess`.
178
+ mask_group_min_patches (`int`, *optional*, defaults to 16):
179
+ Minimum number of patches that should be masked. Can be overridden by the `mask_group_min_patches`
180
+ parameter in `preprocess`.
181
+ mask_group_max_patches (`int`, *optional*):
182
+ Maximum number of patches that should be masked. Can be overridden by the `mask_group_max_patches`
183
+ parameter in `preprocess`.
184
+ mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3):
185
+ Minimum aspect ratio of the mask window. Can be overridden by the `mask_group_min_aspect_ratio` parameter
186
+ in `preprocess`.
187
+ mask_group_max_aspect_ratio (`float`, *optional*):
188
+ Maximum aspect ratio of the mask window. Can be overridden by the `mask_group_max_aspect_ratio` parameter
189
+ in `preprocess`.
190
+ codebook_do_resize (`bool`, *optional*, defaults to `True`):
191
+ Whether to resize the input for codebook to a certain. Can be overridden by the `codebook_do_resize`
192
+ parameter in `preprocess`. `codebook_size`.
193
+ codebook_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
194
+ Resize the input for codebook to the given size. Can be overridden by the `codebook_size` parameter in
195
+ `preprocess`.
196
+ codebook_resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
197
+ Resampling filter to use if resizing the codebook image. Can be overridden by the `codebook_resample`
198
+ parameter in `preprocess`.
199
+ codebook_do_center_crop (`bool`, *optional*, defaults to `True`):
200
+ Whether to crop the input for codebook at the center. If the input size is smaller than
201
+ `codebook_crop_size` along any edge, the image is padded with 0's and then center cropped. Can be
202
+ overridden by the `codebook_do_center_crop` parameter in `preprocess`.
203
+ codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
204
+ Desired output size for codebook input when applying center-cropping. Can be overridden by the
205
+ `codebook_crop_size` parameter in `preprocess`.
206
+ codebook_do_rescale (`bool`, *optional*, defaults to `True`):
207
+ Whether to rescale the input for codebook by the specified scale `codebook_rescale_factor`. Can be
208
+ overridden by the `codebook_do_rescale` parameter in `preprocess`.
209
+ codebook_rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
210
+ Defines the scale factor to use if rescaling the codebook image. Can be overridden by the
211
+ `codebook_rescale_factor` parameter in `preprocess`.
212
+ codebook_do_map_pixels (`bool`, *optional*, defaults to `True`):
213
+ Whether to map the pixel values of the codebook input to (1 - 2e)x + e. Can be overridden by the
214
+ `codebook_do_map_pixels` parameter in `preprocess`.
215
+ codebook_do_normalize (`bool`, *optional*, defaults to `True`):
216
+ Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`. Can
217
+ be overridden by the `codebook_do_normalize` parameter in `preprocess`.
218
+ codebook_image_mean (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0, 0, 0]`):
219
+ The sequence of means for each channel, to be used when normalizing images for codebook. Can be overridden
220
+ by the `codebook_image_mean` parameter in `preprocess`.
221
+ codebook_image_std (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
222
+ The sequence of standard deviations for each channel, to be used when normalizing images for codebook. Can
223
+ be overridden by the `codebook_image_std` parameter in `preprocess`.
224
+ """
225
+
226
+ model_input_names = ["pixel_values"]
227
+
228
+ def __init__(
229
+ self,
230
+ do_resize: bool = True,
231
+ size: Dict[str, int] = None,
232
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
233
+ do_center_crop: bool = True,
234
+ crop_size: Dict[str, int] = None,
235
+ do_rescale: bool = True,
236
+ rescale_factor: Union[int, float] = 1 / 255,
237
+ do_normalize: bool = True,
238
+ image_mean: Optional[Union[float, Iterable[float]]] = None,
239
+ image_std: Optional[Union[float, Iterable[float]]] = None,
240
+ # Mask related params
241
+ return_image_mask: bool = False,
242
+ input_size_patches: int = 14,
243
+ total_mask_patches: int = 75,
244
+ mask_group_min_patches: int = 16,
245
+ mask_group_max_patches: Optional[int] = None,
246
+ mask_group_min_aspect_ratio: float = 0.3,
247
+ mask_group_max_aspect_ratio: Optional[float] = None,
248
+ # Codebook related params
249
+ return_codebook_pixels: bool = False,
250
+ codebook_do_resize: bool = True,
251
+ codebook_size: Optional[bool] = None,
252
+ codebook_resample: int = PILImageResampling.LANCZOS,
253
+ codebook_do_center_crop: bool = True,
254
+ codebook_crop_size: Optional[int] = None,
255
+ codebook_do_rescale: bool = True,
256
+ codebook_rescale_factor: Union[int, float] = 1 / 255,
257
+ codebook_do_map_pixels: bool = True,
258
+ codebook_do_normalize: bool = True,
259
+ codebook_image_mean: Optional[Union[float, Iterable[float]]] = None,
260
+ codebook_image_std: Optional[Union[float, Iterable[float]]] = None,
261
+ **kwargs,
262
+ ) -> None:
263
+ super().__init__(**kwargs)
264
+ size = size if size is not None else {"height": 224, "width": 224}
265
+ size = get_size_dict(size)
266
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
267
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
268
+
269
+ codebook_size = codebook_size if codebook_size is not None else {"height": 112, "width": 112}
270
+ codebook_size = get_size_dict(codebook_size, param_name="codebook_size")
271
+ codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else {"height": 112, "width": 112}
272
+ codebook_crop_size = get_size_dict(codebook_crop_size, param_name="codebook_crop_size")
273
+
274
+ self.do_resize = do_resize
275
+ self.size = size
276
+ self.resample = resample
277
+ self.do_rescale = do_rescale
278
+ self.rescale_factor = rescale_factor
279
+ self.do_center_crop = do_center_crop
280
+ self.crop_size = crop_size
281
+ self.do_normalize = do_normalize
282
+ self.image_mean = image_mean if image_mean is not None else FLAVA_IMAGE_MEAN
283
+ self.image_std = image_std if image_std is not None else FLAVA_IMAGE_STD
284
+
285
+ self.return_image_mask = return_image_mask
286
+ self.input_size_patches = input_size_patches
287
+ self.total_mask_patches = total_mask_patches
288
+ self.mask_group_min_patches = mask_group_min_patches
289
+ self.mask_group_max_patches = mask_group_max_patches
290
+ self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio
291
+ self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio
292
+
293
+ self.return_codebook_pixels = return_codebook_pixels
294
+ self.codebook_do_resize = codebook_do_resize
295
+ self.codebook_size = codebook_size
296
+ self.codebook_resample = codebook_resample
297
+ self.codebook_do_center_crop = codebook_do_center_crop
298
+ self.codebook_crop_size = codebook_crop_size
299
+ self.codebook_do_rescale = codebook_do_rescale
300
+ self.codebook_rescale_factor = codebook_rescale_factor
301
+ self.codebook_do_map_pixels = codebook_do_map_pixels
302
+ self.codebook_do_normalize = codebook_do_normalize
303
+ self.codebook_image_mean = codebook_image_mean
304
+ self.codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else FLAVA_CODEBOOK_MEAN
305
+ self.codebook_image_std = codebook_image_std if codebook_image_std is not None else FLAVA_CODEBOOK_STD
306
+
307
+ @classmethod
308
+ def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
309
+ """
310
+ Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
311
+ created using from_dict and kwargs e.g. `FlavaImageProcessor.from_pretrained(checkpoint, codebook_size=600)`
312
+ """
313
+ image_processor_dict = image_processor_dict.copy()
314
+ if "codebook_size" in kwargs:
315
+ image_processor_dict["codebook_size"] = kwargs.pop("codebook_size")
316
+ if "codebook_crop_size" in kwargs:
317
+ image_processor_dict["codebook_crop_size"] = kwargs.pop("codebook_crop_size")
318
+ return super().from_dict(image_processor_dict, **kwargs)
319
+
320
+ @lru_cache()
321
+ def masking_generator(
322
+ self,
323
+ input_size_patches,
324
+ total_mask_patches,
325
+ mask_group_min_patches,
326
+ mask_group_max_patches,
327
+ mask_group_min_aspect_ratio,
328
+ mask_group_max_aspect_ratio,
329
+ ) -> FlavaMaskingGenerator:
330
+ return FlavaMaskingGenerator(
331
+ input_size=input_size_patches,
332
+ total_mask_patches=total_mask_patches,
333
+ mask_group_min_patches=mask_group_min_patches,
334
+ mask_group_max_patches=mask_group_max_patches,
335
+ mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
336
+ mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
337
+ )
338
+
339
+ # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
340
+ def resize(
341
+ self,
342
+ image: np.ndarray,
343
+ size: Dict[str, int],
344
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
345
+ data_format: Optional[Union[str, ChannelDimension]] = None,
346
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
347
+ **kwargs,
348
+ ) -> np.ndarray:
349
+ """
350
+ Resize an image to `(size["height"], size["width"])`.
351
+
352
+ Args:
353
+ image (`np.ndarray`):
354
+ Image to resize.
355
+ size (`Dict[str, int]`):
356
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
357
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
358
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
359
+ data_format (`ChannelDimension` or `str`, *optional*):
360
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
361
+ image is used. Can be one of:
362
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
363
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
364
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
365
+ input_data_format (`ChannelDimension` or `str`, *optional*):
366
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
367
+ from the input image. Can be one of:
368
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
369
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
370
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
371
+
372
+ Returns:
373
+ `np.ndarray`: The resized image.
374
+ """
375
+ size = get_size_dict(size)
376
+ if "height" not in size or "width" not in size:
377
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
378
+ output_size = (size["height"], size["width"])
379
+ return resize(
380
+ image,
381
+ size=output_size,
382
+ resample=resample,
383
+ data_format=data_format,
384
+ input_data_format=input_data_format,
385
+ **kwargs,
386
+ )
387
+
388
+ def map_pixels(self, image: np.ndarray) -> np.ndarray:
389
+ return (1 - 2 * LOGIT_LAPLACE_EPS) * image + LOGIT_LAPLACE_EPS
390
+
391
+ def _preprocess_image(
392
+ self,
393
+ image: ImageInput,
394
+ do_resize: Optional[bool] = None,
395
+ size: Dict[str, int] = None,
396
+ resample: PILImageResampling = None,
397
+ do_center_crop: Optional[bool] = None,
398
+ crop_size: Dict[str, int] = None,
399
+ do_rescale: Optional[bool] = None,
400
+ rescale_factor: Optional[float] = None,
401
+ do_normalize: Optional[bool] = None,
402
+ image_mean: Optional[Union[float, List[float]]] = None,
403
+ image_std: Optional[Union[float, List[float]]] = None,
404
+ do_map_pixels: Optional[bool] = None,
405
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
406
+ input_data_format: Optional[ChannelDimension] = None,
407
+ ) -> np.ndarray:
408
+ """Preprocesses a single image."""
409
+
410
+ validate_preprocess_arguments(
411
+ do_rescale=do_rescale,
412
+ rescale_factor=rescale_factor,
413
+ do_normalize=do_normalize,
414
+ image_mean=image_mean,
415
+ image_std=image_std,
416
+ do_center_crop=do_center_crop,
417
+ crop_size=crop_size,
418
+ do_resize=do_resize,
419
+ size=size,
420
+ resample=resample,
421
+ )
422
+
423
+ # All transformations expect numpy arrays.
424
+ image = to_numpy_array(image)
425
+
426
+ if do_rescale and is_scaled_image(image):
427
+ logger.warning_once(
428
+ "It looks like you are trying to rescale already rescaled images. If the input"
429
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
430
+ )
431
+
432
+ if input_data_format is None:
433
+ # We assume that all images have the same channel dimension format.
434
+ input_data_format = infer_channel_dimension_format(image)
435
+
436
+ if do_resize:
437
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
438
+
439
+ if do_center_crop:
440
+ image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
441
+
442
+ if do_rescale:
443
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
444
+
445
+ if do_normalize:
446
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
447
+
448
+ if do_map_pixels:
449
+ image = self.map_pixels(image)
450
+
451
+ if data_format is not None:
452
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
453
+ return image
454
+
455
+ @filter_out_non_signature_kwargs()
456
+ def preprocess(
457
+ self,
458
+ images: ImageInput,
459
+ do_resize: Optional[bool] = None,
460
+ size: Dict[str, int] = None,
461
+ resample: PILImageResampling = None,
462
+ do_center_crop: Optional[bool] = None,
463
+ crop_size: Optional[Dict[str, int]] = None,
464
+ do_rescale: Optional[bool] = None,
465
+ rescale_factor: Optional[float] = None,
466
+ do_normalize: Optional[bool] = None,
467
+ image_mean: Optional[Union[float, List[float]]] = None,
468
+ image_std: Optional[Union[float, List[float]]] = None,
469
+ # Mask related params
470
+ return_image_mask: Optional[bool] = None,
471
+ input_size_patches: Optional[int] = None,
472
+ total_mask_patches: Optional[int] = None,
473
+ mask_group_min_patches: Optional[int] = None,
474
+ mask_group_max_patches: Optional[int] = None,
475
+ mask_group_min_aspect_ratio: Optional[float] = None,
476
+ mask_group_max_aspect_ratio: Optional[float] = None,
477
+ # Codebook related params
478
+ return_codebook_pixels: Optional[bool] = None,
479
+ codebook_do_resize: Optional[bool] = None,
480
+ codebook_size: Optional[Dict[str, int]] = None,
481
+ codebook_resample: Optional[int] = None,
482
+ codebook_do_center_crop: Optional[bool] = None,
483
+ codebook_crop_size: Optional[Dict[str, int]] = None,
484
+ codebook_do_rescale: Optional[bool] = None,
485
+ codebook_rescale_factor: Optional[float] = None,
486
+ codebook_do_map_pixels: Optional[bool] = None,
487
+ codebook_do_normalize: Optional[bool] = None,
488
+ codebook_image_mean: Optional[Iterable[float]] = None,
489
+ codebook_image_std: Optional[Iterable[float]] = None,
490
+ return_tensors: Optional[Union[str, TensorType]] = None,
491
+ data_format: ChannelDimension = ChannelDimension.FIRST,
492
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
493
+ ) -> PIL.Image.Image:
494
+ """
495
+ Preprocess an image or batch of images.
496
+
497
+ Args:
498
+ images (`ImageInput`):
499
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
500
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
501
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
502
+ Whether to resize the image.
503
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
504
+ Size of the image.
505
+ resample (`int`, *optional*, defaults to `self.resample`):
506
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
507
+ has an effect if `do_resize` is set to `True`.
508
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
509
+ Whether to center crop the image.
510
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
511
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
512
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
513
+ Whether to rescale the image values between [0 - 1].
514
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
515
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
516
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
517
+ Whether to normalize the image.
518
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
519
+ Image mean.
520
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
521
+ Image standard deviation.
522
+ return_image_mask (`bool`, *optional*, defaults to `self.return_image_mask`):
523
+ Whether to return the image mask.
524
+ input_size_patches (`int`, *optional*, defaults to `self.input_size_patches`):
525
+ Size of the patches to extract from the image.
526
+ total_mask_patches (`int`, *optional*, defaults to `self.total_mask_patches`):
527
+ Total number of patches to extract from the image.
528
+ mask_group_min_patches (`int`, *optional*, defaults to `self.mask_group_min_patches`):
529
+ Minimum number of patches to extract from the image.
530
+ mask_group_max_patches (`int`, *optional*, defaults to `self.mask_group_max_patches`):
531
+ Maximum number of patches to extract from the image.
532
+ mask_group_min_aspect_ratio (`float`, *optional*, defaults to `self.mask_group_min_aspect_ratio`):
533
+ Minimum aspect ratio of the patches to extract from the image.
534
+ mask_group_max_aspect_ratio (`float`, *optional*, defaults to `self.mask_group_max_aspect_ratio`):
535
+ Maximum aspect ratio of the patches to extract from the image.
536
+ return_codebook_pixels (`bool`, *optional*, defaults to `self.return_codebook_pixels`):
537
+ Whether to return the codebook pixels.
538
+ codebook_do_resize (`bool`, *optional*, defaults to `self.codebook_do_resize`):
539
+ Whether to resize the codebook pixels.
540
+ codebook_size (`Dict[str, int]`, *optional*, defaults to `self.codebook_size`):
541
+ Size of the codebook pixels.
542
+ codebook_resample (`int`, *optional*, defaults to `self.codebook_resample`):
543
+ Resampling filter to use if resizing the codebook pixels. This can be one of the enum
544
+ `PILImageResampling`, Only has an effect if `codebook_do_resize` is set to `True`.
545
+ codebook_do_center_crop (`bool`, *optional*, defaults to `self.codebook_do_center_crop`):
546
+ Whether to center crop the codebook pixels.
547
+ codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `self.codebook_crop_size`):
548
+ Size of the center crop of the codebook pixels. Only has an effect if `codebook_do_center_crop` is set
549
+ to `True`.
550
+ codebook_do_rescale (`bool`, *optional*, defaults to `self.codebook_do_rescale`):
551
+ Whether to rescale the codebook pixels values between [0 - 1].
552
+ codebook_rescale_factor (`float`, *optional*, defaults to `self.codebook_rescale_factor`):
553
+ Rescale factor to rescale the codebook pixels by if `codebook_do_rescale` is set to `True`.
554
+ codebook_do_map_pixels (`bool`, *optional*, defaults to `self.codebook_do_map_pixels`):
555
+ Whether to map the codebook pixels values.
556
+ codebook_do_normalize (`bool`, *optional*, defaults to `self.codebook_do_normalize`):
557
+ Whether to normalize the codebook pixels.
558
+ codebook_image_mean (`float` or `List[float]`, *optional*, defaults to `self.codebook_image_mean`):
559
+ Codebook pixels mean to normalize the codebook pixels by if `codebook_do_normalize` is set to `True`.
560
+ codebook_image_std (`float` or `List[float]`, *optional*, defaults to `self.codebook_image_std`):
561
+ Codebook pixels standard deviation to normalize the codebook pixels by if `codebook_do_normalize` is
562
+ set to `True`.
563
+ return_tensors (`str` or `TensorType`, *optional*):
564
+ The type of tensors to return. Can be one of:
565
+ - Unset: Return a list of `np.ndarray`.
566
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
567
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
568
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
569
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
570
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
571
+ The channel dimension format for the output image. Can be one of:
572
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
573
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
574
+ input_data_format (`ChannelDimension` or `str`, *optional*):
575
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
576
+ from the input image. Can be one of:
577
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
578
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
579
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
580
+ """
581
+ do_resize = do_resize if do_resize is not None else self.do_resize
582
+ size = size if size is not None else self.size
583
+ size = get_size_dict(size)
584
+ resample = resample if resample is not None else self.resample
585
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
586
+ crop_size = crop_size if crop_size is not None else self.crop_size
587
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
588
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
589
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
590
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
591
+ image_mean = image_mean if image_mean is not None else self.image_mean
592
+ image_std = image_std if image_std is not None else self.image_std
593
+
594
+ return_image_mask = return_image_mask if return_image_mask is not None else self.return_image_mask
595
+ input_size_patches = input_size_patches if input_size_patches is not None else self.input_size_patches
596
+ total_mask_patches = total_mask_patches if total_mask_patches is not None else self.total_mask_patches
597
+ mask_group_min_patches = (
598
+ mask_group_min_patches if mask_group_min_patches is not None else self.mask_group_min_patches
599
+ )
600
+ mask_group_max_patches = (
601
+ mask_group_max_patches if mask_group_max_patches is not None else self.mask_group_max_patches
602
+ )
603
+ mask_group_min_aspect_ratio = (
604
+ mask_group_min_aspect_ratio
605
+ if mask_group_min_aspect_ratio is not None
606
+ else self.mask_group_min_aspect_ratio
607
+ )
608
+ mask_group_max_aspect_ratio = (
609
+ mask_group_max_aspect_ratio
610
+ if mask_group_max_aspect_ratio is not None
611
+ else self.mask_group_max_aspect_ratio
612
+ )
613
+
614
+ return_codebook_pixels = (
615
+ return_codebook_pixels if return_codebook_pixels is not None else self.return_codebook_pixels
616
+ )
617
+ codebook_do_resize = codebook_do_resize if codebook_do_resize is not None else self.codebook_do_resize
618
+ codebook_size = codebook_size if codebook_size is not None else self.codebook_size
619
+ codebook_size = get_size_dict(codebook_size, param_name="codebook_size")
620
+ codebook_resample = codebook_resample if codebook_resample is not None else self.codebook_resample
621
+ codebook_do_rescale = codebook_do_rescale if codebook_do_rescale is not None else self.codebook_do_rescale
622
+ codebook_rescale_factor = (
623
+ codebook_rescale_factor if codebook_rescale_factor is not None else self.codebook_rescale_factor
624
+ )
625
+ codebook_do_center_crop = (
626
+ codebook_do_center_crop if codebook_do_center_crop is not None else self.codebook_do_center_crop
627
+ )
628
+ codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else self.codebook_crop_size
629
+ codebook_crop_size = get_size_dict(codebook_crop_size, param_name="codebook_crop_size")
630
+ codebook_do_map_pixels = (
631
+ codebook_do_map_pixels if codebook_do_map_pixels is not None else self.codebook_do_map_pixels
632
+ )
633
+ codebook_do_normalize = (
634
+ codebook_do_normalize if codebook_do_normalize is not None else self.codebook_do_normalize
635
+ )
636
+ codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else self.codebook_image_mean
637
+ codebook_image_std = codebook_image_std if codebook_image_std is not None else self.codebook_image_std
638
+
639
+ images = make_list_of_images(images)
640
+
641
+ if not valid_images(images):
642
+ raise ValueError(
643
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
644
+ "torch.Tensor, tf.Tensor or jax.ndarray."
645
+ )
646
+
647
+ processed_images = [
648
+ self._preprocess_image(
649
+ image=img,
650
+ do_resize=do_resize,
651
+ size=size,
652
+ resample=resample,
653
+ do_center_crop=do_center_crop,
654
+ crop_size=crop_size,
655
+ do_rescale=do_rescale,
656
+ rescale_factor=rescale_factor,
657
+ do_normalize=do_normalize,
658
+ image_mean=image_mean,
659
+ image_std=image_std,
660
+ do_map_pixels=False,
661
+ data_format=data_format,
662
+ input_data_format=input_data_format,
663
+ )
664
+ for img in images
665
+ ]
666
+ data = {"pixel_values": processed_images}
667
+
668
+ if return_codebook_pixels:
669
+ codebook_images = [
670
+ self._preprocess_image(
671
+ image=img,
672
+ do_resize=codebook_do_resize,
673
+ size=codebook_size,
674
+ resample=codebook_resample,
675
+ do_center_crop=codebook_do_center_crop,
676
+ crop_size=codebook_crop_size,
677
+ do_rescale=codebook_do_rescale,
678
+ rescale_factor=codebook_rescale_factor,
679
+ do_normalize=codebook_do_normalize,
680
+ image_mean=codebook_image_mean,
681
+ image_std=codebook_image_std,
682
+ do_map_pixels=codebook_do_map_pixels,
683
+ data_format=data_format,
684
+ input_data_format=input_data_format,
685
+ )
686
+ for img in images
687
+ ]
688
+ data["codebook_pixel_values"] = codebook_images
689
+
690
+ if return_image_mask:
691
+ mask_generator = self.masking_generator(
692
+ input_size_patches=input_size_patches,
693
+ total_mask_patches=total_mask_patches,
694
+ mask_group_min_patches=mask_group_min_patches,
695
+ mask_group_max_patches=mask_group_max_patches,
696
+ mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
697
+ mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
698
+ )
699
+ masks = [mask_generator() for _ in images]
700
+ data["bool_masked_pos"] = masks
701
+
702
+ return BatchFeature(data=data, tensor_type=return_tensors)
703
+
704
+
705
+ __all__ = ["FlavaImageProcessor"]
docs/transformers/build/lib/transformers/models/flava/image_processing_flava_fast.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Fast Image processor class for Flava."""
16
+
17
+ import math
18
+ import random
19
+ from functools import lru_cache
20
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union
21
+
22
+ from ...image_processing_utils_fast import (
23
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
24
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
25
+ BaseImageProcessorFast,
26
+ BatchFeature,
27
+ DefaultFastImageProcessorKwargs,
28
+ get_size_dict,
29
+ )
30
+ from ...image_transforms import ChannelDimension, group_images_by_shape, reorder_images
31
+ from ...image_utils import ImageInput, PILImageResampling, SizeDict
32
+ from ...processing_utils import Unpack
33
+ from ...utils import (
34
+ TensorType,
35
+ add_start_docstrings,
36
+ is_torch_available,
37
+ is_torchvision_available,
38
+ is_torchvision_v2_available,
39
+ )
40
+ from .image_processing_flava import (
41
+ FLAVA_CODEBOOK_MEAN,
42
+ FLAVA_CODEBOOK_STD,
43
+ FLAVA_IMAGE_MEAN,
44
+ FLAVA_IMAGE_STD,
45
+ LOGIT_LAPLACE_EPS,
46
+ )
47
+
48
+
49
+ if is_torch_available():
50
+ import torch
51
+
52
+ if is_torchvision_available():
53
+ from ...image_utils import pil_torch_interpolation_mapping
54
+
55
+ if is_torchvision_v2_available():
56
+ from torchvision.transforms.v2 import functional as F
57
+ else:
58
+ from torchvision.transforms import functional as F
59
+
60
+
61
+ class FlavaMaskingGenerator:
62
+ def __init__(
63
+ self,
64
+ input_size: Union[int, Tuple[int, int]] = 14,
65
+ total_mask_patches: int = 75,
66
+ mask_group_max_patches: Optional[int] = None,
67
+ mask_group_min_patches: int = 16,
68
+ mask_group_min_aspect_ratio: Optional[float] = 0.3,
69
+ mask_group_max_aspect_ratio: float = None,
70
+ ):
71
+ if not isinstance(input_size, tuple):
72
+ input_size = (input_size,) * 2
73
+ self.height, self.width = input_size
74
+
75
+ self.num_patches = self.height * self.width
76
+ self.total_mask_patches = total_mask_patches
77
+
78
+ self.mask_group_min_patches = mask_group_min_patches
79
+ self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches
80
+
81
+ mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio
82
+ self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio))
83
+
84
+ def __repr__(self):
85
+ repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
86
+ self.height,
87
+ self.width,
88
+ self.mask_group_min_patches,
89
+ self.mask_group_max_patches,
90
+ self.total_mask_patches,
91
+ self.log_aspect_ratio[0],
92
+ self.log_aspect_ratio[1],
93
+ )
94
+ return repr_str
95
+
96
+ def get_shape(self):
97
+ return self.height, self.width
98
+
99
+ def _mask(self, mask, max_mask_patches):
100
+ delta = 0
101
+ for _attempt in range(10):
102
+ target_area = random.uniform(self.mask_group_min_patches, max_mask_patches)
103
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
104
+ height = int(round(math.sqrt(target_area * aspect_ratio)))
105
+ width = int(round(math.sqrt(target_area / aspect_ratio)))
106
+ if width < self.width and height < self.height:
107
+ top = random.randint(0, self.height - height)
108
+ left = random.randint(0, self.width - width)
109
+
110
+ num_masked = mask[top : top + height, left : left + width].sum()
111
+ # Overlap
112
+ if 0 < height * width - num_masked <= max_mask_patches:
113
+ zeros_pos = mask[top : top + height, left : left + width] == 0
114
+ mask[top : top + height, left : left + width][zeros_pos] = 1
115
+ delta += zeros_pos.sum()
116
+
117
+ if delta > 0:
118
+ break
119
+ return delta
120
+
121
+ def __call__(self):
122
+ mask = torch.zeros(self.get_shape(), dtype=torch.int)
123
+ mask_count = 0
124
+ while mask_count < self.total_mask_patches:
125
+ max_mask_patches = self.total_mask_patches - mask_count
126
+ max_mask_patches = min(max_mask_patches, self.mask_group_max_patches)
127
+
128
+ delta = self._mask(mask, max_mask_patches)
129
+ if delta == 0:
130
+ break
131
+ else:
132
+ mask_count += delta
133
+
134
+ return mask
135
+
136
+
137
+ class FlavaFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
138
+ # Mask related params
139
+ return_image_mask: Optional[bool]
140
+ input_size_patches: Optional[int]
141
+ total_mask_patches: Optional[int]
142
+ mask_group_min_patches: Optional[int]
143
+ mask_group_max_patches: Optional[int]
144
+ mask_group_min_aspect_ratio: Optional[float]
145
+ mask_group_max_aspect_ratio: Optional[float]
146
+ # Codebook related params
147
+ return_codebook_pixels: Optional[bool]
148
+ codebook_do_resize: Optional[bool]
149
+ codebook_size: Optional[bool]
150
+ codebook_resample: Optional[int]
151
+ codebook_do_center_crop: Optional[bool]
152
+ codebook_crop_size: Optional[int]
153
+ codebook_do_rescale: Optional[bool]
154
+ codebook_rescale_factor: Optional[Union[int, float]]
155
+ codebook_do_map_pixels: Optional[bool]
156
+ codebook_do_normalize: Optional[bool]
157
+ codebook_image_mean: Optional[Union[float, Iterable[float]]]
158
+ codebook_image_std: Optional[Union[float, Iterable[float]]]
159
+
160
+
161
+ @add_start_docstrings(
162
+ "Constructs a fast Flava image processor.",
163
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
164
+ """
165
+ return_image_mask (`bool`, *optional*, defaults to `False`):
166
+ Whether to return the image mask. Can be overridden by the `return_image_mask` parameter in `preprocess`.
167
+ input_size_patches (`int`, *optional*, defaults to 14):
168
+ Number of patches in the image in height and width direction. 14x14 = 196 total patches. Can be overridden
169
+ by the `input_size_patches` parameter in `preprocess`.
170
+ total_mask_patches (`int`, *optional*, defaults to 75):
171
+ Total number of patches that should be masked. Can be overridden by the `total_mask_patches` parameter in
172
+ `preprocess`.
173
+ mask_group_min_patches (`int`, *optional*, defaults to 16):
174
+ Minimum number of patches that should be masked. Can be overridden by the `mask_group_min_patches`
175
+ parameter in `preprocess`.
176
+ mask_group_max_patches (`int`, *optional*):
177
+ Maximum number of patches that should be masked. Can be overridden by the `mask_group_max_patches`
178
+ parameter in `preprocess`.
179
+ mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3):
180
+ Minimum aspect ratio of the mask window. Can be overridden by the `mask_group_min_aspect_ratio` parameter
181
+ in `preprocess`.
182
+ mask_group_max_aspect_ratio (`float`, *optional*):
183
+ Maximum aspect ratio of the mask window. Can be overridden by the `mask_group_max_aspect_ratio` parameter
184
+ in `preprocess`.
185
+ codebook_do_resize (`bool`, *optional*, defaults to `True`):
186
+ Whether to resize the input for codebook to a certain. Can be overridden by the `codebook_do_resize`
187
+ parameter in `preprocess`. `codebook_size`.
188
+ codebook_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
189
+ Resize the input for codebook to the given size. Can be overridden by the `codebook_size` parameter in
190
+ `preprocess`.
191
+ codebook_resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
192
+ Resampling filter to use if resizing the codebook image. Can be overridden by the `codebook_resample`
193
+ parameter in `preprocess`.
194
+ codebook_do_center_crop (`bool`, *optional*, defaults to `True`):
195
+ Whether to crop the input for codebook at the center. If the input size is smaller than
196
+ `codebook_crop_size` along any edge, the image is padded with 0's and then center cropped. Can be
197
+ overridden by the `codebook_do_center_crop` parameter in `preprocess`.
198
+ codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
199
+ Desired output size for codebook input when applying center-cropping. Can be overridden by the
200
+ `codebook_crop_size` parameter in `preprocess`.
201
+ codebook_do_rescale (`bool`, *optional*, defaults to `True`):
202
+ Whether to rescale the input for codebook by the specified scale `codebook_rescale_factor`. Can be
203
+ overridden by the `codebook_do_rescale` parameter in `preprocess`.
204
+ codebook_rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
205
+ Defines the scale factor to use if rescaling the codebook image. Can be overridden by the
206
+ `codebook_rescale_factor` parameter in `preprocess`.
207
+ codebook_do_map_pixels (`bool`, *optional*, defaults to `True`):
208
+ Whether to map the pixel values of the codebook input to (1 - 2e)x + e. Can be overridden by the
209
+ `codebook_do_map_pixels` parameter in `preprocess`.
210
+ codebook_do_normalize (`bool`, *optional*, defaults to `True`):
211
+ Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`. Can
212
+ be overridden by the `codebook_do_normalize` parameter in `preprocess`.
213
+ codebook_image_mean (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0, 0, 0]`):
214
+ The sequence of means for each channel, to be used when normalizing images for codebook. Can be overridden
215
+ by the `codebook_image_mean` parameter in `preprocess`.
216
+ codebook_image_std (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
217
+ The sequence of standard deviations for each channel, to be used when normalizing images for codebook. Can
218
+ be overridden by the `codebook_image_std` parameter in `preprocess`.
219
+ """,
220
+ )
221
+ class FlavaImageProcessorFast(BaseImageProcessorFast):
222
+ resample = PILImageResampling.BICUBIC
223
+ image_mean = FLAVA_IMAGE_MEAN
224
+ image_std = FLAVA_IMAGE_STD
225
+ size = {"height": 224, "width": 224}
226
+ crop_size = {"height": 224, "width": 224}
227
+ do_resize = True
228
+ do_center_crop = True
229
+ do_rescale = True
230
+ do_normalize = True
231
+
232
+ # Mask related params
233
+ return_image_mask = False
234
+ input_size_patches = 14
235
+ total_mask_patches = 75
236
+ mask_group_min_patches = 16
237
+ mask_group_max_patches = None
238
+ mask_group_min_aspect_ratio = 0.3
239
+ mask_group_max_aspect_ratio = None
240
+ # Codebook related params
241
+ return_codebook_pixels = False
242
+ codebook_do_resize = True
243
+ codebook_size = {"height": 112, "width": 112}
244
+ # LANCZOS resample does not support torch Tensor. Use BICUBIC as closest alternative
245
+ codebook_resample = PILImageResampling.BICUBIC
246
+ codebook_do_center_crop = True
247
+ codebook_crop_size = {"height": 112, "width": 112}
248
+ codebook_do_rescale = True
249
+ codebook_rescale_factor = 1 / 255
250
+ codebook_do_map_pixels = True
251
+ codebook_do_normalize = True
252
+ codebook_image_mean = FLAVA_CODEBOOK_MEAN
253
+ codebook_image_std = FLAVA_CODEBOOK_STD
254
+ valid_kwargs = FlavaFastImageProcessorKwargs
255
+
256
+ def __init__(self, **kwargs: Unpack[FlavaFastImageProcessorKwargs]):
257
+ super().__init__(**kwargs)
258
+
259
+ @add_start_docstrings(
260
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
261
+ """
262
+ return_image_mask (`bool`, *optional*, defaults to `False`):
263
+ Whether to return the image mask. Can be overridden by the `return_image_mask` parameter in `preprocess`.
264
+ input_size_patches (`int`, *optional*, defaults to 14):
265
+ Number of patches in the image in height and width direction. 14x14 = 196 total patches. Can be overridden
266
+ by the `input_size_patches` parameter in `preprocess`.
267
+ total_mask_patches (`int`, *optional*, defaults to 75):
268
+ Total number of patches that should be masked. Can be overridden by the `total_mask_patches` parameter in
269
+ `preprocess`.
270
+ mask_group_min_patches (`int`, *optional*, defaults to 16):
271
+ Minimum number of patches that should be masked. Can be overridden by the `mask_group_min_patches`
272
+ parameter in `preprocess`.
273
+ mask_group_max_patches (`int`, *optional*):
274
+ Maximum number of patches that should be masked. Can be overridden by the `mask_group_max_patches`
275
+ parameter in `preprocess`.
276
+ mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3):
277
+ Minimum aspect ratio of the mask window. Can be overridden by the `mask_group_min_aspect_ratio` parameter
278
+ in `preprocess`.
279
+ mask_group_max_aspect_ratio (`float`, *optional*):
280
+ Maximum aspect ratio of the mask window. Can be overridden by the `mask_group_max_aspect_ratio` parameter
281
+ in `preprocess`.
282
+ codebook_do_resize (`bool`, *optional*, defaults to `True`):
283
+ Whether to resize the input for codebook to a certain. Can be overridden by the `codebook_do_resize`
284
+ parameter in `preprocess`. `codebook_size`.
285
+ codebook_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
286
+ Resize the input for codebook to the given size. Can be overridden by the `codebook_size` parameter in
287
+ `preprocess`.
288
+ codebook_resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
289
+ Resampling filter to use if resizing the codebook image. Can be overridden by the `codebook_resample`
290
+ parameter in `preprocess`.
291
+ codebook_do_center_crop (`bool`, *optional*, defaults to `True`):
292
+ Whether to crop the input for codebook at the center. If the input size is smaller than
293
+ `codebook_crop_size` along any edge, the image is padded with 0's and then center cropped. Can be
294
+ overridden by the `codebook_do_center_crop` parameter in `preprocess`.
295
+ codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
296
+ Desired output size for codebook input when applying center-cropping. Can be overridden by the
297
+ `codebook_crop_size` parameter in `preprocess`.
298
+ codebook_do_rescale (`bool`, *optional*, defaults to `True`):
299
+ Whether to rescale the input for codebook by the specified scale `codebook_rescale_factor`. Can be
300
+ overridden by the `codebook_do_rescale` parameter in `preprocess`.
301
+ codebook_rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
302
+ Defines the scale factor to use if rescaling the codebook image. Can be overridden by the
303
+ `codebook_rescale_factor` parameter in `preprocess`.
304
+ codebook_do_map_pixels (`bool`, *optional*, defaults to `True`):
305
+ Whether to map the pixel values of the codebook input to (1 - 2e)x + e. Can be overridden by the
306
+ `codebook_do_map_pixels` parameter in `preprocess`.
307
+ codebook_do_normalize (`bool`, *optional*, defaults to `True`):
308
+ Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`. Can
309
+ be overridden by the `codebook_do_normalize` parameter in `preprocess`.
310
+ codebook_image_mean (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0, 0, 0]`):
311
+ The sequence of means for each channel, to be used when normalizing images for codebook. Can be overridden
312
+ by the `codebook_image_mean` parameter in `preprocess`.
313
+ codebook_image_std (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
314
+ The sequence of standard deviations for each channel, to be used when normalizing images for codebook. Can
315
+ be overridden by the `codebook_image_std` parameter in `preprocess`.
316
+ """,
317
+ )
318
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
319
+ return super().preprocess(images, **kwargs)
320
+
321
+ @classmethod
322
+ def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
323
+ """
324
+ Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
325
+ created using from_dict and kwargs e.g. `FlavaImageProcessor.from_pretrained(checkpoint, codebook_size=600)`
326
+ """
327
+ image_processor_dict = image_processor_dict.copy()
328
+ if "codebook_size" in kwargs:
329
+ image_processor_dict["codebook_size"] = kwargs.pop("codebook_size")
330
+ if "codebook_crop_size" in kwargs:
331
+ image_processor_dict["codebook_crop_size"] = kwargs.pop("codebook_crop_size")
332
+ return super().from_dict(image_processor_dict, **kwargs)
333
+
334
+ @lru_cache()
335
+ def masking_generator(
336
+ self,
337
+ input_size_patches,
338
+ total_mask_patches,
339
+ mask_group_min_patches,
340
+ mask_group_max_patches,
341
+ mask_group_min_aspect_ratio,
342
+ mask_group_max_aspect_ratio,
343
+ ) -> FlavaMaskingGenerator:
344
+ return FlavaMaskingGenerator(
345
+ input_size=input_size_patches,
346
+ total_mask_patches=total_mask_patches,
347
+ mask_group_min_patches=mask_group_min_patches,
348
+ mask_group_max_patches=mask_group_max_patches,
349
+ mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
350
+ mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
351
+ )
352
+
353
+ def map_pixels(self, image: "torch.Tensor") -> "torch.Tensor":
354
+ return (1 - 2 * LOGIT_LAPLACE_EPS) * image + LOGIT_LAPLACE_EPS
355
+
356
+ def _further_process_kwargs(
357
+ self,
358
+ size: Optional[SizeDict] = None,
359
+ crop_size: Optional[SizeDict] = None,
360
+ default_to_square: Optional[bool] = None,
361
+ image_mean: Optional[Union[float, list[float]]] = None,
362
+ image_std: Optional[Union[float, list[float]]] = None,
363
+ codebook_size: Optional[SizeDict] = None,
364
+ codebook_crop_size: Optional[SizeDict] = None,
365
+ codebook_image_mean: Optional[Union[float, list[float]]] = None,
366
+ codebook_image_std: Optional[Union[float, list[float]]] = None,
367
+ codebook_resample: Optional[PILImageResampling] = None,
368
+ data_format: Optional[ChannelDimension] = None,
369
+ **kwargs,
370
+ ) -> dict:
371
+ """
372
+ Update kwargs that need further processing before being validated
373
+ Can be overridden by subclasses to customize the processing of kwargs.
374
+ """
375
+ if kwargs is None:
376
+ kwargs = {}
377
+ if size is not None:
378
+ size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
379
+ if crop_size is not None:
380
+ crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
381
+ if isinstance(image_mean, list):
382
+ image_mean = tuple(image_mean)
383
+ if isinstance(image_std, list):
384
+ image_std = tuple(image_std)
385
+ if data_format is None:
386
+ data_format = ChannelDimension.FIRST
387
+ if codebook_size is not None:
388
+ codebook_size = SizeDict(**get_size_dict(size=codebook_size, default_to_square=default_to_square))
389
+ if codebook_crop_size is not None:
390
+ codebook_crop_size = SizeDict(**get_size_dict(codebook_crop_size, param_name="codebook_crop_size"))
391
+ if isinstance(codebook_image_mean, list):
392
+ codebook_image_mean = tuple(codebook_image_mean)
393
+ if isinstance(codebook_image_std, list):
394
+ codebook_image_std = tuple(codebook_image_std)
395
+
396
+ kwargs["size"] = size
397
+ kwargs["crop_size"] = crop_size
398
+ kwargs["default_to_square"] = default_to_square
399
+ kwargs["image_mean"] = image_mean
400
+ kwargs["image_std"] = image_std
401
+ kwargs["codebook_size"] = codebook_size
402
+ kwargs["codebook_crop_size"] = codebook_crop_size
403
+ kwargs["codebook_image_mean"] = codebook_image_mean
404
+ kwargs["codebook_image_std"] = codebook_image_std
405
+ kwargs["data_format"] = data_format
406
+ kwargs["codebook_interpolation"] = (
407
+ pil_torch_interpolation_mapping[codebook_resample]
408
+ if isinstance(codebook_resample, (PILImageResampling, int))
409
+ else codebook_resample
410
+ )
411
+
412
+ return kwargs
413
+
414
+ def _preprocess_image(
415
+ self,
416
+ images: list["torch.Tensor"],
417
+ do_resize: bool,
418
+ size: SizeDict,
419
+ interpolation: Optional["F.InterpolationMode"],
420
+ do_center_crop: bool,
421
+ crop_size: SizeDict,
422
+ do_rescale: bool,
423
+ rescale_factor: float,
424
+ do_normalize: bool,
425
+ do_map_pixels: bool,
426
+ image_mean: Optional[Union[float, list[float]]],
427
+ image_std: Optional[Union[float, list[float]]],
428
+ return_tensors: Optional[Union[str, TensorType]],
429
+ ) -> "torch.Tensor":
430
+ # Group images by size for batched resizing
431
+ grouped_images, grouped_images_index = group_images_by_shape(images)
432
+ resized_images_grouped = {}
433
+ for shape, stacked_images in grouped_images.items():
434
+ if do_resize:
435
+ stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
436
+ resized_images_grouped[shape] = stacked_images
437
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
438
+
439
+ # Group images by size for further processing
440
+ # Needed in case do_resize is False, or resize returns images with different sizes
441
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images)
442
+ processed_images_grouped = {}
443
+ for shape, stacked_images in grouped_images.items():
444
+ if do_center_crop:
445
+ stacked_images = self.center_crop(stacked_images, crop_size)
446
+ # Fused rescale and normalize
447
+ stacked_images = self.rescale_and_normalize(
448
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
449
+ )
450
+ if do_map_pixels:
451
+ stacked_images = self.map_pixels(image=stacked_images)
452
+ processed_images_grouped[shape] = stacked_images
453
+
454
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
455
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
456
+
457
+ return processed_images
458
+
459
+ def _preprocess(
460
+ self,
461
+ images: list["torch.Tensor"],
462
+ do_resize: bool,
463
+ size: SizeDict,
464
+ interpolation: Optional["F.InterpolationMode"],
465
+ do_center_crop: bool,
466
+ crop_size: SizeDict,
467
+ do_rescale: bool,
468
+ rescale_factor: float,
469
+ do_normalize: bool,
470
+ image_mean: Optional[Union[float, list[float]]],
471
+ image_std: Optional[Union[float, list[float]]],
472
+ # Mask related params
473
+ return_image_mask: Optional[bool],
474
+ input_size_patches: Optional[int],
475
+ total_mask_patches: Optional[int],
476
+ mask_group_min_patches: Optional[int],
477
+ mask_group_max_patches: Optional[int],
478
+ mask_group_min_aspect_ratio: Optional[float],
479
+ mask_group_max_aspect_ratio: Optional[float],
480
+ # Codebook related params
481
+ return_codebook_pixels: Optional[bool],
482
+ codebook_do_resize: Optional[bool],
483
+ codebook_size: Optional[SizeDict],
484
+ codebook_interpolation: Optional["F.InterpolationMode"],
485
+ codebook_do_center_crop: Optional[bool],
486
+ codebook_crop_size: Optional[SizeDict],
487
+ codebook_do_rescale: Optional[bool],
488
+ codebook_rescale_factor: Optional[float],
489
+ codebook_do_map_pixels: Optional[bool],
490
+ codebook_do_normalize: Optional[bool],
491
+ codebook_image_mean: Optional[Union[float, list[float]]],
492
+ codebook_image_std: Optional[Union[float, list[float]]],
493
+ return_tensors: Optional[Union[str, TensorType]],
494
+ **kwargs,
495
+ ) -> BatchFeature:
496
+ processed_images = self._preprocess_image(
497
+ images=images,
498
+ do_resize=do_resize,
499
+ size=size,
500
+ interpolation=interpolation,
501
+ do_center_crop=do_center_crop,
502
+ crop_size=crop_size,
503
+ do_rescale=do_rescale,
504
+ rescale_factor=rescale_factor,
505
+ do_normalize=do_normalize,
506
+ do_map_pixels=False,
507
+ image_mean=image_mean,
508
+ image_std=image_std,
509
+ return_tensors=return_tensors,
510
+ )
511
+ data = {
512
+ "pixel_values": processed_images,
513
+ }
514
+
515
+ if return_codebook_pixels:
516
+ codebook_processed_images = self._preprocess_image(
517
+ images=images,
518
+ do_resize=codebook_do_resize,
519
+ size=codebook_size,
520
+ interpolation=codebook_interpolation,
521
+ do_center_crop=codebook_do_center_crop,
522
+ crop_size=codebook_crop_size,
523
+ do_rescale=codebook_do_rescale,
524
+ rescale_factor=codebook_rescale_factor,
525
+ do_normalize=codebook_do_normalize,
526
+ do_map_pixels=codebook_do_map_pixels,
527
+ image_mean=codebook_image_mean,
528
+ image_std=codebook_image_std,
529
+ return_tensors=return_tensors,
530
+ )
531
+ data["codebook_pixel_values"] = codebook_processed_images
532
+
533
+ if return_image_mask:
534
+ mask_generator = self.masking_generator(
535
+ input_size_patches=input_size_patches,
536
+ total_mask_patches=total_mask_patches,
537
+ mask_group_min_patches=mask_group_min_patches,
538
+ mask_group_max_patches=mask_group_max_patches,
539
+ mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
540
+ mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
541
+ )
542
+ masks = [mask_generator() for _ in range(len(images))]
543
+ masks = torch.stack(masks, dim=0) if return_tensors else masks
544
+ data["bool_masked_pos"] = masks
545
+
546
+ return BatchFeature(data=data, tensor_type=return_tensors)
547
+
548
+
549
+ __all__ = ["FlavaImageProcessorFast"]
docs/transformers/build/lib/transformers/models/flava/modeling_flava.py ADDED
@@ -0,0 +1,2127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch FLAVA model."""
16
+
17
+ import collections
18
+ import math
19
+ from collections import OrderedDict
20
+ from dataclasses import dataclass
21
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+
27
+ from ...activations import ACT2FN
28
+ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
29
+ from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
30
+ from ...utils import (
31
+ ModelOutput,
32
+ add_code_sample_docstrings,
33
+ add_start_docstrings,
34
+ add_start_docstrings_to_model_forward,
35
+ logging,
36
+ replace_return_docstrings,
37
+ torch_int,
38
+ )
39
+ from .configuration_flava import (
40
+ FlavaConfig,
41
+ FlavaImageCodebookConfig,
42
+ FlavaImageConfig,
43
+ FlavaMultimodalConfig,
44
+ FlavaTextConfig,
45
+ )
46
+
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+ _CHECKPOINT_FOR_DOC = "facebook/flava-full"
51
+
52
+ # Codebook docstring
53
+ _CHECKPOINT_FOR_CODEBOOK_DOC = "facebook/flava-image-codebook"
54
+ _CONFIG_CLASS_FOR_IMAGE_MODEL_DOC = "FlavaImageConfig"
55
+ _CONFIG_CLASS_FOR_TEXT_MODEL_DOC = "FlavaTextConfig"
56
+ _CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC = "FlavaMultimodalConfig"
57
+ _EXPECTED_IMAGE_OUTPUT_SHAPE = [1, 197, 768]
58
+
59
+
60
+ LOGIT_SCALE_CLAMP_MIN = 0
61
+ LOGIT_SCALE_CLAMP_MAX = 4.6052
62
+
63
+ FlavaPossibleConfigs = Union[FlavaTextConfig, FlavaImageConfig, FlavaMultimodalConfig]
64
+
65
+
66
+ @dataclass
67
+ class FlavaModelOutput(ModelOutput):
68
+ """
69
+ Output from FlavaModel containing embeddings and outputs from individual encoders.
70
+
71
+ Note that `image_embeddings` and `text_embeddigns` returned are similar to pooled output returned from a
72
+ transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
73
+ `text_projection` layers on `image_embeddings` and `text_embeddings` respectively.
74
+
75
+ Args:
76
+ image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
77
+ The image embeddings which are basically the pooled output of [`FlavaImageModel`].
78
+ image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
79
+ The output of the [`FlavaImageModel`].
80
+ text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present):
81
+ The text embeddings which are basically the pooled output of [`FlavaTextModel`].
82
+ text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
83
+ The output of the [`FlavaTextModel`].
84
+ multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
85
+ The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
86
+ multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
87
+ The output of the [`FlavaMultimodalModel`].
88
+ """
89
+
90
+ image_embeddings: Optional[torch.FloatTensor] = None
91
+ image_output: Optional[BaseModelOutputWithPooling] = None
92
+ text_embeddings: Optional[torch.FloatTensor] = None
93
+ text_output: Optional[BaseModelOutputWithPooling] = None
94
+ multimodal_embeddings: Optional[torch.FloatTensor] = None
95
+ multimodal_output: Optional[BaseModelOutputWithPooling] = None
96
+
97
+ def to_tuple(self) -> Tuple[Any]:
98
+ return tuple(
99
+ self[k] if k not in ["text_output", "image_output", "multimodal_output"] else getattr(self, k).to_tuple()
100
+ for k in self.keys()
101
+ )
102
+
103
+
104
+ @dataclass
105
+ class FlavaLosses(ModelOutput):
106
+ """Class representing pretraining losses from FLAVA model
107
+
108
+ Args:
109
+ mim (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels` and `pixel_values` are present, `input_ids_masked` is absent and `mim_weight` > 0.:
110
+ Masked Image Modeling loss as used in BeIT calculated only for unimodal image data.
111
+ mlm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels` and `input_ids_masked` are present, `pixel_values` is absent and `mlm_weight` > 0.:
112
+ Masked Language Modeling loss as used in BERT calculated only for unimodal text data.
113
+ itm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `itm_labels`, `input_ids_masked`, `pixel_values` are present and `itm_weight` > 0.:
114
+ Image Text Matching (ITM) loss calculated for paired image-text data. Note that ITM loss is calculated on
115
+ masked pairs in FLAVA.
116
+ global_contrastive (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `input_ids` and `pixel_values` are present and `global_contrastive_weight` > 0.:
117
+ Contrastive loss for image-text similarity similar to CLIP but calculated globally for paired image-text
118
+ data. This is calculated on unmasked images and texts.
119
+ mmm_image (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_image_weight` > 0.:
120
+ Masked Multimodal Modeling loss's image component calculated on paired image-text data.
121
+ mmm_text (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_text_weight` > 0.:
122
+ Masked Multimodal Modeling loss's text component calculated on paired image-text data.
123
+ """
124
+
125
+ mim: Optional[torch.FloatTensor] = None
126
+ mlm: Optional[torch.FloatTensor] = None
127
+ itm: Optional[torch.FloatTensor] = None
128
+ global_contrastive: Optional[torch.FloatTensor] = None
129
+ mmm_image: Optional[torch.FloatTensor] = None
130
+ mmm_text: Optional[torch.FloatTensor] = None
131
+
132
+ def all_none(self) -> bool:
133
+ all_none = True
134
+ for v in self.values():
135
+ if v is not None:
136
+ all_none = False
137
+ break
138
+ return all_none
139
+
140
+
141
+ @dataclass
142
+ class FlavaForPreTrainingOutput(ModelOutput):
143
+ """
144
+ Output from FlavaForPreTraining containing embeddings, and outputs from individual encoders.
145
+
146
+ Note that `image_embeddings` and `text_embeddings` returned are similar to pooled output returned from a
147
+ transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
148
+ `text_projection` layers on `image_embeddings` and `text_embeddings` respectively.
149
+
150
+ Args:
151
+ loss (`torch.FloatTensor`, *optional*, returned when `return_loss` is True):
152
+ Total loss calculated for this model.
153
+ loss_info (`FlavaLosses`):
154
+ Detailed info for FLAVA Pretraining losses. Check `FlavaLosses` class description for the information on
155
+ the keys.
156
+ image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
157
+ The image embeddings which are basically the pooled output of [`FlavaImageModel`].
158
+ image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
159
+ The output of the [`FlavaImageModel`].
160
+ text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present):
161
+ The text embeddings which are basically the pooled output of [`FlavaTextModel`].
162
+ text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
163
+ The output of the [`FlavaTextModel`].
164
+ multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`):
165
+ The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
166
+ multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`):
167
+ The output of the [`FlavaMultimodalModel`].
168
+
169
+ image_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
170
+ The image embeddings which are basically the pooled output of [`FlavaImageModel`]. Uses `bool_masked_pos`
171
+ to create masked images.
172
+ image_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
173
+ The output of the [`FlavaImageModel`]. Uses `bool_masked_pos` to create masked images.
174
+ text_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids_masked` are present):
175
+ The text embeddings which are basically the pooled output of [`FlavaTextModel`].
176
+ text_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` are present):
177
+ The output of the [`FlavaTextModel`].
178
+ multimodal_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present):
179
+ The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
180
+ multimodal_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` and `pixel_values` are present):
181
+ The output of the [`FlavaMultimodalModel`].
182
+
183
+ mim_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape `(total_masked_patches, image_vocab_size)` , *optional*, returned when `pixel_values` are present and `input_ids_masked` are not):
184
+ The logits for MIM unimodal loss. Uses `book_masked_pos` to get masked patches. The flattened output is
185
+ returned when `bool_masked_pos` has some of the patches masked.
186
+ mlm_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(total_masked_seq_length, text_vocab_size)`, *optional*, returned when `input_ids_masked` are present and `pixel_values` are not):
187
+ The logits for MLM unimodal loss. The flattened output is returned when `input_ids_masked` has some of
188
+ the tokens masked.
189
+ itm_logits (`torch.FloatTensor` of shape `(batch_size, 2)`, *optional*, returned when `input_ids_masked` and `pixel_values` are present):
190
+ The logits for ITM loss. Note that ITM loss is calculated on masked pairs in FLAVA.
191
+ mmm_image_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape`(total_masked_patches, image_vocab_size)`, *optional*, returned when `pixel_values` and `input_ids_masked` are present):
192
+ The logits for MMM image multimodal loss. Uses `book_masked_pos` to get masked patches. The flattened
193
+ output is returned when `bool_masked_pos` has some of the patches masked.
194
+ mmm_text_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(`(total_masked_seq_length, text_vocab_size)`), *optional*, returned when `pixel_values` and `input_ids_masked` are present):
195
+ The logits for MMM text multimodal loss. The flattened output is returned when `input_ids_masked` has
196
+ some of the tokens masked.
197
+ contrastive_logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
198
+ The scaled dot product scores between `image_embeddings` and `text_embeddings` but passed through FLAVA's
199
+ `image_projection` and `text_projection` layers respectively. This represents the image-text similarity
200
+ scores. This is calculated on unmasked images and texts.
201
+ contrastive_logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
202
+ The scaled dot product scores between `text_embeddings` and `image_embeddings` but passed through FLAVA's
203
+ `text_projection` and `image_projection` layers respectively. This is calculated on unmasked images and
204
+ texts.
205
+ """
206
+
207
+ loss: Optional[torch.FloatTensor] = None
208
+ loss_info: FlavaLosses = None
209
+ image_embeddings: Optional[torch.FloatTensor] = None
210
+ image_output: Optional[BaseModelOutputWithPooling] = None
211
+ text_embeddings: Optional[torch.FloatTensor] = None
212
+ text_output: Optional[BaseModelOutputWithPooling] = None
213
+ multimodal_embeddings: Optional[torch.FloatTensor] = None
214
+ multimodal_output: Optional[BaseModelOutputWithPooling] = None
215
+ image_masked_embeddings: Optional[torch.FloatTensor] = None
216
+ image_masked_output: Optional[BaseModelOutputWithPooling] = None
217
+ text_masked_embeddings: Optional[torch.FloatTensor] = None
218
+ text_masked_output: Optional[BaseModelOutputWithPooling] = None
219
+ multimodal_masked_embeddings: Optional[torch.FloatTensor] = None
220
+ multimodal_masked_output: Optional[BaseModelOutputWithPooling] = None
221
+ mim_logits: Optional[torch.FloatTensor] = None
222
+ mlm_logits: Optional[torch.FloatTensor] = None
223
+ itm_logits: Optional[torch.FloatTensor] = None
224
+ contrastive_logits_per_image: Optional[torch.FloatTensor] = None
225
+ contrastive_logits_per_text: Optional[torch.FloatTensor] = None
226
+ mmm_image_logits: Optional[torch.FloatTensor] = None
227
+ mmm_text_logits: Optional[torch.FloatTensor] = None
228
+
229
+ def to_tuple(self) -> Tuple[Any]:
230
+ transformer_outputs = [
231
+ "text_output",
232
+ "image_output",
233
+ "multimodal_output",
234
+ "text_masked_output",
235
+ "image_masked_output",
236
+ "multimodal_masked_output",
237
+ ]
238
+ return tuple(self[k] if k not in transformer_outputs else getattr(self, k).to_tuple() for k in self.keys())
239
+
240
+
241
+ # Based on timm implementation, which can be found here:
242
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py
243
+ class FlavaImageEmbeddings(nn.Module):
244
+ """
245
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
246
+ """
247
+
248
+ def __init__(self, config: FlavaImageConfig, use_mask_token: bool = False) -> None:
249
+ super().__init__()
250
+
251
+ use_mask_token = use_mask_token or config.mask_token
252
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
253
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
254
+ self.patch_embeddings = PatchEmbeddings(
255
+ image_size=config.image_size,
256
+ patch_size=config.patch_size,
257
+ num_channels=config.num_channels,
258
+ embed_dim=config.hidden_size,
259
+ )
260
+ num_patches = self.patch_embeddings.num_patches
261
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
262
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
263
+ self.patch_size = config.patch_size
264
+ self.config = config
265
+
266
+ # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
267
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
268
+ """
269
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
270
+ images. This method is also adapted to support torch.jit tracing.
271
+
272
+ Adapted from:
273
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
274
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
275
+ """
276
+
277
+ num_patches = embeddings.shape[1] - 1
278
+ num_positions = self.position_embeddings.shape[1] - 1
279
+
280
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
281
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
282
+ return self.position_embeddings
283
+
284
+ class_pos_embed = self.position_embeddings[:, :1]
285
+ patch_pos_embed = self.position_embeddings[:, 1:]
286
+
287
+ dim = embeddings.shape[-1]
288
+
289
+ new_height = height // self.patch_size
290
+ new_width = width // self.patch_size
291
+
292
+ sqrt_num_positions = torch_int(num_positions**0.5)
293
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
294
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
295
+
296
+ patch_pos_embed = nn.functional.interpolate(
297
+ patch_pos_embed,
298
+ size=(new_height, new_width),
299
+ mode="bicubic",
300
+ align_corners=False,
301
+ )
302
+
303
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
304
+
305
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
306
+
307
+ def forward(
308
+ self,
309
+ pixel_values: torch.Tensor,
310
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
311
+ interpolate_pos_encoding: bool = False,
312
+ ) -> torch.Tensor:
313
+ batch_size, num_channels, height, width = pixel_values.shape
314
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
315
+
316
+ batch_size, seq_len, _ = embeddings.size()
317
+ if bool_masked_pos is not None:
318
+ mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
319
+ # B X H X W = B X HW
320
+ if bool_masked_pos.dim() == 3:
321
+ bool_masked_pos = bool_masked_pos.view(bool_masked_pos.size(0), -1)
322
+ # replace the masked visual tokens by mask_tokens
323
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
324
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
325
+
326
+ # add the [CLS] token to the embedded patch tokens
327
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
328
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
329
+
330
+ # add positional encoding to each token
331
+ if interpolate_pos_encoding:
332
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
333
+ else:
334
+ embeddings = embeddings + self.position_embeddings
335
+
336
+ embeddings = self.dropout(embeddings)
337
+
338
+ return embeddings
339
+
340
+
341
+ # Based on timm implementation, which can be found here:
342
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py
343
+ class PatchEmbeddings(nn.Module):
344
+ """
345
+ Image to Patch Embedding.
346
+ """
347
+
348
+ def __init__(
349
+ self,
350
+ image_size: int = 224,
351
+ patch_size: Union[int, Tuple[int, int]] = 16,
352
+ num_channels: int = 3,
353
+ embed_dim: int = 768,
354
+ ):
355
+ super().__init__()
356
+ if not isinstance(image_size, collections.abc.Iterable):
357
+ image_size = (image_size, image_size)
358
+ if not isinstance(patch_size, collections.abc.Iterable):
359
+ patch_size = (patch_size, patch_size)
360
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
361
+ self.image_size = image_size
362
+ self.patch_size = patch_size
363
+ self.num_patches = num_patches
364
+
365
+ self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
366
+
367
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
368
+ batch_size, num_channels, height, width = pixel_values.shape
369
+ if not interpolate_pos_encoding:
370
+ if height != self.image_size[0] or width != self.image_size[1]:
371
+ raise ValueError(
372
+ f"Input image size ({height}*{width}) doesn't match model"
373
+ f" ({self.image_size[0]}*{self.image_size[1]})."
374
+ )
375
+ x = self.projection(pixel_values).flatten(2).transpose(1, 2)
376
+ return x
377
+
378
+
379
+ class FlavaTextEmbeddings(nn.Module):
380
+ """Construct the embeddings from word, position and token_type embeddings."""
381
+
382
+ def __init__(self, config):
383
+ super().__init__()
384
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
385
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
386
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
387
+
388
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
389
+ # any TensorFlow checkpoint file
390
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
391
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
392
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
393
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
394
+ self.register_buffer(
395
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
396
+ )
397
+ self.register_buffer(
398
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
399
+ )
400
+
401
+ def forward(
402
+ self,
403
+ input_ids: Optional[torch.Tensor] = None,
404
+ token_type_ids: Optional[torch.Tensor] = None,
405
+ position_ids: Optional[torch.Tensor] = None,
406
+ ):
407
+ input_shape = input_ids.size()
408
+ seq_length = input_shape[1]
409
+
410
+ if position_ids is None:
411
+ position_ids = self.position_ids[:, :seq_length]
412
+
413
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
414
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
415
+ # issue #5664
416
+ if token_type_ids is None:
417
+ if hasattr(self, "token_type_ids"):
418
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
419
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
420
+ token_type_ids = buffered_token_type_ids_expanded
421
+ else:
422
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
423
+
424
+ inputs_embeds = self.word_embeddings(input_ids)
425
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
426
+
427
+ embeddings = inputs_embeds + token_type_embeddings
428
+ if self.position_embedding_type == "absolute":
429
+ position_embeddings = self.position_embeddings(position_ids)
430
+ embeddings += position_embeddings
431
+ embeddings = self.LayerNorm(embeddings)
432
+ embeddings = self.dropout(embeddings)
433
+ return embeddings
434
+
435
+
436
+ class FlavaSelfAttention(nn.Module):
437
+ def __init__(self, config: FlavaPossibleConfigs) -> None:
438
+ super().__init__()
439
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
440
+ raise ValueError(
441
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
442
+ f"heads {config.num_attention_heads}."
443
+ )
444
+
445
+ self.num_attention_heads = config.num_attention_heads
446
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
447
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
448
+
449
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
450
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
451
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
452
+
453
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
454
+
455
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
456
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
457
+ x = x.view(*new_x_shape)
458
+ return x.permute(0, 2, 1, 3)
459
+
460
+ def forward(
461
+ self,
462
+ hidden_states: torch.Tensor,
463
+ attention_mask: Optional[torch.Tensor] = None,
464
+ head_mask: Optional[torch.Tensor] = None,
465
+ output_attentions: bool = False,
466
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
467
+ mixed_query_layer = self.query(hidden_states)
468
+
469
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
470
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
471
+ query_layer = self.transpose_for_scores(mixed_query_layer)
472
+
473
+ # Take the dot product between "query" and "key" to get the raw attention scores.
474
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
475
+
476
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
477
+ if attention_mask is not None:
478
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
479
+ attention_scores = attention_scores + attention_mask
480
+
481
+ # Normalize the attention scores to probabilities.
482
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
483
+
484
+ # This is actually dropping out entire tokens to attend to, which might
485
+ # seem a bit unusual, but is taken from the original Transformer paper.
486
+ attention_probs = self.dropout(attention_probs)
487
+
488
+ # Mask heads if we want to
489
+ if head_mask is not None:
490
+ attention_probs = attention_probs * head_mask
491
+
492
+ context_layer = torch.matmul(attention_probs, value_layer)
493
+
494
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
495
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
496
+ context_layer = context_layer.view(*new_context_layer_shape)
497
+
498
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
499
+
500
+ return outputs
501
+
502
+
503
+ class FlavaSelfOutput(nn.Module):
504
+ """
505
+ The residual connection is defined in FlavaLayer (same as ViTLayer) instead of here (as is the case with other
506
+ models), due to the layernorm applied before each block.
507
+ """
508
+
509
+ def __init__(self, config: FlavaPossibleConfigs) -> None:
510
+ super().__init__()
511
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
512
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
513
+
514
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
515
+ hidden_states = self.dense(hidden_states)
516
+ hidden_states = self.dropout(hidden_states)
517
+
518
+ return hidden_states
519
+
520
+
521
+ class FlavaAttention(nn.Module):
522
+ def __init__(self, config: FlavaPossibleConfigs) -> None:
523
+ super().__init__()
524
+ self.attention = FlavaSelfAttention(config)
525
+ self.output = FlavaSelfOutput(config)
526
+ self.pruned_heads = set()
527
+
528
+ def prune_heads(self, heads: Set[int]) -> None:
529
+ if len(heads) == 0:
530
+ return
531
+ heads, index = find_pruneable_heads_and_indices(
532
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
533
+ )
534
+
535
+ # Prune linear layers
536
+ self.attention.query = prune_linear_layer(self.attention.query, index)
537
+ self.attention.key = prune_linear_layer(self.attention.key, index)
538
+ self.attention.value = prune_linear_layer(self.attention.value, index)
539
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
540
+
541
+ # Update hyper params and store pruned heads
542
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
543
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
544
+ self.pruned_heads = self.pruned_heads.union(heads)
545
+
546
+ def forward(
547
+ self,
548
+ hidden_states: torch.Tensor,
549
+ attention_mask: Optional[torch.Tensor] = None,
550
+ head_mask: Optional[torch.Tensor] = None,
551
+ output_attentions: bool = False,
552
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
553
+ self_outputs = self.attention(
554
+ hidden_states, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions
555
+ )
556
+
557
+ attention_output = self.output(self_outputs[0], hidden_states)
558
+
559
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
560
+ return outputs
561
+
562
+
563
+ class FlavaIntermediate(nn.Module):
564
+ def __init__(self, config: FlavaPossibleConfigs) -> None:
565
+ super().__init__()
566
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
567
+ if isinstance(config.hidden_act, str):
568
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
569
+ else:
570
+ self.intermediate_act_fn = config.hidden_act
571
+
572
+ # Copied from transformers.models.vit.modeling_vit.ViTIntermediate.forward
573
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
574
+ hidden_states = self.dense(hidden_states)
575
+ hidden_states = self.intermediate_act_fn(hidden_states)
576
+
577
+ return hidden_states
578
+
579
+
580
+ class FlavaOutput(nn.Module):
581
+ def __init__(self, config: FlavaPossibleConfigs) -> None:
582
+ super().__init__()
583
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
584
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
585
+
586
+ # Copied from transformers.models.vit.modeling_vit.ViTOutput.forward
587
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
588
+ hidden_states = self.dense(hidden_states)
589
+ hidden_states = self.dropout(hidden_states)
590
+
591
+ hidden_states = hidden_states + input_tensor
592
+
593
+ return hidden_states
594
+
595
+
596
+ class FlavaLayer(nn.Module):
597
+ """This corresponds to the Block class in the timm implementation."""
598
+
599
+ def __init__(self, config: FlavaPossibleConfigs) -> None:
600
+ super().__init__()
601
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
602
+ self.seq_len_dim = 1
603
+ self.attention = FlavaAttention(config)
604
+ self.intermediate = FlavaIntermediate(config)
605
+ self.output = FlavaOutput(config)
606
+
607
+ # TODO: Check fp32 layer norm possiblity
608
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
609
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
610
+
611
+ def forward(
612
+ self,
613
+ hidden_states: torch.Tensor,
614
+ attention_mask: Optional[torch.Tensor] = None,
615
+ head_mask: Optional[torch.Tensor] = None,
616
+ output_attentions: bool = False,
617
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
618
+ self_attention_outputs = self.attention(
619
+ self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
620
+ attention_mask=attention_mask,
621
+ head_mask=head_mask,
622
+ output_attentions=output_attentions,
623
+ )
624
+ attention_output = self_attention_outputs[0]
625
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
626
+
627
+ # first residual connection
628
+ hidden_states = attention_output + hidden_states
629
+
630
+ # in ViT, layernorm is also applied after self-attention
631
+ layer_output = self.layernorm_after(hidden_states)
632
+ layer_output = self.intermediate(layer_output)
633
+
634
+ # second residual connection is done here
635
+ layer_output = self.output(layer_output, hidden_states)
636
+
637
+ outputs = (layer_output,) + outputs
638
+
639
+ return outputs
640
+
641
+
642
+ class FlavaEncoder(nn.Module):
643
+ def __init__(self, config: FlavaConfig) -> None:
644
+ super().__init__()
645
+ self.config = config
646
+ self.layer = nn.ModuleList([FlavaLayer(config) for _ in range(config.num_hidden_layers)])
647
+ self.gradient_checkpointing = False
648
+
649
+ def forward(
650
+ self,
651
+ hidden_states: torch.Tensor,
652
+ attention_mask: Optional[torch.Tensor] = None,
653
+ head_mask: Optional[torch.Tensor] = None,
654
+ output_attentions: bool = False,
655
+ output_hidden_states: bool = False,
656
+ return_dict: bool = True,
657
+ ) -> Union[tuple, BaseModelOutput]:
658
+ all_hidden_states = () if output_hidden_states else None
659
+ all_self_attentions = () if output_attentions else None
660
+
661
+ for i, layer_module in enumerate(self.layer):
662
+ if output_hidden_states:
663
+ all_hidden_states = all_hidden_states + (hidden_states,)
664
+
665
+ layer_head_mask = head_mask[i] if head_mask is not None else None
666
+
667
+ if self.gradient_checkpointing and self.training:
668
+ layer_outputs = self._gradient_checkpointing_func(
669
+ layer_module.__call__,
670
+ hidden_states,
671
+ attention_mask,
672
+ layer_head_mask,
673
+ output_attentions,
674
+ )
675
+ else:
676
+ layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
677
+
678
+ hidden_states = layer_outputs[0]
679
+
680
+ if output_attentions:
681
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
682
+
683
+ if output_hidden_states:
684
+ all_hidden_states = all_hidden_states + (hidden_states,)
685
+
686
+ if not return_dict:
687
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
688
+ return BaseModelOutput(
689
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
690
+ )
691
+
692
+
693
+ class FlavaPooler(nn.Module):
694
+ def __init__(self, config: FlavaPossibleConfigs):
695
+ super().__init__()
696
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
697
+ self.activation = nn.Tanh()
698
+
699
+ def forward(self, hidden_states: torch.Tensor):
700
+ # We "pool" the model by simply taking the hidden state corresponding
701
+ # to the first token.
702
+ first_token_tensor = hidden_states[:, 0]
703
+ pooled_output = self.dense(first_token_tensor)
704
+ pooled_output = self.activation(pooled_output)
705
+ return pooled_output
706
+
707
+
708
+ FLAVA_START_DOCSTRING = r"""
709
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
710
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
711
+ behavior.
712
+
713
+ Parameters:
714
+ config ([`{config}`]): Model configuration class with all the parameters of the model.
715
+ Initializing with a config file does not load the weights associated with the model, only the
716
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
717
+ """
718
+
719
+ FLAVA_INPUTS_DOCSTRING_COMMON = r"""
720
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
721
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
722
+ - 1 for tokens that are **not masked**,
723
+ - 0 for tokens that are **masked**.
724
+ [What are attention masks?](../glossary#attention-mask)
725
+
726
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
727
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
728
+
729
+ - 1 indicates the head is **not masked**,
730
+ - 0 indicates the head is **masked**.
731
+
732
+ output_attentions (`bool`, *optional*):
733
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
734
+ tensors for more detail.
735
+ output_hidden_states (`bool`, *optional*):
736
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
737
+ more detail.
738
+
739
+ return_dict (`bool`, *optional*):
740
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
741
+ """
742
+
743
+ FLAVA_IMAGE_INPUTS_DOCSTRING_BASE = r"""
744
+ Args:
745
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
746
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
747
+ [`FlavaImageProcessor.__call__`] for details.
748
+
749
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
750
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
751
+
752
+ interpolate_pos_encoding (`bool`, *optional*):
753
+ Whether to interpolate the pre-trained position encodings.
754
+ """
755
+
756
+ FLAVA_IMAGE_INPUTS_DOCSTRING = FLAVA_IMAGE_INPUTS_DOCSTRING_BASE + FLAVA_INPUTS_DOCSTRING_COMMON
757
+
758
+ FLAVA_TEXT_INPUTS_DOCSTRING_BASE = r"""
759
+ Args:
760
+ input_ids (`torch.LongTensor` of shape `({0})`):
761
+ Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
762
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
763
+ IDs?](../glossary#input-ids)
764
+
765
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
766
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
767
+ 1]`:
768
+ - 0 corresponds to a *sentence A* token,
769
+ - 1 corresponds to a *sentence B* token.
770
+ [What are token type IDs?](../glossary#token-type-ids)
771
+ """
772
+
773
+ FLAVA_TEXT_INPUTS_DOCSTRING = FLAVA_TEXT_INPUTS_DOCSTRING_BASE + FLAVA_INPUTS_DOCSTRING_COMMON
774
+
775
+ FLAVA_MULTIMODAL_INPUTS_DOCSTRING = (
776
+ r"""
777
+ Args:
778
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, image_num_patches + text_seq_len, hidden_size)`):
779
+ The concatenated hidden states of unimodal encoders.
780
+ """
781
+ + FLAVA_INPUTS_DOCSTRING_COMMON
782
+ )
783
+
784
+ FLAVA_MODEL_INPUTS_DOCSTRING_BASE = r"""
785
+ Args:
786
+ skip_multimodal_encoder (*bool*, *optional*):
787
+ Skip any calculations for multimodal encoder. Useful if multimodal encoding is not going to be used.
788
+ """
789
+
790
+ FLAVA_MODEL_INPUTS_DOCSTRING = (
791
+ FLAVA_IMAGE_INPUTS_DOCSTRING_BASE
792
+ + FLAVA_TEXT_INPUTS_DOCSTRING_BASE
793
+ + FLAVA_INPUTS_DOCSTRING_COMMON
794
+ + FLAVA_MODEL_INPUTS_DOCSTRING_BASE
795
+ )
796
+
797
+
798
+ FLAVA_PRETRAINING_INPUTS_DOCSTRING = (
799
+ r"""
800
+ Args:
801
+ input_ids_masked (`torch.LongTensor` of shape `({0})`):
802
+ Indices of input sequence tokens in the vocabulary. These ones are the masked version of the original task
803
+ to be used with MLM. Indices can be obtained using [`AutoTokenizer`] along with
804
+ [`DataCollatorForMaskedLanguageModeling`]. See [`PreTrainedTokenizer.encode`] and
805
+ [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
806
+
807
+ """
808
+ + FLAVA_TEXT_INPUTS_DOCSTRING_BASE
809
+ + FLAVA_IMAGE_INPUTS_DOCSTRING_BASE
810
+ + r"""
811
+ image_attention_mask (`torch.FloatTensor` of shape `({1})`, *optional*):
812
+ Mask to avoid performing attention on padding token indices specifically for images. Mask values selected
813
+ in `[0, 1]`:
814
+ - 1 for tokens that are **not masked**,
815
+ - 0 for tokens that are **masked**.
816
+ [What are attention masks?](../glossary#attention-mask)
817
+
818
+ skip_unmasked_multimodal_encoder (*bool*, *optional*):
819
+ Skip any calculations for multimodal encoder for unmasked inputs. FLAVA pretraining doesn't need unmasked
820
+ multimodal embeddings or outputs as of now.
821
+
822
+ mlm_labels (`torch.LongTensor` of shape `(batch_size, text_seq_len)`, *optional*):
823
+ Labels for computing the left-to-right language and multimodal masked modeling loss (next word prediction).
824
+ Indices should be in `[-100, 0, ..., text_config.vocab_size - 1]` (see `input_ids` docstring). Tokens with
825
+ indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0,
826
+ ..., text_config.vocab_size - 1]`.
827
+
828
+ mim_labels (`torch.LongTensor` of shape `(batch_size, image_num_patches)`, *optional*):
829
+ Labels for computing the image and multimodal masked modeling loss. Indices should be in `[-100, 0, ...,
830
+ image_config.vocab_size - 1]`. Tokens with indices set to `-100` are ignored (masked), the loss is only
831
+ computed for the tokens with labels in `[0, ..., image_config.vocab_size - 1]`. If not passed, they are
832
+ generated automatically using the image codebook assigned to the model. By default, it uses
833
+ [`FlavaImageCodebook`]. See [`FlavaImageCodebook`] to understand how to generate mim_labels.
834
+
835
+ itm_labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
836
+ Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
837
+ The pairs with 0 will be skipped for calculation of MMM and global contrastive losses as well.
838
+
839
+ return_loss (`bool`, *optional*, default to None):
840
+ Whether to return calculated loss or not.
841
+ """
842
+ + FLAVA_INPUTS_DOCSTRING_COMMON
843
+ )
844
+
845
+ FLAVA_PRETRAINING_START_DOCSTRING_EXTRA = r"""
846
+ Parameters:
847
+ image_codebook ([`nn.Module`]): If passed, the image codebook will be set to this. Otherwise. it will
848
+ be initialized using the image_codebook_config defined in the config first as the first parameter.
849
+ """
850
+
851
+
852
+ class FlavaPreTrainedModel(PreTrainedModel):
853
+ """
854
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
855
+ models.
856
+ """
857
+
858
+ config_class = FlavaConfig
859
+ base_model_prefix = "flava"
860
+ supports_gradient_checkpointing = True
861
+
862
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
863
+ """Initialize the weights"""
864
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
865
+ # Slightly different from the TF version which uses truncated_normal for initialization
866
+ # cf https://github.com/pytorch/pytorch/pull/5617
867
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
868
+ if module.bias is not None:
869
+ module.bias.data.zero_()
870
+ elif isinstance(module, nn.Embedding):
871
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
872
+ if module.padding_idx is not None:
873
+ module.weight.data[module.padding_idx].zero_()
874
+ elif isinstance(module, nn.LayerNorm):
875
+ module.bias.data.zero_()
876
+ module.weight.data.fill_(1.0)
877
+ elif isinstance(module, FlavaMaskedPredictionHead):
878
+ module.bias.data.zero_()
879
+ elif isinstance(module, FlavaImageEmbeddings):
880
+ module.cls_token.data.zero_()
881
+ module.position_embeddings.data.zero_()
882
+ if module.mask_token is not None:
883
+ module.mask_token.data.zero_()
884
+ elif isinstance(module, FlavaMultimodalModel):
885
+ if module.use_cls_token:
886
+ module.cls_token.data.zero_()
887
+ elif isinstance(module, FlavaModel):
888
+ module.logit_scale.data.fill_(self.config.logit_scale_init_value)
889
+
890
+
891
+ @add_start_docstrings(
892
+ "The bare FLAVA Image Model transformer outputting raw hidden-states without any specific head on top.",
893
+ FLAVA_START_DOCSTRING.format(config="FlavaImageConfig"),
894
+ )
895
+ class FlavaImageModel(FlavaPreTrainedModel):
896
+ config_class = FlavaImageConfig
897
+ # This override allows us to load FlavaImageModel from FlavaModel/FlavaForPreTraining checkpoints.
898
+ base_model_prefix = "flava.image_model"
899
+ main_input_name = "pixel_values"
900
+
901
+ def __init__(self, config: FlavaImageConfig, add_pooling_layer: bool = True):
902
+ super().__init__(config)
903
+
904
+ self.config = config
905
+
906
+ self.embeddings = FlavaImageEmbeddings(config)
907
+ self.encoder = FlavaEncoder(config)
908
+
909
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
910
+ self.pooler = FlavaPooler(config) if add_pooling_layer else None
911
+
912
+ self.post_init()
913
+
914
+ def get_input_embeddings(self) -> nn.Module:
915
+ return self.embeddings.patch_embeddings
916
+
917
+ def set_input_embeddings(self, value: nn.Module):
918
+ self.embeddings.patch_embeddings = value
919
+
920
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
921
+ """
922
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
923
+ class PreTrainedModel
924
+ """
925
+ for layer, heads in heads_to_prune.items():
926
+ self.encoder.layer[layer].attention.prune_heads(heads)
927
+
928
+ @add_start_docstrings_to_model_forward(FLAVA_IMAGE_INPUTS_DOCSTRING.format("batch_size, image_num_patches"))
929
+ @add_code_sample_docstrings(
930
+ checkpoint=_CHECKPOINT_FOR_DOC,
931
+ output_type=BaseModelOutputWithPooling,
932
+ config_class=_CONFIG_CLASS_FOR_IMAGE_MODEL_DOC,
933
+ modality="vision",
934
+ expected_output=_EXPECTED_IMAGE_OUTPUT_SHAPE,
935
+ )
936
+ def forward(
937
+ self,
938
+ pixel_values: Optional[torch.Tensor] = None,
939
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
940
+ interpolate_pos_encoding: Optional[bool] = None,
941
+ attention_mask: Optional[torch.Tensor] = None,
942
+ head_mask: Optional[torch.Tensor] = None,
943
+ output_attentions: Optional[bool] = None,
944
+ output_hidden_states: Optional[bool] = None,
945
+ return_dict: Optional[bool] = None,
946
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
947
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
948
+ output_hidden_states = (
949
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
950
+ )
951
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
952
+
953
+ if pixel_values is None:
954
+ raise ValueError("You have to specify pixel_values")
955
+
956
+ # Prepare head mask if needed
957
+ # 1.0 in head_mask indicate we keep the head
958
+ # attention_probs has shape bsz x n_heads x N x N
959
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
960
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
961
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
962
+
963
+ embedding_output = self.embeddings(
964
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
965
+ )
966
+
967
+ encoder_outputs = self.encoder(
968
+ embedding_output,
969
+ attention_mask=attention_mask,
970
+ head_mask=head_mask,
971
+ output_attentions=output_attentions,
972
+ output_hidden_states=output_hidden_states,
973
+ return_dict=return_dict,
974
+ )
975
+ sequence_output = encoder_outputs[0]
976
+ sequence_output = self.layernorm(sequence_output)
977
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
978
+
979
+ if not return_dict:
980
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
981
+
982
+ return BaseModelOutputWithPooling(
983
+ last_hidden_state=sequence_output,
984
+ pooler_output=pooled_output,
985
+ hidden_states=encoder_outputs.hidden_states,
986
+ attentions=encoder_outputs.attentions,
987
+ )
988
+
989
+
990
+ @add_start_docstrings(
991
+ "The bare FLAVA Text Model transformer outputting raw hidden-states without any specific head on top.",
992
+ FLAVA_START_DOCSTRING.format(config="FlavaTextConfig"),
993
+ )
994
+ class FlavaTextModel(FlavaPreTrainedModel):
995
+ config_class = FlavaTextConfig
996
+ # This override allows us to load FlavaTextModel from FlavaModel/FlavaForPreTraining checkpoints.
997
+ base_model_prefix = "flava.text_model"
998
+
999
+ def __init__(self, config: FlavaTextConfig, add_pooling_layer: bool = True):
1000
+ super().__init__(config)
1001
+ self.config = config
1002
+
1003
+ self.embeddings = FlavaTextEmbeddings(config)
1004
+ self.encoder = FlavaEncoder(config)
1005
+
1006
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1007
+ self.pooler = FlavaPooler(config) if add_pooling_layer else None
1008
+
1009
+ self.post_init()
1010
+
1011
+ def get_input_embeddings(self) -> PatchEmbeddings:
1012
+ return self.embeddings.word_embeddings
1013
+
1014
+ def set_input_embeddings(self, value: nn.Module):
1015
+ self.embeddings.word_embeddings = value
1016
+
1017
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
1018
+ """
1019
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1020
+ class PreTrainedModel
1021
+ """
1022
+ for layer, heads in heads_to_prune.items():
1023
+ self.encoder.layer[layer].attention.prune_heads(heads)
1024
+
1025
+ @add_start_docstrings_to_model_forward(FLAVA_TEXT_INPUTS_DOCSTRING.format("batch_size, text_seq_length"))
1026
+ @add_code_sample_docstrings(
1027
+ checkpoint=_CHECKPOINT_FOR_DOC,
1028
+ output_type=BaseModelOutputWithPooling,
1029
+ config_class=_CONFIG_CLASS_FOR_TEXT_MODEL_DOC,
1030
+ )
1031
+ def forward(
1032
+ self,
1033
+ input_ids: Optional[torch.Tensor] = None,
1034
+ attention_mask: Optional[torch.Tensor] = None,
1035
+ token_type_ids: Optional[torch.Tensor] = None,
1036
+ position_ids: Optional[torch.Tensor] = None,
1037
+ head_mask: Optional[torch.Tensor] = None,
1038
+ output_attentions: Optional[bool] = None,
1039
+ output_hidden_states: Optional[bool] = None,
1040
+ return_dict: Optional[bool] = None,
1041
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
1042
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1043
+ output_hidden_states = (
1044
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1045
+ )
1046
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1047
+
1048
+ if input_ids is None:
1049
+ raise ValueError("You have to specify input_ids")
1050
+
1051
+ input_shape = input_ids.size()
1052
+
1053
+ if attention_mask is None:
1054
+ attention_mask = torch.ones(input_shape, device=input_ids.device)
1055
+
1056
+ # Prepare head mask if needed
1057
+ # 1.0 in head_mask indicate we keep the head
1058
+ # attention_probs has shape bsz x n_heads x N x N
1059
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1060
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1061
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1062
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
1063
+ attention_mask, input_shape, input_ids.device
1064
+ )
1065
+
1066
+ embedding_output = self.embeddings(
1067
+ input_ids=input_ids,
1068
+ token_type_ids=token_type_ids,
1069
+ position_ids=position_ids,
1070
+ )
1071
+
1072
+ encoder_outputs = self.encoder(
1073
+ embedding_output,
1074
+ attention_mask=extended_attention_mask,
1075
+ head_mask=head_mask,
1076
+ output_attentions=output_attentions,
1077
+ output_hidden_states=output_hidden_states,
1078
+ return_dict=return_dict,
1079
+ )
1080
+ sequence_output = encoder_outputs[0]
1081
+ sequence_output = self.layernorm(sequence_output)
1082
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1083
+
1084
+ if not return_dict:
1085
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1086
+
1087
+ return BaseModelOutputWithPooling(
1088
+ last_hidden_state=sequence_output,
1089
+ pooler_output=pooled_output,
1090
+ hidden_states=encoder_outputs.hidden_states,
1091
+ attentions=encoder_outputs.attentions,
1092
+ )
1093
+
1094
+
1095
+ @add_start_docstrings(
1096
+ "The bare FLAVA Multimodal Model transformer outputting raw hidden-states without any specific head on top.",
1097
+ FLAVA_START_DOCSTRING.format(config="FlavaMultimodalConfig"),
1098
+ )
1099
+ class FlavaMultimodalModel(FlavaPreTrainedModel):
1100
+ config_class = FlavaMultimodalConfig
1101
+ # This override allows us to load FlavaMultimodalModel from FlavaModel/FlavaForPreTraining checkpoints.
1102
+ base_model_prefix = "flava.multimodal_model"
1103
+ main_input_name = "hidden_states"
1104
+
1105
+ def __init__(self, config: FlavaMultimodalConfig, add_pooling_layer=True):
1106
+ super().__init__(config)
1107
+ self.config = config
1108
+ self.use_cls_token = self.config.use_cls_token
1109
+ if self.use_cls_token:
1110
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
1111
+
1112
+ self.encoder = FlavaEncoder(config)
1113
+
1114
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1115
+ self.pooler = FlavaPooler(config) if add_pooling_layer else None
1116
+
1117
+ self.post_init()
1118
+
1119
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
1120
+ """
1121
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1122
+ class PreTrainedModel
1123
+ """
1124
+ for layer, heads in heads_to_prune.items():
1125
+ self.encoder.layer[layer].attention.prune_heads(heads)
1126
+
1127
+ @add_start_docstrings_to_model_forward(
1128
+ FLAVA_MULTIMODAL_INPUTS_DOCSTRING.format("batch_size, image_num_patches + text_seq_len")
1129
+ )
1130
+ @add_code_sample_docstrings(
1131
+ checkpoint=_CHECKPOINT_FOR_DOC,
1132
+ output_type=BaseModelOutputWithPooling,
1133
+ config_class=_CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC,
1134
+ )
1135
+ def forward(
1136
+ self,
1137
+ hidden_states: torch.Tensor,
1138
+ attention_mask: Optional[torch.Tensor] = None,
1139
+ head_mask: Optional[torch.Tensor] = None,
1140
+ output_attentions: Optional[bool] = None,
1141
+ output_hidden_states: Optional[bool] = None,
1142
+ return_dict: Optional[bool] = None,
1143
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
1144
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1145
+ output_hidden_states = (
1146
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1147
+ )
1148
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1149
+
1150
+ batch_size, seq_length, _ = hidden_states.size()
1151
+
1152
+ if self.use_cls_token:
1153
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
1154
+ hidden_states = torch.cat((cls_tokens, hidden_states), dim=1)
1155
+ seq_length += 1
1156
+
1157
+ if attention_mask is None:
1158
+ attention_mask = torch.ones((batch_size, seq_length), device=hidden_states.device)
1159
+
1160
+ # Prepare head mask if needed
1161
+ # 1.0 in head_mask indicate we keep the head
1162
+ # attention_probs has shape bsz x n_heads x N x N
1163
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1164
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1165
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1166
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
1167
+ attention_mask, (batch_size, seq_length), hidden_states.device
1168
+ )
1169
+
1170
+ encoder_outputs = self.encoder(
1171
+ hidden_states,
1172
+ attention_mask=extended_attention_mask,
1173
+ head_mask=head_mask,
1174
+ output_attentions=output_attentions,
1175
+ output_hidden_states=output_hidden_states,
1176
+ return_dict=return_dict,
1177
+ )
1178
+ sequence_output = encoder_outputs[0]
1179
+ sequence_output = self.layernorm(sequence_output)
1180
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1181
+
1182
+ if not return_dict:
1183
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1184
+
1185
+ return BaseModelOutputWithPooling(
1186
+ last_hidden_state=sequence_output,
1187
+ pooler_output=pooled_output,
1188
+ hidden_states=encoder_outputs.hidden_states,
1189
+ attentions=encoder_outputs.attentions,
1190
+ )
1191
+
1192
+
1193
+ @add_start_docstrings(
1194
+ "The bare FLAVA Model transformer outputting raw hidden-states without any specific head on top.",
1195
+ FLAVA_START_DOCSTRING.format(config="FlavaConfig"),
1196
+ )
1197
+ class FlavaModel(FlavaPreTrainedModel):
1198
+ config_class = FlavaConfig
1199
+
1200
+ def __init__(self, config: FlavaConfig):
1201
+ super().__init__(config)
1202
+
1203
+ if not isinstance(config.text_config, FlavaTextConfig):
1204
+ raise TypeError(
1205
+ "config.text_config is expected to be of type FlavaTextConfig but is of type"
1206
+ f" {type(config.text_config)}."
1207
+ )
1208
+
1209
+ if not isinstance(config.image_config, FlavaImageConfig):
1210
+ raise TypeError(
1211
+ "config.image_config is expected to be of type FlavaImageConfig but is of type"
1212
+ f" {type(config.image_config)}."
1213
+ )
1214
+
1215
+ if not isinstance(config.multimodal_config, FlavaMultimodalConfig):
1216
+ raise TypeError(
1217
+ "config.multimodal_config is expected to be of type FlavaMultimodalConfig but "
1218
+ + f"is of type {type(config.multimodal_config)}."
1219
+ )
1220
+
1221
+ text_config = config.text_config
1222
+ image_config = config.image_config
1223
+ multimodal_config = config.multimodal_config
1224
+
1225
+ self.projection_dim = config.projection_dim
1226
+ self.text_hidden_size = text_config.hidden_size
1227
+ self.image_hidden_size = image_config.hidden_size
1228
+ self.mm_hidden_size = multimodal_config.hidden_size
1229
+
1230
+ self.text_model = FlavaTextModel(text_config)
1231
+ self.image_model = FlavaImageModel(image_config)
1232
+ self.multimodal_model = FlavaMultimodalModel(multimodal_config)
1233
+
1234
+ self.image_projection = nn.Linear(self.image_hidden_size, self.projection_dim)
1235
+ self.text_projection = nn.Linear(self.text_hidden_size, self.projection_dim)
1236
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
1237
+
1238
+ self.image_to_mm_projection = nn.Linear(self.image_hidden_size, self.mm_hidden_size)
1239
+ self.text_to_mm_projection = nn.Linear(self.text_hidden_size, self.mm_hidden_size)
1240
+ # Initialize weights and apply final processing
1241
+ self.post_init()
1242
+
1243
+ @add_start_docstrings_to_model_forward(FLAVA_TEXT_INPUTS_DOCSTRING.format("batch_size, text_seq_length"))
1244
+ def get_text_features(
1245
+ self,
1246
+ input_ids: Optional[torch.Tensor] = None,
1247
+ attention_mask: Optional[torch.Tensor] = None,
1248
+ token_type_ids: Optional[torch.Tensor] = None,
1249
+ position_ids: Optional[torch.Tensor] = None,
1250
+ output_attentions: Optional[bool] = None,
1251
+ output_hidden_states: Optional[bool] = None,
1252
+ return_dict: Optional[bool] = None,
1253
+ ) -> torch.FloatTensor:
1254
+ r"""
1255
+ Returns:
1256
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1257
+ applying the projection layer to the pooled output of [`FlavaTextModel`].
1258
+
1259
+ Examples:
1260
+
1261
+ ```python
1262
+ >>> from transformers import AutoProcessor, FlavaModel
1263
+
1264
+ >>> model = FlavaModel.from_pretrained("{0}")
1265
+ >>> processor = AutoProcessor.from_pretrained("{0}")
1266
+
1267
+ >>> inputs = processor(
1268
+ ... text=["a photo of a cat", "a photo of a dog"], max_length=77, padding="max_length", return_tensors="pt"
1269
+ ... )
1270
+ >>> text_features = model.get_text_features(**inputs)
1271
+ ```""".format(_CHECKPOINT_FOR_DOC)
1272
+ text_outputs = self.text_model(
1273
+ input_ids=input_ids,
1274
+ attention_mask=attention_mask,
1275
+ token_type_ids=token_type_ids,
1276
+ position_ids=position_ids,
1277
+ output_attentions=output_attentions,
1278
+ output_hidden_states=output_hidden_states,
1279
+ return_dict=return_dict,
1280
+ )
1281
+
1282
+ pooled_output = text_outputs[0] # last_hidden_state
1283
+ text_features = self.text_projection(pooled_output)
1284
+
1285
+ return text_features
1286
+
1287
+ @add_start_docstrings_to_model_forward(FLAVA_IMAGE_INPUTS_DOCSTRING.format("batch_size, image_num_patches"))
1288
+ def get_image_features(
1289
+ self,
1290
+ pixel_values: Optional[torch.Tensor] = None,
1291
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
1292
+ interpolate_pos_encoding: Optional[bool] = None,
1293
+ attention_mask: Optional[torch.Tensor] = None,
1294
+ head_mask: Optional[torch.Tensor] = None,
1295
+ output_attentions: Optional[bool] = None,
1296
+ output_hidden_states: Optional[bool] = None,
1297
+ return_dict: Optional[bool] = None,
1298
+ ) -> torch.FloatTensor:
1299
+ r"""
1300
+ Returns:
1301
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1302
+ applying the projection layer to the pooled output of [`FlavaImageModel`].
1303
+
1304
+ Examples:
1305
+
1306
+ ```python
1307
+ >>> from PIL import Image
1308
+ >>> import requests
1309
+ >>> from transformers import AutoProcessor, FlavaModel
1310
+
1311
+ >>> model = FlavaModel.from_pretrained("{0}")
1312
+ >>> processor = AutoProcessor.from_pretrained("{0}")
1313
+
1314
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1315
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1316
+
1317
+ >>> inputs = processor(images=image, return_tensors="pt")
1318
+
1319
+ >>> image_features = model.get_image_features(**inputs)
1320
+ ```""".format(_CHECKPOINT_FOR_DOC)
1321
+ image_outputs = self.image_model(
1322
+ pixel_values=pixel_values,
1323
+ bool_masked_pos=bool_masked_pos,
1324
+ attention_mask=attention_mask,
1325
+ head_mask=head_mask,
1326
+ output_attentions=output_attentions,
1327
+ output_hidden_states=output_hidden_states,
1328
+ interpolate_pos_encoding=interpolate_pos_encoding,
1329
+ return_dict=return_dict,
1330
+ )
1331
+
1332
+ pooled_output = image_outputs[0] # last_hidden_state
1333
+ image_features = self.image_projection(pooled_output)
1334
+
1335
+ return image_features
1336
+
1337
+ @add_start_docstrings_to_model_forward(
1338
+ FLAVA_MODEL_INPUTS_DOCSTRING.format("batch_size, image_num_patches + text_seq_len")
1339
+ )
1340
+ @replace_return_docstrings(output_type=FlavaModelOutput, config_class=FlavaConfig)
1341
+ def forward(
1342
+ self,
1343
+ input_ids: Optional[torch.LongTensor] = None,
1344
+ pixel_values: Optional[torch.FloatTensor] = None,
1345
+ attention_mask: Optional[torch.Tensor] = None,
1346
+ token_type_ids: Optional[torch.Tensor] = None,
1347
+ bool_masked_pos: Optional[torch.Tensor] = None,
1348
+ position_ids: Optional[torch.LongTensor] = None,
1349
+ image_attention_mask: Optional[torch.Tensor] = None,
1350
+ skip_multimodal_encoder: Optional[bool] = None,
1351
+ output_attentions: Optional[bool] = None,
1352
+ output_hidden_states: bool = True,
1353
+ return_dict: Optional[bool] = None,
1354
+ ) -> Union[Tuple, FlavaOutput]:
1355
+ r"""
1356
+ Returns:
1357
+
1358
+ Examples:
1359
+
1360
+ ```python
1361
+ >>> from PIL import Image
1362
+ >>> import requests
1363
+ >>> from transformers import AutoProcessor, FlavaModel
1364
+
1365
+ >>> model = FlavaModel.from_pretrained("facebook/flava-full")
1366
+ >>> processor = AutoProcessor.from_pretrained("facebook/flava-full")
1367
+
1368
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1369
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1370
+
1371
+ >>> inputs = processor(text=["a photo of a cat"], images=image, return_tensors="pt", padding=True)
1372
+
1373
+ >>> outputs = model(**inputs)
1374
+
1375
+ >>> image_embeddings = outputs.image_embeddings
1376
+ >>> text_embeddings = outputs.text_embeddings
1377
+ >>> multimodal_embeddings = outputs.multimodal_embeddings
1378
+
1379
+ >>> outputs.image_embeddings.shape
1380
+ torch.Size([1, 197, 768])
1381
+
1382
+ >>> text_embeddings.shape
1383
+ torch.Size([1, 7, 768])
1384
+
1385
+ >>> multimodal_embeddings.shape
1386
+ torch.Size([1, 205, 768])
1387
+ ```
1388
+ """
1389
+
1390
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1391
+ if not output_hidden_states:
1392
+ raise ValueError("FLAVA model requires hidden states to work. Please set `output_hidden_states=True`")
1393
+ image_embeddings = None
1394
+ image_states = None
1395
+ image_mm_projection = None
1396
+ image_output = None
1397
+ if pixel_values is not None:
1398
+ image_output = self.image_model(
1399
+ pixel_values=pixel_values,
1400
+ bool_masked_pos=bool_masked_pos,
1401
+ attention_mask=image_attention_mask,
1402
+ output_attentions=output_attentions,
1403
+ output_hidden_states=output_hidden_states,
1404
+ return_dict=return_dict,
1405
+ )
1406
+ image_embeddings, image_states = image_output[0], image_output[2]
1407
+ # Note that these states don't use final layernorm in the transformer model
1408
+ image_mm_projection = self.image_to_mm_projection(image_states[-1])
1409
+
1410
+ text_embeddings = None
1411
+ text_states = None
1412
+ text_mm_projection = None
1413
+ text_output = None
1414
+ if input_ids is not None:
1415
+ text_output = self.text_model(
1416
+ input_ids=input_ids,
1417
+ attention_mask=attention_mask,
1418
+ position_ids=position_ids,
1419
+ token_type_ids=token_type_ids,
1420
+ output_attentions=output_attentions,
1421
+ output_hidden_states=output_hidden_states,
1422
+ return_dict=return_dict,
1423
+ )
1424
+
1425
+ text_embeddings, text_states = text_output[0], text_output[2]
1426
+ # Note that these states don't use final layernorm in the transformer model
1427
+ text_mm_projection = self.text_to_mm_projection(text_states[-1])
1428
+
1429
+ multimodal_embeddings = None
1430
+ multimodal_output = None
1431
+ if image_mm_projection is not None and text_mm_projection is not None and not skip_multimodal_encoder:
1432
+ if attention_mask is not None:
1433
+ batch_size, seq_len, _ = image_mm_projection.shape
1434
+ if self.multimodal_model.use_cls_token:
1435
+ seq_len += 1
1436
+ attention_mask_image = torch.ones(batch_size, seq_len, device=image_mm_projection.device)
1437
+ attention_multimodal = torch.cat([attention_mask_image, attention_mask], dim=1)
1438
+ else:
1439
+ attention_multimodal = None
1440
+ multimodal_input = torch.cat([image_mm_projection, text_mm_projection], dim=1)
1441
+ multimodal_output = self.multimodal_model(
1442
+ multimodal_input, attention_mask=attention_multimodal, return_dict=return_dict
1443
+ )
1444
+ multimodal_embeddings = multimodal_output[0]
1445
+
1446
+ if not return_dict:
1447
+ return (
1448
+ image_embeddings,
1449
+ image_output,
1450
+ text_embeddings,
1451
+ text_output,
1452
+ multimodal_embeddings,
1453
+ multimodal_output,
1454
+ )
1455
+
1456
+ return FlavaModelOutput(
1457
+ image_embeddings=image_embeddings,
1458
+ image_output=image_output,
1459
+ text_embeddings=text_embeddings,
1460
+ text_output=text_output,
1461
+ multimodal_embeddings=multimodal_embeddings,
1462
+ multimodal_output=multimodal_output,
1463
+ )
1464
+
1465
+
1466
+ class FlavaImageCodebookResPath(nn.Module):
1467
+ def __init__(self, in_size: int, out_size: int, **kwargs):
1468
+ super().__init__()
1469
+ hid_size = out_size // 4
1470
+
1471
+ path = OrderedDict()
1472
+ path["relu_1"] = nn.ReLU()
1473
+ path["conv_1"] = nn.Conv2d(in_size, hid_size, kernel_size=3, padding=1)
1474
+ path["relu_2"] = nn.ReLU()
1475
+ path["conv_2"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1)
1476
+ path["relu_3"] = nn.ReLU()
1477
+ path["conv_3"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1)
1478
+ path["relu_4"] = nn.ReLU()
1479
+ path["conv_4"] = nn.Conv2d(hid_size, out_size, kernel_size=1, padding=0)
1480
+
1481
+ self.path = nn.Sequential(path)
1482
+
1483
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1484
+ return self.path(x)
1485
+
1486
+
1487
+ class FlavaImageCodebookBlock(nn.Module):
1488
+ def __init__(self, in_size: int, out_size: int, num_layers: int, **kwargs):
1489
+ super().__init__()
1490
+
1491
+ self.post_gain = 1 / (num_layers**2)
1492
+
1493
+ if in_size != out_size:
1494
+ self.id_path = nn.Conv2d(in_size, out_size, kernel_size=1, padding=0)
1495
+ else:
1496
+ self.id_path = nn.Identity()
1497
+
1498
+ self.res_path = FlavaImageCodebookResPath(in_size, out_size)
1499
+
1500
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1501
+ return self.id_path(x) + self.post_gain * self.res_path(x)
1502
+
1503
+
1504
+ class FlavaImageCodebookLayerGroup(nn.Module):
1505
+ def __init__(self, num_blocks: int, num_layers: int, in_size: int, out_size: int, use_pool: bool = True):
1506
+ super().__init__()
1507
+ blocks = OrderedDict()
1508
+ for i in range(num_blocks):
1509
+ if i == 0:
1510
+ blocks[f"block_{i + 1}"] = FlavaImageCodebookBlock(in_size, out_size, num_layers)
1511
+ else:
1512
+ blocks[f"block_{i + 1}"] = FlavaImageCodebookBlock(out_size, out_size, num_layers)
1513
+
1514
+ if use_pool:
1515
+ blocks["pool"] = nn.MaxPool2d(kernel_size=2)
1516
+
1517
+ self.group = nn.Sequential(blocks)
1518
+
1519
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1520
+ return self.group(x)
1521
+
1522
+
1523
+ # Inspired by DALLE Encoder in https://github.com/openai/DALL-E/blob/5be4b236bc3ade6943662354117a0e83752cc322/dall_e/encoder.py#L42
1524
+ @add_start_docstrings(
1525
+ """
1526
+ The FLAVA's image codebook model inspired from DALL-E's original encoder. Outputs raw hidden states and can be used
1527
+ to generate image tokens for an image based on DALL-E's vocab. Used to generate labels for MIM. Use
1528
+ `get_codebook_indices` to get image tokens for an image.
1529
+ """,
1530
+ FLAVA_START_DOCSTRING.format(config="FlavaImageCodebookConfig"),
1531
+ )
1532
+ class FlavaImageCodebook(FlavaPreTrainedModel):
1533
+ base_model_prefix = ""
1534
+ config_class = FlavaImageCodebookConfig
1535
+ main_input_name = "pixel_values"
1536
+ supports_gradient_checkpointing = False
1537
+
1538
+ def __init__(
1539
+ self,
1540
+ config: FlavaImageCodebookConfig,
1541
+ **kwargs: Any,
1542
+ ):
1543
+ super().__init__(config)
1544
+
1545
+ self.config = config
1546
+ self.num_groups = config.num_groups
1547
+ self.input_channels = config.input_channels
1548
+ self.num_blocks_per_group = config.num_blocks_per_group
1549
+ self.hidden_size = config.hidden_size
1550
+ self.vocab_size = config.vocab_size
1551
+
1552
+ num_layers = self.num_groups * self.num_blocks_per_group
1553
+
1554
+ output_blocks = OrderedDict()
1555
+ output_blocks["relu"] = nn.ReLU()
1556
+ output_blocks["conv"] = nn.Conv2d(8 * self.hidden_size, self.vocab_size, kernel_size=1, padding=0)
1557
+
1558
+ blocks = OrderedDict()
1559
+ blocks["input"] = nn.Conv2d(self.input_channels, 1 * self.hidden_size, kernel_size=7, padding=3)
1560
+ blocks["group_1"] = FlavaImageCodebookLayerGroup(
1561
+ self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 1 * self.hidden_size
1562
+ )
1563
+ blocks["group_2"] = FlavaImageCodebookLayerGroup(
1564
+ self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 2 * self.hidden_size
1565
+ )
1566
+ blocks["group_3"] = FlavaImageCodebookLayerGroup(
1567
+ self.num_blocks_per_group, num_layers, 2 * self.hidden_size, 4 * self.hidden_size
1568
+ )
1569
+ blocks["group_4"] = FlavaImageCodebookLayerGroup(
1570
+ self.num_blocks_per_group, num_layers, 4 * self.hidden_size, 8 * self.hidden_size, use_pool=False
1571
+ )
1572
+ blocks["output"] = nn.Sequential(output_blocks)
1573
+
1574
+ self.blocks = nn.Sequential(blocks)
1575
+
1576
+ self.post_init()
1577
+
1578
+ if self.config.freeze:
1579
+ for param in self.parameters():
1580
+ param.requires_grad = False
1581
+
1582
+ def get_codebook_indices(self, pixel_values: torch.Tensor) -> torch.Tensor:
1583
+ """
1584
+ Args:
1585
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1586
+ Pixel values. Codebook pixel values can be obtained using [`AutoImageProcessor`] by passing
1587
+ `return_codebook_pixels=True`. See [`FlavaImageProcessor.__call__`] for details.
1588
+
1589
+ Examples:
1590
+ ```python
1591
+ >>> from PIL import Image
1592
+ >>> import requests
1593
+ >>> from transformers import AutoImageProcessor, FlavaImageCodebook
1594
+
1595
+ >>> model = FlavaImageCodebook.from_pretrained("{0}")
1596
+ >>> image_processor = AutoImageProcessor.from_pretrained("{0}")
1597
+
1598
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1599
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1600
+
1601
+ >>> inputs = image_processor([image], return_codebook_pixels=True, return_tensors="pt")
1602
+ >>> inputs = dict(pixel_values=inputs.codebook_pixel_values)
1603
+
1604
+ >>> outputs = model.get_codebook_indices(**inputs)
1605
+ ```
1606
+ """.format(_CHECKPOINT_FOR_CODEBOOK_DOC)
1607
+ z_logits = self.blocks(pixel_values)
1608
+ return torch.argmax(z_logits, axis=1)
1609
+
1610
+ def get_codebook_probs(self, pixel_values: torch.Tensor) -> torch.Tensor:
1611
+ z_logits = self.blocks(pixel_values)
1612
+ return nn.Softmax(dim=1)(z_logits)
1613
+
1614
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
1615
+ """
1616
+ Args:
1617
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1618
+ Pixel values. Codebook pixel values can be obtained using [`AutoImageProcessor`] by passing
1619
+ `return_codebook_pixels=True`. See [`FlavaImageProcessor.__call__`] for details.
1620
+
1621
+ Examples:
1622
+
1623
+ ```python
1624
+ >>> from PIL import Image
1625
+ >>> import requests
1626
+ >>> from transformers import AutoImageProcessor, FlavaImageCodebook
1627
+
1628
+ >>> model = FlavaImageCodebook.from_pretrained("{0}")
1629
+ >>> image_processor = AutoImageProcessor.from_pretrained("{0}")
1630
+
1631
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1632
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1633
+
1634
+ >>> inputs = image_processor([image], return_codebook_pixels=True, return_tensors="pt")
1635
+ >>> inputs = dict(pixel_values=inputs.codebook_pixel_values)
1636
+
1637
+ >>> outputs = model(**inputs)
1638
+ >>> print(outputs.shape)
1639
+ (1, 196)
1640
+ ```
1641
+ """.format(_CHECKPOINT_FOR_CODEBOOK_DOC)
1642
+ if len(pixel_values.shape) != 4:
1643
+ raise ValueError(f"input shape {pixel_values.shape} is not 4d")
1644
+ if pixel_values.shape[1] != self.input_channels:
1645
+ raise ValueError(f"input has {pixel_values.shape[1]} channels but model built for {self.input_channels}")
1646
+ return self.blocks(pixel_values)
1647
+
1648
+
1649
+ class FlavaPredictionHeadTransform(nn.Module):
1650
+ def __init__(self, config):
1651
+ super().__init__()
1652
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1653
+ if isinstance(config.hidden_act, str):
1654
+ self.transform_act_fn = ACT2FN[config.hidden_act]
1655
+ else:
1656
+ self.transform_act_fn = config.hidden_act
1657
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1658
+
1659
+ def forward(self, hidden_states):
1660
+ hidden_states = self.dense(hidden_states)
1661
+ hidden_states = self.transform_act_fn(hidden_states)
1662
+ hidden_states = self.LayerNorm(hidden_states)
1663
+ return hidden_states
1664
+
1665
+
1666
+ class FlavaMaskedPredictionHead(nn.Module):
1667
+ def __init__(self, config, weight=None):
1668
+ super().__init__()
1669
+ self.config = config
1670
+ self.transform = FlavaPredictionHeadTransform(config)
1671
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1672
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1673
+ if weight is not None:
1674
+ self.decoder.weight = weight
1675
+
1676
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
1677
+ self.decoder.bias = self.bias
1678
+
1679
+ def _tie_weights(self):
1680
+ self.decoder.bias = self.bias
1681
+
1682
+ def forward(self, x):
1683
+ x = self.transform(x)
1684
+ x = self.decoder(x)
1685
+ return x
1686
+
1687
+
1688
+ class FlavaITMHead(nn.Module):
1689
+ def __init__(self, config):
1690
+ super().__init__()
1691
+ self.config = config
1692
+ self.pooler = FlavaPooler(config)
1693
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
1694
+
1695
+ def forward(self, x):
1696
+ x = self.pooler(x)
1697
+ x = self.seq_relationship(x)
1698
+ return x
1699
+
1700
+
1701
+ class FlavaGlobalContrastiveHead(nn.Module):
1702
+ def __init__(self, config):
1703
+ super().__init__()
1704
+ self.config = config
1705
+ self.global_backprop_contrastive = config.global_backprop_contrastive
1706
+
1707
+ def forward(self, image_embeddings, text_embeddings, logit_scale):
1708
+ temperature = torch.exp(logit_scale)
1709
+ if not torch.distributed.is_available() or not torch.distributed.is_initialized():
1710
+ labels = torch.arange(image_embeddings.size(0), device=image_embeddings.device)
1711
+ image_embeddings_all = [image_embeddings]
1712
+ text_embeddings_all = [text_embeddings]
1713
+ else:
1714
+ local_batch_size = image_embeddings.size(0)
1715
+ world_size = torch.distributed.get_world_size()
1716
+
1717
+ if self.global_backprop_contrastive:
1718
+ # `torch.distributed.nn.functional.all_gather` does backprop on all active workers
1719
+ # whereas `torch.distributed.all_gather` does only backpropagates on the current worker.
1720
+ image_embeddings_all = torch.distributed.nn.functional.all_gather(image_embeddings)
1721
+ text_embeddings_all = torch.distributed.nn.functional.all_gather(text_embeddings)
1722
+ else:
1723
+ image_embeddings_all = [torch.zeros_like(text_embeddings) for _ in range(world_size)]
1724
+ text_embeddings_all = [torch.zeros_like(image_embeddings) for _ in range(world_size)]
1725
+ torch.distributed.all_gather(image_embeddings_all, image_embeddings)
1726
+ torch.distributed.all_gather(text_embeddings_all, text_embeddings)
1727
+
1728
+ labels = local_batch_size * torch.distributed.get_rank() + torch.arange(
1729
+ local_batch_size, device=image_embeddings.device
1730
+ )
1731
+
1732
+ image_embeddings_all = torch.cat(image_embeddings_all)
1733
+ text_embeddings_all = torch.cat(text_embeddings_all)
1734
+
1735
+ logits_per_image = torch.matmul(image_embeddings, text_embeddings_all.transpose(0, 1)) * temperature
1736
+ logits_per_text = torch.matmul(text_embeddings, image_embeddings_all.transpose(0, 1)) * temperature
1737
+
1738
+ return logits_per_image, logits_per_text, labels
1739
+
1740
+
1741
+ @add_start_docstrings(
1742
+ """
1743
+ The FLAVA model for pretraining which outputs losses, embeddings, logits and transformer outputs.
1744
+ """,
1745
+ FLAVA_START_DOCSTRING.format(config="FlavaConfig") + FLAVA_PRETRAINING_START_DOCSTRING_EXTRA,
1746
+ )
1747
+ class FlavaForPreTraining(FlavaPreTrainedModel):
1748
+ # Those are linked to xxx.bias
1749
+ _tied_weights_keys = [
1750
+ "mmm_text_head.decoder.bias",
1751
+ "mmm_image_head.decoder.bias",
1752
+ "mlm_head.decoder.bias",
1753
+ "mim_head.decoder.bias",
1754
+ ]
1755
+
1756
+ def __init__(self, config: FlavaConfig, image_codebook: Optional[nn.Module] = None):
1757
+ super().__init__(config)
1758
+ self.flava = FlavaModel(config)
1759
+
1760
+ self.image_codebook = image_codebook
1761
+ if self.image_codebook is None and config.init_codebook:
1762
+ self.image_codebook = FlavaImageCodebook(config.image_codebook_config)
1763
+
1764
+ # Levarage text and image encoder configs to create the masked
1765
+ # head since it has the right vocab
1766
+ self.mim_head = FlavaMaskedPredictionHead(config.image_config)
1767
+ self.mlm_head = FlavaMaskedPredictionHead(config.text_config)
1768
+ self.itm_head = FlavaITMHead(config)
1769
+ self.mmm_image_head = FlavaMaskedPredictionHead(config.image_config)
1770
+ self.mmm_text_head = FlavaMaskedPredictionHead(config.text_config)
1771
+ self.global_contrastive_head = FlavaGlobalContrastiveHead(config)
1772
+
1773
+ self.image_vocab_size = config.image_config.vocab_size
1774
+ self.text_vocab_size = config.text_config.vocab_size
1775
+ self.mlm_weight = config.mlm_weight
1776
+ self.mim_weight = config.mim_weight
1777
+ self.global_contrastive_weight = config.global_contrastive_weight
1778
+ self.ce_ignore_index = config.ce_ignore_index
1779
+ self.itm_weight = config.itm_weight
1780
+ self.mmm_image_weight = config.mmm_image_weight
1781
+ self.mmm_text_weight = config.mmm_text_weight
1782
+ self.skip_unmasked_multimodal_encoder = config.skip_unmasked_multimodal_encoder
1783
+
1784
+ self.post_init()
1785
+
1786
+ def _resize_to_2d(self, x: torch.Tensor):
1787
+ if x.dim() > 2:
1788
+ x = x.view(x.size(0), -1)
1789
+ return x
1790
+
1791
+ @add_start_docstrings_to_model_forward(
1792
+ FLAVA_PRETRAINING_INPUTS_DOCSTRING.format("batch_size, text_seq_len", "batch_size, image_num_patches")
1793
+ )
1794
+ @replace_return_docstrings(output_type=FlavaForPreTrainingOutput, config_class=FlavaConfig)
1795
+ def forward(
1796
+ self,
1797
+ input_ids: Optional[torch.LongTensor] = None,
1798
+ input_ids_masked: Optional[torch.LongTensor] = None,
1799
+ pixel_values: Optional[torch.FloatTensor] = None,
1800
+ codebook_pixel_values: Optional[torch.FloatTensor] = None,
1801
+ attention_mask: Optional[torch.Tensor] = None,
1802
+ token_type_ids: Optional[torch.Tensor] = None,
1803
+ bool_masked_pos: Optional[torch.Tensor] = None,
1804
+ position_ids: Optional[torch.LongTensor] = None,
1805
+ image_attention_mask: Optional[torch.Tensor] = None,
1806
+ skip_unmasked_multimodal_encoder: Optional[bool] = None,
1807
+ mlm_labels: Optional[torch.Tensor] = None,
1808
+ mim_labels: Optional[torch.Tensor] = None,
1809
+ itm_labels: Optional[torch.Tensor] = None,
1810
+ output_attentions: Optional[bool] = None,
1811
+ output_hidden_states: bool = True,
1812
+ return_dict: Optional[bool] = None,
1813
+ return_loss: Optional[bool] = None,
1814
+ ) -> Union[Tuple[torch.Tensor], FlavaForPreTrainingOutput]:
1815
+ """
1816
+ Examples:
1817
+ ```python
1818
+ >>> from PIL import Image
1819
+ >>> import requests
1820
+ >>> from transformers import FlavaForPreTraining, AutoProcessor
1821
+
1822
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1823
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1824
+
1825
+ >>> model = FlavaForPreTraining.from_pretrained("facebook/flava-full")
1826
+ >>> processor = AutoProcessor.from_pretrained("facebook/flava-full")
1827
+
1828
+ >>> text = ["a photo of a cat"]
1829
+
1830
+ >>> inputs = processor(
1831
+ ... images=[image],
1832
+ ... text=text,
1833
+ ... return_masks=True,
1834
+ ... return_codebook_pixels=True,
1835
+ ... padding=True,
1836
+ ... max_length=77,
1837
+ ... return_tensors="pt",
1838
+ ... )
1839
+
1840
+
1841
+ >>> output = model(**inputs)
1842
+ ```
1843
+
1844
+ Return:
1845
+
1846
+ """
1847
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1848
+ return_loss = return_loss if return_loss is not None else self.config.return_loss
1849
+
1850
+ skip_unmasked_multimodal_encoder = (
1851
+ skip_unmasked_multimodal_encoder
1852
+ if skip_unmasked_multimodal_encoder is not None
1853
+ else self.skip_unmasked_multimodal_encoder
1854
+ )
1855
+
1856
+ if input_ids_masked is None and input_ids is not None:
1857
+ logger.warning(
1858
+ "`input_ids_masked` isn't passed which means MLM loss won't be calculated correctlySetting it to"
1859
+ " `input_ids` so that model can work. Please pass it if this is unintentional. This is usually OKAY if"
1860
+ " you are doing inference on unmasked text..."
1861
+ )
1862
+ input_ids_masked = input_ids
1863
+
1864
+ flava_output = self.flava(
1865
+ input_ids=input_ids,
1866
+ pixel_values=pixel_values,
1867
+ attention_mask=attention_mask,
1868
+ token_type_ids=token_type_ids,
1869
+ position_ids=position_ids,
1870
+ image_attention_mask=image_attention_mask,
1871
+ # Don't need unmasked multimodal embedding for anything so skip it
1872
+ # NOTE: ITM uses masked version
1873
+ skip_multimodal_encoder=skip_unmasked_multimodal_encoder,
1874
+ output_attentions=output_attentions,
1875
+ output_hidden_states=output_hidden_states,
1876
+ # Pass true to have deterministic outputs
1877
+ return_dict=True,
1878
+ )
1879
+
1880
+ flava_masked_output = self.flava(
1881
+ input_ids=input_ids_masked,
1882
+ pixel_values=pixel_values,
1883
+ attention_mask=attention_mask,
1884
+ token_type_ids=token_type_ids,
1885
+ image_attention_mask=image_attention_mask,
1886
+ bool_masked_pos=bool_masked_pos,
1887
+ output_attentions=output_attentions,
1888
+ output_hidden_states=output_hidden_states,
1889
+ return_dict=True,
1890
+ )
1891
+
1892
+ pos_mask = None
1893
+
1894
+ image_embeddings = flava_output.image_embeddings
1895
+ text_embeddings = flava_output.text_embeddings
1896
+ image_masked_embeddings = flava_masked_output.image_embeddings
1897
+ text_masked_embeddings = flava_masked_output.text_embeddings
1898
+ multimodal_masked_embeddings = flava_masked_output.multimodal_embeddings
1899
+
1900
+ total_loss = mim_loss = mlm_loss = mmm_text_loss = mmm_image_loss = gc_loss = itm_loss = None
1901
+ mim_logits = mlm_logits = mmm_text_logits = mmm_image_logits = None
1902
+ itm_logits = logits_per_image = logits_per_text = None
1903
+
1904
+ # Calculate mim_labels if necessary from the image_codebook
1905
+ if image_masked_embeddings is not None or multimodal_masked_embeddings is not None:
1906
+ if mim_labels is None and return_loss:
1907
+ if self.image_codebook is None:
1908
+ raise RuntimeError(
1909
+ "`return_loss` is set to True but the image codebook is not initialized and no `mim_labels` "
1910
+ " have been passed. Reinstantiate the model with `init_codebook` set to True or "
1911
+ "pass in your custom `mim_labels`"
1912
+ )
1913
+ if codebook_pixel_values is None:
1914
+ raise ValueError(
1915
+ "`codebook_pixel_value` are required to generate `mim_labels` if loss is expected. "
1916
+ "Call `AutoProcessor` with `return_codebook_pixels` set to True"
1917
+ )
1918
+ mim_labels = self.image_codebook.get_codebook_indices(codebook_pixel_values)
1919
+ # Unimodal MIM Loss
1920
+ # If multimodal embeddings are present, we will calculate MMM loss
1921
+ if self.mim_weight > 0 and image_masked_embeddings is not None and multimodal_masked_embeddings is None:
1922
+ sequence_for_image = image_masked_embeddings
1923
+
1924
+ if mim_labels is not None:
1925
+ mim_labels = self._resize_to_2d(mim_labels)
1926
+ bool_masked_pos = self._resize_to_2d(bool_masked_pos)
1927
+ mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index
1928
+
1929
+ sequence_for_image = sequence_for_image[:, -mim_labels.size(1) :, :]
1930
+ masked_tokens = mim_labels.ne(self.ce_ignore_index)
1931
+ mim_labels_filtered = mim_labels[masked_tokens]
1932
+ sequence_for_image = sequence_for_image[masked_tokens, :]
1933
+ mim_logits = self.mim_head(sequence_for_image)
1934
+ if return_loss:
1935
+ mim_loss = nn.functional.cross_entropy(
1936
+ mim_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1)
1937
+ )
1938
+ mim_loss *= self.mim_weight
1939
+ else:
1940
+ mim_logits = self.mim_head(sequence_for_image)
1941
+
1942
+ # Unimodal MLM Loss
1943
+ if self.mlm_weight > 0 and text_masked_embeddings is not None and multimodal_masked_embeddings is None:
1944
+ sequence_for_text = text_masked_embeddings
1945
+ if mlm_labels is not None:
1946
+ mlm_labels = self._resize_to_2d(mlm_labels)
1947
+ sequence_for_text = sequence_for_text[:, -mlm_labels.size(1) :, :]
1948
+ masked_tokens = mlm_labels.ne(self.ce_ignore_index)
1949
+ mlm_labels_filtered = mlm_labels[masked_tokens]
1950
+ sequence_for_text = sequence_for_text[masked_tokens, :]
1951
+ mlm_logits = self.mlm_head(sequence_for_text)
1952
+ if return_loss:
1953
+ mlm_loss = nn.functional.cross_entropy(
1954
+ mlm_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1)
1955
+ )
1956
+ mlm_loss *= self.mlm_weight
1957
+ else:
1958
+ mlm_logits = self.mlm_head(sequence_for_text)
1959
+
1960
+ # ITM Loss
1961
+ if self.itm_weight > 0 and multimodal_masked_embeddings is not None:
1962
+ itm_logits = self.itm_head(multimodal_masked_embeddings)
1963
+
1964
+ if itm_labels is not None:
1965
+ pos_pairs = itm_labels.ne(0)
1966
+ pos_mask = torch.where(pos_pairs.any(), pos_pairs, pos_pairs.new([True]))
1967
+ if return_loss:
1968
+ itm_loss = nn.functional.cross_entropy(itm_logits, itm_labels)
1969
+ itm_loss *= self.itm_weight
1970
+
1971
+ if multimodal_masked_embeddings is not None:
1972
+ multimodal_masked_embeddings = multimodal_masked_embeddings[pos_mask]
1973
+
1974
+ if mlm_labels is not None:
1975
+ mlm_labels = mlm_labels[pos_mask]
1976
+
1977
+ if mim_labels is not None:
1978
+ mim_labels = mim_labels[pos_mask]
1979
+ bool_masked_pos = bool_masked_pos[pos_mask]
1980
+
1981
+ # MMM Image Loss
1982
+ if multimodal_masked_embeddings is not None and self.mmm_image_weight > 0:
1983
+ sequence_for_image = multimodal_masked_embeddings
1984
+ end_index = image_masked_embeddings.size(1) - 1
1985
+ sequence_for_image = sequence_for_image[:, 2 : 2 + end_index, :]
1986
+
1987
+ if mim_labels is not None:
1988
+ mim_labels = self._resize_to_2d(mim_labels)
1989
+ bool_masked_pos = self._resize_to_2d(bool_masked_pos)
1990
+ mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index
1991
+
1992
+ masked_tokens = mim_labels.ne(self.ce_ignore_index)
1993
+ mim_labels_filtered = mim_labels[masked_tokens]
1994
+ sequence_for_image = sequence_for_image[masked_tokens, :]
1995
+ mmm_image_logits = self.mmm_image_head(sequence_for_image)
1996
+ if return_loss:
1997
+ mmm_image_loss = nn.functional.cross_entropy(
1998
+ mmm_image_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1)
1999
+ )
2000
+ mmm_image_loss *= self.mmm_image_weight
2001
+ else:
2002
+ mmm_image_logits = self.mmm_image_head(sequence_for_image)
2003
+
2004
+ # MMM Text Loss
2005
+ if multimodal_masked_embeddings is not None and self.mmm_text_weight > 0:
2006
+ sequence_for_text = multimodal_masked_embeddings
2007
+ sequence_for_text = sequence_for_text[:, -text_masked_embeddings.size(1) :, :]
2008
+
2009
+ if mlm_labels is not None:
2010
+ mlm_labels = self._resize_to_2d(mlm_labels)
2011
+ masked_tokens = mlm_labels.ne(self.ce_ignore_index)
2012
+ mlm_labels_filtered = mlm_labels[masked_tokens]
2013
+ sequence_for_text = sequence_for_text[masked_tokens, :]
2014
+ mmm_text_logits = self.mmm_text_head(sequence_for_text)
2015
+ if return_loss:
2016
+ mmm_text_loss = nn.functional.cross_entropy(
2017
+ mmm_text_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1)
2018
+ )
2019
+ mmm_text_loss *= self.mmm_text_weight
2020
+ else:
2021
+ mmm_text_logits = self.mmm_text_head(sequence_for_text)
2022
+
2023
+ # Global Contrastive Loss
2024
+ if image_embeddings is not None and text_embeddings is not None and self.global_contrastive_weight > 0:
2025
+ text_embedding = self.flava.text_projection(text_embeddings[:, 0, :])
2026
+ text_embedding = nn.functional.normalize(text_embedding, dim=-1)
2027
+
2028
+ image_embedding = self.flava.image_projection(image_embeddings[:, 0, :])
2029
+ image_embedding = nn.functional.normalize(image_embedding, dim=-1)
2030
+
2031
+ self.flava.logit_scale.data.clamp_(LOGIT_SCALE_CLAMP_MIN, LOGIT_SCALE_CLAMP_MAX)
2032
+
2033
+ logits_per_image, logits_per_text, gc_labels = self.global_contrastive_head(
2034
+ image_embedding, text_embedding, self.flava.logit_scale
2035
+ )
2036
+
2037
+ # Apply ITM negative mask if any
2038
+ if pos_mask is not None:
2039
+ logits_per_image = logits_per_image[pos_mask]
2040
+ logits_per_text = logits_per_text[pos_mask]
2041
+ gc_labels = gc_labels[pos_mask]
2042
+
2043
+ if return_loss:
2044
+ gc_loss_image = nn.functional.cross_entropy(logits_per_image, gc_labels)
2045
+ gc_loss_text = nn.functional.cross_entropy(logits_per_text, gc_labels)
2046
+ gc_loss = (gc_loss_image + gc_loss_text) / 2
2047
+ gc_loss *= self.global_contrastive_weight
2048
+
2049
+ flava_losses = FlavaLosses(
2050
+ mim=mim_loss,
2051
+ mlm=mlm_loss,
2052
+ itm=itm_loss,
2053
+ global_contrastive=gc_loss,
2054
+ mmm_image=mmm_image_loss,
2055
+ mmm_text=mmm_text_loss,
2056
+ )
2057
+
2058
+ if return_loss and not flava_losses.all_none():
2059
+ total_loss = sum(loss if loss is not None else 0 for loss in flava_losses.values())
2060
+
2061
+ if not return_dict:
2062
+ output = (
2063
+ image_embeddings,
2064
+ flava_output.image_output.to_tuple() if flava_output.image_output is not None else None,
2065
+ text_embeddings,
2066
+ flava_output.text_output.to_tuple() if flava_output.text_output is not None else None,
2067
+ flava_output.multimodal_embeddings,
2068
+ flava_output.multimodal_output.to_tuple() if flava_output.multimodal_output is not None else None,
2069
+ image_masked_embeddings,
2070
+ flava_masked_output.image_output.to_tuple() if flava_masked_output.image_output is not None else None,
2071
+ text_masked_embeddings,
2072
+ flava_masked_output.text_output.to_tuple() if flava_masked_output.text_output is not None else None,
2073
+ multimodal_masked_embeddings,
2074
+ flava_masked_output.multimodal_output.to_tuple()
2075
+ if flava_masked_output.multimodal_output is not None
2076
+ else None,
2077
+ mim_logits,
2078
+ mlm_logits,
2079
+ itm_logits,
2080
+ logits_per_image,
2081
+ logits_per_image,
2082
+ mmm_image_logits,
2083
+ mmm_text_logits,
2084
+ )
2085
+ if return_loss and not flava_losses.all_none():
2086
+ output = (
2087
+ total_loss,
2088
+ flava_losses,
2089
+ ) + output
2090
+
2091
+ # Filter None as transformer by default won't handle it
2092
+ return tuple(x for x in output if x is None)
2093
+
2094
+ return FlavaForPreTrainingOutput(
2095
+ loss=total_loss,
2096
+ loss_info=flava_losses,
2097
+ image_embeddings=image_embeddings,
2098
+ image_output=flava_output.image_output,
2099
+ text_embeddings=text_embeddings,
2100
+ text_output=flava_output.text_output,
2101
+ multimodal_embeddings=flava_output.multimodal_embeddings,
2102
+ multimodal_output=flava_output.multimodal_output,
2103
+ image_masked_embeddings=image_masked_embeddings,
2104
+ image_masked_output=flava_masked_output.image_output,
2105
+ text_masked_embeddings=text_masked_embeddings,
2106
+ text_masked_output=flava_masked_output.text_output,
2107
+ multimodal_masked_embeddings=multimodal_masked_embeddings,
2108
+ multimodal_masked_output=flava_masked_output.multimodal_output,
2109
+ mim_logits=mim_logits,
2110
+ mlm_logits=mlm_logits,
2111
+ itm_logits=itm_logits,
2112
+ contrastive_logits_per_image=logits_per_image,
2113
+ contrastive_logits_per_text=logits_per_text,
2114
+ mmm_image_logits=mmm_image_logits,
2115
+ mmm_text_logits=mmm_text_logits,
2116
+ )
2117
+
2118
+
2119
+ __all__ = [
2120
+ "FlavaForPreTraining",
2121
+ "FlavaImageCodebook",
2122
+ "FlavaImageModel",
2123
+ "FlavaModel",
2124
+ "FlavaMultimodalModel",
2125
+ "FlavaPreTrainedModel",
2126
+ "FlavaTextModel",
2127
+ ]
docs/transformers/build/lib/transformers/models/flava/processing_flava.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Image/Text processor class for FLAVA
17
+ """
18
+
19
+ import warnings
20
+ from typing import List, Optional, Union
21
+
22
+ from ...image_utils import ImageInput
23
+ from ...processing_utils import ProcessorMixin
24
+ from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
25
+ from ...utils import TensorType
26
+
27
+
28
+ class FlavaProcessor(ProcessorMixin):
29
+ r"""
30
+ Constructs a FLAVA processor which wraps a FLAVA image processor and a FLAVA tokenizer into a single processor.
31
+
32
+ [`FlavaProcessor`] offers all the functionalities of [`FlavaImageProcessor`] and [`BertTokenizerFast`]. See the
33
+ [`~FlavaProcessor.__call__`] and [`~FlavaProcessor.decode`] for more information.
34
+
35
+ Args:
36
+ image_processor ([`FlavaImageProcessor`], *optional*): The image processor is a required input.
37
+ tokenizer ([`BertTokenizerFast`], *optional*): The tokenizer is a required input.
38
+ """
39
+
40
+ attributes = ["image_processor", "tokenizer"]
41
+ image_processor_class = "FlavaImageProcessor"
42
+ tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
43
+
44
+ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
45
+ feature_extractor = None
46
+ if "feature_extractor" in kwargs:
47
+ warnings.warn(
48
+ "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
49
+ " instead.",
50
+ FutureWarning,
51
+ )
52
+ feature_extractor = kwargs.pop("feature_extractor")
53
+
54
+ image_processor = image_processor if image_processor is not None else feature_extractor
55
+ if image_processor is None:
56
+ raise ValueError("You need to specify an `image_processor`.")
57
+ if tokenizer is None:
58
+ raise ValueError("You need to specify a `tokenizer`.")
59
+
60
+ super().__init__(image_processor, tokenizer)
61
+ self.current_processor = self.image_processor
62
+
63
+ def __call__(
64
+ self,
65
+ images: Optional[ImageInput] = None,
66
+ text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
67
+ add_special_tokens: bool = True,
68
+ padding: Union[bool, str, PaddingStrategy] = False,
69
+ truncation: Union[bool, str, TruncationStrategy] = False,
70
+ max_length: Optional[int] = None,
71
+ stride: int = 0,
72
+ pad_to_multiple_of: Optional[int] = None,
73
+ return_image_mask: Optional[bool] = None,
74
+ return_codebook_pixels: Optional[bool] = None,
75
+ return_token_type_ids: Optional[bool] = None,
76
+ return_attention_mask: Optional[bool] = None,
77
+ return_overflowing_tokens: bool = False,
78
+ return_special_tokens_mask: bool = False,
79
+ return_offsets_mapping: bool = False,
80
+ return_length: bool = False,
81
+ verbose: bool = True,
82
+ return_tensors: Optional[Union[str, TensorType]] = None,
83
+ **kwargs,
84
+ ):
85
+ """
86
+ This method uses [`FlavaImageProcessor.__call__`] method to prepare image(s) for the model, and
87
+ [`BertTokenizerFast.__call__`] to prepare text for the model.
88
+
89
+ Please refer to the docstring of the above two methods for more information.
90
+ """
91
+
92
+ if text is None and images is None:
93
+ raise ValueError("You have to specify either text or images. Both cannot be none.")
94
+
95
+ if text is not None:
96
+ encoding = self.tokenizer(
97
+ text=text,
98
+ add_special_tokens=add_special_tokens,
99
+ padding=padding,
100
+ truncation=truncation,
101
+ max_length=max_length,
102
+ stride=stride,
103
+ pad_to_multiple_of=pad_to_multiple_of,
104
+ return_token_type_ids=return_token_type_ids,
105
+ return_attention_mask=return_attention_mask,
106
+ return_overflowing_tokens=return_overflowing_tokens,
107
+ return_special_tokens_mask=return_special_tokens_mask,
108
+ return_offsets_mapping=return_offsets_mapping,
109
+ return_length=return_length,
110
+ verbose=verbose,
111
+ return_tensors=return_tensors,
112
+ **kwargs,
113
+ )
114
+ if images is not None:
115
+ image_features = self.image_processor(
116
+ images,
117
+ return_image_mask=return_image_mask,
118
+ return_codebook_pixels=return_codebook_pixels,
119
+ return_tensors=return_tensors,
120
+ **kwargs,
121
+ )
122
+
123
+ if text is not None and images is not None:
124
+ encoding.update(image_features)
125
+ return encoding
126
+ elif text is not None:
127
+ return encoding
128
+ else:
129
+ return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
130
+
131
+ def batch_decode(self, *args, **kwargs):
132
+ """
133
+ This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
134
+ refer to the docstring of this method for more information.
135
+ """
136
+ return self.tokenizer.batch_decode(*args, **kwargs)
137
+
138
+ def decode(self, *args, **kwargs):
139
+ """
140
+ This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
141
+ the docstring of this method for more information.
142
+ """
143
+ return self.tokenizer.decode(*args, **kwargs)
144
+
145
+ @property
146
+ def model_input_names(self):
147
+ tokenizer_input_names = self.tokenizer.model_input_names
148
+ image_processor_input_names = self.image_processor.model_input_names
149
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
150
+
151
+ @property
152
+ def feature_extractor_class(self):
153
+ warnings.warn(
154
+ "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
155
+ FutureWarning,
156
+ )
157
+ return self.image_processor_class
158
+
159
+ @property
160
+ def feature_extractor(self):
161
+ warnings.warn(
162
+ "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
163
+ FutureWarning,
164
+ )
165
+ return self.image_processor
166
+
167
+
168
+ __all__ = ["FlavaProcessor"]