Student0809 commited on
Commit
3b47bbc
·
verified ·
1 Parent(s): 33b613a

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dev_scripts/ci_container_test.sh +41 -0
  2. 4JOB_TRAIN.jsonl +0 -0
  3. MANIFEST.in +5 -0
  4. Makefile +25 -0
  5. README.md +423 -0
  6. README_CN.md +413 -0
  7. checkMissing.py +86 -0
  8. clean_transcripts.py +95 -0
  9. count_audios.py +69 -0
  10. count_folders-Copy1.py +122 -0
  11. count_folders.py +122 -0
  12. dialogue_length_distribution.png +0 -0
  13. dialogue_length_ranges.png +0 -0
  14. docs/transformers/build/lib/transformers/models/sam/processing_sam.py +311 -0
  15. docs/transformers/build/lib/transformers/models/seamless_m4t/convert_fairseq2_to_hf.py +396 -0
  16. docs/transformers/build/lib/transformers/models/seamless_m4t/modeling_seamless_m4t.py +0 -0
  17. docs/transformers/build/lib/transformers/models/seamless_m4t/processing_seamless_m4t.py +120 -0
  18. docs/transformers/build/lib/transformers/models/seamless_m4t/tokenization_seamless_m4t_fast.py +450 -0
  19. docs/transformers/build/lib/transformers/models/seamless_m4t_v2/convert_fairseq2_to_hf.py +404 -0
  20. docs/transformers/build/lib/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +0 -0
  21. docs/transformers/build/lib/transformers/models/segformer/__init__.py +30 -0
  22. docs/transformers/build/lib/transformers/models/segformer/configuration_segformer.py +171 -0
  23. docs/transformers/build/lib/transformers/models/segformer/convert_segformer_original_to_pytorch.py +387 -0
  24. docs/transformers/build/lib/transformers/models/segformer/feature_extraction_segformer.py +38 -0
  25. docs/transformers/build/lib/transformers/models/segformer/image_processing_segformer.py +484 -0
  26. docs/transformers/build/lib/transformers/models/segformer/modeling_segformer.py +840 -0
  27. docs/transformers/build/lib/transformers/models/segformer/modeling_tf_segformer.py +1045 -0
  28. docs/transformers/build/lib/transformers/models/seggpt/__init__.py +28 -0
  29. docs/transformers/build/lib/transformers/models/seggpt/configuration_seggpt.py +143 -0
  30. docs/transformers/build/lib/transformers/models/seggpt/convert_seggpt_to_hf.py +221 -0
  31. docs/transformers/build/lib/transformers/models/seggpt/image_processing_seggpt.py +618 -0
  32. docs/transformers/build/lib/transformers/models/seggpt/modeling_seggpt.py +1031 -0
  33. docs/transformers/build/lib/transformers/models/sew/__init__.py +27 -0
  34. docs/transformers/build/lib/transformers/models/sew/configuration_sew.py +256 -0
  35. docs/transformers/build/lib/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py +305 -0
  36. docs/transformers/build/lib/transformers/models/sew/modeling_sew.py +1498 -0
  37. docs/transformers/build/lib/transformers/models/sew_d/__init__.py +27 -0
  38. docs/transformers/build/lib/transformers/models/sew_d/configuration_sew_d.py +291 -0
  39. docs/transformers/build/lib/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py +317 -0
  40. docs/transformers/build/lib/transformers/models/sew_d/modeling_sew_d.py +1748 -0
  41. docs/transformers/build/lib/transformers/models/shieldgemma2/__init__.py +28 -0
  42. docs/transformers/build/lib/transformers/models/shieldgemma2/configuration_shieldgemma2.py +120 -0
  43. docs/transformers/build/lib/transformers/models/shieldgemma2/convert_shieldgemma2_weights_orbax_to_hf.py +470 -0
  44. docs/transformers/build/lib/transformers/models/shieldgemma2/modeling_shieldgemma2.py +220 -0
  45. docs/transformers/build/lib/transformers/models/shieldgemma2/processing_shieldgemma2.py +195 -0
  46. docs/transformers/build/lib/transformers/models/siglip/__init__.py +31 -0
  47. docs/transformers/build/lib/transformers/models/siglip/configuration_siglip.py +269 -0
  48. docs/transformers/build/lib/transformers/models/siglip/convert_siglip_to_hf.py +533 -0
  49. docs/transformers/build/lib/transformers/models/siglip/image_processing_siglip.py +244 -0
  50. docs/transformers/build/lib/transformers/models/siglip/image_processing_siglip_fast.py +41 -0
.dev_scripts/ci_container_test.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
2
+ # pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
3
+ pip install -r requirements/tests.txt -i https://mirrors.aliyun.com/pypi/simple/
4
+ git config --global --add safe.directory /ms-swift
5
+ git config --global user.email tmp
6
+ git config --global user.name tmp.com
7
+
8
+ # linter test
9
+ # use internal project for pre-commit due to the network problem
10
+ if [ `git remote -v | grep alibaba | wc -l` -gt 1 ]; then
11
+ pre-commit run -c .pre-commit-config_local.yaml --all-files
12
+ if [ $? -ne 0 ]; then
13
+ echo "linter test failed, please run 'pre-commit run --all-files' to check"
14
+ echo "From the repository folder"
15
+ echo "Run 'pip install -r requirements/tests.txt' install test dependencies."
16
+ echo "Run 'pre-commit install' install pre-commit hooks."
17
+ echo "Finally run linter with command: 'pre-commit run --all-files' to check."
18
+ echo "Ensure there is no failure!!!!!!!!"
19
+ exit -1
20
+ fi
21
+ fi
22
+
23
+ pip install -r requirements/framework.txt -U -i https://mirrors.aliyun.com/pypi/simple/
24
+ pip install diffusers decord einops -U -i https://mirrors.aliyun.com/pypi/simple/
25
+ pip install autoawq -U --no-deps
26
+
27
+ # test with install
28
+ pip install .
29
+ pip install auto_gptq bitsandbytes deepspeed -U -i https://mirrors.aliyun.com/pypi/simple/
30
+ else
31
+ echo "Running case in release image, run case directly!"
32
+ fi
33
+ # remove torch_extensions folder to avoid ci hang.
34
+ rm -rf ~/.cache/torch_extensions
35
+ if [ $# -eq 0 ]; then
36
+ ci_command="python tests/run.py --subprocess"
37
+ else
38
+ ci_command="$@"
39
+ fi
40
+ echo "Running case with command: $ci_command"
41
+ $ci_command
4JOB_TRAIN.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
MANIFEST.in ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ recursive-include swift/utils *.py
2
+ recursive-include swift/llm/dataset/data *.*
3
+ recursive-include swift/llm/ds_config *.json
4
+ recursive-include requirements *.txt
5
+ recursive-include swift/plugin/loss_scale/config *.json
Makefile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ WHL_BUILD_DIR :=package
2
+ DOC_BUILD_DIR :=docs/build/
3
+
4
+ # default rule
5
+ default: whl docs
6
+
7
+ .PHONY: docs
8
+ docs:
9
+ bash .dev_scripts/build_docs.sh
10
+
11
+ .PHONY: linter
12
+ linter:
13
+ bash .dev_scripts/linter.sh
14
+
15
+ .PHONY: test
16
+ test:
17
+ bash .dev_scripts/citest.sh
18
+
19
+ .PHONY: whl
20
+ whl:
21
+ python setup.py sdist bdist_wheel
22
+
23
+ .PHONY: clean
24
+ clean:
25
+ rm -rf $(WHL_BUILD_DIR) $(DOC_BUILD_DIR)
README.md ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SWIFT (Scalable lightWeight Infrastructure for Fine-Tuning)
2
+
3
+ <p align="center">
4
+ <br>
5
+ <img src="asset/banner.png"/>
6
+ <br>
7
+ <p>
8
+ <p align="center">
9
+ <a href="https://modelscope.cn/home">ModelScope Community Website</a>
10
+ <br>
11
+ <a href="README_CN.md">中文</a> &nbsp | &nbsp English &nbsp
12
+ </p>
13
+
14
+ <p align="center">
15
+ <img src="https://img.shields.io/badge/python-3.10-5be.svg">
16
+ <img src="https://img.shields.io/badge/pytorch-%E2%89%A52.0-orange.svg">
17
+ <a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.19-5D91D4.svg"></a>
18
+ <a href="https://pypi.org/project/ms-swift/"><img src="https://badge.fury.io/py/ms-swift.svg"></a>
19
+ <a href="https://github.com/modelscope/swift/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/swift"></a>
20
+ <a href="https://pepy.tech/project/ms-swift"><img src="https://pepy.tech/badge/ms-swift"></a>
21
+ <a href="https://github.com/modelscope/swift/pulls"><img src="https://img.shields.io/badge/PR-welcome-55EB99.svg"></a>
22
+ </p>
23
+
24
+ <p align="center">
25
+ <a href="https://trendshift.io/repositories/6427" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6427" alt="modelscope%2Fswift | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
26
+ </p>
27
+
28
+ <p align="center">
29
+ <a href="https://arxiv.org/abs/2408.05517">Paper</a> &nbsp | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a> &nbsp
30
+ </p>
31
+
32
+ ## 📖 Table of Contents
33
+ - [Groups](#-Groups)
34
+ - [Introduction](#-introduction)
35
+ - [News](#-news)
36
+ - [Installation](#%EF%B8%8F-installation)
37
+ - [Quick Start](#-quick-Start)
38
+ - [Usage](#-Usage)
39
+ - [License](#-License)
40
+ - [Citation](#-citation)
41
+
42
+
43
+ ## ☎ Groups
44
+
45
+ You can contact us and communicate with us by adding our group:
46
+
47
+
48
+ [Discord Group](https://discord.com/invite/D27yfEFVz5) | WeChat Group
49
+ :-------------------------:|:-------------------------:
50
+ <img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
51
+
52
+
53
+ ## 📝 Introduction
54
+ 🍲 ms-swift is an official framework provided by the ModelScope community for fine-tuning and deploying large language models and multi-modal large models. It currently supports the training (pre-training, fine-tuning, human alignment), inference, evaluation, quantization, and deployment of 500+ large models and 200+ multi-modal large models. These large language models (LLMs) include models such as Qwen3, Qwen3-MoE, Qwen2.5, InternLM3, GLM4, Mistral, DeepSeek-R1, Yi1.5, TeleChat2, Baichuan2, and Gemma2. The multi-modal LLMs include models such as Qwen2.5-VL, Qwen2-Audio, Llama3.4, Llava, InternVL2.5, MiniCPM-V-2.6, GLM4v, Xcomposer2.5, Yi-VL, DeepSeek-VL2, Phi3.5-Vision, and GOT-OCR2.
55
+
56
+ 🍔 Additionally, ms-swift incorporates the latest training technologies, including lightweight techniques such as LoRA, QLoRA, Llama-Pro, LongLoRA, GaLore, Q-GaLore, LoRA+, LISA, DoRA, FourierFt, ReFT, UnSloth, and Liger, as well as human alignment training methods like DPO, GRPO, RM, PPO, KTO, CPO, SimPO, and ORPO. ms-swift supports acceleration of inference, evaluation, and deployment modules using vLLM and LMDeploy, and it supports model quantization with technologies like GPTQ, AWQ, and BNB. Furthermore, ms-swift offers a Gradio-based Web UI and a wealth of best practices.
57
+
58
+ **Why choose ms-swift?**
59
+
60
+ - 🍎 **Model Types**: Supports 500+ pure text large models, **200+ multi-modal large models**, as well as All-to-All multi-modal models, sequence classification models, and embedding models, **covering the entire process from training to deployment**.
61
+ - **Dataset Types**: Comes with 150+ pre-training, fine-tuning, human alignment, multi-modal datasets, and supports custom datasets.
62
+ - **Hardware Support**: Compatible with CPU, RTX series, T4/V100, A10/A100/H100, Ascend NPU, MPS, etc.
63
+ - 🍊 **Lightweight Training**: Supports lightweight fine-tuning methods like LoRA, QLoRA, DoRA, LoRA+, ReFT, RS-LoRA, LLaMAPro, Adapter, GaLore, Q-Galore, LISA, UnSloth, Liger-Kernel.
64
+ - **Distributed Training**: Supports distributed data parallel (DDP), device_map simple model parallelism, DeepSpeed ZeRO2/ZeRO3, FSDP, and other distributed training techniques.
65
+ - **Quantization Training**: Supports training quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ.
66
+ - **RLHF Training**: Supports human alignment training methods such as DPO, GRPO, RM, PPO, KTO, CPO, SimPO, ORPO for both pure text and multi-modal large models.
67
+ - 🍓 **Multi-Modal Training**: Supports training on different modalities like images, videos, and audio, for tasks like VQA, captioning, OCR, and grounding.
68
+ - **Interface Training**: Provides capabilities for training, inference, evaluation, quantization through an interface, completing the whole large model pipeline.
69
+ - **Plugin and Extension**: Supports custom model and dataset extensions, as well as customization of components like loss, metric, trainer, loss-scale, callback, optimizer.
70
+ - 🍉 **Toolbox Capabilities**: Offers not only training support for large models and multi-modal large models but also covers the entire process of inference, evaluation, quantization, and deployment.
71
+ - **Inference Acceleration**: Supports inference acceleration engines like PyTorch, vLLM, LmDeploy, and provides OpenAI API for accelerating inference, deployment, and evaluation modules.
72
+ - **Model Evaluation**: Uses EvalScope as the evaluation backend and supports evaluation on 100+ datasets for both pure text and multi-modal models.
73
+ - **Model Quantization**: Supports AWQ, GPTQ, and BNB quantized exports, with models that can use vLLM/LmDeploy for inference acceleration and continue training.
74
+
75
+
76
+ ## 🎉 News
77
+ - 🎁 2025.05.11: GRPO now supports custom processing logic for reward models. See the GenRM example [here](./docs/source_en/Instruction/GRPO.md#customized-reward-models) .
78
+ - 🎁 2025.04.15: The ms-swift paper has been accepted by AAAI 2025. You can find the paper at [this link](https://ojs.aaai.org/index.php/AAAI/article/view/35383).
79
+ - 🎁 2025.03.23: Multi-round GRPO is now supported for training multi-turn dialogue scenarios (e.g., agent tool calling). Please refer to the [training script](https://idealab.alibaba-inc.com/examples/train/grpo/internal/train_multi_round.sh).
80
+ - 🎁 2025.03.16: Support for Megatron's parallel training techniques is now available. Please see the [Megatron-SWIFT training documentation](https://swift.readthedocs.io/zh-cn/latest/Instruction/Megatron-SWIFT训练.html).
81
+ - 🎁 2025.03.15: Fine-tuning of embedding models for both pure text and multimodal models is supported. Please check the [training script](https://idealab.alibaba-inc.com/examples/train/embedding).
82
+ - 🎁 2025.03.05: The hybrid mode for GRPO is supported, with a script for training a 72B model on 4 GPUs (4*80G) available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/train_72b_4gpu.sh). Tensor parallelism with vllm is also supported, with the training script available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/multi_gpu_mp_colocate.sh).
83
+ - 🎁 2025.02.21: The GRPO algorithm now supports LMDeploy, with the training script available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/full_lmdeploy.sh). Additionally, the performance of the GRPO algorithm has been tested, achieving a training speed increase of up to 300% using various tricks. Please check the WanDB table [here](https://wandb.ai/tastelikefeet/grpo_perf_test?nw=nwuseryuzezyz).
84
+ - 🎁 2025.02.21: The `swift sample` command is now supported. The reinforcement fine-tuning script can be found [here](https://idealab.alibaba-inc.com/docs/source/Instruction/强化微调.md), and the large model API distillation sampling script is available [here](https://idealab.alibaba-inc.com/examples/sampler/distill/distill.sh).
85
+ - 🔥 2025.02.12: Support for the GRPO (Group Relative Policy Optimization) training algorithm has been added. Documentation is available [here](https://idealab.alibaba-inc.com/docs/source/Instruction/GRPO.md).
86
+ - 🎁 2024.12.04: Major update to **ms-swift 3.0**. Please refer to the [release notes and changes](https://swift.readthedocs.io/zh-cn/latest/Instruction/ReleaseNote3.0.html).
87
+ <details><summary>More</summary>
88
+
89
+ - 🎉 2024.08.12: The ms-swift paper has been published on arXiv and can be read [here](https://arxiv.org/abs/2408.05517).
90
+ - 🔥 2024.08.05: Support for using [evalscope](https://github.com/modelscope/evalscope/) as a backend for evaluating large models and multimodal models.
91
+ - 🔥 2024.07.29: Support for using [vllm](https://github.com/vllm-project/vllm) and [lmdeploy](https://github.com/InternLM/lmdeploy) to accelerate inference for large models and multimodal models. When performing infer/deploy/eval, you can specify `--infer_backend vllm/lmdeploy`.
92
+ - 🔥 2024.07.24: Support for human preference alignment training for multimodal large models, including DPO/ORPO/SimPO/CPO/KTO/RM/PPO.
93
+ - 🔥 2024.02.01: Support for Agent training! The training algorithm is derived from [this paper](https://arxiv.org/pdf/2309.00986.pdf).
94
+ </details>
95
+
96
+ ## 🛠️ Installation
97
+ To install using pip:
98
+ ```shell
99
+ pip install ms-swift -U
100
+ ```
101
+
102
+ To install from source:
103
+ ```shell
104
+ # pip install git+https://github.com/modelscope/ms-swift.git
105
+
106
+ git clone https://github.com/modelscope/ms-swift.git
107
+ cd ms-swift
108
+ pip install -e .
109
+ ```
110
+
111
+ Running Environment:
112
+
113
+ | | Range | Recommended | Notes |
114
+ | ------------ |--------------| ----------- | ----------------------------------------- |
115
+ | python | >=3.9 | 3.10 | |
116
+ | cuda | | cuda12 | No need to install if using CPU, NPU, MPS |
117
+ | torch | >=2.0 | | |
118
+ | transformers | >=4.33 | 4.51 | |
119
+ | modelscope | >=1.23 | | |
120
+ | peft | >=0.11,<0.16 | ||
121
+ | trl | >=0.13,<0.18 | 0.17 |RLHF|
122
+ | deepspeed | >=0.14 | 0.14.5 | Training |
123
+ | vllm | >=0.5.1 | 0.7.3/0.8 | Inference/Deployment/Evaluation |
124
+ | lmdeploy | >=0.5 | 0.8 | Inference/Deployment/Evaluation |
125
+ | evalscope | >=0.11 | | Evaluation |
126
+
127
+ For more optional dependencies, you can refer to [here](https://github.com/modelscope/ms-swift/blob/main/requirements/install_all.sh).
128
+
129
+
130
+ ## 🚀 Quick Start
131
+
132
+ 10 minutes of self-cognition fine-tuning of Qwen2.5-7B-Instruct on a single 3090 GPU:
133
+
134
+ ### Command Line Interface
135
+
136
+ ```shell
137
+ # 22GB
138
+ CUDA_VISIBLE_DEVICES=0 \
139
+ swift sft \
140
+ --model Qwen/Qwen2.5-7B-Instruct \
141
+ --train_type lora \
142
+ --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
143
+ 'AI-ModelScope/alpaca-gpt4-data-en#500' \
144
+ 'swift/self-cognition#500' \
145
+ --torch_dtype bfloat16 \
146
+ --num_train_epochs 1 \
147
+ --per_device_train_batch_size 1 \
148
+ --per_device_eval_batch_size 1 \
149
+ --learning_rate 1e-4 \
150
+ --lora_rank 8 \
151
+ --lora_alpha 32 \
152
+ --target_modules all-linear \
153
+ --gradient_accumulation_steps 16 \
154
+ --eval_steps 50 \
155
+ --save_steps 50 \
156
+ --save_total_limit 2 \
157
+ --logging_steps 5 \
158
+ --max_length 2048 \
159
+ --output_dir output \
160
+ --system 'You are a helpful assistant.' \
161
+ --warmup_ratio 0.05 \
162
+ --dataloader_num_workers 4 \
163
+ --model_author swift \
164
+ --model_name swift-robot
165
+ ```
166
+
167
+ Tips:
168
+
169
+ - If you want to train with a custom dataset, you can refer to [this guide](https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html) to organize your dataset format and specify `--dataset <dataset_path>`.
170
+ - The `--model_author` and `--model_name` parameters are only effective when the dataset includes `swift/self-cognition`.
171
+ - To train with a different model, simply modify `--model <model_id/model_path>`.
172
+ - By default, ModelScope is used for downloading models and datasets. If you want to use HuggingFace, simply specify `--use_hf true`.
173
+
174
+ After training is complete, use the following command to infer with the trained weights:
175
+
176
+ - Here, `--adapters` should be replaced with the last checkpoint folder generated during training. Since the adapters folder contains the training parameter file `args.json`, there is no need to specify `--model`, `--system` separately; Swift will automatically read these parameters. To disable this behavior, you can set `--load_args false`.
177
+
178
+ ```shell
179
+ # Using an interactive command line for inference.
180
+ CUDA_VISIBLE_DEVICES=0 \
181
+ swift infer \
182
+ --adapters output/vx-xxx/checkpoint-xxx \
183
+ --stream true \
184
+ --temperature 0 \
185
+ --max_new_tokens 2048
186
+
187
+ # merge-lora and use vLLM for inference acceleration
188
+ CUDA_VISIBLE_DEVICES=0 \
189
+ swift infer \
190
+ --adapters output/vx-xxx/checkpoint-xxx \
191
+ --stream true \
192
+ --merge_lora true \
193
+ --infer_backend vllm \
194
+ --max_model_len 8192 \
195
+ --temperature 0 \
196
+ --max_new_tokens 2048
197
+ ```
198
+
199
+ Finally, use the following command to push the model to ModelScope:
200
+
201
+ ```shell
202
+ CUDA_VISIBLE_DEVICES=0 \
203
+ swift export \
204
+ --adapters output/vx-xxx/checkpoint-xxx \
205
+ --push_to_hub true \
206
+ --hub_model_id '<your-model-id>' \
207
+ --hub_token '<your-sdk-token>' \
208
+ --use_hf false
209
+ ```
210
+
211
+
212
+ ### Web-UI
213
+ The Web-UI is a **zero-threshold** training and deployment interface solution based on Gradio interface technology. For more details, you can check [here](https://swift.readthedocs.io/en/latest/GetStarted/Web-UI.html).
214
+
215
+ ```shell
216
+ SWIFT_UI_LANG=en swift web-ui
217
+ ```
218
+
219
+ ![image.png](./docs/resources/web-ui-en.jpg)
220
+
221
+ ### Using Python
222
+
223
+ ms-swift also supports training and inference using Python. Below is pseudocode for training and inference. For more details, you can refer to [here](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb).
224
+
225
+ Training:
226
+
227
+ ```python
228
+ # Retrieve the model and template, and add a trainable LoRA module
229
+ model, tokenizer = get_model_tokenizer(model_id_or_path, ...)
230
+ template = get_template(model.model_meta.template, tokenizer, ...)
231
+ model = Swift.prepare_model(model, lora_config)
232
+
233
+ # Download and load the dataset, and encode the text into tokens
234
+ train_dataset, val_dataset = load_dataset(dataset_id_or_path, ...)
235
+ train_dataset = EncodePreprocessor(template=template)(train_dataset, num_proc=num_proc)
236
+ val_dataset = EncodePreprocessor(template=template)(val_dataset, num_proc=num_proc)
237
+
238
+ # Train the model
239
+ trainer = Seq2SeqTrainer(
240
+ model=model,
241
+ args=training_args,
242
+ data_collator=template.data_collator,
243
+ train_dataset=train_dataset,
244
+ eval_dataset=val_dataset,
245
+ template=template,
246
+ )
247
+ trainer.train()
248
+ ```
249
+ Inference:
250
+
251
+ ```python
252
+ # Perform inference using the native PyTorch engine
253
+ engine = PtEngine(model_id_or_path, adapters=[lora_checkpoint])
254
+ infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}])
255
+ request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature)
256
+
257
+ resp_list = engine.infer([infer_request], request_config)
258
+ print(f'response: {resp_list[0].choices[0].message.content}')
259
+ ```
260
+
261
+ ## ✨ Usage
262
+ Here is a minimal example of training to deployment using ms-swift. For more details, you can check the [examples](https://github.com/modelscope/ms-swift/tree/main/examples).
263
+
264
+ - If you want to use other models or datasets (including multimodal models and datasets), you only need to modify `--model` to specify the corresponding model's ID or path, and modify `--dataset` to specify the corresponding dataset's ID or path.
265
+ - By default, ModelScope is used for downloading models and datasets. If you want to use HuggingFace, simply specify `--use_hf true`.
266
+
267
+ | Useful Links |
268
+ | ------ |
269
+ | [🔥Command Line Parameters](https://swift.readthedocs.io/en/latest/Instruction/Command-line-parameters.html) |
270
+ | [Supported Models and Datasets](https://swift.readthedocs.io/en/latest/Instruction/Supported-models-and-datasets.html) |
271
+ | [Custom Models](https://swift.readthedocs.io/en/latest/Customization/Custom-model.html), [🔥Custom Datasets](https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html) |
272
+ | [LLM Tutorial](https://github.com/modelscope/modelscope-classroom/tree/main/LLM-tutorial) |
273
+
274
+ ### Training
275
+
276
+ Supported Training Methods:
277
+
278
+ | Method | Full-Parameter | LoRA | QLoRA | Deepspeed | Multi-Node | Multi-Modal |
279
+ |------------------------------------|--------------------------------------------------------------|---------------------------------------------------------------------------------------------|--------------------------------------------------------------|--------------------------------------------------------------|--------------------------------------------------------------|----------------------------------------------------------------------------------------------|
280
+ | Pre-training | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/pretrain/train.sh) | ✅ | ✅ | ✅ | ✅ | ✅ |
281
+ | Instruction Supervised Fine-tuning | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/train.sh) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/lora_sft.sh) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/qlora) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-gpu/deepspeed) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal) |
282
+ | DPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/dpo.sh) |
283
+ | GRPO Training | [✅]((https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/grpo_zero2.sh)) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/multi_node) | ✅ |
284
+ | Reward Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | ✅ |
285
+ | PPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | ❌ |
286
+ | KTO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
287
+ | CPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | ✅ |
288
+ | SimPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | ✅ |
289
+ | ORPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | ✅ |
290
+ | Classification Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_5/sft.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_vl/sft.sh) |
291
+ | Embedding Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gte.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gme.sh) |
292
+
293
+
294
+
295
+ Pre-training:
296
+ ```shell
297
+ # 8*A100
298
+ NPROC_PER_NODE=8 \
299
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
300
+ swift pt \
301
+ --model Qwen/Qwen2.5-7B \
302
+ --dataset swift/chinese-c4 \
303
+ --streaming true \
304
+ --train_type full \
305
+ --deepspeed zero2 \
306
+ --output_dir output \
307
+ --max_steps 10000 \
308
+ ...
309
+ ```
310
+
311
+ Fine-tuning:
312
+ ```shell
313
+ CUDA_VISIBLE_DEVICES=0 swift sft \
314
+ --model Qwen/Qwen2.5-7B-Instruct \
315
+ --dataset AI-ModelScope/alpaca-gpt4-data-en \
316
+ --train_type lora \
317
+ --output_dir output \
318
+ ...
319
+ ```
320
+
321
+ RLHF:
322
+ ```shell
323
+ CUDA_VISIBLE_DEVICES=0 swift rlhf \
324
+ --rlhf_type dpo \
325
+ --model Qwen/Qwen2.5-7B-Instruct \
326
+ --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
327
+ --train_type lora \
328
+ --output_dir output \
329
+ ...
330
+ ```
331
+
332
+
333
+ ### Inference
334
+ ```shell
335
+ CUDA_VISIBLE_DEVICES=0 swift infer \
336
+ --model Qwen/Qwen2.5-7B-Instruct \
337
+ --stream true \
338
+ --infer_backend pt \
339
+ --max_new_tokens 2048
340
+
341
+ # LoRA
342
+ CUDA_VISIBLE_DEVICES=0 swift infer \
343
+ --model Qwen/Qwen2.5-7B-Instruct \
344
+ --adapters swift/test_lora \
345
+ --stream true \
346
+ --infer_backend pt \
347
+ --temperature 0 \
348
+ --max_new_tokens 2048
349
+ ```
350
+
351
+ ### Interface Inference
352
+ ```shell
353
+ CUDA_VISIBLE_DEVICES=0 swift app \
354
+ --model Qwen/Qwen2.5-7B-Instruct \
355
+ --stream true \
356
+ --infer_backend pt \
357
+ --max_new_tokens 2048
358
+ ```
359
+
360
+ ### Deployment
361
+ ```shell
362
+ CUDA_VISIBLE_DEVICES=0 swift deploy \
363
+ --model Qwen/Qwen2.5-7B-Instruct \
364
+ --infer_backend vllm
365
+ ```
366
+
367
+ ### Sampling
368
+ ```shell
369
+ CUDA_VISIBLE_DEVICES=0 swift sample \
370
+ --model LLM-Research/Meta-Llama-3.1-8B-Instruct \
371
+ --sampler_engine pt \
372
+ --num_return_sequences 5 \
373
+ --dataset AI-ModelScope/alpaca-gpt4-data-zh#5
374
+ ```
375
+
376
+ ### Evaluation
377
+ ```shell
378
+ CUDA_VISIBLE_DEVICES=0 swift eval \
379
+ --model Qwen/Qwen2.5-7B-Instruct \
380
+ --infer_backend lmdeploy \
381
+ --eval_backend OpenCompass \
382
+ --eval_dataset ARC_c
383
+ ```
384
+
385
+ ### Quantization
386
+ ```shell
387
+ CUDA_VISIBLE_DEVICES=0 swift export \
388
+ --model Qwen/Qwen2.5-7B-Instruct \
389
+ --quant_bits 4 --quant_method awq \
390
+ --dataset AI-ModelScope/alpaca-gpt4-data-zh \
391
+ --output_dir Qwen2.5-7B-Instruct-AWQ
392
+ ```
393
+
394
+ ### Push Model
395
+ ```shell
396
+ swift export \
397
+ --model <model-path> \
398
+ --push_to_hub true \
399
+ --hub_model_id '<model-id>' \
400
+ --hub_token '<sdk-token>'
401
+ ```
402
+
403
+ ## 🏛 License
404
+
405
+ This framework is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE). For models and datasets, please refer to the original resource page and follow the corresponding License.
406
+
407
+ ## 📎 Citation
408
+
409
+ ```bibtex
410
+ @misc{zhao2024swiftascalablelightweightinfrastructure,
411
+ title={SWIFT:A Scalable lightWeight Infrastructure for Fine-Tuning},
412
+ author={Yuze Zhao and Jintao Huang and Jinghan Hu and Xingjun Wang and Yunlin Mao and Daoze Zhang and Zeyinzi Jiang and Zhikai Wu and Baole Ai and Ang Wang and Wenmeng Zhou and Yingda Chen},
413
+ year={2024},
414
+ eprint={2408.05517},
415
+ archivePrefix={arXiv},
416
+ primaryClass={cs.CL},
417
+ url={https://arxiv.org/abs/2408.05517},
418
+ }
419
+ ```
420
+
421
+ ## Star History
422
+
423
+ [![Star History Chart](https://api.star-history.com/svg?repos=modelscope/swift&type=Date)](https://star-history.com/#modelscope/ms-swift&Date)
README_CN.md ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SWIFT (Scalable lightWeight Infrastructure for Fine-Tuning)
2
+
3
+ <p align="center">
4
+ <br>
5
+ <img src="asset/banner.png"/>
6
+ <br>
7
+ <p>
8
+ <p align="center">
9
+ <a href="https://modelscope.cn/home">魔搭社区官网</a>
10
+ <br>
11
+ 中文&nbsp | &nbsp<a href="README.md">English</a>&nbsp
12
+ </p>
13
+
14
+
15
+ <p align="center">
16
+ <img src="https://img.shields.io/badge/python-3.10-5be.svg">
17
+ <img src="https://img.shields.io/badge/pytorch-%E2%89%A52.0-orange.svg">
18
+ <a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.19-5D91D4.svg"></a>
19
+ <a href="https://pypi.org/project/ms-swift/"><img src="https://badge.fury.io/py/ms-swift.svg"></a>
20
+ <a href="https://github.com/modelscope/swift/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/swift"></a>
21
+ <a href="https://pepy.tech/project/ms-swift"><img src="https://pepy.tech/badge/ms-swift"></a>
22
+ <a href="https://github.com/modelscope/swift/pulls"><img src="https://img.shields.io/badge/PR-welcome-55EB99.svg"></a>
23
+ </p>
24
+
25
+ <p align="center">
26
+ <a href="https://trendshift.io/repositories/6427" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6427" alt="modelscope%2Fswift | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
27
+ </p>
28
+
29
+ <p align="center">
30
+ <a href="https://arxiv.org/abs/2408.05517">论文</a> &nbsp | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a> &nbsp
31
+ </p>
32
+
33
+ ## 📖 目录
34
+ - [用户群](#-用户群)
35
+ - [简介](#-简介)
36
+ - [新闻](#-新闻)
37
+ - [安装](#%EF%B8%8F-安装)
38
+ - [快速开始](#-快速开始)
39
+ - [如何使用](#-如何使用)
40
+ - [License](#-license)
41
+ - [引用](#-引用)
42
+
43
+ ## ☎ 用户群
44
+
45
+ 请扫描下面的二维码来加入我们的交流群:
46
+
47
+ [Discord Group](https://discord.com/invite/D27yfEFVz5) | 微信群
48
+ :-------------------------:|:-------------------------:
49
+ <img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
50
+
51
+ ## 📝 简介
52
+ 🍲 ms-swift是魔搭社区提供的大模型与多模态大模型微调部署框架,现已支持500+大模型与200+多模态大模型的训练(预训练、微调、人类对齐)、推理、评测、量化与部署。其中大模型包括:Qwen3、Qwen3-MoE、Qwen2.5、InternLM3、GLM4、Mistral、DeepSeek-R1、Yi1.5、TeleChat2、Baichuan2、Gemma2等模型,多模态大模型包括:Qwen2.5-VL、Qwen2-Audio、Llama4、Llava、InternVL2.5、MiniCPM-V-2.6、GLM4v、Xcomposer2.5、Yi-VL、DeepSeek-VL2、Phi3.5-Vision、GOT-OCR2等模型。
53
+
54
+ 🍔 除此之外,ms-swift汇集了最新的训练技术,包括LoRA、QLoRA、Llama-Pro、LongLoRA、GaLore、Q-GaLore、LoRA+、LISA、DoRA、FourierFt、ReFT、UnSloth、和Liger等轻量化训练技术,以及DPO、GRPO、RM、PPO、KTO、CPO、SimPO、ORPO等人类对齐训练方法。ms-swift支持使用vLLM和LMDeploy对推理、评测和部署模块进行加速,并支持使用GPTQ、AWQ、BNB等技术对大模型进行量化。ms-swift还提供了基于Gradio的Web-UI界面及丰富的最佳实践。
55
+
56
+ **为什么选择ms-swift?**
57
+ - 🍎 **模型类型**:支持500+纯文本大模型、**200+多模态大模型**以及All-to-All全模态模型、序列分类模型、Embedding模型**训练到部署全流程**。
58
+ - **数据集类型**:内置150+预训练、微调、人类对齐、多模态等各种类型的数据集,并支持自定义数据集。
59
+ - **硬件支持**:CPU、RTX系列、T4/V100、A10/A100/H100、Ascend NPU、MPS等。
60
+ - 🍊 **轻量训练**:支持了LoRA、QLoRA、DoRA、LoRA+、ReFT、RS-LoRA、LLaMAPro、Adapter、GaLore、Q-Galore、LISA、UnSloth、Liger-Kernel等轻量微调方式。
61
+ - **分布式训练**:支持分布式数据并行(DDP)、device_map简易模型并行、DeepSpeed ZeRO2 ZeRO3、FSDP等分布式训练技术。
62
+ - **量化训练**:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。
63
+ - **RLHF训练**:支持纯文本大模型和多模态大模型的DPO、GRPO、RM、PPO、KTO、CPO、SimPO、ORPO等人类对齐训练方法。
64
+ - 🍓 **多模态训练**:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。
65
+ - **界面训练**:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。
66
+ - **插件化与拓展**:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。
67
+ - 🍉 **工具箱能力**:不仅提供大模型和多模态大模型的训练支持,还涵盖其推理、评测、量化和部署全流程。
68
+ - **推理加速**:支持PyTorch、vLLM、LmDeploy推理加速引擎,并提供OpenAI接口,为推理、部署和评测模块提供加速。
69
+ - **模型评测**:以EvalScope作为评测后端,支持100+评测数据集对纯��本和多模态模型进行评测。
70
+ - **模型量化**:支持AWQ、GPTQ和BNB的量化导出,导出的模型支持使用vLLM/LmDeploy推理加速,并支持继续训练。
71
+
72
+ ## 🎉 新闻
73
+ - 🎁 2025.05.11: GRPO中的奖励模型支持自定义处理逻辑,GenRM的例子参考[这里](./docs/source/Instruction/GRPO.md#自定义奖励模型)
74
+ - 🎁 2025.04.15: ms-swift论文已经被AAAI 2025接收,论文地址在[这里](https://ojs.aaai.org/index.php/AAAI/article/view/35383)。
75
+ - 🎁 2025.03.23: 支持了多轮GRPO,用于构建多轮对话场景的训练(例如agent tool calling),请查看[训练脚本](examples/train/grpo/internal/train_multi_round.sh)。
76
+ - 🎁 2025.03.16: 支持了Megatron的并行技术进行训练,请查看[Megatron-SWIFT训练文档](https://swift.readthedocs.io/zh-cn/latest/Instruction/Megatron-SWIFT训练.html)。
77
+ - 🎁 2025.03.15: 支持纯文本和多模态模型的embedding模型的微调,请查看[训练脚本](examples/train/embedding)。
78
+ - 🎁 2025.03.05: 支持GRPO的hybrid模式,4GPU(4*80G)训练72B模型的脚本参考[这里](examples/train/grpo/internal/train_72b_4gpu.sh)。同时支持vllm的tensor并行,训练脚本参考[这里](examples/train/grpo/internal/multi_gpu_mp_colocate.sh)。
79
+ - 🎁 2025.02.21: GRPO算法支持使用LMDeploy,训练脚本参考[这里](examples/train/grpo/internal/full_lmdeploy.sh)。此外测试了GRPO算法的性能,使用一些tricks使训练速度提高到300%。WanDB表格请查看[这里](https://wandb.ai/tastelikefeet/grpo_perf_test?nw=nwuseryuzezyz)。
80
+ - 🎁 2025.02.21: 支持`swift sample`命令。强化微调脚本参考[这里](docs/source/Instruction/强化微调.md),大模型API蒸馏采样脚本参考[这里](examples/sampler/distill/distill.sh)。
81
+ - 🔥 2025.02.12: 支持GRPO (Group Relative Policy Optimization) 训练算法,文档参考[这里](docs/source/Instruction/GRPO.md)。
82
+ - 🎁 2024.12.04: **ms-swift3.0**大版本更新。请查看[发布说明和更改](https://swift.readthedocs.io/zh-cn/latest/Instruction/ReleaseNote3.0.html)。
83
+ <details><summary>更多</summary>
84
+
85
+ - 🎉 2024.08.12: ms-swift论文已经发布到arXiv上,可以点击[这里](https://arxiv.org/abs/2408.05517)阅读。
86
+ - 🔥 2024.08.05: 支持使用[evalscope](https://github.com/modelscope/evalscope/)作为后端进行大模型和多模态模型的评测。
87
+ - 🔥 2024.07.29: 支持使用[vllm](https://github.com/vllm-project/vllm), [lmdeploy](https://github.com/InternLM/lmdeploy)对大模型和多模态大模型进行推理加速,在infer/deploy/eval时额外指定`--infer_backend vllm/lmdeploy`即可。
88
+ - 🔥 2024.07.24: 支持对多模态大模型进行人类偏好对齐训练,包括DPO/ORPO/SimPO/CPO/KTO/RM/PPO。
89
+ - 🔥 2024.02.01: 支持Agent训练!训练算法源自这篇[论文](https://arxiv.org/pdf/2309.00986.pdf)。
90
+ </details>
91
+
92
+ ## 🛠️ 安装
93
+ 使用pip进行安装:
94
+ ```shell
95
+ pip install ms-swift -U
96
+ ```
97
+
98
+ 从源代码安装:
99
+ ```shell
100
+ # pip install git+https://github.com/modelscope/ms-swift.git
101
+
102
+ git clone https://github.com/modelscope/ms-swift.git
103
+ cd ms-swift
104
+ pip install -e .
105
+ ```
106
+
107
+ 运行环境:
108
+
109
+ | | 范围 | 推荐 | 备注 |
110
+ | ------ |--------------| ---- | --|
111
+ | python | >=3.9 | 3.10 ||
112
+ | cuda | | cuda12 |使用cpu、npu、mps则无需安装|
113
+ | torch | >=2.0 | ||
114
+ | transformers | >=4.33 | 4.51 ||
115
+ | modelscope | >=1.23 | ||
116
+ | peft | >=0.11,<0.16 | ||
117
+ | trl | >=0.13,<0.18 | 0.17 |RLHF|
118
+ | deepspeed | >=0.14 | 0.14.5 |训练|
119
+ | vllm | >=0.5.1 | 0.7.3/0.8 |推理/部署/评测|
120
+ | lmdeploy | >=0.5 | 0.8 |推理/部署/评测|
121
+ | evalscope | >=0.11 | |评测|
122
+
123
+ 更多可选依赖可以参考[这里](https://github.com/modelscope/ms-swift/blob/main/requirements/install_all.sh)。
124
+
125
+
126
+ ## 🚀 快速开始
127
+
128
+ **10分钟**在单卡3090上对Qwen2.5-7B-Instruct进行自我认知微调:
129
+
130
+ ### 命令行
131
+ ```shell
132
+ # 22GB
133
+ CUDA_VISIBLE_DEVICES=0 \
134
+ swift sft \
135
+ --model Qwen/Qwen2.5-7B-Instruct \
136
+ --train_type lora \
137
+ --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
138
+ 'AI-ModelScope/alpaca-gpt4-data-en#500' \
139
+ 'swift/self-cognition#500' \
140
+ --torch_dtype bfloat16 \
141
+ --num_train_epochs 1 \
142
+ --per_device_train_batch_size 1 \
143
+ --per_device_eval_batch_size 1 \
144
+ --learning_rate 1e-4 \
145
+ --lora_rank 8 \
146
+ --lora_alpha 32 \
147
+ --target_modules all-linear \
148
+ --gradient_accumulation_steps 16 \
149
+ --eval_steps 50 \
150
+ --save_steps 50 \
151
+ --save_total_limit 2 \
152
+ --logging_steps 5 \
153
+ --max_length 2048 \
154
+ --output_dir output \
155
+ --system 'You are a helpful assistant.' \
156
+ --warmup_ratio 0.05 \
157
+ --dataloader_num_workers 4 \
158
+ --model_author swift \
159
+ --model_name swift-robot
160
+ ```
161
+
162
+ 小贴士:
163
+ - 如果要使用自定义数据集进行训练,你可以参考[这里](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%95%B0%E6%8D%AE%E9%9B%86.html)组织数据集格式,并指定`--dataset <dataset_path>`。
164
+ - `--model_author`和`--model_name`参数只有当数据集中包含`swift/self-cognition`时才生效。
165
+ - 如果要使用其他模型进行训练,你只需要修改`--model <model_id/model_path>`即可。
166
+ - 默认使用ModelScope进行模型和数据集的下载。如果要使用HuggingFace,指定`--use_hf true`即可。
167
+
168
+ 训练完成后,使用以下命令对训练后的权重进行推理:
169
+ - 这里的`--adapters`需要替换成训练生成的last checkpoint文件夹。由于adapters文件夹中包含了训练的参数文件`args.json`,因此不需要额外指定`--model`,`--system`,swift会自动读取这些参数。如果要关闭此行为,可以设置`--load_args false`。
170
+
171
+ ```shell
172
+ # 使用交互式命令行进行推理
173
+ CUDA_VISIBLE_DEVICES=0 \
174
+ swift infer \
175
+ --adapters output/vx-xxx/checkpoint-xxx \
176
+ --stream true \
177
+ --temperature 0 \
178
+ --max_new_tokens 2048
179
+
180
+ # merge-lora并使用vLLM进行推理加速
181
+ CUDA_VISIBLE_DEVICES=0 \
182
+ swift infer \
183
+ --adapters output/vx-xxx/checkpoint-xxx \
184
+ --stream true \
185
+ --merge_lora true \
186
+ --infer_backend vllm \
187
+ --max_model_len 8192 \
188
+ --temperature 0 \
189
+ --max_new_tokens 2048
190
+ ```
191
+
192
+ 最后,使用以下命令将模型推送到ModelScope:
193
+ ```shell
194
+ CUDA_VISIBLE_DEVICES=0 \
195
+ swift export \
196
+ --adapters output/vx-xxx/checkpoint-xxx \
197
+ --push_to_hub true \
198
+ --hub_model_id '<your-model-id>' \
199
+ --hub_token '<your-sdk-token>' \
200
+ --use_hf false
201
+ ```
202
+
203
+ ### Web-UI
204
+
205
+ Web-UI是基于gradio界面技术的**零门槛**训练、部署界面方案,具体可以查看[这里](https://swift.readthedocs.io/zh-cn/latest/GetStarted/Web-UI.html)。
206
+
207
+ ```shell
208
+ swift web-ui
209
+ ```
210
+ ![image.png](./docs/resources/web-ui.jpg)
211
+
212
+ ### 使用Python
213
+ ms-swift也支持使用python的方式进行训练和推理。下面给出训练和推理的**伪代码**,具体可以查看[这里](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb)。
214
+
215
+ 训练:
216
+ ```python
217
+ # 获取模型和template,并加入可训练的LoRA模块
218
+ model, tokenizer = get_model_tokenizer(model_id_or_path, ...)
219
+ template = get_template(model.model_meta.template, tokenizer, ...)
220
+ model = Swift.prepare_model(model, lora_config)
221
+
222
+ # 下载并载入数据集,并将文本encode成tokens
223
+ train_dataset, val_dataset = load_dataset(dataset_id_or_path, ...)
224
+ train_dataset = EncodePreprocessor(template=template)(train_dataset, num_proc=num_proc)
225
+ val_dataset = EncodePreprocessor(template=template)(val_dataset, num_proc=num_proc)
226
+
227
+ # 进行训练
228
+ trainer = Seq2SeqTrainer(
229
+ model=model,
230
+ args=training_args,
231
+ data_collator=template.data_collator,
232
+ train_dataset=train_dataset,
233
+ eval_dataset=val_dataset,
234
+ template=template,
235
+ )
236
+ trainer.train()
237
+ ```
238
+
239
+ 推理:
240
+ ```python
241
+ # 使用原生pytorch引擎进行推理
242
+ engine = PtEngine(model_id_or_path, adapters=[lora_checkpoint])
243
+ infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}])
244
+ request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature)
245
+
246
+ resp_list = engine.infer([infer_request], request_config)
247
+ print(f'response: {resp_list[0].choices[0].message.content}')
248
+ ```
249
+
250
+ ## ✨ 如何使用
251
+
252
+ 这里给出使用ms-swift进行训练到部署到最简示例,具体可以查看[examples](https://github.com/modelscope/ms-swift/tree/main/examples)。
253
+
254
+ - 若想使用其他模型或者数据集(含多模态模型和数据集),你只需要修改`--model`指定对应模型的id或者path,修改`--dataset`指定对应数据集的id或者path即可。
255
+ - 默认使用ModelScope进行模型和数据集的下载。如果要使用HuggingFace,指定`--use_hf true`即可。
256
+
257
+ | 常用链接 |
258
+ | ------ |
259
+ | [🔥命令行参数](https://swift.readthedocs.io/zh-cn/latest/Instruction/%E5%91%BD%E4%BB%A4%E8%A1%8C%E5%8F%82%E6%95%B0.html) |
260
+ | [支持的模型和数据集](https://swift.readthedocs.io/zh-cn/latest/Instruction/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.html) |
261
+ | [自定义模型](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%A8%A1%E5%9E%8B.html), [🔥自定义数据集](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%95%B0%E6%8D%AE%E9%9B%86.html) |
262
+ | [大模型教程](https://github.com/modelscope/modelscope-classroom/tree/main/LLM-tutorial) |
263
+
264
+ ### 训练
265
+ 支持的训练方法:
266
+
267
+ | 方法 | 全参数 | LoRA | QLoRA | Deepspeed | 多机 | 多模态 |
268
+ | ------ | ------ |---------------------------------------------------------------------------------------------| ----- | ------ | ------ |----------------------------------------------------------------------------------------------|
269
+ | 预训练 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/pretrain/train.sh) | ✅ | ✅ | ✅ | ✅ | ✅ |
270
+ | 指令监督微调 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/train.sh) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/lora_sft.sh) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/qlora) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-gpu/deepspeed) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal) |
271
+ | DPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/dpo.sh) |
272
+ | GRPO训练 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/grpo_zero2.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/multi_node) | ✅ |
273
+ | 奖励模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | ✅ |
274
+ | PPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | ❌ |
275
+ | KTO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
276
+ | CPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | ✅ |
277
+ | SimPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | ✅ |
278
+ | ORPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | ✅ |
279
+ | 分类模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_5/sft.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_vl/sft.sh) |
280
+ | Embedding模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gte.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gme.sh) |
281
+
282
+
283
+ 预训练:
284
+ ```shell
285
+ # 8*A100
286
+ NPROC_PER_NODE=8 \
287
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
288
+ swift pt \
289
+ --model Qwen/Qwen2.5-7B \
290
+ --dataset swift/chinese-c4 \
291
+ --streaming true \
292
+ --train_type full \
293
+ --deepspeed zero2 \
294
+ --output_dir output \
295
+ --max_steps 10000 \
296
+ ...
297
+ ```
298
+
299
+ 微调:
300
+ ```shell
301
+ CUDA_VISIBLE_DEVICES=0 swift sft \
302
+ --model Qwen/Qwen2.5-7B-Instruct \
303
+ --dataset AI-ModelScope/alpaca-gpt4-data-zh \
304
+ --train_type lora \
305
+ --output_dir output \
306
+ ...
307
+ ```
308
+
309
+ RLHF:
310
+ ```shell
311
+ CUDA_VISIBLE_DEVICES=0 swift rlhf \
312
+ --rlhf_type dpo \
313
+ --model Qwen/Qwen2.5-7B-Instruct \
314
+ --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
315
+ --train_type lora \
316
+ --output_dir output \
317
+ ...
318
+ ```
319
+
320
+
321
+ ### 推理
322
+ ```shell
323
+ CUDA_VISIBLE_DEVICES=0 swift infer \
324
+ --model Qwen/Qwen2.5-7B-Instruct \
325
+ --stream true \
326
+ --infer_backend pt \
327
+ --max_new_tokens 2048
328
+
329
+ # LoRA
330
+ CUDA_VISIBLE_DEVICES=0 swift infer \
331
+ --model Qwen/Qwen2.5-7B-Instruct \
332
+ --adapters swift/test_lora \
333
+ --stream true \
334
+ --infer_backend pt \
335
+ --temperature 0 \
336
+ --max_new_tokens 2048
337
+ ```
338
+
339
+ ### 界面推理
340
+ ```shell
341
+ CUDA_VISIBLE_DEVICES=0 swift app \
342
+ --model Qwen/Qwen2.5-7B-Instruct \
343
+ --stream true \
344
+ --infer_backend pt \
345
+ --max_new_tokens 2048 \
346
+ --lang zh
347
+ ```
348
+
349
+ ### 部署
350
+ ```shell
351
+ CUDA_VISIBLE_DEVICES=0 swift deploy \
352
+ --model Qwen/Qwen2.5-7B-Instruct \
353
+ --infer_backend vllm
354
+ ```
355
+
356
+ ### 采样
357
+ ```shell
358
+ CUDA_VISIBLE_DEVICES=0 swift sample \
359
+ --model LLM-Research/Meta-Llama-3.1-8B-Instruct \
360
+ --sampler_engine pt \
361
+ --num_return_sequences 5 \
362
+ --dataset AI-ModelScope/alpaca-gpt4-data-zh#5
363
+ ```
364
+
365
+ ### 评测
366
+ ```shell
367
+ CUDA_VISIBLE_DEVICES=0 swift eval \
368
+ --model Qwen/Qwen2.5-7B-Instruct \
369
+ --infer_backend lmdeploy \
370
+ --eval_backend OpenCompass \
371
+ --eval_dataset ARC_c
372
+ ```
373
+
374
+ ### 量化
375
+ ```shell
376
+ CUDA_VISIBLE_DEVICES=0 swift export \
377
+ --model Qwen/Qwen2.5-7B-Instruct \
378
+ --quant_bits 4 --quant_method awq \
379
+ --dataset AI-ModelScope/alpaca-gpt4-data-zh \
380
+ --output_dir Qwen2.5-7B-Instruct-AWQ
381
+ ```
382
+
383
+ ### 推送模型
384
+ ```shell
385
+ swift export \
386
+ --model <model-path> \
387
+ --push_to_hub true \
388
+ --hub_model_id '<model-id>' \
389
+ --hub_token '<sdk-token>'
390
+ ```
391
+
392
+
393
+ ## 🏛 License
394
+
395
+ 本框架使用[Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE)进行许可。模型和数据集请查看原资源页面并遵守对应License。
396
+
397
+ ## 📎 引用
398
+
399
+ ```bibtex
400
+ @misc{zhao2024swiftascalablelightweightinfrastructure,
401
+ title={SWIFT:A Scalable lightWeight Infrastructure for Fine-Tuning},
402
+ author={Yuze Zhao and Jintao Huang and Jinghan Hu and Xingjun Wang and Yunlin Mao and Daoze Zhang and Zeyinzi Jiang and Zhikai Wu and Baole Ai and Ang Wang and Wenmeng Zhou and Yingda Chen},
403
+ year={2024},
404
+ eprint={2408.05517},
405
+ archivePrefix={arXiv},
406
+ primaryClass={cs.CL},
407
+ url={https://arxiv.org/abs/2408.05517},
408
+ }
409
+ ```
410
+
411
+ ## Star History
412
+
413
+ [![Star History Chart](https://api.star-history.com/svg?repos=modelscope/swift&type=Date)](https://star-history.com/#modelscope/ms-swift&Date)
checkMissing.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torchaudio
3
+ from tqdm import tqdm
4
+ import os
5
+ import sys
6
+ from collections import defaultdict
7
+
8
+ def validate_jsonl_audios(jsonl_path):
9
+ """验证JSONL文件中所有音频文件的完整性"""
10
+ stats = defaultdict(int)
11
+ error_log = []
12
+ valid_samples = 0
13
+
14
+ # 第一次遍历:统计总行数(用于进度条)
15
+ with open(jsonl_path, 'r') as f:
16
+ total_lines = sum(1 for _ in f)
17
+
18
+ # 第二次遍历:实际验证
19
+ with open(jsonl_path, 'r') as f:
20
+ for line_num, line in enumerate(tqdm(f, total=total_lines, desc="验证进度", unit="line")):
21
+ try:
22
+ data = json.loads(line.strip())
23
+ if 'audios' not in data or not data['audios']:
24
+ stats['no_audio_field'] += 1
25
+ continue
26
+
27
+ for audio_path in data['audios']:
28
+ # 检查文件是否存在
29
+ if not os.path.exists(audio_path):
30
+ stats['missing'] += 1
31
+ error_log.append(f"[行{line_num+1}] 缺失文件: {audio_path}")
32
+ continue
33
+
34
+ # 检查文件大小
35
+ if os.path.getsize(audio_path) == 0:
36
+ stats['zero_size'] += 1
37
+ error_log.append(f"[行{line_num+1}] 空文件: {audio_path}")
38
+ continue
39
+
40
+ # 验证音频内容
41
+ try:
42
+ waveform, sr = torchaudio.load(audio_path)
43
+ if waveform.numel() == 0:
44
+ stats['empty_audio'] += 1
45
+ error_log.append(f"[行{line_num+1}] 空音频: {audio_path}")
46
+ elif sr not in [8000, 16000, 22050, 44100, 48000]:
47
+ stats['abnormal_sr'] += 1
48
+ error_log.append(f"[行{line_num+1}] 异常采样率({sr}Hz): {audio_path}")
49
+ else:
50
+ stats['valid'] += 1
51
+ except Exception as e:
52
+ stats['corrupted'] += 1
53
+ error_type = str(e).split('(')[0]
54
+ error_log.append(f"[行{line_num+1}] 损坏文件({error_type}): {audio_path}")
55
+
56
+ valid_samples += 1
57
+
58
+ except json.JSONDecodeError:
59
+ stats['invalid_json'] += 1
60
+ error_log.append(f"[行{line_num+1}] 无效JSON格式")
61
+
62
+ # 打印统计报告
63
+ print("\n===== 验证报告 =====")
64
+ print(f"总行数: {total_lines}")
65
+ print(f"有效样本: {valid_samples}")
66
+ print("--- 问题统计 ---")
67
+ for k, v in sorted(stats.items()):
68
+ print(f"{k}: {v}")
69
+
70
+ # 保存错误日志
71
+ if error_log:
72
+ log_file = f"{os.path.splitext(jsonl_path)[0]}_audio_errors.log"
73
+ with open(log_file, 'w') as f:
74
+ f.write("\n".join(error_log))
75
+ print(f"\n发现 {len(error_log)} 个问题,已保存到 {log_file}")
76
+
77
+ if __name__ == "__main__":
78
+ if len(sys.argv) != 2:
79
+ print("使用方法: python validate_audio_jsonl.py <input.jsonl>")
80
+ sys.exit(1)
81
+
82
+ if not os.path.exists(sys.argv[1]):
83
+ print(f"错误: 文件 {sys.argv[1]} 不存在")
84
+ sys.exit(1)
85
+
86
+ validate_jsonl_audios(sys.argv[1])
clean_transcripts.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from typing import List, Dict, Tuple
4
+
5
+ def parse_timestamp(timestamp: str) -> Tuple[int, int]:
6
+ """Convert timestamp string like '00:15' to seconds."""
7
+ minutes, seconds = map(int, timestamp.split(':'))
8
+ return minutes * 60 + seconds
9
+
10
+ def extract_time_and_speaker(line: str) -> Tuple[Tuple[int, int], str]:
11
+ """Extract time range and speaker from a line."""
12
+ # Extract time range
13
+ time_match = re.match(r'\[(\d{2}:\d{2}) - (\d{2}:\d{2})\] (Speaker [A-Z]):', line)
14
+ if not time_match:
15
+ return None, None
16
+
17
+ start_time = parse_timestamp(time_match.group(1))
18
+ end_time = parse_timestamp(time_match.group(2))
19
+ speaker = time_match.group(3)
20
+
21
+ return (start_time, end_time), speaker
22
+
23
+ def has_overlap(range1: Tuple[int, int], range2: Tuple[int, int]) -> bool:
24
+ """Check if two time ranges overlap."""
25
+ start1, end1 = range1
26
+ start2, end2 = range2
27
+ return not (end1 <= start2 or end2 <= start1)
28
+
29
+ def has_same_speaker_overlap(transcript: str) -> bool:
30
+ """Check if a transcript contains overlapping timestamps for the same speaker."""
31
+ lines = transcript.split('\n')
32
+ # Dictionary to store time ranges for each speaker
33
+ speaker_ranges = {}
34
+
35
+ for line in lines:
36
+ if not line.strip():
37
+ continue
38
+
39
+ time_range, speaker = extract_time_and_speaker(line)
40
+ if time_range is None or speaker is None:
41
+ continue
42
+
43
+ # Check for overlaps with existing ranges of the same speaker
44
+ if speaker in speaker_ranges:
45
+ for existing_range in speaker_ranges[speaker]:
46
+ if has_overlap(time_range, existing_range):
47
+ return True
48
+
49
+ speaker_ranges[speaker].append(time_range)
50
+ else:
51
+ speaker_ranges[speaker] = [time_range]
52
+
53
+ return False
54
+
55
+ def process_file(input_file: str, output_file: str, delete_file: str):
56
+ """Process the JSON file and separate entries with same-speaker overlapping timestamps."""
57
+ with open(input_file, 'r', encoding='utf-8') as f:
58
+ data = json.load(f)
59
+
60
+ if isinstance(data, dict):
61
+ data = [data]
62
+
63
+ cleaned_data = []
64
+ deleted_data = []
65
+ removed_count = 0
66
+
67
+ for entry in data:
68
+ if 'model_output' in entry:
69
+ if not has_same_speaker_overlap(entry['model_output']):
70
+ cleaned_data.append(entry)
71
+ else:
72
+ deleted_data.append(entry)
73
+ removed_count += 1
74
+ print(f"Removing entry with key: {entry.get('key', 'unknown')}")
75
+
76
+ # Save cleaned data
77
+ with open(output_file, 'w', encoding='utf-8') as f:
78
+ json.dump(cleaned_data, f, ensure_ascii=False, indent=2)
79
+
80
+ # Save deleted data
81
+ with open(delete_file, 'w', encoding='utf-8') as f:
82
+ json.dump(deleted_data, f, ensure_ascii=False, indent=2)
83
+
84
+ print(f"\nProcessing Summary:")
85
+ print(f"Processed {len(data)} entries")
86
+ print(f"Removed {removed_count} entries with same-speaker overlapping timestamps")
87
+ print(f"Remaining entries: {len(cleaned_data)}")
88
+
89
+ if __name__ == '__main__':
90
+ input_file = 'silence_overlaps/transcriptions.json'
91
+ output_file = 'silence_overlaps/cleaned_transcriptions2.json'
92
+ delete_file = 'silence_overlaps/delete_transcript2.json'
93
+ process_file(input_file, output_file, delete_file)
94
+ print(f"\nCleaned transcriptions have been saved to {output_file}")
95
+ print(f"Deleted entries have been saved to {delete_file}")
count_audios.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from collections import Counter
4
+ from pathlib import Path
5
+
6
+ def collect_unique_audio_paths(json_file_path):
7
+ """
8
+ 提取 JSONL 文件中所有不重复的 audios 路径
9
+ """
10
+ audio_set = set()
11
+ with open(json_file_path, 'r', encoding='utf-8') as f:
12
+ for line_num, line in enumerate(f, 1):
13
+ line = line.strip()
14
+ if not line:
15
+ continue
16
+ try:
17
+ data = json.loads(line)
18
+ if isinstance(data, dict) and 'audios' in data and data['audios']:
19
+ for audio_path in data['audios']:
20
+ audio_set.add(audio_path)
21
+ except Exception as e:
22
+ print(f"第 {line_num} 行处理错误: {e}")
23
+ return audio_set
24
+
25
+ def extract_first_subfolder_after_data(audio_path):
26
+ """
27
+ 提取 audio_path 中 'data/' 后的第一级子文件夹名称
28
+ 例如:
29
+ /.../data/output_xxx/yyy/file.wav → 返回 output_xxx
30
+ """
31
+ try:
32
+ path = Path(audio_path)
33
+ parts = path.parts
34
+ if "wavrewardDataset" in parts:
35
+ data_idx = parts.index("wavrewardDataset")
36
+ if data_idx + 1 < len(parts):
37
+ return parts[data_idx + 1]
38
+ return "unknown"
39
+ except Exception as e:
40
+ print(f"路径解析错误: {audio_path}, 错误: {e}")
41
+ return "error"
42
+
43
+ def main():
44
+ json_file = "all_dataset_train_resampled_16000.jsonl"
45
+
46
+ if not os.path.exists(json_file):
47
+ print(f"文件 {json_file} 不存在")
48
+ return
49
+
50
+ print(f"正在处理文件: {json_file}")
51
+ print("=" * 50)
52
+
53
+ # 步骤 1:收集去重后的音频路径
54
+ unique_audio_paths = collect_unique_audio_paths(json_file)
55
+ print(f"不重复音频文件数: {len(unique_audio_paths)}")
56
+
57
+ # 步骤 2:按 data 后的一级子目录统计
58
+ folder_counter = Counter()
59
+ for audio_path in unique_audio_paths:
60
+ first_subfolder = extract_first_subfolder_after_data(audio_path)
61
+ folder_counter[first_subfolder] += 1
62
+
63
+ print("\n按 data 后一级子文件夹统计(基于去重后的路径):")
64
+ print("-" * 50)
65
+ for folder, count in sorted(folder_counter.items(), key=lambda x: -x[1]):
66
+ print(f"{folder}: {count} 个文件")
67
+
68
+ if __name__ == "__main__":
69
+ main()
count_folders-Copy1.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from collections import Counter
4
+ from pathlib import Path
5
+
6
+ def count_folder_prefixes(json_file_path):
7
+ """
8
+ 统计JSON文件中音频路径的前缀文件夹出现次数
9
+ 提取到 /root/autodl-tmp/wavrewardDataset/ultrachat_200k/data 层级
10
+ 然后根据子文件夹进行分类统计
11
+ 支持JSONL格式(每行一个JSON对象)
12
+ """
13
+ folder_counts = Counter()
14
+
15
+ # 读取JSON文件(支持JSONL格式)
16
+ with open(json_file_path, 'r', encoding='utf-8') as f:
17
+ for line_num, line in enumerate(f, 1):
18
+ line = line.strip()
19
+ if not line: # 跳过空行
20
+ continue
21
+
22
+ try:
23
+ # 解析每一行的JSON对象
24
+ data = json.loads(line)
25
+
26
+ # 处理数据
27
+ if isinstance(data, dict) and 'audios' in data and data['audios']:
28
+ for audio_path in data['audios']:
29
+ prefix = extract_folder_prefix(audio_path)
30
+ if prefix:
31
+ folder_counts[prefix] += 1
32
+
33
+ except json.JSONDecodeError as e:
34
+ print(f"第 {line_num} 行JSON解析错误: {e}")
35
+ continue
36
+ except Exception as e:
37
+ print(f"第 {line_num} 行处理错误: {e}")
38
+ continue
39
+
40
+ return folder_counts
41
+
42
+ def extract_folder_prefix(audio_path):
43
+ """
44
+ 从音频路径中提取到指定层级的前缀,然后包含子文件夹进行分类
45
+ 例如: /root/autodl-tmp/wavrewardDataset/ultrachat_200k/data/output_2000_3000_wrongpause/xxx.wav
46
+ 提取到: /root/autodl-tmp/wavrewardDataset/ultrachat_200k/data/output_2000_3000_wrongpause
47
+ """
48
+ try:
49
+ # 使用Path对象处理路径
50
+ path = Path(audio_path)
51
+
52
+ # 查找目标文件夹在路径中的位置
53
+ parts = path.parts
54
+ target_folder = "newdataset_10k"
55
+
56
+ if target_folder in parts:
57
+ # 找到目标文件夹的位置
58
+ target_index = parts.index(target_folder)
59
+
60
+ # 提取到目标文件夹(包含目标文件夹)
61
+ prefix_parts = parts[:target_index + 1]
62
+
63
+ # 如果目标文件夹后面还有子文件夹,也包含进来
64
+ if target_index + 1 < len(parts):
65
+ # 包含下一个子文件夹(如 output_2000_3000_wrongpause)
66
+ prefix_parts = parts[:target_index + 2]
67
+
68
+ return str(Path(*prefix_parts))
69
+ else:
70
+ # 如果没找到目标文件夹,返回路径的前几级
71
+ if len(parts) >= 4: # /root/autodl-tmp/xxx/...
72
+ return str(Path(*parts[:4]))
73
+ else:
74
+ return str(path.parent)
75
+
76
+ except Exception as e:
77
+ print(f"处理路径时出错: {audio_path}, 错误: {e}")
78
+ return None
79
+
80
+ def main():
81
+ # 指定JSON文件路径
82
+ json_file = "needtouse/adddata_absolute.jsonl"
83
+
84
+ if not os.path.exists(json_file):
85
+ print(f"文件 {json_file} 不存在")
86
+ return
87
+
88
+ print(f"正在统计文件: {json_file}")
89
+ print("=" * 50)
90
+
91
+ # 统计文件夹前缀
92
+ folder_counts = count_folder_prefixes(json_file)
93
+
94
+ # 输出结果
95
+ if folder_counts:
96
+ print("文件夹前缀统计结果:")
97
+ print("-" * 50)
98
+ for folder, count in sorted(folder_counts.items()):
99
+ print(f"{folder}: {count} 次")
100
+
101
+ print(f"\n总计: {len(folder_counts)} 个不同的文件夹前缀")
102
+ print(f"总音频文件数: {sum(folder_counts.values())}")
103
+
104
+ # 按子文件夹类型统计
105
+ print("\n按子文件夹类型统计:")
106
+ print("-" * 30)
107
+ subfolder_counts = Counter()
108
+ for folder in folder_counts.keys():
109
+ path = Path(folder)
110
+ if len(path.parts) > 0:
111
+ # 获取最后一个文件夹名作为子文件夹类型
112
+ subfolder = path.parts[-1]
113
+ subfolder_counts[subfolder] += folder_counts[folder]
114
+
115
+ for subfolder, count in sorted(subfolder_counts.items()):
116
+ print(f"{subfolder}: {count} 次")
117
+
118
+ else:
119
+ print("未找到任何音频路径")
120
+
121
+ if __name__ == "__main__":
122
+ main()
count_folders.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from collections import Counter
4
+ from pathlib import Path
5
+
6
+ def count_folder_prefixes(json_file_path):
7
+ """
8
+ 统计JSON文件中音频路径的前缀文件夹出现次数
9
+ 提取到 /root/autodl-tmp/wavrewardDataset/ultrachat_200k/data 层级
10
+ 然后根据子文件夹进行分类统计
11
+ 支持JSONL格式(每行一个JSON对象)
12
+ """
13
+ folder_counts = Counter()
14
+
15
+ # 读取JSON文件(支持JSONL格式)
16
+ with open(json_file_path, 'r', encoding='utf-8') as f:
17
+ for line_num, line in enumerate(f, 1):
18
+ line = line.strip()
19
+ if not line: # 跳过空行
20
+ continue
21
+
22
+ try:
23
+ # 解析每一行的JSON对象
24
+ data = json.loads(line)
25
+
26
+ # 处理数据
27
+ if isinstance(data, dict) and 'audios' in data and data['audios']:
28
+ for audio_path in data['audios']:
29
+ prefix = extract_folder_prefix(audio_path)
30
+ if prefix:
31
+ folder_counts[prefix] += 1
32
+
33
+ except json.JSONDecodeError as e:
34
+ print(f"第 {line_num} 行JSON解析错误: {e}")
35
+ continue
36
+ except Exception as e:
37
+ print(f"第 {line_num} 行处理错误: {e}")
38
+ continue
39
+
40
+ return folder_counts
41
+
42
+ def extract_folder_prefix(audio_path):
43
+ """
44
+ 从音频路径中提取到指定层级的前缀,然后包含子文件夹进行分类
45
+ 例如: /root/autodl-tmp/wavrewardDataset/ultrachat_200k/data/output_2000_3000_wrongpause/xxx.wav
46
+ 提取到: /root/autodl-tmp/wavrewardDataset/ultrachat_200k/data/output_2000_3000_wrongpause
47
+ """
48
+ try:
49
+ # 使用Path对象处理路径
50
+ path = Path(audio_path)
51
+
52
+ # 查找目标文件夹在路径中的位置
53
+ parts = path.parts
54
+ target_folder = "data"
55
+
56
+ if target_folder in parts:
57
+ # 找到目标文件夹的位置
58
+ target_index = parts.index(target_folder)
59
+
60
+ # 提取到目标文件夹(包含目标文件夹)
61
+ prefix_parts = parts[:target_index + 1]
62
+
63
+ # 如果目标文件夹后面还有子文件夹,也包含进来
64
+ if target_index + 1 < len(parts):
65
+ # 包含下一个子文件夹(如 output_2000_3000_wrongpause)
66
+ prefix_parts = parts[:target_index + 2]
67
+
68
+ return str(Path(*prefix_parts))
69
+ else:
70
+ # 如果没找到目标文件夹,返回路径的前几级
71
+ if len(parts) >= 4: # /root/autodl-tmp/xxx/...
72
+ return str(Path(*parts[:4]))
73
+ else:
74
+ return str(path.parent)
75
+
76
+ except Exception as e:
77
+ print(f"处理路径时出错: {audio_path}, 错误: {e}")
78
+ return None
79
+
80
+ def main():
81
+ # 指定JSON文件路径
82
+ json_file = "dataset_4JOB.jsonl"
83
+
84
+ if not os.path.exists(json_file):
85
+ print(f"文件 {json_file} 不存在")
86
+ return
87
+
88
+ print(f"正在统计文件: {json_file}")
89
+ print("=" * 50)
90
+
91
+ # 统计文件夹前缀
92
+ folder_counts = count_folder_prefixes(json_file)
93
+
94
+ # 输出结果
95
+ if folder_counts:
96
+ print("文件夹前缀统计结果:")
97
+ print("-" * 50)
98
+ for folder, count in sorted(folder_counts.items()):
99
+ print(f"{folder}: {count} 次")
100
+
101
+ print(f"\n总计: {len(folder_counts)} 个不同的文件夹前缀")
102
+ print(f"总音频文件数: {sum(folder_counts.values())}")
103
+
104
+ # 按子文件夹类型统计
105
+ print("\n按子文件夹类型统计:")
106
+ print("-" * 30)
107
+ subfolder_counts = Counter()
108
+ for folder in folder_counts.keys():
109
+ path = Path(folder)
110
+ if len(path.parts) > 0:
111
+ # 获取最后一个文件夹名作为子文件夹类型
112
+ subfolder = path.parts[-1]
113
+ subfolder_counts[subfolder] += folder_counts[folder]
114
+
115
+ for subfolder, count in sorted(subfolder_counts.items()):
116
+ print(f"{subfolder}: {count} 次")
117
+
118
+ else:
119
+ print("未找到任何音频路径")
120
+
121
+ if __name__ == "__main__":
122
+ main()
dialogue_length_distribution.png ADDED
dialogue_length_ranges.png ADDED
docs/transformers/build/lib/transformers/models/sam/processing_sam.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for SAM.
17
+ """
18
+
19
+ from copy import deepcopy
20
+ from typing import List, Optional, Union
21
+
22
+ import numpy as np
23
+
24
+ from ...image_utils import ImageInput, VideoInput
25
+ from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
26
+ from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput
27
+ from ...utils import is_tf_available, is_torch_available
28
+
29
+
30
+ if is_torch_available():
31
+ import torch
32
+
33
+ if is_tf_available():
34
+ import tensorflow as tf
35
+
36
+
37
+ class SamImagesKwargs(ImagesKwargs):
38
+ segmentation_maps: Optional[ImageInput]
39
+ input_points: Optional[List[List[float]]]
40
+ input_labels: Optional[List[List[int]]]
41
+ input_boxes: Optional[List[List[List[float]]]]
42
+ point_pad_value: Optional[int]
43
+
44
+
45
+ class SamProcessorKwargs(ProcessingKwargs, total=False):
46
+ images_kwargs: SamImagesKwargs
47
+ _defaults = {
48
+ "images_kwargs": {
49
+ "point_pad_value": -10,
50
+ }
51
+ }
52
+
53
+
54
+ class SamProcessor(ProcessorMixin):
55
+ r"""
56
+ Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a
57
+ single processor.
58
+
59
+ [`SamProcessor`] offers all the functionalities of [`SamImageProcessor`]. See the docstring of
60
+ [`~SamImageProcessor.__call__`] for more information.
61
+
62
+ Args:
63
+ image_processor (`SamImageProcessor`):
64
+ An instance of [`SamImageProcessor`]. The image processor is a required input.
65
+ """
66
+
67
+ attributes = ["image_processor"]
68
+ image_processor_class = "SamImageProcessor"
69
+ # For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
70
+ optional_call_args = [
71
+ "segmentation_maps",
72
+ "input_points",
73
+ "input_labels",
74
+ "input_boxes",
75
+ ]
76
+
77
+ def __init__(self, image_processor):
78
+ super().__init__(image_processor)
79
+ self.target_size = self.image_processor.size["longest_edge"]
80
+
81
+ def __call__(
82
+ self,
83
+ images: Optional[ImageInput] = None,
84
+ # The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes`
85
+ # arguments that may be passed as a positional argument.
86
+ # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details,
87
+ # or this conversation for more context:
88
+ # https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116
89
+ # This behavior is only needed for backward compatibility and will be removed in future versions.
90
+ *args, # to be deprecated
91
+ text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
92
+ audio: Optional[AudioInput] = None,
93
+ video: Optional[VideoInput] = None,
94
+ **kwargs,
95
+ ) -> BatchEncoding:
96
+ """
97
+ This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D
98
+ points and bounding boxes for the model if they are provided.
99
+ """
100
+ output_kwargs = self._merge_kwargs(
101
+ SamProcessorKwargs,
102
+ tokenizer_init_kwargs={},
103
+ **kwargs,
104
+ **self.prepare_and_validate_optional_call_args(*args),
105
+ )
106
+ input_points = output_kwargs["images_kwargs"].pop("input_points", None)
107
+ input_labels = output_kwargs["images_kwargs"].pop("input_labels", None)
108
+ input_boxes = output_kwargs["images_kwargs"].pop("input_boxes", None)
109
+ point_pad_value = output_kwargs["images_kwargs"].pop("point_pad_value", None)
110
+
111
+ encoding_image_processor = self.image_processor(
112
+ images,
113
+ **output_kwargs["images_kwargs"],
114
+ )
115
+
116
+ # pop arguments that are not used in the foward but used nevertheless
117
+ original_sizes = encoding_image_processor["original_sizes"]
118
+
119
+ if hasattr(original_sizes, "numpy"): # Checks if Torch or TF tensor
120
+ original_sizes = original_sizes.numpy()
121
+
122
+ input_points, input_labels, input_boxes = self._check_and_preprocess_points(
123
+ input_points=input_points,
124
+ input_labels=input_labels,
125
+ input_boxes=input_boxes,
126
+ )
127
+
128
+ encoding_image_processor = self._normalize_and_convert(
129
+ encoding_image_processor,
130
+ original_sizes,
131
+ input_points=input_points,
132
+ input_labels=input_labels,
133
+ input_boxes=input_boxes,
134
+ return_tensors=output_kwargs["common_kwargs"].get("return_tensors"),
135
+ point_pad_value=point_pad_value,
136
+ )
137
+
138
+ return encoding_image_processor
139
+
140
+ def _normalize_and_convert(
141
+ self,
142
+ encoding_image_processor,
143
+ original_sizes,
144
+ input_points=None,
145
+ input_labels=None,
146
+ input_boxes=None,
147
+ return_tensors="pt",
148
+ point_pad_value=-10,
149
+ ):
150
+ if input_points is not None:
151
+ if len(original_sizes) != len(input_points):
152
+ input_points = [
153
+ self._normalize_coordinates(self.target_size, point, original_sizes[0]) for point in input_points
154
+ ]
155
+ else:
156
+ input_points = [
157
+ self._normalize_coordinates(self.target_size, point, original_size)
158
+ for point, original_size in zip(input_points, original_sizes)
159
+ ]
160
+ # check that all arrays have the same shape
161
+ if not all(point.shape == input_points[0].shape for point in input_points):
162
+ if input_labels is not None:
163
+ input_points, input_labels = self._pad_points_and_labels(
164
+ input_points, input_labels, point_pad_value
165
+ )
166
+
167
+ input_points = np.array(input_points)
168
+
169
+ if input_labels is not None:
170
+ input_labels = np.array(input_labels)
171
+
172
+ if input_boxes is not None:
173
+ if len(original_sizes) != len(input_boxes):
174
+ input_boxes = [
175
+ self._normalize_coordinates(self.target_size, box, original_sizes[0], is_bounding_box=True)
176
+ for box in input_boxes
177
+ ]
178
+ else:
179
+ input_boxes = [
180
+ self._normalize_coordinates(self.target_size, box, original_size, is_bounding_box=True)
181
+ for box, original_size in zip(input_boxes, original_sizes)
182
+ ]
183
+ input_boxes = np.array(input_boxes)
184
+
185
+ if input_boxes is not None:
186
+ if return_tensors == "pt":
187
+ input_boxes = torch.from_numpy(input_boxes)
188
+ # boxes batch size of 1 by default
189
+ input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes
190
+ elif return_tensors == "tf":
191
+ input_boxes = tf.convert_to_tensor(input_boxes)
192
+ # boxes batch size of 1 by default
193
+ input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes
194
+ encoding_image_processor.update({"input_boxes": input_boxes})
195
+ if input_points is not None:
196
+ if return_tensors == "pt":
197
+ input_points = torch.from_numpy(input_points)
198
+ # point batch size of 1 by default
199
+ input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points
200
+ elif return_tensors == "tf":
201
+ input_points = tf.convert_to_tensor(input_points)
202
+ # point batch size of 1 by default
203
+ input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points
204
+ encoding_image_processor.update({"input_points": input_points})
205
+ if input_labels is not None:
206
+ if return_tensors == "pt":
207
+ input_labels = torch.from_numpy(input_labels)
208
+ # point batch size of 1 by default
209
+ input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels
210
+ elif return_tensors == "tf":
211
+ input_labels = tf.convert_to_tensor(input_labels)
212
+ # point batch size of 1 by default
213
+ input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels
214
+ encoding_image_processor.update({"input_labels": input_labels})
215
+
216
+ return encoding_image_processor
217
+
218
+ def _pad_points_and_labels(self, input_points, input_labels, point_pad_value):
219
+ r"""
220
+ The method pads the 2D points and labels to the maximum number of points in the batch.
221
+ """
222
+ expected_nb_points = max([point.shape[0] for point in input_points])
223
+ processed_input_points = []
224
+ for i, point in enumerate(input_points):
225
+ if point.shape[0] != expected_nb_points:
226
+ point = np.concatenate(
227
+ [point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0
228
+ )
229
+ input_labels[i] = np.append(input_labels[i], [point_pad_value])
230
+ processed_input_points.append(point)
231
+ input_points = processed_input_points
232
+ return input_points, input_labels
233
+
234
+ def _normalize_coordinates(
235
+ self, target_size: int, coords: np.ndarray, original_size, is_bounding_box=False
236
+ ) -> np.ndarray:
237
+ """
238
+ Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format.
239
+ """
240
+ old_h, old_w = original_size
241
+ new_h, new_w = self.image_processor._get_preprocess_shape(original_size, longest_edge=target_size)
242
+ coords = deepcopy(coords).astype(float)
243
+
244
+ if is_bounding_box:
245
+ coords = coords.reshape(-1, 2, 2)
246
+
247
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
248
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
249
+
250
+ if is_bounding_box:
251
+ coords = coords.reshape(-1, 4)
252
+
253
+ return coords
254
+
255
+ def _check_and_preprocess_points(
256
+ self,
257
+ input_points=None,
258
+ input_labels=None,
259
+ input_boxes=None,
260
+ ):
261
+ r"""
262
+ Check and preprocesses the 2D points, labels and bounding boxes. It checks if the input is valid and if they
263
+ are, it converts the coordinates of the points and bounding boxes. If a user passes directly a `torch.Tensor`,
264
+ it is converted to a `numpy.ndarray` and then to a `list`.
265
+ """
266
+ if input_points is not None:
267
+ if hasattr(input_points, "numpy"): # Checks for TF or Torch tensor
268
+ input_points = input_points.numpy().tolist()
269
+
270
+ if not isinstance(input_points, list) or not isinstance(input_points[0], list):
271
+ raise ValueError("Input points must be a list of list of floating points.")
272
+ input_points = [np.array(input_point) for input_point in input_points]
273
+ else:
274
+ input_points = None
275
+
276
+ if input_labels is not None:
277
+ if hasattr(input_labels, "numpy"):
278
+ input_labels = input_labels.numpy().tolist()
279
+
280
+ if not isinstance(input_labels, list) or not isinstance(input_labels[0], list):
281
+ raise ValueError("Input labels must be a list of list integers.")
282
+ input_labels = [np.array(label) for label in input_labels]
283
+ else:
284
+ input_labels = None
285
+
286
+ if input_boxes is not None:
287
+ if hasattr(input_boxes, "numpy"):
288
+ input_boxes = input_boxes.numpy().tolist()
289
+
290
+ if (
291
+ not isinstance(input_boxes, list)
292
+ or not isinstance(input_boxes[0], list)
293
+ or not isinstance(input_boxes[0][0], list)
294
+ ):
295
+ raise ValueError("Input boxes must be a list of list of list of floating points.")
296
+ input_boxes = [np.array(box).astype(np.float32) for box in input_boxes]
297
+ else:
298
+ input_boxes = None
299
+
300
+ return input_points, input_labels, input_boxes
301
+
302
+ @property
303
+ def model_input_names(self):
304
+ image_processor_input_names = self.image_processor.model_input_names
305
+ return list(dict.fromkeys(image_processor_input_names))
306
+
307
+ def post_process_masks(self, *args, **kwargs):
308
+ return self.image_processor.post_process_masks(*args, **kwargs)
309
+
310
+
311
+ __all__ = ["SamProcessor"]
docs/transformers/build/lib/transformers/models/seamless_m4t/convert_fairseq2_to_hf.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Converting Meta SeamlessM4T checkpoints from seamless_communication to HF."""
16
+
17
+ import argparse
18
+ import os
19
+ from pathlib import Path
20
+
21
+ import torch
22
+ from accelerate.utils.modeling import find_tied_parameters
23
+ from seamless_communication.models.inference.translator import Translator
24
+
25
+ from transformers import (
26
+ SeamlessM4TConfig,
27
+ SeamlessM4TFeatureExtractor,
28
+ SeamlessM4TModel,
29
+ SeamlessM4TProcessor,
30
+ SeamlessM4TTokenizer,
31
+ )
32
+ from transformers.utils import logging
33
+
34
+
35
+ UNIT_SUPPORTED_LANGUAGES = ["__arb__", "__ben__", "__cat__", "__ces__", "__cmn__", "__cym__", "__dan__", "__deu__", "__eng__", "__est__", "__fin__", "__fra__", "__hin__", "__ind__", "__ita__", "__jpn__", "__kan__", "__kor__", "__mlt__", "__nld__", "__pes__", "__pol__", "__por__", "__ron__", "__rus__", "__slk__", "__spa__", "__swe__", "__swh__", "__tam__", "__tel__", "__tgl__", "__tha__", "__tur__", "__ukr__", "__urd__", "__uzn__", "__vie__", ] # fmt: skip
36
+ VOCODER_SUPPORTED_LANGUAGES = ["__arb__", "__ben__", "__cat__", "__ces__", "__cmn__", "__cym__", "__dan__", "__deu__", "__eng__", "__est__", "__fin__", "__fra__", "__hin__", "__ind__", "__ita__", "__jpn__", "__kor__", "__mlt__", "__nld__", "__pes__", "__pol__", "__por__", "__ron__", "__rus__", "__slk__", "__spa__", "__swe__", "__swh__", "__tel__", "__tgl__", "__tha__", "__tur__", "__ukr__", "__urd__", "__uzn__", "__vie__",] # fmt: skip
37
+ MEDIUM_SUPPORTED_LANGUAGES = ["ace","ace_Latn","acm","acq","aeb","afr","ajp","aka","amh","apc","arb","ars","ary","arz","asm","ast","awa","ayr","azb","azj","bak","bam","ban","bel","bem","ben","bho","bjn","bjn_Latn","bod","bos","bug","bul","cat","ceb","ces","cjk","ckb","crh","cym","dan","deu","dik","dyu","dzo","ell","eng","epo","est","eus","ewe","fao","pes","fij","fin","fon","fra","fur","fuv","gla","gle","glg","grn","guj","hat","hau","heb","hin","hne","hrv","hun","hye","ibo","ilo","ind","isl","ita","jav","jpn","kab","kac","kam","kan","kas","kas_Deva","kat","knc","knc_Latn","kaz","kbp","kea","khm","kik","kin","kir","kmb","kon","kor","kmr","lao","lvs","lij","lim","lin","lit","lmo","ltg","ltz","lua","lug","luo","lus","mag","mai","mal","mar","min","mkd","plt","mlt","mni","khk","mos","mri","zsm","mya","nld","nno","nob","npi","nso","nus","nya","oci","gaz","ory","pag","pan","pap","pol","por","prs","pbt","quy","ron","run","rus","sag","san","sat","scn","shn","sin","slk","slv","smo","sna","snd","som","sot","spa","als","srd","srp","ssw","sun","swe","swh","szl","tam","tat","tel","tgk","tgl","tha","tir","taq","taq_Tfng","tpi","tsn","tso","tuk","tum","tur","twi","tzm","uig","ukr","umb","urd","uzn","vec","vie","war","wol","xho","ydd","yor","yue","cmn","cmn_Hant","zul",] # fmt: skip
38
+ LARGE_SUPPORTED_LANGUAGES = ["afr","amh","arb","ary","arz","asm","azj","bel","ben","bos","bul","cat","ceb","ces","ckb","cmn","cmn_Hant","cym","dan","deu","ell","eng","est","eus","fin","fra","fuv","gaz","gle","glg","guj","heb","hin","hrv","hun","hye","ibo","ind","isl","ita","jav","jpn","kan","kat","kaz","khk","khm","kir","kor","lao","lit","lug","luo","lvs","mai","mal","mar","mkd","mlt","mni","mya","nld","nno","nob","npi","nya","ory","pan","pbt","pes","pol","por","ron","rus","sat","slk","slv","sna","snd","som","spa","srp","swe","swh","tam","tel","tgk","tgl","tha","tur","ukr","urd","uzn","vie","yor","yue","zlm","zul",] # fmt: skip
39
+
40
+
41
+ def assert_param_count(model_1, model_2):
42
+ count_1 = sum(p[1].numel() for p in model_1.named_parameters() if "final_proj" not in p[0])
43
+ count_2 = sum(p[1].numel() for p in model_2.named_parameters() if "final_proj" not in p[0])
44
+ assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}"
45
+
46
+
47
+ def param_count(model):
48
+ return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0])
49
+
50
+
51
+ def _grab_best_device(use_gpu=True):
52
+ if torch.cuda.device_count() > 0 and use_gpu:
53
+ device = "cuda"
54
+ else:
55
+ device = "cpu"
56
+ return torch.device(device)
57
+
58
+
59
+ logging.set_verbosity_info()
60
+ logger = logging.get_logger(__name__)
61
+
62
+ vocoder_convert_list = [
63
+ ("ups", "hifi_gan.upsampler"),
64
+ ("conv_pre", "hifi_gan.conv_pre"),
65
+ ("resblocks", "hifi_gan.resblocks"),
66
+ ("conv_post", "hifi_gan.conv_post"),
67
+ ("lang", "language_embedding"),
68
+ ("spkr", "speaker_embedding"),
69
+ ("dict.", "unit_embedding."),
70
+ ("dur_predictor.conv1.0", "dur_predictor.conv1"),
71
+ ("dur_predictor.conv2.0", "dur_predictor.conv2"),
72
+ ]
73
+
74
+ # order is important
75
+ wav2vec_convert_list = [
76
+ ("speech_encoder_frontend.model_dim_proj", "feature_projection.projection"),
77
+ ("speech_encoder_frontend.post_extract_layer_norm", "feature_projection.layer_norm"),
78
+ ("speech_encoder_frontend.pos_encoder.conv", "encoder.pos_conv_embed.conv"),
79
+ ("speech_encoder.inner.layers", "encoder.layers"),
80
+ ("speech_encoder.inner_layer_norm", "encoder.layer_norm"),
81
+ ("speech_encoder.adaptor_layers", "adapter.layers"),
82
+ ("inner_proj", "intermediate_dense"),
83
+ ("self_attn.output_proj", "self_attn.linear_out"),
84
+ ("output_proj", "output_dense"),
85
+ ("self_attn.k_proj", "self_attn.linear_k"),
86
+ ("self_attn.v_proj", "self_attn.linear_v"),
87
+ ("self_attn.q_proj", "self_attn.linear_q"),
88
+ ("self_attn.sdpa.u_bias", "self_attn.pos_bias_u"),
89
+ ("self_attn.sdpa.v_bias", "self_attn.pos_bias_v"),
90
+ ("self_attn.sdpa.r_proj", "self_attn.linear_pos"),
91
+ ("conv.pointwise_conv1", "conv_module.pointwise_conv1"),
92
+ ("conv.pointwise_conv2", "conv_module.pointwise_conv2"),
93
+ ("conv.depthwise_conv", "conv_module.depthwise_conv"),
94
+ ("conv.batch_norm", "conv_module.batch_norm"),
95
+ ("conv_layer_norm", "conv_module.layer_norm"),
96
+ ("speech_encoder.proj1", "intermediate_ffn.intermediate_dense"),
97
+ ("speech_encoder.proj2", "intermediate_ffn.output_dense"),
98
+ ("speech_encoder.layer_norm", "inner_layer_norm"),
99
+ ]
100
+
101
+ t2u_convert_list = [
102
+ ("t2u_model.final_proj", "lm_head"),
103
+ ("t2u_model.", "model."),
104
+ ("encoder_decoder_attn_layer_norm", "cross_attention_layer_norm"),
105
+ ("encoder_decoder_attn", "cross_attention"),
106
+ ("linear_k", "k_proj"),
107
+ ("linear_v", "v_proj"),
108
+ ("linear_q", "q_proj"),
109
+ ("ffn.inner_proj", "ffn.fc1"),
110
+ ("ffn.output_proj", "ffn.fc2"),
111
+ ("output_proj", "out_proj"),
112
+ ("decoder_frontend.embed", "decoder.embed_tokens"),
113
+ ]
114
+
115
+ text_convert_list = [
116
+ ("text_encoder.", ""),
117
+ ("text_decoder.", ""),
118
+ ("text_encoder_frontend.embed", "embed_tokens"),
119
+ ("text_decoder_frontend.embed", "embed_tokens"),
120
+ ("encoder_decoder_attn_layer_norm", "cross_attention_layer_norm"),
121
+ ("encoder_decoder_attn", "cross_attention"),
122
+ ("linear_k", "k_proj"),
123
+ ("linear_v", "v_proj"),
124
+ ("linear_q", "q_proj"),
125
+ ("ffn.inner_proj", "ffn.fc1"),
126
+ ("ffn.output_proj", "ffn.fc2"),
127
+ ("output_proj", "out_proj"),
128
+ ("final_proj", "lm_head"),
129
+ ]
130
+
131
+ CUR_PATH = os.path.dirname(os.path.abspath(__file__))
132
+ default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
133
+ CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "huggingface", "hub")
134
+
135
+
136
+ def _load_hf_config(model_type="medium"):
137
+ if model_type == "medium":
138
+ kwargs = {
139
+ "vocab_size": 256206,
140
+ "t2u_vocab_size": 10082,
141
+ "hidden_size": 1024,
142
+ "max_position_embeddings": 4096,
143
+ "encoder_layers": 12,
144
+ "decoder_layers": 12,
145
+ "encoder_ffn_dim": 4096,
146
+ "decoder_ffn_dim": 4096,
147
+ "t2u_encoder_layers": 4,
148
+ "t2u_decoder_layers": 4,
149
+ "speech_encoder_layers": 12,
150
+ }
151
+ return SeamlessM4TConfig(**kwargs)
152
+ else:
153
+ return SeamlessM4TConfig()
154
+
155
+
156
+ def _convert_model(
157
+ original_model,
158
+ hf_model,
159
+ convert_list,
160
+ device,
161
+ unwanted_prefix="model.",
162
+ filter_state_dict="speech",
163
+ exclude_state_dict=None,
164
+ ):
165
+ state_dict = original_model.state_dict()
166
+
167
+ # filter func
168
+ if isinstance(filter_state_dict, str):
169
+
170
+ def filter_func(x):
171
+ return filter_state_dict in x[0]
172
+
173
+ else:
174
+
175
+ def filter_func(item):
176
+ if exclude_state_dict is not None and exclude_state_dict in item[0]:
177
+ return False
178
+ for filter_el in filter_state_dict:
179
+ if filter_el in item[0]:
180
+ return True
181
+
182
+ return False
183
+
184
+ state_dict = dict(filter(filter_func, state_dict.items()))
185
+
186
+ for k, v in list(state_dict.items()):
187
+ new_k = k[len(unwanted_prefix) :]
188
+ for old_layer_name, new_layer_name in convert_list:
189
+ if old_layer_name in new_k:
190
+ new_k = new_k.replace(old_layer_name, new_layer_name)
191
+
192
+ # must do it by hand
193
+ if ".layer_norm" in new_k and new_k.split(".layer_norm")[0][-1].isnumeric():
194
+ new_k = new_k.replace("layer_norm", "final_layer_norm")
195
+
196
+ state_dict[new_k] = state_dict.pop(k)
197
+
198
+ extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys())
199
+ extra_keys = set(extra_keys)
200
+ missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys())
201
+ missing_keys = set({k for k in missing_keys if "final_logits_bias" not in k})
202
+ if len(extra_keys) != 0:
203
+ raise ValueError(f"extra keys found: {extra_keys}")
204
+ if len(missing_keys) != 0:
205
+ raise ValueError(f"missing keys: {missing_keys}")
206
+ hf_model.load_state_dict(state_dict, strict=False)
207
+ n_params = param_count(hf_model)
208
+
209
+ logger.info(f"model loaded: {round(n_params / 1e6, 1)}M params")
210
+
211
+ hf_model.eval()
212
+ hf_model.to(device)
213
+ del state_dict
214
+
215
+ return hf_model
216
+
217
+
218
+ def load_model(save_dir, model_type, repo_id):
219
+ """
220
+ Meta SeamlessM4T is made of 8 main components:
221
+ - speech_encoder (#1) and speech_encoder_frontend (#2)
222
+ - t2u_model (#3)
223
+ - text_encoder (#4) and text_encoder_frontend (#5)
224
+ - text_decoder (#6) [and text_decoder_frontend (#5) = equals to text_encoder_frontend]
225
+ - final_proj (#7)
226
+ - vocoder (#8)
227
+ """
228
+ device = _grab_best_device()
229
+ if model_type == "medium":
230
+ name = "seamlessM4T_medium"
231
+ else:
232
+ name = "seamlessM4T_large"
233
+
234
+ original_model = Translator(name, "vocoder_36langs", device, torch.float32)
235
+
236
+ ######### TOKENIZER
237
+
238
+ langs = MEDIUM_SUPPORTED_LANGUAGES if model_type == "medium" else LARGE_SUPPORTED_LANGUAGES
239
+ langs = [f"__{lang}__" for lang in langs]
240
+ vocab_file = os.path.join(os.path.expanduser("~"), "tokenizer", model_type, "tokenizer.model")
241
+
242
+ save_dir = os.path.join(save_dir, name)
243
+ Path(save_dir).mkdir(exist_ok=True)
244
+
245
+ tokenizer = SeamlessM4TTokenizer(vocab_file, additional_special_tokens=langs)
246
+
247
+ sanity_check_lang_id = tokenizer.convert_tokens_to_ids("__fra__")
248
+
249
+ tokenizer.save_pretrained(save_dir)
250
+ tokenizer = SeamlessM4TTokenizer.from_pretrained(save_dir)
251
+
252
+ if sanity_check_lang_id != tokenizer.convert_tokens_to_ids("__fra__"):
253
+ raise ValueError(
254
+ f"Error in tokenizer saving/loading - __fra__ lang id is not coherent: {sanity_check_lang_id} vs {tokenizer.convert_tokens_to_ids('__fra__')}"
255
+ )
256
+
257
+ ####### get language to ids dict
258
+ text_decoder_lang_code_to_id = {lang.replace("__", ""): tokenizer.convert_tokens_to_ids(lang) for lang in langs}
259
+ # offset: vocoder unit vocab size + 5 (for EOS/PAD/BOS/UNK/MSK) + len(supported_languages)
260
+ t2u_lang_code_to_id = {
261
+ code.replace("__", ""): i + 10005 + len(UNIT_SUPPORTED_LANGUAGES)
262
+ for i, code in enumerate(UNIT_SUPPORTED_LANGUAGES)
263
+ }
264
+ vocoder_lang_code_to_id = {code.replace("__", ""): i for i, code in enumerate(VOCODER_SUPPORTED_LANGUAGES)}
265
+
266
+ ######### FE
267
+
268
+ fe = SeamlessM4TFeatureExtractor(language_code=langs)
269
+
270
+ fe.save_pretrained(save_dir)
271
+ fe = SeamlessM4TFeatureExtractor.from_pretrained(save_dir)
272
+
273
+ processor = SeamlessM4TProcessor(feature_extractor=fe, tokenizer=tokenizer)
274
+ processor.save_pretrained(save_dir)
275
+ processor.push_to_hub(repo_id=repo_id, create_pr=True)
276
+
277
+ processor = SeamlessM4TProcessor.from_pretrained(save_dir)
278
+
279
+ ######## Model
280
+
281
+ # init model
282
+ hf_config = _load_hf_config(model_type)
283
+ hf_model = SeamlessM4TModel(hf_config)
284
+
285
+ hf_model.generation_config.__setattr__("text_decoder_lang_to_code_id", text_decoder_lang_code_to_id)
286
+ hf_model.generation_config.__setattr__("t2u_lang_code_to_id", t2u_lang_code_to_id)
287
+ hf_model.generation_config.__setattr__("vocoder_lang_code_to_id", vocoder_lang_code_to_id)
288
+
289
+ # -1. take care of vocoder
290
+ # similarly to speech T5 must apply and remove weight norm
291
+ hf_model.vocoder.apply_weight_norm()
292
+ hf_model.vocoder = _convert_model(
293
+ original_model,
294
+ hf_model.vocoder,
295
+ vocoder_convert_list,
296
+ device,
297
+ unwanted_prefix="vocoder.code_generator.",
298
+ filter_state_dict="vocoder",
299
+ )
300
+ hf_model.vocoder.remove_weight_norm()
301
+
302
+ # 1. take care of speech encoder
303
+ wav2vec = hf_model.speech_encoder
304
+ hf_model.speech_encoder = _convert_model(
305
+ original_model, wav2vec, wav2vec_convert_list, device, unwanted_prefix="model.", filter_state_dict="speech"
306
+ )
307
+
308
+ # 2. take care of t2u
309
+
310
+ hf_model.t2u_model = _convert_model(
311
+ original_model,
312
+ hf_model.t2u_model,
313
+ t2u_convert_list,
314
+ device,
315
+ unwanted_prefix="model.",
316
+ filter_state_dict="t2u_model",
317
+ )
318
+
319
+ # 3. take care of text encoder
320
+ hf_model.text_encoder = _convert_model(
321
+ original_model,
322
+ hf_model.text_encoder,
323
+ text_convert_list,
324
+ device,
325
+ unwanted_prefix="model.",
326
+ filter_state_dict=["model.text_encoder"],
327
+ exclude_state_dict="t2u_model",
328
+ )
329
+
330
+ # 4. take care of text decoder
331
+ hf_model.text_decoder = _convert_model(
332
+ original_model,
333
+ hf_model.text_decoder,
334
+ text_convert_list,
335
+ device,
336
+ unwanted_prefix="model.",
337
+ filter_state_dict=["model.text_decoder"],
338
+ exclude_state_dict="t2u_model",
339
+ )
340
+
341
+ # 5. take care of final proj
342
+ hf_model.lm_head = _convert_model(
343
+ original_model,
344
+ hf_model.lm_head,
345
+ [("final_proj.", "")],
346
+ device,
347
+ unwanted_prefix="model.",
348
+ filter_state_dict=["model.final_proj"],
349
+ exclude_state_dict="t2u_model",
350
+ )
351
+
352
+ # sanity check
353
+ print(find_tied_parameters(hf_model))
354
+
355
+ count_1 = param_count(hf_model)
356
+ count_2 = param_count(original_model)
357
+
358
+ print(f"HF MODEL:{count_1}, ORIGINAL_MODEL: {count_2}, diff:{count_1 - count_2}")
359
+ print(f"HF MODEL excluding embeddings:{hf_model.num_parameters(exclude_embeddings=True)}")
360
+
361
+ del original_model
362
+
363
+ hf_model.generation_config._from_model_config = False
364
+ hf_model.save_pretrained(save_dir)
365
+ hf_model.push_to_hub(repo_id=repo_id, create_pr=True)
366
+ hf_model = SeamlessM4TModel.from_pretrained(save_dir)
367
+
368
+
369
+ if __name__ == "__main__":
370
+ parser = argparse.ArgumentParser()
371
+ # Required parameters
372
+
373
+ parser.add_argument(
374
+ "--model_type",
375
+ default="medium",
376
+ type=str,
377
+ help="Model type.",
378
+ )
379
+
380
+ parser.add_argument(
381
+ "--save_dir",
382
+ default="/home/ubuntu/weights",
383
+ type=str,
384
+ help="Path to the output PyTorch model.",
385
+ )
386
+
387
+ parser.add_argument(
388
+ "--repo_id",
389
+ default="facebook/hf-seamless-m4t-medium",
390
+ type=str,
391
+ help="Repo ID.",
392
+ )
393
+
394
+ args = parser.parse_args()
395
+
396
+ load_model(args.save_dir, args.model_type, args.repo_id)
docs/transformers/build/lib/transformers/models/seamless_m4t/modeling_seamless_m4t.py ADDED
The diff for this file is too large to render. See raw diff
 
docs/transformers/build/lib/transformers/models/seamless_m4t/processing_seamless_m4t.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Audio/Text processor class for SeamlessM4T
17
+ """
18
+
19
+ from ...processing_utils import ProcessorMixin
20
+
21
+
22
+ class SeamlessM4TProcessor(ProcessorMixin):
23
+ r"""
24
+ Constructs a SeamlessM4T processor which wraps a SeamlessM4T feature extractor and a SeamlessM4T tokenizer into a
25
+ single processor.
26
+
27
+ [`SeamlessM4TProcessor`] offers all the functionalities of [`SeamlessM4TFeatureExtractor`] and
28
+ [`SeamlessM4TTokenizerFast`]. See the [`~SeamlessM4TProcessor.__call__`] and [`~SeamlessM4TProcessor.decode`] for
29
+ more information.
30
+
31
+ Args:
32
+ feature_extractor ([`SeamlessM4TFeatureExtractor`]):
33
+ The audio processor is a required input.
34
+ tokenizer ([`SeamlessM4TTokenizerFast`]):
35
+ The tokenizer is a required input.
36
+ """
37
+
38
+ feature_extractor_class = "SeamlessM4TFeatureExtractor"
39
+ tokenizer_class = ("SeamlessM4TTokenizer", "SeamlessM4TTokenizerFast")
40
+
41
+ def __init__(self, feature_extractor, tokenizer):
42
+ super().__init__(feature_extractor, tokenizer)
43
+
44
+ def __call__(self, text=None, audios=None, src_lang=None, tgt_lang=None, **kwargs):
45
+ """
46
+ Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
47
+ and `kwargs` arguments to SeamlessM4TTokenizerFast's [`~SeamlessM4TTokenizerFast.__call__`] if `text` is not
48
+ `None` to encode the text. To prepare the audio(s), this method forwards the `audios` and `kwrags` arguments to
49
+ SeamlessM4TFeatureExtractor's [`~SeamlessM4TFeatureExtractor.__call__`] if `audios` is not `None`. Please refer
50
+ to the docstring of the above two methods for more information.
51
+
52
+ Args:
53
+ text (`str`, `List[str]`, `List[List[str]]`):
54
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
55
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
56
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
57
+ audios (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
58
+ The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case
59
+ of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels,
60
+ and T the sample length of the audio.
61
+ src_lang (`str`, *optional*):
62
+ The language code of the input texts/audios. If not specified, the last `src_lang` specified will be
63
+ used.
64
+ tgt_lang (`str`, *optional*):
65
+ The code of the target language. If not specified, the last `tgt_lang` specified will be used.
66
+ kwargs (*optional*):
67
+ Remaining dictionary of keyword arguments that will be passed to the feature extractor and/or the
68
+ tokenizer.
69
+ Returns:
70
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
71
+
72
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
73
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
74
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
75
+ `None`).
76
+ - **input_features** -- Audio input features to be fed to a model. Returned when `audios` is not `None`.
77
+ """
78
+ sampling_rate = kwargs.pop("sampling_rate", None)
79
+
80
+ if text is None and audios is None:
81
+ raise ValueError("You have to specify either text or audios. Both cannot be none.")
82
+ elif text is not None and audios is not None:
83
+ raise ValueError(
84
+ "Text and audios are mututally exclusive when passed to `SeamlessM4T`. Specify one or another."
85
+ )
86
+ elif text is not None:
87
+ if tgt_lang is not None:
88
+ self.tokenizer.tgt_lang = tgt_lang
89
+ if src_lang is not None:
90
+ self.tokenizer.src_lang = src_lang
91
+ encoding = self.tokenizer(text, **kwargs)
92
+
93
+ return encoding
94
+
95
+ else:
96
+ encoding = self.feature_extractor(audios, sampling_rate=sampling_rate, **kwargs)
97
+ return encoding
98
+
99
+ def batch_decode(self, *args, **kwargs):
100
+ """
101
+ This method forwards all its arguments to SeamlessM4TTokenizerFast's [`~PreTrainedTokenizer.batch_decode`].
102
+ Please refer to the docstring of this method for more information.
103
+ """
104
+ return self.tokenizer.batch_decode(*args, **kwargs)
105
+
106
+ def decode(self, *args, **kwargs):
107
+ """
108
+ This method forwards all its arguments to SeamlessM4TTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please
109
+ refer to the docstring of this method for more information.
110
+ """
111
+ return self.tokenizer.decode(*args, **kwargs)
112
+
113
+ @property
114
+ def model_input_names(self):
115
+ tokenizer_input_names = self.tokenizer.model_input_names
116
+ feature_extractor_input_names = self.feature_extractor.model_input_names
117
+ return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
118
+
119
+
120
+ __all__ = ["SeamlessM4TProcessor"]
docs/transformers/build/lib/transformers/models/seamless_m4t/tokenization_seamless_m4t_fast.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Fast Tokenization class for SeamlessM4T."""
16
+
17
+ import os
18
+ from shutil import copyfile
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ from tokenizers import processors
22
+
23
+ from ...tokenization_utils import (
24
+ BatchEncoding,
25
+ PreTokenizedInput,
26
+ TextInput,
27
+ )
28
+ from ...tokenization_utils_fast import PreTrainedTokenizerFast
29
+ from ...utils import PaddingStrategy, is_sentencepiece_available, logging
30
+
31
+
32
+ if is_sentencepiece_available():
33
+ from .tokenization_seamless_m4t import SeamlessM4TTokenizer
34
+ else:
35
+ SeamlessM4TTokenizer = None
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
40
+
41
+
42
+ class SeamlessM4TTokenizerFast(PreTrainedTokenizerFast):
43
+ """
44
+ Construct a "fast" SeamlessM4T tokenizer (backed by HuggingFace's *tokenizers* library). Based on
45
+ [BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models).
46
+
47
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
48
+ refer to this superclass for more information regarding those methods.
49
+
50
+ The tokenization method is `<language code> <tokens> <eos>` for source language documents, and `<eos> <language
51
+ code> <tokens> <eos>` for target language documents.
52
+
53
+ Examples:
54
+
55
+ ```python
56
+ >>> from transformers import SeamlessM4TTokenizerFast
57
+
58
+ >>> tokenizer = SeamlessM4TTokenizerFast.from_pretrained(
59
+ ... "facebook/hf-seamless-m4t-medium", src_lang="eng", tgt_lang="fra"
60
+ ... )
61
+ >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
62
+ >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie."
63
+ >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt")
64
+ ```
65
+
66
+ Args:
67
+ vocab_file (`str`, *optional*):
68
+ Path to the vocabulary file.
69
+ tokenizer_file (`str`, *optional*):
70
+ The path to a tokenizer file to use instead of the vocab file.
71
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
72
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
73
+
74
+ <Tip>
75
+
76
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
77
+ sequence. The token used is the `cls_token`.
78
+
79
+ </Tip>
80
+
81
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
82
+ The end of sequence token.
83
+
84
+ <Tip>
85
+
86
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
87
+ The token used is the `sep_token`.
88
+
89
+ </Tip>
90
+
91
+ sep_token (`str`, *optional*, defaults to `"</s>"`):
92
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
93
+ sequence classification or for a text and a question for question answering. It is also used as the last
94
+ token of a sequence built with special tokens.
95
+ cls_token (`str`, *optional*, defaults to `"<s>"`):
96
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
97
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
98
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
99
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
100
+ token instead.
101
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
102
+ The token used for padding, for example when batching sequences of different lengths.
103
+ src_lang (`str`, *optional*, defaults to `"eng"`):
104
+ The language to use as source language for translation.
105
+ tgt_lang (`str`, *optional*, defaults to `"fra"`):
106
+ The language to use as target language for translation.
107
+ additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*):
108
+ A tuple or a list of additional special tokens.
109
+ """
110
+
111
+ vocab_files_names = VOCAB_FILES_NAMES
112
+ slow_tokenizer_class = SeamlessM4TTokenizer
113
+ model_input_names = ["input_ids", "attention_mask"]
114
+
115
+ prefix_tokens: List[int] = []
116
+ suffix_tokens: List[int] = []
117
+
118
+ def __init__(
119
+ self,
120
+ vocab_file=None,
121
+ tokenizer_file=None,
122
+ bos_token="<s>",
123
+ eos_token="</s>",
124
+ sep_token="</s>",
125
+ cls_token="<s>",
126
+ unk_token="<unk>",
127
+ pad_token="<pad>",
128
+ src_lang="eng",
129
+ tgt_lang="fra",
130
+ additional_special_tokens=None,
131
+ **kwargs,
132
+ ):
133
+ super().__init__(
134
+ vocab_file=vocab_file,
135
+ tokenizer_file=tokenizer_file,
136
+ bos_token=bos_token,
137
+ eos_token=eos_token,
138
+ sep_token=sep_token,
139
+ cls_token=cls_token,
140
+ unk_token=unk_token,
141
+ pad_token=pad_token,
142
+ src_lang=src_lang,
143
+ tgt_lang=tgt_lang,
144
+ additional_special_tokens=additional_special_tokens,
145
+ **kwargs,
146
+ )
147
+
148
+ self.vocab_file = vocab_file
149
+ self._src_lang = f"__{src_lang}__" if "__" not in src_lang else src_lang
150
+ self._tgt_lang = f"__{tgt_lang}__" if "__" not in tgt_lang else tgt_lang
151
+ self.set_src_lang_special_tokens(self._src_lang)
152
+ self.set_tgt_lang_special_tokens(self._tgt_lang)
153
+
154
+ @property
155
+ def can_save_slow_tokenizer(self) -> bool:
156
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
157
+
158
+ @property
159
+ # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.src_lang
160
+ def src_lang(self) -> str:
161
+ return self._src_lang
162
+
163
+ @src_lang.setter
164
+ def src_lang(self, new_src_lang: str) -> None:
165
+ if "__" not in new_src_lang:
166
+ self._src_lang = f"__{new_src_lang}__"
167
+ else:
168
+ self._src_lang = new_src_lang
169
+ self.set_src_lang_special_tokens(self._src_lang)
170
+
171
+ @property
172
+ def tgt_lang(self) -> str:
173
+ return self._tgt_lang
174
+
175
+ @tgt_lang.setter
176
+ def tgt_lang(self, new_tgt_lang: str) -> None:
177
+ if "__" not in new_tgt_lang:
178
+ self._tgt_lang = f"__{new_tgt_lang}__"
179
+ else:
180
+ self._tgt_lang = new_tgt_lang
181
+ self.set_tgt_lang_special_tokens(self._tgt_lang)
182
+
183
+ def build_inputs_with_special_tokens(
184
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
185
+ ) -> List[int]:
186
+ """
187
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
188
+ adding special tokens. The special tokens depend on calling set_lang.
189
+
190
+ An SeamlessM4T sequence has the following format, where `X` represents the sequence:
191
+
192
+ - `input_ids` (for encoder) `[src_lang_code] X [eos]`
193
+ - `decoder_input_ids`: (for decoder) `[eos, tgt_lang_code] X [eos]`
194
+
195
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
196
+ separator.
197
+
198
+ Args:
199
+ token_ids_0 (`List[int]`):
200
+ List of IDs to which the special tokens will be added.
201
+ token_ids_1 (`List[int]`, *optional*):
202
+ Optional second list of IDs for sequence pairs.
203
+
204
+ Returns:
205
+ `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
206
+ """
207
+ if token_ids_1 is None:
208
+ return self.prefix_tokens + token_ids_0 + self.suffix_tokens
209
+ # We don't expect to process pairs, but leave the pair logic for API consistency
210
+ return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
211
+
212
+ # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast.create_token_type_ids_from_sequences
213
+ def create_token_type_ids_from_sequences(
214
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
215
+ ) -> List[int]:
216
+ """
217
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not
218
+ make use of token type ids, therefore a list of zeros is returned.
219
+
220
+ Args:
221
+ token_ids_0 (`List[int]`):
222
+ List of IDs.
223
+ token_ids_1 (`List[int]`, *optional*):
224
+ Optional second list of IDs for sequence pairs.
225
+
226
+ Returns:
227
+ `List[int]`: List of zeros.
228
+
229
+ """
230
+
231
+ sep = [self.sep_token_id]
232
+ cls = [self.cls_token_id]
233
+
234
+ if token_ids_1 is None:
235
+ return len(cls + token_ids_0 + sep) * [0]
236
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
237
+
238
+ def _build_translation_inputs(
239
+ self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
240
+ ):
241
+ """Used by translation pipeline, to prepare inputs for the generate function"""
242
+ if src_lang is None or tgt_lang is None:
243
+ raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
244
+ self.src_lang = src_lang
245
+ inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
246
+ if "__" not in tgt_lang:
247
+ tgt_lang = f"__{tgt_lang}__"
248
+ tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
249
+ inputs["forced_bos_token_id"] = tgt_lang_id
250
+ return inputs
251
+
252
+ # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast.prepare_seq2seq_batch with "fra_Latn"->"fra", "eng_Latn"->"eng"
253
+ def prepare_seq2seq_batch(
254
+ self,
255
+ src_texts: List[str],
256
+ src_lang: str = "eng",
257
+ tgt_texts: Optional[List[str]] = None,
258
+ tgt_lang: str = "fra",
259
+ **kwargs,
260
+ ) -> BatchEncoding:
261
+ self.src_lang = src_lang
262
+ self.tgt_lang = tgt_lang
263
+ return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
264
+
265
+ # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast._switch_to_input_mode
266
+ def _switch_to_input_mode(self):
267
+ return self.set_src_lang_special_tokens(self.src_lang)
268
+
269
+ # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast._switch_to_target_mode
270
+ def _switch_to_target_mode(self):
271
+ return self.set_tgt_lang_special_tokens(self.tgt_lang)
272
+
273
+ def set_src_lang_special_tokens(self, src_lang) -> None:
274
+ """Reset the special tokens to the source lang setting.
275
+ Prefix=[src_lang_code], suffix = [eos]
276
+ """
277
+ self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
278
+
279
+ if self.cur_lang_code == self.unk_token_id:
280
+ logger.warning_once(
281
+ f"`tgt_lang={src_lang}` has not be found in the `vocabulary`. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id."
282
+ )
283
+
284
+ self.init_kwargs["src_lang"] = src_lang
285
+
286
+ self.prefix_tokens = [self.cur_lang_code]
287
+ self.suffix_tokens = [self.eos_token_id]
288
+
289
+ prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
290
+ suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
291
+
292
+ self._tokenizer.post_processor = processors.TemplateProcessing(
293
+ single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
294
+ pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
295
+ special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
296
+ )
297
+
298
+ def set_tgt_lang_special_tokens(self, lang: str) -> None:
299
+ """Reset the special tokens to the target lang setting.
300
+ Prefix=[eos, tgt_lang_code] and suffix=[eos].
301
+ """
302
+ self.cur_lang_code = self.convert_tokens_to_ids(lang)
303
+
304
+ if self.cur_lang_code == self.unk_token_id:
305
+ logger.warning_once(
306
+ f"`tgt_lang={lang}` has not be found in the `vocabulary`. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id."
307
+ )
308
+
309
+ self.init_kwargs["tgt_lang"] = lang
310
+
311
+ self.prefix_tokens = [self.eos_token_id, self.cur_lang_code]
312
+ self.suffix_tokens = [self.eos_token_id]
313
+
314
+ prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
315
+ suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
316
+
317
+ self._tokenizer.post_processor = processors.TemplateProcessing(
318
+ single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
319
+ pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
320
+ special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
321
+ )
322
+
323
+ # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast.save_vocabulary
324
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
325
+ if not self.can_save_slow_tokenizer:
326
+ raise ValueError(
327
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
328
+ "tokenizer."
329
+ )
330
+
331
+ if not os.path.isdir(save_directory):
332
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory.")
333
+ return
334
+ out_vocab_file = os.path.join(
335
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
336
+ )
337
+
338
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
339
+ copyfile(self.vocab_file, out_vocab_file)
340
+
341
+ return (out_vocab_file,)
342
+
343
+ @classmethod
344
+ def _from_pretrained(
345
+ cls,
346
+ resolved_vocab_files,
347
+ pretrained_model_name_or_path,
348
+ init_configuration,
349
+ *init_inputs,
350
+ token=None,
351
+ cache_dir=None,
352
+ local_files_only=False,
353
+ _commit_hash=None,
354
+ _is_local=False,
355
+ **kwargs,
356
+ ):
357
+ tokenizer = super()._from_pretrained(
358
+ resolved_vocab_files,
359
+ pretrained_model_name_or_path,
360
+ init_configuration,
361
+ *init_inputs,
362
+ token=token,
363
+ cache_dir=cache_dir,
364
+ local_files_only=local_files_only,
365
+ _commit_hash=_commit_hash,
366
+ _is_local=_is_local,
367
+ **kwargs,
368
+ )
369
+
370
+ # ensure also set after from pretrained
371
+ tokenizer.set_src_lang_special_tokens(tokenizer._src_lang)
372
+ tokenizer.set_tgt_lang_special_tokens(tokenizer._tgt_lang)
373
+
374
+ return tokenizer
375
+
376
+ def __call__(
377
+ self,
378
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
379
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
380
+ text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
381
+ text_pair_target: Optional[
382
+ Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
383
+ ] = None,
384
+ padding: Union[bool, str, PaddingStrategy] = True,
385
+ pad_to_multiple_of: Optional[int] = 2,
386
+ src_lang: Optional[str] = None,
387
+ tgt_lang: Optional[str] = None,
388
+ **kwargs,
389
+ ):
390
+ """
391
+ Args:
392
+ text (`str`, `List[str]`, `List[List[str]]`, *optional*):
393
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
394
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
395
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
396
+ text_pair (`str`, `List[str]`, `List[List[str]]`, *optional*):
397
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
398
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
399
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
400
+ text_target (`str`, `List[str]`, `List[List[str]]`, *optional*):
401
+ The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
402
+ list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
403
+ you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
404
+ text_pair_target (`str`, `List[str]`, `List[List[str]]`, *optional*):
405
+ The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
406
+ list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
407
+ you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
408
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
409
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
410
+ index) among:
411
+
412
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
413
+ sequence if provided).
414
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
415
+ acceptable input length for the model if that argument is not provided.
416
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
417
+ lengths).
418
+ pad_to_multiple_of (`int`, *optional*):
419
+ If set will pad the sequence to a multiple of the provided value.
420
+
421
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
422
+ `>= 7.5` (Volta).
423
+ src_lang (`str`, *optional*):
424
+ A string representing the source language. If not specified, the last `src_lang` specified (either
425
+ during initialization or when calling this tokenizer) will be used.
426
+ tgt_lang (`str`, *optional*):
427
+ A string representing the target language. If not specified, the last `tgt_lang` specified (either
428
+ during initialization or when calling this tokenizer) will be used.
429
+ kwargs (*optional*):
430
+ Remaining dictionary of keyword arguments that will be passed to [`PreTrainedTokenizerFast.__call__`].
431
+ """
432
+ if src_lang is not None:
433
+ self.src_lang = src_lang
434
+ if tgt_lang is not None:
435
+ self.tgt_lang = tgt_lang
436
+
437
+ output = super().__call__(
438
+ text=text,
439
+ text_pair=text_pair,
440
+ text_target=text_target,
441
+ text_pair_target=text_pair_target,
442
+ padding=padding,
443
+ pad_to_multiple_of=pad_to_multiple_of,
444
+ **kwargs,
445
+ )
446
+
447
+ return output
448
+
449
+
450
+ __all__ = ["SeamlessM4TTokenizerFast"]
docs/transformers/build/lib/transformers/models/seamless_m4t_v2/convert_fairseq2_to_hf.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Converting Meta SeamlessM4Tv2 checkpoints from seamless_communication to HF."""
16
+
17
+ import argparse
18
+ import os
19
+ from pathlib import Path
20
+
21
+ import torch
22
+ from accelerate.utils.modeling import find_tied_parameters
23
+ from seamless_communication.inference import Translator
24
+
25
+ from transformers import (
26
+ SeamlessM4TFeatureExtractor,
27
+ SeamlessM4TProcessor,
28
+ SeamlessM4TTokenizer,
29
+ SeamlessM4Tv2Config,
30
+ SeamlessM4Tv2Model,
31
+ )
32
+ from transformers.utils import logging
33
+
34
+
35
+ # fmt: off
36
+ UNIT_SUPPORTED_LANGUAGES = ["__arb__", "__ben__", "__cat__", "__ces__", "__cmn__", "__cym__", "__dan__", "__deu__", "__eng__", "__est__", "__fin__", "__fra__", "__hin__", "__ind__", "__ita__", "__jpn__", "__kan__", "__kor__", "__mlt__", "__nld__", "__pes__", "__pol__", "__por__", "__ron__", "__rus__", "__slk__", "__spa__", "__swe__", "__swh__", "__tam__", "__tel__", "__tgl__", "__tha__", "__tur__", "__ukr__", "__urd__", "__uzn__", "__vie__", ]
37
+ # fmt: on
38
+
39
+ # fmt: off
40
+ VOCODER_SUPPORTED_LANGUAGES = ["__arb__", "__ben__", "__cat__", "__ces__", "__cmn__", "__cym__", "__dan__", "__deu__", "__eng__", "__est__", "__fin__", "__fra__", "__hin__", "__ind__", "__ita__", "__jpn__", "__kor__", "__mlt__", "__nld__", "__pes__", "__pol__", "__por__", "__ron__", "__rus__", "__slk__", "__spa__", "__swe__", "__swh__", "__tel__", "__tgl__", "__tha__", "__tur__", "__ukr__", "__urd__", "__uzn__", "__vie__",]
41
+ # fmt: on
42
+
43
+ # fmt: off
44
+ LARGE_SUPPORTED_LANGUAGES = ["afr","amh","arb","ary","arz","asm","azj","bel","ben","bos","bul","cat","ceb","ces","ckb","cmn","cmn_Hant","cym","dan","deu","ell","eng","est","eus","fin","fra","fuv","gaz","gle","glg","guj","heb","hin","hrv","hun","hye","ibo","ind","isl","ita","jav","jpn","kan","kat","kaz","khk","khm","kir","kor","lao","lit","lug","luo","lvs","mai","mal","mar","mkd","mlt","mni","mya","nld","nno","nob","npi","nya","ory","pan","pbt","pes","pol","por","ron","rus","sat","slk","slv","sna","snd","som","spa","srp","swe","swh","tam","tel","tgk","tgl","tha","tur","ukr","urd","uzn","vie","yor","yue","zlm","zul",]
45
+ # fmt: on
46
+
47
+
48
+ def assert_param_count(model_1, model_2):
49
+ count_1 = sum(p[1].numel() for p in model_1.named_parameters() if "final_proj" not in p[0])
50
+ count_2 = sum(p[1].numel() for p in model_2.named_parameters() if "final_proj" not in p[0])
51
+ assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}"
52
+
53
+
54
+ def param_count(model):
55
+ return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0])
56
+
57
+
58
+ def _grab_best_device(use_gpu=True):
59
+ if torch.cuda.device_count() > 0 and use_gpu:
60
+ device = "cuda"
61
+ else:
62
+ device = "cpu"
63
+ return torch.device(device)
64
+
65
+
66
+ logging.set_verbosity_info()
67
+ logger = logging.get_logger(__name__)
68
+
69
+ vocoder_convert_list = [
70
+ ("ups", "hifi_gan.upsampler"),
71
+ ("conv_pre", "hifi_gan.conv_pre"),
72
+ ("resblocks", "hifi_gan.resblocks"),
73
+ ("conv_post", "hifi_gan.conv_post"),
74
+ ("lang", "language_embedding"),
75
+ ("spkr", "speaker_embedding"),
76
+ ("dict.", "unit_embedding."),
77
+ ("dur_predictor.conv1.0", "dur_predictor.conv1"),
78
+ ("dur_predictor.conv2.0", "dur_predictor.conv2"),
79
+ ]
80
+
81
+ # order is important
82
+ wav2vec_convert_list = [
83
+ ("speech_encoder_frontend.model_dim_proj", "feature_projection.projection"),
84
+ ("speech_encoder_frontend.post_extract_layer_norm", "feature_projection.layer_norm"),
85
+ ("speech_encoder_frontend.pos_encoder.conv", "encoder.pos_conv_embed.conv"),
86
+ ("speech_encoder.inner.layers", "encoder.layers"),
87
+ ("speech_encoder.inner_layer_norm", "encoder.layer_norm"),
88
+ ("speech_encoder.adaptor_layers", "adapter.layers"),
89
+ ("inner_proj", "intermediate_dense"),
90
+ ("self_attn.output_proj", "self_attn.linear_out"),
91
+ ("output_proj", "output_dense"),
92
+ ("self_attn.k_proj", "self_attn.linear_k"),
93
+ ("self_attn.v_proj", "self_attn.linear_v"),
94
+ ("self_attn.q_proj", "self_attn.linear_q"),
95
+ ("self_attn.sdpa.u_bias", "self_attn.pos_bias_u"),
96
+ ("self_attn.sdpa.v_bias", "self_attn.pos_bias_v"),
97
+ ("self_attn.sdpa.rel_k_embed", "self_attn.distance_embedding"),
98
+ ("self_attn.sdpa.r_proj", "self_attn.linear_pos"),
99
+ ("conv.pointwise_conv1", "conv_module.pointwise_conv1"),
100
+ ("conv.pointwise_conv2", "conv_module.pointwise_conv2"),
101
+ ("conv.depthwise_conv", "conv_module.depthwise_conv"),
102
+ ("conv.batch_norm", "conv_module.batch_norm"),
103
+ ("conv.layer_norm", "conv_module.depthwise_layer_norm"),
104
+ ("conv_layer_norm", "conv_module.layer_norm"),
105
+ ("speech_encoder.proj1", "intermediate_ffn.intermediate_dense"),
106
+ ("speech_encoder.proj2", "intermediate_ffn.output_dense"),
107
+ ("speech_encoder.layer_norm", "inner_layer_norm"),
108
+ ]
109
+
110
+ t2u_convert_list = [
111
+ ("t2u_model.final_proj", "lm_head"),
112
+ ("t2u_model.", "model."),
113
+ ("encoder_decoder_attn_layer_norm", "cross_attention_layer_norm"),
114
+ ("encoder_decoder_attn", "cross_attention"),
115
+ ("linear_k", "k_proj"),
116
+ ("linear_v", "v_proj"),
117
+ ("linear_q", "q_proj"),
118
+ ("ffn.inner_proj", "ffn.fc1"),
119
+ ("ffn.output_proj", "ffn.fc2"),
120
+ ("output_proj", "out_proj"),
121
+ ("decoder_frontend.embed_char", "decoder.embed_char"),
122
+ ("decoder_frontend.pos_emb_alpha_char", "decoder.pos_emb_alpha_char"),
123
+ ("decoder_frontend.embed", "decoder.embed_tokens"),
124
+ ("decoder_frontend.pos_emb_alpha", "decoder.pos_emb_alpha"),
125
+ ("conv1d.conv", "conv"),
126
+ ("conv1d_layer_norm", "conv_layer_norm"),
127
+ ("decoder_frontend.variance_adaptor", "decoder"),
128
+ ("duration_predictor.conv1.0", "duration_predictor.conv1"),
129
+ ("duration_predictor.conv2.0", "duration_predictor.conv2"),
130
+ ]
131
+
132
+ text_convert_list = [
133
+ ("text_encoder.", ""),
134
+ ("text_decoder.", ""),
135
+ ("text_encoder_frontend.embed", "embed_tokens"),
136
+ ("text_decoder_frontend.embed", "embed_tokens"),
137
+ ("encoder_decoder_attn_layer_norm", "cross_attention_layer_norm"),
138
+ ("encoder_decoder_attn", "cross_attention"),
139
+ ("linear_k", "k_proj"),
140
+ ("linear_v", "v_proj"),
141
+ ("linear_q", "q_proj"),
142
+ ("ffn.inner_proj", "ffn.fc1"),
143
+ ("ffn.output_proj", "ffn.fc2"),
144
+ ("output_proj", "out_proj"),
145
+ ("final_proj", "lm_head"),
146
+ ]
147
+
148
+ CUR_PATH = os.path.dirname(os.path.abspath(__file__))
149
+ default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
150
+ CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "huggingface", "hub")
151
+
152
+
153
+ def _load_hf_config():
154
+ return SeamlessM4Tv2Config()
155
+
156
+
157
+ def _convert_model(
158
+ original_model,
159
+ hf_model,
160
+ convert_list,
161
+ device,
162
+ unwanted_prefix="model.",
163
+ filter_state_dict="speech",
164
+ exclude_state_dict=None,
165
+ ):
166
+ state_dict = original_model.state_dict()
167
+
168
+ # filter func
169
+ if isinstance(filter_state_dict, str):
170
+
171
+ def filter_func(x):
172
+ return filter_state_dict in x[0]
173
+
174
+ else:
175
+
176
+ def filter_func(item):
177
+ if exclude_state_dict is not None and exclude_state_dict in item[0]:
178
+ return False
179
+ for filter_el in filter_state_dict:
180
+ if filter_el in item[0]:
181
+ return True
182
+
183
+ return False
184
+
185
+ state_dict = dict(filter(filter_func, state_dict.items()))
186
+
187
+ for k, v in list(state_dict.items()):
188
+ new_k = k[len(unwanted_prefix) :]
189
+ for old_layer_name, new_layer_name in convert_list:
190
+ if old_layer_name in new_k:
191
+ new_k = new_k.replace(old_layer_name, new_layer_name)
192
+
193
+ # must do it by hand
194
+ if ".layer_norm" in new_k and new_k.split(".layer_norm")[0][-1].isnumeric():
195
+ new_k = new_k.replace("layer_norm", "final_layer_norm")
196
+
197
+ state_dict[new_k] = state_dict.pop(k)
198
+
199
+ extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys())
200
+ extra_keys = set(extra_keys)
201
+ missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys())
202
+ missing_keys = set({k for k in missing_keys if "final_logits_bias" not in k})
203
+ if len(extra_keys) != 0:
204
+ raise ValueError(f"extra keys found: {extra_keys}")
205
+ if len(missing_keys) != 0:
206
+ raise ValueError(f"missing keys: {missing_keys}")
207
+ hf_model.load_state_dict(state_dict, strict=False)
208
+ n_params = param_count(hf_model)
209
+
210
+ logger.info(f"model loaded: {round(n_params / 1e6, 1)}M params")
211
+
212
+ hf_model.eval()
213
+ hf_model.to(device)
214
+ del state_dict
215
+
216
+ return hf_model
217
+
218
+
219
+ def load_model(save_dir, model_type, repo_id):
220
+ """
221
+ Meta SeamlessM4Tv2 is made of 8 main components:
222
+ - speech_encoder (#1) and speech_encoder_frontend (#2)
223
+ - t2u_model (#3)
224
+ - text_encoder (#4) and text_encoder_frontend (#5)
225
+ - text_decoder (#6) [and text_decoder_frontend (#5) = equals to text_encoder_frontend]
226
+ - final_proj (#7)
227
+ - vocoder (#8)
228
+ """
229
+ device = _grab_best_device()
230
+ name = "seamlessM4T_v2_large"
231
+
232
+ original_model = Translator(name, "vocoder_v2", device, dtype=torch.float32)
233
+
234
+ ######### TOKENIZER
235
+
236
+ langs = LARGE_SUPPORTED_LANGUAGES
237
+ langs = [f"__{lang}__" for lang in langs]
238
+ vocab_file = os.path.join(os.path.expanduser("~"), "tokenizer", model_type, "tokenizer.model")
239
+
240
+ save_dir = os.path.join(save_dir, name)
241
+ Path(save_dir).mkdir(exist_ok=True)
242
+
243
+ tokenizer = SeamlessM4TTokenizer(vocab_file, additional_special_tokens=langs)
244
+
245
+ sanity_check_lang_id = tokenizer.convert_tokens_to_ids("__fra__")
246
+
247
+ tokenizer.save_pretrained(save_dir)
248
+ tokenizer = SeamlessM4TTokenizer.from_pretrained(save_dir)
249
+
250
+ if sanity_check_lang_id != tokenizer.convert_tokens_to_ids("__fra__"):
251
+ raise ValueError(
252
+ f"Error in tokenizer saving/loading - __fra__ lang id is not coherent: {sanity_check_lang_id} vs {tokenizer.convert_tokens_to_ids('__fra__')}"
253
+ )
254
+
255
+ ####### get language to ids dict
256
+ text_decoder_lang_code_to_id = {lang.replace("__", ""): tokenizer.convert_tokens_to_ids(lang) for lang in langs}
257
+ # offset: vocoder unit vocab size + 5 (for EOS/PAD/BOS/UNK/MSK) + len(supported_languages)
258
+ t2u_lang_code_to_id = {
259
+ code.replace("__", ""): i + 10005 + len(UNIT_SUPPORTED_LANGUAGES)
260
+ for i, code in enumerate(UNIT_SUPPORTED_LANGUAGES)
261
+ }
262
+ vocoder_lang_code_to_id = {code.replace("__", ""): i for i, code in enumerate(VOCODER_SUPPORTED_LANGUAGES)}
263
+
264
+ ######### FE
265
+
266
+ fe = SeamlessM4TFeatureExtractor(language_code=langs)
267
+
268
+ fe.save_pretrained(save_dir)
269
+ fe = SeamlessM4TFeatureExtractor.from_pretrained(save_dir)
270
+
271
+ processor = SeamlessM4TProcessor(feature_extractor=fe, tokenizer=tokenizer)
272
+ processor.save_pretrained(save_dir)
273
+ processor.push_to_hub(repo_id=repo_id, create_pr=True)
274
+
275
+ processor = SeamlessM4TProcessor.from_pretrained(save_dir)
276
+
277
+ ######## Model
278
+
279
+ # init config
280
+ hf_config = _load_hf_config()
281
+
282
+ ######## get id_to_text and char_to_id from original model tokenizers
283
+ id_to_text = {i: original_model.text_tokenizer.model.index_to_token(i) for i in range(hf_config.vocab_size)}
284
+ char_to_id = {
285
+ original_model.model.t2u_model.decoder_frontend.char_tokenizer.model.index_to_token(i): i for i in range(10904)
286
+ }
287
+
288
+ # init model
289
+ hf_model = SeamlessM4Tv2Model(hf_config)
290
+
291
+ hf_model.generation_config.__setattr__("text_decoder_lang_to_code_id", text_decoder_lang_code_to_id)
292
+ hf_model.generation_config.__setattr__("t2u_lang_code_to_id", t2u_lang_code_to_id)
293
+ hf_model.generation_config.__setattr__("vocoder_lang_code_to_id", vocoder_lang_code_to_id)
294
+ hf_model.generation_config.__setattr__("id_to_text", id_to_text)
295
+ hf_model.generation_config.__setattr__("char_to_id", char_to_id)
296
+
297
+ # -1. take care of vocoder
298
+ # similarly to speech T5 must apply and remove weight norm
299
+ hf_model.vocoder.apply_weight_norm()
300
+ hf_model.vocoder = _convert_model(
301
+ original_model,
302
+ hf_model.vocoder,
303
+ vocoder_convert_list,
304
+ device,
305
+ unwanted_prefix="vocoder.code_generator.",
306
+ filter_state_dict="vocoder",
307
+ )
308
+ hf_model.vocoder.remove_weight_norm()
309
+
310
+ # 1. take care of speech encoder
311
+ wav2vec = hf_model.speech_encoder
312
+ hf_model.speech_encoder = _convert_model(
313
+ original_model, wav2vec, wav2vec_convert_list, device, unwanted_prefix="model.", filter_state_dict="speech"
314
+ )
315
+
316
+ # 2. take care of t2u
317
+
318
+ hf_model.t2u_model = _convert_model(
319
+ original_model,
320
+ hf_model.t2u_model,
321
+ t2u_convert_list,
322
+ device,
323
+ unwanted_prefix="model.",
324
+ filter_state_dict="t2u_model",
325
+ )
326
+
327
+ # 3. take care of text encoder
328
+ hf_model.text_encoder = _convert_model(
329
+ original_model,
330
+ hf_model.text_encoder,
331
+ text_convert_list,
332
+ device,
333
+ unwanted_prefix="model.",
334
+ filter_state_dict=["model.text_encoder"],
335
+ exclude_state_dict="t2u_model",
336
+ )
337
+
338
+ # 4. take care of text decoder
339
+ hf_model.text_decoder = _convert_model(
340
+ original_model,
341
+ hf_model.text_decoder,
342
+ text_convert_list,
343
+ device,
344
+ unwanted_prefix="model.",
345
+ filter_state_dict=["model.text_decoder"],
346
+ exclude_state_dict="t2u_model",
347
+ )
348
+
349
+ # 5. take care of final proj
350
+ hf_model.lm_head = _convert_model(
351
+ original_model,
352
+ hf_model.lm_head,
353
+ [("final_proj.", "")],
354
+ device,
355
+ unwanted_prefix="model.",
356
+ filter_state_dict=["model.final_proj"],
357
+ exclude_state_dict="t2u_model",
358
+ )
359
+
360
+ # sanity check
361
+ print(find_tied_parameters(hf_model))
362
+
363
+ count_1 = param_count(hf_model)
364
+ count_2 = param_count(original_model)
365
+
366
+ print(f"HF MODEL:{count_1}, ORIGINAL_MODEL: {count_2}, diff:{count_1 - count_2}")
367
+ print(f"HF MODEL excluding embeddings:{hf_model.num_parameters(exclude_embeddings=True)}")
368
+
369
+ del original_model
370
+
371
+ hf_model.generation_config._from_model_config = False
372
+ hf_model.save_pretrained(save_dir)
373
+ hf_model.push_to_hub(repo_id=repo_id, create_pr=True)
374
+ hf_model = SeamlessM4Tv2Model.from_pretrained(save_dir)
375
+
376
+
377
+ if __name__ == "__main__":
378
+ parser = argparse.ArgumentParser()
379
+ # Required parameters
380
+
381
+ parser.add_argument(
382
+ "--model_type",
383
+ default="large",
384
+ type=str,
385
+ help="Model type.",
386
+ )
387
+
388
+ parser.add_argument(
389
+ "--save_dir",
390
+ default="/home/ubuntu/weights_v2",
391
+ type=str,
392
+ help="Path to the output PyTorch model.",
393
+ )
394
+
395
+ parser.add_argument(
396
+ "--repo_id",
397
+ default="facebook/seamless-m4t-v2-large",
398
+ type=str,
399
+ help="Repo ID.",
400
+ )
401
+
402
+ args = parser.parse_args()
403
+
404
+ load_model(args.save_dir, args.model_type, args.repo_id)
docs/transformers/build/lib/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py ADDED
The diff for this file is too large to render. See raw diff
 
docs/transformers/build/lib/transformers/models/segformer/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_segformer import *
22
+ from .feature_extraction_segformer import *
23
+ from .image_processing_segformer import *
24
+ from .modeling_segformer import *
25
+ from .modeling_tf_segformer import *
26
+ else:
27
+ import sys
28
+
29
+ _file = globals()["__file__"]
30
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/segformer/configuration_segformer.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 NVIDIA and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """SegFormer model configuration"""
16
+
17
+ import warnings
18
+ from collections import OrderedDict
19
+ from typing import Mapping
20
+
21
+ from packaging import version
22
+
23
+ from ...configuration_utils import PretrainedConfig
24
+ from ...onnx import OnnxConfig
25
+ from ...utils import logging
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class SegformerConfig(PretrainedConfig):
32
+ r"""
33
+ This is the configuration class to store the configuration of a [`SegformerModel`]. It is used to instantiate an
34
+ SegFormer model according to the specified arguments, defining the model architecture. Instantiating a
35
+ configuration with the defaults will yield a similar configuration to that of the SegFormer
36
+ [nvidia/segformer-b0-finetuned-ade-512-512](https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512)
37
+ architecture.
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+ Args:
43
+ num_channels (`int`, *optional*, defaults to 3):
44
+ The number of input channels.
45
+ num_encoder_blocks (`int`, *optional*, defaults to 4):
46
+ The number of encoder blocks (i.e. stages in the Mix Transformer encoder).
47
+ depths (`List[int]`, *optional*, defaults to `[2, 2, 2, 2]`):
48
+ The number of layers in each encoder block.
49
+ sr_ratios (`List[int]`, *optional*, defaults to `[8, 4, 2, 1]`):
50
+ Sequence reduction ratios in each encoder block.
51
+ hidden_sizes (`List[int]`, *optional*, defaults to `[32, 64, 160, 256]`):
52
+ Dimension of each of the encoder blocks.
53
+ patch_sizes (`List[int]`, *optional*, defaults to `[7, 3, 3, 3]`):
54
+ Patch size before each encoder block.
55
+ strides (`List[int]`, *optional*, defaults to `[4, 2, 2, 2]`):
56
+ Stride before each encoder block.
57
+ num_attention_heads (`List[int]`, *optional*, defaults to `[1, 2, 5, 8]`):
58
+ Number of attention heads for each attention layer in each block of the Transformer encoder.
59
+ mlp_ratios (`List[int]`, *optional*, defaults to `[4, 4, 4, 4]`):
60
+ Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the
61
+ encoder blocks.
62
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
63
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
64
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
65
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
66
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
67
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
68
+ The dropout ratio for the attention probabilities.
69
+ classifier_dropout_prob (`float`, *optional*, defaults to 0.1):
70
+ The dropout probability before the classification head.
71
+ initializer_range (`float`, *optional*, defaults to 0.02):
72
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
73
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
74
+ The dropout probability for stochastic depth, used in the blocks of the Transformer encoder.
75
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
76
+ The epsilon used by the layer normalization layers.
77
+ decoder_hidden_size (`int`, *optional*, defaults to 256):
78
+ The dimension of the all-MLP decode head.
79
+ semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
80
+ The index that is ignored by the loss function of the semantic segmentation model.
81
+
82
+ Example:
83
+
84
+ ```python
85
+ >>> from transformers import SegformerModel, SegformerConfig
86
+
87
+ >>> # Initializing a SegFormer nvidia/segformer-b0-finetuned-ade-512-512 style configuration
88
+ >>> configuration = SegformerConfig()
89
+
90
+ >>> # Initializing a model from the nvidia/segformer-b0-finetuned-ade-512-512 style configuration
91
+ >>> model = SegformerModel(configuration)
92
+
93
+ >>> # Accessing the model configuration
94
+ >>> configuration = model.config
95
+ ```"""
96
+
97
+ model_type = "segformer"
98
+
99
+ def __init__(
100
+ self,
101
+ num_channels=3,
102
+ num_encoder_blocks=4,
103
+ depths=[2, 2, 2, 2],
104
+ sr_ratios=[8, 4, 2, 1],
105
+ hidden_sizes=[32, 64, 160, 256],
106
+ patch_sizes=[7, 3, 3, 3],
107
+ strides=[4, 2, 2, 2],
108
+ num_attention_heads=[1, 2, 5, 8],
109
+ mlp_ratios=[4, 4, 4, 4],
110
+ hidden_act="gelu",
111
+ hidden_dropout_prob=0.0,
112
+ attention_probs_dropout_prob=0.0,
113
+ classifier_dropout_prob=0.1,
114
+ initializer_range=0.02,
115
+ drop_path_rate=0.1,
116
+ layer_norm_eps=1e-6,
117
+ decoder_hidden_size=256,
118
+ semantic_loss_ignore_index=255,
119
+ **kwargs,
120
+ ):
121
+ super().__init__(**kwargs)
122
+
123
+ if "reshape_last_stage" in kwargs and kwargs["reshape_last_stage"] is False:
124
+ warnings.warn(
125
+ "Reshape_last_stage is set to False in this config. This argument is deprecated and will soon be"
126
+ " removed, as the behaviour will default to that of reshape_last_stage = True.",
127
+ FutureWarning,
128
+ )
129
+
130
+ self.num_channels = num_channels
131
+ self.num_encoder_blocks = num_encoder_blocks
132
+ self.depths = depths
133
+ self.sr_ratios = sr_ratios
134
+ self.hidden_sizes = hidden_sizes
135
+ self.patch_sizes = patch_sizes
136
+ self.strides = strides
137
+ self.mlp_ratios = mlp_ratios
138
+ self.num_attention_heads = num_attention_heads
139
+ self.hidden_act = hidden_act
140
+ self.hidden_dropout_prob = hidden_dropout_prob
141
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
142
+ self.classifier_dropout_prob = classifier_dropout_prob
143
+ self.initializer_range = initializer_range
144
+ self.drop_path_rate = drop_path_rate
145
+ self.layer_norm_eps = layer_norm_eps
146
+ self.decoder_hidden_size = decoder_hidden_size
147
+ self.reshape_last_stage = kwargs.get("reshape_last_stage", True)
148
+ self.semantic_loss_ignore_index = semantic_loss_ignore_index
149
+
150
+
151
+ class SegformerOnnxConfig(OnnxConfig):
152
+ torch_onnx_minimum_version = version.parse("1.11")
153
+
154
+ @property
155
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
156
+ return OrderedDict(
157
+ [
158
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
159
+ ]
160
+ )
161
+
162
+ @property
163
+ def atol_for_validation(self) -> float:
164
+ return 1e-4
165
+
166
+ @property
167
+ def default_onnx_opset(self) -> int:
168
+ return 12
169
+
170
+
171
+ __all__ = ["SegformerConfig", "SegformerOnnxConfig"]
docs/transformers/build/lib/transformers/models/segformer/convert_segformer_original_to_pytorch.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert SegFormer checkpoints."""
16
+
17
+ import argparse
18
+ import json
19
+ from collections import OrderedDict
20
+ from pathlib import Path
21
+
22
+ import requests
23
+ import torch
24
+ from huggingface_hub import hf_hub_download
25
+ from PIL import Image
26
+
27
+ from transformers import (
28
+ SegformerConfig,
29
+ SegformerForImageClassification,
30
+ SegformerForSemanticSegmentation,
31
+ SegformerImageProcessor,
32
+ )
33
+ from transformers.utils import logging
34
+
35
+
36
+ logging.set_verbosity_info()
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ def rename_keys(state_dict, encoder_only=False):
41
+ new_state_dict = OrderedDict()
42
+ for key, value in state_dict.items():
43
+ if encoder_only and not key.startswith("head"):
44
+ key = "segformer.encoder." + key
45
+ if key.startswith("backbone"):
46
+ key = key.replace("backbone", "segformer.encoder")
47
+ if "patch_embed" in key:
48
+ # replace for example patch_embed1 by patch_embeddings.0
49
+ idx = key[key.find("patch_embed") + len("patch_embed")]
50
+ key = key.replace(f"patch_embed{idx}", f"patch_embeddings.{int(idx) - 1}")
51
+ if "norm" in key:
52
+ key = key.replace("norm", "layer_norm")
53
+ if "segformer.encoder.layer_norm" in key:
54
+ # replace for example layer_norm1 by layer_norm.0
55
+ idx = key[key.find("segformer.encoder.layer_norm") + len("segformer.encoder.layer_norm")]
56
+ key = key.replace(f"layer_norm{idx}", f"layer_norm.{int(idx) - 1}")
57
+ if "layer_norm1" in key:
58
+ key = key.replace("layer_norm1", "layer_norm_1")
59
+ if "layer_norm2" in key:
60
+ key = key.replace("layer_norm2", "layer_norm_2")
61
+ if "block" in key:
62
+ # replace for example block1 by block.0
63
+ idx = key[key.find("block") + len("block")]
64
+ key = key.replace(f"block{idx}", f"block.{int(idx) - 1}")
65
+ if "attn.q" in key:
66
+ key = key.replace("attn.q", "attention.self.query")
67
+ if "attn.proj" in key:
68
+ key = key.replace("attn.proj", "attention.output.dense")
69
+ if "attn" in key:
70
+ key = key.replace("attn", "attention.self")
71
+ if "fc1" in key:
72
+ key = key.replace("fc1", "dense1")
73
+ if "fc2" in key:
74
+ key = key.replace("fc2", "dense2")
75
+ if "linear_pred" in key:
76
+ key = key.replace("linear_pred", "classifier")
77
+ if "linear_fuse" in key:
78
+ key = key.replace("linear_fuse.conv", "linear_fuse")
79
+ key = key.replace("linear_fuse.bn", "batch_norm")
80
+ if "linear_c" in key:
81
+ # replace for example linear_c4 by linear_c.3
82
+ idx = key[key.find("linear_c") + len("linear_c")]
83
+ key = key.replace(f"linear_c{idx}", f"linear_c.{int(idx) - 1}")
84
+ if key.startswith("head"):
85
+ key = key.replace("head", "classifier")
86
+ new_state_dict[key] = value
87
+
88
+ return new_state_dict
89
+
90
+
91
+ def read_in_k_v(state_dict, config):
92
+ # for each of the encoder blocks:
93
+ for i in range(config.num_encoder_blocks):
94
+ for j in range(config.depths[i]):
95
+ # read in weights + bias of keys and values (which is a single matrix in the original implementation)
96
+ kv_weight = state_dict.pop(f"segformer.encoder.block.{i}.{j}.attention.self.kv.weight")
97
+ kv_bias = state_dict.pop(f"segformer.encoder.block.{i}.{j}.attention.self.kv.bias")
98
+ # next, add keys and values (in that order) to the state dict
99
+ state_dict[f"segformer.encoder.block.{i}.{j}.attention.self.key.weight"] = kv_weight[
100
+ : config.hidden_sizes[i], :
101
+ ]
102
+ state_dict[f"segformer.encoder.block.{i}.{j}.attention.self.key.bias"] = kv_bias[: config.hidden_sizes[i]]
103
+ state_dict[f"segformer.encoder.block.{i}.{j}.attention.self.value.weight"] = kv_weight[
104
+ config.hidden_sizes[i] :, :
105
+ ]
106
+ state_dict[f"segformer.encoder.block.{i}.{j}.attention.self.value.bias"] = kv_bias[
107
+ config.hidden_sizes[i] :
108
+ ]
109
+
110
+
111
+ # We will verify our results on a COCO image
112
+ def prepare_img():
113
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
114
+ image = Image.open(requests.get(url, stream=True).raw)
115
+
116
+ return image
117
+
118
+
119
+ @torch.no_grad()
120
+ def convert_segformer_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path):
121
+ """
122
+ Copy/paste/tweak model's weights to our SegFormer structure.
123
+ """
124
+
125
+ # load default SegFormer configuration
126
+ config = SegformerConfig()
127
+ encoder_only = False
128
+
129
+ # set attributes based on model_name
130
+ repo_id = "huggingface/label-files"
131
+ if "segformer" in model_name:
132
+ size = model_name[len("segformer.") : len("segformer.") + 2]
133
+ if "ade" in model_name:
134
+ config.num_labels = 150
135
+ filename = "ade20k-id2label.json"
136
+ expected_shape = (1, 150, 128, 128)
137
+ elif "city" in model_name:
138
+ config.num_labels = 19
139
+ filename = "cityscapes-id2label.json"
140
+ expected_shape = (1, 19, 128, 128)
141
+ else:
142
+ raise ValueError(f"Model {model_name} not supported")
143
+ elif "mit" in model_name:
144
+ encoder_only = True
145
+ size = model_name[4:6]
146
+ config.num_labels = 1000
147
+ filename = "imagenet-1k-id2label.json"
148
+ expected_shape = (1, 1000)
149
+ else:
150
+ raise ValueError(f"Model {model_name} not supported")
151
+
152
+ # set config attributes
153
+ id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
154
+ id2label = {int(k): v for k, v in id2label.items()}
155
+ config.id2label = id2label
156
+ config.label2id = {v: k for k, v in id2label.items()}
157
+ if size == "b0":
158
+ pass
159
+ elif size == "b1":
160
+ config.hidden_sizes = [64, 128, 320, 512]
161
+ config.decoder_hidden_size = 256
162
+ elif size == "b2":
163
+ config.hidden_sizes = [64, 128, 320, 512]
164
+ config.decoder_hidden_size = 768
165
+ config.depths = [3, 4, 6, 3]
166
+ elif size == "b3":
167
+ config.hidden_sizes = [64, 128, 320, 512]
168
+ config.decoder_hidden_size = 768
169
+ config.depths = [3, 4, 18, 3]
170
+ elif size == "b4":
171
+ config.hidden_sizes = [64, 128, 320, 512]
172
+ config.decoder_hidden_size = 768
173
+ config.depths = [3, 8, 27, 3]
174
+ elif size == "b5":
175
+ config.hidden_sizes = [64, 128, 320, 512]
176
+ config.decoder_hidden_size = 768
177
+ config.depths = [3, 6, 40, 3]
178
+ else:
179
+ raise ValueError(f"Size {size} not supported")
180
+
181
+ # load image processor (only resize + normalize)
182
+ image_processor = SegformerImageProcessor(
183
+ image_scale=(512, 512), keep_ratio=False, align=False, do_random_crop=False
184
+ )
185
+
186
+ # prepare image
187
+ image = prepare_img()
188
+ pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
189
+
190
+ logger.info(f"Converting model {model_name}...")
191
+
192
+ # load original state dict
193
+ if encoder_only:
194
+ state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True)
195
+ else:
196
+ state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True)["state_dict"]
197
+
198
+ # rename keys
199
+ state_dict = rename_keys(state_dict, encoder_only=encoder_only)
200
+ if not encoder_only:
201
+ del state_dict["decode_head.conv_seg.weight"]
202
+ del state_dict["decode_head.conv_seg.bias"]
203
+
204
+ # key and value matrices need special treatment
205
+ read_in_k_v(state_dict, config)
206
+
207
+ # create HuggingFace model and load state dict
208
+ if encoder_only:
209
+ config.reshape_last_stage = False
210
+ model = SegformerForImageClassification(config)
211
+ else:
212
+ model = SegformerForSemanticSegmentation(config)
213
+ model.load_state_dict(state_dict)
214
+ model.eval()
215
+
216
+ # forward pass
217
+ outputs = model(pixel_values)
218
+ logits = outputs.logits
219
+
220
+ # set expected_slice based on model name
221
+ # ADE20k checkpoints
222
+ if model_name == "segformer.b0.512x512.ade.160k":
223
+ expected_slice = torch.tensor(
224
+ [
225
+ [[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]],
226
+ [[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]],
227
+ [[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]],
228
+ ]
229
+ )
230
+ elif model_name == "segformer.b1.512x512.ade.160k":
231
+ expected_slice = torch.tensor(
232
+ [
233
+ [[-7.5820, -8.7231, -8.3215], [-8.0600, -10.3529, -10.0304], [-7.5208, -9.4103, -9.6239]],
234
+ [[-12.6918, -13.8994, -13.7137], [-13.3196, -15.7523, -15.4789], [-12.9343, -14.8757, -14.9689]],
235
+ [[-11.1911, -11.9421, -11.3243], [-11.3342, -13.6839, -13.3581], [-10.3909, -12.1832, -12.4858]],
236
+ ]
237
+ )
238
+ elif model_name == "segformer.b2.512x512.ade.160k":
239
+ expected_slice = torch.tensor(
240
+ [
241
+ [[-11.8173, -14.3850, -16.3128], [-14.5648, -16.5804, -18.6568], [-14.7223, -15.7387, -18.4218]],
242
+ [[-15.7290, -17.9171, -19.4423], [-18.3105, -19.9448, -21.4661], [-17.9296, -18.6497, -20.7910]],
243
+ [[-15.0783, -17.0336, -18.2789], [-16.8771, -18.6870, -20.1612], [-16.2454, -17.1426, -19.5055]],
244
+ ]
245
+ )
246
+ elif model_name == "segformer.b3.512x512.ade.160k":
247
+ expected_slice = torch.tensor(
248
+ [
249
+ [[-9.0878, -10.2081, -10.1891], [-9.3144, -10.7941, -10.9843], [-9.2294, -10.3855, -10.5704]],
250
+ [[-12.2316, -13.9068, -13.6102], [-12.9161, -14.3702, -14.3235], [-12.5233, -13.7174, -13.7932]],
251
+ [[-14.6275, -15.2490, -14.9727], [-14.3400, -15.9687, -16.2827], [-14.1484, -15.4033, -15.8937]],
252
+ ]
253
+ )
254
+ elif model_name == "segformer.b4.512x512.ade.160k":
255
+ expected_slice = torch.tensor(
256
+ [
257
+ [[-12.3144, -13.2447, -14.0802], [-13.3614, -14.5816, -15.6117], [-13.3340, -14.4433, -16.2219]],
258
+ [[-19.2781, -20.4128, -20.7506], [-20.6153, -21.6566, -22.0998], [-19.9800, -21.0430, -22.1494]],
259
+ [[-18.8739, -19.7804, -21.1834], [-20.1233, -21.6765, -23.2944], [-20.0315, -21.2641, -23.6944]],
260
+ ]
261
+ )
262
+ elif model_name == "segformer.b5.640x640.ade.160k":
263
+ expected_slice = torch.tensor(
264
+ [
265
+ [[-9.5524, -12.0835, -11.7348], [-10.5229, -13.6446, -14.5662], [-9.5842, -12.8851, -13.9414]],
266
+ [[-15.3432, -17.5323, -17.0818], [-16.3330, -18.9255, -19.2101], [-15.1340, -17.7848, -18.3971]],
267
+ [[-12.6072, -14.9486, -14.6631], [-13.7629, -17.0907, -17.7745], [-12.7899, -16.1695, -17.1671]],
268
+ ]
269
+ )
270
+ # Cityscapes checkpoints
271
+ elif model_name == "segformer.b0.1024x1024.city.160k":
272
+ expected_slice = torch.tensor(
273
+ [
274
+ [[-11.9295, -13.4057, -14.8106], [-13.3431, -14.8179, -15.3781], [-14.2836, -15.5942, -16.1588]],
275
+ [[-11.4906, -12.8067, -13.6564], [-13.1189, -14.0500, -14.1543], [-13.8748, -14.5136, -14.8789]],
276
+ [[0.5374, 0.1067, -0.4742], [0.1141, -0.2255, -0.7099], [-0.3000, -0.5924, -1.3105]],
277
+ ]
278
+ )
279
+ elif model_name == "segformer.b0.512x1024.city.160k":
280
+ expected_slice = torch.tensor(
281
+ [
282
+ [[-7.8217, -9.8767, -10.1717], [-9.4438, -10.9058, -11.4047], [-9.7939, -12.3495, -12.1079]],
283
+ [[-7.1514, -9.5336, -10.0860], [-9.7776, -11.6822, -11.8439], [-10.1411, -12.7655, -12.8972]],
284
+ [[0.3021, 0.0805, -0.2310], [-0.0328, -0.1605, -0.2714], [-0.1408, -0.5477, -0.6976]],
285
+ ]
286
+ )
287
+ elif model_name == "segformer.b0.640x1280.city.160k":
288
+ expected_slice = torch.tensor(
289
+ [
290
+ [
291
+ [-1.1372e01, -1.2787e01, -1.3477e01],
292
+ [-1.2536e01, -1.4194e01, -1.4409e01],
293
+ [-1.3217e01, -1.4888e01, -1.5327e01],
294
+ ],
295
+ [
296
+ [-1.4791e01, -1.7122e01, -1.8277e01],
297
+ [-1.7163e01, -1.9192e01, -1.9533e01],
298
+ [-1.7897e01, -1.9991e01, -2.0315e01],
299
+ ],
300
+ [
301
+ [7.6723e-01, 4.1921e-01, -7.7878e-02],
302
+ [4.7772e-01, 9.5557e-03, -2.8082e-01],
303
+ [3.6032e-01, -2.4826e-01, -5.1168e-01],
304
+ ],
305
+ ]
306
+ )
307
+ elif model_name == "segformer.b0.768x768.city.160k":
308
+ expected_slice = torch.tensor(
309
+ [
310
+ [[-9.4959, -11.3087, -11.7479], [-11.0025, -12.6540, -12.3319], [-11.4064, -13.0487, -12.9905]],
311
+ [[-9.8905, -11.3084, -12.0854], [-11.1726, -12.7698, -12.9583], [-11.5985, -13.3278, -14.1774]],
312
+ [[0.2213, 0.0192, -0.2466], [-0.1731, -0.4213, -0.4874], [-0.3126, -0.6541, -1.1389]],
313
+ ]
314
+ )
315
+ elif model_name == "segformer.b1.1024x1024.city.160k":
316
+ expected_slice = torch.tensor(
317
+ [
318
+ [[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]],
319
+ [[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]],
320
+ [[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]],
321
+ ]
322
+ )
323
+ elif model_name == "segformer.b2.1024x1024.city.160k":
324
+ expected_slice = torch.tensor(
325
+ [
326
+ [[-16.0976, -16.4856, -17.3962], [-16.6234, -19.0342, -19.7685], [-16.0900, -18.0661, -19.1180]],
327
+ [[-18.4750, -18.8488, -19.5074], [-19.4030, -22.1570, -22.5977], [-19.1191, -20.8486, -22.3783]],
328
+ [[-4.5178, -5.5037, -6.5109], [-5.0884, -7.2174, -8.0334], [-4.4156, -5.8117, -7.2970]],
329
+ ]
330
+ )
331
+ elif model_name == "segformer.b3.1024x1024.city.160k":
332
+ expected_slice = torch.tensor(
333
+ [
334
+ [[-14.2081, -14.4732, -14.1977], [-14.5867, -16.4423, -16.6356], [-13.4441, -14.9685, -16.8696]],
335
+ [[-14.4576, -14.7073, -15.0451], [-15.0816, -17.6237, -17.9873], [-14.4213, -16.0199, -18.5992]],
336
+ [[-4.7349, -4.9588, -5.0966], [-4.3210, -6.9325, -7.2591], [-3.4312, -4.7484, -7.1917]],
337
+ ]
338
+ )
339
+ elif model_name == "segformer.b4.1024x1024.city.160k":
340
+ expected_slice = torch.tensor(
341
+ [
342
+ [[-11.7737, -11.9526, -11.3273], [-13.6692, -14.4574, -13.8878], [-13.8937, -14.6924, -15.9345]],
343
+ [[-14.6706, -14.5330, -14.1306], [-16.1502, -16.8180, -16.4269], [-16.8338, -17.8939, -20.1746]],
344
+ [[1.0491, 0.8289, 1.0310], [1.1044, 0.5219, 0.8055], [1.0899, 0.6926, 0.5590]],
345
+ ]
346
+ )
347
+ elif model_name == "segformer.b5.1024x1024.city.160k":
348
+ expected_slice = torch.tensor(
349
+ [
350
+ [[-12.5641, -13.4777, -13.0684], [-13.9587, -15.8983, -16.6557], [-13.3109, -15.7350, -16.3141]],
351
+ [[-14.7074, -15.4352, -14.5944], [-16.6353, -18.1663, -18.6120], [-15.1702, -18.0329, -18.1547]],
352
+ [[-1.7990, -2.0951, -1.7784], [-2.6397, -3.8245, -3.9686], [-1.5264, -2.8126, -2.9316]],
353
+ ]
354
+ )
355
+ else:
356
+ predicted_class_idx = logits.argmax(-1).item()
357
+ print("Predicted class:", model.config.id2label[predicted_class_idx])
358
+
359
+ # verify logits
360
+ if not encoder_only:
361
+ assert logits.shape == expected_shape
362
+ assert torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-2)
363
+
364
+ # finally, save model and image processor
365
+ logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
366
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
367
+ model.save_pretrained(pytorch_dump_folder_path)
368
+ image_processor.save_pretrained(pytorch_dump_folder_path)
369
+
370
+
371
+ if __name__ == "__main__":
372
+ parser = argparse.ArgumentParser()
373
+
374
+ parser.add_argument(
375
+ "--model_name",
376
+ default="segformer.b0.512x512.ade.160k",
377
+ type=str,
378
+ help="Name of the model you'd like to convert.",
379
+ )
380
+ parser.add_argument(
381
+ "--checkpoint_path", default=None, type=str, help="Path to the original PyTorch checkpoint (.pth file)."
382
+ )
383
+ parser.add_argument(
384
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
385
+ )
386
+ args = parser.parse_args()
387
+ convert_segformer_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path)
docs/transformers/build/lib/transformers/models/segformer/feature_extraction_segformer.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Feature extractor class for SegFormer."""
16
+
17
+ import warnings
18
+
19
+ from ...utils import logging
20
+ from ...utils.import_utils import requires
21
+ from .image_processing_segformer import SegformerImageProcessor
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ @requires(backends=("vision",))
28
+ class SegformerFeatureExtractor(SegformerImageProcessor):
29
+ def __init__(self, *args, **kwargs) -> None:
30
+ warnings.warn(
31
+ "The class SegformerFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
32
+ " Please use SegformerImageProcessor instead.",
33
+ FutureWarning,
34
+ )
35
+ super().__init__(*args, **kwargs)
36
+
37
+
38
+ __all__ = ["SegformerFeatureExtractor"]
docs/transformers/build/lib/transformers/models/segformer/image_processing_segformer.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for Segformer."""
16
+
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...image_processing_utils import INIT_SERVICE_KWARGS, BaseImageProcessor, BatchFeature, get_size_dict
22
+ from ...image_transforms import resize, to_channel_dimension_format
23
+ from ...image_utils import (
24
+ IMAGENET_DEFAULT_MEAN,
25
+ IMAGENET_DEFAULT_STD,
26
+ ChannelDimension,
27
+ ImageInput,
28
+ PILImageResampling,
29
+ infer_channel_dimension_format,
30
+ is_scaled_image,
31
+ make_list_of_images,
32
+ to_numpy_array,
33
+ valid_images,
34
+ validate_preprocess_arguments,
35
+ )
36
+ from ...utils import (
37
+ TensorType,
38
+ filter_out_non_signature_kwargs,
39
+ is_torch_available,
40
+ is_torch_tensor,
41
+ is_vision_available,
42
+ logging,
43
+ )
44
+ from ...utils.deprecation import deprecate_kwarg
45
+ from ...utils.import_utils import requires
46
+
47
+
48
+ if is_vision_available():
49
+ import PIL.Image
50
+
51
+ if is_torch_available():
52
+ import torch
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+
58
+ @requires(backends=("vision",))
59
+ class SegformerImageProcessor(BaseImageProcessor):
60
+ r"""
61
+ Constructs a Segformer image processor.
62
+
63
+ Args:
64
+ do_resize (`bool`, *optional*, defaults to `True`):
65
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
66
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
67
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`):
68
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
69
+ method.
70
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
71
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
72
+ `preprocess` method.
73
+ do_rescale (`bool`, *optional*, defaults to `True`):
74
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
75
+ parameter in the `preprocess` method.
76
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
77
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
78
+ method.
79
+ do_normalize (`bool`, *optional*, defaults to `True`):
80
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
81
+ method.
82
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
83
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
84
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
85
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
86
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
87
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
88
+ do_reduce_labels (`bool`, *optional*, defaults to `False`):
89
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
90
+ used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
91
+ background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
92
+ `preprocess` method.
93
+ """
94
+
95
+ model_input_names = ["pixel_values"]
96
+
97
+ @deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0")
98
+ @filter_out_non_signature_kwargs(extra=INIT_SERVICE_KWARGS)
99
+ def __init__(
100
+ self,
101
+ do_resize: bool = True,
102
+ size: Dict[str, int] = None,
103
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
104
+ do_rescale: bool = True,
105
+ rescale_factor: Union[int, float] = 1 / 255,
106
+ do_normalize: bool = True,
107
+ image_mean: Optional[Union[float, List[float]]] = None,
108
+ image_std: Optional[Union[float, List[float]]] = None,
109
+ do_reduce_labels: bool = False,
110
+ **kwargs,
111
+ ) -> None:
112
+ super().__init__(**kwargs)
113
+ size = size if size is not None else {"height": 512, "width": 512}
114
+ size = get_size_dict(size)
115
+ self.do_resize = do_resize
116
+ self.size = size
117
+ self.resample = resample
118
+ self.do_rescale = do_rescale
119
+ self.rescale_factor = rescale_factor
120
+ self.do_normalize = do_normalize
121
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
122
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
123
+ self.do_reduce_labels = do_reduce_labels
124
+
125
+ @classmethod
126
+ def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
127
+ """
128
+ Overrides the `from_dict` method from the base class to save support of deprecated `reduce_labels` in old configs
129
+ """
130
+ image_processor_dict = image_processor_dict.copy()
131
+ if "reduce_labels" in image_processor_dict:
132
+ image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels")
133
+ return super().from_dict(image_processor_dict, **kwargs)
134
+
135
+ # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize
136
+ def resize(
137
+ self,
138
+ image: np.ndarray,
139
+ size: Dict[str, int],
140
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
141
+ data_format: Optional[Union[str, ChannelDimension]] = None,
142
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
143
+ **kwargs,
144
+ ) -> np.ndarray:
145
+ """
146
+ Resize an image to `(size["height"], size["width"])`.
147
+
148
+ Args:
149
+ image (`np.ndarray`):
150
+ Image to resize.
151
+ size (`Dict[str, int]`):
152
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
153
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
154
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
155
+ data_format (`ChannelDimension` or `str`, *optional*):
156
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
157
+ image is used. Can be one of:
158
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
159
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
160
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
161
+ input_data_format (`ChannelDimension` or `str`, *optional*):
162
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
163
+ from the input image. Can be one of:
164
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
165
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
166
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
167
+
168
+ Returns:
169
+ `np.ndarray`: The resized image.
170
+ """
171
+ size = get_size_dict(size)
172
+ if "height" not in size or "width" not in size:
173
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
174
+ output_size = (size["height"], size["width"])
175
+ return resize(
176
+ image,
177
+ size=output_size,
178
+ resample=resample,
179
+ data_format=data_format,
180
+ input_data_format=input_data_format,
181
+ **kwargs,
182
+ )
183
+
184
+ # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label
185
+ def reduce_label(self, label: ImageInput) -> np.ndarray:
186
+ label = to_numpy_array(label)
187
+ # Avoid using underflow conversion
188
+ label[label == 0] = 255
189
+ label = label - 1
190
+ label[label == 254] = 255
191
+ return label
192
+
193
+ def _preprocess(
194
+ self,
195
+ image: ImageInput,
196
+ do_reduce_labels: bool,
197
+ do_resize: bool,
198
+ do_rescale: bool,
199
+ do_normalize: bool,
200
+ size: Optional[Dict[str, int]] = None,
201
+ resample: PILImageResampling = None,
202
+ rescale_factor: Optional[float] = None,
203
+ image_mean: Optional[Union[float, List[float]]] = None,
204
+ image_std: Optional[Union[float, List[float]]] = None,
205
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
206
+ ):
207
+ if do_reduce_labels:
208
+ image = self.reduce_label(image)
209
+
210
+ if do_resize:
211
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
212
+
213
+ if do_rescale:
214
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
215
+
216
+ if do_normalize:
217
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
218
+
219
+ return image
220
+
221
+ def _preprocess_image(
222
+ self,
223
+ image: ImageInput,
224
+ do_resize: Optional[bool] = None,
225
+ size: Dict[str, int] = None,
226
+ resample: PILImageResampling = None,
227
+ do_rescale: Optional[bool] = None,
228
+ rescale_factor: Optional[float] = None,
229
+ do_normalize: Optional[bool] = None,
230
+ image_mean: Optional[Union[float, List[float]]] = None,
231
+ image_std: Optional[Union[float, List[float]]] = None,
232
+ data_format: Optional[Union[str, ChannelDimension]] = None,
233
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
234
+ ) -> np.ndarray:
235
+ """Preprocesses a single image."""
236
+ # All transformations expect numpy arrays.
237
+ image = to_numpy_array(image)
238
+ if do_rescale and is_scaled_image(image):
239
+ logger.warning_once(
240
+ "It looks like you are trying to rescale already rescaled images. If the input"
241
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
242
+ )
243
+ if input_data_format is None:
244
+ input_data_format = infer_channel_dimension_format(image)
245
+ image = self._preprocess(
246
+ image=image,
247
+ do_reduce_labels=False,
248
+ do_resize=do_resize,
249
+ size=size,
250
+ resample=resample,
251
+ do_rescale=do_rescale,
252
+ rescale_factor=rescale_factor,
253
+ do_normalize=do_normalize,
254
+ image_mean=image_mean,
255
+ image_std=image_std,
256
+ input_data_format=input_data_format,
257
+ )
258
+ if data_format is not None:
259
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
260
+ return image
261
+
262
+ def _preprocess_mask(
263
+ self,
264
+ segmentation_map: ImageInput,
265
+ do_reduce_labels: Optional[bool] = None,
266
+ do_resize: Optional[bool] = None,
267
+ size: Dict[str, int] = None,
268
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
269
+ ) -> np.ndarray:
270
+ """Preprocesses a single mask."""
271
+ segmentation_map = to_numpy_array(segmentation_map)
272
+ # Add channel dimension if missing - needed for certain transformations
273
+ if segmentation_map.ndim == 2:
274
+ added_channel_dim = True
275
+ segmentation_map = segmentation_map[None, ...]
276
+ input_data_format = ChannelDimension.FIRST
277
+ else:
278
+ added_channel_dim = False
279
+ if input_data_format is None:
280
+ input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
281
+ # reduce zero label if needed
282
+ segmentation_map = self._preprocess(
283
+ image=segmentation_map,
284
+ do_reduce_labels=do_reduce_labels,
285
+ do_resize=do_resize,
286
+ resample=PILImageResampling.NEAREST,
287
+ size=size,
288
+ do_rescale=False,
289
+ do_normalize=False,
290
+ input_data_format=input_data_format,
291
+ )
292
+ # Remove extra channel dimension if added for processing
293
+ if added_channel_dim:
294
+ segmentation_map = segmentation_map.squeeze(0)
295
+ segmentation_map = segmentation_map.astype(np.int64)
296
+ return segmentation_map
297
+
298
+ def __call__(self, images, segmentation_maps=None, **kwargs):
299
+ """
300
+ Preprocesses a batch of images and optionally segmentation maps.
301
+
302
+ Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
303
+ passed in as positional arguments.
304
+ """
305
+ return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
306
+
307
+ @deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0")
308
+ @filter_out_non_signature_kwargs()
309
+ def preprocess(
310
+ self,
311
+ images: ImageInput,
312
+ segmentation_maps: Optional[ImageInput] = None,
313
+ do_resize: Optional[bool] = None,
314
+ size: Optional[Dict[str, int]] = None,
315
+ resample: PILImageResampling = None,
316
+ do_rescale: Optional[bool] = None,
317
+ rescale_factor: Optional[float] = None,
318
+ do_normalize: Optional[bool] = None,
319
+ image_mean: Optional[Union[float, List[float]]] = None,
320
+ image_std: Optional[Union[float, List[float]]] = None,
321
+ do_reduce_labels: Optional[bool] = None,
322
+ return_tensors: Optional[Union[str, TensorType]] = None,
323
+ data_format: ChannelDimension = ChannelDimension.FIRST,
324
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
325
+ ) -> PIL.Image.Image:
326
+ """
327
+ Preprocess an image or batch of images.
328
+
329
+ Args:
330
+ images (`ImageInput`):
331
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
332
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
333
+ segmentation_maps (`ImageInput`, *optional*):
334
+ Segmentation map to preprocess.
335
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
336
+ Whether to resize the image.
337
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
338
+ Size of the image after `resize` is applied.
339
+ resample (`int`, *optional*, defaults to `self.resample`):
340
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
341
+ has an effect if `do_resize` is set to `True`.
342
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
343
+ Whether to rescale the image values between [0 - 1].
344
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
345
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
346
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
347
+ Whether to normalize the image.
348
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
349
+ Image mean.
350
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
351
+ Image standard deviation.
352
+ do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
353
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
354
+ is used for background, and background itself is not included in all classes of a dataset (e.g.
355
+ ADE20k). The background label will be replaced by 255.
356
+ return_tensors (`str` or `TensorType`, *optional*):
357
+ The type of tensors to return. Can be one of:
358
+ - Unset: Return a list of `np.ndarray`.
359
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
360
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
361
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
362
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
363
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
364
+ The channel dimension format for the output image. Can be one of:
365
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
366
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
367
+ input_data_format (`ChannelDimension` or `str`, *optional*):
368
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
369
+ from the input image. Can be one of:
370
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
371
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
372
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
373
+ """
374
+ do_resize = do_resize if do_resize is not None else self.do_resize
375
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
376
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
377
+ do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
378
+ resample = resample if resample is not None else self.resample
379
+ size = size if size is not None else self.size
380
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
381
+ image_mean = image_mean if image_mean is not None else self.image_mean
382
+ image_std = image_std if image_std is not None else self.image_std
383
+
384
+ images = make_list_of_images(images)
385
+
386
+ if segmentation_maps is not None:
387
+ segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
388
+
389
+ if not valid_images(images):
390
+ raise ValueError(
391
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
392
+ "torch.Tensor, tf.Tensor or jax.ndarray."
393
+ )
394
+ validate_preprocess_arguments(
395
+ do_rescale=do_rescale,
396
+ rescale_factor=rescale_factor,
397
+ do_normalize=do_normalize,
398
+ image_mean=image_mean,
399
+ image_std=image_std,
400
+ do_resize=do_resize,
401
+ size=size,
402
+ resample=resample,
403
+ )
404
+
405
+ images = [
406
+ self._preprocess_image(
407
+ image=img,
408
+ do_resize=do_resize,
409
+ resample=resample,
410
+ size=size,
411
+ do_rescale=do_rescale,
412
+ rescale_factor=rescale_factor,
413
+ do_normalize=do_normalize,
414
+ image_mean=image_mean,
415
+ image_std=image_std,
416
+ data_format=data_format,
417
+ input_data_format=input_data_format,
418
+ )
419
+ for img in images
420
+ ]
421
+
422
+ data = {"pixel_values": images}
423
+
424
+ if segmentation_maps is not None:
425
+ segmentation_maps = [
426
+ self._preprocess_mask(
427
+ segmentation_map=segmentation_map,
428
+ do_reduce_labels=do_reduce_labels,
429
+ do_resize=do_resize,
430
+ size=size,
431
+ input_data_format=input_data_format,
432
+ )
433
+ for segmentation_map in segmentation_maps
434
+ ]
435
+ data["labels"] = segmentation_maps
436
+
437
+ return BatchFeature(data=data, tensor_type=return_tensors)
438
+
439
+ # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->Segformer
440
+ def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
441
+ """
442
+ Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
443
+
444
+ Args:
445
+ outputs ([`SegformerForSemanticSegmentation`]):
446
+ Raw outputs of the model.
447
+ target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
448
+ List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
449
+ predictions will not be resized.
450
+
451
+ Returns:
452
+ semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
453
+ segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
454
+ specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
455
+ """
456
+ # TODO: add support for other frameworks
457
+ logits = outputs.logits
458
+
459
+ # Resize logits and compute semantic segmentation maps
460
+ if target_sizes is not None:
461
+ if len(logits) != len(target_sizes):
462
+ raise ValueError(
463
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
464
+ )
465
+
466
+ if is_torch_tensor(target_sizes):
467
+ target_sizes = target_sizes.numpy()
468
+
469
+ semantic_segmentation = []
470
+
471
+ for idx in range(len(logits)):
472
+ resized_logits = torch.nn.functional.interpolate(
473
+ logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
474
+ )
475
+ semantic_map = resized_logits[0].argmax(dim=0)
476
+ semantic_segmentation.append(semantic_map)
477
+ else:
478
+ semantic_segmentation = logits.argmax(dim=1)
479
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
480
+
481
+ return semantic_segmentation
482
+
483
+
484
+ __all__ = ["SegformerImageProcessor"]
docs/transformers/build/lib/transformers/models/segformer/modeling_segformer.py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 NVIDIA The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch SegFormer model."""
16
+
17
+ import math
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
+
25
+ from ...activations import ACT2FN
26
+ from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput, SemanticSegmenterOutput
27
+ from ...modeling_utils import PreTrainedModel
28
+ from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
29
+ from ...utils import (
30
+ add_code_sample_docstrings,
31
+ add_start_docstrings,
32
+ add_start_docstrings_to_model_forward,
33
+ logging,
34
+ replace_return_docstrings,
35
+ )
36
+ from .configuration_segformer import SegformerConfig
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ # General docstring
43
+ _CONFIG_FOR_DOC = "SegformerConfig"
44
+
45
+ # Base docstring
46
+ _CHECKPOINT_FOR_DOC = "nvidia/mit-b0"
47
+ _EXPECTED_OUTPUT_SHAPE = [1, 256, 16, 16]
48
+
49
+ # Image classification docstring
50
+ _IMAGE_CLASS_CHECKPOINT = "nvidia/mit-b0"
51
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
52
+
53
+
54
+ class SegFormerImageClassifierOutput(ImageClassifierOutput):
55
+ """
56
+ Base class for outputs of image classification models.
57
+
58
+ Args:
59
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
60
+ Classification (or regression if config.num_labels==1) loss.
61
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
62
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
63
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
64
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
65
+ one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
66
+ called feature maps) of the model at the output of each stage.
67
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
68
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
69
+ sequence_length)`.
70
+
71
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
72
+ heads.
73
+ """
74
+
75
+ loss: Optional[torch.FloatTensor] = None
76
+ logits: Optional[torch.FloatTensor] = None
77
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
78
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
79
+
80
+
81
+ # Copied from transformers.models.beit.modeling_beit.drop_path
82
+ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
83
+ """
84
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
85
+
86
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
87
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
88
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
89
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
90
+ argument.
91
+ """
92
+ if drop_prob == 0.0 or not training:
93
+ return input
94
+ keep_prob = 1 - drop_prob
95
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
96
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
97
+ random_tensor.floor_() # binarize
98
+ output = input.div(keep_prob) * random_tensor
99
+ return output
100
+
101
+
102
+ # Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath with ConvNext->Segformer
103
+ class SegformerDropPath(nn.Module):
104
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
105
+
106
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
107
+ super().__init__()
108
+ self.drop_prob = drop_prob
109
+
110
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ return drop_path(hidden_states, self.drop_prob, self.training)
112
+
113
+ def extra_repr(self) -> str:
114
+ return "p={}".format(self.drop_prob)
115
+
116
+
117
+ class SegformerOverlapPatchEmbeddings(nn.Module):
118
+ """Construct the overlapping patch embeddings."""
119
+
120
+ def __init__(self, patch_size, stride, num_channels, hidden_size):
121
+ super().__init__()
122
+ self.proj = nn.Conv2d(
123
+ num_channels,
124
+ hidden_size,
125
+ kernel_size=patch_size,
126
+ stride=stride,
127
+ padding=patch_size // 2,
128
+ )
129
+
130
+ self.layer_norm = nn.LayerNorm(hidden_size)
131
+
132
+ def forward(self, pixel_values):
133
+ embeddings = self.proj(pixel_values)
134
+ _, _, height, width = embeddings.shape
135
+ # (batch_size, num_channels, height, width) -> (batch_size, num_channels, height*width) -> (batch_size, height*width, num_channels)
136
+ # this can be fed to a Transformer layer
137
+ embeddings = embeddings.flatten(2).transpose(1, 2)
138
+ embeddings = self.layer_norm(embeddings)
139
+ return embeddings, height, width
140
+
141
+
142
+ class SegformerEfficientSelfAttention(nn.Module):
143
+ """SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT
144
+ paper](https://arxiv.org/abs/2102.12122)."""
145
+
146
+ def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
147
+ super().__init__()
148
+ self.hidden_size = hidden_size
149
+ self.num_attention_heads = num_attention_heads
150
+
151
+ if self.hidden_size % self.num_attention_heads != 0:
152
+ raise ValueError(
153
+ f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
154
+ f"heads ({self.num_attention_heads})"
155
+ )
156
+
157
+ self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
158
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
159
+
160
+ self.query = nn.Linear(self.hidden_size, self.all_head_size)
161
+ self.key = nn.Linear(self.hidden_size, self.all_head_size)
162
+ self.value = nn.Linear(self.hidden_size, self.all_head_size)
163
+
164
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
165
+
166
+ self.sr_ratio = sequence_reduction_ratio
167
+ if sequence_reduction_ratio > 1:
168
+ self.sr = nn.Conv2d(
169
+ hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio
170
+ )
171
+ self.layer_norm = nn.LayerNorm(hidden_size)
172
+
173
+ def transpose_for_scores(self, hidden_states):
174
+ new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
175
+ hidden_states = hidden_states.view(new_shape)
176
+ return hidden_states.permute(0, 2, 1, 3)
177
+
178
+ def forward(
179
+ self,
180
+ hidden_states,
181
+ height,
182
+ width,
183
+ output_attentions=False,
184
+ ):
185
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
186
+
187
+ if self.sr_ratio > 1:
188
+ batch_size, seq_len, num_channels = hidden_states.shape
189
+ # Reshape to (batch_size, num_channels, height, width)
190
+ hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
191
+ # Apply sequence reduction
192
+ hidden_states = self.sr(hidden_states)
193
+ # Reshape back to (batch_size, seq_len, num_channels)
194
+ hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1)
195
+ hidden_states = self.layer_norm(hidden_states)
196
+
197
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
198
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
199
+
200
+ # Take the dot product between "query" and "key" to get the raw attention scores.
201
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
202
+
203
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
204
+
205
+ # Normalize the attention scores to probabilities.
206
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
207
+
208
+ # This is actually dropping out entire tokens to attend to, which might
209
+ # seem a bit unusual, but is taken from the original Transformer paper.
210
+ attention_probs = self.dropout(attention_probs)
211
+
212
+ context_layer = torch.matmul(attention_probs, value_layer)
213
+
214
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
215
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
216
+ context_layer = context_layer.view(new_context_layer_shape)
217
+
218
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
219
+
220
+ return outputs
221
+
222
+
223
+ class SegformerSelfOutput(nn.Module):
224
+ def __init__(self, config, hidden_size):
225
+ super().__init__()
226
+ self.dense = nn.Linear(hidden_size, hidden_size)
227
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
228
+
229
+ def forward(self, hidden_states, input_tensor):
230
+ hidden_states = self.dense(hidden_states)
231
+ hidden_states = self.dropout(hidden_states)
232
+ return hidden_states
233
+
234
+
235
+ class SegformerAttention(nn.Module):
236
+ def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
237
+ super().__init__()
238
+ self.self = SegformerEfficientSelfAttention(
239
+ config=config,
240
+ hidden_size=hidden_size,
241
+ num_attention_heads=num_attention_heads,
242
+ sequence_reduction_ratio=sequence_reduction_ratio,
243
+ )
244
+ self.output = SegformerSelfOutput(config, hidden_size=hidden_size)
245
+ self.pruned_heads = set()
246
+
247
+ def prune_heads(self, heads):
248
+ if len(heads) == 0:
249
+ return
250
+ heads, index = find_pruneable_heads_and_indices(
251
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
252
+ )
253
+
254
+ # Prune linear layers
255
+ self.self.query = prune_linear_layer(self.self.query, index)
256
+ self.self.key = prune_linear_layer(self.self.key, index)
257
+ self.self.value = prune_linear_layer(self.self.value, index)
258
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
259
+
260
+ # Update hyper params and store pruned heads
261
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
262
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
263
+ self.pruned_heads = self.pruned_heads.union(heads)
264
+
265
+ def forward(self, hidden_states, height, width, output_attentions=False):
266
+ self_outputs = self.self(hidden_states, height, width, output_attentions)
267
+
268
+ attention_output = self.output(self_outputs[0], hidden_states)
269
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
270
+ return outputs
271
+
272
+
273
+ class SegformerDWConv(nn.Module):
274
+ def __init__(self, dim=768):
275
+ super().__init__()
276
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
277
+
278
+ def forward(self, hidden_states, height, width):
279
+ batch_size, seq_len, num_channels = hidden_states.shape
280
+ hidden_states = hidden_states.transpose(1, 2).view(batch_size, num_channels, height, width)
281
+ hidden_states = self.dwconv(hidden_states)
282
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
283
+
284
+ return hidden_states
285
+
286
+
287
+ class SegformerMixFFN(nn.Module):
288
+ def __init__(self, config, in_features, hidden_features=None, out_features=None):
289
+ super().__init__()
290
+ out_features = out_features or in_features
291
+ self.dense1 = nn.Linear(in_features, hidden_features)
292
+ self.dwconv = SegformerDWConv(hidden_features)
293
+ if isinstance(config.hidden_act, str):
294
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
295
+ else:
296
+ self.intermediate_act_fn = config.hidden_act
297
+ self.dense2 = nn.Linear(hidden_features, out_features)
298
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
299
+
300
+ def forward(self, hidden_states, height, width):
301
+ hidden_states = self.dense1(hidden_states)
302
+ hidden_states = self.dwconv(hidden_states, height, width)
303
+ hidden_states = self.intermediate_act_fn(hidden_states)
304
+ hidden_states = self.dropout(hidden_states)
305
+ hidden_states = self.dense2(hidden_states)
306
+ hidden_states = self.dropout(hidden_states)
307
+ return hidden_states
308
+
309
+
310
+ class SegformerLayer(nn.Module):
311
+ """This corresponds to the Block class in the original implementation."""
312
+
313
+ def __init__(self, config, hidden_size, num_attention_heads, drop_path, sequence_reduction_ratio, mlp_ratio):
314
+ super().__init__()
315
+ self.layer_norm_1 = nn.LayerNorm(hidden_size)
316
+ self.attention = SegformerAttention(
317
+ config,
318
+ hidden_size=hidden_size,
319
+ num_attention_heads=num_attention_heads,
320
+ sequence_reduction_ratio=sequence_reduction_ratio,
321
+ )
322
+ self.drop_path = SegformerDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
323
+ self.layer_norm_2 = nn.LayerNorm(hidden_size)
324
+ mlp_hidden_size = int(hidden_size * mlp_ratio)
325
+ self.mlp = SegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size)
326
+
327
+ def forward(self, hidden_states, height, width, output_attentions=False):
328
+ self_attention_outputs = self.attention(
329
+ self.layer_norm_1(hidden_states), # in Segformer, layernorm is applied before self-attention
330
+ height,
331
+ width,
332
+ output_attentions=output_attentions,
333
+ )
334
+
335
+ attention_output = self_attention_outputs[0]
336
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
337
+
338
+ # first residual connection (with stochastic depth)
339
+ attention_output = self.drop_path(attention_output)
340
+ hidden_states = attention_output + hidden_states
341
+
342
+ mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width)
343
+
344
+ # second residual connection (with stochastic depth)
345
+ mlp_output = self.drop_path(mlp_output)
346
+ layer_output = mlp_output + hidden_states
347
+
348
+ outputs = (layer_output,) + outputs
349
+
350
+ return outputs
351
+
352
+
353
+ class SegformerEncoder(nn.Module):
354
+ def __init__(self, config):
355
+ super().__init__()
356
+ self.config = config
357
+
358
+ # stochastic depth decay rule
359
+ drop_path_decays = [
360
+ x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")
361
+ ]
362
+
363
+ # patch embeddings
364
+ embeddings = []
365
+ for i in range(config.num_encoder_blocks):
366
+ embeddings.append(
367
+ SegformerOverlapPatchEmbeddings(
368
+ patch_size=config.patch_sizes[i],
369
+ stride=config.strides[i],
370
+ num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
371
+ hidden_size=config.hidden_sizes[i],
372
+ )
373
+ )
374
+ self.patch_embeddings = nn.ModuleList(embeddings)
375
+
376
+ # Transformer blocks
377
+ blocks = []
378
+ cur = 0
379
+ for i in range(config.num_encoder_blocks):
380
+ # each block consists of layers
381
+ layers = []
382
+ if i != 0:
383
+ cur += config.depths[i - 1]
384
+ for j in range(config.depths[i]):
385
+ layers.append(
386
+ SegformerLayer(
387
+ config,
388
+ hidden_size=config.hidden_sizes[i],
389
+ num_attention_heads=config.num_attention_heads[i],
390
+ drop_path=drop_path_decays[cur + j],
391
+ sequence_reduction_ratio=config.sr_ratios[i],
392
+ mlp_ratio=config.mlp_ratios[i],
393
+ )
394
+ )
395
+ blocks.append(nn.ModuleList(layers))
396
+
397
+ self.block = nn.ModuleList(blocks)
398
+
399
+ # Layer norms
400
+ self.layer_norm = nn.ModuleList(
401
+ [nn.LayerNorm(config.hidden_sizes[i]) for i in range(config.num_encoder_blocks)]
402
+ )
403
+
404
+ def forward(
405
+ self,
406
+ pixel_values: torch.FloatTensor,
407
+ output_attentions: Optional[bool] = False,
408
+ output_hidden_states: Optional[bool] = False,
409
+ return_dict: Optional[bool] = True,
410
+ ) -> Union[Tuple, BaseModelOutput]:
411
+ all_hidden_states = () if output_hidden_states else None
412
+ all_self_attentions = () if output_attentions else None
413
+
414
+ batch_size = pixel_values.shape[0]
415
+
416
+ hidden_states = pixel_values
417
+ for idx, x in enumerate(zip(self.patch_embeddings, self.block, self.layer_norm)):
418
+ embedding_layer, block_layer, norm_layer = x
419
+ # first, obtain patch embeddings
420
+ hidden_states, height, width = embedding_layer(hidden_states)
421
+ # second, send embeddings through blocks
422
+ for i, blk in enumerate(block_layer):
423
+ layer_outputs = blk(hidden_states, height, width, output_attentions)
424
+ hidden_states = layer_outputs[0]
425
+ if output_attentions:
426
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
427
+ # third, apply layer norm
428
+ hidden_states = norm_layer(hidden_states)
429
+ # fourth, optionally reshape back to (batch_size, num_channels, height, width)
430
+ if idx != len(self.patch_embeddings) - 1 or (
431
+ idx == len(self.patch_embeddings) - 1 and self.config.reshape_last_stage
432
+ ):
433
+ hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
434
+ if output_hidden_states:
435
+ all_hidden_states = all_hidden_states + (hidden_states,)
436
+
437
+ if not return_dict:
438
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
439
+ return BaseModelOutput(
440
+ last_hidden_state=hidden_states,
441
+ hidden_states=all_hidden_states,
442
+ attentions=all_self_attentions,
443
+ )
444
+
445
+
446
+ class SegformerPreTrainedModel(PreTrainedModel):
447
+ """
448
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
449
+ models.
450
+ """
451
+
452
+ config_class = SegformerConfig
453
+ base_model_prefix = "segformer"
454
+ main_input_name = "pixel_values"
455
+
456
+ def _init_weights(self, module):
457
+ """Initialize the weights"""
458
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
459
+ # Slightly different from the TF version which uses truncated_normal for initialization
460
+ # cf https://github.com/pytorch/pytorch/pull/5617
461
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
462
+ if module.bias is not None:
463
+ module.bias.data.zero_()
464
+ elif isinstance(module, nn.Embedding):
465
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
466
+ if module.padding_idx is not None:
467
+ module.weight.data[module.padding_idx].zero_()
468
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
469
+ module.bias.data.zero_()
470
+ module.weight.data.fill_(1.0)
471
+
472
+
473
+ SEGFORMER_START_DOCSTRING = r"""
474
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
475
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
476
+ behavior.
477
+
478
+ Parameters:
479
+ config ([`SegformerConfig`]): Model configuration class with all the parameters of the model.
480
+ Initializing with a config file does not load the weights associated with the model, only the
481
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
482
+ """
483
+
484
+ SEGFORMER_INPUTS_DOCSTRING = r"""
485
+
486
+ Args:
487
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
488
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
489
+ [`AutoImageProcessor`]. See [`SegformerImageProcessor.__call__`] for details.
490
+
491
+ output_attentions (`bool`, *optional*):
492
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
493
+ tensors for more detail.
494
+ output_hidden_states (`bool`, *optional*):
495
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
496
+ more detail.
497
+ return_dict (`bool`, *optional*):
498
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
499
+ """
500
+
501
+
502
+ @add_start_docstrings(
503
+ "The bare SegFormer encoder (Mix-Transformer) outputting raw hidden-states without any specific head on top.",
504
+ SEGFORMER_START_DOCSTRING,
505
+ )
506
+ class SegformerModel(SegformerPreTrainedModel):
507
+ def __init__(self, config):
508
+ super().__init__(config)
509
+ self.config = config
510
+
511
+ # hierarchical Transformer encoder
512
+ self.encoder = SegformerEncoder(config)
513
+
514
+ # Initialize weights and apply final processing
515
+ self.post_init()
516
+
517
+ def _prune_heads(self, heads_to_prune):
518
+ """
519
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
520
+ class PreTrainedModel
521
+ """
522
+ for layer, heads in heads_to_prune.items():
523
+ self.encoder.layer[layer].attention.prune_heads(heads)
524
+
525
+ @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
526
+ @add_code_sample_docstrings(
527
+ checkpoint=_CHECKPOINT_FOR_DOC,
528
+ output_type=BaseModelOutput,
529
+ config_class=_CONFIG_FOR_DOC,
530
+ modality="vision",
531
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
532
+ )
533
+ def forward(
534
+ self,
535
+ pixel_values: torch.FloatTensor,
536
+ output_attentions: Optional[bool] = None,
537
+ output_hidden_states: Optional[bool] = None,
538
+ return_dict: Optional[bool] = None,
539
+ ) -> Union[Tuple, BaseModelOutput]:
540
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
541
+ output_hidden_states = (
542
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
543
+ )
544
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
545
+
546
+ encoder_outputs = self.encoder(
547
+ pixel_values,
548
+ output_attentions=output_attentions,
549
+ output_hidden_states=output_hidden_states,
550
+ return_dict=return_dict,
551
+ )
552
+ sequence_output = encoder_outputs[0]
553
+
554
+ if not return_dict:
555
+ return (sequence_output,) + encoder_outputs[1:]
556
+
557
+ return BaseModelOutput(
558
+ last_hidden_state=sequence_output,
559
+ hidden_states=encoder_outputs.hidden_states,
560
+ attentions=encoder_outputs.attentions,
561
+ )
562
+
563
+
564
+ @add_start_docstrings(
565
+ """
566
+ SegFormer Model transformer with an image classification head on top (a linear layer on top of the final hidden
567
+ states) e.g. for ImageNet.
568
+ """,
569
+ SEGFORMER_START_DOCSTRING,
570
+ )
571
+ class SegformerForImageClassification(SegformerPreTrainedModel):
572
+ def __init__(self, config):
573
+ super().__init__(config)
574
+
575
+ self.num_labels = config.num_labels
576
+ self.segformer = SegformerModel(config)
577
+
578
+ # Classifier head
579
+ self.classifier = nn.Linear(config.hidden_sizes[-1], config.num_labels)
580
+
581
+ # Initialize weights and apply final processing
582
+ self.post_init()
583
+
584
+ @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
585
+ @add_code_sample_docstrings(
586
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
587
+ output_type=SegFormerImageClassifierOutput,
588
+ config_class=_CONFIG_FOR_DOC,
589
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
590
+ )
591
+ def forward(
592
+ self,
593
+ pixel_values: Optional[torch.FloatTensor] = None,
594
+ labels: Optional[torch.LongTensor] = None,
595
+ output_attentions: Optional[bool] = None,
596
+ output_hidden_states: Optional[bool] = None,
597
+ return_dict: Optional[bool] = None,
598
+ ) -> Union[Tuple, SegFormerImageClassifierOutput]:
599
+ r"""
600
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
601
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
602
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
603
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
604
+ """
605
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
606
+
607
+ outputs = self.segformer(
608
+ pixel_values,
609
+ output_attentions=output_attentions,
610
+ output_hidden_states=output_hidden_states,
611
+ return_dict=return_dict,
612
+ )
613
+
614
+ sequence_output = outputs[0]
615
+
616
+ # convert last hidden states to (batch_size, height*width, hidden_size)
617
+ batch_size = sequence_output.shape[0]
618
+ if self.config.reshape_last_stage:
619
+ # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
620
+ sequence_output = sequence_output.permute(0, 2, 3, 1)
621
+ sequence_output = sequence_output.reshape(batch_size, -1, self.config.hidden_sizes[-1])
622
+
623
+ # global average pooling
624
+ sequence_output = sequence_output.mean(dim=1)
625
+
626
+ logits = self.classifier(sequence_output)
627
+
628
+ loss = None
629
+ if labels is not None:
630
+ if self.config.problem_type is None:
631
+ if self.num_labels == 1:
632
+ self.config.problem_type = "regression"
633
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
634
+ self.config.problem_type = "single_label_classification"
635
+ else:
636
+ self.config.problem_type = "multi_label_classification"
637
+
638
+ if self.config.problem_type == "regression":
639
+ loss_fct = MSELoss()
640
+ if self.num_labels == 1:
641
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
642
+ else:
643
+ loss = loss_fct(logits, labels)
644
+ elif self.config.problem_type == "single_label_classification":
645
+ loss_fct = CrossEntropyLoss()
646
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
647
+ elif self.config.problem_type == "multi_label_classification":
648
+ loss_fct = BCEWithLogitsLoss()
649
+ loss = loss_fct(logits, labels)
650
+ if not return_dict:
651
+ output = (logits,) + outputs[1:]
652
+ return ((loss,) + output) if loss is not None else output
653
+
654
+ return SegFormerImageClassifierOutput(
655
+ loss=loss,
656
+ logits=logits,
657
+ hidden_states=outputs.hidden_states,
658
+ attentions=outputs.attentions,
659
+ )
660
+
661
+
662
+ class SegformerMLP(nn.Module):
663
+ """
664
+ Linear Embedding.
665
+ """
666
+
667
+ def __init__(self, config: SegformerConfig, input_dim):
668
+ super().__init__()
669
+ self.proj = nn.Linear(input_dim, config.decoder_hidden_size)
670
+
671
+ def forward(self, hidden_states: torch.Tensor):
672
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
673
+ hidden_states = self.proj(hidden_states)
674
+ return hidden_states
675
+
676
+
677
+ class SegformerDecodeHead(SegformerPreTrainedModel):
678
+ def __init__(self, config):
679
+ super().__init__(config)
680
+ # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size
681
+ mlps = []
682
+ for i in range(config.num_encoder_blocks):
683
+ mlp = SegformerMLP(config, input_dim=config.hidden_sizes[i])
684
+ mlps.append(mlp)
685
+ self.linear_c = nn.ModuleList(mlps)
686
+
687
+ # the following 3 layers implement the ConvModule of the original implementation
688
+ self.linear_fuse = nn.Conv2d(
689
+ in_channels=config.decoder_hidden_size * config.num_encoder_blocks,
690
+ out_channels=config.decoder_hidden_size,
691
+ kernel_size=1,
692
+ bias=False,
693
+ )
694
+ self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size)
695
+ self.activation = nn.ReLU()
696
+
697
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
698
+ self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1)
699
+
700
+ self.config = config
701
+
702
+ def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor:
703
+ batch_size = encoder_hidden_states[-1].shape[0]
704
+
705
+ all_hidden_states = ()
706
+ for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c):
707
+ if self.config.reshape_last_stage is False and encoder_hidden_state.ndim == 3:
708
+ height = width = int(math.sqrt(encoder_hidden_state.shape[-1]))
709
+ encoder_hidden_state = (
710
+ encoder_hidden_state.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
711
+ )
712
+
713
+ # unify channel dimension
714
+ height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
715
+ encoder_hidden_state = mlp(encoder_hidden_state)
716
+ encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
717
+ encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width)
718
+ # upsample
719
+ encoder_hidden_state = nn.functional.interpolate(
720
+ encoder_hidden_state, size=encoder_hidden_states[0].size()[2:], mode="bilinear", align_corners=False
721
+ )
722
+ all_hidden_states += (encoder_hidden_state,)
723
+
724
+ hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1))
725
+ hidden_states = self.batch_norm(hidden_states)
726
+ hidden_states = self.activation(hidden_states)
727
+ hidden_states = self.dropout(hidden_states)
728
+
729
+ # logits are of shape (batch_size, num_labels, height/4, width/4)
730
+ logits = self.classifier(hidden_states)
731
+
732
+ return logits
733
+
734
+
735
+ @add_start_docstrings(
736
+ """SegFormer Model transformer with an all-MLP decode head on top e.g. for ADE20k, CityScapes.""",
737
+ SEGFORMER_START_DOCSTRING,
738
+ )
739
+ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
740
+ def __init__(self, config):
741
+ super().__init__(config)
742
+ self.segformer = SegformerModel(config)
743
+ self.decode_head = SegformerDecodeHead(config)
744
+
745
+ # Initialize weights and apply final processing
746
+ self.post_init()
747
+
748
+ @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
749
+ @replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
750
+ def forward(
751
+ self,
752
+ pixel_values: torch.FloatTensor,
753
+ labels: Optional[torch.LongTensor] = None,
754
+ output_attentions: Optional[bool] = None,
755
+ output_hidden_states: Optional[bool] = None,
756
+ return_dict: Optional[bool] = None,
757
+ ) -> Union[Tuple, SemanticSegmenterOutput]:
758
+ r"""
759
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
760
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
761
+ config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
762
+
763
+ Returns:
764
+
765
+ Examples:
766
+
767
+ ```python
768
+ >>> from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
769
+ >>> from PIL import Image
770
+ >>> import requests
771
+
772
+ >>> image_processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
773
+ >>> model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
774
+
775
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
776
+ >>> image = Image.open(requests.get(url, stream=True).raw)
777
+
778
+ >>> inputs = image_processor(images=image, return_tensors="pt")
779
+ >>> outputs = model(**inputs)
780
+ >>> logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
781
+ >>> list(logits.shape)
782
+ [1, 150, 128, 128]
783
+ ```"""
784
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
785
+ output_hidden_states = (
786
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
787
+ )
788
+
789
+ if labels is not None and self.config.num_labels < 1:
790
+ raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}")
791
+
792
+ outputs = self.segformer(
793
+ pixel_values,
794
+ output_attentions=output_attentions,
795
+ output_hidden_states=True, # we need the intermediate hidden states
796
+ return_dict=return_dict,
797
+ )
798
+
799
+ encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
800
+
801
+ logits = self.decode_head(encoder_hidden_states)
802
+
803
+ loss = None
804
+ if labels is not None:
805
+ # upsample logits to the images' original size
806
+ upsampled_logits = nn.functional.interpolate(
807
+ logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
808
+ )
809
+ if self.config.num_labels > 1:
810
+ loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
811
+ loss = loss_fct(upsampled_logits, labels)
812
+ elif self.config.num_labels == 1:
813
+ valid_mask = ((labels >= 0) & (labels != self.config.semantic_loss_ignore_index)).float()
814
+ loss_fct = BCEWithLogitsLoss(reduction="none")
815
+ loss = loss_fct(upsampled_logits.squeeze(1), labels.float())
816
+ loss = (loss * valid_mask).mean()
817
+
818
+ if not return_dict:
819
+ if output_hidden_states:
820
+ output = (logits,) + outputs[1:]
821
+ else:
822
+ output = (logits,) + outputs[2:]
823
+ return ((loss,) + output) if loss is not None else output
824
+
825
+ return SemanticSegmenterOutput(
826
+ loss=loss,
827
+ logits=logits,
828
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
829
+ attentions=outputs.attentions,
830
+ )
831
+
832
+
833
+ __all__ = [
834
+ "SegformerDecodeHead",
835
+ "SegformerForImageClassification",
836
+ "SegformerForSemanticSegmentation",
837
+ "SegformerLayer",
838
+ "SegformerModel",
839
+ "SegformerPreTrainedModel",
840
+ ]
docs/transformers/build/lib/transformers/models/segformer/modeling_tf_segformer.py ADDED
@@ -0,0 +1,1045 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 NVIDIA The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """TensorFlow SegFormer model."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import math
20
+ from typing import Optional, Tuple, Union
21
+
22
+ import tensorflow as tf
23
+
24
+ from ...activations_tf import get_tf_activation
25
+ from ...file_utils import (
26
+ add_code_sample_docstrings,
27
+ add_start_docstrings,
28
+ add_start_docstrings_to_model_forward,
29
+ replace_return_docstrings,
30
+ )
31
+ from ...modeling_tf_outputs import TFBaseModelOutput, TFSemanticSegmenterOutput, TFSequenceClassifierOutput
32
+ from ...modeling_tf_utils import (
33
+ TFPreTrainedModel,
34
+ TFSequenceClassificationLoss,
35
+ keras,
36
+ keras_serializable,
37
+ unpack_inputs,
38
+ )
39
+ from ...tf_utils import shape_list, stable_softmax
40
+ from ...utils import logging
41
+ from .configuration_segformer import SegformerConfig
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+ # General docstring
47
+ _CONFIG_FOR_DOC = "SegformerConfig"
48
+
49
+ # Base docstring
50
+ _CHECKPOINT_FOR_DOC = "nvidia/mit-b0"
51
+ _EXPECTED_OUTPUT_SHAPE = [1, 256, 16, 16]
52
+
53
+ # Image classification docstring
54
+ _IMAGE_CLASS_CHECKPOINT = "nvidia/mit-b0"
55
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
56
+
57
+
58
+ # Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath with ConvNext->Segformer
59
+ class TFSegformerDropPath(keras.layers.Layer):
60
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
61
+ References:
62
+ (1) github.com:rwightman/pytorch-image-models
63
+ """
64
+
65
+ def __init__(self, drop_path: float, **kwargs):
66
+ super().__init__(**kwargs)
67
+ self.drop_path = drop_path
68
+
69
+ def call(self, x: tf.Tensor, training=None):
70
+ if training:
71
+ keep_prob = 1 - self.drop_path
72
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
73
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
74
+ random_tensor = tf.floor(random_tensor)
75
+ return (x / keep_prob) * random_tensor
76
+ return x
77
+
78
+
79
+ class TFSegformerOverlapPatchEmbeddings(keras.layers.Layer):
80
+ """Construct the overlapping patch embeddings."""
81
+
82
+ def __init__(self, patch_size, stride, num_channels, hidden_size, **kwargs):
83
+ super().__init__(**kwargs)
84
+ self.padding = keras.layers.ZeroPadding2D(padding=patch_size // 2)
85
+ self.proj = keras.layers.Conv2D(
86
+ filters=hidden_size, kernel_size=patch_size, strides=stride, padding="VALID", name="proj"
87
+ )
88
+
89
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm")
90
+ self.num_channels = num_channels
91
+ self.hidden_size = hidden_size
92
+
93
+ def call(self, pixel_values: tf.Tensor) -> Tuple[tf.Tensor, int, int]:
94
+ embeddings = self.proj(self.padding(pixel_values))
95
+ height = shape_list(embeddings)[1]
96
+ width = shape_list(embeddings)[2]
97
+ hidden_dim = shape_list(embeddings)[3]
98
+ # (batch_size, height, width, num_channels) -> (batch_size, height*width, num_channels)
99
+ # this can be fed to a Transformer layer
100
+ embeddings = tf.reshape(embeddings, (-1, height * width, hidden_dim))
101
+ embeddings = self.layer_norm(embeddings)
102
+ return embeddings, height, width
103
+
104
+ def build(self, input_shape=None):
105
+ if self.built:
106
+ return
107
+ self.built = True
108
+ if getattr(self, "proj", None) is not None:
109
+ with tf.name_scope(self.proj.name):
110
+ self.proj.build([None, None, None, self.num_channels])
111
+ if getattr(self, "layer_norm", None) is not None:
112
+ with tf.name_scope(self.layer_norm.name):
113
+ self.layer_norm.build([None, None, self.hidden_size])
114
+
115
+
116
+ class TFSegformerEfficientSelfAttention(keras.layers.Layer):
117
+ """SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT
118
+ paper](https://arxiv.org/abs/2102.12122)."""
119
+
120
+ def __init__(
121
+ self,
122
+ config: SegformerConfig,
123
+ hidden_size: int,
124
+ num_attention_heads: int,
125
+ sequence_reduction_ratio: int,
126
+ **kwargs,
127
+ ):
128
+ super().__init__(**kwargs)
129
+ self.hidden_size = hidden_size
130
+ self.num_attention_heads = num_attention_heads
131
+
132
+ if self.hidden_size % self.num_attention_heads != 0:
133
+ raise ValueError(
134
+ f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
135
+ f"heads ({self.num_attention_heads})"
136
+ )
137
+
138
+ self.attention_head_size = self.hidden_size // self.num_attention_heads
139
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
140
+ self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
141
+
142
+ self.query = keras.layers.Dense(self.all_head_size, name="query")
143
+ self.key = keras.layers.Dense(self.all_head_size, name="key")
144
+ self.value = keras.layers.Dense(self.all_head_size, name="value")
145
+
146
+ self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
147
+
148
+ self.sr_ratio = sequence_reduction_ratio
149
+ if sequence_reduction_ratio > 1:
150
+ self.sr = keras.layers.Conv2D(
151
+ filters=hidden_size, kernel_size=sequence_reduction_ratio, strides=sequence_reduction_ratio, name="sr"
152
+ )
153
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm")
154
+
155
+ def transpose_for_scores(self, tensor: tf.Tensor) -> tf.Tensor:
156
+ # Reshape from [batch_size, seq_length, all_head_size]
157
+ # to [batch_size, seq_length, num_attention_heads, attention_head_size]
158
+ batch_size = shape_list(tensor)[0]
159
+ tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
160
+
161
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size]
162
+ # to [batch_size, num_attention_heads, seq_length, attention_head_size]
163
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
164
+
165
+ def call(
166
+ self,
167
+ hidden_states: tf.Tensor,
168
+ height: int,
169
+ width: int,
170
+ output_attentions: bool = False,
171
+ training: bool = False,
172
+ ) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
173
+ batch_size = shape_list(hidden_states)[0]
174
+ num_channels = shape_list(hidden_states)[2]
175
+
176
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
177
+
178
+ if self.sr_ratio > 1:
179
+ # Reshape to (batch_size, height, width, num_channels)
180
+ hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels))
181
+ # Apply sequence reduction
182
+ hidden_states = self.sr(hidden_states)
183
+ # Reshape back to (batch_size, seq_len, num_channels)
184
+ hidden_states = tf.reshape(hidden_states, (batch_size, -1, num_channels))
185
+ hidden_states = self.layer_norm(hidden_states)
186
+
187
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
188
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
189
+
190
+ # Take the dot product between "query" and "key" to get the raw attention scores.
191
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
192
+
193
+ scale = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
194
+ attention_scores = tf.divide(attention_scores, scale)
195
+
196
+ # Normalize the attention scores to probabilities.
197
+ attention_probs = stable_softmax(logits=attention_scores, axis=-1)
198
+
199
+ # This is actually dropping out entire tokens to attend to, which might
200
+ # seem a bit unusual, but is taken from the original Transformer paper.
201
+ attention_probs = self.dropout(attention_probs, training=training)
202
+
203
+ context_layer = tf.matmul(attention_probs, value_layer)
204
+
205
+ context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
206
+ # (batch_size, seq_len_q, all_head_size)
207
+ context_layer = tf.reshape(context_layer, (batch_size, -1, self.all_head_size))
208
+
209
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
210
+ return outputs
211
+
212
+ def build(self, input_shape=None):
213
+ if self.built:
214
+ return
215
+ self.built = True
216
+ if getattr(self, "query", None) is not None:
217
+ with tf.name_scope(self.query.name):
218
+ self.query.build([None, None, self.hidden_size])
219
+ if getattr(self, "key", None) is not None:
220
+ with tf.name_scope(self.key.name):
221
+ self.key.build([None, None, self.hidden_size])
222
+ if getattr(self, "value", None) is not None:
223
+ with tf.name_scope(self.value.name):
224
+ self.value.build([None, None, self.hidden_size])
225
+ if getattr(self, "sr", None) is not None:
226
+ with tf.name_scope(self.sr.name):
227
+ self.sr.build([None, None, None, self.hidden_size])
228
+ if getattr(self, "layer_norm", None) is not None:
229
+ with tf.name_scope(self.layer_norm.name):
230
+ self.layer_norm.build([None, None, self.hidden_size])
231
+
232
+
233
+ class TFSegformerSelfOutput(keras.layers.Layer):
234
+ def __init__(self, config: SegformerConfig, hidden_size: int, **kwargs):
235
+ super().__init__(**kwargs)
236
+ self.dense = keras.layers.Dense(hidden_size, name="dense")
237
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
238
+ self.hidden_size = hidden_size
239
+
240
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
241
+ hidden_states = self.dense(hidden_states)
242
+ hidden_states = self.dropout(hidden_states, training=training)
243
+ return hidden_states
244
+
245
+ def build(self, input_shape=None):
246
+ if self.built:
247
+ return
248
+ self.built = True
249
+ if getattr(self, "dense", None) is not None:
250
+ with tf.name_scope(self.dense.name):
251
+ self.dense.build([None, None, self.hidden_size])
252
+
253
+
254
+ class TFSegformerAttention(keras.layers.Layer):
255
+ def __init__(
256
+ self,
257
+ config: SegformerConfig,
258
+ hidden_size: int,
259
+ num_attention_heads: int,
260
+ sequence_reduction_ratio: int,
261
+ **kwargs,
262
+ ):
263
+ super().__init__(**kwargs)
264
+ self.self = TFSegformerEfficientSelfAttention(
265
+ config=config,
266
+ hidden_size=hidden_size,
267
+ num_attention_heads=num_attention_heads,
268
+ sequence_reduction_ratio=sequence_reduction_ratio,
269
+ name="self",
270
+ )
271
+ self.dense_output = TFSegformerSelfOutput(config, hidden_size=hidden_size, name="output")
272
+
273
+ def call(
274
+ self, hidden_states: tf.Tensor, height: int, width: int, output_attentions: bool = False
275
+ ) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
276
+ self_outputs = self.self(hidden_states, height, width, output_attentions)
277
+
278
+ attention_output = self.dense_output(self_outputs[0])
279
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
280
+ return outputs
281
+
282
+ def build(self, input_shape=None):
283
+ if self.built:
284
+ return
285
+ self.built = True
286
+ if getattr(self, "self", None) is not None:
287
+ with tf.name_scope(self.self.name):
288
+ self.self.build(None)
289
+ if getattr(self, "dense_output", None) is not None:
290
+ with tf.name_scope(self.dense_output.name):
291
+ self.dense_output.build(None)
292
+
293
+
294
+ class TFSegformerDWConv(keras.layers.Layer):
295
+ def __init__(self, dim: int = 768, **kwargs):
296
+ super().__init__(**kwargs)
297
+ self.depthwise_convolution = keras.layers.Conv2D(
298
+ filters=dim, kernel_size=3, strides=1, padding="same", groups=dim, name="dwconv"
299
+ )
300
+ self.dim = dim
301
+
302
+ def call(self, hidden_states: tf.Tensor, height: int, width: int) -> tf.Tensor:
303
+ batch_size = shape_list(hidden_states)[0]
304
+ num_channels = shape_list(hidden_states)[-1]
305
+ hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels))
306
+ hidden_states = self.depthwise_convolution(hidden_states)
307
+
308
+ new_height = shape_list(hidden_states)[1]
309
+ new_width = shape_list(hidden_states)[2]
310
+ num_channels = shape_list(hidden_states)[3]
311
+ hidden_states = tf.reshape(hidden_states, (batch_size, new_height * new_width, num_channels))
312
+ return hidden_states
313
+
314
+ def build(self, input_shape=None):
315
+ if self.built:
316
+ return
317
+ self.built = True
318
+ if getattr(self, "depthwise_convolution", None) is not None:
319
+ with tf.name_scope(self.depthwise_convolution.name):
320
+ self.depthwise_convolution.build([None, None, None, self.dim])
321
+
322
+
323
+ class TFSegformerMixFFN(keras.layers.Layer):
324
+ def __init__(
325
+ self,
326
+ config: SegformerConfig,
327
+ in_features: int,
328
+ hidden_features: Optional[int] = None,
329
+ out_features: Optional[int] = None,
330
+ **kwargs,
331
+ ):
332
+ super().__init__(**kwargs)
333
+ out_features = out_features or in_features
334
+ self.dense1 = keras.layers.Dense(hidden_features, name="dense1")
335
+ self.depthwise_convolution = TFSegformerDWConv(hidden_features, name="dwconv")
336
+ if isinstance(config.hidden_act, str):
337
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
338
+ else:
339
+ self.intermediate_act_fn = config.hidden_act
340
+ self.dense2 = keras.layers.Dense(out_features, name="dense2")
341
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
342
+ self.hidden_features = hidden_features
343
+ self.in_features = in_features
344
+
345
+ def call(self, hidden_states: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor:
346
+ hidden_states = self.dense1(hidden_states)
347
+ hidden_states = self.depthwise_convolution(hidden_states, height, width)
348
+ hidden_states = self.intermediate_act_fn(hidden_states)
349
+ hidden_states = self.dropout(hidden_states, training=training)
350
+ hidden_states = self.dense2(hidden_states)
351
+ hidden_states = self.dropout(hidden_states, training=training)
352
+ return hidden_states
353
+
354
+ def build(self, input_shape=None):
355
+ if self.built:
356
+ return
357
+ self.built = True
358
+ if getattr(self, "dense1", None) is not None:
359
+ with tf.name_scope(self.dense1.name):
360
+ self.dense1.build([None, None, self.in_features])
361
+ if getattr(self, "depthwise_convolution", None) is not None:
362
+ with tf.name_scope(self.depthwise_convolution.name):
363
+ self.depthwise_convolution.build(None)
364
+ if getattr(self, "dense2", None) is not None:
365
+ with tf.name_scope(self.dense2.name):
366
+ self.dense2.build([None, None, self.hidden_features])
367
+
368
+
369
+ class TFSegformerLayer(keras.layers.Layer):
370
+ """This corresponds to the Block class in the original implementation."""
371
+
372
+ def __init__(
373
+ self,
374
+ config,
375
+ hidden_size: int,
376
+ num_attention_heads: int,
377
+ drop_path: float,
378
+ sequence_reduction_ratio: int,
379
+ mlp_ratio: int,
380
+ **kwargs,
381
+ ):
382
+ super().__init__(**kwargs)
383
+ self.layer_norm_1 = keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm_1")
384
+ self.attention = TFSegformerAttention(
385
+ config,
386
+ hidden_size=hidden_size,
387
+ num_attention_heads=num_attention_heads,
388
+ sequence_reduction_ratio=sequence_reduction_ratio,
389
+ name="attention",
390
+ )
391
+ self.drop_path = TFSegformerDropPath(drop_path) if drop_path > 0.0 else keras.layers.Activation("linear")
392
+ self.layer_norm_2 = keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm_2")
393
+ mlp_hidden_size = int(hidden_size * mlp_ratio)
394
+ self.mlp = TFSegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size, name="mlp")
395
+ self.hidden_size = hidden_size
396
+
397
+ def call(
398
+ self,
399
+ hidden_states: tf.Tensor,
400
+ height: int,
401
+ width: int,
402
+ output_attentions: bool = False,
403
+ training: bool = False,
404
+ ) -> Tuple:
405
+ self_attention_outputs = self.attention(
406
+ self.layer_norm_1(hidden_states), # in Segformer, layernorm is applied before self-attention
407
+ height,
408
+ width,
409
+ output_attentions=output_attentions,
410
+ training=training,
411
+ )
412
+
413
+ attention_output = self_attention_outputs[0]
414
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
415
+
416
+ # first residual connection (with stochastic depth)
417
+ attention_output = self.drop_path(attention_output, training=training)
418
+ hidden_states = attention_output + hidden_states
419
+ mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width)
420
+
421
+ # second residual connection (with stochastic depth)
422
+ mlp_output = self.drop_path(mlp_output, training=training)
423
+ layer_output = mlp_output + hidden_states
424
+
425
+ outputs = (layer_output,) + outputs
426
+
427
+ return outputs
428
+
429
+ def build(self, input_shape=None):
430
+ if self.built:
431
+ return
432
+ self.built = True
433
+ if getattr(self, "layer_norm_1", None) is not None:
434
+ with tf.name_scope(self.layer_norm_1.name):
435
+ self.layer_norm_1.build([None, None, self.hidden_size])
436
+ if getattr(self, "attention", None) is not None:
437
+ with tf.name_scope(self.attention.name):
438
+ self.attention.build(None)
439
+ if getattr(self, "layer_norm_2", None) is not None:
440
+ with tf.name_scope(self.layer_norm_2.name):
441
+ self.layer_norm_2.build([None, None, self.hidden_size])
442
+ if getattr(self, "mlp", None) is not None:
443
+ with tf.name_scope(self.mlp.name):
444
+ self.mlp.build(None)
445
+
446
+
447
+ class TFSegformerEncoder(keras.layers.Layer):
448
+ def __init__(self, config: SegformerConfig, **kwargs):
449
+ super().__init__(**kwargs)
450
+ self.config = config
451
+
452
+ # stochastic depth decay rule
453
+ drop_path_decays = [x.numpy() for x in tf.linspace(0.0, config.drop_path_rate, sum(config.depths))]
454
+
455
+ # patch embeddings
456
+ embeddings = []
457
+ for i in range(config.num_encoder_blocks):
458
+ embeddings.append(
459
+ TFSegformerOverlapPatchEmbeddings(
460
+ patch_size=config.patch_sizes[i],
461
+ stride=config.strides[i],
462
+ num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
463
+ hidden_size=config.hidden_sizes[i],
464
+ name=f"patch_embeddings.{i}",
465
+ )
466
+ )
467
+ self.embeddings = embeddings
468
+
469
+ # Transformer blocks
470
+ blocks = []
471
+ cur = 0
472
+ for i in range(config.num_encoder_blocks):
473
+ # each block consists of layers
474
+ layers = []
475
+ if i != 0:
476
+ cur += config.depths[i - 1]
477
+ for j in range(config.depths[i]):
478
+ layers.append(
479
+ TFSegformerLayer(
480
+ config,
481
+ hidden_size=config.hidden_sizes[i],
482
+ num_attention_heads=config.num_attention_heads[i],
483
+ drop_path=drop_path_decays[cur + j],
484
+ sequence_reduction_ratio=config.sr_ratios[i],
485
+ mlp_ratio=config.mlp_ratios[i],
486
+ name=f"block.{i}.{j}",
487
+ )
488
+ )
489
+ blocks.append(layers)
490
+
491
+ self.block = blocks
492
+
493
+ # Layer norms
494
+ self.layer_norms = [
495
+ keras.layers.LayerNormalization(epsilon=1e-05, name=f"layer_norm.{i}")
496
+ for i in range(config.num_encoder_blocks)
497
+ ]
498
+
499
+ def call(
500
+ self,
501
+ pixel_values: tf.Tensor,
502
+ output_attentions: Optional[bool] = False,
503
+ output_hidden_states: Optional[bool] = False,
504
+ return_dict: Optional[bool] = True,
505
+ training: bool = False,
506
+ ) -> Union[Tuple, TFBaseModelOutput]:
507
+ all_hidden_states = () if output_hidden_states else None
508
+ all_self_attentions = () if output_attentions else None
509
+
510
+ batch_size = shape_list(pixel_values)[0]
511
+
512
+ hidden_states = pixel_values
513
+ for idx, x in enumerate(zip(self.embeddings, self.block, self.layer_norms)):
514
+ embedding_layer, block_layer, norm_layer = x
515
+ # first, obtain patch embeddings
516
+ hidden_states, height, width = embedding_layer(hidden_states)
517
+
518
+ # second, send embeddings through blocks
519
+ # (each block consists of multiple layers i.e., list of layers)
520
+ for i, blk in enumerate(block_layer):
521
+ layer_outputs = blk(
522
+ hidden_states,
523
+ height,
524
+ width,
525
+ output_attentions,
526
+ training=training,
527
+ )
528
+ hidden_states = layer_outputs[0]
529
+ if output_attentions:
530
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
531
+
532
+ # third, apply layer norm
533
+ hidden_states = norm_layer(hidden_states)
534
+
535
+ # fourth, optionally reshape back to (batch_size, height, width, num_channels)
536
+ if idx != len(self.embeddings) - 1 or (idx == len(self.embeddings) - 1 and self.config.reshape_last_stage):
537
+ num_channels = shape_list(hidden_states)[-1]
538
+ hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels))
539
+
540
+ if output_hidden_states:
541
+ all_hidden_states = all_hidden_states + (hidden_states,)
542
+
543
+ if not return_dict:
544
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
545
+ return TFBaseModelOutput(
546
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
547
+ )
548
+
549
+ def build(self, input_shape=None):
550
+ if self.built:
551
+ return
552
+ self.built = True
553
+ if getattr(self, "layer_norms", None) is not None:
554
+ for layer, shape in zip(self.layer_norms, self.config.hidden_sizes):
555
+ with tf.name_scope(layer.name):
556
+ layer.build([None, None, shape])
557
+ if getattr(self, "block", None) is not None:
558
+ for block in self.block:
559
+ for layer in block:
560
+ with tf.name_scope(layer.name):
561
+ layer.build(None)
562
+ if getattr(self, "embeddings", None) is not None:
563
+ for layer in self.embeddings:
564
+ with tf.name_scope(layer.name):
565
+ layer.build(None)
566
+
567
+
568
+ @keras_serializable
569
+ class TFSegformerMainLayer(keras.layers.Layer):
570
+ config_class = SegformerConfig
571
+
572
+ def __init__(self, config: SegformerConfig, **kwargs):
573
+ super().__init__(**kwargs)
574
+
575
+ self.config = config
576
+ # hierarchical Transformer encoder
577
+ self.encoder = TFSegformerEncoder(config, name="encoder")
578
+
579
+ @unpack_inputs
580
+ def call(
581
+ self,
582
+ pixel_values: tf.Tensor,
583
+ output_attentions: Optional[bool] = None,
584
+ output_hidden_states: Optional[bool] = None,
585
+ return_dict: Optional[bool] = None,
586
+ training: bool = False,
587
+ ) -> Union[Tuple, TFBaseModelOutput]:
588
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
589
+ output_hidden_states = (
590
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
591
+ )
592
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
593
+
594
+ # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
595
+ # So change the input format from `NCHW` to `NHWC`.
596
+ # shape = (batch_size, in_height, in_width, in_channels=num_channels)
597
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
598
+
599
+ encoder_outputs = self.encoder(
600
+ pixel_values,
601
+ output_attentions=output_attentions,
602
+ output_hidden_states=output_hidden_states,
603
+ return_dict=return_dict,
604
+ training=training,
605
+ )
606
+ sequence_output = encoder_outputs[0]
607
+ # Change to NCHW output format to have uniformity in the modules
608
+ sequence_output = tf.transpose(sequence_output, perm=[0, 3, 1, 2])
609
+
610
+ # Change the other hidden state outputs to NCHW as well
611
+ if output_hidden_states:
612
+ hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
613
+
614
+ if not return_dict:
615
+ if tf.greater(len(encoder_outputs[1:]), 0):
616
+ transposed_encoder_outputs = tuple(tf.transpose(v, perm=[0, 3, 1, 2]) for v in encoder_outputs[1:][0])
617
+ return (sequence_output,) + (transposed_encoder_outputs,)
618
+ else:
619
+ return (sequence_output,) + encoder_outputs[1:]
620
+
621
+ return TFBaseModelOutput(
622
+ last_hidden_state=sequence_output,
623
+ hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
624
+ attentions=encoder_outputs.attentions,
625
+ )
626
+
627
+ def build(self, input_shape=None):
628
+ if self.built:
629
+ return
630
+ self.built = True
631
+ if getattr(self, "encoder", None) is not None:
632
+ with tf.name_scope(self.encoder.name):
633
+ self.encoder.build(None)
634
+
635
+
636
+ class TFSegformerPreTrainedModel(TFPreTrainedModel):
637
+ """
638
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
639
+ models.
640
+ """
641
+
642
+ config_class = SegformerConfig
643
+ base_model_prefix = "segformer"
644
+ main_input_name = "pixel_values"
645
+
646
+ @property
647
+ def input_signature(self):
648
+ return {"pixel_values": tf.TensorSpec(shape=(None, self.config.num_channels, 512, 512), dtype=tf.float32)}
649
+
650
+
651
+ SEGFORMER_START_DOCSTRING = r"""
652
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
653
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
654
+ etc.)
655
+
656
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
657
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
658
+ behavior.
659
+
660
+ Parameters:
661
+ config ([`SegformerConfig`]): Model configuration class with all the parameters of the model.
662
+ Initializing with a config file does not load the weights associated with the model, only the
663
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
664
+ """
665
+
666
+ SEGFORMER_INPUTS_DOCSTRING = r"""
667
+
668
+ Args:
669
+ pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
670
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
671
+ [`SegformerImageProcessor.__call__`] for details.
672
+
673
+ output_attentions (`bool`, *optional*):
674
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
675
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
676
+ config will be used instead.
677
+
678
+ output_hidden_states (`bool`, *optional*):
679
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
680
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
681
+ used instead.
682
+
683
+ return_dict (`bool`, *optional*):
684
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
685
+ eager mode, in graph mode the value will always be set to True.
686
+
687
+ training (`bool`, *optional*, defaults to `False``):
688
+ Whether or not to use the model in training mode (some modules like dropout modules have different
689
+ behaviors between training and evaluation).
690
+ """
691
+
692
+
693
+ @add_start_docstrings(
694
+ "The bare SegFormer encoder (Mix-Transformer) outputting raw hidden-states without any specific head on top.",
695
+ SEGFORMER_START_DOCSTRING,
696
+ )
697
+ class TFSegformerModel(TFSegformerPreTrainedModel):
698
+ def __init__(self, config: SegformerConfig, *inputs, **kwargs):
699
+ super().__init__(config, *inputs, **kwargs)
700
+ self.config = config
701
+
702
+ # hierarchical Transformer encoder
703
+ self.segformer = TFSegformerMainLayer(config, name="segformer")
704
+
705
+ @unpack_inputs
706
+ @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
707
+ @add_code_sample_docstrings(
708
+ checkpoint=_CHECKPOINT_FOR_DOC,
709
+ output_type=TFBaseModelOutput,
710
+ config_class=_CONFIG_FOR_DOC,
711
+ modality="vision",
712
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
713
+ )
714
+ def call(
715
+ self,
716
+ pixel_values: tf.Tensor,
717
+ output_attentions: Optional[bool] = None,
718
+ output_hidden_states: Optional[bool] = None,
719
+ return_dict: Optional[bool] = None,
720
+ training: bool = False,
721
+ ) -> Union[Tuple, TFBaseModelOutput]:
722
+ outputs = self.segformer(
723
+ pixel_values,
724
+ output_attentions=output_attentions,
725
+ output_hidden_states=output_hidden_states,
726
+ return_dict=return_dict,
727
+ training=training,
728
+ )
729
+ return outputs
730
+
731
+ def build(self, input_shape=None):
732
+ if self.built:
733
+ return
734
+ self.built = True
735
+ if getattr(self, "segformer", None) is not None:
736
+ with tf.name_scope(self.segformer.name):
737
+ self.segformer.build(None)
738
+
739
+
740
+ @add_start_docstrings(
741
+ """
742
+ SegFormer Model transformer with an image classification head on top (a linear layer on top of the final hidden
743
+ states) e.g. for ImageNet.
744
+ """,
745
+ SEGFORMER_START_DOCSTRING,
746
+ )
747
+ class TFSegformerForImageClassification(TFSegformerPreTrainedModel, TFSequenceClassificationLoss):
748
+ def __init__(self, config: SegformerConfig, *inputs, **kwargs):
749
+ super().__init__(config, *inputs, **kwargs)
750
+
751
+ self.num_labels = config.num_labels
752
+ self.segformer = TFSegformerMainLayer(config, name="segformer")
753
+
754
+ # Classifier head
755
+ self.classifier = keras.layers.Dense(config.num_labels, name="classifier")
756
+ self.config = config
757
+
758
+ @unpack_inputs
759
+ @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
760
+ @add_code_sample_docstrings(
761
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
762
+ output_type=TFSequenceClassifierOutput,
763
+ config_class=_CONFIG_FOR_DOC,
764
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
765
+ )
766
+ def call(
767
+ self,
768
+ pixel_values: tf.Tensor | None = None,
769
+ labels: tf.Tensor | None = None,
770
+ output_attentions: Optional[bool] = None,
771
+ output_hidden_states: Optional[bool] = None,
772
+ return_dict: Optional[bool] = None,
773
+ ) -> Union[Tuple, TFSequenceClassifierOutput]:
774
+ outputs = self.segformer(
775
+ pixel_values,
776
+ output_attentions=output_attentions,
777
+ output_hidden_states=output_hidden_states,
778
+ return_dict=return_dict,
779
+ )
780
+
781
+ sequence_output = outputs[0]
782
+
783
+ # convert last hidden states to (batch_size, height*width, hidden_size)
784
+ batch_size = shape_list(sequence_output)[0]
785
+ sequence_output = tf.transpose(sequence_output, perm=[0, 2, 3, 1])
786
+ sequence_output = tf.reshape(sequence_output, (batch_size, -1, self.config.hidden_sizes[-1]))
787
+
788
+ # global average pooling
789
+ sequence_output = tf.reduce_mean(sequence_output, axis=1)
790
+
791
+ logits = self.classifier(sequence_output)
792
+
793
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
794
+
795
+ if not return_dict:
796
+ output = (logits,) + outputs[1:]
797
+ return ((loss,) + output) if loss is not None else output
798
+
799
+ return TFSequenceClassifierOutput(
800
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
801
+ )
802
+
803
+ def build(self, input_shape=None):
804
+ if self.built:
805
+ return
806
+ self.built = True
807
+ if getattr(self, "segformer", None) is not None:
808
+ with tf.name_scope(self.segformer.name):
809
+ self.segformer.build(None)
810
+ if getattr(self, "classifier", None) is not None:
811
+ with tf.name_scope(self.classifier.name):
812
+ self.classifier.build([None, None, self.config.hidden_sizes[-1]])
813
+
814
+
815
+ class TFSegformerMLP(keras.layers.Layer):
816
+ """
817
+ Linear Embedding.
818
+ """
819
+
820
+ def __init__(self, input_dim: int, config: SegformerConfig, **kwargs):
821
+ super().__init__(**kwargs)
822
+ self.proj = keras.layers.Dense(config.decoder_hidden_size, name="proj")
823
+ self.input_dim = input_dim
824
+
825
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
826
+ height = shape_list(hidden_states)[1]
827
+ width = shape_list(hidden_states)[2]
828
+ hidden_dim = shape_list(hidden_states)[-1]
829
+ hidden_states = tf.reshape(hidden_states, (-1, height * width, hidden_dim))
830
+ hidden_states = self.proj(hidden_states)
831
+ return hidden_states
832
+
833
+ def build(self, input_shape=None):
834
+ if self.built:
835
+ return
836
+ self.built = True
837
+ if getattr(self, "proj", None) is not None:
838
+ with tf.name_scope(self.proj.name):
839
+ self.proj.build([None, None, self.input_dim])
840
+
841
+
842
+ class TFSegformerDecodeHead(TFSegformerPreTrainedModel):
843
+ def __init__(self, config: SegformerConfig, **kwargs):
844
+ super().__init__(config, **kwargs)
845
+ # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size
846
+ mlps = []
847
+ for i in range(config.num_encoder_blocks):
848
+ mlp = TFSegformerMLP(config=config, input_dim=config.hidden_sizes[i], name=f"linear_c.{i}")
849
+ mlps.append(mlp)
850
+ self.mlps = mlps
851
+
852
+ # the following 3 layers implement the ConvModule of the original implementation
853
+ self.linear_fuse = keras.layers.Conv2D(
854
+ filters=config.decoder_hidden_size, kernel_size=1, use_bias=False, name="linear_fuse"
855
+ )
856
+ self.batch_norm = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="batch_norm")
857
+ self.activation = keras.layers.Activation("relu")
858
+
859
+ self.dropout = keras.layers.Dropout(config.classifier_dropout_prob)
860
+ self.classifier = keras.layers.Conv2D(filters=config.num_labels, kernel_size=1, name="classifier")
861
+
862
+ self.config = config
863
+
864
+ def call(self, encoder_hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
865
+ all_hidden_states = ()
866
+ for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.mlps):
867
+ if self.config.reshape_last_stage is False and len(shape_list(encoder_hidden_state)) == 3:
868
+ height = tf.math.sqrt(tf.cast(shape_list(encoder_hidden_state)[1], tf.float32))
869
+ height = width = tf.cast(height, tf.int32)
870
+ channel_dim = shape_list(encoder_hidden_state)[-1]
871
+ encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim))
872
+
873
+ # unify channel dimension
874
+ encoder_hidden_state = tf.transpose(encoder_hidden_state, perm=[0, 2, 3, 1])
875
+ height, width = shape_list(encoder_hidden_state)[1:3]
876
+ encoder_hidden_state = mlp(encoder_hidden_state)
877
+ channel_dim = shape_list(encoder_hidden_state)[-1]
878
+ encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim))
879
+
880
+ # upsample
881
+ temp_state = tf.transpose(encoder_hidden_states[0], perm=[0, 2, 3, 1])
882
+ upsample_resolution = shape_list(temp_state)[1:-1]
883
+ encoder_hidden_state = tf.image.resize(encoder_hidden_state, size=upsample_resolution, method="bilinear")
884
+ all_hidden_states += (encoder_hidden_state,)
885
+
886
+ hidden_states = self.linear_fuse(tf.concat(all_hidden_states[::-1], axis=-1))
887
+ hidden_states = self.batch_norm(hidden_states, training=training)
888
+ hidden_states = self.activation(hidden_states)
889
+ hidden_states = self.dropout(hidden_states, training=training)
890
+
891
+ # logits of shape (batch_size, height/4, width/4, num_labels)
892
+ logits = self.classifier(hidden_states)
893
+
894
+ return logits
895
+
896
+ def build(self, input_shape=None):
897
+ if self.built:
898
+ return
899
+ self.built = True
900
+ if getattr(self, "linear_fuse", None) is not None:
901
+ with tf.name_scope(self.linear_fuse.name):
902
+ self.linear_fuse.build(
903
+ [None, None, None, self.config.decoder_hidden_size * self.config.num_encoder_blocks]
904
+ )
905
+ if getattr(self, "batch_norm", None) is not None:
906
+ with tf.name_scope(self.batch_norm.name):
907
+ self.batch_norm.build([None, None, None, self.config.decoder_hidden_size])
908
+ if getattr(self, "classifier", None) is not None:
909
+ with tf.name_scope(self.classifier.name):
910
+ self.classifier.build([None, None, None, self.config.decoder_hidden_size])
911
+ if getattr(self, "mlps", None) is not None:
912
+ for layer in self.mlps:
913
+ with tf.name_scope(layer.name):
914
+ layer.build(None)
915
+
916
+
917
+ @add_start_docstrings(
918
+ """SegFormer Model transformer with an all-MLP decode head on top e.g. for ADE20k, CityScapes.""",
919
+ SEGFORMER_START_DOCSTRING,
920
+ )
921
+ class TFSegformerForSemanticSegmentation(TFSegformerPreTrainedModel):
922
+ def __init__(self, config: SegformerConfig, **kwargs):
923
+ super().__init__(config, **kwargs)
924
+ self.segformer = TFSegformerMainLayer(config, name="segformer")
925
+ self.decode_head = TFSegformerDecodeHead(config, name="decode_head")
926
+
927
+ def hf_compute_loss(self, logits, labels):
928
+ # upsample logits to the images' original size
929
+ # `labels` is of shape (batch_size, height, width)
930
+ label_interp_shape = shape_list(labels)[1:]
931
+
932
+ upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
933
+ # compute weighted loss
934
+ loss_fct = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
935
+
936
+ def masked_loss(real, pred):
937
+ unmasked_loss = loss_fct(real, pred)
938
+ mask = tf.cast(real != self.config.semantic_loss_ignore_index, dtype=unmasked_loss.dtype)
939
+ masked_loss = unmasked_loss * mask
940
+ # Reduction strategy in the similar spirit with
941
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210
942
+ reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(mask)
943
+ return tf.reshape(reduced_masked_loss, (1,))
944
+
945
+ return masked_loss(labels, upsampled_logits)
946
+
947
+ @unpack_inputs
948
+ @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
949
+ @replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
950
+ def call(
951
+ self,
952
+ pixel_values: tf.Tensor,
953
+ labels: tf.Tensor | None = None,
954
+ output_attentions: Optional[bool] = None,
955
+ output_hidden_states: Optional[bool] = None,
956
+ return_dict: Optional[bool] = None,
957
+ ) -> Union[Tuple, TFSemanticSegmenterOutput]:
958
+ r"""
959
+ labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*):
960
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
961
+ config.num_labels - 1]`. If `config.num_labels > 1`, a (per-pixel) classification loss is computed
962
+ (Cross-Entropy).
963
+
964
+ Returns:
965
+
966
+ Examples:
967
+
968
+ ```python
969
+ >>> from transformers import AutoImageProcessor, TFSegformerForSemanticSegmentation
970
+ >>> from PIL import Image
971
+ >>> import requests
972
+
973
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
974
+ >>> image = Image.open(requests.get(url, stream=True).raw)
975
+
976
+ >>> image_processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
977
+ >>> model = TFSegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
978
+
979
+ >>> inputs = image_processor(images=image, return_tensors="tf")
980
+ >>> outputs = model(**inputs, training=False)
981
+ >>> # logits are of shape (batch_size, num_labels, height/4, width/4)
982
+ >>> logits = outputs.logits
983
+ >>> list(logits.shape)
984
+ [1, 150, 128, 128]
985
+ ```"""
986
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
987
+ output_hidden_states = (
988
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
989
+ )
990
+
991
+ if labels is not None and not self.config.num_labels > 1:
992
+ raise ValueError("The number of labels should be greater than one")
993
+
994
+ outputs = self.segformer(
995
+ pixel_values,
996
+ output_attentions=output_attentions,
997
+ output_hidden_states=True, # we need the intermediate hidden states
998
+ return_dict=return_dict,
999
+ )
1000
+
1001
+ encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
1002
+
1003
+ logits = self.decode_head(encoder_hidden_states)
1004
+
1005
+ loss = None
1006
+ if labels is not None:
1007
+ loss = self.hf_compute_loss(logits=logits, labels=labels)
1008
+
1009
+ # make logits of shape (batch_size, num_labels, height, width) to
1010
+ # keep them consistent across APIs
1011
+ logits = tf.transpose(logits, perm=[0, 3, 1, 2])
1012
+
1013
+ if not return_dict:
1014
+ if output_hidden_states:
1015
+ output = (logits,) + outputs[1:]
1016
+ else:
1017
+ output = (logits,) + outputs[2:]
1018
+ return ((loss,) + output) if loss is not None else output
1019
+
1020
+ return TFSemanticSegmenterOutput(
1021
+ loss=loss,
1022
+ logits=logits,
1023
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
1024
+ attentions=outputs.attentions,
1025
+ )
1026
+
1027
+ def build(self, input_shape=None):
1028
+ if self.built:
1029
+ return
1030
+ self.built = True
1031
+ if getattr(self, "segformer", None) is not None:
1032
+ with tf.name_scope(self.segformer.name):
1033
+ self.segformer.build(None)
1034
+ if getattr(self, "decode_head", None) is not None:
1035
+ with tf.name_scope(self.decode_head.name):
1036
+ self.decode_head.build(None)
1037
+
1038
+
1039
+ __all__ = [
1040
+ "TFSegformerDecodeHead",
1041
+ "TFSegformerForImageClassification",
1042
+ "TFSegformerForSemanticSegmentation",
1043
+ "TFSegformerModel",
1044
+ "TFSegformerPreTrainedModel",
1045
+ ]
docs/transformers/build/lib/transformers/models/seggpt/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_seggpt import *
22
+ from .image_processing_seggpt import *
23
+ from .modeling_seggpt import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/seggpt/configuration_seggpt.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """SegGpt model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class SegGptConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`SegGptModel`]. It is used to instantiate a SegGPT
27
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
28
+ defaults will yield a similar configuration to that of the SegGPT
29
+ [BAAI/seggpt-vit-large](https://huggingface.co/BAAI/seggpt-vit-large) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ hidden_size (`int`, *optional*, defaults to 1024):
36
+ Dimensionality of the encoder layers and the pooler layer.
37
+ num_hidden_layers (`int`, *optional*, defaults to 24):
38
+ Number of hidden layers in the Transformer encoder.
39
+ num_attention_heads (`int`, *optional*, defaults to 16):
40
+ Number of attention heads for each attention layer in the Transformer encoder.
41
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
42
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
43
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
44
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
45
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
46
+ initializer_range (`float`, *optional*, defaults to 0.02):
47
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
48
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
49
+ The epsilon used by the layer normalization layers.
50
+ image_size (`List[int]`, *optional*, defaults to `[896, 448]`):
51
+ The size (resolution) of each image.
52
+ patch_size (`int`, *optional*, defaults to 16):
53
+ The size (resolution) of each patch.
54
+ num_channels (`int`, *optional*, defaults to 3):
55
+ The number of input channels.
56
+ qkv_bias (`bool`, *optional*, defaults to `True`):
57
+ Whether to add a bias to the queries, keys and values.
58
+ mlp_dim (`int`, *optional*):
59
+ The dimensionality of the MLP layer in the Transformer encoder. If unset, defaults to
60
+ `hidden_size` * 4.
61
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
62
+ The drop path rate for the dropout layers.
63
+ pretrain_image_size (`int`, *optional*, defaults to 224):
64
+ The pretrained size of the absolute position embeddings.
65
+ decoder_hidden_size (`int`, *optional*, defaults to 64):
66
+ Hidden size for decoder.
67
+ use_relative_position_embeddings (`bool`, *optional*, defaults to `True`):
68
+ Whether to use relative position embeddings in the attention layers.
69
+ merge_index (`int`, *optional*, defaults to 2):
70
+ The index of the encoder layer to merge the embeddings.
71
+ intermediate_hidden_state_indices (`List[int]`, *optional*, defaults to `[5, 11, 17, 23]`):
72
+ The indices of the encoder layers which we store as features for the decoder.
73
+ beta (`float`, *optional*, defaults to 0.01):
74
+ Regularization factor for SegGptLoss (smooth-l1 loss).
75
+
76
+ Example:
77
+
78
+ ```python
79
+ >>> from transformers import SegGptConfig, SegGptModel
80
+
81
+ >>> # Initializing a SegGPT seggpt-vit-large style configuration
82
+ >>> configuration = SegGptConfig()
83
+
84
+ >>> # Initializing a model (with random weights) from the seggpt-vit-large style configuration
85
+ >>> model = SegGptModel(configuration)
86
+
87
+ >>> # Accessing the model configuration
88
+ >>> configuration = model.config
89
+ ```"""
90
+
91
+ model_type = "seggpt"
92
+
93
+ def __init__(
94
+ self,
95
+ hidden_size=1024,
96
+ num_hidden_layers=24,
97
+ num_attention_heads=16,
98
+ hidden_act="gelu",
99
+ hidden_dropout_prob=0.0,
100
+ initializer_range=0.02,
101
+ layer_norm_eps=1e-6,
102
+ image_size=[896, 448],
103
+ patch_size=16,
104
+ num_channels=3,
105
+ qkv_bias=True,
106
+ mlp_dim=None,
107
+ drop_path_rate=0.1,
108
+ pretrain_image_size=224,
109
+ decoder_hidden_size=64,
110
+ use_relative_position_embeddings=True,
111
+ merge_index=2,
112
+ intermediate_hidden_state_indices=[5, 11, 17, 23],
113
+ beta=0.01,
114
+ **kwargs,
115
+ ):
116
+ super().__init__(**kwargs)
117
+
118
+ if merge_index > min(intermediate_hidden_state_indices):
119
+ raise ValueError(
120
+ f"Merge index must be less than the minimum encoder output index, but got {merge_index=} and {intermediate_hidden_state_indices=}"
121
+ )
122
+ self.hidden_size = hidden_size
123
+ self.num_hidden_layers = num_hidden_layers
124
+ self.num_attention_heads = num_attention_heads
125
+ self.hidden_act = hidden_act
126
+ self.hidden_dropout_prob = hidden_dropout_prob
127
+ self.initializer_range = initializer_range
128
+ self.layer_norm_eps = layer_norm_eps
129
+ self.image_size = image_size
130
+ self.patch_size = patch_size
131
+ self.num_channels = num_channels
132
+ self.qkv_bias = qkv_bias
133
+ self.drop_path_rate = drop_path_rate
134
+ self.pretrain_image_size = pretrain_image_size
135
+ self.decoder_hidden_size = decoder_hidden_size
136
+ self.use_relative_position_embeddings = use_relative_position_embeddings
137
+ self.merge_index = merge_index
138
+ self.intermediate_hidden_state_indices = intermediate_hidden_state_indices
139
+ self.beta = beta
140
+ self.mlp_dim = int(hidden_size * 4) if mlp_dim is None else mlp_dim
141
+
142
+
143
+ __all__ = ["SegGptConfig"]
docs/transformers/build/lib/transformers/models/seggpt/convert_seggpt_to_hf.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert SegGPT checkpoints from the original repository.
16
+
17
+ URL: https://github.com/baaivision/Painter/tree/main/SegGPT
18
+ """
19
+
20
+ import argparse
21
+
22
+ import requests
23
+ import torch
24
+ from PIL import Image
25
+
26
+ from transformers import SegGptConfig, SegGptForImageSegmentation, SegGptImageProcessor
27
+ from transformers.utils import logging
28
+
29
+
30
+ logging.set_verbosity_info()
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ # here we list all keys to be renamed (original name on the left, our name on the right)
35
+ def create_rename_keys(config):
36
+ rename_keys = []
37
+
38
+ # fmt: off
39
+
40
+ # rename embedding and its parameters
41
+ rename_keys.append(("patch_embed.proj.weight", "model.embeddings.patch_embeddings.projection.weight"))
42
+ rename_keys.append(("patch_embed.proj.bias", "model.embeddings.patch_embeddings.projection.bias"))
43
+ rename_keys.append(("mask_token", "model.embeddings.mask_token"))
44
+ rename_keys.append(("segment_token_x", "model.embeddings.segment_token_input"))
45
+ rename_keys.append(("segment_token_y", "model.embeddings.segment_token_prompt"))
46
+ rename_keys.append(("type_token_cls", "model.embeddings.type_token_semantic"))
47
+ rename_keys.append(("type_token_ins", "model.embeddings.type_token_instance"))
48
+ rename_keys.append(("pos_embed", "model.embeddings.position_embeddings"))
49
+
50
+ # rename decoder and other
51
+ rename_keys.append(("norm.weight", "model.encoder.layernorm.weight"))
52
+ rename_keys.append(("norm.bias", "model.encoder.layernorm.bias"))
53
+ rename_keys.append(("decoder_embed.weight", "decoder.decoder_embed.weight"))
54
+ rename_keys.append(("decoder_embed.bias", "decoder.decoder_embed.bias"))
55
+ rename_keys.append(("decoder_pred.0.weight", "decoder.decoder_pred.conv.weight"))
56
+ rename_keys.append(("decoder_pred.0.bias", "decoder.decoder_pred.conv.bias"))
57
+ rename_keys.append(("decoder_pred.1.weight", "decoder.decoder_pred.layernorm.weight"))
58
+ rename_keys.append(("decoder_pred.1.bias", "decoder.decoder_pred.layernorm.bias"))
59
+ rename_keys.append(("decoder_pred.3.weight", "decoder.decoder_pred.head.weight"))
60
+ rename_keys.append(("decoder_pred.3.bias", "decoder.decoder_pred.head.bias"))
61
+
62
+ # rename blocks
63
+ for i in range(config.num_hidden_layers):
64
+ rename_keys.append((f"blocks.{i}.attn.qkv.weight", f"model.encoder.layers.{i}.attention.qkv.weight"))
65
+ rename_keys.append((f"blocks.{i}.attn.qkv.bias", f"model.encoder.layers.{i}.attention.qkv.bias"))
66
+ rename_keys.append((f"blocks.{i}.attn.proj.weight", f"model.encoder.layers.{i}.attention.proj.weight"))
67
+ rename_keys.append((f"blocks.{i}.attn.proj.bias", f"model.encoder.layers.{i}.attention.proj.bias"))
68
+ rename_keys.append((f"blocks.{i}.attn.rel_pos_h", f"model.encoder.layers.{i}.attention.rel_pos_h"))
69
+ rename_keys.append((f"blocks.{i}.attn.rel_pos_w", f"model.encoder.layers.{i}.attention.rel_pos_w"))
70
+
71
+ rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"model.encoder.layers.{i}.mlp.lin1.weight"))
72
+ rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"model.encoder.layers.{i}.mlp.lin1.bias"))
73
+ rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"model.encoder.layers.{i}.mlp.lin2.weight"))
74
+ rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"model.encoder.layers.{i}.mlp.lin2.bias"))
75
+
76
+ rename_keys.append((f"blocks.{i}.norm1.weight", f"model.encoder.layers.{i}.layernorm_before.weight"))
77
+ rename_keys.append((f"blocks.{i}.norm1.bias", f"model.encoder.layers.{i}.layernorm_before.bias"))
78
+ rename_keys.append((f"blocks.{i}.norm2.weight", f"model.encoder.layers.{i}.layernorm_after.weight"))
79
+ rename_keys.append((f"blocks.{i}.norm2.bias", f"model.encoder.layers.{i}.layernorm_after.bias"))
80
+
81
+ # fmt: on
82
+
83
+ return rename_keys
84
+
85
+
86
+ def rename_key(dct, old, new):
87
+ val = dct.pop(old)
88
+ dct[new] = val
89
+
90
+
91
+ # We will verify our results on spongebob images
92
+ def prepare_input():
93
+ image_input_url = (
94
+ "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg"
95
+ )
96
+ image_prompt_url = (
97
+ "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg"
98
+ )
99
+ mask_prompt_url = (
100
+ "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png"
101
+ )
102
+
103
+ image_input = Image.open(requests.get(image_input_url, stream=True).raw)
104
+ image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw)
105
+ mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw)
106
+
107
+ return image_input, image_prompt, mask_prompt
108
+
109
+
110
+ @torch.no_grad()
111
+ def convert_seggpt_checkpoint(args):
112
+ model_name = args.model_name
113
+ pytorch_dump_folder_path = args.pytorch_dump_folder_path
114
+ verify_logits = args.verify_logits
115
+ push_to_hub = args.push_to_hub
116
+
117
+ # Define default GroundingDINO configuation
118
+ config = SegGptConfig()
119
+
120
+ # Load original checkpoint
121
+ checkpoint_url = "https://huggingface.co/BAAI/SegGpt/blob/main/seggpt_vit_large.pth"
122
+ original_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["model"]
123
+
124
+ # # Rename keys
125
+ new_state_dict = original_state_dict.copy()
126
+ rename_keys = create_rename_keys(config)
127
+
128
+ for src, dest in rename_keys:
129
+ rename_key(new_state_dict, src, dest)
130
+
131
+ # Load HF model
132
+ model = SegGptForImageSegmentation(config)
133
+ model.eval()
134
+ missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
135
+ print("Missing keys:", missing_keys)
136
+ print("Unexpected keys:", unexpected_keys)
137
+
138
+ input_img, prompt_img, prompt_mask = prepare_input()
139
+ image_processor = SegGptImageProcessor()
140
+ inputs = image_processor(images=input_img, prompt_images=prompt_img, prompt_masks=prompt_mask, return_tensors="pt")
141
+
142
+ expected_prompt_pixel_values = torch.tensor(
143
+ [
144
+ [[-0.6965, -0.6965, -0.6965], [-0.6965, -0.6965, -0.6965], [-0.6965, -0.6965, -0.6965]],
145
+ [[1.6583, 1.6583, 1.6583], [1.6583, 1.6583, 1.6583], [1.6583, 1.6583, 1.6583]],
146
+ [[2.3088, 2.3088, 2.3088], [2.3088, 2.3088, 2.3088], [2.3088, 2.3088, 2.3088]],
147
+ ]
148
+ )
149
+
150
+ expected_pixel_values = torch.tensor(
151
+ [
152
+ [[1.6324, 1.6153, 1.5810], [1.6153, 1.5982, 1.5810], [1.5810, 1.5639, 1.5639]],
153
+ [[1.2731, 1.2556, 1.2206], [1.2556, 1.2381, 1.2031], [1.2206, 1.2031, 1.1681]],
154
+ [[1.6465, 1.6465, 1.6465], [1.6465, 1.6465, 1.6465], [1.6291, 1.6291, 1.6291]],
155
+ ]
156
+ )
157
+
158
+ expected_prompt_masks = torch.tensor(
159
+ [
160
+ [[-2.1179, -2.1179, -2.1179], [-2.1179, -2.1179, -2.1179], [-2.1179, -2.1179, -2.1179]],
161
+ [[-2.0357, -2.0357, -2.0357], [-2.0357, -2.0357, -2.0357], [-2.0357, -2.0357, -2.0357]],
162
+ [[-1.8044, -1.8044, -1.8044], [-1.8044, -1.8044, -1.8044], [-1.8044, -1.8044, -1.8044]],
163
+ ]
164
+ )
165
+
166
+ assert torch.allclose(inputs.pixel_values[0, :, :3, :3], expected_pixel_values, atol=1e-4)
167
+ assert torch.allclose(inputs.prompt_pixel_values[0, :, :3, :3], expected_prompt_pixel_values, atol=1e-4)
168
+ assert torch.allclose(inputs.prompt_masks[0, :, :3, :3], expected_prompt_masks, atol=1e-4)
169
+
170
+ torch.manual_seed(2)
171
+ outputs = model(**inputs)
172
+ print(outputs)
173
+
174
+ if verify_logits:
175
+ expected_output = torch.tensor(
176
+ [
177
+ [[-2.1208, -2.1190, -2.1198], [-2.1237, -2.1228, -2.1227], [-2.1232, -2.1226, -2.1228]],
178
+ [[-2.0405, -2.0396, -2.0403], [-2.0434, -2.0434, -2.0433], [-2.0428, -2.0432, -2.0434]],
179
+ [[-1.8102, -1.8088, -1.8099], [-1.8131, -1.8126, -1.8129], [-1.8130, -1.8128, -1.8131]],
180
+ ]
181
+ )
182
+ assert torch.allclose(outputs.pred_masks[0, :, :3, :3], expected_output, atol=1e-4)
183
+ print("Looks good!")
184
+ else:
185
+ print("Converted without verifying logits")
186
+
187
+ if pytorch_dump_folder_path is not None:
188
+ print(f"Saving model and processor for {model_name} to {pytorch_dump_folder_path}")
189
+ model.save_pretrained(pytorch_dump_folder_path)
190
+ image_processor.save_pretrained(pytorch_dump_folder_path)
191
+
192
+ if push_to_hub:
193
+ print(f"Pushing model and processor for {model_name} to hub")
194
+ model.push_to_hub(f"EduardoPacheco/{model_name}")
195
+ image_processor.push_to_hub(f"EduardoPacheco/{model_name}")
196
+
197
+
198
+ if __name__ == "__main__":
199
+ parser = argparse.ArgumentParser()
200
+ # Required parameters
201
+ parser.add_argument(
202
+ "--model_name",
203
+ default="seggpt-vit-large",
204
+ type=str,
205
+ choices=["seggpt-vit-large"],
206
+ help="Name of the SegGpt model you'd like to convert.",
207
+ )
208
+ parser.add_argument(
209
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
210
+ )
211
+ parser.add_argument(
212
+ "--verify_logits",
213
+ action="store_false",
214
+ help="Whether or not to verify the logits against the original implementation.",
215
+ )
216
+ parser.add_argument(
217
+ "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
218
+ )
219
+
220
+ args = parser.parse_args()
221
+ convert_seggpt_checkpoint(args)
docs/transformers/build/lib/transformers/models/seggpt/image_processing_seggpt.py ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for SegGPT."""
16
+
17
+ from typing import Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
22
+ from ...image_transforms import resize, to_channel_dimension_format
23
+ from ...image_utils import (
24
+ IMAGENET_DEFAULT_MEAN,
25
+ IMAGENET_DEFAULT_STD,
26
+ ChannelDimension,
27
+ ImageInput,
28
+ PILImageResampling,
29
+ infer_channel_dimension_format,
30
+ is_scaled_image,
31
+ make_list_of_images,
32
+ to_numpy_array,
33
+ valid_images,
34
+ )
35
+ from ...utils import TensorType, is_torch_available, is_vision_available, logging, requires_backends
36
+
37
+
38
+ if is_torch_available():
39
+ import torch
40
+
41
+ if is_vision_available():
42
+ pass
43
+
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+
48
+ # See https://arxiv.org/pdf/2212.02499.pdf at 3.1 Redefining Output Spaces as "Images" - Semantic Segmentation from PAINTER paper
49
+ # Taken from https://github.com/Abdullah-Meda/Painter/blob/main/Painter/data/coco_semseg/gen_color_coco_panoptic_segm.py#L31
50
+ def build_palette(num_labels: int) -> List[Tuple[int, int]]:
51
+ base = int(num_labels ** (1 / 3)) + 1
52
+ margin = 256 // base
53
+
54
+ # we assume that class_idx 0 is the background which is mapped to black
55
+ color_list = [(0, 0, 0)]
56
+ for location in range(num_labels):
57
+ num_seq_r = location // base**2
58
+ num_seq_g = (location % base**2) // base
59
+ num_seq_b = location % base
60
+
61
+ R = 255 - num_seq_r * margin
62
+ G = 255 - num_seq_g * margin
63
+ B = 255 - num_seq_b * margin
64
+
65
+ color_list.append((R, G, B))
66
+
67
+ return color_list
68
+
69
+
70
+ def mask_to_rgb(
71
+ mask: np.ndarray, palette: Optional[List[Tuple[int, int]]] = None, data_format: Optional[ChannelDimension] = None
72
+ ) -> np.ndarray:
73
+ data_format = data_format if data_format is not None else ChannelDimension.FIRST
74
+
75
+ if palette is not None:
76
+ height, width = mask.shape
77
+
78
+ rgb_mask = np.zeros((3, height, width), dtype=np.uint8)
79
+
80
+ classes_in_mask = np.unique(mask)
81
+
82
+ for class_idx in classes_in_mask:
83
+ rgb_value = palette[class_idx]
84
+ class_mask = (mask == class_idx).astype(np.uint8)
85
+ class_mask = np.expand_dims(class_mask, axis=-1)
86
+ class_rgb_mask = class_mask * np.array(rgb_value)
87
+ class_rgb_mask = np.moveaxis(class_rgb_mask, -1, 0)
88
+ rgb_mask += class_rgb_mask.astype(np.uint8)
89
+
90
+ rgb_mask = np.clip(rgb_mask, 0, 255).astype(np.uint8)
91
+
92
+ else:
93
+ rgb_mask = np.repeat(mask[None, ...], 3, axis=0)
94
+
95
+ return to_channel_dimension_format(rgb_mask, data_format)
96
+
97
+
98
+ class SegGptImageProcessor(BaseImageProcessor):
99
+ r"""
100
+ Constructs a SegGpt image processor.
101
+
102
+ Args:
103
+ do_resize (`bool`, *optional*, defaults to `True`):
104
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
105
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
106
+ size (`dict`, *optional*, defaults to `{"height": 448, "width": 448}`):
107
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
108
+ method.
109
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
110
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
111
+ `preprocess` method.
112
+ do_rescale (`bool`, *optional*, defaults to `True`):
113
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
114
+ parameter in the `preprocess` method.
115
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
116
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
117
+ `preprocess` method.
118
+ do_normalize (`bool`, *optional*, defaults to `True`):
119
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
120
+ method.
121
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
122
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
123
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
124
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
125
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
126
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
127
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
128
+ Whether to convert the prompt mask to RGB format. Can be overridden by the `do_convert_rgb` parameter in the
129
+ `preprocess` method.
130
+ """
131
+
132
+ model_input_names = ["pixel_values"]
133
+
134
+ def __init__(
135
+ self,
136
+ do_resize: bool = True,
137
+ size: Optional[Dict[str, int]] = None,
138
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
139
+ do_rescale: bool = True,
140
+ rescale_factor: Union[int, float] = 1 / 255,
141
+ do_normalize: bool = True,
142
+ image_mean: Optional[Union[float, List[float]]] = None,
143
+ image_std: Optional[Union[float, List[float]]] = None,
144
+ do_convert_rgb: bool = True,
145
+ **kwargs,
146
+ ) -> None:
147
+ super().__init__(**kwargs)
148
+ size = size if size is not None else {"height": 448, "width": 448}
149
+ size = get_size_dict(size)
150
+ self.do_resize = do_resize
151
+ self.do_rescale = do_rescale
152
+ self.do_normalize = do_normalize
153
+ self.size = size
154
+ self.resample = resample
155
+ self.rescale_factor = rescale_factor
156
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
157
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
158
+ self.do_convert_rgb = do_convert_rgb
159
+
160
+ def get_palette(self, num_labels: int) -> List[Tuple[int, int]]:
161
+ """Build a palette to map the prompt mask from a single channel to a 3 channel RGB.
162
+
163
+ Args:
164
+ num_labels (`int`):
165
+ Number of classes in the segmentation task (excluding the background).
166
+
167
+ Returns:
168
+ `List[Tuple[int, int]]`: Palette to map the prompt mask from a single channel to a 3 channel RGB.
169
+ """
170
+ return build_palette(num_labels)
171
+
172
+ def mask_to_rgb(
173
+ self,
174
+ image: np.ndarray,
175
+ palette: Optional[List[Tuple[int, int]]] = None,
176
+ data_format: Optional[Union[str, ChannelDimension]] = None,
177
+ ) -> np.ndarray:
178
+ """Converts a segmentation map to RGB format.
179
+
180
+ Args:
181
+ image (`np.ndarray`):
182
+ Segmentation map with dimensions (height, width) where pixel values represent the class index.
183
+ palette (`List[Tuple[int, int]]`, *optional*, defaults to `None`):
184
+ Palette to use to convert the mask to RGB format. If unset, the mask is duplicated across the channel
185
+ dimension.
186
+ data_format (`ChannelDimension` or `str`, *optional*):
187
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
188
+ image is used. Can be one of:
189
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
190
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
191
+
192
+ Returns:
193
+ `np.ndarray`: The mask in RGB format.
194
+ """
195
+ return mask_to_rgb(image, palette=palette, data_format=data_format)
196
+
197
+ # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
198
+ def resize(
199
+ self,
200
+ image: np.ndarray,
201
+ size: Dict[str, int],
202
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
203
+ data_format: Optional[Union[str, ChannelDimension]] = None,
204
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
205
+ **kwargs,
206
+ ) -> np.ndarray:
207
+ """
208
+ Resize an image to `(size["height"], size["width"])`.
209
+
210
+ Args:
211
+ image (`np.ndarray`):
212
+ Image to resize.
213
+ size (`Dict[str, int]`):
214
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
215
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
216
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
217
+ data_format (`ChannelDimension` or `str`, *optional*):
218
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
219
+ image is used. Can be one of:
220
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
221
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
222
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
223
+ input_data_format (`ChannelDimension` or `str`, *optional*):
224
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
225
+ from the input image. Can be one of:
226
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
227
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
228
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
229
+
230
+ Returns:
231
+ `np.ndarray`: The resized image.
232
+ """
233
+ size = get_size_dict(size)
234
+ if "height" not in size or "width" not in size:
235
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
236
+ output_size = (size["height"], size["width"])
237
+ return resize(
238
+ image,
239
+ size=output_size,
240
+ resample=resample,
241
+ data_format=data_format,
242
+ input_data_format=input_data_format,
243
+ **kwargs,
244
+ )
245
+
246
+ def _preprocess_step(
247
+ self,
248
+ images: ImageInput,
249
+ do_resize: Optional[bool] = None,
250
+ size: Dict[str, int] = None,
251
+ resample: PILImageResampling = None,
252
+ do_rescale: Optional[bool] = None,
253
+ rescale_factor: Optional[float] = None,
254
+ do_normalize: Optional[bool] = None,
255
+ image_mean: Optional[Union[float, List[float]]] = None,
256
+ image_std: Optional[Union[float, List[float]]] = None,
257
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
258
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
259
+ do_convert_rgb: Optional[bool] = None,
260
+ num_labels: Optional[int] = None,
261
+ **kwargs,
262
+ ):
263
+ """
264
+ Preprocess an image or batch of images.
265
+
266
+ Args:
267
+ images (`ImageInput`):
268
+ Image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
269
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
270
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
271
+ Whether to resize the image.
272
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
273
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
274
+ resizing.
275
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
276
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BICUBIC`. Only has
277
+ an effect if `do_resize` is set to `True`.
278
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
279
+ Whether to rescale the image values between [0 - 1].
280
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
281
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
282
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
283
+ Whether to normalize the image.
284
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
285
+ Image mean to use if `do_normalize` is set to `True`.
286
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
287
+ Image standard deviation to use if `do_normalize` is set to `True`.
288
+ return_tensors (`str` or `TensorType`, *optional*):
289
+ The type of tensors to return. Can be one of:
290
+ - Unset: Return a list of `np.ndarray`.
291
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
292
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
293
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
294
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
295
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
296
+ The channel dimension format for the output image. Can be one of:
297
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
298
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
299
+ - Unset: Use the channel dimension format of the input image.
300
+ input_data_format (`ChannelDimension` or `str`, *optional*):
301
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
302
+ from the input image. Can be one of:
303
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
304
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
305
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
306
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
307
+ Whether to convert the prompt mask to RGB format. If `num_labels` is specified, a palette will be built
308
+ to map the prompt mask from a single channel to a 3 channel RGB. If unset, the prompt mask is duplicated
309
+ across the channel dimension. Must be set to `False` if the prompt mask is already in RGB format.
310
+ num_labels: (`int`, *optional*):
311
+ Number of classes in the segmentation task (excluding the background). If specified, a palette will be
312
+ built, assuming that class_idx 0 is the background, to map the prompt mask from a single class_idx
313
+ channel to a 3 channel RGB. Not specifying this will result in the prompt mask either being passed
314
+ through as is if it is already in RGB format or being duplicated across the channel dimension.
315
+ """
316
+ do_resize = do_resize if do_resize is not None else self.do_resize
317
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
318
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
319
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
320
+ resample = resample if resample is not None else self.resample
321
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
322
+ image_mean = image_mean if image_mean is not None else self.image_mean
323
+ image_std = image_std if image_std is not None else self.image_std
324
+
325
+ size = size if size is not None else self.size
326
+ size_dict = get_size_dict(size)
327
+
328
+ # If segmentation map is passed we expect 2D images
329
+ images = make_list_of_images(images, expected_ndims=2 if do_convert_rgb else 3)
330
+
331
+ if not valid_images(images):
332
+ raise ValueError(
333
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
334
+ "torch.Tensor, tf.Tensor or jax.ndarray."
335
+ )
336
+
337
+ if do_resize and size is None:
338
+ raise ValueError("Size must be specified if do_resize is True.")
339
+
340
+ if do_rescale and rescale_factor is None:
341
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
342
+
343
+ if do_normalize and (image_mean is None or image_std is None):
344
+ raise ValueError("Image mean and std must be specified if do_normalize is True.")
345
+
346
+ # All transformations expect numpy arrays.
347
+ images = [to_numpy_array(image) for image in images]
348
+
349
+ if do_rescale and is_scaled_image(images[0]):
350
+ logger.warning_once(
351
+ "It looks like you are trying to rescale already rescaled images. If the input"
352
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
353
+ )
354
+
355
+ if input_data_format is None and not do_convert_rgb:
356
+ # We assume that all images have the same channel dimension format.
357
+ input_data_format = infer_channel_dimension_format(images[0])
358
+
359
+ if do_convert_rgb:
360
+ palette = self.get_palette(num_labels) if num_labels is not None else None
361
+ # Since this is the input for the next transformations its format should be the same as the input_data_format
362
+ images = [
363
+ self.mask_to_rgb(image=image, palette=palette, data_format=ChannelDimension.FIRST) for image in images
364
+ ]
365
+ input_data_format = ChannelDimension.FIRST
366
+
367
+ if do_resize:
368
+ images = [
369
+ self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format)
370
+ for image in images
371
+ ]
372
+
373
+ if do_rescale:
374
+ images = [
375
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
376
+ for image in images
377
+ ]
378
+
379
+ if do_normalize:
380
+ images = [
381
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
382
+ for image in images
383
+ ]
384
+
385
+ images = [
386
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
387
+ ]
388
+
389
+ return images
390
+
391
+ def preprocess(
392
+ self,
393
+ images: Optional[ImageInput] = None,
394
+ prompt_images: Optional[ImageInput] = None,
395
+ prompt_masks: Optional[ImageInput] = None,
396
+ do_resize: Optional[bool] = None,
397
+ size: Dict[str, int] = None,
398
+ resample: PILImageResampling = None,
399
+ do_rescale: Optional[bool] = None,
400
+ rescale_factor: Optional[float] = None,
401
+ do_normalize: Optional[bool] = None,
402
+ image_mean: Optional[Union[float, List[float]]] = None,
403
+ image_std: Optional[Union[float, List[float]]] = None,
404
+ do_convert_rgb: Optional[bool] = None,
405
+ num_labels: Optional[int] = None,
406
+ return_tensors: Optional[Union[str, TensorType]] = None,
407
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
408
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
409
+ **kwargs,
410
+ ):
411
+ """
412
+ Preprocess an image or batch of images.
413
+
414
+ Args:
415
+ images (`ImageInput`):
416
+ Image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
417
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
418
+ prompt_images (`ImageInput`):
419
+ Prompt image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
420
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
421
+ prompt_masks (`ImageInput`):
422
+ Prompt mask from prompt image to _preprocess that specify prompt_masks value in the preprocessed output.
423
+ Can either be in the format of segmentation maps (no channels) or RGB images. If in the format of
424
+ RGB images, `do_convert_rgb` should be set to `False`. If in the format of segmentation maps, `num_labels`
425
+ specifying `num_labels` is recommended to build a palette to map the prompt mask from a single channel to
426
+ a 3 channel RGB. If `num_labels` is not specified, the prompt mask will be duplicated across the channel
427
+ dimension.
428
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
429
+ Whether to resize the image.
430
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
431
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
432
+ resizing.
433
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
434
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BICUBIC`. Only has
435
+ an effect if `do_resize` is set to `True`. Doesn't apply to prompt mask as it is resized using nearest.
436
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
437
+ Whether to rescale the image values between [0 - 1].
438
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
439
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
440
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
441
+ Whether to normalize the image.
442
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
443
+ Image mean to use if `do_normalize` is set to `True`.
444
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
445
+ Image standard deviation to use if `do_normalize` is set to `True`.
446
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
447
+ Whether to convert the prompt mask to RGB format. If `num_labels` is specified, a palette will be built
448
+ to map the prompt mask from a single channel to a 3 channel RGB. If unset, the prompt mask is duplicated
449
+ across the channel dimension. Must be set to `False` if the prompt mask is already in RGB format.
450
+ num_labels: (`int`, *optional*):
451
+ Number of classes in the segmentation task (excluding the background). If specified, a palette will be
452
+ built, assuming that class_idx 0 is the background, to map the prompt mask from a plain segmentation map
453
+ with no channels to a 3 channel RGB. Not specifying this will result in the prompt mask either being passed
454
+ through as is if it is already in RGB format (if `do_convert_rgb` is false) or being duplicated
455
+ across the channel dimension.
456
+ return_tensors (`str` or `TensorType`, *optional*):
457
+ The type of tensors to return. Can be one of:
458
+ - Unset: Return a list of `np.ndarray`.
459
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
460
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
461
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
462
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
463
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
464
+ The channel dimension format for the output image. Can be one of:
465
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
466
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
467
+ - Unset: Use the channel dimension format of the input image.
468
+ input_data_format (`ChannelDimension` or `str`, *optional*):
469
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
470
+ from the input image. Can be one of:
471
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
472
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
473
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
474
+ """
475
+ if all(v is None for v in [images, prompt_images, prompt_masks]):
476
+ raise ValueError("At least one of images, prompt_images, prompt_masks must be specified.")
477
+
478
+ data = {}
479
+
480
+ if images is not None:
481
+ images = self._preprocess_step(
482
+ images,
483
+ is_mask=False,
484
+ do_resize=do_resize,
485
+ size=size,
486
+ resample=resample,
487
+ do_rescale=do_rescale,
488
+ rescale_factor=rescale_factor,
489
+ do_normalize=do_normalize,
490
+ image_mean=image_mean,
491
+ image_std=image_std,
492
+ do_convert_rgb=False,
493
+ data_format=data_format,
494
+ input_data_format=input_data_format,
495
+ **kwargs,
496
+ )
497
+
498
+ data["pixel_values"] = images
499
+
500
+ if prompt_images is not None:
501
+ prompt_images = self._preprocess_step(
502
+ prompt_images,
503
+ is_mask=False,
504
+ do_resize=do_resize,
505
+ size=size,
506
+ resample=resample,
507
+ do_rescale=do_rescale,
508
+ rescale_factor=rescale_factor,
509
+ do_normalize=do_normalize,
510
+ image_mean=image_mean,
511
+ image_std=image_std,
512
+ do_convert_rgb=False,
513
+ data_format=data_format,
514
+ input_data_format=input_data_format,
515
+ **kwargs,
516
+ )
517
+
518
+ data["prompt_pixel_values"] = prompt_images
519
+
520
+ if prompt_masks is not None:
521
+ prompt_masks = self._preprocess_step(
522
+ prompt_masks,
523
+ do_resize=do_resize,
524
+ size=size,
525
+ resample=PILImageResampling.NEAREST,
526
+ do_rescale=do_rescale,
527
+ rescale_factor=rescale_factor,
528
+ do_normalize=do_normalize,
529
+ image_mean=image_mean,
530
+ image_std=image_std,
531
+ do_convert_rgb=do_convert_rgb,
532
+ num_labels=num_labels,
533
+ data_format=data_format,
534
+ input_data_format=input_data_format,
535
+ **kwargs,
536
+ )
537
+
538
+ data["prompt_masks"] = prompt_masks
539
+
540
+ return BatchFeature(data=data, tensor_type=return_tensors)
541
+
542
+ def post_process_semantic_segmentation(
543
+ self, outputs, target_sizes: Optional[List[Tuple[int, int]]] = None, num_labels: Optional[int] = None
544
+ ):
545
+ """
546
+ Converts the output of [`SegGptImageSegmentationOutput`] into segmentation maps. Only supports
547
+ PyTorch.
548
+
549
+ Args:
550
+ outputs ([`SegGptImageSegmentationOutput`]):
551
+ Raw outputs of the model.
552
+ target_sizes (`List[Tuple[int, int]]`, *optional*):
553
+ List of length (batch_size), where each list item (`Tuple[int, int]`) corresponds to the requested
554
+ final size (height, width) of each prediction. If left to None, predictions will not be resized.
555
+ num_labels (`int`, *optional*):
556
+ Number of classes in the segmentation task (excluding the background). If specified, a palette will be
557
+ built, assuming that class_idx 0 is the background, to map prediction masks from RGB values to class
558
+ indices. This value should be the same used when preprocessing inputs.
559
+ Returns:
560
+ semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
561
+ segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
562
+ specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
563
+ """
564
+ requires_backends(self, ["torch"])
565
+ # batch_size x num_channels x 2*height x width
566
+ masks = outputs.pred_masks
567
+
568
+ # Predicted mask and prompt are concatenated in the height dimension
569
+ # batch_size x num_channels x height x width
570
+ masks = masks[:, :, masks.shape[2] // 2 :, :]
571
+
572
+ # To unnormalize we need to permute to channel last
573
+ # batch_size x height x width x num_channels
574
+ std = torch.tensor(self.image_std).to(masks.device)
575
+ mean = torch.tensor(self.image_mean).to(masks.device)
576
+
577
+ masks = masks.permute(0, 2, 3, 1) * std + mean
578
+
579
+ # batch_size x num_channels x height x width
580
+ masks = masks.permute(0, 3, 1, 2)
581
+
582
+ # Clip to match with palette if specified
583
+ masks = torch.clip(masks * 255, 0, 255)
584
+
585
+ semantic_segmentation = []
586
+ palette_tensor = None
587
+ palette = self.get_palette(num_labels) if num_labels is not None else None
588
+ if palette is not None:
589
+ palette_tensor = torch.tensor(palette).to(device=masks.device, dtype=torch.float)
590
+ _, num_channels, _, _ = masks.shape
591
+ palette_tensor = palette_tensor.view(1, 1, num_labels + 1, num_channels)
592
+
593
+ for idx, mask in enumerate(masks):
594
+ if target_sizes is not None:
595
+ mask = torch.nn.functional.interpolate(
596
+ mask.unsqueeze(0),
597
+ size=target_sizes[idx],
598
+ mode="nearest",
599
+ )[0]
600
+
601
+ if num_labels is not None:
602
+ channels, height, width = mask.shape
603
+ dist = mask.permute(1, 2, 0).view(height, width, 1, channels)
604
+ dist = dist - palette_tensor
605
+ dist = torch.pow(dist, 2)
606
+ dist = torch.sum(dist, dim=-1)
607
+ pred = dist.argmin(dim=-1)
608
+
609
+ else:
610
+ # If no palette is specified SegGpt will try to paint using the mask class idx as RGB
611
+ pred = mask.mean(dim=0).int()
612
+
613
+ semantic_segmentation.append(pred)
614
+
615
+ return semantic_segmentation
616
+
617
+
618
+ __all__ = ["SegGptImageProcessor"]
docs/transformers/build/lib/transformers/models/seggpt/modeling_seggpt.py ADDED
@@ -0,0 +1,1031 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch SegGpt model."""
16
+
17
+ import collections.abc
18
+ from dataclasses import dataclass
19
+ from typing import Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import functional as F
25
+
26
+ from ...activations import ACT2FN
27
+ from ...modeling_utils import PreTrainedModel
28
+ from ...utils import (
29
+ ModelOutput,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ logging,
33
+ replace_return_docstrings,
34
+ torch_int,
35
+ )
36
+ from .configuration_seggpt import SegGptConfig
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+ # General docstring
42
+ _CONFIG_FOR_DOC = "SegGptConfig"
43
+
44
+ # Base docstring
45
+ _CHECKPOINT_FOR_DOC = "BAAI/seggpt-vit-large"
46
+ _EXPECTED_OUTPUT_SHAPE = [3, 896, 448]
47
+
48
+
49
+ @dataclass
50
+ class SegGptEncoderOutput(ModelOutput):
51
+ """
52
+ Output type of [`SegGptEncoderOutput`].
53
+ Args:
54
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, patch_height, patch_width, hidden_size)`):
55
+ Sequence of hidden-states at the output of the last layer of the model.
56
+ hidden_states (`Tuple[torch.FloatTensor]`, `optional`, returned when `config.output_hidden_states=True`):
57
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
58
+ of shape `(batch_size, patch_height, patch_width, hidden_size)`.
59
+ attentions (`Tuple[torch.FloatTensor]`, `optional`, returned when `config.output_attentions=True`):
60
+ Tuple of *torch.FloatTensor* (one for each layer) of shape
61
+ `(batch_size, num_heads, seq_len, seq_len)`.
62
+ intermediate_hidden_states (`Tuple[torch.FloatTensor]`, *optional*, returned when `config.intermediate_hidden_state_indices` is set):
63
+ Tuple of `torch.FloatTensor` of shape `(batch_size, patch_height, patch_width, hidden_size)`.
64
+ Each element in the Tuple corresponds to the output of the layer specified in `config.intermediate_hidden_state_indices`.
65
+ Additionaly, each feature passes through a LayerNorm.
66
+ """
67
+
68
+ last_hidden_state: torch.FloatTensor
69
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
70
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
71
+ intermediate_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
72
+
73
+
74
+ @dataclass
75
+ class SegGptImageSegmentationOutput(ModelOutput):
76
+ """
77
+ Output type of [`SegGptImageSegmentationOutput`].
78
+
79
+ Args:
80
+ loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
81
+ The loss value.
82
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
83
+ The predicted masks.
84
+ hidden_states (`Tuple[torch.FloatTensor]`, `optional`, returned when `config.output_hidden_states=True`):
85
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
86
+ of shape `(batch_size, patch_height, patch_width, hidden_size)`.
87
+ attentions (`Tuple[torch.FloatTensor]`, `optional`, returned when `config.output_attentions=True`):
88
+ Tuple of `torch.FloatTensor` (one for each layer) of shape
89
+ `(batch_size, num_heads, seq_len, seq_len)`.
90
+ """
91
+
92
+ loss: Optional[torch.FloatTensor] = None
93
+ pred_masks: Optional[torch.FloatTensor] = None
94
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
95
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
96
+
97
+
98
+ # Copied from transformers.models.sam.modeling_sam.SamPatchEmbeddings with Sam->SegGpt
99
+ class SegGptPatchEmbeddings(nn.Module):
100
+ """
101
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
102
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
103
+ Transformer.
104
+ """
105
+
106
+ def __init__(self, config):
107
+ super().__init__()
108
+ image_size, patch_size = config.image_size, config.patch_size
109
+ num_channels, hidden_size = config.num_channels, config.hidden_size
110
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
111
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
112
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
113
+ self.image_size = image_size
114
+ self.patch_size = patch_size
115
+ self.num_channels = num_channels
116
+ self.num_patches = num_patches
117
+
118
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
119
+
120
+ def forward(self, pixel_values):
121
+ batch_size, num_channels, height, width = pixel_values.shape
122
+ if num_channels != self.num_channels:
123
+ raise ValueError(
124
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
125
+ )
126
+ if height != self.image_size[0] or width != self.image_size[1]:
127
+ raise ValueError(
128
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
129
+ )
130
+ embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
131
+ return embeddings
132
+
133
+
134
+ class SegGptEmbeddings(nn.Module):
135
+ """
136
+ Construct the embeddings from patch, position embeddings for input and prompt.
137
+ """
138
+
139
+ def __init__(self, config: SegGptConfig) -> None:
140
+ super().__init__()
141
+
142
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size))
143
+ self.segment_token_input = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size))
144
+ self.segment_token_prompt = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size))
145
+ # token for seg types
146
+ self.type_token_semantic = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size))
147
+ self.type_token_instance = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size))
148
+
149
+ self.patch_embeddings = SegGptPatchEmbeddings(config)
150
+
151
+ num_positions = (config.pretrain_image_size // config.patch_size) ** 2 + 1
152
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_positions, config.hidden_size))
153
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
154
+
155
+ def interpolate_pos_encoding(self, height: int, width: int) -> torch.Tensor:
156
+ patch_pos_embed = self.position_embeddings[:, 1:]
157
+ num_patches = patch_pos_embed.shape[1]
158
+ pretrain_patch_size = torch_int(num_patches**0.5)
159
+
160
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
161
+ if torch.jit.is_tracing() or pretrain_patch_size != height or pretrain_patch_size != width:
162
+ patch_pos_embed = F.interpolate(
163
+ patch_pos_embed.reshape(1, pretrain_patch_size, pretrain_patch_size, -1).permute(0, 3, 1, 2),
164
+ size=(height, width),
165
+ mode="bicubic",
166
+ align_corners=False,
167
+ )
168
+
169
+ return patch_pos_embed.permute(0, 2, 3, 1)
170
+ else:
171
+ return patch_pos_embed.reshape(1, height, width, -1)
172
+
173
+ def forward(
174
+ self,
175
+ pixel_values: torch.Tensor,
176
+ prompt_pixel_values: torch.Tensor,
177
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
178
+ embedding_type: Optional[str] = None,
179
+ ) -> torch.Tensor:
180
+ input_embeddings = self.patch_embeddings(pixel_values)
181
+ prompt_embeddings = self.patch_embeddings(prompt_pixel_values)
182
+
183
+ batch_size, patch_height, patch_width, _ = input_embeddings.shape
184
+
185
+ mask_token = self.mask_token.expand(batch_size, patch_height, patch_width, -1)
186
+ # replace the masked visual tokens by mask_token
187
+ w = bool_masked_pos.unsqueeze(-1).type_as(mask_token).reshape(-1, patch_height, patch_width, 1)
188
+ prompt_embeddings = prompt_embeddings * (1 - w) + mask_token * w
189
+
190
+ embedding_type = embedding_type if embedding_type is not None else "instance"
191
+
192
+ # add positional encoding to each token
193
+ pos_embed = self.interpolate_pos_encoding(patch_height, patch_width)
194
+
195
+ # add segment token
196
+ input_embeddings = input_embeddings + self.segment_token_input
197
+ prompt_embeddings = prompt_embeddings + self.segment_token_prompt
198
+
199
+ # add position embedding skipping CLS
200
+ input_embeddings = input_embeddings + pos_embed
201
+ prompt_embeddings = prompt_embeddings + pos_embed
202
+
203
+ # add type embedding to each token
204
+ if embedding_type == "semantic":
205
+ type_embedding = self.type_token_semantic
206
+ elif embedding_type == "instance":
207
+ type_embedding = self.type_token_instance
208
+ else:
209
+ raise ValueError(f"Embedding type should be either 'semantic' or 'instance', but got {embedding_type}")
210
+
211
+ input_embeddings = input_embeddings + type_embedding
212
+ prompt_embeddings = prompt_embeddings + type_embedding
213
+
214
+ embeddings = torch.cat((input_embeddings, prompt_embeddings), dim=0)
215
+
216
+ return embeddings
217
+
218
+
219
+ class SegGptAttention(nn.Module):
220
+ """Multi-head Attention block with relative position embeddings."""
221
+
222
+ def __init__(self, config):
223
+ super().__init__()
224
+ image_size, patch_size = config.image_size, config.patch_size
225
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
226
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
227
+
228
+ input_size = (image_size[0] // config.patch_size, image_size[1] // config.patch_size)
229
+ head_dim = config.hidden_size // config.num_attention_heads
230
+
231
+ self.num_attention_heads = config.num_attention_heads
232
+ self.scale = head_dim**-0.5
233
+
234
+ self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias)
235
+ self.proj = nn.Linear(config.hidden_size, config.hidden_size)
236
+
237
+ self.use_relative_position_embeddings = config.use_relative_position_embeddings
238
+ if self.use_relative_position_embeddings:
239
+ if input_size is None:
240
+ raise ValueError("Input size must be provided if using relative positional encoding.")
241
+
242
+ # initialize relative positional embeddings
243
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
244
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
245
+
246
+ def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
247
+ """
248
+ Get relative positional embeddings according to the relative positions of
249
+ query and key sizes.
250
+
251
+ Args:
252
+ q_size (int):
253
+ size of the query.
254
+ k_size (int):
255
+ size of key k.
256
+ rel_pos (`torch.Tensor`):
257
+ relative position embeddings (L, channel).
258
+
259
+ Returns:
260
+ Extracted positional embeddings according to relative positions.
261
+ """
262
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
263
+ # Interpolate rel pos.
264
+ rel_pos_resized = F.interpolate(
265
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
266
+ size=max_rel_dist,
267
+ mode="linear",
268
+ )
269
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
270
+
271
+ # Scale the coords with short length if shapes for q and k are different.
272
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
273
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
274
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
275
+
276
+ return rel_pos_resized[relative_coords.long()]
277
+
278
+ def add_decomposed_rel_pos(
279
+ self,
280
+ attn: torch.Tensor,
281
+ query: torch.Tensor,
282
+ rel_pos_h: torch.Tensor,
283
+ rel_pos_w: torch.Tensor,
284
+ q_size: Tuple[int, int],
285
+ k_size: Tuple[int, int],
286
+ ) -> torch.Tensor:
287
+ """
288
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
289
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
290
+
291
+ Args:
292
+ attn (`torch.Tensor`):
293
+ attention map.
294
+ query (`torch.Tensor`):
295
+ query q in the attention layer with shape (batch_size, query_height * query_width, channel).
296
+ rel_pos_h (`torch.Tensor`):
297
+ relative position embeddings (Lh, channel) for height axis.
298
+ rel_pos_w (`torch.Tensor`):
299
+ relative position embeddings (Lw, channel) for width axis.
300
+ q_size (tuple):
301
+ spatial sequence size of query q with (query_height, query_width).
302
+ k_size (tuple):
303
+ spatial sequence size of key k with (key_height, key_width).
304
+
305
+ Returns:
306
+ attn (`torch.Tensor`):
307
+ attention map with added relative positional embeddings.
308
+ """
309
+ query_height, query_width = q_size
310
+ key_height, key_width = k_size
311
+ relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
312
+ relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
313
+
314
+ batch_size, _, dim = query.shape
315
+ reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
316
+ rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
317
+ rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
318
+ attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
319
+ attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
320
+ attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
321
+ return attn
322
+
323
+ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
324
+ batch_size, height, width, _ = hidden_states.shape
325
+ # qkv with shape (3, batch_size, nHead, height * width, channel)
326
+ qkv = (
327
+ self.qkv(hidden_states)
328
+ .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
329
+ .permute(2, 0, 3, 1, 4)
330
+ )
331
+ # q, k, v with shape (batch_size * nHead, height * width, channel)
332
+ query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
333
+
334
+ attn_weights = (query * self.scale) @ key.transpose(-2, -1)
335
+
336
+ if self.use_relative_position_embeddings:
337
+ attn_weights = self.add_decomposed_rel_pos(
338
+ attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
339
+ )
340
+
341
+ attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
342
+
343
+ if output_attentions:
344
+ # this operation is a bit awkward, but it's required to
345
+ # make sure that attn_weights keeps its gradient.
346
+ # In order to do so, attn_weights have to reshaped
347
+ # twice and have to be reused in the following
348
+ attn_weights_reshaped = attn_weights.view(batch_size, self.num_attention_heads, height * width, -1)
349
+ attn_weights = attn_weights_reshaped.view(batch_size * self.num_attention_heads, height * width, -1)
350
+ else:
351
+ attn_weights_reshaped = None
352
+
353
+ attn_output = (attn_weights @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
354
+ attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
355
+
356
+ attn_output = self.proj(attn_output)
357
+
358
+ return (attn_output, attn_weights_reshaped)
359
+
360
+
361
+ # Copied from transformers.models.sam.modeling_sam.SamMLPBlock with SamMLPBlock->SegGptMlp
362
+ class SegGptMlp(nn.Module):
363
+ def __init__(self, config):
364
+ super().__init__()
365
+ self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
366
+ self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
367
+ self.act = ACT2FN[config.hidden_act]
368
+
369
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
370
+ hidden_states = self.lin1(hidden_states)
371
+ hidden_states = self.act(hidden_states)
372
+ hidden_states = self.lin2(hidden_states)
373
+ return hidden_states
374
+
375
+
376
+ # Copied from transformers.models.beit.modeling_beit.drop_path
377
+ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
378
+ """
379
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
380
+
381
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
382
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
383
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
384
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
385
+ argument.
386
+ """
387
+ if drop_prob == 0.0 or not training:
388
+ return input
389
+ keep_prob = 1 - drop_prob
390
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
391
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
392
+ random_tensor.floor_() # binarize
393
+ output = input.div(keep_prob) * random_tensor
394
+ return output
395
+
396
+
397
+ # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->SegGpt
398
+ class SegGptDropPath(nn.Module):
399
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
400
+
401
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
402
+ super().__init__()
403
+ self.drop_prob = drop_prob
404
+
405
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
406
+ return drop_path(hidden_states, self.drop_prob, self.training)
407
+
408
+ def extra_repr(self) -> str:
409
+ return "p={}".format(self.drop_prob)
410
+
411
+
412
+ class SegGptLayer(nn.Module):
413
+ def __init__(self, config: SegGptConfig, drop_path_rate: float) -> None:
414
+ super().__init__()
415
+ self.attention = SegGptAttention(config)
416
+ self.mlp = SegGptMlp(config)
417
+ self.drop_path = SegGptDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
418
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
419
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
420
+
421
+ def forward(
422
+ self,
423
+ hidden_states: torch.Tensor,
424
+ ensemble_cond: int,
425
+ feature_ensemble: bool = False,
426
+ output_attentions: bool = False,
427
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
428
+ self_attention_outputs = self.attention(
429
+ self.layernorm_before(hidden_states), # in SegGpt, layernorm is applied before self-attention
430
+ output_attentions=output_attentions,
431
+ )
432
+ attention_output = self_attention_outputs[0]
433
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
434
+
435
+ if feature_ensemble and attention_output.shape[0] // 2 >= ensemble_cond:
436
+ prompt, inputs = attention_output.split(attention_output.shape[1] // 2, dim=1)
437
+ if ensemble_cond == 2:
438
+ num_prompts = attention_output.shape[0] // 2
439
+ inputs = inputs.reshape(2, num_prompts, -1)
440
+ inputs = inputs.mean(dim=1, keepdim=True).expand_as(inputs)
441
+ inputs = inputs.reshape(*prompt.shape)
442
+ else:
443
+ inputs = inputs.mean(dim=0, keepdim=True).expand_as(inputs)
444
+ attention_output = torch.cat([prompt, inputs], dim=1)
445
+
446
+ # first residual connection
447
+ hidden_states = self.drop_path(attention_output) + hidden_states
448
+ residual = hidden_states
449
+
450
+ hidden_states = self.layernorm_after(hidden_states)
451
+ hidden_states = self.mlp(hidden_states)
452
+ hidden_states = residual + self.drop_path(hidden_states)
453
+
454
+ outputs = (hidden_states,) + outputs
455
+
456
+ return outputs
457
+
458
+
459
+ class SegGptEncoder(nn.Module):
460
+ def __init__(self, config: SegGptConfig) -> None:
461
+ super().__init__()
462
+ self.config = config
463
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
464
+ self.layers = nn.ModuleList([SegGptLayer(config, dpr[i]) for i in range(config.num_hidden_layers)])
465
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
466
+ self.gradient_checkpointing = False
467
+
468
+ def forward(
469
+ self,
470
+ hidden_states: torch.Tensor,
471
+ feature_ensemble: bool = False,
472
+ output_attentions: bool = False,
473
+ output_hidden_states: bool = False,
474
+ return_dict: bool = True,
475
+ ) -> Union[tuple, SegGptEncoderOutput]:
476
+ all_hidden_states = () if output_hidden_states else None
477
+ all_self_attentions = () if output_attentions else None
478
+ intermediate_hidden_states = []
479
+
480
+ for i, layer_module in enumerate(self.layers):
481
+ if output_hidden_states:
482
+ all_hidden_states = all_hidden_states + (hidden_states,)
483
+
484
+ # Condition to check if we have the appropriate number of prompts to ensemble
485
+ ensemble_cond = 2 if self.config.merge_index > i else 1
486
+
487
+ if self.gradient_checkpointing and self.training:
488
+ layer_outputs = self._gradient_checkpointing_func(
489
+ layer_module.__call__,
490
+ hidden_states,
491
+ ensemble_cond,
492
+ feature_ensemble,
493
+ output_attentions,
494
+ )
495
+ else:
496
+ layer_outputs = layer_module(hidden_states, ensemble_cond, feature_ensemble, output_attentions)
497
+
498
+ hidden_states = layer_outputs[0]
499
+
500
+ if i == self.config.merge_index:
501
+ hidden_states = (
502
+ hidden_states[: hidden_states.shape[0] // 2] + hidden_states[hidden_states.shape[0] // 2 :]
503
+ ) * 0.5
504
+
505
+ if i in self.config.intermediate_hidden_state_indices:
506
+ intermediate_hidden_states.append(self.layernorm(hidden_states))
507
+
508
+ if output_attentions:
509
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
510
+
511
+ if output_hidden_states:
512
+ all_hidden_states = all_hidden_states + (hidden_states,)
513
+
514
+ if not return_dict:
515
+ return tuple(
516
+ v
517
+ for v in [hidden_states, all_hidden_states, all_self_attentions, intermediate_hidden_states]
518
+ if v is not None
519
+ )
520
+ return SegGptEncoderOutput(
521
+ last_hidden_state=hidden_states,
522
+ hidden_states=all_hidden_states,
523
+ attentions=all_self_attentions,
524
+ intermediate_hidden_states=intermediate_hidden_states,
525
+ )
526
+
527
+
528
+ # Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->SegGpt
529
+ class SegGptLayerNorm(nn.Module):
530
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
531
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
532
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
533
+ """
534
+
535
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
536
+ super().__init__()
537
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
538
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
539
+ self.eps = eps
540
+ self.data_format = data_format
541
+ if self.data_format not in ["channels_last", "channels_first"]:
542
+ raise NotImplementedError(f"Unsupported data format: {self.data_format}")
543
+ self.normalized_shape = (normalized_shape,)
544
+
545
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
546
+ if self.data_format == "channels_last":
547
+ x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
548
+ elif self.data_format == "channels_first":
549
+ input_dtype = x.dtype
550
+ x = x.float()
551
+ u = x.mean(1, keepdim=True)
552
+ s = (x - u).pow(2).mean(1, keepdim=True)
553
+ x = (x - u) / torch.sqrt(s + self.eps)
554
+ x = x.to(dtype=input_dtype)
555
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
556
+ return x
557
+
558
+
559
+ class SegGptDecoderHead(nn.Module):
560
+ def __init__(self, config):
561
+ super().__init__()
562
+ self.conv = nn.Conv2d(
563
+ config.decoder_hidden_size,
564
+ config.decoder_hidden_size,
565
+ kernel_size=3,
566
+ padding=1,
567
+ )
568
+ self.layernorm = SegGptLayerNorm(
569
+ normalized_shape=config.decoder_hidden_size, eps=config.layer_norm_eps, data_format="channels_first"
570
+ )
571
+ self.act_fct = ACT2FN[config.hidden_act]
572
+ self.head = nn.Conv2d(config.decoder_hidden_size, 3, kernel_size=1, bias=True) # decoder to patch
573
+
574
+ def forward(self, hidden_states: torch.FloatTensor):
575
+ hidden_states = self.conv(hidden_states)
576
+ hidden_states = self.layernorm(hidden_states)
577
+ hidden_states = self.act_fct(hidden_states)
578
+ hidden_states = self.head(hidden_states)
579
+
580
+ return hidden_states
581
+
582
+
583
+ class SegGptDecoder(nn.Module):
584
+ def __init__(self, config):
585
+ super().__init__()
586
+ self.decoder_embed = nn.Linear(
587
+ config.hidden_size * len(config.intermediate_hidden_state_indices),
588
+ config.patch_size**2 * config.decoder_hidden_size,
589
+ bias=True,
590
+ )
591
+ self.decoder_pred = SegGptDecoderHead(config)
592
+ self.patch_size = config.patch_size
593
+ self.decoder_hidden_size = config.decoder_hidden_size
594
+ self.config = config
595
+
596
+ def _reshape_hidden_states(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
597
+ batch_size, patch_height, patch_width, _ = hidden_states.shape
598
+ hidden_states = hidden_states.reshape(
599
+ batch_size, patch_height, patch_width, self.patch_size, self.patch_size, self.decoder_hidden_size
600
+ )
601
+ hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
602
+ hidden_states = hidden_states.reshape(
603
+ shape=(batch_size, -1, patch_height * self.patch_size, patch_width * self.patch_size)
604
+ )
605
+
606
+ return hidden_states
607
+
608
+ def forward(self, hidden_states: torch.FloatTensor):
609
+ hidden_states = self.decoder_embed(hidden_states)
610
+ hidden_states = self._reshape_hidden_states(hidden_states)
611
+ hidden_states = self.decoder_pred(hidden_states)
612
+
613
+ return hidden_states
614
+
615
+
616
+ class SegGptPreTrainedModel(PreTrainedModel):
617
+ """
618
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
619
+ models.
620
+ """
621
+
622
+ config_class = SegGptConfig
623
+ base_model_prefix = "model"
624
+ main_input_name = "pixel_values"
625
+ supports_gradient_checkpointing = True
626
+ _no_split_modules = ["SegGptEmbeddings", "SegGptLayer"]
627
+
628
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
629
+ """Initialize the weights"""
630
+ std = self.config.initializer_range
631
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
632
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
633
+ # `trunc_normal_cpu` not implemented in `half` issues
634
+ module.weight.data = nn.init.trunc_normal_(module.weight.data.to(torch.float32), mean=0.0, std=std).to(
635
+ module.weight.dtype
636
+ )
637
+ if module.bias is not None:
638
+ module.bias.data.zero_()
639
+ elif isinstance(module, nn.LayerNorm):
640
+ module.bias.data.zero_()
641
+ module.weight.data.fill_(1.0)
642
+ elif isinstance(module, SegGptAttention):
643
+ module.rel_pos_h.data = nn.init.trunc_normal_(
644
+ module.rel_pos_h.data.to(torch.float32),
645
+ mean=0.0,
646
+ std=std,
647
+ ).to(module.rel_pos_h.dtype)
648
+
649
+ module.rel_pos_w.data = nn.init.trunc_normal_(
650
+ module.rel_pos_w.data.to(torch.float32),
651
+ mean=0.0,
652
+ std=std,
653
+ ).to(module.rel_pos_w.dtype)
654
+
655
+ elif isinstance(module, SegGptEmbeddings):
656
+ module.position_embeddings.data = nn.init.trunc_normal_(
657
+ module.position_embeddings.data.to(torch.float32),
658
+ mean=0.0,
659
+ std=std,
660
+ ).to(module.position_embeddings.dtype)
661
+
662
+ torch.nn.init.normal_(module.mask_token, std=std)
663
+ torch.nn.init.normal_(module.segment_token_input, std=std)
664
+ torch.nn.init.normal_(module.segment_token_prompt, std=std)
665
+ torch.nn.init.normal_(module.type_token_semantic, std=std)
666
+ torch.nn.init.normal_(module.type_token_instance, std=std)
667
+
668
+
669
+ SEGGPT_START_DOCSTRING = r"""
670
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
671
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
672
+ behavior.
673
+
674
+ Parameters:
675
+ config ([`SegGptConfig`]): Model configuration class with all the parameters of the model.
676
+ Initializing with a config file does not load the weights associated with the model, only the
677
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
678
+ """
679
+
680
+ SEGGPT_INPUTS_DOCSTRING = r"""
681
+ Args:
682
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
683
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`SegGptImageProcessor.__call__`]
684
+ for details.
685
+
686
+ prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
687
+ Prompt pixel values. Prompt pixel values can be obtained using [`AutoImageProcessor`]. See
688
+ [`SegGptImageProcessor.__call__`] for details.
689
+
690
+ prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
691
+ Prompt mask. Prompt mask can be obtained using [`AutoImageProcessor`]. See [`SegGptImageProcessor.__call__`] for
692
+ details.
693
+
694
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
695
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
696
+
697
+ feature_ensemble (`bool`, *optional*):
698
+ Boolean indicating whether to use feature ensemble or not. If `True`, the model will use feature ensemble
699
+ if we have at least two prompts. If `False`, the model will not use feature ensemble. This argument should
700
+ be considered when doing few-shot inference on an input image i.e. more than one prompt for the same image.
701
+
702
+ embedding_type (`str`, *optional*):
703
+ Embedding type. Indicates whether the prompt is a semantic or instance embedding. Can be either
704
+ instance or semantic.
705
+
706
+ output_attentions (`bool`, *optional*):
707
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
708
+ tensors for more detail.
709
+ output_hidden_states (`bool`, *optional*):
710
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
711
+ more detail.
712
+ return_dict (`bool`, *optional*):
713
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
714
+ """
715
+
716
+
717
+ @add_start_docstrings(
718
+ "The bare SegGpt Model transformer outputting raw hidden-states without any specific head on top.",
719
+ SEGGPT_START_DOCSTRING,
720
+ )
721
+ class SegGptModel(SegGptPreTrainedModel):
722
+ def __init__(self, config: SegGptConfig):
723
+ super().__init__(config)
724
+ self.config = config
725
+
726
+ self.embeddings = SegGptEmbeddings(config)
727
+ self.encoder = SegGptEncoder(config)
728
+
729
+ # Initialize weights and apply final processing
730
+ self.post_init()
731
+
732
+ def get_input_embeddings(self) -> SegGptPatchEmbeddings:
733
+ return self.embeddings.patch_embeddings
734
+
735
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
736
+ """
737
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
738
+ class PreTrainedModel
739
+ """
740
+ for layer, heads in heads_to_prune.items():
741
+ self.encoder.layer[layer].attention.prune_heads(heads)
742
+
743
+ @add_start_docstrings_to_model_forward(SEGGPT_INPUTS_DOCSTRING)
744
+ @replace_return_docstrings(output_type=SegGptEncoderOutput, config_class=_CONFIG_FOR_DOC)
745
+ def forward(
746
+ self,
747
+ pixel_values: torch.Tensor,
748
+ prompt_pixel_values: torch.Tensor,
749
+ prompt_masks: torch.Tensor,
750
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
751
+ feature_ensemble: Optional[bool] = None,
752
+ embedding_type: Optional[str] = None,
753
+ labels: Optional[torch.FloatTensor] = None,
754
+ output_attentions: Optional[bool] = None,
755
+ output_hidden_states: Optional[bool] = None,
756
+ return_dict: Optional[bool] = None,
757
+ ) -> Union[Tuple, SegGptEncoderOutput]:
758
+ r"""
759
+ labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, `optional`):
760
+ Ground truth mask for input images.
761
+
762
+ Returns:
763
+
764
+ Examples:
765
+
766
+ ```python
767
+ >>> from transformers import SegGptImageProcessor, SegGptModel
768
+ >>> from PIL import Image
769
+ >>> import requests
770
+
771
+ >>> image_input_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg"
772
+ >>> image_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg"
773
+ >>> mask_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png"
774
+
775
+ >>> image_input = Image.open(requests.get(image_input_url, stream=True).raw)
776
+ >>> image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw)
777
+ >>> mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw).convert("L")
778
+
779
+ >>> checkpoint = "BAAI/seggpt-vit-large"
780
+ >>> model = SegGptModel.from_pretrained(checkpoint)
781
+ >>> image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
782
+
783
+ >>> inputs = image_processor(images=image_input, prompt_images=image_prompt, prompt_masks=mask_prompt, return_tensors="pt")
784
+
785
+ >>> outputs = model(**inputs)
786
+ >>> list(outputs.last_hidden_state.shape)
787
+ [1, 56, 28, 1024]
788
+ ```
789
+ """
790
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
791
+ output_hidden_states = (
792
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
793
+ )
794
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
795
+ feature_ensemble = feature_ensemble if feature_ensemble is not None else False
796
+
797
+ expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
798
+ pixel_values = pixel_values.to(expected_dtype)
799
+ prompt_pixel_values = prompt_pixel_values.to(expected_dtype)
800
+
801
+ # Prepare inputs
802
+ pixel_values = torch.cat((prompt_pixel_values, pixel_values), dim=2)
803
+ prompt_pixel_values = (
804
+ torch.cat((prompt_masks, prompt_masks), dim=2)
805
+ if labels is None
806
+ else torch.cat((prompt_masks, labels), dim=2)
807
+ )
808
+
809
+ if bool_masked_pos is None and labels is not None:
810
+ logger.warning_once(
811
+ "Labels were provided, but bool_masked_pos were not. It will be set to default value. If you're training the model, make sure to provide a bool_masked_pos."
812
+ )
813
+
814
+ # We concat on height axis so SegGPT can handle as a single image, hence we need to mask the portion
815
+ # of the mask prompt pixels that will be destinated to the prediction as they don't add any information.
816
+ # This is only the case for inference. In training, the model concat of prompt mask and label is masked
817
+ # and reconstructed together (In-Context Painting).
818
+ if bool_masked_pos is None:
819
+ num_patches = self.embeddings.patch_embeddings.num_patches
820
+ bool_masked_pos_zeros = torch.zeros(num_patches // 2, dtype=torch.bool, device=pixel_values.device)
821
+ bool_masked_pos_ones = torch.ones(
822
+ num_patches - num_patches // 2, dtype=torch.bool, device=pixel_values.device
823
+ )
824
+ bool_masked_pos = torch.cat([bool_masked_pos_zeros, bool_masked_pos_ones])
825
+ bool_masked_pos = bool_masked_pos.unsqueeze(0)
826
+
827
+ embedding_output = self.embeddings(
828
+ pixel_values, prompt_pixel_values, embedding_type=embedding_type, bool_masked_pos=bool_masked_pos
829
+ )
830
+
831
+ encoder_outputs = self.encoder(
832
+ embedding_output,
833
+ feature_ensemble=feature_ensemble,
834
+ output_attentions=output_attentions,
835
+ output_hidden_states=output_hidden_states,
836
+ return_dict=return_dict,
837
+ )
838
+
839
+ return encoder_outputs
840
+
841
+
842
+ def patchify(tensor: torch.Tensor, patch_size: int) -> torch.Tensor:
843
+ batch_size, num_channels, height, width = tensor.shape
844
+ patch_height = height // patch_size
845
+ patch_width = width // patch_size
846
+
847
+ tensor = tensor.reshape(shape=(batch_size, num_channels, patch_height, patch_size, patch_width, patch_size))
848
+ tensor = tensor.permute(0, 2, 4, 3, 5, 1)
849
+ tensor = tensor.reshape(shape=(batch_size, patch_height * patch_width, patch_size**2 * 3))
850
+
851
+ return tensor
852
+
853
+
854
+ def unpatchify(tensor: torch.Tensor, patch_height: int, patch_width: int) -> torch.Tensor:
855
+ batch_size = tensor.shape[0]
856
+ patch_size = int((tensor.shape[-1] / 3) ** 0.5)
857
+ if patch_height * patch_width != tensor.shape[1]:
858
+ raise ValueError(
859
+ f"Number of patches {tensor.shape[1]} does not match patch height ({patch_height}) and width ({patch_width})."
860
+ )
861
+
862
+ tensor = tensor.reshape(shape=(batch_size, patch_height, patch_width, patch_size, patch_size, 3))
863
+ tensor = tensor.permute(0, 5, 1, 3, 2, 4)
864
+ tensor = tensor.reshape(shape=(batch_size, 3, patch_height * patch_size, patch_width * patch_size))
865
+
866
+ return tensor
867
+
868
+
869
+ class SegGptLoss(nn.Module):
870
+ def __init__(self, config):
871
+ super().__init__()
872
+ self.beta = config.beta
873
+ self.patch_size = config.patch_size
874
+
875
+ def forward(
876
+ self,
877
+ prompt_masks: torch.FloatTensor,
878
+ pred_masks: torch.FloatTensor,
879
+ labels: torch.FloatTensor,
880
+ bool_masked_pos: torch.BoolTensor,
881
+ ):
882
+ """Computes the L1 loss between the predicted masks and the ground truth masks.
883
+
884
+ Args:
885
+ prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
886
+ Pixel values from mask prompt.
887
+
888
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, 2*height, width)`):
889
+ Predicted masks.
890
+
891
+ labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
892
+ Ground truth mask for input images.
893
+
894
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
895
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
896
+
897
+ Returns:
898
+ `torch.FloatTensor`: The mean L1 loss between the predicted masks and the ground truth masks.
899
+ """
900
+ ground_truth = torch.cat((prompt_masks, labels), dim=2)
901
+
902
+ mask = bool_masked_pos[:, :, None].repeat(1, 1, self.patch_size**2 * 3)
903
+ mask = unpatchify(mask, ground_truth.shape[2] // self.patch_size, ground_truth.shape[3] // self.patch_size)
904
+
905
+ loss = F.smooth_l1_loss(pred_masks, ground_truth, reduction="none", beta=self.beta)
906
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
907
+
908
+ return loss
909
+
910
+
911
+ @add_start_docstrings(
912
+ "SegGpt model with a decoder on top for one-shot image segmentation.",
913
+ SEGGPT_START_DOCSTRING,
914
+ )
915
+ class SegGptForImageSegmentation(SegGptPreTrainedModel):
916
+ def __init__(self, config: SegGptConfig):
917
+ super().__init__(config)
918
+ self.config = config
919
+
920
+ self.model = SegGptModel(config)
921
+ self.decoder = SegGptDecoder(config)
922
+
923
+ # Initialize weights and apply final processing
924
+ self.post_init()
925
+
926
+ @add_start_docstrings_to_model_forward(SEGGPT_INPUTS_DOCSTRING)
927
+ @replace_return_docstrings(output_type=SegGptImageSegmentationOutput, config_class=_CONFIG_FOR_DOC)
928
+ def forward(
929
+ self,
930
+ pixel_values: torch.Tensor,
931
+ prompt_pixel_values: torch.Tensor,
932
+ prompt_masks: torch.Tensor,
933
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
934
+ feature_ensemble: Optional[bool] = None,
935
+ embedding_type: Optional[str] = None,
936
+ labels: Optional[torch.FloatTensor] = None,
937
+ output_attentions: Optional[bool] = None,
938
+ output_hidden_states: Optional[bool] = None,
939
+ return_dict: Optional[bool] = None,
940
+ ) -> Union[Tuple, SegGptImageSegmentationOutput]:
941
+ r"""
942
+ labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, `optional`):
943
+ Ground truth mask for input images.
944
+
945
+ Returns:
946
+
947
+ Examples:
948
+
949
+ ```python
950
+ >>> from transformers import SegGptImageProcessor, SegGptForImageSegmentation
951
+ >>> from PIL import Image
952
+ >>> import requests
953
+
954
+ >>> image_input_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg"
955
+ >>> image_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg"
956
+ >>> mask_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png"
957
+
958
+ >>> image_input = Image.open(requests.get(image_input_url, stream=True).raw)
959
+ >>> image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw)
960
+ >>> mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw).convert("L")
961
+
962
+ >>> checkpoint = "BAAI/seggpt-vit-large"
963
+ >>> model = SegGptForImageSegmentation.from_pretrained(checkpoint)
964
+ >>> image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
965
+
966
+ >>> inputs = image_processor(images=image_input, prompt_images=image_prompt, prompt_masks=mask_prompt, return_tensors="pt")
967
+ >>> outputs = model(**inputs)
968
+ >>> result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[(image_input.height, image_input.width)])[0]
969
+ >>> print(list(result.shape))
970
+ [170, 297]
971
+ ```
972
+ """
973
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
974
+ output_hidden_states = (
975
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
976
+ )
977
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
978
+
979
+ if bool_masked_pos is None:
980
+ num_patches = self.model.embeddings.patch_embeddings.num_patches
981
+ bool_masked_pos_zeros = torch.zeros(num_patches // 2, dtype=torch.bool, device=pixel_values.device)
982
+ bool_masked_pos_ones = torch.ones(
983
+ num_patches - num_patches // 2, dtype=torch.bool, device=pixel_values.device
984
+ )
985
+ bool_masked_pos = torch.cat([bool_masked_pos_zeros, bool_masked_pos_ones])
986
+ bool_masked_pos = bool_masked_pos.unsqueeze(0)
987
+
988
+ outputs = self.model(
989
+ pixel_values=pixel_values,
990
+ prompt_pixel_values=prompt_pixel_values,
991
+ prompt_masks=prompt_masks,
992
+ bool_masked_pos=bool_masked_pos,
993
+ feature_ensemble=feature_ensemble,
994
+ embedding_type=embedding_type,
995
+ labels=labels,
996
+ output_attentions=output_attentions,
997
+ output_hidden_states=output_hidden_states,
998
+ return_dict=return_dict,
999
+ )
1000
+
1001
+ intermediate_hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[-1]
1002
+ intermediate_hidden_states = torch.cat(intermediate_hidden_states, dim=-1)
1003
+ pred_masks = self.decoder(intermediate_hidden_states)
1004
+
1005
+ loss = None
1006
+ if labels is not None:
1007
+ loss_fn = SegGptLoss(self.config)
1008
+ loss = loss_fn(prompt_masks, pred_masks, labels, bool_masked_pos)
1009
+
1010
+ if not return_dict:
1011
+ output = (pred_masks,)
1012
+ if output_hidden_states:
1013
+ output = output + (outputs[1],)
1014
+
1015
+ if output_attentions:
1016
+ idx = 2 if output_hidden_states else 1
1017
+ output = output + (outputs[idx],)
1018
+
1019
+ if loss is not None:
1020
+ output = (loss,) + output
1021
+ return output
1022
+
1023
+ return SegGptImageSegmentationOutput(
1024
+ loss=loss,
1025
+ pred_masks=pred_masks,
1026
+ hidden_states=outputs.hidden_states,
1027
+ attentions=outputs.attentions,
1028
+ )
1029
+
1030
+
1031
+ __all__ = ["SegGptModel", "SegGptPreTrainedModel", "SegGptForImageSegmentation"]
docs/transformers/build/lib/transformers/models/sew/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_sew import *
22
+ from .modeling_sew import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/sew/configuration_sew.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 ASAPP Inc. and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """SEW model configuration"""
16
+
17
+ import functools
18
+ import operator
19
+
20
+ from ...configuration_utils import PretrainedConfig
21
+ from ...utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class SEWConfig(PretrainedConfig):
28
+ r"""
29
+ This is the configuration class to store the configuration of a [`SEWModel`]. It is used to instantiate a SEW model
30
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
31
+ defaults will yield a similar configuration to that of the SEW
32
+ [asapp/sew-tiny-100k](https://huggingface.co/asapp/sew-tiny-100k) architecture.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 32):
40
+ Vocabulary size of the SEW model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`SEW`].
42
+ hidden_size (`int`, *optional*, defaults to 768):
43
+ Dimensionality of the encoder layers and the pooler layer.
44
+ num_hidden_layers (`int`, *optional*, defaults to 12):
45
+ Number of hidden layers in the Transformer encoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 12):
47
+ Number of attention heads for each attention layer in the Transformer encoder.
48
+ intermediate_size (`int`, *optional*, defaults to 3072):
49
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
50
+ squeeze_factor (`int`, *optional*, defaults to 2):
51
+ Sequence length downsampling factor after the encoder and upsampling factor after the transformer.
52
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
53
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
54
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
55
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
56
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
57
+ activation_dropout (`float`, *optional*, defaults to 0.1):
58
+ The dropout ratio for activations inside the fully connected layer.
59
+ attention_dropout (`float`, *optional*, defaults to 0.1):
60
+ The dropout ratio for the attention probabilities.
61
+ final_dropout (`float`, *optional*, defaults to 0.1):
62
+ The dropout probability for the final projection layer of [`SEWForCTC`].
63
+ layerdrop (`float`, *optional*, defaults to 0.1):
64
+ The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more
65
+ details.
66
+ initializer_range (`float`, *optional*, defaults to 0.02):
67
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
69
+ The epsilon used by the layer normalization layers.
70
+ feat_extract_norm (`str`, *optional*, defaults to `"group"`):
71
+ The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
72
+ normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
73
+ convolutional layers.
74
+ feat_proj_dropout (`float`, *optional*, defaults to 0.0):
75
+ The dropout probability for output of the feature encoder.
76
+ feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
77
+ The non-linear activation function (function or string) in the 1D convolutional layers of the feature
78
+ extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
79
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512)`):
80
+ A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
81
+ feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
82
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1)`):
83
+ A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
84
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
85
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1)`):
86
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
87
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
88
+ *conv_dim*.
89
+ conv_bias (`bool`, *optional*, defaults to `False`):
90
+ Whether the 1D convolutional layers have a bias.
91
+ num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
92
+ Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
93
+ embeddings layer.
94
+ num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
95
+ Number of groups of 1D convolutional positional embeddings layer.
96
+ apply_spec_augment (`bool`, *optional*, defaults to `True`):
97
+ Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
98
+ [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
99
+ Recognition](https://arxiv.org/abs/1904.08779).
100
+ mask_time_prob (`float`, *optional*, defaults to 0.05):
101
+ Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
102
+ procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
103
+ reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
104
+ masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
105
+ actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
106
+ mask_time_length (`int`, *optional*, defaults to 10):
107
+ Length of vector span along the time axis.
108
+ mask_time_min_masks (`int`, *optional*, defaults to 2),:
109
+ The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
110
+ irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
111
+ mask_time_min_masks''
112
+ mask_feature_prob (`float`, *optional*, defaults to 0.0):
113
+ Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
114
+ masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
115
+ the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
116
+ span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
117
+ may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
118
+ True`.
119
+ mask_feature_length (`int`, *optional*, defaults to 10):
120
+ Length of vector span along the feature axis.
121
+ mask_feature_min_masks (`int`, *optional*, defaults to 0),:
122
+ The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
123
+ step, irrespectively of `mask_feature_prob`. Only relevant if
124
+ ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
125
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
126
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
127
+ instance of [`SEWForCTC`].
128
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
129
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
130
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
131
+ of [`SEWForCTC`].
132
+ use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
133
+ Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
134
+ instance of [`Wav2Vec2ForSequenceClassification`].
135
+ classifier_proj_size (`int`, *optional*, defaults to 256):
136
+ Dimensionality of the projection before token mean-pooling for classification.
137
+
138
+ Example:
139
+
140
+ ```python
141
+ >>> from transformers import SEWConfig, SEWModel
142
+
143
+ >>> # Initializing a SEW asapp/sew-tiny-100k style configuration
144
+ >>> configuration = SEWConfig()
145
+
146
+ >>> # Initializing a model (with random weights) from the asapp/sew-tiny-100k style configuration
147
+ >>> model = SEWModel(configuration)
148
+
149
+ >>> # Accessing the model configuration
150
+ >>> configuration = model.config
151
+ ```"""
152
+
153
+ model_type = "sew"
154
+
155
+ def __init__(
156
+ self,
157
+ vocab_size=32,
158
+ hidden_size=768,
159
+ num_hidden_layers=12,
160
+ num_attention_heads=12,
161
+ intermediate_size=3072,
162
+ squeeze_factor=2,
163
+ hidden_act="gelu",
164
+ hidden_dropout=0.1,
165
+ activation_dropout=0.1,
166
+ attention_dropout=0.1,
167
+ feat_proj_dropout=0.0,
168
+ final_dropout=0.1,
169
+ layerdrop=0.1,
170
+ initializer_range=0.02,
171
+ layer_norm_eps=1e-5,
172
+ feat_extract_norm="group",
173
+ feat_extract_activation="gelu",
174
+ conv_dim=(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512),
175
+ conv_stride=(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1),
176
+ conv_kernel=(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1),
177
+ conv_bias=False,
178
+ num_conv_pos_embeddings=128,
179
+ num_conv_pos_embedding_groups=16,
180
+ apply_spec_augment=True,
181
+ mask_time_prob=0.05,
182
+ mask_time_length=10,
183
+ mask_time_min_masks=2,
184
+ mask_feature_prob=0.0,
185
+ mask_feature_length=10,
186
+ mask_feature_min_masks=0,
187
+ ctc_loss_reduction="mean",
188
+ ctc_zero_infinity=False,
189
+ use_weighted_layer_sum=False,
190
+ classifier_proj_size=256,
191
+ pad_token_id=0,
192
+ bos_token_id=1,
193
+ eos_token_id=2,
194
+ **kwargs,
195
+ ):
196
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
197
+ self.hidden_size = hidden_size
198
+ self.feat_extract_norm = feat_extract_norm
199
+ self.feat_extract_activation = feat_extract_activation
200
+ self.conv_dim = list(conv_dim)
201
+ self.conv_stride = list(conv_stride)
202
+ self.conv_kernel = list(conv_kernel)
203
+ self.conv_bias = conv_bias
204
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
205
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
206
+ self.num_feat_extract_layers = len(self.conv_dim)
207
+ self.num_hidden_layers = num_hidden_layers
208
+ self.intermediate_size = intermediate_size
209
+ self.squeeze_factor = squeeze_factor
210
+ self.hidden_act = hidden_act
211
+ self.num_attention_heads = num_attention_heads
212
+ self.hidden_dropout = hidden_dropout
213
+ self.attention_dropout = attention_dropout
214
+ self.activation_dropout = activation_dropout
215
+ self.feat_proj_dropout = feat_proj_dropout
216
+ self.final_dropout = final_dropout
217
+ self.layerdrop = layerdrop
218
+ self.layer_norm_eps = layer_norm_eps
219
+ self.initializer_range = initializer_range
220
+ self.vocab_size = vocab_size
221
+
222
+ if (
223
+ (len(self.conv_stride) != self.num_feat_extract_layers)
224
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
225
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
226
+ ):
227
+ raise ValueError(
228
+ "Configuration for convolutional layers is incorrect. "
229
+ "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
230
+ f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
231
+ f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
232
+ )
233
+
234
+ # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
235
+ self.apply_spec_augment = apply_spec_augment
236
+ self.mask_time_prob = mask_time_prob
237
+ self.mask_time_length = mask_time_length
238
+ self.mask_time_min_masks = mask_time_min_masks
239
+ self.mask_feature_prob = mask_feature_prob
240
+ self.mask_feature_length = mask_feature_length
241
+ self.mask_feature_min_masks = mask_feature_min_masks
242
+
243
+ # ctc loss
244
+ self.ctc_loss_reduction = ctc_loss_reduction
245
+ self.ctc_zero_infinity = ctc_zero_infinity
246
+
247
+ # sequence classification
248
+ self.use_weighted_layer_sum = use_weighted_layer_sum
249
+ self.classifier_proj_size = classifier_proj_size
250
+
251
+ @property
252
+ def inputs_to_logits_ratio(self):
253
+ return functools.reduce(operator.mul, self.conv_stride, 1)
254
+
255
+
256
+ __all__ = ["SEWConfig"]
docs/transformers/build/lib/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert SEW checkpoint."""
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+
21
+ import fairseq
22
+ import torch
23
+ from fairseq.data import Dictionary
24
+
25
+ # Register SEW's fairseq modules
26
+ from sew_asapp import tasks # noqa: F401
27
+
28
+ from transformers import (
29
+ SEWConfig,
30
+ SEWForCTC,
31
+ SEWModel,
32
+ Wav2Vec2CTCTokenizer,
33
+ Wav2Vec2FeatureExtractor,
34
+ Wav2Vec2Processor,
35
+ logging,
36
+ )
37
+
38
+
39
+ logging.set_verbosity_info()
40
+ logger = logging.get_logger(__name__)
41
+
42
+ MAPPING = {
43
+ "post_extract_proj": "feature_projection",
44
+ "encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
45
+ "self_attn.k_proj": "encoder.layers.*.attention.k_proj",
46
+ "self_attn.v_proj": "encoder.layers.*.attention.v_proj",
47
+ "self_attn.q_proj": "encoder.layers.*.attention.q_proj",
48
+ "self_attn.out_proj": "encoder.layers.*.attention.out_proj",
49
+ "self_attn_layer_norm": "encoder.layers.*.layer_norm",
50
+ "fc1": "encoder.layers.*.feed_forward.intermediate_dense",
51
+ "fc2": "encoder.layers.*.feed_forward.output_dense",
52
+ "final_layer_norm": "encoder.layers.*.final_layer_norm",
53
+ "encoder.upsample.0": "encoder.upsample.projection",
54
+ "encoder.layer_norm": "encoder.layer_norm",
55
+ "w2v_model.layer_norm": "layer_norm",
56
+ "w2v_encoder.proj": "lm_head",
57
+ "mask_emb": "masked_spec_embed",
58
+ }
59
+
60
+
61
+ def set_recursively(hf_pointer, key, value, full_name, weight_type):
62
+ for attribute in key.split("."):
63
+ hf_pointer = getattr(hf_pointer, attribute)
64
+
65
+ if weight_type is not None:
66
+ hf_shape = getattr(hf_pointer, weight_type).shape
67
+ else:
68
+ hf_shape = hf_pointer.shape
69
+
70
+ assert hf_shape == value.shape, (
71
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
72
+ f" {value.shape} for {full_name}"
73
+ )
74
+
75
+ if weight_type == "weight":
76
+ hf_pointer.weight.data = value
77
+ elif weight_type == "weight_g":
78
+ hf_pointer.weight_g.data = value
79
+ elif weight_type == "weight_v":
80
+ hf_pointer.weight_v.data = value
81
+ elif weight_type == "bias":
82
+ hf_pointer.bias.data = value
83
+ else:
84
+ hf_pointer.data = value
85
+
86
+ logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
87
+
88
+
89
+ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
90
+ unused_weights = []
91
+ fairseq_dict = fairseq_model.state_dict()
92
+
93
+ feature_extractor = hf_model.sew.feature_extractor if is_finetuned else hf_model.feature_extractor
94
+
95
+ for name, value in fairseq_dict.items():
96
+ is_used = False
97
+ if "conv_layers" in name:
98
+ load_conv_layer(
99
+ name,
100
+ value,
101
+ feature_extractor,
102
+ unused_weights,
103
+ hf_model.config.feat_extract_norm == "group",
104
+ )
105
+ is_used = True
106
+ else:
107
+ for key, mapped_key in MAPPING.items():
108
+ mapped_key = "sew." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
109
+
110
+ if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
111
+ is_used = True
112
+ if "*" in mapped_key:
113
+ layer_index = name.split(key)[0].split(".")[-2]
114
+ mapped_key = mapped_key.replace("*", layer_index)
115
+ if "weight_g" in name:
116
+ weight_type = "weight_g"
117
+ elif "weight_v" in name:
118
+ weight_type = "weight_v"
119
+ elif "weight" in name:
120
+ weight_type = "weight"
121
+ elif "bias" in name:
122
+ weight_type = "bias"
123
+ else:
124
+ weight_type = None
125
+ set_recursively(hf_model, mapped_key, value, name, weight_type)
126
+ continue
127
+ if not is_used:
128
+ unused_weights.append(name)
129
+
130
+ logger.warning(f"Unused weights: {unused_weights}")
131
+
132
+
133
+ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
134
+ name = full_name.split("conv_layers.")[-1]
135
+ items = name.split(".")
136
+ layer_id = int(items[0])
137
+ type_id = int(items[1])
138
+
139
+ if type_id == 0:
140
+ if "bias" in name:
141
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
142
+ f"{full_name} has size {value.shape}, but"
143
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
144
+ )
145
+ feature_extractor.conv_layers[layer_id].conv.bias.data = value
146
+ logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
147
+ elif "weight" in name:
148
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
149
+ f"{full_name} has size {value.shape}, but"
150
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
151
+ )
152
+ feature_extractor.conv_layers[layer_id].conv.weight.data = value
153
+ logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
154
+ elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
155
+ if "bias" in name:
156
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
157
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
158
+ " found."
159
+ )
160
+ feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
161
+ logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
162
+ elif "weight" in name:
163
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
164
+ f"{full_name} has size {value.shape}, but"
165
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
166
+ )
167
+ feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
168
+ logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
169
+ else:
170
+ unused_weights.append(full_name)
171
+
172
+
173
+ def convert_config(model, is_finetuned):
174
+ config = SEWConfig()
175
+ if is_finetuned:
176
+ fs_config = model.w2v_encoder.w2v_model.cfg
177
+ else:
178
+ fs_config = model.cfg
179
+
180
+ config.conv_bias = fs_config.conv_bias
181
+ conv_layers = eval(fs_config.conv_feature_layers)
182
+ config.conv_dim = [x[0] for x in conv_layers]
183
+ config.conv_kernel = [x[1] for x in conv_layers]
184
+ config.conv_stride = [x[2] for x in conv_layers]
185
+ config.feat_extract_activation = "gelu"
186
+ config.feat_extract_norm = "layer" if fs_config.extractor_mode == "layer_norm" else "group"
187
+ config.final_dropout = 0.0
188
+ config.hidden_act = fs_config.activation_fn.name
189
+ config.hidden_size = fs_config.encoder_embed_dim
190
+ config.initializer_range = 0.02
191
+ config.intermediate_size = fs_config.encoder_ffn_embed_dim
192
+ config.layer_norm_eps = 1e-5
193
+ config.layerdrop = fs_config.encoder_layerdrop
194
+ config.num_attention_heads = fs_config.encoder_attention_heads
195
+ config.num_conv_pos_embedding_groups = fs_config.conv_pos_groups
196
+ config.num_conv_pos_embeddings = fs_config.conv_pos
197
+ config.num_feat_extract_layers = len(conv_layers)
198
+ config.num_hidden_layers = fs_config.encoder_layers
199
+ config.squeeze_factor = fs_config.squeeze_factor
200
+
201
+ # take care of any params that are overridden by the Wav2VecCtc model
202
+ if is_finetuned:
203
+ fs_config = model.cfg
204
+ config.final_dropout = fs_config.final_dropout
205
+ config.layerdrop = fs_config.layerdrop
206
+ config.activation_dropout = fs_config.activation_dropout
207
+ config.apply_spec_augment = fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0
208
+ config.attention_dropout = fs_config.attention_dropout
209
+ config.feat_proj_dropout = fs_config.dropout_input
210
+ config.hidden_dropout = fs_config.dropout
211
+ config.mask_feature_length = fs_config.mask_channel_length
212
+ config.mask_feature_prob = fs_config.mask_channel_prob
213
+ config.mask_time_length = fs_config.mask_length
214
+ config.mask_time_prob = fs_config.mask_prob
215
+
216
+ config.feature_extractor_type = "Wav2Vec2FeatureExtractor"
217
+ config.tokenizer_class = "Wav2Vec2CTCTokenizer"
218
+
219
+ return config
220
+
221
+
222
+ @torch.no_grad()
223
+ def convert_sew_checkpoint(
224
+ checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
225
+ ):
226
+ """
227
+ Copy/paste/tweak model's weights to transformers design.
228
+ """
229
+
230
+ if is_finetuned:
231
+ model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
232
+ [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
233
+ )
234
+ else:
235
+ model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])
236
+
237
+ if config_path is not None:
238
+ config = SEWConfig.from_pretrained(config_path)
239
+ else:
240
+ config = convert_config(model[0], is_finetuned)
241
+ model = model[0].eval()
242
+
243
+ return_attention_mask = True if config.feat_extract_norm == "layer" else False
244
+ feature_extractor = Wav2Vec2FeatureExtractor(
245
+ feature_size=1,
246
+ sampling_rate=16000,
247
+ padding_value=0,
248
+ do_normalize=True,
249
+ return_attention_mask=return_attention_mask,
250
+ )
251
+
252
+ if is_finetuned:
253
+ if dict_path:
254
+ target_dict = Dictionary.load(dict_path)
255
+
256
+ # important change bos & pad token id since CTC symbol is <pad> and
257
+ # not <s> as in fairseq
258
+ target_dict.indices[target_dict.bos_word] = target_dict.pad_index
259
+ target_dict.indices[target_dict.pad_word] = target_dict.bos_index
260
+ config.bos_token_id = target_dict.pad_index
261
+ config.pad_token_id = target_dict.bos_index
262
+ config.eos_token_id = target_dict.eos_index
263
+ config.vocab_size = len(target_dict.symbols)
264
+ vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
265
+ if not os.path.isdir(pytorch_dump_folder_path):
266
+ logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path))
267
+ return
268
+ os.makedirs(pytorch_dump_folder_path, exist_ok=True)
269
+ with open(vocab_path, "w", encoding="utf-8") as vocab_handle:
270
+ json.dump(target_dict.indices, vocab_handle)
271
+ tokenizer = Wav2Vec2CTCTokenizer(
272
+ vocab_path,
273
+ unk_token=target_dict.unk_word,
274
+ pad_token=target_dict.pad_word,
275
+ bos_token=target_dict.bos_word,
276
+ eos_token=target_dict.eos_word,
277
+ word_delimiter_token="|",
278
+ do_lower_case=False,
279
+ )
280
+ processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
281
+ processor.save_pretrained(pytorch_dump_folder_path)
282
+
283
+ hf_model = SEWForCTC(config)
284
+ else:
285
+ hf_model = SEWModel(config)
286
+ feature_extractor.save_pretrained(pytorch_dump_folder_path)
287
+
288
+ recursively_load_weights(model, hf_model, is_finetuned)
289
+
290
+ hf_model.save_pretrained(pytorch_dump_folder_path)
291
+
292
+
293
+ if __name__ == "__main__":
294
+ parser = argparse.ArgumentParser()
295
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
296
+ parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
297
+ parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
298
+ parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
299
+ parser.add_argument(
300
+ "--is_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
301
+ )
302
+ args = parser.parse_args()
303
+ convert_sew_checkpoint(
304
+ args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, args.is_finetuned
305
+ )
docs/transformers/build/lib/transformers/models/sew/modeling_sew.py ADDED
@@ -0,0 +1,1498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch SEW model."""
16
+
17
+ import math
18
+ import warnings
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+
27
+ from ...activations import ACT2FN
28
+ from ...integrations.deepspeed import is_deepspeed_zero3_enabled
29
+ from ...integrations.fsdp import is_fsdp_managed_module
30
+ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
31
+ from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
32
+ from ...modeling_utils import PreTrainedModel
33
+ from ...utils import (
34
+ add_code_sample_docstrings,
35
+ add_start_docstrings,
36
+ add_start_docstrings_to_model_forward,
37
+ logging,
38
+ )
39
+ from .configuration_sew import SEWConfig
40
+
41
+
42
+ if is_flash_attn_available():
43
+ from ...modeling_flash_attention_utils import _flash_attention_forward
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+
49
+ _HIDDEN_STATES_START_POSITION = 1
50
+
51
+ # General docstring
52
+ _CONFIG_FOR_DOC = "SEWConfig"
53
+
54
+ # Base docstring
55
+ _CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k-ft-ls100h"
56
+ _EXPECTED_OUTPUT_SHAPE = [1, 292, 512]
57
+
58
+ # CTC docstring
59
+ _CTC_EXPECTED_OUTPUT = (
60
+ "'MISTER QUILTER IS THE APPOSTILE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPOLLE'"
61
+ )
62
+ _CTC_EXPECTED_LOSS = 0.42
63
+
64
+ # Audio class docstring
65
+ _SEQ_CLASS_CHECKPOINT = "anton-l/sew-mid-100k-ft-keyword-spotting"
66
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
67
+ _SEQ_CLASS_EXPECTED_LOSS = 9.52
68
+
69
+
70
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
71
+ def _compute_mask_indices(
72
+ shape: Tuple[int, int],
73
+ mask_prob: float,
74
+ mask_length: int,
75
+ attention_mask: Optional[torch.LongTensor] = None,
76
+ min_masks: int = 0,
77
+ ) -> np.ndarray:
78
+ """
79
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
80
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
81
+ CPU as part of the preprocessing during training.
82
+
83
+ Args:
84
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
85
+ the first element is the batch size and the second element is the length of the axis to span.
86
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
87
+ independently generated mask spans of length `mask_length` is computed by
88
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
89
+ actual percentage will be smaller.
90
+ mask_length: size of the mask
91
+ min_masks: minimum number of masked spans
92
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
93
+ each batch dimension.
94
+ """
95
+ batch_size, sequence_length = shape
96
+
97
+ if mask_length < 1:
98
+ raise ValueError("`mask_length` has to be bigger than 0.")
99
+
100
+ if mask_length > sequence_length:
101
+ raise ValueError(
102
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
103
+ f" and `sequence_length`: {sequence_length}`"
104
+ )
105
+
106
+ # epsilon is used for probabilistic rounding
107
+ epsilon = np.random.rand(1).item()
108
+
109
+ def compute_num_masked_span(input_length):
110
+ """Given input length, compute how many spans should be masked"""
111
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
112
+ num_masked_span = max(num_masked_span, min_masks)
113
+
114
+ # make sure num masked span <= sequence_length
115
+ if num_masked_span * mask_length > sequence_length:
116
+ num_masked_span = sequence_length // mask_length
117
+
118
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
119
+ if input_length - (mask_length - 1) < num_masked_span:
120
+ num_masked_span = max(input_length - (mask_length - 1), 0)
121
+
122
+ return num_masked_span
123
+
124
+ # compute number of masked spans in batch
125
+ input_lengths = (
126
+ attention_mask.detach().sum(-1).tolist()
127
+ if attention_mask is not None
128
+ else [sequence_length for _ in range(batch_size)]
129
+ )
130
+
131
+ # SpecAugment mask to fill
132
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
133
+ spec_aug_mask_idxs = []
134
+
135
+ max_num_masked_span = compute_num_masked_span(sequence_length)
136
+
137
+ if max_num_masked_span == 0:
138
+ return spec_aug_mask
139
+
140
+ for input_length in input_lengths:
141
+ # compute num of masked spans for this input
142
+ num_masked_span = compute_num_masked_span(input_length)
143
+
144
+ # get random indices to mask
145
+ spec_aug_mask_idx = np.random.choice(
146
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
147
+ )
148
+
149
+ # pick first sampled index that will serve as a dummy index to pad vector
150
+ # to ensure same dimension for all batches due to probabilistic rounding
151
+ # Picking first sample just pads those vectors twice.
152
+ if len(spec_aug_mask_idx) == 0:
153
+ # this case can only happen if `input_length` is strictly smaller then
154
+ # `sequence_length` in which case the last token has to be a padding
155
+ # token which we can use as a dummy mask id
156
+ dummy_mask_idx = sequence_length - 1
157
+ else:
158
+ dummy_mask_idx = spec_aug_mask_idx[0]
159
+
160
+ spec_aug_mask_idx = np.concatenate(
161
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
162
+ )
163
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
164
+
165
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
166
+
167
+ # expand masked indices to masked spans
168
+ spec_aug_mask_idxs = np.broadcast_to(
169
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
170
+ )
171
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
172
+
173
+ # add offset to the starting indexes so that indexes now create a span
174
+ offsets = np.arange(mask_length)[None, None, :]
175
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
176
+ batch_size, max_num_masked_span * mask_length
177
+ )
178
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
179
+
180
+ # ensure that we cannot have indices larger than sequence_length
181
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
182
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
183
+
184
+ # scatter indices to mask
185
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
186
+
187
+ return spec_aug_mask
188
+
189
+
190
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SEW
191
+ class SEWNoLayerNormConvLayer(nn.Module):
192
+ def __init__(self, config, layer_id=0):
193
+ super().__init__()
194
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
195
+ self.out_conv_dim = config.conv_dim[layer_id]
196
+
197
+ self.conv = nn.Conv1d(
198
+ self.in_conv_dim,
199
+ self.out_conv_dim,
200
+ kernel_size=config.conv_kernel[layer_id],
201
+ stride=config.conv_stride[layer_id],
202
+ bias=config.conv_bias,
203
+ )
204
+ self.activation = ACT2FN[config.feat_extract_activation]
205
+
206
+ def forward(self, hidden_states):
207
+ hidden_states = self.conv(hidden_states)
208
+ hidden_states = self.activation(hidden_states)
209
+ return hidden_states
210
+
211
+
212
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SEW
213
+ class SEWLayerNormConvLayer(nn.Module):
214
+ def __init__(self, config, layer_id=0):
215
+ super().__init__()
216
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
217
+ self.out_conv_dim = config.conv_dim[layer_id]
218
+
219
+ self.conv = nn.Conv1d(
220
+ self.in_conv_dim,
221
+ self.out_conv_dim,
222
+ kernel_size=config.conv_kernel[layer_id],
223
+ stride=config.conv_stride[layer_id],
224
+ bias=config.conv_bias,
225
+ )
226
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
227
+ self.activation = ACT2FN[config.feat_extract_activation]
228
+
229
+ def forward(self, hidden_states):
230
+ hidden_states = self.conv(hidden_states)
231
+
232
+ hidden_states = hidden_states.transpose(-2, -1)
233
+ hidden_states = self.layer_norm(hidden_states)
234
+ hidden_states = hidden_states.transpose(-2, -1)
235
+
236
+ hidden_states = self.activation(hidden_states)
237
+ return hidden_states
238
+
239
+
240
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SEW
241
+ class SEWGroupNormConvLayer(nn.Module):
242
+ def __init__(self, config, layer_id=0):
243
+ super().__init__()
244
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
245
+ self.out_conv_dim = config.conv_dim[layer_id]
246
+
247
+ self.conv = nn.Conv1d(
248
+ self.in_conv_dim,
249
+ self.out_conv_dim,
250
+ kernel_size=config.conv_kernel[layer_id],
251
+ stride=config.conv_stride[layer_id],
252
+ bias=config.conv_bias,
253
+ )
254
+ self.activation = ACT2FN[config.feat_extract_activation]
255
+
256
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
257
+
258
+ def forward(self, hidden_states):
259
+ hidden_states = self.conv(hidden_states)
260
+ hidden_states = self.layer_norm(hidden_states)
261
+ hidden_states = self.activation(hidden_states)
262
+ return hidden_states
263
+
264
+
265
+ class SEWPositionalConvEmbedding(nn.Module):
266
+ def __init__(self, config):
267
+ super().__init__()
268
+ self.conv = nn.Conv1d(
269
+ config.hidden_size,
270
+ config.hidden_size,
271
+ kernel_size=config.num_conv_pos_embeddings,
272
+ padding=config.num_conv_pos_embeddings // 2,
273
+ groups=config.num_conv_pos_embedding_groups,
274
+ stride=config.squeeze_factor,
275
+ )
276
+
277
+ weight_norm = nn.utils.weight_norm
278
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
279
+ weight_norm = nn.utils.parametrizations.weight_norm
280
+
281
+ if is_deepspeed_zero3_enabled():
282
+ import deepspeed
283
+
284
+ with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
285
+ self.conv = weight_norm(self.conv, name="weight", dim=2)
286
+ if hasattr(self.conv, "parametrizations"):
287
+ weight_g = self.conv.parametrizations.weight.original0
288
+ weight_v = self.conv.parametrizations.weight.original1
289
+ else:
290
+ weight_g = self.conv.weight_g
291
+ weight_v = self.conv.weight_v
292
+ deepspeed.zero.register_external_parameter(self, weight_v)
293
+ deepspeed.zero.register_external_parameter(self, weight_g)
294
+ else:
295
+ self.conv = weight_norm(self.conv, name="weight", dim=2)
296
+
297
+ self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings)
298
+ self.activation = ACT2FN[config.feat_extract_activation]
299
+
300
+ def forward(self, hidden_states):
301
+ hidden_states = self.conv(hidden_states)
302
+ hidden_states = self.padding(hidden_states)
303
+ hidden_states = self.activation(hidden_states)
304
+
305
+ return hidden_states
306
+
307
+
308
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->SEW
309
+ class SEWSamePadLayer(nn.Module):
310
+ def __init__(self, num_conv_pos_embeddings):
311
+ super().__init__()
312
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
313
+
314
+ def forward(self, hidden_states):
315
+ if self.num_pad_remove > 0:
316
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
317
+ return hidden_states
318
+
319
+
320
+ class SEWUpsampling(nn.Module):
321
+ def __init__(self, config):
322
+ super().__init__()
323
+ self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor)
324
+ self.activation = ACT2FN[config.feat_extract_activation]
325
+ self.squeeze_factor = config.squeeze_factor
326
+
327
+ def forward(self, hidden_states):
328
+ hidden_states = self.projection(hidden_states)
329
+ hidden_states = self.activation(hidden_states)
330
+
331
+ if self.squeeze_factor > 1:
332
+ # transform embedding channels to sequence length
333
+ bsz, src_len, src_embed_dim = hidden_states.size()
334
+ tgt_len = src_len * self.squeeze_factor
335
+ tgt_embed_dim = src_embed_dim // self.squeeze_factor
336
+ hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim)
337
+ hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim)
338
+
339
+ return hidden_states
340
+
341
+
342
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->SEW
343
+ class SEWFeatureEncoder(nn.Module):
344
+ """Construct the features from raw audio waveform"""
345
+
346
+ def __init__(self, config):
347
+ super().__init__()
348
+
349
+ if config.feat_extract_norm == "group":
350
+ conv_layers = [SEWGroupNormConvLayer(config, layer_id=0)] + [
351
+ SEWNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
352
+ ]
353
+ elif config.feat_extract_norm == "layer":
354
+ conv_layers = [SEWLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
355
+ else:
356
+ raise ValueError(
357
+ f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
358
+ )
359
+ self.conv_layers = nn.ModuleList(conv_layers)
360
+ self.gradient_checkpointing = False
361
+ self._requires_grad = True
362
+
363
+ def _freeze_parameters(self):
364
+ for param in self.parameters():
365
+ param.requires_grad = False
366
+ self._requires_grad = False
367
+
368
+ def forward(self, input_values):
369
+ hidden_states = input_values[:, None]
370
+
371
+ # make sure hidden_states require grad for gradient_checkpointing
372
+ if self._requires_grad and self.training:
373
+ hidden_states.requires_grad = True
374
+
375
+ for conv_layer in self.conv_layers:
376
+ if self._requires_grad and self.gradient_checkpointing and self.training:
377
+ hidden_states = self._gradient_checkpointing_func(
378
+ conv_layer.__call__,
379
+ hidden_states,
380
+ )
381
+ else:
382
+ hidden_states = conv_layer(hidden_states)
383
+
384
+ return hidden_states
385
+
386
+
387
+ class SEWFeatureExtractor(SEWFeatureEncoder):
388
+ def __init__(self, config):
389
+ super().__init__(config)
390
+ warnings.warn(
391
+ f"The class `{self.__class__.__name__}` has been depreciated "
392
+ "and will be removed in Transformers v5. "
393
+ f"Use `{self.__class__.__bases__[0].__name__}` instead.",
394
+ FutureWarning,
395
+ )
396
+
397
+
398
+ # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->SEW
399
+ class SEWAttention(nn.Module):
400
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
401
+
402
+ def __init__(
403
+ self,
404
+ embed_dim: int,
405
+ num_heads: int,
406
+ dropout: float = 0.0,
407
+ is_decoder: bool = False,
408
+ bias: bool = True,
409
+ is_causal: bool = False,
410
+ config: Optional[SEWConfig] = None,
411
+ ):
412
+ super().__init__()
413
+ self.embed_dim = embed_dim
414
+ self.num_heads = num_heads
415
+ self.dropout = dropout
416
+ self.head_dim = embed_dim // num_heads
417
+ self.config = config
418
+
419
+ if (self.head_dim * num_heads) != self.embed_dim:
420
+ raise ValueError(
421
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
422
+ f" and `num_heads`: {num_heads})."
423
+ )
424
+ self.scaling = self.head_dim**-0.5
425
+ self.is_decoder = is_decoder
426
+ self.is_causal = is_causal
427
+
428
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
429
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
430
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
431
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
432
+
433
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
434
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
435
+
436
+ def forward(
437
+ self,
438
+ hidden_states: torch.Tensor,
439
+ key_value_states: Optional[torch.Tensor] = None,
440
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
441
+ attention_mask: Optional[torch.Tensor] = None,
442
+ layer_head_mask: Optional[torch.Tensor] = None,
443
+ output_attentions: bool = False,
444
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
445
+ """Input shape: Batch x Time x Channel"""
446
+
447
+ # if key_value_states are provided this layer is used as a cross-attention layer
448
+ # for the decoder
449
+ is_cross_attention = key_value_states is not None
450
+
451
+ bsz, tgt_len, _ = hidden_states.size()
452
+
453
+ # get query proj
454
+ query_states = self.q_proj(hidden_states) * self.scaling
455
+ # get key, value proj
456
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
457
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
458
+ # the provided `key_value_states` to support prefix tuning
459
+ if (
460
+ is_cross_attention
461
+ and past_key_value is not None
462
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
463
+ ):
464
+ # reuse k,v, cross_attentions
465
+ key_states = past_key_value[0]
466
+ value_states = past_key_value[1]
467
+ elif is_cross_attention:
468
+ # cross_attentions
469
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
470
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
471
+ elif past_key_value is not None:
472
+ # reuse k, v, self_attention
473
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
474
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
475
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
476
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
477
+ else:
478
+ # self_attention
479
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
480
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
481
+
482
+ if self.is_decoder:
483
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
484
+ # Further calls to cross_attention layer can then reuse all cross-attention
485
+ # key/value_states (first "if" case)
486
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
487
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
488
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
489
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
490
+ past_key_value = (key_states, value_states)
491
+
492
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
493
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
494
+ key_states = key_states.reshape(*proj_shape)
495
+ value_states = value_states.reshape(*proj_shape)
496
+
497
+ src_len = key_states.size(1)
498
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
499
+
500
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
501
+ raise ValueError(
502
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
503
+ f" {attn_weights.size()}"
504
+ )
505
+
506
+ if attention_mask is not None:
507
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
508
+ raise ValueError(
509
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
510
+ )
511
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
512
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
513
+
514
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
515
+
516
+ if layer_head_mask is not None:
517
+ if layer_head_mask.size() != (self.num_heads,):
518
+ raise ValueError(
519
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
520
+ f" {layer_head_mask.size()}"
521
+ )
522
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
523
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
524
+
525
+ if output_attentions:
526
+ # this operation is a bit awkward, but it's required to
527
+ # make sure that attn_weights keeps its gradient.
528
+ # In order to do so, attn_weights have to be reshaped
529
+ # twice and have to be reused in the following
530
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
531
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
532
+ else:
533
+ attn_weights_reshaped = None
534
+
535
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
536
+
537
+ attn_output = torch.bmm(attn_probs, value_states)
538
+
539
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
540
+ raise ValueError(
541
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
542
+ f" {attn_output.size()}"
543
+ )
544
+
545
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
546
+ attn_output = attn_output.transpose(1, 2)
547
+
548
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
549
+ # partitioned across GPUs when using tensor-parallelism.
550
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
551
+
552
+ attn_output = self.out_proj(attn_output)
553
+
554
+ return attn_output, attn_weights_reshaped, past_key_value
555
+
556
+
557
+ # Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->SEW
558
+ class SEWFlashAttention2(SEWAttention):
559
+ """
560
+ SEW flash attention module. This module inherits from `SEWAttention` as the weights of the module stays
561
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
562
+ flash attention and deal with padding tokens in case the input contains any of them.
563
+ """
564
+
565
+ def __init__(self, *args, **kwargs):
566
+ super().__init__(*args, **kwargs)
567
+
568
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
569
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
570
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
571
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
572
+
573
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
574
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
575
+
576
+ def forward(
577
+ self,
578
+ hidden_states: torch.Tensor,
579
+ key_value_states: Optional[torch.Tensor] = None,
580
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
581
+ attention_mask: Optional[torch.Tensor] = None,
582
+ layer_head_mask: Optional[torch.Tensor] = None,
583
+ output_attentions: bool = False,
584
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
585
+ # SEWFlashAttention2 attention does not support output_attentions
586
+ if output_attentions:
587
+ raise ValueError("SEWFlashAttention2 attention does not support output_attentions")
588
+
589
+ # if key_value_states are provided this layer is used as a cross-attention layer
590
+ # for the decoder
591
+ is_cross_attention = key_value_states is not None
592
+
593
+ bsz, q_len, _ = hidden_states.size()
594
+
595
+ # get query proj
596
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
597
+ # get key, value proj
598
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
599
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
600
+ # the provided `key_value_states` to support prefix tuning
601
+ if (
602
+ is_cross_attention
603
+ and past_key_value is not None
604
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
605
+ ):
606
+ # reuse k,v, cross_attentions
607
+ key_states = past_key_value[0].transpose(1, 2)
608
+ value_states = past_key_value[1].transpose(1, 2)
609
+ elif is_cross_attention:
610
+ # cross_attentions
611
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
612
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
613
+ elif past_key_value is not None:
614
+ # reuse k, v, self_attention
615
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
616
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
617
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
618
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
619
+ else:
620
+ # self_attention
621
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
622
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
623
+
624
+ if self.is_decoder:
625
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
626
+ # Further calls to cross_attention layer can then reuse all cross-attention
627
+ # key/value_states (first "if" case)
628
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
629
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
630
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
631
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
632
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
633
+
634
+ kv_seq_len = key_states.shape[-2]
635
+ if past_key_value is not None:
636
+ kv_seq_len += past_key_value[0].shape[-2]
637
+
638
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
639
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
640
+ # cast them back in the correct dtype just to be sure everything works as expected.
641
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
642
+ # in fp32. (LlamaRMSNorm handles it correctly)
643
+
644
+ input_dtype = query_states.dtype
645
+ if input_dtype == torch.float32:
646
+ if torch.is_autocast_enabled():
647
+ target_dtype = torch.get_autocast_gpu_dtype()
648
+ # Handle the case where the model is quantized
649
+ elif hasattr(self.config, "_pre_quantization_dtype"):
650
+ target_dtype = self.config._pre_quantization_dtype
651
+ else:
652
+ target_dtype = self.q_proj.weight.dtype
653
+
654
+ logger.warning_once(
655
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
656
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
657
+ f" {target_dtype}."
658
+ )
659
+
660
+ query_states = query_states.to(target_dtype)
661
+ key_states = key_states.to(target_dtype)
662
+ value_states = value_states.to(target_dtype)
663
+
664
+ attn_output = _flash_attention_forward(
665
+ query_states,
666
+ key_states,
667
+ value_states,
668
+ attention_mask,
669
+ q_len,
670
+ dropout=self.dropout if self.training else 0.0,
671
+ is_causal=self.is_causal,
672
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
673
+ )
674
+
675
+ attn_output = attn_output.reshape(bsz, q_len, -1)
676
+ attn_output = self.out_proj(attn_output)
677
+
678
+ if not output_attentions:
679
+ attn_weights = None
680
+
681
+ return attn_output, attn_weights, past_key_value
682
+
683
+
684
+ class SEWSdpaAttention(SEWAttention):
685
+ # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->SEW
686
+ def forward(
687
+ self,
688
+ hidden_states: torch.Tensor,
689
+ key_value_states: Optional[torch.Tensor] = None,
690
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
691
+ attention_mask: Optional[torch.Tensor] = None,
692
+ layer_head_mask: Optional[torch.Tensor] = None,
693
+ output_attentions: bool = False,
694
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
695
+ """Input shape: Batch x Time x Channel"""
696
+ if output_attentions or layer_head_mask is not None:
697
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
698
+ logger.warning_once(
699
+ "SEWModel is using SEWSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
700
+ ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
701
+ )
702
+ return super().forward(
703
+ hidden_states,
704
+ key_value_states=key_value_states,
705
+ past_key_value=past_key_value,
706
+ attention_mask=attention_mask,
707
+ layer_head_mask=layer_head_mask,
708
+ output_attentions=output_attentions,
709
+ )
710
+
711
+ # if key_value_states are provided this layer is used as a cross-attention layer
712
+ # for the decoder
713
+ is_cross_attention = key_value_states is not None
714
+
715
+ bsz, tgt_len, _ = hidden_states.size()
716
+
717
+ # get query proj
718
+ query_states = self.q_proj(hidden_states)
719
+ # get key, value proj
720
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
721
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
722
+ # the provided `key_value_states` to support prefix tuning
723
+ if (
724
+ is_cross_attention
725
+ and past_key_value is not None
726
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
727
+ ):
728
+ # reuse k,v, cross_attentions
729
+ key_states = past_key_value[0]
730
+ value_states = past_key_value[1]
731
+ elif is_cross_attention:
732
+ # cross_attentions
733
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
734
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
735
+ elif past_key_value is not None:
736
+ # reuse k, v, self_attention
737
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
738
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
739
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
740
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
741
+ else:
742
+ # self_attention
743
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
744
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
745
+
746
+ if self.is_decoder:
747
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
748
+ # Further calls to cross_attention layer can then reuse all cross-attention
749
+ # key/value_states (first "if" case)
750
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
751
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
752
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
753
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
754
+ past_key_value = (key_states, value_states)
755
+
756
+ query_states = self._shape(query_states, tgt_len, bsz)
757
+
758
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
759
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
760
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
761
+ is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
762
+
763
+ # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
764
+ # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
765
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
766
+ query_states,
767
+ key_states,
768
+ value_states,
769
+ attn_mask=attention_mask,
770
+ dropout_p=self.dropout if self.training else 0.0,
771
+ is_causal=is_causal,
772
+ )
773
+
774
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
775
+ raise ValueError(
776
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
777
+ f" {attn_output.size()}"
778
+ )
779
+
780
+ attn_output = attn_output.transpose(1, 2)
781
+
782
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
783
+ # partitioned across GPUs when using tensor-parallelism.
784
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
785
+
786
+ attn_output = self.out_proj(attn_output)
787
+
788
+ return attn_output, None, past_key_value
789
+
790
+
791
+ SEW_ATTENTION_CLASSES = {
792
+ "eager": SEWAttention,
793
+ "sdpa": SEWSdpaAttention,
794
+ "flash_attention_2": SEWFlashAttention2,
795
+ }
796
+
797
+
798
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->SEW
799
+ class SEWFeedForward(nn.Module):
800
+ def __init__(self, config):
801
+ super().__init__()
802
+ self.intermediate_dropout = nn.Dropout(config.activation_dropout)
803
+
804
+ self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
805
+ if isinstance(config.hidden_act, str):
806
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
807
+ else:
808
+ self.intermediate_act_fn = config.hidden_act
809
+
810
+ self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
811
+ self.output_dropout = nn.Dropout(config.hidden_dropout)
812
+
813
+ def forward(self, hidden_states):
814
+ hidden_states = self.intermediate_dense(hidden_states)
815
+ hidden_states = self.intermediate_act_fn(hidden_states)
816
+ hidden_states = self.intermediate_dropout(hidden_states)
817
+
818
+ hidden_states = self.output_dense(hidden_states)
819
+ hidden_states = self.output_dropout(hidden_states)
820
+ return hidden_states
821
+
822
+
823
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->SEW, WAV2VEC2->SEW
824
+ class SEWEncoderLayer(nn.Module):
825
+ def __init__(self, config):
826
+ super().__init__()
827
+ self.attention = SEW_ATTENTION_CLASSES[config._attn_implementation](
828
+ embed_dim=config.hidden_size,
829
+ num_heads=config.num_attention_heads,
830
+ dropout=config.attention_dropout,
831
+ is_decoder=False,
832
+ )
833
+
834
+ self.dropout = nn.Dropout(config.hidden_dropout)
835
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
836
+ self.feed_forward = SEWFeedForward(config)
837
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
838
+
839
+ def forward(self, hidden_states, attention_mask=None, output_attentions=False):
840
+ attn_residual = hidden_states
841
+ hidden_states, attn_weights, _ = self.attention(
842
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
843
+ )
844
+ hidden_states = self.dropout(hidden_states)
845
+ hidden_states = attn_residual + hidden_states
846
+
847
+ hidden_states = self.layer_norm(hidden_states)
848
+ hidden_states = hidden_states + self.feed_forward(hidden_states)
849
+ hidden_states = self.final_layer_norm(hidden_states)
850
+
851
+ outputs = (hidden_states,)
852
+
853
+ if output_attentions:
854
+ outputs += (attn_weights,)
855
+
856
+ return outputs
857
+
858
+
859
+ class SEWEncoder(nn.Module):
860
+ def __init__(self, config):
861
+ super().__init__()
862
+ self.config = config
863
+ self.pos_conv_embed = SEWPositionalConvEmbedding(config)
864
+ self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor)
865
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
866
+ self.dropout = nn.Dropout(config.hidden_dropout)
867
+ self.layers = nn.ModuleList([SEWEncoderLayer(config) for _ in range(config.num_hidden_layers)])
868
+ self.upsample = SEWUpsampling(config)
869
+ self.gradient_checkpointing = False
870
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
871
+
872
+ def forward(
873
+ self,
874
+ hidden_states,
875
+ attention_mask=None,
876
+ output_attentions=False,
877
+ output_hidden_states=False,
878
+ return_dict=True,
879
+ ):
880
+ all_hidden_states = () if output_hidden_states else None
881
+ all_self_attentions = () if output_attentions else None
882
+
883
+ if attention_mask is not None:
884
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
885
+ if self._use_flash_attention_2:
886
+ # make sure padded tokens output 0
887
+ hidden_states[~expand_attention_mask] = 0.0
888
+ # 2d mask is passed through the layers
889
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
890
+ else:
891
+ # make sure padded tokens output 0
892
+ hidden_states[~expand_attention_mask] = 0.0
893
+ input_lengths = (attention_mask.long()).sum(-1)
894
+ # apply pooling formula to get real output_lengths
895
+ output_lengths = input_lengths // self.config.squeeze_factor
896
+ max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor
897
+ attention_ids = (
898
+ torch.arange(0, max_encoder_length, device=output_lengths.device)
899
+ .view(1, -1)
900
+ .expand(output_lengths.shape[0], -1)
901
+ )
902
+ attention_mask = (attention_ids < output_lengths.view(-1, 1)).long()
903
+
904
+ # extend attention_mask
905
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
906
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
907
+ attention_mask = attention_mask.expand(
908
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
909
+ )
910
+
911
+ n_input_timesteps = hidden_states.shape[1]
912
+
913
+ hidden_states = hidden_states.transpose(1, 2)
914
+ position_embeddings = self.pos_conv_embed(hidden_states)
915
+ pooled_hidden_states = self.pool(hidden_states)
916
+ min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1))
917
+ hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length]
918
+ hidden_states = hidden_states.transpose(1, 2)
919
+
920
+ hidden_states = self.layer_norm(hidden_states)
921
+ hidden_states = self.dropout(hidden_states)
922
+
923
+ synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
924
+
925
+ for layer in self.layers:
926
+ if output_hidden_states:
927
+ all_hidden_states = all_hidden_states + (hidden_states,)
928
+
929
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
930
+ dropout_probability = torch.rand([])
931
+
932
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
933
+ if not skip_the_layer or synced_gpus:
934
+ # under fsdp or deepspeed zero3 all gpus must run in sync
935
+ if self.gradient_checkpointing and self.training:
936
+ layer_outputs = self._gradient_checkpointing_func(
937
+ layer.__call__,
938
+ hidden_states,
939
+ attention_mask,
940
+ output_attentions,
941
+ )
942
+ else:
943
+ layer_outputs = layer(
944
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
945
+ )
946
+ hidden_states = layer_outputs[0]
947
+
948
+ if skip_the_layer:
949
+ layer_outputs = (None, None)
950
+
951
+ if output_attentions:
952
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
953
+
954
+ if output_hidden_states:
955
+ all_hidden_states = all_hidden_states + (hidden_states,)
956
+
957
+ hidden_states = self.upsample(hidden_states)
958
+ if hidden_states.shape[1] < n_input_timesteps:
959
+ hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1]))
960
+
961
+ if not return_dict:
962
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
963
+ return BaseModelOutput(
964
+ last_hidden_state=hidden_states,
965
+ hidden_states=all_hidden_states,
966
+ attentions=all_self_attentions,
967
+ )
968
+
969
+
970
+ class SEWPreTrainedModel(PreTrainedModel):
971
+ """
972
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
973
+ models.
974
+ """
975
+
976
+ config_class = SEWConfig
977
+ base_model_prefix = "sew"
978
+ main_input_name = "input_values"
979
+ supports_gradient_checkpointing = True
980
+ _supports_flash_attn_2 = True
981
+ _supports_sdpa = True
982
+
983
+ def _init_weights(self, module):
984
+ """Initialize the weights"""
985
+ if isinstance(module, SEWPositionalConvEmbedding):
986
+ nn.init.normal_(
987
+ module.conv.weight,
988
+ mean=0,
989
+ std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
990
+ )
991
+ nn.init.constant_(module.conv.bias, 0)
992
+ elif isinstance(module, nn.Linear):
993
+ # Slightly different from the TF version which uses truncated_normal for initialization
994
+ # cf https://github.com/pytorch/pytorch/pull/5617
995
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
996
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
997
+ module.bias.data.zero_()
998
+ module.weight.data.fill_(1.0)
999
+ elif isinstance(module, nn.Conv1d):
1000
+ if is_deepspeed_zero3_enabled():
1001
+ import deepspeed
1002
+
1003
+ if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
1004
+ with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
1005
+ nn.init.kaiming_normal_(module.weight.data)
1006
+ else:
1007
+ with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
1008
+ nn.init.kaiming_normal_(module.weight.data)
1009
+ else:
1010
+ nn.init.kaiming_normal_(module.weight.data)
1011
+
1012
+ if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
1013
+ module.bias.data.zero_()
1014
+
1015
+ def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
1016
+ """
1017
+ Computes the output length of the convolutional layers
1018
+ """
1019
+
1020
+ def _conv_out_length(input_length, kernel_size, stride):
1021
+ # 1D convolutional layer output length formula taken
1022
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
1023
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
1024
+
1025
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
1026
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
1027
+
1028
+ return input_lengths
1029
+
1030
+ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
1031
+ output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
1032
+ batch_size = attention_mask.shape[0]
1033
+
1034
+ attention_mask = torch.zeros(
1035
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
1036
+ )
1037
+ # these two operations makes sure that all values before the output lengths idxs are attended to
1038
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
1039
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
1040
+ return attention_mask
1041
+
1042
+
1043
+ SEW_START_DOCSTRING = r"""
1044
+ SEW was proposed in [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech
1045
+ Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger,
1046
+ Yoav Artzi.
1047
+
1048
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1049
+ library implements for all its model (such as downloading or saving etc.).
1050
+
1051
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
1052
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
1053
+ behavior.
1054
+
1055
+ Parameters:
1056
+ config ([`SEWConfig`]): Model configuration class with all the parameters of the model.
1057
+ Initializing with a config file does not load the weights associated with the model, only the
1058
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1059
+ """
1060
+
1061
+
1062
+ SEW_INPUTS_DOCSTRING = r"""
1063
+ Args:
1064
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
1065
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
1066
+ into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
1067
+ soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
1068
+ conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
1069
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1070
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1071
+ 1]`:
1072
+
1073
+ - 1 for tokens that are **not masked**,
1074
+ - 0 for tokens that are **masked**.
1075
+
1076
+ [What are attention masks?](../glossary#attention-mask)
1077
+
1078
+ output_attentions (`bool`, *optional*):
1079
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1080
+ tensors for more detail.
1081
+ output_hidden_states (`bool`, *optional*):
1082
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1083
+ more detail.
1084
+ return_dict (`bool`, *optional*):
1085
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1086
+ """
1087
+
1088
+
1089
+ @add_start_docstrings(
1090
+ "The bare SEW Model transformer outputting raw hidden-states without any specific head on top.",
1091
+ SEW_START_DOCSTRING,
1092
+ )
1093
+ class SEWModel(SEWPreTrainedModel):
1094
+ def __init__(self, config: SEWConfig):
1095
+ super().__init__(config)
1096
+ self.config = config
1097
+ self.feature_extractor = SEWFeatureEncoder(config)
1098
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
1099
+
1100
+ self.project_features = config.conv_dim[-1] != config.hidden_size
1101
+ if self.project_features:
1102
+ self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
1103
+ self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
1104
+
1105
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
1106
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
1107
+
1108
+ self.encoder = SEWEncoder(config)
1109
+
1110
+ # Initialize weights and apply final processing
1111
+ self.post_init()
1112
+
1113
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
1114
+ def _mask_hidden_states(
1115
+ self,
1116
+ hidden_states: torch.FloatTensor,
1117
+ mask_time_indices: Optional[torch.FloatTensor] = None,
1118
+ attention_mask: Optional[torch.LongTensor] = None,
1119
+ ):
1120
+ """
1121
+ Masks extracted features along time axis and/or along feature axis according to
1122
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
1123
+ """
1124
+
1125
+ # `config.apply_spec_augment` can set masking to False
1126
+ if not getattr(self.config, "apply_spec_augment", True):
1127
+ return hidden_states
1128
+
1129
+ # generate indices & apply SpecAugment along time axis
1130
+ batch_size, sequence_length, hidden_size = hidden_states.size()
1131
+
1132
+ if mask_time_indices is not None:
1133
+ # apply SpecAugment along time axis with given mask_time_indices
1134
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
1135
+ elif self.config.mask_time_prob > 0 and self.training:
1136
+ mask_time_indices = _compute_mask_indices(
1137
+ (batch_size, sequence_length),
1138
+ mask_prob=self.config.mask_time_prob,
1139
+ mask_length=self.config.mask_time_length,
1140
+ attention_mask=attention_mask,
1141
+ min_masks=self.config.mask_time_min_masks,
1142
+ )
1143
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
1144
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
1145
+
1146
+ if self.config.mask_feature_prob > 0 and self.training:
1147
+ # generate indices & apply SpecAugment along feature axis
1148
+ mask_feature_indices = _compute_mask_indices(
1149
+ (batch_size, hidden_size),
1150
+ mask_prob=self.config.mask_feature_prob,
1151
+ mask_length=self.config.mask_feature_length,
1152
+ min_masks=self.config.mask_feature_min_masks,
1153
+ )
1154
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
1155
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
1156
+ hidden_states[mask_feature_indices] = 0
1157
+
1158
+ return hidden_states
1159
+
1160
+ @add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
1161
+ @add_code_sample_docstrings(
1162
+ checkpoint=_CHECKPOINT_FOR_DOC,
1163
+ output_type=BaseModelOutput,
1164
+ config_class=_CONFIG_FOR_DOC,
1165
+ modality="audio",
1166
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
1167
+ )
1168
+ def forward(
1169
+ self,
1170
+ input_values: Optional[torch.Tensor],
1171
+ attention_mask: Optional[torch.Tensor] = None,
1172
+ mask_time_indices: Optional[torch.FloatTensor] = None,
1173
+ output_attentions: Optional[bool] = None,
1174
+ output_hidden_states: Optional[bool] = None,
1175
+ return_dict: Optional[bool] = None,
1176
+ ) -> Union[Tuple, BaseModelOutput]:
1177
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1178
+ output_hidden_states = (
1179
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1180
+ )
1181
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1182
+
1183
+ extract_features = self.feature_extractor(input_values)
1184
+ extract_features = extract_features.transpose(1, 2)
1185
+ extract_features = self.layer_norm(extract_features)
1186
+
1187
+ if self.project_features:
1188
+ extract_features = self.feature_projection(extract_features)
1189
+ hidden_states = self.feature_dropout(extract_features)
1190
+
1191
+ if attention_mask is not None:
1192
+ # compute reduced attention_mask corresponding to feature vectors
1193
+ attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
1194
+
1195
+ hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
1196
+
1197
+ encoder_outputs = self.encoder(
1198
+ hidden_states,
1199
+ attention_mask=attention_mask,
1200
+ output_attentions=output_attentions,
1201
+ output_hidden_states=output_hidden_states,
1202
+ return_dict=return_dict,
1203
+ )
1204
+
1205
+ hidden_states = encoder_outputs[0]
1206
+
1207
+ if not return_dict:
1208
+ return (hidden_states,) + encoder_outputs[1:]
1209
+
1210
+ return BaseModelOutput(
1211
+ last_hidden_state=hidden_states,
1212
+ hidden_states=encoder_outputs.hidden_states,
1213
+ attentions=encoder_outputs.attentions,
1214
+ )
1215
+
1216
+
1217
+ @add_start_docstrings(
1218
+ """SEW Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
1219
+ SEW_START_DOCSTRING,
1220
+ )
1221
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW
1222
+ class SEWForCTC(SEWPreTrainedModel):
1223
+ def __init__(self, config, target_lang: Optional[str] = None):
1224
+ super().__init__(config)
1225
+
1226
+ self.sew = SEWModel(config)
1227
+ self.dropout = nn.Dropout(config.final_dropout)
1228
+
1229
+ self.target_lang = target_lang
1230
+
1231
+ if config.vocab_size is None:
1232
+ raise ValueError(
1233
+ f"You are trying to instantiate {self.__class__} with a configuration that "
1234
+ "does not define the vocabulary size of the language model head. Please "
1235
+ "instantiate the model as follows: `SEWForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
1236
+ "or define `vocab_size` of your model's configuration."
1237
+ )
1238
+ output_hidden_size = (
1239
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
1240
+ )
1241
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
1242
+
1243
+ # Initialize weights and apply final processing
1244
+ self.post_init()
1245
+
1246
+ def tie_weights(self):
1247
+ """
1248
+ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
1249
+ passing `target_lang=...` to `from_pretrained(...)`.
1250
+
1251
+ This method is **not** supposed to be called by the user and is prone to be changed in the future.
1252
+ """
1253
+
1254
+ # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
1255
+ # correctly load adapter layers for SEW so that we do not have to introduce a new API to
1256
+ # [`PreTrainedModel`]. While slightly hacky, SEW never has to tie input and output embeddings, so that it is
1257
+ # ok to repurpose this function here.
1258
+ target_lang = self.target_lang
1259
+
1260
+ if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
1261
+ raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
1262
+ elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
1263
+ logger.info("By default `target_lang` is set to 'eng'.")
1264
+ elif target_lang is not None:
1265
+ self.load_adapter(target_lang, force_load=True)
1266
+
1267
+ def freeze_feature_extractor(self):
1268
+ """
1269
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1270
+ not be updated during training.
1271
+ """
1272
+ warnings.warn(
1273
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
1274
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
1275
+ FutureWarning,
1276
+ )
1277
+ self.freeze_feature_encoder()
1278
+
1279
+ def freeze_feature_encoder(self):
1280
+ """
1281
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1282
+ not be updated during training.
1283
+ """
1284
+ self.sew.feature_extractor._freeze_parameters()
1285
+
1286
+ def freeze_base_model(self):
1287
+ """
1288
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
1289
+ be updated during training. Only the classification head will be updated.
1290
+ """
1291
+ for param in self.sew.parameters():
1292
+ param.requires_grad = False
1293
+
1294
+ @add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
1295
+ @add_code_sample_docstrings(
1296
+ checkpoint=_CHECKPOINT_FOR_DOC,
1297
+ output_type=CausalLMOutput,
1298
+ config_class=_CONFIG_FOR_DOC,
1299
+ expected_output=_CTC_EXPECTED_OUTPUT,
1300
+ expected_loss=_CTC_EXPECTED_LOSS,
1301
+ )
1302
+ def forward(
1303
+ self,
1304
+ input_values: Optional[torch.Tensor],
1305
+ attention_mask: Optional[torch.Tensor] = None,
1306
+ output_attentions: Optional[bool] = None,
1307
+ output_hidden_states: Optional[bool] = None,
1308
+ return_dict: Optional[bool] = None,
1309
+ labels: Optional[torch.Tensor] = None,
1310
+ ) -> Union[Tuple, CausalLMOutput]:
1311
+ r"""
1312
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
1313
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
1314
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
1315
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
1316
+ config.vocab_size - 1]`.
1317
+ """
1318
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1319
+
1320
+ if labels is not None and labels.max() >= self.config.vocab_size:
1321
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
1322
+
1323
+ outputs = self.sew(
1324
+ input_values,
1325
+ attention_mask=attention_mask,
1326
+ output_attentions=output_attentions,
1327
+ output_hidden_states=output_hidden_states,
1328
+ return_dict=return_dict,
1329
+ )
1330
+
1331
+ hidden_states = outputs[0]
1332
+ hidden_states = self.dropout(hidden_states)
1333
+
1334
+ logits = self.lm_head(hidden_states)
1335
+
1336
+ loss = None
1337
+ if labels is not None:
1338
+ # retrieve loss input_lengths from attention_mask
1339
+ attention_mask = (
1340
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
1341
+ )
1342
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
1343
+
1344
+ # assuming that padded tokens are filled with -100
1345
+ # when not being attended to
1346
+ labels_mask = labels >= 0
1347
+ target_lengths = labels_mask.sum(-1)
1348
+ flattened_targets = labels.masked_select(labels_mask)
1349
+
1350
+ # ctc_loss doesn't support fp16
1351
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
1352
+
1353
+ with torch.backends.cudnn.flags(enabled=False):
1354
+ loss = nn.functional.ctc_loss(
1355
+ log_probs,
1356
+ flattened_targets,
1357
+ input_lengths,
1358
+ target_lengths,
1359
+ blank=self.config.pad_token_id,
1360
+ reduction=self.config.ctc_loss_reduction,
1361
+ zero_infinity=self.config.ctc_zero_infinity,
1362
+ )
1363
+
1364
+ if not return_dict:
1365
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1366
+ return ((loss,) + output) if loss is not None else output
1367
+
1368
+ return CausalLMOutput(
1369
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1370
+ )
1371
+
1372
+
1373
+ @add_start_docstrings(
1374
+ """
1375
+ SEW Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB
1376
+ Keyword Spotting.
1377
+ """,
1378
+ SEW_START_DOCSTRING,
1379
+ )
1380
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW
1381
+ class SEWForSequenceClassification(SEWPreTrainedModel):
1382
+ def __init__(self, config):
1383
+ super().__init__(config)
1384
+
1385
+ if hasattr(config, "add_adapter") and config.add_adapter:
1386
+ raise ValueError(
1387
+ "Sequence classification does not support the use of SEW adapters (config.add_adapter=True)"
1388
+ )
1389
+ self.sew = SEWModel(config)
1390
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1391
+ if config.use_weighted_layer_sum:
1392
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1393
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
1394
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
1395
+
1396
+ # Initialize weights and apply final processing
1397
+ self.post_init()
1398
+
1399
+ def freeze_feature_extractor(self):
1400
+ """
1401
+ Calling this function will disable the gradient computation for the feature encoder so that its parameters will
1402
+ not be updated during training.
1403
+ """
1404
+ warnings.warn(
1405
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
1406
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
1407
+ FutureWarning,
1408
+ )
1409
+ self.freeze_feature_encoder()
1410
+
1411
+ def freeze_feature_encoder(self):
1412
+ """
1413
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1414
+ not be updated during training.
1415
+ """
1416
+ self.sew.feature_extractor._freeze_parameters()
1417
+
1418
+ def freeze_base_model(self):
1419
+ """
1420
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
1421
+ be updated during training. Only the classification head will be updated.
1422
+ """
1423
+ for param in self.sew.parameters():
1424
+ param.requires_grad = False
1425
+
1426
+ @add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
1427
+ @add_code_sample_docstrings(
1428
+ checkpoint=_SEQ_CLASS_CHECKPOINT,
1429
+ output_type=SequenceClassifierOutput,
1430
+ config_class=_CONFIG_FOR_DOC,
1431
+ modality="audio",
1432
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1433
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
1434
+ )
1435
+ def forward(
1436
+ self,
1437
+ input_values: Optional[torch.Tensor],
1438
+ attention_mask: Optional[torch.Tensor] = None,
1439
+ output_attentions: Optional[bool] = None,
1440
+ output_hidden_states: Optional[bool] = None,
1441
+ return_dict: Optional[bool] = None,
1442
+ labels: Optional[torch.Tensor] = None,
1443
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1444
+ r"""
1445
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1446
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1447
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1448
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1449
+ """
1450
+
1451
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1452
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
1453
+
1454
+ outputs = self.sew(
1455
+ input_values,
1456
+ attention_mask=attention_mask,
1457
+ output_attentions=output_attentions,
1458
+ output_hidden_states=output_hidden_states,
1459
+ return_dict=return_dict,
1460
+ )
1461
+
1462
+ if self.config.use_weighted_layer_sum:
1463
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
1464
+ hidden_states = torch.stack(hidden_states, dim=1)
1465
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
1466
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
1467
+ else:
1468
+ hidden_states = outputs[0]
1469
+
1470
+ hidden_states = self.projector(hidden_states)
1471
+ if attention_mask is None:
1472
+ pooled_output = hidden_states.mean(dim=1)
1473
+ else:
1474
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
1475
+ expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
1476
+ hidden_states[~expand_padding_mask] = 0.0
1477
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
1478
+
1479
+ logits = self.classifier(pooled_output)
1480
+
1481
+ loss = None
1482
+ if labels is not None:
1483
+ loss_fct = CrossEntropyLoss()
1484
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1485
+
1486
+ if not return_dict:
1487
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1488
+ return ((loss,) + output) if loss is not None else output
1489
+
1490
+ return SequenceClassifierOutput(
1491
+ loss=loss,
1492
+ logits=logits,
1493
+ hidden_states=outputs.hidden_states,
1494
+ attentions=outputs.attentions,
1495
+ )
1496
+
1497
+
1498
+ __all__ = ["SEWForCTC", "SEWForSequenceClassification", "SEWModel", "SEWPreTrainedModel"]
docs/transformers/build/lib/transformers/models/sew_d/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_sew_d import *
22
+ from .modeling_sew_d import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/sew_d/configuration_sew_d.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 ASAPP Inc. and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """SEW-D model configuration"""
16
+
17
+ import functools
18
+ import operator
19
+
20
+ from ...configuration_utils import PretrainedConfig
21
+ from ...utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class SEWDConfig(PretrainedConfig):
28
+ r"""
29
+ This is the configuration class to store the configuration of a [`SEWDModel`]. It is used to instantiate a SEW-D
30
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
31
+ defaults will yield a similar configuration to that of the SEW-D
32
+ [asapp/sew-d-tiny-100k](https://huggingface.co/asapp/sew-d-tiny-100k) architecture.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 32):
40
+ Vocabulary size of the SEW-D model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`SEWD`].
42
+ hidden_size (`int`, *optional*, defaults to 768):
43
+ Dimensionality of the encoder layers and the pooler layer.
44
+ num_hidden_layers (`int`, *optional*, defaults to 12):
45
+ Number of hidden layers in the Transformer encoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 12):
47
+ Number of attention heads for each attention layer in the Transformer encoder.
48
+ intermediate_size (`int`, *optional*, defaults to 3072):
49
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
50
+ squeeze_factor (`int`, *optional*, defaults to 2):
51
+ Sequence length downsampling factor after the encoder and upsampling factor after the transformer.
52
+ max_position_embeddings (`int`, *optional*, defaults to 512):
53
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
54
+ just in case (e.g., 512 or 1024 or 2048).
55
+ position_buckets (`int`, *optional*, defaults to 256):
56
+ The maximum size of relative position embeddings.
57
+ share_att_key (`bool`, *optional*, defaults to `True`):
58
+ Whether to share attention key with c2p and p2c.
59
+ relative_attention (`bool`, *optional*, defaults to `True`):
60
+ Whether to use relative position encoding.
61
+ pos_att_type (`Tuple[str]`, *optional*, defaults to `("p2c", "c2p")`):
62
+ The type of relative position attention, it can be a combination of `("p2c", "c2p")`, e.g. `("p2c")`,
63
+ `("p2c", "c2p")`, `("p2c", "c2p")`.
64
+ norm_rel_ebd (`str`, *optional*, defaults to `"layer_norm"`):
65
+ Whether to use layer norm in relative embedding (`"layer_norm"` if yes)
66
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_python"`):
67
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
68
+ `"relu"`, `"selu"`, `"gelu_python"` and `"gelu_new"` are supported.
69
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
70
+ Deprecated. Not used by the model and will be removed in a future version.
71
+ activation_dropout (`float`, *optional*, defaults to 0.1):
72
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
73
+ attention_dropout (`float`, *optional*, defaults to 0.1):
74
+ The dropout ratio for the attention probabilities.
75
+ final_dropout (`float`, *optional*, defaults to 0.1):
76
+ The dropout probability for the final projection layer of [`SEWDForCTC`].
77
+ initializer_range (`float`, *optional*, defaults to 0.02):
78
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
79
+ layer_norm_eps (`float`, *optional*, defaults to 1e-7):
80
+ The epsilon used by the layer normalization layers in the transformer encoder.
81
+ feature_layer_norm_eps (`float`, *optional*, defaults to 1e-5):
82
+ The epsilon used by the layer normalization after the feature encoder.
83
+ feat_extract_norm (`str`, *optional*, defaults to `"group"`):
84
+ The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
85
+ normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
86
+ convolutional layers.
87
+ feat_proj_dropout (`float`, *optional*, defaults to 0.0):
88
+ The dropout probability for output of the feature encoder.
89
+ feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
90
+ The non-linear activation function (function or string) in the 1D convolutional layers of the feature
91
+ extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
92
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512)`):
93
+ A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
94
+ feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
95
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1)`):
96
+ A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
97
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
98
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1)`):
99
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
100
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
101
+ *conv_dim*.
102
+ conv_bias (`bool`, *optional*, defaults to `False`):
103
+ Whether the 1D convolutional layers have a bias.
104
+ num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
105
+ Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
106
+ embeddings layer.
107
+ num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
108
+ Number of groups of 1D convolutional positional embeddings layer.
109
+ apply_spec_augment (`bool`, *optional*, defaults to `True`):
110
+ Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
111
+ [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
112
+ Recognition](https://arxiv.org/abs/1904.08779).
113
+ mask_time_prob (`float`, *optional*, defaults to 0.05):
114
+ Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
115
+ procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
116
+ reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
117
+ masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
118
+ actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
119
+ mask_time_length (`int`, *optional*, defaults to 10):
120
+ Length of vector span along the time axis.
121
+ mask_time_min_masks (`int`, *optional*, defaults to 2),:
122
+ The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
123
+ irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
124
+ mask_time_min_masks''
125
+ mask_feature_prob (`float`, *optional*, defaults to 0.0):
126
+ Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
127
+ masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
128
+ the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
129
+ span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
130
+ may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
131
+ True`.
132
+ mask_feature_length (`int`, *optional*, defaults to 10):
133
+ Length of vector span along the feature axis.
134
+ mask_feature_min_masks (`int`, *optional*, defaults to 0),:
135
+ The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
136
+ step, irrespectively of `mask_feature_prob`. Only relevant if
137
+ ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
138
+ diversity_loss_weight (`int`, *optional*, defaults to 0.1):
139
+ The weight of the codebook diversity loss component.
140
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
141
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
142
+ instance of [`SEWDForCTC`].
143
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
144
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
145
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
146
+ of [`SEWDForCTC`].
147
+ use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
148
+ Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
149
+ instance of [`Wav2Vec2ForSequenceClassification`].
150
+ classifier_proj_size (`int`, *optional*, defaults to 256):
151
+ Dimensionality of the projection before token mean-pooling for classification.
152
+
153
+ Example:
154
+
155
+ ```python
156
+ >>> from transformers import SEWDConfig, SEWDModel
157
+
158
+ >>> # Initializing a SEW-D asapp/sew-d-tiny-100k style configuration
159
+ >>> configuration = SEWDConfig()
160
+
161
+ >>> # Initializing a model (with random weights) from the asapp/sew-d-tiny-100k style configuration
162
+ >>> model = SEWDModel(configuration)
163
+
164
+ >>> # Accessing the model configuration
165
+ >>> configuration = model.config
166
+ ```"""
167
+
168
+ model_type = "sew-d"
169
+
170
+ def __init__(
171
+ self,
172
+ vocab_size=32,
173
+ hidden_size=768,
174
+ num_hidden_layers=12,
175
+ num_attention_heads=12,
176
+ intermediate_size=3072,
177
+ squeeze_factor=2,
178
+ max_position_embeddings=512,
179
+ position_buckets=256,
180
+ share_att_key=True,
181
+ relative_attention=True,
182
+ pos_att_type=("p2c", "c2p"),
183
+ norm_rel_ebd="layer_norm",
184
+ hidden_act="gelu_python",
185
+ hidden_dropout=0.1,
186
+ activation_dropout=0.1,
187
+ attention_dropout=0.1,
188
+ feat_proj_dropout=0.0,
189
+ final_dropout=0.1,
190
+ initializer_range=0.02,
191
+ layer_norm_eps=1e-7,
192
+ feature_layer_norm_eps=1e-5,
193
+ feat_extract_norm="group",
194
+ feat_extract_activation="gelu",
195
+ conv_dim=(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512),
196
+ conv_stride=(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1),
197
+ conv_kernel=(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1),
198
+ conv_bias=False,
199
+ num_conv_pos_embeddings=128,
200
+ num_conv_pos_embedding_groups=16,
201
+ apply_spec_augment=True,
202
+ mask_time_prob=0.05,
203
+ mask_time_length=10,
204
+ mask_time_min_masks=2,
205
+ mask_feature_prob=0.0,
206
+ mask_feature_length=10,
207
+ mask_feature_min_masks=0,
208
+ ctc_loss_reduction="mean",
209
+ ctc_zero_infinity=False,
210
+ use_weighted_layer_sum=False,
211
+ classifier_proj_size=256,
212
+ pad_token_id=0,
213
+ bos_token_id=1,
214
+ eos_token_id=2,
215
+ **kwargs,
216
+ ):
217
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
218
+ self.hidden_size = hidden_size
219
+ self.feat_extract_norm = feat_extract_norm
220
+ self.feat_extract_activation = feat_extract_activation
221
+ self.conv_dim = list(conv_dim)
222
+ self.conv_stride = list(conv_stride)
223
+ self.conv_kernel = list(conv_kernel)
224
+ self.conv_bias = conv_bias
225
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
226
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
227
+ self.num_feat_extract_layers = len(self.conv_dim)
228
+ self.num_hidden_layers = num_hidden_layers
229
+ self.intermediate_size = intermediate_size
230
+ self.squeeze_factor = squeeze_factor
231
+ self.max_position_embeddings = max_position_embeddings
232
+ self.position_buckets = position_buckets
233
+ self.share_att_key = share_att_key
234
+ self.relative_attention = relative_attention
235
+ self.norm_rel_ebd = norm_rel_ebd
236
+ self.pos_att_type = list(pos_att_type)
237
+ self.hidden_act = hidden_act
238
+ self.num_attention_heads = num_attention_heads
239
+ self._hidden_dropout = hidden_dropout
240
+ self.attention_dropout = attention_dropout
241
+ self.activation_dropout = activation_dropout
242
+ self.feat_proj_dropout = feat_proj_dropout
243
+ self.final_dropout = final_dropout
244
+ self.layer_norm_eps = layer_norm_eps
245
+ self.feature_layer_norm_eps = feature_layer_norm_eps
246
+ self.initializer_range = initializer_range
247
+ self.vocab_size = vocab_size
248
+
249
+ if (
250
+ (len(self.conv_stride) != self.num_feat_extract_layers)
251
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
252
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
253
+ ):
254
+ raise ValueError(
255
+ "Configuration for convolutional layers is incorrect. "
256
+ "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
257
+ f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
258
+ f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
259
+ )
260
+
261
+ # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
262
+ self.apply_spec_augment = apply_spec_augment
263
+ self.mask_time_prob = mask_time_prob
264
+ self.mask_time_length = mask_time_length
265
+ self.mask_time_min_masks = mask_time_min_masks
266
+ self.mask_feature_prob = mask_feature_prob
267
+ self.mask_feature_length = mask_feature_length
268
+ self.mask_feature_min_masks = mask_feature_min_masks
269
+
270
+ # ctc loss
271
+ self.ctc_loss_reduction = ctc_loss_reduction
272
+ self.ctc_zero_infinity = ctc_zero_infinity
273
+
274
+ # sequence classification
275
+ self.use_weighted_layer_sum = use_weighted_layer_sum
276
+ self.classifier_proj_size = classifier_proj_size
277
+
278
+ @property
279
+ def inputs_to_logits_ratio(self):
280
+ return functools.reduce(operator.mul, self.conv_stride, 1)
281
+
282
+ def to_dict(self):
283
+ """
284
+ Serializes this instance to a Python dictionary.
285
+ """
286
+ output = super().to_dict()
287
+ output["hidden_dropout"] = output.pop("_hidden_dropout")
288
+ return output
289
+
290
+
291
+ __all__ = ["SEWDConfig"]
docs/transformers/build/lib/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert SEW checkpoint."""
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+
21
+ import fairseq
22
+ import torch
23
+ from fairseq.data import Dictionary
24
+
25
+ # Register SEW's fairseq modules
26
+ from sew_asapp import tasks # noqa: F401
27
+
28
+ from transformers import (
29
+ SEWDConfig,
30
+ SEWDForCTC,
31
+ SEWDModel,
32
+ Wav2Vec2CTCTokenizer,
33
+ Wav2Vec2FeatureExtractor,
34
+ Wav2Vec2Processor,
35
+ logging,
36
+ )
37
+
38
+
39
+ logging.set_verbosity_info()
40
+ logger = logging.get_logger(__name__)
41
+
42
+ MAPPING = {
43
+ "post_extract_proj": "feature_projection",
44
+ "encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
45
+ "attention.self.query_proj": "encoder.encoder.layer.*.attention.self.query_proj",
46
+ "attention.self.key_proj": "encoder.encoder.layer.*.attention.self.key_proj",
47
+ "attention.self.value_proj": "encoder.encoder.layer.*.attention.self.value_proj",
48
+ "attention.output.dense": "encoder.encoder.layer.*.attention.output.dense",
49
+ "attention.output.LayerNorm": "encoder.encoder.layer.*.attention.output.LayerNorm",
50
+ "intermediate.dense": "encoder.encoder.layer.*.intermediate.dense",
51
+ "output.dense": "encoder.encoder.layer.*.output.dense",
52
+ "output.LayerNorm": "encoder.encoder.layer.*.output.LayerNorm",
53
+ "encoder.encoder.rel_embeddings": "encoder.encoder.rel_embeddings",
54
+ "encoder.encoder.LayerNorm": "encoder.encoder.LayerNorm",
55
+ "encoder.upsample.0": "encoder.upsample.projection",
56
+ "encoder.layer_norm": "encoder.layer_norm",
57
+ "w2v_model.layer_norm": "layer_norm",
58
+ "w2v_encoder.proj": "lm_head",
59
+ "mask_emb": "masked_spec_embed",
60
+ }
61
+
62
+
63
+ def set_recursively(hf_pointer, key, value, full_name, weight_type):
64
+ for attribute in key.split("."):
65
+ hf_pointer = getattr(hf_pointer, attribute)
66
+
67
+ if weight_type is not None:
68
+ hf_shape = getattr(hf_pointer, weight_type).shape
69
+ else:
70
+ hf_shape = hf_pointer.shape
71
+
72
+ assert hf_shape == value.shape, (
73
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
74
+ f" {value.shape} for {full_name}"
75
+ )
76
+
77
+ if weight_type == "weight":
78
+ hf_pointer.weight.data = value
79
+ elif weight_type == "weight_g":
80
+ hf_pointer.weight_g.data = value
81
+ elif weight_type == "weight_v":
82
+ hf_pointer.weight_v.data = value
83
+ elif weight_type == "bias":
84
+ hf_pointer.bias.data = value
85
+ else:
86
+ hf_pointer.data = value
87
+
88
+ logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
89
+
90
+
91
+ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
92
+ unused_weights = []
93
+ fairseq_dict = fairseq_model.state_dict()
94
+
95
+ feature_extractor = hf_model.sew_d.feature_extractor if is_finetuned else hf_model.feature_extractor
96
+
97
+ for name, value in fairseq_dict.items():
98
+ is_used = False
99
+ if "conv_layers" in name:
100
+ load_conv_layer(
101
+ name,
102
+ value,
103
+ feature_extractor,
104
+ unused_weights,
105
+ hf_model.config.feat_extract_norm == "group",
106
+ )
107
+ is_used = True
108
+ else:
109
+ for key, mapped_key in MAPPING.items():
110
+ mapped_key = "sew_d." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
111
+
112
+ if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
113
+ is_used = True
114
+ if "*" in mapped_key:
115
+ layer_index = name.split(key)[0].split(".")[-2]
116
+ if not layer_index.isnumeric():
117
+ continue
118
+ mapped_key = mapped_key.replace("*", layer_index)
119
+ if "weight_g" in name:
120
+ weight_type = "weight_g"
121
+ elif "weight_v" in name:
122
+ weight_type = "weight_v"
123
+ elif "weight" in name:
124
+ weight_type = "weight"
125
+ elif "bias" in name:
126
+ weight_type = "bias"
127
+ else:
128
+ weight_type = None
129
+ set_recursively(hf_model, mapped_key, value, name, weight_type)
130
+ continue
131
+ if not is_used:
132
+ unused_weights.append(name)
133
+
134
+ logger.warning(f"Unused weights: {unused_weights}")
135
+
136
+
137
+ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
138
+ name = full_name.split("conv_layers.")[-1]
139
+ items = name.split(".")
140
+ layer_id = int(items[0])
141
+ type_id = int(items[1])
142
+
143
+ if type_id == 0:
144
+ if "bias" in name:
145
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
146
+ f"{full_name} has size {value.shape}, but"
147
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
148
+ )
149
+ feature_extractor.conv_layers[layer_id].conv.bias.data = value
150
+ logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
151
+ elif "weight" in name:
152
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
153
+ f"{full_name} has size {value.shape}, but"
154
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
155
+ )
156
+ feature_extractor.conv_layers[layer_id].conv.weight.data = value
157
+ logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
158
+ elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
159
+ if "bias" in name:
160
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
161
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
162
+ " found."
163
+ )
164
+ feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
165
+ logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
166
+ elif "weight" in name:
167
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
168
+ f"{full_name} has size {value.shape}, but"
169
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
170
+ )
171
+ feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
172
+ logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
173
+ else:
174
+ unused_weights.append(full_name)
175
+
176
+
177
+ def convert_config(model, is_finetuned):
178
+ config = SEWDConfig()
179
+ if is_finetuned:
180
+ fs_config = model.w2v_encoder.w2v_model.cfg
181
+ else:
182
+ fs_config = model.cfg
183
+
184
+ config.conv_bias = fs_config.conv_bias
185
+ conv_layers = eval(fs_config.conv_feature_layers)
186
+ config.conv_dim = [x[0] for x in conv_layers]
187
+ config.conv_kernel = [x[1] for x in conv_layers]
188
+ config.conv_stride = [x[2] for x in conv_layers]
189
+ config.feat_extract_activation = "gelu"
190
+ config.feat_extract_norm = "layer" if fs_config.extractor_mode == "layer_norm" else "group"
191
+ config.final_dropout = 0.0
192
+ config.hidden_act = fs_config.activation_fn.name
193
+ config.hidden_size = fs_config.encoder_embed_dim
194
+ config.initializer_range = 0.02
195
+ config.intermediate_size = fs_config.encoder_ffn_embed_dim
196
+ config.layer_norm_eps = 1e-5
197
+ config.layerdrop = fs_config.encoder_layerdrop
198
+ config.num_attention_heads = fs_config.encoder_attention_heads
199
+ config.num_conv_pos_embedding_groups = fs_config.conv_pos_groups
200
+ config.num_conv_pos_embeddings = fs_config.conv_pos
201
+ config.num_feat_extract_layers = len(conv_layers)
202
+ config.num_hidden_layers = fs_config.encoder_layers
203
+ config.squeeze_factor = fs_config.squeeze_factor
204
+ # DeBERTa-specific parameters:
205
+ config.max_position_embeddings = fs_config.max_position_embeddings
206
+ config.position_buckets = fs_config.position_buckets
207
+ config.share_att_key = fs_config.share_att_key
208
+ config.relative_attention = fs_config.relative_attention
209
+ config.position_biased_input = fs_config.position_biased_input
210
+ config.pos_att_type = tuple(fs_config.pos_att_type.split("|"))
211
+ config.norm_rel_ebd = fs_config.norm_rel_ebd
212
+
213
+ # take care of any params that are overridden by the Wav2VecCtc model
214
+ if is_finetuned:
215
+ fs_config = model.cfg
216
+ config.final_dropout = fs_config.final_dropout
217
+ config.layerdrop = fs_config.layerdrop
218
+ config.activation_dropout = fs_config.activation_dropout
219
+ config.apply_spec_augment = fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0
220
+ config.attention_dropout = fs_config.attention_dropout
221
+ config.feat_proj_dropout = fs_config.dropout_input
222
+ config.hidden_dropout = fs_config.dropout
223
+ config.mask_feature_length = fs_config.mask_channel_length
224
+ config.mask_feature_prob = fs_config.mask_channel_prob
225
+ config.mask_time_length = fs_config.mask_length
226
+ config.mask_time_prob = fs_config.mask_prob
227
+
228
+ config.feature_extractor_type = "Wav2Vec2FeatureExtractor"
229
+ config.tokenizer_class = "Wav2Vec2CTCTokenizer"
230
+
231
+ return config
232
+
233
+
234
+ @torch.no_grad()
235
+ def convert_sew_checkpoint(
236
+ checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
237
+ ):
238
+ """
239
+ Copy/paste/tweak model's weights to transformers design.
240
+ """
241
+
242
+ if is_finetuned:
243
+ model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
244
+ [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
245
+ )
246
+ else:
247
+ model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])
248
+
249
+ if config_path is not None:
250
+ config = SEWDConfig.from_pretrained(config_path)
251
+ else:
252
+ config = convert_config(model[0], is_finetuned)
253
+ model = model[0].eval()
254
+
255
+ return_attention_mask = True if config.feat_extract_norm == "layer" else False
256
+ feature_extractor = Wav2Vec2FeatureExtractor(
257
+ feature_size=1,
258
+ sampling_rate=16000,
259
+ padding_value=0,
260
+ do_normalize=True,
261
+ return_attention_mask=return_attention_mask,
262
+ )
263
+
264
+ if is_finetuned:
265
+ if dict_path:
266
+ target_dict = Dictionary.load(dict_path)
267
+
268
+ # important change bos & pad token id since CTC symbol is <pad> and
269
+ # not <s> as in fairseq
270
+ target_dict.indices[target_dict.bos_word] = target_dict.pad_index
271
+ target_dict.indices[target_dict.pad_word] = target_dict.bos_index
272
+ config.bos_token_id = target_dict.pad_index
273
+ config.pad_token_id = target_dict.bos_index
274
+ config.eos_token_id = target_dict.eos_index
275
+ config.vocab_size = len(target_dict.symbols)
276
+ vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
277
+ if not os.path.isdir(pytorch_dump_folder_path):
278
+ logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path))
279
+ return
280
+ os.makedirs(pytorch_dump_folder_path, exist_ok=True)
281
+ with open(vocab_path, "w", encoding="utf-8") as vocab_handle:
282
+ json.dump(target_dict.indices, vocab_handle)
283
+ tokenizer = Wav2Vec2CTCTokenizer(
284
+ vocab_path,
285
+ unk_token=target_dict.unk_word,
286
+ pad_token=target_dict.pad_word,
287
+ bos_token=target_dict.bos_word,
288
+ eos_token=target_dict.eos_word,
289
+ word_delimiter_token="|",
290
+ do_lower_case=False,
291
+ )
292
+ processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
293
+ processor.save_pretrained(pytorch_dump_folder_path)
294
+
295
+ hf_model = SEWDForCTC(config)
296
+ else:
297
+ hf_model = SEWDModel(config)
298
+ feature_extractor.save_pretrained(pytorch_dump_folder_path)
299
+
300
+ recursively_load_weights(model, hf_model, is_finetuned)
301
+
302
+ hf_model.save_pretrained(pytorch_dump_folder_path)
303
+
304
+
305
+ if __name__ == "__main__":
306
+ parser = argparse.ArgumentParser()
307
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
308
+ parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
309
+ parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
310
+ parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
311
+ parser.add_argument(
312
+ "--is_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
313
+ )
314
+ args = parser.parse_args()
315
+ convert_sew_checkpoint(
316
+ args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, args.is_finetuned
317
+ )
docs/transformers/build/lib/transformers/models/sew_d/modeling_sew_d.py ADDED
@@ -0,0 +1,1748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch SEW model."""
16
+
17
+ import math
18
+ import warnings
19
+ from collections.abc import Sequence
20
+ from typing import Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import CrossEntropyLoss, LayerNorm
27
+
28
+ from ...activations import ACT2FN
29
+ from ...integrations.deepspeed import is_deepspeed_zero3_enabled
30
+ from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
31
+ from ...modeling_utils import PreTrainedModel
32
+ from ...pytorch_utils import softmax_backward_data
33
+ from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
34
+ from .configuration_sew_d import SEWDConfig
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ _HIDDEN_STATES_START_POSITION = 1
40
+
41
+
42
+ # General docstring
43
+ _CONFIG_FOR_DOC = "SEWDConfig"
44
+
45
+ # Base docstring
46
+ _CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k-ft-ls100h"
47
+ _EXPECTED_OUTPUT_SHAPE = [1, 292, 384]
48
+
49
+ # CTC docstring
50
+ _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTIL OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
51
+ _CTC_EXPECTED_LOSS = 0.21
52
+
53
+ # Audio class docstring
54
+ _SEQ_CLASS_CHECKPOINT = "anton-l/sew-d-mid-400k-ft-keyword-spotting"
55
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
56
+ _SEQ_CLASS_EXPECTED_LOSS = 3.16
57
+
58
+
59
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
60
+ def _compute_mask_indices(
61
+ shape: Tuple[int, int],
62
+ mask_prob: float,
63
+ mask_length: int,
64
+ attention_mask: Optional[torch.LongTensor] = None,
65
+ min_masks: int = 0,
66
+ ) -> np.ndarray:
67
+ """
68
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
69
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
70
+ CPU as part of the preprocessing during training.
71
+
72
+ Args:
73
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
74
+ the first element is the batch size and the second element is the length of the axis to span.
75
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
76
+ independently generated mask spans of length `mask_length` is computed by
77
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
78
+ actual percentage will be smaller.
79
+ mask_length: size of the mask
80
+ min_masks: minimum number of masked spans
81
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
82
+ each batch dimension.
83
+ """
84
+ batch_size, sequence_length = shape
85
+
86
+ if mask_length < 1:
87
+ raise ValueError("`mask_length` has to be bigger than 0.")
88
+
89
+ if mask_length > sequence_length:
90
+ raise ValueError(
91
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
92
+ f" and `sequence_length`: {sequence_length}`"
93
+ )
94
+
95
+ # epsilon is used for probabilistic rounding
96
+ epsilon = np.random.rand(1).item()
97
+
98
+ def compute_num_masked_span(input_length):
99
+ """Given input length, compute how many spans should be masked"""
100
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
101
+ num_masked_span = max(num_masked_span, min_masks)
102
+
103
+ # make sure num masked span <= sequence_length
104
+ if num_masked_span * mask_length > sequence_length:
105
+ num_masked_span = sequence_length // mask_length
106
+
107
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
108
+ if input_length - (mask_length - 1) < num_masked_span:
109
+ num_masked_span = max(input_length - (mask_length - 1), 0)
110
+
111
+ return num_masked_span
112
+
113
+ # compute number of masked spans in batch
114
+ input_lengths = (
115
+ attention_mask.detach().sum(-1).tolist()
116
+ if attention_mask is not None
117
+ else [sequence_length for _ in range(batch_size)]
118
+ )
119
+
120
+ # SpecAugment mask to fill
121
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
122
+ spec_aug_mask_idxs = []
123
+
124
+ max_num_masked_span = compute_num_masked_span(sequence_length)
125
+
126
+ if max_num_masked_span == 0:
127
+ return spec_aug_mask
128
+
129
+ for input_length in input_lengths:
130
+ # compute num of masked spans for this input
131
+ num_masked_span = compute_num_masked_span(input_length)
132
+
133
+ # get random indices to mask
134
+ spec_aug_mask_idx = np.random.choice(
135
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
136
+ )
137
+
138
+ # pick first sampled index that will serve as a dummy index to pad vector
139
+ # to ensure same dimension for all batches due to probabilistic rounding
140
+ # Picking first sample just pads those vectors twice.
141
+ if len(spec_aug_mask_idx) == 0:
142
+ # this case can only happen if `input_length` is strictly smaller then
143
+ # `sequence_length` in which case the last token has to be a padding
144
+ # token which we can use as a dummy mask id
145
+ dummy_mask_idx = sequence_length - 1
146
+ else:
147
+ dummy_mask_idx = spec_aug_mask_idx[0]
148
+
149
+ spec_aug_mask_idx = np.concatenate(
150
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
151
+ )
152
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
153
+
154
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
155
+
156
+ # expand masked indices to masked spans
157
+ spec_aug_mask_idxs = np.broadcast_to(
158
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
159
+ )
160
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
161
+
162
+ # add offset to the starting indexes so that indexes now create a span
163
+ offsets = np.arange(mask_length)[None, None, :]
164
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
165
+ batch_size, max_num_masked_span * mask_length
166
+ )
167
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
168
+
169
+ # ensure that we cannot have indices larger than sequence_length
170
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
171
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
172
+
173
+ # scatter indices to mask
174
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
175
+
176
+ return spec_aug_mask
177
+
178
+
179
+ def make_log_bucket_position(relative_pos, bucket_size, max_position):
180
+ sign = torch.sign(relative_pos)
181
+ mid = bucket_size // 2
182
+ abs_pos = torch.where(
183
+ (relative_pos < mid) & (relative_pos > -mid),
184
+ torch.tensor(mid - 1).type_as(relative_pos),
185
+ torch.abs(relative_pos),
186
+ )
187
+ log_pos = (
188
+ torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
189
+ )
190
+ bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
191
+ return bucket_pos
192
+
193
+
194
+ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1, device=None):
195
+ """
196
+ Build relative position according to the query and key
197
+
198
+ We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
199
+ \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
200
+ P_k\\)
201
+
202
+ Args:
203
+ query_size (int): the length of query
204
+ key_size (int): the length of key
205
+ bucket_size (int): the size of position bucket
206
+ max_position (int): the maximum allowed absolute position
207
+ device (`torch.device`): the device on which tensors will be created.
208
+
209
+ Return:
210
+ `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
211
+ """
212
+
213
+ q_ids = torch.arange(0, query_size, device=device)
214
+ k_ids = torch.arange(0, key_size, device=device)
215
+ rel_pos_ids = q_ids[:, None] - k_ids[None, :]
216
+ if bucket_size > 0 and max_position > 0:
217
+ rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
218
+ rel_pos_ids = rel_pos_ids.to(torch.long)
219
+ rel_pos_ids = rel_pos_ids[:query_size, :]
220
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
221
+ return rel_pos_ids
222
+
223
+
224
+ @torch.jit.script
225
+ # Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand
226
+ def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
227
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
228
+
229
+
230
+ @torch.jit.script
231
+ # Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand
232
+ def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
233
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
234
+
235
+
236
+ @torch.jit.script
237
+ # Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand
238
+ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
239
+ return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
240
+
241
+
242
+ def get_mask(input, local_context):
243
+ if not isinstance(local_context, DropoutContext):
244
+ dropout = local_context
245
+ mask = None
246
+ else:
247
+ dropout = local_context.dropout
248
+ dropout *= local_context.scale
249
+ mask = local_context.mask if local_context.reuse_mask else None
250
+
251
+ if dropout > 0 and mask is None:
252
+ mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
253
+
254
+ if isinstance(local_context, DropoutContext):
255
+ if local_context.mask is None:
256
+ local_context.mask = mask
257
+
258
+ return mask, dropout
259
+
260
+
261
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SEWD
262
+ class SEWDNoLayerNormConvLayer(nn.Module):
263
+ def __init__(self, config, layer_id=0):
264
+ super().__init__()
265
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
266
+ self.out_conv_dim = config.conv_dim[layer_id]
267
+
268
+ self.conv = nn.Conv1d(
269
+ self.in_conv_dim,
270
+ self.out_conv_dim,
271
+ kernel_size=config.conv_kernel[layer_id],
272
+ stride=config.conv_stride[layer_id],
273
+ bias=config.conv_bias,
274
+ )
275
+ self.activation = ACT2FN[config.feat_extract_activation]
276
+
277
+ def forward(self, hidden_states):
278
+ hidden_states = self.conv(hidden_states)
279
+ hidden_states = self.activation(hidden_states)
280
+ return hidden_states
281
+
282
+
283
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SEWD
284
+ class SEWDLayerNormConvLayer(nn.Module):
285
+ def __init__(self, config, layer_id=0):
286
+ super().__init__()
287
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
288
+ self.out_conv_dim = config.conv_dim[layer_id]
289
+
290
+ self.conv = nn.Conv1d(
291
+ self.in_conv_dim,
292
+ self.out_conv_dim,
293
+ kernel_size=config.conv_kernel[layer_id],
294
+ stride=config.conv_stride[layer_id],
295
+ bias=config.conv_bias,
296
+ )
297
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
298
+ self.activation = ACT2FN[config.feat_extract_activation]
299
+
300
+ def forward(self, hidden_states):
301
+ hidden_states = self.conv(hidden_states)
302
+
303
+ hidden_states = hidden_states.transpose(-2, -1)
304
+ hidden_states = self.layer_norm(hidden_states)
305
+ hidden_states = hidden_states.transpose(-2, -1)
306
+
307
+ hidden_states = self.activation(hidden_states)
308
+ return hidden_states
309
+
310
+
311
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SEWD
312
+ class SEWDGroupNormConvLayer(nn.Module):
313
+ def __init__(self, config, layer_id=0):
314
+ super().__init__()
315
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
316
+ self.out_conv_dim = config.conv_dim[layer_id]
317
+
318
+ self.conv = nn.Conv1d(
319
+ self.in_conv_dim,
320
+ self.out_conv_dim,
321
+ kernel_size=config.conv_kernel[layer_id],
322
+ stride=config.conv_stride[layer_id],
323
+ bias=config.conv_bias,
324
+ )
325
+ self.activation = ACT2FN[config.feat_extract_activation]
326
+
327
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
328
+
329
+ def forward(self, hidden_states):
330
+ hidden_states = self.conv(hidden_states)
331
+ hidden_states = self.layer_norm(hidden_states)
332
+ hidden_states = self.activation(hidden_states)
333
+ return hidden_states
334
+
335
+
336
+ # Copied from transformers.models.sew.modeling_sew.SEWPositionalConvEmbedding with SEW->SEWD
337
+ class SEWDPositionalConvEmbedding(nn.Module):
338
+ def __init__(self, config):
339
+ super().__init__()
340
+ self.conv = nn.Conv1d(
341
+ config.hidden_size,
342
+ config.hidden_size,
343
+ kernel_size=config.num_conv_pos_embeddings,
344
+ padding=config.num_conv_pos_embeddings // 2,
345
+ groups=config.num_conv_pos_embedding_groups,
346
+ stride=config.squeeze_factor,
347
+ )
348
+
349
+ weight_norm = nn.utils.weight_norm
350
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
351
+ weight_norm = nn.utils.parametrizations.weight_norm
352
+
353
+ if is_deepspeed_zero3_enabled():
354
+ import deepspeed
355
+
356
+ with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
357
+ self.conv = weight_norm(self.conv, name="weight", dim=2)
358
+ if hasattr(self.conv, "parametrizations"):
359
+ weight_g = self.conv.parametrizations.weight.original0
360
+ weight_v = self.conv.parametrizations.weight.original1
361
+ else:
362
+ weight_g = self.conv.weight_g
363
+ weight_v = self.conv.weight_v
364
+ deepspeed.zero.register_external_parameter(self, weight_v)
365
+ deepspeed.zero.register_external_parameter(self, weight_g)
366
+ else:
367
+ self.conv = weight_norm(self.conv, name="weight", dim=2)
368
+
369
+ self.padding = SEWDSamePadLayer(config.num_conv_pos_embeddings)
370
+ self.activation = ACT2FN[config.feat_extract_activation]
371
+
372
+ def forward(self, hidden_states):
373
+ hidden_states = self.conv(hidden_states)
374
+ hidden_states = self.padding(hidden_states)
375
+ hidden_states = self.activation(hidden_states)
376
+
377
+ return hidden_states
378
+
379
+
380
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->SEW
381
+ class SEWDSamePadLayer(nn.Module):
382
+ def __init__(self, num_conv_pos_embeddings):
383
+ super().__init__()
384
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
385
+
386
+ def forward(self, hidden_states):
387
+ if self.num_pad_remove > 0:
388
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
389
+ return hidden_states
390
+
391
+
392
+ # Copied from transformers.models.sew.modeling_sew.SEWUpsampling with SEW->SEWD
393
+ class SEWDUpsampling(nn.Module):
394
+ def __init__(self, config):
395
+ super().__init__()
396
+ self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor)
397
+ self.activation = ACT2FN[config.feat_extract_activation]
398
+ self.squeeze_factor = config.squeeze_factor
399
+
400
+ def forward(self, hidden_states):
401
+ hidden_states = self.projection(hidden_states)
402
+ hidden_states = self.activation(hidden_states)
403
+
404
+ if self.squeeze_factor > 1:
405
+ # transform embedding channels to sequence length
406
+ bsz, src_len, src_embed_dim = hidden_states.size()
407
+ tgt_len = src_len * self.squeeze_factor
408
+ tgt_embed_dim = src_embed_dim // self.squeeze_factor
409
+ hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim)
410
+ hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim)
411
+
412
+ return hidden_states
413
+
414
+
415
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->SEWD
416
+ class SEWDFeatureEncoder(nn.Module):
417
+ """Construct the features from raw audio waveform"""
418
+
419
+ def __init__(self, config):
420
+ super().__init__()
421
+
422
+ if config.feat_extract_norm == "group":
423
+ conv_layers = [SEWDGroupNormConvLayer(config, layer_id=0)] + [
424
+ SEWDNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
425
+ ]
426
+ elif config.feat_extract_norm == "layer":
427
+ conv_layers = [SEWDLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
428
+ else:
429
+ raise ValueError(
430
+ f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
431
+ )
432
+ self.conv_layers = nn.ModuleList(conv_layers)
433
+ self.gradient_checkpointing = False
434
+ self._requires_grad = True
435
+
436
+ def _freeze_parameters(self):
437
+ for param in self.parameters():
438
+ param.requires_grad = False
439
+ self._requires_grad = False
440
+
441
+ def forward(self, input_values):
442
+ hidden_states = input_values[:, None]
443
+
444
+ # make sure hidden_states require grad for gradient_checkpointing
445
+ if self._requires_grad and self.training:
446
+ hidden_states.requires_grad = True
447
+
448
+ for conv_layer in self.conv_layers:
449
+ if self._requires_grad and self.gradient_checkpointing and self.training:
450
+ hidden_states = self._gradient_checkpointing_func(
451
+ conv_layer.__call__,
452
+ hidden_states,
453
+ )
454
+ else:
455
+ hidden_states = conv_layer(hidden_states)
456
+
457
+ return hidden_states
458
+
459
+
460
+ class SEWDFeatureExtractor(SEWDFeatureEncoder):
461
+ def __init__(self, config):
462
+ super().__init__(config)
463
+ warnings.warn(
464
+ f"The class `{self.__class__.__name__}` has been depreciated "
465
+ "and will be removed in Transformers v5. "
466
+ f"Use `{self.__class__.__bases__[0].__name__}` instead.",
467
+ FutureWarning,
468
+ )
469
+
470
+
471
+ class ContextPooler(nn.Module):
472
+ def __init__(self, config):
473
+ super().__init__()
474
+ self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
475
+ self.dropout = StableDropout(config.pooler_dropout)
476
+ self.config = config
477
+
478
+ def forward(self, hidden_states):
479
+ # We "pool" the model by simply taking the hidden state corresponding
480
+ # to the first token.
481
+
482
+ context_token = hidden_states[:, 0]
483
+ context_token = self.dropout(context_token)
484
+ pooled_output = self.dense(context_token)
485
+ pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
486
+ return pooled_output
487
+
488
+ @property
489
+ def output_dim(self):
490
+ return self.config.hidden_size
491
+
492
+
493
+ class XSoftmax(torch.autograd.Function):
494
+ """
495
+ Masked Softmax which is optimized for saving memory
496
+
497
+ Args:
498
+ input (`torch.tensor`): The input tensor that will apply softmax.
499
+ mask (`torch.IntTensor`):
500
+ The mask matrix where 0 indicate that element will be ignored in the softmax calculation.
501
+ dim (int): The dimension that will apply softmax
502
+
503
+ Example:
504
+
505
+ ```python
506
+ >>> import torch
507
+ >>> from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax
508
+
509
+ >>> # Make a tensor
510
+ >>> x = torch.randn([4, 20, 100])
511
+
512
+ >>> # Create a mask
513
+ >>> mask = (x > 0).int()
514
+
515
+ >>> # Specify the dimension to apply softmax
516
+ >>> dim = -1
517
+
518
+ >>> y = XSoftmax.apply(x, mask, dim)
519
+ ```"""
520
+
521
+ @staticmethod
522
+ def forward(ctx, input, mask, dim):
523
+ ctx.dim = dim
524
+ rmask = ~(mask.to(torch.bool))
525
+
526
+ output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
527
+ output = torch.softmax(output, ctx.dim)
528
+ output.masked_fill_(rmask, 0)
529
+ ctx.save_for_backward(output)
530
+ return output
531
+
532
+ @staticmethod
533
+ def backward(ctx, grad_output):
534
+ (output,) = ctx.saved_tensors
535
+ inputGrad = softmax_backward_data(ctx, grad_output, output, ctx.dim, output)
536
+ return inputGrad, None, None
537
+
538
+ @staticmethod
539
+ def symbolic(g, self, mask, dim):
540
+ import torch.onnx.symbolic_helper as sym_help
541
+ from torch.onnx.symbolic_opset9 import masked_fill, softmax
542
+
543
+ mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"])
544
+ r_mask = g.op(
545
+ "Cast",
546
+ g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
547
+ to_i=sym_help.cast_pytorch_to_onnx["Bool"],
548
+ )
549
+ output = masked_fill(
550
+ g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
551
+ )
552
+ output = softmax(g, output, dim)
553
+ return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool)))
554
+
555
+
556
+ class DropoutContext:
557
+ def __init__(self):
558
+ self.dropout = 0
559
+ self.mask = None
560
+ self.scale = 1
561
+ self.reuse_mask = True
562
+
563
+
564
+ class XDropout(torch.autograd.Function):
565
+ """Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
566
+
567
+ @staticmethod
568
+ def forward(ctx, input, local_ctx):
569
+ mask, dropout = get_mask(input, local_ctx)
570
+ ctx.scale = 1.0 / (1 - dropout)
571
+ if dropout > 0:
572
+ ctx.save_for_backward(mask)
573
+ return input.masked_fill(mask, 0) * ctx.scale
574
+ else:
575
+ return input
576
+
577
+ @staticmethod
578
+ def backward(ctx, grad_output):
579
+ if ctx.scale > 1:
580
+ (mask,) = ctx.saved_tensors
581
+ return grad_output.masked_fill(mask, 0) * ctx.scale, None
582
+ else:
583
+ return grad_output, None
584
+
585
+ @staticmethod
586
+ def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value:
587
+ from torch.onnx import symbolic_opset12
588
+
589
+ dropout_p = local_ctx
590
+ if isinstance(local_ctx, DropoutContext):
591
+ dropout_p = local_ctx.dropout
592
+ # StableDropout only calls this function when training.
593
+ train = True
594
+ # TODO: We should check if the opset_version being used to export
595
+ # is > 12 here, but there's no good way to do that. As-is, if the
596
+ # opset_version < 12, export will fail with a CheckerError.
597
+ # Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
598
+ # if opset_version < 12:
599
+ # return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
600
+ return symbolic_opset12.dropout(g, input, dropout_p, train)
601
+
602
+
603
+ class StableDropout(nn.Module):
604
+ """
605
+ Optimized dropout module for stabilizing the training
606
+
607
+ Args:
608
+ drop_prob (float): the dropout probabilities
609
+ """
610
+
611
+ def __init__(self, drop_prob):
612
+ super().__init__()
613
+ self.drop_prob = drop_prob
614
+ self.count = 0
615
+ self.context_stack = None
616
+
617
+ def forward(self, x):
618
+ """
619
+ Call the module
620
+
621
+ Args:
622
+ x (`torch.tensor`): The input tensor to apply dropout
623
+ """
624
+ if self.training and self.drop_prob > 0:
625
+ return XDropout.apply(x, self.get_context())
626
+ return x
627
+
628
+ def clear_context(self):
629
+ self.count = 0
630
+ self.context_stack = None
631
+
632
+ def init_context(self, reuse_mask=True, scale=1):
633
+ if self.context_stack is None:
634
+ self.context_stack = []
635
+ self.count = 0
636
+ for c in self.context_stack:
637
+ c.reuse_mask = reuse_mask
638
+ c.scale = scale
639
+
640
+ def get_context(self):
641
+ if self.context_stack is not None:
642
+ if self.count >= len(self.context_stack):
643
+ self.context_stack.append(DropoutContext())
644
+ ctx = self.context_stack[self.count]
645
+ ctx.dropout = self.drop_prob
646
+ self.count += 1
647
+ return ctx
648
+ else:
649
+ return self.drop_prob
650
+
651
+
652
+ class SEWDSelfOutput(nn.Module):
653
+ def __init__(self, config):
654
+ super().__init__()
655
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
656
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
657
+ self.dropout = nn.Dropout(config.activation_dropout)
658
+
659
+ def forward(self, hidden_states, input_tensor):
660
+ hidden_states = self.dense(hidden_states)
661
+ hidden_states = self.dropout(hidden_states)
662
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
663
+ return hidden_states
664
+
665
+
666
+ class DisentangledSelfAttention(nn.Module):
667
+ """
668
+ Disentangled self-attention module
669
+
670
+ Parameters:
671
+ config (`DebertaV2Config`):
672
+ A model config class instance with the configuration to build a new model. The schema is similar to
673
+ *BertConfig*, for more details, please refer [`DebertaV2Config`]
674
+
675
+ """
676
+
677
+ def __init__(self, config):
678
+ super().__init__()
679
+ if config.hidden_size % config.num_attention_heads != 0:
680
+ raise ValueError(
681
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
682
+ f"heads ({config.num_attention_heads})"
683
+ )
684
+ self.num_attention_heads = config.num_attention_heads
685
+ _attention_head_size = config.hidden_size // config.num_attention_heads
686
+ self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
687
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
688
+ self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
689
+ self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
690
+ self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
691
+
692
+ self.share_att_key = getattr(config, "share_att_key", False)
693
+ self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
694
+ self.relative_attention = getattr(config, "relative_attention", False)
695
+
696
+ if self.relative_attention:
697
+ self.position_buckets = getattr(config, "position_buckets", -1)
698
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
699
+ if self.max_relative_positions < 1:
700
+ self.max_relative_positions = config.max_position_embeddings
701
+ self.pos_ebd_size = self.max_relative_positions
702
+ if self.position_buckets > 0:
703
+ self.pos_ebd_size = self.position_buckets
704
+
705
+ self.pos_dropout = StableDropout(config.activation_dropout)
706
+
707
+ if not self.share_att_key:
708
+ if "c2p" in self.pos_att_type:
709
+ self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
710
+ if "p2c" in self.pos_att_type:
711
+ self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
712
+
713
+ self.dropout = StableDropout(config.attention_dropout)
714
+
715
+ def transpose_for_scores(self, x, attention_heads):
716
+ new_x_shape = x.size()[:-1] + (attention_heads, -1)
717
+ x = x.view(new_x_shape)
718
+ return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
719
+
720
+ def forward(
721
+ self,
722
+ hidden_states,
723
+ attention_mask,
724
+ output_attentions=False,
725
+ query_states=None,
726
+ relative_pos=None,
727
+ rel_embeddings=None,
728
+ ):
729
+ """
730
+ Call the module
731
+
732
+ Args:
733
+ hidden_states (`torch.FloatTensor`):
734
+ Input states to the module usually the output from previous layer, it will be the Q,K and V in
735
+ *Attention(Q,K,V)*
736
+
737
+ attention_mask (`torch.BoolTensor`):
738
+ An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
739
+ sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
740
+ th token.
741
+
742
+ output_attentions (`bool`, *optional*):
743
+ Whether return the attention matrix.
744
+
745
+ query_states (`torch.FloatTensor`, *optional*):
746
+ The *Q* state in *Attention(Q,K,V)*.
747
+
748
+ relative_pos (`torch.LongTensor`):
749
+ The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
750
+ values ranging in [*-max_relative_positions*, *max_relative_positions*].
751
+
752
+ rel_embeddings (`torch.FloatTensor`):
753
+ The embedding of relative distances. It's a tensor of shape [\\(2 \\times
754
+ \\text{max_relative_positions}\\), *hidden_size*].
755
+
756
+
757
+ """
758
+ if query_states is None:
759
+ query_states = hidden_states
760
+ query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads)
761
+ key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)
762
+ value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
763
+
764
+ rel_att = None
765
+ # Take the dot product between "query" and "key" to get the raw attention scores.
766
+ scale_factor = 1
767
+ if "c2p" in self.pos_att_type:
768
+ scale_factor += 1
769
+ if "p2c" in self.pos_att_type:
770
+ scale_factor += 1
771
+ scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
772
+ attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype))
773
+ if self.relative_attention:
774
+ rel_embeddings = self.pos_dropout(rel_embeddings)
775
+ rel_att = self.disentangled_attention_bias(
776
+ query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
777
+ )
778
+
779
+ if rel_att is not None:
780
+ attention_scores = attention_scores + rel_att
781
+ attention_scores = attention_scores
782
+ attention_scores = attention_scores.view(
783
+ -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
784
+ )
785
+
786
+ # bsz x height x length x dimension
787
+ attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
788
+ attention_probs = self.dropout(attention_probs)
789
+ context_layer = torch.bmm(
790
+ attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer
791
+ )
792
+ context_layer = (
793
+ context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1))
794
+ .permute(0, 2, 1, 3)
795
+ .contiguous()
796
+ )
797
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
798
+ context_layer = context_layer.view(new_context_layer_shape)
799
+ if output_attentions:
800
+ return (context_layer, attention_probs)
801
+ else:
802
+ return context_layer
803
+
804
+ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
805
+ if relative_pos is None:
806
+ q = query_layer.size(-2)
807
+ relative_pos = build_relative_position(
808
+ q,
809
+ key_layer.size(-2),
810
+ bucket_size=self.position_buckets,
811
+ max_position=self.max_relative_positions,
812
+ device=query_layer.device,
813
+ )
814
+ if relative_pos.dim() == 2:
815
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
816
+ elif relative_pos.dim() == 3:
817
+ relative_pos = relative_pos.unsqueeze(1)
818
+ # bsz x height x query x key
819
+ elif relative_pos.dim() != 4:
820
+ raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
821
+
822
+ att_span = self.pos_ebd_size
823
+ relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long)
824
+
825
+ rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
826
+ if self.share_att_key:
827
+ pos_query_layer = self.transpose_for_scores(
828
+ self.query_proj(rel_embeddings), self.num_attention_heads
829
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
830
+ pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
831
+ query_layer.size(0) // self.num_attention_heads, 1, 1
832
+ )
833
+ else:
834
+ if "c2p" in self.pos_att_type:
835
+ pos_key_layer = self.transpose_for_scores(
836
+ self.pos_key_proj(rel_embeddings), self.num_attention_heads
837
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1)
838
+ if "p2c" in self.pos_att_type:
839
+ pos_query_layer = self.transpose_for_scores(
840
+ self.pos_query_proj(rel_embeddings), self.num_attention_heads
841
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1)
842
+
843
+ score = 0
844
+ # content->position
845
+ if "c2p" in self.pos_att_type:
846
+ scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
847
+ c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
848
+ c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
849
+ c2p_att = torch.gather(
850
+ c2p_att,
851
+ dim=-1,
852
+ index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
853
+ )
854
+ score += c2p_att / scale.to(dtype=c2p_att.dtype)
855
+
856
+ # position->content
857
+ if "p2c" in self.pos_att_type:
858
+ scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
859
+ if key_layer.size(-2) != query_layer.size(-2):
860
+ r_pos = build_relative_position(
861
+ key_layer.size(-2),
862
+ key_layer.size(-2),
863
+ bucket_size=self.position_buckets,
864
+ max_position=self.max_relative_positions,
865
+ device=query_layer.device,
866
+ )
867
+ r_pos = r_pos.unsqueeze(0)
868
+ else:
869
+ r_pos = relative_pos
870
+
871
+ p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
872
+ p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
873
+ p2c_att = torch.gather(
874
+ p2c_att,
875
+ dim=-1,
876
+ index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
877
+ ).transpose(-1, -2)
878
+ score += p2c_att / scale.to(dtype=p2c_att.dtype)
879
+
880
+ return score
881
+
882
+
883
+ class SEWDAttention(nn.Module):
884
+ def __init__(self, config):
885
+ super().__init__()
886
+ self.self = DisentangledSelfAttention(config)
887
+ self.output = SEWDSelfOutput(config)
888
+ self.config = config
889
+
890
+ def forward(
891
+ self,
892
+ hidden_states,
893
+ attention_mask,
894
+ output_attentions=False,
895
+ query_states=None,
896
+ relative_pos=None,
897
+ rel_embeddings=None,
898
+ ):
899
+ self_output = self.self(
900
+ hidden_states,
901
+ attention_mask,
902
+ output_attentions,
903
+ query_states=query_states,
904
+ relative_pos=relative_pos,
905
+ rel_embeddings=rel_embeddings,
906
+ )
907
+ if output_attentions:
908
+ self_output, att_matrix = self_output
909
+ if query_states is None:
910
+ query_states = hidden_states
911
+ attention_output = self.output(self_output, query_states)
912
+
913
+ if output_attentions:
914
+ return (attention_output, att_matrix)
915
+ else:
916
+ return attention_output
917
+
918
+
919
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->SEWD
920
+ class SEWDIntermediate(nn.Module):
921
+ def __init__(self, config):
922
+ super().__init__()
923
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
924
+ if isinstance(config.hidden_act, str):
925
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
926
+ else:
927
+ self.intermediate_act_fn = config.hidden_act
928
+
929
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
930
+ hidden_states = self.dense(hidden_states)
931
+ hidden_states = self.intermediate_act_fn(hidden_states)
932
+ return hidden_states
933
+
934
+
935
+ class SEWDOutput(nn.Module):
936
+ def __init__(self, config):
937
+ super().__init__()
938
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
939
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
940
+ self.dropout = nn.Dropout(config.activation_dropout)
941
+ self.config = config
942
+
943
+ def forward(self, hidden_states, input_tensor):
944
+ hidden_states = self.dense(hidden_states)
945
+ hidden_states = self.dropout(hidden_states)
946
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
947
+ return hidden_states
948
+
949
+
950
+ class SEWDLayer(nn.Module):
951
+ def __init__(self, config):
952
+ super().__init__()
953
+ self.attention = SEWDAttention(config)
954
+ self.intermediate = SEWDIntermediate(config)
955
+ self.output = SEWDOutput(config)
956
+
957
+ def forward(
958
+ self,
959
+ hidden_states,
960
+ attention_mask,
961
+ query_states=None,
962
+ relative_pos=None,
963
+ rel_embeddings=None,
964
+ output_attentions=False,
965
+ ):
966
+ attention_output = self.attention(
967
+ hidden_states,
968
+ attention_mask,
969
+ output_attentions=output_attentions,
970
+ query_states=query_states,
971
+ relative_pos=relative_pos,
972
+ rel_embeddings=rel_embeddings,
973
+ )
974
+ if output_attentions:
975
+ attention_output, att_matrix = attention_output
976
+ intermediate_output = self.intermediate(attention_output)
977
+ layer_output = self.output(intermediate_output, attention_output)
978
+ if output_attentions:
979
+ return (layer_output, att_matrix)
980
+ else:
981
+ return layer_output
982
+
983
+
984
+ class ConvLayer(nn.Module):
985
+ def __init__(self, config):
986
+ super().__init__()
987
+ kernel_size = getattr(config, "conv_kernel_size", 3)
988
+ groups = getattr(config, "conv_groups", 1)
989
+ self.conv_act = getattr(config, "conv_act", "tanh")
990
+ self.conv = nn.Conv1d(
991
+ config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
992
+ )
993
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
994
+ self.dropout = StableDropout(config.hidden_dropout_prob)
995
+ self.config = config
996
+
997
+ def forward(self, hidden_states, residual_states, input_mask):
998
+ out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
999
+ rmask = (1 - input_mask).bool()
1000
+ out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0)
1001
+ out = ACT2FN[self.conv_act](self.dropout(out))
1002
+
1003
+ layer_norm_input = residual_states + out
1004
+ output = self.LayerNorm(layer_norm_input).to(layer_norm_input)
1005
+
1006
+ if input_mask is None:
1007
+ output_states = output
1008
+ else:
1009
+ if input_mask.dim() != layer_norm_input.dim():
1010
+ if input_mask.dim() == 4:
1011
+ input_mask = input_mask.squeeze(1).squeeze(1)
1012
+ input_mask = input_mask.unsqueeze(2)
1013
+
1014
+ input_mask = input_mask.to(output.dtype)
1015
+ output_states = output * input_mask
1016
+
1017
+ return output_states
1018
+
1019
+
1020
+ class SEWDTransformerEncoder(nn.Module):
1021
+ """Modified BertEncoder with relative position bias support"""
1022
+
1023
+ def __init__(self, config):
1024
+ super().__init__()
1025
+
1026
+ self.layer = nn.ModuleList([SEWDLayer(config) for _ in range(config.num_hidden_layers)])
1027
+ self.relative_attention = getattr(config, "relative_attention", False)
1028
+
1029
+ if self.relative_attention:
1030
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
1031
+ if self.max_relative_positions < 1:
1032
+ self.max_relative_positions = config.max_position_embeddings
1033
+
1034
+ self.position_buckets = getattr(config, "position_buckets", -1)
1035
+ pos_ebd_size = self.max_relative_positions * 2
1036
+
1037
+ if self.position_buckets > 0:
1038
+ pos_ebd_size = self.position_buckets * 2
1039
+
1040
+ self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size)
1041
+
1042
+ self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")]
1043
+
1044
+ if "layer_norm" in self.norm_rel_ebd:
1045
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
1046
+
1047
+ self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None
1048
+ self.gradient_checkpointing = False
1049
+
1050
+ def get_rel_embedding(self):
1051
+ rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
1052
+ if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd):
1053
+ rel_embeddings = self.LayerNorm(rel_embeddings)
1054
+ return rel_embeddings
1055
+
1056
+ def get_attention_mask(self, attention_mask):
1057
+ if attention_mask.dim() <= 2:
1058
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
1059
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
1060
+ elif attention_mask.dim() == 3:
1061
+ attention_mask = attention_mask.unsqueeze(1)
1062
+
1063
+ return attention_mask
1064
+
1065
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
1066
+ if self.relative_attention and relative_pos is None:
1067
+ q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
1068
+ relative_pos = build_relative_position(
1069
+ q,
1070
+ hidden_states.size(-2),
1071
+ bucket_size=self.position_buckets,
1072
+ max_position=self.max_relative_positions,
1073
+ device=hidden_states.device,
1074
+ )
1075
+ return relative_pos
1076
+
1077
+ def forward(
1078
+ self,
1079
+ hidden_states,
1080
+ attention_mask,
1081
+ output_hidden_states=True,
1082
+ output_attentions=False,
1083
+ query_states=None,
1084
+ relative_pos=None,
1085
+ return_dict=True,
1086
+ ):
1087
+ if attention_mask.dim() <= 2:
1088
+ input_mask = attention_mask
1089
+ else:
1090
+ input_mask = attention_mask.sum(-2) > 0
1091
+ attention_mask = self.get_attention_mask(attention_mask)
1092
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
1093
+
1094
+ all_hidden_states = () if output_hidden_states else None
1095
+ all_attentions = () if output_attentions else None
1096
+
1097
+ if isinstance(hidden_states, Sequence):
1098
+ next_kv = hidden_states[0]
1099
+ else:
1100
+ next_kv = hidden_states
1101
+ rel_embeddings = self.get_rel_embedding()
1102
+ output_states = next_kv
1103
+ for i, layer_module in enumerate(self.layer):
1104
+ if output_hidden_states:
1105
+ all_hidden_states = all_hidden_states + (output_states,)
1106
+
1107
+ if self.gradient_checkpointing and self.training:
1108
+ output_states = self._gradient_checkpointing_func(
1109
+ layer_module.__call__,
1110
+ next_kv,
1111
+ attention_mask,
1112
+ query_states,
1113
+ relative_pos,
1114
+ rel_embeddings,
1115
+ output_attentions,
1116
+ )
1117
+ else:
1118
+ output_states = layer_module(
1119
+ next_kv,
1120
+ attention_mask,
1121
+ query_states=query_states,
1122
+ relative_pos=relative_pos,
1123
+ rel_embeddings=rel_embeddings,
1124
+ output_attentions=output_attentions,
1125
+ )
1126
+
1127
+ if output_attentions:
1128
+ output_states, att_m = output_states
1129
+
1130
+ if i == 0 and self.conv is not None:
1131
+ output_states = self.conv(hidden_states, output_states, input_mask)
1132
+
1133
+ if query_states is not None:
1134
+ query_states = output_states
1135
+ if isinstance(hidden_states, Sequence):
1136
+ next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
1137
+ else:
1138
+ next_kv = output_states
1139
+
1140
+ if output_attentions:
1141
+ all_attentions = all_attentions + (att_m,)
1142
+
1143
+ if output_hidden_states:
1144
+ all_hidden_states = all_hidden_states + (output_states,)
1145
+
1146
+ if not return_dict:
1147
+ return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
1148
+ return BaseModelOutput(
1149
+ last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions
1150
+ )
1151
+
1152
+
1153
+ class SEWDEncoder(nn.Module):
1154
+ def __init__(self, config):
1155
+ super().__init__()
1156
+ self.config = config
1157
+ self.pos_conv_embed = SEWDPositionalConvEmbedding(config)
1158
+ self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor)
1159
+ self.encoder = SEWDTransformerEncoder(config)
1160
+ self.upsample = SEWDUpsampling(config)
1161
+ self.gradient_checkpointing = False
1162
+
1163
+ def forward(
1164
+ self,
1165
+ hidden_states: torch.tensor,
1166
+ attention_mask: Optional[torch.Tensor] = None,
1167
+ output_attentions: bool = False,
1168
+ output_hidden_states: bool = False,
1169
+ return_dict: bool = True,
1170
+ ):
1171
+ max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor
1172
+ if attention_mask is None:
1173
+ attention_mask = torch.ones(
1174
+ (hidden_states.shape[0], max_encoder_length), dtype=torch.long, device=hidden_states.device
1175
+ )
1176
+ else:
1177
+ # make sure padded tokens output 0
1178
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
1179
+ hidden_states[~expand_attention_mask.bool()] = 0.0
1180
+
1181
+ input_lengths = (attention_mask.long()).sum(-1)
1182
+ # apply pooling formula to get real output_lengths
1183
+ output_lengths = input_lengths // self.config.squeeze_factor
1184
+ attention_ids = (
1185
+ torch.arange(0, max_encoder_length, device=output_lengths.device)
1186
+ .view(1, -1)
1187
+ .expand(output_lengths.shape[0], -1)
1188
+ )
1189
+ attention_mask = (attention_ids < output_lengths.view(-1, 1)).long()
1190
+
1191
+ n_input_timesteps = hidden_states.shape[1]
1192
+
1193
+ hidden_states = hidden_states.transpose(1, 2)
1194
+ position_embeddings = self.pos_conv_embed(hidden_states)
1195
+ pooled_hidden_states = self.pool(hidden_states)
1196
+ min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1))
1197
+ hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length]
1198
+ hidden_states = hidden_states.transpose(1, 2)
1199
+
1200
+ encoder_outputs = self.encoder(hidden_states, attention_mask, output_hidden_states, output_attentions)
1201
+
1202
+ hidden_states = self.upsample(encoder_outputs.last_hidden_state)
1203
+ if hidden_states.shape[1] < n_input_timesteps:
1204
+ hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1]))
1205
+
1206
+ if not return_dict:
1207
+ return tuple(
1208
+ v for v in [hidden_states, encoder_outputs.hidden_states, encoder_outputs.attentions] if v is not None
1209
+ )
1210
+ return BaseModelOutput(
1211
+ last_hidden_state=hidden_states,
1212
+ hidden_states=encoder_outputs.hidden_states,
1213
+ attentions=encoder_outputs.attentions,
1214
+ )
1215
+
1216
+
1217
+ class SEWDPreTrainedModel(PreTrainedModel):
1218
+ """
1219
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1220
+ models.
1221
+ """
1222
+
1223
+ config_class = SEWDConfig
1224
+ base_model_prefix = "sew-d"
1225
+ main_input_name = "input_values"
1226
+ supports_gradient_checkpointing = True
1227
+
1228
+ def _init_weights(self, module):
1229
+ """Initialize the weights"""
1230
+ if isinstance(module, SEWDPositionalConvEmbedding):
1231
+ nn.init.normal_(
1232
+ module.conv.weight,
1233
+ mean=0,
1234
+ std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
1235
+ )
1236
+ nn.init.constant_(module.conv.bias, 0)
1237
+ elif isinstance(module, nn.Linear):
1238
+ # Slightly different from the TF version which uses truncated_normal for initialization
1239
+ # cf https://github.com/pytorch/pytorch/pull/5617
1240
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1241
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
1242
+ module.bias.data.zero_()
1243
+ module.weight.data.fill_(1.0)
1244
+ elif isinstance(module, nn.Conv1d):
1245
+ if is_deepspeed_zero3_enabled():
1246
+ import deepspeed
1247
+
1248
+ if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
1249
+ with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
1250
+ nn.init.kaiming_normal_(module.weight.data)
1251
+ else:
1252
+ with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
1253
+ nn.init.kaiming_normal_(module.weight.data)
1254
+ else:
1255
+ nn.init.kaiming_normal_(module.weight.data)
1256
+ elif isinstance(module, nn.Embedding):
1257
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1258
+ if module.padding_idx is not None:
1259
+ module.weight.data[module.padding_idx].zero_()
1260
+
1261
+ if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
1262
+ module.bias.data.zero_()
1263
+
1264
+ def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
1265
+ """
1266
+ Computes the output length of the convolutional layers
1267
+ """
1268
+
1269
+ def _conv_out_length(input_length, kernel_size, stride):
1270
+ # 1D convolutional layer output length formula taken
1271
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
1272
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
1273
+
1274
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
1275
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
1276
+
1277
+ return input_lengths
1278
+
1279
+ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
1280
+ output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
1281
+ batch_size = attention_mask.shape[0]
1282
+
1283
+ attention_mask = torch.zeros(
1284
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
1285
+ )
1286
+ # these two operations makes sure that all values before the output lengths idxs are attended to
1287
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
1288
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
1289
+ return attention_mask
1290
+
1291
+
1292
+ SEWD_START_DOCSTRING = r"""
1293
+ SEW-D was proposed in [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech
1294
+ Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger,
1295
+ Yoav Artzi.
1296
+
1297
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1298
+ library implements for all its model (such as downloading or saving etc.).
1299
+
1300
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
1301
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
1302
+ behavior.
1303
+
1304
+ Parameters:
1305
+ config ([`SEWDConfig`]): Model configuration class with all the parameters of the model.
1306
+ Initializing with a config file does not load the weights associated with the model, only the
1307
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1308
+ """
1309
+
1310
+
1311
+ SEWD_INPUTS_DOCSTRING = r"""
1312
+ Args:
1313
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
1314
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
1315
+ into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
1316
+ soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
1317
+ conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
1318
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1319
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1320
+ 1]`:
1321
+
1322
+ - 1 for tokens that are **not masked**,
1323
+ - 0 for tokens that are **masked**.
1324
+
1325
+ [What are attention masks?](../glossary#attention-mask)
1326
+
1327
+ output_attentions (`bool`, *optional*):
1328
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1329
+ tensors for more detail.
1330
+ output_hidden_states (`bool`, *optional*):
1331
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1332
+ more detail.
1333
+ return_dict (`bool`, *optional*):
1334
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1335
+ """
1336
+
1337
+
1338
+ @add_start_docstrings(
1339
+ "The bare SEW-D Model transformer outputting raw hidden-states without any specific head on top.",
1340
+ SEWD_START_DOCSTRING,
1341
+ )
1342
+ # Copied from transformers.models.sew.modeling_sew.SEWModel with SEW->SEWD, layer_norm_eps->feature_layer_norm_eps
1343
+ class SEWDModel(SEWDPreTrainedModel):
1344
+ def __init__(self, config: SEWDConfig):
1345
+ super().__init__(config)
1346
+ self.config = config
1347
+ self.feature_extractor = SEWDFeatureEncoder(config)
1348
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.feature_layer_norm_eps)
1349
+
1350
+ self.project_features = config.conv_dim[-1] != config.hidden_size
1351
+ if self.project_features:
1352
+ self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
1353
+ self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
1354
+
1355
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
1356
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
1357
+
1358
+ self.encoder = SEWDEncoder(config)
1359
+
1360
+ # Initialize weights and apply final processing
1361
+ self.post_init()
1362
+
1363
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
1364
+ def _mask_hidden_states(
1365
+ self,
1366
+ hidden_states: torch.FloatTensor,
1367
+ mask_time_indices: Optional[torch.FloatTensor] = None,
1368
+ attention_mask: Optional[torch.LongTensor] = None,
1369
+ ):
1370
+ """
1371
+ Masks extracted features along time axis and/or along feature axis according to
1372
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
1373
+ """
1374
+
1375
+ # `config.apply_spec_augment` can set masking to False
1376
+ if not getattr(self.config, "apply_spec_augment", True):
1377
+ return hidden_states
1378
+
1379
+ # generate indices & apply SpecAugment along time axis
1380
+ batch_size, sequence_length, hidden_size = hidden_states.size()
1381
+
1382
+ if mask_time_indices is not None:
1383
+ # apply SpecAugment along time axis with given mask_time_indices
1384
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
1385
+ elif self.config.mask_time_prob > 0 and self.training:
1386
+ mask_time_indices = _compute_mask_indices(
1387
+ (batch_size, sequence_length),
1388
+ mask_prob=self.config.mask_time_prob,
1389
+ mask_length=self.config.mask_time_length,
1390
+ attention_mask=attention_mask,
1391
+ min_masks=self.config.mask_time_min_masks,
1392
+ )
1393
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
1394
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
1395
+
1396
+ if self.config.mask_feature_prob > 0 and self.training:
1397
+ # generate indices & apply SpecAugment along feature axis
1398
+ mask_feature_indices = _compute_mask_indices(
1399
+ (batch_size, hidden_size),
1400
+ mask_prob=self.config.mask_feature_prob,
1401
+ mask_length=self.config.mask_feature_length,
1402
+ min_masks=self.config.mask_feature_min_masks,
1403
+ )
1404
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
1405
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
1406
+ hidden_states[mask_feature_indices] = 0
1407
+
1408
+ return hidden_states
1409
+
1410
+ @add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
1411
+ @add_code_sample_docstrings(
1412
+ checkpoint=_CHECKPOINT_FOR_DOC,
1413
+ output_type=BaseModelOutput,
1414
+ config_class=_CONFIG_FOR_DOC,
1415
+ modality="audio",
1416
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
1417
+ )
1418
+ def forward(
1419
+ self,
1420
+ input_values: Optional[torch.Tensor],
1421
+ attention_mask: Optional[torch.Tensor] = None,
1422
+ mask_time_indices: Optional[torch.FloatTensor] = None,
1423
+ output_attentions: Optional[bool] = None,
1424
+ output_hidden_states: Optional[bool] = None,
1425
+ return_dict: Optional[bool] = None,
1426
+ ) -> Union[Tuple, BaseModelOutput]:
1427
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1428
+ output_hidden_states = (
1429
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1430
+ )
1431
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1432
+
1433
+ extract_features = self.feature_extractor(input_values)
1434
+ extract_features = extract_features.transpose(1, 2)
1435
+ extract_features = self.layer_norm(extract_features)
1436
+
1437
+ if self.project_features:
1438
+ extract_features = self.feature_projection(extract_features)
1439
+ hidden_states = self.feature_dropout(extract_features)
1440
+
1441
+ if attention_mask is not None:
1442
+ # compute reduced attention_mask corresponding to feature vectors
1443
+ attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
1444
+
1445
+ hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
1446
+
1447
+ encoder_outputs = self.encoder(
1448
+ hidden_states,
1449
+ attention_mask=attention_mask,
1450
+ output_attentions=output_attentions,
1451
+ output_hidden_states=output_hidden_states,
1452
+ return_dict=return_dict,
1453
+ )
1454
+
1455
+ hidden_states = encoder_outputs[0]
1456
+
1457
+ if not return_dict:
1458
+ return (hidden_states,) + encoder_outputs[1:]
1459
+
1460
+ return BaseModelOutput(
1461
+ last_hidden_state=hidden_states,
1462
+ hidden_states=encoder_outputs.hidden_states,
1463
+ attentions=encoder_outputs.attentions,
1464
+ )
1465
+
1466
+
1467
+ @add_start_docstrings(
1468
+ """SEW-D Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
1469
+ SEWD_START_DOCSTRING,
1470
+ )
1471
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV2VEC2->SEWD
1472
+ class SEWDForCTC(SEWDPreTrainedModel):
1473
+ def __init__(self, config, target_lang: Optional[str] = None):
1474
+ super().__init__(config)
1475
+
1476
+ self.sew_d = SEWDModel(config)
1477
+ self.dropout = nn.Dropout(config.final_dropout)
1478
+
1479
+ self.target_lang = target_lang
1480
+
1481
+ if config.vocab_size is None:
1482
+ raise ValueError(
1483
+ f"You are trying to instantiate {self.__class__} with a configuration that "
1484
+ "does not define the vocabulary size of the language model head. Please "
1485
+ "instantiate the model as follows: `SEWDForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
1486
+ "or define `vocab_size` of your model's configuration."
1487
+ )
1488
+ output_hidden_size = (
1489
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
1490
+ )
1491
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
1492
+
1493
+ # Initialize weights and apply final processing
1494
+ self.post_init()
1495
+
1496
+ def tie_weights(self):
1497
+ """
1498
+ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
1499
+ passing `target_lang=...` to `from_pretrained(...)`.
1500
+
1501
+ This method is **not** supposed to be called by the user and is prone to be changed in the future.
1502
+ """
1503
+
1504
+ # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
1505
+ # correctly load adapter layers for SEWD so that we do not have to introduce a new API to
1506
+ # [`PreTrainedModel`]. While slightly hacky, SEWD never has to tie input and output embeddings, so that it is
1507
+ # ok to repurpose this function here.
1508
+ target_lang = self.target_lang
1509
+
1510
+ if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
1511
+ raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
1512
+ elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
1513
+ logger.info("By default `target_lang` is set to 'eng'.")
1514
+ elif target_lang is not None:
1515
+ self.load_adapter(target_lang, force_load=True)
1516
+
1517
+ def freeze_feature_extractor(self):
1518
+ """
1519
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1520
+ not be updated during training.
1521
+ """
1522
+ warnings.warn(
1523
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
1524
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
1525
+ FutureWarning,
1526
+ )
1527
+ self.freeze_feature_encoder()
1528
+
1529
+ def freeze_feature_encoder(self):
1530
+ """
1531
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1532
+ not be updated during training.
1533
+ """
1534
+ self.sew_d.feature_extractor._freeze_parameters()
1535
+
1536
+ def freeze_base_model(self):
1537
+ """
1538
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
1539
+ be updated during training. Only the classification head will be updated.
1540
+ """
1541
+ for param in self.sew_d.parameters():
1542
+ param.requires_grad = False
1543
+
1544
+ @add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
1545
+ @add_code_sample_docstrings(
1546
+ checkpoint=_CHECKPOINT_FOR_DOC,
1547
+ output_type=CausalLMOutput,
1548
+ config_class=_CONFIG_FOR_DOC,
1549
+ expected_output=_CTC_EXPECTED_OUTPUT,
1550
+ expected_loss=_CTC_EXPECTED_LOSS,
1551
+ )
1552
+ def forward(
1553
+ self,
1554
+ input_values: Optional[torch.Tensor],
1555
+ attention_mask: Optional[torch.Tensor] = None,
1556
+ output_attentions: Optional[bool] = None,
1557
+ output_hidden_states: Optional[bool] = None,
1558
+ return_dict: Optional[bool] = None,
1559
+ labels: Optional[torch.Tensor] = None,
1560
+ ) -> Union[Tuple, CausalLMOutput]:
1561
+ r"""
1562
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
1563
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
1564
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
1565
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
1566
+ config.vocab_size - 1]`.
1567
+ """
1568
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1569
+
1570
+ if labels is not None and labels.max() >= self.config.vocab_size:
1571
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
1572
+
1573
+ outputs = self.sew_d(
1574
+ input_values,
1575
+ attention_mask=attention_mask,
1576
+ output_attentions=output_attentions,
1577
+ output_hidden_states=output_hidden_states,
1578
+ return_dict=return_dict,
1579
+ )
1580
+
1581
+ hidden_states = outputs[0]
1582
+ hidden_states = self.dropout(hidden_states)
1583
+
1584
+ logits = self.lm_head(hidden_states)
1585
+
1586
+ loss = None
1587
+ if labels is not None:
1588
+ # retrieve loss input_lengths from attention_mask
1589
+ attention_mask = (
1590
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
1591
+ )
1592
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
1593
+
1594
+ # assuming that padded tokens are filled with -100
1595
+ # when not being attended to
1596
+ labels_mask = labels >= 0
1597
+ target_lengths = labels_mask.sum(-1)
1598
+ flattened_targets = labels.masked_select(labels_mask)
1599
+
1600
+ # ctc_loss doesn't support fp16
1601
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
1602
+
1603
+ with torch.backends.cudnn.flags(enabled=False):
1604
+ loss = nn.functional.ctc_loss(
1605
+ log_probs,
1606
+ flattened_targets,
1607
+ input_lengths,
1608
+ target_lengths,
1609
+ blank=self.config.pad_token_id,
1610
+ reduction=self.config.ctc_loss_reduction,
1611
+ zero_infinity=self.config.ctc_zero_infinity,
1612
+ )
1613
+
1614
+ if not return_dict:
1615
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1616
+ return ((loss,) + output) if loss is not None else output
1617
+
1618
+ return CausalLMOutput(
1619
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1620
+ )
1621
+
1622
+
1623
+ @add_start_docstrings(
1624
+ """
1625
+ SEWD Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB
1626
+ Keyword Spotting.
1627
+ """,
1628
+ SEWD_START_DOCSTRING,
1629
+ )
1630
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV2VEC2->SEWD
1631
+ class SEWDForSequenceClassification(SEWDPreTrainedModel):
1632
+ def __init__(self, config):
1633
+ super().__init__(config)
1634
+
1635
+ if hasattr(config, "add_adapter") and config.add_adapter:
1636
+ raise ValueError(
1637
+ "Sequence classification does not support the use of SEWD adapters (config.add_adapter=True)"
1638
+ )
1639
+ self.sew_d = SEWDModel(config)
1640
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1641
+ if config.use_weighted_layer_sum:
1642
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1643
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
1644
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
1645
+
1646
+ # Initialize weights and apply final processing
1647
+ self.post_init()
1648
+
1649
+ def freeze_feature_extractor(self):
1650
+ """
1651
+ Calling this function will disable the gradient computation for the feature encoder so that its parameters will
1652
+ not be updated during training.
1653
+ """
1654
+ warnings.warn(
1655
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
1656
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
1657
+ FutureWarning,
1658
+ )
1659
+ self.freeze_feature_encoder()
1660
+
1661
+ def freeze_feature_encoder(self):
1662
+ """
1663
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1664
+ not be updated during training.
1665
+ """
1666
+ self.sew_d.feature_extractor._freeze_parameters()
1667
+
1668
+ def freeze_base_model(self):
1669
+ """
1670
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
1671
+ be updated during training. Only the classification head will be updated.
1672
+ """
1673
+ for param in self.sew_d.parameters():
1674
+ param.requires_grad = False
1675
+
1676
+ @add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
1677
+ @add_code_sample_docstrings(
1678
+ checkpoint=_SEQ_CLASS_CHECKPOINT,
1679
+ output_type=SequenceClassifierOutput,
1680
+ config_class=_CONFIG_FOR_DOC,
1681
+ modality="audio",
1682
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1683
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
1684
+ )
1685
+ def forward(
1686
+ self,
1687
+ input_values: Optional[torch.Tensor],
1688
+ attention_mask: Optional[torch.Tensor] = None,
1689
+ output_attentions: Optional[bool] = None,
1690
+ output_hidden_states: Optional[bool] = None,
1691
+ return_dict: Optional[bool] = None,
1692
+ labels: Optional[torch.Tensor] = None,
1693
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1694
+ r"""
1695
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1696
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1697
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1698
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1699
+ """
1700
+
1701
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1702
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
1703
+
1704
+ outputs = self.sew_d(
1705
+ input_values,
1706
+ attention_mask=attention_mask,
1707
+ output_attentions=output_attentions,
1708
+ output_hidden_states=output_hidden_states,
1709
+ return_dict=return_dict,
1710
+ )
1711
+
1712
+ if self.config.use_weighted_layer_sum:
1713
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
1714
+ hidden_states = torch.stack(hidden_states, dim=1)
1715
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
1716
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
1717
+ else:
1718
+ hidden_states = outputs[0]
1719
+
1720
+ hidden_states = self.projector(hidden_states)
1721
+ if attention_mask is None:
1722
+ pooled_output = hidden_states.mean(dim=1)
1723
+ else:
1724
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
1725
+ expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
1726
+ hidden_states[~expand_padding_mask] = 0.0
1727
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
1728
+
1729
+ logits = self.classifier(pooled_output)
1730
+
1731
+ loss = None
1732
+ if labels is not None:
1733
+ loss_fct = CrossEntropyLoss()
1734
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1735
+
1736
+ if not return_dict:
1737
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1738
+ return ((loss,) + output) if loss is not None else output
1739
+
1740
+ return SequenceClassifierOutput(
1741
+ loss=loss,
1742
+ logits=logits,
1743
+ hidden_states=outputs.hidden_states,
1744
+ attentions=outputs.attentions,
1745
+ )
1746
+
1747
+
1748
+ __all__ = ["SEWDForCTC", "SEWDForSequenceClassification", "SEWDModel", "SEWDPreTrainedModel"]
docs/transformers/build/lib/transformers/models/shieldgemma2/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_shieldgemma2 import *
22
+ from .modeling_shieldgemma2 import *
23
+ from .processing_shieldgemma2 import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/shieldgemma2/configuration_shieldgemma2.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+ from ..auto import CONFIG_MAPPING, AutoConfig
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class ShieldGemma2Config(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`ShieldGemma2ForImageClassification`]. It is used to instantiate an
28
+ ShieldGemma2ForImageClassification according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the shieldgemma-2-4b-it.
30
+
31
+ e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+ Args:
37
+ text_config (`Union[ShieldGemma2TextConfig, dict]`, *optional*):
38
+ The config object of the text backbone.
39
+ vision_config (`Union[AutoConfig, dict]`, *optional*):
40
+ Custom vision config or dict.
41
+ mm_tokens_per_image (`int`, *optional*, defaults to 256):
42
+ The number of tokens per image embedding.
43
+ boi_token_index (`int`, *optional*, defaults to 255999):
44
+ The begin-of-image token index to wrap the image prompt.
45
+ eoi_token_index (`int`, *optional*, defaults to 256000):
46
+ The end-of-image token index to wrap the image prompt.
47
+ image_token_index (`int`, *optional*, defaults to 262144):
48
+ The image token index to encode the image prompt.
49
+ initializer_range (`float`, *optional*, defaults to 0.02):
50
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
51
+
52
+
53
+ Example:
54
+
55
+ ```python
56
+ >>> from transformers import ShieldGemma2ForConditionalGeneration, ShieldGemma2Config, SiglipVisionConfig, ShieldGemma2TextConfig
57
+
58
+ >>> # Initializing a Siglip-like vision config
59
+ >>> vision_config = SiglipVisionConfig()
60
+
61
+ >>> # Initializing a ShieldGemma2 Text config
62
+ >>> text_config = ShieldGemma2TextConfig()
63
+
64
+ >>> # Initializing a ShieldGemma2 gemma-3-4b style configuration
65
+ >>> configuration = ShieldGemma2Config(vision_config, text_config)
66
+
67
+ >>> # Initializing a model from the gemma-3-4b style configuration
68
+ >>> model = ShieldGemma2TextConfig(configuration)
69
+
70
+ >>> # Accessing the model configuration
71
+ >>> configuration = model.config
72
+ ```"""
73
+
74
+ model_type = "shieldgemma2"
75
+ attribute_map = {
76
+ "image_token_id": "image_token_index",
77
+ "boi_token_id": "boi_token_index",
78
+ "eoi_token_id": "eoi_token_index",
79
+ }
80
+ sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
81
+
82
+ def __init__(
83
+ self,
84
+ text_config=None,
85
+ vision_config=None,
86
+ mm_tokens_per_image: int = 256,
87
+ boi_token_index: int = 255_999,
88
+ eoi_token_index: int = 256_000,
89
+ image_token_index: int = 262_144,
90
+ initializer_range: float = 0.02,
91
+ **kwargs,
92
+ ):
93
+ if isinstance(vision_config, dict):
94
+ vision_config["model_type"] = (
95
+ vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
96
+ )
97
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
98
+ elif vision_config is None:
99
+ vision_config = CONFIG_MAPPING["siglip_vision_model"]()
100
+
101
+ self.vision_config = vision_config
102
+
103
+ if isinstance(text_config, dict):
104
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma3_text"
105
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
106
+ elif text_config is None:
107
+ text_config = CONFIG_MAPPING["gemma3_text"]()
108
+
109
+ self.text_config = text_config
110
+ self.vision_config = vision_config
111
+ self.mm_tokens_per_image = mm_tokens_per_image
112
+ self.boi_token_index = boi_token_index
113
+ self.eoi_token_index = eoi_token_index
114
+ self.image_token_index = image_token_index
115
+ self.initializer_range = initializer_range
116
+
117
+ super().__init__(**kwargs)
118
+
119
+
120
+ __all__ = ["ShieldGemma2Config"]
docs/transformers/build/lib/transformers/models/shieldgemma2/convert_shieldgemma2_weights_orbax_to_hf.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint.
2
+
3
+ python -m transformers.models.shieldgemma2.convert_shieldgemma2_weights_orbax_to_hf \
4
+ --tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \
5
+ --checkpoint_path_gemma="$HOME/gemma3/gemma3_4b_pt_orbax/" \
6
+ --checkpoint_path_shieldgemma="$HOME/shieldgemma2/shieldgemma-2_4b_orbax/" \
7
+ --output_path="$HOME/shieldgemma2/shieldgemma2_4b_pt_safetensors/" \
8
+ --precision='bfloat16'
9
+ """
10
+
11
+ import dataclasses
12
+ from collections.abc import Iterator, Mapping, Sequence
13
+ from typing import Any
14
+
15
+ import accelerate
16
+ import numpy as np
17
+ import torch
18
+ import tree
19
+ from absl import app, flags, logging
20
+ from orbax import checkpoint as obc
21
+
22
+ from ...image_utils import PILImageResampling
23
+ from ..gemma import GemmaTokenizerFast
24
+ from ..gemma3 import Gemma3ImageProcessor, Gemma3TextConfig
25
+ from ..siglip import SiglipVisionConfig
26
+ from . import ShieldGemma2Config, ShieldGemma2ForImageClassification, ShieldGemma2Processor
27
+
28
+
29
+ # ==== Internal Constants and Classes ====
30
+
31
+ _CHAT_TEMPLATE = """{{ bos_token }}
32
+ {%- for message in messages -%}
33
+ {%- if message['content'] is iterable -%}
34
+ {%- for item in message['content'] -%}
35
+ {{ '<start_of_turn>user\n' }}
36
+ {%- if item['type'] == 'image' -%}
37
+ {{ '<start_of_image>' }}
38
+ {%- elif item['type'] == 'text' -%}
39
+ {{ 'You are a policy expert trying to help determine whether a generated or real image is in' }}
40
+ {{ ' violation of the safety policy below. Your decision must very clearly follow the specific'}}
41
+ {{ ' language used in the safety policy and should not include any interpretation outside of the' }}
42
+ {{ ' safety policy.\n\n' }}
43
+ {{ '<BEGIN_SAFETY_POLICY>\n' }}
44
+ {{ item['text'] + '\n' }}
45
+ {{ '</END_SAFETY_POLICY>\n\n' }}
46
+ {{ 'Does the image violate the above policy? Your Answer must start with \\'Yes\\' or \\'No\\'.' }}
47
+ {{ '<end_of_turn>\n' }}
48
+ {%- endif -%}
49
+ {%- endfor -%}
50
+ {{'<start_of_turn>model\n'}}
51
+ {%- else -%}
52
+ {{ raise_exception("Conversation messages must contain iterable content containing images and policy definitions in text.") }}
53
+ {%- endif -%}
54
+ {%- endfor -%}
55
+ """
56
+
57
+ _DTYPES = {
58
+ "float32": torch.float32,
59
+ "bfloat16": torch.bfloat16,
60
+ "float16": torch.float16,
61
+ }
62
+
63
+ _SHIELDGEMMA2_POLICIES: Mapping[str, str] = {
64
+ "dangerous": (
65
+ "No Dangerous Content: The image shall not contain content that"
66
+ " facilitates or encourages activities that could cause real-world harm"
67
+ " (e.g., building firearms and explosive devices, promotion of"
68
+ " terrorism, instructions for suicide)."
69
+ ),
70
+ "sexual": (
71
+ "No Sexually Explicit content: The image shall not contain content that"
72
+ " depicts explicit or graphic sexual acts (e.g., pornography, erotic"
73
+ " nudity, depictions of rape or sexual assault)."
74
+ ),
75
+ "violence": (
76
+ "No Violence/Gore content: The image shall not contain content that"
77
+ " depicts shocking, sensational, or gratuitous violence (e.g.,"
78
+ " excessive blood and gore, gratuitous violence against animals,"
79
+ " extreme injury or moment of death)."
80
+ ),
81
+ }
82
+
83
+ _SIGLIP_BASE = "SigLiPFromPatches_0/siglip_encoder"
84
+ _SIGLIP_EMBEDDING = "SigLiPFromPatches_0/siglip_encoder/embedding"
85
+ _SIGLIP_TRANSFORMER_ENCODER_BLOCK = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_"
86
+ _SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN = len(_SIGLIP_TRANSFORMER_ENCODER_BLOCK)
87
+ _SIGLIP_TRANSFORMER_ENCODER_NORM = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoder_norm"
88
+
89
+ _TRANSFORMER_DECODER_BLOCK = "transformer/layer_"
90
+ _TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK)
91
+ _TRANSFORMER_EMBEDDER = "transformer/embedder"
92
+ _TRANSFORMER_FINAL_NORM = "transformer/final_norm"
93
+ _TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/"
94
+ _TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX)
95
+
96
+ # ==== Flags ====
97
+
98
+ _GEMMA_CHECKPOINT_PATH = flags.DEFINE_string(
99
+ name="checkpoint_path_gemma",
100
+ default=None,
101
+ help="Path to the Orbax checkpoint containing the vision weights.",
102
+ required=True,
103
+ )
104
+
105
+ _SHIELDGEMMA_CHECKPOINT_PATH = flags.DEFINE_string(
106
+ name="checkpoint_path_shieldgemma",
107
+ default=None,
108
+ help="Path to the Orbax checkpoint containing the language model weights.",
109
+ required=True,
110
+ )
111
+
112
+ OUTPUT_PATH = flags.DEFINE_string(
113
+ name="output_path",
114
+ default=None,
115
+ help="Path to store the HF checkpoint.",
116
+ required=True,
117
+ )
118
+
119
+ PRECISION = flags.DEFINE_enum(
120
+ name="precision",
121
+ default=None,
122
+ help="The floating point precision (aka dtype) of the model.",
123
+ enum_values=set(_DTYPES.keys()),
124
+ required=True,
125
+ )
126
+
127
+ TOKENIZER_PATH = flags.DEFINE_string(
128
+ name="tokenizer_path",
129
+ default=None,
130
+ help="Path to the SentencePiece model file.",
131
+ required=True,
132
+ )
133
+
134
+
135
+ def convert_siglip_weight(
136
+ config: SiglipVisionConfig,
137
+ paths: Sequence[str],
138
+ weights: np.ndarray,
139
+ ) -> tuple[str, np.ndarray]:
140
+ path, prop = paths
141
+ normalized_path: str = ""
142
+ updated_weights: np.ndarray = None
143
+
144
+ if path == _SIGLIP_BASE:
145
+ normalized_path = "vision_tower.vision_model.embeddings.position_embedding.weight"
146
+ updated_weights = weights.reshape(-1, config.hidden_size)
147
+ elif path == _SIGLIP_EMBEDDING:
148
+ if prop == "kernel":
149
+ normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.weight"
150
+ updated_weights = weights.transpose(3, 2, 0, 1)
151
+ elif prop == "bias":
152
+ normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.bias"
153
+ updated_weights = weights
154
+ else:
155
+ raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.")
156
+ elif path.startswith(_SIGLIP_TRANSFORMER_ENCODER_BLOCK):
157
+ encoder_block_path = path[_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN:]
158
+ next_path_seperator_idx = encoder_block_path.find("/")
159
+ layer_idx = encoder_block_path[:next_path_seperator_idx]
160
+ encoder_block_path = encoder_block_path[next_path_seperator_idx:]
161
+ normalized_path = f"vision_tower.vision_model.encoder.layers.{layer_idx}"
162
+
163
+ if encoder_block_path.startswith("/LayerNorm"):
164
+ normalized_path += ".layer_norm1" if path.endswith("_0") else ".layer_norm2"
165
+
166
+ if prop == "scale":
167
+ normalized_path += ".weight"
168
+ updated_weights = weights.transpose()
169
+ elif prop == "bias":
170
+ normalized_path += ".bias"
171
+ updated_weights = weights
172
+ else:
173
+ raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.")
174
+ elif encoder_block_path.startswith("/MlpBlock_0"):
175
+ normalized_path += ".mlp.fc1" if "/Dense_0" in encoder_block_path else ".mlp.fc2"
176
+
177
+ if prop == "kernel":
178
+ normalized_path += ".weight"
179
+ updated_weights = weights.transpose()
180
+ elif prop == "bias":
181
+ normalized_path += ".bias"
182
+ updated_weights = weights
183
+ else:
184
+ raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.")
185
+ elif encoder_block_path.startswith("/MultiHeadDotProductAttention_0"):
186
+ if encoder_block_path.endswith("/key"):
187
+ normalized_path += ".self_attn.k_proj"
188
+ elif encoder_block_path.endswith("/out"):
189
+ normalized_path += ".self_attn.out_proj"
190
+ elif encoder_block_path.endswith("/query"):
191
+ normalized_path += ".self_attn.q_proj"
192
+ elif encoder_block_path.endswith("/value"):
193
+ normalized_path += ".self_attn.v_proj"
194
+ else:
195
+ raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer MultiHeadDotProductAttention_0.")
196
+
197
+ if prop == "bias":
198
+ normalized_path += ".bias"
199
+ updated_weights = weights.reshape(-1, config.hidden_size).reshape(-1)
200
+ elif prop == "kernel":
201
+ normalized_path += ".weight"
202
+ updated_weights = weights.reshape(-1, config.hidden_size).transpose()
203
+ else:
204
+ raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.")
205
+ else:
206
+ raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer Encoder Block.")
207
+ elif path == _SIGLIP_TRANSFORMER_ENCODER_NORM:
208
+ if prop == "scale":
209
+ normalized_path = "vision_tower.vision_model.post_layernorm.weight"
210
+ updated_weights = weights.transpose()
211
+ elif prop == "bias":
212
+ normalized_path = "vision_tower.vision_model.post_layernorm.bias"
213
+ updated_weights = weights
214
+ else:
215
+ raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.")
216
+ else:
217
+ raise ValueError(f"Unexpected path `{path}`.")
218
+
219
+ if "vision" in normalized_path:
220
+ print(normalized_path)
221
+ return normalized_path, updated_weights
222
+
223
+
224
+ def convert_transformer_weights(
225
+ config: Gemma3TextConfig,
226
+ paths: Sequence[str],
227
+ weights: np.ndarray,
228
+ ) -> Iterator[tuple[str, np.ndarray]]:
229
+ path, prop = paths
230
+
231
+ if path.startswith(_TRANSFORMER_POST_TRAINING_PREFIX):
232
+ path = path[_TRANSFORMER_POST_TRAINING_PREFIX_LEN:]
233
+
234
+ converted_paths: list[str] = []
235
+ converted_weights: list[Any] = []
236
+
237
+ attn_head_dim = config.num_attention_heads * config.head_dim
238
+ kv_head_dim = config.num_key_value_heads * config.head_dim
239
+
240
+ if path == _TRANSFORMER_EMBEDDER:
241
+ if prop == "input_embedding":
242
+ # Tied to language_model.lm_head.weight, assigned at the end.
243
+ converted_paths = ["language_model.model.embed_tokens.weight"]
244
+ # Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama
245
+ pre_expansion_embeddings = weights
246
+ mu = np.mean(pre_expansion_embeddings, axis=0)
247
+ sigma = np.cov(pre_expansion_embeddings, rowvar=False, bias=True)
248
+ new_embeddings = np.random.multivariate_normal(mu, 1e-5 * sigma, size=64)
249
+ weights = np.vstack([pre_expansion_embeddings, new_embeddings])
250
+ converted_weights = [weights]
251
+ else:
252
+ raise ValueError(f"Unexpected member, {prop}, in Embedder.")
253
+ elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"):
254
+ if path.endswith("/mm_input_projection"):
255
+ converted_paths = ["multi_modal_projector.mm_input_projection_weight"]
256
+ converted_weights = [weights]
257
+ elif path.endswith("/mm_soft_embedding_norm"):
258
+ converted_paths = ["multi_modal_projector.mm_soft_emb_norm.weight"]
259
+ converted_weights = [weights]
260
+ else:
261
+ raise ValueError(f"Unexpected subpath, `{path}`, in Embedder.")
262
+ elif path == _TRANSFORMER_FINAL_NORM:
263
+ converted_paths = ["language_model.model.norm.weight"]
264
+ converted_weights = [weights]
265
+ elif path.startswith(_TRANSFORMER_DECODER_BLOCK):
266
+ decoder_block_path = path[_TRANSFORMER_DECODER_BLOCK_LEN:]
267
+ next_path_seperator_idx = decoder_block_path.find("/")
268
+ layer_idx = decoder_block_path[:next_path_seperator_idx]
269
+ decoder_block_path = decoder_block_path[next_path_seperator_idx:]
270
+
271
+ base_path = f"language_model.model.layers.{layer_idx}"
272
+
273
+ if path.endswith("attn/attn_vec_einsum"):
274
+ converted_paths = [f"{base_path}.self_attn.o_proj.weight"]
275
+ converted_weights = [weights.transpose(2, 0, 1).reshape(config.hidden_size, attn_head_dim)]
276
+ elif path.endswith("attn/_key_norm"):
277
+ converted_paths = [f"{base_path}.self_attn.k_norm.weight"]
278
+ converted_weights = [weights]
279
+ elif path.endswith("attn/kv_einsum"):
280
+ converted_paths = [
281
+ f"{base_path}.self_attn.k_proj.weight",
282
+ f"{base_path}.self_attn.v_proj.weight",
283
+ ]
284
+ k_proj_weights, v_proj_weights = weights
285
+ converted_weights = [
286
+ k_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size),
287
+ v_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size),
288
+ ]
289
+ elif path.endswith("attn/q_einsum"):
290
+ converted_paths = [f"{base_path}.self_attn.q_proj.weight"]
291
+ converted_weights = [weights.transpose(0, 2, 1).reshape(attn_head_dim, config.hidden_size)]
292
+ elif path.endswith("attn/_query_norm"):
293
+ converted_paths = [f"{base_path}.self_attn.q_norm.weight"]
294
+ converted_weights = [weights]
295
+ elif path.endswith("mlp/gating_einsum"):
296
+ converted_paths = [
297
+ f"{base_path}.mlp.gate_proj.weight",
298
+ f"{base_path}.mlp.up_proj.weight",
299
+ ]
300
+ gate_proj_weight, up_proj_weight = weights
301
+ converted_weights = [gate_proj_weight, up_proj_weight]
302
+ elif path.endswith("mlp/linear"):
303
+ converted_paths = [f"{base_path}.mlp.down_proj.weight"]
304
+ converted_weights = [weights.transpose()]
305
+ elif path.endswith("post_attention_norm"):
306
+ converted_paths = [f"{base_path}.post_attention_layernorm.weight"]
307
+ converted_weights = [weights]
308
+ elif path.endswith("post_ffw_norm"):
309
+ converted_paths = [f"{base_path}.post_feedforward_layernorm.weight"]
310
+ converted_weights = [weights]
311
+ elif path.endswith("pre_attention_norm"):
312
+ converted_paths = [f"{base_path}.input_layernorm.weight"]
313
+ converted_weights = [weights]
314
+ elif path.endswith("pre_ffw_norm"):
315
+ converted_paths = [f"{base_path}.pre_feedforward_layernorm.weight"]
316
+ converted_weights = [weights]
317
+ else:
318
+ raise ValueError(f"Unexpected path `{path}` in Decoder Block.")
319
+ else:
320
+ raise ValueError(f"Unexpected path `{path}`.")
321
+
322
+ if (cpl := len(converted_paths)) != (cwl := len(converted_weights)):
323
+ raise ValueError(
324
+ "The `converted_paths` and `converted_weights` should be the same "
325
+ f"length. Got {cpl} and {cwl}, respectively, for {path}."
326
+ )
327
+
328
+ return zip(converted_paths, converted_weights)
329
+
330
+
331
+ def transpose_reshape(x: torch.Tensor) -> torch.Tensor:
332
+ x = x.transpose(1, 2)
333
+ return x.reshape(x.shape[0] * x.shape[1], x.shape[2]).contiguous()
334
+
335
+
336
+ @dataclasses.dataclass(frozen=True)
337
+ class ConversionResult:
338
+ state_tree: dict[str, torch.Tensor]
339
+ config: ShieldGemma2Config
340
+
341
+
342
+ def convert(
343
+ shieldgemma_checkpoint_path: str,
344
+ gemma_checkpoint_path: str,
345
+ config: ShieldGemma2Config,
346
+ target_dtype: torch.dtype,
347
+ ) -> ConversionResult:
348
+ """Loads Orbax checkpoint from `input_path` and converts it to HF tree."""
349
+ checkpointer = obc.PyTreeCheckpointer()
350
+
351
+ sg2_ckpt = checkpointer.restore(shieldgemma_checkpoint_path)
352
+ g3_ckpt = checkpointer.restore(gemma_checkpoint_path)
353
+
354
+ hf_tree: dict[str, torch.Tensor] = {}
355
+
356
+ def update_tree(path: str, weights: np.ndarray) -> None:
357
+ torch_tensor = torch.from_numpy(weights.astype("float32")).type(target_dtype)
358
+ logging.info(
359
+ "%s converted shape=%s with dtype=%s",
360
+ path,
361
+ weights.shape,
362
+ torch_tensor.dtype,
363
+ )
364
+ hf_tree[f"model.{path}"] = torch_tensor
365
+
366
+ for paths, value in tree.flatten_with_path(g3_ckpt):
367
+ if paths[0].startswith("SigLiPFromPatches_"):
368
+ path, weights = convert_siglip_weight(config=config.vision_config, paths=paths, weights=value)
369
+ update_tree(path, weights)
370
+
371
+ for paths, value in tree.flatten_with_path(sg2_ckpt):
372
+ for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value):
373
+ update_tree(path, weights)
374
+
375
+ hf_tree["model.language_model.lm_head.weight"] = hf_tree["model.language_model.model.embed_tokens.weight"]
376
+
377
+ return ConversionResult(state_tree=hf_tree, config=config)
378
+
379
+
380
+ def main(*args):
381
+ del args
382
+
383
+ dtype = getattr(torch, PRECISION.value)
384
+ output_path = OUTPUT_PATH.value
385
+
386
+ tokenizer = GemmaTokenizerFast(
387
+ TOKENIZER_PATH.value,
388
+ extra_special_tokens={
389
+ "image_token": "<image_soft_token>", # Should be ID=262_144
390
+ "boi_token": "<start_of_image>", # Should be ID=255_999
391
+ "eoi_token": "<end_of_image>", # Should be ID=256_000
392
+ },
393
+ )
394
+
395
+ yes_token_index, no_token_index = torch.tensor(tokenizer(["Yes", "No"])["input_ids"])[:, 1].numpy()
396
+
397
+ config = ShieldGemma2Config(
398
+ yes_token_index=int(yes_token_index),
399
+ no_token_index=int(no_token_index),
400
+ text_config=Gemma3TextConfig(
401
+ vocab_size=262_208,
402
+ hidden_size=2560,
403
+ intermediate_size=2560 * 8 // 2,
404
+ num_attention_heads=8,
405
+ head_dim=256,
406
+ num_hidden_layers=34,
407
+ num_key_value_heads=4,
408
+ sliding_window=1024,
409
+ rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only
410
+ rope_theta=1_000_000,
411
+ rope_local_base_freq=10_000,
412
+ attn_logit_softcapping=None,
413
+ query_pre_attn_scalar=256,
414
+ max_position_embeddings=8192,
415
+ ),
416
+ vision_config={
417
+ "hidden_size": 1152,
418
+ "intermediate_size": 4304,
419
+ "num_hidden_layers": 27,
420
+ "num_attention_heads": 16,
421
+ "num_channels": 3,
422
+ "image_size": 896,
423
+ "patch_size": 14,
424
+ "hidden_act": "gelu_pytorch_tanh",
425
+ "layer_norm_eps": 1e-6,
426
+ "attention_dropout": 0.0,
427
+ "vision_use_head": False,
428
+ },
429
+ )
430
+
431
+ config.save_pretrained(output_path)
432
+
433
+ image_processor = Gemma3ImageProcessor(
434
+ image_seq_length=256,
435
+ image_mean=(0.5,) * 3,
436
+ image_std=(0.5,) * 3,
437
+ size={"height": 896, "width": 896},
438
+ resample=PILImageResampling.BILINEAR,
439
+ )
440
+ processor = ShieldGemma2Processor(
441
+ image_processor=image_processor,
442
+ tokenizer=tokenizer,
443
+ policy_definitions=_SHIELDGEMMA2_POLICIES,
444
+ )
445
+ tokenizer.chat_template = _CHAT_TEMPLATE
446
+ processor.chat_template = _CHAT_TEMPLATE
447
+
448
+ processor.save_pretrained(output_path)
449
+ logging.info("Saved Shieldgemma2Processor to %s", output_path)
450
+ del processor
451
+ del tokenizer
452
+
453
+ logging.info("Converting Shieldgemma2 @ %s", dtype)
454
+ result = convert(_SHIELDGEMMA_CHECKPOINT_PATH.value, _GEMMA_CHECKPOINT_PATH.value, config, dtype)
455
+ logging.info("Converted Shieldgemma2 state tree from Orbax to Hugging Face.")
456
+
457
+ with accelerate.init_empty_weights():
458
+ model = ShieldGemma2ForImageClassification(config=config)
459
+
460
+ model.load_state_dict(result.state_tree, assign=True, strict=True)
461
+ model.config.torch_dtype = dtype
462
+ logging.info("Loaded Shieldgemma2 in Hugging Face Transformers.")
463
+ model.save_pretrained(output_path, safe_serialization=True)
464
+ logging.info("Saved Shieldgemma2 to SafeTensors in %s", output_path)
465
+ del model
466
+ del result
467
+
468
+
469
+ if __name__ == "__main__":
470
+ app.run(main)
docs/transformers/build/lib/transformers/models/shieldgemma2/modeling_shieldgemma2.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Union
18
+
19
+ import torch
20
+ import torch.utils.checkpoint
21
+
22
+ from ...cache_utils import Cache
23
+ from ...modeling_outputs import ImageClassifierOutputWithNoAttention
24
+ from ...modeling_utils import PreTrainedModel
25
+ from ...utils import (
26
+ add_start_docstrings_to_model_forward,
27
+ logging,
28
+ )
29
+ from ..auto import AutoModelForImageTextToText
30
+ from .configuration_shieldgemma2 import ShieldGemma2Config
31
+
32
+
33
+ _CHECKPOINT_FOR_DOC = "google/shieldgemma-2-4b-it"
34
+ _CONFIG_FOR_DOC = "ShieldGemma2Config"
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ SHIELDGEMMA2_INPUTS_DOCSTRING = r"""
39
+ Args:
40
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
41
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
42
+ it.
43
+
44
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
45
+ [`PreTrainedTokenizer.__call__`] for details.
46
+
47
+ [What are input IDs?](../glossary#input-ids)
48
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
49
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
50
+
51
+ - 1 for tokens that are **not masked**,
52
+ - 0 for tokens that are **masked**.
53
+
54
+ [What are attention masks?](../glossary#attention-mask)
55
+
56
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
57
+ [`PreTrainedTokenizer.__call__`] for details.
58
+
59
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
60
+ `past_key_values`).
61
+
62
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
63
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
64
+ information on the default strategy.
65
+
66
+ - 1 indicates the head is **not masked**,
67
+ - 0 indicates the head is **masked**.
68
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
69
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
70
+ config.n_positions - 1]`.
71
+
72
+ [What are position IDs?](../glossary#position-ids)
73
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
74
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
75
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
76
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
77
+
78
+ Two formats are allowed:
79
+ - a [`~cache_utils.Cache`] instance, see our
80
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
81
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
82
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
83
+ cache format.
84
+
85
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
86
+ legacy cache format will be returned.
87
+
88
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
89
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
90
+ of shape `(batch_size, sequence_length)`.
91
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
92
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
93
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
94
+ model's internal embedding lookup matrix.
95
+ use_cache (`bool`, *optional*):
96
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
97
+ `past_key_values`).
98
+ output_attentions (`bool`, *optional*):
99
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
100
+ tensors for more detail.
101
+ output_hidden_states (`bool`, *optional*):
102
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
103
+ more detail.
104
+ return_dict (`bool`, *optional*):
105
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
106
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
107
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
108
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
109
+ the complete sequence length.
110
+ """
111
+
112
+
113
+ @dataclass
114
+ class ShieldGemma2ImageClassifierOutputWithNoAttention(ImageClassifierOutputWithNoAttention):
115
+ """ShieldGemma2 classifies imags as violative or not relative to a specific policy
116
+ Args:
117
+ """
118
+
119
+ probabilities: Optional[torch.Tensor] = None
120
+
121
+
122
+ class ShieldGemma2ForImageClassification(PreTrainedModel):
123
+ config_class = ShieldGemma2Config
124
+
125
+ def __init__(self, config: ShieldGemma2Config):
126
+ super().__init__(config=config)
127
+ self.yes_token_index = getattr(config, "yes_token_index", 10_784)
128
+ self.no_token_index = getattr(config, "no_token_index", 3771)
129
+ self.model = AutoModelForImageTextToText.from_config(config=config)
130
+
131
+ def get_input_embeddings(self):
132
+ return self.model.language_model.get_input_embeddings()
133
+
134
+ def set_input_embeddings(self, value):
135
+ self.model.language_model.set_input_embeddings(value)
136
+
137
+ def get_output_embeddings(self):
138
+ return self.model.language_model.get_output_embeddings()
139
+
140
+ def set_output_embeddings(self, new_embeddings):
141
+ self.model.language_model.set_output_embeddings(new_embeddings)
142
+
143
+ def set_decoder(self, decoder):
144
+ self.model.language_model.set_decoder(decoder)
145
+
146
+ def get_decoder(self):
147
+ return self.model.language_model.get_decoder()
148
+
149
+ def tie_weights(self):
150
+ return self.model.language_model.tie_weights()
151
+
152
+ @add_start_docstrings_to_model_forward(SHIELDGEMMA2_INPUTS_DOCSTRING)
153
+ def forward(
154
+ self,
155
+ input_ids: Optional[torch.LongTensor] = None,
156
+ pixel_values: Optional[torch.FloatTensor] = None,
157
+ attention_mask: Optional[torch.Tensor] = None,
158
+ position_ids: Optional[torch.LongTensor] = None,
159
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
160
+ token_type_ids: Optional[torch.LongTensor] = None,
161
+ cache_position: Optional[torch.LongTensor] = None,
162
+ inputs_embeds: Optional[torch.FloatTensor] = None,
163
+ labels: Optional[torch.LongTensor] = None,
164
+ use_cache: Optional[bool] = None,
165
+ output_attentions: Optional[bool] = None,
166
+ output_hidden_states: Optional[bool] = None,
167
+ return_dict: Optional[bool] = None,
168
+ logits_to_keep: Union[int, torch.Tensor] = 0,
169
+ **lm_kwargs,
170
+ ) -> ShieldGemma2ImageClassifierOutputWithNoAttention:
171
+ """Predicts the binary probability that the image violates the specified policy.
172
+
173
+ Returns:
174
+ A `ShieldGemma2ImageClassifierOutputWithNoAttention` instance containing the logits and probabilities
175
+ associated with the model predicting the `Yes` or `No` token as the response to that prompt, captured in the
176
+ following properties.
177
+
178
+ * `logits` (`torch.Tensor` of shape `(batch_size, 2)`):
179
+ The first position along dim=1 is the logits for the `Yes` token and the second position along dim=1 is
180
+ the logits for the `No` token.
181
+ * `probabilities` (`torch.Tensor` of shape `(batch_size, 2)`):
182
+ The first position along dim=1 is the probability of predicting the `Yes` token and the second position
183
+ along dim=1 is the probability of predicting the `No` token.
184
+
185
+ ShieldGemma prompts are constructed such that predicting the `Yes` token means the content *does violate* the
186
+ policy as described. If you are only interested in the violative condition, use
187
+ `violated = outputs.probabilities[:, 1]` to extract that slice from the output tensors.
188
+
189
+ When used with the `ShieldGemma2Processor`, the `batch_size` will be equal to `len(images) * len(policies)`,
190
+ and the order within the batch will be img1_policy1, ... img1_policyN, ... imgM_policyN.
191
+ """
192
+ outputs = self.model(
193
+ input_ids=input_ids,
194
+ pixel_values=pixel_values,
195
+ attention_mask=attention_mask,
196
+ position_ids=position_ids,
197
+ past_key_values=past_key_values,
198
+ token_type_ids=token_type_ids,
199
+ cache_position=cache_position,
200
+ inputs_embeds=inputs_embeds,
201
+ labels=labels,
202
+ use_cache=use_cache,
203
+ output_attentions=output_attentions,
204
+ output_hidden_states=output_hidden_states,
205
+ return_dict=return_dict,
206
+ logits_to_keep=logits_to_keep,
207
+ **lm_kwargs,
208
+ )
209
+ logits = outputs.logits
210
+ selected_logits = logits[:, -1, [self.yes_token_index, self.no_token_index]]
211
+ probabilities = torch.softmax(selected_logits, dim=-1)
212
+ return ShieldGemma2ImageClassifierOutputWithNoAttention(
213
+ logits=selected_logits,
214
+ probabilities=probabilities,
215
+ )
216
+
217
+
218
+ __all__ = [
219
+ "ShieldGemma2ForImageClassification",
220
+ ]
docs/transformers/build/lib/transformers/models/shieldgemma2/processing_shieldgemma2.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from collections.abc import Mapping, Sequence
17
+ from typing import Optional
18
+
19
+ from ...feature_extraction_utils import BatchFeature
20
+ from ...image_utils import ImageInput
21
+ from ...processing_utils import Unpack
22
+ from ...utils import logging
23
+ from ..gemma3.processing_gemma3 import Gemma3Processor, Gemma3ProcessorKwargs
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ DEFAULT_SHIELDGEMMA2_POLICIES: Mapping[str, str] = {
29
+ "dangerous": (
30
+ "No Dangerous Content: The image shall not contain content that"
31
+ " facilitates or encourages activities that could cause real-world harm"
32
+ " (e.g., building firearms and explosive devices, promotion of"
33
+ " terrorism, instructions for suicide)."
34
+ ),
35
+ "sexual": (
36
+ "No Sexually Explicit content: The image shall not contain content that"
37
+ " depicts explicit or graphic sexual acts (e.g., pornography, erotic"
38
+ " nudity, depictions of rape or sexual assault)."
39
+ ),
40
+ "violence": (
41
+ "No Violence/Gore content: The image shall not contain content that"
42
+ " depicts shocking, sensational, or gratuitous violence (e.g.,"
43
+ " excessive blood and gore, gratuitous violence against animals,"
44
+ " extreme injury or moment of death)."
45
+ ),
46
+ }
47
+
48
+
49
+ class ShieldGemma2ProcessorKwargs(Gemma3ProcessorKwargs, total=False):
50
+ policies: Optional[Sequence[str]]
51
+ custom_policies: Optional[Mapping[str, str]]
52
+ _defaults = {
53
+ "text_kwargs": {
54
+ "padding": True,
55
+ },
56
+ "images_kwargs": {
57
+ "do_pan_and_scan": False,
58
+ },
59
+ }
60
+
61
+
62
+ class ShieldGemma2Processor(Gemma3Processor):
63
+ def __init__(
64
+ self, image_processor, tokenizer, chat_template=None, image_seq_length=256, policy_definitions=None, **kwargs
65
+ ):
66
+ """A processor for the ShieldGemma 2 model.
67
+
68
+ Args:
69
+ image_processor: The image processor to use, typically a `Gemma3ImageProcessorFast` instance.
70
+ tokenizer: The tokenizer to use, typically a `GemmaTokenizerFast` instance.
71
+ chat_template: The chat template to use with this processor. Typically, this is unset as the processor
72
+ configuration on Hugging Face Hub includes this value already.
73
+ image_seq_length: The number of soft tokens per image. Typically, this is unset as the processor
74
+ configuration on Hugging Face Hub includes this value already.
75
+ policy_definitions: A mapping from policy name to its description in text used as the default policies to
76
+ classify images against. The policy descriptions are included in the text of the prompts generated by
77
+ this processor. Typically, this is unset as the processor configuration on Hugging Face Hub includes
78
+ the base policies ShieldGemma was trained on.
79
+ """
80
+ super().__init__(image_processor, tokenizer, chat_template, image_seq_length, **kwargs)
81
+ if policy_definitions is None:
82
+ self.policy_definitions = DEFAULT_SHIELDGEMMA2_POLICIES
83
+ else:
84
+ self.policy_definitions = policy_definitions
85
+
86
+ def __call__(
87
+ self,
88
+ images: ImageInput = None,
89
+ text=None,
90
+ videos=None,
91
+ audio=None,
92
+ **kwargs: Unpack[ShieldGemma2ProcessorKwargs],
93
+ ) -> BatchFeature:
94
+ """Generates a batch of inputs from the provided images.
95
+
96
+ ShieldGemma was trained to classify image content for policy compliance using a specific prompt construction.
97
+ This processor generates a batch of such prompts from the provided images by:
98
+
99
+ 1. Creating a list of conversations, one for each `<image, policy>` pair;
100
+ 2. Converting these conversations to text using `self.apply_chat_template()`; and
101
+ 3. Encoding the conversations and images using the same techniques as `Gemma3Processor`.
102
+
103
+ Args:
104
+ images: A single image or a list of images to include in the batch.
105
+ text: Not supported.
106
+ videos: Not supported.
107
+ audio: Not supported.
108
+ kwargs: An optional dictionary of keyword arguments to configre the
109
+ processor. Possible values include:
110
+
111
+ * `custom_policies`: Additional policy definitions that augment the `self.policy_definitions` passed
112
+ into the constructor. Note that `custom_policies` that share a key with `self.policy_definitions`
113
+ will override the policy description
114
+ * `policies`: (Optional) a list of keys in the joint `self.policy_definitions | custom_policies`
115
+ dictionary of specific interest for the provided images. If empty or None, prompts will be
116
+ generated for every key in the joint dictionary.
117
+
118
+ Returns:
119
+ A `BatchFeature` continaing `input_ids`, `pixel_values`, etc. where each Tensor is of shape
120
+ `(len(images) * len(policies), )`, and the order within the batch will be
121
+ img1_policy1, ... img1_policyN, ... imgM_policyN.
122
+ """
123
+ del text, videos, audio
124
+
125
+ if not images:
126
+ raise ValueError("ShieldGemma 2 needs images to classify")
127
+ elif not isinstance(images, Sequence):
128
+ images = [images]
129
+
130
+ if not self.chat_template:
131
+ raise ValueError("ShieldGemma 2 requires the use of a specific chat template")
132
+
133
+ # Disable pan and scan
134
+ images_kwargs = kwargs.setdefault("images_kwargs", {})
135
+ if images_kwargs.get("do_pan_and_scan") is True:
136
+ logger.warning_once("ShieldGemma2 does not support pan and scan.")
137
+ images_kwargs["do_pan_and_scan"] = False
138
+
139
+ # Enable padding on the batch during tokenization
140
+ text_kwargs = kwargs.setdefault("text_kwargs", {})
141
+ if "padding" not in text_kwargs:
142
+ text_kwargs["padding"] = kwargs.pop("padding", True)
143
+ text_kwargs["padding_side"] = kwargs.pop("padding_side", "left")
144
+
145
+ policy_definitions: Mapping[str, str] = {
146
+ **self.policy_definitions,
147
+ **kwargs.get("custom_policies", {}),
148
+ }
149
+
150
+ if (policies := kwargs.get("policies")) is None:
151
+ policies = list(policy_definitions.keys())
152
+
153
+ # TODO(ryanmullins): Support images from PIL or URLs.
154
+ messages = []
155
+ expanded_images = []
156
+ for img in images:
157
+ for policy in policies:
158
+ messages.append(
159
+ [
160
+ {
161
+ "role": "user",
162
+ "content": [
163
+ {"type": "image"},
164
+ {"type": "text", "text": policy_definitions[policy]},
165
+ ],
166
+ }
167
+ ]
168
+ )
169
+ expanded_images.append([img])
170
+
171
+ text = self.apply_chat_template(messages, tokenize=False)
172
+ return super().__call__(images=expanded_images, text=text, **kwargs)
173
+
174
+ def batch_decode(self, *args, **kwargs):
175
+ """
176
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
177
+ refer to the docstring of this method for more information.
178
+ """
179
+ return self.tokenizer.batch_decode(*args, **kwargs)
180
+
181
+ def decode(self, *args, **kwargs):
182
+ """
183
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
184
+ the docstring of this method for more information.
185
+ """
186
+ return self.tokenizer.decode(*args, **kwargs)
187
+
188
+ @property
189
+ def model_input_names(self):
190
+ tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
191
+ image_processor_input_names = self.image_processor.model_input_names
192
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
193
+
194
+
195
+ __all__ = ["ShieldGemma2Processor"]
docs/transformers/build/lib/transformers/models/siglip/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_siglip import *
22
+ from .image_processing_siglip import *
23
+ from .image_processing_siglip_fast import *
24
+ from .modeling_siglip import *
25
+ from .processing_siglip import *
26
+ from .tokenization_siglip import *
27
+ else:
28
+ import sys
29
+
30
+ _file = globals()["__file__"]
31
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/siglip/configuration_siglip.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Siglip model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class SiglipTextConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
27
+ Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
28
+ configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
29
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ vocab_size (`int`, *optional*, defaults to 32000):
36
+ Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
37
+ the `inputs_ids` passed when calling [`SiglipModel`].
38
+ hidden_size (`int`, *optional*, defaults to 768):
39
+ Dimensionality of the encoder layers and the pooler layer.
40
+ intermediate_size (`int`, *optional*, defaults to 3072):
41
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
42
+ num_hidden_layers (`int`, *optional*, defaults to 12):
43
+ Number of hidden layers in the Transformer encoder.
44
+ num_attention_heads (`int`, *optional*, defaults to 12):
45
+ Number of attention heads for each attention layer in the Transformer encoder.
46
+ max_position_embeddings (`int`, *optional*, defaults to 64):
47
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
48
+ just in case (e.g., 512 or 1024 or 2048).
49
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
50
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
51
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
52
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
53
+ The epsilon used by the layer normalization layers.
54
+ attention_dropout (`float`, *optional*, defaults to 0.0):
55
+ The dropout ratio for the attention probabilities.
56
+ pad_token_id (`int`, *optional*, defaults to 1):
57
+ The id of the padding token in the vocabulary.
58
+ bos_token_id (`int`, *optional*, defaults to 49406):
59
+ The id of the beginning-of-sequence token in the vocabulary.
60
+ eos_token_id (`int`, *optional*, defaults to 49407):
61
+ The id of the end-of-sequence token in the vocabulary.
62
+ projection_size (`int`, *optional*, defaults to `hidden_size`):
63
+ The size of the projection head.
64
+
65
+ Example:
66
+
67
+ ```python
68
+ >>> from transformers import SiglipTextConfig, SiglipTextModel
69
+
70
+ >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
71
+ >>> configuration = SiglipTextConfig()
72
+
73
+ >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
74
+ >>> model = SiglipTextModel(configuration)
75
+
76
+ >>> # Accessing the model configuration
77
+ >>> configuration = model.config
78
+ ```"""
79
+
80
+ model_type = "siglip_text_model"
81
+ base_config_key = "text_config"
82
+
83
+ def __init__(
84
+ self,
85
+ vocab_size=32000,
86
+ hidden_size=768,
87
+ intermediate_size=3072,
88
+ num_hidden_layers=12,
89
+ num_attention_heads=12,
90
+ max_position_embeddings=64,
91
+ hidden_act="gelu_pytorch_tanh",
92
+ layer_norm_eps=1e-6,
93
+ attention_dropout=0.0,
94
+ # This differs from `CLIPTokenizer`'s default and from openai/siglip
95
+ # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
96
+ pad_token_id=1,
97
+ bos_token_id=49406,
98
+ eos_token_id=49407,
99
+ projection_size=None,
100
+ **kwargs,
101
+ ):
102
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
103
+
104
+ self.vocab_size = vocab_size
105
+ self.hidden_size = hidden_size
106
+ self.intermediate_size = intermediate_size
107
+ self.num_hidden_layers = num_hidden_layers
108
+ self.num_attention_heads = num_attention_heads
109
+ self.max_position_embeddings = max_position_embeddings
110
+ self.layer_norm_eps = layer_norm_eps
111
+ self.hidden_act = hidden_act
112
+ self.attention_dropout = attention_dropout
113
+ self.projection_size = projection_size if projection_size is not None else hidden_size
114
+
115
+
116
+ class SiglipVisionConfig(PretrainedConfig):
117
+ r"""
118
+ This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
119
+ Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
120
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
121
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
122
+
123
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
124
+ documentation from [`PretrainedConfig`] for more information.
125
+
126
+ Args:
127
+ hidden_size (`int`, *optional*, defaults to 768):
128
+ Dimensionality of the encoder layers and the pooler layer.
129
+ intermediate_size (`int`, *optional*, defaults to 3072):
130
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
131
+ num_hidden_layers (`int`, *optional*, defaults to 12):
132
+ Number of hidden layers in the Transformer encoder.
133
+ num_attention_heads (`int`, *optional*, defaults to 12):
134
+ Number of attention heads for each attention layer in the Transformer encoder.
135
+ num_channels (`int`, *optional*, defaults to 3):
136
+ Number of channels in the input images.
137
+ image_size (`int`, *optional*, defaults to 224):
138
+ The size (resolution) of each image.
139
+ patch_size (`int`, *optional*, defaults to 16):
140
+ The size (resolution) of each patch.
141
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
142
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
143
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
144
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
145
+ The epsilon used by the layer normalization layers.
146
+ attention_dropout (`float`, *optional*, defaults to 0.0):
147
+ The dropout ratio for the attention probabilities.
148
+
149
+ Example:
150
+
151
+ ```python
152
+ >>> from transformers import SiglipVisionConfig, SiglipVisionModel
153
+
154
+ >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
155
+ >>> configuration = SiglipVisionConfig()
156
+
157
+ >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
158
+ >>> model = SiglipVisionModel(configuration)
159
+
160
+ >>> # Accessing the model configuration
161
+ >>> configuration = model.config
162
+ ```"""
163
+
164
+ model_type = "siglip_vision_model"
165
+ base_config_key = "vision_config"
166
+
167
+ def __init__(
168
+ self,
169
+ hidden_size=768,
170
+ intermediate_size=3072,
171
+ num_hidden_layers=12,
172
+ num_attention_heads=12,
173
+ num_channels=3,
174
+ image_size=224,
175
+ patch_size=16,
176
+ hidden_act="gelu_pytorch_tanh",
177
+ layer_norm_eps=1e-6,
178
+ attention_dropout=0.0,
179
+ **kwargs,
180
+ ):
181
+ super().__init__(**kwargs)
182
+
183
+ self.hidden_size = hidden_size
184
+ self.intermediate_size = intermediate_size
185
+ self.num_hidden_layers = num_hidden_layers
186
+ self.num_attention_heads = num_attention_heads
187
+ self.num_channels = num_channels
188
+ self.patch_size = patch_size
189
+ self.image_size = image_size
190
+ self.attention_dropout = attention_dropout
191
+ self.layer_norm_eps = layer_norm_eps
192
+ self.hidden_act = hidden_act
193
+
194
+
195
+ class SiglipConfig(PretrainedConfig):
196
+ r"""
197
+ [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
198
+ instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
199
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
200
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
201
+
202
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
203
+ documentation from [`PretrainedConfig`] for more information.
204
+
205
+ Args:
206
+ text_config (`dict`, *optional*):
207
+ Dictionary of configuration options used to initialize [`SiglipTextConfig`].
208
+ vision_config (`dict`, *optional*):
209
+ Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
210
+ kwargs (*optional*):
211
+ Dictionary of keyword arguments.
212
+
213
+ Example:
214
+
215
+ ```python
216
+ >>> from transformers import SiglipConfig, SiglipModel
217
+
218
+ >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
219
+ >>> configuration = SiglipConfig()
220
+
221
+ >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
222
+ >>> model = SiglipModel(configuration)
223
+
224
+ >>> # Accessing the model configuration
225
+ >>> configuration = model.config
226
+
227
+ >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
228
+ >>> from transformers import SiglipTextConfig, SiglipVisionConfig
229
+
230
+ >>> # Initializing a SiglipText and SiglipVision configuration
231
+ >>> config_text = SiglipTextConfig()
232
+ >>> config_vision = SiglipVisionConfig()
233
+
234
+ >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
235
+ ```"""
236
+
237
+ model_type = "siglip"
238
+ sub_configs = {"text_config": SiglipTextConfig, "vision_config": SiglipVisionConfig}
239
+
240
+ def __init__(self, text_config=None, vision_config=None, **kwargs):
241
+ super().__init__(**kwargs)
242
+
243
+ if text_config is None:
244
+ text_config = {}
245
+ logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
246
+
247
+ if vision_config is None:
248
+ vision_config = {}
249
+ logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
250
+
251
+ self.text_config = SiglipTextConfig(**text_config)
252
+ self.vision_config = SiglipVisionConfig(**vision_config)
253
+
254
+ self.initializer_factor = 1.0
255
+
256
+ @classmethod
257
+ def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs):
258
+ r"""
259
+ Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
260
+ model configuration.
261
+
262
+ Returns:
263
+ [`SiglipConfig`]: An instance of a configuration object
264
+ """
265
+
266
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
267
+
268
+
269
+ __all__ = ["SiglipConfig", "SiglipTextConfig", "SiglipVisionConfig"]
docs/transformers/build/lib/transformers/models/siglip/convert_siglip_to_hf.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert SigLIP checkpoints from the original repository.
16
+
17
+ URL: https://github.com/google-research/big_vision/tree/main
18
+ """
19
+
20
+ import argparse
21
+ import collections
22
+ import os
23
+ from typing import Tuple
24
+
25
+ import numpy as np
26
+ import requests
27
+ import torch
28
+ from huggingface_hub import hf_hub_download
29
+ from numpy import load
30
+ from PIL import Image
31
+
32
+ from transformers import (
33
+ GemmaTokenizerFast,
34
+ SiglipConfig,
35
+ SiglipImageProcessor,
36
+ SiglipModel,
37
+ SiglipProcessor,
38
+ SiglipTokenizer,
39
+ )
40
+ from transformers.utils import logging
41
+
42
+
43
+ logging.set_verbosity_info()
44
+ logger = logging.get_logger(__name__)
45
+
46
+
47
+ MODEL_CONFIGS = {
48
+ "base": {
49
+ "hidden_size": 768,
50
+ "intermediate_size": 3072,
51
+ "num_hidden_layers": 12,
52
+ "num_attention_heads": 12,
53
+ },
54
+ "large": {
55
+ "hidden_size": 1024,
56
+ "intermediate_size": 4096,
57
+ "num_hidden_layers": 24,
58
+ "num_attention_heads": 16,
59
+ },
60
+ "giant-opt": {
61
+ "hidden_size": 1536,
62
+ "intermediate_size": 6144,
63
+ "num_hidden_layers": 40,
64
+ "num_attention_heads": 16,
65
+ },
66
+ "so400m": {
67
+ "hidden_size": 1152,
68
+ "intermediate_size": 4304,
69
+ "num_hidden_layers": 27,
70
+ "num_attention_heads": 16,
71
+ },
72
+ }
73
+
74
+ model_name_to_checkpoint = {
75
+ # base checkpoints
76
+ "siglip-base-patch16-224": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_224_63724782.npz",
77
+ "siglip-base-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_256_60500360.npz",
78
+ "siglip-base-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_384_68578854.npz",
79
+ "siglip-base-patch16-512": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_512_68580893.npz",
80
+ # large checkpoints
81
+ "siglip-large-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_256_60552751.npz",
82
+ "siglip-large-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_384_63634585.npz",
83
+ # multilingual checkpoint
84
+ "siglip-base-patch16-256-i18n": "/Users/nielsrogge/Documents/SigLIP/webli_i18n_b16_256_66117334.npz",
85
+ # so400m checkpoints
86
+ "siglip-so400m-patch14-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_so400m_384_58765454.npz",
87
+ # ----------------- v2 -----------------
88
+ # base checkpoints
89
+ "siglip2-base-patch32-256": "gv-hf/siglip2/siglip2_b32_256.npz",
90
+ "siglip2-base-patch16-224": "gv-hf/siglip2/siglip2_b16_224.npz",
91
+ "siglip2-base-patch16-256": "gv-hf/siglip2/siglip2_b16_256.npz",
92
+ "siglip2-base-patch16-384": "gv-hf/siglip2/siglip2_b16_384.npz",
93
+ "siglip2-base-patch16-512": "gv-hf/siglip2/siglip2_b16_512.npz",
94
+ # large checkpoints
95
+ "siglip2-large-patch16-256": "gv-hf/siglip2/siglip2_l16_256.npz",
96
+ "siglip2-large-patch16-384": "gv-hf/siglip2/siglip2_l16_384.npz",
97
+ "siglip2-large-patch16-512": "gv-hf/siglip2/siglip2_l16_512.npz",
98
+ # giant opt checkpoints
99
+ "siglip2-giant-opt-patch16-256": "gv-hf/siglip2/siglip2_g-opt16_256.npz",
100
+ "siglip2-giant-opt-patch16-384": "gv-hf/siglip2/siglip2_g-opt16_384.npz",
101
+ # so400m checkpoints
102
+ "siglip2-so400m-patch14-224": "gv-hf/siglip2/siglip2_so400m14_224.npz",
103
+ "siglip2-so400m-patch14-384": "gv-hf/siglip2/siglip2_so400m14_384.npz",
104
+ "siglip2-so400m-patch16-256": "gv-hf/siglip2/siglip2_so400m16_256.npz",
105
+ "siglip2-so400m-patch16-384": "gv-hf/siglip2/siglip2_so400m16_384.npz",
106
+ "siglip2-so400m-patch16-512": "gv-hf/siglip2/siglip2_so400m16_512.npz",
107
+ }
108
+
109
+ # ------------------------------------------------------------------------------------------------------
110
+ # CONFIG
111
+ # ------------------------------------------------------------------------------------------------------
112
+
113
+
114
+ def get_image_size_from_model_name(model_name: str) -> int:
115
+ if "-i18n" not in model_name:
116
+ size = model_name.split("-")[-1]
117
+ else:
118
+ size = model_name.split("-")[-2]
119
+ return int(size)
120
+
121
+
122
+ def get_patch_size_from_model_name(model_name: str) -> int:
123
+ patch_str = [x for x in model_name.split("-") if "patch" in x][0]
124
+ return int(patch_str[-2:])
125
+
126
+
127
+ def get_vocab_size_from_model_name(model_name: str) -> int:
128
+ if "siglip2" in model_name:
129
+ vocab_size = 256000
130
+ elif "-i18n" in model_name:
131
+ vocab_size = 250000
132
+ else:
133
+ vocab_size = 32000
134
+ return vocab_size
135
+
136
+
137
+ def get_vocab_file_from_model_name(model_name: str) -> str:
138
+ # get vocab file
139
+ if "i18n" in model_name:
140
+ vocab_file = "/Users/nielsrogge/Documents/SigLIP/multilingual_vocab/sentencepiece.model"
141
+ else:
142
+ vocab_file = "/Users/nielsrogge/Documents/SigLIP/english_vocab/sentencepiece.model"
143
+ return vocab_file
144
+
145
+
146
+ def get_text_and_vision_vit_variants(model_name: str) -> Tuple[str, str]:
147
+ variant = model_name.split("-")[1] if "giant-opt" not in model_name else "giant-opt"
148
+ return {
149
+ "base": ("base", "base"),
150
+ "large": ("large", "large"),
151
+ "so400m": ("so400m", "so400m"),
152
+ # g-opt siglip2 is not symmetric
153
+ "giant-opt": ("so400m", "giant-opt"),
154
+ }[variant]
155
+
156
+
157
+ def get_siglip_config(model_name):
158
+ text_variant, vision_variant = get_text_and_vision_vit_variants(model_name)
159
+ text_config = MODEL_CONFIGS[text_variant].copy()
160
+ vision_config = MODEL_CONFIGS[vision_variant].copy()
161
+
162
+ text_config["vocab_size"] = get_vocab_size_from_model_name(model_name)
163
+ vision_config["image_size"] = get_image_size_from_model_name(model_name)
164
+ vision_config["patch_size"] = get_patch_size_from_model_name(model_name)
165
+
166
+ if text_config["hidden_size"] != vision_config["hidden_size"]:
167
+ text_config["projection_size"] = vision_config["hidden_size"]
168
+
169
+ return SiglipConfig(text_config=text_config, vision_config=vision_config)
170
+
171
+
172
+ # ------------------------------------------------------------------------------------------------------
173
+ # PROCESSING
174
+ # ------------------------------------------------------------------------------------------------------
175
+
176
+
177
+ def get_tokenizer(model_name: str) -> GemmaTokenizerFast:
178
+ if "siglip2" in model_name:
179
+ tokenizer = GemmaTokenizerFast.from_pretrained(
180
+ "google/gemma-2-9b-it",
181
+ add_bos_token=False,
182
+ add_eos_token=True,
183
+ padding_side="right",
184
+ do_lower_case=True,
185
+ # important: make tokenizer NOT return attention_mask since original one doesn't require it
186
+ model_input_names=["input_ids"],
187
+ )
188
+ else:
189
+ # for siglip v1
190
+ vocab_file = get_vocab_file_from_model_name(model_name)
191
+ # important: make tokenizer not return attention_mask since original one doesn't require it
192
+ tokenizer = SiglipTokenizer(vocab_file=vocab_file, model_input_names=["input_ids"])
193
+ return tokenizer
194
+
195
+
196
+ def get_image_processor(model_name: str) -> SiglipImageProcessor:
197
+ image_size = get_image_size_from_model_name(model_name)
198
+ size = {"height": image_size, "width": image_size}
199
+ if "siglip2" in model_name:
200
+ image_processor = SiglipImageProcessor(size=size, resample=2) # bilinear resampling
201
+ else:
202
+ image_processor = SiglipImageProcessor(size=size)
203
+ return image_processor
204
+
205
+
206
+ # ------------------------------------------------------------------------------------------------------
207
+ # CONVERT FUNCTIONS
208
+ # ------------------------------------------------------------------------------------------------------
209
+
210
+
211
+ def split_encoderblock_layers(state_dict: dict) -> dict:
212
+ """
213
+ Split the encoderblock weight into layers. In some cases they are concatenated in
214
+ the original checkpoints.
215
+ """
216
+ # Make shallow copy
217
+ state_dict = state_dict.copy()
218
+ # Split encoderblock weight into layers
219
+ keys = list(state_dict.keys())
220
+ for key in keys:
221
+ if "/encoderblock/" in key:
222
+ weight = state_dict.pop(key)
223
+ for i, weight_i in enumerate(weight):
224
+ new_name = key.replace("encoderblock", f"encoderblock_{i}")
225
+ state_dict[new_name] = weight_i
226
+ return state_dict
227
+
228
+
229
+ def create_rename_keys(config):
230
+ rename_keys = []
231
+ # fmt: off
232
+
233
+ # vision encoder
234
+
235
+ rename_keys.append(("params/img/embedding/kernel", "vision_model.embeddings.patch_embedding.weight"))
236
+ rename_keys.append(("params/img/embedding/bias", "vision_model.embeddings.patch_embedding.bias"))
237
+ rename_keys.append(("params/img/pos_embedding", "vision_model.embeddings.position_embedding.weight"))
238
+
239
+ for i in range(config.vision_config.num_hidden_layers):
240
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/scale", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
241
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
242
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/scale", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
243
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
244
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
245
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
246
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
247
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
248
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"vision_model.encoder.layers.{i}.self_attn.k_proj.weight"))
249
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"vision_model.encoder.layers.{i}.self_attn.k_proj.bias"))
250
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"vision_model.encoder.layers.{i}.self_attn.v_proj.weight"))
251
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"vision_model.encoder.layers.{i}.self_attn.v_proj.bias"))
252
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"vision_model.encoder.layers.{i}.self_attn.q_proj.weight"))
253
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"vision_model.encoder.layers.{i}.self_attn.q_proj.bias"))
254
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"vision_model.encoder.layers.{i}.self_attn.out_proj.weight"))
255
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"vision_model.encoder.layers.{i}.self_attn.out_proj.bias"))
256
+
257
+ rename_keys.append(("params/img/Transformer/encoder_norm/scale", "vision_model.post_layernorm.weight"))
258
+ rename_keys.append(("params/img/Transformer/encoder_norm/bias", "vision_model.post_layernorm.bias"))
259
+
260
+ rename_keys.append(("params/img/MAPHead_0/probe", "vision_model.head.probe"))
261
+ rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/scale", "vision_model.head.layernorm.weight"))
262
+ rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/bias", "vision_model.head.layernorm.bias"))
263
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/kernel", "vision_model.head.mlp.fc1.weight"))
264
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/bias", "vision_model.head.mlp.fc1.bias"))
265
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/kernel", "vision_model.head.mlp.fc2.weight"))
266
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/bias", "vision_model.head.mlp.fc2.bias"))
267
+ rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/kernel", "vision_model.head.attention.out_proj.weight"))
268
+ rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/bias", "vision_model.head.attention.out_proj.bias"))
269
+
270
+ # text encoder
271
+
272
+ rename_keys.append(("params/txt/Embed_0/embedding", "text_model.embeddings.token_embedding.weight"))
273
+ rename_keys.append(("params/txt/pos_embedding", "text_model.embeddings.position_embedding.weight"))
274
+
275
+ for i in range(config.text_config.num_hidden_layers):
276
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/scale", f"text_model.encoder.layers.{i}.layer_norm1.weight"))
277
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/bias", f"text_model.encoder.layers.{i}.layer_norm1.bias"))
278
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/scale", f"text_model.encoder.layers.{i}.layer_norm2.weight"))
279
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/bias", f"text_model.encoder.layers.{i}.layer_norm2.bias"))
280
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"text_model.encoder.layers.{i}.mlp.fc1.weight"))
281
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"text_model.encoder.layers.{i}.mlp.fc1.bias"))
282
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"text_model.encoder.layers.{i}.mlp.fc2.weight"))
283
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"text_model.encoder.layers.{i}.mlp.fc2.bias"))
284
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"text_model.encoder.layers.{i}.self_attn.k_proj.weight"))
285
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"text_model.encoder.layers.{i}.self_attn.k_proj.bias"))
286
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"text_model.encoder.layers.{i}.self_attn.v_proj.weight"))
287
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"text_model.encoder.layers.{i}.self_attn.v_proj.bias"))
288
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"text_model.encoder.layers.{i}.self_attn.q_proj.weight"))
289
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"text_model.encoder.layers.{i}.self_attn.q_proj.bias"))
290
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"text_model.encoder.layers.{i}.self_attn.out_proj.weight"))
291
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"text_model.encoder.layers.{i}.self_attn.out_proj.bias"))
292
+
293
+ rename_keys.append(("params/txt/Encoder_0/encoder_norm/scale", "text_model.final_layer_norm.weight"))
294
+ rename_keys.append(("params/txt/Encoder_0/encoder_norm/bias", "text_model.final_layer_norm.bias"))
295
+ rename_keys.append(("params/txt/head/kernel", "text_model.head.weight"))
296
+ rename_keys.append(("params/txt/head/bias", "text_model.head.bias"))
297
+
298
+ # learned temperature and bias
299
+ rename_keys.append(("params/t", "logit_scale"))
300
+ rename_keys.append(("params/b", "logit_bias"))
301
+
302
+ # fmt: on
303
+ return rename_keys
304
+
305
+
306
+ def rename_key(dct, old, new, config):
307
+ val = dct.pop(old)
308
+
309
+ if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "vision" in new:
310
+ val = val.reshape(-1, config.vision_config.hidden_size)
311
+ if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "text" in new:
312
+ val = val.reshape(-1, config.text_config.hidden_size)
313
+
314
+ if "patch_embedding.weight" in new:
315
+ val = val.transpose(3, 2, 0, 1)
316
+ elif new.endswith("weight") and "position_embedding" not in new and "token_embedding" not in new:
317
+ val = val.T
318
+
319
+ if "position_embedding" in new and "vision" in new:
320
+ val = val.reshape(-1, config.vision_config.hidden_size)
321
+ if "position_embedding" in new and "text" in new:
322
+ val = val.reshape(-1, config.text_config.hidden_size)
323
+
324
+ if new.endswith("bias"):
325
+ val = val.reshape(-1)
326
+
327
+ dct[new] = torch.from_numpy(val)
328
+
329
+
330
+ def read_in_q_k_v_head(state_dict, config):
331
+ # read in individual input projection layers
332
+ key_proj_weight = (
333
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/kernel")
334
+ .reshape(-1, config.vision_config.hidden_size)
335
+ .T
336
+ )
337
+ key_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/bias").reshape(-1)
338
+ value_proj_weight = (
339
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/kernel")
340
+ .reshape(-1, config.vision_config.hidden_size)
341
+ .T
342
+ )
343
+ value_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/bias").reshape(-1)
344
+ query_proj_weight = (
345
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/kernel")
346
+ .reshape(-1, config.vision_config.hidden_size)
347
+ .T
348
+ )
349
+ query_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/bias").reshape(-1)
350
+
351
+ # next, add them to the state dict as a single matrix + vector
352
+ state_dict["vision_model.head.attention.in_proj_weight"] = torch.from_numpy(
353
+ np.concatenate([query_proj_weight, key_proj_weight, value_proj_weight], axis=0)
354
+ )
355
+ state_dict["vision_model.head.attention.in_proj_bias"] = torch.from_numpy(
356
+ np.concatenate([query_proj_bias, key_proj_bias, value_proj_bias], axis=0)
357
+ )
358
+
359
+
360
+ # We will verify our results on an image of cute cats
361
+ def prepare_img():
362
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
363
+ image = Image.open(requests.get(url, stream=True).raw)
364
+ return image
365
+
366
+
367
+ def flatten_nested_dict(params, parent_key="", sep="/"):
368
+ items = []
369
+
370
+ for k, v in params.items():
371
+ new_key = parent_key + sep + k if parent_key else k
372
+
373
+ if isinstance(v, collections.abc.MutableMapping):
374
+ items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
375
+ else:
376
+ items.append((new_key, v))
377
+ return dict(items)
378
+
379
+
380
+ @torch.no_grad()
381
+ def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logits=True, push_to_hub=False):
382
+ """
383
+ Copy/paste/tweak model's weights to our SigLIP structure.
384
+ """
385
+
386
+ # Define default SigLIP configuration
387
+ config = get_siglip_config(model_name)
388
+
389
+ # Get checkpoint
390
+ checkpoint = model_name_to_checkpoint[model_name]
391
+ if not os.path.exists(checkpoint):
392
+ org, repo_id, *filepath = checkpoint.split("/")
393
+ checkpoint = hf_hub_download(repo_id=f"{org}/{repo_id}", filename="/".join(filepath))
394
+
395
+ # Load original state dict
396
+ data = load(checkpoint)
397
+ state_dict = flatten_nested_dict(data)
398
+ state_dict = split_encoderblock_layers(state_dict)
399
+
400
+ # Remove and rename some keys
401
+ rename_keys = create_rename_keys(config)
402
+ for src, dest in rename_keys:
403
+ rename_key(state_dict, src, dest, config)
404
+
405
+ # qkv matrices of attention pooling head need special treatment
406
+ read_in_q_k_v_head(state_dict, config)
407
+
408
+ # Load HuggingFace model
409
+ model = SiglipModel(config).eval()
410
+ model.load_state_dict(state_dict)
411
+
412
+ # Create processor
413
+ image_processor = get_image_processor(model_name)
414
+ tokenizer = get_tokenizer(model_name)
415
+ processor = SiglipProcessor(image_processor=image_processor, tokenizer=tokenizer)
416
+
417
+ # Verify forward pass on dummy images and texts
418
+ url_1 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-ipod.jpg"
419
+ image_1 = Image.open(requests.get(url_1, stream=True).raw).convert("RGB")
420
+ url_2 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-blank.jpg"
421
+ image_2 = Image.open(requests.get(url_2, stream=True).raw).convert("RGB")
422
+ texts = ["an apple", "a picture of an apple"]
423
+
424
+ inputs = processor(images=[image_1, image_2], text=texts, padding="max_length", max_length=64, return_tensors="pt")
425
+ with torch.no_grad():
426
+ outputs = model(**inputs)
427
+
428
+ if verify_logits:
429
+ image_size = config.vision_config.image_size
430
+
431
+ # verify input_ids against original ones
432
+ if image_size == 224:
433
+ filename = "siglip_pixel_values.pt"
434
+ elif image_size == 256:
435
+ filename = "siglip_pixel_values_256.pt"
436
+ elif image_size == 384:
437
+ filename = "siglip_pixel_values_384.pt"
438
+ elif image_size == 512:
439
+ filename = "siglip_pixel_values_512.pt"
440
+ else:
441
+ raise ValueError("Image size not supported")
442
+
443
+ filepath = hf_hub_download(repo_id="nielsr/test-image", filename=filename, repo_type="dataset")
444
+ original_pixel_values = torch.load(filepath, weights_only=True)
445
+ filepath = hf_hub_download(repo_id="nielsr/test-image", filename="siglip_input_ids.pt", repo_type="dataset")
446
+ original_input_ids = torch.load(filepath, weights_only=True)
447
+
448
+ if "i18n" not in model_name:
449
+ assert inputs.input_ids.tolist() == original_input_ids.tolist()
450
+
451
+ print("Mean of original pixel values:", original_pixel_values.mean())
452
+ print("Mean of new pixel values:", inputs.pixel_values.mean())
453
+
454
+ # note: we're testing with original pixel values here since we don't have exact pixel values
455
+ with torch.no_grad():
456
+ outputs = model(input_ids=original_input_ids, pixel_values=original_pixel_values)
457
+ print(outputs.logits_per_image[:3, :3])
458
+
459
+ probs = torch.sigmoid(outputs.logits_per_image) # these are the probabilities
460
+ print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
461
+ print(f"{probs[0][1]:.1%} that image 0 is '{texts[1]}'")
462
+
463
+ if model_name == "siglip-base-patch16-224":
464
+ expected_slice = torch.tensor(
465
+ [[-2.9621, -2.1672], [-0.2713, 0.2910]],
466
+ )
467
+ elif model_name == "siglip-base-patch16-256":
468
+ expected_slice = torch.tensor(
469
+ [[-3.1146, -1.9894], [-0.7312, 0.6387]],
470
+ )
471
+ elif model_name == "siglip-base-patch16-384":
472
+ expected_slice = torch.tensor(
473
+ [[-2.8098, -2.1891], [-0.4242, 0.4102]],
474
+ )
475
+ elif model_name == "siglip-base-patch16-512":
476
+ expected_slice = torch.tensor(
477
+ [[-2.7899, -2.2668], [-0.4295, -0.0735]],
478
+ )
479
+ elif model_name == "siglip-large-patch16-256":
480
+ expected_slice = torch.tensor(
481
+ [[-1.5827, -0.5801], [-0.9153, 0.1363]],
482
+ )
483
+ elif model_name == "siglip-large-patch16-384":
484
+ expected_slice = torch.tensor(
485
+ [[-2.1523, -0.2899], [-0.2959, 0.7884]],
486
+ )
487
+ elif model_name == "siglip-so400m-patch14-384":
488
+ expected_slice = torch.tensor([[-1.2441, -0.6649], [-0.7060, 0.7374]])
489
+ elif model_name == "siglip-base-patch16-256-i18n":
490
+ expected_slice = torch.tensor(
491
+ [[-0.9064, 0.1073], [-0.0299, 0.5304]],
492
+ )
493
+
494
+ assert torch.allclose(outputs.logits_per_image[:3, :3], expected_slice, atol=1e-4)
495
+ print("Looks ok!")
496
+
497
+ if pytorch_dump_folder_path is not None:
498
+ pytorch_dump_folder_path = os.path.join(pytorch_dump_folder_path, model_name)
499
+ os.makedirs(pytorch_dump_folder_path, exist_ok=True)
500
+ print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
501
+ model.save_pretrained(pytorch_dump_folder_path)
502
+ print(f"Saving processor to {pytorch_dump_folder_path}")
503
+ processor.save_pretrained(pytorch_dump_folder_path)
504
+
505
+ if push_to_hub:
506
+ model.push_to_hub(f"s0225/{model_name}", private=True)
507
+ processor.push_to_hub(f"s0225/{model_name}", private=True)
508
+
509
+
510
+ if __name__ == "__main__":
511
+ parser = argparse.ArgumentParser()
512
+ # Required parameters
513
+ parser.add_argument(
514
+ "--model_name",
515
+ default="siglip-base-patch16-224",
516
+ type=str,
517
+ choices=model_name_to_checkpoint.keys(),
518
+ help="Name of the model you'd like to convert.",
519
+ )
520
+ parser.add_argument(
521
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
522
+ )
523
+ parser.add_argument(
524
+ "--verify_logits",
525
+ action="store_true",
526
+ help="Whether to verify logits against the original implementation.",
527
+ )
528
+ parser.add_argument(
529
+ "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
530
+ )
531
+
532
+ args = parser.parse_args()
533
+ convert_siglip_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.verify_logits, args.push_to_hub)
docs/transformers/build/lib/transformers/models/siglip/image_processing_siglip.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for SigLIP."""
16
+
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
20
+ from ...image_transforms import (
21
+ convert_to_rgb,
22
+ resize,
23
+ to_channel_dimension_format,
24
+ )
25
+ from ...image_utils import (
26
+ IMAGENET_STANDARD_MEAN,
27
+ IMAGENET_STANDARD_STD,
28
+ ChannelDimension,
29
+ ImageInput,
30
+ PILImageResampling,
31
+ infer_channel_dimension_format,
32
+ is_scaled_image,
33
+ make_flat_list_of_images,
34
+ to_numpy_array,
35
+ valid_images,
36
+ validate_preprocess_arguments,
37
+ )
38
+ from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ if is_vision_available():
45
+ import PIL
46
+
47
+
48
+ class SiglipImageProcessor(BaseImageProcessor):
49
+ r"""
50
+ Constructs a SigLIP image processor.
51
+
52
+ Args:
53
+ do_resize (`bool`, *optional*, defaults to `True`):
54
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
55
+ `do_resize` in the `preprocess` method.
56
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
57
+ Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
58
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
59
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
60
+ do_rescale (`bool`, *optional*, defaults to `True`):
61
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
62
+ the `preprocess` method.
63
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
64
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
65
+ method.
66
+ do_normalize (`bool`, *optional*, defaults to `True`):
67
+ Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
68
+ `do_normalize` in the `preprocess` method.
69
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
70
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
71
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
72
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
73
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
74
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
75
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
76
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
77
+ Whether to convert the image to RGB.
78
+ """
79
+
80
+ model_input_names = ["pixel_values"]
81
+
82
+ def __init__(
83
+ self,
84
+ do_resize: bool = True,
85
+ size: Dict[str, int] = None,
86
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
87
+ do_rescale: bool = True,
88
+ rescale_factor: Union[int, float] = 1 / 255,
89
+ do_normalize: bool = True,
90
+ image_mean: Optional[Union[float, List[float]]] = None,
91
+ image_std: Optional[Union[float, List[float]]] = None,
92
+ do_convert_rgb: Optional[bool] = None,
93
+ **kwargs,
94
+ ) -> None:
95
+ super().__init__(**kwargs)
96
+ size = size if size is not None else {"height": 224, "width": 224}
97
+ image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
98
+ image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
99
+
100
+ self.do_resize = do_resize
101
+ self.size = size
102
+ self.resample = resample
103
+ self.do_rescale = do_rescale
104
+ self.rescale_factor = rescale_factor
105
+ self.do_normalize = do_normalize
106
+ self.image_mean = image_mean
107
+ self.image_std = image_std
108
+ self.do_convert_rgb = do_convert_rgb
109
+
110
+ @filter_out_non_signature_kwargs()
111
+ def preprocess(
112
+ self,
113
+ images: ImageInput,
114
+ do_resize: Optional[bool] = None,
115
+ size: Dict[str, int] = None,
116
+ resample: PILImageResampling = None,
117
+ do_rescale: Optional[bool] = None,
118
+ rescale_factor: Optional[float] = None,
119
+ do_normalize: Optional[bool] = None,
120
+ image_mean: Optional[Union[float, List[float]]] = None,
121
+ image_std: Optional[Union[float, List[float]]] = None,
122
+ return_tensors: Optional[Union[str, TensorType]] = None,
123
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
124
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
125
+ do_convert_rgb: Optional[bool] = None,
126
+ ) -> PIL.Image.Image:
127
+ """
128
+ Preprocess an image or batch of images.
129
+
130
+ Args:
131
+ images (`ImageInput`):
132
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
133
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
134
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
135
+ Whether to resize the image.
136
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
137
+ Size of the image after resizing.
138
+ resample (`int`, *optional*, defaults to `self.resample`):
139
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
140
+ has an effect if `do_resize` is set to `True`.
141
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
142
+ Whether to rescale the image.
143
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
144
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
145
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
146
+ Whether to normalize the image.
147
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
148
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
149
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
150
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
151
+ `True`.
152
+ return_tensors (`str` or `TensorType`, *optional*):
153
+ The type of tensors to return. Can be one of:
154
+ - Unset: Return a list of `np.ndarray`.
155
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
156
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
157
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
158
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
159
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
160
+ The channel dimension format for the output image. Can be one of:
161
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
162
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
163
+ - Unset: Use the channel dimension format of the input image.
164
+ input_data_format (`ChannelDimension` or `str`, *optional*):
165
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
166
+ from the input image. Can be one of:
167
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
168
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
169
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
170
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
171
+ Whether to convert the image to RGB.
172
+ """
173
+ do_resize = do_resize if do_resize is not None else self.do_resize
174
+ size = size if size is not None else self.size
175
+ size = get_size_dict(size, param_name="size", default_to_square=False)
176
+ resample = resample if resample is not None else self.resample
177
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
178
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
179
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
180
+ image_mean = image_mean if image_mean is not None else self.image_mean
181
+ image_std = image_std if image_std is not None else self.image_std
182
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
183
+
184
+ images = make_flat_list_of_images(images)
185
+
186
+ if not valid_images(images):
187
+ raise ValueError(
188
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
189
+ "torch.Tensor, tf.Tensor or jax.ndarray."
190
+ )
191
+ validate_preprocess_arguments(
192
+ do_rescale=do_rescale,
193
+ rescale_factor=rescale_factor,
194
+ do_normalize=do_normalize,
195
+ image_mean=image_mean,
196
+ image_std=image_std,
197
+ do_resize=do_resize,
198
+ size=size,
199
+ resample=resample,
200
+ )
201
+ if do_convert_rgb:
202
+ images = [convert_to_rgb(image) for image in images]
203
+
204
+ # All transformations expect numpy arrays.
205
+ images = [to_numpy_array(image) for image in images]
206
+
207
+ if do_rescale and is_scaled_image(images[0]):
208
+ logger.warning_once(
209
+ "It looks like you are trying to rescale already rescaled images. If the input"
210
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
211
+ )
212
+
213
+ if input_data_format is None:
214
+ # We assume that all images have the same channel dimension format.
215
+ input_data_format = infer_channel_dimension_format(images[0])
216
+
217
+ if do_resize:
218
+ height, width = size["height"], size["width"]
219
+ images = [
220
+ resize(image=image, size=(height, width), resample=resample, input_data_format=input_data_format)
221
+ for image in images
222
+ ]
223
+
224
+ if do_rescale:
225
+ images = [
226
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
227
+ for image in images
228
+ ]
229
+
230
+ if do_normalize:
231
+ images = [
232
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
233
+ for image in images
234
+ ]
235
+
236
+ images = [
237
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
238
+ ]
239
+
240
+ data = {"pixel_values": images}
241
+ return BatchFeature(data=data, tensor_type=return_tensors)
242
+
243
+
244
+ __all__ = ["SiglipImageProcessor"]
docs/transformers/build/lib/transformers/models/siglip/image_processing_siglip_fast.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Fast Image processor class for SigLIP."""
16
+
17
+ from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast
18
+ from ...image_utils import (
19
+ IMAGENET_STANDARD_MEAN,
20
+ IMAGENET_STANDARD_STD,
21
+ PILImageResampling,
22
+ )
23
+ from ...utils import add_start_docstrings
24
+
25
+
26
+ @add_start_docstrings(
27
+ "Constructs a fast SigLIP image processor.",
28
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
29
+ )
30
+ class SiglipImageProcessorFast(BaseImageProcessorFast):
31
+ resample = PILImageResampling.BICUBIC
32
+ image_mean = IMAGENET_STANDARD_MEAN
33
+ image_std = IMAGENET_STANDARD_STD
34
+ size = {"height": 224, "width": 224}
35
+ default_to_square = False
36
+ do_resize = True
37
+ do_rescale = True
38
+ do_normalize = True
39
+
40
+
41
+ __all__ = ["SiglipImageProcessorFast"]