Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- docs/resources/grpo_countdown.png +3 -0
- docs/resources/grpo_geoqa.png +3 -0
- docs/resources/grpo_openr1_multimodal.png +3 -0
- docs/transformers/build/lib/transformers/models/depth_anything/convert_distill_any_depth_to_hf.py +246 -0
- docs/transformers/build/lib/transformers/models/depth_anything/modeling_depth_anything.py +469 -0
- docs/transformers/build/lib/transformers/models/depth_pro/configuration_depth_pro.py +205 -0
- docs/transformers/build/lib/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py +254 -0
- docs/transformers/build/lib/transformers/models/depth_pro/image_processing_depth_pro.py +392 -0
- docs/transformers/build/lib/transformers/models/depth_pro/image_processing_depth_pro_fast.py +189 -0
- docs/transformers/build/lib/transformers/models/depth_pro/modeling_depth_pro.py +1218 -0
- docs/transformers/build/lib/transformers/models/detr/__init__.py +31 -0
- docs/transformers/build/lib/transformers/models/detr/configuration_detr.py +289 -0
- docs/transformers/build/lib/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py +277 -0
- docs/transformers/build/lib/transformers/models/detr/convert_detr_to_pytorch.py +385 -0
- docs/transformers/build/lib/transformers/models/detr/feature_extraction_detr.py +48 -0
- docs/transformers/build/lib/transformers/models/detr/image_processing_detr_fast.py +1312 -0
- docs/transformers/build/lib/transformers/models/detr/modeling_detr.py +1815 -0
- docs/transformers/build/lib/transformers/models/dialogpt/__init__.py +0 -0
- docs/transformers/build/lib/transformers/models/dialogpt/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py +46 -0
- docs/transformers/build/lib/transformers/models/diffllama/__init__.py +27 -0
- docs/transformers/build/lib/transformers/models/diffllama/configuration_diffllama.py +199 -0
- docs/transformers/build/lib/transformers/models/esm/openfold_utils/rigid_utils.py +1242 -0
- docs/transformers/build/lib/transformers/models/falcon/configuration_falcon.py +211 -0
- docs/transformers/build/lib/transformers/models/falcon/convert_custom_code_checkpoint.py +74 -0
- docs/transformers/build/lib/transformers/models/falcon/modeling_falcon.py +1566 -0
- docs/transformers/build/lib/transformers/models/falcon_mamba/__init__.py +27 -0
- docs/transformers/build/lib/transformers/models/falcon_mamba/configuration_falcon_mamba.py +162 -0
- docs/transformers/build/lib/transformers/models/falcon_mamba/modeling_falcon_mamba.py +873 -0
- docs/transformers/build/lib/transformers/models/fastspeech2_conformer/__init__.py +28 -0
- docs/transformers/build/lib/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +480 -0
- docs/transformers/build/lib/transformers/models/fastspeech2_conformer/convert_fastspeech2_conformer_original_pytorch_checkpoint_to_pytorch.py +210 -0
- docs/transformers/build/lib/transformers/models/fastspeech2_conformer/convert_hifigan.py +134 -0
- docs/transformers/build/lib/transformers/models/fastspeech2_conformer/convert_model_with_hifigan.py +102 -0
- docs/transformers/build/lib/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +1697 -0
- docs/transformers/build/lib/transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +188 -0
- docs/transformers/build/lib/transformers/models/flaubert/__init__.py +29 -0
- docs/transformers/build/lib/transformers/models/flaubert/configuration_flaubert.py +235 -0
- docs/transformers/build/lib/transformers/models/flaubert/modeling_flaubert.py +1739 -0
- docs/transformers/build/lib/transformers/models/flaubert/modeling_tf_flaubert.py +1344 -0
- docs/transformers/build/lib/transformers/models/flaubert/tokenization_flaubert.py +568 -0
- docs/transformers/build/lib/transformers/models/flava/__init__.py +31 -0
- docs/transformers/build/lib/transformers/models/flava/configuration_flava.py +701 -0
- docs/transformers/build/lib/transformers/models/flava/convert_dalle_to_flava_codebook.py +102 -0
- docs/transformers/build/lib/transformers/models/flava/convert_flava_original_pytorch_to_hf.py +99 -0
- docs/transformers/build/lib/transformers/models/flava/feature_extraction_flava.py +38 -0
- docs/transformers/build/lib/transformers/models/flava/image_processing_flava.py +705 -0
- docs/transformers/build/lib/transformers/models/flava/image_processing_flava_fast.py +549 -0
- docs/transformers/build/lib/transformers/models/flava/modeling_flava.py +2127 -0
- docs/transformers/build/lib/transformers/models/flava/processing_flava.py +168 -0
.gitattributes
CHANGED
|
@@ -48,3 +48,6 @@ wandb/offline-run-20250624_115955-iye05c18/run-iye05c18.wandb filter=lfs diff=lf
|
|
| 48 |
wandb/offline-run-20250721_000454-up3efnok/run-up3efnok.wandb filter=lfs diff=lfs merge=lfs -text
|
| 49 |
wandb/offline-run-20250722_003110-femxkckf/run-femxkckf.wandb filter=lfs diff=lfs merge=lfs -text
|
| 50 |
seamless_interaction/assets/banner.gif filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
wandb/offline-run-20250721_000454-up3efnok/run-up3efnok.wandb filter=lfs diff=lfs merge=lfs -text
|
| 49 |
wandb/offline-run-20250722_003110-femxkckf/run-femxkckf.wandb filter=lfs diff=lfs merge=lfs -text
|
| 50 |
seamless_interaction/assets/banner.gif filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
docs/resources/grpo_countdown.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
docs/resources/grpo_geoqa.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
docs/resources/grpo_openr1_multimodal.png filter=lfs diff=lfs merge=lfs -text
|
docs/resources/grpo_countdown.png
ADDED
|
Git LFS Details
|
docs/resources/grpo_geoqa.png
ADDED
|
Git LFS Details
|
docs/resources/grpo_openr1_multimodal.png
ADDED
|
Git LFS Details
|
docs/transformers/build/lib/transformers/models/depth_anything/convert_distill_any_depth_to_hf.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert Distill Any Depth checkpoints from the original repository. URL:
|
| 16 |
+
https://github.com/Westlake-AGI-Lab/Distill-Any-Depth"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import re
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
import requests
|
| 23 |
+
import torch
|
| 24 |
+
from huggingface_hub import hf_hub_download
|
| 25 |
+
from PIL import Image
|
| 26 |
+
from safetensors.torch import load_file
|
| 27 |
+
|
| 28 |
+
from transformers import DepthAnythingConfig, DepthAnythingForDepthEstimation, Dinov2Config, DPTImageProcessor
|
| 29 |
+
from transformers.utils import logging
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logging.set_verbosity_info()
|
| 33 |
+
logger = logging.get_logger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
| 37 |
+
r"(backbone|pretrained)\.cls_token": r"backbone.embeddings.cls_token",
|
| 38 |
+
r"(backbone|pretrained)\.mask_token": r"backbone.embeddings.mask_token",
|
| 39 |
+
r"(backbone|pretrained)\.pos_embed": r"backbone.embeddings.position_embeddings",
|
| 40 |
+
r"(backbone|pretrained)\.patch_embed\.proj\.(weight|bias)": r"backbone.embeddings.patch_embeddings.projection.\2",
|
| 41 |
+
r"(backbone|pretrained)\.norm\.(weight|bias)": r"backbone.layernorm.\2",
|
| 42 |
+
r"(backbone|pretrained)(\.blocks(\.\d+)?)?\.(\d+)\.attn\.proj\.(weight|bias)": r"backbone.encoder.layer.\4.attention.output.dense.\5",
|
| 43 |
+
r"(backbone|pretrained)(\.blocks(\.\d+)?)?\.(\d+)\.ls(1|2)\.gamma": r"backbone.encoder.layer.\4.layer_scale\5.lambda1",
|
| 44 |
+
r"(backbone|pretrained)(\.blocks(\.\d+)?)?\.(\d+)\.mlp\.fc(1|2)\.(weight|bias)": r"backbone.encoder.layer.\4.mlp.fc\5.\6",
|
| 45 |
+
r"(backbone|pretrained)(\.blocks(\.\d+)?)?\.(\d+)\.norm(1|2)\.(weight|bias)": r"backbone.encoder.layer.\4.norm\5.\6",
|
| 46 |
+
r"depth_head\.projects\.(\d+)\.(weight|bias)": r"neck.reassemble_stage.layers.\1.projection.\2",
|
| 47 |
+
r"depth_head\.resize_layers\.(?!2)(\d+)\.(weight|bias)": r"neck.reassemble_stage.layers.\1.resize.\2",
|
| 48 |
+
r"depth_head\.scratch\.layer(\d+)_rn\.weight": lambda m: f"neck.convs.{int(m[1]) - 1}.weight",
|
| 49 |
+
r"depth_head\.scratch\.output_conv(\d+)(?:\.(\d+))?\.(weight|bias)": lambda m: (
|
| 50 |
+
f"head.conv{int(m[1]) + (int(m[2]) // 2 if m[2] else 0)}.{m[3]}" if m[1] == "2" else f"head.conv{m[1]}.{m[3]}"
|
| 51 |
+
),
|
| 52 |
+
r"depth_head\.scratch\.refinenet(\d+)\.out_conv\.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{3 - (int(m[1]) - 1)}.projection.{m[2]}",
|
| 53 |
+
r"depth_head\.scratch\.refinenet(\d+)\.resConfUnit(\d+)\.conv(\d+)\.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{3 - (int(m[1]) - 1)}.residual_layer{m[2]}.convolution{m[3]}.{m[4]}",
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_dpt_config(model_name):
|
| 58 |
+
if "small" in model_name:
|
| 59 |
+
out_indices = [3, 6, 9, 12]
|
| 60 |
+
backbone_config = Dinov2Config.from_pretrained(
|
| 61 |
+
"facebook/dinov2-small", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False
|
| 62 |
+
)
|
| 63 |
+
fusion_hidden_size = 64
|
| 64 |
+
neck_hidden_sizes = [48, 96, 192, 384]
|
| 65 |
+
elif "base" in model_name:
|
| 66 |
+
out_indices = [3, 6, 9, 12]
|
| 67 |
+
backbone_config = Dinov2Config.from_pretrained(
|
| 68 |
+
"facebook/dinov2-base", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False
|
| 69 |
+
)
|
| 70 |
+
fusion_hidden_size = 128
|
| 71 |
+
neck_hidden_sizes = [96, 192, 384, 768]
|
| 72 |
+
elif "large" in model_name:
|
| 73 |
+
out_indices = [5, 12, 18, 24]
|
| 74 |
+
backbone_config = Dinov2Config.from_pretrained(
|
| 75 |
+
"facebook/dinov2-large", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False
|
| 76 |
+
)
|
| 77 |
+
fusion_hidden_size = 256
|
| 78 |
+
neck_hidden_sizes = [256, 512, 1024, 1024]
|
| 79 |
+
else:
|
| 80 |
+
raise NotImplementedError(f"Model not supported: {model_name}")
|
| 81 |
+
|
| 82 |
+
depth_estimation_type = "relative"
|
| 83 |
+
max_depth = None
|
| 84 |
+
|
| 85 |
+
config = DepthAnythingConfig(
|
| 86 |
+
reassemble_hidden_size=backbone_config.hidden_size,
|
| 87 |
+
patch_size=backbone_config.patch_size,
|
| 88 |
+
backbone_config=backbone_config,
|
| 89 |
+
fusion_hidden_size=fusion_hidden_size,
|
| 90 |
+
neck_hidden_sizes=neck_hidden_sizes,
|
| 91 |
+
depth_estimation_type=depth_estimation_type,
|
| 92 |
+
max_depth=max_depth,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
return config
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def convert_key_pattern(key, mapping):
|
| 99 |
+
for pattern, replacement in mapping.items():
|
| 100 |
+
match = re.fullmatch(pattern, key)
|
| 101 |
+
if match:
|
| 102 |
+
if callable(replacement):
|
| 103 |
+
return replacement(match)
|
| 104 |
+
return re.sub(pattern, replacement, key)
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def convert_keys(state_dict, config):
|
| 109 |
+
new_state_dict = {}
|
| 110 |
+
qkv_pattern = r"(backbone|pretrained)(\.blocks(\.\d+)?)?\.(\d+)\.attn\.qkv\.(weight|bias)"
|
| 111 |
+
qkv_keys = [k for k in list(state_dict.keys()) if re.match(qkv_pattern, k)]
|
| 112 |
+
for old_key in qkv_keys:
|
| 113 |
+
value = state_dict.pop(old_key)
|
| 114 |
+
match = re.match(qkv_pattern, old_key)
|
| 115 |
+
_, _, _, layer, attr = match.groups()
|
| 116 |
+
hidden_size = config.backbone_config.hidden_size
|
| 117 |
+
q = value[:hidden_size]
|
| 118 |
+
k = value[hidden_size : hidden_size * 2]
|
| 119 |
+
v = value[-hidden_size:]
|
| 120 |
+
|
| 121 |
+
for proj, tensor in zip(["query", "key", "value"], [q, k, v]):
|
| 122 |
+
new_key = f"backbone.encoder.layer.{layer}.attention.attention.{proj}.{attr}"
|
| 123 |
+
new_state_dict[new_key] = tensor
|
| 124 |
+
|
| 125 |
+
for old_key in list(state_dict.keys()):
|
| 126 |
+
value = state_dict.pop(old_key)
|
| 127 |
+
new_key = convert_key_pattern(old_key, ORIGINAL_TO_CONVERTED_KEY_MAPPING)
|
| 128 |
+
|
| 129 |
+
new_state_dict[new_key] = value
|
| 130 |
+
|
| 131 |
+
return new_state_dict
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def prepare_img():
|
| 135 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 136 |
+
return Image.open(requests.get(url, stream=True).raw)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
name_to_checkpoint = {
|
| 140 |
+
"distill-any-depth-small": "small/model.safetensors",
|
| 141 |
+
"distill-any-depth-base": "base/model.safetensors",
|
| 142 |
+
"distill-any-depth-large": "large/model.safetensors",
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@torch.no_grad()
|
| 147 |
+
def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, verify_logits):
|
| 148 |
+
config = get_dpt_config(model_name)
|
| 149 |
+
|
| 150 |
+
repo_id = "xingyang1/Distill-Any-Depth"
|
| 151 |
+
filepath = hf_hub_download(repo_id=repo_id, filename=name_to_checkpoint[model_name])
|
| 152 |
+
state_dict = load_file(filepath)
|
| 153 |
+
|
| 154 |
+
converted_state_dict = convert_keys(state_dict, config)
|
| 155 |
+
|
| 156 |
+
model = DepthAnythingForDepthEstimation(config)
|
| 157 |
+
model.load_state_dict(converted_state_dict)
|
| 158 |
+
model.eval()
|
| 159 |
+
|
| 160 |
+
processor = DPTImageProcessor(
|
| 161 |
+
do_resize=True,
|
| 162 |
+
size={"height": 518, "width": 518},
|
| 163 |
+
ensure_multiple_of=14,
|
| 164 |
+
keep_aspect_ratio=True,
|
| 165 |
+
do_rescale=True,
|
| 166 |
+
do_normalize=True,
|
| 167 |
+
image_mean=[0.485, 0.456, 0.406],
|
| 168 |
+
image_std=[0.229, 0.224, 0.225],
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 172 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
| 173 |
+
|
| 174 |
+
pixel_values = processor(image, return_tensors="pt").pixel_values
|
| 175 |
+
|
| 176 |
+
with torch.no_grad():
|
| 177 |
+
outputs = model(pixel_values)
|
| 178 |
+
predicted_depth = outputs.predicted_depth
|
| 179 |
+
|
| 180 |
+
print("Shape of predicted depth:", predicted_depth.shape)
|
| 181 |
+
print("First values:", predicted_depth[0, :3, :3])
|
| 182 |
+
|
| 183 |
+
if verify_logits:
|
| 184 |
+
print("Verifying logits...")
|
| 185 |
+
expected_shape = torch.Size([1, 518, 686])
|
| 186 |
+
|
| 187 |
+
if model_name == "distill-any-depth-small":
|
| 188 |
+
expected_slice = torch.tensor(
|
| 189 |
+
[[2.5653, 2.5249, 2.5570], [2.4897, 2.5235, 2.5355], [2.5255, 2.5261, 2.5422]]
|
| 190 |
+
)
|
| 191 |
+
elif model_name == "distill-any-depth-base":
|
| 192 |
+
expected_slice = torch.tensor(
|
| 193 |
+
[[4.8976, 4.9075, 4.9403], [4.8872, 4.8906, 4.9448], [4.8712, 4.8898, 4.8838]]
|
| 194 |
+
)
|
| 195 |
+
elif model_name == "distill-any-depth-large":
|
| 196 |
+
expected_slice = torch.tensor(
|
| 197 |
+
[[55.1067, 51.1828, 51.6803], [51.9098, 50.7529, 51.4494], [50.1745, 50.5491, 50.8818]]
|
| 198 |
+
)
|
| 199 |
+
else:
|
| 200 |
+
raise ValueError("Not supported")
|
| 201 |
+
|
| 202 |
+
assert predicted_depth.shape == torch.Size(expected_shape)
|
| 203 |
+
assert torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-4)
|
| 204 |
+
print("Looks ok!")
|
| 205 |
+
|
| 206 |
+
if pytorch_dump_folder_path is not None:
|
| 207 |
+
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
| 208 |
+
print(f"Saving model and processor to {pytorch_dump_folder_path}")
|
| 209 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 210 |
+
processor.save_pretrained(pytorch_dump_folder_path)
|
| 211 |
+
|
| 212 |
+
if push_to_hub:
|
| 213 |
+
print("Pushing model and processor to hub...")
|
| 214 |
+
model.push_to_hub(repo_id=f"{model_name.title()}-hf")
|
| 215 |
+
processor.push_to_hub(repo_id=f"{model_name.title()}-hf")
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
if __name__ == "__main__":
|
| 219 |
+
parser = argparse.ArgumentParser()
|
| 220 |
+
parser.add_argument(
|
| 221 |
+
"--model_name",
|
| 222 |
+
default="distill-any-depth-small",
|
| 223 |
+
type=str,
|
| 224 |
+
choices=name_to_checkpoint.keys(),
|
| 225 |
+
help="Name of the model you'd like to convert.",
|
| 226 |
+
)
|
| 227 |
+
parser.add_argument(
|
| 228 |
+
"--pytorch_dump_folder_path",
|
| 229 |
+
default=None,
|
| 230 |
+
type=str,
|
| 231 |
+
help="Path to the output PyTorch model directory.",
|
| 232 |
+
)
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--push_to_hub",
|
| 235 |
+
action="store_true",
|
| 236 |
+
help="Whether to push the model to the hub after conversion.",
|
| 237 |
+
)
|
| 238 |
+
parser.add_argument(
|
| 239 |
+
"--verify_logits",
|
| 240 |
+
action="store_true",
|
| 241 |
+
required=False,
|
| 242 |
+
help="Whether to verify the logits after conversion.",
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
args = parser.parse_args()
|
| 246 |
+
convert_dpt_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.verify_logits)
|
docs/transformers/build/lib/transformers/models/depth_anything/modeling_depth_anything.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 TikTok and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch Depth Anything model."""
|
| 16 |
+
|
| 17 |
+
from typing import List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.utils.checkpoint
|
| 21 |
+
from torch import nn
|
| 22 |
+
|
| 23 |
+
from ...file_utils import (
|
| 24 |
+
add_start_docstrings,
|
| 25 |
+
add_start_docstrings_to_model_forward,
|
| 26 |
+
replace_return_docstrings,
|
| 27 |
+
)
|
| 28 |
+
from ...modeling_outputs import DepthEstimatorOutput
|
| 29 |
+
from ...modeling_utils import PreTrainedModel
|
| 30 |
+
from ...utils import logging
|
| 31 |
+
from ...utils.backbone_utils import load_backbone
|
| 32 |
+
from .configuration_depth_anything import DepthAnythingConfig
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
logger = logging.get_logger(__name__)
|
| 36 |
+
|
| 37 |
+
# General docstring
|
| 38 |
+
_CONFIG_FOR_DOC = "DepthAnythingConfig"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
DEPTH_ANYTHING_START_DOCSTRING = r"""
|
| 42 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 43 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 44 |
+
behavior.
|
| 45 |
+
|
| 46 |
+
Parameters:
|
| 47 |
+
config ([`DepthAnythingConfig`]): Model configuration class with all the parameters of the model.
|
| 48 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 49 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
DEPTH_ANYTHING_INPUTS_DOCSTRING = r"""
|
| 53 |
+
Args:
|
| 54 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 55 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`]
|
| 56 |
+
for details.
|
| 57 |
+
output_attentions (`bool`, *optional*):
|
| 58 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 59 |
+
tensors for more detail.
|
| 60 |
+
output_hidden_states (`bool`, *optional*):
|
| 61 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 62 |
+
more detail.
|
| 63 |
+
return_dict (`bool`, *optional*):
|
| 64 |
+
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class DepthAnythingReassembleLayer(nn.Module):
|
| 69 |
+
def __init__(self, config, channels, factor):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.projection = nn.Conv2d(in_channels=config.reassemble_hidden_size, out_channels=channels, kernel_size=1)
|
| 72 |
+
|
| 73 |
+
# up/down sampling depending on factor
|
| 74 |
+
if factor > 1:
|
| 75 |
+
self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0)
|
| 76 |
+
elif factor == 1:
|
| 77 |
+
self.resize = nn.Identity()
|
| 78 |
+
elif factor < 1:
|
| 79 |
+
# so should downsample
|
| 80 |
+
self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1)
|
| 81 |
+
|
| 82 |
+
# Copied from transformers.models.dpt.modeling_dpt.DPTReassembleLayer.forward
|
| 83 |
+
def forward(self, hidden_state):
|
| 84 |
+
hidden_state = self.projection(hidden_state)
|
| 85 |
+
hidden_state = self.resize(hidden_state)
|
| 86 |
+
|
| 87 |
+
return hidden_state
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class DepthAnythingReassembleStage(nn.Module):
|
| 91 |
+
"""
|
| 92 |
+
This class reassembles the hidden states of the backbone into image-like feature representations at various
|
| 93 |
+
resolutions.
|
| 94 |
+
|
| 95 |
+
This happens in 3 stages:
|
| 96 |
+
1. Take the patch embeddings and reshape them to image-like feature representations.
|
| 97 |
+
2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`.
|
| 98 |
+
3. Resizing the spatial dimensions (height, width).
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
config (`[DepthAnythingConfig]`):
|
| 102 |
+
Model configuration class defining the model architecture.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(self, config):
|
| 106 |
+
super().__init__()
|
| 107 |
+
|
| 108 |
+
self.config = config
|
| 109 |
+
self.layers = nn.ModuleList()
|
| 110 |
+
for channels, factor in zip(config.neck_hidden_sizes, config.reassemble_factors):
|
| 111 |
+
self.layers.append(DepthAnythingReassembleLayer(config, channels=channels, factor=factor))
|
| 112 |
+
|
| 113 |
+
def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]:
|
| 114 |
+
"""
|
| 115 |
+
Args:
|
| 116 |
+
hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`):
|
| 117 |
+
List of hidden states from the backbone.
|
| 118 |
+
"""
|
| 119 |
+
out = []
|
| 120 |
+
|
| 121 |
+
for i, hidden_state in enumerate(hidden_states):
|
| 122 |
+
# reshape to (batch_size, num_channels, height, width)
|
| 123 |
+
hidden_state = hidden_state[:, 1:]
|
| 124 |
+
batch_size, _, num_channels = hidden_state.shape
|
| 125 |
+
hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels)
|
| 126 |
+
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
|
| 127 |
+
hidden_state = self.layers[i](hidden_state)
|
| 128 |
+
out.append(hidden_state)
|
| 129 |
+
|
| 130 |
+
return out
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class DepthAnythingPreActResidualLayer(nn.Module):
|
| 134 |
+
"""
|
| 135 |
+
ResidualConvUnit, pre-activate residual unit.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
config (`[DepthAnythingConfig]`):
|
| 139 |
+
Model configuration class defining the model architecture.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def __init__(self, config):
|
| 143 |
+
super().__init__()
|
| 144 |
+
|
| 145 |
+
self.activation1 = nn.ReLU()
|
| 146 |
+
self.convolution1 = nn.Conv2d(
|
| 147 |
+
config.fusion_hidden_size,
|
| 148 |
+
config.fusion_hidden_size,
|
| 149 |
+
kernel_size=3,
|
| 150 |
+
stride=1,
|
| 151 |
+
padding=1,
|
| 152 |
+
bias=True,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
self.activation2 = nn.ReLU()
|
| 156 |
+
self.convolution2 = nn.Conv2d(
|
| 157 |
+
config.fusion_hidden_size,
|
| 158 |
+
config.fusion_hidden_size,
|
| 159 |
+
kernel_size=3,
|
| 160 |
+
stride=1,
|
| 161 |
+
padding=1,
|
| 162 |
+
bias=True,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
| 166 |
+
residual = hidden_state
|
| 167 |
+
hidden_state = self.activation1(hidden_state)
|
| 168 |
+
hidden_state = self.convolution1(hidden_state)
|
| 169 |
+
hidden_state = self.activation2(hidden_state)
|
| 170 |
+
hidden_state = self.convolution2(hidden_state)
|
| 171 |
+
|
| 172 |
+
return hidden_state + residual
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class DepthAnythingFeatureFusionLayer(nn.Module):
|
| 176 |
+
"""Feature fusion layer, merges feature maps from different stages.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
config (`[DepthAnythingConfig]`):
|
| 180 |
+
Model configuration class defining the model architecture.
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def __init__(self, config):
|
| 184 |
+
super().__init__()
|
| 185 |
+
|
| 186 |
+
self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
|
| 187 |
+
|
| 188 |
+
self.residual_layer1 = DepthAnythingPreActResidualLayer(config)
|
| 189 |
+
self.residual_layer2 = DepthAnythingPreActResidualLayer(config)
|
| 190 |
+
|
| 191 |
+
def forward(self, hidden_state, residual=None, size=None):
|
| 192 |
+
if residual is not None:
|
| 193 |
+
if hidden_state.shape != residual.shape:
|
| 194 |
+
residual = nn.functional.interpolate(
|
| 195 |
+
residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False
|
| 196 |
+
)
|
| 197 |
+
hidden_state = hidden_state + self.residual_layer1(residual)
|
| 198 |
+
|
| 199 |
+
hidden_state = self.residual_layer2(hidden_state)
|
| 200 |
+
|
| 201 |
+
modifier = {"scale_factor": 2} if size is None else {"size": size}
|
| 202 |
+
|
| 203 |
+
hidden_state = nn.functional.interpolate(
|
| 204 |
+
hidden_state,
|
| 205 |
+
**modifier,
|
| 206 |
+
mode="bilinear",
|
| 207 |
+
align_corners=True,
|
| 208 |
+
)
|
| 209 |
+
hidden_state = self.projection(hidden_state)
|
| 210 |
+
|
| 211 |
+
return hidden_state
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class DepthAnythingFeatureFusionStage(nn.Module):
|
| 215 |
+
# Copied from transformers.models.dpt.modeling_dpt.DPTFeatureFusionStage.__init__ with DPT->DepthAnything
|
| 216 |
+
def __init__(self, config):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.layers = nn.ModuleList()
|
| 219 |
+
for _ in range(len(config.neck_hidden_sizes)):
|
| 220 |
+
self.layers.append(DepthAnythingFeatureFusionLayer(config))
|
| 221 |
+
|
| 222 |
+
def forward(self, hidden_states, size=None):
|
| 223 |
+
# reversing the hidden_states, we start from the last
|
| 224 |
+
hidden_states = hidden_states[::-1]
|
| 225 |
+
|
| 226 |
+
fused_hidden_states = []
|
| 227 |
+
fused_hidden_state = None
|
| 228 |
+
|
| 229 |
+
for idx, (hidden_state, layer) in enumerate(zip(hidden_states, self.layers)):
|
| 230 |
+
size = hidden_states[idx + 1].shape[2:] if idx != (len(hidden_states) - 1) else None
|
| 231 |
+
|
| 232 |
+
if fused_hidden_state is None:
|
| 233 |
+
# first layer only uses the last hidden_state
|
| 234 |
+
fused_hidden_state = layer(hidden_state, size=size)
|
| 235 |
+
else:
|
| 236 |
+
fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size)
|
| 237 |
+
|
| 238 |
+
fused_hidden_states.append(fused_hidden_state)
|
| 239 |
+
|
| 240 |
+
return fused_hidden_states
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# Modified from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->DepthAnything,dpt->depth_anything
|
| 244 |
+
# avoiding sdpa and flash_attn_2 support, it's done in the backend
|
| 245 |
+
class DepthAnythingPreTrainedModel(PreTrainedModel):
|
| 246 |
+
"""
|
| 247 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 248 |
+
models.
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
config_class = DepthAnythingConfig
|
| 252 |
+
base_model_prefix = "depth_anything"
|
| 253 |
+
main_input_name = "pixel_values"
|
| 254 |
+
supports_gradient_checkpointing = True
|
| 255 |
+
|
| 256 |
+
def _init_weights(self, module):
|
| 257 |
+
"""Initialize the weights"""
|
| 258 |
+
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
|
| 259 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 260 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 261 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 262 |
+
if module.bias is not None:
|
| 263 |
+
module.bias.data.zero_()
|
| 264 |
+
elif isinstance(module, nn.LayerNorm):
|
| 265 |
+
module.bias.data.zero_()
|
| 266 |
+
module.weight.data.fill_(1.0)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class DepthAnythingNeck(nn.Module):
|
| 270 |
+
"""
|
| 271 |
+
DepthAnythingNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as
|
| 272 |
+
input and produces another list of tensors as output. For DepthAnything, it includes 2 stages:
|
| 273 |
+
|
| 274 |
+
* DepthAnythingReassembleStage
|
| 275 |
+
* DepthAnythingFeatureFusionStage.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
config (dict): config dict.
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
def __init__(self, config):
|
| 282 |
+
super().__init__()
|
| 283 |
+
self.config = config
|
| 284 |
+
|
| 285 |
+
self.reassemble_stage = DepthAnythingReassembleStage(config)
|
| 286 |
+
|
| 287 |
+
self.convs = nn.ModuleList()
|
| 288 |
+
for channel in config.neck_hidden_sizes:
|
| 289 |
+
self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
|
| 290 |
+
|
| 291 |
+
# fusion
|
| 292 |
+
self.fusion_stage = DepthAnythingFeatureFusionStage(config)
|
| 293 |
+
|
| 294 |
+
def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]:
|
| 295 |
+
"""
|
| 296 |
+
Args:
|
| 297 |
+
hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
|
| 298 |
+
List of hidden states from the backbone.
|
| 299 |
+
"""
|
| 300 |
+
if not isinstance(hidden_states, (tuple, list)):
|
| 301 |
+
raise TypeError("hidden_states should be a tuple or list of tensors")
|
| 302 |
+
|
| 303 |
+
if len(hidden_states) != len(self.config.neck_hidden_sizes):
|
| 304 |
+
raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.")
|
| 305 |
+
|
| 306 |
+
# postprocess hidden states
|
| 307 |
+
hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
|
| 308 |
+
|
| 309 |
+
features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
|
| 310 |
+
|
| 311 |
+
# fusion blocks
|
| 312 |
+
output = self.fusion_stage(features)
|
| 313 |
+
|
| 314 |
+
return output
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class DepthAnythingDepthEstimationHead(nn.Module):
|
| 318 |
+
"""
|
| 319 |
+
Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
|
| 320 |
+
the predictions to the input resolution after the first convolutional layer (details can be found in the DPT paper's
|
| 321 |
+
supplementary material). The final activation function is either ReLU or Sigmoid, depending on the depth estimation
|
| 322 |
+
type (relative or metric). For metric depth estimation, the output is scaled by the maximum depth used during pretraining.
|
| 323 |
+
"""
|
| 324 |
+
|
| 325 |
+
def __init__(self, config):
|
| 326 |
+
super().__init__()
|
| 327 |
+
|
| 328 |
+
self.head_in_index = config.head_in_index
|
| 329 |
+
self.patch_size = config.patch_size
|
| 330 |
+
|
| 331 |
+
features = config.fusion_hidden_size
|
| 332 |
+
self.conv1 = nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1)
|
| 333 |
+
self.conv2 = nn.Conv2d(features // 2, config.head_hidden_size, kernel_size=3, stride=1, padding=1)
|
| 334 |
+
self.activation1 = nn.ReLU()
|
| 335 |
+
self.conv3 = nn.Conv2d(config.head_hidden_size, 1, kernel_size=1, stride=1, padding=0)
|
| 336 |
+
if config.depth_estimation_type == "relative":
|
| 337 |
+
self.activation2 = nn.ReLU()
|
| 338 |
+
elif config.depth_estimation_type == "metric":
|
| 339 |
+
self.activation2 = nn.Sigmoid()
|
| 340 |
+
else:
|
| 341 |
+
raise ValueError(f"Unknown depth estimation type: {config.depth_estimation_type}")
|
| 342 |
+
self.max_depth = config.max_depth
|
| 343 |
+
|
| 344 |
+
def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) -> torch.Tensor:
|
| 345 |
+
hidden_states = hidden_states[self.head_in_index]
|
| 346 |
+
|
| 347 |
+
predicted_depth = self.conv1(hidden_states)
|
| 348 |
+
predicted_depth = nn.functional.interpolate(
|
| 349 |
+
predicted_depth,
|
| 350 |
+
(int(patch_height * self.patch_size), int(patch_width * self.patch_size)),
|
| 351 |
+
mode="bilinear",
|
| 352 |
+
align_corners=True,
|
| 353 |
+
)
|
| 354 |
+
predicted_depth = self.conv2(predicted_depth)
|
| 355 |
+
predicted_depth = self.activation1(predicted_depth)
|
| 356 |
+
predicted_depth = self.conv3(predicted_depth)
|
| 357 |
+
predicted_depth = self.activation2(predicted_depth) * self.max_depth
|
| 358 |
+
predicted_depth = predicted_depth.squeeze(dim=1) # shape (batch_size, height, width)
|
| 359 |
+
|
| 360 |
+
return predicted_depth
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
@add_start_docstrings(
|
| 364 |
+
"""
|
| 365 |
+
Depth Anything Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2.
|
| 366 |
+
""",
|
| 367 |
+
DEPTH_ANYTHING_START_DOCSTRING,
|
| 368 |
+
)
|
| 369 |
+
class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel):
|
| 370 |
+
_no_split_modules = ["DPTViTEmbeddings"]
|
| 371 |
+
|
| 372 |
+
def __init__(self, config):
|
| 373 |
+
super().__init__(config)
|
| 374 |
+
|
| 375 |
+
self.backbone = load_backbone(config)
|
| 376 |
+
self.neck = DepthAnythingNeck(config)
|
| 377 |
+
self.head = DepthAnythingDepthEstimationHead(config)
|
| 378 |
+
|
| 379 |
+
# Initialize weights and apply final processing
|
| 380 |
+
self.post_init()
|
| 381 |
+
|
| 382 |
+
@add_start_docstrings_to_model_forward(DEPTH_ANYTHING_INPUTS_DOCSTRING)
|
| 383 |
+
@replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
|
| 384 |
+
def forward(
|
| 385 |
+
self,
|
| 386 |
+
pixel_values: torch.FloatTensor,
|
| 387 |
+
labels: Optional[torch.LongTensor] = None,
|
| 388 |
+
output_attentions: Optional[bool] = None,
|
| 389 |
+
output_hidden_states: Optional[bool] = None,
|
| 390 |
+
return_dict: Optional[bool] = None,
|
| 391 |
+
) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]:
|
| 392 |
+
r"""
|
| 393 |
+
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
| 394 |
+
Ground truth depth estimation maps for computing the loss.
|
| 395 |
+
|
| 396 |
+
Returns:
|
| 397 |
+
|
| 398 |
+
Examples:
|
| 399 |
+
```python
|
| 400 |
+
>>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation
|
| 401 |
+
>>> import torch
|
| 402 |
+
>>> import numpy as np
|
| 403 |
+
>>> from PIL import Image
|
| 404 |
+
>>> import requests
|
| 405 |
+
|
| 406 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 407 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 408 |
+
|
| 409 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-small-hf")
|
| 410 |
+
>>> model = AutoModelForDepthEstimation.from_pretrained("LiheYoung/depth-anything-small-hf")
|
| 411 |
+
|
| 412 |
+
>>> # prepare image for the model
|
| 413 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 414 |
+
|
| 415 |
+
>>> with torch.no_grad():
|
| 416 |
+
... outputs = model(**inputs)
|
| 417 |
+
|
| 418 |
+
>>> # interpolate to original size
|
| 419 |
+
>>> post_processed_output = image_processor.post_process_depth_estimation(
|
| 420 |
+
... outputs,
|
| 421 |
+
... target_sizes=[(image.height, image.width)],
|
| 422 |
+
... )
|
| 423 |
+
|
| 424 |
+
>>> # visualize the prediction
|
| 425 |
+
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
|
| 426 |
+
>>> depth = predicted_depth * 255 / predicted_depth.max()
|
| 427 |
+
>>> depth = depth.detach().cpu().numpy()
|
| 428 |
+
>>> depth = Image.fromarray(depth.astype("uint8"))
|
| 429 |
+
```"""
|
| 430 |
+
loss = None
|
| 431 |
+
if labels is not None:
|
| 432 |
+
raise NotImplementedError("Training is not implemented yet")
|
| 433 |
+
|
| 434 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 435 |
+
output_hidden_states = (
|
| 436 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 437 |
+
)
|
| 438 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 439 |
+
|
| 440 |
+
outputs = self.backbone.forward_with_filtered_kwargs(
|
| 441 |
+
pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
|
| 442 |
+
)
|
| 443 |
+
hidden_states = outputs.feature_maps
|
| 444 |
+
|
| 445 |
+
_, _, height, width = pixel_values.shape
|
| 446 |
+
patch_size = self.config.patch_size
|
| 447 |
+
patch_height = height // patch_size
|
| 448 |
+
patch_width = width // patch_size
|
| 449 |
+
|
| 450 |
+
hidden_states = self.neck(hidden_states, patch_height, patch_width)
|
| 451 |
+
|
| 452 |
+
predicted_depth = self.head(hidden_states, patch_height, patch_width)
|
| 453 |
+
|
| 454 |
+
if not return_dict:
|
| 455 |
+
if output_hidden_states:
|
| 456 |
+
output = (predicted_depth,) + outputs[1:]
|
| 457 |
+
else:
|
| 458 |
+
output = (predicted_depth,) + outputs[2:]
|
| 459 |
+
return ((loss,) + output) if loss is not None else output
|
| 460 |
+
|
| 461 |
+
return DepthEstimatorOutput(
|
| 462 |
+
loss=loss,
|
| 463 |
+
predicted_depth=predicted_depth,
|
| 464 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 465 |
+
attentions=outputs.attentions,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
__all__ = ["DepthAnythingForDepthEstimation", "DepthAnythingPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/depth_pro/configuration_depth_pro.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""DepthPro model configuration"""
|
| 16 |
+
|
| 17 |
+
from copy import deepcopy
|
| 18 |
+
|
| 19 |
+
from ...configuration_utils import PretrainedConfig
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
from ..auto.configuration_auto import CONFIG_MAPPING, AutoConfig
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DepthProConfig(PretrainedConfig):
|
| 28 |
+
r"""
|
| 29 |
+
This is the configuration class to store the configuration of a [`DepthProModel`]. It is used to instantiate a
|
| 30 |
+
DepthPro model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 31 |
+
with the defaults will yield a similar configuration to that of the DepthPro
|
| 32 |
+
[apple/DepthPro](https://huggingface.co/apple/DepthPro) architecture.
|
| 33 |
+
|
| 34 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 35 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
fusion_hidden_size (`int`, *optional*, defaults to 256):
|
| 39 |
+
The number of channels before fusion.
|
| 40 |
+
patch_size (`int`, *optional*, defaults to 384):
|
| 41 |
+
The size (resolution) of each patch. This is also the image_size for backbone model.
|
| 42 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 43 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 44 |
+
intermediate_hook_ids (`List[int]`, *optional*, defaults to `[11, 5]`):
|
| 45 |
+
Indices of the intermediate hidden states from the patch encoder to use for fusion.
|
| 46 |
+
intermediate_feature_dims (`List[int]`, *optional*, defaults to `[256, 256]`):
|
| 47 |
+
Hidden state dimensions during upsampling for each intermediate hidden state in `intermediate_hook_ids`.
|
| 48 |
+
scaled_images_ratios (`List[float]`, *optional*, defaults to `[0.25, 0.5, 1]`):
|
| 49 |
+
Ratios of scaled images to be used by the patch encoder.
|
| 50 |
+
scaled_images_overlap_ratios (`List[float]`, *optional*, defaults to `[0.0, 0.5, 0.25]`):
|
| 51 |
+
Overlap ratios between patches for each scaled image in `scaled_images_ratios`.
|
| 52 |
+
scaled_images_feature_dims (`List[int]`, *optional*, defaults to `[1024, 1024, 512]`):
|
| 53 |
+
Hidden state dimensions during upsampling for each scaled image in `scaled_images_ratios`.
|
| 54 |
+
merge_padding_value (`int`, *optional*, defaults to 3):
|
| 55 |
+
When merging smaller patches back to the image size, overlapping sections of this size are removed.
|
| 56 |
+
use_batch_norm_in_fusion_residual (`bool`, *optional*, defaults to `False`):
|
| 57 |
+
Whether to use batch normalization in the pre-activate residual units of the fusion blocks.
|
| 58 |
+
use_bias_in_fusion_residual (`bool`, *optional*, defaults to `True`):
|
| 59 |
+
Whether to use bias in the pre-activate residual units of the fusion blocks.
|
| 60 |
+
use_fov_model (`bool`, *optional*, defaults to `False`):
|
| 61 |
+
Whether to use `DepthProFovModel` to generate the field of view.
|
| 62 |
+
num_fov_head_layers (`int`, *optional*, defaults to 2):
|
| 63 |
+
Number of convolution layers in the head of `DepthProFovModel`.
|
| 64 |
+
image_model_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
|
| 65 |
+
The configuration of the image encoder model, which is loaded using the [`AutoModel`] API.
|
| 66 |
+
By default, Dinov2 model is used as backbone.
|
| 67 |
+
patch_model_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
|
| 68 |
+
The configuration of the patch encoder model, which is loaded using the [`AutoModel`] API.
|
| 69 |
+
By default, Dinov2 model is used as backbone.
|
| 70 |
+
fov_model_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
|
| 71 |
+
The configuration of the fov encoder model, which is loaded using the [`AutoModel`] API.
|
| 72 |
+
By default, Dinov2 model is used as backbone.
|
| 73 |
+
|
| 74 |
+
Example:
|
| 75 |
+
|
| 76 |
+
```python
|
| 77 |
+
>>> from transformers import DepthProConfig, DepthProModel
|
| 78 |
+
|
| 79 |
+
>>> # Initializing a DepthPro apple/DepthPro style configuration
|
| 80 |
+
>>> configuration = DepthProConfig()
|
| 81 |
+
|
| 82 |
+
>>> # Initializing a model (with random weights) from the apple/DepthPro style configuration
|
| 83 |
+
>>> model = DepthProModel(configuration)
|
| 84 |
+
|
| 85 |
+
>>> # Accessing the model configuration
|
| 86 |
+
>>> configuration = model.config
|
| 87 |
+
```"""
|
| 88 |
+
|
| 89 |
+
model_type = "depth_pro"
|
| 90 |
+
sub_configs = {"image_model_config": AutoConfig, "patch_model_config": AutoConfig, "fov_model_config": AutoConfig}
|
| 91 |
+
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
fusion_hidden_size=256,
|
| 95 |
+
patch_size=384,
|
| 96 |
+
initializer_range=0.02,
|
| 97 |
+
intermediate_hook_ids=[11, 5],
|
| 98 |
+
intermediate_feature_dims=[256, 256],
|
| 99 |
+
scaled_images_ratios=[0.25, 0.5, 1],
|
| 100 |
+
scaled_images_overlap_ratios=[0.0, 0.5, 0.25],
|
| 101 |
+
scaled_images_feature_dims=[1024, 1024, 512],
|
| 102 |
+
merge_padding_value=3,
|
| 103 |
+
use_batch_norm_in_fusion_residual=False,
|
| 104 |
+
use_bias_in_fusion_residual=True,
|
| 105 |
+
use_fov_model=False,
|
| 106 |
+
num_fov_head_layers=2,
|
| 107 |
+
image_model_config=None,
|
| 108 |
+
patch_model_config=None,
|
| 109 |
+
fov_model_config=None,
|
| 110 |
+
**kwargs,
|
| 111 |
+
):
|
| 112 |
+
super().__init__(**kwargs)
|
| 113 |
+
|
| 114 |
+
# scaled_images_ratios is sorted
|
| 115 |
+
if scaled_images_ratios != sorted(scaled_images_ratios):
|
| 116 |
+
raise ValueError(
|
| 117 |
+
f"Values in scaled_images_ratios={scaled_images_ratios} should be sorted from low to high"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# scaled_images_ratios, scaled_images_overlap_ratios, scaled_images_feature_dims should be consistent
|
| 121 |
+
if not (len(scaled_images_ratios) == len(scaled_images_overlap_ratios) == len(scaled_images_feature_dims)):
|
| 122 |
+
raise ValueError(
|
| 123 |
+
f"len(scaled_images_ratios)={len(scaled_images_ratios)} and "
|
| 124 |
+
f"len(scaled_images_overlap_ratios)={len(scaled_images_overlap_ratios)} and "
|
| 125 |
+
f"len(scaled_images_feature_dims)={len(scaled_images_feature_dims)}, "
|
| 126 |
+
f"should match in config."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# intermediate_hook_ids, intermediate_feature_dims should be consistent
|
| 130 |
+
if not (len(intermediate_hook_ids) == len(intermediate_feature_dims)):
|
| 131 |
+
raise ValueError(
|
| 132 |
+
f"len(intermediate_hook_ids)={len(intermediate_hook_ids)} and "
|
| 133 |
+
f"len(intermediate_feature_dims)={len(intermediate_feature_dims)}, "
|
| 134 |
+
f"should match in config."
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# fusion_hidden_size should be consistent with num_fov_head_layers
|
| 138 |
+
if fusion_hidden_size // 2**num_fov_head_layers == 0:
|
| 139 |
+
raise ValueError(
|
| 140 |
+
f"fusion_hidden_size={fusion_hidden_size} should be consistent with num_fov_head_layers={num_fov_head_layers} "
|
| 141 |
+
"i.e fusion_hidden_size // 2**num_fov_head_layers > 0"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self.fusion_hidden_size = fusion_hidden_size
|
| 145 |
+
self.patch_size = patch_size
|
| 146 |
+
self.initializer_range = initializer_range
|
| 147 |
+
self.use_batch_norm_in_fusion_residual = use_batch_norm_in_fusion_residual
|
| 148 |
+
self.use_bias_in_fusion_residual = use_bias_in_fusion_residual
|
| 149 |
+
self.use_fov_model = use_fov_model
|
| 150 |
+
self.num_fov_head_layers = num_fov_head_layers
|
| 151 |
+
self.intermediate_hook_ids = intermediate_hook_ids
|
| 152 |
+
self.intermediate_feature_dims = intermediate_feature_dims
|
| 153 |
+
self.scaled_images_ratios = scaled_images_ratios
|
| 154 |
+
self.scaled_images_overlap_ratios = scaled_images_overlap_ratios
|
| 155 |
+
self.scaled_images_feature_dims = scaled_images_feature_dims
|
| 156 |
+
self.merge_padding_value = merge_padding_value
|
| 157 |
+
self.image_model_config = image_model_config
|
| 158 |
+
self.patch_model_config = patch_model_config
|
| 159 |
+
self.fov_model_config = fov_model_config
|
| 160 |
+
|
| 161 |
+
for sub_config_key in self.sub_configs.keys():
|
| 162 |
+
sub_config = getattr(self, sub_config_key)
|
| 163 |
+
|
| 164 |
+
if sub_config is None:
|
| 165 |
+
sub_config = CONFIG_MAPPING["dinov2"](image_size=patch_size)
|
| 166 |
+
logger.info(
|
| 167 |
+
f"`{sub_config_key}` is `None`. Initializing `{sub_config_key}` with the `Dinov2Config` "
|
| 168 |
+
f"with default values except `{sub_config_key}.image_size` is set to `config.patch_size`."
|
| 169 |
+
)
|
| 170 |
+
elif isinstance(sub_config, dict):
|
| 171 |
+
sub_config = deepcopy(sub_config)
|
| 172 |
+
if "model_type" not in sub_config:
|
| 173 |
+
raise KeyError(
|
| 174 |
+
f"The `model_type` key is missing in the `{sub_config_key}` dictionary. Please provide the model type."
|
| 175 |
+
)
|
| 176 |
+
elif sub_config["model_type"] not in CONFIG_MAPPING:
|
| 177 |
+
raise ValueError(
|
| 178 |
+
f"The model type `{sub_config['model_type']}` in `{sub_config_key}` is not supported. Please provide a valid model type."
|
| 179 |
+
)
|
| 180 |
+
image_size = sub_config.get("image_size")
|
| 181 |
+
if image_size != patch_size:
|
| 182 |
+
logger.info(
|
| 183 |
+
f"The `image_size` in `{sub_config_key}` is set to `{image_size}`, "
|
| 184 |
+
f"but it does not match the required `patch_size` of `{patch_size}`. "
|
| 185 |
+
f"Updating `image_size` to `{patch_size}` for consistency. "
|
| 186 |
+
f"Ensure that `image_size` aligns with `patch_size` in the configuration."
|
| 187 |
+
)
|
| 188 |
+
sub_config.update({"image_size": patch_size})
|
| 189 |
+
sub_config = CONFIG_MAPPING[sub_config["model_type"]](**sub_config)
|
| 190 |
+
elif isinstance(sub_config, PretrainedConfig):
|
| 191 |
+
sub_config = sub_config
|
| 192 |
+
image_size = getattr(sub_config, "image_size", None)
|
| 193 |
+
if image_size != patch_size:
|
| 194 |
+
raise ValueError(
|
| 195 |
+
f"`config.{sub_config_key}.image_size={image_size}` should match `config.patch_size={patch_size}`."
|
| 196 |
+
)
|
| 197 |
+
else:
|
| 198 |
+
raise TypeError(
|
| 199 |
+
f"Invalid type for `sub_config`. Expected `PretrainedConfig`, `dict`, or `None`, but got {type(sub_config)}."
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
setattr(self, sub_config_key, sub_config)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
__all__ = ["DepthProConfig"]
|
docs/transformers/build/lib/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import gc
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
import regex as re
|
| 20 |
+
import torch
|
| 21 |
+
from huggingface_hub import hf_hub_download
|
| 22 |
+
|
| 23 |
+
from transformers import (
|
| 24 |
+
DepthProConfig,
|
| 25 |
+
DepthProForDepthEstimation,
|
| 26 |
+
DepthProImageProcessorFast,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# fmt: off
|
| 31 |
+
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
| 32 |
+
|
| 33 |
+
# encoder
|
| 34 |
+
r"encoder.(patch|image)_encoder.cls_token": r"depth_pro.encoder.\1_encoder.model.embeddings.cls_token",
|
| 35 |
+
r"encoder.(patch|image)_encoder.pos_embed": r"depth_pro.encoder.\1_encoder.model.embeddings.position_embeddings",
|
| 36 |
+
r"encoder.(patch|image)_encoder.patch_embed.proj.(weight|bias)": r"depth_pro.encoder.\1_encoder.model.embeddings.patch_embeddings.projection.\2",
|
| 37 |
+
r"encoder.(patch|image)_encoder.blocks.(\d+).norm(\d+).(weight|bias)": r"depth_pro.encoder.\1_encoder.model.encoder.layer.\2.norm\3.\4",
|
| 38 |
+
r"encoder.(patch|image)_encoder.blocks.(\d+).attn.qkv.(weight|bias)": r"depth_pro.encoder.\1_encoder.model.encoder.layer.\2.attention.attention.(query|key|value).\3",
|
| 39 |
+
r"encoder.(patch|image)_encoder.blocks.(\d+).attn.proj.(weight|bias)": r"depth_pro.encoder.\1_encoder.model.encoder.layer.\2.attention.output.dense.\3",
|
| 40 |
+
r"encoder.(patch|image)_encoder.blocks.(\d+).ls(\d+).gamma": r"depth_pro.encoder.\1_encoder.model.encoder.layer.\2.layer_scale\3.lambda1",
|
| 41 |
+
r"encoder.(patch|image)_encoder.blocks.(\d+).mlp.fc(\d+).(weight|bias)": r"depth_pro.encoder.\1_encoder.model.encoder.layer.\2.mlp.fc\3.\4",
|
| 42 |
+
r"encoder.(patch|image)_encoder.norm.(weight|bias)": r"depth_pro.encoder.\1_encoder.model.layernorm.\2",
|
| 43 |
+
r"encoder.fuse_lowres.(weight|bias)": r"depth_pro.neck.fuse_image_with_low_res.\1",
|
| 44 |
+
|
| 45 |
+
# fov
|
| 46 |
+
r"fov.encoder.0.cls_token": r"fov_model.fov_encoder.model.embeddings.cls_token",
|
| 47 |
+
r"fov.encoder.0.pos_embed": r"fov_model.fov_encoder.model.embeddings.position_embeddings",
|
| 48 |
+
r"fov.encoder.0.patch_embed.proj.(weight|bias)": r"fov_model.fov_encoder.model.embeddings.patch_embeddings.projection.\1",
|
| 49 |
+
r"fov.encoder.0.blocks.(\d+).norm(\d+).(weight|bias)": r"fov_model.fov_encoder.model.encoder.layer.\1.norm\2.\3",
|
| 50 |
+
r"fov.encoder.0.blocks.(\d+).attn.qkv.(weight|bias)": r"fov_model.fov_encoder.model.encoder.layer.\1.attention.attention.(query|key|value).\2",
|
| 51 |
+
r"fov.encoder.0.blocks.(\d+).attn.proj.(weight|bias)": r"fov_model.fov_encoder.model.encoder.layer.\1.attention.output.dense.\2",
|
| 52 |
+
r"fov.encoder.0.blocks.(\d+).ls(\d+).gamma": r"fov_model.fov_encoder.model.encoder.layer.\1.layer_scale\2.lambda1",
|
| 53 |
+
r"fov.encoder.0.blocks.(\d+).mlp.fc(\d+).(weight|bias)": r"fov_model.fov_encoder.model.encoder.layer.\1.mlp.fc\2.\3",
|
| 54 |
+
r"fov.encoder.0.norm.(weight|bias)": r"fov_model.fov_encoder.model.layernorm.\1",
|
| 55 |
+
r"fov.downsample.0.(weight|bias)": r"fov_model.conv.\1",
|
| 56 |
+
r"fov.encoder.1.(weight|bias)": r"fov_model.fov_encoder.neck.\1",
|
| 57 |
+
r"fov.head.(\d+).(weight|bias)": r"fov_model.head.layers.\1.\2",
|
| 58 |
+
|
| 59 |
+
# head
|
| 60 |
+
r"head.(\d+).(weight|bias)": r"head.layers.\1.\2",
|
| 61 |
+
|
| 62 |
+
# upsamples
|
| 63 |
+
r"encoder.upsample_lowres.(weight|bias)": r"depth_pro.neck.feature_upsample.image_block.layers.0.\1",
|
| 64 |
+
r"encoder.upsample_latent(\d+).(\d+).(weight|bias)": lambda match: (
|
| 65 |
+
f"depth_pro.neck.feature_upsample.intermediate.{1-int(match.group(1))}.layers.{match.group(2)}.{match.group(3)}"
|
| 66 |
+
),
|
| 67 |
+
r"encoder.upsample(\d+).(\d+).(weight|bias)": lambda match: (
|
| 68 |
+
f"depth_pro.neck.feature_upsample.scaled_images.{2-int(match.group(1))}.layers.{match.group(2)}.{match.group(3)}"
|
| 69 |
+
),
|
| 70 |
+
|
| 71 |
+
# projections between encoder and fusion
|
| 72 |
+
r"decoder.convs.(\d+).weight": lambda match: (
|
| 73 |
+
f"depth_pro.neck.feature_projection.projections.{4-int(match.group(1))}.weight"
|
| 74 |
+
),
|
| 75 |
+
|
| 76 |
+
# fusion stage
|
| 77 |
+
r"decoder.fusions.([1234]).resnet(\d+).residual.(\d+).(weight|bias)": lambda match: (
|
| 78 |
+
f"fusion_stage.intermediate.{4-int(match.group(1))}.residual_layer{match.group(2)}.convolution{(int(match.group(3))+1)//2}.{match.group(4)}"
|
| 79 |
+
),
|
| 80 |
+
r"decoder.fusions.0.resnet(\d+).residual.(\d+).(weight|bias)": lambda match: (
|
| 81 |
+
f"fusion_stage.final.residual_layer{match.group(1)}.convolution{(int(match.group(2))+1)//2}.{match.group(3)}"
|
| 82 |
+
),
|
| 83 |
+
r"decoder.fusions.([1234]).out_conv.(weight|bias)": lambda match: (
|
| 84 |
+
f"fusion_stage.intermediate.{4-int(match.group(1))}.projection.{match.group(2)}"
|
| 85 |
+
),
|
| 86 |
+
r"decoder.fusions.0.out_conv.(weight|bias)": lambda match: (
|
| 87 |
+
f"fusion_stage.final.projection.{match.group(1)}"
|
| 88 |
+
),
|
| 89 |
+
r"decoder.fusions.(\d+).deconv.(weight|bias)": lambda match: (
|
| 90 |
+
f"fusion_stage.intermediate.{4-int(match.group(1))}.deconv.{match.group(2)}"
|
| 91 |
+
),
|
| 92 |
+
}
|
| 93 |
+
# fmt: on
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
|
| 97 |
+
output_dict = {}
|
| 98 |
+
if state_dict_keys is not None:
|
| 99 |
+
old_text = "\n".join(state_dict_keys)
|
| 100 |
+
new_text = old_text
|
| 101 |
+
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
|
| 102 |
+
if replacement is None:
|
| 103 |
+
new_text = re.sub(pattern, "", new_text) # an empty line
|
| 104 |
+
continue
|
| 105 |
+
new_text = re.sub(pattern, replacement, new_text)
|
| 106 |
+
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
|
| 107 |
+
return output_dict
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_qkv_state_dict(key, parameter):
|
| 111 |
+
"""
|
| 112 |
+
new key which looks like this
|
| 113 |
+
xxxx.(q|k|v).xxx (m, n)
|
| 114 |
+
|
| 115 |
+
is converted to
|
| 116 |
+
xxxx.q.xxxx (m//3, n)
|
| 117 |
+
xxxx.k.xxxx (m//3, n)
|
| 118 |
+
xxxx.v.xxxx (m//3, n)
|
| 119 |
+
"""
|
| 120 |
+
qkv_state_dict = {}
|
| 121 |
+
placeholder = re.search(r"(\(.*?\))", key).group(1) # finds "(query|key|value)"
|
| 122 |
+
replacements_keys = placeholder[1:-1].split("|") # creates ['query', 'key', 'value']
|
| 123 |
+
replacements_vals = torch.split(
|
| 124 |
+
parameter, split_size_or_sections=parameter.size(0) // len(replacements_keys), dim=0
|
| 125 |
+
)
|
| 126 |
+
for replacement_key, replacement_val in zip(replacements_keys, replacements_vals):
|
| 127 |
+
qkv_state_dict[key.replace(placeholder, replacement_key)] = replacement_val
|
| 128 |
+
return qkv_state_dict
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def write_model(
|
| 132 |
+
hf_repo_id: str,
|
| 133 |
+
output_dir: str,
|
| 134 |
+
safe_serialization: bool = True,
|
| 135 |
+
):
|
| 136 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 137 |
+
|
| 138 |
+
# ------------------------------------------------------------
|
| 139 |
+
# Create and save config
|
| 140 |
+
# ------------------------------------------------------------
|
| 141 |
+
|
| 142 |
+
# create config
|
| 143 |
+
backbone_config = {
|
| 144 |
+
"model_type": "dinov2",
|
| 145 |
+
"num_hidden_layers": 24,
|
| 146 |
+
"patch_size": 16,
|
| 147 |
+
"hidden_size": 1024,
|
| 148 |
+
"num_attention_heads": 16,
|
| 149 |
+
"image_size": 384,
|
| 150 |
+
"use_mask_token": False,
|
| 151 |
+
}
|
| 152 |
+
config = DepthProConfig(
|
| 153 |
+
# original implementation uses same config for all 3 models
|
| 154 |
+
image_model_config=backbone_config,
|
| 155 |
+
patch_model_config=backbone_config,
|
| 156 |
+
fov_model_config=backbone_config,
|
| 157 |
+
use_fov_model=True,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# save config
|
| 161 |
+
config.save_pretrained(output_dir)
|
| 162 |
+
print("Model config saved successfully...")
|
| 163 |
+
|
| 164 |
+
# ------------------------------------------------------------
|
| 165 |
+
# Convert weights
|
| 166 |
+
# ------------------------------------------------------------
|
| 167 |
+
|
| 168 |
+
# download and load state_dict from hf repo
|
| 169 |
+
file_path = hf_hub_download(hf_repo_id, "depth_pro.pt")
|
| 170 |
+
loaded = torch.load(file_path, weights_only=True)
|
| 171 |
+
|
| 172 |
+
print("Converting model...")
|
| 173 |
+
all_keys = list(loaded.keys())
|
| 174 |
+
new_keys = convert_old_keys_to_new_keys(all_keys)
|
| 175 |
+
|
| 176 |
+
state_dict = {}
|
| 177 |
+
for key in all_keys:
|
| 178 |
+
new_key = new_keys[key]
|
| 179 |
+
current_parameter = loaded.pop(key)
|
| 180 |
+
|
| 181 |
+
if "qkv" in key:
|
| 182 |
+
qkv_state_dict = get_qkv_state_dict(new_key, current_parameter)
|
| 183 |
+
state_dict.update(qkv_state_dict)
|
| 184 |
+
else:
|
| 185 |
+
state_dict[new_key] = current_parameter
|
| 186 |
+
|
| 187 |
+
print("Loading the checkpoint in a DepthPro model.")
|
| 188 |
+
model = DepthProForDepthEstimation(config)
|
| 189 |
+
model.load_state_dict(state_dict, strict=True, assign=True)
|
| 190 |
+
print("Checkpoint loaded successfully.")
|
| 191 |
+
|
| 192 |
+
print("Saving the model.")
|
| 193 |
+
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
|
| 194 |
+
del state_dict, model
|
| 195 |
+
|
| 196 |
+
# Safety check: reload the converted model
|
| 197 |
+
gc.collect()
|
| 198 |
+
print("Reloading the model to check if it's saved correctly.")
|
| 199 |
+
model = DepthProForDepthEstimation.from_pretrained(output_dir, device_map="auto")
|
| 200 |
+
print("Model reloaded successfully.")
|
| 201 |
+
return model
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def write_image_processor(output_dir: str):
|
| 205 |
+
image_processor = DepthProImageProcessorFast()
|
| 206 |
+
image_processor.save_pretrained(output_dir)
|
| 207 |
+
return image_processor
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def main():
|
| 211 |
+
parser = argparse.ArgumentParser()
|
| 212 |
+
parser.add_argument(
|
| 213 |
+
"--hf_repo_id",
|
| 214 |
+
default="apple/DepthPro",
|
| 215 |
+
help="Location of official weights from apple on HF",
|
| 216 |
+
)
|
| 217 |
+
parser.add_argument(
|
| 218 |
+
"--output_dir",
|
| 219 |
+
default="apple_DepthPro",
|
| 220 |
+
help="Location to write the converted model and processor",
|
| 221 |
+
)
|
| 222 |
+
parser.add_argument(
|
| 223 |
+
"--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
|
| 224 |
+
)
|
| 225 |
+
parser.add_argument(
|
| 226 |
+
"--push_to_hub",
|
| 227 |
+
action=argparse.BooleanOptionalAction,
|
| 228 |
+
help="Whether or not to push the converted model to the huggingface hub.",
|
| 229 |
+
)
|
| 230 |
+
parser.add_argument(
|
| 231 |
+
"--hub_repo_id",
|
| 232 |
+
default="apple/DepthPro-hf",
|
| 233 |
+
help="Huggingface hub repo to write the converted model and processor",
|
| 234 |
+
)
|
| 235 |
+
args = parser.parse_args()
|
| 236 |
+
|
| 237 |
+
model = write_model(
|
| 238 |
+
hf_repo_id=args.hf_repo_id,
|
| 239 |
+
output_dir=args.output_dir,
|
| 240 |
+
safe_serialization=args.safe_serialization,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
image_processor = write_image_processor(
|
| 244 |
+
output_dir=args.output_dir,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
if args.push_to_hub:
|
| 248 |
+
print("Pushing to hub...")
|
| 249 |
+
model.push_to_hub(args.hub_repo_id)
|
| 250 |
+
image_processor.push_to_hub(args.hub_repo_id)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
main()
|
docs/transformers/build/lib/transformers/models/depth_pro/image_processing_depth_pro.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for DepthPro."""
|
| 16 |
+
|
| 17 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from ...utils.import_utils import requires
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if TYPE_CHECKING:
|
| 25 |
+
from .modeling_depth_pro import DepthProDepthEstimatorOutput
|
| 26 |
+
|
| 27 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
| 28 |
+
from ...image_transforms import to_channel_dimension_format
|
| 29 |
+
from ...image_utils import (
|
| 30 |
+
IMAGENET_STANDARD_MEAN,
|
| 31 |
+
IMAGENET_STANDARD_STD,
|
| 32 |
+
ChannelDimension,
|
| 33 |
+
ImageInput,
|
| 34 |
+
PILImageResampling,
|
| 35 |
+
infer_channel_dimension_format,
|
| 36 |
+
is_scaled_image,
|
| 37 |
+
is_torch_available,
|
| 38 |
+
make_list_of_images,
|
| 39 |
+
pil_torch_interpolation_mapping,
|
| 40 |
+
to_numpy_array,
|
| 41 |
+
valid_images,
|
| 42 |
+
)
|
| 43 |
+
from ...utils import (
|
| 44 |
+
TensorType,
|
| 45 |
+
filter_out_non_signature_kwargs,
|
| 46 |
+
logging,
|
| 47 |
+
requires_backends,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if is_torch_available():
|
| 52 |
+
import torch
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
logger = logging.get_logger(__name__)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@requires(backends=("torchvision", "torch"))
|
| 59 |
+
class DepthProImageProcessor(BaseImageProcessor):
|
| 60 |
+
r"""
|
| 61 |
+
Constructs a DepthPro image processor.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 65 |
+
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
|
| 66 |
+
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
|
| 67 |
+
size (`dict`, *optional*, defaults to `{"height": 1536, "width": 1536}`):
|
| 68 |
+
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
| 69 |
+
method.
|
| 70 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
| 71 |
+
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
| 72 |
+
`preprocess` method.
|
| 73 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 74 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
| 75 |
+
parameter in the `preprocess` method.
|
| 76 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 77 |
+
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
| 78 |
+
`preprocess` method.
|
| 79 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 80 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
| 81 |
+
method.
|
| 82 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
| 83 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 84 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 85 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
| 86 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
| 87 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
model_input_names = ["pixel_values"]
|
| 91 |
+
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
do_resize: bool = True,
|
| 95 |
+
size: Optional[Dict[str, int]] = None,
|
| 96 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 97 |
+
do_rescale: bool = True,
|
| 98 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 99 |
+
do_normalize: bool = True,
|
| 100 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 101 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 102 |
+
**kwargs,
|
| 103 |
+
):
|
| 104 |
+
super().__init__(**kwargs)
|
| 105 |
+
size = size if size is not None else {"height": 1536, "width": 1536}
|
| 106 |
+
size = get_size_dict(size)
|
| 107 |
+
self.do_resize = do_resize
|
| 108 |
+
self.do_rescale = do_rescale
|
| 109 |
+
self.do_normalize = do_normalize
|
| 110 |
+
self.size = size
|
| 111 |
+
self.resample = resample
|
| 112 |
+
self.rescale_factor = rescale_factor
|
| 113 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
| 114 |
+
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
| 115 |
+
|
| 116 |
+
def resize(
|
| 117 |
+
self,
|
| 118 |
+
image: np.ndarray,
|
| 119 |
+
size: Dict[str, int],
|
| 120 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 121 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 122 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 123 |
+
**kwargs,
|
| 124 |
+
) -> np.ndarray:
|
| 125 |
+
"""
|
| 126 |
+
Resize an image to `(size["height"], size["width"])`.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
image (`np.ndarray`):
|
| 130 |
+
Image to resize.
|
| 131 |
+
size (`Dict[str, int]`):
|
| 132 |
+
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
| 133 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 134 |
+
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
|
| 135 |
+
data_format (`ChannelDimension` or `str`, *optional*):
|
| 136 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 137 |
+
image is used. Can be one of:
|
| 138 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 139 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 140 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 141 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 142 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 143 |
+
from the input image. Can be one of:
|
| 144 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 145 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 146 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
`np.ndarray`: The resized images.
|
| 150 |
+
"""
|
| 151 |
+
requires_backends(self, "torch")
|
| 152 |
+
|
| 153 |
+
size = get_size_dict(size)
|
| 154 |
+
if "height" not in size or "width" not in size:
|
| 155 |
+
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
| 156 |
+
output_size = (size["height"], size["width"])
|
| 157 |
+
|
| 158 |
+
# we use torch interpolation instead of image.resize because DepthProImageProcessor
|
| 159 |
+
# rescales, then normalizes, which may cause some values to become negative, before resizing the image.
|
| 160 |
+
# image.resize expects all values to be in range [0, 1] or [0, 255] and throws an exception otherwise,
|
| 161 |
+
# however pytorch interpolation works with negative values.
|
| 162 |
+
# relevant issue here: https://github.com/huggingface/transformers/issues/34920
|
| 163 |
+
# input should be (B, C, H, W)
|
| 164 |
+
image_tensor = torch.from_numpy(image).unsqueeze(0)
|
| 165 |
+
resized_image = torch.nn.functional.interpolate(
|
| 166 |
+
input=image_tensor,
|
| 167 |
+
size=output_size,
|
| 168 |
+
mode=pil_torch_interpolation_mapping[resample].value,
|
| 169 |
+
)
|
| 170 |
+
resized_image = resized_image.squeeze(0).numpy()
|
| 171 |
+
return resized_image
|
| 172 |
+
|
| 173 |
+
def _validate_input_arguments(
|
| 174 |
+
self,
|
| 175 |
+
do_resize: bool,
|
| 176 |
+
size: Dict[str, int],
|
| 177 |
+
resample: PILImageResampling,
|
| 178 |
+
do_rescale: bool,
|
| 179 |
+
rescale_factor: float,
|
| 180 |
+
do_normalize: bool,
|
| 181 |
+
image_mean: Union[float, List[float]],
|
| 182 |
+
image_std: Union[float, List[float]],
|
| 183 |
+
data_format: Union[str, ChannelDimension],
|
| 184 |
+
):
|
| 185 |
+
if do_resize and None in (size, resample):
|
| 186 |
+
raise ValueError("Size and resample must be specified if do_resize is True.")
|
| 187 |
+
|
| 188 |
+
if do_rescale and rescale_factor is None:
|
| 189 |
+
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
| 190 |
+
|
| 191 |
+
if do_normalize and None in (image_mean, image_std):
|
| 192 |
+
raise ValueError("Image mean and standard deviation must be specified if do_normalize is True.")
|
| 193 |
+
|
| 194 |
+
@filter_out_non_signature_kwargs()
|
| 195 |
+
def preprocess(
|
| 196 |
+
self,
|
| 197 |
+
images: ImageInput,
|
| 198 |
+
do_resize: Optional[bool] = None,
|
| 199 |
+
size: Optional[Dict[str, int]] = None,
|
| 200 |
+
resample: Optional[PILImageResampling] = None,
|
| 201 |
+
do_rescale: Optional[bool] = None,
|
| 202 |
+
rescale_factor: Optional[float] = None,
|
| 203 |
+
do_normalize: Optional[bool] = None,
|
| 204 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 205 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 206 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 207 |
+
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
| 208 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 209 |
+
):
|
| 210 |
+
"""
|
| 211 |
+
Preprocess an image or batch of images.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
images (`ImageInput`):
|
| 215 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 216 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 217 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 218 |
+
Whether to resize the image.
|
| 219 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 220 |
+
Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
|
| 221 |
+
resizing.
|
| 222 |
+
resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
|
| 223 |
+
`PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
|
| 224 |
+
an effect if `do_resize` is set to `True`.
|
| 225 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 226 |
+
Whether to rescale the image values between [0 - 1].
|
| 227 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 228 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 229 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 230 |
+
Whether to normalize the image.
|
| 231 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 232 |
+
Image mean to use if `do_normalize` is set to `True`.
|
| 233 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 234 |
+
Image standard deviation to use if `do_normalize` is set to `True`.
|
| 235 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 236 |
+
The type of tensors to return. Can be one of:
|
| 237 |
+
- Unset: Return a list of `np.ndarray`.
|
| 238 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 239 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 240 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 241 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 242 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 243 |
+
The channel dimension format for the output image. Can be one of:
|
| 244 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 245 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 246 |
+
- Unset: Use the channel dimension format of the input image.
|
| 247 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 248 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 249 |
+
from the input image. Can be one of:
|
| 250 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 251 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 252 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 253 |
+
"""
|
| 254 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 255 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 256 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 257 |
+
resample = resample if resample is not None else self.resample
|
| 258 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 259 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 260 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 261 |
+
|
| 262 |
+
size = size if size is not None else self.size
|
| 263 |
+
|
| 264 |
+
images = make_list_of_images(images)
|
| 265 |
+
|
| 266 |
+
if not valid_images(images):
|
| 267 |
+
raise ValueError(
|
| 268 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 269 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 270 |
+
)
|
| 271 |
+
self._validate_input_arguments(
|
| 272 |
+
do_resize=do_resize,
|
| 273 |
+
size=size,
|
| 274 |
+
resample=resample,
|
| 275 |
+
do_rescale=do_rescale,
|
| 276 |
+
rescale_factor=rescale_factor,
|
| 277 |
+
do_normalize=do_normalize,
|
| 278 |
+
image_mean=image_mean,
|
| 279 |
+
image_std=image_std,
|
| 280 |
+
data_format=data_format,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# All transformations expect numpy arrays.
|
| 284 |
+
images = [to_numpy_array(image) for image in images]
|
| 285 |
+
|
| 286 |
+
if is_scaled_image(images[0]) and do_rescale:
|
| 287 |
+
logger.warning_once(
|
| 288 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 289 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
if input_data_format is None:
|
| 293 |
+
# We assume that all images have the same channel dimension format.
|
| 294 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 295 |
+
|
| 296 |
+
all_images = []
|
| 297 |
+
for image in images:
|
| 298 |
+
if do_rescale:
|
| 299 |
+
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 300 |
+
|
| 301 |
+
if do_normalize:
|
| 302 |
+
image = self.normalize(
|
| 303 |
+
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# depth-pro rescales and normalizes the image before resizing it
|
| 307 |
+
# uses torch interpolation which requires ChannelDimension.FIRST
|
| 308 |
+
if do_resize:
|
| 309 |
+
image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
|
| 310 |
+
image = self.resize(image=image, size=size, resample=resample)
|
| 311 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=ChannelDimension.FIRST)
|
| 312 |
+
else:
|
| 313 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 314 |
+
|
| 315 |
+
all_images.append(image)
|
| 316 |
+
|
| 317 |
+
data = {"pixel_values": all_images}
|
| 318 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 319 |
+
|
| 320 |
+
def post_process_depth_estimation(
|
| 321 |
+
self,
|
| 322 |
+
outputs: "DepthProDepthEstimatorOutput",
|
| 323 |
+
target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
|
| 324 |
+
) -> Dict[str, List[TensorType]]:
|
| 325 |
+
"""
|
| 326 |
+
Post-processes the raw depth predictions from the model to generate
|
| 327 |
+
final depth predictions which is caliberated using the field of view if provided
|
| 328 |
+
and resized to specified target sizes if provided.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
outputs ([`DepthProDepthEstimatorOutput`]):
|
| 332 |
+
Raw outputs of the model.
|
| 333 |
+
target_sizes (`Optional[Union[TensorType, List[Tuple[int, int]], None]]`, *optional*, defaults to `None`):
|
| 334 |
+
Target sizes to resize the depth predictions. Can be a tensor of shape `(batch_size, 2)`
|
| 335 |
+
or a list of tuples `(height, width)` for each image in the batch. If `None`, no resizing
|
| 336 |
+
is performed.
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
`List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
|
| 340 |
+
predictions, and field of view (degrees) and focal length (pixels) if `field_of_view` is given in `outputs`.
|
| 341 |
+
|
| 342 |
+
Raises:
|
| 343 |
+
`ValueError`:
|
| 344 |
+
If the lengths of `predicted_depths`, `fovs`, or `target_sizes` are mismatched.
|
| 345 |
+
"""
|
| 346 |
+
requires_backends(self, "torch")
|
| 347 |
+
|
| 348 |
+
predicted_depth = outputs.predicted_depth
|
| 349 |
+
fov = outputs.field_of_view
|
| 350 |
+
|
| 351 |
+
batch_size = len(predicted_depth)
|
| 352 |
+
|
| 353 |
+
if target_sizes is not None and batch_size != len(target_sizes):
|
| 354 |
+
raise ValueError(
|
| 355 |
+
"Make sure that you pass in as many fov values as the batch dimension of the predicted depth"
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
results = []
|
| 359 |
+
fov = [None] * batch_size if fov is None else fov
|
| 360 |
+
target_sizes = [None] * batch_size if target_sizes is None else target_sizes
|
| 361 |
+
for depth, fov_value, target_size in zip(predicted_depth, fov, target_sizes):
|
| 362 |
+
focal_length = None
|
| 363 |
+
if target_size is not None:
|
| 364 |
+
# scale image w.r.t fov
|
| 365 |
+
if fov_value is not None:
|
| 366 |
+
width = target_size[1]
|
| 367 |
+
focal_length = 0.5 * width / torch.tan(0.5 * torch.deg2rad(fov_value))
|
| 368 |
+
depth = depth * width / focal_length
|
| 369 |
+
|
| 370 |
+
# interpolate
|
| 371 |
+
depth = torch.nn.functional.interpolate(
|
| 372 |
+
# input should be (B, C, H, W)
|
| 373 |
+
input=depth.unsqueeze(0).unsqueeze(1),
|
| 374 |
+
size=target_size,
|
| 375 |
+
mode=pil_torch_interpolation_mapping[self.resample].value,
|
| 376 |
+
).squeeze()
|
| 377 |
+
|
| 378 |
+
# inverse the depth
|
| 379 |
+
depth = 1.0 / torch.clamp(depth, min=1e-4, max=1e4)
|
| 380 |
+
|
| 381 |
+
results.append(
|
| 382 |
+
{
|
| 383 |
+
"predicted_depth": depth,
|
| 384 |
+
"field_of_view": fov_value,
|
| 385 |
+
"focal_length": focal_length,
|
| 386 |
+
}
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
return results
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
__all__ = ["DepthProImageProcessor"]
|
docs/transformers/build/lib/transformers/models/depth_pro/image_processing_depth_pro_fast.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Fast Image processor class for DepthPro."""
|
| 16 |
+
|
| 17 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
from ...image_processing_base import BatchFeature
|
| 20 |
+
from ...image_processing_utils_fast import (
|
| 21 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
| 22 |
+
BaseImageProcessorFast,
|
| 23 |
+
group_images_by_shape,
|
| 24 |
+
reorder_images,
|
| 25 |
+
)
|
| 26 |
+
from ...image_utils import (
|
| 27 |
+
IMAGENET_STANDARD_MEAN,
|
| 28 |
+
IMAGENET_STANDARD_STD,
|
| 29 |
+
PILImageResampling,
|
| 30 |
+
SizeDict,
|
| 31 |
+
)
|
| 32 |
+
from ...utils import (
|
| 33 |
+
TensorType,
|
| 34 |
+
add_start_docstrings,
|
| 35 |
+
is_torch_available,
|
| 36 |
+
is_torchvision_available,
|
| 37 |
+
is_torchvision_v2_available,
|
| 38 |
+
logging,
|
| 39 |
+
requires_backends,
|
| 40 |
+
)
|
| 41 |
+
from ...utils.import_utils import requires
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if TYPE_CHECKING:
|
| 45 |
+
from .modeling_depth_pro import DepthProDepthEstimatorOutput
|
| 46 |
+
|
| 47 |
+
logger = logging.get_logger(__name__)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if is_torch_available():
|
| 51 |
+
import torch
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if is_torchvision_available():
|
| 55 |
+
from ...image_utils import pil_torch_interpolation_mapping
|
| 56 |
+
|
| 57 |
+
if is_torchvision_v2_available():
|
| 58 |
+
from torchvision.transforms.v2 import functional as F
|
| 59 |
+
else:
|
| 60 |
+
from torchvision.transforms import functional as F
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@add_start_docstrings(
|
| 64 |
+
"Constructs a fast DepthPro image processor.",
|
| 65 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
| 66 |
+
)
|
| 67 |
+
@requires(backends=("torchvision", "torch"))
|
| 68 |
+
class DepthProImageProcessorFast(BaseImageProcessorFast):
|
| 69 |
+
resample = PILImageResampling.BILINEAR
|
| 70 |
+
image_mean = IMAGENET_STANDARD_MEAN
|
| 71 |
+
image_std = IMAGENET_STANDARD_STD
|
| 72 |
+
size = {"height": 1536, "width": 1536}
|
| 73 |
+
do_resize = True
|
| 74 |
+
do_rescale = True
|
| 75 |
+
do_normalize = True
|
| 76 |
+
|
| 77 |
+
# DepthPro resizes image after rescaling and normalizing,
|
| 78 |
+
# which makes it different from BaseImageProcessorFast._preprocess
|
| 79 |
+
def _preprocess(
|
| 80 |
+
self,
|
| 81 |
+
images: List["torch.Tensor"],
|
| 82 |
+
do_resize: bool,
|
| 83 |
+
size: SizeDict,
|
| 84 |
+
interpolation: Optional["F.InterpolationMode"],
|
| 85 |
+
do_center_crop: bool,
|
| 86 |
+
crop_size: SizeDict,
|
| 87 |
+
do_rescale: bool,
|
| 88 |
+
rescale_factor: float,
|
| 89 |
+
do_normalize: bool,
|
| 90 |
+
image_mean: Optional[Union[float, List[float]]],
|
| 91 |
+
image_std: Optional[Union[float, List[float]]],
|
| 92 |
+
return_tensors: Optional[Union[str, TensorType]],
|
| 93 |
+
) -> BatchFeature:
|
| 94 |
+
# Group images by size for batched scaling
|
| 95 |
+
grouped_images, grouped_images_index = group_images_by_shape(images)
|
| 96 |
+
processed_images_grouped = {}
|
| 97 |
+
for shape, stacked_images in grouped_images.items():
|
| 98 |
+
# Fused rescale and normalize
|
| 99 |
+
stacked_images = self.rescale_and_normalize(
|
| 100 |
+
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
| 101 |
+
)
|
| 102 |
+
if do_resize:
|
| 103 |
+
stacked_images = self.resize(
|
| 104 |
+
image=stacked_images,
|
| 105 |
+
size=size,
|
| 106 |
+
interpolation=interpolation,
|
| 107 |
+
antialias=False,
|
| 108 |
+
)
|
| 109 |
+
processed_images_grouped[shape] = stacked_images
|
| 110 |
+
|
| 111 |
+
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
| 112 |
+
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
| 113 |
+
|
| 114 |
+
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
| 115 |
+
|
| 116 |
+
# Copied from transformers.models.depth_pro.image_processing_depth_pro.DepthProImageProcessor.post_process_depth_estimation
|
| 117 |
+
def post_process_depth_estimation(
|
| 118 |
+
self,
|
| 119 |
+
outputs: "DepthProDepthEstimatorOutput",
|
| 120 |
+
target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
|
| 121 |
+
) -> Dict[str, List[TensorType]]:
|
| 122 |
+
"""
|
| 123 |
+
Post-processes the raw depth predictions from the model to generate
|
| 124 |
+
final depth predictions which is caliberated using the field of view if provided
|
| 125 |
+
and resized to specified target sizes if provided.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
outputs ([`DepthProDepthEstimatorOutput`]):
|
| 129 |
+
Raw outputs of the model.
|
| 130 |
+
target_sizes (`Optional[Union[TensorType, List[Tuple[int, int]], None]]`, *optional*, defaults to `None`):
|
| 131 |
+
Target sizes to resize the depth predictions. Can be a tensor of shape `(batch_size, 2)`
|
| 132 |
+
or a list of tuples `(height, width)` for each image in the batch. If `None`, no resizing
|
| 133 |
+
is performed.
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
`List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
|
| 137 |
+
predictions, and field of view (degrees) and focal length (pixels) if `field_of_view` is given in `outputs`.
|
| 138 |
+
|
| 139 |
+
Raises:
|
| 140 |
+
`ValueError`:
|
| 141 |
+
If the lengths of `predicted_depths`, `fovs`, or `target_sizes` are mismatched.
|
| 142 |
+
"""
|
| 143 |
+
requires_backends(self, "torch")
|
| 144 |
+
|
| 145 |
+
predicted_depth = outputs.predicted_depth
|
| 146 |
+
fov = outputs.field_of_view
|
| 147 |
+
|
| 148 |
+
batch_size = len(predicted_depth)
|
| 149 |
+
|
| 150 |
+
if target_sizes is not None and batch_size != len(target_sizes):
|
| 151 |
+
raise ValueError(
|
| 152 |
+
"Make sure that you pass in as many fov values as the batch dimension of the predicted depth"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
results = []
|
| 156 |
+
fov = [None] * batch_size if fov is None else fov
|
| 157 |
+
target_sizes = [None] * batch_size if target_sizes is None else target_sizes
|
| 158 |
+
for depth, fov_value, target_size in zip(predicted_depth, fov, target_sizes):
|
| 159 |
+
focal_length = None
|
| 160 |
+
if target_size is not None:
|
| 161 |
+
# scale image w.r.t fov
|
| 162 |
+
if fov_value is not None:
|
| 163 |
+
width = target_size[1]
|
| 164 |
+
focal_length = 0.5 * width / torch.tan(0.5 * torch.deg2rad(fov_value))
|
| 165 |
+
depth = depth * width / focal_length
|
| 166 |
+
|
| 167 |
+
# interpolate
|
| 168 |
+
depth = torch.nn.functional.interpolate(
|
| 169 |
+
# input should be (B, C, H, W)
|
| 170 |
+
input=depth.unsqueeze(0).unsqueeze(1),
|
| 171 |
+
size=target_size,
|
| 172 |
+
mode=pil_torch_interpolation_mapping[self.resample].value,
|
| 173 |
+
).squeeze()
|
| 174 |
+
|
| 175 |
+
# inverse the depth
|
| 176 |
+
depth = 1.0 / torch.clamp(depth, min=1e-4, max=1e4)
|
| 177 |
+
|
| 178 |
+
results.append(
|
| 179 |
+
{
|
| 180 |
+
"predicted_depth": depth,
|
| 181 |
+
"field_of_view": fov_value,
|
| 182 |
+
"focal_length": focal_length,
|
| 183 |
+
}
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
return results
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
__all__ = ["DepthProImageProcessorFast"]
|
docs/transformers/build/lib/transformers/models/depth_pro/modeling_depth_pro.py
ADDED
|
@@ -0,0 +1,1218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The Apple Research Team Authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch DepthPro model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from torch import nn
|
| 24 |
+
|
| 25 |
+
from ...modeling_outputs import BaseModelOutput
|
| 26 |
+
from ...modeling_utils import PreTrainedModel
|
| 27 |
+
from ...utils import (
|
| 28 |
+
ModelOutput,
|
| 29 |
+
add_start_docstrings,
|
| 30 |
+
add_start_docstrings_to_model_forward,
|
| 31 |
+
logging,
|
| 32 |
+
replace_return_docstrings,
|
| 33 |
+
torch_int,
|
| 34 |
+
)
|
| 35 |
+
from ..auto import AutoModel
|
| 36 |
+
from .configuration_depth_pro import DepthProConfig
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class DepthProOutput(ModelOutput):
|
| 44 |
+
"""
|
| 45 |
+
Base class for DepthPro's outputs.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, n_patches_per_batch, sequence_length, hidden_size)`):
|
| 49 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 50 |
+
features (`Union[torch.FloatTensor, List[torch.FloatTensor]]`, *optional*):
|
| 51 |
+
Features from encoders. Can be a single feature or a list of features.
|
| 52 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 53 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 54 |
+
one for the output of each layer) of shape `(batch_size, n_patches_per_batch, sequence_length, hidden_size)`.
|
| 55 |
+
|
| 56 |
+
Hidden-states of the model at the output of each layer and the optional initial embedding outputs.
|
| 57 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 58 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, n_patches_per_batch, num_heads, sequence_length,
|
| 59 |
+
sequence_length)`.
|
| 60 |
+
|
| 61 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 62 |
+
heads.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 66 |
+
features: Union[torch.FloatTensor, List[torch.FloatTensor]] = None
|
| 67 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 68 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class DepthProDepthEstimatorOutput(ModelOutput):
|
| 73 |
+
"""
|
| 74 |
+
Base class for DepthProForDepthEstimation's output.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 78 |
+
Classification (or regression if config.num_labels==1) loss.
|
| 79 |
+
predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`):
|
| 80 |
+
Predicted depth for each pixel.
|
| 81 |
+
field_of_view (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned when `use_fov_model` is provided):
|
| 82 |
+
Field of View Scaler.
|
| 83 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 84 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 85 |
+
one for the output of each layer) of shape `(batch_size, n_patches_per_batch, sequence_length, hidden_size)`.
|
| 86 |
+
|
| 87 |
+
Hidden-states of the model at the output of each layer and the optional initial embedding outputs.
|
| 88 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 89 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, n_patches_per_batch, num_heads, sequence_length,
|
| 90 |
+
sequence_length)`.
|
| 91 |
+
|
| 92 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 93 |
+
heads.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
loss: Optional[torch.FloatTensor] = None
|
| 97 |
+
predicted_depth: Optional[torch.FloatTensor] = None
|
| 98 |
+
field_of_view: Optional[torch.FloatTensor] = None
|
| 99 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 100 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def split_to_patches(pixel_values: torch.Tensor, patch_size: int, overlap_ratio: float) -> torch.Tensor:
|
| 104 |
+
"""Creates Patches from Batch."""
|
| 105 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 106 |
+
|
| 107 |
+
if height == width == patch_size:
|
| 108 |
+
# create patches only if scaled image is not already equal to patch size
|
| 109 |
+
return pixel_values
|
| 110 |
+
|
| 111 |
+
stride = torch_int(patch_size * (1 - overlap_ratio))
|
| 112 |
+
|
| 113 |
+
patches = F.unfold(pixel_values, kernel_size=(patch_size, patch_size), stride=(stride, stride))
|
| 114 |
+
patches = patches.permute(2, 0, 1)
|
| 115 |
+
patches = patches.reshape(-1, num_channels, patch_size, patch_size)
|
| 116 |
+
|
| 117 |
+
return patches
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def reshape_features(hidden_states: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
"""Discard class token and reshape 1D feature map to a 2D grid."""
|
| 122 |
+
n_samples, seq_len, hidden_size = hidden_states.shape
|
| 123 |
+
size = torch_int(seq_len**0.5)
|
| 124 |
+
|
| 125 |
+
hidden_states = hidden_states[:, -(size**2) :, :] # remove special tokens if there are any
|
| 126 |
+
hidden_states = hidden_states.reshape(n_samples, size, size, hidden_size)
|
| 127 |
+
hidden_states = hidden_states.permute(0, 3, 1, 2)
|
| 128 |
+
|
| 129 |
+
return hidden_states
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def merge_patches(patches: torch.Tensor, batch_size: int, padding: int) -> torch.Tensor:
|
| 133 |
+
"""Merges smaller patches into image-like feature map."""
|
| 134 |
+
n_patches, hidden_size, out_size, out_size = patches.shape
|
| 135 |
+
n_patches_per_batch = n_patches // batch_size
|
| 136 |
+
sqrt_n_patches_per_batch = torch_int(n_patches_per_batch**0.5)
|
| 137 |
+
new_out_size = sqrt_n_patches_per_batch * out_size
|
| 138 |
+
|
| 139 |
+
if n_patches == batch_size:
|
| 140 |
+
# merge only if the patches were created from scaled image
|
| 141 |
+
# patches are not created when scaled image size is equal to patch size
|
| 142 |
+
return patches
|
| 143 |
+
|
| 144 |
+
if n_patches_per_batch < 4:
|
| 145 |
+
# for each batch, atleast 4 small patches are required to
|
| 146 |
+
# recreate a large square patch from merging them and later padding is applied
|
| 147 |
+
# 3 x (8x8) patches becomes 1 x ( 8x8 ) patch (extra patch ignored, no padding)
|
| 148 |
+
# 4 x (8x8) patches becomes 1 x (16x16) patch (padding later)
|
| 149 |
+
# 5 x (8x8) patches becomes 1 x (16x16) patch (extra patch ignored, padding later)
|
| 150 |
+
# 9 x (8x8) patches becomes 1 x (24x24) patch (padding later)
|
| 151 |
+
# thus the following code only rearranges the patches and removes extra ones
|
| 152 |
+
padding = 0
|
| 153 |
+
|
| 154 |
+
# make sure padding is not large enough to remove more than half of the patch
|
| 155 |
+
padding = min(out_size // 4, padding)
|
| 156 |
+
|
| 157 |
+
if padding == 0:
|
| 158 |
+
# faster when no padding is required
|
| 159 |
+
merged = patches.reshape(n_patches_per_batch, batch_size, hidden_size, out_size, out_size)
|
| 160 |
+
merged = merged.permute(1, 2, 0, 3, 4)
|
| 161 |
+
merged = merged[:, :, : sqrt_n_patches_per_batch**2, :, :]
|
| 162 |
+
merged = merged.reshape(
|
| 163 |
+
batch_size, hidden_size, sqrt_n_patches_per_batch, sqrt_n_patches_per_batch, out_size, out_size
|
| 164 |
+
)
|
| 165 |
+
merged = merged.permute(0, 1, 2, 4, 3, 5)
|
| 166 |
+
merged = merged.reshape(batch_size, hidden_size, new_out_size, new_out_size)
|
| 167 |
+
else:
|
| 168 |
+
# padding example:
|
| 169 |
+
# let out_size = 8, new_out_size = 32, padding = 2
|
| 170 |
+
# each patch is separated by "|"
|
| 171 |
+
# and padding is applied to the merging edges of each patch
|
| 172 |
+
# 00 01 02 03 04 05 06 07 | 08 09 10 11 12 13 14 15 | 16 17 18 19 20 21 22 23 | 24 25 26 27 28 29 30 31
|
| 173 |
+
# 00 01 02 03 04 05 -- -- | -- -- 10 11 12 13 -- -- | -- -- 18 19 20 21 -- -- | -- -- 26 27 28 29 30 31
|
| 174 |
+
i = 0
|
| 175 |
+
boxes = []
|
| 176 |
+
for h in range(sqrt_n_patches_per_batch):
|
| 177 |
+
boxes_in_row = []
|
| 178 |
+
for w in range(sqrt_n_patches_per_batch):
|
| 179 |
+
box = patches[batch_size * i : batch_size * (i + 1)]
|
| 180 |
+
|
| 181 |
+
# collect paddings
|
| 182 |
+
paddings = [0, 0, 0, 0]
|
| 183 |
+
if h != 0:
|
| 184 |
+
# remove pad from height if box is not at top border
|
| 185 |
+
paddings[0] = padding
|
| 186 |
+
if w != 0:
|
| 187 |
+
# remove pad from width if box is not at left border
|
| 188 |
+
paddings[2] = padding
|
| 189 |
+
if h != sqrt_n_patches_per_batch - 1:
|
| 190 |
+
# remove pad from height if box is not at bottom border
|
| 191 |
+
paddings[1] = padding
|
| 192 |
+
if w != sqrt_n_patches_per_batch - 1:
|
| 193 |
+
# remove pad from width if box is not at right border
|
| 194 |
+
paddings[3] = padding
|
| 195 |
+
|
| 196 |
+
# remove paddings
|
| 197 |
+
_, _, box_h, box_w = box.shape
|
| 198 |
+
pad_top, pad_bottom, pad_left, pad_right = paddings
|
| 199 |
+
box = box[:, :, pad_top : box_h - pad_bottom, pad_left : box_w - pad_right]
|
| 200 |
+
|
| 201 |
+
boxes_in_row.append(box)
|
| 202 |
+
i += 1
|
| 203 |
+
boxes_in_row = torch.cat(boxes_in_row, dim=-1)
|
| 204 |
+
boxes.append(boxes_in_row)
|
| 205 |
+
merged = torch.cat(boxes, dim=-2)
|
| 206 |
+
|
| 207 |
+
return merged
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def reconstruct_feature_maps(
|
| 211 |
+
hidden_state: torch.Tensor, batch_size: int, padding: int, output_size: Tuple[float, float]
|
| 212 |
+
) -> torch.Tensor:
|
| 213 |
+
"""
|
| 214 |
+
Reconstructs feature maps from the hidden state produced by any of the encoder. Converts the hidden state of shape
|
| 215 |
+
`(n_patches_per_batch * batch_size, seq_len, hidden_size)` to feature maps of shape
|
| 216 |
+
`(batch_size, hidden_size, output_size[0], output_size[1])`.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
hidden_state (torch.Tensor): Input tensor of shape `(n_patches_per_batch * batch_size, seq_len, hidden_size)`
|
| 220 |
+
representing the encoded patches.
|
| 221 |
+
batch_size (int): The number of samples in a batch.
|
| 222 |
+
padding (int): The amount of padding to be removed when merging patches.
|
| 223 |
+
output_size (Tuple[float, float]): The desired output size for the feature maps, specified as `(height, width)`.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
torch.Tensor: Reconstructed feature maps of shape `(batch_size, hidden_size, output_size[0], output_size[1])`.
|
| 227 |
+
"""
|
| 228 |
+
# reshape back to image like
|
| 229 |
+
features = reshape_features(hidden_state)
|
| 230 |
+
|
| 231 |
+
# merge all patches in a batch to create one large patch per batch
|
| 232 |
+
features = merge_patches(
|
| 233 |
+
features,
|
| 234 |
+
batch_size=batch_size,
|
| 235 |
+
padding=padding,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# interpolate patches to base size
|
| 239 |
+
features = F.interpolate(
|
| 240 |
+
features,
|
| 241 |
+
size=output_size,
|
| 242 |
+
mode="bilinear",
|
| 243 |
+
align_corners=False,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
return features
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class DepthProPatchEncoder(nn.Module):
|
| 250 |
+
def __init__(self, config: DepthProConfig):
|
| 251 |
+
super().__init__()
|
| 252 |
+
self.config = config
|
| 253 |
+
|
| 254 |
+
self.intermediate_hook_ids = config.intermediate_hook_ids
|
| 255 |
+
self.intermediate_feature_dims = config.intermediate_feature_dims
|
| 256 |
+
self.scaled_images_ratios = config.scaled_images_ratios
|
| 257 |
+
self.scaled_images_overlap_ratios = config.scaled_images_overlap_ratios
|
| 258 |
+
self.scaled_images_feature_dims = config.scaled_images_feature_dims
|
| 259 |
+
self.merge_padding_value = config.merge_padding_value
|
| 260 |
+
|
| 261 |
+
self.n_scaled_images = len(config.scaled_images_ratios)
|
| 262 |
+
self.n_intermediate_hooks = len(config.intermediate_hook_ids)
|
| 263 |
+
self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
|
| 264 |
+
|
| 265 |
+
self.model = AutoModel.from_config(config.patch_model_config)
|
| 266 |
+
|
| 267 |
+
def forward(
|
| 268 |
+
self,
|
| 269 |
+
pixel_values: torch.Tensor,
|
| 270 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 271 |
+
) -> List[torch.Tensor]:
|
| 272 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 273 |
+
|
| 274 |
+
if min(self.scaled_images_ratios) * min(height, width) < self.config.patch_size:
|
| 275 |
+
raise ValueError(
|
| 276 |
+
f"Image size {height}x{width} is too small to be scaled "
|
| 277 |
+
f"with scaled_images_ratios={self.scaled_images_ratios} "
|
| 278 |
+
f"when patch_size={self.config.patch_size}."
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# STEP 1: create 3-level image
|
| 282 |
+
|
| 283 |
+
scaled_images = []
|
| 284 |
+
for ratio in self.scaled_images_ratios:
|
| 285 |
+
scaled_images.append(
|
| 286 |
+
F.interpolate(
|
| 287 |
+
pixel_values,
|
| 288 |
+
scale_factor=ratio,
|
| 289 |
+
mode="bilinear",
|
| 290 |
+
align_corners=False,
|
| 291 |
+
)
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# STEP 2: create patches
|
| 295 |
+
|
| 296 |
+
for i in range(self.n_scaled_images):
|
| 297 |
+
scaled_images[i] = split_to_patches(
|
| 298 |
+
scaled_images[i],
|
| 299 |
+
patch_size=self.config.patch_size,
|
| 300 |
+
overlap_ratio=self.scaled_images_overlap_ratios[i],
|
| 301 |
+
)
|
| 302 |
+
n_patches_per_scaled_image = [len(i) for i in scaled_images]
|
| 303 |
+
patches = torch.cat(scaled_images[::-1], dim=0) # -1 as patch encoder expects high res patches first
|
| 304 |
+
|
| 305 |
+
# STEP 3: apply patch encoder
|
| 306 |
+
|
| 307 |
+
encodings = self.model(
|
| 308 |
+
# each patch is processed as a separate batch
|
| 309 |
+
patches,
|
| 310 |
+
head_mask=head_mask,
|
| 311 |
+
# required for intermediate features
|
| 312 |
+
output_hidden_states=self.n_intermediate_hooks > 0,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
scaled_images_last_hidden_state = torch.split_with_sizes(encodings[0], n_patches_per_scaled_image[::-1])
|
| 316 |
+
# -1 (reverse list) as patch encoder returns high res patches first, we need low res first
|
| 317 |
+
scaled_images_last_hidden_state = scaled_images_last_hidden_state[::-1]
|
| 318 |
+
|
| 319 |
+
# calculate base height and width
|
| 320 |
+
# base height and width are the dimensions of the lowest resolution features
|
| 321 |
+
exponent_value = torch_int(math.log2(width / self.out_size))
|
| 322 |
+
base_height = height // 2**exponent_value
|
| 323 |
+
base_width = width // 2**exponent_value
|
| 324 |
+
|
| 325 |
+
# STEP 4: get patch features (high_res, med_res, low_res) - (3-5) in diagram
|
| 326 |
+
|
| 327 |
+
scaled_images_features = []
|
| 328 |
+
for i in range(self.n_scaled_images):
|
| 329 |
+
hidden_state = scaled_images_last_hidden_state[i]
|
| 330 |
+
batch_size = batch_size
|
| 331 |
+
padding = torch_int(self.merge_padding_value * (1 / self.scaled_images_ratios[i]))
|
| 332 |
+
output_height = base_height * 2**i
|
| 333 |
+
output_width = base_width * 2**i
|
| 334 |
+
features = reconstruct_feature_maps(
|
| 335 |
+
hidden_state,
|
| 336 |
+
batch_size=batch_size,
|
| 337 |
+
padding=padding,
|
| 338 |
+
output_size=(output_height, output_width),
|
| 339 |
+
)
|
| 340 |
+
scaled_images_features.append(features)
|
| 341 |
+
|
| 342 |
+
# STEP 5: get intermediate features - (1-2) in diagram
|
| 343 |
+
|
| 344 |
+
intermediate_features = []
|
| 345 |
+
for i in range(self.n_intermediate_hooks):
|
| 346 |
+
# +1 to correct index position as hidden_states contain embedding output as well
|
| 347 |
+
hidden_state = encodings[2][self.intermediate_hook_ids[i] + 1]
|
| 348 |
+
padding = torch_int(self.merge_padding_value * (1 / self.scaled_images_ratios[-1]))
|
| 349 |
+
output_height = base_height * 2 ** (self.n_scaled_images - 1)
|
| 350 |
+
output_width = base_width * 2 ** (self.n_scaled_images - 1)
|
| 351 |
+
features = reconstruct_feature_maps(
|
| 352 |
+
hidden_state,
|
| 353 |
+
batch_size=batch_size,
|
| 354 |
+
padding=padding,
|
| 355 |
+
output_size=(output_height, output_width),
|
| 356 |
+
)
|
| 357 |
+
intermediate_features.append(features)
|
| 358 |
+
|
| 359 |
+
# STEP 7: combine all features
|
| 360 |
+
features = [*scaled_images_features, *intermediate_features]
|
| 361 |
+
|
| 362 |
+
return features
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class DepthProImageEncoder(nn.Module):
|
| 366 |
+
def __init__(self, config: DepthProConfig):
|
| 367 |
+
super().__init__()
|
| 368 |
+
self.config = config
|
| 369 |
+
self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
|
| 370 |
+
|
| 371 |
+
self.model = AutoModel.from_config(config.image_model_config)
|
| 372 |
+
|
| 373 |
+
def forward(
|
| 374 |
+
self,
|
| 375 |
+
pixel_values: torch.Tensor,
|
| 376 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 377 |
+
output_attentions: bool = False,
|
| 378 |
+
output_hidden_states: bool = False,
|
| 379 |
+
return_dict: bool = True,
|
| 380 |
+
) -> Union[tuple, DepthProOutput]:
|
| 381 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 382 |
+
|
| 383 |
+
# scale the image for image_encoder
|
| 384 |
+
size = self.config.image_model_config.image_size
|
| 385 |
+
pixel_values = F.interpolate(
|
| 386 |
+
pixel_values,
|
| 387 |
+
size=(size, size),
|
| 388 |
+
mode="bilinear",
|
| 389 |
+
align_corners=False,
|
| 390 |
+
)
|
| 391 |
+
encodings = self.model(
|
| 392 |
+
pixel_values=pixel_values,
|
| 393 |
+
head_mask=head_mask,
|
| 394 |
+
output_attentions=output_attentions,
|
| 395 |
+
output_hidden_states=output_hidden_states,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
# calculate base height and width
|
| 399 |
+
# base height and width are the dimensions of the lowest resolution features
|
| 400 |
+
exponent_value = torch_int(math.log2(width / self.out_size))
|
| 401 |
+
base_height = height // 2**exponent_value
|
| 402 |
+
base_width = width // 2**exponent_value
|
| 403 |
+
|
| 404 |
+
features = reconstruct_feature_maps(
|
| 405 |
+
encodings[0],
|
| 406 |
+
batch_size=batch_size,
|
| 407 |
+
padding=0,
|
| 408 |
+
output_size=(base_height, base_width),
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
if not return_dict:
|
| 412 |
+
return (encodings[0], features) + encodings[2:] # ignore last_hidden_state and poooler output
|
| 413 |
+
|
| 414 |
+
return DepthProOutput(
|
| 415 |
+
last_hidden_state=encodings.last_hidden_state,
|
| 416 |
+
features=features,
|
| 417 |
+
hidden_states=encodings.hidden_states,
|
| 418 |
+
attentions=encodings.attentions,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
class DepthProEncoder(nn.Module):
|
| 423 |
+
def __init__(self, config: DepthProConfig):
|
| 424 |
+
super().__init__()
|
| 425 |
+
self.config = config
|
| 426 |
+
self.intermediate_hook_ids = config.intermediate_hook_ids
|
| 427 |
+
self.intermediate_feature_dims = config.intermediate_feature_dims
|
| 428 |
+
self.scaled_images_ratios = config.scaled_images_ratios
|
| 429 |
+
self.scaled_images_overlap_ratios = config.scaled_images_overlap_ratios
|
| 430 |
+
self.scaled_images_feature_dims = config.scaled_images_feature_dims
|
| 431 |
+
self.merge_padding_value = config.merge_padding_value
|
| 432 |
+
|
| 433 |
+
self.n_scaled_images = len(self.scaled_images_ratios)
|
| 434 |
+
self.n_intermediate_hooks = len(self.intermediate_hook_ids)
|
| 435 |
+
|
| 436 |
+
self.patch_encoder = DepthProPatchEncoder(config)
|
| 437 |
+
self.image_encoder = DepthProImageEncoder(config)
|
| 438 |
+
|
| 439 |
+
def forward(
|
| 440 |
+
self,
|
| 441 |
+
pixel_values: torch.Tensor,
|
| 442 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 443 |
+
output_attentions: bool = False,
|
| 444 |
+
output_hidden_states: bool = False,
|
| 445 |
+
return_dict: bool = True,
|
| 446 |
+
) -> Union[tuple, DepthProOutput]:
|
| 447 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 448 |
+
|
| 449 |
+
patch_features = self.patch_encoder(
|
| 450 |
+
pixel_values,
|
| 451 |
+
head_mask=head_mask,
|
| 452 |
+
)
|
| 453 |
+
image_encodings = self.image_encoder(
|
| 454 |
+
pixel_values,
|
| 455 |
+
head_mask=head_mask,
|
| 456 |
+
output_attentions=output_attentions,
|
| 457 |
+
output_hidden_states=output_hidden_states,
|
| 458 |
+
return_dict=return_dict,
|
| 459 |
+
)
|
| 460 |
+
image_features = image_encodings[1] # index 1 contains features
|
| 461 |
+
|
| 462 |
+
features = [image_features, *patch_features]
|
| 463 |
+
|
| 464 |
+
if not return_dict:
|
| 465 |
+
return (image_encodings[0], features) + image_encodings[2:]
|
| 466 |
+
|
| 467 |
+
return DepthProOutput(
|
| 468 |
+
last_hidden_state=image_encodings.last_hidden_state,
|
| 469 |
+
features=features,
|
| 470 |
+
hidden_states=image_encodings.hidden_states,
|
| 471 |
+
attentions=image_encodings.attentions,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
class DepthProFeatureUpsampleBlock(nn.Module):
|
| 476 |
+
def __init__(
|
| 477 |
+
self,
|
| 478 |
+
config: DepthProConfig,
|
| 479 |
+
input_dims: int,
|
| 480 |
+
intermediate_dims: int,
|
| 481 |
+
output_dims: int,
|
| 482 |
+
n_upsample_layers: int,
|
| 483 |
+
use_proj: bool = True,
|
| 484 |
+
bias: bool = False,
|
| 485 |
+
):
|
| 486 |
+
super().__init__()
|
| 487 |
+
self.config = config
|
| 488 |
+
self.layers = nn.ModuleList()
|
| 489 |
+
|
| 490 |
+
# create first projection layer
|
| 491 |
+
if use_proj:
|
| 492 |
+
proj = nn.Conv2d(
|
| 493 |
+
in_channels=input_dims,
|
| 494 |
+
out_channels=intermediate_dims,
|
| 495 |
+
kernel_size=1,
|
| 496 |
+
stride=1,
|
| 497 |
+
padding=0,
|
| 498 |
+
bias=bias,
|
| 499 |
+
)
|
| 500 |
+
self.layers.append(proj)
|
| 501 |
+
|
| 502 |
+
# create following upsample layers
|
| 503 |
+
for i in range(n_upsample_layers):
|
| 504 |
+
in_channels = intermediate_dims if i == 0 else output_dims
|
| 505 |
+
layer = nn.ConvTranspose2d(
|
| 506 |
+
in_channels=in_channels,
|
| 507 |
+
out_channels=output_dims,
|
| 508 |
+
kernel_size=2,
|
| 509 |
+
stride=2,
|
| 510 |
+
padding=0,
|
| 511 |
+
bias=bias,
|
| 512 |
+
)
|
| 513 |
+
self.layers.append(layer)
|
| 514 |
+
|
| 515 |
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
| 516 |
+
for layer in self.layers:
|
| 517 |
+
features = layer(features)
|
| 518 |
+
return features
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
class DepthProFeatureUpsample(nn.Module):
|
| 522 |
+
def __init__(self, config: DepthProConfig):
|
| 523 |
+
super().__init__()
|
| 524 |
+
self.config = config
|
| 525 |
+
self.n_scaled_images = len(self.config.scaled_images_ratios)
|
| 526 |
+
self.n_intermediate_hooks = len(self.config.intermediate_hook_ids)
|
| 527 |
+
|
| 528 |
+
# for image_features
|
| 529 |
+
self.image_block = DepthProFeatureUpsampleBlock(
|
| 530 |
+
config=config,
|
| 531 |
+
input_dims=config.image_model_config.hidden_size,
|
| 532 |
+
intermediate_dims=config.image_model_config.hidden_size,
|
| 533 |
+
output_dims=config.scaled_images_feature_dims[0],
|
| 534 |
+
n_upsample_layers=1,
|
| 535 |
+
use_proj=False,
|
| 536 |
+
bias=True,
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
# for scaled_images_features
|
| 540 |
+
self.scaled_images = nn.ModuleList()
|
| 541 |
+
for i, feature_dims in enumerate(config.scaled_images_feature_dims):
|
| 542 |
+
block = DepthProFeatureUpsampleBlock(
|
| 543 |
+
config=config,
|
| 544 |
+
input_dims=config.patch_model_config.hidden_size,
|
| 545 |
+
intermediate_dims=feature_dims,
|
| 546 |
+
output_dims=feature_dims,
|
| 547 |
+
n_upsample_layers=1,
|
| 548 |
+
)
|
| 549 |
+
self.scaled_images.append(block)
|
| 550 |
+
|
| 551 |
+
# for intermediate_features
|
| 552 |
+
self.intermediate = nn.ModuleList()
|
| 553 |
+
for i, feature_dims in enumerate(config.intermediate_feature_dims):
|
| 554 |
+
intermediate_dims = config.fusion_hidden_size if i == 0 else feature_dims
|
| 555 |
+
block = DepthProFeatureUpsampleBlock(
|
| 556 |
+
config=config,
|
| 557 |
+
input_dims=config.patch_model_config.hidden_size,
|
| 558 |
+
intermediate_dims=intermediate_dims,
|
| 559 |
+
output_dims=feature_dims,
|
| 560 |
+
n_upsample_layers=2 + i,
|
| 561 |
+
)
|
| 562 |
+
self.intermediate.append(block)
|
| 563 |
+
|
| 564 |
+
def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 565 |
+
features[0] = self.image_block(features[0])
|
| 566 |
+
|
| 567 |
+
for i in range(self.n_scaled_images):
|
| 568 |
+
features[i + 1] = self.scaled_images[i](features[i + 1])
|
| 569 |
+
|
| 570 |
+
for i in range(self.n_intermediate_hooks):
|
| 571 |
+
features[self.n_scaled_images + i + 1] = self.intermediate[i](features[self.n_scaled_images + i + 1])
|
| 572 |
+
|
| 573 |
+
return features
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
class DepthProFeatureProjection(nn.Module):
|
| 577 |
+
def __init__(self, config: DepthProConfig):
|
| 578 |
+
super().__init__()
|
| 579 |
+
self.config = config
|
| 580 |
+
|
| 581 |
+
combined_feature_dims = config.scaled_images_feature_dims + config.intermediate_feature_dims
|
| 582 |
+
self.projections = nn.ModuleList()
|
| 583 |
+
for i, in_channels in enumerate(combined_feature_dims):
|
| 584 |
+
if i == len(combined_feature_dims) - 1 and in_channels == config.fusion_hidden_size:
|
| 585 |
+
# projection for last layer can be ignored if input and output channels already match
|
| 586 |
+
self.projections.append(nn.Identity())
|
| 587 |
+
else:
|
| 588 |
+
self.projections.append(
|
| 589 |
+
nn.Conv2d(
|
| 590 |
+
in_channels=in_channels,
|
| 591 |
+
out_channels=config.fusion_hidden_size,
|
| 592 |
+
kernel_size=3,
|
| 593 |
+
stride=1,
|
| 594 |
+
padding=1,
|
| 595 |
+
bias=False,
|
| 596 |
+
)
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 600 |
+
projected_features = []
|
| 601 |
+
for i, projection in enumerate(self.projections):
|
| 602 |
+
upsampled_feature = projection(features[i])
|
| 603 |
+
projected_features.append(upsampled_feature)
|
| 604 |
+
return projected_features
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
class DepthProNeck(nn.Module):
|
| 608 |
+
def __init__(self, config: DepthProConfig):
|
| 609 |
+
super().__init__()
|
| 610 |
+
self.config = config
|
| 611 |
+
|
| 612 |
+
self.feature_upsample = DepthProFeatureUpsample(config)
|
| 613 |
+
self.fuse_image_with_low_res = nn.Conv2d(
|
| 614 |
+
in_channels=config.scaled_images_feature_dims[0] * 2,
|
| 615 |
+
out_channels=config.scaled_images_feature_dims[0],
|
| 616 |
+
kernel_size=1,
|
| 617 |
+
stride=1,
|
| 618 |
+
padding=0,
|
| 619 |
+
bias=True,
|
| 620 |
+
)
|
| 621 |
+
self.feature_projection = DepthProFeatureProjection(config)
|
| 622 |
+
|
| 623 |
+
def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 624 |
+
features = self.feature_upsample(features)
|
| 625 |
+
# global features = low res features + image features
|
| 626 |
+
global_features = torch.cat((features[1], features[0]), dim=1)
|
| 627 |
+
global_features = self.fuse_image_with_low_res(global_features)
|
| 628 |
+
features = [global_features, *features[2:]]
|
| 629 |
+
features = self.feature_projection(features)
|
| 630 |
+
return features
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
# General docstring
|
| 634 |
+
_CONFIG_FOR_DOC = "DepthProConfig"
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
DEPTH_PRO_START_DOCSTRING = r"""
|
| 638 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 639 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 640 |
+
behavior.
|
| 641 |
+
|
| 642 |
+
Parameters:
|
| 643 |
+
config ([`DepthProConfig`]): Model configuration class with all the parameters of the model.
|
| 644 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 645 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 646 |
+
"""
|
| 647 |
+
|
| 648 |
+
DEPTH_PRO_INPUTS_DOCSTRING = r"""
|
| 649 |
+
Args:
|
| 650 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 651 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`]
|
| 652 |
+
for details.
|
| 653 |
+
|
| 654 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 655 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 656 |
+
|
| 657 |
+
- 1 indicates the head is **not masked**,
|
| 658 |
+
- 0 indicates the head is **masked**.
|
| 659 |
+
|
| 660 |
+
output_attentions (`bool`, *optional*):
|
| 661 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 662 |
+
tensors for more detail.
|
| 663 |
+
output_hidden_states (`bool`, *optional*):
|
| 664 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 665 |
+
more detail.
|
| 666 |
+
return_dict (`bool`, *optional*):
|
| 667 |
+
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
| 668 |
+
"""
|
| 669 |
+
|
| 670 |
+
DEPTH_PRO_FOR_DEPTH_ESTIMATION_START_DOCSTRING = r"""
|
| 671 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 672 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 673 |
+
behavior.
|
| 674 |
+
|
| 675 |
+
Parameters:
|
| 676 |
+
config ([`DepthProConfig`]): Model configuration class with all the parameters of the model.
|
| 677 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 678 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 679 |
+
use_fov_model (`bool`, *optional*, defaults to `True`):
|
| 680 |
+
Whether to use `DepthProFovModel` to generate the field of view.
|
| 681 |
+
"""
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
class DepthProPreTrainedModel(PreTrainedModel):
|
| 685 |
+
"""
|
| 686 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 687 |
+
models.
|
| 688 |
+
"""
|
| 689 |
+
|
| 690 |
+
config_class = DepthProConfig
|
| 691 |
+
base_model_prefix = "depth_pro"
|
| 692 |
+
main_input_name = "pixel_values"
|
| 693 |
+
supports_gradient_checkpointing = True
|
| 694 |
+
_supports_sdpa = True
|
| 695 |
+
_no_split_modules = ["DepthProPreActResidualLayer"]
|
| 696 |
+
_keys_to_ignore_on_load_unexpected = ["fov_model.*"]
|
| 697 |
+
|
| 698 |
+
def _init_weights(self, module):
|
| 699 |
+
"""Initialize the weights"""
|
| 700 |
+
if isinstance(module, nn.Linear):
|
| 701 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 702 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 703 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 704 |
+
if module.bias is not None:
|
| 705 |
+
module.bias.data.zero_()
|
| 706 |
+
elif isinstance(module, nn.LayerNorm):
|
| 707 |
+
module.bias.data.zero_()
|
| 708 |
+
module.weight.data.fill_(1.0)
|
| 709 |
+
elif isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
|
| 710 |
+
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
| 711 |
+
if module.bias is not None:
|
| 712 |
+
module.bias.data.zero_()
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
@add_start_docstrings(
|
| 716 |
+
"The bare DepthPro Model transformer outputting raw hidden-states without any specific head on top.",
|
| 717 |
+
DEPTH_PRO_START_DOCSTRING,
|
| 718 |
+
)
|
| 719 |
+
class DepthProModel(DepthProPreTrainedModel):
|
| 720 |
+
def __init__(self, config):
|
| 721 |
+
super().__init__(config)
|
| 722 |
+
self.config = config
|
| 723 |
+
self.encoder = DepthProEncoder(config)
|
| 724 |
+
self.neck = DepthProNeck(config)
|
| 725 |
+
# Initialize weights and apply final processing
|
| 726 |
+
self.post_init()
|
| 727 |
+
|
| 728 |
+
def get_input_embeddings(self):
|
| 729 |
+
return self.encoder.image_encoder.model.get_input_embeddings()
|
| 730 |
+
|
| 731 |
+
@add_start_docstrings_to_model_forward(DEPTH_PRO_INPUTS_DOCSTRING)
|
| 732 |
+
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
| 733 |
+
def forward(
|
| 734 |
+
self,
|
| 735 |
+
pixel_values: torch.FloatTensor,
|
| 736 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 737 |
+
output_attentions: Optional[bool] = None,
|
| 738 |
+
output_hidden_states: Optional[bool] = None,
|
| 739 |
+
return_dict: Optional[bool] = None,
|
| 740 |
+
) -> Union[Tuple, DepthProOutput]:
|
| 741 |
+
r"""
|
| 742 |
+
Returns:
|
| 743 |
+
|
| 744 |
+
Examples:
|
| 745 |
+
|
| 746 |
+
```python
|
| 747 |
+
>>> import torch
|
| 748 |
+
>>> from PIL import Image
|
| 749 |
+
>>> import requests
|
| 750 |
+
>>> from transformers import AutoProcessor, DepthProModel
|
| 751 |
+
|
| 752 |
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
| 753 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 754 |
+
|
| 755 |
+
>>> checkpoint = "apple/DepthPro-hf"
|
| 756 |
+
>>> processor = AutoProcessor.from_pretrained(checkpoint)
|
| 757 |
+
>>> model = DepthProModel.from_pretrained(checkpoint)
|
| 758 |
+
|
| 759 |
+
>>> # prepare image for the model
|
| 760 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
| 761 |
+
|
| 762 |
+
>>> with torch.no_grad():
|
| 763 |
+
... output = model(**inputs)
|
| 764 |
+
|
| 765 |
+
>>> output.last_hidden_state.shape
|
| 766 |
+
torch.Size([1, 35, 577, 1024])
|
| 767 |
+
```"""
|
| 768 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 769 |
+
output_hidden_states = (
|
| 770 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 771 |
+
)
|
| 772 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 773 |
+
|
| 774 |
+
encodings = self.encoder(
|
| 775 |
+
pixel_values,
|
| 776 |
+
head_mask=head_mask,
|
| 777 |
+
output_attentions=output_attentions,
|
| 778 |
+
output_hidden_states=output_hidden_states,
|
| 779 |
+
return_dict=return_dict,
|
| 780 |
+
)
|
| 781 |
+
features = encodings[1] # index 1 contains features
|
| 782 |
+
features = self.neck(features)
|
| 783 |
+
|
| 784 |
+
if not return_dict:
|
| 785 |
+
return (encodings[0], features) + encodings[2:]
|
| 786 |
+
|
| 787 |
+
return DepthProOutput(
|
| 788 |
+
last_hidden_state=encodings.last_hidden_state,
|
| 789 |
+
features=features,
|
| 790 |
+
hidden_states=encodings.hidden_states,
|
| 791 |
+
attentions=encodings.attentions,
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
# Copied from transformers.models.dpt.modeling_dpt.DPTPreActResidualLayer DPT->DepthPro
|
| 796 |
+
class DepthProPreActResidualLayer(nn.Module):
|
| 797 |
+
"""
|
| 798 |
+
ResidualConvUnit, pre-activate residual unit.
|
| 799 |
+
|
| 800 |
+
Args:
|
| 801 |
+
config (`[DepthProConfig]`):
|
| 802 |
+
Model configuration class defining the model architecture.
|
| 803 |
+
"""
|
| 804 |
+
|
| 805 |
+
def __init__(self, config):
|
| 806 |
+
super().__init__()
|
| 807 |
+
|
| 808 |
+
self.use_batch_norm = config.use_batch_norm_in_fusion_residual
|
| 809 |
+
use_bias_in_fusion_residual = (
|
| 810 |
+
config.use_bias_in_fusion_residual
|
| 811 |
+
if config.use_bias_in_fusion_residual is not None
|
| 812 |
+
else not self.use_batch_norm
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
self.activation1 = nn.ReLU()
|
| 816 |
+
self.convolution1 = nn.Conv2d(
|
| 817 |
+
config.fusion_hidden_size,
|
| 818 |
+
config.fusion_hidden_size,
|
| 819 |
+
kernel_size=3,
|
| 820 |
+
stride=1,
|
| 821 |
+
padding=1,
|
| 822 |
+
bias=use_bias_in_fusion_residual,
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
self.activation2 = nn.ReLU()
|
| 826 |
+
self.convolution2 = nn.Conv2d(
|
| 827 |
+
config.fusion_hidden_size,
|
| 828 |
+
config.fusion_hidden_size,
|
| 829 |
+
kernel_size=3,
|
| 830 |
+
stride=1,
|
| 831 |
+
padding=1,
|
| 832 |
+
bias=use_bias_in_fusion_residual,
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
if self.use_batch_norm:
|
| 836 |
+
self.batch_norm1 = nn.BatchNorm2d(config.fusion_hidden_size)
|
| 837 |
+
self.batch_norm2 = nn.BatchNorm2d(config.fusion_hidden_size)
|
| 838 |
+
|
| 839 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
| 840 |
+
residual = hidden_state
|
| 841 |
+
hidden_state = self.activation1(hidden_state)
|
| 842 |
+
|
| 843 |
+
hidden_state = self.convolution1(hidden_state)
|
| 844 |
+
|
| 845 |
+
if self.use_batch_norm:
|
| 846 |
+
hidden_state = self.batch_norm1(hidden_state)
|
| 847 |
+
|
| 848 |
+
hidden_state = self.activation2(hidden_state)
|
| 849 |
+
hidden_state = self.convolution2(hidden_state)
|
| 850 |
+
|
| 851 |
+
if self.use_batch_norm:
|
| 852 |
+
hidden_state = self.batch_norm2(hidden_state)
|
| 853 |
+
|
| 854 |
+
return hidden_state + residual
|
| 855 |
+
|
| 856 |
+
|
| 857 |
+
# Modified from transformers.models.dpt.modeling_dpt.DPTFeatureFusionLayer
|
| 858 |
+
# except it uses deconv and skip_add and needs no interpolation
|
| 859 |
+
class DepthProFeatureFusionLayer(nn.Module):
|
| 860 |
+
def __init__(self, config: DepthProConfig, use_deconv: bool = True):
|
| 861 |
+
super().__init__()
|
| 862 |
+
self.config = config
|
| 863 |
+
self.use_deconv = use_deconv
|
| 864 |
+
|
| 865 |
+
self.residual_layer1 = DepthProPreActResidualLayer(config)
|
| 866 |
+
self.residual_layer2 = DepthProPreActResidualLayer(config)
|
| 867 |
+
|
| 868 |
+
if self.use_deconv:
|
| 869 |
+
self.deconv = nn.ConvTranspose2d(
|
| 870 |
+
in_channels=config.fusion_hidden_size,
|
| 871 |
+
out_channels=config.fusion_hidden_size,
|
| 872 |
+
kernel_size=2,
|
| 873 |
+
stride=2,
|
| 874 |
+
padding=0,
|
| 875 |
+
bias=False,
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
|
| 879 |
+
|
| 880 |
+
def forward(self, hidden_state: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 881 |
+
if residual is not None:
|
| 882 |
+
residual = self.residual_layer1(residual)
|
| 883 |
+
hidden_state = hidden_state + residual
|
| 884 |
+
|
| 885 |
+
hidden_state = self.residual_layer2(hidden_state)
|
| 886 |
+
if self.use_deconv:
|
| 887 |
+
hidden_state = self.deconv(hidden_state)
|
| 888 |
+
hidden_state = self.projection(hidden_state)
|
| 889 |
+
|
| 890 |
+
return hidden_state
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
# Modified from transformers.models.dpt.modeling_dpt.DPTFeatureFusionStage with DPT->DepthPro
|
| 894 |
+
# with deconv and reversed layers
|
| 895 |
+
class DepthProFeatureFusionStage(nn.Module):
|
| 896 |
+
def __init__(self, config):
|
| 897 |
+
super().__init__()
|
| 898 |
+
self.config = config
|
| 899 |
+
|
| 900 |
+
self.num_layers = len(config.intermediate_hook_ids) + len(config.scaled_images_ratios)
|
| 901 |
+
self.intermediate = nn.ModuleList()
|
| 902 |
+
for _ in range(self.num_layers - 1):
|
| 903 |
+
self.intermediate.append(DepthProFeatureFusionLayer(config))
|
| 904 |
+
|
| 905 |
+
# final layer doesnot require deconvolution
|
| 906 |
+
self.final = DepthProFeatureFusionLayer(config, use_deconv=False)
|
| 907 |
+
|
| 908 |
+
def forward(self, hidden_states: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 909 |
+
if self.num_layers != len(hidden_states):
|
| 910 |
+
raise ValueError(
|
| 911 |
+
f"num_layers={self.num_layers} in DepthProFeatureFusionStage"
|
| 912 |
+
f"doesnot match len(hidden_states)={len(hidden_states)}"
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
fused_hidden_states = []
|
| 916 |
+
fused_hidden_state = None
|
| 917 |
+
for hidden_state, layer in zip(hidden_states[:-1], self.intermediate):
|
| 918 |
+
if fused_hidden_state is None:
|
| 919 |
+
# first layer only uses the last hidden_state
|
| 920 |
+
fused_hidden_state = layer(hidden_state)
|
| 921 |
+
else:
|
| 922 |
+
fused_hidden_state = layer(fused_hidden_state, hidden_state)
|
| 923 |
+
fused_hidden_states.append(fused_hidden_state)
|
| 924 |
+
|
| 925 |
+
hidden_state = hidden_states[-1]
|
| 926 |
+
fused_hidden_state = self.final(fused_hidden_state, hidden_state)
|
| 927 |
+
fused_hidden_states.append(fused_hidden_state)
|
| 928 |
+
|
| 929 |
+
return fused_hidden_states
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
class DepthProFovEncoder(nn.Module):
|
| 933 |
+
def __init__(self, config: DepthProConfig):
|
| 934 |
+
super().__init__()
|
| 935 |
+
self.config = config
|
| 936 |
+
self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
|
| 937 |
+
|
| 938 |
+
self.model = AutoModel.from_config(config.fov_model_config)
|
| 939 |
+
self.neck = nn.Linear(config.fov_model_config.hidden_size, config.fusion_hidden_size // 2)
|
| 940 |
+
|
| 941 |
+
def forward(
|
| 942 |
+
self,
|
| 943 |
+
pixel_values: torch.Tensor,
|
| 944 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 945 |
+
) -> torch.Tensor:
|
| 946 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 947 |
+
|
| 948 |
+
# scale the image for fov_encoder
|
| 949 |
+
size = self.config.fov_model_config.image_size
|
| 950 |
+
pixel_values = F.interpolate(
|
| 951 |
+
pixel_values,
|
| 952 |
+
size=(size, size),
|
| 953 |
+
mode="bilinear",
|
| 954 |
+
align_corners=False,
|
| 955 |
+
)
|
| 956 |
+
encodings = self.model(
|
| 957 |
+
pixel_values=pixel_values,
|
| 958 |
+
head_mask=head_mask,
|
| 959 |
+
)
|
| 960 |
+
hidden_state = encodings[0]
|
| 961 |
+
hidden_state = self.neck(hidden_state)
|
| 962 |
+
|
| 963 |
+
# calculate base height and width
|
| 964 |
+
# base height and width are the dimensions of the lowest resolution features
|
| 965 |
+
exponent_value = torch_int(math.log2(width / self.out_size))
|
| 966 |
+
base_height = height // 2**exponent_value
|
| 967 |
+
base_width = width // 2**exponent_value
|
| 968 |
+
|
| 969 |
+
features = reconstruct_feature_maps(
|
| 970 |
+
hidden_state,
|
| 971 |
+
batch_size=batch_size,
|
| 972 |
+
padding=0,
|
| 973 |
+
output_size=(base_height, base_width),
|
| 974 |
+
)
|
| 975 |
+
|
| 976 |
+
return features
|
| 977 |
+
|
| 978 |
+
|
| 979 |
+
class DepthProFovHead(nn.Module):
|
| 980 |
+
def __init__(self, config: DepthProConfig):
|
| 981 |
+
super().__init__()
|
| 982 |
+
self.config = config
|
| 983 |
+
self.fusion_hidden_size = config.fusion_hidden_size
|
| 984 |
+
self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
|
| 985 |
+
|
| 986 |
+
# create initial head layers
|
| 987 |
+
self.layers = nn.ModuleList()
|
| 988 |
+
for i in range(config.num_fov_head_layers):
|
| 989 |
+
self.layers.append(
|
| 990 |
+
nn.Conv2d(
|
| 991 |
+
math.ceil(self.fusion_hidden_size / 2 ** (i + 1)),
|
| 992 |
+
math.ceil(self.fusion_hidden_size / 2 ** (i + 2)),
|
| 993 |
+
kernel_size=3,
|
| 994 |
+
stride=2,
|
| 995 |
+
padding=1,
|
| 996 |
+
)
|
| 997 |
+
)
|
| 998 |
+
self.layers.append(nn.ReLU(True))
|
| 999 |
+
# calculate expected shapes to finally generate a scalar output from final head layer
|
| 1000 |
+
final_in_channels = math.ceil(self.fusion_hidden_size / 2 ** (config.num_fov_head_layers + 1))
|
| 1001 |
+
final_kernel_size = torch_int((self.out_size - 1) / 2**config.num_fov_head_layers + 1)
|
| 1002 |
+
self.layers.append(
|
| 1003 |
+
nn.Conv2d(
|
| 1004 |
+
in_channels=final_in_channels, out_channels=1, kernel_size=final_kernel_size, stride=1, padding=0
|
| 1005 |
+
)
|
| 1006 |
+
)
|
| 1007 |
+
|
| 1008 |
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
| 1009 |
+
features = F.interpolate(
|
| 1010 |
+
features,
|
| 1011 |
+
size=(self.out_size, self.out_size),
|
| 1012 |
+
mode="bilinear",
|
| 1013 |
+
align_corners=False,
|
| 1014 |
+
)
|
| 1015 |
+
for layer in self.layers:
|
| 1016 |
+
features = layer(features)
|
| 1017 |
+
return features
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
class DepthProFovModel(nn.Module):
|
| 1021 |
+
def __init__(self, config: DepthProConfig):
|
| 1022 |
+
super().__init__()
|
| 1023 |
+
self.config = config
|
| 1024 |
+
self.fusion_hidden_size = config.fusion_hidden_size
|
| 1025 |
+
|
| 1026 |
+
self.fov_encoder = DepthProFovEncoder(config)
|
| 1027 |
+
self.conv = nn.Conv2d(
|
| 1028 |
+
self.fusion_hidden_size, self.fusion_hidden_size // 2, kernel_size=3, stride=2, padding=1
|
| 1029 |
+
)
|
| 1030 |
+
self.activation = nn.ReLU(inplace=True)
|
| 1031 |
+
self.head = DepthProFovHead(config)
|
| 1032 |
+
|
| 1033 |
+
def forward(
|
| 1034 |
+
self,
|
| 1035 |
+
pixel_values: torch.Tensor,
|
| 1036 |
+
global_features: torch.Tensor,
|
| 1037 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1038 |
+
) -> torch.Tensor:
|
| 1039 |
+
fov_features = self.fov_encoder(pixel_values, head_mask)
|
| 1040 |
+
|
| 1041 |
+
global_features = self.conv(global_features)
|
| 1042 |
+
global_features = self.activation(global_features)
|
| 1043 |
+
|
| 1044 |
+
fov_features = fov_features + global_features
|
| 1045 |
+
fov_output = self.head(fov_features)
|
| 1046 |
+
fov_output = fov_output.flatten()
|
| 1047 |
+
|
| 1048 |
+
return fov_output
|
| 1049 |
+
|
| 1050 |
+
|
| 1051 |
+
class DepthProDepthEstimationHead(nn.Module):
|
| 1052 |
+
"""
|
| 1053 |
+
The DepthProDepthEstimationHead module serves as the output head for depth estimation tasks.
|
| 1054 |
+
This module comprises a sequence of convolutional and transposed convolutional layers
|
| 1055 |
+
that process the feature map from the fusion to produce a single-channel depth map.
|
| 1056 |
+
Key operations include dimensionality reduction and upsampling to match the input resolution.
|
| 1057 |
+
"""
|
| 1058 |
+
|
| 1059 |
+
def __init__(self, config):
|
| 1060 |
+
super().__init__()
|
| 1061 |
+
self.config = config
|
| 1062 |
+
|
| 1063 |
+
features = config.fusion_hidden_size
|
| 1064 |
+
self.layers = nn.ModuleList(
|
| 1065 |
+
[
|
| 1066 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
| 1067 |
+
nn.ConvTranspose2d(
|
| 1068 |
+
in_channels=features // 2,
|
| 1069 |
+
out_channels=features // 2,
|
| 1070 |
+
kernel_size=2,
|
| 1071 |
+
stride=2,
|
| 1072 |
+
padding=0,
|
| 1073 |
+
bias=True,
|
| 1074 |
+
),
|
| 1075 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
| 1076 |
+
nn.ReLU(True),
|
| 1077 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
| 1078 |
+
nn.ReLU(),
|
| 1079 |
+
]
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 1083 |
+
for layer in self.layers:
|
| 1084 |
+
hidden_states = layer(hidden_states)
|
| 1085 |
+
|
| 1086 |
+
predicted_depth = hidden_states.squeeze(dim=1)
|
| 1087 |
+
return predicted_depth
|
| 1088 |
+
|
| 1089 |
+
|
| 1090 |
+
@add_start_docstrings(
|
| 1091 |
+
"""
|
| 1092 |
+
DepthPro Model with a depth estimation head on top (consisting of 3 convolutional layers).
|
| 1093 |
+
""",
|
| 1094 |
+
DEPTH_PRO_FOR_DEPTH_ESTIMATION_START_DOCSTRING,
|
| 1095 |
+
)
|
| 1096 |
+
class DepthProForDepthEstimation(DepthProPreTrainedModel):
|
| 1097 |
+
def __init__(self, config, use_fov_model=None):
|
| 1098 |
+
super().__init__(config)
|
| 1099 |
+
self.config = config
|
| 1100 |
+
self.use_fov_model = use_fov_model if use_fov_model is not None else self.config.use_fov_model
|
| 1101 |
+
|
| 1102 |
+
# dinov2 (vit) like encoders
|
| 1103 |
+
self.depth_pro = DepthProModel(config)
|
| 1104 |
+
|
| 1105 |
+
# dpt (vit) like fusion stage
|
| 1106 |
+
self.fusion_stage = DepthProFeatureFusionStage(config)
|
| 1107 |
+
|
| 1108 |
+
# depth estimation head
|
| 1109 |
+
self.head = DepthProDepthEstimationHead(config)
|
| 1110 |
+
|
| 1111 |
+
# dinov2 (vit) like encoder
|
| 1112 |
+
self.fov_model = DepthProFovModel(config) if self.use_fov_model else None
|
| 1113 |
+
|
| 1114 |
+
# Initialize weights and apply final processing
|
| 1115 |
+
self.post_init()
|
| 1116 |
+
|
| 1117 |
+
@add_start_docstrings_to_model_forward(DEPTH_PRO_INPUTS_DOCSTRING)
|
| 1118 |
+
@replace_return_docstrings(output_type=DepthProDepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
|
| 1119 |
+
def forward(
|
| 1120 |
+
self,
|
| 1121 |
+
pixel_values: torch.FloatTensor,
|
| 1122 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1123 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1124 |
+
output_attentions: Optional[bool] = None,
|
| 1125 |
+
output_hidden_states: Optional[bool] = None,
|
| 1126 |
+
return_dict: Optional[bool] = None,
|
| 1127 |
+
) -> Union[Tuple[torch.Tensor], DepthProDepthEstimatorOutput]:
|
| 1128 |
+
r"""
|
| 1129 |
+
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
| 1130 |
+
Ground truth depth estimation maps for computing the loss.
|
| 1131 |
+
|
| 1132 |
+
Returns:
|
| 1133 |
+
|
| 1134 |
+
Examples:
|
| 1135 |
+
|
| 1136 |
+
```python
|
| 1137 |
+
>>> from transformers import AutoImageProcessor, DepthProForDepthEstimation
|
| 1138 |
+
>>> import torch
|
| 1139 |
+
>>> from PIL import Image
|
| 1140 |
+
>>> import requests
|
| 1141 |
+
|
| 1142 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1143 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1144 |
+
|
| 1145 |
+
>>> checkpoint = "apple/DepthPro-hf"
|
| 1146 |
+
>>> processor = AutoImageProcessor.from_pretrained(checkpoint)
|
| 1147 |
+
>>> model = DepthProForDepthEstimation.from_pretrained(checkpoint)
|
| 1148 |
+
|
| 1149 |
+
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1150 |
+
>>> model.to(device)
|
| 1151 |
+
|
| 1152 |
+
>>> # prepare image for the model
|
| 1153 |
+
>>> inputs = processor(images=image, return_tensors="pt").to(device)
|
| 1154 |
+
|
| 1155 |
+
>>> with torch.no_grad():
|
| 1156 |
+
... outputs = model(**inputs)
|
| 1157 |
+
|
| 1158 |
+
>>> # interpolate to original size
|
| 1159 |
+
>>> post_processed_output = processor.post_process_depth_estimation(
|
| 1160 |
+
... outputs, target_sizes=[(image.height, image.width)],
|
| 1161 |
+
... )
|
| 1162 |
+
|
| 1163 |
+
>>> # get the field of view (fov) predictions
|
| 1164 |
+
>>> field_of_view = post_processed_output[0]["field_of_view"]
|
| 1165 |
+
>>> focal_length = post_processed_output[0]["focal_length"]
|
| 1166 |
+
|
| 1167 |
+
>>> # visualize the prediction
|
| 1168 |
+
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
|
| 1169 |
+
>>> depth = predicted_depth * 255 / predicted_depth.max()
|
| 1170 |
+
>>> depth = depth.detach().cpu().numpy()
|
| 1171 |
+
>>> depth = Image.fromarray(depth.astype("uint8"))
|
| 1172 |
+
```"""
|
| 1173 |
+
loss = None
|
| 1174 |
+
if labels is not None:
|
| 1175 |
+
raise NotImplementedError("Training is not implemented yet")
|
| 1176 |
+
|
| 1177 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1178 |
+
output_hidden_states = (
|
| 1179 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1180 |
+
)
|
| 1181 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1182 |
+
|
| 1183 |
+
depth_pro_outputs = self.depth_pro(
|
| 1184 |
+
pixel_values=pixel_values,
|
| 1185 |
+
head_mask=head_mask,
|
| 1186 |
+
output_attentions=output_attentions,
|
| 1187 |
+
output_hidden_states=output_hidden_states,
|
| 1188 |
+
return_dict=True,
|
| 1189 |
+
)
|
| 1190 |
+
features = depth_pro_outputs.features
|
| 1191 |
+
fused_hidden_states = self.fusion_stage(features)
|
| 1192 |
+
predicted_depth = self.head(fused_hidden_states[-1])
|
| 1193 |
+
|
| 1194 |
+
if self.use_fov_model:
|
| 1195 |
+
# frozen features from encoder are used
|
| 1196 |
+
features_for_fov = features[0].detach()
|
| 1197 |
+
fov = self.fov_model(
|
| 1198 |
+
pixel_values=pixel_values,
|
| 1199 |
+
global_features=features_for_fov,
|
| 1200 |
+
head_mask=head_mask,
|
| 1201 |
+
)
|
| 1202 |
+
else:
|
| 1203 |
+
fov = None
|
| 1204 |
+
|
| 1205 |
+
if not return_dict:
|
| 1206 |
+
outputs = [loss, predicted_depth, fov, depth_pro_outputs.hidden_states, depth_pro_outputs.attentions]
|
| 1207 |
+
return tuple(v for v in outputs if v is not None)
|
| 1208 |
+
|
| 1209 |
+
return DepthProDepthEstimatorOutput(
|
| 1210 |
+
loss=loss,
|
| 1211 |
+
predicted_depth=predicted_depth,
|
| 1212 |
+
field_of_view=fov,
|
| 1213 |
+
hidden_states=depth_pro_outputs.hidden_states,
|
| 1214 |
+
attentions=depth_pro_outputs.attentions,
|
| 1215 |
+
)
|
| 1216 |
+
|
| 1217 |
+
|
| 1218 |
+
__all__ = ["DepthProPreTrainedModel", "DepthProModel", "DepthProForDepthEstimation"]
|
docs/transformers/build/lib/transformers/models/detr/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import TYPE_CHECKING
|
| 16 |
+
|
| 17 |
+
from ...utils import _LazyModule
|
| 18 |
+
from ...utils.import_utils import define_import_structure
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if TYPE_CHECKING:
|
| 22 |
+
from .configuration_detr import *
|
| 23 |
+
from .feature_extraction_detr import *
|
| 24 |
+
from .image_processing_detr import *
|
| 25 |
+
from .image_processing_detr_fast import *
|
| 26 |
+
from .modeling_detr import *
|
| 27 |
+
else:
|
| 28 |
+
import sys
|
| 29 |
+
|
| 30 |
+
_file = globals()["__file__"]
|
| 31 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/detr/configuration_detr.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 Facebook AI Research and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""DETR model configuration"""
|
| 16 |
+
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
from typing import Mapping
|
| 19 |
+
|
| 20 |
+
from packaging import version
|
| 21 |
+
|
| 22 |
+
from ...configuration_utils import PretrainedConfig
|
| 23 |
+
from ...onnx import OnnxConfig
|
| 24 |
+
from ...utils import logging
|
| 25 |
+
from ...utils.backbone_utils import verify_backbone_config_arguments
|
| 26 |
+
from ..auto import CONFIG_MAPPING
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class DetrConfig(PretrainedConfig):
|
| 33 |
+
r"""
|
| 34 |
+
This is the configuration class to store the configuration of a [`DetrModel`]. It is used to instantiate a DETR
|
| 35 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 36 |
+
defaults will yield a similar configuration to that of the DETR
|
| 37 |
+
[facebook/detr-resnet-50](https://huggingface.co/facebook/detr-resnet-50) architecture.
|
| 38 |
+
|
| 39 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 40 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
use_timm_backbone (`bool`, *optional*, defaults to `True`):
|
| 44 |
+
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
|
| 45 |
+
API.
|
| 46 |
+
backbone_config (`PretrainedConfig` or `dict`, *optional*):
|
| 47 |
+
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
|
| 48 |
+
case it will default to `ResNetConfig()`.
|
| 49 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 50 |
+
The number of input channels.
|
| 51 |
+
num_queries (`int`, *optional*, defaults to 100):
|
| 52 |
+
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetrModel`] can
|
| 53 |
+
detect in a single image. For COCO, we recommend 100 queries.
|
| 54 |
+
d_model (`int`, *optional*, defaults to 256):
|
| 55 |
+
This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others.
|
| 56 |
+
encoder_layers (`int`, *optional*, defaults to 6):
|
| 57 |
+
Number of encoder layers.
|
| 58 |
+
decoder_layers (`int`, *optional*, defaults to 6):
|
| 59 |
+
Number of decoder layers.
|
| 60 |
+
encoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 61 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 62 |
+
decoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 63 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 64 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 2048):
|
| 65 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 66 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 2048):
|
| 67 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 68 |
+
activation_function (`str` or `function`, *optional*, defaults to `"relu"`):
|
| 69 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 70 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 71 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
| 72 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 73 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 74 |
+
The dropout ratio for the attention probabilities.
|
| 75 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
| 76 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 77 |
+
init_std (`float`, *optional*, defaults to 0.02):
|
| 78 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 79 |
+
init_xavier_std (`float`, *optional*, defaults to 1):
|
| 80 |
+
The scaling factor used for the Xavier initialization gain in the HM Attention map module.
|
| 81 |
+
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
| 82 |
+
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
| 83 |
+
for more details.
|
| 84 |
+
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
| 85 |
+
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
| 86 |
+
for more details.
|
| 87 |
+
auxiliary_loss (`bool`, *optional*, defaults to `False`):
|
| 88 |
+
Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
|
| 89 |
+
position_embedding_type (`str`, *optional*, defaults to `"sine"`):
|
| 90 |
+
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
|
| 91 |
+
backbone (`str`, *optional*, defaults to `"resnet50"`):
|
| 92 |
+
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
| 93 |
+
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
|
| 94 |
+
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
| 95 |
+
use_pretrained_backbone (`bool`, *optional*, `True`):
|
| 96 |
+
Whether to use pretrained weights for the backbone.
|
| 97 |
+
backbone_kwargs (`dict`, *optional*):
|
| 98 |
+
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
| 99 |
+
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
| 100 |
+
dilation (`bool`, *optional*, defaults to `False`):
|
| 101 |
+
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
|
| 102 |
+
`use_timm_backbone` = `True`.
|
| 103 |
+
class_cost (`float`, *optional*, defaults to 1):
|
| 104 |
+
Relative weight of the classification error in the Hungarian matching cost.
|
| 105 |
+
bbox_cost (`float`, *optional*, defaults to 5):
|
| 106 |
+
Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost.
|
| 107 |
+
giou_cost (`float`, *optional*, defaults to 2):
|
| 108 |
+
Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost.
|
| 109 |
+
mask_loss_coefficient (`float`, *optional*, defaults to 1):
|
| 110 |
+
Relative weight of the Focal loss in the panoptic segmentation loss.
|
| 111 |
+
dice_loss_coefficient (`float`, *optional*, defaults to 1):
|
| 112 |
+
Relative weight of the DICE/F-1 loss in the panoptic segmentation loss.
|
| 113 |
+
bbox_loss_coefficient (`float`, *optional*, defaults to 5):
|
| 114 |
+
Relative weight of the L1 bounding box loss in the object detection loss.
|
| 115 |
+
giou_loss_coefficient (`float`, *optional*, defaults to 2):
|
| 116 |
+
Relative weight of the generalized IoU loss in the object detection loss.
|
| 117 |
+
eos_coefficient (`float`, *optional*, defaults to 0.1):
|
| 118 |
+
Relative classification weight of the 'no-object' class in the object detection loss.
|
| 119 |
+
|
| 120 |
+
Examples:
|
| 121 |
+
|
| 122 |
+
```python
|
| 123 |
+
>>> from transformers import DetrConfig, DetrModel
|
| 124 |
+
|
| 125 |
+
>>> # Initializing a DETR facebook/detr-resnet-50 style configuration
|
| 126 |
+
>>> configuration = DetrConfig()
|
| 127 |
+
|
| 128 |
+
>>> # Initializing a model (with random weights) from the facebook/detr-resnet-50 style configuration
|
| 129 |
+
>>> model = DetrModel(configuration)
|
| 130 |
+
|
| 131 |
+
>>> # Accessing the model configuration
|
| 132 |
+
>>> configuration = model.config
|
| 133 |
+
```"""
|
| 134 |
+
|
| 135 |
+
model_type = "detr"
|
| 136 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 137 |
+
attribute_map = {
|
| 138 |
+
"hidden_size": "d_model",
|
| 139 |
+
"num_attention_heads": "encoder_attention_heads",
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
use_timm_backbone=True,
|
| 145 |
+
backbone_config=None,
|
| 146 |
+
num_channels=3,
|
| 147 |
+
num_queries=100,
|
| 148 |
+
encoder_layers=6,
|
| 149 |
+
encoder_ffn_dim=2048,
|
| 150 |
+
encoder_attention_heads=8,
|
| 151 |
+
decoder_layers=6,
|
| 152 |
+
decoder_ffn_dim=2048,
|
| 153 |
+
decoder_attention_heads=8,
|
| 154 |
+
encoder_layerdrop=0.0,
|
| 155 |
+
decoder_layerdrop=0.0,
|
| 156 |
+
is_encoder_decoder=True,
|
| 157 |
+
activation_function="relu",
|
| 158 |
+
d_model=256,
|
| 159 |
+
dropout=0.1,
|
| 160 |
+
attention_dropout=0.0,
|
| 161 |
+
activation_dropout=0.0,
|
| 162 |
+
init_std=0.02,
|
| 163 |
+
init_xavier_std=1.0,
|
| 164 |
+
auxiliary_loss=False,
|
| 165 |
+
position_embedding_type="sine",
|
| 166 |
+
backbone="resnet50",
|
| 167 |
+
use_pretrained_backbone=True,
|
| 168 |
+
backbone_kwargs=None,
|
| 169 |
+
dilation=False,
|
| 170 |
+
class_cost=1,
|
| 171 |
+
bbox_cost=5,
|
| 172 |
+
giou_cost=2,
|
| 173 |
+
mask_loss_coefficient=1,
|
| 174 |
+
dice_loss_coefficient=1,
|
| 175 |
+
bbox_loss_coefficient=5,
|
| 176 |
+
giou_loss_coefficient=2,
|
| 177 |
+
eos_coefficient=0.1,
|
| 178 |
+
**kwargs,
|
| 179 |
+
):
|
| 180 |
+
# We default to values which were previously hard-coded in the model. This enables configurability of the config
|
| 181 |
+
# while keeping the default behavior the same.
|
| 182 |
+
if use_timm_backbone and backbone_kwargs is None:
|
| 183 |
+
backbone_kwargs = {}
|
| 184 |
+
if dilation:
|
| 185 |
+
backbone_kwargs["output_stride"] = 16
|
| 186 |
+
backbone_kwargs["out_indices"] = [1, 2, 3, 4]
|
| 187 |
+
backbone_kwargs["in_chans"] = num_channels
|
| 188 |
+
# Backwards compatibility
|
| 189 |
+
elif not use_timm_backbone and backbone in (None, "resnet50"):
|
| 190 |
+
if backbone_config is None:
|
| 191 |
+
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
|
| 192 |
+
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
|
| 193 |
+
elif isinstance(backbone_config, dict):
|
| 194 |
+
backbone_model_type = backbone_config.get("model_type")
|
| 195 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 196 |
+
backbone_config = config_class.from_dict(backbone_config)
|
| 197 |
+
backbone = None
|
| 198 |
+
# set timm attributes to None
|
| 199 |
+
dilation = None
|
| 200 |
+
|
| 201 |
+
verify_backbone_config_arguments(
|
| 202 |
+
use_timm_backbone=use_timm_backbone,
|
| 203 |
+
use_pretrained_backbone=use_pretrained_backbone,
|
| 204 |
+
backbone=backbone,
|
| 205 |
+
backbone_config=backbone_config,
|
| 206 |
+
backbone_kwargs=backbone_kwargs,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
self.use_timm_backbone = use_timm_backbone
|
| 210 |
+
self.backbone_config = backbone_config
|
| 211 |
+
self.num_channels = num_channels
|
| 212 |
+
self.num_queries = num_queries
|
| 213 |
+
self.d_model = d_model
|
| 214 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 215 |
+
self.encoder_layers = encoder_layers
|
| 216 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 217 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 218 |
+
self.decoder_layers = decoder_layers
|
| 219 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 220 |
+
self.dropout = dropout
|
| 221 |
+
self.attention_dropout = attention_dropout
|
| 222 |
+
self.activation_dropout = activation_dropout
|
| 223 |
+
self.activation_function = activation_function
|
| 224 |
+
self.init_std = init_std
|
| 225 |
+
self.init_xavier_std = init_xavier_std
|
| 226 |
+
self.encoder_layerdrop = encoder_layerdrop
|
| 227 |
+
self.decoder_layerdrop = decoder_layerdrop
|
| 228 |
+
self.num_hidden_layers = encoder_layers
|
| 229 |
+
self.auxiliary_loss = auxiliary_loss
|
| 230 |
+
self.position_embedding_type = position_embedding_type
|
| 231 |
+
self.backbone = backbone
|
| 232 |
+
self.use_pretrained_backbone = use_pretrained_backbone
|
| 233 |
+
self.backbone_kwargs = backbone_kwargs
|
| 234 |
+
self.dilation = dilation
|
| 235 |
+
# Hungarian matcher
|
| 236 |
+
self.class_cost = class_cost
|
| 237 |
+
self.bbox_cost = bbox_cost
|
| 238 |
+
self.giou_cost = giou_cost
|
| 239 |
+
# Loss coefficients
|
| 240 |
+
self.mask_loss_coefficient = mask_loss_coefficient
|
| 241 |
+
self.dice_loss_coefficient = dice_loss_coefficient
|
| 242 |
+
self.bbox_loss_coefficient = bbox_loss_coefficient
|
| 243 |
+
self.giou_loss_coefficient = giou_loss_coefficient
|
| 244 |
+
self.eos_coefficient = eos_coefficient
|
| 245 |
+
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
| 246 |
+
|
| 247 |
+
@property
|
| 248 |
+
def num_attention_heads(self) -> int:
|
| 249 |
+
return self.encoder_attention_heads
|
| 250 |
+
|
| 251 |
+
@property
|
| 252 |
+
def hidden_size(self) -> int:
|
| 253 |
+
return self.d_model
|
| 254 |
+
|
| 255 |
+
@classmethod
|
| 256 |
+
def from_backbone_config(cls, backbone_config: PretrainedConfig, **kwargs):
|
| 257 |
+
"""Instantiate a [`DetrConfig`] (or a derived class) from a pre-trained backbone model configuration.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
backbone_config ([`PretrainedConfig`]):
|
| 261 |
+
The backbone configuration.
|
| 262 |
+
Returns:
|
| 263 |
+
[`DetrConfig`]: An instance of a configuration object
|
| 264 |
+
"""
|
| 265 |
+
return cls(backbone_config=backbone_config, **kwargs)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class DetrOnnxConfig(OnnxConfig):
|
| 269 |
+
torch_onnx_minimum_version = version.parse("1.11")
|
| 270 |
+
|
| 271 |
+
@property
|
| 272 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 273 |
+
return OrderedDict(
|
| 274 |
+
[
|
| 275 |
+
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
|
| 276 |
+
("pixel_mask", {0: "batch"}),
|
| 277 |
+
]
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
@property
|
| 281 |
+
def atol_for_validation(self) -> float:
|
| 282 |
+
return 1e-5
|
| 283 |
+
|
| 284 |
+
@property
|
| 285 |
+
def default_onnx_opset(self) -> int:
|
| 286 |
+
return 12
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
__all__ = ["DetrConfig", "DetrOnnxConfig"]
|
docs/transformers/build/lib/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2020 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert DETR checkpoints with timm backbone."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
from collections import OrderedDict
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
import requests
|
| 23 |
+
import torch
|
| 24 |
+
from huggingface_hub import hf_hub_download
|
| 25 |
+
from PIL import Image
|
| 26 |
+
|
| 27 |
+
from transformers import DetrConfig, DetrForObjectDetection, DetrForSegmentation, DetrImageProcessor
|
| 28 |
+
from transformers.utils import logging
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
logging.set_verbosity_info()
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
# here we list all keys to be renamed (original name on the left, our name on the right)
|
| 35 |
+
rename_keys = []
|
| 36 |
+
for i in range(6):
|
| 37 |
+
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
| 38 |
+
rename_keys.append(
|
| 39 |
+
(f"transformer.encoder.layers.{i}.self_attn.out_proj.weight", f"encoder.layers.{i}.self_attn.out_proj.weight")
|
| 40 |
+
)
|
| 41 |
+
rename_keys.append(
|
| 42 |
+
(f"transformer.encoder.layers.{i}.self_attn.out_proj.bias", f"encoder.layers.{i}.self_attn.out_proj.bias")
|
| 43 |
+
)
|
| 44 |
+
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"encoder.layers.{i}.fc1.weight"))
|
| 45 |
+
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"encoder.layers.{i}.fc1.bias"))
|
| 46 |
+
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"encoder.layers.{i}.fc2.weight"))
|
| 47 |
+
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"encoder.layers.{i}.fc2.bias"))
|
| 48 |
+
rename_keys.append(
|
| 49 |
+
(f"transformer.encoder.layers.{i}.norm1.weight", f"encoder.layers.{i}.self_attn_layer_norm.weight")
|
| 50 |
+
)
|
| 51 |
+
rename_keys.append((f"transformer.encoder.layers.{i}.norm1.bias", f"encoder.layers.{i}.self_attn_layer_norm.bias"))
|
| 52 |
+
rename_keys.append((f"transformer.encoder.layers.{i}.norm2.weight", f"encoder.layers.{i}.final_layer_norm.weight"))
|
| 53 |
+
rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"encoder.layers.{i}.final_layer_norm.bias"))
|
| 54 |
+
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
|
| 55 |
+
rename_keys.append(
|
| 56 |
+
(f"transformer.decoder.layers.{i}.self_attn.out_proj.weight", f"decoder.layers.{i}.self_attn.out_proj.weight")
|
| 57 |
+
)
|
| 58 |
+
rename_keys.append(
|
| 59 |
+
(f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"decoder.layers.{i}.self_attn.out_proj.bias")
|
| 60 |
+
)
|
| 61 |
+
rename_keys.append(
|
| 62 |
+
(
|
| 63 |
+
f"transformer.decoder.layers.{i}.multihead_attn.out_proj.weight",
|
| 64 |
+
f"decoder.layers.{i}.encoder_attn.out_proj.weight",
|
| 65 |
+
)
|
| 66 |
+
)
|
| 67 |
+
rename_keys.append(
|
| 68 |
+
(
|
| 69 |
+
f"transformer.decoder.layers.{i}.multihead_attn.out_proj.bias",
|
| 70 |
+
f"decoder.layers.{i}.encoder_attn.out_proj.bias",
|
| 71 |
+
)
|
| 72 |
+
)
|
| 73 |
+
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"decoder.layers.{i}.fc1.weight"))
|
| 74 |
+
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"decoder.layers.{i}.fc1.bias"))
|
| 75 |
+
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"decoder.layers.{i}.fc2.weight"))
|
| 76 |
+
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"decoder.layers.{i}.fc2.bias"))
|
| 77 |
+
rename_keys.append(
|
| 78 |
+
(f"transformer.decoder.layers.{i}.norm1.weight", f"decoder.layers.{i}.self_attn_layer_norm.weight")
|
| 79 |
+
)
|
| 80 |
+
rename_keys.append((f"transformer.decoder.layers.{i}.norm1.bias", f"decoder.layers.{i}.self_attn_layer_norm.bias"))
|
| 81 |
+
rename_keys.append(
|
| 82 |
+
(f"transformer.decoder.layers.{i}.norm2.weight", f"decoder.layers.{i}.encoder_attn_layer_norm.weight")
|
| 83 |
+
)
|
| 84 |
+
rename_keys.append(
|
| 85 |
+
(f"transformer.decoder.layers.{i}.norm2.bias", f"decoder.layers.{i}.encoder_attn_layer_norm.bias")
|
| 86 |
+
)
|
| 87 |
+
rename_keys.append((f"transformer.decoder.layers.{i}.norm3.weight", f"decoder.layers.{i}.final_layer_norm.weight"))
|
| 88 |
+
rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"decoder.layers.{i}.final_layer_norm.bias"))
|
| 89 |
+
|
| 90 |
+
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
|
| 91 |
+
rename_keys.extend(
|
| 92 |
+
[
|
| 93 |
+
("input_proj.weight", "input_projection.weight"),
|
| 94 |
+
("input_proj.bias", "input_projection.bias"),
|
| 95 |
+
("query_embed.weight", "query_position_embeddings.weight"),
|
| 96 |
+
("transformer.decoder.norm.weight", "decoder.layernorm.weight"),
|
| 97 |
+
("transformer.decoder.norm.bias", "decoder.layernorm.bias"),
|
| 98 |
+
("class_embed.weight", "class_labels_classifier.weight"),
|
| 99 |
+
("class_embed.bias", "class_labels_classifier.bias"),
|
| 100 |
+
("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"),
|
| 101 |
+
("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"),
|
| 102 |
+
("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"),
|
| 103 |
+
("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"),
|
| 104 |
+
("bbox_embed.layers.2.weight", "bbox_predictor.layers.2.weight"),
|
| 105 |
+
("bbox_embed.layers.2.bias", "bbox_predictor.layers.2.bias"),
|
| 106 |
+
]
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def rename_key(state_dict, old, new):
|
| 111 |
+
val = state_dict.pop(old)
|
| 112 |
+
state_dict[new] = val
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def rename_backbone_keys(state_dict):
|
| 116 |
+
new_state_dict = OrderedDict()
|
| 117 |
+
for key, value in state_dict.items():
|
| 118 |
+
if "backbone.0.body" in key:
|
| 119 |
+
new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model")
|
| 120 |
+
new_state_dict[new_key] = value
|
| 121 |
+
else:
|
| 122 |
+
new_state_dict[key] = value
|
| 123 |
+
|
| 124 |
+
return new_state_dict
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def read_in_q_k_v(state_dict, is_panoptic=False):
|
| 128 |
+
prefix = ""
|
| 129 |
+
if is_panoptic:
|
| 130 |
+
prefix = "detr."
|
| 131 |
+
|
| 132 |
+
# first: transformer encoder
|
| 133 |
+
for i in range(6):
|
| 134 |
+
# read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
|
| 135 |
+
in_proj_weight = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_weight")
|
| 136 |
+
in_proj_bias = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_bias")
|
| 137 |
+
# next, add query, keys and values (in that order) to the state dict
|
| 138 |
+
state_dict[f"encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
|
| 139 |
+
state_dict[f"encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
|
| 140 |
+
state_dict[f"encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
|
| 141 |
+
state_dict[f"encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
|
| 142 |
+
state_dict[f"encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
|
| 143 |
+
state_dict[f"encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
|
| 144 |
+
# next: transformer decoder (which is a bit more complex because it also includes cross-attention)
|
| 145 |
+
for i in range(6):
|
| 146 |
+
# read in weights + bias of input projection layer of self-attention
|
| 147 |
+
in_proj_weight = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_weight")
|
| 148 |
+
in_proj_bias = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_bias")
|
| 149 |
+
# next, add query, keys and values (in that order) to the state dict
|
| 150 |
+
state_dict[f"decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
|
| 151 |
+
state_dict[f"decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
|
| 152 |
+
state_dict[f"decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
|
| 153 |
+
state_dict[f"decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
|
| 154 |
+
state_dict[f"decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
|
| 155 |
+
state_dict[f"decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
|
| 156 |
+
# read in weights + bias of input projection layer of cross-attention
|
| 157 |
+
in_proj_weight_cross_attn = state_dict.pop(
|
| 158 |
+
f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_weight"
|
| 159 |
+
)
|
| 160 |
+
in_proj_bias_cross_attn = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_bias")
|
| 161 |
+
# next, add query, keys and values (in that order) of cross-attention to the state dict
|
| 162 |
+
state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.weight"] = in_proj_weight_cross_attn[:256, :]
|
| 163 |
+
state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.bias"] = in_proj_bias_cross_attn[:256]
|
| 164 |
+
state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.weight"] = in_proj_weight_cross_attn[256:512, :]
|
| 165 |
+
state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.bias"] = in_proj_bias_cross_attn[256:512]
|
| 166 |
+
state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.weight"] = in_proj_weight_cross_attn[-256:, :]
|
| 167 |
+
state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.bias"] = in_proj_bias_cross_attn[-256:]
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# We will verify our results on an image of cute cats
|
| 171 |
+
def prepare_img():
|
| 172 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 173 |
+
im = Image.open(requests.get(url, stream=True).raw)
|
| 174 |
+
|
| 175 |
+
return im
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@torch.no_grad()
|
| 179 |
+
def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
|
| 180 |
+
"""
|
| 181 |
+
Copy/paste/tweak model's weights to our DETR structure.
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
# load default config
|
| 185 |
+
config = DetrConfig()
|
| 186 |
+
# set backbone and dilation attributes
|
| 187 |
+
if "resnet101" in model_name:
|
| 188 |
+
config.backbone = "resnet101"
|
| 189 |
+
if "dc5" in model_name:
|
| 190 |
+
config.dilation = True
|
| 191 |
+
is_panoptic = "panoptic" in model_name
|
| 192 |
+
if is_panoptic:
|
| 193 |
+
config.num_labels = 250
|
| 194 |
+
else:
|
| 195 |
+
config.num_labels = 91
|
| 196 |
+
repo_id = "huggingface/label-files"
|
| 197 |
+
filename = "coco-detection-id2label.json"
|
| 198 |
+
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
| 199 |
+
id2label = {int(k): v for k, v in id2label.items()}
|
| 200 |
+
config.id2label = id2label
|
| 201 |
+
config.label2id = {v: k for k, v in id2label.items()}
|
| 202 |
+
|
| 203 |
+
# load image processor
|
| 204 |
+
format = "coco_panoptic" if is_panoptic else "coco_detection"
|
| 205 |
+
image_processor = DetrImageProcessor(format=format)
|
| 206 |
+
|
| 207 |
+
# prepare image
|
| 208 |
+
img = prepare_img()
|
| 209 |
+
encoding = image_processor(images=img, return_tensors="pt")
|
| 210 |
+
pixel_values = encoding["pixel_values"]
|
| 211 |
+
|
| 212 |
+
logger.info(f"Converting model {model_name}...")
|
| 213 |
+
|
| 214 |
+
# load original model from torch hub
|
| 215 |
+
detr = torch.hub.load("facebookresearch/detr", model_name, pretrained=True).eval()
|
| 216 |
+
state_dict = detr.state_dict()
|
| 217 |
+
# rename keys
|
| 218 |
+
for src, dest in rename_keys:
|
| 219 |
+
if is_panoptic:
|
| 220 |
+
src = "detr." + src
|
| 221 |
+
rename_key(state_dict, src, dest)
|
| 222 |
+
state_dict = rename_backbone_keys(state_dict)
|
| 223 |
+
# query, key and value matrices need special treatment
|
| 224 |
+
read_in_q_k_v(state_dict, is_panoptic=is_panoptic)
|
| 225 |
+
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
|
| 226 |
+
prefix = "detr.model." if is_panoptic else "model."
|
| 227 |
+
for key in state_dict.copy().keys():
|
| 228 |
+
if is_panoptic:
|
| 229 |
+
if (
|
| 230 |
+
key.startswith("detr")
|
| 231 |
+
and not key.startswith("class_labels_classifier")
|
| 232 |
+
and not key.startswith("bbox_predictor")
|
| 233 |
+
):
|
| 234 |
+
val = state_dict.pop(key)
|
| 235 |
+
state_dict["detr.model" + key[4:]] = val
|
| 236 |
+
elif "class_labels_classifier" in key or "bbox_predictor" in key:
|
| 237 |
+
val = state_dict.pop(key)
|
| 238 |
+
state_dict["detr." + key] = val
|
| 239 |
+
elif key.startswith("bbox_attention") or key.startswith("mask_head"):
|
| 240 |
+
continue
|
| 241 |
+
else:
|
| 242 |
+
val = state_dict.pop(key)
|
| 243 |
+
state_dict[prefix + key] = val
|
| 244 |
+
else:
|
| 245 |
+
if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"):
|
| 246 |
+
val = state_dict.pop(key)
|
| 247 |
+
state_dict[prefix + key] = val
|
| 248 |
+
# finally, create HuggingFace model and load state dict
|
| 249 |
+
model = DetrForSegmentation(config) if is_panoptic else DetrForObjectDetection(config)
|
| 250 |
+
model.load_state_dict(state_dict)
|
| 251 |
+
model.eval()
|
| 252 |
+
# verify our conversion
|
| 253 |
+
original_outputs = detr(pixel_values)
|
| 254 |
+
outputs = model(pixel_values)
|
| 255 |
+
assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-4)
|
| 256 |
+
assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-4)
|
| 257 |
+
if is_panoptic:
|
| 258 |
+
assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4)
|
| 259 |
+
|
| 260 |
+
# Save model and image processor
|
| 261 |
+
logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
|
| 262 |
+
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
| 263 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 264 |
+
image_processor.save_pretrained(pytorch_dump_folder_path)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
if __name__ == "__main__":
|
| 268 |
+
parser = argparse.ArgumentParser()
|
| 269 |
+
|
| 270 |
+
parser.add_argument(
|
| 271 |
+
"--model_name", default="detr_resnet50", type=str, help="Name of the DETR model you'd like to convert."
|
| 272 |
+
)
|
| 273 |
+
parser.add_argument(
|
| 274 |
+
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
|
| 275 |
+
)
|
| 276 |
+
args = parser.parse_args()
|
| 277 |
+
convert_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path)
|
docs/transformers/build/lib/transformers/models/detr/convert_detr_to_pytorch.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert DETR checkpoints with native (Transformers) backbone."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import requests
|
| 22 |
+
import torch
|
| 23 |
+
from huggingface_hub import hf_hub_download
|
| 24 |
+
from PIL import Image
|
| 25 |
+
|
| 26 |
+
from transformers import DetrConfig, DetrForObjectDetection, DetrForSegmentation, DetrImageProcessor, ResNetConfig
|
| 27 |
+
from transformers.utils import logging
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logging.set_verbosity_info()
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_detr_config(model_name):
|
| 35 |
+
# initialize config
|
| 36 |
+
if "resnet-50" in model_name:
|
| 37 |
+
backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-50")
|
| 38 |
+
elif "resnet-101" in model_name:
|
| 39 |
+
backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-101")
|
| 40 |
+
else:
|
| 41 |
+
raise ValueError("Model name should include either resnet50 or resnet101")
|
| 42 |
+
|
| 43 |
+
config = DetrConfig(use_timm_backbone=False, backbone_config=backbone_config)
|
| 44 |
+
|
| 45 |
+
# set label attributes
|
| 46 |
+
is_panoptic = "panoptic" in model_name
|
| 47 |
+
if is_panoptic:
|
| 48 |
+
config.num_labels = 250
|
| 49 |
+
else:
|
| 50 |
+
config.num_labels = 91
|
| 51 |
+
repo_id = "huggingface/label-files"
|
| 52 |
+
filename = "coco-detection-id2label.json"
|
| 53 |
+
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
| 54 |
+
id2label = {int(k): v for k, v in id2label.items()}
|
| 55 |
+
config.id2label = id2label
|
| 56 |
+
config.label2id = {v: k for k, v in id2label.items()}
|
| 57 |
+
|
| 58 |
+
return config, is_panoptic
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def create_rename_keys(config):
|
| 62 |
+
# here we list all keys to be renamed (original name on the left, our name on the right)
|
| 63 |
+
rename_keys = []
|
| 64 |
+
|
| 65 |
+
# stem
|
| 66 |
+
# fmt: off
|
| 67 |
+
rename_keys.append(("backbone.0.body.conv1.weight", "backbone.conv_encoder.model.embedder.embedder.convolution.weight"))
|
| 68 |
+
rename_keys.append(("backbone.0.body.bn1.weight", "backbone.conv_encoder.model.embedder.embedder.normalization.weight"))
|
| 69 |
+
rename_keys.append(("backbone.0.body.bn1.bias", "backbone.conv_encoder.model.embedder.embedder.normalization.bias"))
|
| 70 |
+
rename_keys.append(("backbone.0.body.bn1.running_mean", "backbone.conv_encoder.model.embedder.embedder.normalization.running_mean"))
|
| 71 |
+
rename_keys.append(("backbone.0.body.bn1.running_var", "backbone.conv_encoder.model.embedder.embedder.normalization.running_var"))
|
| 72 |
+
# stages
|
| 73 |
+
for stage_idx in range(len(config.backbone_config.depths)):
|
| 74 |
+
for layer_idx in range(config.backbone_config.depths[stage_idx]):
|
| 75 |
+
# shortcut
|
| 76 |
+
if layer_idx == 0:
|
| 77 |
+
rename_keys.append(
|
| 78 |
+
(
|
| 79 |
+
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.0.weight",
|
| 80 |
+
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.convolution.weight",
|
| 81 |
+
)
|
| 82 |
+
)
|
| 83 |
+
rename_keys.append(
|
| 84 |
+
(
|
| 85 |
+
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.weight",
|
| 86 |
+
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.weight",
|
| 87 |
+
)
|
| 88 |
+
)
|
| 89 |
+
rename_keys.append(
|
| 90 |
+
(
|
| 91 |
+
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.bias",
|
| 92 |
+
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.bias",
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
rename_keys.append(
|
| 96 |
+
(
|
| 97 |
+
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.running_mean",
|
| 98 |
+
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_mean",
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
rename_keys.append(
|
| 102 |
+
(
|
| 103 |
+
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.running_var",
|
| 104 |
+
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_var",
|
| 105 |
+
)
|
| 106 |
+
)
|
| 107 |
+
# 3 convs
|
| 108 |
+
for i in range(3):
|
| 109 |
+
rename_keys.append(
|
| 110 |
+
(
|
| 111 |
+
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.conv{i+1}.weight",
|
| 112 |
+
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.convolution.weight",
|
| 113 |
+
)
|
| 114 |
+
)
|
| 115 |
+
rename_keys.append(
|
| 116 |
+
(
|
| 117 |
+
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.weight",
|
| 118 |
+
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.weight",
|
| 119 |
+
)
|
| 120 |
+
)
|
| 121 |
+
rename_keys.append(
|
| 122 |
+
(
|
| 123 |
+
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.bias",
|
| 124 |
+
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.bias",
|
| 125 |
+
)
|
| 126 |
+
)
|
| 127 |
+
rename_keys.append(
|
| 128 |
+
(
|
| 129 |
+
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.running_mean",
|
| 130 |
+
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_mean",
|
| 131 |
+
)
|
| 132 |
+
)
|
| 133 |
+
rename_keys.append(
|
| 134 |
+
(
|
| 135 |
+
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.running_var",
|
| 136 |
+
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_var",
|
| 137 |
+
)
|
| 138 |
+
)
|
| 139 |
+
# fmt: on
|
| 140 |
+
|
| 141 |
+
for i in range(config.encoder_layers):
|
| 142 |
+
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
| 143 |
+
rename_keys.append(
|
| 144 |
+
(
|
| 145 |
+
f"transformer.encoder.layers.{i}.self_attn.out_proj.weight",
|
| 146 |
+
f"encoder.layers.{i}.self_attn.out_proj.weight",
|
| 147 |
+
)
|
| 148 |
+
)
|
| 149 |
+
rename_keys.append(
|
| 150 |
+
(f"transformer.encoder.layers.{i}.self_attn.out_proj.bias", f"encoder.layers.{i}.self_attn.out_proj.bias")
|
| 151 |
+
)
|
| 152 |
+
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"encoder.layers.{i}.fc1.weight"))
|
| 153 |
+
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"encoder.layers.{i}.fc1.bias"))
|
| 154 |
+
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"encoder.layers.{i}.fc2.weight"))
|
| 155 |
+
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"encoder.layers.{i}.fc2.bias"))
|
| 156 |
+
rename_keys.append(
|
| 157 |
+
(f"transformer.encoder.layers.{i}.norm1.weight", f"encoder.layers.{i}.self_attn_layer_norm.weight")
|
| 158 |
+
)
|
| 159 |
+
rename_keys.append(
|
| 160 |
+
(f"transformer.encoder.layers.{i}.norm1.bias", f"encoder.layers.{i}.self_attn_layer_norm.bias")
|
| 161 |
+
)
|
| 162 |
+
rename_keys.append(
|
| 163 |
+
(f"transformer.encoder.layers.{i}.norm2.weight", f"encoder.layers.{i}.final_layer_norm.weight")
|
| 164 |
+
)
|
| 165 |
+
rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"encoder.layers.{i}.final_layer_norm.bias"))
|
| 166 |
+
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
|
| 167 |
+
rename_keys.append(
|
| 168 |
+
(
|
| 169 |
+
f"transformer.decoder.layers.{i}.self_attn.out_proj.weight",
|
| 170 |
+
f"decoder.layers.{i}.self_attn.out_proj.weight",
|
| 171 |
+
)
|
| 172 |
+
)
|
| 173 |
+
rename_keys.append(
|
| 174 |
+
(f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"decoder.layers.{i}.self_attn.out_proj.bias")
|
| 175 |
+
)
|
| 176 |
+
rename_keys.append(
|
| 177 |
+
(
|
| 178 |
+
f"transformer.decoder.layers.{i}.multihead_attn.out_proj.weight",
|
| 179 |
+
f"decoder.layers.{i}.encoder_attn.out_proj.weight",
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
rename_keys.append(
|
| 183 |
+
(
|
| 184 |
+
f"transformer.decoder.layers.{i}.multihead_attn.out_proj.bias",
|
| 185 |
+
f"decoder.layers.{i}.encoder_attn.out_proj.bias",
|
| 186 |
+
)
|
| 187 |
+
)
|
| 188 |
+
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"decoder.layers.{i}.fc1.weight"))
|
| 189 |
+
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"decoder.layers.{i}.fc1.bias"))
|
| 190 |
+
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"decoder.layers.{i}.fc2.weight"))
|
| 191 |
+
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"decoder.layers.{i}.fc2.bias"))
|
| 192 |
+
rename_keys.append(
|
| 193 |
+
(f"transformer.decoder.layers.{i}.norm1.weight", f"decoder.layers.{i}.self_attn_layer_norm.weight")
|
| 194 |
+
)
|
| 195 |
+
rename_keys.append(
|
| 196 |
+
(f"transformer.decoder.layers.{i}.norm1.bias", f"decoder.layers.{i}.self_attn_layer_norm.bias")
|
| 197 |
+
)
|
| 198 |
+
rename_keys.append(
|
| 199 |
+
(f"transformer.decoder.layers.{i}.norm2.weight", f"decoder.layers.{i}.encoder_attn_layer_norm.weight")
|
| 200 |
+
)
|
| 201 |
+
rename_keys.append(
|
| 202 |
+
(f"transformer.decoder.layers.{i}.norm2.bias", f"decoder.layers.{i}.encoder_attn_layer_norm.bias")
|
| 203 |
+
)
|
| 204 |
+
rename_keys.append(
|
| 205 |
+
(f"transformer.decoder.layers.{i}.norm3.weight", f"decoder.layers.{i}.final_layer_norm.weight")
|
| 206 |
+
)
|
| 207 |
+
rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"decoder.layers.{i}.final_layer_norm.bias"))
|
| 208 |
+
|
| 209 |
+
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
|
| 210 |
+
rename_keys.extend(
|
| 211 |
+
[
|
| 212 |
+
("input_proj.weight", "input_projection.weight"),
|
| 213 |
+
("input_proj.bias", "input_projection.bias"),
|
| 214 |
+
("query_embed.weight", "query_position_embeddings.weight"),
|
| 215 |
+
("transformer.decoder.norm.weight", "decoder.layernorm.weight"),
|
| 216 |
+
("transformer.decoder.norm.bias", "decoder.layernorm.bias"),
|
| 217 |
+
("class_embed.weight", "class_labels_classifier.weight"),
|
| 218 |
+
("class_embed.bias", "class_labels_classifier.bias"),
|
| 219 |
+
("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"),
|
| 220 |
+
("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"),
|
| 221 |
+
("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"),
|
| 222 |
+
("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"),
|
| 223 |
+
("bbox_embed.layers.2.weight", "bbox_predictor.layers.2.weight"),
|
| 224 |
+
("bbox_embed.layers.2.bias", "bbox_predictor.layers.2.bias"),
|
| 225 |
+
]
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
return rename_keys
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def rename_key(state_dict, old, new):
|
| 232 |
+
val = state_dict.pop(old)
|
| 233 |
+
state_dict[new] = val
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def read_in_q_k_v(state_dict, is_panoptic=False):
|
| 237 |
+
prefix = ""
|
| 238 |
+
if is_panoptic:
|
| 239 |
+
prefix = "detr."
|
| 240 |
+
|
| 241 |
+
# first: transformer encoder
|
| 242 |
+
for i in range(6):
|
| 243 |
+
# read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
|
| 244 |
+
in_proj_weight = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_weight")
|
| 245 |
+
in_proj_bias = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_bias")
|
| 246 |
+
# next, add query, keys and values (in that order) to the state dict
|
| 247 |
+
state_dict[f"encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
|
| 248 |
+
state_dict[f"encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
|
| 249 |
+
state_dict[f"encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
|
| 250 |
+
state_dict[f"encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
|
| 251 |
+
state_dict[f"encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
|
| 252 |
+
state_dict[f"encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
|
| 253 |
+
# next: transformer decoder (which is a bit more complex because it also includes cross-attention)
|
| 254 |
+
for i in range(6):
|
| 255 |
+
# read in weights + bias of input projection layer of self-attention
|
| 256 |
+
in_proj_weight = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_weight")
|
| 257 |
+
in_proj_bias = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_bias")
|
| 258 |
+
# next, add query, keys and values (in that order) to the state dict
|
| 259 |
+
state_dict[f"decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
|
| 260 |
+
state_dict[f"decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
|
| 261 |
+
state_dict[f"decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
|
| 262 |
+
state_dict[f"decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
|
| 263 |
+
state_dict[f"decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
|
| 264 |
+
state_dict[f"decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
|
| 265 |
+
# read in weights + bias of input projection layer of cross-attention
|
| 266 |
+
in_proj_weight_cross_attn = state_dict.pop(
|
| 267 |
+
f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_weight"
|
| 268 |
+
)
|
| 269 |
+
in_proj_bias_cross_attn = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_bias")
|
| 270 |
+
# next, add query, keys and values (in that order) of cross-attention to the state dict
|
| 271 |
+
state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.weight"] = in_proj_weight_cross_attn[:256, :]
|
| 272 |
+
state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.bias"] = in_proj_bias_cross_attn[:256]
|
| 273 |
+
state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.weight"] = in_proj_weight_cross_attn[256:512, :]
|
| 274 |
+
state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.bias"] = in_proj_bias_cross_attn[256:512]
|
| 275 |
+
state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.weight"] = in_proj_weight_cross_attn[-256:, :]
|
| 276 |
+
state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.bias"] = in_proj_bias_cross_attn[-256:]
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# We will verify our results on an image of cute cats
|
| 280 |
+
def prepare_img():
|
| 281 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 282 |
+
im = Image.open(requests.get(url, stream=True).raw)
|
| 283 |
+
|
| 284 |
+
return im
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
@torch.no_grad()
|
| 288 |
+
def convert_detr_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
|
| 289 |
+
"""
|
| 290 |
+
Copy/paste/tweak model's weights to our DETR structure.
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
# load default config
|
| 294 |
+
config, is_panoptic = get_detr_config(model_name)
|
| 295 |
+
|
| 296 |
+
# load original model from torch hub
|
| 297 |
+
model_name_to_original_name = {
|
| 298 |
+
"detr-resnet-50": "detr_resnet50",
|
| 299 |
+
"detr-resnet-101": "detr_resnet101",
|
| 300 |
+
}
|
| 301 |
+
logger.info(f"Converting model {model_name}...")
|
| 302 |
+
detr = torch.hub.load("facebookresearch/detr", model_name_to_original_name[model_name], pretrained=True).eval()
|
| 303 |
+
state_dict = detr.state_dict()
|
| 304 |
+
# rename keys
|
| 305 |
+
for src, dest in create_rename_keys(config):
|
| 306 |
+
if is_panoptic:
|
| 307 |
+
src = "detr." + src
|
| 308 |
+
rename_key(state_dict, src, dest)
|
| 309 |
+
# query, key and value matrices need special treatment
|
| 310 |
+
read_in_q_k_v(state_dict, is_panoptic=is_panoptic)
|
| 311 |
+
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
|
| 312 |
+
prefix = "detr.model." if is_panoptic else "model."
|
| 313 |
+
for key in state_dict.copy().keys():
|
| 314 |
+
if is_panoptic:
|
| 315 |
+
if (
|
| 316 |
+
key.startswith("detr")
|
| 317 |
+
and not key.startswith("class_labels_classifier")
|
| 318 |
+
and not key.startswith("bbox_predictor")
|
| 319 |
+
):
|
| 320 |
+
val = state_dict.pop(key)
|
| 321 |
+
state_dict["detr.model" + key[4:]] = val
|
| 322 |
+
elif "class_labels_classifier" in key or "bbox_predictor" in key:
|
| 323 |
+
val = state_dict.pop(key)
|
| 324 |
+
state_dict["detr." + key] = val
|
| 325 |
+
elif key.startswith("bbox_attention") or key.startswith("mask_head"):
|
| 326 |
+
continue
|
| 327 |
+
else:
|
| 328 |
+
val = state_dict.pop(key)
|
| 329 |
+
state_dict[prefix + key] = val
|
| 330 |
+
else:
|
| 331 |
+
if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"):
|
| 332 |
+
val = state_dict.pop(key)
|
| 333 |
+
state_dict[prefix + key] = val
|
| 334 |
+
|
| 335 |
+
# finally, create HuggingFace model and load state dict
|
| 336 |
+
model = DetrForSegmentation(config) if is_panoptic else DetrForObjectDetection(config)
|
| 337 |
+
model.load_state_dict(state_dict)
|
| 338 |
+
model.eval()
|
| 339 |
+
|
| 340 |
+
# verify our conversion on an image
|
| 341 |
+
format = "coco_panoptic" if is_panoptic else "coco_detection"
|
| 342 |
+
processor = DetrImageProcessor(format=format)
|
| 343 |
+
|
| 344 |
+
encoding = processor(images=prepare_img(), return_tensors="pt")
|
| 345 |
+
pixel_values = encoding["pixel_values"]
|
| 346 |
+
|
| 347 |
+
original_outputs = detr(pixel_values)
|
| 348 |
+
outputs = model(pixel_values)
|
| 349 |
+
|
| 350 |
+
assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-3)
|
| 351 |
+
assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-3)
|
| 352 |
+
if is_panoptic:
|
| 353 |
+
assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4)
|
| 354 |
+
print("Looks ok!")
|
| 355 |
+
|
| 356 |
+
if pytorch_dump_folder_path is not None:
|
| 357 |
+
# Save model and image processor
|
| 358 |
+
logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
|
| 359 |
+
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
| 360 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 361 |
+
processor.save_pretrained(pytorch_dump_folder_path)
|
| 362 |
+
|
| 363 |
+
if push_to_hub:
|
| 364 |
+
# Upload model and image processor to the hub
|
| 365 |
+
logger.info("Uploading PyTorch model and image processor to the hub...")
|
| 366 |
+
model.push_to_hub(f"nielsr/{model_name}")
|
| 367 |
+
processor.push_to_hub(f"nielsr/{model_name}")
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
if __name__ == "__main__":
|
| 371 |
+
parser = argparse.ArgumentParser()
|
| 372 |
+
|
| 373 |
+
parser.add_argument(
|
| 374 |
+
"--model_name",
|
| 375 |
+
default="detr-resnet-50",
|
| 376 |
+
type=str,
|
| 377 |
+
choices=["detr-resnet-50", "detr-resnet-101"],
|
| 378 |
+
help="Name of the DETR model you'd like to convert.",
|
| 379 |
+
)
|
| 380 |
+
parser.add_argument(
|
| 381 |
+
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
|
| 382 |
+
)
|
| 383 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to the hub or not.")
|
| 384 |
+
args = parser.parse_args()
|
| 385 |
+
convert_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
docs/transformers/build/lib/transformers/models/detr/feature_extraction_detr.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Feature extractor class for DETR."""
|
| 16 |
+
|
| 17 |
+
import warnings
|
| 18 |
+
|
| 19 |
+
from ...image_transforms import rgb_to_id as _rgb_to_id
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
from ...utils.import_utils import requires
|
| 22 |
+
from .image_processing_detr import DetrImageProcessor
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def rgb_to_id(x):
|
| 29 |
+
warnings.warn(
|
| 30 |
+
"rgb_to_id has moved and will not be importable from this module from v5. "
|
| 31 |
+
"Please import from transformers.image_transforms instead.",
|
| 32 |
+
FutureWarning,
|
| 33 |
+
)
|
| 34 |
+
return _rgb_to_id(x)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@requires(backends=("vision",))
|
| 38 |
+
class DetrFeatureExtractor(DetrImageProcessor):
|
| 39 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 40 |
+
warnings.warn(
|
| 41 |
+
"The class DetrFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
|
| 42 |
+
" Please use DetrImageProcessor instead.",
|
| 43 |
+
FutureWarning,
|
| 44 |
+
)
|
| 45 |
+
super().__init__(*args, **kwargs)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
__all__ = ["DetrFeatureExtractor"]
|
docs/transformers/build/lib/transformers/models/detr/image_processing_detr_fast.py
ADDED
|
@@ -0,0 +1,1312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Fast Image processor class for DETR."""
|
| 16 |
+
|
| 17 |
+
import io
|
| 18 |
+
import pathlib
|
| 19 |
+
from collections import defaultdict
|
| 20 |
+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
| 21 |
+
|
| 22 |
+
from ...image_processing_utils import BatchFeature, get_size_dict
|
| 23 |
+
from ...image_processing_utils_fast import (
|
| 24 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
| 25 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
| 26 |
+
BaseImageProcessorFast,
|
| 27 |
+
DefaultFastImageProcessorKwargs,
|
| 28 |
+
SizeDict,
|
| 29 |
+
get_image_size_for_max_height_width,
|
| 30 |
+
get_max_height_width,
|
| 31 |
+
safe_squeeze,
|
| 32 |
+
)
|
| 33 |
+
from ...image_transforms import (
|
| 34 |
+
center_to_corners_format,
|
| 35 |
+
corners_to_center_format,
|
| 36 |
+
id_to_rgb,
|
| 37 |
+
)
|
| 38 |
+
from ...image_utils import (
|
| 39 |
+
IMAGENET_DEFAULT_MEAN,
|
| 40 |
+
IMAGENET_DEFAULT_STD,
|
| 41 |
+
AnnotationFormat,
|
| 42 |
+
AnnotationType,
|
| 43 |
+
ChannelDimension,
|
| 44 |
+
ImageInput,
|
| 45 |
+
PILImageResampling,
|
| 46 |
+
get_image_size,
|
| 47 |
+
validate_annotations,
|
| 48 |
+
)
|
| 49 |
+
from ...processing_utils import Unpack
|
| 50 |
+
from ...utils import (
|
| 51 |
+
TensorType,
|
| 52 |
+
add_start_docstrings,
|
| 53 |
+
is_torch_available,
|
| 54 |
+
is_torchvision_available,
|
| 55 |
+
is_torchvision_v2_available,
|
| 56 |
+
is_vision_available,
|
| 57 |
+
logging,
|
| 58 |
+
)
|
| 59 |
+
from ...utils.import_utils import requires
|
| 60 |
+
from .image_processing_detr import (
|
| 61 |
+
compute_segments,
|
| 62 |
+
convert_segmentation_to_rle,
|
| 63 |
+
get_size_with_aspect_ratio,
|
| 64 |
+
remove_low_and_no_objects,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if is_torch_available():
|
| 69 |
+
import torch
|
| 70 |
+
from torch import nn
|
| 71 |
+
|
| 72 |
+
if is_vision_available():
|
| 73 |
+
import PIL
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if is_torchvision_v2_available():
|
| 77 |
+
from torchvision.io import read_image
|
| 78 |
+
from torchvision.transforms.v2 import functional as F
|
| 79 |
+
elif is_torchvision_available():
|
| 80 |
+
from torchvision.io import read_image
|
| 81 |
+
from torchvision.transforms import functional as F
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
logger = logging.get_logger(__name__)
|
| 85 |
+
|
| 86 |
+
SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L33
|
| 90 |
+
def convert_coco_poly_to_mask(segmentations, height: int, width: int, device: torch.device) -> torch.Tensor:
|
| 91 |
+
"""
|
| 92 |
+
Convert a COCO polygon annotation to a mask.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
segmentations (`List[List[float]]`):
|
| 96 |
+
List of polygons, each polygon represented by a list of x-y coordinates.
|
| 97 |
+
height (`int`):
|
| 98 |
+
Height of the mask.
|
| 99 |
+
width (`int`):
|
| 100 |
+
Width of the mask.
|
| 101 |
+
"""
|
| 102 |
+
try:
|
| 103 |
+
from pycocotools import mask as coco_mask
|
| 104 |
+
except ImportError:
|
| 105 |
+
raise ImportError("Pycocotools is not installed in your environment.")
|
| 106 |
+
|
| 107 |
+
masks = []
|
| 108 |
+
for polygons in segmentations:
|
| 109 |
+
rles = coco_mask.frPyObjects(polygons, height, width)
|
| 110 |
+
mask = coco_mask.decode(rles)
|
| 111 |
+
if len(mask.shape) < 3:
|
| 112 |
+
mask = mask[..., None]
|
| 113 |
+
mask = torch.as_tensor(mask, dtype=torch.uint8, device=device)
|
| 114 |
+
mask = torch.any(mask, axis=2)
|
| 115 |
+
masks.append(mask)
|
| 116 |
+
if masks:
|
| 117 |
+
masks = torch.stack(masks, axis=0)
|
| 118 |
+
else:
|
| 119 |
+
masks = torch.zeros((0, height, width), dtype=torch.uint8, device=device)
|
| 120 |
+
|
| 121 |
+
return masks
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L50
|
| 125 |
+
def prepare_coco_detection_annotation(
|
| 126 |
+
image,
|
| 127 |
+
target,
|
| 128 |
+
return_segmentation_masks: bool = False,
|
| 129 |
+
input_data_format: Optional[Union[ChannelDimension, str]] = None,
|
| 130 |
+
):
|
| 131 |
+
"""
|
| 132 |
+
Convert the target in COCO format into the format expected by DETR.
|
| 133 |
+
"""
|
| 134 |
+
image_height, image_width = image.size()[-2:]
|
| 135 |
+
|
| 136 |
+
image_id = target["image_id"]
|
| 137 |
+
image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device)
|
| 138 |
+
|
| 139 |
+
# Get all COCO annotations for the given image.
|
| 140 |
+
annotations = target["annotations"]
|
| 141 |
+
classes = []
|
| 142 |
+
area = []
|
| 143 |
+
boxes = []
|
| 144 |
+
keypoints = []
|
| 145 |
+
for obj in annotations:
|
| 146 |
+
if "iscrowd" not in obj or obj["iscrowd"] == 0:
|
| 147 |
+
classes.append(obj["category_id"])
|
| 148 |
+
area.append(obj["area"])
|
| 149 |
+
boxes.append(obj["bbox"])
|
| 150 |
+
if "keypoints" in obj:
|
| 151 |
+
keypoints.append(obj["keypoints"])
|
| 152 |
+
|
| 153 |
+
classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device)
|
| 154 |
+
area = torch.as_tensor(area, dtype=torch.float32, device=image.device)
|
| 155 |
+
iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device)
|
| 156 |
+
# guard against no boxes via resizing
|
| 157 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4)
|
| 158 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 159 |
+
boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
|
| 160 |
+
boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
|
| 161 |
+
|
| 162 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
| 163 |
+
|
| 164 |
+
new_target = {
|
| 165 |
+
"image_id": image_id,
|
| 166 |
+
"class_labels": classes[keep],
|
| 167 |
+
"boxes": boxes[keep],
|
| 168 |
+
"area": area[keep],
|
| 169 |
+
"iscrowd": iscrowd[keep],
|
| 170 |
+
"orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device),
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
if keypoints:
|
| 174 |
+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device)
|
| 175 |
+
# Apply the keep mask here to filter the relevant annotations
|
| 176 |
+
keypoints = keypoints[keep]
|
| 177 |
+
num_keypoints = keypoints.shape[0]
|
| 178 |
+
keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
|
| 179 |
+
new_target["keypoints"] = keypoints
|
| 180 |
+
|
| 181 |
+
if return_segmentation_masks:
|
| 182 |
+
segmentation_masks = [obj["segmentation"] for obj in annotations]
|
| 183 |
+
masks = convert_coco_poly_to_mask(segmentation_masks, image_height, image_width, device=image.device)
|
| 184 |
+
new_target["masks"] = masks[keep]
|
| 185 |
+
|
| 186 |
+
return new_target
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
|
| 190 |
+
"""
|
| 191 |
+
Compute the bounding boxes around the provided panoptic segmentation masks.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
masks: masks in format `[number_masks, height, width]` where N is the number of masks
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
boxes: bounding boxes in format `[number_masks, 4]` in xyxy format
|
| 198 |
+
"""
|
| 199 |
+
if masks.numel() == 0:
|
| 200 |
+
return torch.zeros((0, 4), device=masks.device)
|
| 201 |
+
|
| 202 |
+
h, w = masks.shape[-2:]
|
| 203 |
+
y = torch.arange(0, h, dtype=torch.float32, device=masks.device)
|
| 204 |
+
x = torch.arange(0, w, dtype=torch.float32, device=masks.device)
|
| 205 |
+
# see https://github.com/pytorch/pytorch/issues/50276
|
| 206 |
+
y, x = torch.meshgrid(y, x, indexing="ij")
|
| 207 |
+
|
| 208 |
+
x_mask = masks * torch.unsqueeze(x, 0)
|
| 209 |
+
x_max = x_mask.view(x_mask.shape[0], -1).max(-1)[0]
|
| 210 |
+
x_min = (
|
| 211 |
+
torch.where(masks, x.unsqueeze(0), torch.tensor(1e8, device=masks.device)).view(masks.shape[0], -1).min(-1)[0]
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
y_mask = masks * torch.unsqueeze(y, 0)
|
| 215 |
+
y_max = y_mask.view(y_mask.shape[0], -1).max(-1)[0]
|
| 216 |
+
y_min = (
|
| 217 |
+
torch.where(masks, y.unsqueeze(0), torch.tensor(1e8, device=masks.device)).view(masks.shape[0], -1).min(-1)[0]
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# 2 functions below adapted from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
|
| 224 |
+
# Copyright (c) 2018, Alexander Kirillov
|
| 225 |
+
# All rights reserved.
|
| 226 |
+
def rgb_to_id(color):
|
| 227 |
+
"""
|
| 228 |
+
Converts RGB color to unique ID.
|
| 229 |
+
"""
|
| 230 |
+
if isinstance(color, torch.Tensor) and len(color.shape) == 3:
|
| 231 |
+
if color.dtype == torch.uint8:
|
| 232 |
+
color = color.to(torch.int32)
|
| 233 |
+
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
|
| 234 |
+
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def prepare_coco_panoptic_annotation(
|
| 238 |
+
image: torch.Tensor,
|
| 239 |
+
target: Dict,
|
| 240 |
+
masks_path: Union[str, pathlib.Path],
|
| 241 |
+
return_masks: bool = True,
|
| 242 |
+
input_data_format: Union[ChannelDimension, str] = None,
|
| 243 |
+
) -> Dict:
|
| 244 |
+
"""
|
| 245 |
+
Prepare a coco panoptic annotation for DETR.
|
| 246 |
+
"""
|
| 247 |
+
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
| 248 |
+
annotation_path = pathlib.Path(masks_path) / target["file_name"]
|
| 249 |
+
|
| 250 |
+
new_target = {}
|
| 251 |
+
new_target["image_id"] = torch.as_tensor(
|
| 252 |
+
[target["image_id"] if "image_id" in target else target["id"]], dtype=torch.int64, device=image.device
|
| 253 |
+
)
|
| 254 |
+
new_target["size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device)
|
| 255 |
+
new_target["orig_size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device)
|
| 256 |
+
|
| 257 |
+
if "segments_info" in target:
|
| 258 |
+
masks = read_image(annotation_path).permute(1, 2, 0).to(dtype=torch.int32, device=image.device)
|
| 259 |
+
masks = rgb_to_id(masks)
|
| 260 |
+
|
| 261 |
+
ids = torch.as_tensor([segment_info["id"] for segment_info in target["segments_info"]], device=image.device)
|
| 262 |
+
masks = masks == ids[:, None, None]
|
| 263 |
+
masks = masks.to(torch.bool)
|
| 264 |
+
if return_masks:
|
| 265 |
+
new_target["masks"] = masks
|
| 266 |
+
new_target["boxes"] = masks_to_boxes(masks)
|
| 267 |
+
new_target["class_labels"] = torch.as_tensor(
|
| 268 |
+
[segment_info["category_id"] for segment_info in target["segments_info"]],
|
| 269 |
+
dtype=torch.int64,
|
| 270 |
+
device=image.device,
|
| 271 |
+
)
|
| 272 |
+
new_target["iscrowd"] = torch.as_tensor(
|
| 273 |
+
[segment_info["iscrowd"] for segment_info in target["segments_info"]],
|
| 274 |
+
dtype=torch.int64,
|
| 275 |
+
device=image.device,
|
| 276 |
+
)
|
| 277 |
+
new_target["area"] = torch.as_tensor(
|
| 278 |
+
[segment_info["area"] for segment_info in target["segments_info"]],
|
| 279 |
+
dtype=torch.float32,
|
| 280 |
+
device=image.device,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
return new_target
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class DetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
| 287 |
+
format: Optional[Union[str, AnnotationFormat]]
|
| 288 |
+
do_convert_annotations: Optional[bool]
|
| 289 |
+
do_pad: Optional[bool]
|
| 290 |
+
pad_size: Optional[Dict[str, int]]
|
| 291 |
+
return_segmentation_masks: Optional[bool]
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@add_start_docstrings(
|
| 295 |
+
"Constructs a fast Detr image processor.",
|
| 296 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
| 297 |
+
"""
|
| 298 |
+
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
|
| 299 |
+
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
|
| 300 |
+
do_convert_annotations (`bool`, *optional*, defaults to `True`):
|
| 301 |
+
Controls whether to convert the annotations to the format expected by the DETR model. Converts the
|
| 302 |
+
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
|
| 303 |
+
Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
|
| 304 |
+
do_pad (`bool`, *optional*, defaults to `True`):
|
| 305 |
+
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
|
| 306 |
+
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
|
| 307 |
+
If `pad_size` is provided, the image will be padded to the specified dimensions.
|
| 308 |
+
Otherwise, the image will be padded to the maximum height and width of the batch.
|
| 309 |
+
pad_size (`Dict[str, int]`, *optional*):
|
| 310 |
+
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
| 311 |
+
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
|
| 312 |
+
height and width in the batch.
|
| 313 |
+
return_segmentation_masks (`bool`, *optional*, defaults to `False`):
|
| 314 |
+
Whether to return segmentation masks.
|
| 315 |
+
""",
|
| 316 |
+
)
|
| 317 |
+
@requires(backends=("torchvision", "torch"))
|
| 318 |
+
class DetrImageProcessorFast(BaseImageProcessorFast):
|
| 319 |
+
resample = PILImageResampling.BILINEAR
|
| 320 |
+
image_mean = IMAGENET_DEFAULT_MEAN
|
| 321 |
+
image_std = IMAGENET_DEFAULT_STD
|
| 322 |
+
format = AnnotationFormat.COCO_DETECTION
|
| 323 |
+
do_resize = True
|
| 324 |
+
do_rescale = True
|
| 325 |
+
do_normalize = True
|
| 326 |
+
do_pad = True
|
| 327 |
+
size = {"shortest_edge": 800, "longest_edge": 1333}
|
| 328 |
+
default_to_square = False
|
| 329 |
+
model_input_names = ["pixel_values", "pixel_mask"]
|
| 330 |
+
valid_kwargs = DetrFastImageProcessorKwargs
|
| 331 |
+
|
| 332 |
+
def __init__(self, **kwargs: Unpack[DetrFastImageProcessorKwargs]) -> None:
|
| 333 |
+
if "pad_and_return_pixel_mask" in kwargs:
|
| 334 |
+
kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask")
|
| 335 |
+
|
| 336 |
+
size = kwargs.pop("size", None)
|
| 337 |
+
if "max_size" in kwargs:
|
| 338 |
+
logger.warning_once(
|
| 339 |
+
"The `max_size` parameter is deprecated and will be removed in v4.26. "
|
| 340 |
+
"Please specify in `size['longest_edge'] instead`.",
|
| 341 |
+
)
|
| 342 |
+
max_size = kwargs.pop("max_size")
|
| 343 |
+
else:
|
| 344 |
+
max_size = None if size is None else 1333
|
| 345 |
+
|
| 346 |
+
size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333}
|
| 347 |
+
self.size = get_size_dict(size, max_size=max_size, default_to_square=False)
|
| 348 |
+
|
| 349 |
+
# Backwards compatibility
|
| 350 |
+
do_convert_annotations = kwargs.get("do_convert_annotations", None)
|
| 351 |
+
do_normalize = kwargs.get("do_normalize", None)
|
| 352 |
+
if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None:
|
| 353 |
+
self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize
|
| 354 |
+
|
| 355 |
+
super().__init__(**kwargs)
|
| 356 |
+
|
| 357 |
+
@classmethod
|
| 358 |
+
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
| 359 |
+
"""
|
| 360 |
+
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
|
| 361 |
+
created using from_dict and kwargs e.g. `DetrImageProcessorFast.from_pretrained(checkpoint, size=600,
|
| 362 |
+
max_size=800)`
|
| 363 |
+
"""
|
| 364 |
+
image_processor_dict = image_processor_dict.copy()
|
| 365 |
+
if "max_size" in kwargs:
|
| 366 |
+
image_processor_dict["max_size"] = kwargs.pop("max_size")
|
| 367 |
+
if "pad_and_return_pixel_mask" in kwargs:
|
| 368 |
+
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
|
| 369 |
+
return super().from_dict(image_processor_dict, **kwargs)
|
| 370 |
+
|
| 371 |
+
def prepare_annotation(
|
| 372 |
+
self,
|
| 373 |
+
image: torch.Tensor,
|
| 374 |
+
target: Dict,
|
| 375 |
+
format: Optional[AnnotationFormat] = None,
|
| 376 |
+
return_segmentation_masks: Optional[bool] = None,
|
| 377 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 378 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 379 |
+
) -> Dict:
|
| 380 |
+
"""
|
| 381 |
+
Prepare an annotation for feeding into DETR model.
|
| 382 |
+
"""
|
| 383 |
+
format = format if format is not None else self.format
|
| 384 |
+
|
| 385 |
+
if format == AnnotationFormat.COCO_DETECTION:
|
| 386 |
+
return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
|
| 387 |
+
target = prepare_coco_detection_annotation(
|
| 388 |
+
image, target, return_segmentation_masks, input_data_format=input_data_format
|
| 389 |
+
)
|
| 390 |
+
elif format == AnnotationFormat.COCO_PANOPTIC:
|
| 391 |
+
return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
|
| 392 |
+
target = prepare_coco_panoptic_annotation(
|
| 393 |
+
image,
|
| 394 |
+
target,
|
| 395 |
+
masks_path=masks_path,
|
| 396 |
+
return_masks=return_segmentation_masks,
|
| 397 |
+
input_data_format=input_data_format,
|
| 398 |
+
)
|
| 399 |
+
else:
|
| 400 |
+
raise ValueError(f"Format {format} is not supported.")
|
| 401 |
+
return target
|
| 402 |
+
|
| 403 |
+
def resize(
|
| 404 |
+
self,
|
| 405 |
+
image: torch.Tensor,
|
| 406 |
+
size: SizeDict,
|
| 407 |
+
interpolation: "F.InterpolationMode" = None,
|
| 408 |
+
**kwargs,
|
| 409 |
+
) -> torch.Tensor:
|
| 410 |
+
"""
|
| 411 |
+
Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
|
| 412 |
+
int, smaller edge of the image will be matched to this number.
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
image (`torch.Tensor`):
|
| 416 |
+
Image to resize.
|
| 417 |
+
size (`SizeDict`):
|
| 418 |
+
Size of the image's `(height, width)` dimensions after resizing. Available options are:
|
| 419 |
+
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
|
| 420 |
+
Do NOT keep the aspect ratio.
|
| 421 |
+
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
|
| 422 |
+
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
|
| 423 |
+
less or equal to `longest_edge`.
|
| 424 |
+
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
|
| 425 |
+
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
|
| 426 |
+
`max_width`.
|
| 427 |
+
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
| 428 |
+
Resampling filter to use if resizing the image.
|
| 429 |
+
"""
|
| 430 |
+
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
|
| 431 |
+
if size.shortest_edge and size.longest_edge:
|
| 432 |
+
# Resize the image so that the shortest edge or the longest edge is of the given size
|
| 433 |
+
# while maintaining the aspect ratio of the original image.
|
| 434 |
+
new_size = get_size_with_aspect_ratio(
|
| 435 |
+
image.size()[-2:],
|
| 436 |
+
size["shortest_edge"],
|
| 437 |
+
size["longest_edge"],
|
| 438 |
+
)
|
| 439 |
+
elif size.max_height and size.max_width:
|
| 440 |
+
new_size = get_image_size_for_max_height_width(image.size()[-2:], size["max_height"], size["max_width"])
|
| 441 |
+
elif size.height and size.width:
|
| 442 |
+
new_size = (size["height"], size["width"])
|
| 443 |
+
else:
|
| 444 |
+
raise ValueError(
|
| 445 |
+
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
|
| 446 |
+
f" {size.keys()}."
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
image = F.resize(
|
| 450 |
+
image,
|
| 451 |
+
size=new_size,
|
| 452 |
+
interpolation=interpolation,
|
| 453 |
+
**kwargs,
|
| 454 |
+
)
|
| 455 |
+
return image
|
| 456 |
+
|
| 457 |
+
def resize_annotation(
|
| 458 |
+
self,
|
| 459 |
+
annotation: Dict[str, Any],
|
| 460 |
+
orig_size: Tuple[int, int],
|
| 461 |
+
target_size: Tuple[int, int],
|
| 462 |
+
threshold: float = 0.5,
|
| 463 |
+
interpolation: "F.InterpolationMode" = None,
|
| 464 |
+
):
|
| 465 |
+
"""
|
| 466 |
+
Resizes an annotation to a target size.
|
| 467 |
+
|
| 468 |
+
Args:
|
| 469 |
+
annotation (`Dict[str, Any]`):
|
| 470 |
+
The annotation dictionary.
|
| 471 |
+
orig_size (`Tuple[int, int]`):
|
| 472 |
+
The original size of the input image.
|
| 473 |
+
target_size (`Tuple[int, int]`):
|
| 474 |
+
The target size of the image, as returned by the preprocessing `resize` step.
|
| 475 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 476 |
+
The threshold used to binarize the segmentation masks.
|
| 477 |
+
resample (`InterpolationMode`, defaults to `InterpolationMode.NEAREST`):
|
| 478 |
+
The resampling filter to use when resizing the masks.
|
| 479 |
+
"""
|
| 480 |
+
interpolation = interpolation if interpolation is not None else F.InterpolationMode.NEAREST
|
| 481 |
+
ratio_height, ratio_width = [target / orig for target, orig in zip(target_size, orig_size)]
|
| 482 |
+
|
| 483 |
+
new_annotation = {}
|
| 484 |
+
new_annotation["size"] = target_size
|
| 485 |
+
|
| 486 |
+
for key, value in annotation.items():
|
| 487 |
+
if key == "boxes":
|
| 488 |
+
boxes = value
|
| 489 |
+
scaled_boxes = boxes * torch.as_tensor(
|
| 490 |
+
[ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32, device=boxes.device
|
| 491 |
+
)
|
| 492 |
+
new_annotation["boxes"] = scaled_boxes
|
| 493 |
+
elif key == "area":
|
| 494 |
+
area = value
|
| 495 |
+
scaled_area = area * (ratio_width * ratio_height)
|
| 496 |
+
new_annotation["area"] = scaled_area
|
| 497 |
+
elif key == "masks":
|
| 498 |
+
masks = value[:, None]
|
| 499 |
+
masks = [F.resize(mask, target_size, interpolation=interpolation) for mask in masks]
|
| 500 |
+
masks = torch.stack(masks).to(torch.float32)
|
| 501 |
+
masks = masks[:, 0] > threshold
|
| 502 |
+
new_annotation["masks"] = masks
|
| 503 |
+
elif key == "size":
|
| 504 |
+
new_annotation["size"] = target_size
|
| 505 |
+
else:
|
| 506 |
+
new_annotation[key] = value
|
| 507 |
+
|
| 508 |
+
return new_annotation
|
| 509 |
+
|
| 510 |
+
def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict:
|
| 511 |
+
image_height, image_width = image_size
|
| 512 |
+
norm_annotation = {}
|
| 513 |
+
for key, value in annotation.items():
|
| 514 |
+
if key == "boxes":
|
| 515 |
+
boxes = value
|
| 516 |
+
boxes = corners_to_center_format(boxes)
|
| 517 |
+
boxes /= torch.as_tensor(
|
| 518 |
+
[image_width, image_height, image_width, image_height], dtype=torch.float32, device=boxes.device
|
| 519 |
+
)
|
| 520 |
+
norm_annotation[key] = boxes
|
| 521 |
+
else:
|
| 522 |
+
norm_annotation[key] = value
|
| 523 |
+
return norm_annotation
|
| 524 |
+
|
| 525 |
+
def _update_annotation_for_padded_image(
|
| 526 |
+
self,
|
| 527 |
+
annotation: Dict,
|
| 528 |
+
input_image_size: Tuple[int, int],
|
| 529 |
+
output_image_size: Tuple[int, int],
|
| 530 |
+
padding,
|
| 531 |
+
update_bboxes,
|
| 532 |
+
) -> Dict:
|
| 533 |
+
"""
|
| 534 |
+
Update the annotation for a padded image.
|
| 535 |
+
"""
|
| 536 |
+
new_annotation = {}
|
| 537 |
+
new_annotation["size"] = output_image_size
|
| 538 |
+
ratio_height, ratio_width = (input / output for output, input in zip(output_image_size, input_image_size))
|
| 539 |
+
|
| 540 |
+
for key, value in annotation.items():
|
| 541 |
+
if key == "masks":
|
| 542 |
+
masks = value
|
| 543 |
+
masks = F.pad(
|
| 544 |
+
masks,
|
| 545 |
+
padding,
|
| 546 |
+
fill=0,
|
| 547 |
+
)
|
| 548 |
+
masks = safe_squeeze(masks, 1)
|
| 549 |
+
new_annotation["masks"] = masks
|
| 550 |
+
elif key == "boxes" and update_bboxes:
|
| 551 |
+
boxes = value
|
| 552 |
+
boxes *= torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height], device=boxes.device)
|
| 553 |
+
new_annotation["boxes"] = boxes
|
| 554 |
+
elif key == "size":
|
| 555 |
+
new_annotation["size"] = output_image_size
|
| 556 |
+
else:
|
| 557 |
+
new_annotation[key] = value
|
| 558 |
+
return new_annotation
|
| 559 |
+
|
| 560 |
+
def pad(
|
| 561 |
+
self,
|
| 562 |
+
image: torch.Tensor,
|
| 563 |
+
padded_size: Tuple[int, int],
|
| 564 |
+
annotation: Optional[Dict[str, Any]] = None,
|
| 565 |
+
update_bboxes: bool = True,
|
| 566 |
+
fill: int = 0,
|
| 567 |
+
):
|
| 568 |
+
original_size = image.size()[-2:]
|
| 569 |
+
padding_bottom = padded_size[0] - original_size[0]
|
| 570 |
+
padding_right = padded_size[1] - original_size[1]
|
| 571 |
+
if padding_bottom < 0 or padding_right < 0:
|
| 572 |
+
raise ValueError(
|
| 573 |
+
f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
|
| 574 |
+
f"original size. Got padded size: {padded_size}, original size: {original_size}."
|
| 575 |
+
)
|
| 576 |
+
if original_size != padded_size:
|
| 577 |
+
padding = [0, 0, padding_right, padding_bottom]
|
| 578 |
+
image = F.pad(image, padding, fill=fill)
|
| 579 |
+
if annotation is not None:
|
| 580 |
+
annotation = self._update_annotation_for_padded_image(
|
| 581 |
+
annotation, original_size, padded_size, padding, update_bboxes
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
# Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
| 585 |
+
pixel_mask = torch.zeros(padded_size, dtype=torch.int64, device=image.device)
|
| 586 |
+
pixel_mask[: original_size[0], : original_size[1]] = 1
|
| 587 |
+
|
| 588 |
+
return image, pixel_mask, annotation
|
| 589 |
+
|
| 590 |
+
@add_start_docstrings(
|
| 591 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
| 592 |
+
"""
|
| 593 |
+
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
|
| 594 |
+
List of annotations associated with the image or batch of images. If annotation is for object
|
| 595 |
+
detection, the annotations should be a dictionary with the following keys:
|
| 596 |
+
- "image_id" (`int`): The image id.
|
| 597 |
+
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
|
| 598 |
+
dictionary. An image can have no annotations, in which case the list should be empty.
|
| 599 |
+
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
|
| 600 |
+
- "image_id" (`int`): The image id.
|
| 601 |
+
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
| 602 |
+
An image can have no segments, in which case the list should be empty.
|
| 603 |
+
- "file_name" (`str`): The file name of the image.
|
| 604 |
+
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
|
| 605 |
+
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
|
| 606 |
+
do_convert_annotations (`bool`, *optional*, defaults to `True`):
|
| 607 |
+
Controls whether to convert the annotations to the format expected by the DETR model. Converts the
|
| 608 |
+
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
|
| 609 |
+
Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
|
| 610 |
+
do_pad (`bool`, *optional*, defaults to `True`):
|
| 611 |
+
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
|
| 612 |
+
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
|
| 613 |
+
If `pad_size` is provided, the image will be padded to the specified dimensions.
|
| 614 |
+
Otherwise, the image will be padded to the maximum height and width of the batch.
|
| 615 |
+
pad_size (`Dict[str, int]`, *optional*):
|
| 616 |
+
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
| 617 |
+
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
|
| 618 |
+
height and width in the batch.
|
| 619 |
+
return_segmentation_masks (`bool`, *optional*, defaults to `False`):
|
| 620 |
+
Whether to return segmentation masks.
|
| 621 |
+
masks_path (`str` or `pathlib.Path`, *optional*):
|
| 622 |
+
Path to the directory containing the segmentation masks.
|
| 623 |
+
""",
|
| 624 |
+
)
|
| 625 |
+
def preprocess(
|
| 626 |
+
self,
|
| 627 |
+
images: ImageInput,
|
| 628 |
+
annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
|
| 629 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 630 |
+
**kwargs: Unpack[DetrFastImageProcessorKwargs],
|
| 631 |
+
) -> BatchFeature:
|
| 632 |
+
if "pad_and_return_pixel_mask" in kwargs:
|
| 633 |
+
kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask")
|
| 634 |
+
logger.warning_once(
|
| 635 |
+
"The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, "
|
| 636 |
+
"use `do_pad` instead."
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
if "max_size" in kwargs:
|
| 640 |
+
logger.warning_once(
|
| 641 |
+
"The `max_size` argument is deprecated and will be removed in a future version, use"
|
| 642 |
+
" `size['longest_edge']` instead."
|
| 643 |
+
)
|
| 644 |
+
kwargs["size"] = kwargs.pop("max_size")
|
| 645 |
+
|
| 646 |
+
return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
|
| 647 |
+
|
| 648 |
+
def _preprocess(
|
| 649 |
+
self,
|
| 650 |
+
images: List["torch.Tensor"],
|
| 651 |
+
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
|
| 652 |
+
return_segmentation_masks: bool,
|
| 653 |
+
masks_path: Optional[Union[str, pathlib.Path]],
|
| 654 |
+
do_resize: bool,
|
| 655 |
+
size: SizeDict,
|
| 656 |
+
interpolation: Optional["F.InterpolationMode"],
|
| 657 |
+
do_center_crop: bool,
|
| 658 |
+
crop_size: SizeDict,
|
| 659 |
+
do_rescale: bool,
|
| 660 |
+
rescale_factor: float,
|
| 661 |
+
do_normalize: bool,
|
| 662 |
+
do_convert_annotations: bool,
|
| 663 |
+
image_mean: Optional[Union[float, List[float]]],
|
| 664 |
+
image_std: Optional[Union[float, List[float]]],
|
| 665 |
+
do_pad: bool,
|
| 666 |
+
pad_size: Optional[Dict[str, int]],
|
| 667 |
+
format: Optional[Union[str, AnnotationFormat]],
|
| 668 |
+
return_tensors: Optional[Union[str, TensorType]],
|
| 669 |
+
) -> BatchFeature:
|
| 670 |
+
"""
|
| 671 |
+
Preprocess an image or a batch of images so that it can be used by the model.
|
| 672 |
+
"""
|
| 673 |
+
if annotations is not None and isinstance(annotations, dict):
|
| 674 |
+
annotations = [annotations]
|
| 675 |
+
|
| 676 |
+
if annotations is not None and len(images) != len(annotations):
|
| 677 |
+
raise ValueError(
|
| 678 |
+
f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
format = AnnotationFormat(format)
|
| 682 |
+
if annotations is not None:
|
| 683 |
+
validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
|
| 684 |
+
|
| 685 |
+
if (
|
| 686 |
+
masks_path is not None
|
| 687 |
+
and format == AnnotationFormat.COCO_PANOPTIC
|
| 688 |
+
and not isinstance(masks_path, (pathlib.Path, str))
|
| 689 |
+
):
|
| 690 |
+
raise ValueError(
|
| 691 |
+
"The path to the directory containing the mask PNG files should be provided as a"
|
| 692 |
+
f" `pathlib.Path` or string object, but is {type(masks_path)} instead."
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
data = {}
|
| 696 |
+
|
| 697 |
+
processed_images = []
|
| 698 |
+
processed_annotations = []
|
| 699 |
+
pixel_masks = [] # Initialize pixel_masks here
|
| 700 |
+
for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
|
| 701 |
+
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
|
| 702 |
+
if annotations is not None:
|
| 703 |
+
annotation = self.prepare_annotation(
|
| 704 |
+
image,
|
| 705 |
+
annotation,
|
| 706 |
+
format,
|
| 707 |
+
return_segmentation_masks=return_segmentation_masks,
|
| 708 |
+
masks_path=masks_path,
|
| 709 |
+
input_data_format=ChannelDimension.FIRST,
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
if do_resize:
|
| 713 |
+
resized_image = self.resize(image, size=size, interpolation=interpolation)
|
| 714 |
+
if annotations is not None:
|
| 715 |
+
annotation = self.resize_annotation(
|
| 716 |
+
annotation,
|
| 717 |
+
orig_size=image.size()[-2:],
|
| 718 |
+
target_size=resized_image.size()[-2:],
|
| 719 |
+
)
|
| 720 |
+
image = resized_image
|
| 721 |
+
# Fused rescale and normalize
|
| 722 |
+
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
|
| 723 |
+
if do_convert_annotations and annotations is not None:
|
| 724 |
+
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
|
| 725 |
+
|
| 726 |
+
processed_images.append(image)
|
| 727 |
+
processed_annotations.append(annotation)
|
| 728 |
+
images = processed_images
|
| 729 |
+
annotations = processed_annotations if annotations is not None else None
|
| 730 |
+
|
| 731 |
+
if do_pad:
|
| 732 |
+
# depends on all resized image shapes so we need another loop
|
| 733 |
+
if pad_size is not None:
|
| 734 |
+
padded_size = (pad_size["height"], pad_size["width"])
|
| 735 |
+
else:
|
| 736 |
+
padded_size = get_max_height_width(images)
|
| 737 |
+
|
| 738 |
+
padded_images = []
|
| 739 |
+
padded_annotations = []
|
| 740 |
+
for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
|
| 741 |
+
# Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
|
| 742 |
+
if padded_size == image.size()[-2:]:
|
| 743 |
+
padded_images.append(image)
|
| 744 |
+
pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device))
|
| 745 |
+
padded_annotations.append(annotation)
|
| 746 |
+
continue
|
| 747 |
+
image, pixel_mask, annotation = self.pad(
|
| 748 |
+
image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations
|
| 749 |
+
)
|
| 750 |
+
padded_images.append(image)
|
| 751 |
+
padded_annotations.append(annotation)
|
| 752 |
+
pixel_masks.append(pixel_mask)
|
| 753 |
+
images = padded_images
|
| 754 |
+
annotations = padded_annotations if annotations is not None else None
|
| 755 |
+
data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)})
|
| 756 |
+
|
| 757 |
+
data.update({"pixel_values": torch.stack(images, dim=0)})
|
| 758 |
+
encoded_inputs = BatchFeature(data, tensor_type=return_tensors)
|
| 759 |
+
if annotations is not None:
|
| 760 |
+
encoded_inputs["labels"] = [
|
| 761 |
+
BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
|
| 762 |
+
]
|
| 763 |
+
return encoded_inputs
|
| 764 |
+
|
| 765 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process
|
| 766 |
+
def post_process(self, outputs, target_sizes):
|
| 767 |
+
"""
|
| 768 |
+
Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
| 769 |
+
bottom_right_x, bottom_right_y) format. Only supports PyTorch.
|
| 770 |
+
|
| 771 |
+
Args:
|
| 772 |
+
outputs ([`DetrObjectDetectionOutput`]):
|
| 773 |
+
Raw outputs of the model.
|
| 774 |
+
target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
|
| 775 |
+
Tensor containing the size (height, width) of each image of the batch. For evaluation, this must be the
|
| 776 |
+
original image size (before any data augmentation). For visualization, this should be the image size
|
| 777 |
+
after data augment, but before padding.
|
| 778 |
+
Returns:
|
| 779 |
+
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
| 780 |
+
in the batch as predicted by the model.
|
| 781 |
+
"""
|
| 782 |
+
logger.warning_once(
|
| 783 |
+
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
|
| 784 |
+
" `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
|
| 788 |
+
|
| 789 |
+
if len(out_logits) != len(target_sizes):
|
| 790 |
+
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
|
| 791 |
+
if target_sizes.shape[1] != 2:
|
| 792 |
+
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
|
| 793 |
+
|
| 794 |
+
prob = nn.functional.softmax(out_logits, -1)
|
| 795 |
+
scores, labels = prob[..., :-1].max(-1)
|
| 796 |
+
|
| 797 |
+
# convert to [x0, y0, x1, y1] format
|
| 798 |
+
boxes = center_to_corners_format(out_bbox)
|
| 799 |
+
# and from relative [0, 1] to absolute [0, height] coordinates
|
| 800 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 801 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
| 802 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 803 |
+
|
| 804 |
+
results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
|
| 805 |
+
return results
|
| 806 |
+
|
| 807 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_segmentation
|
| 808 |
+
def post_process_segmentation(self, outputs, target_sizes, threshold=0.9, mask_threshold=0.5):
|
| 809 |
+
"""
|
| 810 |
+
Converts the output of [`DetrForSegmentation`] into image segmentation predictions. Only supports PyTorch.
|
| 811 |
+
|
| 812 |
+
Args:
|
| 813 |
+
outputs ([`DetrSegmentationOutput`]):
|
| 814 |
+
Raw outputs of the model.
|
| 815 |
+
target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`):
|
| 816 |
+
Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction.
|
| 817 |
+
threshold (`float`, *optional*, defaults to 0.9):
|
| 818 |
+
Threshold to use to filter out queries.
|
| 819 |
+
mask_threshold (`float`, *optional*, defaults to 0.5):
|
| 820 |
+
Threshold to use when turning the predicted masks into binary values.
|
| 821 |
+
Returns:
|
| 822 |
+
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image
|
| 823 |
+
in the batch as predicted by the model.
|
| 824 |
+
"""
|
| 825 |
+
logger.warning_once(
|
| 826 |
+
"`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
|
| 827 |
+
" `post_process_semantic_segmentation`.",
|
| 828 |
+
)
|
| 829 |
+
out_logits, raw_masks = outputs.logits, outputs.pred_masks
|
| 830 |
+
empty_label = out_logits.shape[-1] - 1
|
| 831 |
+
preds = []
|
| 832 |
+
|
| 833 |
+
def to_tuple(tup):
|
| 834 |
+
if isinstance(tup, tuple):
|
| 835 |
+
return tup
|
| 836 |
+
return tuple(tup.tolist())
|
| 837 |
+
|
| 838 |
+
for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
|
| 839 |
+
# we filter empty queries and detection below threshold
|
| 840 |
+
cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
|
| 841 |
+
keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
|
| 842 |
+
cur_scores = cur_scores[keep]
|
| 843 |
+
cur_labels = cur_labels[keep]
|
| 844 |
+
cur_masks = cur_masks[keep]
|
| 845 |
+
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
|
| 846 |
+
cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1
|
| 847 |
+
|
| 848 |
+
predictions = {"scores": cur_scores, "labels": cur_labels, "masks": cur_masks}
|
| 849 |
+
preds.append(predictions)
|
| 850 |
+
return preds
|
| 851 |
+
|
| 852 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_instance
|
| 853 |
+
def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5):
|
| 854 |
+
"""
|
| 855 |
+
Converts the output of [`DetrForSegmentation`] into actual instance segmentation predictions. Only supports
|
| 856 |
+
PyTorch.
|
| 857 |
+
|
| 858 |
+
Args:
|
| 859 |
+
results (`List[Dict]`):
|
| 860 |
+
Results list obtained by [`~DetrImageProcessor.post_process`], to which "masks" results will be added.
|
| 861 |
+
outputs ([`DetrSegmentationOutput`]):
|
| 862 |
+
Raw outputs of the model.
|
| 863 |
+
orig_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
|
| 864 |
+
Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
|
| 865 |
+
image size (before any data augmentation).
|
| 866 |
+
max_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
|
| 867 |
+
Tensor containing the maximum size (h, w) of each image of the batch. For evaluation, this must be the
|
| 868 |
+
original image size (before any data augmentation).
|
| 869 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 870 |
+
Threshold to use when turning the predicted masks into binary values.
|
| 871 |
+
Returns:
|
| 872 |
+
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an
|
| 873 |
+
image in the batch as predicted by the model.
|
| 874 |
+
"""
|
| 875 |
+
logger.warning_once(
|
| 876 |
+
"`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use"
|
| 877 |
+
" `post_process_instance_segmentation`.",
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
if len(orig_target_sizes) != len(max_target_sizes):
|
| 881 |
+
raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes")
|
| 882 |
+
max_h, max_w = max_target_sizes.max(0)[0].tolist()
|
| 883 |
+
outputs_masks = outputs.pred_masks.squeeze(2)
|
| 884 |
+
outputs_masks = nn.functional.interpolate(
|
| 885 |
+
outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False
|
| 886 |
+
)
|
| 887 |
+
outputs_masks = (outputs_masks.sigmoid() > threshold).cpu()
|
| 888 |
+
|
| 889 |
+
for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
|
| 890 |
+
img_h, img_w = t[0], t[1]
|
| 891 |
+
results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
|
| 892 |
+
results[i]["masks"] = nn.functional.interpolate(
|
| 893 |
+
results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
|
| 894 |
+
).byte()
|
| 895 |
+
|
| 896 |
+
return results
|
| 897 |
+
|
| 898 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_panoptic
|
| 899 |
+
def post_process_panoptic(self, outputs, processed_sizes, target_sizes=None, is_thing_map=None, threshold=0.85):
|
| 900 |
+
"""
|
| 901 |
+
Converts the output of [`DetrForSegmentation`] into actual panoptic predictions. Only supports PyTorch.
|
| 902 |
+
|
| 903 |
+
Args:
|
| 904 |
+
outputs ([`DetrSegmentationOutput`]):
|
| 905 |
+
Raw outputs of the model.
|
| 906 |
+
processed_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`):
|
| 907 |
+
Torch Tensor (or list) containing the size (h, w) of each image of the batch, i.e. the size after data
|
| 908 |
+
augmentation but before batching.
|
| 909 |
+
target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*):
|
| 910 |
+
Torch Tensor (or list) corresponding to the requested final size `(height, width)` of each prediction.
|
| 911 |
+
If left to None, it will default to the `processed_sizes`.
|
| 912 |
+
is_thing_map (`torch.Tensor` of shape `(batch_size, 2)`, *optional*):
|
| 913 |
+
Dictionary mapping class indices to either True or False, depending on whether or not they are a thing.
|
| 914 |
+
If not set, defaults to the `is_thing_map` of COCO panoptic.
|
| 915 |
+
threshold (`float`, *optional*, defaults to 0.85):
|
| 916 |
+
Threshold to use to filter out queries.
|
| 917 |
+
Returns:
|
| 918 |
+
`List[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for
|
| 919 |
+
an image in the batch as predicted by the model.
|
| 920 |
+
"""
|
| 921 |
+
logger.warning_once(
|
| 922 |
+
"`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use"
|
| 923 |
+
" `post_process_panoptic_segmentation`.",
|
| 924 |
+
)
|
| 925 |
+
if target_sizes is None:
|
| 926 |
+
target_sizes = processed_sizes
|
| 927 |
+
if len(processed_sizes) != len(target_sizes):
|
| 928 |
+
raise ValueError("Make sure to pass in as many processed_sizes as target_sizes")
|
| 929 |
+
|
| 930 |
+
if is_thing_map is None:
|
| 931 |
+
# default to is_thing_map of COCO panoptic
|
| 932 |
+
is_thing_map = {i: i <= 90 for i in range(201)}
|
| 933 |
+
|
| 934 |
+
out_logits, raw_masks, raw_boxes = outputs.logits, outputs.pred_masks, outputs.pred_boxes
|
| 935 |
+
if not len(out_logits) == len(raw_masks) == len(target_sizes):
|
| 936 |
+
raise ValueError(
|
| 937 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits and masks"
|
| 938 |
+
)
|
| 939 |
+
empty_label = out_logits.shape[-1] - 1
|
| 940 |
+
preds = []
|
| 941 |
+
|
| 942 |
+
def to_tuple(tup):
|
| 943 |
+
if isinstance(tup, tuple):
|
| 944 |
+
return tup
|
| 945 |
+
return tuple(tup.tolist())
|
| 946 |
+
|
| 947 |
+
for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
|
| 948 |
+
out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
|
| 949 |
+
):
|
| 950 |
+
# we filter empty queries and detection below threshold
|
| 951 |
+
cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
|
| 952 |
+
keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
|
| 953 |
+
cur_scores = cur_scores[keep]
|
| 954 |
+
cur_labels = cur_labels[keep]
|
| 955 |
+
cur_masks = cur_masks[keep]
|
| 956 |
+
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
|
| 957 |
+
cur_boxes = center_to_corners_format(cur_boxes[keep])
|
| 958 |
+
|
| 959 |
+
h, w = cur_masks.shape[-2:]
|
| 960 |
+
if len(cur_boxes) != len(cur_labels):
|
| 961 |
+
raise ValueError("Not as many boxes as there are classes")
|
| 962 |
+
|
| 963 |
+
# It may be that we have several predicted masks for the same stuff class.
|
| 964 |
+
# In the following, we track the list of masks ids for each stuff class (they are merged later on)
|
| 965 |
+
cur_masks = cur_masks.flatten(1)
|
| 966 |
+
stuff_equiv_classes = defaultdict(lambda: [])
|
| 967 |
+
for k, label in enumerate(cur_labels):
|
| 968 |
+
if not is_thing_map[label.item()]:
|
| 969 |
+
stuff_equiv_classes[label.item()].append(k)
|
| 970 |
+
|
| 971 |
+
def get_ids_area(masks, scores, dedup=False):
|
| 972 |
+
# This helper function creates the final panoptic segmentation image
|
| 973 |
+
# It also returns the area of the masks that appears on the image
|
| 974 |
+
|
| 975 |
+
m_id = masks.transpose(0, 1).softmax(-1)
|
| 976 |
+
|
| 977 |
+
if m_id.shape[-1] == 0:
|
| 978 |
+
# We didn't detect any mask :(
|
| 979 |
+
m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device)
|
| 980 |
+
else:
|
| 981 |
+
m_id = m_id.argmax(-1).view(h, w)
|
| 982 |
+
|
| 983 |
+
if dedup:
|
| 984 |
+
# Merge the masks corresponding to the same stuff class
|
| 985 |
+
for equiv in stuff_equiv_classes.values():
|
| 986 |
+
if len(equiv) > 1:
|
| 987 |
+
for eq_id in equiv:
|
| 988 |
+
m_id.masked_fill_(m_id.eq(eq_id), equiv[0])
|
| 989 |
+
|
| 990 |
+
final_h, final_w = to_tuple(target_size)
|
| 991 |
+
|
| 992 |
+
seg_img = PIL.Image.fromarray(id_to_rgb(m_id.view(h, w).cpu().numpy()))
|
| 993 |
+
seg_img = seg_img.resize(size=(final_w, final_h), resample=PILImageResampling.NEAREST)
|
| 994 |
+
|
| 995 |
+
np_seg_img = torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes()))
|
| 996 |
+
np_seg_img = np_seg_img.view(final_h, final_w, 3)
|
| 997 |
+
np_seg_img = np_seg_img.numpy()
|
| 998 |
+
|
| 999 |
+
m_id = torch.from_numpy(rgb_to_id(np_seg_img))
|
| 1000 |
+
|
| 1001 |
+
area = []
|
| 1002 |
+
for i in range(len(scores)):
|
| 1003 |
+
area.append(m_id.eq(i).sum().item())
|
| 1004 |
+
return area, seg_img
|
| 1005 |
+
|
| 1006 |
+
area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
|
| 1007 |
+
if cur_labels.numel() > 0:
|
| 1008 |
+
# We know filter empty masks as long as we find some
|
| 1009 |
+
while True:
|
| 1010 |
+
filtered_small = torch.as_tensor(
|
| 1011 |
+
[area[i] <= 4 for i, c in enumerate(cur_labels)], dtype=torch.bool, device=keep.device
|
| 1012 |
+
)
|
| 1013 |
+
if filtered_small.any().item():
|
| 1014 |
+
cur_scores = cur_scores[~filtered_small]
|
| 1015 |
+
cur_labels = cur_labels[~filtered_small]
|
| 1016 |
+
cur_masks = cur_masks[~filtered_small]
|
| 1017 |
+
area, seg_img = get_ids_area(cur_masks, cur_scores)
|
| 1018 |
+
else:
|
| 1019 |
+
break
|
| 1020 |
+
|
| 1021 |
+
else:
|
| 1022 |
+
cur_labels = torch.ones(1, dtype=torch.long, device=cur_labels.device)
|
| 1023 |
+
|
| 1024 |
+
segments_info = []
|
| 1025 |
+
for i, a in enumerate(area):
|
| 1026 |
+
cat = cur_labels[i].item()
|
| 1027 |
+
segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a})
|
| 1028 |
+
del cur_labels
|
| 1029 |
+
|
| 1030 |
+
with io.BytesIO() as out:
|
| 1031 |
+
seg_img.save(out, format="PNG")
|
| 1032 |
+
predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
|
| 1033 |
+
preds.append(predictions)
|
| 1034 |
+
return preds
|
| 1035 |
+
|
| 1036 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_object_detection
|
| 1037 |
+
def post_process_object_detection(
|
| 1038 |
+
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None
|
| 1039 |
+
):
|
| 1040 |
+
"""
|
| 1041 |
+
Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
| 1042 |
+
bottom_right_x, bottom_right_y) format. Only supports PyTorch.
|
| 1043 |
+
|
| 1044 |
+
Args:
|
| 1045 |
+
outputs ([`DetrObjectDetectionOutput`]):
|
| 1046 |
+
Raw outputs of the model.
|
| 1047 |
+
threshold (`float`, *optional*):
|
| 1048 |
+
Score threshold to keep object detection predictions.
|
| 1049 |
+
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
|
| 1050 |
+
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
| 1051 |
+
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
|
| 1052 |
+
Returns:
|
| 1053 |
+
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
| 1054 |
+
in the batch as predicted by the model.
|
| 1055 |
+
"""
|
| 1056 |
+
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
|
| 1057 |
+
|
| 1058 |
+
if target_sizes is not None:
|
| 1059 |
+
if len(out_logits) != len(target_sizes):
|
| 1060 |
+
raise ValueError(
|
| 1061 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
prob = nn.functional.softmax(out_logits, -1)
|
| 1065 |
+
scores, labels = prob[..., :-1].max(-1)
|
| 1066 |
+
|
| 1067 |
+
# Convert to [x0, y0, x1, y1] format
|
| 1068 |
+
boxes = center_to_corners_format(out_bbox)
|
| 1069 |
+
|
| 1070 |
+
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
| 1071 |
+
if target_sizes is not None:
|
| 1072 |
+
if isinstance(target_sizes, List):
|
| 1073 |
+
img_h = torch.Tensor([i[0] for i in target_sizes])
|
| 1074 |
+
img_w = torch.Tensor([i[1] for i in target_sizes])
|
| 1075 |
+
else:
|
| 1076 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 1077 |
+
|
| 1078 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
| 1079 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 1080 |
+
|
| 1081 |
+
results = []
|
| 1082 |
+
for s, l, b in zip(scores, labels, boxes):
|
| 1083 |
+
score = s[s > threshold]
|
| 1084 |
+
label = l[s > threshold]
|
| 1085 |
+
box = b[s > threshold]
|
| 1086 |
+
results.append({"scores": score, "labels": label, "boxes": box})
|
| 1087 |
+
|
| 1088 |
+
return results
|
| 1089 |
+
|
| 1090 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_semantic_segmentation
|
| 1091 |
+
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple[int, int]] = None):
|
| 1092 |
+
"""
|
| 1093 |
+
Converts the output of [`DetrForSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
| 1094 |
+
|
| 1095 |
+
Args:
|
| 1096 |
+
outputs ([`DetrForSegmentation`]):
|
| 1097 |
+
Raw outputs of the model.
|
| 1098 |
+
target_sizes (`List[Tuple[int, int]]`, *optional*):
|
| 1099 |
+
A list of tuples (`Tuple[int, int]`) containing the target size (height, width) of each image in the
|
| 1100 |
+
batch. If unset, predictions will not be resized.
|
| 1101 |
+
Returns:
|
| 1102 |
+
`List[torch.Tensor]`:
|
| 1103 |
+
A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
|
| 1104 |
+
corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
|
| 1105 |
+
`torch.Tensor` correspond to a semantic class id.
|
| 1106 |
+
"""
|
| 1107 |
+
class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
|
| 1108 |
+
masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
|
| 1109 |
+
|
| 1110 |
+
# Remove the null class `[..., :-1]`
|
| 1111 |
+
masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
|
| 1112 |
+
masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
|
| 1113 |
+
|
| 1114 |
+
# Semantic segmentation logits of shape (batch_size, num_classes, height, width)
|
| 1115 |
+
segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
| 1116 |
+
batch_size = class_queries_logits.shape[0]
|
| 1117 |
+
|
| 1118 |
+
# Resize logits and compute semantic segmentation maps
|
| 1119 |
+
if target_sizes is not None:
|
| 1120 |
+
if batch_size != len(target_sizes):
|
| 1121 |
+
raise ValueError(
|
| 1122 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
| 1123 |
+
)
|
| 1124 |
+
|
| 1125 |
+
semantic_segmentation = []
|
| 1126 |
+
for idx in range(batch_size):
|
| 1127 |
+
resized_logits = nn.functional.interpolate(
|
| 1128 |
+
segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
| 1129 |
+
)
|
| 1130 |
+
semantic_map = resized_logits[0].argmax(dim=0)
|
| 1131 |
+
semantic_segmentation.append(semantic_map)
|
| 1132 |
+
else:
|
| 1133 |
+
semantic_segmentation = segmentation.argmax(dim=1)
|
| 1134 |
+
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
| 1135 |
+
|
| 1136 |
+
return semantic_segmentation
|
| 1137 |
+
|
| 1138 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_instance_segmentation
|
| 1139 |
+
def post_process_instance_segmentation(
|
| 1140 |
+
self,
|
| 1141 |
+
outputs,
|
| 1142 |
+
threshold: float = 0.5,
|
| 1143 |
+
mask_threshold: float = 0.5,
|
| 1144 |
+
overlap_mask_area_threshold: float = 0.8,
|
| 1145 |
+
target_sizes: Optional[List[Tuple[int, int]]] = None,
|
| 1146 |
+
return_coco_annotation: Optional[bool] = False,
|
| 1147 |
+
) -> List[Dict]:
|
| 1148 |
+
"""
|
| 1149 |
+
Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch.
|
| 1150 |
+
|
| 1151 |
+
Args:
|
| 1152 |
+
outputs ([`DetrForSegmentation`]):
|
| 1153 |
+
Raw outputs of the model.
|
| 1154 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 1155 |
+
The probability score threshold to keep predicted instance masks.
|
| 1156 |
+
mask_threshold (`float`, *optional*, defaults to 0.5):
|
| 1157 |
+
Threshold to use when turning the predicted masks into binary values.
|
| 1158 |
+
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
| 1159 |
+
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
| 1160 |
+
instance mask.
|
| 1161 |
+
target_sizes (`List[Tuple]`, *optional*):
|
| 1162 |
+
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
|
| 1163 |
+
final size (height, width) of each prediction. If unset, predictions will not be resized.
|
| 1164 |
+
return_coco_annotation (`bool`, *optional*):
|
| 1165 |
+
Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE)
|
| 1166 |
+
format.
|
| 1167 |
+
Returns:
|
| 1168 |
+
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
|
| 1169 |
+
- **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
|
| 1170 |
+
`List[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to
|
| 1171 |
+
`True`. Set to `None` if no mask if found above `threshold`.
|
| 1172 |
+
- **segments_info** -- A dictionary that contains additional information on each segment.
|
| 1173 |
+
- **id** -- An integer representing the `segment_id`.
|
| 1174 |
+
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
|
| 1175 |
+
- **score** -- Prediction score of segment with `segment_id`.
|
| 1176 |
+
"""
|
| 1177 |
+
class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
|
| 1178 |
+
masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
|
| 1179 |
+
|
| 1180 |
+
batch_size = class_queries_logits.shape[0]
|
| 1181 |
+
num_labels = class_queries_logits.shape[-1] - 1
|
| 1182 |
+
|
| 1183 |
+
mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
|
| 1184 |
+
|
| 1185 |
+
# Predicted label and score of each query (batch_size, num_queries)
|
| 1186 |
+
pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
|
| 1187 |
+
|
| 1188 |
+
# Loop over items in batch size
|
| 1189 |
+
results: List[Dict[str, TensorType]] = []
|
| 1190 |
+
|
| 1191 |
+
for i in range(batch_size):
|
| 1192 |
+
mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
|
| 1193 |
+
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
|
| 1194 |
+
)
|
| 1195 |
+
|
| 1196 |
+
# No mask found
|
| 1197 |
+
if mask_probs_item.shape[0] <= 0:
|
| 1198 |
+
height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
|
| 1199 |
+
segmentation = torch.zeros((height, width)) - 1
|
| 1200 |
+
results.append({"segmentation": segmentation, "segments_info": []})
|
| 1201 |
+
continue
|
| 1202 |
+
|
| 1203 |
+
# Get segmentation map and segment information of batch item
|
| 1204 |
+
target_size = target_sizes[i] if target_sizes is not None else None
|
| 1205 |
+
segmentation, segments = compute_segments(
|
| 1206 |
+
mask_probs=mask_probs_item,
|
| 1207 |
+
pred_scores=pred_scores_item,
|
| 1208 |
+
pred_labels=pred_labels_item,
|
| 1209 |
+
mask_threshold=mask_threshold,
|
| 1210 |
+
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
| 1211 |
+
label_ids_to_fuse=[],
|
| 1212 |
+
target_size=target_size,
|
| 1213 |
+
)
|
| 1214 |
+
|
| 1215 |
+
# Return segmentation map in run-length encoding (RLE) format
|
| 1216 |
+
if return_coco_annotation:
|
| 1217 |
+
segmentation = convert_segmentation_to_rle(segmentation)
|
| 1218 |
+
|
| 1219 |
+
results.append({"segmentation": segmentation, "segments_info": segments})
|
| 1220 |
+
return results
|
| 1221 |
+
|
| 1222 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_panoptic_segmentation
|
| 1223 |
+
def post_process_panoptic_segmentation(
|
| 1224 |
+
self,
|
| 1225 |
+
outputs,
|
| 1226 |
+
threshold: float = 0.5,
|
| 1227 |
+
mask_threshold: float = 0.5,
|
| 1228 |
+
overlap_mask_area_threshold: float = 0.8,
|
| 1229 |
+
label_ids_to_fuse: Optional[Set[int]] = None,
|
| 1230 |
+
target_sizes: Optional[List[Tuple[int, int]]] = None,
|
| 1231 |
+
) -> List[Dict]:
|
| 1232 |
+
"""
|
| 1233 |
+
Converts the output of [`DetrForSegmentation`] into image panoptic segmentation predictions. Only supports
|
| 1234 |
+
PyTorch.
|
| 1235 |
+
|
| 1236 |
+
Args:
|
| 1237 |
+
outputs ([`DetrForSegmentation`]):
|
| 1238 |
+
The outputs from [`DetrForSegmentation`].
|
| 1239 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 1240 |
+
The probability score threshold to keep predicted instance masks.
|
| 1241 |
+
mask_threshold (`float`, *optional*, defaults to 0.5):
|
| 1242 |
+
Threshold to use when turning the predicted masks into binary values.
|
| 1243 |
+
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
| 1244 |
+
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
| 1245 |
+
instance mask.
|
| 1246 |
+
label_ids_to_fuse (`Set[int]`, *optional*):
|
| 1247 |
+
The labels in this state will have all their instances be fused together. For instance we could say
|
| 1248 |
+
there can only be one sky in an image, but several persons, so the label ID for sky would be in that
|
| 1249 |
+
set, but not the one for person.
|
| 1250 |
+
target_sizes (`List[Tuple]`, *optional*):
|
| 1251 |
+
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
|
| 1252 |
+
final size (height, width) of each prediction in batch. If unset, predictions will not be resized.
|
| 1253 |
+
Returns:
|
| 1254 |
+
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
|
| 1255 |
+
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id` or
|
| 1256 |
+
`None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized to
|
| 1257 |
+
the corresponding `target_sizes` entry.
|
| 1258 |
+
- **segments_info** -- A dictionary that contains additional information on each segment.
|
| 1259 |
+
- **id** -- an integer representing the `segment_id`.
|
| 1260 |
+
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
|
| 1261 |
+
- **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
|
| 1262 |
+
Multiple instances of the same class / label were fused and assigned a single `segment_id`.
|
| 1263 |
+
- **score** -- Prediction score of segment with `segment_id`.
|
| 1264 |
+
"""
|
| 1265 |
+
|
| 1266 |
+
if label_ids_to_fuse is None:
|
| 1267 |
+
logger.warning_once("`label_ids_to_fuse` unset. No instance will be fused.")
|
| 1268 |
+
label_ids_to_fuse = set()
|
| 1269 |
+
|
| 1270 |
+
class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
|
| 1271 |
+
masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
|
| 1272 |
+
|
| 1273 |
+
batch_size = class_queries_logits.shape[0]
|
| 1274 |
+
num_labels = class_queries_logits.shape[-1] - 1
|
| 1275 |
+
|
| 1276 |
+
mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
|
| 1277 |
+
|
| 1278 |
+
# Predicted label and score of each query (batch_size, num_queries)
|
| 1279 |
+
pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
|
| 1280 |
+
|
| 1281 |
+
# Loop over items in batch size
|
| 1282 |
+
results: List[Dict[str, TensorType]] = []
|
| 1283 |
+
|
| 1284 |
+
for i in range(batch_size):
|
| 1285 |
+
mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
|
| 1286 |
+
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
|
| 1287 |
+
)
|
| 1288 |
+
|
| 1289 |
+
# No mask found
|
| 1290 |
+
if mask_probs_item.shape[0] <= 0:
|
| 1291 |
+
height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
|
| 1292 |
+
segmentation = torch.zeros((height, width)) - 1
|
| 1293 |
+
results.append({"segmentation": segmentation, "segments_info": []})
|
| 1294 |
+
continue
|
| 1295 |
+
|
| 1296 |
+
# Get segmentation map and segment information of batch item
|
| 1297 |
+
target_size = target_sizes[i] if target_sizes is not None else None
|
| 1298 |
+
segmentation, segments = compute_segments(
|
| 1299 |
+
mask_probs=mask_probs_item,
|
| 1300 |
+
pred_scores=pred_scores_item,
|
| 1301 |
+
pred_labels=pred_labels_item,
|
| 1302 |
+
mask_threshold=mask_threshold,
|
| 1303 |
+
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
| 1304 |
+
label_ids_to_fuse=label_ids_to_fuse,
|
| 1305 |
+
target_size=target_size,
|
| 1306 |
+
)
|
| 1307 |
+
|
| 1308 |
+
results.append({"segmentation": segmentation, "segments_info": segments})
|
| 1309 |
+
return results
|
| 1310 |
+
|
| 1311 |
+
|
| 1312 |
+
__all__ = ["DetrImageProcessorFast"]
|
docs/transformers/build/lib/transformers/models/detr/modeling_detr.py
ADDED
|
@@ -0,0 +1,1815 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 Facebook AI Research The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch DETR model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from torch import Tensor, nn
|
| 23 |
+
|
| 24 |
+
from ...activations import ACT2FN
|
| 25 |
+
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
| 26 |
+
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
|
| 27 |
+
from ...modeling_utils import PreTrainedModel
|
| 28 |
+
from ...utils import (
|
| 29 |
+
ModelOutput,
|
| 30 |
+
add_start_docstrings,
|
| 31 |
+
add_start_docstrings_to_model_forward,
|
| 32 |
+
is_timm_available,
|
| 33 |
+
logging,
|
| 34 |
+
replace_return_docstrings,
|
| 35 |
+
requires_backends,
|
| 36 |
+
)
|
| 37 |
+
from ...utils.backbone_utils import load_backbone
|
| 38 |
+
from .configuration_detr import DetrConfig
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if is_timm_available():
|
| 42 |
+
from timm import create_model
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
logger = logging.get_logger(__name__)
|
| 46 |
+
|
| 47 |
+
_CONFIG_FOR_DOC = "DetrConfig"
|
| 48 |
+
_CHECKPOINT_FOR_DOC = "facebook/detr-resnet-50"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class DetrDecoderOutput(BaseModelOutputWithCrossAttentions):
|
| 53 |
+
"""
|
| 54 |
+
Base class for outputs of the DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
|
| 55 |
+
namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
|
| 56 |
+
gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 60 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 61 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 62 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 63 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
| 64 |
+
plus the initial embedding outputs.
|
| 65 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 66 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 67 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
| 68 |
+
the self-attention heads.
|
| 69 |
+
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
|
| 70 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 71 |
+
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
|
| 72 |
+
used to compute the weighted average in the cross-attention heads.
|
| 73 |
+
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
|
| 74 |
+
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
|
| 75 |
+
layernorm.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@dataclass
|
| 82 |
+
class DetrModelOutput(Seq2SeqModelOutput):
|
| 83 |
+
"""
|
| 84 |
+
Base class for outputs of the DETR encoder-decoder model. This class adds one attribute to Seq2SeqModelOutput,
|
| 85 |
+
namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
|
| 86 |
+
gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 90 |
+
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
| 91 |
+
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 92 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 93 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each
|
| 94 |
+
layer plus the initial embedding outputs.
|
| 95 |
+
decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 96 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 97 |
+
sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the
|
| 98 |
+
weighted average in the self-attention heads.
|
| 99 |
+
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 100 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 101 |
+
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
|
| 102 |
+
used to compute the weighted average in the cross-attention heads.
|
| 103 |
+
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 104 |
+
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
| 105 |
+
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 106 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 107 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
|
| 108 |
+
layer plus the initial embedding outputs.
|
| 109 |
+
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 110 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 111 |
+
sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the
|
| 112 |
+
weighted average in the self-attention heads.
|
| 113 |
+
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
|
| 114 |
+
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
|
| 115 |
+
layernorm.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@dataclass
|
| 122 |
+
class DetrObjectDetectionOutput(ModelOutput):
|
| 123 |
+
"""
|
| 124 |
+
Output type of [`DetrForObjectDetection`].
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
|
| 128 |
+
Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
|
| 129 |
+
bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
|
| 130 |
+
scale-invariant IoU loss.
|
| 131 |
+
loss_dict (`Dict`, *optional*):
|
| 132 |
+
A dictionary containing the individual losses. Useful for logging.
|
| 133 |
+
logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
|
| 134 |
+
Classification logits (including no-object) for all queries.
|
| 135 |
+
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 136 |
+
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
| 137 |
+
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
|
| 138 |
+
possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
|
| 139 |
+
unnormalized bounding boxes.
|
| 140 |
+
auxiliary_outputs (`list[Dict]`, *optional*):
|
| 141 |
+
Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
|
| 142 |
+
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
|
| 143 |
+
`pred_boxes`) for each decoder layer.
|
| 144 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 145 |
+
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
| 146 |
+
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 147 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 148 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each
|
| 149 |
+
layer plus the initial embedding outputs.
|
| 150 |
+
decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 151 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 152 |
+
sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the
|
| 153 |
+
weighted average in the self-attention heads.
|
| 154 |
+
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 155 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 156 |
+
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
|
| 157 |
+
used to compute the weighted average in the cross-attention heads.
|
| 158 |
+
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 159 |
+
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
| 160 |
+
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 161 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 162 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
|
| 163 |
+
layer plus the initial embedding outputs.
|
| 164 |
+
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 165 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 166 |
+
sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the
|
| 167 |
+
weighted average in the self-attention heads.
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
loss: Optional[torch.FloatTensor] = None
|
| 171 |
+
loss_dict: Optional[Dict] = None
|
| 172 |
+
logits: Optional[torch.FloatTensor] = None
|
| 173 |
+
pred_boxes: Optional[torch.FloatTensor] = None
|
| 174 |
+
auxiliary_outputs: Optional[List[Dict]] = None
|
| 175 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 176 |
+
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 177 |
+
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 178 |
+
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 179 |
+
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
| 180 |
+
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 181 |
+
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@dataclass
|
| 185 |
+
class DetrSegmentationOutput(ModelOutput):
|
| 186 |
+
"""
|
| 187 |
+
Output type of [`DetrForSegmentation`].
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
|
| 191 |
+
Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
|
| 192 |
+
bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
|
| 193 |
+
scale-invariant IoU loss.
|
| 194 |
+
loss_dict (`Dict`, *optional*):
|
| 195 |
+
A dictionary containing the individual losses. Useful for logging.
|
| 196 |
+
logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
|
| 197 |
+
Classification logits (including no-object) for all queries.
|
| 198 |
+
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 199 |
+
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
| 200 |
+
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
|
| 201 |
+
possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
|
| 202 |
+
unnormalized bounding boxes.
|
| 203 |
+
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`):
|
| 204 |
+
Segmentation masks logits for all queries. See also
|
| 205 |
+
[`~DetrImageProcessor.post_process_semantic_segmentation`] or
|
| 206 |
+
[`~DetrImageProcessor.post_process_instance_segmentation`]
|
| 207 |
+
[`~DetrImageProcessor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic
|
| 208 |
+
segmentation masks respectively.
|
| 209 |
+
auxiliary_outputs (`list[Dict]`, *optional*):
|
| 210 |
+
Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
|
| 211 |
+
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
|
| 212 |
+
`pred_boxes`) for each decoder layer.
|
| 213 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 214 |
+
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
| 215 |
+
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 216 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 217 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each
|
| 218 |
+
layer plus the initial embedding outputs.
|
| 219 |
+
decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 220 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 221 |
+
sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the
|
| 222 |
+
weighted average in the self-attention heads.
|
| 223 |
+
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 224 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 225 |
+
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
|
| 226 |
+
used to compute the weighted average in the cross-attention heads.
|
| 227 |
+
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 228 |
+
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
| 229 |
+
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 230 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 231 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
|
| 232 |
+
layer plus the initial embedding outputs.
|
| 233 |
+
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 234 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 235 |
+
sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the
|
| 236 |
+
weighted average in the self-attention heads.
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
loss: Optional[torch.FloatTensor] = None
|
| 240 |
+
loss_dict: Optional[Dict] = None
|
| 241 |
+
logits: Optional[torch.FloatTensor] = None
|
| 242 |
+
pred_boxes: Optional[torch.FloatTensor] = None
|
| 243 |
+
pred_masks: Optional[torch.FloatTensor] = None
|
| 244 |
+
auxiliary_outputs: Optional[List[Dict]] = None
|
| 245 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 246 |
+
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 247 |
+
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 248 |
+
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 249 |
+
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
| 250 |
+
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 251 |
+
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# BELOW: utilities copied from
|
| 255 |
+
# https://github.com/facebookresearch/detr/blob/master/backbone.py
|
| 256 |
+
class DetrFrozenBatchNorm2d(nn.Module):
|
| 257 |
+
"""
|
| 258 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
| 259 |
+
|
| 260 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
|
| 261 |
+
torchvision.models.resnet[18,34,50,101] produce nans.
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
def __init__(self, n):
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.register_buffer("weight", torch.ones(n))
|
| 267 |
+
self.register_buffer("bias", torch.zeros(n))
|
| 268 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
| 269 |
+
self.register_buffer("running_var", torch.ones(n))
|
| 270 |
+
|
| 271 |
+
def _load_from_state_dict(
|
| 272 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 273 |
+
):
|
| 274 |
+
num_batches_tracked_key = prefix + "num_batches_tracked"
|
| 275 |
+
if num_batches_tracked_key in state_dict:
|
| 276 |
+
del state_dict[num_batches_tracked_key]
|
| 277 |
+
|
| 278 |
+
super()._load_from_state_dict(
|
| 279 |
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
def forward(self, x):
|
| 283 |
+
# move reshapes to the beginning
|
| 284 |
+
# to make it user-friendly
|
| 285 |
+
weight = self.weight.reshape(1, -1, 1, 1)
|
| 286 |
+
bias = self.bias.reshape(1, -1, 1, 1)
|
| 287 |
+
running_var = self.running_var.reshape(1, -1, 1, 1)
|
| 288 |
+
running_mean = self.running_mean.reshape(1, -1, 1, 1)
|
| 289 |
+
epsilon = 1e-5
|
| 290 |
+
scale = weight * (running_var + epsilon).rsqrt()
|
| 291 |
+
bias = bias - running_mean * scale
|
| 292 |
+
return x * scale + bias
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def replace_batch_norm(model):
|
| 296 |
+
r"""
|
| 297 |
+
Recursively replace all `torch.nn.BatchNorm2d` with `DetrFrozenBatchNorm2d`.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
model (torch.nn.Module):
|
| 301 |
+
input model
|
| 302 |
+
"""
|
| 303 |
+
for name, module in model.named_children():
|
| 304 |
+
if isinstance(module, nn.BatchNorm2d):
|
| 305 |
+
new_module = DetrFrozenBatchNorm2d(module.num_features)
|
| 306 |
+
|
| 307 |
+
if not module.weight.device == torch.device("meta"):
|
| 308 |
+
new_module.weight.data.copy_(module.weight)
|
| 309 |
+
new_module.bias.data.copy_(module.bias)
|
| 310 |
+
new_module.running_mean.data.copy_(module.running_mean)
|
| 311 |
+
new_module.running_var.data.copy_(module.running_var)
|
| 312 |
+
|
| 313 |
+
model._modules[name] = new_module
|
| 314 |
+
|
| 315 |
+
if len(list(module.children())) > 0:
|
| 316 |
+
replace_batch_norm(module)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class DetrConvEncoder(nn.Module):
|
| 320 |
+
"""
|
| 321 |
+
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
|
| 322 |
+
|
| 323 |
+
nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.
|
| 324 |
+
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
def __init__(self, config):
|
| 328 |
+
super().__init__()
|
| 329 |
+
|
| 330 |
+
self.config = config
|
| 331 |
+
|
| 332 |
+
# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
|
| 333 |
+
if config.use_timm_backbone:
|
| 334 |
+
# We default to values which were previously hard-coded. This enables configurability from the config
|
| 335 |
+
# using backbone arguments, while keeping the default behavior the same.
|
| 336 |
+
requires_backends(self, ["timm"])
|
| 337 |
+
kwargs = getattr(config, "backbone_kwargs", {})
|
| 338 |
+
kwargs = {} if kwargs is None else kwargs.copy()
|
| 339 |
+
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
|
| 340 |
+
num_channels = kwargs.pop("in_chans", config.num_channels)
|
| 341 |
+
if config.dilation:
|
| 342 |
+
kwargs["output_stride"] = kwargs.get("output_stride", 16)
|
| 343 |
+
backbone = create_model(
|
| 344 |
+
config.backbone,
|
| 345 |
+
pretrained=config.use_pretrained_backbone,
|
| 346 |
+
features_only=True,
|
| 347 |
+
out_indices=out_indices,
|
| 348 |
+
in_chans=num_channels,
|
| 349 |
+
**kwargs,
|
| 350 |
+
)
|
| 351 |
+
else:
|
| 352 |
+
backbone = load_backbone(config)
|
| 353 |
+
|
| 354 |
+
# replace batch norm by frozen batch norm
|
| 355 |
+
with torch.no_grad():
|
| 356 |
+
replace_batch_norm(backbone)
|
| 357 |
+
self.model = backbone
|
| 358 |
+
self.intermediate_channel_sizes = (
|
| 359 |
+
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
backbone_model_type = None
|
| 363 |
+
if config.backbone is not None:
|
| 364 |
+
backbone_model_type = config.backbone
|
| 365 |
+
elif config.backbone_config is not None:
|
| 366 |
+
backbone_model_type = config.backbone_config.model_type
|
| 367 |
+
else:
|
| 368 |
+
raise ValueError("Either `backbone` or `backbone_config` should be provided in the config")
|
| 369 |
+
|
| 370 |
+
if "resnet" in backbone_model_type:
|
| 371 |
+
for name, parameter in self.model.named_parameters():
|
| 372 |
+
if config.use_timm_backbone:
|
| 373 |
+
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
| 374 |
+
parameter.requires_grad_(False)
|
| 375 |
+
else:
|
| 376 |
+
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
|
| 377 |
+
parameter.requires_grad_(False)
|
| 378 |
+
|
| 379 |
+
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
|
| 380 |
+
# send pixel_values through the model to get list of feature maps
|
| 381 |
+
features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
|
| 382 |
+
|
| 383 |
+
out = []
|
| 384 |
+
for feature_map in features:
|
| 385 |
+
# downsample pixel_mask to match shape of corresponding feature_map
|
| 386 |
+
mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
|
| 387 |
+
out.append((feature_map, mask))
|
| 388 |
+
return out
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class DetrConvModel(nn.Module):
|
| 392 |
+
"""
|
| 393 |
+
This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
|
| 394 |
+
"""
|
| 395 |
+
|
| 396 |
+
def __init__(self, conv_encoder, position_embedding):
|
| 397 |
+
super().__init__()
|
| 398 |
+
self.conv_encoder = conv_encoder
|
| 399 |
+
self.position_embedding = position_embedding
|
| 400 |
+
|
| 401 |
+
def forward(self, pixel_values, pixel_mask):
|
| 402 |
+
# send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
|
| 403 |
+
out = self.conv_encoder(pixel_values, pixel_mask)
|
| 404 |
+
pos = []
|
| 405 |
+
for feature_map, mask in out:
|
| 406 |
+
# position encoding
|
| 407 |
+
pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
|
| 408 |
+
|
| 409 |
+
return out, pos
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class DetrSinePositionEmbedding(nn.Module):
|
| 413 |
+
"""
|
| 414 |
+
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
|
| 415 |
+
need paper, generalized to work on images.
|
| 416 |
+
"""
|
| 417 |
+
|
| 418 |
+
def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None):
|
| 419 |
+
super().__init__()
|
| 420 |
+
self.embedding_dim = embedding_dim
|
| 421 |
+
self.temperature = temperature
|
| 422 |
+
self.normalize = normalize
|
| 423 |
+
if scale is not None and normalize is False:
|
| 424 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 425 |
+
if scale is None:
|
| 426 |
+
scale = 2 * math.pi
|
| 427 |
+
self.scale = scale
|
| 428 |
+
|
| 429 |
+
def forward(self, pixel_values, pixel_mask):
|
| 430 |
+
if pixel_mask is None:
|
| 431 |
+
raise ValueError("No pixel mask provided")
|
| 432 |
+
y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
|
| 433 |
+
x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
|
| 434 |
+
if self.normalize:
|
| 435 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
|
| 436 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
|
| 437 |
+
|
| 438 |
+
dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
|
| 439 |
+
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
|
| 440 |
+
|
| 441 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 442 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 443 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 444 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 445 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 446 |
+
return pos
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
class DetrLearnedPositionEmbedding(nn.Module):
|
| 450 |
+
"""
|
| 451 |
+
This module learns positional embeddings up to a fixed maximum size.
|
| 452 |
+
"""
|
| 453 |
+
|
| 454 |
+
def __init__(self, embedding_dim=256):
|
| 455 |
+
super().__init__()
|
| 456 |
+
self.row_embeddings = nn.Embedding(50, embedding_dim)
|
| 457 |
+
self.column_embeddings = nn.Embedding(50, embedding_dim)
|
| 458 |
+
|
| 459 |
+
def forward(self, pixel_values, pixel_mask=None):
|
| 460 |
+
height, width = pixel_values.shape[-2:]
|
| 461 |
+
width_values = torch.arange(width, device=pixel_values.device)
|
| 462 |
+
height_values = torch.arange(height, device=pixel_values.device)
|
| 463 |
+
x_emb = self.column_embeddings(width_values)
|
| 464 |
+
y_emb = self.row_embeddings(height_values)
|
| 465 |
+
pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
|
| 466 |
+
pos = pos.permute(2, 0, 1)
|
| 467 |
+
pos = pos.unsqueeze(0)
|
| 468 |
+
pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
|
| 469 |
+
return pos
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def build_position_encoding(config):
|
| 473 |
+
n_steps = config.d_model // 2
|
| 474 |
+
if config.position_embedding_type == "sine":
|
| 475 |
+
# TODO find a better way of exposing other arguments
|
| 476 |
+
position_embedding = DetrSinePositionEmbedding(n_steps, normalize=True)
|
| 477 |
+
elif config.position_embedding_type == "learned":
|
| 478 |
+
position_embedding = DetrLearnedPositionEmbedding(n_steps)
|
| 479 |
+
else:
|
| 480 |
+
raise ValueError(f"Not supported {config.position_embedding_type}")
|
| 481 |
+
|
| 482 |
+
return position_embedding
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class DetrAttention(nn.Module):
|
| 486 |
+
"""
|
| 487 |
+
Multi-headed attention from 'Attention Is All You Need' paper.
|
| 488 |
+
|
| 489 |
+
Here, we add position embeddings to the queries and keys (as explained in the DETR paper).
|
| 490 |
+
"""
|
| 491 |
+
|
| 492 |
+
def __init__(
|
| 493 |
+
self,
|
| 494 |
+
embed_dim: int,
|
| 495 |
+
num_heads: int,
|
| 496 |
+
dropout: float = 0.0,
|
| 497 |
+
bias: bool = True,
|
| 498 |
+
):
|
| 499 |
+
super().__init__()
|
| 500 |
+
self.embed_dim = embed_dim
|
| 501 |
+
self.num_heads = num_heads
|
| 502 |
+
self.dropout = dropout
|
| 503 |
+
self.head_dim = embed_dim // num_heads
|
| 504 |
+
if self.head_dim * num_heads != self.embed_dim:
|
| 505 |
+
raise ValueError(
|
| 506 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 507 |
+
f" {num_heads})."
|
| 508 |
+
)
|
| 509 |
+
self.scaling = self.head_dim**-0.5
|
| 510 |
+
|
| 511 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 512 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 513 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 514 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 515 |
+
|
| 516 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
| 517 |
+
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 518 |
+
|
| 519 |
+
def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor]):
|
| 520 |
+
return tensor if object_queries is None else tensor + object_queries
|
| 521 |
+
|
| 522 |
+
def forward(
|
| 523 |
+
self,
|
| 524 |
+
hidden_states: torch.Tensor,
|
| 525 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 526 |
+
object_queries: Optional[torch.Tensor] = None,
|
| 527 |
+
key_value_states: Optional[torch.Tensor] = None,
|
| 528 |
+
spatial_position_embeddings: Optional[torch.Tensor] = None,
|
| 529 |
+
output_attentions: bool = False,
|
| 530 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 531 |
+
"""Input shape: Batch x Time x Channel"""
|
| 532 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
| 533 |
+
# for the decoder
|
| 534 |
+
is_cross_attention = key_value_states is not None
|
| 535 |
+
batch_size, target_len, embed_dim = hidden_states.size()
|
| 536 |
+
|
| 537 |
+
# add position embeddings to the hidden states before projecting to queries and keys
|
| 538 |
+
if object_queries is not None:
|
| 539 |
+
hidden_states_original = hidden_states
|
| 540 |
+
hidden_states = self.with_pos_embed(hidden_states, object_queries)
|
| 541 |
+
|
| 542 |
+
# add key-value position embeddings to the key value states
|
| 543 |
+
if spatial_position_embeddings is not None:
|
| 544 |
+
key_value_states_original = key_value_states
|
| 545 |
+
key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
|
| 546 |
+
|
| 547 |
+
# get query proj
|
| 548 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
| 549 |
+
# get key, value proj
|
| 550 |
+
if is_cross_attention:
|
| 551 |
+
# cross_attentions
|
| 552 |
+
key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
|
| 553 |
+
value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
|
| 554 |
+
else:
|
| 555 |
+
# self_attention
|
| 556 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
|
| 557 |
+
value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
|
| 558 |
+
|
| 559 |
+
proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
|
| 560 |
+
query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
|
| 561 |
+
key_states = key_states.view(*proj_shape)
|
| 562 |
+
value_states = value_states.view(*proj_shape)
|
| 563 |
+
|
| 564 |
+
source_len = key_states.size(1)
|
| 565 |
+
|
| 566 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
| 567 |
+
|
| 568 |
+
if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
|
| 569 |
+
raise ValueError(
|
| 570 |
+
f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
|
| 571 |
+
f" {attn_weights.size()}"
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
if attention_mask is not None:
|
| 575 |
+
if attention_mask.size() != (batch_size, 1, target_len, source_len):
|
| 576 |
+
raise ValueError(
|
| 577 |
+
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
|
| 578 |
+
f" {attention_mask.size()}"
|
| 579 |
+
)
|
| 580 |
+
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
|
| 581 |
+
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
|
| 582 |
+
|
| 583 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 584 |
+
|
| 585 |
+
if output_attentions:
|
| 586 |
+
# this operation is a bit awkward, but it's required to
|
| 587 |
+
# make sure that attn_weights keeps its gradient.
|
| 588 |
+
# In order to do so, attn_weights have to reshaped
|
| 589 |
+
# twice and have to be reused in the following
|
| 590 |
+
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
|
| 591 |
+
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
|
| 592 |
+
else:
|
| 593 |
+
attn_weights_reshaped = None
|
| 594 |
+
|
| 595 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 596 |
+
|
| 597 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
| 598 |
+
|
| 599 |
+
if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
|
| 600 |
+
raise ValueError(
|
| 601 |
+
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
|
| 602 |
+
f" {attn_output.size()}"
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
|
| 606 |
+
attn_output = attn_output.transpose(1, 2)
|
| 607 |
+
attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
|
| 608 |
+
|
| 609 |
+
attn_output = self.out_proj(attn_output)
|
| 610 |
+
|
| 611 |
+
return attn_output, attn_weights_reshaped
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
class DetrEncoderLayer(nn.Module):
|
| 615 |
+
def __init__(self, config: DetrConfig):
|
| 616 |
+
super().__init__()
|
| 617 |
+
self.embed_dim = config.d_model
|
| 618 |
+
self.self_attn = DetrAttention(
|
| 619 |
+
embed_dim=self.embed_dim,
|
| 620 |
+
num_heads=config.encoder_attention_heads,
|
| 621 |
+
dropout=config.attention_dropout,
|
| 622 |
+
)
|
| 623 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 624 |
+
self.dropout = config.dropout
|
| 625 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
| 626 |
+
self.activation_dropout = config.activation_dropout
|
| 627 |
+
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
| 628 |
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
| 629 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 630 |
+
|
| 631 |
+
def forward(
|
| 632 |
+
self,
|
| 633 |
+
hidden_states: torch.Tensor,
|
| 634 |
+
attention_mask: torch.Tensor,
|
| 635 |
+
object_queries: Optional[torch.Tensor] = None,
|
| 636 |
+
output_attentions: bool = False,
|
| 637 |
+
):
|
| 638 |
+
"""
|
| 639 |
+
Args:
|
| 640 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 641 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 642 |
+
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
| 643 |
+
values.
|
| 644 |
+
object_queries (`torch.FloatTensor`, *optional*):
|
| 645 |
+
Object queries (also called content embeddings), to be added to the hidden states.
|
| 646 |
+
output_attentions (`bool`, *optional*):
|
| 647 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 648 |
+
returned tensors for more detail.
|
| 649 |
+
"""
|
| 650 |
+
residual = hidden_states
|
| 651 |
+
hidden_states, attn_weights = self.self_attn(
|
| 652 |
+
hidden_states=hidden_states,
|
| 653 |
+
attention_mask=attention_mask,
|
| 654 |
+
object_queries=object_queries,
|
| 655 |
+
output_attentions=output_attentions,
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 659 |
+
hidden_states = residual + hidden_states
|
| 660 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 661 |
+
|
| 662 |
+
residual = hidden_states
|
| 663 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 664 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
| 665 |
+
|
| 666 |
+
hidden_states = self.fc2(hidden_states)
|
| 667 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 668 |
+
|
| 669 |
+
hidden_states = residual + hidden_states
|
| 670 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 671 |
+
|
| 672 |
+
if self.training:
|
| 673 |
+
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
| 674 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
| 675 |
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
| 676 |
+
|
| 677 |
+
outputs = (hidden_states,)
|
| 678 |
+
|
| 679 |
+
if output_attentions:
|
| 680 |
+
outputs += (attn_weights,)
|
| 681 |
+
|
| 682 |
+
return outputs
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
class DetrDecoderLayer(nn.Module):
|
| 686 |
+
def __init__(self, config: DetrConfig):
|
| 687 |
+
super().__init__()
|
| 688 |
+
self.embed_dim = config.d_model
|
| 689 |
+
|
| 690 |
+
self.self_attn = DetrAttention(
|
| 691 |
+
embed_dim=self.embed_dim,
|
| 692 |
+
num_heads=config.decoder_attention_heads,
|
| 693 |
+
dropout=config.attention_dropout,
|
| 694 |
+
)
|
| 695 |
+
self.dropout = config.dropout
|
| 696 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
| 697 |
+
self.activation_dropout = config.activation_dropout
|
| 698 |
+
|
| 699 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 700 |
+
self.encoder_attn = DetrAttention(
|
| 701 |
+
self.embed_dim,
|
| 702 |
+
config.decoder_attention_heads,
|
| 703 |
+
dropout=config.attention_dropout,
|
| 704 |
+
)
|
| 705 |
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 706 |
+
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
| 707 |
+
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
| 708 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 709 |
+
|
| 710 |
+
def forward(
|
| 711 |
+
self,
|
| 712 |
+
hidden_states: torch.Tensor,
|
| 713 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 714 |
+
object_queries: Optional[torch.Tensor] = None,
|
| 715 |
+
query_position_embeddings: Optional[torch.Tensor] = None,
|
| 716 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 717 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 718 |
+
output_attentions: Optional[bool] = False,
|
| 719 |
+
):
|
| 720 |
+
"""
|
| 721 |
+
Args:
|
| 722 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 723 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 724 |
+
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
| 725 |
+
values.
|
| 726 |
+
object_queries (`torch.FloatTensor`, *optional*):
|
| 727 |
+
object_queries that are added to the hidden states
|
| 728 |
+
in the cross-attention layer.
|
| 729 |
+
query_position_embeddings (`torch.FloatTensor`, *optional*):
|
| 730 |
+
position embeddings that are added to the queries and keys
|
| 731 |
+
in the self-attention layer.
|
| 732 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
| 733 |
+
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 734 |
+
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
| 735 |
+
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
| 736 |
+
values.
|
| 737 |
+
output_attentions (`bool`, *optional*):
|
| 738 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 739 |
+
returned tensors for more detail.
|
| 740 |
+
"""
|
| 741 |
+
residual = hidden_states
|
| 742 |
+
|
| 743 |
+
# Self Attention
|
| 744 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 745 |
+
hidden_states=hidden_states,
|
| 746 |
+
object_queries=query_position_embeddings,
|
| 747 |
+
attention_mask=attention_mask,
|
| 748 |
+
output_attentions=output_attentions,
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 752 |
+
hidden_states = residual + hidden_states
|
| 753 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 754 |
+
|
| 755 |
+
# Cross-Attention Block
|
| 756 |
+
cross_attn_weights = None
|
| 757 |
+
if encoder_hidden_states is not None:
|
| 758 |
+
residual = hidden_states
|
| 759 |
+
|
| 760 |
+
hidden_states, cross_attn_weights = self.encoder_attn(
|
| 761 |
+
hidden_states=hidden_states,
|
| 762 |
+
object_queries=query_position_embeddings,
|
| 763 |
+
key_value_states=encoder_hidden_states,
|
| 764 |
+
attention_mask=encoder_attention_mask,
|
| 765 |
+
spatial_position_embeddings=object_queries,
|
| 766 |
+
output_attentions=output_attentions,
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 770 |
+
hidden_states = residual + hidden_states
|
| 771 |
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
| 772 |
+
|
| 773 |
+
# Fully Connected
|
| 774 |
+
residual = hidden_states
|
| 775 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 776 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
| 777 |
+
hidden_states = self.fc2(hidden_states)
|
| 778 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 779 |
+
hidden_states = residual + hidden_states
|
| 780 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 781 |
+
|
| 782 |
+
outputs = (hidden_states,)
|
| 783 |
+
|
| 784 |
+
if output_attentions:
|
| 785 |
+
outputs += (self_attn_weights, cross_attn_weights)
|
| 786 |
+
|
| 787 |
+
return outputs
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
class DetrPreTrainedModel(PreTrainedModel):
|
| 791 |
+
config_class = DetrConfig
|
| 792 |
+
base_model_prefix = "model"
|
| 793 |
+
main_input_name = "pixel_values"
|
| 794 |
+
_no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"]
|
| 795 |
+
|
| 796 |
+
def _init_weights(self, module):
|
| 797 |
+
std = self.config.init_std
|
| 798 |
+
xavier_std = self.config.init_xavier_std
|
| 799 |
+
|
| 800 |
+
if isinstance(module, DetrMHAttentionMap):
|
| 801 |
+
nn.init.zeros_(module.k_linear.bias)
|
| 802 |
+
nn.init.zeros_(module.q_linear.bias)
|
| 803 |
+
nn.init.xavier_uniform_(module.k_linear.weight, gain=xavier_std)
|
| 804 |
+
nn.init.xavier_uniform_(module.q_linear.weight, gain=xavier_std)
|
| 805 |
+
elif isinstance(module, DetrLearnedPositionEmbedding):
|
| 806 |
+
nn.init.uniform_(module.row_embeddings.weight)
|
| 807 |
+
nn.init.uniform_(module.column_embeddings.weight)
|
| 808 |
+
if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
| 809 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 810 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 811 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 812 |
+
if module.bias is not None:
|
| 813 |
+
module.bias.data.zero_()
|
| 814 |
+
elif isinstance(module, nn.Embedding):
|
| 815 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 816 |
+
if module.padding_idx is not None:
|
| 817 |
+
module.weight.data[module.padding_idx].zero_()
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
DETR_START_DOCSTRING = r"""
|
| 821 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 822 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 823 |
+
etc.)
|
| 824 |
+
|
| 825 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 826 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 827 |
+
and behavior.
|
| 828 |
+
|
| 829 |
+
Parameters:
|
| 830 |
+
config ([`DetrConfig`]):
|
| 831 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 832 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 833 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 834 |
+
"""
|
| 835 |
+
|
| 836 |
+
DETR_INPUTS_DOCSTRING = r"""
|
| 837 |
+
Args:
|
| 838 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 839 |
+
Pixel values. Padding will be ignored by default should you provide it.
|
| 840 |
+
|
| 841 |
+
Pixel values can be obtained using [`AutoImageProcessor`]. See [`DetrImageProcessor.__call__`] for details.
|
| 842 |
+
|
| 843 |
+
pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
| 844 |
+
Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
|
| 845 |
+
|
| 846 |
+
- 1 for pixels that are real (i.e. **not masked**),
|
| 847 |
+
- 0 for pixels that are padding (i.e. **masked**).
|
| 848 |
+
|
| 849 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 850 |
+
|
| 851 |
+
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
| 852 |
+
Not used by default. Can be used to mask object queries.
|
| 853 |
+
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
|
| 854 |
+
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
| 855 |
+
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
| 856 |
+
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
| 857 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 858 |
+
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
| 859 |
+
can choose to directly pass a flattened representation of an image.
|
| 860 |
+
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
| 861 |
+
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
| 862 |
+
embedded representation.
|
| 863 |
+
output_attentions (`bool`, *optional*):
|
| 864 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 865 |
+
tensors for more detail.
|
| 866 |
+
output_hidden_states (`bool`, *optional*):
|
| 867 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 868 |
+
more detail.
|
| 869 |
+
return_dict (`bool`, *optional*):
|
| 870 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 871 |
+
"""
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
class DetrEncoder(DetrPreTrainedModel):
|
| 875 |
+
"""
|
| 876 |
+
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
| 877 |
+
[`DetrEncoderLayer`].
|
| 878 |
+
|
| 879 |
+
The encoder updates the flattened feature map through multiple self-attention layers.
|
| 880 |
+
|
| 881 |
+
Small tweak for DETR:
|
| 882 |
+
|
| 883 |
+
- object_queries are added to the forward pass.
|
| 884 |
+
|
| 885 |
+
Args:
|
| 886 |
+
config: DetrConfig
|
| 887 |
+
"""
|
| 888 |
+
|
| 889 |
+
def __init__(self, config: DetrConfig):
|
| 890 |
+
super().__init__(config)
|
| 891 |
+
|
| 892 |
+
self.dropout = config.dropout
|
| 893 |
+
self.layerdrop = config.encoder_layerdrop
|
| 894 |
+
|
| 895 |
+
self.layers = nn.ModuleList([DetrEncoderLayer(config) for _ in range(config.encoder_layers)])
|
| 896 |
+
|
| 897 |
+
# in the original DETR, no layernorm is used at the end of the encoder, as "normalize_before" is set to False by default
|
| 898 |
+
|
| 899 |
+
# Initialize weights and apply final processing
|
| 900 |
+
self.post_init()
|
| 901 |
+
|
| 902 |
+
def forward(
|
| 903 |
+
self,
|
| 904 |
+
inputs_embeds=None,
|
| 905 |
+
attention_mask=None,
|
| 906 |
+
object_queries=None,
|
| 907 |
+
output_attentions=None,
|
| 908 |
+
output_hidden_states=None,
|
| 909 |
+
return_dict=None,
|
| 910 |
+
):
|
| 911 |
+
r"""
|
| 912 |
+
Args:
|
| 913 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 914 |
+
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
|
| 915 |
+
|
| 916 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 917 |
+
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
|
| 918 |
+
|
| 919 |
+
- 1 for pixel features that are real (i.e. **not masked**),
|
| 920 |
+
- 0 for pixel features that are padding (i.e. **masked**).
|
| 921 |
+
|
| 922 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 923 |
+
|
| 924 |
+
object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 925 |
+
Object queries that are added to the queries in each self-attention layer.
|
| 926 |
+
|
| 927 |
+
output_attentions (`bool`, *optional*):
|
| 928 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 929 |
+
returned tensors for more detail.
|
| 930 |
+
output_hidden_states (`bool`, *optional*):
|
| 931 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 932 |
+
for more detail.
|
| 933 |
+
return_dict (`bool`, *optional*):
|
| 934 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 935 |
+
"""
|
| 936 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 937 |
+
output_hidden_states = (
|
| 938 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 939 |
+
)
|
| 940 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 941 |
+
|
| 942 |
+
hidden_states = inputs_embeds
|
| 943 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 944 |
+
|
| 945 |
+
# expand attention_mask
|
| 946 |
+
if attention_mask is not None:
|
| 947 |
+
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
| 948 |
+
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
| 949 |
+
|
| 950 |
+
encoder_states = () if output_hidden_states else None
|
| 951 |
+
all_attentions = () if output_attentions else None
|
| 952 |
+
for i, encoder_layer in enumerate(self.layers):
|
| 953 |
+
if output_hidden_states:
|
| 954 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 955 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
| 956 |
+
to_drop = False
|
| 957 |
+
if self.training:
|
| 958 |
+
dropout_probability = torch.rand([])
|
| 959 |
+
if dropout_probability < self.layerdrop: # skip the layer
|
| 960 |
+
to_drop = True
|
| 961 |
+
|
| 962 |
+
if to_drop:
|
| 963 |
+
layer_outputs = (None, None)
|
| 964 |
+
else:
|
| 965 |
+
# we add object_queries as extra input to the encoder_layer
|
| 966 |
+
layer_outputs = encoder_layer(
|
| 967 |
+
hidden_states,
|
| 968 |
+
attention_mask,
|
| 969 |
+
object_queries=object_queries,
|
| 970 |
+
output_attentions=output_attentions,
|
| 971 |
+
)
|
| 972 |
+
|
| 973 |
+
hidden_states = layer_outputs[0]
|
| 974 |
+
|
| 975 |
+
if output_attentions:
|
| 976 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 977 |
+
|
| 978 |
+
if output_hidden_states:
|
| 979 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 980 |
+
|
| 981 |
+
if not return_dict:
|
| 982 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
| 983 |
+
return BaseModelOutput(
|
| 984 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
|
| 988 |
+
class DetrDecoder(DetrPreTrainedModel):
|
| 989 |
+
"""
|
| 990 |
+
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`].
|
| 991 |
+
|
| 992 |
+
The decoder updates the query embeddings through multiple self-attention and cross-attention layers.
|
| 993 |
+
|
| 994 |
+
Some small tweaks for DETR:
|
| 995 |
+
|
| 996 |
+
- object_queries and query_position_embeddings are added to the forward pass.
|
| 997 |
+
- if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
|
| 998 |
+
|
| 999 |
+
Args:
|
| 1000 |
+
config: DetrConfig
|
| 1001 |
+
"""
|
| 1002 |
+
|
| 1003 |
+
def __init__(self, config: DetrConfig):
|
| 1004 |
+
super().__init__(config)
|
| 1005 |
+
self.dropout = config.dropout
|
| 1006 |
+
self.layerdrop = config.decoder_layerdrop
|
| 1007 |
+
|
| 1008 |
+
self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)])
|
| 1009 |
+
# in DETR, the decoder uses layernorm after the last decoder layer output
|
| 1010 |
+
self.layernorm = nn.LayerNorm(config.d_model)
|
| 1011 |
+
|
| 1012 |
+
self.gradient_checkpointing = False
|
| 1013 |
+
# Initialize weights and apply final processing
|
| 1014 |
+
self.post_init()
|
| 1015 |
+
|
| 1016 |
+
def forward(
|
| 1017 |
+
self,
|
| 1018 |
+
inputs_embeds=None,
|
| 1019 |
+
attention_mask=None,
|
| 1020 |
+
encoder_hidden_states=None,
|
| 1021 |
+
encoder_attention_mask=None,
|
| 1022 |
+
object_queries=None,
|
| 1023 |
+
query_position_embeddings=None,
|
| 1024 |
+
output_attentions=None,
|
| 1025 |
+
output_hidden_states=None,
|
| 1026 |
+
return_dict=None,
|
| 1027 |
+
):
|
| 1028 |
+
r"""
|
| 1029 |
+
Args:
|
| 1030 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 1031 |
+
The query embeddings that are passed into the decoder.
|
| 1032 |
+
|
| 1033 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1034 |
+
Mask to avoid performing attention on certain queries. Mask values selected in `[0, 1]`:
|
| 1035 |
+
|
| 1036 |
+
- 1 for queries that are **not masked**,
|
| 1037 |
+
- 0 for queries that are **masked**.
|
| 1038 |
+
|
| 1039 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1040 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
|
| 1041 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
| 1042 |
+
of the decoder.
|
| 1043 |
+
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
|
| 1044 |
+
Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
|
| 1045 |
+
in `[0, 1]`:
|
| 1046 |
+
|
| 1047 |
+
- 1 for pixels that are real (i.e. **not masked**),
|
| 1048 |
+
- 0 for pixels that are padding (i.e. **masked**).
|
| 1049 |
+
|
| 1050 |
+
object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 1051 |
+
Object queries that are added to the queries and keys in each cross-attention layer.
|
| 1052 |
+
query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
| 1053 |
+
, *optional*): Position embeddings that are added to the values and keys in each self-attention layer.
|
| 1054 |
+
|
| 1055 |
+
output_attentions (`bool`, *optional*):
|
| 1056 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 1057 |
+
returned tensors for more detail.
|
| 1058 |
+
output_hidden_states (`bool`, *optional*):
|
| 1059 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 1060 |
+
for more detail.
|
| 1061 |
+
return_dict (`bool`, *optional*):
|
| 1062 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1063 |
+
"""
|
| 1064 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1065 |
+
output_hidden_states = (
|
| 1066 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1067 |
+
)
|
| 1068 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1069 |
+
|
| 1070 |
+
if inputs_embeds is not None:
|
| 1071 |
+
hidden_states = inputs_embeds
|
| 1072 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 1073 |
+
|
| 1074 |
+
combined_attention_mask = None
|
| 1075 |
+
|
| 1076 |
+
if attention_mask is not None and combined_attention_mask is not None:
|
| 1077 |
+
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
| 1078 |
+
combined_attention_mask = combined_attention_mask + _prepare_4d_attention_mask(
|
| 1079 |
+
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
# expand encoder attention mask
|
| 1083 |
+
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
| 1084 |
+
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
| 1085 |
+
encoder_attention_mask = _prepare_4d_attention_mask(
|
| 1086 |
+
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
| 1087 |
+
)
|
| 1088 |
+
|
| 1089 |
+
# optional intermediate hidden states
|
| 1090 |
+
intermediate = () if self.config.auxiliary_loss else None
|
| 1091 |
+
|
| 1092 |
+
# decoder layers
|
| 1093 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1094 |
+
all_self_attns = () if output_attentions else None
|
| 1095 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
| 1096 |
+
|
| 1097 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 1098 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
| 1099 |
+
if output_hidden_states:
|
| 1100 |
+
all_hidden_states += (hidden_states,)
|
| 1101 |
+
if self.training:
|
| 1102 |
+
dropout_probability = torch.rand([])
|
| 1103 |
+
if dropout_probability < self.layerdrop:
|
| 1104 |
+
continue
|
| 1105 |
+
|
| 1106 |
+
if self.gradient_checkpointing and self.training:
|
| 1107 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 1108 |
+
decoder_layer.__call__,
|
| 1109 |
+
hidden_states,
|
| 1110 |
+
combined_attention_mask,
|
| 1111 |
+
encoder_hidden_states,
|
| 1112 |
+
encoder_attention_mask,
|
| 1113 |
+
None,
|
| 1114 |
+
)
|
| 1115 |
+
else:
|
| 1116 |
+
layer_outputs = decoder_layer(
|
| 1117 |
+
hidden_states,
|
| 1118 |
+
attention_mask=combined_attention_mask,
|
| 1119 |
+
object_queries=object_queries,
|
| 1120 |
+
query_position_embeddings=query_position_embeddings,
|
| 1121 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1122 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1123 |
+
output_attentions=output_attentions,
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
hidden_states = layer_outputs[0]
|
| 1127 |
+
|
| 1128 |
+
if self.config.auxiliary_loss:
|
| 1129 |
+
hidden_states = self.layernorm(hidden_states)
|
| 1130 |
+
intermediate += (hidden_states,)
|
| 1131 |
+
|
| 1132 |
+
if output_attentions:
|
| 1133 |
+
all_self_attns += (layer_outputs[1],)
|
| 1134 |
+
|
| 1135 |
+
if encoder_hidden_states is not None:
|
| 1136 |
+
all_cross_attentions += (layer_outputs[2],)
|
| 1137 |
+
|
| 1138 |
+
# finally, apply layernorm
|
| 1139 |
+
hidden_states = self.layernorm(hidden_states)
|
| 1140 |
+
|
| 1141 |
+
# add hidden states from the last decoder layer
|
| 1142 |
+
if output_hidden_states:
|
| 1143 |
+
all_hidden_states += (hidden_states,)
|
| 1144 |
+
|
| 1145 |
+
# stack intermediate decoder activations
|
| 1146 |
+
if self.config.auxiliary_loss:
|
| 1147 |
+
intermediate = torch.stack(intermediate)
|
| 1148 |
+
|
| 1149 |
+
if not return_dict:
|
| 1150 |
+
return tuple(
|
| 1151 |
+
v
|
| 1152 |
+
for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate]
|
| 1153 |
+
if v is not None
|
| 1154 |
+
)
|
| 1155 |
+
return DetrDecoderOutput(
|
| 1156 |
+
last_hidden_state=hidden_states,
|
| 1157 |
+
hidden_states=all_hidden_states,
|
| 1158 |
+
attentions=all_self_attns,
|
| 1159 |
+
cross_attentions=all_cross_attentions,
|
| 1160 |
+
intermediate_hidden_states=intermediate,
|
| 1161 |
+
)
|
| 1162 |
+
|
| 1163 |
+
|
| 1164 |
+
@add_start_docstrings(
|
| 1165 |
+
"""
|
| 1166 |
+
The bare DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without
|
| 1167 |
+
any specific head on top.
|
| 1168 |
+
""",
|
| 1169 |
+
DETR_START_DOCSTRING,
|
| 1170 |
+
)
|
| 1171 |
+
class DetrModel(DetrPreTrainedModel):
|
| 1172 |
+
def __init__(self, config: DetrConfig):
|
| 1173 |
+
super().__init__(config)
|
| 1174 |
+
|
| 1175 |
+
# Create backbone + positional encoding
|
| 1176 |
+
backbone = DetrConvEncoder(config)
|
| 1177 |
+
object_queries = build_position_encoding(config)
|
| 1178 |
+
self.backbone = DetrConvModel(backbone, object_queries)
|
| 1179 |
+
|
| 1180 |
+
# Create projection layer
|
| 1181 |
+
self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
|
| 1182 |
+
|
| 1183 |
+
self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
|
| 1184 |
+
|
| 1185 |
+
self.encoder = DetrEncoder(config)
|
| 1186 |
+
self.decoder = DetrDecoder(config)
|
| 1187 |
+
|
| 1188 |
+
# Initialize weights and apply final processing
|
| 1189 |
+
self.post_init()
|
| 1190 |
+
|
| 1191 |
+
def get_encoder(self):
|
| 1192 |
+
return self.encoder
|
| 1193 |
+
|
| 1194 |
+
def get_decoder(self):
|
| 1195 |
+
return self.decoder
|
| 1196 |
+
|
| 1197 |
+
def freeze_backbone(self):
|
| 1198 |
+
for name, param in self.backbone.conv_encoder.model.named_parameters():
|
| 1199 |
+
param.requires_grad_(False)
|
| 1200 |
+
|
| 1201 |
+
def unfreeze_backbone(self):
|
| 1202 |
+
for name, param in self.backbone.conv_encoder.model.named_parameters():
|
| 1203 |
+
param.requires_grad_(True)
|
| 1204 |
+
|
| 1205 |
+
@add_start_docstrings_to_model_forward(DETR_INPUTS_DOCSTRING)
|
| 1206 |
+
@replace_return_docstrings(output_type=DetrModelOutput, config_class=_CONFIG_FOR_DOC)
|
| 1207 |
+
def forward(
|
| 1208 |
+
self,
|
| 1209 |
+
pixel_values: torch.FloatTensor,
|
| 1210 |
+
pixel_mask: Optional[torch.LongTensor] = None,
|
| 1211 |
+
decoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 1212 |
+
encoder_outputs: Optional[torch.FloatTensor] = None,
|
| 1213 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1214 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1215 |
+
output_attentions: Optional[bool] = None,
|
| 1216 |
+
output_hidden_states: Optional[bool] = None,
|
| 1217 |
+
return_dict: Optional[bool] = None,
|
| 1218 |
+
) -> Union[Tuple[torch.FloatTensor], DetrModelOutput]:
|
| 1219 |
+
r"""
|
| 1220 |
+
Returns:
|
| 1221 |
+
|
| 1222 |
+
Examples:
|
| 1223 |
+
|
| 1224 |
+
```python
|
| 1225 |
+
>>> from transformers import AutoImageProcessor, DetrModel
|
| 1226 |
+
>>> from PIL import Image
|
| 1227 |
+
>>> import requests
|
| 1228 |
+
|
| 1229 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1230 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1231 |
+
|
| 1232 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
| 1233 |
+
>>> model = DetrModel.from_pretrained("facebook/detr-resnet-50")
|
| 1234 |
+
|
| 1235 |
+
>>> # prepare image for the model
|
| 1236 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 1237 |
+
|
| 1238 |
+
>>> # forward pass
|
| 1239 |
+
>>> outputs = model(**inputs)
|
| 1240 |
+
|
| 1241 |
+
>>> # the last hidden states are the final query embeddings of the Transformer decoder
|
| 1242 |
+
>>> # these are of shape (batch_size, num_queries, hidden_size)
|
| 1243 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
| 1244 |
+
>>> list(last_hidden_states.shape)
|
| 1245 |
+
[1, 100, 256]
|
| 1246 |
+
```"""
|
| 1247 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1248 |
+
output_hidden_states = (
|
| 1249 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1250 |
+
)
|
| 1251 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1252 |
+
|
| 1253 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 1254 |
+
device = pixel_values.device
|
| 1255 |
+
|
| 1256 |
+
if pixel_mask is None:
|
| 1257 |
+
pixel_mask = torch.ones(((batch_size, height, width)), device=device)
|
| 1258 |
+
|
| 1259 |
+
# First, sent pixel_values + pixel_mask through Backbone to obtain the features
|
| 1260 |
+
# pixel_values should be of shape (batch_size, num_channels, height, width)
|
| 1261 |
+
# pixel_mask should be of shape (batch_size, height, width)
|
| 1262 |
+
features, object_queries_list = self.backbone(pixel_values, pixel_mask)
|
| 1263 |
+
|
| 1264 |
+
# get final feature map and downsampled mask
|
| 1265 |
+
feature_map, mask = features[-1]
|
| 1266 |
+
|
| 1267 |
+
if mask is None:
|
| 1268 |
+
raise ValueError("Backbone does not return downsampled pixel mask")
|
| 1269 |
+
|
| 1270 |
+
# Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
| 1271 |
+
projected_feature_map = self.input_projection(feature_map)
|
| 1272 |
+
|
| 1273 |
+
# Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
| 1274 |
+
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
| 1275 |
+
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
| 1276 |
+
object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
|
| 1277 |
+
|
| 1278 |
+
flattened_mask = mask.flatten(1)
|
| 1279 |
+
|
| 1280 |
+
# Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
|
| 1281 |
+
# flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size)
|
| 1282 |
+
# flattened_mask is a Tensor of shape (batch_size, heigth*width)
|
| 1283 |
+
if encoder_outputs is None:
|
| 1284 |
+
encoder_outputs = self.encoder(
|
| 1285 |
+
inputs_embeds=flattened_features,
|
| 1286 |
+
attention_mask=flattened_mask,
|
| 1287 |
+
object_queries=object_queries,
|
| 1288 |
+
output_attentions=output_attentions,
|
| 1289 |
+
output_hidden_states=output_hidden_states,
|
| 1290 |
+
return_dict=return_dict,
|
| 1291 |
+
)
|
| 1292 |
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
| 1293 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
| 1294 |
+
encoder_outputs = BaseModelOutput(
|
| 1295 |
+
last_hidden_state=encoder_outputs[0],
|
| 1296 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
| 1297 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
| 1298 |
+
)
|
| 1299 |
+
|
| 1300 |
+
# Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output)
|
| 1301 |
+
query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
|
| 1302 |
+
queries = torch.zeros_like(query_position_embeddings)
|
| 1303 |
+
|
| 1304 |
+
# decoder outputs consists of (dec_features, dec_hidden, dec_attn)
|
| 1305 |
+
decoder_outputs = self.decoder(
|
| 1306 |
+
inputs_embeds=queries,
|
| 1307 |
+
attention_mask=None,
|
| 1308 |
+
object_queries=object_queries,
|
| 1309 |
+
query_position_embeddings=query_position_embeddings,
|
| 1310 |
+
encoder_hidden_states=encoder_outputs[0],
|
| 1311 |
+
encoder_attention_mask=flattened_mask,
|
| 1312 |
+
output_attentions=output_attentions,
|
| 1313 |
+
output_hidden_states=output_hidden_states,
|
| 1314 |
+
return_dict=return_dict,
|
| 1315 |
+
)
|
| 1316 |
+
|
| 1317 |
+
if not return_dict:
|
| 1318 |
+
return decoder_outputs + encoder_outputs
|
| 1319 |
+
|
| 1320 |
+
return DetrModelOutput(
|
| 1321 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
| 1322 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 1323 |
+
decoder_attentions=decoder_outputs.attentions,
|
| 1324 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
| 1325 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
| 1326 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 1327 |
+
encoder_attentions=encoder_outputs.attentions,
|
| 1328 |
+
intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
|
| 1329 |
+
)
|
| 1330 |
+
|
| 1331 |
+
|
| 1332 |
+
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
| 1333 |
+
class DetrMLPPredictionHead(nn.Module):
|
| 1334 |
+
"""
|
| 1335 |
+
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
| 1336 |
+
height and width of a bounding box w.r.t. an image.
|
| 1337 |
+
|
| 1338 |
+
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
| 1339 |
+
|
| 1340 |
+
"""
|
| 1341 |
+
|
| 1342 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
| 1343 |
+
super().__init__()
|
| 1344 |
+
self.num_layers = num_layers
|
| 1345 |
+
h = [hidden_dim] * (num_layers - 1)
|
| 1346 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
| 1347 |
+
|
| 1348 |
+
def forward(self, x):
|
| 1349 |
+
for i, layer in enumerate(self.layers):
|
| 1350 |
+
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 1351 |
+
return x
|
| 1352 |
+
|
| 1353 |
+
|
| 1354 |
+
@add_start_docstrings(
|
| 1355 |
+
"""
|
| 1356 |
+
DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, for tasks
|
| 1357 |
+
such as COCO detection.
|
| 1358 |
+
""",
|
| 1359 |
+
DETR_START_DOCSTRING,
|
| 1360 |
+
)
|
| 1361 |
+
class DetrForObjectDetection(DetrPreTrainedModel):
|
| 1362 |
+
def __init__(self, config: DetrConfig):
|
| 1363 |
+
super().__init__(config)
|
| 1364 |
+
|
| 1365 |
+
# DETR encoder-decoder model
|
| 1366 |
+
self.model = DetrModel(config)
|
| 1367 |
+
|
| 1368 |
+
# Object detection heads
|
| 1369 |
+
self.class_labels_classifier = nn.Linear(
|
| 1370 |
+
config.d_model, config.num_labels + 1
|
| 1371 |
+
) # We add one for the "no object" class
|
| 1372 |
+
self.bbox_predictor = DetrMLPPredictionHead(
|
| 1373 |
+
input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
|
| 1374 |
+
)
|
| 1375 |
+
|
| 1376 |
+
# Initialize weights and apply final processing
|
| 1377 |
+
self.post_init()
|
| 1378 |
+
|
| 1379 |
+
@add_start_docstrings_to_model_forward(DETR_INPUTS_DOCSTRING)
|
| 1380 |
+
@replace_return_docstrings(output_type=DetrObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
|
| 1381 |
+
def forward(
|
| 1382 |
+
self,
|
| 1383 |
+
pixel_values: torch.FloatTensor,
|
| 1384 |
+
pixel_mask: Optional[torch.LongTensor] = None,
|
| 1385 |
+
decoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 1386 |
+
encoder_outputs: Optional[torch.FloatTensor] = None,
|
| 1387 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1388 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1389 |
+
labels: Optional[List[dict]] = None,
|
| 1390 |
+
output_attentions: Optional[bool] = None,
|
| 1391 |
+
output_hidden_states: Optional[bool] = None,
|
| 1392 |
+
return_dict: Optional[bool] = None,
|
| 1393 |
+
) -> Union[Tuple[torch.FloatTensor], DetrObjectDetectionOutput]:
|
| 1394 |
+
r"""
|
| 1395 |
+
labels (`List[Dict]` of len `(batch_size,)`, *optional*):
|
| 1396 |
+
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
| 1397 |
+
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
| 1398 |
+
respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
|
| 1399 |
+
in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
|
| 1400 |
+
|
| 1401 |
+
Returns:
|
| 1402 |
+
|
| 1403 |
+
Examples:
|
| 1404 |
+
|
| 1405 |
+
```python
|
| 1406 |
+
>>> from transformers import AutoImageProcessor, DetrForObjectDetection
|
| 1407 |
+
>>> import torch
|
| 1408 |
+
>>> from PIL import Image
|
| 1409 |
+
>>> import requests
|
| 1410 |
+
|
| 1411 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1412 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1413 |
+
|
| 1414 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
| 1415 |
+
>>> model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
|
| 1416 |
+
|
| 1417 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 1418 |
+
>>> outputs = model(**inputs)
|
| 1419 |
+
|
| 1420 |
+
>>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
|
| 1421 |
+
>>> target_sizes = torch.tensor([image.size[::-1]])
|
| 1422 |
+
>>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
|
| 1423 |
+
... 0
|
| 1424 |
+
... ]
|
| 1425 |
+
|
| 1426 |
+
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
| 1427 |
+
... box = [round(i, 2) for i in box.tolist()]
|
| 1428 |
+
... print(
|
| 1429 |
+
... f"Detected {model.config.id2label[label.item()]} with confidence "
|
| 1430 |
+
... f"{round(score.item(), 3)} at location {box}"
|
| 1431 |
+
... )
|
| 1432 |
+
Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98]
|
| 1433 |
+
Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66]
|
| 1434 |
+
Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76]
|
| 1435 |
+
Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]
|
| 1436 |
+
Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]
|
| 1437 |
+
```"""
|
| 1438 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1439 |
+
|
| 1440 |
+
# First, sent images through DETR base model to obtain encoder + decoder outputs
|
| 1441 |
+
outputs = self.model(
|
| 1442 |
+
pixel_values,
|
| 1443 |
+
pixel_mask=pixel_mask,
|
| 1444 |
+
decoder_attention_mask=decoder_attention_mask,
|
| 1445 |
+
encoder_outputs=encoder_outputs,
|
| 1446 |
+
inputs_embeds=inputs_embeds,
|
| 1447 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 1448 |
+
output_attentions=output_attentions,
|
| 1449 |
+
output_hidden_states=output_hidden_states,
|
| 1450 |
+
return_dict=return_dict,
|
| 1451 |
+
)
|
| 1452 |
+
|
| 1453 |
+
sequence_output = outputs[0]
|
| 1454 |
+
|
| 1455 |
+
# class logits + predicted bounding boxes
|
| 1456 |
+
logits = self.class_labels_classifier(sequence_output)
|
| 1457 |
+
pred_boxes = self.bbox_predictor(sequence_output).sigmoid()
|
| 1458 |
+
|
| 1459 |
+
loss, loss_dict, auxiliary_outputs = None, None, None
|
| 1460 |
+
if labels is not None:
|
| 1461 |
+
outputs_class, outputs_coord = None, None
|
| 1462 |
+
if self.config.auxiliary_loss:
|
| 1463 |
+
intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
|
| 1464 |
+
outputs_class = self.class_labels_classifier(intermediate)
|
| 1465 |
+
outputs_coord = self.bbox_predictor(intermediate).sigmoid()
|
| 1466 |
+
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
| 1467 |
+
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
|
| 1468 |
+
)
|
| 1469 |
+
|
| 1470 |
+
if not return_dict:
|
| 1471 |
+
if auxiliary_outputs is not None:
|
| 1472 |
+
output = (logits, pred_boxes) + auxiliary_outputs + outputs
|
| 1473 |
+
else:
|
| 1474 |
+
output = (logits, pred_boxes) + outputs
|
| 1475 |
+
return ((loss, loss_dict) + output) if loss is not None else output
|
| 1476 |
+
|
| 1477 |
+
return DetrObjectDetectionOutput(
|
| 1478 |
+
loss=loss,
|
| 1479 |
+
loss_dict=loss_dict,
|
| 1480 |
+
logits=logits,
|
| 1481 |
+
pred_boxes=pred_boxes,
|
| 1482 |
+
auxiliary_outputs=auxiliary_outputs,
|
| 1483 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 1484 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
| 1485 |
+
decoder_attentions=outputs.decoder_attentions,
|
| 1486 |
+
cross_attentions=outputs.cross_attentions,
|
| 1487 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
| 1488 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
| 1489 |
+
encoder_attentions=outputs.encoder_attentions,
|
| 1490 |
+
)
|
| 1491 |
+
|
| 1492 |
+
|
| 1493 |
+
@add_start_docstrings(
|
| 1494 |
+
"""
|
| 1495 |
+
DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top, for tasks
|
| 1496 |
+
such as COCO panoptic.
|
| 1497 |
+
|
| 1498 |
+
""",
|
| 1499 |
+
DETR_START_DOCSTRING,
|
| 1500 |
+
)
|
| 1501 |
+
class DetrForSegmentation(DetrPreTrainedModel):
|
| 1502 |
+
def __init__(self, config: DetrConfig):
|
| 1503 |
+
super().__init__(config)
|
| 1504 |
+
|
| 1505 |
+
# object detection model
|
| 1506 |
+
self.detr = DetrForObjectDetection(config)
|
| 1507 |
+
|
| 1508 |
+
# segmentation head
|
| 1509 |
+
hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
|
| 1510 |
+
intermediate_channel_sizes = self.detr.model.backbone.conv_encoder.intermediate_channel_sizes
|
| 1511 |
+
|
| 1512 |
+
self.mask_head = DetrMaskHeadSmallConv(
|
| 1513 |
+
hidden_size + number_of_heads, intermediate_channel_sizes[::-1][-3:], hidden_size
|
| 1514 |
+
)
|
| 1515 |
+
|
| 1516 |
+
self.bbox_attention = DetrMHAttentionMap(
|
| 1517 |
+
hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std
|
| 1518 |
+
)
|
| 1519 |
+
# Initialize weights and apply final processing
|
| 1520 |
+
self.post_init()
|
| 1521 |
+
|
| 1522 |
+
@add_start_docstrings_to_model_forward(DETR_INPUTS_DOCSTRING)
|
| 1523 |
+
@replace_return_docstrings(output_type=DetrSegmentationOutput, config_class=_CONFIG_FOR_DOC)
|
| 1524 |
+
def forward(
|
| 1525 |
+
self,
|
| 1526 |
+
pixel_values: torch.FloatTensor,
|
| 1527 |
+
pixel_mask: Optional[torch.LongTensor] = None,
|
| 1528 |
+
decoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 1529 |
+
encoder_outputs: Optional[torch.FloatTensor] = None,
|
| 1530 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1531 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1532 |
+
labels: Optional[List[dict]] = None,
|
| 1533 |
+
output_attentions: Optional[bool] = None,
|
| 1534 |
+
output_hidden_states: Optional[bool] = None,
|
| 1535 |
+
return_dict: Optional[bool] = None,
|
| 1536 |
+
) -> Union[Tuple[torch.FloatTensor], DetrSegmentationOutput]:
|
| 1537 |
+
r"""
|
| 1538 |
+
labels (`List[Dict]` of len `(batch_size,)`, *optional*):
|
| 1539 |
+
Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
|
| 1540 |
+
dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
|
| 1541 |
+
bounding boxes and segmentation masks of an image in the batch respectively). The class labels themselves
|
| 1542 |
+
should be a `torch.LongTensor` of len `(number of bounding boxes in the image,)`, the boxes a
|
| 1543 |
+
`torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)` and the masks a
|
| 1544 |
+
`torch.FloatTensor` of shape `(number of bounding boxes in the image, height, width)`.
|
| 1545 |
+
|
| 1546 |
+
Returns:
|
| 1547 |
+
|
| 1548 |
+
Examples:
|
| 1549 |
+
|
| 1550 |
+
```python
|
| 1551 |
+
>>> import io
|
| 1552 |
+
>>> import requests
|
| 1553 |
+
>>> from PIL import Image
|
| 1554 |
+
>>> import torch
|
| 1555 |
+
>>> import numpy
|
| 1556 |
+
|
| 1557 |
+
>>> from transformers import AutoImageProcessor, DetrForSegmentation
|
| 1558 |
+
>>> from transformers.image_transforms import rgb_to_id
|
| 1559 |
+
|
| 1560 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1561 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1562 |
+
|
| 1563 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic")
|
| 1564 |
+
>>> model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")
|
| 1565 |
+
|
| 1566 |
+
>>> # prepare image for the model
|
| 1567 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 1568 |
+
|
| 1569 |
+
>>> # forward pass
|
| 1570 |
+
>>> outputs = model(**inputs)
|
| 1571 |
+
|
| 1572 |
+
>>> # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
|
| 1573 |
+
>>> # Segmentation results are returned as a list of dictionaries
|
| 1574 |
+
>>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(300, 500)])
|
| 1575 |
+
|
| 1576 |
+
>>> # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
|
| 1577 |
+
>>> panoptic_seg = result[0]["segmentation"]
|
| 1578 |
+
>>> # Get prediction score and segment_id to class_id mapping of each segment
|
| 1579 |
+
>>> panoptic_segments_info = result[0]["segments_info"]
|
| 1580 |
+
```"""
|
| 1581 |
+
|
| 1582 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1583 |
+
|
| 1584 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 1585 |
+
device = pixel_values.device
|
| 1586 |
+
|
| 1587 |
+
if pixel_mask is None:
|
| 1588 |
+
pixel_mask = torch.ones((batch_size, height, width), device=device)
|
| 1589 |
+
|
| 1590 |
+
# First, get list of feature maps and position embeddings
|
| 1591 |
+
features, object_queries_list = self.detr.model.backbone(pixel_values, pixel_mask=pixel_mask)
|
| 1592 |
+
|
| 1593 |
+
# Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
| 1594 |
+
feature_map, mask = features[-1]
|
| 1595 |
+
batch_size, num_channels, height, width = feature_map.shape
|
| 1596 |
+
projected_feature_map = self.detr.model.input_projection(feature_map)
|
| 1597 |
+
|
| 1598 |
+
# Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
| 1599 |
+
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
| 1600 |
+
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
| 1601 |
+
object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
|
| 1602 |
+
|
| 1603 |
+
flattened_mask = mask.flatten(1)
|
| 1604 |
+
|
| 1605 |
+
# Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
|
| 1606 |
+
# flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size)
|
| 1607 |
+
# flattened_mask is a Tensor of shape (batch_size, heigth*width)
|
| 1608 |
+
if encoder_outputs is None:
|
| 1609 |
+
encoder_outputs = self.detr.model.encoder(
|
| 1610 |
+
inputs_embeds=flattened_features,
|
| 1611 |
+
attention_mask=flattened_mask,
|
| 1612 |
+
object_queries=object_queries,
|
| 1613 |
+
output_attentions=output_attentions,
|
| 1614 |
+
output_hidden_states=output_hidden_states,
|
| 1615 |
+
return_dict=return_dict,
|
| 1616 |
+
)
|
| 1617 |
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
| 1618 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
| 1619 |
+
encoder_outputs = BaseModelOutput(
|
| 1620 |
+
last_hidden_state=encoder_outputs[0],
|
| 1621 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
| 1622 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
| 1623 |
+
)
|
| 1624 |
+
|
| 1625 |
+
# Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output)
|
| 1626 |
+
query_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
|
| 1627 |
+
batch_size, 1, 1
|
| 1628 |
+
)
|
| 1629 |
+
queries = torch.zeros_like(query_position_embeddings)
|
| 1630 |
+
|
| 1631 |
+
# decoder outputs consists of (dec_features, dec_hidden, dec_attn)
|
| 1632 |
+
decoder_outputs = self.detr.model.decoder(
|
| 1633 |
+
inputs_embeds=queries,
|
| 1634 |
+
attention_mask=None,
|
| 1635 |
+
object_queries=object_queries,
|
| 1636 |
+
query_position_embeddings=query_position_embeddings,
|
| 1637 |
+
encoder_hidden_states=encoder_outputs[0],
|
| 1638 |
+
encoder_attention_mask=flattened_mask,
|
| 1639 |
+
output_attentions=output_attentions,
|
| 1640 |
+
output_hidden_states=output_hidden_states,
|
| 1641 |
+
return_dict=return_dict,
|
| 1642 |
+
)
|
| 1643 |
+
|
| 1644 |
+
sequence_output = decoder_outputs[0]
|
| 1645 |
+
|
| 1646 |
+
# Sixth, compute logits, pred_boxes and pred_masks
|
| 1647 |
+
logits = self.detr.class_labels_classifier(sequence_output)
|
| 1648 |
+
pred_boxes = self.detr.bbox_predictor(sequence_output).sigmoid()
|
| 1649 |
+
|
| 1650 |
+
memory = encoder_outputs[0].permute(0, 2, 1).view(batch_size, self.config.d_model, height, width)
|
| 1651 |
+
mask = flattened_mask.view(batch_size, height, width)
|
| 1652 |
+
|
| 1653 |
+
# FIXME h_boxes takes the last one computed, keep this in mind
|
| 1654 |
+
# important: we need to reverse the mask, since in the original implementation the mask works reversed
|
| 1655 |
+
# bbox_mask is of shape (batch_size, num_queries, number_of_attention_heads in bbox_attention, height/32, width/32)
|
| 1656 |
+
bbox_mask = self.bbox_attention(sequence_output, memory, mask=~mask)
|
| 1657 |
+
|
| 1658 |
+
seg_masks = self.mask_head(projected_feature_map, bbox_mask, [features[2][0], features[1][0], features[0][0]])
|
| 1659 |
+
|
| 1660 |
+
pred_masks = seg_masks.view(batch_size, self.detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
|
| 1661 |
+
|
| 1662 |
+
loss, loss_dict, auxiliary_outputs = None, None, None
|
| 1663 |
+
if labels is not None:
|
| 1664 |
+
outputs_class, outputs_coord = None, None
|
| 1665 |
+
if self.config.auxiliary_loss:
|
| 1666 |
+
intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1]
|
| 1667 |
+
outputs_class = self.detr.class_labels_classifier(intermediate)
|
| 1668 |
+
outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid()
|
| 1669 |
+
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
| 1670 |
+
logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
|
| 1671 |
+
)
|
| 1672 |
+
|
| 1673 |
+
if not return_dict:
|
| 1674 |
+
if auxiliary_outputs is not None:
|
| 1675 |
+
output = (logits, pred_boxes, pred_masks) + auxiliary_outputs + decoder_outputs + encoder_outputs
|
| 1676 |
+
else:
|
| 1677 |
+
output = (logits, pred_boxes, pred_masks) + decoder_outputs + encoder_outputs
|
| 1678 |
+
return ((loss, loss_dict) + output) if loss is not None else output
|
| 1679 |
+
|
| 1680 |
+
return DetrSegmentationOutput(
|
| 1681 |
+
loss=loss,
|
| 1682 |
+
loss_dict=loss_dict,
|
| 1683 |
+
logits=logits,
|
| 1684 |
+
pred_boxes=pred_boxes,
|
| 1685 |
+
pred_masks=pred_masks,
|
| 1686 |
+
auxiliary_outputs=auxiliary_outputs,
|
| 1687 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
| 1688 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 1689 |
+
decoder_attentions=decoder_outputs.attentions,
|
| 1690 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
| 1691 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
| 1692 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 1693 |
+
encoder_attentions=encoder_outputs.attentions,
|
| 1694 |
+
)
|
| 1695 |
+
|
| 1696 |
+
|
| 1697 |
+
def _expand(tensor, length: int):
|
| 1698 |
+
return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
|
| 1699 |
+
|
| 1700 |
+
|
| 1701 |
+
# taken from https://github.com/facebookresearch/detr/blob/master/models/segmentation.py
|
| 1702 |
+
class DetrMaskHeadSmallConv(nn.Module):
|
| 1703 |
+
"""
|
| 1704 |
+
Simple convolutional head, using group norm. Upsampling is done using a FPN approach
|
| 1705 |
+
"""
|
| 1706 |
+
|
| 1707 |
+
def __init__(self, dim, fpn_dims, context_dim):
|
| 1708 |
+
super().__init__()
|
| 1709 |
+
|
| 1710 |
+
if dim % 8 != 0:
|
| 1711 |
+
raise ValueError(
|
| 1712 |
+
"The hidden_size + number of attention heads must be divisible by 8 as the number of groups in"
|
| 1713 |
+
" GroupNorm is set to 8"
|
| 1714 |
+
)
|
| 1715 |
+
|
| 1716 |
+
inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
|
| 1717 |
+
|
| 1718 |
+
self.lay1 = nn.Conv2d(dim, dim, 3, padding=1)
|
| 1719 |
+
self.gn1 = nn.GroupNorm(8, dim)
|
| 1720 |
+
self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1)
|
| 1721 |
+
self.gn2 = nn.GroupNorm(min(8, inter_dims[1]), inter_dims[1])
|
| 1722 |
+
self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
|
| 1723 |
+
self.gn3 = nn.GroupNorm(min(8, inter_dims[2]), inter_dims[2])
|
| 1724 |
+
self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
|
| 1725 |
+
self.gn4 = nn.GroupNorm(min(8, inter_dims[3]), inter_dims[3])
|
| 1726 |
+
self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
|
| 1727 |
+
self.gn5 = nn.GroupNorm(min(8, inter_dims[4]), inter_dims[4])
|
| 1728 |
+
self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1)
|
| 1729 |
+
|
| 1730 |
+
self.dim = dim
|
| 1731 |
+
|
| 1732 |
+
self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
|
| 1733 |
+
self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
|
| 1734 |
+
self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
|
| 1735 |
+
|
| 1736 |
+
for m in self.modules():
|
| 1737 |
+
if isinstance(m, nn.Conv2d):
|
| 1738 |
+
nn.init.kaiming_uniform_(m.weight, a=1)
|
| 1739 |
+
nn.init.constant_(m.bias, 0)
|
| 1740 |
+
|
| 1741 |
+
def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]):
|
| 1742 |
+
# here we concatenate x, the projected feature map, of shape (batch_size, d_model, heigth/32, width/32) with
|
| 1743 |
+
# the bbox_mask = the attention maps of shape (batch_size, n_queries, n_heads, height/32, width/32).
|
| 1744 |
+
# We expand the projected feature map to match the number of heads.
|
| 1745 |
+
x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
|
| 1746 |
+
|
| 1747 |
+
x = self.lay1(x)
|
| 1748 |
+
x = self.gn1(x)
|
| 1749 |
+
x = nn.functional.relu(x)
|
| 1750 |
+
x = self.lay2(x)
|
| 1751 |
+
x = self.gn2(x)
|
| 1752 |
+
x = nn.functional.relu(x)
|
| 1753 |
+
|
| 1754 |
+
cur_fpn = self.adapter1(fpns[0])
|
| 1755 |
+
if cur_fpn.size(0) != x.size(0):
|
| 1756 |
+
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
| 1757 |
+
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
| 1758 |
+
x = self.lay3(x)
|
| 1759 |
+
x = self.gn3(x)
|
| 1760 |
+
x = nn.functional.relu(x)
|
| 1761 |
+
|
| 1762 |
+
cur_fpn = self.adapter2(fpns[1])
|
| 1763 |
+
if cur_fpn.size(0) != x.size(0):
|
| 1764 |
+
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
| 1765 |
+
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
| 1766 |
+
x = self.lay4(x)
|
| 1767 |
+
x = self.gn4(x)
|
| 1768 |
+
x = nn.functional.relu(x)
|
| 1769 |
+
|
| 1770 |
+
cur_fpn = self.adapter3(fpns[2])
|
| 1771 |
+
if cur_fpn.size(0) != x.size(0):
|
| 1772 |
+
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
| 1773 |
+
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
| 1774 |
+
x = self.lay5(x)
|
| 1775 |
+
x = self.gn5(x)
|
| 1776 |
+
x = nn.functional.relu(x)
|
| 1777 |
+
|
| 1778 |
+
x = self.out_lay(x)
|
| 1779 |
+
return x
|
| 1780 |
+
|
| 1781 |
+
|
| 1782 |
+
class DetrMHAttentionMap(nn.Module):
|
| 1783 |
+
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
|
| 1784 |
+
|
| 1785 |
+
def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None):
|
| 1786 |
+
super().__init__()
|
| 1787 |
+
self.num_heads = num_heads
|
| 1788 |
+
self.hidden_dim = hidden_dim
|
| 1789 |
+
self.dropout = nn.Dropout(dropout)
|
| 1790 |
+
|
| 1791 |
+
self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
| 1792 |
+
self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
| 1793 |
+
|
| 1794 |
+
self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
|
| 1795 |
+
|
| 1796 |
+
def forward(self, q, k, mask: Optional[Tensor] = None):
|
| 1797 |
+
q = self.q_linear(q)
|
| 1798 |
+
k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
|
| 1799 |
+
queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
|
| 1800 |
+
keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
|
| 1801 |
+
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
|
| 1802 |
+
|
| 1803 |
+
if mask is not None:
|
| 1804 |
+
weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
|
| 1805 |
+
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
|
| 1806 |
+
weights = self.dropout(weights)
|
| 1807 |
+
return weights
|
| 1808 |
+
|
| 1809 |
+
|
| 1810 |
+
__all__ = [
|
| 1811 |
+
"DetrForObjectDetection",
|
| 1812 |
+
"DetrForSegmentation",
|
| 1813 |
+
"DetrModel",
|
| 1814 |
+
"DetrPreTrainedModel",
|
| 1815 |
+
]
|
docs/transformers/build/lib/transformers/models/dialogpt/__init__.py
ADDED
|
File without changes
|
docs/transformers/build/lib/transformers/models/dialogpt/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from transformers.utils import WEIGHTS_NAME
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
DIALOGPT_MODELS = ["small", "medium", "large"]
|
| 24 |
+
|
| 25 |
+
OLD_KEY = "lm_head.decoder.weight"
|
| 26 |
+
NEW_KEY = "lm_head.weight"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def convert_dialogpt_checkpoint(checkpoint_path: str, pytorch_dump_folder_path: str):
|
| 30 |
+
d = torch.load(checkpoint_path, weights_only=True)
|
| 31 |
+
d[NEW_KEY] = d.pop(OLD_KEY)
|
| 32 |
+
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
| 33 |
+
torch.save(d, os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
parser = argparse.ArgumentParser()
|
| 38 |
+
parser.add_argument("--dialogpt_path", default=".", type=str)
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
for MODEL in DIALOGPT_MODELS:
|
| 41 |
+
checkpoint_path = os.path.join(args.dialogpt_path, f"{MODEL}_ft.pkl")
|
| 42 |
+
pytorch_dump_folder_path = f"./DialoGPT-{MODEL}"
|
| 43 |
+
convert_dialogpt_checkpoint(
|
| 44 |
+
checkpoint_path,
|
| 45 |
+
pytorch_dump_folder_path,
|
| 46 |
+
)
|
docs/transformers/build/lib/transformers/models/diffllama/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_diffllama import *
|
| 22 |
+
from .modeling_diffllama import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/diffllama/configuration_diffllama.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on Llama implementations in this library and Microsoft's
|
| 5 |
+
# Differential Transformer implementations.
|
| 6 |
+
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
"""DiffLlama model configuration"""
|
| 19 |
+
|
| 20 |
+
from ...configuration_utils import PretrainedConfig
|
| 21 |
+
from ...modeling_rope_utils import rope_config_validation
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DiffLlamaConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`DiffLlamaModel`]. It is used to instantiate an DiffLlama
|
| 27 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults
|
| 28 |
+
will yield a similar configuration to that of the [kajuma/DiffLlama-0.3B-handcut](https://huggingface.co/kajuma/DiffLlama-0.3B-handcut).
|
| 29 |
+
|
| 30 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 31 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
| 36 |
+
Vocabulary size of the DiffLlama model. Defines the number of different tokens that can be represented by the
|
| 37 |
+
`inputs_ids` passed when calling [`DiffLlamaModel`]
|
| 38 |
+
hidden_size (`int`, *optional*, defaults to 2048):
|
| 39 |
+
Dimension of the hidden representations.
|
| 40 |
+
intermediate_size (`int`, *optional*, defaults to 8192):
|
| 41 |
+
Dimension of the MLP representations.
|
| 42 |
+
num_hidden_layers (`int`, *optional*, defaults to 16):
|
| 43 |
+
Number of hidden layers in the Transformer decoder.
|
| 44 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 45 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 46 |
+
num_key_value_heads (`int`, *optional*):
|
| 47 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 48 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 49 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 50 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 51 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 52 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
| 53 |
+
`num_attention_heads`.
|
| 54 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 55 |
+
The non-linear activation function (function or string) in the decoder.
|
| 56 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 57 |
+
The maximum sequence length that this model might ever be used with.
|
| 58 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 59 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 60 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 61 |
+
The epsilon used by the rms normalization layers.
|
| 62 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 63 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 64 |
+
relevant if `config.is_decoder=True`.
|
| 65 |
+
pad_token_id (`int`, *optional*):
|
| 66 |
+
Padding token id.
|
| 67 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
| 68 |
+
Beginning of stream token id.
|
| 69 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 70 |
+
End of stream token id.
|
| 71 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 72 |
+
Whether to tie weight embeddings
|
| 73 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 74 |
+
The base period of the RoPE embeddings.
|
| 75 |
+
rope_scaling (`Dict`, *optional*):
|
| 76 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 77 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 78 |
+
accordingly.
|
| 79 |
+
Expected contents:
|
| 80 |
+
`rope_type` (`str`):
|
| 81 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 82 |
+
'diffllama3'], with 'default' being the original RoPE implementation.
|
| 83 |
+
`factor` (`float`, *optional*):
|
| 84 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 85 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 86 |
+
original maximum pre-trained length.
|
| 87 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 88 |
+
Used with 'dynamic', 'longrope' and 'diffllama3'. The original max position embeddings used during
|
| 89 |
+
pretraining.
|
| 90 |
+
`attention_factor` (`float`, *optional*):
|
| 91 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 92 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 93 |
+
`factor` field to infer the suggested value.
|
| 94 |
+
`beta_fast` (`float`, *optional*):
|
| 95 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 96 |
+
ramp function. If unspecified, it defaults to 32.
|
| 97 |
+
`beta_slow` (`float`, *optional*):
|
| 98 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 99 |
+
ramp function. If unspecified, it defaults to 1.
|
| 100 |
+
`short_factor` (`List[float]`, *optional*):
|
| 101 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 102 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 103 |
+
size divided by the number of attention heads divided by 2
|
| 104 |
+
`long_factor` (`List[float]`, *optional*):
|
| 105 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 106 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 107 |
+
size divided by the number of attention heads divided by 2
|
| 108 |
+
`low_freq_factor` (`float`, *optional*):
|
| 109 |
+
Only used with 'diffllama3'. Scaling factor applied to low frequency components of the RoPE
|
| 110 |
+
`high_freq_factor` (`float`, *optional*):
|
| 111 |
+
Only used with 'diffllama3'. Scaling factor applied to high frequency components of the RoPE
|
| 112 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 113 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 114 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 115 |
+
The dropout ratio for the attention probabilities.
|
| 116 |
+
lambda_std_dev (`float`, *optional*, defaults to 0.1):
|
| 117 |
+
The standard deviation for initialization of parameter lambda in attention layer.
|
| 118 |
+
head_dim (`int`, *optional*):
|
| 119 |
+
The attention head dimension. If None, it will default to hidden_size // num_heads
|
| 120 |
+
|
| 121 |
+
```python
|
| 122 |
+
>>> from transformers import DiffLlamaModel, DiffLlamaConfig
|
| 123 |
+
|
| 124 |
+
>>> # Initializing a DiffLlama diffllama-7b style configuration
|
| 125 |
+
>>> configuration = DiffLlamaConfig()
|
| 126 |
+
|
| 127 |
+
>>> # Initializing a model from the diffllama-7b style configuration
|
| 128 |
+
>>> model = DiffLlamaModel(configuration)
|
| 129 |
+
|
| 130 |
+
>>> # Accessing the model configuration
|
| 131 |
+
>>> configuration = model.config
|
| 132 |
+
```"""
|
| 133 |
+
|
| 134 |
+
model_type = "diffllama"
|
| 135 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 136 |
+
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
vocab_size=32000,
|
| 140 |
+
hidden_size=2048,
|
| 141 |
+
intermediate_size=8192,
|
| 142 |
+
num_hidden_layers=16,
|
| 143 |
+
num_attention_heads=32,
|
| 144 |
+
num_key_value_heads=None,
|
| 145 |
+
hidden_act="silu",
|
| 146 |
+
max_position_embeddings=2048,
|
| 147 |
+
initializer_range=0.02,
|
| 148 |
+
rms_norm_eps=1e-5,
|
| 149 |
+
use_cache=True,
|
| 150 |
+
pad_token_id=None,
|
| 151 |
+
bos_token_id=1,
|
| 152 |
+
eos_token_id=2,
|
| 153 |
+
tie_word_embeddings=False,
|
| 154 |
+
rope_theta=10000.0,
|
| 155 |
+
rope_scaling=None,
|
| 156 |
+
attention_bias=False,
|
| 157 |
+
attention_dropout=0.0,
|
| 158 |
+
lambda_std_dev=0.1,
|
| 159 |
+
head_dim=None,
|
| 160 |
+
**kwargs,
|
| 161 |
+
):
|
| 162 |
+
self.vocab_size = vocab_size
|
| 163 |
+
self.max_position_embeddings = max_position_embeddings
|
| 164 |
+
self.hidden_size = hidden_size
|
| 165 |
+
self.intermediate_size = intermediate_size
|
| 166 |
+
self.num_hidden_layers = num_hidden_layers
|
| 167 |
+
self.num_attention_heads = num_attention_heads
|
| 168 |
+
|
| 169 |
+
# for backward compatibility
|
| 170 |
+
if num_key_value_heads is None:
|
| 171 |
+
num_key_value_heads = num_attention_heads
|
| 172 |
+
|
| 173 |
+
self.num_key_value_heads = num_key_value_heads
|
| 174 |
+
self.hidden_act = hidden_act
|
| 175 |
+
self.initializer_range = initializer_range
|
| 176 |
+
self.rms_norm_eps = rms_norm_eps
|
| 177 |
+
self.use_cache = use_cache
|
| 178 |
+
self.rope_theta = rope_theta
|
| 179 |
+
self.rope_scaling = rope_scaling
|
| 180 |
+
self.attention_bias = attention_bias
|
| 181 |
+
self.attention_dropout = attention_dropout
|
| 182 |
+
self.lambda_std_dev = lambda_std_dev
|
| 183 |
+
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
| 184 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 185 |
+
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
| 186 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 187 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 188 |
+
rope_config_validation(self)
|
| 189 |
+
|
| 190 |
+
super().__init__(
|
| 191 |
+
pad_token_id=pad_token_id,
|
| 192 |
+
bos_token_id=bos_token_id,
|
| 193 |
+
eos_token_id=eos_token_id,
|
| 194 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 195 |
+
**kwargs,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
__all__ = ["DiffLlamaConfig"]
|
docs/transformers/build/lib/transformers/models/esm/openfold_utils/rigid_utils.py
ADDED
|
@@ -0,0 +1,1242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from functools import lru_cache
|
| 19 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def rot_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
| 26 |
+
"""
|
| 27 |
+
Performs matrix multiplication of two rotation matrix tensors. Written out by hand to avoid AMP downcasting.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
a: [*, 3, 3] left multiplicand
|
| 31 |
+
b: [*, 3, 3] right multiplicand
|
| 32 |
+
Returns:
|
| 33 |
+
The product ab
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def row_mul(i: int) -> torch.Tensor:
|
| 37 |
+
return torch.stack(
|
| 38 |
+
[
|
| 39 |
+
a[..., i, 0] * b[..., 0, 0] + a[..., i, 1] * b[..., 1, 0] + a[..., i, 2] * b[..., 2, 0],
|
| 40 |
+
a[..., i, 0] * b[..., 0, 1] + a[..., i, 1] * b[..., 1, 1] + a[..., i, 2] * b[..., 2, 1],
|
| 41 |
+
a[..., i, 0] * b[..., 0, 2] + a[..., i, 1] * b[..., 1, 2] + a[..., i, 2] * b[..., 2, 2],
|
| 42 |
+
],
|
| 43 |
+
dim=-1,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
return torch.stack(
|
| 47 |
+
[
|
| 48 |
+
row_mul(0),
|
| 49 |
+
row_mul(1),
|
| 50 |
+
row_mul(2),
|
| 51 |
+
],
|
| 52 |
+
dim=-2,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def rot_vec_mul(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
"""
|
| 58 |
+
Applies a rotation to a vector. Written out by hand to avoid transfer to avoid AMP downcasting.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
r: [*, 3, 3] rotation matrices
|
| 62 |
+
t: [*, 3] coordinate tensors
|
| 63 |
+
Returns:
|
| 64 |
+
[*, 3] rotated coordinates
|
| 65 |
+
"""
|
| 66 |
+
x, y, z = torch.unbind(t, dim=-1)
|
| 67 |
+
return torch.stack(
|
| 68 |
+
[
|
| 69 |
+
r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
|
| 70 |
+
r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
|
| 71 |
+
r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
|
| 72 |
+
],
|
| 73 |
+
dim=-1,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@lru_cache(maxsize=None)
|
| 78 |
+
def identity_rot_mats(
|
| 79 |
+
batch_dims: Tuple[int, ...],
|
| 80 |
+
dtype: Optional[torch.dtype] = None,
|
| 81 |
+
device: Optional[torch.device] = None,
|
| 82 |
+
requires_grad: bool = True,
|
| 83 |
+
) -> torch.Tensor:
|
| 84 |
+
rots = torch.eye(3, dtype=dtype, device=device, requires_grad=requires_grad)
|
| 85 |
+
rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
|
| 86 |
+
rots = rots.expand(*batch_dims, -1, -1)
|
| 87 |
+
rots = rots.contiguous()
|
| 88 |
+
|
| 89 |
+
return rots
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@lru_cache(maxsize=None)
|
| 93 |
+
def identity_trans(
|
| 94 |
+
batch_dims: Tuple[int, ...],
|
| 95 |
+
dtype: Optional[torch.dtype] = None,
|
| 96 |
+
device: Optional[torch.device] = None,
|
| 97 |
+
requires_grad: bool = True,
|
| 98 |
+
) -> torch.Tensor:
|
| 99 |
+
trans = torch.zeros((*batch_dims, 3), dtype=dtype, device=device, requires_grad=requires_grad)
|
| 100 |
+
return trans
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@lru_cache(maxsize=None)
|
| 104 |
+
def identity_quats(
|
| 105 |
+
batch_dims: Tuple[int, ...],
|
| 106 |
+
dtype: Optional[torch.dtype] = None,
|
| 107 |
+
device: Optional[torch.device] = None,
|
| 108 |
+
requires_grad: bool = True,
|
| 109 |
+
) -> torch.Tensor:
|
| 110 |
+
quat = torch.zeros((*batch_dims, 4), dtype=dtype, device=device, requires_grad=requires_grad)
|
| 111 |
+
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
quat[..., 0] = 1
|
| 114 |
+
|
| 115 |
+
return quat
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
_quat_elements: List[str] = ["a", "b", "c", "d"]
|
| 119 |
+
_qtr_keys: List[str] = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
|
| 120 |
+
_qtr_ind_dict: Dict[str, int] = {key: ind for ind, key in enumerate(_qtr_keys)}
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _to_mat(pairs: List[Tuple[str, int]]) -> np.ndarray:
|
| 124 |
+
mat = np.zeros((4, 4))
|
| 125 |
+
for key, value in pairs:
|
| 126 |
+
ind = _qtr_ind_dict[key]
|
| 127 |
+
mat[ind // 4][ind % 4] = value
|
| 128 |
+
|
| 129 |
+
return mat
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
_QTR_MAT = np.zeros((4, 4, 3, 3))
|
| 133 |
+
_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
|
| 134 |
+
_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
|
| 135 |
+
_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
|
| 136 |
+
_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
|
| 137 |
+
_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
|
| 138 |
+
_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
|
| 139 |
+
_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
|
| 140 |
+
_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
|
| 141 |
+
_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
|
| 145 |
+
"""
|
| 146 |
+
Converts a quaternion to a rotation matrix.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
quat: [*, 4] quaternions
|
| 150 |
+
Returns:
|
| 151 |
+
[*, 3, 3] rotation matrices
|
| 152 |
+
"""
|
| 153 |
+
# [*, 4, 4]
|
| 154 |
+
quat = quat[..., None] * quat[..., None, :]
|
| 155 |
+
|
| 156 |
+
# [4, 4, 3, 3]
|
| 157 |
+
mat = _get_quat("_QTR_MAT", dtype=quat.dtype, device=quat.device)
|
| 158 |
+
|
| 159 |
+
# [*, 4, 4, 3, 3]
|
| 160 |
+
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
|
| 161 |
+
quat = quat[..., None, None] * shaped_qtr_mat
|
| 162 |
+
|
| 163 |
+
# [*, 3, 3]
|
| 164 |
+
return torch.sum(quat, dim=(-3, -4))
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def rot_to_quat(rot: torch.Tensor) -> torch.Tensor:
|
| 168 |
+
if rot.shape[-2:] != (3, 3):
|
| 169 |
+
raise ValueError("Input rotation is incorrectly shaped")
|
| 170 |
+
|
| 171 |
+
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = [[rot[..., i, j] for j in range(3)] for i in range(3)]
|
| 172 |
+
|
| 173 |
+
k = [
|
| 174 |
+
[
|
| 175 |
+
xx + yy + zz,
|
| 176 |
+
zy - yz,
|
| 177 |
+
xz - zx,
|
| 178 |
+
yx - xy,
|
| 179 |
+
],
|
| 180 |
+
[
|
| 181 |
+
zy - yz,
|
| 182 |
+
xx - yy - zz,
|
| 183 |
+
xy + yx,
|
| 184 |
+
xz + zx,
|
| 185 |
+
],
|
| 186 |
+
[
|
| 187 |
+
xz - zx,
|
| 188 |
+
xy + yx,
|
| 189 |
+
yy - xx - zz,
|
| 190 |
+
yz + zy,
|
| 191 |
+
],
|
| 192 |
+
[
|
| 193 |
+
yx - xy,
|
| 194 |
+
xz + zx,
|
| 195 |
+
yz + zy,
|
| 196 |
+
zz - xx - yy,
|
| 197 |
+
],
|
| 198 |
+
]
|
| 199 |
+
|
| 200 |
+
_, vectors = torch.linalg.eigh((1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2))
|
| 201 |
+
return vectors[..., -1]
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
_QUAT_MULTIPLY = np.zeros((4, 4, 4))
|
| 205 |
+
_QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, -1]]
|
| 206 |
+
|
| 207 |
+
_QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0]]
|
| 208 |
+
|
| 209 |
+
_QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], [0, 1, 0, 0]]
|
| 210 |
+
|
| 211 |
+
_QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], [1, 0, 0, 0]]
|
| 212 |
+
|
| 213 |
+
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
|
| 214 |
+
|
| 215 |
+
_CACHED_QUATS: Dict[str, np.ndarray] = {
|
| 216 |
+
"_QTR_MAT": _QTR_MAT,
|
| 217 |
+
"_QUAT_MULTIPLY": _QUAT_MULTIPLY,
|
| 218 |
+
"_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC,
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
@lru_cache(maxsize=None)
|
| 223 |
+
def _get_quat(quat_key: str, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
| 224 |
+
return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def quat_multiply(quat1: torch.Tensor, quat2: torch.Tensor) -> torch.Tensor:
|
| 228 |
+
"""Multiply a quaternion by another quaternion."""
|
| 229 |
+
mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device)
|
| 230 |
+
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
|
| 231 |
+
return torch.sum(reshaped_mat * quat1[..., :, None, None] * quat2[..., None, :, None], dim=(-3, -2))
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def quat_multiply_by_vec(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
|
| 235 |
+
"""Multiply a quaternion by a pure-vector quaternion."""
|
| 236 |
+
mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device)
|
| 237 |
+
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
|
| 238 |
+
return torch.sum(reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], dim=(-3, -2))
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def invert_rot_mat(rot_mat: torch.Tensor) -> torch.Tensor:
|
| 242 |
+
return rot_mat.transpose(-1, -2)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def invert_quat(quat: torch.Tensor) -> torch.Tensor:
|
| 246 |
+
quat_prime = quat.clone()
|
| 247 |
+
quat_prime[..., 1:] *= -1
|
| 248 |
+
inv = quat_prime / torch.sum(quat**2, dim=-1, keepdim=True)
|
| 249 |
+
return inv
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class Rotation:
|
| 253 |
+
"""
|
| 254 |
+
A 3D rotation. Depending on how the object is initialized, the rotation is represented by either a rotation matrix
|
| 255 |
+
or a quaternion, though both formats are made available by helper functions. To simplify gradient computation, the
|
| 256 |
+
underlying format of the rotation cannot be changed in-place. Like Rigid, the class is designed to mimic the
|
| 257 |
+
behavior of a torch Tensor, almost as if each Rotation object were a tensor of rotations, in one format or another.
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
def __init__(
|
| 261 |
+
self,
|
| 262 |
+
rot_mats: Optional[torch.Tensor] = None,
|
| 263 |
+
quats: Optional[torch.Tensor] = None,
|
| 264 |
+
normalize_quats: bool = True,
|
| 265 |
+
):
|
| 266 |
+
"""
|
| 267 |
+
Args:
|
| 268 |
+
rot_mats:
|
| 269 |
+
A [*, 3, 3] rotation matrix tensor. Mutually exclusive with quats
|
| 270 |
+
quats:
|
| 271 |
+
A [*, 4] quaternion. Mutually exclusive with rot_mats. If normalize_quats is not True, must be a unit
|
| 272 |
+
quaternion
|
| 273 |
+
normalize_quats:
|
| 274 |
+
If quats is specified, whether to normalize quats
|
| 275 |
+
"""
|
| 276 |
+
if (rot_mats is None and quats is None) or (rot_mats is not None and quats is not None):
|
| 277 |
+
raise ValueError("Exactly one input argument must be specified")
|
| 278 |
+
|
| 279 |
+
if (rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or (quats is not None and quats.shape[-1] != 4):
|
| 280 |
+
raise ValueError("Incorrectly shaped rotation matrix or quaternion")
|
| 281 |
+
|
| 282 |
+
# Force full-precision
|
| 283 |
+
if quats is not None:
|
| 284 |
+
quats = quats.to(dtype=torch.float32)
|
| 285 |
+
if rot_mats is not None:
|
| 286 |
+
rot_mats = rot_mats.to(dtype=torch.float32)
|
| 287 |
+
|
| 288 |
+
if quats is not None and normalize_quats:
|
| 289 |
+
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
|
| 290 |
+
|
| 291 |
+
self._rot_mats = rot_mats
|
| 292 |
+
self._quats = quats
|
| 293 |
+
|
| 294 |
+
@staticmethod
|
| 295 |
+
def identity(
|
| 296 |
+
shape,
|
| 297 |
+
dtype: Optional[torch.dtype] = None,
|
| 298 |
+
device: Optional[torch.device] = None,
|
| 299 |
+
requires_grad: bool = True,
|
| 300 |
+
fmt: str = "quat",
|
| 301 |
+
) -> Rotation:
|
| 302 |
+
"""
|
| 303 |
+
Returns an identity Rotation.
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
shape:
|
| 307 |
+
The "shape" of the resulting Rotation object. See documentation for the shape property
|
| 308 |
+
dtype:
|
| 309 |
+
The torch dtype for the rotation
|
| 310 |
+
device:
|
| 311 |
+
The torch device for the new rotation
|
| 312 |
+
requires_grad:
|
| 313 |
+
Whether the underlying tensors in the new rotation object should require gradient computation
|
| 314 |
+
fmt:
|
| 315 |
+
One of "quat" or "rot_mat". Determines the underlying format of the new object's rotation
|
| 316 |
+
Returns:
|
| 317 |
+
A new identity rotation
|
| 318 |
+
"""
|
| 319 |
+
if fmt == "rot_mat":
|
| 320 |
+
rot_mats = identity_rot_mats(
|
| 321 |
+
shape,
|
| 322 |
+
dtype,
|
| 323 |
+
device,
|
| 324 |
+
requires_grad,
|
| 325 |
+
)
|
| 326 |
+
return Rotation(rot_mats=rot_mats, quats=None)
|
| 327 |
+
elif fmt == "quat":
|
| 328 |
+
quats = identity_quats(shape, dtype, device, requires_grad)
|
| 329 |
+
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
|
| 330 |
+
else:
|
| 331 |
+
raise ValueError(f"Invalid format: f{fmt}")
|
| 332 |
+
|
| 333 |
+
# Magic methods
|
| 334 |
+
|
| 335 |
+
def __getitem__(self, index: Any) -> Rotation:
|
| 336 |
+
"""
|
| 337 |
+
Allows torch-style indexing over the virtual shape of the rotation object. See documentation for the shape
|
| 338 |
+
property.
|
| 339 |
+
|
| 340 |
+
Args:
|
| 341 |
+
index:
|
| 342 |
+
A torch index. E.g. (1, 3, 2), or (slice(None,))
|
| 343 |
+
Returns:
|
| 344 |
+
The indexed rotation
|
| 345 |
+
"""
|
| 346 |
+
if type(index) is not tuple:
|
| 347 |
+
index = (index,)
|
| 348 |
+
|
| 349 |
+
if self._rot_mats is not None:
|
| 350 |
+
rot_mats = self._rot_mats[index + (slice(None), slice(None))]
|
| 351 |
+
return Rotation(rot_mats=rot_mats)
|
| 352 |
+
elif self._quats is not None:
|
| 353 |
+
quats = self._quats[index + (slice(None),)]
|
| 354 |
+
return Rotation(quats=quats, normalize_quats=False)
|
| 355 |
+
else:
|
| 356 |
+
raise ValueError("Both rotations are None")
|
| 357 |
+
|
| 358 |
+
def __mul__(self, right: torch.Tensor) -> Rotation:
|
| 359 |
+
"""
|
| 360 |
+
Pointwise left multiplication of the rotation with a tensor. Can be used to e.g. mask the Rotation.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
right:
|
| 364 |
+
The tensor multiplicand
|
| 365 |
+
Returns:
|
| 366 |
+
The product
|
| 367 |
+
"""
|
| 368 |
+
if not (isinstance(right, torch.Tensor)):
|
| 369 |
+
raise TypeError("The other multiplicand must be a Tensor")
|
| 370 |
+
|
| 371 |
+
if self._rot_mats is not None:
|
| 372 |
+
rot_mats = self._rot_mats * right[..., None, None]
|
| 373 |
+
return Rotation(rot_mats=rot_mats, quats=None)
|
| 374 |
+
elif self._quats is not None:
|
| 375 |
+
quats = self._quats * right[..., None]
|
| 376 |
+
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
|
| 377 |
+
else:
|
| 378 |
+
raise ValueError("Both rotations are None")
|
| 379 |
+
|
| 380 |
+
def __rmul__(self, left: torch.Tensor) -> Rotation:
|
| 381 |
+
"""
|
| 382 |
+
Reverse pointwise multiplication of the rotation with a tensor.
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
left:
|
| 386 |
+
The left multiplicand
|
| 387 |
+
Returns:
|
| 388 |
+
The product
|
| 389 |
+
"""
|
| 390 |
+
return self.__mul__(left)
|
| 391 |
+
|
| 392 |
+
# Properties
|
| 393 |
+
|
| 394 |
+
@property
|
| 395 |
+
def shape(self) -> torch.Size:
|
| 396 |
+
"""
|
| 397 |
+
Returns the virtual shape of the rotation object. This shape is defined as the batch dimensions of the
|
| 398 |
+
underlying rotation matrix or quaternion. If the Rotation was initialized with a [10, 3, 3] rotation matrix
|
| 399 |
+
tensor, for example, the resulting shape would be [10].
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
The virtual shape of the rotation object
|
| 403 |
+
"""
|
| 404 |
+
if self._rot_mats is not None:
|
| 405 |
+
return self._rot_mats.shape[:-2]
|
| 406 |
+
elif self._quats is not None:
|
| 407 |
+
return self._quats.shape[:-1]
|
| 408 |
+
else:
|
| 409 |
+
raise ValueError("Both rotations are None")
|
| 410 |
+
|
| 411 |
+
@property
|
| 412 |
+
def dtype(self) -> torch.dtype:
|
| 413 |
+
"""
|
| 414 |
+
Returns the dtype of the underlying rotation.
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
The dtype of the underlying rotation
|
| 418 |
+
"""
|
| 419 |
+
if self._rot_mats is not None:
|
| 420 |
+
return self._rot_mats.dtype
|
| 421 |
+
elif self._quats is not None:
|
| 422 |
+
return self._quats.dtype
|
| 423 |
+
else:
|
| 424 |
+
raise ValueError("Both rotations are None")
|
| 425 |
+
|
| 426 |
+
@property
|
| 427 |
+
def device(self) -> torch.device:
|
| 428 |
+
"""
|
| 429 |
+
The device of the underlying rotation
|
| 430 |
+
|
| 431 |
+
Returns:
|
| 432 |
+
The device of the underlying rotation
|
| 433 |
+
"""
|
| 434 |
+
if self._rot_mats is not None:
|
| 435 |
+
return self._rot_mats.device
|
| 436 |
+
elif self._quats is not None:
|
| 437 |
+
return self._quats.device
|
| 438 |
+
else:
|
| 439 |
+
raise ValueError("Both rotations are None")
|
| 440 |
+
|
| 441 |
+
@property
|
| 442 |
+
def requires_grad(self) -> bool:
|
| 443 |
+
"""
|
| 444 |
+
Returns the requires_grad property of the underlying rotation
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
The requires_grad property of the underlying tensor
|
| 448 |
+
"""
|
| 449 |
+
if self._rot_mats is not None:
|
| 450 |
+
return self._rot_mats.requires_grad
|
| 451 |
+
elif self._quats is not None:
|
| 452 |
+
return self._quats.requires_grad
|
| 453 |
+
else:
|
| 454 |
+
raise ValueError("Both rotations are None")
|
| 455 |
+
|
| 456 |
+
def get_rot_mats(self) -> torch.Tensor:
|
| 457 |
+
"""
|
| 458 |
+
Returns the underlying rotation as a rotation matrix tensor.
|
| 459 |
+
|
| 460 |
+
Returns:
|
| 461 |
+
The rotation as a rotation matrix tensor
|
| 462 |
+
"""
|
| 463 |
+
if self._rot_mats is not None:
|
| 464 |
+
return self._rot_mats
|
| 465 |
+
elif self._quats is not None:
|
| 466 |
+
return quat_to_rot(self._quats)
|
| 467 |
+
else:
|
| 468 |
+
raise ValueError("Both rotations are None")
|
| 469 |
+
|
| 470 |
+
def get_quats(self) -> torch.Tensor:
|
| 471 |
+
"""
|
| 472 |
+
Returns the underlying rotation as a quaternion tensor.
|
| 473 |
+
|
| 474 |
+
Depending on whether the Rotation was initialized with a quaternion, this function may call torch.linalg.eigh.
|
| 475 |
+
|
| 476 |
+
Returns:
|
| 477 |
+
The rotation as a quaternion tensor.
|
| 478 |
+
"""
|
| 479 |
+
if self._rot_mats is not None:
|
| 480 |
+
return rot_to_quat(self._rot_mats)
|
| 481 |
+
elif self._quats is not None:
|
| 482 |
+
return self._quats
|
| 483 |
+
else:
|
| 484 |
+
raise ValueError("Both rotations are None")
|
| 485 |
+
|
| 486 |
+
def get_cur_rot(self) -> torch.Tensor:
|
| 487 |
+
"""
|
| 488 |
+
Return the underlying rotation in its current form
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
The stored rotation
|
| 492 |
+
"""
|
| 493 |
+
if self._rot_mats is not None:
|
| 494 |
+
return self._rot_mats
|
| 495 |
+
elif self._quats is not None:
|
| 496 |
+
return self._quats
|
| 497 |
+
else:
|
| 498 |
+
raise ValueError("Both rotations are None")
|
| 499 |
+
|
| 500 |
+
# Rotation functions
|
| 501 |
+
|
| 502 |
+
def compose_q_update_vec(self, q_update_vec: torch.Tensor, normalize_quats: bool = True) -> Rotation:
|
| 503 |
+
"""
|
| 504 |
+
Returns a new quaternion Rotation after updating the current object's underlying rotation with a quaternion
|
| 505 |
+
update, formatted as a [*, 3] tensor whose final three columns represent x, y, z such that (1, x, y, z) is the
|
| 506 |
+
desired (not necessarily unit) quaternion update.
|
| 507 |
+
|
| 508 |
+
Args:
|
| 509 |
+
q_update_vec:
|
| 510 |
+
A [*, 3] quaternion update tensor
|
| 511 |
+
normalize_quats:
|
| 512 |
+
Whether to normalize the output quaternion
|
| 513 |
+
Returns:
|
| 514 |
+
An updated Rotation
|
| 515 |
+
"""
|
| 516 |
+
quats = self.get_quats()
|
| 517 |
+
new_quats = quats + quat_multiply_by_vec(quats, q_update_vec)
|
| 518 |
+
return Rotation(
|
| 519 |
+
rot_mats=None,
|
| 520 |
+
quats=new_quats,
|
| 521 |
+
normalize_quats=normalize_quats,
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
def compose_r(self, r: Rotation) -> Rotation:
|
| 525 |
+
"""
|
| 526 |
+
Compose the rotation matrices of the current Rotation object with those of another.
|
| 527 |
+
|
| 528 |
+
Args:
|
| 529 |
+
r:
|
| 530 |
+
An update rotation object
|
| 531 |
+
Returns:
|
| 532 |
+
An updated rotation object
|
| 533 |
+
"""
|
| 534 |
+
r1 = self.get_rot_mats()
|
| 535 |
+
r2 = r.get_rot_mats()
|
| 536 |
+
new_rot_mats = rot_matmul(r1, r2)
|
| 537 |
+
return Rotation(rot_mats=new_rot_mats, quats=None)
|
| 538 |
+
|
| 539 |
+
def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation:
|
| 540 |
+
"""
|
| 541 |
+
Compose the quaternions of the current Rotation object with those of another.
|
| 542 |
+
|
| 543 |
+
Depending on whether either Rotation was initialized with quaternions, this function may call
|
| 544 |
+
torch.linalg.eigh.
|
| 545 |
+
|
| 546 |
+
Args:
|
| 547 |
+
r:
|
| 548 |
+
An update rotation object
|
| 549 |
+
Returns:
|
| 550 |
+
An updated rotation object
|
| 551 |
+
"""
|
| 552 |
+
q1 = self.get_quats()
|
| 553 |
+
q2 = r.get_quats()
|
| 554 |
+
new_quats = quat_multiply(q1, q2)
|
| 555 |
+
return Rotation(rot_mats=None, quats=new_quats, normalize_quats=normalize_quats)
|
| 556 |
+
|
| 557 |
+
def apply(self, pts: torch.Tensor) -> torch.Tensor:
|
| 558 |
+
"""
|
| 559 |
+
Apply the current Rotation as a rotation matrix to a set of 3D coordinates.
|
| 560 |
+
|
| 561 |
+
Args:
|
| 562 |
+
pts:
|
| 563 |
+
A [*, 3] set of points
|
| 564 |
+
Returns:
|
| 565 |
+
[*, 3] rotated points
|
| 566 |
+
"""
|
| 567 |
+
rot_mats = self.get_rot_mats()
|
| 568 |
+
return rot_vec_mul(rot_mats, pts)
|
| 569 |
+
|
| 570 |
+
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
|
| 571 |
+
"""
|
| 572 |
+
The inverse of the apply() method.
|
| 573 |
+
|
| 574 |
+
Args:
|
| 575 |
+
pts:
|
| 576 |
+
A [*, 3] set of points
|
| 577 |
+
Returns:
|
| 578 |
+
[*, 3] inverse-rotated points
|
| 579 |
+
"""
|
| 580 |
+
rot_mats = self.get_rot_mats()
|
| 581 |
+
inv_rot_mats = invert_rot_mat(rot_mats)
|
| 582 |
+
return rot_vec_mul(inv_rot_mats, pts)
|
| 583 |
+
|
| 584 |
+
def invert(self) -> Rotation:
|
| 585 |
+
"""
|
| 586 |
+
Returns the inverse of the current Rotation.
|
| 587 |
+
|
| 588 |
+
Returns:
|
| 589 |
+
The inverse of the current Rotation
|
| 590 |
+
"""
|
| 591 |
+
if self._rot_mats is not None:
|
| 592 |
+
return Rotation(rot_mats=invert_rot_mat(self._rot_mats), quats=None)
|
| 593 |
+
elif self._quats is not None:
|
| 594 |
+
return Rotation(
|
| 595 |
+
rot_mats=None,
|
| 596 |
+
quats=invert_quat(self._quats),
|
| 597 |
+
normalize_quats=False,
|
| 598 |
+
)
|
| 599 |
+
else:
|
| 600 |
+
raise ValueError("Both rotations are None")
|
| 601 |
+
|
| 602 |
+
# "Tensor" stuff
|
| 603 |
+
|
| 604 |
+
def unsqueeze(self, dim: int) -> Rotation:
|
| 605 |
+
"""
|
| 606 |
+
Analogous to torch.unsqueeze. The dimension is relative to the shape of the Rotation object.
|
| 607 |
+
|
| 608 |
+
Args:
|
| 609 |
+
dim: A positive or negative dimension index.
|
| 610 |
+
Returns:
|
| 611 |
+
The unsqueezed Rotation.
|
| 612 |
+
"""
|
| 613 |
+
if dim >= len(self.shape):
|
| 614 |
+
raise ValueError("Invalid dimension")
|
| 615 |
+
|
| 616 |
+
if self._rot_mats is not None:
|
| 617 |
+
rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
|
| 618 |
+
return Rotation(rot_mats=rot_mats, quats=None)
|
| 619 |
+
elif self._quats is not None:
|
| 620 |
+
quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
|
| 621 |
+
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
|
| 622 |
+
else:
|
| 623 |
+
raise ValueError("Both rotations are None")
|
| 624 |
+
|
| 625 |
+
@staticmethod
|
| 626 |
+
def cat(rs: Sequence[Rotation], dim: int) -> Rotation:
|
| 627 |
+
"""
|
| 628 |
+
Concatenates rotations along one of the batch dimensions. Analogous to torch.cat().
|
| 629 |
+
|
| 630 |
+
Note that the output of this operation is always a rotation matrix, regardless of the format of input
|
| 631 |
+
rotations.
|
| 632 |
+
|
| 633 |
+
Args:
|
| 634 |
+
rs:
|
| 635 |
+
A list of rotation objects
|
| 636 |
+
dim:
|
| 637 |
+
The dimension along which the rotations should be concatenated
|
| 638 |
+
Returns:
|
| 639 |
+
A concatenated Rotation object in rotation matrix format
|
| 640 |
+
"""
|
| 641 |
+
rot_mats = torch.cat(
|
| 642 |
+
[r.get_rot_mats() for r in rs],
|
| 643 |
+
dim=dim if dim >= 0 else dim - 2,
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
return Rotation(rot_mats=rot_mats, quats=None)
|
| 647 |
+
|
| 648 |
+
def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rotation:
|
| 649 |
+
"""
|
| 650 |
+
Apply a Tensor -> Tensor function to underlying rotation tensors, mapping over the rotation dimension(s). Can
|
| 651 |
+
be used e.g. to sum out a one-hot batch dimension.
|
| 652 |
+
|
| 653 |
+
Args:
|
| 654 |
+
fn:
|
| 655 |
+
A Tensor -> Tensor function to be mapped over the Rotation
|
| 656 |
+
Returns:
|
| 657 |
+
The transformed Rotation object
|
| 658 |
+
"""
|
| 659 |
+
if self._rot_mats is not None:
|
| 660 |
+
rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
|
| 661 |
+
rot_mats = torch.stack(list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1)
|
| 662 |
+
rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
|
| 663 |
+
return Rotation(rot_mats=rot_mats, quats=None)
|
| 664 |
+
elif self._quats is not None:
|
| 665 |
+
quats = torch.stack(list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1)
|
| 666 |
+
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
|
| 667 |
+
else:
|
| 668 |
+
raise ValueError("Both rotations are None")
|
| 669 |
+
|
| 670 |
+
def cuda(self) -> Rotation:
|
| 671 |
+
"""
|
| 672 |
+
Analogous to the cuda() method of torch Tensors
|
| 673 |
+
|
| 674 |
+
Returns:
|
| 675 |
+
A copy of the Rotation in CUDA memory
|
| 676 |
+
"""
|
| 677 |
+
if self._rot_mats is not None:
|
| 678 |
+
return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
|
| 679 |
+
elif self._quats is not None:
|
| 680 |
+
return Rotation(rot_mats=None, quats=self._quats.cuda(), normalize_quats=False)
|
| 681 |
+
else:
|
| 682 |
+
raise ValueError("Both rotations are None")
|
| 683 |
+
|
| 684 |
+
def to(self, device: Optional[torch.device], dtype: Optional[torch.dtype]) -> Rotation:
|
| 685 |
+
"""
|
| 686 |
+
Analogous to the to() method of torch Tensors
|
| 687 |
+
|
| 688 |
+
Args:
|
| 689 |
+
device:
|
| 690 |
+
A torch device
|
| 691 |
+
dtype:
|
| 692 |
+
A torch dtype
|
| 693 |
+
Returns:
|
| 694 |
+
A copy of the Rotation using the new device and dtype
|
| 695 |
+
"""
|
| 696 |
+
if self._rot_mats is not None:
|
| 697 |
+
return Rotation(
|
| 698 |
+
rot_mats=self._rot_mats.to(device=device, dtype=dtype),
|
| 699 |
+
quats=None,
|
| 700 |
+
)
|
| 701 |
+
elif self._quats is not None:
|
| 702 |
+
return Rotation(
|
| 703 |
+
rot_mats=None,
|
| 704 |
+
quats=self._quats.to(device=device, dtype=dtype),
|
| 705 |
+
normalize_quats=False,
|
| 706 |
+
)
|
| 707 |
+
else:
|
| 708 |
+
raise ValueError("Both rotations are None")
|
| 709 |
+
|
| 710 |
+
def detach(self) -> Rotation:
|
| 711 |
+
"""
|
| 712 |
+
Returns a copy of the Rotation whose underlying Tensor has been detached from its torch graph.
|
| 713 |
+
|
| 714 |
+
Returns:
|
| 715 |
+
A copy of the Rotation whose underlying Tensor has been detached from its torch graph
|
| 716 |
+
"""
|
| 717 |
+
if self._rot_mats is not None:
|
| 718 |
+
return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
|
| 719 |
+
elif self._quats is not None:
|
| 720 |
+
return Rotation(
|
| 721 |
+
rot_mats=None,
|
| 722 |
+
quats=self._quats.detach(),
|
| 723 |
+
normalize_quats=False,
|
| 724 |
+
)
|
| 725 |
+
else:
|
| 726 |
+
raise ValueError("Both rotations are None")
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
class Rigid:
|
| 730 |
+
"""
|
| 731 |
+
A class representing a rigid transformation. Little more than a wrapper around two objects: a Rotation object and a
|
| 732 |
+
[*, 3] translation Designed to behave approximately like a single torch tensor with the shape of the shared batch
|
| 733 |
+
dimensions of its component parts.
|
| 734 |
+
"""
|
| 735 |
+
|
| 736 |
+
def __init__(self, rots: Optional[Rotation], trans: Optional[torch.Tensor]):
|
| 737 |
+
"""
|
| 738 |
+
Args:
|
| 739 |
+
rots: A [*, 3, 3] rotation tensor
|
| 740 |
+
trans: A corresponding [*, 3] translation tensor
|
| 741 |
+
"""
|
| 742 |
+
# (we need device, dtype, etc. from at least one input)
|
| 743 |
+
|
| 744 |
+
batch_dims, dtype, device, requires_grad = None, None, None, None
|
| 745 |
+
if trans is not None:
|
| 746 |
+
batch_dims = trans.shape[:-1]
|
| 747 |
+
dtype = trans.dtype
|
| 748 |
+
device = trans.device
|
| 749 |
+
requires_grad = trans.requires_grad
|
| 750 |
+
elif rots is not None:
|
| 751 |
+
batch_dims = rots.shape
|
| 752 |
+
dtype = rots.dtype
|
| 753 |
+
device = rots.device
|
| 754 |
+
requires_grad = rots.requires_grad
|
| 755 |
+
else:
|
| 756 |
+
raise ValueError("At least one input argument must be specified")
|
| 757 |
+
|
| 758 |
+
if rots is None:
|
| 759 |
+
rots = Rotation.identity(
|
| 760 |
+
batch_dims,
|
| 761 |
+
dtype,
|
| 762 |
+
device,
|
| 763 |
+
requires_grad,
|
| 764 |
+
)
|
| 765 |
+
elif trans is None:
|
| 766 |
+
trans = identity_trans(
|
| 767 |
+
batch_dims,
|
| 768 |
+
dtype,
|
| 769 |
+
device,
|
| 770 |
+
requires_grad,
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
assert rots is not None
|
| 774 |
+
assert trans is not None
|
| 775 |
+
|
| 776 |
+
if (rots.shape != trans.shape[:-1]) or (rots.device != trans.device):
|
| 777 |
+
raise ValueError("Rots and trans incompatible")
|
| 778 |
+
|
| 779 |
+
# Force full precision. Happens to the rotations automatically.
|
| 780 |
+
trans = trans.to(dtype=torch.float32)
|
| 781 |
+
|
| 782 |
+
self._rots = rots
|
| 783 |
+
self._trans = trans
|
| 784 |
+
|
| 785 |
+
@staticmethod
|
| 786 |
+
def identity(
|
| 787 |
+
shape: Tuple[int, ...],
|
| 788 |
+
dtype: Optional[torch.dtype] = None,
|
| 789 |
+
device: Optional[torch.device] = None,
|
| 790 |
+
requires_grad: bool = True,
|
| 791 |
+
fmt: str = "quat",
|
| 792 |
+
) -> Rigid:
|
| 793 |
+
"""
|
| 794 |
+
Constructs an identity transformation.
|
| 795 |
+
|
| 796 |
+
Args:
|
| 797 |
+
shape:
|
| 798 |
+
The desired shape
|
| 799 |
+
dtype:
|
| 800 |
+
The dtype of both internal tensors
|
| 801 |
+
device:
|
| 802 |
+
The device of both internal tensors
|
| 803 |
+
requires_grad:
|
| 804 |
+
Whether grad should be enabled for the internal tensors
|
| 805 |
+
Returns:
|
| 806 |
+
The identity transformation
|
| 807 |
+
"""
|
| 808 |
+
return Rigid(
|
| 809 |
+
Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
|
| 810 |
+
identity_trans(shape, dtype, device, requires_grad),
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
def __getitem__(self, index: Any) -> Rigid:
|
| 814 |
+
"""
|
| 815 |
+
Indexes the affine transformation with PyTorch-style indices. The index is applied to the shared dimensions of
|
| 816 |
+
both the rotation and the translation.
|
| 817 |
+
|
| 818 |
+
E.g.::
|
| 819 |
+
|
| 820 |
+
r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None) t = Rigid(r, torch.rand(10, 10, 3)) indexed =
|
| 821 |
+
t[3, 4:6] assert(indexed.shape == (2,)) assert(indexed.get_rots().shape == (2,))
|
| 822 |
+
assert(indexed.get_trans().shape == (2, 3))
|
| 823 |
+
|
| 824 |
+
Args:
|
| 825 |
+
index: A standard torch tensor index. E.g. 8, (10, None, 3),
|
| 826 |
+
or (3, slice(0, 1, None))
|
| 827 |
+
Returns:
|
| 828 |
+
The indexed tensor
|
| 829 |
+
"""
|
| 830 |
+
if type(index) is not tuple:
|
| 831 |
+
index = (index,)
|
| 832 |
+
|
| 833 |
+
return Rigid(
|
| 834 |
+
self._rots[index],
|
| 835 |
+
self._trans[index + (slice(None),)],
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
def __mul__(self, right: torch.Tensor) -> Rigid:
|
| 839 |
+
"""
|
| 840 |
+
Pointwise left multiplication of the transformation with a tensor. Can be used to e.g. mask the Rigid.
|
| 841 |
+
|
| 842 |
+
Args:
|
| 843 |
+
right:
|
| 844 |
+
The tensor multiplicand
|
| 845 |
+
Returns:
|
| 846 |
+
The product
|
| 847 |
+
"""
|
| 848 |
+
if not (isinstance(right, torch.Tensor)):
|
| 849 |
+
raise TypeError("The other multiplicand must be a Tensor")
|
| 850 |
+
|
| 851 |
+
new_rots = self._rots * right
|
| 852 |
+
new_trans = self._trans * right[..., None]
|
| 853 |
+
|
| 854 |
+
return Rigid(new_rots, new_trans)
|
| 855 |
+
|
| 856 |
+
def __rmul__(self, left: torch.Tensor) -> Rigid:
|
| 857 |
+
"""
|
| 858 |
+
Reverse pointwise multiplication of the transformation with a tensor.
|
| 859 |
+
|
| 860 |
+
Args:
|
| 861 |
+
left:
|
| 862 |
+
The left multiplicand
|
| 863 |
+
Returns:
|
| 864 |
+
The product
|
| 865 |
+
"""
|
| 866 |
+
return self.__mul__(left)
|
| 867 |
+
|
| 868 |
+
@property
|
| 869 |
+
def shape(self) -> torch.Size:
|
| 870 |
+
"""
|
| 871 |
+
Returns the shape of the shared dimensions of the rotation and the translation.
|
| 872 |
+
|
| 873 |
+
Returns:
|
| 874 |
+
The shape of the transformation
|
| 875 |
+
"""
|
| 876 |
+
return self._trans.shape[:-1]
|
| 877 |
+
|
| 878 |
+
@property
|
| 879 |
+
def device(self) -> torch.device:
|
| 880 |
+
"""
|
| 881 |
+
Returns the device on which the Rigid's tensors are located.
|
| 882 |
+
|
| 883 |
+
Returns:
|
| 884 |
+
The device on which the Rigid's tensors are located
|
| 885 |
+
"""
|
| 886 |
+
return self._trans.device
|
| 887 |
+
|
| 888 |
+
def get_rots(self) -> Rotation:
|
| 889 |
+
"""
|
| 890 |
+
Getter for the rotation.
|
| 891 |
+
|
| 892 |
+
Returns:
|
| 893 |
+
The rotation object
|
| 894 |
+
"""
|
| 895 |
+
return self._rots
|
| 896 |
+
|
| 897 |
+
def get_trans(self) -> torch.Tensor:
|
| 898 |
+
"""
|
| 899 |
+
Getter for the translation.
|
| 900 |
+
|
| 901 |
+
Returns:
|
| 902 |
+
The stored translation
|
| 903 |
+
"""
|
| 904 |
+
return self._trans
|
| 905 |
+
|
| 906 |
+
def compose_q_update_vec(self, q_update_vec: torch.Tensor) -> Rigid:
|
| 907 |
+
"""
|
| 908 |
+
Composes the transformation with a quaternion update vector of shape [*, 6], where the final 6 columns
|
| 909 |
+
represent the x, y, and z values of a quaternion of form (1, x, y, z) followed by a 3D translation.
|
| 910 |
+
|
| 911 |
+
Args:
|
| 912 |
+
q_vec: The quaternion update vector.
|
| 913 |
+
Returns:
|
| 914 |
+
The composed transformation.
|
| 915 |
+
"""
|
| 916 |
+
q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
|
| 917 |
+
new_rots = self._rots.compose_q_update_vec(q_vec)
|
| 918 |
+
|
| 919 |
+
trans_update = self._rots.apply(t_vec)
|
| 920 |
+
new_translation = self._trans + trans_update
|
| 921 |
+
|
| 922 |
+
return Rigid(new_rots, new_translation)
|
| 923 |
+
|
| 924 |
+
def compose(self, r: Rigid) -> Rigid:
|
| 925 |
+
"""
|
| 926 |
+
Composes the current rigid object with another.
|
| 927 |
+
|
| 928 |
+
Args:
|
| 929 |
+
r:
|
| 930 |
+
Another Rigid object
|
| 931 |
+
Returns:
|
| 932 |
+
The composition of the two transformations
|
| 933 |
+
"""
|
| 934 |
+
new_rot = self._rots.compose_r(r._rots)
|
| 935 |
+
new_trans = self._rots.apply(r._trans) + self._trans
|
| 936 |
+
return Rigid(new_rot, new_trans)
|
| 937 |
+
|
| 938 |
+
def apply(self, pts: torch.Tensor) -> torch.Tensor:
|
| 939 |
+
"""
|
| 940 |
+
Applies the transformation to a coordinate tensor.
|
| 941 |
+
|
| 942 |
+
Args:
|
| 943 |
+
pts: A [*, 3] coordinate tensor.
|
| 944 |
+
Returns:
|
| 945 |
+
The transformed points.
|
| 946 |
+
"""
|
| 947 |
+
rotated = self._rots.apply(pts)
|
| 948 |
+
return rotated + self._trans
|
| 949 |
+
|
| 950 |
+
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
|
| 951 |
+
"""
|
| 952 |
+
Applies the inverse of the transformation to a coordinate tensor.
|
| 953 |
+
|
| 954 |
+
Args:
|
| 955 |
+
pts: A [*, 3] coordinate tensor
|
| 956 |
+
Returns:
|
| 957 |
+
The transformed points.
|
| 958 |
+
"""
|
| 959 |
+
pts = pts - self._trans
|
| 960 |
+
return self._rots.invert_apply(pts)
|
| 961 |
+
|
| 962 |
+
def invert(self) -> Rigid:
|
| 963 |
+
"""
|
| 964 |
+
Inverts the transformation.
|
| 965 |
+
|
| 966 |
+
Returns:
|
| 967 |
+
The inverse transformation.
|
| 968 |
+
"""
|
| 969 |
+
rot_inv = self._rots.invert()
|
| 970 |
+
trn_inv = rot_inv.apply(self._trans)
|
| 971 |
+
|
| 972 |
+
return Rigid(rot_inv, -1 * trn_inv)
|
| 973 |
+
|
| 974 |
+
def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rigid:
|
| 975 |
+
"""
|
| 976 |
+
Apply a Tensor -> Tensor function to underlying translation and rotation tensors, mapping over the
|
| 977 |
+
translation/rotation dimensions respectively.
|
| 978 |
+
|
| 979 |
+
Args:
|
| 980 |
+
fn:
|
| 981 |
+
A Tensor -> Tensor function to be mapped over the Rigid
|
| 982 |
+
Returns:
|
| 983 |
+
The transformed Rigid object
|
| 984 |
+
"""
|
| 985 |
+
new_rots = self._rots.map_tensor_fn(fn)
|
| 986 |
+
new_trans = torch.stack(list(map(fn, torch.unbind(self._trans, dim=-1))), dim=-1)
|
| 987 |
+
|
| 988 |
+
return Rigid(new_rots, new_trans)
|
| 989 |
+
|
| 990 |
+
def to_tensor_4x4(self) -> torch.Tensor:
|
| 991 |
+
"""
|
| 992 |
+
Converts a transformation to a homogeneous transformation tensor.
|
| 993 |
+
|
| 994 |
+
Returns:
|
| 995 |
+
A [*, 4, 4] homogeneous transformation tensor
|
| 996 |
+
"""
|
| 997 |
+
tensor = self._trans.new_zeros((*self.shape, 4, 4))
|
| 998 |
+
tensor[..., :3, :3] = self._rots.get_rot_mats()
|
| 999 |
+
tensor[..., :3, 3] = self._trans
|
| 1000 |
+
tensor[..., 3, 3] = 1
|
| 1001 |
+
return tensor
|
| 1002 |
+
|
| 1003 |
+
@staticmethod
|
| 1004 |
+
def from_tensor_4x4(t: torch.Tensor) -> Rigid:
|
| 1005 |
+
"""
|
| 1006 |
+
Constructs a transformation from a homogeneous transformation tensor.
|
| 1007 |
+
|
| 1008 |
+
Args:
|
| 1009 |
+
t: [*, 4, 4] homogeneous transformation tensor
|
| 1010 |
+
Returns:
|
| 1011 |
+
T object with shape [*]
|
| 1012 |
+
"""
|
| 1013 |
+
if t.shape[-2:] != (4, 4):
|
| 1014 |
+
raise ValueError("Incorrectly shaped input tensor")
|
| 1015 |
+
|
| 1016 |
+
rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
|
| 1017 |
+
trans = t[..., :3, 3]
|
| 1018 |
+
|
| 1019 |
+
return Rigid(rots, trans)
|
| 1020 |
+
|
| 1021 |
+
def to_tensor_7(self) -> torch.Tensor:
|
| 1022 |
+
"""
|
| 1023 |
+
Converts a transformation to a tensor with 7 final columns, four for the quaternion followed by three for the
|
| 1024 |
+
translation.
|
| 1025 |
+
|
| 1026 |
+
Returns:
|
| 1027 |
+
A [*, 7] tensor representation of the transformation
|
| 1028 |
+
"""
|
| 1029 |
+
tensor = self._trans.new_zeros((*self.shape, 7))
|
| 1030 |
+
tensor[..., :4] = self._rots.get_quats()
|
| 1031 |
+
tensor[..., 4:] = self._trans
|
| 1032 |
+
|
| 1033 |
+
return tensor
|
| 1034 |
+
|
| 1035 |
+
@staticmethod
|
| 1036 |
+
def from_tensor_7(t: torch.Tensor, normalize_quats: bool = False) -> Rigid:
|
| 1037 |
+
if t.shape[-1] != 7:
|
| 1038 |
+
raise ValueError("Incorrectly shaped input tensor")
|
| 1039 |
+
|
| 1040 |
+
quats, trans = t[..., :4], t[..., 4:]
|
| 1041 |
+
|
| 1042 |
+
rots = Rotation(rot_mats=None, quats=quats, normalize_quats=normalize_quats)
|
| 1043 |
+
|
| 1044 |
+
return Rigid(rots, trans)
|
| 1045 |
+
|
| 1046 |
+
@staticmethod
|
| 1047 |
+
def from_3_points(
|
| 1048 |
+
p_neg_x_axis: torch.Tensor, origin: torch.Tensor, p_xy_plane: torch.Tensor, eps: float = 1e-8
|
| 1049 |
+
) -> Rigid:
|
| 1050 |
+
"""
|
| 1051 |
+
Implements algorithm 21. Constructs transformations from sets of 3 points using the Gram-Schmidt algorithm.
|
| 1052 |
+
|
| 1053 |
+
Args:
|
| 1054 |
+
p_neg_x_axis: [*, 3] coordinates
|
| 1055 |
+
origin: [*, 3] coordinates used as frame origins
|
| 1056 |
+
p_xy_plane: [*, 3] coordinates
|
| 1057 |
+
eps: Small epsilon value
|
| 1058 |
+
Returns:
|
| 1059 |
+
A transformation object of shape [*]
|
| 1060 |
+
"""
|
| 1061 |
+
p_neg_x_axis_unbound = torch.unbind(p_neg_x_axis, dim=-1)
|
| 1062 |
+
origin_unbound = torch.unbind(origin, dim=-1)
|
| 1063 |
+
p_xy_plane_unbound = torch.unbind(p_xy_plane, dim=-1)
|
| 1064 |
+
|
| 1065 |
+
e0 = [c1 - c2 for c1, c2 in zip(origin_unbound, p_neg_x_axis_unbound)]
|
| 1066 |
+
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane_unbound, origin_unbound)]
|
| 1067 |
+
|
| 1068 |
+
denom = torch.sqrt(sum(c * c for c in e0) + eps * torch.ones_like(e0[0]))
|
| 1069 |
+
e0 = [c / denom for c in e0]
|
| 1070 |
+
dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
|
| 1071 |
+
e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
|
| 1072 |
+
denom = torch.sqrt(sum((c * c for c in e1)) + eps * torch.ones_like(e1[0]))
|
| 1073 |
+
e1 = [c / denom for c in e1]
|
| 1074 |
+
e2 = [
|
| 1075 |
+
e0[1] * e1[2] - e0[2] * e1[1],
|
| 1076 |
+
e0[2] * e1[0] - e0[0] * e1[2],
|
| 1077 |
+
e0[0] * e1[1] - e0[1] * e1[0],
|
| 1078 |
+
]
|
| 1079 |
+
|
| 1080 |
+
rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
|
| 1081 |
+
rots = rots.reshape(rots.shape[:-1] + (3, 3))
|
| 1082 |
+
|
| 1083 |
+
rot_obj = Rotation(rot_mats=rots, quats=None)
|
| 1084 |
+
|
| 1085 |
+
return Rigid(rot_obj, torch.stack(origin_unbound, dim=-1))
|
| 1086 |
+
|
| 1087 |
+
def unsqueeze(self, dim: int) -> Rigid:
|
| 1088 |
+
"""
|
| 1089 |
+
Analogous to torch.unsqueeze. The dimension is relative to the shared dimensions of the rotation/translation.
|
| 1090 |
+
|
| 1091 |
+
Args:
|
| 1092 |
+
dim: A positive or negative dimension index.
|
| 1093 |
+
Returns:
|
| 1094 |
+
The unsqueezed transformation.
|
| 1095 |
+
"""
|
| 1096 |
+
if dim >= len(self.shape):
|
| 1097 |
+
raise ValueError("Invalid dimension")
|
| 1098 |
+
rots = self._rots.unsqueeze(dim)
|
| 1099 |
+
trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
|
| 1100 |
+
|
| 1101 |
+
return Rigid(rots, trans)
|
| 1102 |
+
|
| 1103 |
+
@staticmethod
|
| 1104 |
+
def cat(ts: Sequence[Rigid], dim: int) -> Rigid:
|
| 1105 |
+
"""
|
| 1106 |
+
Concatenates transformations along a new dimension.
|
| 1107 |
+
|
| 1108 |
+
Args:
|
| 1109 |
+
ts:
|
| 1110 |
+
A list of T objects
|
| 1111 |
+
dim:
|
| 1112 |
+
The dimension along which the transformations should be concatenated
|
| 1113 |
+
Returns:
|
| 1114 |
+
A concatenated transformation object
|
| 1115 |
+
"""
|
| 1116 |
+
rots = Rotation.cat([t._rots for t in ts], dim)
|
| 1117 |
+
trans = torch.cat([t._trans for t in ts], dim=dim if dim >= 0 else dim - 1)
|
| 1118 |
+
|
| 1119 |
+
return Rigid(rots, trans)
|
| 1120 |
+
|
| 1121 |
+
def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Rigid:
|
| 1122 |
+
"""
|
| 1123 |
+
Applies a Rotation -> Rotation function to the stored rotation object.
|
| 1124 |
+
|
| 1125 |
+
Args:
|
| 1126 |
+
fn: A function of type Rotation -> Rotation
|
| 1127 |
+
Returns:
|
| 1128 |
+
A transformation object with a transformed rotation.
|
| 1129 |
+
"""
|
| 1130 |
+
return Rigid(fn(self._rots), self._trans)
|
| 1131 |
+
|
| 1132 |
+
def apply_trans_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rigid:
|
| 1133 |
+
"""
|
| 1134 |
+
Applies a Tensor -> Tensor function to the stored translation.
|
| 1135 |
+
|
| 1136 |
+
Args:
|
| 1137 |
+
fn:
|
| 1138 |
+
A function of type Tensor -> Tensor to be applied to the translation
|
| 1139 |
+
Returns:
|
| 1140 |
+
A transformation object with a transformed translation.
|
| 1141 |
+
"""
|
| 1142 |
+
return Rigid(self._rots, fn(self._trans))
|
| 1143 |
+
|
| 1144 |
+
def scale_translation(self, trans_scale_factor: float) -> Rigid:
|
| 1145 |
+
"""
|
| 1146 |
+
Scales the translation by a constant factor.
|
| 1147 |
+
|
| 1148 |
+
Args:
|
| 1149 |
+
trans_scale_factor:
|
| 1150 |
+
The constant factor
|
| 1151 |
+
Returns:
|
| 1152 |
+
A transformation object with a scaled translation.
|
| 1153 |
+
"""
|
| 1154 |
+
return self.apply_trans_fn(lambda t: t * trans_scale_factor)
|
| 1155 |
+
|
| 1156 |
+
def stop_rot_gradient(self) -> Rigid:
|
| 1157 |
+
"""
|
| 1158 |
+
Detaches the underlying rotation object
|
| 1159 |
+
|
| 1160 |
+
Returns:
|
| 1161 |
+
A transformation object with detached rotations
|
| 1162 |
+
"""
|
| 1163 |
+
return self.apply_rot_fn(lambda r: r.detach())
|
| 1164 |
+
|
| 1165 |
+
@staticmethod
|
| 1166 |
+
def make_transform_from_reference(
|
| 1167 |
+
n_xyz: torch.Tensor, ca_xyz: torch.Tensor, c_xyz: torch.Tensor, eps: float = 1e-20
|
| 1168 |
+
) -> Rigid:
|
| 1169 |
+
"""
|
| 1170 |
+
Returns a transformation object from reference coordinates.
|
| 1171 |
+
|
| 1172 |
+
Note that this method does not take care of symmetries. If you provide the atom positions in the non-standard
|
| 1173 |
+
way, the N atom will end up not at [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
|
| 1174 |
+
need to take care of such cases in your code.
|
| 1175 |
+
|
| 1176 |
+
Args:
|
| 1177 |
+
n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
|
| 1178 |
+
ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
|
| 1179 |
+
c_xyz: A [*, 3] tensor of carbon xyz coordinates.
|
| 1180 |
+
Returns:
|
| 1181 |
+
A transformation object. After applying the translation and rotation to the reference backbone, the
|
| 1182 |
+
coordinates will approximately equal to the input coordinates.
|
| 1183 |
+
"""
|
| 1184 |
+
translation = -1 * ca_xyz
|
| 1185 |
+
n_xyz = n_xyz + translation
|
| 1186 |
+
c_xyz = c_xyz + translation
|
| 1187 |
+
|
| 1188 |
+
c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
|
| 1189 |
+
norm = torch.sqrt(eps + c_x**2 + c_y**2)
|
| 1190 |
+
sin_c1 = -c_y / norm
|
| 1191 |
+
cos_c1 = c_x / norm
|
| 1192 |
+
|
| 1193 |
+
c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
|
| 1194 |
+
c1_rots[..., 0, 0] = cos_c1
|
| 1195 |
+
c1_rots[..., 0, 1] = -1 * sin_c1
|
| 1196 |
+
c1_rots[..., 1, 0] = sin_c1
|
| 1197 |
+
c1_rots[..., 1, 1] = cos_c1
|
| 1198 |
+
c1_rots[..., 2, 2] = 1
|
| 1199 |
+
|
| 1200 |
+
norm = torch.sqrt(eps + c_x**2 + c_y**2 + c_z**2)
|
| 1201 |
+
sin_c2 = c_z / norm
|
| 1202 |
+
cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm
|
| 1203 |
+
|
| 1204 |
+
c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
|
| 1205 |
+
c2_rots[..., 0, 0] = cos_c2
|
| 1206 |
+
c2_rots[..., 0, 2] = sin_c2
|
| 1207 |
+
c2_rots[..., 1, 1] = 1
|
| 1208 |
+
c2_rots[..., 2, 0] = -1 * sin_c2
|
| 1209 |
+
c2_rots[..., 2, 2] = cos_c2
|
| 1210 |
+
|
| 1211 |
+
c_rots = rot_matmul(c2_rots, c1_rots)
|
| 1212 |
+
n_xyz = rot_vec_mul(c_rots, n_xyz)
|
| 1213 |
+
|
| 1214 |
+
_, n_y, n_z = [n_xyz[..., i] for i in range(3)]
|
| 1215 |
+
norm = torch.sqrt(eps + n_y**2 + n_z**2)
|
| 1216 |
+
sin_n = -n_z / norm
|
| 1217 |
+
cos_n = n_y / norm
|
| 1218 |
+
|
| 1219 |
+
n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
|
| 1220 |
+
n_rots[..., 0, 0] = 1
|
| 1221 |
+
n_rots[..., 1, 1] = cos_n
|
| 1222 |
+
n_rots[..., 1, 2] = -1 * sin_n
|
| 1223 |
+
n_rots[..., 2, 1] = sin_n
|
| 1224 |
+
n_rots[..., 2, 2] = cos_n
|
| 1225 |
+
|
| 1226 |
+
rots = rot_matmul(n_rots, c_rots)
|
| 1227 |
+
|
| 1228 |
+
rots = rots.transpose(-1, -2)
|
| 1229 |
+
translation = -1 * translation
|
| 1230 |
+
|
| 1231 |
+
rot_obj = Rotation(rot_mats=rots, quats=None)
|
| 1232 |
+
|
| 1233 |
+
return Rigid(rot_obj, translation)
|
| 1234 |
+
|
| 1235 |
+
def cuda(self) -> Rigid:
|
| 1236 |
+
"""
|
| 1237 |
+
Moves the transformation object to GPU memory
|
| 1238 |
+
|
| 1239 |
+
Returns:
|
| 1240 |
+
A version of the transformation on GPU
|
| 1241 |
+
"""
|
| 1242 |
+
return Rigid(self._rots.cuda(), self._trans.cuda())
|
docs/transformers/build/lib/transformers/models/falcon/configuration_falcon.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Falcon configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class FalconConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`FalconModel`]. It is used to instantiate a Falcon
|
| 27 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 28 |
+
defaults will yield a similar configuration to that of the
|
| 29 |
+
[tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
vocab_size (`int`, *optional*, defaults to 65024):
|
| 37 |
+
Vocabulary size of the Falcon model. Defines the number of different tokens that can be represented by the
|
| 38 |
+
`inputs_ids` passed when calling [`FalconModel`]
|
| 39 |
+
hidden_size (`int`, *optional*, defaults to 4544):
|
| 40 |
+
Dimension of the hidden representations.
|
| 41 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 42 |
+
Number of hidden layers in the Transformer decoder.
|
| 43 |
+
num_attention_heads (`int`, *optional*, defaults to 71):
|
| 44 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 45 |
+
num_ln_in_parallel_attn (`int`, *optional*):
|
| 46 |
+
Set to 2 if separate layer norms are to be used for the MLP and the attention output when using parallel
|
| 47 |
+
attention, otherwise, 1.
|
| 48 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
|
| 49 |
+
The epsilon used by the layer normalization layers.
|
| 50 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 51 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 52 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 53 |
+
Whether the model should return the last key/values attentions (not used by all models). Only relevant if
|
| 54 |
+
`config.is_decoder=True`.
|
| 55 |
+
hidden_dropout (`float`, *optional*, defaults to 0.0):
|
| 56 |
+
The dropout probability for MLP layers.
|
| 57 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 58 |
+
The dropout probability for attention layers.
|
| 59 |
+
num_kv_heads (`int`, *optional*):
|
| 60 |
+
Number of key-value heads to use per attention layer. If unset, defaults to the same value as
|
| 61 |
+
`num_attention_heads`.
|
| 62 |
+
alibi (`bool`, *optional*, defaults to `False`):
|
| 63 |
+
Whether to use ALiBi positional biases during self-attention.
|
| 64 |
+
new_decoder_architecture (`bool`, *optional*, defaults to `False`):
|
| 65 |
+
Whether to use the new (Falcon-40B) decoder architecture. If `True`, the `multi_query` and `parallel_attn`
|
| 66 |
+
arguments are ignored, as the new decoder always uses parallel attention.
|
| 67 |
+
multi_query (`bool`, *optional*, defaults to `True`):
|
| 68 |
+
Whether to use multi-query attention in the decoder. Ignored when `new_decoder_architecture` is `True`.
|
| 69 |
+
parallel_attn (`bool`, *optional*, defaults to `True`):
|
| 70 |
+
Whether to compute attention in parallel with the feedforward layer. If False, they are consecutive
|
| 71 |
+
instead, as in the original Transformer architecture. Ignored when `new_decoder_architecture` is `True`.
|
| 72 |
+
bias (`bool`, *optional*, defaults to `False`):
|
| 73 |
+
Whether to use bias on Linear layers.
|
| 74 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 75 |
+
The maximum sequence length that this model might ever be used with, when `alibi` is `False`. Pretrained
|
| 76 |
+
Falcon models with RoPE support up to 2048 tokens.
|
| 77 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 78 |
+
The base period of the RoPE embeddings.
|
| 79 |
+
rope_scaling (`Dict`, *optional*):
|
| 80 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 81 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 82 |
+
accordingly.
|
| 83 |
+
Expected contents:
|
| 84 |
+
`rope_type` (`str`):
|
| 85 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 86 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 87 |
+
`factor` (`float`, *optional*):
|
| 88 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 89 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 90 |
+
original maximum pre-trained length.
|
| 91 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 92 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 93 |
+
pretraining.
|
| 94 |
+
`attention_factor` (`float`, *optional*):
|
| 95 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 96 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 97 |
+
`factor` field to infer the suggested value.
|
| 98 |
+
`beta_fast` (`float`, *optional*):
|
| 99 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 100 |
+
ramp function. If unspecified, it defaults to 32.
|
| 101 |
+
`beta_slow` (`float`, *optional*):
|
| 102 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 103 |
+
ramp function. If unspecified, it defaults to 1.
|
| 104 |
+
`short_factor` (`List[float]`, *optional*):
|
| 105 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 106 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 107 |
+
size divided by the number of attention heads divided by 2
|
| 108 |
+
`long_factor` (`List[float]`, *optional*):
|
| 109 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 110 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 111 |
+
size divided by the number of attention heads divided by 2
|
| 112 |
+
`low_freq_factor` (`float`, *optional*):
|
| 113 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 114 |
+
`high_freq_factor` (`float`, *optional*):
|
| 115 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 116 |
+
bos_token_id (`int`, *optional*, defaults to 11):
|
| 117 |
+
The id of the "beginning-of-sequence" token.
|
| 118 |
+
eos_token_id (`int`, *optional*, defaults to 11):
|
| 119 |
+
The id of the "end-of-sequence" token.
|
| 120 |
+
ffn_hidden_size (`int`, *optional*):
|
| 121 |
+
The hidden size of the feedforward layer in the Transformer decoder.
|
| 122 |
+
defaults to 4x hidden dim
|
| 123 |
+
activation (`str`, *optional*, defaults to `"gelu"`):
|
| 124 |
+
The activation function used in the feedforward layer.
|
| 125 |
+
|
| 126 |
+
Example:
|
| 127 |
+
|
| 128 |
+
```python
|
| 129 |
+
>>> from transformers import FalconModel, FalconConfig
|
| 130 |
+
|
| 131 |
+
>>> # Initializing a small (2-layer) Falcon configuration
|
| 132 |
+
>>> configuration = FalconConfig(num_hidden_layers=2)
|
| 133 |
+
|
| 134 |
+
>>> # Initializing a model from the small configuration
|
| 135 |
+
>>> model = FalconModel(configuration)
|
| 136 |
+
|
| 137 |
+
>>> # Accessing the model configuration
|
| 138 |
+
>>> configuration = model.config
|
| 139 |
+
```"""
|
| 140 |
+
|
| 141 |
+
model_type = "falcon"
|
| 142 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 143 |
+
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
vocab_size=65024,
|
| 147 |
+
hidden_size=4544,
|
| 148 |
+
num_hidden_layers=32,
|
| 149 |
+
num_attention_heads=71,
|
| 150 |
+
num_ln_in_parallel_attn=None,
|
| 151 |
+
layer_norm_epsilon=1e-5,
|
| 152 |
+
initializer_range=0.02,
|
| 153 |
+
use_cache=True,
|
| 154 |
+
hidden_dropout=0.0,
|
| 155 |
+
attention_dropout=0.0,
|
| 156 |
+
num_kv_heads=None,
|
| 157 |
+
alibi=False,
|
| 158 |
+
new_decoder_architecture=False,
|
| 159 |
+
multi_query=True,
|
| 160 |
+
parallel_attn=True,
|
| 161 |
+
bias=False,
|
| 162 |
+
max_position_embeddings=2048,
|
| 163 |
+
rope_theta=10000.0,
|
| 164 |
+
rope_scaling=None,
|
| 165 |
+
bos_token_id=11,
|
| 166 |
+
eos_token_id=11,
|
| 167 |
+
ffn_hidden_size=None,
|
| 168 |
+
activation="gelu",
|
| 169 |
+
**kwargs,
|
| 170 |
+
):
|
| 171 |
+
self.vocab_size = vocab_size
|
| 172 |
+
# Backward compatibility with n_embed kwarg
|
| 173 |
+
n_embed = kwargs.pop("n_embed", None)
|
| 174 |
+
self.hidden_size = hidden_size if n_embed is None else n_embed
|
| 175 |
+
self.num_hidden_layers = num_hidden_layers
|
| 176 |
+
self.num_attention_heads = num_attention_heads
|
| 177 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 178 |
+
self.initializer_range = initializer_range
|
| 179 |
+
self.use_cache = use_cache
|
| 180 |
+
self.hidden_dropout = hidden_dropout
|
| 181 |
+
self.attention_dropout = attention_dropout
|
| 182 |
+
self.bos_token_id = bos_token_id
|
| 183 |
+
self.eos_token_id = eos_token_id
|
| 184 |
+
self.num_kv_heads = num_attention_heads if num_kv_heads is None else num_kv_heads
|
| 185 |
+
self.alibi = alibi
|
| 186 |
+
self.new_decoder_architecture = new_decoder_architecture
|
| 187 |
+
self.multi_query = multi_query # Ignored when new_decoder_architecture is True
|
| 188 |
+
self.parallel_attn = parallel_attn
|
| 189 |
+
self.bias = bias
|
| 190 |
+
self.num_ln_in_parallel_attn = num_ln_in_parallel_attn
|
| 191 |
+
self.max_position_embeddings = max_position_embeddings
|
| 192 |
+
self.rope_theta = rope_theta
|
| 193 |
+
self.rope_scaling = rope_scaling
|
| 194 |
+
self.activation = activation
|
| 195 |
+
if ffn_hidden_size is None:
|
| 196 |
+
self.ffn_hidden_size = hidden_size * 4
|
| 197 |
+
else:
|
| 198 |
+
self.ffn_hidden_size = ffn_hidden_size
|
| 199 |
+
|
| 200 |
+
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 201 |
+
|
| 202 |
+
@property
|
| 203 |
+
def head_dim(self):
|
| 204 |
+
return self.hidden_size // self.num_attention_heads
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def rotary(self):
|
| 208 |
+
return not self.alibi
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
__all__ = ["FalconConfig"]
|
docs/transformers/build/lib/transformers/models/falcon/convert_custom_code_checkpoint.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from argparse import ArgumentParser
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
This script converts Falcon custom code checkpoints to modern Falcon checkpoints that use code in the Transformers
|
| 8 |
+
library. After conversion, performance (especially for generation) should improve and the checkpoint can be loaded
|
| 9 |
+
without needing trust_remote_code=True.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
if __name__ == "__main__":
|
| 13 |
+
parser = ArgumentParser()
|
| 14 |
+
parser.add_argument(
|
| 15 |
+
"--checkpoint_dir",
|
| 16 |
+
type=Path,
|
| 17 |
+
required=True,
|
| 18 |
+
help="Directory containing a custom code checkpoint to convert to a modern Falcon checkpoint.",
|
| 19 |
+
)
|
| 20 |
+
args = parser.parse_args()
|
| 21 |
+
|
| 22 |
+
if not args.checkpoint_dir.is_dir():
|
| 23 |
+
raise ValueError("--checkpoint_dir argument should be a directory!")
|
| 24 |
+
|
| 25 |
+
if (
|
| 26 |
+
not (args.checkpoint_dir / "configuration_RW.py").is_file()
|
| 27 |
+
or not (args.checkpoint_dir / "modelling_RW.py").is_file()
|
| 28 |
+
):
|
| 29 |
+
raise ValueError(
|
| 30 |
+
"The model directory should contain configuration_RW.py and modelling_RW.py files! Are you sure this is a custom code checkpoint?"
|
| 31 |
+
)
|
| 32 |
+
(args.checkpoint_dir / "configuration_RW.py").unlink()
|
| 33 |
+
(args.checkpoint_dir / "modelling_RW.py").unlink()
|
| 34 |
+
|
| 35 |
+
config = args.checkpoint_dir / "config.json"
|
| 36 |
+
text = config.read_text()
|
| 37 |
+
text = text.replace("RWForCausalLM", "FalconForCausalLM")
|
| 38 |
+
text = text.replace("RefinedWebModel", "falcon")
|
| 39 |
+
text = text.replace("RefinedWeb", "falcon")
|
| 40 |
+
json_config = json.loads(text)
|
| 41 |
+
del json_config["auto_map"]
|
| 42 |
+
|
| 43 |
+
if "n_head" in json_config:
|
| 44 |
+
json_config["num_attention_heads"] = json_config.pop("n_head")
|
| 45 |
+
if "n_layer" in json_config:
|
| 46 |
+
json_config["num_hidden_layers"] = json_config.pop("n_layer")
|
| 47 |
+
if "n_head_kv" in json_config:
|
| 48 |
+
json_config["num_kv_heads"] = json_config.pop("n_head_kv")
|
| 49 |
+
json_config["new_decoder_architecture"] = True
|
| 50 |
+
else:
|
| 51 |
+
json_config["new_decoder_architecture"] = False
|
| 52 |
+
bos_token_id = json_config.get("bos_token_id", 1)
|
| 53 |
+
eos_token_id = json_config.get("eos_token_id", 2)
|
| 54 |
+
config.unlink()
|
| 55 |
+
config.write_text(json.dumps(json_config, indent=2, sort_keys=True))
|
| 56 |
+
|
| 57 |
+
tokenizer_config = args.checkpoint_dir / "tokenizer_config.json"
|
| 58 |
+
if tokenizer_config.is_file():
|
| 59 |
+
text = tokenizer_config.read_text()
|
| 60 |
+
json_config = json.loads(text)
|
| 61 |
+
if json_config["tokenizer_class"] == "PreTrainedTokenizerFast":
|
| 62 |
+
json_config["model_input_names"] = ["input_ids", "attention_mask"]
|
| 63 |
+
tokenizer_config.unlink()
|
| 64 |
+
tokenizer_config.write_text(json.dumps(json_config, indent=2, sort_keys=True))
|
| 65 |
+
|
| 66 |
+
generation_config_path = args.checkpoint_dir / "generation_config.json"
|
| 67 |
+
generation_dict = {
|
| 68 |
+
"_from_model_config": True,
|
| 69 |
+
"bos_token_id": bos_token_id,
|
| 70 |
+
"eos_token_id": eos_token_id,
|
| 71 |
+
"transformers_version": "4.33.0.dev0",
|
| 72 |
+
}
|
| 73 |
+
generation_config_path.write_text(json.dumps(generation_dict, indent=2, sort_keys=True))
|
| 74 |
+
print("Done! Please double-check that the new checkpoint works as expected.")
|
docs/transformers/build/lib/transformers/models/falcon/modeling_falcon.py
ADDED
|
@@ -0,0 +1,1566 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch Falcon model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.utils.checkpoint
|
| 22 |
+
from torch import nn
|
| 23 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
| 24 |
+
from torch.nn import functional as F
|
| 25 |
+
|
| 26 |
+
from ...activations import get_activation
|
| 27 |
+
from ...cache_utils import Cache, DynamicCache, StaticCache
|
| 28 |
+
from ...generation import GenerationMixin
|
| 29 |
+
from ...modeling_attn_mask_utils import (
|
| 30 |
+
AttentionMaskConverter,
|
| 31 |
+
)
|
| 32 |
+
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
| 33 |
+
from ...modeling_outputs import (
|
| 34 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 35 |
+
CausalLMOutputWithCrossAttentions,
|
| 36 |
+
QuestionAnsweringModelOutput,
|
| 37 |
+
SequenceClassifierOutputWithPast,
|
| 38 |
+
TokenClassifierOutput,
|
| 39 |
+
)
|
| 40 |
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 41 |
+
from ...modeling_utils import PreTrainedModel
|
| 42 |
+
from ...utils import (
|
| 43 |
+
add_code_sample_docstrings,
|
| 44 |
+
add_start_docstrings,
|
| 45 |
+
add_start_docstrings_to_model_forward,
|
| 46 |
+
logging,
|
| 47 |
+
)
|
| 48 |
+
from .configuration_falcon import FalconConfig
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if TYPE_CHECKING:
|
| 52 |
+
from ...configuration_utils import PretrainedConfig
|
| 53 |
+
|
| 54 |
+
if is_flash_attn_available():
|
| 55 |
+
from ...modeling_flash_attention_utils import _flash_attention_forward
|
| 56 |
+
|
| 57 |
+
logger = logging.get_logger(__name__)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
_CHECKPOINT_FOR_DOC = "Rocketknight1/falcon-rw-1b"
|
| 61 |
+
_CONFIG_FOR_DOC = "FalconConfig"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
|
| 65 |
+
# In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
|
| 66 |
+
class FalconLinear(nn.Linear):
|
| 67 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 68 |
+
hidden_states = input @ self.weight.T
|
| 69 |
+
if self.bias is None:
|
| 70 |
+
return hidden_states
|
| 71 |
+
return hidden_states + self.bias
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 75 |
+
def rotate_half(x):
|
| 76 |
+
"""Rotates half the hidden dims of the input."""
|
| 77 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 78 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 79 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
| 83 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 84 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
q (`torch.Tensor`): The query tensor.
|
| 88 |
+
k (`torch.Tensor`): The key tensor.
|
| 89 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 90 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 91 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 92 |
+
Deprecated and unused.
|
| 93 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 94 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 95 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 96 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 97 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 98 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 99 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 100 |
+
Returns:
|
| 101 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 102 |
+
"""
|
| 103 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 104 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 105 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 106 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 107 |
+
return q_embed, k_embed
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Falcon
|
| 111 |
+
class FalconRotaryEmbedding(nn.Module):
|
| 112 |
+
def __init__(self, config: FalconConfig, device=None):
|
| 113 |
+
super().__init__()
|
| 114 |
+
# BC: "rope_type" was originally "type"
|
| 115 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 116 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 117 |
+
else:
|
| 118 |
+
self.rope_type = "default"
|
| 119 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 120 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 121 |
+
|
| 122 |
+
self.config = config
|
| 123 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 124 |
+
|
| 125 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 126 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 127 |
+
self.original_inv_freq = self.inv_freq
|
| 128 |
+
|
| 129 |
+
@torch.no_grad()
|
| 130 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 131 |
+
def forward(self, x, position_ids):
|
| 132 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 133 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 134 |
+
|
| 135 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 136 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 137 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 138 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 139 |
+
cos = emb.cos() * self.attention_scaling
|
| 140 |
+
sin = emb.sin() * self.attention_scaling
|
| 141 |
+
|
| 142 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
| 146 |
+
batch_size, seq_length = attention_mask.shape
|
| 147 |
+
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
| 148 |
+
base = torch.tensor(
|
| 149 |
+
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
|
| 150 |
+
)
|
| 151 |
+
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
|
| 152 |
+
slopes = torch.pow(base, powers)
|
| 153 |
+
|
| 154 |
+
if closest_power_of_2 != num_heads:
|
| 155 |
+
extra_base = torch.tensor(
|
| 156 |
+
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
|
| 157 |
+
)
|
| 158 |
+
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
| 159 |
+
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
|
| 160 |
+
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
| 161 |
+
|
| 162 |
+
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
|
| 163 |
+
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
|
| 164 |
+
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
|
| 165 |
+
# => the query_length dimension will then be broadcasted correctly
|
| 166 |
+
# This is more or less identical to T5's relative position bias:
|
| 167 |
+
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
|
| 168 |
+
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
|
| 169 |
+
alibi = slopes[..., None].bfloat16() * arange_tensor
|
| 170 |
+
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# Copied from transformers.models.bloom.modeling_bloom.dropout_add
|
| 174 |
+
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
|
| 175 |
+
"""
|
| 176 |
+
Dropout add function
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
x (`torch.tensor`):
|
| 180 |
+
input tensor
|
| 181 |
+
residual (`torch.tensor`):
|
| 182 |
+
residual tensor
|
| 183 |
+
prob (`float`):
|
| 184 |
+
dropout probability
|
| 185 |
+
training (`bool`):
|
| 186 |
+
training mode
|
| 187 |
+
"""
|
| 188 |
+
out = F.dropout(x, p=prob, training=training)
|
| 189 |
+
out = residual + out
|
| 190 |
+
return out
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class FalconAttention(nn.Module):
|
| 194 |
+
def __init__(self, config: FalconConfig, layer_idx=None):
|
| 195 |
+
super().__init__()
|
| 196 |
+
|
| 197 |
+
self.config = config
|
| 198 |
+
self.hidden_size = config.hidden_size
|
| 199 |
+
self.num_heads = config.num_attention_heads
|
| 200 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 201 |
+
self.split_size = self.hidden_size
|
| 202 |
+
self.hidden_dropout = config.hidden_dropout
|
| 203 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 204 |
+
self.rope_theta = config.rope_theta
|
| 205 |
+
self.is_causal = True
|
| 206 |
+
self._use_sdpa = config._attn_implementation == "sdpa"
|
| 207 |
+
self.layer_idx = layer_idx
|
| 208 |
+
if layer_idx is None:
|
| 209 |
+
logger.warning_once(
|
| 210 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
| 211 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
| 212 |
+
"when creating this class."
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if self.head_dim * self.num_heads != self.hidden_size:
|
| 216 |
+
raise ValueError(
|
| 217 |
+
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
|
| 218 |
+
f" {self.num_heads})."
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Layer-wise attention scaling
|
| 222 |
+
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
| 223 |
+
self.beta = self.inv_norm_factor
|
| 224 |
+
if config.new_decoder_architecture:
|
| 225 |
+
qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
|
| 226 |
+
elif config.multi_query:
|
| 227 |
+
qkv_out_dim = self.hidden_size + 2 * self.head_dim
|
| 228 |
+
else:
|
| 229 |
+
qkv_out_dim = 3 * self.hidden_size
|
| 230 |
+
self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
|
| 231 |
+
self.new_decoder_architecture = config.new_decoder_architecture
|
| 232 |
+
self.multi_query = config.multi_query
|
| 233 |
+
self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
|
| 234 |
+
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
| 235 |
+
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
|
| 236 |
+
|
| 237 |
+
# TODO (raushan): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
|
| 238 |
+
if config.rotary:
|
| 239 |
+
self.rotary_emb = FalconRotaryEmbedding(config=self.config)
|
| 240 |
+
|
| 241 |
+
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 242 |
+
"""
|
| 243 |
+
Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim]
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
|
| 250 |
+
value: [batch_size, seq_length, num_heads, head_dim]
|
| 251 |
+
"""
|
| 252 |
+
if self.new_decoder_architecture:
|
| 253 |
+
batch, seq_len, _ = fused_qkv.shape
|
| 254 |
+
qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
|
| 255 |
+
query = qkv[:, :, :, :-2]
|
| 256 |
+
key = qkv[:, :, :, [-2]]
|
| 257 |
+
value = qkv[:, :, :, [-1]]
|
| 258 |
+
key = torch.broadcast_to(key, query.shape)
|
| 259 |
+
value = torch.broadcast_to(value, query.shape)
|
| 260 |
+
|
| 261 |
+
query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
|
| 262 |
+
return query, key, value
|
| 263 |
+
elif not self.multi_query:
|
| 264 |
+
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
|
| 265 |
+
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
|
| 266 |
+
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
|
| 267 |
+
else:
|
| 268 |
+
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
|
| 269 |
+
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
|
| 270 |
+
return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
|
| 271 |
+
|
| 272 |
+
# Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads
|
| 273 |
+
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
| 274 |
+
"""
|
| 275 |
+
Merge heads together over the last dimension
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
x (`torch.tensor`): [batch_size * num_heads, seq_length, head_dim]
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
|
| 282 |
+
"""
|
| 283 |
+
# What we want to achieve is:
|
| 284 |
+
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
|
| 285 |
+
batch_size_and_num_heads, seq_length, _ = x.shape
|
| 286 |
+
batch_size = batch_size_and_num_heads // self.num_heads
|
| 287 |
+
|
| 288 |
+
# First view to decompose the batch size
|
| 289 |
+
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
|
| 290 |
+
x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
|
| 291 |
+
|
| 292 |
+
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
|
| 293 |
+
x = x.permute(0, 2, 1, 3)
|
| 294 |
+
|
| 295 |
+
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
|
| 296 |
+
return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
|
| 297 |
+
|
| 298 |
+
def forward(
|
| 299 |
+
self,
|
| 300 |
+
hidden_states: torch.Tensor,
|
| 301 |
+
alibi: Optional[torch.Tensor],
|
| 302 |
+
attention_mask: torch.Tensor,
|
| 303 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 304 |
+
layer_past: Optional[Cache] = None,
|
| 305 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 306 |
+
use_cache: bool = False,
|
| 307 |
+
output_attentions: bool = False,
|
| 308 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 309 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 310 |
+
):
|
| 311 |
+
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
| 312 |
+
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
| 313 |
+
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
| 314 |
+
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
| 315 |
+
|
| 316 |
+
batch_size, query_length, _, _ = query_layer.shape
|
| 317 |
+
|
| 318 |
+
query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim)
|
| 319 |
+
key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
|
| 320 |
+
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
|
| 321 |
+
|
| 322 |
+
if alibi is None:
|
| 323 |
+
cos, sin = position_embeddings
|
| 324 |
+
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
|
| 325 |
+
|
| 326 |
+
if layer_past is not None:
|
| 327 |
+
cache_kwargs = {"cache_position": cache_position}
|
| 328 |
+
if alibi is None:
|
| 329 |
+
cache_kwargs.update({"sin": sin, "cos": cos})
|
| 330 |
+
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
|
| 331 |
+
|
| 332 |
+
kv_length = key_layer.shape[-2]
|
| 333 |
+
if self._use_sdpa and query_layer.device.type == "cuda" and attention_mask is not None:
|
| 334 |
+
# For torch<=2.1.2, SDPA with memory-efficient backend is bugged with non-contiguous inputs with custom attn_mask,
|
| 335 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 336 |
+
query_layer = query_layer.contiguous()
|
| 337 |
+
key_layer = key_layer.contiguous()
|
| 338 |
+
value_layer = value_layer.contiguous()
|
| 339 |
+
|
| 340 |
+
if attention_mask is not None:
|
| 341 |
+
attention_mask = attention_mask[:, :, :, : key_layer.shape[-2]]
|
| 342 |
+
|
| 343 |
+
if alibi is None:
|
| 344 |
+
if self._use_sdpa and not output_attentions:
|
| 345 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
|
| 346 |
+
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
|
| 347 |
+
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not
|
| 348 |
+
# create a causal mask in case query_length == 1.
|
| 349 |
+
is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False
|
| 350 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 351 |
+
query_layer,
|
| 352 |
+
key_layer,
|
| 353 |
+
value_layer,
|
| 354 |
+
attn_mask=attention_mask,
|
| 355 |
+
dropout_p=0.0,
|
| 356 |
+
is_causal=is_causal,
|
| 357 |
+
)
|
| 358 |
+
attention_scores = None
|
| 359 |
+
else:
|
| 360 |
+
attention_scores = query_layer @ key_layer.transpose(-1, -2)
|
| 361 |
+
attention_scores /= math.sqrt(self.head_dim)
|
| 362 |
+
|
| 363 |
+
attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
|
| 364 |
+
# It is unclear why neither dropout nor head_mask is applied here (while it is with alibi).
|
| 365 |
+
attn_output = attention_scores @ value_layer
|
| 366 |
+
|
| 367 |
+
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
|
| 368 |
+
attn_output = attn_output.permute(0, 2, 1, 3)
|
| 369 |
+
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
|
| 370 |
+
|
| 371 |
+
attn_output = self.dense(attn_output)
|
| 372 |
+
|
| 373 |
+
if output_attentions:
|
| 374 |
+
return attn_output, layer_past, attention_scores
|
| 375 |
+
else:
|
| 376 |
+
return attn_output, layer_past
|
| 377 |
+
|
| 378 |
+
else:
|
| 379 |
+
if self._use_sdpa and not output_attentions and head_mask is None:
|
| 380 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
|
| 381 |
+
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
|
| 382 |
+
is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False
|
| 383 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 384 |
+
query_layer,
|
| 385 |
+
key_layer,
|
| 386 |
+
value_layer,
|
| 387 |
+
attn_mask=attention_mask,
|
| 388 |
+
dropout_p=self.attention_dropout.p if self.training else 0.0,
|
| 389 |
+
is_causal=is_causal,
|
| 390 |
+
)
|
| 391 |
+
attn_output = attn_output.transpose(1, 2)
|
| 392 |
+
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
|
| 393 |
+
|
| 394 |
+
attn_output = self.dense(attn_output)
|
| 395 |
+
else:
|
| 396 |
+
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
| 397 |
+
|
| 398 |
+
# change view to [batch_size, num_heads, q_length, kv_length]
|
| 399 |
+
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
|
| 400 |
+
|
| 401 |
+
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
| 402 |
+
input_dtype = attention_scores.dtype
|
| 403 |
+
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
| 404 |
+
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
| 405 |
+
attention_scores = attention_scores.to(torch.float32)
|
| 406 |
+
|
| 407 |
+
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
|
| 408 |
+
attention_logits *= self.inv_norm_factor
|
| 409 |
+
attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype)
|
| 410 |
+
# [batch_size, num_heads, q_length, kv_length]
|
| 411 |
+
attention_probs = self.attention_dropout(attention_probs)
|
| 412 |
+
|
| 413 |
+
if head_mask is not None:
|
| 414 |
+
attention_probs = attention_probs * head_mask
|
| 415 |
+
|
| 416 |
+
# change view [batch_size, num_heads, q_length, kv_length]
|
| 417 |
+
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
|
| 418 |
+
|
| 419 |
+
# matmul: [batch_size * num_heads, q_length, head_dim]
|
| 420 |
+
attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1)
|
| 421 |
+
|
| 422 |
+
# change view [batch_size, q_length, num_heads * head_dim]
|
| 423 |
+
attn_output = self._merge_heads(attn_output)
|
| 424 |
+
|
| 425 |
+
attn_output = self.dense(attn_output)
|
| 426 |
+
|
| 427 |
+
if output_attentions:
|
| 428 |
+
return attn_output, layer_past, attention_probs
|
| 429 |
+
else:
|
| 430 |
+
return attn_output, layer_past
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class FalconFlashAttention2(FalconAttention):
|
| 434 |
+
"""
|
| 435 |
+
Falcon flash attention module. This module inherits from `FalconAttention` as the weights of the module stays
|
| 436 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 437 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
| 438 |
+
"""
|
| 439 |
+
|
| 440 |
+
def __init__(self, *args, **kwargs):
|
| 441 |
+
super().__init__(*args, **kwargs)
|
| 442 |
+
|
| 443 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 444 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 445 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 446 |
+
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
| 447 |
+
|
| 448 |
+
def forward(
|
| 449 |
+
self,
|
| 450 |
+
hidden_states: torch.Tensor,
|
| 451 |
+
alibi: Optional[torch.Tensor],
|
| 452 |
+
attention_mask: torch.Tensor,
|
| 453 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 454 |
+
layer_past: Optional[Cache] = None,
|
| 455 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 456 |
+
use_cache: bool = False,
|
| 457 |
+
output_attentions: bool = False,
|
| 458 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 459 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 460 |
+
):
|
| 461 |
+
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
| 462 |
+
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
| 463 |
+
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
| 464 |
+
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
| 465 |
+
|
| 466 |
+
batch_size, query_length, _, _ = query_layer.shape
|
| 467 |
+
|
| 468 |
+
query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim)
|
| 469 |
+
key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
|
| 470 |
+
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
|
| 471 |
+
|
| 472 |
+
if alibi is None:
|
| 473 |
+
cos, sin = position_embeddings
|
| 474 |
+
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
|
| 475 |
+
|
| 476 |
+
if layer_past is not None:
|
| 477 |
+
cache_kwargs = {"cache_position": cache_position}
|
| 478 |
+
if alibi is None:
|
| 479 |
+
cache_kwargs.update({"sin": sin, "cos": cos})
|
| 480 |
+
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
|
| 481 |
+
|
| 482 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
| 483 |
+
# to be able to avoid many of these transpose/reshape/view.
|
| 484 |
+
query_layer = query_layer.transpose(1, 2)
|
| 485 |
+
key_layer = key_layer.transpose(1, 2)
|
| 486 |
+
value_layer = value_layer.transpose(1, 2)
|
| 487 |
+
|
| 488 |
+
if alibi is not None:
|
| 489 |
+
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
|
| 490 |
+
|
| 491 |
+
attn_dropout = self.config.attention_dropout if self.training else 0.0
|
| 492 |
+
|
| 493 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 494 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 495 |
+
# cast them back in float16 just to be sure everything works as expected.
|
| 496 |
+
input_dtype = query_layer.dtype
|
| 497 |
+
if input_dtype == torch.float32:
|
| 498 |
+
if torch.is_autocast_enabled():
|
| 499 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 500 |
+
# Handle the case where the model is quantized
|
| 501 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 502 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 503 |
+
else:
|
| 504 |
+
target_dtype = self.query_key_value.weight.dtype
|
| 505 |
+
|
| 506 |
+
logger.warning_once(
|
| 507 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
| 508 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 509 |
+
f" {target_dtype}."
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
query_layer = query_layer.to(target_dtype)
|
| 513 |
+
key_layer = key_layer.to(target_dtype)
|
| 514 |
+
value_layer = value_layer.to(target_dtype)
|
| 515 |
+
|
| 516 |
+
attn_output = _flash_attention_forward(
|
| 517 |
+
query_layer,
|
| 518 |
+
key_layer,
|
| 519 |
+
value_layer,
|
| 520 |
+
attention_mask,
|
| 521 |
+
query_length,
|
| 522 |
+
position_ids=position_ids,
|
| 523 |
+
dropout=attn_dropout,
|
| 524 |
+
is_causal=self.is_causal,
|
| 525 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
attn_weights = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
|
| 529 |
+
attn_output = self.dense(attn_weights)
|
| 530 |
+
|
| 531 |
+
if not output_attentions:
|
| 532 |
+
attn_weights = None
|
| 533 |
+
|
| 534 |
+
return attn_output, layer_past, attn_weights
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
class FalconMLP(nn.Module):
|
| 538 |
+
def __init__(self, config: FalconConfig):
|
| 539 |
+
super().__init__()
|
| 540 |
+
hidden_size = config.hidden_size
|
| 541 |
+
|
| 542 |
+
self.dense_h_to_4h = FalconLinear(hidden_size, config.ffn_hidden_size, bias=config.bias)
|
| 543 |
+
self.act = get_activation(config.activation)
|
| 544 |
+
self.dense_4h_to_h = FalconLinear(config.ffn_hidden_size, hidden_size, bias=config.bias)
|
| 545 |
+
self.hidden_dropout = config.hidden_dropout
|
| 546 |
+
|
| 547 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 548 |
+
x = self.act(self.dense_h_to_4h(x))
|
| 549 |
+
x = self.dense_4h_to_h(x)
|
| 550 |
+
return x
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
FALCON_ATTENTION_CLASSES = {
|
| 554 |
+
"eager": FalconAttention,
|
| 555 |
+
"sdpa": FalconAttention, # FalconAttention originally implemented both a forward with & without SDPA
|
| 556 |
+
"flash_attention_2": FalconFlashAttention2,
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
class FalconDecoderLayer(nn.Module):
|
| 561 |
+
def __init__(self, config: FalconConfig, layer_idx=None):
|
| 562 |
+
super().__init__()
|
| 563 |
+
hidden_size = config.hidden_size
|
| 564 |
+
self.num_heads = config.num_attention_heads
|
| 565 |
+
|
| 566 |
+
self.self_attention = FALCON_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
| 567 |
+
self.mlp = FalconMLP(config)
|
| 568 |
+
self.hidden_dropout = config.hidden_dropout
|
| 569 |
+
self.config = config
|
| 570 |
+
|
| 571 |
+
if config.num_ln_in_parallel_attn is None and config.new_decoder_architecture:
|
| 572 |
+
config.num_ln_in_parallel_attn = 2
|
| 573 |
+
|
| 574 |
+
if not config.parallel_attn:
|
| 575 |
+
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 576 |
+
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 577 |
+
else:
|
| 578 |
+
if config.num_ln_in_parallel_attn == 2:
|
| 579 |
+
# The layer norm before self-attention
|
| 580 |
+
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 581 |
+
# The layer norm before the MLP
|
| 582 |
+
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 583 |
+
else:
|
| 584 |
+
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 585 |
+
|
| 586 |
+
def forward(
|
| 587 |
+
self,
|
| 588 |
+
hidden_states: torch.Tensor,
|
| 589 |
+
alibi: Optional[torch.Tensor],
|
| 590 |
+
attention_mask: torch.Tensor,
|
| 591 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 592 |
+
layer_past: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 593 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 594 |
+
use_cache: bool = False,
|
| 595 |
+
output_attentions: bool = False,
|
| 596 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 597 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 598 |
+
**kwargs,
|
| 599 |
+
):
|
| 600 |
+
residual = hidden_states
|
| 601 |
+
|
| 602 |
+
if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2:
|
| 603 |
+
attention_layernorm_out = self.ln_attn(hidden_states)
|
| 604 |
+
mlp_layernorm_out = self.ln_mlp(hidden_states)
|
| 605 |
+
else:
|
| 606 |
+
attention_layernorm_out = self.input_layernorm(hidden_states)
|
| 607 |
+
|
| 608 |
+
# Self attention.
|
| 609 |
+
attn_outputs = self.self_attention(
|
| 610 |
+
attention_layernorm_out,
|
| 611 |
+
layer_past=layer_past,
|
| 612 |
+
attention_mask=attention_mask,
|
| 613 |
+
position_ids=position_ids,
|
| 614 |
+
alibi=alibi,
|
| 615 |
+
head_mask=head_mask,
|
| 616 |
+
use_cache=use_cache,
|
| 617 |
+
output_attentions=output_attentions,
|
| 618 |
+
cache_position=cache_position,
|
| 619 |
+
position_embeddings=position_embeddings,
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
attention_output = attn_outputs[0]
|
| 623 |
+
|
| 624 |
+
if not self.config.new_decoder_architecture:
|
| 625 |
+
if self.config.parallel_attn:
|
| 626 |
+
mlp_layernorm_out = attention_layernorm_out
|
| 627 |
+
else:
|
| 628 |
+
residual = dropout_add(
|
| 629 |
+
attention_output, residual, self.config.attention_dropout, training=self.training
|
| 630 |
+
)
|
| 631 |
+
mlp_layernorm_out = self.post_attention_layernorm(residual)
|
| 632 |
+
|
| 633 |
+
if (
|
| 634 |
+
self.config.new_decoder_architecture
|
| 635 |
+
and self.config.parallel_attn
|
| 636 |
+
and self.config.num_ln_in_parallel_attn == 1
|
| 637 |
+
):
|
| 638 |
+
mlp_layernorm_out = attention_layernorm_out
|
| 639 |
+
|
| 640 |
+
outputs = attn_outputs[1:]
|
| 641 |
+
|
| 642 |
+
# MLP.
|
| 643 |
+
mlp_output = self.mlp(mlp_layernorm_out)
|
| 644 |
+
|
| 645 |
+
if self.config.new_decoder_architecture or self.config.parallel_attn:
|
| 646 |
+
mlp_output += attention_output
|
| 647 |
+
|
| 648 |
+
output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
|
| 649 |
+
|
| 650 |
+
if use_cache:
|
| 651 |
+
outputs = (output,) + outputs
|
| 652 |
+
else:
|
| 653 |
+
outputs = (output,) + outputs[1:]
|
| 654 |
+
|
| 655 |
+
return outputs # hidden_states, past_kv, attentions
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
FALCON_START_DOCSTRING = r"""
|
| 659 |
+
|
| 660 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 661 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
|
| 662 |
+
|
| 663 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 664 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 665 |
+
and behavior.
|
| 666 |
+
|
| 667 |
+
Parameters:
|
| 668 |
+
config ([`FalconConfig`]): Model configuration class with all the parameters of the model.
|
| 669 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 670 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 671 |
+
"""
|
| 672 |
+
|
| 673 |
+
FALCON_INPUTS_DOCSTRING = r"""
|
| 674 |
+
Args:
|
| 675 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
| 676 |
+
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
|
| 677 |
+
(`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
|
| 678 |
+
|
| 679 |
+
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
| 680 |
+
`input_ids`.
|
| 681 |
+
|
| 682 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 683 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 684 |
+
|
| 685 |
+
[What are input IDs?](../glossary#input-ids)
|
| 686 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
| 687 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 688 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 689 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 690 |
+
|
| 691 |
+
Two formats are allowed:
|
| 692 |
+
- a [`~cache_utils.Cache`] instance, see our
|
| 693 |
+
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
| 694 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 695 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
| 696 |
+
cache format.
|
| 697 |
+
|
| 698 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
| 699 |
+
legacy cache format will be returned.
|
| 700 |
+
|
| 701 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 702 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 703 |
+
of shape `(batch_size, sequence_length)`.
|
| 704 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 705 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 706 |
+
|
| 707 |
+
- 1 for tokens that are **not masked**,
|
| 708 |
+
- 0 for tokens that are **masked**.
|
| 709 |
+
|
| 710 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 711 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 712 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 713 |
+
config.n_positions - 1]`.
|
| 714 |
+
|
| 715 |
+
[What are position IDs?](../glossary#position-ids)
|
| 716 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 717 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 718 |
+
|
| 719 |
+
- 1 indicates the head is **not masked**,
|
| 720 |
+
- 0 indicates the head is **masked**.
|
| 721 |
+
|
| 722 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 723 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 724 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 725 |
+
model's internal embedding lookup matrix.
|
| 726 |
+
|
| 727 |
+
If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
|
| 728 |
+
`past_key_values`).
|
| 729 |
+
use_cache (`bool`, *optional*):
|
| 730 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 731 |
+
`past_key_values`).
|
| 732 |
+
output_attentions (`bool`, *optional*):
|
| 733 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 734 |
+
tensors for more detail.
|
| 735 |
+
output_hidden_states (`bool`, *optional*):
|
| 736 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 737 |
+
more detail.
|
| 738 |
+
return_dict (`bool`, *optional*):
|
| 739 |
+
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
| 740 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 741 |
+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
| 742 |
+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
| 743 |
+
the complete sequence length.
|
| 744 |
+
"""
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
class FalconPreTrainedModel(PreTrainedModel):
|
| 748 |
+
"""
|
| 749 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 750 |
+
models.
|
| 751 |
+
"""
|
| 752 |
+
|
| 753 |
+
config_class = FalconConfig
|
| 754 |
+
base_model_prefix = "transformer"
|
| 755 |
+
supports_gradient_checkpointing = True
|
| 756 |
+
_no_split_modules = ["FalconDecoderLayer"]
|
| 757 |
+
_supports_flash_attn_2 = True
|
| 758 |
+
_supports_sdpa = True
|
| 759 |
+
_supports_cache_class = True
|
| 760 |
+
_supports_quantized_cache = True
|
| 761 |
+
_supports_static_cache = True
|
| 762 |
+
|
| 763 |
+
def __init__(self, *inputs, **kwargs):
|
| 764 |
+
super().__init__(*inputs, **kwargs)
|
| 765 |
+
|
| 766 |
+
def _init_weights(self, module: nn.Module):
|
| 767 |
+
"""Initialize the weights."""
|
| 768 |
+
if isinstance(module, nn.Linear) or isinstance(module, FalconLinear):
|
| 769 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 770 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 771 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 772 |
+
if module.bias is not None:
|
| 773 |
+
module.bias.data.zero_()
|
| 774 |
+
elif isinstance(module, nn.Embedding):
|
| 775 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 776 |
+
if module.padding_idx is not None:
|
| 777 |
+
module.weight.data[module.padding_idx].zero_()
|
| 778 |
+
elif isinstance(module, LayerNorm):
|
| 779 |
+
module.bias.data.zero_()
|
| 780 |
+
module.weight.data.fill_(1.0)
|
| 781 |
+
|
| 782 |
+
# Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
|
| 783 |
+
@classmethod
|
| 784 |
+
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig":
|
| 785 |
+
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
|
| 786 |
+
if _is_bettertransformer:
|
| 787 |
+
return config
|
| 788 |
+
|
| 789 |
+
if not hard_check_only:
|
| 790 |
+
config._attn_implementation = "sdpa"
|
| 791 |
+
return config
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
@add_start_docstrings(
|
| 795 |
+
"The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
|
| 796 |
+
FALCON_START_DOCSTRING,
|
| 797 |
+
)
|
| 798 |
+
class FalconModel(FalconPreTrainedModel):
|
| 799 |
+
def __init__(self, config: FalconConfig):
|
| 800 |
+
super().__init__(config)
|
| 801 |
+
|
| 802 |
+
self.embed_dim = config.hidden_size
|
| 803 |
+
self.num_heads = config.num_attention_heads
|
| 804 |
+
self.use_alibi = config.alibi
|
| 805 |
+
|
| 806 |
+
# Embedding + LN Embedding
|
| 807 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
|
| 808 |
+
|
| 809 |
+
# Transformer blocks
|
| 810 |
+
self.h = nn.ModuleList([FalconDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
| 811 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 812 |
+
self._use_sdpa = config._attn_implementation == "sdpa"
|
| 813 |
+
|
| 814 |
+
# Final Layer Norm
|
| 815 |
+
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
| 816 |
+
|
| 817 |
+
self.rotary_emb = FalconRotaryEmbedding(config=config)
|
| 818 |
+
|
| 819 |
+
self.gradient_checkpointing = False
|
| 820 |
+
|
| 821 |
+
# Initialize weights and apply final processing
|
| 822 |
+
self.post_init()
|
| 823 |
+
|
| 824 |
+
def get_input_embeddings(self):
|
| 825 |
+
return self.word_embeddings
|
| 826 |
+
|
| 827 |
+
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
| 828 |
+
self.word_embeddings = new_embeddings
|
| 829 |
+
|
| 830 |
+
@add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
|
| 831 |
+
@add_code_sample_docstrings(
|
| 832 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 833 |
+
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
| 834 |
+
config_class=_CONFIG_FOR_DOC,
|
| 835 |
+
)
|
| 836 |
+
def forward(
|
| 837 |
+
self,
|
| 838 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 839 |
+
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
|
| 840 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 841 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 842 |
+
head_mask: Optional[torch.LongTensor] = None,
|
| 843 |
+
inputs_embeds: Optional[torch.LongTensor] = None,
|
| 844 |
+
use_cache: Optional[bool] = None,
|
| 845 |
+
output_attentions: Optional[bool] = None,
|
| 846 |
+
output_hidden_states: Optional[bool] = None,
|
| 847 |
+
return_dict: Optional[bool] = None,
|
| 848 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 849 |
+
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
| 850 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 851 |
+
output_hidden_states = (
|
| 852 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 853 |
+
)
|
| 854 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 855 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 856 |
+
|
| 857 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 858 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 859 |
+
|
| 860 |
+
if self.gradient_checkpointing and self.training:
|
| 861 |
+
if use_cache:
|
| 862 |
+
logger.warning_once(
|
| 863 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 864 |
+
)
|
| 865 |
+
use_cache = False
|
| 866 |
+
|
| 867 |
+
if inputs_embeds is None:
|
| 868 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 869 |
+
|
| 870 |
+
# kept for BC (non `Cache` `past_key_values` inputs)
|
| 871 |
+
return_legacy_cache = False
|
| 872 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
| 873 |
+
return_legacy_cache = True
|
| 874 |
+
if past_key_values is None:
|
| 875 |
+
past_key_values = DynamicCache()
|
| 876 |
+
else:
|
| 877 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 878 |
+
logger.warning_once(
|
| 879 |
+
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
| 880 |
+
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
| 881 |
+
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
# Compute alibi tensor: check build_alibi_tensor documentation
|
| 885 |
+
alibi = None
|
| 886 |
+
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 887 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 888 |
+
if self.use_alibi:
|
| 889 |
+
mask = (
|
| 890 |
+
torch.ones(
|
| 891 |
+
(batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long
|
| 892 |
+
)
|
| 893 |
+
if attention_mask is None
|
| 894 |
+
else attention_mask
|
| 895 |
+
)
|
| 896 |
+
alibi = build_alibi_tensor(mask, self.num_heads, dtype=inputs_embeds.dtype)
|
| 897 |
+
|
| 898 |
+
if cache_position is None:
|
| 899 |
+
cache_position = torch.arange(
|
| 900 |
+
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
if position_ids is None:
|
| 904 |
+
position_ids = cache_position.unsqueeze(0)
|
| 905 |
+
|
| 906 |
+
causal_mask = self._update_causal_mask(
|
| 907 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, head_mask, alibi
|
| 908 |
+
)
|
| 909 |
+
|
| 910 |
+
# Prepare head mask if needed
|
| 911 |
+
# 1.0 in head_mask indicate we keep the head
|
| 912 |
+
# attention_probs has shape batch_size x num_heads x N x N
|
| 913 |
+
# head_mask has shape n_layer x batch x num_heads x N x N
|
| 914 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 915 |
+
hidden_states = inputs_embeds
|
| 916 |
+
|
| 917 |
+
# create position embeddings to be shared across the decoder layers
|
| 918 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 919 |
+
|
| 920 |
+
next_decoder_cache = None
|
| 921 |
+
all_self_attentions = () if output_attentions else None
|
| 922 |
+
all_hidden_states = () if output_hidden_states else None
|
| 923 |
+
|
| 924 |
+
for i, block in enumerate(self.h):
|
| 925 |
+
if output_hidden_states:
|
| 926 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 927 |
+
|
| 928 |
+
if self.gradient_checkpointing and self.training:
|
| 929 |
+
outputs = self._gradient_checkpointing_func(
|
| 930 |
+
block.__call__,
|
| 931 |
+
hidden_states,
|
| 932 |
+
alibi,
|
| 933 |
+
causal_mask,
|
| 934 |
+
position_ids,
|
| 935 |
+
head_mask[i],
|
| 936 |
+
past_key_values,
|
| 937 |
+
use_cache,
|
| 938 |
+
output_attentions,
|
| 939 |
+
cache_position,
|
| 940 |
+
position_embeddings,
|
| 941 |
+
)
|
| 942 |
+
else:
|
| 943 |
+
outputs = block(
|
| 944 |
+
hidden_states,
|
| 945 |
+
layer_past=past_key_values,
|
| 946 |
+
attention_mask=causal_mask,
|
| 947 |
+
position_ids=position_ids,
|
| 948 |
+
head_mask=head_mask[i],
|
| 949 |
+
use_cache=use_cache,
|
| 950 |
+
output_attentions=output_attentions,
|
| 951 |
+
alibi=alibi,
|
| 952 |
+
cache_position=cache_position,
|
| 953 |
+
position_embeddings=position_embeddings,
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
hidden_states = outputs[0]
|
| 957 |
+
if use_cache is True:
|
| 958 |
+
next_decoder_cache = outputs[1]
|
| 959 |
+
|
| 960 |
+
if output_attentions:
|
| 961 |
+
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
| 962 |
+
|
| 963 |
+
# Add last hidden state
|
| 964 |
+
hidden_states = self.ln_f(hidden_states)
|
| 965 |
+
|
| 966 |
+
if output_hidden_states:
|
| 967 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 968 |
+
|
| 969 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 970 |
+
if return_legacy_cache:
|
| 971 |
+
next_cache = next_cache.to_legacy_cache()
|
| 972 |
+
|
| 973 |
+
if not return_dict:
|
| 974 |
+
return tuple(
|
| 975 |
+
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
|
| 976 |
+
)
|
| 977 |
+
|
| 978 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 979 |
+
last_hidden_state=hidden_states,
|
| 980 |
+
past_key_values=next_cache,
|
| 981 |
+
hidden_states=all_hidden_states,
|
| 982 |
+
attentions=all_self_attentions,
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
def _update_causal_mask(
|
| 986 |
+
self,
|
| 987 |
+
attention_mask: torch.Tensor,
|
| 988 |
+
input_tensor: torch.Tensor,
|
| 989 |
+
cache_position: torch.Tensor,
|
| 990 |
+
past_key_values: Cache,
|
| 991 |
+
output_attentions: bool,
|
| 992 |
+
head_mask: torch.Tensor,
|
| 993 |
+
alibi: torch.Tensor,
|
| 994 |
+
):
|
| 995 |
+
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
| 996 |
+
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
| 997 |
+
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
| 998 |
+
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
| 999 |
+
|
| 1000 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 1001 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
| 1002 |
+
return attention_mask
|
| 1003 |
+
return None
|
| 1004 |
+
|
| 1005 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 1006 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 1007 |
+
# to infer the attention mask.
|
| 1008 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1009 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 1010 |
+
|
| 1011 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 1012 |
+
if (
|
| 1013 |
+
self.config._attn_implementation == "sdpa"
|
| 1014 |
+
and not using_static_cache
|
| 1015 |
+
and not output_attentions
|
| 1016 |
+
and head_mask is None
|
| 1017 |
+
and alibi is None
|
| 1018 |
+
):
|
| 1019 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 1020 |
+
attention_mask,
|
| 1021 |
+
inputs_embeds=input_tensor,
|
| 1022 |
+
past_key_values_length=past_seen_tokens,
|
| 1023 |
+
is_training=self.training,
|
| 1024 |
+
):
|
| 1025 |
+
return None
|
| 1026 |
+
|
| 1027 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
| 1028 |
+
min_dtype = torch.finfo(dtype).min
|
| 1029 |
+
batch_size, sequence_length, _ = input_tensor.shape
|
| 1030 |
+
if using_static_cache:
|
| 1031 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 1032 |
+
else:
|
| 1033 |
+
target_length = (
|
| 1034 |
+
attention_mask.shape[-1]
|
| 1035 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 1036 |
+
else past_seen_tokens + sequence_length
|
| 1037 |
+
)
|
| 1038 |
+
|
| 1039 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 1040 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 1041 |
+
attention_mask,
|
| 1042 |
+
sequence_length=sequence_length,
|
| 1043 |
+
target_length=target_length,
|
| 1044 |
+
dtype=dtype,
|
| 1045 |
+
device=device,
|
| 1046 |
+
cache_position=cache_position,
|
| 1047 |
+
batch_size=input_tensor.shape[0],
|
| 1048 |
+
)
|
| 1049 |
+
|
| 1050 |
+
# We take care to integrate alibi bias in the causal_mask here
|
| 1051 |
+
if head_mask is None and alibi is not None:
|
| 1052 |
+
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
|
| 1053 |
+
causal_mask = torch.masked_fill(
|
| 1054 |
+
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
|
| 1055 |
+
causal_mask < -1,
|
| 1056 |
+
min_dtype,
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
if (
|
| 1060 |
+
self.config._attn_implementation == "sdpa"
|
| 1061 |
+
and attention_mask is not None
|
| 1062 |
+
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
| 1063 |
+
and not output_attentions
|
| 1064 |
+
):
|
| 1065 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 1066 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 1067 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 1068 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
| 1069 |
+
|
| 1070 |
+
return causal_mask
|
| 1071 |
+
|
| 1072 |
+
@staticmethod
|
| 1073 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
| 1074 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 1075 |
+
attention_mask: torch.Tensor,
|
| 1076 |
+
sequence_length: int,
|
| 1077 |
+
target_length: int,
|
| 1078 |
+
dtype: torch.dtype,
|
| 1079 |
+
device: torch.device,
|
| 1080 |
+
cache_position: torch.Tensor,
|
| 1081 |
+
batch_size: int,
|
| 1082 |
+
**kwargs,
|
| 1083 |
+
):
|
| 1084 |
+
"""
|
| 1085 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 1086 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 1087 |
+
|
| 1088 |
+
Args:
|
| 1089 |
+
attention_mask (`torch.Tensor`):
|
| 1090 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
| 1091 |
+
`(batch_size, 1, query_length, key_value_length)`.
|
| 1092 |
+
sequence_length (`int`):
|
| 1093 |
+
The sequence length being processed.
|
| 1094 |
+
target_length (`int`):
|
| 1095 |
+
The target length: when generating with static cache, the mask should be as long as the static cache,
|
| 1096 |
+
to account for the 0 padding, the part of the cache that is not filled yet.
|
| 1097 |
+
dtype (`torch.dtype`):
|
| 1098 |
+
The dtype to use for the 4D attention mask.
|
| 1099 |
+
device (`torch.device`):
|
| 1100 |
+
The device to place the 4D attention mask on.
|
| 1101 |
+
cache_position (`torch.Tensor`):
|
| 1102 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 1103 |
+
batch_size (`torch.Tensor`):
|
| 1104 |
+
Batch size.
|
| 1105 |
+
"""
|
| 1106 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 1107 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 1108 |
+
causal_mask = attention_mask
|
| 1109 |
+
else:
|
| 1110 |
+
min_dtype = torch.finfo(dtype).min
|
| 1111 |
+
causal_mask = torch.full(
|
| 1112 |
+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
| 1113 |
+
)
|
| 1114 |
+
if sequence_length != 1:
|
| 1115 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 1116 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 1117 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 1118 |
+
if attention_mask is not None:
|
| 1119 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 1120 |
+
mask_length = attention_mask.shape[-1]
|
| 1121 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
| 1122 |
+
causal_mask.device
|
| 1123 |
+
)
|
| 1124 |
+
padding_mask = padding_mask == 0
|
| 1125 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 1126 |
+
padding_mask, min_dtype
|
| 1127 |
+
)
|
| 1128 |
+
|
| 1129 |
+
return causal_mask
|
| 1130 |
+
|
| 1131 |
+
|
| 1132 |
+
@add_start_docstrings(
|
| 1133 |
+
"The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).",
|
| 1134 |
+
FALCON_START_DOCSTRING,
|
| 1135 |
+
)
|
| 1136 |
+
class FalconForCausalLM(FalconPreTrainedModel, GenerationMixin):
|
| 1137 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1138 |
+
|
| 1139 |
+
def __init__(self, config: FalconConfig):
|
| 1140 |
+
super().__init__(config)
|
| 1141 |
+
self.transformer = FalconModel(config)
|
| 1142 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1143 |
+
|
| 1144 |
+
# Initialize weights and apply final processing
|
| 1145 |
+
self.post_init()
|
| 1146 |
+
|
| 1147 |
+
def get_output_embeddings(self):
|
| 1148 |
+
return self.lm_head
|
| 1149 |
+
|
| 1150 |
+
def set_output_embeddings(self, new_embeddings: torch.Tensor):
|
| 1151 |
+
self.lm_head = new_embeddings
|
| 1152 |
+
|
| 1153 |
+
@add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
|
| 1154 |
+
@add_code_sample_docstrings(
|
| 1155 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1156 |
+
output_type=CausalLMOutputWithCrossAttentions,
|
| 1157 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1158 |
+
)
|
| 1159 |
+
def forward(
|
| 1160 |
+
self,
|
| 1161 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1162 |
+
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
|
| 1163 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1164 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1165 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1166 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1167 |
+
labels: Optional[torch.Tensor] = None,
|
| 1168 |
+
use_cache: Optional[bool] = None,
|
| 1169 |
+
output_attentions: Optional[bool] = None,
|
| 1170 |
+
output_hidden_states: Optional[bool] = None,
|
| 1171 |
+
return_dict: Optional[bool] = None,
|
| 1172 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1173 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1174 |
+
**kwargs,
|
| 1175 |
+
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
| 1176 |
+
r"""
|
| 1177 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1178 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
| 1179 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
| 1180 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
| 1181 |
+
|
| 1182 |
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
| 1183 |
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
| 1184 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
| 1185 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
| 1186 |
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
| 1187 |
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
| 1188 |
+
"""
|
| 1189 |
+
|
| 1190 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1191 |
+
|
| 1192 |
+
transformer_outputs = self.transformer(
|
| 1193 |
+
input_ids,
|
| 1194 |
+
past_key_values=past_key_values,
|
| 1195 |
+
attention_mask=attention_mask,
|
| 1196 |
+
position_ids=position_ids,
|
| 1197 |
+
head_mask=head_mask,
|
| 1198 |
+
inputs_embeds=inputs_embeds,
|
| 1199 |
+
use_cache=use_cache,
|
| 1200 |
+
output_attentions=output_attentions,
|
| 1201 |
+
output_hidden_states=output_hidden_states,
|
| 1202 |
+
return_dict=return_dict,
|
| 1203 |
+
cache_position=cache_position,
|
| 1204 |
+
)
|
| 1205 |
+
hidden_states = transformer_outputs[0]
|
| 1206 |
+
|
| 1207 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1208 |
+
lm_logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 1209 |
+
|
| 1210 |
+
loss = None
|
| 1211 |
+
if labels is not None:
|
| 1212 |
+
loss = self.loss_function(
|
| 1213 |
+
lm_logits,
|
| 1214 |
+
labels,
|
| 1215 |
+
vocab_size=self.config.vocab_size,
|
| 1216 |
+
**kwargs,
|
| 1217 |
+
)
|
| 1218 |
+
|
| 1219 |
+
if not return_dict:
|
| 1220 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
| 1221 |
+
return ((loss,) + output) if loss is not None else output
|
| 1222 |
+
|
| 1223 |
+
return CausalLMOutputWithCrossAttentions(
|
| 1224 |
+
loss=loss,
|
| 1225 |
+
logits=lm_logits,
|
| 1226 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 1227 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 1228 |
+
attentions=transformer_outputs.attentions,
|
| 1229 |
+
)
|
| 1230 |
+
|
| 1231 |
+
def _reorder_cache(
|
| 1232 |
+
self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
| 1233 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
| 1234 |
+
"""
|
| 1235 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
| 1236 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
| 1237 |
+
beam_idx at every generation step.
|
| 1238 |
+
|
| 1239 |
+
Output shares the same memory storage as `past`.
|
| 1240 |
+
"""
|
| 1241 |
+
|
| 1242 |
+
# Get a copy of `beam_idx` on all the devices where we need those indices.
|
| 1243 |
+
device_to_beam_idx = {
|
| 1244 |
+
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
|
| 1245 |
+
}
|
| 1246 |
+
reordered_past = tuple(
|
| 1247 |
+
(
|
| 1248 |
+
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
| 1249 |
+
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
| 1250 |
+
)
|
| 1251 |
+
for layer_past in past
|
| 1252 |
+
)
|
| 1253 |
+
return reordered_past
|
| 1254 |
+
|
| 1255 |
+
|
| 1256 |
+
@add_start_docstrings(
|
| 1257 |
+
"""
|
| 1258 |
+
The Falcon Model transformer with a sequence classification head on top (linear layer).
|
| 1259 |
+
|
| 1260 |
+
[`FalconForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
| 1261 |
+
(e.g. GPT-1) do.
|
| 1262 |
+
|
| 1263 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
| 1264 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
| 1265 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
| 1266 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
| 1267 |
+
each row of the batch).
|
| 1268 |
+
""",
|
| 1269 |
+
FALCON_START_DOCSTRING,
|
| 1270 |
+
)
|
| 1271 |
+
class FalconForSequenceClassification(FalconPreTrainedModel):
|
| 1272 |
+
def __init__(self, config: FalconConfig):
|
| 1273 |
+
super().__init__(config)
|
| 1274 |
+
self.num_labels = config.num_labels
|
| 1275 |
+
self.transformer = FalconModel(config)
|
| 1276 |
+
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
| 1277 |
+
|
| 1278 |
+
# Initialize weights and apply final processing
|
| 1279 |
+
self.post_init()
|
| 1280 |
+
|
| 1281 |
+
@add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
|
| 1282 |
+
@add_code_sample_docstrings(
|
| 1283 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1284 |
+
output_type=SequenceClassifierOutputWithPast,
|
| 1285 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1286 |
+
)
|
| 1287 |
+
def forward(
|
| 1288 |
+
self,
|
| 1289 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1290 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 1291 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1292 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1293 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1294 |
+
labels: Optional[torch.Tensor] = None,
|
| 1295 |
+
use_cache: Optional[bool] = None,
|
| 1296 |
+
output_attentions: Optional[bool] = None,
|
| 1297 |
+
output_hidden_states: Optional[bool] = None,
|
| 1298 |
+
return_dict: Optional[bool] = None,
|
| 1299 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
|
| 1300 |
+
r"""
|
| 1301 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1302 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1303 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1304 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1305 |
+
"""
|
| 1306 |
+
|
| 1307 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1308 |
+
|
| 1309 |
+
transformer_outputs = self.transformer(
|
| 1310 |
+
input_ids,
|
| 1311 |
+
past_key_values=past_key_values,
|
| 1312 |
+
attention_mask=attention_mask,
|
| 1313 |
+
head_mask=head_mask,
|
| 1314 |
+
inputs_embeds=inputs_embeds,
|
| 1315 |
+
use_cache=use_cache,
|
| 1316 |
+
output_attentions=output_attentions,
|
| 1317 |
+
output_hidden_states=output_hidden_states,
|
| 1318 |
+
return_dict=return_dict,
|
| 1319 |
+
)
|
| 1320 |
+
|
| 1321 |
+
hidden_states = transformer_outputs[0]
|
| 1322 |
+
logits = self.score(hidden_states)
|
| 1323 |
+
|
| 1324 |
+
if input_ids is not None:
|
| 1325 |
+
batch_size = input_ids.shape[0]
|
| 1326 |
+
else:
|
| 1327 |
+
batch_size = inputs_embeds.shape[0]
|
| 1328 |
+
|
| 1329 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
| 1330 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
| 1331 |
+
if self.config.pad_token_id is None:
|
| 1332 |
+
last_non_pad_token = -1
|
| 1333 |
+
elif input_ids is not None:
|
| 1334 |
+
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
| 1335 |
+
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
| 1336 |
+
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
|
| 1337 |
+
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
| 1338 |
+
else:
|
| 1339 |
+
last_non_pad_token = -1
|
| 1340 |
+
logger.warning_once(
|
| 1341 |
+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
| 1342 |
+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
| 1343 |
+
)
|
| 1344 |
+
|
| 1345 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
| 1346 |
+
|
| 1347 |
+
loss = None
|
| 1348 |
+
if labels is not None:
|
| 1349 |
+
if self.config.problem_type is None:
|
| 1350 |
+
if self.num_labels == 1:
|
| 1351 |
+
self.config.problem_type = "regression"
|
| 1352 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 1353 |
+
self.config.problem_type = "single_label_classification"
|
| 1354 |
+
else:
|
| 1355 |
+
self.config.problem_type = "multi_label_classification"
|
| 1356 |
+
|
| 1357 |
+
if self.config.problem_type == "regression":
|
| 1358 |
+
loss_fct = MSELoss()
|
| 1359 |
+
if self.num_labels == 1:
|
| 1360 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
| 1361 |
+
else:
|
| 1362 |
+
loss = loss_fct(pooled_logits, labels)
|
| 1363 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1364 |
+
loss_fct = CrossEntropyLoss()
|
| 1365 |
+
loss = loss_fct(pooled_logits, labels)
|
| 1366 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1367 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1368 |
+
loss = loss_fct(pooled_logits, labels)
|
| 1369 |
+
if not return_dict:
|
| 1370 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
| 1371 |
+
return ((loss,) + output) if loss is not None else output
|
| 1372 |
+
|
| 1373 |
+
return SequenceClassifierOutputWithPast(
|
| 1374 |
+
loss=loss,
|
| 1375 |
+
logits=pooled_logits,
|
| 1376 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 1377 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 1378 |
+
attentions=transformer_outputs.attentions,
|
| 1379 |
+
)
|
| 1380 |
+
|
| 1381 |
+
|
| 1382 |
+
@add_start_docstrings(
|
| 1383 |
+
"""
|
| 1384 |
+
Falcon Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 1385 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1386 |
+
""",
|
| 1387 |
+
FALCON_START_DOCSTRING,
|
| 1388 |
+
)
|
| 1389 |
+
class FalconForTokenClassification(FalconPreTrainedModel):
|
| 1390 |
+
def __init__(self, config: FalconConfig):
|
| 1391 |
+
super().__init__(config)
|
| 1392 |
+
self.num_labels = config.num_labels
|
| 1393 |
+
|
| 1394 |
+
self.transformer = FalconModel(config)
|
| 1395 |
+
if getattr(config, "classifier_dropout", None) is not None:
|
| 1396 |
+
classifier_dropout = config.classifier_dropout
|
| 1397 |
+
elif getattr(config, "hidden_dropout", None) is not None:
|
| 1398 |
+
classifier_dropout = config.hidden_dropout
|
| 1399 |
+
else:
|
| 1400 |
+
classifier_dropout = 0.1
|
| 1401 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 1402 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1403 |
+
|
| 1404 |
+
# Initialize weights and apply final processing
|
| 1405 |
+
self.post_init()
|
| 1406 |
+
|
| 1407 |
+
@add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
|
| 1408 |
+
@add_code_sample_docstrings(
|
| 1409 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1410 |
+
output_type=TokenClassifierOutput,
|
| 1411 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1412 |
+
)
|
| 1413 |
+
def forward(
|
| 1414 |
+
self,
|
| 1415 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1416 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 1417 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1418 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1419 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1420 |
+
labels: Optional[torch.Tensor] = None,
|
| 1421 |
+
use_cache: Optional[bool] = None,
|
| 1422 |
+
output_attentions: Optional[bool] = None,
|
| 1423 |
+
output_hidden_states: Optional[bool] = None,
|
| 1424 |
+
return_dict: Optional[bool] = None,
|
| 1425 |
+
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
| 1426 |
+
r"""
|
| 1427 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1428 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1429 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1430 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1431 |
+
"""
|
| 1432 |
+
|
| 1433 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1434 |
+
|
| 1435 |
+
transformer_outputs = self.transformer(
|
| 1436 |
+
input_ids,
|
| 1437 |
+
past_key_values=past_key_values,
|
| 1438 |
+
attention_mask=attention_mask,
|
| 1439 |
+
head_mask=head_mask,
|
| 1440 |
+
inputs_embeds=inputs_embeds,
|
| 1441 |
+
use_cache=use_cache,
|
| 1442 |
+
output_attentions=output_attentions,
|
| 1443 |
+
output_hidden_states=output_hidden_states,
|
| 1444 |
+
return_dict=return_dict,
|
| 1445 |
+
)
|
| 1446 |
+
|
| 1447 |
+
hidden_states = transformer_outputs[0]
|
| 1448 |
+
hidden_states = self.dropout(hidden_states)
|
| 1449 |
+
logits = self.classifier(hidden_states)
|
| 1450 |
+
|
| 1451 |
+
loss = None
|
| 1452 |
+
if labels is not None:
|
| 1453 |
+
batch_size, seq_length = labels.shape
|
| 1454 |
+
loss_fct = CrossEntropyLoss()
|
| 1455 |
+
loss = loss_fct(
|
| 1456 |
+
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
|
| 1457 |
+
)
|
| 1458 |
+
|
| 1459 |
+
if not return_dict:
|
| 1460 |
+
output = (logits,) + transformer_outputs[2:]
|
| 1461 |
+
return ((loss,) + output) if loss is not None else output
|
| 1462 |
+
|
| 1463 |
+
return TokenClassifierOutput(
|
| 1464 |
+
loss=loss,
|
| 1465 |
+
logits=logits,
|
| 1466 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 1467 |
+
attentions=transformer_outputs.attentions,
|
| 1468 |
+
)
|
| 1469 |
+
|
| 1470 |
+
|
| 1471 |
+
@add_start_docstrings(
|
| 1472 |
+
"""
|
| 1473 |
+
The Falcon Model transformer with a span classification head on top for extractive question-answering tasks like
|
| 1474 |
+
SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1475 |
+
""",
|
| 1476 |
+
FALCON_START_DOCSTRING,
|
| 1477 |
+
)
|
| 1478 |
+
class FalconForQuestionAnswering(FalconPreTrainedModel):
|
| 1479 |
+
def __init__(self, config):
|
| 1480 |
+
super().__init__(config)
|
| 1481 |
+
self.transformer = FalconModel(config)
|
| 1482 |
+
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
| 1483 |
+
|
| 1484 |
+
# Initialize weights and apply final processing
|
| 1485 |
+
self.post_init()
|
| 1486 |
+
|
| 1487 |
+
@add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
|
| 1488 |
+
def forward(
|
| 1489 |
+
self,
|
| 1490 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1491 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1492 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1493 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1494 |
+
start_positions: Optional[torch.LongTensor] = None,
|
| 1495 |
+
end_positions: Optional[torch.LongTensor] = None,
|
| 1496 |
+
output_attentions: Optional[bool] = None,
|
| 1497 |
+
output_hidden_states: Optional[bool] = None,
|
| 1498 |
+
return_dict: Optional[bool] = None,
|
| 1499 |
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
| 1500 |
+
r"""
|
| 1501 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1502 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1503 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1504 |
+
are not taken into account for computing the loss.
|
| 1505 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1506 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1507 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1508 |
+
are not taken into account for computing the loss.
|
| 1509 |
+
"""
|
| 1510 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1511 |
+
|
| 1512 |
+
outputs = self.transformer(
|
| 1513 |
+
input_ids,
|
| 1514 |
+
attention_mask=attention_mask,
|
| 1515 |
+
head_mask=head_mask,
|
| 1516 |
+
inputs_embeds=inputs_embeds,
|
| 1517 |
+
output_attentions=output_attentions,
|
| 1518 |
+
output_hidden_states=output_hidden_states,
|
| 1519 |
+
return_dict=return_dict,
|
| 1520 |
+
)
|
| 1521 |
+
|
| 1522 |
+
sequence_output = outputs[0]
|
| 1523 |
+
|
| 1524 |
+
logits = self.qa_outputs(sequence_output)
|
| 1525 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1526 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 1527 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 1528 |
+
|
| 1529 |
+
total_loss = None
|
| 1530 |
+
if start_positions is not None and end_positions is not None:
|
| 1531 |
+
# If we are on multi-GPU, split add a dimension
|
| 1532 |
+
if len(start_positions.size()) > 1:
|
| 1533 |
+
start_positions = start_positions.squeeze(-1)
|
| 1534 |
+
if len(end_positions.size()) > 1:
|
| 1535 |
+
end_positions = end_positions.squeeze(-1)
|
| 1536 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1537 |
+
ignored_index = start_logits.size(1)
|
| 1538 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 1539 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 1540 |
+
|
| 1541 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1542 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1543 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1544 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1545 |
+
|
| 1546 |
+
if not return_dict:
|
| 1547 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 1548 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1549 |
+
|
| 1550 |
+
return QuestionAnsweringModelOutput(
|
| 1551 |
+
loss=total_loss,
|
| 1552 |
+
start_logits=start_logits,
|
| 1553 |
+
end_logits=end_logits,
|
| 1554 |
+
hidden_states=outputs.hidden_states,
|
| 1555 |
+
attentions=outputs.attentions,
|
| 1556 |
+
)
|
| 1557 |
+
|
| 1558 |
+
|
| 1559 |
+
__all__ = [
|
| 1560 |
+
"FalconForCausalLM",
|
| 1561 |
+
"FalconModel",
|
| 1562 |
+
"FalconPreTrainedModel",
|
| 1563 |
+
"FalconForSequenceClassification",
|
| 1564 |
+
"FalconForTokenClassification",
|
| 1565 |
+
"FalconForQuestionAnswering",
|
| 1566 |
+
]
|
docs/transformers/build/lib/transformers/models/falcon_mamba/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_falcon_mamba import *
|
| 22 |
+
from .modeling_falcon_mamba import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/falcon_mamba/configuration_falcon_mamba.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""FALCONMAMBA configuration"""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
from ...configuration_utils import PretrainedConfig
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class FalconMambaConfig(PretrainedConfig):
|
| 27 |
+
"""
|
| 28 |
+
This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA
|
| 29 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 30 |
+
defaults will yield a similar configuration to that of the FALCON_MAMBA
|
| 31 |
+
[tiiuae/falcon-mamba-7b](https://huggingface.co/tiiuae/falcon-mamba-7b) architecture.
|
| 32 |
+
|
| 33 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 34 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
vocab_size (`int`, *optional*, defaults to 50280):
|
| 39 |
+
Vocabulary size of the FALCON_MAMBA model. Defines the number of different tokens that can be represented by the
|
| 40 |
+
`inputs_ids` passed when calling [`FalconMambaModel`].
|
| 41 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 42 |
+
Dimensionality of the embeddings and hidden states.
|
| 43 |
+
state_size (`int`, *optional*, defaults to 16): shape of the state space latents.
|
| 44 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 45 |
+
Number of hidden layers in the model.
|
| 46 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
|
| 47 |
+
The epsilon to use in the layer normalization layers.
|
| 48 |
+
pad_token_id (`int`, *optional*, defaults to 0):
|
| 49 |
+
Padding token id.
|
| 50 |
+
bos_token_id (`int`, *optional*, defaults to 0):
|
| 51 |
+
The id of the beginning of sentence token in the vocabulary.
|
| 52 |
+
eos_token_id (`int`, *optional*, defaults to 0):
|
| 53 |
+
The id of the end of sentence token in the vocabulary.
|
| 54 |
+
expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
|
| 55 |
+
conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
|
| 56 |
+
use_bias (`bool`, *optional*, defaults to `False`):
|
| 57 |
+
Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
|
| 58 |
+
use_conv_bias (`bool`, *optional*, defaults to `True`):
|
| 59 |
+
Whether or not to use bias in the convolution layer of the mixer block.
|
| 60 |
+
hidden_act (`str`, *optional*, defaults to `"silu"`):
|
| 61 |
+
The non-linear activation function (function or string) in the decoder.
|
| 62 |
+
initializer_range (`float`, *optional*, defaults to 0.1):
|
| 63 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 64 |
+
residual_in_fp32 (`bool`, *optional*, defaults to `True`):
|
| 65 |
+
Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model
|
| 66 |
+
time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
|
| 67 |
+
Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
|
| 68 |
+
time_step_scale (`float`, *optional*, defaults to 1.0):
|
| 69 |
+
Scale used used to scale `dt_proj.bias`.
|
| 70 |
+
time_step_min (`float`, *optional*, defaults to 0.001):
|
| 71 |
+
Minimum `time_step` used to bound `dt_proj.bias`.
|
| 72 |
+
time_step_max (`float`, *optional*, defaults to 0.1):
|
| 73 |
+
Maximum `time_step` used to bound `dt_proj.bias`.
|
| 74 |
+
time_step_init_scheme (`float`, *optional*, defaults to `"random"`):
|
| 75 |
+
Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`
|
| 76 |
+
time_step_floor (`float`, *optional*, defaults to 0.0001):
|
| 77 |
+
Minimum clamping value of the `dt_proj.bias` layer initialization.
|
| 78 |
+
rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
|
| 79 |
+
Whether or not to rescale `out_proj` weights when initializing.
|
| 80 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 81 |
+
Whether or not the cache should be used.
|
| 82 |
+
use_mambapy (`bool`, *optional*, defaults to `False`):
|
| 83 |
+
Determines the fallback strategy during training if the CUDA-based official implementation of FalconMamba is not available. If `True`, the falcon_mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
|
| 84 |
+
mixer_rms_eps (`float`, *optional*, defaults to 1e-06):
|
| 85 |
+
The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states.
|
| 86 |
+
Example:
|
| 87 |
+
|
| 88 |
+
```python
|
| 89 |
+
>>> from transformers import FalconMambaConfig, FalconMambaModel
|
| 90 |
+
|
| 91 |
+
>>> # Initializing a FalconMamba configuration
|
| 92 |
+
>>> configuration = FalconMambaConfig()
|
| 93 |
+
|
| 94 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 95 |
+
>>> model = FalconMambaModel(configuration)
|
| 96 |
+
|
| 97 |
+
>>> # Accessing the model configuration
|
| 98 |
+
>>> configuration = model.config
|
| 99 |
+
```"""
|
| 100 |
+
|
| 101 |
+
model_type = "falcon_mamba"
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
vocab_size=50280,
|
| 106 |
+
hidden_size=768,
|
| 107 |
+
state_size=16,
|
| 108 |
+
num_hidden_layers=32,
|
| 109 |
+
layer_norm_epsilon=1e-5,
|
| 110 |
+
pad_token_id=0,
|
| 111 |
+
bos_token_id=0,
|
| 112 |
+
eos_token_id=0,
|
| 113 |
+
expand=2,
|
| 114 |
+
conv_kernel=4,
|
| 115 |
+
use_bias=False,
|
| 116 |
+
use_conv_bias=True,
|
| 117 |
+
hidden_act="silu",
|
| 118 |
+
initializer_range=0.1,
|
| 119 |
+
residual_in_fp32=True,
|
| 120 |
+
time_step_rank="auto",
|
| 121 |
+
time_step_scale=1.0,
|
| 122 |
+
time_step_min=0.001,
|
| 123 |
+
time_step_max=0.1,
|
| 124 |
+
time_step_init_scheme="random",
|
| 125 |
+
time_step_floor=1e-4,
|
| 126 |
+
rescale_prenorm_residual=False,
|
| 127 |
+
use_cache=True,
|
| 128 |
+
use_mambapy=False,
|
| 129 |
+
mixer_rms_eps=1e-6,
|
| 130 |
+
**kwargs,
|
| 131 |
+
):
|
| 132 |
+
self.vocab_size = vocab_size
|
| 133 |
+
self.hidden_size = hidden_size
|
| 134 |
+
self.state_size = state_size
|
| 135 |
+
self.num_hidden_layers = num_hidden_layers
|
| 136 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 137 |
+
self.conv_kernel = conv_kernel
|
| 138 |
+
self.expand = expand
|
| 139 |
+
self.intermediate_size = int(expand * self.hidden_size)
|
| 140 |
+
self.bos_token_id = bos_token_id
|
| 141 |
+
self.eos_token_id = eos_token_id
|
| 142 |
+
self.pad_token_id = pad_token_id
|
| 143 |
+
self.use_bias = use_bias
|
| 144 |
+
self.use_conv_bias = use_conv_bias
|
| 145 |
+
self.hidden_act = hidden_act
|
| 146 |
+
self.initializer_range = initializer_range
|
| 147 |
+
self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
|
| 148 |
+
self.time_step_scale = time_step_scale
|
| 149 |
+
self.time_step_min = time_step_min
|
| 150 |
+
self.time_step_max = time_step_max
|
| 151 |
+
self.time_step_init_scheme = time_step_init_scheme
|
| 152 |
+
self.time_step_floor = time_step_floor
|
| 153 |
+
self.rescale_prenorm_residual = rescale_prenorm_residual
|
| 154 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 155 |
+
self.use_cache = use_cache
|
| 156 |
+
self.use_mambapy = use_mambapy
|
| 157 |
+
self.mixer_rms_eps = mixer_rms_eps
|
| 158 |
+
|
| 159 |
+
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
__all__ = ["FalconMambaConfig"]
|
docs/transformers/build/lib/transformers/models/falcon_mamba/modeling_falcon_mamba.py
ADDED
|
@@ -0,0 +1,873 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch FALCONMAMBA model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.utils.checkpoint
|
| 23 |
+
from torch import nn
|
| 24 |
+
from torch.nn import CrossEntropyLoss
|
| 25 |
+
|
| 26 |
+
from ...activations import ACT2FN
|
| 27 |
+
from ...cache_utils import MambaCache
|
| 28 |
+
from ...generation import GenerationMixin
|
| 29 |
+
from ...modeling_utils import PreTrainedModel
|
| 30 |
+
from ...utils import (
|
| 31 |
+
ModelOutput,
|
| 32 |
+
add_code_sample_docstrings,
|
| 33 |
+
add_start_docstrings,
|
| 34 |
+
add_start_docstrings_to_model_forward,
|
| 35 |
+
logging,
|
| 36 |
+
)
|
| 37 |
+
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available
|
| 38 |
+
from .configuration_falcon_mamba import FalconMambaConfig
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
if is_mambapy_available():
|
| 44 |
+
from mambapy.pscan import pscan
|
| 45 |
+
else:
|
| 46 |
+
pscan = None
|
| 47 |
+
|
| 48 |
+
if is_mamba_ssm_available():
|
| 49 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
| 50 |
+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
| 51 |
+
|
| 52 |
+
from ...kernels.falcon_mamba import mamba_inner_fn
|
| 53 |
+
else:
|
| 54 |
+
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
| 55 |
+
|
| 56 |
+
if is_causal_conv1d_available():
|
| 57 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
| 58 |
+
else:
|
| 59 |
+
causal_conv1d_update, causal_conv1d_fn = None, None
|
| 60 |
+
|
| 61 |
+
is_fast_path_available = all(
|
| 62 |
+
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
_CHECKPOINT_FOR_DOC = "tiiuae/falcon-mamba-7b"
|
| 66 |
+
_CONFIG_FOR_DOC = "FalconMambaConfig"
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def rms_forward(hidden_states, variance_epsilon=1e-6):
|
| 70 |
+
"""
|
| 71 |
+
Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will
|
| 72 |
+
leverage this in order to multiply the final result with the RMSNorm weight
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
hidden_states (`torch.Tensor`):
|
| 76 |
+
Hidden states to normalize
|
| 77 |
+
variance_epsilon (`float`):
|
| 78 |
+
The eps value to add in the square root scaling factor
|
| 79 |
+
"""
|
| 80 |
+
input_dtype = hidden_states.dtype
|
| 81 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 82 |
+
|
| 83 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 84 |
+
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
| 85 |
+
return hidden_states.to(input_dtype)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class FalconMambaMixer(nn.Module):
|
| 89 |
+
"""
|
| 90 |
+
Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
|
| 91 |
+
A, D are input independent (see FalconMamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
|
| 92 |
+
∆, B, C are input-dependent (this is a key difference between FalconMamba and the linear time invariant S4,
|
| 93 |
+
and is why FalconMamba is called **selective** state spaces)
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def __init__(self, config: FalconMambaConfig, layer_idx: int):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.config = config
|
| 99 |
+
self.hidden_size = config.hidden_size
|
| 100 |
+
self.ssm_state_size = config.state_size
|
| 101 |
+
self.conv_kernel_size = config.conv_kernel
|
| 102 |
+
self.intermediate_size = config.intermediate_size
|
| 103 |
+
self.time_step_rank = int(config.time_step_rank)
|
| 104 |
+
self.layer_idx = layer_idx
|
| 105 |
+
self.use_conv_bias = config.use_conv_bias
|
| 106 |
+
self.conv1d = nn.Conv1d(
|
| 107 |
+
in_channels=self.intermediate_size,
|
| 108 |
+
out_channels=self.intermediate_size,
|
| 109 |
+
bias=config.use_conv_bias,
|
| 110 |
+
kernel_size=config.conv_kernel,
|
| 111 |
+
groups=self.intermediate_size,
|
| 112 |
+
padding=config.conv_kernel - 1,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.activation = config.hidden_act
|
| 116 |
+
self.act = ACT2FN[config.hidden_act]
|
| 117 |
+
|
| 118 |
+
self.use_mambapy = config.use_mambapy
|
| 119 |
+
|
| 120 |
+
# projection of the input hidden states
|
| 121 |
+
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
|
| 122 |
+
# selective projection used to make dt, B and C input dependent
|
| 123 |
+
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
|
| 124 |
+
# time step projection (discretization)
|
| 125 |
+
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
|
| 126 |
+
|
| 127 |
+
# S4D real initialization. These are not discretized!
|
| 128 |
+
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
|
| 129 |
+
A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
|
| 130 |
+
A = A.expand(self.intermediate_size, -1).contiguous()
|
| 131 |
+
|
| 132 |
+
self.A_log = nn.Parameter(torch.log(A))
|
| 133 |
+
self.D = nn.Parameter(torch.ones(self.intermediate_size))
|
| 134 |
+
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
|
| 135 |
+
self.use_bias = config.use_bias
|
| 136 |
+
|
| 137 |
+
# Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
|
| 138 |
+
self.register_buffer(
|
| 139 |
+
"b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False
|
| 140 |
+
)
|
| 141 |
+
self.register_buffer(
|
| 142 |
+
"dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False
|
| 143 |
+
)
|
| 144 |
+
self.rms_eps = config.mixer_rms_eps
|
| 145 |
+
|
| 146 |
+
if not is_fast_path_available:
|
| 147 |
+
if self.use_mambapy:
|
| 148 |
+
if is_mambapy_available():
|
| 149 |
+
logger.warning_once(
|
| 150 |
+
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
|
| 151 |
+
" is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and"
|
| 152 |
+
" https://github.com/Dao-AILab/causal-conv1d"
|
| 153 |
+
)
|
| 154 |
+
else:
|
| 155 |
+
raise ImportError(
|
| 156 |
+
"use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py."
|
| 157 |
+
)
|
| 158 |
+
else:
|
| 159 |
+
logger.warning_once(
|
| 160 |
+
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
|
| 161 |
+
" is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and"
|
| 162 |
+
" https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def cuda_kernels_forward(
|
| 166 |
+
self,
|
| 167 |
+
hidden_states: torch.Tensor,
|
| 168 |
+
cache_params: Optional[MambaCache] = None,
|
| 169 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 170 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 171 |
+
):
|
| 172 |
+
# 1. Gated MLP's linear projection
|
| 173 |
+
projected_states = self.in_proj(hidden_states).transpose(1, 2)
|
| 174 |
+
|
| 175 |
+
if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
|
| 176 |
+
contextualized_states = mamba_inner_fn(
|
| 177 |
+
projected_states,
|
| 178 |
+
self.conv1d.weight,
|
| 179 |
+
self.conv1d.bias if self.use_conv_bias else None,
|
| 180 |
+
self.x_proj.weight,
|
| 181 |
+
self.dt_proj.weight,
|
| 182 |
+
self.out_proj.weight,
|
| 183 |
+
self.out_proj.bias.float() if self.use_bias else None,
|
| 184 |
+
-torch.exp(self.A_log.float()),
|
| 185 |
+
None, # input-dependent B
|
| 186 |
+
None, # input-dependent C
|
| 187 |
+
self.D.float(),
|
| 188 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 189 |
+
delta_softplus=True,
|
| 190 |
+
b_rms_weight=self.b_c_rms,
|
| 191 |
+
c_rms_weight=self.b_c_rms,
|
| 192 |
+
dt_rms_weight=self.dt_rms,
|
| 193 |
+
b_c_dt_rms_eps=self.rms_eps,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
else:
|
| 197 |
+
hidden_states, gate = projected_states.chunk(2, dim=1)
|
| 198 |
+
|
| 199 |
+
if attention_mask is not None:
|
| 200 |
+
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
| 201 |
+
|
| 202 |
+
# 2. Convolution sequence transformation
|
| 203 |
+
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
|
| 204 |
+
if cache_params is not None and cache_position[0] > 0:
|
| 205 |
+
hidden_states = causal_conv1d_update(
|
| 206 |
+
hidden_states.squeeze(-1),
|
| 207 |
+
cache_params.conv_states[self.layer_idx],
|
| 208 |
+
conv_weights,
|
| 209 |
+
self.conv1d.bias,
|
| 210 |
+
self.activation,
|
| 211 |
+
)
|
| 212 |
+
hidden_states = hidden_states.unsqueeze(-1)
|
| 213 |
+
else:
|
| 214 |
+
if cache_params is not None:
|
| 215 |
+
conv_states = nn.functional.pad(
|
| 216 |
+
hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
|
| 217 |
+
)
|
| 218 |
+
cache_params.update_conv_state(self.layer_idx, conv_states, cache_position)
|
| 219 |
+
hidden_states = causal_conv1d_fn(
|
| 220 |
+
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
if attention_mask is not None:
|
| 224 |
+
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
| 225 |
+
|
| 226 |
+
# 3. State Space Model sequence transformation
|
| 227 |
+
# 3.a. input varying initialization of time_step, B and C
|
| 228 |
+
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
| 229 |
+
time_step, B, C = torch.split(
|
| 230 |
+
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
B = rms_forward(B, variance_epsilon=self.rms_eps)
|
| 234 |
+
C = rms_forward(C, variance_epsilon=self.rms_eps)
|
| 235 |
+
time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
|
| 236 |
+
|
| 237 |
+
# In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
|
| 238 |
+
# at the price of a small overhead.
|
| 239 |
+
if hasattr(self.config, "_pre_quantization_dtype"):
|
| 240 |
+
discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
|
| 241 |
+
else:
|
| 242 |
+
discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
|
| 243 |
+
|
| 244 |
+
A = -torch.exp(self.A_log.float())
|
| 245 |
+
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
| 246 |
+
time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
|
| 247 |
+
if cache_params is not None and cache_position[0] > 0:
|
| 248 |
+
scan_outputs = selective_state_update(
|
| 249 |
+
cache_params.ssm_states[self.layer_idx],
|
| 250 |
+
hidden_states[..., 0],
|
| 251 |
+
discrete_time_step[..., 0],
|
| 252 |
+
A,
|
| 253 |
+
B[:, 0],
|
| 254 |
+
C[:, 0],
|
| 255 |
+
self.D,
|
| 256 |
+
gate[..., 0],
|
| 257 |
+
time_proj_bias,
|
| 258 |
+
dt_softplus=True,
|
| 259 |
+
).unsqueeze(-1)
|
| 260 |
+
else:
|
| 261 |
+
scan_outputs, ssm_state = selective_scan_fn(
|
| 262 |
+
hidden_states,
|
| 263 |
+
discrete_time_step,
|
| 264 |
+
A,
|
| 265 |
+
B.transpose(1, 2),
|
| 266 |
+
C.transpose(1, 2),
|
| 267 |
+
self.D.float(),
|
| 268 |
+
gate,
|
| 269 |
+
time_proj_bias,
|
| 270 |
+
delta_softplus=True,
|
| 271 |
+
return_last_state=True,
|
| 272 |
+
)
|
| 273 |
+
if ssm_state is not None and cache_params is not None:
|
| 274 |
+
cache_params.update_ssm_state(self.layer_idx, ssm_state)
|
| 275 |
+
|
| 276 |
+
# 4. Final linear projection
|
| 277 |
+
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
|
| 278 |
+
return contextualized_states
|
| 279 |
+
|
| 280 |
+
def slow_forward(
|
| 281 |
+
self,
|
| 282 |
+
input_states,
|
| 283 |
+
cache_params: Optional[MambaCache] = None,
|
| 284 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 285 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 286 |
+
):
|
| 287 |
+
batch_size, seq_len, _ = input_states.shape
|
| 288 |
+
dtype = input_states.dtype
|
| 289 |
+
# 1. Gated MLP's linear projection
|
| 290 |
+
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
|
| 291 |
+
hidden_states, gate = projected_states.chunk(2, dim=1)
|
| 292 |
+
|
| 293 |
+
if attention_mask is not None:
|
| 294 |
+
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
| 295 |
+
|
| 296 |
+
# 2. Convolution sequence transformation
|
| 297 |
+
if cache_params is not None:
|
| 298 |
+
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
|
| 299 |
+
ssm_state = ssm_state.to(hidden_states.device)
|
| 300 |
+
# use `cache_position.shape[0]` to check whether we are in prefill
|
| 301 |
+
# stage, it's equivalent to check `cache_position[0] == 0`, which
|
| 302 |
+
# breaks dynamo fullgraph constraints
|
| 303 |
+
if cache_position is not None and cache_position.shape[0] == self.conv_kernel_size:
|
| 304 |
+
conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
|
| 305 |
+
|
| 306 |
+
cache_params.update_conv_state(self.layer_idx, conv_state, cache_position)
|
| 307 |
+
hidden_states = self.act(
|
| 308 |
+
self.conv1d(hidden_states)[..., :seq_len]
|
| 309 |
+
) # [batch, intermediate_size, seq_len]
|
| 310 |
+
else:
|
| 311 |
+
conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
|
| 312 |
+
conv_state = conv_state.to(self.conv1d.weight.device)
|
| 313 |
+
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
|
| 314 |
+
if self.use_conv_bias:
|
| 315 |
+
hidden_states += self.conv1d.bias
|
| 316 |
+
hidden_states = (
|
| 317 |
+
self.act(hidden_states).to(dtype).unsqueeze(-1)
|
| 318 |
+
) # [batch, intermediate_size, 1] : decoding
|
| 319 |
+
else:
|
| 320 |
+
ssm_state = torch.zeros(
|
| 321 |
+
(batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype
|
| 322 |
+
)
|
| 323 |
+
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
|
| 324 |
+
|
| 325 |
+
if attention_mask is not None:
|
| 326 |
+
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
| 327 |
+
|
| 328 |
+
# 3. State Space Model sequence transformation
|
| 329 |
+
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
|
| 330 |
+
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
| 331 |
+
time_step, B, C = torch.split(
|
| 332 |
+
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
B = rms_forward(B, variance_epsilon=self.rms_eps)
|
| 336 |
+
C = rms_forward(C, variance_epsilon=self.rms_eps)
|
| 337 |
+
time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
|
| 338 |
+
|
| 339 |
+
discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
|
| 340 |
+
discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(
|
| 341 |
+
1, 2
|
| 342 |
+
) # [batch, intermediate_size, seq_len]
|
| 343 |
+
|
| 344 |
+
# 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
|
| 345 |
+
A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
|
| 346 |
+
discrete_A = torch.exp(
|
| 347 |
+
A[None, :, None, :] * discrete_time_step[:, :, :, None]
|
| 348 |
+
) # [batch, intermediate_size, seq_len, ssm_state_size]
|
| 349 |
+
discrete_B = (
|
| 350 |
+
discrete_time_step[:, :, :, None] * B[:, None, :, :].float()
|
| 351 |
+
) # [batch, intermediate_size, seq_len, ssm_state_size]
|
| 352 |
+
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
|
| 353 |
+
|
| 354 |
+
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
| 355 |
+
if self.use_mambapy and self.training and cache_params is None:
|
| 356 |
+
hs = pscan(
|
| 357 |
+
discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)
|
| 358 |
+
) # [batch, seq_len, intermediate_size, ssm_state_size]
|
| 359 |
+
scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len]
|
| 360 |
+
scan_output = scan_output + hidden_states * self.D[None, :, None]
|
| 361 |
+
scan_output = scan_output * self.act(gate)
|
| 362 |
+
else:
|
| 363 |
+
scan_outputs = []
|
| 364 |
+
for i in range(seq_len):
|
| 365 |
+
ssm_state = (
|
| 366 |
+
discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
|
| 367 |
+
) # [batch, intermediate_size, ssm_state]
|
| 368 |
+
scan_output = torch.matmul(
|
| 369 |
+
ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)
|
| 370 |
+
) # [batch, intermediate_size, 1]
|
| 371 |
+
scan_outputs.append(scan_output[:, :, 0])
|
| 372 |
+
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
|
| 373 |
+
scan_output = scan_output + (hidden_states * self.D[None, :, None])
|
| 374 |
+
scan_output = scan_output * self.act(gate)
|
| 375 |
+
|
| 376 |
+
if cache_params is not None:
|
| 377 |
+
cache_params.update_ssm_state(self.layer_idx, ssm_state)
|
| 378 |
+
|
| 379 |
+
# 4. Final linear projection
|
| 380 |
+
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
|
| 381 |
+
return contextualized_states
|
| 382 |
+
|
| 383 |
+
# Copied from transformers.models.mamba.modeling_mamba.MambaMixer.forward
|
| 384 |
+
def forward(
|
| 385 |
+
self,
|
| 386 |
+
hidden_states,
|
| 387 |
+
cache_params: Optional[MambaCache] = None,
|
| 388 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 389 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 390 |
+
):
|
| 391 |
+
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
|
| 392 |
+
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
| 393 |
+
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
# Copied from transformers.models.mamba.modeling_mamba.MambaRMSNorm with Mamba->FalconMamba
|
| 397 |
+
class FalconMambaRMSNorm(nn.Module):
|
| 398 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 399 |
+
"""
|
| 400 |
+
FalconMambaRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
|
| 401 |
+
"""
|
| 402 |
+
super().__init__()
|
| 403 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 404 |
+
self.variance_epsilon = eps
|
| 405 |
+
|
| 406 |
+
def extra_repr(self):
|
| 407 |
+
return f"{self.weight.shape[0]}, eps={self.variance_epsilon}"
|
| 408 |
+
|
| 409 |
+
# Ignore copy
|
| 410 |
+
def forward(self, hidden_states):
|
| 411 |
+
return self.weight.to(hidden_states.device) * rms_forward(
|
| 412 |
+
hidden_states, variance_epsilon=self.variance_epsilon
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
# Copied from transformers.models.mamba.modeling_mamba.MambaBlock with Mamba->FalconMamba,FalconMambaCache->MambaCache
|
| 417 |
+
class FalconMambaBlock(nn.Module):
|
| 418 |
+
def __init__(self, config, layer_idx):
|
| 419 |
+
super().__init__()
|
| 420 |
+
self.config = config
|
| 421 |
+
self.layer_idx = layer_idx
|
| 422 |
+
self.residual_in_fp32 = config.residual_in_fp32
|
| 423 |
+
self.norm = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
| 424 |
+
self.mixer = FalconMambaMixer(config, layer_idx=layer_idx)
|
| 425 |
+
|
| 426 |
+
def forward(
|
| 427 |
+
self,
|
| 428 |
+
hidden_states,
|
| 429 |
+
cache_params: Optional[MambaCache] = None,
|
| 430 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 431 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 432 |
+
):
|
| 433 |
+
residual = hidden_states
|
| 434 |
+
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
|
| 435 |
+
if self.residual_in_fp32:
|
| 436 |
+
residual = residual.to(torch.float32)
|
| 437 |
+
|
| 438 |
+
hidden_states = self.mixer(
|
| 439 |
+
hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
|
| 440 |
+
)
|
| 441 |
+
hidden_states = residual + hidden_states
|
| 442 |
+
return hidden_states
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
# Copied from transformers.models.mamba.modeling_mamba.MambaPreTrainedModel with Mamba->FalconMamba
|
| 446 |
+
class FalconMambaPreTrainedModel(PreTrainedModel):
|
| 447 |
+
"""
|
| 448 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 449 |
+
models.
|
| 450 |
+
"""
|
| 451 |
+
|
| 452 |
+
config_class = FalconMambaConfig
|
| 453 |
+
base_model_prefix = "backbone"
|
| 454 |
+
_no_split_modules = ["FalconMambaBlock", "FalconMambaMixer"]
|
| 455 |
+
supports_gradient_checkpointing = True
|
| 456 |
+
_is_stateful = True
|
| 457 |
+
|
| 458 |
+
def _init_weights(self, module):
|
| 459 |
+
"""Initialize the weights."""
|
| 460 |
+
if isinstance(module, FalconMambaMixer):
|
| 461 |
+
module.A_log._no_weight_decay = True
|
| 462 |
+
module.D._no_weight_decay = True
|
| 463 |
+
|
| 464 |
+
dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
|
| 465 |
+
if self.config.time_step_init_scheme == "constant":
|
| 466 |
+
nn.init.constant_(module.dt_proj.weight, dt_init_std)
|
| 467 |
+
elif self.config.time_step_init_scheme == "random":
|
| 468 |
+
nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
|
| 469 |
+
|
| 470 |
+
dt = torch.exp(
|
| 471 |
+
torch.rand(self.config.intermediate_size)
|
| 472 |
+
* (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
|
| 473 |
+
+ math.log(self.config.time_step_min)
|
| 474 |
+
).clamp(min=self.config.time_step_floor)
|
| 475 |
+
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
| 476 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
| 477 |
+
with torch.no_grad():
|
| 478 |
+
module.dt_proj.bias.copy_(inv_dt)
|
| 479 |
+
module.dt_proj.bias._no_reinit = True
|
| 480 |
+
|
| 481 |
+
if isinstance(module, nn.Linear):
|
| 482 |
+
if module.bias is not None:
|
| 483 |
+
if not getattr(module.bias, "_no_reinit", False):
|
| 484 |
+
nn.init.zeros_(module.bias)
|
| 485 |
+
elif isinstance(module, nn.Embedding):
|
| 486 |
+
nn.init.normal_(module.weight, std=self.config.initializer_range)
|
| 487 |
+
|
| 488 |
+
if self.config.rescale_prenorm_residual:
|
| 489 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
| 490 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
| 491 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
| 492 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
| 493 |
+
#
|
| 494 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
| 495 |
+
for name, p in module.named_parameters():
|
| 496 |
+
if name in ["out_proj.weight"]:
|
| 497 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
| 498 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
| 499 |
+
# We need to reinit p since this code could be called multiple times
|
| 500 |
+
# Having just p *= scale would repeatedly scale it down
|
| 501 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
| 502 |
+
with torch.no_grad():
|
| 503 |
+
p /= math.sqrt(self.config.num_hidden_layers)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
@dataclass
|
| 507 |
+
# Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->FALCONMAMBA,Mamba->FalconMamba,FalconMambaCache->MambaCache
|
| 508 |
+
class FalconMambaOutput(ModelOutput):
|
| 509 |
+
"""
|
| 510 |
+
Class for the FALCONMAMBA model outputs.
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 514 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 515 |
+
cache_params (`MambaCache`):
|
| 516 |
+
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
| 517 |
+
avoid providing the old `input_ids`.
|
| 518 |
+
|
| 519 |
+
Includes both the State space model state matrices after the selective scan, and the Convolutional states
|
| 520 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 521 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 522 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 523 |
+
|
| 524 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 525 |
+
"""
|
| 526 |
+
|
| 527 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 528 |
+
cache_params: Optional[MambaCache] = None
|
| 529 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
@dataclass
|
| 533 |
+
# Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->FalconMamba,FalconMambaCache->MambaCache
|
| 534 |
+
class FalconMambaCausalLMOutput(ModelOutput):
|
| 535 |
+
"""
|
| 536 |
+
Base class for causal language model (or autoregressive) outputs.
|
| 537 |
+
|
| 538 |
+
Args:
|
| 539 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 540 |
+
Language modeling loss (for next-token prediction).
|
| 541 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 542 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 543 |
+
cache_params (`MambaCache`):
|
| 544 |
+
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
| 545 |
+
avoid providing the old `input_ids`.
|
| 546 |
+
|
| 547 |
+
Includes both the State space model state matrices after the selective scan, and the Convolutional states
|
| 548 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 549 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 550 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 551 |
+
|
| 552 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 553 |
+
"""
|
| 554 |
+
|
| 555 |
+
loss: Optional[torch.FloatTensor] = None
|
| 556 |
+
logits: Optional[torch.FloatTensor] = None
|
| 557 |
+
cache_params: Optional[MambaCache] = None
|
| 558 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
FALCONMAMBA_START_DOCSTRING = r"""
|
| 562 |
+
|
| 563 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 564 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 565 |
+
etc.)
|
| 566 |
+
|
| 567 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 568 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 569 |
+
and behavior.
|
| 570 |
+
|
| 571 |
+
Parameters:
|
| 572 |
+
config ([`FalconMambaConfig`]): Model configuration class with all the parameters of the model.
|
| 573 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 574 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 575 |
+
"""
|
| 576 |
+
|
| 577 |
+
FALCONMAMBA_INPUTS_DOCSTRING = r"""
|
| 578 |
+
Args:
|
| 579 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
| 580 |
+
Indices of input sequence tokens in the vocabulary.
|
| 581 |
+
|
| 582 |
+
If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as
|
| 583 |
+
`input_ids`.
|
| 584 |
+
|
| 585 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 586 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 587 |
+
|
| 588 |
+
[What are input IDs?](../glossary#input-ids)
|
| 589 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 590 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 591 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 592 |
+
model's internal embedding lookup matrix.
|
| 593 |
+
cache_params (`MambaCache`, *optional*):
|
| 594 |
+
If passed along, the model uses the previous state in all the blocks (which will give the output for the
|
| 595 |
+
`input_ids` provided as if the model add `state_input_ids + input_ids` as context).
|
| 596 |
+
use_cache (`bool`, *optional*):
|
| 597 |
+
If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
|
| 598 |
+
output_hidden_states (`bool`, *optional*):
|
| 599 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 600 |
+
more detail.
|
| 601 |
+
return_dict (`bool`, *optional*):
|
| 602 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 603 |
+
"""
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
@add_start_docstrings(
|
| 607 |
+
"The bare FALCONMAMBA Model transformer outputting raw hidden-states without any specific head on top.",
|
| 608 |
+
FALCONMAMBA_START_DOCSTRING,
|
| 609 |
+
)
|
| 610 |
+
class FalconMambaModel(FalconMambaPreTrainedModel):
|
| 611 |
+
def __init__(self, config):
|
| 612 |
+
super().__init__(config)
|
| 613 |
+
|
| 614 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 615 |
+
self.layers = nn.ModuleList(
|
| 616 |
+
[FalconMambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
self.gradient_checkpointing = False
|
| 620 |
+
self.norm_f = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
| 621 |
+
# Initialize weights and apply final processing
|
| 622 |
+
self.post_init()
|
| 623 |
+
|
| 624 |
+
def get_input_embeddings(self):
|
| 625 |
+
return self.embeddings
|
| 626 |
+
|
| 627 |
+
def set_input_embeddings(self, new_embeddings):
|
| 628 |
+
self.embeddings = new_embeddings
|
| 629 |
+
|
| 630 |
+
@add_start_docstrings_to_model_forward(FALCONMAMBA_INPUTS_DOCSTRING)
|
| 631 |
+
@add_code_sample_docstrings(
|
| 632 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 633 |
+
output_type=FalconMambaOutput,
|
| 634 |
+
config_class=_CONFIG_FOR_DOC,
|
| 635 |
+
)
|
| 636 |
+
def forward(
|
| 637 |
+
self,
|
| 638 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 639 |
+
inputs_embeds: Optional[torch.LongTensor] = None,
|
| 640 |
+
cache_params: Optional[MambaCache] = None,
|
| 641 |
+
use_cache: Optional[bool] = None,
|
| 642 |
+
output_hidden_states: Optional[bool] = None,
|
| 643 |
+
return_dict: Optional[bool] = None,
|
| 644 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 645 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 646 |
+
) -> Union[Tuple, FalconMambaOutput]:
|
| 647 |
+
output_hidden_states = (
|
| 648 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 649 |
+
)
|
| 650 |
+
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
|
| 651 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 652 |
+
|
| 653 |
+
if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
|
| 654 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 655 |
+
|
| 656 |
+
if inputs_embeds is None:
|
| 657 |
+
inputs_embeds = self.embeddings(input_ids)
|
| 658 |
+
|
| 659 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 660 |
+
use_cache = False
|
| 661 |
+
|
| 662 |
+
if use_cache:
|
| 663 |
+
if cache_params is None:
|
| 664 |
+
cache_params = MambaCache(
|
| 665 |
+
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
|
| 666 |
+
)
|
| 667 |
+
cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
|
| 668 |
+
elif cache_position is None:
|
| 669 |
+
# cases when we do manual forward instead of using `model.generate` which will initiate
|
| 670 |
+
# `cache_position` and makes sure it is not None, throw error here instead of doing some
|
| 671 |
+
# hack to conjecture the current cache position
|
| 672 |
+
raise ValueError(
|
| 673 |
+
"You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
|
| 674 |
+
"you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
|
| 675 |
+
"be initialized for you automatically"
|
| 676 |
+
)
|
| 677 |
+
else:
|
| 678 |
+
cache_params = None
|
| 679 |
+
hidden_states = inputs_embeds
|
| 680 |
+
all_hidden_states = () if output_hidden_states else None
|
| 681 |
+
for mixer_block in self.layers:
|
| 682 |
+
if self.gradient_checkpointing and self.training:
|
| 683 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 684 |
+
mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
|
| 685 |
+
)
|
| 686 |
+
else:
|
| 687 |
+
hidden_states = mixer_block(
|
| 688 |
+
hidden_states,
|
| 689 |
+
cache_params=cache_params,
|
| 690 |
+
cache_position=cache_position,
|
| 691 |
+
attention_mask=attention_mask,
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
if output_hidden_states:
|
| 695 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 696 |
+
|
| 697 |
+
hidden_states = self.norm_f(hidden_states)
|
| 698 |
+
|
| 699 |
+
if output_hidden_states:
|
| 700 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 701 |
+
|
| 702 |
+
if not return_dict:
|
| 703 |
+
return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
|
| 704 |
+
|
| 705 |
+
return FalconMambaOutput(
|
| 706 |
+
last_hidden_state=hidden_states,
|
| 707 |
+
cache_params=cache_params if use_cache else None,
|
| 708 |
+
hidden_states=all_hidden_states,
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
@add_start_docstrings(
|
| 713 |
+
"""
|
| 714 |
+
The FALCONMAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
| 715 |
+
embeddings).
|
| 716 |
+
""",
|
| 717 |
+
FALCONMAMBA_START_DOCSTRING,
|
| 718 |
+
)
|
| 719 |
+
# Copied from transformers.models.mamba.modeling_mamba.MambaForCausalLM with MAMBA->FALCONMAMBA,Mamba->FalconMamba,mamba->falcon_mamba,FalconMambaCache->MambaCache
|
| 720 |
+
class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
|
| 721 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 722 |
+
|
| 723 |
+
def __init__(self, config):
|
| 724 |
+
super().__init__(config)
|
| 725 |
+
self.backbone = FalconMambaModel(config)
|
| 726 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 727 |
+
# Initialize weights and apply final processing
|
| 728 |
+
self.post_init()
|
| 729 |
+
|
| 730 |
+
def get_output_embeddings(self):
|
| 731 |
+
return self.lm_head
|
| 732 |
+
|
| 733 |
+
def set_output_embeddings(self, new_embeddings):
|
| 734 |
+
self.lm_head = new_embeddings
|
| 735 |
+
|
| 736 |
+
def get_input_embeddings(self):
|
| 737 |
+
return self.backbone.get_input_embeddings()
|
| 738 |
+
|
| 739 |
+
def set_input_embeddings(self, new_embeddings):
|
| 740 |
+
return self.backbone.set_input_embeddings(new_embeddings)
|
| 741 |
+
|
| 742 |
+
def _update_model_kwargs_for_generation(
|
| 743 |
+
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], num_new_tokens: int = 1, **kwargs
|
| 744 |
+
) -> Dict[str, Any]:
|
| 745 |
+
model_kwargs["cache_params"] = outputs.get("cache_params", None)
|
| 746 |
+
if (
|
| 747 |
+
model_kwargs.get("use_cache", True)
|
| 748 |
+
and "cache_position" in model_kwargs
|
| 749 |
+
and model_kwargs["cache_position"] is not None
|
| 750 |
+
):
|
| 751 |
+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
| 752 |
+
|
| 753 |
+
if "attention_mask" in model_kwargs:
|
| 754 |
+
attention_mask = model_kwargs["attention_mask"]
|
| 755 |
+
model_kwargs["attention_mask"] = torch.cat(
|
| 756 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
return model_kwargs
|
| 760 |
+
|
| 761 |
+
def prepare_inputs_for_generation(
|
| 762 |
+
self,
|
| 763 |
+
input_ids,
|
| 764 |
+
inputs_embeds=None,
|
| 765 |
+
use_cache=None,
|
| 766 |
+
cache_params: Optional[MambaCache] = None,
|
| 767 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 768 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 769 |
+
**kwargs,
|
| 770 |
+
):
|
| 771 |
+
# Overwritten -- uses `cache_params` as opposed to `past_key_values`
|
| 772 |
+
|
| 773 |
+
if use_cache:
|
| 774 |
+
# `cache_position` should have been initialized in `generate`
|
| 775 |
+
if cache_position is None:
|
| 776 |
+
raise ValueError(
|
| 777 |
+
"`cache_position` should not be None as it should have been initialized in "
|
| 778 |
+
"`model.generate`, you are responsible for passing in a valid `cache_position` if "
|
| 779 |
+
"you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
|
| 780 |
+
)
|
| 781 |
+
if cache_position[0] > 0:
|
| 782 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 783 |
+
|
| 784 |
+
if attention_mask is not None:
|
| 785 |
+
attention_mask = None
|
| 786 |
+
|
| 787 |
+
else:
|
| 788 |
+
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
|
| 789 |
+
# considering padding will be applied when input length is shorter, and truncation
|
| 790 |
+
# will be applied when it is longer, so it will be equivalent to always have it match
|
| 791 |
+
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
|
| 792 |
+
cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
|
| 793 |
+
|
| 794 |
+
if inputs_embeds is not None and cache_params is None:
|
| 795 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 796 |
+
else:
|
| 797 |
+
model_inputs = {"input_ids": input_ids.contiguous()}
|
| 798 |
+
|
| 799 |
+
model_inputs.update(
|
| 800 |
+
{
|
| 801 |
+
"cache_params": cache_params,
|
| 802 |
+
"use_cache": use_cache,
|
| 803 |
+
"cache_position": cache_position,
|
| 804 |
+
"attention_mask": attention_mask,
|
| 805 |
+
}
|
| 806 |
+
)
|
| 807 |
+
return model_inputs
|
| 808 |
+
|
| 809 |
+
@add_start_docstrings_to_model_forward(FALCONMAMBA_INPUTS_DOCSTRING)
|
| 810 |
+
@add_code_sample_docstrings(
|
| 811 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 812 |
+
output_type=FalconMambaCausalLMOutput,
|
| 813 |
+
config_class=_CONFIG_FOR_DOC,
|
| 814 |
+
)
|
| 815 |
+
def forward(
|
| 816 |
+
self,
|
| 817 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 818 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 819 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 820 |
+
cache_params: Optional[MambaCache] = None,
|
| 821 |
+
labels: Optional[torch.LongTensor] = None,
|
| 822 |
+
output_hidden_states: Optional[bool] = None,
|
| 823 |
+
return_dict: Optional[bool] = None,
|
| 824 |
+
use_cache: Optional[bool] = None,
|
| 825 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 826 |
+
**kwargs, # for now we need this for generation
|
| 827 |
+
) -> Union[Tuple, FalconMambaCausalLMOutput]:
|
| 828 |
+
r"""
|
| 829 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 830 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
| 831 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
| 832 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
| 833 |
+
"""
|
| 834 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 835 |
+
|
| 836 |
+
falcon_mamba_outputs = self.backbone(
|
| 837 |
+
input_ids,
|
| 838 |
+
cache_params=cache_params,
|
| 839 |
+
inputs_embeds=inputs_embeds,
|
| 840 |
+
output_hidden_states=output_hidden_states,
|
| 841 |
+
return_dict=return_dict,
|
| 842 |
+
use_cache=use_cache,
|
| 843 |
+
cache_position=cache_position,
|
| 844 |
+
attention_mask=attention_mask,
|
| 845 |
+
)
|
| 846 |
+
hidden_states = falcon_mamba_outputs[0]
|
| 847 |
+
|
| 848 |
+
logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
|
| 849 |
+
|
| 850 |
+
loss = None
|
| 851 |
+
if labels is not None:
|
| 852 |
+
# move labels to correct device to enable model parallelism
|
| 853 |
+
labels = labels.to(logits.device)
|
| 854 |
+
# Shift so that tokens < n predict n
|
| 855 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 856 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 857 |
+
# Flatten the tokens
|
| 858 |
+
loss_fct = CrossEntropyLoss()
|
| 859 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 860 |
+
|
| 861 |
+
if not return_dict:
|
| 862 |
+
output = (logits,) + falcon_mamba_outputs[1:]
|
| 863 |
+
return ((loss,) + output) if loss is not None else output
|
| 864 |
+
|
| 865 |
+
return FalconMambaCausalLMOutput(
|
| 866 |
+
loss=loss,
|
| 867 |
+
logits=logits,
|
| 868 |
+
cache_params=falcon_mamba_outputs.cache_params,
|
| 869 |
+
hidden_states=falcon_mamba_outputs.hidden_states,
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/fastspeech2_conformer/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_fastspeech2_conformer import *
|
| 22 |
+
from .modeling_fastspeech2_conformer import *
|
| 23 |
+
from .tokenization_fastspeech2_conformer import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""FastSpeech2Conformer model configuration"""
|
| 16 |
+
|
| 17 |
+
from typing import Dict
|
| 18 |
+
|
| 19 |
+
from ...configuration_utils import PretrainedConfig
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class FastSpeech2ConformerConfig(PretrainedConfig):
|
| 27 |
+
r"""
|
| 28 |
+
This is the configuration class to store the configuration of a [`FastSpeech2ConformerModel`]. It is used to
|
| 29 |
+
instantiate a FastSpeech2Conformer model according to the specified arguments, defining the model architecture.
|
| 30 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the
|
| 31 |
+
FastSpeech2Conformer [espnet/fastspeech2_conformer](https://huggingface.co/espnet/fastspeech2_conformer)
|
| 32 |
+
architecture.
|
| 33 |
+
|
| 34 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 35 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
hidden_size (`int`, *optional*, defaults to 384):
|
| 39 |
+
The dimensionality of the hidden layers.
|
| 40 |
+
vocab_size (`int`, *optional*, defaults to 78):
|
| 41 |
+
The size of the vocabulary.
|
| 42 |
+
num_mel_bins (`int`, *optional*, defaults to 80):
|
| 43 |
+
The number of mel filters used in the filter bank.
|
| 44 |
+
encoder_num_attention_heads (`int`, *optional*, defaults to 2):
|
| 45 |
+
The number of attention heads in the encoder.
|
| 46 |
+
encoder_layers (`int`, *optional*, defaults to 4):
|
| 47 |
+
The number of layers in the encoder.
|
| 48 |
+
encoder_linear_units (`int`, *optional*, defaults to 1536):
|
| 49 |
+
The number of units in the linear layer of the encoder.
|
| 50 |
+
decoder_layers (`int`, *optional*, defaults to 4):
|
| 51 |
+
The number of layers in the decoder.
|
| 52 |
+
decoder_num_attention_heads (`int`, *optional*, defaults to 2):
|
| 53 |
+
The number of attention heads in the decoder.
|
| 54 |
+
decoder_linear_units (`int`, *optional*, defaults to 1536):
|
| 55 |
+
The number of units in the linear layer of the decoder.
|
| 56 |
+
speech_decoder_postnet_layers (`int`, *optional*, defaults to 5):
|
| 57 |
+
The number of layers in the post-net of the speech decoder.
|
| 58 |
+
speech_decoder_postnet_units (`int`, *optional*, defaults to 256):
|
| 59 |
+
The number of units in the post-net layers of the speech decoder.
|
| 60 |
+
speech_decoder_postnet_kernel (`int`, *optional*, defaults to 5):
|
| 61 |
+
The kernel size in the post-net of the speech decoder.
|
| 62 |
+
positionwise_conv_kernel_size (`int`, *optional*, defaults to 3):
|
| 63 |
+
The size of the convolution kernel used in the position-wise layer.
|
| 64 |
+
encoder_normalize_before (`bool`, *optional*, defaults to `False`):
|
| 65 |
+
Specifies whether to normalize before encoder layers.
|
| 66 |
+
decoder_normalize_before (`bool`, *optional*, defaults to `False`):
|
| 67 |
+
Specifies whether to normalize before decoder layers.
|
| 68 |
+
encoder_concat_after (`bool`, *optional*, defaults to `False`):
|
| 69 |
+
Specifies whether to concatenate after encoder layers.
|
| 70 |
+
decoder_concat_after (`bool`, *optional*, defaults to `False`):
|
| 71 |
+
Specifies whether to concatenate after decoder layers.
|
| 72 |
+
reduction_factor (`int`, *optional*, defaults to 1):
|
| 73 |
+
The factor by which the speech frame rate is reduced.
|
| 74 |
+
speaking_speed (`float`, *optional*, defaults to 1.0):
|
| 75 |
+
The speed of the speech produced.
|
| 76 |
+
use_macaron_style_in_conformer (`bool`, *optional*, defaults to `True`):
|
| 77 |
+
Specifies whether to use macaron style in the conformer.
|
| 78 |
+
use_cnn_in_conformer (`bool`, *optional*, defaults to `True`):
|
| 79 |
+
Specifies whether to use convolutional neural networks in the conformer.
|
| 80 |
+
encoder_kernel_size (`int`, *optional*, defaults to 7):
|
| 81 |
+
The kernel size used in the encoder.
|
| 82 |
+
decoder_kernel_size (`int`, *optional*, defaults to 31):
|
| 83 |
+
The kernel size used in the decoder.
|
| 84 |
+
duration_predictor_layers (`int`, *optional*, defaults to 2):
|
| 85 |
+
The number of layers in the duration predictor.
|
| 86 |
+
duration_predictor_channels (`int`, *optional*, defaults to 256):
|
| 87 |
+
The number of channels in the duration predictor.
|
| 88 |
+
duration_predictor_kernel_size (`int`, *optional*, defaults to 3):
|
| 89 |
+
The kernel size used in the duration predictor.
|
| 90 |
+
energy_predictor_layers (`int`, *optional*, defaults to 2):
|
| 91 |
+
The number of layers in the energy predictor.
|
| 92 |
+
energy_predictor_channels (`int`, *optional*, defaults to 256):
|
| 93 |
+
The number of channels in the energy predictor.
|
| 94 |
+
energy_predictor_kernel_size (`int`, *optional*, defaults to 3):
|
| 95 |
+
The kernel size used in the energy predictor.
|
| 96 |
+
energy_predictor_dropout (`float`, *optional*, defaults to 0.5):
|
| 97 |
+
The dropout rate in the energy predictor.
|
| 98 |
+
energy_embed_kernel_size (`int`, *optional*, defaults to 1):
|
| 99 |
+
The kernel size used in the energy embed layer.
|
| 100 |
+
energy_embed_dropout (`float`, *optional*, defaults to 0.0):
|
| 101 |
+
The dropout rate in the energy embed layer.
|
| 102 |
+
stop_gradient_from_energy_predictor (`bool`, *optional*, defaults to `False`):
|
| 103 |
+
Specifies whether to stop gradients from the energy predictor.
|
| 104 |
+
pitch_predictor_layers (`int`, *optional*, defaults to 5):
|
| 105 |
+
The number of layers in the pitch predictor.
|
| 106 |
+
pitch_predictor_channels (`int`, *optional*, defaults to 256):
|
| 107 |
+
The number of channels in the pitch predictor.
|
| 108 |
+
pitch_predictor_kernel_size (`int`, *optional*, defaults to 5):
|
| 109 |
+
The kernel size used in the pitch predictor.
|
| 110 |
+
pitch_predictor_dropout (`float`, *optional*, defaults to 0.5):
|
| 111 |
+
The dropout rate in the pitch predictor.
|
| 112 |
+
pitch_embed_kernel_size (`int`, *optional*, defaults to 1):
|
| 113 |
+
The kernel size used in the pitch embed layer.
|
| 114 |
+
pitch_embed_dropout (`float`, *optional*, defaults to 0.0):
|
| 115 |
+
The dropout rate in the pitch embed layer.
|
| 116 |
+
stop_gradient_from_pitch_predictor (`bool`, *optional*, defaults to `True`):
|
| 117 |
+
Specifies whether to stop gradients from the pitch predictor.
|
| 118 |
+
encoder_dropout_rate (`float`, *optional*, defaults to 0.2):
|
| 119 |
+
The dropout rate in the encoder.
|
| 120 |
+
encoder_positional_dropout_rate (`float`, *optional*, defaults to 0.2):
|
| 121 |
+
The positional dropout rate in the encoder.
|
| 122 |
+
encoder_attention_dropout_rate (`float`, *optional*, defaults to 0.2):
|
| 123 |
+
The attention dropout rate in the encoder.
|
| 124 |
+
decoder_dropout_rate (`float`, *optional*, defaults to 0.2):
|
| 125 |
+
The dropout rate in the decoder.
|
| 126 |
+
decoder_positional_dropout_rate (`float`, *optional*, defaults to 0.2):
|
| 127 |
+
The positional dropout rate in the decoder.
|
| 128 |
+
decoder_attention_dropout_rate (`float`, *optional*, defaults to 0.2):
|
| 129 |
+
The attention dropout rate in the decoder.
|
| 130 |
+
duration_predictor_dropout_rate (`float`, *optional*, defaults to 0.2):
|
| 131 |
+
The dropout rate in the duration predictor.
|
| 132 |
+
speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5):
|
| 133 |
+
The dropout rate in the speech decoder postnet.
|
| 134 |
+
max_source_positions (`int`, *optional*, defaults to 5000):
|
| 135 |
+
if `"relative"` position embeddings are used, defines the maximum source input positions.
|
| 136 |
+
use_masking (`bool`, *optional*, defaults to `True`):
|
| 137 |
+
Specifies whether to use masking in the model.
|
| 138 |
+
use_weighted_masking (`bool`, *optional*, defaults to `False`):
|
| 139 |
+
Specifies whether to use weighted masking in the model.
|
| 140 |
+
num_speakers (`int`, *optional*):
|
| 141 |
+
Number of speakers. If set to > 1, assume that the speaker ids will be provided as the input and use
|
| 142 |
+
speaker id embedding layer.
|
| 143 |
+
num_languages (`int`, *optional*):
|
| 144 |
+
Number of languages. If set to > 1, assume that the language ids will be provided as the input and use the
|
| 145 |
+
languge id embedding layer.
|
| 146 |
+
speaker_embed_dim (`int`, *optional*):
|
| 147 |
+
Speaker embedding dimension. If set to > 0, assume that speaker_embedding will be provided as the input.
|
| 148 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
| 149 |
+
Specifies whether the model is an encoder-decoder.
|
| 150 |
+
|
| 151 |
+
Example:
|
| 152 |
+
|
| 153 |
+
```python
|
| 154 |
+
>>> from transformers import FastSpeech2ConformerModel, FastSpeech2ConformerConfig
|
| 155 |
+
|
| 156 |
+
>>> # Initializing a FastSpeech2Conformer style configuration
|
| 157 |
+
>>> configuration = FastSpeech2ConformerConfig()
|
| 158 |
+
|
| 159 |
+
>>> # Initializing a model from the FastSpeech2Conformer style configuration
|
| 160 |
+
>>> model = FastSpeech2ConformerModel(configuration)
|
| 161 |
+
|
| 162 |
+
>>> # Accessing the model configuration
|
| 163 |
+
>>> configuration = model.config
|
| 164 |
+
```"""
|
| 165 |
+
|
| 166 |
+
model_type = "fastspeech2_conformer"
|
| 167 |
+
base_config_key = "model_config"
|
| 168 |
+
attribute_map = {"num_hidden_layers": "encoder_layers", "num_attention_heads": "encoder_num_attention_heads"}
|
| 169 |
+
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
hidden_size=384,
|
| 173 |
+
vocab_size=78,
|
| 174 |
+
num_mel_bins=80,
|
| 175 |
+
encoder_num_attention_heads=2,
|
| 176 |
+
encoder_layers=4,
|
| 177 |
+
encoder_linear_units=1536,
|
| 178 |
+
decoder_layers=4,
|
| 179 |
+
decoder_num_attention_heads=2,
|
| 180 |
+
decoder_linear_units=1536,
|
| 181 |
+
speech_decoder_postnet_layers=5,
|
| 182 |
+
speech_decoder_postnet_units=256,
|
| 183 |
+
speech_decoder_postnet_kernel=5,
|
| 184 |
+
positionwise_conv_kernel_size=3,
|
| 185 |
+
encoder_normalize_before=False,
|
| 186 |
+
decoder_normalize_before=False,
|
| 187 |
+
encoder_concat_after=False,
|
| 188 |
+
decoder_concat_after=False,
|
| 189 |
+
reduction_factor=1,
|
| 190 |
+
speaking_speed=1.0,
|
| 191 |
+
use_macaron_style_in_conformer=True,
|
| 192 |
+
use_cnn_in_conformer=True,
|
| 193 |
+
encoder_kernel_size=7,
|
| 194 |
+
decoder_kernel_size=31,
|
| 195 |
+
duration_predictor_layers=2,
|
| 196 |
+
duration_predictor_channels=256,
|
| 197 |
+
duration_predictor_kernel_size=3,
|
| 198 |
+
energy_predictor_layers=2,
|
| 199 |
+
energy_predictor_channels=256,
|
| 200 |
+
energy_predictor_kernel_size=3,
|
| 201 |
+
energy_predictor_dropout=0.5,
|
| 202 |
+
energy_embed_kernel_size=1,
|
| 203 |
+
energy_embed_dropout=0.0,
|
| 204 |
+
stop_gradient_from_energy_predictor=False,
|
| 205 |
+
pitch_predictor_layers=5,
|
| 206 |
+
pitch_predictor_channels=256,
|
| 207 |
+
pitch_predictor_kernel_size=5,
|
| 208 |
+
pitch_predictor_dropout=0.5,
|
| 209 |
+
pitch_embed_kernel_size=1,
|
| 210 |
+
pitch_embed_dropout=0.0,
|
| 211 |
+
stop_gradient_from_pitch_predictor=True,
|
| 212 |
+
encoder_dropout_rate=0.2,
|
| 213 |
+
encoder_positional_dropout_rate=0.2,
|
| 214 |
+
encoder_attention_dropout_rate=0.2,
|
| 215 |
+
decoder_dropout_rate=0.2,
|
| 216 |
+
decoder_positional_dropout_rate=0.2,
|
| 217 |
+
decoder_attention_dropout_rate=0.2,
|
| 218 |
+
duration_predictor_dropout_rate=0.2,
|
| 219 |
+
speech_decoder_postnet_dropout=0.5,
|
| 220 |
+
max_source_positions=5000,
|
| 221 |
+
use_masking=True,
|
| 222 |
+
use_weighted_masking=False,
|
| 223 |
+
num_speakers=None,
|
| 224 |
+
num_languages=None,
|
| 225 |
+
speaker_embed_dim=None,
|
| 226 |
+
is_encoder_decoder=True,
|
| 227 |
+
**kwargs,
|
| 228 |
+
):
|
| 229 |
+
if positionwise_conv_kernel_size % 2 == 0:
|
| 230 |
+
raise ValueError(
|
| 231 |
+
f"positionwise_conv_kernel_size must be odd, but got {positionwise_conv_kernel_size} instead."
|
| 232 |
+
)
|
| 233 |
+
if encoder_kernel_size % 2 == 0:
|
| 234 |
+
raise ValueError(f"encoder_kernel_size must be odd, but got {encoder_kernel_size} instead.")
|
| 235 |
+
if decoder_kernel_size % 2 == 0:
|
| 236 |
+
raise ValueError(f"decoder_kernel_size must be odd, but got {decoder_kernel_size} instead.")
|
| 237 |
+
if duration_predictor_kernel_size % 2 == 0:
|
| 238 |
+
raise ValueError(
|
| 239 |
+
f"duration_predictor_kernel_size must be odd, but got {duration_predictor_kernel_size} instead."
|
| 240 |
+
)
|
| 241 |
+
if energy_predictor_kernel_size % 2 == 0:
|
| 242 |
+
raise ValueError(
|
| 243 |
+
f"energy_predictor_kernel_size must be odd, but got {energy_predictor_kernel_size} instead."
|
| 244 |
+
)
|
| 245 |
+
if energy_embed_kernel_size % 2 == 0:
|
| 246 |
+
raise ValueError(f"energy_embed_kernel_size must be odd, but got {energy_embed_kernel_size} instead.")
|
| 247 |
+
if pitch_predictor_kernel_size % 2 == 0:
|
| 248 |
+
raise ValueError(
|
| 249 |
+
f"pitch_predictor_kernel_size must be odd, but got {pitch_predictor_kernel_size} instead."
|
| 250 |
+
)
|
| 251 |
+
if pitch_embed_kernel_size % 2 == 0:
|
| 252 |
+
raise ValueError(f"pitch_embed_kernel_size must be odd, but got {pitch_embed_kernel_size} instead.")
|
| 253 |
+
if hidden_size % encoder_num_attention_heads != 0:
|
| 254 |
+
raise ValueError("The hidden_size must be evenly divisible by encoder_num_attention_heads.")
|
| 255 |
+
if hidden_size % decoder_num_attention_heads != 0:
|
| 256 |
+
raise ValueError("The hidden_size must be evenly divisible by decoder_num_attention_heads.")
|
| 257 |
+
if use_masking and use_weighted_masking:
|
| 258 |
+
raise ValueError("Either use_masking or use_weighted_masking can be True, but not both.")
|
| 259 |
+
|
| 260 |
+
self.hidden_size = hidden_size
|
| 261 |
+
self.vocab_size = vocab_size
|
| 262 |
+
self.num_mel_bins = num_mel_bins
|
| 263 |
+
self.encoder_config = {
|
| 264 |
+
"num_attention_heads": encoder_num_attention_heads,
|
| 265 |
+
"layers": encoder_layers,
|
| 266 |
+
"kernel_size": encoder_kernel_size,
|
| 267 |
+
"attention_dropout_rate": encoder_attention_dropout_rate,
|
| 268 |
+
"dropout_rate": encoder_dropout_rate,
|
| 269 |
+
"positional_dropout_rate": encoder_positional_dropout_rate,
|
| 270 |
+
"linear_units": encoder_linear_units,
|
| 271 |
+
"normalize_before": encoder_normalize_before,
|
| 272 |
+
"concat_after": encoder_concat_after,
|
| 273 |
+
}
|
| 274 |
+
self.decoder_config = {
|
| 275 |
+
"num_attention_heads": decoder_num_attention_heads,
|
| 276 |
+
"layers": decoder_layers,
|
| 277 |
+
"kernel_size": decoder_kernel_size,
|
| 278 |
+
"attention_dropout_rate": decoder_attention_dropout_rate,
|
| 279 |
+
"dropout_rate": decoder_dropout_rate,
|
| 280 |
+
"positional_dropout_rate": decoder_positional_dropout_rate,
|
| 281 |
+
"linear_units": decoder_linear_units,
|
| 282 |
+
"normalize_before": decoder_normalize_before,
|
| 283 |
+
"concat_after": decoder_concat_after,
|
| 284 |
+
}
|
| 285 |
+
self.encoder_num_attention_heads = encoder_num_attention_heads
|
| 286 |
+
self.encoder_layers = encoder_layers
|
| 287 |
+
self.duration_predictor_channels = duration_predictor_channels
|
| 288 |
+
self.duration_predictor_kernel_size = duration_predictor_kernel_size
|
| 289 |
+
self.duration_predictor_layers = duration_predictor_layers
|
| 290 |
+
self.energy_embed_dropout = energy_embed_dropout
|
| 291 |
+
self.energy_embed_kernel_size = energy_embed_kernel_size
|
| 292 |
+
self.energy_predictor_channels = energy_predictor_channels
|
| 293 |
+
self.energy_predictor_dropout = energy_predictor_dropout
|
| 294 |
+
self.energy_predictor_kernel_size = energy_predictor_kernel_size
|
| 295 |
+
self.energy_predictor_layers = energy_predictor_layers
|
| 296 |
+
self.pitch_embed_dropout = pitch_embed_dropout
|
| 297 |
+
self.pitch_embed_kernel_size = pitch_embed_kernel_size
|
| 298 |
+
self.pitch_predictor_channels = pitch_predictor_channels
|
| 299 |
+
self.pitch_predictor_dropout = pitch_predictor_dropout
|
| 300 |
+
self.pitch_predictor_kernel_size = pitch_predictor_kernel_size
|
| 301 |
+
self.pitch_predictor_layers = pitch_predictor_layers
|
| 302 |
+
self.positionwise_conv_kernel_size = positionwise_conv_kernel_size
|
| 303 |
+
self.speech_decoder_postnet_units = speech_decoder_postnet_units
|
| 304 |
+
self.speech_decoder_postnet_dropout = speech_decoder_postnet_dropout
|
| 305 |
+
self.speech_decoder_postnet_kernel = speech_decoder_postnet_kernel
|
| 306 |
+
self.speech_decoder_postnet_layers = speech_decoder_postnet_layers
|
| 307 |
+
self.reduction_factor = reduction_factor
|
| 308 |
+
self.speaking_speed = speaking_speed
|
| 309 |
+
self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor
|
| 310 |
+
self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor
|
| 311 |
+
self.max_source_positions = max_source_positions
|
| 312 |
+
self.use_cnn_in_conformer = use_cnn_in_conformer
|
| 313 |
+
self.use_macaron_style_in_conformer = use_macaron_style_in_conformer
|
| 314 |
+
self.use_masking = use_masking
|
| 315 |
+
self.use_weighted_masking = use_weighted_masking
|
| 316 |
+
self.num_speakers = num_speakers
|
| 317 |
+
self.num_languages = num_languages
|
| 318 |
+
self.speaker_embed_dim = speaker_embed_dim
|
| 319 |
+
self.duration_predictor_dropout_rate = duration_predictor_dropout_rate
|
| 320 |
+
self.is_encoder_decoder = is_encoder_decoder
|
| 321 |
+
|
| 322 |
+
super().__init__(
|
| 323 |
+
is_encoder_decoder=is_encoder_decoder,
|
| 324 |
+
**kwargs,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class FastSpeech2ConformerHifiGanConfig(PretrainedConfig):
|
| 329 |
+
r"""
|
| 330 |
+
This is the configuration class to store the configuration of a [`FastSpeech2ConformerHifiGanModel`]. It is used to
|
| 331 |
+
instantiate a FastSpeech2Conformer HiFi-GAN vocoder model according to the specified arguments, defining the model
|
| 332 |
+
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the
|
| 333 |
+
FastSpeech2Conformer
|
| 334 |
+
[espnet/fastspeech2_conformer_hifigan](https://huggingface.co/espnet/fastspeech2_conformer_hifigan) architecture.
|
| 335 |
+
|
| 336 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 337 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
model_in_dim (`int`, *optional*, defaults to 80):
|
| 341 |
+
The number of frequency bins in the input log-mel spectrogram.
|
| 342 |
+
upsample_initial_channel (`int`, *optional*, defaults to 512):
|
| 343 |
+
The number of input channels into the upsampling network.
|
| 344 |
+
upsample_rates (`Tuple[int]` or `List[int]`, *optional*, defaults to `[8, 8, 2, 2]`):
|
| 345 |
+
A tuple of integers defining the stride of each 1D convolutional layer in the upsampling network. The
|
| 346 |
+
length of *upsample_rates* defines the number of convolutional layers and has to match the length of
|
| 347 |
+
*upsample_kernel_sizes*.
|
| 348 |
+
upsample_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[16, 16, 4, 4]`):
|
| 349 |
+
A tuple of integers defining the kernel size of each 1D convolutional layer in the upsampling network. The
|
| 350 |
+
length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match the length of
|
| 351 |
+
*upsample_rates*.
|
| 352 |
+
resblock_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[3, 7, 11]`):
|
| 353 |
+
A tuple of integers defining the kernel sizes of the 1D convolutional layers in the multi-receptive field
|
| 354 |
+
fusion (MRF) module.
|
| 355 |
+
resblock_dilation_sizes (`Tuple[Tuple[int]]` or `List[List[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
|
| 356 |
+
A nested tuple of integers defining the dilation rates of the dilated 1D convolutional layers in the
|
| 357 |
+
multi-receptive field fusion (MRF) module.
|
| 358 |
+
initializer_range (`float`, *optional*, defaults to 0.01):
|
| 359 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 360 |
+
leaky_relu_slope (`float`, *optional*, defaults to 0.1):
|
| 361 |
+
The angle of the negative slope used by the leaky ReLU activation.
|
| 362 |
+
normalize_before (`bool`, *optional*, defaults to `True`):
|
| 363 |
+
Whether or not to normalize the spectrogram before vocoding using the vocoder's learned mean and variance.
|
| 364 |
+
|
| 365 |
+
Example:
|
| 366 |
+
|
| 367 |
+
```python
|
| 368 |
+
>>> from transformers import FastSpeech2ConformerHifiGan, FastSpeech2ConformerHifiGanConfig
|
| 369 |
+
|
| 370 |
+
>>> # Initializing a FastSpeech2ConformerHifiGan configuration
|
| 371 |
+
>>> configuration = FastSpeech2ConformerHifiGanConfig()
|
| 372 |
+
|
| 373 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 374 |
+
>>> model = FastSpeech2ConformerHifiGan(configuration)
|
| 375 |
+
|
| 376 |
+
>>> # Accessing the model configuration
|
| 377 |
+
>>> configuration = model.config
|
| 378 |
+
```"""
|
| 379 |
+
|
| 380 |
+
model_type = "hifigan"
|
| 381 |
+
base_config_key = "vocoder_config"
|
| 382 |
+
|
| 383 |
+
def __init__(
|
| 384 |
+
self,
|
| 385 |
+
model_in_dim=80,
|
| 386 |
+
upsample_initial_channel=512,
|
| 387 |
+
upsample_rates=[8, 8, 2, 2],
|
| 388 |
+
upsample_kernel_sizes=[16, 16, 4, 4],
|
| 389 |
+
resblock_kernel_sizes=[3, 7, 11],
|
| 390 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 391 |
+
initializer_range=0.01,
|
| 392 |
+
leaky_relu_slope=0.1,
|
| 393 |
+
normalize_before=True,
|
| 394 |
+
**kwargs,
|
| 395 |
+
):
|
| 396 |
+
self.model_in_dim = model_in_dim
|
| 397 |
+
self.upsample_initial_channel = upsample_initial_channel
|
| 398 |
+
self.upsample_rates = upsample_rates
|
| 399 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 400 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
| 401 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
| 402 |
+
self.initializer_range = initializer_range
|
| 403 |
+
self.leaky_relu_slope = leaky_relu_slope
|
| 404 |
+
self.normalize_before = normalize_before
|
| 405 |
+
super().__init__(**kwargs)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class FastSpeech2ConformerWithHifiGanConfig(PretrainedConfig):
|
| 409 |
+
"""
|
| 410 |
+
This is the configuration class to store the configuration of a [`FastSpeech2ConformerWithHifiGan`]. It is used to
|
| 411 |
+
instantiate a `FastSpeech2ConformerWithHifiGanModel` model according to the specified sub-models configurations,
|
| 412 |
+
defining the model architecture.
|
| 413 |
+
|
| 414 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the
|
| 415 |
+
FastSpeech2ConformerModel [espnet/fastspeech2_conformer](https://huggingface.co/espnet/fastspeech2_conformer) and
|
| 416 |
+
FastSpeech2ConformerHifiGan
|
| 417 |
+
[espnet/fastspeech2_conformer_hifigan](https://huggingface.co/espnet/fastspeech2_conformer_hifigan) architectures.
|
| 418 |
+
|
| 419 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 420 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
model_config (`typing.Dict`, *optional*):
|
| 424 |
+
Configuration of the text-to-speech model.
|
| 425 |
+
vocoder_config (`typing.Dict`, *optional*):
|
| 426 |
+
Configuration of the vocoder model.
|
| 427 |
+
model_config ([`FastSpeech2ConformerConfig`], *optional*):
|
| 428 |
+
Configuration of the text-to-speech model.
|
| 429 |
+
vocoder_config ([`FastSpeech2ConformerHiFiGanConfig`], *optional*):
|
| 430 |
+
Configuration of the vocoder model.
|
| 431 |
+
|
| 432 |
+
Example:
|
| 433 |
+
|
| 434 |
+
```python
|
| 435 |
+
>>> from transformers import (
|
| 436 |
+
... FastSpeech2ConformerConfig,
|
| 437 |
+
... FastSpeech2ConformerHifiGanConfig,
|
| 438 |
+
... FastSpeech2ConformerWithHifiGanConfig,
|
| 439 |
+
... FastSpeech2ConformerWithHifiGan,
|
| 440 |
+
... )
|
| 441 |
+
|
| 442 |
+
>>> # Initializing FastSpeech2ConformerWithHifiGan sub-modules configurations.
|
| 443 |
+
>>> model_config = FastSpeech2ConformerConfig()
|
| 444 |
+
>>> vocoder_config = FastSpeech2ConformerHifiGanConfig()
|
| 445 |
+
|
| 446 |
+
>>> # Initializing a FastSpeech2ConformerWithHifiGan module style configuration
|
| 447 |
+
>>> configuration = FastSpeech2ConformerWithHifiGanConfig(model_config.to_dict(), vocoder_config.to_dict())
|
| 448 |
+
|
| 449 |
+
>>> # Initializing a model (with random weights)
|
| 450 |
+
>>> model = FastSpeech2ConformerWithHifiGan(configuration)
|
| 451 |
+
|
| 452 |
+
>>> # Accessing the model configuration
|
| 453 |
+
>>> configuration = model.config
|
| 454 |
+
```
|
| 455 |
+
"""
|
| 456 |
+
|
| 457 |
+
model_type = "fastspeech2_conformer_with_hifigan"
|
| 458 |
+
sub_configs = {"model_config": FastSpeech2ConformerConfig, "vocoder_config": FastSpeech2ConformerHifiGanConfig}
|
| 459 |
+
|
| 460 |
+
def __init__(
|
| 461 |
+
self,
|
| 462 |
+
model_config: Dict = None,
|
| 463 |
+
vocoder_config: Dict = None,
|
| 464 |
+
**kwargs,
|
| 465 |
+
):
|
| 466 |
+
if model_config is None:
|
| 467 |
+
model_config = {}
|
| 468 |
+
logger.info("model_config is None. initializing the model with default values.")
|
| 469 |
+
|
| 470 |
+
if vocoder_config is None:
|
| 471 |
+
vocoder_config = {}
|
| 472 |
+
logger.info("vocoder_config is None. initializing the coarse model with default values.")
|
| 473 |
+
|
| 474 |
+
self.model_config = FastSpeech2ConformerConfig(**model_config)
|
| 475 |
+
self.vocoder_config = FastSpeech2ConformerHifiGanConfig(**vocoder_config)
|
| 476 |
+
|
| 477 |
+
super().__init__(**kwargs)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
__all__ = ["FastSpeech2ConformerConfig", "FastSpeech2ConformerHifiGanConfig", "FastSpeech2ConformerWithHifiGanConfig"]
|
docs/transformers/build/lib/transformers/models/fastspeech2_conformer/convert_fastspeech2_conformer_original_pytorch_checkpoint_to_pytorch.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert FastSpeech2Conformer checkpoint."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import re
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from tempfile import TemporaryDirectory
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import yaml
|
| 25 |
+
|
| 26 |
+
from transformers import (
|
| 27 |
+
FastSpeech2ConformerConfig,
|
| 28 |
+
FastSpeech2ConformerModel,
|
| 29 |
+
FastSpeech2ConformerTokenizer,
|
| 30 |
+
logging,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
logging.set_verbosity_info()
|
| 35 |
+
logger = logging.get_logger("transformers.models.FastSpeech2Conformer")
|
| 36 |
+
|
| 37 |
+
CONFIG_MAPPING = {
|
| 38 |
+
"adim": "hidden_size",
|
| 39 |
+
"aheads": "num_attention_heads",
|
| 40 |
+
"conformer_dec_kernel_size": "decoder_kernel_size",
|
| 41 |
+
"conformer_enc_kernel_size": "encoder_kernel_size",
|
| 42 |
+
"decoder_normalize_before": "decoder_normalize_before",
|
| 43 |
+
"dlayers": "decoder_layers",
|
| 44 |
+
"dunits": "decoder_linear_units",
|
| 45 |
+
"duration_predictor_chans": "duration_predictor_channels",
|
| 46 |
+
"duration_predictor_kernel_size": "duration_predictor_kernel_size",
|
| 47 |
+
"duration_predictor_layers": "duration_predictor_layers",
|
| 48 |
+
"elayers": "encoder_layers",
|
| 49 |
+
"encoder_normalize_before": "encoder_normalize_before",
|
| 50 |
+
"energy_embed_dropout": "energy_embed_dropout",
|
| 51 |
+
"energy_embed_kernel_size": "energy_embed_kernel_size",
|
| 52 |
+
"energy_predictor_chans": "energy_predictor_channels",
|
| 53 |
+
"energy_predictor_dropout": "energy_predictor_dropout",
|
| 54 |
+
"energy_predictor_kernel_size": "energy_predictor_kernel_size",
|
| 55 |
+
"energy_predictor_layers": "energy_predictor_layers",
|
| 56 |
+
"eunits": "encoder_linear_units",
|
| 57 |
+
"pitch_embed_dropout": "pitch_embed_dropout",
|
| 58 |
+
"pitch_embed_kernel_size": "pitch_embed_kernel_size",
|
| 59 |
+
"pitch_predictor_chans": "pitch_predictor_channels",
|
| 60 |
+
"pitch_predictor_dropout": "pitch_predictor_dropout",
|
| 61 |
+
"pitch_predictor_kernel_size": "pitch_predictor_kernel_size",
|
| 62 |
+
"pitch_predictor_layers": "pitch_predictor_layers",
|
| 63 |
+
"positionwise_conv_kernel_size": "positionwise_conv_kernel_size",
|
| 64 |
+
"postnet_chans": "speech_decoder_postnet_units",
|
| 65 |
+
"postnet_filts": "speech_decoder_postnet_kernel",
|
| 66 |
+
"postnet_layers": "speech_decoder_postnet_layers",
|
| 67 |
+
"reduction_factor": "reduction_factor",
|
| 68 |
+
"stop_gradient_from_energy_predictor": "stop_gradient_from_energy_predictor",
|
| 69 |
+
"stop_gradient_from_pitch_predictor": "stop_gradient_from_pitch_predictor",
|
| 70 |
+
"transformer_dec_attn_dropout_rate": "decoder_attention_dropout_rate",
|
| 71 |
+
"transformer_dec_dropout_rate": "decoder_dropout_rate",
|
| 72 |
+
"transformer_dec_positional_dropout_rate": "decoder_positional_dropout_rate",
|
| 73 |
+
"transformer_enc_attn_dropout_rate": "encoder_attention_dropout_rate",
|
| 74 |
+
"transformer_enc_dropout_rate": "encoder_dropout_rate",
|
| 75 |
+
"transformer_enc_positional_dropout_rate": "encoder_positional_dropout_rate",
|
| 76 |
+
"use_cnn_in_conformer": "use_cnn_in_conformer",
|
| 77 |
+
"use_macaron_style_in_conformer": "use_macaron_style_in_conformer",
|
| 78 |
+
"use_masking": "use_masking",
|
| 79 |
+
"use_weighted_masking": "use_weighted_masking",
|
| 80 |
+
"idim": "input_dim",
|
| 81 |
+
"odim": "num_mel_bins",
|
| 82 |
+
"spk_embed_dim": "speaker_embed_dim",
|
| 83 |
+
"langs": "num_languages",
|
| 84 |
+
"spks": "num_speakers",
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def remap_model_yaml_config(yaml_config_path):
|
| 89 |
+
with Path(yaml_config_path).open("r", encoding="utf-8") as f:
|
| 90 |
+
args = yaml.safe_load(f)
|
| 91 |
+
args = argparse.Namespace(**args)
|
| 92 |
+
|
| 93 |
+
remapped_config = {}
|
| 94 |
+
|
| 95 |
+
model_params = args.tts_conf["text2mel_params"]
|
| 96 |
+
# espnet_config_key -> hf_config_key, any keys not included are ignored
|
| 97 |
+
for espnet_config_key, hf_config_key in CONFIG_MAPPING.items():
|
| 98 |
+
if espnet_config_key in model_params:
|
| 99 |
+
remapped_config[hf_config_key] = model_params[espnet_config_key]
|
| 100 |
+
|
| 101 |
+
return remapped_config, args.g2p, args.token_list
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def convert_espnet_state_dict_to_hf(state_dict):
|
| 105 |
+
new_state_dict = {}
|
| 106 |
+
for key in state_dict:
|
| 107 |
+
if "tts.generator.text2mel." in key:
|
| 108 |
+
new_key = key.replace("tts.generator.text2mel.", "")
|
| 109 |
+
if "postnet" in key:
|
| 110 |
+
new_key = new_key.replace("postnet.postnet", "speech_decoder_postnet.layers")
|
| 111 |
+
new_key = new_key.replace(".0.weight", ".conv.weight")
|
| 112 |
+
new_key = new_key.replace(".1.weight", ".batch_norm.weight")
|
| 113 |
+
new_key = new_key.replace(".1.bias", ".batch_norm.bias")
|
| 114 |
+
new_key = new_key.replace(".1.running_mean", ".batch_norm.running_mean")
|
| 115 |
+
new_key = new_key.replace(".1.running_var", ".batch_norm.running_var")
|
| 116 |
+
new_key = new_key.replace(".1.num_batches_tracked", ".batch_norm.num_batches_tracked")
|
| 117 |
+
if "feat_out" in key:
|
| 118 |
+
if "weight" in key:
|
| 119 |
+
new_key = "speech_decoder_postnet.feat_out.weight"
|
| 120 |
+
if "bias" in key:
|
| 121 |
+
new_key = "speech_decoder_postnet.feat_out.bias"
|
| 122 |
+
if "encoder.embed.0.weight" in key:
|
| 123 |
+
new_key = new_key.replace("0.", "")
|
| 124 |
+
if "w_1" in key:
|
| 125 |
+
new_key = new_key.replace("w_1", "conv1")
|
| 126 |
+
if "w_2" in key:
|
| 127 |
+
new_key = new_key.replace("w_2", "conv2")
|
| 128 |
+
if "predictor.conv" in key:
|
| 129 |
+
new_key = new_key.replace(".conv", ".conv_layers")
|
| 130 |
+
pattern = r"(\d)\.(\d)"
|
| 131 |
+
replacement = (
|
| 132 |
+
r"\1.conv" if ("2.weight" not in new_key) and ("2.bias" not in new_key) else r"\1.layer_norm"
|
| 133 |
+
)
|
| 134 |
+
new_key = re.sub(pattern, replacement, new_key)
|
| 135 |
+
if "pitch_embed" in key or "energy_embed" in key:
|
| 136 |
+
new_key = new_key.replace("0", "conv")
|
| 137 |
+
if "encoders" in key:
|
| 138 |
+
new_key = new_key.replace("encoders", "conformer_layers")
|
| 139 |
+
new_key = new_key.replace("norm_final", "final_layer_norm")
|
| 140 |
+
new_key = new_key.replace("norm_mha", "self_attn_layer_norm")
|
| 141 |
+
new_key = new_key.replace("norm_ff_macaron", "ff_macaron_layer_norm")
|
| 142 |
+
new_key = new_key.replace("norm_ff", "ff_layer_norm")
|
| 143 |
+
new_key = new_key.replace("norm_conv", "conv_layer_norm")
|
| 144 |
+
if "lid_emb" in key:
|
| 145 |
+
new_key = new_key.replace("lid_emb", "language_id_embedding")
|
| 146 |
+
if "sid_emb" in key:
|
| 147 |
+
new_key = new_key.replace("sid_emb", "speaker_id_embedding")
|
| 148 |
+
|
| 149 |
+
new_state_dict[new_key] = state_dict[key]
|
| 150 |
+
|
| 151 |
+
return new_state_dict
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@torch.no_grad()
|
| 155 |
+
def convert_FastSpeech2ConformerModel_checkpoint(
|
| 156 |
+
checkpoint_path,
|
| 157 |
+
yaml_config_path,
|
| 158 |
+
pytorch_dump_folder_path,
|
| 159 |
+
repo_id=None,
|
| 160 |
+
):
|
| 161 |
+
model_params, tokenizer_name, vocab = remap_model_yaml_config(yaml_config_path)
|
| 162 |
+
config = FastSpeech2ConformerConfig(**model_params)
|
| 163 |
+
|
| 164 |
+
# Prepare the model
|
| 165 |
+
model = FastSpeech2ConformerModel(config)
|
| 166 |
+
|
| 167 |
+
espnet_checkpoint = torch.load(checkpoint_path, weights_only=True)
|
| 168 |
+
hf_compatible_state_dict = convert_espnet_state_dict_to_hf(espnet_checkpoint)
|
| 169 |
+
|
| 170 |
+
model.load_state_dict(hf_compatible_state_dict)
|
| 171 |
+
|
| 172 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 173 |
+
|
| 174 |
+
# Prepare the tokenizer
|
| 175 |
+
with TemporaryDirectory() as tempdir:
|
| 176 |
+
vocab = {token: id for id, token in enumerate(vocab)}
|
| 177 |
+
vocab_file = Path(tempdir) / "vocab.json"
|
| 178 |
+
with open(vocab_file, "w") as f:
|
| 179 |
+
json.dump(vocab, f)
|
| 180 |
+
should_strip_spaces = "no_space" in tokenizer_name
|
| 181 |
+
tokenizer = FastSpeech2ConformerTokenizer(str(vocab_file), should_strip_spaces=should_strip_spaces)
|
| 182 |
+
|
| 183 |
+
tokenizer.save_pretrained(pytorch_dump_folder_path)
|
| 184 |
+
|
| 185 |
+
if repo_id:
|
| 186 |
+
print("Pushing to the hub...")
|
| 187 |
+
model.push_to_hub(repo_id)
|
| 188 |
+
tokenizer.push_to_hub(repo_id)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if __name__ == "__main__":
|
| 192 |
+
parser = argparse.ArgumentParser()
|
| 193 |
+
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
|
| 194 |
+
parser.add_argument(
|
| 195 |
+
"--yaml_config_path", required=True, default=None, type=str, help="Path to config.yaml of model to convert"
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
|
| 199 |
+
)
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
args = parser.parse_args()
|
| 205 |
+
convert_FastSpeech2ConformerModel_checkpoint(
|
| 206 |
+
args.checkpoint_path,
|
| 207 |
+
args.yaml_config_path,
|
| 208 |
+
args.pytorch_dump_folder_path,
|
| 209 |
+
args.push_to_hub,
|
| 210 |
+
)
|
docs/transformers/build/lib/transformers/models/fastspeech2_conformer/convert_hifigan.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert FastSpeech2Conformer HiFi-GAN checkpoint."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import yaml
|
| 22 |
+
|
| 23 |
+
from transformers import FastSpeech2ConformerHifiGan, FastSpeech2ConformerHifiGanConfig, logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logging.set_verbosity_info()
|
| 27 |
+
logger = logging.get_logger("transformers.models.FastSpeech2Conformer")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_weights(checkpoint, hf_model, config):
|
| 31 |
+
vocoder_key_prefix = "tts.generator.vocoder."
|
| 32 |
+
checkpoint = {k.replace(vocoder_key_prefix, ""): v for k, v in checkpoint.items() if vocoder_key_prefix in k}
|
| 33 |
+
|
| 34 |
+
hf_model.apply_weight_norm()
|
| 35 |
+
|
| 36 |
+
hf_model.conv_pre.weight_g.data = checkpoint["input_conv.weight_g"]
|
| 37 |
+
hf_model.conv_pre.weight_v.data = checkpoint["input_conv.weight_v"]
|
| 38 |
+
hf_model.conv_pre.bias.data = checkpoint["input_conv.bias"]
|
| 39 |
+
|
| 40 |
+
for i in range(len(config.upsample_rates)):
|
| 41 |
+
hf_model.upsampler[i].weight_g.data = checkpoint[f"upsamples.{i}.1.weight_g"]
|
| 42 |
+
hf_model.upsampler[i].weight_v.data = checkpoint[f"upsamples.{i}.1.weight_v"]
|
| 43 |
+
hf_model.upsampler[i].bias.data = checkpoint[f"upsamples.{i}.1.bias"]
|
| 44 |
+
|
| 45 |
+
for i in range(len(config.upsample_rates) * len(config.resblock_kernel_sizes)):
|
| 46 |
+
for j in range(len(config.resblock_dilation_sizes)):
|
| 47 |
+
hf_model.resblocks[i].convs1[j].weight_g.data = checkpoint[f"blocks.{i}.convs1.{j}.1.weight_g"]
|
| 48 |
+
hf_model.resblocks[i].convs1[j].weight_v.data = checkpoint[f"blocks.{i}.convs1.{j}.1.weight_v"]
|
| 49 |
+
hf_model.resblocks[i].convs1[j].bias.data = checkpoint[f"blocks.{i}.convs1.{j}.1.bias"]
|
| 50 |
+
|
| 51 |
+
hf_model.resblocks[i].convs2[j].weight_g.data = checkpoint[f"blocks.{i}.convs2.{j}.1.weight_g"]
|
| 52 |
+
hf_model.resblocks[i].convs2[j].weight_v.data = checkpoint[f"blocks.{i}.convs2.{j}.1.weight_v"]
|
| 53 |
+
hf_model.resblocks[i].convs2[j].bias.data = checkpoint[f"blocks.{i}.convs2.{j}.1.bias"]
|
| 54 |
+
|
| 55 |
+
hf_model.conv_post.weight_g.data = checkpoint["output_conv.1.weight_g"]
|
| 56 |
+
hf_model.conv_post.weight_v.data = checkpoint["output_conv.1.weight_v"]
|
| 57 |
+
hf_model.conv_post.bias.data = checkpoint["output_conv.1.bias"]
|
| 58 |
+
|
| 59 |
+
hf_model.remove_weight_norm()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def remap_hifigan_yaml_config(yaml_config_path):
|
| 63 |
+
with Path(yaml_config_path).open("r", encoding="utf-8") as f:
|
| 64 |
+
args = yaml.safe_load(f)
|
| 65 |
+
args = argparse.Namespace(**args)
|
| 66 |
+
|
| 67 |
+
vocoder_type = args.tts_conf["vocoder_type"]
|
| 68 |
+
if vocoder_type != "hifigan_generator":
|
| 69 |
+
raise TypeError(f"Vocoder config must be for `hifigan_generator`, but got {vocoder_type}")
|
| 70 |
+
|
| 71 |
+
remapped_dict = {}
|
| 72 |
+
vocoder_params = args.tts_conf["vocoder_params"]
|
| 73 |
+
|
| 74 |
+
# espnet_config_key -> hf_config_key
|
| 75 |
+
key_mappings = {
|
| 76 |
+
"channels": "upsample_initial_channel",
|
| 77 |
+
"in_channels": "model_in_dim",
|
| 78 |
+
"resblock_dilations": "resblock_dilation_sizes",
|
| 79 |
+
"resblock_kernel_sizes": "resblock_kernel_sizes",
|
| 80 |
+
"upsample_kernel_sizes": "upsample_kernel_sizes",
|
| 81 |
+
"upsample_scales": "upsample_rates",
|
| 82 |
+
}
|
| 83 |
+
for espnet_config_key, hf_config_key in key_mappings.items():
|
| 84 |
+
remapped_dict[hf_config_key] = vocoder_params[espnet_config_key]
|
| 85 |
+
remapped_dict["sampling_rate"] = args.tts_conf["sampling_rate"]
|
| 86 |
+
remapped_dict["normalize_before"] = False
|
| 87 |
+
remapped_dict["leaky_relu_slope"] = vocoder_params["nonlinear_activation_params"]["negative_slope"]
|
| 88 |
+
|
| 89 |
+
return remapped_dict
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@torch.no_grad()
|
| 93 |
+
def convert_hifigan_checkpoint(
|
| 94 |
+
checkpoint_path,
|
| 95 |
+
pytorch_dump_folder_path,
|
| 96 |
+
yaml_config_path=None,
|
| 97 |
+
repo_id=None,
|
| 98 |
+
):
|
| 99 |
+
if yaml_config_path is not None:
|
| 100 |
+
config_kwargs = remap_hifigan_yaml_config(yaml_config_path)
|
| 101 |
+
config = FastSpeech2ConformerHifiGanConfig(**config_kwargs)
|
| 102 |
+
else:
|
| 103 |
+
config = FastSpeech2ConformerHifiGanConfig()
|
| 104 |
+
|
| 105 |
+
model = FastSpeech2ConformerHifiGan(config)
|
| 106 |
+
|
| 107 |
+
orig_checkpoint = torch.load(checkpoint_path, weights_only=True)
|
| 108 |
+
load_weights(orig_checkpoint, model, config)
|
| 109 |
+
|
| 110 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 111 |
+
|
| 112 |
+
if repo_id:
|
| 113 |
+
print("Pushing to the hub...")
|
| 114 |
+
model.push_to_hub(repo_id)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
parser = argparse.ArgumentParser()
|
| 119 |
+
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
|
| 120 |
+
parser.add_argument("--yaml_config_path", default=None, type=str, help="Path to config.yaml of model to convert")
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
args = parser.parse_args()
|
| 129 |
+
convert_hifigan_checkpoint(
|
| 130 |
+
args.checkpoint_path,
|
| 131 |
+
args.pytorch_dump_folder_path,
|
| 132 |
+
args.yaml_config_path,
|
| 133 |
+
args.push_to_hub,
|
| 134 |
+
)
|
docs/transformers/build/lib/transformers/models/fastspeech2_conformer/convert_model_with_hifigan.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert FastSpeech2Conformer checkpoint."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from transformers import (
|
| 22 |
+
FastSpeech2ConformerConfig,
|
| 23 |
+
FastSpeech2ConformerHifiGan,
|
| 24 |
+
FastSpeech2ConformerHifiGanConfig,
|
| 25 |
+
FastSpeech2ConformerModel,
|
| 26 |
+
FastSpeech2ConformerWithHifiGan,
|
| 27 |
+
FastSpeech2ConformerWithHifiGanConfig,
|
| 28 |
+
logging,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
from .convert_fastspeech2_conformer_original_pytorch_checkpoint_to_pytorch import (
|
| 32 |
+
convert_espnet_state_dict_to_hf,
|
| 33 |
+
remap_model_yaml_config,
|
| 34 |
+
)
|
| 35 |
+
from .convert_hifigan import load_weights, remap_hifigan_yaml_config
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
logging.set_verbosity_info()
|
| 39 |
+
logger = logging.get_logger("transformers.models.FastSpeech2Conformer")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def convert_FastSpeech2ConformerWithHifiGan_checkpoint(
|
| 43 |
+
checkpoint_path,
|
| 44 |
+
yaml_config_path,
|
| 45 |
+
pytorch_dump_folder_path,
|
| 46 |
+
repo_id=None,
|
| 47 |
+
):
|
| 48 |
+
# Prepare the model
|
| 49 |
+
model_params, *_ = remap_model_yaml_config(yaml_config_path)
|
| 50 |
+
model_config = FastSpeech2ConformerConfig(**model_params)
|
| 51 |
+
|
| 52 |
+
model = FastSpeech2ConformerModel(model_config)
|
| 53 |
+
|
| 54 |
+
espnet_checkpoint = torch.load(checkpoint_path, weights_only=True)
|
| 55 |
+
hf_compatible_state_dict = convert_espnet_state_dict_to_hf(espnet_checkpoint)
|
| 56 |
+
model.load_state_dict(hf_compatible_state_dict)
|
| 57 |
+
|
| 58 |
+
# Prepare the vocoder
|
| 59 |
+
config_kwargs = remap_hifigan_yaml_config(yaml_config_path)
|
| 60 |
+
vocoder_config = FastSpeech2ConformerHifiGanConfig(**config_kwargs)
|
| 61 |
+
|
| 62 |
+
vocoder = FastSpeech2ConformerHifiGan(vocoder_config)
|
| 63 |
+
load_weights(espnet_checkpoint, vocoder, vocoder_config)
|
| 64 |
+
|
| 65 |
+
# Prepare the model + vocoder
|
| 66 |
+
config = FastSpeech2ConformerWithHifiGanConfig.from_sub_model_configs(model_config, vocoder_config)
|
| 67 |
+
with_hifigan_model = FastSpeech2ConformerWithHifiGan(config)
|
| 68 |
+
with_hifigan_model.model = model
|
| 69 |
+
with_hifigan_model.vocoder = vocoder
|
| 70 |
+
|
| 71 |
+
with_hifigan_model.save_pretrained(pytorch_dump_folder_path)
|
| 72 |
+
|
| 73 |
+
if repo_id:
|
| 74 |
+
print("Pushing to the hub...")
|
| 75 |
+
with_hifigan_model.push_to_hub(repo_id)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
parser = argparse.ArgumentParser()
|
| 80 |
+
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--yaml_config_path", required=True, default=None, type=str, help="Path to config.yaml of model to convert"
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--pytorch_dump_folder_path",
|
| 86 |
+
required=True,
|
| 87 |
+
default=None,
|
| 88 |
+
type=str,
|
| 89 |
+
help="Path to the output `FastSpeech2ConformerModel` PyTorch model.",
|
| 90 |
+
)
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
args = parser.parse_args()
|
| 96 |
+
|
| 97 |
+
convert_FastSpeech2ConformerWithHifiGan_checkpoint(
|
| 98 |
+
args.checkpoint_path,
|
| 99 |
+
args.yaml_config_path,
|
| 100 |
+
args.pytorch_dump_folder_path,
|
| 101 |
+
args.push_to_hub,
|
| 102 |
+
)
|
docs/transformers/build/lib/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py
ADDED
|
@@ -0,0 +1,1697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The Espnet authors, IMS Toucan authors, and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch FastSpeech2Conformer model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from torch import nn
|
| 23 |
+
|
| 24 |
+
from ...modeling_outputs import BaseModelOutput
|
| 25 |
+
from ...modeling_utils import PreTrainedModel
|
| 26 |
+
from ...utils import ModelOutput, add_start_docstrings, logging, replace_return_docstrings
|
| 27 |
+
from .configuration_fastspeech2_conformer import (
|
| 28 |
+
FastSpeech2ConformerConfig,
|
| 29 |
+
FastSpeech2ConformerHifiGanConfig,
|
| 30 |
+
FastSpeech2ConformerWithHifiGanConfig,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class FastSpeech2ConformerModelOutput(ModelOutput):
|
| 39 |
+
"""
|
| 40 |
+
Output type of [`FastSpeech2ConformerModel`].
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 44 |
+
Spectrogram generation loss.
|
| 45 |
+
spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`):
|
| 46 |
+
The predicted spectrogram.
|
| 47 |
+
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 48 |
+
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
| 49 |
+
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 50 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 51 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 52 |
+
|
| 53 |
+
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
|
| 54 |
+
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 55 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 56 |
+
sequence_length)`.
|
| 57 |
+
|
| 58 |
+
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
| 59 |
+
self-attention heads.
|
| 60 |
+
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 61 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 62 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 63 |
+
|
| 64 |
+
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
|
| 65 |
+
decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 66 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 67 |
+
sequence_length)`.
|
| 68 |
+
|
| 69 |
+
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
| 70 |
+
self-attention heads.
|
| 71 |
+
duration_outputs (`torch.LongTensor` of shape `(batch_size, max_text_length + 1)`, *optional*):
|
| 72 |
+
Outputs of the duration predictor.
|
| 73 |
+
pitch_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
|
| 74 |
+
Outputs of the pitch predictor.
|
| 75 |
+
energy_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
|
| 76 |
+
Outputs of the energy predictor.
|
| 77 |
+
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
loss: Optional[torch.FloatTensor] = None
|
| 81 |
+
spectrogram: Optional[torch.FloatTensor] = None
|
| 82 |
+
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
| 83 |
+
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 84 |
+
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 85 |
+
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 86 |
+
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 87 |
+
duration_outputs: Optional[torch.LongTensor] = None
|
| 88 |
+
pitch_outputs: Optional[torch.FloatTensor] = None
|
| 89 |
+
energy_outputs: Optional[torch.FloatTensor] = None
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclass
|
| 93 |
+
class FastSpeech2ConformerWithHifiGanOutput(FastSpeech2ConformerModelOutput):
|
| 94 |
+
"""
|
| 95 |
+
Output type of [`FastSpeech2ConformerWithHifiGan`].
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
waveform (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
|
| 99 |
+
Speech output as a result of passing the predicted mel spectrogram through the vocoder.
|
| 100 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 101 |
+
Spectrogram generation loss.
|
| 102 |
+
spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`):
|
| 103 |
+
The predicted spectrogram.
|
| 104 |
+
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 105 |
+
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
| 106 |
+
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 107 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 108 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 109 |
+
|
| 110 |
+
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
|
| 111 |
+
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 112 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 113 |
+
sequence_length)`.
|
| 114 |
+
|
| 115 |
+
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
| 116 |
+
self-attention heads.
|
| 117 |
+
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 118 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 119 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 120 |
+
|
| 121 |
+
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
|
| 122 |
+
decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 123 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 124 |
+
sequence_length)`.
|
| 125 |
+
|
| 126 |
+
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
| 127 |
+
self-attention heads.
|
| 128 |
+
duration_outputs (`torch.LongTensor` of shape `(batch_size, max_text_length + 1)`, *optional*):
|
| 129 |
+
Outputs of the duration predictor.
|
| 130 |
+
pitch_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
|
| 131 |
+
Outputs of the pitch predictor.
|
| 132 |
+
energy_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
|
| 133 |
+
Outputs of the energy predictor.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
waveform: Optional[torch.FloatTensor] = None
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
_CONFIG_FOR_DOC = "FastSpeech2ConformerConfig"
|
| 140 |
+
|
| 141 |
+
FASTSPEECH2_CONFORMER_START_DOCSTRING = r"""
|
| 142 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 143 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 144 |
+
etc.)
|
| 145 |
+
|
| 146 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 147 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 148 |
+
and behavior.
|
| 149 |
+
|
| 150 |
+
Parameters:
|
| 151 |
+
config ([`FastSpeech2ConformerConfig`]):
|
| 152 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 153 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 154 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
HIFIGAN_START_DOCSTRING = r"""
|
| 159 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 160 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 161 |
+
etc.)
|
| 162 |
+
|
| 163 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 164 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 165 |
+
and behavior.
|
| 166 |
+
|
| 167 |
+
Parameters:
|
| 168 |
+
config ([`FastSpeech2ConformerConfig`]):
|
| 169 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 170 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 171 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
FASTSPEECH2_CONFORMER_WITH_HIFIGAN_START_DOCSTRING = r"""
|
| 175 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 176 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 177 |
+
etc.)
|
| 178 |
+
|
| 179 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 180 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 181 |
+
and behavior.
|
| 182 |
+
|
| 183 |
+
Parameters:
|
| 184 |
+
config ([`FastSpeech2ConformerWithHifiGanConfig`]):
|
| 185 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 186 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 187 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def length_regulator(encoded_embeddings, duration_labels, speaking_speed=1.0):
|
| 192 |
+
"""
|
| 193 |
+
Length regulator for feed-forward Transformer.
|
| 194 |
+
|
| 195 |
+
This is the length regulator module described in `FastSpeech: Fast, Robust and Controllable Text to Speech`
|
| 196 |
+
https://arxiv.org/pdf/1905.09263.pdf. The length regulator expands char or phoneme-level embedding features to
|
| 197 |
+
frame-level by repeating each feature based on the corresponding predicted durations.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
encoded_embeddings (`torch.Tensor` of shape `(batch_size, max_text_length, embedding_dim)`):
|
| 201 |
+
Batch of sequences of char or phoneme embeddings.
|
| 202 |
+
duration_labels (`torch.LongTensor` of shape `(batch_size, time)`):
|
| 203 |
+
Batch of durations of each frame.
|
| 204 |
+
speaking_speed (`float`, *optional*, defaults to 1.0):
|
| 205 |
+
Value to control speed of speech.
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
`torch.Tensor`:
|
| 209 |
+
Replicated input tensor based on durations (batch_size, time*, embedding_dim).
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
if speaking_speed <= 0:
|
| 213 |
+
raise ValueError("`speaking_speed` must be greater than 0.")
|
| 214 |
+
elif speaking_speed != 1.0:
|
| 215 |
+
duration_labels = torch.round(duration_labels.float() * speaking_speed).long()
|
| 216 |
+
|
| 217 |
+
if duration_labels.sum() == 0:
|
| 218 |
+
duration_labels[duration_labels.sum(dim=1).eq(0)] = 1
|
| 219 |
+
|
| 220 |
+
# Calculate the maximum length needed
|
| 221 |
+
max_len = torch.sum(duration_labels, dim=1).max()
|
| 222 |
+
|
| 223 |
+
# Create a padded tensor to hold the results
|
| 224 |
+
hidden_states = torch.zeros(
|
| 225 |
+
(encoded_embeddings.size(0), max_len, encoded_embeddings.size(2)),
|
| 226 |
+
dtype=torch.float,
|
| 227 |
+
device=encoded_embeddings.device,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Loop through the batch and fill in the data
|
| 231 |
+
for i, (encoded_embedding, target_duration) in enumerate(zip(encoded_embeddings, duration_labels)):
|
| 232 |
+
repeated = torch.repeat_interleave(encoded_embedding, target_duration, dim=0)
|
| 233 |
+
hidden_states[i, : repeated.size(0)] = repeated
|
| 234 |
+
|
| 235 |
+
return hidden_states
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class FastSpeech2ConformerDurationPredictor(nn.Module):
|
| 239 |
+
"""
|
| 240 |
+
Duration predictor module.
|
| 241 |
+
|
| 242 |
+
This is a module of duration predictor described in the paper 'FastSpeech: Fast, Robust and Controllable Text to
|
| 243 |
+
Speech' https://arxiv.org/pdf/1905.09263.pdf The duration predictor predicts a duration of each frame in log domain
|
| 244 |
+
from the hidden embeddings of encoder.
|
| 245 |
+
|
| 246 |
+
Note:
|
| 247 |
+
The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`, the
|
| 248 |
+
outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
|
| 249 |
+
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
def __init__(self, config: FastSpeech2ConformerConfig):
|
| 253 |
+
super().__init__()
|
| 254 |
+
|
| 255 |
+
self.conv_layers = nn.ModuleList()
|
| 256 |
+
self.log_domain_offset = 1.0
|
| 257 |
+
|
| 258 |
+
for layer_idx in range(config.duration_predictor_layers):
|
| 259 |
+
num_chans = config.duration_predictor_channels
|
| 260 |
+
input_channels = config.hidden_size if layer_idx == 0 else num_chans
|
| 261 |
+
layer = FastSpeech2ConformerPredictorLayer(
|
| 262 |
+
input_channels,
|
| 263 |
+
num_chans,
|
| 264 |
+
config.duration_predictor_kernel_size,
|
| 265 |
+
config.duration_predictor_dropout_rate,
|
| 266 |
+
)
|
| 267 |
+
self.conv_layers.append(layer)
|
| 268 |
+
self.linear = nn.Linear(config.duration_predictor_channels, 1)
|
| 269 |
+
|
| 270 |
+
def forward(self, encoder_hidden_states):
|
| 271 |
+
"""
|
| 272 |
+
Args:
|
| 273 |
+
hidden_states (`torch.Tensor` of shape `(batch_size, max_text_length, input_dim)`):
|
| 274 |
+
Batch of input sequences.
|
| 275 |
+
padding_masks (`torch.ByteTensor` of shape `(batch_size, max_text_length)`, *optional*):
|
| 276 |
+
Batch of masks indicating padded part.
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
`torch.Tensor`: Batch of predicted durations in log domain `(batch_size, max_text_length)`.
|
| 280 |
+
|
| 281 |
+
"""
|
| 282 |
+
# (batch_size, input_dim, max_text_length)
|
| 283 |
+
hidden_states = encoder_hidden_states.transpose(1, -1)
|
| 284 |
+
for layer in self.conv_layers:
|
| 285 |
+
hidden_states = layer(hidden_states)
|
| 286 |
+
|
| 287 |
+
# NOTE: calculate in log domain, (batch_size, max_text_length)
|
| 288 |
+
hidden_states = self.linear(hidden_states.transpose(1, -1)).squeeze(-1)
|
| 289 |
+
|
| 290 |
+
if not self.training:
|
| 291 |
+
# NOTE: calculate in linear domain
|
| 292 |
+
hidden_states = torch.clamp(torch.round(hidden_states.exp() - self.log_domain_offset), min=0).long()
|
| 293 |
+
|
| 294 |
+
return hidden_states
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# Copied from transformers.models.speecht5.modeling_speecht5.SpeechT5BatchNormConvLayer
|
| 298 |
+
class FastSpeech2ConformerBatchNormConvLayer(nn.Module):
|
| 299 |
+
def __init__(self, config, layer_id=0):
|
| 300 |
+
super().__init__()
|
| 301 |
+
|
| 302 |
+
if layer_id == 0:
|
| 303 |
+
in_conv_dim = config.num_mel_bins
|
| 304 |
+
else:
|
| 305 |
+
in_conv_dim = config.speech_decoder_postnet_units
|
| 306 |
+
|
| 307 |
+
if layer_id == config.speech_decoder_postnet_layers - 1:
|
| 308 |
+
out_conv_dim = config.num_mel_bins
|
| 309 |
+
else:
|
| 310 |
+
out_conv_dim = config.speech_decoder_postnet_units
|
| 311 |
+
|
| 312 |
+
self.conv = nn.Conv1d(
|
| 313 |
+
in_conv_dim,
|
| 314 |
+
out_conv_dim,
|
| 315 |
+
kernel_size=config.speech_decoder_postnet_kernel,
|
| 316 |
+
stride=1,
|
| 317 |
+
padding=(config.speech_decoder_postnet_kernel - 1) // 2,
|
| 318 |
+
bias=False,
|
| 319 |
+
)
|
| 320 |
+
self.batch_norm = nn.BatchNorm1d(out_conv_dim)
|
| 321 |
+
|
| 322 |
+
if layer_id < config.speech_decoder_postnet_layers - 1:
|
| 323 |
+
self.activation = nn.Tanh()
|
| 324 |
+
else:
|
| 325 |
+
self.activation = None
|
| 326 |
+
|
| 327 |
+
self.dropout = nn.Dropout(config.speech_decoder_postnet_dropout)
|
| 328 |
+
|
| 329 |
+
def forward(self, hidden_states):
|
| 330 |
+
hidden_states = self.conv(hidden_states)
|
| 331 |
+
hidden_states = self.batch_norm(hidden_states)
|
| 332 |
+
if self.activation is not None:
|
| 333 |
+
hidden_states = self.activation(hidden_states)
|
| 334 |
+
hidden_states = self.dropout(hidden_states)
|
| 335 |
+
return hidden_states
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class FastSpeech2ConformerSpeechDecoderPostnet(nn.Module):
|
| 339 |
+
def __init__(self, config):
|
| 340 |
+
super().__init__()
|
| 341 |
+
self.config = config
|
| 342 |
+
self.feat_out = nn.Linear(config.hidden_size, config.num_mel_bins * config.reduction_factor)
|
| 343 |
+
self.layers = nn.ModuleList(
|
| 344 |
+
[FastSpeech2ConformerBatchNormConvLayer(config, i) for i in range(config.speech_decoder_postnet_layers)]
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 348 |
+
outputs_before_postnet = self.feat_out(hidden_states).view(hidden_states.size(0), -1, self.config.num_mel_bins)
|
| 349 |
+
layer_output = outputs_before_postnet.transpose(1, 2)
|
| 350 |
+
for layer in self.layers:
|
| 351 |
+
layer_output = layer(layer_output)
|
| 352 |
+
outputs_after_postnet = outputs_before_postnet + layer_output.transpose(1, 2)
|
| 353 |
+
return outputs_before_postnet, outputs_after_postnet
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class FastSpeech2ConformerPredictorLayer(nn.Module):
|
| 357 |
+
def __init__(self, input_channels, num_chans, kernel_size, dropout_rate):
|
| 358 |
+
super().__init__()
|
| 359 |
+
self.conv = nn.Conv1d(
|
| 360 |
+
input_channels,
|
| 361 |
+
num_chans,
|
| 362 |
+
kernel_size,
|
| 363 |
+
stride=1,
|
| 364 |
+
padding=(kernel_size - 1) // 2,
|
| 365 |
+
)
|
| 366 |
+
self.activation = nn.ReLU()
|
| 367 |
+
self.layer_norm = nn.LayerNorm(num_chans)
|
| 368 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 369 |
+
|
| 370 |
+
def forward(self, hidden_states):
|
| 371 |
+
hidden_states = self.conv(hidden_states)
|
| 372 |
+
hidden_states = self.activation(hidden_states)
|
| 373 |
+
|
| 374 |
+
# Perform layer norm on dimension 1
|
| 375 |
+
hidden_states = hidden_states.transpose(1, -1)
|
| 376 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 377 |
+
hidden_states = hidden_states.transpose(1, -1)
|
| 378 |
+
|
| 379 |
+
hidden_states = self.dropout(hidden_states)
|
| 380 |
+
|
| 381 |
+
return hidden_states
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class FastSpeech2ConformerVariancePredictor(nn.Module):
|
| 385 |
+
def __init__(
|
| 386 |
+
self,
|
| 387 |
+
config: FastSpeech2ConformerConfig,
|
| 388 |
+
num_layers=2,
|
| 389 |
+
num_chans=384,
|
| 390 |
+
kernel_size=3,
|
| 391 |
+
dropout_rate=0.5,
|
| 392 |
+
):
|
| 393 |
+
"""
|
| 394 |
+
Initilize variance predictor module.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
input_dim (`int`): Input dimension.
|
| 398 |
+
num_layers (`int`, *optional*, defaults to 2): Number of convolutional layers.
|
| 399 |
+
num_chans (`int`, *optional*, defaults to 384): Number of channels of convolutional layers.
|
| 400 |
+
kernel_size (`int`, *optional*, defaults to 3): Kernel size of convolutional layers.
|
| 401 |
+
dropout_rate (`float`, *optional*, defaults to 0.5): Dropout rate.
|
| 402 |
+
"""
|
| 403 |
+
super().__init__()
|
| 404 |
+
self.conv_layers = nn.ModuleList()
|
| 405 |
+
for idx in range(num_layers):
|
| 406 |
+
input_channels = config.hidden_size if idx == 0 else num_chans
|
| 407 |
+
layer = FastSpeech2ConformerPredictorLayer(input_channels, num_chans, kernel_size, dropout_rate)
|
| 408 |
+
self.conv_layers.append(layer)
|
| 409 |
+
self.linear = nn.Linear(num_chans, 1)
|
| 410 |
+
|
| 411 |
+
def forward(self, encoder_hidden_states, padding_masks=None):
|
| 412 |
+
"""
|
| 413 |
+
Calculate forward propagation.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, max_text_length, input_dim)`):
|
| 417 |
+
Batch of input sequences.
|
| 418 |
+
padding_masks (`torch.ByteTensor` of shape `(batch_size, max_text_length)`, *optional*):
|
| 419 |
+
Batch of masks indicating padded part.
|
| 420 |
+
|
| 421 |
+
Returns:
|
| 422 |
+
Tensor: Batch of predicted sequences `(batch_size, max_text_length, 1)`.
|
| 423 |
+
"""
|
| 424 |
+
# (batch_size, input_dim, max_text_length)
|
| 425 |
+
hidden_states = encoder_hidden_states.transpose(1, -1)
|
| 426 |
+
for layer in self.conv_layers:
|
| 427 |
+
hidden_states = layer(hidden_states)
|
| 428 |
+
|
| 429 |
+
hidden_states = self.linear(hidden_states.transpose(1, 2))
|
| 430 |
+
|
| 431 |
+
if padding_masks is not None:
|
| 432 |
+
hidden_states = hidden_states.masked_fill(padding_masks, 0.0)
|
| 433 |
+
|
| 434 |
+
return hidden_states
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
class FastSpeech2ConformerVarianceEmbedding(nn.Module):
|
| 438 |
+
def __init__(
|
| 439 |
+
self,
|
| 440 |
+
in_channels=1,
|
| 441 |
+
out_channels=384,
|
| 442 |
+
kernel_size=1,
|
| 443 |
+
padding=0,
|
| 444 |
+
dropout_rate=0.0,
|
| 445 |
+
):
|
| 446 |
+
super().__init__()
|
| 447 |
+
self.conv = nn.Conv1d(
|
| 448 |
+
in_channels=in_channels,
|
| 449 |
+
out_channels=out_channels,
|
| 450 |
+
kernel_size=kernel_size,
|
| 451 |
+
padding=padding,
|
| 452 |
+
)
|
| 453 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 454 |
+
|
| 455 |
+
def forward(self, hidden_states):
|
| 456 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 457 |
+
hidden_states = self.conv(hidden_states)
|
| 458 |
+
hidden_states = self.dropout(hidden_states)
|
| 459 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 460 |
+
return hidden_states
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
class FastSpeech2ConformerAttention(nn.Module):
|
| 464 |
+
"""
|
| 465 |
+
Multi-Head attention layer with relative position encoding. Details can be found in
|
| 466 |
+
https://github.com/espnet/espnet/pull/2816. Paper: https://arxiv.org/abs/1901.02860.
|
| 467 |
+
"""
|
| 468 |
+
|
| 469 |
+
def __init__(self, config: FastSpeech2ConformerConfig, module_config):
|
| 470 |
+
"""Construct an FastSpeech2ConformerAttention object."""
|
| 471 |
+
super().__init__()
|
| 472 |
+
# We assume d_v always equals dim_key
|
| 473 |
+
self.num_heads = module_config["num_attention_heads"]
|
| 474 |
+
self.hidden_size = config.hidden_size
|
| 475 |
+
self.dim_key = self.hidden_size // self.num_heads
|
| 476 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 477 |
+
self.linear_q = nn.Linear(self.hidden_size, self.hidden_size)
|
| 478 |
+
self.linear_k = nn.Linear(self.hidden_size, self.hidden_size)
|
| 479 |
+
self.linear_v = nn.Linear(self.hidden_size, self.hidden_size)
|
| 480 |
+
self.linear_out = nn.Linear(self.hidden_size, self.hidden_size)
|
| 481 |
+
self.dropout = nn.Dropout(p=module_config["attention_dropout_rate"])
|
| 482 |
+
|
| 483 |
+
# linear transformation for positional encoding
|
| 484 |
+
self.linear_pos = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
| 485 |
+
# these two learnable bias are used in matrix c and matrix d
|
| 486 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 487 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.num_heads, self.head_dim))
|
| 488 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.num_heads, self.head_dim))
|
| 489 |
+
|
| 490 |
+
def shift_relative_position_tensor(self, pos_tensor):
|
| 491 |
+
"""
|
| 492 |
+
Args:
|
| 493 |
+
pos_tensor (torch.Tensor of shape (batch_size, head, time1, 2*time1-1)): Input tensor.
|
| 494 |
+
"""
|
| 495 |
+
zero_pad = torch.zeros((*pos_tensor.size()[:3], 1), device=pos_tensor.device, dtype=pos_tensor.dtype)
|
| 496 |
+
pos_tensor_padded = torch.cat([zero_pad, pos_tensor], dim=-1)
|
| 497 |
+
|
| 498 |
+
pos_tensor_padded = pos_tensor_padded.view(*pos_tensor.size()[:2], pos_tensor.size(3) + 1, pos_tensor.size(2))
|
| 499 |
+
# only keep the positions from 0 to time2
|
| 500 |
+
pos_tensor = pos_tensor_padded[:, :, 1:].view_as(pos_tensor)[:, :, :, : pos_tensor.size(-1) // 2 + 1]
|
| 501 |
+
|
| 502 |
+
return pos_tensor
|
| 503 |
+
|
| 504 |
+
def forward(
|
| 505 |
+
self,
|
| 506 |
+
hidden_states: torch.Tensor,
|
| 507 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 508 |
+
pos_emb: Optional[torch.Tensor] = None,
|
| 509 |
+
output_attentions: Optional[torch.Tensor] = False,
|
| 510 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 511 |
+
"""
|
| 512 |
+
Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
| 513 |
+
|
| 514 |
+
Args:
|
| 515 |
+
hidden_states (`torch.Tensor` of shape `(batch, time2, size)`): Values of the hidden states
|
| 516 |
+
attention_mask (`torch.Tensor` of shape `(batch, time1, time2)`): Mask tensor.
|
| 517 |
+
pos_emb (`torch.Tensor` of shape `(batch, 2*time1-1, size)`): Positional embedding tensor.
|
| 518 |
+
output_attentions (`bool`, *optional*):
|
| 519 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 520 |
+
returned tensors for more detail.
|
| 521 |
+
Returns:
|
| 522 |
+
`torch.Tensor`: Output tensor of shape `(batch, time1, d_model)`.
|
| 523 |
+
"""
|
| 524 |
+
bsz, q_len, _ = hidden_states.size()
|
| 525 |
+
query_states = self.linear_q(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
|
| 526 |
+
key_states = self.linear_k(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
|
| 527 |
+
value_states = self.linear_v(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
|
| 528 |
+
|
| 529 |
+
bsz_pos = pos_emb.size(0)
|
| 530 |
+
pos_encoding = self.linear_pos(pos_emb).view(bsz_pos, -1, self.num_heads, self.head_dim)
|
| 531 |
+
|
| 532 |
+
# (batch_size, head, time1, dim_key)
|
| 533 |
+
query_with_bias_u = (query_states + self.pos_bias_u).transpose(1, 2)
|
| 534 |
+
# (batch_size, head, time1, dim_key)
|
| 535 |
+
query_with_bias_v = (query_states + self.pos_bias_v).transpose(1, 2)
|
| 536 |
+
|
| 537 |
+
# compute attention score
|
| 538 |
+
# first compute matrix a and matrix c
|
| 539 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 540 |
+
# (batch_size, head, time1, time2)
|
| 541 |
+
matrix_ac = torch.matmul(query_with_bias_u, key_states.permute(0, 2, 3, 1))
|
| 542 |
+
|
| 543 |
+
# compute matrix b and matrix d
|
| 544 |
+
# (batch_size, head, time1, 2*time1-1)
|
| 545 |
+
matrix_bd = torch.matmul(query_with_bias_v, pos_encoding.permute(0, 2, 3, 1))
|
| 546 |
+
matrix_bd = self.shift_relative_position_tensor(matrix_bd)
|
| 547 |
+
|
| 548 |
+
# (batch_size, head, time1, time2)
|
| 549 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(self.dim_key)
|
| 550 |
+
|
| 551 |
+
# Forward attention
|
| 552 |
+
if attention_mask is not None:
|
| 553 |
+
expected_size = (bsz, 1, q_len)
|
| 554 |
+
if attention_mask.size() != expected_size:
|
| 555 |
+
raise ValueError(f"Attention mask should be of size {expected_size}, but is {attention_mask.size()}")
|
| 556 |
+
attention_mask = attention_mask.unsqueeze(1).eq(0)
|
| 557 |
+
min_value = float(torch.finfo(scores.dtype).min)
|
| 558 |
+
scores = scores.masked_fill(attention_mask, min_value)
|
| 559 |
+
attn_weights = torch.softmax(scores, dim=-1).masked_fill(attention_mask, 0.0)
|
| 560 |
+
else:
|
| 561 |
+
attn_weights = torch.softmax(scores, dim=-1)
|
| 562 |
+
|
| 563 |
+
attn_weights = self.dropout(attn_weights)
|
| 564 |
+
attn_output = torch.matmul(attn_weights, value_states.transpose(1, 2))
|
| 565 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
|
| 566 |
+
|
| 567 |
+
attn_output = self.linear_out(attn_output)
|
| 568 |
+
|
| 569 |
+
if not output_attentions:
|
| 570 |
+
attn_weights = None
|
| 571 |
+
|
| 572 |
+
return attn_output, attn_weights
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
class FastSpeech2ConformerConvolutionModule(nn.Module):
|
| 576 |
+
def __init__(self, config: FastSpeech2ConformerConfig, module_config):
|
| 577 |
+
super().__init__()
|
| 578 |
+
# kernel_size should be an odd number for 'SAME' padding
|
| 579 |
+
channels = config.hidden_size
|
| 580 |
+
kernel_size = module_config["kernel_size"]
|
| 581 |
+
self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True)
|
| 582 |
+
self.depthwise_conv = nn.Conv1d(
|
| 583 |
+
channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=True
|
| 584 |
+
)
|
| 585 |
+
self.norm = nn.BatchNorm1d(channels)
|
| 586 |
+
self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True)
|
| 587 |
+
|
| 588 |
+
def forward(self, hidden_states):
|
| 589 |
+
"""
|
| 590 |
+
Compute convolution module.
|
| 591 |
+
|
| 592 |
+
Args:
|
| 593 |
+
hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
|
| 594 |
+
|
| 595 |
+
Returns:
|
| 596 |
+
`torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
|
| 597 |
+
|
| 598 |
+
"""
|
| 599 |
+
# exchange the temporal dimension and the feature dimension
|
| 600 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 601 |
+
|
| 602 |
+
# GLU mechanism, (batch_size, 2*channel, dim)
|
| 603 |
+
hidden_states = self.pointwise_conv1(hidden_states)
|
| 604 |
+
# (batch_size, channel, dim)
|
| 605 |
+
hidden_states = nn.functional.glu(hidden_states, dim=1)
|
| 606 |
+
|
| 607 |
+
# 1D Depthwise Conv
|
| 608 |
+
hidden_states = self.depthwise_conv(hidden_states)
|
| 609 |
+
hidden_states = self.norm(hidden_states)
|
| 610 |
+
|
| 611 |
+
hidden_states = hidden_states * torch.sigmoid(hidden_states)
|
| 612 |
+
|
| 613 |
+
hidden_states = self.pointwise_conv2(hidden_states)
|
| 614 |
+
|
| 615 |
+
return hidden_states.transpose(1, 2)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
class FastSpeech2ConformerEncoderLayer(nn.Module):
|
| 619 |
+
def __init__(self, config: FastSpeech2ConformerConfig, module_config):
|
| 620 |
+
super().__init__()
|
| 621 |
+
|
| 622 |
+
# self-attention module definition
|
| 623 |
+
self.self_attn = FastSpeech2ConformerAttention(config, module_config)
|
| 624 |
+
|
| 625 |
+
# feed-forward module definition
|
| 626 |
+
self.feed_forward = FastSpeech2ConformerMultiLayeredConv1d(config, module_config)
|
| 627 |
+
|
| 628 |
+
self.macaron_style = config.use_macaron_style_in_conformer
|
| 629 |
+
if self.macaron_style:
|
| 630 |
+
self.feed_forward_macaron = FastSpeech2ConformerMultiLayeredConv1d(config, module_config)
|
| 631 |
+
self.ff_macaron_layer_norm = nn.LayerNorm(config.hidden_size)
|
| 632 |
+
self.ff_scale = 0.5
|
| 633 |
+
else:
|
| 634 |
+
self.ff_scale = 1.0
|
| 635 |
+
|
| 636 |
+
# convolution module definition
|
| 637 |
+
self.use_cnn_module = config.use_cnn_in_conformer
|
| 638 |
+
if self.use_cnn_module:
|
| 639 |
+
self.conv_module = FastSpeech2ConformerConvolutionModule(config, module_config)
|
| 640 |
+
self.conv_layer_norm = nn.LayerNorm(config.hidden_size)
|
| 641 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
|
| 642 |
+
|
| 643 |
+
self.ff_layer_norm = nn.LayerNorm(config.hidden_size)
|
| 644 |
+
|
| 645 |
+
self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size)
|
| 646 |
+
|
| 647 |
+
self.dropout = nn.Dropout(module_config["dropout_rate"])
|
| 648 |
+
self.size = config.hidden_size
|
| 649 |
+
self.normalize_before = module_config["normalize_before"]
|
| 650 |
+
self.concat_after = module_config["concat_after"]
|
| 651 |
+
if self.concat_after:
|
| 652 |
+
self.concat_linear = nn.Linear(config.hidden_size + config.hidden_size, config.hidden_size)
|
| 653 |
+
|
| 654 |
+
def forward(
|
| 655 |
+
self,
|
| 656 |
+
hidden_states: torch.Tensor,
|
| 657 |
+
pos_emb: Optional[torch.Tensor] = None,
|
| 658 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 659 |
+
output_attentions: Optional[torch.Tensor] = False,
|
| 660 |
+
):
|
| 661 |
+
"""
|
| 662 |
+
Compute encoded features.
|
| 663 |
+
|
| 664 |
+
Args:
|
| 665 |
+
hidden_states (`torch.Tensor` of shape `(batch, time, size)`): Input tensor.
|
| 666 |
+
pos_emb (`torch.Tensor` of shape `(1, time, size)`): Positional embeddings tensor.
|
| 667 |
+
attention_mask (`torch.Tensor` of shape `(batch, time)`): Attention mask tensor for the input.
|
| 668 |
+
output_attentions (`bool`, *optional*):
|
| 669 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 670 |
+
returned tensors for more detail.
|
| 671 |
+
Returns:
|
| 672 |
+
`torch.Tensor`: Output tensor of shape `(batch, time, size)`.
|
| 673 |
+
|
| 674 |
+
"""
|
| 675 |
+
# whether to use macaron style
|
| 676 |
+
if self.macaron_style:
|
| 677 |
+
residual = hidden_states
|
| 678 |
+
if self.normalize_before:
|
| 679 |
+
hidden_states = self.ff_macaron_layer_norm(hidden_states)
|
| 680 |
+
hidden_states = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(hidden_states))
|
| 681 |
+
if not self.normalize_before:
|
| 682 |
+
hidden_states = self.ff_macaron_layer_norm(hidden_states)
|
| 683 |
+
|
| 684 |
+
# multi-headed self-attention module
|
| 685 |
+
residual = hidden_states
|
| 686 |
+
if self.normalize_before:
|
| 687 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 688 |
+
|
| 689 |
+
attention_output, attention_scores = self.self_attn(
|
| 690 |
+
hidden_states, attention_mask=attention_mask, pos_emb=pos_emb, output_attentions=output_attentions
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
if self.concat_after:
|
| 694 |
+
x_concat = torch.cat((hidden_states, attention_output), dim=-1)
|
| 695 |
+
hidden_states = self.concat_linear(x_concat)
|
| 696 |
+
hidden_states = residual + hidden_states
|
| 697 |
+
else:
|
| 698 |
+
hidden_states = self.dropout(attention_output)
|
| 699 |
+
hidden_states = residual + hidden_states
|
| 700 |
+
if not self.normalize_before:
|
| 701 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 702 |
+
|
| 703 |
+
# convolution module
|
| 704 |
+
if self.use_cnn_module:
|
| 705 |
+
residual = hidden_states
|
| 706 |
+
if self.normalize_before:
|
| 707 |
+
hidden_states = self.conv_layer_norm(hidden_states)
|
| 708 |
+
hidden_states = self.conv_module(hidden_states)
|
| 709 |
+
hidden_states = self.dropout(hidden_states)
|
| 710 |
+
hidden_states = residual + hidden_states
|
| 711 |
+
if not self.normalize_before:
|
| 712 |
+
hidden_states = self.conv_layer_norm(hidden_states)
|
| 713 |
+
|
| 714 |
+
# feed forward module
|
| 715 |
+
residual = hidden_states
|
| 716 |
+
if self.normalize_before:
|
| 717 |
+
hidden_states = self.ff_layer_norm(hidden_states)
|
| 718 |
+
hidden_states = self.feed_forward(hidden_states)
|
| 719 |
+
hidden_states = self.dropout(hidden_states)
|
| 720 |
+
hidden_states = residual + self.ff_scale * hidden_states
|
| 721 |
+
if not self.normalize_before:
|
| 722 |
+
hidden_states = self.ff_layer_norm(hidden_states)
|
| 723 |
+
|
| 724 |
+
if self.conv_module is not None:
|
| 725 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 726 |
+
|
| 727 |
+
outputs = (hidden_states,)
|
| 728 |
+
|
| 729 |
+
if output_attentions:
|
| 730 |
+
outputs += (attention_scores,)
|
| 731 |
+
|
| 732 |
+
return outputs
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
class FastSpeech2ConformerMultiLayeredConv1d(nn.Module):
|
| 736 |
+
"""
|
| 737 |
+
Multi-layered conv1d for Transformer block.
|
| 738 |
+
|
| 739 |
+
This is a module of multi-layered conv1d designed to replace positionwise feed-forward network in Transformer
|
| 740 |
+
block, which is introduced in 'FastSpeech: Fast, Robust and Controllable Text to Speech'
|
| 741 |
+
https://arxiv.org/pdf/1905.09263.pdf
|
| 742 |
+
"""
|
| 743 |
+
|
| 744 |
+
def __init__(self, config: FastSpeech2ConformerConfig, module_config):
|
| 745 |
+
"""
|
| 746 |
+
Initialize FastSpeech2ConformerMultiLayeredConv1d module.
|
| 747 |
+
|
| 748 |
+
Args:
|
| 749 |
+
input_channels (`int`): Number of input channels.
|
| 750 |
+
hidden_channels (`int`): Number of hidden channels.
|
| 751 |
+
kernel_size (`int`): Kernel size of conv1d.
|
| 752 |
+
dropout_rate (`float`): Dropout rate.
|
| 753 |
+
"""
|
| 754 |
+
super().__init__()
|
| 755 |
+
input_channels = config.hidden_size
|
| 756 |
+
hidden_channels = module_config["linear_units"]
|
| 757 |
+
kernel_size = config.positionwise_conv_kernel_size
|
| 758 |
+
self.conv1 = nn.Conv1d(input_channels, hidden_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2)
|
| 759 |
+
self.conv2 = nn.Conv1d(hidden_channels, input_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2)
|
| 760 |
+
self.dropout = nn.Dropout(module_config["dropout_rate"])
|
| 761 |
+
|
| 762 |
+
def forward(self, hidden_states):
|
| 763 |
+
"""
|
| 764 |
+
Calculate forward propagation.
|
| 765 |
+
|
| 766 |
+
Args:
|
| 767 |
+
hidden_states (torch.Tensor): Batch of input tensors (batch_size, time, input_channels).
|
| 768 |
+
|
| 769 |
+
Returns:
|
| 770 |
+
torch.Tensor: Batch of output tensors (batch_size, time, hidden_channels).
|
| 771 |
+
"""
|
| 772 |
+
hidden_states = hidden_states.transpose(-1, 1)
|
| 773 |
+
hidden_states = self.conv1(hidden_states)
|
| 774 |
+
hidden_states = torch.relu(hidden_states)
|
| 775 |
+
hidden_states = self.dropout(hidden_states)
|
| 776 |
+
hidden_states = self.conv2(hidden_states)
|
| 777 |
+
hidden_states = hidden_states.transpose(-1, 1)
|
| 778 |
+
return hidden_states
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
|
| 782 |
+
"""
|
| 783 |
+
Args:
|
| 784 |
+
Relative positional encoding module (new implementation). Details can be found in
|
| 785 |
+
https://github.com/espnet/espnet/pull/2816. See : Appendix Batch in https://arxiv.org/abs/1901.02860
|
| 786 |
+
config (`FastSpeech2ConformerConfig`):
|
| 787 |
+
FastSpeech2ConformerConfig instance.
|
| 788 |
+
module_config (`dict`):
|
| 789 |
+
Dictionary containing the encoder or decoder module configuration from the `FastSpeech2ConformerConfig`.
|
| 790 |
+
"""
|
| 791 |
+
|
| 792 |
+
def __init__(self, config: FastSpeech2ConformerConfig, module_config):
|
| 793 |
+
"""
|
| 794 |
+
Construct an PositionalEncoding object.
|
| 795 |
+
"""
|
| 796 |
+
super().__init__()
|
| 797 |
+
self.embed_dim = config.hidden_size
|
| 798 |
+
self.input_scale = math.sqrt(self.embed_dim)
|
| 799 |
+
self.dropout = nn.Dropout(p=module_config["positional_dropout_rate"])
|
| 800 |
+
self.pos_enc = None
|
| 801 |
+
self.max_len = 5000
|
| 802 |
+
self.extend_pos_enc(torch.tensor(0.0).expand(1, self.max_len))
|
| 803 |
+
|
| 804 |
+
def extend_pos_enc(self, x):
|
| 805 |
+
"""Reset the positional encodings."""
|
| 806 |
+
if self.pos_enc is not None:
|
| 807 |
+
# self.pos_enc contains both positive and negative parts
|
| 808 |
+
# the length of self.pos_enc is 2 * input_len - 1
|
| 809 |
+
if self.pos_enc.size(1) >= x.size(1) * 2 - 1:
|
| 810 |
+
if self.pos_enc.dtype != x.dtype or self.pos_enc.device != x.device:
|
| 811 |
+
self.pos_enc = self.pos_enc.to(dtype=x.dtype, device=x.device)
|
| 812 |
+
return
|
| 813 |
+
# Suppose `i` means to the position of query vector and `j` means the
|
| 814 |
+
# position of key vector. We use position relative positions when keys
|
| 815 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
| 816 |
+
pos_enc_positive = torch.zeros(x.size(1), self.embed_dim)
|
| 817 |
+
pos_enc_negative = torch.zeros(x.size(1), self.embed_dim)
|
| 818 |
+
position = torch.arange(0, x.size(1), dtype=torch.int64).float().unsqueeze(1)
|
| 819 |
+
div_term = torch.exp(
|
| 820 |
+
torch.arange(0, self.embed_dim, 2, dtype=torch.int64).float() * -(math.log(10000.0) / self.embed_dim)
|
| 821 |
+
)
|
| 822 |
+
pos_enc_positive[:, 0::2] = torch.sin(position * div_term)
|
| 823 |
+
pos_enc_positive[:, 1::2] = torch.cos(position * div_term)
|
| 824 |
+
pos_enc_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
| 825 |
+
pos_enc_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
| 826 |
+
|
| 827 |
+
# Reserve the order of positive indices and concat both positive and
|
| 828 |
+
# negative indices. This is used to support the shifting trick
|
| 829 |
+
# as in https://arxiv.org/abs/1901.02860
|
| 830 |
+
pos_enc_positive = torch.flip(pos_enc_positive, [0]).unsqueeze(0)
|
| 831 |
+
pos_enc_negative = pos_enc_negative[1:].unsqueeze(0)
|
| 832 |
+
pos_enc = torch.cat([pos_enc_positive, pos_enc_negative], dim=1)
|
| 833 |
+
self.pos_enc = pos_enc.to(device=x.device, dtype=x.dtype)
|
| 834 |
+
|
| 835 |
+
def forward(self, feature_representation):
|
| 836 |
+
"""
|
| 837 |
+
Args:
|
| 838 |
+
feature_representation (`torch.Tensor` of shape (batch_size, time, `*`)):
|
| 839 |
+
Input tensor.
|
| 840 |
+
|
| 841 |
+
Returns:
|
| 842 |
+
`torch.Tensor`: Encoded tensor (batch_size, time, `*`).
|
| 843 |
+
"""
|
| 844 |
+
self.extend_pos_enc(feature_representation)
|
| 845 |
+
hidden_states = feature_representation * self.input_scale
|
| 846 |
+
center_idx = self.pos_enc.size(1) // 2
|
| 847 |
+
pos_emb = self.pos_enc[:, center_idx - hidden_states.size(1) + 1 : center_idx + hidden_states.size(1)]
|
| 848 |
+
return self.dropout(hidden_states), self.dropout(pos_emb)
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
class FastSpeech2ConformerEncoder(nn.Module):
|
| 852 |
+
"""
|
| 853 |
+
FastSpeech2ConformerEncoder encoder module.
|
| 854 |
+
|
| 855 |
+
Args:
|
| 856 |
+
config (`FastSpeech2ConformerConfig`):
|
| 857 |
+
FastSpeech2ConformerConfig instance.
|
| 858 |
+
module_config (`dict`):
|
| 859 |
+
Dictionary containing the encoder or decoder module configuration from the `FastSpeech2ConformerConfig`.
|
| 860 |
+
use_encoder_input_layer (`bool`, *optional*, defaults to `False`):
|
| 861 |
+
Input layer type.
|
| 862 |
+
"""
|
| 863 |
+
|
| 864 |
+
def __init__(
|
| 865 |
+
self,
|
| 866 |
+
config: FastSpeech2ConformerConfig,
|
| 867 |
+
module_config,
|
| 868 |
+
use_encoder_input_layer=False,
|
| 869 |
+
):
|
| 870 |
+
super().__init__()
|
| 871 |
+
|
| 872 |
+
self.embed = None
|
| 873 |
+
if use_encoder_input_layer:
|
| 874 |
+
self.embed = nn.Embedding(
|
| 875 |
+
num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, padding_idx=0
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
self.pos_enc = FastSpeech2ConformerRelPositionalEncoding(config, module_config)
|
| 879 |
+
|
| 880 |
+
self.conformer_layers = nn.ModuleList(
|
| 881 |
+
[FastSpeech2ConformerEncoderLayer(config, module_config) for _ in range(module_config["layers"])]
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
def forward(
|
| 885 |
+
self,
|
| 886 |
+
input_tensor: torch.LongTensor,
|
| 887 |
+
attention_mask: Optional[bool] = None,
|
| 888 |
+
output_hidden_states: Optional[bool] = None,
|
| 889 |
+
output_attentions: Optional[bool] = False,
|
| 890 |
+
return_dict: Optional[bool] = None,
|
| 891 |
+
):
|
| 892 |
+
"""
|
| 893 |
+
Args:
|
| 894 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 895 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
| 896 |
+
provide it.
|
| 897 |
+
|
| 898 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 899 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 900 |
+
|
| 901 |
+
[What are input IDs?](../glossary#input-ids)
|
| 902 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 903 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 904 |
+
|
| 905 |
+
- 1 for tokens that are **not masked**,
|
| 906 |
+
- 0 for tokens that are **masked**.
|
| 907 |
+
|
| 908 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 909 |
+
output_hidden_states (`bool`, *optional*):
|
| 910 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 911 |
+
for more detail.
|
| 912 |
+
output_attentions (`bool`, *optional*):
|
| 913 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 914 |
+
returned tensors for more detail.
|
| 915 |
+
return_dict (`bool`, *optional*):
|
| 916 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 917 |
+
Returns:
|
| 918 |
+
`torch.Tensor`:
|
| 919 |
+
Output tensor of shape `(batch, time, attention_dim)`.
|
| 920 |
+
"""
|
| 921 |
+
feature_representation = input_tensor
|
| 922 |
+
if self.embed is not None:
|
| 923 |
+
feature_representation = self.embed(feature_representation)
|
| 924 |
+
|
| 925 |
+
hidden_states, pos_emb = self.pos_enc(feature_representation)
|
| 926 |
+
|
| 927 |
+
all_hidden_states = () if output_hidden_states else None
|
| 928 |
+
all_self_attentions = () if output_attentions else None
|
| 929 |
+
|
| 930 |
+
for conformer_layer in self.conformer_layers:
|
| 931 |
+
if output_hidden_states:
|
| 932 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 933 |
+
|
| 934 |
+
layer_outputs = conformer_layer(hidden_states, pos_emb, attention_mask, output_attentions)
|
| 935 |
+
hidden_states = layer_outputs[0]
|
| 936 |
+
|
| 937 |
+
if output_attentions:
|
| 938 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 939 |
+
|
| 940 |
+
# Add last layer
|
| 941 |
+
if output_hidden_states:
|
| 942 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 943 |
+
|
| 944 |
+
if not return_dict:
|
| 945 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 946 |
+
return BaseModelOutput(
|
| 947 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
class FastSpeech2ConformerLoss(nn.Module):
|
| 952 |
+
def __init__(self, config: FastSpeech2ConformerConfig):
|
| 953 |
+
super().__init__()
|
| 954 |
+
|
| 955 |
+
use_masking = config.use_masking
|
| 956 |
+
use_weighted_masking = config.use_weighted_masking
|
| 957 |
+
|
| 958 |
+
if use_masking and use_weighted_masking:
|
| 959 |
+
raise ValueError("Either use_masking or use_weighted_masking can be True, but not both.")
|
| 960 |
+
|
| 961 |
+
self.use_masking = use_masking
|
| 962 |
+
self.use_weighted_masking = use_weighted_masking
|
| 963 |
+
|
| 964 |
+
# define criterions
|
| 965 |
+
reduction = "none" if self.use_weighted_masking else "mean"
|
| 966 |
+
self.l1_criterion = nn.L1Loss(reduction=reduction)
|
| 967 |
+
self.mse_criterion = nn.MSELoss(reduction=reduction)
|
| 968 |
+
self.duration_criterion = nn.MSELoss(reduction=reduction)
|
| 969 |
+
self.log_domain_offset = 1.0
|
| 970 |
+
|
| 971 |
+
def forward(
|
| 972 |
+
self,
|
| 973 |
+
outputs_after_postnet,
|
| 974 |
+
outputs_before_postnet,
|
| 975 |
+
duration_outputs,
|
| 976 |
+
pitch_outputs,
|
| 977 |
+
energy_outputs,
|
| 978 |
+
spectrogram_labels,
|
| 979 |
+
duration_labels,
|
| 980 |
+
pitch_labels,
|
| 981 |
+
energy_labels,
|
| 982 |
+
duration_mask,
|
| 983 |
+
spectrogram_mask,
|
| 984 |
+
):
|
| 985 |
+
"""
|
| 986 |
+
Args:
|
| 987 |
+
outputs_after_postnet (`torch.Tensor` of shape `(batch_size, max_spectrogram_length, num_mel_bins)`):
|
| 988 |
+
Batch of outputs after postnet.
|
| 989 |
+
outputs_before_postnet (`torch.Tensor` of shape `(batch_size, max_spectrogram_length, num_mel_bins)`):
|
| 990 |
+
Batch of outputs before postnet.
|
| 991 |
+
duration_outputs (`torch.LongTensor` of shape `(batch_size, max_text_length)`):
|
| 992 |
+
Batch of outputs of duration predictor.
|
| 993 |
+
pitch_outputs (`torch.Tensor` of shape `(batch_size, max_text_length, 1)`):
|
| 994 |
+
Batch of outputs of pitch predictor.
|
| 995 |
+
energy_outputs (`torch.Tensor` of shape `(batch_size, max_text_length, 1)`):
|
| 996 |
+
Batch of outputs of energy predictor.
|
| 997 |
+
spectrogram_labels (`torch.Tensor` of shape `(batch_size, max_spectrogram_length, num_mel_bins)`):
|
| 998 |
+
Batch of target features.
|
| 999 |
+
duration_labels (`torch.LongTensor` of shape `(batch_size, max_text_length)`): Batch of durations.
|
| 1000 |
+
pitch_labels (`torch.Tensor` of shape `(batch_size, max_text_length, 1)`):
|
| 1001 |
+
Batch of target token-averaged pitch.
|
| 1002 |
+
energy_labels (`torch.Tensor` of shape `(batch_size, max_text_length, 1)`):
|
| 1003 |
+
Batch of target token-averaged energy.
|
| 1004 |
+
duration_mask (`torch.LongTensor`):
|
| 1005 |
+
Mask used to discern which values the duration loss should be calculated for.
|
| 1006 |
+
spectrogram_mask (`torch.LongTensor`):
|
| 1007 |
+
Mask used to discern which values the spectrogam loss should be calculated for.
|
| 1008 |
+
|
| 1009 |
+
Returns:
|
| 1010 |
+
`tuple(torch.FloatTensor)`: Tuple of tensors containing, in order, the L1 loss value, duration predictor
|
| 1011 |
+
loss value, pitch predictor loss value, and energy predictor loss value.
|
| 1012 |
+
|
| 1013 |
+
"""
|
| 1014 |
+
pitch_and_energy_masks = duration_mask.unsqueeze(-1)
|
| 1015 |
+
|
| 1016 |
+
# apply mask to remove padded part
|
| 1017 |
+
if self.use_masking:
|
| 1018 |
+
outputs_before_postnet = outputs_before_postnet.masked_select(spectrogram_mask)
|
| 1019 |
+
if outputs_after_postnet is not None:
|
| 1020 |
+
outputs_after_postnet = outputs_after_postnet.masked_select(spectrogram_mask)
|
| 1021 |
+
spectrogram_labels = spectrogram_labels.masked_select(spectrogram_mask)
|
| 1022 |
+
duration_outputs = duration_outputs.masked_select(duration_mask)
|
| 1023 |
+
duration_labels = duration_labels.masked_select(duration_mask)
|
| 1024 |
+
pitch_outputs = pitch_outputs.masked_select(pitch_and_energy_masks)
|
| 1025 |
+
energy_outputs = energy_outputs.masked_select(pitch_and_energy_masks)
|
| 1026 |
+
pitch_labels = pitch_labels.masked_select(pitch_and_energy_masks)
|
| 1027 |
+
energy_labels = energy_labels.masked_select(pitch_and_energy_masks)
|
| 1028 |
+
|
| 1029 |
+
# calculate loss
|
| 1030 |
+
l1_loss = self.l1_criterion(outputs_before_postnet, spectrogram_labels)
|
| 1031 |
+
if outputs_after_postnet is not None:
|
| 1032 |
+
l1_loss = l1_loss + self.l1_criterion(outputs_after_postnet, spectrogram_labels)
|
| 1033 |
+
duration_labels = torch.log(duration_labels.float() + self.log_domain_offset)
|
| 1034 |
+
duration_loss = self.duration_criterion(duration_outputs, duration_labels)
|
| 1035 |
+
pitch_loss = self.mse_criterion(pitch_outputs, pitch_labels)
|
| 1036 |
+
energy_loss = self.mse_criterion(energy_outputs, energy_labels)
|
| 1037 |
+
|
| 1038 |
+
# make weighted mask and apply it
|
| 1039 |
+
if self.use_weighted_masking:
|
| 1040 |
+
spectrogram_mask = nn.functional.pad(
|
| 1041 |
+
spectrogram_mask.transpose(1, 2),
|
| 1042 |
+
[0, spectrogram_labels.size(1) - spectrogram_mask.size(1), 0, 0, 0, 0],
|
| 1043 |
+
value=False,
|
| 1044 |
+
).transpose(1, 2)
|
| 1045 |
+
|
| 1046 |
+
out_weights = spectrogram_mask.float() / spectrogram_mask.sum(dim=1, keepdim=True).float()
|
| 1047 |
+
out_weights /= spectrogram_labels.size(0) * spectrogram_labels.size(2)
|
| 1048 |
+
duration_weights = duration_mask.float() / duration_mask.sum(dim=1, keepdim=True).float()
|
| 1049 |
+
duration_weights /= duration_labels.size(0)
|
| 1050 |
+
|
| 1051 |
+
# apply weight
|
| 1052 |
+
l1_loss = l1_loss.mul(out_weights).masked_select(spectrogram_mask).sum()
|
| 1053 |
+
duration_loss = duration_loss.mul(duration_weights).masked_select(duration_mask).sum()
|
| 1054 |
+
pitch_weights = duration_weights.unsqueeze(-1)
|
| 1055 |
+
pitch_loss = pitch_loss.mul(pitch_weights).masked_select(pitch_and_energy_masks).sum()
|
| 1056 |
+
energy_loss = energy_loss.mul(pitch_weights).masked_select(pitch_and_energy_masks).sum()
|
| 1057 |
+
|
| 1058 |
+
return l1_loss + duration_loss + pitch_loss + energy_loss
|
| 1059 |
+
|
| 1060 |
+
|
| 1061 |
+
class FastSpeech2ConformerPreTrainedModel(PreTrainedModel):
|
| 1062 |
+
"""
|
| 1063 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 1064 |
+
models.
|
| 1065 |
+
"""
|
| 1066 |
+
|
| 1067 |
+
config_class = FastSpeech2ConformerConfig
|
| 1068 |
+
base_model_prefix = "fastspeech2_conformer"
|
| 1069 |
+
|
| 1070 |
+
main_input_name = "input_ids"
|
| 1071 |
+
|
| 1072 |
+
def _init_weights(self, module):
|
| 1073 |
+
"""Initialize the weights"""
|
| 1074 |
+
if isinstance(module, (nn.LayerNorm)):
|
| 1075 |
+
module.bias.data.zero_()
|
| 1076 |
+
module.weight.data.fill_(1.0)
|
| 1077 |
+
elif isinstance(module, nn.Conv1d):
|
| 1078 |
+
nn.init.kaiming_normal_(module.weight)
|
| 1079 |
+
if module.bias is not None:
|
| 1080 |
+
key = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
| 1081 |
+
nn.init.uniform_(module.bias, a=-key, b=key)
|
| 1082 |
+
elif isinstance(module, nn.Embedding):
|
| 1083 |
+
module.weight.data.normal_()
|
| 1084 |
+
if module.padding_idx is not None:
|
| 1085 |
+
module.weight.data[module.padding_idx].zero_()
|
| 1086 |
+
elif isinstance(module, FastSpeech2ConformerAttention):
|
| 1087 |
+
nn.init.xavier_uniform_(module.pos_bias_u)
|
| 1088 |
+
nn.init.xavier_uniform_(module.pos_bias_v)
|
| 1089 |
+
|
| 1090 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 1091 |
+
if isinstance(module, FastSpeech2ConformerEncoder):
|
| 1092 |
+
module.gradient_checkpointing = value
|
| 1093 |
+
|
| 1094 |
+
|
| 1095 |
+
@add_start_docstrings(
|
| 1096 |
+
"""FastSpeech2Conformer Model.""",
|
| 1097 |
+
FASTSPEECH2_CONFORMER_START_DOCSTRING,
|
| 1098 |
+
)
|
| 1099 |
+
class FastSpeech2ConformerModel(FastSpeech2ConformerPreTrainedModel):
|
| 1100 |
+
"""
|
| 1101 |
+
FastSpeech 2 module.
|
| 1102 |
+
|
| 1103 |
+
This is a module of FastSpeech 2 described in 'FastSpeech 2: Fast and High-Quality End-to-End Text to Speech'
|
| 1104 |
+
https://arxiv.org/abs/2006.04558. Instead of quantized pitch and energy, we use token-averaged value introduced in
|
| 1105 |
+
FastPitch: Parallel Text-to-speech with Pitch Prediction. The encoder and decoder are Conformers instead of regular
|
| 1106 |
+
Transformers.
|
| 1107 |
+
"""
|
| 1108 |
+
|
| 1109 |
+
def __init__(self, config: FastSpeech2ConformerConfig):
|
| 1110 |
+
super().__init__(config)
|
| 1111 |
+
self.config = config
|
| 1112 |
+
|
| 1113 |
+
# store hyperparameters
|
| 1114 |
+
self.vocab_size = config.vocab_size
|
| 1115 |
+
self.num_mel_bins = config.num_mel_bins
|
| 1116 |
+
self.hidden_size = config.hidden_size
|
| 1117 |
+
self.reduction_factor = config.reduction_factor
|
| 1118 |
+
self.stop_gradient_from_pitch_predictor = config.stop_gradient_from_pitch_predictor
|
| 1119 |
+
self.stop_gradient_from_energy_predictor = config.stop_gradient_from_energy_predictor
|
| 1120 |
+
|
| 1121 |
+
self.multilingual_model = config.num_languages is not None and config.num_languages > 1
|
| 1122 |
+
if self.multilingual_model:
|
| 1123 |
+
self.language_id_embedding = torch.nn.Embedding(config.num_languages, self.hidden_size)
|
| 1124 |
+
|
| 1125 |
+
self.multispeaker_model = config.num_speakers is not None and config.num_speakers > 1
|
| 1126 |
+
if self.multispeaker_model:
|
| 1127 |
+
self.speaker_id_embedding = torch.nn.Embedding(config.num_speakers, config.hidden_size)
|
| 1128 |
+
|
| 1129 |
+
self.speaker_embed_dim = config.speaker_embed_dim
|
| 1130 |
+
if self.speaker_embed_dim:
|
| 1131 |
+
self.projection = nn.Linear(config.hidden_size + self.speaker_embed_dim, config.hidden_size)
|
| 1132 |
+
|
| 1133 |
+
self.encoder = FastSpeech2ConformerEncoder(config, config.encoder_config, use_encoder_input_layer=True)
|
| 1134 |
+
|
| 1135 |
+
self.duration_predictor = FastSpeech2ConformerDurationPredictor(config)
|
| 1136 |
+
|
| 1137 |
+
self.pitch_predictor = FastSpeech2ConformerVariancePredictor(
|
| 1138 |
+
config,
|
| 1139 |
+
num_layers=config.pitch_predictor_layers,
|
| 1140 |
+
num_chans=config.pitch_predictor_channels,
|
| 1141 |
+
kernel_size=config.pitch_predictor_kernel_size,
|
| 1142 |
+
dropout_rate=config.pitch_predictor_dropout,
|
| 1143 |
+
)
|
| 1144 |
+
# continuous pitch + FastPitch style avg
|
| 1145 |
+
self.pitch_embed = FastSpeech2ConformerVarianceEmbedding(
|
| 1146 |
+
out_channels=self.hidden_size,
|
| 1147 |
+
kernel_size=config.pitch_embed_kernel_size,
|
| 1148 |
+
padding=(config.pitch_embed_kernel_size - 1) // 2,
|
| 1149 |
+
dropout_rate=config.pitch_embed_dropout,
|
| 1150 |
+
)
|
| 1151 |
+
|
| 1152 |
+
self.energy_predictor = FastSpeech2ConformerVariancePredictor(
|
| 1153 |
+
config,
|
| 1154 |
+
num_layers=config.energy_predictor_layers,
|
| 1155 |
+
num_chans=config.energy_predictor_channels,
|
| 1156 |
+
kernel_size=config.energy_predictor_kernel_size,
|
| 1157 |
+
dropout_rate=config.energy_predictor_dropout,
|
| 1158 |
+
)
|
| 1159 |
+
# continuous energy + FastPitch style avg
|
| 1160 |
+
self.energy_embed = FastSpeech2ConformerVarianceEmbedding(
|
| 1161 |
+
out_channels=self.hidden_size,
|
| 1162 |
+
kernel_size=config.energy_embed_kernel_size,
|
| 1163 |
+
padding=(config.energy_embed_kernel_size - 1) // 2,
|
| 1164 |
+
dropout_rate=config.energy_embed_dropout,
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
# The decoder is an encoder
|
| 1168 |
+
self.decoder = FastSpeech2ConformerEncoder(config, config.decoder_config, use_encoder_input_layer=False)
|
| 1169 |
+
|
| 1170 |
+
self.speech_decoder_postnet = FastSpeech2ConformerSpeechDecoderPostnet(config)
|
| 1171 |
+
|
| 1172 |
+
self.criterion = FastSpeech2ConformerLoss(config)
|
| 1173 |
+
|
| 1174 |
+
self.post_init()
|
| 1175 |
+
|
| 1176 |
+
@replace_return_docstrings(output_type=FastSpeech2ConformerModelOutput, config_class=_CONFIG_FOR_DOC)
|
| 1177 |
+
def forward(
|
| 1178 |
+
self,
|
| 1179 |
+
input_ids: torch.LongTensor,
|
| 1180 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1181 |
+
spectrogram_labels: Optional[torch.FloatTensor] = None,
|
| 1182 |
+
duration_labels: Optional[torch.LongTensor] = None,
|
| 1183 |
+
pitch_labels: Optional[torch.FloatTensor] = None,
|
| 1184 |
+
energy_labels: Optional[torch.FloatTensor] = None,
|
| 1185 |
+
speaker_ids: Optional[torch.LongTensor] = None,
|
| 1186 |
+
lang_ids: Optional[torch.LongTensor] = None,
|
| 1187 |
+
speaker_embedding: Optional[torch.FloatTensor] = None,
|
| 1188 |
+
return_dict: Optional[bool] = None,
|
| 1189 |
+
output_attentions: Optional[bool] = None,
|
| 1190 |
+
output_hidden_states: Optional[bool] = None,
|
| 1191 |
+
) -> Union[Tuple, FastSpeech2ConformerModelOutput]:
|
| 1192 |
+
"""
|
| 1193 |
+
Args:
|
| 1194 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1195 |
+
Input sequence of text vectors.
|
| 1196 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to `None`):
|
| 1197 |
+
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
|
| 1198 |
+
`[0, 1]`: 0 for tokens that are **masked**, 1 for tokens that are **not masked**.
|
| 1199 |
+
spectrogram_labels (`torch.FloatTensor` of shape `(batch_size, max_spectrogram_length, num_mel_bins)`, *optional*, defaults to `None`):
|
| 1200 |
+
Batch of padded target features.
|
| 1201 |
+
duration_labels (`torch.LongTensor` of shape `(batch_size, sequence_length + 1)`, *optional*, defaults to `None`):
|
| 1202 |
+
Batch of padded durations.
|
| 1203 |
+
pitch_labels (`torch.FloatTensor` of shape `(batch_size, sequence_length + 1, 1)`, *optional*, defaults to `None`):
|
| 1204 |
+
Batch of padded token-averaged pitch.
|
| 1205 |
+
energy_labels (`torch.FloatTensor` of shape `(batch_size, sequence_length + 1, 1)`, *optional*, defaults to `None`):
|
| 1206 |
+
Batch of padded token-averaged energy.
|
| 1207 |
+
speaker_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*, defaults to `None`):
|
| 1208 |
+
Speaker ids used to condition features of speech output by the model.
|
| 1209 |
+
lang_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*, defaults to `None`):
|
| 1210 |
+
Language ids used to condition features of speech output by the model.
|
| 1211 |
+
speaker_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`, *optional*, defaults to `None`):
|
| 1212 |
+
Embedding containing conditioning signals for the features of the speech.
|
| 1213 |
+
return_dict (`bool`, *optional*, defaults to `None`):
|
| 1214 |
+
Whether or not to return a [`FastSpeech2ConformerModelOutput`] instead of a plain tuple.
|
| 1215 |
+
output_attentions (`bool`, *optional*, defaults to `None`):
|
| 1216 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 1217 |
+
returned tensors for more detail.
|
| 1218 |
+
output_hidden_states (`bool`, *optional*, defaults to `None`):
|
| 1219 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 1220 |
+
for more detail.
|
| 1221 |
+
|
| 1222 |
+
Returns:
|
| 1223 |
+
|
| 1224 |
+
Example:
|
| 1225 |
+
|
| 1226 |
+
```python
|
| 1227 |
+
>>> from transformers import (
|
| 1228 |
+
... FastSpeech2ConformerTokenizer,
|
| 1229 |
+
... FastSpeech2ConformerModel,
|
| 1230 |
+
... FastSpeech2ConformerHifiGan,
|
| 1231 |
+
... )
|
| 1232 |
+
|
| 1233 |
+
>>> tokenizer = FastSpeech2ConformerTokenizer.from_pretrained("espnet/fastspeech2_conformer")
|
| 1234 |
+
>>> inputs = tokenizer("some text to convert to speech", return_tensors="pt")
|
| 1235 |
+
>>> input_ids = inputs["input_ids"]
|
| 1236 |
+
|
| 1237 |
+
>>> model = FastSpeech2ConformerModel.from_pretrained("espnet/fastspeech2_conformer")
|
| 1238 |
+
>>> output_dict = model(input_ids, return_dict=True)
|
| 1239 |
+
>>> spectrogram = output_dict["spectrogram"]
|
| 1240 |
+
|
| 1241 |
+
>>> vocoder = FastSpeech2ConformerHifiGan.from_pretrained("espnet/fastspeech2_conformer_hifigan")
|
| 1242 |
+
>>> waveform = vocoder(spectrogram)
|
| 1243 |
+
>>> print(waveform.shape)
|
| 1244 |
+
torch.Size([1, 49664])
|
| 1245 |
+
```
|
| 1246 |
+
"""
|
| 1247 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1248 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1249 |
+
output_hidden_states = (
|
| 1250 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1251 |
+
)
|
| 1252 |
+
|
| 1253 |
+
if attention_mask is None:
|
| 1254 |
+
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
|
| 1255 |
+
|
| 1256 |
+
has_missing_labels = (
|
| 1257 |
+
spectrogram_labels is None or duration_labels is None or pitch_labels is None or energy_labels is None
|
| 1258 |
+
)
|
| 1259 |
+
if self.training and has_missing_labels:
|
| 1260 |
+
raise ValueError("All labels must be provided to run in training mode.")
|
| 1261 |
+
|
| 1262 |
+
# forward encoder
|
| 1263 |
+
text_masks = attention_mask.unsqueeze(-2)
|
| 1264 |
+
|
| 1265 |
+
encoder_outputs = self.encoder(
|
| 1266 |
+
input_ids,
|
| 1267 |
+
text_masks,
|
| 1268 |
+
output_hidden_states=output_hidden_states,
|
| 1269 |
+
output_attentions=output_attentions,
|
| 1270 |
+
return_dict=return_dict,
|
| 1271 |
+
)
|
| 1272 |
+
hidden_states = encoder_outputs[0]
|
| 1273 |
+
|
| 1274 |
+
# Integrate with language id, speaker id, and speaker embedding
|
| 1275 |
+
if self.multispeaker_model and speaker_ids is not None:
|
| 1276 |
+
speaker_id_embeddings = self.speaker_id_embedding(speaker_ids.view(-1))
|
| 1277 |
+
hidden_states = hidden_states + speaker_id_embeddings.unsqueeze(1)
|
| 1278 |
+
|
| 1279 |
+
if self.multilingual_model and lang_ids is not None:
|
| 1280 |
+
language_id_embbedings = self.language_id_embedding(lang_ids.view(-1))
|
| 1281 |
+
hidden_states = hidden_states + language_id_embbedings.unsqueeze(1)
|
| 1282 |
+
|
| 1283 |
+
if self.speaker_embed_dim is not None and speaker_embedding is not None:
|
| 1284 |
+
embeddings_expanded = (
|
| 1285 |
+
nn.functional.normalize(speaker_embedding).unsqueeze(1).expand(-1, hidden_states.size(1), -1)
|
| 1286 |
+
)
|
| 1287 |
+
hidden_states = self.projection(torch.cat([hidden_states, embeddings_expanded], dim=-1))
|
| 1288 |
+
|
| 1289 |
+
# forward duration predictor and variance predictors
|
| 1290 |
+
duration_mask = ~attention_mask.bool()
|
| 1291 |
+
|
| 1292 |
+
if self.stop_gradient_from_pitch_predictor:
|
| 1293 |
+
pitch_predictions = self.pitch_predictor(hidden_states.detach(), duration_mask.unsqueeze(-1))
|
| 1294 |
+
else:
|
| 1295 |
+
pitch_predictions = self.pitch_predictor(hidden_states, duration_mask.unsqueeze(-1))
|
| 1296 |
+
|
| 1297 |
+
if self.stop_gradient_from_energy_predictor:
|
| 1298 |
+
energy_predictions = self.energy_predictor(hidden_states.detach(), duration_mask.unsqueeze(-1))
|
| 1299 |
+
else:
|
| 1300 |
+
energy_predictions = self.energy_predictor(hidden_states, duration_mask.unsqueeze(-1))
|
| 1301 |
+
|
| 1302 |
+
duration_predictions = self.duration_predictor(hidden_states)
|
| 1303 |
+
duration_predictions = duration_predictions.masked_fill(duration_mask, 0.0)
|
| 1304 |
+
|
| 1305 |
+
if not self.training:
|
| 1306 |
+
# use prediction in inference
|
| 1307 |
+
embedded_pitch_curve = self.pitch_embed(pitch_predictions)
|
| 1308 |
+
embedded_energy_curve = self.energy_embed(energy_predictions)
|
| 1309 |
+
hidden_states = hidden_states + embedded_energy_curve + embedded_pitch_curve
|
| 1310 |
+
hidden_states = length_regulator(hidden_states, duration_predictions, self.config.speaking_speed)
|
| 1311 |
+
else:
|
| 1312 |
+
# use groundtruth in training
|
| 1313 |
+
embedded_pitch_curve = self.pitch_embed(pitch_labels)
|
| 1314 |
+
embedded_energy_curve = self.energy_embed(energy_labels)
|
| 1315 |
+
hidden_states = hidden_states + embedded_energy_curve + embedded_pitch_curve
|
| 1316 |
+
hidden_states = length_regulator(hidden_states, duration_labels)
|
| 1317 |
+
|
| 1318 |
+
# forward decoder
|
| 1319 |
+
if not self.training:
|
| 1320 |
+
hidden_mask = None
|
| 1321 |
+
else:
|
| 1322 |
+
spectrogram_mask = (spectrogram_labels != -100).any(dim=-1)
|
| 1323 |
+
spectrogram_mask = spectrogram_mask.int()
|
| 1324 |
+
if self.reduction_factor > 1:
|
| 1325 |
+
length_dim = spectrogram_mask.shape[1] - spectrogram_mask.shape[1] % self.reduction_factor
|
| 1326 |
+
spectrogram_mask = spectrogram_mask[:, :, :length_dim]
|
| 1327 |
+
hidden_mask = spectrogram_mask.unsqueeze(-2)
|
| 1328 |
+
|
| 1329 |
+
decoder_outputs = self.decoder(
|
| 1330 |
+
hidden_states,
|
| 1331 |
+
hidden_mask,
|
| 1332 |
+
output_hidden_states=output_hidden_states,
|
| 1333 |
+
output_attentions=output_attentions,
|
| 1334 |
+
return_dict=return_dict,
|
| 1335 |
+
)
|
| 1336 |
+
|
| 1337 |
+
outputs_before_postnet, outputs_after_postnet = self.speech_decoder_postnet(decoder_outputs[0])
|
| 1338 |
+
|
| 1339 |
+
loss = None
|
| 1340 |
+
if self.training:
|
| 1341 |
+
# calculate loss
|
| 1342 |
+
loss_duration_mask = ~duration_mask
|
| 1343 |
+
loss_spectrogram_mask = spectrogram_mask.unsqueeze(-1).bool()
|
| 1344 |
+
loss = self.criterion(
|
| 1345 |
+
outputs_after_postnet=outputs_after_postnet,
|
| 1346 |
+
outputs_before_postnet=outputs_before_postnet,
|
| 1347 |
+
duration_outputs=duration_predictions,
|
| 1348 |
+
pitch_outputs=pitch_predictions,
|
| 1349 |
+
energy_outputs=energy_predictions,
|
| 1350 |
+
spectrogram_labels=spectrogram_labels,
|
| 1351 |
+
duration_labels=duration_labels,
|
| 1352 |
+
pitch_labels=pitch_labels,
|
| 1353 |
+
energy_labels=energy_labels,
|
| 1354 |
+
duration_mask=loss_duration_mask,
|
| 1355 |
+
spectrogram_mask=loss_spectrogram_mask,
|
| 1356 |
+
)
|
| 1357 |
+
|
| 1358 |
+
if not return_dict:
|
| 1359 |
+
postnet_outputs = (outputs_after_postnet,)
|
| 1360 |
+
audio_feature_predictions = (
|
| 1361 |
+
duration_predictions,
|
| 1362 |
+
pitch_predictions,
|
| 1363 |
+
energy_predictions,
|
| 1364 |
+
)
|
| 1365 |
+
outputs = postnet_outputs + encoder_outputs + decoder_outputs[1:] + audio_feature_predictions
|
| 1366 |
+
return ((loss,) + outputs) if loss is not None else outputs
|
| 1367 |
+
|
| 1368 |
+
return FastSpeech2ConformerModelOutput(
|
| 1369 |
+
loss=loss,
|
| 1370 |
+
spectrogram=outputs_after_postnet,
|
| 1371 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
| 1372 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 1373 |
+
encoder_attentions=encoder_outputs.attentions,
|
| 1374 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 1375 |
+
decoder_attentions=decoder_outputs.attentions,
|
| 1376 |
+
duration_outputs=duration_predictions,
|
| 1377 |
+
pitch_outputs=pitch_predictions,
|
| 1378 |
+
energy_outputs=energy_predictions,
|
| 1379 |
+
)
|
| 1380 |
+
|
| 1381 |
+
|
| 1382 |
+
# Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock
|
| 1383 |
+
class HifiGanResidualBlock(nn.Module):
|
| 1384 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1):
|
| 1385 |
+
super().__init__()
|
| 1386 |
+
self.leaky_relu_slope = leaky_relu_slope
|
| 1387 |
+
|
| 1388 |
+
self.convs1 = nn.ModuleList(
|
| 1389 |
+
[
|
| 1390 |
+
nn.Conv1d(
|
| 1391 |
+
channels,
|
| 1392 |
+
channels,
|
| 1393 |
+
kernel_size,
|
| 1394 |
+
stride=1,
|
| 1395 |
+
dilation=dilation[i],
|
| 1396 |
+
padding=self.get_padding(kernel_size, dilation[i]),
|
| 1397 |
+
)
|
| 1398 |
+
for i in range(len(dilation))
|
| 1399 |
+
]
|
| 1400 |
+
)
|
| 1401 |
+
self.convs2 = nn.ModuleList(
|
| 1402 |
+
[
|
| 1403 |
+
nn.Conv1d(
|
| 1404 |
+
channels,
|
| 1405 |
+
channels,
|
| 1406 |
+
kernel_size,
|
| 1407 |
+
stride=1,
|
| 1408 |
+
dilation=1,
|
| 1409 |
+
padding=self.get_padding(kernel_size, 1),
|
| 1410 |
+
)
|
| 1411 |
+
for _ in range(len(dilation))
|
| 1412 |
+
]
|
| 1413 |
+
)
|
| 1414 |
+
|
| 1415 |
+
def get_padding(self, kernel_size, dilation=1):
|
| 1416 |
+
return (kernel_size * dilation - dilation) // 2
|
| 1417 |
+
|
| 1418 |
+
def apply_weight_norm(self):
|
| 1419 |
+
weight_norm = nn.utils.weight_norm
|
| 1420 |
+
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
| 1421 |
+
weight_norm = nn.utils.parametrizations.weight_norm
|
| 1422 |
+
|
| 1423 |
+
for layer in self.convs1:
|
| 1424 |
+
weight_norm(layer)
|
| 1425 |
+
for layer in self.convs2:
|
| 1426 |
+
weight_norm(layer)
|
| 1427 |
+
|
| 1428 |
+
def remove_weight_norm(self):
|
| 1429 |
+
for layer in self.convs1:
|
| 1430 |
+
nn.utils.remove_weight_norm(layer)
|
| 1431 |
+
for layer in self.convs2:
|
| 1432 |
+
nn.utils.remove_weight_norm(layer)
|
| 1433 |
+
|
| 1434 |
+
def forward(self, hidden_states):
|
| 1435 |
+
for conv1, conv2 in zip(self.convs1, self.convs2):
|
| 1436 |
+
residual = hidden_states
|
| 1437 |
+
hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
|
| 1438 |
+
hidden_states = conv1(hidden_states)
|
| 1439 |
+
hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
|
| 1440 |
+
hidden_states = conv2(hidden_states)
|
| 1441 |
+
hidden_states = hidden_states + residual
|
| 1442 |
+
return hidden_states
|
| 1443 |
+
|
| 1444 |
+
|
| 1445 |
+
@add_start_docstrings(
|
| 1446 |
+
"""HiFi-GAN vocoder.""",
|
| 1447 |
+
HIFIGAN_START_DOCSTRING,
|
| 1448 |
+
)
|
| 1449 |
+
# Copied from transformers.models.speecht5.modeling_speecht5.SpeechT5HifiGan with SpeechT5->FastSpeech2Conformer
|
| 1450 |
+
class FastSpeech2ConformerHifiGan(PreTrainedModel):
|
| 1451 |
+
config_class = FastSpeech2ConformerHifiGanConfig
|
| 1452 |
+
main_input_name = "spectrogram"
|
| 1453 |
+
|
| 1454 |
+
def __init__(self, config: FastSpeech2ConformerHifiGanConfig):
|
| 1455 |
+
super().__init__(config)
|
| 1456 |
+
self.num_kernels = len(config.resblock_kernel_sizes)
|
| 1457 |
+
self.num_upsamples = len(config.upsample_rates)
|
| 1458 |
+
self.conv_pre = nn.Conv1d(
|
| 1459 |
+
config.model_in_dim,
|
| 1460 |
+
config.upsample_initial_channel,
|
| 1461 |
+
kernel_size=7,
|
| 1462 |
+
stride=1,
|
| 1463 |
+
padding=3,
|
| 1464 |
+
)
|
| 1465 |
+
|
| 1466 |
+
self.upsampler = nn.ModuleList()
|
| 1467 |
+
for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
|
| 1468 |
+
self.upsampler.append(
|
| 1469 |
+
nn.ConvTranspose1d(
|
| 1470 |
+
config.upsample_initial_channel // (2**i),
|
| 1471 |
+
config.upsample_initial_channel // (2 ** (i + 1)),
|
| 1472 |
+
kernel_size=kernel_size,
|
| 1473 |
+
stride=upsample_rate,
|
| 1474 |
+
padding=(kernel_size - upsample_rate) // 2,
|
| 1475 |
+
)
|
| 1476 |
+
)
|
| 1477 |
+
|
| 1478 |
+
self.resblocks = nn.ModuleList()
|
| 1479 |
+
for i in range(len(self.upsampler)):
|
| 1480 |
+
channels = config.upsample_initial_channel // (2 ** (i + 1))
|
| 1481 |
+
for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
|
| 1482 |
+
self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
|
| 1483 |
+
|
| 1484 |
+
self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3)
|
| 1485 |
+
|
| 1486 |
+
self.register_buffer("mean", torch.zeros(config.model_in_dim))
|
| 1487 |
+
self.register_buffer("scale", torch.ones(config.model_in_dim))
|
| 1488 |
+
|
| 1489 |
+
# Initialize weights and apply final processing
|
| 1490 |
+
self.post_init()
|
| 1491 |
+
|
| 1492 |
+
def _init_weights(self, module):
|
| 1493 |
+
"""Initialize the weights."""
|
| 1494 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
| 1495 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 1496 |
+
if module.bias is not None:
|
| 1497 |
+
module.bias.data.zero_()
|
| 1498 |
+
|
| 1499 |
+
def apply_weight_norm(self):
|
| 1500 |
+
weight_norm = nn.utils.weight_norm
|
| 1501 |
+
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
| 1502 |
+
weight_norm = nn.utils.parametrizations.weight_norm
|
| 1503 |
+
|
| 1504 |
+
weight_norm(self.conv_pre)
|
| 1505 |
+
for layer in self.upsampler:
|
| 1506 |
+
weight_norm(layer)
|
| 1507 |
+
for layer in self.resblocks:
|
| 1508 |
+
layer.apply_weight_norm()
|
| 1509 |
+
weight_norm(self.conv_post)
|
| 1510 |
+
|
| 1511 |
+
def remove_weight_norm(self):
|
| 1512 |
+
nn.utils.remove_weight_norm(self.conv_pre)
|
| 1513 |
+
for layer in self.upsampler:
|
| 1514 |
+
nn.utils.remove_weight_norm(layer)
|
| 1515 |
+
for layer in self.resblocks:
|
| 1516 |
+
layer.remove_weight_norm()
|
| 1517 |
+
nn.utils.remove_weight_norm(self.conv_post)
|
| 1518 |
+
|
| 1519 |
+
def forward(self, spectrogram: torch.FloatTensor) -> torch.FloatTensor:
|
| 1520 |
+
r"""
|
| 1521 |
+
Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch
|
| 1522 |
+
of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech
|
| 1523 |
+
waveform.
|
| 1524 |
+
|
| 1525 |
+
Args:
|
| 1526 |
+
spectrogram (`torch.FloatTensor`):
|
| 1527 |
+
Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length,
|
| 1528 |
+
config.model_in_dim)`, or un-batched and of shape `(sequence_length, config.model_in_dim)`.
|
| 1529 |
+
|
| 1530 |
+
Returns:
|
| 1531 |
+
`torch.FloatTensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of
|
| 1532 |
+
shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`.
|
| 1533 |
+
"""
|
| 1534 |
+
if self.config.normalize_before:
|
| 1535 |
+
spectrogram = (spectrogram - self.mean) / self.scale
|
| 1536 |
+
|
| 1537 |
+
is_batched = spectrogram.dim() == 3
|
| 1538 |
+
if not is_batched:
|
| 1539 |
+
spectrogram = spectrogram.unsqueeze(0)
|
| 1540 |
+
|
| 1541 |
+
hidden_states = spectrogram.transpose(2, 1)
|
| 1542 |
+
|
| 1543 |
+
hidden_states = self.conv_pre(hidden_states)
|
| 1544 |
+
for i in range(self.num_upsamples):
|
| 1545 |
+
hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
|
| 1546 |
+
hidden_states = self.upsampler[i](hidden_states)
|
| 1547 |
+
|
| 1548 |
+
res_state = self.resblocks[i * self.num_kernels](hidden_states)
|
| 1549 |
+
for j in range(1, self.num_kernels):
|
| 1550 |
+
res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
|
| 1551 |
+
hidden_states = res_state / self.num_kernels
|
| 1552 |
+
|
| 1553 |
+
hidden_states = nn.functional.leaky_relu(hidden_states)
|
| 1554 |
+
hidden_states = self.conv_post(hidden_states)
|
| 1555 |
+
hidden_states = torch.tanh(hidden_states)
|
| 1556 |
+
|
| 1557 |
+
if not is_batched:
|
| 1558 |
+
# remove batch dim and collapse tensor to 1-d audio waveform
|
| 1559 |
+
waveform = hidden_states.squeeze(0).transpose(1, 0).view(-1)
|
| 1560 |
+
else:
|
| 1561 |
+
# remove seq-len dim since this collapses to 1
|
| 1562 |
+
waveform = hidden_states.squeeze(1)
|
| 1563 |
+
|
| 1564 |
+
return waveform
|
| 1565 |
+
|
| 1566 |
+
|
| 1567 |
+
@add_start_docstrings(
|
| 1568 |
+
"The FastSpeech2ConformerModel with a FastSpeech2ConformerHifiGan vocoder head that performs text-to-speech (waveform).",
|
| 1569 |
+
FASTSPEECH2_CONFORMER_WITH_HIFIGAN_START_DOCSTRING,
|
| 1570 |
+
)
|
| 1571 |
+
class FastSpeech2ConformerWithHifiGan(PreTrainedModel):
|
| 1572 |
+
config_class = FastSpeech2ConformerWithHifiGanConfig
|
| 1573 |
+
|
| 1574 |
+
def __init__(self, config: FastSpeech2ConformerWithHifiGanConfig):
|
| 1575 |
+
super().__init__(config)
|
| 1576 |
+
|
| 1577 |
+
self.model = FastSpeech2ConformerModel(config.model_config)
|
| 1578 |
+
self.vocoder = FastSpeech2ConformerHifiGan(config.vocoder_config)
|
| 1579 |
+
|
| 1580 |
+
self.config = config
|
| 1581 |
+
|
| 1582 |
+
@replace_return_docstrings(
|
| 1583 |
+
output_type=FastSpeech2ConformerWithHifiGanOutput, config_class=FastSpeech2ConformerWithHifiGanConfig
|
| 1584 |
+
)
|
| 1585 |
+
def forward(
|
| 1586 |
+
self,
|
| 1587 |
+
input_ids: torch.LongTensor,
|
| 1588 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1589 |
+
spectrogram_labels: Optional[torch.FloatTensor] = None,
|
| 1590 |
+
duration_labels: Optional[torch.LongTensor] = None,
|
| 1591 |
+
pitch_labels: Optional[torch.FloatTensor] = None,
|
| 1592 |
+
energy_labels: Optional[torch.FloatTensor] = None,
|
| 1593 |
+
speaker_ids: Optional[torch.LongTensor] = None,
|
| 1594 |
+
lang_ids: Optional[torch.LongTensor] = None,
|
| 1595 |
+
speaker_embedding: Optional[torch.FloatTensor] = None,
|
| 1596 |
+
return_dict: Optional[bool] = None,
|
| 1597 |
+
output_attentions: Optional[bool] = None,
|
| 1598 |
+
output_hidden_states: Optional[bool] = None,
|
| 1599 |
+
) -> Union[Tuple, FastSpeech2ConformerModelOutput]:
|
| 1600 |
+
"""
|
| 1601 |
+
Args:
|
| 1602 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1603 |
+
Input sequence of text vectors.
|
| 1604 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to `None`):
|
| 1605 |
+
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
|
| 1606 |
+
`[0, 1]`: 0 for tokens that are **masked**, 1 for tokens that are **not masked**.
|
| 1607 |
+
spectrogram_labels (`torch.FloatTensor` of shape `(batch_size, max_spectrogram_length, num_mel_bins)`, *optional*, defaults to `None`):
|
| 1608 |
+
Batch of padded target features.
|
| 1609 |
+
duration_labels (`torch.LongTensor` of shape `(batch_size, sequence_length + 1)`, *optional*, defaults to `None`):
|
| 1610 |
+
Batch of padded durations.
|
| 1611 |
+
pitch_labels (`torch.FloatTensor` of shape `(batch_size, sequence_length + 1, 1)`, *optional*, defaults to `None`):
|
| 1612 |
+
Batch of padded token-averaged pitch.
|
| 1613 |
+
energy_labels (`torch.FloatTensor` of shape `(batch_size, sequence_length + 1, 1)`, *optional*, defaults to `None`):
|
| 1614 |
+
Batch of padded token-averaged energy.
|
| 1615 |
+
speaker_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*, defaults to `None`):
|
| 1616 |
+
Speaker ids used to condition features of speech output by the model.
|
| 1617 |
+
lang_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*, defaults to `None`):
|
| 1618 |
+
Language ids used to condition features of speech output by the model.
|
| 1619 |
+
speaker_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`, *optional*, defaults to `None`):
|
| 1620 |
+
Embedding containing conditioning signals for the features of the speech.
|
| 1621 |
+
return_dict (`bool`, *optional*, defaults to `None`):
|
| 1622 |
+
Whether or not to return a [`FastSpeech2ConformerModelOutput`] instead of a plain tuple.
|
| 1623 |
+
output_attentions (`bool`, *optional*, defaults to `None`):
|
| 1624 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 1625 |
+
returned tensors for more detail.
|
| 1626 |
+
output_hidden_states (`bool`, *optional*, defaults to `None`):
|
| 1627 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 1628 |
+
for more detail.
|
| 1629 |
+
|
| 1630 |
+
Returns:
|
| 1631 |
+
|
| 1632 |
+
Example:
|
| 1633 |
+
|
| 1634 |
+
```python
|
| 1635 |
+
>>> from transformers import (
|
| 1636 |
+
... FastSpeech2ConformerTokenizer,
|
| 1637 |
+
... FastSpeech2ConformerWithHifiGan,
|
| 1638 |
+
... )
|
| 1639 |
+
|
| 1640 |
+
>>> tokenizer = FastSpeech2ConformerTokenizer.from_pretrained("espnet/fastspeech2_conformer")
|
| 1641 |
+
>>> inputs = tokenizer("some text to convert to speech", return_tensors="pt")
|
| 1642 |
+
>>> input_ids = inputs["input_ids"]
|
| 1643 |
+
|
| 1644 |
+
>>> model = FastSpeech2ConformerWithHifiGan.from_pretrained("espnet/fastspeech2_conformer_with_hifigan")
|
| 1645 |
+
>>> output_dict = model(input_ids, return_dict=True)
|
| 1646 |
+
>>> waveform = output_dict["waveform"]
|
| 1647 |
+
>>> print(waveform.shape)
|
| 1648 |
+
torch.Size([1, 49664])
|
| 1649 |
+
```
|
| 1650 |
+
"""
|
| 1651 |
+
return_dict = return_dict if return_dict is not None else self.config.model_config.use_return_dict
|
| 1652 |
+
output_attentions = (
|
| 1653 |
+
output_attentions if output_attentions is not None else self.config.model_config.output_attentions
|
| 1654 |
+
)
|
| 1655 |
+
output_hidden_states = (
|
| 1656 |
+
output_hidden_states if output_hidden_states is not None else self.config.model_config.output_hidden_states
|
| 1657 |
+
)
|
| 1658 |
+
|
| 1659 |
+
model_outputs = self.model(
|
| 1660 |
+
input_ids,
|
| 1661 |
+
attention_mask,
|
| 1662 |
+
spectrogram_labels=spectrogram_labels,
|
| 1663 |
+
duration_labels=duration_labels,
|
| 1664 |
+
pitch_labels=pitch_labels,
|
| 1665 |
+
energy_labels=energy_labels,
|
| 1666 |
+
speaker_ids=speaker_ids,
|
| 1667 |
+
lang_ids=lang_ids,
|
| 1668 |
+
speaker_embedding=speaker_embedding,
|
| 1669 |
+
return_dict=return_dict,
|
| 1670 |
+
output_attentions=output_attentions,
|
| 1671 |
+
output_hidden_states=output_hidden_states,
|
| 1672 |
+
)
|
| 1673 |
+
|
| 1674 |
+
if not return_dict:
|
| 1675 |
+
has_missing_labels = (
|
| 1676 |
+
spectrogram_labels is None or duration_labels is None or pitch_labels is None or energy_labels is None
|
| 1677 |
+
)
|
| 1678 |
+
if has_missing_labels:
|
| 1679 |
+
spectrogram = model_outputs[0]
|
| 1680 |
+
else:
|
| 1681 |
+
spectrogram = model_outputs[1]
|
| 1682 |
+
else:
|
| 1683 |
+
spectrogram = model_outputs["spectrogram"]
|
| 1684 |
+
waveform = self.vocoder(spectrogram)
|
| 1685 |
+
|
| 1686 |
+
if not return_dict:
|
| 1687 |
+
return model_outputs + (waveform,)
|
| 1688 |
+
|
| 1689 |
+
return FastSpeech2ConformerWithHifiGanOutput(waveform=waveform, **model_outputs)
|
| 1690 |
+
|
| 1691 |
+
|
| 1692 |
+
__all__ = [
|
| 1693 |
+
"FastSpeech2ConformerWithHifiGan",
|
| 1694 |
+
"FastSpeech2ConformerHifiGan",
|
| 1695 |
+
"FastSpeech2ConformerModel",
|
| 1696 |
+
"FastSpeech2ConformerPreTrainedModel",
|
| 1697 |
+
]
|
docs/transformers/build/lib/transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for FastSpeech2Conformer."""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
from typing import Optional, Tuple
|
| 20 |
+
|
| 21 |
+
import regex
|
| 22 |
+
|
| 23 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
| 24 |
+
from ...utils import logging, requires_backends
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class FastSpeech2ConformerTokenizer(PreTrainedTokenizer):
|
| 33 |
+
"""
|
| 34 |
+
Construct a FastSpeech2Conformer tokenizer.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
vocab_file (`str`):
|
| 38 |
+
Path to the vocabulary file.
|
| 39 |
+
bos_token (`str`, *optional*, defaults to `"<sos/eos>"`):
|
| 40 |
+
The begin of sequence token. Note that for FastSpeech2, it is the same as the `eos_token`.
|
| 41 |
+
eos_token (`str`, *optional*, defaults to `"<sos/eos>"`):
|
| 42 |
+
The end of sequence token. Note that for FastSpeech2, it is the same as the `bos_token`.
|
| 43 |
+
pad_token (`str`, *optional*, defaults to `"<blank>"`):
|
| 44 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 45 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 46 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 47 |
+
token instead.
|
| 48 |
+
should_strip_spaces (`bool`, *optional*, defaults to `False`):
|
| 49 |
+
Whether or not to strip the spaces from the list of tokens.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 53 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
vocab_file,
|
| 58 |
+
bos_token="<sos/eos>",
|
| 59 |
+
eos_token="<sos/eos>",
|
| 60 |
+
pad_token="<blank>",
|
| 61 |
+
unk_token="<unk>",
|
| 62 |
+
should_strip_spaces=False,
|
| 63 |
+
**kwargs,
|
| 64 |
+
):
|
| 65 |
+
requires_backends(self, "g2p_en")
|
| 66 |
+
|
| 67 |
+
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
| 68 |
+
self.encoder = json.load(vocab_handle)
|
| 69 |
+
|
| 70 |
+
import g2p_en
|
| 71 |
+
|
| 72 |
+
self.g2p = g2p_en.G2p()
|
| 73 |
+
|
| 74 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 75 |
+
|
| 76 |
+
super().__init__(
|
| 77 |
+
bos_token=bos_token,
|
| 78 |
+
eos_token=eos_token,
|
| 79 |
+
unk_token=unk_token,
|
| 80 |
+
pad_token=pad_token,
|
| 81 |
+
should_strip_spaces=should_strip_spaces,
|
| 82 |
+
**kwargs,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.should_strip_spaces = should_strip_spaces
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
def vocab_size(self):
|
| 89 |
+
return len(self.decoder)
|
| 90 |
+
|
| 91 |
+
def get_vocab(self):
|
| 92 |
+
"Returns vocab as a dict"
|
| 93 |
+
return dict(self.encoder, **self.added_tokens_encoder)
|
| 94 |
+
|
| 95 |
+
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
|
| 96 |
+
# expand symbols
|
| 97 |
+
text = regex.sub(";", ",", text)
|
| 98 |
+
text = regex.sub(":", ",", text)
|
| 99 |
+
text = regex.sub("-", " ", text)
|
| 100 |
+
text = regex.sub("&", "and", text)
|
| 101 |
+
|
| 102 |
+
# strip unnecessary symbols
|
| 103 |
+
text = regex.sub(r"[\(\)\[\]\<\>\"]+", "", text)
|
| 104 |
+
|
| 105 |
+
# strip whitespaces
|
| 106 |
+
text = regex.sub(r"\s+", " ", text)
|
| 107 |
+
|
| 108 |
+
text = text.upper()
|
| 109 |
+
|
| 110 |
+
return text, kwargs
|
| 111 |
+
|
| 112 |
+
def _tokenize(self, text):
|
| 113 |
+
"""Returns a tokenized string."""
|
| 114 |
+
# phonemize
|
| 115 |
+
tokens = self.g2p(text)
|
| 116 |
+
|
| 117 |
+
if self.should_strip_spaces:
|
| 118 |
+
tokens = list(filter(lambda s: s != " ", tokens))
|
| 119 |
+
|
| 120 |
+
tokens.append(self.eos_token)
|
| 121 |
+
|
| 122 |
+
return tokens
|
| 123 |
+
|
| 124 |
+
def _convert_token_to_id(self, token):
|
| 125 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 126 |
+
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
| 127 |
+
|
| 128 |
+
def _convert_id_to_token(self, index):
|
| 129 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 130 |
+
return self.decoder.get(index, self.unk_token)
|
| 131 |
+
|
| 132 |
+
# Override since phonemes cannot be converted back to strings
|
| 133 |
+
def decode(self, token_ids, **kwargs):
|
| 134 |
+
logger.warning(
|
| 135 |
+
"Phonemes cannot be reliably converted to a string due to the one-many mapping, converting to tokens instead."
|
| 136 |
+
)
|
| 137 |
+
return self.convert_ids_to_tokens(token_ids)
|
| 138 |
+
|
| 139 |
+
# Override since phonemes cannot be converted back to strings
|
| 140 |
+
def convert_tokens_to_string(self, tokens, **kwargs):
|
| 141 |
+
logger.warning(
|
| 142 |
+
"Phonemes cannot be reliably converted to a string due to the one-many mapping, returning the tokens."
|
| 143 |
+
)
|
| 144 |
+
return tokens
|
| 145 |
+
|
| 146 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 147 |
+
"""
|
| 148 |
+
Save the vocabulary and special tokens file to a directory.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
save_directory (`str`):
|
| 152 |
+
The directory in which to save the vocabulary.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
`Tuple(str)`: Paths to the files saved.
|
| 156 |
+
"""
|
| 157 |
+
if not os.path.isdir(save_directory):
|
| 158 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 159 |
+
return
|
| 160 |
+
vocab_file = os.path.join(
|
| 161 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 165 |
+
f.write(json.dumps(self.get_vocab(), ensure_ascii=False))
|
| 166 |
+
|
| 167 |
+
return (vocab_file,)
|
| 168 |
+
|
| 169 |
+
def __getstate__(self):
|
| 170 |
+
state = self.__dict__.copy()
|
| 171 |
+
state["g2p"] = None
|
| 172 |
+
return state
|
| 173 |
+
|
| 174 |
+
def __setstate__(self, d):
|
| 175 |
+
self.__dict__ = d
|
| 176 |
+
|
| 177 |
+
try:
|
| 178 |
+
import g2p_en
|
| 179 |
+
|
| 180 |
+
self.g2p = g2p_en.G2p()
|
| 181 |
+
except ImportError:
|
| 182 |
+
raise ImportError(
|
| 183 |
+
"You need to install g2p-en to use FastSpeech2ConformerTokenizer. "
|
| 184 |
+
"See https://pypi.org/project/g2p-en/ for installation."
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
__all__ = ["FastSpeech2ConformerTokenizer"]
|
docs/transformers/build/lib/transformers/models/flaubert/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_flaubert import *
|
| 22 |
+
from .modeling_flaubert import *
|
| 23 |
+
from .modeling_tf_flaubert import *
|
| 24 |
+
from .tokenization_flaubert import *
|
| 25 |
+
else:
|
| 26 |
+
import sys
|
| 27 |
+
|
| 28 |
+
_file = globals()["__file__"]
|
| 29 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/flaubert/configuration_flaubert.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2019-present CNRS, Facebook Inc. and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Flaubert configuration"""
|
| 16 |
+
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
from typing import Mapping
|
| 19 |
+
|
| 20 |
+
from ...configuration_utils import PretrainedConfig
|
| 21 |
+
from ...onnx import OnnxConfig
|
| 22 |
+
from ...utils import logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class FlaubertConfig(PretrainedConfig):
|
| 29 |
+
"""
|
| 30 |
+
This is the configuration class to store the configuration of a [`FlaubertModel`] or a [`TFFlaubertModel`]. It is
|
| 31 |
+
used to instantiate a FlauBERT model according to the specified arguments, defining the model architecture.
|
| 32 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the FlauBERT
|
| 33 |
+
[flaubert/flaubert_base_uncased](https://huggingface.co/flaubert/flaubert_base_uncased) architecture.
|
| 34 |
+
|
| 35 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 36 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
pre_norm (`bool`, *optional*, defaults to `False`):
|
| 40 |
+
Whether to apply the layer normalization before or after the feed forward layer following the attention in
|
| 41 |
+
each layer (Vaswani et al., Tensor2Tensor for Neural Machine Translation. 2018)
|
| 42 |
+
layerdrop (`float`, *optional*, defaults to 0.0):
|
| 43 |
+
Probability to drop layers during training (Fan et al., Reducing Transformer Depth on Demand with
|
| 44 |
+
Structured Dropout. ICLR 2020)
|
| 45 |
+
vocab_size (`int`, *optional*, defaults to 30145):
|
| 46 |
+
Vocabulary size of the FlauBERT model. Defines the number of different tokens that can be represented by
|
| 47 |
+
the `inputs_ids` passed when calling [`FlaubertModel`] or [`TFFlaubertModel`].
|
| 48 |
+
emb_dim (`int`, *optional*, defaults to 2048):
|
| 49 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 50 |
+
n_layer (`int`, *optional*, defaults to 12):
|
| 51 |
+
Number of hidden layers in the Transformer encoder.
|
| 52 |
+
n_head (`int`, *optional*, defaults to 16):
|
| 53 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 54 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
| 55 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 56 |
+
attention_dropout (`float`, *optional*, defaults to 0.1):
|
| 57 |
+
The dropout probability for the attention mechanism
|
| 58 |
+
gelu_activation (`bool`, *optional*, defaults to `True`):
|
| 59 |
+
Whether or not to use a *gelu* activation instead of *relu*.
|
| 60 |
+
sinusoidal_embeddings (`bool`, *optional*, defaults to `False`):
|
| 61 |
+
Whether or not to use sinusoidal positional embeddings instead of absolute positional embeddings.
|
| 62 |
+
causal (`bool`, *optional*, defaults to `False`):
|
| 63 |
+
Whether or not the model should behave in a causal manner. Causal models use a triangular attention mask in
|
| 64 |
+
order to only attend to the left-side context instead if a bidirectional context.
|
| 65 |
+
asm (`bool`, *optional*, defaults to `False`):
|
| 66 |
+
Whether or not to use an adaptive log softmax projection layer instead of a linear layer for the prediction
|
| 67 |
+
layer.
|
| 68 |
+
n_langs (`int`, *optional*, defaults to 1):
|
| 69 |
+
The number of languages the model handles. Set to 1 for monolingual models.
|
| 70 |
+
use_lang_emb (`bool`, *optional*, defaults to `True`)
|
| 71 |
+
Whether to use language embeddings. Some models use additional language embeddings, see [the multilingual
|
| 72 |
+
models page](http://huggingface.co/transformers/multilingual.html#xlm-language-embeddings) for information
|
| 73 |
+
on how to use them.
|
| 74 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
| 75 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 76 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 77 |
+
embed_init_std (`float`, *optional*, defaults to 2048^-0.5):
|
| 78 |
+
The standard deviation of the truncated_normal_initializer for initializing the embedding matrices.
|
| 79 |
+
init_std (`int`, *optional*, defaults to 50257):
|
| 80 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices except the
|
| 81 |
+
embedding matrices.
|
| 82 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 83 |
+
The epsilon used by the layer normalization layers.
|
| 84 |
+
bos_index (`int`, *optional*, defaults to 0):
|
| 85 |
+
The index of the beginning of sentence token in the vocabulary.
|
| 86 |
+
eos_index (`int`, *optional*, defaults to 1):
|
| 87 |
+
The index of the end of sentence token in the vocabulary.
|
| 88 |
+
pad_index (`int`, *optional*, defaults to 2):
|
| 89 |
+
The index of the padding token in the vocabulary.
|
| 90 |
+
unk_index (`int`, *optional*, defaults to 3):
|
| 91 |
+
The index of the unknown token in the vocabulary.
|
| 92 |
+
mask_index (`int`, *optional*, defaults to 5):
|
| 93 |
+
The index of the masking token in the vocabulary.
|
| 94 |
+
is_encoder(`bool`, *optional*, defaults to `True`):
|
| 95 |
+
Whether or not the initialized model should be a transformer encoder or decoder as seen in Vaswani et al.
|
| 96 |
+
summary_type (`string`, *optional*, defaults to "first"):
|
| 97 |
+
Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
|
| 98 |
+
|
| 99 |
+
Has to be one of the following options:
|
| 100 |
+
|
| 101 |
+
- `"last"`: Take the last token hidden state (like XLNet).
|
| 102 |
+
- `"first"`: Take the first token hidden state (like BERT).
|
| 103 |
+
- `"mean"`: Take the mean of all tokens hidden states.
|
| 104 |
+
- `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
|
| 105 |
+
- `"attn"`: Not implemented now, use multi-head attention.
|
| 106 |
+
summary_use_proj (`bool`, *optional*, defaults to `True`):
|
| 107 |
+
Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
|
| 108 |
+
|
| 109 |
+
Whether or not to add a projection after the vector extraction.
|
| 110 |
+
summary_activation (`str`, *optional*):
|
| 111 |
+
Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
|
| 112 |
+
|
| 113 |
+
Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
|
| 114 |
+
summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
|
| 115 |
+
Used in the sequence classification and multiple choice models.
|
| 116 |
+
|
| 117 |
+
Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
|
| 118 |
+
summary_first_dropout (`float`, *optional*, defaults to 0.1):
|
| 119 |
+
Used in the sequence classification and multiple choice models.
|
| 120 |
+
|
| 121 |
+
The dropout ratio to be used after the projection and activation.
|
| 122 |
+
start_n_top (`int`, *optional*, defaults to 5):
|
| 123 |
+
Used in the SQuAD evaluation script.
|
| 124 |
+
end_n_top (`int`, *optional*, defaults to 5):
|
| 125 |
+
Used in the SQuAD evaluation script.
|
| 126 |
+
mask_token_id (`int`, *optional*, defaults to 0):
|
| 127 |
+
Model agnostic parameter to identify masked tokens when generating text in an MLM context.
|
| 128 |
+
lang_id (`int`, *optional*, defaults to 1):
|
| 129 |
+
The ID of the language used by the model. This parameter is used when generating text in a given language.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
model_type = "flaubert"
|
| 133 |
+
attribute_map = {
|
| 134 |
+
"hidden_size": "emb_dim",
|
| 135 |
+
"num_attention_heads": "n_heads",
|
| 136 |
+
"num_hidden_layers": "n_layers",
|
| 137 |
+
"n_words": "vocab_size", # For backward compatibility
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
pre_norm=False,
|
| 143 |
+
layerdrop=0.0,
|
| 144 |
+
vocab_size=30145,
|
| 145 |
+
emb_dim=2048,
|
| 146 |
+
n_layers=12,
|
| 147 |
+
n_heads=16,
|
| 148 |
+
dropout=0.1,
|
| 149 |
+
attention_dropout=0.1,
|
| 150 |
+
gelu_activation=True,
|
| 151 |
+
sinusoidal_embeddings=False,
|
| 152 |
+
causal=False,
|
| 153 |
+
asm=False,
|
| 154 |
+
n_langs=1,
|
| 155 |
+
use_lang_emb=True,
|
| 156 |
+
max_position_embeddings=512,
|
| 157 |
+
embed_init_std=2048**-0.5,
|
| 158 |
+
layer_norm_eps=1e-12,
|
| 159 |
+
init_std=0.02,
|
| 160 |
+
bos_index=0,
|
| 161 |
+
eos_index=1,
|
| 162 |
+
pad_index=2,
|
| 163 |
+
unk_index=3,
|
| 164 |
+
mask_index=5,
|
| 165 |
+
is_encoder=True,
|
| 166 |
+
summary_type="first",
|
| 167 |
+
summary_use_proj=True,
|
| 168 |
+
summary_activation=None,
|
| 169 |
+
summary_proj_to_labels=True,
|
| 170 |
+
summary_first_dropout=0.1,
|
| 171 |
+
start_n_top=5,
|
| 172 |
+
end_n_top=5,
|
| 173 |
+
mask_token_id=0,
|
| 174 |
+
lang_id=0,
|
| 175 |
+
pad_token_id=2,
|
| 176 |
+
bos_token_id=0,
|
| 177 |
+
**kwargs,
|
| 178 |
+
):
|
| 179 |
+
"""Constructs FlaubertConfig."""
|
| 180 |
+
self.pre_norm = pre_norm
|
| 181 |
+
self.layerdrop = layerdrop
|
| 182 |
+
self.vocab_size = vocab_size
|
| 183 |
+
self.emb_dim = emb_dim
|
| 184 |
+
self.n_layers = n_layers
|
| 185 |
+
self.n_heads = n_heads
|
| 186 |
+
self.dropout = dropout
|
| 187 |
+
self.attention_dropout = attention_dropout
|
| 188 |
+
self.gelu_activation = gelu_activation
|
| 189 |
+
self.sinusoidal_embeddings = sinusoidal_embeddings
|
| 190 |
+
self.causal = causal
|
| 191 |
+
self.asm = asm
|
| 192 |
+
self.n_langs = n_langs
|
| 193 |
+
self.use_lang_emb = use_lang_emb
|
| 194 |
+
self.layer_norm_eps = layer_norm_eps
|
| 195 |
+
self.bos_index = bos_index
|
| 196 |
+
self.eos_index = eos_index
|
| 197 |
+
self.pad_index = pad_index
|
| 198 |
+
self.unk_index = unk_index
|
| 199 |
+
self.mask_index = mask_index
|
| 200 |
+
self.is_encoder = is_encoder
|
| 201 |
+
self.max_position_embeddings = max_position_embeddings
|
| 202 |
+
self.embed_init_std = embed_init_std
|
| 203 |
+
self.init_std = init_std
|
| 204 |
+
self.summary_type = summary_type
|
| 205 |
+
self.summary_use_proj = summary_use_proj
|
| 206 |
+
self.summary_activation = summary_activation
|
| 207 |
+
self.summary_proj_to_labels = summary_proj_to_labels
|
| 208 |
+
self.summary_first_dropout = summary_first_dropout
|
| 209 |
+
self.start_n_top = start_n_top
|
| 210 |
+
self.end_n_top = end_n_top
|
| 211 |
+
self.mask_token_id = mask_token_id
|
| 212 |
+
self.lang_id = lang_id
|
| 213 |
+
|
| 214 |
+
if "n_words" in kwargs:
|
| 215 |
+
self.n_words = kwargs["n_words"]
|
| 216 |
+
|
| 217 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class FlaubertOnnxConfig(OnnxConfig):
|
| 221 |
+
@property
|
| 222 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 223 |
+
if self.task == "multiple-choice":
|
| 224 |
+
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
| 225 |
+
else:
|
| 226 |
+
dynamic_axis = {0: "batch", 1: "sequence"}
|
| 227 |
+
return OrderedDict(
|
| 228 |
+
[
|
| 229 |
+
("input_ids", dynamic_axis),
|
| 230 |
+
("attention_mask", dynamic_axis),
|
| 231 |
+
]
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
__all__ = ["FlaubertConfig", "FlaubertOnnxConfig"]
|
docs/transformers/build/lib/transformers/models/flaubert/modeling_flaubert.py
ADDED
|
@@ -0,0 +1,1739 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2019-present CNRS, Facebook Inc. and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch Flaubert model, based on XLM."""
|
| 16 |
+
|
| 17 |
+
import itertools
|
| 18 |
+
import math
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Callable, Dict, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
from torch import nn
|
| 25 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 26 |
+
|
| 27 |
+
from ...activations import gelu, get_activation
|
| 28 |
+
from ...generation import GenerationMixin
|
| 29 |
+
from ...modeling_outputs import (
|
| 30 |
+
BaseModelOutput,
|
| 31 |
+
MaskedLMOutput,
|
| 32 |
+
MultipleChoiceModelOutput,
|
| 33 |
+
QuestionAnsweringModelOutput,
|
| 34 |
+
SequenceClassifierOutput,
|
| 35 |
+
TokenClassifierOutput,
|
| 36 |
+
)
|
| 37 |
+
from ...modeling_utils import PreTrainedModel
|
| 38 |
+
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
| 39 |
+
from ...utils import (
|
| 40 |
+
ModelOutput,
|
| 41 |
+
add_code_sample_docstrings,
|
| 42 |
+
add_start_docstrings,
|
| 43 |
+
add_start_docstrings_to_model_forward,
|
| 44 |
+
logging,
|
| 45 |
+
replace_return_docstrings,
|
| 46 |
+
)
|
| 47 |
+
from .configuration_flaubert import FlaubertConfig
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
logger = logging.get_logger(__name__)
|
| 51 |
+
|
| 52 |
+
_CHECKPOINT_FOR_DOC = "flaubert/flaubert_base_cased"
|
| 53 |
+
_CONFIG_FOR_DOC = "FlaubertConfig"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Copied from transformers.models.xlm.modeling_xlm.create_sinusoidal_embeddings
|
| 57 |
+
def create_sinusoidal_embeddings(n_pos, dim, out):
|
| 58 |
+
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
|
| 59 |
+
out.requires_grad = False
|
| 60 |
+
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
|
| 61 |
+
out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
|
| 62 |
+
out.detach_()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# Copied from transformers.models.xlm.modeling_xlm.get_masks
|
| 66 |
+
def get_masks(slen, lengths, causal, padding_mask=None):
|
| 67 |
+
"""
|
| 68 |
+
Generate hidden states mask, and optionally an attention mask.
|
| 69 |
+
"""
|
| 70 |
+
alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
|
| 71 |
+
if padding_mask is not None:
|
| 72 |
+
mask = padding_mask
|
| 73 |
+
else:
|
| 74 |
+
assert lengths.max().item() <= slen
|
| 75 |
+
mask = alen < lengths[:, None]
|
| 76 |
+
|
| 77 |
+
# attention mask is the same as mask, or triangular inferior attention (causal)
|
| 78 |
+
bs = lengths.size(0)
|
| 79 |
+
if causal:
|
| 80 |
+
attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
|
| 81 |
+
else:
|
| 82 |
+
attn_mask = mask
|
| 83 |
+
|
| 84 |
+
# sanity check
|
| 85 |
+
assert mask.size() == (bs, slen)
|
| 86 |
+
assert causal is False or attn_mask.size() == (bs, slen, slen)
|
| 87 |
+
|
| 88 |
+
return mask, attn_mask
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# Copied from transformers.models.xlm.modeling_xlm.MultiHeadAttention
|
| 92 |
+
class MultiHeadAttention(nn.Module):
|
| 93 |
+
NEW_ID = itertools.count()
|
| 94 |
+
|
| 95 |
+
def __init__(self, n_heads, dim, config):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.layer_id = next(MultiHeadAttention.NEW_ID)
|
| 98 |
+
self.dim = dim
|
| 99 |
+
self.n_heads = n_heads
|
| 100 |
+
self.dropout = config.attention_dropout
|
| 101 |
+
assert self.dim % self.n_heads == 0
|
| 102 |
+
|
| 103 |
+
self.q_lin = nn.Linear(dim, dim)
|
| 104 |
+
self.k_lin = nn.Linear(dim, dim)
|
| 105 |
+
self.v_lin = nn.Linear(dim, dim)
|
| 106 |
+
self.out_lin = nn.Linear(dim, dim)
|
| 107 |
+
self.pruned_heads = set()
|
| 108 |
+
|
| 109 |
+
def prune_heads(self, heads):
|
| 110 |
+
attention_head_size = self.dim // self.n_heads
|
| 111 |
+
if len(heads) == 0:
|
| 112 |
+
return
|
| 113 |
+
heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
|
| 114 |
+
# Prune linear layers
|
| 115 |
+
self.q_lin = prune_linear_layer(self.q_lin, index)
|
| 116 |
+
self.k_lin = prune_linear_layer(self.k_lin, index)
|
| 117 |
+
self.v_lin = prune_linear_layer(self.v_lin, index)
|
| 118 |
+
self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
|
| 119 |
+
# Update hyper params
|
| 120 |
+
self.n_heads = self.n_heads - len(heads)
|
| 121 |
+
self.dim = attention_head_size * self.n_heads
|
| 122 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 123 |
+
|
| 124 |
+
def forward(self, input, mask, kv=None, cache=None, head_mask=None, output_attentions=False):
|
| 125 |
+
"""
|
| 126 |
+
Self-attention (if kv is None) or attention over source sentence (provided by kv).
|
| 127 |
+
"""
|
| 128 |
+
# Input is (bs, qlen, dim)
|
| 129 |
+
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
|
| 130 |
+
bs, qlen, dim = input.size()
|
| 131 |
+
if kv is None:
|
| 132 |
+
klen = qlen if cache is None else cache["slen"] + qlen
|
| 133 |
+
else:
|
| 134 |
+
klen = kv.size(1)
|
| 135 |
+
# assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
|
| 136 |
+
n_heads = self.n_heads
|
| 137 |
+
dim_per_head = self.dim // n_heads
|
| 138 |
+
mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)
|
| 139 |
+
|
| 140 |
+
def shape(x):
|
| 141 |
+
"""projection"""
|
| 142 |
+
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
|
| 143 |
+
|
| 144 |
+
def unshape(x):
|
| 145 |
+
"""compute context"""
|
| 146 |
+
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
|
| 147 |
+
|
| 148 |
+
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
|
| 149 |
+
if kv is None:
|
| 150 |
+
k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head)
|
| 151 |
+
v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head)
|
| 152 |
+
elif cache is None or self.layer_id not in cache:
|
| 153 |
+
k = v = kv
|
| 154 |
+
k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head)
|
| 155 |
+
v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head)
|
| 156 |
+
|
| 157 |
+
if cache is not None:
|
| 158 |
+
if self.layer_id in cache:
|
| 159 |
+
if kv is None:
|
| 160 |
+
k_, v_ = cache[self.layer_id]
|
| 161 |
+
k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head)
|
| 162 |
+
v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head)
|
| 163 |
+
else:
|
| 164 |
+
k, v = cache[self.layer_id]
|
| 165 |
+
cache[self.layer_id] = (k, v)
|
| 166 |
+
|
| 167 |
+
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
|
| 168 |
+
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
|
| 169 |
+
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
|
| 170 |
+
scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)
|
| 171 |
+
|
| 172 |
+
weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
|
| 173 |
+
weights = nn.functional.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
|
| 174 |
+
|
| 175 |
+
# Mask heads if we want to
|
| 176 |
+
if head_mask is not None:
|
| 177 |
+
weights = weights * head_mask
|
| 178 |
+
|
| 179 |
+
context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
|
| 180 |
+
context = unshape(context) # (bs, qlen, dim)
|
| 181 |
+
|
| 182 |
+
outputs = (self.out_lin(context),)
|
| 183 |
+
if output_attentions:
|
| 184 |
+
outputs = outputs + (weights,)
|
| 185 |
+
return outputs
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# Copied from transformers.models.xlm.modeling_xlm.TransformerFFN
|
| 189 |
+
class TransformerFFN(nn.Module):
|
| 190 |
+
def __init__(self, in_dim, dim_hidden, out_dim, config):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.dropout = config.dropout
|
| 193 |
+
self.lin1 = nn.Linear(in_dim, dim_hidden)
|
| 194 |
+
self.lin2 = nn.Linear(dim_hidden, out_dim)
|
| 195 |
+
self.act = gelu if config.gelu_activation else nn.functional.relu
|
| 196 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 197 |
+
self.seq_len_dim = 1
|
| 198 |
+
|
| 199 |
+
def forward(self, input):
|
| 200 |
+
return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
|
| 201 |
+
|
| 202 |
+
def ff_chunk(self, input):
|
| 203 |
+
x = self.lin1(input)
|
| 204 |
+
x = self.act(x)
|
| 205 |
+
x = self.lin2(x)
|
| 206 |
+
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
| 207 |
+
return x
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
FLAUBERT_START_DOCSTRING = r"""
|
| 211 |
+
|
| 212 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 213 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 214 |
+
etc.)
|
| 215 |
+
|
| 216 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 217 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 218 |
+
and behavior.
|
| 219 |
+
|
| 220 |
+
Parameters:
|
| 221 |
+
config ([`FlaubertConfig`]): Model configuration class with all the parameters of the model.
|
| 222 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 223 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
FLAUBERT_INPUTS_DOCSTRING = r"""
|
| 227 |
+
Args:
|
| 228 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 229 |
+
Indices of input sequence tokens in the vocabulary.
|
| 230 |
+
|
| 231 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 232 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 233 |
+
|
| 234 |
+
[What are input IDs?](../glossary#input-ids)
|
| 235 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 236 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 237 |
+
|
| 238 |
+
- 1 for tokens that are **not masked**,
|
| 239 |
+
- 0 for tokens that are **masked**.
|
| 240 |
+
|
| 241 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 242 |
+
token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 243 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 244 |
+
1]`:
|
| 245 |
+
|
| 246 |
+
- 0 corresponds to a *sentence A* token,
|
| 247 |
+
- 1 corresponds to a *sentence B* token.
|
| 248 |
+
|
| 249 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 250 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 251 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 252 |
+
config.max_position_embeddings - 1]`.
|
| 253 |
+
|
| 254 |
+
[What are position IDs?](../glossary#position-ids)
|
| 255 |
+
lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 256 |
+
Length of each sentence that can be used to avoid performing attention on padding token indices. You can
|
| 257 |
+
also use `attention_mask` for the same result (see above), kept here for compatibility. Indices selected in
|
| 258 |
+
`[0, ..., input_ids.size(-1)]`:
|
| 259 |
+
cache (`Dict[str, torch.FloatTensor]`, *optional*):
|
| 260 |
+
Dictionary strings to `torch.FloatTensor` that contains precomputed hidden-states (key and values in the
|
| 261 |
+
attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential
|
| 262 |
+
decoding. The dictionary object will be modified in-place during the forward pass to add newly computed
|
| 263 |
+
hidden-states.
|
| 264 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 265 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 266 |
+
|
| 267 |
+
- 1 indicates the head is **not masked**,
|
| 268 |
+
- 0 indicates the head is **masked**.
|
| 269 |
+
|
| 270 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 271 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 272 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 273 |
+
model's internal embedding lookup matrix.
|
| 274 |
+
output_attentions (`bool`, *optional*):
|
| 275 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 276 |
+
tensors for more detail.
|
| 277 |
+
output_hidden_states (`bool`, *optional*):
|
| 278 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 279 |
+
more detail.
|
| 280 |
+
return_dict (`bool`, *optional*):
|
| 281 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
@add_start_docstrings(
|
| 286 |
+
"The bare Flaubert Model transformer outputting raw hidden-states without any specific head on top.",
|
| 287 |
+
FLAUBERT_START_DOCSTRING,
|
| 288 |
+
)
|
| 289 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMPredLayer with XLM->Flaubert
|
| 290 |
+
class FlaubertPredLayer(nn.Module):
|
| 291 |
+
"""
|
| 292 |
+
Prediction layer (cross_entropy or adaptive_softmax).
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
def __init__(self, config):
|
| 296 |
+
super().__init__()
|
| 297 |
+
self.asm = config.asm
|
| 298 |
+
self.n_words = config.n_words
|
| 299 |
+
self.pad_index = config.pad_index
|
| 300 |
+
dim = config.emb_dim
|
| 301 |
+
|
| 302 |
+
if config.asm is False:
|
| 303 |
+
self.proj = nn.Linear(dim, config.n_words, bias=True)
|
| 304 |
+
else:
|
| 305 |
+
self.proj = nn.AdaptiveLogSoftmaxWithLoss(
|
| 306 |
+
in_features=dim,
|
| 307 |
+
n_classes=config.n_words,
|
| 308 |
+
cutoffs=config.asm_cutoffs,
|
| 309 |
+
div_value=config.asm_div_value,
|
| 310 |
+
head_bias=True, # default is False
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
def forward(self, x, y=None):
|
| 314 |
+
"""Compute the loss, and optionally the scores."""
|
| 315 |
+
outputs = ()
|
| 316 |
+
if self.asm is False:
|
| 317 |
+
scores = self.proj(x)
|
| 318 |
+
outputs = (scores,) + outputs
|
| 319 |
+
if y is not None:
|
| 320 |
+
loss = nn.functional.cross_entropy(scores.view(-1, self.n_words), y.view(-1), reduction="mean")
|
| 321 |
+
outputs = (loss,) + outputs
|
| 322 |
+
else:
|
| 323 |
+
scores = self.proj.log_prob(x)
|
| 324 |
+
outputs = (scores,) + outputs
|
| 325 |
+
if y is not None:
|
| 326 |
+
_, loss = self.proj(x, y)
|
| 327 |
+
outputs = (loss,) + outputs
|
| 328 |
+
|
| 329 |
+
return outputs
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@dataclass
|
| 333 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMSquadHeadOutput with XLM->Flaubert
|
| 334 |
+
class FlaubertSquadHeadOutput(ModelOutput):
|
| 335 |
+
"""
|
| 336 |
+
Base class for outputs of question answering models using a [`~modeling_utils.FlaubertSQuADHead`].
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
|
| 340 |
+
Classification loss as the sum of start token, end token (and is_impossible if provided) classification
|
| 341 |
+
losses.
|
| 342 |
+
start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
| 343 |
+
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
|
| 344 |
+
start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
| 345 |
+
Indices for the top config.start_n_top start token possibilities (beam-search).
|
| 346 |
+
end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
| 347 |
+
Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
|
| 348 |
+
(beam-search).
|
| 349 |
+
end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
| 350 |
+
Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
|
| 351 |
+
cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
| 352 |
+
Log probabilities for the `is_impossible` label of the answers.
|
| 353 |
+
|
| 354 |
+
"""
|
| 355 |
+
|
| 356 |
+
loss: Optional[torch.FloatTensor] = None
|
| 357 |
+
start_top_log_probs: Optional[torch.FloatTensor] = None
|
| 358 |
+
start_top_index: Optional[torch.LongTensor] = None
|
| 359 |
+
end_top_log_probs: Optional[torch.FloatTensor] = None
|
| 360 |
+
end_top_index: Optional[torch.LongTensor] = None
|
| 361 |
+
cls_logits: Optional[torch.FloatTensor] = None
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerStartLogits with XLM->Flaubert
|
| 365 |
+
class FlaubertPoolerStartLogits(nn.Module):
|
| 366 |
+
"""
|
| 367 |
+
Compute SQuAD start logits from sequence hidden states.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
config ([`FlaubertConfig`]):
|
| 371 |
+
The config used by the model, will be used to grab the `hidden_size` of the model.
|
| 372 |
+
"""
|
| 373 |
+
|
| 374 |
+
def __init__(self, config: FlaubertConfig):
|
| 375 |
+
super().__init__()
|
| 376 |
+
self.dense = nn.Linear(config.hidden_size, 1)
|
| 377 |
+
|
| 378 |
+
def forward(
|
| 379 |
+
self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
|
| 380 |
+
) -> torch.FloatTensor:
|
| 381 |
+
"""
|
| 382 |
+
Args:
|
| 383 |
+
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
| 384 |
+
The final hidden states of the model.
|
| 385 |
+
p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
|
| 386 |
+
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
|
| 387 |
+
should be masked.
|
| 388 |
+
|
| 389 |
+
Returns:
|
| 390 |
+
`torch.FloatTensor`: The start logits for SQuAD.
|
| 391 |
+
"""
|
| 392 |
+
x = self.dense(hidden_states).squeeze(-1)
|
| 393 |
+
|
| 394 |
+
if p_mask is not None:
|
| 395 |
+
if p_mask.dtype == torch.float16:
|
| 396 |
+
x = x * (1 - p_mask) - 65500 * p_mask
|
| 397 |
+
else:
|
| 398 |
+
x = x * (1 - p_mask) - 1e30 * p_mask
|
| 399 |
+
|
| 400 |
+
return x
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerEndLogits with XLM->Flaubert
|
| 404 |
+
class FlaubertPoolerEndLogits(nn.Module):
|
| 405 |
+
"""
|
| 406 |
+
Compute SQuAD end logits from sequence hidden states.
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
config ([`FlaubertConfig`]):
|
| 410 |
+
The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
|
| 411 |
+
to use.
|
| 412 |
+
"""
|
| 413 |
+
|
| 414 |
+
def __init__(self, config: FlaubertConfig):
|
| 415 |
+
super().__init__()
|
| 416 |
+
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
| 417 |
+
self.activation = nn.Tanh()
|
| 418 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 419 |
+
self.dense_1 = nn.Linear(config.hidden_size, 1)
|
| 420 |
+
|
| 421 |
+
def forward(
|
| 422 |
+
self,
|
| 423 |
+
hidden_states: torch.FloatTensor,
|
| 424 |
+
start_states: Optional[torch.FloatTensor] = None,
|
| 425 |
+
start_positions: Optional[torch.LongTensor] = None,
|
| 426 |
+
p_mask: Optional[torch.FloatTensor] = None,
|
| 427 |
+
) -> torch.FloatTensor:
|
| 428 |
+
"""
|
| 429 |
+
Args:
|
| 430 |
+
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
| 431 |
+
The final hidden states of the model.
|
| 432 |
+
start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
|
| 433 |
+
The hidden states of the first tokens for the labeled span.
|
| 434 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 435 |
+
The position of the first token for the labeled span.
|
| 436 |
+
p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
|
| 437 |
+
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
|
| 438 |
+
should be masked.
|
| 439 |
+
|
| 440 |
+
<Tip>
|
| 441 |
+
|
| 442 |
+
One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
|
| 443 |
+
`start_states`.
|
| 444 |
+
|
| 445 |
+
</Tip>
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
`torch.FloatTensor`: The end logits for SQuAD.
|
| 449 |
+
"""
|
| 450 |
+
assert start_states is not None or start_positions is not None, (
|
| 451 |
+
"One of start_states, start_positions should be not None"
|
| 452 |
+
)
|
| 453 |
+
if start_positions is not None:
|
| 454 |
+
slen, hsz = hidden_states.shape[-2:]
|
| 455 |
+
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
| 456 |
+
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
|
| 457 |
+
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
|
| 458 |
+
|
| 459 |
+
x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
|
| 460 |
+
x = self.activation(x)
|
| 461 |
+
x = self.LayerNorm(x)
|
| 462 |
+
x = self.dense_1(x).squeeze(-1)
|
| 463 |
+
|
| 464 |
+
if p_mask is not None:
|
| 465 |
+
if p_mask.dtype == torch.float16:
|
| 466 |
+
x = x * (1 - p_mask) - 65500 * p_mask
|
| 467 |
+
else:
|
| 468 |
+
x = x * (1 - p_mask) - 1e30 * p_mask
|
| 469 |
+
|
| 470 |
+
return x
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerAnswerClass with XLM->Flaubert
|
| 474 |
+
class FlaubertPoolerAnswerClass(nn.Module):
|
| 475 |
+
"""
|
| 476 |
+
Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
|
| 477 |
+
|
| 478 |
+
Args:
|
| 479 |
+
config ([`FlaubertConfig`]):
|
| 480 |
+
The config used by the model, will be used to grab the `hidden_size` of the model.
|
| 481 |
+
"""
|
| 482 |
+
|
| 483 |
+
def __init__(self, config: FlaubertConfig):
|
| 484 |
+
super().__init__()
|
| 485 |
+
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
| 486 |
+
self.activation = nn.Tanh()
|
| 487 |
+
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
|
| 488 |
+
|
| 489 |
+
def forward(
|
| 490 |
+
self,
|
| 491 |
+
hidden_states: torch.FloatTensor,
|
| 492 |
+
start_states: Optional[torch.FloatTensor] = None,
|
| 493 |
+
start_positions: Optional[torch.LongTensor] = None,
|
| 494 |
+
cls_index: Optional[torch.LongTensor] = None,
|
| 495 |
+
) -> torch.FloatTensor:
|
| 496 |
+
"""
|
| 497 |
+
Args:
|
| 498 |
+
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
| 499 |
+
The final hidden states of the model.
|
| 500 |
+
start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
|
| 501 |
+
The hidden states of the first tokens for the labeled span.
|
| 502 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 503 |
+
The position of the first token for the labeled span.
|
| 504 |
+
cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 505 |
+
Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
|
| 506 |
+
|
| 507 |
+
<Tip>
|
| 508 |
+
|
| 509 |
+
One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
|
| 510 |
+
`start_states`.
|
| 511 |
+
|
| 512 |
+
</Tip>
|
| 513 |
+
|
| 514 |
+
Returns:
|
| 515 |
+
`torch.FloatTensor`: The SQuAD 2.0 answer class.
|
| 516 |
+
"""
|
| 517 |
+
# No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
|
| 518 |
+
hsz = hidden_states.shape[-1]
|
| 519 |
+
assert start_states is not None or start_positions is not None, (
|
| 520 |
+
"One of start_states, start_positions should be not None"
|
| 521 |
+
)
|
| 522 |
+
if start_positions is not None:
|
| 523 |
+
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
| 524 |
+
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
|
| 525 |
+
|
| 526 |
+
if cls_index is not None:
|
| 527 |
+
cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
| 528 |
+
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
|
| 529 |
+
else:
|
| 530 |
+
cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
|
| 531 |
+
|
| 532 |
+
x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
|
| 533 |
+
x = self.activation(x)
|
| 534 |
+
x = self.dense_1(x).squeeze(-1)
|
| 535 |
+
|
| 536 |
+
return x
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMSQuADHead with XLM->Flaubert
|
| 540 |
+
class FlaubertSQuADHead(nn.Module):
|
| 541 |
+
r"""
|
| 542 |
+
A SQuAD head inspired by XLNet.
|
| 543 |
+
|
| 544 |
+
Args:
|
| 545 |
+
config ([`FlaubertConfig`]):
|
| 546 |
+
The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
|
| 547 |
+
to use.
|
| 548 |
+
"""
|
| 549 |
+
|
| 550 |
+
def __init__(self, config: FlaubertConfig):
|
| 551 |
+
super().__init__()
|
| 552 |
+
self.start_n_top = config.start_n_top
|
| 553 |
+
self.end_n_top = config.end_n_top
|
| 554 |
+
|
| 555 |
+
self.start_logits = FlaubertPoolerStartLogits(config)
|
| 556 |
+
self.end_logits = FlaubertPoolerEndLogits(config)
|
| 557 |
+
self.answer_class = FlaubertPoolerAnswerClass(config)
|
| 558 |
+
|
| 559 |
+
@replace_return_docstrings(output_type=FlaubertSquadHeadOutput, config_class=FlaubertConfig)
|
| 560 |
+
def forward(
|
| 561 |
+
self,
|
| 562 |
+
hidden_states: torch.FloatTensor,
|
| 563 |
+
start_positions: Optional[torch.LongTensor] = None,
|
| 564 |
+
end_positions: Optional[torch.LongTensor] = None,
|
| 565 |
+
cls_index: Optional[torch.LongTensor] = None,
|
| 566 |
+
is_impossible: Optional[torch.LongTensor] = None,
|
| 567 |
+
p_mask: Optional[torch.FloatTensor] = None,
|
| 568 |
+
return_dict: bool = False,
|
| 569 |
+
) -> Union[FlaubertSquadHeadOutput, Tuple[torch.FloatTensor]]:
|
| 570 |
+
"""
|
| 571 |
+
Args:
|
| 572 |
+
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
| 573 |
+
Final hidden states of the model on the sequence tokens.
|
| 574 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 575 |
+
Positions of the first token for the labeled span.
|
| 576 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 577 |
+
Positions of the last token for the labeled span.
|
| 578 |
+
cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 579 |
+
Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
|
| 580 |
+
is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 581 |
+
Whether the question has a possible answer in the paragraph or not.
|
| 582 |
+
p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
|
| 583 |
+
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
|
| 584 |
+
should be masked.
|
| 585 |
+
return_dict (`bool`, *optional*, defaults to `False`):
|
| 586 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 587 |
+
|
| 588 |
+
Returns:
|
| 589 |
+
"""
|
| 590 |
+
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
| 591 |
+
|
| 592 |
+
if start_positions is not None and end_positions is not None:
|
| 593 |
+
# If we are on multi-GPU, let's remove the dimension added by batch splitting
|
| 594 |
+
for x in (start_positions, end_positions, cls_index, is_impossible):
|
| 595 |
+
if x is not None and x.dim() > 1:
|
| 596 |
+
x.squeeze_(-1)
|
| 597 |
+
|
| 598 |
+
# during training, compute the end logits based on the ground truth of the start position
|
| 599 |
+
end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
|
| 600 |
+
|
| 601 |
+
loss_fct = CrossEntropyLoss()
|
| 602 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 603 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 604 |
+
total_loss = (start_loss + end_loss) / 2
|
| 605 |
+
|
| 606 |
+
if cls_index is not None and is_impossible is not None:
|
| 607 |
+
# Predict answerability from the representation of CLS and START
|
| 608 |
+
cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
|
| 609 |
+
loss_fct_cls = nn.BCEWithLogitsLoss()
|
| 610 |
+
cls_loss = loss_fct_cls(cls_logits, is_impossible)
|
| 611 |
+
|
| 612 |
+
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
|
| 613 |
+
total_loss += cls_loss * 0.5
|
| 614 |
+
|
| 615 |
+
return FlaubertSquadHeadOutput(loss=total_loss) if return_dict else (total_loss,)
|
| 616 |
+
|
| 617 |
+
else:
|
| 618 |
+
# during inference, compute the end logits based on beam search
|
| 619 |
+
bsz, slen, hsz = hidden_states.size()
|
| 620 |
+
start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen)
|
| 621 |
+
|
| 622 |
+
start_top_log_probs, start_top_index = torch.topk(
|
| 623 |
+
start_log_probs, self.start_n_top, dim=-1
|
| 624 |
+
) # shape (bsz, start_n_top)
|
| 625 |
+
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
| 626 |
+
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
|
| 627 |
+
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
|
| 628 |
+
|
| 629 |
+
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
|
| 630 |
+
start_states
|
| 631 |
+
) # shape (bsz, slen, start_n_top, hsz)
|
| 632 |
+
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
|
| 633 |
+
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
|
| 634 |
+
end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
|
| 635 |
+
|
| 636 |
+
end_top_log_probs, end_top_index = torch.topk(
|
| 637 |
+
end_log_probs, self.end_n_top, dim=1
|
| 638 |
+
) # shape (bsz, end_n_top, start_n_top)
|
| 639 |
+
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
|
| 640 |
+
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
|
| 641 |
+
|
| 642 |
+
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
|
| 643 |
+
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
|
| 644 |
+
|
| 645 |
+
if not return_dict:
|
| 646 |
+
return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
|
| 647 |
+
else:
|
| 648 |
+
return FlaubertSquadHeadOutput(
|
| 649 |
+
start_top_log_probs=start_top_log_probs,
|
| 650 |
+
start_top_index=start_top_index,
|
| 651 |
+
end_top_log_probs=end_top_log_probs,
|
| 652 |
+
end_top_index=end_top_index,
|
| 653 |
+
cls_logits=cls_logits,
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Flaubert
|
| 658 |
+
class FlaubertSequenceSummary(nn.Module):
|
| 659 |
+
r"""
|
| 660 |
+
Compute a single vector summary of a sequence hidden states.
|
| 661 |
+
|
| 662 |
+
Args:
|
| 663 |
+
config ([`FlaubertConfig`]):
|
| 664 |
+
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
|
| 665 |
+
config class of your model for the default values it uses):
|
| 666 |
+
|
| 667 |
+
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
|
| 668 |
+
|
| 669 |
+
- `"last"` -- Take the last token hidden state (like XLNet)
|
| 670 |
+
- `"first"` -- Take the first token hidden state (like Bert)
|
| 671 |
+
- `"mean"` -- Take the mean of all tokens hidden states
|
| 672 |
+
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
|
| 673 |
+
- `"attn"` -- Not implemented now, use multi-head attention
|
| 674 |
+
|
| 675 |
+
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
|
| 676 |
+
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
|
| 677 |
+
(otherwise to `config.hidden_size`).
|
| 678 |
+
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
|
| 679 |
+
another string or `None` will add no activation.
|
| 680 |
+
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
|
| 681 |
+
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
|
| 682 |
+
"""
|
| 683 |
+
|
| 684 |
+
def __init__(self, config: FlaubertConfig):
|
| 685 |
+
super().__init__()
|
| 686 |
+
|
| 687 |
+
self.summary_type = getattr(config, "summary_type", "last")
|
| 688 |
+
if self.summary_type == "attn":
|
| 689 |
+
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
| 690 |
+
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
| 691 |
+
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
| 692 |
+
raise NotImplementedError
|
| 693 |
+
|
| 694 |
+
self.summary = nn.Identity()
|
| 695 |
+
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
|
| 696 |
+
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
|
| 697 |
+
num_classes = config.num_labels
|
| 698 |
+
else:
|
| 699 |
+
num_classes = config.hidden_size
|
| 700 |
+
self.summary = nn.Linear(config.hidden_size, num_classes)
|
| 701 |
+
|
| 702 |
+
activation_string = getattr(config, "summary_activation", None)
|
| 703 |
+
self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
|
| 704 |
+
|
| 705 |
+
self.first_dropout = nn.Identity()
|
| 706 |
+
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
| 707 |
+
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
| 708 |
+
|
| 709 |
+
self.last_dropout = nn.Identity()
|
| 710 |
+
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
| 711 |
+
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
| 712 |
+
|
| 713 |
+
def forward(
|
| 714 |
+
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
| 715 |
+
) -> torch.FloatTensor:
|
| 716 |
+
"""
|
| 717 |
+
Compute a single vector summary of a sequence hidden states.
|
| 718 |
+
|
| 719 |
+
Args:
|
| 720 |
+
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
|
| 721 |
+
The hidden states of the last layer.
|
| 722 |
+
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
| 723 |
+
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
|
| 724 |
+
|
| 725 |
+
Returns:
|
| 726 |
+
`torch.FloatTensor`: The summary of the sequence hidden states.
|
| 727 |
+
"""
|
| 728 |
+
if self.summary_type == "last":
|
| 729 |
+
output = hidden_states[:, -1]
|
| 730 |
+
elif self.summary_type == "first":
|
| 731 |
+
output = hidden_states[:, 0]
|
| 732 |
+
elif self.summary_type == "mean":
|
| 733 |
+
output = hidden_states.mean(dim=1)
|
| 734 |
+
elif self.summary_type == "cls_index":
|
| 735 |
+
if cls_index is None:
|
| 736 |
+
cls_index = torch.full_like(
|
| 737 |
+
hidden_states[..., :1, :],
|
| 738 |
+
hidden_states.shape[-2] - 1,
|
| 739 |
+
dtype=torch.long,
|
| 740 |
+
)
|
| 741 |
+
else:
|
| 742 |
+
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
| 743 |
+
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
| 744 |
+
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
| 745 |
+
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
| 746 |
+
elif self.summary_type == "attn":
|
| 747 |
+
raise NotImplementedError
|
| 748 |
+
|
| 749 |
+
output = self.first_dropout(output)
|
| 750 |
+
output = self.summary(output)
|
| 751 |
+
output = self.activation(output)
|
| 752 |
+
output = self.last_dropout(output)
|
| 753 |
+
|
| 754 |
+
return output
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMPreTrainedModel with XLM->Flaubert
|
| 758 |
+
class FlaubertPreTrainedModel(PreTrainedModel):
|
| 759 |
+
"""
|
| 760 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 761 |
+
models.
|
| 762 |
+
"""
|
| 763 |
+
|
| 764 |
+
config_class = FlaubertConfig
|
| 765 |
+
load_tf_weights = None
|
| 766 |
+
base_model_prefix = "transformer"
|
| 767 |
+
|
| 768 |
+
def __init__(self, *inputs, **kwargs):
|
| 769 |
+
super().__init__(*inputs, **kwargs)
|
| 770 |
+
|
| 771 |
+
@property
|
| 772 |
+
def dummy_inputs(self):
|
| 773 |
+
inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
| 774 |
+
attns_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
| 775 |
+
if self.config.use_lang_emb and self.config.n_langs > 1:
|
| 776 |
+
langs_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
| 777 |
+
else:
|
| 778 |
+
langs_list = None
|
| 779 |
+
return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
|
| 780 |
+
|
| 781 |
+
def _init_weights(self, module):
|
| 782 |
+
"""Initialize the weights."""
|
| 783 |
+
if isinstance(module, nn.Embedding):
|
| 784 |
+
if self.config is not None and self.config.embed_init_std is not None:
|
| 785 |
+
nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std)
|
| 786 |
+
if module.padding_idx is not None:
|
| 787 |
+
module.weight.data[module.padding_idx].zero_()
|
| 788 |
+
if isinstance(module, nn.Linear):
|
| 789 |
+
if self.config is not None and self.config.init_std is not None:
|
| 790 |
+
nn.init.normal_(module.weight, mean=0, std=self.config.init_std)
|
| 791 |
+
if module.bias is not None:
|
| 792 |
+
nn.init.constant_(module.bias, 0.0)
|
| 793 |
+
if isinstance(module, nn.LayerNorm):
|
| 794 |
+
module.bias.data.zero_()
|
| 795 |
+
module.weight.data.fill_(1.0)
|
| 796 |
+
if isinstance(module, FlaubertModel) and self.config.sinusoidal_embeddings:
|
| 797 |
+
create_sinusoidal_embeddings(
|
| 798 |
+
self.config.max_position_embeddings, self.config.emb_dim, out=module.position_embeddings.weight
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
class FlaubertModel(FlaubertPreTrainedModel):
|
| 803 |
+
def __init__(self, config): # , dico, is_encoder, with_output):
|
| 804 |
+
super().__init__(config)
|
| 805 |
+
|
| 806 |
+
# encoder / decoder, output layer
|
| 807 |
+
self.is_encoder = config.is_encoder
|
| 808 |
+
self.is_decoder = not config.is_encoder
|
| 809 |
+
if self.is_decoder:
|
| 810 |
+
raise NotImplementedError("Currently Flaubert can only be used as an encoder")
|
| 811 |
+
# self.with_output = with_output
|
| 812 |
+
self.causal = config.causal
|
| 813 |
+
|
| 814 |
+
# dictionary / languages
|
| 815 |
+
self.n_langs = config.n_langs
|
| 816 |
+
self.use_lang_emb = config.use_lang_emb
|
| 817 |
+
self.n_words = config.n_words
|
| 818 |
+
self.eos_index = config.eos_index
|
| 819 |
+
self.pad_index = config.pad_index
|
| 820 |
+
# self.dico = dico
|
| 821 |
+
# self.id2lang = config.id2lang
|
| 822 |
+
# self.lang2id = config.lang2id
|
| 823 |
+
# assert len(self.dico) == self.n_words
|
| 824 |
+
# assert len(self.id2lang) == len(self.lang2id) == self.n_langs
|
| 825 |
+
|
| 826 |
+
# model parameters
|
| 827 |
+
self.dim = config.emb_dim # 512 by default
|
| 828 |
+
self.hidden_dim = self.dim * 4 # 2048 by default
|
| 829 |
+
self.n_heads = config.n_heads # 8 by default
|
| 830 |
+
self.n_layers = config.n_layers
|
| 831 |
+
self.dropout = config.dropout
|
| 832 |
+
self.attention_dropout = config.attention_dropout
|
| 833 |
+
assert self.dim % self.n_heads == 0, "transformer dim must be a multiple of n_heads"
|
| 834 |
+
|
| 835 |
+
# embeddings
|
| 836 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
|
| 837 |
+
if config.n_langs > 1 and config.use_lang_emb:
|
| 838 |
+
self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
|
| 839 |
+
self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
|
| 840 |
+
self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
|
| 841 |
+
|
| 842 |
+
# transformer layers
|
| 843 |
+
self.attentions = nn.ModuleList()
|
| 844 |
+
self.layer_norm1 = nn.ModuleList()
|
| 845 |
+
self.ffns = nn.ModuleList()
|
| 846 |
+
self.layer_norm2 = nn.ModuleList()
|
| 847 |
+
# if self.is_decoder:
|
| 848 |
+
# self.layer_norm15 = nn.ModuleList()
|
| 849 |
+
# self.encoder_attn = nn.ModuleList()
|
| 850 |
+
|
| 851 |
+
for _ in range(self.n_layers):
|
| 852 |
+
self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config))
|
| 853 |
+
self.layer_norm1.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
|
| 854 |
+
# if self.is_decoder:
|
| 855 |
+
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
|
| 856 |
+
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
|
| 857 |
+
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
|
| 858 |
+
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
|
| 859 |
+
|
| 860 |
+
if hasattr(config, "pruned_heads"):
|
| 861 |
+
pruned_heads = config.pruned_heads.copy().items()
|
| 862 |
+
config.pruned_heads = {}
|
| 863 |
+
for layer, heads in pruned_heads:
|
| 864 |
+
if self.attentions[int(layer)].n_heads == config.n_heads:
|
| 865 |
+
self.prune_heads({int(layer): list(map(int, heads))})
|
| 866 |
+
|
| 867 |
+
# Initialize weights and apply final processing
|
| 868 |
+
self.post_init()
|
| 869 |
+
|
| 870 |
+
self.layerdrop = getattr(config, "layerdrop", 0.0)
|
| 871 |
+
self.pre_norm = getattr(config, "pre_norm", False)
|
| 872 |
+
self.register_buffer(
|
| 873 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 874 |
+
)
|
| 875 |
+
|
| 876 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMModel.get_input_embeddings
|
| 877 |
+
def get_input_embeddings(self):
|
| 878 |
+
return self.embeddings
|
| 879 |
+
|
| 880 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMModel.set_input_embeddings
|
| 881 |
+
def set_input_embeddings(self, new_embeddings):
|
| 882 |
+
self.embeddings = new_embeddings
|
| 883 |
+
|
| 884 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMModel._prune_heads
|
| 885 |
+
def _prune_heads(self, heads_to_prune):
|
| 886 |
+
"""
|
| 887 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 888 |
+
class PreTrainedModel
|
| 889 |
+
"""
|
| 890 |
+
for layer, heads in heads_to_prune.items():
|
| 891 |
+
self.attentions[layer].prune_heads(heads)
|
| 892 |
+
|
| 893 |
+
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING)
|
| 894 |
+
@add_code_sample_docstrings(
|
| 895 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 896 |
+
output_type=BaseModelOutput,
|
| 897 |
+
config_class=_CONFIG_FOR_DOC,
|
| 898 |
+
)
|
| 899 |
+
def forward(
|
| 900 |
+
self,
|
| 901 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 902 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 903 |
+
langs: Optional[torch.Tensor] = None,
|
| 904 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 905 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 906 |
+
lengths: Optional[torch.LongTensor] = None,
|
| 907 |
+
cache: Optional[Dict[str, torch.FloatTensor]] = None,
|
| 908 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 909 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 910 |
+
output_attentions: Optional[bool] = None,
|
| 911 |
+
output_hidden_states: Optional[bool] = None,
|
| 912 |
+
return_dict: Optional[bool] = None,
|
| 913 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 914 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 915 |
+
output_hidden_states = (
|
| 916 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 917 |
+
)
|
| 918 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 919 |
+
|
| 920 |
+
# removed: src_enc=None, src_len=None
|
| 921 |
+
if input_ids is not None:
|
| 922 |
+
bs, slen = input_ids.size()
|
| 923 |
+
else:
|
| 924 |
+
bs, slen = inputs_embeds.size()[:-1]
|
| 925 |
+
|
| 926 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 927 |
+
|
| 928 |
+
if lengths is None:
|
| 929 |
+
if input_ids is not None:
|
| 930 |
+
lengths = (input_ids != self.pad_index).sum(dim=1).long()
|
| 931 |
+
else:
|
| 932 |
+
lengths = torch.tensor([slen] * bs, device=device)
|
| 933 |
+
# mask = input_ids != self.pad_index
|
| 934 |
+
|
| 935 |
+
# check inputs
|
| 936 |
+
assert lengths.size(0) == bs
|
| 937 |
+
assert lengths.max().item() <= slen
|
| 938 |
+
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
|
| 939 |
+
# assert (src_enc is None) == (src_len is None)
|
| 940 |
+
# if src_enc is not None:
|
| 941 |
+
# assert self.is_decoder
|
| 942 |
+
# assert src_enc.size(0) == bs
|
| 943 |
+
|
| 944 |
+
# generate masks
|
| 945 |
+
mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
|
| 946 |
+
# if self.is_decoder and src_enc is not None:
|
| 947 |
+
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
|
| 948 |
+
|
| 949 |
+
# Setting the position-ids to the registered buffer in constructor, it helps
|
| 950 |
+
# when tracing the model without passing position-ids, solves
|
| 951 |
+
# isues similar to issue #5664
|
| 952 |
+
if position_ids is None:
|
| 953 |
+
if hasattr(self, "position_ids"):
|
| 954 |
+
position_ids = self.position_ids[:, :slen]
|
| 955 |
+
position_ids = position_ids.expand((bs, slen))
|
| 956 |
+
else:
|
| 957 |
+
position_ids = torch.arange(slen, dtype=torch.long, device=device)
|
| 958 |
+
position_ids = position_ids.unsqueeze(0).expand((bs, slen))
|
| 959 |
+
else:
|
| 960 |
+
assert position_ids.size() == (bs, slen) # (slen, bs)
|
| 961 |
+
# position_ids = position_ids.transpose(0, 1)
|
| 962 |
+
|
| 963 |
+
# langs
|
| 964 |
+
if langs is not None:
|
| 965 |
+
assert langs.size() == (bs, slen) # (slen, bs)
|
| 966 |
+
# langs = langs.transpose(0, 1)
|
| 967 |
+
|
| 968 |
+
# Prepare head mask if needed
|
| 969 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layers)
|
| 970 |
+
|
| 971 |
+
# do not recompute cached elements
|
| 972 |
+
if cache is not None and input_ids is not None:
|
| 973 |
+
_slen = slen - cache["slen"]
|
| 974 |
+
input_ids = input_ids[:, -_slen:]
|
| 975 |
+
position_ids = position_ids[:, -_slen:]
|
| 976 |
+
if langs is not None:
|
| 977 |
+
langs = langs[:, -_slen:]
|
| 978 |
+
mask = mask[:, -_slen:]
|
| 979 |
+
attn_mask = attn_mask[:, -_slen:]
|
| 980 |
+
|
| 981 |
+
# embeddings
|
| 982 |
+
if inputs_embeds is None:
|
| 983 |
+
inputs_embeds = self.embeddings(input_ids)
|
| 984 |
+
|
| 985 |
+
tensor = inputs_embeds + self.position_embeddings(position_ids).expand_as(inputs_embeds)
|
| 986 |
+
if langs is not None and self.use_lang_emb and self.config.n_langs > 1:
|
| 987 |
+
tensor = tensor + self.lang_embeddings(langs)
|
| 988 |
+
if token_type_ids is not None:
|
| 989 |
+
tensor = tensor + self.embeddings(token_type_ids)
|
| 990 |
+
tensor = self.layer_norm_emb(tensor)
|
| 991 |
+
tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training)
|
| 992 |
+
tensor *= mask.unsqueeze(-1).to(tensor.dtype)
|
| 993 |
+
|
| 994 |
+
# transformer layers
|
| 995 |
+
hidden_states = () if output_hidden_states else None
|
| 996 |
+
attentions = () if output_attentions else None
|
| 997 |
+
for i in range(self.n_layers):
|
| 998 |
+
# LayerDrop
|
| 999 |
+
if self.training:
|
| 1000 |
+
dropout_probability = torch.rand([])
|
| 1001 |
+
if dropout_probability < self.layerdrop:
|
| 1002 |
+
continue
|
| 1003 |
+
|
| 1004 |
+
if output_hidden_states:
|
| 1005 |
+
hidden_states = hidden_states + (tensor,)
|
| 1006 |
+
|
| 1007 |
+
# self attention
|
| 1008 |
+
if not self.pre_norm:
|
| 1009 |
+
attn_outputs = self.attentions[i](
|
| 1010 |
+
tensor,
|
| 1011 |
+
attn_mask,
|
| 1012 |
+
cache=cache,
|
| 1013 |
+
head_mask=head_mask[i],
|
| 1014 |
+
output_attentions=output_attentions,
|
| 1015 |
+
)
|
| 1016 |
+
attn = attn_outputs[0]
|
| 1017 |
+
if output_attentions:
|
| 1018 |
+
attentions = attentions + (attn_outputs[1],)
|
| 1019 |
+
attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
|
| 1020 |
+
tensor = tensor + attn
|
| 1021 |
+
tensor = self.layer_norm1[i](tensor)
|
| 1022 |
+
else:
|
| 1023 |
+
tensor_normalized = self.layer_norm1[i](tensor)
|
| 1024 |
+
attn_outputs = self.attentions[i](tensor_normalized, attn_mask, cache=cache, head_mask=head_mask[i])
|
| 1025 |
+
attn = attn_outputs[0]
|
| 1026 |
+
if output_attentions:
|
| 1027 |
+
attentions = attentions + (attn_outputs[1],)
|
| 1028 |
+
attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
|
| 1029 |
+
tensor = tensor + attn
|
| 1030 |
+
|
| 1031 |
+
# encoder attention (for decoder only)
|
| 1032 |
+
# if self.is_decoder and src_enc is not None:
|
| 1033 |
+
# attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
|
| 1034 |
+
# attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
|
| 1035 |
+
# tensor = tensor + attn
|
| 1036 |
+
# tensor = self.layer_norm15[i](tensor)
|
| 1037 |
+
|
| 1038 |
+
# FFN
|
| 1039 |
+
if not self.pre_norm:
|
| 1040 |
+
tensor = tensor + self.ffns[i](tensor)
|
| 1041 |
+
tensor = self.layer_norm2[i](tensor)
|
| 1042 |
+
else:
|
| 1043 |
+
tensor_normalized = self.layer_norm2[i](tensor)
|
| 1044 |
+
tensor = tensor + self.ffns[i](tensor_normalized)
|
| 1045 |
+
|
| 1046 |
+
tensor *= mask.unsqueeze(-1).to(tensor.dtype)
|
| 1047 |
+
|
| 1048 |
+
# Add last hidden state
|
| 1049 |
+
if output_hidden_states:
|
| 1050 |
+
hidden_states = hidden_states + (tensor,)
|
| 1051 |
+
|
| 1052 |
+
# update cache length
|
| 1053 |
+
if cache is not None:
|
| 1054 |
+
cache["slen"] += tensor.size(1)
|
| 1055 |
+
|
| 1056 |
+
# move back sequence length to dimension 0
|
| 1057 |
+
# tensor = tensor.transpose(0, 1)
|
| 1058 |
+
|
| 1059 |
+
if not return_dict:
|
| 1060 |
+
return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
|
| 1061 |
+
|
| 1062 |
+
return BaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
|
| 1063 |
+
|
| 1064 |
+
|
| 1065 |
+
@add_start_docstrings(
|
| 1066 |
+
"""
|
| 1067 |
+
The Flaubert Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
| 1068 |
+
embeddings).
|
| 1069 |
+
""",
|
| 1070 |
+
FLAUBERT_START_DOCSTRING,
|
| 1071 |
+
)
|
| 1072 |
+
# Copied transformers.models.xlm.modeling_xlm.XLMWithLMHeadModel with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
| 1073 |
+
class FlaubertWithLMHeadModel(FlaubertPreTrainedModel, GenerationMixin):
|
| 1074 |
+
_tied_weights_keys = ["pred_layer.proj.weight"]
|
| 1075 |
+
|
| 1076 |
+
def __init__(self, config):
|
| 1077 |
+
super().__init__(config)
|
| 1078 |
+
self.transformer = FlaubertModel(config)
|
| 1079 |
+
self.pred_layer = FlaubertPredLayer(config)
|
| 1080 |
+
|
| 1081 |
+
# Initialize weights and apply final processing
|
| 1082 |
+
self.post_init()
|
| 1083 |
+
|
| 1084 |
+
def get_output_embeddings(self):
|
| 1085 |
+
return self.pred_layer.proj
|
| 1086 |
+
|
| 1087 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1088 |
+
self.pred_layer.proj = new_embeddings
|
| 1089 |
+
|
| 1090 |
+
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| 1091 |
+
# Overwritten -- uses a language id
|
| 1092 |
+
|
| 1093 |
+
mask_token_id = self.config.mask_token_id
|
| 1094 |
+
lang_id = self.config.lang_id
|
| 1095 |
+
|
| 1096 |
+
effective_batch_size = input_ids.shape[0]
|
| 1097 |
+
mask_token = torch.full((effective_batch_size, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
|
| 1098 |
+
input_ids = torch.cat([input_ids, mask_token], dim=1)
|
| 1099 |
+
if lang_id is not None:
|
| 1100 |
+
langs = torch.full_like(input_ids, lang_id)
|
| 1101 |
+
else:
|
| 1102 |
+
langs = None
|
| 1103 |
+
return {"input_ids": input_ids, "langs": langs}
|
| 1104 |
+
|
| 1105 |
+
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1106 |
+
@add_code_sample_docstrings(
|
| 1107 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1108 |
+
output_type=MaskedLMOutput,
|
| 1109 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1110 |
+
mask="<special1>",
|
| 1111 |
+
)
|
| 1112 |
+
def forward(
|
| 1113 |
+
self,
|
| 1114 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1115 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1116 |
+
langs: Optional[torch.Tensor] = None,
|
| 1117 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1118 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1119 |
+
lengths: Optional[torch.Tensor] = None,
|
| 1120 |
+
cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 1121 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1122 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1123 |
+
labels: Optional[torch.Tensor] = None,
|
| 1124 |
+
output_attentions: Optional[bool] = None,
|
| 1125 |
+
output_hidden_states: Optional[bool] = None,
|
| 1126 |
+
return_dict: Optional[bool] = None,
|
| 1127 |
+
) -> Union[Tuple, MaskedLMOutput]:
|
| 1128 |
+
r"""
|
| 1129 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1130 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
| 1131 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
| 1132 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
| 1133 |
+
"""
|
| 1134 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1135 |
+
|
| 1136 |
+
transformer_outputs = self.transformer(
|
| 1137 |
+
input_ids,
|
| 1138 |
+
attention_mask=attention_mask,
|
| 1139 |
+
langs=langs,
|
| 1140 |
+
token_type_ids=token_type_ids,
|
| 1141 |
+
position_ids=position_ids,
|
| 1142 |
+
lengths=lengths,
|
| 1143 |
+
cache=cache,
|
| 1144 |
+
head_mask=head_mask,
|
| 1145 |
+
inputs_embeds=inputs_embeds,
|
| 1146 |
+
output_attentions=output_attentions,
|
| 1147 |
+
output_hidden_states=output_hidden_states,
|
| 1148 |
+
return_dict=return_dict,
|
| 1149 |
+
)
|
| 1150 |
+
|
| 1151 |
+
output = transformer_outputs[0]
|
| 1152 |
+
outputs = self.pred_layer(output, labels) # (loss, logits) or (logits,) depending on if labels are provided.
|
| 1153 |
+
|
| 1154 |
+
if not return_dict:
|
| 1155 |
+
return outputs + transformer_outputs[1:]
|
| 1156 |
+
|
| 1157 |
+
return MaskedLMOutput(
|
| 1158 |
+
loss=outputs[0] if labels is not None else None,
|
| 1159 |
+
logits=outputs[0] if labels is None else outputs[1],
|
| 1160 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 1161 |
+
attentions=transformer_outputs.attentions,
|
| 1162 |
+
)
|
| 1163 |
+
|
| 1164 |
+
|
| 1165 |
+
@add_start_docstrings(
|
| 1166 |
+
"""
|
| 1167 |
+
Flaubert Model with a sequence classification/regression head on top (a linear layer on top of the pooled output)
|
| 1168 |
+
e.g. for GLUE tasks.
|
| 1169 |
+
""",
|
| 1170 |
+
FLAUBERT_START_DOCSTRING,
|
| 1171 |
+
)
|
| 1172 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMForSequenceClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
| 1173 |
+
class FlaubertForSequenceClassification(FlaubertPreTrainedModel):
|
| 1174 |
+
def __init__(self, config):
|
| 1175 |
+
super().__init__(config)
|
| 1176 |
+
self.num_labels = config.num_labels
|
| 1177 |
+
self.config = config
|
| 1178 |
+
|
| 1179 |
+
self.transformer = FlaubertModel(config)
|
| 1180 |
+
self.sequence_summary = FlaubertSequenceSummary(config)
|
| 1181 |
+
|
| 1182 |
+
# Initialize weights and apply final processing
|
| 1183 |
+
self.post_init()
|
| 1184 |
+
|
| 1185 |
+
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1186 |
+
@add_code_sample_docstrings(
|
| 1187 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1188 |
+
output_type=SequenceClassifierOutput,
|
| 1189 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1190 |
+
)
|
| 1191 |
+
def forward(
|
| 1192 |
+
self,
|
| 1193 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1194 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1195 |
+
langs: Optional[torch.Tensor] = None,
|
| 1196 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1197 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1198 |
+
lengths: Optional[torch.Tensor] = None,
|
| 1199 |
+
cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 1200 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1201 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1202 |
+
labels: Optional[torch.Tensor] = None,
|
| 1203 |
+
output_attentions: Optional[bool] = None,
|
| 1204 |
+
output_hidden_states: Optional[bool] = None,
|
| 1205 |
+
return_dict: Optional[bool] = None,
|
| 1206 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 1207 |
+
r"""
|
| 1208 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1209 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1210 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1211 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1212 |
+
"""
|
| 1213 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1214 |
+
|
| 1215 |
+
transformer_outputs = self.transformer(
|
| 1216 |
+
input_ids,
|
| 1217 |
+
attention_mask=attention_mask,
|
| 1218 |
+
langs=langs,
|
| 1219 |
+
token_type_ids=token_type_ids,
|
| 1220 |
+
position_ids=position_ids,
|
| 1221 |
+
lengths=lengths,
|
| 1222 |
+
cache=cache,
|
| 1223 |
+
head_mask=head_mask,
|
| 1224 |
+
inputs_embeds=inputs_embeds,
|
| 1225 |
+
output_attentions=output_attentions,
|
| 1226 |
+
output_hidden_states=output_hidden_states,
|
| 1227 |
+
return_dict=return_dict,
|
| 1228 |
+
)
|
| 1229 |
+
|
| 1230 |
+
output = transformer_outputs[0]
|
| 1231 |
+
logits = self.sequence_summary(output)
|
| 1232 |
+
|
| 1233 |
+
loss = None
|
| 1234 |
+
if labels is not None:
|
| 1235 |
+
if self.config.problem_type is None:
|
| 1236 |
+
if self.num_labels == 1:
|
| 1237 |
+
self.config.problem_type = "regression"
|
| 1238 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 1239 |
+
self.config.problem_type = "single_label_classification"
|
| 1240 |
+
else:
|
| 1241 |
+
self.config.problem_type = "multi_label_classification"
|
| 1242 |
+
|
| 1243 |
+
if self.config.problem_type == "regression":
|
| 1244 |
+
loss_fct = MSELoss()
|
| 1245 |
+
if self.num_labels == 1:
|
| 1246 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 1247 |
+
else:
|
| 1248 |
+
loss = loss_fct(logits, labels)
|
| 1249 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1250 |
+
loss_fct = CrossEntropyLoss()
|
| 1251 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1252 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1253 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1254 |
+
loss = loss_fct(logits, labels)
|
| 1255 |
+
|
| 1256 |
+
if not return_dict:
|
| 1257 |
+
output = (logits,) + transformer_outputs[1:]
|
| 1258 |
+
return ((loss,) + output) if loss is not None else output
|
| 1259 |
+
|
| 1260 |
+
return SequenceClassifierOutput(
|
| 1261 |
+
loss=loss,
|
| 1262 |
+
logits=logits,
|
| 1263 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 1264 |
+
attentions=transformer_outputs.attentions,
|
| 1265 |
+
)
|
| 1266 |
+
|
| 1267 |
+
|
| 1268 |
+
@add_start_docstrings(
|
| 1269 |
+
"""
|
| 1270 |
+
Flaubert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 1271 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1272 |
+
""",
|
| 1273 |
+
FLAUBERT_START_DOCSTRING,
|
| 1274 |
+
)
|
| 1275 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMForTokenClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
| 1276 |
+
class FlaubertForTokenClassification(FlaubertPreTrainedModel):
|
| 1277 |
+
def __init__(self, config):
|
| 1278 |
+
super().__init__(config)
|
| 1279 |
+
self.num_labels = config.num_labels
|
| 1280 |
+
|
| 1281 |
+
self.transformer = FlaubertModel(config)
|
| 1282 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 1283 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1284 |
+
|
| 1285 |
+
# Initialize weights and apply final processing
|
| 1286 |
+
self.post_init()
|
| 1287 |
+
|
| 1288 |
+
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1289 |
+
@add_code_sample_docstrings(
|
| 1290 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1291 |
+
output_type=TokenClassifierOutput,
|
| 1292 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1293 |
+
)
|
| 1294 |
+
def forward(
|
| 1295 |
+
self,
|
| 1296 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1297 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1298 |
+
langs: Optional[torch.Tensor] = None,
|
| 1299 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1300 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1301 |
+
lengths: Optional[torch.Tensor] = None,
|
| 1302 |
+
cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 1303 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1304 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1305 |
+
labels: Optional[torch.Tensor] = None,
|
| 1306 |
+
output_attentions: Optional[bool] = None,
|
| 1307 |
+
output_hidden_states: Optional[bool] = None,
|
| 1308 |
+
return_dict: Optional[bool] = None,
|
| 1309 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
| 1310 |
+
r"""
|
| 1311 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1312 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
| 1313 |
+
"""
|
| 1314 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1315 |
+
|
| 1316 |
+
outputs = self.transformer(
|
| 1317 |
+
input_ids,
|
| 1318 |
+
attention_mask=attention_mask,
|
| 1319 |
+
langs=langs,
|
| 1320 |
+
token_type_ids=token_type_ids,
|
| 1321 |
+
position_ids=position_ids,
|
| 1322 |
+
lengths=lengths,
|
| 1323 |
+
cache=cache,
|
| 1324 |
+
head_mask=head_mask,
|
| 1325 |
+
inputs_embeds=inputs_embeds,
|
| 1326 |
+
output_attentions=output_attentions,
|
| 1327 |
+
output_hidden_states=output_hidden_states,
|
| 1328 |
+
return_dict=return_dict,
|
| 1329 |
+
)
|
| 1330 |
+
|
| 1331 |
+
sequence_output = outputs[0]
|
| 1332 |
+
|
| 1333 |
+
sequence_output = self.dropout(sequence_output)
|
| 1334 |
+
logits = self.classifier(sequence_output)
|
| 1335 |
+
|
| 1336 |
+
loss = None
|
| 1337 |
+
if labels is not None:
|
| 1338 |
+
loss_fct = CrossEntropyLoss()
|
| 1339 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1340 |
+
|
| 1341 |
+
if not return_dict:
|
| 1342 |
+
output = (logits,) + outputs[1:]
|
| 1343 |
+
return ((loss,) + output) if loss is not None else output
|
| 1344 |
+
|
| 1345 |
+
return TokenClassifierOutput(
|
| 1346 |
+
loss=loss,
|
| 1347 |
+
logits=logits,
|
| 1348 |
+
hidden_states=outputs.hidden_states,
|
| 1349 |
+
attentions=outputs.attentions,
|
| 1350 |
+
)
|
| 1351 |
+
|
| 1352 |
+
|
| 1353 |
+
@add_start_docstrings(
|
| 1354 |
+
"""
|
| 1355 |
+
Flaubert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 1356 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1357 |
+
""",
|
| 1358 |
+
FLAUBERT_START_DOCSTRING,
|
| 1359 |
+
)
|
| 1360 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMForQuestionAnsweringSimple with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
| 1361 |
+
class FlaubertForQuestionAnsweringSimple(FlaubertPreTrainedModel):
|
| 1362 |
+
def __init__(self, config):
|
| 1363 |
+
super().__init__(config)
|
| 1364 |
+
|
| 1365 |
+
self.transformer = FlaubertModel(config)
|
| 1366 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 1367 |
+
|
| 1368 |
+
# Initialize weights and apply final processing
|
| 1369 |
+
self.post_init()
|
| 1370 |
+
|
| 1371 |
+
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1372 |
+
@add_code_sample_docstrings(
|
| 1373 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1374 |
+
output_type=QuestionAnsweringModelOutput,
|
| 1375 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1376 |
+
)
|
| 1377 |
+
def forward(
|
| 1378 |
+
self,
|
| 1379 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1380 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1381 |
+
langs: Optional[torch.Tensor] = None,
|
| 1382 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1383 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1384 |
+
lengths: Optional[torch.Tensor] = None,
|
| 1385 |
+
cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 1386 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1387 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1388 |
+
start_positions: Optional[torch.Tensor] = None,
|
| 1389 |
+
end_positions: Optional[torch.Tensor] = None,
|
| 1390 |
+
output_attentions: Optional[bool] = None,
|
| 1391 |
+
output_hidden_states: Optional[bool] = None,
|
| 1392 |
+
return_dict: Optional[bool] = None,
|
| 1393 |
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
| 1394 |
+
r"""
|
| 1395 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1396 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1397 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1398 |
+
are not taken into account for computing the loss.
|
| 1399 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1400 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1401 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1402 |
+
are not taken into account for computing the loss.
|
| 1403 |
+
"""
|
| 1404 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1405 |
+
|
| 1406 |
+
transformer_outputs = self.transformer(
|
| 1407 |
+
input_ids,
|
| 1408 |
+
attention_mask=attention_mask,
|
| 1409 |
+
langs=langs,
|
| 1410 |
+
token_type_ids=token_type_ids,
|
| 1411 |
+
position_ids=position_ids,
|
| 1412 |
+
lengths=lengths,
|
| 1413 |
+
cache=cache,
|
| 1414 |
+
head_mask=head_mask,
|
| 1415 |
+
inputs_embeds=inputs_embeds,
|
| 1416 |
+
output_attentions=output_attentions,
|
| 1417 |
+
output_hidden_states=output_hidden_states,
|
| 1418 |
+
return_dict=return_dict,
|
| 1419 |
+
)
|
| 1420 |
+
|
| 1421 |
+
sequence_output = transformer_outputs[0]
|
| 1422 |
+
|
| 1423 |
+
logits = self.qa_outputs(sequence_output)
|
| 1424 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1425 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 1426 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 1427 |
+
|
| 1428 |
+
total_loss = None
|
| 1429 |
+
if start_positions is not None and end_positions is not None:
|
| 1430 |
+
# If we are on multi-GPU, split add a dimension
|
| 1431 |
+
if len(start_positions.size()) > 1:
|
| 1432 |
+
start_positions = start_positions.squeeze(-1)
|
| 1433 |
+
if len(end_positions.size()) > 1:
|
| 1434 |
+
end_positions = end_positions.squeeze(-1)
|
| 1435 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1436 |
+
ignored_index = start_logits.size(1)
|
| 1437 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 1438 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 1439 |
+
|
| 1440 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1441 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1442 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1443 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1444 |
+
|
| 1445 |
+
if not return_dict:
|
| 1446 |
+
output = (start_logits, end_logits) + transformer_outputs[1:]
|
| 1447 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1448 |
+
|
| 1449 |
+
return QuestionAnsweringModelOutput(
|
| 1450 |
+
loss=total_loss,
|
| 1451 |
+
start_logits=start_logits,
|
| 1452 |
+
end_logits=end_logits,
|
| 1453 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 1454 |
+
attentions=transformer_outputs.attentions,
|
| 1455 |
+
)
|
| 1456 |
+
|
| 1457 |
+
|
| 1458 |
+
@add_start_docstrings(
|
| 1459 |
+
"""
|
| 1460 |
+
Flaubert Model with a beam-search span classification head on top for extractive question-answering tasks like
|
| 1461 |
+
SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1462 |
+
""",
|
| 1463 |
+
FLAUBERT_START_DOCSTRING,
|
| 1464 |
+
)
|
| 1465 |
+
@dataclass
|
| 1466 |
+
# Copied from transformer.models.xlm.modeling_xlm.XLMForQuestionAnsweringOutput with XLM->Flaubert
|
| 1467 |
+
class FlaubertForQuestionAnsweringOutput(ModelOutput):
|
| 1468 |
+
"""
|
| 1469 |
+
Base class for outputs of question answering models using a `SquadHead`.
|
| 1470 |
+
|
| 1471 |
+
Args:
|
| 1472 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
|
| 1473 |
+
Classification loss as the sum of start token, end token (and is_impossible if provided) classification
|
| 1474 |
+
losses.
|
| 1475 |
+
start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
| 1476 |
+
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
|
| 1477 |
+
start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
| 1478 |
+
Indices for the top config.start_n_top start token possibilities (beam-search).
|
| 1479 |
+
end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
| 1480 |
+
Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
|
| 1481 |
+
(beam-search).
|
| 1482 |
+
end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
| 1483 |
+
Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
|
| 1484 |
+
cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
| 1485 |
+
Log probabilities for the `is_impossible` label of the answers.
|
| 1486 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 1487 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 1488 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
| 1489 |
+
|
| 1490 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 1491 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 1492 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 1493 |
+
sequence_length)`.
|
| 1494 |
+
|
| 1495 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 1496 |
+
heads.
|
| 1497 |
+
"""
|
| 1498 |
+
|
| 1499 |
+
loss: Optional[torch.FloatTensor] = None
|
| 1500 |
+
start_top_log_probs: Optional[torch.FloatTensor] = None
|
| 1501 |
+
start_top_index: Optional[torch.LongTensor] = None
|
| 1502 |
+
end_top_log_probs: Optional[torch.FloatTensor] = None
|
| 1503 |
+
end_top_index: Optional[torch.LongTensor] = None
|
| 1504 |
+
cls_logits: Optional[torch.FloatTensor] = None
|
| 1505 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 1506 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 1507 |
+
|
| 1508 |
+
|
| 1509 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMForQuestionAnswering with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
| 1510 |
+
class FlaubertForQuestionAnswering(FlaubertPreTrainedModel):
|
| 1511 |
+
def __init__(self, config):
|
| 1512 |
+
super().__init__(config)
|
| 1513 |
+
|
| 1514 |
+
self.transformer = FlaubertModel(config)
|
| 1515 |
+
self.qa_outputs = FlaubertSQuADHead(config)
|
| 1516 |
+
|
| 1517 |
+
# Initialize weights and apply final processing
|
| 1518 |
+
self.post_init()
|
| 1519 |
+
|
| 1520 |
+
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1521 |
+
@replace_return_docstrings(output_type=FlaubertForQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC)
|
| 1522 |
+
def forward(
|
| 1523 |
+
self,
|
| 1524 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1525 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1526 |
+
langs: Optional[torch.Tensor] = None,
|
| 1527 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1528 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1529 |
+
lengths: Optional[torch.Tensor] = None,
|
| 1530 |
+
cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 1531 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1532 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1533 |
+
start_positions: Optional[torch.Tensor] = None,
|
| 1534 |
+
end_positions: Optional[torch.Tensor] = None,
|
| 1535 |
+
is_impossible: Optional[torch.Tensor] = None,
|
| 1536 |
+
cls_index: Optional[torch.Tensor] = None,
|
| 1537 |
+
p_mask: Optional[torch.Tensor] = None,
|
| 1538 |
+
output_attentions: Optional[bool] = None,
|
| 1539 |
+
output_hidden_states: Optional[bool] = None,
|
| 1540 |
+
return_dict: Optional[bool] = None,
|
| 1541 |
+
) -> Union[Tuple, FlaubertForQuestionAnsweringOutput]:
|
| 1542 |
+
r"""
|
| 1543 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1544 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1545 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1546 |
+
are not taken into account for computing the loss.
|
| 1547 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1548 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1549 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1550 |
+
are not taken into account for computing the loss.
|
| 1551 |
+
is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1552 |
+
Labels whether a question has an answer or no answer (SQuAD 2.0)
|
| 1553 |
+
cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1554 |
+
Labels for position (index) of the classification token to use as input for computing plausibility of the
|
| 1555 |
+
answer.
|
| 1556 |
+
p_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1557 |
+
Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...). 1.0 means token should be
|
| 1558 |
+
masked. 0.0 mean token is not masked.
|
| 1559 |
+
|
| 1560 |
+
Returns:
|
| 1561 |
+
|
| 1562 |
+
Example:
|
| 1563 |
+
|
| 1564 |
+
```python
|
| 1565 |
+
>>> from transformers import AutoTokenizer, FlaubertForQuestionAnswering
|
| 1566 |
+
>>> import torch
|
| 1567 |
+
|
| 1568 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-mlm-en-2048")
|
| 1569 |
+
>>> model = FlaubertForQuestionAnswering.from_pretrained("FacebookAI/xlm-mlm-en-2048")
|
| 1570 |
+
|
| 1571 |
+
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(
|
| 1572 |
+
... 0
|
| 1573 |
+
... ) # Batch size 1
|
| 1574 |
+
>>> start_positions = torch.tensor([1])
|
| 1575 |
+
>>> end_positions = torch.tensor([3])
|
| 1576 |
+
|
| 1577 |
+
>>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
|
| 1578 |
+
>>> loss = outputs.loss
|
| 1579 |
+
```"""
|
| 1580 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1581 |
+
|
| 1582 |
+
transformer_outputs = self.transformer(
|
| 1583 |
+
input_ids,
|
| 1584 |
+
attention_mask=attention_mask,
|
| 1585 |
+
langs=langs,
|
| 1586 |
+
token_type_ids=token_type_ids,
|
| 1587 |
+
position_ids=position_ids,
|
| 1588 |
+
lengths=lengths,
|
| 1589 |
+
cache=cache,
|
| 1590 |
+
head_mask=head_mask,
|
| 1591 |
+
inputs_embeds=inputs_embeds,
|
| 1592 |
+
output_attentions=output_attentions,
|
| 1593 |
+
output_hidden_states=output_hidden_states,
|
| 1594 |
+
return_dict=return_dict,
|
| 1595 |
+
)
|
| 1596 |
+
|
| 1597 |
+
output = transformer_outputs[0]
|
| 1598 |
+
|
| 1599 |
+
outputs = self.qa_outputs(
|
| 1600 |
+
output,
|
| 1601 |
+
start_positions=start_positions,
|
| 1602 |
+
end_positions=end_positions,
|
| 1603 |
+
cls_index=cls_index,
|
| 1604 |
+
is_impossible=is_impossible,
|
| 1605 |
+
p_mask=p_mask,
|
| 1606 |
+
return_dict=return_dict,
|
| 1607 |
+
)
|
| 1608 |
+
|
| 1609 |
+
if not return_dict:
|
| 1610 |
+
return outputs + transformer_outputs[1:]
|
| 1611 |
+
|
| 1612 |
+
return FlaubertForQuestionAnsweringOutput(
|
| 1613 |
+
loss=outputs.loss,
|
| 1614 |
+
start_top_log_probs=outputs.start_top_log_probs,
|
| 1615 |
+
start_top_index=outputs.start_top_index,
|
| 1616 |
+
end_top_log_probs=outputs.end_top_log_probs,
|
| 1617 |
+
end_top_index=outputs.end_top_index,
|
| 1618 |
+
cls_logits=outputs.cls_logits,
|
| 1619 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 1620 |
+
attentions=transformer_outputs.attentions,
|
| 1621 |
+
)
|
| 1622 |
+
|
| 1623 |
+
|
| 1624 |
+
@add_start_docstrings(
|
| 1625 |
+
"""
|
| 1626 |
+
Flaubert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
| 1627 |
+
softmax) e.g. for RocStories/SWAG tasks.
|
| 1628 |
+
""",
|
| 1629 |
+
FLAUBERT_START_DOCSTRING,
|
| 1630 |
+
)
|
| 1631 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMForMultipleChoice with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
| 1632 |
+
class FlaubertForMultipleChoice(FlaubertPreTrainedModel):
|
| 1633 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 1634 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1635 |
+
|
| 1636 |
+
self.transformer = FlaubertModel(config)
|
| 1637 |
+
self.sequence_summary = FlaubertSequenceSummary(config)
|
| 1638 |
+
self.logits_proj = nn.Linear(config.num_labels, 1)
|
| 1639 |
+
|
| 1640 |
+
# Initialize weights and apply final processing
|
| 1641 |
+
self.post_init()
|
| 1642 |
+
|
| 1643 |
+
@add_start_docstrings_to_model_forward(
|
| 1644 |
+
FLAUBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
| 1645 |
+
)
|
| 1646 |
+
@add_code_sample_docstrings(
|
| 1647 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1648 |
+
output_type=MultipleChoiceModelOutput,
|
| 1649 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1650 |
+
)
|
| 1651 |
+
def forward(
|
| 1652 |
+
self,
|
| 1653 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1654 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1655 |
+
langs: Optional[torch.Tensor] = None,
|
| 1656 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1657 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1658 |
+
lengths: Optional[torch.Tensor] = None,
|
| 1659 |
+
cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 1660 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1661 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1662 |
+
labels: Optional[torch.Tensor] = None,
|
| 1663 |
+
output_attentions: Optional[bool] = None,
|
| 1664 |
+
output_hidden_states: Optional[bool] = None,
|
| 1665 |
+
return_dict: Optional[bool] = None,
|
| 1666 |
+
) -> Union[Tuple, MultipleChoiceModelOutput]:
|
| 1667 |
+
r"""
|
| 1668 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1669 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
| 1670 |
+
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
|
| 1671 |
+
`input_ids` above)
|
| 1672 |
+
"""
|
| 1673 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1674 |
+
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
| 1675 |
+
|
| 1676 |
+
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
| 1677 |
+
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
| 1678 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
| 1679 |
+
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
| 1680 |
+
langs = langs.view(-1, langs.size(-1)) if langs is not None else None
|
| 1681 |
+
inputs_embeds = (
|
| 1682 |
+
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
| 1683 |
+
if inputs_embeds is not None
|
| 1684 |
+
else None
|
| 1685 |
+
)
|
| 1686 |
+
|
| 1687 |
+
if lengths is not None:
|
| 1688 |
+
logger.warning(
|
| 1689 |
+
"The `lengths` parameter cannot be used with the Flaubert multiple choice models. Please use the "
|
| 1690 |
+
"attention mask instead."
|
| 1691 |
+
)
|
| 1692 |
+
lengths = None
|
| 1693 |
+
|
| 1694 |
+
transformer_outputs = self.transformer(
|
| 1695 |
+
input_ids=input_ids,
|
| 1696 |
+
attention_mask=attention_mask,
|
| 1697 |
+
langs=langs,
|
| 1698 |
+
token_type_ids=token_type_ids,
|
| 1699 |
+
position_ids=position_ids,
|
| 1700 |
+
lengths=lengths,
|
| 1701 |
+
cache=cache,
|
| 1702 |
+
head_mask=head_mask,
|
| 1703 |
+
inputs_embeds=inputs_embeds,
|
| 1704 |
+
output_attentions=output_attentions,
|
| 1705 |
+
output_hidden_states=output_hidden_states,
|
| 1706 |
+
return_dict=return_dict,
|
| 1707 |
+
)
|
| 1708 |
+
output = transformer_outputs[0]
|
| 1709 |
+
logits = self.sequence_summary(output)
|
| 1710 |
+
logits = self.logits_proj(logits)
|
| 1711 |
+
reshaped_logits = logits.view(-1, num_choices)
|
| 1712 |
+
|
| 1713 |
+
loss = None
|
| 1714 |
+
if labels is not None:
|
| 1715 |
+
loss_fct = CrossEntropyLoss()
|
| 1716 |
+
loss = loss_fct(reshaped_logits, labels)
|
| 1717 |
+
|
| 1718 |
+
if not return_dict:
|
| 1719 |
+
output = (reshaped_logits,) + transformer_outputs[1:]
|
| 1720 |
+
return ((loss,) + output) if loss is not None else output
|
| 1721 |
+
|
| 1722 |
+
return MultipleChoiceModelOutput(
|
| 1723 |
+
loss=loss,
|
| 1724 |
+
logits=reshaped_logits,
|
| 1725 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 1726 |
+
attentions=transformer_outputs.attentions,
|
| 1727 |
+
)
|
| 1728 |
+
|
| 1729 |
+
|
| 1730 |
+
__all__ = [
|
| 1731 |
+
"FlaubertForMultipleChoice",
|
| 1732 |
+
"FlaubertForQuestionAnswering",
|
| 1733 |
+
"FlaubertForQuestionAnsweringSimple",
|
| 1734 |
+
"FlaubertForSequenceClassification",
|
| 1735 |
+
"FlaubertForTokenClassification",
|
| 1736 |
+
"FlaubertModel",
|
| 1737 |
+
"FlaubertWithLMHeadModel",
|
| 1738 |
+
"FlaubertPreTrainedModel",
|
| 1739 |
+
]
|
docs/transformers/build/lib/transformers/models/flaubert/modeling_tf_flaubert.py
ADDED
|
@@ -0,0 +1,1344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
TF 2.0 Flaubert model.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import itertools
|
| 22 |
+
import random
|
| 23 |
+
import warnings
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
from typing import Dict, Optional, Tuple, Union
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
import tensorflow as tf
|
| 29 |
+
|
| 30 |
+
from ...activations_tf import get_tf_activation
|
| 31 |
+
from ...modeling_tf_outputs import (
|
| 32 |
+
TFBaseModelOutput,
|
| 33 |
+
TFMultipleChoiceModelOutput,
|
| 34 |
+
TFQuestionAnsweringModelOutput,
|
| 35 |
+
TFSequenceClassifierOutput,
|
| 36 |
+
TFTokenClassifierOutput,
|
| 37 |
+
)
|
| 38 |
+
from ...modeling_tf_utils import (
|
| 39 |
+
TFModelInputType,
|
| 40 |
+
TFMultipleChoiceLoss,
|
| 41 |
+
TFPreTrainedModel,
|
| 42 |
+
TFQuestionAnsweringLoss,
|
| 43 |
+
TFSequenceClassificationLoss,
|
| 44 |
+
TFSequenceSummary,
|
| 45 |
+
TFSharedEmbeddings,
|
| 46 |
+
TFTokenClassificationLoss,
|
| 47 |
+
get_initializer,
|
| 48 |
+
keras,
|
| 49 |
+
keras_serializable,
|
| 50 |
+
unpack_inputs,
|
| 51 |
+
)
|
| 52 |
+
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
|
| 53 |
+
from ...utils import (
|
| 54 |
+
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
| 55 |
+
ModelOutput,
|
| 56 |
+
add_code_sample_docstrings,
|
| 57 |
+
add_start_docstrings,
|
| 58 |
+
add_start_docstrings_to_model_forward,
|
| 59 |
+
logging,
|
| 60 |
+
)
|
| 61 |
+
from .configuration_flaubert import FlaubertConfig
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
logger = logging.get_logger(__name__)
|
| 65 |
+
|
| 66 |
+
_CHECKPOINT_FOR_DOC = "flaubert/flaubert_base_cased"
|
| 67 |
+
_CONFIG_FOR_DOC = "FlaubertConfig"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
FLAUBERT_START_DOCSTRING = r"""
|
| 71 |
+
|
| 72 |
+
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 73 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 74 |
+
etc.)
|
| 75 |
+
|
| 76 |
+
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
| 77 |
+
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
| 78 |
+
behavior.
|
| 79 |
+
|
| 80 |
+
<Tip>
|
| 81 |
+
|
| 82 |
+
TensorFlow models and layers in `transformers` accept two formats as input:
|
| 83 |
+
|
| 84 |
+
- having all inputs as keyword arguments (like PyTorch models), or
|
| 85 |
+
- having all inputs as a list, tuple or dict in the first positional argument.
|
| 86 |
+
|
| 87 |
+
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
|
| 88 |
+
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
|
| 89 |
+
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
|
| 90 |
+
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
|
| 91 |
+
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
|
| 92 |
+
positional argument:
|
| 93 |
+
|
| 94 |
+
- a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
|
| 95 |
+
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
| 96 |
+
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
|
| 97 |
+
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
| 98 |
+
`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
|
| 99 |
+
|
| 100 |
+
Note that when creating models and layers with
|
| 101 |
+
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
|
| 102 |
+
about any of this, as you can just pass inputs like you would to any other Python function!
|
| 103 |
+
|
| 104 |
+
</Tip>
|
| 105 |
+
|
| 106 |
+
Parameters:
|
| 107 |
+
config ([`FlaubertConfig`]): Model configuration class with all the parameters of the model.
|
| 108 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 109 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
FLAUBERT_INPUTS_DOCSTRING = r"""
|
| 113 |
+
Args:
|
| 114 |
+
input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`):
|
| 115 |
+
Indices of input sequence tokens in the vocabulary.
|
| 116 |
+
|
| 117 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
|
| 118 |
+
[`PreTrainedTokenizer.encode`] for details.
|
| 119 |
+
|
| 120 |
+
[What are input IDs?](../glossary#input-ids)
|
| 121 |
+
attention_mask (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 122 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 123 |
+
|
| 124 |
+
- `1` for tokens that are **not masked**,
|
| 125 |
+
- `0` for tokens that are **masked**.
|
| 126 |
+
|
| 127 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 128 |
+
langs (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
|
| 129 |
+
A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
|
| 130 |
+
languages ids which can be obtained from the language names by using two conversion mappings provided in
|
| 131 |
+
the configuration of the model (only provided for multilingual models). More precisely, the *language name
|
| 132 |
+
to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
|
| 133 |
+
*language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
|
| 134 |
+
|
| 135 |
+
See usage examples detailed in the [multilingual documentation](../multilingual).
|
| 136 |
+
token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
|
| 137 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 138 |
+
1]`:
|
| 139 |
+
|
| 140 |
+
- `0` corresponds to a *sentence A* token,
|
| 141 |
+
- `1` corresponds to a *sentence B* token.
|
| 142 |
+
|
| 143 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 144 |
+
position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
|
| 145 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 146 |
+
config.max_position_embeddings - 1]`.
|
| 147 |
+
|
| 148 |
+
[What are position IDs?](../glossary#position-ids)
|
| 149 |
+
lengths (`tf.Tensor` or `Numpy array` of shape `(batch_size,)`, *optional*):
|
| 150 |
+
Length of each sentence that can be used to avoid performing attention on padding token indices. You can
|
| 151 |
+
also use *attention_mask* for the same result (see above), kept here for compatibility Indices selected in
|
| 152 |
+
`[0, ..., input_ids.size(-1)]`:
|
| 153 |
+
cache (`Dict[str, tf.Tensor]`, *optional*):
|
| 154 |
+
Dictionary string to `tf.FloatTensor` that contains precomputed hidden states (key and values in the
|
| 155 |
+
attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential
|
| 156 |
+
decoding.
|
| 157 |
+
|
| 158 |
+
The dictionary object will be modified in-place during the forward pass to add newly computed
|
| 159 |
+
hidden-states.
|
| 160 |
+
head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 161 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 162 |
+
|
| 163 |
+
- `1` indicates the head is **not masked**,
|
| 164 |
+
- `0` indicates the head is **masked**.
|
| 165 |
+
|
| 166 |
+
inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 167 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 168 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 169 |
+
model's internal embedding lookup matrix.
|
| 170 |
+
output_attentions (`bool`, *optional*):
|
| 171 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 172 |
+
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
|
| 173 |
+
config will be used instead.
|
| 174 |
+
output_hidden_states (`bool`, *optional*):
|
| 175 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 176 |
+
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
| 177 |
+
used instead.
|
| 178 |
+
return_dict (`bool`, *optional*):
|
| 179 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
|
| 180 |
+
eager mode, in graph mode the value will always be set to True.
|
| 181 |
+
training (`bool`, *optional*, defaults to `False`):
|
| 182 |
+
Whether or not to use the model in training mode (some modules like dropout modules have different
|
| 183 |
+
behaviors between training and evaluation).
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def get_masks(slen, lengths, causal, padding_mask=None):
|
| 188 |
+
"""
|
| 189 |
+
Generate hidden states mask, and optionally an attention mask.
|
| 190 |
+
"""
|
| 191 |
+
bs = shape_list(lengths)[0]
|
| 192 |
+
if padding_mask is not None:
|
| 193 |
+
mask = padding_mask
|
| 194 |
+
else:
|
| 195 |
+
# assert lengths.max().item() <= slen
|
| 196 |
+
alen = tf.range(slen, dtype=lengths.dtype)
|
| 197 |
+
mask = alen < tf.expand_dims(lengths, axis=1)
|
| 198 |
+
|
| 199 |
+
# attention mask is the same as mask, or triangular inferior attention (causal)
|
| 200 |
+
if causal:
|
| 201 |
+
attn_mask = tf.less_equal(
|
| 202 |
+
tf.tile(tf.reshape(alen, (1, 1, slen)), (bs, slen, 1)), tf.reshape(alen, (1, slen, 1))
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
attn_mask = mask
|
| 206 |
+
|
| 207 |
+
# sanity check
|
| 208 |
+
# assert shape_list(mask) == [bs, slen]
|
| 209 |
+
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
| 210 |
+
if causal:
|
| 211 |
+
tf.debugging.assert_equal(shape_list(attn_mask), [bs, slen, slen])
|
| 212 |
+
|
| 213 |
+
return mask, attn_mask
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class TFFlaubertPreTrainedModel(TFPreTrainedModel):
|
| 217 |
+
"""
|
| 218 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 219 |
+
models.
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
config_class = FlaubertConfig
|
| 223 |
+
base_model_prefix = "transformer"
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
def dummy_inputs(self):
|
| 227 |
+
# Sometimes Flaubert has language embeddings so don't forget to build them as well if needed
|
| 228 |
+
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]], dtype=tf.int32)
|
| 229 |
+
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]], dtype=tf.int32)
|
| 230 |
+
if self.config.use_lang_emb and self.config.n_langs > 1:
|
| 231 |
+
return {
|
| 232 |
+
"input_ids": inputs_list,
|
| 233 |
+
"attention_mask": attns_list,
|
| 234 |
+
"langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]], dtype=tf.int32),
|
| 235 |
+
}
|
| 236 |
+
else:
|
| 237 |
+
return {"input_ids": inputs_list, "attention_mask": attns_list}
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
@add_start_docstrings(
|
| 241 |
+
"The bare Flaubert Model transformer outputting raw hidden-states without any specific head on top.",
|
| 242 |
+
FLAUBERT_START_DOCSTRING,
|
| 243 |
+
)
|
| 244 |
+
class TFFlaubertModel(TFFlaubertPreTrainedModel):
|
| 245 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 246 |
+
super().__init__(config, *inputs, **kwargs)
|
| 247 |
+
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
| 248 |
+
|
| 249 |
+
@unpack_inputs
|
| 250 |
+
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING)
|
| 251 |
+
@add_code_sample_docstrings(
|
| 252 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 253 |
+
output_type=TFBaseModelOutput,
|
| 254 |
+
config_class=_CONFIG_FOR_DOC,
|
| 255 |
+
)
|
| 256 |
+
def call(
|
| 257 |
+
self,
|
| 258 |
+
input_ids: np.ndarray | tf.Tensor | None = None,
|
| 259 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 260 |
+
langs: np.ndarray | tf.Tensor | None = None,
|
| 261 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 262 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 263 |
+
lengths: np.ndarray | tf.Tensor | None = None,
|
| 264 |
+
cache: Optional[Dict[str, tf.Tensor]] = None,
|
| 265 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 266 |
+
inputs_embeds: tf.Tensor | None = None,
|
| 267 |
+
output_attentions: Optional[bool] = None,
|
| 268 |
+
output_hidden_states: Optional[bool] = None,
|
| 269 |
+
return_dict: Optional[bool] = None,
|
| 270 |
+
training: Optional[bool] = False,
|
| 271 |
+
) -> Union[Tuple, TFBaseModelOutput]:
|
| 272 |
+
outputs = self.transformer(
|
| 273 |
+
input_ids=input_ids,
|
| 274 |
+
attention_mask=attention_mask,
|
| 275 |
+
langs=langs,
|
| 276 |
+
token_type_ids=token_type_ids,
|
| 277 |
+
position_ids=position_ids,
|
| 278 |
+
lengths=lengths,
|
| 279 |
+
cache=cache,
|
| 280 |
+
head_mask=head_mask,
|
| 281 |
+
inputs_embeds=inputs_embeds,
|
| 282 |
+
output_attentions=output_attentions,
|
| 283 |
+
output_hidden_states=output_hidden_states,
|
| 284 |
+
return_dict=return_dict,
|
| 285 |
+
training=training,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
return outputs
|
| 289 |
+
|
| 290 |
+
def build(self, input_shape=None):
|
| 291 |
+
if self.built:
|
| 292 |
+
return
|
| 293 |
+
self.built = True
|
| 294 |
+
if getattr(self, "transformer", None) is not None:
|
| 295 |
+
with tf.name_scope(self.transformer.name):
|
| 296 |
+
self.transformer.build(None)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMMultiHeadAttention with XLM->Flaubert
|
| 300 |
+
class TFFlaubertMultiHeadAttention(keras.layers.Layer):
|
| 301 |
+
NEW_ID = itertools.count()
|
| 302 |
+
|
| 303 |
+
def __init__(self, n_heads, dim, config, **kwargs):
|
| 304 |
+
super().__init__(**kwargs)
|
| 305 |
+
self.layer_id = next(TFFlaubertMultiHeadAttention.NEW_ID)
|
| 306 |
+
self.dim = dim
|
| 307 |
+
self.n_heads = n_heads
|
| 308 |
+
self.output_attentions = config.output_attentions
|
| 309 |
+
assert self.dim % self.n_heads == 0
|
| 310 |
+
|
| 311 |
+
self.q_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="q_lin")
|
| 312 |
+
self.k_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="k_lin")
|
| 313 |
+
self.v_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="v_lin")
|
| 314 |
+
self.out_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="out_lin")
|
| 315 |
+
self.dropout = keras.layers.Dropout(config.attention_dropout)
|
| 316 |
+
self.pruned_heads = set()
|
| 317 |
+
self.dim = dim
|
| 318 |
+
|
| 319 |
+
def prune_heads(self, heads):
|
| 320 |
+
raise NotImplementedError
|
| 321 |
+
|
| 322 |
+
def call(self, input, mask, kv, cache, head_mask, output_attentions, training=False):
|
| 323 |
+
"""
|
| 324 |
+
Self-attention (if kv is None) or attention over source sentence (provided by kv).
|
| 325 |
+
"""
|
| 326 |
+
# Input is (bs, qlen, dim)
|
| 327 |
+
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
|
| 328 |
+
bs, qlen, dim = shape_list(input)
|
| 329 |
+
|
| 330 |
+
if kv is None:
|
| 331 |
+
klen = qlen if cache is None else cache["slen"] + qlen
|
| 332 |
+
else:
|
| 333 |
+
klen = shape_list(kv)[1]
|
| 334 |
+
|
| 335 |
+
# assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
|
| 336 |
+
dim_per_head = self.dim // self.n_heads
|
| 337 |
+
mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen)
|
| 338 |
+
|
| 339 |
+
def shape(x):
|
| 340 |
+
"""projection"""
|
| 341 |
+
return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3))
|
| 342 |
+
|
| 343 |
+
def unshape(x):
|
| 344 |
+
"""compute context"""
|
| 345 |
+
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
|
| 346 |
+
|
| 347 |
+
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
|
| 348 |
+
|
| 349 |
+
if kv is None:
|
| 350 |
+
k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head)
|
| 351 |
+
v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head)
|
| 352 |
+
elif cache is None or self.layer_id not in cache:
|
| 353 |
+
k = v = kv
|
| 354 |
+
k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head)
|
| 355 |
+
v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head)
|
| 356 |
+
|
| 357 |
+
if cache is not None:
|
| 358 |
+
if self.layer_id in cache:
|
| 359 |
+
if kv is None:
|
| 360 |
+
k_, v_ = cache[self.layer_id]
|
| 361 |
+
k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head)
|
| 362 |
+
v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head)
|
| 363 |
+
else:
|
| 364 |
+
k, v = cache[self.layer_id]
|
| 365 |
+
|
| 366 |
+
cache[self.layer_id] = (k, v)
|
| 367 |
+
|
| 368 |
+
f_dim_per_head = tf.cast(dim_per_head, dtype=q.dtype)
|
| 369 |
+
q = tf.multiply(q, tf.math.rsqrt(f_dim_per_head)) # (bs, n_heads, qlen, dim_per_head)
|
| 370 |
+
k = tf.cast(k, dtype=q.dtype)
|
| 371 |
+
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
|
| 372 |
+
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
|
| 373 |
+
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
|
| 374 |
+
mask = tf.cast(mask, dtype=scores.dtype)
|
| 375 |
+
scores = scores - 1e30 * (1.0 - mask)
|
| 376 |
+
weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
|
| 377 |
+
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
|
| 378 |
+
|
| 379 |
+
# Mask heads if we want to
|
| 380 |
+
if head_mask is not None:
|
| 381 |
+
weights = weights * head_mask
|
| 382 |
+
|
| 383 |
+
context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
|
| 384 |
+
context = unshape(context) # (bs, qlen, dim)
|
| 385 |
+
outputs = (self.out_lin(context),)
|
| 386 |
+
|
| 387 |
+
if output_attentions:
|
| 388 |
+
outputs = outputs + (weights,)
|
| 389 |
+
|
| 390 |
+
return outputs
|
| 391 |
+
|
| 392 |
+
def build(self, input_shape=None):
|
| 393 |
+
if self.built:
|
| 394 |
+
return
|
| 395 |
+
self.built = True
|
| 396 |
+
if getattr(self, "q_lin", None) is not None:
|
| 397 |
+
with tf.name_scope(self.q_lin.name):
|
| 398 |
+
self.q_lin.build([None, None, self.dim])
|
| 399 |
+
if getattr(self, "k_lin", None) is not None:
|
| 400 |
+
with tf.name_scope(self.k_lin.name):
|
| 401 |
+
self.k_lin.build([None, None, self.dim])
|
| 402 |
+
if getattr(self, "v_lin", None) is not None:
|
| 403 |
+
with tf.name_scope(self.v_lin.name):
|
| 404 |
+
self.v_lin.build([None, None, self.dim])
|
| 405 |
+
if getattr(self, "out_lin", None) is not None:
|
| 406 |
+
with tf.name_scope(self.out_lin.name):
|
| 407 |
+
self.out_lin.build([None, None, self.dim])
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMTransformerFFN
|
| 411 |
+
class TFFlaubertTransformerFFN(keras.layers.Layer):
|
| 412 |
+
def __init__(self, in_dim, dim_hidden, out_dim, config, **kwargs):
|
| 413 |
+
super().__init__(**kwargs)
|
| 414 |
+
|
| 415 |
+
self.lin1 = keras.layers.Dense(dim_hidden, kernel_initializer=get_initializer(config.init_std), name="lin1")
|
| 416 |
+
self.lin2 = keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name="lin2")
|
| 417 |
+
self.act = get_tf_activation("gelu") if config.gelu_activation else get_tf_activation("relu")
|
| 418 |
+
self.dropout = keras.layers.Dropout(config.dropout)
|
| 419 |
+
self.in_dim = in_dim
|
| 420 |
+
self.dim_hidden = dim_hidden
|
| 421 |
+
|
| 422 |
+
def call(self, input, training=False):
|
| 423 |
+
x = self.lin1(input)
|
| 424 |
+
x = self.act(x)
|
| 425 |
+
x = self.lin2(x)
|
| 426 |
+
x = self.dropout(x, training=training)
|
| 427 |
+
|
| 428 |
+
return x
|
| 429 |
+
|
| 430 |
+
def build(self, input_shape=None):
|
| 431 |
+
if self.built:
|
| 432 |
+
return
|
| 433 |
+
self.built = True
|
| 434 |
+
if getattr(self, "lin1", None) is not None:
|
| 435 |
+
with tf.name_scope(self.lin1.name):
|
| 436 |
+
self.lin1.build([None, None, self.in_dim])
|
| 437 |
+
if getattr(self, "lin2", None) is not None:
|
| 438 |
+
with tf.name_scope(self.lin2.name):
|
| 439 |
+
self.lin2.build([None, None, self.dim_hidden])
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
@keras_serializable
|
| 443 |
+
class TFFlaubertMainLayer(keras.layers.Layer):
|
| 444 |
+
config_class = FlaubertConfig
|
| 445 |
+
|
| 446 |
+
def __init__(self, config, **kwargs):
|
| 447 |
+
super().__init__(**kwargs)
|
| 448 |
+
|
| 449 |
+
self.config = config
|
| 450 |
+
self.n_heads = config.n_heads
|
| 451 |
+
self.n_langs = config.n_langs
|
| 452 |
+
self.dim = config.emb_dim
|
| 453 |
+
self.hidden_dim = self.dim * 4
|
| 454 |
+
self.n_words = config.n_words
|
| 455 |
+
self.pad_index = config.pad_index
|
| 456 |
+
self.causal = config.causal
|
| 457 |
+
self.n_layers = config.n_layers
|
| 458 |
+
self.use_lang_emb = config.use_lang_emb
|
| 459 |
+
self.layerdrop = getattr(config, "layerdrop", 0.0)
|
| 460 |
+
self.pre_norm = getattr(config, "pre_norm", False)
|
| 461 |
+
self.output_attentions = config.output_attentions
|
| 462 |
+
self.output_hidden_states = config.output_hidden_states
|
| 463 |
+
self.return_dict = config.use_return_dict
|
| 464 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 465 |
+
self.embed_init_std = config.embed_init_std
|
| 466 |
+
self.dropout = keras.layers.Dropout(config.dropout)
|
| 467 |
+
self.embeddings = TFSharedEmbeddings(
|
| 468 |
+
self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings"
|
| 469 |
+
)
|
| 470 |
+
self.layer_norm_emb = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm_emb")
|
| 471 |
+
self.attentions = []
|
| 472 |
+
self.layer_norm1 = []
|
| 473 |
+
self.ffns = []
|
| 474 |
+
self.layer_norm2 = []
|
| 475 |
+
|
| 476 |
+
for i in range(self.n_layers):
|
| 477 |
+
self.attentions.append(
|
| 478 |
+
TFFlaubertMultiHeadAttention(self.n_heads, self.dim, config=config, name=f"attentions_._{i}")
|
| 479 |
+
)
|
| 480 |
+
self.layer_norm1.append(
|
| 481 |
+
keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name=f"layer_norm1_._{i}")
|
| 482 |
+
)
|
| 483 |
+
# if self.is_decoder:
|
| 484 |
+
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
|
| 485 |
+
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
|
| 486 |
+
self.ffns.append(
|
| 487 |
+
TFFlaubertTransformerFFN(self.dim, self.hidden_dim, self.dim, config=config, name=f"ffns_._{i}")
|
| 488 |
+
)
|
| 489 |
+
self.layer_norm2.append(
|
| 490 |
+
keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name=f"layer_norm2_._{i}")
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
def build(self, input_shape=None):
|
| 494 |
+
with tf.name_scope("position_embeddings"):
|
| 495 |
+
self.position_embeddings = self.add_weight(
|
| 496 |
+
name="embeddings",
|
| 497 |
+
shape=[self.max_position_embeddings, self.dim],
|
| 498 |
+
initializer=get_initializer(self.embed_init_std),
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
if self.n_langs > 1 and self.use_lang_emb:
|
| 502 |
+
with tf.name_scope("lang_embeddings"):
|
| 503 |
+
self.lang_embeddings = self.add_weight(
|
| 504 |
+
name="embeddings",
|
| 505 |
+
shape=[self.n_langs, self.dim],
|
| 506 |
+
initializer=get_initializer(self.embed_init_std),
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
if self.built:
|
| 510 |
+
return
|
| 511 |
+
self.built = True
|
| 512 |
+
if getattr(self, "embeddings", None) is not None:
|
| 513 |
+
with tf.name_scope(self.embeddings.name):
|
| 514 |
+
self.embeddings.build(None)
|
| 515 |
+
if getattr(self, "layer_norm_emb", None) is not None:
|
| 516 |
+
with tf.name_scope(self.layer_norm_emb.name):
|
| 517 |
+
self.layer_norm_emb.build([None, None, self.dim])
|
| 518 |
+
for layer in self.attentions:
|
| 519 |
+
with tf.name_scope(layer.name):
|
| 520 |
+
layer.build(None)
|
| 521 |
+
for layer in self.layer_norm1:
|
| 522 |
+
with tf.name_scope(layer.name):
|
| 523 |
+
layer.build([None, None, self.dim])
|
| 524 |
+
for layer in self.ffns:
|
| 525 |
+
with tf.name_scope(layer.name):
|
| 526 |
+
layer.build(None)
|
| 527 |
+
for layer in self.layer_norm2:
|
| 528 |
+
with tf.name_scope(layer.name):
|
| 529 |
+
layer.build([None, None, self.dim])
|
| 530 |
+
|
| 531 |
+
def get_input_embeddings(self):
|
| 532 |
+
return self.embeddings
|
| 533 |
+
|
| 534 |
+
def set_input_embeddings(self, value):
|
| 535 |
+
self.embeddings.weight = value
|
| 536 |
+
self.embeddings.vocab_size = shape_list(value)[0]
|
| 537 |
+
|
| 538 |
+
@unpack_inputs
|
| 539 |
+
def call(
|
| 540 |
+
self,
|
| 541 |
+
input_ids: np.ndarray | tf.Tensor | None = None,
|
| 542 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 543 |
+
langs: np.ndarray | tf.Tensor | None = None,
|
| 544 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 545 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 546 |
+
lengths: np.ndarray | tf.Tensor | None = None,
|
| 547 |
+
cache: Optional[Dict[str, tf.Tensor]] = None,
|
| 548 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 549 |
+
inputs_embeds: tf.Tensor | None = None,
|
| 550 |
+
output_attentions: Optional[bool] = None,
|
| 551 |
+
output_hidden_states: Optional[bool] = None,
|
| 552 |
+
return_dict: Optional[bool] = None,
|
| 553 |
+
training: Optional[bool] = False,
|
| 554 |
+
) -> Union[Tuple, TFBaseModelOutput]:
|
| 555 |
+
# removed: src_enc=None, src_len=None
|
| 556 |
+
|
| 557 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 558 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 559 |
+
elif input_ids is not None:
|
| 560 |
+
bs, slen = shape_list(input_ids)
|
| 561 |
+
elif inputs_embeds is not None:
|
| 562 |
+
bs, slen = shape_list(inputs_embeds)[:2]
|
| 563 |
+
else:
|
| 564 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 565 |
+
|
| 566 |
+
if lengths is None:
|
| 567 |
+
if input_ids is not None:
|
| 568 |
+
lengths = tf.reduce_sum(
|
| 569 |
+
tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=input_ids.dtype), axis=1
|
| 570 |
+
)
|
| 571 |
+
else:
|
| 572 |
+
lengths = tf.convert_to_tensor([slen] * bs)
|
| 573 |
+
# mask = input_ids != self.pad_index
|
| 574 |
+
|
| 575 |
+
# check inputs
|
| 576 |
+
# assert shape_list(lengths)[0] == bs
|
| 577 |
+
(
|
| 578 |
+
tf.debugging.assert_equal(shape_list(lengths)[0], bs),
|
| 579 |
+
f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched",
|
| 580 |
+
)
|
| 581 |
+
# assert lengths.max().item() <= slen
|
| 582 |
+
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
|
| 583 |
+
# assert (src_enc is None) == (src_len is None)
|
| 584 |
+
# if src_enc is not None:
|
| 585 |
+
# assert self.is_decoder
|
| 586 |
+
# assert src_enc.size(0) == bs
|
| 587 |
+
|
| 588 |
+
# generate masks
|
| 589 |
+
mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
|
| 590 |
+
# if self.is_decoder and src_enc is not None:
|
| 591 |
+
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
|
| 592 |
+
|
| 593 |
+
# position_ids
|
| 594 |
+
if position_ids is None:
|
| 595 |
+
position_ids = tf.expand_dims(tf.range(slen), axis=0)
|
| 596 |
+
position_ids = tf.tile(position_ids, (bs, 1))
|
| 597 |
+
|
| 598 |
+
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
| 599 |
+
(
|
| 600 |
+
tf.debugging.assert_equal(shape_list(position_ids), [bs, slen]),
|
| 601 |
+
f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched",
|
| 602 |
+
)
|
| 603 |
+
# position_ids = position_ids.transpose(0, 1)
|
| 604 |
+
|
| 605 |
+
# langs
|
| 606 |
+
if langs is not None:
|
| 607 |
+
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
| 608 |
+
(
|
| 609 |
+
tf.debugging.assert_equal(shape_list(langs), [bs, slen]),
|
| 610 |
+
f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched",
|
| 611 |
+
)
|
| 612 |
+
# langs = langs.transpose(0, 1)
|
| 613 |
+
|
| 614 |
+
# Prepare head mask if needed
|
| 615 |
+
# 1.0 in head_mask indicate we keep the head
|
| 616 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 617 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 618 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
|
| 619 |
+
if head_mask is not None:
|
| 620 |
+
raise NotImplementedError
|
| 621 |
+
else:
|
| 622 |
+
head_mask = [None] * self.n_layers
|
| 623 |
+
|
| 624 |
+
# do not recompute cached elements
|
| 625 |
+
if cache is not None and input_ids is not None:
|
| 626 |
+
_slen = slen - cache["slen"]
|
| 627 |
+
input_ids = input_ids[:, -_slen:]
|
| 628 |
+
position_ids = position_ids[:, -_slen:]
|
| 629 |
+
if langs is not None:
|
| 630 |
+
langs = langs[:, -_slen:]
|
| 631 |
+
mask = mask[:, -_slen:]
|
| 632 |
+
attn_mask = attn_mask[:, -_slen:]
|
| 633 |
+
|
| 634 |
+
# embeddings
|
| 635 |
+
if inputs_embeds is None:
|
| 636 |
+
check_embeddings_within_bounds(input_ids, self.embeddings.vocab_size)
|
| 637 |
+
inputs_embeds = self.embeddings(input_ids)
|
| 638 |
+
|
| 639 |
+
tensor = inputs_embeds + tf.gather(self.position_embeddings, position_ids)
|
| 640 |
+
|
| 641 |
+
if langs is not None and self.use_lang_emb:
|
| 642 |
+
tensor = tensor + tf.gather(self.lang_embeddings, langs)
|
| 643 |
+
if token_type_ids is not None:
|
| 644 |
+
tensor = tensor + self.embeddings(token_type_ids)
|
| 645 |
+
|
| 646 |
+
tensor = self.layer_norm_emb(tensor)
|
| 647 |
+
tensor = self.dropout(tensor, training=training)
|
| 648 |
+
mask = tf.cast(mask, dtype=tensor.dtype)
|
| 649 |
+
tensor = tensor * tf.expand_dims(mask, axis=-1)
|
| 650 |
+
|
| 651 |
+
# hidden_states and attentions cannot be None in graph mode.
|
| 652 |
+
hidden_states = () if output_hidden_states else None
|
| 653 |
+
attentions = () if output_attentions else None
|
| 654 |
+
|
| 655 |
+
# transformer layers
|
| 656 |
+
for i in range(self.n_layers):
|
| 657 |
+
# LayerDrop
|
| 658 |
+
dropout_probability = random.uniform(0, 1)
|
| 659 |
+
|
| 660 |
+
if training and (dropout_probability < self.layerdrop):
|
| 661 |
+
continue
|
| 662 |
+
|
| 663 |
+
if output_hidden_states:
|
| 664 |
+
hidden_states = hidden_states + (tensor,)
|
| 665 |
+
|
| 666 |
+
# self attention
|
| 667 |
+
if not self.pre_norm:
|
| 668 |
+
attn_outputs = self.attentions[i](
|
| 669 |
+
tensor,
|
| 670 |
+
attn_mask,
|
| 671 |
+
None,
|
| 672 |
+
cache,
|
| 673 |
+
head_mask[i],
|
| 674 |
+
output_attentions,
|
| 675 |
+
training=training,
|
| 676 |
+
)
|
| 677 |
+
attn = attn_outputs[0]
|
| 678 |
+
|
| 679 |
+
if output_attentions:
|
| 680 |
+
attentions = attentions + (attn_outputs[1],)
|
| 681 |
+
|
| 682 |
+
attn = self.dropout(attn, training=training)
|
| 683 |
+
tensor = tensor + attn
|
| 684 |
+
tensor = self.layer_norm1[i](tensor)
|
| 685 |
+
else:
|
| 686 |
+
tensor_normalized = self.layer_norm1[i](tensor)
|
| 687 |
+
attn_outputs = self.attentions[i](
|
| 688 |
+
tensor_normalized,
|
| 689 |
+
attn_mask,
|
| 690 |
+
None,
|
| 691 |
+
cache,
|
| 692 |
+
head_mask[i],
|
| 693 |
+
output_attentions,
|
| 694 |
+
training=training,
|
| 695 |
+
)
|
| 696 |
+
attn = attn_outputs[0]
|
| 697 |
+
|
| 698 |
+
if output_attentions:
|
| 699 |
+
attentions = attentions + (attn_outputs[1],)
|
| 700 |
+
|
| 701 |
+
attn = self.dropout(attn, training=training)
|
| 702 |
+
tensor = tensor + attn
|
| 703 |
+
|
| 704 |
+
# encoder attention (for decoder only)
|
| 705 |
+
# if self.is_decoder and src_enc is not None:
|
| 706 |
+
# attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
|
| 707 |
+
# attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
|
| 708 |
+
# tensor = tensor + attn
|
| 709 |
+
# tensor = self.layer_norm15[i](tensor)
|
| 710 |
+
|
| 711 |
+
# FFN
|
| 712 |
+
if not self.pre_norm:
|
| 713 |
+
tensor = tensor + self.ffns[i](tensor)
|
| 714 |
+
tensor = self.layer_norm2[i](tensor)
|
| 715 |
+
else:
|
| 716 |
+
tensor_normalized = self.layer_norm2[i](tensor)
|
| 717 |
+
tensor = tensor + self.ffns[i](tensor_normalized)
|
| 718 |
+
|
| 719 |
+
tensor = tensor * tf.expand_dims(mask, axis=-1)
|
| 720 |
+
|
| 721 |
+
# Add last hidden state
|
| 722 |
+
if output_hidden_states:
|
| 723 |
+
hidden_states = hidden_states + (tensor,)
|
| 724 |
+
|
| 725 |
+
# update cache length
|
| 726 |
+
if cache is not None:
|
| 727 |
+
cache["slen"] += tensor.size(1)
|
| 728 |
+
|
| 729 |
+
# move back sequence length to dimension 0
|
| 730 |
+
# tensor = tensor.transpose(0, 1)
|
| 731 |
+
|
| 732 |
+
if not return_dict:
|
| 733 |
+
return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
|
| 734 |
+
|
| 735 |
+
return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMPredLayer
|
| 739 |
+
class TFFlaubertPredLayer(keras.layers.Layer):
|
| 740 |
+
"""
|
| 741 |
+
Prediction layer (cross_entropy or adaptive_softmax).
|
| 742 |
+
"""
|
| 743 |
+
|
| 744 |
+
def __init__(self, config, input_embeddings, **kwargs):
|
| 745 |
+
super().__init__(**kwargs)
|
| 746 |
+
|
| 747 |
+
self.asm = config.asm
|
| 748 |
+
self.n_words = config.n_words
|
| 749 |
+
self.pad_index = config.pad_index
|
| 750 |
+
|
| 751 |
+
if config.asm is False:
|
| 752 |
+
self.input_embeddings = input_embeddings
|
| 753 |
+
else:
|
| 754 |
+
raise NotImplementedError
|
| 755 |
+
# self.proj = nn.AdaptiveLogSoftmaxWithLoss(
|
| 756 |
+
# in_features=dim,
|
| 757 |
+
# n_classes=config.n_words,
|
| 758 |
+
# cutoffs=config.asm_cutoffs,
|
| 759 |
+
# div_value=config.asm_div_value,
|
| 760 |
+
# head_bias=True, # default is False
|
| 761 |
+
# )
|
| 762 |
+
|
| 763 |
+
def build(self, input_shape):
|
| 764 |
+
# The output weights are the same as the input embeddings, but there is an output-only bias for each token.
|
| 765 |
+
self.bias = self.add_weight(shape=(self.n_words,), initializer="zeros", trainable=True, name="bias")
|
| 766 |
+
|
| 767 |
+
super().build(input_shape)
|
| 768 |
+
|
| 769 |
+
def get_output_embeddings(self):
|
| 770 |
+
return self.input_embeddings
|
| 771 |
+
|
| 772 |
+
def set_output_embeddings(self, value):
|
| 773 |
+
self.input_embeddings.weight = value
|
| 774 |
+
self.input_embeddings.vocab_size = shape_list(value)[0]
|
| 775 |
+
|
| 776 |
+
def get_bias(self):
|
| 777 |
+
return {"bias": self.bias}
|
| 778 |
+
|
| 779 |
+
def set_bias(self, value):
|
| 780 |
+
self.bias = value["bias"]
|
| 781 |
+
self.vocab_size = shape_list(value["bias"])[0]
|
| 782 |
+
|
| 783 |
+
def call(self, hidden_states):
|
| 784 |
+
hidden_states = self.input_embeddings(hidden_states, mode="linear")
|
| 785 |
+
hidden_states = hidden_states + self.bias
|
| 786 |
+
|
| 787 |
+
return hidden_states
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
@dataclass
|
| 791 |
+
class TFFlaubertWithLMHeadModelOutput(ModelOutput):
|
| 792 |
+
"""
|
| 793 |
+
Base class for [`TFFlaubertWithLMHeadModel`] outputs.
|
| 794 |
+
|
| 795 |
+
Args:
|
| 796 |
+
logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 797 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 798 |
+
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 799 |
+
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 800 |
+
`(batch_size, sequence_length, hidden_size)`.
|
| 801 |
+
|
| 802 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 803 |
+
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 804 |
+
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 805 |
+
sequence_length)`.
|
| 806 |
+
|
| 807 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 808 |
+
heads.
|
| 809 |
+
"""
|
| 810 |
+
|
| 811 |
+
logits: Optional[tf.Tensor] = None
|
| 812 |
+
hidden_states: Tuple[tf.Tensor] | None = None
|
| 813 |
+
attentions: Tuple[tf.Tensor] | None = None
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
@add_start_docstrings(
|
| 817 |
+
"""
|
| 818 |
+
The Flaubert Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
| 819 |
+
embeddings).
|
| 820 |
+
""",
|
| 821 |
+
FLAUBERT_START_DOCSTRING,
|
| 822 |
+
)
|
| 823 |
+
class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
|
| 824 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 825 |
+
super().__init__(config, *inputs, **kwargs)
|
| 826 |
+
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
| 827 |
+
self.pred_layer = TFFlaubertPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")
|
| 828 |
+
# Flaubert does not have past caching features
|
| 829 |
+
self.supports_xla_generation = False
|
| 830 |
+
|
| 831 |
+
def get_lm_head(self):
|
| 832 |
+
return self.pred_layer
|
| 833 |
+
|
| 834 |
+
def get_prefix_bias_name(self):
|
| 835 |
+
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
| 836 |
+
return self.name + "/" + self.pred_layer.name
|
| 837 |
+
|
| 838 |
+
def prepare_inputs_for_generation(self, inputs, **kwargs):
|
| 839 |
+
mask_token_id = self.config.mask_token_id
|
| 840 |
+
lang_id = self.config.lang_id
|
| 841 |
+
|
| 842 |
+
effective_batch_size = inputs.shape[0]
|
| 843 |
+
mask_token = tf.fill((effective_batch_size, 1), 1) * mask_token_id
|
| 844 |
+
inputs = tf.concat([inputs, mask_token], axis=1)
|
| 845 |
+
|
| 846 |
+
if lang_id is not None:
|
| 847 |
+
langs = tf.ones_like(inputs) * lang_id
|
| 848 |
+
else:
|
| 849 |
+
langs = None
|
| 850 |
+
return {"input_ids": inputs, "langs": langs}
|
| 851 |
+
|
| 852 |
+
@unpack_inputs
|
| 853 |
+
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING)
|
| 854 |
+
@add_code_sample_docstrings(
|
| 855 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 856 |
+
output_type=TFFlaubertWithLMHeadModelOutput,
|
| 857 |
+
config_class=_CONFIG_FOR_DOC,
|
| 858 |
+
)
|
| 859 |
+
def call(
|
| 860 |
+
self,
|
| 861 |
+
input_ids: np.ndarray | tf.Tensor | None = None,
|
| 862 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 863 |
+
langs: np.ndarray | tf.Tensor | None = None,
|
| 864 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 865 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 866 |
+
lengths: np.ndarray | tf.Tensor | None = None,
|
| 867 |
+
cache: Optional[Dict[str, tf.Tensor]] = None,
|
| 868 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 869 |
+
inputs_embeds: tf.Tensor | None = None,
|
| 870 |
+
output_attentions: Optional[bool] = None,
|
| 871 |
+
output_hidden_states: Optional[bool] = None,
|
| 872 |
+
return_dict: Optional[bool] = None,
|
| 873 |
+
training: Optional[bool] = False,
|
| 874 |
+
) -> Union[Tuple, TFFlaubertWithLMHeadModelOutput]:
|
| 875 |
+
transformer_outputs = self.transformer(
|
| 876 |
+
input_ids=input_ids,
|
| 877 |
+
attention_mask=attention_mask,
|
| 878 |
+
langs=langs,
|
| 879 |
+
token_type_ids=token_type_ids,
|
| 880 |
+
position_ids=position_ids,
|
| 881 |
+
lengths=lengths,
|
| 882 |
+
cache=cache,
|
| 883 |
+
head_mask=head_mask,
|
| 884 |
+
inputs_embeds=inputs_embeds,
|
| 885 |
+
output_attentions=output_attentions,
|
| 886 |
+
output_hidden_states=output_hidden_states,
|
| 887 |
+
return_dict=return_dict,
|
| 888 |
+
training=training,
|
| 889 |
+
)
|
| 890 |
+
output = transformer_outputs[0]
|
| 891 |
+
outputs = self.pred_layer(output)
|
| 892 |
+
|
| 893 |
+
if not return_dict:
|
| 894 |
+
return (outputs,) + transformer_outputs[1:]
|
| 895 |
+
|
| 896 |
+
return TFFlaubertWithLMHeadModelOutput(
|
| 897 |
+
logits=outputs, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
+
def build(self, input_shape=None):
|
| 901 |
+
if self.built:
|
| 902 |
+
return
|
| 903 |
+
self.built = True
|
| 904 |
+
if getattr(self, "transformer", None) is not None:
|
| 905 |
+
with tf.name_scope(self.transformer.name):
|
| 906 |
+
self.transformer.build(None)
|
| 907 |
+
if getattr(self, "pred_layer", None) is not None:
|
| 908 |
+
with tf.name_scope(self.pred_layer.name):
|
| 909 |
+
self.pred_layer.build(None)
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
@add_start_docstrings(
|
| 913 |
+
"""
|
| 914 |
+
Flaubert Model with a sequence classification/regression head on top (a linear layer on top of the pooled output)
|
| 915 |
+
e.g. for GLUE tasks.
|
| 916 |
+
""",
|
| 917 |
+
FLAUBERT_START_DOCSTRING,
|
| 918 |
+
)
|
| 919 |
+
# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForSequenceClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
| 920 |
+
class TFFlaubertForSequenceClassification(TFFlaubertPreTrainedModel, TFSequenceClassificationLoss):
|
| 921 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 922 |
+
super().__init__(config, *inputs, **kwargs)
|
| 923 |
+
self.num_labels = config.num_labels
|
| 924 |
+
|
| 925 |
+
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
| 926 |
+
self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary")
|
| 927 |
+
|
| 928 |
+
@unpack_inputs
|
| 929 |
+
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 930 |
+
@add_code_sample_docstrings(
|
| 931 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 932 |
+
output_type=TFSequenceClassifierOutput,
|
| 933 |
+
config_class=_CONFIG_FOR_DOC,
|
| 934 |
+
)
|
| 935 |
+
def call(
|
| 936 |
+
self,
|
| 937 |
+
input_ids: TFModelInputType | None = None,
|
| 938 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 939 |
+
langs: np.ndarray | tf.Tensor | None = None,
|
| 940 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 941 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 942 |
+
lengths: np.ndarray | tf.Tensor | None = None,
|
| 943 |
+
cache: Optional[Dict[str, tf.Tensor]] = None,
|
| 944 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 945 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 946 |
+
output_attentions: Optional[bool] = None,
|
| 947 |
+
output_hidden_states: Optional[bool] = None,
|
| 948 |
+
return_dict: Optional[bool] = None,
|
| 949 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 950 |
+
training: bool = False,
|
| 951 |
+
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
|
| 952 |
+
r"""
|
| 953 |
+
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
| 954 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 955 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 956 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 957 |
+
"""
|
| 958 |
+
transformer_outputs = self.transformer(
|
| 959 |
+
input_ids=input_ids,
|
| 960 |
+
attention_mask=attention_mask,
|
| 961 |
+
langs=langs,
|
| 962 |
+
token_type_ids=token_type_ids,
|
| 963 |
+
position_ids=position_ids,
|
| 964 |
+
lengths=lengths,
|
| 965 |
+
cache=cache,
|
| 966 |
+
head_mask=head_mask,
|
| 967 |
+
inputs_embeds=inputs_embeds,
|
| 968 |
+
output_attentions=output_attentions,
|
| 969 |
+
output_hidden_states=output_hidden_states,
|
| 970 |
+
return_dict=return_dict,
|
| 971 |
+
training=training,
|
| 972 |
+
)
|
| 973 |
+
output = transformer_outputs[0]
|
| 974 |
+
|
| 975 |
+
logits = self.sequence_summary(output)
|
| 976 |
+
|
| 977 |
+
loss = None if labels is None else self.hf_compute_loss(labels, logits)
|
| 978 |
+
|
| 979 |
+
if not return_dict:
|
| 980 |
+
output = (logits,) + transformer_outputs[1:]
|
| 981 |
+
return ((loss,) + output) if loss is not None else output
|
| 982 |
+
|
| 983 |
+
return TFSequenceClassifierOutput(
|
| 984 |
+
loss=loss,
|
| 985 |
+
logits=logits,
|
| 986 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 987 |
+
attentions=transformer_outputs.attentions,
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
def build(self, input_shape=None):
|
| 991 |
+
if self.built:
|
| 992 |
+
return
|
| 993 |
+
self.built = True
|
| 994 |
+
if getattr(self, "transformer", None) is not None:
|
| 995 |
+
with tf.name_scope(self.transformer.name):
|
| 996 |
+
self.transformer.build(None)
|
| 997 |
+
if getattr(self, "sequence_summary", None) is not None:
|
| 998 |
+
with tf.name_scope(self.sequence_summary.name):
|
| 999 |
+
self.sequence_summary.build(None)
|
| 1000 |
+
|
| 1001 |
+
|
| 1002 |
+
@add_start_docstrings(
|
| 1003 |
+
"""
|
| 1004 |
+
Flaubert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 1005 |
+
layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1006 |
+
""",
|
| 1007 |
+
FLAUBERT_START_DOCSTRING,
|
| 1008 |
+
)
|
| 1009 |
+
# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForQuestionAnsweringSimple with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
| 1010 |
+
class TFFlaubertForQuestionAnsweringSimple(TFFlaubertPreTrainedModel, TFQuestionAnsweringLoss):
|
| 1011 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 1012 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1013 |
+
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
| 1014 |
+
self.qa_outputs = keras.layers.Dense(
|
| 1015 |
+
config.num_labels, kernel_initializer=get_initializer(config.init_std), name="qa_outputs"
|
| 1016 |
+
)
|
| 1017 |
+
self.config = config
|
| 1018 |
+
|
| 1019 |
+
@unpack_inputs
|
| 1020 |
+
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1021 |
+
@add_code_sample_docstrings(
|
| 1022 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1023 |
+
output_type=TFQuestionAnsweringModelOutput,
|
| 1024 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1025 |
+
)
|
| 1026 |
+
def call(
|
| 1027 |
+
self,
|
| 1028 |
+
input_ids: TFModelInputType | None = None,
|
| 1029 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1030 |
+
langs: np.ndarray | tf.Tensor | None = None,
|
| 1031 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1032 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1033 |
+
lengths: np.ndarray | tf.Tensor | None = None,
|
| 1034 |
+
cache: Optional[Dict[str, tf.Tensor]] = None,
|
| 1035 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1036 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1037 |
+
output_attentions: Optional[bool] = None,
|
| 1038 |
+
output_hidden_states: Optional[bool] = None,
|
| 1039 |
+
return_dict: Optional[bool] = None,
|
| 1040 |
+
start_positions: np.ndarray | tf.Tensor | None = None,
|
| 1041 |
+
end_positions: np.ndarray | tf.Tensor | None = None,
|
| 1042 |
+
training: bool = False,
|
| 1043 |
+
) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
|
| 1044 |
+
r"""
|
| 1045 |
+
start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
| 1046 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1047 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1048 |
+
are not taken into account for computing the loss.
|
| 1049 |
+
end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
| 1050 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1051 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1052 |
+
are not taken into account for computing the loss.
|
| 1053 |
+
"""
|
| 1054 |
+
transformer_outputs = self.transformer(
|
| 1055 |
+
input_ids=input_ids,
|
| 1056 |
+
attention_mask=attention_mask,
|
| 1057 |
+
langs=langs,
|
| 1058 |
+
token_type_ids=token_type_ids,
|
| 1059 |
+
position_ids=position_ids,
|
| 1060 |
+
lengths=lengths,
|
| 1061 |
+
cache=cache,
|
| 1062 |
+
head_mask=head_mask,
|
| 1063 |
+
inputs_embeds=inputs_embeds,
|
| 1064 |
+
output_attentions=output_attentions,
|
| 1065 |
+
output_hidden_states=output_hidden_states,
|
| 1066 |
+
return_dict=return_dict,
|
| 1067 |
+
training=training,
|
| 1068 |
+
)
|
| 1069 |
+
sequence_output = transformer_outputs[0]
|
| 1070 |
+
|
| 1071 |
+
logits = self.qa_outputs(sequence_output)
|
| 1072 |
+
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
| 1073 |
+
start_logits = tf.squeeze(start_logits, axis=-1)
|
| 1074 |
+
end_logits = tf.squeeze(end_logits, axis=-1)
|
| 1075 |
+
|
| 1076 |
+
loss = None
|
| 1077 |
+
if start_positions is not None and end_positions is not None:
|
| 1078 |
+
labels = {"start_position": start_positions}
|
| 1079 |
+
labels["end_position"] = end_positions
|
| 1080 |
+
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
|
| 1081 |
+
|
| 1082 |
+
if not return_dict:
|
| 1083 |
+
output = (start_logits, end_logits) + transformer_outputs[1:]
|
| 1084 |
+
return ((loss,) + output) if loss is not None else output
|
| 1085 |
+
|
| 1086 |
+
return TFQuestionAnsweringModelOutput(
|
| 1087 |
+
loss=loss,
|
| 1088 |
+
start_logits=start_logits,
|
| 1089 |
+
end_logits=end_logits,
|
| 1090 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 1091 |
+
attentions=transformer_outputs.attentions,
|
| 1092 |
+
)
|
| 1093 |
+
|
| 1094 |
+
def build(self, input_shape=None):
|
| 1095 |
+
if self.built:
|
| 1096 |
+
return
|
| 1097 |
+
self.built = True
|
| 1098 |
+
if getattr(self, "transformer", None) is not None:
|
| 1099 |
+
with tf.name_scope(self.transformer.name):
|
| 1100 |
+
self.transformer.build(None)
|
| 1101 |
+
if getattr(self, "qa_outputs", None) is not None:
|
| 1102 |
+
with tf.name_scope(self.qa_outputs.name):
|
| 1103 |
+
self.qa_outputs.build([None, None, self.config.hidden_size])
|
| 1104 |
+
|
| 1105 |
+
|
| 1106 |
+
@add_start_docstrings(
|
| 1107 |
+
"""
|
| 1108 |
+
Flaubert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 1109 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1110 |
+
""",
|
| 1111 |
+
FLAUBERT_START_DOCSTRING,
|
| 1112 |
+
)
|
| 1113 |
+
# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForTokenClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
| 1114 |
+
class TFFlaubertForTokenClassification(TFFlaubertPreTrainedModel, TFTokenClassificationLoss):
|
| 1115 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 1116 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1117 |
+
self.num_labels = config.num_labels
|
| 1118 |
+
|
| 1119 |
+
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
| 1120 |
+
self.dropout = keras.layers.Dropout(config.dropout)
|
| 1121 |
+
self.classifier = keras.layers.Dense(
|
| 1122 |
+
config.num_labels, kernel_initializer=get_initializer(config.init_std), name="classifier"
|
| 1123 |
+
)
|
| 1124 |
+
self.config = config
|
| 1125 |
+
|
| 1126 |
+
@unpack_inputs
|
| 1127 |
+
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1128 |
+
@add_code_sample_docstrings(
|
| 1129 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1130 |
+
output_type=TFTokenClassifierOutput,
|
| 1131 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1132 |
+
)
|
| 1133 |
+
def call(
|
| 1134 |
+
self,
|
| 1135 |
+
input_ids: TFModelInputType | None = None,
|
| 1136 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1137 |
+
langs: np.ndarray | tf.Tensor | None = None,
|
| 1138 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1139 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1140 |
+
lengths: np.ndarray | tf.Tensor | None = None,
|
| 1141 |
+
cache: Optional[Dict[str, tf.Tensor]] = None,
|
| 1142 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1143 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1144 |
+
output_attentions: Optional[bool] = None,
|
| 1145 |
+
output_hidden_states: Optional[bool] = None,
|
| 1146 |
+
return_dict: Optional[bool] = None,
|
| 1147 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1148 |
+
training: bool = False,
|
| 1149 |
+
) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
|
| 1150 |
+
r"""
|
| 1151 |
+
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1152 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
| 1153 |
+
"""
|
| 1154 |
+
transformer_outputs = self.transformer(
|
| 1155 |
+
input_ids=input_ids,
|
| 1156 |
+
attention_mask=attention_mask,
|
| 1157 |
+
langs=langs,
|
| 1158 |
+
token_type_ids=token_type_ids,
|
| 1159 |
+
position_ids=position_ids,
|
| 1160 |
+
lengths=lengths,
|
| 1161 |
+
cache=cache,
|
| 1162 |
+
head_mask=head_mask,
|
| 1163 |
+
inputs_embeds=inputs_embeds,
|
| 1164 |
+
output_attentions=output_attentions,
|
| 1165 |
+
output_hidden_states=output_hidden_states,
|
| 1166 |
+
return_dict=return_dict,
|
| 1167 |
+
training=training,
|
| 1168 |
+
)
|
| 1169 |
+
sequence_output = transformer_outputs[0]
|
| 1170 |
+
|
| 1171 |
+
sequence_output = self.dropout(sequence_output, training=training)
|
| 1172 |
+
logits = self.classifier(sequence_output)
|
| 1173 |
+
|
| 1174 |
+
loss = None if labels is None else self.hf_compute_loss(labels, logits)
|
| 1175 |
+
|
| 1176 |
+
if not return_dict:
|
| 1177 |
+
output = (logits,) + transformer_outputs[1:]
|
| 1178 |
+
return ((loss,) + output) if loss is not None else output
|
| 1179 |
+
|
| 1180 |
+
return TFTokenClassifierOutput(
|
| 1181 |
+
loss=loss,
|
| 1182 |
+
logits=logits,
|
| 1183 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 1184 |
+
attentions=transformer_outputs.attentions,
|
| 1185 |
+
)
|
| 1186 |
+
|
| 1187 |
+
def build(self, input_shape=None):
|
| 1188 |
+
if self.built:
|
| 1189 |
+
return
|
| 1190 |
+
self.built = True
|
| 1191 |
+
if getattr(self, "transformer", None) is not None:
|
| 1192 |
+
with tf.name_scope(self.transformer.name):
|
| 1193 |
+
self.transformer.build(None)
|
| 1194 |
+
if getattr(self, "classifier", None) is not None:
|
| 1195 |
+
with tf.name_scope(self.classifier.name):
|
| 1196 |
+
self.classifier.build([None, None, self.config.hidden_size])
|
| 1197 |
+
|
| 1198 |
+
|
| 1199 |
+
@add_start_docstrings(
|
| 1200 |
+
"""
|
| 1201 |
+
Flaubert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
| 1202 |
+
softmax) e.g. for RocStories/SWAG tasks.
|
| 1203 |
+
""",
|
| 1204 |
+
FLAUBERT_START_DOCSTRING,
|
| 1205 |
+
)
|
| 1206 |
+
# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForMultipleChoice with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
| 1207 |
+
class TFFlaubertForMultipleChoice(TFFlaubertPreTrainedModel, TFMultipleChoiceLoss):
|
| 1208 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 1209 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1210 |
+
|
| 1211 |
+
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
| 1212 |
+
self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary")
|
| 1213 |
+
self.logits_proj = keras.layers.Dense(
|
| 1214 |
+
1, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
|
| 1215 |
+
)
|
| 1216 |
+
self.config = config
|
| 1217 |
+
|
| 1218 |
+
@property
|
| 1219 |
+
def dummy_inputs(self):
|
| 1220 |
+
"""
|
| 1221 |
+
Dummy inputs to build the network.
|
| 1222 |
+
|
| 1223 |
+
Returns:
|
| 1224 |
+
tf.Tensor with dummy inputs
|
| 1225 |
+
"""
|
| 1226 |
+
# Sometimes Flaubert has language embeddings so don't forget to build them as well if needed
|
| 1227 |
+
if self.config.use_lang_emb and self.config.n_langs > 1:
|
| 1228 |
+
return {
|
| 1229 |
+
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32),
|
| 1230 |
+
"langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32),
|
| 1231 |
+
}
|
| 1232 |
+
else:
|
| 1233 |
+
return {
|
| 1234 |
+
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32),
|
| 1235 |
+
}
|
| 1236 |
+
|
| 1237 |
+
@unpack_inputs
|
| 1238 |
+
@add_start_docstrings_to_model_forward(
|
| 1239 |
+
FLAUBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
| 1240 |
+
)
|
| 1241 |
+
@add_code_sample_docstrings(
|
| 1242 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1243 |
+
output_type=TFMultipleChoiceModelOutput,
|
| 1244 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1245 |
+
)
|
| 1246 |
+
def call(
|
| 1247 |
+
self,
|
| 1248 |
+
input_ids: TFModelInputType | None = None,
|
| 1249 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1250 |
+
langs: np.ndarray | tf.Tensor | None = None,
|
| 1251 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1252 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1253 |
+
lengths: np.ndarray | tf.Tensor | None = None,
|
| 1254 |
+
cache: Optional[Dict[str, tf.Tensor]] = None,
|
| 1255 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1256 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1257 |
+
output_attentions: Optional[bool] = None,
|
| 1258 |
+
output_hidden_states: Optional[bool] = None,
|
| 1259 |
+
return_dict: Optional[bool] = None,
|
| 1260 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1261 |
+
training: bool = False,
|
| 1262 |
+
) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
|
| 1263 |
+
if input_ids is not None:
|
| 1264 |
+
num_choices = shape_list(input_ids)[1]
|
| 1265 |
+
seq_length = shape_list(input_ids)[2]
|
| 1266 |
+
else:
|
| 1267 |
+
num_choices = shape_list(inputs_embeds)[1]
|
| 1268 |
+
seq_length = shape_list(inputs_embeds)[2]
|
| 1269 |
+
|
| 1270 |
+
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
| 1271 |
+
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
| 1272 |
+
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
| 1273 |
+
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
| 1274 |
+
flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None
|
| 1275 |
+
flat_inputs_embeds = (
|
| 1276 |
+
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
| 1277 |
+
if inputs_embeds is not None
|
| 1278 |
+
else None
|
| 1279 |
+
)
|
| 1280 |
+
|
| 1281 |
+
if lengths is not None:
|
| 1282 |
+
logger.warning(
|
| 1283 |
+
"The `lengths` parameter cannot be used with the Flaubert multiple choice models. Please use the "
|
| 1284 |
+
"attention mask instead.",
|
| 1285 |
+
)
|
| 1286 |
+
lengths = None
|
| 1287 |
+
|
| 1288 |
+
transformer_outputs = self.transformer(
|
| 1289 |
+
flat_input_ids,
|
| 1290 |
+
flat_attention_mask,
|
| 1291 |
+
flat_langs,
|
| 1292 |
+
flat_token_type_ids,
|
| 1293 |
+
flat_position_ids,
|
| 1294 |
+
lengths,
|
| 1295 |
+
cache,
|
| 1296 |
+
head_mask,
|
| 1297 |
+
flat_inputs_embeds,
|
| 1298 |
+
output_attentions,
|
| 1299 |
+
output_hidden_states,
|
| 1300 |
+
return_dict=return_dict,
|
| 1301 |
+
training=training,
|
| 1302 |
+
)
|
| 1303 |
+
output = transformer_outputs[0]
|
| 1304 |
+
logits = self.sequence_summary(output)
|
| 1305 |
+
logits = self.logits_proj(logits)
|
| 1306 |
+
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
| 1307 |
+
|
| 1308 |
+
loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
|
| 1309 |
+
|
| 1310 |
+
if not return_dict:
|
| 1311 |
+
output = (reshaped_logits,) + transformer_outputs[1:]
|
| 1312 |
+
return ((loss,) + output) if loss is not None else output
|
| 1313 |
+
|
| 1314 |
+
return TFMultipleChoiceModelOutput(
|
| 1315 |
+
loss=loss,
|
| 1316 |
+
logits=reshaped_logits,
|
| 1317 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 1318 |
+
attentions=transformer_outputs.attentions,
|
| 1319 |
+
)
|
| 1320 |
+
|
| 1321 |
+
def build(self, input_shape=None):
|
| 1322 |
+
if self.built:
|
| 1323 |
+
return
|
| 1324 |
+
self.built = True
|
| 1325 |
+
if getattr(self, "transformer", None) is not None:
|
| 1326 |
+
with tf.name_scope(self.transformer.name):
|
| 1327 |
+
self.transformer.build(None)
|
| 1328 |
+
if getattr(self, "sequence_summary", None) is not None:
|
| 1329 |
+
with tf.name_scope(self.sequence_summary.name):
|
| 1330 |
+
self.sequence_summary.build(None)
|
| 1331 |
+
if getattr(self, "logits_proj", None) is not None:
|
| 1332 |
+
with tf.name_scope(self.logits_proj.name):
|
| 1333 |
+
self.logits_proj.build([None, None, self.config.num_labels])
|
| 1334 |
+
|
| 1335 |
+
|
| 1336 |
+
__all__ = [
|
| 1337 |
+
"TFFlaubertForMultipleChoice",
|
| 1338 |
+
"TFFlaubertForQuestionAnsweringSimple",
|
| 1339 |
+
"TFFlaubertForSequenceClassification",
|
| 1340 |
+
"TFFlaubertForTokenClassification",
|
| 1341 |
+
"TFFlaubertModel",
|
| 1342 |
+
"TFFlaubertPreTrainedModel",
|
| 1343 |
+
"TFFlaubertWithLMHeadModel",
|
| 1344 |
+
]
|
docs/transformers/build/lib/transformers/models/flaubert/tokenization_flaubert.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2019-present CNRS, Facebook Inc. and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for Flaubert."""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import re
|
| 20 |
+
import unicodedata
|
| 21 |
+
from typing import List, Optional, Tuple
|
| 22 |
+
|
| 23 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
| 24 |
+
from ...utils import logging
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
VOCAB_FILES_NAMES = {
|
| 30 |
+
"vocab_file": "vocab.json",
|
| 31 |
+
"merges_file": "merges.txt",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def convert_to_unicode(text):
|
| 36 |
+
"""
|
| 37 |
+
Converts `text` to Unicode (if it's not already), assuming UTF-8 input.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def ensure_text(s, encoding="utf-8", errors="strict"):
|
| 41 |
+
if isinstance(s, bytes):
|
| 42 |
+
return s.decode(encoding, errors)
|
| 43 |
+
elif isinstance(s, str):
|
| 44 |
+
return s
|
| 45 |
+
else:
|
| 46 |
+
raise TypeError(f"not expecting type '{type(s)}'")
|
| 47 |
+
|
| 48 |
+
return ensure_text(text, encoding="utf-8", errors="ignore")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Copied from transformers.models.xlm.tokenization_xlm.get_pairs
|
| 52 |
+
def get_pairs(word):
|
| 53 |
+
"""
|
| 54 |
+
Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
|
| 55 |
+
strings)
|
| 56 |
+
"""
|
| 57 |
+
pairs = set()
|
| 58 |
+
prev_char = word[0]
|
| 59 |
+
for char in word[1:]:
|
| 60 |
+
pairs.add((prev_char, char))
|
| 61 |
+
prev_char = char
|
| 62 |
+
return pairs
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# Copied from transformers.models.xlm.tokenization_xlm.replace_unicode_punct
|
| 66 |
+
def replace_unicode_punct(text):
|
| 67 |
+
"""
|
| 68 |
+
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
|
| 69 |
+
"""
|
| 70 |
+
text = text.replace(",", ",")
|
| 71 |
+
text = re.sub(r"。\s*", ". ", text)
|
| 72 |
+
text = text.replace("、", ",")
|
| 73 |
+
text = text.replace("”", '"')
|
| 74 |
+
text = text.replace("“", '"')
|
| 75 |
+
text = text.replace("∶", ":")
|
| 76 |
+
text = text.replace(":", ":")
|
| 77 |
+
text = text.replace("?", "?")
|
| 78 |
+
text = text.replace("《", '"')
|
| 79 |
+
text = text.replace("》", '"')
|
| 80 |
+
text = text.replace(")", ")")
|
| 81 |
+
text = text.replace("!", "!")
|
| 82 |
+
text = text.replace("(", "(")
|
| 83 |
+
text = text.replace(";", ";")
|
| 84 |
+
text = text.replace("1", "1")
|
| 85 |
+
text = text.replace("」", '"')
|
| 86 |
+
text = text.replace("「", '"')
|
| 87 |
+
text = text.replace("0", "0")
|
| 88 |
+
text = text.replace("3", "3")
|
| 89 |
+
text = text.replace("2", "2")
|
| 90 |
+
text = text.replace("5", "5")
|
| 91 |
+
text = text.replace("6", "6")
|
| 92 |
+
text = text.replace("9", "9")
|
| 93 |
+
text = text.replace("7", "7")
|
| 94 |
+
text = text.replace("8", "8")
|
| 95 |
+
text = text.replace("4", "4")
|
| 96 |
+
text = re.sub(r".\s*", ". ", text)
|
| 97 |
+
text = text.replace("~", "~")
|
| 98 |
+
text = text.replace("’", "'")
|
| 99 |
+
text = text.replace("…", "...")
|
| 100 |
+
text = text.replace("━", "-")
|
| 101 |
+
text = text.replace("〈", "<")
|
| 102 |
+
text = text.replace("〉", ">")
|
| 103 |
+
text = text.replace("【", "[")
|
| 104 |
+
text = text.replace("】", "]")
|
| 105 |
+
text = text.replace("%", "%")
|
| 106 |
+
return text
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# Copied from transformers.models.xlm.tokenization_xlm.remove_non_printing_char
|
| 110 |
+
def remove_non_printing_char(text):
|
| 111 |
+
"""
|
| 112 |
+
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
|
| 113 |
+
"""
|
| 114 |
+
output = []
|
| 115 |
+
for char in text:
|
| 116 |
+
cat = unicodedata.category(char)
|
| 117 |
+
if cat.startswith("C"):
|
| 118 |
+
continue
|
| 119 |
+
output.append(char)
|
| 120 |
+
return "".join(output)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class FlaubertTokenizer(PreTrainedTokenizer):
|
| 124 |
+
"""
|
| 125 |
+
Construct a Flaubert tokenizer. Based on Byte-Pair Encoding. The tokenization process is the following:
|
| 126 |
+
|
| 127 |
+
- Moses preprocessing and tokenization.
|
| 128 |
+
- Normalizing all inputs text.
|
| 129 |
+
- The arguments `special_tokens` and the function `set_special_tokens`, can be used to add additional symbols (like
|
| 130 |
+
"__classify__") to a vocabulary.
|
| 131 |
+
- The argument `do_lowercase` controls lower casing (automatically set for pretrained vocabularies).
|
| 132 |
+
|
| 133 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 134 |
+
this superclass for more information regarding those methods.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
vocab_file (`str`):
|
| 138 |
+
Vocabulary file.
|
| 139 |
+
merges_file (`str`):
|
| 140 |
+
Merges file.
|
| 141 |
+
do_lowercase (`bool`, *optional*, defaults to `False`):
|
| 142 |
+
Controls lower casing.
|
| 143 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 144 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 145 |
+
token instead.
|
| 146 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 147 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 148 |
+
|
| 149 |
+
<Tip>
|
| 150 |
+
|
| 151 |
+
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
| 152 |
+
sequence. The token used is the `cls_token`.
|
| 153 |
+
|
| 154 |
+
</Tip>
|
| 155 |
+
|
| 156 |
+
sep_token (`str`, *optional*, defaults to `"</s>"`):
|
| 157 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 158 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 159 |
+
token of a sequence built with special tokens.
|
| 160 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 161 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 162 |
+
cls_token (`str`, *optional*, defaults to `"</s>"`):
|
| 163 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
| 164 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
| 165 |
+
mask_token (`str`, *optional*, defaults to `"<special1>"`):
|
| 166 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 167 |
+
modeling. This is the token which the model will try to predict.
|
| 168 |
+
additional_special_tokens (`List[str]`, *optional*, defaults to `['<special0>', '<special1>', '<special2>', '<special3>', '<special4>', '<special5>', '<special6>', '<special7>', '<special8>', '<special9>']`):
|
| 169 |
+
List of additional special tokens.
|
| 170 |
+
lang2id (`Dict[str, int]`, *optional*):
|
| 171 |
+
Dictionary mapping languages string identifiers to their IDs.
|
| 172 |
+
id2lang (`Dict[int, str]`, *optional*):
|
| 173 |
+
Dictionary mapping language IDs to their string identifiers.
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 177 |
+
|
| 178 |
+
def __init__(
|
| 179 |
+
self,
|
| 180 |
+
vocab_file,
|
| 181 |
+
merges_file,
|
| 182 |
+
do_lowercase=False,
|
| 183 |
+
unk_token="<unk>",
|
| 184 |
+
bos_token="<s>",
|
| 185 |
+
sep_token="</s>",
|
| 186 |
+
pad_token="<pad>",
|
| 187 |
+
cls_token="</s>",
|
| 188 |
+
mask_token="<special1>",
|
| 189 |
+
additional_special_tokens=[
|
| 190 |
+
"<special0>",
|
| 191 |
+
"<special1>",
|
| 192 |
+
"<special2>",
|
| 193 |
+
"<special3>",
|
| 194 |
+
"<special4>",
|
| 195 |
+
"<special5>",
|
| 196 |
+
"<special6>",
|
| 197 |
+
"<special7>",
|
| 198 |
+
"<special8>",
|
| 199 |
+
"<special9>",
|
| 200 |
+
],
|
| 201 |
+
lang2id=None,
|
| 202 |
+
id2lang=None,
|
| 203 |
+
**kwargs,
|
| 204 |
+
):
|
| 205 |
+
do_lowercase_and_remove_accent = kwargs.pop("do_lowercase_and_remove_accent", None)
|
| 206 |
+
if do_lowercase_and_remove_accent is not None:
|
| 207 |
+
logger.warning(
|
| 208 |
+
"`do_lowercase_and_remove_accent` is passed as a keyword argument, but this won't do anything."
|
| 209 |
+
" `FlaubertTokenizer` will always set it to `False`."
|
| 210 |
+
)
|
| 211 |
+
# always `False`
|
| 212 |
+
self.do_lowercase_and_remove_accent = False
|
| 213 |
+
|
| 214 |
+
self.do_lowercase = do_lowercase
|
| 215 |
+
|
| 216 |
+
try:
|
| 217 |
+
import sacremoses
|
| 218 |
+
except ImportError:
|
| 219 |
+
raise ImportError(
|
| 220 |
+
"You need to install sacremoses to use FlaubertTokenizer. "
|
| 221 |
+
"See https://pypi.org/project/sacremoses/ for installation."
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
self.sm = sacremoses
|
| 225 |
+
|
| 226 |
+
# cache of sm.MosesPunctNormalizer instance
|
| 227 |
+
self.cache_moses_punct_normalizer = {}
|
| 228 |
+
# cache of sm.MosesTokenizer instance
|
| 229 |
+
self.cache_moses_tokenizer = {}
|
| 230 |
+
self.lang_with_custom_tokenizer = {"zh", "th", "ja"}
|
| 231 |
+
self.lang2id = lang2id
|
| 232 |
+
self.id2lang = id2lang
|
| 233 |
+
if lang2id is not None and id2lang is not None:
|
| 234 |
+
assert len(lang2id) == len(id2lang)
|
| 235 |
+
|
| 236 |
+
self.ja_word_tokenizer = None
|
| 237 |
+
self.zh_word_tokenizer = None
|
| 238 |
+
|
| 239 |
+
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
| 240 |
+
self.encoder = json.load(vocab_handle)
|
| 241 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 242 |
+
with open(merges_file, encoding="utf-8") as merges_handle:
|
| 243 |
+
merges = merges_handle.read().split("\n")[:-1]
|
| 244 |
+
merges = [tuple(merge.split()[:2]) for merge in merges]
|
| 245 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 246 |
+
self.cache = {}
|
| 247 |
+
|
| 248 |
+
super().__init__(
|
| 249 |
+
do_lowercase=do_lowercase,
|
| 250 |
+
unk_token=unk_token,
|
| 251 |
+
bos_token=bos_token,
|
| 252 |
+
sep_token=sep_token,
|
| 253 |
+
pad_token=pad_token,
|
| 254 |
+
cls_token=cls_token,
|
| 255 |
+
mask_token=mask_token,
|
| 256 |
+
additional_special_tokens=additional_special_tokens,
|
| 257 |
+
lang2id=lang2id,
|
| 258 |
+
id2lang=id2lang,
|
| 259 |
+
**kwargs,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
@property
|
| 263 |
+
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.do_lower_case
|
| 264 |
+
def do_lower_case(self):
|
| 265 |
+
return self.do_lowercase_and_remove_accent
|
| 266 |
+
|
| 267 |
+
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_punct_norm
|
| 268 |
+
def moses_punct_norm(self, text, lang):
|
| 269 |
+
if lang not in self.cache_moses_punct_normalizer:
|
| 270 |
+
punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang)
|
| 271 |
+
self.cache_moses_punct_normalizer[lang] = punct_normalizer
|
| 272 |
+
else:
|
| 273 |
+
punct_normalizer = self.cache_moses_punct_normalizer[lang]
|
| 274 |
+
return punct_normalizer.normalize(text)
|
| 275 |
+
|
| 276 |
+
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_tokenize
|
| 277 |
+
def moses_tokenize(self, text, lang):
|
| 278 |
+
if lang not in self.cache_moses_tokenizer:
|
| 279 |
+
moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
|
| 280 |
+
self.cache_moses_tokenizer[lang] = moses_tokenizer
|
| 281 |
+
else:
|
| 282 |
+
moses_tokenizer = self.cache_moses_tokenizer[lang]
|
| 283 |
+
return moses_tokenizer.tokenize(text, return_str=False, escape=False)
|
| 284 |
+
|
| 285 |
+
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_pipeline
|
| 286 |
+
def moses_pipeline(self, text, lang):
|
| 287 |
+
text = replace_unicode_punct(text)
|
| 288 |
+
text = self.moses_punct_norm(text, lang)
|
| 289 |
+
text = remove_non_printing_char(text)
|
| 290 |
+
return text
|
| 291 |
+
|
| 292 |
+
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.ja_tokenize
|
| 293 |
+
def ja_tokenize(self, text):
|
| 294 |
+
if self.ja_word_tokenizer is None:
|
| 295 |
+
try:
|
| 296 |
+
import Mykytea
|
| 297 |
+
|
| 298 |
+
self.ja_word_tokenizer = Mykytea.Mykytea(
|
| 299 |
+
f"-model {os.path.expanduser('~')}/local/share/kytea/model.bin"
|
| 300 |
+
)
|
| 301 |
+
except (AttributeError, ImportError):
|
| 302 |
+
logger.error(
|
| 303 |
+
"Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper"
|
| 304 |
+
" (https://github.com/chezou/Mykytea-python) with the following steps"
|
| 305 |
+
)
|
| 306 |
+
logger.error("1. git clone [email protected]:neubig/kytea.git && cd kytea")
|
| 307 |
+
logger.error("2. autoreconf -i")
|
| 308 |
+
logger.error("3. ./configure --prefix=$HOME/local")
|
| 309 |
+
logger.error("4. make && make install")
|
| 310 |
+
logger.error("5. pip install kytea")
|
| 311 |
+
raise
|
| 312 |
+
return list(self.ja_word_tokenizer.getWS(text))
|
| 313 |
+
|
| 314 |
+
@property
|
| 315 |
+
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.vocab_size
|
| 316 |
+
def vocab_size(self):
|
| 317 |
+
return len(self.encoder)
|
| 318 |
+
|
| 319 |
+
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.get_vocab
|
| 320 |
+
def get_vocab(self):
|
| 321 |
+
return dict(self.encoder, **self.added_tokens_encoder)
|
| 322 |
+
|
| 323 |
+
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.bpe
|
| 324 |
+
def bpe(self, token):
|
| 325 |
+
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
| 326 |
+
if token in self.cache:
|
| 327 |
+
return self.cache[token]
|
| 328 |
+
pairs = get_pairs(word)
|
| 329 |
+
|
| 330 |
+
if not pairs:
|
| 331 |
+
return token + "</w>"
|
| 332 |
+
|
| 333 |
+
while True:
|
| 334 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
| 335 |
+
if bigram not in self.bpe_ranks:
|
| 336 |
+
break
|
| 337 |
+
first, second = bigram
|
| 338 |
+
new_word = []
|
| 339 |
+
i = 0
|
| 340 |
+
while i < len(word):
|
| 341 |
+
try:
|
| 342 |
+
j = word.index(first, i)
|
| 343 |
+
except ValueError:
|
| 344 |
+
new_word.extend(word[i:])
|
| 345 |
+
break
|
| 346 |
+
else:
|
| 347 |
+
new_word.extend(word[i:j])
|
| 348 |
+
i = j
|
| 349 |
+
|
| 350 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 351 |
+
new_word.append(first + second)
|
| 352 |
+
i += 2
|
| 353 |
+
else:
|
| 354 |
+
new_word.append(word[i])
|
| 355 |
+
i += 1
|
| 356 |
+
new_word = tuple(new_word)
|
| 357 |
+
word = new_word
|
| 358 |
+
if len(word) == 1:
|
| 359 |
+
break
|
| 360 |
+
else:
|
| 361 |
+
pairs = get_pairs(word)
|
| 362 |
+
word = " ".join(word)
|
| 363 |
+
if word == "\n </w>":
|
| 364 |
+
word = "\n</w>"
|
| 365 |
+
self.cache[token] = word
|
| 366 |
+
return word
|
| 367 |
+
|
| 368 |
+
def preprocess_text(self, text):
|
| 369 |
+
text = text.replace("``", '"').replace("''", '"')
|
| 370 |
+
text = convert_to_unicode(text)
|
| 371 |
+
text = unicodedata.normalize("NFC", text)
|
| 372 |
+
|
| 373 |
+
if self.do_lowercase:
|
| 374 |
+
text = text.lower()
|
| 375 |
+
|
| 376 |
+
return text
|
| 377 |
+
|
| 378 |
+
def _tokenize(self, text, bypass_tokenizer=False):
|
| 379 |
+
"""
|
| 380 |
+
Tokenize a string given language code using Moses.
|
| 381 |
+
|
| 382 |
+
Details of tokenization:
|
| 383 |
+
|
| 384 |
+
- [sacremoses](https://github.com/alvations/sacremoses): port of Moses
|
| 385 |
+
- Install with `pip install sacremoses`
|
| 386 |
+
|
| 387 |
+
Args:
|
| 388 |
+
- bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False)
|
| 389 |
+
(bool). If True, we only apply BPE.
|
| 390 |
+
|
| 391 |
+
Returns:
|
| 392 |
+
List of tokens.
|
| 393 |
+
"""
|
| 394 |
+
lang = "fr"
|
| 395 |
+
if lang and self.lang2id and lang not in self.lang2id:
|
| 396 |
+
logger.error(
|
| 397 |
+
"Supplied language code not found in lang2id mapping. Please check that your language is supported by"
|
| 398 |
+
" the loaded pretrained model."
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
if bypass_tokenizer:
|
| 402 |
+
text = text.split()
|
| 403 |
+
else:
|
| 404 |
+
text = self.preprocess_text(text)
|
| 405 |
+
text = self.moses_pipeline(text, lang=lang)
|
| 406 |
+
text = self.moses_tokenize(text, lang=lang)
|
| 407 |
+
|
| 408 |
+
split_tokens = []
|
| 409 |
+
for token in text:
|
| 410 |
+
if token:
|
| 411 |
+
split_tokens.extend(list(self.bpe(token).split(" ")))
|
| 412 |
+
|
| 413 |
+
return split_tokens
|
| 414 |
+
|
| 415 |
+
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer._convert_token_to_id
|
| 416 |
+
def _convert_token_to_id(self, token):
|
| 417 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 418 |
+
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
| 419 |
+
|
| 420 |
+
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer._convert_id_to_token
|
| 421 |
+
def _convert_id_to_token(self, index):
|
| 422 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 423 |
+
return self.decoder.get(index, self.unk_token)
|
| 424 |
+
|
| 425 |
+
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.convert_tokens_to_string
|
| 426 |
+
def convert_tokens_to_string(self, tokens):
|
| 427 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 428 |
+
out_string = "".join(tokens).replace("</w>", " ").strip()
|
| 429 |
+
return out_string
|
| 430 |
+
|
| 431 |
+
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.build_inputs_with_special_tokens
|
| 432 |
+
def build_inputs_with_special_tokens(
|
| 433 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 434 |
+
) -> List[int]:
|
| 435 |
+
"""
|
| 436 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 437 |
+
adding special tokens. An XLM sequence has the following format:
|
| 438 |
+
|
| 439 |
+
- single sequence: `<s> X </s>`
|
| 440 |
+
- pair of sequences: `<s> A </s> B </s>`
|
| 441 |
+
|
| 442 |
+
Args:
|
| 443 |
+
token_ids_0 (`List[int]`):
|
| 444 |
+
List of IDs to which the special tokens will be added.
|
| 445 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 446 |
+
Optional second list of IDs for sequence pairs.
|
| 447 |
+
|
| 448 |
+
Returns:
|
| 449 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 450 |
+
|
| 451 |
+
"""
|
| 452 |
+
bos = [self.bos_token_id]
|
| 453 |
+
sep = [self.sep_token_id]
|
| 454 |
+
|
| 455 |
+
if token_ids_1 is None:
|
| 456 |
+
return bos + token_ids_0 + sep
|
| 457 |
+
return bos + token_ids_0 + sep + token_ids_1 + sep
|
| 458 |
+
|
| 459 |
+
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.get_special_tokens_mask
|
| 460 |
+
def get_special_tokens_mask(
|
| 461 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 462 |
+
) -> List[int]:
|
| 463 |
+
"""
|
| 464 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 465 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
token_ids_0 (`List[int]`):
|
| 469 |
+
List of IDs.
|
| 470 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 471 |
+
Optional second list of IDs for sequence pairs.
|
| 472 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 473 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 474 |
+
|
| 475 |
+
Returns:
|
| 476 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 477 |
+
"""
|
| 478 |
+
|
| 479 |
+
if already_has_special_tokens:
|
| 480 |
+
return super().get_special_tokens_mask(
|
| 481 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
if token_ids_1 is not None:
|
| 485 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 486 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 487 |
+
|
| 488 |
+
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.create_token_type_ids_from_sequences
|
| 489 |
+
def create_token_type_ids_from_sequences(
|
| 490 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 491 |
+
) -> List[int]:
|
| 492 |
+
"""
|
| 493 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. An XLM sequence
|
| 494 |
+
pair mask has the following format:
|
| 495 |
+
|
| 496 |
+
```
|
| 497 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 498 |
+
| first sequence | second sequence |
|
| 499 |
+
```
|
| 500 |
+
|
| 501 |
+
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
|
| 502 |
+
|
| 503 |
+
Args:
|
| 504 |
+
token_ids_0 (`List[int]`):
|
| 505 |
+
List of IDs.
|
| 506 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 507 |
+
Optional second list of IDs for sequence pairs.
|
| 508 |
+
|
| 509 |
+
Returns:
|
| 510 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
| 511 |
+
"""
|
| 512 |
+
sep = [self.sep_token_id]
|
| 513 |
+
cls = [self.cls_token_id]
|
| 514 |
+
if token_ids_1 is None:
|
| 515 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 516 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 517 |
+
|
| 518 |
+
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.save_vocabulary
|
| 519 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 520 |
+
if not os.path.isdir(save_directory):
|
| 521 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 522 |
+
return
|
| 523 |
+
vocab_file = os.path.join(
|
| 524 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 525 |
+
)
|
| 526 |
+
merge_file = os.path.join(
|
| 527 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 531 |
+
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
| 532 |
+
|
| 533 |
+
index = 0
|
| 534 |
+
with open(merge_file, "w", encoding="utf-8") as writer:
|
| 535 |
+
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
| 536 |
+
if index != token_index:
|
| 537 |
+
logger.warning(
|
| 538 |
+
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
|
| 539 |
+
" Please check that the tokenizer is not corrupted!"
|
| 540 |
+
)
|
| 541 |
+
index = token_index
|
| 542 |
+
writer.write(" ".join(bpe_tokens) + "\n")
|
| 543 |
+
index += 1
|
| 544 |
+
|
| 545 |
+
return vocab_file, merge_file
|
| 546 |
+
|
| 547 |
+
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.__getstate__
|
| 548 |
+
def __getstate__(self):
|
| 549 |
+
state = self.__dict__.copy()
|
| 550 |
+
state["sm"] = None
|
| 551 |
+
return state
|
| 552 |
+
|
| 553 |
+
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.__setstate__
|
| 554 |
+
def __setstate__(self, d):
|
| 555 |
+
self.__dict__ = d
|
| 556 |
+
|
| 557 |
+
try:
|
| 558 |
+
import sacremoses
|
| 559 |
+
except ImportError:
|
| 560 |
+
raise ImportError(
|
| 561 |
+
"You need to install sacremoses to use XLMTokenizer. "
|
| 562 |
+
"See https://pypi.org/project/sacremoses/ for installation."
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
self.sm = sacremoses
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
__all__ = ["FlaubertTokenizer"]
|
docs/transformers/build/lib/transformers/models/flava/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_flava import *
|
| 22 |
+
from .feature_extraction_flava import *
|
| 23 |
+
from .image_processing_flava import *
|
| 24 |
+
from .image_processing_flava_fast import *
|
| 25 |
+
from .modeling_flava import *
|
| 26 |
+
from .processing_flava import *
|
| 27 |
+
else:
|
| 28 |
+
import sys
|
| 29 |
+
|
| 30 |
+
_file = globals()["__file__"]
|
| 31 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/flava/configuration_flava.py
ADDED
|
@@ -0,0 +1,701 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""FLAVA model configurations"""
|
| 16 |
+
|
| 17 |
+
from typing import Any, Dict
|
| 18 |
+
|
| 19 |
+
from ...configuration_utils import PretrainedConfig
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class FlavaImageConfig(PretrainedConfig):
|
| 27 |
+
r"""
|
| 28 |
+
This is the configuration class to store the configuration of a [`FlavaImageModel`]. It is used to instantiate an
|
| 29 |
+
FLAVA model according to the specified arguments, defining the model architecture.
|
| 30 |
+
|
| 31 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
|
| 32 |
+
[facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
|
| 33 |
+
|
| 34 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 35 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 40 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 41 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 42 |
+
Number of hidden layers in the Transformer encoder.
|
| 43 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 44 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 45 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 46 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 47 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 48 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 49 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 50 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 51 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 52 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 53 |
+
The dropout ratio for the attention probabilities.
|
| 54 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 55 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 56 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 57 |
+
The epsilon used by the layer normalization layers.
|
| 58 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 59 |
+
The size (resolution) of each image.
|
| 60 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 61 |
+
The size (resolution) of each patch.
|
| 62 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 63 |
+
The number of input channels.
|
| 64 |
+
qkv_bias (`bool`, *optional*, defaults to `True`):
|
| 65 |
+
Whether to add a bias to the queries, keys and values.
|
| 66 |
+
mask_token (`bool`, *optional*, defaults to `True`):
|
| 67 |
+
Whether to use a mask token or not. Used in MIM (Masked Image Modeling) loss for FLAVA.
|
| 68 |
+
vocab_size (`int`, *optional*, defaults to 8192):
|
| 69 |
+
Vocabulary size of the [`FlavaImageCodebook`] used in conjunction with [`FlavaImageModel`] for MIM (Masked
|
| 70 |
+
Image Modeling) loss for FLAVA.
|
| 71 |
+
|
| 72 |
+
Example:
|
| 73 |
+
|
| 74 |
+
```python
|
| 75 |
+
>>> from transformers import FlavaImageConfig, FlavaImageModel
|
| 76 |
+
|
| 77 |
+
>>> # Initializing a FlavaImageModel with style configuration
|
| 78 |
+
>>> configuration = FlavaImageConfig()
|
| 79 |
+
|
| 80 |
+
>>> # Initializing a FlavaImageModel model (with random weights) from the style configuration
|
| 81 |
+
>>> model = FlavaImageModel(configuration)
|
| 82 |
+
|
| 83 |
+
>>> # Accessing the model configuration
|
| 84 |
+
>>> configuration = model.config
|
| 85 |
+
```"""
|
| 86 |
+
|
| 87 |
+
model_type = "flava_image_model"
|
| 88 |
+
base_config_key = "image_config"
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
hidden_size: int = 768,
|
| 93 |
+
num_hidden_layers: int = 12,
|
| 94 |
+
num_attention_heads: int = 12,
|
| 95 |
+
intermediate_size: int = 3072,
|
| 96 |
+
hidden_act: int = "gelu",
|
| 97 |
+
hidden_dropout_prob: float = 0.0,
|
| 98 |
+
attention_probs_dropout_prob: float = 0.0,
|
| 99 |
+
initializer_range: float = 0.02,
|
| 100 |
+
layer_norm_eps: float = 1e-12,
|
| 101 |
+
image_size: int = 224,
|
| 102 |
+
patch_size: int = 16,
|
| 103 |
+
num_channels: int = 3,
|
| 104 |
+
qkv_bias: bool = True,
|
| 105 |
+
mask_token: bool = True,
|
| 106 |
+
vocab_size: int = 8192,
|
| 107 |
+
**kwargs,
|
| 108 |
+
):
|
| 109 |
+
super().__init__(**kwargs)
|
| 110 |
+
|
| 111 |
+
self.hidden_size = hidden_size
|
| 112 |
+
self.num_hidden_layers = num_hidden_layers
|
| 113 |
+
self.num_attention_heads = num_attention_heads
|
| 114 |
+
self.intermediate_size = intermediate_size
|
| 115 |
+
self.hidden_act = hidden_act
|
| 116 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 117 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 118 |
+
self.initializer_range = initializer_range
|
| 119 |
+
self.layer_norm_eps = layer_norm_eps
|
| 120 |
+
self.image_size = image_size
|
| 121 |
+
self.patch_size = patch_size
|
| 122 |
+
self.num_channels = num_channels
|
| 123 |
+
self.qkv_bias = qkv_bias
|
| 124 |
+
self.mask_token = mask_token
|
| 125 |
+
self.vocab_size = vocab_size
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class FlavaTextConfig(PretrainedConfig):
|
| 129 |
+
r"""
|
| 130 |
+
This is the configuration class to store the configuration of a [`FlavaTextModel`]. It is used to instantiate an
|
| 131 |
+
FLAVA model according to the specified arguments, defining the model architecture.
|
| 132 |
+
|
| 133 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
|
| 134 |
+
[facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
|
| 135 |
+
|
| 136 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 137 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
vocab_size (`int`, *optional*, defaults to 30522):
|
| 142 |
+
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
|
| 143 |
+
`inputs_ids` passed when calling [`FlavaTextModel`].
|
| 144 |
+
type_vocab_size (`int`, *optional*, defaults to 2):
|
| 145 |
+
The vocabulary size of the `token_type_ids` passed when calling [`FlavaTextModel`]. Note that even though
|
| 146 |
+
text encoder allows `token_type_ids`'s value as 2, for text-only pretraining and fine-tuning, only 1 is
|
| 147 |
+
used similar to RoBERTa.
|
| 148 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
| 149 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 150 |
+
just in case (e.g., 512 or 1024 or 2048). For VL, max_length passed to model is 77.
|
| 151 |
+
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
| 152 |
+
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
| 153 |
+
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
| 154 |
+
[Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
|
| 155 |
+
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
|
| 156 |
+
with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
|
| 157 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 158 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 159 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 160 |
+
Number of hidden layers in the Transformer encoder.
|
| 161 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 162 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 163 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 164 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 165 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 166 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 167 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 168 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 169 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 170 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 171 |
+
The dropout ratio for the attention probabilities.
|
| 172 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 173 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 174 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 175 |
+
The epsilon used by the layer normalization layers.
|
| 176 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 177 |
+
The size (resolution) of each image.
|
| 178 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 179 |
+
The size (resolution) of each patch.
|
| 180 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 181 |
+
The number of input channels.
|
| 182 |
+
qkv_bias (`bool`, *optional*, defaults to `True`):
|
| 183 |
+
Whether to add a bias to the queries, keys and values.
|
| 184 |
+
|
| 185 |
+
Example:
|
| 186 |
+
|
| 187 |
+
```python
|
| 188 |
+
>>> from transformers import FlavaTextConfig, FlavaTextModel
|
| 189 |
+
|
| 190 |
+
>>> # Initializing a FlavaTextModel with style configuration
|
| 191 |
+
>>> configuration = FlavaTextConfig()
|
| 192 |
+
|
| 193 |
+
>>> # Initializing a FlavaTextModel model (with random weights) from the style configuration
|
| 194 |
+
>>> model = FlavaTextModel(configuration)
|
| 195 |
+
|
| 196 |
+
>>> # Accessing the model configuration
|
| 197 |
+
>>> configuration = model.config
|
| 198 |
+
```"""
|
| 199 |
+
|
| 200 |
+
model_type = "flava_text_model"
|
| 201 |
+
base_config_key = "text_config"
|
| 202 |
+
|
| 203 |
+
def __init__(
|
| 204 |
+
self,
|
| 205 |
+
vocab_size: int = 30522,
|
| 206 |
+
type_vocab_size: int = 2,
|
| 207 |
+
max_position_embeddings: int = 512,
|
| 208 |
+
position_embedding_type: str = "absolute",
|
| 209 |
+
hidden_size: int = 768,
|
| 210 |
+
num_hidden_layers: int = 12,
|
| 211 |
+
num_attention_heads: int = 12,
|
| 212 |
+
intermediate_size: int = 3072,
|
| 213 |
+
hidden_act: str = "gelu",
|
| 214 |
+
hidden_dropout_prob: float = 0.0,
|
| 215 |
+
attention_probs_dropout_prob: float = 0.0,
|
| 216 |
+
initializer_range: float = 0.02,
|
| 217 |
+
layer_norm_eps: float = 1e-12,
|
| 218 |
+
pad_token_id: int = 0,
|
| 219 |
+
qkv_bias: bool = True,
|
| 220 |
+
**kwargs,
|
| 221 |
+
):
|
| 222 |
+
super().__init__(**kwargs)
|
| 223 |
+
|
| 224 |
+
self.vocab_size = vocab_size
|
| 225 |
+
self.type_vocab_size = type_vocab_size
|
| 226 |
+
self.max_position_embeddings = max_position_embeddings
|
| 227 |
+
self.position_embedding_type = position_embedding_type
|
| 228 |
+
self.hidden_size = hidden_size
|
| 229 |
+
self.num_hidden_layers = num_hidden_layers
|
| 230 |
+
self.num_attention_heads = num_attention_heads
|
| 231 |
+
self.intermediate_size = intermediate_size
|
| 232 |
+
self.hidden_act = hidden_act
|
| 233 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 234 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 235 |
+
self.initializer_range = initializer_range
|
| 236 |
+
self.layer_norm_eps = layer_norm_eps
|
| 237 |
+
self.qkv_bias = qkv_bias
|
| 238 |
+
self.pad_token_id = pad_token_id
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class FlavaMultimodalConfig(PretrainedConfig):
|
| 242 |
+
r"""
|
| 243 |
+
This is the configuration class to store the configuration of a [`FlavaMultimodalModel`]. It is used to instantiate
|
| 244 |
+
an FLAVA model according to the specified arguments, defining the model architecture.
|
| 245 |
+
|
| 246 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
|
| 247 |
+
[facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
|
| 248 |
+
|
| 249 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 250 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 255 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 256 |
+
num_hidden_layers (`int`, *optional*, defaults to 6):
|
| 257 |
+
Number of hidden layers in the Transformer encoder.
|
| 258 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 259 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 260 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 261 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 262 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 263 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 264 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 265 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 266 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 267 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 268 |
+
The dropout ratio for the attention probabilities.
|
| 269 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 270 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 271 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 272 |
+
The epsilon used by the layer normalization layers.
|
| 273 |
+
qkv_bias (`bool`, *optional*, defaults to `True`):
|
| 274 |
+
Whether to add a bias to the queries, keys and values.
|
| 275 |
+
use_cls_token (`bool`, *optional*, defaults to `True`):
|
| 276 |
+
Whether to use an extra CLS token for multimodal settings. Usually needed by the FLAVA model.
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
Example:
|
| 280 |
+
|
| 281 |
+
```python
|
| 282 |
+
>>> from transformers import FlavaMultimodalConfig, FlavaMultimodalModel
|
| 283 |
+
|
| 284 |
+
>>> # Initializing a FlavaMultimodalModel with style configuration
|
| 285 |
+
>>> configuration = FlavaMultimodalConfig()
|
| 286 |
+
|
| 287 |
+
>>> # Initializing a FlavaMultimodalModel model (with random weights) from the style configuration
|
| 288 |
+
>>> model = FlavaMultimodalModel(configuration)
|
| 289 |
+
|
| 290 |
+
>>> # Accessing the model configuration
|
| 291 |
+
>>> configuration = model.config
|
| 292 |
+
```"""
|
| 293 |
+
|
| 294 |
+
model_type = "flava_multimodal_model"
|
| 295 |
+
base_config_key = "multimodal_config"
|
| 296 |
+
|
| 297 |
+
def __init__(
|
| 298 |
+
self,
|
| 299 |
+
hidden_size: int = 768,
|
| 300 |
+
num_hidden_layers: int = 6,
|
| 301 |
+
num_attention_heads: int = 12,
|
| 302 |
+
intermediate_size: int = 3072,
|
| 303 |
+
hidden_act: int = "gelu",
|
| 304 |
+
hidden_dropout_prob: int = 0.0,
|
| 305 |
+
attention_probs_dropout_prob: int = 0.0,
|
| 306 |
+
initializer_range: float = 0.02,
|
| 307 |
+
layer_norm_eps: float = 1e-12,
|
| 308 |
+
qkv_bias: bool = True,
|
| 309 |
+
use_cls_token: bool = True,
|
| 310 |
+
**kwargs,
|
| 311 |
+
):
|
| 312 |
+
super().__init__(**kwargs)
|
| 313 |
+
|
| 314 |
+
self.hidden_size = hidden_size
|
| 315 |
+
self.num_hidden_layers = num_hidden_layers
|
| 316 |
+
self.num_attention_heads = num_attention_heads
|
| 317 |
+
self.intermediate_size = intermediate_size
|
| 318 |
+
self.hidden_act = hidden_act
|
| 319 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 320 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 321 |
+
self.initializer_range = initializer_range
|
| 322 |
+
self.layer_norm_eps = layer_norm_eps
|
| 323 |
+
self.qkv_bias = qkv_bias
|
| 324 |
+
self.use_cls_token = use_cls_token
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class FlavaImageCodebookConfig(PretrainedConfig):
|
| 328 |
+
model_type = "flava_image_codebook"
|
| 329 |
+
base_config_key = "image_codebook_config"
|
| 330 |
+
|
| 331 |
+
r"""
|
| 332 |
+
[`FlavaImageCodebookConfig`] is the configuration class to store the configuration of a [`FlavaImageCodebook`]. It
|
| 333 |
+
is used to instantiate an FLAVA model according to the specified arguments, defining the model architecture.
|
| 334 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
|
| 335 |
+
[facebook/flava-image-codebook](https://huggingface.co/facebook/flava-image-codebook) architecture.
|
| 336 |
+
|
| 337 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 338 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 339 |
+
|
| 340 |
+
Args:
|
| 341 |
+
num_groups (`int`, *optional*, defaults to 4):
|
| 342 |
+
Number of groups to be created. This parameter as of now doesn't affect the model and is used for some
|
| 343 |
+
internal calculation and estimations.
|
| 344 |
+
input_channels (`int`, *optional*, defaults to 3):
|
| 345 |
+
Number of channels in the image to be passed.
|
| 346 |
+
num_blocks_per_group (`int`, *optional*, defaults to 2):
|
| 347 |
+
Number of conv-based blocks per group.
|
| 348 |
+
hidden_size (`int`, *optional*, defaults to 256):
|
| 349 |
+
Size of hidden dim for the blocks.
|
| 350 |
+
vocab_size (`int`, *optional*, defaults to 8192):
|
| 351 |
+
Size of the output vocabulary for the codebook.
|
| 352 |
+
freeze (`bool`, defaults to `True`):
|
| 353 |
+
Whether to freeze the weights of the model.
|
| 354 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 355 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 356 |
+
kwargs (*optional*):
|
| 357 |
+
Dictionary of keyword arguments.
|
| 358 |
+
|
| 359 |
+
Example:
|
| 360 |
+
|
| 361 |
+
```python
|
| 362 |
+
>>> from transformers import FlavaImageCodebookConfig, FlavaImageCodebook
|
| 363 |
+
|
| 364 |
+
>>> # Initializing a FlavaImageCodebook with style configuration
|
| 365 |
+
>>> configuration = FlavaImageCodebookConfig()
|
| 366 |
+
|
| 367 |
+
>>> # Initializing a FlavaImageCodebook model (with random weights) from the style configuration
|
| 368 |
+
>>> model = FlavaImageCodebook(configuration)
|
| 369 |
+
>>> # Accessing the model configuration
|
| 370 |
+
>>> configuration = model.config
|
| 371 |
+
```
|
| 372 |
+
"""
|
| 373 |
+
|
| 374 |
+
def __init__(
|
| 375 |
+
self,
|
| 376 |
+
num_groups: int = 4,
|
| 377 |
+
input_channels: int = 3,
|
| 378 |
+
num_blocks_per_group: int = 2,
|
| 379 |
+
hidden_size: int = 256,
|
| 380 |
+
vocab_size: int = 8192,
|
| 381 |
+
freeze: int = True,
|
| 382 |
+
initializer_range: float = 0.02,
|
| 383 |
+
**kwargs,
|
| 384 |
+
):
|
| 385 |
+
super().__init__(**kwargs)
|
| 386 |
+
self.num_groups = num_groups
|
| 387 |
+
self.input_channels = input_channels
|
| 388 |
+
self.num_blocks_per_group = num_blocks_per_group
|
| 389 |
+
self.hidden_size = hidden_size
|
| 390 |
+
self.vocab_size = vocab_size
|
| 391 |
+
self.freeze = freeze
|
| 392 |
+
self.initializer_range = initializer_range
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class FlavaConfig(PretrainedConfig):
|
| 396 |
+
r"""
|
| 397 |
+
[`FlavaConfig`] is the configuration class to store the configuration of a [`FlavaModel`]. It is used to
|
| 398 |
+
instantiate FLAVA model according to the specified arguments, defining the text model, image model, image codebook
|
| 399 |
+
and multimodal model configs. Instantiating a configuration with the defaults will yield a similar configuration to
|
| 400 |
+
that of the FLAVA [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
|
| 401 |
+
|
| 402 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 403 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
text_config (`dict`, *optional*):
|
| 407 |
+
Dictionary of configuration options used to initialize [`FlavaTextConfig`].
|
| 408 |
+
image_config (`dict`, *optional*):
|
| 409 |
+
Dictionary of configuration options used to initialize [`FlavaImageConfig`].
|
| 410 |
+
multimodal_config (`dict`, *optional*):
|
| 411 |
+
Dictionary of configuration options used to initialize [`FlavaMultimodalConfig`].
|
| 412 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 413 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 414 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 415 |
+
The epsilon used by the layer normalization layers.
|
| 416 |
+
projection_dim (`int`, *optional*, defaults to 512):
|
| 417 |
+
Dimensionality of text and image projection layers.
|
| 418 |
+
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
|
| 419 |
+
The initial value of the *logit_scale* parameter. Default is used as per the original FLAVA/CLIP
|
| 420 |
+
implementation.
|
| 421 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 422 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 423 |
+
ce_ignore_index (`int`, *optional*, defaults to -100):
|
| 424 |
+
Cross entropy index to ignore.
|
| 425 |
+
mim_weight (`float`, *optional*, defaults to 1.0):
|
| 426 |
+
Weight to be assigned to MIM (Masked Image Modeling) unimodal loss
|
| 427 |
+
mlm_weight (`float`, *optional*, defaults to 1.0):
|
| 428 |
+
Weight to be assigned to MLM (Masked Language Modeling) unimodal loss
|
| 429 |
+
global_contrastive_weight (`float`, *optional*, defaults to 1.0):
|
| 430 |
+
Weight to be assigned to global contrastive cross-alignment loss.
|
| 431 |
+
itm_weight (`float`, *optional*, defaults to 1.0):
|
| 432 |
+
Weight to be assigned to image-text matching multimodal loss.
|
| 433 |
+
mmm_image_weight (`float`, *optional*, defaults to 1.0):
|
| 434 |
+
Weight to be assigned to MMM loss's image part.
|
| 435 |
+
mmm_text_weight (`float`, *optional*, defaults to 1.0):
|
| 436 |
+
Weight to be assigned to MMM loss's text part.
|
| 437 |
+
global_backprop_contrastive (`bool`, *optional*, defaults to `True`):
|
| 438 |
+
Whether to use global backpropgation through all workers in contrastive loss.
|
| 439 |
+
skip_unmasked_multimodal_encoder (`bool`, *optional*, defaults to `True`):
|
| 440 |
+
Whether to skip running unmasked multimodal encoder whose outputs are not used by FLAVA losses.
|
| 441 |
+
return_loss (`bool`, *optional*, defaults to `True`):
|
| 442 |
+
Whether to return loss or not
|
| 443 |
+
|
| 444 |
+
kwargs (*optional*):
|
| 445 |
+
Dictionary of keyword arguments.
|
| 446 |
+
|
| 447 |
+
Example:
|
| 448 |
+
|
| 449 |
+
```python
|
| 450 |
+
>>> from transformers import FlavaConfig, FlavaModel, FlavaForPreTraining
|
| 451 |
+
|
| 452 |
+
>>> # Initializing a FlavaConfig with style configuration
|
| 453 |
+
>>> configuration = FlavaConfig()
|
| 454 |
+
|
| 455 |
+
>>> # Initializing a FlavaModel and FlavaForPreTraining model (with random weights) from the style configuration
|
| 456 |
+
>>> model = FlavaModel(configuration)
|
| 457 |
+
>>> model_pre = FlavaForPreTraining(configuration)
|
| 458 |
+
|
| 459 |
+
>>> # Accessing the model configuration
|
| 460 |
+
>>> configuration = model.config
|
| 461 |
+
>>> configuration_pre = model_pre.config
|
| 462 |
+
```
|
| 463 |
+
"""
|
| 464 |
+
|
| 465 |
+
model_type = "flava"
|
| 466 |
+
sub_configs = {
|
| 467 |
+
"text_config": FlavaTextConfig,
|
| 468 |
+
"image_config": FlavaImageConfig,
|
| 469 |
+
"multimodal_config": FlavaMultimodalConfig,
|
| 470 |
+
"image_codebook_config": FlavaImageCodebookConfig,
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
def __init__(
|
| 474 |
+
self,
|
| 475 |
+
image_config: Dict[str, Any] = None,
|
| 476 |
+
text_config: Dict[str, Any] = None,
|
| 477 |
+
multimodal_config: Dict[str, Any] = None,
|
| 478 |
+
image_codebook_config: Dict[str, Any] = None,
|
| 479 |
+
hidden_size: int = 768,
|
| 480 |
+
layer_norm_eps: float = 1e-12,
|
| 481 |
+
projection_dim: int = 768,
|
| 482 |
+
init_codebook: bool = True,
|
| 483 |
+
logit_scale_init_value: float = 2.6592,
|
| 484 |
+
initializer_range: float = 0.02,
|
| 485 |
+
ce_ignore_index: int = -100,
|
| 486 |
+
mim_weight: float = 1.0,
|
| 487 |
+
mlm_weight: float = 1.0,
|
| 488 |
+
global_contrastive_weight: float = 1.0,
|
| 489 |
+
itm_weight: float = 1.0,
|
| 490 |
+
mmm_image_weight: float = 1.0,
|
| 491 |
+
mmm_text_weight: float = 1.0,
|
| 492 |
+
global_backprop_contrastive: bool = True,
|
| 493 |
+
skip_unmasked_multimodal_encoder: bool = True,
|
| 494 |
+
return_loss: bool = True,
|
| 495 |
+
**kwargs,
|
| 496 |
+
):
|
| 497 |
+
# If `_config_dict` exist, we use them for the backward compatibility.
|
| 498 |
+
# We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
|
| 499 |
+
# of confusion!).
|
| 500 |
+
text_config_dict = kwargs.pop("text_config_dict", None)
|
| 501 |
+
image_config_dict = kwargs.pop("image_config_dict", None)
|
| 502 |
+
multimodal_config_dict = kwargs.pop("multimodal_config_dict", None)
|
| 503 |
+
image_codebook_config_dict = kwargs.pop("image_codebook_config_dict", None)
|
| 504 |
+
|
| 505 |
+
super().__init__(**kwargs)
|
| 506 |
+
|
| 507 |
+
# Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
|
| 508 |
+
# `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
|
| 509 |
+
# cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
|
| 510 |
+
if text_config_dict is not None:
|
| 511 |
+
if text_config is None:
|
| 512 |
+
text_config = {}
|
| 513 |
+
|
| 514 |
+
# This is the complete result when using `text_config_dict`.
|
| 515 |
+
_text_config_dict = FlavaTextConfig(**text_config_dict).to_dict()
|
| 516 |
+
|
| 517 |
+
# Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
|
| 518 |
+
for key, value in _text_config_dict.items():
|
| 519 |
+
if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
|
| 520 |
+
# If specified in `text_config_dict`
|
| 521 |
+
if key in text_config_dict:
|
| 522 |
+
message = (
|
| 523 |
+
f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
|
| 524 |
+
f'The value `text_config_dict["{key}"]` will be used instead.'
|
| 525 |
+
)
|
| 526 |
+
# If inferred from default argument values (just to be super careful)
|
| 527 |
+
else:
|
| 528 |
+
message = (
|
| 529 |
+
f"`text_config_dict` is provided which will be used to initialize `FlavaTextConfig`. The "
|
| 530 |
+
f'value `text_config["{key}"]` will be overridden.'
|
| 531 |
+
)
|
| 532 |
+
logger.info(message)
|
| 533 |
+
|
| 534 |
+
# Update all values in `text_config` with the ones in `_text_config_dict`.
|
| 535 |
+
text_config.update(_text_config_dict)
|
| 536 |
+
|
| 537 |
+
if image_config_dict is not None:
|
| 538 |
+
if image_config is None:
|
| 539 |
+
image_config = {}
|
| 540 |
+
|
| 541 |
+
# This is the complete result when using `image_config_dict`.
|
| 542 |
+
_image_config_dict = FlavaImageConfig(**image_config_dict).to_dict()
|
| 543 |
+
# convert keys to string instead of integer
|
| 544 |
+
if "id2label" in _image_config_dict:
|
| 545 |
+
_image_config_dict["id2label"] = {
|
| 546 |
+
str(key): value for key, value in _image_config_dict["id2label"].items()
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
# Give a warning if the values exist in both `_image_config_dict` and `image_config` but being different.
|
| 550 |
+
for key, value in _image_config_dict.items():
|
| 551 |
+
if key in image_config and value != image_config[key] and key not in ["transformers_version"]:
|
| 552 |
+
# If specified in `image_config_dict`
|
| 553 |
+
if key in image_config_dict:
|
| 554 |
+
message = (
|
| 555 |
+
f"`{key}` is found in both `image_config_dict` and `image_config` but with different "
|
| 556 |
+
f'values. The value `image_config_dict["{key}"]` will be used instead.'
|
| 557 |
+
)
|
| 558 |
+
# If inferred from default argument values (just to be super careful)
|
| 559 |
+
else:
|
| 560 |
+
message = (
|
| 561 |
+
f"`image_config_dict` is provided which will be used to initialize `FlavaImageConfig`. "
|
| 562 |
+
f'The value `image_config["{key}"]` will be overridden.'
|
| 563 |
+
)
|
| 564 |
+
logger.info(message)
|
| 565 |
+
|
| 566 |
+
# Update all values in `image_config` with the ones in `_image_config_dict`.
|
| 567 |
+
image_config.update(_image_config_dict)
|
| 568 |
+
|
| 569 |
+
if multimodal_config_dict is not None:
|
| 570 |
+
if multimodal_config is None:
|
| 571 |
+
multimodal_config = {}
|
| 572 |
+
|
| 573 |
+
# This is the complete result when using `multimodal_config_dict`.
|
| 574 |
+
_multimodal_config_dict = FlavaMultimodalConfig(**multimodal_config_dict).to_dict()
|
| 575 |
+
|
| 576 |
+
# Give a warning if the values exist in both `_multimodal_config_dict` and `multimodal_config` but being
|
| 577 |
+
# different.
|
| 578 |
+
for key, value in _multimodal_config_dict.items():
|
| 579 |
+
if (
|
| 580 |
+
key in multimodal_config
|
| 581 |
+
and value != multimodal_config[key]
|
| 582 |
+
and key not in ["transformers_version"]
|
| 583 |
+
):
|
| 584 |
+
# If specified in `multimodal_config_dict`
|
| 585 |
+
if key in multimodal_config_dict:
|
| 586 |
+
message = (
|
| 587 |
+
f"`{key}` is found in both `multimodal_config_dict` and `multimodal_config` but with "
|
| 588 |
+
f'different values. The value `multimodal_config_dict["{key}"]` will be used instead.'
|
| 589 |
+
)
|
| 590 |
+
# If inferred from default argument values (just to be super careful)
|
| 591 |
+
else:
|
| 592 |
+
message = (
|
| 593 |
+
f"`multimodal_config_dict` is provided which will be used to initialize "
|
| 594 |
+
f'`FlavaMultimodalConfig`. The value `multimodal_config["{key}"]` will be overridden.'
|
| 595 |
+
)
|
| 596 |
+
logger.info(message)
|
| 597 |
+
|
| 598 |
+
# Update all values in `multimodal_config` with the ones in `_multimodal_config_dict`.
|
| 599 |
+
multimodal_config.update(_multimodal_config_dict)
|
| 600 |
+
|
| 601 |
+
if image_codebook_config_dict is not None:
|
| 602 |
+
if image_codebook_config is None:
|
| 603 |
+
image_codebook_config = {}
|
| 604 |
+
|
| 605 |
+
# This is the complete result when using `image_codebook_config_dict`.
|
| 606 |
+
_image_codebook_config_dict = FlavaImageCodebookConfig(**image_codebook_config_dict).to_dict()
|
| 607 |
+
|
| 608 |
+
# Give a warning if the values exist in both `_image_codebook_config_dict` and `image_codebook_config` but
|
| 609 |
+
# being different.
|
| 610 |
+
for key, value in _image_codebook_config_dict.items():
|
| 611 |
+
if (
|
| 612 |
+
key in image_codebook_config
|
| 613 |
+
and value != image_codebook_config[key]
|
| 614 |
+
and key not in ["transformers_version"]
|
| 615 |
+
):
|
| 616 |
+
# If specified in `image_codebook_config_dict`
|
| 617 |
+
if key in image_codebook_config_dict:
|
| 618 |
+
message = (
|
| 619 |
+
f"`{key}` is found in both `image_codebook_config_dict` and `image_codebook_config` but "
|
| 620 |
+
f'with different values. The value `image_codebook_config_dict["{key}"]` will be used '
|
| 621 |
+
"instead."
|
| 622 |
+
)
|
| 623 |
+
# If inferred from default argument values (just to be super careful)
|
| 624 |
+
else:
|
| 625 |
+
message = (
|
| 626 |
+
f"`image_codebook_config_dict` is provided which will be used to initialize "
|
| 627 |
+
f'`FlavaImageCodebookConfig`. The value `image_codebook_config["{key}"]` will be overridden.'
|
| 628 |
+
)
|
| 629 |
+
logger.info(message)
|
| 630 |
+
|
| 631 |
+
# Update all values in `image_codebook_config` with the ones in `_image_codebook_config_dict`.
|
| 632 |
+
image_codebook_config.update(_image_codebook_config_dict)
|
| 633 |
+
|
| 634 |
+
if image_config is None:
|
| 635 |
+
image_config = {}
|
| 636 |
+
logger.info("`image_config` is `None`. initializing the `FlavaImageConfig` with default values.")
|
| 637 |
+
|
| 638 |
+
if text_config is None:
|
| 639 |
+
text_config = {}
|
| 640 |
+
logger.info("`text_config` is `None`. Initializing the `FlavaTextConfig` with default values.")
|
| 641 |
+
|
| 642 |
+
if multimodal_config is None:
|
| 643 |
+
multimodal_config = {}
|
| 644 |
+
logger.info("`multimodal_config` is `None`. initializing the `FlavaMultimodalConfig` with default values.")
|
| 645 |
+
|
| 646 |
+
if image_codebook_config is None:
|
| 647 |
+
image_codebook_config = {}
|
| 648 |
+
logger.info(
|
| 649 |
+
"`image_codebook_config` is `None`. initializing the `FlavaImageCodebookConfig` with default values."
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
self.image_config = FlavaImageConfig(**image_config)
|
| 653 |
+
self.text_config = FlavaTextConfig(**text_config)
|
| 654 |
+
self.multimodal_config = FlavaMultimodalConfig(**multimodal_config)
|
| 655 |
+
self.image_codebook_config = FlavaImageCodebookConfig(**image_codebook_config)
|
| 656 |
+
self.projection_dim = projection_dim
|
| 657 |
+
self.init_codebook = init_codebook
|
| 658 |
+
|
| 659 |
+
self.hidden_size = hidden_size
|
| 660 |
+
self.layer_norm_eps = layer_norm_eps
|
| 661 |
+
self.initializer_range = initializer_range
|
| 662 |
+
self.logit_scale_init_value = logit_scale_init_value
|
| 663 |
+
self.initializer_factor = 1.0
|
| 664 |
+
self.ce_ignore_index = ce_ignore_index
|
| 665 |
+
self.mim_weight = mim_weight
|
| 666 |
+
self.mlm_weight = mlm_weight
|
| 667 |
+
self.global_contrastive_weight = global_contrastive_weight
|
| 668 |
+
self.itm_weight = itm_weight
|
| 669 |
+
self.mmm_image_weight = mmm_image_weight
|
| 670 |
+
self.mmm_text_weight = mmm_text_weight
|
| 671 |
+
self.global_backprop_contrastive = global_backprop_contrastive
|
| 672 |
+
self.skip_unmasked_multimodal_encoder = skip_unmasked_multimodal_encoder
|
| 673 |
+
self.return_loss = return_loss
|
| 674 |
+
|
| 675 |
+
@classmethod
|
| 676 |
+
def from_configs(
|
| 677 |
+
cls,
|
| 678 |
+
image_config: FlavaImageConfig,
|
| 679 |
+
text_config: FlavaTextConfig,
|
| 680 |
+
multimodal_config: FlavaMultimodalConfig,
|
| 681 |
+
image_codebook_config: FlavaImageCodebookConfig,
|
| 682 |
+
**kwargs,
|
| 683 |
+
):
|
| 684 |
+
r"""
|
| 685 |
+
Instantiate a [`FlavaConfig`] (or a derived class) from flava text model configuration, flava image model
|
| 686 |
+
configuration, flava multimodal model and flava codebook model configuration.
|
| 687 |
+
|
| 688 |
+
Returns:
|
| 689 |
+
[`FlavaConfig`]: An instance of a configuration object
|
| 690 |
+
"""
|
| 691 |
+
|
| 692 |
+
return cls(
|
| 693 |
+
image_config=image_config.to_dict(),
|
| 694 |
+
text_config=text_config.to_dict(),
|
| 695 |
+
multimodal_config=multimodal_config.to_dict(),
|
| 696 |
+
image_codebook_config=image_codebook_config.to_dict(),
|
| 697 |
+
**kwargs,
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
__all__ = ["FlavaConfig", "FlavaImageCodebookConfig", "FlavaImageConfig", "FlavaMultimodalConfig", "FlavaTextConfig"]
|
docs/transformers/build/lib/transformers/models/flava/convert_dalle_to_flava_codebook.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from transformers import FlavaImageCodebook, FlavaImageCodebookConfig
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def rreplace(s, old, new, occurrence):
|
| 25 |
+
li = s.rsplit(old, occurrence)
|
| 26 |
+
return new.join(li)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def count_parameters(state_dict):
|
| 30 |
+
# encoder.embeddings are double copied in original FLAVA
|
| 31 |
+
return sum(param.float().sum() if "encoder.embeddings" not in key else 0 for key, param in state_dict.items())
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def upgrade_state_dict(state_dict):
|
| 35 |
+
upgrade = {}
|
| 36 |
+
|
| 37 |
+
group_keys = ["group_1", "group_2", "group_3", "group_4"]
|
| 38 |
+
for key, value in state_dict.items():
|
| 39 |
+
for group_key in group_keys:
|
| 40 |
+
if group_key in key:
|
| 41 |
+
key = key.replace(f"{group_key}.", f"{group_key}.group.")
|
| 42 |
+
|
| 43 |
+
if "res_path" in key:
|
| 44 |
+
key = key.replace("res_path.", "res_path.path.")
|
| 45 |
+
|
| 46 |
+
if key.endswith(".w"):
|
| 47 |
+
key = rreplace(key, ".w", ".weight", 1)
|
| 48 |
+
if key.endswith(".b"):
|
| 49 |
+
key = rreplace(key, ".b", ".bias", 1)
|
| 50 |
+
|
| 51 |
+
upgrade[key] = value.float()
|
| 52 |
+
|
| 53 |
+
return upgrade
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def convert_dalle_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None, save_checkpoint=True):
|
| 58 |
+
"""
|
| 59 |
+
Copy/paste/tweak model's weights to transformers design.
|
| 60 |
+
"""
|
| 61 |
+
from dall_e import Encoder
|
| 62 |
+
|
| 63 |
+
encoder = Encoder()
|
| 64 |
+
if os.path.exists(checkpoint_path):
|
| 65 |
+
ckpt = torch.load(checkpoint_path, weights_only=True)
|
| 66 |
+
else:
|
| 67 |
+
ckpt = torch.hub.load_state_dict_from_url(checkpoint_path)
|
| 68 |
+
|
| 69 |
+
if isinstance(ckpt, Encoder):
|
| 70 |
+
ckpt = ckpt.state_dict()
|
| 71 |
+
encoder.load_state_dict(ckpt)
|
| 72 |
+
|
| 73 |
+
if config_path is not None:
|
| 74 |
+
config = FlavaImageCodebookConfig.from_pretrained(config_path)
|
| 75 |
+
else:
|
| 76 |
+
config = FlavaImageCodebookConfig()
|
| 77 |
+
|
| 78 |
+
hf_model = FlavaImageCodebook(config).eval()
|
| 79 |
+
state_dict = encoder.state_dict()
|
| 80 |
+
|
| 81 |
+
hf_state_dict = upgrade_state_dict(state_dict)
|
| 82 |
+
hf_model.load_state_dict(hf_state_dict)
|
| 83 |
+
hf_state_dict = hf_model.state_dict()
|
| 84 |
+
hf_count = count_parameters(hf_state_dict)
|
| 85 |
+
state_dict_count = count_parameters(state_dict)
|
| 86 |
+
|
| 87 |
+
assert torch.allclose(hf_count, state_dict_count, atol=1e-3)
|
| 88 |
+
|
| 89 |
+
if save_checkpoint:
|
| 90 |
+
hf_model.save_pretrained(pytorch_dump_folder_path)
|
| 91 |
+
else:
|
| 92 |
+
return hf_state_dict
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
parser = argparse.ArgumentParser()
|
| 97 |
+
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
| 98 |
+
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to flava checkpoint")
|
| 99 |
+
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
| 100 |
+
args = parser.parse_args()
|
| 101 |
+
|
| 102 |
+
convert_dalle_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
|
docs/transformers/build/lib/transformers/models/flava/convert_flava_original_pytorch_to_hf.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from transformers import FlavaConfig, FlavaForPreTraining
|
| 22 |
+
from transformers.models.flava.convert_dalle_to_flava_codebook import convert_dalle_checkpoint
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def count_parameters(state_dict):
|
| 26 |
+
# encoder.embeddings are double copied in original FLAVA
|
| 27 |
+
return sum(param.float().sum() if "encoder.embeddings" not in key else 0 for key, param in state_dict.items())
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def upgrade_state_dict(state_dict, codebook_state_dict):
|
| 31 |
+
upgrade = {}
|
| 32 |
+
|
| 33 |
+
for key, value in state_dict.items():
|
| 34 |
+
if "text_encoder.embeddings" in key or "image_encoder.embeddings" in key:
|
| 35 |
+
continue
|
| 36 |
+
|
| 37 |
+
key = key.replace("heads.cmd.mim_head.cls.predictions", "mmm_image_head")
|
| 38 |
+
key = key.replace("heads.cmd.mlm_head.cls.predictions", "mmm_text_head")
|
| 39 |
+
key = key.replace("heads.cmd.itm_head.cls", "itm_head")
|
| 40 |
+
key = key.replace("heads.cmd.itm_head.pooler", "itm_head.pooler")
|
| 41 |
+
key = key.replace("heads.cmd.clip_head.logit_scale", "flava.logit_scale")
|
| 42 |
+
key = key.replace("heads.fairseq_mlm.cls.predictions", "mlm_head")
|
| 43 |
+
key = key.replace("heads.imagenet.mim_head.cls.predictions", "mim_head")
|
| 44 |
+
key = key.replace("mm_text_projection", "flava.text_to_mm_projection")
|
| 45 |
+
key = key.replace("mm_image_projection", "flava.image_to_mm_projection")
|
| 46 |
+
key = key.replace("image_encoder.module", "flava.image_model")
|
| 47 |
+
key = key.replace("text_encoder.module", "flava.text_model")
|
| 48 |
+
key = key.replace("mm_encoder.module.encoder.cls_token", "flava.multimodal_model.cls_token")
|
| 49 |
+
key = key.replace("mm_encoder.module", "flava.multimodal_model")
|
| 50 |
+
key = key.replace("text_projection", "flava.text_projection")
|
| 51 |
+
key = key.replace("image_projection", "flava.image_projection")
|
| 52 |
+
|
| 53 |
+
upgrade[key] = value.float()
|
| 54 |
+
|
| 55 |
+
for key, value in codebook_state_dict.items():
|
| 56 |
+
upgrade[f"image_codebook.{key}"] = value
|
| 57 |
+
|
| 58 |
+
return upgrade
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@torch.no_grad()
|
| 62 |
+
def convert_flava_checkpoint(checkpoint_path, codebook_path, pytorch_dump_folder_path, config_path=None):
|
| 63 |
+
"""
|
| 64 |
+
Copy/paste/tweak model's weights to transformers design.
|
| 65 |
+
"""
|
| 66 |
+
if config_path is not None:
|
| 67 |
+
config = FlavaConfig.from_pretrained(config_path)
|
| 68 |
+
else:
|
| 69 |
+
config = FlavaConfig()
|
| 70 |
+
|
| 71 |
+
hf_model = FlavaForPreTraining(config).eval()
|
| 72 |
+
|
| 73 |
+
codebook_state_dict = convert_dalle_checkpoint(codebook_path, None, save_checkpoint=False)
|
| 74 |
+
|
| 75 |
+
if os.path.exists(checkpoint_path):
|
| 76 |
+
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
|
| 77 |
+
else:
|
| 78 |
+
state_dict = torch.hub.load_state_dict_from_url(checkpoint_path, map_location="cpu")
|
| 79 |
+
|
| 80 |
+
hf_state_dict = upgrade_state_dict(state_dict, codebook_state_dict)
|
| 81 |
+
hf_model.load_state_dict(hf_state_dict)
|
| 82 |
+
hf_state_dict = hf_model.state_dict()
|
| 83 |
+
hf_count = count_parameters(hf_state_dict)
|
| 84 |
+
state_dict_count = count_parameters(state_dict) + count_parameters(codebook_state_dict)
|
| 85 |
+
|
| 86 |
+
assert torch.allclose(hf_count, state_dict_count, atol=1e-3)
|
| 87 |
+
|
| 88 |
+
hf_model.save_pretrained(pytorch_dump_folder_path)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
parser = argparse.ArgumentParser()
|
| 93 |
+
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
| 94 |
+
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to flava checkpoint")
|
| 95 |
+
parser.add_argument("--codebook_path", default=None, type=str, help="Path to flava codebook checkpoint")
|
| 96 |
+
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
| 97 |
+
args = parser.parse_args()
|
| 98 |
+
|
| 99 |
+
convert_flava_checkpoint(args.checkpoint_path, args.codebook_path, args.pytorch_dump_folder_path, args.config_path)
|
docs/transformers/build/lib/transformers/models/flava/feature_extraction_flava.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Feature extractor class for FLAVA."""
|
| 16 |
+
|
| 17 |
+
import warnings
|
| 18 |
+
|
| 19 |
+
from ...utils import logging
|
| 20 |
+
from ...utils.import_utils import requires
|
| 21 |
+
from .image_processing_flava import FlavaImageProcessor
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@requires(backends=("vision",))
|
| 28 |
+
class FlavaFeatureExtractor(FlavaImageProcessor):
|
| 29 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 30 |
+
warnings.warn(
|
| 31 |
+
"The class FlavaFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
|
| 32 |
+
" use FlavaImageProcessor instead.",
|
| 33 |
+
FutureWarning,
|
| 34 |
+
)
|
| 35 |
+
super().__init__(*args, **kwargs)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
__all__ = ["FlavaFeatureExtractor"]
|
docs/transformers/build/lib/transformers/models/flava/image_processing_flava.py
ADDED
|
@@ -0,0 +1,705 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for Flava."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import random
|
| 19 |
+
from functools import lru_cache
|
| 20 |
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
| 25 |
+
from ...image_transforms import resize, to_channel_dimension_format
|
| 26 |
+
from ...image_utils import (
|
| 27 |
+
OPENAI_CLIP_MEAN,
|
| 28 |
+
OPENAI_CLIP_STD,
|
| 29 |
+
ChannelDimension,
|
| 30 |
+
ImageInput,
|
| 31 |
+
PILImageResampling,
|
| 32 |
+
infer_channel_dimension_format,
|
| 33 |
+
is_scaled_image,
|
| 34 |
+
make_list_of_images,
|
| 35 |
+
to_numpy_array,
|
| 36 |
+
valid_images,
|
| 37 |
+
validate_preprocess_arguments,
|
| 38 |
+
)
|
| 39 |
+
from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
|
| 40 |
+
from ...utils.import_utils import requires
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if is_vision_available():
|
| 44 |
+
import PIL
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
logger = logging.get_logger(__name__)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# These values are taken from CLIP
|
| 51 |
+
FLAVA_IMAGE_MEAN = OPENAI_CLIP_MEAN
|
| 52 |
+
FLAVA_IMAGE_STD = OPENAI_CLIP_STD
|
| 53 |
+
FLAVA_CODEBOOK_MEAN = [0.0, 0.0, 0.0]
|
| 54 |
+
FLAVA_CODEBOOK_STD = [1.0, 1.0, 1.0]
|
| 55 |
+
LOGIT_LAPLACE_EPS: float = 0.1
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Inspired from https://github.com/microsoft/unilm/blob/master/beit/masking_generator.py
|
| 59 |
+
class FlavaMaskingGenerator:
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
input_size: Union[int, Tuple[int, int]] = 14,
|
| 63 |
+
total_mask_patches: int = 75,
|
| 64 |
+
mask_group_max_patches: Optional[int] = None,
|
| 65 |
+
mask_group_min_patches: int = 16,
|
| 66 |
+
mask_group_min_aspect_ratio: Optional[float] = 0.3,
|
| 67 |
+
mask_group_max_aspect_ratio: Optional[float] = None,
|
| 68 |
+
):
|
| 69 |
+
if not isinstance(input_size, tuple):
|
| 70 |
+
input_size = (input_size,) * 2
|
| 71 |
+
self.height, self.width = input_size
|
| 72 |
+
|
| 73 |
+
self.num_patches = self.height * self.width
|
| 74 |
+
self.total_mask_patches = total_mask_patches
|
| 75 |
+
|
| 76 |
+
self.mask_group_min_patches = mask_group_min_patches
|
| 77 |
+
self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches
|
| 78 |
+
|
| 79 |
+
mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio
|
| 80 |
+
self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio))
|
| 81 |
+
|
| 82 |
+
def __repr__(self):
|
| 83 |
+
repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
|
| 84 |
+
self.height,
|
| 85 |
+
self.width,
|
| 86 |
+
self.mask_group_min_patches,
|
| 87 |
+
self.mask_group_max_patches,
|
| 88 |
+
self.total_mask_patches,
|
| 89 |
+
self.log_aspect_ratio[0],
|
| 90 |
+
self.log_aspect_ratio[1],
|
| 91 |
+
)
|
| 92 |
+
return repr_str
|
| 93 |
+
|
| 94 |
+
def get_shape(self):
|
| 95 |
+
return self.height, self.width
|
| 96 |
+
|
| 97 |
+
def _mask(self, mask, max_mask_patches):
|
| 98 |
+
delta = 0
|
| 99 |
+
for _attempt in range(10):
|
| 100 |
+
target_area = random.uniform(self.mask_group_min_patches, max_mask_patches)
|
| 101 |
+
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
| 102 |
+
height = int(round(math.sqrt(target_area * aspect_ratio)))
|
| 103 |
+
width = int(round(math.sqrt(target_area / aspect_ratio)))
|
| 104 |
+
if width < self.width and height < self.height:
|
| 105 |
+
top = random.randint(0, self.height - height)
|
| 106 |
+
left = random.randint(0, self.width - width)
|
| 107 |
+
|
| 108 |
+
num_masked = mask[top : top + height, left : left + width].sum()
|
| 109 |
+
# Overlap
|
| 110 |
+
if 0 < height * width - num_masked <= max_mask_patches:
|
| 111 |
+
for i in range(top, top + height):
|
| 112 |
+
for j in range(left, left + width):
|
| 113 |
+
if mask[i, j] == 0:
|
| 114 |
+
mask[i, j] = 1
|
| 115 |
+
delta += 1
|
| 116 |
+
|
| 117 |
+
if delta > 0:
|
| 118 |
+
break
|
| 119 |
+
return delta
|
| 120 |
+
|
| 121 |
+
def __call__(self):
|
| 122 |
+
mask = np.zeros(shape=self.get_shape(), dtype=int)
|
| 123 |
+
mask_count = 0
|
| 124 |
+
while mask_count < self.total_mask_patches:
|
| 125 |
+
max_mask_patches = self.total_mask_patches - mask_count
|
| 126 |
+
max_mask_patches = min(max_mask_patches, self.mask_group_max_patches)
|
| 127 |
+
|
| 128 |
+
delta = self._mask(mask, max_mask_patches)
|
| 129 |
+
if delta == 0:
|
| 130 |
+
break
|
| 131 |
+
else:
|
| 132 |
+
mask_count += delta
|
| 133 |
+
|
| 134 |
+
return mask
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@requires(backends=("vision",))
|
| 138 |
+
class FlavaImageProcessor(BaseImageProcessor):
|
| 139 |
+
r"""
|
| 140 |
+
Constructs a Flava image processor.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 144 |
+
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
| 145 |
+
`do_resize` parameter in `preprocess`.
|
| 146 |
+
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
|
| 147 |
+
Size of the image after resizing. Can be overridden by the `size` parameter in `preprocess`.
|
| 148 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
| 149 |
+
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in
|
| 150 |
+
`preprocess`.
|
| 151 |
+
do_center_crop (`bool`, *optional*, defaults to `True`):
|
| 152 |
+
Whether to center crop the images. Can be overridden by the `do_center_crop` parameter in `preprocess`.
|
| 153 |
+
crop_size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
|
| 154 |
+
Size of image after the center crop `(crop_size["height"], crop_size["width"])`. Can be overridden by the
|
| 155 |
+
`crop_size` parameter in `preprocess`.
|
| 156 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 157 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
| 158 |
+
parameter in `preprocess`.
|
| 159 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 160 |
+
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in
|
| 161 |
+
`preprocess`.
|
| 162 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 163 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in `preprocess`.
|
| 164 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
| 165 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 166 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 167 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
| 168 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
| 169 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 170 |
+
return_image_mask (`bool`, *optional*, defaults to `False`):
|
| 171 |
+
Whether to return the image mask. Can be overridden by the `return_image_mask` parameter in `preprocess`.
|
| 172 |
+
input_size_patches (`int`, *optional*, defaults to 14):
|
| 173 |
+
Number of patches in the image in height and width direction. 14x14 = 196 total patches. Can be overridden
|
| 174 |
+
by the `input_size_patches` parameter in `preprocess`.
|
| 175 |
+
total_mask_patches (`int`, *optional*, defaults to 75):
|
| 176 |
+
Total number of patches that should be masked. Can be overridden by the `total_mask_patches` parameter in
|
| 177 |
+
`preprocess`.
|
| 178 |
+
mask_group_min_patches (`int`, *optional*, defaults to 16):
|
| 179 |
+
Minimum number of patches that should be masked. Can be overridden by the `mask_group_min_patches`
|
| 180 |
+
parameter in `preprocess`.
|
| 181 |
+
mask_group_max_patches (`int`, *optional*):
|
| 182 |
+
Maximum number of patches that should be masked. Can be overridden by the `mask_group_max_patches`
|
| 183 |
+
parameter in `preprocess`.
|
| 184 |
+
mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3):
|
| 185 |
+
Minimum aspect ratio of the mask window. Can be overridden by the `mask_group_min_aspect_ratio` parameter
|
| 186 |
+
in `preprocess`.
|
| 187 |
+
mask_group_max_aspect_ratio (`float`, *optional*):
|
| 188 |
+
Maximum aspect ratio of the mask window. Can be overridden by the `mask_group_max_aspect_ratio` parameter
|
| 189 |
+
in `preprocess`.
|
| 190 |
+
codebook_do_resize (`bool`, *optional*, defaults to `True`):
|
| 191 |
+
Whether to resize the input for codebook to a certain. Can be overridden by the `codebook_do_resize`
|
| 192 |
+
parameter in `preprocess`. `codebook_size`.
|
| 193 |
+
codebook_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
| 194 |
+
Resize the input for codebook to the given size. Can be overridden by the `codebook_size` parameter in
|
| 195 |
+
`preprocess`.
|
| 196 |
+
codebook_resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
|
| 197 |
+
Resampling filter to use if resizing the codebook image. Can be overridden by the `codebook_resample`
|
| 198 |
+
parameter in `preprocess`.
|
| 199 |
+
codebook_do_center_crop (`bool`, *optional*, defaults to `True`):
|
| 200 |
+
Whether to crop the input for codebook at the center. If the input size is smaller than
|
| 201 |
+
`codebook_crop_size` along any edge, the image is padded with 0's and then center cropped. Can be
|
| 202 |
+
overridden by the `codebook_do_center_crop` parameter in `preprocess`.
|
| 203 |
+
codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
| 204 |
+
Desired output size for codebook input when applying center-cropping. Can be overridden by the
|
| 205 |
+
`codebook_crop_size` parameter in `preprocess`.
|
| 206 |
+
codebook_do_rescale (`bool`, *optional*, defaults to `True`):
|
| 207 |
+
Whether to rescale the input for codebook by the specified scale `codebook_rescale_factor`. Can be
|
| 208 |
+
overridden by the `codebook_do_rescale` parameter in `preprocess`.
|
| 209 |
+
codebook_rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 210 |
+
Defines the scale factor to use if rescaling the codebook image. Can be overridden by the
|
| 211 |
+
`codebook_rescale_factor` parameter in `preprocess`.
|
| 212 |
+
codebook_do_map_pixels (`bool`, *optional*, defaults to `True`):
|
| 213 |
+
Whether to map the pixel values of the codebook input to (1 - 2e)x + e. Can be overridden by the
|
| 214 |
+
`codebook_do_map_pixels` parameter in `preprocess`.
|
| 215 |
+
codebook_do_normalize (`bool`, *optional*, defaults to `True`):
|
| 216 |
+
Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`. Can
|
| 217 |
+
be overridden by the `codebook_do_normalize` parameter in `preprocess`.
|
| 218 |
+
codebook_image_mean (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0, 0, 0]`):
|
| 219 |
+
The sequence of means for each channel, to be used when normalizing images for codebook. Can be overridden
|
| 220 |
+
by the `codebook_image_mean` parameter in `preprocess`.
|
| 221 |
+
codebook_image_std (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 222 |
+
The sequence of standard deviations for each channel, to be used when normalizing images for codebook. Can
|
| 223 |
+
be overridden by the `codebook_image_std` parameter in `preprocess`.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
model_input_names = ["pixel_values"]
|
| 227 |
+
|
| 228 |
+
def __init__(
|
| 229 |
+
self,
|
| 230 |
+
do_resize: bool = True,
|
| 231 |
+
size: Dict[str, int] = None,
|
| 232 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 233 |
+
do_center_crop: bool = True,
|
| 234 |
+
crop_size: Dict[str, int] = None,
|
| 235 |
+
do_rescale: bool = True,
|
| 236 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 237 |
+
do_normalize: bool = True,
|
| 238 |
+
image_mean: Optional[Union[float, Iterable[float]]] = None,
|
| 239 |
+
image_std: Optional[Union[float, Iterable[float]]] = None,
|
| 240 |
+
# Mask related params
|
| 241 |
+
return_image_mask: bool = False,
|
| 242 |
+
input_size_patches: int = 14,
|
| 243 |
+
total_mask_patches: int = 75,
|
| 244 |
+
mask_group_min_patches: int = 16,
|
| 245 |
+
mask_group_max_patches: Optional[int] = None,
|
| 246 |
+
mask_group_min_aspect_ratio: float = 0.3,
|
| 247 |
+
mask_group_max_aspect_ratio: Optional[float] = None,
|
| 248 |
+
# Codebook related params
|
| 249 |
+
return_codebook_pixels: bool = False,
|
| 250 |
+
codebook_do_resize: bool = True,
|
| 251 |
+
codebook_size: Optional[bool] = None,
|
| 252 |
+
codebook_resample: int = PILImageResampling.LANCZOS,
|
| 253 |
+
codebook_do_center_crop: bool = True,
|
| 254 |
+
codebook_crop_size: Optional[int] = None,
|
| 255 |
+
codebook_do_rescale: bool = True,
|
| 256 |
+
codebook_rescale_factor: Union[int, float] = 1 / 255,
|
| 257 |
+
codebook_do_map_pixels: bool = True,
|
| 258 |
+
codebook_do_normalize: bool = True,
|
| 259 |
+
codebook_image_mean: Optional[Union[float, Iterable[float]]] = None,
|
| 260 |
+
codebook_image_std: Optional[Union[float, Iterable[float]]] = None,
|
| 261 |
+
**kwargs,
|
| 262 |
+
) -> None:
|
| 263 |
+
super().__init__(**kwargs)
|
| 264 |
+
size = size if size is not None else {"height": 224, "width": 224}
|
| 265 |
+
size = get_size_dict(size)
|
| 266 |
+
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
| 267 |
+
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
| 268 |
+
|
| 269 |
+
codebook_size = codebook_size if codebook_size is not None else {"height": 112, "width": 112}
|
| 270 |
+
codebook_size = get_size_dict(codebook_size, param_name="codebook_size")
|
| 271 |
+
codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else {"height": 112, "width": 112}
|
| 272 |
+
codebook_crop_size = get_size_dict(codebook_crop_size, param_name="codebook_crop_size")
|
| 273 |
+
|
| 274 |
+
self.do_resize = do_resize
|
| 275 |
+
self.size = size
|
| 276 |
+
self.resample = resample
|
| 277 |
+
self.do_rescale = do_rescale
|
| 278 |
+
self.rescale_factor = rescale_factor
|
| 279 |
+
self.do_center_crop = do_center_crop
|
| 280 |
+
self.crop_size = crop_size
|
| 281 |
+
self.do_normalize = do_normalize
|
| 282 |
+
self.image_mean = image_mean if image_mean is not None else FLAVA_IMAGE_MEAN
|
| 283 |
+
self.image_std = image_std if image_std is not None else FLAVA_IMAGE_STD
|
| 284 |
+
|
| 285 |
+
self.return_image_mask = return_image_mask
|
| 286 |
+
self.input_size_patches = input_size_patches
|
| 287 |
+
self.total_mask_patches = total_mask_patches
|
| 288 |
+
self.mask_group_min_patches = mask_group_min_patches
|
| 289 |
+
self.mask_group_max_patches = mask_group_max_patches
|
| 290 |
+
self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio
|
| 291 |
+
self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio
|
| 292 |
+
|
| 293 |
+
self.return_codebook_pixels = return_codebook_pixels
|
| 294 |
+
self.codebook_do_resize = codebook_do_resize
|
| 295 |
+
self.codebook_size = codebook_size
|
| 296 |
+
self.codebook_resample = codebook_resample
|
| 297 |
+
self.codebook_do_center_crop = codebook_do_center_crop
|
| 298 |
+
self.codebook_crop_size = codebook_crop_size
|
| 299 |
+
self.codebook_do_rescale = codebook_do_rescale
|
| 300 |
+
self.codebook_rescale_factor = codebook_rescale_factor
|
| 301 |
+
self.codebook_do_map_pixels = codebook_do_map_pixels
|
| 302 |
+
self.codebook_do_normalize = codebook_do_normalize
|
| 303 |
+
self.codebook_image_mean = codebook_image_mean
|
| 304 |
+
self.codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else FLAVA_CODEBOOK_MEAN
|
| 305 |
+
self.codebook_image_std = codebook_image_std if codebook_image_std is not None else FLAVA_CODEBOOK_STD
|
| 306 |
+
|
| 307 |
+
@classmethod
|
| 308 |
+
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
| 309 |
+
"""
|
| 310 |
+
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
|
| 311 |
+
created using from_dict and kwargs e.g. `FlavaImageProcessor.from_pretrained(checkpoint, codebook_size=600)`
|
| 312 |
+
"""
|
| 313 |
+
image_processor_dict = image_processor_dict.copy()
|
| 314 |
+
if "codebook_size" in kwargs:
|
| 315 |
+
image_processor_dict["codebook_size"] = kwargs.pop("codebook_size")
|
| 316 |
+
if "codebook_crop_size" in kwargs:
|
| 317 |
+
image_processor_dict["codebook_crop_size"] = kwargs.pop("codebook_crop_size")
|
| 318 |
+
return super().from_dict(image_processor_dict, **kwargs)
|
| 319 |
+
|
| 320 |
+
@lru_cache()
|
| 321 |
+
def masking_generator(
|
| 322 |
+
self,
|
| 323 |
+
input_size_patches,
|
| 324 |
+
total_mask_patches,
|
| 325 |
+
mask_group_min_patches,
|
| 326 |
+
mask_group_max_patches,
|
| 327 |
+
mask_group_min_aspect_ratio,
|
| 328 |
+
mask_group_max_aspect_ratio,
|
| 329 |
+
) -> FlavaMaskingGenerator:
|
| 330 |
+
return FlavaMaskingGenerator(
|
| 331 |
+
input_size=input_size_patches,
|
| 332 |
+
total_mask_patches=total_mask_patches,
|
| 333 |
+
mask_group_min_patches=mask_group_min_patches,
|
| 334 |
+
mask_group_max_patches=mask_group_max_patches,
|
| 335 |
+
mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
|
| 336 |
+
mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
# Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
|
| 340 |
+
def resize(
|
| 341 |
+
self,
|
| 342 |
+
image: np.ndarray,
|
| 343 |
+
size: Dict[str, int],
|
| 344 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 345 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 346 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 347 |
+
**kwargs,
|
| 348 |
+
) -> np.ndarray:
|
| 349 |
+
"""
|
| 350 |
+
Resize an image to `(size["height"], size["width"])`.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
image (`np.ndarray`):
|
| 354 |
+
Image to resize.
|
| 355 |
+
size (`Dict[str, int]`):
|
| 356 |
+
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
| 357 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
| 358 |
+
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
|
| 359 |
+
data_format (`ChannelDimension` or `str`, *optional*):
|
| 360 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 361 |
+
image is used. Can be one of:
|
| 362 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 363 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 364 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 365 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 366 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 367 |
+
from the input image. Can be one of:
|
| 368 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 369 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 370 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
`np.ndarray`: The resized image.
|
| 374 |
+
"""
|
| 375 |
+
size = get_size_dict(size)
|
| 376 |
+
if "height" not in size or "width" not in size:
|
| 377 |
+
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
| 378 |
+
output_size = (size["height"], size["width"])
|
| 379 |
+
return resize(
|
| 380 |
+
image,
|
| 381 |
+
size=output_size,
|
| 382 |
+
resample=resample,
|
| 383 |
+
data_format=data_format,
|
| 384 |
+
input_data_format=input_data_format,
|
| 385 |
+
**kwargs,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
def map_pixels(self, image: np.ndarray) -> np.ndarray:
|
| 389 |
+
return (1 - 2 * LOGIT_LAPLACE_EPS) * image + LOGIT_LAPLACE_EPS
|
| 390 |
+
|
| 391 |
+
def _preprocess_image(
|
| 392 |
+
self,
|
| 393 |
+
image: ImageInput,
|
| 394 |
+
do_resize: Optional[bool] = None,
|
| 395 |
+
size: Dict[str, int] = None,
|
| 396 |
+
resample: PILImageResampling = None,
|
| 397 |
+
do_center_crop: Optional[bool] = None,
|
| 398 |
+
crop_size: Dict[str, int] = None,
|
| 399 |
+
do_rescale: Optional[bool] = None,
|
| 400 |
+
rescale_factor: Optional[float] = None,
|
| 401 |
+
do_normalize: Optional[bool] = None,
|
| 402 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 403 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 404 |
+
do_map_pixels: Optional[bool] = None,
|
| 405 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
| 406 |
+
input_data_format: Optional[ChannelDimension] = None,
|
| 407 |
+
) -> np.ndarray:
|
| 408 |
+
"""Preprocesses a single image."""
|
| 409 |
+
|
| 410 |
+
validate_preprocess_arguments(
|
| 411 |
+
do_rescale=do_rescale,
|
| 412 |
+
rescale_factor=rescale_factor,
|
| 413 |
+
do_normalize=do_normalize,
|
| 414 |
+
image_mean=image_mean,
|
| 415 |
+
image_std=image_std,
|
| 416 |
+
do_center_crop=do_center_crop,
|
| 417 |
+
crop_size=crop_size,
|
| 418 |
+
do_resize=do_resize,
|
| 419 |
+
size=size,
|
| 420 |
+
resample=resample,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
# All transformations expect numpy arrays.
|
| 424 |
+
image = to_numpy_array(image)
|
| 425 |
+
|
| 426 |
+
if do_rescale and is_scaled_image(image):
|
| 427 |
+
logger.warning_once(
|
| 428 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 429 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
if input_data_format is None:
|
| 433 |
+
# We assume that all images have the same channel dimension format.
|
| 434 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 435 |
+
|
| 436 |
+
if do_resize:
|
| 437 |
+
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
| 438 |
+
|
| 439 |
+
if do_center_crop:
|
| 440 |
+
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
|
| 441 |
+
|
| 442 |
+
if do_rescale:
|
| 443 |
+
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 444 |
+
|
| 445 |
+
if do_normalize:
|
| 446 |
+
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 447 |
+
|
| 448 |
+
if do_map_pixels:
|
| 449 |
+
image = self.map_pixels(image)
|
| 450 |
+
|
| 451 |
+
if data_format is not None:
|
| 452 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 453 |
+
return image
|
| 454 |
+
|
| 455 |
+
@filter_out_non_signature_kwargs()
|
| 456 |
+
def preprocess(
|
| 457 |
+
self,
|
| 458 |
+
images: ImageInput,
|
| 459 |
+
do_resize: Optional[bool] = None,
|
| 460 |
+
size: Dict[str, int] = None,
|
| 461 |
+
resample: PILImageResampling = None,
|
| 462 |
+
do_center_crop: Optional[bool] = None,
|
| 463 |
+
crop_size: Optional[Dict[str, int]] = None,
|
| 464 |
+
do_rescale: Optional[bool] = None,
|
| 465 |
+
rescale_factor: Optional[float] = None,
|
| 466 |
+
do_normalize: Optional[bool] = None,
|
| 467 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 468 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 469 |
+
# Mask related params
|
| 470 |
+
return_image_mask: Optional[bool] = None,
|
| 471 |
+
input_size_patches: Optional[int] = None,
|
| 472 |
+
total_mask_patches: Optional[int] = None,
|
| 473 |
+
mask_group_min_patches: Optional[int] = None,
|
| 474 |
+
mask_group_max_patches: Optional[int] = None,
|
| 475 |
+
mask_group_min_aspect_ratio: Optional[float] = None,
|
| 476 |
+
mask_group_max_aspect_ratio: Optional[float] = None,
|
| 477 |
+
# Codebook related params
|
| 478 |
+
return_codebook_pixels: Optional[bool] = None,
|
| 479 |
+
codebook_do_resize: Optional[bool] = None,
|
| 480 |
+
codebook_size: Optional[Dict[str, int]] = None,
|
| 481 |
+
codebook_resample: Optional[int] = None,
|
| 482 |
+
codebook_do_center_crop: Optional[bool] = None,
|
| 483 |
+
codebook_crop_size: Optional[Dict[str, int]] = None,
|
| 484 |
+
codebook_do_rescale: Optional[bool] = None,
|
| 485 |
+
codebook_rescale_factor: Optional[float] = None,
|
| 486 |
+
codebook_do_map_pixels: Optional[bool] = None,
|
| 487 |
+
codebook_do_normalize: Optional[bool] = None,
|
| 488 |
+
codebook_image_mean: Optional[Iterable[float]] = None,
|
| 489 |
+
codebook_image_std: Optional[Iterable[float]] = None,
|
| 490 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 491 |
+
data_format: ChannelDimension = ChannelDimension.FIRST,
|
| 492 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 493 |
+
) -> PIL.Image.Image:
|
| 494 |
+
"""
|
| 495 |
+
Preprocess an image or batch of images.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
images (`ImageInput`):
|
| 499 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 500 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 501 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 502 |
+
Whether to resize the image.
|
| 503 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 504 |
+
Size of the image.
|
| 505 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
| 506 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
|
| 507 |
+
has an effect if `do_resize` is set to `True`.
|
| 508 |
+
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
|
| 509 |
+
Whether to center crop the image.
|
| 510 |
+
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
|
| 511 |
+
Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
|
| 512 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 513 |
+
Whether to rescale the image values between [0 - 1].
|
| 514 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 515 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 516 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 517 |
+
Whether to normalize the image.
|
| 518 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 519 |
+
Image mean.
|
| 520 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 521 |
+
Image standard deviation.
|
| 522 |
+
return_image_mask (`bool`, *optional*, defaults to `self.return_image_mask`):
|
| 523 |
+
Whether to return the image mask.
|
| 524 |
+
input_size_patches (`int`, *optional*, defaults to `self.input_size_patches`):
|
| 525 |
+
Size of the patches to extract from the image.
|
| 526 |
+
total_mask_patches (`int`, *optional*, defaults to `self.total_mask_patches`):
|
| 527 |
+
Total number of patches to extract from the image.
|
| 528 |
+
mask_group_min_patches (`int`, *optional*, defaults to `self.mask_group_min_patches`):
|
| 529 |
+
Minimum number of patches to extract from the image.
|
| 530 |
+
mask_group_max_patches (`int`, *optional*, defaults to `self.mask_group_max_patches`):
|
| 531 |
+
Maximum number of patches to extract from the image.
|
| 532 |
+
mask_group_min_aspect_ratio (`float`, *optional*, defaults to `self.mask_group_min_aspect_ratio`):
|
| 533 |
+
Minimum aspect ratio of the patches to extract from the image.
|
| 534 |
+
mask_group_max_aspect_ratio (`float`, *optional*, defaults to `self.mask_group_max_aspect_ratio`):
|
| 535 |
+
Maximum aspect ratio of the patches to extract from the image.
|
| 536 |
+
return_codebook_pixels (`bool`, *optional*, defaults to `self.return_codebook_pixels`):
|
| 537 |
+
Whether to return the codebook pixels.
|
| 538 |
+
codebook_do_resize (`bool`, *optional*, defaults to `self.codebook_do_resize`):
|
| 539 |
+
Whether to resize the codebook pixels.
|
| 540 |
+
codebook_size (`Dict[str, int]`, *optional*, defaults to `self.codebook_size`):
|
| 541 |
+
Size of the codebook pixels.
|
| 542 |
+
codebook_resample (`int`, *optional*, defaults to `self.codebook_resample`):
|
| 543 |
+
Resampling filter to use if resizing the codebook pixels. This can be one of the enum
|
| 544 |
+
`PILImageResampling`, Only has an effect if `codebook_do_resize` is set to `True`.
|
| 545 |
+
codebook_do_center_crop (`bool`, *optional*, defaults to `self.codebook_do_center_crop`):
|
| 546 |
+
Whether to center crop the codebook pixels.
|
| 547 |
+
codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `self.codebook_crop_size`):
|
| 548 |
+
Size of the center crop of the codebook pixels. Only has an effect if `codebook_do_center_crop` is set
|
| 549 |
+
to `True`.
|
| 550 |
+
codebook_do_rescale (`bool`, *optional*, defaults to `self.codebook_do_rescale`):
|
| 551 |
+
Whether to rescale the codebook pixels values between [0 - 1].
|
| 552 |
+
codebook_rescale_factor (`float`, *optional*, defaults to `self.codebook_rescale_factor`):
|
| 553 |
+
Rescale factor to rescale the codebook pixels by if `codebook_do_rescale` is set to `True`.
|
| 554 |
+
codebook_do_map_pixels (`bool`, *optional*, defaults to `self.codebook_do_map_pixels`):
|
| 555 |
+
Whether to map the codebook pixels values.
|
| 556 |
+
codebook_do_normalize (`bool`, *optional*, defaults to `self.codebook_do_normalize`):
|
| 557 |
+
Whether to normalize the codebook pixels.
|
| 558 |
+
codebook_image_mean (`float` or `List[float]`, *optional*, defaults to `self.codebook_image_mean`):
|
| 559 |
+
Codebook pixels mean to normalize the codebook pixels by if `codebook_do_normalize` is set to `True`.
|
| 560 |
+
codebook_image_std (`float` or `List[float]`, *optional*, defaults to `self.codebook_image_std`):
|
| 561 |
+
Codebook pixels standard deviation to normalize the codebook pixels by if `codebook_do_normalize` is
|
| 562 |
+
set to `True`.
|
| 563 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 564 |
+
The type of tensors to return. Can be one of:
|
| 565 |
+
- Unset: Return a list of `np.ndarray`.
|
| 566 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 567 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 568 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 569 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 570 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 571 |
+
The channel dimension format for the output image. Can be one of:
|
| 572 |
+
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 573 |
+
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 574 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 575 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 576 |
+
from the input image. Can be one of:
|
| 577 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 578 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 579 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 580 |
+
"""
|
| 581 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 582 |
+
size = size if size is not None else self.size
|
| 583 |
+
size = get_size_dict(size)
|
| 584 |
+
resample = resample if resample is not None else self.resample
|
| 585 |
+
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
| 586 |
+
crop_size = crop_size if crop_size is not None else self.crop_size
|
| 587 |
+
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
| 588 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 589 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 590 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 591 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 592 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 593 |
+
|
| 594 |
+
return_image_mask = return_image_mask if return_image_mask is not None else self.return_image_mask
|
| 595 |
+
input_size_patches = input_size_patches if input_size_patches is not None else self.input_size_patches
|
| 596 |
+
total_mask_patches = total_mask_patches if total_mask_patches is not None else self.total_mask_patches
|
| 597 |
+
mask_group_min_patches = (
|
| 598 |
+
mask_group_min_patches if mask_group_min_patches is not None else self.mask_group_min_patches
|
| 599 |
+
)
|
| 600 |
+
mask_group_max_patches = (
|
| 601 |
+
mask_group_max_patches if mask_group_max_patches is not None else self.mask_group_max_patches
|
| 602 |
+
)
|
| 603 |
+
mask_group_min_aspect_ratio = (
|
| 604 |
+
mask_group_min_aspect_ratio
|
| 605 |
+
if mask_group_min_aspect_ratio is not None
|
| 606 |
+
else self.mask_group_min_aspect_ratio
|
| 607 |
+
)
|
| 608 |
+
mask_group_max_aspect_ratio = (
|
| 609 |
+
mask_group_max_aspect_ratio
|
| 610 |
+
if mask_group_max_aspect_ratio is not None
|
| 611 |
+
else self.mask_group_max_aspect_ratio
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
return_codebook_pixels = (
|
| 615 |
+
return_codebook_pixels if return_codebook_pixels is not None else self.return_codebook_pixels
|
| 616 |
+
)
|
| 617 |
+
codebook_do_resize = codebook_do_resize if codebook_do_resize is not None else self.codebook_do_resize
|
| 618 |
+
codebook_size = codebook_size if codebook_size is not None else self.codebook_size
|
| 619 |
+
codebook_size = get_size_dict(codebook_size, param_name="codebook_size")
|
| 620 |
+
codebook_resample = codebook_resample if codebook_resample is not None else self.codebook_resample
|
| 621 |
+
codebook_do_rescale = codebook_do_rescale if codebook_do_rescale is not None else self.codebook_do_rescale
|
| 622 |
+
codebook_rescale_factor = (
|
| 623 |
+
codebook_rescale_factor if codebook_rescale_factor is not None else self.codebook_rescale_factor
|
| 624 |
+
)
|
| 625 |
+
codebook_do_center_crop = (
|
| 626 |
+
codebook_do_center_crop if codebook_do_center_crop is not None else self.codebook_do_center_crop
|
| 627 |
+
)
|
| 628 |
+
codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else self.codebook_crop_size
|
| 629 |
+
codebook_crop_size = get_size_dict(codebook_crop_size, param_name="codebook_crop_size")
|
| 630 |
+
codebook_do_map_pixels = (
|
| 631 |
+
codebook_do_map_pixels if codebook_do_map_pixels is not None else self.codebook_do_map_pixels
|
| 632 |
+
)
|
| 633 |
+
codebook_do_normalize = (
|
| 634 |
+
codebook_do_normalize if codebook_do_normalize is not None else self.codebook_do_normalize
|
| 635 |
+
)
|
| 636 |
+
codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else self.codebook_image_mean
|
| 637 |
+
codebook_image_std = codebook_image_std if codebook_image_std is not None else self.codebook_image_std
|
| 638 |
+
|
| 639 |
+
images = make_list_of_images(images)
|
| 640 |
+
|
| 641 |
+
if not valid_images(images):
|
| 642 |
+
raise ValueError(
|
| 643 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 644 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
processed_images = [
|
| 648 |
+
self._preprocess_image(
|
| 649 |
+
image=img,
|
| 650 |
+
do_resize=do_resize,
|
| 651 |
+
size=size,
|
| 652 |
+
resample=resample,
|
| 653 |
+
do_center_crop=do_center_crop,
|
| 654 |
+
crop_size=crop_size,
|
| 655 |
+
do_rescale=do_rescale,
|
| 656 |
+
rescale_factor=rescale_factor,
|
| 657 |
+
do_normalize=do_normalize,
|
| 658 |
+
image_mean=image_mean,
|
| 659 |
+
image_std=image_std,
|
| 660 |
+
do_map_pixels=False,
|
| 661 |
+
data_format=data_format,
|
| 662 |
+
input_data_format=input_data_format,
|
| 663 |
+
)
|
| 664 |
+
for img in images
|
| 665 |
+
]
|
| 666 |
+
data = {"pixel_values": processed_images}
|
| 667 |
+
|
| 668 |
+
if return_codebook_pixels:
|
| 669 |
+
codebook_images = [
|
| 670 |
+
self._preprocess_image(
|
| 671 |
+
image=img,
|
| 672 |
+
do_resize=codebook_do_resize,
|
| 673 |
+
size=codebook_size,
|
| 674 |
+
resample=codebook_resample,
|
| 675 |
+
do_center_crop=codebook_do_center_crop,
|
| 676 |
+
crop_size=codebook_crop_size,
|
| 677 |
+
do_rescale=codebook_do_rescale,
|
| 678 |
+
rescale_factor=codebook_rescale_factor,
|
| 679 |
+
do_normalize=codebook_do_normalize,
|
| 680 |
+
image_mean=codebook_image_mean,
|
| 681 |
+
image_std=codebook_image_std,
|
| 682 |
+
do_map_pixels=codebook_do_map_pixels,
|
| 683 |
+
data_format=data_format,
|
| 684 |
+
input_data_format=input_data_format,
|
| 685 |
+
)
|
| 686 |
+
for img in images
|
| 687 |
+
]
|
| 688 |
+
data["codebook_pixel_values"] = codebook_images
|
| 689 |
+
|
| 690 |
+
if return_image_mask:
|
| 691 |
+
mask_generator = self.masking_generator(
|
| 692 |
+
input_size_patches=input_size_patches,
|
| 693 |
+
total_mask_patches=total_mask_patches,
|
| 694 |
+
mask_group_min_patches=mask_group_min_patches,
|
| 695 |
+
mask_group_max_patches=mask_group_max_patches,
|
| 696 |
+
mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
|
| 697 |
+
mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
|
| 698 |
+
)
|
| 699 |
+
masks = [mask_generator() for _ in images]
|
| 700 |
+
data["bool_masked_pos"] = masks
|
| 701 |
+
|
| 702 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
__all__ = ["FlavaImageProcessor"]
|
docs/transformers/build/lib/transformers/models/flava/image_processing_flava_fast.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Fast Image processor class for Flava."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import random
|
| 19 |
+
from functools import lru_cache
|
| 20 |
+
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
from ...image_processing_utils_fast import (
|
| 23 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
| 24 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
| 25 |
+
BaseImageProcessorFast,
|
| 26 |
+
BatchFeature,
|
| 27 |
+
DefaultFastImageProcessorKwargs,
|
| 28 |
+
get_size_dict,
|
| 29 |
+
)
|
| 30 |
+
from ...image_transforms import ChannelDimension, group_images_by_shape, reorder_images
|
| 31 |
+
from ...image_utils import ImageInput, PILImageResampling, SizeDict
|
| 32 |
+
from ...processing_utils import Unpack
|
| 33 |
+
from ...utils import (
|
| 34 |
+
TensorType,
|
| 35 |
+
add_start_docstrings,
|
| 36 |
+
is_torch_available,
|
| 37 |
+
is_torchvision_available,
|
| 38 |
+
is_torchvision_v2_available,
|
| 39 |
+
)
|
| 40 |
+
from .image_processing_flava import (
|
| 41 |
+
FLAVA_CODEBOOK_MEAN,
|
| 42 |
+
FLAVA_CODEBOOK_STD,
|
| 43 |
+
FLAVA_IMAGE_MEAN,
|
| 44 |
+
FLAVA_IMAGE_STD,
|
| 45 |
+
LOGIT_LAPLACE_EPS,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if is_torch_available():
|
| 50 |
+
import torch
|
| 51 |
+
|
| 52 |
+
if is_torchvision_available():
|
| 53 |
+
from ...image_utils import pil_torch_interpolation_mapping
|
| 54 |
+
|
| 55 |
+
if is_torchvision_v2_available():
|
| 56 |
+
from torchvision.transforms.v2 import functional as F
|
| 57 |
+
else:
|
| 58 |
+
from torchvision.transforms import functional as F
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class FlavaMaskingGenerator:
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
input_size: Union[int, Tuple[int, int]] = 14,
|
| 65 |
+
total_mask_patches: int = 75,
|
| 66 |
+
mask_group_max_patches: Optional[int] = None,
|
| 67 |
+
mask_group_min_patches: int = 16,
|
| 68 |
+
mask_group_min_aspect_ratio: Optional[float] = 0.3,
|
| 69 |
+
mask_group_max_aspect_ratio: float = None,
|
| 70 |
+
):
|
| 71 |
+
if not isinstance(input_size, tuple):
|
| 72 |
+
input_size = (input_size,) * 2
|
| 73 |
+
self.height, self.width = input_size
|
| 74 |
+
|
| 75 |
+
self.num_patches = self.height * self.width
|
| 76 |
+
self.total_mask_patches = total_mask_patches
|
| 77 |
+
|
| 78 |
+
self.mask_group_min_patches = mask_group_min_patches
|
| 79 |
+
self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches
|
| 80 |
+
|
| 81 |
+
mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio
|
| 82 |
+
self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio))
|
| 83 |
+
|
| 84 |
+
def __repr__(self):
|
| 85 |
+
repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
|
| 86 |
+
self.height,
|
| 87 |
+
self.width,
|
| 88 |
+
self.mask_group_min_patches,
|
| 89 |
+
self.mask_group_max_patches,
|
| 90 |
+
self.total_mask_patches,
|
| 91 |
+
self.log_aspect_ratio[0],
|
| 92 |
+
self.log_aspect_ratio[1],
|
| 93 |
+
)
|
| 94 |
+
return repr_str
|
| 95 |
+
|
| 96 |
+
def get_shape(self):
|
| 97 |
+
return self.height, self.width
|
| 98 |
+
|
| 99 |
+
def _mask(self, mask, max_mask_patches):
|
| 100 |
+
delta = 0
|
| 101 |
+
for _attempt in range(10):
|
| 102 |
+
target_area = random.uniform(self.mask_group_min_patches, max_mask_patches)
|
| 103 |
+
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
| 104 |
+
height = int(round(math.sqrt(target_area * aspect_ratio)))
|
| 105 |
+
width = int(round(math.sqrt(target_area / aspect_ratio)))
|
| 106 |
+
if width < self.width and height < self.height:
|
| 107 |
+
top = random.randint(0, self.height - height)
|
| 108 |
+
left = random.randint(0, self.width - width)
|
| 109 |
+
|
| 110 |
+
num_masked = mask[top : top + height, left : left + width].sum()
|
| 111 |
+
# Overlap
|
| 112 |
+
if 0 < height * width - num_masked <= max_mask_patches:
|
| 113 |
+
zeros_pos = mask[top : top + height, left : left + width] == 0
|
| 114 |
+
mask[top : top + height, left : left + width][zeros_pos] = 1
|
| 115 |
+
delta += zeros_pos.sum()
|
| 116 |
+
|
| 117 |
+
if delta > 0:
|
| 118 |
+
break
|
| 119 |
+
return delta
|
| 120 |
+
|
| 121 |
+
def __call__(self):
|
| 122 |
+
mask = torch.zeros(self.get_shape(), dtype=torch.int)
|
| 123 |
+
mask_count = 0
|
| 124 |
+
while mask_count < self.total_mask_patches:
|
| 125 |
+
max_mask_patches = self.total_mask_patches - mask_count
|
| 126 |
+
max_mask_patches = min(max_mask_patches, self.mask_group_max_patches)
|
| 127 |
+
|
| 128 |
+
delta = self._mask(mask, max_mask_patches)
|
| 129 |
+
if delta == 0:
|
| 130 |
+
break
|
| 131 |
+
else:
|
| 132 |
+
mask_count += delta
|
| 133 |
+
|
| 134 |
+
return mask
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class FlavaFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
| 138 |
+
# Mask related params
|
| 139 |
+
return_image_mask: Optional[bool]
|
| 140 |
+
input_size_patches: Optional[int]
|
| 141 |
+
total_mask_patches: Optional[int]
|
| 142 |
+
mask_group_min_patches: Optional[int]
|
| 143 |
+
mask_group_max_patches: Optional[int]
|
| 144 |
+
mask_group_min_aspect_ratio: Optional[float]
|
| 145 |
+
mask_group_max_aspect_ratio: Optional[float]
|
| 146 |
+
# Codebook related params
|
| 147 |
+
return_codebook_pixels: Optional[bool]
|
| 148 |
+
codebook_do_resize: Optional[bool]
|
| 149 |
+
codebook_size: Optional[bool]
|
| 150 |
+
codebook_resample: Optional[int]
|
| 151 |
+
codebook_do_center_crop: Optional[bool]
|
| 152 |
+
codebook_crop_size: Optional[int]
|
| 153 |
+
codebook_do_rescale: Optional[bool]
|
| 154 |
+
codebook_rescale_factor: Optional[Union[int, float]]
|
| 155 |
+
codebook_do_map_pixels: Optional[bool]
|
| 156 |
+
codebook_do_normalize: Optional[bool]
|
| 157 |
+
codebook_image_mean: Optional[Union[float, Iterable[float]]]
|
| 158 |
+
codebook_image_std: Optional[Union[float, Iterable[float]]]
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@add_start_docstrings(
|
| 162 |
+
"Constructs a fast Flava image processor.",
|
| 163 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
| 164 |
+
"""
|
| 165 |
+
return_image_mask (`bool`, *optional*, defaults to `False`):
|
| 166 |
+
Whether to return the image mask. Can be overridden by the `return_image_mask` parameter in `preprocess`.
|
| 167 |
+
input_size_patches (`int`, *optional*, defaults to 14):
|
| 168 |
+
Number of patches in the image in height and width direction. 14x14 = 196 total patches. Can be overridden
|
| 169 |
+
by the `input_size_patches` parameter in `preprocess`.
|
| 170 |
+
total_mask_patches (`int`, *optional*, defaults to 75):
|
| 171 |
+
Total number of patches that should be masked. Can be overridden by the `total_mask_patches` parameter in
|
| 172 |
+
`preprocess`.
|
| 173 |
+
mask_group_min_patches (`int`, *optional*, defaults to 16):
|
| 174 |
+
Minimum number of patches that should be masked. Can be overridden by the `mask_group_min_patches`
|
| 175 |
+
parameter in `preprocess`.
|
| 176 |
+
mask_group_max_patches (`int`, *optional*):
|
| 177 |
+
Maximum number of patches that should be masked. Can be overridden by the `mask_group_max_patches`
|
| 178 |
+
parameter in `preprocess`.
|
| 179 |
+
mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3):
|
| 180 |
+
Minimum aspect ratio of the mask window. Can be overridden by the `mask_group_min_aspect_ratio` parameter
|
| 181 |
+
in `preprocess`.
|
| 182 |
+
mask_group_max_aspect_ratio (`float`, *optional*):
|
| 183 |
+
Maximum aspect ratio of the mask window. Can be overridden by the `mask_group_max_aspect_ratio` parameter
|
| 184 |
+
in `preprocess`.
|
| 185 |
+
codebook_do_resize (`bool`, *optional*, defaults to `True`):
|
| 186 |
+
Whether to resize the input for codebook to a certain. Can be overridden by the `codebook_do_resize`
|
| 187 |
+
parameter in `preprocess`. `codebook_size`.
|
| 188 |
+
codebook_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
| 189 |
+
Resize the input for codebook to the given size. Can be overridden by the `codebook_size` parameter in
|
| 190 |
+
`preprocess`.
|
| 191 |
+
codebook_resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
|
| 192 |
+
Resampling filter to use if resizing the codebook image. Can be overridden by the `codebook_resample`
|
| 193 |
+
parameter in `preprocess`.
|
| 194 |
+
codebook_do_center_crop (`bool`, *optional*, defaults to `True`):
|
| 195 |
+
Whether to crop the input for codebook at the center. If the input size is smaller than
|
| 196 |
+
`codebook_crop_size` along any edge, the image is padded with 0's and then center cropped. Can be
|
| 197 |
+
overridden by the `codebook_do_center_crop` parameter in `preprocess`.
|
| 198 |
+
codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
| 199 |
+
Desired output size for codebook input when applying center-cropping. Can be overridden by the
|
| 200 |
+
`codebook_crop_size` parameter in `preprocess`.
|
| 201 |
+
codebook_do_rescale (`bool`, *optional*, defaults to `True`):
|
| 202 |
+
Whether to rescale the input for codebook by the specified scale `codebook_rescale_factor`. Can be
|
| 203 |
+
overridden by the `codebook_do_rescale` parameter in `preprocess`.
|
| 204 |
+
codebook_rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 205 |
+
Defines the scale factor to use if rescaling the codebook image. Can be overridden by the
|
| 206 |
+
`codebook_rescale_factor` parameter in `preprocess`.
|
| 207 |
+
codebook_do_map_pixels (`bool`, *optional*, defaults to `True`):
|
| 208 |
+
Whether to map the pixel values of the codebook input to (1 - 2e)x + e. Can be overridden by the
|
| 209 |
+
`codebook_do_map_pixels` parameter in `preprocess`.
|
| 210 |
+
codebook_do_normalize (`bool`, *optional*, defaults to `True`):
|
| 211 |
+
Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`. Can
|
| 212 |
+
be overridden by the `codebook_do_normalize` parameter in `preprocess`.
|
| 213 |
+
codebook_image_mean (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0, 0, 0]`):
|
| 214 |
+
The sequence of means for each channel, to be used when normalizing images for codebook. Can be overridden
|
| 215 |
+
by the `codebook_image_mean` parameter in `preprocess`.
|
| 216 |
+
codebook_image_std (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 217 |
+
The sequence of standard deviations for each channel, to be used when normalizing images for codebook. Can
|
| 218 |
+
be overridden by the `codebook_image_std` parameter in `preprocess`.
|
| 219 |
+
""",
|
| 220 |
+
)
|
| 221 |
+
class FlavaImageProcessorFast(BaseImageProcessorFast):
|
| 222 |
+
resample = PILImageResampling.BICUBIC
|
| 223 |
+
image_mean = FLAVA_IMAGE_MEAN
|
| 224 |
+
image_std = FLAVA_IMAGE_STD
|
| 225 |
+
size = {"height": 224, "width": 224}
|
| 226 |
+
crop_size = {"height": 224, "width": 224}
|
| 227 |
+
do_resize = True
|
| 228 |
+
do_center_crop = True
|
| 229 |
+
do_rescale = True
|
| 230 |
+
do_normalize = True
|
| 231 |
+
|
| 232 |
+
# Mask related params
|
| 233 |
+
return_image_mask = False
|
| 234 |
+
input_size_patches = 14
|
| 235 |
+
total_mask_patches = 75
|
| 236 |
+
mask_group_min_patches = 16
|
| 237 |
+
mask_group_max_patches = None
|
| 238 |
+
mask_group_min_aspect_ratio = 0.3
|
| 239 |
+
mask_group_max_aspect_ratio = None
|
| 240 |
+
# Codebook related params
|
| 241 |
+
return_codebook_pixels = False
|
| 242 |
+
codebook_do_resize = True
|
| 243 |
+
codebook_size = {"height": 112, "width": 112}
|
| 244 |
+
# LANCZOS resample does not support torch Tensor. Use BICUBIC as closest alternative
|
| 245 |
+
codebook_resample = PILImageResampling.BICUBIC
|
| 246 |
+
codebook_do_center_crop = True
|
| 247 |
+
codebook_crop_size = {"height": 112, "width": 112}
|
| 248 |
+
codebook_do_rescale = True
|
| 249 |
+
codebook_rescale_factor = 1 / 255
|
| 250 |
+
codebook_do_map_pixels = True
|
| 251 |
+
codebook_do_normalize = True
|
| 252 |
+
codebook_image_mean = FLAVA_CODEBOOK_MEAN
|
| 253 |
+
codebook_image_std = FLAVA_CODEBOOK_STD
|
| 254 |
+
valid_kwargs = FlavaFastImageProcessorKwargs
|
| 255 |
+
|
| 256 |
+
def __init__(self, **kwargs: Unpack[FlavaFastImageProcessorKwargs]):
|
| 257 |
+
super().__init__(**kwargs)
|
| 258 |
+
|
| 259 |
+
@add_start_docstrings(
|
| 260 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
| 261 |
+
"""
|
| 262 |
+
return_image_mask (`bool`, *optional*, defaults to `False`):
|
| 263 |
+
Whether to return the image mask. Can be overridden by the `return_image_mask` parameter in `preprocess`.
|
| 264 |
+
input_size_patches (`int`, *optional*, defaults to 14):
|
| 265 |
+
Number of patches in the image in height and width direction. 14x14 = 196 total patches. Can be overridden
|
| 266 |
+
by the `input_size_patches` parameter in `preprocess`.
|
| 267 |
+
total_mask_patches (`int`, *optional*, defaults to 75):
|
| 268 |
+
Total number of patches that should be masked. Can be overridden by the `total_mask_patches` parameter in
|
| 269 |
+
`preprocess`.
|
| 270 |
+
mask_group_min_patches (`int`, *optional*, defaults to 16):
|
| 271 |
+
Minimum number of patches that should be masked. Can be overridden by the `mask_group_min_patches`
|
| 272 |
+
parameter in `preprocess`.
|
| 273 |
+
mask_group_max_patches (`int`, *optional*):
|
| 274 |
+
Maximum number of patches that should be masked. Can be overridden by the `mask_group_max_patches`
|
| 275 |
+
parameter in `preprocess`.
|
| 276 |
+
mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3):
|
| 277 |
+
Minimum aspect ratio of the mask window. Can be overridden by the `mask_group_min_aspect_ratio` parameter
|
| 278 |
+
in `preprocess`.
|
| 279 |
+
mask_group_max_aspect_ratio (`float`, *optional*):
|
| 280 |
+
Maximum aspect ratio of the mask window. Can be overridden by the `mask_group_max_aspect_ratio` parameter
|
| 281 |
+
in `preprocess`.
|
| 282 |
+
codebook_do_resize (`bool`, *optional*, defaults to `True`):
|
| 283 |
+
Whether to resize the input for codebook to a certain. Can be overridden by the `codebook_do_resize`
|
| 284 |
+
parameter in `preprocess`. `codebook_size`.
|
| 285 |
+
codebook_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
| 286 |
+
Resize the input for codebook to the given size. Can be overridden by the `codebook_size` parameter in
|
| 287 |
+
`preprocess`.
|
| 288 |
+
codebook_resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
|
| 289 |
+
Resampling filter to use if resizing the codebook image. Can be overridden by the `codebook_resample`
|
| 290 |
+
parameter in `preprocess`.
|
| 291 |
+
codebook_do_center_crop (`bool`, *optional*, defaults to `True`):
|
| 292 |
+
Whether to crop the input for codebook at the center. If the input size is smaller than
|
| 293 |
+
`codebook_crop_size` along any edge, the image is padded with 0's and then center cropped. Can be
|
| 294 |
+
overridden by the `codebook_do_center_crop` parameter in `preprocess`.
|
| 295 |
+
codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
| 296 |
+
Desired output size for codebook input when applying center-cropping. Can be overridden by the
|
| 297 |
+
`codebook_crop_size` parameter in `preprocess`.
|
| 298 |
+
codebook_do_rescale (`bool`, *optional*, defaults to `True`):
|
| 299 |
+
Whether to rescale the input for codebook by the specified scale `codebook_rescale_factor`. Can be
|
| 300 |
+
overridden by the `codebook_do_rescale` parameter in `preprocess`.
|
| 301 |
+
codebook_rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 302 |
+
Defines the scale factor to use if rescaling the codebook image. Can be overridden by the
|
| 303 |
+
`codebook_rescale_factor` parameter in `preprocess`.
|
| 304 |
+
codebook_do_map_pixels (`bool`, *optional*, defaults to `True`):
|
| 305 |
+
Whether to map the pixel values of the codebook input to (1 - 2e)x + e. Can be overridden by the
|
| 306 |
+
`codebook_do_map_pixels` parameter in `preprocess`.
|
| 307 |
+
codebook_do_normalize (`bool`, *optional*, defaults to `True`):
|
| 308 |
+
Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`. Can
|
| 309 |
+
be overridden by the `codebook_do_normalize` parameter in `preprocess`.
|
| 310 |
+
codebook_image_mean (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0, 0, 0]`):
|
| 311 |
+
The sequence of means for each channel, to be used when normalizing images for codebook. Can be overridden
|
| 312 |
+
by the `codebook_image_mean` parameter in `preprocess`.
|
| 313 |
+
codebook_image_std (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 314 |
+
The sequence of standard deviations for each channel, to be used when normalizing images for codebook. Can
|
| 315 |
+
be overridden by the `codebook_image_std` parameter in `preprocess`.
|
| 316 |
+
""",
|
| 317 |
+
)
|
| 318 |
+
def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
|
| 319 |
+
return super().preprocess(images, **kwargs)
|
| 320 |
+
|
| 321 |
+
@classmethod
|
| 322 |
+
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
| 323 |
+
"""
|
| 324 |
+
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
|
| 325 |
+
created using from_dict and kwargs e.g. `FlavaImageProcessor.from_pretrained(checkpoint, codebook_size=600)`
|
| 326 |
+
"""
|
| 327 |
+
image_processor_dict = image_processor_dict.copy()
|
| 328 |
+
if "codebook_size" in kwargs:
|
| 329 |
+
image_processor_dict["codebook_size"] = kwargs.pop("codebook_size")
|
| 330 |
+
if "codebook_crop_size" in kwargs:
|
| 331 |
+
image_processor_dict["codebook_crop_size"] = kwargs.pop("codebook_crop_size")
|
| 332 |
+
return super().from_dict(image_processor_dict, **kwargs)
|
| 333 |
+
|
| 334 |
+
@lru_cache()
|
| 335 |
+
def masking_generator(
|
| 336 |
+
self,
|
| 337 |
+
input_size_patches,
|
| 338 |
+
total_mask_patches,
|
| 339 |
+
mask_group_min_patches,
|
| 340 |
+
mask_group_max_patches,
|
| 341 |
+
mask_group_min_aspect_ratio,
|
| 342 |
+
mask_group_max_aspect_ratio,
|
| 343 |
+
) -> FlavaMaskingGenerator:
|
| 344 |
+
return FlavaMaskingGenerator(
|
| 345 |
+
input_size=input_size_patches,
|
| 346 |
+
total_mask_patches=total_mask_patches,
|
| 347 |
+
mask_group_min_patches=mask_group_min_patches,
|
| 348 |
+
mask_group_max_patches=mask_group_max_patches,
|
| 349 |
+
mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
|
| 350 |
+
mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
def map_pixels(self, image: "torch.Tensor") -> "torch.Tensor":
|
| 354 |
+
return (1 - 2 * LOGIT_LAPLACE_EPS) * image + LOGIT_LAPLACE_EPS
|
| 355 |
+
|
| 356 |
+
def _further_process_kwargs(
|
| 357 |
+
self,
|
| 358 |
+
size: Optional[SizeDict] = None,
|
| 359 |
+
crop_size: Optional[SizeDict] = None,
|
| 360 |
+
default_to_square: Optional[bool] = None,
|
| 361 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 362 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 363 |
+
codebook_size: Optional[SizeDict] = None,
|
| 364 |
+
codebook_crop_size: Optional[SizeDict] = None,
|
| 365 |
+
codebook_image_mean: Optional[Union[float, list[float]]] = None,
|
| 366 |
+
codebook_image_std: Optional[Union[float, list[float]]] = None,
|
| 367 |
+
codebook_resample: Optional[PILImageResampling] = None,
|
| 368 |
+
data_format: Optional[ChannelDimension] = None,
|
| 369 |
+
**kwargs,
|
| 370 |
+
) -> dict:
|
| 371 |
+
"""
|
| 372 |
+
Update kwargs that need further processing before being validated
|
| 373 |
+
Can be overridden by subclasses to customize the processing of kwargs.
|
| 374 |
+
"""
|
| 375 |
+
if kwargs is None:
|
| 376 |
+
kwargs = {}
|
| 377 |
+
if size is not None:
|
| 378 |
+
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
|
| 379 |
+
if crop_size is not None:
|
| 380 |
+
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
|
| 381 |
+
if isinstance(image_mean, list):
|
| 382 |
+
image_mean = tuple(image_mean)
|
| 383 |
+
if isinstance(image_std, list):
|
| 384 |
+
image_std = tuple(image_std)
|
| 385 |
+
if data_format is None:
|
| 386 |
+
data_format = ChannelDimension.FIRST
|
| 387 |
+
if codebook_size is not None:
|
| 388 |
+
codebook_size = SizeDict(**get_size_dict(size=codebook_size, default_to_square=default_to_square))
|
| 389 |
+
if codebook_crop_size is not None:
|
| 390 |
+
codebook_crop_size = SizeDict(**get_size_dict(codebook_crop_size, param_name="codebook_crop_size"))
|
| 391 |
+
if isinstance(codebook_image_mean, list):
|
| 392 |
+
codebook_image_mean = tuple(codebook_image_mean)
|
| 393 |
+
if isinstance(codebook_image_std, list):
|
| 394 |
+
codebook_image_std = tuple(codebook_image_std)
|
| 395 |
+
|
| 396 |
+
kwargs["size"] = size
|
| 397 |
+
kwargs["crop_size"] = crop_size
|
| 398 |
+
kwargs["default_to_square"] = default_to_square
|
| 399 |
+
kwargs["image_mean"] = image_mean
|
| 400 |
+
kwargs["image_std"] = image_std
|
| 401 |
+
kwargs["codebook_size"] = codebook_size
|
| 402 |
+
kwargs["codebook_crop_size"] = codebook_crop_size
|
| 403 |
+
kwargs["codebook_image_mean"] = codebook_image_mean
|
| 404 |
+
kwargs["codebook_image_std"] = codebook_image_std
|
| 405 |
+
kwargs["data_format"] = data_format
|
| 406 |
+
kwargs["codebook_interpolation"] = (
|
| 407 |
+
pil_torch_interpolation_mapping[codebook_resample]
|
| 408 |
+
if isinstance(codebook_resample, (PILImageResampling, int))
|
| 409 |
+
else codebook_resample
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
return kwargs
|
| 413 |
+
|
| 414 |
+
def _preprocess_image(
|
| 415 |
+
self,
|
| 416 |
+
images: list["torch.Tensor"],
|
| 417 |
+
do_resize: bool,
|
| 418 |
+
size: SizeDict,
|
| 419 |
+
interpolation: Optional["F.InterpolationMode"],
|
| 420 |
+
do_center_crop: bool,
|
| 421 |
+
crop_size: SizeDict,
|
| 422 |
+
do_rescale: bool,
|
| 423 |
+
rescale_factor: float,
|
| 424 |
+
do_normalize: bool,
|
| 425 |
+
do_map_pixels: bool,
|
| 426 |
+
image_mean: Optional[Union[float, list[float]]],
|
| 427 |
+
image_std: Optional[Union[float, list[float]]],
|
| 428 |
+
return_tensors: Optional[Union[str, TensorType]],
|
| 429 |
+
) -> "torch.Tensor":
|
| 430 |
+
# Group images by size for batched resizing
|
| 431 |
+
grouped_images, grouped_images_index = group_images_by_shape(images)
|
| 432 |
+
resized_images_grouped = {}
|
| 433 |
+
for shape, stacked_images in grouped_images.items():
|
| 434 |
+
if do_resize:
|
| 435 |
+
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
|
| 436 |
+
resized_images_grouped[shape] = stacked_images
|
| 437 |
+
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
| 438 |
+
|
| 439 |
+
# Group images by size for further processing
|
| 440 |
+
# Needed in case do_resize is False, or resize returns images with different sizes
|
| 441 |
+
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
| 442 |
+
processed_images_grouped = {}
|
| 443 |
+
for shape, stacked_images in grouped_images.items():
|
| 444 |
+
if do_center_crop:
|
| 445 |
+
stacked_images = self.center_crop(stacked_images, crop_size)
|
| 446 |
+
# Fused rescale and normalize
|
| 447 |
+
stacked_images = self.rescale_and_normalize(
|
| 448 |
+
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
| 449 |
+
)
|
| 450 |
+
if do_map_pixels:
|
| 451 |
+
stacked_images = self.map_pixels(image=stacked_images)
|
| 452 |
+
processed_images_grouped[shape] = stacked_images
|
| 453 |
+
|
| 454 |
+
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
| 455 |
+
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
| 456 |
+
|
| 457 |
+
return processed_images
|
| 458 |
+
|
| 459 |
+
def _preprocess(
|
| 460 |
+
self,
|
| 461 |
+
images: list["torch.Tensor"],
|
| 462 |
+
do_resize: bool,
|
| 463 |
+
size: SizeDict,
|
| 464 |
+
interpolation: Optional["F.InterpolationMode"],
|
| 465 |
+
do_center_crop: bool,
|
| 466 |
+
crop_size: SizeDict,
|
| 467 |
+
do_rescale: bool,
|
| 468 |
+
rescale_factor: float,
|
| 469 |
+
do_normalize: bool,
|
| 470 |
+
image_mean: Optional[Union[float, list[float]]],
|
| 471 |
+
image_std: Optional[Union[float, list[float]]],
|
| 472 |
+
# Mask related params
|
| 473 |
+
return_image_mask: Optional[bool],
|
| 474 |
+
input_size_patches: Optional[int],
|
| 475 |
+
total_mask_patches: Optional[int],
|
| 476 |
+
mask_group_min_patches: Optional[int],
|
| 477 |
+
mask_group_max_patches: Optional[int],
|
| 478 |
+
mask_group_min_aspect_ratio: Optional[float],
|
| 479 |
+
mask_group_max_aspect_ratio: Optional[float],
|
| 480 |
+
# Codebook related params
|
| 481 |
+
return_codebook_pixels: Optional[bool],
|
| 482 |
+
codebook_do_resize: Optional[bool],
|
| 483 |
+
codebook_size: Optional[SizeDict],
|
| 484 |
+
codebook_interpolation: Optional["F.InterpolationMode"],
|
| 485 |
+
codebook_do_center_crop: Optional[bool],
|
| 486 |
+
codebook_crop_size: Optional[SizeDict],
|
| 487 |
+
codebook_do_rescale: Optional[bool],
|
| 488 |
+
codebook_rescale_factor: Optional[float],
|
| 489 |
+
codebook_do_map_pixels: Optional[bool],
|
| 490 |
+
codebook_do_normalize: Optional[bool],
|
| 491 |
+
codebook_image_mean: Optional[Union[float, list[float]]],
|
| 492 |
+
codebook_image_std: Optional[Union[float, list[float]]],
|
| 493 |
+
return_tensors: Optional[Union[str, TensorType]],
|
| 494 |
+
**kwargs,
|
| 495 |
+
) -> BatchFeature:
|
| 496 |
+
processed_images = self._preprocess_image(
|
| 497 |
+
images=images,
|
| 498 |
+
do_resize=do_resize,
|
| 499 |
+
size=size,
|
| 500 |
+
interpolation=interpolation,
|
| 501 |
+
do_center_crop=do_center_crop,
|
| 502 |
+
crop_size=crop_size,
|
| 503 |
+
do_rescale=do_rescale,
|
| 504 |
+
rescale_factor=rescale_factor,
|
| 505 |
+
do_normalize=do_normalize,
|
| 506 |
+
do_map_pixels=False,
|
| 507 |
+
image_mean=image_mean,
|
| 508 |
+
image_std=image_std,
|
| 509 |
+
return_tensors=return_tensors,
|
| 510 |
+
)
|
| 511 |
+
data = {
|
| 512 |
+
"pixel_values": processed_images,
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
if return_codebook_pixels:
|
| 516 |
+
codebook_processed_images = self._preprocess_image(
|
| 517 |
+
images=images,
|
| 518 |
+
do_resize=codebook_do_resize,
|
| 519 |
+
size=codebook_size,
|
| 520 |
+
interpolation=codebook_interpolation,
|
| 521 |
+
do_center_crop=codebook_do_center_crop,
|
| 522 |
+
crop_size=codebook_crop_size,
|
| 523 |
+
do_rescale=codebook_do_rescale,
|
| 524 |
+
rescale_factor=codebook_rescale_factor,
|
| 525 |
+
do_normalize=codebook_do_normalize,
|
| 526 |
+
do_map_pixels=codebook_do_map_pixels,
|
| 527 |
+
image_mean=codebook_image_mean,
|
| 528 |
+
image_std=codebook_image_std,
|
| 529 |
+
return_tensors=return_tensors,
|
| 530 |
+
)
|
| 531 |
+
data["codebook_pixel_values"] = codebook_processed_images
|
| 532 |
+
|
| 533 |
+
if return_image_mask:
|
| 534 |
+
mask_generator = self.masking_generator(
|
| 535 |
+
input_size_patches=input_size_patches,
|
| 536 |
+
total_mask_patches=total_mask_patches,
|
| 537 |
+
mask_group_min_patches=mask_group_min_patches,
|
| 538 |
+
mask_group_max_patches=mask_group_max_patches,
|
| 539 |
+
mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
|
| 540 |
+
mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
|
| 541 |
+
)
|
| 542 |
+
masks = [mask_generator() for _ in range(len(images))]
|
| 543 |
+
masks = torch.stack(masks, dim=0) if return_tensors else masks
|
| 544 |
+
data["bool_masked_pos"] = masks
|
| 545 |
+
|
| 546 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
__all__ = ["FlavaImageProcessorFast"]
|
docs/transformers/build/lib/transformers/models/flava/modeling_flava.py
ADDED
|
@@ -0,0 +1,2127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch FLAVA model."""
|
| 16 |
+
|
| 17 |
+
import collections
|
| 18 |
+
import math
|
| 19 |
+
from collections import OrderedDict
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.utils.checkpoint
|
| 25 |
+
from torch import nn
|
| 26 |
+
|
| 27 |
+
from ...activations import ACT2FN
|
| 28 |
+
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
| 29 |
+
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
| 30 |
+
from ...utils import (
|
| 31 |
+
ModelOutput,
|
| 32 |
+
add_code_sample_docstrings,
|
| 33 |
+
add_start_docstrings,
|
| 34 |
+
add_start_docstrings_to_model_forward,
|
| 35 |
+
logging,
|
| 36 |
+
replace_return_docstrings,
|
| 37 |
+
torch_int,
|
| 38 |
+
)
|
| 39 |
+
from .configuration_flava import (
|
| 40 |
+
FlavaConfig,
|
| 41 |
+
FlavaImageCodebookConfig,
|
| 42 |
+
FlavaImageConfig,
|
| 43 |
+
FlavaMultimodalConfig,
|
| 44 |
+
FlavaTextConfig,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__)
|
| 49 |
+
|
| 50 |
+
_CHECKPOINT_FOR_DOC = "facebook/flava-full"
|
| 51 |
+
|
| 52 |
+
# Codebook docstring
|
| 53 |
+
_CHECKPOINT_FOR_CODEBOOK_DOC = "facebook/flava-image-codebook"
|
| 54 |
+
_CONFIG_CLASS_FOR_IMAGE_MODEL_DOC = "FlavaImageConfig"
|
| 55 |
+
_CONFIG_CLASS_FOR_TEXT_MODEL_DOC = "FlavaTextConfig"
|
| 56 |
+
_CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC = "FlavaMultimodalConfig"
|
| 57 |
+
_EXPECTED_IMAGE_OUTPUT_SHAPE = [1, 197, 768]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
LOGIT_SCALE_CLAMP_MIN = 0
|
| 61 |
+
LOGIT_SCALE_CLAMP_MAX = 4.6052
|
| 62 |
+
|
| 63 |
+
FlavaPossibleConfigs = Union[FlavaTextConfig, FlavaImageConfig, FlavaMultimodalConfig]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class FlavaModelOutput(ModelOutput):
|
| 68 |
+
"""
|
| 69 |
+
Output from FlavaModel containing embeddings and outputs from individual encoders.
|
| 70 |
+
|
| 71 |
+
Note that `image_embeddings` and `text_embeddigns` returned are similar to pooled output returned from a
|
| 72 |
+
transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
|
| 73 |
+
`text_projection` layers on `image_embeddings` and `text_embeddings` respectively.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
|
| 77 |
+
The image embeddings which are basically the pooled output of [`FlavaImageModel`].
|
| 78 |
+
image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
|
| 79 |
+
The output of the [`FlavaImageModel`].
|
| 80 |
+
text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present):
|
| 81 |
+
The text embeddings which are basically the pooled output of [`FlavaTextModel`].
|
| 82 |
+
text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
|
| 83 |
+
The output of the [`FlavaTextModel`].
|
| 84 |
+
multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
|
| 85 |
+
The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
|
| 86 |
+
multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
|
| 87 |
+
The output of the [`FlavaMultimodalModel`].
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
image_embeddings: Optional[torch.FloatTensor] = None
|
| 91 |
+
image_output: Optional[BaseModelOutputWithPooling] = None
|
| 92 |
+
text_embeddings: Optional[torch.FloatTensor] = None
|
| 93 |
+
text_output: Optional[BaseModelOutputWithPooling] = None
|
| 94 |
+
multimodal_embeddings: Optional[torch.FloatTensor] = None
|
| 95 |
+
multimodal_output: Optional[BaseModelOutputWithPooling] = None
|
| 96 |
+
|
| 97 |
+
def to_tuple(self) -> Tuple[Any]:
|
| 98 |
+
return tuple(
|
| 99 |
+
self[k] if k not in ["text_output", "image_output", "multimodal_output"] else getattr(self, k).to_tuple()
|
| 100 |
+
for k in self.keys()
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@dataclass
|
| 105 |
+
class FlavaLosses(ModelOutput):
|
| 106 |
+
"""Class representing pretraining losses from FLAVA model
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
mim (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels` and `pixel_values` are present, `input_ids_masked` is absent and `mim_weight` > 0.:
|
| 110 |
+
Masked Image Modeling loss as used in BeIT calculated only for unimodal image data.
|
| 111 |
+
mlm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels` and `input_ids_masked` are present, `pixel_values` is absent and `mlm_weight` > 0.:
|
| 112 |
+
Masked Language Modeling loss as used in BERT calculated only for unimodal text data.
|
| 113 |
+
itm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `itm_labels`, `input_ids_masked`, `pixel_values` are present and `itm_weight` > 0.:
|
| 114 |
+
Image Text Matching (ITM) loss calculated for paired image-text data. Note that ITM loss is calculated on
|
| 115 |
+
masked pairs in FLAVA.
|
| 116 |
+
global_contrastive (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `input_ids` and `pixel_values` are present and `global_contrastive_weight` > 0.:
|
| 117 |
+
Contrastive loss for image-text similarity similar to CLIP but calculated globally for paired image-text
|
| 118 |
+
data. This is calculated on unmasked images and texts.
|
| 119 |
+
mmm_image (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_image_weight` > 0.:
|
| 120 |
+
Masked Multimodal Modeling loss's image component calculated on paired image-text data.
|
| 121 |
+
mmm_text (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_text_weight` > 0.:
|
| 122 |
+
Masked Multimodal Modeling loss's text component calculated on paired image-text data.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
mim: Optional[torch.FloatTensor] = None
|
| 126 |
+
mlm: Optional[torch.FloatTensor] = None
|
| 127 |
+
itm: Optional[torch.FloatTensor] = None
|
| 128 |
+
global_contrastive: Optional[torch.FloatTensor] = None
|
| 129 |
+
mmm_image: Optional[torch.FloatTensor] = None
|
| 130 |
+
mmm_text: Optional[torch.FloatTensor] = None
|
| 131 |
+
|
| 132 |
+
def all_none(self) -> bool:
|
| 133 |
+
all_none = True
|
| 134 |
+
for v in self.values():
|
| 135 |
+
if v is not None:
|
| 136 |
+
all_none = False
|
| 137 |
+
break
|
| 138 |
+
return all_none
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@dataclass
|
| 142 |
+
class FlavaForPreTrainingOutput(ModelOutput):
|
| 143 |
+
"""
|
| 144 |
+
Output from FlavaForPreTraining containing embeddings, and outputs from individual encoders.
|
| 145 |
+
|
| 146 |
+
Note that `image_embeddings` and `text_embeddings` returned are similar to pooled output returned from a
|
| 147 |
+
transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
|
| 148 |
+
`text_projection` layers on `image_embeddings` and `text_embeddings` respectively.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
loss (`torch.FloatTensor`, *optional*, returned when `return_loss` is True):
|
| 152 |
+
Total loss calculated for this model.
|
| 153 |
+
loss_info (`FlavaLosses`):
|
| 154 |
+
Detailed info for FLAVA Pretraining losses. Check `FlavaLosses` class description for the information on
|
| 155 |
+
the keys.
|
| 156 |
+
image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
|
| 157 |
+
The image embeddings which are basically the pooled output of [`FlavaImageModel`].
|
| 158 |
+
image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
|
| 159 |
+
The output of the [`FlavaImageModel`].
|
| 160 |
+
text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present):
|
| 161 |
+
The text embeddings which are basically the pooled output of [`FlavaTextModel`].
|
| 162 |
+
text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
|
| 163 |
+
The output of the [`FlavaTextModel`].
|
| 164 |
+
multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`):
|
| 165 |
+
The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
|
| 166 |
+
multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`):
|
| 167 |
+
The output of the [`FlavaMultimodalModel`].
|
| 168 |
+
|
| 169 |
+
image_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
|
| 170 |
+
The image embeddings which are basically the pooled output of [`FlavaImageModel`]. Uses `bool_masked_pos`
|
| 171 |
+
to create masked images.
|
| 172 |
+
image_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
|
| 173 |
+
The output of the [`FlavaImageModel`]. Uses `bool_masked_pos` to create masked images.
|
| 174 |
+
text_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids_masked` are present):
|
| 175 |
+
The text embeddings which are basically the pooled output of [`FlavaTextModel`].
|
| 176 |
+
text_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` are present):
|
| 177 |
+
The output of the [`FlavaTextModel`].
|
| 178 |
+
multimodal_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present):
|
| 179 |
+
The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
|
| 180 |
+
multimodal_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` and `pixel_values` are present):
|
| 181 |
+
The output of the [`FlavaMultimodalModel`].
|
| 182 |
+
|
| 183 |
+
mim_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape `(total_masked_patches, image_vocab_size)` , *optional*, returned when `pixel_values` are present and `input_ids_masked` are not):
|
| 184 |
+
The logits for MIM unimodal loss. Uses `book_masked_pos` to get masked patches. The flattened output is
|
| 185 |
+
returned when `bool_masked_pos` has some of the patches masked.
|
| 186 |
+
mlm_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(total_masked_seq_length, text_vocab_size)`, *optional*, returned when `input_ids_masked` are present and `pixel_values` are not):
|
| 187 |
+
The logits for MLM unimodal loss. The flattened output is returned when `input_ids_masked` has some of
|
| 188 |
+
the tokens masked.
|
| 189 |
+
itm_logits (`torch.FloatTensor` of shape `(batch_size, 2)`, *optional*, returned when `input_ids_masked` and `pixel_values` are present):
|
| 190 |
+
The logits for ITM loss. Note that ITM loss is calculated on masked pairs in FLAVA.
|
| 191 |
+
mmm_image_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape`(total_masked_patches, image_vocab_size)`, *optional*, returned when `pixel_values` and `input_ids_masked` are present):
|
| 192 |
+
The logits for MMM image multimodal loss. Uses `book_masked_pos` to get masked patches. The flattened
|
| 193 |
+
output is returned when `bool_masked_pos` has some of the patches masked.
|
| 194 |
+
mmm_text_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(`(total_masked_seq_length, text_vocab_size)`), *optional*, returned when `pixel_values` and `input_ids_masked` are present):
|
| 195 |
+
The logits for MMM text multimodal loss. The flattened output is returned when `input_ids_masked` has
|
| 196 |
+
some of the tokens masked.
|
| 197 |
+
contrastive_logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
| 198 |
+
The scaled dot product scores between `image_embeddings` and `text_embeddings` but passed through FLAVA's
|
| 199 |
+
`image_projection` and `text_projection` layers respectively. This represents the image-text similarity
|
| 200 |
+
scores. This is calculated on unmasked images and texts.
|
| 201 |
+
contrastive_logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
| 202 |
+
The scaled dot product scores between `text_embeddings` and `image_embeddings` but passed through FLAVA's
|
| 203 |
+
`text_projection` and `image_projection` layers respectively. This is calculated on unmasked images and
|
| 204 |
+
texts.
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
loss: Optional[torch.FloatTensor] = None
|
| 208 |
+
loss_info: FlavaLosses = None
|
| 209 |
+
image_embeddings: Optional[torch.FloatTensor] = None
|
| 210 |
+
image_output: Optional[BaseModelOutputWithPooling] = None
|
| 211 |
+
text_embeddings: Optional[torch.FloatTensor] = None
|
| 212 |
+
text_output: Optional[BaseModelOutputWithPooling] = None
|
| 213 |
+
multimodal_embeddings: Optional[torch.FloatTensor] = None
|
| 214 |
+
multimodal_output: Optional[BaseModelOutputWithPooling] = None
|
| 215 |
+
image_masked_embeddings: Optional[torch.FloatTensor] = None
|
| 216 |
+
image_masked_output: Optional[BaseModelOutputWithPooling] = None
|
| 217 |
+
text_masked_embeddings: Optional[torch.FloatTensor] = None
|
| 218 |
+
text_masked_output: Optional[BaseModelOutputWithPooling] = None
|
| 219 |
+
multimodal_masked_embeddings: Optional[torch.FloatTensor] = None
|
| 220 |
+
multimodal_masked_output: Optional[BaseModelOutputWithPooling] = None
|
| 221 |
+
mim_logits: Optional[torch.FloatTensor] = None
|
| 222 |
+
mlm_logits: Optional[torch.FloatTensor] = None
|
| 223 |
+
itm_logits: Optional[torch.FloatTensor] = None
|
| 224 |
+
contrastive_logits_per_image: Optional[torch.FloatTensor] = None
|
| 225 |
+
contrastive_logits_per_text: Optional[torch.FloatTensor] = None
|
| 226 |
+
mmm_image_logits: Optional[torch.FloatTensor] = None
|
| 227 |
+
mmm_text_logits: Optional[torch.FloatTensor] = None
|
| 228 |
+
|
| 229 |
+
def to_tuple(self) -> Tuple[Any]:
|
| 230 |
+
transformer_outputs = [
|
| 231 |
+
"text_output",
|
| 232 |
+
"image_output",
|
| 233 |
+
"multimodal_output",
|
| 234 |
+
"text_masked_output",
|
| 235 |
+
"image_masked_output",
|
| 236 |
+
"multimodal_masked_output",
|
| 237 |
+
]
|
| 238 |
+
return tuple(self[k] if k not in transformer_outputs else getattr(self, k).to_tuple() for k in self.keys())
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# Based on timm implementation, which can be found here:
|
| 242 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py
|
| 243 |
+
class FlavaImageEmbeddings(nn.Module):
|
| 244 |
+
"""
|
| 245 |
+
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
def __init__(self, config: FlavaImageConfig, use_mask_token: bool = False) -> None:
|
| 249 |
+
super().__init__()
|
| 250 |
+
|
| 251 |
+
use_mask_token = use_mask_token or config.mask_token
|
| 252 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
| 253 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
| 254 |
+
self.patch_embeddings = PatchEmbeddings(
|
| 255 |
+
image_size=config.image_size,
|
| 256 |
+
patch_size=config.patch_size,
|
| 257 |
+
num_channels=config.num_channels,
|
| 258 |
+
embed_dim=config.hidden_size,
|
| 259 |
+
)
|
| 260 |
+
num_patches = self.patch_embeddings.num_patches
|
| 261 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
| 262 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 263 |
+
self.patch_size = config.patch_size
|
| 264 |
+
self.config = config
|
| 265 |
+
|
| 266 |
+
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
| 267 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
| 268 |
+
"""
|
| 269 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
| 270 |
+
images. This method is also adapted to support torch.jit tracing.
|
| 271 |
+
|
| 272 |
+
Adapted from:
|
| 273 |
+
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
| 274 |
+
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
| 275 |
+
"""
|
| 276 |
+
|
| 277 |
+
num_patches = embeddings.shape[1] - 1
|
| 278 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
| 279 |
+
|
| 280 |
+
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
| 281 |
+
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
| 282 |
+
return self.position_embeddings
|
| 283 |
+
|
| 284 |
+
class_pos_embed = self.position_embeddings[:, :1]
|
| 285 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
| 286 |
+
|
| 287 |
+
dim = embeddings.shape[-1]
|
| 288 |
+
|
| 289 |
+
new_height = height // self.patch_size
|
| 290 |
+
new_width = width // self.patch_size
|
| 291 |
+
|
| 292 |
+
sqrt_num_positions = torch_int(num_positions**0.5)
|
| 293 |
+
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
| 294 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
| 295 |
+
|
| 296 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 297 |
+
patch_pos_embed,
|
| 298 |
+
size=(new_height, new_width),
|
| 299 |
+
mode="bicubic",
|
| 300 |
+
align_corners=False,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 304 |
+
|
| 305 |
+
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
| 306 |
+
|
| 307 |
+
def forward(
|
| 308 |
+
self,
|
| 309 |
+
pixel_values: torch.Tensor,
|
| 310 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
| 311 |
+
interpolate_pos_encoding: bool = False,
|
| 312 |
+
) -> torch.Tensor:
|
| 313 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 314 |
+
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 315 |
+
|
| 316 |
+
batch_size, seq_len, _ = embeddings.size()
|
| 317 |
+
if bool_masked_pos is not None:
|
| 318 |
+
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
|
| 319 |
+
# B X H X W = B X HW
|
| 320 |
+
if bool_masked_pos.dim() == 3:
|
| 321 |
+
bool_masked_pos = bool_masked_pos.view(bool_masked_pos.size(0), -1)
|
| 322 |
+
# replace the masked visual tokens by mask_tokens
|
| 323 |
+
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
| 324 |
+
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
| 325 |
+
|
| 326 |
+
# add the [CLS] token to the embedded patch tokens
|
| 327 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
| 328 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
| 329 |
+
|
| 330 |
+
# add positional encoding to each token
|
| 331 |
+
if interpolate_pos_encoding:
|
| 332 |
+
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
| 333 |
+
else:
|
| 334 |
+
embeddings = embeddings + self.position_embeddings
|
| 335 |
+
|
| 336 |
+
embeddings = self.dropout(embeddings)
|
| 337 |
+
|
| 338 |
+
return embeddings
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
# Based on timm implementation, which can be found here:
|
| 342 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py
|
| 343 |
+
class PatchEmbeddings(nn.Module):
|
| 344 |
+
"""
|
| 345 |
+
Image to Patch Embedding.
|
| 346 |
+
"""
|
| 347 |
+
|
| 348 |
+
def __init__(
|
| 349 |
+
self,
|
| 350 |
+
image_size: int = 224,
|
| 351 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 352 |
+
num_channels: int = 3,
|
| 353 |
+
embed_dim: int = 768,
|
| 354 |
+
):
|
| 355 |
+
super().__init__()
|
| 356 |
+
if not isinstance(image_size, collections.abc.Iterable):
|
| 357 |
+
image_size = (image_size, image_size)
|
| 358 |
+
if not isinstance(patch_size, collections.abc.Iterable):
|
| 359 |
+
patch_size = (patch_size, patch_size)
|
| 360 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 361 |
+
self.image_size = image_size
|
| 362 |
+
self.patch_size = patch_size
|
| 363 |
+
self.num_patches = num_patches
|
| 364 |
+
|
| 365 |
+
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 366 |
+
|
| 367 |
+
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
| 368 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 369 |
+
if not interpolate_pos_encoding:
|
| 370 |
+
if height != self.image_size[0] or width != self.image_size[1]:
|
| 371 |
+
raise ValueError(
|
| 372 |
+
f"Input image size ({height}*{width}) doesn't match model"
|
| 373 |
+
f" ({self.image_size[0]}*{self.image_size[1]})."
|
| 374 |
+
)
|
| 375 |
+
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
| 376 |
+
return x
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class FlavaTextEmbeddings(nn.Module):
|
| 380 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 381 |
+
|
| 382 |
+
def __init__(self, config):
|
| 383 |
+
super().__init__()
|
| 384 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 385 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 386 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
| 387 |
+
|
| 388 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 389 |
+
# any TensorFlow checkpoint file
|
| 390 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 391 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 392 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 393 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 394 |
+
self.register_buffer(
|
| 395 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 396 |
+
)
|
| 397 |
+
self.register_buffer(
|
| 398 |
+
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
def forward(
|
| 402 |
+
self,
|
| 403 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 404 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 405 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 406 |
+
):
|
| 407 |
+
input_shape = input_ids.size()
|
| 408 |
+
seq_length = input_shape[1]
|
| 409 |
+
|
| 410 |
+
if position_ids is None:
|
| 411 |
+
position_ids = self.position_ids[:, :seq_length]
|
| 412 |
+
|
| 413 |
+
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
| 414 |
+
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
| 415 |
+
# issue #5664
|
| 416 |
+
if token_type_ids is None:
|
| 417 |
+
if hasattr(self, "token_type_ids"):
|
| 418 |
+
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
| 419 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
| 420 |
+
token_type_ids = buffered_token_type_ids_expanded
|
| 421 |
+
else:
|
| 422 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
| 423 |
+
|
| 424 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 425 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 426 |
+
|
| 427 |
+
embeddings = inputs_embeds + token_type_embeddings
|
| 428 |
+
if self.position_embedding_type == "absolute":
|
| 429 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 430 |
+
embeddings += position_embeddings
|
| 431 |
+
embeddings = self.LayerNorm(embeddings)
|
| 432 |
+
embeddings = self.dropout(embeddings)
|
| 433 |
+
return embeddings
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class FlavaSelfAttention(nn.Module):
|
| 437 |
+
def __init__(self, config: FlavaPossibleConfigs) -> None:
|
| 438 |
+
super().__init__()
|
| 439 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 440 |
+
raise ValueError(
|
| 441 |
+
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
|
| 442 |
+
f"heads {config.num_attention_heads}."
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
self.num_attention_heads = config.num_attention_heads
|
| 446 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 447 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 448 |
+
|
| 449 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 450 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 451 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 452 |
+
|
| 453 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 454 |
+
|
| 455 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 456 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 457 |
+
x = x.view(*new_x_shape)
|
| 458 |
+
return x.permute(0, 2, 1, 3)
|
| 459 |
+
|
| 460 |
+
def forward(
|
| 461 |
+
self,
|
| 462 |
+
hidden_states: torch.Tensor,
|
| 463 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 464 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 465 |
+
output_attentions: bool = False,
|
| 466 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 467 |
+
mixed_query_layer = self.query(hidden_states)
|
| 468 |
+
|
| 469 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 470 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 471 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 472 |
+
|
| 473 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 474 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 475 |
+
|
| 476 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 477 |
+
if attention_mask is not None:
|
| 478 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 479 |
+
attention_scores = attention_scores + attention_mask
|
| 480 |
+
|
| 481 |
+
# Normalize the attention scores to probabilities.
|
| 482 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 483 |
+
|
| 484 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 485 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 486 |
+
attention_probs = self.dropout(attention_probs)
|
| 487 |
+
|
| 488 |
+
# Mask heads if we want to
|
| 489 |
+
if head_mask is not None:
|
| 490 |
+
attention_probs = attention_probs * head_mask
|
| 491 |
+
|
| 492 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 493 |
+
|
| 494 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 495 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 496 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 497 |
+
|
| 498 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 499 |
+
|
| 500 |
+
return outputs
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class FlavaSelfOutput(nn.Module):
|
| 504 |
+
"""
|
| 505 |
+
The residual connection is defined in FlavaLayer (same as ViTLayer) instead of here (as is the case with other
|
| 506 |
+
models), due to the layernorm applied before each block.
|
| 507 |
+
"""
|
| 508 |
+
|
| 509 |
+
def __init__(self, config: FlavaPossibleConfigs) -> None:
|
| 510 |
+
super().__init__()
|
| 511 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 512 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 513 |
+
|
| 514 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 515 |
+
hidden_states = self.dense(hidden_states)
|
| 516 |
+
hidden_states = self.dropout(hidden_states)
|
| 517 |
+
|
| 518 |
+
return hidden_states
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
class FlavaAttention(nn.Module):
|
| 522 |
+
def __init__(self, config: FlavaPossibleConfigs) -> None:
|
| 523 |
+
super().__init__()
|
| 524 |
+
self.attention = FlavaSelfAttention(config)
|
| 525 |
+
self.output = FlavaSelfOutput(config)
|
| 526 |
+
self.pruned_heads = set()
|
| 527 |
+
|
| 528 |
+
def prune_heads(self, heads: Set[int]) -> None:
|
| 529 |
+
if len(heads) == 0:
|
| 530 |
+
return
|
| 531 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 532 |
+
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
# Prune linear layers
|
| 536 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
| 537 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
| 538 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
| 539 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 540 |
+
|
| 541 |
+
# Update hyper params and store pruned heads
|
| 542 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
| 543 |
+
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
| 544 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 545 |
+
|
| 546 |
+
def forward(
|
| 547 |
+
self,
|
| 548 |
+
hidden_states: torch.Tensor,
|
| 549 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 550 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 551 |
+
output_attentions: bool = False,
|
| 552 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 553 |
+
self_outputs = self.attention(
|
| 554 |
+
hidden_states, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 558 |
+
|
| 559 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 560 |
+
return outputs
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
class FlavaIntermediate(nn.Module):
|
| 564 |
+
def __init__(self, config: FlavaPossibleConfigs) -> None:
|
| 565 |
+
super().__init__()
|
| 566 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 567 |
+
if isinstance(config.hidden_act, str):
|
| 568 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 569 |
+
else:
|
| 570 |
+
self.intermediate_act_fn = config.hidden_act
|
| 571 |
+
|
| 572 |
+
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate.forward
|
| 573 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 574 |
+
hidden_states = self.dense(hidden_states)
|
| 575 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 576 |
+
|
| 577 |
+
return hidden_states
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
class FlavaOutput(nn.Module):
|
| 581 |
+
def __init__(self, config: FlavaPossibleConfigs) -> None:
|
| 582 |
+
super().__init__()
|
| 583 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 584 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 585 |
+
|
| 586 |
+
# Copied from transformers.models.vit.modeling_vit.ViTOutput.forward
|
| 587 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 588 |
+
hidden_states = self.dense(hidden_states)
|
| 589 |
+
hidden_states = self.dropout(hidden_states)
|
| 590 |
+
|
| 591 |
+
hidden_states = hidden_states + input_tensor
|
| 592 |
+
|
| 593 |
+
return hidden_states
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
class FlavaLayer(nn.Module):
|
| 597 |
+
"""This corresponds to the Block class in the timm implementation."""
|
| 598 |
+
|
| 599 |
+
def __init__(self, config: FlavaPossibleConfigs) -> None:
|
| 600 |
+
super().__init__()
|
| 601 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 602 |
+
self.seq_len_dim = 1
|
| 603 |
+
self.attention = FlavaAttention(config)
|
| 604 |
+
self.intermediate = FlavaIntermediate(config)
|
| 605 |
+
self.output = FlavaOutput(config)
|
| 606 |
+
|
| 607 |
+
# TODO: Check fp32 layer norm possiblity
|
| 608 |
+
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 609 |
+
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 610 |
+
|
| 611 |
+
def forward(
|
| 612 |
+
self,
|
| 613 |
+
hidden_states: torch.Tensor,
|
| 614 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 615 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 616 |
+
output_attentions: bool = False,
|
| 617 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 618 |
+
self_attention_outputs = self.attention(
|
| 619 |
+
self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
|
| 620 |
+
attention_mask=attention_mask,
|
| 621 |
+
head_mask=head_mask,
|
| 622 |
+
output_attentions=output_attentions,
|
| 623 |
+
)
|
| 624 |
+
attention_output = self_attention_outputs[0]
|
| 625 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 626 |
+
|
| 627 |
+
# first residual connection
|
| 628 |
+
hidden_states = attention_output + hidden_states
|
| 629 |
+
|
| 630 |
+
# in ViT, layernorm is also applied after self-attention
|
| 631 |
+
layer_output = self.layernorm_after(hidden_states)
|
| 632 |
+
layer_output = self.intermediate(layer_output)
|
| 633 |
+
|
| 634 |
+
# second residual connection is done here
|
| 635 |
+
layer_output = self.output(layer_output, hidden_states)
|
| 636 |
+
|
| 637 |
+
outputs = (layer_output,) + outputs
|
| 638 |
+
|
| 639 |
+
return outputs
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
class FlavaEncoder(nn.Module):
|
| 643 |
+
def __init__(self, config: FlavaConfig) -> None:
|
| 644 |
+
super().__init__()
|
| 645 |
+
self.config = config
|
| 646 |
+
self.layer = nn.ModuleList([FlavaLayer(config) for _ in range(config.num_hidden_layers)])
|
| 647 |
+
self.gradient_checkpointing = False
|
| 648 |
+
|
| 649 |
+
def forward(
|
| 650 |
+
self,
|
| 651 |
+
hidden_states: torch.Tensor,
|
| 652 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 653 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 654 |
+
output_attentions: bool = False,
|
| 655 |
+
output_hidden_states: bool = False,
|
| 656 |
+
return_dict: bool = True,
|
| 657 |
+
) -> Union[tuple, BaseModelOutput]:
|
| 658 |
+
all_hidden_states = () if output_hidden_states else None
|
| 659 |
+
all_self_attentions = () if output_attentions else None
|
| 660 |
+
|
| 661 |
+
for i, layer_module in enumerate(self.layer):
|
| 662 |
+
if output_hidden_states:
|
| 663 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 664 |
+
|
| 665 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 666 |
+
|
| 667 |
+
if self.gradient_checkpointing and self.training:
|
| 668 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 669 |
+
layer_module.__call__,
|
| 670 |
+
hidden_states,
|
| 671 |
+
attention_mask,
|
| 672 |
+
layer_head_mask,
|
| 673 |
+
output_attentions,
|
| 674 |
+
)
|
| 675 |
+
else:
|
| 676 |
+
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
|
| 677 |
+
|
| 678 |
+
hidden_states = layer_outputs[0]
|
| 679 |
+
|
| 680 |
+
if output_attentions:
|
| 681 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 682 |
+
|
| 683 |
+
if output_hidden_states:
|
| 684 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 685 |
+
|
| 686 |
+
if not return_dict:
|
| 687 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 688 |
+
return BaseModelOutput(
|
| 689 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
class FlavaPooler(nn.Module):
|
| 694 |
+
def __init__(self, config: FlavaPossibleConfigs):
|
| 695 |
+
super().__init__()
|
| 696 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 697 |
+
self.activation = nn.Tanh()
|
| 698 |
+
|
| 699 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 700 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 701 |
+
# to the first token.
|
| 702 |
+
first_token_tensor = hidden_states[:, 0]
|
| 703 |
+
pooled_output = self.dense(first_token_tensor)
|
| 704 |
+
pooled_output = self.activation(pooled_output)
|
| 705 |
+
return pooled_output
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
FLAVA_START_DOCSTRING = r"""
|
| 709 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 710 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 711 |
+
behavior.
|
| 712 |
+
|
| 713 |
+
Parameters:
|
| 714 |
+
config ([`{config}`]): Model configuration class with all the parameters of the model.
|
| 715 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 716 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 717 |
+
"""
|
| 718 |
+
|
| 719 |
+
FLAVA_INPUTS_DOCSTRING_COMMON = r"""
|
| 720 |
+
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
| 721 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 722 |
+
- 1 for tokens that are **not masked**,
|
| 723 |
+
- 0 for tokens that are **masked**.
|
| 724 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 725 |
+
|
| 726 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 727 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 728 |
+
|
| 729 |
+
- 1 indicates the head is **not masked**,
|
| 730 |
+
- 0 indicates the head is **masked**.
|
| 731 |
+
|
| 732 |
+
output_attentions (`bool`, *optional*):
|
| 733 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 734 |
+
tensors for more detail.
|
| 735 |
+
output_hidden_states (`bool`, *optional*):
|
| 736 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 737 |
+
more detail.
|
| 738 |
+
|
| 739 |
+
return_dict (`bool`, *optional*):
|
| 740 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 741 |
+
"""
|
| 742 |
+
|
| 743 |
+
FLAVA_IMAGE_INPUTS_DOCSTRING_BASE = r"""
|
| 744 |
+
Args:
|
| 745 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 746 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
| 747 |
+
[`FlavaImageProcessor.__call__`] for details.
|
| 748 |
+
|
| 749 |
+
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
|
| 750 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
| 751 |
+
|
| 752 |
+
interpolate_pos_encoding (`bool`, *optional*):
|
| 753 |
+
Whether to interpolate the pre-trained position encodings.
|
| 754 |
+
"""
|
| 755 |
+
|
| 756 |
+
FLAVA_IMAGE_INPUTS_DOCSTRING = FLAVA_IMAGE_INPUTS_DOCSTRING_BASE + FLAVA_INPUTS_DOCSTRING_COMMON
|
| 757 |
+
|
| 758 |
+
FLAVA_TEXT_INPUTS_DOCSTRING_BASE = r"""
|
| 759 |
+
Args:
|
| 760 |
+
input_ids (`torch.LongTensor` of shape `({0})`):
|
| 761 |
+
Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
|
| 762 |
+
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
|
| 763 |
+
IDs?](../glossary#input-ids)
|
| 764 |
+
|
| 765 |
+
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 766 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 767 |
+
1]`:
|
| 768 |
+
- 0 corresponds to a *sentence A* token,
|
| 769 |
+
- 1 corresponds to a *sentence B* token.
|
| 770 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 771 |
+
"""
|
| 772 |
+
|
| 773 |
+
FLAVA_TEXT_INPUTS_DOCSTRING = FLAVA_TEXT_INPUTS_DOCSTRING_BASE + FLAVA_INPUTS_DOCSTRING_COMMON
|
| 774 |
+
|
| 775 |
+
FLAVA_MULTIMODAL_INPUTS_DOCSTRING = (
|
| 776 |
+
r"""
|
| 777 |
+
Args:
|
| 778 |
+
hidden_states (`torch.FloatTensor` of shape `(batch_size, image_num_patches + text_seq_len, hidden_size)`):
|
| 779 |
+
The concatenated hidden states of unimodal encoders.
|
| 780 |
+
"""
|
| 781 |
+
+ FLAVA_INPUTS_DOCSTRING_COMMON
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
FLAVA_MODEL_INPUTS_DOCSTRING_BASE = r"""
|
| 785 |
+
Args:
|
| 786 |
+
skip_multimodal_encoder (*bool*, *optional*):
|
| 787 |
+
Skip any calculations for multimodal encoder. Useful if multimodal encoding is not going to be used.
|
| 788 |
+
"""
|
| 789 |
+
|
| 790 |
+
FLAVA_MODEL_INPUTS_DOCSTRING = (
|
| 791 |
+
FLAVA_IMAGE_INPUTS_DOCSTRING_BASE
|
| 792 |
+
+ FLAVA_TEXT_INPUTS_DOCSTRING_BASE
|
| 793 |
+
+ FLAVA_INPUTS_DOCSTRING_COMMON
|
| 794 |
+
+ FLAVA_MODEL_INPUTS_DOCSTRING_BASE
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
FLAVA_PRETRAINING_INPUTS_DOCSTRING = (
|
| 799 |
+
r"""
|
| 800 |
+
Args:
|
| 801 |
+
input_ids_masked (`torch.LongTensor` of shape `({0})`):
|
| 802 |
+
Indices of input sequence tokens in the vocabulary. These ones are the masked version of the original task
|
| 803 |
+
to be used with MLM. Indices can be obtained using [`AutoTokenizer`] along with
|
| 804 |
+
[`DataCollatorForMaskedLanguageModeling`]. See [`PreTrainedTokenizer.encode`] and
|
| 805 |
+
[`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
|
| 806 |
+
|
| 807 |
+
"""
|
| 808 |
+
+ FLAVA_TEXT_INPUTS_DOCSTRING_BASE
|
| 809 |
+
+ FLAVA_IMAGE_INPUTS_DOCSTRING_BASE
|
| 810 |
+
+ r"""
|
| 811 |
+
image_attention_mask (`torch.FloatTensor` of shape `({1})`, *optional*):
|
| 812 |
+
Mask to avoid performing attention on padding token indices specifically for images. Mask values selected
|
| 813 |
+
in `[0, 1]`:
|
| 814 |
+
- 1 for tokens that are **not masked**,
|
| 815 |
+
- 0 for tokens that are **masked**.
|
| 816 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 817 |
+
|
| 818 |
+
skip_unmasked_multimodal_encoder (*bool*, *optional*):
|
| 819 |
+
Skip any calculations for multimodal encoder for unmasked inputs. FLAVA pretraining doesn't need unmasked
|
| 820 |
+
multimodal embeddings or outputs as of now.
|
| 821 |
+
|
| 822 |
+
mlm_labels (`torch.LongTensor` of shape `(batch_size, text_seq_len)`, *optional*):
|
| 823 |
+
Labels for computing the left-to-right language and multimodal masked modeling loss (next word prediction).
|
| 824 |
+
Indices should be in `[-100, 0, ..., text_config.vocab_size - 1]` (see `input_ids` docstring). Tokens with
|
| 825 |
+
indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0,
|
| 826 |
+
..., text_config.vocab_size - 1]`.
|
| 827 |
+
|
| 828 |
+
mim_labels (`torch.LongTensor` of shape `(batch_size, image_num_patches)`, *optional*):
|
| 829 |
+
Labels for computing the image and multimodal masked modeling loss. Indices should be in `[-100, 0, ...,
|
| 830 |
+
image_config.vocab_size - 1]`. Tokens with indices set to `-100` are ignored (masked), the loss is only
|
| 831 |
+
computed for the tokens with labels in `[0, ..., image_config.vocab_size - 1]`. If not passed, they are
|
| 832 |
+
generated automatically using the image codebook assigned to the model. By default, it uses
|
| 833 |
+
[`FlavaImageCodebook`]. See [`FlavaImageCodebook`] to understand how to generate mim_labels.
|
| 834 |
+
|
| 835 |
+
itm_labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
|
| 836 |
+
Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
|
| 837 |
+
The pairs with 0 will be skipped for calculation of MMM and global contrastive losses as well.
|
| 838 |
+
|
| 839 |
+
return_loss (`bool`, *optional*, default to None):
|
| 840 |
+
Whether to return calculated loss or not.
|
| 841 |
+
"""
|
| 842 |
+
+ FLAVA_INPUTS_DOCSTRING_COMMON
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
FLAVA_PRETRAINING_START_DOCSTRING_EXTRA = r"""
|
| 846 |
+
Parameters:
|
| 847 |
+
image_codebook ([`nn.Module`]): If passed, the image codebook will be set to this. Otherwise. it will
|
| 848 |
+
be initialized using the image_codebook_config defined in the config first as the first parameter.
|
| 849 |
+
"""
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
class FlavaPreTrainedModel(PreTrainedModel):
|
| 853 |
+
"""
|
| 854 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 855 |
+
models.
|
| 856 |
+
"""
|
| 857 |
+
|
| 858 |
+
config_class = FlavaConfig
|
| 859 |
+
base_model_prefix = "flava"
|
| 860 |
+
supports_gradient_checkpointing = True
|
| 861 |
+
|
| 862 |
+
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
| 863 |
+
"""Initialize the weights"""
|
| 864 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 865 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 866 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 867 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 868 |
+
if module.bias is not None:
|
| 869 |
+
module.bias.data.zero_()
|
| 870 |
+
elif isinstance(module, nn.Embedding):
|
| 871 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 872 |
+
if module.padding_idx is not None:
|
| 873 |
+
module.weight.data[module.padding_idx].zero_()
|
| 874 |
+
elif isinstance(module, nn.LayerNorm):
|
| 875 |
+
module.bias.data.zero_()
|
| 876 |
+
module.weight.data.fill_(1.0)
|
| 877 |
+
elif isinstance(module, FlavaMaskedPredictionHead):
|
| 878 |
+
module.bias.data.zero_()
|
| 879 |
+
elif isinstance(module, FlavaImageEmbeddings):
|
| 880 |
+
module.cls_token.data.zero_()
|
| 881 |
+
module.position_embeddings.data.zero_()
|
| 882 |
+
if module.mask_token is not None:
|
| 883 |
+
module.mask_token.data.zero_()
|
| 884 |
+
elif isinstance(module, FlavaMultimodalModel):
|
| 885 |
+
if module.use_cls_token:
|
| 886 |
+
module.cls_token.data.zero_()
|
| 887 |
+
elif isinstance(module, FlavaModel):
|
| 888 |
+
module.logit_scale.data.fill_(self.config.logit_scale_init_value)
|
| 889 |
+
|
| 890 |
+
|
| 891 |
+
@add_start_docstrings(
|
| 892 |
+
"The bare FLAVA Image Model transformer outputting raw hidden-states without any specific head on top.",
|
| 893 |
+
FLAVA_START_DOCSTRING.format(config="FlavaImageConfig"),
|
| 894 |
+
)
|
| 895 |
+
class FlavaImageModel(FlavaPreTrainedModel):
|
| 896 |
+
config_class = FlavaImageConfig
|
| 897 |
+
# This override allows us to load FlavaImageModel from FlavaModel/FlavaForPreTraining checkpoints.
|
| 898 |
+
base_model_prefix = "flava.image_model"
|
| 899 |
+
main_input_name = "pixel_values"
|
| 900 |
+
|
| 901 |
+
def __init__(self, config: FlavaImageConfig, add_pooling_layer: bool = True):
|
| 902 |
+
super().__init__(config)
|
| 903 |
+
|
| 904 |
+
self.config = config
|
| 905 |
+
|
| 906 |
+
self.embeddings = FlavaImageEmbeddings(config)
|
| 907 |
+
self.encoder = FlavaEncoder(config)
|
| 908 |
+
|
| 909 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 910 |
+
self.pooler = FlavaPooler(config) if add_pooling_layer else None
|
| 911 |
+
|
| 912 |
+
self.post_init()
|
| 913 |
+
|
| 914 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 915 |
+
return self.embeddings.patch_embeddings
|
| 916 |
+
|
| 917 |
+
def set_input_embeddings(self, value: nn.Module):
|
| 918 |
+
self.embeddings.patch_embeddings = value
|
| 919 |
+
|
| 920 |
+
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
| 921 |
+
"""
|
| 922 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 923 |
+
class PreTrainedModel
|
| 924 |
+
"""
|
| 925 |
+
for layer, heads in heads_to_prune.items():
|
| 926 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 927 |
+
|
| 928 |
+
@add_start_docstrings_to_model_forward(FLAVA_IMAGE_INPUTS_DOCSTRING.format("batch_size, image_num_patches"))
|
| 929 |
+
@add_code_sample_docstrings(
|
| 930 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 931 |
+
output_type=BaseModelOutputWithPooling,
|
| 932 |
+
config_class=_CONFIG_CLASS_FOR_IMAGE_MODEL_DOC,
|
| 933 |
+
modality="vision",
|
| 934 |
+
expected_output=_EXPECTED_IMAGE_OUTPUT_SHAPE,
|
| 935 |
+
)
|
| 936 |
+
def forward(
|
| 937 |
+
self,
|
| 938 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 939 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
| 940 |
+
interpolate_pos_encoding: Optional[bool] = None,
|
| 941 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 942 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 943 |
+
output_attentions: Optional[bool] = None,
|
| 944 |
+
output_hidden_states: Optional[bool] = None,
|
| 945 |
+
return_dict: Optional[bool] = None,
|
| 946 |
+
) -> Union[tuple, BaseModelOutputWithPooling]:
|
| 947 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 948 |
+
output_hidden_states = (
|
| 949 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 950 |
+
)
|
| 951 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 952 |
+
|
| 953 |
+
if pixel_values is None:
|
| 954 |
+
raise ValueError("You have to specify pixel_values")
|
| 955 |
+
|
| 956 |
+
# Prepare head mask if needed
|
| 957 |
+
# 1.0 in head_mask indicate we keep the head
|
| 958 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 959 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 960 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 961 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 962 |
+
|
| 963 |
+
embedding_output = self.embeddings(
|
| 964 |
+
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
encoder_outputs = self.encoder(
|
| 968 |
+
embedding_output,
|
| 969 |
+
attention_mask=attention_mask,
|
| 970 |
+
head_mask=head_mask,
|
| 971 |
+
output_attentions=output_attentions,
|
| 972 |
+
output_hidden_states=output_hidden_states,
|
| 973 |
+
return_dict=return_dict,
|
| 974 |
+
)
|
| 975 |
+
sequence_output = encoder_outputs[0]
|
| 976 |
+
sequence_output = self.layernorm(sequence_output)
|
| 977 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 978 |
+
|
| 979 |
+
if not return_dict:
|
| 980 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 981 |
+
|
| 982 |
+
return BaseModelOutputWithPooling(
|
| 983 |
+
last_hidden_state=sequence_output,
|
| 984 |
+
pooler_output=pooled_output,
|
| 985 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 986 |
+
attentions=encoder_outputs.attentions,
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
|
| 990 |
+
@add_start_docstrings(
|
| 991 |
+
"The bare FLAVA Text Model transformer outputting raw hidden-states without any specific head on top.",
|
| 992 |
+
FLAVA_START_DOCSTRING.format(config="FlavaTextConfig"),
|
| 993 |
+
)
|
| 994 |
+
class FlavaTextModel(FlavaPreTrainedModel):
|
| 995 |
+
config_class = FlavaTextConfig
|
| 996 |
+
# This override allows us to load FlavaTextModel from FlavaModel/FlavaForPreTraining checkpoints.
|
| 997 |
+
base_model_prefix = "flava.text_model"
|
| 998 |
+
|
| 999 |
+
def __init__(self, config: FlavaTextConfig, add_pooling_layer: bool = True):
|
| 1000 |
+
super().__init__(config)
|
| 1001 |
+
self.config = config
|
| 1002 |
+
|
| 1003 |
+
self.embeddings = FlavaTextEmbeddings(config)
|
| 1004 |
+
self.encoder = FlavaEncoder(config)
|
| 1005 |
+
|
| 1006 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 1007 |
+
self.pooler = FlavaPooler(config) if add_pooling_layer else None
|
| 1008 |
+
|
| 1009 |
+
self.post_init()
|
| 1010 |
+
|
| 1011 |
+
def get_input_embeddings(self) -> PatchEmbeddings:
|
| 1012 |
+
return self.embeddings.word_embeddings
|
| 1013 |
+
|
| 1014 |
+
def set_input_embeddings(self, value: nn.Module):
|
| 1015 |
+
self.embeddings.word_embeddings = value
|
| 1016 |
+
|
| 1017 |
+
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
| 1018 |
+
"""
|
| 1019 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 1020 |
+
class PreTrainedModel
|
| 1021 |
+
"""
|
| 1022 |
+
for layer, heads in heads_to_prune.items():
|
| 1023 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 1024 |
+
|
| 1025 |
+
@add_start_docstrings_to_model_forward(FLAVA_TEXT_INPUTS_DOCSTRING.format("batch_size, text_seq_length"))
|
| 1026 |
+
@add_code_sample_docstrings(
|
| 1027 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1028 |
+
output_type=BaseModelOutputWithPooling,
|
| 1029 |
+
config_class=_CONFIG_CLASS_FOR_TEXT_MODEL_DOC,
|
| 1030 |
+
)
|
| 1031 |
+
def forward(
|
| 1032 |
+
self,
|
| 1033 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1034 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1035 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1036 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1037 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1038 |
+
output_attentions: Optional[bool] = None,
|
| 1039 |
+
output_hidden_states: Optional[bool] = None,
|
| 1040 |
+
return_dict: Optional[bool] = None,
|
| 1041 |
+
) -> Union[tuple, BaseModelOutputWithPooling]:
|
| 1042 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1043 |
+
output_hidden_states = (
|
| 1044 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1045 |
+
)
|
| 1046 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1047 |
+
|
| 1048 |
+
if input_ids is None:
|
| 1049 |
+
raise ValueError("You have to specify input_ids")
|
| 1050 |
+
|
| 1051 |
+
input_shape = input_ids.size()
|
| 1052 |
+
|
| 1053 |
+
if attention_mask is None:
|
| 1054 |
+
attention_mask = torch.ones(input_shape, device=input_ids.device)
|
| 1055 |
+
|
| 1056 |
+
# Prepare head mask if needed
|
| 1057 |
+
# 1.0 in head_mask indicate we keep the head
|
| 1058 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 1059 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 1060 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 1061 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 1062 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
| 1063 |
+
attention_mask, input_shape, input_ids.device
|
| 1064 |
+
)
|
| 1065 |
+
|
| 1066 |
+
embedding_output = self.embeddings(
|
| 1067 |
+
input_ids=input_ids,
|
| 1068 |
+
token_type_ids=token_type_ids,
|
| 1069 |
+
position_ids=position_ids,
|
| 1070 |
+
)
|
| 1071 |
+
|
| 1072 |
+
encoder_outputs = self.encoder(
|
| 1073 |
+
embedding_output,
|
| 1074 |
+
attention_mask=extended_attention_mask,
|
| 1075 |
+
head_mask=head_mask,
|
| 1076 |
+
output_attentions=output_attentions,
|
| 1077 |
+
output_hidden_states=output_hidden_states,
|
| 1078 |
+
return_dict=return_dict,
|
| 1079 |
+
)
|
| 1080 |
+
sequence_output = encoder_outputs[0]
|
| 1081 |
+
sequence_output = self.layernorm(sequence_output)
|
| 1082 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 1083 |
+
|
| 1084 |
+
if not return_dict:
|
| 1085 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 1086 |
+
|
| 1087 |
+
return BaseModelOutputWithPooling(
|
| 1088 |
+
last_hidden_state=sequence_output,
|
| 1089 |
+
pooler_output=pooled_output,
|
| 1090 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1091 |
+
attentions=encoder_outputs.attentions,
|
| 1092 |
+
)
|
| 1093 |
+
|
| 1094 |
+
|
| 1095 |
+
@add_start_docstrings(
|
| 1096 |
+
"The bare FLAVA Multimodal Model transformer outputting raw hidden-states without any specific head on top.",
|
| 1097 |
+
FLAVA_START_DOCSTRING.format(config="FlavaMultimodalConfig"),
|
| 1098 |
+
)
|
| 1099 |
+
class FlavaMultimodalModel(FlavaPreTrainedModel):
|
| 1100 |
+
config_class = FlavaMultimodalConfig
|
| 1101 |
+
# This override allows us to load FlavaMultimodalModel from FlavaModel/FlavaForPreTraining checkpoints.
|
| 1102 |
+
base_model_prefix = "flava.multimodal_model"
|
| 1103 |
+
main_input_name = "hidden_states"
|
| 1104 |
+
|
| 1105 |
+
def __init__(self, config: FlavaMultimodalConfig, add_pooling_layer=True):
|
| 1106 |
+
super().__init__(config)
|
| 1107 |
+
self.config = config
|
| 1108 |
+
self.use_cls_token = self.config.use_cls_token
|
| 1109 |
+
if self.use_cls_token:
|
| 1110 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
| 1111 |
+
|
| 1112 |
+
self.encoder = FlavaEncoder(config)
|
| 1113 |
+
|
| 1114 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 1115 |
+
self.pooler = FlavaPooler(config) if add_pooling_layer else None
|
| 1116 |
+
|
| 1117 |
+
self.post_init()
|
| 1118 |
+
|
| 1119 |
+
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
| 1120 |
+
"""
|
| 1121 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 1122 |
+
class PreTrainedModel
|
| 1123 |
+
"""
|
| 1124 |
+
for layer, heads in heads_to_prune.items():
|
| 1125 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 1126 |
+
|
| 1127 |
+
@add_start_docstrings_to_model_forward(
|
| 1128 |
+
FLAVA_MULTIMODAL_INPUTS_DOCSTRING.format("batch_size, image_num_patches + text_seq_len")
|
| 1129 |
+
)
|
| 1130 |
+
@add_code_sample_docstrings(
|
| 1131 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1132 |
+
output_type=BaseModelOutputWithPooling,
|
| 1133 |
+
config_class=_CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC,
|
| 1134 |
+
)
|
| 1135 |
+
def forward(
|
| 1136 |
+
self,
|
| 1137 |
+
hidden_states: torch.Tensor,
|
| 1138 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1139 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1140 |
+
output_attentions: Optional[bool] = None,
|
| 1141 |
+
output_hidden_states: Optional[bool] = None,
|
| 1142 |
+
return_dict: Optional[bool] = None,
|
| 1143 |
+
) -> Union[tuple, BaseModelOutputWithPooling]:
|
| 1144 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1145 |
+
output_hidden_states = (
|
| 1146 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1147 |
+
)
|
| 1148 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1149 |
+
|
| 1150 |
+
batch_size, seq_length, _ = hidden_states.size()
|
| 1151 |
+
|
| 1152 |
+
if self.use_cls_token:
|
| 1153 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
| 1154 |
+
hidden_states = torch.cat((cls_tokens, hidden_states), dim=1)
|
| 1155 |
+
seq_length += 1
|
| 1156 |
+
|
| 1157 |
+
if attention_mask is None:
|
| 1158 |
+
attention_mask = torch.ones((batch_size, seq_length), device=hidden_states.device)
|
| 1159 |
+
|
| 1160 |
+
# Prepare head mask if needed
|
| 1161 |
+
# 1.0 in head_mask indicate we keep the head
|
| 1162 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 1163 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 1164 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 1165 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 1166 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
| 1167 |
+
attention_mask, (batch_size, seq_length), hidden_states.device
|
| 1168 |
+
)
|
| 1169 |
+
|
| 1170 |
+
encoder_outputs = self.encoder(
|
| 1171 |
+
hidden_states,
|
| 1172 |
+
attention_mask=extended_attention_mask,
|
| 1173 |
+
head_mask=head_mask,
|
| 1174 |
+
output_attentions=output_attentions,
|
| 1175 |
+
output_hidden_states=output_hidden_states,
|
| 1176 |
+
return_dict=return_dict,
|
| 1177 |
+
)
|
| 1178 |
+
sequence_output = encoder_outputs[0]
|
| 1179 |
+
sequence_output = self.layernorm(sequence_output)
|
| 1180 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 1181 |
+
|
| 1182 |
+
if not return_dict:
|
| 1183 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 1184 |
+
|
| 1185 |
+
return BaseModelOutputWithPooling(
|
| 1186 |
+
last_hidden_state=sequence_output,
|
| 1187 |
+
pooler_output=pooled_output,
|
| 1188 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1189 |
+
attentions=encoder_outputs.attentions,
|
| 1190 |
+
)
|
| 1191 |
+
|
| 1192 |
+
|
| 1193 |
+
@add_start_docstrings(
|
| 1194 |
+
"The bare FLAVA Model transformer outputting raw hidden-states without any specific head on top.",
|
| 1195 |
+
FLAVA_START_DOCSTRING.format(config="FlavaConfig"),
|
| 1196 |
+
)
|
| 1197 |
+
class FlavaModel(FlavaPreTrainedModel):
|
| 1198 |
+
config_class = FlavaConfig
|
| 1199 |
+
|
| 1200 |
+
def __init__(self, config: FlavaConfig):
|
| 1201 |
+
super().__init__(config)
|
| 1202 |
+
|
| 1203 |
+
if not isinstance(config.text_config, FlavaTextConfig):
|
| 1204 |
+
raise TypeError(
|
| 1205 |
+
"config.text_config is expected to be of type FlavaTextConfig but is of type"
|
| 1206 |
+
f" {type(config.text_config)}."
|
| 1207 |
+
)
|
| 1208 |
+
|
| 1209 |
+
if not isinstance(config.image_config, FlavaImageConfig):
|
| 1210 |
+
raise TypeError(
|
| 1211 |
+
"config.image_config is expected to be of type FlavaImageConfig but is of type"
|
| 1212 |
+
f" {type(config.image_config)}."
|
| 1213 |
+
)
|
| 1214 |
+
|
| 1215 |
+
if not isinstance(config.multimodal_config, FlavaMultimodalConfig):
|
| 1216 |
+
raise TypeError(
|
| 1217 |
+
"config.multimodal_config is expected to be of type FlavaMultimodalConfig but "
|
| 1218 |
+
+ f"is of type {type(config.multimodal_config)}."
|
| 1219 |
+
)
|
| 1220 |
+
|
| 1221 |
+
text_config = config.text_config
|
| 1222 |
+
image_config = config.image_config
|
| 1223 |
+
multimodal_config = config.multimodal_config
|
| 1224 |
+
|
| 1225 |
+
self.projection_dim = config.projection_dim
|
| 1226 |
+
self.text_hidden_size = text_config.hidden_size
|
| 1227 |
+
self.image_hidden_size = image_config.hidden_size
|
| 1228 |
+
self.mm_hidden_size = multimodal_config.hidden_size
|
| 1229 |
+
|
| 1230 |
+
self.text_model = FlavaTextModel(text_config)
|
| 1231 |
+
self.image_model = FlavaImageModel(image_config)
|
| 1232 |
+
self.multimodal_model = FlavaMultimodalModel(multimodal_config)
|
| 1233 |
+
|
| 1234 |
+
self.image_projection = nn.Linear(self.image_hidden_size, self.projection_dim)
|
| 1235 |
+
self.text_projection = nn.Linear(self.text_hidden_size, self.projection_dim)
|
| 1236 |
+
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
| 1237 |
+
|
| 1238 |
+
self.image_to_mm_projection = nn.Linear(self.image_hidden_size, self.mm_hidden_size)
|
| 1239 |
+
self.text_to_mm_projection = nn.Linear(self.text_hidden_size, self.mm_hidden_size)
|
| 1240 |
+
# Initialize weights and apply final processing
|
| 1241 |
+
self.post_init()
|
| 1242 |
+
|
| 1243 |
+
@add_start_docstrings_to_model_forward(FLAVA_TEXT_INPUTS_DOCSTRING.format("batch_size, text_seq_length"))
|
| 1244 |
+
def get_text_features(
|
| 1245 |
+
self,
|
| 1246 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1247 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1248 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1249 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1250 |
+
output_attentions: Optional[bool] = None,
|
| 1251 |
+
output_hidden_states: Optional[bool] = None,
|
| 1252 |
+
return_dict: Optional[bool] = None,
|
| 1253 |
+
) -> torch.FloatTensor:
|
| 1254 |
+
r"""
|
| 1255 |
+
Returns:
|
| 1256 |
+
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
| 1257 |
+
applying the projection layer to the pooled output of [`FlavaTextModel`].
|
| 1258 |
+
|
| 1259 |
+
Examples:
|
| 1260 |
+
|
| 1261 |
+
```python
|
| 1262 |
+
>>> from transformers import AutoProcessor, FlavaModel
|
| 1263 |
+
|
| 1264 |
+
>>> model = FlavaModel.from_pretrained("{0}")
|
| 1265 |
+
>>> processor = AutoProcessor.from_pretrained("{0}")
|
| 1266 |
+
|
| 1267 |
+
>>> inputs = processor(
|
| 1268 |
+
... text=["a photo of a cat", "a photo of a dog"], max_length=77, padding="max_length", return_tensors="pt"
|
| 1269 |
+
... )
|
| 1270 |
+
>>> text_features = model.get_text_features(**inputs)
|
| 1271 |
+
```""".format(_CHECKPOINT_FOR_DOC)
|
| 1272 |
+
text_outputs = self.text_model(
|
| 1273 |
+
input_ids=input_ids,
|
| 1274 |
+
attention_mask=attention_mask,
|
| 1275 |
+
token_type_ids=token_type_ids,
|
| 1276 |
+
position_ids=position_ids,
|
| 1277 |
+
output_attentions=output_attentions,
|
| 1278 |
+
output_hidden_states=output_hidden_states,
|
| 1279 |
+
return_dict=return_dict,
|
| 1280 |
+
)
|
| 1281 |
+
|
| 1282 |
+
pooled_output = text_outputs[0] # last_hidden_state
|
| 1283 |
+
text_features = self.text_projection(pooled_output)
|
| 1284 |
+
|
| 1285 |
+
return text_features
|
| 1286 |
+
|
| 1287 |
+
@add_start_docstrings_to_model_forward(FLAVA_IMAGE_INPUTS_DOCSTRING.format("batch_size, image_num_patches"))
|
| 1288 |
+
def get_image_features(
|
| 1289 |
+
self,
|
| 1290 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 1291 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
| 1292 |
+
interpolate_pos_encoding: Optional[bool] = None,
|
| 1293 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1294 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1295 |
+
output_attentions: Optional[bool] = None,
|
| 1296 |
+
output_hidden_states: Optional[bool] = None,
|
| 1297 |
+
return_dict: Optional[bool] = None,
|
| 1298 |
+
) -> torch.FloatTensor:
|
| 1299 |
+
r"""
|
| 1300 |
+
Returns:
|
| 1301 |
+
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
| 1302 |
+
applying the projection layer to the pooled output of [`FlavaImageModel`].
|
| 1303 |
+
|
| 1304 |
+
Examples:
|
| 1305 |
+
|
| 1306 |
+
```python
|
| 1307 |
+
>>> from PIL import Image
|
| 1308 |
+
>>> import requests
|
| 1309 |
+
>>> from transformers import AutoProcessor, FlavaModel
|
| 1310 |
+
|
| 1311 |
+
>>> model = FlavaModel.from_pretrained("{0}")
|
| 1312 |
+
>>> processor = AutoProcessor.from_pretrained("{0}")
|
| 1313 |
+
|
| 1314 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1315 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1316 |
+
|
| 1317 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
| 1318 |
+
|
| 1319 |
+
>>> image_features = model.get_image_features(**inputs)
|
| 1320 |
+
```""".format(_CHECKPOINT_FOR_DOC)
|
| 1321 |
+
image_outputs = self.image_model(
|
| 1322 |
+
pixel_values=pixel_values,
|
| 1323 |
+
bool_masked_pos=bool_masked_pos,
|
| 1324 |
+
attention_mask=attention_mask,
|
| 1325 |
+
head_mask=head_mask,
|
| 1326 |
+
output_attentions=output_attentions,
|
| 1327 |
+
output_hidden_states=output_hidden_states,
|
| 1328 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1329 |
+
return_dict=return_dict,
|
| 1330 |
+
)
|
| 1331 |
+
|
| 1332 |
+
pooled_output = image_outputs[0] # last_hidden_state
|
| 1333 |
+
image_features = self.image_projection(pooled_output)
|
| 1334 |
+
|
| 1335 |
+
return image_features
|
| 1336 |
+
|
| 1337 |
+
@add_start_docstrings_to_model_forward(
|
| 1338 |
+
FLAVA_MODEL_INPUTS_DOCSTRING.format("batch_size, image_num_patches + text_seq_len")
|
| 1339 |
+
)
|
| 1340 |
+
@replace_return_docstrings(output_type=FlavaModelOutput, config_class=FlavaConfig)
|
| 1341 |
+
def forward(
|
| 1342 |
+
self,
|
| 1343 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1344 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1345 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1346 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1347 |
+
bool_masked_pos: Optional[torch.Tensor] = None,
|
| 1348 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1349 |
+
image_attention_mask: Optional[torch.Tensor] = None,
|
| 1350 |
+
skip_multimodal_encoder: Optional[bool] = None,
|
| 1351 |
+
output_attentions: Optional[bool] = None,
|
| 1352 |
+
output_hidden_states: bool = True,
|
| 1353 |
+
return_dict: Optional[bool] = None,
|
| 1354 |
+
) -> Union[Tuple, FlavaOutput]:
|
| 1355 |
+
r"""
|
| 1356 |
+
Returns:
|
| 1357 |
+
|
| 1358 |
+
Examples:
|
| 1359 |
+
|
| 1360 |
+
```python
|
| 1361 |
+
>>> from PIL import Image
|
| 1362 |
+
>>> import requests
|
| 1363 |
+
>>> from transformers import AutoProcessor, FlavaModel
|
| 1364 |
+
|
| 1365 |
+
>>> model = FlavaModel.from_pretrained("facebook/flava-full")
|
| 1366 |
+
>>> processor = AutoProcessor.from_pretrained("facebook/flava-full")
|
| 1367 |
+
|
| 1368 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1369 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1370 |
+
|
| 1371 |
+
>>> inputs = processor(text=["a photo of a cat"], images=image, return_tensors="pt", padding=True)
|
| 1372 |
+
|
| 1373 |
+
>>> outputs = model(**inputs)
|
| 1374 |
+
|
| 1375 |
+
>>> image_embeddings = outputs.image_embeddings
|
| 1376 |
+
>>> text_embeddings = outputs.text_embeddings
|
| 1377 |
+
>>> multimodal_embeddings = outputs.multimodal_embeddings
|
| 1378 |
+
|
| 1379 |
+
>>> outputs.image_embeddings.shape
|
| 1380 |
+
torch.Size([1, 197, 768])
|
| 1381 |
+
|
| 1382 |
+
>>> text_embeddings.shape
|
| 1383 |
+
torch.Size([1, 7, 768])
|
| 1384 |
+
|
| 1385 |
+
>>> multimodal_embeddings.shape
|
| 1386 |
+
torch.Size([1, 205, 768])
|
| 1387 |
+
```
|
| 1388 |
+
"""
|
| 1389 |
+
|
| 1390 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 1391 |
+
if not output_hidden_states:
|
| 1392 |
+
raise ValueError("FLAVA model requires hidden states to work. Please set `output_hidden_states=True`")
|
| 1393 |
+
image_embeddings = None
|
| 1394 |
+
image_states = None
|
| 1395 |
+
image_mm_projection = None
|
| 1396 |
+
image_output = None
|
| 1397 |
+
if pixel_values is not None:
|
| 1398 |
+
image_output = self.image_model(
|
| 1399 |
+
pixel_values=pixel_values,
|
| 1400 |
+
bool_masked_pos=bool_masked_pos,
|
| 1401 |
+
attention_mask=image_attention_mask,
|
| 1402 |
+
output_attentions=output_attentions,
|
| 1403 |
+
output_hidden_states=output_hidden_states,
|
| 1404 |
+
return_dict=return_dict,
|
| 1405 |
+
)
|
| 1406 |
+
image_embeddings, image_states = image_output[0], image_output[2]
|
| 1407 |
+
# Note that these states don't use final layernorm in the transformer model
|
| 1408 |
+
image_mm_projection = self.image_to_mm_projection(image_states[-1])
|
| 1409 |
+
|
| 1410 |
+
text_embeddings = None
|
| 1411 |
+
text_states = None
|
| 1412 |
+
text_mm_projection = None
|
| 1413 |
+
text_output = None
|
| 1414 |
+
if input_ids is not None:
|
| 1415 |
+
text_output = self.text_model(
|
| 1416 |
+
input_ids=input_ids,
|
| 1417 |
+
attention_mask=attention_mask,
|
| 1418 |
+
position_ids=position_ids,
|
| 1419 |
+
token_type_ids=token_type_ids,
|
| 1420 |
+
output_attentions=output_attentions,
|
| 1421 |
+
output_hidden_states=output_hidden_states,
|
| 1422 |
+
return_dict=return_dict,
|
| 1423 |
+
)
|
| 1424 |
+
|
| 1425 |
+
text_embeddings, text_states = text_output[0], text_output[2]
|
| 1426 |
+
# Note that these states don't use final layernorm in the transformer model
|
| 1427 |
+
text_mm_projection = self.text_to_mm_projection(text_states[-1])
|
| 1428 |
+
|
| 1429 |
+
multimodal_embeddings = None
|
| 1430 |
+
multimodal_output = None
|
| 1431 |
+
if image_mm_projection is not None and text_mm_projection is not None and not skip_multimodal_encoder:
|
| 1432 |
+
if attention_mask is not None:
|
| 1433 |
+
batch_size, seq_len, _ = image_mm_projection.shape
|
| 1434 |
+
if self.multimodal_model.use_cls_token:
|
| 1435 |
+
seq_len += 1
|
| 1436 |
+
attention_mask_image = torch.ones(batch_size, seq_len, device=image_mm_projection.device)
|
| 1437 |
+
attention_multimodal = torch.cat([attention_mask_image, attention_mask], dim=1)
|
| 1438 |
+
else:
|
| 1439 |
+
attention_multimodal = None
|
| 1440 |
+
multimodal_input = torch.cat([image_mm_projection, text_mm_projection], dim=1)
|
| 1441 |
+
multimodal_output = self.multimodal_model(
|
| 1442 |
+
multimodal_input, attention_mask=attention_multimodal, return_dict=return_dict
|
| 1443 |
+
)
|
| 1444 |
+
multimodal_embeddings = multimodal_output[0]
|
| 1445 |
+
|
| 1446 |
+
if not return_dict:
|
| 1447 |
+
return (
|
| 1448 |
+
image_embeddings,
|
| 1449 |
+
image_output,
|
| 1450 |
+
text_embeddings,
|
| 1451 |
+
text_output,
|
| 1452 |
+
multimodal_embeddings,
|
| 1453 |
+
multimodal_output,
|
| 1454 |
+
)
|
| 1455 |
+
|
| 1456 |
+
return FlavaModelOutput(
|
| 1457 |
+
image_embeddings=image_embeddings,
|
| 1458 |
+
image_output=image_output,
|
| 1459 |
+
text_embeddings=text_embeddings,
|
| 1460 |
+
text_output=text_output,
|
| 1461 |
+
multimodal_embeddings=multimodal_embeddings,
|
| 1462 |
+
multimodal_output=multimodal_output,
|
| 1463 |
+
)
|
| 1464 |
+
|
| 1465 |
+
|
| 1466 |
+
class FlavaImageCodebookResPath(nn.Module):
|
| 1467 |
+
def __init__(self, in_size: int, out_size: int, **kwargs):
|
| 1468 |
+
super().__init__()
|
| 1469 |
+
hid_size = out_size // 4
|
| 1470 |
+
|
| 1471 |
+
path = OrderedDict()
|
| 1472 |
+
path["relu_1"] = nn.ReLU()
|
| 1473 |
+
path["conv_1"] = nn.Conv2d(in_size, hid_size, kernel_size=3, padding=1)
|
| 1474 |
+
path["relu_2"] = nn.ReLU()
|
| 1475 |
+
path["conv_2"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1)
|
| 1476 |
+
path["relu_3"] = nn.ReLU()
|
| 1477 |
+
path["conv_3"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1)
|
| 1478 |
+
path["relu_4"] = nn.ReLU()
|
| 1479 |
+
path["conv_4"] = nn.Conv2d(hid_size, out_size, kernel_size=1, padding=0)
|
| 1480 |
+
|
| 1481 |
+
self.path = nn.Sequential(path)
|
| 1482 |
+
|
| 1483 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1484 |
+
return self.path(x)
|
| 1485 |
+
|
| 1486 |
+
|
| 1487 |
+
class FlavaImageCodebookBlock(nn.Module):
|
| 1488 |
+
def __init__(self, in_size: int, out_size: int, num_layers: int, **kwargs):
|
| 1489 |
+
super().__init__()
|
| 1490 |
+
|
| 1491 |
+
self.post_gain = 1 / (num_layers**2)
|
| 1492 |
+
|
| 1493 |
+
if in_size != out_size:
|
| 1494 |
+
self.id_path = nn.Conv2d(in_size, out_size, kernel_size=1, padding=0)
|
| 1495 |
+
else:
|
| 1496 |
+
self.id_path = nn.Identity()
|
| 1497 |
+
|
| 1498 |
+
self.res_path = FlavaImageCodebookResPath(in_size, out_size)
|
| 1499 |
+
|
| 1500 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1501 |
+
return self.id_path(x) + self.post_gain * self.res_path(x)
|
| 1502 |
+
|
| 1503 |
+
|
| 1504 |
+
class FlavaImageCodebookLayerGroup(nn.Module):
|
| 1505 |
+
def __init__(self, num_blocks: int, num_layers: int, in_size: int, out_size: int, use_pool: bool = True):
|
| 1506 |
+
super().__init__()
|
| 1507 |
+
blocks = OrderedDict()
|
| 1508 |
+
for i in range(num_blocks):
|
| 1509 |
+
if i == 0:
|
| 1510 |
+
blocks[f"block_{i + 1}"] = FlavaImageCodebookBlock(in_size, out_size, num_layers)
|
| 1511 |
+
else:
|
| 1512 |
+
blocks[f"block_{i + 1}"] = FlavaImageCodebookBlock(out_size, out_size, num_layers)
|
| 1513 |
+
|
| 1514 |
+
if use_pool:
|
| 1515 |
+
blocks["pool"] = nn.MaxPool2d(kernel_size=2)
|
| 1516 |
+
|
| 1517 |
+
self.group = nn.Sequential(blocks)
|
| 1518 |
+
|
| 1519 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1520 |
+
return self.group(x)
|
| 1521 |
+
|
| 1522 |
+
|
| 1523 |
+
# Inspired by DALLE Encoder in https://github.com/openai/DALL-E/blob/5be4b236bc3ade6943662354117a0e83752cc322/dall_e/encoder.py#L42
|
| 1524 |
+
@add_start_docstrings(
|
| 1525 |
+
"""
|
| 1526 |
+
The FLAVA's image codebook model inspired from DALL-E's original encoder. Outputs raw hidden states and can be used
|
| 1527 |
+
to generate image tokens for an image based on DALL-E's vocab. Used to generate labels for MIM. Use
|
| 1528 |
+
`get_codebook_indices` to get image tokens for an image.
|
| 1529 |
+
""",
|
| 1530 |
+
FLAVA_START_DOCSTRING.format(config="FlavaImageCodebookConfig"),
|
| 1531 |
+
)
|
| 1532 |
+
class FlavaImageCodebook(FlavaPreTrainedModel):
|
| 1533 |
+
base_model_prefix = ""
|
| 1534 |
+
config_class = FlavaImageCodebookConfig
|
| 1535 |
+
main_input_name = "pixel_values"
|
| 1536 |
+
supports_gradient_checkpointing = False
|
| 1537 |
+
|
| 1538 |
+
def __init__(
|
| 1539 |
+
self,
|
| 1540 |
+
config: FlavaImageCodebookConfig,
|
| 1541 |
+
**kwargs: Any,
|
| 1542 |
+
):
|
| 1543 |
+
super().__init__(config)
|
| 1544 |
+
|
| 1545 |
+
self.config = config
|
| 1546 |
+
self.num_groups = config.num_groups
|
| 1547 |
+
self.input_channels = config.input_channels
|
| 1548 |
+
self.num_blocks_per_group = config.num_blocks_per_group
|
| 1549 |
+
self.hidden_size = config.hidden_size
|
| 1550 |
+
self.vocab_size = config.vocab_size
|
| 1551 |
+
|
| 1552 |
+
num_layers = self.num_groups * self.num_blocks_per_group
|
| 1553 |
+
|
| 1554 |
+
output_blocks = OrderedDict()
|
| 1555 |
+
output_blocks["relu"] = nn.ReLU()
|
| 1556 |
+
output_blocks["conv"] = nn.Conv2d(8 * self.hidden_size, self.vocab_size, kernel_size=1, padding=0)
|
| 1557 |
+
|
| 1558 |
+
blocks = OrderedDict()
|
| 1559 |
+
blocks["input"] = nn.Conv2d(self.input_channels, 1 * self.hidden_size, kernel_size=7, padding=3)
|
| 1560 |
+
blocks["group_1"] = FlavaImageCodebookLayerGroup(
|
| 1561 |
+
self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 1 * self.hidden_size
|
| 1562 |
+
)
|
| 1563 |
+
blocks["group_2"] = FlavaImageCodebookLayerGroup(
|
| 1564 |
+
self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 2 * self.hidden_size
|
| 1565 |
+
)
|
| 1566 |
+
blocks["group_3"] = FlavaImageCodebookLayerGroup(
|
| 1567 |
+
self.num_blocks_per_group, num_layers, 2 * self.hidden_size, 4 * self.hidden_size
|
| 1568 |
+
)
|
| 1569 |
+
blocks["group_4"] = FlavaImageCodebookLayerGroup(
|
| 1570 |
+
self.num_blocks_per_group, num_layers, 4 * self.hidden_size, 8 * self.hidden_size, use_pool=False
|
| 1571 |
+
)
|
| 1572 |
+
blocks["output"] = nn.Sequential(output_blocks)
|
| 1573 |
+
|
| 1574 |
+
self.blocks = nn.Sequential(blocks)
|
| 1575 |
+
|
| 1576 |
+
self.post_init()
|
| 1577 |
+
|
| 1578 |
+
if self.config.freeze:
|
| 1579 |
+
for param in self.parameters():
|
| 1580 |
+
param.requires_grad = False
|
| 1581 |
+
|
| 1582 |
+
def get_codebook_indices(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 1583 |
+
"""
|
| 1584 |
+
Args:
|
| 1585 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 1586 |
+
Pixel values. Codebook pixel values can be obtained using [`AutoImageProcessor`] by passing
|
| 1587 |
+
`return_codebook_pixels=True`. See [`FlavaImageProcessor.__call__`] for details.
|
| 1588 |
+
|
| 1589 |
+
Examples:
|
| 1590 |
+
```python
|
| 1591 |
+
>>> from PIL import Image
|
| 1592 |
+
>>> import requests
|
| 1593 |
+
>>> from transformers import AutoImageProcessor, FlavaImageCodebook
|
| 1594 |
+
|
| 1595 |
+
>>> model = FlavaImageCodebook.from_pretrained("{0}")
|
| 1596 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("{0}")
|
| 1597 |
+
|
| 1598 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1599 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1600 |
+
|
| 1601 |
+
>>> inputs = image_processor([image], return_codebook_pixels=True, return_tensors="pt")
|
| 1602 |
+
>>> inputs = dict(pixel_values=inputs.codebook_pixel_values)
|
| 1603 |
+
|
| 1604 |
+
>>> outputs = model.get_codebook_indices(**inputs)
|
| 1605 |
+
```
|
| 1606 |
+
""".format(_CHECKPOINT_FOR_CODEBOOK_DOC)
|
| 1607 |
+
z_logits = self.blocks(pixel_values)
|
| 1608 |
+
return torch.argmax(z_logits, axis=1)
|
| 1609 |
+
|
| 1610 |
+
def get_codebook_probs(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 1611 |
+
z_logits = self.blocks(pixel_values)
|
| 1612 |
+
return nn.Softmax(dim=1)(z_logits)
|
| 1613 |
+
|
| 1614 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
| 1615 |
+
"""
|
| 1616 |
+
Args:
|
| 1617 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 1618 |
+
Pixel values. Codebook pixel values can be obtained using [`AutoImageProcessor`] by passing
|
| 1619 |
+
`return_codebook_pixels=True`. See [`FlavaImageProcessor.__call__`] for details.
|
| 1620 |
+
|
| 1621 |
+
Examples:
|
| 1622 |
+
|
| 1623 |
+
```python
|
| 1624 |
+
>>> from PIL import Image
|
| 1625 |
+
>>> import requests
|
| 1626 |
+
>>> from transformers import AutoImageProcessor, FlavaImageCodebook
|
| 1627 |
+
|
| 1628 |
+
>>> model = FlavaImageCodebook.from_pretrained("{0}")
|
| 1629 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("{0}")
|
| 1630 |
+
|
| 1631 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1632 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1633 |
+
|
| 1634 |
+
>>> inputs = image_processor([image], return_codebook_pixels=True, return_tensors="pt")
|
| 1635 |
+
>>> inputs = dict(pixel_values=inputs.codebook_pixel_values)
|
| 1636 |
+
|
| 1637 |
+
>>> outputs = model(**inputs)
|
| 1638 |
+
>>> print(outputs.shape)
|
| 1639 |
+
(1, 196)
|
| 1640 |
+
```
|
| 1641 |
+
""".format(_CHECKPOINT_FOR_CODEBOOK_DOC)
|
| 1642 |
+
if len(pixel_values.shape) != 4:
|
| 1643 |
+
raise ValueError(f"input shape {pixel_values.shape} is not 4d")
|
| 1644 |
+
if pixel_values.shape[1] != self.input_channels:
|
| 1645 |
+
raise ValueError(f"input has {pixel_values.shape[1]} channels but model built for {self.input_channels}")
|
| 1646 |
+
return self.blocks(pixel_values)
|
| 1647 |
+
|
| 1648 |
+
|
| 1649 |
+
class FlavaPredictionHeadTransform(nn.Module):
|
| 1650 |
+
def __init__(self, config):
|
| 1651 |
+
super().__init__()
|
| 1652 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 1653 |
+
if isinstance(config.hidden_act, str):
|
| 1654 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 1655 |
+
else:
|
| 1656 |
+
self.transform_act_fn = config.hidden_act
|
| 1657 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 1658 |
+
|
| 1659 |
+
def forward(self, hidden_states):
|
| 1660 |
+
hidden_states = self.dense(hidden_states)
|
| 1661 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 1662 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 1663 |
+
return hidden_states
|
| 1664 |
+
|
| 1665 |
+
|
| 1666 |
+
class FlavaMaskedPredictionHead(nn.Module):
|
| 1667 |
+
def __init__(self, config, weight=None):
|
| 1668 |
+
super().__init__()
|
| 1669 |
+
self.config = config
|
| 1670 |
+
self.transform = FlavaPredictionHeadTransform(config)
|
| 1671 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1672 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 1673 |
+
if weight is not None:
|
| 1674 |
+
self.decoder.weight = weight
|
| 1675 |
+
|
| 1676 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 1677 |
+
self.decoder.bias = self.bias
|
| 1678 |
+
|
| 1679 |
+
def _tie_weights(self):
|
| 1680 |
+
self.decoder.bias = self.bias
|
| 1681 |
+
|
| 1682 |
+
def forward(self, x):
|
| 1683 |
+
x = self.transform(x)
|
| 1684 |
+
x = self.decoder(x)
|
| 1685 |
+
return x
|
| 1686 |
+
|
| 1687 |
+
|
| 1688 |
+
class FlavaITMHead(nn.Module):
|
| 1689 |
+
def __init__(self, config):
|
| 1690 |
+
super().__init__()
|
| 1691 |
+
self.config = config
|
| 1692 |
+
self.pooler = FlavaPooler(config)
|
| 1693 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 1694 |
+
|
| 1695 |
+
def forward(self, x):
|
| 1696 |
+
x = self.pooler(x)
|
| 1697 |
+
x = self.seq_relationship(x)
|
| 1698 |
+
return x
|
| 1699 |
+
|
| 1700 |
+
|
| 1701 |
+
class FlavaGlobalContrastiveHead(nn.Module):
|
| 1702 |
+
def __init__(self, config):
|
| 1703 |
+
super().__init__()
|
| 1704 |
+
self.config = config
|
| 1705 |
+
self.global_backprop_contrastive = config.global_backprop_contrastive
|
| 1706 |
+
|
| 1707 |
+
def forward(self, image_embeddings, text_embeddings, logit_scale):
|
| 1708 |
+
temperature = torch.exp(logit_scale)
|
| 1709 |
+
if not torch.distributed.is_available() or not torch.distributed.is_initialized():
|
| 1710 |
+
labels = torch.arange(image_embeddings.size(0), device=image_embeddings.device)
|
| 1711 |
+
image_embeddings_all = [image_embeddings]
|
| 1712 |
+
text_embeddings_all = [text_embeddings]
|
| 1713 |
+
else:
|
| 1714 |
+
local_batch_size = image_embeddings.size(0)
|
| 1715 |
+
world_size = torch.distributed.get_world_size()
|
| 1716 |
+
|
| 1717 |
+
if self.global_backprop_contrastive:
|
| 1718 |
+
# `torch.distributed.nn.functional.all_gather` does backprop on all active workers
|
| 1719 |
+
# whereas `torch.distributed.all_gather` does only backpropagates on the current worker.
|
| 1720 |
+
image_embeddings_all = torch.distributed.nn.functional.all_gather(image_embeddings)
|
| 1721 |
+
text_embeddings_all = torch.distributed.nn.functional.all_gather(text_embeddings)
|
| 1722 |
+
else:
|
| 1723 |
+
image_embeddings_all = [torch.zeros_like(text_embeddings) for _ in range(world_size)]
|
| 1724 |
+
text_embeddings_all = [torch.zeros_like(image_embeddings) for _ in range(world_size)]
|
| 1725 |
+
torch.distributed.all_gather(image_embeddings_all, image_embeddings)
|
| 1726 |
+
torch.distributed.all_gather(text_embeddings_all, text_embeddings)
|
| 1727 |
+
|
| 1728 |
+
labels = local_batch_size * torch.distributed.get_rank() + torch.arange(
|
| 1729 |
+
local_batch_size, device=image_embeddings.device
|
| 1730 |
+
)
|
| 1731 |
+
|
| 1732 |
+
image_embeddings_all = torch.cat(image_embeddings_all)
|
| 1733 |
+
text_embeddings_all = torch.cat(text_embeddings_all)
|
| 1734 |
+
|
| 1735 |
+
logits_per_image = torch.matmul(image_embeddings, text_embeddings_all.transpose(0, 1)) * temperature
|
| 1736 |
+
logits_per_text = torch.matmul(text_embeddings, image_embeddings_all.transpose(0, 1)) * temperature
|
| 1737 |
+
|
| 1738 |
+
return logits_per_image, logits_per_text, labels
|
| 1739 |
+
|
| 1740 |
+
|
| 1741 |
+
@add_start_docstrings(
|
| 1742 |
+
"""
|
| 1743 |
+
The FLAVA model for pretraining which outputs losses, embeddings, logits and transformer outputs.
|
| 1744 |
+
""",
|
| 1745 |
+
FLAVA_START_DOCSTRING.format(config="FlavaConfig") + FLAVA_PRETRAINING_START_DOCSTRING_EXTRA,
|
| 1746 |
+
)
|
| 1747 |
+
class FlavaForPreTraining(FlavaPreTrainedModel):
|
| 1748 |
+
# Those are linked to xxx.bias
|
| 1749 |
+
_tied_weights_keys = [
|
| 1750 |
+
"mmm_text_head.decoder.bias",
|
| 1751 |
+
"mmm_image_head.decoder.bias",
|
| 1752 |
+
"mlm_head.decoder.bias",
|
| 1753 |
+
"mim_head.decoder.bias",
|
| 1754 |
+
]
|
| 1755 |
+
|
| 1756 |
+
def __init__(self, config: FlavaConfig, image_codebook: Optional[nn.Module] = None):
|
| 1757 |
+
super().__init__(config)
|
| 1758 |
+
self.flava = FlavaModel(config)
|
| 1759 |
+
|
| 1760 |
+
self.image_codebook = image_codebook
|
| 1761 |
+
if self.image_codebook is None and config.init_codebook:
|
| 1762 |
+
self.image_codebook = FlavaImageCodebook(config.image_codebook_config)
|
| 1763 |
+
|
| 1764 |
+
# Levarage text and image encoder configs to create the masked
|
| 1765 |
+
# head since it has the right vocab
|
| 1766 |
+
self.mim_head = FlavaMaskedPredictionHead(config.image_config)
|
| 1767 |
+
self.mlm_head = FlavaMaskedPredictionHead(config.text_config)
|
| 1768 |
+
self.itm_head = FlavaITMHead(config)
|
| 1769 |
+
self.mmm_image_head = FlavaMaskedPredictionHead(config.image_config)
|
| 1770 |
+
self.mmm_text_head = FlavaMaskedPredictionHead(config.text_config)
|
| 1771 |
+
self.global_contrastive_head = FlavaGlobalContrastiveHead(config)
|
| 1772 |
+
|
| 1773 |
+
self.image_vocab_size = config.image_config.vocab_size
|
| 1774 |
+
self.text_vocab_size = config.text_config.vocab_size
|
| 1775 |
+
self.mlm_weight = config.mlm_weight
|
| 1776 |
+
self.mim_weight = config.mim_weight
|
| 1777 |
+
self.global_contrastive_weight = config.global_contrastive_weight
|
| 1778 |
+
self.ce_ignore_index = config.ce_ignore_index
|
| 1779 |
+
self.itm_weight = config.itm_weight
|
| 1780 |
+
self.mmm_image_weight = config.mmm_image_weight
|
| 1781 |
+
self.mmm_text_weight = config.mmm_text_weight
|
| 1782 |
+
self.skip_unmasked_multimodal_encoder = config.skip_unmasked_multimodal_encoder
|
| 1783 |
+
|
| 1784 |
+
self.post_init()
|
| 1785 |
+
|
| 1786 |
+
def _resize_to_2d(self, x: torch.Tensor):
|
| 1787 |
+
if x.dim() > 2:
|
| 1788 |
+
x = x.view(x.size(0), -1)
|
| 1789 |
+
return x
|
| 1790 |
+
|
| 1791 |
+
@add_start_docstrings_to_model_forward(
|
| 1792 |
+
FLAVA_PRETRAINING_INPUTS_DOCSTRING.format("batch_size, text_seq_len", "batch_size, image_num_patches")
|
| 1793 |
+
)
|
| 1794 |
+
@replace_return_docstrings(output_type=FlavaForPreTrainingOutput, config_class=FlavaConfig)
|
| 1795 |
+
def forward(
|
| 1796 |
+
self,
|
| 1797 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1798 |
+
input_ids_masked: Optional[torch.LongTensor] = None,
|
| 1799 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1800 |
+
codebook_pixel_values: Optional[torch.FloatTensor] = None,
|
| 1801 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1802 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1803 |
+
bool_masked_pos: Optional[torch.Tensor] = None,
|
| 1804 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1805 |
+
image_attention_mask: Optional[torch.Tensor] = None,
|
| 1806 |
+
skip_unmasked_multimodal_encoder: Optional[bool] = None,
|
| 1807 |
+
mlm_labels: Optional[torch.Tensor] = None,
|
| 1808 |
+
mim_labels: Optional[torch.Tensor] = None,
|
| 1809 |
+
itm_labels: Optional[torch.Tensor] = None,
|
| 1810 |
+
output_attentions: Optional[bool] = None,
|
| 1811 |
+
output_hidden_states: bool = True,
|
| 1812 |
+
return_dict: Optional[bool] = None,
|
| 1813 |
+
return_loss: Optional[bool] = None,
|
| 1814 |
+
) -> Union[Tuple[torch.Tensor], FlavaForPreTrainingOutput]:
|
| 1815 |
+
"""
|
| 1816 |
+
Examples:
|
| 1817 |
+
```python
|
| 1818 |
+
>>> from PIL import Image
|
| 1819 |
+
>>> import requests
|
| 1820 |
+
>>> from transformers import FlavaForPreTraining, AutoProcessor
|
| 1821 |
+
|
| 1822 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1823 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1824 |
+
|
| 1825 |
+
>>> model = FlavaForPreTraining.from_pretrained("facebook/flava-full")
|
| 1826 |
+
>>> processor = AutoProcessor.from_pretrained("facebook/flava-full")
|
| 1827 |
+
|
| 1828 |
+
>>> text = ["a photo of a cat"]
|
| 1829 |
+
|
| 1830 |
+
>>> inputs = processor(
|
| 1831 |
+
... images=[image],
|
| 1832 |
+
... text=text,
|
| 1833 |
+
... return_masks=True,
|
| 1834 |
+
... return_codebook_pixels=True,
|
| 1835 |
+
... padding=True,
|
| 1836 |
+
... max_length=77,
|
| 1837 |
+
... return_tensors="pt",
|
| 1838 |
+
... )
|
| 1839 |
+
|
| 1840 |
+
|
| 1841 |
+
>>> output = model(**inputs)
|
| 1842 |
+
```
|
| 1843 |
+
|
| 1844 |
+
Return:
|
| 1845 |
+
|
| 1846 |
+
"""
|
| 1847 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1848 |
+
return_loss = return_loss if return_loss is not None else self.config.return_loss
|
| 1849 |
+
|
| 1850 |
+
skip_unmasked_multimodal_encoder = (
|
| 1851 |
+
skip_unmasked_multimodal_encoder
|
| 1852 |
+
if skip_unmasked_multimodal_encoder is not None
|
| 1853 |
+
else self.skip_unmasked_multimodal_encoder
|
| 1854 |
+
)
|
| 1855 |
+
|
| 1856 |
+
if input_ids_masked is None and input_ids is not None:
|
| 1857 |
+
logger.warning(
|
| 1858 |
+
"`input_ids_masked` isn't passed which means MLM loss won't be calculated correctlySetting it to"
|
| 1859 |
+
" `input_ids` so that model can work. Please pass it if this is unintentional. This is usually OKAY if"
|
| 1860 |
+
" you are doing inference on unmasked text..."
|
| 1861 |
+
)
|
| 1862 |
+
input_ids_masked = input_ids
|
| 1863 |
+
|
| 1864 |
+
flava_output = self.flava(
|
| 1865 |
+
input_ids=input_ids,
|
| 1866 |
+
pixel_values=pixel_values,
|
| 1867 |
+
attention_mask=attention_mask,
|
| 1868 |
+
token_type_ids=token_type_ids,
|
| 1869 |
+
position_ids=position_ids,
|
| 1870 |
+
image_attention_mask=image_attention_mask,
|
| 1871 |
+
# Don't need unmasked multimodal embedding for anything so skip it
|
| 1872 |
+
# NOTE: ITM uses masked version
|
| 1873 |
+
skip_multimodal_encoder=skip_unmasked_multimodal_encoder,
|
| 1874 |
+
output_attentions=output_attentions,
|
| 1875 |
+
output_hidden_states=output_hidden_states,
|
| 1876 |
+
# Pass true to have deterministic outputs
|
| 1877 |
+
return_dict=True,
|
| 1878 |
+
)
|
| 1879 |
+
|
| 1880 |
+
flava_masked_output = self.flava(
|
| 1881 |
+
input_ids=input_ids_masked,
|
| 1882 |
+
pixel_values=pixel_values,
|
| 1883 |
+
attention_mask=attention_mask,
|
| 1884 |
+
token_type_ids=token_type_ids,
|
| 1885 |
+
image_attention_mask=image_attention_mask,
|
| 1886 |
+
bool_masked_pos=bool_masked_pos,
|
| 1887 |
+
output_attentions=output_attentions,
|
| 1888 |
+
output_hidden_states=output_hidden_states,
|
| 1889 |
+
return_dict=True,
|
| 1890 |
+
)
|
| 1891 |
+
|
| 1892 |
+
pos_mask = None
|
| 1893 |
+
|
| 1894 |
+
image_embeddings = flava_output.image_embeddings
|
| 1895 |
+
text_embeddings = flava_output.text_embeddings
|
| 1896 |
+
image_masked_embeddings = flava_masked_output.image_embeddings
|
| 1897 |
+
text_masked_embeddings = flava_masked_output.text_embeddings
|
| 1898 |
+
multimodal_masked_embeddings = flava_masked_output.multimodal_embeddings
|
| 1899 |
+
|
| 1900 |
+
total_loss = mim_loss = mlm_loss = mmm_text_loss = mmm_image_loss = gc_loss = itm_loss = None
|
| 1901 |
+
mim_logits = mlm_logits = mmm_text_logits = mmm_image_logits = None
|
| 1902 |
+
itm_logits = logits_per_image = logits_per_text = None
|
| 1903 |
+
|
| 1904 |
+
# Calculate mim_labels if necessary from the image_codebook
|
| 1905 |
+
if image_masked_embeddings is not None or multimodal_masked_embeddings is not None:
|
| 1906 |
+
if mim_labels is None and return_loss:
|
| 1907 |
+
if self.image_codebook is None:
|
| 1908 |
+
raise RuntimeError(
|
| 1909 |
+
"`return_loss` is set to True but the image codebook is not initialized and no `mim_labels` "
|
| 1910 |
+
" have been passed. Reinstantiate the model with `init_codebook` set to True or "
|
| 1911 |
+
"pass in your custom `mim_labels`"
|
| 1912 |
+
)
|
| 1913 |
+
if codebook_pixel_values is None:
|
| 1914 |
+
raise ValueError(
|
| 1915 |
+
"`codebook_pixel_value` are required to generate `mim_labels` if loss is expected. "
|
| 1916 |
+
"Call `AutoProcessor` with `return_codebook_pixels` set to True"
|
| 1917 |
+
)
|
| 1918 |
+
mim_labels = self.image_codebook.get_codebook_indices(codebook_pixel_values)
|
| 1919 |
+
# Unimodal MIM Loss
|
| 1920 |
+
# If multimodal embeddings are present, we will calculate MMM loss
|
| 1921 |
+
if self.mim_weight > 0 and image_masked_embeddings is not None and multimodal_masked_embeddings is None:
|
| 1922 |
+
sequence_for_image = image_masked_embeddings
|
| 1923 |
+
|
| 1924 |
+
if mim_labels is not None:
|
| 1925 |
+
mim_labels = self._resize_to_2d(mim_labels)
|
| 1926 |
+
bool_masked_pos = self._resize_to_2d(bool_masked_pos)
|
| 1927 |
+
mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index
|
| 1928 |
+
|
| 1929 |
+
sequence_for_image = sequence_for_image[:, -mim_labels.size(1) :, :]
|
| 1930 |
+
masked_tokens = mim_labels.ne(self.ce_ignore_index)
|
| 1931 |
+
mim_labels_filtered = mim_labels[masked_tokens]
|
| 1932 |
+
sequence_for_image = sequence_for_image[masked_tokens, :]
|
| 1933 |
+
mim_logits = self.mim_head(sequence_for_image)
|
| 1934 |
+
if return_loss:
|
| 1935 |
+
mim_loss = nn.functional.cross_entropy(
|
| 1936 |
+
mim_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1)
|
| 1937 |
+
)
|
| 1938 |
+
mim_loss *= self.mim_weight
|
| 1939 |
+
else:
|
| 1940 |
+
mim_logits = self.mim_head(sequence_for_image)
|
| 1941 |
+
|
| 1942 |
+
# Unimodal MLM Loss
|
| 1943 |
+
if self.mlm_weight > 0 and text_masked_embeddings is not None and multimodal_masked_embeddings is None:
|
| 1944 |
+
sequence_for_text = text_masked_embeddings
|
| 1945 |
+
if mlm_labels is not None:
|
| 1946 |
+
mlm_labels = self._resize_to_2d(mlm_labels)
|
| 1947 |
+
sequence_for_text = sequence_for_text[:, -mlm_labels.size(1) :, :]
|
| 1948 |
+
masked_tokens = mlm_labels.ne(self.ce_ignore_index)
|
| 1949 |
+
mlm_labels_filtered = mlm_labels[masked_tokens]
|
| 1950 |
+
sequence_for_text = sequence_for_text[masked_tokens, :]
|
| 1951 |
+
mlm_logits = self.mlm_head(sequence_for_text)
|
| 1952 |
+
if return_loss:
|
| 1953 |
+
mlm_loss = nn.functional.cross_entropy(
|
| 1954 |
+
mlm_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1)
|
| 1955 |
+
)
|
| 1956 |
+
mlm_loss *= self.mlm_weight
|
| 1957 |
+
else:
|
| 1958 |
+
mlm_logits = self.mlm_head(sequence_for_text)
|
| 1959 |
+
|
| 1960 |
+
# ITM Loss
|
| 1961 |
+
if self.itm_weight > 0 and multimodal_masked_embeddings is not None:
|
| 1962 |
+
itm_logits = self.itm_head(multimodal_masked_embeddings)
|
| 1963 |
+
|
| 1964 |
+
if itm_labels is not None:
|
| 1965 |
+
pos_pairs = itm_labels.ne(0)
|
| 1966 |
+
pos_mask = torch.where(pos_pairs.any(), pos_pairs, pos_pairs.new([True]))
|
| 1967 |
+
if return_loss:
|
| 1968 |
+
itm_loss = nn.functional.cross_entropy(itm_logits, itm_labels)
|
| 1969 |
+
itm_loss *= self.itm_weight
|
| 1970 |
+
|
| 1971 |
+
if multimodal_masked_embeddings is not None:
|
| 1972 |
+
multimodal_masked_embeddings = multimodal_masked_embeddings[pos_mask]
|
| 1973 |
+
|
| 1974 |
+
if mlm_labels is not None:
|
| 1975 |
+
mlm_labels = mlm_labels[pos_mask]
|
| 1976 |
+
|
| 1977 |
+
if mim_labels is not None:
|
| 1978 |
+
mim_labels = mim_labels[pos_mask]
|
| 1979 |
+
bool_masked_pos = bool_masked_pos[pos_mask]
|
| 1980 |
+
|
| 1981 |
+
# MMM Image Loss
|
| 1982 |
+
if multimodal_masked_embeddings is not None and self.mmm_image_weight > 0:
|
| 1983 |
+
sequence_for_image = multimodal_masked_embeddings
|
| 1984 |
+
end_index = image_masked_embeddings.size(1) - 1
|
| 1985 |
+
sequence_for_image = sequence_for_image[:, 2 : 2 + end_index, :]
|
| 1986 |
+
|
| 1987 |
+
if mim_labels is not None:
|
| 1988 |
+
mim_labels = self._resize_to_2d(mim_labels)
|
| 1989 |
+
bool_masked_pos = self._resize_to_2d(bool_masked_pos)
|
| 1990 |
+
mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index
|
| 1991 |
+
|
| 1992 |
+
masked_tokens = mim_labels.ne(self.ce_ignore_index)
|
| 1993 |
+
mim_labels_filtered = mim_labels[masked_tokens]
|
| 1994 |
+
sequence_for_image = sequence_for_image[masked_tokens, :]
|
| 1995 |
+
mmm_image_logits = self.mmm_image_head(sequence_for_image)
|
| 1996 |
+
if return_loss:
|
| 1997 |
+
mmm_image_loss = nn.functional.cross_entropy(
|
| 1998 |
+
mmm_image_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1)
|
| 1999 |
+
)
|
| 2000 |
+
mmm_image_loss *= self.mmm_image_weight
|
| 2001 |
+
else:
|
| 2002 |
+
mmm_image_logits = self.mmm_image_head(sequence_for_image)
|
| 2003 |
+
|
| 2004 |
+
# MMM Text Loss
|
| 2005 |
+
if multimodal_masked_embeddings is not None and self.mmm_text_weight > 0:
|
| 2006 |
+
sequence_for_text = multimodal_masked_embeddings
|
| 2007 |
+
sequence_for_text = sequence_for_text[:, -text_masked_embeddings.size(1) :, :]
|
| 2008 |
+
|
| 2009 |
+
if mlm_labels is not None:
|
| 2010 |
+
mlm_labels = self._resize_to_2d(mlm_labels)
|
| 2011 |
+
masked_tokens = mlm_labels.ne(self.ce_ignore_index)
|
| 2012 |
+
mlm_labels_filtered = mlm_labels[masked_tokens]
|
| 2013 |
+
sequence_for_text = sequence_for_text[masked_tokens, :]
|
| 2014 |
+
mmm_text_logits = self.mmm_text_head(sequence_for_text)
|
| 2015 |
+
if return_loss:
|
| 2016 |
+
mmm_text_loss = nn.functional.cross_entropy(
|
| 2017 |
+
mmm_text_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1)
|
| 2018 |
+
)
|
| 2019 |
+
mmm_text_loss *= self.mmm_text_weight
|
| 2020 |
+
else:
|
| 2021 |
+
mmm_text_logits = self.mmm_text_head(sequence_for_text)
|
| 2022 |
+
|
| 2023 |
+
# Global Contrastive Loss
|
| 2024 |
+
if image_embeddings is not None and text_embeddings is not None and self.global_contrastive_weight > 0:
|
| 2025 |
+
text_embedding = self.flava.text_projection(text_embeddings[:, 0, :])
|
| 2026 |
+
text_embedding = nn.functional.normalize(text_embedding, dim=-1)
|
| 2027 |
+
|
| 2028 |
+
image_embedding = self.flava.image_projection(image_embeddings[:, 0, :])
|
| 2029 |
+
image_embedding = nn.functional.normalize(image_embedding, dim=-1)
|
| 2030 |
+
|
| 2031 |
+
self.flava.logit_scale.data.clamp_(LOGIT_SCALE_CLAMP_MIN, LOGIT_SCALE_CLAMP_MAX)
|
| 2032 |
+
|
| 2033 |
+
logits_per_image, logits_per_text, gc_labels = self.global_contrastive_head(
|
| 2034 |
+
image_embedding, text_embedding, self.flava.logit_scale
|
| 2035 |
+
)
|
| 2036 |
+
|
| 2037 |
+
# Apply ITM negative mask if any
|
| 2038 |
+
if pos_mask is not None:
|
| 2039 |
+
logits_per_image = logits_per_image[pos_mask]
|
| 2040 |
+
logits_per_text = logits_per_text[pos_mask]
|
| 2041 |
+
gc_labels = gc_labels[pos_mask]
|
| 2042 |
+
|
| 2043 |
+
if return_loss:
|
| 2044 |
+
gc_loss_image = nn.functional.cross_entropy(logits_per_image, gc_labels)
|
| 2045 |
+
gc_loss_text = nn.functional.cross_entropy(logits_per_text, gc_labels)
|
| 2046 |
+
gc_loss = (gc_loss_image + gc_loss_text) / 2
|
| 2047 |
+
gc_loss *= self.global_contrastive_weight
|
| 2048 |
+
|
| 2049 |
+
flava_losses = FlavaLosses(
|
| 2050 |
+
mim=mim_loss,
|
| 2051 |
+
mlm=mlm_loss,
|
| 2052 |
+
itm=itm_loss,
|
| 2053 |
+
global_contrastive=gc_loss,
|
| 2054 |
+
mmm_image=mmm_image_loss,
|
| 2055 |
+
mmm_text=mmm_text_loss,
|
| 2056 |
+
)
|
| 2057 |
+
|
| 2058 |
+
if return_loss and not flava_losses.all_none():
|
| 2059 |
+
total_loss = sum(loss if loss is not None else 0 for loss in flava_losses.values())
|
| 2060 |
+
|
| 2061 |
+
if not return_dict:
|
| 2062 |
+
output = (
|
| 2063 |
+
image_embeddings,
|
| 2064 |
+
flava_output.image_output.to_tuple() if flava_output.image_output is not None else None,
|
| 2065 |
+
text_embeddings,
|
| 2066 |
+
flava_output.text_output.to_tuple() if flava_output.text_output is not None else None,
|
| 2067 |
+
flava_output.multimodal_embeddings,
|
| 2068 |
+
flava_output.multimodal_output.to_tuple() if flava_output.multimodal_output is not None else None,
|
| 2069 |
+
image_masked_embeddings,
|
| 2070 |
+
flava_masked_output.image_output.to_tuple() if flava_masked_output.image_output is not None else None,
|
| 2071 |
+
text_masked_embeddings,
|
| 2072 |
+
flava_masked_output.text_output.to_tuple() if flava_masked_output.text_output is not None else None,
|
| 2073 |
+
multimodal_masked_embeddings,
|
| 2074 |
+
flava_masked_output.multimodal_output.to_tuple()
|
| 2075 |
+
if flava_masked_output.multimodal_output is not None
|
| 2076 |
+
else None,
|
| 2077 |
+
mim_logits,
|
| 2078 |
+
mlm_logits,
|
| 2079 |
+
itm_logits,
|
| 2080 |
+
logits_per_image,
|
| 2081 |
+
logits_per_image,
|
| 2082 |
+
mmm_image_logits,
|
| 2083 |
+
mmm_text_logits,
|
| 2084 |
+
)
|
| 2085 |
+
if return_loss and not flava_losses.all_none():
|
| 2086 |
+
output = (
|
| 2087 |
+
total_loss,
|
| 2088 |
+
flava_losses,
|
| 2089 |
+
) + output
|
| 2090 |
+
|
| 2091 |
+
# Filter None as transformer by default won't handle it
|
| 2092 |
+
return tuple(x for x in output if x is None)
|
| 2093 |
+
|
| 2094 |
+
return FlavaForPreTrainingOutput(
|
| 2095 |
+
loss=total_loss,
|
| 2096 |
+
loss_info=flava_losses,
|
| 2097 |
+
image_embeddings=image_embeddings,
|
| 2098 |
+
image_output=flava_output.image_output,
|
| 2099 |
+
text_embeddings=text_embeddings,
|
| 2100 |
+
text_output=flava_output.text_output,
|
| 2101 |
+
multimodal_embeddings=flava_output.multimodal_embeddings,
|
| 2102 |
+
multimodal_output=flava_output.multimodal_output,
|
| 2103 |
+
image_masked_embeddings=image_masked_embeddings,
|
| 2104 |
+
image_masked_output=flava_masked_output.image_output,
|
| 2105 |
+
text_masked_embeddings=text_masked_embeddings,
|
| 2106 |
+
text_masked_output=flava_masked_output.text_output,
|
| 2107 |
+
multimodal_masked_embeddings=multimodal_masked_embeddings,
|
| 2108 |
+
multimodal_masked_output=flava_masked_output.multimodal_output,
|
| 2109 |
+
mim_logits=mim_logits,
|
| 2110 |
+
mlm_logits=mlm_logits,
|
| 2111 |
+
itm_logits=itm_logits,
|
| 2112 |
+
contrastive_logits_per_image=logits_per_image,
|
| 2113 |
+
contrastive_logits_per_text=logits_per_text,
|
| 2114 |
+
mmm_image_logits=mmm_image_logits,
|
| 2115 |
+
mmm_text_logits=mmm_text_logits,
|
| 2116 |
+
)
|
| 2117 |
+
|
| 2118 |
+
|
| 2119 |
+
__all__ = [
|
| 2120 |
+
"FlavaForPreTraining",
|
| 2121 |
+
"FlavaImageCodebook",
|
| 2122 |
+
"FlavaImageModel",
|
| 2123 |
+
"FlavaModel",
|
| 2124 |
+
"FlavaMultimodalModel",
|
| 2125 |
+
"FlavaPreTrainedModel",
|
| 2126 |
+
"FlavaTextModel",
|
| 2127 |
+
]
|
docs/transformers/build/lib/transformers/models/flava/processing_flava.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Image/Text processor class for FLAVA
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import warnings
|
| 20 |
+
from typing import List, Optional, Union
|
| 21 |
+
|
| 22 |
+
from ...image_utils import ImageInput
|
| 23 |
+
from ...processing_utils import ProcessorMixin
|
| 24 |
+
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
| 25 |
+
from ...utils import TensorType
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class FlavaProcessor(ProcessorMixin):
|
| 29 |
+
r"""
|
| 30 |
+
Constructs a FLAVA processor which wraps a FLAVA image processor and a FLAVA tokenizer into a single processor.
|
| 31 |
+
|
| 32 |
+
[`FlavaProcessor`] offers all the functionalities of [`FlavaImageProcessor`] and [`BertTokenizerFast`]. See the
|
| 33 |
+
[`~FlavaProcessor.__call__`] and [`~FlavaProcessor.decode`] for more information.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
image_processor ([`FlavaImageProcessor`], *optional*): The image processor is a required input.
|
| 37 |
+
tokenizer ([`BertTokenizerFast`], *optional*): The tokenizer is a required input.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
attributes = ["image_processor", "tokenizer"]
|
| 41 |
+
image_processor_class = "FlavaImageProcessor"
|
| 42 |
+
tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
|
| 43 |
+
|
| 44 |
+
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
|
| 45 |
+
feature_extractor = None
|
| 46 |
+
if "feature_extractor" in kwargs:
|
| 47 |
+
warnings.warn(
|
| 48 |
+
"The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
|
| 49 |
+
" instead.",
|
| 50 |
+
FutureWarning,
|
| 51 |
+
)
|
| 52 |
+
feature_extractor = kwargs.pop("feature_extractor")
|
| 53 |
+
|
| 54 |
+
image_processor = image_processor if image_processor is not None else feature_extractor
|
| 55 |
+
if image_processor is None:
|
| 56 |
+
raise ValueError("You need to specify an `image_processor`.")
|
| 57 |
+
if tokenizer is None:
|
| 58 |
+
raise ValueError("You need to specify a `tokenizer`.")
|
| 59 |
+
|
| 60 |
+
super().__init__(image_processor, tokenizer)
|
| 61 |
+
self.current_processor = self.image_processor
|
| 62 |
+
|
| 63 |
+
def __call__(
|
| 64 |
+
self,
|
| 65 |
+
images: Optional[ImageInput] = None,
|
| 66 |
+
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
| 67 |
+
add_special_tokens: bool = True,
|
| 68 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
| 69 |
+
truncation: Union[bool, str, TruncationStrategy] = False,
|
| 70 |
+
max_length: Optional[int] = None,
|
| 71 |
+
stride: int = 0,
|
| 72 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 73 |
+
return_image_mask: Optional[bool] = None,
|
| 74 |
+
return_codebook_pixels: Optional[bool] = None,
|
| 75 |
+
return_token_type_ids: Optional[bool] = None,
|
| 76 |
+
return_attention_mask: Optional[bool] = None,
|
| 77 |
+
return_overflowing_tokens: bool = False,
|
| 78 |
+
return_special_tokens_mask: bool = False,
|
| 79 |
+
return_offsets_mapping: bool = False,
|
| 80 |
+
return_length: bool = False,
|
| 81 |
+
verbose: bool = True,
|
| 82 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 83 |
+
**kwargs,
|
| 84 |
+
):
|
| 85 |
+
"""
|
| 86 |
+
This method uses [`FlavaImageProcessor.__call__`] method to prepare image(s) for the model, and
|
| 87 |
+
[`BertTokenizerFast.__call__`] to prepare text for the model.
|
| 88 |
+
|
| 89 |
+
Please refer to the docstring of the above two methods for more information.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
if text is None and images is None:
|
| 93 |
+
raise ValueError("You have to specify either text or images. Both cannot be none.")
|
| 94 |
+
|
| 95 |
+
if text is not None:
|
| 96 |
+
encoding = self.tokenizer(
|
| 97 |
+
text=text,
|
| 98 |
+
add_special_tokens=add_special_tokens,
|
| 99 |
+
padding=padding,
|
| 100 |
+
truncation=truncation,
|
| 101 |
+
max_length=max_length,
|
| 102 |
+
stride=stride,
|
| 103 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 104 |
+
return_token_type_ids=return_token_type_ids,
|
| 105 |
+
return_attention_mask=return_attention_mask,
|
| 106 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
| 107 |
+
return_special_tokens_mask=return_special_tokens_mask,
|
| 108 |
+
return_offsets_mapping=return_offsets_mapping,
|
| 109 |
+
return_length=return_length,
|
| 110 |
+
verbose=verbose,
|
| 111 |
+
return_tensors=return_tensors,
|
| 112 |
+
**kwargs,
|
| 113 |
+
)
|
| 114 |
+
if images is not None:
|
| 115 |
+
image_features = self.image_processor(
|
| 116 |
+
images,
|
| 117 |
+
return_image_mask=return_image_mask,
|
| 118 |
+
return_codebook_pixels=return_codebook_pixels,
|
| 119 |
+
return_tensors=return_tensors,
|
| 120 |
+
**kwargs,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if text is not None and images is not None:
|
| 124 |
+
encoding.update(image_features)
|
| 125 |
+
return encoding
|
| 126 |
+
elif text is not None:
|
| 127 |
+
return encoding
|
| 128 |
+
else:
|
| 129 |
+
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
|
| 130 |
+
|
| 131 |
+
def batch_decode(self, *args, **kwargs):
|
| 132 |
+
"""
|
| 133 |
+
This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 134 |
+
refer to the docstring of this method for more information.
|
| 135 |
+
"""
|
| 136 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 137 |
+
|
| 138 |
+
def decode(self, *args, **kwargs):
|
| 139 |
+
"""
|
| 140 |
+
This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 141 |
+
the docstring of this method for more information.
|
| 142 |
+
"""
|
| 143 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 144 |
+
|
| 145 |
+
@property
|
| 146 |
+
def model_input_names(self):
|
| 147 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 148 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 149 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def feature_extractor_class(self):
|
| 153 |
+
warnings.warn(
|
| 154 |
+
"`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
|
| 155 |
+
FutureWarning,
|
| 156 |
+
)
|
| 157 |
+
return self.image_processor_class
|
| 158 |
+
|
| 159 |
+
@property
|
| 160 |
+
def feature_extractor(self):
|
| 161 |
+
warnings.warn(
|
| 162 |
+
"`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
|
| 163 |
+
FutureWarning,
|
| 164 |
+
)
|
| 165 |
+
return self.image_processor
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
__all__ = ["FlavaProcessor"]
|