Add files using upload-large-folder tool
Browse files- docs/transformers/build/lib/transformers/models/roc_bert/tokenization_roc_bert.py +1122 -0
- docs/transformers/build/lib/transformers/models/roformer/modeling_flax_roformer.py +1091 -0
- docs/transformers/build/lib/transformers/models/roformer/modeling_tf_roformer.py +1547 -0
- docs/transformers/build/lib/transformers/models/roformer/tokenization_roformer.py +540 -0
- docs/transformers/build/lib/transformers/models/roformer/tokenization_roformer_fast.py +180 -0
- docs/transformers/build/lib/transformers/models/roformer/tokenization_utils.py +68 -0
- docs/transformers/build/lib/transformers/models/rt_detr/__init__.py +33 -0
- docs/transformers/build/lib/transformers/models/rt_detr/configuration_rt_detr.py +364 -0
- docs/transformers/build/lib/transformers/models/rt_detr/configuration_rt_detr_resnet.py +114 -0
- docs/transformers/build/lib/transformers/models/rt_detr/convert_rt_detr_original_pytorch_checkpoint_to_hf.py +782 -0
- docs/transformers/build/lib/transformers/models/rt_detr/image_processing_rt_detr.py +1102 -0
- docs/transformers/build/lib/transformers/models/rt_detr/image_processing_rt_detr_fast.py +608 -0
- docs/transformers/build/lib/transformers/models/rt_detr/modeling_rt_detr.py +0 -0
- docs/transformers/build/lib/transformers/models/rt_detr/modeling_rt_detr_resnet.py +440 -0
- docs/transformers/build/lib/transformers/models/rt_detr/modular_rt_detr.py +410 -0
- docs/transformers/build/lib/transformers/models/rt_detr_v2/__init__.py +29 -0
- docs/transformers/build/lib/transformers/models/rt_detr_v2/convert_rt_detr_v2_weights_to_hf.py +363 -0
- docs/transformers/build/lib/transformers/models/rt_detr_v2/modular_rt_detr_v2.py +628 -0
- docs/transformers/build/lib/transformers/models/rwkv/__init__.py +27 -0
- docs/transformers/build/lib/transformers/models/rwkv/configuration_rwkv.py +120 -0
docs/transformers/build/lib/transformers/models/roc_bert/tokenization_roc_bert.py
ADDED
|
@@ -0,0 +1,1122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 WeChatAI and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for RoCBert."""
|
| 16 |
+
|
| 17 |
+
import collections
|
| 18 |
+
import itertools
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import unicodedata
|
| 22 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
| 25 |
+
from ...tokenization_utils_base import (
|
| 26 |
+
ENCODE_KWARGS_DOCSTRING,
|
| 27 |
+
ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
|
| 28 |
+
BatchEncoding,
|
| 29 |
+
EncodedInput,
|
| 30 |
+
EncodedInputPair,
|
| 31 |
+
PaddingStrategy,
|
| 32 |
+
PreTokenizedInput,
|
| 33 |
+
PreTokenizedInputPair,
|
| 34 |
+
TensorType,
|
| 35 |
+
TextInput,
|
| 36 |
+
TextInputPair,
|
| 37 |
+
TruncationStrategy,
|
| 38 |
+
)
|
| 39 |
+
from ...utils import add_end_docstrings, logging
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__)
|
| 43 |
+
|
| 44 |
+
VOCAB_FILES_NAMES = {
|
| 45 |
+
"vocab_file": "vocab.txt",
|
| 46 |
+
"word_shape_file": "word_shape.json",
|
| 47 |
+
"word_pronunciation_file": "word_pronunciation.json",
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Copied from transformers.models.bert.tokenization_bert.load_vocab
|
| 52 |
+
def load_vocab(vocab_file):
|
| 53 |
+
"""Loads a vocabulary file into a dictionary."""
|
| 54 |
+
vocab = collections.OrderedDict()
|
| 55 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
| 56 |
+
tokens = reader.readlines()
|
| 57 |
+
for index, token in enumerate(tokens):
|
| 58 |
+
token = token.rstrip("\n")
|
| 59 |
+
vocab[token] = index
|
| 60 |
+
return vocab
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
|
| 64 |
+
def whitespace_tokenize(text):
|
| 65 |
+
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
| 66 |
+
text = text.strip()
|
| 67 |
+
if not text:
|
| 68 |
+
return []
|
| 69 |
+
tokens = text.split()
|
| 70 |
+
return tokens
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class RoCBertTokenizer(PreTrainedTokenizer):
|
| 74 |
+
r"""
|
| 75 |
+
Args:
|
| 76 |
+
Construct a RoCBert tokenizer. Based on WordPiece. This tokenizer inherits from [`PreTrainedTokenizer`] which
|
| 77 |
+
contains most of the main methods. Users should refer to this superclass for more information regarding those
|
| 78 |
+
methods.
|
| 79 |
+
vocab_file (`str`):
|
| 80 |
+
File containing the vocabulary.
|
| 81 |
+
word_shape_file (`str`):
|
| 82 |
+
File containing the word => shape info.
|
| 83 |
+
word_pronunciation_file (`str`):
|
| 84 |
+
File containing the word => pronunciation info.
|
| 85 |
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
| 86 |
+
Whether or not to lowercase the input when tokenizing.
|
| 87 |
+
do_basic_tokenize (`bool`, *optional*, defaults to `True`):
|
| 88 |
+
Whether or not to do basic tokenization before WordPiece.
|
| 89 |
+
never_split (`Iterable`, *optional*):
|
| 90 |
+
Collection of tokens which will never be split during tokenization. Only has an effect when
|
| 91 |
+
`do_basic_tokenize=True`
|
| 92 |
+
unk_token (`str`, *optional*, defaults to `"[UNK]"`):
|
| 93 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 94 |
+
token instead.
|
| 95 |
+
sep_token (`str`, *optional*, defaults to `"[SEP]"`):
|
| 96 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 97 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 98 |
+
token of a sequence built with special tokens.
|
| 99 |
+
pad_token (`str`, *optional*, defaults to `"[PAD]"`):
|
| 100 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 101 |
+
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
|
| 102 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
| 103 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
| 104 |
+
mask_token (`str`, *optional*, defaults to `"[MASK]"`):
|
| 105 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 106 |
+
modeling. This is the token which the model will try to predict.
|
| 107 |
+
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
|
| 108 |
+
Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see this
|
| 109 |
+
[issue](https://github.com/huggingface/transformers/issues/328)).
|
| 110 |
+
strip_accents (`bool`, *optional*):
|
| 111 |
+
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
| 112 |
+
value for `lowercase` (as in the original BERT).
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 116 |
+
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
vocab_file,
|
| 120 |
+
word_shape_file,
|
| 121 |
+
word_pronunciation_file,
|
| 122 |
+
do_lower_case=True,
|
| 123 |
+
do_basic_tokenize=True,
|
| 124 |
+
never_split=None,
|
| 125 |
+
unk_token="[UNK]",
|
| 126 |
+
sep_token="[SEP]",
|
| 127 |
+
pad_token="[PAD]",
|
| 128 |
+
cls_token="[CLS]",
|
| 129 |
+
mask_token="[MASK]",
|
| 130 |
+
tokenize_chinese_chars=True,
|
| 131 |
+
strip_accents=None,
|
| 132 |
+
**kwargs,
|
| 133 |
+
):
|
| 134 |
+
for cur_file in [vocab_file, word_shape_file, word_pronunciation_file]:
|
| 135 |
+
if cur_file is None or not os.path.isfile(cur_file):
|
| 136 |
+
raise ValueError(
|
| 137 |
+
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google "
|
| 138 |
+
"pretrained model use `tokenizer = RoCBertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
self.vocab = load_vocab(vocab_file)
|
| 142 |
+
|
| 143 |
+
with open(word_shape_file, "r", encoding="utf8") as in_file:
|
| 144 |
+
self.word_shape = json.load(in_file)
|
| 145 |
+
|
| 146 |
+
with open(word_pronunciation_file, "r", encoding="utf8") as in_file:
|
| 147 |
+
self.word_pronunciation = json.load(in_file)
|
| 148 |
+
|
| 149 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 150 |
+
|
| 151 |
+
self.do_basic_tokenize = do_basic_tokenize
|
| 152 |
+
if do_basic_tokenize:
|
| 153 |
+
self.basic_tokenizer = RoCBertBasicTokenizer(
|
| 154 |
+
do_lower_case=do_lower_case,
|
| 155 |
+
never_split=never_split,
|
| 156 |
+
tokenize_chinese_chars=tokenize_chinese_chars,
|
| 157 |
+
strip_accents=strip_accents,
|
| 158 |
+
)
|
| 159 |
+
self.wordpiece_tokenizer = RoCBertWordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
|
| 160 |
+
super().__init__(
|
| 161 |
+
do_lower_case=do_lower_case,
|
| 162 |
+
do_basic_tokenize=do_basic_tokenize,
|
| 163 |
+
never_split=never_split,
|
| 164 |
+
unk_token=unk_token,
|
| 165 |
+
sep_token=sep_token,
|
| 166 |
+
pad_token=pad_token,
|
| 167 |
+
cls_token=cls_token,
|
| 168 |
+
mask_token=mask_token,
|
| 169 |
+
tokenize_chinese_chars=tokenize_chinese_chars,
|
| 170 |
+
strip_accents=strip_accents,
|
| 171 |
+
**kwargs,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
@property
|
| 175 |
+
def do_lower_case(self):
|
| 176 |
+
return self.basic_tokenizer.do_lower_case
|
| 177 |
+
|
| 178 |
+
@property
|
| 179 |
+
def vocab_size(self):
|
| 180 |
+
return len(self.vocab)
|
| 181 |
+
|
| 182 |
+
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_vocab
|
| 183 |
+
def get_vocab(self):
|
| 184 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 185 |
+
|
| 186 |
+
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
|
| 187 |
+
def _tokenize(self, text, split_special_tokens=False):
|
| 188 |
+
split_tokens = []
|
| 189 |
+
if self.do_basic_tokenize:
|
| 190 |
+
for token in self.basic_tokenizer.tokenize(
|
| 191 |
+
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
| 192 |
+
):
|
| 193 |
+
# If the token is part of the never_split set
|
| 194 |
+
if token in self.basic_tokenizer.never_split:
|
| 195 |
+
split_tokens.append(token)
|
| 196 |
+
else:
|
| 197 |
+
split_tokens += self.wordpiece_tokenizer.tokenize(token)
|
| 198 |
+
else:
|
| 199 |
+
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
| 200 |
+
return split_tokens
|
| 201 |
+
|
| 202 |
+
def _encode_plus(
|
| 203 |
+
self,
|
| 204 |
+
text: Union[TextInput, PreTokenizedInput, EncodedInput],
|
| 205 |
+
text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
|
| 206 |
+
add_special_tokens: bool = True,
|
| 207 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
| 208 |
+
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
| 209 |
+
max_length: Optional[int] = None,
|
| 210 |
+
stride: int = 0,
|
| 211 |
+
is_split_into_words: bool = False,
|
| 212 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 213 |
+
padding_side: Optional[str] = None,
|
| 214 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 215 |
+
return_token_type_ids: Optional[bool] = None,
|
| 216 |
+
return_attention_mask: Optional[bool] = None,
|
| 217 |
+
return_overflowing_tokens: bool = False,
|
| 218 |
+
return_special_tokens_mask: bool = False,
|
| 219 |
+
return_offsets_mapping: bool = False,
|
| 220 |
+
return_length: bool = False,
|
| 221 |
+
verbose: bool = True,
|
| 222 |
+
**kwargs,
|
| 223 |
+
) -> BatchEncoding:
|
| 224 |
+
def get_input_ids(text):
|
| 225 |
+
if isinstance(text, str):
|
| 226 |
+
tokens = self.tokenize(text, **kwargs)
|
| 227 |
+
tokens_ids = self.convert_tokens_to_ids(tokens)
|
| 228 |
+
tokens_shape_ids = self.convert_tokens_to_shape_ids(tokens)
|
| 229 |
+
tokens_proun_ids = self.convert_tokens_to_pronunciation_ids(tokens)
|
| 230 |
+
return tokens_ids, tokens_shape_ids, tokens_proun_ids
|
| 231 |
+
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
|
| 232 |
+
if is_split_into_words:
|
| 233 |
+
tokens = list(
|
| 234 |
+
itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))
|
| 235 |
+
)
|
| 236 |
+
tokens_ids = self.convert_tokens_to_ids(tokens)
|
| 237 |
+
tokens_shape_ids = self.convert_tokens_to_shape_ids(tokens)
|
| 238 |
+
tokens_proun_ids = self.convert_tokens_to_pronunciation_ids(tokens)
|
| 239 |
+
return tokens_ids, tokens_shape_ids, tokens_proun_ids
|
| 240 |
+
else:
|
| 241 |
+
tokens_ids = self.convert_tokens_to_ids(text)
|
| 242 |
+
tokens_shape_ids = self.convert_tokens_to_shape_ids(text)
|
| 243 |
+
tokens_proun_ids = self.convert_tokens_to_pronunciation_ids(text)
|
| 244 |
+
return tokens_ids, tokens_shape_ids, tokens_proun_ids
|
| 245 |
+
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
|
| 246 |
+
return text, [0] * len(text), [0] * len(text) # shape and proun id is pad_value
|
| 247 |
+
else:
|
| 248 |
+
if is_split_into_words:
|
| 249 |
+
raise ValueError(
|
| 250 |
+
f"Input {text} is not valid. Should be a string or a list/tuple of strings when"
|
| 251 |
+
" `is_split_into_words=True`."
|
| 252 |
+
)
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError(
|
| 255 |
+
f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of"
|
| 256 |
+
" integers."
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
if return_offsets_mapping:
|
| 260 |
+
raise NotImplementedError(
|
| 261 |
+
"return_offset_mapping is not available when using Python tokenizers. "
|
| 262 |
+
"To use this feature, change your tokenizer to one deriving from "
|
| 263 |
+
"transformers.PreTrainedTokenizerFast. "
|
| 264 |
+
"More information on available tokenizers at "
|
| 265 |
+
"https://github.com/huggingface/transformers/pull/2674"
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
first_ids, first_shape_ids, first_proun_ids = get_input_ids(text)
|
| 269 |
+
if text_pair is not None:
|
| 270 |
+
second_ids, second_shape_ids, second_proun_ids = get_input_ids(text_pair)
|
| 271 |
+
else:
|
| 272 |
+
second_ids, second_shape_ids, second_proun_ids = None, None, None
|
| 273 |
+
|
| 274 |
+
return self.prepare_for_model(
|
| 275 |
+
first_ids,
|
| 276 |
+
first_shape_ids,
|
| 277 |
+
first_proun_ids,
|
| 278 |
+
pair_ids=second_ids,
|
| 279 |
+
pair_shape_ids=second_shape_ids,
|
| 280 |
+
pair_pronunciation_ids=second_proun_ids,
|
| 281 |
+
add_special_tokens=add_special_tokens,
|
| 282 |
+
padding=padding_strategy.value,
|
| 283 |
+
truncation=truncation_strategy.value,
|
| 284 |
+
max_length=max_length,
|
| 285 |
+
stride=stride,
|
| 286 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 287 |
+
padding_side=padding_side,
|
| 288 |
+
return_tensors=return_tensors,
|
| 289 |
+
prepend_batch_axis=True,
|
| 290 |
+
return_attention_mask=return_attention_mask,
|
| 291 |
+
return_token_type_ids=return_token_type_ids,
|
| 292 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
| 293 |
+
return_special_tokens_mask=return_special_tokens_mask,
|
| 294 |
+
return_length=return_length,
|
| 295 |
+
verbose=verbose,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
|
| 299 |
+
def prepare_for_model(
|
| 300 |
+
self,
|
| 301 |
+
ids: List[int],
|
| 302 |
+
shape_ids: List[int],
|
| 303 |
+
pronunciation_ids: List[int],
|
| 304 |
+
pair_ids: Optional[List[int]] = None,
|
| 305 |
+
pair_shape_ids: Optional[List[int]] = None,
|
| 306 |
+
pair_pronunciation_ids: Optional[List[int]] = None,
|
| 307 |
+
add_special_tokens: bool = True,
|
| 308 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
| 309 |
+
truncation: Union[bool, str, TruncationStrategy] = None,
|
| 310 |
+
max_length: Optional[int] = None,
|
| 311 |
+
stride: int = 0,
|
| 312 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 313 |
+
padding_side: Optional[str] = None,
|
| 314 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 315 |
+
return_token_type_ids: Optional[bool] = None,
|
| 316 |
+
return_attention_mask: Optional[bool] = None,
|
| 317 |
+
return_overflowing_tokens: bool = False,
|
| 318 |
+
return_special_tokens_mask: bool = False,
|
| 319 |
+
return_offsets_mapping: bool = False,
|
| 320 |
+
return_length: bool = False,
|
| 321 |
+
verbose: bool = True,
|
| 322 |
+
prepend_batch_axis: bool = False,
|
| 323 |
+
**kwargs,
|
| 324 |
+
) -> BatchEncoding:
|
| 325 |
+
"""
|
| 326 |
+
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
|
| 327 |
+
adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
|
| 328 |
+
manages a moving window (with user defined stride) for overflowing tokens. Please Note, for *pair_ids*
|
| 329 |
+
different than `None` and *truncation_strategy = longest_first* or `True`, it is not possible to return
|
| 330 |
+
overflowing tokens. Such a combination of arguments will raise an error.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
ids (`List[int]`):
|
| 334 |
+
Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
|
| 335 |
+
`convert_tokens_to_id` methods.
|
| 336 |
+
shape_ids (`List[int]`):
|
| 337 |
+
Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
|
| 338 |
+
`convert_token_to_shape_id` methods.
|
| 339 |
+
pronunciation_ids (`List[int]`):
|
| 340 |
+
Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
|
| 341 |
+
`convert_token_to_pronunciation_id` methods.
|
| 342 |
+
pair_ids (`List[int]`, *optional*):
|
| 343 |
+
Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
|
| 344 |
+
and `convert_tokens_to_id` methods.
|
| 345 |
+
pair_shape_ids (`List[int]`, *optional*):
|
| 346 |
+
Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
|
| 347 |
+
and `convert_token_to_shape_id` methods.
|
| 348 |
+
pair_pronunciation_ids (`List[int]`, *optional*):
|
| 349 |
+
Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
|
| 350 |
+
and `convert_token_to_pronunciation_id` methods.
|
| 351 |
+
"""
|
| 352 |
+
|
| 353 |
+
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
|
| 354 |
+
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
|
| 355 |
+
padding=padding,
|
| 356 |
+
truncation=truncation,
|
| 357 |
+
max_length=max_length,
|
| 358 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 359 |
+
verbose=verbose,
|
| 360 |
+
**kwargs,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
pair = bool(pair_ids is not None)
|
| 364 |
+
len_ids = len(ids)
|
| 365 |
+
len_pair_ids = len(pair_ids) if pair else 0
|
| 366 |
+
|
| 367 |
+
if return_token_type_ids and not add_special_tokens:
|
| 368 |
+
raise ValueError(
|
| 369 |
+
"Asking to return token_type_ids while setting add_special_tokens to False "
|
| 370 |
+
"results in an undefined behavior. Please set add_special_tokens to True or "
|
| 371 |
+
"set return_token_type_ids to None."
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
if (
|
| 375 |
+
return_overflowing_tokens
|
| 376 |
+
and truncation_strategy == TruncationStrategy.LONGEST_FIRST
|
| 377 |
+
and pair_ids is not None
|
| 378 |
+
):
|
| 379 |
+
raise ValueError(
|
| 380 |
+
"Not possible to return overflowing tokens for pair of sequences with the "
|
| 381 |
+
"`longest_first`. Please select another truncation strategy than `longest_first`, "
|
| 382 |
+
"for instance `only_second` or `only_first`."
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# Load from model defaults
|
| 386 |
+
if return_token_type_ids is None:
|
| 387 |
+
return_token_type_ids = "token_type_ids" in self.model_input_names
|
| 388 |
+
if return_attention_mask is None:
|
| 389 |
+
return_attention_mask = "attention_mask" in self.model_input_names
|
| 390 |
+
|
| 391 |
+
encoded_inputs = {}
|
| 392 |
+
|
| 393 |
+
# Compute the total size of the returned encodings
|
| 394 |
+
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
|
| 395 |
+
|
| 396 |
+
# Truncation: Handle max sequence length
|
| 397 |
+
overflowing_tokens = []
|
| 398 |
+
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
|
| 399 |
+
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
|
| 400 |
+
ids,
|
| 401 |
+
pair_ids=pair_ids,
|
| 402 |
+
num_tokens_to_remove=total_len - max_length,
|
| 403 |
+
truncation_strategy=truncation_strategy,
|
| 404 |
+
stride=stride,
|
| 405 |
+
)
|
| 406 |
+
shape_ids, pair_shape_ids, _ = self.truncate_sequences(
|
| 407 |
+
shape_ids,
|
| 408 |
+
pair_ids=pair_shape_ids,
|
| 409 |
+
num_tokens_to_remove=total_len - max_length,
|
| 410 |
+
truncation_strategy=truncation_strategy,
|
| 411 |
+
stride=stride,
|
| 412 |
+
)
|
| 413 |
+
pronunciation_ids, pair_pronunciation_ids, _ = self.truncate_sequences(
|
| 414 |
+
pronunciation_ids,
|
| 415 |
+
pair_ids=pair_pronunciation_ids,
|
| 416 |
+
num_tokens_to_remove=total_len - max_length,
|
| 417 |
+
truncation_strategy=truncation_strategy,
|
| 418 |
+
stride=stride,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
if return_overflowing_tokens:
|
| 422 |
+
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
| 423 |
+
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
| 424 |
+
|
| 425 |
+
# Add special tokens
|
| 426 |
+
if add_special_tokens:
|
| 427 |
+
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
|
| 428 |
+
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
|
| 429 |
+
input_shape_ids = self.build_inputs_with_special_tokens(
|
| 430 |
+
shape_ids, pair_shape_ids, self.word_shape["[UNK]"], self.word_shape["[UNK]"]
|
| 431 |
+
)
|
| 432 |
+
input_pronunciation_ids = self.build_inputs_with_special_tokens(
|
| 433 |
+
pronunciation_ids,
|
| 434 |
+
pair_pronunciation_ids,
|
| 435 |
+
self.word_pronunciation["[UNK]"],
|
| 436 |
+
self.word_pronunciation["[UNK]"],
|
| 437 |
+
)
|
| 438 |
+
else:
|
| 439 |
+
sequence = ids + pair_ids if pair_ids else ids
|
| 440 |
+
token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair_ids else [])
|
| 441 |
+
input_shape_ids = shape_ids + pair_shape_ids if pair_shape_ids else shape_ids
|
| 442 |
+
input_pronunciation_ids = (
|
| 443 |
+
pronunciation_ids + pair_pronunciation_ids if pair_pronunciation_ids else pronunciation_ids
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# Build output dictionary
|
| 447 |
+
encoded_inputs["input_ids"] = sequence
|
| 448 |
+
encoded_inputs["input_shape_ids"] = input_shape_ids
|
| 449 |
+
encoded_inputs["input_pronunciation_ids"] = input_pronunciation_ids
|
| 450 |
+
if return_token_type_ids:
|
| 451 |
+
encoded_inputs["token_type_ids"] = token_type_ids
|
| 452 |
+
if return_special_tokens_mask:
|
| 453 |
+
if add_special_tokens:
|
| 454 |
+
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
|
| 455 |
+
else:
|
| 456 |
+
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
|
| 457 |
+
|
| 458 |
+
# Check lengths
|
| 459 |
+
self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
|
| 460 |
+
|
| 461 |
+
# Padding
|
| 462 |
+
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
|
| 463 |
+
encoded_inputs = self.pad(
|
| 464 |
+
encoded_inputs,
|
| 465 |
+
max_length=max_length,
|
| 466 |
+
padding=padding_strategy.value,
|
| 467 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 468 |
+
padding_side=padding_side,
|
| 469 |
+
return_attention_mask=return_attention_mask,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
if return_length:
|
| 473 |
+
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
|
| 474 |
+
|
| 475 |
+
batch_outputs = BatchEncoding(
|
| 476 |
+
encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
return batch_outputs
|
| 480 |
+
|
| 481 |
+
def _pad(
|
| 482 |
+
self,
|
| 483 |
+
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
| 484 |
+
max_length: Optional[int] = None,
|
| 485 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
| 486 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 487 |
+
padding_side: Optional[str] = None,
|
| 488 |
+
return_attention_mask: Optional[bool] = None,
|
| 489 |
+
) -> dict:
|
| 490 |
+
# Load from model defaults
|
| 491 |
+
if return_attention_mask is None:
|
| 492 |
+
return_attention_mask = "attention_mask" in self.model_input_names
|
| 493 |
+
|
| 494 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
| 495 |
+
|
| 496 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
| 497 |
+
max_length = len(required_input)
|
| 498 |
+
|
| 499 |
+
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
| 500 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
| 501 |
+
|
| 502 |
+
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
| 503 |
+
|
| 504 |
+
# Initialize attention mask if not present.
|
| 505 |
+
if return_attention_mask and "attention_mask" not in encoded_inputs:
|
| 506 |
+
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
| 507 |
+
|
| 508 |
+
if needs_to_be_padded:
|
| 509 |
+
difference = max_length - len(required_input)
|
| 510 |
+
padding_side = padding_side if padding_side is not None else self.padding_side
|
| 511 |
+
|
| 512 |
+
if padding_side == "right":
|
| 513 |
+
if return_attention_mask:
|
| 514 |
+
encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
|
| 515 |
+
if "token_type_ids" in encoded_inputs:
|
| 516 |
+
encoded_inputs["token_type_ids"] = (
|
| 517 |
+
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
|
| 518 |
+
)
|
| 519 |
+
if "special_tokens_mask" in encoded_inputs:
|
| 520 |
+
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
|
| 521 |
+
for key in ["input_shape_ids", "input_pronunciation_ids"]:
|
| 522 |
+
if key in encoded_inputs:
|
| 523 |
+
encoded_inputs[key] = encoded_inputs[key] + [self.pad_token_id] * difference
|
| 524 |
+
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
|
| 525 |
+
elif padding_side == "left":
|
| 526 |
+
if return_attention_mask:
|
| 527 |
+
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
| 528 |
+
if "token_type_ids" in encoded_inputs:
|
| 529 |
+
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
|
| 530 |
+
"token_type_ids"
|
| 531 |
+
]
|
| 532 |
+
if "special_tokens_mask" in encoded_inputs:
|
| 533 |
+
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
| 534 |
+
for key in ["input_shape_ids", "input_pronunciation_ids"]:
|
| 535 |
+
if key in encoded_inputs:
|
| 536 |
+
encoded_inputs[key] = [self.pad_token_id] * difference + encoded_inputs[key]
|
| 537 |
+
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
| 538 |
+
else:
|
| 539 |
+
raise ValueError("Invalid padding strategy:" + str(padding_side))
|
| 540 |
+
|
| 541 |
+
return encoded_inputs
|
| 542 |
+
|
| 543 |
+
def _batch_encode_plus(
|
| 544 |
+
self,
|
| 545 |
+
batch_text_or_text_pairs: Union[
|
| 546 |
+
List[TextInput],
|
| 547 |
+
List[TextInputPair],
|
| 548 |
+
List[PreTokenizedInput],
|
| 549 |
+
List[PreTokenizedInputPair],
|
| 550 |
+
List[EncodedInput],
|
| 551 |
+
List[EncodedInputPair],
|
| 552 |
+
],
|
| 553 |
+
add_special_tokens: bool = True,
|
| 554 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
| 555 |
+
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
| 556 |
+
max_length: Optional[int] = None,
|
| 557 |
+
stride: int = 0,
|
| 558 |
+
is_split_into_words: bool = False,
|
| 559 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 560 |
+
padding_side: Optional[str] = None,
|
| 561 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 562 |
+
return_token_type_ids: Optional[bool] = None,
|
| 563 |
+
return_attention_mask: Optional[bool] = None,
|
| 564 |
+
return_overflowing_tokens: bool = False,
|
| 565 |
+
return_special_tokens_mask: bool = False,
|
| 566 |
+
return_offsets_mapping: bool = False,
|
| 567 |
+
return_length: bool = False,
|
| 568 |
+
verbose: bool = True,
|
| 569 |
+
**kwargs,
|
| 570 |
+
) -> BatchEncoding:
|
| 571 |
+
def get_input_ids(text):
|
| 572 |
+
if isinstance(text, str):
|
| 573 |
+
tokens = self.tokenize(text, **kwargs)
|
| 574 |
+
tokens_ids = self.convert_tokens_to_ids(tokens)
|
| 575 |
+
tokens_shape_ids = self.convert_tokens_to_shape_ids(tokens)
|
| 576 |
+
tokens_proun_ids = self.convert_tokens_to_pronunciation_ids(tokens)
|
| 577 |
+
return tokens_ids, tokens_shape_ids, tokens_proun_ids
|
| 578 |
+
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
|
| 579 |
+
if is_split_into_words:
|
| 580 |
+
tokens = list(
|
| 581 |
+
itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))
|
| 582 |
+
)
|
| 583 |
+
tokens_ids = self.convert_tokens_to_ids(tokens)
|
| 584 |
+
tokens_shape_ids = self.convert_tokens_to_shape_ids(tokens)
|
| 585 |
+
tokens_proun_ids = self.convert_tokens_to_pronunciation_ids(tokens)
|
| 586 |
+
return tokens_ids, tokens_shape_ids, tokens_proun_ids
|
| 587 |
+
else:
|
| 588 |
+
tokens_ids = self.convert_tokens_to_ids(text)
|
| 589 |
+
tokens_shape_ids = self.convert_tokens_to_shape_ids(text)
|
| 590 |
+
tokens_proun_ids = self.convert_tokens_to_pronunciation_ids(text)
|
| 591 |
+
return tokens_ids, tokens_shape_ids, tokens_proun_ids
|
| 592 |
+
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
|
| 593 |
+
return text, [0] * len(text), [0] * len(text) # shape and proun id is pad_value
|
| 594 |
+
else:
|
| 595 |
+
raise ValueError(
|
| 596 |
+
"Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
if return_offsets_mapping:
|
| 600 |
+
raise NotImplementedError(
|
| 601 |
+
"return_offset_mapping is not available when using Python tokenizers. "
|
| 602 |
+
"To use this feature, change your tokenizer to one deriving from "
|
| 603 |
+
"transformers.PreTrainedTokenizerFast."
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
input_ids = []
|
| 607 |
+
input_shape_ids = []
|
| 608 |
+
input_pronunciation_ids = []
|
| 609 |
+
for ids_or_pair_ids in batch_text_or_text_pairs:
|
| 610 |
+
if not isinstance(ids_or_pair_ids, (list, tuple)):
|
| 611 |
+
ids, pair_ids = ids_or_pair_ids, None
|
| 612 |
+
elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)):
|
| 613 |
+
ids, pair_ids = ids_or_pair_ids, None
|
| 614 |
+
else:
|
| 615 |
+
ids, pair_ids = ids_or_pair_ids
|
| 616 |
+
|
| 617 |
+
first_ids, first_shape_ids, first_proun_ids = get_input_ids(ids)
|
| 618 |
+
if pair_ids is not None:
|
| 619 |
+
second_ids, second_shape_ids, second_proun_ids = get_input_ids(pair_ids)
|
| 620 |
+
else:
|
| 621 |
+
second_ids, second_shape_ids, second_proun_ids = None, None, None
|
| 622 |
+
|
| 623 |
+
input_ids.append((first_ids, second_ids))
|
| 624 |
+
input_shape_ids.append((first_shape_ids, second_shape_ids))
|
| 625 |
+
input_pronunciation_ids.append((first_proun_ids, second_proun_ids))
|
| 626 |
+
|
| 627 |
+
batch_outputs = self._batch_prepare_for_model(
|
| 628 |
+
input_ids,
|
| 629 |
+
batch_shape_ids_pairs=input_shape_ids,
|
| 630 |
+
batch_pronunciation_ids_pairs=input_pronunciation_ids,
|
| 631 |
+
add_special_tokens=add_special_tokens,
|
| 632 |
+
padding_strategy=padding_strategy,
|
| 633 |
+
truncation_strategy=truncation_strategy,
|
| 634 |
+
max_length=max_length,
|
| 635 |
+
stride=stride,
|
| 636 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 637 |
+
padding_side=padding_side,
|
| 638 |
+
return_attention_mask=return_attention_mask,
|
| 639 |
+
return_token_type_ids=return_token_type_ids,
|
| 640 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
| 641 |
+
return_special_tokens_mask=return_special_tokens_mask,
|
| 642 |
+
return_length=return_length,
|
| 643 |
+
return_tensors=return_tensors,
|
| 644 |
+
verbose=verbose,
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
return BatchEncoding(batch_outputs)
|
| 648 |
+
|
| 649 |
+
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
|
| 650 |
+
def _batch_prepare_for_model(
|
| 651 |
+
self,
|
| 652 |
+
batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]],
|
| 653 |
+
batch_shape_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]],
|
| 654 |
+
batch_pronunciation_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]],
|
| 655 |
+
add_special_tokens: bool = True,
|
| 656 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
| 657 |
+
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
| 658 |
+
max_length: Optional[int] = None,
|
| 659 |
+
stride: int = 0,
|
| 660 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 661 |
+
padding_side: Optional[str] = None,
|
| 662 |
+
return_tensors: Optional[str] = None,
|
| 663 |
+
return_token_type_ids: Optional[bool] = None,
|
| 664 |
+
return_attention_mask: Optional[bool] = None,
|
| 665 |
+
return_overflowing_tokens: bool = False,
|
| 666 |
+
return_special_tokens_mask: bool = False,
|
| 667 |
+
return_length: bool = False,
|
| 668 |
+
verbose: bool = True,
|
| 669 |
+
) -> BatchEncoding:
|
| 670 |
+
"""
|
| 671 |
+
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
|
| 672 |
+
adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
|
| 673 |
+
manages a moving window (with user defined stride) for overflowing tokens
|
| 674 |
+
|
| 675 |
+
Args:
|
| 676 |
+
batch_ids_pairs: list of tokenized input ids or input ids pairs
|
| 677 |
+
batch_shape_ids_pairs: list of tokenized input shape ids or input shape ids pairs
|
| 678 |
+
batch_pronunciation_ids_pairs: list of tokenized input pronunciation ids or input pronunciation ids pairs
|
| 679 |
+
"""
|
| 680 |
+
|
| 681 |
+
batch_outputs = {}
|
| 682 |
+
for i, (first_ids, second_ids) in enumerate(batch_ids_pairs):
|
| 683 |
+
first_shape_ids, second_shape_ids = batch_shape_ids_pairs[i]
|
| 684 |
+
first_pronunciation_ids, second_pronunciation_ids = batch_pronunciation_ids_pairs[i]
|
| 685 |
+
outputs = self.prepare_for_model(
|
| 686 |
+
first_ids,
|
| 687 |
+
first_shape_ids,
|
| 688 |
+
first_pronunciation_ids,
|
| 689 |
+
pair_ids=second_ids,
|
| 690 |
+
pair_shape_ids=second_shape_ids,
|
| 691 |
+
pair_pronunciation_ids=second_pronunciation_ids,
|
| 692 |
+
add_special_tokens=add_special_tokens,
|
| 693 |
+
padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward
|
| 694 |
+
truncation=truncation_strategy.value,
|
| 695 |
+
max_length=max_length,
|
| 696 |
+
stride=stride,
|
| 697 |
+
pad_to_multiple_of=None, # we pad in batch afterward
|
| 698 |
+
padding_side=None, # we pad in batch afterward
|
| 699 |
+
return_attention_mask=False, # we pad in batch afterward
|
| 700 |
+
return_token_type_ids=return_token_type_ids,
|
| 701 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
| 702 |
+
return_special_tokens_mask=return_special_tokens_mask,
|
| 703 |
+
return_length=return_length,
|
| 704 |
+
return_tensors=None, # We convert the whole batch to tensors at the end
|
| 705 |
+
prepend_batch_axis=False,
|
| 706 |
+
verbose=verbose,
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
for key, value in outputs.items():
|
| 710 |
+
if key not in batch_outputs:
|
| 711 |
+
batch_outputs[key] = []
|
| 712 |
+
batch_outputs[key].append(value)
|
| 713 |
+
|
| 714 |
+
batch_outputs = self.pad(
|
| 715 |
+
batch_outputs,
|
| 716 |
+
padding=padding_strategy.value,
|
| 717 |
+
max_length=max_length,
|
| 718 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 719 |
+
padding_side=padding_side,
|
| 720 |
+
return_attention_mask=return_attention_mask,
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
|
| 724 |
+
|
| 725 |
+
return batch_outputs
|
| 726 |
+
|
| 727 |
+
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_token_to_id
|
| 728 |
+
def _convert_token_to_id(self, token):
|
| 729 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 730 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 731 |
+
|
| 732 |
+
def _convert_token_to_shape_id(self, token):
|
| 733 |
+
"""Converts a token (str) in an shape_id using the shape vocab."""
|
| 734 |
+
return self.word_shape.get(token, self.word_shape.get(self.unk_token))
|
| 735 |
+
|
| 736 |
+
def convert_tokens_to_shape_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
|
| 737 |
+
if tokens is None:
|
| 738 |
+
return None
|
| 739 |
+
|
| 740 |
+
ids = []
|
| 741 |
+
for token in tokens:
|
| 742 |
+
ids.append(self._convert_token_to_shape_id(token))
|
| 743 |
+
return ids
|
| 744 |
+
|
| 745 |
+
def _convert_token_to_pronunciation_id(self, token):
|
| 746 |
+
"""Converts a token (str) in an shape_id using the shape vocab."""
|
| 747 |
+
return self.word_pronunciation.get(token, self.word_pronunciation.get(self.unk_token))
|
| 748 |
+
|
| 749 |
+
def convert_tokens_to_pronunciation_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
|
| 750 |
+
if tokens is None:
|
| 751 |
+
return None
|
| 752 |
+
|
| 753 |
+
ids = []
|
| 754 |
+
for token in tokens:
|
| 755 |
+
ids.append(self._convert_token_to_pronunciation_id(token))
|
| 756 |
+
return ids
|
| 757 |
+
|
| 758 |
+
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_id_to_token
|
| 759 |
+
def _convert_id_to_token(self, index):
|
| 760 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 761 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 762 |
+
|
| 763 |
+
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.convert_tokens_to_string
|
| 764 |
+
def convert_tokens_to_string(self, tokens):
|
| 765 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 766 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 767 |
+
return out_string
|
| 768 |
+
|
| 769 |
+
def build_inputs_with_special_tokens(
|
| 770 |
+
self,
|
| 771 |
+
token_ids_0: List[int],
|
| 772 |
+
token_ids_1: Optional[List[int]] = None,
|
| 773 |
+
cls_token_id: Optional[int] = None,
|
| 774 |
+
sep_token_id: Optional[int] = None,
|
| 775 |
+
) -> List[int]:
|
| 776 |
+
"""
|
| 777 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 778 |
+
adding special tokens. A BERT sequence has the following format:
|
| 779 |
+
|
| 780 |
+
- single sequence: `[CLS] X [SEP]`
|
| 781 |
+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
| 782 |
+
|
| 783 |
+
Args:
|
| 784 |
+
token_ids_0 (`List[int]`):
|
| 785 |
+
List of IDs to which the special tokens will be added.
|
| 786 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 787 |
+
Optional second list of IDs for sequence pairs.
|
| 788 |
+
|
| 789 |
+
Returns:
|
| 790 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 791 |
+
"""
|
| 792 |
+
cls = [self.cls_token_id] if cls_token_id is None else [cls_token_id]
|
| 793 |
+
sep = [self.sep_token_id] if sep_token_id is None else [sep_token_id]
|
| 794 |
+
if token_ids_1 is None:
|
| 795 |
+
return cls + token_ids_0 + sep
|
| 796 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 797 |
+
|
| 798 |
+
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask
|
| 799 |
+
def get_special_tokens_mask(
|
| 800 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 801 |
+
) -> List[int]:
|
| 802 |
+
"""
|
| 803 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 804 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 805 |
+
|
| 806 |
+
Args:
|
| 807 |
+
token_ids_0 (`List[int]`):
|
| 808 |
+
List of IDs.
|
| 809 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 810 |
+
Optional second list of IDs for sequence pairs.
|
| 811 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 812 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 813 |
+
|
| 814 |
+
Returns:
|
| 815 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 816 |
+
"""
|
| 817 |
+
|
| 818 |
+
if already_has_special_tokens:
|
| 819 |
+
return super().get_special_tokens_mask(
|
| 820 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
if token_ids_1 is not None:
|
| 824 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 825 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 826 |
+
|
| 827 |
+
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.create_token_type_ids_from_sequences
|
| 828 |
+
def create_token_type_ids_from_sequences(
|
| 829 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 830 |
+
) -> List[int]:
|
| 831 |
+
"""
|
| 832 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
|
| 833 |
+
pair mask has the following format:
|
| 834 |
+
|
| 835 |
+
```
|
| 836 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 837 |
+
| first sequence | second sequence |
|
| 838 |
+
```
|
| 839 |
+
|
| 840 |
+
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
|
| 841 |
+
|
| 842 |
+
Args:
|
| 843 |
+
token_ids_0 (`List[int]`):
|
| 844 |
+
List of IDs.
|
| 845 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 846 |
+
Optional second list of IDs for sequence pairs.
|
| 847 |
+
|
| 848 |
+
Returns:
|
| 849 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
| 850 |
+
"""
|
| 851 |
+
sep = [self.sep_token_id]
|
| 852 |
+
cls = [self.cls_token_id]
|
| 853 |
+
if token_ids_1 is None:
|
| 854 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 855 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 856 |
+
|
| 857 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, str, str]:
|
| 858 |
+
index = 0
|
| 859 |
+
if os.path.isdir(save_directory):
|
| 860 |
+
vocab_file = os.path.join(
|
| 861 |
+
save_directory,
|
| 862 |
+
(filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"],
|
| 863 |
+
)
|
| 864 |
+
word_shape_file = os.path.join(
|
| 865 |
+
save_directory,
|
| 866 |
+
(filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["word_shape_file"],
|
| 867 |
+
)
|
| 868 |
+
word_pronunciation_file = os.path.join(
|
| 869 |
+
save_directory,
|
| 870 |
+
(filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["word_pronunciation_file"],
|
| 871 |
+
)
|
| 872 |
+
else:
|
| 873 |
+
raise ValueError(
|
| 874 |
+
f"Can't find a directory at path '{save_directory}'. To load the vocabulary from a Google "
|
| 875 |
+
"pretrained model use `tokenizer = RoCBertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 879 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 880 |
+
if index != token_index:
|
| 881 |
+
logger.warning(
|
| 882 |
+
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
| 883 |
+
" Please check that the vocabulary is not corrupted!"
|
| 884 |
+
)
|
| 885 |
+
index = token_index
|
| 886 |
+
writer.write(token + "\n")
|
| 887 |
+
index += 1
|
| 888 |
+
|
| 889 |
+
with open(word_shape_file, "w", encoding="utf8") as writer:
|
| 890 |
+
json.dump(self.word_shape, writer, ensure_ascii=False, indent=4, separators=(", ", ": "))
|
| 891 |
+
|
| 892 |
+
with open(word_pronunciation_file, "w", encoding="utf8") as writer:
|
| 893 |
+
json.dump(self.word_pronunciation, writer, ensure_ascii=False, indent=4, separators=(", ", ": "))
|
| 894 |
+
|
| 895 |
+
return (
|
| 896 |
+
vocab_file,
|
| 897 |
+
word_shape_file,
|
| 898 |
+
word_pronunciation_file,
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer with BasicTokenizer->RoCBertBasicTokenizer
|
| 903 |
+
class RoCBertBasicTokenizer:
|
| 904 |
+
"""
|
| 905 |
+
Constructs a RoCBertBasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
|
| 906 |
+
|
| 907 |
+
Args:
|
| 908 |
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
| 909 |
+
Whether or not to lowercase the input when tokenizing.
|
| 910 |
+
never_split (`Iterable`, *optional*):
|
| 911 |
+
Collection of tokens which will never be split during tokenization. Only has an effect when
|
| 912 |
+
`do_basic_tokenize=True`
|
| 913 |
+
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
|
| 914 |
+
Whether or not to tokenize Chinese characters.
|
| 915 |
+
|
| 916 |
+
This should likely be deactivated for Japanese (see this
|
| 917 |
+
[issue](https://github.com/huggingface/transformers/issues/328)).
|
| 918 |
+
strip_accents (`bool`, *optional*):
|
| 919 |
+
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
| 920 |
+
value for `lowercase` (as in the original BERT).
|
| 921 |
+
do_split_on_punc (`bool`, *optional*, defaults to `True`):
|
| 922 |
+
In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
|
| 923 |
+
the full context of the words, such as contractions.
|
| 924 |
+
"""
|
| 925 |
+
|
| 926 |
+
def __init__(
|
| 927 |
+
self,
|
| 928 |
+
do_lower_case=True,
|
| 929 |
+
never_split=None,
|
| 930 |
+
tokenize_chinese_chars=True,
|
| 931 |
+
strip_accents=None,
|
| 932 |
+
do_split_on_punc=True,
|
| 933 |
+
):
|
| 934 |
+
if never_split is None:
|
| 935 |
+
never_split = []
|
| 936 |
+
self.do_lower_case = do_lower_case
|
| 937 |
+
self.never_split = set(never_split)
|
| 938 |
+
self.tokenize_chinese_chars = tokenize_chinese_chars
|
| 939 |
+
self.strip_accents = strip_accents
|
| 940 |
+
self.do_split_on_punc = do_split_on_punc
|
| 941 |
+
|
| 942 |
+
def tokenize(self, text, never_split=None):
|
| 943 |
+
"""
|
| 944 |
+
Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
|
| 945 |
+
|
| 946 |
+
Args:
|
| 947 |
+
never_split (`List[str]`, *optional*)
|
| 948 |
+
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
|
| 949 |
+
[`PreTrainedTokenizer.tokenize`]) List of token not to split.
|
| 950 |
+
"""
|
| 951 |
+
# union() returns a new set by concatenating the two sets.
|
| 952 |
+
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
|
| 953 |
+
text = self._clean_text(text)
|
| 954 |
+
|
| 955 |
+
# This was added on November 1st, 2018 for the multilingual and Chinese
|
| 956 |
+
# models. This is also applied to the English models now, but it doesn't
|
| 957 |
+
# matter since the English models were not trained on any Chinese data
|
| 958 |
+
# and generally don't have any Chinese data in them (there are Chinese
|
| 959 |
+
# characters in the vocabulary because Wikipedia does have some Chinese
|
| 960 |
+
# words in the English Wikipedia.).
|
| 961 |
+
if self.tokenize_chinese_chars:
|
| 962 |
+
text = self._tokenize_chinese_chars(text)
|
| 963 |
+
# prevents treating the same character with different unicode codepoints as different characters
|
| 964 |
+
unicode_normalized_text = unicodedata.normalize("NFC", text)
|
| 965 |
+
orig_tokens = whitespace_tokenize(unicode_normalized_text)
|
| 966 |
+
split_tokens = []
|
| 967 |
+
for token in orig_tokens:
|
| 968 |
+
if token not in never_split:
|
| 969 |
+
if self.do_lower_case:
|
| 970 |
+
token = token.lower()
|
| 971 |
+
if self.strip_accents is not False:
|
| 972 |
+
token = self._run_strip_accents(token)
|
| 973 |
+
elif self.strip_accents:
|
| 974 |
+
token = self._run_strip_accents(token)
|
| 975 |
+
split_tokens.extend(self._run_split_on_punc(token, never_split))
|
| 976 |
+
|
| 977 |
+
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
| 978 |
+
return output_tokens
|
| 979 |
+
|
| 980 |
+
def _run_strip_accents(self, text):
|
| 981 |
+
"""Strips accents from a piece of text."""
|
| 982 |
+
text = unicodedata.normalize("NFD", text)
|
| 983 |
+
output = []
|
| 984 |
+
for char in text:
|
| 985 |
+
cat = unicodedata.category(char)
|
| 986 |
+
if cat == "Mn":
|
| 987 |
+
continue
|
| 988 |
+
output.append(char)
|
| 989 |
+
return "".join(output)
|
| 990 |
+
|
| 991 |
+
def _run_split_on_punc(self, text, never_split=None):
|
| 992 |
+
"""Splits punctuation on a piece of text."""
|
| 993 |
+
if not self.do_split_on_punc or (never_split is not None and text in never_split):
|
| 994 |
+
return [text]
|
| 995 |
+
chars = list(text)
|
| 996 |
+
i = 0
|
| 997 |
+
start_new_word = True
|
| 998 |
+
output = []
|
| 999 |
+
while i < len(chars):
|
| 1000 |
+
char = chars[i]
|
| 1001 |
+
if _is_punctuation(char):
|
| 1002 |
+
output.append([char])
|
| 1003 |
+
start_new_word = True
|
| 1004 |
+
else:
|
| 1005 |
+
if start_new_word:
|
| 1006 |
+
output.append([])
|
| 1007 |
+
start_new_word = False
|
| 1008 |
+
output[-1].append(char)
|
| 1009 |
+
i += 1
|
| 1010 |
+
|
| 1011 |
+
return ["".join(x) for x in output]
|
| 1012 |
+
|
| 1013 |
+
def _tokenize_chinese_chars(self, text):
|
| 1014 |
+
"""Adds whitespace around any CJK character."""
|
| 1015 |
+
output = []
|
| 1016 |
+
for char in text:
|
| 1017 |
+
cp = ord(char)
|
| 1018 |
+
if self._is_chinese_char(cp):
|
| 1019 |
+
output.append(" ")
|
| 1020 |
+
output.append(char)
|
| 1021 |
+
output.append(" ")
|
| 1022 |
+
else:
|
| 1023 |
+
output.append(char)
|
| 1024 |
+
return "".join(output)
|
| 1025 |
+
|
| 1026 |
+
def _is_chinese_char(self, cp):
|
| 1027 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
| 1028 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
| 1029 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
| 1030 |
+
#
|
| 1031 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
| 1032 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
| 1033 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
| 1034 |
+
# space-separated words, so they are not treated specially and handled
|
| 1035 |
+
# like the all of the other languages.
|
| 1036 |
+
if (
|
| 1037 |
+
(cp >= 0x4E00 and cp <= 0x9FFF)
|
| 1038 |
+
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
| 1039 |
+
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
| 1040 |
+
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
| 1041 |
+
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
| 1042 |
+
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
| 1043 |
+
or (cp >= 0xF900 and cp <= 0xFAFF)
|
| 1044 |
+
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
| 1045 |
+
): #
|
| 1046 |
+
return True
|
| 1047 |
+
|
| 1048 |
+
return False
|
| 1049 |
+
|
| 1050 |
+
def _clean_text(self, text):
|
| 1051 |
+
"""Performs invalid character removal and whitespace cleanup on text."""
|
| 1052 |
+
output = []
|
| 1053 |
+
for char in text:
|
| 1054 |
+
cp = ord(char)
|
| 1055 |
+
if cp == 0 or cp == 0xFFFD or _is_control(char):
|
| 1056 |
+
continue
|
| 1057 |
+
if _is_whitespace(char):
|
| 1058 |
+
output.append(" ")
|
| 1059 |
+
else:
|
| 1060 |
+
output.append(char)
|
| 1061 |
+
return "".join(output)
|
| 1062 |
+
|
| 1063 |
+
|
| 1064 |
+
# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer with WordpieceTokenizer->RoCBertWordpieceTokenizer
|
| 1065 |
+
class RoCBertWordpieceTokenizer:
|
| 1066 |
+
"""Runs WordPiece tokenization."""
|
| 1067 |
+
|
| 1068 |
+
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
|
| 1069 |
+
self.vocab = vocab
|
| 1070 |
+
self.unk_token = unk_token
|
| 1071 |
+
self.max_input_chars_per_word = max_input_chars_per_word
|
| 1072 |
+
|
| 1073 |
+
def tokenize(self, text):
|
| 1074 |
+
"""
|
| 1075 |
+
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
|
| 1076 |
+
tokenization using the given vocabulary.
|
| 1077 |
+
|
| 1078 |
+
For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
|
| 1079 |
+
|
| 1080 |
+
Args:
|
| 1081 |
+
text: A single token or whitespace separated tokens. This should have
|
| 1082 |
+
already been passed through *BasicTokenizer*.
|
| 1083 |
+
|
| 1084 |
+
Returns:
|
| 1085 |
+
A list of wordpiece tokens.
|
| 1086 |
+
"""
|
| 1087 |
+
|
| 1088 |
+
output_tokens = []
|
| 1089 |
+
for token in whitespace_tokenize(text):
|
| 1090 |
+
chars = list(token)
|
| 1091 |
+
if len(chars) > self.max_input_chars_per_word:
|
| 1092 |
+
output_tokens.append(self.unk_token)
|
| 1093 |
+
continue
|
| 1094 |
+
|
| 1095 |
+
is_bad = False
|
| 1096 |
+
start = 0
|
| 1097 |
+
sub_tokens = []
|
| 1098 |
+
while start < len(chars):
|
| 1099 |
+
end = len(chars)
|
| 1100 |
+
cur_substr = None
|
| 1101 |
+
while start < end:
|
| 1102 |
+
substr = "".join(chars[start:end])
|
| 1103 |
+
if start > 0:
|
| 1104 |
+
substr = "##" + substr
|
| 1105 |
+
if substr in self.vocab:
|
| 1106 |
+
cur_substr = substr
|
| 1107 |
+
break
|
| 1108 |
+
end -= 1
|
| 1109 |
+
if cur_substr is None:
|
| 1110 |
+
is_bad = True
|
| 1111 |
+
break
|
| 1112 |
+
sub_tokens.append(cur_substr)
|
| 1113 |
+
start = end
|
| 1114 |
+
|
| 1115 |
+
if is_bad:
|
| 1116 |
+
output_tokens.append(self.unk_token)
|
| 1117 |
+
else:
|
| 1118 |
+
output_tokens.extend(sub_tokens)
|
| 1119 |
+
return output_tokens
|
| 1120 |
+
|
| 1121 |
+
|
| 1122 |
+
__all__ = ["RoCBertTokenizer"]
|
docs/transformers/build/lib/transformers/models/roformer/modeling_flax_roformer.py
ADDED
|
@@ -0,0 +1,1091 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Flax RoFormer model."""
|
| 16 |
+
|
| 17 |
+
from typing import Callable, Optional, Tuple
|
| 18 |
+
|
| 19 |
+
import flax.linen as nn
|
| 20 |
+
import jax
|
| 21 |
+
import jax.numpy as jnp
|
| 22 |
+
import numpy as np
|
| 23 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
| 24 |
+
from flax.linen.attention import dot_product_attention_weights
|
| 25 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
| 26 |
+
from jax import lax
|
| 27 |
+
|
| 28 |
+
from ...modeling_flax_outputs import (
|
| 29 |
+
FlaxBaseModelOutput,
|
| 30 |
+
FlaxMaskedLMOutput,
|
| 31 |
+
FlaxMultipleChoiceModelOutput,
|
| 32 |
+
FlaxQuestionAnsweringModelOutput,
|
| 33 |
+
FlaxSequenceClassifierOutput,
|
| 34 |
+
FlaxTokenClassifierOutput,
|
| 35 |
+
)
|
| 36 |
+
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring
|
| 37 |
+
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
| 38 |
+
from .configuration_roformer import RoFormerConfig
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
_CHECKPOINT_FOR_DOC = "junnyu/roformer_chinese_base"
|
| 44 |
+
_CONFIG_FOR_DOC = "RoFormerConfig"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
ROFORMER_START_DOCSTRING = r"""
|
| 48 |
+
|
| 49 |
+
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 50 |
+
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
|
| 51 |
+
|
| 52 |
+
This model is also a
|
| 53 |
+
[flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
|
| 54 |
+
a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
|
| 55 |
+
behavior.
|
| 56 |
+
|
| 57 |
+
Finally, this model supports inherent JAX features such as:
|
| 58 |
+
|
| 59 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
| 60 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
| 61 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
| 62 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
| 63 |
+
|
| 64 |
+
Parameters:
|
| 65 |
+
config ([`RoFormerConfig`]): Model configuration class with all the parameters of the
|
| 66 |
+
model. Initializing with a config file does not load the weights associated with the model, only the
|
| 67 |
+
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
| 68 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
| 69 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
| 70 |
+
`jax.numpy.bfloat16` (on TPUs).
|
| 71 |
+
|
| 72 |
+
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
| 73 |
+
specified all the computation will be performed with the given `dtype`.
|
| 74 |
+
|
| 75 |
+
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
| 76 |
+
parameters.**
|
| 77 |
+
|
| 78 |
+
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
| 79 |
+
[`~FlaxPreTrainedModel.to_bf16`].
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
ROFORMER_INPUTS_DOCSTRING = r"""
|
| 83 |
+
Args:
|
| 84 |
+
input_ids (`numpy.ndarray` of shape `({0})`):
|
| 85 |
+
Indices of input sequence tokens in the vocabulary.
|
| 86 |
+
|
| 87 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 88 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 89 |
+
|
| 90 |
+
[What are input IDs?](../glossary#input-ids)
|
| 91 |
+
attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
|
| 92 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 93 |
+
|
| 94 |
+
- 1 for tokens that are **not masked**,
|
| 95 |
+
- 0 for tokens that are **masked**.
|
| 96 |
+
|
| 97 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 98 |
+
token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
|
| 99 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 100 |
+
1]`:
|
| 101 |
+
|
| 102 |
+
- 0 corresponds to a *sentence A* token,
|
| 103 |
+
- 1 corresponds to a *sentence B* token.
|
| 104 |
+
|
| 105 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 106 |
+
position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
|
| 107 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 108 |
+
config.max_position_embeddings - 1]`.
|
| 109 |
+
head_mask (`numpy.ndarray` of shape `({0})`, `optional):
|
| 110 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
| 111 |
+
|
| 112 |
+
- 1 indicates the head is **not masked**,
|
| 113 |
+
- 0 indicates the head is **masked**.
|
| 114 |
+
|
| 115 |
+
return_dict (`bool`, *optional*):
|
| 116 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# Copied from transformers.models.marian.modeling_flax_marian.create_sinusoidal_positions
|
| 121 |
+
def create_sinusoidal_positions(n_pos, dim):
|
| 122 |
+
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
|
| 123 |
+
sentinel = dim // 2 + dim % 2
|
| 124 |
+
out = np.zeros_like(position_enc)
|
| 125 |
+
out[:, 0:sentinel] = np.sin(position_enc[:, 0::2])
|
| 126 |
+
out[:, sentinel:] = np.cos(position_enc[:, 1::2])
|
| 127 |
+
|
| 128 |
+
return jnp.array(out)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class FlaxRoFormerEmbeddings(nn.Module):
|
| 132 |
+
"""Construct the embeddings from word and token_type embeddings."""
|
| 133 |
+
|
| 134 |
+
config: RoFormerConfig
|
| 135 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 136 |
+
|
| 137 |
+
def setup(self):
|
| 138 |
+
self.word_embeddings = nn.Embed(
|
| 139 |
+
self.config.vocab_size,
|
| 140 |
+
self.config.hidden_size,
|
| 141 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
| 142 |
+
)
|
| 143 |
+
self.token_type_embeddings = nn.Embed(
|
| 144 |
+
self.config.type_vocab_size,
|
| 145 |
+
self.config.hidden_size,
|
| 146 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
| 147 |
+
)
|
| 148 |
+
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 149 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 150 |
+
|
| 151 |
+
def __call__(self, input_ids, token_type_ids, attention_mask, deterministic: bool = True):
|
| 152 |
+
# Embed
|
| 153 |
+
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
|
| 154 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
|
| 155 |
+
|
| 156 |
+
# Sum all embeddings
|
| 157 |
+
hidden_states = inputs_embeds + token_type_embeddings
|
| 158 |
+
|
| 159 |
+
# Layer Norm
|
| 160 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 161 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 162 |
+
return hidden_states
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class FlaxRoFormerSelfAttention(nn.Module):
|
| 166 |
+
config: RoFormerConfig
|
| 167 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 168 |
+
|
| 169 |
+
def setup(self) -> None:
|
| 170 |
+
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
| 171 |
+
raise ValueError(
|
| 172 |
+
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
|
| 173 |
+
" : {self.config.num_attention_heads}"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
self.query = nn.Dense(
|
| 177 |
+
self.config.hidden_size,
|
| 178 |
+
dtype=self.dtype,
|
| 179 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 180 |
+
)
|
| 181 |
+
self.key = nn.Dense(
|
| 182 |
+
self.config.hidden_size,
|
| 183 |
+
dtype=self.dtype,
|
| 184 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 185 |
+
)
|
| 186 |
+
self.value = nn.Dense(
|
| 187 |
+
self.config.hidden_size,
|
| 188 |
+
dtype=self.dtype,
|
| 189 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
self.rotary_value = self.config.rotary_value
|
| 193 |
+
|
| 194 |
+
def __call__(
|
| 195 |
+
self,
|
| 196 |
+
hidden_states,
|
| 197 |
+
attention_mask,
|
| 198 |
+
sinusoidal_pos,
|
| 199 |
+
layer_head_mask,
|
| 200 |
+
deterministic=True,
|
| 201 |
+
output_attentions: bool = False,
|
| 202 |
+
):
|
| 203 |
+
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
| 204 |
+
|
| 205 |
+
query_states = self.query(hidden_states).reshape(
|
| 206 |
+
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
| 207 |
+
)
|
| 208 |
+
value_states = self.value(hidden_states).reshape(
|
| 209 |
+
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
| 210 |
+
)
|
| 211 |
+
key_states = self.key(hidden_states).reshape(
|
| 212 |
+
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if sinusoidal_pos is not None:
|
| 216 |
+
if self.rotary_value:
|
| 217 |
+
query_states, key_states, value_states = self.apply_rotary_position_embeddings(
|
| 218 |
+
sinusoidal_pos, query_states, key_states, value_states
|
| 219 |
+
)
|
| 220 |
+
else:
|
| 221 |
+
query_states, key_states = self.apply_rotary_position_embeddings(
|
| 222 |
+
sinusoidal_pos, query_states, key_states
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Convert the boolean attention mask to an attention bias.
|
| 226 |
+
if attention_mask is not None:
|
| 227 |
+
# attention mask in the form of attention bias
|
| 228 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
| 229 |
+
attention_bias = lax.select(
|
| 230 |
+
attention_mask > 0,
|
| 231 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
| 232 |
+
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
| 233 |
+
)
|
| 234 |
+
else:
|
| 235 |
+
attention_bias = None
|
| 236 |
+
|
| 237 |
+
dropout_rng = None
|
| 238 |
+
if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
|
| 239 |
+
dropout_rng = self.make_rng("dropout")
|
| 240 |
+
|
| 241 |
+
attn_weights = dot_product_attention_weights(
|
| 242 |
+
query_states,
|
| 243 |
+
key_states,
|
| 244 |
+
bias=attention_bias,
|
| 245 |
+
dropout_rng=dropout_rng,
|
| 246 |
+
dropout_rate=self.config.attention_probs_dropout_prob,
|
| 247 |
+
broadcast_dropout=True,
|
| 248 |
+
deterministic=deterministic,
|
| 249 |
+
dtype=self.dtype,
|
| 250 |
+
precision=None,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# Mask heads if we want to
|
| 254 |
+
if layer_head_mask is not None:
|
| 255 |
+
attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
|
| 256 |
+
|
| 257 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
| 258 |
+
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
| 259 |
+
|
| 260 |
+
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
| 261 |
+
return outputs
|
| 262 |
+
|
| 263 |
+
@staticmethod
|
| 264 |
+
def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
|
| 265 |
+
sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1)
|
| 266 |
+
sin_pos = jnp.stack([sin, sin], axis=-1).reshape(sinusoidal_pos.shape)
|
| 267 |
+
cos_pos = jnp.stack([cos, cos], axis=-1).reshape(sinusoidal_pos.shape)
|
| 268 |
+
|
| 269 |
+
def rotate_layer(layer, sin_pos, cos_pos):
|
| 270 |
+
rotate_half_layer = jnp.stack([-layer[..., 1::2], layer[..., ::2]], axis=-1).reshape(layer.shape)
|
| 271 |
+
rotary_matrix_cos = jnp.einsum("bslh,...sh->bslh", layer, cos_pos)
|
| 272 |
+
rotary_matrix_sin = jnp.einsum("bslh,...sh->bslh", rotate_half_layer, sin_pos)
|
| 273 |
+
return rotary_matrix_cos + rotary_matrix_sin
|
| 274 |
+
|
| 275 |
+
query_layer = rotate_layer(query_layer, sin_pos, cos_pos)
|
| 276 |
+
key_layer = rotate_layer(key_layer, sin_pos, cos_pos)
|
| 277 |
+
if value_layer is not None:
|
| 278 |
+
value_layer = rotate_layer(value_layer, sin_pos, cos_pos)
|
| 279 |
+
return query_layer, key_layer, value_layer
|
| 280 |
+
return query_layer, key_layer
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->RoFormer
|
| 284 |
+
class FlaxRoFormerSelfOutput(nn.Module):
|
| 285 |
+
config: RoFormerConfig
|
| 286 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 287 |
+
|
| 288 |
+
def setup(self):
|
| 289 |
+
self.dense = nn.Dense(
|
| 290 |
+
self.config.hidden_size,
|
| 291 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 292 |
+
dtype=self.dtype,
|
| 293 |
+
)
|
| 294 |
+
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 295 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 296 |
+
|
| 297 |
+
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
|
| 298 |
+
hidden_states = self.dense(hidden_states)
|
| 299 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 300 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 301 |
+
return hidden_states
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class FlaxRoFormerAttention(nn.Module):
|
| 305 |
+
config: RoFormerConfig
|
| 306 |
+
dtype: jnp.dtype = jnp.float32
|
| 307 |
+
|
| 308 |
+
def setup(self):
|
| 309 |
+
self.self = FlaxRoFormerSelfAttention(self.config, dtype=self.dtype)
|
| 310 |
+
self.output = FlaxRoFormerSelfOutput(self.config, dtype=self.dtype)
|
| 311 |
+
|
| 312 |
+
def __call__(
|
| 313 |
+
self,
|
| 314 |
+
hidden_states,
|
| 315 |
+
attention_mask,
|
| 316 |
+
sinusoidal_pos,
|
| 317 |
+
layer_head_mask,
|
| 318 |
+
deterministic=True,
|
| 319 |
+
output_attentions: bool = False,
|
| 320 |
+
):
|
| 321 |
+
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
| 322 |
+
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
| 323 |
+
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
| 324 |
+
attn_outputs = self.self(
|
| 325 |
+
hidden_states,
|
| 326 |
+
attention_mask,
|
| 327 |
+
sinusoidal_pos,
|
| 328 |
+
layer_head_mask=layer_head_mask,
|
| 329 |
+
deterministic=deterministic,
|
| 330 |
+
output_attentions=output_attentions,
|
| 331 |
+
)
|
| 332 |
+
attn_output = attn_outputs[0]
|
| 333 |
+
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
| 334 |
+
|
| 335 |
+
outputs = (hidden_states,)
|
| 336 |
+
|
| 337 |
+
if output_attentions:
|
| 338 |
+
outputs += (attn_outputs[1],)
|
| 339 |
+
|
| 340 |
+
return outputs
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->RoFormer
|
| 344 |
+
class FlaxRoFormerIntermediate(nn.Module):
|
| 345 |
+
config: RoFormerConfig
|
| 346 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 347 |
+
|
| 348 |
+
def setup(self):
|
| 349 |
+
self.dense = nn.Dense(
|
| 350 |
+
self.config.intermediate_size,
|
| 351 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 352 |
+
dtype=self.dtype,
|
| 353 |
+
)
|
| 354 |
+
self.activation = ACT2FN[self.config.hidden_act]
|
| 355 |
+
|
| 356 |
+
def __call__(self, hidden_states):
|
| 357 |
+
hidden_states = self.dense(hidden_states)
|
| 358 |
+
hidden_states = self.activation(hidden_states)
|
| 359 |
+
return hidden_states
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->RoFormer
|
| 363 |
+
class FlaxRoFormerOutput(nn.Module):
|
| 364 |
+
config: RoFormerConfig
|
| 365 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 366 |
+
|
| 367 |
+
def setup(self):
|
| 368 |
+
self.dense = nn.Dense(
|
| 369 |
+
self.config.hidden_size,
|
| 370 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 371 |
+
dtype=self.dtype,
|
| 372 |
+
)
|
| 373 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 374 |
+
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 375 |
+
|
| 376 |
+
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
| 377 |
+
hidden_states = self.dense(hidden_states)
|
| 378 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 379 |
+
hidden_states = self.LayerNorm(hidden_states + attention_output)
|
| 380 |
+
return hidden_states
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
class FlaxRoFormerLayer(nn.Module):
|
| 384 |
+
config: RoFormerConfig
|
| 385 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 386 |
+
|
| 387 |
+
def setup(self):
|
| 388 |
+
self.attention = FlaxRoFormerAttention(self.config, dtype=self.dtype)
|
| 389 |
+
self.intermediate = FlaxRoFormerIntermediate(self.config, dtype=self.dtype)
|
| 390 |
+
self.output = FlaxRoFormerOutput(self.config, dtype=self.dtype)
|
| 391 |
+
|
| 392 |
+
def __call__(
|
| 393 |
+
self,
|
| 394 |
+
hidden_states,
|
| 395 |
+
attention_mask,
|
| 396 |
+
sinusiodal_pos,
|
| 397 |
+
layer_head_mask,
|
| 398 |
+
deterministic: bool = True,
|
| 399 |
+
output_attentions: bool = False,
|
| 400 |
+
):
|
| 401 |
+
attention_outputs = self.attention(
|
| 402 |
+
hidden_states,
|
| 403 |
+
attention_mask,
|
| 404 |
+
sinusiodal_pos,
|
| 405 |
+
layer_head_mask=layer_head_mask,
|
| 406 |
+
deterministic=deterministic,
|
| 407 |
+
output_attentions=output_attentions,
|
| 408 |
+
)
|
| 409 |
+
attention_output = attention_outputs[0]
|
| 410 |
+
|
| 411 |
+
hidden_states = self.intermediate(attention_output)
|
| 412 |
+
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
| 413 |
+
|
| 414 |
+
outputs = (hidden_states,)
|
| 415 |
+
|
| 416 |
+
if output_attentions:
|
| 417 |
+
outputs += (attention_outputs[1],)
|
| 418 |
+
return outputs
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
class FlaxRoFormerLayerCollection(nn.Module):
|
| 422 |
+
config: RoFormerConfig
|
| 423 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 424 |
+
|
| 425 |
+
def setup(self):
|
| 426 |
+
self.layers = [
|
| 427 |
+
FlaxRoFormerLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
| 428 |
+
]
|
| 429 |
+
|
| 430 |
+
def __call__(
|
| 431 |
+
self,
|
| 432 |
+
hidden_states,
|
| 433 |
+
attention_mask,
|
| 434 |
+
sinusoidal_pos,
|
| 435 |
+
head_mask,
|
| 436 |
+
deterministic: bool = True,
|
| 437 |
+
output_attentions: bool = False,
|
| 438 |
+
output_hidden_states: bool = False,
|
| 439 |
+
return_dict: bool = True,
|
| 440 |
+
):
|
| 441 |
+
all_attentions = () if output_attentions else None
|
| 442 |
+
all_hidden_states = () if output_hidden_states else None
|
| 443 |
+
|
| 444 |
+
# Check if head_mask has a correct number of layers specified if desired
|
| 445 |
+
if head_mask is not None:
|
| 446 |
+
if head_mask.shape[0] != (len(self.layers)):
|
| 447 |
+
raise ValueError(
|
| 448 |
+
f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
|
| 449 |
+
f" {head_mask.shape[0]}."
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
for i, layer in enumerate(self.layers):
|
| 453 |
+
if output_hidden_states:
|
| 454 |
+
all_hidden_states += (hidden_states,)
|
| 455 |
+
|
| 456 |
+
layer_outputs = layer(
|
| 457 |
+
hidden_states,
|
| 458 |
+
attention_mask,
|
| 459 |
+
sinusoidal_pos,
|
| 460 |
+
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
| 461 |
+
deterministic=deterministic,
|
| 462 |
+
output_attentions=output_attentions,
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
hidden_states = layer_outputs[0]
|
| 466 |
+
|
| 467 |
+
if output_attentions:
|
| 468 |
+
all_attentions += (layer_outputs[1],)
|
| 469 |
+
|
| 470 |
+
if output_hidden_states:
|
| 471 |
+
all_hidden_states += (hidden_states,)
|
| 472 |
+
|
| 473 |
+
outputs = (hidden_states,)
|
| 474 |
+
|
| 475 |
+
if not return_dict:
|
| 476 |
+
return tuple(v for v in outputs if v is not None)
|
| 477 |
+
|
| 478 |
+
return FlaxBaseModelOutput(
|
| 479 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class FlaxRoFormerEncoder(nn.Module):
|
| 484 |
+
config: RoFormerConfig
|
| 485 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 486 |
+
|
| 487 |
+
def setup(self):
|
| 488 |
+
self.embed_positions = create_sinusoidal_positions(
|
| 489 |
+
self.config.max_position_embeddings, self.config.hidden_size // self.config.num_attention_heads
|
| 490 |
+
)
|
| 491 |
+
self.layer = FlaxRoFormerLayerCollection(self.config, dtype=self.dtype)
|
| 492 |
+
|
| 493 |
+
def __call__(
|
| 494 |
+
self,
|
| 495 |
+
hidden_states,
|
| 496 |
+
attention_mask,
|
| 497 |
+
head_mask,
|
| 498 |
+
deterministic: bool = True,
|
| 499 |
+
output_attentions: bool = False,
|
| 500 |
+
output_hidden_states: bool = False,
|
| 501 |
+
return_dict: bool = True,
|
| 502 |
+
):
|
| 503 |
+
sinusoidal_pos = self.embed_positions[: hidden_states.shape[1], :]
|
| 504 |
+
|
| 505 |
+
return self.layer(
|
| 506 |
+
hidden_states,
|
| 507 |
+
attention_mask,
|
| 508 |
+
sinusoidal_pos,
|
| 509 |
+
head_mask,
|
| 510 |
+
deterministic=deterministic,
|
| 511 |
+
output_attentions=output_attentions,
|
| 512 |
+
output_hidden_states=output_hidden_states,
|
| 513 |
+
return_dict=return_dict,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPredictionHeadTransform with Bert->RoFormer
|
| 518 |
+
class FlaxRoFormerPredictionHeadTransform(nn.Module):
|
| 519 |
+
config: RoFormerConfig
|
| 520 |
+
dtype: jnp.dtype = jnp.float32
|
| 521 |
+
|
| 522 |
+
def setup(self):
|
| 523 |
+
self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
|
| 524 |
+
self.activation = ACT2FN[self.config.hidden_act]
|
| 525 |
+
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 526 |
+
|
| 527 |
+
def __call__(self, hidden_states):
|
| 528 |
+
hidden_states = self.dense(hidden_states)
|
| 529 |
+
hidden_states = self.activation(hidden_states)
|
| 530 |
+
return self.LayerNorm(hidden_states)
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLMPredictionHead with Bert->RoFormer
|
| 534 |
+
class FlaxRoFormerLMPredictionHead(nn.Module):
|
| 535 |
+
config: RoFormerConfig
|
| 536 |
+
dtype: jnp.dtype = jnp.float32
|
| 537 |
+
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
| 538 |
+
|
| 539 |
+
def setup(self):
|
| 540 |
+
self.transform = FlaxRoFormerPredictionHeadTransform(self.config, dtype=self.dtype)
|
| 541 |
+
self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
|
| 542 |
+
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
|
| 543 |
+
|
| 544 |
+
def __call__(self, hidden_states, shared_embedding=None):
|
| 545 |
+
hidden_states = self.transform(hidden_states)
|
| 546 |
+
|
| 547 |
+
if shared_embedding is not None:
|
| 548 |
+
hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
| 549 |
+
else:
|
| 550 |
+
hidden_states = self.decoder(hidden_states)
|
| 551 |
+
|
| 552 |
+
bias = jnp.asarray(self.bias, self.dtype)
|
| 553 |
+
hidden_states += bias
|
| 554 |
+
return hidden_states
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOnlyMLMHead with Bert->RoFormer
|
| 558 |
+
class FlaxRoFormerOnlyMLMHead(nn.Module):
|
| 559 |
+
config: RoFormerConfig
|
| 560 |
+
dtype: jnp.dtype = jnp.float32
|
| 561 |
+
|
| 562 |
+
def setup(self):
|
| 563 |
+
self.predictions = FlaxRoFormerLMPredictionHead(self.config, dtype=self.dtype)
|
| 564 |
+
|
| 565 |
+
def __call__(self, hidden_states, shared_embedding=None):
|
| 566 |
+
hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding)
|
| 567 |
+
return hidden_states
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
class FlaxRoFormerClassificationHead(nn.Module):
|
| 571 |
+
config: RoFormerConfig
|
| 572 |
+
dtype: jnp.dtype = jnp.float32
|
| 573 |
+
|
| 574 |
+
def setup(self):
|
| 575 |
+
self.dense = nn.Dense(
|
| 576 |
+
self.config.hidden_size,
|
| 577 |
+
dtype=self.dtype,
|
| 578 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 579 |
+
)
|
| 580 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 581 |
+
self.out_proj = nn.Dense(
|
| 582 |
+
self.config.num_labels,
|
| 583 |
+
dtype=self.dtype,
|
| 584 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 585 |
+
)
|
| 586 |
+
self.activation = ACT2FN[self.config.hidden_act]
|
| 587 |
+
|
| 588 |
+
def __call__(self, hidden_states, deterministic=True):
|
| 589 |
+
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
|
| 590 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 591 |
+
hidden_states = self.dense(hidden_states)
|
| 592 |
+
hidden_states = self.activation(hidden_states)
|
| 593 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 594 |
+
hidden_states = self.out_proj(hidden_states)
|
| 595 |
+
return hidden_states
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
class FlaxRoFormerPreTrainedModel(FlaxPreTrainedModel):
|
| 599 |
+
"""
|
| 600 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 601 |
+
models.
|
| 602 |
+
"""
|
| 603 |
+
|
| 604 |
+
config_class = RoFormerConfig
|
| 605 |
+
base_model_prefix = "roformer"
|
| 606 |
+
module_class: nn.Module = None
|
| 607 |
+
|
| 608 |
+
def __init__(
|
| 609 |
+
self,
|
| 610 |
+
config: RoFormerConfig,
|
| 611 |
+
input_shape: Tuple = (1, 1),
|
| 612 |
+
seed: int = 0,
|
| 613 |
+
dtype: jnp.dtype = jnp.float32,
|
| 614 |
+
_do_init: bool = True,
|
| 615 |
+
**kwargs,
|
| 616 |
+
):
|
| 617 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
| 618 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
| 619 |
+
|
| 620 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
| 621 |
+
# init input tensors
|
| 622 |
+
input_ids = jnp.zeros(input_shape, dtype="i4")
|
| 623 |
+
token_type_ids = jnp.zeros_like(input_ids)
|
| 624 |
+
attention_mask = jnp.ones_like(input_ids)
|
| 625 |
+
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
| 626 |
+
|
| 627 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
| 628 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
| 629 |
+
|
| 630 |
+
random_params = self.module.init(
|
| 631 |
+
rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False
|
| 632 |
+
)["params"]
|
| 633 |
+
|
| 634 |
+
if params is not None:
|
| 635 |
+
random_params = flatten_dict(unfreeze(random_params))
|
| 636 |
+
params = flatten_dict(unfreeze(params))
|
| 637 |
+
for missing_key in self._missing_keys:
|
| 638 |
+
params[missing_key] = random_params[missing_key]
|
| 639 |
+
self._missing_keys = set()
|
| 640 |
+
return freeze(unflatten_dict(params))
|
| 641 |
+
else:
|
| 642 |
+
return random_params
|
| 643 |
+
|
| 644 |
+
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 645 |
+
def __call__(
|
| 646 |
+
self,
|
| 647 |
+
input_ids,
|
| 648 |
+
attention_mask=None,
|
| 649 |
+
token_type_ids=None,
|
| 650 |
+
head_mask=None,
|
| 651 |
+
params: dict = None,
|
| 652 |
+
dropout_rng: jax.random.PRNGKey = None,
|
| 653 |
+
train: bool = False,
|
| 654 |
+
output_attentions: Optional[bool] = None,
|
| 655 |
+
output_hidden_states: Optional[bool] = None,
|
| 656 |
+
return_dict: Optional[bool] = None,
|
| 657 |
+
):
|
| 658 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 659 |
+
output_hidden_states = (
|
| 660 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 661 |
+
)
|
| 662 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 663 |
+
|
| 664 |
+
# init input tensors if not passed
|
| 665 |
+
if token_type_ids is None:
|
| 666 |
+
token_type_ids = jnp.zeros_like(input_ids)
|
| 667 |
+
|
| 668 |
+
if attention_mask is None:
|
| 669 |
+
attention_mask = jnp.ones_like(input_ids)
|
| 670 |
+
|
| 671 |
+
if head_mask is None:
|
| 672 |
+
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
| 673 |
+
|
| 674 |
+
# Handle any PRNG if needed
|
| 675 |
+
rngs = {}
|
| 676 |
+
if dropout_rng is not None:
|
| 677 |
+
rngs["dropout"] = dropout_rng
|
| 678 |
+
|
| 679 |
+
return self.module.apply(
|
| 680 |
+
{"params": params or self.params},
|
| 681 |
+
jnp.array(input_ids, dtype="i4"),
|
| 682 |
+
jnp.array(attention_mask, dtype="i4"),
|
| 683 |
+
jnp.array(token_type_ids, dtype="i4"),
|
| 684 |
+
jnp.array(head_mask, dtype="i4"),
|
| 685 |
+
not train,
|
| 686 |
+
output_attentions,
|
| 687 |
+
output_hidden_states,
|
| 688 |
+
return_dict,
|
| 689 |
+
rngs=rngs,
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
class FlaxRoFormerModule(nn.Module):
|
| 694 |
+
config: RoFormerConfig
|
| 695 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 696 |
+
|
| 697 |
+
def setup(self):
|
| 698 |
+
self.embeddings = FlaxRoFormerEmbeddings(self.config, dtype=self.dtype)
|
| 699 |
+
self.encoder = FlaxRoFormerEncoder(self.config, dtype=self.dtype)
|
| 700 |
+
|
| 701 |
+
def __call__(
|
| 702 |
+
self,
|
| 703 |
+
input_ids,
|
| 704 |
+
attention_mask,
|
| 705 |
+
token_type_ids,
|
| 706 |
+
head_mask,
|
| 707 |
+
deterministic: bool = True,
|
| 708 |
+
output_attentions: bool = False,
|
| 709 |
+
output_hidden_states: bool = False,
|
| 710 |
+
return_dict: bool = True,
|
| 711 |
+
):
|
| 712 |
+
hidden_states = self.embeddings(input_ids, token_type_ids, attention_mask, deterministic=deterministic)
|
| 713 |
+
outputs = self.encoder(
|
| 714 |
+
hidden_states,
|
| 715 |
+
attention_mask,
|
| 716 |
+
head_mask=head_mask,
|
| 717 |
+
deterministic=deterministic,
|
| 718 |
+
output_attentions=output_attentions,
|
| 719 |
+
output_hidden_states=output_hidden_states,
|
| 720 |
+
return_dict=return_dict,
|
| 721 |
+
)
|
| 722 |
+
hidden_states = outputs[0]
|
| 723 |
+
|
| 724 |
+
if not return_dict:
|
| 725 |
+
return (hidden_states,) + outputs[1:]
|
| 726 |
+
|
| 727 |
+
return FlaxBaseModelOutput(
|
| 728 |
+
last_hidden_state=hidden_states,
|
| 729 |
+
hidden_states=outputs.hidden_states,
|
| 730 |
+
attentions=outputs.attentions,
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
@add_start_docstrings(
|
| 735 |
+
"The bare RoFormer Model transformer outputting raw hidden-states without any specific head on top.",
|
| 736 |
+
ROFORMER_START_DOCSTRING,
|
| 737 |
+
)
|
| 738 |
+
class FlaxRoFormerModel(FlaxRoFormerPreTrainedModel):
|
| 739 |
+
module_class = FlaxRoFormerModule
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
append_call_sample_docstring(FlaxRoFormerModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC)
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
class FlaxRoFormerForMaskedLMModule(nn.Module):
|
| 746 |
+
config: RoFormerConfig
|
| 747 |
+
dtype: jnp.dtype = jnp.float32
|
| 748 |
+
|
| 749 |
+
def setup(self):
|
| 750 |
+
self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype)
|
| 751 |
+
self.cls = FlaxRoFormerOnlyMLMHead(config=self.config, dtype=self.dtype)
|
| 752 |
+
|
| 753 |
+
def __call__(
|
| 754 |
+
self,
|
| 755 |
+
input_ids,
|
| 756 |
+
attention_mask,
|
| 757 |
+
token_type_ids,
|
| 758 |
+
head_mask,
|
| 759 |
+
deterministic: bool = True,
|
| 760 |
+
output_attentions: bool = False,
|
| 761 |
+
output_hidden_states: bool = False,
|
| 762 |
+
return_dict: bool = True,
|
| 763 |
+
):
|
| 764 |
+
# Model
|
| 765 |
+
outputs = self.roformer(
|
| 766 |
+
input_ids,
|
| 767 |
+
attention_mask,
|
| 768 |
+
token_type_ids,
|
| 769 |
+
head_mask,
|
| 770 |
+
deterministic=deterministic,
|
| 771 |
+
output_attentions=output_attentions,
|
| 772 |
+
output_hidden_states=output_hidden_states,
|
| 773 |
+
return_dict=return_dict,
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
hidden_states = outputs[0]
|
| 777 |
+
if self.config.tie_word_embeddings:
|
| 778 |
+
shared_embedding = self.roformer.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
| 779 |
+
else:
|
| 780 |
+
shared_embedding = None
|
| 781 |
+
|
| 782 |
+
# Compute the prediction scores
|
| 783 |
+
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
|
| 784 |
+
|
| 785 |
+
if not return_dict:
|
| 786 |
+
return (logits,) + outputs[1:]
|
| 787 |
+
|
| 788 |
+
return FlaxMaskedLMOutput(
|
| 789 |
+
logits=logits,
|
| 790 |
+
hidden_states=outputs.hidden_states,
|
| 791 |
+
attentions=outputs.attentions,
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
@add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING)
|
| 796 |
+
class FlaxRoFormerForMaskedLM(FlaxRoFormerPreTrainedModel):
|
| 797 |
+
module_class = FlaxRoFormerForMaskedLMModule
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
append_call_sample_docstring(
|
| 801 |
+
FlaxRoFormerForMaskedLM,
|
| 802 |
+
_CHECKPOINT_FOR_DOC,
|
| 803 |
+
FlaxMaskedLMOutput,
|
| 804 |
+
_CONFIG_FOR_DOC,
|
| 805 |
+
mask="<mask>",
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
|
| 809 |
+
class FlaxRoFormerForSequenceClassificationModule(nn.Module):
|
| 810 |
+
config: RoFormerConfig
|
| 811 |
+
dtype: jnp.dtype = jnp.float32
|
| 812 |
+
|
| 813 |
+
def setup(self):
|
| 814 |
+
self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype)
|
| 815 |
+
self.classifier = FlaxRoFormerClassificationHead(config=self.config, dtype=self.dtype)
|
| 816 |
+
|
| 817 |
+
def __call__(
|
| 818 |
+
self,
|
| 819 |
+
input_ids,
|
| 820 |
+
attention_mask,
|
| 821 |
+
token_type_ids,
|
| 822 |
+
head_mask,
|
| 823 |
+
deterministic: bool = True,
|
| 824 |
+
output_attentions: bool = False,
|
| 825 |
+
output_hidden_states: bool = False,
|
| 826 |
+
return_dict: bool = True,
|
| 827 |
+
):
|
| 828 |
+
# Model
|
| 829 |
+
outputs = self.roformer(
|
| 830 |
+
input_ids,
|
| 831 |
+
attention_mask,
|
| 832 |
+
token_type_ids,
|
| 833 |
+
head_mask,
|
| 834 |
+
deterministic=deterministic,
|
| 835 |
+
output_attentions=output_attentions,
|
| 836 |
+
output_hidden_states=output_hidden_states,
|
| 837 |
+
return_dict=return_dict,
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
sequence_output = outputs[0]
|
| 841 |
+
logits = self.classifier(sequence_output, deterministic=deterministic)
|
| 842 |
+
|
| 843 |
+
if not return_dict:
|
| 844 |
+
return (logits,) + outputs[1:]
|
| 845 |
+
|
| 846 |
+
return FlaxSequenceClassifierOutput(
|
| 847 |
+
logits=logits,
|
| 848 |
+
hidden_states=outputs.hidden_states,
|
| 849 |
+
attentions=outputs.attentions,
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
@add_start_docstrings(
|
| 854 |
+
"""
|
| 855 |
+
RoFormer Model transformer with a sequence classification/regression head on top (a linear layer on top of the
|
| 856 |
+
pooled output) e.g. for GLUE tasks.
|
| 857 |
+
""",
|
| 858 |
+
ROFORMER_START_DOCSTRING,
|
| 859 |
+
)
|
| 860 |
+
class FlaxRoFormerForSequenceClassification(FlaxRoFormerPreTrainedModel):
|
| 861 |
+
module_class = FlaxRoFormerForSequenceClassificationModule
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
append_call_sample_docstring(
|
| 865 |
+
FlaxRoFormerForSequenceClassification,
|
| 866 |
+
_CHECKPOINT_FOR_DOC,
|
| 867 |
+
FlaxSequenceClassifierOutput,
|
| 868 |
+
_CONFIG_FOR_DOC,
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
|
| 872 |
+
class FlaxRoFormerForMultipleChoiceModule(nn.Module):
|
| 873 |
+
config: RoFormerConfig
|
| 874 |
+
dtype: jnp.dtype = jnp.float32
|
| 875 |
+
|
| 876 |
+
def setup(self):
|
| 877 |
+
self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype)
|
| 878 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 879 |
+
self.classifier = nn.Dense(1, dtype=self.dtype)
|
| 880 |
+
|
| 881 |
+
def __call__(
|
| 882 |
+
self,
|
| 883 |
+
input_ids,
|
| 884 |
+
attention_mask,
|
| 885 |
+
token_type_ids,
|
| 886 |
+
head_mask,
|
| 887 |
+
deterministic: bool = True,
|
| 888 |
+
output_attentions: bool = False,
|
| 889 |
+
output_hidden_states: bool = False,
|
| 890 |
+
return_dict: bool = True,
|
| 891 |
+
):
|
| 892 |
+
num_choices = input_ids.shape[1]
|
| 893 |
+
input_ids = input_ids.reshape(-1, input_ids.shape[-1])
|
| 894 |
+
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1])
|
| 895 |
+
token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1])
|
| 896 |
+
|
| 897 |
+
# Model
|
| 898 |
+
outputs = self.roformer(
|
| 899 |
+
input_ids,
|
| 900 |
+
attention_mask,
|
| 901 |
+
token_type_ids,
|
| 902 |
+
head_mask,
|
| 903 |
+
deterministic=deterministic,
|
| 904 |
+
output_attentions=output_attentions,
|
| 905 |
+
output_hidden_states=output_hidden_states,
|
| 906 |
+
return_dict=return_dict,
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
# Equivalent to sequence_summary call in the PyTorch implementation
|
| 910 |
+
hidden_states = outputs[0]
|
| 911 |
+
pooled_output = hidden_states[:, -1]
|
| 912 |
+
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
| 913 |
+
|
| 914 |
+
logits = self.classifier(pooled_output)
|
| 915 |
+
|
| 916 |
+
reshaped_logits = logits.reshape(-1, num_choices)
|
| 917 |
+
|
| 918 |
+
if not return_dict:
|
| 919 |
+
return (reshaped_logits,) + outputs[2:]
|
| 920 |
+
|
| 921 |
+
return FlaxMultipleChoiceModelOutput(
|
| 922 |
+
logits=reshaped_logits,
|
| 923 |
+
hidden_states=outputs.hidden_states,
|
| 924 |
+
attentions=outputs.attentions,
|
| 925 |
+
)
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
@add_start_docstrings(
|
| 929 |
+
"""
|
| 930 |
+
RoFormer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
| 931 |
+
softmax) e.g. for RocStories/SWAG tasks.
|
| 932 |
+
""",
|
| 933 |
+
ROFORMER_START_DOCSTRING,
|
| 934 |
+
)
|
| 935 |
+
class FlaxRoFormerForMultipleChoice(FlaxRoFormerPreTrainedModel):
|
| 936 |
+
module_class = FlaxRoFormerForMultipleChoiceModule
|
| 937 |
+
|
| 938 |
+
|
| 939 |
+
overwrite_call_docstring(
|
| 940 |
+
FlaxRoFormerForMultipleChoice, ROFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
| 941 |
+
)
|
| 942 |
+
append_call_sample_docstring(
|
| 943 |
+
FlaxRoFormerForMultipleChoice,
|
| 944 |
+
_CHECKPOINT_FOR_DOC,
|
| 945 |
+
FlaxMultipleChoiceModelOutput,
|
| 946 |
+
_CONFIG_FOR_DOC,
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
class FlaxRoFormerForTokenClassificationModule(nn.Module):
|
| 951 |
+
config: RoFormerConfig
|
| 952 |
+
dtype: jnp.dtype = jnp.float32
|
| 953 |
+
|
| 954 |
+
def setup(self):
|
| 955 |
+
self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype)
|
| 956 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 957 |
+
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
| 958 |
+
|
| 959 |
+
def __call__(
|
| 960 |
+
self,
|
| 961 |
+
input_ids,
|
| 962 |
+
attention_mask,
|
| 963 |
+
token_type_ids,
|
| 964 |
+
head_mask,
|
| 965 |
+
deterministic: bool = True,
|
| 966 |
+
output_attentions: bool = False,
|
| 967 |
+
output_hidden_states: bool = False,
|
| 968 |
+
return_dict: bool = True,
|
| 969 |
+
):
|
| 970 |
+
# Model
|
| 971 |
+
outputs = self.roformer(
|
| 972 |
+
input_ids,
|
| 973 |
+
attention_mask,
|
| 974 |
+
token_type_ids,
|
| 975 |
+
head_mask,
|
| 976 |
+
deterministic=deterministic,
|
| 977 |
+
output_attentions=output_attentions,
|
| 978 |
+
output_hidden_states=output_hidden_states,
|
| 979 |
+
return_dict=return_dict,
|
| 980 |
+
)
|
| 981 |
+
|
| 982 |
+
hidden_states = outputs[0]
|
| 983 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 984 |
+
logits = self.classifier(hidden_states)
|
| 985 |
+
|
| 986 |
+
if not return_dict:
|
| 987 |
+
return (logits,) + outputs[1:]
|
| 988 |
+
|
| 989 |
+
return FlaxTokenClassifierOutput(
|
| 990 |
+
logits=logits,
|
| 991 |
+
hidden_states=outputs.hidden_states,
|
| 992 |
+
attentions=outputs.attentions,
|
| 993 |
+
)
|
| 994 |
+
|
| 995 |
+
|
| 996 |
+
@add_start_docstrings(
|
| 997 |
+
"""
|
| 998 |
+
RoFormer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 999 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1000 |
+
""",
|
| 1001 |
+
ROFORMER_START_DOCSTRING,
|
| 1002 |
+
)
|
| 1003 |
+
class FlaxRoFormerForTokenClassification(FlaxRoFormerPreTrainedModel):
|
| 1004 |
+
module_class = FlaxRoFormerForTokenClassificationModule
|
| 1005 |
+
|
| 1006 |
+
|
| 1007 |
+
append_call_sample_docstring(
|
| 1008 |
+
FlaxRoFormerForTokenClassification,
|
| 1009 |
+
_CHECKPOINT_FOR_DOC,
|
| 1010 |
+
FlaxTokenClassifierOutput,
|
| 1011 |
+
_CONFIG_FOR_DOC,
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
|
| 1015 |
+
class FlaxRoFormerForQuestionAnsweringModule(nn.Module):
|
| 1016 |
+
config: RoFormerConfig
|
| 1017 |
+
dtype: jnp.dtype = jnp.float32
|
| 1018 |
+
|
| 1019 |
+
def setup(self):
|
| 1020 |
+
self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype)
|
| 1021 |
+
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
| 1022 |
+
|
| 1023 |
+
def __call__(
|
| 1024 |
+
self,
|
| 1025 |
+
input_ids,
|
| 1026 |
+
attention_mask,
|
| 1027 |
+
token_type_ids,
|
| 1028 |
+
head_mask,
|
| 1029 |
+
deterministic: bool = True,
|
| 1030 |
+
output_attentions: bool = False,
|
| 1031 |
+
output_hidden_states: bool = False,
|
| 1032 |
+
return_dict: bool = True,
|
| 1033 |
+
):
|
| 1034 |
+
# Model
|
| 1035 |
+
outputs = self.roformer(
|
| 1036 |
+
input_ids,
|
| 1037 |
+
attention_mask,
|
| 1038 |
+
token_type_ids,
|
| 1039 |
+
head_mask,
|
| 1040 |
+
deterministic=deterministic,
|
| 1041 |
+
output_attentions=output_attentions,
|
| 1042 |
+
output_hidden_states=output_hidden_states,
|
| 1043 |
+
return_dict=return_dict,
|
| 1044 |
+
)
|
| 1045 |
+
|
| 1046 |
+
hidden_states = outputs[0]
|
| 1047 |
+
|
| 1048 |
+
logits = self.qa_outputs(hidden_states)
|
| 1049 |
+
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
|
| 1050 |
+
start_logits = start_logits.squeeze(-1)
|
| 1051 |
+
end_logits = end_logits.squeeze(-1)
|
| 1052 |
+
|
| 1053 |
+
if not return_dict:
|
| 1054 |
+
return (start_logits, end_logits) + outputs[1:]
|
| 1055 |
+
|
| 1056 |
+
return FlaxQuestionAnsweringModelOutput(
|
| 1057 |
+
start_logits=start_logits,
|
| 1058 |
+
end_logits=end_logits,
|
| 1059 |
+
hidden_states=outputs.hidden_states,
|
| 1060 |
+
attentions=outputs.attentions,
|
| 1061 |
+
)
|
| 1062 |
+
|
| 1063 |
+
|
| 1064 |
+
@add_start_docstrings(
|
| 1065 |
+
"""
|
| 1066 |
+
RoFormer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 1067 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1068 |
+
""",
|
| 1069 |
+
ROFORMER_START_DOCSTRING,
|
| 1070 |
+
)
|
| 1071 |
+
class FlaxRoFormerForQuestionAnswering(FlaxRoFormerPreTrainedModel):
|
| 1072 |
+
module_class = FlaxRoFormerForQuestionAnsweringModule
|
| 1073 |
+
|
| 1074 |
+
|
| 1075 |
+
append_call_sample_docstring(
|
| 1076 |
+
FlaxRoFormerForQuestionAnswering,
|
| 1077 |
+
_CHECKPOINT_FOR_DOC,
|
| 1078 |
+
FlaxQuestionAnsweringModelOutput,
|
| 1079 |
+
_CONFIG_FOR_DOC,
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
|
| 1083 |
+
__all__ = [
|
| 1084 |
+
"FlaxRoFormerForMaskedLM",
|
| 1085 |
+
"FlaxRoFormerForMultipleChoice",
|
| 1086 |
+
"FlaxRoFormerForQuestionAnswering",
|
| 1087 |
+
"FlaxRoFormerForSequenceClassification",
|
| 1088 |
+
"FlaxRoFormerForTokenClassification",
|
| 1089 |
+
"FlaxRoFormerModel",
|
| 1090 |
+
"FlaxRoFormerPreTrainedModel",
|
| 1091 |
+
]
|
docs/transformers/build/lib/transformers/models/roformer/modeling_tf_roformer.py
ADDED
|
@@ -0,0 +1,1547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""TF 2.0 RoFormer model."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
from typing import Dict, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import tensorflow as tf
|
| 24 |
+
|
| 25 |
+
from ...activations_tf import get_tf_activation
|
| 26 |
+
from ...modeling_tf_outputs import (
|
| 27 |
+
TFBaseModelOutput,
|
| 28 |
+
TFBaseModelOutputWithPooling,
|
| 29 |
+
TFCausalLMOutput,
|
| 30 |
+
TFMaskedLMOutput,
|
| 31 |
+
TFMultipleChoiceModelOutput,
|
| 32 |
+
TFQuestionAnsweringModelOutput,
|
| 33 |
+
TFSequenceClassifierOutput,
|
| 34 |
+
TFTokenClassifierOutput,
|
| 35 |
+
)
|
| 36 |
+
from ...modeling_tf_utils import (
|
| 37 |
+
TFCausalLanguageModelingLoss,
|
| 38 |
+
TFMaskedLanguageModelingLoss,
|
| 39 |
+
TFModelInputType,
|
| 40 |
+
TFMultipleChoiceLoss,
|
| 41 |
+
TFPreTrainedModel,
|
| 42 |
+
TFQuestionAnsweringLoss,
|
| 43 |
+
TFSequenceClassificationLoss,
|
| 44 |
+
TFSequenceSummary,
|
| 45 |
+
TFTokenClassificationLoss,
|
| 46 |
+
get_initializer,
|
| 47 |
+
keras,
|
| 48 |
+
keras_serializable,
|
| 49 |
+
unpack_inputs,
|
| 50 |
+
)
|
| 51 |
+
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
|
| 52 |
+
from ...utils import (
|
| 53 |
+
add_code_sample_docstrings,
|
| 54 |
+
add_start_docstrings,
|
| 55 |
+
add_start_docstrings_to_model_forward,
|
| 56 |
+
logging,
|
| 57 |
+
)
|
| 58 |
+
from .configuration_roformer import RoFormerConfig
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
logger = logging.get_logger(__name__)
|
| 62 |
+
|
| 63 |
+
_CHECKPOINT_FOR_DOC = "junnyu/roformer_chinese_base"
|
| 64 |
+
_CONFIG_FOR_DOC = "RoFormerConfig"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class TFRoFormerSinusoidalPositionalEmbedding(keras.layers.Layer):
|
| 68 |
+
"""This module produces sinusoidal positional embeddings of any length."""
|
| 69 |
+
|
| 70 |
+
def __init__(self, num_positions: int, embedding_dim: int, **kwargs):
|
| 71 |
+
super().__init__(**kwargs)
|
| 72 |
+
|
| 73 |
+
if embedding_dim % 2 != 0:
|
| 74 |
+
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
|
| 75 |
+
|
| 76 |
+
self.embedding_dim = embedding_dim
|
| 77 |
+
self.num_positions = num_positions
|
| 78 |
+
|
| 79 |
+
def build(self, input_shape: tf.TensorShape):
|
| 80 |
+
"""
|
| 81 |
+
Build shared token embedding layer Shared weights logic adapted from
|
| 82 |
+
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
weight = self._init_weight(self.num_positions, self.embedding_dim)
|
| 86 |
+
|
| 87 |
+
self.weight = self.add_weight(
|
| 88 |
+
name="embeddings",
|
| 89 |
+
shape=[self.num_positions, self.embedding_dim],
|
| 90 |
+
)
|
| 91 |
+
weight = tf.cast(weight, dtype=self.weight.dtype)
|
| 92 |
+
|
| 93 |
+
self.weight.assign(weight)
|
| 94 |
+
|
| 95 |
+
super().build(input_shape)
|
| 96 |
+
|
| 97 |
+
@staticmethod
|
| 98 |
+
def _init_weight(n_pos: int, dim: int):
|
| 99 |
+
"""
|
| 100 |
+
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
|
| 101 |
+
the 2nd half of the vector. [dim // 2:]
|
| 102 |
+
"""
|
| 103 |
+
position_enc = np.array(
|
| 104 |
+
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
|
| 105 |
+
)
|
| 106 |
+
table = np.zeros_like(position_enc)
|
| 107 |
+
# index 0 is all zero
|
| 108 |
+
table[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
|
| 109 |
+
table[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
|
| 110 |
+
# convert to tensor
|
| 111 |
+
table = tf.convert_to_tensor(table)
|
| 112 |
+
tf.stop_gradient(table)
|
| 113 |
+
return table
|
| 114 |
+
|
| 115 |
+
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
|
| 116 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
| 117 |
+
bsz, seq_len = input_shape[:2]
|
| 118 |
+
|
| 119 |
+
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
|
| 120 |
+
return tf.gather(self.weight, positions)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class TFRoFormerEmbeddings(keras.layers.Layer):
|
| 124 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 125 |
+
|
| 126 |
+
def __init__(self, config: RoFormerConfig, **kwargs):
|
| 127 |
+
super().__init__(**kwargs)
|
| 128 |
+
|
| 129 |
+
self.config = config
|
| 130 |
+
self.embedding_size = config.embedding_size
|
| 131 |
+
self.initializer_range = config.initializer_range
|
| 132 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 133 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 134 |
+
|
| 135 |
+
def build(self, input_shape=None):
|
| 136 |
+
with tf.name_scope("word_embeddings"):
|
| 137 |
+
self.weight = self.add_weight(
|
| 138 |
+
name="weight",
|
| 139 |
+
shape=[self.config.vocab_size, self.embedding_size],
|
| 140 |
+
initializer=get_initializer(self.initializer_range),
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
with tf.name_scope("token_type_embeddings"):
|
| 144 |
+
self.token_type_embeddings = self.add_weight(
|
| 145 |
+
name="embeddings",
|
| 146 |
+
shape=[self.config.type_vocab_size, self.embedding_size],
|
| 147 |
+
initializer=get_initializer(self.initializer_range),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if self.built:
|
| 151 |
+
return
|
| 152 |
+
self.built = True
|
| 153 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 154 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 155 |
+
self.LayerNorm.build([None, None, self.config.embedding_size])
|
| 156 |
+
|
| 157 |
+
def call(
|
| 158 |
+
self,
|
| 159 |
+
input_ids: Optional[tf.Tensor] = None,
|
| 160 |
+
token_type_ids: Optional[tf.Tensor] = None,
|
| 161 |
+
inputs_embeds: Optional[tf.Tensor] = None,
|
| 162 |
+
training: bool = False,
|
| 163 |
+
) -> tf.Tensor:
|
| 164 |
+
"""
|
| 165 |
+
Applies embedding based on inputs tensor.
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
final_embeddings (`tf.Tensor`): output embedding tensor.
|
| 170 |
+
"""
|
| 171 |
+
assert not (input_ids is None and inputs_embeds is None)
|
| 172 |
+
|
| 173 |
+
if input_ids is not None:
|
| 174 |
+
check_embeddings_within_bounds(input_ids, self.config.vocab_size)
|
| 175 |
+
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
|
| 176 |
+
|
| 177 |
+
input_shape = shape_list(inputs_embeds)[:-1]
|
| 178 |
+
|
| 179 |
+
if token_type_ids is None:
|
| 180 |
+
token_type_ids = tf.fill(dims=input_shape, value=0)
|
| 181 |
+
|
| 182 |
+
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
|
| 183 |
+
final_embeddings = inputs_embeds + token_type_embeds
|
| 184 |
+
final_embeddings = self.LayerNorm(inputs=final_embeddings)
|
| 185 |
+
final_embeddings = self.dropout(inputs=final_embeddings, training=training)
|
| 186 |
+
|
| 187 |
+
return final_embeddings
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class TFRoFormerSelfAttention(keras.layers.Layer):
|
| 191 |
+
def __init__(self, config: RoFormerConfig, **kwargs):
|
| 192 |
+
super().__init__(**kwargs)
|
| 193 |
+
|
| 194 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 195 |
+
raise ValueError(
|
| 196 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
|
| 197 |
+
f"of attention heads ({config.num_attention_heads})"
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
self.num_attention_heads = config.num_attention_heads
|
| 201 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 202 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 203 |
+
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
|
| 204 |
+
|
| 205 |
+
self.query = keras.layers.Dense(
|
| 206 |
+
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
| 207 |
+
)
|
| 208 |
+
self.key = keras.layers.Dense(
|
| 209 |
+
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
|
| 210 |
+
)
|
| 211 |
+
self.value = keras.layers.Dense(
|
| 212 |
+
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
|
| 213 |
+
)
|
| 214 |
+
self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
| 215 |
+
self.rotary_value = config.rotary_value
|
| 216 |
+
self.config = config
|
| 217 |
+
|
| 218 |
+
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
| 219 |
+
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
| 220 |
+
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
| 221 |
+
|
| 222 |
+
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
|
| 223 |
+
return tf.transpose(tensor, perm=[0, 2, 1, 3])
|
| 224 |
+
|
| 225 |
+
def call(
|
| 226 |
+
self,
|
| 227 |
+
hidden_states: tf.Tensor,
|
| 228 |
+
attention_mask: tf.Tensor,
|
| 229 |
+
sinusoidal_pos: tf.Tensor,
|
| 230 |
+
head_mask: tf.Tensor,
|
| 231 |
+
output_attentions: bool,
|
| 232 |
+
training: bool = False,
|
| 233 |
+
) -> Tuple[tf.Tensor]:
|
| 234 |
+
batch_size = shape_list(hidden_states)[0]
|
| 235 |
+
mixed_query_layer = self.query(inputs=hidden_states)
|
| 236 |
+
mixed_key_layer = self.key(inputs=hidden_states)
|
| 237 |
+
mixed_value_layer = self.value(inputs=hidden_states)
|
| 238 |
+
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
| 239 |
+
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
| 240 |
+
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
| 241 |
+
|
| 242 |
+
if sinusoidal_pos is not None:
|
| 243 |
+
if self.rotary_value:
|
| 244 |
+
query_layer, key_layer, value_layer = self.apply_rotary_position_embeddings(
|
| 245 |
+
sinusoidal_pos, query_layer, key_layer, value_layer
|
| 246 |
+
)
|
| 247 |
+
else:
|
| 248 |
+
query_layer, key_layer = self.apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer)
|
| 249 |
+
|
| 250 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 251 |
+
# (batch size, num_heads, seq_len_q, seq_len_k)
|
| 252 |
+
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
|
| 253 |
+
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
|
| 254 |
+
attention_scores = tf.divide(attention_scores, dk)
|
| 255 |
+
|
| 256 |
+
if attention_mask is not None:
|
| 257 |
+
# Apply the attention mask is (precomputed for all layers in TFRoFormerModel call() function)
|
| 258 |
+
attention_scores = tf.add(attention_scores, attention_mask)
|
| 259 |
+
|
| 260 |
+
# Normalize the attention scores to probabilities.
|
| 261 |
+
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
|
| 262 |
+
|
| 263 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 264 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 265 |
+
attention_probs = self.dropout(inputs=attention_probs, training=training)
|
| 266 |
+
|
| 267 |
+
# Mask heads if we want to
|
| 268 |
+
if head_mask is not None:
|
| 269 |
+
attention_probs = tf.multiply(attention_probs, head_mask)
|
| 270 |
+
|
| 271 |
+
attention_output = tf.matmul(attention_probs, value_layer)
|
| 272 |
+
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
|
| 273 |
+
|
| 274 |
+
# (batch_size, seq_len_q, all_head_size)
|
| 275 |
+
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
| 276 |
+
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
| 277 |
+
|
| 278 |
+
return outputs
|
| 279 |
+
|
| 280 |
+
@staticmethod
|
| 281 |
+
def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
|
| 282 |
+
# https://kexue.fm/archives/8265
|
| 283 |
+
# sin [batch_size, num_heads, sequence_length, embed_size_per_head//2]
|
| 284 |
+
# cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]
|
| 285 |
+
sin, cos = tf.split(sinusoidal_pos, num_or_size_splits=2, axis=-1)
|
| 286 |
+
# sin [θ0,θ1,θ2......θd/2-1]-> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
|
| 287 |
+
# cos [θ0,θ1,θ2......θd/2-1]-> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
|
| 288 |
+
sin_pos = tf.repeat(sin, 2, axis=-1)
|
| 289 |
+
cos_pos = tf.repeat(cos, 2, axis=-1)
|
| 290 |
+
# rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
|
| 291 |
+
rotate_half_query_layer = tf.stack([-query_layer[..., 1::2], query_layer[..., ::2]], axis=-1)
|
| 292 |
+
rotate_half_query_layer = tf.reshape(rotate_half_query_layer, shape_list(query_layer))
|
| 293 |
+
query_layer = query_layer * cos_pos + rotate_half_query_layer * sin_pos
|
| 294 |
+
# rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
|
| 295 |
+
rotate_half_key_layer = tf.stack([-key_layer[..., 1::2], key_layer[..., ::2]], axis=-1)
|
| 296 |
+
rotate_half_key_layer = tf.reshape(rotate_half_key_layer, shape_list(key_layer))
|
| 297 |
+
key_layer = key_layer * cos_pos + rotate_half_key_layer * sin_pos
|
| 298 |
+
if value_layer is not None:
|
| 299 |
+
# rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2]
|
| 300 |
+
rotate_half_value_layer = tf.stack([-value_layer[..., 1::2], value_layer[..., ::2]], axis=-1)
|
| 301 |
+
rotate_half_value_layer = tf.reshape(rotate_half_value_layer, shape_list(value_layer))
|
| 302 |
+
value_layer = value_layer * cos_pos + rotate_half_value_layer * sin_pos
|
| 303 |
+
return query_layer, key_layer, value_layer
|
| 304 |
+
return query_layer, key_layer
|
| 305 |
+
|
| 306 |
+
def build(self, input_shape=None):
|
| 307 |
+
if self.built:
|
| 308 |
+
return
|
| 309 |
+
self.built = True
|
| 310 |
+
if getattr(self, "query", None) is not None:
|
| 311 |
+
with tf.name_scope(self.query.name):
|
| 312 |
+
self.query.build([None, None, self.config.hidden_size])
|
| 313 |
+
if getattr(self, "key", None) is not None:
|
| 314 |
+
with tf.name_scope(self.key.name):
|
| 315 |
+
self.key.build([None, None, self.config.hidden_size])
|
| 316 |
+
if getattr(self, "value", None) is not None:
|
| 317 |
+
with tf.name_scope(self.value.name):
|
| 318 |
+
self.value.build([None, None, self.config.hidden_size])
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->RoFormer
|
| 322 |
+
class TFRoFormerSelfOutput(keras.layers.Layer):
|
| 323 |
+
def __init__(self, config: RoFormerConfig, **kwargs):
|
| 324 |
+
super().__init__(**kwargs)
|
| 325 |
+
|
| 326 |
+
self.dense = keras.layers.Dense(
|
| 327 |
+
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 328 |
+
)
|
| 329 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 330 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 331 |
+
self.config = config
|
| 332 |
+
|
| 333 |
+
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 334 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 335 |
+
hidden_states = self.dropout(inputs=hidden_states, training=training)
|
| 336 |
+
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
|
| 337 |
+
|
| 338 |
+
return hidden_states
|
| 339 |
+
|
| 340 |
+
def build(self, input_shape=None):
|
| 341 |
+
if self.built:
|
| 342 |
+
return
|
| 343 |
+
self.built = True
|
| 344 |
+
if getattr(self, "dense", None) is not None:
|
| 345 |
+
with tf.name_scope(self.dense.name):
|
| 346 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 347 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 348 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 349 |
+
self.LayerNorm.build([None, None, self.config.hidden_size])
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class TFRoFormerAttention(keras.layers.Layer):
|
| 353 |
+
def __init__(self, config: RoFormerConfig, **kwargs):
|
| 354 |
+
super().__init__(**kwargs)
|
| 355 |
+
|
| 356 |
+
self.self_attention = TFRoFormerSelfAttention(config, name="self")
|
| 357 |
+
self.dense_output = TFRoFormerSelfOutput(config, name="output")
|
| 358 |
+
|
| 359 |
+
def prune_heads(self, heads):
|
| 360 |
+
raise NotImplementedError
|
| 361 |
+
|
| 362 |
+
def call(
|
| 363 |
+
self,
|
| 364 |
+
input_tensor: tf.Tensor,
|
| 365 |
+
attention_mask: tf.Tensor,
|
| 366 |
+
sinusoidal_pos: tf.Tensor,
|
| 367 |
+
head_mask: tf.Tensor,
|
| 368 |
+
output_attentions: bool,
|
| 369 |
+
training: bool = False,
|
| 370 |
+
) -> Tuple[tf.Tensor]:
|
| 371 |
+
self_outputs = self.self_attention(
|
| 372 |
+
hidden_states=input_tensor,
|
| 373 |
+
attention_mask=attention_mask,
|
| 374 |
+
sinusoidal_pos=sinusoidal_pos,
|
| 375 |
+
head_mask=head_mask,
|
| 376 |
+
output_attentions=output_attentions,
|
| 377 |
+
training=training,
|
| 378 |
+
)
|
| 379 |
+
attention_output = self.dense_output(
|
| 380 |
+
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
| 381 |
+
)
|
| 382 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 383 |
+
|
| 384 |
+
return outputs
|
| 385 |
+
|
| 386 |
+
def build(self, input_shape=None):
|
| 387 |
+
if self.built:
|
| 388 |
+
return
|
| 389 |
+
self.built = True
|
| 390 |
+
if getattr(self, "self_attention", None) is not None:
|
| 391 |
+
with tf.name_scope(self.self_attention.name):
|
| 392 |
+
self.self_attention.build(None)
|
| 393 |
+
if getattr(self, "dense_output", None) is not None:
|
| 394 |
+
with tf.name_scope(self.dense_output.name):
|
| 395 |
+
self.dense_output.build(None)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->RoFormer
|
| 399 |
+
class TFRoFormerIntermediate(keras.layers.Layer):
|
| 400 |
+
def __init__(self, config: RoFormerConfig, **kwargs):
|
| 401 |
+
super().__init__(**kwargs)
|
| 402 |
+
|
| 403 |
+
self.dense = keras.layers.Dense(
|
| 404 |
+
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
if isinstance(config.hidden_act, str):
|
| 408 |
+
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
|
| 409 |
+
else:
|
| 410 |
+
self.intermediate_act_fn = config.hidden_act
|
| 411 |
+
self.config = config
|
| 412 |
+
|
| 413 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 414 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 415 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 416 |
+
|
| 417 |
+
return hidden_states
|
| 418 |
+
|
| 419 |
+
def build(self, input_shape=None):
|
| 420 |
+
if self.built:
|
| 421 |
+
return
|
| 422 |
+
self.built = True
|
| 423 |
+
if getattr(self, "dense", None) is not None:
|
| 424 |
+
with tf.name_scope(self.dense.name):
|
| 425 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->RoFormer
|
| 429 |
+
class TFRoFormerOutput(keras.layers.Layer):
|
| 430 |
+
def __init__(self, config: RoFormerConfig, **kwargs):
|
| 431 |
+
super().__init__(**kwargs)
|
| 432 |
+
|
| 433 |
+
self.dense = keras.layers.Dense(
|
| 434 |
+
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 435 |
+
)
|
| 436 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 437 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 438 |
+
self.config = config
|
| 439 |
+
|
| 440 |
+
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 441 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 442 |
+
hidden_states = self.dropout(inputs=hidden_states, training=training)
|
| 443 |
+
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
|
| 444 |
+
|
| 445 |
+
return hidden_states
|
| 446 |
+
|
| 447 |
+
def build(self, input_shape=None):
|
| 448 |
+
if self.built:
|
| 449 |
+
return
|
| 450 |
+
self.built = True
|
| 451 |
+
if getattr(self, "dense", None) is not None:
|
| 452 |
+
with tf.name_scope(self.dense.name):
|
| 453 |
+
self.dense.build([None, None, self.config.intermediate_size])
|
| 454 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 455 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 456 |
+
self.LayerNorm.build([None, None, self.config.hidden_size])
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class TFRoFormerLayer(keras.layers.Layer):
|
| 460 |
+
def __init__(self, config: RoFormerConfig, **kwargs):
|
| 461 |
+
super().__init__(**kwargs)
|
| 462 |
+
|
| 463 |
+
self.attention = TFRoFormerAttention(config, name="attention")
|
| 464 |
+
self.intermediate = TFRoFormerIntermediate(config, name="intermediate")
|
| 465 |
+
self.roformer_output = TFRoFormerOutput(config, name="output")
|
| 466 |
+
|
| 467 |
+
def call(
|
| 468 |
+
self,
|
| 469 |
+
hidden_states: tf.Tensor,
|
| 470 |
+
attention_mask: tf.Tensor,
|
| 471 |
+
sinusoidal_pos: tf.Tensor,
|
| 472 |
+
head_mask: tf.Tensor,
|
| 473 |
+
output_attentions: bool,
|
| 474 |
+
training: bool = False,
|
| 475 |
+
) -> Tuple[tf.Tensor]:
|
| 476 |
+
attention_outputs = self.attention(
|
| 477 |
+
input_tensor=hidden_states,
|
| 478 |
+
attention_mask=attention_mask,
|
| 479 |
+
sinusoidal_pos=sinusoidal_pos,
|
| 480 |
+
head_mask=head_mask,
|
| 481 |
+
output_attentions=output_attentions,
|
| 482 |
+
training=training,
|
| 483 |
+
)
|
| 484 |
+
attention_output = attention_outputs[0]
|
| 485 |
+
intermediate_output = self.intermediate(hidden_states=attention_output)
|
| 486 |
+
layer_output = self.roformer_output(
|
| 487 |
+
hidden_states=intermediate_output, input_tensor=attention_output, training=training
|
| 488 |
+
)
|
| 489 |
+
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
| 490 |
+
|
| 491 |
+
return outputs
|
| 492 |
+
|
| 493 |
+
def build(self, input_shape=None):
|
| 494 |
+
if self.built:
|
| 495 |
+
return
|
| 496 |
+
self.built = True
|
| 497 |
+
if getattr(self, "attention", None) is not None:
|
| 498 |
+
with tf.name_scope(self.attention.name):
|
| 499 |
+
self.attention.build(None)
|
| 500 |
+
if getattr(self, "intermediate", None) is not None:
|
| 501 |
+
with tf.name_scope(self.intermediate.name):
|
| 502 |
+
self.intermediate.build(None)
|
| 503 |
+
if getattr(self, "roformer_output", None) is not None:
|
| 504 |
+
with tf.name_scope(self.roformer_output.name):
|
| 505 |
+
self.roformer_output.build(None)
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
class TFRoFormerEncoder(keras.layers.Layer):
|
| 509 |
+
def __init__(self, config: RoFormerConfig, **kwargs):
|
| 510 |
+
super().__init__(**kwargs)
|
| 511 |
+
self.embed_positions = TFRoFormerSinusoidalPositionalEmbedding(
|
| 512 |
+
config.max_position_embeddings,
|
| 513 |
+
config.hidden_size // config.num_attention_heads,
|
| 514 |
+
name="embed_positions",
|
| 515 |
+
)
|
| 516 |
+
self.layer = [TFRoFormerLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
|
| 517 |
+
|
| 518 |
+
def call(
|
| 519 |
+
self,
|
| 520 |
+
hidden_states: tf.Tensor,
|
| 521 |
+
attention_mask: tf.Tensor,
|
| 522 |
+
head_mask: tf.Tensor,
|
| 523 |
+
output_attentions: bool,
|
| 524 |
+
output_hidden_states: bool,
|
| 525 |
+
return_dict: bool,
|
| 526 |
+
training: bool = False,
|
| 527 |
+
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
| 528 |
+
all_hidden_states = () if output_hidden_states else None
|
| 529 |
+
all_attentions = () if output_attentions else None
|
| 530 |
+
|
| 531 |
+
# [sequence_length, embed_size_per_head] -> [batch_size, num_heads, sequence_length, embed_size_per_head]
|
| 532 |
+
sinusoidal_pos = self.embed_positions(shape_list(hidden_states)[:-1])[None, None, :, :]
|
| 533 |
+
|
| 534 |
+
for i, layer_module in enumerate(self.layer):
|
| 535 |
+
if output_hidden_states:
|
| 536 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 537 |
+
|
| 538 |
+
layer_outputs = layer_module(
|
| 539 |
+
hidden_states=hidden_states,
|
| 540 |
+
attention_mask=attention_mask,
|
| 541 |
+
sinusoidal_pos=sinusoidal_pos,
|
| 542 |
+
head_mask=head_mask[i],
|
| 543 |
+
output_attentions=output_attentions,
|
| 544 |
+
training=training,
|
| 545 |
+
)
|
| 546 |
+
hidden_states = layer_outputs[0]
|
| 547 |
+
|
| 548 |
+
if output_attentions:
|
| 549 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 550 |
+
|
| 551 |
+
# Add last layer
|
| 552 |
+
if output_hidden_states:
|
| 553 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 554 |
+
|
| 555 |
+
if not return_dict:
|
| 556 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
| 557 |
+
|
| 558 |
+
return TFBaseModelOutput(
|
| 559 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
def build(self, input_shape=None):
|
| 563 |
+
if self.built:
|
| 564 |
+
return
|
| 565 |
+
self.built = True
|
| 566 |
+
if getattr(self, "embed_positions", None) is not None:
|
| 567 |
+
with tf.name_scope(self.embed_positions.name):
|
| 568 |
+
self.embed_positions.build(None)
|
| 569 |
+
if getattr(self, "layer", None) is not None:
|
| 570 |
+
for layer in self.layer:
|
| 571 |
+
with tf.name_scope(layer.name):
|
| 572 |
+
layer.build(None)
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
class TFRoFormerPredictionHeadTransform(keras.layers.Layer):
|
| 576 |
+
def __init__(self, config: RoFormerConfig, **kwargs):
|
| 577 |
+
super().__init__(**kwargs)
|
| 578 |
+
|
| 579 |
+
self.dense = keras.layers.Dense(
|
| 580 |
+
units=config.embedding_size,
|
| 581 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 582 |
+
name="dense",
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
if isinstance(config.hidden_act, str):
|
| 586 |
+
self.transform_act_fn = get_tf_activation(config.hidden_act)
|
| 587 |
+
else:
|
| 588 |
+
self.transform_act_fn = config.hidden_act
|
| 589 |
+
|
| 590 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 591 |
+
self.config = config
|
| 592 |
+
|
| 593 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 594 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 595 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 596 |
+
hidden_states = self.LayerNorm(inputs=hidden_states)
|
| 597 |
+
|
| 598 |
+
return hidden_states
|
| 599 |
+
|
| 600 |
+
def build(self, input_shape=None):
|
| 601 |
+
if self.built:
|
| 602 |
+
return
|
| 603 |
+
self.built = True
|
| 604 |
+
if getattr(self, "dense", None) is not None:
|
| 605 |
+
with tf.name_scope(self.dense.name):
|
| 606 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 607 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 608 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 609 |
+
self.LayerNorm.build([None, None, self.config.embedding_size])
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
class TFRoFormerLMPredictionHead(keras.layers.Layer):
|
| 613 |
+
def __init__(self, config: RoFormerConfig, input_embeddings: keras.layers.Layer, **kwargs):
|
| 614 |
+
super().__init__(**kwargs)
|
| 615 |
+
|
| 616 |
+
self.config = config
|
| 617 |
+
self.embedding_size = config.embedding_size
|
| 618 |
+
|
| 619 |
+
self.transform = TFRoFormerPredictionHeadTransform(config, name="transform")
|
| 620 |
+
|
| 621 |
+
# The output weights are the same as the input embeddings, but there is
|
| 622 |
+
# an output-only bias for each token.
|
| 623 |
+
self.input_embeddings = input_embeddings
|
| 624 |
+
|
| 625 |
+
def build(self, input_shape=None):
|
| 626 |
+
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
| 627 |
+
|
| 628 |
+
if self.built:
|
| 629 |
+
return
|
| 630 |
+
self.built = True
|
| 631 |
+
if getattr(self, "transform", None) is not None:
|
| 632 |
+
with tf.name_scope(self.transform.name):
|
| 633 |
+
self.transform.build(None)
|
| 634 |
+
|
| 635 |
+
def get_output_embeddings(self) -> keras.layers.Layer:
|
| 636 |
+
return self.input_embeddings
|
| 637 |
+
|
| 638 |
+
def set_output_embeddings(self, value: tf.Variable):
|
| 639 |
+
self.input_embeddings.weight = value
|
| 640 |
+
self.input_embeddings.vocab_size = shape_list(value)[0]
|
| 641 |
+
|
| 642 |
+
def get_bias(self) -> Dict[str, tf.Variable]:
|
| 643 |
+
return {"bias": self.bias}
|
| 644 |
+
|
| 645 |
+
def set_bias(self, value: tf.Variable):
|
| 646 |
+
self.bias = value["bias"]
|
| 647 |
+
self.config.vocab_size = shape_list(value["bias"])[0]
|
| 648 |
+
|
| 649 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 650 |
+
hidden_states = self.transform(hidden_states=hidden_states)
|
| 651 |
+
seq_length = shape_list(hidden_states)[1]
|
| 652 |
+
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
|
| 653 |
+
hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
|
| 654 |
+
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
|
| 655 |
+
hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
|
| 656 |
+
|
| 657 |
+
return hidden_states
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMLMHead with Bert->RoFormer
|
| 661 |
+
class TFRoFormerMLMHead(keras.layers.Layer):
|
| 662 |
+
def __init__(self, config: RoFormerConfig, input_embeddings: keras.layers.Layer, **kwargs):
|
| 663 |
+
super().__init__(**kwargs)
|
| 664 |
+
|
| 665 |
+
self.predictions = TFRoFormerLMPredictionHead(config, input_embeddings, name="predictions")
|
| 666 |
+
|
| 667 |
+
def call(self, sequence_output: tf.Tensor) -> tf.Tensor:
|
| 668 |
+
prediction_scores = self.predictions(hidden_states=sequence_output)
|
| 669 |
+
|
| 670 |
+
return prediction_scores
|
| 671 |
+
|
| 672 |
+
def build(self, input_shape=None):
|
| 673 |
+
if self.built:
|
| 674 |
+
return
|
| 675 |
+
self.built = True
|
| 676 |
+
if getattr(self, "predictions", None) is not None:
|
| 677 |
+
with tf.name_scope(self.predictions.name):
|
| 678 |
+
self.predictions.build(None)
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
@keras_serializable
|
| 682 |
+
class TFRoFormerMainLayer(keras.layers.Layer):
|
| 683 |
+
config_class = RoFormerConfig
|
| 684 |
+
|
| 685 |
+
def __init__(self, config: RoFormerConfig, add_pooling_layer: bool = True, **kwargs):
|
| 686 |
+
super().__init__(**kwargs)
|
| 687 |
+
|
| 688 |
+
self.config = config
|
| 689 |
+
|
| 690 |
+
self.embeddings = TFRoFormerEmbeddings(config, name="embeddings")
|
| 691 |
+
if config.embedding_size != config.hidden_size:
|
| 692 |
+
self.embeddings_project = keras.layers.Dense(config.hidden_size, name="embeddings_project")
|
| 693 |
+
|
| 694 |
+
self.encoder = TFRoFormerEncoder(config, name="encoder")
|
| 695 |
+
|
| 696 |
+
def get_input_embeddings(self) -> keras.layers.Layer:
|
| 697 |
+
return self.embeddings
|
| 698 |
+
|
| 699 |
+
def set_input_embeddings(self, value: tf.Variable):
|
| 700 |
+
self.embeddings.weight = value
|
| 701 |
+
self.embeddings.vocab_size = shape_list(value)[0]
|
| 702 |
+
|
| 703 |
+
def _prune_heads(self, heads_to_prune):
|
| 704 |
+
"""
|
| 705 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 706 |
+
class PreTrainedModel
|
| 707 |
+
"""
|
| 708 |
+
raise NotImplementedError
|
| 709 |
+
|
| 710 |
+
@unpack_inputs
|
| 711 |
+
def call(
|
| 712 |
+
self,
|
| 713 |
+
input_ids: TFModelInputType | None = None,
|
| 714 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 715 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 716 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 717 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 718 |
+
output_attentions: Optional[bool] = None,
|
| 719 |
+
output_hidden_states: Optional[bool] = None,
|
| 720 |
+
return_dict: Optional[bool] = None,
|
| 721 |
+
training: bool = False,
|
| 722 |
+
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
| 723 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 724 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 725 |
+
elif input_ids is not None:
|
| 726 |
+
input_shape = shape_list(input_ids)
|
| 727 |
+
elif inputs_embeds is not None:
|
| 728 |
+
input_shape = shape_list(inputs_embeds)[:-1]
|
| 729 |
+
else:
|
| 730 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 731 |
+
|
| 732 |
+
if attention_mask is None:
|
| 733 |
+
attention_mask = tf.fill(dims=input_shape, value=1)
|
| 734 |
+
|
| 735 |
+
if token_type_ids is None:
|
| 736 |
+
token_type_ids = tf.fill(dims=input_shape, value=0)
|
| 737 |
+
|
| 738 |
+
embedding_output = self.embeddings(
|
| 739 |
+
input_ids=input_ids,
|
| 740 |
+
token_type_ids=token_type_ids,
|
| 741 |
+
inputs_embeds=inputs_embeds,
|
| 742 |
+
training=training,
|
| 743 |
+
)
|
| 744 |
+
if hasattr(self, "embeddings_project"):
|
| 745 |
+
embedding_output = self.embeddings_project(embedding_output, training=training)
|
| 746 |
+
|
| 747 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 748 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
| 749 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
| 750 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
| 751 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
| 752 |
+
extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
|
| 753 |
+
|
| 754 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 755 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 756 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 757 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 758 |
+
# effectively the same as removing these entirely.
|
| 759 |
+
extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
|
| 760 |
+
one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
|
| 761 |
+
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
|
| 762 |
+
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
| 763 |
+
|
| 764 |
+
# Prepare head mask if needed
|
| 765 |
+
# 1.0 in head_mask indicate we keep the head
|
| 766 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 767 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 768 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 769 |
+
if head_mask is not None:
|
| 770 |
+
raise NotImplementedError
|
| 771 |
+
else:
|
| 772 |
+
head_mask = [None] * self.config.num_hidden_layers
|
| 773 |
+
|
| 774 |
+
encoder_outputs = self.encoder(
|
| 775 |
+
hidden_states=embedding_output,
|
| 776 |
+
attention_mask=extended_attention_mask,
|
| 777 |
+
head_mask=head_mask,
|
| 778 |
+
output_attentions=output_attentions,
|
| 779 |
+
output_hidden_states=output_hidden_states,
|
| 780 |
+
return_dict=return_dict,
|
| 781 |
+
training=training,
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
sequence_output = encoder_outputs[0]
|
| 785 |
+
|
| 786 |
+
if not return_dict:
|
| 787 |
+
return (sequence_output,) + encoder_outputs[1:]
|
| 788 |
+
|
| 789 |
+
return TFBaseModelOutput(
|
| 790 |
+
last_hidden_state=sequence_output,
|
| 791 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 792 |
+
attentions=encoder_outputs.attentions,
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
def build(self, input_shape=None):
|
| 796 |
+
if self.built:
|
| 797 |
+
return
|
| 798 |
+
self.built = True
|
| 799 |
+
if getattr(self, "embeddings", None) is not None:
|
| 800 |
+
with tf.name_scope(self.embeddings.name):
|
| 801 |
+
self.embeddings.build(None)
|
| 802 |
+
if getattr(self, "encoder", None) is not None:
|
| 803 |
+
with tf.name_scope(self.encoder.name):
|
| 804 |
+
self.encoder.build(None)
|
| 805 |
+
if getattr(self, "embeddings_project", None) is not None:
|
| 806 |
+
with tf.name_scope(self.embeddings_project.name):
|
| 807 |
+
self.embeddings_project.build([None, None, self.config.embedding_size])
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
class TFRoFormerPreTrainedModel(TFPreTrainedModel):
|
| 811 |
+
"""
|
| 812 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 813 |
+
models.
|
| 814 |
+
"""
|
| 815 |
+
|
| 816 |
+
config_class = RoFormerConfig
|
| 817 |
+
base_model_prefix = "roformer"
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
ROFORMER_START_DOCSTRING = r"""
|
| 821 |
+
|
| 822 |
+
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 823 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 824 |
+
etc.)
|
| 825 |
+
|
| 826 |
+
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
| 827 |
+
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
| 828 |
+
behavior.
|
| 829 |
+
|
| 830 |
+
<Tip>
|
| 831 |
+
|
| 832 |
+
TensorFlow models and layers in `transformers` accept two formats as input:
|
| 833 |
+
|
| 834 |
+
- having all inputs as keyword arguments (like PyTorch models), or
|
| 835 |
+
- having all inputs as a list, tuple or dict in the first positional argument.
|
| 836 |
+
|
| 837 |
+
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
|
| 838 |
+
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
|
| 839 |
+
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
|
| 840 |
+
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
|
| 841 |
+
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
|
| 842 |
+
positional argument:
|
| 843 |
+
|
| 844 |
+
- a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
|
| 845 |
+
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
| 846 |
+
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
|
| 847 |
+
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
| 848 |
+
`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
|
| 849 |
+
|
| 850 |
+
Note that when creating models and layers with
|
| 851 |
+
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
|
| 852 |
+
about any of this, as you can just pass inputs like you would to any other Python function!
|
| 853 |
+
|
| 854 |
+
</Tip>
|
| 855 |
+
|
| 856 |
+
Args:
|
| 857 |
+
config ([`RoFormerConfig`]): Model configuration class with all the parameters of the model.
|
| 858 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 859 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 860 |
+
"""
|
| 861 |
+
|
| 862 |
+
ROFORMER_INPUTS_DOCSTRING = r"""
|
| 863 |
+
Args:
|
| 864 |
+
input_ids (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`):
|
| 865 |
+
Indices of input sequence tokens in the vocabulary.
|
| 866 |
+
|
| 867 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
|
| 868 |
+
[`PreTrainedTokenizer.encode`] for details.
|
| 869 |
+
|
| 870 |
+
[What are input IDs?](../glossary#input-ids)
|
| 871 |
+
attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 872 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 873 |
+
|
| 874 |
+
- 1 for tokens that are **not masked**,
|
| 875 |
+
- 0 for tokens that are **masked**.
|
| 876 |
+
|
| 877 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 878 |
+
token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 879 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 880 |
+
1]`:
|
| 881 |
+
|
| 882 |
+
- 0 corresponds to a *sentence A* token,
|
| 883 |
+
- 1 corresponds to a *sentence B* token.
|
| 884 |
+
|
| 885 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 886 |
+
head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 887 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 888 |
+
|
| 889 |
+
- 1 indicates the head is **not masked**,
|
| 890 |
+
- 0 indicates the head is **masked**.
|
| 891 |
+
|
| 892 |
+
inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
|
| 893 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 894 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 895 |
+
model's internal embedding lookup matrix.
|
| 896 |
+
output_attentions (`bool`, *optional*):
|
| 897 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 898 |
+
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
|
| 899 |
+
config will be used instead.
|
| 900 |
+
output_hidden_states (`bool`, *optional*):
|
| 901 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 902 |
+
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
| 903 |
+
used instead.
|
| 904 |
+
return_dict (`bool`, *optional*):
|
| 905 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
|
| 906 |
+
eager mode, in graph mode the value will always be set to True.
|
| 907 |
+
training (`bool`, *optional*, defaults to `False``):
|
| 908 |
+
Whether or not to use the model in training mode (some modules like dropout modules have different
|
| 909 |
+
behaviors between training and evaluation).
|
| 910 |
+
"""
|
| 911 |
+
|
| 912 |
+
|
| 913 |
+
@add_start_docstrings(
|
| 914 |
+
"The bare RoFormer Model transformer outputing raw hidden-states without any specific head on top.",
|
| 915 |
+
ROFORMER_START_DOCSTRING,
|
| 916 |
+
)
|
| 917 |
+
class TFRoFormerModel(TFRoFormerPreTrainedModel):
|
| 918 |
+
def __init__(self, config: RoFormerConfig, *inputs, **kwargs):
|
| 919 |
+
super().__init__(config, *inputs, **kwargs)
|
| 920 |
+
|
| 921 |
+
self.roformer = TFRoFormerMainLayer(config, name="roformer")
|
| 922 |
+
|
| 923 |
+
@unpack_inputs
|
| 924 |
+
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 925 |
+
@add_code_sample_docstrings(
|
| 926 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 927 |
+
output_type=TFBaseModelOutputWithPooling,
|
| 928 |
+
config_class=_CONFIG_FOR_DOC,
|
| 929 |
+
)
|
| 930 |
+
def call(
|
| 931 |
+
self,
|
| 932 |
+
input_ids: TFModelInputType | None = None,
|
| 933 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 934 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 935 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 936 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 937 |
+
output_attentions: Optional[bool] = None,
|
| 938 |
+
output_hidden_states: Optional[bool] = None,
|
| 939 |
+
return_dict: Optional[bool] = None,
|
| 940 |
+
training: Optional[bool] = False,
|
| 941 |
+
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
| 942 |
+
outputs = self.roformer(
|
| 943 |
+
input_ids=input_ids,
|
| 944 |
+
attention_mask=attention_mask,
|
| 945 |
+
token_type_ids=token_type_ids,
|
| 946 |
+
head_mask=head_mask,
|
| 947 |
+
inputs_embeds=inputs_embeds,
|
| 948 |
+
output_attentions=output_attentions,
|
| 949 |
+
output_hidden_states=output_hidden_states,
|
| 950 |
+
return_dict=return_dict,
|
| 951 |
+
training=training,
|
| 952 |
+
)
|
| 953 |
+
|
| 954 |
+
return outputs
|
| 955 |
+
|
| 956 |
+
def build(self, input_shape=None):
|
| 957 |
+
if self.built:
|
| 958 |
+
return
|
| 959 |
+
self.built = True
|
| 960 |
+
if getattr(self, "roformer", None) is not None:
|
| 961 |
+
with tf.name_scope(self.roformer.name):
|
| 962 |
+
self.roformer.build(None)
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
@add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING)
|
| 966 |
+
class TFRoFormerForMaskedLM(TFRoFormerPreTrainedModel, TFMaskedLanguageModelingLoss):
|
| 967 |
+
def __init__(self, config: RoFormerConfig, *inputs, **kwargs):
|
| 968 |
+
super().__init__(config, *inputs, **kwargs)
|
| 969 |
+
|
| 970 |
+
if config.is_decoder:
|
| 971 |
+
logger.warning(
|
| 972 |
+
"If you want to use `TFRoFormerForMaskedLM` make sure `config.is_decoder=False` for "
|
| 973 |
+
"bi-directional self-attention."
|
| 974 |
+
)
|
| 975 |
+
|
| 976 |
+
self.roformer = TFRoFormerMainLayer(config, name="roformer")
|
| 977 |
+
self.mlm = TFRoFormerMLMHead(config, input_embeddings=self.roformer.embeddings, name="mlm___cls")
|
| 978 |
+
|
| 979 |
+
def get_lm_head(self) -> keras.layers.Layer:
|
| 980 |
+
return self.mlm.predictions
|
| 981 |
+
|
| 982 |
+
@unpack_inputs
|
| 983 |
+
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 984 |
+
@add_code_sample_docstrings(
|
| 985 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 986 |
+
output_type=TFMaskedLMOutput,
|
| 987 |
+
config_class=_CONFIG_FOR_DOC,
|
| 988 |
+
)
|
| 989 |
+
def call(
|
| 990 |
+
self,
|
| 991 |
+
input_ids: TFModelInputType | None = None,
|
| 992 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 993 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 994 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 995 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 996 |
+
output_attentions: Optional[bool] = None,
|
| 997 |
+
output_hidden_states: Optional[bool] = None,
|
| 998 |
+
return_dict: Optional[bool] = None,
|
| 999 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1000 |
+
training: Optional[bool] = False,
|
| 1001 |
+
) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
|
| 1002 |
+
r"""
|
| 1003 |
+
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1004 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 1005 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 1006 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 1007 |
+
"""
|
| 1008 |
+
outputs = self.roformer(
|
| 1009 |
+
input_ids=input_ids,
|
| 1010 |
+
attention_mask=attention_mask,
|
| 1011 |
+
token_type_ids=token_type_ids,
|
| 1012 |
+
head_mask=head_mask,
|
| 1013 |
+
inputs_embeds=inputs_embeds,
|
| 1014 |
+
output_attentions=output_attentions,
|
| 1015 |
+
output_hidden_states=output_hidden_states,
|
| 1016 |
+
return_dict=return_dict,
|
| 1017 |
+
training=training,
|
| 1018 |
+
)
|
| 1019 |
+
sequence_output = outputs[0]
|
| 1020 |
+
prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
|
| 1021 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
|
| 1022 |
+
|
| 1023 |
+
if not return_dict:
|
| 1024 |
+
output = (prediction_scores,) + outputs[2:]
|
| 1025 |
+
return ((loss,) + output) if loss is not None else output
|
| 1026 |
+
|
| 1027 |
+
return TFMaskedLMOutput(
|
| 1028 |
+
loss=loss,
|
| 1029 |
+
logits=prediction_scores,
|
| 1030 |
+
hidden_states=outputs.hidden_states,
|
| 1031 |
+
attentions=outputs.attentions,
|
| 1032 |
+
)
|
| 1033 |
+
|
| 1034 |
+
def build(self, input_shape=None):
|
| 1035 |
+
if self.built:
|
| 1036 |
+
return
|
| 1037 |
+
self.built = True
|
| 1038 |
+
if getattr(self, "roformer", None) is not None:
|
| 1039 |
+
with tf.name_scope(self.roformer.name):
|
| 1040 |
+
self.roformer.build(None)
|
| 1041 |
+
if getattr(self, "mlm", None) is not None:
|
| 1042 |
+
with tf.name_scope(self.mlm.name):
|
| 1043 |
+
self.mlm.build(None)
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
@add_start_docstrings(
|
| 1047 |
+
"""RoFormer Model with a `language modeling` head on top for CLM fine-tuning.""", ROFORMER_START_DOCSTRING
|
| 1048 |
+
)
|
| 1049 |
+
class TFRoFormerForCausalLM(TFRoFormerPreTrainedModel, TFCausalLanguageModelingLoss):
|
| 1050 |
+
def __init__(self, config: RoFormerConfig, *inputs, **kwargs):
|
| 1051 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1052 |
+
|
| 1053 |
+
if not config.is_decoder:
|
| 1054 |
+
logger.warning("If you want to use `TFRoFormerForCausalLM` as a standalone, add `is_decoder=True.`")
|
| 1055 |
+
|
| 1056 |
+
self.roformer = TFRoFormerMainLayer(config, name="roformer")
|
| 1057 |
+
self.mlm = TFRoFormerMLMHead(config, input_embeddings=self.roformer.embeddings, name="mlm___cls")
|
| 1058 |
+
|
| 1059 |
+
def get_lm_head(self) -> keras.layers.Layer:
|
| 1060 |
+
return self.mlm.predictions
|
| 1061 |
+
|
| 1062 |
+
@unpack_inputs
|
| 1063 |
+
@add_code_sample_docstrings(
|
| 1064 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1065 |
+
output_type=TFCausalLMOutput,
|
| 1066 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1067 |
+
)
|
| 1068 |
+
def call(
|
| 1069 |
+
self,
|
| 1070 |
+
input_ids: TFModelInputType | None = None,
|
| 1071 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1072 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1073 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1074 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1075 |
+
output_attentions: Optional[bool] = None,
|
| 1076 |
+
output_hidden_states: Optional[bool] = None,
|
| 1077 |
+
return_dict: Optional[bool] = None,
|
| 1078 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1079 |
+
training: Optional[bool] = False,
|
| 1080 |
+
) -> Union[TFCausalLMOutput, Tuple[tf.Tensor]]:
|
| 1081 |
+
r"""
|
| 1082 |
+
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1083 |
+
Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
|
| 1084 |
+
config.vocab_size - 1]`.
|
| 1085 |
+
"""
|
| 1086 |
+
outputs = self.roformer(
|
| 1087 |
+
input_ids=input_ids,
|
| 1088 |
+
attention_mask=attention_mask,
|
| 1089 |
+
token_type_ids=token_type_ids,
|
| 1090 |
+
head_mask=head_mask,
|
| 1091 |
+
inputs_embeds=inputs_embeds,
|
| 1092 |
+
output_attentions=output_attentions,
|
| 1093 |
+
output_hidden_states=output_hidden_states,
|
| 1094 |
+
return_dict=return_dict,
|
| 1095 |
+
training=training,
|
| 1096 |
+
)
|
| 1097 |
+
sequence_output = outputs[0]
|
| 1098 |
+
logits = self.mlm(sequence_output=sequence_output, training=training)
|
| 1099 |
+
loss = None
|
| 1100 |
+
|
| 1101 |
+
if labels is not None:
|
| 1102 |
+
# shift labels to the left and cut last logit token
|
| 1103 |
+
shifted_logits = logits[:, :-1]
|
| 1104 |
+
labels = labels[:, 1:]
|
| 1105 |
+
loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
|
| 1106 |
+
|
| 1107 |
+
if not return_dict:
|
| 1108 |
+
output = (logits,) + outputs[2:]
|
| 1109 |
+
return ((loss,) + output) if loss is not None else output
|
| 1110 |
+
|
| 1111 |
+
return TFCausalLMOutput(
|
| 1112 |
+
loss=loss,
|
| 1113 |
+
logits=logits,
|
| 1114 |
+
hidden_states=outputs.hidden_states,
|
| 1115 |
+
attentions=outputs.attentions,
|
| 1116 |
+
)
|
| 1117 |
+
|
| 1118 |
+
def build(self, input_shape=None):
|
| 1119 |
+
if self.built:
|
| 1120 |
+
return
|
| 1121 |
+
self.built = True
|
| 1122 |
+
if getattr(self, "roformer", None) is not None:
|
| 1123 |
+
with tf.name_scope(self.roformer.name):
|
| 1124 |
+
self.roformer.build(None)
|
| 1125 |
+
if getattr(self, "mlm", None) is not None:
|
| 1126 |
+
with tf.name_scope(self.mlm.name):
|
| 1127 |
+
self.mlm.build(None)
|
| 1128 |
+
|
| 1129 |
+
|
| 1130 |
+
class TFRoFormerClassificationHead(keras.layers.Layer):
|
| 1131 |
+
"""Head for sentence-level classification tasks."""
|
| 1132 |
+
|
| 1133 |
+
def __init__(self, config: RoFormerConfig, *inputs, **kwargs):
|
| 1134 |
+
super().__init__(*inputs, **kwargs)
|
| 1135 |
+
|
| 1136 |
+
self.dense = keras.layers.Dense(
|
| 1137 |
+
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 1138 |
+
)
|
| 1139 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 1140 |
+
self.out_proj = keras.layers.Dense(
|
| 1141 |
+
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj"
|
| 1142 |
+
)
|
| 1143 |
+
|
| 1144 |
+
if isinstance(config.hidden_act, str):
|
| 1145 |
+
self.classifier_act_fn = get_tf_activation(config.hidden_act)
|
| 1146 |
+
else:
|
| 1147 |
+
self.classifier_act_fn = config.hidden_act
|
| 1148 |
+
self.config = config
|
| 1149 |
+
|
| 1150 |
+
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 1151 |
+
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
|
| 1152 |
+
hidden_states = self.dropout(inputs=hidden_states, training=training)
|
| 1153 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 1154 |
+
hidden_states = self.classifier_act_fn(hidden_states)
|
| 1155 |
+
hidden_states = self.dropout(inputs=hidden_states, training=training)
|
| 1156 |
+
hidden_states = self.out_proj(hidden_states)
|
| 1157 |
+
|
| 1158 |
+
return hidden_states
|
| 1159 |
+
|
| 1160 |
+
def build(self, input_shape=None):
|
| 1161 |
+
if self.built:
|
| 1162 |
+
return
|
| 1163 |
+
self.built = True
|
| 1164 |
+
if getattr(self, "dense", None) is not None:
|
| 1165 |
+
with tf.name_scope(self.dense.name):
|
| 1166 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 1167 |
+
if getattr(self, "out_proj", None) is not None:
|
| 1168 |
+
with tf.name_scope(self.out_proj.name):
|
| 1169 |
+
self.out_proj.build([None, None, self.config.hidden_size])
|
| 1170 |
+
|
| 1171 |
+
|
| 1172 |
+
@add_start_docstrings(
|
| 1173 |
+
"""
|
| 1174 |
+
RoFormer Model transformer with a sequence classification/regression head on top e.g., for GLUE tasks.
|
| 1175 |
+
""",
|
| 1176 |
+
ROFORMER_START_DOCSTRING,
|
| 1177 |
+
)
|
| 1178 |
+
class TFRoFormerForSequenceClassification(TFRoFormerPreTrainedModel, TFSequenceClassificationLoss):
|
| 1179 |
+
def __init__(self, config: RoFormerConfig, *inputs, **kwargs):
|
| 1180 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1181 |
+
|
| 1182 |
+
self.num_labels = config.num_labels
|
| 1183 |
+
|
| 1184 |
+
self.roformer = TFRoFormerMainLayer(config, name="roformer")
|
| 1185 |
+
self.classifier = TFRoFormerClassificationHead(config, name="classifier")
|
| 1186 |
+
|
| 1187 |
+
@unpack_inputs
|
| 1188 |
+
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1189 |
+
@add_code_sample_docstrings(
|
| 1190 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1191 |
+
output_type=TFSequenceClassifierOutput,
|
| 1192 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1193 |
+
)
|
| 1194 |
+
def call(
|
| 1195 |
+
self,
|
| 1196 |
+
input_ids: TFModelInputType | None = None,
|
| 1197 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1198 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1199 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1200 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1201 |
+
output_attentions: Optional[bool] = None,
|
| 1202 |
+
output_hidden_states: Optional[bool] = None,
|
| 1203 |
+
return_dict: Optional[bool] = None,
|
| 1204 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1205 |
+
training: Optional[bool] = False,
|
| 1206 |
+
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
|
| 1207 |
+
r"""
|
| 1208 |
+
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
|
| 1209 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1210 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1211 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1212 |
+
"""
|
| 1213 |
+
outputs = self.roformer(
|
| 1214 |
+
input_ids=input_ids,
|
| 1215 |
+
attention_mask=attention_mask,
|
| 1216 |
+
token_type_ids=token_type_ids,
|
| 1217 |
+
head_mask=head_mask,
|
| 1218 |
+
inputs_embeds=inputs_embeds,
|
| 1219 |
+
output_attentions=output_attentions,
|
| 1220 |
+
output_hidden_states=output_hidden_states,
|
| 1221 |
+
return_dict=return_dict,
|
| 1222 |
+
training=training,
|
| 1223 |
+
)
|
| 1224 |
+
logits = self.classifier(hidden_states=outputs[0], training=training)
|
| 1225 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
|
| 1226 |
+
|
| 1227 |
+
if not return_dict:
|
| 1228 |
+
output = (logits,) + outputs[1:]
|
| 1229 |
+
|
| 1230 |
+
return ((loss,) + output) if loss is not None else output
|
| 1231 |
+
|
| 1232 |
+
return TFSequenceClassifierOutput(
|
| 1233 |
+
loss=loss,
|
| 1234 |
+
logits=logits,
|
| 1235 |
+
hidden_states=outputs.hidden_states,
|
| 1236 |
+
attentions=outputs.attentions,
|
| 1237 |
+
)
|
| 1238 |
+
|
| 1239 |
+
def build(self, input_shape=None):
|
| 1240 |
+
if self.built:
|
| 1241 |
+
return
|
| 1242 |
+
self.built = True
|
| 1243 |
+
if getattr(self, "roformer", None) is not None:
|
| 1244 |
+
with tf.name_scope(self.roformer.name):
|
| 1245 |
+
self.roformer.build(None)
|
| 1246 |
+
if getattr(self, "classifier", None) is not None:
|
| 1247 |
+
with tf.name_scope(self.classifier.name):
|
| 1248 |
+
self.classifier.build(None)
|
| 1249 |
+
|
| 1250 |
+
|
| 1251 |
+
@add_start_docstrings(
|
| 1252 |
+
"""
|
| 1253 |
+
RoFormer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
| 1254 |
+
softmax) e.g. for RocStories/SWAG tasks.
|
| 1255 |
+
""",
|
| 1256 |
+
ROFORMER_START_DOCSTRING,
|
| 1257 |
+
)
|
| 1258 |
+
class TFRoFormerForMultipleChoice(TFRoFormerPreTrainedModel, TFMultipleChoiceLoss):
|
| 1259 |
+
def __init__(self, config: RoFormerConfig, *inputs, **kwargs):
|
| 1260 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1261 |
+
|
| 1262 |
+
self.roformer = TFRoFormerMainLayer(config, name="roformer")
|
| 1263 |
+
self.sequence_summary = TFSequenceSummary(config, config.initializer_range, name="sequence_summary")
|
| 1264 |
+
self.classifier = keras.layers.Dense(
|
| 1265 |
+
units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
| 1266 |
+
)
|
| 1267 |
+
self.config = config
|
| 1268 |
+
|
| 1269 |
+
@unpack_inputs
|
| 1270 |
+
@add_start_docstrings_to_model_forward(
|
| 1271 |
+
ROFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
| 1272 |
+
)
|
| 1273 |
+
@add_code_sample_docstrings(
|
| 1274 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1275 |
+
output_type=TFMultipleChoiceModelOutput,
|
| 1276 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1277 |
+
)
|
| 1278 |
+
def call(
|
| 1279 |
+
self,
|
| 1280 |
+
input_ids: TFModelInputType | None = None,
|
| 1281 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1282 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1283 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1284 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1285 |
+
output_attentions: Optional[bool] = None,
|
| 1286 |
+
output_hidden_states: Optional[bool] = None,
|
| 1287 |
+
return_dict: Optional[bool] = None,
|
| 1288 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1289 |
+
training: Optional[bool] = False,
|
| 1290 |
+
) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
|
| 1291 |
+
r"""
|
| 1292 |
+
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
|
| 1293 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
|
| 1294 |
+
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
|
| 1295 |
+
"""
|
| 1296 |
+
if input_ids is not None:
|
| 1297 |
+
num_choices = shape_list(input_ids)[1]
|
| 1298 |
+
seq_length = shape_list(input_ids)[2]
|
| 1299 |
+
else:
|
| 1300 |
+
num_choices = shape_list(inputs_embeds)[1]
|
| 1301 |
+
seq_length = shape_list(inputs_embeds)[2]
|
| 1302 |
+
|
| 1303 |
+
flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None
|
| 1304 |
+
flat_attention_mask = (
|
| 1305 |
+
tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None
|
| 1306 |
+
)
|
| 1307 |
+
flat_token_type_ids = (
|
| 1308 |
+
tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None
|
| 1309 |
+
)
|
| 1310 |
+
flat_inputs_embeds = (
|
| 1311 |
+
tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3]))
|
| 1312 |
+
if inputs_embeds is not None
|
| 1313 |
+
else None
|
| 1314 |
+
)
|
| 1315 |
+
outputs = self.roformer(
|
| 1316 |
+
input_ids=flat_input_ids,
|
| 1317 |
+
attention_mask=flat_attention_mask,
|
| 1318 |
+
token_type_ids=flat_token_type_ids,
|
| 1319 |
+
head_mask=head_mask,
|
| 1320 |
+
inputs_embeds=flat_inputs_embeds,
|
| 1321 |
+
output_attentions=output_attentions,
|
| 1322 |
+
output_hidden_states=output_hidden_states,
|
| 1323 |
+
return_dict=return_dict,
|
| 1324 |
+
training=training,
|
| 1325 |
+
)
|
| 1326 |
+
logits = self.sequence_summary(inputs=outputs[0], training=training)
|
| 1327 |
+
logits = self.classifier(inputs=logits)
|
| 1328 |
+
reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
|
| 1329 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)
|
| 1330 |
+
|
| 1331 |
+
if not return_dict:
|
| 1332 |
+
output = (reshaped_logits,) + outputs[1:]
|
| 1333 |
+
|
| 1334 |
+
return ((loss,) + output) if loss is not None else output
|
| 1335 |
+
|
| 1336 |
+
return TFMultipleChoiceModelOutput(
|
| 1337 |
+
loss=loss,
|
| 1338 |
+
logits=reshaped_logits,
|
| 1339 |
+
hidden_states=outputs.hidden_states,
|
| 1340 |
+
attentions=outputs.attentions,
|
| 1341 |
+
)
|
| 1342 |
+
|
| 1343 |
+
def build(self, input_shape=None):
|
| 1344 |
+
if self.built:
|
| 1345 |
+
return
|
| 1346 |
+
self.built = True
|
| 1347 |
+
if getattr(self, "roformer", None) is not None:
|
| 1348 |
+
with tf.name_scope(self.roformer.name):
|
| 1349 |
+
self.roformer.build(None)
|
| 1350 |
+
if getattr(self, "sequence_summary", None) is not None:
|
| 1351 |
+
with tf.name_scope(self.sequence_summary.name):
|
| 1352 |
+
self.sequence_summary.build(None)
|
| 1353 |
+
if getattr(self, "classifier", None) is not None:
|
| 1354 |
+
with tf.name_scope(self.classifier.name):
|
| 1355 |
+
self.classifier.build([None, None, self.config.hidden_size])
|
| 1356 |
+
|
| 1357 |
+
|
| 1358 |
+
@add_start_docstrings(
|
| 1359 |
+
"""
|
| 1360 |
+
RoFormer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 1361 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1362 |
+
""",
|
| 1363 |
+
ROFORMER_START_DOCSTRING,
|
| 1364 |
+
)
|
| 1365 |
+
class TFRoFormerForTokenClassification(TFRoFormerPreTrainedModel, TFTokenClassificationLoss):
|
| 1366 |
+
def __init__(self, config: RoFormerConfig, *inputs, **kwargs):
|
| 1367 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1368 |
+
|
| 1369 |
+
self.num_labels = config.num_labels
|
| 1370 |
+
|
| 1371 |
+
self.roformer = TFRoFormerMainLayer(config, name="roformer")
|
| 1372 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 1373 |
+
self.classifier = keras.layers.Dense(
|
| 1374 |
+
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
| 1375 |
+
)
|
| 1376 |
+
self.config = config
|
| 1377 |
+
|
| 1378 |
+
@unpack_inputs
|
| 1379 |
+
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1380 |
+
@add_code_sample_docstrings(
|
| 1381 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1382 |
+
output_type=TFTokenClassifierOutput,
|
| 1383 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1384 |
+
)
|
| 1385 |
+
def call(
|
| 1386 |
+
self,
|
| 1387 |
+
input_ids: TFModelInputType | None = None,
|
| 1388 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1389 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1390 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1391 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1392 |
+
output_attentions: Optional[bool] = None,
|
| 1393 |
+
output_hidden_states: Optional[bool] = None,
|
| 1394 |
+
return_dict: Optional[bool] = None,
|
| 1395 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1396 |
+
training: Optional[bool] = False,
|
| 1397 |
+
) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
|
| 1398 |
+
r"""
|
| 1399 |
+
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1400 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
| 1401 |
+
"""
|
| 1402 |
+
outputs = self.roformer(
|
| 1403 |
+
input_ids=input_ids,
|
| 1404 |
+
attention_mask=attention_mask,
|
| 1405 |
+
token_type_ids=token_type_ids,
|
| 1406 |
+
head_mask=head_mask,
|
| 1407 |
+
inputs_embeds=inputs_embeds,
|
| 1408 |
+
output_attentions=output_attentions,
|
| 1409 |
+
output_hidden_states=output_hidden_states,
|
| 1410 |
+
return_dict=return_dict,
|
| 1411 |
+
training=training,
|
| 1412 |
+
)
|
| 1413 |
+
sequence_output = outputs[0]
|
| 1414 |
+
sequence_output = self.dropout(inputs=sequence_output, training=training)
|
| 1415 |
+
logits = self.classifier(inputs=sequence_output)
|
| 1416 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
|
| 1417 |
+
|
| 1418 |
+
if not return_dict:
|
| 1419 |
+
output = (logits,) + outputs[1:]
|
| 1420 |
+
return ((loss,) + output) if loss is not None else output
|
| 1421 |
+
|
| 1422 |
+
return TFTokenClassifierOutput(
|
| 1423 |
+
loss=loss,
|
| 1424 |
+
logits=logits,
|
| 1425 |
+
hidden_states=outputs.hidden_states,
|
| 1426 |
+
attentions=outputs.attentions,
|
| 1427 |
+
)
|
| 1428 |
+
|
| 1429 |
+
def build(self, input_shape=None):
|
| 1430 |
+
if self.built:
|
| 1431 |
+
return
|
| 1432 |
+
self.built = True
|
| 1433 |
+
if getattr(self, "roformer", None) is not None:
|
| 1434 |
+
with tf.name_scope(self.roformer.name):
|
| 1435 |
+
self.roformer.build(None)
|
| 1436 |
+
if getattr(self, "classifier", None) is not None:
|
| 1437 |
+
with tf.name_scope(self.classifier.name):
|
| 1438 |
+
self.classifier.build([None, None, self.config.hidden_size])
|
| 1439 |
+
|
| 1440 |
+
|
| 1441 |
+
@add_start_docstrings(
|
| 1442 |
+
"""
|
| 1443 |
+
RoFormer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 1444 |
+
layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1445 |
+
""",
|
| 1446 |
+
ROFORMER_START_DOCSTRING,
|
| 1447 |
+
)
|
| 1448 |
+
class TFRoFormerForQuestionAnswering(TFRoFormerPreTrainedModel, TFQuestionAnsweringLoss):
|
| 1449 |
+
def __init__(self, config: RoFormerConfig, *inputs, **kwargs):
|
| 1450 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1451 |
+
|
| 1452 |
+
self.num_labels = config.num_labels
|
| 1453 |
+
|
| 1454 |
+
self.roformer = TFRoFormerMainLayer(config, name="roformer")
|
| 1455 |
+
self.qa_outputs = keras.layers.Dense(
|
| 1456 |
+
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
| 1457 |
+
)
|
| 1458 |
+
self.config = config
|
| 1459 |
+
|
| 1460 |
+
@unpack_inputs
|
| 1461 |
+
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1462 |
+
@add_code_sample_docstrings(
|
| 1463 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1464 |
+
output_type=TFQuestionAnsweringModelOutput,
|
| 1465 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1466 |
+
)
|
| 1467 |
+
def call(
|
| 1468 |
+
self,
|
| 1469 |
+
input_ids: TFModelInputType | None = None,
|
| 1470 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1471 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1472 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1473 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1474 |
+
output_attentions: Optional[bool] = None,
|
| 1475 |
+
output_hidden_states: Optional[bool] = None,
|
| 1476 |
+
return_dict: Optional[bool] = None,
|
| 1477 |
+
start_positions: np.ndarray | tf.Tensor | None = None,
|
| 1478 |
+
end_positions: np.ndarray | tf.Tensor | None = None,
|
| 1479 |
+
training: Optional[bool] = False,
|
| 1480 |
+
) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
|
| 1481 |
+
r"""
|
| 1482 |
+
start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
|
| 1483 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1484 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1485 |
+
are not taken into account for computing the loss.
|
| 1486 |
+
end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
|
| 1487 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1488 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1489 |
+
are not taken into account for computing the loss.
|
| 1490 |
+
"""
|
| 1491 |
+
outputs = self.roformer(
|
| 1492 |
+
input_ids=input_ids,
|
| 1493 |
+
attention_mask=attention_mask,
|
| 1494 |
+
token_type_ids=token_type_ids,
|
| 1495 |
+
head_mask=head_mask,
|
| 1496 |
+
inputs_embeds=inputs_embeds,
|
| 1497 |
+
output_attentions=output_attentions,
|
| 1498 |
+
output_hidden_states=output_hidden_states,
|
| 1499 |
+
return_dict=return_dict,
|
| 1500 |
+
training=training,
|
| 1501 |
+
)
|
| 1502 |
+
sequence_output = outputs[0]
|
| 1503 |
+
logits = self.qa_outputs(inputs=sequence_output)
|
| 1504 |
+
start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1)
|
| 1505 |
+
start_logits = tf.squeeze(input=start_logits, axis=-1)
|
| 1506 |
+
end_logits = tf.squeeze(input=end_logits, axis=-1)
|
| 1507 |
+
loss = None
|
| 1508 |
+
|
| 1509 |
+
if start_positions is not None and end_positions is not None:
|
| 1510 |
+
labels = {"start_position": start_positions, "end_position": end_positions}
|
| 1511 |
+
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
|
| 1512 |
+
|
| 1513 |
+
if not return_dict:
|
| 1514 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 1515 |
+
return ((loss,) + output) if loss is not None else output
|
| 1516 |
+
|
| 1517 |
+
return TFQuestionAnsweringModelOutput(
|
| 1518 |
+
loss=loss,
|
| 1519 |
+
start_logits=start_logits,
|
| 1520 |
+
end_logits=end_logits,
|
| 1521 |
+
hidden_states=outputs.hidden_states,
|
| 1522 |
+
attentions=outputs.attentions,
|
| 1523 |
+
)
|
| 1524 |
+
|
| 1525 |
+
def build(self, input_shape=None):
|
| 1526 |
+
if self.built:
|
| 1527 |
+
return
|
| 1528 |
+
self.built = True
|
| 1529 |
+
if getattr(self, "roformer", None) is not None:
|
| 1530 |
+
with tf.name_scope(self.roformer.name):
|
| 1531 |
+
self.roformer.build(None)
|
| 1532 |
+
if getattr(self, "qa_outputs", None) is not None:
|
| 1533 |
+
with tf.name_scope(self.qa_outputs.name):
|
| 1534 |
+
self.qa_outputs.build([None, None, self.config.hidden_size])
|
| 1535 |
+
|
| 1536 |
+
|
| 1537 |
+
__all__ = [
|
| 1538 |
+
"TFRoFormerForCausalLM",
|
| 1539 |
+
"TFRoFormerForMaskedLM",
|
| 1540 |
+
"TFRoFormerForMultipleChoice",
|
| 1541 |
+
"TFRoFormerForQuestionAnswering",
|
| 1542 |
+
"TFRoFormerForSequenceClassification",
|
| 1543 |
+
"TFRoFormerForTokenClassification",
|
| 1544 |
+
"TFRoFormerLayer",
|
| 1545 |
+
"TFRoFormerModel",
|
| 1546 |
+
"TFRoFormerPreTrainedModel",
|
| 1547 |
+
]
|
docs/transformers/build/lib/transformers/models/roformer/tokenization_roformer.py
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for RoFormer."""
|
| 16 |
+
|
| 17 |
+
import collections
|
| 18 |
+
import os
|
| 19 |
+
import unicodedata
|
| 20 |
+
from typing import List, Optional, Tuple
|
| 21 |
+
|
| 22 |
+
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
| 23 |
+
from ...utils import logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Copied from transformers.models.bert.tokenization_bert.load_vocab
|
| 32 |
+
def load_vocab(vocab_file):
|
| 33 |
+
"""Loads a vocabulary file into a dictionary."""
|
| 34 |
+
vocab = collections.OrderedDict()
|
| 35 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
| 36 |
+
tokens = reader.readlines()
|
| 37 |
+
for index, token in enumerate(tokens):
|
| 38 |
+
token = token.rstrip("\n")
|
| 39 |
+
vocab[token] = index
|
| 40 |
+
return vocab
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
|
| 44 |
+
def whitespace_tokenize(text):
|
| 45 |
+
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
| 46 |
+
text = text.strip()
|
| 47 |
+
if not text:
|
| 48 |
+
return []
|
| 49 |
+
tokens = text.split()
|
| 50 |
+
return tokens
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
|
| 54 |
+
class BasicTokenizer:
|
| 55 |
+
"""
|
| 56 |
+
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
| 60 |
+
Whether or not to lowercase the input when tokenizing.
|
| 61 |
+
never_split (`Iterable`, *optional*):
|
| 62 |
+
Collection of tokens which will never be split during tokenization. Only has an effect when
|
| 63 |
+
`do_basic_tokenize=True`
|
| 64 |
+
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
|
| 65 |
+
Whether or not to tokenize Chinese characters.
|
| 66 |
+
|
| 67 |
+
This should likely be deactivated for Japanese (see this
|
| 68 |
+
[issue](https://github.com/huggingface/transformers/issues/328)).
|
| 69 |
+
strip_accents (`bool`, *optional*):
|
| 70 |
+
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
| 71 |
+
value for `lowercase` (as in the original BERT).
|
| 72 |
+
do_split_on_punc (`bool`, *optional*, defaults to `True`):
|
| 73 |
+
In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
|
| 74 |
+
the full context of the words, such as contractions.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
do_lower_case=True,
|
| 80 |
+
never_split=None,
|
| 81 |
+
tokenize_chinese_chars=True,
|
| 82 |
+
strip_accents=None,
|
| 83 |
+
do_split_on_punc=True,
|
| 84 |
+
):
|
| 85 |
+
if never_split is None:
|
| 86 |
+
never_split = []
|
| 87 |
+
self.do_lower_case = do_lower_case
|
| 88 |
+
self.never_split = set(never_split)
|
| 89 |
+
self.tokenize_chinese_chars = tokenize_chinese_chars
|
| 90 |
+
self.strip_accents = strip_accents
|
| 91 |
+
self.do_split_on_punc = do_split_on_punc
|
| 92 |
+
|
| 93 |
+
def tokenize(self, text, never_split=None):
|
| 94 |
+
"""
|
| 95 |
+
Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
never_split (`List[str]`, *optional*)
|
| 99 |
+
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
|
| 100 |
+
[`PreTrainedTokenizer.tokenize`]) List of token not to split.
|
| 101 |
+
"""
|
| 102 |
+
# union() returns a new set by concatenating the two sets.
|
| 103 |
+
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
|
| 104 |
+
text = self._clean_text(text)
|
| 105 |
+
|
| 106 |
+
# This was added on November 1st, 2018 for the multilingual and Chinese
|
| 107 |
+
# models. This is also applied to the English models now, but it doesn't
|
| 108 |
+
# matter since the English models were not trained on any Chinese data
|
| 109 |
+
# and generally don't have any Chinese data in them (there are Chinese
|
| 110 |
+
# characters in the vocabulary because Wikipedia does have some Chinese
|
| 111 |
+
# words in the English Wikipedia.).
|
| 112 |
+
if self.tokenize_chinese_chars:
|
| 113 |
+
text = self._tokenize_chinese_chars(text)
|
| 114 |
+
# prevents treating the same character with different unicode codepoints as different characters
|
| 115 |
+
unicode_normalized_text = unicodedata.normalize("NFC", text)
|
| 116 |
+
orig_tokens = whitespace_tokenize(unicode_normalized_text)
|
| 117 |
+
split_tokens = []
|
| 118 |
+
for token in orig_tokens:
|
| 119 |
+
if token not in never_split:
|
| 120 |
+
if self.do_lower_case:
|
| 121 |
+
token = token.lower()
|
| 122 |
+
if self.strip_accents is not False:
|
| 123 |
+
token = self._run_strip_accents(token)
|
| 124 |
+
elif self.strip_accents:
|
| 125 |
+
token = self._run_strip_accents(token)
|
| 126 |
+
split_tokens.extend(self._run_split_on_punc(token, never_split))
|
| 127 |
+
|
| 128 |
+
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
| 129 |
+
return output_tokens
|
| 130 |
+
|
| 131 |
+
def _run_strip_accents(self, text):
|
| 132 |
+
"""Strips accents from a piece of text."""
|
| 133 |
+
text = unicodedata.normalize("NFD", text)
|
| 134 |
+
output = []
|
| 135 |
+
for char in text:
|
| 136 |
+
cat = unicodedata.category(char)
|
| 137 |
+
if cat == "Mn":
|
| 138 |
+
continue
|
| 139 |
+
output.append(char)
|
| 140 |
+
return "".join(output)
|
| 141 |
+
|
| 142 |
+
def _run_split_on_punc(self, text, never_split=None):
|
| 143 |
+
"""Splits punctuation on a piece of text."""
|
| 144 |
+
if not self.do_split_on_punc or (never_split is not None and text in never_split):
|
| 145 |
+
return [text]
|
| 146 |
+
chars = list(text)
|
| 147 |
+
i = 0
|
| 148 |
+
start_new_word = True
|
| 149 |
+
output = []
|
| 150 |
+
while i < len(chars):
|
| 151 |
+
char = chars[i]
|
| 152 |
+
if _is_punctuation(char):
|
| 153 |
+
output.append([char])
|
| 154 |
+
start_new_word = True
|
| 155 |
+
else:
|
| 156 |
+
if start_new_word:
|
| 157 |
+
output.append([])
|
| 158 |
+
start_new_word = False
|
| 159 |
+
output[-1].append(char)
|
| 160 |
+
i += 1
|
| 161 |
+
|
| 162 |
+
return ["".join(x) for x in output]
|
| 163 |
+
|
| 164 |
+
def _tokenize_chinese_chars(self, text):
|
| 165 |
+
"""Adds whitespace around any CJK character."""
|
| 166 |
+
output = []
|
| 167 |
+
for char in text:
|
| 168 |
+
cp = ord(char)
|
| 169 |
+
if self._is_chinese_char(cp):
|
| 170 |
+
output.append(" ")
|
| 171 |
+
output.append(char)
|
| 172 |
+
output.append(" ")
|
| 173 |
+
else:
|
| 174 |
+
output.append(char)
|
| 175 |
+
return "".join(output)
|
| 176 |
+
|
| 177 |
+
def _is_chinese_char(self, cp):
|
| 178 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
| 179 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
| 180 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
| 181 |
+
#
|
| 182 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
| 183 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
| 184 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
| 185 |
+
# space-separated words, so they are not treated specially and handled
|
| 186 |
+
# like the all of the other languages.
|
| 187 |
+
if (
|
| 188 |
+
(cp >= 0x4E00 and cp <= 0x9FFF)
|
| 189 |
+
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
| 190 |
+
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
| 191 |
+
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
| 192 |
+
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
| 193 |
+
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
| 194 |
+
or (cp >= 0xF900 and cp <= 0xFAFF)
|
| 195 |
+
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
| 196 |
+
): #
|
| 197 |
+
return True
|
| 198 |
+
|
| 199 |
+
return False
|
| 200 |
+
|
| 201 |
+
def _clean_text(self, text):
|
| 202 |
+
"""Performs invalid character removal and whitespace cleanup on text."""
|
| 203 |
+
output = []
|
| 204 |
+
for char in text:
|
| 205 |
+
cp = ord(char)
|
| 206 |
+
if cp == 0 or cp == 0xFFFD or _is_control(char):
|
| 207 |
+
continue
|
| 208 |
+
if _is_whitespace(char):
|
| 209 |
+
output.append(" ")
|
| 210 |
+
else:
|
| 211 |
+
output.append(char)
|
| 212 |
+
return "".join(output)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer
|
| 216 |
+
class WordpieceTokenizer:
|
| 217 |
+
"""Runs WordPiece tokenization."""
|
| 218 |
+
|
| 219 |
+
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
|
| 220 |
+
self.vocab = vocab
|
| 221 |
+
self.unk_token = unk_token
|
| 222 |
+
self.max_input_chars_per_word = max_input_chars_per_word
|
| 223 |
+
|
| 224 |
+
def tokenize(self, text):
|
| 225 |
+
"""
|
| 226 |
+
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
|
| 227 |
+
tokenization using the given vocabulary.
|
| 228 |
+
|
| 229 |
+
For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
text: A single token or whitespace separated tokens. This should have
|
| 233 |
+
already been passed through *BasicTokenizer*.
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
A list of wordpiece tokens.
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
output_tokens = []
|
| 240 |
+
for token in whitespace_tokenize(text):
|
| 241 |
+
chars = list(token)
|
| 242 |
+
if len(chars) > self.max_input_chars_per_word:
|
| 243 |
+
output_tokens.append(self.unk_token)
|
| 244 |
+
continue
|
| 245 |
+
|
| 246 |
+
is_bad = False
|
| 247 |
+
start = 0
|
| 248 |
+
sub_tokens = []
|
| 249 |
+
while start < len(chars):
|
| 250 |
+
end = len(chars)
|
| 251 |
+
cur_substr = None
|
| 252 |
+
while start < end:
|
| 253 |
+
substr = "".join(chars[start:end])
|
| 254 |
+
if start > 0:
|
| 255 |
+
substr = "##" + substr
|
| 256 |
+
if substr in self.vocab:
|
| 257 |
+
cur_substr = substr
|
| 258 |
+
break
|
| 259 |
+
end -= 1
|
| 260 |
+
if cur_substr is None:
|
| 261 |
+
is_bad = True
|
| 262 |
+
break
|
| 263 |
+
sub_tokens.append(cur_substr)
|
| 264 |
+
start = end
|
| 265 |
+
|
| 266 |
+
if is_bad:
|
| 267 |
+
output_tokens.append(self.unk_token)
|
| 268 |
+
else:
|
| 269 |
+
output_tokens.extend(sub_tokens)
|
| 270 |
+
return output_tokens
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class RoFormerTokenizer(PreTrainedTokenizer):
|
| 274 |
+
r"""
|
| 275 |
+
Construct a RoFormer tokenizer. Based on [Rust Jieba](https://pypi.org/project/rjieba/).
|
| 276 |
+
|
| 277 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 278 |
+
this superclass for more information regarding those methods.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
vocab_file (`str`):
|
| 282 |
+
File containing the vocabulary.
|
| 283 |
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
| 284 |
+
Whether or not to lowercase the input when tokenizing.
|
| 285 |
+
do_basic_tokenize (`bool`, *optional*, defaults to `True`):
|
| 286 |
+
Whether or not to do basic tokenization before WordPiece.
|
| 287 |
+
never_split (`Iterable`, *optional*):
|
| 288 |
+
Collection of tokens which will never be split during tokenization. Only has an effect when
|
| 289 |
+
`do_basic_tokenize=True`
|
| 290 |
+
unk_token (`str`, *optional*, defaults to `"[UNK]"`):
|
| 291 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 292 |
+
token instead.
|
| 293 |
+
sep_token (`str`, *optional*, defaults to `"[SEP]"`):
|
| 294 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 295 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 296 |
+
token of a sequence built with special tokens.
|
| 297 |
+
pad_token (`str`, *optional*, defaults to `"[PAD]"`):
|
| 298 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 299 |
+
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
|
| 300 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
| 301 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
| 302 |
+
mask_token (`str`, *optional*, defaults to `"[MASK]"`):
|
| 303 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 304 |
+
modeling. This is the token which the model will try to predict.
|
| 305 |
+
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
|
| 306 |
+
Whether or not to tokenize Chinese characters.
|
| 307 |
+
|
| 308 |
+
This should likely be deactivated for Japanese (see this
|
| 309 |
+
[issue](https://github.com/huggingface/transformers/issues/328)).
|
| 310 |
+
strip_accents (`bool`, *optional*):
|
| 311 |
+
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
| 312 |
+
value for `lowercase` (as in the original BERT).
|
| 313 |
+
|
| 314 |
+
Example:
|
| 315 |
+
|
| 316 |
+
```python
|
| 317 |
+
>>> from transformers import RoFormerTokenizer
|
| 318 |
+
|
| 319 |
+
>>> tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base")
|
| 320 |
+
>>> tokenizer.tokenize("今天天气非常好。")
|
| 321 |
+
['今', '天', '天', '气', '非常', '好', '。']
|
| 322 |
+
```"""
|
| 323 |
+
|
| 324 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 325 |
+
|
| 326 |
+
def __init__(
|
| 327 |
+
self,
|
| 328 |
+
vocab_file,
|
| 329 |
+
do_lower_case=True,
|
| 330 |
+
do_basic_tokenize=True,
|
| 331 |
+
never_split=None,
|
| 332 |
+
unk_token="[UNK]",
|
| 333 |
+
sep_token="[SEP]",
|
| 334 |
+
pad_token="[PAD]",
|
| 335 |
+
cls_token="[CLS]",
|
| 336 |
+
mask_token="[MASK]",
|
| 337 |
+
tokenize_chinese_chars=True,
|
| 338 |
+
strip_accents=None,
|
| 339 |
+
**kwargs,
|
| 340 |
+
):
|
| 341 |
+
if not os.path.isfile(vocab_file):
|
| 342 |
+
raise ValueError(
|
| 343 |
+
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
|
| 344 |
+
" model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
| 345 |
+
)
|
| 346 |
+
self.vocab = load_vocab(vocab_file)
|
| 347 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 348 |
+
self.do_basic_tokenize = do_basic_tokenize
|
| 349 |
+
if do_basic_tokenize:
|
| 350 |
+
self.basic_tokenizer = BasicTokenizer(
|
| 351 |
+
do_lower_case=do_lower_case,
|
| 352 |
+
never_split=never_split,
|
| 353 |
+
tokenize_chinese_chars=tokenize_chinese_chars,
|
| 354 |
+
strip_accents=strip_accents,
|
| 355 |
+
)
|
| 356 |
+
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
|
| 357 |
+
try:
|
| 358 |
+
import rjieba
|
| 359 |
+
except ImportError:
|
| 360 |
+
raise ImportError(
|
| 361 |
+
"You need to install rjieba to use RoFormerTokenizer. "
|
| 362 |
+
"See https://pypi.org/project/rjieba/ for installation."
|
| 363 |
+
)
|
| 364 |
+
self.jieba = rjieba
|
| 365 |
+
|
| 366 |
+
super().__init__(
|
| 367 |
+
do_lower_case=do_lower_case,
|
| 368 |
+
do_basic_tokenize=do_basic_tokenize,
|
| 369 |
+
never_split=never_split,
|
| 370 |
+
unk_token=unk_token,
|
| 371 |
+
sep_token=sep_token,
|
| 372 |
+
pad_token=pad_token,
|
| 373 |
+
cls_token=cls_token,
|
| 374 |
+
mask_token=mask_token,
|
| 375 |
+
tokenize_chinese_chars=tokenize_chinese_chars,
|
| 376 |
+
strip_accents=strip_accents,
|
| 377 |
+
**kwargs,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
@property
|
| 381 |
+
def do_lower_case(self):
|
| 382 |
+
return self.basic_tokenizer.do_lower_case
|
| 383 |
+
|
| 384 |
+
@property
|
| 385 |
+
def vocab_size(self):
|
| 386 |
+
return len(self.vocab)
|
| 387 |
+
|
| 388 |
+
def __getstate__(self):
|
| 389 |
+
state = self.__dict__.copy()
|
| 390 |
+
state["jieba"] = None
|
| 391 |
+
return state
|
| 392 |
+
|
| 393 |
+
def __setstate__(self, d):
|
| 394 |
+
self.__dict__ = d
|
| 395 |
+
import rjieba
|
| 396 |
+
|
| 397 |
+
self.jieba = rjieba
|
| 398 |
+
|
| 399 |
+
def get_vocab(self):
|
| 400 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 401 |
+
|
| 402 |
+
def _tokenize(self, text, use_jieba=True):
|
| 403 |
+
split_tokens = []
|
| 404 |
+
if use_jieba:
|
| 405 |
+
for wholword in self.jieba.cut(text, False):
|
| 406 |
+
if wholword in self.vocab:
|
| 407 |
+
split_tokens.append(wholword)
|
| 408 |
+
else:
|
| 409 |
+
# use bert tokenizer to _tokenize
|
| 410 |
+
char_list = self._tokenize(wholword, use_jieba=False)
|
| 411 |
+
split_tokens.extend(char_list)
|
| 412 |
+
else:
|
| 413 |
+
if self.do_basic_tokenize:
|
| 414 |
+
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
| 415 |
+
# If the token is part of the never_split set
|
| 416 |
+
if token in self.basic_tokenizer.never_split:
|
| 417 |
+
split_tokens.append(token)
|
| 418 |
+
else:
|
| 419 |
+
split_tokens += self.wordpiece_tokenizer.tokenize(token)
|
| 420 |
+
else:
|
| 421 |
+
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
| 422 |
+
return split_tokens
|
| 423 |
+
|
| 424 |
+
def _convert_token_to_id(self, token):
|
| 425 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 426 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 427 |
+
|
| 428 |
+
def _convert_id_to_token(self, index):
|
| 429 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 430 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 431 |
+
|
| 432 |
+
def convert_tokens_to_string(self, tokens):
|
| 433 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 434 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 435 |
+
return out_string
|
| 436 |
+
|
| 437 |
+
def build_inputs_with_special_tokens(
|
| 438 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 439 |
+
) -> List[int]:
|
| 440 |
+
"""
|
| 441 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 442 |
+
adding special tokens. A RoFormer sequence has the following format:
|
| 443 |
+
|
| 444 |
+
- single sequence: `[CLS] X [SEP]`
|
| 445 |
+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
| 446 |
+
|
| 447 |
+
Args:
|
| 448 |
+
token_ids_0 (`List[int]`):
|
| 449 |
+
List of IDs to which the special tokens will be added.
|
| 450 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 451 |
+
Optional second list of IDs for sequence pairs.
|
| 452 |
+
|
| 453 |
+
Returns:
|
| 454 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 455 |
+
"""
|
| 456 |
+
if token_ids_1 is None:
|
| 457 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 458 |
+
cls = [self.cls_token_id]
|
| 459 |
+
sep = [self.sep_token_id]
|
| 460 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 461 |
+
|
| 462 |
+
def get_special_tokens_mask(
|
| 463 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 464 |
+
) -> List[int]:
|
| 465 |
+
"""
|
| 466 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 467 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 468 |
+
|
| 469 |
+
Args:
|
| 470 |
+
token_ids_0 (`List[int]`):
|
| 471 |
+
List of IDs.
|
| 472 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 473 |
+
Optional second list of IDs for sequence pairs.
|
| 474 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 475 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 476 |
+
|
| 477 |
+
Returns:
|
| 478 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 479 |
+
"""
|
| 480 |
+
|
| 481 |
+
if already_has_special_tokens:
|
| 482 |
+
return super().get_special_tokens_mask(
|
| 483 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
if token_ids_1 is not None:
|
| 487 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 488 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 489 |
+
|
| 490 |
+
def create_token_type_ids_from_sequences(
|
| 491 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 492 |
+
) -> List[int]:
|
| 493 |
+
"""
|
| 494 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A RoFormer
|
| 495 |
+
sequence pair mask has the following format:
|
| 496 |
+
|
| 497 |
+
```
|
| 498 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 499 |
+
| first sequence | second sequence |
|
| 500 |
+
```
|
| 501 |
+
|
| 502 |
+
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
token_ids_0 (`List[int]`):
|
| 506 |
+
List of IDs.
|
| 507 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 508 |
+
Optional second list of IDs for sequence pairs.
|
| 509 |
+
|
| 510 |
+
Returns:
|
| 511 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
| 512 |
+
"""
|
| 513 |
+
sep = [self.sep_token_id]
|
| 514 |
+
cls = [self.cls_token_id]
|
| 515 |
+
if token_ids_1 is None:
|
| 516 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 517 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 518 |
+
|
| 519 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 520 |
+
index = 0
|
| 521 |
+
if os.path.isdir(save_directory):
|
| 522 |
+
vocab_file = os.path.join(
|
| 523 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 524 |
+
)
|
| 525 |
+
else:
|
| 526 |
+
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
|
| 527 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 528 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 529 |
+
if index != token_index:
|
| 530 |
+
logger.warning(
|
| 531 |
+
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
| 532 |
+
" Please check that the vocabulary is not corrupted!"
|
| 533 |
+
)
|
| 534 |
+
index = token_index
|
| 535 |
+
writer.write(token + "\n")
|
| 536 |
+
index += 1
|
| 537 |
+
return (vocab_file,)
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
__all__ = ["RoFormerTokenizer"]
|
docs/transformers/build/lib/transformers/models/roformer/tokenization_roformer_fast.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for RoFormer."""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
from typing import List, Optional, Tuple
|
| 19 |
+
|
| 20 |
+
from tokenizers import normalizers
|
| 21 |
+
from tokenizers.pre_tokenizers import BertPreTokenizer, PreTokenizer
|
| 22 |
+
|
| 23 |
+
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
| 24 |
+
from ...utils import logging
|
| 25 |
+
from .tokenization_roformer import RoFormerTokenizer
|
| 26 |
+
from .tokenization_utils import JiebaPreTokenizer
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__)
|
| 30 |
+
|
| 31 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class RoFormerTokenizerFast(PreTrainedTokenizerFast):
|
| 35 |
+
r"""
|
| 36 |
+
Construct a "fast" RoFormer tokenizer (backed by HuggingFace's *tokenizers* library).
|
| 37 |
+
|
| 38 |
+
[`RoFormerTokenizerFast`] is almost identical to [`BertTokenizerFast`] and runs end-to-end tokenization:
|
| 39 |
+
punctuation splitting and wordpiece. There are some difference between them when tokenizing Chinese.
|
| 40 |
+
|
| 41 |
+
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
| 42 |
+
refer to this superclass for more information regarding those methods.
|
| 43 |
+
|
| 44 |
+
Example:
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
>>> from transformers import RoFormerTokenizerFast
|
| 48 |
+
|
| 49 |
+
>>> tokenizer = RoFormerTokenizerFast.from_pretrained("junnyu/roformer_chinese_base")
|
| 50 |
+
>>> tokenizer.tokenize("今天天气非常好。")
|
| 51 |
+
['今', '天', '天', '气', '非常', '好', '。']
|
| 52 |
+
```"""
|
| 53 |
+
|
| 54 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 55 |
+
slow_tokenizer_class = RoFormerTokenizer
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
vocab_file=None,
|
| 60 |
+
tokenizer_file=None,
|
| 61 |
+
do_lower_case=True,
|
| 62 |
+
unk_token="[UNK]",
|
| 63 |
+
sep_token="[SEP]",
|
| 64 |
+
pad_token="[PAD]",
|
| 65 |
+
cls_token="[CLS]",
|
| 66 |
+
mask_token="[MASK]",
|
| 67 |
+
tokenize_chinese_chars=True,
|
| 68 |
+
strip_accents=None,
|
| 69 |
+
**kwargs,
|
| 70 |
+
):
|
| 71 |
+
super().__init__(
|
| 72 |
+
vocab_file,
|
| 73 |
+
tokenizer_file=tokenizer_file,
|
| 74 |
+
do_lower_case=do_lower_case,
|
| 75 |
+
unk_token=unk_token,
|
| 76 |
+
sep_token=sep_token,
|
| 77 |
+
pad_token=pad_token,
|
| 78 |
+
cls_token=cls_token,
|
| 79 |
+
mask_token=mask_token,
|
| 80 |
+
tokenize_chinese_chars=tokenize_chinese_chars,
|
| 81 |
+
strip_accents=strip_accents,
|
| 82 |
+
**kwargs,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
| 86 |
+
if (
|
| 87 |
+
normalizer_state.get("lowercase", do_lower_case) != do_lower_case
|
| 88 |
+
or normalizer_state.get("strip_accents", strip_accents) != strip_accents
|
| 89 |
+
):
|
| 90 |
+
normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
|
| 91 |
+
normalizer_state["lowercase"] = do_lower_case
|
| 92 |
+
normalizer_state["strip_accents"] = strip_accents
|
| 93 |
+
self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
|
| 94 |
+
|
| 95 |
+
# Make sure we correctly set the custom PreTokenizer
|
| 96 |
+
vocab = self.backend_tokenizer.get_vocab()
|
| 97 |
+
self.backend_tokenizer.pre_tokenizer = PreTokenizer.custom(JiebaPreTokenizer(vocab))
|
| 98 |
+
|
| 99 |
+
self.do_lower_case = do_lower_case
|
| 100 |
+
|
| 101 |
+
def __getstate__(self):
|
| 102 |
+
state = self.__dict__.copy()
|
| 103 |
+
state["_tokenizer"].pre_tokenizer = BertPreTokenizer()
|
| 104 |
+
return state
|
| 105 |
+
|
| 106 |
+
def __setstate__(self, d):
|
| 107 |
+
self.__dict__ = d
|
| 108 |
+
vocab = self.__dict__["_tokenizer"].get_vocab()
|
| 109 |
+
self.__dict__["_tokenizer"].pre_tokenizer = PreTokenizer.custom(JiebaPreTokenizer(vocab))
|
| 110 |
+
|
| 111 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 112 |
+
"""
|
| 113 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 114 |
+
adding special tokens. A RoFormer sequence has the following format:
|
| 115 |
+
|
| 116 |
+
- single sequence: `[CLS] X [SEP]`
|
| 117 |
+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
token_ids_0 (`List[int]`):
|
| 121 |
+
List of IDs to which the special tokens will be added.
|
| 122 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 123 |
+
Optional second list of IDs for sequence pairs.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 127 |
+
"""
|
| 128 |
+
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 129 |
+
|
| 130 |
+
if token_ids_1 is not None:
|
| 131 |
+
output += token_ids_1 + [self.sep_token_id]
|
| 132 |
+
|
| 133 |
+
return output
|
| 134 |
+
|
| 135 |
+
def create_token_type_ids_from_sequences(
|
| 136 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 137 |
+
) -> List[int]:
|
| 138 |
+
"""
|
| 139 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A RoFormer
|
| 140 |
+
sequence pair mask has the following format:
|
| 141 |
+
|
| 142 |
+
```
|
| 143 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 144 |
+
| first sequence | second sequence |
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
token_ids_0 (`List[int]`):
|
| 151 |
+
List of IDs.
|
| 152 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 153 |
+
Optional second list of IDs for sequence pairs.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
| 157 |
+
"""
|
| 158 |
+
sep = [self.sep_token_id]
|
| 159 |
+
cls = [self.cls_token_id]
|
| 160 |
+
if token_ids_1 is None:
|
| 161 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 162 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 163 |
+
|
| 164 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 165 |
+
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
| 166 |
+
return tuple(files)
|
| 167 |
+
|
| 168 |
+
def save_pretrained(
|
| 169 |
+
self,
|
| 170 |
+
save_directory,
|
| 171 |
+
legacy_format=None,
|
| 172 |
+
filename_prefix=None,
|
| 173 |
+
push_to_hub=False,
|
| 174 |
+
**kwargs,
|
| 175 |
+
):
|
| 176 |
+
self.backend_tokenizer.pre_tokenizer = BertPreTokenizer()
|
| 177 |
+
return super().save_pretrained(save_directory, legacy_format, filename_prefix, push_to_hub, **kwargs)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
__all__ = ["RoFormerTokenizerFast"]
|
docs/transformers/build/lib/transformers/models/roformer/tokenization_utils.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization utils for RoFormer."""
|
| 16 |
+
|
| 17 |
+
from typing import List
|
| 18 |
+
|
| 19 |
+
from tokenizers import NormalizedString, PreTokenizedString, normalizers
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class JiebaPreTokenizer:
|
| 23 |
+
def __init__(self, vocab) -> None:
|
| 24 |
+
self.vocab = vocab
|
| 25 |
+
self.normalizers = normalizers.BertNormalizer(
|
| 26 |
+
clean_text=False,
|
| 27 |
+
handle_chinese_chars=True,
|
| 28 |
+
strip_accents=False,
|
| 29 |
+
lowercase=False,
|
| 30 |
+
)
|
| 31 |
+
try:
|
| 32 |
+
import rjieba
|
| 33 |
+
except ImportError:
|
| 34 |
+
raise ImportError(
|
| 35 |
+
"You need to install rjieba to use RoFormerTokenizer. "
|
| 36 |
+
"See https://pypi.org/project/rjieba/ for installation."
|
| 37 |
+
)
|
| 38 |
+
self.jieba = rjieba
|
| 39 |
+
|
| 40 |
+
def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
|
| 41 |
+
splits = []
|
| 42 |
+
|
| 43 |
+
# this code slice normalized_string is too slow (6s) but test_alignment_methods can pass
|
| 44 |
+
for token, start, end in self.jieba.tokenize(str(normalized_string), hmm=False):
|
| 45 |
+
if token in self.vocab:
|
| 46 |
+
splits.append(normalized_string[start:end])
|
| 47 |
+
else:
|
| 48 |
+
token_list = self.normalizers.normalize_str(token).split()
|
| 49 |
+
for token in token_list:
|
| 50 |
+
if token:
|
| 51 |
+
end = start + len(token)
|
| 52 |
+
splits.append(normalized_string[start:end])
|
| 53 |
+
start = end
|
| 54 |
+
|
| 55 |
+
# this code test_alignment_methods can't pass but fast (300ms)
|
| 56 |
+
# for token in self.jieba.cut(str(normalized_string), False):
|
| 57 |
+
# if token in self.vocab:
|
| 58 |
+
# splits.append(NormalizedString(token))
|
| 59 |
+
# else:
|
| 60 |
+
# token_list = self.normalizers.normalize_str(token).split()
|
| 61 |
+
# for token in token_list:
|
| 62 |
+
# if token:
|
| 63 |
+
# splits.append(NormalizedString(token))
|
| 64 |
+
|
| 65 |
+
return splits
|
| 66 |
+
|
| 67 |
+
def pre_tokenize(self, pretok: PreTokenizedString):
|
| 68 |
+
pretok.split(self.jieba_split)
|
docs/transformers/build/lib/transformers/models/rt_detr/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from typing import TYPE_CHECKING
|
| 17 |
+
|
| 18 |
+
from ...utils import _LazyModule
|
| 19 |
+
from ...utils.import_utils import define_import_structure
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if TYPE_CHECKING:
|
| 23 |
+
from .configuration_rt_detr import *
|
| 24 |
+
from .configuration_rt_detr_resnet import *
|
| 25 |
+
from .image_processing_rt_detr import *
|
| 26 |
+
from .image_processing_rt_detr_fast import *
|
| 27 |
+
from .modeling_rt_detr import *
|
| 28 |
+
from .modeling_rt_detr_resnet import *
|
| 29 |
+
else:
|
| 30 |
+
import sys
|
| 31 |
+
|
| 32 |
+
_file = globals()["__file__"]
|
| 33 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/rt_detr/configuration_rt_detr.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""RT-DETR model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
from ...utils.backbone_utils import verify_backbone_config_arguments
|
| 20 |
+
from ..auto import CONFIG_MAPPING
|
| 21 |
+
from .configuration_rt_detr_resnet import RTDetrResNetConfig
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class RTDetrConfig(PretrainedConfig):
|
| 28 |
+
r"""
|
| 29 |
+
This is the configuration class to store the configuration of a [`RTDetrModel`]. It is used to instantiate a
|
| 30 |
+
RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 31 |
+
with the defaults will yield a similar configuration to that of the RT-DETR
|
| 32 |
+
[checkpoing/todo](https://huggingface.co/checkpoing/todo) architecture.
|
| 33 |
+
|
| 34 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 35 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
initializer_range (`float`, *optional*, defaults to 0.01):
|
| 39 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 40 |
+
initializer_bias_prior_prob (`float`, *optional*):
|
| 41 |
+
The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
|
| 42 |
+
If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
|
| 43 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 44 |
+
The epsilon used by the layer normalization layers.
|
| 45 |
+
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 46 |
+
The epsilon used by the batch normalization layers.
|
| 47 |
+
backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
|
| 48 |
+
The configuration of the backbone model.
|
| 49 |
+
backbone (`str`, *optional*):
|
| 50 |
+
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
| 51 |
+
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
|
| 52 |
+
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
| 53 |
+
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
|
| 54 |
+
Whether to use pretrained weights for the backbone.
|
| 55 |
+
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
| 56 |
+
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
| 57 |
+
library.
|
| 58 |
+
freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
|
| 59 |
+
Whether to freeze the batch normalization layers in the backbone.
|
| 60 |
+
backbone_kwargs (`dict`, *optional*):
|
| 61 |
+
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
| 62 |
+
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
| 63 |
+
encoder_hidden_dim (`int`, *optional*, defaults to 256):
|
| 64 |
+
Dimension of the layers in hybrid encoder.
|
| 65 |
+
encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
|
| 66 |
+
Multi level features input for encoder.
|
| 67 |
+
feat_strides (`List[int]`, *optional*, defaults to `[8, 16, 32]`):
|
| 68 |
+
Strides used in each feature map.
|
| 69 |
+
encoder_layers (`int`, *optional*, defaults to 1):
|
| 70 |
+
Total of layers to be used by the encoder.
|
| 71 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 1024):
|
| 72 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 73 |
+
encoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 74 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 75 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 76 |
+
The ratio for all dropout layers.
|
| 77 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
| 78 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 79 |
+
encode_proj_layers (`List[int]`, *optional*, defaults to `[2]`):
|
| 80 |
+
Indexes of the projected layers to be used in the encoder.
|
| 81 |
+
positional_encoding_temperature (`int`, *optional*, defaults to 10000):
|
| 82 |
+
The temperature parameter used to create the positional encodings.
|
| 83 |
+
encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
|
| 84 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 85 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 86 |
+
activation_function (`str`, *optional*, defaults to `"silu"`):
|
| 87 |
+
The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
|
| 88 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 89 |
+
eval_size (`Tuple[int, int]`, *optional*):
|
| 90 |
+
Height and width used to computes the effective height and width of the position embeddings after taking
|
| 91 |
+
into account the stride.
|
| 92 |
+
normalize_before (`bool`, *optional*, defaults to `False`):
|
| 93 |
+
Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
|
| 94 |
+
feed-forward modules.
|
| 95 |
+
hidden_expansion (`float`, *optional*, defaults to 1.0):
|
| 96 |
+
Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
|
| 97 |
+
d_model (`int`, *optional*, defaults to 256):
|
| 98 |
+
Dimension of the layers exclude hybrid encoder.
|
| 99 |
+
num_queries (`int`, *optional*, defaults to 300):
|
| 100 |
+
Number of object queries.
|
| 101 |
+
decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
|
| 102 |
+
Multi level features dimension for decoder
|
| 103 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 1024):
|
| 104 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 105 |
+
num_feature_levels (`int`, *optional*, defaults to 3):
|
| 106 |
+
The number of input feature levels.
|
| 107 |
+
decoder_n_points (`int`, *optional*, defaults to 4):
|
| 108 |
+
The number of sampled keys in each feature level for each attention head in the decoder.
|
| 109 |
+
decoder_layers (`int`, *optional*, defaults to 6):
|
| 110 |
+
Number of decoder layers.
|
| 111 |
+
decoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 112 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 113 |
+
decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
|
| 114 |
+
The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
|
| 115 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 116 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 117 |
+
The dropout ratio for the attention probabilities.
|
| 118 |
+
num_denoising (`int`, *optional*, defaults to 100):
|
| 119 |
+
The total number of denoising tasks or queries to be used for contrastive denoising.
|
| 120 |
+
label_noise_ratio (`float`, *optional*, defaults to 0.5):
|
| 121 |
+
The fraction of denoising labels to which random noise should be added.
|
| 122 |
+
box_noise_scale (`float`, *optional*, defaults to 1.0):
|
| 123 |
+
Scale or magnitude of noise to be added to the bounding boxes.
|
| 124 |
+
learn_initial_query (`bool`, *optional*, defaults to `False`):
|
| 125 |
+
Indicates whether the initial query embeddings for the decoder should be learned during training
|
| 126 |
+
anchor_image_size (`Tuple[int, int]`, *optional*):
|
| 127 |
+
Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied.
|
| 128 |
+
disable_custom_kernels (`bool`, *optional*, defaults to `True`):
|
| 129 |
+
Whether to disable custom kernels.
|
| 130 |
+
with_box_refine (`bool`, *optional*, defaults to `True`):
|
| 131 |
+
Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
|
| 132 |
+
based on the predictions from the previous layer.
|
| 133 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
| 134 |
+
Whether the architecture has an encoder decoder structure.
|
| 135 |
+
matcher_alpha (`float`, *optional*, defaults to 0.25):
|
| 136 |
+
Parameter alpha used by the Hungarian Matcher.
|
| 137 |
+
matcher_gamma (`float`, *optional*, defaults to 2.0):
|
| 138 |
+
Parameter gamma used by the Hungarian Matcher.
|
| 139 |
+
matcher_class_cost (`float`, *optional*, defaults to 2.0):
|
| 140 |
+
The relative weight of the class loss used by the Hungarian Matcher.
|
| 141 |
+
matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
|
| 142 |
+
The relative weight of the bounding box loss used by the Hungarian Matcher.
|
| 143 |
+
matcher_giou_cost (`float`, *optional*, defaults to 2.0):
|
| 144 |
+
The relative weight of the giou loss of used by the Hungarian Matcher.
|
| 145 |
+
use_focal_loss (`bool`, *optional*, defaults to `True`):
|
| 146 |
+
Parameter informing if focal focal should be used.
|
| 147 |
+
auxiliary_loss (`bool`, *optional*, defaults to `True`):
|
| 148 |
+
Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
|
| 149 |
+
focal_loss_alpha (`float`, *optional*, defaults to 0.75):
|
| 150 |
+
Parameter alpha used to compute the focal loss.
|
| 151 |
+
focal_loss_gamma (`float`, *optional*, defaults to 2.0):
|
| 152 |
+
Parameter gamma used to compute the focal loss.
|
| 153 |
+
weight_loss_vfl (`float`, *optional*, defaults to 1.0):
|
| 154 |
+
Relative weight of the varifocal loss in the object detection loss.
|
| 155 |
+
weight_loss_bbox (`float`, *optional*, defaults to 5.0):
|
| 156 |
+
Relative weight of the L1 bounding box loss in the object detection loss.
|
| 157 |
+
weight_loss_giou (`float`, *optional*, defaults to 2.0):
|
| 158 |
+
Relative weight of the generalized IoU loss in the object detection loss.
|
| 159 |
+
eos_coefficient (`float`, *optional*, defaults to 0.0001):
|
| 160 |
+
Relative classification weight of the 'no-object' class in the object detection loss.
|
| 161 |
+
|
| 162 |
+
Examples:
|
| 163 |
+
|
| 164 |
+
```python
|
| 165 |
+
>>> from transformers import RTDetrConfig, RTDetrModel
|
| 166 |
+
|
| 167 |
+
>>> # Initializing a RT-DETR configuration
|
| 168 |
+
>>> configuration = RTDetrConfig()
|
| 169 |
+
|
| 170 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 171 |
+
>>> model = RTDetrModel(configuration)
|
| 172 |
+
|
| 173 |
+
>>> # Accessing the model configuration
|
| 174 |
+
>>> configuration = model.config
|
| 175 |
+
```"""
|
| 176 |
+
|
| 177 |
+
model_type = "rt_detr"
|
| 178 |
+
layer_types = ["basic", "bottleneck"]
|
| 179 |
+
attribute_map = {
|
| 180 |
+
"hidden_size": "d_model",
|
| 181 |
+
"num_attention_heads": "encoder_attention_heads",
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
def __init__(
|
| 185 |
+
self,
|
| 186 |
+
initializer_range=0.01,
|
| 187 |
+
initializer_bias_prior_prob=None,
|
| 188 |
+
layer_norm_eps=1e-5,
|
| 189 |
+
batch_norm_eps=1e-5,
|
| 190 |
+
# backbone
|
| 191 |
+
backbone_config=None,
|
| 192 |
+
backbone=None,
|
| 193 |
+
use_pretrained_backbone=False,
|
| 194 |
+
use_timm_backbone=False,
|
| 195 |
+
freeze_backbone_batch_norms=True,
|
| 196 |
+
backbone_kwargs=None,
|
| 197 |
+
# encoder HybridEncoder
|
| 198 |
+
encoder_hidden_dim=256,
|
| 199 |
+
encoder_in_channels=[512, 1024, 2048],
|
| 200 |
+
feat_strides=[8, 16, 32],
|
| 201 |
+
encoder_layers=1,
|
| 202 |
+
encoder_ffn_dim=1024,
|
| 203 |
+
encoder_attention_heads=8,
|
| 204 |
+
dropout=0.0,
|
| 205 |
+
activation_dropout=0.0,
|
| 206 |
+
encode_proj_layers=[2],
|
| 207 |
+
positional_encoding_temperature=10000,
|
| 208 |
+
encoder_activation_function="gelu",
|
| 209 |
+
activation_function="silu",
|
| 210 |
+
eval_size=None,
|
| 211 |
+
normalize_before=False,
|
| 212 |
+
hidden_expansion=1.0,
|
| 213 |
+
# decoder RTDetrTransformer
|
| 214 |
+
d_model=256,
|
| 215 |
+
num_queries=300,
|
| 216 |
+
decoder_in_channels=[256, 256, 256],
|
| 217 |
+
decoder_ffn_dim=1024,
|
| 218 |
+
num_feature_levels=3,
|
| 219 |
+
decoder_n_points=4,
|
| 220 |
+
decoder_layers=6,
|
| 221 |
+
decoder_attention_heads=8,
|
| 222 |
+
decoder_activation_function="relu",
|
| 223 |
+
attention_dropout=0.0,
|
| 224 |
+
num_denoising=100,
|
| 225 |
+
label_noise_ratio=0.5,
|
| 226 |
+
box_noise_scale=1.0,
|
| 227 |
+
learn_initial_query=False,
|
| 228 |
+
anchor_image_size=None,
|
| 229 |
+
disable_custom_kernels=True,
|
| 230 |
+
with_box_refine=True,
|
| 231 |
+
is_encoder_decoder=True,
|
| 232 |
+
# Loss
|
| 233 |
+
matcher_alpha=0.25,
|
| 234 |
+
matcher_gamma=2.0,
|
| 235 |
+
matcher_class_cost=2.0,
|
| 236 |
+
matcher_bbox_cost=5.0,
|
| 237 |
+
matcher_giou_cost=2.0,
|
| 238 |
+
use_focal_loss=True,
|
| 239 |
+
auxiliary_loss=True,
|
| 240 |
+
focal_loss_alpha=0.75,
|
| 241 |
+
focal_loss_gamma=2.0,
|
| 242 |
+
weight_loss_vfl=1.0,
|
| 243 |
+
weight_loss_bbox=5.0,
|
| 244 |
+
weight_loss_giou=2.0,
|
| 245 |
+
eos_coefficient=1e-4,
|
| 246 |
+
**kwargs,
|
| 247 |
+
):
|
| 248 |
+
self.initializer_range = initializer_range
|
| 249 |
+
self.initializer_bias_prior_prob = initializer_bias_prior_prob
|
| 250 |
+
self.layer_norm_eps = layer_norm_eps
|
| 251 |
+
self.batch_norm_eps = batch_norm_eps
|
| 252 |
+
# backbone
|
| 253 |
+
if backbone_config is None and backbone is None:
|
| 254 |
+
logger.info(
|
| 255 |
+
"`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetr-ResNet` backbone."
|
| 256 |
+
)
|
| 257 |
+
backbone_config = RTDetrResNetConfig(
|
| 258 |
+
num_channels=3,
|
| 259 |
+
embedding_size=64,
|
| 260 |
+
hidden_sizes=[256, 512, 1024, 2048],
|
| 261 |
+
depths=[3, 4, 6, 3],
|
| 262 |
+
layer_type="bottleneck",
|
| 263 |
+
hidden_act="relu",
|
| 264 |
+
downsample_in_first_stage=False,
|
| 265 |
+
downsample_in_bottleneck=False,
|
| 266 |
+
out_features=None,
|
| 267 |
+
out_indices=[2, 3, 4],
|
| 268 |
+
)
|
| 269 |
+
elif isinstance(backbone_config, dict):
|
| 270 |
+
backbone_model_type = backbone_config.pop("model_type")
|
| 271 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 272 |
+
backbone_config = config_class.from_dict(backbone_config)
|
| 273 |
+
|
| 274 |
+
verify_backbone_config_arguments(
|
| 275 |
+
use_timm_backbone=use_timm_backbone,
|
| 276 |
+
use_pretrained_backbone=use_pretrained_backbone,
|
| 277 |
+
backbone=backbone,
|
| 278 |
+
backbone_config=backbone_config,
|
| 279 |
+
backbone_kwargs=backbone_kwargs,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
self.backbone_config = backbone_config
|
| 283 |
+
self.backbone = backbone
|
| 284 |
+
self.use_pretrained_backbone = use_pretrained_backbone
|
| 285 |
+
self.use_timm_backbone = use_timm_backbone
|
| 286 |
+
self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
|
| 287 |
+
self.backbone_kwargs = backbone_kwargs
|
| 288 |
+
# encoder
|
| 289 |
+
self.encoder_hidden_dim = encoder_hidden_dim
|
| 290 |
+
self.encoder_in_channels = encoder_in_channels
|
| 291 |
+
self.feat_strides = feat_strides
|
| 292 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 293 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 294 |
+
self.dropout = dropout
|
| 295 |
+
self.activation_dropout = activation_dropout
|
| 296 |
+
self.encode_proj_layers = encode_proj_layers
|
| 297 |
+
self.encoder_layers = encoder_layers
|
| 298 |
+
self.positional_encoding_temperature = positional_encoding_temperature
|
| 299 |
+
self.eval_size = eval_size
|
| 300 |
+
self.normalize_before = normalize_before
|
| 301 |
+
self.encoder_activation_function = encoder_activation_function
|
| 302 |
+
self.activation_function = activation_function
|
| 303 |
+
self.hidden_expansion = hidden_expansion
|
| 304 |
+
# decoder
|
| 305 |
+
self.d_model = d_model
|
| 306 |
+
self.num_queries = num_queries
|
| 307 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 308 |
+
self.decoder_in_channels = decoder_in_channels
|
| 309 |
+
self.num_feature_levels = num_feature_levels
|
| 310 |
+
self.decoder_n_points = decoder_n_points
|
| 311 |
+
self.decoder_layers = decoder_layers
|
| 312 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 313 |
+
self.decoder_activation_function = decoder_activation_function
|
| 314 |
+
self.attention_dropout = attention_dropout
|
| 315 |
+
self.num_denoising = num_denoising
|
| 316 |
+
self.label_noise_ratio = label_noise_ratio
|
| 317 |
+
self.box_noise_scale = box_noise_scale
|
| 318 |
+
self.learn_initial_query = learn_initial_query
|
| 319 |
+
self.anchor_image_size = anchor_image_size
|
| 320 |
+
self.auxiliary_loss = auxiliary_loss
|
| 321 |
+
self.disable_custom_kernels = disable_custom_kernels
|
| 322 |
+
self.with_box_refine = with_box_refine
|
| 323 |
+
# Loss
|
| 324 |
+
self.matcher_alpha = matcher_alpha
|
| 325 |
+
self.matcher_gamma = matcher_gamma
|
| 326 |
+
self.matcher_class_cost = matcher_class_cost
|
| 327 |
+
self.matcher_bbox_cost = matcher_bbox_cost
|
| 328 |
+
self.matcher_giou_cost = matcher_giou_cost
|
| 329 |
+
self.use_focal_loss = use_focal_loss
|
| 330 |
+
self.focal_loss_alpha = focal_loss_alpha
|
| 331 |
+
self.focal_loss_gamma = focal_loss_gamma
|
| 332 |
+
self.weight_loss_vfl = weight_loss_vfl
|
| 333 |
+
self.weight_loss_bbox = weight_loss_bbox
|
| 334 |
+
self.weight_loss_giou = weight_loss_giou
|
| 335 |
+
self.eos_coefficient = eos_coefficient
|
| 336 |
+
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
| 337 |
+
|
| 338 |
+
@property
|
| 339 |
+
def num_attention_heads(self) -> int:
|
| 340 |
+
return self.encoder_attention_heads
|
| 341 |
+
|
| 342 |
+
@property
|
| 343 |
+
def hidden_size(self) -> int:
|
| 344 |
+
return self.d_model
|
| 345 |
+
|
| 346 |
+
@classmethod
|
| 347 |
+
def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
|
| 348 |
+
"""Instantiate a [`RTDetrConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
|
| 349 |
+
configuration.
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
backbone_config ([`PretrainedConfig`]):
|
| 353 |
+
The backbone configuration.
|
| 354 |
+
|
| 355 |
+
Returns:
|
| 356 |
+
[`RTDetrConfig`]: An instance of a configuration object
|
| 357 |
+
"""
|
| 358 |
+
return cls(
|
| 359 |
+
backbone_config=backbone_config,
|
| 360 |
+
**kwargs,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
__all__ = ["RTDetrConfig"]
|
docs/transformers/build/lib/transformers/models/rt_detr/configuration_rt_detr_resnet.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""RT-DETR ResNet model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class RTDetrResNetConfig(BackboneConfigMixin, PretrainedConfig):
|
| 26 |
+
r"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`RTDetrResnetBackbone`]. It is used to instantiate an
|
| 28 |
+
ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 29 |
+
with the defaults will yield a similar configuration to that of the ResNet
|
| 30 |
+
[microsoft/resnet-50](https://huggingface.co/microsoft/resnet-50) architecture.
|
| 31 |
+
|
| 32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 33 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 37 |
+
The number of input channels.
|
| 38 |
+
embedding_size (`int`, *optional*, defaults to 64):
|
| 39 |
+
Dimensionality (hidden size) for the embedding layer.
|
| 40 |
+
hidden_sizes (`List[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`):
|
| 41 |
+
Dimensionality (hidden size) at each stage.
|
| 42 |
+
depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 3]`):
|
| 43 |
+
Depth (number of layers) for each stage.
|
| 44 |
+
layer_type (`str`, *optional*, defaults to `"bottleneck"`):
|
| 45 |
+
The layer to use, it can be either `"basic"` (used for smaller models, like resnet-18 or resnet-34) or
|
| 46 |
+
`"bottleneck"` (used for larger models like resnet-50 and above).
|
| 47 |
+
hidden_act (`str`, *optional*, defaults to `"relu"`):
|
| 48 |
+
The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"`
|
| 49 |
+
are supported.
|
| 50 |
+
downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
|
| 51 |
+
If `True`, the first stage will downsample the inputs using a `stride` of 2.
|
| 52 |
+
downsample_in_bottleneck (`bool`, *optional*, defaults to `False`):
|
| 53 |
+
If `True`, the first conv 1x1 in ResNetBottleNeckLayer will downsample the inputs using a `stride` of 2.
|
| 54 |
+
out_features (`List[str]`, *optional*):
|
| 55 |
+
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
| 56 |
+
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
|
| 57 |
+
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
|
| 58 |
+
same order as defined in the `stage_names` attribute.
|
| 59 |
+
out_indices (`List[int]`, *optional*):
|
| 60 |
+
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
|
| 61 |
+
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
|
| 62 |
+
If unset and `out_features` is unset, will default to the last stage. Must be in the
|
| 63 |
+
same order as defined in the `stage_names` attribute.
|
| 64 |
+
|
| 65 |
+
Example:
|
| 66 |
+
```python
|
| 67 |
+
>>> from transformers import RTDetrResNetConfig, RTDetrResnetBackbone
|
| 68 |
+
|
| 69 |
+
>>> # Initializing a ResNet resnet-50 style configuration
|
| 70 |
+
>>> configuration = RTDetrResNetConfig()
|
| 71 |
+
|
| 72 |
+
>>> # Initializing a model (with random weights) from the resnet-50 style configuration
|
| 73 |
+
>>> model = RTDetrResnetBackbone(configuration)
|
| 74 |
+
|
| 75 |
+
>>> # Accessing the model configuration
|
| 76 |
+
>>> configuration = model.config
|
| 77 |
+
```
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
model_type = "rt_detr_resnet"
|
| 81 |
+
layer_types = ["basic", "bottleneck"]
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
num_channels=3,
|
| 86 |
+
embedding_size=64,
|
| 87 |
+
hidden_sizes=[256, 512, 1024, 2048],
|
| 88 |
+
depths=[3, 4, 6, 3],
|
| 89 |
+
layer_type="bottleneck",
|
| 90 |
+
hidden_act="relu",
|
| 91 |
+
downsample_in_first_stage=False,
|
| 92 |
+
downsample_in_bottleneck=False,
|
| 93 |
+
out_features=None,
|
| 94 |
+
out_indices=None,
|
| 95 |
+
**kwargs,
|
| 96 |
+
):
|
| 97 |
+
super().__init__(**kwargs)
|
| 98 |
+
if layer_type not in self.layer_types:
|
| 99 |
+
raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}")
|
| 100 |
+
self.num_channels = num_channels
|
| 101 |
+
self.embedding_size = embedding_size
|
| 102 |
+
self.hidden_sizes = hidden_sizes
|
| 103 |
+
self.depths = depths
|
| 104 |
+
self.layer_type = layer_type
|
| 105 |
+
self.hidden_act = hidden_act
|
| 106 |
+
self.downsample_in_first_stage = downsample_in_first_stage
|
| 107 |
+
self.downsample_in_bottleneck = downsample_in_bottleneck
|
| 108 |
+
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
|
| 109 |
+
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
| 110 |
+
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
__all__ = ["RTDetrResNetConfig"]
|
docs/transformers/build/lib/transformers/models/rt_detr/convert_rt_detr_original_pytorch_checkpoint_to_hf.py
ADDED
|
@@ -0,0 +1,782 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert RT Detr checkpoints with Timm backbone"""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import requests
|
| 22 |
+
import torch
|
| 23 |
+
from huggingface_hub import hf_hub_download
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from torchvision import transforms
|
| 26 |
+
|
| 27 |
+
from transformers import RTDetrConfig, RTDetrForObjectDetection, RTDetrImageProcessor
|
| 28 |
+
from transformers.utils import logging
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
logging.set_verbosity_info()
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_rt_detr_config(model_name: str) -> RTDetrConfig:
|
| 36 |
+
config = RTDetrConfig()
|
| 37 |
+
|
| 38 |
+
config.num_labels = 80
|
| 39 |
+
repo_id = "huggingface/label-files"
|
| 40 |
+
filename = "coco-detection-mmdet-id2label.json"
|
| 41 |
+
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
| 42 |
+
id2label = {int(k): v for k, v in id2label.items()}
|
| 43 |
+
config.id2label = id2label
|
| 44 |
+
config.label2id = {v: k for k, v in id2label.items()}
|
| 45 |
+
|
| 46 |
+
if model_name == "rtdetr_r18vd":
|
| 47 |
+
config.backbone_config.hidden_sizes = [64, 128, 256, 512]
|
| 48 |
+
config.backbone_config.depths = [2, 2, 2, 2]
|
| 49 |
+
config.backbone_config.layer_type = "basic"
|
| 50 |
+
config.encoder_in_channels = [128, 256, 512]
|
| 51 |
+
config.hidden_expansion = 0.5
|
| 52 |
+
config.decoder_layers = 3
|
| 53 |
+
elif model_name == "rtdetr_r34vd":
|
| 54 |
+
config.backbone_config.hidden_sizes = [64, 128, 256, 512]
|
| 55 |
+
config.backbone_config.depths = [3, 4, 6, 3]
|
| 56 |
+
config.backbone_config.layer_type = "basic"
|
| 57 |
+
config.encoder_in_channels = [128, 256, 512]
|
| 58 |
+
config.hidden_expansion = 0.5
|
| 59 |
+
config.decoder_layers = 4
|
| 60 |
+
elif model_name == "rtdetr_r50vd_m":
|
| 61 |
+
pass
|
| 62 |
+
elif model_name == "rtdetr_r50vd":
|
| 63 |
+
pass
|
| 64 |
+
elif model_name == "rtdetr_r101vd":
|
| 65 |
+
config.backbone_config.depths = [3, 4, 23, 3]
|
| 66 |
+
config.encoder_ffn_dim = 2048
|
| 67 |
+
config.encoder_hidden_dim = 384
|
| 68 |
+
config.decoder_in_channels = [384, 384, 384]
|
| 69 |
+
elif model_name == "rtdetr_r18vd_coco_o365":
|
| 70 |
+
config.backbone_config.hidden_sizes = [64, 128, 256, 512]
|
| 71 |
+
config.backbone_config.depths = [2, 2, 2, 2]
|
| 72 |
+
config.backbone_config.layer_type = "basic"
|
| 73 |
+
config.encoder_in_channels = [128, 256, 512]
|
| 74 |
+
config.hidden_expansion = 0.5
|
| 75 |
+
config.decoder_layers = 3
|
| 76 |
+
elif model_name == "rtdetr_r50vd_coco_o365":
|
| 77 |
+
pass
|
| 78 |
+
elif model_name == "rtdetr_r101vd_coco_o365":
|
| 79 |
+
config.backbone_config.depths = [3, 4, 23, 3]
|
| 80 |
+
config.encoder_ffn_dim = 2048
|
| 81 |
+
config.encoder_hidden_dim = 384
|
| 82 |
+
config.decoder_in_channels = [384, 384, 384]
|
| 83 |
+
|
| 84 |
+
return config
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def create_rename_keys(config):
|
| 88 |
+
# here we list all keys to be renamed (original name on the left, our name on the right)
|
| 89 |
+
rename_keys = []
|
| 90 |
+
|
| 91 |
+
# stem
|
| 92 |
+
# fmt: off
|
| 93 |
+
last_key = ["weight", "bias", "running_mean", "running_var"]
|
| 94 |
+
|
| 95 |
+
for level in range(3):
|
| 96 |
+
rename_keys.append((f"backbone.conv1.conv1_{level+1}.conv.weight", f"model.backbone.model.embedder.embedder.{level}.convolution.weight"))
|
| 97 |
+
for last in last_key:
|
| 98 |
+
rename_keys.append((f"backbone.conv1.conv1_{level+1}.norm.{last}", f"model.backbone.model.embedder.embedder.{level}.normalization.{last}"))
|
| 99 |
+
|
| 100 |
+
for stage_idx in range(len(config.backbone_config.depths)):
|
| 101 |
+
for layer_idx in range(config.backbone_config.depths[stage_idx]):
|
| 102 |
+
# shortcut
|
| 103 |
+
if layer_idx == 0:
|
| 104 |
+
if stage_idx == 0:
|
| 105 |
+
rename_keys.append(
|
| 106 |
+
(
|
| 107 |
+
f"backbone.res_layers.{stage_idx}.blocks.0.short.conv.weight",
|
| 108 |
+
f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.convolution.weight",
|
| 109 |
+
)
|
| 110 |
+
)
|
| 111 |
+
for last in last_key:
|
| 112 |
+
rename_keys.append(
|
| 113 |
+
(
|
| 114 |
+
f"backbone.res_layers.{stage_idx}.blocks.0.short.norm.{last}",
|
| 115 |
+
f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.normalization.{last}",
|
| 116 |
+
)
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
rename_keys.append(
|
| 120 |
+
(
|
| 121 |
+
f"backbone.res_layers.{stage_idx}.blocks.0.short.conv.conv.weight",
|
| 122 |
+
f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.1.convolution.weight",
|
| 123 |
+
)
|
| 124 |
+
)
|
| 125 |
+
for last in last_key:
|
| 126 |
+
rename_keys.append(
|
| 127 |
+
(
|
| 128 |
+
f"backbone.res_layers.{stage_idx}.blocks.0.short.conv.norm.{last}",
|
| 129 |
+
f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.1.normalization.{last}",
|
| 130 |
+
)
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
rename_keys.append(
|
| 134 |
+
(
|
| 135 |
+
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2a.conv.weight",
|
| 136 |
+
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.0.convolution.weight",
|
| 137 |
+
)
|
| 138 |
+
)
|
| 139 |
+
for last in last_key:
|
| 140 |
+
rename_keys.append((
|
| 141 |
+
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2a.norm.{last}",
|
| 142 |
+
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.0.normalization.{last}",
|
| 143 |
+
))
|
| 144 |
+
|
| 145 |
+
rename_keys.append(
|
| 146 |
+
(
|
| 147 |
+
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2b.conv.weight",
|
| 148 |
+
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.1.convolution.weight",
|
| 149 |
+
)
|
| 150 |
+
)
|
| 151 |
+
for last in last_key:
|
| 152 |
+
rename_keys.append((
|
| 153 |
+
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2b.norm.{last}",
|
| 154 |
+
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.1.normalization.{last}",
|
| 155 |
+
))
|
| 156 |
+
|
| 157 |
+
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/nn/backbone/presnet.py#L171
|
| 158 |
+
if config.backbone_config.layer_type != "basic":
|
| 159 |
+
rename_keys.append(
|
| 160 |
+
(
|
| 161 |
+
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2c.conv.weight",
|
| 162 |
+
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.2.convolution.weight",
|
| 163 |
+
)
|
| 164 |
+
)
|
| 165 |
+
for last in last_key:
|
| 166 |
+
rename_keys.append((
|
| 167 |
+
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2c.norm.{last}",
|
| 168 |
+
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.2.normalization.{last}",
|
| 169 |
+
))
|
| 170 |
+
# fmt: on
|
| 171 |
+
|
| 172 |
+
for i in range(config.encoder_layers):
|
| 173 |
+
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
| 174 |
+
rename_keys.append(
|
| 175 |
+
(
|
| 176 |
+
f"encoder.encoder.{i}.layers.0.self_attn.out_proj.weight",
|
| 177 |
+
f"model.encoder.encoder.{i}.layers.0.self_attn.out_proj.weight",
|
| 178 |
+
)
|
| 179 |
+
)
|
| 180 |
+
rename_keys.append(
|
| 181 |
+
(
|
| 182 |
+
f"encoder.encoder.{i}.layers.0.self_attn.out_proj.bias",
|
| 183 |
+
f"model.encoder.encoder.{i}.layers.0.self_attn.out_proj.bias",
|
| 184 |
+
)
|
| 185 |
+
)
|
| 186 |
+
rename_keys.append(
|
| 187 |
+
(
|
| 188 |
+
f"encoder.encoder.{i}.layers.0.linear1.weight",
|
| 189 |
+
f"model.encoder.encoder.{i}.layers.0.fc1.weight",
|
| 190 |
+
)
|
| 191 |
+
)
|
| 192 |
+
rename_keys.append(
|
| 193 |
+
(
|
| 194 |
+
f"encoder.encoder.{i}.layers.0.linear1.bias",
|
| 195 |
+
f"model.encoder.encoder.{i}.layers.0.fc1.bias",
|
| 196 |
+
)
|
| 197 |
+
)
|
| 198 |
+
rename_keys.append(
|
| 199 |
+
(
|
| 200 |
+
f"encoder.encoder.{i}.layers.0.linear2.weight",
|
| 201 |
+
f"model.encoder.encoder.{i}.layers.0.fc2.weight",
|
| 202 |
+
)
|
| 203 |
+
)
|
| 204 |
+
rename_keys.append(
|
| 205 |
+
(
|
| 206 |
+
f"encoder.encoder.{i}.layers.0.linear2.bias",
|
| 207 |
+
f"model.encoder.encoder.{i}.layers.0.fc2.bias",
|
| 208 |
+
)
|
| 209 |
+
)
|
| 210 |
+
rename_keys.append(
|
| 211 |
+
(
|
| 212 |
+
f"encoder.encoder.{i}.layers.0.norm1.weight",
|
| 213 |
+
f"model.encoder.encoder.{i}.layers.0.self_attn_layer_norm.weight",
|
| 214 |
+
)
|
| 215 |
+
)
|
| 216 |
+
rename_keys.append(
|
| 217 |
+
(
|
| 218 |
+
f"encoder.encoder.{i}.layers.0.norm1.bias",
|
| 219 |
+
f"model.encoder.encoder.{i}.layers.0.self_attn_layer_norm.bias",
|
| 220 |
+
)
|
| 221 |
+
)
|
| 222 |
+
rename_keys.append(
|
| 223 |
+
(
|
| 224 |
+
f"encoder.encoder.{i}.layers.0.norm2.weight",
|
| 225 |
+
f"model.encoder.encoder.{i}.layers.0.final_layer_norm.weight",
|
| 226 |
+
)
|
| 227 |
+
)
|
| 228 |
+
rename_keys.append(
|
| 229 |
+
(
|
| 230 |
+
f"encoder.encoder.{i}.layers.0.norm2.bias",
|
| 231 |
+
f"model.encoder.encoder.{i}.layers.0.final_layer_norm.bias",
|
| 232 |
+
)
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
for j in range(0, 3):
|
| 236 |
+
rename_keys.append((f"encoder.input_proj.{j}.0.weight", f"model.encoder_input_proj.{j}.0.weight"))
|
| 237 |
+
for last in last_key:
|
| 238 |
+
rename_keys.append((f"encoder.input_proj.{j}.1.{last}", f"model.encoder_input_proj.{j}.1.{last}"))
|
| 239 |
+
|
| 240 |
+
block_levels = 3 if config.backbone_config.layer_type != "basic" else 4
|
| 241 |
+
|
| 242 |
+
for i in range(len(config.encoder_in_channels) - 1):
|
| 243 |
+
# encoder layers: hybridencoder parts
|
| 244 |
+
for j in range(1, block_levels):
|
| 245 |
+
rename_keys.append(
|
| 246 |
+
(f"encoder.fpn_blocks.{i}.conv{j}.conv.weight", f"model.encoder.fpn_blocks.{i}.conv{j}.conv.weight")
|
| 247 |
+
)
|
| 248 |
+
for last in last_key:
|
| 249 |
+
rename_keys.append(
|
| 250 |
+
(
|
| 251 |
+
f"encoder.fpn_blocks.{i}.conv{j}.norm.{last}",
|
| 252 |
+
f"model.encoder.fpn_blocks.{i}.conv{j}.norm.{last}",
|
| 253 |
+
)
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
rename_keys.append((f"encoder.lateral_convs.{i}.conv.weight", f"model.encoder.lateral_convs.{i}.conv.weight"))
|
| 257 |
+
for last in last_key:
|
| 258 |
+
rename_keys.append(
|
| 259 |
+
(f"encoder.lateral_convs.{i}.norm.{last}", f"model.encoder.lateral_convs.{i}.norm.{last}")
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
for j in range(3):
|
| 263 |
+
for k in range(1, 3):
|
| 264 |
+
rename_keys.append(
|
| 265 |
+
(
|
| 266 |
+
f"encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight",
|
| 267 |
+
f"model.encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight",
|
| 268 |
+
)
|
| 269 |
+
)
|
| 270 |
+
for last in last_key:
|
| 271 |
+
rename_keys.append(
|
| 272 |
+
(
|
| 273 |
+
f"encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}",
|
| 274 |
+
f"model.encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}",
|
| 275 |
+
)
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
for j in range(1, block_levels):
|
| 279 |
+
rename_keys.append(
|
| 280 |
+
(f"encoder.pan_blocks.{i}.conv{j}.conv.weight", f"model.encoder.pan_blocks.{i}.conv{j}.conv.weight")
|
| 281 |
+
)
|
| 282 |
+
for last in last_key:
|
| 283 |
+
rename_keys.append(
|
| 284 |
+
(
|
| 285 |
+
f"encoder.pan_blocks.{i}.conv{j}.norm.{last}",
|
| 286 |
+
f"model.encoder.pan_blocks.{i}.conv{j}.norm.{last}",
|
| 287 |
+
)
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
for j in range(3):
|
| 291 |
+
for k in range(1, 3):
|
| 292 |
+
rename_keys.append(
|
| 293 |
+
(
|
| 294 |
+
f"encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight",
|
| 295 |
+
f"model.encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight",
|
| 296 |
+
)
|
| 297 |
+
)
|
| 298 |
+
for last in last_key:
|
| 299 |
+
rename_keys.append(
|
| 300 |
+
(
|
| 301 |
+
f"encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}",
|
| 302 |
+
f"model.encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}",
|
| 303 |
+
)
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
rename_keys.append(
|
| 307 |
+
(f"encoder.downsample_convs.{i}.conv.weight", f"model.encoder.downsample_convs.{i}.conv.weight")
|
| 308 |
+
)
|
| 309 |
+
for last in last_key:
|
| 310 |
+
rename_keys.append(
|
| 311 |
+
(f"encoder.downsample_convs.{i}.norm.{last}", f"model.encoder.downsample_convs.{i}.norm.{last}")
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
for i in range(config.decoder_layers):
|
| 315 |
+
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
|
| 316 |
+
rename_keys.append(
|
| 317 |
+
(
|
| 318 |
+
f"decoder.decoder.layers.{i}.self_attn.out_proj.weight",
|
| 319 |
+
f"model.decoder.layers.{i}.self_attn.out_proj.weight",
|
| 320 |
+
)
|
| 321 |
+
)
|
| 322 |
+
rename_keys.append(
|
| 323 |
+
(
|
| 324 |
+
f"decoder.decoder.layers.{i}.self_attn.out_proj.bias",
|
| 325 |
+
f"model.decoder.layers.{i}.self_attn.out_proj.bias",
|
| 326 |
+
)
|
| 327 |
+
)
|
| 328 |
+
rename_keys.append(
|
| 329 |
+
(
|
| 330 |
+
f"decoder.decoder.layers.{i}.cross_attn.sampling_offsets.weight",
|
| 331 |
+
f"model.decoder.layers.{i}.encoder_attn.sampling_offsets.weight",
|
| 332 |
+
)
|
| 333 |
+
)
|
| 334 |
+
rename_keys.append(
|
| 335 |
+
(
|
| 336 |
+
f"decoder.decoder.layers.{i}.cross_attn.sampling_offsets.bias",
|
| 337 |
+
f"model.decoder.layers.{i}.encoder_attn.sampling_offsets.bias",
|
| 338 |
+
)
|
| 339 |
+
)
|
| 340 |
+
rename_keys.append(
|
| 341 |
+
(
|
| 342 |
+
f"decoder.decoder.layers.{i}.cross_attn.attention_weights.weight",
|
| 343 |
+
f"model.decoder.layers.{i}.encoder_attn.attention_weights.weight",
|
| 344 |
+
)
|
| 345 |
+
)
|
| 346 |
+
rename_keys.append(
|
| 347 |
+
(
|
| 348 |
+
f"decoder.decoder.layers.{i}.cross_attn.attention_weights.bias",
|
| 349 |
+
f"model.decoder.layers.{i}.encoder_attn.attention_weights.bias",
|
| 350 |
+
)
|
| 351 |
+
)
|
| 352 |
+
rename_keys.append(
|
| 353 |
+
(
|
| 354 |
+
f"decoder.decoder.layers.{i}.cross_attn.value_proj.weight",
|
| 355 |
+
f"model.decoder.layers.{i}.encoder_attn.value_proj.weight",
|
| 356 |
+
)
|
| 357 |
+
)
|
| 358 |
+
rename_keys.append(
|
| 359 |
+
(
|
| 360 |
+
f"decoder.decoder.layers.{i}.cross_attn.value_proj.bias",
|
| 361 |
+
f"model.decoder.layers.{i}.encoder_attn.value_proj.bias",
|
| 362 |
+
)
|
| 363 |
+
)
|
| 364 |
+
rename_keys.append(
|
| 365 |
+
(
|
| 366 |
+
f"decoder.decoder.layers.{i}.cross_attn.output_proj.weight",
|
| 367 |
+
f"model.decoder.layers.{i}.encoder_attn.output_proj.weight",
|
| 368 |
+
)
|
| 369 |
+
)
|
| 370 |
+
rename_keys.append(
|
| 371 |
+
(
|
| 372 |
+
f"decoder.decoder.layers.{i}.cross_attn.output_proj.bias",
|
| 373 |
+
f"model.decoder.layers.{i}.encoder_attn.output_proj.bias",
|
| 374 |
+
)
|
| 375 |
+
)
|
| 376 |
+
rename_keys.append(
|
| 377 |
+
(f"decoder.decoder.layers.{i}.norm1.weight", f"model.decoder.layers.{i}.self_attn_layer_norm.weight")
|
| 378 |
+
)
|
| 379 |
+
rename_keys.append(
|
| 380 |
+
(f"decoder.decoder.layers.{i}.norm1.bias", f"model.decoder.layers.{i}.self_attn_layer_norm.bias")
|
| 381 |
+
)
|
| 382 |
+
rename_keys.append(
|
| 383 |
+
(f"decoder.decoder.layers.{i}.norm2.weight", f"model.decoder.layers.{i}.encoder_attn_layer_norm.weight")
|
| 384 |
+
)
|
| 385 |
+
rename_keys.append(
|
| 386 |
+
(f"decoder.decoder.layers.{i}.norm2.bias", f"model.decoder.layers.{i}.encoder_attn_layer_norm.bias")
|
| 387 |
+
)
|
| 388 |
+
rename_keys.append((f"decoder.decoder.layers.{i}.linear1.weight", f"model.decoder.layers.{i}.fc1.weight"))
|
| 389 |
+
rename_keys.append((f"decoder.decoder.layers.{i}.linear1.bias", f"model.decoder.layers.{i}.fc1.bias"))
|
| 390 |
+
rename_keys.append((f"decoder.decoder.layers.{i}.linear2.weight", f"model.decoder.layers.{i}.fc2.weight"))
|
| 391 |
+
rename_keys.append((f"decoder.decoder.layers.{i}.linear2.bias", f"model.decoder.layers.{i}.fc2.bias"))
|
| 392 |
+
rename_keys.append(
|
| 393 |
+
(f"decoder.decoder.layers.{i}.norm3.weight", f"model.decoder.layers.{i}.final_layer_norm.weight")
|
| 394 |
+
)
|
| 395 |
+
rename_keys.append(
|
| 396 |
+
(f"decoder.decoder.layers.{i}.norm3.bias", f"model.decoder.layers.{i}.final_layer_norm.bias")
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
for i in range(config.decoder_layers):
|
| 400 |
+
# decoder + class and bounding box heads
|
| 401 |
+
rename_keys.append(
|
| 402 |
+
(
|
| 403 |
+
f"decoder.dec_score_head.{i}.weight",
|
| 404 |
+
f"model.decoder.class_embed.{i}.weight",
|
| 405 |
+
)
|
| 406 |
+
)
|
| 407 |
+
rename_keys.append(
|
| 408 |
+
(
|
| 409 |
+
f"decoder.dec_score_head.{i}.bias",
|
| 410 |
+
f"model.decoder.class_embed.{i}.bias",
|
| 411 |
+
)
|
| 412 |
+
)
|
| 413 |
+
rename_keys.append(
|
| 414 |
+
(
|
| 415 |
+
f"decoder.dec_bbox_head.{i}.layers.0.weight",
|
| 416 |
+
f"model.decoder.bbox_embed.{i}.layers.0.weight",
|
| 417 |
+
)
|
| 418 |
+
)
|
| 419 |
+
rename_keys.append(
|
| 420 |
+
(
|
| 421 |
+
f"decoder.dec_bbox_head.{i}.layers.0.bias",
|
| 422 |
+
f"model.decoder.bbox_embed.{i}.layers.0.bias",
|
| 423 |
+
)
|
| 424 |
+
)
|
| 425 |
+
rename_keys.append(
|
| 426 |
+
(
|
| 427 |
+
f"decoder.dec_bbox_head.{i}.layers.1.weight",
|
| 428 |
+
f"model.decoder.bbox_embed.{i}.layers.1.weight",
|
| 429 |
+
)
|
| 430 |
+
)
|
| 431 |
+
rename_keys.append(
|
| 432 |
+
(
|
| 433 |
+
f"decoder.dec_bbox_head.{i}.layers.1.bias",
|
| 434 |
+
f"model.decoder.bbox_embed.{i}.layers.1.bias",
|
| 435 |
+
)
|
| 436 |
+
)
|
| 437 |
+
rename_keys.append(
|
| 438 |
+
(
|
| 439 |
+
f"decoder.dec_bbox_head.{i}.layers.2.weight",
|
| 440 |
+
f"model.decoder.bbox_embed.{i}.layers.2.weight",
|
| 441 |
+
)
|
| 442 |
+
)
|
| 443 |
+
rename_keys.append(
|
| 444 |
+
(
|
| 445 |
+
f"decoder.dec_bbox_head.{i}.layers.2.bias",
|
| 446 |
+
f"model.decoder.bbox_embed.{i}.layers.2.bias",
|
| 447 |
+
)
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# decoder projection
|
| 451 |
+
for i in range(len(config.decoder_in_channels)):
|
| 452 |
+
rename_keys.append(
|
| 453 |
+
(
|
| 454 |
+
f"decoder.input_proj.{i}.conv.weight",
|
| 455 |
+
f"model.decoder_input_proj.{i}.0.weight",
|
| 456 |
+
)
|
| 457 |
+
)
|
| 458 |
+
for last in last_key:
|
| 459 |
+
rename_keys.append(
|
| 460 |
+
(
|
| 461 |
+
f"decoder.input_proj.{i}.norm.{last}",
|
| 462 |
+
f"model.decoder_input_proj.{i}.1.{last}",
|
| 463 |
+
)
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
|
| 467 |
+
rename_keys.extend(
|
| 468 |
+
[
|
| 469 |
+
("decoder.denoising_class_embed.weight", "model.denoising_class_embed.weight"),
|
| 470 |
+
("decoder.query_pos_head.layers.0.weight", "model.decoder.query_pos_head.layers.0.weight"),
|
| 471 |
+
("decoder.query_pos_head.layers.0.bias", "model.decoder.query_pos_head.layers.0.bias"),
|
| 472 |
+
("decoder.query_pos_head.layers.1.weight", "model.decoder.query_pos_head.layers.1.weight"),
|
| 473 |
+
("decoder.query_pos_head.layers.1.bias", "model.decoder.query_pos_head.layers.1.bias"),
|
| 474 |
+
("decoder.enc_output.0.weight", "model.enc_output.0.weight"),
|
| 475 |
+
("decoder.enc_output.0.bias", "model.enc_output.0.bias"),
|
| 476 |
+
("decoder.enc_output.1.weight", "model.enc_output.1.weight"),
|
| 477 |
+
("decoder.enc_output.1.bias", "model.enc_output.1.bias"),
|
| 478 |
+
("decoder.enc_score_head.weight", "model.enc_score_head.weight"),
|
| 479 |
+
("decoder.enc_score_head.bias", "model.enc_score_head.bias"),
|
| 480 |
+
("decoder.enc_bbox_head.layers.0.weight", "model.enc_bbox_head.layers.0.weight"),
|
| 481 |
+
("decoder.enc_bbox_head.layers.0.bias", "model.enc_bbox_head.layers.0.bias"),
|
| 482 |
+
("decoder.enc_bbox_head.layers.1.weight", "model.enc_bbox_head.layers.1.weight"),
|
| 483 |
+
("decoder.enc_bbox_head.layers.1.bias", "model.enc_bbox_head.layers.1.bias"),
|
| 484 |
+
("decoder.enc_bbox_head.layers.2.weight", "model.enc_bbox_head.layers.2.weight"),
|
| 485 |
+
("decoder.enc_bbox_head.layers.2.bias", "model.enc_bbox_head.layers.2.bias"),
|
| 486 |
+
]
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
return rename_keys
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def rename_key(state_dict, old, new):
|
| 493 |
+
try:
|
| 494 |
+
val = state_dict.pop(old)
|
| 495 |
+
state_dict[new] = val
|
| 496 |
+
except Exception:
|
| 497 |
+
pass
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def read_in_q_k_v(state_dict, config):
|
| 501 |
+
prefix = ""
|
| 502 |
+
encoder_hidden_dim = config.encoder_hidden_dim
|
| 503 |
+
|
| 504 |
+
# first: transformer encoder
|
| 505 |
+
for i in range(config.encoder_layers):
|
| 506 |
+
# read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
|
| 507 |
+
in_proj_weight = state_dict.pop(f"{prefix}encoder.encoder.{i}.layers.0.self_attn.in_proj_weight")
|
| 508 |
+
in_proj_bias = state_dict.pop(f"{prefix}encoder.encoder.{i}.layers.0.self_attn.in_proj_bias")
|
| 509 |
+
# next, add query, keys and values (in that order) to the state dict
|
| 510 |
+
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.q_proj.weight"] = in_proj_weight[
|
| 511 |
+
:encoder_hidden_dim, :
|
| 512 |
+
]
|
| 513 |
+
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.q_proj.bias"] = in_proj_bias[:encoder_hidden_dim]
|
| 514 |
+
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.k_proj.weight"] = in_proj_weight[
|
| 515 |
+
encoder_hidden_dim : 2 * encoder_hidden_dim, :
|
| 516 |
+
]
|
| 517 |
+
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.k_proj.bias"] = in_proj_bias[
|
| 518 |
+
encoder_hidden_dim : 2 * encoder_hidden_dim
|
| 519 |
+
]
|
| 520 |
+
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.v_proj.weight"] = in_proj_weight[
|
| 521 |
+
-encoder_hidden_dim:, :
|
| 522 |
+
]
|
| 523 |
+
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.v_proj.bias"] = in_proj_bias[-encoder_hidden_dim:]
|
| 524 |
+
# next: transformer decoder (which is a bit more complex because it also includes cross-attention)
|
| 525 |
+
for i in range(config.decoder_layers):
|
| 526 |
+
# read in weights + bias of input projection layer of self-attention
|
| 527 |
+
in_proj_weight = state_dict.pop(f"{prefix}decoder.decoder.layers.{i}.self_attn.in_proj_weight")
|
| 528 |
+
in_proj_bias = state_dict.pop(f"{prefix}decoder.decoder.layers.{i}.self_attn.in_proj_bias")
|
| 529 |
+
# next, add query, keys and values (in that order) to the state dict
|
| 530 |
+
state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
|
| 531 |
+
state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
|
| 532 |
+
state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
|
| 533 |
+
state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
|
| 534 |
+
state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
|
| 535 |
+
state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
# We will verify our results on an image of cute cats
|
| 539 |
+
def prepare_img():
|
| 540 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 541 |
+
im = Image.open(requests.get(url, stream=True).raw)
|
| 542 |
+
|
| 543 |
+
return im
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
@torch.no_grad()
|
| 547 |
+
def convert_rt_detr_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, repo_id):
|
| 548 |
+
"""
|
| 549 |
+
Copy/paste/tweak model's weights to our RTDETR structure.
|
| 550 |
+
"""
|
| 551 |
+
|
| 552 |
+
# load default config
|
| 553 |
+
config = get_rt_detr_config(model_name)
|
| 554 |
+
|
| 555 |
+
# load original model from torch hub
|
| 556 |
+
model_name_to_checkpoint_url = {
|
| 557 |
+
"rtdetr_r18vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r18vd_dec3_6x_coco_from_paddle.pth",
|
| 558 |
+
"rtdetr_r34vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r34vd_dec4_6x_coco_from_paddle.pth",
|
| 559 |
+
"rtdetr_r50vd_m": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_m_6x_coco_from_paddle.pth",
|
| 560 |
+
"rtdetr_r50vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_6x_coco_from_paddle.pth",
|
| 561 |
+
"rtdetr_r101vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r101vd_6x_coco_from_paddle.pth",
|
| 562 |
+
"rtdetr_r18vd_coco_o365": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r18vd_5x_coco_objects365_from_paddle.pth",
|
| 563 |
+
"rtdetr_r50vd_coco_o365": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_2x_coco_objects365_from_paddle.pth",
|
| 564 |
+
"rtdetr_r101vd_coco_o365": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r101vd_2x_coco_objects365_from_paddle.pth",
|
| 565 |
+
}
|
| 566 |
+
logger.info(f"Converting model {model_name}...")
|
| 567 |
+
state_dict = torch.hub.load_state_dict_from_url(model_name_to_checkpoint_url[model_name], map_location="cpu")[
|
| 568 |
+
"ema"
|
| 569 |
+
]["module"]
|
| 570 |
+
|
| 571 |
+
# rename keys
|
| 572 |
+
for src, dest in create_rename_keys(config):
|
| 573 |
+
rename_key(state_dict, src, dest)
|
| 574 |
+
# query, key and value matrices need special treatment
|
| 575 |
+
read_in_q_k_v(state_dict, config)
|
| 576 |
+
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
|
| 577 |
+
for key in state_dict.copy().keys():
|
| 578 |
+
if key.endswith("num_batches_tracked"):
|
| 579 |
+
del state_dict[key]
|
| 580 |
+
# for two_stage
|
| 581 |
+
if "bbox_embed" in key or ("class_embed" in key and "denoising_" not in key):
|
| 582 |
+
state_dict[key.split("model.decoder.")[-1]] = state_dict[key]
|
| 583 |
+
|
| 584 |
+
# finally, create HuggingFace model and load state dict
|
| 585 |
+
model = RTDetrForObjectDetection(config)
|
| 586 |
+
model.load_state_dict(state_dict)
|
| 587 |
+
model.eval()
|
| 588 |
+
|
| 589 |
+
# load image processor
|
| 590 |
+
image_processor = RTDetrImageProcessor()
|
| 591 |
+
|
| 592 |
+
# prepare image
|
| 593 |
+
img = prepare_img()
|
| 594 |
+
|
| 595 |
+
# preprocess image
|
| 596 |
+
transformations = transforms.Compose(
|
| 597 |
+
[
|
| 598 |
+
transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR),
|
| 599 |
+
transforms.ToTensor(),
|
| 600 |
+
]
|
| 601 |
+
)
|
| 602 |
+
original_pixel_values = transformations(img).unsqueeze(0) # insert batch dimension
|
| 603 |
+
|
| 604 |
+
encoding = image_processor(images=img, return_tensors="pt")
|
| 605 |
+
pixel_values = encoding["pixel_values"]
|
| 606 |
+
|
| 607 |
+
assert torch.allclose(original_pixel_values, pixel_values)
|
| 608 |
+
|
| 609 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 610 |
+
model.to(device)
|
| 611 |
+
pixel_values = pixel_values.to(device)
|
| 612 |
+
|
| 613 |
+
# Pass image by the model
|
| 614 |
+
outputs = model(pixel_values)
|
| 615 |
+
|
| 616 |
+
if model_name == "rtdetr_r18vd":
|
| 617 |
+
expected_slice_logits = torch.tensor(
|
| 618 |
+
[
|
| 619 |
+
[-4.3364253, -6.465683, -3.6130402],
|
| 620 |
+
[-4.083815, -6.4039373, -6.97881],
|
| 621 |
+
[-4.192215, -7.3410473, -6.9027247],
|
| 622 |
+
]
|
| 623 |
+
)
|
| 624 |
+
expected_slice_boxes = torch.tensor(
|
| 625 |
+
[
|
| 626 |
+
[0.16868353, 0.19833282, 0.21182671],
|
| 627 |
+
[0.25559652, 0.55121744, 0.47988364],
|
| 628 |
+
[0.7698693, 0.4124569, 0.46036878],
|
| 629 |
+
]
|
| 630 |
+
)
|
| 631 |
+
elif model_name == "rtdetr_r34vd":
|
| 632 |
+
expected_slice_logits = torch.tensor(
|
| 633 |
+
[
|
| 634 |
+
[-4.3727384, -4.7921476, -5.7299604],
|
| 635 |
+
[-4.840536, -8.455345, -4.1745796],
|
| 636 |
+
[-4.1277084, -5.2154565, -5.7852697],
|
| 637 |
+
]
|
| 638 |
+
)
|
| 639 |
+
expected_slice_boxes = torch.tensor(
|
| 640 |
+
[
|
| 641 |
+
[0.258278, 0.5497808, 0.4732004],
|
| 642 |
+
[0.16889669, 0.19890057, 0.21138911],
|
| 643 |
+
[0.76632994, 0.4147879, 0.46851268],
|
| 644 |
+
]
|
| 645 |
+
)
|
| 646 |
+
elif model_name == "rtdetr_r50vd_m":
|
| 647 |
+
expected_slice_logits = torch.tensor(
|
| 648 |
+
[
|
| 649 |
+
[-4.319764, -6.1349025, -6.094794],
|
| 650 |
+
[-5.1056995, -7.744766, -4.803956],
|
| 651 |
+
[-4.7685347, -7.9278393, -4.5751696],
|
| 652 |
+
]
|
| 653 |
+
)
|
| 654 |
+
expected_slice_boxes = torch.tensor(
|
| 655 |
+
[
|
| 656 |
+
[0.2582739, 0.55071366, 0.47660282],
|
| 657 |
+
[0.16811174, 0.19954777, 0.21292639],
|
| 658 |
+
[0.54986024, 0.2752091, 0.0561416],
|
| 659 |
+
]
|
| 660 |
+
)
|
| 661 |
+
elif model_name == "rtdetr_r50vd":
|
| 662 |
+
expected_slice_logits = torch.tensor(
|
| 663 |
+
[
|
| 664 |
+
[-4.6476398, -5.001154, -4.9785104],
|
| 665 |
+
[-4.1593494, -4.7038546, -5.946485],
|
| 666 |
+
[-4.4374595, -4.658361, -6.2352347],
|
| 667 |
+
]
|
| 668 |
+
)
|
| 669 |
+
expected_slice_boxes = torch.tensor(
|
| 670 |
+
[
|
| 671 |
+
[0.16880608, 0.19992264, 0.21225442],
|
| 672 |
+
[0.76837635, 0.4122631, 0.46368608],
|
| 673 |
+
[0.2595386, 0.5483334, 0.4777486],
|
| 674 |
+
]
|
| 675 |
+
)
|
| 676 |
+
elif model_name == "rtdetr_r101vd":
|
| 677 |
+
expected_slice_logits = torch.tensor(
|
| 678 |
+
[
|
| 679 |
+
[-4.6162, -4.9189, -4.6656],
|
| 680 |
+
[-4.4701, -4.4997, -4.9659],
|
| 681 |
+
[-5.6641, -7.9000, -5.0725],
|
| 682 |
+
]
|
| 683 |
+
)
|
| 684 |
+
expected_slice_boxes = torch.tensor(
|
| 685 |
+
[
|
| 686 |
+
[0.7707, 0.4124, 0.4585],
|
| 687 |
+
[0.2589, 0.5492, 0.4735],
|
| 688 |
+
[0.1688, 0.1993, 0.2108],
|
| 689 |
+
]
|
| 690 |
+
)
|
| 691 |
+
elif model_name == "rtdetr_r18vd_coco_o365":
|
| 692 |
+
expected_slice_logits = torch.tensor(
|
| 693 |
+
[
|
| 694 |
+
[-4.8726, -5.9066, -5.2450],
|
| 695 |
+
[-4.8157, -6.8764, -5.1656],
|
| 696 |
+
[-4.7492, -5.7006, -5.1333],
|
| 697 |
+
]
|
| 698 |
+
)
|
| 699 |
+
expected_slice_boxes = torch.tensor(
|
| 700 |
+
[
|
| 701 |
+
[0.2552, 0.5501, 0.4773],
|
| 702 |
+
[0.1685, 0.1986, 0.2104],
|
| 703 |
+
[0.7692, 0.4141, 0.4620],
|
| 704 |
+
]
|
| 705 |
+
)
|
| 706 |
+
elif model_name == "rtdetr_r50vd_coco_o365":
|
| 707 |
+
expected_slice_logits = torch.tensor(
|
| 708 |
+
[
|
| 709 |
+
[-4.6491, -3.9252, -5.3163],
|
| 710 |
+
[-4.1386, -5.0348, -3.9016],
|
| 711 |
+
[-4.4778, -4.5423, -5.7356],
|
| 712 |
+
]
|
| 713 |
+
)
|
| 714 |
+
expected_slice_boxes = torch.tensor(
|
| 715 |
+
[
|
| 716 |
+
[0.2583, 0.5492, 0.4747],
|
| 717 |
+
[0.5501, 0.2754, 0.0574],
|
| 718 |
+
[0.7693, 0.4137, 0.4613],
|
| 719 |
+
]
|
| 720 |
+
)
|
| 721 |
+
elif model_name == "rtdetr_r101vd_coco_o365":
|
| 722 |
+
expected_slice_logits = torch.tensor(
|
| 723 |
+
[
|
| 724 |
+
[-4.5152, -5.6811, -5.7311],
|
| 725 |
+
[-4.5358, -7.2422, -5.0941],
|
| 726 |
+
[-4.6919, -5.5834, -6.0145],
|
| 727 |
+
]
|
| 728 |
+
)
|
| 729 |
+
expected_slice_boxes = torch.tensor(
|
| 730 |
+
[
|
| 731 |
+
[0.7703, 0.4140, 0.4583],
|
| 732 |
+
[0.1686, 0.1991, 0.2107],
|
| 733 |
+
[0.2570, 0.5496, 0.4750],
|
| 734 |
+
]
|
| 735 |
+
)
|
| 736 |
+
else:
|
| 737 |
+
raise ValueError(f"Unknown rt_detr_name: {model_name}")
|
| 738 |
+
|
| 739 |
+
assert torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits.to(outputs.logits.device), atol=1e-4)
|
| 740 |
+
assert torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes.to(outputs.pred_boxes.device), atol=1e-3)
|
| 741 |
+
|
| 742 |
+
if pytorch_dump_folder_path is not None:
|
| 743 |
+
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
| 744 |
+
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
|
| 745 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 746 |
+
print(f"Saving image processor to {pytorch_dump_folder_path}")
|
| 747 |
+
image_processor.save_pretrained(pytorch_dump_folder_path)
|
| 748 |
+
|
| 749 |
+
if push_to_hub:
|
| 750 |
+
# Upload model, image processor and config to the hub
|
| 751 |
+
logger.info("Uploading PyTorch model and image processor to the hub...")
|
| 752 |
+
config.push_to_hub(
|
| 753 |
+
repo_id=repo_id, commit_message="Add config from convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py"
|
| 754 |
+
)
|
| 755 |
+
model.push_to_hub(
|
| 756 |
+
repo_id=repo_id, commit_message="Add model from convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py"
|
| 757 |
+
)
|
| 758 |
+
image_processor.push_to_hub(
|
| 759 |
+
repo_id=repo_id,
|
| 760 |
+
commit_message="Add image processor from convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py",
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
if __name__ == "__main__":
|
| 765 |
+
parser = argparse.ArgumentParser()
|
| 766 |
+
parser.add_argument(
|
| 767 |
+
"--model_name",
|
| 768 |
+
default="rtdetr_r50vd",
|
| 769 |
+
type=str,
|
| 770 |
+
help="model_name of the checkpoint you'd like to convert.",
|
| 771 |
+
)
|
| 772 |
+
parser.add_argument(
|
| 773 |
+
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
| 774 |
+
)
|
| 775 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to the hub or not.")
|
| 776 |
+
parser.add_argument(
|
| 777 |
+
"--repo_id",
|
| 778 |
+
type=str,
|
| 779 |
+
help="repo_id where the model will be pushed to.",
|
| 780 |
+
)
|
| 781 |
+
args = parser.parse_args()
|
| 782 |
+
convert_rt_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.repo_id)
|
docs/transformers/build/lib/transformers/models/rt_detr/image_processing_rt_detr.py
ADDED
|
@@ -0,0 +1,1102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for RT-DETR."""
|
| 16 |
+
|
| 17 |
+
import pathlib
|
| 18 |
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
from ...feature_extraction_utils import BatchFeature
|
| 23 |
+
from ...image_processing_utils import BaseImageProcessor, get_size_dict
|
| 24 |
+
from ...image_transforms import (
|
| 25 |
+
PaddingMode,
|
| 26 |
+
center_to_corners_format,
|
| 27 |
+
corners_to_center_format,
|
| 28 |
+
pad,
|
| 29 |
+
rescale,
|
| 30 |
+
resize,
|
| 31 |
+
to_channel_dimension_format,
|
| 32 |
+
)
|
| 33 |
+
from ...image_utils import (
|
| 34 |
+
IMAGENET_DEFAULT_MEAN,
|
| 35 |
+
IMAGENET_DEFAULT_STD,
|
| 36 |
+
AnnotationFormat,
|
| 37 |
+
AnnotationType,
|
| 38 |
+
ChannelDimension,
|
| 39 |
+
ImageInput,
|
| 40 |
+
PILImageResampling,
|
| 41 |
+
get_image_size,
|
| 42 |
+
infer_channel_dimension_format,
|
| 43 |
+
is_scaled_image,
|
| 44 |
+
make_list_of_images,
|
| 45 |
+
to_numpy_array,
|
| 46 |
+
valid_images,
|
| 47 |
+
validate_annotations,
|
| 48 |
+
validate_preprocess_arguments,
|
| 49 |
+
)
|
| 50 |
+
from ...utils import (
|
| 51 |
+
filter_out_non_signature_kwargs,
|
| 52 |
+
is_flax_available,
|
| 53 |
+
is_jax_tensor,
|
| 54 |
+
is_tf_available,
|
| 55 |
+
is_tf_tensor,
|
| 56 |
+
is_torch_available,
|
| 57 |
+
is_torch_tensor,
|
| 58 |
+
logging,
|
| 59 |
+
requires_backends,
|
| 60 |
+
)
|
| 61 |
+
from ...utils.generic import TensorType
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if is_torch_available():
|
| 65 |
+
import torch
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 69 |
+
|
| 70 |
+
SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION,)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# Copied from transformers.models.detr.image_processing_detr.get_size_with_aspect_ratio
|
| 74 |
+
def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, int]:
|
| 75 |
+
"""
|
| 76 |
+
Computes the output image size given the input image size and the desired output size.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
image_size (`Tuple[int, int]`):
|
| 80 |
+
The input image size.
|
| 81 |
+
size (`int`):
|
| 82 |
+
The desired output size.
|
| 83 |
+
max_size (`int`, *optional*):
|
| 84 |
+
The maximum allowed output size.
|
| 85 |
+
"""
|
| 86 |
+
height, width = image_size
|
| 87 |
+
raw_size = None
|
| 88 |
+
if max_size is not None:
|
| 89 |
+
min_original_size = float(min((height, width)))
|
| 90 |
+
max_original_size = float(max((height, width)))
|
| 91 |
+
if max_original_size / min_original_size * size > max_size:
|
| 92 |
+
raw_size = max_size * min_original_size / max_original_size
|
| 93 |
+
size = int(round(raw_size))
|
| 94 |
+
|
| 95 |
+
if (height <= width and height == size) or (width <= height and width == size):
|
| 96 |
+
oh, ow = height, width
|
| 97 |
+
elif width < height:
|
| 98 |
+
ow = size
|
| 99 |
+
if max_size is not None and raw_size is not None:
|
| 100 |
+
oh = int(raw_size * height / width)
|
| 101 |
+
else:
|
| 102 |
+
oh = int(size * height / width)
|
| 103 |
+
else:
|
| 104 |
+
oh = size
|
| 105 |
+
if max_size is not None and raw_size is not None:
|
| 106 |
+
ow = int(raw_size * width / height)
|
| 107 |
+
else:
|
| 108 |
+
ow = int(size * width / height)
|
| 109 |
+
|
| 110 |
+
return (oh, ow)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size
|
| 114 |
+
def get_resize_output_image_size(
|
| 115 |
+
input_image: np.ndarray,
|
| 116 |
+
size: Union[int, Tuple[int, int], List[int]],
|
| 117 |
+
max_size: Optional[int] = None,
|
| 118 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 119 |
+
) -> Tuple[int, int]:
|
| 120 |
+
"""
|
| 121 |
+
Computes the output image size given the input image size and the desired output size. If the desired output size
|
| 122 |
+
is a tuple or list, the output image size is returned as is. If the desired output size is an integer, the output
|
| 123 |
+
image size is computed by keeping the aspect ratio of the input image size.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
input_image (`np.ndarray`):
|
| 127 |
+
The image to resize.
|
| 128 |
+
size (`int` or `Tuple[int, int]` or `List[int]`):
|
| 129 |
+
The desired output size.
|
| 130 |
+
max_size (`int`, *optional*):
|
| 131 |
+
The maximum allowed output size.
|
| 132 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 133 |
+
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
| 134 |
+
"""
|
| 135 |
+
image_size = get_image_size(input_image, input_data_format)
|
| 136 |
+
if isinstance(size, (list, tuple)):
|
| 137 |
+
return size
|
| 138 |
+
|
| 139 |
+
return get_size_with_aspect_ratio(image_size, size, max_size)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# Copied from transformers.models.detr.image_processing_detr.get_image_size_for_max_height_width
|
| 143 |
+
def get_image_size_for_max_height_width(
|
| 144 |
+
input_image: np.ndarray,
|
| 145 |
+
max_height: int,
|
| 146 |
+
max_width: int,
|
| 147 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 148 |
+
) -> Tuple[int, int]:
|
| 149 |
+
"""
|
| 150 |
+
Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
|
| 151 |
+
Important, even if image_height < max_height and image_width < max_width, the image will be resized
|
| 152 |
+
to at least one of the edges be equal to max_height or max_width.
|
| 153 |
+
For example:
|
| 154 |
+
- input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
|
| 155 |
+
- input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
|
| 156 |
+
Args:
|
| 157 |
+
input_image (`np.ndarray`):
|
| 158 |
+
The image to resize.
|
| 159 |
+
max_height (`int`):
|
| 160 |
+
The maximum allowed height.
|
| 161 |
+
max_width (`int`):
|
| 162 |
+
The maximum allowed width.
|
| 163 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 164 |
+
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
| 165 |
+
"""
|
| 166 |
+
image_size = get_image_size(input_image, input_data_format)
|
| 167 |
+
height, width = image_size
|
| 168 |
+
height_scale = max_height / height
|
| 169 |
+
width_scale = max_width / width
|
| 170 |
+
min_scale = min(height_scale, width_scale)
|
| 171 |
+
new_height = int(height * min_scale)
|
| 172 |
+
new_width = int(width * min_scale)
|
| 173 |
+
return new_height, new_width
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# Copied from transformers.models.detr.image_processing_detr.get_numpy_to_framework_fn
|
| 177 |
+
def get_numpy_to_framework_fn(arr) -> Callable:
|
| 178 |
+
"""
|
| 179 |
+
Returns a function that converts a numpy array to the framework of the input array.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
arr (`np.ndarray`): The array to convert.
|
| 183 |
+
"""
|
| 184 |
+
if isinstance(arr, np.ndarray):
|
| 185 |
+
return np.array
|
| 186 |
+
if is_tf_available() and is_tf_tensor(arr):
|
| 187 |
+
import tensorflow as tf
|
| 188 |
+
|
| 189 |
+
return tf.convert_to_tensor
|
| 190 |
+
if is_torch_available() and is_torch_tensor(arr):
|
| 191 |
+
import torch
|
| 192 |
+
|
| 193 |
+
return torch.tensor
|
| 194 |
+
if is_flax_available() and is_jax_tensor(arr):
|
| 195 |
+
import jax.numpy as jnp
|
| 196 |
+
|
| 197 |
+
return jnp.array
|
| 198 |
+
raise ValueError(f"Cannot convert arrays of type {type(arr)}")
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# Copied from transformers.models.detr.image_processing_detr.safe_squeeze
|
| 202 |
+
def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray:
|
| 203 |
+
"""
|
| 204 |
+
Squeezes an array, but only if the axis specified has dim 1.
|
| 205 |
+
"""
|
| 206 |
+
if axis is None:
|
| 207 |
+
return arr.squeeze()
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
return arr.squeeze(axis=axis)
|
| 211 |
+
except ValueError:
|
| 212 |
+
return arr
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# Copied from transformers.models.detr.image_processing_detr.normalize_annotation
|
| 216 |
+
def normalize_annotation(annotation: Dict, image_size: Tuple[int, int]) -> Dict:
|
| 217 |
+
image_height, image_width = image_size
|
| 218 |
+
norm_annotation = {}
|
| 219 |
+
for key, value in annotation.items():
|
| 220 |
+
if key == "boxes":
|
| 221 |
+
boxes = value
|
| 222 |
+
boxes = corners_to_center_format(boxes)
|
| 223 |
+
boxes /= np.asarray([image_width, image_height, image_width, image_height], dtype=np.float32)
|
| 224 |
+
norm_annotation[key] = boxes
|
| 225 |
+
else:
|
| 226 |
+
norm_annotation[key] = value
|
| 227 |
+
return norm_annotation
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# Copied from transformers.models.detr.image_processing_detr.max_across_indices
|
| 231 |
+
def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
| 232 |
+
"""
|
| 233 |
+
Return the maximum value across all indices of an iterable of values.
|
| 234 |
+
"""
|
| 235 |
+
return [max(values_i) for values_i in zip(*values)]
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
|
| 239 |
+
def get_max_height_width(
|
| 240 |
+
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
| 241 |
+
) -> List[int]:
|
| 242 |
+
"""
|
| 243 |
+
Get the maximum height and width across all images in a batch.
|
| 244 |
+
"""
|
| 245 |
+
if input_data_format is None:
|
| 246 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 247 |
+
|
| 248 |
+
if input_data_format == ChannelDimension.FIRST:
|
| 249 |
+
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
| 250 |
+
elif input_data_format == ChannelDimension.LAST:
|
| 251 |
+
max_height, max_width, _ = max_across_indices([img.shape for img in images])
|
| 252 |
+
else:
|
| 253 |
+
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
|
| 254 |
+
return (max_height, max_width)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
|
| 258 |
+
def make_pixel_mask(
|
| 259 |
+
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
| 260 |
+
) -> np.ndarray:
|
| 261 |
+
"""
|
| 262 |
+
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
image (`np.ndarray`):
|
| 266 |
+
Image to make the pixel mask for.
|
| 267 |
+
output_size (`Tuple[int, int]`):
|
| 268 |
+
Output size of the mask.
|
| 269 |
+
"""
|
| 270 |
+
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
| 271 |
+
mask = np.zeros(output_size, dtype=np.int64)
|
| 272 |
+
mask[:input_height, :input_width] = 1
|
| 273 |
+
return mask
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def prepare_coco_detection_annotation(
|
| 277 |
+
image,
|
| 278 |
+
target,
|
| 279 |
+
return_segmentation_masks: bool = False,
|
| 280 |
+
input_data_format: Optional[Union[ChannelDimension, str]] = None,
|
| 281 |
+
):
|
| 282 |
+
"""
|
| 283 |
+
Convert the target in COCO format into the format expected by RTDETR.
|
| 284 |
+
"""
|
| 285 |
+
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
| 286 |
+
|
| 287 |
+
image_id = target["image_id"]
|
| 288 |
+
image_id = np.asarray([image_id], dtype=np.int64)
|
| 289 |
+
|
| 290 |
+
# Get all COCO annotations for the given image.
|
| 291 |
+
annotations = target["annotations"]
|
| 292 |
+
annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0]
|
| 293 |
+
|
| 294 |
+
classes = [obj["category_id"] for obj in annotations]
|
| 295 |
+
classes = np.asarray(classes, dtype=np.int64)
|
| 296 |
+
|
| 297 |
+
# for conversion to coco api
|
| 298 |
+
area = np.asarray([obj["area"] for obj in annotations], dtype=np.float32)
|
| 299 |
+
iscrowd = np.asarray([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in annotations], dtype=np.int64)
|
| 300 |
+
|
| 301 |
+
boxes = [obj["bbox"] for obj in annotations]
|
| 302 |
+
# guard against no boxes via resizing
|
| 303 |
+
boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4)
|
| 304 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 305 |
+
boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
|
| 306 |
+
boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
|
| 307 |
+
|
| 308 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
| 309 |
+
|
| 310 |
+
new_target = {}
|
| 311 |
+
new_target["image_id"] = image_id
|
| 312 |
+
new_target["class_labels"] = classes[keep]
|
| 313 |
+
new_target["boxes"] = boxes[keep]
|
| 314 |
+
new_target["area"] = area[keep]
|
| 315 |
+
new_target["iscrowd"] = iscrowd[keep]
|
| 316 |
+
new_target["orig_size"] = np.asarray([int(image_height), int(image_width)], dtype=np.int64)
|
| 317 |
+
|
| 318 |
+
if annotations and "keypoints" in annotations[0]:
|
| 319 |
+
keypoints = [obj["keypoints"] for obj in annotations]
|
| 320 |
+
# Converting the filtered keypoints list to a numpy array
|
| 321 |
+
keypoints = np.asarray(keypoints, dtype=np.float32)
|
| 322 |
+
# Apply the keep mask here to filter the relevant annotations
|
| 323 |
+
keypoints = keypoints[keep]
|
| 324 |
+
num_keypoints = keypoints.shape[0]
|
| 325 |
+
keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
|
| 326 |
+
new_target["keypoints"] = keypoints
|
| 327 |
+
|
| 328 |
+
return new_target
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# Copied from transformers.models.detr.image_processing_detr.resize_annotation
|
| 332 |
+
def resize_annotation(
|
| 333 |
+
annotation: Dict[str, Any],
|
| 334 |
+
orig_size: Tuple[int, int],
|
| 335 |
+
target_size: Tuple[int, int],
|
| 336 |
+
threshold: float = 0.5,
|
| 337 |
+
resample: PILImageResampling = PILImageResampling.NEAREST,
|
| 338 |
+
):
|
| 339 |
+
"""
|
| 340 |
+
Resizes an annotation to a target size.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
annotation (`Dict[str, Any]`):
|
| 344 |
+
The annotation dictionary.
|
| 345 |
+
orig_size (`Tuple[int, int]`):
|
| 346 |
+
The original size of the input image.
|
| 347 |
+
target_size (`Tuple[int, int]`):
|
| 348 |
+
The target size of the image, as returned by the preprocessing `resize` step.
|
| 349 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 350 |
+
The threshold used to binarize the segmentation masks.
|
| 351 |
+
resample (`PILImageResampling`, defaults to `PILImageResampling.NEAREST`):
|
| 352 |
+
The resampling filter to use when resizing the masks.
|
| 353 |
+
"""
|
| 354 |
+
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(target_size, orig_size))
|
| 355 |
+
ratio_height, ratio_width = ratios
|
| 356 |
+
|
| 357 |
+
new_annotation = {}
|
| 358 |
+
new_annotation["size"] = target_size
|
| 359 |
+
|
| 360 |
+
for key, value in annotation.items():
|
| 361 |
+
if key == "boxes":
|
| 362 |
+
boxes = value
|
| 363 |
+
scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32)
|
| 364 |
+
new_annotation["boxes"] = scaled_boxes
|
| 365 |
+
elif key == "area":
|
| 366 |
+
area = value
|
| 367 |
+
scaled_area = area * (ratio_width * ratio_height)
|
| 368 |
+
new_annotation["area"] = scaled_area
|
| 369 |
+
elif key == "masks":
|
| 370 |
+
masks = value[:, None]
|
| 371 |
+
masks = np.array([resize(mask, target_size, resample=resample) for mask in masks])
|
| 372 |
+
masks = masks.astype(np.float32)
|
| 373 |
+
masks = masks[:, 0] > threshold
|
| 374 |
+
new_annotation["masks"] = masks
|
| 375 |
+
elif key == "size":
|
| 376 |
+
new_annotation["size"] = target_size
|
| 377 |
+
else:
|
| 378 |
+
new_annotation[key] = value
|
| 379 |
+
|
| 380 |
+
return new_annotation
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
class RTDetrImageProcessor(BaseImageProcessor):
|
| 384 |
+
r"""
|
| 385 |
+
Constructs a RT-DETR image processor.
|
| 386 |
+
|
| 387 |
+
Args:
|
| 388 |
+
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
|
| 389 |
+
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
|
| 390 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 391 |
+
Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be
|
| 392 |
+
overridden by the `do_resize` parameter in the `preprocess` method.
|
| 393 |
+
size (`Dict[str, int]` *optional*, defaults to `{"height": 640, "width": 640}`):
|
| 394 |
+
Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter
|
| 395 |
+
in the `preprocess` method. Available options are:
|
| 396 |
+
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
|
| 397 |
+
Do NOT keep the aspect ratio.
|
| 398 |
+
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
|
| 399 |
+
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
|
| 400 |
+
less or equal to `longest_edge`.
|
| 401 |
+
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
|
| 402 |
+
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
|
| 403 |
+
`max_width`.
|
| 404 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 405 |
+
Resampling filter to use if resizing the image.
|
| 406 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 407 |
+
Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
|
| 408 |
+
`do_rescale` parameter in the `preprocess` method.
|
| 409 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 410 |
+
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
| 411 |
+
`preprocess` method.
|
| 412 |
+
Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
|
| 413 |
+
`preprocess` method.
|
| 414 |
+
do_normalize (`bool`, *optional*, defaults to `False`):
|
| 415 |
+
Whether to normalize the image.
|
| 416 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
|
| 417 |
+
Mean values to use when normalizing the image. Can be a single value or a list of values, one for each
|
| 418 |
+
channel. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 419 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
|
| 420 |
+
Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one
|
| 421 |
+
for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 422 |
+
do_convert_annotations (`bool`, *optional*, defaults to `True`):
|
| 423 |
+
Controls whether to convert the annotations to the format expected by the DETR model. Converts the
|
| 424 |
+
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
|
| 425 |
+
Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
|
| 426 |
+
do_pad (`bool`, *optional*, defaults to `False`):
|
| 427 |
+
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
|
| 428 |
+
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
|
| 429 |
+
If `pad_size` is provided, the image will be padded to the specified dimensions.
|
| 430 |
+
Otherwise, the image will be padded to the maximum height and width of the batch.
|
| 431 |
+
pad_size (`Dict[str, int]`, *optional*):
|
| 432 |
+
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
| 433 |
+
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
|
| 434 |
+
height and width in the batch.
|
| 435 |
+
"""
|
| 436 |
+
|
| 437 |
+
model_input_names = ["pixel_values", "pixel_mask"]
|
| 438 |
+
|
| 439 |
+
def __init__(
|
| 440 |
+
self,
|
| 441 |
+
format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION,
|
| 442 |
+
do_resize: bool = True,
|
| 443 |
+
size: Dict[str, int] = None,
|
| 444 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 445 |
+
do_rescale: bool = True,
|
| 446 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 447 |
+
do_normalize: bool = False,
|
| 448 |
+
image_mean: Union[float, List[float]] = None,
|
| 449 |
+
image_std: Union[float, List[float]] = None,
|
| 450 |
+
do_convert_annotations: bool = True,
|
| 451 |
+
do_pad: bool = False,
|
| 452 |
+
pad_size: Optional[Dict[str, int]] = None,
|
| 453 |
+
**kwargs,
|
| 454 |
+
) -> None:
|
| 455 |
+
size = size if size is not None else {"height": 640, "width": 640}
|
| 456 |
+
size = get_size_dict(size, default_to_square=False)
|
| 457 |
+
|
| 458 |
+
if do_convert_annotations is None:
|
| 459 |
+
do_convert_annotations = do_normalize
|
| 460 |
+
|
| 461 |
+
super().__init__(**kwargs)
|
| 462 |
+
self.format = format
|
| 463 |
+
self.do_resize = do_resize
|
| 464 |
+
self.size = size
|
| 465 |
+
self.resample = resample
|
| 466 |
+
self.do_rescale = do_rescale
|
| 467 |
+
self.rescale_factor = rescale_factor
|
| 468 |
+
self.do_normalize = do_normalize
|
| 469 |
+
self.do_convert_annotations = do_convert_annotations
|
| 470 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
| 471 |
+
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
| 472 |
+
self.do_pad = do_pad
|
| 473 |
+
self.pad_size = pad_size
|
| 474 |
+
|
| 475 |
+
def prepare_annotation(
|
| 476 |
+
self,
|
| 477 |
+
image: np.ndarray,
|
| 478 |
+
target: Dict,
|
| 479 |
+
format: Optional[AnnotationFormat] = None,
|
| 480 |
+
return_segmentation_masks: Optional[bool] = None,
|
| 481 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 482 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 483 |
+
) -> Dict:
|
| 484 |
+
"""
|
| 485 |
+
Prepare an annotation for feeding into RTDETR model.
|
| 486 |
+
"""
|
| 487 |
+
format = format if format is not None else self.format
|
| 488 |
+
|
| 489 |
+
if format == AnnotationFormat.COCO_DETECTION:
|
| 490 |
+
return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
|
| 491 |
+
target = prepare_coco_detection_annotation(
|
| 492 |
+
image, target, return_segmentation_masks, input_data_format=input_data_format
|
| 493 |
+
)
|
| 494 |
+
else:
|
| 495 |
+
raise ValueError(f"Format {format} is not supported.")
|
| 496 |
+
return target
|
| 497 |
+
|
| 498 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize
|
| 499 |
+
def resize(
|
| 500 |
+
self,
|
| 501 |
+
image: np.ndarray,
|
| 502 |
+
size: Dict[str, int],
|
| 503 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 504 |
+
data_format: Optional[ChannelDimension] = None,
|
| 505 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 506 |
+
**kwargs,
|
| 507 |
+
) -> np.ndarray:
|
| 508 |
+
"""
|
| 509 |
+
Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
|
| 510 |
+
int, smaller edge of the image will be matched to this number.
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
image (`np.ndarray`):
|
| 514 |
+
Image to resize.
|
| 515 |
+
size (`Dict[str, int]`):
|
| 516 |
+
Size of the image's `(height, width)` dimensions after resizing. Available options are:
|
| 517 |
+
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
|
| 518 |
+
Do NOT keep the aspect ratio.
|
| 519 |
+
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
|
| 520 |
+
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
|
| 521 |
+
less or equal to `longest_edge`.
|
| 522 |
+
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
|
| 523 |
+
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
|
| 524 |
+
`max_width`.
|
| 525 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 526 |
+
Resampling filter to use if resizing the image.
|
| 527 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 528 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 529 |
+
image is used.
|
| 530 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 531 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 532 |
+
"""
|
| 533 |
+
if "max_size" in kwargs:
|
| 534 |
+
logger.warning_once(
|
| 535 |
+
"The `max_size` parameter is deprecated and will be removed in v4.26. "
|
| 536 |
+
"Please specify in `size['longest_edge'] instead`.",
|
| 537 |
+
)
|
| 538 |
+
max_size = kwargs.pop("max_size")
|
| 539 |
+
else:
|
| 540 |
+
max_size = None
|
| 541 |
+
size = get_size_dict(size, max_size=max_size, default_to_square=False)
|
| 542 |
+
if "shortest_edge" in size and "longest_edge" in size:
|
| 543 |
+
new_size = get_resize_output_image_size(
|
| 544 |
+
image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
|
| 545 |
+
)
|
| 546 |
+
elif "max_height" in size and "max_width" in size:
|
| 547 |
+
new_size = get_image_size_for_max_height_width(
|
| 548 |
+
image, size["max_height"], size["max_width"], input_data_format=input_data_format
|
| 549 |
+
)
|
| 550 |
+
elif "height" in size and "width" in size:
|
| 551 |
+
new_size = (size["height"], size["width"])
|
| 552 |
+
else:
|
| 553 |
+
raise ValueError(
|
| 554 |
+
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
|
| 555 |
+
f" {size.keys()}."
|
| 556 |
+
)
|
| 557 |
+
image = resize(
|
| 558 |
+
image,
|
| 559 |
+
size=new_size,
|
| 560 |
+
resample=resample,
|
| 561 |
+
data_format=data_format,
|
| 562 |
+
input_data_format=input_data_format,
|
| 563 |
+
**kwargs,
|
| 564 |
+
)
|
| 565 |
+
return image
|
| 566 |
+
|
| 567 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation
|
| 568 |
+
def resize_annotation(
|
| 569 |
+
self,
|
| 570 |
+
annotation,
|
| 571 |
+
orig_size,
|
| 572 |
+
size,
|
| 573 |
+
resample: PILImageResampling = PILImageResampling.NEAREST,
|
| 574 |
+
) -> Dict:
|
| 575 |
+
"""
|
| 576 |
+
Resize the annotation to match the resized image. If size is an int, smaller edge of the mask will be matched
|
| 577 |
+
to this number.
|
| 578 |
+
"""
|
| 579 |
+
return resize_annotation(annotation, orig_size=orig_size, target_size=size, resample=resample)
|
| 580 |
+
|
| 581 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
|
| 582 |
+
def rescale(
|
| 583 |
+
self,
|
| 584 |
+
image: np.ndarray,
|
| 585 |
+
rescale_factor: float,
|
| 586 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 587 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 588 |
+
) -> np.ndarray:
|
| 589 |
+
"""
|
| 590 |
+
Rescale the image by the given factor. image = image * rescale_factor.
|
| 591 |
+
|
| 592 |
+
Args:
|
| 593 |
+
image (`np.ndarray`):
|
| 594 |
+
Image to rescale.
|
| 595 |
+
rescale_factor (`float`):
|
| 596 |
+
The value to use for rescaling.
|
| 597 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 598 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 599 |
+
image is used. Can be one of:
|
| 600 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 601 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 602 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 603 |
+
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
|
| 604 |
+
one of:
|
| 605 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 606 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 607 |
+
"""
|
| 608 |
+
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
|
| 609 |
+
|
| 610 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation
|
| 611 |
+
def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict:
|
| 612 |
+
"""
|
| 613 |
+
Normalize the boxes in the annotation from `[top_left_x, top_left_y, bottom_right_x, bottom_right_y]` to
|
| 614 |
+
`[center_x, center_y, width, height]` format and from absolute to relative pixel values.
|
| 615 |
+
"""
|
| 616 |
+
return normalize_annotation(annotation, image_size=image_size)
|
| 617 |
+
|
| 618 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._update_annotation_for_padded_image
|
| 619 |
+
def _update_annotation_for_padded_image(
|
| 620 |
+
self,
|
| 621 |
+
annotation: Dict,
|
| 622 |
+
input_image_size: Tuple[int, int],
|
| 623 |
+
output_image_size: Tuple[int, int],
|
| 624 |
+
padding,
|
| 625 |
+
update_bboxes,
|
| 626 |
+
) -> Dict:
|
| 627 |
+
"""
|
| 628 |
+
Update the annotation for a padded image.
|
| 629 |
+
"""
|
| 630 |
+
new_annotation = {}
|
| 631 |
+
new_annotation["size"] = output_image_size
|
| 632 |
+
|
| 633 |
+
for key, value in annotation.items():
|
| 634 |
+
if key == "masks":
|
| 635 |
+
masks = value
|
| 636 |
+
masks = pad(
|
| 637 |
+
masks,
|
| 638 |
+
padding,
|
| 639 |
+
mode=PaddingMode.CONSTANT,
|
| 640 |
+
constant_values=0,
|
| 641 |
+
input_data_format=ChannelDimension.FIRST,
|
| 642 |
+
)
|
| 643 |
+
masks = safe_squeeze(masks, 1)
|
| 644 |
+
new_annotation["masks"] = masks
|
| 645 |
+
elif key == "boxes" and update_bboxes:
|
| 646 |
+
boxes = value
|
| 647 |
+
boxes *= np.asarray(
|
| 648 |
+
[
|
| 649 |
+
input_image_size[1] / output_image_size[1],
|
| 650 |
+
input_image_size[0] / output_image_size[0],
|
| 651 |
+
input_image_size[1] / output_image_size[1],
|
| 652 |
+
input_image_size[0] / output_image_size[0],
|
| 653 |
+
]
|
| 654 |
+
)
|
| 655 |
+
new_annotation["boxes"] = boxes
|
| 656 |
+
elif key == "size":
|
| 657 |
+
new_annotation["size"] = output_image_size
|
| 658 |
+
else:
|
| 659 |
+
new_annotation[key] = value
|
| 660 |
+
return new_annotation
|
| 661 |
+
|
| 662 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image
|
| 663 |
+
def _pad_image(
|
| 664 |
+
self,
|
| 665 |
+
image: np.ndarray,
|
| 666 |
+
output_size: Tuple[int, int],
|
| 667 |
+
annotation: Optional[Dict[str, Any]] = None,
|
| 668 |
+
constant_values: Union[float, Iterable[float]] = 0,
|
| 669 |
+
data_format: Optional[ChannelDimension] = None,
|
| 670 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 671 |
+
update_bboxes: bool = True,
|
| 672 |
+
) -> np.ndarray:
|
| 673 |
+
"""
|
| 674 |
+
Pad an image with zeros to the given size.
|
| 675 |
+
"""
|
| 676 |
+
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
| 677 |
+
output_height, output_width = output_size
|
| 678 |
+
|
| 679 |
+
pad_bottom = output_height - input_height
|
| 680 |
+
pad_right = output_width - input_width
|
| 681 |
+
padding = ((0, pad_bottom), (0, pad_right))
|
| 682 |
+
padded_image = pad(
|
| 683 |
+
image,
|
| 684 |
+
padding,
|
| 685 |
+
mode=PaddingMode.CONSTANT,
|
| 686 |
+
constant_values=constant_values,
|
| 687 |
+
data_format=data_format,
|
| 688 |
+
input_data_format=input_data_format,
|
| 689 |
+
)
|
| 690 |
+
if annotation is not None:
|
| 691 |
+
annotation = self._update_annotation_for_padded_image(
|
| 692 |
+
annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes
|
| 693 |
+
)
|
| 694 |
+
return padded_image, annotation
|
| 695 |
+
|
| 696 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.pad
|
| 697 |
+
def pad(
|
| 698 |
+
self,
|
| 699 |
+
images: List[np.ndarray],
|
| 700 |
+
annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
|
| 701 |
+
constant_values: Union[float, Iterable[float]] = 0,
|
| 702 |
+
return_pixel_mask: bool = True,
|
| 703 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 704 |
+
data_format: Optional[ChannelDimension] = None,
|
| 705 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 706 |
+
update_bboxes: bool = True,
|
| 707 |
+
pad_size: Optional[Dict[str, int]] = None,
|
| 708 |
+
) -> BatchFeature:
|
| 709 |
+
"""
|
| 710 |
+
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
|
| 711 |
+
in the batch and optionally returns their corresponding pixel mask.
|
| 712 |
+
|
| 713 |
+
Args:
|
| 714 |
+
images (List[`np.ndarray`]):
|
| 715 |
+
Images to pad.
|
| 716 |
+
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
|
| 717 |
+
Annotations to transform according to the padding that is applied to the images.
|
| 718 |
+
constant_values (`float` or `Iterable[float]`, *optional*):
|
| 719 |
+
The value to use for the padding if `mode` is `"constant"`.
|
| 720 |
+
return_pixel_mask (`bool`, *optional*, defaults to `True`):
|
| 721 |
+
Whether to return a pixel mask.
|
| 722 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 723 |
+
The type of tensors to return. Can be one of:
|
| 724 |
+
- Unset: Return a list of `np.ndarray`.
|
| 725 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 726 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 727 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 728 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 729 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 730 |
+
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
| 731 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 732 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 733 |
+
update_bboxes (`bool`, *optional*, defaults to `True`):
|
| 734 |
+
Whether to update the bounding boxes in the annotations to match the padded images. If the
|
| 735 |
+
bounding boxes have not been converted to relative coordinates and `(centre_x, centre_y, width, height)`
|
| 736 |
+
format, the bounding boxes will not be updated.
|
| 737 |
+
pad_size (`Dict[str, int]`, *optional*):
|
| 738 |
+
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
| 739 |
+
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
|
| 740 |
+
height and width in the batch.
|
| 741 |
+
"""
|
| 742 |
+
pad_size = pad_size if pad_size is not None else self.pad_size
|
| 743 |
+
if pad_size is not None:
|
| 744 |
+
padded_size = (pad_size["height"], pad_size["width"])
|
| 745 |
+
else:
|
| 746 |
+
padded_size = get_max_height_width(images, input_data_format=input_data_format)
|
| 747 |
+
|
| 748 |
+
annotation_list = annotations if annotations is not None else [None] * len(images)
|
| 749 |
+
padded_images = []
|
| 750 |
+
padded_annotations = []
|
| 751 |
+
for image, annotation in zip(images, annotation_list):
|
| 752 |
+
padded_image, padded_annotation = self._pad_image(
|
| 753 |
+
image,
|
| 754 |
+
padded_size,
|
| 755 |
+
annotation,
|
| 756 |
+
constant_values=constant_values,
|
| 757 |
+
data_format=data_format,
|
| 758 |
+
input_data_format=input_data_format,
|
| 759 |
+
update_bboxes=update_bboxes,
|
| 760 |
+
)
|
| 761 |
+
padded_images.append(padded_image)
|
| 762 |
+
padded_annotations.append(padded_annotation)
|
| 763 |
+
|
| 764 |
+
data = {"pixel_values": padded_images}
|
| 765 |
+
|
| 766 |
+
if return_pixel_mask:
|
| 767 |
+
masks = [
|
| 768 |
+
make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format)
|
| 769 |
+
for image in images
|
| 770 |
+
]
|
| 771 |
+
data["pixel_mask"] = masks
|
| 772 |
+
|
| 773 |
+
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
| 774 |
+
|
| 775 |
+
if annotations is not None:
|
| 776 |
+
encoded_inputs["labels"] = [
|
| 777 |
+
BatchFeature(annotation, tensor_type=return_tensors) for annotation in padded_annotations
|
| 778 |
+
]
|
| 779 |
+
|
| 780 |
+
return encoded_inputs
|
| 781 |
+
|
| 782 |
+
@filter_out_non_signature_kwargs()
|
| 783 |
+
def preprocess(
|
| 784 |
+
self,
|
| 785 |
+
images: ImageInput,
|
| 786 |
+
annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
|
| 787 |
+
return_segmentation_masks: Optional[bool] = None,
|
| 788 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 789 |
+
do_resize: Optional[bool] = None,
|
| 790 |
+
size: Optional[Dict[str, int]] = None,
|
| 791 |
+
resample=None, # PILImageResampling
|
| 792 |
+
do_rescale: Optional[bool] = None,
|
| 793 |
+
rescale_factor: Optional[Union[int, float]] = None,
|
| 794 |
+
do_normalize: Optional[bool] = None,
|
| 795 |
+
do_convert_annotations: Optional[bool] = None,
|
| 796 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 797 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 798 |
+
do_pad: Optional[bool] = None,
|
| 799 |
+
format: Optional[Union[str, AnnotationFormat]] = None,
|
| 800 |
+
return_tensors: Optional[Union[TensorType, str]] = None,
|
| 801 |
+
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
| 802 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 803 |
+
pad_size: Optional[Dict[str, int]] = None,
|
| 804 |
+
) -> BatchFeature:
|
| 805 |
+
"""
|
| 806 |
+
Preprocess an image or a batch of images so that it can be used by the model.
|
| 807 |
+
|
| 808 |
+
Args:
|
| 809 |
+
images (`ImageInput`):
|
| 810 |
+
Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging
|
| 811 |
+
from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 812 |
+
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
|
| 813 |
+
List of annotations associated with the image or batch of images. If annotation is for object
|
| 814 |
+
detection, the annotations should be a dictionary with the following keys:
|
| 815 |
+
- "image_id" (`int`): The image id.
|
| 816 |
+
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
|
| 817 |
+
dictionary. An image can have no annotations, in which case the list should be empty.
|
| 818 |
+
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
|
| 819 |
+
- "image_id" (`int`): The image id.
|
| 820 |
+
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
| 821 |
+
An image can have no segments, in which case the list should be empty.
|
| 822 |
+
- "file_name" (`str`): The file name of the image.
|
| 823 |
+
return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks):
|
| 824 |
+
Whether to return segmentation masks.
|
| 825 |
+
masks_path (`str` or `pathlib.Path`, *optional*):
|
| 826 |
+
Path to the directory containing the segmentation masks.
|
| 827 |
+
do_resize (`bool`, *optional*, defaults to self.do_resize):
|
| 828 |
+
Whether to resize the image.
|
| 829 |
+
size (`Dict[str, int]`, *optional*, defaults to self.size):
|
| 830 |
+
Size of the image's `(height, width)` dimensions after resizing. Available options are:
|
| 831 |
+
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
|
| 832 |
+
Do NOT keep the aspect ratio.
|
| 833 |
+
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
|
| 834 |
+
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
|
| 835 |
+
less or equal to `longest_edge`.
|
| 836 |
+
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
|
| 837 |
+
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
|
| 838 |
+
`max_width`.
|
| 839 |
+
resample (`PILImageResampling`, *optional*, defaults to self.resample):
|
| 840 |
+
Resampling filter to use when resizing the image.
|
| 841 |
+
do_rescale (`bool`, *optional*, defaults to self.do_rescale):
|
| 842 |
+
Whether to rescale the image.
|
| 843 |
+
rescale_factor (`float`, *optional*, defaults to self.rescale_factor):
|
| 844 |
+
Rescale factor to use when rescaling the image.
|
| 845 |
+
do_normalize (`bool`, *optional*, defaults to self.do_normalize):
|
| 846 |
+
Whether to normalize the image.
|
| 847 |
+
do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations):
|
| 848 |
+
Whether to convert the annotations to the format expected by the model. Converts the bounding
|
| 849 |
+
boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)`
|
| 850 |
+
and in relative coordinates.
|
| 851 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to self.image_mean):
|
| 852 |
+
Mean to use when normalizing the image.
|
| 853 |
+
image_std (`float` or `List[float]`, *optional*, defaults to self.image_std):
|
| 854 |
+
Standard deviation to use when normalizing the image.
|
| 855 |
+
do_pad (`bool`, *optional*, defaults to self.do_pad):
|
| 856 |
+
Whether to pad the image. If `True`, padding will be applied to the bottom and right of
|
| 857 |
+
the image with zeros. If `pad_size` is provided, the image will be padded to the specified
|
| 858 |
+
dimensions. Otherwise, the image will be padded to the maximum height and width of the batch.
|
| 859 |
+
format (`str` or `AnnotationFormat`, *optional*, defaults to self.format):
|
| 860 |
+
Format of the annotations.
|
| 861 |
+
return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
|
| 862 |
+
Type of tensors to return. If `None`, will return the list of images.
|
| 863 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 864 |
+
The channel dimension format for the output image. Can be one of:
|
| 865 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 866 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 867 |
+
- Unset: Use the channel dimension format of the input image.
|
| 868 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 869 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 870 |
+
from the input image. Can be one of:
|
| 871 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 872 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 873 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 874 |
+
pad_size (`Dict[str, int]`, *optional*):
|
| 875 |
+
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
| 876 |
+
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
|
| 877 |
+
height and width in the batch.
|
| 878 |
+
"""
|
| 879 |
+
do_resize = self.do_resize if do_resize is None else do_resize
|
| 880 |
+
size = self.size if size is None else size
|
| 881 |
+
size = get_size_dict(size=size, default_to_square=True)
|
| 882 |
+
resample = self.resample if resample is None else resample
|
| 883 |
+
do_rescale = self.do_rescale if do_rescale is None else do_rescale
|
| 884 |
+
rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
|
| 885 |
+
do_normalize = self.do_normalize if do_normalize is None else do_normalize
|
| 886 |
+
image_mean = self.image_mean if image_mean is None else image_mean
|
| 887 |
+
image_std = self.image_std if image_std is None else image_std
|
| 888 |
+
do_convert_annotations = (
|
| 889 |
+
self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations
|
| 890 |
+
)
|
| 891 |
+
do_pad = self.do_pad if do_pad is None else do_pad
|
| 892 |
+
pad_size = self.pad_size if pad_size is None else pad_size
|
| 893 |
+
format = self.format if format is None else format
|
| 894 |
+
|
| 895 |
+
images = make_list_of_images(images)
|
| 896 |
+
|
| 897 |
+
if not valid_images(images):
|
| 898 |
+
raise ValueError(
|
| 899 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 900 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
# Here, the pad() method pads to the maximum of (width, height). It does not need to be validated.
|
| 904 |
+
|
| 905 |
+
validate_preprocess_arguments(
|
| 906 |
+
do_rescale=do_rescale,
|
| 907 |
+
rescale_factor=rescale_factor,
|
| 908 |
+
do_normalize=do_normalize,
|
| 909 |
+
image_mean=image_mean,
|
| 910 |
+
image_std=image_std,
|
| 911 |
+
do_resize=do_resize,
|
| 912 |
+
size=size,
|
| 913 |
+
resample=resample,
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
if annotations is not None and isinstance(annotations, dict):
|
| 917 |
+
annotations = [annotations]
|
| 918 |
+
|
| 919 |
+
if annotations is not None and len(images) != len(annotations):
|
| 920 |
+
raise ValueError(
|
| 921 |
+
f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
|
| 922 |
+
)
|
| 923 |
+
|
| 924 |
+
format = AnnotationFormat(format)
|
| 925 |
+
if annotations is not None:
|
| 926 |
+
validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
|
| 927 |
+
|
| 928 |
+
images = make_list_of_images(images)
|
| 929 |
+
if not valid_images(images):
|
| 930 |
+
raise ValueError(
|
| 931 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 932 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
# All transformations expect numpy arrays
|
| 936 |
+
images = [to_numpy_array(image) for image in images]
|
| 937 |
+
|
| 938 |
+
if do_rescale and is_scaled_image(images[0]):
|
| 939 |
+
logger.warning_once(
|
| 940 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 941 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 942 |
+
)
|
| 943 |
+
|
| 944 |
+
if input_data_format is None:
|
| 945 |
+
# We assume that all images have the same channel dimension format.
|
| 946 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 947 |
+
|
| 948 |
+
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
|
| 949 |
+
if annotations is not None:
|
| 950 |
+
prepared_images = []
|
| 951 |
+
prepared_annotations = []
|
| 952 |
+
for image, target in zip(images, annotations):
|
| 953 |
+
target = self.prepare_annotation(
|
| 954 |
+
image,
|
| 955 |
+
target,
|
| 956 |
+
format,
|
| 957 |
+
return_segmentation_masks=return_segmentation_masks,
|
| 958 |
+
masks_path=masks_path,
|
| 959 |
+
input_data_format=input_data_format,
|
| 960 |
+
)
|
| 961 |
+
prepared_images.append(image)
|
| 962 |
+
prepared_annotations.append(target)
|
| 963 |
+
images = prepared_images
|
| 964 |
+
annotations = prepared_annotations
|
| 965 |
+
del prepared_images, prepared_annotations
|
| 966 |
+
|
| 967 |
+
# transformations
|
| 968 |
+
if do_resize:
|
| 969 |
+
if annotations is not None:
|
| 970 |
+
resized_images, resized_annotations = [], []
|
| 971 |
+
for image, target in zip(images, annotations):
|
| 972 |
+
orig_size = get_image_size(image, input_data_format)
|
| 973 |
+
resized_image = self.resize(
|
| 974 |
+
image, size=size, resample=resample, input_data_format=input_data_format
|
| 975 |
+
)
|
| 976 |
+
resized_annotation = self.resize_annotation(
|
| 977 |
+
target, orig_size, get_image_size(resized_image, input_data_format)
|
| 978 |
+
)
|
| 979 |
+
resized_images.append(resized_image)
|
| 980 |
+
resized_annotations.append(resized_annotation)
|
| 981 |
+
images = resized_images
|
| 982 |
+
annotations = resized_annotations
|
| 983 |
+
del resized_images, resized_annotations
|
| 984 |
+
else:
|
| 985 |
+
images = [
|
| 986 |
+
self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
|
| 987 |
+
for image in images
|
| 988 |
+
]
|
| 989 |
+
|
| 990 |
+
if do_rescale:
|
| 991 |
+
images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
|
| 992 |
+
|
| 993 |
+
if do_normalize:
|
| 994 |
+
images = [
|
| 995 |
+
self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
|
| 996 |
+
]
|
| 997 |
+
|
| 998 |
+
if do_convert_annotations and annotations is not None:
|
| 999 |
+
annotations = [
|
| 1000 |
+
self.normalize_annotation(annotation, get_image_size(image, input_data_format))
|
| 1001 |
+
for annotation, image in zip(annotations, images)
|
| 1002 |
+
]
|
| 1003 |
+
|
| 1004 |
+
if do_pad:
|
| 1005 |
+
# Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
|
| 1006 |
+
encoded_inputs = self.pad(
|
| 1007 |
+
images,
|
| 1008 |
+
annotations=annotations,
|
| 1009 |
+
return_pixel_mask=True,
|
| 1010 |
+
data_format=data_format,
|
| 1011 |
+
input_data_format=input_data_format,
|
| 1012 |
+
update_bboxes=do_convert_annotations,
|
| 1013 |
+
return_tensors=return_tensors,
|
| 1014 |
+
pad_size=pad_size,
|
| 1015 |
+
)
|
| 1016 |
+
else:
|
| 1017 |
+
images = [
|
| 1018 |
+
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 1019 |
+
for image in images
|
| 1020 |
+
]
|
| 1021 |
+
encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
| 1022 |
+
if annotations is not None:
|
| 1023 |
+
encoded_inputs["labels"] = [
|
| 1024 |
+
BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
|
| 1025 |
+
]
|
| 1026 |
+
|
| 1027 |
+
return encoded_inputs
|
| 1028 |
+
|
| 1029 |
+
def post_process_object_detection(
|
| 1030 |
+
self,
|
| 1031 |
+
outputs,
|
| 1032 |
+
threshold: float = 0.5,
|
| 1033 |
+
target_sizes: Union[TensorType, List[Tuple]] = None,
|
| 1034 |
+
use_focal_loss: bool = True,
|
| 1035 |
+
):
|
| 1036 |
+
"""
|
| 1037 |
+
Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
| 1038 |
+
bottom_right_x, bottom_right_y) format. Only supports PyTorch.
|
| 1039 |
+
|
| 1040 |
+
Args:
|
| 1041 |
+
outputs ([`DetrObjectDetectionOutput`]):
|
| 1042 |
+
Raw outputs of the model.
|
| 1043 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 1044 |
+
Score threshold to keep object detection predictions.
|
| 1045 |
+
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
|
| 1046 |
+
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
| 1047 |
+
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
|
| 1048 |
+
use_focal_loss (`bool` defaults to `True`):
|
| 1049 |
+
Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied
|
| 1050 |
+
to compute the scores of each detection, otherwise, a softmax function is used.
|
| 1051 |
+
|
| 1052 |
+
Returns:
|
| 1053 |
+
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
| 1054 |
+
in the batch as predicted by the model.
|
| 1055 |
+
"""
|
| 1056 |
+
requires_backends(self, ["torch"])
|
| 1057 |
+
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
|
| 1058 |
+
# convert from relative cxcywh to absolute xyxy
|
| 1059 |
+
boxes = center_to_corners_format(out_bbox)
|
| 1060 |
+
if target_sizes is not None:
|
| 1061 |
+
if len(out_logits) != len(target_sizes):
|
| 1062 |
+
raise ValueError(
|
| 1063 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
| 1064 |
+
)
|
| 1065 |
+
if isinstance(target_sizes, List):
|
| 1066 |
+
img_h, img_w = torch.as_tensor(target_sizes).unbind(1)
|
| 1067 |
+
else:
|
| 1068 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 1069 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
| 1070 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 1071 |
+
|
| 1072 |
+
num_top_queries = out_logits.shape[1]
|
| 1073 |
+
num_classes = out_logits.shape[2]
|
| 1074 |
+
|
| 1075 |
+
if use_focal_loss:
|
| 1076 |
+
scores = torch.nn.functional.sigmoid(out_logits)
|
| 1077 |
+
scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1)
|
| 1078 |
+
labels = index % num_classes
|
| 1079 |
+
index = index // num_classes
|
| 1080 |
+
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
|
| 1081 |
+
else:
|
| 1082 |
+
scores = torch.nn.functional.softmax(out_logits)[:, :, :-1]
|
| 1083 |
+
scores, labels = scores.max(dim=-1)
|
| 1084 |
+
if scores.shape[1] > num_top_queries:
|
| 1085 |
+
scores, index = torch.topk(scores, num_top_queries, dim=-1)
|
| 1086 |
+
labels = torch.gather(labels, dim=1, index=index)
|
| 1087 |
+
boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
|
| 1088 |
+
|
| 1089 |
+
results = []
|
| 1090 |
+
for score, label, box in zip(scores, labels, boxes):
|
| 1091 |
+
results.append(
|
| 1092 |
+
{
|
| 1093 |
+
"scores": score[score > threshold],
|
| 1094 |
+
"labels": label[score > threshold],
|
| 1095 |
+
"boxes": box[score > threshold],
|
| 1096 |
+
}
|
| 1097 |
+
)
|
| 1098 |
+
|
| 1099 |
+
return results
|
| 1100 |
+
|
| 1101 |
+
|
| 1102 |
+
__all__ = ["RTDetrImageProcessor"]
|
docs/transformers/build/lib/transformers/models/rt_detr/image_processing_rt_detr_fast.py
ADDED
|
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/rt_detr/modular_rt_detr.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_rt_detr.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
import pathlib
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
from ...image_processing_utils import BatchFeature
|
| 11 |
+
from ...image_processing_utils_fast import (
|
| 12 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
| 13 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
| 14 |
+
BaseImageProcessorFast,
|
| 15 |
+
DefaultFastImageProcessorKwargs,
|
| 16 |
+
SizeDict,
|
| 17 |
+
add_start_docstrings,
|
| 18 |
+
get_image_size_for_max_height_width,
|
| 19 |
+
get_max_height_width,
|
| 20 |
+
safe_squeeze,
|
| 21 |
+
)
|
| 22 |
+
from ...image_transforms import center_to_corners_format, corners_to_center_format
|
| 23 |
+
from ...image_utils import (
|
| 24 |
+
IMAGENET_DEFAULT_MEAN,
|
| 25 |
+
IMAGENET_DEFAULT_STD,
|
| 26 |
+
AnnotationFormat,
|
| 27 |
+
AnnotationType,
|
| 28 |
+
ChannelDimension,
|
| 29 |
+
ImageInput,
|
| 30 |
+
PILImageResampling,
|
| 31 |
+
get_image_size,
|
| 32 |
+
validate_annotations,
|
| 33 |
+
)
|
| 34 |
+
from ...processing_utils import Unpack
|
| 35 |
+
from ...utils import (
|
| 36 |
+
TensorType,
|
| 37 |
+
is_torch_available,
|
| 38 |
+
is_torchvision_available,
|
| 39 |
+
is_torchvision_v2_available,
|
| 40 |
+
requires_backends,
|
| 41 |
+
)
|
| 42 |
+
from ...utils.import_utils import requires
|
| 43 |
+
from .image_processing_rt_detr import get_size_with_aspect_ratio
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if is_torch_available():
|
| 47 |
+
import torch
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if is_torchvision_v2_available():
|
| 51 |
+
from torchvision.transforms.v2 import functional as F
|
| 52 |
+
elif is_torchvision_available():
|
| 53 |
+
from torchvision.transforms import functional as F
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class RTDetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
| 57 |
+
format: Optional[Union[str, AnnotationFormat]]
|
| 58 |
+
do_convert_annotations: Optional[bool]
|
| 59 |
+
do_pad: Optional[bool]
|
| 60 |
+
pad_size: Optional[Dict[str, int]]
|
| 61 |
+
return_segmentation_masks: Optional[bool]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def prepare_coco_detection_annotation(
|
| 68 |
+
image,
|
| 69 |
+
target,
|
| 70 |
+
return_segmentation_masks: bool = False,
|
| 71 |
+
input_data_format: Optional[Union[ChannelDimension, str]] = None,
|
| 72 |
+
):
|
| 73 |
+
"""
|
| 74 |
+
Convert the target in COCO format into the format expected by RT-DETR.
|
| 75 |
+
"""
|
| 76 |
+
image_height, image_width = image.size()[-2:]
|
| 77 |
+
|
| 78 |
+
image_id = target["image_id"]
|
| 79 |
+
image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device)
|
| 80 |
+
|
| 81 |
+
# Get all COCO annotations for the given image.
|
| 82 |
+
annotations = target["annotations"]
|
| 83 |
+
classes = []
|
| 84 |
+
area = []
|
| 85 |
+
boxes = []
|
| 86 |
+
keypoints = []
|
| 87 |
+
for obj in annotations:
|
| 88 |
+
if "iscrowd" not in obj or obj["iscrowd"] == 0:
|
| 89 |
+
classes.append(obj["category_id"])
|
| 90 |
+
area.append(obj["area"])
|
| 91 |
+
boxes.append(obj["bbox"])
|
| 92 |
+
if "keypoints" in obj:
|
| 93 |
+
keypoints.append(obj["keypoints"])
|
| 94 |
+
|
| 95 |
+
classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device)
|
| 96 |
+
area = torch.as_tensor(area, dtype=torch.float32, device=image.device)
|
| 97 |
+
iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device)
|
| 98 |
+
# guard against no boxes via resizing
|
| 99 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4)
|
| 100 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 101 |
+
boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
|
| 102 |
+
boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
|
| 103 |
+
|
| 104 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
| 105 |
+
|
| 106 |
+
new_target = {
|
| 107 |
+
"image_id": image_id,
|
| 108 |
+
"class_labels": classes[keep],
|
| 109 |
+
"boxes": boxes[keep],
|
| 110 |
+
"area": area[keep],
|
| 111 |
+
"iscrowd": iscrowd[keep],
|
| 112 |
+
"orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device),
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
if keypoints:
|
| 116 |
+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device)
|
| 117 |
+
# Apply the keep mask here to filter the relevant annotations
|
| 118 |
+
keypoints = keypoints[keep]
|
| 119 |
+
num_keypoints = keypoints.shape[0]
|
| 120 |
+
keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
|
| 121 |
+
new_target["keypoints"] = keypoints
|
| 122 |
+
|
| 123 |
+
return new_target
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@add_start_docstrings(
|
| 127 |
+
"Constructs a fast RTDetr image processor.",
|
| 128 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
| 129 |
+
"""
|
| 130 |
+
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
|
| 131 |
+
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
|
| 132 |
+
do_convert_annotations (`bool`, *optional*, defaults to `True`):
|
| 133 |
+
Controls whether to convert the annotations to the format expected by the RT_DETR model. Converts the
|
| 134 |
+
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
|
| 135 |
+
Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
|
| 136 |
+
do_pad (`bool`, *optional*, defaults to `True`):
|
| 137 |
+
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
|
| 138 |
+
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
|
| 139 |
+
If `pad_size` is provided, the image will be padded to the specified dimensions.
|
| 140 |
+
Otherwise, the image will be padded to the maximum height and width of the batch.
|
| 141 |
+
pad_size (`Dict[str, int]`, *optional*):
|
| 142 |
+
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
| 143 |
+
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
|
| 144 |
+
height and width in the batch.
|
| 145 |
+
return_segmentation_masks (`bool`, *optional*, defaults to `False`):
|
| 146 |
+
Whether to return segmentation masks.
|
| 147 |
+
""",
|
| 148 |
+
)
|
| 149 |
+
@requires(backends=("torchvision", "torch"))
|
| 150 |
+
class RTDetrImageProcessorFast(BaseImageProcessorFast):
|
| 151 |
+
resample = PILImageResampling.BILINEAR
|
| 152 |
+
image_mean = IMAGENET_DEFAULT_MEAN
|
| 153 |
+
image_std = IMAGENET_DEFAULT_STD
|
| 154 |
+
format = AnnotationFormat.COCO_DETECTION
|
| 155 |
+
do_resize = True
|
| 156 |
+
do_rescale = True
|
| 157 |
+
do_normalize = False
|
| 158 |
+
do_pad = False
|
| 159 |
+
size = {"height": 640, "width": 640}
|
| 160 |
+
default_to_square = False
|
| 161 |
+
model_input_names = ["pixel_values", "pixel_mask"]
|
| 162 |
+
valid_kwargs = RTDetrFastImageProcessorKwargs
|
| 163 |
+
do_convert_annotations = True
|
| 164 |
+
|
| 165 |
+
def __init__(self, **kwargs: Unpack[RTDetrFastImageProcessorKwargs]) -> None:
|
| 166 |
+
# Backwards compatibility
|
| 167 |
+
do_convert_annotations = kwargs.get("do_convert_annotations", None)
|
| 168 |
+
do_normalize = kwargs.get("do_normalize", None)
|
| 169 |
+
if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None:
|
| 170 |
+
self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize
|
| 171 |
+
|
| 172 |
+
super().__init__(**kwargs)
|
| 173 |
+
|
| 174 |
+
def prepare_annotation(
|
| 175 |
+
self,
|
| 176 |
+
image: torch.Tensor,
|
| 177 |
+
target: Dict,
|
| 178 |
+
format: Optional[AnnotationFormat] = None,
|
| 179 |
+
return_segmentation_masks: Optional[bool] = None,
|
| 180 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 181 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 182 |
+
) -> Dict:
|
| 183 |
+
"""
|
| 184 |
+
Prepare an annotation for feeding into RT_DETR model.
|
| 185 |
+
"""
|
| 186 |
+
format = format if format is not None else self.format
|
| 187 |
+
|
| 188 |
+
if format == AnnotationFormat.COCO_DETECTION:
|
| 189 |
+
return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
|
| 190 |
+
target = prepare_coco_detection_annotation(
|
| 191 |
+
image, target, return_segmentation_masks, input_data_format=input_data_format
|
| 192 |
+
)
|
| 193 |
+
else:
|
| 194 |
+
raise ValueError(f"Format {format} is not supported.")
|
| 195 |
+
return target
|
| 196 |
+
|
| 197 |
+
def resize(
|
| 198 |
+
self,
|
| 199 |
+
image: torch.Tensor,
|
| 200 |
+
size: SizeDict,
|
| 201 |
+
interpolation: "F.InterpolationMode" = None,
|
| 202 |
+
**kwargs,
|
| 203 |
+
) -> torch.Tensor:
|
| 204 |
+
"""
|
| 205 |
+
Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
|
| 206 |
+
int, smaller edge of the image will be matched to this number.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
image (`torch.Tensor`):
|
| 210 |
+
Image to resize.
|
| 211 |
+
size (`SizeDict`):
|
| 212 |
+
Size of the image's `(height, width)` dimensions after resizing. Available options are:
|
| 213 |
+
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
|
| 214 |
+
Do NOT keep the aspect ratio.
|
| 215 |
+
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
|
| 216 |
+
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
|
| 217 |
+
less or equal to `longest_edge`.
|
| 218 |
+
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
|
| 219 |
+
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
|
| 220 |
+
`max_width`.
|
| 221 |
+
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
| 222 |
+
Resampling filter to use if resizing the image.
|
| 223 |
+
"""
|
| 224 |
+
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
|
| 225 |
+
if size.shortest_edge and size.longest_edge:
|
| 226 |
+
# Resize the image so that the shortest edge or the longest edge is of the given size
|
| 227 |
+
# while maintaining the aspect ratio of the original image.
|
| 228 |
+
new_size = get_size_with_aspect_ratio(
|
| 229 |
+
image.size()[-2:],
|
| 230 |
+
size["shortest_edge"],
|
| 231 |
+
size["longest_edge"],
|
| 232 |
+
)
|
| 233 |
+
elif size.max_height and size.max_width:
|
| 234 |
+
new_size = get_image_size_for_max_height_width(image.size()[-2:], size["max_height"], size["max_width"])
|
| 235 |
+
elif size.height and size.width:
|
| 236 |
+
new_size = (size["height"], size["width"])
|
| 237 |
+
else:
|
| 238 |
+
raise ValueError(
|
| 239 |
+
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
|
| 240 |
+
f" {size.keys()}."
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
image = F.resize(
|
| 244 |
+
image,
|
| 245 |
+
size=new_size,
|
| 246 |
+
interpolation=interpolation,
|
| 247 |
+
**kwargs,
|
| 248 |
+
)
|
| 249 |
+
return image
|
| 250 |
+
|
| 251 |
+
def resize_annotation(
|
| 252 |
+
self,
|
| 253 |
+
annotation: Dict[str, Any],
|
| 254 |
+
orig_size: Tuple[int, int],
|
| 255 |
+
target_size: Tuple[int, int],
|
| 256 |
+
threshold: float = 0.5,
|
| 257 |
+
interpolation: "F.InterpolationMode" = None,
|
| 258 |
+
):
|
| 259 |
+
"""
|
| 260 |
+
Resizes an annotation to a target size.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
annotation (`Dict[str, Any]`):
|
| 264 |
+
The annotation dictionary.
|
| 265 |
+
orig_size (`Tuple[int, int]`):
|
| 266 |
+
The original size of the input image.
|
| 267 |
+
target_size (`Tuple[int, int]`):
|
| 268 |
+
The target size of the image, as returned by the preprocessing `resize` step.
|
| 269 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 270 |
+
The threshold used to binarize the segmentation masks.
|
| 271 |
+
resample (`InterpolationMode`, defaults to `InterpolationMode.NEAREST`):
|
| 272 |
+
The resampling filter to use when resizing the masks.
|
| 273 |
+
"""
|
| 274 |
+
interpolation = interpolation if interpolation is not None else F.InterpolationMode.NEAREST
|
| 275 |
+
ratio_height, ratio_width = [target / orig for target, orig in zip(target_size, orig_size)]
|
| 276 |
+
|
| 277 |
+
new_annotation = {}
|
| 278 |
+
new_annotation["size"] = target_size
|
| 279 |
+
|
| 280 |
+
for key, value in annotation.items():
|
| 281 |
+
if key == "boxes":
|
| 282 |
+
boxes = value
|
| 283 |
+
scaled_boxes = boxes * torch.as_tensor(
|
| 284 |
+
[ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32, device=boxes.device
|
| 285 |
+
)
|
| 286 |
+
new_annotation["boxes"] = scaled_boxes
|
| 287 |
+
elif key == "area":
|
| 288 |
+
area = value
|
| 289 |
+
scaled_area = area * (ratio_width * ratio_height)
|
| 290 |
+
new_annotation["area"] = scaled_area
|
| 291 |
+
elif key == "masks":
|
| 292 |
+
masks = value[:, None]
|
| 293 |
+
masks = [F.resize(mask, target_size, interpolation=interpolation) for mask in masks]
|
| 294 |
+
masks = torch.stack(masks).to(torch.float32)
|
| 295 |
+
masks = masks[:, 0] > threshold
|
| 296 |
+
new_annotation["masks"] = masks
|
| 297 |
+
elif key == "size":
|
| 298 |
+
new_annotation["size"] = target_size
|
| 299 |
+
else:
|
| 300 |
+
new_annotation[key] = value
|
| 301 |
+
|
| 302 |
+
return new_annotation
|
| 303 |
+
|
| 304 |
+
def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict:
|
| 305 |
+
image_height, image_width = image_size
|
| 306 |
+
norm_annotation = {}
|
| 307 |
+
for key, value in annotation.items():
|
| 308 |
+
if key == "boxes":
|
| 309 |
+
boxes = value
|
| 310 |
+
boxes = corners_to_center_format(boxes)
|
| 311 |
+
boxes /= torch.as_tensor(
|
| 312 |
+
[image_width, image_height, image_width, image_height], dtype=torch.float32, device=boxes.device
|
| 313 |
+
)
|
| 314 |
+
norm_annotation[key] = boxes
|
| 315 |
+
else:
|
| 316 |
+
norm_annotation[key] = value
|
| 317 |
+
return norm_annotation
|
| 318 |
+
|
| 319 |
+
def _update_annotation_for_padded_image(
|
| 320 |
+
self,
|
| 321 |
+
annotation: Dict,
|
| 322 |
+
input_image_size: Tuple[int, int],
|
| 323 |
+
output_image_size: Tuple[int, int],
|
| 324 |
+
padding,
|
| 325 |
+
update_bboxes,
|
| 326 |
+
) -> Dict:
|
| 327 |
+
"""
|
| 328 |
+
Update the annotation for a padded image.
|
| 329 |
+
"""
|
| 330 |
+
new_annotation = {}
|
| 331 |
+
new_annotation["size"] = output_image_size
|
| 332 |
+
ratio_height, ratio_width = (input / output for output, input in zip(output_image_size, input_image_size))
|
| 333 |
+
|
| 334 |
+
for key, value in annotation.items():
|
| 335 |
+
if key == "masks":
|
| 336 |
+
masks = value
|
| 337 |
+
masks = F.pad(
|
| 338 |
+
masks,
|
| 339 |
+
padding,
|
| 340 |
+
fill=0,
|
| 341 |
+
)
|
| 342 |
+
masks = safe_squeeze(masks, 1)
|
| 343 |
+
new_annotation["masks"] = masks
|
| 344 |
+
elif key == "boxes" and update_bboxes:
|
| 345 |
+
boxes = value
|
| 346 |
+
boxes *= torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height], device=boxes.device)
|
| 347 |
+
new_annotation["boxes"] = boxes
|
| 348 |
+
elif key == "size":
|
| 349 |
+
new_annotation["size"] = output_image_size
|
| 350 |
+
else:
|
| 351 |
+
new_annotation[key] = value
|
| 352 |
+
return new_annotation
|
| 353 |
+
|
| 354 |
+
def pad(
|
| 355 |
+
self,
|
| 356 |
+
image: torch.Tensor,
|
| 357 |
+
padded_size: Tuple[int, int],
|
| 358 |
+
annotation: Optional[Dict[str, Any]] = None,
|
| 359 |
+
update_bboxes: bool = True,
|
| 360 |
+
fill: int = 0,
|
| 361 |
+
):
|
| 362 |
+
original_size = image.size()[-2:]
|
| 363 |
+
padding_bottom = padded_size[0] - original_size[0]
|
| 364 |
+
padding_right = padded_size[1] - original_size[1]
|
| 365 |
+
if padding_bottom < 0 or padding_right < 0:
|
| 366 |
+
raise ValueError(
|
| 367 |
+
f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
|
| 368 |
+
f"original size. Got padded size: {padded_size}, original size: {original_size}."
|
| 369 |
+
)
|
| 370 |
+
if original_size != padded_size:
|
| 371 |
+
padding = [0, 0, padding_right, padding_bottom]
|
| 372 |
+
image = F.pad(image, padding, fill=fill)
|
| 373 |
+
if annotation is not None:
|
| 374 |
+
annotation = self._update_annotation_for_padded_image(
|
| 375 |
+
annotation, original_size, padded_size, padding, update_bboxes
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
| 379 |
+
pixel_mask = torch.zeros(padded_size, dtype=torch.int64, device=image.device)
|
| 380 |
+
pixel_mask[: original_size[0], : original_size[1]] = 1
|
| 381 |
+
|
| 382 |
+
return image, pixel_mask, annotation
|
| 383 |
+
|
| 384 |
+
@add_start_docstrings(
|
| 385 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
| 386 |
+
"""
|
| 387 |
+
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
|
| 388 |
+
List of annotations associated with the image or batch of images. If annotation is for object
|
| 389 |
+
detection, the annotations should be a dictionary with the following keys:
|
| 390 |
+
- "image_id" (`int`): The image id.
|
| 391 |
+
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
|
| 392 |
+
dictionary. An image can have no annotations, in which case the list should be empty.
|
| 393 |
+
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
|
| 394 |
+
- "image_id" (`int`): The image id.
|
| 395 |
+
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
| 396 |
+
An image can have no segments, in which case the list should be empty.
|
| 397 |
+
- "file_name" (`str`): The file name of the image.
|
| 398 |
+
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
|
| 399 |
+
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
|
| 400 |
+
do_convert_annotations (`bool`, *optional*, defaults to `True`):
|
| 401 |
+
Controls whether to convert the annotations to the format expected by the DETR model. Converts the
|
| 402 |
+
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
|
| 403 |
+
Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
|
| 404 |
+
do_pad (`bool`, *optional*, defaults to `True`):
|
| 405 |
+
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
|
| 406 |
+
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
|
| 407 |
+
If `pad_size` is provided, the image will be padded to the specified dimensions.
|
| 408 |
+
Otherwise, the image will be padded to the maximum height and width of the batch.
|
| 409 |
+
pad_size (`Dict[str, int]`, *optional*):
|
| 410 |
+
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
| 411 |
+
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
|
| 412 |
+
height and width in the batch.
|
| 413 |
+
return_segmentation_masks (`bool`, *optional*, defaults to `False`):
|
| 414 |
+
Whether to return segmentation masks.
|
| 415 |
+
masks_path (`str` or `pathlib.Path`, *optional*):
|
| 416 |
+
Path to the directory containing the segmentation masks.
|
| 417 |
+
""",
|
| 418 |
+
)
|
| 419 |
+
def preprocess(
|
| 420 |
+
self,
|
| 421 |
+
images: ImageInput,
|
| 422 |
+
annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
|
| 423 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 424 |
+
**kwargs: Unpack[RTDetrFastImageProcessorKwargs],
|
| 425 |
+
) -> BatchFeature:
|
| 426 |
+
return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
|
| 427 |
+
|
| 428 |
+
def _preprocess(
|
| 429 |
+
self,
|
| 430 |
+
images: List["torch.Tensor"],
|
| 431 |
+
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
|
| 432 |
+
return_segmentation_masks: bool,
|
| 433 |
+
masks_path: Optional[Union[str, pathlib.Path]],
|
| 434 |
+
do_resize: bool,
|
| 435 |
+
size: SizeDict,
|
| 436 |
+
interpolation: Optional["F.InterpolationMode"],
|
| 437 |
+
do_center_crop: bool,
|
| 438 |
+
crop_size: SizeDict,
|
| 439 |
+
do_rescale: bool,
|
| 440 |
+
rescale_factor: float,
|
| 441 |
+
do_normalize: bool,
|
| 442 |
+
do_convert_annotations: bool,
|
| 443 |
+
image_mean: Optional[Union[float, List[float]]],
|
| 444 |
+
image_std: Optional[Union[float, List[float]]],
|
| 445 |
+
do_pad: bool,
|
| 446 |
+
pad_size: Optional[Dict[str, int]],
|
| 447 |
+
format: Optional[Union[str, AnnotationFormat]],
|
| 448 |
+
return_tensors: Optional[Union[str, TensorType]],
|
| 449 |
+
) -> BatchFeature:
|
| 450 |
+
"""
|
| 451 |
+
Preprocess an image or a batch of images so that it can be used by the model.
|
| 452 |
+
"""
|
| 453 |
+
|
| 454 |
+
if annotations is not None and isinstance(annotations, dict):
|
| 455 |
+
annotations = [annotations]
|
| 456 |
+
|
| 457 |
+
if annotations is not None and len(images) != len(annotations):
|
| 458 |
+
raise ValueError(
|
| 459 |
+
f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
format = AnnotationFormat(format)
|
| 463 |
+
if annotations is not None:
|
| 464 |
+
validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
|
| 465 |
+
|
| 466 |
+
data = {}
|
| 467 |
+
processed_images = []
|
| 468 |
+
processed_annotations = []
|
| 469 |
+
pixel_masks = [] # Initialize pixel_masks here
|
| 470 |
+
for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
|
| 471 |
+
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
|
| 472 |
+
if annotations is not None:
|
| 473 |
+
annotation = self.prepare_annotation(
|
| 474 |
+
image,
|
| 475 |
+
annotation,
|
| 476 |
+
format,
|
| 477 |
+
return_segmentation_masks=return_segmentation_masks,
|
| 478 |
+
masks_path=masks_path,
|
| 479 |
+
input_data_format=ChannelDimension.FIRST,
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
if do_resize:
|
| 483 |
+
resized_image = self.resize(image, size=size, interpolation=interpolation)
|
| 484 |
+
if annotations is not None:
|
| 485 |
+
annotation = self.resize_annotation(
|
| 486 |
+
annotation,
|
| 487 |
+
orig_size=image.size()[-2:],
|
| 488 |
+
target_size=resized_image.size()[-2:],
|
| 489 |
+
)
|
| 490 |
+
image = resized_image
|
| 491 |
+
# Fused rescale and normalize
|
| 492 |
+
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
|
| 493 |
+
if do_convert_annotations and annotations is not None:
|
| 494 |
+
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
|
| 495 |
+
|
| 496 |
+
processed_images.append(image)
|
| 497 |
+
processed_annotations.append(annotation)
|
| 498 |
+
images = processed_images
|
| 499 |
+
annotations = processed_annotations if annotations is not None else None
|
| 500 |
+
|
| 501 |
+
if do_pad:
|
| 502 |
+
# depends on all resized image shapes so we need another loop
|
| 503 |
+
if pad_size is not None:
|
| 504 |
+
padded_size = (pad_size["height"], pad_size["width"])
|
| 505 |
+
else:
|
| 506 |
+
padded_size = get_max_height_width(images)
|
| 507 |
+
|
| 508 |
+
padded_images = []
|
| 509 |
+
padded_annotations = []
|
| 510 |
+
for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
|
| 511 |
+
# Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
|
| 512 |
+
if padded_size == image.size()[-2:]:
|
| 513 |
+
padded_images.append(image)
|
| 514 |
+
pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device))
|
| 515 |
+
padded_annotations.append(annotation)
|
| 516 |
+
continue
|
| 517 |
+
image, pixel_mask, annotation = self.pad(
|
| 518 |
+
image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations
|
| 519 |
+
)
|
| 520 |
+
padded_images.append(image)
|
| 521 |
+
padded_annotations.append(annotation)
|
| 522 |
+
pixel_masks.append(pixel_mask)
|
| 523 |
+
images = padded_images
|
| 524 |
+
annotations = padded_annotations if annotations is not None else None
|
| 525 |
+
data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)})
|
| 526 |
+
|
| 527 |
+
data.update({"pixel_values": torch.stack(images, dim=0)})
|
| 528 |
+
encoded_inputs = BatchFeature(data, tensor_type=return_tensors)
|
| 529 |
+
if annotations is not None:
|
| 530 |
+
encoded_inputs["labels"] = [
|
| 531 |
+
BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
|
| 532 |
+
]
|
| 533 |
+
return encoded_inputs
|
| 534 |
+
|
| 535 |
+
def post_process_object_detection(
|
| 536 |
+
self,
|
| 537 |
+
outputs,
|
| 538 |
+
threshold: float = 0.5,
|
| 539 |
+
target_sizes: Union[TensorType, List[Tuple]] = None,
|
| 540 |
+
use_focal_loss: bool = True,
|
| 541 |
+
):
|
| 542 |
+
"""
|
| 543 |
+
Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
| 544 |
+
bottom_right_x, bottom_right_y) format. Only supports PyTorch.
|
| 545 |
+
|
| 546 |
+
Args:
|
| 547 |
+
outputs ([`DetrObjectDetectionOutput`]):
|
| 548 |
+
Raw outputs of the model.
|
| 549 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 550 |
+
Score threshold to keep object detection predictions.
|
| 551 |
+
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
|
| 552 |
+
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
| 553 |
+
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
|
| 554 |
+
use_focal_loss (`bool` defaults to `True`):
|
| 555 |
+
Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied
|
| 556 |
+
to compute the scores of each detection, otherwise, a softmax function is used.
|
| 557 |
+
|
| 558 |
+
Returns:
|
| 559 |
+
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
| 560 |
+
in the batch as predicted by the model.
|
| 561 |
+
"""
|
| 562 |
+
requires_backends(self, ["torch"])
|
| 563 |
+
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
|
| 564 |
+
# convert from relative cxcywh to absolute xyxy
|
| 565 |
+
boxes = center_to_corners_format(out_bbox)
|
| 566 |
+
if target_sizes is not None:
|
| 567 |
+
if len(out_logits) != len(target_sizes):
|
| 568 |
+
raise ValueError(
|
| 569 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
| 570 |
+
)
|
| 571 |
+
if isinstance(target_sizes, List):
|
| 572 |
+
img_h, img_w = torch.as_tensor(target_sizes).unbind(1)
|
| 573 |
+
else:
|
| 574 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 575 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
| 576 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 577 |
+
|
| 578 |
+
num_top_queries = out_logits.shape[1]
|
| 579 |
+
num_classes = out_logits.shape[2]
|
| 580 |
+
|
| 581 |
+
if use_focal_loss:
|
| 582 |
+
scores = torch.nn.functional.sigmoid(out_logits)
|
| 583 |
+
scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1)
|
| 584 |
+
labels = index % num_classes
|
| 585 |
+
index = index // num_classes
|
| 586 |
+
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
|
| 587 |
+
else:
|
| 588 |
+
scores = torch.nn.functional.softmax(out_logits)[:, :, :-1]
|
| 589 |
+
scores, labels = scores.max(dim=-1)
|
| 590 |
+
if scores.shape[1] > num_top_queries:
|
| 591 |
+
scores, index = torch.topk(scores, num_top_queries, dim=-1)
|
| 592 |
+
labels = torch.gather(labels, dim=1, index=index)
|
| 593 |
+
boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
|
| 594 |
+
|
| 595 |
+
results = []
|
| 596 |
+
for score, label, box in zip(scores, labels, boxes):
|
| 597 |
+
results.append(
|
| 598 |
+
{
|
| 599 |
+
"scores": score[score > threshold],
|
| 600 |
+
"labels": label[score > threshold],
|
| 601 |
+
"boxes": box[score > threshold],
|
| 602 |
+
}
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
return results
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
__all__ = ["RTDetrImageProcessorFast"]
|
docs/transformers/build/lib/transformers/models/rt_detr/modeling_rt_detr.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
docs/transformers/build/lib/transformers/models/rt_detr/modeling_rt_detr_resnet.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
PyTorch RTDetr specific ResNet model. The main difference between hugginface ResNet model is that this RTDetrResNet model forces to use shortcut at the first layer in the resnet-18/34 models.
|
| 17 |
+
See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L126 for details.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import math
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
from torch import Tensor, nn
|
| 24 |
+
|
| 25 |
+
from ...activations import ACT2FN
|
| 26 |
+
from ...modeling_outputs import (
|
| 27 |
+
BackboneOutput,
|
| 28 |
+
BaseModelOutputWithNoAttention,
|
| 29 |
+
)
|
| 30 |
+
from ...modeling_utils import PreTrainedModel
|
| 31 |
+
from ...utils import (
|
| 32 |
+
add_start_docstrings,
|
| 33 |
+
add_start_docstrings_to_model_forward,
|
| 34 |
+
logging,
|
| 35 |
+
replace_return_docstrings,
|
| 36 |
+
)
|
| 37 |
+
from ...utils.backbone_utils import BackboneMixin
|
| 38 |
+
from .configuration_rt_detr_resnet import RTDetrResNetConfig
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
# General docstring
|
| 44 |
+
_CONFIG_FOR_DOC = "RTDetrResNetConfig"
|
| 45 |
+
|
| 46 |
+
# Base docstring
|
| 47 |
+
_CHECKPOINT_FOR_DOC = "microsoft/resnet-50"
|
| 48 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 2048, 7, 7]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Copied from transformers.models.resnet.modeling_resnet.ResNetConvLayer -> RTDetrResNetConvLayer
|
| 52 |
+
class RTDetrResNetConvLayer(nn.Module):
|
| 53 |
+
def __init__(
|
| 54 |
+
self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
|
| 55 |
+
):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.convolution = nn.Conv2d(
|
| 58 |
+
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False
|
| 59 |
+
)
|
| 60 |
+
self.normalization = nn.BatchNorm2d(out_channels)
|
| 61 |
+
self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
|
| 62 |
+
|
| 63 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 64 |
+
hidden_state = self.convolution(input)
|
| 65 |
+
hidden_state = self.normalization(hidden_state)
|
| 66 |
+
hidden_state = self.activation(hidden_state)
|
| 67 |
+
return hidden_state
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class RTDetrResNetEmbeddings(nn.Module):
|
| 71 |
+
"""
|
| 72 |
+
ResNet Embeddings (stem) composed of a deep aggressive convolution.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self, config: RTDetrResNetConfig):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.embedder = nn.Sequential(
|
| 78 |
+
*[
|
| 79 |
+
RTDetrResNetConvLayer(
|
| 80 |
+
config.num_channels,
|
| 81 |
+
config.embedding_size // 2,
|
| 82 |
+
kernel_size=3,
|
| 83 |
+
stride=2,
|
| 84 |
+
activation=config.hidden_act,
|
| 85 |
+
),
|
| 86 |
+
RTDetrResNetConvLayer(
|
| 87 |
+
config.embedding_size // 2,
|
| 88 |
+
config.embedding_size // 2,
|
| 89 |
+
kernel_size=3,
|
| 90 |
+
stride=1,
|
| 91 |
+
activation=config.hidden_act,
|
| 92 |
+
),
|
| 93 |
+
RTDetrResNetConvLayer(
|
| 94 |
+
config.embedding_size // 2,
|
| 95 |
+
config.embedding_size,
|
| 96 |
+
kernel_size=3,
|
| 97 |
+
stride=1,
|
| 98 |
+
activation=config.hidden_act,
|
| 99 |
+
),
|
| 100 |
+
]
|
| 101 |
+
)
|
| 102 |
+
self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 103 |
+
self.num_channels = config.num_channels
|
| 104 |
+
|
| 105 |
+
def forward(self, pixel_values: Tensor) -> Tensor:
|
| 106 |
+
num_channels = pixel_values.shape[1]
|
| 107 |
+
if num_channels != self.num_channels:
|
| 108 |
+
raise ValueError(
|
| 109 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 110 |
+
)
|
| 111 |
+
embedding = self.embedder(pixel_values)
|
| 112 |
+
embedding = self.pooler(embedding)
|
| 113 |
+
return embedding
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Copied from transformers.models.resnet.modeling_resnet.ResNetShortCut -> RTDetrResNetChortCut
|
| 117 |
+
class RTDetrResNetShortCut(nn.Module):
|
| 118 |
+
"""
|
| 119 |
+
ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
|
| 120 |
+
downsample the input using `stride=2`.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
|
| 126 |
+
self.normalization = nn.BatchNorm2d(out_channels)
|
| 127 |
+
|
| 128 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 129 |
+
hidden_state = self.convolution(input)
|
| 130 |
+
hidden_state = self.normalization(hidden_state)
|
| 131 |
+
return hidden_state
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class RTDetrResNetBasicLayer(nn.Module):
|
| 135 |
+
"""
|
| 136 |
+
A classic ResNet's residual layer composed by two `3x3` convolutions.
|
| 137 |
+
See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L34.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
config: RTDetrResNetConfig,
|
| 143 |
+
in_channels: int,
|
| 144 |
+
out_channels: int,
|
| 145 |
+
stride: int = 1,
|
| 146 |
+
should_apply_shortcut: bool = False,
|
| 147 |
+
):
|
| 148 |
+
super().__init__()
|
| 149 |
+
if in_channels != out_channels:
|
| 150 |
+
self.shortcut = (
|
| 151 |
+
nn.Sequential(
|
| 152 |
+
*[nn.AvgPool2d(2, 2, 0, ceil_mode=True), RTDetrResNetShortCut(in_channels, out_channels, stride=1)]
|
| 153 |
+
)
|
| 154 |
+
if should_apply_shortcut
|
| 155 |
+
else nn.Identity()
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
self.shortcut = (
|
| 159 |
+
RTDetrResNetShortCut(in_channels, out_channels, stride=stride)
|
| 160 |
+
if should_apply_shortcut
|
| 161 |
+
else nn.Identity()
|
| 162 |
+
)
|
| 163 |
+
self.layer = nn.Sequential(
|
| 164 |
+
RTDetrResNetConvLayer(in_channels, out_channels, stride=stride),
|
| 165 |
+
RTDetrResNetConvLayer(out_channels, out_channels, activation=None),
|
| 166 |
+
)
|
| 167 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 168 |
+
|
| 169 |
+
def forward(self, hidden_state):
|
| 170 |
+
residual = hidden_state
|
| 171 |
+
hidden_state = self.layer(hidden_state)
|
| 172 |
+
residual = self.shortcut(residual)
|
| 173 |
+
hidden_state += residual
|
| 174 |
+
hidden_state = self.activation(hidden_state)
|
| 175 |
+
return hidden_state
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class RTDetrResNetBottleNeckLayer(nn.Module):
|
| 179 |
+
"""
|
| 180 |
+
A classic RTDetrResNet's bottleneck layer composed by three `3x3` convolutions.
|
| 181 |
+
|
| 182 |
+
The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
|
| 183 |
+
convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. If
|
| 184 |
+
`downsample_in_bottleneck` is true, downsample will be in the first layer instead of the second layer.
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
config: RTDetrResNetConfig,
|
| 190 |
+
in_channels: int,
|
| 191 |
+
out_channels: int,
|
| 192 |
+
stride: int = 1,
|
| 193 |
+
):
|
| 194 |
+
super().__init__()
|
| 195 |
+
reduction = 4
|
| 196 |
+
should_apply_shortcut = in_channels != out_channels or stride != 1
|
| 197 |
+
reduces_channels = out_channels // reduction
|
| 198 |
+
if stride == 2:
|
| 199 |
+
self.shortcut = nn.Sequential(
|
| 200 |
+
*[
|
| 201 |
+
nn.AvgPool2d(2, 2, 0, ceil_mode=True),
|
| 202 |
+
RTDetrResNetShortCut(in_channels, out_channels, stride=1)
|
| 203 |
+
if should_apply_shortcut
|
| 204 |
+
else nn.Identity(),
|
| 205 |
+
]
|
| 206 |
+
)
|
| 207 |
+
else:
|
| 208 |
+
self.shortcut = (
|
| 209 |
+
RTDetrResNetShortCut(in_channels, out_channels, stride=stride)
|
| 210 |
+
if should_apply_shortcut
|
| 211 |
+
else nn.Identity()
|
| 212 |
+
)
|
| 213 |
+
self.layer = nn.Sequential(
|
| 214 |
+
RTDetrResNetConvLayer(
|
| 215 |
+
in_channels, reduces_channels, kernel_size=1, stride=stride if config.downsample_in_bottleneck else 1
|
| 216 |
+
),
|
| 217 |
+
RTDetrResNetConvLayer(
|
| 218 |
+
reduces_channels, reduces_channels, stride=stride if not config.downsample_in_bottleneck else 1
|
| 219 |
+
),
|
| 220 |
+
RTDetrResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None),
|
| 221 |
+
)
|
| 222 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 223 |
+
|
| 224 |
+
def forward(self, hidden_state):
|
| 225 |
+
residual = hidden_state
|
| 226 |
+
hidden_state = self.layer(hidden_state)
|
| 227 |
+
residual = self.shortcut(residual)
|
| 228 |
+
hidden_state += residual
|
| 229 |
+
hidden_state = self.activation(hidden_state)
|
| 230 |
+
return hidden_state
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class RTDetrResNetStage(nn.Module):
|
| 234 |
+
"""
|
| 235 |
+
A RTDetrResNet stage composed by stacked layers.
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
def __init__(
|
| 239 |
+
self,
|
| 240 |
+
config: RTDetrResNetConfig,
|
| 241 |
+
in_channels: int,
|
| 242 |
+
out_channels: int,
|
| 243 |
+
stride: int = 2,
|
| 244 |
+
depth: int = 2,
|
| 245 |
+
):
|
| 246 |
+
super().__init__()
|
| 247 |
+
|
| 248 |
+
layer = RTDetrResNetBottleNeckLayer if config.layer_type == "bottleneck" else RTDetrResNetBasicLayer
|
| 249 |
+
|
| 250 |
+
if config.layer_type == "bottleneck":
|
| 251 |
+
first_layer = layer(
|
| 252 |
+
config,
|
| 253 |
+
in_channels,
|
| 254 |
+
out_channels,
|
| 255 |
+
stride=stride,
|
| 256 |
+
)
|
| 257 |
+
else:
|
| 258 |
+
first_layer = layer(config, in_channels, out_channels, stride=stride, should_apply_shortcut=True)
|
| 259 |
+
self.layers = nn.Sequential(
|
| 260 |
+
first_layer, *[layer(config, out_channels, out_channels) for _ in range(depth - 1)]
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 264 |
+
hidden_state = input
|
| 265 |
+
for layer in self.layers:
|
| 266 |
+
hidden_state = layer(hidden_state)
|
| 267 |
+
return hidden_state
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# Copied from transformers.models.resnet.modeling_resnet.ResNetEncoder with ResNet->RTDetrResNet
|
| 271 |
+
class RTDetrResNetEncoder(nn.Module):
|
| 272 |
+
def __init__(self, config: RTDetrResNetConfig):
|
| 273 |
+
super().__init__()
|
| 274 |
+
self.stages = nn.ModuleList([])
|
| 275 |
+
# based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input
|
| 276 |
+
self.stages.append(
|
| 277 |
+
RTDetrResNetStage(
|
| 278 |
+
config,
|
| 279 |
+
config.embedding_size,
|
| 280 |
+
config.hidden_sizes[0],
|
| 281 |
+
stride=2 if config.downsample_in_first_stage else 1,
|
| 282 |
+
depth=config.depths[0],
|
| 283 |
+
)
|
| 284 |
+
)
|
| 285 |
+
in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
|
| 286 |
+
for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]):
|
| 287 |
+
self.stages.append(RTDetrResNetStage(config, in_channels, out_channels, depth=depth))
|
| 288 |
+
|
| 289 |
+
def forward(
|
| 290 |
+
self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
|
| 291 |
+
) -> BaseModelOutputWithNoAttention:
|
| 292 |
+
hidden_states = () if output_hidden_states else None
|
| 293 |
+
|
| 294 |
+
for stage_module in self.stages:
|
| 295 |
+
if output_hidden_states:
|
| 296 |
+
hidden_states = hidden_states + (hidden_state,)
|
| 297 |
+
|
| 298 |
+
hidden_state = stage_module(hidden_state)
|
| 299 |
+
|
| 300 |
+
if output_hidden_states:
|
| 301 |
+
hidden_states = hidden_states + (hidden_state,)
|
| 302 |
+
|
| 303 |
+
if not return_dict:
|
| 304 |
+
return tuple(v for v in [hidden_state, hidden_states] if v is not None)
|
| 305 |
+
|
| 306 |
+
return BaseModelOutputWithNoAttention(
|
| 307 |
+
last_hidden_state=hidden_state,
|
| 308 |
+
hidden_states=hidden_states,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel with ResNet->RTDetrResNet
|
| 313 |
+
class RTDetrResNetPreTrainedModel(PreTrainedModel):
|
| 314 |
+
"""
|
| 315 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 316 |
+
models.
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
config_class = RTDetrResNetConfig
|
| 320 |
+
base_model_prefix = "resnet"
|
| 321 |
+
main_input_name = "pixel_values"
|
| 322 |
+
_no_split_modules = ["RTDetrResNetConvLayer", "RTDetrResNetShortCut"]
|
| 323 |
+
|
| 324 |
+
def _init_weights(self, module):
|
| 325 |
+
if isinstance(module, nn.Conv2d):
|
| 326 |
+
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
| 327 |
+
# copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
|
| 328 |
+
elif isinstance(module, nn.Linear):
|
| 329 |
+
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
|
| 330 |
+
if module.bias is not None:
|
| 331 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
|
| 332 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| 333 |
+
nn.init.uniform_(module.bias, -bound, bound)
|
| 334 |
+
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 335 |
+
nn.init.constant_(module.weight, 1)
|
| 336 |
+
nn.init.constant_(module.bias, 0)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
RTDETR_RESNET_START_DOCSTRING = r"""
|
| 340 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 341 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 342 |
+
behavior.
|
| 343 |
+
|
| 344 |
+
Parameters:
|
| 345 |
+
config ([`RTDetrResNetConfig`]): Model configuration class with all the parameters of the model.
|
| 346 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 347 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
RTDETR_RESNET_INPUTS_DOCSTRING = r"""
|
| 351 |
+
Args:
|
| 352 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 353 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
| 354 |
+
[`RTDetrImageProcessor.__call__`] for details.
|
| 355 |
+
|
| 356 |
+
output_hidden_states (`bool`, *optional*):
|
| 357 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 358 |
+
more detail.
|
| 359 |
+
return_dict (`bool`, *optional*):
|
| 360 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
@add_start_docstrings(
|
| 365 |
+
"""
|
| 366 |
+
ResNet backbone, to be used with frameworks like RTDETR.
|
| 367 |
+
""",
|
| 368 |
+
RTDETR_RESNET_START_DOCSTRING,
|
| 369 |
+
)
|
| 370 |
+
class RTDetrResNetBackbone(RTDetrResNetPreTrainedModel, BackboneMixin):
|
| 371 |
+
def __init__(self, config):
|
| 372 |
+
super().__init__(config)
|
| 373 |
+
super()._init_backbone(config)
|
| 374 |
+
|
| 375 |
+
self.num_features = [config.embedding_size] + config.hidden_sizes
|
| 376 |
+
self.embedder = RTDetrResNetEmbeddings(config)
|
| 377 |
+
self.encoder = RTDetrResNetEncoder(config)
|
| 378 |
+
|
| 379 |
+
# initialize weights and apply final processing
|
| 380 |
+
self.post_init()
|
| 381 |
+
|
| 382 |
+
@add_start_docstrings_to_model_forward(RTDETR_RESNET_INPUTS_DOCSTRING)
|
| 383 |
+
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
|
| 384 |
+
def forward(
|
| 385 |
+
self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
|
| 386 |
+
) -> BackboneOutput:
|
| 387 |
+
"""
|
| 388 |
+
Returns:
|
| 389 |
+
|
| 390 |
+
Examples:
|
| 391 |
+
|
| 392 |
+
```python
|
| 393 |
+
>>> from transformers import RTDetrResNetConfig, RTDetrResNetBackbone
|
| 394 |
+
>>> import torch
|
| 395 |
+
|
| 396 |
+
>>> config = RTDetrResNetConfig()
|
| 397 |
+
>>> model = RTDetrResNetBackbone(config)
|
| 398 |
+
|
| 399 |
+
>>> pixel_values = torch.randn(1, 3, 224, 224)
|
| 400 |
+
|
| 401 |
+
>>> with torch.no_grad():
|
| 402 |
+
... outputs = model(pixel_values)
|
| 403 |
+
|
| 404 |
+
>>> feature_maps = outputs.feature_maps
|
| 405 |
+
>>> list(feature_maps[-1].shape)
|
| 406 |
+
[1, 2048, 7, 7]
|
| 407 |
+
```"""
|
| 408 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 409 |
+
output_hidden_states = (
|
| 410 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
embedding_output = self.embedder(pixel_values)
|
| 414 |
+
|
| 415 |
+
outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True)
|
| 416 |
+
|
| 417 |
+
hidden_states = outputs.hidden_states
|
| 418 |
+
|
| 419 |
+
feature_maps = ()
|
| 420 |
+
for idx, stage in enumerate(self.stage_names):
|
| 421 |
+
if stage in self.out_features:
|
| 422 |
+
feature_maps += (hidden_states[idx],)
|
| 423 |
+
|
| 424 |
+
if not return_dict:
|
| 425 |
+
output = (feature_maps,)
|
| 426 |
+
if output_hidden_states:
|
| 427 |
+
output += (outputs.hidden_states,)
|
| 428 |
+
return output
|
| 429 |
+
|
| 430 |
+
return BackboneOutput(
|
| 431 |
+
feature_maps=feature_maps,
|
| 432 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 433 |
+
attentions=None,
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
__all__ = [
|
| 438 |
+
"RTDetrResNetBackbone",
|
| 439 |
+
"RTDetrResNetPreTrainedModel",
|
| 440 |
+
]
|
docs/transformers/build/lib/transformers/models/rt_detr/modular_rt_detr.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib
|
| 2 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
from transformers.models.detr.image_processing_detr_fast import (
|
| 5 |
+
DetrFastImageProcessorKwargs,
|
| 6 |
+
DetrImageProcessorFast,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
from ...image_processing_utils import BatchFeature
|
| 10 |
+
from ...image_processing_utils_fast import (
|
| 11 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
| 12 |
+
BaseImageProcessorFast,
|
| 13 |
+
SizeDict,
|
| 14 |
+
add_start_docstrings,
|
| 15 |
+
get_max_height_width,
|
| 16 |
+
)
|
| 17 |
+
from ...image_transforms import center_to_corners_format
|
| 18 |
+
from ...image_utils import (
|
| 19 |
+
IMAGENET_DEFAULT_MEAN,
|
| 20 |
+
IMAGENET_DEFAULT_STD,
|
| 21 |
+
AnnotationFormat,
|
| 22 |
+
AnnotationType,
|
| 23 |
+
ChannelDimension,
|
| 24 |
+
ImageInput,
|
| 25 |
+
PILImageResampling,
|
| 26 |
+
get_image_size,
|
| 27 |
+
validate_annotations,
|
| 28 |
+
)
|
| 29 |
+
from ...processing_utils import Unpack
|
| 30 |
+
from ...utils import (
|
| 31 |
+
TensorType,
|
| 32 |
+
is_torch_available,
|
| 33 |
+
is_torchvision_available,
|
| 34 |
+
is_torchvision_v2_available,
|
| 35 |
+
logging,
|
| 36 |
+
requires_backends,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if is_torch_available():
|
| 41 |
+
import torch
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if is_torchvision_v2_available():
|
| 45 |
+
from torchvision.transforms.v2 import functional as F
|
| 46 |
+
elif is_torchvision_available():
|
| 47 |
+
from torchvision.transforms import functional as F
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
logger = logging.get_logger(__name__)
|
| 51 |
+
|
| 52 |
+
SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION,)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def prepare_coco_detection_annotation(
|
| 56 |
+
image,
|
| 57 |
+
target,
|
| 58 |
+
return_segmentation_masks: bool = False,
|
| 59 |
+
input_data_format: Optional[Union[ChannelDimension, str]] = None,
|
| 60 |
+
):
|
| 61 |
+
"""
|
| 62 |
+
Convert the target in COCO format into the format expected by RT-DETR.
|
| 63 |
+
"""
|
| 64 |
+
image_height, image_width = image.size()[-2:]
|
| 65 |
+
|
| 66 |
+
image_id = target["image_id"]
|
| 67 |
+
image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device)
|
| 68 |
+
|
| 69 |
+
# Get all COCO annotations for the given image.
|
| 70 |
+
annotations = target["annotations"]
|
| 71 |
+
classes = []
|
| 72 |
+
area = []
|
| 73 |
+
boxes = []
|
| 74 |
+
keypoints = []
|
| 75 |
+
for obj in annotations:
|
| 76 |
+
if "iscrowd" not in obj or obj["iscrowd"] == 0:
|
| 77 |
+
classes.append(obj["category_id"])
|
| 78 |
+
area.append(obj["area"])
|
| 79 |
+
boxes.append(obj["bbox"])
|
| 80 |
+
if "keypoints" in obj:
|
| 81 |
+
keypoints.append(obj["keypoints"])
|
| 82 |
+
|
| 83 |
+
classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device)
|
| 84 |
+
area = torch.as_tensor(area, dtype=torch.float32, device=image.device)
|
| 85 |
+
iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device)
|
| 86 |
+
# guard against no boxes via resizing
|
| 87 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4)
|
| 88 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 89 |
+
boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
|
| 90 |
+
boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
|
| 91 |
+
|
| 92 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
| 93 |
+
|
| 94 |
+
new_target = {
|
| 95 |
+
"image_id": image_id,
|
| 96 |
+
"class_labels": classes[keep],
|
| 97 |
+
"boxes": boxes[keep],
|
| 98 |
+
"area": area[keep],
|
| 99 |
+
"iscrowd": iscrowd[keep],
|
| 100 |
+
"orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device),
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
if keypoints:
|
| 104 |
+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device)
|
| 105 |
+
# Apply the keep mask here to filter the relevant annotations
|
| 106 |
+
keypoints = keypoints[keep]
|
| 107 |
+
num_keypoints = keypoints.shape[0]
|
| 108 |
+
keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
|
| 109 |
+
new_target["keypoints"] = keypoints
|
| 110 |
+
|
| 111 |
+
return new_target
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class RTDetrFastImageProcessorKwargs(DetrFastImageProcessorKwargs):
|
| 115 |
+
pass
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast):
|
| 119 |
+
resample = PILImageResampling.BILINEAR
|
| 120 |
+
image_mean = IMAGENET_DEFAULT_MEAN
|
| 121 |
+
image_std = IMAGENET_DEFAULT_STD
|
| 122 |
+
format = AnnotationFormat.COCO_DETECTION
|
| 123 |
+
do_convert_annotations = True
|
| 124 |
+
do_resize = True
|
| 125 |
+
do_rescale = True
|
| 126 |
+
do_normalize = False
|
| 127 |
+
do_pad = False
|
| 128 |
+
size = {"height": 640, "width": 640}
|
| 129 |
+
default_to_square = False
|
| 130 |
+
model_input_names = ["pixel_values", "pixel_mask"]
|
| 131 |
+
valid_kwargs = RTDetrFastImageProcessorKwargs
|
| 132 |
+
|
| 133 |
+
def __init__(self, **kwargs: Unpack[RTDetrFastImageProcessorKwargs]) -> None:
|
| 134 |
+
# Backwards compatibility
|
| 135 |
+
do_convert_annotations = kwargs.get("do_convert_annotations", None)
|
| 136 |
+
do_normalize = kwargs.get("do_normalize", None)
|
| 137 |
+
if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None:
|
| 138 |
+
self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize
|
| 139 |
+
|
| 140 |
+
BaseImageProcessorFast.__init__(**kwargs)
|
| 141 |
+
|
| 142 |
+
@add_start_docstrings(
|
| 143 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
| 144 |
+
"""
|
| 145 |
+
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
|
| 146 |
+
List of annotations associated with the image or batch of images. If annotation is for object
|
| 147 |
+
detection, the annotations should be a dictionary with the following keys:
|
| 148 |
+
- "image_id" (`int`): The image id.
|
| 149 |
+
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
|
| 150 |
+
dictionary. An image can have no annotations, in which case the list should be empty.
|
| 151 |
+
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
|
| 152 |
+
- "image_id" (`int`): The image id.
|
| 153 |
+
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
| 154 |
+
An image can have no segments, in which case the list should be empty.
|
| 155 |
+
- "file_name" (`str`): The file name of the image.
|
| 156 |
+
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
|
| 157 |
+
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
|
| 158 |
+
do_convert_annotations (`bool`, *optional*, defaults to `True`):
|
| 159 |
+
Controls whether to convert the annotations to the format expected by the DETR model. Converts the
|
| 160 |
+
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
|
| 161 |
+
Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
|
| 162 |
+
do_pad (`bool`, *optional*, defaults to `True`):
|
| 163 |
+
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
|
| 164 |
+
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
|
| 165 |
+
If `pad_size` is provided, the image will be padded to the specified dimensions.
|
| 166 |
+
Otherwise, the image will be padded to the maximum height and width of the batch.
|
| 167 |
+
pad_size (`Dict[str, int]`, *optional*):
|
| 168 |
+
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
| 169 |
+
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
|
| 170 |
+
height and width in the batch.
|
| 171 |
+
return_segmentation_masks (`bool`, *optional*, defaults to `False`):
|
| 172 |
+
Whether to return segmentation masks.
|
| 173 |
+
masks_path (`str` or `pathlib.Path`, *optional*):
|
| 174 |
+
Path to the directory containing the segmentation masks.
|
| 175 |
+
""",
|
| 176 |
+
)
|
| 177 |
+
def preprocess(
|
| 178 |
+
self,
|
| 179 |
+
images: ImageInput,
|
| 180 |
+
annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
|
| 181 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 182 |
+
**kwargs: Unpack[RTDetrFastImageProcessorKwargs],
|
| 183 |
+
) -> BatchFeature:
|
| 184 |
+
return BaseImageProcessorFast().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
|
| 185 |
+
|
| 186 |
+
def prepare_annotation(
|
| 187 |
+
self,
|
| 188 |
+
image: torch.Tensor,
|
| 189 |
+
target: Dict,
|
| 190 |
+
format: Optional[AnnotationFormat] = None,
|
| 191 |
+
return_segmentation_masks: Optional[bool] = None,
|
| 192 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 193 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 194 |
+
) -> Dict:
|
| 195 |
+
format = format if format is not None else self.format
|
| 196 |
+
|
| 197 |
+
if format == AnnotationFormat.COCO_DETECTION:
|
| 198 |
+
return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
|
| 199 |
+
target = prepare_coco_detection_annotation(
|
| 200 |
+
image, target, return_segmentation_masks, input_data_format=input_data_format
|
| 201 |
+
)
|
| 202 |
+
else:
|
| 203 |
+
raise ValueError(f"Format {format} is not supported.")
|
| 204 |
+
return target
|
| 205 |
+
|
| 206 |
+
def _preprocess(
|
| 207 |
+
self,
|
| 208 |
+
images: List["torch.Tensor"],
|
| 209 |
+
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
|
| 210 |
+
return_segmentation_masks: bool,
|
| 211 |
+
masks_path: Optional[Union[str, pathlib.Path]],
|
| 212 |
+
do_resize: bool,
|
| 213 |
+
size: SizeDict,
|
| 214 |
+
interpolation: Optional["F.InterpolationMode"],
|
| 215 |
+
do_center_crop: bool,
|
| 216 |
+
crop_size: SizeDict,
|
| 217 |
+
do_rescale: bool,
|
| 218 |
+
rescale_factor: float,
|
| 219 |
+
do_normalize: bool,
|
| 220 |
+
do_convert_annotations: bool,
|
| 221 |
+
image_mean: Optional[Union[float, List[float]]],
|
| 222 |
+
image_std: Optional[Union[float, List[float]]],
|
| 223 |
+
do_pad: bool,
|
| 224 |
+
pad_size: Optional[Dict[str, int]],
|
| 225 |
+
format: Optional[Union[str, AnnotationFormat]],
|
| 226 |
+
return_tensors: Optional[Union[str, TensorType]],
|
| 227 |
+
) -> BatchFeature:
|
| 228 |
+
"""
|
| 229 |
+
Preprocess an image or a batch of images so that it can be used by the model.
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
if annotations is not None and isinstance(annotations, dict):
|
| 233 |
+
annotations = [annotations]
|
| 234 |
+
|
| 235 |
+
if annotations is not None and len(images) != len(annotations):
|
| 236 |
+
raise ValueError(
|
| 237 |
+
f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
format = AnnotationFormat(format)
|
| 241 |
+
if annotations is not None:
|
| 242 |
+
validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
|
| 243 |
+
|
| 244 |
+
data = {}
|
| 245 |
+
processed_images = []
|
| 246 |
+
processed_annotations = []
|
| 247 |
+
pixel_masks = [] # Initialize pixel_masks here
|
| 248 |
+
for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
|
| 249 |
+
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
|
| 250 |
+
if annotations is not None:
|
| 251 |
+
annotation = self.prepare_annotation(
|
| 252 |
+
image,
|
| 253 |
+
annotation,
|
| 254 |
+
format,
|
| 255 |
+
return_segmentation_masks=return_segmentation_masks,
|
| 256 |
+
masks_path=masks_path,
|
| 257 |
+
input_data_format=ChannelDimension.FIRST,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
if do_resize:
|
| 261 |
+
resized_image = self.resize(image, size=size, interpolation=interpolation)
|
| 262 |
+
if annotations is not None:
|
| 263 |
+
annotation = self.resize_annotation(
|
| 264 |
+
annotation,
|
| 265 |
+
orig_size=image.size()[-2:],
|
| 266 |
+
target_size=resized_image.size()[-2:],
|
| 267 |
+
)
|
| 268 |
+
image = resized_image
|
| 269 |
+
# Fused rescale and normalize
|
| 270 |
+
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
|
| 271 |
+
if do_convert_annotations and annotations is not None:
|
| 272 |
+
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
|
| 273 |
+
|
| 274 |
+
processed_images.append(image)
|
| 275 |
+
processed_annotations.append(annotation)
|
| 276 |
+
images = processed_images
|
| 277 |
+
annotations = processed_annotations if annotations is not None else None
|
| 278 |
+
|
| 279 |
+
if do_pad:
|
| 280 |
+
# depends on all resized image shapes so we need another loop
|
| 281 |
+
if pad_size is not None:
|
| 282 |
+
padded_size = (pad_size["height"], pad_size["width"])
|
| 283 |
+
else:
|
| 284 |
+
padded_size = get_max_height_width(images)
|
| 285 |
+
|
| 286 |
+
padded_images = []
|
| 287 |
+
padded_annotations = []
|
| 288 |
+
for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
|
| 289 |
+
# Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
|
| 290 |
+
if padded_size == image.size()[-2:]:
|
| 291 |
+
padded_images.append(image)
|
| 292 |
+
pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device))
|
| 293 |
+
padded_annotations.append(annotation)
|
| 294 |
+
continue
|
| 295 |
+
image, pixel_mask, annotation = self.pad(
|
| 296 |
+
image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations
|
| 297 |
+
)
|
| 298 |
+
padded_images.append(image)
|
| 299 |
+
padded_annotations.append(annotation)
|
| 300 |
+
pixel_masks.append(pixel_mask)
|
| 301 |
+
images = padded_images
|
| 302 |
+
annotations = padded_annotations if annotations is not None else None
|
| 303 |
+
data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)})
|
| 304 |
+
|
| 305 |
+
data.update({"pixel_values": torch.stack(images, dim=0)})
|
| 306 |
+
encoded_inputs = BatchFeature(data, tensor_type=return_tensors)
|
| 307 |
+
if annotations is not None:
|
| 308 |
+
encoded_inputs["labels"] = [
|
| 309 |
+
BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
|
| 310 |
+
]
|
| 311 |
+
return encoded_inputs
|
| 312 |
+
|
| 313 |
+
def post_process_object_detection(
|
| 314 |
+
self,
|
| 315 |
+
outputs,
|
| 316 |
+
threshold: float = 0.5,
|
| 317 |
+
target_sizes: Union[TensorType, List[Tuple]] = None,
|
| 318 |
+
use_focal_loss: bool = True,
|
| 319 |
+
):
|
| 320 |
+
"""
|
| 321 |
+
Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
| 322 |
+
bottom_right_x, bottom_right_y) format. Only supports PyTorch.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
outputs ([`DetrObjectDetectionOutput`]):
|
| 326 |
+
Raw outputs of the model.
|
| 327 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 328 |
+
Score threshold to keep object detection predictions.
|
| 329 |
+
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
|
| 330 |
+
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
| 331 |
+
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
|
| 332 |
+
use_focal_loss (`bool` defaults to `True`):
|
| 333 |
+
Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied
|
| 334 |
+
to compute the scores of each detection, otherwise, a softmax function is used.
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
| 338 |
+
in the batch as predicted by the model.
|
| 339 |
+
"""
|
| 340 |
+
requires_backends(self, ["torch"])
|
| 341 |
+
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
|
| 342 |
+
# convert from relative cxcywh to absolute xyxy
|
| 343 |
+
boxes = center_to_corners_format(out_bbox)
|
| 344 |
+
if target_sizes is not None:
|
| 345 |
+
if len(out_logits) != len(target_sizes):
|
| 346 |
+
raise ValueError(
|
| 347 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
| 348 |
+
)
|
| 349 |
+
if isinstance(target_sizes, List):
|
| 350 |
+
img_h, img_w = torch.as_tensor(target_sizes).unbind(1)
|
| 351 |
+
else:
|
| 352 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 353 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
| 354 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 355 |
+
|
| 356 |
+
num_top_queries = out_logits.shape[1]
|
| 357 |
+
num_classes = out_logits.shape[2]
|
| 358 |
+
|
| 359 |
+
if use_focal_loss:
|
| 360 |
+
scores = torch.nn.functional.sigmoid(out_logits)
|
| 361 |
+
scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1)
|
| 362 |
+
labels = index % num_classes
|
| 363 |
+
index = index // num_classes
|
| 364 |
+
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
|
| 365 |
+
else:
|
| 366 |
+
scores = torch.nn.functional.softmax(out_logits)[:, :, :-1]
|
| 367 |
+
scores, labels = scores.max(dim=-1)
|
| 368 |
+
if scores.shape[1] > num_top_queries:
|
| 369 |
+
scores, index = torch.topk(scores, num_top_queries, dim=-1)
|
| 370 |
+
labels = torch.gather(labels, dim=1, index=index)
|
| 371 |
+
boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
|
| 372 |
+
|
| 373 |
+
results = []
|
| 374 |
+
for score, label, box in zip(scores, labels, boxes):
|
| 375 |
+
results.append(
|
| 376 |
+
{
|
| 377 |
+
"scores": score[score > threshold],
|
| 378 |
+
"labels": label[score > threshold],
|
| 379 |
+
"boxes": box[score > threshold],
|
| 380 |
+
}
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
return results
|
| 384 |
+
|
| 385 |
+
def from_dict():
|
| 386 |
+
raise NotImplementedError("No need to override this method for RT-DETR yet.")
|
| 387 |
+
|
| 388 |
+
def post_process():
|
| 389 |
+
raise NotImplementedError("Post-processing is not implemented for RT-DETR yet.")
|
| 390 |
+
|
| 391 |
+
def post_process_segmentation():
|
| 392 |
+
raise NotImplementedError("Segmentation post-processing is not implemented for RT-DETR yet.")
|
| 393 |
+
|
| 394 |
+
def post_process_instance():
|
| 395 |
+
raise NotImplementedError("Instance post-processing is not implemented for RT-DETR yet.")
|
| 396 |
+
|
| 397 |
+
def post_process_panoptic():
|
| 398 |
+
raise NotImplementedError("Panoptic post-processing is not implemented for RT-DETR yet.")
|
| 399 |
+
|
| 400 |
+
def post_process_instance_segmentation():
|
| 401 |
+
raise NotImplementedError("Segmentation post-processing is not implemented for RT-DETR yet.")
|
| 402 |
+
|
| 403 |
+
def post_process_semantic_segmentation():
|
| 404 |
+
raise NotImplementedError("Semantic segmentation post-processing is not implemented for RT-DETR yet.")
|
| 405 |
+
|
| 406 |
+
def post_process_panoptic_segmentation():
|
| 407 |
+
raise NotImplementedError("Panoptic segmentation post-processing is not implemented for RT-DETR yet.")
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
__all__ = ["RTDetrImageProcessorFast"]
|
docs/transformers/build/lib/transformers/models/rt_detr_v2/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from typing import TYPE_CHECKING
|
| 17 |
+
|
| 18 |
+
from ...utils import _LazyModule
|
| 19 |
+
from ...utils.import_utils import define_import_structure
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if TYPE_CHECKING:
|
| 23 |
+
from .configuration_rt_detr_v2 import *
|
| 24 |
+
from .modeling_rt_detr_v2 import *
|
| 25 |
+
else:
|
| 26 |
+
import sys
|
| 27 |
+
|
| 28 |
+
_file = globals()["__file__"]
|
| 29 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/rt_detr_v2/convert_rt_detr_v2_weights_to_hf.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert RT Detr V2 checkpoints with Timm backbone"""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import re
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
import requests
|
| 23 |
+
import torch
|
| 24 |
+
from huggingface_hub import hf_hub_download
|
| 25 |
+
from PIL import Image
|
| 26 |
+
from torchvision import transforms
|
| 27 |
+
|
| 28 |
+
from transformers import RTDetrImageProcessor, RTDetrV2Config, RTDetrV2ForObjectDetection
|
| 29 |
+
from transformers.utils import logging
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logging.set_verbosity_info()
|
| 33 |
+
logger = logging.get_logger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_rt_detr_v2_config(model_name: str) -> RTDetrV2Config:
|
| 37 |
+
config = RTDetrV2Config()
|
| 38 |
+
|
| 39 |
+
config.num_labels = 80
|
| 40 |
+
repo_id = "huggingface/label-files"
|
| 41 |
+
filename = "coco-detection-mmdet-id2label.json"
|
| 42 |
+
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
| 43 |
+
id2label = {int(k): v for k, v in id2label.items()}
|
| 44 |
+
config.id2label = id2label
|
| 45 |
+
config.label2id = {v: k for k, v in id2label.items()}
|
| 46 |
+
|
| 47 |
+
if model_name == "rtdetr_v2_r18vd":
|
| 48 |
+
config.backbone_config.hidden_sizes = [64, 128, 256, 512]
|
| 49 |
+
config.backbone_config.depths = [2, 2, 2, 2]
|
| 50 |
+
config.backbone_config.layer_type = "basic"
|
| 51 |
+
config.encoder_in_channels = [128, 256, 512]
|
| 52 |
+
config.hidden_expansion = 0.5
|
| 53 |
+
config.decoder_layers = 3
|
| 54 |
+
elif model_name == "rtdetr_v2_r34vd":
|
| 55 |
+
config.backbone_config.hidden_sizes = [64, 128, 256, 512]
|
| 56 |
+
config.backbone_config.depths = [3, 4, 6, 3]
|
| 57 |
+
config.backbone_config.layer_type = "basic"
|
| 58 |
+
config.encoder_in_channels = [128, 256, 512]
|
| 59 |
+
config.hidden_expansion = 0.5
|
| 60 |
+
config.decoder_layers = 4
|
| 61 |
+
# TODO: check this not working
|
| 62 |
+
elif model_name == "rtdetr_v2_r50vd_m":
|
| 63 |
+
config.hidden_expansion = 0.5
|
| 64 |
+
elif model_name == "rtdetr_v2_r50vd":
|
| 65 |
+
pass
|
| 66 |
+
elif model_name == "rtdetr_v2_r101vd":
|
| 67 |
+
config.backbone_config.depths = [3, 4, 23, 3]
|
| 68 |
+
config.encoder_ffn_dim = 2048
|
| 69 |
+
config.encoder_hidden_dim = 384
|
| 70 |
+
config.decoder_in_channels = [384, 384, 384]
|
| 71 |
+
|
| 72 |
+
return config
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Define a mapping from original keys to converted keys using regex
|
| 76 |
+
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
| 77 |
+
r"backbone.conv1.conv1_1.conv.weight": r"model.backbone.model.embedder.embedder.0.convolution.weight",
|
| 78 |
+
r"backbone.conv1.conv1_1.norm.(weight|bias|running_mean|running_var)": r"model.backbone.model.embedder.embedder.0.normalization.\1",
|
| 79 |
+
r"backbone.conv1.conv1_2.conv.weight": r"model.backbone.model.embedder.embedder.1.convolution.weight",
|
| 80 |
+
r"backbone.conv1.conv1_2.norm.(weight|bias|running_mean|running_var)": r"model.backbone.model.embedder.embedder.1.normalization.\1",
|
| 81 |
+
r"backbone.conv1.conv1_3.conv.weight": r"model.backbone.model.embedder.embedder.2.convolution.weight",
|
| 82 |
+
r"backbone.conv1.conv1_3.norm.(weight|bias|running_mean|running_var)": r"model.backbone.model.embedder.embedder.2.normalization.\1",
|
| 83 |
+
r"backbone.res_layers.(\d+).blocks.(\d+).branch2a.conv.weight": r"model.backbone.model.encoder.stages.\1.layers.\2.layer.0.convolution.weight",
|
| 84 |
+
r"backbone.res_layers.(\d+).blocks.(\d+).branch2a.norm.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.layers.\2.layer.0.normalization.\3",
|
| 85 |
+
r"backbone.res_layers.(\d+).blocks.(\d+).branch2b.conv.weight": r"model.backbone.model.encoder.stages.\1.layers.\2.layer.1.convolution.weight",
|
| 86 |
+
r"backbone.res_layers.(\d+).blocks.(\d+).branch2b.norm.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.layers.\2.layer.1.normalization.\3",
|
| 87 |
+
r"backbone.res_layers.(\d+).blocks.(\d+).branch2c.conv.weight": r"model.backbone.model.encoder.stages.\1.layers.\2.layer.2.convolution.weight",
|
| 88 |
+
r"backbone.res_layers.(\d+).blocks.(\d+).branch2c.norm.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.layers.\2.layer.2.normalization.\3",
|
| 89 |
+
r"encoder.encoder.(\d+).layers.0.self_attn.out_proj.weight": r"model.encoder.encoder.\1.layers.0.self_attn.out_proj.weight",
|
| 90 |
+
r"encoder.encoder.(\d+).layers.0.self_attn.out_proj.bias": r"model.encoder.encoder.\1.layers.0.self_attn.out_proj.bias",
|
| 91 |
+
r"encoder.encoder.(\d+).layers.0.linear1.weight": r"model.encoder.encoder.\1.layers.0.fc1.weight",
|
| 92 |
+
r"encoder.encoder.(\d+).layers.0.linear1.bias": r"model.encoder.encoder.\1.layers.0.fc1.bias",
|
| 93 |
+
r"encoder.encoder.(\d+).layers.0.linear2.weight": r"model.encoder.encoder.\1.layers.0.fc2.weight",
|
| 94 |
+
r"encoder.encoder.(\d+).layers.0.linear2.bias": r"model.encoder.encoder.\1.layers.0.fc2.bias",
|
| 95 |
+
r"encoder.encoder.(\d+).layers.0.norm1.weight": r"model.encoder.encoder.\1.layers.0.self_attn_layer_norm.weight",
|
| 96 |
+
r"encoder.encoder.(\d+).layers.0.norm1.bias": r"model.encoder.encoder.\1.layers.0.self_attn_layer_norm.bias",
|
| 97 |
+
r"encoder.encoder.(\d+).layers.0.norm2.weight": r"model.encoder.encoder.\1.layers.0.final_layer_norm.weight",
|
| 98 |
+
r"encoder.encoder.(\d+).layers.0.norm2.bias": r"model.encoder.encoder.\1.layers.0.final_layer_norm.bias",
|
| 99 |
+
r"encoder.input_proj.(\d+).conv.weight": r"model.encoder_input_proj.\1.0.weight",
|
| 100 |
+
r"encoder.input_proj.(\d+).norm.(.*)": r"model.encoder_input_proj.\1.1.\2",
|
| 101 |
+
r"encoder.fpn_blocks.(\d+).conv(\d+).conv.weight": r"model.encoder.fpn_blocks.\1.conv\2.conv.weight",
|
| 102 |
+
# r"encoder.fpn_blocks.(\d+).conv(\d+).norm.(.*)": r"model.encoder.fpn_blocks.\1.conv\2.norm.\3",
|
| 103 |
+
r"encoder.fpn_blocks.(\d+).conv(\d+).norm.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.conv\2.norm.\3",
|
| 104 |
+
r"encoder.lateral_convs.(\d+).conv.weight": r"model.encoder.lateral_convs.\1.conv.weight",
|
| 105 |
+
r"encoder.lateral_convs.(\d+).norm.(.*)": r"model.encoder.lateral_convs.\1.norm.\2",
|
| 106 |
+
r"encoder.fpn_blocks.(\d+).bottlenecks.(\d+).conv(\d+).conv.weight": r"model.encoder.fpn_blocks.\1.bottlenecks.\2.conv\3.conv.weight",
|
| 107 |
+
r"encoder.fpn_blocks.(\d+).bottlenecks.(\d+).conv(\d+).norm.(\w+)": r"model.encoder.fpn_blocks.\1.bottlenecks.\2.conv\3.norm.\4",
|
| 108 |
+
r"encoder.pan_blocks.(\d+).conv(\d+).conv.weight": r"model.encoder.pan_blocks.\1.conv\2.conv.weight",
|
| 109 |
+
r"encoder.pan_blocks.(\d+).conv(\d+).norm.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.conv\2.norm.\3",
|
| 110 |
+
r"encoder.pan_blocks.(\d+).bottlenecks.(\d+).conv(\d+).conv.weight": r"model.encoder.pan_blocks.\1.bottlenecks.\2.conv\3.conv.weight",
|
| 111 |
+
r"encoder.pan_blocks.(\d+).bottlenecks.(\d+).conv(\d+).norm.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.bottlenecks.\2.conv\3.norm.\4",
|
| 112 |
+
r"encoder.downsample_convs.(\d+).conv.weight": r"model.encoder.downsample_convs.\1.conv.weight",
|
| 113 |
+
r"encoder.downsample_convs.(\d+).norm.(weight|bias|running_mean|running_var)": r"model.encoder.downsample_convs.\1.norm.\2",
|
| 114 |
+
r"decoder.decoder.layers.(\d+).self_attn.out_proj.weight": r"model.decoder.layers.\1.self_attn.out_proj.weight",
|
| 115 |
+
r"decoder.decoder.layers.(\d+).self_attn.out_proj.bias": r"model.decoder.layers.\1.self_attn.out_proj.bias",
|
| 116 |
+
r"decoder.decoder.layers.(\d+).cross_attn.sampling_offsets.weight": r"model.decoder.layers.\1.encoder_attn.sampling_offsets.weight",
|
| 117 |
+
r"decoder.decoder.layers.(\d+).cross_attn.sampling_offsets.bias": r"model.decoder.layers.\1.encoder_attn.sampling_offsets.bias",
|
| 118 |
+
r"decoder.decoder.layers.(\d+).cross_attn.attention_weights.weight": r"model.decoder.layers.\1.encoder_attn.attention_weights.weight",
|
| 119 |
+
r"decoder.decoder.layers.(\d+).cross_attn.attention_weights.bias": r"model.decoder.layers.\1.encoder_attn.attention_weights.bias",
|
| 120 |
+
r"decoder.decoder.layers.(\d+).cross_attn.value_proj.weight": r"model.decoder.layers.\1.encoder_attn.value_proj.weight",
|
| 121 |
+
r"decoder.decoder.layers.(\d+).cross_attn.value_proj.bias": r"model.decoder.layers.\1.encoder_attn.value_proj.bias",
|
| 122 |
+
r"decoder.decoder.layers.(\d+).cross_attn.output_proj.weight": r"model.decoder.layers.\1.encoder_attn.output_proj.weight",
|
| 123 |
+
r"decoder.decoder.layers.(\d+).cross_attn.output_proj.bias": r"model.decoder.layers.\1.encoder_attn.output_proj.bias",
|
| 124 |
+
r"decoder.decoder.layers.(\d+).norm1.weight": r"model.decoder.layers.\1.self_attn_layer_norm.weight",
|
| 125 |
+
r"decoder.decoder.layers.(\d+).norm1.bias": r"model.decoder.layers.\1.self_attn_layer_norm.bias",
|
| 126 |
+
r"decoder.decoder.layers.(\d+).norm2.weight": r"model.decoder.layers.\1.encoder_attn_layer_norm.weight",
|
| 127 |
+
r"decoder.decoder.layers.(\d+).norm2.bias": r"model.decoder.layers.\1.encoder_attn_layer_norm.bias",
|
| 128 |
+
r"decoder.decoder.layers.(\d+).linear1.weight": r"model.decoder.layers.\1.fc1.weight",
|
| 129 |
+
r"decoder.decoder.layers.(\d+).linear1.bias": r"model.decoder.layers.\1.fc1.bias",
|
| 130 |
+
r"decoder.decoder.layers.(\d+).linear2.weight": r"model.decoder.layers.\1.fc2.weight",
|
| 131 |
+
r"decoder.decoder.layers.(\d+).linear2.bias": r"model.decoder.layers.\1.fc2.bias",
|
| 132 |
+
r"decoder.decoder.layers.(\d+).norm3.weight": r"model.decoder.layers.\1.final_layer_norm.weight",
|
| 133 |
+
r"decoder.decoder.layers.(\d+).norm3.bias": r"model.decoder.layers.\1.final_layer_norm.bias",
|
| 134 |
+
r"decoder.decoder.layers.(\d+).cross_attn.num_points_scale": r"model.decoder.layers.\1.encoder_attn.n_points_scale",
|
| 135 |
+
r"decoder.dec_score_head.(\d+).weight": r"model.decoder.class_embed.\1.weight",
|
| 136 |
+
r"decoder.dec_score_head.(\d+).bias": r"model.decoder.class_embed.\1.bias",
|
| 137 |
+
r"decoder.dec_bbox_head.(\d+).layers.(\d+).(weight|bias)": r"model.decoder.bbox_embed.\1.layers.\2.\3",
|
| 138 |
+
r"decoder.denoising_class_embed.weight": r"model.denoising_class_embed.weight",
|
| 139 |
+
r"decoder.query_pos_head.layers.0.weight": r"model.decoder.query_pos_head.layers.0.weight",
|
| 140 |
+
r"decoder.query_pos_head.layers.0.bias": r"model.decoder.query_pos_head.layers.0.bias",
|
| 141 |
+
r"decoder.query_pos_head.layers.1.weight": r"model.decoder.query_pos_head.layers.1.weight",
|
| 142 |
+
r"decoder.query_pos_head.layers.1.bias": r"model.decoder.query_pos_head.layers.1.bias",
|
| 143 |
+
r"decoder.enc_output.proj.weight": r"model.enc_output.0.weight",
|
| 144 |
+
r"decoder.enc_output.proj.bias": r"model.enc_output.0.bias",
|
| 145 |
+
r"decoder.enc_output.norm.weight": r"model.enc_output.1.weight",
|
| 146 |
+
r"decoder.enc_output.norm.bias": r"model.enc_output.1.bias",
|
| 147 |
+
r"decoder.enc_score_head.weight": r"model.enc_score_head.weight",
|
| 148 |
+
r"decoder.enc_score_head.bias": r"model.enc_score_head.bias",
|
| 149 |
+
r"decoder.enc_bbox_head.layers.(\d+).(weight|bias)": r"model.enc_bbox_head.layers.\1.\2",
|
| 150 |
+
r"backbone.res_layers.0.blocks.0.short.conv.weight": r"model.backbone.model.encoder.stages.0.layers.0.shortcut.convolution.weight",
|
| 151 |
+
r"backbone.res_layers.0.blocks.0.short.norm.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.0.layers.0.shortcut.normalization.\1",
|
| 152 |
+
r"backbone.res_layers.(\d+).blocks.0.short.conv.conv.weight": r"model.backbone.model.encoder.stages.\1.layers.0.shortcut.1.convolution.weight",
|
| 153 |
+
r"backbone.res_layers.(\d+).blocks.0.short.conv.norm.(\w+)": r"model.backbone.model.encoder.stages.\1.layers.0.shortcut.1.normalization.\2",
|
| 154 |
+
# Mapping for subsequent blocks in other stages
|
| 155 |
+
r"backbone.res_layers.(\d+).blocks.0.short.conv.weight": r"model.backbone.model.encoder.stages.\1.layers.0.shortcut.1.convolution.weight",
|
| 156 |
+
r"backbone.res_layers.(\d+).blocks.0.short.norm.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.layers.0.shortcut.1.normalization.\2",
|
| 157 |
+
r"decoder.input_proj.(\d+).conv.weight": r"model.decoder_input_proj.\1.0.weight",
|
| 158 |
+
r"decoder.input_proj.(\d+).norm.(.*)": r"model.decoder_input_proj.\1.1.\2",
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
|
| 163 |
+
# Use the mapping to rename keys
|
| 164 |
+
for original_key, converted_key in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
|
| 165 |
+
for key in list(state_dict_keys.keys()):
|
| 166 |
+
new_key = re.sub(original_key, converted_key, key)
|
| 167 |
+
if new_key != key:
|
| 168 |
+
state_dict_keys[new_key] = state_dict_keys.pop(key)
|
| 169 |
+
|
| 170 |
+
return state_dict_keys
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def read_in_q_k_v(state_dict, config):
|
| 174 |
+
prefix = ""
|
| 175 |
+
encoder_hidden_dim = config.encoder_hidden_dim
|
| 176 |
+
|
| 177 |
+
# first: transformer encoder
|
| 178 |
+
for i in range(config.encoder_layers):
|
| 179 |
+
# read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
|
| 180 |
+
in_proj_weight = state_dict.pop(f"{prefix}encoder.encoder.{i}.layers.0.self_attn.in_proj_weight")
|
| 181 |
+
in_proj_bias = state_dict.pop(f"{prefix}encoder.encoder.{i}.layers.0.self_attn.in_proj_bias")
|
| 182 |
+
# next, add query, keys and values (in that order) to the state dict
|
| 183 |
+
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.q_proj.weight"] = in_proj_weight[
|
| 184 |
+
:encoder_hidden_dim, :
|
| 185 |
+
]
|
| 186 |
+
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.q_proj.bias"] = in_proj_bias[:encoder_hidden_dim]
|
| 187 |
+
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.k_proj.weight"] = in_proj_weight[
|
| 188 |
+
encoder_hidden_dim : 2 * encoder_hidden_dim, :
|
| 189 |
+
]
|
| 190 |
+
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.k_proj.bias"] = in_proj_bias[
|
| 191 |
+
encoder_hidden_dim : 2 * encoder_hidden_dim
|
| 192 |
+
]
|
| 193 |
+
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.v_proj.weight"] = in_proj_weight[
|
| 194 |
+
-encoder_hidden_dim:, :
|
| 195 |
+
]
|
| 196 |
+
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.v_proj.bias"] = in_proj_bias[-encoder_hidden_dim:]
|
| 197 |
+
# next: transformer decoder (which is a bit more complex because it also includes cross-attention)
|
| 198 |
+
for i in range(config.decoder_layers):
|
| 199 |
+
# read in weights + bias of input projection layer of self-attention
|
| 200 |
+
in_proj_weight = state_dict.pop(f"{prefix}decoder.decoder.layers.{i}.self_attn.in_proj_weight")
|
| 201 |
+
in_proj_bias = state_dict.pop(f"{prefix}decoder.decoder.layers.{i}.self_attn.in_proj_bias")
|
| 202 |
+
# next, add query, keys and values (in that order) to the state dict
|
| 203 |
+
state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
|
| 204 |
+
state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
|
| 205 |
+
state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
|
| 206 |
+
state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
|
| 207 |
+
state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
|
| 208 |
+
state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# We will verify our results on an image of cute cats
|
| 212 |
+
def prepare_img():
|
| 213 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 214 |
+
im = Image.open(requests.get(url, stream=True).raw)
|
| 215 |
+
|
| 216 |
+
return im
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
@torch.no_grad()
|
| 220 |
+
def write_model_and_image_processor(model_name, output_dir, push_to_hub, repo_id):
|
| 221 |
+
"""
|
| 222 |
+
Copy/paste/tweak model's weights to our RTDETR structure.
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
# load default config
|
| 226 |
+
config = get_rt_detr_v2_config(model_name)
|
| 227 |
+
|
| 228 |
+
# load original model from torch hub
|
| 229 |
+
model_name_to_checkpoint_url = {
|
| 230 |
+
"rtdetr_v2_r18vd": "https://github.com/lyuwenyu/storage/releases/download/v0.2/rtdetrv2_r18vd_120e_coco_rerun_48.1.pth",
|
| 231 |
+
"rtdetr_v2_r34vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r34vd_120e_coco_ema.pth",
|
| 232 |
+
"rtdetr_v2_r50vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r50vd_6x_coco_ema.pth",
|
| 233 |
+
"rtdetr_v2_r101vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r101vd_6x_coco_from_paddle.pth",
|
| 234 |
+
}
|
| 235 |
+
logger.info(f"Converting model {model_name}...")
|
| 236 |
+
state_dict = torch.hub.load_state_dict_from_url(model_name_to_checkpoint_url[model_name], map_location="cpu")[
|
| 237 |
+
"ema"
|
| 238 |
+
]["module"]
|
| 239 |
+
# rename keys
|
| 240 |
+
state_dict = convert_old_keys_to_new_keys(state_dict)
|
| 241 |
+
for key in state_dict.copy().keys():
|
| 242 |
+
if key.endswith("num_batches_tracked"):
|
| 243 |
+
del state_dict[key]
|
| 244 |
+
# query, key and value matrices need special treatment
|
| 245 |
+
read_in_q_k_v(state_dict, config)
|
| 246 |
+
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
|
| 247 |
+
for key in state_dict.copy().keys():
|
| 248 |
+
if key.endswith("num_batches_tracked"):
|
| 249 |
+
del state_dict[key]
|
| 250 |
+
# for two_stage
|
| 251 |
+
if "bbox_embed" in key or ("class_embed" in key and "denoising_" not in key):
|
| 252 |
+
state_dict[key.split("model.decoder.")[-1]] = state_dict[key]
|
| 253 |
+
|
| 254 |
+
# no need in ckpt
|
| 255 |
+
del state_dict["decoder.anchors"]
|
| 256 |
+
del state_dict["decoder.valid_mask"]
|
| 257 |
+
# finally, create HuggingFace model and load state dict
|
| 258 |
+
model = RTDetrV2ForObjectDetection(config)
|
| 259 |
+
model.load_state_dict(state_dict)
|
| 260 |
+
model.eval()
|
| 261 |
+
|
| 262 |
+
# load image processor
|
| 263 |
+
image_processor = RTDetrImageProcessor()
|
| 264 |
+
|
| 265 |
+
# prepare image
|
| 266 |
+
img = prepare_img()
|
| 267 |
+
|
| 268 |
+
# preprocess image
|
| 269 |
+
transformations = transforms.Compose(
|
| 270 |
+
[
|
| 271 |
+
transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR),
|
| 272 |
+
transforms.ToTensor(),
|
| 273 |
+
]
|
| 274 |
+
)
|
| 275 |
+
original_pixel_values = transformations(img).unsqueeze(0) # insert batch dimension
|
| 276 |
+
|
| 277 |
+
encoding = image_processor(images=img, return_tensors="pt")
|
| 278 |
+
pixel_values = encoding["pixel_values"]
|
| 279 |
+
|
| 280 |
+
assert torch.allclose(original_pixel_values, pixel_values)
|
| 281 |
+
|
| 282 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 283 |
+
model.to(device)
|
| 284 |
+
pixel_values = pixel_values.to(device)
|
| 285 |
+
|
| 286 |
+
# Pass image by the model
|
| 287 |
+
with torch.no_grad():
|
| 288 |
+
outputs = model(pixel_values)
|
| 289 |
+
|
| 290 |
+
if model_name == "rtdetr_v2_r18vd":
|
| 291 |
+
expected_slice_logits = torch.tensor(
|
| 292 |
+
[[-3.7045, -5.1913, -6.1787], [-4.0106, -9.3450, -5.2043], [-4.1287, -4.7463, -5.8634]]
|
| 293 |
+
)
|
| 294 |
+
expected_slice_boxes = torch.tensor(
|
| 295 |
+
[[0.2582, 0.5497, 0.4764], [0.1684, 0.1985, 0.2120], [0.7665, 0.4146, 0.4669]]
|
| 296 |
+
)
|
| 297 |
+
elif model_name == "rtdetr_v2_r34vd":
|
| 298 |
+
expected_slice_logits = torch.tensor(
|
| 299 |
+
[[-4.6108, -5.9453, -3.8505], [-3.8702, -6.1136, -5.5677], [-3.7790, -6.4538, -5.9449]]
|
| 300 |
+
)
|
| 301 |
+
expected_slice_boxes = torch.tensor(
|
| 302 |
+
[[0.1691, 0.1984, 0.2118], [0.2594, 0.5506, 0.4736], [0.7669, 0.4136, 0.4654]]
|
| 303 |
+
)
|
| 304 |
+
elif model_name == "rtdetr_v2_r50vd":
|
| 305 |
+
expected_slice_logits = torch.tensor(
|
| 306 |
+
[[-4.7881, -4.6754, -6.1624], [-5.4441, -6.6486, -4.3840], [-3.5455, -4.9318, -6.3544]]
|
| 307 |
+
)
|
| 308 |
+
expected_slice_boxes = torch.tensor(
|
| 309 |
+
[[0.2588, 0.5487, 0.4747], [0.5497, 0.2760, 0.0573], [0.7688, 0.4133, 0.4634]]
|
| 310 |
+
)
|
| 311 |
+
elif model_name == "rtdetr_v2_r101vd":
|
| 312 |
+
expected_slice_logits = torch.tensor(
|
| 313 |
+
[[-4.6162, -4.9189, -4.6656], [-4.4701, -4.4997, -4.9659], [-5.6641, -7.9000, -5.0725]]
|
| 314 |
+
)
|
| 315 |
+
expected_slice_boxes = torch.tensor(
|
| 316 |
+
[[0.7707, 0.4124, 0.4585], [0.2589, 0.5492, 0.4735], [0.1688, 0.1993, 0.2108]]
|
| 317 |
+
)
|
| 318 |
+
else:
|
| 319 |
+
raise ValueError(f"Unknown rt_detr_v2_name: {model_name}")
|
| 320 |
+
assert torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits.to(outputs.logits.device), atol=1e-4)
|
| 321 |
+
assert torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes.to(outputs.pred_boxes.device), atol=1e-3)
|
| 322 |
+
|
| 323 |
+
if output_dir is not None:
|
| 324 |
+
Path(output_dir).mkdir(exist_ok=True)
|
| 325 |
+
print(f"Saving model {model_name} to {output_dir}")
|
| 326 |
+
model.save_pretrained(output_dir)
|
| 327 |
+
print(f"Saving image processor to {output_dir}")
|
| 328 |
+
image_processor.save_pretrained(output_dir)
|
| 329 |
+
|
| 330 |
+
if push_to_hub:
|
| 331 |
+
# Upload model, image processor and config to the hub
|
| 332 |
+
logger.info("Uploading PyTorch model and image processor to the hub...")
|
| 333 |
+
config.push_to_hub(
|
| 334 |
+
repo_id=repo_id,
|
| 335 |
+
commit_message="Add config from convert_rt_detr_v2_original_pytorch_checkpoint_to_pytorch.py",
|
| 336 |
+
)
|
| 337 |
+
model.push_to_hub(
|
| 338 |
+
repo_id=repo_id,
|
| 339 |
+
commit_message="Add model from convert_rt_detr_v2_original_pytorch_checkpoint_to_pytorch.py",
|
| 340 |
+
)
|
| 341 |
+
image_processor.push_to_hub(
|
| 342 |
+
repo_id=repo_id,
|
| 343 |
+
commit_message="Add image processor from convert_rt_detr_v2_original_pytorch_checkpoint_to_pytorch.py",
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
if __name__ == "__main__":
|
| 348 |
+
parser = argparse.ArgumentParser()
|
| 349 |
+
parser.add_argument(
|
| 350 |
+
"--model_name",
|
| 351 |
+
default="rtdetr_v2_r18vd",
|
| 352 |
+
type=str,
|
| 353 |
+
help="model_name of the checkpoint you'd like to convert.",
|
| 354 |
+
)
|
| 355 |
+
parser.add_argument("--output_dir", default=None, type=str, help="Location to write HF model and image processor")
|
| 356 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to the hub or not.")
|
| 357 |
+
parser.add_argument(
|
| 358 |
+
"--repo_id",
|
| 359 |
+
type=str,
|
| 360 |
+
help="repo_id where the model will be pushed to.",
|
| 361 |
+
)
|
| 362 |
+
args = parser.parse_args()
|
| 363 |
+
write_model_and_image_processor(args.model_name, args.output_dir, args.push_to_hub, args.repo_id)
|
docs/transformers/build/lib/transformers/models/rt_detr_v2/modular_rt_detr_v2.py
ADDED
|
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 Baidu Inc and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import warnings
|
| 16 |
+
from functools import partial
|
| 17 |
+
from typing import List, Optional
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from torch import Tensor, nn
|
| 22 |
+
|
| 23 |
+
from ...configuration_utils import PretrainedConfig
|
| 24 |
+
from ...utils import is_torchdynamo_compiling, logging
|
| 25 |
+
from ...utils.backbone_utils import (
|
| 26 |
+
verify_backbone_config_arguments,
|
| 27 |
+
)
|
| 28 |
+
from ..auto import CONFIG_MAPPING
|
| 29 |
+
from ..rt_detr.modeling_rt_detr import (
|
| 30 |
+
RTDetrDecoder,
|
| 31 |
+
RTDetrDecoderLayer,
|
| 32 |
+
RTDetrForObjectDetection,
|
| 33 |
+
RTDetrMLPPredictionHead,
|
| 34 |
+
RTDetrModel,
|
| 35 |
+
RTDetrPreTrainedModel,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class RTDetrV2Config(PretrainedConfig):
|
| 43 |
+
r"""
|
| 44 |
+
This is the configuration class to store the configuration of a [`RTDetrV2Model`]. It is used to instantiate a
|
| 45 |
+
RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 46 |
+
with the defaults will yield a similar configuration to that of the RT-DETR architecture.
|
| 47 |
+
|
| 48 |
+
e.g. [PekingU/rtdetr_r18vd](https://huggingface.co/PekingU/rtdetr_r18vd)
|
| 49 |
+
|
| 50 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 51 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
initializer_range (`float`, *optional*, defaults to 0.01):
|
| 55 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 56 |
+
initializer_bias_prior_prob (`float`, *optional*):
|
| 57 |
+
The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
|
| 58 |
+
If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
|
| 59 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 60 |
+
The epsilon used by the layer normalization layers.
|
| 61 |
+
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 62 |
+
The epsilon used by the batch normalization layers.
|
| 63 |
+
backbone_config (`Dict`, *optional*, defaults to `RTDetrV2ResNetConfig()`):
|
| 64 |
+
The configuration of the backbone model.
|
| 65 |
+
backbone (`str`, *optional*):
|
| 66 |
+
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
| 67 |
+
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
|
| 68 |
+
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
| 69 |
+
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
|
| 70 |
+
Whether to use pretrained weights for the backbone.
|
| 71 |
+
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
| 72 |
+
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
| 73 |
+
library.
|
| 74 |
+
freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
|
| 75 |
+
Whether to freeze the batch normalization layers in the backbone.
|
| 76 |
+
backbone_kwargs (`dict`, *optional*):
|
| 77 |
+
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
| 78 |
+
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
| 79 |
+
encoder_hidden_dim (`int`, *optional*, defaults to 256):
|
| 80 |
+
Dimension of the layers in hybrid encoder.
|
| 81 |
+
encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
|
| 82 |
+
Multi level features input for encoder.
|
| 83 |
+
feat_strides (`List[int]`, *optional*, defaults to `[8, 16, 32]`):
|
| 84 |
+
Strides used in each feature map.
|
| 85 |
+
encoder_layers (`int`, *optional*, defaults to 1):
|
| 86 |
+
Total of layers to be used by the encoder.
|
| 87 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 1024):
|
| 88 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 89 |
+
encoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 90 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 91 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 92 |
+
The ratio for all dropout layers.
|
| 93 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
| 94 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 95 |
+
encode_proj_layers (`List[int]`, *optional*, defaults to `[2]`):
|
| 96 |
+
Indexes of the projected layers to be used in the encoder.
|
| 97 |
+
positional_encoding_temperature (`int`, *optional*, defaults to 10000):
|
| 98 |
+
The temperature parameter used to create the positional encodings.
|
| 99 |
+
encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
|
| 100 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 101 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 102 |
+
activation_function (`str`, *optional*, defaults to `"silu"`):
|
| 103 |
+
The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
|
| 104 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 105 |
+
eval_size (`Tuple[int, int]`, *optional*):
|
| 106 |
+
Height and width used to compute the effective height and width of the position embeddings after taking
|
| 107 |
+
into account the stride.
|
| 108 |
+
normalize_before (`bool`, *optional*, defaults to `False`):
|
| 109 |
+
Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
|
| 110 |
+
feed-forward modules.
|
| 111 |
+
hidden_expansion (`float`, *optional*, defaults to 1.0):
|
| 112 |
+
Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
|
| 113 |
+
d_model (`int`, *optional*, defaults to 256):
|
| 114 |
+
Dimension of the layers exclude hybrid encoder.
|
| 115 |
+
num_queries (`int`, *optional*, defaults to 300):
|
| 116 |
+
Number of object queries.
|
| 117 |
+
decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
|
| 118 |
+
Multi level features dimension for decoder
|
| 119 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 1024):
|
| 120 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 121 |
+
num_feature_levels (`int`, *optional*, defaults to 3):
|
| 122 |
+
The number of input feature levels.
|
| 123 |
+
decoder_n_points (`int`, *optional*, defaults to 4):
|
| 124 |
+
The number of sampled keys in each feature level for each attention head in the decoder.
|
| 125 |
+
decoder_layers (`int`, *optional*, defaults to 6):
|
| 126 |
+
Number of decoder layers.
|
| 127 |
+
decoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 128 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 129 |
+
decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
|
| 130 |
+
The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
|
| 131 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 132 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 133 |
+
The dropout ratio for the attention probabilities.
|
| 134 |
+
num_denoising (`int`, *optional*, defaults to 100):
|
| 135 |
+
The total number of denoising tasks or queries to be used for contrastive denoising.
|
| 136 |
+
label_noise_ratio (`float`, *optional*, defaults to 0.5):
|
| 137 |
+
The fraction of denoising labels to which random noise should be added.
|
| 138 |
+
box_noise_scale (`float`, *optional*, defaults to 1.0):
|
| 139 |
+
Scale or magnitude of noise to be added to the bounding boxes.
|
| 140 |
+
learn_initial_query (`bool`, *optional*, defaults to `False`):
|
| 141 |
+
Indicates whether the initial query embeddings for the decoder should be learned during training
|
| 142 |
+
anchor_image_size (`Tuple[int, int]`, *optional*):
|
| 143 |
+
Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied.
|
| 144 |
+
with_box_refine (`bool`, *optional*, defaults to `True`):
|
| 145 |
+
Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
|
| 146 |
+
based on the predictions from the previous layer.
|
| 147 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
| 148 |
+
Whether the architecture has an encoder decoder structure.
|
| 149 |
+
matcher_alpha (`float`, *optional*, defaults to 0.25):
|
| 150 |
+
Parameter alpha used by the Hungarian Matcher.
|
| 151 |
+
matcher_gamma (`float`, *optional*, defaults to 2.0):
|
| 152 |
+
Parameter gamma used by the Hungarian Matcher.
|
| 153 |
+
matcher_class_cost (`float`, *optional*, defaults to 2.0):
|
| 154 |
+
The relative weight of the class loss used by the Hungarian Matcher.
|
| 155 |
+
matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
|
| 156 |
+
The relative weight of the bounding box loss used by the Hungarian Matcher.
|
| 157 |
+
matcher_giou_cost (`float`, *optional*, defaults to 2.0):
|
| 158 |
+
The relative weight of the giou loss of used by the Hungarian Matcher.
|
| 159 |
+
use_focal_loss (`bool`, *optional*, defaults to `True`):
|
| 160 |
+
Parameter informing if focal loss should be used.
|
| 161 |
+
auxiliary_loss (`bool`, *optional*, defaults to `True`):
|
| 162 |
+
Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
|
| 163 |
+
focal_loss_alpha (`float`, *optional*, defaults to 0.75):
|
| 164 |
+
Parameter alpha used to compute the focal loss.
|
| 165 |
+
focal_loss_gamma (`float`, *optional*, defaults to 2.0):
|
| 166 |
+
Parameter gamma used to compute the focal loss.
|
| 167 |
+
weight_loss_vfl (`float`, *optional*, defaults to 1.0):
|
| 168 |
+
Relative weight of the varifocal loss in the object detection loss.
|
| 169 |
+
weight_loss_bbox (`float`, *optional*, defaults to 5.0):
|
| 170 |
+
Relative weight of the L1 bounding box loss in the object detection loss.
|
| 171 |
+
weight_loss_giou (`float`, *optional*, defaults to 2.0):
|
| 172 |
+
Relative weight of the generalized IoU loss in the object detection loss.
|
| 173 |
+
eos_coefficient (`float`, *optional*, defaults to 0.0001):
|
| 174 |
+
Relative classification weight of the 'no-object' class in the object detection loss.
|
| 175 |
+
decoder_n_levels (`int`, *optional*, defaults to 3):
|
| 176 |
+
The number of feature levels used by the decoder.
|
| 177 |
+
decoder_offset_scale (`float`, *optional*, defaults to 0.5):
|
| 178 |
+
Scaling factor applied to the attention offsets in the decoder.
|
| 179 |
+
decoder_method (`str`, *optional*, defaults to `"default"`):
|
| 180 |
+
The method to use for the decoder: `"default"` or `"discrete"`.
|
| 181 |
+
|
| 182 |
+
Examples:
|
| 183 |
+
|
| 184 |
+
```python
|
| 185 |
+
>>> from transformers import RTDetrV2Config, RTDetrV2Model
|
| 186 |
+
|
| 187 |
+
>>> # Initializing a RT-DETR configuration
|
| 188 |
+
>>> configuration = RTDetrV2Config()
|
| 189 |
+
|
| 190 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 191 |
+
>>> model = RTDetrV2Model(configuration)
|
| 192 |
+
|
| 193 |
+
>>> # Accessing the model configuration
|
| 194 |
+
>>> configuration = model.config
|
| 195 |
+
```
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
model_type = "rt_detr_v2"
|
| 199 |
+
layer_types = ["basic", "bottleneck"]
|
| 200 |
+
attribute_map = {
|
| 201 |
+
"hidden_size": "d_model",
|
| 202 |
+
"num_attention_heads": "encoder_attention_heads",
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
initializer_range=0.01,
|
| 208 |
+
initializer_bias_prior_prob=None,
|
| 209 |
+
layer_norm_eps=1e-5,
|
| 210 |
+
batch_norm_eps=1e-5,
|
| 211 |
+
# backbone
|
| 212 |
+
backbone_config=None,
|
| 213 |
+
backbone=None,
|
| 214 |
+
use_pretrained_backbone=False,
|
| 215 |
+
use_timm_backbone=False,
|
| 216 |
+
freeze_backbone_batch_norms=True,
|
| 217 |
+
backbone_kwargs=None,
|
| 218 |
+
# encoder HybridEncoder
|
| 219 |
+
encoder_hidden_dim=256,
|
| 220 |
+
encoder_in_channels=[512, 1024, 2048],
|
| 221 |
+
feat_strides=[8, 16, 32],
|
| 222 |
+
encoder_layers=1,
|
| 223 |
+
encoder_ffn_dim=1024,
|
| 224 |
+
encoder_attention_heads=8,
|
| 225 |
+
dropout=0.0,
|
| 226 |
+
activation_dropout=0.0,
|
| 227 |
+
encode_proj_layers=[2],
|
| 228 |
+
positional_encoding_temperature=10000,
|
| 229 |
+
encoder_activation_function="gelu",
|
| 230 |
+
activation_function="silu",
|
| 231 |
+
eval_size=None,
|
| 232 |
+
normalize_before=False,
|
| 233 |
+
hidden_expansion=1.0,
|
| 234 |
+
# decoder RTDetrV2Transformer
|
| 235 |
+
d_model=256,
|
| 236 |
+
num_queries=300,
|
| 237 |
+
decoder_in_channels=[256, 256, 256],
|
| 238 |
+
decoder_ffn_dim=1024,
|
| 239 |
+
num_feature_levels=3,
|
| 240 |
+
decoder_n_points=4,
|
| 241 |
+
decoder_layers=6,
|
| 242 |
+
decoder_attention_heads=8,
|
| 243 |
+
decoder_activation_function="relu",
|
| 244 |
+
attention_dropout=0.0,
|
| 245 |
+
num_denoising=100,
|
| 246 |
+
label_noise_ratio=0.5,
|
| 247 |
+
box_noise_scale=1.0,
|
| 248 |
+
learn_initial_query=False,
|
| 249 |
+
anchor_image_size=None,
|
| 250 |
+
with_box_refine=True,
|
| 251 |
+
is_encoder_decoder=True,
|
| 252 |
+
# Loss
|
| 253 |
+
matcher_alpha=0.25,
|
| 254 |
+
matcher_gamma=2.0,
|
| 255 |
+
matcher_class_cost=2.0,
|
| 256 |
+
matcher_bbox_cost=5.0,
|
| 257 |
+
matcher_giou_cost=2.0,
|
| 258 |
+
use_focal_loss=True,
|
| 259 |
+
auxiliary_loss=True,
|
| 260 |
+
focal_loss_alpha=0.75,
|
| 261 |
+
focal_loss_gamma=2.0,
|
| 262 |
+
weight_loss_vfl=1.0,
|
| 263 |
+
weight_loss_bbox=5.0,
|
| 264 |
+
weight_loss_giou=2.0,
|
| 265 |
+
eos_coefficient=1e-4,
|
| 266 |
+
decoder_n_levels=3, # default value
|
| 267 |
+
decoder_offset_scale=0.5, # default value
|
| 268 |
+
decoder_method="default",
|
| 269 |
+
**kwargs,
|
| 270 |
+
):
|
| 271 |
+
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
| 272 |
+
self.initializer_range = initializer_range
|
| 273 |
+
self.initializer_bias_prior_prob = initializer_bias_prior_prob
|
| 274 |
+
self.layer_norm_eps = layer_norm_eps
|
| 275 |
+
self.batch_norm_eps = batch_norm_eps
|
| 276 |
+
# backbone
|
| 277 |
+
if backbone_config is None and backbone is None:
|
| 278 |
+
logger.info(
|
| 279 |
+
"`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetrV2-ResNet` backbone."
|
| 280 |
+
)
|
| 281 |
+
backbone_model_type = "rt_detr_resnet"
|
| 282 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 283 |
+
# this will map it to RTDetrResNetConfig
|
| 284 |
+
# note: we can instead create RTDetrV2ResNetConfig but it will be exactly the same as V1
|
| 285 |
+
# and we would need to create RTDetrV2ResNetModel
|
| 286 |
+
backbone_config = config_class(
|
| 287 |
+
num_channels=3,
|
| 288 |
+
embedding_size=64,
|
| 289 |
+
hidden_sizes=[256, 512, 1024, 2048],
|
| 290 |
+
depths=[3, 4, 6, 3],
|
| 291 |
+
layer_type="bottleneck",
|
| 292 |
+
hidden_act="relu",
|
| 293 |
+
downsample_in_first_stage=False,
|
| 294 |
+
downsample_in_bottleneck=False,
|
| 295 |
+
out_features=None,
|
| 296 |
+
out_indices=[2, 3, 4],
|
| 297 |
+
)
|
| 298 |
+
elif isinstance(backbone_config, dict):
|
| 299 |
+
backbone_model_type = backbone_config.pop("model_type")
|
| 300 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 301 |
+
backbone_config = config_class.from_dict(backbone_config)
|
| 302 |
+
|
| 303 |
+
verify_backbone_config_arguments(
|
| 304 |
+
use_timm_backbone=use_timm_backbone,
|
| 305 |
+
use_pretrained_backbone=use_pretrained_backbone,
|
| 306 |
+
backbone=backbone,
|
| 307 |
+
backbone_config=backbone_config,
|
| 308 |
+
backbone_kwargs=backbone_kwargs,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
self.backbone_config = backbone_config
|
| 312 |
+
self.backbone = backbone
|
| 313 |
+
self.use_pretrained_backbone = use_pretrained_backbone
|
| 314 |
+
self.use_timm_backbone = use_timm_backbone
|
| 315 |
+
self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
|
| 316 |
+
self.backbone_kwargs = backbone_kwargs
|
| 317 |
+
# encoder
|
| 318 |
+
self.encoder_hidden_dim = encoder_hidden_dim
|
| 319 |
+
self.encoder_in_channels = encoder_in_channels
|
| 320 |
+
self.feat_strides = feat_strides
|
| 321 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 322 |
+
self.dropout = dropout
|
| 323 |
+
self.activation_dropout = activation_dropout
|
| 324 |
+
self.encode_proj_layers = encode_proj_layers
|
| 325 |
+
self.encoder_layers = encoder_layers
|
| 326 |
+
self.positional_encoding_temperature = positional_encoding_temperature
|
| 327 |
+
self.eval_size = eval_size
|
| 328 |
+
self.normalize_before = normalize_before
|
| 329 |
+
self.encoder_activation_function = encoder_activation_function
|
| 330 |
+
self.activation_function = activation_function
|
| 331 |
+
self.hidden_expansion = hidden_expansion
|
| 332 |
+
self.num_queries = num_queries
|
| 333 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 334 |
+
self.decoder_in_channels = decoder_in_channels
|
| 335 |
+
self.num_feature_levels = num_feature_levels
|
| 336 |
+
self.decoder_n_points = decoder_n_points
|
| 337 |
+
self.decoder_layers = decoder_layers
|
| 338 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 339 |
+
self.decoder_activation_function = decoder_activation_function
|
| 340 |
+
self.attention_dropout = attention_dropout
|
| 341 |
+
self.num_denoising = num_denoising
|
| 342 |
+
self.label_noise_ratio = label_noise_ratio
|
| 343 |
+
self.box_noise_scale = box_noise_scale
|
| 344 |
+
self.learn_initial_query = learn_initial_query
|
| 345 |
+
self.anchor_image_size = anchor_image_size
|
| 346 |
+
self.auxiliary_loss = auxiliary_loss
|
| 347 |
+
self.with_box_refine = with_box_refine
|
| 348 |
+
# Loss
|
| 349 |
+
self.matcher_alpha = matcher_alpha
|
| 350 |
+
self.matcher_gamma = matcher_gamma
|
| 351 |
+
self.matcher_class_cost = matcher_class_cost
|
| 352 |
+
self.matcher_bbox_cost = matcher_bbox_cost
|
| 353 |
+
self.matcher_giou_cost = matcher_giou_cost
|
| 354 |
+
self.use_focal_loss = use_focal_loss
|
| 355 |
+
self.focal_loss_alpha = focal_loss_alpha
|
| 356 |
+
self.focal_loss_gamma = focal_loss_gamma
|
| 357 |
+
self.weight_loss_vfl = weight_loss_vfl
|
| 358 |
+
self.weight_loss_bbox = weight_loss_bbox
|
| 359 |
+
self.weight_loss_giou = weight_loss_giou
|
| 360 |
+
self.eos_coefficient = eos_coefficient
|
| 361 |
+
|
| 362 |
+
if not hasattr(self, "d_model"):
|
| 363 |
+
self.d_model = d_model
|
| 364 |
+
|
| 365 |
+
if not hasattr(self, "encoder_attention_heads"):
|
| 366 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 367 |
+
# add the new attributes with the given values or defaults
|
| 368 |
+
self.decoder_n_levels = decoder_n_levels
|
| 369 |
+
self.decoder_offset_scale = decoder_offset_scale
|
| 370 |
+
self.decoder_method = decoder_method
|
| 371 |
+
|
| 372 |
+
@classmethod
|
| 373 |
+
def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
|
| 374 |
+
"""Instantiate a [`RTDetrV2Config`] (or a derived class) from a pre-trained backbone model configuration and DETR model
|
| 375 |
+
configuration.
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
backbone_config ([`PretrainedConfig`]):
|
| 379 |
+
The backbone configuration.
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
[`RTDetrV2Config`]: An instance of a configuration object
|
| 383 |
+
"""
|
| 384 |
+
return cls(
|
| 385 |
+
backbone_config=backbone_config,
|
| 386 |
+
**kwargs,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def multi_scale_deformable_attention_v2(
|
| 391 |
+
value: Tensor,
|
| 392 |
+
value_spatial_shapes: Tensor,
|
| 393 |
+
sampling_locations: Tensor,
|
| 394 |
+
attention_weights: Tensor,
|
| 395 |
+
num_points_list: List[int],
|
| 396 |
+
method="default",
|
| 397 |
+
) -> Tensor:
|
| 398 |
+
batch_size, _, num_heads, hidden_dim = value.shape
|
| 399 |
+
_, num_queries, num_heads, num_levels, num_points = sampling_locations.shape
|
| 400 |
+
value_list = (
|
| 401 |
+
value.permute(0, 2, 3, 1)
|
| 402 |
+
.flatten(0, 1)
|
| 403 |
+
.split([height * width for height, width in value_spatial_shapes], dim=-1)
|
| 404 |
+
)
|
| 405 |
+
# sampling_offsets [8, 480, 8, 12, 2]
|
| 406 |
+
if method == "default":
|
| 407 |
+
sampling_grids = 2 * sampling_locations - 1
|
| 408 |
+
elif method == "discrete":
|
| 409 |
+
sampling_grids = sampling_locations
|
| 410 |
+
sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
| 411 |
+
sampling_grids = sampling_grids.split(num_points_list, dim=-2)
|
| 412 |
+
sampling_value_list = []
|
| 413 |
+
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
| 414 |
+
# batch_size, height*width, num_heads, hidden_dim
|
| 415 |
+
# -> batch_size, height*width, num_heads*hidden_dim
|
| 416 |
+
# -> batch_size, num_heads*hidden_dim, height*width
|
| 417 |
+
# -> batch_size*num_heads, hidden_dim, height, width
|
| 418 |
+
value_l_ = value_list[level_id].reshape(batch_size * num_heads, hidden_dim, height, width)
|
| 419 |
+
# batch_size, num_queries, num_heads, num_points, 2
|
| 420 |
+
# -> batch_size, num_heads, num_queries, num_points, 2
|
| 421 |
+
# -> batch_size*num_heads, num_queries, num_points, 2
|
| 422 |
+
sampling_grid_l_ = sampling_grids[level_id]
|
| 423 |
+
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
| 424 |
+
if method == "default":
|
| 425 |
+
sampling_value_l_ = nn.functional.grid_sample(
|
| 426 |
+
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
|
| 427 |
+
)
|
| 428 |
+
elif method == "discrete":
|
| 429 |
+
sampling_coord = (sampling_grid_l_ * torch.tensor([[width, height]], device=value.device) + 0.5).to(
|
| 430 |
+
torch.int64
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# Separate clamping for x and y coordinates
|
| 434 |
+
sampling_coord_x = sampling_coord[..., 0].clamp(0, width - 1)
|
| 435 |
+
sampling_coord_y = sampling_coord[..., 1].clamp(0, height - 1)
|
| 436 |
+
|
| 437 |
+
# Combine the clamped coordinates
|
| 438 |
+
sampling_coord = torch.stack([sampling_coord_x, sampling_coord_y], dim=-1)
|
| 439 |
+
sampling_coord = sampling_coord.reshape(batch_size * num_heads, num_queries * num_points_list[level_id], 2)
|
| 440 |
+
sampling_idx = (
|
| 441 |
+
torch.arange(sampling_coord.shape[0], device=value.device)
|
| 442 |
+
.unsqueeze(-1)
|
| 443 |
+
.repeat(1, sampling_coord.shape[1])
|
| 444 |
+
)
|
| 445 |
+
sampling_value_l_ = value_l_[sampling_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]]
|
| 446 |
+
sampling_value_l_ = sampling_value_l_.permute(0, 2, 1).reshape(
|
| 447 |
+
batch_size * num_heads, hidden_dim, num_queries, num_points_list[level_id]
|
| 448 |
+
)
|
| 449 |
+
sampling_value_list.append(sampling_value_l_)
|
| 450 |
+
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
| 451 |
+
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
| 452 |
+
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
| 453 |
+
attention_weights = attention_weights.permute(0, 2, 1, 3).reshape(
|
| 454 |
+
batch_size * num_heads, 1, num_queries, sum(num_points_list)
|
| 455 |
+
)
|
| 456 |
+
output = (
|
| 457 |
+
(torch.concat(sampling_value_list, dim=-1) * attention_weights)
|
| 458 |
+
.sum(-1)
|
| 459 |
+
.view(batch_size, num_heads * hidden_dim, num_queries)
|
| 460 |
+
)
|
| 461 |
+
return output.transpose(1, 2).contiguous()
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
# the main change
|
| 465 |
+
class RTDetrV2MultiscaleDeformableAttention(nn.Module):
|
| 466 |
+
"""
|
| 467 |
+
RTDetrV2 version of multiscale deformable attention, extending the base implementation
|
| 468 |
+
with improved offset handling and initialization.
|
| 469 |
+
"""
|
| 470 |
+
|
| 471 |
+
def __init__(self, config: RTDetrV2Config):
|
| 472 |
+
super().__init__()
|
| 473 |
+
num_heads = config.decoder_attention_heads
|
| 474 |
+
n_points = config.decoder_n_points
|
| 475 |
+
|
| 476 |
+
if config.d_model % num_heads != 0:
|
| 477 |
+
raise ValueError(
|
| 478 |
+
f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
|
| 479 |
+
)
|
| 480 |
+
dim_per_head = config.d_model // num_heads
|
| 481 |
+
# check if dim_per_head is power of 2
|
| 482 |
+
if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
|
| 483 |
+
warnings.warn(
|
| 484 |
+
"You'd better set embed_dim (d_model) in RTDetrV2MultiscaleDeformableAttention to make the"
|
| 485 |
+
" dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
|
| 486 |
+
" implementation."
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
self.im2col_step = 64
|
| 490 |
+
|
| 491 |
+
self.d_model = config.d_model
|
| 492 |
+
|
| 493 |
+
# V2-specific attributes
|
| 494 |
+
self.n_levels = config.decoder_n_levels
|
| 495 |
+
self.n_heads = num_heads
|
| 496 |
+
self.n_points = n_points
|
| 497 |
+
|
| 498 |
+
self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
|
| 499 |
+
self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
|
| 500 |
+
self.value_proj = nn.Linear(config.d_model, config.d_model)
|
| 501 |
+
self.output_proj = nn.Linear(config.d_model, config.d_model)
|
| 502 |
+
|
| 503 |
+
self.offset_scale = config.decoder_offset_scale
|
| 504 |
+
self.method = config.decoder_method
|
| 505 |
+
|
| 506 |
+
# Initialize n_points list and scale
|
| 507 |
+
n_points_list = [self.n_points for _ in range(self.n_levels)]
|
| 508 |
+
self.n_points_list = n_points_list
|
| 509 |
+
n_points_scale = [1 / n for n in n_points_list for _ in range(n)]
|
| 510 |
+
self.register_buffer("n_points_scale", torch.tensor(n_points_scale, dtype=torch.float32))
|
| 511 |
+
|
| 512 |
+
def forward(
|
| 513 |
+
self,
|
| 514 |
+
hidden_states: torch.Tensor,
|
| 515 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 516 |
+
encoder_hidden_states=None,
|
| 517 |
+
encoder_attention_mask=None,
|
| 518 |
+
position_embeddings: Optional[torch.Tensor] = None,
|
| 519 |
+
reference_points=None,
|
| 520 |
+
spatial_shapes=None,
|
| 521 |
+
spatial_shapes_list=None,
|
| 522 |
+
level_start_index=None,
|
| 523 |
+
output_attentions: bool = False,
|
| 524 |
+
):
|
| 525 |
+
# Process inputs up to sampling locations calculation using parent class logic
|
| 526 |
+
if position_embeddings is not None:
|
| 527 |
+
hidden_states = hidden_states + position_embeddings
|
| 528 |
+
|
| 529 |
+
batch_size, num_queries, _ = hidden_states.shape
|
| 530 |
+
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
| 531 |
+
if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
|
| 532 |
+
raise ValueError(
|
| 533 |
+
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
value = self.value_proj(encoder_hidden_states)
|
| 537 |
+
if attention_mask is not None:
|
| 538 |
+
value = value.masked_fill(~attention_mask[..., None], float(0))
|
| 539 |
+
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
|
| 540 |
+
|
| 541 |
+
# V2-specific sampling offsets shape
|
| 542 |
+
sampling_offsets = self.sampling_offsets(hidden_states).view(
|
| 543 |
+
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points, 2
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
attention_weights = self.attention_weights(hidden_states).view(
|
| 547 |
+
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
|
| 548 |
+
)
|
| 549 |
+
attention_weights = F.softmax(attention_weights, -1)
|
| 550 |
+
|
| 551 |
+
# V2-specific sampling locations calculation
|
| 552 |
+
if reference_points.shape[-1] == 2:
|
| 553 |
+
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
|
| 554 |
+
sampling_locations = (
|
| 555 |
+
reference_points[:, :, None, :, None, :]
|
| 556 |
+
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
| 557 |
+
)
|
| 558 |
+
elif reference_points.shape[-1] == 4:
|
| 559 |
+
n_points_scale = self.n_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1)
|
| 560 |
+
offset = sampling_offsets * n_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale
|
| 561 |
+
sampling_locations = reference_points[:, :, None, :, :2] + offset
|
| 562 |
+
else:
|
| 563 |
+
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
| 564 |
+
|
| 565 |
+
# V2-specific attention implementation choice
|
| 566 |
+
output = multi_scale_deformable_attention_v2(
|
| 567 |
+
value, spatial_shapes_list, sampling_locations, attention_weights, self.n_points_list, self.method
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
output = self.output_proj(output)
|
| 571 |
+
return output, attention_weights
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
class RTDetrV2DecoderLayer(RTDetrDecoderLayer):
|
| 575 |
+
def __init__(self, config: RTDetrV2Config):
|
| 576 |
+
# initialize parent class
|
| 577 |
+
super().__init__(config)
|
| 578 |
+
# override only the encoder attention module with v2 version
|
| 579 |
+
self.encoder_attn = RTDetrV2MultiscaleDeformableAttention(config)
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
class RTDetrV2PreTrainedModel(RTDetrPreTrainedModel):
|
| 583 |
+
pass
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
class RTDetrV2Decoder(RTDetrDecoder):
|
| 587 |
+
def __init__(self, config: RTDetrV2Config):
|
| 588 |
+
super().__init__(config)
|
| 589 |
+
self.layers = nn.ModuleList([RTDetrV2DecoderLayer(config) for _ in range(config.decoder_layers)])
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
class RTDetrV2Model(RTDetrModel):
|
| 593 |
+
def __init__(self, config: RTDetrV2Config):
|
| 594 |
+
super().__init__(config)
|
| 595 |
+
# decoder
|
| 596 |
+
self.decoder = RTDetrV2Decoder(config)
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
class RTDetrV2MLPPredictionHead(RTDetrMLPPredictionHead):
|
| 600 |
+
pass
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
class RTDetrV2ForObjectDetection(RTDetrForObjectDetection, RTDetrV2PreTrainedModel):
|
| 604 |
+
def __init__(self, config: RTDetrV2Config):
|
| 605 |
+
RTDetrV2PreTrainedModel.__init__(config)
|
| 606 |
+
# RTDETR encoder-decoder model
|
| 607 |
+
self.model = RTDetrV2Model(config)
|
| 608 |
+
|
| 609 |
+
# Detection heads on top
|
| 610 |
+
class_embed = partial(nn.Linear, config.d_model, config.num_labels)
|
| 611 |
+
bbox_embed = partial(RTDetrV2MLPPredictionHead, config, config.d_model, config.d_model, 4, num_layers=3)
|
| 612 |
+
|
| 613 |
+
self.class_embed = nn.ModuleList([class_embed() for _ in range(config.decoder_layers)])
|
| 614 |
+
self.bbox_embed = nn.ModuleList([bbox_embed() for _ in range(config.decoder_layers)])
|
| 615 |
+
|
| 616 |
+
self.model.decoder.class_embed = self.class_embed
|
| 617 |
+
self.model.decoder.bbox_embed = self.bbox_embed
|
| 618 |
+
|
| 619 |
+
# Initialize weights and apply final processing
|
| 620 |
+
self.post_init()
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
__all__ = [
|
| 624 |
+
"RTDetrV2Config",
|
| 625 |
+
"RTDetrV2Model",
|
| 626 |
+
"RTDetrV2PreTrainedModel",
|
| 627 |
+
"RTDetrV2ForObjectDetection",
|
| 628 |
+
]
|
docs/transformers/build/lib/transformers/models/rwkv/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_rwkv import *
|
| 22 |
+
from .modeling_rwkv import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/rwkv/configuration_rwkv.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""RWKV configuration"""
|
| 17 |
+
|
| 18 |
+
from ...configuration_utils import PretrainedConfig
|
| 19 |
+
from ...utils import logging
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class RwkvConfig(PretrainedConfig):
|
| 26 |
+
"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`RwkvModel`]. It is used to instantiate a RWKV
|
| 28 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 29 |
+
defaults will yield a similar configuration to that of the RWVK-4
|
| 30 |
+
[RWKV/rwkv-4-169m-pile](https://huggingface.co/RWKV/rwkv-4-169m-pile) architecture.
|
| 31 |
+
|
| 32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 33 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
vocab_size (`int`, *optional*, defaults to 50277):
|
| 38 |
+
Vocabulary size of the RWKV model. Defines the number of different tokens that can be represented by the
|
| 39 |
+
`inputs_ids` passed when calling [`RwkvModel`].
|
| 40 |
+
context_length (`int`, *optional*, defaults to 1024):
|
| 41 |
+
The maximum sequence length that this model can be used with in a single forward (using it in RNN mode
|
| 42 |
+
lets use any sequence length).
|
| 43 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 44 |
+
Dimensionality of the embeddings and hidden states.
|
| 45 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 46 |
+
Number of hidden layers in the model.
|
| 47 |
+
attention_hidden_size (`int`, *optional*):
|
| 48 |
+
Dimensionality of the attention hidden states. Will default to `hidden_size` if unset.
|
| 49 |
+
intermediate_size (`int`, *optional*):
|
| 50 |
+
Dimensionality of the inner feed-forward layers. Will default to 4 times `hidden_size` if unset.
|
| 51 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
|
| 52 |
+
The epsilon to use in the layer normalization layers.
|
| 53 |
+
bos_token_id (`int`, *optional*, defaults to 0):
|
| 54 |
+
The id of the beginning of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer
|
| 55 |
+
as GPTNeoX.
|
| 56 |
+
eos_token_id (`int`, *optional*, defaults to 0):
|
| 57 |
+
The id of the end of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer as
|
| 58 |
+
GPTNeoX.
|
| 59 |
+
rescale_every (`int`, *optional*, defaults to 6):
|
| 60 |
+
At inference, the hidden states (and weights of the correponding output layers) are divided by 2 every
|
| 61 |
+
`rescale_every` layer. If set to 0 or a negative number, no rescale is done.
|
| 62 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 63 |
+
Whether or not to tie the word embeddings with the input token embeddings.
|
| 64 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 65 |
+
Whether or not the model should return the last state.
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
Example:
|
| 69 |
+
|
| 70 |
+
```python
|
| 71 |
+
>>> from transformers import RwkvConfig, RwkvModel
|
| 72 |
+
|
| 73 |
+
>>> # Initializing a Rwkv configuration
|
| 74 |
+
>>> configuration = RwkvConfig()
|
| 75 |
+
|
| 76 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 77 |
+
>>> model = RwkvModel(configuration)
|
| 78 |
+
|
| 79 |
+
>>> # Accessing the model configuration
|
| 80 |
+
>>> configuration = model.config
|
| 81 |
+
```"""
|
| 82 |
+
|
| 83 |
+
model_type = "rwkv"
|
| 84 |
+
attribute_map = {"max_position_embeddings": "context_length"}
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
vocab_size=50277,
|
| 89 |
+
context_length=1024,
|
| 90 |
+
hidden_size=4096,
|
| 91 |
+
num_hidden_layers=32,
|
| 92 |
+
attention_hidden_size=None,
|
| 93 |
+
intermediate_size=None,
|
| 94 |
+
layer_norm_epsilon=1e-5,
|
| 95 |
+
bos_token_id=0,
|
| 96 |
+
eos_token_id=0,
|
| 97 |
+
rescale_every=6,
|
| 98 |
+
tie_word_embeddings=False,
|
| 99 |
+
use_cache=True,
|
| 100 |
+
**kwargs,
|
| 101 |
+
):
|
| 102 |
+
self.vocab_size = vocab_size
|
| 103 |
+
self.context_length = context_length
|
| 104 |
+
self.hidden_size = hidden_size
|
| 105 |
+
self.num_hidden_layers = num_hidden_layers
|
| 106 |
+
self.attention_hidden_size = attention_hidden_size if attention_hidden_size is not None else hidden_size
|
| 107 |
+
self.intermediate_size = intermediate_size if intermediate_size is not None else 4 * hidden_size
|
| 108 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 109 |
+
self.rescale_every = rescale_every
|
| 110 |
+
self.use_cache = use_cache
|
| 111 |
+
|
| 112 |
+
self.bos_token_id = bos_token_id
|
| 113 |
+
self.eos_token_id = eos_token_id
|
| 114 |
+
|
| 115 |
+
super().__init__(
|
| 116 |
+
tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
__all__ = ["RwkvConfig"]
|