Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dev_scripts/ci_container_test.sh +41 -0
- 4JOB_TRAIN.jsonl +0 -0
- MANIFEST.in +5 -0
- Makefile +25 -0
- README.md +423 -0
- README_CN.md +413 -0
- checkMissing.py +86 -0
- clean_transcripts.py +95 -0
- count_audios.py +69 -0
- count_folders-Copy1.py +122 -0
- count_folders.py +122 -0
- dialogue_length_distribution.png +0 -0
- dialogue_length_ranges.png +0 -0
- docs/transformers/build/lib/transformers/models/sam/processing_sam.py +311 -0
- docs/transformers/build/lib/transformers/models/seamless_m4t/convert_fairseq2_to_hf.py +396 -0
- docs/transformers/build/lib/transformers/models/seamless_m4t/modeling_seamless_m4t.py +0 -0
- docs/transformers/build/lib/transformers/models/seamless_m4t/processing_seamless_m4t.py +120 -0
- docs/transformers/build/lib/transformers/models/seamless_m4t/tokenization_seamless_m4t_fast.py +450 -0
- docs/transformers/build/lib/transformers/models/seamless_m4t_v2/convert_fairseq2_to_hf.py +404 -0
- docs/transformers/build/lib/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +0 -0
- docs/transformers/build/lib/transformers/models/segformer/__init__.py +30 -0
- docs/transformers/build/lib/transformers/models/segformer/configuration_segformer.py +171 -0
- docs/transformers/build/lib/transformers/models/segformer/convert_segformer_original_to_pytorch.py +387 -0
- docs/transformers/build/lib/transformers/models/segformer/feature_extraction_segformer.py +38 -0
- docs/transformers/build/lib/transformers/models/segformer/image_processing_segformer.py +484 -0
- docs/transformers/build/lib/transformers/models/segformer/modeling_segformer.py +840 -0
- docs/transformers/build/lib/transformers/models/segformer/modeling_tf_segformer.py +1045 -0
- docs/transformers/build/lib/transformers/models/seggpt/__init__.py +28 -0
- docs/transformers/build/lib/transformers/models/seggpt/configuration_seggpt.py +143 -0
- docs/transformers/build/lib/transformers/models/seggpt/convert_seggpt_to_hf.py +221 -0
- docs/transformers/build/lib/transformers/models/seggpt/image_processing_seggpt.py +618 -0
- docs/transformers/build/lib/transformers/models/seggpt/modeling_seggpt.py +1031 -0
- docs/transformers/build/lib/transformers/models/sew/__init__.py +27 -0
- docs/transformers/build/lib/transformers/models/sew/configuration_sew.py +256 -0
- docs/transformers/build/lib/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py +305 -0
- docs/transformers/build/lib/transformers/models/sew/modeling_sew.py +1498 -0
- docs/transformers/build/lib/transformers/models/sew_d/__init__.py +27 -0
- docs/transformers/build/lib/transformers/models/sew_d/configuration_sew_d.py +291 -0
- docs/transformers/build/lib/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py +317 -0
- docs/transformers/build/lib/transformers/models/sew_d/modeling_sew_d.py +1748 -0
- docs/transformers/build/lib/transformers/models/shieldgemma2/__init__.py +28 -0
- docs/transformers/build/lib/transformers/models/shieldgemma2/configuration_shieldgemma2.py +120 -0
- docs/transformers/build/lib/transformers/models/shieldgemma2/convert_shieldgemma2_weights_orbax_to_hf.py +470 -0
- docs/transformers/build/lib/transformers/models/shieldgemma2/modeling_shieldgemma2.py +220 -0
- docs/transformers/build/lib/transformers/models/shieldgemma2/processing_shieldgemma2.py +195 -0
- docs/transformers/build/lib/transformers/models/siglip/__init__.py +31 -0
- docs/transformers/build/lib/transformers/models/siglip/configuration_siglip.py +269 -0
- docs/transformers/build/lib/transformers/models/siglip/convert_siglip_to_hf.py +533 -0
- docs/transformers/build/lib/transformers/models/siglip/image_processing_siglip.py +244 -0
- docs/transformers/build/lib/transformers/models/siglip/image_processing_siglip_fast.py +41 -0
.dev_scripts/ci_container_test.sh
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
|
| 2 |
+
# pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
| 3 |
+
pip install -r requirements/tests.txt -i https://mirrors.aliyun.com/pypi/simple/
|
| 4 |
+
git config --global --add safe.directory /ms-swift
|
| 5 |
+
git config --global user.email tmp
|
| 6 |
+
git config --global user.name tmp.com
|
| 7 |
+
|
| 8 |
+
# linter test
|
| 9 |
+
# use internal project for pre-commit due to the network problem
|
| 10 |
+
if [ `git remote -v | grep alibaba | wc -l` -gt 1 ]; then
|
| 11 |
+
pre-commit run -c .pre-commit-config_local.yaml --all-files
|
| 12 |
+
if [ $? -ne 0 ]; then
|
| 13 |
+
echo "linter test failed, please run 'pre-commit run --all-files' to check"
|
| 14 |
+
echo "From the repository folder"
|
| 15 |
+
echo "Run 'pip install -r requirements/tests.txt' install test dependencies."
|
| 16 |
+
echo "Run 'pre-commit install' install pre-commit hooks."
|
| 17 |
+
echo "Finally run linter with command: 'pre-commit run --all-files' to check."
|
| 18 |
+
echo "Ensure there is no failure!!!!!!!!"
|
| 19 |
+
exit -1
|
| 20 |
+
fi
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
pip install -r requirements/framework.txt -U -i https://mirrors.aliyun.com/pypi/simple/
|
| 24 |
+
pip install diffusers decord einops -U -i https://mirrors.aliyun.com/pypi/simple/
|
| 25 |
+
pip install autoawq -U --no-deps
|
| 26 |
+
|
| 27 |
+
# test with install
|
| 28 |
+
pip install .
|
| 29 |
+
pip install auto_gptq bitsandbytes deepspeed -U -i https://mirrors.aliyun.com/pypi/simple/
|
| 30 |
+
else
|
| 31 |
+
echo "Running case in release image, run case directly!"
|
| 32 |
+
fi
|
| 33 |
+
# remove torch_extensions folder to avoid ci hang.
|
| 34 |
+
rm -rf ~/.cache/torch_extensions
|
| 35 |
+
if [ $# -eq 0 ]; then
|
| 36 |
+
ci_command="python tests/run.py --subprocess"
|
| 37 |
+
else
|
| 38 |
+
ci_command="$@"
|
| 39 |
+
fi
|
| 40 |
+
echo "Running case with command: $ci_command"
|
| 41 |
+
$ci_command
|
4JOB_TRAIN.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
MANIFEST.in
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
recursive-include swift/utils *.py
|
| 2 |
+
recursive-include swift/llm/dataset/data *.*
|
| 3 |
+
recursive-include swift/llm/ds_config *.json
|
| 4 |
+
recursive-include requirements *.txt
|
| 5 |
+
recursive-include swift/plugin/loss_scale/config *.json
|
Makefile
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
WHL_BUILD_DIR :=package
|
| 2 |
+
DOC_BUILD_DIR :=docs/build/
|
| 3 |
+
|
| 4 |
+
# default rule
|
| 5 |
+
default: whl docs
|
| 6 |
+
|
| 7 |
+
.PHONY: docs
|
| 8 |
+
docs:
|
| 9 |
+
bash .dev_scripts/build_docs.sh
|
| 10 |
+
|
| 11 |
+
.PHONY: linter
|
| 12 |
+
linter:
|
| 13 |
+
bash .dev_scripts/linter.sh
|
| 14 |
+
|
| 15 |
+
.PHONY: test
|
| 16 |
+
test:
|
| 17 |
+
bash .dev_scripts/citest.sh
|
| 18 |
+
|
| 19 |
+
.PHONY: whl
|
| 20 |
+
whl:
|
| 21 |
+
python setup.py sdist bdist_wheel
|
| 22 |
+
|
| 23 |
+
.PHONY: clean
|
| 24 |
+
clean:
|
| 25 |
+
rm -rf $(WHL_BUILD_DIR) $(DOC_BUILD_DIR)
|
README.md
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SWIFT (Scalable lightWeight Infrastructure for Fine-Tuning)
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
<br>
|
| 5 |
+
<img src="asset/banner.png"/>
|
| 6 |
+
<br>
|
| 7 |
+
<p>
|
| 8 |
+
<p align="center">
|
| 9 |
+
<a href="https://modelscope.cn/home">ModelScope Community Website</a>
|
| 10 |
+
<br>
|
| 11 |
+
<a href="README_CN.md">中文</a>   |   English  
|
| 12 |
+
</p>
|
| 13 |
+
|
| 14 |
+
<p align="center">
|
| 15 |
+
<img src="https://img.shields.io/badge/python-3.10-5be.svg">
|
| 16 |
+
<img src="https://img.shields.io/badge/pytorch-%E2%89%A52.0-orange.svg">
|
| 17 |
+
<a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.19-5D91D4.svg"></a>
|
| 18 |
+
<a href="https://pypi.org/project/ms-swift/"><img src="https://badge.fury.io/py/ms-swift.svg"></a>
|
| 19 |
+
<a href="https://github.com/modelscope/swift/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/swift"></a>
|
| 20 |
+
<a href="https://pepy.tech/project/ms-swift"><img src="https://pepy.tech/badge/ms-swift"></a>
|
| 21 |
+
<a href="https://github.com/modelscope/swift/pulls"><img src="https://img.shields.io/badge/PR-welcome-55EB99.svg"></a>
|
| 22 |
+
</p>
|
| 23 |
+
|
| 24 |
+
<p align="center">
|
| 25 |
+
<a href="https://trendshift.io/repositories/6427" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6427" alt="modelscope%2Fswift | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
| 26 |
+
</p>
|
| 27 |
+
|
| 28 |
+
<p align="center">
|
| 29 |
+
<a href="https://arxiv.org/abs/2408.05517">Paper</a>   | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a>   |   <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a>  
|
| 30 |
+
</p>
|
| 31 |
+
|
| 32 |
+
## 📖 Table of Contents
|
| 33 |
+
- [Groups](#-Groups)
|
| 34 |
+
- [Introduction](#-introduction)
|
| 35 |
+
- [News](#-news)
|
| 36 |
+
- [Installation](#%EF%B8%8F-installation)
|
| 37 |
+
- [Quick Start](#-quick-Start)
|
| 38 |
+
- [Usage](#-Usage)
|
| 39 |
+
- [License](#-License)
|
| 40 |
+
- [Citation](#-citation)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
## ☎ Groups
|
| 44 |
+
|
| 45 |
+
You can contact us and communicate with us by adding our group:
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
[Discord Group](https://discord.com/invite/D27yfEFVz5) | WeChat Group
|
| 49 |
+
:-------------------------:|:-------------------------:
|
| 50 |
+
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
## 📝 Introduction
|
| 54 |
+
🍲 ms-swift is an official framework provided by the ModelScope community for fine-tuning and deploying large language models and multi-modal large models. It currently supports the training (pre-training, fine-tuning, human alignment), inference, evaluation, quantization, and deployment of 500+ large models and 200+ multi-modal large models. These large language models (LLMs) include models such as Qwen3, Qwen3-MoE, Qwen2.5, InternLM3, GLM4, Mistral, DeepSeek-R1, Yi1.5, TeleChat2, Baichuan2, and Gemma2. The multi-modal LLMs include models such as Qwen2.5-VL, Qwen2-Audio, Llama3.4, Llava, InternVL2.5, MiniCPM-V-2.6, GLM4v, Xcomposer2.5, Yi-VL, DeepSeek-VL2, Phi3.5-Vision, and GOT-OCR2.
|
| 55 |
+
|
| 56 |
+
🍔 Additionally, ms-swift incorporates the latest training technologies, including lightweight techniques such as LoRA, QLoRA, Llama-Pro, LongLoRA, GaLore, Q-GaLore, LoRA+, LISA, DoRA, FourierFt, ReFT, UnSloth, and Liger, as well as human alignment training methods like DPO, GRPO, RM, PPO, KTO, CPO, SimPO, and ORPO. ms-swift supports acceleration of inference, evaluation, and deployment modules using vLLM and LMDeploy, and it supports model quantization with technologies like GPTQ, AWQ, and BNB. Furthermore, ms-swift offers a Gradio-based Web UI and a wealth of best practices.
|
| 57 |
+
|
| 58 |
+
**Why choose ms-swift?**
|
| 59 |
+
|
| 60 |
+
- 🍎 **Model Types**: Supports 500+ pure text large models, **200+ multi-modal large models**, as well as All-to-All multi-modal models, sequence classification models, and embedding models, **covering the entire process from training to deployment**.
|
| 61 |
+
- **Dataset Types**: Comes with 150+ pre-training, fine-tuning, human alignment, multi-modal datasets, and supports custom datasets.
|
| 62 |
+
- **Hardware Support**: Compatible with CPU, RTX series, T4/V100, A10/A100/H100, Ascend NPU, MPS, etc.
|
| 63 |
+
- 🍊 **Lightweight Training**: Supports lightweight fine-tuning methods like LoRA, QLoRA, DoRA, LoRA+, ReFT, RS-LoRA, LLaMAPro, Adapter, GaLore, Q-Galore, LISA, UnSloth, Liger-Kernel.
|
| 64 |
+
- **Distributed Training**: Supports distributed data parallel (DDP), device_map simple model parallelism, DeepSpeed ZeRO2/ZeRO3, FSDP, and other distributed training techniques.
|
| 65 |
+
- **Quantization Training**: Supports training quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ.
|
| 66 |
+
- **RLHF Training**: Supports human alignment training methods such as DPO, GRPO, RM, PPO, KTO, CPO, SimPO, ORPO for both pure text and multi-modal large models.
|
| 67 |
+
- 🍓 **Multi-Modal Training**: Supports training on different modalities like images, videos, and audio, for tasks like VQA, captioning, OCR, and grounding.
|
| 68 |
+
- **Interface Training**: Provides capabilities for training, inference, evaluation, quantization through an interface, completing the whole large model pipeline.
|
| 69 |
+
- **Plugin and Extension**: Supports custom model and dataset extensions, as well as customization of components like loss, metric, trainer, loss-scale, callback, optimizer.
|
| 70 |
+
- 🍉 **Toolbox Capabilities**: Offers not only training support for large models and multi-modal large models but also covers the entire process of inference, evaluation, quantization, and deployment.
|
| 71 |
+
- **Inference Acceleration**: Supports inference acceleration engines like PyTorch, vLLM, LmDeploy, and provides OpenAI API for accelerating inference, deployment, and evaluation modules.
|
| 72 |
+
- **Model Evaluation**: Uses EvalScope as the evaluation backend and supports evaluation on 100+ datasets for both pure text and multi-modal models.
|
| 73 |
+
- **Model Quantization**: Supports AWQ, GPTQ, and BNB quantized exports, with models that can use vLLM/LmDeploy for inference acceleration and continue training.
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
## 🎉 News
|
| 77 |
+
- 🎁 2025.05.11: GRPO now supports custom processing logic for reward models. See the GenRM example [here](./docs/source_en/Instruction/GRPO.md#customized-reward-models) .
|
| 78 |
+
- 🎁 2025.04.15: The ms-swift paper has been accepted by AAAI 2025. You can find the paper at [this link](https://ojs.aaai.org/index.php/AAAI/article/view/35383).
|
| 79 |
+
- 🎁 2025.03.23: Multi-round GRPO is now supported for training multi-turn dialogue scenarios (e.g., agent tool calling). Please refer to the [training script](https://idealab.alibaba-inc.com/examples/train/grpo/internal/train_multi_round.sh).
|
| 80 |
+
- 🎁 2025.03.16: Support for Megatron's parallel training techniques is now available. Please see the [Megatron-SWIFT training documentation](https://swift.readthedocs.io/zh-cn/latest/Instruction/Megatron-SWIFT训练.html).
|
| 81 |
+
- 🎁 2025.03.15: Fine-tuning of embedding models for both pure text and multimodal models is supported. Please check the [training script](https://idealab.alibaba-inc.com/examples/train/embedding).
|
| 82 |
+
- 🎁 2025.03.05: The hybrid mode for GRPO is supported, with a script for training a 72B model on 4 GPUs (4*80G) available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/train_72b_4gpu.sh). Tensor parallelism with vllm is also supported, with the training script available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/multi_gpu_mp_colocate.sh).
|
| 83 |
+
- 🎁 2025.02.21: The GRPO algorithm now supports LMDeploy, with the training script available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/full_lmdeploy.sh). Additionally, the performance of the GRPO algorithm has been tested, achieving a training speed increase of up to 300% using various tricks. Please check the WanDB table [here](https://wandb.ai/tastelikefeet/grpo_perf_test?nw=nwuseryuzezyz).
|
| 84 |
+
- 🎁 2025.02.21: The `swift sample` command is now supported. The reinforcement fine-tuning script can be found [here](https://idealab.alibaba-inc.com/docs/source/Instruction/强化微调.md), and the large model API distillation sampling script is available [here](https://idealab.alibaba-inc.com/examples/sampler/distill/distill.sh).
|
| 85 |
+
- 🔥 2025.02.12: Support for the GRPO (Group Relative Policy Optimization) training algorithm has been added. Documentation is available [here](https://idealab.alibaba-inc.com/docs/source/Instruction/GRPO.md).
|
| 86 |
+
- 🎁 2024.12.04: Major update to **ms-swift 3.0**. Please refer to the [release notes and changes](https://swift.readthedocs.io/zh-cn/latest/Instruction/ReleaseNote3.0.html).
|
| 87 |
+
<details><summary>More</summary>
|
| 88 |
+
|
| 89 |
+
- 🎉 2024.08.12: The ms-swift paper has been published on arXiv and can be read [here](https://arxiv.org/abs/2408.05517).
|
| 90 |
+
- 🔥 2024.08.05: Support for using [evalscope](https://github.com/modelscope/evalscope/) as a backend for evaluating large models and multimodal models.
|
| 91 |
+
- 🔥 2024.07.29: Support for using [vllm](https://github.com/vllm-project/vllm) and [lmdeploy](https://github.com/InternLM/lmdeploy) to accelerate inference for large models and multimodal models. When performing infer/deploy/eval, you can specify `--infer_backend vllm/lmdeploy`.
|
| 92 |
+
- 🔥 2024.07.24: Support for human preference alignment training for multimodal large models, including DPO/ORPO/SimPO/CPO/KTO/RM/PPO.
|
| 93 |
+
- 🔥 2024.02.01: Support for Agent training! The training algorithm is derived from [this paper](https://arxiv.org/pdf/2309.00986.pdf).
|
| 94 |
+
</details>
|
| 95 |
+
|
| 96 |
+
## 🛠️ Installation
|
| 97 |
+
To install using pip:
|
| 98 |
+
```shell
|
| 99 |
+
pip install ms-swift -U
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
To install from source:
|
| 103 |
+
```shell
|
| 104 |
+
# pip install git+https://github.com/modelscope/ms-swift.git
|
| 105 |
+
|
| 106 |
+
git clone https://github.com/modelscope/ms-swift.git
|
| 107 |
+
cd ms-swift
|
| 108 |
+
pip install -e .
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
Running Environment:
|
| 112 |
+
|
| 113 |
+
| | Range | Recommended | Notes |
|
| 114 |
+
| ------------ |--------------| ----------- | ----------------------------------------- |
|
| 115 |
+
| python | >=3.9 | 3.10 | |
|
| 116 |
+
| cuda | | cuda12 | No need to install if using CPU, NPU, MPS |
|
| 117 |
+
| torch | >=2.0 | | |
|
| 118 |
+
| transformers | >=4.33 | 4.51 | |
|
| 119 |
+
| modelscope | >=1.23 | | |
|
| 120 |
+
| peft | >=0.11,<0.16 | ||
|
| 121 |
+
| trl | >=0.13,<0.18 | 0.17 |RLHF|
|
| 122 |
+
| deepspeed | >=0.14 | 0.14.5 | Training |
|
| 123 |
+
| vllm | >=0.5.1 | 0.7.3/0.8 | Inference/Deployment/Evaluation |
|
| 124 |
+
| lmdeploy | >=0.5 | 0.8 | Inference/Deployment/Evaluation |
|
| 125 |
+
| evalscope | >=0.11 | | Evaluation |
|
| 126 |
+
|
| 127 |
+
For more optional dependencies, you can refer to [here](https://github.com/modelscope/ms-swift/blob/main/requirements/install_all.sh).
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
## 🚀 Quick Start
|
| 131 |
+
|
| 132 |
+
10 minutes of self-cognition fine-tuning of Qwen2.5-7B-Instruct on a single 3090 GPU:
|
| 133 |
+
|
| 134 |
+
### Command Line Interface
|
| 135 |
+
|
| 136 |
+
```shell
|
| 137 |
+
# 22GB
|
| 138 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 139 |
+
swift sft \
|
| 140 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 141 |
+
--train_type lora \
|
| 142 |
+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
|
| 143 |
+
'AI-ModelScope/alpaca-gpt4-data-en#500' \
|
| 144 |
+
'swift/self-cognition#500' \
|
| 145 |
+
--torch_dtype bfloat16 \
|
| 146 |
+
--num_train_epochs 1 \
|
| 147 |
+
--per_device_train_batch_size 1 \
|
| 148 |
+
--per_device_eval_batch_size 1 \
|
| 149 |
+
--learning_rate 1e-4 \
|
| 150 |
+
--lora_rank 8 \
|
| 151 |
+
--lora_alpha 32 \
|
| 152 |
+
--target_modules all-linear \
|
| 153 |
+
--gradient_accumulation_steps 16 \
|
| 154 |
+
--eval_steps 50 \
|
| 155 |
+
--save_steps 50 \
|
| 156 |
+
--save_total_limit 2 \
|
| 157 |
+
--logging_steps 5 \
|
| 158 |
+
--max_length 2048 \
|
| 159 |
+
--output_dir output \
|
| 160 |
+
--system 'You are a helpful assistant.' \
|
| 161 |
+
--warmup_ratio 0.05 \
|
| 162 |
+
--dataloader_num_workers 4 \
|
| 163 |
+
--model_author swift \
|
| 164 |
+
--model_name swift-robot
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
Tips:
|
| 168 |
+
|
| 169 |
+
- If you want to train with a custom dataset, you can refer to [this guide](https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html) to organize your dataset format and specify `--dataset <dataset_path>`.
|
| 170 |
+
- The `--model_author` and `--model_name` parameters are only effective when the dataset includes `swift/self-cognition`.
|
| 171 |
+
- To train with a different model, simply modify `--model <model_id/model_path>`.
|
| 172 |
+
- By default, ModelScope is used for downloading models and datasets. If you want to use HuggingFace, simply specify `--use_hf true`.
|
| 173 |
+
|
| 174 |
+
After training is complete, use the following command to infer with the trained weights:
|
| 175 |
+
|
| 176 |
+
- Here, `--adapters` should be replaced with the last checkpoint folder generated during training. Since the adapters folder contains the training parameter file `args.json`, there is no need to specify `--model`, `--system` separately; Swift will automatically read these parameters. To disable this behavior, you can set `--load_args false`.
|
| 177 |
+
|
| 178 |
+
```shell
|
| 179 |
+
# Using an interactive command line for inference.
|
| 180 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 181 |
+
swift infer \
|
| 182 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 183 |
+
--stream true \
|
| 184 |
+
--temperature 0 \
|
| 185 |
+
--max_new_tokens 2048
|
| 186 |
+
|
| 187 |
+
# merge-lora and use vLLM for inference acceleration
|
| 188 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 189 |
+
swift infer \
|
| 190 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 191 |
+
--stream true \
|
| 192 |
+
--merge_lora true \
|
| 193 |
+
--infer_backend vllm \
|
| 194 |
+
--max_model_len 8192 \
|
| 195 |
+
--temperature 0 \
|
| 196 |
+
--max_new_tokens 2048
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
Finally, use the following command to push the model to ModelScope:
|
| 200 |
+
|
| 201 |
+
```shell
|
| 202 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 203 |
+
swift export \
|
| 204 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 205 |
+
--push_to_hub true \
|
| 206 |
+
--hub_model_id '<your-model-id>' \
|
| 207 |
+
--hub_token '<your-sdk-token>' \
|
| 208 |
+
--use_hf false
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
### Web-UI
|
| 213 |
+
The Web-UI is a **zero-threshold** training and deployment interface solution based on Gradio interface technology. For more details, you can check [here](https://swift.readthedocs.io/en/latest/GetStarted/Web-UI.html).
|
| 214 |
+
|
| 215 |
+
```shell
|
| 216 |
+
SWIFT_UI_LANG=en swift web-ui
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+

|
| 220 |
+
|
| 221 |
+
### Using Python
|
| 222 |
+
|
| 223 |
+
ms-swift also supports training and inference using Python. Below is pseudocode for training and inference. For more details, you can refer to [here](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb).
|
| 224 |
+
|
| 225 |
+
Training:
|
| 226 |
+
|
| 227 |
+
```python
|
| 228 |
+
# Retrieve the model and template, and add a trainable LoRA module
|
| 229 |
+
model, tokenizer = get_model_tokenizer(model_id_or_path, ...)
|
| 230 |
+
template = get_template(model.model_meta.template, tokenizer, ...)
|
| 231 |
+
model = Swift.prepare_model(model, lora_config)
|
| 232 |
+
|
| 233 |
+
# Download and load the dataset, and encode the text into tokens
|
| 234 |
+
train_dataset, val_dataset = load_dataset(dataset_id_or_path, ...)
|
| 235 |
+
train_dataset = EncodePreprocessor(template=template)(train_dataset, num_proc=num_proc)
|
| 236 |
+
val_dataset = EncodePreprocessor(template=template)(val_dataset, num_proc=num_proc)
|
| 237 |
+
|
| 238 |
+
# Train the model
|
| 239 |
+
trainer = Seq2SeqTrainer(
|
| 240 |
+
model=model,
|
| 241 |
+
args=training_args,
|
| 242 |
+
data_collator=template.data_collator,
|
| 243 |
+
train_dataset=train_dataset,
|
| 244 |
+
eval_dataset=val_dataset,
|
| 245 |
+
template=template,
|
| 246 |
+
)
|
| 247 |
+
trainer.train()
|
| 248 |
+
```
|
| 249 |
+
Inference:
|
| 250 |
+
|
| 251 |
+
```python
|
| 252 |
+
# Perform inference using the native PyTorch engine
|
| 253 |
+
engine = PtEngine(model_id_or_path, adapters=[lora_checkpoint])
|
| 254 |
+
infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}])
|
| 255 |
+
request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature)
|
| 256 |
+
|
| 257 |
+
resp_list = engine.infer([infer_request], request_config)
|
| 258 |
+
print(f'response: {resp_list[0].choices[0].message.content}')
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
## ✨ Usage
|
| 262 |
+
Here is a minimal example of training to deployment using ms-swift. For more details, you can check the [examples](https://github.com/modelscope/ms-swift/tree/main/examples).
|
| 263 |
+
|
| 264 |
+
- If you want to use other models or datasets (including multimodal models and datasets), you only need to modify `--model` to specify the corresponding model's ID or path, and modify `--dataset` to specify the corresponding dataset's ID or path.
|
| 265 |
+
- By default, ModelScope is used for downloading models and datasets. If you want to use HuggingFace, simply specify `--use_hf true`.
|
| 266 |
+
|
| 267 |
+
| Useful Links |
|
| 268 |
+
| ------ |
|
| 269 |
+
| [🔥Command Line Parameters](https://swift.readthedocs.io/en/latest/Instruction/Command-line-parameters.html) |
|
| 270 |
+
| [Supported Models and Datasets](https://swift.readthedocs.io/en/latest/Instruction/Supported-models-and-datasets.html) |
|
| 271 |
+
| [Custom Models](https://swift.readthedocs.io/en/latest/Customization/Custom-model.html), [🔥Custom Datasets](https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html) |
|
| 272 |
+
| [LLM Tutorial](https://github.com/modelscope/modelscope-classroom/tree/main/LLM-tutorial) |
|
| 273 |
+
|
| 274 |
+
### Training
|
| 275 |
+
|
| 276 |
+
Supported Training Methods:
|
| 277 |
+
|
| 278 |
+
| Method | Full-Parameter | LoRA | QLoRA | Deepspeed | Multi-Node | Multi-Modal |
|
| 279 |
+
|------------------------------------|--------------------------------------------------------------|---------------------------------------------------------------------------------------------|--------------------------------------------------------------|--------------------------------------------------------------|--------------------------------------------------------------|----------------------------------------------------------------------------------------------|
|
| 280 |
+
| Pre-training | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/pretrain/train.sh) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| 281 |
+
| Instruction Supervised Fine-tuning | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/train.sh) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/lora_sft.sh) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/qlora) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-gpu/deepspeed) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal) |
|
| 282 |
+
| DPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/dpo.sh) |
|
| 283 |
+
| GRPO Training | [✅]((https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/grpo_zero2.sh)) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/multi_node) | ✅ |
|
| 284 |
+
| Reward Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | ✅ |
|
| 285 |
+
| PPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | ❌ |
|
| 286 |
+
| KTO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
|
| 287 |
+
| CPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | ✅ |
|
| 288 |
+
| SimPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | ✅ |
|
| 289 |
+
| ORPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | ✅ |
|
| 290 |
+
| Classification Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_5/sft.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_vl/sft.sh) |
|
| 291 |
+
| Embedding Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gte.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gme.sh) |
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
Pre-training:
|
| 296 |
+
```shell
|
| 297 |
+
# 8*A100
|
| 298 |
+
NPROC_PER_NODE=8 \
|
| 299 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
| 300 |
+
swift pt \
|
| 301 |
+
--model Qwen/Qwen2.5-7B \
|
| 302 |
+
--dataset swift/chinese-c4 \
|
| 303 |
+
--streaming true \
|
| 304 |
+
--train_type full \
|
| 305 |
+
--deepspeed zero2 \
|
| 306 |
+
--output_dir output \
|
| 307 |
+
--max_steps 10000 \
|
| 308 |
+
...
|
| 309 |
+
```
|
| 310 |
+
|
| 311 |
+
Fine-tuning:
|
| 312 |
+
```shell
|
| 313 |
+
CUDA_VISIBLE_DEVICES=0 swift sft \
|
| 314 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 315 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-en \
|
| 316 |
+
--train_type lora \
|
| 317 |
+
--output_dir output \
|
| 318 |
+
...
|
| 319 |
+
```
|
| 320 |
+
|
| 321 |
+
RLHF:
|
| 322 |
+
```shell
|
| 323 |
+
CUDA_VISIBLE_DEVICES=0 swift rlhf \
|
| 324 |
+
--rlhf_type dpo \
|
| 325 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 326 |
+
--dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
|
| 327 |
+
--train_type lora \
|
| 328 |
+
--output_dir output \
|
| 329 |
+
...
|
| 330 |
+
```
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
### Inference
|
| 334 |
+
```shell
|
| 335 |
+
CUDA_VISIBLE_DEVICES=0 swift infer \
|
| 336 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 337 |
+
--stream true \
|
| 338 |
+
--infer_backend pt \
|
| 339 |
+
--max_new_tokens 2048
|
| 340 |
+
|
| 341 |
+
# LoRA
|
| 342 |
+
CUDA_VISIBLE_DEVICES=0 swift infer \
|
| 343 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 344 |
+
--adapters swift/test_lora \
|
| 345 |
+
--stream true \
|
| 346 |
+
--infer_backend pt \
|
| 347 |
+
--temperature 0 \
|
| 348 |
+
--max_new_tokens 2048
|
| 349 |
+
```
|
| 350 |
+
|
| 351 |
+
### Interface Inference
|
| 352 |
+
```shell
|
| 353 |
+
CUDA_VISIBLE_DEVICES=0 swift app \
|
| 354 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 355 |
+
--stream true \
|
| 356 |
+
--infer_backend pt \
|
| 357 |
+
--max_new_tokens 2048
|
| 358 |
+
```
|
| 359 |
+
|
| 360 |
+
### Deployment
|
| 361 |
+
```shell
|
| 362 |
+
CUDA_VISIBLE_DEVICES=0 swift deploy \
|
| 363 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 364 |
+
--infer_backend vllm
|
| 365 |
+
```
|
| 366 |
+
|
| 367 |
+
### Sampling
|
| 368 |
+
```shell
|
| 369 |
+
CUDA_VISIBLE_DEVICES=0 swift sample \
|
| 370 |
+
--model LLM-Research/Meta-Llama-3.1-8B-Instruct \
|
| 371 |
+
--sampler_engine pt \
|
| 372 |
+
--num_return_sequences 5 \
|
| 373 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh#5
|
| 374 |
+
```
|
| 375 |
+
|
| 376 |
+
### Evaluation
|
| 377 |
+
```shell
|
| 378 |
+
CUDA_VISIBLE_DEVICES=0 swift eval \
|
| 379 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 380 |
+
--infer_backend lmdeploy \
|
| 381 |
+
--eval_backend OpenCompass \
|
| 382 |
+
--eval_dataset ARC_c
|
| 383 |
+
```
|
| 384 |
+
|
| 385 |
+
### Quantization
|
| 386 |
+
```shell
|
| 387 |
+
CUDA_VISIBLE_DEVICES=0 swift export \
|
| 388 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 389 |
+
--quant_bits 4 --quant_method awq \
|
| 390 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh \
|
| 391 |
+
--output_dir Qwen2.5-7B-Instruct-AWQ
|
| 392 |
+
```
|
| 393 |
+
|
| 394 |
+
### Push Model
|
| 395 |
+
```shell
|
| 396 |
+
swift export \
|
| 397 |
+
--model <model-path> \
|
| 398 |
+
--push_to_hub true \
|
| 399 |
+
--hub_model_id '<model-id>' \
|
| 400 |
+
--hub_token '<sdk-token>'
|
| 401 |
+
```
|
| 402 |
+
|
| 403 |
+
## 🏛 License
|
| 404 |
+
|
| 405 |
+
This framework is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE). For models and datasets, please refer to the original resource page and follow the corresponding License.
|
| 406 |
+
|
| 407 |
+
## 📎 Citation
|
| 408 |
+
|
| 409 |
+
```bibtex
|
| 410 |
+
@misc{zhao2024swiftascalablelightweightinfrastructure,
|
| 411 |
+
title={SWIFT:A Scalable lightWeight Infrastructure for Fine-Tuning},
|
| 412 |
+
author={Yuze Zhao and Jintao Huang and Jinghan Hu and Xingjun Wang and Yunlin Mao and Daoze Zhang and Zeyinzi Jiang and Zhikai Wu and Baole Ai and Ang Wang and Wenmeng Zhou and Yingda Chen},
|
| 413 |
+
year={2024},
|
| 414 |
+
eprint={2408.05517},
|
| 415 |
+
archivePrefix={arXiv},
|
| 416 |
+
primaryClass={cs.CL},
|
| 417 |
+
url={https://arxiv.org/abs/2408.05517},
|
| 418 |
+
}
|
| 419 |
+
```
|
| 420 |
+
|
| 421 |
+
## Star History
|
| 422 |
+
|
| 423 |
+
[](https://star-history.com/#modelscope/ms-swift&Date)
|
README_CN.md
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SWIFT (Scalable lightWeight Infrastructure for Fine-Tuning)
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
<br>
|
| 5 |
+
<img src="asset/banner.png"/>
|
| 6 |
+
<br>
|
| 7 |
+
<p>
|
| 8 |
+
<p align="center">
|
| 9 |
+
<a href="https://modelscope.cn/home">魔搭社区官网</a>
|
| 10 |
+
<br>
|
| 11 |
+
中文  |  <a href="README.md">English</a> 
|
| 12 |
+
</p>
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
<p align="center">
|
| 16 |
+
<img src="https://img.shields.io/badge/python-3.10-5be.svg">
|
| 17 |
+
<img src="https://img.shields.io/badge/pytorch-%E2%89%A52.0-orange.svg">
|
| 18 |
+
<a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.19-5D91D4.svg"></a>
|
| 19 |
+
<a href="https://pypi.org/project/ms-swift/"><img src="https://badge.fury.io/py/ms-swift.svg"></a>
|
| 20 |
+
<a href="https://github.com/modelscope/swift/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/swift"></a>
|
| 21 |
+
<a href="https://pepy.tech/project/ms-swift"><img src="https://pepy.tech/badge/ms-swift"></a>
|
| 22 |
+
<a href="https://github.com/modelscope/swift/pulls"><img src="https://img.shields.io/badge/PR-welcome-55EB99.svg"></a>
|
| 23 |
+
</p>
|
| 24 |
+
|
| 25 |
+
<p align="center">
|
| 26 |
+
<a href="https://trendshift.io/repositories/6427" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6427" alt="modelscope%2Fswift | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
| 27 |
+
</p>
|
| 28 |
+
|
| 29 |
+
<p align="center">
|
| 30 |
+
<a href="https://arxiv.org/abs/2408.05517">论文</a>   | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a>   |   <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a>  
|
| 31 |
+
</p>
|
| 32 |
+
|
| 33 |
+
## 📖 目录
|
| 34 |
+
- [用户群](#-用户群)
|
| 35 |
+
- [简介](#-简介)
|
| 36 |
+
- [新闻](#-新闻)
|
| 37 |
+
- [安装](#%EF%B8%8F-安装)
|
| 38 |
+
- [快速开始](#-快速开始)
|
| 39 |
+
- [如何使用](#-如何使用)
|
| 40 |
+
- [License](#-license)
|
| 41 |
+
- [引用](#-引用)
|
| 42 |
+
|
| 43 |
+
## ☎ 用户群
|
| 44 |
+
|
| 45 |
+
请扫描下面的二维码来加入我们的交流群:
|
| 46 |
+
|
| 47 |
+
[Discord Group](https://discord.com/invite/D27yfEFVz5) | 微信群
|
| 48 |
+
:-------------------------:|:-------------------------:
|
| 49 |
+
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
|
| 50 |
+
|
| 51 |
+
## 📝 简介
|
| 52 |
+
🍲 ms-swift是魔搭社区提供的大模型与多模态大模型微调部署框架,现已支持500+大模型与200+多模态大模型的训练(预训练、微调、人类对齐)、推理、评测、量化与部署。其中大模型包括:Qwen3、Qwen3-MoE、Qwen2.5、InternLM3、GLM4、Mistral、DeepSeek-R1、Yi1.5、TeleChat2、Baichuan2、Gemma2等模型,多模态大模型包括:Qwen2.5-VL、Qwen2-Audio、Llama4、Llava、InternVL2.5、MiniCPM-V-2.6、GLM4v、Xcomposer2.5、Yi-VL、DeepSeek-VL2、Phi3.5-Vision、GOT-OCR2等模型。
|
| 53 |
+
|
| 54 |
+
🍔 除此之外,ms-swift汇集了最新的训练技术,包括LoRA、QLoRA、Llama-Pro、LongLoRA、GaLore、Q-GaLore、LoRA+、LISA、DoRA、FourierFt、ReFT、UnSloth、和Liger等轻量化训练技术,以及DPO、GRPO、RM、PPO、KTO、CPO、SimPO、ORPO等人类对齐训练方法。ms-swift支持使用vLLM和LMDeploy对推理、评测和部署模块进行加速,并支持使用GPTQ、AWQ、BNB等技术对大模型进行量化。ms-swift还提供了基于Gradio的Web-UI界面及丰富的最佳实践。
|
| 55 |
+
|
| 56 |
+
**为什么选择ms-swift?**
|
| 57 |
+
- 🍎 **模型类型**:支持500+纯文本大模型、**200+多模态大模型**以及All-to-All全模态模型、序列分类模型、Embedding模型**训练到部署全流程**。
|
| 58 |
+
- **数据集类型**:内置150+预训练、微调、人类对齐、多模态等各种类型的数据集,并支持自定义数据集。
|
| 59 |
+
- **硬件支持**:CPU、RTX系列、T4/V100、A10/A100/H100、Ascend NPU、MPS等。
|
| 60 |
+
- 🍊 **轻量训练**:支持了LoRA、QLoRA、DoRA、LoRA+、ReFT、RS-LoRA、LLaMAPro、Adapter、GaLore、Q-Galore、LISA、UnSloth、Liger-Kernel等轻量微调方式。
|
| 61 |
+
- **分布式训练**:支持分布式数据并行(DDP)、device_map简易模型并行、DeepSpeed ZeRO2 ZeRO3、FSDP等分布式训练技术。
|
| 62 |
+
- **量化训练**:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。
|
| 63 |
+
- **RLHF训练**:支持纯文本大模型和多模态大模型的DPO、GRPO、RM、PPO、KTO、CPO、SimPO、ORPO等人类对齐训练方法。
|
| 64 |
+
- 🍓 **多模态训练**:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。
|
| 65 |
+
- **界面训练**:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。
|
| 66 |
+
- **插件化与拓展**:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。
|
| 67 |
+
- 🍉 **工具箱能力**:不仅提供大模型和多模态大模型的训练支持,还涵盖其推理、评测、量化和部署全流程。
|
| 68 |
+
- **推理加速**:支持PyTorch、vLLM、LmDeploy推理加速引擎,并提供OpenAI接口,为推理、部署和评测模块提供加速。
|
| 69 |
+
- **模型评测**:以EvalScope作为评测后端,支持100+评测数据集对纯��本和多模态模型进行评测。
|
| 70 |
+
- **模型量化**:支持AWQ、GPTQ和BNB的量化导出,导出的模型支持使用vLLM/LmDeploy推理加速,并支持继续训练。
|
| 71 |
+
|
| 72 |
+
## 🎉 新闻
|
| 73 |
+
- 🎁 2025.05.11: GRPO中的奖励模型支持自定义处理逻辑,GenRM的例子参考[这里](./docs/source/Instruction/GRPO.md#自定义奖励模型)
|
| 74 |
+
- 🎁 2025.04.15: ms-swift论文已经被AAAI 2025接收,论文地址在[这里](https://ojs.aaai.org/index.php/AAAI/article/view/35383)。
|
| 75 |
+
- 🎁 2025.03.23: 支持了多轮GRPO,用于构建多轮对话场景的训练(例如agent tool calling),请查看[训练脚本](examples/train/grpo/internal/train_multi_round.sh)。
|
| 76 |
+
- 🎁 2025.03.16: 支持了Megatron的并行技术进行训练,请查看[Megatron-SWIFT训练文档](https://swift.readthedocs.io/zh-cn/latest/Instruction/Megatron-SWIFT训练.html)。
|
| 77 |
+
- 🎁 2025.03.15: 支持纯文本和多模态模型的embedding模型的微调,请查看[训练脚本](examples/train/embedding)。
|
| 78 |
+
- 🎁 2025.03.05: 支持GRPO的hybrid模式,4GPU(4*80G)训练72B模型的脚本参考[这里](examples/train/grpo/internal/train_72b_4gpu.sh)。同时支持vllm的tensor并行,训练脚本参考[这里](examples/train/grpo/internal/multi_gpu_mp_colocate.sh)。
|
| 79 |
+
- 🎁 2025.02.21: GRPO算法支持使用LMDeploy,训练脚本参考[这里](examples/train/grpo/internal/full_lmdeploy.sh)。此外测试了GRPO算法的性能,使用一些tricks使训练速度提高到300%。WanDB表格请查看[这里](https://wandb.ai/tastelikefeet/grpo_perf_test?nw=nwuseryuzezyz)。
|
| 80 |
+
- 🎁 2025.02.21: 支持`swift sample`命令。强化微调脚本参考[这里](docs/source/Instruction/强化微调.md),大模型API蒸馏采样脚本参考[这里](examples/sampler/distill/distill.sh)。
|
| 81 |
+
- 🔥 2025.02.12: 支持GRPO (Group Relative Policy Optimization) 训练算法,文档参考[这里](docs/source/Instruction/GRPO.md)。
|
| 82 |
+
- 🎁 2024.12.04: **ms-swift3.0**大版本更新。请查看[发布说明和更改](https://swift.readthedocs.io/zh-cn/latest/Instruction/ReleaseNote3.0.html)。
|
| 83 |
+
<details><summary>更多</summary>
|
| 84 |
+
|
| 85 |
+
- 🎉 2024.08.12: ms-swift论文已经发布到arXiv上,可以点击[这里](https://arxiv.org/abs/2408.05517)阅读。
|
| 86 |
+
- 🔥 2024.08.05: 支持使用[evalscope](https://github.com/modelscope/evalscope/)作为后端进行大模型和多模态模型的评测。
|
| 87 |
+
- 🔥 2024.07.29: 支持使用[vllm](https://github.com/vllm-project/vllm), [lmdeploy](https://github.com/InternLM/lmdeploy)对大模型和多模态大模型进行推理加速,在infer/deploy/eval时额外指定`--infer_backend vllm/lmdeploy`即可。
|
| 88 |
+
- 🔥 2024.07.24: 支持对多模态大模型进行人类偏好对齐训练,包括DPO/ORPO/SimPO/CPO/KTO/RM/PPO。
|
| 89 |
+
- 🔥 2024.02.01: 支持Agent训练!训练算法源自这篇[论文](https://arxiv.org/pdf/2309.00986.pdf)。
|
| 90 |
+
</details>
|
| 91 |
+
|
| 92 |
+
## 🛠️ 安装
|
| 93 |
+
使用pip进行安装:
|
| 94 |
+
```shell
|
| 95 |
+
pip install ms-swift -U
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
从源代码安装:
|
| 99 |
+
```shell
|
| 100 |
+
# pip install git+https://github.com/modelscope/ms-swift.git
|
| 101 |
+
|
| 102 |
+
git clone https://github.com/modelscope/ms-swift.git
|
| 103 |
+
cd ms-swift
|
| 104 |
+
pip install -e .
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
运行环境:
|
| 108 |
+
|
| 109 |
+
| | 范围 | 推荐 | 备注 |
|
| 110 |
+
| ------ |--------------| ---- | --|
|
| 111 |
+
| python | >=3.9 | 3.10 ||
|
| 112 |
+
| cuda | | cuda12 |使用cpu、npu、mps则无需安装|
|
| 113 |
+
| torch | >=2.0 | ||
|
| 114 |
+
| transformers | >=4.33 | 4.51 ||
|
| 115 |
+
| modelscope | >=1.23 | ||
|
| 116 |
+
| peft | >=0.11,<0.16 | ||
|
| 117 |
+
| trl | >=0.13,<0.18 | 0.17 |RLHF|
|
| 118 |
+
| deepspeed | >=0.14 | 0.14.5 |训练|
|
| 119 |
+
| vllm | >=0.5.1 | 0.7.3/0.8 |推理/部署/评测|
|
| 120 |
+
| lmdeploy | >=0.5 | 0.8 |推理/部署/评测|
|
| 121 |
+
| evalscope | >=0.11 | |评测|
|
| 122 |
+
|
| 123 |
+
更多可选依赖可以参考[这里](https://github.com/modelscope/ms-swift/blob/main/requirements/install_all.sh)。
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
## 🚀 快速开始
|
| 127 |
+
|
| 128 |
+
**10分钟**在单卡3090上对Qwen2.5-7B-Instruct进行自我认知微调:
|
| 129 |
+
|
| 130 |
+
### 命令行
|
| 131 |
+
```shell
|
| 132 |
+
# 22GB
|
| 133 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 134 |
+
swift sft \
|
| 135 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 136 |
+
--train_type lora \
|
| 137 |
+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
|
| 138 |
+
'AI-ModelScope/alpaca-gpt4-data-en#500' \
|
| 139 |
+
'swift/self-cognition#500' \
|
| 140 |
+
--torch_dtype bfloat16 \
|
| 141 |
+
--num_train_epochs 1 \
|
| 142 |
+
--per_device_train_batch_size 1 \
|
| 143 |
+
--per_device_eval_batch_size 1 \
|
| 144 |
+
--learning_rate 1e-4 \
|
| 145 |
+
--lora_rank 8 \
|
| 146 |
+
--lora_alpha 32 \
|
| 147 |
+
--target_modules all-linear \
|
| 148 |
+
--gradient_accumulation_steps 16 \
|
| 149 |
+
--eval_steps 50 \
|
| 150 |
+
--save_steps 50 \
|
| 151 |
+
--save_total_limit 2 \
|
| 152 |
+
--logging_steps 5 \
|
| 153 |
+
--max_length 2048 \
|
| 154 |
+
--output_dir output \
|
| 155 |
+
--system 'You are a helpful assistant.' \
|
| 156 |
+
--warmup_ratio 0.05 \
|
| 157 |
+
--dataloader_num_workers 4 \
|
| 158 |
+
--model_author swift \
|
| 159 |
+
--model_name swift-robot
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
小贴士:
|
| 163 |
+
- 如果要使用自定义数据集进行训练,你可以参考[这里](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%95%B0%E6%8D%AE%E9%9B%86.html)组织数据集格式,并指定`--dataset <dataset_path>`。
|
| 164 |
+
- `--model_author`和`--model_name`参数只有当数据集中包含`swift/self-cognition`时才生效。
|
| 165 |
+
- 如果要使用其他模型进行训练,你只需要修改`--model <model_id/model_path>`即可。
|
| 166 |
+
- 默认使用ModelScope进行模型和数据集的下载。如果要使用HuggingFace,指定`--use_hf true`即可。
|
| 167 |
+
|
| 168 |
+
训练完成后,使用以下命令对训练后的权重进行推理:
|
| 169 |
+
- 这里的`--adapters`需要替换成训练生成的last checkpoint文件夹。由于adapters文件夹中包含了训练的参数文件`args.json`,因此不需要额外指定`--model`,`--system`,swift会自动读取这些参数。如果要关闭此行为,可以设置`--load_args false`。
|
| 170 |
+
|
| 171 |
+
```shell
|
| 172 |
+
# 使用交互式命令行进行推理
|
| 173 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 174 |
+
swift infer \
|
| 175 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 176 |
+
--stream true \
|
| 177 |
+
--temperature 0 \
|
| 178 |
+
--max_new_tokens 2048
|
| 179 |
+
|
| 180 |
+
# merge-lora并使用vLLM进行推理加速
|
| 181 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 182 |
+
swift infer \
|
| 183 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 184 |
+
--stream true \
|
| 185 |
+
--merge_lora true \
|
| 186 |
+
--infer_backend vllm \
|
| 187 |
+
--max_model_len 8192 \
|
| 188 |
+
--temperature 0 \
|
| 189 |
+
--max_new_tokens 2048
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
最后,使用以下命令将模型推送到ModelScope:
|
| 193 |
+
```shell
|
| 194 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 195 |
+
swift export \
|
| 196 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 197 |
+
--push_to_hub true \
|
| 198 |
+
--hub_model_id '<your-model-id>' \
|
| 199 |
+
--hub_token '<your-sdk-token>' \
|
| 200 |
+
--use_hf false
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
### Web-UI
|
| 204 |
+
|
| 205 |
+
Web-UI是基于gradio界面技术的**零门槛**训练、部署界面方案,具体可以查看[这里](https://swift.readthedocs.io/zh-cn/latest/GetStarted/Web-UI.html)。
|
| 206 |
+
|
| 207 |
+
```shell
|
| 208 |
+
swift web-ui
|
| 209 |
+
```
|
| 210 |
+

|
| 211 |
+
|
| 212 |
+
### 使用Python
|
| 213 |
+
ms-swift也支持使用python的方式进行训练和推理。下面给出训练和推理的**伪代码**,具体可以查看[这里](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb)。
|
| 214 |
+
|
| 215 |
+
训练:
|
| 216 |
+
```python
|
| 217 |
+
# 获取模型和template,并加入可训练的LoRA模块
|
| 218 |
+
model, tokenizer = get_model_tokenizer(model_id_or_path, ...)
|
| 219 |
+
template = get_template(model.model_meta.template, tokenizer, ...)
|
| 220 |
+
model = Swift.prepare_model(model, lora_config)
|
| 221 |
+
|
| 222 |
+
# 下载并载入数据集,并将文本encode成tokens
|
| 223 |
+
train_dataset, val_dataset = load_dataset(dataset_id_or_path, ...)
|
| 224 |
+
train_dataset = EncodePreprocessor(template=template)(train_dataset, num_proc=num_proc)
|
| 225 |
+
val_dataset = EncodePreprocessor(template=template)(val_dataset, num_proc=num_proc)
|
| 226 |
+
|
| 227 |
+
# 进行训练
|
| 228 |
+
trainer = Seq2SeqTrainer(
|
| 229 |
+
model=model,
|
| 230 |
+
args=training_args,
|
| 231 |
+
data_collator=template.data_collator,
|
| 232 |
+
train_dataset=train_dataset,
|
| 233 |
+
eval_dataset=val_dataset,
|
| 234 |
+
template=template,
|
| 235 |
+
)
|
| 236 |
+
trainer.train()
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
推理:
|
| 240 |
+
```python
|
| 241 |
+
# 使用原生pytorch引擎进行推理
|
| 242 |
+
engine = PtEngine(model_id_or_path, adapters=[lora_checkpoint])
|
| 243 |
+
infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}])
|
| 244 |
+
request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature)
|
| 245 |
+
|
| 246 |
+
resp_list = engine.infer([infer_request], request_config)
|
| 247 |
+
print(f'response: {resp_list[0].choices[0].message.content}')
|
| 248 |
+
```
|
| 249 |
+
|
| 250 |
+
## ✨ 如何使用
|
| 251 |
+
|
| 252 |
+
这里给出使用ms-swift进行训练到部署到最简示例,具体可以查看[examples](https://github.com/modelscope/ms-swift/tree/main/examples)。
|
| 253 |
+
|
| 254 |
+
- 若想使用其他模型或者数据集(含多模态模型和数据集),你只需要修改`--model`指定对应模型的id或者path,修改`--dataset`指定对应数据集的id或者path即可。
|
| 255 |
+
- 默认使用ModelScope进行模型和数据集的下载。如果要使用HuggingFace,指定`--use_hf true`即可。
|
| 256 |
+
|
| 257 |
+
| 常用链接 |
|
| 258 |
+
| ------ |
|
| 259 |
+
| [🔥命令行参数](https://swift.readthedocs.io/zh-cn/latest/Instruction/%E5%91%BD%E4%BB%A4%E8%A1%8C%E5%8F%82%E6%95%B0.html) |
|
| 260 |
+
| [支持的模型和数据集](https://swift.readthedocs.io/zh-cn/latest/Instruction/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.html) |
|
| 261 |
+
| [自定义模型](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%A8%A1%E5%9E%8B.html), [🔥自定义数据集](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%95%B0%E6%8D%AE%E9%9B%86.html) |
|
| 262 |
+
| [大模型教程](https://github.com/modelscope/modelscope-classroom/tree/main/LLM-tutorial) |
|
| 263 |
+
|
| 264 |
+
### 训练
|
| 265 |
+
支持的训练方法:
|
| 266 |
+
|
| 267 |
+
| 方法 | 全参数 | LoRA | QLoRA | Deepspeed | 多机 | 多模态 |
|
| 268 |
+
| ------ | ------ |---------------------------------------------------------------------------------------------| ----- | ------ | ------ |----------------------------------------------------------------------------------------------|
|
| 269 |
+
| 预训练 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/pretrain/train.sh) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| 270 |
+
| 指令监督微调 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/train.sh) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/lora_sft.sh) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/qlora) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-gpu/deepspeed) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal) |
|
| 271 |
+
| DPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/dpo.sh) |
|
| 272 |
+
| GRPO训练 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/grpo_zero2.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/multi_node) | ✅ |
|
| 273 |
+
| 奖励模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | ✅ |
|
| 274 |
+
| PPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | ❌ |
|
| 275 |
+
| KTO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
|
| 276 |
+
| CPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | ✅ |
|
| 277 |
+
| SimPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | ✅ |
|
| 278 |
+
| ORPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | ✅ |
|
| 279 |
+
| 分类模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_5/sft.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_vl/sft.sh) |
|
| 280 |
+
| Embedding模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gte.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gme.sh) |
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
预训练:
|
| 284 |
+
```shell
|
| 285 |
+
# 8*A100
|
| 286 |
+
NPROC_PER_NODE=8 \
|
| 287 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
| 288 |
+
swift pt \
|
| 289 |
+
--model Qwen/Qwen2.5-7B \
|
| 290 |
+
--dataset swift/chinese-c4 \
|
| 291 |
+
--streaming true \
|
| 292 |
+
--train_type full \
|
| 293 |
+
--deepspeed zero2 \
|
| 294 |
+
--output_dir output \
|
| 295 |
+
--max_steps 10000 \
|
| 296 |
+
...
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
微调:
|
| 300 |
+
```shell
|
| 301 |
+
CUDA_VISIBLE_DEVICES=0 swift sft \
|
| 302 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 303 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh \
|
| 304 |
+
--train_type lora \
|
| 305 |
+
--output_dir output \
|
| 306 |
+
...
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
RLHF:
|
| 310 |
+
```shell
|
| 311 |
+
CUDA_VISIBLE_DEVICES=0 swift rlhf \
|
| 312 |
+
--rlhf_type dpo \
|
| 313 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 314 |
+
--dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
|
| 315 |
+
--train_type lora \
|
| 316 |
+
--output_dir output \
|
| 317 |
+
...
|
| 318 |
+
```
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
### 推理
|
| 322 |
+
```shell
|
| 323 |
+
CUDA_VISIBLE_DEVICES=0 swift infer \
|
| 324 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 325 |
+
--stream true \
|
| 326 |
+
--infer_backend pt \
|
| 327 |
+
--max_new_tokens 2048
|
| 328 |
+
|
| 329 |
+
# LoRA
|
| 330 |
+
CUDA_VISIBLE_DEVICES=0 swift infer \
|
| 331 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 332 |
+
--adapters swift/test_lora \
|
| 333 |
+
--stream true \
|
| 334 |
+
--infer_backend pt \
|
| 335 |
+
--temperature 0 \
|
| 336 |
+
--max_new_tokens 2048
|
| 337 |
+
```
|
| 338 |
+
|
| 339 |
+
### 界面推理
|
| 340 |
+
```shell
|
| 341 |
+
CUDA_VISIBLE_DEVICES=0 swift app \
|
| 342 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 343 |
+
--stream true \
|
| 344 |
+
--infer_backend pt \
|
| 345 |
+
--max_new_tokens 2048 \
|
| 346 |
+
--lang zh
|
| 347 |
+
```
|
| 348 |
+
|
| 349 |
+
### 部署
|
| 350 |
+
```shell
|
| 351 |
+
CUDA_VISIBLE_DEVICES=0 swift deploy \
|
| 352 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 353 |
+
--infer_backend vllm
|
| 354 |
+
```
|
| 355 |
+
|
| 356 |
+
### 采样
|
| 357 |
+
```shell
|
| 358 |
+
CUDA_VISIBLE_DEVICES=0 swift sample \
|
| 359 |
+
--model LLM-Research/Meta-Llama-3.1-8B-Instruct \
|
| 360 |
+
--sampler_engine pt \
|
| 361 |
+
--num_return_sequences 5 \
|
| 362 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh#5
|
| 363 |
+
```
|
| 364 |
+
|
| 365 |
+
### 评测
|
| 366 |
+
```shell
|
| 367 |
+
CUDA_VISIBLE_DEVICES=0 swift eval \
|
| 368 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 369 |
+
--infer_backend lmdeploy \
|
| 370 |
+
--eval_backend OpenCompass \
|
| 371 |
+
--eval_dataset ARC_c
|
| 372 |
+
```
|
| 373 |
+
|
| 374 |
+
### 量化
|
| 375 |
+
```shell
|
| 376 |
+
CUDA_VISIBLE_DEVICES=0 swift export \
|
| 377 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 378 |
+
--quant_bits 4 --quant_method awq \
|
| 379 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh \
|
| 380 |
+
--output_dir Qwen2.5-7B-Instruct-AWQ
|
| 381 |
+
```
|
| 382 |
+
|
| 383 |
+
### 推送模型
|
| 384 |
+
```shell
|
| 385 |
+
swift export \
|
| 386 |
+
--model <model-path> \
|
| 387 |
+
--push_to_hub true \
|
| 388 |
+
--hub_model_id '<model-id>' \
|
| 389 |
+
--hub_token '<sdk-token>'
|
| 390 |
+
```
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
## 🏛 License
|
| 394 |
+
|
| 395 |
+
本框架使用[Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE)进行许可。模型和数据集请查看原资源页面并遵守对应License。
|
| 396 |
+
|
| 397 |
+
## 📎 引用
|
| 398 |
+
|
| 399 |
+
```bibtex
|
| 400 |
+
@misc{zhao2024swiftascalablelightweightinfrastructure,
|
| 401 |
+
title={SWIFT:A Scalable lightWeight Infrastructure for Fine-Tuning},
|
| 402 |
+
author={Yuze Zhao and Jintao Huang and Jinghan Hu and Xingjun Wang and Yunlin Mao and Daoze Zhang and Zeyinzi Jiang and Zhikai Wu and Baole Ai and Ang Wang and Wenmeng Zhou and Yingda Chen},
|
| 403 |
+
year={2024},
|
| 404 |
+
eprint={2408.05517},
|
| 405 |
+
archivePrefix={arXiv},
|
| 406 |
+
primaryClass={cs.CL},
|
| 407 |
+
url={https://arxiv.org/abs/2408.05517},
|
| 408 |
+
}
|
| 409 |
+
```
|
| 410 |
+
|
| 411 |
+
## Star History
|
| 412 |
+
|
| 413 |
+
[](https://star-history.com/#modelscope/ms-swift&Date)
|
checkMissing.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import torchaudio
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
|
| 8 |
+
def validate_jsonl_audios(jsonl_path):
|
| 9 |
+
"""验证JSONL文件中所有音频文件的完整性"""
|
| 10 |
+
stats = defaultdict(int)
|
| 11 |
+
error_log = []
|
| 12 |
+
valid_samples = 0
|
| 13 |
+
|
| 14 |
+
# 第一次遍历:统计总行数(用于进度条)
|
| 15 |
+
with open(jsonl_path, 'r') as f:
|
| 16 |
+
total_lines = sum(1 for _ in f)
|
| 17 |
+
|
| 18 |
+
# 第二次遍历:实际验证
|
| 19 |
+
with open(jsonl_path, 'r') as f:
|
| 20 |
+
for line_num, line in enumerate(tqdm(f, total=total_lines, desc="验证进度", unit="line")):
|
| 21 |
+
try:
|
| 22 |
+
data = json.loads(line.strip())
|
| 23 |
+
if 'audios' not in data or not data['audios']:
|
| 24 |
+
stats['no_audio_field'] += 1
|
| 25 |
+
continue
|
| 26 |
+
|
| 27 |
+
for audio_path in data['audios']:
|
| 28 |
+
# 检查文件是否存在
|
| 29 |
+
if not os.path.exists(audio_path):
|
| 30 |
+
stats['missing'] += 1
|
| 31 |
+
error_log.append(f"[行{line_num+1}] 缺失文件: {audio_path}")
|
| 32 |
+
continue
|
| 33 |
+
|
| 34 |
+
# 检查文件大小
|
| 35 |
+
if os.path.getsize(audio_path) == 0:
|
| 36 |
+
stats['zero_size'] += 1
|
| 37 |
+
error_log.append(f"[行{line_num+1}] 空文件: {audio_path}")
|
| 38 |
+
continue
|
| 39 |
+
|
| 40 |
+
# 验证音频内容
|
| 41 |
+
try:
|
| 42 |
+
waveform, sr = torchaudio.load(audio_path)
|
| 43 |
+
if waveform.numel() == 0:
|
| 44 |
+
stats['empty_audio'] += 1
|
| 45 |
+
error_log.append(f"[行{line_num+1}] 空音频: {audio_path}")
|
| 46 |
+
elif sr not in [8000, 16000, 22050, 44100, 48000]:
|
| 47 |
+
stats['abnormal_sr'] += 1
|
| 48 |
+
error_log.append(f"[行{line_num+1}] 异常采样率({sr}Hz): {audio_path}")
|
| 49 |
+
else:
|
| 50 |
+
stats['valid'] += 1
|
| 51 |
+
except Exception as e:
|
| 52 |
+
stats['corrupted'] += 1
|
| 53 |
+
error_type = str(e).split('(')[0]
|
| 54 |
+
error_log.append(f"[行{line_num+1}] 损坏文件({error_type}): {audio_path}")
|
| 55 |
+
|
| 56 |
+
valid_samples += 1
|
| 57 |
+
|
| 58 |
+
except json.JSONDecodeError:
|
| 59 |
+
stats['invalid_json'] += 1
|
| 60 |
+
error_log.append(f"[行{line_num+1}] 无效JSON格式")
|
| 61 |
+
|
| 62 |
+
# 打印统计报告
|
| 63 |
+
print("\n===== 验证报告 =====")
|
| 64 |
+
print(f"总行数: {total_lines}")
|
| 65 |
+
print(f"有效样本: {valid_samples}")
|
| 66 |
+
print("--- 问题统计 ---")
|
| 67 |
+
for k, v in sorted(stats.items()):
|
| 68 |
+
print(f"{k}: {v}")
|
| 69 |
+
|
| 70 |
+
# 保存错误日志
|
| 71 |
+
if error_log:
|
| 72 |
+
log_file = f"{os.path.splitext(jsonl_path)[0]}_audio_errors.log"
|
| 73 |
+
with open(log_file, 'w') as f:
|
| 74 |
+
f.write("\n".join(error_log))
|
| 75 |
+
print(f"\n发现 {len(error_log)} 个问题,已保存到 {log_file}")
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
if len(sys.argv) != 2:
|
| 79 |
+
print("使用方法: python validate_audio_jsonl.py <input.jsonl>")
|
| 80 |
+
sys.exit(1)
|
| 81 |
+
|
| 82 |
+
if not os.path.exists(sys.argv[1]):
|
| 83 |
+
print(f"错误: 文件 {sys.argv[1]} 不存在")
|
| 84 |
+
sys.exit(1)
|
| 85 |
+
|
| 86 |
+
validate_jsonl_audios(sys.argv[1])
|
clean_transcripts.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
from typing import List, Dict, Tuple
|
| 4 |
+
|
| 5 |
+
def parse_timestamp(timestamp: str) -> Tuple[int, int]:
|
| 6 |
+
"""Convert timestamp string like '00:15' to seconds."""
|
| 7 |
+
minutes, seconds = map(int, timestamp.split(':'))
|
| 8 |
+
return minutes * 60 + seconds
|
| 9 |
+
|
| 10 |
+
def extract_time_and_speaker(line: str) -> Tuple[Tuple[int, int], str]:
|
| 11 |
+
"""Extract time range and speaker from a line."""
|
| 12 |
+
# Extract time range
|
| 13 |
+
time_match = re.match(r'\[(\d{2}:\d{2}) - (\d{2}:\d{2})\] (Speaker [A-Z]):', line)
|
| 14 |
+
if not time_match:
|
| 15 |
+
return None, None
|
| 16 |
+
|
| 17 |
+
start_time = parse_timestamp(time_match.group(1))
|
| 18 |
+
end_time = parse_timestamp(time_match.group(2))
|
| 19 |
+
speaker = time_match.group(3)
|
| 20 |
+
|
| 21 |
+
return (start_time, end_time), speaker
|
| 22 |
+
|
| 23 |
+
def has_overlap(range1: Tuple[int, int], range2: Tuple[int, int]) -> bool:
|
| 24 |
+
"""Check if two time ranges overlap."""
|
| 25 |
+
start1, end1 = range1
|
| 26 |
+
start2, end2 = range2
|
| 27 |
+
return not (end1 <= start2 or end2 <= start1)
|
| 28 |
+
|
| 29 |
+
def has_same_speaker_overlap(transcript: str) -> bool:
|
| 30 |
+
"""Check if a transcript contains overlapping timestamps for the same speaker."""
|
| 31 |
+
lines = transcript.split('\n')
|
| 32 |
+
# Dictionary to store time ranges for each speaker
|
| 33 |
+
speaker_ranges = {}
|
| 34 |
+
|
| 35 |
+
for line in lines:
|
| 36 |
+
if not line.strip():
|
| 37 |
+
continue
|
| 38 |
+
|
| 39 |
+
time_range, speaker = extract_time_and_speaker(line)
|
| 40 |
+
if time_range is None or speaker is None:
|
| 41 |
+
continue
|
| 42 |
+
|
| 43 |
+
# Check for overlaps with existing ranges of the same speaker
|
| 44 |
+
if speaker in speaker_ranges:
|
| 45 |
+
for existing_range in speaker_ranges[speaker]:
|
| 46 |
+
if has_overlap(time_range, existing_range):
|
| 47 |
+
return True
|
| 48 |
+
|
| 49 |
+
speaker_ranges[speaker].append(time_range)
|
| 50 |
+
else:
|
| 51 |
+
speaker_ranges[speaker] = [time_range]
|
| 52 |
+
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
def process_file(input_file: str, output_file: str, delete_file: str):
|
| 56 |
+
"""Process the JSON file and separate entries with same-speaker overlapping timestamps."""
|
| 57 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 58 |
+
data = json.load(f)
|
| 59 |
+
|
| 60 |
+
if isinstance(data, dict):
|
| 61 |
+
data = [data]
|
| 62 |
+
|
| 63 |
+
cleaned_data = []
|
| 64 |
+
deleted_data = []
|
| 65 |
+
removed_count = 0
|
| 66 |
+
|
| 67 |
+
for entry in data:
|
| 68 |
+
if 'model_output' in entry:
|
| 69 |
+
if not has_same_speaker_overlap(entry['model_output']):
|
| 70 |
+
cleaned_data.append(entry)
|
| 71 |
+
else:
|
| 72 |
+
deleted_data.append(entry)
|
| 73 |
+
removed_count += 1
|
| 74 |
+
print(f"Removing entry with key: {entry.get('key', 'unknown')}")
|
| 75 |
+
|
| 76 |
+
# Save cleaned data
|
| 77 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 78 |
+
json.dump(cleaned_data, f, ensure_ascii=False, indent=2)
|
| 79 |
+
|
| 80 |
+
# Save deleted data
|
| 81 |
+
with open(delete_file, 'w', encoding='utf-8') as f:
|
| 82 |
+
json.dump(deleted_data, f, ensure_ascii=False, indent=2)
|
| 83 |
+
|
| 84 |
+
print(f"\nProcessing Summary:")
|
| 85 |
+
print(f"Processed {len(data)} entries")
|
| 86 |
+
print(f"Removed {removed_count} entries with same-speaker overlapping timestamps")
|
| 87 |
+
print(f"Remaining entries: {len(cleaned_data)}")
|
| 88 |
+
|
| 89 |
+
if __name__ == '__main__':
|
| 90 |
+
input_file = 'silence_overlaps/transcriptions.json'
|
| 91 |
+
output_file = 'silence_overlaps/cleaned_transcriptions2.json'
|
| 92 |
+
delete_file = 'silence_overlaps/delete_transcript2.json'
|
| 93 |
+
process_file(input_file, output_file, delete_file)
|
| 94 |
+
print(f"\nCleaned transcriptions have been saved to {output_file}")
|
| 95 |
+
print(f"Deleted entries have been saved to {delete_file}")
|
count_audios.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from collections import Counter
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
def collect_unique_audio_paths(json_file_path):
|
| 7 |
+
"""
|
| 8 |
+
提取 JSONL 文件中所有不重复的 audios 路径
|
| 9 |
+
"""
|
| 10 |
+
audio_set = set()
|
| 11 |
+
with open(json_file_path, 'r', encoding='utf-8') as f:
|
| 12 |
+
for line_num, line in enumerate(f, 1):
|
| 13 |
+
line = line.strip()
|
| 14 |
+
if not line:
|
| 15 |
+
continue
|
| 16 |
+
try:
|
| 17 |
+
data = json.loads(line)
|
| 18 |
+
if isinstance(data, dict) and 'audios' in data and data['audios']:
|
| 19 |
+
for audio_path in data['audios']:
|
| 20 |
+
audio_set.add(audio_path)
|
| 21 |
+
except Exception as e:
|
| 22 |
+
print(f"第 {line_num} 行处理错误: {e}")
|
| 23 |
+
return audio_set
|
| 24 |
+
|
| 25 |
+
def extract_first_subfolder_after_data(audio_path):
|
| 26 |
+
"""
|
| 27 |
+
提取 audio_path 中 'data/' 后的第一级子文件夹名称
|
| 28 |
+
例如:
|
| 29 |
+
/.../data/output_xxx/yyy/file.wav → 返回 output_xxx
|
| 30 |
+
"""
|
| 31 |
+
try:
|
| 32 |
+
path = Path(audio_path)
|
| 33 |
+
parts = path.parts
|
| 34 |
+
if "wavrewardDataset" in parts:
|
| 35 |
+
data_idx = parts.index("wavrewardDataset")
|
| 36 |
+
if data_idx + 1 < len(parts):
|
| 37 |
+
return parts[data_idx + 1]
|
| 38 |
+
return "unknown"
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"路径解析错误: {audio_path}, 错误: {e}")
|
| 41 |
+
return "error"
|
| 42 |
+
|
| 43 |
+
def main():
|
| 44 |
+
json_file = "all_dataset_train_resampled_16000.jsonl"
|
| 45 |
+
|
| 46 |
+
if not os.path.exists(json_file):
|
| 47 |
+
print(f"文件 {json_file} 不存在")
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
+
print(f"正在处理文件: {json_file}")
|
| 51 |
+
print("=" * 50)
|
| 52 |
+
|
| 53 |
+
# 步骤 1:收集去重后的音频路径
|
| 54 |
+
unique_audio_paths = collect_unique_audio_paths(json_file)
|
| 55 |
+
print(f"不重复音频文件数: {len(unique_audio_paths)}")
|
| 56 |
+
|
| 57 |
+
# 步骤 2:按 data 后的一级子目录统计
|
| 58 |
+
folder_counter = Counter()
|
| 59 |
+
for audio_path in unique_audio_paths:
|
| 60 |
+
first_subfolder = extract_first_subfolder_after_data(audio_path)
|
| 61 |
+
folder_counter[first_subfolder] += 1
|
| 62 |
+
|
| 63 |
+
print("\n按 data 后一级子文件夹统计(基于去重后的路径):")
|
| 64 |
+
print("-" * 50)
|
| 65 |
+
for folder, count in sorted(folder_counter.items(), key=lambda x: -x[1]):
|
| 66 |
+
print(f"{folder}: {count} 个文件")
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
main()
|
count_folders-Copy1.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from collections import Counter
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
def count_folder_prefixes(json_file_path):
|
| 7 |
+
"""
|
| 8 |
+
统计JSON文件中音频路径的前缀文件夹出现次数
|
| 9 |
+
提取到 /root/autodl-tmp/wavrewardDataset/ultrachat_200k/data 层级
|
| 10 |
+
然后根据子文件夹进行分类统计
|
| 11 |
+
支持JSONL格式(每行一个JSON对象)
|
| 12 |
+
"""
|
| 13 |
+
folder_counts = Counter()
|
| 14 |
+
|
| 15 |
+
# 读取JSON文件(支持JSONL格式)
|
| 16 |
+
with open(json_file_path, 'r', encoding='utf-8') as f:
|
| 17 |
+
for line_num, line in enumerate(f, 1):
|
| 18 |
+
line = line.strip()
|
| 19 |
+
if not line: # 跳过空行
|
| 20 |
+
continue
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
# 解析每一行的JSON对象
|
| 24 |
+
data = json.loads(line)
|
| 25 |
+
|
| 26 |
+
# 处理数据
|
| 27 |
+
if isinstance(data, dict) and 'audios' in data and data['audios']:
|
| 28 |
+
for audio_path in data['audios']:
|
| 29 |
+
prefix = extract_folder_prefix(audio_path)
|
| 30 |
+
if prefix:
|
| 31 |
+
folder_counts[prefix] += 1
|
| 32 |
+
|
| 33 |
+
except json.JSONDecodeError as e:
|
| 34 |
+
print(f"第 {line_num} 行JSON解析错误: {e}")
|
| 35 |
+
continue
|
| 36 |
+
except Exception as e:
|
| 37 |
+
print(f"第 {line_num} 行处理错误: {e}")
|
| 38 |
+
continue
|
| 39 |
+
|
| 40 |
+
return folder_counts
|
| 41 |
+
|
| 42 |
+
def extract_folder_prefix(audio_path):
|
| 43 |
+
"""
|
| 44 |
+
从音频路径中提取到指定层级的前缀,然后包含子文件夹进行分类
|
| 45 |
+
例如: /root/autodl-tmp/wavrewardDataset/ultrachat_200k/data/output_2000_3000_wrongpause/xxx.wav
|
| 46 |
+
提取到: /root/autodl-tmp/wavrewardDataset/ultrachat_200k/data/output_2000_3000_wrongpause
|
| 47 |
+
"""
|
| 48 |
+
try:
|
| 49 |
+
# 使用Path对象处理路径
|
| 50 |
+
path = Path(audio_path)
|
| 51 |
+
|
| 52 |
+
# 查找目标文件夹在路径中的位置
|
| 53 |
+
parts = path.parts
|
| 54 |
+
target_folder = "newdataset_10k"
|
| 55 |
+
|
| 56 |
+
if target_folder in parts:
|
| 57 |
+
# 找到目标文件夹的位置
|
| 58 |
+
target_index = parts.index(target_folder)
|
| 59 |
+
|
| 60 |
+
# 提取到目标文件夹(包含目标文件夹)
|
| 61 |
+
prefix_parts = parts[:target_index + 1]
|
| 62 |
+
|
| 63 |
+
# 如果目标文件夹后面还有子文件夹,也包含进来
|
| 64 |
+
if target_index + 1 < len(parts):
|
| 65 |
+
# 包含下一个子文件夹(如 output_2000_3000_wrongpause)
|
| 66 |
+
prefix_parts = parts[:target_index + 2]
|
| 67 |
+
|
| 68 |
+
return str(Path(*prefix_parts))
|
| 69 |
+
else:
|
| 70 |
+
# 如果没找到目标文件夹,返回路径的前几级
|
| 71 |
+
if len(parts) >= 4: # /root/autodl-tmp/xxx/...
|
| 72 |
+
return str(Path(*parts[:4]))
|
| 73 |
+
else:
|
| 74 |
+
return str(path.parent)
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"处理路径时出错: {audio_path}, 错误: {e}")
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
def main():
|
| 81 |
+
# 指定JSON文件路径
|
| 82 |
+
json_file = "needtouse/adddata_absolute.jsonl"
|
| 83 |
+
|
| 84 |
+
if not os.path.exists(json_file):
|
| 85 |
+
print(f"文件 {json_file} 不存在")
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
print(f"正在统计文件: {json_file}")
|
| 89 |
+
print("=" * 50)
|
| 90 |
+
|
| 91 |
+
# 统计文件夹前缀
|
| 92 |
+
folder_counts = count_folder_prefixes(json_file)
|
| 93 |
+
|
| 94 |
+
# 输出结果
|
| 95 |
+
if folder_counts:
|
| 96 |
+
print("文件夹前缀统计结果:")
|
| 97 |
+
print("-" * 50)
|
| 98 |
+
for folder, count in sorted(folder_counts.items()):
|
| 99 |
+
print(f"{folder}: {count} 次")
|
| 100 |
+
|
| 101 |
+
print(f"\n总计: {len(folder_counts)} 个不同的文件夹前缀")
|
| 102 |
+
print(f"总音频文件数: {sum(folder_counts.values())}")
|
| 103 |
+
|
| 104 |
+
# 按子文件夹类型统计
|
| 105 |
+
print("\n按子文件夹类型统计:")
|
| 106 |
+
print("-" * 30)
|
| 107 |
+
subfolder_counts = Counter()
|
| 108 |
+
for folder in folder_counts.keys():
|
| 109 |
+
path = Path(folder)
|
| 110 |
+
if len(path.parts) > 0:
|
| 111 |
+
# 获取最后一个文件夹名作为子文件夹类型
|
| 112 |
+
subfolder = path.parts[-1]
|
| 113 |
+
subfolder_counts[subfolder] += folder_counts[folder]
|
| 114 |
+
|
| 115 |
+
for subfolder, count in sorted(subfolder_counts.items()):
|
| 116 |
+
print(f"{subfolder}: {count} 次")
|
| 117 |
+
|
| 118 |
+
else:
|
| 119 |
+
print("未找到任何音频路径")
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
main()
|
count_folders.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from collections import Counter
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
def count_folder_prefixes(json_file_path):
|
| 7 |
+
"""
|
| 8 |
+
统计JSON文件中音频路径的前缀文件夹出现次数
|
| 9 |
+
提取到 /root/autodl-tmp/wavrewardDataset/ultrachat_200k/data 层级
|
| 10 |
+
然后根据子文件夹进行分类统计
|
| 11 |
+
支持JSONL格式(每行一个JSON对象)
|
| 12 |
+
"""
|
| 13 |
+
folder_counts = Counter()
|
| 14 |
+
|
| 15 |
+
# 读取JSON文件(支持JSONL格式)
|
| 16 |
+
with open(json_file_path, 'r', encoding='utf-8') as f:
|
| 17 |
+
for line_num, line in enumerate(f, 1):
|
| 18 |
+
line = line.strip()
|
| 19 |
+
if not line: # 跳过空行
|
| 20 |
+
continue
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
# 解析每一行的JSON对象
|
| 24 |
+
data = json.loads(line)
|
| 25 |
+
|
| 26 |
+
# 处理数据
|
| 27 |
+
if isinstance(data, dict) and 'audios' in data and data['audios']:
|
| 28 |
+
for audio_path in data['audios']:
|
| 29 |
+
prefix = extract_folder_prefix(audio_path)
|
| 30 |
+
if prefix:
|
| 31 |
+
folder_counts[prefix] += 1
|
| 32 |
+
|
| 33 |
+
except json.JSONDecodeError as e:
|
| 34 |
+
print(f"第 {line_num} 行JSON解析错误: {e}")
|
| 35 |
+
continue
|
| 36 |
+
except Exception as e:
|
| 37 |
+
print(f"第 {line_num} 行处理错误: {e}")
|
| 38 |
+
continue
|
| 39 |
+
|
| 40 |
+
return folder_counts
|
| 41 |
+
|
| 42 |
+
def extract_folder_prefix(audio_path):
|
| 43 |
+
"""
|
| 44 |
+
从音频路径中提取到指定层级的前缀,然后包含子文件夹进行分类
|
| 45 |
+
例如: /root/autodl-tmp/wavrewardDataset/ultrachat_200k/data/output_2000_3000_wrongpause/xxx.wav
|
| 46 |
+
提取到: /root/autodl-tmp/wavrewardDataset/ultrachat_200k/data/output_2000_3000_wrongpause
|
| 47 |
+
"""
|
| 48 |
+
try:
|
| 49 |
+
# 使用Path对象处理路径
|
| 50 |
+
path = Path(audio_path)
|
| 51 |
+
|
| 52 |
+
# 查找目标文件夹在路径中的位置
|
| 53 |
+
parts = path.parts
|
| 54 |
+
target_folder = "data"
|
| 55 |
+
|
| 56 |
+
if target_folder in parts:
|
| 57 |
+
# 找到目标文件夹的位置
|
| 58 |
+
target_index = parts.index(target_folder)
|
| 59 |
+
|
| 60 |
+
# 提取到目标文件夹(包含目标文件夹)
|
| 61 |
+
prefix_parts = parts[:target_index + 1]
|
| 62 |
+
|
| 63 |
+
# 如果目标文件夹后面还有子文件夹,也包含进来
|
| 64 |
+
if target_index + 1 < len(parts):
|
| 65 |
+
# 包含下一个子文件夹(如 output_2000_3000_wrongpause)
|
| 66 |
+
prefix_parts = parts[:target_index + 2]
|
| 67 |
+
|
| 68 |
+
return str(Path(*prefix_parts))
|
| 69 |
+
else:
|
| 70 |
+
# 如果没找到目标文件夹,返回路径的前几级
|
| 71 |
+
if len(parts) >= 4: # /root/autodl-tmp/xxx/...
|
| 72 |
+
return str(Path(*parts[:4]))
|
| 73 |
+
else:
|
| 74 |
+
return str(path.parent)
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"处理路径时出错: {audio_path}, 错误: {e}")
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
def main():
|
| 81 |
+
# 指定JSON文件路径
|
| 82 |
+
json_file = "dataset_4JOB.jsonl"
|
| 83 |
+
|
| 84 |
+
if not os.path.exists(json_file):
|
| 85 |
+
print(f"文件 {json_file} 不存在")
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
print(f"正在统计文件: {json_file}")
|
| 89 |
+
print("=" * 50)
|
| 90 |
+
|
| 91 |
+
# 统计文件夹前缀
|
| 92 |
+
folder_counts = count_folder_prefixes(json_file)
|
| 93 |
+
|
| 94 |
+
# 输出结果
|
| 95 |
+
if folder_counts:
|
| 96 |
+
print("文件夹前缀统计结果:")
|
| 97 |
+
print("-" * 50)
|
| 98 |
+
for folder, count in sorted(folder_counts.items()):
|
| 99 |
+
print(f"{folder}: {count} 次")
|
| 100 |
+
|
| 101 |
+
print(f"\n总计: {len(folder_counts)} 个不同的文件夹前缀")
|
| 102 |
+
print(f"总音频文件数: {sum(folder_counts.values())}")
|
| 103 |
+
|
| 104 |
+
# 按子文件夹类型统计
|
| 105 |
+
print("\n按子文件夹类型统计:")
|
| 106 |
+
print("-" * 30)
|
| 107 |
+
subfolder_counts = Counter()
|
| 108 |
+
for folder in folder_counts.keys():
|
| 109 |
+
path = Path(folder)
|
| 110 |
+
if len(path.parts) > 0:
|
| 111 |
+
# 获取最后一个文件夹名作为子文件夹类型
|
| 112 |
+
subfolder = path.parts[-1]
|
| 113 |
+
subfolder_counts[subfolder] += folder_counts[folder]
|
| 114 |
+
|
| 115 |
+
for subfolder, count in sorted(subfolder_counts.items()):
|
| 116 |
+
print(f"{subfolder}: {count} 次")
|
| 117 |
+
|
| 118 |
+
else:
|
| 119 |
+
print("未找到任何音频路径")
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
main()
|
dialogue_length_distribution.png
ADDED
|
dialogue_length_ranges.png
ADDED
|
docs/transformers/build/lib/transformers/models/sam/processing_sam.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Processor class for SAM.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from copy import deepcopy
|
| 20 |
+
from typing import List, Optional, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
from ...image_utils import ImageInput, VideoInput
|
| 25 |
+
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
|
| 26 |
+
from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput
|
| 27 |
+
from ...utils import is_tf_available, is_torch_available
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if is_torch_available():
|
| 31 |
+
import torch
|
| 32 |
+
|
| 33 |
+
if is_tf_available():
|
| 34 |
+
import tensorflow as tf
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class SamImagesKwargs(ImagesKwargs):
|
| 38 |
+
segmentation_maps: Optional[ImageInput]
|
| 39 |
+
input_points: Optional[List[List[float]]]
|
| 40 |
+
input_labels: Optional[List[List[int]]]
|
| 41 |
+
input_boxes: Optional[List[List[List[float]]]]
|
| 42 |
+
point_pad_value: Optional[int]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class SamProcessorKwargs(ProcessingKwargs, total=False):
|
| 46 |
+
images_kwargs: SamImagesKwargs
|
| 47 |
+
_defaults = {
|
| 48 |
+
"images_kwargs": {
|
| 49 |
+
"point_pad_value": -10,
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SamProcessor(ProcessorMixin):
|
| 55 |
+
r"""
|
| 56 |
+
Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a
|
| 57 |
+
single processor.
|
| 58 |
+
|
| 59 |
+
[`SamProcessor`] offers all the functionalities of [`SamImageProcessor`]. See the docstring of
|
| 60 |
+
[`~SamImageProcessor.__call__`] for more information.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
image_processor (`SamImageProcessor`):
|
| 64 |
+
An instance of [`SamImageProcessor`]. The image processor is a required input.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
attributes = ["image_processor"]
|
| 68 |
+
image_processor_class = "SamImageProcessor"
|
| 69 |
+
# For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
|
| 70 |
+
optional_call_args = [
|
| 71 |
+
"segmentation_maps",
|
| 72 |
+
"input_points",
|
| 73 |
+
"input_labels",
|
| 74 |
+
"input_boxes",
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
def __init__(self, image_processor):
|
| 78 |
+
super().__init__(image_processor)
|
| 79 |
+
self.target_size = self.image_processor.size["longest_edge"]
|
| 80 |
+
|
| 81 |
+
def __call__(
|
| 82 |
+
self,
|
| 83 |
+
images: Optional[ImageInput] = None,
|
| 84 |
+
# The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes`
|
| 85 |
+
# arguments that may be passed as a positional argument.
|
| 86 |
+
# See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details,
|
| 87 |
+
# or this conversation for more context:
|
| 88 |
+
# https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116
|
| 89 |
+
# This behavior is only needed for backward compatibility and will be removed in future versions.
|
| 90 |
+
*args, # to be deprecated
|
| 91 |
+
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
| 92 |
+
audio: Optional[AudioInput] = None,
|
| 93 |
+
video: Optional[VideoInput] = None,
|
| 94 |
+
**kwargs,
|
| 95 |
+
) -> BatchEncoding:
|
| 96 |
+
"""
|
| 97 |
+
This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D
|
| 98 |
+
points and bounding boxes for the model if they are provided.
|
| 99 |
+
"""
|
| 100 |
+
output_kwargs = self._merge_kwargs(
|
| 101 |
+
SamProcessorKwargs,
|
| 102 |
+
tokenizer_init_kwargs={},
|
| 103 |
+
**kwargs,
|
| 104 |
+
**self.prepare_and_validate_optional_call_args(*args),
|
| 105 |
+
)
|
| 106 |
+
input_points = output_kwargs["images_kwargs"].pop("input_points", None)
|
| 107 |
+
input_labels = output_kwargs["images_kwargs"].pop("input_labels", None)
|
| 108 |
+
input_boxes = output_kwargs["images_kwargs"].pop("input_boxes", None)
|
| 109 |
+
point_pad_value = output_kwargs["images_kwargs"].pop("point_pad_value", None)
|
| 110 |
+
|
| 111 |
+
encoding_image_processor = self.image_processor(
|
| 112 |
+
images,
|
| 113 |
+
**output_kwargs["images_kwargs"],
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# pop arguments that are not used in the foward but used nevertheless
|
| 117 |
+
original_sizes = encoding_image_processor["original_sizes"]
|
| 118 |
+
|
| 119 |
+
if hasattr(original_sizes, "numpy"): # Checks if Torch or TF tensor
|
| 120 |
+
original_sizes = original_sizes.numpy()
|
| 121 |
+
|
| 122 |
+
input_points, input_labels, input_boxes = self._check_and_preprocess_points(
|
| 123 |
+
input_points=input_points,
|
| 124 |
+
input_labels=input_labels,
|
| 125 |
+
input_boxes=input_boxes,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
encoding_image_processor = self._normalize_and_convert(
|
| 129 |
+
encoding_image_processor,
|
| 130 |
+
original_sizes,
|
| 131 |
+
input_points=input_points,
|
| 132 |
+
input_labels=input_labels,
|
| 133 |
+
input_boxes=input_boxes,
|
| 134 |
+
return_tensors=output_kwargs["common_kwargs"].get("return_tensors"),
|
| 135 |
+
point_pad_value=point_pad_value,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
return encoding_image_processor
|
| 139 |
+
|
| 140 |
+
def _normalize_and_convert(
|
| 141 |
+
self,
|
| 142 |
+
encoding_image_processor,
|
| 143 |
+
original_sizes,
|
| 144 |
+
input_points=None,
|
| 145 |
+
input_labels=None,
|
| 146 |
+
input_boxes=None,
|
| 147 |
+
return_tensors="pt",
|
| 148 |
+
point_pad_value=-10,
|
| 149 |
+
):
|
| 150 |
+
if input_points is not None:
|
| 151 |
+
if len(original_sizes) != len(input_points):
|
| 152 |
+
input_points = [
|
| 153 |
+
self._normalize_coordinates(self.target_size, point, original_sizes[0]) for point in input_points
|
| 154 |
+
]
|
| 155 |
+
else:
|
| 156 |
+
input_points = [
|
| 157 |
+
self._normalize_coordinates(self.target_size, point, original_size)
|
| 158 |
+
for point, original_size in zip(input_points, original_sizes)
|
| 159 |
+
]
|
| 160 |
+
# check that all arrays have the same shape
|
| 161 |
+
if not all(point.shape == input_points[0].shape for point in input_points):
|
| 162 |
+
if input_labels is not None:
|
| 163 |
+
input_points, input_labels = self._pad_points_and_labels(
|
| 164 |
+
input_points, input_labels, point_pad_value
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
input_points = np.array(input_points)
|
| 168 |
+
|
| 169 |
+
if input_labels is not None:
|
| 170 |
+
input_labels = np.array(input_labels)
|
| 171 |
+
|
| 172 |
+
if input_boxes is not None:
|
| 173 |
+
if len(original_sizes) != len(input_boxes):
|
| 174 |
+
input_boxes = [
|
| 175 |
+
self._normalize_coordinates(self.target_size, box, original_sizes[0], is_bounding_box=True)
|
| 176 |
+
for box in input_boxes
|
| 177 |
+
]
|
| 178 |
+
else:
|
| 179 |
+
input_boxes = [
|
| 180 |
+
self._normalize_coordinates(self.target_size, box, original_size, is_bounding_box=True)
|
| 181 |
+
for box, original_size in zip(input_boxes, original_sizes)
|
| 182 |
+
]
|
| 183 |
+
input_boxes = np.array(input_boxes)
|
| 184 |
+
|
| 185 |
+
if input_boxes is not None:
|
| 186 |
+
if return_tensors == "pt":
|
| 187 |
+
input_boxes = torch.from_numpy(input_boxes)
|
| 188 |
+
# boxes batch size of 1 by default
|
| 189 |
+
input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes
|
| 190 |
+
elif return_tensors == "tf":
|
| 191 |
+
input_boxes = tf.convert_to_tensor(input_boxes)
|
| 192 |
+
# boxes batch size of 1 by default
|
| 193 |
+
input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes
|
| 194 |
+
encoding_image_processor.update({"input_boxes": input_boxes})
|
| 195 |
+
if input_points is not None:
|
| 196 |
+
if return_tensors == "pt":
|
| 197 |
+
input_points = torch.from_numpy(input_points)
|
| 198 |
+
# point batch size of 1 by default
|
| 199 |
+
input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points
|
| 200 |
+
elif return_tensors == "tf":
|
| 201 |
+
input_points = tf.convert_to_tensor(input_points)
|
| 202 |
+
# point batch size of 1 by default
|
| 203 |
+
input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points
|
| 204 |
+
encoding_image_processor.update({"input_points": input_points})
|
| 205 |
+
if input_labels is not None:
|
| 206 |
+
if return_tensors == "pt":
|
| 207 |
+
input_labels = torch.from_numpy(input_labels)
|
| 208 |
+
# point batch size of 1 by default
|
| 209 |
+
input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels
|
| 210 |
+
elif return_tensors == "tf":
|
| 211 |
+
input_labels = tf.convert_to_tensor(input_labels)
|
| 212 |
+
# point batch size of 1 by default
|
| 213 |
+
input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels
|
| 214 |
+
encoding_image_processor.update({"input_labels": input_labels})
|
| 215 |
+
|
| 216 |
+
return encoding_image_processor
|
| 217 |
+
|
| 218 |
+
def _pad_points_and_labels(self, input_points, input_labels, point_pad_value):
|
| 219 |
+
r"""
|
| 220 |
+
The method pads the 2D points and labels to the maximum number of points in the batch.
|
| 221 |
+
"""
|
| 222 |
+
expected_nb_points = max([point.shape[0] for point in input_points])
|
| 223 |
+
processed_input_points = []
|
| 224 |
+
for i, point in enumerate(input_points):
|
| 225 |
+
if point.shape[0] != expected_nb_points:
|
| 226 |
+
point = np.concatenate(
|
| 227 |
+
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0
|
| 228 |
+
)
|
| 229 |
+
input_labels[i] = np.append(input_labels[i], [point_pad_value])
|
| 230 |
+
processed_input_points.append(point)
|
| 231 |
+
input_points = processed_input_points
|
| 232 |
+
return input_points, input_labels
|
| 233 |
+
|
| 234 |
+
def _normalize_coordinates(
|
| 235 |
+
self, target_size: int, coords: np.ndarray, original_size, is_bounding_box=False
|
| 236 |
+
) -> np.ndarray:
|
| 237 |
+
"""
|
| 238 |
+
Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format.
|
| 239 |
+
"""
|
| 240 |
+
old_h, old_w = original_size
|
| 241 |
+
new_h, new_w = self.image_processor._get_preprocess_shape(original_size, longest_edge=target_size)
|
| 242 |
+
coords = deepcopy(coords).astype(float)
|
| 243 |
+
|
| 244 |
+
if is_bounding_box:
|
| 245 |
+
coords = coords.reshape(-1, 2, 2)
|
| 246 |
+
|
| 247 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
| 248 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
| 249 |
+
|
| 250 |
+
if is_bounding_box:
|
| 251 |
+
coords = coords.reshape(-1, 4)
|
| 252 |
+
|
| 253 |
+
return coords
|
| 254 |
+
|
| 255 |
+
def _check_and_preprocess_points(
|
| 256 |
+
self,
|
| 257 |
+
input_points=None,
|
| 258 |
+
input_labels=None,
|
| 259 |
+
input_boxes=None,
|
| 260 |
+
):
|
| 261 |
+
r"""
|
| 262 |
+
Check and preprocesses the 2D points, labels and bounding boxes. It checks if the input is valid and if they
|
| 263 |
+
are, it converts the coordinates of the points and bounding boxes. If a user passes directly a `torch.Tensor`,
|
| 264 |
+
it is converted to a `numpy.ndarray` and then to a `list`.
|
| 265 |
+
"""
|
| 266 |
+
if input_points is not None:
|
| 267 |
+
if hasattr(input_points, "numpy"): # Checks for TF or Torch tensor
|
| 268 |
+
input_points = input_points.numpy().tolist()
|
| 269 |
+
|
| 270 |
+
if not isinstance(input_points, list) or not isinstance(input_points[0], list):
|
| 271 |
+
raise ValueError("Input points must be a list of list of floating points.")
|
| 272 |
+
input_points = [np.array(input_point) for input_point in input_points]
|
| 273 |
+
else:
|
| 274 |
+
input_points = None
|
| 275 |
+
|
| 276 |
+
if input_labels is not None:
|
| 277 |
+
if hasattr(input_labels, "numpy"):
|
| 278 |
+
input_labels = input_labels.numpy().tolist()
|
| 279 |
+
|
| 280 |
+
if not isinstance(input_labels, list) or not isinstance(input_labels[0], list):
|
| 281 |
+
raise ValueError("Input labels must be a list of list integers.")
|
| 282 |
+
input_labels = [np.array(label) for label in input_labels]
|
| 283 |
+
else:
|
| 284 |
+
input_labels = None
|
| 285 |
+
|
| 286 |
+
if input_boxes is not None:
|
| 287 |
+
if hasattr(input_boxes, "numpy"):
|
| 288 |
+
input_boxes = input_boxes.numpy().tolist()
|
| 289 |
+
|
| 290 |
+
if (
|
| 291 |
+
not isinstance(input_boxes, list)
|
| 292 |
+
or not isinstance(input_boxes[0], list)
|
| 293 |
+
or not isinstance(input_boxes[0][0], list)
|
| 294 |
+
):
|
| 295 |
+
raise ValueError("Input boxes must be a list of list of list of floating points.")
|
| 296 |
+
input_boxes = [np.array(box).astype(np.float32) for box in input_boxes]
|
| 297 |
+
else:
|
| 298 |
+
input_boxes = None
|
| 299 |
+
|
| 300 |
+
return input_points, input_labels, input_boxes
|
| 301 |
+
|
| 302 |
+
@property
|
| 303 |
+
def model_input_names(self):
|
| 304 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 305 |
+
return list(dict.fromkeys(image_processor_input_names))
|
| 306 |
+
|
| 307 |
+
def post_process_masks(self, *args, **kwargs):
|
| 308 |
+
return self.image_processor.post_process_masks(*args, **kwargs)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
__all__ = ["SamProcessor"]
|
docs/transformers/build/lib/transformers/models/seamless_m4t/convert_fairseq2_to_hf.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Converting Meta SeamlessM4T checkpoints from seamless_communication to HF."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import os
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from accelerate.utils.modeling import find_tied_parameters
|
| 23 |
+
from seamless_communication.models.inference.translator import Translator
|
| 24 |
+
|
| 25 |
+
from transformers import (
|
| 26 |
+
SeamlessM4TConfig,
|
| 27 |
+
SeamlessM4TFeatureExtractor,
|
| 28 |
+
SeamlessM4TModel,
|
| 29 |
+
SeamlessM4TProcessor,
|
| 30 |
+
SeamlessM4TTokenizer,
|
| 31 |
+
)
|
| 32 |
+
from transformers.utils import logging
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
UNIT_SUPPORTED_LANGUAGES = ["__arb__", "__ben__", "__cat__", "__ces__", "__cmn__", "__cym__", "__dan__", "__deu__", "__eng__", "__est__", "__fin__", "__fra__", "__hin__", "__ind__", "__ita__", "__jpn__", "__kan__", "__kor__", "__mlt__", "__nld__", "__pes__", "__pol__", "__por__", "__ron__", "__rus__", "__slk__", "__spa__", "__swe__", "__swh__", "__tam__", "__tel__", "__tgl__", "__tha__", "__tur__", "__ukr__", "__urd__", "__uzn__", "__vie__", ] # fmt: skip
|
| 36 |
+
VOCODER_SUPPORTED_LANGUAGES = ["__arb__", "__ben__", "__cat__", "__ces__", "__cmn__", "__cym__", "__dan__", "__deu__", "__eng__", "__est__", "__fin__", "__fra__", "__hin__", "__ind__", "__ita__", "__jpn__", "__kor__", "__mlt__", "__nld__", "__pes__", "__pol__", "__por__", "__ron__", "__rus__", "__slk__", "__spa__", "__swe__", "__swh__", "__tel__", "__tgl__", "__tha__", "__tur__", "__ukr__", "__urd__", "__uzn__", "__vie__",] # fmt: skip
|
| 37 |
+
MEDIUM_SUPPORTED_LANGUAGES = ["ace","ace_Latn","acm","acq","aeb","afr","ajp","aka","amh","apc","arb","ars","ary","arz","asm","ast","awa","ayr","azb","azj","bak","bam","ban","bel","bem","ben","bho","bjn","bjn_Latn","bod","bos","bug","bul","cat","ceb","ces","cjk","ckb","crh","cym","dan","deu","dik","dyu","dzo","ell","eng","epo","est","eus","ewe","fao","pes","fij","fin","fon","fra","fur","fuv","gla","gle","glg","grn","guj","hat","hau","heb","hin","hne","hrv","hun","hye","ibo","ilo","ind","isl","ita","jav","jpn","kab","kac","kam","kan","kas","kas_Deva","kat","knc","knc_Latn","kaz","kbp","kea","khm","kik","kin","kir","kmb","kon","kor","kmr","lao","lvs","lij","lim","lin","lit","lmo","ltg","ltz","lua","lug","luo","lus","mag","mai","mal","mar","min","mkd","plt","mlt","mni","khk","mos","mri","zsm","mya","nld","nno","nob","npi","nso","nus","nya","oci","gaz","ory","pag","pan","pap","pol","por","prs","pbt","quy","ron","run","rus","sag","san","sat","scn","shn","sin","slk","slv","smo","sna","snd","som","sot","spa","als","srd","srp","ssw","sun","swe","swh","szl","tam","tat","tel","tgk","tgl","tha","tir","taq","taq_Tfng","tpi","tsn","tso","tuk","tum","tur","twi","tzm","uig","ukr","umb","urd","uzn","vec","vie","war","wol","xho","ydd","yor","yue","cmn","cmn_Hant","zul",] # fmt: skip
|
| 38 |
+
LARGE_SUPPORTED_LANGUAGES = ["afr","amh","arb","ary","arz","asm","azj","bel","ben","bos","bul","cat","ceb","ces","ckb","cmn","cmn_Hant","cym","dan","deu","ell","eng","est","eus","fin","fra","fuv","gaz","gle","glg","guj","heb","hin","hrv","hun","hye","ibo","ind","isl","ita","jav","jpn","kan","kat","kaz","khk","khm","kir","kor","lao","lit","lug","luo","lvs","mai","mal","mar","mkd","mlt","mni","mya","nld","nno","nob","npi","nya","ory","pan","pbt","pes","pol","por","ron","rus","sat","slk","slv","sna","snd","som","spa","srp","swe","swh","tam","tel","tgk","tgl","tha","tur","ukr","urd","uzn","vie","yor","yue","zlm","zul",] # fmt: skip
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def assert_param_count(model_1, model_2):
|
| 42 |
+
count_1 = sum(p[1].numel() for p in model_1.named_parameters() if "final_proj" not in p[0])
|
| 43 |
+
count_2 = sum(p[1].numel() for p in model_2.named_parameters() if "final_proj" not in p[0])
|
| 44 |
+
assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def param_count(model):
|
| 48 |
+
return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0])
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _grab_best_device(use_gpu=True):
|
| 52 |
+
if torch.cuda.device_count() > 0 and use_gpu:
|
| 53 |
+
device = "cuda"
|
| 54 |
+
else:
|
| 55 |
+
device = "cpu"
|
| 56 |
+
return torch.device(device)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
logging.set_verbosity_info()
|
| 60 |
+
logger = logging.get_logger(__name__)
|
| 61 |
+
|
| 62 |
+
vocoder_convert_list = [
|
| 63 |
+
("ups", "hifi_gan.upsampler"),
|
| 64 |
+
("conv_pre", "hifi_gan.conv_pre"),
|
| 65 |
+
("resblocks", "hifi_gan.resblocks"),
|
| 66 |
+
("conv_post", "hifi_gan.conv_post"),
|
| 67 |
+
("lang", "language_embedding"),
|
| 68 |
+
("spkr", "speaker_embedding"),
|
| 69 |
+
("dict.", "unit_embedding."),
|
| 70 |
+
("dur_predictor.conv1.0", "dur_predictor.conv1"),
|
| 71 |
+
("dur_predictor.conv2.0", "dur_predictor.conv2"),
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
# order is important
|
| 75 |
+
wav2vec_convert_list = [
|
| 76 |
+
("speech_encoder_frontend.model_dim_proj", "feature_projection.projection"),
|
| 77 |
+
("speech_encoder_frontend.post_extract_layer_norm", "feature_projection.layer_norm"),
|
| 78 |
+
("speech_encoder_frontend.pos_encoder.conv", "encoder.pos_conv_embed.conv"),
|
| 79 |
+
("speech_encoder.inner.layers", "encoder.layers"),
|
| 80 |
+
("speech_encoder.inner_layer_norm", "encoder.layer_norm"),
|
| 81 |
+
("speech_encoder.adaptor_layers", "adapter.layers"),
|
| 82 |
+
("inner_proj", "intermediate_dense"),
|
| 83 |
+
("self_attn.output_proj", "self_attn.linear_out"),
|
| 84 |
+
("output_proj", "output_dense"),
|
| 85 |
+
("self_attn.k_proj", "self_attn.linear_k"),
|
| 86 |
+
("self_attn.v_proj", "self_attn.linear_v"),
|
| 87 |
+
("self_attn.q_proj", "self_attn.linear_q"),
|
| 88 |
+
("self_attn.sdpa.u_bias", "self_attn.pos_bias_u"),
|
| 89 |
+
("self_attn.sdpa.v_bias", "self_attn.pos_bias_v"),
|
| 90 |
+
("self_attn.sdpa.r_proj", "self_attn.linear_pos"),
|
| 91 |
+
("conv.pointwise_conv1", "conv_module.pointwise_conv1"),
|
| 92 |
+
("conv.pointwise_conv2", "conv_module.pointwise_conv2"),
|
| 93 |
+
("conv.depthwise_conv", "conv_module.depthwise_conv"),
|
| 94 |
+
("conv.batch_norm", "conv_module.batch_norm"),
|
| 95 |
+
("conv_layer_norm", "conv_module.layer_norm"),
|
| 96 |
+
("speech_encoder.proj1", "intermediate_ffn.intermediate_dense"),
|
| 97 |
+
("speech_encoder.proj2", "intermediate_ffn.output_dense"),
|
| 98 |
+
("speech_encoder.layer_norm", "inner_layer_norm"),
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
t2u_convert_list = [
|
| 102 |
+
("t2u_model.final_proj", "lm_head"),
|
| 103 |
+
("t2u_model.", "model."),
|
| 104 |
+
("encoder_decoder_attn_layer_norm", "cross_attention_layer_norm"),
|
| 105 |
+
("encoder_decoder_attn", "cross_attention"),
|
| 106 |
+
("linear_k", "k_proj"),
|
| 107 |
+
("linear_v", "v_proj"),
|
| 108 |
+
("linear_q", "q_proj"),
|
| 109 |
+
("ffn.inner_proj", "ffn.fc1"),
|
| 110 |
+
("ffn.output_proj", "ffn.fc2"),
|
| 111 |
+
("output_proj", "out_proj"),
|
| 112 |
+
("decoder_frontend.embed", "decoder.embed_tokens"),
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
text_convert_list = [
|
| 116 |
+
("text_encoder.", ""),
|
| 117 |
+
("text_decoder.", ""),
|
| 118 |
+
("text_encoder_frontend.embed", "embed_tokens"),
|
| 119 |
+
("text_decoder_frontend.embed", "embed_tokens"),
|
| 120 |
+
("encoder_decoder_attn_layer_norm", "cross_attention_layer_norm"),
|
| 121 |
+
("encoder_decoder_attn", "cross_attention"),
|
| 122 |
+
("linear_k", "k_proj"),
|
| 123 |
+
("linear_v", "v_proj"),
|
| 124 |
+
("linear_q", "q_proj"),
|
| 125 |
+
("ffn.inner_proj", "ffn.fc1"),
|
| 126 |
+
("ffn.output_proj", "ffn.fc2"),
|
| 127 |
+
("output_proj", "out_proj"),
|
| 128 |
+
("final_proj", "lm_head"),
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
CUR_PATH = os.path.dirname(os.path.abspath(__file__))
|
| 132 |
+
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
|
| 133 |
+
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "huggingface", "hub")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _load_hf_config(model_type="medium"):
|
| 137 |
+
if model_type == "medium":
|
| 138 |
+
kwargs = {
|
| 139 |
+
"vocab_size": 256206,
|
| 140 |
+
"t2u_vocab_size": 10082,
|
| 141 |
+
"hidden_size": 1024,
|
| 142 |
+
"max_position_embeddings": 4096,
|
| 143 |
+
"encoder_layers": 12,
|
| 144 |
+
"decoder_layers": 12,
|
| 145 |
+
"encoder_ffn_dim": 4096,
|
| 146 |
+
"decoder_ffn_dim": 4096,
|
| 147 |
+
"t2u_encoder_layers": 4,
|
| 148 |
+
"t2u_decoder_layers": 4,
|
| 149 |
+
"speech_encoder_layers": 12,
|
| 150 |
+
}
|
| 151 |
+
return SeamlessM4TConfig(**kwargs)
|
| 152 |
+
else:
|
| 153 |
+
return SeamlessM4TConfig()
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _convert_model(
|
| 157 |
+
original_model,
|
| 158 |
+
hf_model,
|
| 159 |
+
convert_list,
|
| 160 |
+
device,
|
| 161 |
+
unwanted_prefix="model.",
|
| 162 |
+
filter_state_dict="speech",
|
| 163 |
+
exclude_state_dict=None,
|
| 164 |
+
):
|
| 165 |
+
state_dict = original_model.state_dict()
|
| 166 |
+
|
| 167 |
+
# filter func
|
| 168 |
+
if isinstance(filter_state_dict, str):
|
| 169 |
+
|
| 170 |
+
def filter_func(x):
|
| 171 |
+
return filter_state_dict in x[0]
|
| 172 |
+
|
| 173 |
+
else:
|
| 174 |
+
|
| 175 |
+
def filter_func(item):
|
| 176 |
+
if exclude_state_dict is not None and exclude_state_dict in item[0]:
|
| 177 |
+
return False
|
| 178 |
+
for filter_el in filter_state_dict:
|
| 179 |
+
if filter_el in item[0]:
|
| 180 |
+
return True
|
| 181 |
+
|
| 182 |
+
return False
|
| 183 |
+
|
| 184 |
+
state_dict = dict(filter(filter_func, state_dict.items()))
|
| 185 |
+
|
| 186 |
+
for k, v in list(state_dict.items()):
|
| 187 |
+
new_k = k[len(unwanted_prefix) :]
|
| 188 |
+
for old_layer_name, new_layer_name in convert_list:
|
| 189 |
+
if old_layer_name in new_k:
|
| 190 |
+
new_k = new_k.replace(old_layer_name, new_layer_name)
|
| 191 |
+
|
| 192 |
+
# must do it by hand
|
| 193 |
+
if ".layer_norm" in new_k and new_k.split(".layer_norm")[0][-1].isnumeric():
|
| 194 |
+
new_k = new_k.replace("layer_norm", "final_layer_norm")
|
| 195 |
+
|
| 196 |
+
state_dict[new_k] = state_dict.pop(k)
|
| 197 |
+
|
| 198 |
+
extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys())
|
| 199 |
+
extra_keys = set(extra_keys)
|
| 200 |
+
missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys())
|
| 201 |
+
missing_keys = set({k for k in missing_keys if "final_logits_bias" not in k})
|
| 202 |
+
if len(extra_keys) != 0:
|
| 203 |
+
raise ValueError(f"extra keys found: {extra_keys}")
|
| 204 |
+
if len(missing_keys) != 0:
|
| 205 |
+
raise ValueError(f"missing keys: {missing_keys}")
|
| 206 |
+
hf_model.load_state_dict(state_dict, strict=False)
|
| 207 |
+
n_params = param_count(hf_model)
|
| 208 |
+
|
| 209 |
+
logger.info(f"model loaded: {round(n_params / 1e6, 1)}M params")
|
| 210 |
+
|
| 211 |
+
hf_model.eval()
|
| 212 |
+
hf_model.to(device)
|
| 213 |
+
del state_dict
|
| 214 |
+
|
| 215 |
+
return hf_model
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def load_model(save_dir, model_type, repo_id):
|
| 219 |
+
"""
|
| 220 |
+
Meta SeamlessM4T is made of 8 main components:
|
| 221 |
+
- speech_encoder (#1) and speech_encoder_frontend (#2)
|
| 222 |
+
- t2u_model (#3)
|
| 223 |
+
- text_encoder (#4) and text_encoder_frontend (#5)
|
| 224 |
+
- text_decoder (#6) [and text_decoder_frontend (#5) = equals to text_encoder_frontend]
|
| 225 |
+
- final_proj (#7)
|
| 226 |
+
- vocoder (#8)
|
| 227 |
+
"""
|
| 228 |
+
device = _grab_best_device()
|
| 229 |
+
if model_type == "medium":
|
| 230 |
+
name = "seamlessM4T_medium"
|
| 231 |
+
else:
|
| 232 |
+
name = "seamlessM4T_large"
|
| 233 |
+
|
| 234 |
+
original_model = Translator(name, "vocoder_36langs", device, torch.float32)
|
| 235 |
+
|
| 236 |
+
######### TOKENIZER
|
| 237 |
+
|
| 238 |
+
langs = MEDIUM_SUPPORTED_LANGUAGES if model_type == "medium" else LARGE_SUPPORTED_LANGUAGES
|
| 239 |
+
langs = [f"__{lang}__" for lang in langs]
|
| 240 |
+
vocab_file = os.path.join(os.path.expanduser("~"), "tokenizer", model_type, "tokenizer.model")
|
| 241 |
+
|
| 242 |
+
save_dir = os.path.join(save_dir, name)
|
| 243 |
+
Path(save_dir).mkdir(exist_ok=True)
|
| 244 |
+
|
| 245 |
+
tokenizer = SeamlessM4TTokenizer(vocab_file, additional_special_tokens=langs)
|
| 246 |
+
|
| 247 |
+
sanity_check_lang_id = tokenizer.convert_tokens_to_ids("__fra__")
|
| 248 |
+
|
| 249 |
+
tokenizer.save_pretrained(save_dir)
|
| 250 |
+
tokenizer = SeamlessM4TTokenizer.from_pretrained(save_dir)
|
| 251 |
+
|
| 252 |
+
if sanity_check_lang_id != tokenizer.convert_tokens_to_ids("__fra__"):
|
| 253 |
+
raise ValueError(
|
| 254 |
+
f"Error in tokenizer saving/loading - __fra__ lang id is not coherent: {sanity_check_lang_id} vs {tokenizer.convert_tokens_to_ids('__fra__')}"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
####### get language to ids dict
|
| 258 |
+
text_decoder_lang_code_to_id = {lang.replace("__", ""): tokenizer.convert_tokens_to_ids(lang) for lang in langs}
|
| 259 |
+
# offset: vocoder unit vocab size + 5 (for EOS/PAD/BOS/UNK/MSK) + len(supported_languages)
|
| 260 |
+
t2u_lang_code_to_id = {
|
| 261 |
+
code.replace("__", ""): i + 10005 + len(UNIT_SUPPORTED_LANGUAGES)
|
| 262 |
+
for i, code in enumerate(UNIT_SUPPORTED_LANGUAGES)
|
| 263 |
+
}
|
| 264 |
+
vocoder_lang_code_to_id = {code.replace("__", ""): i for i, code in enumerate(VOCODER_SUPPORTED_LANGUAGES)}
|
| 265 |
+
|
| 266 |
+
######### FE
|
| 267 |
+
|
| 268 |
+
fe = SeamlessM4TFeatureExtractor(language_code=langs)
|
| 269 |
+
|
| 270 |
+
fe.save_pretrained(save_dir)
|
| 271 |
+
fe = SeamlessM4TFeatureExtractor.from_pretrained(save_dir)
|
| 272 |
+
|
| 273 |
+
processor = SeamlessM4TProcessor(feature_extractor=fe, tokenizer=tokenizer)
|
| 274 |
+
processor.save_pretrained(save_dir)
|
| 275 |
+
processor.push_to_hub(repo_id=repo_id, create_pr=True)
|
| 276 |
+
|
| 277 |
+
processor = SeamlessM4TProcessor.from_pretrained(save_dir)
|
| 278 |
+
|
| 279 |
+
######## Model
|
| 280 |
+
|
| 281 |
+
# init model
|
| 282 |
+
hf_config = _load_hf_config(model_type)
|
| 283 |
+
hf_model = SeamlessM4TModel(hf_config)
|
| 284 |
+
|
| 285 |
+
hf_model.generation_config.__setattr__("text_decoder_lang_to_code_id", text_decoder_lang_code_to_id)
|
| 286 |
+
hf_model.generation_config.__setattr__("t2u_lang_code_to_id", t2u_lang_code_to_id)
|
| 287 |
+
hf_model.generation_config.__setattr__("vocoder_lang_code_to_id", vocoder_lang_code_to_id)
|
| 288 |
+
|
| 289 |
+
# -1. take care of vocoder
|
| 290 |
+
# similarly to speech T5 must apply and remove weight norm
|
| 291 |
+
hf_model.vocoder.apply_weight_norm()
|
| 292 |
+
hf_model.vocoder = _convert_model(
|
| 293 |
+
original_model,
|
| 294 |
+
hf_model.vocoder,
|
| 295 |
+
vocoder_convert_list,
|
| 296 |
+
device,
|
| 297 |
+
unwanted_prefix="vocoder.code_generator.",
|
| 298 |
+
filter_state_dict="vocoder",
|
| 299 |
+
)
|
| 300 |
+
hf_model.vocoder.remove_weight_norm()
|
| 301 |
+
|
| 302 |
+
# 1. take care of speech encoder
|
| 303 |
+
wav2vec = hf_model.speech_encoder
|
| 304 |
+
hf_model.speech_encoder = _convert_model(
|
| 305 |
+
original_model, wav2vec, wav2vec_convert_list, device, unwanted_prefix="model.", filter_state_dict="speech"
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# 2. take care of t2u
|
| 309 |
+
|
| 310 |
+
hf_model.t2u_model = _convert_model(
|
| 311 |
+
original_model,
|
| 312 |
+
hf_model.t2u_model,
|
| 313 |
+
t2u_convert_list,
|
| 314 |
+
device,
|
| 315 |
+
unwanted_prefix="model.",
|
| 316 |
+
filter_state_dict="t2u_model",
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# 3. take care of text encoder
|
| 320 |
+
hf_model.text_encoder = _convert_model(
|
| 321 |
+
original_model,
|
| 322 |
+
hf_model.text_encoder,
|
| 323 |
+
text_convert_list,
|
| 324 |
+
device,
|
| 325 |
+
unwanted_prefix="model.",
|
| 326 |
+
filter_state_dict=["model.text_encoder"],
|
| 327 |
+
exclude_state_dict="t2u_model",
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# 4. take care of text decoder
|
| 331 |
+
hf_model.text_decoder = _convert_model(
|
| 332 |
+
original_model,
|
| 333 |
+
hf_model.text_decoder,
|
| 334 |
+
text_convert_list,
|
| 335 |
+
device,
|
| 336 |
+
unwanted_prefix="model.",
|
| 337 |
+
filter_state_dict=["model.text_decoder"],
|
| 338 |
+
exclude_state_dict="t2u_model",
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# 5. take care of final proj
|
| 342 |
+
hf_model.lm_head = _convert_model(
|
| 343 |
+
original_model,
|
| 344 |
+
hf_model.lm_head,
|
| 345 |
+
[("final_proj.", "")],
|
| 346 |
+
device,
|
| 347 |
+
unwanted_prefix="model.",
|
| 348 |
+
filter_state_dict=["model.final_proj"],
|
| 349 |
+
exclude_state_dict="t2u_model",
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# sanity check
|
| 353 |
+
print(find_tied_parameters(hf_model))
|
| 354 |
+
|
| 355 |
+
count_1 = param_count(hf_model)
|
| 356 |
+
count_2 = param_count(original_model)
|
| 357 |
+
|
| 358 |
+
print(f"HF MODEL:{count_1}, ORIGINAL_MODEL: {count_2}, diff:{count_1 - count_2}")
|
| 359 |
+
print(f"HF MODEL excluding embeddings:{hf_model.num_parameters(exclude_embeddings=True)}")
|
| 360 |
+
|
| 361 |
+
del original_model
|
| 362 |
+
|
| 363 |
+
hf_model.generation_config._from_model_config = False
|
| 364 |
+
hf_model.save_pretrained(save_dir)
|
| 365 |
+
hf_model.push_to_hub(repo_id=repo_id, create_pr=True)
|
| 366 |
+
hf_model = SeamlessM4TModel.from_pretrained(save_dir)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
if __name__ == "__main__":
|
| 370 |
+
parser = argparse.ArgumentParser()
|
| 371 |
+
# Required parameters
|
| 372 |
+
|
| 373 |
+
parser.add_argument(
|
| 374 |
+
"--model_type",
|
| 375 |
+
default="medium",
|
| 376 |
+
type=str,
|
| 377 |
+
help="Model type.",
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
parser.add_argument(
|
| 381 |
+
"--save_dir",
|
| 382 |
+
default="/home/ubuntu/weights",
|
| 383 |
+
type=str,
|
| 384 |
+
help="Path to the output PyTorch model.",
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
parser.add_argument(
|
| 388 |
+
"--repo_id",
|
| 389 |
+
default="facebook/hf-seamless-m4t-medium",
|
| 390 |
+
type=str,
|
| 391 |
+
help="Repo ID.",
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
args = parser.parse_args()
|
| 395 |
+
|
| 396 |
+
load_model(args.save_dir, args.model_type, args.repo_id)
|
docs/transformers/build/lib/transformers/models/seamless_m4t/modeling_seamless_m4t.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
docs/transformers/build/lib/transformers/models/seamless_m4t/processing_seamless_m4t.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Audio/Text processor class for SeamlessM4T
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from ...processing_utils import ProcessorMixin
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SeamlessM4TProcessor(ProcessorMixin):
|
| 23 |
+
r"""
|
| 24 |
+
Constructs a SeamlessM4T processor which wraps a SeamlessM4T feature extractor and a SeamlessM4T tokenizer into a
|
| 25 |
+
single processor.
|
| 26 |
+
|
| 27 |
+
[`SeamlessM4TProcessor`] offers all the functionalities of [`SeamlessM4TFeatureExtractor`] and
|
| 28 |
+
[`SeamlessM4TTokenizerFast`]. See the [`~SeamlessM4TProcessor.__call__`] and [`~SeamlessM4TProcessor.decode`] for
|
| 29 |
+
more information.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
feature_extractor ([`SeamlessM4TFeatureExtractor`]):
|
| 33 |
+
The audio processor is a required input.
|
| 34 |
+
tokenizer ([`SeamlessM4TTokenizerFast`]):
|
| 35 |
+
The tokenizer is a required input.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
feature_extractor_class = "SeamlessM4TFeatureExtractor"
|
| 39 |
+
tokenizer_class = ("SeamlessM4TTokenizer", "SeamlessM4TTokenizerFast")
|
| 40 |
+
|
| 41 |
+
def __init__(self, feature_extractor, tokenizer):
|
| 42 |
+
super().__init__(feature_extractor, tokenizer)
|
| 43 |
+
|
| 44 |
+
def __call__(self, text=None, audios=None, src_lang=None, tgt_lang=None, **kwargs):
|
| 45 |
+
"""
|
| 46 |
+
Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
|
| 47 |
+
and `kwargs` arguments to SeamlessM4TTokenizerFast's [`~SeamlessM4TTokenizerFast.__call__`] if `text` is not
|
| 48 |
+
`None` to encode the text. To prepare the audio(s), this method forwards the `audios` and `kwrags` arguments to
|
| 49 |
+
SeamlessM4TFeatureExtractor's [`~SeamlessM4TFeatureExtractor.__call__`] if `audios` is not `None`. Please refer
|
| 50 |
+
to the docstring of the above two methods for more information.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
| 54 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 55 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 56 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 57 |
+
audios (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
| 58 |
+
The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case
|
| 59 |
+
of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels,
|
| 60 |
+
and T the sample length of the audio.
|
| 61 |
+
src_lang (`str`, *optional*):
|
| 62 |
+
The language code of the input texts/audios. If not specified, the last `src_lang` specified will be
|
| 63 |
+
used.
|
| 64 |
+
tgt_lang (`str`, *optional*):
|
| 65 |
+
The code of the target language. If not specified, the last `tgt_lang` specified will be used.
|
| 66 |
+
kwargs (*optional*):
|
| 67 |
+
Remaining dictionary of keyword arguments that will be passed to the feature extractor and/or the
|
| 68 |
+
tokenizer.
|
| 69 |
+
Returns:
|
| 70 |
+
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
|
| 71 |
+
|
| 72 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
| 73 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
| 74 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
| 75 |
+
`None`).
|
| 76 |
+
- **input_features** -- Audio input features to be fed to a model. Returned when `audios` is not `None`.
|
| 77 |
+
"""
|
| 78 |
+
sampling_rate = kwargs.pop("sampling_rate", None)
|
| 79 |
+
|
| 80 |
+
if text is None and audios is None:
|
| 81 |
+
raise ValueError("You have to specify either text or audios. Both cannot be none.")
|
| 82 |
+
elif text is not None and audios is not None:
|
| 83 |
+
raise ValueError(
|
| 84 |
+
"Text and audios are mututally exclusive when passed to `SeamlessM4T`. Specify one or another."
|
| 85 |
+
)
|
| 86 |
+
elif text is not None:
|
| 87 |
+
if tgt_lang is not None:
|
| 88 |
+
self.tokenizer.tgt_lang = tgt_lang
|
| 89 |
+
if src_lang is not None:
|
| 90 |
+
self.tokenizer.src_lang = src_lang
|
| 91 |
+
encoding = self.tokenizer(text, **kwargs)
|
| 92 |
+
|
| 93 |
+
return encoding
|
| 94 |
+
|
| 95 |
+
else:
|
| 96 |
+
encoding = self.feature_extractor(audios, sampling_rate=sampling_rate, **kwargs)
|
| 97 |
+
return encoding
|
| 98 |
+
|
| 99 |
+
def batch_decode(self, *args, **kwargs):
|
| 100 |
+
"""
|
| 101 |
+
This method forwards all its arguments to SeamlessM4TTokenizerFast's [`~PreTrainedTokenizer.batch_decode`].
|
| 102 |
+
Please refer to the docstring of this method for more information.
|
| 103 |
+
"""
|
| 104 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 105 |
+
|
| 106 |
+
def decode(self, *args, **kwargs):
|
| 107 |
+
"""
|
| 108 |
+
This method forwards all its arguments to SeamlessM4TTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please
|
| 109 |
+
refer to the docstring of this method for more information.
|
| 110 |
+
"""
|
| 111 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 112 |
+
|
| 113 |
+
@property
|
| 114 |
+
def model_input_names(self):
|
| 115 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 116 |
+
feature_extractor_input_names = self.feature_extractor.model_input_names
|
| 117 |
+
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
__all__ = ["SeamlessM4TProcessor"]
|
docs/transformers/build/lib/transformers/models/seamless_m4t/tokenization_seamless_m4t_fast.py
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Fast Tokenization class for SeamlessM4T."""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
from shutil import copyfile
|
| 19 |
+
from typing import List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
from tokenizers import processors
|
| 22 |
+
|
| 23 |
+
from ...tokenization_utils import (
|
| 24 |
+
BatchEncoding,
|
| 25 |
+
PreTokenizedInput,
|
| 26 |
+
TextInput,
|
| 27 |
+
)
|
| 28 |
+
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
| 29 |
+
from ...utils import PaddingStrategy, is_sentencepiece_available, logging
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if is_sentencepiece_available():
|
| 33 |
+
from .tokenization_seamless_m4t import SeamlessM4TTokenizer
|
| 34 |
+
else:
|
| 35 |
+
SeamlessM4TTokenizer = None
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__)
|
| 38 |
+
|
| 39 |
+
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class SeamlessM4TTokenizerFast(PreTrainedTokenizerFast):
|
| 43 |
+
"""
|
| 44 |
+
Construct a "fast" SeamlessM4T tokenizer (backed by HuggingFace's *tokenizers* library). Based on
|
| 45 |
+
[BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models).
|
| 46 |
+
|
| 47 |
+
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
| 48 |
+
refer to this superclass for more information regarding those methods.
|
| 49 |
+
|
| 50 |
+
The tokenization method is `<language code> <tokens> <eos>` for source language documents, and `<eos> <language
|
| 51 |
+
code> <tokens> <eos>` for target language documents.
|
| 52 |
+
|
| 53 |
+
Examples:
|
| 54 |
+
|
| 55 |
+
```python
|
| 56 |
+
>>> from transformers import SeamlessM4TTokenizerFast
|
| 57 |
+
|
| 58 |
+
>>> tokenizer = SeamlessM4TTokenizerFast.from_pretrained(
|
| 59 |
+
... "facebook/hf-seamless-m4t-medium", src_lang="eng", tgt_lang="fra"
|
| 60 |
+
... )
|
| 61 |
+
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
| 62 |
+
>>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie."
|
| 63 |
+
>>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt")
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
vocab_file (`str`, *optional*):
|
| 68 |
+
Path to the vocabulary file.
|
| 69 |
+
tokenizer_file (`str`, *optional*):
|
| 70 |
+
The path to a tokenizer file to use instead of the vocab file.
|
| 71 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 72 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 73 |
+
|
| 74 |
+
<Tip>
|
| 75 |
+
|
| 76 |
+
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
| 77 |
+
sequence. The token used is the `cls_token`.
|
| 78 |
+
|
| 79 |
+
</Tip>
|
| 80 |
+
|
| 81 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 82 |
+
The end of sequence token.
|
| 83 |
+
|
| 84 |
+
<Tip>
|
| 85 |
+
|
| 86 |
+
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
| 87 |
+
The token used is the `sep_token`.
|
| 88 |
+
|
| 89 |
+
</Tip>
|
| 90 |
+
|
| 91 |
+
sep_token (`str`, *optional*, defaults to `"</s>"`):
|
| 92 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 93 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 94 |
+
token of a sequence built with special tokens.
|
| 95 |
+
cls_token (`str`, *optional*, defaults to `"<s>"`):
|
| 96 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
| 97 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
| 98 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 99 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 100 |
+
token instead.
|
| 101 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 102 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 103 |
+
src_lang (`str`, *optional*, defaults to `"eng"`):
|
| 104 |
+
The language to use as source language for translation.
|
| 105 |
+
tgt_lang (`str`, *optional*, defaults to `"fra"`):
|
| 106 |
+
The language to use as target language for translation.
|
| 107 |
+
additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*):
|
| 108 |
+
A tuple or a list of additional special tokens.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 112 |
+
slow_tokenizer_class = SeamlessM4TTokenizer
|
| 113 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 114 |
+
|
| 115 |
+
prefix_tokens: List[int] = []
|
| 116 |
+
suffix_tokens: List[int] = []
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
vocab_file=None,
|
| 121 |
+
tokenizer_file=None,
|
| 122 |
+
bos_token="<s>",
|
| 123 |
+
eos_token="</s>",
|
| 124 |
+
sep_token="</s>",
|
| 125 |
+
cls_token="<s>",
|
| 126 |
+
unk_token="<unk>",
|
| 127 |
+
pad_token="<pad>",
|
| 128 |
+
src_lang="eng",
|
| 129 |
+
tgt_lang="fra",
|
| 130 |
+
additional_special_tokens=None,
|
| 131 |
+
**kwargs,
|
| 132 |
+
):
|
| 133 |
+
super().__init__(
|
| 134 |
+
vocab_file=vocab_file,
|
| 135 |
+
tokenizer_file=tokenizer_file,
|
| 136 |
+
bos_token=bos_token,
|
| 137 |
+
eos_token=eos_token,
|
| 138 |
+
sep_token=sep_token,
|
| 139 |
+
cls_token=cls_token,
|
| 140 |
+
unk_token=unk_token,
|
| 141 |
+
pad_token=pad_token,
|
| 142 |
+
src_lang=src_lang,
|
| 143 |
+
tgt_lang=tgt_lang,
|
| 144 |
+
additional_special_tokens=additional_special_tokens,
|
| 145 |
+
**kwargs,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
self.vocab_file = vocab_file
|
| 149 |
+
self._src_lang = f"__{src_lang}__" if "__" not in src_lang else src_lang
|
| 150 |
+
self._tgt_lang = f"__{tgt_lang}__" if "__" not in tgt_lang else tgt_lang
|
| 151 |
+
self.set_src_lang_special_tokens(self._src_lang)
|
| 152 |
+
self.set_tgt_lang_special_tokens(self._tgt_lang)
|
| 153 |
+
|
| 154 |
+
@property
|
| 155 |
+
def can_save_slow_tokenizer(self) -> bool:
|
| 156 |
+
return os.path.isfile(self.vocab_file) if self.vocab_file else False
|
| 157 |
+
|
| 158 |
+
@property
|
| 159 |
+
# Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.src_lang
|
| 160 |
+
def src_lang(self) -> str:
|
| 161 |
+
return self._src_lang
|
| 162 |
+
|
| 163 |
+
@src_lang.setter
|
| 164 |
+
def src_lang(self, new_src_lang: str) -> None:
|
| 165 |
+
if "__" not in new_src_lang:
|
| 166 |
+
self._src_lang = f"__{new_src_lang}__"
|
| 167 |
+
else:
|
| 168 |
+
self._src_lang = new_src_lang
|
| 169 |
+
self.set_src_lang_special_tokens(self._src_lang)
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def tgt_lang(self) -> str:
|
| 173 |
+
return self._tgt_lang
|
| 174 |
+
|
| 175 |
+
@tgt_lang.setter
|
| 176 |
+
def tgt_lang(self, new_tgt_lang: str) -> None:
|
| 177 |
+
if "__" not in new_tgt_lang:
|
| 178 |
+
self._tgt_lang = f"__{new_tgt_lang}__"
|
| 179 |
+
else:
|
| 180 |
+
self._tgt_lang = new_tgt_lang
|
| 181 |
+
self.set_tgt_lang_special_tokens(self._tgt_lang)
|
| 182 |
+
|
| 183 |
+
def build_inputs_with_special_tokens(
|
| 184 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 185 |
+
) -> List[int]:
|
| 186 |
+
"""
|
| 187 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 188 |
+
adding special tokens. The special tokens depend on calling set_lang.
|
| 189 |
+
|
| 190 |
+
An SeamlessM4T sequence has the following format, where `X` represents the sequence:
|
| 191 |
+
|
| 192 |
+
- `input_ids` (for encoder) `[src_lang_code] X [eos]`
|
| 193 |
+
- `decoder_input_ids`: (for decoder) `[eos, tgt_lang_code] X [eos]`
|
| 194 |
+
|
| 195 |
+
BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
|
| 196 |
+
separator.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
token_ids_0 (`List[int]`):
|
| 200 |
+
List of IDs to which the special tokens will be added.
|
| 201 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 202 |
+
Optional second list of IDs for sequence pairs.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
`List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 206 |
+
"""
|
| 207 |
+
if token_ids_1 is None:
|
| 208 |
+
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
|
| 209 |
+
# We don't expect to process pairs, but leave the pair logic for API consistency
|
| 210 |
+
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
| 211 |
+
|
| 212 |
+
# Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast.create_token_type_ids_from_sequences
|
| 213 |
+
def create_token_type_ids_from_sequences(
|
| 214 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 215 |
+
) -> List[int]:
|
| 216 |
+
"""
|
| 217 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not
|
| 218 |
+
make use of token type ids, therefore a list of zeros is returned.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
token_ids_0 (`List[int]`):
|
| 222 |
+
List of IDs.
|
| 223 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 224 |
+
Optional second list of IDs for sequence pairs.
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
`List[int]`: List of zeros.
|
| 228 |
+
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
sep = [self.sep_token_id]
|
| 232 |
+
cls = [self.cls_token_id]
|
| 233 |
+
|
| 234 |
+
if token_ids_1 is None:
|
| 235 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 236 |
+
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
| 237 |
+
|
| 238 |
+
def _build_translation_inputs(
|
| 239 |
+
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
|
| 240 |
+
):
|
| 241 |
+
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
| 242 |
+
if src_lang is None or tgt_lang is None:
|
| 243 |
+
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
| 244 |
+
self.src_lang = src_lang
|
| 245 |
+
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
|
| 246 |
+
if "__" not in tgt_lang:
|
| 247 |
+
tgt_lang = f"__{tgt_lang}__"
|
| 248 |
+
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
| 249 |
+
inputs["forced_bos_token_id"] = tgt_lang_id
|
| 250 |
+
return inputs
|
| 251 |
+
|
| 252 |
+
# Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast.prepare_seq2seq_batch with "fra_Latn"->"fra", "eng_Latn"->"eng"
|
| 253 |
+
def prepare_seq2seq_batch(
|
| 254 |
+
self,
|
| 255 |
+
src_texts: List[str],
|
| 256 |
+
src_lang: str = "eng",
|
| 257 |
+
tgt_texts: Optional[List[str]] = None,
|
| 258 |
+
tgt_lang: str = "fra",
|
| 259 |
+
**kwargs,
|
| 260 |
+
) -> BatchEncoding:
|
| 261 |
+
self.src_lang = src_lang
|
| 262 |
+
self.tgt_lang = tgt_lang
|
| 263 |
+
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
| 264 |
+
|
| 265 |
+
# Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast._switch_to_input_mode
|
| 266 |
+
def _switch_to_input_mode(self):
|
| 267 |
+
return self.set_src_lang_special_tokens(self.src_lang)
|
| 268 |
+
|
| 269 |
+
# Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast._switch_to_target_mode
|
| 270 |
+
def _switch_to_target_mode(self):
|
| 271 |
+
return self.set_tgt_lang_special_tokens(self.tgt_lang)
|
| 272 |
+
|
| 273 |
+
def set_src_lang_special_tokens(self, src_lang) -> None:
|
| 274 |
+
"""Reset the special tokens to the source lang setting.
|
| 275 |
+
Prefix=[src_lang_code], suffix = [eos]
|
| 276 |
+
"""
|
| 277 |
+
self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
|
| 278 |
+
|
| 279 |
+
if self.cur_lang_code == self.unk_token_id:
|
| 280 |
+
logger.warning_once(
|
| 281 |
+
f"`tgt_lang={src_lang}` has not be found in the `vocabulary`. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id."
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
self.init_kwargs["src_lang"] = src_lang
|
| 285 |
+
|
| 286 |
+
self.prefix_tokens = [self.cur_lang_code]
|
| 287 |
+
self.suffix_tokens = [self.eos_token_id]
|
| 288 |
+
|
| 289 |
+
prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
|
| 290 |
+
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
|
| 291 |
+
|
| 292 |
+
self._tokenizer.post_processor = processors.TemplateProcessing(
|
| 293 |
+
single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
|
| 294 |
+
pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
|
| 295 |
+
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
| 299 |
+
"""Reset the special tokens to the target lang setting.
|
| 300 |
+
Prefix=[eos, tgt_lang_code] and suffix=[eos].
|
| 301 |
+
"""
|
| 302 |
+
self.cur_lang_code = self.convert_tokens_to_ids(lang)
|
| 303 |
+
|
| 304 |
+
if self.cur_lang_code == self.unk_token_id:
|
| 305 |
+
logger.warning_once(
|
| 306 |
+
f"`tgt_lang={lang}` has not be found in the `vocabulary`. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id."
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
self.init_kwargs["tgt_lang"] = lang
|
| 310 |
+
|
| 311 |
+
self.prefix_tokens = [self.eos_token_id, self.cur_lang_code]
|
| 312 |
+
self.suffix_tokens = [self.eos_token_id]
|
| 313 |
+
|
| 314 |
+
prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
|
| 315 |
+
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
|
| 316 |
+
|
| 317 |
+
self._tokenizer.post_processor = processors.TemplateProcessing(
|
| 318 |
+
single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
|
| 319 |
+
pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
|
| 320 |
+
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast.save_vocabulary
|
| 324 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 325 |
+
if not self.can_save_slow_tokenizer:
|
| 326 |
+
raise ValueError(
|
| 327 |
+
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
| 328 |
+
"tokenizer."
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if not os.path.isdir(save_directory):
|
| 332 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory.")
|
| 333 |
+
return
|
| 334 |
+
out_vocab_file = os.path.join(
|
| 335 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
| 339 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 340 |
+
|
| 341 |
+
return (out_vocab_file,)
|
| 342 |
+
|
| 343 |
+
@classmethod
|
| 344 |
+
def _from_pretrained(
|
| 345 |
+
cls,
|
| 346 |
+
resolved_vocab_files,
|
| 347 |
+
pretrained_model_name_or_path,
|
| 348 |
+
init_configuration,
|
| 349 |
+
*init_inputs,
|
| 350 |
+
token=None,
|
| 351 |
+
cache_dir=None,
|
| 352 |
+
local_files_only=False,
|
| 353 |
+
_commit_hash=None,
|
| 354 |
+
_is_local=False,
|
| 355 |
+
**kwargs,
|
| 356 |
+
):
|
| 357 |
+
tokenizer = super()._from_pretrained(
|
| 358 |
+
resolved_vocab_files,
|
| 359 |
+
pretrained_model_name_or_path,
|
| 360 |
+
init_configuration,
|
| 361 |
+
*init_inputs,
|
| 362 |
+
token=token,
|
| 363 |
+
cache_dir=cache_dir,
|
| 364 |
+
local_files_only=local_files_only,
|
| 365 |
+
_commit_hash=_commit_hash,
|
| 366 |
+
_is_local=_is_local,
|
| 367 |
+
**kwargs,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# ensure also set after from pretrained
|
| 371 |
+
tokenizer.set_src_lang_special_tokens(tokenizer._src_lang)
|
| 372 |
+
tokenizer.set_tgt_lang_special_tokens(tokenizer._tgt_lang)
|
| 373 |
+
|
| 374 |
+
return tokenizer
|
| 375 |
+
|
| 376 |
+
def __call__(
|
| 377 |
+
self,
|
| 378 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
| 379 |
+
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
| 380 |
+
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
| 381 |
+
text_pair_target: Optional[
|
| 382 |
+
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
|
| 383 |
+
] = None,
|
| 384 |
+
padding: Union[bool, str, PaddingStrategy] = True,
|
| 385 |
+
pad_to_multiple_of: Optional[int] = 2,
|
| 386 |
+
src_lang: Optional[str] = None,
|
| 387 |
+
tgt_lang: Optional[str] = None,
|
| 388 |
+
**kwargs,
|
| 389 |
+
):
|
| 390 |
+
"""
|
| 391 |
+
Args:
|
| 392 |
+
text (`str`, `List[str]`, `List[List[str]]`, *optional*):
|
| 393 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 394 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 395 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 396 |
+
text_pair (`str`, `List[str]`, `List[List[str]]`, *optional*):
|
| 397 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 398 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 399 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 400 |
+
text_target (`str`, `List[str]`, `List[List[str]]`, *optional*):
|
| 401 |
+
The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
|
| 402 |
+
list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
|
| 403 |
+
you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 404 |
+
text_pair_target (`str`, `List[str]`, `List[List[str]]`, *optional*):
|
| 405 |
+
The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
|
| 406 |
+
list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
|
| 407 |
+
you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 408 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
| 409 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
| 410 |
+
index) among:
|
| 411 |
+
|
| 412 |
+
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
| 413 |
+
sequence if provided).
|
| 414 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
| 415 |
+
acceptable input length for the model if that argument is not provided.
|
| 416 |
+
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
| 417 |
+
lengths).
|
| 418 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 419 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 420 |
+
|
| 421 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
| 422 |
+
`>= 7.5` (Volta).
|
| 423 |
+
src_lang (`str`, *optional*):
|
| 424 |
+
A string representing the source language. If not specified, the last `src_lang` specified (either
|
| 425 |
+
during initialization or when calling this tokenizer) will be used.
|
| 426 |
+
tgt_lang (`str`, *optional*):
|
| 427 |
+
A string representing the target language. If not specified, the last `tgt_lang` specified (either
|
| 428 |
+
during initialization or when calling this tokenizer) will be used.
|
| 429 |
+
kwargs (*optional*):
|
| 430 |
+
Remaining dictionary of keyword arguments that will be passed to [`PreTrainedTokenizerFast.__call__`].
|
| 431 |
+
"""
|
| 432 |
+
if src_lang is not None:
|
| 433 |
+
self.src_lang = src_lang
|
| 434 |
+
if tgt_lang is not None:
|
| 435 |
+
self.tgt_lang = tgt_lang
|
| 436 |
+
|
| 437 |
+
output = super().__call__(
|
| 438 |
+
text=text,
|
| 439 |
+
text_pair=text_pair,
|
| 440 |
+
text_target=text_target,
|
| 441 |
+
text_pair_target=text_pair_target,
|
| 442 |
+
padding=padding,
|
| 443 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 444 |
+
**kwargs,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
return output
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
__all__ = ["SeamlessM4TTokenizerFast"]
|
docs/transformers/build/lib/transformers/models/seamless_m4t_v2/convert_fairseq2_to_hf.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Converting Meta SeamlessM4Tv2 checkpoints from seamless_communication to HF."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import os
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from accelerate.utils.modeling import find_tied_parameters
|
| 23 |
+
from seamless_communication.inference import Translator
|
| 24 |
+
|
| 25 |
+
from transformers import (
|
| 26 |
+
SeamlessM4TFeatureExtractor,
|
| 27 |
+
SeamlessM4TProcessor,
|
| 28 |
+
SeamlessM4TTokenizer,
|
| 29 |
+
SeamlessM4Tv2Config,
|
| 30 |
+
SeamlessM4Tv2Model,
|
| 31 |
+
)
|
| 32 |
+
from transformers.utils import logging
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# fmt: off
|
| 36 |
+
UNIT_SUPPORTED_LANGUAGES = ["__arb__", "__ben__", "__cat__", "__ces__", "__cmn__", "__cym__", "__dan__", "__deu__", "__eng__", "__est__", "__fin__", "__fra__", "__hin__", "__ind__", "__ita__", "__jpn__", "__kan__", "__kor__", "__mlt__", "__nld__", "__pes__", "__pol__", "__por__", "__ron__", "__rus__", "__slk__", "__spa__", "__swe__", "__swh__", "__tam__", "__tel__", "__tgl__", "__tha__", "__tur__", "__ukr__", "__urd__", "__uzn__", "__vie__", ]
|
| 37 |
+
# fmt: on
|
| 38 |
+
|
| 39 |
+
# fmt: off
|
| 40 |
+
VOCODER_SUPPORTED_LANGUAGES = ["__arb__", "__ben__", "__cat__", "__ces__", "__cmn__", "__cym__", "__dan__", "__deu__", "__eng__", "__est__", "__fin__", "__fra__", "__hin__", "__ind__", "__ita__", "__jpn__", "__kor__", "__mlt__", "__nld__", "__pes__", "__pol__", "__por__", "__ron__", "__rus__", "__slk__", "__spa__", "__swe__", "__swh__", "__tel__", "__tgl__", "__tha__", "__tur__", "__ukr__", "__urd__", "__uzn__", "__vie__",]
|
| 41 |
+
# fmt: on
|
| 42 |
+
|
| 43 |
+
# fmt: off
|
| 44 |
+
LARGE_SUPPORTED_LANGUAGES = ["afr","amh","arb","ary","arz","asm","azj","bel","ben","bos","bul","cat","ceb","ces","ckb","cmn","cmn_Hant","cym","dan","deu","ell","eng","est","eus","fin","fra","fuv","gaz","gle","glg","guj","heb","hin","hrv","hun","hye","ibo","ind","isl","ita","jav","jpn","kan","kat","kaz","khk","khm","kir","kor","lao","lit","lug","luo","lvs","mai","mal","mar","mkd","mlt","mni","mya","nld","nno","nob","npi","nya","ory","pan","pbt","pes","pol","por","ron","rus","sat","slk","slv","sna","snd","som","spa","srp","swe","swh","tam","tel","tgk","tgl","tha","tur","ukr","urd","uzn","vie","yor","yue","zlm","zul",]
|
| 45 |
+
# fmt: on
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def assert_param_count(model_1, model_2):
|
| 49 |
+
count_1 = sum(p[1].numel() for p in model_1.named_parameters() if "final_proj" not in p[0])
|
| 50 |
+
count_2 = sum(p[1].numel() for p in model_2.named_parameters() if "final_proj" not in p[0])
|
| 51 |
+
assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def param_count(model):
|
| 55 |
+
return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0])
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _grab_best_device(use_gpu=True):
|
| 59 |
+
if torch.cuda.device_count() > 0 and use_gpu:
|
| 60 |
+
device = "cuda"
|
| 61 |
+
else:
|
| 62 |
+
device = "cpu"
|
| 63 |
+
return torch.device(device)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
logging.set_verbosity_info()
|
| 67 |
+
logger = logging.get_logger(__name__)
|
| 68 |
+
|
| 69 |
+
vocoder_convert_list = [
|
| 70 |
+
("ups", "hifi_gan.upsampler"),
|
| 71 |
+
("conv_pre", "hifi_gan.conv_pre"),
|
| 72 |
+
("resblocks", "hifi_gan.resblocks"),
|
| 73 |
+
("conv_post", "hifi_gan.conv_post"),
|
| 74 |
+
("lang", "language_embedding"),
|
| 75 |
+
("spkr", "speaker_embedding"),
|
| 76 |
+
("dict.", "unit_embedding."),
|
| 77 |
+
("dur_predictor.conv1.0", "dur_predictor.conv1"),
|
| 78 |
+
("dur_predictor.conv2.0", "dur_predictor.conv2"),
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
# order is important
|
| 82 |
+
wav2vec_convert_list = [
|
| 83 |
+
("speech_encoder_frontend.model_dim_proj", "feature_projection.projection"),
|
| 84 |
+
("speech_encoder_frontend.post_extract_layer_norm", "feature_projection.layer_norm"),
|
| 85 |
+
("speech_encoder_frontend.pos_encoder.conv", "encoder.pos_conv_embed.conv"),
|
| 86 |
+
("speech_encoder.inner.layers", "encoder.layers"),
|
| 87 |
+
("speech_encoder.inner_layer_norm", "encoder.layer_norm"),
|
| 88 |
+
("speech_encoder.adaptor_layers", "adapter.layers"),
|
| 89 |
+
("inner_proj", "intermediate_dense"),
|
| 90 |
+
("self_attn.output_proj", "self_attn.linear_out"),
|
| 91 |
+
("output_proj", "output_dense"),
|
| 92 |
+
("self_attn.k_proj", "self_attn.linear_k"),
|
| 93 |
+
("self_attn.v_proj", "self_attn.linear_v"),
|
| 94 |
+
("self_attn.q_proj", "self_attn.linear_q"),
|
| 95 |
+
("self_attn.sdpa.u_bias", "self_attn.pos_bias_u"),
|
| 96 |
+
("self_attn.sdpa.v_bias", "self_attn.pos_bias_v"),
|
| 97 |
+
("self_attn.sdpa.rel_k_embed", "self_attn.distance_embedding"),
|
| 98 |
+
("self_attn.sdpa.r_proj", "self_attn.linear_pos"),
|
| 99 |
+
("conv.pointwise_conv1", "conv_module.pointwise_conv1"),
|
| 100 |
+
("conv.pointwise_conv2", "conv_module.pointwise_conv2"),
|
| 101 |
+
("conv.depthwise_conv", "conv_module.depthwise_conv"),
|
| 102 |
+
("conv.batch_norm", "conv_module.batch_norm"),
|
| 103 |
+
("conv.layer_norm", "conv_module.depthwise_layer_norm"),
|
| 104 |
+
("conv_layer_norm", "conv_module.layer_norm"),
|
| 105 |
+
("speech_encoder.proj1", "intermediate_ffn.intermediate_dense"),
|
| 106 |
+
("speech_encoder.proj2", "intermediate_ffn.output_dense"),
|
| 107 |
+
("speech_encoder.layer_norm", "inner_layer_norm"),
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
t2u_convert_list = [
|
| 111 |
+
("t2u_model.final_proj", "lm_head"),
|
| 112 |
+
("t2u_model.", "model."),
|
| 113 |
+
("encoder_decoder_attn_layer_norm", "cross_attention_layer_norm"),
|
| 114 |
+
("encoder_decoder_attn", "cross_attention"),
|
| 115 |
+
("linear_k", "k_proj"),
|
| 116 |
+
("linear_v", "v_proj"),
|
| 117 |
+
("linear_q", "q_proj"),
|
| 118 |
+
("ffn.inner_proj", "ffn.fc1"),
|
| 119 |
+
("ffn.output_proj", "ffn.fc2"),
|
| 120 |
+
("output_proj", "out_proj"),
|
| 121 |
+
("decoder_frontend.embed_char", "decoder.embed_char"),
|
| 122 |
+
("decoder_frontend.pos_emb_alpha_char", "decoder.pos_emb_alpha_char"),
|
| 123 |
+
("decoder_frontend.embed", "decoder.embed_tokens"),
|
| 124 |
+
("decoder_frontend.pos_emb_alpha", "decoder.pos_emb_alpha"),
|
| 125 |
+
("conv1d.conv", "conv"),
|
| 126 |
+
("conv1d_layer_norm", "conv_layer_norm"),
|
| 127 |
+
("decoder_frontend.variance_adaptor", "decoder"),
|
| 128 |
+
("duration_predictor.conv1.0", "duration_predictor.conv1"),
|
| 129 |
+
("duration_predictor.conv2.0", "duration_predictor.conv2"),
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
text_convert_list = [
|
| 133 |
+
("text_encoder.", ""),
|
| 134 |
+
("text_decoder.", ""),
|
| 135 |
+
("text_encoder_frontend.embed", "embed_tokens"),
|
| 136 |
+
("text_decoder_frontend.embed", "embed_tokens"),
|
| 137 |
+
("encoder_decoder_attn_layer_norm", "cross_attention_layer_norm"),
|
| 138 |
+
("encoder_decoder_attn", "cross_attention"),
|
| 139 |
+
("linear_k", "k_proj"),
|
| 140 |
+
("linear_v", "v_proj"),
|
| 141 |
+
("linear_q", "q_proj"),
|
| 142 |
+
("ffn.inner_proj", "ffn.fc1"),
|
| 143 |
+
("ffn.output_proj", "ffn.fc2"),
|
| 144 |
+
("output_proj", "out_proj"),
|
| 145 |
+
("final_proj", "lm_head"),
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
CUR_PATH = os.path.dirname(os.path.abspath(__file__))
|
| 149 |
+
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
|
| 150 |
+
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "huggingface", "hub")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _load_hf_config():
|
| 154 |
+
return SeamlessM4Tv2Config()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _convert_model(
|
| 158 |
+
original_model,
|
| 159 |
+
hf_model,
|
| 160 |
+
convert_list,
|
| 161 |
+
device,
|
| 162 |
+
unwanted_prefix="model.",
|
| 163 |
+
filter_state_dict="speech",
|
| 164 |
+
exclude_state_dict=None,
|
| 165 |
+
):
|
| 166 |
+
state_dict = original_model.state_dict()
|
| 167 |
+
|
| 168 |
+
# filter func
|
| 169 |
+
if isinstance(filter_state_dict, str):
|
| 170 |
+
|
| 171 |
+
def filter_func(x):
|
| 172 |
+
return filter_state_dict in x[0]
|
| 173 |
+
|
| 174 |
+
else:
|
| 175 |
+
|
| 176 |
+
def filter_func(item):
|
| 177 |
+
if exclude_state_dict is not None and exclude_state_dict in item[0]:
|
| 178 |
+
return False
|
| 179 |
+
for filter_el in filter_state_dict:
|
| 180 |
+
if filter_el in item[0]:
|
| 181 |
+
return True
|
| 182 |
+
|
| 183 |
+
return False
|
| 184 |
+
|
| 185 |
+
state_dict = dict(filter(filter_func, state_dict.items()))
|
| 186 |
+
|
| 187 |
+
for k, v in list(state_dict.items()):
|
| 188 |
+
new_k = k[len(unwanted_prefix) :]
|
| 189 |
+
for old_layer_name, new_layer_name in convert_list:
|
| 190 |
+
if old_layer_name in new_k:
|
| 191 |
+
new_k = new_k.replace(old_layer_name, new_layer_name)
|
| 192 |
+
|
| 193 |
+
# must do it by hand
|
| 194 |
+
if ".layer_norm" in new_k and new_k.split(".layer_norm")[0][-1].isnumeric():
|
| 195 |
+
new_k = new_k.replace("layer_norm", "final_layer_norm")
|
| 196 |
+
|
| 197 |
+
state_dict[new_k] = state_dict.pop(k)
|
| 198 |
+
|
| 199 |
+
extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys())
|
| 200 |
+
extra_keys = set(extra_keys)
|
| 201 |
+
missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys())
|
| 202 |
+
missing_keys = set({k for k in missing_keys if "final_logits_bias" not in k})
|
| 203 |
+
if len(extra_keys) != 0:
|
| 204 |
+
raise ValueError(f"extra keys found: {extra_keys}")
|
| 205 |
+
if len(missing_keys) != 0:
|
| 206 |
+
raise ValueError(f"missing keys: {missing_keys}")
|
| 207 |
+
hf_model.load_state_dict(state_dict, strict=False)
|
| 208 |
+
n_params = param_count(hf_model)
|
| 209 |
+
|
| 210 |
+
logger.info(f"model loaded: {round(n_params / 1e6, 1)}M params")
|
| 211 |
+
|
| 212 |
+
hf_model.eval()
|
| 213 |
+
hf_model.to(device)
|
| 214 |
+
del state_dict
|
| 215 |
+
|
| 216 |
+
return hf_model
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def load_model(save_dir, model_type, repo_id):
|
| 220 |
+
"""
|
| 221 |
+
Meta SeamlessM4Tv2 is made of 8 main components:
|
| 222 |
+
- speech_encoder (#1) and speech_encoder_frontend (#2)
|
| 223 |
+
- t2u_model (#3)
|
| 224 |
+
- text_encoder (#4) and text_encoder_frontend (#5)
|
| 225 |
+
- text_decoder (#6) [and text_decoder_frontend (#5) = equals to text_encoder_frontend]
|
| 226 |
+
- final_proj (#7)
|
| 227 |
+
- vocoder (#8)
|
| 228 |
+
"""
|
| 229 |
+
device = _grab_best_device()
|
| 230 |
+
name = "seamlessM4T_v2_large"
|
| 231 |
+
|
| 232 |
+
original_model = Translator(name, "vocoder_v2", device, dtype=torch.float32)
|
| 233 |
+
|
| 234 |
+
######### TOKENIZER
|
| 235 |
+
|
| 236 |
+
langs = LARGE_SUPPORTED_LANGUAGES
|
| 237 |
+
langs = [f"__{lang}__" for lang in langs]
|
| 238 |
+
vocab_file = os.path.join(os.path.expanduser("~"), "tokenizer", model_type, "tokenizer.model")
|
| 239 |
+
|
| 240 |
+
save_dir = os.path.join(save_dir, name)
|
| 241 |
+
Path(save_dir).mkdir(exist_ok=True)
|
| 242 |
+
|
| 243 |
+
tokenizer = SeamlessM4TTokenizer(vocab_file, additional_special_tokens=langs)
|
| 244 |
+
|
| 245 |
+
sanity_check_lang_id = tokenizer.convert_tokens_to_ids("__fra__")
|
| 246 |
+
|
| 247 |
+
tokenizer.save_pretrained(save_dir)
|
| 248 |
+
tokenizer = SeamlessM4TTokenizer.from_pretrained(save_dir)
|
| 249 |
+
|
| 250 |
+
if sanity_check_lang_id != tokenizer.convert_tokens_to_ids("__fra__"):
|
| 251 |
+
raise ValueError(
|
| 252 |
+
f"Error in tokenizer saving/loading - __fra__ lang id is not coherent: {sanity_check_lang_id} vs {tokenizer.convert_tokens_to_ids('__fra__')}"
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
####### get language to ids dict
|
| 256 |
+
text_decoder_lang_code_to_id = {lang.replace("__", ""): tokenizer.convert_tokens_to_ids(lang) for lang in langs}
|
| 257 |
+
# offset: vocoder unit vocab size + 5 (for EOS/PAD/BOS/UNK/MSK) + len(supported_languages)
|
| 258 |
+
t2u_lang_code_to_id = {
|
| 259 |
+
code.replace("__", ""): i + 10005 + len(UNIT_SUPPORTED_LANGUAGES)
|
| 260 |
+
for i, code in enumerate(UNIT_SUPPORTED_LANGUAGES)
|
| 261 |
+
}
|
| 262 |
+
vocoder_lang_code_to_id = {code.replace("__", ""): i for i, code in enumerate(VOCODER_SUPPORTED_LANGUAGES)}
|
| 263 |
+
|
| 264 |
+
######### FE
|
| 265 |
+
|
| 266 |
+
fe = SeamlessM4TFeatureExtractor(language_code=langs)
|
| 267 |
+
|
| 268 |
+
fe.save_pretrained(save_dir)
|
| 269 |
+
fe = SeamlessM4TFeatureExtractor.from_pretrained(save_dir)
|
| 270 |
+
|
| 271 |
+
processor = SeamlessM4TProcessor(feature_extractor=fe, tokenizer=tokenizer)
|
| 272 |
+
processor.save_pretrained(save_dir)
|
| 273 |
+
processor.push_to_hub(repo_id=repo_id, create_pr=True)
|
| 274 |
+
|
| 275 |
+
processor = SeamlessM4TProcessor.from_pretrained(save_dir)
|
| 276 |
+
|
| 277 |
+
######## Model
|
| 278 |
+
|
| 279 |
+
# init config
|
| 280 |
+
hf_config = _load_hf_config()
|
| 281 |
+
|
| 282 |
+
######## get id_to_text and char_to_id from original model tokenizers
|
| 283 |
+
id_to_text = {i: original_model.text_tokenizer.model.index_to_token(i) for i in range(hf_config.vocab_size)}
|
| 284 |
+
char_to_id = {
|
| 285 |
+
original_model.model.t2u_model.decoder_frontend.char_tokenizer.model.index_to_token(i): i for i in range(10904)
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
# init model
|
| 289 |
+
hf_model = SeamlessM4Tv2Model(hf_config)
|
| 290 |
+
|
| 291 |
+
hf_model.generation_config.__setattr__("text_decoder_lang_to_code_id", text_decoder_lang_code_to_id)
|
| 292 |
+
hf_model.generation_config.__setattr__("t2u_lang_code_to_id", t2u_lang_code_to_id)
|
| 293 |
+
hf_model.generation_config.__setattr__("vocoder_lang_code_to_id", vocoder_lang_code_to_id)
|
| 294 |
+
hf_model.generation_config.__setattr__("id_to_text", id_to_text)
|
| 295 |
+
hf_model.generation_config.__setattr__("char_to_id", char_to_id)
|
| 296 |
+
|
| 297 |
+
# -1. take care of vocoder
|
| 298 |
+
# similarly to speech T5 must apply and remove weight norm
|
| 299 |
+
hf_model.vocoder.apply_weight_norm()
|
| 300 |
+
hf_model.vocoder = _convert_model(
|
| 301 |
+
original_model,
|
| 302 |
+
hf_model.vocoder,
|
| 303 |
+
vocoder_convert_list,
|
| 304 |
+
device,
|
| 305 |
+
unwanted_prefix="vocoder.code_generator.",
|
| 306 |
+
filter_state_dict="vocoder",
|
| 307 |
+
)
|
| 308 |
+
hf_model.vocoder.remove_weight_norm()
|
| 309 |
+
|
| 310 |
+
# 1. take care of speech encoder
|
| 311 |
+
wav2vec = hf_model.speech_encoder
|
| 312 |
+
hf_model.speech_encoder = _convert_model(
|
| 313 |
+
original_model, wav2vec, wav2vec_convert_list, device, unwanted_prefix="model.", filter_state_dict="speech"
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# 2. take care of t2u
|
| 317 |
+
|
| 318 |
+
hf_model.t2u_model = _convert_model(
|
| 319 |
+
original_model,
|
| 320 |
+
hf_model.t2u_model,
|
| 321 |
+
t2u_convert_list,
|
| 322 |
+
device,
|
| 323 |
+
unwanted_prefix="model.",
|
| 324 |
+
filter_state_dict="t2u_model",
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# 3. take care of text encoder
|
| 328 |
+
hf_model.text_encoder = _convert_model(
|
| 329 |
+
original_model,
|
| 330 |
+
hf_model.text_encoder,
|
| 331 |
+
text_convert_list,
|
| 332 |
+
device,
|
| 333 |
+
unwanted_prefix="model.",
|
| 334 |
+
filter_state_dict=["model.text_encoder"],
|
| 335 |
+
exclude_state_dict="t2u_model",
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# 4. take care of text decoder
|
| 339 |
+
hf_model.text_decoder = _convert_model(
|
| 340 |
+
original_model,
|
| 341 |
+
hf_model.text_decoder,
|
| 342 |
+
text_convert_list,
|
| 343 |
+
device,
|
| 344 |
+
unwanted_prefix="model.",
|
| 345 |
+
filter_state_dict=["model.text_decoder"],
|
| 346 |
+
exclude_state_dict="t2u_model",
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# 5. take care of final proj
|
| 350 |
+
hf_model.lm_head = _convert_model(
|
| 351 |
+
original_model,
|
| 352 |
+
hf_model.lm_head,
|
| 353 |
+
[("final_proj.", "")],
|
| 354 |
+
device,
|
| 355 |
+
unwanted_prefix="model.",
|
| 356 |
+
filter_state_dict=["model.final_proj"],
|
| 357 |
+
exclude_state_dict="t2u_model",
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
# sanity check
|
| 361 |
+
print(find_tied_parameters(hf_model))
|
| 362 |
+
|
| 363 |
+
count_1 = param_count(hf_model)
|
| 364 |
+
count_2 = param_count(original_model)
|
| 365 |
+
|
| 366 |
+
print(f"HF MODEL:{count_1}, ORIGINAL_MODEL: {count_2}, diff:{count_1 - count_2}")
|
| 367 |
+
print(f"HF MODEL excluding embeddings:{hf_model.num_parameters(exclude_embeddings=True)}")
|
| 368 |
+
|
| 369 |
+
del original_model
|
| 370 |
+
|
| 371 |
+
hf_model.generation_config._from_model_config = False
|
| 372 |
+
hf_model.save_pretrained(save_dir)
|
| 373 |
+
hf_model.push_to_hub(repo_id=repo_id, create_pr=True)
|
| 374 |
+
hf_model = SeamlessM4Tv2Model.from_pretrained(save_dir)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
if __name__ == "__main__":
|
| 378 |
+
parser = argparse.ArgumentParser()
|
| 379 |
+
# Required parameters
|
| 380 |
+
|
| 381 |
+
parser.add_argument(
|
| 382 |
+
"--model_type",
|
| 383 |
+
default="large",
|
| 384 |
+
type=str,
|
| 385 |
+
help="Model type.",
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
parser.add_argument(
|
| 389 |
+
"--save_dir",
|
| 390 |
+
default="/home/ubuntu/weights_v2",
|
| 391 |
+
type=str,
|
| 392 |
+
help="Path to the output PyTorch model.",
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
parser.add_argument(
|
| 396 |
+
"--repo_id",
|
| 397 |
+
default="facebook/seamless-m4t-v2-large",
|
| 398 |
+
type=str,
|
| 399 |
+
help="Repo ID.",
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
args = parser.parse_args()
|
| 403 |
+
|
| 404 |
+
load_model(args.save_dir, args.model_type, args.repo_id)
|
docs/transformers/build/lib/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
docs/transformers/build/lib/transformers/models/segformer/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_segformer import *
|
| 22 |
+
from .feature_extraction_segformer import *
|
| 23 |
+
from .image_processing_segformer import *
|
| 24 |
+
from .modeling_segformer import *
|
| 25 |
+
from .modeling_tf_segformer import *
|
| 26 |
+
else:
|
| 27 |
+
import sys
|
| 28 |
+
|
| 29 |
+
_file = globals()["__file__"]
|
| 30 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/segformer/configuration_segformer.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 NVIDIA and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""SegFormer model configuration"""
|
| 16 |
+
|
| 17 |
+
import warnings
|
| 18 |
+
from collections import OrderedDict
|
| 19 |
+
from typing import Mapping
|
| 20 |
+
|
| 21 |
+
from packaging import version
|
| 22 |
+
|
| 23 |
+
from ...configuration_utils import PretrainedConfig
|
| 24 |
+
from ...onnx import OnnxConfig
|
| 25 |
+
from ...utils import logging
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class SegformerConfig(PretrainedConfig):
|
| 32 |
+
r"""
|
| 33 |
+
This is the configuration class to store the configuration of a [`SegformerModel`]. It is used to instantiate an
|
| 34 |
+
SegFormer model according to the specified arguments, defining the model architecture. Instantiating a
|
| 35 |
+
configuration with the defaults will yield a similar configuration to that of the SegFormer
|
| 36 |
+
[nvidia/segformer-b0-finetuned-ade-512-512](https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512)
|
| 37 |
+
architecture.
|
| 38 |
+
|
| 39 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 40 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 44 |
+
The number of input channels.
|
| 45 |
+
num_encoder_blocks (`int`, *optional*, defaults to 4):
|
| 46 |
+
The number of encoder blocks (i.e. stages in the Mix Transformer encoder).
|
| 47 |
+
depths (`List[int]`, *optional*, defaults to `[2, 2, 2, 2]`):
|
| 48 |
+
The number of layers in each encoder block.
|
| 49 |
+
sr_ratios (`List[int]`, *optional*, defaults to `[8, 4, 2, 1]`):
|
| 50 |
+
Sequence reduction ratios in each encoder block.
|
| 51 |
+
hidden_sizes (`List[int]`, *optional*, defaults to `[32, 64, 160, 256]`):
|
| 52 |
+
Dimension of each of the encoder blocks.
|
| 53 |
+
patch_sizes (`List[int]`, *optional*, defaults to `[7, 3, 3, 3]`):
|
| 54 |
+
Patch size before each encoder block.
|
| 55 |
+
strides (`List[int]`, *optional*, defaults to `[4, 2, 2, 2]`):
|
| 56 |
+
Stride before each encoder block.
|
| 57 |
+
num_attention_heads (`List[int]`, *optional*, defaults to `[1, 2, 5, 8]`):
|
| 58 |
+
Number of attention heads for each attention layer in each block of the Transformer encoder.
|
| 59 |
+
mlp_ratios (`List[int]`, *optional*, defaults to `[4, 4, 4, 4]`):
|
| 60 |
+
Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the
|
| 61 |
+
encoder blocks.
|
| 62 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 63 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 64 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 65 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 66 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 67 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 68 |
+
The dropout ratio for the attention probabilities.
|
| 69 |
+
classifier_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 70 |
+
The dropout probability before the classification head.
|
| 71 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 72 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 73 |
+
drop_path_rate (`float`, *optional*, defaults to 0.1):
|
| 74 |
+
The dropout probability for stochastic depth, used in the blocks of the Transformer encoder.
|
| 75 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 76 |
+
The epsilon used by the layer normalization layers.
|
| 77 |
+
decoder_hidden_size (`int`, *optional*, defaults to 256):
|
| 78 |
+
The dimension of the all-MLP decode head.
|
| 79 |
+
semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
|
| 80 |
+
The index that is ignored by the loss function of the semantic segmentation model.
|
| 81 |
+
|
| 82 |
+
Example:
|
| 83 |
+
|
| 84 |
+
```python
|
| 85 |
+
>>> from transformers import SegformerModel, SegformerConfig
|
| 86 |
+
|
| 87 |
+
>>> # Initializing a SegFormer nvidia/segformer-b0-finetuned-ade-512-512 style configuration
|
| 88 |
+
>>> configuration = SegformerConfig()
|
| 89 |
+
|
| 90 |
+
>>> # Initializing a model from the nvidia/segformer-b0-finetuned-ade-512-512 style configuration
|
| 91 |
+
>>> model = SegformerModel(configuration)
|
| 92 |
+
|
| 93 |
+
>>> # Accessing the model configuration
|
| 94 |
+
>>> configuration = model.config
|
| 95 |
+
```"""
|
| 96 |
+
|
| 97 |
+
model_type = "segformer"
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
num_channels=3,
|
| 102 |
+
num_encoder_blocks=4,
|
| 103 |
+
depths=[2, 2, 2, 2],
|
| 104 |
+
sr_ratios=[8, 4, 2, 1],
|
| 105 |
+
hidden_sizes=[32, 64, 160, 256],
|
| 106 |
+
patch_sizes=[7, 3, 3, 3],
|
| 107 |
+
strides=[4, 2, 2, 2],
|
| 108 |
+
num_attention_heads=[1, 2, 5, 8],
|
| 109 |
+
mlp_ratios=[4, 4, 4, 4],
|
| 110 |
+
hidden_act="gelu",
|
| 111 |
+
hidden_dropout_prob=0.0,
|
| 112 |
+
attention_probs_dropout_prob=0.0,
|
| 113 |
+
classifier_dropout_prob=0.1,
|
| 114 |
+
initializer_range=0.02,
|
| 115 |
+
drop_path_rate=0.1,
|
| 116 |
+
layer_norm_eps=1e-6,
|
| 117 |
+
decoder_hidden_size=256,
|
| 118 |
+
semantic_loss_ignore_index=255,
|
| 119 |
+
**kwargs,
|
| 120 |
+
):
|
| 121 |
+
super().__init__(**kwargs)
|
| 122 |
+
|
| 123 |
+
if "reshape_last_stage" in kwargs and kwargs["reshape_last_stage"] is False:
|
| 124 |
+
warnings.warn(
|
| 125 |
+
"Reshape_last_stage is set to False in this config. This argument is deprecated and will soon be"
|
| 126 |
+
" removed, as the behaviour will default to that of reshape_last_stage = True.",
|
| 127 |
+
FutureWarning,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
self.num_channels = num_channels
|
| 131 |
+
self.num_encoder_blocks = num_encoder_blocks
|
| 132 |
+
self.depths = depths
|
| 133 |
+
self.sr_ratios = sr_ratios
|
| 134 |
+
self.hidden_sizes = hidden_sizes
|
| 135 |
+
self.patch_sizes = patch_sizes
|
| 136 |
+
self.strides = strides
|
| 137 |
+
self.mlp_ratios = mlp_ratios
|
| 138 |
+
self.num_attention_heads = num_attention_heads
|
| 139 |
+
self.hidden_act = hidden_act
|
| 140 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 141 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 142 |
+
self.classifier_dropout_prob = classifier_dropout_prob
|
| 143 |
+
self.initializer_range = initializer_range
|
| 144 |
+
self.drop_path_rate = drop_path_rate
|
| 145 |
+
self.layer_norm_eps = layer_norm_eps
|
| 146 |
+
self.decoder_hidden_size = decoder_hidden_size
|
| 147 |
+
self.reshape_last_stage = kwargs.get("reshape_last_stage", True)
|
| 148 |
+
self.semantic_loss_ignore_index = semantic_loss_ignore_index
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class SegformerOnnxConfig(OnnxConfig):
|
| 152 |
+
torch_onnx_minimum_version = version.parse("1.11")
|
| 153 |
+
|
| 154 |
+
@property
|
| 155 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 156 |
+
return OrderedDict(
|
| 157 |
+
[
|
| 158 |
+
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
|
| 159 |
+
]
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
@property
|
| 163 |
+
def atol_for_validation(self) -> float:
|
| 164 |
+
return 1e-4
|
| 165 |
+
|
| 166 |
+
@property
|
| 167 |
+
def default_onnx_opset(self) -> int:
|
| 168 |
+
return 12
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
__all__ = ["SegformerConfig", "SegformerOnnxConfig"]
|
docs/transformers/build/lib/transformers/models/segformer/convert_segformer_original_to_pytorch.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert SegFormer checkpoints."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
from collections import OrderedDict
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
import requests
|
| 23 |
+
import torch
|
| 24 |
+
from huggingface_hub import hf_hub_download
|
| 25 |
+
from PIL import Image
|
| 26 |
+
|
| 27 |
+
from transformers import (
|
| 28 |
+
SegformerConfig,
|
| 29 |
+
SegformerForImageClassification,
|
| 30 |
+
SegformerForSemanticSegmentation,
|
| 31 |
+
SegformerImageProcessor,
|
| 32 |
+
)
|
| 33 |
+
from transformers.utils import logging
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
logging.set_verbosity_info()
|
| 37 |
+
logger = logging.get_logger(__name__)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def rename_keys(state_dict, encoder_only=False):
|
| 41 |
+
new_state_dict = OrderedDict()
|
| 42 |
+
for key, value in state_dict.items():
|
| 43 |
+
if encoder_only and not key.startswith("head"):
|
| 44 |
+
key = "segformer.encoder." + key
|
| 45 |
+
if key.startswith("backbone"):
|
| 46 |
+
key = key.replace("backbone", "segformer.encoder")
|
| 47 |
+
if "patch_embed" in key:
|
| 48 |
+
# replace for example patch_embed1 by patch_embeddings.0
|
| 49 |
+
idx = key[key.find("patch_embed") + len("patch_embed")]
|
| 50 |
+
key = key.replace(f"patch_embed{idx}", f"patch_embeddings.{int(idx) - 1}")
|
| 51 |
+
if "norm" in key:
|
| 52 |
+
key = key.replace("norm", "layer_norm")
|
| 53 |
+
if "segformer.encoder.layer_norm" in key:
|
| 54 |
+
# replace for example layer_norm1 by layer_norm.0
|
| 55 |
+
idx = key[key.find("segformer.encoder.layer_norm") + len("segformer.encoder.layer_norm")]
|
| 56 |
+
key = key.replace(f"layer_norm{idx}", f"layer_norm.{int(idx) - 1}")
|
| 57 |
+
if "layer_norm1" in key:
|
| 58 |
+
key = key.replace("layer_norm1", "layer_norm_1")
|
| 59 |
+
if "layer_norm2" in key:
|
| 60 |
+
key = key.replace("layer_norm2", "layer_norm_2")
|
| 61 |
+
if "block" in key:
|
| 62 |
+
# replace for example block1 by block.0
|
| 63 |
+
idx = key[key.find("block") + len("block")]
|
| 64 |
+
key = key.replace(f"block{idx}", f"block.{int(idx) - 1}")
|
| 65 |
+
if "attn.q" in key:
|
| 66 |
+
key = key.replace("attn.q", "attention.self.query")
|
| 67 |
+
if "attn.proj" in key:
|
| 68 |
+
key = key.replace("attn.proj", "attention.output.dense")
|
| 69 |
+
if "attn" in key:
|
| 70 |
+
key = key.replace("attn", "attention.self")
|
| 71 |
+
if "fc1" in key:
|
| 72 |
+
key = key.replace("fc1", "dense1")
|
| 73 |
+
if "fc2" in key:
|
| 74 |
+
key = key.replace("fc2", "dense2")
|
| 75 |
+
if "linear_pred" in key:
|
| 76 |
+
key = key.replace("linear_pred", "classifier")
|
| 77 |
+
if "linear_fuse" in key:
|
| 78 |
+
key = key.replace("linear_fuse.conv", "linear_fuse")
|
| 79 |
+
key = key.replace("linear_fuse.bn", "batch_norm")
|
| 80 |
+
if "linear_c" in key:
|
| 81 |
+
# replace for example linear_c4 by linear_c.3
|
| 82 |
+
idx = key[key.find("linear_c") + len("linear_c")]
|
| 83 |
+
key = key.replace(f"linear_c{idx}", f"linear_c.{int(idx) - 1}")
|
| 84 |
+
if key.startswith("head"):
|
| 85 |
+
key = key.replace("head", "classifier")
|
| 86 |
+
new_state_dict[key] = value
|
| 87 |
+
|
| 88 |
+
return new_state_dict
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def read_in_k_v(state_dict, config):
|
| 92 |
+
# for each of the encoder blocks:
|
| 93 |
+
for i in range(config.num_encoder_blocks):
|
| 94 |
+
for j in range(config.depths[i]):
|
| 95 |
+
# read in weights + bias of keys and values (which is a single matrix in the original implementation)
|
| 96 |
+
kv_weight = state_dict.pop(f"segformer.encoder.block.{i}.{j}.attention.self.kv.weight")
|
| 97 |
+
kv_bias = state_dict.pop(f"segformer.encoder.block.{i}.{j}.attention.self.kv.bias")
|
| 98 |
+
# next, add keys and values (in that order) to the state dict
|
| 99 |
+
state_dict[f"segformer.encoder.block.{i}.{j}.attention.self.key.weight"] = kv_weight[
|
| 100 |
+
: config.hidden_sizes[i], :
|
| 101 |
+
]
|
| 102 |
+
state_dict[f"segformer.encoder.block.{i}.{j}.attention.self.key.bias"] = kv_bias[: config.hidden_sizes[i]]
|
| 103 |
+
state_dict[f"segformer.encoder.block.{i}.{j}.attention.self.value.weight"] = kv_weight[
|
| 104 |
+
config.hidden_sizes[i] :, :
|
| 105 |
+
]
|
| 106 |
+
state_dict[f"segformer.encoder.block.{i}.{j}.attention.self.value.bias"] = kv_bias[
|
| 107 |
+
config.hidden_sizes[i] :
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# We will verify our results on a COCO image
|
| 112 |
+
def prepare_img():
|
| 113 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 114 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
| 115 |
+
|
| 116 |
+
return image
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@torch.no_grad()
|
| 120 |
+
def convert_segformer_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path):
|
| 121 |
+
"""
|
| 122 |
+
Copy/paste/tweak model's weights to our SegFormer structure.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
# load default SegFormer configuration
|
| 126 |
+
config = SegformerConfig()
|
| 127 |
+
encoder_only = False
|
| 128 |
+
|
| 129 |
+
# set attributes based on model_name
|
| 130 |
+
repo_id = "huggingface/label-files"
|
| 131 |
+
if "segformer" in model_name:
|
| 132 |
+
size = model_name[len("segformer.") : len("segformer.") + 2]
|
| 133 |
+
if "ade" in model_name:
|
| 134 |
+
config.num_labels = 150
|
| 135 |
+
filename = "ade20k-id2label.json"
|
| 136 |
+
expected_shape = (1, 150, 128, 128)
|
| 137 |
+
elif "city" in model_name:
|
| 138 |
+
config.num_labels = 19
|
| 139 |
+
filename = "cityscapes-id2label.json"
|
| 140 |
+
expected_shape = (1, 19, 128, 128)
|
| 141 |
+
else:
|
| 142 |
+
raise ValueError(f"Model {model_name} not supported")
|
| 143 |
+
elif "mit" in model_name:
|
| 144 |
+
encoder_only = True
|
| 145 |
+
size = model_name[4:6]
|
| 146 |
+
config.num_labels = 1000
|
| 147 |
+
filename = "imagenet-1k-id2label.json"
|
| 148 |
+
expected_shape = (1, 1000)
|
| 149 |
+
else:
|
| 150 |
+
raise ValueError(f"Model {model_name} not supported")
|
| 151 |
+
|
| 152 |
+
# set config attributes
|
| 153 |
+
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
| 154 |
+
id2label = {int(k): v for k, v in id2label.items()}
|
| 155 |
+
config.id2label = id2label
|
| 156 |
+
config.label2id = {v: k for k, v in id2label.items()}
|
| 157 |
+
if size == "b0":
|
| 158 |
+
pass
|
| 159 |
+
elif size == "b1":
|
| 160 |
+
config.hidden_sizes = [64, 128, 320, 512]
|
| 161 |
+
config.decoder_hidden_size = 256
|
| 162 |
+
elif size == "b2":
|
| 163 |
+
config.hidden_sizes = [64, 128, 320, 512]
|
| 164 |
+
config.decoder_hidden_size = 768
|
| 165 |
+
config.depths = [3, 4, 6, 3]
|
| 166 |
+
elif size == "b3":
|
| 167 |
+
config.hidden_sizes = [64, 128, 320, 512]
|
| 168 |
+
config.decoder_hidden_size = 768
|
| 169 |
+
config.depths = [3, 4, 18, 3]
|
| 170 |
+
elif size == "b4":
|
| 171 |
+
config.hidden_sizes = [64, 128, 320, 512]
|
| 172 |
+
config.decoder_hidden_size = 768
|
| 173 |
+
config.depths = [3, 8, 27, 3]
|
| 174 |
+
elif size == "b5":
|
| 175 |
+
config.hidden_sizes = [64, 128, 320, 512]
|
| 176 |
+
config.decoder_hidden_size = 768
|
| 177 |
+
config.depths = [3, 6, 40, 3]
|
| 178 |
+
else:
|
| 179 |
+
raise ValueError(f"Size {size} not supported")
|
| 180 |
+
|
| 181 |
+
# load image processor (only resize + normalize)
|
| 182 |
+
image_processor = SegformerImageProcessor(
|
| 183 |
+
image_scale=(512, 512), keep_ratio=False, align=False, do_random_crop=False
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# prepare image
|
| 187 |
+
image = prepare_img()
|
| 188 |
+
pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
|
| 189 |
+
|
| 190 |
+
logger.info(f"Converting model {model_name}...")
|
| 191 |
+
|
| 192 |
+
# load original state dict
|
| 193 |
+
if encoder_only:
|
| 194 |
+
state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True)
|
| 195 |
+
else:
|
| 196 |
+
state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True)["state_dict"]
|
| 197 |
+
|
| 198 |
+
# rename keys
|
| 199 |
+
state_dict = rename_keys(state_dict, encoder_only=encoder_only)
|
| 200 |
+
if not encoder_only:
|
| 201 |
+
del state_dict["decode_head.conv_seg.weight"]
|
| 202 |
+
del state_dict["decode_head.conv_seg.bias"]
|
| 203 |
+
|
| 204 |
+
# key and value matrices need special treatment
|
| 205 |
+
read_in_k_v(state_dict, config)
|
| 206 |
+
|
| 207 |
+
# create HuggingFace model and load state dict
|
| 208 |
+
if encoder_only:
|
| 209 |
+
config.reshape_last_stage = False
|
| 210 |
+
model = SegformerForImageClassification(config)
|
| 211 |
+
else:
|
| 212 |
+
model = SegformerForSemanticSegmentation(config)
|
| 213 |
+
model.load_state_dict(state_dict)
|
| 214 |
+
model.eval()
|
| 215 |
+
|
| 216 |
+
# forward pass
|
| 217 |
+
outputs = model(pixel_values)
|
| 218 |
+
logits = outputs.logits
|
| 219 |
+
|
| 220 |
+
# set expected_slice based on model name
|
| 221 |
+
# ADE20k checkpoints
|
| 222 |
+
if model_name == "segformer.b0.512x512.ade.160k":
|
| 223 |
+
expected_slice = torch.tensor(
|
| 224 |
+
[
|
| 225 |
+
[[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]],
|
| 226 |
+
[[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]],
|
| 227 |
+
[[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]],
|
| 228 |
+
]
|
| 229 |
+
)
|
| 230 |
+
elif model_name == "segformer.b1.512x512.ade.160k":
|
| 231 |
+
expected_slice = torch.tensor(
|
| 232 |
+
[
|
| 233 |
+
[[-7.5820, -8.7231, -8.3215], [-8.0600, -10.3529, -10.0304], [-7.5208, -9.4103, -9.6239]],
|
| 234 |
+
[[-12.6918, -13.8994, -13.7137], [-13.3196, -15.7523, -15.4789], [-12.9343, -14.8757, -14.9689]],
|
| 235 |
+
[[-11.1911, -11.9421, -11.3243], [-11.3342, -13.6839, -13.3581], [-10.3909, -12.1832, -12.4858]],
|
| 236 |
+
]
|
| 237 |
+
)
|
| 238 |
+
elif model_name == "segformer.b2.512x512.ade.160k":
|
| 239 |
+
expected_slice = torch.tensor(
|
| 240 |
+
[
|
| 241 |
+
[[-11.8173, -14.3850, -16.3128], [-14.5648, -16.5804, -18.6568], [-14.7223, -15.7387, -18.4218]],
|
| 242 |
+
[[-15.7290, -17.9171, -19.4423], [-18.3105, -19.9448, -21.4661], [-17.9296, -18.6497, -20.7910]],
|
| 243 |
+
[[-15.0783, -17.0336, -18.2789], [-16.8771, -18.6870, -20.1612], [-16.2454, -17.1426, -19.5055]],
|
| 244 |
+
]
|
| 245 |
+
)
|
| 246 |
+
elif model_name == "segformer.b3.512x512.ade.160k":
|
| 247 |
+
expected_slice = torch.tensor(
|
| 248 |
+
[
|
| 249 |
+
[[-9.0878, -10.2081, -10.1891], [-9.3144, -10.7941, -10.9843], [-9.2294, -10.3855, -10.5704]],
|
| 250 |
+
[[-12.2316, -13.9068, -13.6102], [-12.9161, -14.3702, -14.3235], [-12.5233, -13.7174, -13.7932]],
|
| 251 |
+
[[-14.6275, -15.2490, -14.9727], [-14.3400, -15.9687, -16.2827], [-14.1484, -15.4033, -15.8937]],
|
| 252 |
+
]
|
| 253 |
+
)
|
| 254 |
+
elif model_name == "segformer.b4.512x512.ade.160k":
|
| 255 |
+
expected_slice = torch.tensor(
|
| 256 |
+
[
|
| 257 |
+
[[-12.3144, -13.2447, -14.0802], [-13.3614, -14.5816, -15.6117], [-13.3340, -14.4433, -16.2219]],
|
| 258 |
+
[[-19.2781, -20.4128, -20.7506], [-20.6153, -21.6566, -22.0998], [-19.9800, -21.0430, -22.1494]],
|
| 259 |
+
[[-18.8739, -19.7804, -21.1834], [-20.1233, -21.6765, -23.2944], [-20.0315, -21.2641, -23.6944]],
|
| 260 |
+
]
|
| 261 |
+
)
|
| 262 |
+
elif model_name == "segformer.b5.640x640.ade.160k":
|
| 263 |
+
expected_slice = torch.tensor(
|
| 264 |
+
[
|
| 265 |
+
[[-9.5524, -12.0835, -11.7348], [-10.5229, -13.6446, -14.5662], [-9.5842, -12.8851, -13.9414]],
|
| 266 |
+
[[-15.3432, -17.5323, -17.0818], [-16.3330, -18.9255, -19.2101], [-15.1340, -17.7848, -18.3971]],
|
| 267 |
+
[[-12.6072, -14.9486, -14.6631], [-13.7629, -17.0907, -17.7745], [-12.7899, -16.1695, -17.1671]],
|
| 268 |
+
]
|
| 269 |
+
)
|
| 270 |
+
# Cityscapes checkpoints
|
| 271 |
+
elif model_name == "segformer.b0.1024x1024.city.160k":
|
| 272 |
+
expected_slice = torch.tensor(
|
| 273 |
+
[
|
| 274 |
+
[[-11.9295, -13.4057, -14.8106], [-13.3431, -14.8179, -15.3781], [-14.2836, -15.5942, -16.1588]],
|
| 275 |
+
[[-11.4906, -12.8067, -13.6564], [-13.1189, -14.0500, -14.1543], [-13.8748, -14.5136, -14.8789]],
|
| 276 |
+
[[0.5374, 0.1067, -0.4742], [0.1141, -0.2255, -0.7099], [-0.3000, -0.5924, -1.3105]],
|
| 277 |
+
]
|
| 278 |
+
)
|
| 279 |
+
elif model_name == "segformer.b0.512x1024.city.160k":
|
| 280 |
+
expected_slice = torch.tensor(
|
| 281 |
+
[
|
| 282 |
+
[[-7.8217, -9.8767, -10.1717], [-9.4438, -10.9058, -11.4047], [-9.7939, -12.3495, -12.1079]],
|
| 283 |
+
[[-7.1514, -9.5336, -10.0860], [-9.7776, -11.6822, -11.8439], [-10.1411, -12.7655, -12.8972]],
|
| 284 |
+
[[0.3021, 0.0805, -0.2310], [-0.0328, -0.1605, -0.2714], [-0.1408, -0.5477, -0.6976]],
|
| 285 |
+
]
|
| 286 |
+
)
|
| 287 |
+
elif model_name == "segformer.b0.640x1280.city.160k":
|
| 288 |
+
expected_slice = torch.tensor(
|
| 289 |
+
[
|
| 290 |
+
[
|
| 291 |
+
[-1.1372e01, -1.2787e01, -1.3477e01],
|
| 292 |
+
[-1.2536e01, -1.4194e01, -1.4409e01],
|
| 293 |
+
[-1.3217e01, -1.4888e01, -1.5327e01],
|
| 294 |
+
],
|
| 295 |
+
[
|
| 296 |
+
[-1.4791e01, -1.7122e01, -1.8277e01],
|
| 297 |
+
[-1.7163e01, -1.9192e01, -1.9533e01],
|
| 298 |
+
[-1.7897e01, -1.9991e01, -2.0315e01],
|
| 299 |
+
],
|
| 300 |
+
[
|
| 301 |
+
[7.6723e-01, 4.1921e-01, -7.7878e-02],
|
| 302 |
+
[4.7772e-01, 9.5557e-03, -2.8082e-01],
|
| 303 |
+
[3.6032e-01, -2.4826e-01, -5.1168e-01],
|
| 304 |
+
],
|
| 305 |
+
]
|
| 306 |
+
)
|
| 307 |
+
elif model_name == "segformer.b0.768x768.city.160k":
|
| 308 |
+
expected_slice = torch.tensor(
|
| 309 |
+
[
|
| 310 |
+
[[-9.4959, -11.3087, -11.7479], [-11.0025, -12.6540, -12.3319], [-11.4064, -13.0487, -12.9905]],
|
| 311 |
+
[[-9.8905, -11.3084, -12.0854], [-11.1726, -12.7698, -12.9583], [-11.5985, -13.3278, -14.1774]],
|
| 312 |
+
[[0.2213, 0.0192, -0.2466], [-0.1731, -0.4213, -0.4874], [-0.3126, -0.6541, -1.1389]],
|
| 313 |
+
]
|
| 314 |
+
)
|
| 315 |
+
elif model_name == "segformer.b1.1024x1024.city.160k":
|
| 316 |
+
expected_slice = torch.tensor(
|
| 317 |
+
[
|
| 318 |
+
[[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]],
|
| 319 |
+
[[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]],
|
| 320 |
+
[[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]],
|
| 321 |
+
]
|
| 322 |
+
)
|
| 323 |
+
elif model_name == "segformer.b2.1024x1024.city.160k":
|
| 324 |
+
expected_slice = torch.tensor(
|
| 325 |
+
[
|
| 326 |
+
[[-16.0976, -16.4856, -17.3962], [-16.6234, -19.0342, -19.7685], [-16.0900, -18.0661, -19.1180]],
|
| 327 |
+
[[-18.4750, -18.8488, -19.5074], [-19.4030, -22.1570, -22.5977], [-19.1191, -20.8486, -22.3783]],
|
| 328 |
+
[[-4.5178, -5.5037, -6.5109], [-5.0884, -7.2174, -8.0334], [-4.4156, -5.8117, -7.2970]],
|
| 329 |
+
]
|
| 330 |
+
)
|
| 331 |
+
elif model_name == "segformer.b3.1024x1024.city.160k":
|
| 332 |
+
expected_slice = torch.tensor(
|
| 333 |
+
[
|
| 334 |
+
[[-14.2081, -14.4732, -14.1977], [-14.5867, -16.4423, -16.6356], [-13.4441, -14.9685, -16.8696]],
|
| 335 |
+
[[-14.4576, -14.7073, -15.0451], [-15.0816, -17.6237, -17.9873], [-14.4213, -16.0199, -18.5992]],
|
| 336 |
+
[[-4.7349, -4.9588, -5.0966], [-4.3210, -6.9325, -7.2591], [-3.4312, -4.7484, -7.1917]],
|
| 337 |
+
]
|
| 338 |
+
)
|
| 339 |
+
elif model_name == "segformer.b4.1024x1024.city.160k":
|
| 340 |
+
expected_slice = torch.tensor(
|
| 341 |
+
[
|
| 342 |
+
[[-11.7737, -11.9526, -11.3273], [-13.6692, -14.4574, -13.8878], [-13.8937, -14.6924, -15.9345]],
|
| 343 |
+
[[-14.6706, -14.5330, -14.1306], [-16.1502, -16.8180, -16.4269], [-16.8338, -17.8939, -20.1746]],
|
| 344 |
+
[[1.0491, 0.8289, 1.0310], [1.1044, 0.5219, 0.8055], [1.0899, 0.6926, 0.5590]],
|
| 345 |
+
]
|
| 346 |
+
)
|
| 347 |
+
elif model_name == "segformer.b5.1024x1024.city.160k":
|
| 348 |
+
expected_slice = torch.tensor(
|
| 349 |
+
[
|
| 350 |
+
[[-12.5641, -13.4777, -13.0684], [-13.9587, -15.8983, -16.6557], [-13.3109, -15.7350, -16.3141]],
|
| 351 |
+
[[-14.7074, -15.4352, -14.5944], [-16.6353, -18.1663, -18.6120], [-15.1702, -18.0329, -18.1547]],
|
| 352 |
+
[[-1.7990, -2.0951, -1.7784], [-2.6397, -3.8245, -3.9686], [-1.5264, -2.8126, -2.9316]],
|
| 353 |
+
]
|
| 354 |
+
)
|
| 355 |
+
else:
|
| 356 |
+
predicted_class_idx = logits.argmax(-1).item()
|
| 357 |
+
print("Predicted class:", model.config.id2label[predicted_class_idx])
|
| 358 |
+
|
| 359 |
+
# verify logits
|
| 360 |
+
if not encoder_only:
|
| 361 |
+
assert logits.shape == expected_shape
|
| 362 |
+
assert torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-2)
|
| 363 |
+
|
| 364 |
+
# finally, save model and image processor
|
| 365 |
+
logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
|
| 366 |
+
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
| 367 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 368 |
+
image_processor.save_pretrained(pytorch_dump_folder_path)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
if __name__ == "__main__":
|
| 372 |
+
parser = argparse.ArgumentParser()
|
| 373 |
+
|
| 374 |
+
parser.add_argument(
|
| 375 |
+
"--model_name",
|
| 376 |
+
default="segformer.b0.512x512.ade.160k",
|
| 377 |
+
type=str,
|
| 378 |
+
help="Name of the model you'd like to convert.",
|
| 379 |
+
)
|
| 380 |
+
parser.add_argument(
|
| 381 |
+
"--checkpoint_path", default=None, type=str, help="Path to the original PyTorch checkpoint (.pth file)."
|
| 382 |
+
)
|
| 383 |
+
parser.add_argument(
|
| 384 |
+
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
|
| 385 |
+
)
|
| 386 |
+
args = parser.parse_args()
|
| 387 |
+
convert_segformer_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path)
|
docs/transformers/build/lib/transformers/models/segformer/feature_extraction_segformer.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Feature extractor class for SegFormer."""
|
| 16 |
+
|
| 17 |
+
import warnings
|
| 18 |
+
|
| 19 |
+
from ...utils import logging
|
| 20 |
+
from ...utils.import_utils import requires
|
| 21 |
+
from .image_processing_segformer import SegformerImageProcessor
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@requires(backends=("vision",))
|
| 28 |
+
class SegformerFeatureExtractor(SegformerImageProcessor):
|
| 29 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 30 |
+
warnings.warn(
|
| 31 |
+
"The class SegformerFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
|
| 32 |
+
" Please use SegformerImageProcessor instead.",
|
| 33 |
+
FutureWarning,
|
| 34 |
+
)
|
| 35 |
+
super().__init__(*args, **kwargs)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
__all__ = ["SegformerFeatureExtractor"]
|
docs/transformers/build/lib/transformers/models/segformer/image_processing_segformer.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for Segformer."""
|
| 16 |
+
|
| 17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from ...image_processing_utils import INIT_SERVICE_KWARGS, BaseImageProcessor, BatchFeature, get_size_dict
|
| 22 |
+
from ...image_transforms import resize, to_channel_dimension_format
|
| 23 |
+
from ...image_utils import (
|
| 24 |
+
IMAGENET_DEFAULT_MEAN,
|
| 25 |
+
IMAGENET_DEFAULT_STD,
|
| 26 |
+
ChannelDimension,
|
| 27 |
+
ImageInput,
|
| 28 |
+
PILImageResampling,
|
| 29 |
+
infer_channel_dimension_format,
|
| 30 |
+
is_scaled_image,
|
| 31 |
+
make_list_of_images,
|
| 32 |
+
to_numpy_array,
|
| 33 |
+
valid_images,
|
| 34 |
+
validate_preprocess_arguments,
|
| 35 |
+
)
|
| 36 |
+
from ...utils import (
|
| 37 |
+
TensorType,
|
| 38 |
+
filter_out_non_signature_kwargs,
|
| 39 |
+
is_torch_available,
|
| 40 |
+
is_torch_tensor,
|
| 41 |
+
is_vision_available,
|
| 42 |
+
logging,
|
| 43 |
+
)
|
| 44 |
+
from ...utils.deprecation import deprecate_kwarg
|
| 45 |
+
from ...utils.import_utils import requires
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if is_vision_available():
|
| 49 |
+
import PIL.Image
|
| 50 |
+
|
| 51 |
+
if is_torch_available():
|
| 52 |
+
import torch
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
logger = logging.get_logger(__name__)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@requires(backends=("vision",))
|
| 59 |
+
class SegformerImageProcessor(BaseImageProcessor):
|
| 60 |
+
r"""
|
| 61 |
+
Constructs a Segformer image processor.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 65 |
+
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
|
| 66 |
+
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
|
| 67 |
+
size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`):
|
| 68 |
+
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
| 69 |
+
method.
|
| 70 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
| 71 |
+
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
| 72 |
+
`preprocess` method.
|
| 73 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 74 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
| 75 |
+
parameter in the `preprocess` method.
|
| 76 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 77 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
| 78 |
+
method.
|
| 79 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 80 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
| 81 |
+
method.
|
| 82 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
| 83 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 84 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 85 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
| 86 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
| 87 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 88 |
+
do_reduce_labels (`bool`, *optional*, defaults to `False`):
|
| 89 |
+
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
|
| 90 |
+
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
|
| 91 |
+
background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
|
| 92 |
+
`preprocess` method.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
model_input_names = ["pixel_values"]
|
| 96 |
+
|
| 97 |
+
@deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0")
|
| 98 |
+
@filter_out_non_signature_kwargs(extra=INIT_SERVICE_KWARGS)
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
do_resize: bool = True,
|
| 102 |
+
size: Dict[str, int] = None,
|
| 103 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 104 |
+
do_rescale: bool = True,
|
| 105 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 106 |
+
do_normalize: bool = True,
|
| 107 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 108 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 109 |
+
do_reduce_labels: bool = False,
|
| 110 |
+
**kwargs,
|
| 111 |
+
) -> None:
|
| 112 |
+
super().__init__(**kwargs)
|
| 113 |
+
size = size if size is not None else {"height": 512, "width": 512}
|
| 114 |
+
size = get_size_dict(size)
|
| 115 |
+
self.do_resize = do_resize
|
| 116 |
+
self.size = size
|
| 117 |
+
self.resample = resample
|
| 118 |
+
self.do_rescale = do_rescale
|
| 119 |
+
self.rescale_factor = rescale_factor
|
| 120 |
+
self.do_normalize = do_normalize
|
| 121 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
| 122 |
+
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
| 123 |
+
self.do_reduce_labels = do_reduce_labels
|
| 124 |
+
|
| 125 |
+
@classmethod
|
| 126 |
+
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
| 127 |
+
"""
|
| 128 |
+
Overrides the `from_dict` method from the base class to save support of deprecated `reduce_labels` in old configs
|
| 129 |
+
"""
|
| 130 |
+
image_processor_dict = image_processor_dict.copy()
|
| 131 |
+
if "reduce_labels" in image_processor_dict:
|
| 132 |
+
image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels")
|
| 133 |
+
return super().from_dict(image_processor_dict, **kwargs)
|
| 134 |
+
|
| 135 |
+
# Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize
|
| 136 |
+
def resize(
|
| 137 |
+
self,
|
| 138 |
+
image: np.ndarray,
|
| 139 |
+
size: Dict[str, int],
|
| 140 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 141 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 142 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 143 |
+
**kwargs,
|
| 144 |
+
) -> np.ndarray:
|
| 145 |
+
"""
|
| 146 |
+
Resize an image to `(size["height"], size["width"])`.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
image (`np.ndarray`):
|
| 150 |
+
Image to resize.
|
| 151 |
+
size (`Dict[str, int]`):
|
| 152 |
+
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
| 153 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 154 |
+
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
|
| 155 |
+
data_format (`ChannelDimension` or `str`, *optional*):
|
| 156 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 157 |
+
image is used. Can be one of:
|
| 158 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 159 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 160 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 161 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 162 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 163 |
+
from the input image. Can be one of:
|
| 164 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 165 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 166 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
`np.ndarray`: The resized image.
|
| 170 |
+
"""
|
| 171 |
+
size = get_size_dict(size)
|
| 172 |
+
if "height" not in size or "width" not in size:
|
| 173 |
+
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
| 174 |
+
output_size = (size["height"], size["width"])
|
| 175 |
+
return resize(
|
| 176 |
+
image,
|
| 177 |
+
size=output_size,
|
| 178 |
+
resample=resample,
|
| 179 |
+
data_format=data_format,
|
| 180 |
+
input_data_format=input_data_format,
|
| 181 |
+
**kwargs,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label
|
| 185 |
+
def reduce_label(self, label: ImageInput) -> np.ndarray:
|
| 186 |
+
label = to_numpy_array(label)
|
| 187 |
+
# Avoid using underflow conversion
|
| 188 |
+
label[label == 0] = 255
|
| 189 |
+
label = label - 1
|
| 190 |
+
label[label == 254] = 255
|
| 191 |
+
return label
|
| 192 |
+
|
| 193 |
+
def _preprocess(
|
| 194 |
+
self,
|
| 195 |
+
image: ImageInput,
|
| 196 |
+
do_reduce_labels: bool,
|
| 197 |
+
do_resize: bool,
|
| 198 |
+
do_rescale: bool,
|
| 199 |
+
do_normalize: bool,
|
| 200 |
+
size: Optional[Dict[str, int]] = None,
|
| 201 |
+
resample: PILImageResampling = None,
|
| 202 |
+
rescale_factor: Optional[float] = None,
|
| 203 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 204 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 205 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 206 |
+
):
|
| 207 |
+
if do_reduce_labels:
|
| 208 |
+
image = self.reduce_label(image)
|
| 209 |
+
|
| 210 |
+
if do_resize:
|
| 211 |
+
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
| 212 |
+
|
| 213 |
+
if do_rescale:
|
| 214 |
+
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 215 |
+
|
| 216 |
+
if do_normalize:
|
| 217 |
+
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 218 |
+
|
| 219 |
+
return image
|
| 220 |
+
|
| 221 |
+
def _preprocess_image(
|
| 222 |
+
self,
|
| 223 |
+
image: ImageInput,
|
| 224 |
+
do_resize: Optional[bool] = None,
|
| 225 |
+
size: Dict[str, int] = None,
|
| 226 |
+
resample: PILImageResampling = None,
|
| 227 |
+
do_rescale: Optional[bool] = None,
|
| 228 |
+
rescale_factor: Optional[float] = None,
|
| 229 |
+
do_normalize: Optional[bool] = None,
|
| 230 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 231 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 232 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 233 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 234 |
+
) -> np.ndarray:
|
| 235 |
+
"""Preprocesses a single image."""
|
| 236 |
+
# All transformations expect numpy arrays.
|
| 237 |
+
image = to_numpy_array(image)
|
| 238 |
+
if do_rescale and is_scaled_image(image):
|
| 239 |
+
logger.warning_once(
|
| 240 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 241 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 242 |
+
)
|
| 243 |
+
if input_data_format is None:
|
| 244 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 245 |
+
image = self._preprocess(
|
| 246 |
+
image=image,
|
| 247 |
+
do_reduce_labels=False,
|
| 248 |
+
do_resize=do_resize,
|
| 249 |
+
size=size,
|
| 250 |
+
resample=resample,
|
| 251 |
+
do_rescale=do_rescale,
|
| 252 |
+
rescale_factor=rescale_factor,
|
| 253 |
+
do_normalize=do_normalize,
|
| 254 |
+
image_mean=image_mean,
|
| 255 |
+
image_std=image_std,
|
| 256 |
+
input_data_format=input_data_format,
|
| 257 |
+
)
|
| 258 |
+
if data_format is not None:
|
| 259 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 260 |
+
return image
|
| 261 |
+
|
| 262 |
+
def _preprocess_mask(
|
| 263 |
+
self,
|
| 264 |
+
segmentation_map: ImageInput,
|
| 265 |
+
do_reduce_labels: Optional[bool] = None,
|
| 266 |
+
do_resize: Optional[bool] = None,
|
| 267 |
+
size: Dict[str, int] = None,
|
| 268 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 269 |
+
) -> np.ndarray:
|
| 270 |
+
"""Preprocesses a single mask."""
|
| 271 |
+
segmentation_map = to_numpy_array(segmentation_map)
|
| 272 |
+
# Add channel dimension if missing - needed for certain transformations
|
| 273 |
+
if segmentation_map.ndim == 2:
|
| 274 |
+
added_channel_dim = True
|
| 275 |
+
segmentation_map = segmentation_map[None, ...]
|
| 276 |
+
input_data_format = ChannelDimension.FIRST
|
| 277 |
+
else:
|
| 278 |
+
added_channel_dim = False
|
| 279 |
+
if input_data_format is None:
|
| 280 |
+
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
|
| 281 |
+
# reduce zero label if needed
|
| 282 |
+
segmentation_map = self._preprocess(
|
| 283 |
+
image=segmentation_map,
|
| 284 |
+
do_reduce_labels=do_reduce_labels,
|
| 285 |
+
do_resize=do_resize,
|
| 286 |
+
resample=PILImageResampling.NEAREST,
|
| 287 |
+
size=size,
|
| 288 |
+
do_rescale=False,
|
| 289 |
+
do_normalize=False,
|
| 290 |
+
input_data_format=input_data_format,
|
| 291 |
+
)
|
| 292 |
+
# Remove extra channel dimension if added for processing
|
| 293 |
+
if added_channel_dim:
|
| 294 |
+
segmentation_map = segmentation_map.squeeze(0)
|
| 295 |
+
segmentation_map = segmentation_map.astype(np.int64)
|
| 296 |
+
return segmentation_map
|
| 297 |
+
|
| 298 |
+
def __call__(self, images, segmentation_maps=None, **kwargs):
|
| 299 |
+
"""
|
| 300 |
+
Preprocesses a batch of images and optionally segmentation maps.
|
| 301 |
+
|
| 302 |
+
Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
|
| 303 |
+
passed in as positional arguments.
|
| 304 |
+
"""
|
| 305 |
+
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
|
| 306 |
+
|
| 307 |
+
@deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0")
|
| 308 |
+
@filter_out_non_signature_kwargs()
|
| 309 |
+
def preprocess(
|
| 310 |
+
self,
|
| 311 |
+
images: ImageInput,
|
| 312 |
+
segmentation_maps: Optional[ImageInput] = None,
|
| 313 |
+
do_resize: Optional[bool] = None,
|
| 314 |
+
size: Optional[Dict[str, int]] = None,
|
| 315 |
+
resample: PILImageResampling = None,
|
| 316 |
+
do_rescale: Optional[bool] = None,
|
| 317 |
+
rescale_factor: Optional[float] = None,
|
| 318 |
+
do_normalize: Optional[bool] = None,
|
| 319 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 320 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 321 |
+
do_reduce_labels: Optional[bool] = None,
|
| 322 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 323 |
+
data_format: ChannelDimension = ChannelDimension.FIRST,
|
| 324 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 325 |
+
) -> PIL.Image.Image:
|
| 326 |
+
"""
|
| 327 |
+
Preprocess an image or batch of images.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
images (`ImageInput`):
|
| 331 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 332 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 333 |
+
segmentation_maps (`ImageInput`, *optional*):
|
| 334 |
+
Segmentation map to preprocess.
|
| 335 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 336 |
+
Whether to resize the image.
|
| 337 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 338 |
+
Size of the image after `resize` is applied.
|
| 339 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
| 340 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
|
| 341 |
+
has an effect if `do_resize` is set to `True`.
|
| 342 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 343 |
+
Whether to rescale the image values between [0 - 1].
|
| 344 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 345 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 346 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 347 |
+
Whether to normalize the image.
|
| 348 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 349 |
+
Image mean.
|
| 350 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 351 |
+
Image standard deviation.
|
| 352 |
+
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
|
| 353 |
+
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
|
| 354 |
+
is used for background, and background itself is not included in all classes of a dataset (e.g.
|
| 355 |
+
ADE20k). The background label will be replaced by 255.
|
| 356 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 357 |
+
The type of tensors to return. Can be one of:
|
| 358 |
+
- Unset: Return a list of `np.ndarray`.
|
| 359 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 360 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 361 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 362 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 363 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 364 |
+
The channel dimension format for the output image. Can be one of:
|
| 365 |
+
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 366 |
+
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 367 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 368 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 369 |
+
from the input image. Can be one of:
|
| 370 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 371 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 372 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 373 |
+
"""
|
| 374 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 375 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 376 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 377 |
+
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
|
| 378 |
+
resample = resample if resample is not None else self.resample
|
| 379 |
+
size = size if size is not None else self.size
|
| 380 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 381 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 382 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 383 |
+
|
| 384 |
+
images = make_list_of_images(images)
|
| 385 |
+
|
| 386 |
+
if segmentation_maps is not None:
|
| 387 |
+
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
|
| 388 |
+
|
| 389 |
+
if not valid_images(images):
|
| 390 |
+
raise ValueError(
|
| 391 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 392 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 393 |
+
)
|
| 394 |
+
validate_preprocess_arguments(
|
| 395 |
+
do_rescale=do_rescale,
|
| 396 |
+
rescale_factor=rescale_factor,
|
| 397 |
+
do_normalize=do_normalize,
|
| 398 |
+
image_mean=image_mean,
|
| 399 |
+
image_std=image_std,
|
| 400 |
+
do_resize=do_resize,
|
| 401 |
+
size=size,
|
| 402 |
+
resample=resample,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
images = [
|
| 406 |
+
self._preprocess_image(
|
| 407 |
+
image=img,
|
| 408 |
+
do_resize=do_resize,
|
| 409 |
+
resample=resample,
|
| 410 |
+
size=size,
|
| 411 |
+
do_rescale=do_rescale,
|
| 412 |
+
rescale_factor=rescale_factor,
|
| 413 |
+
do_normalize=do_normalize,
|
| 414 |
+
image_mean=image_mean,
|
| 415 |
+
image_std=image_std,
|
| 416 |
+
data_format=data_format,
|
| 417 |
+
input_data_format=input_data_format,
|
| 418 |
+
)
|
| 419 |
+
for img in images
|
| 420 |
+
]
|
| 421 |
+
|
| 422 |
+
data = {"pixel_values": images}
|
| 423 |
+
|
| 424 |
+
if segmentation_maps is not None:
|
| 425 |
+
segmentation_maps = [
|
| 426 |
+
self._preprocess_mask(
|
| 427 |
+
segmentation_map=segmentation_map,
|
| 428 |
+
do_reduce_labels=do_reduce_labels,
|
| 429 |
+
do_resize=do_resize,
|
| 430 |
+
size=size,
|
| 431 |
+
input_data_format=input_data_format,
|
| 432 |
+
)
|
| 433 |
+
for segmentation_map in segmentation_maps
|
| 434 |
+
]
|
| 435 |
+
data["labels"] = segmentation_maps
|
| 436 |
+
|
| 437 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 438 |
+
|
| 439 |
+
# Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->Segformer
|
| 440 |
+
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
|
| 441 |
+
"""
|
| 442 |
+
Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
| 443 |
+
|
| 444 |
+
Args:
|
| 445 |
+
outputs ([`SegformerForSemanticSegmentation`]):
|
| 446 |
+
Raw outputs of the model.
|
| 447 |
+
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
|
| 448 |
+
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
|
| 449 |
+
predictions will not be resized.
|
| 450 |
+
|
| 451 |
+
Returns:
|
| 452 |
+
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
|
| 453 |
+
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
|
| 454 |
+
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
| 455 |
+
"""
|
| 456 |
+
# TODO: add support for other frameworks
|
| 457 |
+
logits = outputs.logits
|
| 458 |
+
|
| 459 |
+
# Resize logits and compute semantic segmentation maps
|
| 460 |
+
if target_sizes is not None:
|
| 461 |
+
if len(logits) != len(target_sizes):
|
| 462 |
+
raise ValueError(
|
| 463 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
if is_torch_tensor(target_sizes):
|
| 467 |
+
target_sizes = target_sizes.numpy()
|
| 468 |
+
|
| 469 |
+
semantic_segmentation = []
|
| 470 |
+
|
| 471 |
+
for idx in range(len(logits)):
|
| 472 |
+
resized_logits = torch.nn.functional.interpolate(
|
| 473 |
+
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
| 474 |
+
)
|
| 475 |
+
semantic_map = resized_logits[0].argmax(dim=0)
|
| 476 |
+
semantic_segmentation.append(semantic_map)
|
| 477 |
+
else:
|
| 478 |
+
semantic_segmentation = logits.argmax(dim=1)
|
| 479 |
+
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
| 480 |
+
|
| 481 |
+
return semantic_segmentation
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
__all__ = ["SegformerImageProcessor"]
|
docs/transformers/build/lib/transformers/models/segformer/modeling_segformer.py
ADDED
|
@@ -0,0 +1,840 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 NVIDIA The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch SegFormer model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.utils.checkpoint
|
| 22 |
+
from torch import nn
|
| 23 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 24 |
+
|
| 25 |
+
from ...activations import ACT2FN
|
| 26 |
+
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput, SemanticSegmenterOutput
|
| 27 |
+
from ...modeling_utils import PreTrainedModel
|
| 28 |
+
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
| 29 |
+
from ...utils import (
|
| 30 |
+
add_code_sample_docstrings,
|
| 31 |
+
add_start_docstrings,
|
| 32 |
+
add_start_docstrings_to_model_forward,
|
| 33 |
+
logging,
|
| 34 |
+
replace_return_docstrings,
|
| 35 |
+
)
|
| 36 |
+
from .configuration_segformer import SegformerConfig
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# General docstring
|
| 43 |
+
_CONFIG_FOR_DOC = "SegformerConfig"
|
| 44 |
+
|
| 45 |
+
# Base docstring
|
| 46 |
+
_CHECKPOINT_FOR_DOC = "nvidia/mit-b0"
|
| 47 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 256, 16, 16]
|
| 48 |
+
|
| 49 |
+
# Image classification docstring
|
| 50 |
+
_IMAGE_CLASS_CHECKPOINT = "nvidia/mit-b0"
|
| 51 |
+
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SegFormerImageClassifierOutput(ImageClassifierOutput):
|
| 55 |
+
"""
|
| 56 |
+
Base class for outputs of image classification models.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 60 |
+
Classification (or regression if config.num_labels==1) loss.
|
| 61 |
+
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
| 62 |
+
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
| 63 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 64 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 65 |
+
one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
|
| 66 |
+
called feature maps) of the model at the output of each stage.
|
| 67 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 68 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
|
| 69 |
+
sequence_length)`.
|
| 70 |
+
|
| 71 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 72 |
+
heads.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
loss: Optional[torch.FloatTensor] = None
|
| 76 |
+
logits: Optional[torch.FloatTensor] = None
|
| 77 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 78 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# Copied from transformers.models.beit.modeling_beit.drop_path
|
| 82 |
+
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
| 83 |
+
"""
|
| 84 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 85 |
+
|
| 86 |
+
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
| 87 |
+
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 88 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
| 89 |
+
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
| 90 |
+
argument.
|
| 91 |
+
"""
|
| 92 |
+
if drop_prob == 0.0 or not training:
|
| 93 |
+
return input
|
| 94 |
+
keep_prob = 1 - drop_prob
|
| 95 |
+
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 96 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
| 97 |
+
random_tensor.floor_() # binarize
|
| 98 |
+
output = input.div(keep_prob) * random_tensor
|
| 99 |
+
return output
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath with ConvNext->Segformer
|
| 103 |
+
class SegformerDropPath(nn.Module):
|
| 104 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 105 |
+
|
| 106 |
+
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
| 107 |
+
super().__init__()
|
| 108 |
+
self.drop_prob = drop_prob
|
| 109 |
+
|
| 110 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 111 |
+
return drop_path(hidden_states, self.drop_prob, self.training)
|
| 112 |
+
|
| 113 |
+
def extra_repr(self) -> str:
|
| 114 |
+
return "p={}".format(self.drop_prob)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class SegformerOverlapPatchEmbeddings(nn.Module):
|
| 118 |
+
"""Construct the overlapping patch embeddings."""
|
| 119 |
+
|
| 120 |
+
def __init__(self, patch_size, stride, num_channels, hidden_size):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.proj = nn.Conv2d(
|
| 123 |
+
num_channels,
|
| 124 |
+
hidden_size,
|
| 125 |
+
kernel_size=patch_size,
|
| 126 |
+
stride=stride,
|
| 127 |
+
padding=patch_size // 2,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
self.layer_norm = nn.LayerNorm(hidden_size)
|
| 131 |
+
|
| 132 |
+
def forward(self, pixel_values):
|
| 133 |
+
embeddings = self.proj(pixel_values)
|
| 134 |
+
_, _, height, width = embeddings.shape
|
| 135 |
+
# (batch_size, num_channels, height, width) -> (batch_size, num_channels, height*width) -> (batch_size, height*width, num_channels)
|
| 136 |
+
# this can be fed to a Transformer layer
|
| 137 |
+
embeddings = embeddings.flatten(2).transpose(1, 2)
|
| 138 |
+
embeddings = self.layer_norm(embeddings)
|
| 139 |
+
return embeddings, height, width
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class SegformerEfficientSelfAttention(nn.Module):
|
| 143 |
+
"""SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT
|
| 144 |
+
paper](https://arxiv.org/abs/2102.12122)."""
|
| 145 |
+
|
| 146 |
+
def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.hidden_size = hidden_size
|
| 149 |
+
self.num_attention_heads = num_attention_heads
|
| 150 |
+
|
| 151 |
+
if self.hidden_size % self.num_attention_heads != 0:
|
| 152 |
+
raise ValueError(
|
| 153 |
+
f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
|
| 154 |
+
f"heads ({self.num_attention_heads})"
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
|
| 158 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 159 |
+
|
| 160 |
+
self.query = nn.Linear(self.hidden_size, self.all_head_size)
|
| 161 |
+
self.key = nn.Linear(self.hidden_size, self.all_head_size)
|
| 162 |
+
self.value = nn.Linear(self.hidden_size, self.all_head_size)
|
| 163 |
+
|
| 164 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 165 |
+
|
| 166 |
+
self.sr_ratio = sequence_reduction_ratio
|
| 167 |
+
if sequence_reduction_ratio > 1:
|
| 168 |
+
self.sr = nn.Conv2d(
|
| 169 |
+
hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio
|
| 170 |
+
)
|
| 171 |
+
self.layer_norm = nn.LayerNorm(hidden_size)
|
| 172 |
+
|
| 173 |
+
def transpose_for_scores(self, hidden_states):
|
| 174 |
+
new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 175 |
+
hidden_states = hidden_states.view(new_shape)
|
| 176 |
+
return hidden_states.permute(0, 2, 1, 3)
|
| 177 |
+
|
| 178 |
+
def forward(
|
| 179 |
+
self,
|
| 180 |
+
hidden_states,
|
| 181 |
+
height,
|
| 182 |
+
width,
|
| 183 |
+
output_attentions=False,
|
| 184 |
+
):
|
| 185 |
+
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
| 186 |
+
|
| 187 |
+
if self.sr_ratio > 1:
|
| 188 |
+
batch_size, seq_len, num_channels = hidden_states.shape
|
| 189 |
+
# Reshape to (batch_size, num_channels, height, width)
|
| 190 |
+
hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
|
| 191 |
+
# Apply sequence reduction
|
| 192 |
+
hidden_states = self.sr(hidden_states)
|
| 193 |
+
# Reshape back to (batch_size, seq_len, num_channels)
|
| 194 |
+
hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1)
|
| 195 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 196 |
+
|
| 197 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 198 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 199 |
+
|
| 200 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 201 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 202 |
+
|
| 203 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 204 |
+
|
| 205 |
+
# Normalize the attention scores to probabilities.
|
| 206 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 207 |
+
|
| 208 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 209 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 210 |
+
attention_probs = self.dropout(attention_probs)
|
| 211 |
+
|
| 212 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 213 |
+
|
| 214 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 215 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 216 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
| 217 |
+
|
| 218 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 219 |
+
|
| 220 |
+
return outputs
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class SegformerSelfOutput(nn.Module):
|
| 224 |
+
def __init__(self, config, hidden_size):
|
| 225 |
+
super().__init__()
|
| 226 |
+
self.dense = nn.Linear(hidden_size, hidden_size)
|
| 227 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 228 |
+
|
| 229 |
+
def forward(self, hidden_states, input_tensor):
|
| 230 |
+
hidden_states = self.dense(hidden_states)
|
| 231 |
+
hidden_states = self.dropout(hidden_states)
|
| 232 |
+
return hidden_states
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class SegformerAttention(nn.Module):
|
| 236 |
+
def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
|
| 237 |
+
super().__init__()
|
| 238 |
+
self.self = SegformerEfficientSelfAttention(
|
| 239 |
+
config=config,
|
| 240 |
+
hidden_size=hidden_size,
|
| 241 |
+
num_attention_heads=num_attention_heads,
|
| 242 |
+
sequence_reduction_ratio=sequence_reduction_ratio,
|
| 243 |
+
)
|
| 244 |
+
self.output = SegformerSelfOutput(config, hidden_size=hidden_size)
|
| 245 |
+
self.pruned_heads = set()
|
| 246 |
+
|
| 247 |
+
def prune_heads(self, heads):
|
| 248 |
+
if len(heads) == 0:
|
| 249 |
+
return
|
| 250 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 251 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# Prune linear layers
|
| 255 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 256 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 257 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 258 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 259 |
+
|
| 260 |
+
# Update hyper params and store pruned heads
|
| 261 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 262 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 263 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 264 |
+
|
| 265 |
+
def forward(self, hidden_states, height, width, output_attentions=False):
|
| 266 |
+
self_outputs = self.self(hidden_states, height, width, output_attentions)
|
| 267 |
+
|
| 268 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 269 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 270 |
+
return outputs
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class SegformerDWConv(nn.Module):
|
| 274 |
+
def __init__(self, dim=768):
|
| 275 |
+
super().__init__()
|
| 276 |
+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
| 277 |
+
|
| 278 |
+
def forward(self, hidden_states, height, width):
|
| 279 |
+
batch_size, seq_len, num_channels = hidden_states.shape
|
| 280 |
+
hidden_states = hidden_states.transpose(1, 2).view(batch_size, num_channels, height, width)
|
| 281 |
+
hidden_states = self.dwconv(hidden_states)
|
| 282 |
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
| 283 |
+
|
| 284 |
+
return hidden_states
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class SegformerMixFFN(nn.Module):
|
| 288 |
+
def __init__(self, config, in_features, hidden_features=None, out_features=None):
|
| 289 |
+
super().__init__()
|
| 290 |
+
out_features = out_features or in_features
|
| 291 |
+
self.dense1 = nn.Linear(in_features, hidden_features)
|
| 292 |
+
self.dwconv = SegformerDWConv(hidden_features)
|
| 293 |
+
if isinstance(config.hidden_act, str):
|
| 294 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 295 |
+
else:
|
| 296 |
+
self.intermediate_act_fn = config.hidden_act
|
| 297 |
+
self.dense2 = nn.Linear(hidden_features, out_features)
|
| 298 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 299 |
+
|
| 300 |
+
def forward(self, hidden_states, height, width):
|
| 301 |
+
hidden_states = self.dense1(hidden_states)
|
| 302 |
+
hidden_states = self.dwconv(hidden_states, height, width)
|
| 303 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 304 |
+
hidden_states = self.dropout(hidden_states)
|
| 305 |
+
hidden_states = self.dense2(hidden_states)
|
| 306 |
+
hidden_states = self.dropout(hidden_states)
|
| 307 |
+
return hidden_states
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class SegformerLayer(nn.Module):
|
| 311 |
+
"""This corresponds to the Block class in the original implementation."""
|
| 312 |
+
|
| 313 |
+
def __init__(self, config, hidden_size, num_attention_heads, drop_path, sequence_reduction_ratio, mlp_ratio):
|
| 314 |
+
super().__init__()
|
| 315 |
+
self.layer_norm_1 = nn.LayerNorm(hidden_size)
|
| 316 |
+
self.attention = SegformerAttention(
|
| 317 |
+
config,
|
| 318 |
+
hidden_size=hidden_size,
|
| 319 |
+
num_attention_heads=num_attention_heads,
|
| 320 |
+
sequence_reduction_ratio=sequence_reduction_ratio,
|
| 321 |
+
)
|
| 322 |
+
self.drop_path = SegformerDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 323 |
+
self.layer_norm_2 = nn.LayerNorm(hidden_size)
|
| 324 |
+
mlp_hidden_size = int(hidden_size * mlp_ratio)
|
| 325 |
+
self.mlp = SegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size)
|
| 326 |
+
|
| 327 |
+
def forward(self, hidden_states, height, width, output_attentions=False):
|
| 328 |
+
self_attention_outputs = self.attention(
|
| 329 |
+
self.layer_norm_1(hidden_states), # in Segformer, layernorm is applied before self-attention
|
| 330 |
+
height,
|
| 331 |
+
width,
|
| 332 |
+
output_attentions=output_attentions,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
attention_output = self_attention_outputs[0]
|
| 336 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 337 |
+
|
| 338 |
+
# first residual connection (with stochastic depth)
|
| 339 |
+
attention_output = self.drop_path(attention_output)
|
| 340 |
+
hidden_states = attention_output + hidden_states
|
| 341 |
+
|
| 342 |
+
mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width)
|
| 343 |
+
|
| 344 |
+
# second residual connection (with stochastic depth)
|
| 345 |
+
mlp_output = self.drop_path(mlp_output)
|
| 346 |
+
layer_output = mlp_output + hidden_states
|
| 347 |
+
|
| 348 |
+
outputs = (layer_output,) + outputs
|
| 349 |
+
|
| 350 |
+
return outputs
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class SegformerEncoder(nn.Module):
|
| 354 |
+
def __init__(self, config):
|
| 355 |
+
super().__init__()
|
| 356 |
+
self.config = config
|
| 357 |
+
|
| 358 |
+
# stochastic depth decay rule
|
| 359 |
+
drop_path_decays = [
|
| 360 |
+
x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")
|
| 361 |
+
]
|
| 362 |
+
|
| 363 |
+
# patch embeddings
|
| 364 |
+
embeddings = []
|
| 365 |
+
for i in range(config.num_encoder_blocks):
|
| 366 |
+
embeddings.append(
|
| 367 |
+
SegformerOverlapPatchEmbeddings(
|
| 368 |
+
patch_size=config.patch_sizes[i],
|
| 369 |
+
stride=config.strides[i],
|
| 370 |
+
num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
|
| 371 |
+
hidden_size=config.hidden_sizes[i],
|
| 372 |
+
)
|
| 373 |
+
)
|
| 374 |
+
self.patch_embeddings = nn.ModuleList(embeddings)
|
| 375 |
+
|
| 376 |
+
# Transformer blocks
|
| 377 |
+
blocks = []
|
| 378 |
+
cur = 0
|
| 379 |
+
for i in range(config.num_encoder_blocks):
|
| 380 |
+
# each block consists of layers
|
| 381 |
+
layers = []
|
| 382 |
+
if i != 0:
|
| 383 |
+
cur += config.depths[i - 1]
|
| 384 |
+
for j in range(config.depths[i]):
|
| 385 |
+
layers.append(
|
| 386 |
+
SegformerLayer(
|
| 387 |
+
config,
|
| 388 |
+
hidden_size=config.hidden_sizes[i],
|
| 389 |
+
num_attention_heads=config.num_attention_heads[i],
|
| 390 |
+
drop_path=drop_path_decays[cur + j],
|
| 391 |
+
sequence_reduction_ratio=config.sr_ratios[i],
|
| 392 |
+
mlp_ratio=config.mlp_ratios[i],
|
| 393 |
+
)
|
| 394 |
+
)
|
| 395 |
+
blocks.append(nn.ModuleList(layers))
|
| 396 |
+
|
| 397 |
+
self.block = nn.ModuleList(blocks)
|
| 398 |
+
|
| 399 |
+
# Layer norms
|
| 400 |
+
self.layer_norm = nn.ModuleList(
|
| 401 |
+
[nn.LayerNorm(config.hidden_sizes[i]) for i in range(config.num_encoder_blocks)]
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
def forward(
|
| 405 |
+
self,
|
| 406 |
+
pixel_values: torch.FloatTensor,
|
| 407 |
+
output_attentions: Optional[bool] = False,
|
| 408 |
+
output_hidden_states: Optional[bool] = False,
|
| 409 |
+
return_dict: Optional[bool] = True,
|
| 410 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 411 |
+
all_hidden_states = () if output_hidden_states else None
|
| 412 |
+
all_self_attentions = () if output_attentions else None
|
| 413 |
+
|
| 414 |
+
batch_size = pixel_values.shape[0]
|
| 415 |
+
|
| 416 |
+
hidden_states = pixel_values
|
| 417 |
+
for idx, x in enumerate(zip(self.patch_embeddings, self.block, self.layer_norm)):
|
| 418 |
+
embedding_layer, block_layer, norm_layer = x
|
| 419 |
+
# first, obtain patch embeddings
|
| 420 |
+
hidden_states, height, width = embedding_layer(hidden_states)
|
| 421 |
+
# second, send embeddings through blocks
|
| 422 |
+
for i, blk in enumerate(block_layer):
|
| 423 |
+
layer_outputs = blk(hidden_states, height, width, output_attentions)
|
| 424 |
+
hidden_states = layer_outputs[0]
|
| 425 |
+
if output_attentions:
|
| 426 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 427 |
+
# third, apply layer norm
|
| 428 |
+
hidden_states = norm_layer(hidden_states)
|
| 429 |
+
# fourth, optionally reshape back to (batch_size, num_channels, height, width)
|
| 430 |
+
if idx != len(self.patch_embeddings) - 1 or (
|
| 431 |
+
idx == len(self.patch_embeddings) - 1 and self.config.reshape_last_stage
|
| 432 |
+
):
|
| 433 |
+
hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
|
| 434 |
+
if output_hidden_states:
|
| 435 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 436 |
+
|
| 437 |
+
if not return_dict:
|
| 438 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 439 |
+
return BaseModelOutput(
|
| 440 |
+
last_hidden_state=hidden_states,
|
| 441 |
+
hidden_states=all_hidden_states,
|
| 442 |
+
attentions=all_self_attentions,
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class SegformerPreTrainedModel(PreTrainedModel):
|
| 447 |
+
"""
|
| 448 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 449 |
+
models.
|
| 450 |
+
"""
|
| 451 |
+
|
| 452 |
+
config_class = SegformerConfig
|
| 453 |
+
base_model_prefix = "segformer"
|
| 454 |
+
main_input_name = "pixel_values"
|
| 455 |
+
|
| 456 |
+
def _init_weights(self, module):
|
| 457 |
+
"""Initialize the weights"""
|
| 458 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 459 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 460 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 461 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 462 |
+
if module.bias is not None:
|
| 463 |
+
module.bias.data.zero_()
|
| 464 |
+
elif isinstance(module, nn.Embedding):
|
| 465 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 466 |
+
if module.padding_idx is not None:
|
| 467 |
+
module.weight.data[module.padding_idx].zero_()
|
| 468 |
+
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
|
| 469 |
+
module.bias.data.zero_()
|
| 470 |
+
module.weight.data.fill_(1.0)
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
SEGFORMER_START_DOCSTRING = r"""
|
| 474 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
| 475 |
+
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 476 |
+
behavior.
|
| 477 |
+
|
| 478 |
+
Parameters:
|
| 479 |
+
config ([`SegformerConfig`]): Model configuration class with all the parameters of the model.
|
| 480 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 481 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 482 |
+
"""
|
| 483 |
+
|
| 484 |
+
SEGFORMER_INPUTS_DOCSTRING = r"""
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 488 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
| 489 |
+
[`AutoImageProcessor`]. See [`SegformerImageProcessor.__call__`] for details.
|
| 490 |
+
|
| 491 |
+
output_attentions (`bool`, *optional*):
|
| 492 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 493 |
+
tensors for more detail.
|
| 494 |
+
output_hidden_states (`bool`, *optional*):
|
| 495 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 496 |
+
more detail.
|
| 497 |
+
return_dict (`bool`, *optional*):
|
| 498 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 499 |
+
"""
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
@add_start_docstrings(
|
| 503 |
+
"The bare SegFormer encoder (Mix-Transformer) outputting raw hidden-states without any specific head on top.",
|
| 504 |
+
SEGFORMER_START_DOCSTRING,
|
| 505 |
+
)
|
| 506 |
+
class SegformerModel(SegformerPreTrainedModel):
|
| 507 |
+
def __init__(self, config):
|
| 508 |
+
super().__init__(config)
|
| 509 |
+
self.config = config
|
| 510 |
+
|
| 511 |
+
# hierarchical Transformer encoder
|
| 512 |
+
self.encoder = SegformerEncoder(config)
|
| 513 |
+
|
| 514 |
+
# Initialize weights and apply final processing
|
| 515 |
+
self.post_init()
|
| 516 |
+
|
| 517 |
+
def _prune_heads(self, heads_to_prune):
|
| 518 |
+
"""
|
| 519 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 520 |
+
class PreTrainedModel
|
| 521 |
+
"""
|
| 522 |
+
for layer, heads in heads_to_prune.items():
|
| 523 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 524 |
+
|
| 525 |
+
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 526 |
+
@add_code_sample_docstrings(
|
| 527 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 528 |
+
output_type=BaseModelOutput,
|
| 529 |
+
config_class=_CONFIG_FOR_DOC,
|
| 530 |
+
modality="vision",
|
| 531 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 532 |
+
)
|
| 533 |
+
def forward(
|
| 534 |
+
self,
|
| 535 |
+
pixel_values: torch.FloatTensor,
|
| 536 |
+
output_attentions: Optional[bool] = None,
|
| 537 |
+
output_hidden_states: Optional[bool] = None,
|
| 538 |
+
return_dict: Optional[bool] = None,
|
| 539 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 540 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 541 |
+
output_hidden_states = (
|
| 542 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 543 |
+
)
|
| 544 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 545 |
+
|
| 546 |
+
encoder_outputs = self.encoder(
|
| 547 |
+
pixel_values,
|
| 548 |
+
output_attentions=output_attentions,
|
| 549 |
+
output_hidden_states=output_hidden_states,
|
| 550 |
+
return_dict=return_dict,
|
| 551 |
+
)
|
| 552 |
+
sequence_output = encoder_outputs[0]
|
| 553 |
+
|
| 554 |
+
if not return_dict:
|
| 555 |
+
return (sequence_output,) + encoder_outputs[1:]
|
| 556 |
+
|
| 557 |
+
return BaseModelOutput(
|
| 558 |
+
last_hidden_state=sequence_output,
|
| 559 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 560 |
+
attentions=encoder_outputs.attentions,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
@add_start_docstrings(
|
| 565 |
+
"""
|
| 566 |
+
SegFormer Model transformer with an image classification head on top (a linear layer on top of the final hidden
|
| 567 |
+
states) e.g. for ImageNet.
|
| 568 |
+
""",
|
| 569 |
+
SEGFORMER_START_DOCSTRING,
|
| 570 |
+
)
|
| 571 |
+
class SegformerForImageClassification(SegformerPreTrainedModel):
|
| 572 |
+
def __init__(self, config):
|
| 573 |
+
super().__init__(config)
|
| 574 |
+
|
| 575 |
+
self.num_labels = config.num_labels
|
| 576 |
+
self.segformer = SegformerModel(config)
|
| 577 |
+
|
| 578 |
+
# Classifier head
|
| 579 |
+
self.classifier = nn.Linear(config.hidden_sizes[-1], config.num_labels)
|
| 580 |
+
|
| 581 |
+
# Initialize weights and apply final processing
|
| 582 |
+
self.post_init()
|
| 583 |
+
|
| 584 |
+
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 585 |
+
@add_code_sample_docstrings(
|
| 586 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
| 587 |
+
output_type=SegFormerImageClassifierOutput,
|
| 588 |
+
config_class=_CONFIG_FOR_DOC,
|
| 589 |
+
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
| 590 |
+
)
|
| 591 |
+
def forward(
|
| 592 |
+
self,
|
| 593 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 594 |
+
labels: Optional[torch.LongTensor] = None,
|
| 595 |
+
output_attentions: Optional[bool] = None,
|
| 596 |
+
output_hidden_states: Optional[bool] = None,
|
| 597 |
+
return_dict: Optional[bool] = None,
|
| 598 |
+
) -> Union[Tuple, SegFormerImageClassifierOutput]:
|
| 599 |
+
r"""
|
| 600 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 601 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| 602 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 603 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 604 |
+
"""
|
| 605 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 606 |
+
|
| 607 |
+
outputs = self.segformer(
|
| 608 |
+
pixel_values,
|
| 609 |
+
output_attentions=output_attentions,
|
| 610 |
+
output_hidden_states=output_hidden_states,
|
| 611 |
+
return_dict=return_dict,
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
sequence_output = outputs[0]
|
| 615 |
+
|
| 616 |
+
# convert last hidden states to (batch_size, height*width, hidden_size)
|
| 617 |
+
batch_size = sequence_output.shape[0]
|
| 618 |
+
if self.config.reshape_last_stage:
|
| 619 |
+
# (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
|
| 620 |
+
sequence_output = sequence_output.permute(0, 2, 3, 1)
|
| 621 |
+
sequence_output = sequence_output.reshape(batch_size, -1, self.config.hidden_sizes[-1])
|
| 622 |
+
|
| 623 |
+
# global average pooling
|
| 624 |
+
sequence_output = sequence_output.mean(dim=1)
|
| 625 |
+
|
| 626 |
+
logits = self.classifier(sequence_output)
|
| 627 |
+
|
| 628 |
+
loss = None
|
| 629 |
+
if labels is not None:
|
| 630 |
+
if self.config.problem_type is None:
|
| 631 |
+
if self.num_labels == 1:
|
| 632 |
+
self.config.problem_type = "regression"
|
| 633 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 634 |
+
self.config.problem_type = "single_label_classification"
|
| 635 |
+
else:
|
| 636 |
+
self.config.problem_type = "multi_label_classification"
|
| 637 |
+
|
| 638 |
+
if self.config.problem_type == "regression":
|
| 639 |
+
loss_fct = MSELoss()
|
| 640 |
+
if self.num_labels == 1:
|
| 641 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 642 |
+
else:
|
| 643 |
+
loss = loss_fct(logits, labels)
|
| 644 |
+
elif self.config.problem_type == "single_label_classification":
|
| 645 |
+
loss_fct = CrossEntropyLoss()
|
| 646 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 647 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 648 |
+
loss_fct = BCEWithLogitsLoss()
|
| 649 |
+
loss = loss_fct(logits, labels)
|
| 650 |
+
if not return_dict:
|
| 651 |
+
output = (logits,) + outputs[1:]
|
| 652 |
+
return ((loss,) + output) if loss is not None else output
|
| 653 |
+
|
| 654 |
+
return SegFormerImageClassifierOutput(
|
| 655 |
+
loss=loss,
|
| 656 |
+
logits=logits,
|
| 657 |
+
hidden_states=outputs.hidden_states,
|
| 658 |
+
attentions=outputs.attentions,
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
class SegformerMLP(nn.Module):
|
| 663 |
+
"""
|
| 664 |
+
Linear Embedding.
|
| 665 |
+
"""
|
| 666 |
+
|
| 667 |
+
def __init__(self, config: SegformerConfig, input_dim):
|
| 668 |
+
super().__init__()
|
| 669 |
+
self.proj = nn.Linear(input_dim, config.decoder_hidden_size)
|
| 670 |
+
|
| 671 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 672 |
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
| 673 |
+
hidden_states = self.proj(hidden_states)
|
| 674 |
+
return hidden_states
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
class SegformerDecodeHead(SegformerPreTrainedModel):
|
| 678 |
+
def __init__(self, config):
|
| 679 |
+
super().__init__(config)
|
| 680 |
+
# linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size
|
| 681 |
+
mlps = []
|
| 682 |
+
for i in range(config.num_encoder_blocks):
|
| 683 |
+
mlp = SegformerMLP(config, input_dim=config.hidden_sizes[i])
|
| 684 |
+
mlps.append(mlp)
|
| 685 |
+
self.linear_c = nn.ModuleList(mlps)
|
| 686 |
+
|
| 687 |
+
# the following 3 layers implement the ConvModule of the original implementation
|
| 688 |
+
self.linear_fuse = nn.Conv2d(
|
| 689 |
+
in_channels=config.decoder_hidden_size * config.num_encoder_blocks,
|
| 690 |
+
out_channels=config.decoder_hidden_size,
|
| 691 |
+
kernel_size=1,
|
| 692 |
+
bias=False,
|
| 693 |
+
)
|
| 694 |
+
self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size)
|
| 695 |
+
self.activation = nn.ReLU()
|
| 696 |
+
|
| 697 |
+
self.dropout = nn.Dropout(config.classifier_dropout_prob)
|
| 698 |
+
self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1)
|
| 699 |
+
|
| 700 |
+
self.config = config
|
| 701 |
+
|
| 702 |
+
def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor:
|
| 703 |
+
batch_size = encoder_hidden_states[-1].shape[0]
|
| 704 |
+
|
| 705 |
+
all_hidden_states = ()
|
| 706 |
+
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c):
|
| 707 |
+
if self.config.reshape_last_stage is False and encoder_hidden_state.ndim == 3:
|
| 708 |
+
height = width = int(math.sqrt(encoder_hidden_state.shape[-1]))
|
| 709 |
+
encoder_hidden_state = (
|
| 710 |
+
encoder_hidden_state.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
# unify channel dimension
|
| 714 |
+
height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
|
| 715 |
+
encoder_hidden_state = mlp(encoder_hidden_state)
|
| 716 |
+
encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
|
| 717 |
+
encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width)
|
| 718 |
+
# upsample
|
| 719 |
+
encoder_hidden_state = nn.functional.interpolate(
|
| 720 |
+
encoder_hidden_state, size=encoder_hidden_states[0].size()[2:], mode="bilinear", align_corners=False
|
| 721 |
+
)
|
| 722 |
+
all_hidden_states += (encoder_hidden_state,)
|
| 723 |
+
|
| 724 |
+
hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1))
|
| 725 |
+
hidden_states = self.batch_norm(hidden_states)
|
| 726 |
+
hidden_states = self.activation(hidden_states)
|
| 727 |
+
hidden_states = self.dropout(hidden_states)
|
| 728 |
+
|
| 729 |
+
# logits are of shape (batch_size, num_labels, height/4, width/4)
|
| 730 |
+
logits = self.classifier(hidden_states)
|
| 731 |
+
|
| 732 |
+
return logits
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
@add_start_docstrings(
|
| 736 |
+
"""SegFormer Model transformer with an all-MLP decode head on top e.g. for ADE20k, CityScapes.""",
|
| 737 |
+
SEGFORMER_START_DOCSTRING,
|
| 738 |
+
)
|
| 739 |
+
class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
|
| 740 |
+
def __init__(self, config):
|
| 741 |
+
super().__init__(config)
|
| 742 |
+
self.segformer = SegformerModel(config)
|
| 743 |
+
self.decode_head = SegformerDecodeHead(config)
|
| 744 |
+
|
| 745 |
+
# Initialize weights and apply final processing
|
| 746 |
+
self.post_init()
|
| 747 |
+
|
| 748 |
+
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 749 |
+
@replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
|
| 750 |
+
def forward(
|
| 751 |
+
self,
|
| 752 |
+
pixel_values: torch.FloatTensor,
|
| 753 |
+
labels: Optional[torch.LongTensor] = None,
|
| 754 |
+
output_attentions: Optional[bool] = None,
|
| 755 |
+
output_hidden_states: Optional[bool] = None,
|
| 756 |
+
return_dict: Optional[bool] = None,
|
| 757 |
+
) -> Union[Tuple, SemanticSegmenterOutput]:
|
| 758 |
+
r"""
|
| 759 |
+
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
| 760 |
+
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
|
| 761 |
+
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
|
| 762 |
+
|
| 763 |
+
Returns:
|
| 764 |
+
|
| 765 |
+
Examples:
|
| 766 |
+
|
| 767 |
+
```python
|
| 768 |
+
>>> from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
|
| 769 |
+
>>> from PIL import Image
|
| 770 |
+
>>> import requests
|
| 771 |
+
|
| 772 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
| 773 |
+
>>> model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
| 774 |
+
|
| 775 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 776 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 777 |
+
|
| 778 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 779 |
+
>>> outputs = model(**inputs)
|
| 780 |
+
>>> logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
|
| 781 |
+
>>> list(logits.shape)
|
| 782 |
+
[1, 150, 128, 128]
|
| 783 |
+
```"""
|
| 784 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 785 |
+
output_hidden_states = (
|
| 786 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
if labels is not None and self.config.num_labels < 1:
|
| 790 |
+
raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}")
|
| 791 |
+
|
| 792 |
+
outputs = self.segformer(
|
| 793 |
+
pixel_values,
|
| 794 |
+
output_attentions=output_attentions,
|
| 795 |
+
output_hidden_states=True, # we need the intermediate hidden states
|
| 796 |
+
return_dict=return_dict,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
| 800 |
+
|
| 801 |
+
logits = self.decode_head(encoder_hidden_states)
|
| 802 |
+
|
| 803 |
+
loss = None
|
| 804 |
+
if labels is not None:
|
| 805 |
+
# upsample logits to the images' original size
|
| 806 |
+
upsampled_logits = nn.functional.interpolate(
|
| 807 |
+
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
| 808 |
+
)
|
| 809 |
+
if self.config.num_labels > 1:
|
| 810 |
+
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
|
| 811 |
+
loss = loss_fct(upsampled_logits, labels)
|
| 812 |
+
elif self.config.num_labels == 1:
|
| 813 |
+
valid_mask = ((labels >= 0) & (labels != self.config.semantic_loss_ignore_index)).float()
|
| 814 |
+
loss_fct = BCEWithLogitsLoss(reduction="none")
|
| 815 |
+
loss = loss_fct(upsampled_logits.squeeze(1), labels.float())
|
| 816 |
+
loss = (loss * valid_mask).mean()
|
| 817 |
+
|
| 818 |
+
if not return_dict:
|
| 819 |
+
if output_hidden_states:
|
| 820 |
+
output = (logits,) + outputs[1:]
|
| 821 |
+
else:
|
| 822 |
+
output = (logits,) + outputs[2:]
|
| 823 |
+
return ((loss,) + output) if loss is not None else output
|
| 824 |
+
|
| 825 |
+
return SemanticSegmenterOutput(
|
| 826 |
+
loss=loss,
|
| 827 |
+
logits=logits,
|
| 828 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 829 |
+
attentions=outputs.attentions,
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
|
| 833 |
+
__all__ = [
|
| 834 |
+
"SegformerDecodeHead",
|
| 835 |
+
"SegformerForImageClassification",
|
| 836 |
+
"SegformerForSemanticSegmentation",
|
| 837 |
+
"SegformerLayer",
|
| 838 |
+
"SegformerModel",
|
| 839 |
+
"SegformerPreTrainedModel",
|
| 840 |
+
]
|
docs/transformers/build/lib/transformers/models/segformer/modeling_tf_segformer.py
ADDED
|
@@ -0,0 +1,1045 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 NVIDIA The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""TensorFlow SegFormer model."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
from typing import Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import tensorflow as tf
|
| 23 |
+
|
| 24 |
+
from ...activations_tf import get_tf_activation
|
| 25 |
+
from ...file_utils import (
|
| 26 |
+
add_code_sample_docstrings,
|
| 27 |
+
add_start_docstrings,
|
| 28 |
+
add_start_docstrings_to_model_forward,
|
| 29 |
+
replace_return_docstrings,
|
| 30 |
+
)
|
| 31 |
+
from ...modeling_tf_outputs import TFBaseModelOutput, TFSemanticSegmenterOutput, TFSequenceClassifierOutput
|
| 32 |
+
from ...modeling_tf_utils import (
|
| 33 |
+
TFPreTrainedModel,
|
| 34 |
+
TFSequenceClassificationLoss,
|
| 35 |
+
keras,
|
| 36 |
+
keras_serializable,
|
| 37 |
+
unpack_inputs,
|
| 38 |
+
)
|
| 39 |
+
from ...tf_utils import shape_list, stable_softmax
|
| 40 |
+
from ...utils import logging
|
| 41 |
+
from .configuration_segformer import SegformerConfig
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = logging.get_logger(__name__)
|
| 45 |
+
|
| 46 |
+
# General docstring
|
| 47 |
+
_CONFIG_FOR_DOC = "SegformerConfig"
|
| 48 |
+
|
| 49 |
+
# Base docstring
|
| 50 |
+
_CHECKPOINT_FOR_DOC = "nvidia/mit-b0"
|
| 51 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 256, 16, 16]
|
| 52 |
+
|
| 53 |
+
# Image classification docstring
|
| 54 |
+
_IMAGE_CLASS_CHECKPOINT = "nvidia/mit-b0"
|
| 55 |
+
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath with ConvNext->Segformer
|
| 59 |
+
class TFSegformerDropPath(keras.layers.Layer):
|
| 60 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 61 |
+
References:
|
| 62 |
+
(1) github.com:rwightman/pytorch-image-models
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(self, drop_path: float, **kwargs):
|
| 66 |
+
super().__init__(**kwargs)
|
| 67 |
+
self.drop_path = drop_path
|
| 68 |
+
|
| 69 |
+
def call(self, x: tf.Tensor, training=None):
|
| 70 |
+
if training:
|
| 71 |
+
keep_prob = 1 - self.drop_path
|
| 72 |
+
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
|
| 73 |
+
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
|
| 74 |
+
random_tensor = tf.floor(random_tensor)
|
| 75 |
+
return (x / keep_prob) * random_tensor
|
| 76 |
+
return x
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class TFSegformerOverlapPatchEmbeddings(keras.layers.Layer):
|
| 80 |
+
"""Construct the overlapping patch embeddings."""
|
| 81 |
+
|
| 82 |
+
def __init__(self, patch_size, stride, num_channels, hidden_size, **kwargs):
|
| 83 |
+
super().__init__(**kwargs)
|
| 84 |
+
self.padding = keras.layers.ZeroPadding2D(padding=patch_size // 2)
|
| 85 |
+
self.proj = keras.layers.Conv2D(
|
| 86 |
+
filters=hidden_size, kernel_size=patch_size, strides=stride, padding="VALID", name="proj"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm")
|
| 90 |
+
self.num_channels = num_channels
|
| 91 |
+
self.hidden_size = hidden_size
|
| 92 |
+
|
| 93 |
+
def call(self, pixel_values: tf.Tensor) -> Tuple[tf.Tensor, int, int]:
|
| 94 |
+
embeddings = self.proj(self.padding(pixel_values))
|
| 95 |
+
height = shape_list(embeddings)[1]
|
| 96 |
+
width = shape_list(embeddings)[2]
|
| 97 |
+
hidden_dim = shape_list(embeddings)[3]
|
| 98 |
+
# (batch_size, height, width, num_channels) -> (batch_size, height*width, num_channels)
|
| 99 |
+
# this can be fed to a Transformer layer
|
| 100 |
+
embeddings = tf.reshape(embeddings, (-1, height * width, hidden_dim))
|
| 101 |
+
embeddings = self.layer_norm(embeddings)
|
| 102 |
+
return embeddings, height, width
|
| 103 |
+
|
| 104 |
+
def build(self, input_shape=None):
|
| 105 |
+
if self.built:
|
| 106 |
+
return
|
| 107 |
+
self.built = True
|
| 108 |
+
if getattr(self, "proj", None) is not None:
|
| 109 |
+
with tf.name_scope(self.proj.name):
|
| 110 |
+
self.proj.build([None, None, None, self.num_channels])
|
| 111 |
+
if getattr(self, "layer_norm", None) is not None:
|
| 112 |
+
with tf.name_scope(self.layer_norm.name):
|
| 113 |
+
self.layer_norm.build([None, None, self.hidden_size])
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class TFSegformerEfficientSelfAttention(keras.layers.Layer):
|
| 117 |
+
"""SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT
|
| 118 |
+
paper](https://arxiv.org/abs/2102.12122)."""
|
| 119 |
+
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
config: SegformerConfig,
|
| 123 |
+
hidden_size: int,
|
| 124 |
+
num_attention_heads: int,
|
| 125 |
+
sequence_reduction_ratio: int,
|
| 126 |
+
**kwargs,
|
| 127 |
+
):
|
| 128 |
+
super().__init__(**kwargs)
|
| 129 |
+
self.hidden_size = hidden_size
|
| 130 |
+
self.num_attention_heads = num_attention_heads
|
| 131 |
+
|
| 132 |
+
if self.hidden_size % self.num_attention_heads != 0:
|
| 133 |
+
raise ValueError(
|
| 134 |
+
f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
|
| 135 |
+
f"heads ({self.num_attention_heads})"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
self.attention_head_size = self.hidden_size // self.num_attention_heads
|
| 139 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 140 |
+
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
|
| 141 |
+
|
| 142 |
+
self.query = keras.layers.Dense(self.all_head_size, name="query")
|
| 143 |
+
self.key = keras.layers.Dense(self.all_head_size, name="key")
|
| 144 |
+
self.value = keras.layers.Dense(self.all_head_size, name="value")
|
| 145 |
+
|
| 146 |
+
self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
|
| 147 |
+
|
| 148 |
+
self.sr_ratio = sequence_reduction_ratio
|
| 149 |
+
if sequence_reduction_ratio > 1:
|
| 150 |
+
self.sr = keras.layers.Conv2D(
|
| 151 |
+
filters=hidden_size, kernel_size=sequence_reduction_ratio, strides=sequence_reduction_ratio, name="sr"
|
| 152 |
+
)
|
| 153 |
+
self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm")
|
| 154 |
+
|
| 155 |
+
def transpose_for_scores(self, tensor: tf.Tensor) -> tf.Tensor:
|
| 156 |
+
# Reshape from [batch_size, seq_length, all_head_size]
|
| 157 |
+
# to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
| 158 |
+
batch_size = shape_list(tensor)[0]
|
| 159 |
+
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
| 160 |
+
|
| 161 |
+
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size]
|
| 162 |
+
# to [batch_size, num_attention_heads, seq_length, attention_head_size]
|
| 163 |
+
return tf.transpose(tensor, perm=[0, 2, 1, 3])
|
| 164 |
+
|
| 165 |
+
def call(
|
| 166 |
+
self,
|
| 167 |
+
hidden_states: tf.Tensor,
|
| 168 |
+
height: int,
|
| 169 |
+
width: int,
|
| 170 |
+
output_attentions: bool = False,
|
| 171 |
+
training: bool = False,
|
| 172 |
+
) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
|
| 173 |
+
batch_size = shape_list(hidden_states)[0]
|
| 174 |
+
num_channels = shape_list(hidden_states)[2]
|
| 175 |
+
|
| 176 |
+
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
| 177 |
+
|
| 178 |
+
if self.sr_ratio > 1:
|
| 179 |
+
# Reshape to (batch_size, height, width, num_channels)
|
| 180 |
+
hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels))
|
| 181 |
+
# Apply sequence reduction
|
| 182 |
+
hidden_states = self.sr(hidden_states)
|
| 183 |
+
# Reshape back to (batch_size, seq_len, num_channels)
|
| 184 |
+
hidden_states = tf.reshape(hidden_states, (batch_size, -1, num_channels))
|
| 185 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 186 |
+
|
| 187 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 188 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 189 |
+
|
| 190 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 191 |
+
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
|
| 192 |
+
|
| 193 |
+
scale = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
|
| 194 |
+
attention_scores = tf.divide(attention_scores, scale)
|
| 195 |
+
|
| 196 |
+
# Normalize the attention scores to probabilities.
|
| 197 |
+
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
|
| 198 |
+
|
| 199 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 200 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 201 |
+
attention_probs = self.dropout(attention_probs, training=training)
|
| 202 |
+
|
| 203 |
+
context_layer = tf.matmul(attention_probs, value_layer)
|
| 204 |
+
|
| 205 |
+
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
|
| 206 |
+
# (batch_size, seq_len_q, all_head_size)
|
| 207 |
+
context_layer = tf.reshape(context_layer, (batch_size, -1, self.all_head_size))
|
| 208 |
+
|
| 209 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 210 |
+
return outputs
|
| 211 |
+
|
| 212 |
+
def build(self, input_shape=None):
|
| 213 |
+
if self.built:
|
| 214 |
+
return
|
| 215 |
+
self.built = True
|
| 216 |
+
if getattr(self, "query", None) is not None:
|
| 217 |
+
with tf.name_scope(self.query.name):
|
| 218 |
+
self.query.build([None, None, self.hidden_size])
|
| 219 |
+
if getattr(self, "key", None) is not None:
|
| 220 |
+
with tf.name_scope(self.key.name):
|
| 221 |
+
self.key.build([None, None, self.hidden_size])
|
| 222 |
+
if getattr(self, "value", None) is not None:
|
| 223 |
+
with tf.name_scope(self.value.name):
|
| 224 |
+
self.value.build([None, None, self.hidden_size])
|
| 225 |
+
if getattr(self, "sr", None) is not None:
|
| 226 |
+
with tf.name_scope(self.sr.name):
|
| 227 |
+
self.sr.build([None, None, None, self.hidden_size])
|
| 228 |
+
if getattr(self, "layer_norm", None) is not None:
|
| 229 |
+
with tf.name_scope(self.layer_norm.name):
|
| 230 |
+
self.layer_norm.build([None, None, self.hidden_size])
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class TFSegformerSelfOutput(keras.layers.Layer):
|
| 234 |
+
def __init__(self, config: SegformerConfig, hidden_size: int, **kwargs):
|
| 235 |
+
super().__init__(**kwargs)
|
| 236 |
+
self.dense = keras.layers.Dense(hidden_size, name="dense")
|
| 237 |
+
self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
|
| 238 |
+
self.hidden_size = hidden_size
|
| 239 |
+
|
| 240 |
+
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 241 |
+
hidden_states = self.dense(hidden_states)
|
| 242 |
+
hidden_states = self.dropout(hidden_states, training=training)
|
| 243 |
+
return hidden_states
|
| 244 |
+
|
| 245 |
+
def build(self, input_shape=None):
|
| 246 |
+
if self.built:
|
| 247 |
+
return
|
| 248 |
+
self.built = True
|
| 249 |
+
if getattr(self, "dense", None) is not None:
|
| 250 |
+
with tf.name_scope(self.dense.name):
|
| 251 |
+
self.dense.build([None, None, self.hidden_size])
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class TFSegformerAttention(keras.layers.Layer):
|
| 255 |
+
def __init__(
|
| 256 |
+
self,
|
| 257 |
+
config: SegformerConfig,
|
| 258 |
+
hidden_size: int,
|
| 259 |
+
num_attention_heads: int,
|
| 260 |
+
sequence_reduction_ratio: int,
|
| 261 |
+
**kwargs,
|
| 262 |
+
):
|
| 263 |
+
super().__init__(**kwargs)
|
| 264 |
+
self.self = TFSegformerEfficientSelfAttention(
|
| 265 |
+
config=config,
|
| 266 |
+
hidden_size=hidden_size,
|
| 267 |
+
num_attention_heads=num_attention_heads,
|
| 268 |
+
sequence_reduction_ratio=sequence_reduction_ratio,
|
| 269 |
+
name="self",
|
| 270 |
+
)
|
| 271 |
+
self.dense_output = TFSegformerSelfOutput(config, hidden_size=hidden_size, name="output")
|
| 272 |
+
|
| 273 |
+
def call(
|
| 274 |
+
self, hidden_states: tf.Tensor, height: int, width: int, output_attentions: bool = False
|
| 275 |
+
) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
|
| 276 |
+
self_outputs = self.self(hidden_states, height, width, output_attentions)
|
| 277 |
+
|
| 278 |
+
attention_output = self.dense_output(self_outputs[0])
|
| 279 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 280 |
+
return outputs
|
| 281 |
+
|
| 282 |
+
def build(self, input_shape=None):
|
| 283 |
+
if self.built:
|
| 284 |
+
return
|
| 285 |
+
self.built = True
|
| 286 |
+
if getattr(self, "self", None) is not None:
|
| 287 |
+
with tf.name_scope(self.self.name):
|
| 288 |
+
self.self.build(None)
|
| 289 |
+
if getattr(self, "dense_output", None) is not None:
|
| 290 |
+
with tf.name_scope(self.dense_output.name):
|
| 291 |
+
self.dense_output.build(None)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class TFSegformerDWConv(keras.layers.Layer):
|
| 295 |
+
def __init__(self, dim: int = 768, **kwargs):
|
| 296 |
+
super().__init__(**kwargs)
|
| 297 |
+
self.depthwise_convolution = keras.layers.Conv2D(
|
| 298 |
+
filters=dim, kernel_size=3, strides=1, padding="same", groups=dim, name="dwconv"
|
| 299 |
+
)
|
| 300 |
+
self.dim = dim
|
| 301 |
+
|
| 302 |
+
def call(self, hidden_states: tf.Tensor, height: int, width: int) -> tf.Tensor:
|
| 303 |
+
batch_size = shape_list(hidden_states)[0]
|
| 304 |
+
num_channels = shape_list(hidden_states)[-1]
|
| 305 |
+
hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels))
|
| 306 |
+
hidden_states = self.depthwise_convolution(hidden_states)
|
| 307 |
+
|
| 308 |
+
new_height = shape_list(hidden_states)[1]
|
| 309 |
+
new_width = shape_list(hidden_states)[2]
|
| 310 |
+
num_channels = shape_list(hidden_states)[3]
|
| 311 |
+
hidden_states = tf.reshape(hidden_states, (batch_size, new_height * new_width, num_channels))
|
| 312 |
+
return hidden_states
|
| 313 |
+
|
| 314 |
+
def build(self, input_shape=None):
|
| 315 |
+
if self.built:
|
| 316 |
+
return
|
| 317 |
+
self.built = True
|
| 318 |
+
if getattr(self, "depthwise_convolution", None) is not None:
|
| 319 |
+
with tf.name_scope(self.depthwise_convolution.name):
|
| 320 |
+
self.depthwise_convolution.build([None, None, None, self.dim])
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class TFSegformerMixFFN(keras.layers.Layer):
|
| 324 |
+
def __init__(
|
| 325 |
+
self,
|
| 326 |
+
config: SegformerConfig,
|
| 327 |
+
in_features: int,
|
| 328 |
+
hidden_features: Optional[int] = None,
|
| 329 |
+
out_features: Optional[int] = None,
|
| 330 |
+
**kwargs,
|
| 331 |
+
):
|
| 332 |
+
super().__init__(**kwargs)
|
| 333 |
+
out_features = out_features or in_features
|
| 334 |
+
self.dense1 = keras.layers.Dense(hidden_features, name="dense1")
|
| 335 |
+
self.depthwise_convolution = TFSegformerDWConv(hidden_features, name="dwconv")
|
| 336 |
+
if isinstance(config.hidden_act, str):
|
| 337 |
+
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
|
| 338 |
+
else:
|
| 339 |
+
self.intermediate_act_fn = config.hidden_act
|
| 340 |
+
self.dense2 = keras.layers.Dense(out_features, name="dense2")
|
| 341 |
+
self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
|
| 342 |
+
self.hidden_features = hidden_features
|
| 343 |
+
self.in_features = in_features
|
| 344 |
+
|
| 345 |
+
def call(self, hidden_states: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor:
|
| 346 |
+
hidden_states = self.dense1(hidden_states)
|
| 347 |
+
hidden_states = self.depthwise_convolution(hidden_states, height, width)
|
| 348 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 349 |
+
hidden_states = self.dropout(hidden_states, training=training)
|
| 350 |
+
hidden_states = self.dense2(hidden_states)
|
| 351 |
+
hidden_states = self.dropout(hidden_states, training=training)
|
| 352 |
+
return hidden_states
|
| 353 |
+
|
| 354 |
+
def build(self, input_shape=None):
|
| 355 |
+
if self.built:
|
| 356 |
+
return
|
| 357 |
+
self.built = True
|
| 358 |
+
if getattr(self, "dense1", None) is not None:
|
| 359 |
+
with tf.name_scope(self.dense1.name):
|
| 360 |
+
self.dense1.build([None, None, self.in_features])
|
| 361 |
+
if getattr(self, "depthwise_convolution", None) is not None:
|
| 362 |
+
with tf.name_scope(self.depthwise_convolution.name):
|
| 363 |
+
self.depthwise_convolution.build(None)
|
| 364 |
+
if getattr(self, "dense2", None) is not None:
|
| 365 |
+
with tf.name_scope(self.dense2.name):
|
| 366 |
+
self.dense2.build([None, None, self.hidden_features])
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class TFSegformerLayer(keras.layers.Layer):
|
| 370 |
+
"""This corresponds to the Block class in the original implementation."""
|
| 371 |
+
|
| 372 |
+
def __init__(
|
| 373 |
+
self,
|
| 374 |
+
config,
|
| 375 |
+
hidden_size: int,
|
| 376 |
+
num_attention_heads: int,
|
| 377 |
+
drop_path: float,
|
| 378 |
+
sequence_reduction_ratio: int,
|
| 379 |
+
mlp_ratio: int,
|
| 380 |
+
**kwargs,
|
| 381 |
+
):
|
| 382 |
+
super().__init__(**kwargs)
|
| 383 |
+
self.layer_norm_1 = keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm_1")
|
| 384 |
+
self.attention = TFSegformerAttention(
|
| 385 |
+
config,
|
| 386 |
+
hidden_size=hidden_size,
|
| 387 |
+
num_attention_heads=num_attention_heads,
|
| 388 |
+
sequence_reduction_ratio=sequence_reduction_ratio,
|
| 389 |
+
name="attention",
|
| 390 |
+
)
|
| 391 |
+
self.drop_path = TFSegformerDropPath(drop_path) if drop_path > 0.0 else keras.layers.Activation("linear")
|
| 392 |
+
self.layer_norm_2 = keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm_2")
|
| 393 |
+
mlp_hidden_size = int(hidden_size * mlp_ratio)
|
| 394 |
+
self.mlp = TFSegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size, name="mlp")
|
| 395 |
+
self.hidden_size = hidden_size
|
| 396 |
+
|
| 397 |
+
def call(
|
| 398 |
+
self,
|
| 399 |
+
hidden_states: tf.Tensor,
|
| 400 |
+
height: int,
|
| 401 |
+
width: int,
|
| 402 |
+
output_attentions: bool = False,
|
| 403 |
+
training: bool = False,
|
| 404 |
+
) -> Tuple:
|
| 405 |
+
self_attention_outputs = self.attention(
|
| 406 |
+
self.layer_norm_1(hidden_states), # in Segformer, layernorm is applied before self-attention
|
| 407 |
+
height,
|
| 408 |
+
width,
|
| 409 |
+
output_attentions=output_attentions,
|
| 410 |
+
training=training,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
attention_output = self_attention_outputs[0]
|
| 414 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 415 |
+
|
| 416 |
+
# first residual connection (with stochastic depth)
|
| 417 |
+
attention_output = self.drop_path(attention_output, training=training)
|
| 418 |
+
hidden_states = attention_output + hidden_states
|
| 419 |
+
mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width)
|
| 420 |
+
|
| 421 |
+
# second residual connection (with stochastic depth)
|
| 422 |
+
mlp_output = self.drop_path(mlp_output, training=training)
|
| 423 |
+
layer_output = mlp_output + hidden_states
|
| 424 |
+
|
| 425 |
+
outputs = (layer_output,) + outputs
|
| 426 |
+
|
| 427 |
+
return outputs
|
| 428 |
+
|
| 429 |
+
def build(self, input_shape=None):
|
| 430 |
+
if self.built:
|
| 431 |
+
return
|
| 432 |
+
self.built = True
|
| 433 |
+
if getattr(self, "layer_norm_1", None) is not None:
|
| 434 |
+
with tf.name_scope(self.layer_norm_1.name):
|
| 435 |
+
self.layer_norm_1.build([None, None, self.hidden_size])
|
| 436 |
+
if getattr(self, "attention", None) is not None:
|
| 437 |
+
with tf.name_scope(self.attention.name):
|
| 438 |
+
self.attention.build(None)
|
| 439 |
+
if getattr(self, "layer_norm_2", None) is not None:
|
| 440 |
+
with tf.name_scope(self.layer_norm_2.name):
|
| 441 |
+
self.layer_norm_2.build([None, None, self.hidden_size])
|
| 442 |
+
if getattr(self, "mlp", None) is not None:
|
| 443 |
+
with tf.name_scope(self.mlp.name):
|
| 444 |
+
self.mlp.build(None)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
class TFSegformerEncoder(keras.layers.Layer):
|
| 448 |
+
def __init__(self, config: SegformerConfig, **kwargs):
|
| 449 |
+
super().__init__(**kwargs)
|
| 450 |
+
self.config = config
|
| 451 |
+
|
| 452 |
+
# stochastic depth decay rule
|
| 453 |
+
drop_path_decays = [x.numpy() for x in tf.linspace(0.0, config.drop_path_rate, sum(config.depths))]
|
| 454 |
+
|
| 455 |
+
# patch embeddings
|
| 456 |
+
embeddings = []
|
| 457 |
+
for i in range(config.num_encoder_blocks):
|
| 458 |
+
embeddings.append(
|
| 459 |
+
TFSegformerOverlapPatchEmbeddings(
|
| 460 |
+
patch_size=config.patch_sizes[i],
|
| 461 |
+
stride=config.strides[i],
|
| 462 |
+
num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
|
| 463 |
+
hidden_size=config.hidden_sizes[i],
|
| 464 |
+
name=f"patch_embeddings.{i}",
|
| 465 |
+
)
|
| 466 |
+
)
|
| 467 |
+
self.embeddings = embeddings
|
| 468 |
+
|
| 469 |
+
# Transformer blocks
|
| 470 |
+
blocks = []
|
| 471 |
+
cur = 0
|
| 472 |
+
for i in range(config.num_encoder_blocks):
|
| 473 |
+
# each block consists of layers
|
| 474 |
+
layers = []
|
| 475 |
+
if i != 0:
|
| 476 |
+
cur += config.depths[i - 1]
|
| 477 |
+
for j in range(config.depths[i]):
|
| 478 |
+
layers.append(
|
| 479 |
+
TFSegformerLayer(
|
| 480 |
+
config,
|
| 481 |
+
hidden_size=config.hidden_sizes[i],
|
| 482 |
+
num_attention_heads=config.num_attention_heads[i],
|
| 483 |
+
drop_path=drop_path_decays[cur + j],
|
| 484 |
+
sequence_reduction_ratio=config.sr_ratios[i],
|
| 485 |
+
mlp_ratio=config.mlp_ratios[i],
|
| 486 |
+
name=f"block.{i}.{j}",
|
| 487 |
+
)
|
| 488 |
+
)
|
| 489 |
+
blocks.append(layers)
|
| 490 |
+
|
| 491 |
+
self.block = blocks
|
| 492 |
+
|
| 493 |
+
# Layer norms
|
| 494 |
+
self.layer_norms = [
|
| 495 |
+
keras.layers.LayerNormalization(epsilon=1e-05, name=f"layer_norm.{i}")
|
| 496 |
+
for i in range(config.num_encoder_blocks)
|
| 497 |
+
]
|
| 498 |
+
|
| 499 |
+
def call(
|
| 500 |
+
self,
|
| 501 |
+
pixel_values: tf.Tensor,
|
| 502 |
+
output_attentions: Optional[bool] = False,
|
| 503 |
+
output_hidden_states: Optional[bool] = False,
|
| 504 |
+
return_dict: Optional[bool] = True,
|
| 505 |
+
training: bool = False,
|
| 506 |
+
) -> Union[Tuple, TFBaseModelOutput]:
|
| 507 |
+
all_hidden_states = () if output_hidden_states else None
|
| 508 |
+
all_self_attentions = () if output_attentions else None
|
| 509 |
+
|
| 510 |
+
batch_size = shape_list(pixel_values)[0]
|
| 511 |
+
|
| 512 |
+
hidden_states = pixel_values
|
| 513 |
+
for idx, x in enumerate(zip(self.embeddings, self.block, self.layer_norms)):
|
| 514 |
+
embedding_layer, block_layer, norm_layer = x
|
| 515 |
+
# first, obtain patch embeddings
|
| 516 |
+
hidden_states, height, width = embedding_layer(hidden_states)
|
| 517 |
+
|
| 518 |
+
# second, send embeddings through blocks
|
| 519 |
+
# (each block consists of multiple layers i.e., list of layers)
|
| 520 |
+
for i, blk in enumerate(block_layer):
|
| 521 |
+
layer_outputs = blk(
|
| 522 |
+
hidden_states,
|
| 523 |
+
height,
|
| 524 |
+
width,
|
| 525 |
+
output_attentions,
|
| 526 |
+
training=training,
|
| 527 |
+
)
|
| 528 |
+
hidden_states = layer_outputs[0]
|
| 529 |
+
if output_attentions:
|
| 530 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 531 |
+
|
| 532 |
+
# third, apply layer norm
|
| 533 |
+
hidden_states = norm_layer(hidden_states)
|
| 534 |
+
|
| 535 |
+
# fourth, optionally reshape back to (batch_size, height, width, num_channels)
|
| 536 |
+
if idx != len(self.embeddings) - 1 or (idx == len(self.embeddings) - 1 and self.config.reshape_last_stage):
|
| 537 |
+
num_channels = shape_list(hidden_states)[-1]
|
| 538 |
+
hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels))
|
| 539 |
+
|
| 540 |
+
if output_hidden_states:
|
| 541 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 542 |
+
|
| 543 |
+
if not return_dict:
|
| 544 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 545 |
+
return TFBaseModelOutput(
|
| 546 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
def build(self, input_shape=None):
|
| 550 |
+
if self.built:
|
| 551 |
+
return
|
| 552 |
+
self.built = True
|
| 553 |
+
if getattr(self, "layer_norms", None) is not None:
|
| 554 |
+
for layer, shape in zip(self.layer_norms, self.config.hidden_sizes):
|
| 555 |
+
with tf.name_scope(layer.name):
|
| 556 |
+
layer.build([None, None, shape])
|
| 557 |
+
if getattr(self, "block", None) is not None:
|
| 558 |
+
for block in self.block:
|
| 559 |
+
for layer in block:
|
| 560 |
+
with tf.name_scope(layer.name):
|
| 561 |
+
layer.build(None)
|
| 562 |
+
if getattr(self, "embeddings", None) is not None:
|
| 563 |
+
for layer in self.embeddings:
|
| 564 |
+
with tf.name_scope(layer.name):
|
| 565 |
+
layer.build(None)
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
@keras_serializable
|
| 569 |
+
class TFSegformerMainLayer(keras.layers.Layer):
|
| 570 |
+
config_class = SegformerConfig
|
| 571 |
+
|
| 572 |
+
def __init__(self, config: SegformerConfig, **kwargs):
|
| 573 |
+
super().__init__(**kwargs)
|
| 574 |
+
|
| 575 |
+
self.config = config
|
| 576 |
+
# hierarchical Transformer encoder
|
| 577 |
+
self.encoder = TFSegformerEncoder(config, name="encoder")
|
| 578 |
+
|
| 579 |
+
@unpack_inputs
|
| 580 |
+
def call(
|
| 581 |
+
self,
|
| 582 |
+
pixel_values: tf.Tensor,
|
| 583 |
+
output_attentions: Optional[bool] = None,
|
| 584 |
+
output_hidden_states: Optional[bool] = None,
|
| 585 |
+
return_dict: Optional[bool] = None,
|
| 586 |
+
training: bool = False,
|
| 587 |
+
) -> Union[Tuple, TFBaseModelOutput]:
|
| 588 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 589 |
+
output_hidden_states = (
|
| 590 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 591 |
+
)
|
| 592 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 593 |
+
|
| 594 |
+
# When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
|
| 595 |
+
# So change the input format from `NCHW` to `NHWC`.
|
| 596 |
+
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
|
| 597 |
+
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
|
| 598 |
+
|
| 599 |
+
encoder_outputs = self.encoder(
|
| 600 |
+
pixel_values,
|
| 601 |
+
output_attentions=output_attentions,
|
| 602 |
+
output_hidden_states=output_hidden_states,
|
| 603 |
+
return_dict=return_dict,
|
| 604 |
+
training=training,
|
| 605 |
+
)
|
| 606 |
+
sequence_output = encoder_outputs[0]
|
| 607 |
+
# Change to NCHW output format to have uniformity in the modules
|
| 608 |
+
sequence_output = tf.transpose(sequence_output, perm=[0, 3, 1, 2])
|
| 609 |
+
|
| 610 |
+
# Change the other hidden state outputs to NCHW as well
|
| 611 |
+
if output_hidden_states:
|
| 612 |
+
hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
|
| 613 |
+
|
| 614 |
+
if not return_dict:
|
| 615 |
+
if tf.greater(len(encoder_outputs[1:]), 0):
|
| 616 |
+
transposed_encoder_outputs = tuple(tf.transpose(v, perm=[0, 3, 1, 2]) for v in encoder_outputs[1:][0])
|
| 617 |
+
return (sequence_output,) + (transposed_encoder_outputs,)
|
| 618 |
+
else:
|
| 619 |
+
return (sequence_output,) + encoder_outputs[1:]
|
| 620 |
+
|
| 621 |
+
return TFBaseModelOutput(
|
| 622 |
+
last_hidden_state=sequence_output,
|
| 623 |
+
hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
|
| 624 |
+
attentions=encoder_outputs.attentions,
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
def build(self, input_shape=None):
|
| 628 |
+
if self.built:
|
| 629 |
+
return
|
| 630 |
+
self.built = True
|
| 631 |
+
if getattr(self, "encoder", None) is not None:
|
| 632 |
+
with tf.name_scope(self.encoder.name):
|
| 633 |
+
self.encoder.build(None)
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
class TFSegformerPreTrainedModel(TFPreTrainedModel):
|
| 637 |
+
"""
|
| 638 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 639 |
+
models.
|
| 640 |
+
"""
|
| 641 |
+
|
| 642 |
+
config_class = SegformerConfig
|
| 643 |
+
base_model_prefix = "segformer"
|
| 644 |
+
main_input_name = "pixel_values"
|
| 645 |
+
|
| 646 |
+
@property
|
| 647 |
+
def input_signature(self):
|
| 648 |
+
return {"pixel_values": tf.TensorSpec(shape=(None, self.config.num_channels, 512, 512), dtype=tf.float32)}
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
SEGFORMER_START_DOCSTRING = r"""
|
| 652 |
+
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 653 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 654 |
+
etc.)
|
| 655 |
+
|
| 656 |
+
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
| 657 |
+
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
| 658 |
+
behavior.
|
| 659 |
+
|
| 660 |
+
Parameters:
|
| 661 |
+
config ([`SegformerConfig`]): Model configuration class with all the parameters of the model.
|
| 662 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 663 |
+
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
|
| 664 |
+
"""
|
| 665 |
+
|
| 666 |
+
SEGFORMER_INPUTS_DOCSTRING = r"""
|
| 667 |
+
|
| 668 |
+
Args:
|
| 669 |
+
pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
|
| 670 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
| 671 |
+
[`SegformerImageProcessor.__call__`] for details.
|
| 672 |
+
|
| 673 |
+
output_attentions (`bool`, *optional*):
|
| 674 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 675 |
+
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
|
| 676 |
+
config will be used instead.
|
| 677 |
+
|
| 678 |
+
output_hidden_states (`bool`, *optional*):
|
| 679 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 680 |
+
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
| 681 |
+
used instead.
|
| 682 |
+
|
| 683 |
+
return_dict (`bool`, *optional*):
|
| 684 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
|
| 685 |
+
eager mode, in graph mode the value will always be set to True.
|
| 686 |
+
|
| 687 |
+
training (`bool`, *optional*, defaults to `False``):
|
| 688 |
+
Whether or not to use the model in training mode (some modules like dropout modules have different
|
| 689 |
+
behaviors between training and evaluation).
|
| 690 |
+
"""
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
@add_start_docstrings(
|
| 694 |
+
"The bare SegFormer encoder (Mix-Transformer) outputting raw hidden-states without any specific head on top.",
|
| 695 |
+
SEGFORMER_START_DOCSTRING,
|
| 696 |
+
)
|
| 697 |
+
class TFSegformerModel(TFSegformerPreTrainedModel):
|
| 698 |
+
def __init__(self, config: SegformerConfig, *inputs, **kwargs):
|
| 699 |
+
super().__init__(config, *inputs, **kwargs)
|
| 700 |
+
self.config = config
|
| 701 |
+
|
| 702 |
+
# hierarchical Transformer encoder
|
| 703 |
+
self.segformer = TFSegformerMainLayer(config, name="segformer")
|
| 704 |
+
|
| 705 |
+
@unpack_inputs
|
| 706 |
+
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 707 |
+
@add_code_sample_docstrings(
|
| 708 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 709 |
+
output_type=TFBaseModelOutput,
|
| 710 |
+
config_class=_CONFIG_FOR_DOC,
|
| 711 |
+
modality="vision",
|
| 712 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 713 |
+
)
|
| 714 |
+
def call(
|
| 715 |
+
self,
|
| 716 |
+
pixel_values: tf.Tensor,
|
| 717 |
+
output_attentions: Optional[bool] = None,
|
| 718 |
+
output_hidden_states: Optional[bool] = None,
|
| 719 |
+
return_dict: Optional[bool] = None,
|
| 720 |
+
training: bool = False,
|
| 721 |
+
) -> Union[Tuple, TFBaseModelOutput]:
|
| 722 |
+
outputs = self.segformer(
|
| 723 |
+
pixel_values,
|
| 724 |
+
output_attentions=output_attentions,
|
| 725 |
+
output_hidden_states=output_hidden_states,
|
| 726 |
+
return_dict=return_dict,
|
| 727 |
+
training=training,
|
| 728 |
+
)
|
| 729 |
+
return outputs
|
| 730 |
+
|
| 731 |
+
def build(self, input_shape=None):
|
| 732 |
+
if self.built:
|
| 733 |
+
return
|
| 734 |
+
self.built = True
|
| 735 |
+
if getattr(self, "segformer", None) is not None:
|
| 736 |
+
with tf.name_scope(self.segformer.name):
|
| 737 |
+
self.segformer.build(None)
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
@add_start_docstrings(
|
| 741 |
+
"""
|
| 742 |
+
SegFormer Model transformer with an image classification head on top (a linear layer on top of the final hidden
|
| 743 |
+
states) e.g. for ImageNet.
|
| 744 |
+
""",
|
| 745 |
+
SEGFORMER_START_DOCSTRING,
|
| 746 |
+
)
|
| 747 |
+
class TFSegformerForImageClassification(TFSegformerPreTrainedModel, TFSequenceClassificationLoss):
|
| 748 |
+
def __init__(self, config: SegformerConfig, *inputs, **kwargs):
|
| 749 |
+
super().__init__(config, *inputs, **kwargs)
|
| 750 |
+
|
| 751 |
+
self.num_labels = config.num_labels
|
| 752 |
+
self.segformer = TFSegformerMainLayer(config, name="segformer")
|
| 753 |
+
|
| 754 |
+
# Classifier head
|
| 755 |
+
self.classifier = keras.layers.Dense(config.num_labels, name="classifier")
|
| 756 |
+
self.config = config
|
| 757 |
+
|
| 758 |
+
@unpack_inputs
|
| 759 |
+
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 760 |
+
@add_code_sample_docstrings(
|
| 761 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
| 762 |
+
output_type=TFSequenceClassifierOutput,
|
| 763 |
+
config_class=_CONFIG_FOR_DOC,
|
| 764 |
+
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
| 765 |
+
)
|
| 766 |
+
def call(
|
| 767 |
+
self,
|
| 768 |
+
pixel_values: tf.Tensor | None = None,
|
| 769 |
+
labels: tf.Tensor | None = None,
|
| 770 |
+
output_attentions: Optional[bool] = None,
|
| 771 |
+
output_hidden_states: Optional[bool] = None,
|
| 772 |
+
return_dict: Optional[bool] = None,
|
| 773 |
+
) -> Union[Tuple, TFSequenceClassifierOutput]:
|
| 774 |
+
outputs = self.segformer(
|
| 775 |
+
pixel_values,
|
| 776 |
+
output_attentions=output_attentions,
|
| 777 |
+
output_hidden_states=output_hidden_states,
|
| 778 |
+
return_dict=return_dict,
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
sequence_output = outputs[0]
|
| 782 |
+
|
| 783 |
+
# convert last hidden states to (batch_size, height*width, hidden_size)
|
| 784 |
+
batch_size = shape_list(sequence_output)[0]
|
| 785 |
+
sequence_output = tf.transpose(sequence_output, perm=[0, 2, 3, 1])
|
| 786 |
+
sequence_output = tf.reshape(sequence_output, (batch_size, -1, self.config.hidden_sizes[-1]))
|
| 787 |
+
|
| 788 |
+
# global average pooling
|
| 789 |
+
sequence_output = tf.reduce_mean(sequence_output, axis=1)
|
| 790 |
+
|
| 791 |
+
logits = self.classifier(sequence_output)
|
| 792 |
+
|
| 793 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
|
| 794 |
+
|
| 795 |
+
if not return_dict:
|
| 796 |
+
output = (logits,) + outputs[1:]
|
| 797 |
+
return ((loss,) + output) if loss is not None else output
|
| 798 |
+
|
| 799 |
+
return TFSequenceClassifierOutput(
|
| 800 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
def build(self, input_shape=None):
|
| 804 |
+
if self.built:
|
| 805 |
+
return
|
| 806 |
+
self.built = True
|
| 807 |
+
if getattr(self, "segformer", None) is not None:
|
| 808 |
+
with tf.name_scope(self.segformer.name):
|
| 809 |
+
self.segformer.build(None)
|
| 810 |
+
if getattr(self, "classifier", None) is not None:
|
| 811 |
+
with tf.name_scope(self.classifier.name):
|
| 812 |
+
self.classifier.build([None, None, self.config.hidden_sizes[-1]])
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
class TFSegformerMLP(keras.layers.Layer):
|
| 816 |
+
"""
|
| 817 |
+
Linear Embedding.
|
| 818 |
+
"""
|
| 819 |
+
|
| 820 |
+
def __init__(self, input_dim: int, config: SegformerConfig, **kwargs):
|
| 821 |
+
super().__init__(**kwargs)
|
| 822 |
+
self.proj = keras.layers.Dense(config.decoder_hidden_size, name="proj")
|
| 823 |
+
self.input_dim = input_dim
|
| 824 |
+
|
| 825 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 826 |
+
height = shape_list(hidden_states)[1]
|
| 827 |
+
width = shape_list(hidden_states)[2]
|
| 828 |
+
hidden_dim = shape_list(hidden_states)[-1]
|
| 829 |
+
hidden_states = tf.reshape(hidden_states, (-1, height * width, hidden_dim))
|
| 830 |
+
hidden_states = self.proj(hidden_states)
|
| 831 |
+
return hidden_states
|
| 832 |
+
|
| 833 |
+
def build(self, input_shape=None):
|
| 834 |
+
if self.built:
|
| 835 |
+
return
|
| 836 |
+
self.built = True
|
| 837 |
+
if getattr(self, "proj", None) is not None:
|
| 838 |
+
with tf.name_scope(self.proj.name):
|
| 839 |
+
self.proj.build([None, None, self.input_dim])
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
class TFSegformerDecodeHead(TFSegformerPreTrainedModel):
|
| 843 |
+
def __init__(self, config: SegformerConfig, **kwargs):
|
| 844 |
+
super().__init__(config, **kwargs)
|
| 845 |
+
# linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size
|
| 846 |
+
mlps = []
|
| 847 |
+
for i in range(config.num_encoder_blocks):
|
| 848 |
+
mlp = TFSegformerMLP(config=config, input_dim=config.hidden_sizes[i], name=f"linear_c.{i}")
|
| 849 |
+
mlps.append(mlp)
|
| 850 |
+
self.mlps = mlps
|
| 851 |
+
|
| 852 |
+
# the following 3 layers implement the ConvModule of the original implementation
|
| 853 |
+
self.linear_fuse = keras.layers.Conv2D(
|
| 854 |
+
filters=config.decoder_hidden_size, kernel_size=1, use_bias=False, name="linear_fuse"
|
| 855 |
+
)
|
| 856 |
+
self.batch_norm = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="batch_norm")
|
| 857 |
+
self.activation = keras.layers.Activation("relu")
|
| 858 |
+
|
| 859 |
+
self.dropout = keras.layers.Dropout(config.classifier_dropout_prob)
|
| 860 |
+
self.classifier = keras.layers.Conv2D(filters=config.num_labels, kernel_size=1, name="classifier")
|
| 861 |
+
|
| 862 |
+
self.config = config
|
| 863 |
+
|
| 864 |
+
def call(self, encoder_hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 865 |
+
all_hidden_states = ()
|
| 866 |
+
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.mlps):
|
| 867 |
+
if self.config.reshape_last_stage is False and len(shape_list(encoder_hidden_state)) == 3:
|
| 868 |
+
height = tf.math.sqrt(tf.cast(shape_list(encoder_hidden_state)[1], tf.float32))
|
| 869 |
+
height = width = tf.cast(height, tf.int32)
|
| 870 |
+
channel_dim = shape_list(encoder_hidden_state)[-1]
|
| 871 |
+
encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim))
|
| 872 |
+
|
| 873 |
+
# unify channel dimension
|
| 874 |
+
encoder_hidden_state = tf.transpose(encoder_hidden_state, perm=[0, 2, 3, 1])
|
| 875 |
+
height, width = shape_list(encoder_hidden_state)[1:3]
|
| 876 |
+
encoder_hidden_state = mlp(encoder_hidden_state)
|
| 877 |
+
channel_dim = shape_list(encoder_hidden_state)[-1]
|
| 878 |
+
encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim))
|
| 879 |
+
|
| 880 |
+
# upsample
|
| 881 |
+
temp_state = tf.transpose(encoder_hidden_states[0], perm=[0, 2, 3, 1])
|
| 882 |
+
upsample_resolution = shape_list(temp_state)[1:-1]
|
| 883 |
+
encoder_hidden_state = tf.image.resize(encoder_hidden_state, size=upsample_resolution, method="bilinear")
|
| 884 |
+
all_hidden_states += (encoder_hidden_state,)
|
| 885 |
+
|
| 886 |
+
hidden_states = self.linear_fuse(tf.concat(all_hidden_states[::-1], axis=-1))
|
| 887 |
+
hidden_states = self.batch_norm(hidden_states, training=training)
|
| 888 |
+
hidden_states = self.activation(hidden_states)
|
| 889 |
+
hidden_states = self.dropout(hidden_states, training=training)
|
| 890 |
+
|
| 891 |
+
# logits of shape (batch_size, height/4, width/4, num_labels)
|
| 892 |
+
logits = self.classifier(hidden_states)
|
| 893 |
+
|
| 894 |
+
return logits
|
| 895 |
+
|
| 896 |
+
def build(self, input_shape=None):
|
| 897 |
+
if self.built:
|
| 898 |
+
return
|
| 899 |
+
self.built = True
|
| 900 |
+
if getattr(self, "linear_fuse", None) is not None:
|
| 901 |
+
with tf.name_scope(self.linear_fuse.name):
|
| 902 |
+
self.linear_fuse.build(
|
| 903 |
+
[None, None, None, self.config.decoder_hidden_size * self.config.num_encoder_blocks]
|
| 904 |
+
)
|
| 905 |
+
if getattr(self, "batch_norm", None) is not None:
|
| 906 |
+
with tf.name_scope(self.batch_norm.name):
|
| 907 |
+
self.batch_norm.build([None, None, None, self.config.decoder_hidden_size])
|
| 908 |
+
if getattr(self, "classifier", None) is not None:
|
| 909 |
+
with tf.name_scope(self.classifier.name):
|
| 910 |
+
self.classifier.build([None, None, None, self.config.decoder_hidden_size])
|
| 911 |
+
if getattr(self, "mlps", None) is not None:
|
| 912 |
+
for layer in self.mlps:
|
| 913 |
+
with tf.name_scope(layer.name):
|
| 914 |
+
layer.build(None)
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
@add_start_docstrings(
|
| 918 |
+
"""SegFormer Model transformer with an all-MLP decode head on top e.g. for ADE20k, CityScapes.""",
|
| 919 |
+
SEGFORMER_START_DOCSTRING,
|
| 920 |
+
)
|
| 921 |
+
class TFSegformerForSemanticSegmentation(TFSegformerPreTrainedModel):
|
| 922 |
+
def __init__(self, config: SegformerConfig, **kwargs):
|
| 923 |
+
super().__init__(config, **kwargs)
|
| 924 |
+
self.segformer = TFSegformerMainLayer(config, name="segformer")
|
| 925 |
+
self.decode_head = TFSegformerDecodeHead(config, name="decode_head")
|
| 926 |
+
|
| 927 |
+
def hf_compute_loss(self, logits, labels):
|
| 928 |
+
# upsample logits to the images' original size
|
| 929 |
+
# `labels` is of shape (batch_size, height, width)
|
| 930 |
+
label_interp_shape = shape_list(labels)[1:]
|
| 931 |
+
|
| 932 |
+
upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
|
| 933 |
+
# compute weighted loss
|
| 934 |
+
loss_fct = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
|
| 935 |
+
|
| 936 |
+
def masked_loss(real, pred):
|
| 937 |
+
unmasked_loss = loss_fct(real, pred)
|
| 938 |
+
mask = tf.cast(real != self.config.semantic_loss_ignore_index, dtype=unmasked_loss.dtype)
|
| 939 |
+
masked_loss = unmasked_loss * mask
|
| 940 |
+
# Reduction strategy in the similar spirit with
|
| 941 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210
|
| 942 |
+
reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(mask)
|
| 943 |
+
return tf.reshape(reduced_masked_loss, (1,))
|
| 944 |
+
|
| 945 |
+
return masked_loss(labels, upsampled_logits)
|
| 946 |
+
|
| 947 |
+
@unpack_inputs
|
| 948 |
+
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 949 |
+
@replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
|
| 950 |
+
def call(
|
| 951 |
+
self,
|
| 952 |
+
pixel_values: tf.Tensor,
|
| 953 |
+
labels: tf.Tensor | None = None,
|
| 954 |
+
output_attentions: Optional[bool] = None,
|
| 955 |
+
output_hidden_states: Optional[bool] = None,
|
| 956 |
+
return_dict: Optional[bool] = None,
|
| 957 |
+
) -> Union[Tuple, TFSemanticSegmenterOutput]:
|
| 958 |
+
r"""
|
| 959 |
+
labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*):
|
| 960 |
+
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
|
| 961 |
+
config.num_labels - 1]`. If `config.num_labels > 1`, a (per-pixel) classification loss is computed
|
| 962 |
+
(Cross-Entropy).
|
| 963 |
+
|
| 964 |
+
Returns:
|
| 965 |
+
|
| 966 |
+
Examples:
|
| 967 |
+
|
| 968 |
+
```python
|
| 969 |
+
>>> from transformers import AutoImageProcessor, TFSegformerForSemanticSegmentation
|
| 970 |
+
>>> from PIL import Image
|
| 971 |
+
>>> import requests
|
| 972 |
+
|
| 973 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 974 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 975 |
+
|
| 976 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
| 977 |
+
>>> model = TFSegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
| 978 |
+
|
| 979 |
+
>>> inputs = image_processor(images=image, return_tensors="tf")
|
| 980 |
+
>>> outputs = model(**inputs, training=False)
|
| 981 |
+
>>> # logits are of shape (batch_size, num_labels, height/4, width/4)
|
| 982 |
+
>>> logits = outputs.logits
|
| 983 |
+
>>> list(logits.shape)
|
| 984 |
+
[1, 150, 128, 128]
|
| 985 |
+
```"""
|
| 986 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 987 |
+
output_hidden_states = (
|
| 988 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 989 |
+
)
|
| 990 |
+
|
| 991 |
+
if labels is not None and not self.config.num_labels > 1:
|
| 992 |
+
raise ValueError("The number of labels should be greater than one")
|
| 993 |
+
|
| 994 |
+
outputs = self.segformer(
|
| 995 |
+
pixel_values,
|
| 996 |
+
output_attentions=output_attentions,
|
| 997 |
+
output_hidden_states=True, # we need the intermediate hidden states
|
| 998 |
+
return_dict=return_dict,
|
| 999 |
+
)
|
| 1000 |
+
|
| 1001 |
+
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
| 1002 |
+
|
| 1003 |
+
logits = self.decode_head(encoder_hidden_states)
|
| 1004 |
+
|
| 1005 |
+
loss = None
|
| 1006 |
+
if labels is not None:
|
| 1007 |
+
loss = self.hf_compute_loss(logits=logits, labels=labels)
|
| 1008 |
+
|
| 1009 |
+
# make logits of shape (batch_size, num_labels, height, width) to
|
| 1010 |
+
# keep them consistent across APIs
|
| 1011 |
+
logits = tf.transpose(logits, perm=[0, 3, 1, 2])
|
| 1012 |
+
|
| 1013 |
+
if not return_dict:
|
| 1014 |
+
if output_hidden_states:
|
| 1015 |
+
output = (logits,) + outputs[1:]
|
| 1016 |
+
else:
|
| 1017 |
+
output = (logits,) + outputs[2:]
|
| 1018 |
+
return ((loss,) + output) if loss is not None else output
|
| 1019 |
+
|
| 1020 |
+
return TFSemanticSegmenterOutput(
|
| 1021 |
+
loss=loss,
|
| 1022 |
+
logits=logits,
|
| 1023 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 1024 |
+
attentions=outputs.attentions,
|
| 1025 |
+
)
|
| 1026 |
+
|
| 1027 |
+
def build(self, input_shape=None):
|
| 1028 |
+
if self.built:
|
| 1029 |
+
return
|
| 1030 |
+
self.built = True
|
| 1031 |
+
if getattr(self, "segformer", None) is not None:
|
| 1032 |
+
with tf.name_scope(self.segformer.name):
|
| 1033 |
+
self.segformer.build(None)
|
| 1034 |
+
if getattr(self, "decode_head", None) is not None:
|
| 1035 |
+
with tf.name_scope(self.decode_head.name):
|
| 1036 |
+
self.decode_head.build(None)
|
| 1037 |
+
|
| 1038 |
+
|
| 1039 |
+
__all__ = [
|
| 1040 |
+
"TFSegformerDecodeHead",
|
| 1041 |
+
"TFSegformerForImageClassification",
|
| 1042 |
+
"TFSegformerForSemanticSegmentation",
|
| 1043 |
+
"TFSegformerModel",
|
| 1044 |
+
"TFSegformerPreTrainedModel",
|
| 1045 |
+
]
|
docs/transformers/build/lib/transformers/models/seggpt/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_seggpt import *
|
| 22 |
+
from .image_processing_seggpt import *
|
| 23 |
+
from .modeling_seggpt import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/seggpt/configuration_seggpt.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""SegGpt model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SegGptConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`SegGptModel`]. It is used to instantiate a SegGPT
|
| 27 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 28 |
+
defaults will yield a similar configuration to that of the SegGPT
|
| 29 |
+
[BAAI/seggpt-vit-large](https://huggingface.co/BAAI/seggpt-vit-large) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
hidden_size (`int`, *optional*, defaults to 1024):
|
| 36 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 37 |
+
num_hidden_layers (`int`, *optional*, defaults to 24):
|
| 38 |
+
Number of hidden layers in the Transformer encoder.
|
| 39 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
| 40 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 41 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 42 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 43 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 44 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 45 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 46 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 47 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 48 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 49 |
+
The epsilon used by the layer normalization layers.
|
| 50 |
+
image_size (`List[int]`, *optional*, defaults to `[896, 448]`):
|
| 51 |
+
The size (resolution) of each image.
|
| 52 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 53 |
+
The size (resolution) of each patch.
|
| 54 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 55 |
+
The number of input channels.
|
| 56 |
+
qkv_bias (`bool`, *optional*, defaults to `True`):
|
| 57 |
+
Whether to add a bias to the queries, keys and values.
|
| 58 |
+
mlp_dim (`int`, *optional*):
|
| 59 |
+
The dimensionality of the MLP layer in the Transformer encoder. If unset, defaults to
|
| 60 |
+
`hidden_size` * 4.
|
| 61 |
+
drop_path_rate (`float`, *optional*, defaults to 0.1):
|
| 62 |
+
The drop path rate for the dropout layers.
|
| 63 |
+
pretrain_image_size (`int`, *optional*, defaults to 224):
|
| 64 |
+
The pretrained size of the absolute position embeddings.
|
| 65 |
+
decoder_hidden_size (`int`, *optional*, defaults to 64):
|
| 66 |
+
Hidden size for decoder.
|
| 67 |
+
use_relative_position_embeddings (`bool`, *optional*, defaults to `True`):
|
| 68 |
+
Whether to use relative position embeddings in the attention layers.
|
| 69 |
+
merge_index (`int`, *optional*, defaults to 2):
|
| 70 |
+
The index of the encoder layer to merge the embeddings.
|
| 71 |
+
intermediate_hidden_state_indices (`List[int]`, *optional*, defaults to `[5, 11, 17, 23]`):
|
| 72 |
+
The indices of the encoder layers which we store as features for the decoder.
|
| 73 |
+
beta (`float`, *optional*, defaults to 0.01):
|
| 74 |
+
Regularization factor for SegGptLoss (smooth-l1 loss).
|
| 75 |
+
|
| 76 |
+
Example:
|
| 77 |
+
|
| 78 |
+
```python
|
| 79 |
+
>>> from transformers import SegGptConfig, SegGptModel
|
| 80 |
+
|
| 81 |
+
>>> # Initializing a SegGPT seggpt-vit-large style configuration
|
| 82 |
+
>>> configuration = SegGptConfig()
|
| 83 |
+
|
| 84 |
+
>>> # Initializing a model (with random weights) from the seggpt-vit-large style configuration
|
| 85 |
+
>>> model = SegGptModel(configuration)
|
| 86 |
+
|
| 87 |
+
>>> # Accessing the model configuration
|
| 88 |
+
>>> configuration = model.config
|
| 89 |
+
```"""
|
| 90 |
+
|
| 91 |
+
model_type = "seggpt"
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
hidden_size=1024,
|
| 96 |
+
num_hidden_layers=24,
|
| 97 |
+
num_attention_heads=16,
|
| 98 |
+
hidden_act="gelu",
|
| 99 |
+
hidden_dropout_prob=0.0,
|
| 100 |
+
initializer_range=0.02,
|
| 101 |
+
layer_norm_eps=1e-6,
|
| 102 |
+
image_size=[896, 448],
|
| 103 |
+
patch_size=16,
|
| 104 |
+
num_channels=3,
|
| 105 |
+
qkv_bias=True,
|
| 106 |
+
mlp_dim=None,
|
| 107 |
+
drop_path_rate=0.1,
|
| 108 |
+
pretrain_image_size=224,
|
| 109 |
+
decoder_hidden_size=64,
|
| 110 |
+
use_relative_position_embeddings=True,
|
| 111 |
+
merge_index=2,
|
| 112 |
+
intermediate_hidden_state_indices=[5, 11, 17, 23],
|
| 113 |
+
beta=0.01,
|
| 114 |
+
**kwargs,
|
| 115 |
+
):
|
| 116 |
+
super().__init__(**kwargs)
|
| 117 |
+
|
| 118 |
+
if merge_index > min(intermediate_hidden_state_indices):
|
| 119 |
+
raise ValueError(
|
| 120 |
+
f"Merge index must be less than the minimum encoder output index, but got {merge_index=} and {intermediate_hidden_state_indices=}"
|
| 121 |
+
)
|
| 122 |
+
self.hidden_size = hidden_size
|
| 123 |
+
self.num_hidden_layers = num_hidden_layers
|
| 124 |
+
self.num_attention_heads = num_attention_heads
|
| 125 |
+
self.hidden_act = hidden_act
|
| 126 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 127 |
+
self.initializer_range = initializer_range
|
| 128 |
+
self.layer_norm_eps = layer_norm_eps
|
| 129 |
+
self.image_size = image_size
|
| 130 |
+
self.patch_size = patch_size
|
| 131 |
+
self.num_channels = num_channels
|
| 132 |
+
self.qkv_bias = qkv_bias
|
| 133 |
+
self.drop_path_rate = drop_path_rate
|
| 134 |
+
self.pretrain_image_size = pretrain_image_size
|
| 135 |
+
self.decoder_hidden_size = decoder_hidden_size
|
| 136 |
+
self.use_relative_position_embeddings = use_relative_position_embeddings
|
| 137 |
+
self.merge_index = merge_index
|
| 138 |
+
self.intermediate_hidden_state_indices = intermediate_hidden_state_indices
|
| 139 |
+
self.beta = beta
|
| 140 |
+
self.mlp_dim = int(hidden_size * 4) if mlp_dim is None else mlp_dim
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
__all__ = ["SegGptConfig"]
|
docs/transformers/build/lib/transformers/models/seggpt/convert_seggpt_to_hf.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert SegGPT checkpoints from the original repository.
|
| 16 |
+
|
| 17 |
+
URL: https://github.com/baaivision/Painter/tree/main/SegGPT
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
|
| 22 |
+
import requests
|
| 23 |
+
import torch
|
| 24 |
+
from PIL import Image
|
| 25 |
+
|
| 26 |
+
from transformers import SegGptConfig, SegGptForImageSegmentation, SegGptImageProcessor
|
| 27 |
+
from transformers.utils import logging
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logging.set_verbosity_info()
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# here we list all keys to be renamed (original name on the left, our name on the right)
|
| 35 |
+
def create_rename_keys(config):
|
| 36 |
+
rename_keys = []
|
| 37 |
+
|
| 38 |
+
# fmt: off
|
| 39 |
+
|
| 40 |
+
# rename embedding and its parameters
|
| 41 |
+
rename_keys.append(("patch_embed.proj.weight", "model.embeddings.patch_embeddings.projection.weight"))
|
| 42 |
+
rename_keys.append(("patch_embed.proj.bias", "model.embeddings.patch_embeddings.projection.bias"))
|
| 43 |
+
rename_keys.append(("mask_token", "model.embeddings.mask_token"))
|
| 44 |
+
rename_keys.append(("segment_token_x", "model.embeddings.segment_token_input"))
|
| 45 |
+
rename_keys.append(("segment_token_y", "model.embeddings.segment_token_prompt"))
|
| 46 |
+
rename_keys.append(("type_token_cls", "model.embeddings.type_token_semantic"))
|
| 47 |
+
rename_keys.append(("type_token_ins", "model.embeddings.type_token_instance"))
|
| 48 |
+
rename_keys.append(("pos_embed", "model.embeddings.position_embeddings"))
|
| 49 |
+
|
| 50 |
+
# rename decoder and other
|
| 51 |
+
rename_keys.append(("norm.weight", "model.encoder.layernorm.weight"))
|
| 52 |
+
rename_keys.append(("norm.bias", "model.encoder.layernorm.bias"))
|
| 53 |
+
rename_keys.append(("decoder_embed.weight", "decoder.decoder_embed.weight"))
|
| 54 |
+
rename_keys.append(("decoder_embed.bias", "decoder.decoder_embed.bias"))
|
| 55 |
+
rename_keys.append(("decoder_pred.0.weight", "decoder.decoder_pred.conv.weight"))
|
| 56 |
+
rename_keys.append(("decoder_pred.0.bias", "decoder.decoder_pred.conv.bias"))
|
| 57 |
+
rename_keys.append(("decoder_pred.1.weight", "decoder.decoder_pred.layernorm.weight"))
|
| 58 |
+
rename_keys.append(("decoder_pred.1.bias", "decoder.decoder_pred.layernorm.bias"))
|
| 59 |
+
rename_keys.append(("decoder_pred.3.weight", "decoder.decoder_pred.head.weight"))
|
| 60 |
+
rename_keys.append(("decoder_pred.3.bias", "decoder.decoder_pred.head.bias"))
|
| 61 |
+
|
| 62 |
+
# rename blocks
|
| 63 |
+
for i in range(config.num_hidden_layers):
|
| 64 |
+
rename_keys.append((f"blocks.{i}.attn.qkv.weight", f"model.encoder.layers.{i}.attention.qkv.weight"))
|
| 65 |
+
rename_keys.append((f"blocks.{i}.attn.qkv.bias", f"model.encoder.layers.{i}.attention.qkv.bias"))
|
| 66 |
+
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"model.encoder.layers.{i}.attention.proj.weight"))
|
| 67 |
+
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"model.encoder.layers.{i}.attention.proj.bias"))
|
| 68 |
+
rename_keys.append((f"blocks.{i}.attn.rel_pos_h", f"model.encoder.layers.{i}.attention.rel_pos_h"))
|
| 69 |
+
rename_keys.append((f"blocks.{i}.attn.rel_pos_w", f"model.encoder.layers.{i}.attention.rel_pos_w"))
|
| 70 |
+
|
| 71 |
+
rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"model.encoder.layers.{i}.mlp.lin1.weight"))
|
| 72 |
+
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"model.encoder.layers.{i}.mlp.lin1.bias"))
|
| 73 |
+
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"model.encoder.layers.{i}.mlp.lin2.weight"))
|
| 74 |
+
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"model.encoder.layers.{i}.mlp.lin2.bias"))
|
| 75 |
+
|
| 76 |
+
rename_keys.append((f"blocks.{i}.norm1.weight", f"model.encoder.layers.{i}.layernorm_before.weight"))
|
| 77 |
+
rename_keys.append((f"blocks.{i}.norm1.bias", f"model.encoder.layers.{i}.layernorm_before.bias"))
|
| 78 |
+
rename_keys.append((f"blocks.{i}.norm2.weight", f"model.encoder.layers.{i}.layernorm_after.weight"))
|
| 79 |
+
rename_keys.append((f"blocks.{i}.norm2.bias", f"model.encoder.layers.{i}.layernorm_after.bias"))
|
| 80 |
+
|
| 81 |
+
# fmt: on
|
| 82 |
+
|
| 83 |
+
return rename_keys
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def rename_key(dct, old, new):
|
| 87 |
+
val = dct.pop(old)
|
| 88 |
+
dct[new] = val
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# We will verify our results on spongebob images
|
| 92 |
+
def prepare_input():
|
| 93 |
+
image_input_url = (
|
| 94 |
+
"https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg"
|
| 95 |
+
)
|
| 96 |
+
image_prompt_url = (
|
| 97 |
+
"https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg"
|
| 98 |
+
)
|
| 99 |
+
mask_prompt_url = (
|
| 100 |
+
"https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
image_input = Image.open(requests.get(image_input_url, stream=True).raw)
|
| 104 |
+
image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw)
|
| 105 |
+
mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw)
|
| 106 |
+
|
| 107 |
+
return image_input, image_prompt, mask_prompt
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@torch.no_grad()
|
| 111 |
+
def convert_seggpt_checkpoint(args):
|
| 112 |
+
model_name = args.model_name
|
| 113 |
+
pytorch_dump_folder_path = args.pytorch_dump_folder_path
|
| 114 |
+
verify_logits = args.verify_logits
|
| 115 |
+
push_to_hub = args.push_to_hub
|
| 116 |
+
|
| 117 |
+
# Define default GroundingDINO configuation
|
| 118 |
+
config = SegGptConfig()
|
| 119 |
+
|
| 120 |
+
# Load original checkpoint
|
| 121 |
+
checkpoint_url = "https://huggingface.co/BAAI/SegGpt/blob/main/seggpt_vit_large.pth"
|
| 122 |
+
original_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["model"]
|
| 123 |
+
|
| 124 |
+
# # Rename keys
|
| 125 |
+
new_state_dict = original_state_dict.copy()
|
| 126 |
+
rename_keys = create_rename_keys(config)
|
| 127 |
+
|
| 128 |
+
for src, dest in rename_keys:
|
| 129 |
+
rename_key(new_state_dict, src, dest)
|
| 130 |
+
|
| 131 |
+
# Load HF model
|
| 132 |
+
model = SegGptForImageSegmentation(config)
|
| 133 |
+
model.eval()
|
| 134 |
+
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
|
| 135 |
+
print("Missing keys:", missing_keys)
|
| 136 |
+
print("Unexpected keys:", unexpected_keys)
|
| 137 |
+
|
| 138 |
+
input_img, prompt_img, prompt_mask = prepare_input()
|
| 139 |
+
image_processor = SegGptImageProcessor()
|
| 140 |
+
inputs = image_processor(images=input_img, prompt_images=prompt_img, prompt_masks=prompt_mask, return_tensors="pt")
|
| 141 |
+
|
| 142 |
+
expected_prompt_pixel_values = torch.tensor(
|
| 143 |
+
[
|
| 144 |
+
[[-0.6965, -0.6965, -0.6965], [-0.6965, -0.6965, -0.6965], [-0.6965, -0.6965, -0.6965]],
|
| 145 |
+
[[1.6583, 1.6583, 1.6583], [1.6583, 1.6583, 1.6583], [1.6583, 1.6583, 1.6583]],
|
| 146 |
+
[[2.3088, 2.3088, 2.3088], [2.3088, 2.3088, 2.3088], [2.3088, 2.3088, 2.3088]],
|
| 147 |
+
]
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
expected_pixel_values = torch.tensor(
|
| 151 |
+
[
|
| 152 |
+
[[1.6324, 1.6153, 1.5810], [1.6153, 1.5982, 1.5810], [1.5810, 1.5639, 1.5639]],
|
| 153 |
+
[[1.2731, 1.2556, 1.2206], [1.2556, 1.2381, 1.2031], [1.2206, 1.2031, 1.1681]],
|
| 154 |
+
[[1.6465, 1.6465, 1.6465], [1.6465, 1.6465, 1.6465], [1.6291, 1.6291, 1.6291]],
|
| 155 |
+
]
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
expected_prompt_masks = torch.tensor(
|
| 159 |
+
[
|
| 160 |
+
[[-2.1179, -2.1179, -2.1179], [-2.1179, -2.1179, -2.1179], [-2.1179, -2.1179, -2.1179]],
|
| 161 |
+
[[-2.0357, -2.0357, -2.0357], [-2.0357, -2.0357, -2.0357], [-2.0357, -2.0357, -2.0357]],
|
| 162 |
+
[[-1.8044, -1.8044, -1.8044], [-1.8044, -1.8044, -1.8044], [-1.8044, -1.8044, -1.8044]],
|
| 163 |
+
]
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
assert torch.allclose(inputs.pixel_values[0, :, :3, :3], expected_pixel_values, atol=1e-4)
|
| 167 |
+
assert torch.allclose(inputs.prompt_pixel_values[0, :, :3, :3], expected_prompt_pixel_values, atol=1e-4)
|
| 168 |
+
assert torch.allclose(inputs.prompt_masks[0, :, :3, :3], expected_prompt_masks, atol=1e-4)
|
| 169 |
+
|
| 170 |
+
torch.manual_seed(2)
|
| 171 |
+
outputs = model(**inputs)
|
| 172 |
+
print(outputs)
|
| 173 |
+
|
| 174 |
+
if verify_logits:
|
| 175 |
+
expected_output = torch.tensor(
|
| 176 |
+
[
|
| 177 |
+
[[-2.1208, -2.1190, -2.1198], [-2.1237, -2.1228, -2.1227], [-2.1232, -2.1226, -2.1228]],
|
| 178 |
+
[[-2.0405, -2.0396, -2.0403], [-2.0434, -2.0434, -2.0433], [-2.0428, -2.0432, -2.0434]],
|
| 179 |
+
[[-1.8102, -1.8088, -1.8099], [-1.8131, -1.8126, -1.8129], [-1.8130, -1.8128, -1.8131]],
|
| 180 |
+
]
|
| 181 |
+
)
|
| 182 |
+
assert torch.allclose(outputs.pred_masks[0, :, :3, :3], expected_output, atol=1e-4)
|
| 183 |
+
print("Looks good!")
|
| 184 |
+
else:
|
| 185 |
+
print("Converted without verifying logits")
|
| 186 |
+
|
| 187 |
+
if pytorch_dump_folder_path is not None:
|
| 188 |
+
print(f"Saving model and processor for {model_name} to {pytorch_dump_folder_path}")
|
| 189 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 190 |
+
image_processor.save_pretrained(pytorch_dump_folder_path)
|
| 191 |
+
|
| 192 |
+
if push_to_hub:
|
| 193 |
+
print(f"Pushing model and processor for {model_name} to hub")
|
| 194 |
+
model.push_to_hub(f"EduardoPacheco/{model_name}")
|
| 195 |
+
image_processor.push_to_hub(f"EduardoPacheco/{model_name}")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
if __name__ == "__main__":
|
| 199 |
+
parser = argparse.ArgumentParser()
|
| 200 |
+
# Required parameters
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--model_name",
|
| 203 |
+
default="seggpt-vit-large",
|
| 204 |
+
type=str,
|
| 205 |
+
choices=["seggpt-vit-large"],
|
| 206 |
+
help="Name of the SegGpt model you'd like to convert.",
|
| 207 |
+
)
|
| 208 |
+
parser.add_argument(
|
| 209 |
+
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
| 210 |
+
)
|
| 211 |
+
parser.add_argument(
|
| 212 |
+
"--verify_logits",
|
| 213 |
+
action="store_false",
|
| 214 |
+
help="Whether or not to verify the logits against the original implementation.",
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
args = parser.parse_args()
|
| 221 |
+
convert_seggpt_checkpoint(args)
|
docs/transformers/build/lib/transformers/models/seggpt/image_processing_seggpt.py
ADDED
|
@@ -0,0 +1,618 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for SegGPT."""
|
| 16 |
+
|
| 17 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
| 22 |
+
from ...image_transforms import resize, to_channel_dimension_format
|
| 23 |
+
from ...image_utils import (
|
| 24 |
+
IMAGENET_DEFAULT_MEAN,
|
| 25 |
+
IMAGENET_DEFAULT_STD,
|
| 26 |
+
ChannelDimension,
|
| 27 |
+
ImageInput,
|
| 28 |
+
PILImageResampling,
|
| 29 |
+
infer_channel_dimension_format,
|
| 30 |
+
is_scaled_image,
|
| 31 |
+
make_list_of_images,
|
| 32 |
+
to_numpy_array,
|
| 33 |
+
valid_images,
|
| 34 |
+
)
|
| 35 |
+
from ...utils import TensorType, is_torch_available, is_vision_available, logging, requires_backends
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if is_torch_available():
|
| 39 |
+
import torch
|
| 40 |
+
|
| 41 |
+
if is_vision_available():
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
logger = logging.get_logger(__name__)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# See https://arxiv.org/pdf/2212.02499.pdf at 3.1 Redefining Output Spaces as "Images" - Semantic Segmentation from PAINTER paper
|
| 49 |
+
# Taken from https://github.com/Abdullah-Meda/Painter/blob/main/Painter/data/coco_semseg/gen_color_coco_panoptic_segm.py#L31
|
| 50 |
+
def build_palette(num_labels: int) -> List[Tuple[int, int]]:
|
| 51 |
+
base = int(num_labels ** (1 / 3)) + 1
|
| 52 |
+
margin = 256 // base
|
| 53 |
+
|
| 54 |
+
# we assume that class_idx 0 is the background which is mapped to black
|
| 55 |
+
color_list = [(0, 0, 0)]
|
| 56 |
+
for location in range(num_labels):
|
| 57 |
+
num_seq_r = location // base**2
|
| 58 |
+
num_seq_g = (location % base**2) // base
|
| 59 |
+
num_seq_b = location % base
|
| 60 |
+
|
| 61 |
+
R = 255 - num_seq_r * margin
|
| 62 |
+
G = 255 - num_seq_g * margin
|
| 63 |
+
B = 255 - num_seq_b * margin
|
| 64 |
+
|
| 65 |
+
color_list.append((R, G, B))
|
| 66 |
+
|
| 67 |
+
return color_list
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def mask_to_rgb(
|
| 71 |
+
mask: np.ndarray, palette: Optional[List[Tuple[int, int]]] = None, data_format: Optional[ChannelDimension] = None
|
| 72 |
+
) -> np.ndarray:
|
| 73 |
+
data_format = data_format if data_format is not None else ChannelDimension.FIRST
|
| 74 |
+
|
| 75 |
+
if palette is not None:
|
| 76 |
+
height, width = mask.shape
|
| 77 |
+
|
| 78 |
+
rgb_mask = np.zeros((3, height, width), dtype=np.uint8)
|
| 79 |
+
|
| 80 |
+
classes_in_mask = np.unique(mask)
|
| 81 |
+
|
| 82 |
+
for class_idx in classes_in_mask:
|
| 83 |
+
rgb_value = palette[class_idx]
|
| 84 |
+
class_mask = (mask == class_idx).astype(np.uint8)
|
| 85 |
+
class_mask = np.expand_dims(class_mask, axis=-1)
|
| 86 |
+
class_rgb_mask = class_mask * np.array(rgb_value)
|
| 87 |
+
class_rgb_mask = np.moveaxis(class_rgb_mask, -1, 0)
|
| 88 |
+
rgb_mask += class_rgb_mask.astype(np.uint8)
|
| 89 |
+
|
| 90 |
+
rgb_mask = np.clip(rgb_mask, 0, 255).astype(np.uint8)
|
| 91 |
+
|
| 92 |
+
else:
|
| 93 |
+
rgb_mask = np.repeat(mask[None, ...], 3, axis=0)
|
| 94 |
+
|
| 95 |
+
return to_channel_dimension_format(rgb_mask, data_format)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class SegGptImageProcessor(BaseImageProcessor):
|
| 99 |
+
r"""
|
| 100 |
+
Constructs a SegGpt image processor.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 104 |
+
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
|
| 105 |
+
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
|
| 106 |
+
size (`dict`, *optional*, defaults to `{"height": 448, "width": 448}`):
|
| 107 |
+
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
| 108 |
+
method.
|
| 109 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
| 110 |
+
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
| 111 |
+
`preprocess` method.
|
| 112 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 113 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
| 114 |
+
parameter in the `preprocess` method.
|
| 115 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 116 |
+
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
| 117 |
+
`preprocess` method.
|
| 118 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 119 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
| 120 |
+
method.
|
| 121 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
|
| 122 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 123 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 124 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
|
| 125 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
| 126 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 127 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 128 |
+
Whether to convert the prompt mask to RGB format. Can be overridden by the `do_convert_rgb` parameter in the
|
| 129 |
+
`preprocess` method.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
model_input_names = ["pixel_values"]
|
| 133 |
+
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
do_resize: bool = True,
|
| 137 |
+
size: Optional[Dict[str, int]] = None,
|
| 138 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 139 |
+
do_rescale: bool = True,
|
| 140 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 141 |
+
do_normalize: bool = True,
|
| 142 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 143 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 144 |
+
do_convert_rgb: bool = True,
|
| 145 |
+
**kwargs,
|
| 146 |
+
) -> None:
|
| 147 |
+
super().__init__(**kwargs)
|
| 148 |
+
size = size if size is not None else {"height": 448, "width": 448}
|
| 149 |
+
size = get_size_dict(size)
|
| 150 |
+
self.do_resize = do_resize
|
| 151 |
+
self.do_rescale = do_rescale
|
| 152 |
+
self.do_normalize = do_normalize
|
| 153 |
+
self.size = size
|
| 154 |
+
self.resample = resample
|
| 155 |
+
self.rescale_factor = rescale_factor
|
| 156 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
| 157 |
+
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
| 158 |
+
self.do_convert_rgb = do_convert_rgb
|
| 159 |
+
|
| 160 |
+
def get_palette(self, num_labels: int) -> List[Tuple[int, int]]:
|
| 161 |
+
"""Build a palette to map the prompt mask from a single channel to a 3 channel RGB.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
num_labels (`int`):
|
| 165 |
+
Number of classes in the segmentation task (excluding the background).
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
`List[Tuple[int, int]]`: Palette to map the prompt mask from a single channel to a 3 channel RGB.
|
| 169 |
+
"""
|
| 170 |
+
return build_palette(num_labels)
|
| 171 |
+
|
| 172 |
+
def mask_to_rgb(
|
| 173 |
+
self,
|
| 174 |
+
image: np.ndarray,
|
| 175 |
+
palette: Optional[List[Tuple[int, int]]] = None,
|
| 176 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 177 |
+
) -> np.ndarray:
|
| 178 |
+
"""Converts a segmentation map to RGB format.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
image (`np.ndarray`):
|
| 182 |
+
Segmentation map with dimensions (height, width) where pixel values represent the class index.
|
| 183 |
+
palette (`List[Tuple[int, int]]`, *optional*, defaults to `None`):
|
| 184 |
+
Palette to use to convert the mask to RGB format. If unset, the mask is duplicated across the channel
|
| 185 |
+
dimension.
|
| 186 |
+
data_format (`ChannelDimension` or `str`, *optional*):
|
| 187 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 188 |
+
image is used. Can be one of:
|
| 189 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 190 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
`np.ndarray`: The mask in RGB format.
|
| 194 |
+
"""
|
| 195 |
+
return mask_to_rgb(image, palette=palette, data_format=data_format)
|
| 196 |
+
|
| 197 |
+
# Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
|
| 198 |
+
def resize(
|
| 199 |
+
self,
|
| 200 |
+
image: np.ndarray,
|
| 201 |
+
size: Dict[str, int],
|
| 202 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 203 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 204 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 205 |
+
**kwargs,
|
| 206 |
+
) -> np.ndarray:
|
| 207 |
+
"""
|
| 208 |
+
Resize an image to `(size["height"], size["width"])`.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
image (`np.ndarray`):
|
| 212 |
+
Image to resize.
|
| 213 |
+
size (`Dict[str, int]`):
|
| 214 |
+
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
| 215 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
| 216 |
+
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
|
| 217 |
+
data_format (`ChannelDimension` or `str`, *optional*):
|
| 218 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 219 |
+
image is used. Can be one of:
|
| 220 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 221 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 222 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 223 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 224 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 225 |
+
from the input image. Can be one of:
|
| 226 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 227 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 228 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
`np.ndarray`: The resized image.
|
| 232 |
+
"""
|
| 233 |
+
size = get_size_dict(size)
|
| 234 |
+
if "height" not in size or "width" not in size:
|
| 235 |
+
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
| 236 |
+
output_size = (size["height"], size["width"])
|
| 237 |
+
return resize(
|
| 238 |
+
image,
|
| 239 |
+
size=output_size,
|
| 240 |
+
resample=resample,
|
| 241 |
+
data_format=data_format,
|
| 242 |
+
input_data_format=input_data_format,
|
| 243 |
+
**kwargs,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
def _preprocess_step(
|
| 247 |
+
self,
|
| 248 |
+
images: ImageInput,
|
| 249 |
+
do_resize: Optional[bool] = None,
|
| 250 |
+
size: Dict[str, int] = None,
|
| 251 |
+
resample: PILImageResampling = None,
|
| 252 |
+
do_rescale: Optional[bool] = None,
|
| 253 |
+
rescale_factor: Optional[float] = None,
|
| 254 |
+
do_normalize: Optional[bool] = None,
|
| 255 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 256 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 257 |
+
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
| 258 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 259 |
+
do_convert_rgb: Optional[bool] = None,
|
| 260 |
+
num_labels: Optional[int] = None,
|
| 261 |
+
**kwargs,
|
| 262 |
+
):
|
| 263 |
+
"""
|
| 264 |
+
Preprocess an image or batch of images.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
images (`ImageInput`):
|
| 268 |
+
Image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 269 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 270 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 271 |
+
Whether to resize the image.
|
| 272 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 273 |
+
Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
|
| 274 |
+
resizing.
|
| 275 |
+
resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
|
| 276 |
+
`PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BICUBIC`. Only has
|
| 277 |
+
an effect if `do_resize` is set to `True`.
|
| 278 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 279 |
+
Whether to rescale the image values between [0 - 1].
|
| 280 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 281 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 282 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 283 |
+
Whether to normalize the image.
|
| 284 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 285 |
+
Image mean to use if `do_normalize` is set to `True`.
|
| 286 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 287 |
+
Image standard deviation to use if `do_normalize` is set to `True`.
|
| 288 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 289 |
+
The type of tensors to return. Can be one of:
|
| 290 |
+
- Unset: Return a list of `np.ndarray`.
|
| 291 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 292 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 293 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 294 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 295 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 296 |
+
The channel dimension format for the output image. Can be one of:
|
| 297 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 298 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 299 |
+
- Unset: Use the channel dimension format of the input image.
|
| 300 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 301 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 302 |
+
from the input image. Can be one of:
|
| 303 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 304 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 305 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 306 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 307 |
+
Whether to convert the prompt mask to RGB format. If `num_labels` is specified, a palette will be built
|
| 308 |
+
to map the prompt mask from a single channel to a 3 channel RGB. If unset, the prompt mask is duplicated
|
| 309 |
+
across the channel dimension. Must be set to `False` if the prompt mask is already in RGB format.
|
| 310 |
+
num_labels: (`int`, *optional*):
|
| 311 |
+
Number of classes in the segmentation task (excluding the background). If specified, a palette will be
|
| 312 |
+
built, assuming that class_idx 0 is the background, to map the prompt mask from a single class_idx
|
| 313 |
+
channel to a 3 channel RGB. Not specifying this will result in the prompt mask either being passed
|
| 314 |
+
through as is if it is already in RGB format or being duplicated across the channel dimension.
|
| 315 |
+
"""
|
| 316 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 317 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 318 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 319 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 320 |
+
resample = resample if resample is not None else self.resample
|
| 321 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 322 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 323 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 324 |
+
|
| 325 |
+
size = size if size is not None else self.size
|
| 326 |
+
size_dict = get_size_dict(size)
|
| 327 |
+
|
| 328 |
+
# If segmentation map is passed we expect 2D images
|
| 329 |
+
images = make_list_of_images(images, expected_ndims=2 if do_convert_rgb else 3)
|
| 330 |
+
|
| 331 |
+
if not valid_images(images):
|
| 332 |
+
raise ValueError(
|
| 333 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 334 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if do_resize and size is None:
|
| 338 |
+
raise ValueError("Size must be specified if do_resize is True.")
|
| 339 |
+
|
| 340 |
+
if do_rescale and rescale_factor is None:
|
| 341 |
+
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
| 342 |
+
|
| 343 |
+
if do_normalize and (image_mean is None or image_std is None):
|
| 344 |
+
raise ValueError("Image mean and std must be specified if do_normalize is True.")
|
| 345 |
+
|
| 346 |
+
# All transformations expect numpy arrays.
|
| 347 |
+
images = [to_numpy_array(image) for image in images]
|
| 348 |
+
|
| 349 |
+
if do_rescale and is_scaled_image(images[0]):
|
| 350 |
+
logger.warning_once(
|
| 351 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 352 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
if input_data_format is None and not do_convert_rgb:
|
| 356 |
+
# We assume that all images have the same channel dimension format.
|
| 357 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 358 |
+
|
| 359 |
+
if do_convert_rgb:
|
| 360 |
+
palette = self.get_palette(num_labels) if num_labels is not None else None
|
| 361 |
+
# Since this is the input for the next transformations its format should be the same as the input_data_format
|
| 362 |
+
images = [
|
| 363 |
+
self.mask_to_rgb(image=image, palette=palette, data_format=ChannelDimension.FIRST) for image in images
|
| 364 |
+
]
|
| 365 |
+
input_data_format = ChannelDimension.FIRST
|
| 366 |
+
|
| 367 |
+
if do_resize:
|
| 368 |
+
images = [
|
| 369 |
+
self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format)
|
| 370 |
+
for image in images
|
| 371 |
+
]
|
| 372 |
+
|
| 373 |
+
if do_rescale:
|
| 374 |
+
images = [
|
| 375 |
+
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 376 |
+
for image in images
|
| 377 |
+
]
|
| 378 |
+
|
| 379 |
+
if do_normalize:
|
| 380 |
+
images = [
|
| 381 |
+
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 382 |
+
for image in images
|
| 383 |
+
]
|
| 384 |
+
|
| 385 |
+
images = [
|
| 386 |
+
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
| 387 |
+
]
|
| 388 |
+
|
| 389 |
+
return images
|
| 390 |
+
|
| 391 |
+
def preprocess(
|
| 392 |
+
self,
|
| 393 |
+
images: Optional[ImageInput] = None,
|
| 394 |
+
prompt_images: Optional[ImageInput] = None,
|
| 395 |
+
prompt_masks: Optional[ImageInput] = None,
|
| 396 |
+
do_resize: Optional[bool] = None,
|
| 397 |
+
size: Dict[str, int] = None,
|
| 398 |
+
resample: PILImageResampling = None,
|
| 399 |
+
do_rescale: Optional[bool] = None,
|
| 400 |
+
rescale_factor: Optional[float] = None,
|
| 401 |
+
do_normalize: Optional[bool] = None,
|
| 402 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 403 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 404 |
+
do_convert_rgb: Optional[bool] = None,
|
| 405 |
+
num_labels: Optional[int] = None,
|
| 406 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 407 |
+
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
| 408 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 409 |
+
**kwargs,
|
| 410 |
+
):
|
| 411 |
+
"""
|
| 412 |
+
Preprocess an image or batch of images.
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
images (`ImageInput`):
|
| 416 |
+
Image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 417 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 418 |
+
prompt_images (`ImageInput`):
|
| 419 |
+
Prompt image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 420 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 421 |
+
prompt_masks (`ImageInput`):
|
| 422 |
+
Prompt mask from prompt image to _preprocess that specify prompt_masks value in the preprocessed output.
|
| 423 |
+
Can either be in the format of segmentation maps (no channels) or RGB images. If in the format of
|
| 424 |
+
RGB images, `do_convert_rgb` should be set to `False`. If in the format of segmentation maps, `num_labels`
|
| 425 |
+
specifying `num_labels` is recommended to build a palette to map the prompt mask from a single channel to
|
| 426 |
+
a 3 channel RGB. If `num_labels` is not specified, the prompt mask will be duplicated across the channel
|
| 427 |
+
dimension.
|
| 428 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 429 |
+
Whether to resize the image.
|
| 430 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 431 |
+
Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
|
| 432 |
+
resizing.
|
| 433 |
+
resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
|
| 434 |
+
`PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BICUBIC`. Only has
|
| 435 |
+
an effect if `do_resize` is set to `True`. Doesn't apply to prompt mask as it is resized using nearest.
|
| 436 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 437 |
+
Whether to rescale the image values between [0 - 1].
|
| 438 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 439 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 440 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 441 |
+
Whether to normalize the image.
|
| 442 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 443 |
+
Image mean to use if `do_normalize` is set to `True`.
|
| 444 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 445 |
+
Image standard deviation to use if `do_normalize` is set to `True`.
|
| 446 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 447 |
+
Whether to convert the prompt mask to RGB format. If `num_labels` is specified, a palette will be built
|
| 448 |
+
to map the prompt mask from a single channel to a 3 channel RGB. If unset, the prompt mask is duplicated
|
| 449 |
+
across the channel dimension. Must be set to `False` if the prompt mask is already in RGB format.
|
| 450 |
+
num_labels: (`int`, *optional*):
|
| 451 |
+
Number of classes in the segmentation task (excluding the background). If specified, a palette will be
|
| 452 |
+
built, assuming that class_idx 0 is the background, to map the prompt mask from a plain segmentation map
|
| 453 |
+
with no channels to a 3 channel RGB. Not specifying this will result in the prompt mask either being passed
|
| 454 |
+
through as is if it is already in RGB format (if `do_convert_rgb` is false) or being duplicated
|
| 455 |
+
across the channel dimension.
|
| 456 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 457 |
+
The type of tensors to return. Can be one of:
|
| 458 |
+
- Unset: Return a list of `np.ndarray`.
|
| 459 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 460 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 461 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 462 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 463 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 464 |
+
The channel dimension format for the output image. Can be one of:
|
| 465 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 466 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 467 |
+
- Unset: Use the channel dimension format of the input image.
|
| 468 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 469 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 470 |
+
from the input image. Can be one of:
|
| 471 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 472 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 473 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 474 |
+
"""
|
| 475 |
+
if all(v is None for v in [images, prompt_images, prompt_masks]):
|
| 476 |
+
raise ValueError("At least one of images, prompt_images, prompt_masks must be specified.")
|
| 477 |
+
|
| 478 |
+
data = {}
|
| 479 |
+
|
| 480 |
+
if images is not None:
|
| 481 |
+
images = self._preprocess_step(
|
| 482 |
+
images,
|
| 483 |
+
is_mask=False,
|
| 484 |
+
do_resize=do_resize,
|
| 485 |
+
size=size,
|
| 486 |
+
resample=resample,
|
| 487 |
+
do_rescale=do_rescale,
|
| 488 |
+
rescale_factor=rescale_factor,
|
| 489 |
+
do_normalize=do_normalize,
|
| 490 |
+
image_mean=image_mean,
|
| 491 |
+
image_std=image_std,
|
| 492 |
+
do_convert_rgb=False,
|
| 493 |
+
data_format=data_format,
|
| 494 |
+
input_data_format=input_data_format,
|
| 495 |
+
**kwargs,
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
data["pixel_values"] = images
|
| 499 |
+
|
| 500 |
+
if prompt_images is not None:
|
| 501 |
+
prompt_images = self._preprocess_step(
|
| 502 |
+
prompt_images,
|
| 503 |
+
is_mask=False,
|
| 504 |
+
do_resize=do_resize,
|
| 505 |
+
size=size,
|
| 506 |
+
resample=resample,
|
| 507 |
+
do_rescale=do_rescale,
|
| 508 |
+
rescale_factor=rescale_factor,
|
| 509 |
+
do_normalize=do_normalize,
|
| 510 |
+
image_mean=image_mean,
|
| 511 |
+
image_std=image_std,
|
| 512 |
+
do_convert_rgb=False,
|
| 513 |
+
data_format=data_format,
|
| 514 |
+
input_data_format=input_data_format,
|
| 515 |
+
**kwargs,
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
data["prompt_pixel_values"] = prompt_images
|
| 519 |
+
|
| 520 |
+
if prompt_masks is not None:
|
| 521 |
+
prompt_masks = self._preprocess_step(
|
| 522 |
+
prompt_masks,
|
| 523 |
+
do_resize=do_resize,
|
| 524 |
+
size=size,
|
| 525 |
+
resample=PILImageResampling.NEAREST,
|
| 526 |
+
do_rescale=do_rescale,
|
| 527 |
+
rescale_factor=rescale_factor,
|
| 528 |
+
do_normalize=do_normalize,
|
| 529 |
+
image_mean=image_mean,
|
| 530 |
+
image_std=image_std,
|
| 531 |
+
do_convert_rgb=do_convert_rgb,
|
| 532 |
+
num_labels=num_labels,
|
| 533 |
+
data_format=data_format,
|
| 534 |
+
input_data_format=input_data_format,
|
| 535 |
+
**kwargs,
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
data["prompt_masks"] = prompt_masks
|
| 539 |
+
|
| 540 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 541 |
+
|
| 542 |
+
def post_process_semantic_segmentation(
|
| 543 |
+
self, outputs, target_sizes: Optional[List[Tuple[int, int]]] = None, num_labels: Optional[int] = None
|
| 544 |
+
):
|
| 545 |
+
"""
|
| 546 |
+
Converts the output of [`SegGptImageSegmentationOutput`] into segmentation maps. Only supports
|
| 547 |
+
PyTorch.
|
| 548 |
+
|
| 549 |
+
Args:
|
| 550 |
+
outputs ([`SegGptImageSegmentationOutput`]):
|
| 551 |
+
Raw outputs of the model.
|
| 552 |
+
target_sizes (`List[Tuple[int, int]]`, *optional*):
|
| 553 |
+
List of length (batch_size), where each list item (`Tuple[int, int]`) corresponds to the requested
|
| 554 |
+
final size (height, width) of each prediction. If left to None, predictions will not be resized.
|
| 555 |
+
num_labels (`int`, *optional*):
|
| 556 |
+
Number of classes in the segmentation task (excluding the background). If specified, a palette will be
|
| 557 |
+
built, assuming that class_idx 0 is the background, to map prediction masks from RGB values to class
|
| 558 |
+
indices. This value should be the same used when preprocessing inputs.
|
| 559 |
+
Returns:
|
| 560 |
+
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
|
| 561 |
+
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
|
| 562 |
+
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
| 563 |
+
"""
|
| 564 |
+
requires_backends(self, ["torch"])
|
| 565 |
+
# batch_size x num_channels x 2*height x width
|
| 566 |
+
masks = outputs.pred_masks
|
| 567 |
+
|
| 568 |
+
# Predicted mask and prompt are concatenated in the height dimension
|
| 569 |
+
# batch_size x num_channels x height x width
|
| 570 |
+
masks = masks[:, :, masks.shape[2] // 2 :, :]
|
| 571 |
+
|
| 572 |
+
# To unnormalize we need to permute to channel last
|
| 573 |
+
# batch_size x height x width x num_channels
|
| 574 |
+
std = torch.tensor(self.image_std).to(masks.device)
|
| 575 |
+
mean = torch.tensor(self.image_mean).to(masks.device)
|
| 576 |
+
|
| 577 |
+
masks = masks.permute(0, 2, 3, 1) * std + mean
|
| 578 |
+
|
| 579 |
+
# batch_size x num_channels x height x width
|
| 580 |
+
masks = masks.permute(0, 3, 1, 2)
|
| 581 |
+
|
| 582 |
+
# Clip to match with palette if specified
|
| 583 |
+
masks = torch.clip(masks * 255, 0, 255)
|
| 584 |
+
|
| 585 |
+
semantic_segmentation = []
|
| 586 |
+
palette_tensor = None
|
| 587 |
+
palette = self.get_palette(num_labels) if num_labels is not None else None
|
| 588 |
+
if palette is not None:
|
| 589 |
+
palette_tensor = torch.tensor(palette).to(device=masks.device, dtype=torch.float)
|
| 590 |
+
_, num_channels, _, _ = masks.shape
|
| 591 |
+
palette_tensor = palette_tensor.view(1, 1, num_labels + 1, num_channels)
|
| 592 |
+
|
| 593 |
+
for idx, mask in enumerate(masks):
|
| 594 |
+
if target_sizes is not None:
|
| 595 |
+
mask = torch.nn.functional.interpolate(
|
| 596 |
+
mask.unsqueeze(0),
|
| 597 |
+
size=target_sizes[idx],
|
| 598 |
+
mode="nearest",
|
| 599 |
+
)[0]
|
| 600 |
+
|
| 601 |
+
if num_labels is not None:
|
| 602 |
+
channels, height, width = mask.shape
|
| 603 |
+
dist = mask.permute(1, 2, 0).view(height, width, 1, channels)
|
| 604 |
+
dist = dist - palette_tensor
|
| 605 |
+
dist = torch.pow(dist, 2)
|
| 606 |
+
dist = torch.sum(dist, dim=-1)
|
| 607 |
+
pred = dist.argmin(dim=-1)
|
| 608 |
+
|
| 609 |
+
else:
|
| 610 |
+
# If no palette is specified SegGpt will try to paint using the mask class idx as RGB
|
| 611 |
+
pred = mask.mean(dim=0).int()
|
| 612 |
+
|
| 613 |
+
semantic_segmentation.append(pred)
|
| 614 |
+
|
| 615 |
+
return semantic_segmentation
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
__all__ = ["SegGptImageProcessor"]
|
docs/transformers/build/lib/transformers/models/seggpt/modeling_seggpt.py
ADDED
|
@@ -0,0 +1,1031 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch SegGpt model."""
|
| 16 |
+
|
| 17 |
+
import collections.abc
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.utils.checkpoint
|
| 23 |
+
from torch import nn
|
| 24 |
+
from torch.nn import functional as F
|
| 25 |
+
|
| 26 |
+
from ...activations import ACT2FN
|
| 27 |
+
from ...modeling_utils import PreTrainedModel
|
| 28 |
+
from ...utils import (
|
| 29 |
+
ModelOutput,
|
| 30 |
+
add_start_docstrings,
|
| 31 |
+
add_start_docstrings_to_model_forward,
|
| 32 |
+
logging,
|
| 33 |
+
replace_return_docstrings,
|
| 34 |
+
torch_int,
|
| 35 |
+
)
|
| 36 |
+
from .configuration_seggpt import SegGptConfig
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
# General docstring
|
| 42 |
+
_CONFIG_FOR_DOC = "SegGptConfig"
|
| 43 |
+
|
| 44 |
+
# Base docstring
|
| 45 |
+
_CHECKPOINT_FOR_DOC = "BAAI/seggpt-vit-large"
|
| 46 |
+
_EXPECTED_OUTPUT_SHAPE = [3, 896, 448]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class SegGptEncoderOutput(ModelOutput):
|
| 51 |
+
"""
|
| 52 |
+
Output type of [`SegGptEncoderOutput`].
|
| 53 |
+
Args:
|
| 54 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, patch_height, patch_width, hidden_size)`):
|
| 55 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 56 |
+
hidden_states (`Tuple[torch.FloatTensor]`, `optional`, returned when `config.output_hidden_states=True`):
|
| 57 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 58 |
+
of shape `(batch_size, patch_height, patch_width, hidden_size)`.
|
| 59 |
+
attentions (`Tuple[torch.FloatTensor]`, `optional`, returned when `config.output_attentions=True`):
|
| 60 |
+
Tuple of *torch.FloatTensor* (one for each layer) of shape
|
| 61 |
+
`(batch_size, num_heads, seq_len, seq_len)`.
|
| 62 |
+
intermediate_hidden_states (`Tuple[torch.FloatTensor]`, *optional*, returned when `config.intermediate_hidden_state_indices` is set):
|
| 63 |
+
Tuple of `torch.FloatTensor` of shape `(batch_size, patch_height, patch_width, hidden_size)`.
|
| 64 |
+
Each element in the Tuple corresponds to the output of the layer specified in `config.intermediate_hidden_state_indices`.
|
| 65 |
+
Additionaly, each feature passes through a LayerNorm.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
last_hidden_state: torch.FloatTensor
|
| 69 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 70 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 71 |
+
intermediate_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@dataclass
|
| 75 |
+
class SegGptImageSegmentationOutput(ModelOutput):
|
| 76 |
+
"""
|
| 77 |
+
Output type of [`SegGptImageSegmentationOutput`].
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
|
| 81 |
+
The loss value.
|
| 82 |
+
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 83 |
+
The predicted masks.
|
| 84 |
+
hidden_states (`Tuple[torch.FloatTensor]`, `optional`, returned when `config.output_hidden_states=True`):
|
| 85 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 86 |
+
of shape `(batch_size, patch_height, patch_width, hidden_size)`.
|
| 87 |
+
attentions (`Tuple[torch.FloatTensor]`, `optional`, returned when `config.output_attentions=True`):
|
| 88 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape
|
| 89 |
+
`(batch_size, num_heads, seq_len, seq_len)`.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
loss: Optional[torch.FloatTensor] = None
|
| 93 |
+
pred_masks: Optional[torch.FloatTensor] = None
|
| 94 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 95 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# Copied from transformers.models.sam.modeling_sam.SamPatchEmbeddings with Sam->SegGpt
|
| 99 |
+
class SegGptPatchEmbeddings(nn.Module):
|
| 100 |
+
"""
|
| 101 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
| 102 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
| 103 |
+
Transformer.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def __init__(self, config):
|
| 107 |
+
super().__init__()
|
| 108 |
+
image_size, patch_size = config.image_size, config.patch_size
|
| 109 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
| 110 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
| 111 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
| 112 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 113 |
+
self.image_size = image_size
|
| 114 |
+
self.patch_size = patch_size
|
| 115 |
+
self.num_channels = num_channels
|
| 116 |
+
self.num_patches = num_patches
|
| 117 |
+
|
| 118 |
+
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
| 119 |
+
|
| 120 |
+
def forward(self, pixel_values):
|
| 121 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 122 |
+
if num_channels != self.num_channels:
|
| 123 |
+
raise ValueError(
|
| 124 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 125 |
+
)
|
| 126 |
+
if height != self.image_size[0] or width != self.image_size[1]:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
| 129 |
+
)
|
| 130 |
+
embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
|
| 131 |
+
return embeddings
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class SegGptEmbeddings(nn.Module):
|
| 135 |
+
"""
|
| 136 |
+
Construct the embeddings from patch, position embeddings for input and prompt.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(self, config: SegGptConfig) -> None:
|
| 140 |
+
super().__init__()
|
| 141 |
+
|
| 142 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size))
|
| 143 |
+
self.segment_token_input = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size))
|
| 144 |
+
self.segment_token_prompt = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size))
|
| 145 |
+
# token for seg types
|
| 146 |
+
self.type_token_semantic = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size))
|
| 147 |
+
self.type_token_instance = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size))
|
| 148 |
+
|
| 149 |
+
self.patch_embeddings = SegGptPatchEmbeddings(config)
|
| 150 |
+
|
| 151 |
+
num_positions = (config.pretrain_image_size // config.patch_size) ** 2 + 1
|
| 152 |
+
self.position_embeddings = nn.Parameter(torch.randn(1, num_positions, config.hidden_size))
|
| 153 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 154 |
+
|
| 155 |
+
def interpolate_pos_encoding(self, height: int, width: int) -> torch.Tensor:
|
| 156 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
| 157 |
+
num_patches = patch_pos_embed.shape[1]
|
| 158 |
+
pretrain_patch_size = torch_int(num_patches**0.5)
|
| 159 |
+
|
| 160 |
+
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
| 161 |
+
if torch.jit.is_tracing() or pretrain_patch_size != height or pretrain_patch_size != width:
|
| 162 |
+
patch_pos_embed = F.interpolate(
|
| 163 |
+
patch_pos_embed.reshape(1, pretrain_patch_size, pretrain_patch_size, -1).permute(0, 3, 1, 2),
|
| 164 |
+
size=(height, width),
|
| 165 |
+
mode="bicubic",
|
| 166 |
+
align_corners=False,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
return patch_pos_embed.permute(0, 2, 3, 1)
|
| 170 |
+
else:
|
| 171 |
+
return patch_pos_embed.reshape(1, height, width, -1)
|
| 172 |
+
|
| 173 |
+
def forward(
|
| 174 |
+
self,
|
| 175 |
+
pixel_values: torch.Tensor,
|
| 176 |
+
prompt_pixel_values: torch.Tensor,
|
| 177 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
| 178 |
+
embedding_type: Optional[str] = None,
|
| 179 |
+
) -> torch.Tensor:
|
| 180 |
+
input_embeddings = self.patch_embeddings(pixel_values)
|
| 181 |
+
prompt_embeddings = self.patch_embeddings(prompt_pixel_values)
|
| 182 |
+
|
| 183 |
+
batch_size, patch_height, patch_width, _ = input_embeddings.shape
|
| 184 |
+
|
| 185 |
+
mask_token = self.mask_token.expand(batch_size, patch_height, patch_width, -1)
|
| 186 |
+
# replace the masked visual tokens by mask_token
|
| 187 |
+
w = bool_masked_pos.unsqueeze(-1).type_as(mask_token).reshape(-1, patch_height, patch_width, 1)
|
| 188 |
+
prompt_embeddings = prompt_embeddings * (1 - w) + mask_token * w
|
| 189 |
+
|
| 190 |
+
embedding_type = embedding_type if embedding_type is not None else "instance"
|
| 191 |
+
|
| 192 |
+
# add positional encoding to each token
|
| 193 |
+
pos_embed = self.interpolate_pos_encoding(patch_height, patch_width)
|
| 194 |
+
|
| 195 |
+
# add segment token
|
| 196 |
+
input_embeddings = input_embeddings + self.segment_token_input
|
| 197 |
+
prompt_embeddings = prompt_embeddings + self.segment_token_prompt
|
| 198 |
+
|
| 199 |
+
# add position embedding skipping CLS
|
| 200 |
+
input_embeddings = input_embeddings + pos_embed
|
| 201 |
+
prompt_embeddings = prompt_embeddings + pos_embed
|
| 202 |
+
|
| 203 |
+
# add type embedding to each token
|
| 204 |
+
if embedding_type == "semantic":
|
| 205 |
+
type_embedding = self.type_token_semantic
|
| 206 |
+
elif embedding_type == "instance":
|
| 207 |
+
type_embedding = self.type_token_instance
|
| 208 |
+
else:
|
| 209 |
+
raise ValueError(f"Embedding type should be either 'semantic' or 'instance', but got {embedding_type}")
|
| 210 |
+
|
| 211 |
+
input_embeddings = input_embeddings + type_embedding
|
| 212 |
+
prompt_embeddings = prompt_embeddings + type_embedding
|
| 213 |
+
|
| 214 |
+
embeddings = torch.cat((input_embeddings, prompt_embeddings), dim=0)
|
| 215 |
+
|
| 216 |
+
return embeddings
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class SegGptAttention(nn.Module):
|
| 220 |
+
"""Multi-head Attention block with relative position embeddings."""
|
| 221 |
+
|
| 222 |
+
def __init__(self, config):
|
| 223 |
+
super().__init__()
|
| 224 |
+
image_size, patch_size = config.image_size, config.patch_size
|
| 225 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
| 226 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
| 227 |
+
|
| 228 |
+
input_size = (image_size[0] // config.patch_size, image_size[1] // config.patch_size)
|
| 229 |
+
head_dim = config.hidden_size // config.num_attention_heads
|
| 230 |
+
|
| 231 |
+
self.num_attention_heads = config.num_attention_heads
|
| 232 |
+
self.scale = head_dim**-0.5
|
| 233 |
+
|
| 234 |
+
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias)
|
| 235 |
+
self.proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| 236 |
+
|
| 237 |
+
self.use_relative_position_embeddings = config.use_relative_position_embeddings
|
| 238 |
+
if self.use_relative_position_embeddings:
|
| 239 |
+
if input_size is None:
|
| 240 |
+
raise ValueError("Input size must be provided if using relative positional encoding.")
|
| 241 |
+
|
| 242 |
+
# initialize relative positional embeddings
|
| 243 |
+
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
| 244 |
+
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
| 245 |
+
|
| 246 |
+
def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
| 247 |
+
"""
|
| 248 |
+
Get relative positional embeddings according to the relative positions of
|
| 249 |
+
query and key sizes.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
q_size (int):
|
| 253 |
+
size of the query.
|
| 254 |
+
k_size (int):
|
| 255 |
+
size of key k.
|
| 256 |
+
rel_pos (`torch.Tensor`):
|
| 257 |
+
relative position embeddings (L, channel).
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
Extracted positional embeddings according to relative positions.
|
| 261 |
+
"""
|
| 262 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
| 263 |
+
# Interpolate rel pos.
|
| 264 |
+
rel_pos_resized = F.interpolate(
|
| 265 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
| 266 |
+
size=max_rel_dist,
|
| 267 |
+
mode="linear",
|
| 268 |
+
)
|
| 269 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
| 270 |
+
|
| 271 |
+
# Scale the coords with short length if shapes for q and k are different.
|
| 272 |
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
| 273 |
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
| 274 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
| 275 |
+
|
| 276 |
+
return rel_pos_resized[relative_coords.long()]
|
| 277 |
+
|
| 278 |
+
def add_decomposed_rel_pos(
|
| 279 |
+
self,
|
| 280 |
+
attn: torch.Tensor,
|
| 281 |
+
query: torch.Tensor,
|
| 282 |
+
rel_pos_h: torch.Tensor,
|
| 283 |
+
rel_pos_w: torch.Tensor,
|
| 284 |
+
q_size: Tuple[int, int],
|
| 285 |
+
k_size: Tuple[int, int],
|
| 286 |
+
) -> torch.Tensor:
|
| 287 |
+
"""
|
| 288 |
+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
| 289 |
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
attn (`torch.Tensor`):
|
| 293 |
+
attention map.
|
| 294 |
+
query (`torch.Tensor`):
|
| 295 |
+
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
|
| 296 |
+
rel_pos_h (`torch.Tensor`):
|
| 297 |
+
relative position embeddings (Lh, channel) for height axis.
|
| 298 |
+
rel_pos_w (`torch.Tensor`):
|
| 299 |
+
relative position embeddings (Lw, channel) for width axis.
|
| 300 |
+
q_size (tuple):
|
| 301 |
+
spatial sequence size of query q with (query_height, query_width).
|
| 302 |
+
k_size (tuple):
|
| 303 |
+
spatial sequence size of key k with (key_height, key_width).
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
attn (`torch.Tensor`):
|
| 307 |
+
attention map with added relative positional embeddings.
|
| 308 |
+
"""
|
| 309 |
+
query_height, query_width = q_size
|
| 310 |
+
key_height, key_width = k_size
|
| 311 |
+
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
|
| 312 |
+
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
|
| 313 |
+
|
| 314 |
+
batch_size, _, dim = query.shape
|
| 315 |
+
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
|
| 316 |
+
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
|
| 317 |
+
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
|
| 318 |
+
attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
|
| 319 |
+
attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
| 320 |
+
attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
|
| 321 |
+
return attn
|
| 322 |
+
|
| 323 |
+
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
| 324 |
+
batch_size, height, width, _ = hidden_states.shape
|
| 325 |
+
# qkv with shape (3, batch_size, nHead, height * width, channel)
|
| 326 |
+
qkv = (
|
| 327 |
+
self.qkv(hidden_states)
|
| 328 |
+
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
|
| 329 |
+
.permute(2, 0, 3, 1, 4)
|
| 330 |
+
)
|
| 331 |
+
# q, k, v with shape (batch_size * nHead, height * width, channel)
|
| 332 |
+
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
|
| 333 |
+
|
| 334 |
+
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
|
| 335 |
+
|
| 336 |
+
if self.use_relative_position_embeddings:
|
| 337 |
+
attn_weights = self.add_decomposed_rel_pos(
|
| 338 |
+
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
|
| 342 |
+
|
| 343 |
+
if output_attentions:
|
| 344 |
+
# this operation is a bit awkward, but it's required to
|
| 345 |
+
# make sure that attn_weights keeps its gradient.
|
| 346 |
+
# In order to do so, attn_weights have to reshaped
|
| 347 |
+
# twice and have to be reused in the following
|
| 348 |
+
attn_weights_reshaped = attn_weights.view(batch_size, self.num_attention_heads, height * width, -1)
|
| 349 |
+
attn_weights = attn_weights_reshaped.view(batch_size * self.num_attention_heads, height * width, -1)
|
| 350 |
+
else:
|
| 351 |
+
attn_weights_reshaped = None
|
| 352 |
+
|
| 353 |
+
attn_output = (attn_weights @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
|
| 354 |
+
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
|
| 355 |
+
|
| 356 |
+
attn_output = self.proj(attn_output)
|
| 357 |
+
|
| 358 |
+
return (attn_output, attn_weights_reshaped)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
# Copied from transformers.models.sam.modeling_sam.SamMLPBlock with SamMLPBlock->SegGptMlp
|
| 362 |
+
class SegGptMlp(nn.Module):
|
| 363 |
+
def __init__(self, config):
|
| 364 |
+
super().__init__()
|
| 365 |
+
self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
|
| 366 |
+
self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
|
| 367 |
+
self.act = ACT2FN[config.hidden_act]
|
| 368 |
+
|
| 369 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 370 |
+
hidden_states = self.lin1(hidden_states)
|
| 371 |
+
hidden_states = self.act(hidden_states)
|
| 372 |
+
hidden_states = self.lin2(hidden_states)
|
| 373 |
+
return hidden_states
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
# Copied from transformers.models.beit.modeling_beit.drop_path
|
| 377 |
+
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
| 378 |
+
"""
|
| 379 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 380 |
+
|
| 381 |
+
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
| 382 |
+
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 383 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
| 384 |
+
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
| 385 |
+
argument.
|
| 386 |
+
"""
|
| 387 |
+
if drop_prob == 0.0 or not training:
|
| 388 |
+
return input
|
| 389 |
+
keep_prob = 1 - drop_prob
|
| 390 |
+
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 391 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
| 392 |
+
random_tensor.floor_() # binarize
|
| 393 |
+
output = input.div(keep_prob) * random_tensor
|
| 394 |
+
return output
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->SegGpt
|
| 398 |
+
class SegGptDropPath(nn.Module):
|
| 399 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 400 |
+
|
| 401 |
+
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
| 402 |
+
super().__init__()
|
| 403 |
+
self.drop_prob = drop_prob
|
| 404 |
+
|
| 405 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 406 |
+
return drop_path(hidden_states, self.drop_prob, self.training)
|
| 407 |
+
|
| 408 |
+
def extra_repr(self) -> str:
|
| 409 |
+
return "p={}".format(self.drop_prob)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class SegGptLayer(nn.Module):
|
| 413 |
+
def __init__(self, config: SegGptConfig, drop_path_rate: float) -> None:
|
| 414 |
+
super().__init__()
|
| 415 |
+
self.attention = SegGptAttention(config)
|
| 416 |
+
self.mlp = SegGptMlp(config)
|
| 417 |
+
self.drop_path = SegGptDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
| 418 |
+
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 419 |
+
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 420 |
+
|
| 421 |
+
def forward(
|
| 422 |
+
self,
|
| 423 |
+
hidden_states: torch.Tensor,
|
| 424 |
+
ensemble_cond: int,
|
| 425 |
+
feature_ensemble: bool = False,
|
| 426 |
+
output_attentions: bool = False,
|
| 427 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 428 |
+
self_attention_outputs = self.attention(
|
| 429 |
+
self.layernorm_before(hidden_states), # in SegGpt, layernorm is applied before self-attention
|
| 430 |
+
output_attentions=output_attentions,
|
| 431 |
+
)
|
| 432 |
+
attention_output = self_attention_outputs[0]
|
| 433 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 434 |
+
|
| 435 |
+
if feature_ensemble and attention_output.shape[0] // 2 >= ensemble_cond:
|
| 436 |
+
prompt, inputs = attention_output.split(attention_output.shape[1] // 2, dim=1)
|
| 437 |
+
if ensemble_cond == 2:
|
| 438 |
+
num_prompts = attention_output.shape[0] // 2
|
| 439 |
+
inputs = inputs.reshape(2, num_prompts, -1)
|
| 440 |
+
inputs = inputs.mean(dim=1, keepdim=True).expand_as(inputs)
|
| 441 |
+
inputs = inputs.reshape(*prompt.shape)
|
| 442 |
+
else:
|
| 443 |
+
inputs = inputs.mean(dim=0, keepdim=True).expand_as(inputs)
|
| 444 |
+
attention_output = torch.cat([prompt, inputs], dim=1)
|
| 445 |
+
|
| 446 |
+
# first residual connection
|
| 447 |
+
hidden_states = self.drop_path(attention_output) + hidden_states
|
| 448 |
+
residual = hidden_states
|
| 449 |
+
|
| 450 |
+
hidden_states = self.layernorm_after(hidden_states)
|
| 451 |
+
hidden_states = self.mlp(hidden_states)
|
| 452 |
+
hidden_states = residual + self.drop_path(hidden_states)
|
| 453 |
+
|
| 454 |
+
outputs = (hidden_states,) + outputs
|
| 455 |
+
|
| 456 |
+
return outputs
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class SegGptEncoder(nn.Module):
|
| 460 |
+
def __init__(self, config: SegGptConfig) -> None:
|
| 461 |
+
super().__init__()
|
| 462 |
+
self.config = config
|
| 463 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
|
| 464 |
+
self.layers = nn.ModuleList([SegGptLayer(config, dpr[i]) for i in range(config.num_hidden_layers)])
|
| 465 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 466 |
+
self.gradient_checkpointing = False
|
| 467 |
+
|
| 468 |
+
def forward(
|
| 469 |
+
self,
|
| 470 |
+
hidden_states: torch.Tensor,
|
| 471 |
+
feature_ensemble: bool = False,
|
| 472 |
+
output_attentions: bool = False,
|
| 473 |
+
output_hidden_states: bool = False,
|
| 474 |
+
return_dict: bool = True,
|
| 475 |
+
) -> Union[tuple, SegGptEncoderOutput]:
|
| 476 |
+
all_hidden_states = () if output_hidden_states else None
|
| 477 |
+
all_self_attentions = () if output_attentions else None
|
| 478 |
+
intermediate_hidden_states = []
|
| 479 |
+
|
| 480 |
+
for i, layer_module in enumerate(self.layers):
|
| 481 |
+
if output_hidden_states:
|
| 482 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 483 |
+
|
| 484 |
+
# Condition to check if we have the appropriate number of prompts to ensemble
|
| 485 |
+
ensemble_cond = 2 if self.config.merge_index > i else 1
|
| 486 |
+
|
| 487 |
+
if self.gradient_checkpointing and self.training:
|
| 488 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 489 |
+
layer_module.__call__,
|
| 490 |
+
hidden_states,
|
| 491 |
+
ensemble_cond,
|
| 492 |
+
feature_ensemble,
|
| 493 |
+
output_attentions,
|
| 494 |
+
)
|
| 495 |
+
else:
|
| 496 |
+
layer_outputs = layer_module(hidden_states, ensemble_cond, feature_ensemble, output_attentions)
|
| 497 |
+
|
| 498 |
+
hidden_states = layer_outputs[0]
|
| 499 |
+
|
| 500 |
+
if i == self.config.merge_index:
|
| 501 |
+
hidden_states = (
|
| 502 |
+
hidden_states[: hidden_states.shape[0] // 2] + hidden_states[hidden_states.shape[0] // 2 :]
|
| 503 |
+
) * 0.5
|
| 504 |
+
|
| 505 |
+
if i in self.config.intermediate_hidden_state_indices:
|
| 506 |
+
intermediate_hidden_states.append(self.layernorm(hidden_states))
|
| 507 |
+
|
| 508 |
+
if output_attentions:
|
| 509 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 510 |
+
|
| 511 |
+
if output_hidden_states:
|
| 512 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 513 |
+
|
| 514 |
+
if not return_dict:
|
| 515 |
+
return tuple(
|
| 516 |
+
v
|
| 517 |
+
for v in [hidden_states, all_hidden_states, all_self_attentions, intermediate_hidden_states]
|
| 518 |
+
if v is not None
|
| 519 |
+
)
|
| 520 |
+
return SegGptEncoderOutput(
|
| 521 |
+
last_hidden_state=hidden_states,
|
| 522 |
+
hidden_states=all_hidden_states,
|
| 523 |
+
attentions=all_self_attentions,
|
| 524 |
+
intermediate_hidden_states=intermediate_hidden_states,
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
# Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->SegGpt
|
| 529 |
+
class SegGptLayerNorm(nn.Module):
|
| 530 |
+
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 531 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
|
| 532 |
+
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
|
| 533 |
+
"""
|
| 534 |
+
|
| 535 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
| 536 |
+
super().__init__()
|
| 537 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 538 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 539 |
+
self.eps = eps
|
| 540 |
+
self.data_format = data_format
|
| 541 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 542 |
+
raise NotImplementedError(f"Unsupported data format: {self.data_format}")
|
| 543 |
+
self.normalized_shape = (normalized_shape,)
|
| 544 |
+
|
| 545 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 546 |
+
if self.data_format == "channels_last":
|
| 547 |
+
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 548 |
+
elif self.data_format == "channels_first":
|
| 549 |
+
input_dtype = x.dtype
|
| 550 |
+
x = x.float()
|
| 551 |
+
u = x.mean(1, keepdim=True)
|
| 552 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 553 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 554 |
+
x = x.to(dtype=input_dtype)
|
| 555 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 556 |
+
return x
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
class SegGptDecoderHead(nn.Module):
|
| 560 |
+
def __init__(self, config):
|
| 561 |
+
super().__init__()
|
| 562 |
+
self.conv = nn.Conv2d(
|
| 563 |
+
config.decoder_hidden_size,
|
| 564 |
+
config.decoder_hidden_size,
|
| 565 |
+
kernel_size=3,
|
| 566 |
+
padding=1,
|
| 567 |
+
)
|
| 568 |
+
self.layernorm = SegGptLayerNorm(
|
| 569 |
+
normalized_shape=config.decoder_hidden_size, eps=config.layer_norm_eps, data_format="channels_first"
|
| 570 |
+
)
|
| 571 |
+
self.act_fct = ACT2FN[config.hidden_act]
|
| 572 |
+
self.head = nn.Conv2d(config.decoder_hidden_size, 3, kernel_size=1, bias=True) # decoder to patch
|
| 573 |
+
|
| 574 |
+
def forward(self, hidden_states: torch.FloatTensor):
|
| 575 |
+
hidden_states = self.conv(hidden_states)
|
| 576 |
+
hidden_states = self.layernorm(hidden_states)
|
| 577 |
+
hidden_states = self.act_fct(hidden_states)
|
| 578 |
+
hidden_states = self.head(hidden_states)
|
| 579 |
+
|
| 580 |
+
return hidden_states
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
class SegGptDecoder(nn.Module):
|
| 584 |
+
def __init__(self, config):
|
| 585 |
+
super().__init__()
|
| 586 |
+
self.decoder_embed = nn.Linear(
|
| 587 |
+
config.hidden_size * len(config.intermediate_hidden_state_indices),
|
| 588 |
+
config.patch_size**2 * config.decoder_hidden_size,
|
| 589 |
+
bias=True,
|
| 590 |
+
)
|
| 591 |
+
self.decoder_pred = SegGptDecoderHead(config)
|
| 592 |
+
self.patch_size = config.patch_size
|
| 593 |
+
self.decoder_hidden_size = config.decoder_hidden_size
|
| 594 |
+
self.config = config
|
| 595 |
+
|
| 596 |
+
def _reshape_hidden_states(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
| 597 |
+
batch_size, patch_height, patch_width, _ = hidden_states.shape
|
| 598 |
+
hidden_states = hidden_states.reshape(
|
| 599 |
+
batch_size, patch_height, patch_width, self.patch_size, self.patch_size, self.decoder_hidden_size
|
| 600 |
+
)
|
| 601 |
+
hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
|
| 602 |
+
hidden_states = hidden_states.reshape(
|
| 603 |
+
shape=(batch_size, -1, patch_height * self.patch_size, patch_width * self.patch_size)
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
return hidden_states
|
| 607 |
+
|
| 608 |
+
def forward(self, hidden_states: torch.FloatTensor):
|
| 609 |
+
hidden_states = self.decoder_embed(hidden_states)
|
| 610 |
+
hidden_states = self._reshape_hidden_states(hidden_states)
|
| 611 |
+
hidden_states = self.decoder_pred(hidden_states)
|
| 612 |
+
|
| 613 |
+
return hidden_states
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
class SegGptPreTrainedModel(PreTrainedModel):
|
| 617 |
+
"""
|
| 618 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 619 |
+
models.
|
| 620 |
+
"""
|
| 621 |
+
|
| 622 |
+
config_class = SegGptConfig
|
| 623 |
+
base_model_prefix = "model"
|
| 624 |
+
main_input_name = "pixel_values"
|
| 625 |
+
supports_gradient_checkpointing = True
|
| 626 |
+
_no_split_modules = ["SegGptEmbeddings", "SegGptLayer"]
|
| 627 |
+
|
| 628 |
+
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
| 629 |
+
"""Initialize the weights"""
|
| 630 |
+
std = self.config.initializer_range
|
| 631 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 632 |
+
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
| 633 |
+
# `trunc_normal_cpu` not implemented in `half` issues
|
| 634 |
+
module.weight.data = nn.init.trunc_normal_(module.weight.data.to(torch.float32), mean=0.0, std=std).to(
|
| 635 |
+
module.weight.dtype
|
| 636 |
+
)
|
| 637 |
+
if module.bias is not None:
|
| 638 |
+
module.bias.data.zero_()
|
| 639 |
+
elif isinstance(module, nn.LayerNorm):
|
| 640 |
+
module.bias.data.zero_()
|
| 641 |
+
module.weight.data.fill_(1.0)
|
| 642 |
+
elif isinstance(module, SegGptAttention):
|
| 643 |
+
module.rel_pos_h.data = nn.init.trunc_normal_(
|
| 644 |
+
module.rel_pos_h.data.to(torch.float32),
|
| 645 |
+
mean=0.0,
|
| 646 |
+
std=std,
|
| 647 |
+
).to(module.rel_pos_h.dtype)
|
| 648 |
+
|
| 649 |
+
module.rel_pos_w.data = nn.init.trunc_normal_(
|
| 650 |
+
module.rel_pos_w.data.to(torch.float32),
|
| 651 |
+
mean=0.0,
|
| 652 |
+
std=std,
|
| 653 |
+
).to(module.rel_pos_w.dtype)
|
| 654 |
+
|
| 655 |
+
elif isinstance(module, SegGptEmbeddings):
|
| 656 |
+
module.position_embeddings.data = nn.init.trunc_normal_(
|
| 657 |
+
module.position_embeddings.data.to(torch.float32),
|
| 658 |
+
mean=0.0,
|
| 659 |
+
std=std,
|
| 660 |
+
).to(module.position_embeddings.dtype)
|
| 661 |
+
|
| 662 |
+
torch.nn.init.normal_(module.mask_token, std=std)
|
| 663 |
+
torch.nn.init.normal_(module.segment_token_input, std=std)
|
| 664 |
+
torch.nn.init.normal_(module.segment_token_prompt, std=std)
|
| 665 |
+
torch.nn.init.normal_(module.type_token_semantic, std=std)
|
| 666 |
+
torch.nn.init.normal_(module.type_token_instance, std=std)
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
SEGGPT_START_DOCSTRING = r"""
|
| 670 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 671 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 672 |
+
behavior.
|
| 673 |
+
|
| 674 |
+
Parameters:
|
| 675 |
+
config ([`SegGptConfig`]): Model configuration class with all the parameters of the model.
|
| 676 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 677 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 678 |
+
"""
|
| 679 |
+
|
| 680 |
+
SEGGPT_INPUTS_DOCSTRING = r"""
|
| 681 |
+
Args:
|
| 682 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 683 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`SegGptImageProcessor.__call__`]
|
| 684 |
+
for details.
|
| 685 |
+
|
| 686 |
+
prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 687 |
+
Prompt pixel values. Prompt pixel values can be obtained using [`AutoImageProcessor`]. See
|
| 688 |
+
[`SegGptImageProcessor.__call__`] for details.
|
| 689 |
+
|
| 690 |
+
prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 691 |
+
Prompt mask. Prompt mask can be obtained using [`AutoImageProcessor`]. See [`SegGptImageProcessor.__call__`] for
|
| 692 |
+
details.
|
| 693 |
+
|
| 694 |
+
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
| 695 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
| 696 |
+
|
| 697 |
+
feature_ensemble (`bool`, *optional*):
|
| 698 |
+
Boolean indicating whether to use feature ensemble or not. If `True`, the model will use feature ensemble
|
| 699 |
+
if we have at least two prompts. If `False`, the model will not use feature ensemble. This argument should
|
| 700 |
+
be considered when doing few-shot inference on an input image i.e. more than one prompt for the same image.
|
| 701 |
+
|
| 702 |
+
embedding_type (`str`, *optional*):
|
| 703 |
+
Embedding type. Indicates whether the prompt is a semantic or instance embedding. Can be either
|
| 704 |
+
instance or semantic.
|
| 705 |
+
|
| 706 |
+
output_attentions (`bool`, *optional*):
|
| 707 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 708 |
+
tensors for more detail.
|
| 709 |
+
output_hidden_states (`bool`, *optional*):
|
| 710 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 711 |
+
more detail.
|
| 712 |
+
return_dict (`bool`, *optional*):
|
| 713 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 714 |
+
"""
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
@add_start_docstrings(
|
| 718 |
+
"The bare SegGpt Model transformer outputting raw hidden-states without any specific head on top.",
|
| 719 |
+
SEGGPT_START_DOCSTRING,
|
| 720 |
+
)
|
| 721 |
+
class SegGptModel(SegGptPreTrainedModel):
|
| 722 |
+
def __init__(self, config: SegGptConfig):
|
| 723 |
+
super().__init__(config)
|
| 724 |
+
self.config = config
|
| 725 |
+
|
| 726 |
+
self.embeddings = SegGptEmbeddings(config)
|
| 727 |
+
self.encoder = SegGptEncoder(config)
|
| 728 |
+
|
| 729 |
+
# Initialize weights and apply final processing
|
| 730 |
+
self.post_init()
|
| 731 |
+
|
| 732 |
+
def get_input_embeddings(self) -> SegGptPatchEmbeddings:
|
| 733 |
+
return self.embeddings.patch_embeddings
|
| 734 |
+
|
| 735 |
+
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
| 736 |
+
"""
|
| 737 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 738 |
+
class PreTrainedModel
|
| 739 |
+
"""
|
| 740 |
+
for layer, heads in heads_to_prune.items():
|
| 741 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 742 |
+
|
| 743 |
+
@add_start_docstrings_to_model_forward(SEGGPT_INPUTS_DOCSTRING)
|
| 744 |
+
@replace_return_docstrings(output_type=SegGptEncoderOutput, config_class=_CONFIG_FOR_DOC)
|
| 745 |
+
def forward(
|
| 746 |
+
self,
|
| 747 |
+
pixel_values: torch.Tensor,
|
| 748 |
+
prompt_pixel_values: torch.Tensor,
|
| 749 |
+
prompt_masks: torch.Tensor,
|
| 750 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
| 751 |
+
feature_ensemble: Optional[bool] = None,
|
| 752 |
+
embedding_type: Optional[str] = None,
|
| 753 |
+
labels: Optional[torch.FloatTensor] = None,
|
| 754 |
+
output_attentions: Optional[bool] = None,
|
| 755 |
+
output_hidden_states: Optional[bool] = None,
|
| 756 |
+
return_dict: Optional[bool] = None,
|
| 757 |
+
) -> Union[Tuple, SegGptEncoderOutput]:
|
| 758 |
+
r"""
|
| 759 |
+
labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, `optional`):
|
| 760 |
+
Ground truth mask for input images.
|
| 761 |
+
|
| 762 |
+
Returns:
|
| 763 |
+
|
| 764 |
+
Examples:
|
| 765 |
+
|
| 766 |
+
```python
|
| 767 |
+
>>> from transformers import SegGptImageProcessor, SegGptModel
|
| 768 |
+
>>> from PIL import Image
|
| 769 |
+
>>> import requests
|
| 770 |
+
|
| 771 |
+
>>> image_input_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg"
|
| 772 |
+
>>> image_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg"
|
| 773 |
+
>>> mask_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png"
|
| 774 |
+
|
| 775 |
+
>>> image_input = Image.open(requests.get(image_input_url, stream=True).raw)
|
| 776 |
+
>>> image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw)
|
| 777 |
+
>>> mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw).convert("L")
|
| 778 |
+
|
| 779 |
+
>>> checkpoint = "BAAI/seggpt-vit-large"
|
| 780 |
+
>>> model = SegGptModel.from_pretrained(checkpoint)
|
| 781 |
+
>>> image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
|
| 782 |
+
|
| 783 |
+
>>> inputs = image_processor(images=image_input, prompt_images=image_prompt, prompt_masks=mask_prompt, return_tensors="pt")
|
| 784 |
+
|
| 785 |
+
>>> outputs = model(**inputs)
|
| 786 |
+
>>> list(outputs.last_hidden_state.shape)
|
| 787 |
+
[1, 56, 28, 1024]
|
| 788 |
+
```
|
| 789 |
+
"""
|
| 790 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 791 |
+
output_hidden_states = (
|
| 792 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 793 |
+
)
|
| 794 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 795 |
+
feature_ensemble = feature_ensemble if feature_ensemble is not None else False
|
| 796 |
+
|
| 797 |
+
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
|
| 798 |
+
pixel_values = pixel_values.to(expected_dtype)
|
| 799 |
+
prompt_pixel_values = prompt_pixel_values.to(expected_dtype)
|
| 800 |
+
|
| 801 |
+
# Prepare inputs
|
| 802 |
+
pixel_values = torch.cat((prompt_pixel_values, pixel_values), dim=2)
|
| 803 |
+
prompt_pixel_values = (
|
| 804 |
+
torch.cat((prompt_masks, prompt_masks), dim=2)
|
| 805 |
+
if labels is None
|
| 806 |
+
else torch.cat((prompt_masks, labels), dim=2)
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
if bool_masked_pos is None and labels is not None:
|
| 810 |
+
logger.warning_once(
|
| 811 |
+
"Labels were provided, but bool_masked_pos were not. It will be set to default value. If you're training the model, make sure to provide a bool_masked_pos."
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
# We concat on height axis so SegGPT can handle as a single image, hence we need to mask the portion
|
| 815 |
+
# of the mask prompt pixels that will be destinated to the prediction as they don't add any information.
|
| 816 |
+
# This is only the case for inference. In training, the model concat of prompt mask and label is masked
|
| 817 |
+
# and reconstructed together (In-Context Painting).
|
| 818 |
+
if bool_masked_pos is None:
|
| 819 |
+
num_patches = self.embeddings.patch_embeddings.num_patches
|
| 820 |
+
bool_masked_pos_zeros = torch.zeros(num_patches // 2, dtype=torch.bool, device=pixel_values.device)
|
| 821 |
+
bool_masked_pos_ones = torch.ones(
|
| 822 |
+
num_patches - num_patches // 2, dtype=torch.bool, device=pixel_values.device
|
| 823 |
+
)
|
| 824 |
+
bool_masked_pos = torch.cat([bool_masked_pos_zeros, bool_masked_pos_ones])
|
| 825 |
+
bool_masked_pos = bool_masked_pos.unsqueeze(0)
|
| 826 |
+
|
| 827 |
+
embedding_output = self.embeddings(
|
| 828 |
+
pixel_values, prompt_pixel_values, embedding_type=embedding_type, bool_masked_pos=bool_masked_pos
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
encoder_outputs = self.encoder(
|
| 832 |
+
embedding_output,
|
| 833 |
+
feature_ensemble=feature_ensemble,
|
| 834 |
+
output_attentions=output_attentions,
|
| 835 |
+
output_hidden_states=output_hidden_states,
|
| 836 |
+
return_dict=return_dict,
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
return encoder_outputs
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
def patchify(tensor: torch.Tensor, patch_size: int) -> torch.Tensor:
|
| 843 |
+
batch_size, num_channels, height, width = tensor.shape
|
| 844 |
+
patch_height = height // patch_size
|
| 845 |
+
patch_width = width // patch_size
|
| 846 |
+
|
| 847 |
+
tensor = tensor.reshape(shape=(batch_size, num_channels, patch_height, patch_size, patch_width, patch_size))
|
| 848 |
+
tensor = tensor.permute(0, 2, 4, 3, 5, 1)
|
| 849 |
+
tensor = tensor.reshape(shape=(batch_size, patch_height * patch_width, patch_size**2 * 3))
|
| 850 |
+
|
| 851 |
+
return tensor
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
def unpatchify(tensor: torch.Tensor, patch_height: int, patch_width: int) -> torch.Tensor:
|
| 855 |
+
batch_size = tensor.shape[0]
|
| 856 |
+
patch_size = int((tensor.shape[-1] / 3) ** 0.5)
|
| 857 |
+
if patch_height * patch_width != tensor.shape[1]:
|
| 858 |
+
raise ValueError(
|
| 859 |
+
f"Number of patches {tensor.shape[1]} does not match patch height ({patch_height}) and width ({patch_width})."
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
tensor = tensor.reshape(shape=(batch_size, patch_height, patch_width, patch_size, patch_size, 3))
|
| 863 |
+
tensor = tensor.permute(0, 5, 1, 3, 2, 4)
|
| 864 |
+
tensor = tensor.reshape(shape=(batch_size, 3, patch_height * patch_size, patch_width * patch_size))
|
| 865 |
+
|
| 866 |
+
return tensor
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
class SegGptLoss(nn.Module):
|
| 870 |
+
def __init__(self, config):
|
| 871 |
+
super().__init__()
|
| 872 |
+
self.beta = config.beta
|
| 873 |
+
self.patch_size = config.patch_size
|
| 874 |
+
|
| 875 |
+
def forward(
|
| 876 |
+
self,
|
| 877 |
+
prompt_masks: torch.FloatTensor,
|
| 878 |
+
pred_masks: torch.FloatTensor,
|
| 879 |
+
labels: torch.FloatTensor,
|
| 880 |
+
bool_masked_pos: torch.BoolTensor,
|
| 881 |
+
):
|
| 882 |
+
"""Computes the L1 loss between the predicted masks and the ground truth masks.
|
| 883 |
+
|
| 884 |
+
Args:
|
| 885 |
+
prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 886 |
+
Pixel values from mask prompt.
|
| 887 |
+
|
| 888 |
+
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, 2*height, width)`):
|
| 889 |
+
Predicted masks.
|
| 890 |
+
|
| 891 |
+
labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 892 |
+
Ground truth mask for input images.
|
| 893 |
+
|
| 894 |
+
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
| 895 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
| 896 |
+
|
| 897 |
+
Returns:
|
| 898 |
+
`torch.FloatTensor`: The mean L1 loss between the predicted masks and the ground truth masks.
|
| 899 |
+
"""
|
| 900 |
+
ground_truth = torch.cat((prompt_masks, labels), dim=2)
|
| 901 |
+
|
| 902 |
+
mask = bool_masked_pos[:, :, None].repeat(1, 1, self.patch_size**2 * 3)
|
| 903 |
+
mask = unpatchify(mask, ground_truth.shape[2] // self.patch_size, ground_truth.shape[3] // self.patch_size)
|
| 904 |
+
|
| 905 |
+
loss = F.smooth_l1_loss(pred_masks, ground_truth, reduction="none", beta=self.beta)
|
| 906 |
+
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
| 907 |
+
|
| 908 |
+
return loss
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
@add_start_docstrings(
|
| 912 |
+
"SegGpt model with a decoder on top for one-shot image segmentation.",
|
| 913 |
+
SEGGPT_START_DOCSTRING,
|
| 914 |
+
)
|
| 915 |
+
class SegGptForImageSegmentation(SegGptPreTrainedModel):
|
| 916 |
+
def __init__(self, config: SegGptConfig):
|
| 917 |
+
super().__init__(config)
|
| 918 |
+
self.config = config
|
| 919 |
+
|
| 920 |
+
self.model = SegGptModel(config)
|
| 921 |
+
self.decoder = SegGptDecoder(config)
|
| 922 |
+
|
| 923 |
+
# Initialize weights and apply final processing
|
| 924 |
+
self.post_init()
|
| 925 |
+
|
| 926 |
+
@add_start_docstrings_to_model_forward(SEGGPT_INPUTS_DOCSTRING)
|
| 927 |
+
@replace_return_docstrings(output_type=SegGptImageSegmentationOutput, config_class=_CONFIG_FOR_DOC)
|
| 928 |
+
def forward(
|
| 929 |
+
self,
|
| 930 |
+
pixel_values: torch.Tensor,
|
| 931 |
+
prompt_pixel_values: torch.Tensor,
|
| 932 |
+
prompt_masks: torch.Tensor,
|
| 933 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
| 934 |
+
feature_ensemble: Optional[bool] = None,
|
| 935 |
+
embedding_type: Optional[str] = None,
|
| 936 |
+
labels: Optional[torch.FloatTensor] = None,
|
| 937 |
+
output_attentions: Optional[bool] = None,
|
| 938 |
+
output_hidden_states: Optional[bool] = None,
|
| 939 |
+
return_dict: Optional[bool] = None,
|
| 940 |
+
) -> Union[Tuple, SegGptImageSegmentationOutput]:
|
| 941 |
+
r"""
|
| 942 |
+
labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, `optional`):
|
| 943 |
+
Ground truth mask for input images.
|
| 944 |
+
|
| 945 |
+
Returns:
|
| 946 |
+
|
| 947 |
+
Examples:
|
| 948 |
+
|
| 949 |
+
```python
|
| 950 |
+
>>> from transformers import SegGptImageProcessor, SegGptForImageSegmentation
|
| 951 |
+
>>> from PIL import Image
|
| 952 |
+
>>> import requests
|
| 953 |
+
|
| 954 |
+
>>> image_input_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg"
|
| 955 |
+
>>> image_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg"
|
| 956 |
+
>>> mask_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png"
|
| 957 |
+
|
| 958 |
+
>>> image_input = Image.open(requests.get(image_input_url, stream=True).raw)
|
| 959 |
+
>>> image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw)
|
| 960 |
+
>>> mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw).convert("L")
|
| 961 |
+
|
| 962 |
+
>>> checkpoint = "BAAI/seggpt-vit-large"
|
| 963 |
+
>>> model = SegGptForImageSegmentation.from_pretrained(checkpoint)
|
| 964 |
+
>>> image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
|
| 965 |
+
|
| 966 |
+
>>> inputs = image_processor(images=image_input, prompt_images=image_prompt, prompt_masks=mask_prompt, return_tensors="pt")
|
| 967 |
+
>>> outputs = model(**inputs)
|
| 968 |
+
>>> result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[(image_input.height, image_input.width)])[0]
|
| 969 |
+
>>> print(list(result.shape))
|
| 970 |
+
[170, 297]
|
| 971 |
+
```
|
| 972 |
+
"""
|
| 973 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 974 |
+
output_hidden_states = (
|
| 975 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 976 |
+
)
|
| 977 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 978 |
+
|
| 979 |
+
if bool_masked_pos is None:
|
| 980 |
+
num_patches = self.model.embeddings.patch_embeddings.num_patches
|
| 981 |
+
bool_masked_pos_zeros = torch.zeros(num_patches // 2, dtype=torch.bool, device=pixel_values.device)
|
| 982 |
+
bool_masked_pos_ones = torch.ones(
|
| 983 |
+
num_patches - num_patches // 2, dtype=torch.bool, device=pixel_values.device
|
| 984 |
+
)
|
| 985 |
+
bool_masked_pos = torch.cat([bool_masked_pos_zeros, bool_masked_pos_ones])
|
| 986 |
+
bool_masked_pos = bool_masked_pos.unsqueeze(0)
|
| 987 |
+
|
| 988 |
+
outputs = self.model(
|
| 989 |
+
pixel_values=pixel_values,
|
| 990 |
+
prompt_pixel_values=prompt_pixel_values,
|
| 991 |
+
prompt_masks=prompt_masks,
|
| 992 |
+
bool_masked_pos=bool_masked_pos,
|
| 993 |
+
feature_ensemble=feature_ensemble,
|
| 994 |
+
embedding_type=embedding_type,
|
| 995 |
+
labels=labels,
|
| 996 |
+
output_attentions=output_attentions,
|
| 997 |
+
output_hidden_states=output_hidden_states,
|
| 998 |
+
return_dict=return_dict,
|
| 999 |
+
)
|
| 1000 |
+
|
| 1001 |
+
intermediate_hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[-1]
|
| 1002 |
+
intermediate_hidden_states = torch.cat(intermediate_hidden_states, dim=-1)
|
| 1003 |
+
pred_masks = self.decoder(intermediate_hidden_states)
|
| 1004 |
+
|
| 1005 |
+
loss = None
|
| 1006 |
+
if labels is not None:
|
| 1007 |
+
loss_fn = SegGptLoss(self.config)
|
| 1008 |
+
loss = loss_fn(prompt_masks, pred_masks, labels, bool_masked_pos)
|
| 1009 |
+
|
| 1010 |
+
if not return_dict:
|
| 1011 |
+
output = (pred_masks,)
|
| 1012 |
+
if output_hidden_states:
|
| 1013 |
+
output = output + (outputs[1],)
|
| 1014 |
+
|
| 1015 |
+
if output_attentions:
|
| 1016 |
+
idx = 2 if output_hidden_states else 1
|
| 1017 |
+
output = output + (outputs[idx],)
|
| 1018 |
+
|
| 1019 |
+
if loss is not None:
|
| 1020 |
+
output = (loss,) + output
|
| 1021 |
+
return output
|
| 1022 |
+
|
| 1023 |
+
return SegGptImageSegmentationOutput(
|
| 1024 |
+
loss=loss,
|
| 1025 |
+
pred_masks=pred_masks,
|
| 1026 |
+
hidden_states=outputs.hidden_states,
|
| 1027 |
+
attentions=outputs.attentions,
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
|
| 1031 |
+
__all__ = ["SegGptModel", "SegGptPreTrainedModel", "SegGptForImageSegmentation"]
|
docs/transformers/build/lib/transformers/models/sew/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_sew import *
|
| 22 |
+
from .modeling_sew import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/sew/configuration_sew.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 ASAPP Inc. and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""SEW model configuration"""
|
| 16 |
+
|
| 17 |
+
import functools
|
| 18 |
+
import operator
|
| 19 |
+
|
| 20 |
+
from ...configuration_utils import PretrainedConfig
|
| 21 |
+
from ...utils import logging
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SEWConfig(PretrainedConfig):
|
| 28 |
+
r"""
|
| 29 |
+
This is the configuration class to store the configuration of a [`SEWModel`]. It is used to instantiate a SEW model
|
| 30 |
+
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 31 |
+
defaults will yield a similar configuration to that of the SEW
|
| 32 |
+
[asapp/sew-tiny-100k](https://huggingface.co/asapp/sew-tiny-100k) architecture.
|
| 33 |
+
|
| 34 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 35 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
vocab_size (`int`, *optional*, defaults to 32):
|
| 40 |
+
Vocabulary size of the SEW model. Defines the number of different tokens that can be represented by the
|
| 41 |
+
`inputs_ids` passed when calling [`SEW`].
|
| 42 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 43 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 44 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 45 |
+
Number of hidden layers in the Transformer encoder.
|
| 46 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 47 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 48 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 49 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 50 |
+
squeeze_factor (`int`, *optional*, defaults to 2):
|
| 51 |
+
Sequence length downsampling factor after the encoder and upsampling factor after the transformer.
|
| 52 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 53 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 54 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 55 |
+
hidden_dropout (`float`, *optional*, defaults to 0.1):
|
| 56 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 57 |
+
activation_dropout (`float`, *optional*, defaults to 0.1):
|
| 58 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 59 |
+
attention_dropout (`float`, *optional*, defaults to 0.1):
|
| 60 |
+
The dropout ratio for the attention probabilities.
|
| 61 |
+
final_dropout (`float`, *optional*, defaults to 0.1):
|
| 62 |
+
The dropout probability for the final projection layer of [`SEWForCTC`].
|
| 63 |
+
layerdrop (`float`, *optional*, defaults to 0.1):
|
| 64 |
+
The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more
|
| 65 |
+
details.
|
| 66 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 67 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 68 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 69 |
+
The epsilon used by the layer normalization layers.
|
| 70 |
+
feat_extract_norm (`str`, *optional*, defaults to `"group"`):
|
| 71 |
+
The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
|
| 72 |
+
normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
|
| 73 |
+
convolutional layers.
|
| 74 |
+
feat_proj_dropout (`float`, *optional*, defaults to 0.0):
|
| 75 |
+
The dropout probability for output of the feature encoder.
|
| 76 |
+
feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
|
| 77 |
+
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
|
| 78 |
+
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 79 |
+
conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512)`):
|
| 80 |
+
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
|
| 81 |
+
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
|
| 82 |
+
conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1)`):
|
| 83 |
+
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
|
| 84 |
+
of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
|
| 85 |
+
conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1)`):
|
| 86 |
+
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
|
| 87 |
+
length of *conv_kernel* defines the number of convolutional layers and has to match the length of
|
| 88 |
+
*conv_dim*.
|
| 89 |
+
conv_bias (`bool`, *optional*, defaults to `False`):
|
| 90 |
+
Whether the 1D convolutional layers have a bias.
|
| 91 |
+
num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
|
| 92 |
+
Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
|
| 93 |
+
embeddings layer.
|
| 94 |
+
num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
|
| 95 |
+
Number of groups of 1D convolutional positional embeddings layer.
|
| 96 |
+
apply_spec_augment (`bool`, *optional*, defaults to `True`):
|
| 97 |
+
Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
|
| 98 |
+
[SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
| 99 |
+
Recognition](https://arxiv.org/abs/1904.08779).
|
| 100 |
+
mask_time_prob (`float`, *optional*, defaults to 0.05):
|
| 101 |
+
Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
|
| 102 |
+
procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
|
| 103 |
+
reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
|
| 104 |
+
masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
|
| 105 |
+
actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
|
| 106 |
+
mask_time_length (`int`, *optional*, defaults to 10):
|
| 107 |
+
Length of vector span along the time axis.
|
| 108 |
+
mask_time_min_masks (`int`, *optional*, defaults to 2),:
|
| 109 |
+
The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
|
| 110 |
+
irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
|
| 111 |
+
mask_time_min_masks''
|
| 112 |
+
mask_feature_prob (`float`, *optional*, defaults to 0.0):
|
| 113 |
+
Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
|
| 114 |
+
masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
|
| 115 |
+
the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
|
| 116 |
+
span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
|
| 117 |
+
may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
|
| 118 |
+
True`.
|
| 119 |
+
mask_feature_length (`int`, *optional*, defaults to 10):
|
| 120 |
+
Length of vector span along the feature axis.
|
| 121 |
+
mask_feature_min_masks (`int`, *optional*, defaults to 0),:
|
| 122 |
+
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
|
| 123 |
+
step, irrespectively of `mask_feature_prob`. Only relevant if
|
| 124 |
+
''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
|
| 125 |
+
ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
|
| 126 |
+
Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
|
| 127 |
+
instance of [`SEWForCTC`].
|
| 128 |
+
ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
|
| 129 |
+
Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
|
| 130 |
+
occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
|
| 131 |
+
of [`SEWForCTC`].
|
| 132 |
+
use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
|
| 133 |
+
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
|
| 134 |
+
instance of [`Wav2Vec2ForSequenceClassification`].
|
| 135 |
+
classifier_proj_size (`int`, *optional*, defaults to 256):
|
| 136 |
+
Dimensionality of the projection before token mean-pooling for classification.
|
| 137 |
+
|
| 138 |
+
Example:
|
| 139 |
+
|
| 140 |
+
```python
|
| 141 |
+
>>> from transformers import SEWConfig, SEWModel
|
| 142 |
+
|
| 143 |
+
>>> # Initializing a SEW asapp/sew-tiny-100k style configuration
|
| 144 |
+
>>> configuration = SEWConfig()
|
| 145 |
+
|
| 146 |
+
>>> # Initializing a model (with random weights) from the asapp/sew-tiny-100k style configuration
|
| 147 |
+
>>> model = SEWModel(configuration)
|
| 148 |
+
|
| 149 |
+
>>> # Accessing the model configuration
|
| 150 |
+
>>> configuration = model.config
|
| 151 |
+
```"""
|
| 152 |
+
|
| 153 |
+
model_type = "sew"
|
| 154 |
+
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
vocab_size=32,
|
| 158 |
+
hidden_size=768,
|
| 159 |
+
num_hidden_layers=12,
|
| 160 |
+
num_attention_heads=12,
|
| 161 |
+
intermediate_size=3072,
|
| 162 |
+
squeeze_factor=2,
|
| 163 |
+
hidden_act="gelu",
|
| 164 |
+
hidden_dropout=0.1,
|
| 165 |
+
activation_dropout=0.1,
|
| 166 |
+
attention_dropout=0.1,
|
| 167 |
+
feat_proj_dropout=0.0,
|
| 168 |
+
final_dropout=0.1,
|
| 169 |
+
layerdrop=0.1,
|
| 170 |
+
initializer_range=0.02,
|
| 171 |
+
layer_norm_eps=1e-5,
|
| 172 |
+
feat_extract_norm="group",
|
| 173 |
+
feat_extract_activation="gelu",
|
| 174 |
+
conv_dim=(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512),
|
| 175 |
+
conv_stride=(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1),
|
| 176 |
+
conv_kernel=(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1),
|
| 177 |
+
conv_bias=False,
|
| 178 |
+
num_conv_pos_embeddings=128,
|
| 179 |
+
num_conv_pos_embedding_groups=16,
|
| 180 |
+
apply_spec_augment=True,
|
| 181 |
+
mask_time_prob=0.05,
|
| 182 |
+
mask_time_length=10,
|
| 183 |
+
mask_time_min_masks=2,
|
| 184 |
+
mask_feature_prob=0.0,
|
| 185 |
+
mask_feature_length=10,
|
| 186 |
+
mask_feature_min_masks=0,
|
| 187 |
+
ctc_loss_reduction="mean",
|
| 188 |
+
ctc_zero_infinity=False,
|
| 189 |
+
use_weighted_layer_sum=False,
|
| 190 |
+
classifier_proj_size=256,
|
| 191 |
+
pad_token_id=0,
|
| 192 |
+
bos_token_id=1,
|
| 193 |
+
eos_token_id=2,
|
| 194 |
+
**kwargs,
|
| 195 |
+
):
|
| 196 |
+
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
|
| 197 |
+
self.hidden_size = hidden_size
|
| 198 |
+
self.feat_extract_norm = feat_extract_norm
|
| 199 |
+
self.feat_extract_activation = feat_extract_activation
|
| 200 |
+
self.conv_dim = list(conv_dim)
|
| 201 |
+
self.conv_stride = list(conv_stride)
|
| 202 |
+
self.conv_kernel = list(conv_kernel)
|
| 203 |
+
self.conv_bias = conv_bias
|
| 204 |
+
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
| 205 |
+
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
| 206 |
+
self.num_feat_extract_layers = len(self.conv_dim)
|
| 207 |
+
self.num_hidden_layers = num_hidden_layers
|
| 208 |
+
self.intermediate_size = intermediate_size
|
| 209 |
+
self.squeeze_factor = squeeze_factor
|
| 210 |
+
self.hidden_act = hidden_act
|
| 211 |
+
self.num_attention_heads = num_attention_heads
|
| 212 |
+
self.hidden_dropout = hidden_dropout
|
| 213 |
+
self.attention_dropout = attention_dropout
|
| 214 |
+
self.activation_dropout = activation_dropout
|
| 215 |
+
self.feat_proj_dropout = feat_proj_dropout
|
| 216 |
+
self.final_dropout = final_dropout
|
| 217 |
+
self.layerdrop = layerdrop
|
| 218 |
+
self.layer_norm_eps = layer_norm_eps
|
| 219 |
+
self.initializer_range = initializer_range
|
| 220 |
+
self.vocab_size = vocab_size
|
| 221 |
+
|
| 222 |
+
if (
|
| 223 |
+
(len(self.conv_stride) != self.num_feat_extract_layers)
|
| 224 |
+
or (len(self.conv_kernel) != self.num_feat_extract_layers)
|
| 225 |
+
or (len(self.conv_dim) != self.num_feat_extract_layers)
|
| 226 |
+
):
|
| 227 |
+
raise ValueError(
|
| 228 |
+
"Configuration for convolutional layers is incorrect. "
|
| 229 |
+
"It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
|
| 230 |
+
f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
|
| 231 |
+
f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
|
| 235 |
+
self.apply_spec_augment = apply_spec_augment
|
| 236 |
+
self.mask_time_prob = mask_time_prob
|
| 237 |
+
self.mask_time_length = mask_time_length
|
| 238 |
+
self.mask_time_min_masks = mask_time_min_masks
|
| 239 |
+
self.mask_feature_prob = mask_feature_prob
|
| 240 |
+
self.mask_feature_length = mask_feature_length
|
| 241 |
+
self.mask_feature_min_masks = mask_feature_min_masks
|
| 242 |
+
|
| 243 |
+
# ctc loss
|
| 244 |
+
self.ctc_loss_reduction = ctc_loss_reduction
|
| 245 |
+
self.ctc_zero_infinity = ctc_zero_infinity
|
| 246 |
+
|
| 247 |
+
# sequence classification
|
| 248 |
+
self.use_weighted_layer_sum = use_weighted_layer_sum
|
| 249 |
+
self.classifier_proj_size = classifier_proj_size
|
| 250 |
+
|
| 251 |
+
@property
|
| 252 |
+
def inputs_to_logits_ratio(self):
|
| 253 |
+
return functools.reduce(operator.mul, self.conv_stride, 1)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
__all__ = ["SEWConfig"]
|
docs/transformers/build/lib/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert SEW checkpoint."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
import fairseq
|
| 22 |
+
import torch
|
| 23 |
+
from fairseq.data import Dictionary
|
| 24 |
+
|
| 25 |
+
# Register SEW's fairseq modules
|
| 26 |
+
from sew_asapp import tasks # noqa: F401
|
| 27 |
+
|
| 28 |
+
from transformers import (
|
| 29 |
+
SEWConfig,
|
| 30 |
+
SEWForCTC,
|
| 31 |
+
SEWModel,
|
| 32 |
+
Wav2Vec2CTCTokenizer,
|
| 33 |
+
Wav2Vec2FeatureExtractor,
|
| 34 |
+
Wav2Vec2Processor,
|
| 35 |
+
logging,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logging.set_verbosity_info()
|
| 40 |
+
logger = logging.get_logger(__name__)
|
| 41 |
+
|
| 42 |
+
MAPPING = {
|
| 43 |
+
"post_extract_proj": "feature_projection",
|
| 44 |
+
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
|
| 45 |
+
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
|
| 46 |
+
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
|
| 47 |
+
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
|
| 48 |
+
"self_attn.out_proj": "encoder.layers.*.attention.out_proj",
|
| 49 |
+
"self_attn_layer_norm": "encoder.layers.*.layer_norm",
|
| 50 |
+
"fc1": "encoder.layers.*.feed_forward.intermediate_dense",
|
| 51 |
+
"fc2": "encoder.layers.*.feed_forward.output_dense",
|
| 52 |
+
"final_layer_norm": "encoder.layers.*.final_layer_norm",
|
| 53 |
+
"encoder.upsample.0": "encoder.upsample.projection",
|
| 54 |
+
"encoder.layer_norm": "encoder.layer_norm",
|
| 55 |
+
"w2v_model.layer_norm": "layer_norm",
|
| 56 |
+
"w2v_encoder.proj": "lm_head",
|
| 57 |
+
"mask_emb": "masked_spec_embed",
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
| 62 |
+
for attribute in key.split("."):
|
| 63 |
+
hf_pointer = getattr(hf_pointer, attribute)
|
| 64 |
+
|
| 65 |
+
if weight_type is not None:
|
| 66 |
+
hf_shape = getattr(hf_pointer, weight_type).shape
|
| 67 |
+
else:
|
| 68 |
+
hf_shape = hf_pointer.shape
|
| 69 |
+
|
| 70 |
+
assert hf_shape == value.shape, (
|
| 71 |
+
f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
|
| 72 |
+
f" {value.shape} for {full_name}"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if weight_type == "weight":
|
| 76 |
+
hf_pointer.weight.data = value
|
| 77 |
+
elif weight_type == "weight_g":
|
| 78 |
+
hf_pointer.weight_g.data = value
|
| 79 |
+
elif weight_type == "weight_v":
|
| 80 |
+
hf_pointer.weight_v.data = value
|
| 81 |
+
elif weight_type == "bias":
|
| 82 |
+
hf_pointer.bias.data = value
|
| 83 |
+
else:
|
| 84 |
+
hf_pointer.data = value
|
| 85 |
+
|
| 86 |
+
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
|
| 90 |
+
unused_weights = []
|
| 91 |
+
fairseq_dict = fairseq_model.state_dict()
|
| 92 |
+
|
| 93 |
+
feature_extractor = hf_model.sew.feature_extractor if is_finetuned else hf_model.feature_extractor
|
| 94 |
+
|
| 95 |
+
for name, value in fairseq_dict.items():
|
| 96 |
+
is_used = False
|
| 97 |
+
if "conv_layers" in name:
|
| 98 |
+
load_conv_layer(
|
| 99 |
+
name,
|
| 100 |
+
value,
|
| 101 |
+
feature_extractor,
|
| 102 |
+
unused_weights,
|
| 103 |
+
hf_model.config.feat_extract_norm == "group",
|
| 104 |
+
)
|
| 105 |
+
is_used = True
|
| 106 |
+
else:
|
| 107 |
+
for key, mapped_key in MAPPING.items():
|
| 108 |
+
mapped_key = "sew." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
|
| 109 |
+
|
| 110 |
+
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
|
| 111 |
+
is_used = True
|
| 112 |
+
if "*" in mapped_key:
|
| 113 |
+
layer_index = name.split(key)[0].split(".")[-2]
|
| 114 |
+
mapped_key = mapped_key.replace("*", layer_index)
|
| 115 |
+
if "weight_g" in name:
|
| 116 |
+
weight_type = "weight_g"
|
| 117 |
+
elif "weight_v" in name:
|
| 118 |
+
weight_type = "weight_v"
|
| 119 |
+
elif "weight" in name:
|
| 120 |
+
weight_type = "weight"
|
| 121 |
+
elif "bias" in name:
|
| 122 |
+
weight_type = "bias"
|
| 123 |
+
else:
|
| 124 |
+
weight_type = None
|
| 125 |
+
set_recursively(hf_model, mapped_key, value, name, weight_type)
|
| 126 |
+
continue
|
| 127 |
+
if not is_used:
|
| 128 |
+
unused_weights.append(name)
|
| 129 |
+
|
| 130 |
+
logger.warning(f"Unused weights: {unused_weights}")
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
|
| 134 |
+
name = full_name.split("conv_layers.")[-1]
|
| 135 |
+
items = name.split(".")
|
| 136 |
+
layer_id = int(items[0])
|
| 137 |
+
type_id = int(items[1])
|
| 138 |
+
|
| 139 |
+
if type_id == 0:
|
| 140 |
+
if "bias" in name:
|
| 141 |
+
assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
|
| 142 |
+
f"{full_name} has size {value.shape}, but"
|
| 143 |
+
f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
|
| 144 |
+
)
|
| 145 |
+
feature_extractor.conv_layers[layer_id].conv.bias.data = value
|
| 146 |
+
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
|
| 147 |
+
elif "weight" in name:
|
| 148 |
+
assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
|
| 149 |
+
f"{full_name} has size {value.shape}, but"
|
| 150 |
+
f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
|
| 151 |
+
)
|
| 152 |
+
feature_extractor.conv_layers[layer_id].conv.weight.data = value
|
| 153 |
+
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
|
| 154 |
+
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
|
| 155 |
+
if "bias" in name:
|
| 156 |
+
assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
|
| 157 |
+
f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
|
| 158 |
+
" found."
|
| 159 |
+
)
|
| 160 |
+
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
|
| 161 |
+
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
|
| 162 |
+
elif "weight" in name:
|
| 163 |
+
assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
|
| 164 |
+
f"{full_name} has size {value.shape}, but"
|
| 165 |
+
f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
|
| 166 |
+
)
|
| 167 |
+
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
|
| 168 |
+
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
|
| 169 |
+
else:
|
| 170 |
+
unused_weights.append(full_name)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def convert_config(model, is_finetuned):
|
| 174 |
+
config = SEWConfig()
|
| 175 |
+
if is_finetuned:
|
| 176 |
+
fs_config = model.w2v_encoder.w2v_model.cfg
|
| 177 |
+
else:
|
| 178 |
+
fs_config = model.cfg
|
| 179 |
+
|
| 180 |
+
config.conv_bias = fs_config.conv_bias
|
| 181 |
+
conv_layers = eval(fs_config.conv_feature_layers)
|
| 182 |
+
config.conv_dim = [x[0] for x in conv_layers]
|
| 183 |
+
config.conv_kernel = [x[1] for x in conv_layers]
|
| 184 |
+
config.conv_stride = [x[2] for x in conv_layers]
|
| 185 |
+
config.feat_extract_activation = "gelu"
|
| 186 |
+
config.feat_extract_norm = "layer" if fs_config.extractor_mode == "layer_norm" else "group"
|
| 187 |
+
config.final_dropout = 0.0
|
| 188 |
+
config.hidden_act = fs_config.activation_fn.name
|
| 189 |
+
config.hidden_size = fs_config.encoder_embed_dim
|
| 190 |
+
config.initializer_range = 0.02
|
| 191 |
+
config.intermediate_size = fs_config.encoder_ffn_embed_dim
|
| 192 |
+
config.layer_norm_eps = 1e-5
|
| 193 |
+
config.layerdrop = fs_config.encoder_layerdrop
|
| 194 |
+
config.num_attention_heads = fs_config.encoder_attention_heads
|
| 195 |
+
config.num_conv_pos_embedding_groups = fs_config.conv_pos_groups
|
| 196 |
+
config.num_conv_pos_embeddings = fs_config.conv_pos
|
| 197 |
+
config.num_feat_extract_layers = len(conv_layers)
|
| 198 |
+
config.num_hidden_layers = fs_config.encoder_layers
|
| 199 |
+
config.squeeze_factor = fs_config.squeeze_factor
|
| 200 |
+
|
| 201 |
+
# take care of any params that are overridden by the Wav2VecCtc model
|
| 202 |
+
if is_finetuned:
|
| 203 |
+
fs_config = model.cfg
|
| 204 |
+
config.final_dropout = fs_config.final_dropout
|
| 205 |
+
config.layerdrop = fs_config.layerdrop
|
| 206 |
+
config.activation_dropout = fs_config.activation_dropout
|
| 207 |
+
config.apply_spec_augment = fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0
|
| 208 |
+
config.attention_dropout = fs_config.attention_dropout
|
| 209 |
+
config.feat_proj_dropout = fs_config.dropout_input
|
| 210 |
+
config.hidden_dropout = fs_config.dropout
|
| 211 |
+
config.mask_feature_length = fs_config.mask_channel_length
|
| 212 |
+
config.mask_feature_prob = fs_config.mask_channel_prob
|
| 213 |
+
config.mask_time_length = fs_config.mask_length
|
| 214 |
+
config.mask_time_prob = fs_config.mask_prob
|
| 215 |
+
|
| 216 |
+
config.feature_extractor_type = "Wav2Vec2FeatureExtractor"
|
| 217 |
+
config.tokenizer_class = "Wav2Vec2CTCTokenizer"
|
| 218 |
+
|
| 219 |
+
return config
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
@torch.no_grad()
|
| 223 |
+
def convert_sew_checkpoint(
|
| 224 |
+
checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
|
| 225 |
+
):
|
| 226 |
+
"""
|
| 227 |
+
Copy/paste/tweak model's weights to transformers design.
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
if is_finetuned:
|
| 231 |
+
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
| 232 |
+
[checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
|
| 233 |
+
)
|
| 234 |
+
else:
|
| 235 |
+
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])
|
| 236 |
+
|
| 237 |
+
if config_path is not None:
|
| 238 |
+
config = SEWConfig.from_pretrained(config_path)
|
| 239 |
+
else:
|
| 240 |
+
config = convert_config(model[0], is_finetuned)
|
| 241 |
+
model = model[0].eval()
|
| 242 |
+
|
| 243 |
+
return_attention_mask = True if config.feat_extract_norm == "layer" else False
|
| 244 |
+
feature_extractor = Wav2Vec2FeatureExtractor(
|
| 245 |
+
feature_size=1,
|
| 246 |
+
sampling_rate=16000,
|
| 247 |
+
padding_value=0,
|
| 248 |
+
do_normalize=True,
|
| 249 |
+
return_attention_mask=return_attention_mask,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
if is_finetuned:
|
| 253 |
+
if dict_path:
|
| 254 |
+
target_dict = Dictionary.load(dict_path)
|
| 255 |
+
|
| 256 |
+
# important change bos & pad token id since CTC symbol is <pad> and
|
| 257 |
+
# not <s> as in fairseq
|
| 258 |
+
target_dict.indices[target_dict.bos_word] = target_dict.pad_index
|
| 259 |
+
target_dict.indices[target_dict.pad_word] = target_dict.bos_index
|
| 260 |
+
config.bos_token_id = target_dict.pad_index
|
| 261 |
+
config.pad_token_id = target_dict.bos_index
|
| 262 |
+
config.eos_token_id = target_dict.eos_index
|
| 263 |
+
config.vocab_size = len(target_dict.symbols)
|
| 264 |
+
vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
|
| 265 |
+
if not os.path.isdir(pytorch_dump_folder_path):
|
| 266 |
+
logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path))
|
| 267 |
+
return
|
| 268 |
+
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
| 269 |
+
with open(vocab_path, "w", encoding="utf-8") as vocab_handle:
|
| 270 |
+
json.dump(target_dict.indices, vocab_handle)
|
| 271 |
+
tokenizer = Wav2Vec2CTCTokenizer(
|
| 272 |
+
vocab_path,
|
| 273 |
+
unk_token=target_dict.unk_word,
|
| 274 |
+
pad_token=target_dict.pad_word,
|
| 275 |
+
bos_token=target_dict.bos_word,
|
| 276 |
+
eos_token=target_dict.eos_word,
|
| 277 |
+
word_delimiter_token="|",
|
| 278 |
+
do_lower_case=False,
|
| 279 |
+
)
|
| 280 |
+
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
| 281 |
+
processor.save_pretrained(pytorch_dump_folder_path)
|
| 282 |
+
|
| 283 |
+
hf_model = SEWForCTC(config)
|
| 284 |
+
else:
|
| 285 |
+
hf_model = SEWModel(config)
|
| 286 |
+
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
| 287 |
+
|
| 288 |
+
recursively_load_weights(model, hf_model, is_finetuned)
|
| 289 |
+
|
| 290 |
+
hf_model.save_pretrained(pytorch_dump_folder_path)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
if __name__ == "__main__":
|
| 294 |
+
parser = argparse.ArgumentParser()
|
| 295 |
+
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
| 296 |
+
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
|
| 297 |
+
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
|
| 298 |
+
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
| 299 |
+
parser.add_argument(
|
| 300 |
+
"--is_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
|
| 301 |
+
)
|
| 302 |
+
args = parser.parse_args()
|
| 303 |
+
convert_sew_checkpoint(
|
| 304 |
+
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, args.is_finetuned
|
| 305 |
+
)
|
docs/transformers/build/lib/transformers/models/sew/modeling_sew.py
ADDED
|
@@ -0,0 +1,1498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch SEW model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import warnings
|
| 19 |
+
from typing import Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.utils.checkpoint
|
| 24 |
+
from torch import nn
|
| 25 |
+
from torch.nn import CrossEntropyLoss
|
| 26 |
+
|
| 27 |
+
from ...activations import ACT2FN
|
| 28 |
+
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 29 |
+
from ...integrations.fsdp import is_fsdp_managed_module
|
| 30 |
+
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
| 31 |
+
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
| 32 |
+
from ...modeling_utils import PreTrainedModel
|
| 33 |
+
from ...utils import (
|
| 34 |
+
add_code_sample_docstrings,
|
| 35 |
+
add_start_docstrings,
|
| 36 |
+
add_start_docstrings_to_model_forward,
|
| 37 |
+
logging,
|
| 38 |
+
)
|
| 39 |
+
from .configuration_sew import SEWConfig
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if is_flash_attn_available():
|
| 43 |
+
from ...modeling_flash_attention_utils import _flash_attention_forward
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
logger = logging.get_logger(__name__)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
_HIDDEN_STATES_START_POSITION = 1
|
| 50 |
+
|
| 51 |
+
# General docstring
|
| 52 |
+
_CONFIG_FOR_DOC = "SEWConfig"
|
| 53 |
+
|
| 54 |
+
# Base docstring
|
| 55 |
+
_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k-ft-ls100h"
|
| 56 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 292, 512]
|
| 57 |
+
|
| 58 |
+
# CTC docstring
|
| 59 |
+
_CTC_EXPECTED_OUTPUT = (
|
| 60 |
+
"'MISTER QUILTER IS THE APPOSTILE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPOLLE'"
|
| 61 |
+
)
|
| 62 |
+
_CTC_EXPECTED_LOSS = 0.42
|
| 63 |
+
|
| 64 |
+
# Audio class docstring
|
| 65 |
+
_SEQ_CLASS_CHECKPOINT = "anton-l/sew-mid-100k-ft-keyword-spotting"
|
| 66 |
+
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
|
| 67 |
+
_SEQ_CLASS_EXPECTED_LOSS = 9.52
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
|
| 71 |
+
def _compute_mask_indices(
|
| 72 |
+
shape: Tuple[int, int],
|
| 73 |
+
mask_prob: float,
|
| 74 |
+
mask_length: int,
|
| 75 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 76 |
+
min_masks: int = 0,
|
| 77 |
+
) -> np.ndarray:
|
| 78 |
+
"""
|
| 79 |
+
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
|
| 80 |
+
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
|
| 81 |
+
CPU as part of the preprocessing during training.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
|
| 85 |
+
the first element is the batch size and the second element is the length of the axis to span.
|
| 86 |
+
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
|
| 87 |
+
independently generated mask spans of length `mask_length` is computed by
|
| 88 |
+
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
|
| 89 |
+
actual percentage will be smaller.
|
| 90 |
+
mask_length: size of the mask
|
| 91 |
+
min_masks: minimum number of masked spans
|
| 92 |
+
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
|
| 93 |
+
each batch dimension.
|
| 94 |
+
"""
|
| 95 |
+
batch_size, sequence_length = shape
|
| 96 |
+
|
| 97 |
+
if mask_length < 1:
|
| 98 |
+
raise ValueError("`mask_length` has to be bigger than 0.")
|
| 99 |
+
|
| 100 |
+
if mask_length > sequence_length:
|
| 101 |
+
raise ValueError(
|
| 102 |
+
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
|
| 103 |
+
f" and `sequence_length`: {sequence_length}`"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# epsilon is used for probabilistic rounding
|
| 107 |
+
epsilon = np.random.rand(1).item()
|
| 108 |
+
|
| 109 |
+
def compute_num_masked_span(input_length):
|
| 110 |
+
"""Given input length, compute how many spans should be masked"""
|
| 111 |
+
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
|
| 112 |
+
num_masked_span = max(num_masked_span, min_masks)
|
| 113 |
+
|
| 114 |
+
# make sure num masked span <= sequence_length
|
| 115 |
+
if num_masked_span * mask_length > sequence_length:
|
| 116 |
+
num_masked_span = sequence_length // mask_length
|
| 117 |
+
|
| 118 |
+
# make sure num_masked span is also <= input_length - (mask_length - 1)
|
| 119 |
+
if input_length - (mask_length - 1) < num_masked_span:
|
| 120 |
+
num_masked_span = max(input_length - (mask_length - 1), 0)
|
| 121 |
+
|
| 122 |
+
return num_masked_span
|
| 123 |
+
|
| 124 |
+
# compute number of masked spans in batch
|
| 125 |
+
input_lengths = (
|
| 126 |
+
attention_mask.detach().sum(-1).tolist()
|
| 127 |
+
if attention_mask is not None
|
| 128 |
+
else [sequence_length for _ in range(batch_size)]
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# SpecAugment mask to fill
|
| 132 |
+
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
|
| 133 |
+
spec_aug_mask_idxs = []
|
| 134 |
+
|
| 135 |
+
max_num_masked_span = compute_num_masked_span(sequence_length)
|
| 136 |
+
|
| 137 |
+
if max_num_masked_span == 0:
|
| 138 |
+
return spec_aug_mask
|
| 139 |
+
|
| 140 |
+
for input_length in input_lengths:
|
| 141 |
+
# compute num of masked spans for this input
|
| 142 |
+
num_masked_span = compute_num_masked_span(input_length)
|
| 143 |
+
|
| 144 |
+
# get random indices to mask
|
| 145 |
+
spec_aug_mask_idx = np.random.choice(
|
| 146 |
+
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# pick first sampled index that will serve as a dummy index to pad vector
|
| 150 |
+
# to ensure same dimension for all batches due to probabilistic rounding
|
| 151 |
+
# Picking first sample just pads those vectors twice.
|
| 152 |
+
if len(spec_aug_mask_idx) == 0:
|
| 153 |
+
# this case can only happen if `input_length` is strictly smaller then
|
| 154 |
+
# `sequence_length` in which case the last token has to be a padding
|
| 155 |
+
# token which we can use as a dummy mask id
|
| 156 |
+
dummy_mask_idx = sequence_length - 1
|
| 157 |
+
else:
|
| 158 |
+
dummy_mask_idx = spec_aug_mask_idx[0]
|
| 159 |
+
|
| 160 |
+
spec_aug_mask_idx = np.concatenate(
|
| 161 |
+
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
|
| 162 |
+
)
|
| 163 |
+
spec_aug_mask_idxs.append(spec_aug_mask_idx)
|
| 164 |
+
|
| 165 |
+
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
|
| 166 |
+
|
| 167 |
+
# expand masked indices to masked spans
|
| 168 |
+
spec_aug_mask_idxs = np.broadcast_to(
|
| 169 |
+
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
|
| 170 |
+
)
|
| 171 |
+
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
|
| 172 |
+
|
| 173 |
+
# add offset to the starting indexes so that indexes now create a span
|
| 174 |
+
offsets = np.arange(mask_length)[None, None, :]
|
| 175 |
+
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
|
| 176 |
+
batch_size, max_num_masked_span * mask_length
|
| 177 |
+
)
|
| 178 |
+
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
| 179 |
+
|
| 180 |
+
# ensure that we cannot have indices larger than sequence_length
|
| 181 |
+
if spec_aug_mask_idxs.max() > sequence_length - 1:
|
| 182 |
+
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
|
| 183 |
+
|
| 184 |
+
# scatter indices to mask
|
| 185 |
+
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
| 186 |
+
|
| 187 |
+
return spec_aug_mask
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SEW
|
| 191 |
+
class SEWNoLayerNormConvLayer(nn.Module):
|
| 192 |
+
def __init__(self, config, layer_id=0):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 195 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
| 196 |
+
|
| 197 |
+
self.conv = nn.Conv1d(
|
| 198 |
+
self.in_conv_dim,
|
| 199 |
+
self.out_conv_dim,
|
| 200 |
+
kernel_size=config.conv_kernel[layer_id],
|
| 201 |
+
stride=config.conv_stride[layer_id],
|
| 202 |
+
bias=config.conv_bias,
|
| 203 |
+
)
|
| 204 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 205 |
+
|
| 206 |
+
def forward(self, hidden_states):
|
| 207 |
+
hidden_states = self.conv(hidden_states)
|
| 208 |
+
hidden_states = self.activation(hidden_states)
|
| 209 |
+
return hidden_states
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SEW
|
| 213 |
+
class SEWLayerNormConvLayer(nn.Module):
|
| 214 |
+
def __init__(self, config, layer_id=0):
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 217 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
| 218 |
+
|
| 219 |
+
self.conv = nn.Conv1d(
|
| 220 |
+
self.in_conv_dim,
|
| 221 |
+
self.out_conv_dim,
|
| 222 |
+
kernel_size=config.conv_kernel[layer_id],
|
| 223 |
+
stride=config.conv_stride[layer_id],
|
| 224 |
+
bias=config.conv_bias,
|
| 225 |
+
)
|
| 226 |
+
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
|
| 227 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 228 |
+
|
| 229 |
+
def forward(self, hidden_states):
|
| 230 |
+
hidden_states = self.conv(hidden_states)
|
| 231 |
+
|
| 232 |
+
hidden_states = hidden_states.transpose(-2, -1)
|
| 233 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 234 |
+
hidden_states = hidden_states.transpose(-2, -1)
|
| 235 |
+
|
| 236 |
+
hidden_states = self.activation(hidden_states)
|
| 237 |
+
return hidden_states
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SEW
|
| 241 |
+
class SEWGroupNormConvLayer(nn.Module):
|
| 242 |
+
def __init__(self, config, layer_id=0):
|
| 243 |
+
super().__init__()
|
| 244 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 245 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
| 246 |
+
|
| 247 |
+
self.conv = nn.Conv1d(
|
| 248 |
+
self.in_conv_dim,
|
| 249 |
+
self.out_conv_dim,
|
| 250 |
+
kernel_size=config.conv_kernel[layer_id],
|
| 251 |
+
stride=config.conv_stride[layer_id],
|
| 252 |
+
bias=config.conv_bias,
|
| 253 |
+
)
|
| 254 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 255 |
+
|
| 256 |
+
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
|
| 257 |
+
|
| 258 |
+
def forward(self, hidden_states):
|
| 259 |
+
hidden_states = self.conv(hidden_states)
|
| 260 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 261 |
+
hidden_states = self.activation(hidden_states)
|
| 262 |
+
return hidden_states
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class SEWPositionalConvEmbedding(nn.Module):
|
| 266 |
+
def __init__(self, config):
|
| 267 |
+
super().__init__()
|
| 268 |
+
self.conv = nn.Conv1d(
|
| 269 |
+
config.hidden_size,
|
| 270 |
+
config.hidden_size,
|
| 271 |
+
kernel_size=config.num_conv_pos_embeddings,
|
| 272 |
+
padding=config.num_conv_pos_embeddings // 2,
|
| 273 |
+
groups=config.num_conv_pos_embedding_groups,
|
| 274 |
+
stride=config.squeeze_factor,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
weight_norm = nn.utils.weight_norm
|
| 278 |
+
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
| 279 |
+
weight_norm = nn.utils.parametrizations.weight_norm
|
| 280 |
+
|
| 281 |
+
if is_deepspeed_zero3_enabled():
|
| 282 |
+
import deepspeed
|
| 283 |
+
|
| 284 |
+
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
| 285 |
+
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
| 286 |
+
if hasattr(self.conv, "parametrizations"):
|
| 287 |
+
weight_g = self.conv.parametrizations.weight.original0
|
| 288 |
+
weight_v = self.conv.parametrizations.weight.original1
|
| 289 |
+
else:
|
| 290 |
+
weight_g = self.conv.weight_g
|
| 291 |
+
weight_v = self.conv.weight_v
|
| 292 |
+
deepspeed.zero.register_external_parameter(self, weight_v)
|
| 293 |
+
deepspeed.zero.register_external_parameter(self, weight_g)
|
| 294 |
+
else:
|
| 295 |
+
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
| 296 |
+
|
| 297 |
+
self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings)
|
| 298 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 299 |
+
|
| 300 |
+
def forward(self, hidden_states):
|
| 301 |
+
hidden_states = self.conv(hidden_states)
|
| 302 |
+
hidden_states = self.padding(hidden_states)
|
| 303 |
+
hidden_states = self.activation(hidden_states)
|
| 304 |
+
|
| 305 |
+
return hidden_states
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->SEW
|
| 309 |
+
class SEWSamePadLayer(nn.Module):
|
| 310 |
+
def __init__(self, num_conv_pos_embeddings):
|
| 311 |
+
super().__init__()
|
| 312 |
+
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
|
| 313 |
+
|
| 314 |
+
def forward(self, hidden_states):
|
| 315 |
+
if self.num_pad_remove > 0:
|
| 316 |
+
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
|
| 317 |
+
return hidden_states
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class SEWUpsampling(nn.Module):
|
| 321 |
+
def __init__(self, config):
|
| 322 |
+
super().__init__()
|
| 323 |
+
self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor)
|
| 324 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 325 |
+
self.squeeze_factor = config.squeeze_factor
|
| 326 |
+
|
| 327 |
+
def forward(self, hidden_states):
|
| 328 |
+
hidden_states = self.projection(hidden_states)
|
| 329 |
+
hidden_states = self.activation(hidden_states)
|
| 330 |
+
|
| 331 |
+
if self.squeeze_factor > 1:
|
| 332 |
+
# transform embedding channels to sequence length
|
| 333 |
+
bsz, src_len, src_embed_dim = hidden_states.size()
|
| 334 |
+
tgt_len = src_len * self.squeeze_factor
|
| 335 |
+
tgt_embed_dim = src_embed_dim // self.squeeze_factor
|
| 336 |
+
hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim)
|
| 337 |
+
hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim)
|
| 338 |
+
|
| 339 |
+
return hidden_states
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->SEW
|
| 343 |
+
class SEWFeatureEncoder(nn.Module):
|
| 344 |
+
"""Construct the features from raw audio waveform"""
|
| 345 |
+
|
| 346 |
+
def __init__(self, config):
|
| 347 |
+
super().__init__()
|
| 348 |
+
|
| 349 |
+
if config.feat_extract_norm == "group":
|
| 350 |
+
conv_layers = [SEWGroupNormConvLayer(config, layer_id=0)] + [
|
| 351 |
+
SEWNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
|
| 352 |
+
]
|
| 353 |
+
elif config.feat_extract_norm == "layer":
|
| 354 |
+
conv_layers = [SEWLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
|
| 355 |
+
else:
|
| 356 |
+
raise ValueError(
|
| 357 |
+
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
| 358 |
+
)
|
| 359 |
+
self.conv_layers = nn.ModuleList(conv_layers)
|
| 360 |
+
self.gradient_checkpointing = False
|
| 361 |
+
self._requires_grad = True
|
| 362 |
+
|
| 363 |
+
def _freeze_parameters(self):
|
| 364 |
+
for param in self.parameters():
|
| 365 |
+
param.requires_grad = False
|
| 366 |
+
self._requires_grad = False
|
| 367 |
+
|
| 368 |
+
def forward(self, input_values):
|
| 369 |
+
hidden_states = input_values[:, None]
|
| 370 |
+
|
| 371 |
+
# make sure hidden_states require grad for gradient_checkpointing
|
| 372 |
+
if self._requires_grad and self.training:
|
| 373 |
+
hidden_states.requires_grad = True
|
| 374 |
+
|
| 375 |
+
for conv_layer in self.conv_layers:
|
| 376 |
+
if self._requires_grad and self.gradient_checkpointing and self.training:
|
| 377 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 378 |
+
conv_layer.__call__,
|
| 379 |
+
hidden_states,
|
| 380 |
+
)
|
| 381 |
+
else:
|
| 382 |
+
hidden_states = conv_layer(hidden_states)
|
| 383 |
+
|
| 384 |
+
return hidden_states
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class SEWFeatureExtractor(SEWFeatureEncoder):
|
| 388 |
+
def __init__(self, config):
|
| 389 |
+
super().__init__(config)
|
| 390 |
+
warnings.warn(
|
| 391 |
+
f"The class `{self.__class__.__name__}` has been depreciated "
|
| 392 |
+
"and will be removed in Transformers v5. "
|
| 393 |
+
f"Use `{self.__class__.__bases__[0].__name__}` instead.",
|
| 394 |
+
FutureWarning,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->SEW
|
| 399 |
+
class SEWAttention(nn.Module):
|
| 400 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 401 |
+
|
| 402 |
+
def __init__(
|
| 403 |
+
self,
|
| 404 |
+
embed_dim: int,
|
| 405 |
+
num_heads: int,
|
| 406 |
+
dropout: float = 0.0,
|
| 407 |
+
is_decoder: bool = False,
|
| 408 |
+
bias: bool = True,
|
| 409 |
+
is_causal: bool = False,
|
| 410 |
+
config: Optional[SEWConfig] = None,
|
| 411 |
+
):
|
| 412 |
+
super().__init__()
|
| 413 |
+
self.embed_dim = embed_dim
|
| 414 |
+
self.num_heads = num_heads
|
| 415 |
+
self.dropout = dropout
|
| 416 |
+
self.head_dim = embed_dim // num_heads
|
| 417 |
+
self.config = config
|
| 418 |
+
|
| 419 |
+
if (self.head_dim * num_heads) != self.embed_dim:
|
| 420 |
+
raise ValueError(
|
| 421 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
| 422 |
+
f" and `num_heads`: {num_heads})."
|
| 423 |
+
)
|
| 424 |
+
self.scaling = self.head_dim**-0.5
|
| 425 |
+
self.is_decoder = is_decoder
|
| 426 |
+
self.is_causal = is_causal
|
| 427 |
+
|
| 428 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 429 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 430 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 431 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 432 |
+
|
| 433 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 434 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 435 |
+
|
| 436 |
+
def forward(
|
| 437 |
+
self,
|
| 438 |
+
hidden_states: torch.Tensor,
|
| 439 |
+
key_value_states: Optional[torch.Tensor] = None,
|
| 440 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 441 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 442 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
| 443 |
+
output_attentions: bool = False,
|
| 444 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 445 |
+
"""Input shape: Batch x Time x Channel"""
|
| 446 |
+
|
| 447 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
| 448 |
+
# for the decoder
|
| 449 |
+
is_cross_attention = key_value_states is not None
|
| 450 |
+
|
| 451 |
+
bsz, tgt_len, _ = hidden_states.size()
|
| 452 |
+
|
| 453 |
+
# get query proj
|
| 454 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
| 455 |
+
# get key, value proj
|
| 456 |
+
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
| 457 |
+
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
| 458 |
+
# the provided `key_value_states` to support prefix tuning
|
| 459 |
+
if (
|
| 460 |
+
is_cross_attention
|
| 461 |
+
and past_key_value is not None
|
| 462 |
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
| 463 |
+
):
|
| 464 |
+
# reuse k,v, cross_attentions
|
| 465 |
+
key_states = past_key_value[0]
|
| 466 |
+
value_states = past_key_value[1]
|
| 467 |
+
elif is_cross_attention:
|
| 468 |
+
# cross_attentions
|
| 469 |
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
| 470 |
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
| 471 |
+
elif past_key_value is not None:
|
| 472 |
+
# reuse k, v, self_attention
|
| 473 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 474 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 475 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 476 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 477 |
+
else:
|
| 478 |
+
# self_attention
|
| 479 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 480 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 481 |
+
|
| 482 |
+
if self.is_decoder:
|
| 483 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 484 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 485 |
+
# key/value_states (first "if" case)
|
| 486 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 487 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 488 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 489 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 490 |
+
past_key_value = (key_states, value_states)
|
| 491 |
+
|
| 492 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
| 493 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
| 494 |
+
key_states = key_states.reshape(*proj_shape)
|
| 495 |
+
value_states = value_states.reshape(*proj_shape)
|
| 496 |
+
|
| 497 |
+
src_len = key_states.size(1)
|
| 498 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
| 499 |
+
|
| 500 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
| 501 |
+
raise ValueError(
|
| 502 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
| 503 |
+
f" {attn_weights.size()}"
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
if attention_mask is not None:
|
| 507 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
| 508 |
+
raise ValueError(
|
| 509 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
| 510 |
+
)
|
| 511 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
| 512 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 513 |
+
|
| 514 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 515 |
+
|
| 516 |
+
if layer_head_mask is not None:
|
| 517 |
+
if layer_head_mask.size() != (self.num_heads,):
|
| 518 |
+
raise ValueError(
|
| 519 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
| 520 |
+
f" {layer_head_mask.size()}"
|
| 521 |
+
)
|
| 522 |
+
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 523 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 524 |
+
|
| 525 |
+
if output_attentions:
|
| 526 |
+
# this operation is a bit awkward, but it's required to
|
| 527 |
+
# make sure that attn_weights keeps its gradient.
|
| 528 |
+
# In order to do so, attn_weights have to be reshaped
|
| 529 |
+
# twice and have to be reused in the following
|
| 530 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 531 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
| 532 |
+
else:
|
| 533 |
+
attn_weights_reshaped = None
|
| 534 |
+
|
| 535 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 536 |
+
|
| 537 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
| 538 |
+
|
| 539 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
| 540 |
+
raise ValueError(
|
| 541 |
+
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
|
| 542 |
+
f" {attn_output.size()}"
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
| 546 |
+
attn_output = attn_output.transpose(1, 2)
|
| 547 |
+
|
| 548 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
| 549 |
+
# partitioned across GPUs when using tensor-parallelism.
|
| 550 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
| 551 |
+
|
| 552 |
+
attn_output = self.out_proj(attn_output)
|
| 553 |
+
|
| 554 |
+
return attn_output, attn_weights_reshaped, past_key_value
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->SEW
|
| 558 |
+
class SEWFlashAttention2(SEWAttention):
|
| 559 |
+
"""
|
| 560 |
+
SEW flash attention module. This module inherits from `SEWAttention` as the weights of the module stays
|
| 561 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 562 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
| 563 |
+
"""
|
| 564 |
+
|
| 565 |
+
def __init__(self, *args, **kwargs):
|
| 566 |
+
super().__init__(*args, **kwargs)
|
| 567 |
+
|
| 568 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 569 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 570 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 571 |
+
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
| 572 |
+
|
| 573 |
+
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 574 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
| 575 |
+
|
| 576 |
+
def forward(
|
| 577 |
+
self,
|
| 578 |
+
hidden_states: torch.Tensor,
|
| 579 |
+
key_value_states: Optional[torch.Tensor] = None,
|
| 580 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 581 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 582 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
| 583 |
+
output_attentions: bool = False,
|
| 584 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 585 |
+
# SEWFlashAttention2 attention does not support output_attentions
|
| 586 |
+
if output_attentions:
|
| 587 |
+
raise ValueError("SEWFlashAttention2 attention does not support output_attentions")
|
| 588 |
+
|
| 589 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
| 590 |
+
# for the decoder
|
| 591 |
+
is_cross_attention = key_value_states is not None
|
| 592 |
+
|
| 593 |
+
bsz, q_len, _ = hidden_states.size()
|
| 594 |
+
|
| 595 |
+
# get query proj
|
| 596 |
+
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
|
| 597 |
+
# get key, value proj
|
| 598 |
+
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
| 599 |
+
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
| 600 |
+
# the provided `key_value_states` to support prefix tuning
|
| 601 |
+
if (
|
| 602 |
+
is_cross_attention
|
| 603 |
+
and past_key_value is not None
|
| 604 |
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
| 605 |
+
):
|
| 606 |
+
# reuse k,v, cross_attentions
|
| 607 |
+
key_states = past_key_value[0].transpose(1, 2)
|
| 608 |
+
value_states = past_key_value[1].transpose(1, 2)
|
| 609 |
+
elif is_cross_attention:
|
| 610 |
+
# cross_attentions
|
| 611 |
+
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
|
| 612 |
+
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
|
| 613 |
+
elif past_key_value is not None:
|
| 614 |
+
# reuse k, v, self_attention
|
| 615 |
+
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
| 616 |
+
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
| 617 |
+
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
|
| 618 |
+
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
|
| 619 |
+
else:
|
| 620 |
+
# self_attention
|
| 621 |
+
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
| 622 |
+
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
| 623 |
+
|
| 624 |
+
if self.is_decoder:
|
| 625 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 626 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 627 |
+
# key/value_states (first "if" case)
|
| 628 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 629 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 630 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 631 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 632 |
+
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
|
| 633 |
+
|
| 634 |
+
kv_seq_len = key_states.shape[-2]
|
| 635 |
+
if past_key_value is not None:
|
| 636 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
| 637 |
+
|
| 638 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 639 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 640 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
| 641 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
| 642 |
+
# in fp32. (LlamaRMSNorm handles it correctly)
|
| 643 |
+
|
| 644 |
+
input_dtype = query_states.dtype
|
| 645 |
+
if input_dtype == torch.float32:
|
| 646 |
+
if torch.is_autocast_enabled():
|
| 647 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 648 |
+
# Handle the case where the model is quantized
|
| 649 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 650 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 651 |
+
else:
|
| 652 |
+
target_dtype = self.q_proj.weight.dtype
|
| 653 |
+
|
| 654 |
+
logger.warning_once(
|
| 655 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
| 656 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 657 |
+
f" {target_dtype}."
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
query_states = query_states.to(target_dtype)
|
| 661 |
+
key_states = key_states.to(target_dtype)
|
| 662 |
+
value_states = value_states.to(target_dtype)
|
| 663 |
+
|
| 664 |
+
attn_output = _flash_attention_forward(
|
| 665 |
+
query_states,
|
| 666 |
+
key_states,
|
| 667 |
+
value_states,
|
| 668 |
+
attention_mask,
|
| 669 |
+
q_len,
|
| 670 |
+
dropout=self.dropout if self.training else 0.0,
|
| 671 |
+
is_causal=self.is_causal,
|
| 672 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
attn_output = attn_output.reshape(bsz, q_len, -1)
|
| 676 |
+
attn_output = self.out_proj(attn_output)
|
| 677 |
+
|
| 678 |
+
if not output_attentions:
|
| 679 |
+
attn_weights = None
|
| 680 |
+
|
| 681 |
+
return attn_output, attn_weights, past_key_value
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
class SEWSdpaAttention(SEWAttention):
|
| 685 |
+
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->SEW
|
| 686 |
+
def forward(
|
| 687 |
+
self,
|
| 688 |
+
hidden_states: torch.Tensor,
|
| 689 |
+
key_value_states: Optional[torch.Tensor] = None,
|
| 690 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 691 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 692 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
| 693 |
+
output_attentions: bool = False,
|
| 694 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 695 |
+
"""Input shape: Batch x Time x Channel"""
|
| 696 |
+
if output_attentions or layer_head_mask is not None:
|
| 697 |
+
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
| 698 |
+
logger.warning_once(
|
| 699 |
+
"SEWModel is using SEWSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
| 700 |
+
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 701 |
+
)
|
| 702 |
+
return super().forward(
|
| 703 |
+
hidden_states,
|
| 704 |
+
key_value_states=key_value_states,
|
| 705 |
+
past_key_value=past_key_value,
|
| 706 |
+
attention_mask=attention_mask,
|
| 707 |
+
layer_head_mask=layer_head_mask,
|
| 708 |
+
output_attentions=output_attentions,
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
| 712 |
+
# for the decoder
|
| 713 |
+
is_cross_attention = key_value_states is not None
|
| 714 |
+
|
| 715 |
+
bsz, tgt_len, _ = hidden_states.size()
|
| 716 |
+
|
| 717 |
+
# get query proj
|
| 718 |
+
query_states = self.q_proj(hidden_states)
|
| 719 |
+
# get key, value proj
|
| 720 |
+
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
| 721 |
+
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
| 722 |
+
# the provided `key_value_states` to support prefix tuning
|
| 723 |
+
if (
|
| 724 |
+
is_cross_attention
|
| 725 |
+
and past_key_value is not None
|
| 726 |
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
| 727 |
+
):
|
| 728 |
+
# reuse k,v, cross_attentions
|
| 729 |
+
key_states = past_key_value[0]
|
| 730 |
+
value_states = past_key_value[1]
|
| 731 |
+
elif is_cross_attention:
|
| 732 |
+
# cross_attentions
|
| 733 |
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
| 734 |
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
| 735 |
+
elif past_key_value is not None:
|
| 736 |
+
# reuse k, v, self_attention
|
| 737 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 738 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 739 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 740 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 741 |
+
else:
|
| 742 |
+
# self_attention
|
| 743 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 744 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 745 |
+
|
| 746 |
+
if self.is_decoder:
|
| 747 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 748 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 749 |
+
# key/value_states (first "if" case)
|
| 750 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 751 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 752 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 753 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 754 |
+
past_key_value = (key_states, value_states)
|
| 755 |
+
|
| 756 |
+
query_states = self._shape(query_states, tgt_len, bsz)
|
| 757 |
+
|
| 758 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 759 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 760 |
+
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
| 761 |
+
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
|
| 762 |
+
|
| 763 |
+
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
| 764 |
+
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
| 765 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 766 |
+
query_states,
|
| 767 |
+
key_states,
|
| 768 |
+
value_states,
|
| 769 |
+
attn_mask=attention_mask,
|
| 770 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 771 |
+
is_causal=is_causal,
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
| 775 |
+
raise ValueError(
|
| 776 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
| 777 |
+
f" {attn_output.size()}"
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
attn_output = attn_output.transpose(1, 2)
|
| 781 |
+
|
| 782 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
| 783 |
+
# partitioned across GPUs when using tensor-parallelism.
|
| 784 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
| 785 |
+
|
| 786 |
+
attn_output = self.out_proj(attn_output)
|
| 787 |
+
|
| 788 |
+
return attn_output, None, past_key_value
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
SEW_ATTENTION_CLASSES = {
|
| 792 |
+
"eager": SEWAttention,
|
| 793 |
+
"sdpa": SEWSdpaAttention,
|
| 794 |
+
"flash_attention_2": SEWFlashAttention2,
|
| 795 |
+
}
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->SEW
|
| 799 |
+
class SEWFeedForward(nn.Module):
|
| 800 |
+
def __init__(self, config):
|
| 801 |
+
super().__init__()
|
| 802 |
+
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
|
| 803 |
+
|
| 804 |
+
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 805 |
+
if isinstance(config.hidden_act, str):
|
| 806 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 807 |
+
else:
|
| 808 |
+
self.intermediate_act_fn = config.hidden_act
|
| 809 |
+
|
| 810 |
+
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 811 |
+
self.output_dropout = nn.Dropout(config.hidden_dropout)
|
| 812 |
+
|
| 813 |
+
def forward(self, hidden_states):
|
| 814 |
+
hidden_states = self.intermediate_dense(hidden_states)
|
| 815 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 816 |
+
hidden_states = self.intermediate_dropout(hidden_states)
|
| 817 |
+
|
| 818 |
+
hidden_states = self.output_dense(hidden_states)
|
| 819 |
+
hidden_states = self.output_dropout(hidden_states)
|
| 820 |
+
return hidden_states
|
| 821 |
+
|
| 822 |
+
|
| 823 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->SEW, WAV2VEC2->SEW
|
| 824 |
+
class SEWEncoderLayer(nn.Module):
|
| 825 |
+
def __init__(self, config):
|
| 826 |
+
super().__init__()
|
| 827 |
+
self.attention = SEW_ATTENTION_CLASSES[config._attn_implementation](
|
| 828 |
+
embed_dim=config.hidden_size,
|
| 829 |
+
num_heads=config.num_attention_heads,
|
| 830 |
+
dropout=config.attention_dropout,
|
| 831 |
+
is_decoder=False,
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
| 835 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 836 |
+
self.feed_forward = SEWFeedForward(config)
|
| 837 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 838 |
+
|
| 839 |
+
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
| 840 |
+
attn_residual = hidden_states
|
| 841 |
+
hidden_states, attn_weights, _ = self.attention(
|
| 842 |
+
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
| 843 |
+
)
|
| 844 |
+
hidden_states = self.dropout(hidden_states)
|
| 845 |
+
hidden_states = attn_residual + hidden_states
|
| 846 |
+
|
| 847 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 848 |
+
hidden_states = hidden_states + self.feed_forward(hidden_states)
|
| 849 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 850 |
+
|
| 851 |
+
outputs = (hidden_states,)
|
| 852 |
+
|
| 853 |
+
if output_attentions:
|
| 854 |
+
outputs += (attn_weights,)
|
| 855 |
+
|
| 856 |
+
return outputs
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
class SEWEncoder(nn.Module):
|
| 860 |
+
def __init__(self, config):
|
| 861 |
+
super().__init__()
|
| 862 |
+
self.config = config
|
| 863 |
+
self.pos_conv_embed = SEWPositionalConvEmbedding(config)
|
| 864 |
+
self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor)
|
| 865 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 866 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
| 867 |
+
self.layers = nn.ModuleList([SEWEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 868 |
+
self.upsample = SEWUpsampling(config)
|
| 869 |
+
self.gradient_checkpointing = False
|
| 870 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 871 |
+
|
| 872 |
+
def forward(
|
| 873 |
+
self,
|
| 874 |
+
hidden_states,
|
| 875 |
+
attention_mask=None,
|
| 876 |
+
output_attentions=False,
|
| 877 |
+
output_hidden_states=False,
|
| 878 |
+
return_dict=True,
|
| 879 |
+
):
|
| 880 |
+
all_hidden_states = () if output_hidden_states else None
|
| 881 |
+
all_self_attentions = () if output_attentions else None
|
| 882 |
+
|
| 883 |
+
if attention_mask is not None:
|
| 884 |
+
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
| 885 |
+
if self._use_flash_attention_2:
|
| 886 |
+
# make sure padded tokens output 0
|
| 887 |
+
hidden_states[~expand_attention_mask] = 0.0
|
| 888 |
+
# 2d mask is passed through the layers
|
| 889 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 890 |
+
else:
|
| 891 |
+
# make sure padded tokens output 0
|
| 892 |
+
hidden_states[~expand_attention_mask] = 0.0
|
| 893 |
+
input_lengths = (attention_mask.long()).sum(-1)
|
| 894 |
+
# apply pooling formula to get real output_lengths
|
| 895 |
+
output_lengths = input_lengths // self.config.squeeze_factor
|
| 896 |
+
max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor
|
| 897 |
+
attention_ids = (
|
| 898 |
+
torch.arange(0, max_encoder_length, device=output_lengths.device)
|
| 899 |
+
.view(1, -1)
|
| 900 |
+
.expand(output_lengths.shape[0], -1)
|
| 901 |
+
)
|
| 902 |
+
attention_mask = (attention_ids < output_lengths.view(-1, 1)).long()
|
| 903 |
+
|
| 904 |
+
# extend attention_mask
|
| 905 |
+
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
|
| 906 |
+
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
|
| 907 |
+
attention_mask = attention_mask.expand(
|
| 908 |
+
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
+
n_input_timesteps = hidden_states.shape[1]
|
| 912 |
+
|
| 913 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 914 |
+
position_embeddings = self.pos_conv_embed(hidden_states)
|
| 915 |
+
pooled_hidden_states = self.pool(hidden_states)
|
| 916 |
+
min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1))
|
| 917 |
+
hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length]
|
| 918 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 919 |
+
|
| 920 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 921 |
+
hidden_states = self.dropout(hidden_states)
|
| 922 |
+
|
| 923 |
+
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
| 924 |
+
|
| 925 |
+
for layer in self.layers:
|
| 926 |
+
if output_hidden_states:
|
| 927 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 928 |
+
|
| 929 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
| 930 |
+
dropout_probability = torch.rand([])
|
| 931 |
+
|
| 932 |
+
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
| 933 |
+
if not skip_the_layer or synced_gpus:
|
| 934 |
+
# under fsdp or deepspeed zero3 all gpus must run in sync
|
| 935 |
+
if self.gradient_checkpointing and self.training:
|
| 936 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 937 |
+
layer.__call__,
|
| 938 |
+
hidden_states,
|
| 939 |
+
attention_mask,
|
| 940 |
+
output_attentions,
|
| 941 |
+
)
|
| 942 |
+
else:
|
| 943 |
+
layer_outputs = layer(
|
| 944 |
+
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
| 945 |
+
)
|
| 946 |
+
hidden_states = layer_outputs[0]
|
| 947 |
+
|
| 948 |
+
if skip_the_layer:
|
| 949 |
+
layer_outputs = (None, None)
|
| 950 |
+
|
| 951 |
+
if output_attentions:
|
| 952 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 953 |
+
|
| 954 |
+
if output_hidden_states:
|
| 955 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 956 |
+
|
| 957 |
+
hidden_states = self.upsample(hidden_states)
|
| 958 |
+
if hidden_states.shape[1] < n_input_timesteps:
|
| 959 |
+
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1]))
|
| 960 |
+
|
| 961 |
+
if not return_dict:
|
| 962 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 963 |
+
return BaseModelOutput(
|
| 964 |
+
last_hidden_state=hidden_states,
|
| 965 |
+
hidden_states=all_hidden_states,
|
| 966 |
+
attentions=all_self_attentions,
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
class SEWPreTrainedModel(PreTrainedModel):
|
| 971 |
+
"""
|
| 972 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 973 |
+
models.
|
| 974 |
+
"""
|
| 975 |
+
|
| 976 |
+
config_class = SEWConfig
|
| 977 |
+
base_model_prefix = "sew"
|
| 978 |
+
main_input_name = "input_values"
|
| 979 |
+
supports_gradient_checkpointing = True
|
| 980 |
+
_supports_flash_attn_2 = True
|
| 981 |
+
_supports_sdpa = True
|
| 982 |
+
|
| 983 |
+
def _init_weights(self, module):
|
| 984 |
+
"""Initialize the weights"""
|
| 985 |
+
if isinstance(module, SEWPositionalConvEmbedding):
|
| 986 |
+
nn.init.normal_(
|
| 987 |
+
module.conv.weight,
|
| 988 |
+
mean=0,
|
| 989 |
+
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
|
| 990 |
+
)
|
| 991 |
+
nn.init.constant_(module.conv.bias, 0)
|
| 992 |
+
elif isinstance(module, nn.Linear):
|
| 993 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 994 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 995 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 996 |
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
| 997 |
+
module.bias.data.zero_()
|
| 998 |
+
module.weight.data.fill_(1.0)
|
| 999 |
+
elif isinstance(module, nn.Conv1d):
|
| 1000 |
+
if is_deepspeed_zero3_enabled():
|
| 1001 |
+
import deepspeed
|
| 1002 |
+
|
| 1003 |
+
if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
|
| 1004 |
+
with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
|
| 1005 |
+
nn.init.kaiming_normal_(module.weight.data)
|
| 1006 |
+
else:
|
| 1007 |
+
with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
|
| 1008 |
+
nn.init.kaiming_normal_(module.weight.data)
|
| 1009 |
+
else:
|
| 1010 |
+
nn.init.kaiming_normal_(module.weight.data)
|
| 1011 |
+
|
| 1012 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
|
| 1013 |
+
module.bias.data.zero_()
|
| 1014 |
+
|
| 1015 |
+
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
| 1016 |
+
"""
|
| 1017 |
+
Computes the output length of the convolutional layers
|
| 1018 |
+
"""
|
| 1019 |
+
|
| 1020 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
| 1021 |
+
# 1D convolutional layer output length formula taken
|
| 1022 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
| 1023 |
+
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
|
| 1024 |
+
|
| 1025 |
+
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
| 1026 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
| 1027 |
+
|
| 1028 |
+
return input_lengths
|
| 1029 |
+
|
| 1030 |
+
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
|
| 1031 |
+
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
| 1032 |
+
batch_size = attention_mask.shape[0]
|
| 1033 |
+
|
| 1034 |
+
attention_mask = torch.zeros(
|
| 1035 |
+
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
|
| 1036 |
+
)
|
| 1037 |
+
# these two operations makes sure that all values before the output lengths idxs are attended to
|
| 1038 |
+
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
|
| 1039 |
+
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
| 1040 |
+
return attention_mask
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
SEW_START_DOCSTRING = r"""
|
| 1044 |
+
SEW was proposed in [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech
|
| 1045 |
+
Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger,
|
| 1046 |
+
Yoav Artzi.
|
| 1047 |
+
|
| 1048 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1049 |
+
library implements for all its model (such as downloading or saving etc.).
|
| 1050 |
+
|
| 1051 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
| 1052 |
+
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 1053 |
+
behavior.
|
| 1054 |
+
|
| 1055 |
+
Parameters:
|
| 1056 |
+
config ([`SEWConfig`]): Model configuration class with all the parameters of the model.
|
| 1057 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 1058 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 1059 |
+
"""
|
| 1060 |
+
|
| 1061 |
+
|
| 1062 |
+
SEW_INPUTS_DOCSTRING = r"""
|
| 1063 |
+
Args:
|
| 1064 |
+
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 1065 |
+
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
|
| 1066 |
+
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
|
| 1067 |
+
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
|
| 1068 |
+
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
|
| 1069 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1070 |
+
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
|
| 1071 |
+
1]`:
|
| 1072 |
+
|
| 1073 |
+
- 1 for tokens that are **not masked**,
|
| 1074 |
+
- 0 for tokens that are **masked**.
|
| 1075 |
+
|
| 1076 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1077 |
+
|
| 1078 |
+
output_attentions (`bool`, *optional*):
|
| 1079 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1080 |
+
tensors for more detail.
|
| 1081 |
+
output_hidden_states (`bool`, *optional*):
|
| 1082 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1083 |
+
more detail.
|
| 1084 |
+
return_dict (`bool`, *optional*):
|
| 1085 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1086 |
+
"""
|
| 1087 |
+
|
| 1088 |
+
|
| 1089 |
+
@add_start_docstrings(
|
| 1090 |
+
"The bare SEW Model transformer outputting raw hidden-states without any specific head on top.",
|
| 1091 |
+
SEW_START_DOCSTRING,
|
| 1092 |
+
)
|
| 1093 |
+
class SEWModel(SEWPreTrainedModel):
|
| 1094 |
+
def __init__(self, config: SEWConfig):
|
| 1095 |
+
super().__init__(config)
|
| 1096 |
+
self.config = config
|
| 1097 |
+
self.feature_extractor = SEWFeatureEncoder(config)
|
| 1098 |
+
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
|
| 1099 |
+
|
| 1100 |
+
self.project_features = config.conv_dim[-1] != config.hidden_size
|
| 1101 |
+
if self.project_features:
|
| 1102 |
+
self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
| 1103 |
+
self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
|
| 1104 |
+
|
| 1105 |
+
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
| 1106 |
+
self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
|
| 1107 |
+
|
| 1108 |
+
self.encoder = SEWEncoder(config)
|
| 1109 |
+
|
| 1110 |
+
# Initialize weights and apply final processing
|
| 1111 |
+
self.post_init()
|
| 1112 |
+
|
| 1113 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
|
| 1114 |
+
def _mask_hidden_states(
|
| 1115 |
+
self,
|
| 1116 |
+
hidden_states: torch.FloatTensor,
|
| 1117 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
| 1118 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1119 |
+
):
|
| 1120 |
+
"""
|
| 1121 |
+
Masks extracted features along time axis and/or along feature axis according to
|
| 1122 |
+
[SpecAugment](https://arxiv.org/abs/1904.08779).
|
| 1123 |
+
"""
|
| 1124 |
+
|
| 1125 |
+
# `config.apply_spec_augment` can set masking to False
|
| 1126 |
+
if not getattr(self.config, "apply_spec_augment", True):
|
| 1127 |
+
return hidden_states
|
| 1128 |
+
|
| 1129 |
+
# generate indices & apply SpecAugment along time axis
|
| 1130 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
| 1131 |
+
|
| 1132 |
+
if mask_time_indices is not None:
|
| 1133 |
+
# apply SpecAugment along time axis with given mask_time_indices
|
| 1134 |
+
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
| 1135 |
+
elif self.config.mask_time_prob > 0 and self.training:
|
| 1136 |
+
mask_time_indices = _compute_mask_indices(
|
| 1137 |
+
(batch_size, sequence_length),
|
| 1138 |
+
mask_prob=self.config.mask_time_prob,
|
| 1139 |
+
mask_length=self.config.mask_time_length,
|
| 1140 |
+
attention_mask=attention_mask,
|
| 1141 |
+
min_masks=self.config.mask_time_min_masks,
|
| 1142 |
+
)
|
| 1143 |
+
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
|
| 1144 |
+
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
| 1145 |
+
|
| 1146 |
+
if self.config.mask_feature_prob > 0 and self.training:
|
| 1147 |
+
# generate indices & apply SpecAugment along feature axis
|
| 1148 |
+
mask_feature_indices = _compute_mask_indices(
|
| 1149 |
+
(batch_size, hidden_size),
|
| 1150 |
+
mask_prob=self.config.mask_feature_prob,
|
| 1151 |
+
mask_length=self.config.mask_feature_length,
|
| 1152 |
+
min_masks=self.config.mask_feature_min_masks,
|
| 1153 |
+
)
|
| 1154 |
+
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
|
| 1155 |
+
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
|
| 1156 |
+
hidden_states[mask_feature_indices] = 0
|
| 1157 |
+
|
| 1158 |
+
return hidden_states
|
| 1159 |
+
|
| 1160 |
+
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
|
| 1161 |
+
@add_code_sample_docstrings(
|
| 1162 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1163 |
+
output_type=BaseModelOutput,
|
| 1164 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1165 |
+
modality="audio",
|
| 1166 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 1167 |
+
)
|
| 1168 |
+
def forward(
|
| 1169 |
+
self,
|
| 1170 |
+
input_values: Optional[torch.Tensor],
|
| 1171 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1172 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
| 1173 |
+
output_attentions: Optional[bool] = None,
|
| 1174 |
+
output_hidden_states: Optional[bool] = None,
|
| 1175 |
+
return_dict: Optional[bool] = None,
|
| 1176 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 1177 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1178 |
+
output_hidden_states = (
|
| 1179 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1180 |
+
)
|
| 1181 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1182 |
+
|
| 1183 |
+
extract_features = self.feature_extractor(input_values)
|
| 1184 |
+
extract_features = extract_features.transpose(1, 2)
|
| 1185 |
+
extract_features = self.layer_norm(extract_features)
|
| 1186 |
+
|
| 1187 |
+
if self.project_features:
|
| 1188 |
+
extract_features = self.feature_projection(extract_features)
|
| 1189 |
+
hidden_states = self.feature_dropout(extract_features)
|
| 1190 |
+
|
| 1191 |
+
if attention_mask is not None:
|
| 1192 |
+
# compute reduced attention_mask corresponding to feature vectors
|
| 1193 |
+
attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
| 1194 |
+
|
| 1195 |
+
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
|
| 1196 |
+
|
| 1197 |
+
encoder_outputs = self.encoder(
|
| 1198 |
+
hidden_states,
|
| 1199 |
+
attention_mask=attention_mask,
|
| 1200 |
+
output_attentions=output_attentions,
|
| 1201 |
+
output_hidden_states=output_hidden_states,
|
| 1202 |
+
return_dict=return_dict,
|
| 1203 |
+
)
|
| 1204 |
+
|
| 1205 |
+
hidden_states = encoder_outputs[0]
|
| 1206 |
+
|
| 1207 |
+
if not return_dict:
|
| 1208 |
+
return (hidden_states,) + encoder_outputs[1:]
|
| 1209 |
+
|
| 1210 |
+
return BaseModelOutput(
|
| 1211 |
+
last_hidden_state=hidden_states,
|
| 1212 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1213 |
+
attentions=encoder_outputs.attentions,
|
| 1214 |
+
)
|
| 1215 |
+
|
| 1216 |
+
|
| 1217 |
+
@add_start_docstrings(
|
| 1218 |
+
"""SEW Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
|
| 1219 |
+
SEW_START_DOCSTRING,
|
| 1220 |
+
)
|
| 1221 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW
|
| 1222 |
+
class SEWForCTC(SEWPreTrainedModel):
|
| 1223 |
+
def __init__(self, config, target_lang: Optional[str] = None):
|
| 1224 |
+
super().__init__(config)
|
| 1225 |
+
|
| 1226 |
+
self.sew = SEWModel(config)
|
| 1227 |
+
self.dropout = nn.Dropout(config.final_dropout)
|
| 1228 |
+
|
| 1229 |
+
self.target_lang = target_lang
|
| 1230 |
+
|
| 1231 |
+
if config.vocab_size is None:
|
| 1232 |
+
raise ValueError(
|
| 1233 |
+
f"You are trying to instantiate {self.__class__} with a configuration that "
|
| 1234 |
+
"does not define the vocabulary size of the language model head. Please "
|
| 1235 |
+
"instantiate the model as follows: `SEWForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
|
| 1236 |
+
"or define `vocab_size` of your model's configuration."
|
| 1237 |
+
)
|
| 1238 |
+
output_hidden_size = (
|
| 1239 |
+
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
|
| 1240 |
+
)
|
| 1241 |
+
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
| 1242 |
+
|
| 1243 |
+
# Initialize weights and apply final processing
|
| 1244 |
+
self.post_init()
|
| 1245 |
+
|
| 1246 |
+
def tie_weights(self):
|
| 1247 |
+
"""
|
| 1248 |
+
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
| 1249 |
+
passing `target_lang=...` to `from_pretrained(...)`.
|
| 1250 |
+
|
| 1251 |
+
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
| 1252 |
+
"""
|
| 1253 |
+
|
| 1254 |
+
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
| 1255 |
+
# correctly load adapter layers for SEW so that we do not have to introduce a new API to
|
| 1256 |
+
# [`PreTrainedModel`]. While slightly hacky, SEW never has to tie input and output embeddings, so that it is
|
| 1257 |
+
# ok to repurpose this function here.
|
| 1258 |
+
target_lang = self.target_lang
|
| 1259 |
+
|
| 1260 |
+
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
| 1261 |
+
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
| 1262 |
+
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
| 1263 |
+
logger.info("By default `target_lang` is set to 'eng'.")
|
| 1264 |
+
elif target_lang is not None:
|
| 1265 |
+
self.load_adapter(target_lang, force_load=True)
|
| 1266 |
+
|
| 1267 |
+
def freeze_feature_extractor(self):
|
| 1268 |
+
"""
|
| 1269 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1270 |
+
not be updated during training.
|
| 1271 |
+
"""
|
| 1272 |
+
warnings.warn(
|
| 1273 |
+
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
|
| 1274 |
+
"Please use the equivalent `freeze_feature_encoder` method instead.",
|
| 1275 |
+
FutureWarning,
|
| 1276 |
+
)
|
| 1277 |
+
self.freeze_feature_encoder()
|
| 1278 |
+
|
| 1279 |
+
def freeze_feature_encoder(self):
|
| 1280 |
+
"""
|
| 1281 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1282 |
+
not be updated during training.
|
| 1283 |
+
"""
|
| 1284 |
+
self.sew.feature_extractor._freeze_parameters()
|
| 1285 |
+
|
| 1286 |
+
def freeze_base_model(self):
|
| 1287 |
+
"""
|
| 1288 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 1289 |
+
be updated during training. Only the classification head will be updated.
|
| 1290 |
+
"""
|
| 1291 |
+
for param in self.sew.parameters():
|
| 1292 |
+
param.requires_grad = False
|
| 1293 |
+
|
| 1294 |
+
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
|
| 1295 |
+
@add_code_sample_docstrings(
|
| 1296 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1297 |
+
output_type=CausalLMOutput,
|
| 1298 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1299 |
+
expected_output=_CTC_EXPECTED_OUTPUT,
|
| 1300 |
+
expected_loss=_CTC_EXPECTED_LOSS,
|
| 1301 |
+
)
|
| 1302 |
+
def forward(
|
| 1303 |
+
self,
|
| 1304 |
+
input_values: Optional[torch.Tensor],
|
| 1305 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1306 |
+
output_attentions: Optional[bool] = None,
|
| 1307 |
+
output_hidden_states: Optional[bool] = None,
|
| 1308 |
+
return_dict: Optional[bool] = None,
|
| 1309 |
+
labels: Optional[torch.Tensor] = None,
|
| 1310 |
+
) -> Union[Tuple, CausalLMOutput]:
|
| 1311 |
+
r"""
|
| 1312 |
+
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
| 1313 |
+
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
| 1314 |
+
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
|
| 1315 |
+
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
| 1316 |
+
config.vocab_size - 1]`.
|
| 1317 |
+
"""
|
| 1318 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1319 |
+
|
| 1320 |
+
if labels is not None and labels.max() >= self.config.vocab_size:
|
| 1321 |
+
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
| 1322 |
+
|
| 1323 |
+
outputs = self.sew(
|
| 1324 |
+
input_values,
|
| 1325 |
+
attention_mask=attention_mask,
|
| 1326 |
+
output_attentions=output_attentions,
|
| 1327 |
+
output_hidden_states=output_hidden_states,
|
| 1328 |
+
return_dict=return_dict,
|
| 1329 |
+
)
|
| 1330 |
+
|
| 1331 |
+
hidden_states = outputs[0]
|
| 1332 |
+
hidden_states = self.dropout(hidden_states)
|
| 1333 |
+
|
| 1334 |
+
logits = self.lm_head(hidden_states)
|
| 1335 |
+
|
| 1336 |
+
loss = None
|
| 1337 |
+
if labels is not None:
|
| 1338 |
+
# retrieve loss input_lengths from attention_mask
|
| 1339 |
+
attention_mask = (
|
| 1340 |
+
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
| 1341 |
+
)
|
| 1342 |
+
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
| 1343 |
+
|
| 1344 |
+
# assuming that padded tokens are filled with -100
|
| 1345 |
+
# when not being attended to
|
| 1346 |
+
labels_mask = labels >= 0
|
| 1347 |
+
target_lengths = labels_mask.sum(-1)
|
| 1348 |
+
flattened_targets = labels.masked_select(labels_mask)
|
| 1349 |
+
|
| 1350 |
+
# ctc_loss doesn't support fp16
|
| 1351 |
+
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
| 1352 |
+
|
| 1353 |
+
with torch.backends.cudnn.flags(enabled=False):
|
| 1354 |
+
loss = nn.functional.ctc_loss(
|
| 1355 |
+
log_probs,
|
| 1356 |
+
flattened_targets,
|
| 1357 |
+
input_lengths,
|
| 1358 |
+
target_lengths,
|
| 1359 |
+
blank=self.config.pad_token_id,
|
| 1360 |
+
reduction=self.config.ctc_loss_reduction,
|
| 1361 |
+
zero_infinity=self.config.ctc_zero_infinity,
|
| 1362 |
+
)
|
| 1363 |
+
|
| 1364 |
+
if not return_dict:
|
| 1365 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1366 |
+
return ((loss,) + output) if loss is not None else output
|
| 1367 |
+
|
| 1368 |
+
return CausalLMOutput(
|
| 1369 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
| 1370 |
+
)
|
| 1371 |
+
|
| 1372 |
+
|
| 1373 |
+
@add_start_docstrings(
|
| 1374 |
+
"""
|
| 1375 |
+
SEW Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB
|
| 1376 |
+
Keyword Spotting.
|
| 1377 |
+
""",
|
| 1378 |
+
SEW_START_DOCSTRING,
|
| 1379 |
+
)
|
| 1380 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW
|
| 1381 |
+
class SEWForSequenceClassification(SEWPreTrainedModel):
|
| 1382 |
+
def __init__(self, config):
|
| 1383 |
+
super().__init__(config)
|
| 1384 |
+
|
| 1385 |
+
if hasattr(config, "add_adapter") and config.add_adapter:
|
| 1386 |
+
raise ValueError(
|
| 1387 |
+
"Sequence classification does not support the use of SEW adapters (config.add_adapter=True)"
|
| 1388 |
+
)
|
| 1389 |
+
self.sew = SEWModel(config)
|
| 1390 |
+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
| 1391 |
+
if config.use_weighted_layer_sum:
|
| 1392 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 1393 |
+
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
| 1394 |
+
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
| 1395 |
+
|
| 1396 |
+
# Initialize weights and apply final processing
|
| 1397 |
+
self.post_init()
|
| 1398 |
+
|
| 1399 |
+
def freeze_feature_extractor(self):
|
| 1400 |
+
"""
|
| 1401 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
|
| 1402 |
+
not be updated during training.
|
| 1403 |
+
"""
|
| 1404 |
+
warnings.warn(
|
| 1405 |
+
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
|
| 1406 |
+
"Please use the equivalent `freeze_feature_encoder` method instead.",
|
| 1407 |
+
FutureWarning,
|
| 1408 |
+
)
|
| 1409 |
+
self.freeze_feature_encoder()
|
| 1410 |
+
|
| 1411 |
+
def freeze_feature_encoder(self):
|
| 1412 |
+
"""
|
| 1413 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1414 |
+
not be updated during training.
|
| 1415 |
+
"""
|
| 1416 |
+
self.sew.feature_extractor._freeze_parameters()
|
| 1417 |
+
|
| 1418 |
+
def freeze_base_model(self):
|
| 1419 |
+
"""
|
| 1420 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 1421 |
+
be updated during training. Only the classification head will be updated.
|
| 1422 |
+
"""
|
| 1423 |
+
for param in self.sew.parameters():
|
| 1424 |
+
param.requires_grad = False
|
| 1425 |
+
|
| 1426 |
+
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
|
| 1427 |
+
@add_code_sample_docstrings(
|
| 1428 |
+
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
| 1429 |
+
output_type=SequenceClassifierOutput,
|
| 1430 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1431 |
+
modality="audio",
|
| 1432 |
+
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
| 1433 |
+
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
| 1434 |
+
)
|
| 1435 |
+
def forward(
|
| 1436 |
+
self,
|
| 1437 |
+
input_values: Optional[torch.Tensor],
|
| 1438 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1439 |
+
output_attentions: Optional[bool] = None,
|
| 1440 |
+
output_hidden_states: Optional[bool] = None,
|
| 1441 |
+
return_dict: Optional[bool] = None,
|
| 1442 |
+
labels: Optional[torch.Tensor] = None,
|
| 1443 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 1444 |
+
r"""
|
| 1445 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1446 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1447 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1448 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1449 |
+
"""
|
| 1450 |
+
|
| 1451 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1452 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 1453 |
+
|
| 1454 |
+
outputs = self.sew(
|
| 1455 |
+
input_values,
|
| 1456 |
+
attention_mask=attention_mask,
|
| 1457 |
+
output_attentions=output_attentions,
|
| 1458 |
+
output_hidden_states=output_hidden_states,
|
| 1459 |
+
return_dict=return_dict,
|
| 1460 |
+
)
|
| 1461 |
+
|
| 1462 |
+
if self.config.use_weighted_layer_sum:
|
| 1463 |
+
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
| 1464 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
| 1465 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 1466 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 1467 |
+
else:
|
| 1468 |
+
hidden_states = outputs[0]
|
| 1469 |
+
|
| 1470 |
+
hidden_states = self.projector(hidden_states)
|
| 1471 |
+
if attention_mask is None:
|
| 1472 |
+
pooled_output = hidden_states.mean(dim=1)
|
| 1473 |
+
else:
|
| 1474 |
+
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
| 1475 |
+
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
| 1476 |
+
hidden_states[~expand_padding_mask] = 0.0
|
| 1477 |
+
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
| 1478 |
+
|
| 1479 |
+
logits = self.classifier(pooled_output)
|
| 1480 |
+
|
| 1481 |
+
loss = None
|
| 1482 |
+
if labels is not None:
|
| 1483 |
+
loss_fct = CrossEntropyLoss()
|
| 1484 |
+
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
| 1485 |
+
|
| 1486 |
+
if not return_dict:
|
| 1487 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1488 |
+
return ((loss,) + output) if loss is not None else output
|
| 1489 |
+
|
| 1490 |
+
return SequenceClassifierOutput(
|
| 1491 |
+
loss=loss,
|
| 1492 |
+
logits=logits,
|
| 1493 |
+
hidden_states=outputs.hidden_states,
|
| 1494 |
+
attentions=outputs.attentions,
|
| 1495 |
+
)
|
| 1496 |
+
|
| 1497 |
+
|
| 1498 |
+
__all__ = ["SEWForCTC", "SEWForSequenceClassification", "SEWModel", "SEWPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/sew_d/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_sew_d import *
|
| 22 |
+
from .modeling_sew_d import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/sew_d/configuration_sew_d.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 ASAPP Inc. and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""SEW-D model configuration"""
|
| 16 |
+
|
| 17 |
+
import functools
|
| 18 |
+
import operator
|
| 19 |
+
|
| 20 |
+
from ...configuration_utils import PretrainedConfig
|
| 21 |
+
from ...utils import logging
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SEWDConfig(PretrainedConfig):
|
| 28 |
+
r"""
|
| 29 |
+
This is the configuration class to store the configuration of a [`SEWDModel`]. It is used to instantiate a SEW-D
|
| 30 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 31 |
+
defaults will yield a similar configuration to that of the SEW-D
|
| 32 |
+
[asapp/sew-d-tiny-100k](https://huggingface.co/asapp/sew-d-tiny-100k) architecture.
|
| 33 |
+
|
| 34 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 35 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
vocab_size (`int`, *optional*, defaults to 32):
|
| 40 |
+
Vocabulary size of the SEW-D model. Defines the number of different tokens that can be represented by the
|
| 41 |
+
`inputs_ids` passed when calling [`SEWD`].
|
| 42 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 43 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 44 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 45 |
+
Number of hidden layers in the Transformer encoder.
|
| 46 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 47 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 48 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 49 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 50 |
+
squeeze_factor (`int`, *optional*, defaults to 2):
|
| 51 |
+
Sequence length downsampling factor after the encoder and upsampling factor after the transformer.
|
| 52 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
| 53 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 54 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 55 |
+
position_buckets (`int`, *optional*, defaults to 256):
|
| 56 |
+
The maximum size of relative position embeddings.
|
| 57 |
+
share_att_key (`bool`, *optional*, defaults to `True`):
|
| 58 |
+
Whether to share attention key with c2p and p2c.
|
| 59 |
+
relative_attention (`bool`, *optional*, defaults to `True`):
|
| 60 |
+
Whether to use relative position encoding.
|
| 61 |
+
pos_att_type (`Tuple[str]`, *optional*, defaults to `("p2c", "c2p")`):
|
| 62 |
+
The type of relative position attention, it can be a combination of `("p2c", "c2p")`, e.g. `("p2c")`,
|
| 63 |
+
`("p2c", "c2p")`, `("p2c", "c2p")`.
|
| 64 |
+
norm_rel_ebd (`str`, *optional*, defaults to `"layer_norm"`):
|
| 65 |
+
Whether to use layer norm in relative embedding (`"layer_norm"` if yes)
|
| 66 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_python"`):
|
| 67 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 68 |
+
`"relu"`, `"selu"`, `"gelu_python"` and `"gelu_new"` are supported.
|
| 69 |
+
hidden_dropout (`float`, *optional*, defaults to 0.1):
|
| 70 |
+
Deprecated. Not used by the model and will be removed in a future version.
|
| 71 |
+
activation_dropout (`float`, *optional*, defaults to 0.1):
|
| 72 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 73 |
+
attention_dropout (`float`, *optional*, defaults to 0.1):
|
| 74 |
+
The dropout ratio for the attention probabilities.
|
| 75 |
+
final_dropout (`float`, *optional*, defaults to 0.1):
|
| 76 |
+
The dropout probability for the final projection layer of [`SEWDForCTC`].
|
| 77 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 78 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 79 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-7):
|
| 80 |
+
The epsilon used by the layer normalization layers in the transformer encoder.
|
| 81 |
+
feature_layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
| 82 |
+
The epsilon used by the layer normalization after the feature encoder.
|
| 83 |
+
feat_extract_norm (`str`, *optional*, defaults to `"group"`):
|
| 84 |
+
The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
|
| 85 |
+
normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
|
| 86 |
+
convolutional layers.
|
| 87 |
+
feat_proj_dropout (`float`, *optional*, defaults to 0.0):
|
| 88 |
+
The dropout probability for output of the feature encoder.
|
| 89 |
+
feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
|
| 90 |
+
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
|
| 91 |
+
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 92 |
+
conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512)`):
|
| 93 |
+
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
|
| 94 |
+
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
|
| 95 |
+
conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1)`):
|
| 96 |
+
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
|
| 97 |
+
of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
|
| 98 |
+
conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1)`):
|
| 99 |
+
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
|
| 100 |
+
length of *conv_kernel* defines the number of convolutional layers and has to match the length of
|
| 101 |
+
*conv_dim*.
|
| 102 |
+
conv_bias (`bool`, *optional*, defaults to `False`):
|
| 103 |
+
Whether the 1D convolutional layers have a bias.
|
| 104 |
+
num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
|
| 105 |
+
Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
|
| 106 |
+
embeddings layer.
|
| 107 |
+
num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
|
| 108 |
+
Number of groups of 1D convolutional positional embeddings layer.
|
| 109 |
+
apply_spec_augment (`bool`, *optional*, defaults to `True`):
|
| 110 |
+
Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
|
| 111 |
+
[SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
| 112 |
+
Recognition](https://arxiv.org/abs/1904.08779).
|
| 113 |
+
mask_time_prob (`float`, *optional*, defaults to 0.05):
|
| 114 |
+
Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
|
| 115 |
+
procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
|
| 116 |
+
reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
|
| 117 |
+
masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
|
| 118 |
+
actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
|
| 119 |
+
mask_time_length (`int`, *optional*, defaults to 10):
|
| 120 |
+
Length of vector span along the time axis.
|
| 121 |
+
mask_time_min_masks (`int`, *optional*, defaults to 2),:
|
| 122 |
+
The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
|
| 123 |
+
irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
|
| 124 |
+
mask_time_min_masks''
|
| 125 |
+
mask_feature_prob (`float`, *optional*, defaults to 0.0):
|
| 126 |
+
Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
|
| 127 |
+
masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
|
| 128 |
+
the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
|
| 129 |
+
span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
|
| 130 |
+
may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
|
| 131 |
+
True`.
|
| 132 |
+
mask_feature_length (`int`, *optional*, defaults to 10):
|
| 133 |
+
Length of vector span along the feature axis.
|
| 134 |
+
mask_feature_min_masks (`int`, *optional*, defaults to 0),:
|
| 135 |
+
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
|
| 136 |
+
step, irrespectively of `mask_feature_prob`. Only relevant if
|
| 137 |
+
''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
|
| 138 |
+
diversity_loss_weight (`int`, *optional*, defaults to 0.1):
|
| 139 |
+
The weight of the codebook diversity loss component.
|
| 140 |
+
ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
|
| 141 |
+
Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
|
| 142 |
+
instance of [`SEWDForCTC`].
|
| 143 |
+
ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
|
| 144 |
+
Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
|
| 145 |
+
occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
|
| 146 |
+
of [`SEWDForCTC`].
|
| 147 |
+
use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
|
| 148 |
+
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
|
| 149 |
+
instance of [`Wav2Vec2ForSequenceClassification`].
|
| 150 |
+
classifier_proj_size (`int`, *optional*, defaults to 256):
|
| 151 |
+
Dimensionality of the projection before token mean-pooling for classification.
|
| 152 |
+
|
| 153 |
+
Example:
|
| 154 |
+
|
| 155 |
+
```python
|
| 156 |
+
>>> from transformers import SEWDConfig, SEWDModel
|
| 157 |
+
|
| 158 |
+
>>> # Initializing a SEW-D asapp/sew-d-tiny-100k style configuration
|
| 159 |
+
>>> configuration = SEWDConfig()
|
| 160 |
+
|
| 161 |
+
>>> # Initializing a model (with random weights) from the asapp/sew-d-tiny-100k style configuration
|
| 162 |
+
>>> model = SEWDModel(configuration)
|
| 163 |
+
|
| 164 |
+
>>> # Accessing the model configuration
|
| 165 |
+
>>> configuration = model.config
|
| 166 |
+
```"""
|
| 167 |
+
|
| 168 |
+
model_type = "sew-d"
|
| 169 |
+
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
vocab_size=32,
|
| 173 |
+
hidden_size=768,
|
| 174 |
+
num_hidden_layers=12,
|
| 175 |
+
num_attention_heads=12,
|
| 176 |
+
intermediate_size=3072,
|
| 177 |
+
squeeze_factor=2,
|
| 178 |
+
max_position_embeddings=512,
|
| 179 |
+
position_buckets=256,
|
| 180 |
+
share_att_key=True,
|
| 181 |
+
relative_attention=True,
|
| 182 |
+
pos_att_type=("p2c", "c2p"),
|
| 183 |
+
norm_rel_ebd="layer_norm",
|
| 184 |
+
hidden_act="gelu_python",
|
| 185 |
+
hidden_dropout=0.1,
|
| 186 |
+
activation_dropout=0.1,
|
| 187 |
+
attention_dropout=0.1,
|
| 188 |
+
feat_proj_dropout=0.0,
|
| 189 |
+
final_dropout=0.1,
|
| 190 |
+
initializer_range=0.02,
|
| 191 |
+
layer_norm_eps=1e-7,
|
| 192 |
+
feature_layer_norm_eps=1e-5,
|
| 193 |
+
feat_extract_norm="group",
|
| 194 |
+
feat_extract_activation="gelu",
|
| 195 |
+
conv_dim=(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512),
|
| 196 |
+
conv_stride=(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1),
|
| 197 |
+
conv_kernel=(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1),
|
| 198 |
+
conv_bias=False,
|
| 199 |
+
num_conv_pos_embeddings=128,
|
| 200 |
+
num_conv_pos_embedding_groups=16,
|
| 201 |
+
apply_spec_augment=True,
|
| 202 |
+
mask_time_prob=0.05,
|
| 203 |
+
mask_time_length=10,
|
| 204 |
+
mask_time_min_masks=2,
|
| 205 |
+
mask_feature_prob=0.0,
|
| 206 |
+
mask_feature_length=10,
|
| 207 |
+
mask_feature_min_masks=0,
|
| 208 |
+
ctc_loss_reduction="mean",
|
| 209 |
+
ctc_zero_infinity=False,
|
| 210 |
+
use_weighted_layer_sum=False,
|
| 211 |
+
classifier_proj_size=256,
|
| 212 |
+
pad_token_id=0,
|
| 213 |
+
bos_token_id=1,
|
| 214 |
+
eos_token_id=2,
|
| 215 |
+
**kwargs,
|
| 216 |
+
):
|
| 217 |
+
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
|
| 218 |
+
self.hidden_size = hidden_size
|
| 219 |
+
self.feat_extract_norm = feat_extract_norm
|
| 220 |
+
self.feat_extract_activation = feat_extract_activation
|
| 221 |
+
self.conv_dim = list(conv_dim)
|
| 222 |
+
self.conv_stride = list(conv_stride)
|
| 223 |
+
self.conv_kernel = list(conv_kernel)
|
| 224 |
+
self.conv_bias = conv_bias
|
| 225 |
+
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
| 226 |
+
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
| 227 |
+
self.num_feat_extract_layers = len(self.conv_dim)
|
| 228 |
+
self.num_hidden_layers = num_hidden_layers
|
| 229 |
+
self.intermediate_size = intermediate_size
|
| 230 |
+
self.squeeze_factor = squeeze_factor
|
| 231 |
+
self.max_position_embeddings = max_position_embeddings
|
| 232 |
+
self.position_buckets = position_buckets
|
| 233 |
+
self.share_att_key = share_att_key
|
| 234 |
+
self.relative_attention = relative_attention
|
| 235 |
+
self.norm_rel_ebd = norm_rel_ebd
|
| 236 |
+
self.pos_att_type = list(pos_att_type)
|
| 237 |
+
self.hidden_act = hidden_act
|
| 238 |
+
self.num_attention_heads = num_attention_heads
|
| 239 |
+
self._hidden_dropout = hidden_dropout
|
| 240 |
+
self.attention_dropout = attention_dropout
|
| 241 |
+
self.activation_dropout = activation_dropout
|
| 242 |
+
self.feat_proj_dropout = feat_proj_dropout
|
| 243 |
+
self.final_dropout = final_dropout
|
| 244 |
+
self.layer_norm_eps = layer_norm_eps
|
| 245 |
+
self.feature_layer_norm_eps = feature_layer_norm_eps
|
| 246 |
+
self.initializer_range = initializer_range
|
| 247 |
+
self.vocab_size = vocab_size
|
| 248 |
+
|
| 249 |
+
if (
|
| 250 |
+
(len(self.conv_stride) != self.num_feat_extract_layers)
|
| 251 |
+
or (len(self.conv_kernel) != self.num_feat_extract_layers)
|
| 252 |
+
or (len(self.conv_dim) != self.num_feat_extract_layers)
|
| 253 |
+
):
|
| 254 |
+
raise ValueError(
|
| 255 |
+
"Configuration for convolutional layers is incorrect. "
|
| 256 |
+
"It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
|
| 257 |
+
f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
|
| 258 |
+
f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
|
| 262 |
+
self.apply_spec_augment = apply_spec_augment
|
| 263 |
+
self.mask_time_prob = mask_time_prob
|
| 264 |
+
self.mask_time_length = mask_time_length
|
| 265 |
+
self.mask_time_min_masks = mask_time_min_masks
|
| 266 |
+
self.mask_feature_prob = mask_feature_prob
|
| 267 |
+
self.mask_feature_length = mask_feature_length
|
| 268 |
+
self.mask_feature_min_masks = mask_feature_min_masks
|
| 269 |
+
|
| 270 |
+
# ctc loss
|
| 271 |
+
self.ctc_loss_reduction = ctc_loss_reduction
|
| 272 |
+
self.ctc_zero_infinity = ctc_zero_infinity
|
| 273 |
+
|
| 274 |
+
# sequence classification
|
| 275 |
+
self.use_weighted_layer_sum = use_weighted_layer_sum
|
| 276 |
+
self.classifier_proj_size = classifier_proj_size
|
| 277 |
+
|
| 278 |
+
@property
|
| 279 |
+
def inputs_to_logits_ratio(self):
|
| 280 |
+
return functools.reduce(operator.mul, self.conv_stride, 1)
|
| 281 |
+
|
| 282 |
+
def to_dict(self):
|
| 283 |
+
"""
|
| 284 |
+
Serializes this instance to a Python dictionary.
|
| 285 |
+
"""
|
| 286 |
+
output = super().to_dict()
|
| 287 |
+
output["hidden_dropout"] = output.pop("_hidden_dropout")
|
| 288 |
+
return output
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
__all__ = ["SEWDConfig"]
|
docs/transformers/build/lib/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert SEW checkpoint."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
import fairseq
|
| 22 |
+
import torch
|
| 23 |
+
from fairseq.data import Dictionary
|
| 24 |
+
|
| 25 |
+
# Register SEW's fairseq modules
|
| 26 |
+
from sew_asapp import tasks # noqa: F401
|
| 27 |
+
|
| 28 |
+
from transformers import (
|
| 29 |
+
SEWDConfig,
|
| 30 |
+
SEWDForCTC,
|
| 31 |
+
SEWDModel,
|
| 32 |
+
Wav2Vec2CTCTokenizer,
|
| 33 |
+
Wav2Vec2FeatureExtractor,
|
| 34 |
+
Wav2Vec2Processor,
|
| 35 |
+
logging,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logging.set_verbosity_info()
|
| 40 |
+
logger = logging.get_logger(__name__)
|
| 41 |
+
|
| 42 |
+
MAPPING = {
|
| 43 |
+
"post_extract_proj": "feature_projection",
|
| 44 |
+
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
|
| 45 |
+
"attention.self.query_proj": "encoder.encoder.layer.*.attention.self.query_proj",
|
| 46 |
+
"attention.self.key_proj": "encoder.encoder.layer.*.attention.self.key_proj",
|
| 47 |
+
"attention.self.value_proj": "encoder.encoder.layer.*.attention.self.value_proj",
|
| 48 |
+
"attention.output.dense": "encoder.encoder.layer.*.attention.output.dense",
|
| 49 |
+
"attention.output.LayerNorm": "encoder.encoder.layer.*.attention.output.LayerNorm",
|
| 50 |
+
"intermediate.dense": "encoder.encoder.layer.*.intermediate.dense",
|
| 51 |
+
"output.dense": "encoder.encoder.layer.*.output.dense",
|
| 52 |
+
"output.LayerNorm": "encoder.encoder.layer.*.output.LayerNorm",
|
| 53 |
+
"encoder.encoder.rel_embeddings": "encoder.encoder.rel_embeddings",
|
| 54 |
+
"encoder.encoder.LayerNorm": "encoder.encoder.LayerNorm",
|
| 55 |
+
"encoder.upsample.0": "encoder.upsample.projection",
|
| 56 |
+
"encoder.layer_norm": "encoder.layer_norm",
|
| 57 |
+
"w2v_model.layer_norm": "layer_norm",
|
| 58 |
+
"w2v_encoder.proj": "lm_head",
|
| 59 |
+
"mask_emb": "masked_spec_embed",
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
| 64 |
+
for attribute in key.split("."):
|
| 65 |
+
hf_pointer = getattr(hf_pointer, attribute)
|
| 66 |
+
|
| 67 |
+
if weight_type is not None:
|
| 68 |
+
hf_shape = getattr(hf_pointer, weight_type).shape
|
| 69 |
+
else:
|
| 70 |
+
hf_shape = hf_pointer.shape
|
| 71 |
+
|
| 72 |
+
assert hf_shape == value.shape, (
|
| 73 |
+
f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
|
| 74 |
+
f" {value.shape} for {full_name}"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
if weight_type == "weight":
|
| 78 |
+
hf_pointer.weight.data = value
|
| 79 |
+
elif weight_type == "weight_g":
|
| 80 |
+
hf_pointer.weight_g.data = value
|
| 81 |
+
elif weight_type == "weight_v":
|
| 82 |
+
hf_pointer.weight_v.data = value
|
| 83 |
+
elif weight_type == "bias":
|
| 84 |
+
hf_pointer.bias.data = value
|
| 85 |
+
else:
|
| 86 |
+
hf_pointer.data = value
|
| 87 |
+
|
| 88 |
+
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
|
| 92 |
+
unused_weights = []
|
| 93 |
+
fairseq_dict = fairseq_model.state_dict()
|
| 94 |
+
|
| 95 |
+
feature_extractor = hf_model.sew_d.feature_extractor if is_finetuned else hf_model.feature_extractor
|
| 96 |
+
|
| 97 |
+
for name, value in fairseq_dict.items():
|
| 98 |
+
is_used = False
|
| 99 |
+
if "conv_layers" in name:
|
| 100 |
+
load_conv_layer(
|
| 101 |
+
name,
|
| 102 |
+
value,
|
| 103 |
+
feature_extractor,
|
| 104 |
+
unused_weights,
|
| 105 |
+
hf_model.config.feat_extract_norm == "group",
|
| 106 |
+
)
|
| 107 |
+
is_used = True
|
| 108 |
+
else:
|
| 109 |
+
for key, mapped_key in MAPPING.items():
|
| 110 |
+
mapped_key = "sew_d." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
|
| 111 |
+
|
| 112 |
+
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
|
| 113 |
+
is_used = True
|
| 114 |
+
if "*" in mapped_key:
|
| 115 |
+
layer_index = name.split(key)[0].split(".")[-2]
|
| 116 |
+
if not layer_index.isnumeric():
|
| 117 |
+
continue
|
| 118 |
+
mapped_key = mapped_key.replace("*", layer_index)
|
| 119 |
+
if "weight_g" in name:
|
| 120 |
+
weight_type = "weight_g"
|
| 121 |
+
elif "weight_v" in name:
|
| 122 |
+
weight_type = "weight_v"
|
| 123 |
+
elif "weight" in name:
|
| 124 |
+
weight_type = "weight"
|
| 125 |
+
elif "bias" in name:
|
| 126 |
+
weight_type = "bias"
|
| 127 |
+
else:
|
| 128 |
+
weight_type = None
|
| 129 |
+
set_recursively(hf_model, mapped_key, value, name, weight_type)
|
| 130 |
+
continue
|
| 131 |
+
if not is_used:
|
| 132 |
+
unused_weights.append(name)
|
| 133 |
+
|
| 134 |
+
logger.warning(f"Unused weights: {unused_weights}")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
|
| 138 |
+
name = full_name.split("conv_layers.")[-1]
|
| 139 |
+
items = name.split(".")
|
| 140 |
+
layer_id = int(items[0])
|
| 141 |
+
type_id = int(items[1])
|
| 142 |
+
|
| 143 |
+
if type_id == 0:
|
| 144 |
+
if "bias" in name:
|
| 145 |
+
assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
|
| 146 |
+
f"{full_name} has size {value.shape}, but"
|
| 147 |
+
f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
|
| 148 |
+
)
|
| 149 |
+
feature_extractor.conv_layers[layer_id].conv.bias.data = value
|
| 150 |
+
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
|
| 151 |
+
elif "weight" in name:
|
| 152 |
+
assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
|
| 153 |
+
f"{full_name} has size {value.shape}, but"
|
| 154 |
+
f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
|
| 155 |
+
)
|
| 156 |
+
feature_extractor.conv_layers[layer_id].conv.weight.data = value
|
| 157 |
+
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
|
| 158 |
+
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
|
| 159 |
+
if "bias" in name:
|
| 160 |
+
assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
|
| 161 |
+
f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
|
| 162 |
+
" found."
|
| 163 |
+
)
|
| 164 |
+
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
|
| 165 |
+
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
|
| 166 |
+
elif "weight" in name:
|
| 167 |
+
assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
|
| 168 |
+
f"{full_name} has size {value.shape}, but"
|
| 169 |
+
f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
|
| 170 |
+
)
|
| 171 |
+
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
|
| 172 |
+
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
|
| 173 |
+
else:
|
| 174 |
+
unused_weights.append(full_name)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def convert_config(model, is_finetuned):
|
| 178 |
+
config = SEWDConfig()
|
| 179 |
+
if is_finetuned:
|
| 180 |
+
fs_config = model.w2v_encoder.w2v_model.cfg
|
| 181 |
+
else:
|
| 182 |
+
fs_config = model.cfg
|
| 183 |
+
|
| 184 |
+
config.conv_bias = fs_config.conv_bias
|
| 185 |
+
conv_layers = eval(fs_config.conv_feature_layers)
|
| 186 |
+
config.conv_dim = [x[0] for x in conv_layers]
|
| 187 |
+
config.conv_kernel = [x[1] for x in conv_layers]
|
| 188 |
+
config.conv_stride = [x[2] for x in conv_layers]
|
| 189 |
+
config.feat_extract_activation = "gelu"
|
| 190 |
+
config.feat_extract_norm = "layer" if fs_config.extractor_mode == "layer_norm" else "group"
|
| 191 |
+
config.final_dropout = 0.0
|
| 192 |
+
config.hidden_act = fs_config.activation_fn.name
|
| 193 |
+
config.hidden_size = fs_config.encoder_embed_dim
|
| 194 |
+
config.initializer_range = 0.02
|
| 195 |
+
config.intermediate_size = fs_config.encoder_ffn_embed_dim
|
| 196 |
+
config.layer_norm_eps = 1e-5
|
| 197 |
+
config.layerdrop = fs_config.encoder_layerdrop
|
| 198 |
+
config.num_attention_heads = fs_config.encoder_attention_heads
|
| 199 |
+
config.num_conv_pos_embedding_groups = fs_config.conv_pos_groups
|
| 200 |
+
config.num_conv_pos_embeddings = fs_config.conv_pos
|
| 201 |
+
config.num_feat_extract_layers = len(conv_layers)
|
| 202 |
+
config.num_hidden_layers = fs_config.encoder_layers
|
| 203 |
+
config.squeeze_factor = fs_config.squeeze_factor
|
| 204 |
+
# DeBERTa-specific parameters:
|
| 205 |
+
config.max_position_embeddings = fs_config.max_position_embeddings
|
| 206 |
+
config.position_buckets = fs_config.position_buckets
|
| 207 |
+
config.share_att_key = fs_config.share_att_key
|
| 208 |
+
config.relative_attention = fs_config.relative_attention
|
| 209 |
+
config.position_biased_input = fs_config.position_biased_input
|
| 210 |
+
config.pos_att_type = tuple(fs_config.pos_att_type.split("|"))
|
| 211 |
+
config.norm_rel_ebd = fs_config.norm_rel_ebd
|
| 212 |
+
|
| 213 |
+
# take care of any params that are overridden by the Wav2VecCtc model
|
| 214 |
+
if is_finetuned:
|
| 215 |
+
fs_config = model.cfg
|
| 216 |
+
config.final_dropout = fs_config.final_dropout
|
| 217 |
+
config.layerdrop = fs_config.layerdrop
|
| 218 |
+
config.activation_dropout = fs_config.activation_dropout
|
| 219 |
+
config.apply_spec_augment = fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0
|
| 220 |
+
config.attention_dropout = fs_config.attention_dropout
|
| 221 |
+
config.feat_proj_dropout = fs_config.dropout_input
|
| 222 |
+
config.hidden_dropout = fs_config.dropout
|
| 223 |
+
config.mask_feature_length = fs_config.mask_channel_length
|
| 224 |
+
config.mask_feature_prob = fs_config.mask_channel_prob
|
| 225 |
+
config.mask_time_length = fs_config.mask_length
|
| 226 |
+
config.mask_time_prob = fs_config.mask_prob
|
| 227 |
+
|
| 228 |
+
config.feature_extractor_type = "Wav2Vec2FeatureExtractor"
|
| 229 |
+
config.tokenizer_class = "Wav2Vec2CTCTokenizer"
|
| 230 |
+
|
| 231 |
+
return config
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@torch.no_grad()
|
| 235 |
+
def convert_sew_checkpoint(
|
| 236 |
+
checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
|
| 237 |
+
):
|
| 238 |
+
"""
|
| 239 |
+
Copy/paste/tweak model's weights to transformers design.
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
if is_finetuned:
|
| 243 |
+
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
| 244 |
+
[checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
|
| 245 |
+
)
|
| 246 |
+
else:
|
| 247 |
+
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])
|
| 248 |
+
|
| 249 |
+
if config_path is not None:
|
| 250 |
+
config = SEWDConfig.from_pretrained(config_path)
|
| 251 |
+
else:
|
| 252 |
+
config = convert_config(model[0], is_finetuned)
|
| 253 |
+
model = model[0].eval()
|
| 254 |
+
|
| 255 |
+
return_attention_mask = True if config.feat_extract_norm == "layer" else False
|
| 256 |
+
feature_extractor = Wav2Vec2FeatureExtractor(
|
| 257 |
+
feature_size=1,
|
| 258 |
+
sampling_rate=16000,
|
| 259 |
+
padding_value=0,
|
| 260 |
+
do_normalize=True,
|
| 261 |
+
return_attention_mask=return_attention_mask,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
if is_finetuned:
|
| 265 |
+
if dict_path:
|
| 266 |
+
target_dict = Dictionary.load(dict_path)
|
| 267 |
+
|
| 268 |
+
# important change bos & pad token id since CTC symbol is <pad> and
|
| 269 |
+
# not <s> as in fairseq
|
| 270 |
+
target_dict.indices[target_dict.bos_word] = target_dict.pad_index
|
| 271 |
+
target_dict.indices[target_dict.pad_word] = target_dict.bos_index
|
| 272 |
+
config.bos_token_id = target_dict.pad_index
|
| 273 |
+
config.pad_token_id = target_dict.bos_index
|
| 274 |
+
config.eos_token_id = target_dict.eos_index
|
| 275 |
+
config.vocab_size = len(target_dict.symbols)
|
| 276 |
+
vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
|
| 277 |
+
if not os.path.isdir(pytorch_dump_folder_path):
|
| 278 |
+
logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path))
|
| 279 |
+
return
|
| 280 |
+
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
| 281 |
+
with open(vocab_path, "w", encoding="utf-8") as vocab_handle:
|
| 282 |
+
json.dump(target_dict.indices, vocab_handle)
|
| 283 |
+
tokenizer = Wav2Vec2CTCTokenizer(
|
| 284 |
+
vocab_path,
|
| 285 |
+
unk_token=target_dict.unk_word,
|
| 286 |
+
pad_token=target_dict.pad_word,
|
| 287 |
+
bos_token=target_dict.bos_word,
|
| 288 |
+
eos_token=target_dict.eos_word,
|
| 289 |
+
word_delimiter_token="|",
|
| 290 |
+
do_lower_case=False,
|
| 291 |
+
)
|
| 292 |
+
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
| 293 |
+
processor.save_pretrained(pytorch_dump_folder_path)
|
| 294 |
+
|
| 295 |
+
hf_model = SEWDForCTC(config)
|
| 296 |
+
else:
|
| 297 |
+
hf_model = SEWDModel(config)
|
| 298 |
+
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
| 299 |
+
|
| 300 |
+
recursively_load_weights(model, hf_model, is_finetuned)
|
| 301 |
+
|
| 302 |
+
hf_model.save_pretrained(pytorch_dump_folder_path)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
if __name__ == "__main__":
|
| 306 |
+
parser = argparse.ArgumentParser()
|
| 307 |
+
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
| 308 |
+
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
|
| 309 |
+
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
|
| 310 |
+
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
| 311 |
+
parser.add_argument(
|
| 312 |
+
"--is_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
|
| 313 |
+
)
|
| 314 |
+
args = parser.parse_args()
|
| 315 |
+
convert_sew_checkpoint(
|
| 316 |
+
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, args.is_finetuned
|
| 317 |
+
)
|
docs/transformers/build/lib/transformers/models/sew_d/modeling_sew_d.py
ADDED
|
@@ -0,0 +1,1748 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch SEW model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import warnings
|
| 19 |
+
from collections.abc import Sequence
|
| 20 |
+
from typing import Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
import torch.utils.checkpoint
|
| 25 |
+
from torch import nn
|
| 26 |
+
from torch.nn import CrossEntropyLoss, LayerNorm
|
| 27 |
+
|
| 28 |
+
from ...activations import ACT2FN
|
| 29 |
+
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 30 |
+
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
| 31 |
+
from ...modeling_utils import PreTrainedModel
|
| 32 |
+
from ...pytorch_utils import softmax_backward_data
|
| 33 |
+
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
| 34 |
+
from .configuration_sew_d import SEWDConfig
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__)
|
| 38 |
+
|
| 39 |
+
_HIDDEN_STATES_START_POSITION = 1
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# General docstring
|
| 43 |
+
_CONFIG_FOR_DOC = "SEWDConfig"
|
| 44 |
+
|
| 45 |
+
# Base docstring
|
| 46 |
+
_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k-ft-ls100h"
|
| 47 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 292, 384]
|
| 48 |
+
|
| 49 |
+
# CTC docstring
|
| 50 |
+
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTIL OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
| 51 |
+
_CTC_EXPECTED_LOSS = 0.21
|
| 52 |
+
|
| 53 |
+
# Audio class docstring
|
| 54 |
+
_SEQ_CLASS_CHECKPOINT = "anton-l/sew-d-mid-400k-ft-keyword-spotting"
|
| 55 |
+
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
|
| 56 |
+
_SEQ_CLASS_EXPECTED_LOSS = 3.16
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
|
| 60 |
+
def _compute_mask_indices(
|
| 61 |
+
shape: Tuple[int, int],
|
| 62 |
+
mask_prob: float,
|
| 63 |
+
mask_length: int,
|
| 64 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 65 |
+
min_masks: int = 0,
|
| 66 |
+
) -> np.ndarray:
|
| 67 |
+
"""
|
| 68 |
+
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
|
| 69 |
+
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
|
| 70 |
+
CPU as part of the preprocessing during training.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
|
| 74 |
+
the first element is the batch size and the second element is the length of the axis to span.
|
| 75 |
+
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
|
| 76 |
+
independently generated mask spans of length `mask_length` is computed by
|
| 77 |
+
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
|
| 78 |
+
actual percentage will be smaller.
|
| 79 |
+
mask_length: size of the mask
|
| 80 |
+
min_masks: minimum number of masked spans
|
| 81 |
+
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
|
| 82 |
+
each batch dimension.
|
| 83 |
+
"""
|
| 84 |
+
batch_size, sequence_length = shape
|
| 85 |
+
|
| 86 |
+
if mask_length < 1:
|
| 87 |
+
raise ValueError("`mask_length` has to be bigger than 0.")
|
| 88 |
+
|
| 89 |
+
if mask_length > sequence_length:
|
| 90 |
+
raise ValueError(
|
| 91 |
+
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
|
| 92 |
+
f" and `sequence_length`: {sequence_length}`"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# epsilon is used for probabilistic rounding
|
| 96 |
+
epsilon = np.random.rand(1).item()
|
| 97 |
+
|
| 98 |
+
def compute_num_masked_span(input_length):
|
| 99 |
+
"""Given input length, compute how many spans should be masked"""
|
| 100 |
+
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
|
| 101 |
+
num_masked_span = max(num_masked_span, min_masks)
|
| 102 |
+
|
| 103 |
+
# make sure num masked span <= sequence_length
|
| 104 |
+
if num_masked_span * mask_length > sequence_length:
|
| 105 |
+
num_masked_span = sequence_length // mask_length
|
| 106 |
+
|
| 107 |
+
# make sure num_masked span is also <= input_length - (mask_length - 1)
|
| 108 |
+
if input_length - (mask_length - 1) < num_masked_span:
|
| 109 |
+
num_masked_span = max(input_length - (mask_length - 1), 0)
|
| 110 |
+
|
| 111 |
+
return num_masked_span
|
| 112 |
+
|
| 113 |
+
# compute number of masked spans in batch
|
| 114 |
+
input_lengths = (
|
| 115 |
+
attention_mask.detach().sum(-1).tolist()
|
| 116 |
+
if attention_mask is not None
|
| 117 |
+
else [sequence_length for _ in range(batch_size)]
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# SpecAugment mask to fill
|
| 121 |
+
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
|
| 122 |
+
spec_aug_mask_idxs = []
|
| 123 |
+
|
| 124 |
+
max_num_masked_span = compute_num_masked_span(sequence_length)
|
| 125 |
+
|
| 126 |
+
if max_num_masked_span == 0:
|
| 127 |
+
return spec_aug_mask
|
| 128 |
+
|
| 129 |
+
for input_length in input_lengths:
|
| 130 |
+
# compute num of masked spans for this input
|
| 131 |
+
num_masked_span = compute_num_masked_span(input_length)
|
| 132 |
+
|
| 133 |
+
# get random indices to mask
|
| 134 |
+
spec_aug_mask_idx = np.random.choice(
|
| 135 |
+
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# pick first sampled index that will serve as a dummy index to pad vector
|
| 139 |
+
# to ensure same dimension for all batches due to probabilistic rounding
|
| 140 |
+
# Picking first sample just pads those vectors twice.
|
| 141 |
+
if len(spec_aug_mask_idx) == 0:
|
| 142 |
+
# this case can only happen if `input_length` is strictly smaller then
|
| 143 |
+
# `sequence_length` in which case the last token has to be a padding
|
| 144 |
+
# token which we can use as a dummy mask id
|
| 145 |
+
dummy_mask_idx = sequence_length - 1
|
| 146 |
+
else:
|
| 147 |
+
dummy_mask_idx = spec_aug_mask_idx[0]
|
| 148 |
+
|
| 149 |
+
spec_aug_mask_idx = np.concatenate(
|
| 150 |
+
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
|
| 151 |
+
)
|
| 152 |
+
spec_aug_mask_idxs.append(spec_aug_mask_idx)
|
| 153 |
+
|
| 154 |
+
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
|
| 155 |
+
|
| 156 |
+
# expand masked indices to masked spans
|
| 157 |
+
spec_aug_mask_idxs = np.broadcast_to(
|
| 158 |
+
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
|
| 159 |
+
)
|
| 160 |
+
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
|
| 161 |
+
|
| 162 |
+
# add offset to the starting indexes so that indexes now create a span
|
| 163 |
+
offsets = np.arange(mask_length)[None, None, :]
|
| 164 |
+
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
|
| 165 |
+
batch_size, max_num_masked_span * mask_length
|
| 166 |
+
)
|
| 167 |
+
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
| 168 |
+
|
| 169 |
+
# ensure that we cannot have indices larger than sequence_length
|
| 170 |
+
if spec_aug_mask_idxs.max() > sequence_length - 1:
|
| 171 |
+
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
|
| 172 |
+
|
| 173 |
+
# scatter indices to mask
|
| 174 |
+
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
| 175 |
+
|
| 176 |
+
return spec_aug_mask
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def make_log_bucket_position(relative_pos, bucket_size, max_position):
|
| 180 |
+
sign = torch.sign(relative_pos)
|
| 181 |
+
mid = bucket_size // 2
|
| 182 |
+
abs_pos = torch.where(
|
| 183 |
+
(relative_pos < mid) & (relative_pos > -mid),
|
| 184 |
+
torch.tensor(mid - 1).type_as(relative_pos),
|
| 185 |
+
torch.abs(relative_pos),
|
| 186 |
+
)
|
| 187 |
+
log_pos = (
|
| 188 |
+
torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
|
| 189 |
+
)
|
| 190 |
+
bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
|
| 191 |
+
return bucket_pos
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1, device=None):
|
| 195 |
+
"""
|
| 196 |
+
Build relative position according to the query and key
|
| 197 |
+
|
| 198 |
+
We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
|
| 199 |
+
\\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
|
| 200 |
+
P_k\\)
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
query_size (int): the length of query
|
| 204 |
+
key_size (int): the length of key
|
| 205 |
+
bucket_size (int): the size of position bucket
|
| 206 |
+
max_position (int): the maximum allowed absolute position
|
| 207 |
+
device (`torch.device`): the device on which tensors will be created.
|
| 208 |
+
|
| 209 |
+
Return:
|
| 210 |
+
`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
q_ids = torch.arange(0, query_size, device=device)
|
| 214 |
+
k_ids = torch.arange(0, key_size, device=device)
|
| 215 |
+
rel_pos_ids = q_ids[:, None] - k_ids[None, :]
|
| 216 |
+
if bucket_size > 0 and max_position > 0:
|
| 217 |
+
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
|
| 218 |
+
rel_pos_ids = rel_pos_ids.to(torch.long)
|
| 219 |
+
rel_pos_ids = rel_pos_ids[:query_size, :]
|
| 220 |
+
rel_pos_ids = rel_pos_ids.unsqueeze(0)
|
| 221 |
+
return rel_pos_ids
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@torch.jit.script
|
| 225 |
+
# Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand
|
| 226 |
+
def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
|
| 227 |
+
return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@torch.jit.script
|
| 231 |
+
# Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand
|
| 232 |
+
def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
|
| 233 |
+
return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@torch.jit.script
|
| 237 |
+
# Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand
|
| 238 |
+
def pos_dynamic_expand(pos_index, p2c_att, key_layer):
|
| 239 |
+
return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def get_mask(input, local_context):
|
| 243 |
+
if not isinstance(local_context, DropoutContext):
|
| 244 |
+
dropout = local_context
|
| 245 |
+
mask = None
|
| 246 |
+
else:
|
| 247 |
+
dropout = local_context.dropout
|
| 248 |
+
dropout *= local_context.scale
|
| 249 |
+
mask = local_context.mask if local_context.reuse_mask else None
|
| 250 |
+
|
| 251 |
+
if dropout > 0 and mask is None:
|
| 252 |
+
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
|
| 253 |
+
|
| 254 |
+
if isinstance(local_context, DropoutContext):
|
| 255 |
+
if local_context.mask is None:
|
| 256 |
+
local_context.mask = mask
|
| 257 |
+
|
| 258 |
+
return mask, dropout
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SEWD
|
| 262 |
+
class SEWDNoLayerNormConvLayer(nn.Module):
|
| 263 |
+
def __init__(self, config, layer_id=0):
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 266 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
| 267 |
+
|
| 268 |
+
self.conv = nn.Conv1d(
|
| 269 |
+
self.in_conv_dim,
|
| 270 |
+
self.out_conv_dim,
|
| 271 |
+
kernel_size=config.conv_kernel[layer_id],
|
| 272 |
+
stride=config.conv_stride[layer_id],
|
| 273 |
+
bias=config.conv_bias,
|
| 274 |
+
)
|
| 275 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 276 |
+
|
| 277 |
+
def forward(self, hidden_states):
|
| 278 |
+
hidden_states = self.conv(hidden_states)
|
| 279 |
+
hidden_states = self.activation(hidden_states)
|
| 280 |
+
return hidden_states
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SEWD
|
| 284 |
+
class SEWDLayerNormConvLayer(nn.Module):
|
| 285 |
+
def __init__(self, config, layer_id=0):
|
| 286 |
+
super().__init__()
|
| 287 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 288 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
| 289 |
+
|
| 290 |
+
self.conv = nn.Conv1d(
|
| 291 |
+
self.in_conv_dim,
|
| 292 |
+
self.out_conv_dim,
|
| 293 |
+
kernel_size=config.conv_kernel[layer_id],
|
| 294 |
+
stride=config.conv_stride[layer_id],
|
| 295 |
+
bias=config.conv_bias,
|
| 296 |
+
)
|
| 297 |
+
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
|
| 298 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 299 |
+
|
| 300 |
+
def forward(self, hidden_states):
|
| 301 |
+
hidden_states = self.conv(hidden_states)
|
| 302 |
+
|
| 303 |
+
hidden_states = hidden_states.transpose(-2, -1)
|
| 304 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 305 |
+
hidden_states = hidden_states.transpose(-2, -1)
|
| 306 |
+
|
| 307 |
+
hidden_states = self.activation(hidden_states)
|
| 308 |
+
return hidden_states
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SEWD
|
| 312 |
+
class SEWDGroupNormConvLayer(nn.Module):
|
| 313 |
+
def __init__(self, config, layer_id=0):
|
| 314 |
+
super().__init__()
|
| 315 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 316 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
| 317 |
+
|
| 318 |
+
self.conv = nn.Conv1d(
|
| 319 |
+
self.in_conv_dim,
|
| 320 |
+
self.out_conv_dim,
|
| 321 |
+
kernel_size=config.conv_kernel[layer_id],
|
| 322 |
+
stride=config.conv_stride[layer_id],
|
| 323 |
+
bias=config.conv_bias,
|
| 324 |
+
)
|
| 325 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 326 |
+
|
| 327 |
+
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
|
| 328 |
+
|
| 329 |
+
def forward(self, hidden_states):
|
| 330 |
+
hidden_states = self.conv(hidden_states)
|
| 331 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 332 |
+
hidden_states = self.activation(hidden_states)
|
| 333 |
+
return hidden_states
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
# Copied from transformers.models.sew.modeling_sew.SEWPositionalConvEmbedding with SEW->SEWD
|
| 337 |
+
class SEWDPositionalConvEmbedding(nn.Module):
|
| 338 |
+
def __init__(self, config):
|
| 339 |
+
super().__init__()
|
| 340 |
+
self.conv = nn.Conv1d(
|
| 341 |
+
config.hidden_size,
|
| 342 |
+
config.hidden_size,
|
| 343 |
+
kernel_size=config.num_conv_pos_embeddings,
|
| 344 |
+
padding=config.num_conv_pos_embeddings // 2,
|
| 345 |
+
groups=config.num_conv_pos_embedding_groups,
|
| 346 |
+
stride=config.squeeze_factor,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
weight_norm = nn.utils.weight_norm
|
| 350 |
+
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
| 351 |
+
weight_norm = nn.utils.parametrizations.weight_norm
|
| 352 |
+
|
| 353 |
+
if is_deepspeed_zero3_enabled():
|
| 354 |
+
import deepspeed
|
| 355 |
+
|
| 356 |
+
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
| 357 |
+
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
| 358 |
+
if hasattr(self.conv, "parametrizations"):
|
| 359 |
+
weight_g = self.conv.parametrizations.weight.original0
|
| 360 |
+
weight_v = self.conv.parametrizations.weight.original1
|
| 361 |
+
else:
|
| 362 |
+
weight_g = self.conv.weight_g
|
| 363 |
+
weight_v = self.conv.weight_v
|
| 364 |
+
deepspeed.zero.register_external_parameter(self, weight_v)
|
| 365 |
+
deepspeed.zero.register_external_parameter(self, weight_g)
|
| 366 |
+
else:
|
| 367 |
+
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
| 368 |
+
|
| 369 |
+
self.padding = SEWDSamePadLayer(config.num_conv_pos_embeddings)
|
| 370 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 371 |
+
|
| 372 |
+
def forward(self, hidden_states):
|
| 373 |
+
hidden_states = self.conv(hidden_states)
|
| 374 |
+
hidden_states = self.padding(hidden_states)
|
| 375 |
+
hidden_states = self.activation(hidden_states)
|
| 376 |
+
|
| 377 |
+
return hidden_states
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->SEW
|
| 381 |
+
class SEWDSamePadLayer(nn.Module):
|
| 382 |
+
def __init__(self, num_conv_pos_embeddings):
|
| 383 |
+
super().__init__()
|
| 384 |
+
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
|
| 385 |
+
|
| 386 |
+
def forward(self, hidden_states):
|
| 387 |
+
if self.num_pad_remove > 0:
|
| 388 |
+
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
|
| 389 |
+
return hidden_states
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# Copied from transformers.models.sew.modeling_sew.SEWUpsampling with SEW->SEWD
|
| 393 |
+
class SEWDUpsampling(nn.Module):
|
| 394 |
+
def __init__(self, config):
|
| 395 |
+
super().__init__()
|
| 396 |
+
self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor)
|
| 397 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 398 |
+
self.squeeze_factor = config.squeeze_factor
|
| 399 |
+
|
| 400 |
+
def forward(self, hidden_states):
|
| 401 |
+
hidden_states = self.projection(hidden_states)
|
| 402 |
+
hidden_states = self.activation(hidden_states)
|
| 403 |
+
|
| 404 |
+
if self.squeeze_factor > 1:
|
| 405 |
+
# transform embedding channels to sequence length
|
| 406 |
+
bsz, src_len, src_embed_dim = hidden_states.size()
|
| 407 |
+
tgt_len = src_len * self.squeeze_factor
|
| 408 |
+
tgt_embed_dim = src_embed_dim // self.squeeze_factor
|
| 409 |
+
hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim)
|
| 410 |
+
hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim)
|
| 411 |
+
|
| 412 |
+
return hidden_states
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->SEWD
|
| 416 |
+
class SEWDFeatureEncoder(nn.Module):
|
| 417 |
+
"""Construct the features from raw audio waveform"""
|
| 418 |
+
|
| 419 |
+
def __init__(self, config):
|
| 420 |
+
super().__init__()
|
| 421 |
+
|
| 422 |
+
if config.feat_extract_norm == "group":
|
| 423 |
+
conv_layers = [SEWDGroupNormConvLayer(config, layer_id=0)] + [
|
| 424 |
+
SEWDNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
|
| 425 |
+
]
|
| 426 |
+
elif config.feat_extract_norm == "layer":
|
| 427 |
+
conv_layers = [SEWDLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
|
| 428 |
+
else:
|
| 429 |
+
raise ValueError(
|
| 430 |
+
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
| 431 |
+
)
|
| 432 |
+
self.conv_layers = nn.ModuleList(conv_layers)
|
| 433 |
+
self.gradient_checkpointing = False
|
| 434 |
+
self._requires_grad = True
|
| 435 |
+
|
| 436 |
+
def _freeze_parameters(self):
|
| 437 |
+
for param in self.parameters():
|
| 438 |
+
param.requires_grad = False
|
| 439 |
+
self._requires_grad = False
|
| 440 |
+
|
| 441 |
+
def forward(self, input_values):
|
| 442 |
+
hidden_states = input_values[:, None]
|
| 443 |
+
|
| 444 |
+
# make sure hidden_states require grad for gradient_checkpointing
|
| 445 |
+
if self._requires_grad and self.training:
|
| 446 |
+
hidden_states.requires_grad = True
|
| 447 |
+
|
| 448 |
+
for conv_layer in self.conv_layers:
|
| 449 |
+
if self._requires_grad and self.gradient_checkpointing and self.training:
|
| 450 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 451 |
+
conv_layer.__call__,
|
| 452 |
+
hidden_states,
|
| 453 |
+
)
|
| 454 |
+
else:
|
| 455 |
+
hidden_states = conv_layer(hidden_states)
|
| 456 |
+
|
| 457 |
+
return hidden_states
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
class SEWDFeatureExtractor(SEWDFeatureEncoder):
|
| 461 |
+
def __init__(self, config):
|
| 462 |
+
super().__init__(config)
|
| 463 |
+
warnings.warn(
|
| 464 |
+
f"The class `{self.__class__.__name__}` has been depreciated "
|
| 465 |
+
"and will be removed in Transformers v5. "
|
| 466 |
+
f"Use `{self.__class__.__bases__[0].__name__}` instead.",
|
| 467 |
+
FutureWarning,
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
class ContextPooler(nn.Module):
|
| 472 |
+
def __init__(self, config):
|
| 473 |
+
super().__init__()
|
| 474 |
+
self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
|
| 475 |
+
self.dropout = StableDropout(config.pooler_dropout)
|
| 476 |
+
self.config = config
|
| 477 |
+
|
| 478 |
+
def forward(self, hidden_states):
|
| 479 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 480 |
+
# to the first token.
|
| 481 |
+
|
| 482 |
+
context_token = hidden_states[:, 0]
|
| 483 |
+
context_token = self.dropout(context_token)
|
| 484 |
+
pooled_output = self.dense(context_token)
|
| 485 |
+
pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
|
| 486 |
+
return pooled_output
|
| 487 |
+
|
| 488 |
+
@property
|
| 489 |
+
def output_dim(self):
|
| 490 |
+
return self.config.hidden_size
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class XSoftmax(torch.autograd.Function):
|
| 494 |
+
"""
|
| 495 |
+
Masked Softmax which is optimized for saving memory
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
input (`torch.tensor`): The input tensor that will apply softmax.
|
| 499 |
+
mask (`torch.IntTensor`):
|
| 500 |
+
The mask matrix where 0 indicate that element will be ignored in the softmax calculation.
|
| 501 |
+
dim (int): The dimension that will apply softmax
|
| 502 |
+
|
| 503 |
+
Example:
|
| 504 |
+
|
| 505 |
+
```python
|
| 506 |
+
>>> import torch
|
| 507 |
+
>>> from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax
|
| 508 |
+
|
| 509 |
+
>>> # Make a tensor
|
| 510 |
+
>>> x = torch.randn([4, 20, 100])
|
| 511 |
+
|
| 512 |
+
>>> # Create a mask
|
| 513 |
+
>>> mask = (x > 0).int()
|
| 514 |
+
|
| 515 |
+
>>> # Specify the dimension to apply softmax
|
| 516 |
+
>>> dim = -1
|
| 517 |
+
|
| 518 |
+
>>> y = XSoftmax.apply(x, mask, dim)
|
| 519 |
+
```"""
|
| 520 |
+
|
| 521 |
+
@staticmethod
|
| 522 |
+
def forward(ctx, input, mask, dim):
|
| 523 |
+
ctx.dim = dim
|
| 524 |
+
rmask = ~(mask.to(torch.bool))
|
| 525 |
+
|
| 526 |
+
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
|
| 527 |
+
output = torch.softmax(output, ctx.dim)
|
| 528 |
+
output.masked_fill_(rmask, 0)
|
| 529 |
+
ctx.save_for_backward(output)
|
| 530 |
+
return output
|
| 531 |
+
|
| 532 |
+
@staticmethod
|
| 533 |
+
def backward(ctx, grad_output):
|
| 534 |
+
(output,) = ctx.saved_tensors
|
| 535 |
+
inputGrad = softmax_backward_data(ctx, grad_output, output, ctx.dim, output)
|
| 536 |
+
return inputGrad, None, None
|
| 537 |
+
|
| 538 |
+
@staticmethod
|
| 539 |
+
def symbolic(g, self, mask, dim):
|
| 540 |
+
import torch.onnx.symbolic_helper as sym_help
|
| 541 |
+
from torch.onnx.symbolic_opset9 import masked_fill, softmax
|
| 542 |
+
|
| 543 |
+
mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"])
|
| 544 |
+
r_mask = g.op(
|
| 545 |
+
"Cast",
|
| 546 |
+
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
|
| 547 |
+
to_i=sym_help.cast_pytorch_to_onnx["Bool"],
|
| 548 |
+
)
|
| 549 |
+
output = masked_fill(
|
| 550 |
+
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
|
| 551 |
+
)
|
| 552 |
+
output = softmax(g, output, dim)
|
| 553 |
+
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool)))
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
class DropoutContext:
|
| 557 |
+
def __init__(self):
|
| 558 |
+
self.dropout = 0
|
| 559 |
+
self.mask = None
|
| 560 |
+
self.scale = 1
|
| 561 |
+
self.reuse_mask = True
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
class XDropout(torch.autograd.Function):
|
| 565 |
+
"""Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
|
| 566 |
+
|
| 567 |
+
@staticmethod
|
| 568 |
+
def forward(ctx, input, local_ctx):
|
| 569 |
+
mask, dropout = get_mask(input, local_ctx)
|
| 570 |
+
ctx.scale = 1.0 / (1 - dropout)
|
| 571 |
+
if dropout > 0:
|
| 572 |
+
ctx.save_for_backward(mask)
|
| 573 |
+
return input.masked_fill(mask, 0) * ctx.scale
|
| 574 |
+
else:
|
| 575 |
+
return input
|
| 576 |
+
|
| 577 |
+
@staticmethod
|
| 578 |
+
def backward(ctx, grad_output):
|
| 579 |
+
if ctx.scale > 1:
|
| 580 |
+
(mask,) = ctx.saved_tensors
|
| 581 |
+
return grad_output.masked_fill(mask, 0) * ctx.scale, None
|
| 582 |
+
else:
|
| 583 |
+
return grad_output, None
|
| 584 |
+
|
| 585 |
+
@staticmethod
|
| 586 |
+
def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value:
|
| 587 |
+
from torch.onnx import symbolic_opset12
|
| 588 |
+
|
| 589 |
+
dropout_p = local_ctx
|
| 590 |
+
if isinstance(local_ctx, DropoutContext):
|
| 591 |
+
dropout_p = local_ctx.dropout
|
| 592 |
+
# StableDropout only calls this function when training.
|
| 593 |
+
train = True
|
| 594 |
+
# TODO: We should check if the opset_version being used to export
|
| 595 |
+
# is > 12 here, but there's no good way to do that. As-is, if the
|
| 596 |
+
# opset_version < 12, export will fail with a CheckerError.
|
| 597 |
+
# Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
|
| 598 |
+
# if opset_version < 12:
|
| 599 |
+
# return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
|
| 600 |
+
return symbolic_opset12.dropout(g, input, dropout_p, train)
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
class StableDropout(nn.Module):
|
| 604 |
+
"""
|
| 605 |
+
Optimized dropout module for stabilizing the training
|
| 606 |
+
|
| 607 |
+
Args:
|
| 608 |
+
drop_prob (float): the dropout probabilities
|
| 609 |
+
"""
|
| 610 |
+
|
| 611 |
+
def __init__(self, drop_prob):
|
| 612 |
+
super().__init__()
|
| 613 |
+
self.drop_prob = drop_prob
|
| 614 |
+
self.count = 0
|
| 615 |
+
self.context_stack = None
|
| 616 |
+
|
| 617 |
+
def forward(self, x):
|
| 618 |
+
"""
|
| 619 |
+
Call the module
|
| 620 |
+
|
| 621 |
+
Args:
|
| 622 |
+
x (`torch.tensor`): The input tensor to apply dropout
|
| 623 |
+
"""
|
| 624 |
+
if self.training and self.drop_prob > 0:
|
| 625 |
+
return XDropout.apply(x, self.get_context())
|
| 626 |
+
return x
|
| 627 |
+
|
| 628 |
+
def clear_context(self):
|
| 629 |
+
self.count = 0
|
| 630 |
+
self.context_stack = None
|
| 631 |
+
|
| 632 |
+
def init_context(self, reuse_mask=True, scale=1):
|
| 633 |
+
if self.context_stack is None:
|
| 634 |
+
self.context_stack = []
|
| 635 |
+
self.count = 0
|
| 636 |
+
for c in self.context_stack:
|
| 637 |
+
c.reuse_mask = reuse_mask
|
| 638 |
+
c.scale = scale
|
| 639 |
+
|
| 640 |
+
def get_context(self):
|
| 641 |
+
if self.context_stack is not None:
|
| 642 |
+
if self.count >= len(self.context_stack):
|
| 643 |
+
self.context_stack.append(DropoutContext())
|
| 644 |
+
ctx = self.context_stack[self.count]
|
| 645 |
+
ctx.dropout = self.drop_prob
|
| 646 |
+
self.count += 1
|
| 647 |
+
return ctx
|
| 648 |
+
else:
|
| 649 |
+
return self.drop_prob
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
class SEWDSelfOutput(nn.Module):
|
| 653 |
+
def __init__(self, config):
|
| 654 |
+
super().__init__()
|
| 655 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 656 |
+
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
|
| 657 |
+
self.dropout = nn.Dropout(config.activation_dropout)
|
| 658 |
+
|
| 659 |
+
def forward(self, hidden_states, input_tensor):
|
| 660 |
+
hidden_states = self.dense(hidden_states)
|
| 661 |
+
hidden_states = self.dropout(hidden_states)
|
| 662 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 663 |
+
return hidden_states
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
class DisentangledSelfAttention(nn.Module):
|
| 667 |
+
"""
|
| 668 |
+
Disentangled self-attention module
|
| 669 |
+
|
| 670 |
+
Parameters:
|
| 671 |
+
config (`DebertaV2Config`):
|
| 672 |
+
A model config class instance with the configuration to build a new model. The schema is similar to
|
| 673 |
+
*BertConfig*, for more details, please refer [`DebertaV2Config`]
|
| 674 |
+
|
| 675 |
+
"""
|
| 676 |
+
|
| 677 |
+
def __init__(self, config):
|
| 678 |
+
super().__init__()
|
| 679 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 680 |
+
raise ValueError(
|
| 681 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 682 |
+
f"heads ({config.num_attention_heads})"
|
| 683 |
+
)
|
| 684 |
+
self.num_attention_heads = config.num_attention_heads
|
| 685 |
+
_attention_head_size = config.hidden_size // config.num_attention_heads
|
| 686 |
+
self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
|
| 687 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 688 |
+
self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
| 689 |
+
self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
| 690 |
+
self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
| 691 |
+
|
| 692 |
+
self.share_att_key = getattr(config, "share_att_key", False)
|
| 693 |
+
self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
|
| 694 |
+
self.relative_attention = getattr(config, "relative_attention", False)
|
| 695 |
+
|
| 696 |
+
if self.relative_attention:
|
| 697 |
+
self.position_buckets = getattr(config, "position_buckets", -1)
|
| 698 |
+
self.max_relative_positions = getattr(config, "max_relative_positions", -1)
|
| 699 |
+
if self.max_relative_positions < 1:
|
| 700 |
+
self.max_relative_positions = config.max_position_embeddings
|
| 701 |
+
self.pos_ebd_size = self.max_relative_positions
|
| 702 |
+
if self.position_buckets > 0:
|
| 703 |
+
self.pos_ebd_size = self.position_buckets
|
| 704 |
+
|
| 705 |
+
self.pos_dropout = StableDropout(config.activation_dropout)
|
| 706 |
+
|
| 707 |
+
if not self.share_att_key:
|
| 708 |
+
if "c2p" in self.pos_att_type:
|
| 709 |
+
self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
| 710 |
+
if "p2c" in self.pos_att_type:
|
| 711 |
+
self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
|
| 712 |
+
|
| 713 |
+
self.dropout = StableDropout(config.attention_dropout)
|
| 714 |
+
|
| 715 |
+
def transpose_for_scores(self, x, attention_heads):
|
| 716 |
+
new_x_shape = x.size()[:-1] + (attention_heads, -1)
|
| 717 |
+
x = x.view(new_x_shape)
|
| 718 |
+
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
|
| 719 |
+
|
| 720 |
+
def forward(
|
| 721 |
+
self,
|
| 722 |
+
hidden_states,
|
| 723 |
+
attention_mask,
|
| 724 |
+
output_attentions=False,
|
| 725 |
+
query_states=None,
|
| 726 |
+
relative_pos=None,
|
| 727 |
+
rel_embeddings=None,
|
| 728 |
+
):
|
| 729 |
+
"""
|
| 730 |
+
Call the module
|
| 731 |
+
|
| 732 |
+
Args:
|
| 733 |
+
hidden_states (`torch.FloatTensor`):
|
| 734 |
+
Input states to the module usually the output from previous layer, it will be the Q,K and V in
|
| 735 |
+
*Attention(Q,K,V)*
|
| 736 |
+
|
| 737 |
+
attention_mask (`torch.BoolTensor`):
|
| 738 |
+
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
|
| 739 |
+
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
|
| 740 |
+
th token.
|
| 741 |
+
|
| 742 |
+
output_attentions (`bool`, *optional*):
|
| 743 |
+
Whether return the attention matrix.
|
| 744 |
+
|
| 745 |
+
query_states (`torch.FloatTensor`, *optional*):
|
| 746 |
+
The *Q* state in *Attention(Q,K,V)*.
|
| 747 |
+
|
| 748 |
+
relative_pos (`torch.LongTensor`):
|
| 749 |
+
The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
|
| 750 |
+
values ranging in [*-max_relative_positions*, *max_relative_positions*].
|
| 751 |
+
|
| 752 |
+
rel_embeddings (`torch.FloatTensor`):
|
| 753 |
+
The embedding of relative distances. It's a tensor of shape [\\(2 \\times
|
| 754 |
+
\\text{max_relative_positions}\\), *hidden_size*].
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
"""
|
| 758 |
+
if query_states is None:
|
| 759 |
+
query_states = hidden_states
|
| 760 |
+
query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads)
|
| 761 |
+
key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)
|
| 762 |
+
value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
|
| 763 |
+
|
| 764 |
+
rel_att = None
|
| 765 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 766 |
+
scale_factor = 1
|
| 767 |
+
if "c2p" in self.pos_att_type:
|
| 768 |
+
scale_factor += 1
|
| 769 |
+
if "p2c" in self.pos_att_type:
|
| 770 |
+
scale_factor += 1
|
| 771 |
+
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
|
| 772 |
+
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype))
|
| 773 |
+
if self.relative_attention:
|
| 774 |
+
rel_embeddings = self.pos_dropout(rel_embeddings)
|
| 775 |
+
rel_att = self.disentangled_attention_bias(
|
| 776 |
+
query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
if rel_att is not None:
|
| 780 |
+
attention_scores = attention_scores + rel_att
|
| 781 |
+
attention_scores = attention_scores
|
| 782 |
+
attention_scores = attention_scores.view(
|
| 783 |
+
-1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
# bsz x height x length x dimension
|
| 787 |
+
attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
|
| 788 |
+
attention_probs = self.dropout(attention_probs)
|
| 789 |
+
context_layer = torch.bmm(
|
| 790 |
+
attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer
|
| 791 |
+
)
|
| 792 |
+
context_layer = (
|
| 793 |
+
context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1))
|
| 794 |
+
.permute(0, 2, 1, 3)
|
| 795 |
+
.contiguous()
|
| 796 |
+
)
|
| 797 |
+
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
| 798 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
| 799 |
+
if output_attentions:
|
| 800 |
+
return (context_layer, attention_probs)
|
| 801 |
+
else:
|
| 802 |
+
return context_layer
|
| 803 |
+
|
| 804 |
+
def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
|
| 805 |
+
if relative_pos is None:
|
| 806 |
+
q = query_layer.size(-2)
|
| 807 |
+
relative_pos = build_relative_position(
|
| 808 |
+
q,
|
| 809 |
+
key_layer.size(-2),
|
| 810 |
+
bucket_size=self.position_buckets,
|
| 811 |
+
max_position=self.max_relative_positions,
|
| 812 |
+
device=query_layer.device,
|
| 813 |
+
)
|
| 814 |
+
if relative_pos.dim() == 2:
|
| 815 |
+
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
|
| 816 |
+
elif relative_pos.dim() == 3:
|
| 817 |
+
relative_pos = relative_pos.unsqueeze(1)
|
| 818 |
+
# bsz x height x query x key
|
| 819 |
+
elif relative_pos.dim() != 4:
|
| 820 |
+
raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
|
| 821 |
+
|
| 822 |
+
att_span = self.pos_ebd_size
|
| 823 |
+
relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long)
|
| 824 |
+
|
| 825 |
+
rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
|
| 826 |
+
if self.share_att_key:
|
| 827 |
+
pos_query_layer = self.transpose_for_scores(
|
| 828 |
+
self.query_proj(rel_embeddings), self.num_attention_heads
|
| 829 |
+
).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
|
| 830 |
+
pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
|
| 831 |
+
query_layer.size(0) // self.num_attention_heads, 1, 1
|
| 832 |
+
)
|
| 833 |
+
else:
|
| 834 |
+
if "c2p" in self.pos_att_type:
|
| 835 |
+
pos_key_layer = self.transpose_for_scores(
|
| 836 |
+
self.pos_key_proj(rel_embeddings), self.num_attention_heads
|
| 837 |
+
).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1)
|
| 838 |
+
if "p2c" in self.pos_att_type:
|
| 839 |
+
pos_query_layer = self.transpose_for_scores(
|
| 840 |
+
self.pos_query_proj(rel_embeddings), self.num_attention_heads
|
| 841 |
+
).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1)
|
| 842 |
+
|
| 843 |
+
score = 0
|
| 844 |
+
# content->position
|
| 845 |
+
if "c2p" in self.pos_att_type:
|
| 846 |
+
scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
|
| 847 |
+
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
|
| 848 |
+
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
|
| 849 |
+
c2p_att = torch.gather(
|
| 850 |
+
c2p_att,
|
| 851 |
+
dim=-1,
|
| 852 |
+
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
|
| 853 |
+
)
|
| 854 |
+
score += c2p_att / scale.to(dtype=c2p_att.dtype)
|
| 855 |
+
|
| 856 |
+
# position->content
|
| 857 |
+
if "p2c" in self.pos_att_type:
|
| 858 |
+
scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
|
| 859 |
+
if key_layer.size(-2) != query_layer.size(-2):
|
| 860 |
+
r_pos = build_relative_position(
|
| 861 |
+
key_layer.size(-2),
|
| 862 |
+
key_layer.size(-2),
|
| 863 |
+
bucket_size=self.position_buckets,
|
| 864 |
+
max_position=self.max_relative_positions,
|
| 865 |
+
device=query_layer.device,
|
| 866 |
+
)
|
| 867 |
+
r_pos = r_pos.unsqueeze(0)
|
| 868 |
+
else:
|
| 869 |
+
r_pos = relative_pos
|
| 870 |
+
|
| 871 |
+
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
|
| 872 |
+
p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
|
| 873 |
+
p2c_att = torch.gather(
|
| 874 |
+
p2c_att,
|
| 875 |
+
dim=-1,
|
| 876 |
+
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
|
| 877 |
+
).transpose(-1, -2)
|
| 878 |
+
score += p2c_att / scale.to(dtype=p2c_att.dtype)
|
| 879 |
+
|
| 880 |
+
return score
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
class SEWDAttention(nn.Module):
|
| 884 |
+
def __init__(self, config):
|
| 885 |
+
super().__init__()
|
| 886 |
+
self.self = DisentangledSelfAttention(config)
|
| 887 |
+
self.output = SEWDSelfOutput(config)
|
| 888 |
+
self.config = config
|
| 889 |
+
|
| 890 |
+
def forward(
|
| 891 |
+
self,
|
| 892 |
+
hidden_states,
|
| 893 |
+
attention_mask,
|
| 894 |
+
output_attentions=False,
|
| 895 |
+
query_states=None,
|
| 896 |
+
relative_pos=None,
|
| 897 |
+
rel_embeddings=None,
|
| 898 |
+
):
|
| 899 |
+
self_output = self.self(
|
| 900 |
+
hidden_states,
|
| 901 |
+
attention_mask,
|
| 902 |
+
output_attentions,
|
| 903 |
+
query_states=query_states,
|
| 904 |
+
relative_pos=relative_pos,
|
| 905 |
+
rel_embeddings=rel_embeddings,
|
| 906 |
+
)
|
| 907 |
+
if output_attentions:
|
| 908 |
+
self_output, att_matrix = self_output
|
| 909 |
+
if query_states is None:
|
| 910 |
+
query_states = hidden_states
|
| 911 |
+
attention_output = self.output(self_output, query_states)
|
| 912 |
+
|
| 913 |
+
if output_attentions:
|
| 914 |
+
return (attention_output, att_matrix)
|
| 915 |
+
else:
|
| 916 |
+
return attention_output
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->SEWD
|
| 920 |
+
class SEWDIntermediate(nn.Module):
|
| 921 |
+
def __init__(self, config):
|
| 922 |
+
super().__init__()
|
| 923 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 924 |
+
if isinstance(config.hidden_act, str):
|
| 925 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 926 |
+
else:
|
| 927 |
+
self.intermediate_act_fn = config.hidden_act
|
| 928 |
+
|
| 929 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 930 |
+
hidden_states = self.dense(hidden_states)
|
| 931 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 932 |
+
return hidden_states
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
class SEWDOutput(nn.Module):
|
| 936 |
+
def __init__(self, config):
|
| 937 |
+
super().__init__()
|
| 938 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 939 |
+
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
|
| 940 |
+
self.dropout = nn.Dropout(config.activation_dropout)
|
| 941 |
+
self.config = config
|
| 942 |
+
|
| 943 |
+
def forward(self, hidden_states, input_tensor):
|
| 944 |
+
hidden_states = self.dense(hidden_states)
|
| 945 |
+
hidden_states = self.dropout(hidden_states)
|
| 946 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 947 |
+
return hidden_states
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
class SEWDLayer(nn.Module):
|
| 951 |
+
def __init__(self, config):
|
| 952 |
+
super().__init__()
|
| 953 |
+
self.attention = SEWDAttention(config)
|
| 954 |
+
self.intermediate = SEWDIntermediate(config)
|
| 955 |
+
self.output = SEWDOutput(config)
|
| 956 |
+
|
| 957 |
+
def forward(
|
| 958 |
+
self,
|
| 959 |
+
hidden_states,
|
| 960 |
+
attention_mask,
|
| 961 |
+
query_states=None,
|
| 962 |
+
relative_pos=None,
|
| 963 |
+
rel_embeddings=None,
|
| 964 |
+
output_attentions=False,
|
| 965 |
+
):
|
| 966 |
+
attention_output = self.attention(
|
| 967 |
+
hidden_states,
|
| 968 |
+
attention_mask,
|
| 969 |
+
output_attentions=output_attentions,
|
| 970 |
+
query_states=query_states,
|
| 971 |
+
relative_pos=relative_pos,
|
| 972 |
+
rel_embeddings=rel_embeddings,
|
| 973 |
+
)
|
| 974 |
+
if output_attentions:
|
| 975 |
+
attention_output, att_matrix = attention_output
|
| 976 |
+
intermediate_output = self.intermediate(attention_output)
|
| 977 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 978 |
+
if output_attentions:
|
| 979 |
+
return (layer_output, att_matrix)
|
| 980 |
+
else:
|
| 981 |
+
return layer_output
|
| 982 |
+
|
| 983 |
+
|
| 984 |
+
class ConvLayer(nn.Module):
|
| 985 |
+
def __init__(self, config):
|
| 986 |
+
super().__init__()
|
| 987 |
+
kernel_size = getattr(config, "conv_kernel_size", 3)
|
| 988 |
+
groups = getattr(config, "conv_groups", 1)
|
| 989 |
+
self.conv_act = getattr(config, "conv_act", "tanh")
|
| 990 |
+
self.conv = nn.Conv1d(
|
| 991 |
+
config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
|
| 992 |
+
)
|
| 993 |
+
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
|
| 994 |
+
self.dropout = StableDropout(config.hidden_dropout_prob)
|
| 995 |
+
self.config = config
|
| 996 |
+
|
| 997 |
+
def forward(self, hidden_states, residual_states, input_mask):
|
| 998 |
+
out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
|
| 999 |
+
rmask = (1 - input_mask).bool()
|
| 1000 |
+
out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0)
|
| 1001 |
+
out = ACT2FN[self.conv_act](self.dropout(out))
|
| 1002 |
+
|
| 1003 |
+
layer_norm_input = residual_states + out
|
| 1004 |
+
output = self.LayerNorm(layer_norm_input).to(layer_norm_input)
|
| 1005 |
+
|
| 1006 |
+
if input_mask is None:
|
| 1007 |
+
output_states = output
|
| 1008 |
+
else:
|
| 1009 |
+
if input_mask.dim() != layer_norm_input.dim():
|
| 1010 |
+
if input_mask.dim() == 4:
|
| 1011 |
+
input_mask = input_mask.squeeze(1).squeeze(1)
|
| 1012 |
+
input_mask = input_mask.unsqueeze(2)
|
| 1013 |
+
|
| 1014 |
+
input_mask = input_mask.to(output.dtype)
|
| 1015 |
+
output_states = output * input_mask
|
| 1016 |
+
|
| 1017 |
+
return output_states
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
class SEWDTransformerEncoder(nn.Module):
|
| 1021 |
+
"""Modified BertEncoder with relative position bias support"""
|
| 1022 |
+
|
| 1023 |
+
def __init__(self, config):
|
| 1024 |
+
super().__init__()
|
| 1025 |
+
|
| 1026 |
+
self.layer = nn.ModuleList([SEWDLayer(config) for _ in range(config.num_hidden_layers)])
|
| 1027 |
+
self.relative_attention = getattr(config, "relative_attention", False)
|
| 1028 |
+
|
| 1029 |
+
if self.relative_attention:
|
| 1030 |
+
self.max_relative_positions = getattr(config, "max_relative_positions", -1)
|
| 1031 |
+
if self.max_relative_positions < 1:
|
| 1032 |
+
self.max_relative_positions = config.max_position_embeddings
|
| 1033 |
+
|
| 1034 |
+
self.position_buckets = getattr(config, "position_buckets", -1)
|
| 1035 |
+
pos_ebd_size = self.max_relative_positions * 2
|
| 1036 |
+
|
| 1037 |
+
if self.position_buckets > 0:
|
| 1038 |
+
pos_ebd_size = self.position_buckets * 2
|
| 1039 |
+
|
| 1040 |
+
self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size)
|
| 1041 |
+
|
| 1042 |
+
self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")]
|
| 1043 |
+
|
| 1044 |
+
if "layer_norm" in self.norm_rel_ebd:
|
| 1045 |
+
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
|
| 1046 |
+
|
| 1047 |
+
self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None
|
| 1048 |
+
self.gradient_checkpointing = False
|
| 1049 |
+
|
| 1050 |
+
def get_rel_embedding(self):
|
| 1051 |
+
rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
|
| 1052 |
+
if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd):
|
| 1053 |
+
rel_embeddings = self.LayerNorm(rel_embeddings)
|
| 1054 |
+
return rel_embeddings
|
| 1055 |
+
|
| 1056 |
+
def get_attention_mask(self, attention_mask):
|
| 1057 |
+
if attention_mask.dim() <= 2:
|
| 1058 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 1059 |
+
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
|
| 1060 |
+
elif attention_mask.dim() == 3:
|
| 1061 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 1062 |
+
|
| 1063 |
+
return attention_mask
|
| 1064 |
+
|
| 1065 |
+
def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
|
| 1066 |
+
if self.relative_attention and relative_pos is None:
|
| 1067 |
+
q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
|
| 1068 |
+
relative_pos = build_relative_position(
|
| 1069 |
+
q,
|
| 1070 |
+
hidden_states.size(-2),
|
| 1071 |
+
bucket_size=self.position_buckets,
|
| 1072 |
+
max_position=self.max_relative_positions,
|
| 1073 |
+
device=hidden_states.device,
|
| 1074 |
+
)
|
| 1075 |
+
return relative_pos
|
| 1076 |
+
|
| 1077 |
+
def forward(
|
| 1078 |
+
self,
|
| 1079 |
+
hidden_states,
|
| 1080 |
+
attention_mask,
|
| 1081 |
+
output_hidden_states=True,
|
| 1082 |
+
output_attentions=False,
|
| 1083 |
+
query_states=None,
|
| 1084 |
+
relative_pos=None,
|
| 1085 |
+
return_dict=True,
|
| 1086 |
+
):
|
| 1087 |
+
if attention_mask.dim() <= 2:
|
| 1088 |
+
input_mask = attention_mask
|
| 1089 |
+
else:
|
| 1090 |
+
input_mask = attention_mask.sum(-2) > 0
|
| 1091 |
+
attention_mask = self.get_attention_mask(attention_mask)
|
| 1092 |
+
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
|
| 1093 |
+
|
| 1094 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1095 |
+
all_attentions = () if output_attentions else None
|
| 1096 |
+
|
| 1097 |
+
if isinstance(hidden_states, Sequence):
|
| 1098 |
+
next_kv = hidden_states[0]
|
| 1099 |
+
else:
|
| 1100 |
+
next_kv = hidden_states
|
| 1101 |
+
rel_embeddings = self.get_rel_embedding()
|
| 1102 |
+
output_states = next_kv
|
| 1103 |
+
for i, layer_module in enumerate(self.layer):
|
| 1104 |
+
if output_hidden_states:
|
| 1105 |
+
all_hidden_states = all_hidden_states + (output_states,)
|
| 1106 |
+
|
| 1107 |
+
if self.gradient_checkpointing and self.training:
|
| 1108 |
+
output_states = self._gradient_checkpointing_func(
|
| 1109 |
+
layer_module.__call__,
|
| 1110 |
+
next_kv,
|
| 1111 |
+
attention_mask,
|
| 1112 |
+
query_states,
|
| 1113 |
+
relative_pos,
|
| 1114 |
+
rel_embeddings,
|
| 1115 |
+
output_attentions,
|
| 1116 |
+
)
|
| 1117 |
+
else:
|
| 1118 |
+
output_states = layer_module(
|
| 1119 |
+
next_kv,
|
| 1120 |
+
attention_mask,
|
| 1121 |
+
query_states=query_states,
|
| 1122 |
+
relative_pos=relative_pos,
|
| 1123 |
+
rel_embeddings=rel_embeddings,
|
| 1124 |
+
output_attentions=output_attentions,
|
| 1125 |
+
)
|
| 1126 |
+
|
| 1127 |
+
if output_attentions:
|
| 1128 |
+
output_states, att_m = output_states
|
| 1129 |
+
|
| 1130 |
+
if i == 0 and self.conv is not None:
|
| 1131 |
+
output_states = self.conv(hidden_states, output_states, input_mask)
|
| 1132 |
+
|
| 1133 |
+
if query_states is not None:
|
| 1134 |
+
query_states = output_states
|
| 1135 |
+
if isinstance(hidden_states, Sequence):
|
| 1136 |
+
next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
|
| 1137 |
+
else:
|
| 1138 |
+
next_kv = output_states
|
| 1139 |
+
|
| 1140 |
+
if output_attentions:
|
| 1141 |
+
all_attentions = all_attentions + (att_m,)
|
| 1142 |
+
|
| 1143 |
+
if output_hidden_states:
|
| 1144 |
+
all_hidden_states = all_hidden_states + (output_states,)
|
| 1145 |
+
|
| 1146 |
+
if not return_dict:
|
| 1147 |
+
return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
|
| 1148 |
+
return BaseModelOutput(
|
| 1149 |
+
last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 1150 |
+
)
|
| 1151 |
+
|
| 1152 |
+
|
| 1153 |
+
class SEWDEncoder(nn.Module):
|
| 1154 |
+
def __init__(self, config):
|
| 1155 |
+
super().__init__()
|
| 1156 |
+
self.config = config
|
| 1157 |
+
self.pos_conv_embed = SEWDPositionalConvEmbedding(config)
|
| 1158 |
+
self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor)
|
| 1159 |
+
self.encoder = SEWDTransformerEncoder(config)
|
| 1160 |
+
self.upsample = SEWDUpsampling(config)
|
| 1161 |
+
self.gradient_checkpointing = False
|
| 1162 |
+
|
| 1163 |
+
def forward(
|
| 1164 |
+
self,
|
| 1165 |
+
hidden_states: torch.tensor,
|
| 1166 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1167 |
+
output_attentions: bool = False,
|
| 1168 |
+
output_hidden_states: bool = False,
|
| 1169 |
+
return_dict: bool = True,
|
| 1170 |
+
):
|
| 1171 |
+
max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor
|
| 1172 |
+
if attention_mask is None:
|
| 1173 |
+
attention_mask = torch.ones(
|
| 1174 |
+
(hidden_states.shape[0], max_encoder_length), dtype=torch.long, device=hidden_states.device
|
| 1175 |
+
)
|
| 1176 |
+
else:
|
| 1177 |
+
# make sure padded tokens output 0
|
| 1178 |
+
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
| 1179 |
+
hidden_states[~expand_attention_mask.bool()] = 0.0
|
| 1180 |
+
|
| 1181 |
+
input_lengths = (attention_mask.long()).sum(-1)
|
| 1182 |
+
# apply pooling formula to get real output_lengths
|
| 1183 |
+
output_lengths = input_lengths // self.config.squeeze_factor
|
| 1184 |
+
attention_ids = (
|
| 1185 |
+
torch.arange(0, max_encoder_length, device=output_lengths.device)
|
| 1186 |
+
.view(1, -1)
|
| 1187 |
+
.expand(output_lengths.shape[0], -1)
|
| 1188 |
+
)
|
| 1189 |
+
attention_mask = (attention_ids < output_lengths.view(-1, 1)).long()
|
| 1190 |
+
|
| 1191 |
+
n_input_timesteps = hidden_states.shape[1]
|
| 1192 |
+
|
| 1193 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 1194 |
+
position_embeddings = self.pos_conv_embed(hidden_states)
|
| 1195 |
+
pooled_hidden_states = self.pool(hidden_states)
|
| 1196 |
+
min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1))
|
| 1197 |
+
hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length]
|
| 1198 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 1199 |
+
|
| 1200 |
+
encoder_outputs = self.encoder(hidden_states, attention_mask, output_hidden_states, output_attentions)
|
| 1201 |
+
|
| 1202 |
+
hidden_states = self.upsample(encoder_outputs.last_hidden_state)
|
| 1203 |
+
if hidden_states.shape[1] < n_input_timesteps:
|
| 1204 |
+
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1]))
|
| 1205 |
+
|
| 1206 |
+
if not return_dict:
|
| 1207 |
+
return tuple(
|
| 1208 |
+
v for v in [hidden_states, encoder_outputs.hidden_states, encoder_outputs.attentions] if v is not None
|
| 1209 |
+
)
|
| 1210 |
+
return BaseModelOutput(
|
| 1211 |
+
last_hidden_state=hidden_states,
|
| 1212 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1213 |
+
attentions=encoder_outputs.attentions,
|
| 1214 |
+
)
|
| 1215 |
+
|
| 1216 |
+
|
| 1217 |
+
class SEWDPreTrainedModel(PreTrainedModel):
|
| 1218 |
+
"""
|
| 1219 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 1220 |
+
models.
|
| 1221 |
+
"""
|
| 1222 |
+
|
| 1223 |
+
config_class = SEWDConfig
|
| 1224 |
+
base_model_prefix = "sew-d"
|
| 1225 |
+
main_input_name = "input_values"
|
| 1226 |
+
supports_gradient_checkpointing = True
|
| 1227 |
+
|
| 1228 |
+
def _init_weights(self, module):
|
| 1229 |
+
"""Initialize the weights"""
|
| 1230 |
+
if isinstance(module, SEWDPositionalConvEmbedding):
|
| 1231 |
+
nn.init.normal_(
|
| 1232 |
+
module.conv.weight,
|
| 1233 |
+
mean=0,
|
| 1234 |
+
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
|
| 1235 |
+
)
|
| 1236 |
+
nn.init.constant_(module.conv.bias, 0)
|
| 1237 |
+
elif isinstance(module, nn.Linear):
|
| 1238 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 1239 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 1240 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 1241 |
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
| 1242 |
+
module.bias.data.zero_()
|
| 1243 |
+
module.weight.data.fill_(1.0)
|
| 1244 |
+
elif isinstance(module, nn.Conv1d):
|
| 1245 |
+
if is_deepspeed_zero3_enabled():
|
| 1246 |
+
import deepspeed
|
| 1247 |
+
|
| 1248 |
+
if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
|
| 1249 |
+
with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
|
| 1250 |
+
nn.init.kaiming_normal_(module.weight.data)
|
| 1251 |
+
else:
|
| 1252 |
+
with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
|
| 1253 |
+
nn.init.kaiming_normal_(module.weight.data)
|
| 1254 |
+
else:
|
| 1255 |
+
nn.init.kaiming_normal_(module.weight.data)
|
| 1256 |
+
elif isinstance(module, nn.Embedding):
|
| 1257 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 1258 |
+
if module.padding_idx is not None:
|
| 1259 |
+
module.weight.data[module.padding_idx].zero_()
|
| 1260 |
+
|
| 1261 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
|
| 1262 |
+
module.bias.data.zero_()
|
| 1263 |
+
|
| 1264 |
+
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
| 1265 |
+
"""
|
| 1266 |
+
Computes the output length of the convolutional layers
|
| 1267 |
+
"""
|
| 1268 |
+
|
| 1269 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
| 1270 |
+
# 1D convolutional layer output length formula taken
|
| 1271 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
| 1272 |
+
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
|
| 1273 |
+
|
| 1274 |
+
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
| 1275 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
| 1276 |
+
|
| 1277 |
+
return input_lengths
|
| 1278 |
+
|
| 1279 |
+
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
|
| 1280 |
+
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
| 1281 |
+
batch_size = attention_mask.shape[0]
|
| 1282 |
+
|
| 1283 |
+
attention_mask = torch.zeros(
|
| 1284 |
+
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
|
| 1285 |
+
)
|
| 1286 |
+
# these two operations makes sure that all values before the output lengths idxs are attended to
|
| 1287 |
+
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
|
| 1288 |
+
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
| 1289 |
+
return attention_mask
|
| 1290 |
+
|
| 1291 |
+
|
| 1292 |
+
SEWD_START_DOCSTRING = r"""
|
| 1293 |
+
SEW-D was proposed in [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech
|
| 1294 |
+
Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger,
|
| 1295 |
+
Yoav Artzi.
|
| 1296 |
+
|
| 1297 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1298 |
+
library implements for all its model (such as downloading or saving etc.).
|
| 1299 |
+
|
| 1300 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
| 1301 |
+
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 1302 |
+
behavior.
|
| 1303 |
+
|
| 1304 |
+
Parameters:
|
| 1305 |
+
config ([`SEWDConfig`]): Model configuration class with all the parameters of the model.
|
| 1306 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 1307 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 1308 |
+
"""
|
| 1309 |
+
|
| 1310 |
+
|
| 1311 |
+
SEWD_INPUTS_DOCSTRING = r"""
|
| 1312 |
+
Args:
|
| 1313 |
+
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 1314 |
+
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
|
| 1315 |
+
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
|
| 1316 |
+
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
|
| 1317 |
+
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
|
| 1318 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1319 |
+
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
|
| 1320 |
+
1]`:
|
| 1321 |
+
|
| 1322 |
+
- 1 for tokens that are **not masked**,
|
| 1323 |
+
- 0 for tokens that are **masked**.
|
| 1324 |
+
|
| 1325 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1326 |
+
|
| 1327 |
+
output_attentions (`bool`, *optional*):
|
| 1328 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1329 |
+
tensors for more detail.
|
| 1330 |
+
output_hidden_states (`bool`, *optional*):
|
| 1331 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1332 |
+
more detail.
|
| 1333 |
+
return_dict (`bool`, *optional*):
|
| 1334 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1335 |
+
"""
|
| 1336 |
+
|
| 1337 |
+
|
| 1338 |
+
@add_start_docstrings(
|
| 1339 |
+
"The bare SEW-D Model transformer outputting raw hidden-states without any specific head on top.",
|
| 1340 |
+
SEWD_START_DOCSTRING,
|
| 1341 |
+
)
|
| 1342 |
+
# Copied from transformers.models.sew.modeling_sew.SEWModel with SEW->SEWD, layer_norm_eps->feature_layer_norm_eps
|
| 1343 |
+
class SEWDModel(SEWDPreTrainedModel):
|
| 1344 |
+
def __init__(self, config: SEWDConfig):
|
| 1345 |
+
super().__init__(config)
|
| 1346 |
+
self.config = config
|
| 1347 |
+
self.feature_extractor = SEWDFeatureEncoder(config)
|
| 1348 |
+
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.feature_layer_norm_eps)
|
| 1349 |
+
|
| 1350 |
+
self.project_features = config.conv_dim[-1] != config.hidden_size
|
| 1351 |
+
if self.project_features:
|
| 1352 |
+
self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
| 1353 |
+
self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
|
| 1354 |
+
|
| 1355 |
+
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
| 1356 |
+
self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
|
| 1357 |
+
|
| 1358 |
+
self.encoder = SEWDEncoder(config)
|
| 1359 |
+
|
| 1360 |
+
# Initialize weights and apply final processing
|
| 1361 |
+
self.post_init()
|
| 1362 |
+
|
| 1363 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
|
| 1364 |
+
def _mask_hidden_states(
|
| 1365 |
+
self,
|
| 1366 |
+
hidden_states: torch.FloatTensor,
|
| 1367 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
| 1368 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1369 |
+
):
|
| 1370 |
+
"""
|
| 1371 |
+
Masks extracted features along time axis and/or along feature axis according to
|
| 1372 |
+
[SpecAugment](https://arxiv.org/abs/1904.08779).
|
| 1373 |
+
"""
|
| 1374 |
+
|
| 1375 |
+
# `config.apply_spec_augment` can set masking to False
|
| 1376 |
+
if not getattr(self.config, "apply_spec_augment", True):
|
| 1377 |
+
return hidden_states
|
| 1378 |
+
|
| 1379 |
+
# generate indices & apply SpecAugment along time axis
|
| 1380 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
| 1381 |
+
|
| 1382 |
+
if mask_time_indices is not None:
|
| 1383 |
+
# apply SpecAugment along time axis with given mask_time_indices
|
| 1384 |
+
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
| 1385 |
+
elif self.config.mask_time_prob > 0 and self.training:
|
| 1386 |
+
mask_time_indices = _compute_mask_indices(
|
| 1387 |
+
(batch_size, sequence_length),
|
| 1388 |
+
mask_prob=self.config.mask_time_prob,
|
| 1389 |
+
mask_length=self.config.mask_time_length,
|
| 1390 |
+
attention_mask=attention_mask,
|
| 1391 |
+
min_masks=self.config.mask_time_min_masks,
|
| 1392 |
+
)
|
| 1393 |
+
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
|
| 1394 |
+
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
| 1395 |
+
|
| 1396 |
+
if self.config.mask_feature_prob > 0 and self.training:
|
| 1397 |
+
# generate indices & apply SpecAugment along feature axis
|
| 1398 |
+
mask_feature_indices = _compute_mask_indices(
|
| 1399 |
+
(batch_size, hidden_size),
|
| 1400 |
+
mask_prob=self.config.mask_feature_prob,
|
| 1401 |
+
mask_length=self.config.mask_feature_length,
|
| 1402 |
+
min_masks=self.config.mask_feature_min_masks,
|
| 1403 |
+
)
|
| 1404 |
+
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
|
| 1405 |
+
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
|
| 1406 |
+
hidden_states[mask_feature_indices] = 0
|
| 1407 |
+
|
| 1408 |
+
return hidden_states
|
| 1409 |
+
|
| 1410 |
+
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
|
| 1411 |
+
@add_code_sample_docstrings(
|
| 1412 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1413 |
+
output_type=BaseModelOutput,
|
| 1414 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1415 |
+
modality="audio",
|
| 1416 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 1417 |
+
)
|
| 1418 |
+
def forward(
|
| 1419 |
+
self,
|
| 1420 |
+
input_values: Optional[torch.Tensor],
|
| 1421 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1422 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
| 1423 |
+
output_attentions: Optional[bool] = None,
|
| 1424 |
+
output_hidden_states: Optional[bool] = None,
|
| 1425 |
+
return_dict: Optional[bool] = None,
|
| 1426 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 1427 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1428 |
+
output_hidden_states = (
|
| 1429 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1430 |
+
)
|
| 1431 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1432 |
+
|
| 1433 |
+
extract_features = self.feature_extractor(input_values)
|
| 1434 |
+
extract_features = extract_features.transpose(1, 2)
|
| 1435 |
+
extract_features = self.layer_norm(extract_features)
|
| 1436 |
+
|
| 1437 |
+
if self.project_features:
|
| 1438 |
+
extract_features = self.feature_projection(extract_features)
|
| 1439 |
+
hidden_states = self.feature_dropout(extract_features)
|
| 1440 |
+
|
| 1441 |
+
if attention_mask is not None:
|
| 1442 |
+
# compute reduced attention_mask corresponding to feature vectors
|
| 1443 |
+
attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
| 1444 |
+
|
| 1445 |
+
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
|
| 1446 |
+
|
| 1447 |
+
encoder_outputs = self.encoder(
|
| 1448 |
+
hidden_states,
|
| 1449 |
+
attention_mask=attention_mask,
|
| 1450 |
+
output_attentions=output_attentions,
|
| 1451 |
+
output_hidden_states=output_hidden_states,
|
| 1452 |
+
return_dict=return_dict,
|
| 1453 |
+
)
|
| 1454 |
+
|
| 1455 |
+
hidden_states = encoder_outputs[0]
|
| 1456 |
+
|
| 1457 |
+
if not return_dict:
|
| 1458 |
+
return (hidden_states,) + encoder_outputs[1:]
|
| 1459 |
+
|
| 1460 |
+
return BaseModelOutput(
|
| 1461 |
+
last_hidden_state=hidden_states,
|
| 1462 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1463 |
+
attentions=encoder_outputs.attentions,
|
| 1464 |
+
)
|
| 1465 |
+
|
| 1466 |
+
|
| 1467 |
+
@add_start_docstrings(
|
| 1468 |
+
"""SEW-D Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
|
| 1469 |
+
SEWD_START_DOCSTRING,
|
| 1470 |
+
)
|
| 1471 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV2VEC2->SEWD
|
| 1472 |
+
class SEWDForCTC(SEWDPreTrainedModel):
|
| 1473 |
+
def __init__(self, config, target_lang: Optional[str] = None):
|
| 1474 |
+
super().__init__(config)
|
| 1475 |
+
|
| 1476 |
+
self.sew_d = SEWDModel(config)
|
| 1477 |
+
self.dropout = nn.Dropout(config.final_dropout)
|
| 1478 |
+
|
| 1479 |
+
self.target_lang = target_lang
|
| 1480 |
+
|
| 1481 |
+
if config.vocab_size is None:
|
| 1482 |
+
raise ValueError(
|
| 1483 |
+
f"You are trying to instantiate {self.__class__} with a configuration that "
|
| 1484 |
+
"does not define the vocabulary size of the language model head. Please "
|
| 1485 |
+
"instantiate the model as follows: `SEWDForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
|
| 1486 |
+
"or define `vocab_size` of your model's configuration."
|
| 1487 |
+
)
|
| 1488 |
+
output_hidden_size = (
|
| 1489 |
+
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
|
| 1490 |
+
)
|
| 1491 |
+
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
| 1492 |
+
|
| 1493 |
+
# Initialize weights and apply final processing
|
| 1494 |
+
self.post_init()
|
| 1495 |
+
|
| 1496 |
+
def tie_weights(self):
|
| 1497 |
+
"""
|
| 1498 |
+
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
| 1499 |
+
passing `target_lang=...` to `from_pretrained(...)`.
|
| 1500 |
+
|
| 1501 |
+
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
| 1502 |
+
"""
|
| 1503 |
+
|
| 1504 |
+
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
| 1505 |
+
# correctly load adapter layers for SEWD so that we do not have to introduce a new API to
|
| 1506 |
+
# [`PreTrainedModel`]. While slightly hacky, SEWD never has to tie input and output embeddings, so that it is
|
| 1507 |
+
# ok to repurpose this function here.
|
| 1508 |
+
target_lang = self.target_lang
|
| 1509 |
+
|
| 1510 |
+
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
| 1511 |
+
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
| 1512 |
+
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
| 1513 |
+
logger.info("By default `target_lang` is set to 'eng'.")
|
| 1514 |
+
elif target_lang is not None:
|
| 1515 |
+
self.load_adapter(target_lang, force_load=True)
|
| 1516 |
+
|
| 1517 |
+
def freeze_feature_extractor(self):
|
| 1518 |
+
"""
|
| 1519 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1520 |
+
not be updated during training.
|
| 1521 |
+
"""
|
| 1522 |
+
warnings.warn(
|
| 1523 |
+
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
|
| 1524 |
+
"Please use the equivalent `freeze_feature_encoder` method instead.",
|
| 1525 |
+
FutureWarning,
|
| 1526 |
+
)
|
| 1527 |
+
self.freeze_feature_encoder()
|
| 1528 |
+
|
| 1529 |
+
def freeze_feature_encoder(self):
|
| 1530 |
+
"""
|
| 1531 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1532 |
+
not be updated during training.
|
| 1533 |
+
"""
|
| 1534 |
+
self.sew_d.feature_extractor._freeze_parameters()
|
| 1535 |
+
|
| 1536 |
+
def freeze_base_model(self):
|
| 1537 |
+
"""
|
| 1538 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 1539 |
+
be updated during training. Only the classification head will be updated.
|
| 1540 |
+
"""
|
| 1541 |
+
for param in self.sew_d.parameters():
|
| 1542 |
+
param.requires_grad = False
|
| 1543 |
+
|
| 1544 |
+
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
|
| 1545 |
+
@add_code_sample_docstrings(
|
| 1546 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1547 |
+
output_type=CausalLMOutput,
|
| 1548 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1549 |
+
expected_output=_CTC_EXPECTED_OUTPUT,
|
| 1550 |
+
expected_loss=_CTC_EXPECTED_LOSS,
|
| 1551 |
+
)
|
| 1552 |
+
def forward(
|
| 1553 |
+
self,
|
| 1554 |
+
input_values: Optional[torch.Tensor],
|
| 1555 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1556 |
+
output_attentions: Optional[bool] = None,
|
| 1557 |
+
output_hidden_states: Optional[bool] = None,
|
| 1558 |
+
return_dict: Optional[bool] = None,
|
| 1559 |
+
labels: Optional[torch.Tensor] = None,
|
| 1560 |
+
) -> Union[Tuple, CausalLMOutput]:
|
| 1561 |
+
r"""
|
| 1562 |
+
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
| 1563 |
+
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
| 1564 |
+
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
|
| 1565 |
+
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
| 1566 |
+
config.vocab_size - 1]`.
|
| 1567 |
+
"""
|
| 1568 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1569 |
+
|
| 1570 |
+
if labels is not None and labels.max() >= self.config.vocab_size:
|
| 1571 |
+
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
| 1572 |
+
|
| 1573 |
+
outputs = self.sew_d(
|
| 1574 |
+
input_values,
|
| 1575 |
+
attention_mask=attention_mask,
|
| 1576 |
+
output_attentions=output_attentions,
|
| 1577 |
+
output_hidden_states=output_hidden_states,
|
| 1578 |
+
return_dict=return_dict,
|
| 1579 |
+
)
|
| 1580 |
+
|
| 1581 |
+
hidden_states = outputs[0]
|
| 1582 |
+
hidden_states = self.dropout(hidden_states)
|
| 1583 |
+
|
| 1584 |
+
logits = self.lm_head(hidden_states)
|
| 1585 |
+
|
| 1586 |
+
loss = None
|
| 1587 |
+
if labels is not None:
|
| 1588 |
+
# retrieve loss input_lengths from attention_mask
|
| 1589 |
+
attention_mask = (
|
| 1590 |
+
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
| 1591 |
+
)
|
| 1592 |
+
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
| 1593 |
+
|
| 1594 |
+
# assuming that padded tokens are filled with -100
|
| 1595 |
+
# when not being attended to
|
| 1596 |
+
labels_mask = labels >= 0
|
| 1597 |
+
target_lengths = labels_mask.sum(-1)
|
| 1598 |
+
flattened_targets = labels.masked_select(labels_mask)
|
| 1599 |
+
|
| 1600 |
+
# ctc_loss doesn't support fp16
|
| 1601 |
+
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
| 1602 |
+
|
| 1603 |
+
with torch.backends.cudnn.flags(enabled=False):
|
| 1604 |
+
loss = nn.functional.ctc_loss(
|
| 1605 |
+
log_probs,
|
| 1606 |
+
flattened_targets,
|
| 1607 |
+
input_lengths,
|
| 1608 |
+
target_lengths,
|
| 1609 |
+
blank=self.config.pad_token_id,
|
| 1610 |
+
reduction=self.config.ctc_loss_reduction,
|
| 1611 |
+
zero_infinity=self.config.ctc_zero_infinity,
|
| 1612 |
+
)
|
| 1613 |
+
|
| 1614 |
+
if not return_dict:
|
| 1615 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1616 |
+
return ((loss,) + output) if loss is not None else output
|
| 1617 |
+
|
| 1618 |
+
return CausalLMOutput(
|
| 1619 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
| 1620 |
+
)
|
| 1621 |
+
|
| 1622 |
+
|
| 1623 |
+
@add_start_docstrings(
|
| 1624 |
+
"""
|
| 1625 |
+
SEWD Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB
|
| 1626 |
+
Keyword Spotting.
|
| 1627 |
+
""",
|
| 1628 |
+
SEWD_START_DOCSTRING,
|
| 1629 |
+
)
|
| 1630 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV2VEC2->SEWD
|
| 1631 |
+
class SEWDForSequenceClassification(SEWDPreTrainedModel):
|
| 1632 |
+
def __init__(self, config):
|
| 1633 |
+
super().__init__(config)
|
| 1634 |
+
|
| 1635 |
+
if hasattr(config, "add_adapter") and config.add_adapter:
|
| 1636 |
+
raise ValueError(
|
| 1637 |
+
"Sequence classification does not support the use of SEWD adapters (config.add_adapter=True)"
|
| 1638 |
+
)
|
| 1639 |
+
self.sew_d = SEWDModel(config)
|
| 1640 |
+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
| 1641 |
+
if config.use_weighted_layer_sum:
|
| 1642 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 1643 |
+
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
| 1644 |
+
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
| 1645 |
+
|
| 1646 |
+
# Initialize weights and apply final processing
|
| 1647 |
+
self.post_init()
|
| 1648 |
+
|
| 1649 |
+
def freeze_feature_extractor(self):
|
| 1650 |
+
"""
|
| 1651 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
|
| 1652 |
+
not be updated during training.
|
| 1653 |
+
"""
|
| 1654 |
+
warnings.warn(
|
| 1655 |
+
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
|
| 1656 |
+
"Please use the equivalent `freeze_feature_encoder` method instead.",
|
| 1657 |
+
FutureWarning,
|
| 1658 |
+
)
|
| 1659 |
+
self.freeze_feature_encoder()
|
| 1660 |
+
|
| 1661 |
+
def freeze_feature_encoder(self):
|
| 1662 |
+
"""
|
| 1663 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1664 |
+
not be updated during training.
|
| 1665 |
+
"""
|
| 1666 |
+
self.sew_d.feature_extractor._freeze_parameters()
|
| 1667 |
+
|
| 1668 |
+
def freeze_base_model(self):
|
| 1669 |
+
"""
|
| 1670 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 1671 |
+
be updated during training. Only the classification head will be updated.
|
| 1672 |
+
"""
|
| 1673 |
+
for param in self.sew_d.parameters():
|
| 1674 |
+
param.requires_grad = False
|
| 1675 |
+
|
| 1676 |
+
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
|
| 1677 |
+
@add_code_sample_docstrings(
|
| 1678 |
+
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
| 1679 |
+
output_type=SequenceClassifierOutput,
|
| 1680 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1681 |
+
modality="audio",
|
| 1682 |
+
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
| 1683 |
+
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
| 1684 |
+
)
|
| 1685 |
+
def forward(
|
| 1686 |
+
self,
|
| 1687 |
+
input_values: Optional[torch.Tensor],
|
| 1688 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1689 |
+
output_attentions: Optional[bool] = None,
|
| 1690 |
+
output_hidden_states: Optional[bool] = None,
|
| 1691 |
+
return_dict: Optional[bool] = None,
|
| 1692 |
+
labels: Optional[torch.Tensor] = None,
|
| 1693 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 1694 |
+
r"""
|
| 1695 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1696 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1697 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1698 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1699 |
+
"""
|
| 1700 |
+
|
| 1701 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1702 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 1703 |
+
|
| 1704 |
+
outputs = self.sew_d(
|
| 1705 |
+
input_values,
|
| 1706 |
+
attention_mask=attention_mask,
|
| 1707 |
+
output_attentions=output_attentions,
|
| 1708 |
+
output_hidden_states=output_hidden_states,
|
| 1709 |
+
return_dict=return_dict,
|
| 1710 |
+
)
|
| 1711 |
+
|
| 1712 |
+
if self.config.use_weighted_layer_sum:
|
| 1713 |
+
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
| 1714 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
| 1715 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 1716 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 1717 |
+
else:
|
| 1718 |
+
hidden_states = outputs[0]
|
| 1719 |
+
|
| 1720 |
+
hidden_states = self.projector(hidden_states)
|
| 1721 |
+
if attention_mask is None:
|
| 1722 |
+
pooled_output = hidden_states.mean(dim=1)
|
| 1723 |
+
else:
|
| 1724 |
+
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
| 1725 |
+
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
| 1726 |
+
hidden_states[~expand_padding_mask] = 0.0
|
| 1727 |
+
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
| 1728 |
+
|
| 1729 |
+
logits = self.classifier(pooled_output)
|
| 1730 |
+
|
| 1731 |
+
loss = None
|
| 1732 |
+
if labels is not None:
|
| 1733 |
+
loss_fct = CrossEntropyLoss()
|
| 1734 |
+
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
| 1735 |
+
|
| 1736 |
+
if not return_dict:
|
| 1737 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1738 |
+
return ((loss,) + output) if loss is not None else output
|
| 1739 |
+
|
| 1740 |
+
return SequenceClassifierOutput(
|
| 1741 |
+
loss=loss,
|
| 1742 |
+
logits=logits,
|
| 1743 |
+
hidden_states=outputs.hidden_states,
|
| 1744 |
+
attentions=outputs.attentions,
|
| 1745 |
+
)
|
| 1746 |
+
|
| 1747 |
+
|
| 1748 |
+
__all__ = ["SEWDForCTC", "SEWDForSequenceClassification", "SEWDModel", "SEWDPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/shieldgemma2/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_shieldgemma2 import *
|
| 22 |
+
from .modeling_shieldgemma2 import *
|
| 23 |
+
from .processing_shieldgemma2 import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/shieldgemma2/configuration_shieldgemma2.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
from ..auto import CONFIG_MAPPING, AutoConfig
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ShieldGemma2Config(PretrainedConfig):
|
| 26 |
+
r"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`ShieldGemma2ForImageClassification`]. It is used to instantiate an
|
| 28 |
+
ShieldGemma2ForImageClassification according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 29 |
+
with the defaults will yield a similar configuration to that of the shieldgemma-2-4b-it.
|
| 30 |
+
|
| 31 |
+
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
|
| 32 |
+
|
| 33 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 34 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
text_config (`Union[ShieldGemma2TextConfig, dict]`, *optional*):
|
| 38 |
+
The config object of the text backbone.
|
| 39 |
+
vision_config (`Union[AutoConfig, dict]`, *optional*):
|
| 40 |
+
Custom vision config or dict.
|
| 41 |
+
mm_tokens_per_image (`int`, *optional*, defaults to 256):
|
| 42 |
+
The number of tokens per image embedding.
|
| 43 |
+
boi_token_index (`int`, *optional*, defaults to 255999):
|
| 44 |
+
The begin-of-image token index to wrap the image prompt.
|
| 45 |
+
eoi_token_index (`int`, *optional*, defaults to 256000):
|
| 46 |
+
The end-of-image token index to wrap the image prompt.
|
| 47 |
+
image_token_index (`int`, *optional*, defaults to 262144):
|
| 48 |
+
The image token index to encode the image prompt.
|
| 49 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 50 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
Example:
|
| 54 |
+
|
| 55 |
+
```python
|
| 56 |
+
>>> from transformers import ShieldGemma2ForConditionalGeneration, ShieldGemma2Config, SiglipVisionConfig, ShieldGemma2TextConfig
|
| 57 |
+
|
| 58 |
+
>>> # Initializing a Siglip-like vision config
|
| 59 |
+
>>> vision_config = SiglipVisionConfig()
|
| 60 |
+
|
| 61 |
+
>>> # Initializing a ShieldGemma2 Text config
|
| 62 |
+
>>> text_config = ShieldGemma2TextConfig()
|
| 63 |
+
|
| 64 |
+
>>> # Initializing a ShieldGemma2 gemma-3-4b style configuration
|
| 65 |
+
>>> configuration = ShieldGemma2Config(vision_config, text_config)
|
| 66 |
+
|
| 67 |
+
>>> # Initializing a model from the gemma-3-4b style configuration
|
| 68 |
+
>>> model = ShieldGemma2TextConfig(configuration)
|
| 69 |
+
|
| 70 |
+
>>> # Accessing the model configuration
|
| 71 |
+
>>> configuration = model.config
|
| 72 |
+
```"""
|
| 73 |
+
|
| 74 |
+
model_type = "shieldgemma2"
|
| 75 |
+
attribute_map = {
|
| 76 |
+
"image_token_id": "image_token_index",
|
| 77 |
+
"boi_token_id": "boi_token_index",
|
| 78 |
+
"eoi_token_id": "eoi_token_index",
|
| 79 |
+
}
|
| 80 |
+
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
text_config=None,
|
| 85 |
+
vision_config=None,
|
| 86 |
+
mm_tokens_per_image: int = 256,
|
| 87 |
+
boi_token_index: int = 255_999,
|
| 88 |
+
eoi_token_index: int = 256_000,
|
| 89 |
+
image_token_index: int = 262_144,
|
| 90 |
+
initializer_range: float = 0.02,
|
| 91 |
+
**kwargs,
|
| 92 |
+
):
|
| 93 |
+
if isinstance(vision_config, dict):
|
| 94 |
+
vision_config["model_type"] = (
|
| 95 |
+
vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
|
| 96 |
+
)
|
| 97 |
+
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
| 98 |
+
elif vision_config is None:
|
| 99 |
+
vision_config = CONFIG_MAPPING["siglip_vision_model"]()
|
| 100 |
+
|
| 101 |
+
self.vision_config = vision_config
|
| 102 |
+
|
| 103 |
+
if isinstance(text_config, dict):
|
| 104 |
+
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma3_text"
|
| 105 |
+
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
| 106 |
+
elif text_config is None:
|
| 107 |
+
text_config = CONFIG_MAPPING["gemma3_text"]()
|
| 108 |
+
|
| 109 |
+
self.text_config = text_config
|
| 110 |
+
self.vision_config = vision_config
|
| 111 |
+
self.mm_tokens_per_image = mm_tokens_per_image
|
| 112 |
+
self.boi_token_index = boi_token_index
|
| 113 |
+
self.eoi_token_index = eoi_token_index
|
| 114 |
+
self.image_token_index = image_token_index
|
| 115 |
+
self.initializer_range = initializer_range
|
| 116 |
+
|
| 117 |
+
super().__init__(**kwargs)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
__all__ = ["ShieldGemma2Config"]
|
docs/transformers/build/lib/transformers/models/shieldgemma2/convert_shieldgemma2_weights_orbax_to_hf.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint.
|
| 2 |
+
|
| 3 |
+
python -m transformers.models.shieldgemma2.convert_shieldgemma2_weights_orbax_to_hf \
|
| 4 |
+
--tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \
|
| 5 |
+
--checkpoint_path_gemma="$HOME/gemma3/gemma3_4b_pt_orbax/" \
|
| 6 |
+
--checkpoint_path_shieldgemma="$HOME/shieldgemma2/shieldgemma-2_4b_orbax/" \
|
| 7 |
+
--output_path="$HOME/shieldgemma2/shieldgemma2_4b_pt_safetensors/" \
|
| 8 |
+
--precision='bfloat16'
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import dataclasses
|
| 12 |
+
from collections.abc import Iterator, Mapping, Sequence
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
import accelerate
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import tree
|
| 19 |
+
from absl import app, flags, logging
|
| 20 |
+
from orbax import checkpoint as obc
|
| 21 |
+
|
| 22 |
+
from ...image_utils import PILImageResampling
|
| 23 |
+
from ..gemma import GemmaTokenizerFast
|
| 24 |
+
from ..gemma3 import Gemma3ImageProcessor, Gemma3TextConfig
|
| 25 |
+
from ..siglip import SiglipVisionConfig
|
| 26 |
+
from . import ShieldGemma2Config, ShieldGemma2ForImageClassification, ShieldGemma2Processor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ==== Internal Constants and Classes ====
|
| 30 |
+
|
| 31 |
+
_CHAT_TEMPLATE = """{{ bos_token }}
|
| 32 |
+
{%- for message in messages -%}
|
| 33 |
+
{%- if message['content'] is iterable -%}
|
| 34 |
+
{%- for item in message['content'] -%}
|
| 35 |
+
{{ '<start_of_turn>user\n' }}
|
| 36 |
+
{%- if item['type'] == 'image' -%}
|
| 37 |
+
{{ '<start_of_image>' }}
|
| 38 |
+
{%- elif item['type'] == 'text' -%}
|
| 39 |
+
{{ 'You are a policy expert trying to help determine whether a generated or real image is in' }}
|
| 40 |
+
{{ ' violation of the safety policy below. Your decision must very clearly follow the specific'}}
|
| 41 |
+
{{ ' language used in the safety policy and should not include any interpretation outside of the' }}
|
| 42 |
+
{{ ' safety policy.\n\n' }}
|
| 43 |
+
{{ '<BEGIN_SAFETY_POLICY>\n' }}
|
| 44 |
+
{{ item['text'] + '\n' }}
|
| 45 |
+
{{ '</END_SAFETY_POLICY>\n\n' }}
|
| 46 |
+
{{ 'Does the image violate the above policy? Your Answer must start with \\'Yes\\' or \\'No\\'.' }}
|
| 47 |
+
{{ '<end_of_turn>\n' }}
|
| 48 |
+
{%- endif -%}
|
| 49 |
+
{%- endfor -%}
|
| 50 |
+
{{'<start_of_turn>model\n'}}
|
| 51 |
+
{%- else -%}
|
| 52 |
+
{{ raise_exception("Conversation messages must contain iterable content containing images and policy definitions in text.") }}
|
| 53 |
+
{%- endif -%}
|
| 54 |
+
{%- endfor -%}
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
_DTYPES = {
|
| 58 |
+
"float32": torch.float32,
|
| 59 |
+
"bfloat16": torch.bfloat16,
|
| 60 |
+
"float16": torch.float16,
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
_SHIELDGEMMA2_POLICIES: Mapping[str, str] = {
|
| 64 |
+
"dangerous": (
|
| 65 |
+
"No Dangerous Content: The image shall not contain content that"
|
| 66 |
+
" facilitates or encourages activities that could cause real-world harm"
|
| 67 |
+
" (e.g., building firearms and explosive devices, promotion of"
|
| 68 |
+
" terrorism, instructions for suicide)."
|
| 69 |
+
),
|
| 70 |
+
"sexual": (
|
| 71 |
+
"No Sexually Explicit content: The image shall not contain content that"
|
| 72 |
+
" depicts explicit or graphic sexual acts (e.g., pornography, erotic"
|
| 73 |
+
" nudity, depictions of rape or sexual assault)."
|
| 74 |
+
),
|
| 75 |
+
"violence": (
|
| 76 |
+
"No Violence/Gore content: The image shall not contain content that"
|
| 77 |
+
" depicts shocking, sensational, or gratuitous violence (e.g.,"
|
| 78 |
+
" excessive blood and gore, gratuitous violence against animals,"
|
| 79 |
+
" extreme injury or moment of death)."
|
| 80 |
+
),
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
_SIGLIP_BASE = "SigLiPFromPatches_0/siglip_encoder"
|
| 84 |
+
_SIGLIP_EMBEDDING = "SigLiPFromPatches_0/siglip_encoder/embedding"
|
| 85 |
+
_SIGLIP_TRANSFORMER_ENCODER_BLOCK = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_"
|
| 86 |
+
_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN = len(_SIGLIP_TRANSFORMER_ENCODER_BLOCK)
|
| 87 |
+
_SIGLIP_TRANSFORMER_ENCODER_NORM = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoder_norm"
|
| 88 |
+
|
| 89 |
+
_TRANSFORMER_DECODER_BLOCK = "transformer/layer_"
|
| 90 |
+
_TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK)
|
| 91 |
+
_TRANSFORMER_EMBEDDER = "transformer/embedder"
|
| 92 |
+
_TRANSFORMER_FINAL_NORM = "transformer/final_norm"
|
| 93 |
+
_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/"
|
| 94 |
+
_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX)
|
| 95 |
+
|
| 96 |
+
# ==== Flags ====
|
| 97 |
+
|
| 98 |
+
_GEMMA_CHECKPOINT_PATH = flags.DEFINE_string(
|
| 99 |
+
name="checkpoint_path_gemma",
|
| 100 |
+
default=None,
|
| 101 |
+
help="Path to the Orbax checkpoint containing the vision weights.",
|
| 102 |
+
required=True,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
_SHIELDGEMMA_CHECKPOINT_PATH = flags.DEFINE_string(
|
| 106 |
+
name="checkpoint_path_shieldgemma",
|
| 107 |
+
default=None,
|
| 108 |
+
help="Path to the Orbax checkpoint containing the language model weights.",
|
| 109 |
+
required=True,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
OUTPUT_PATH = flags.DEFINE_string(
|
| 113 |
+
name="output_path",
|
| 114 |
+
default=None,
|
| 115 |
+
help="Path to store the HF checkpoint.",
|
| 116 |
+
required=True,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
PRECISION = flags.DEFINE_enum(
|
| 120 |
+
name="precision",
|
| 121 |
+
default=None,
|
| 122 |
+
help="The floating point precision (aka dtype) of the model.",
|
| 123 |
+
enum_values=set(_DTYPES.keys()),
|
| 124 |
+
required=True,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
TOKENIZER_PATH = flags.DEFINE_string(
|
| 128 |
+
name="tokenizer_path",
|
| 129 |
+
default=None,
|
| 130 |
+
help="Path to the SentencePiece model file.",
|
| 131 |
+
required=True,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def convert_siglip_weight(
|
| 136 |
+
config: SiglipVisionConfig,
|
| 137 |
+
paths: Sequence[str],
|
| 138 |
+
weights: np.ndarray,
|
| 139 |
+
) -> tuple[str, np.ndarray]:
|
| 140 |
+
path, prop = paths
|
| 141 |
+
normalized_path: str = ""
|
| 142 |
+
updated_weights: np.ndarray = None
|
| 143 |
+
|
| 144 |
+
if path == _SIGLIP_BASE:
|
| 145 |
+
normalized_path = "vision_tower.vision_model.embeddings.position_embedding.weight"
|
| 146 |
+
updated_weights = weights.reshape(-1, config.hidden_size)
|
| 147 |
+
elif path == _SIGLIP_EMBEDDING:
|
| 148 |
+
if prop == "kernel":
|
| 149 |
+
normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.weight"
|
| 150 |
+
updated_weights = weights.transpose(3, 2, 0, 1)
|
| 151 |
+
elif prop == "bias":
|
| 152 |
+
normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.bias"
|
| 153 |
+
updated_weights = weights
|
| 154 |
+
else:
|
| 155 |
+
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.")
|
| 156 |
+
elif path.startswith(_SIGLIP_TRANSFORMER_ENCODER_BLOCK):
|
| 157 |
+
encoder_block_path = path[_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN:]
|
| 158 |
+
next_path_seperator_idx = encoder_block_path.find("/")
|
| 159 |
+
layer_idx = encoder_block_path[:next_path_seperator_idx]
|
| 160 |
+
encoder_block_path = encoder_block_path[next_path_seperator_idx:]
|
| 161 |
+
normalized_path = f"vision_tower.vision_model.encoder.layers.{layer_idx}"
|
| 162 |
+
|
| 163 |
+
if encoder_block_path.startswith("/LayerNorm"):
|
| 164 |
+
normalized_path += ".layer_norm1" if path.endswith("_0") else ".layer_norm2"
|
| 165 |
+
|
| 166 |
+
if prop == "scale":
|
| 167 |
+
normalized_path += ".weight"
|
| 168 |
+
updated_weights = weights.transpose()
|
| 169 |
+
elif prop == "bias":
|
| 170 |
+
normalized_path += ".bias"
|
| 171 |
+
updated_weights = weights
|
| 172 |
+
else:
|
| 173 |
+
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.")
|
| 174 |
+
elif encoder_block_path.startswith("/MlpBlock_0"):
|
| 175 |
+
normalized_path += ".mlp.fc1" if "/Dense_0" in encoder_block_path else ".mlp.fc2"
|
| 176 |
+
|
| 177 |
+
if prop == "kernel":
|
| 178 |
+
normalized_path += ".weight"
|
| 179 |
+
updated_weights = weights.transpose()
|
| 180 |
+
elif prop == "bias":
|
| 181 |
+
normalized_path += ".bias"
|
| 182 |
+
updated_weights = weights
|
| 183 |
+
else:
|
| 184 |
+
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.")
|
| 185 |
+
elif encoder_block_path.startswith("/MultiHeadDotProductAttention_0"):
|
| 186 |
+
if encoder_block_path.endswith("/key"):
|
| 187 |
+
normalized_path += ".self_attn.k_proj"
|
| 188 |
+
elif encoder_block_path.endswith("/out"):
|
| 189 |
+
normalized_path += ".self_attn.out_proj"
|
| 190 |
+
elif encoder_block_path.endswith("/query"):
|
| 191 |
+
normalized_path += ".self_attn.q_proj"
|
| 192 |
+
elif encoder_block_path.endswith("/value"):
|
| 193 |
+
normalized_path += ".self_attn.v_proj"
|
| 194 |
+
else:
|
| 195 |
+
raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer MultiHeadDotProductAttention_0.")
|
| 196 |
+
|
| 197 |
+
if prop == "bias":
|
| 198 |
+
normalized_path += ".bias"
|
| 199 |
+
updated_weights = weights.reshape(-1, config.hidden_size).reshape(-1)
|
| 200 |
+
elif prop == "kernel":
|
| 201 |
+
normalized_path += ".weight"
|
| 202 |
+
updated_weights = weights.reshape(-1, config.hidden_size).transpose()
|
| 203 |
+
else:
|
| 204 |
+
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.")
|
| 205 |
+
else:
|
| 206 |
+
raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer Encoder Block.")
|
| 207 |
+
elif path == _SIGLIP_TRANSFORMER_ENCODER_NORM:
|
| 208 |
+
if prop == "scale":
|
| 209 |
+
normalized_path = "vision_tower.vision_model.post_layernorm.weight"
|
| 210 |
+
updated_weights = weights.transpose()
|
| 211 |
+
elif prop == "bias":
|
| 212 |
+
normalized_path = "vision_tower.vision_model.post_layernorm.bias"
|
| 213 |
+
updated_weights = weights
|
| 214 |
+
else:
|
| 215 |
+
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.")
|
| 216 |
+
else:
|
| 217 |
+
raise ValueError(f"Unexpected path `{path}`.")
|
| 218 |
+
|
| 219 |
+
if "vision" in normalized_path:
|
| 220 |
+
print(normalized_path)
|
| 221 |
+
return normalized_path, updated_weights
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def convert_transformer_weights(
|
| 225 |
+
config: Gemma3TextConfig,
|
| 226 |
+
paths: Sequence[str],
|
| 227 |
+
weights: np.ndarray,
|
| 228 |
+
) -> Iterator[tuple[str, np.ndarray]]:
|
| 229 |
+
path, prop = paths
|
| 230 |
+
|
| 231 |
+
if path.startswith(_TRANSFORMER_POST_TRAINING_PREFIX):
|
| 232 |
+
path = path[_TRANSFORMER_POST_TRAINING_PREFIX_LEN:]
|
| 233 |
+
|
| 234 |
+
converted_paths: list[str] = []
|
| 235 |
+
converted_weights: list[Any] = []
|
| 236 |
+
|
| 237 |
+
attn_head_dim = config.num_attention_heads * config.head_dim
|
| 238 |
+
kv_head_dim = config.num_key_value_heads * config.head_dim
|
| 239 |
+
|
| 240 |
+
if path == _TRANSFORMER_EMBEDDER:
|
| 241 |
+
if prop == "input_embedding":
|
| 242 |
+
# Tied to language_model.lm_head.weight, assigned at the end.
|
| 243 |
+
converted_paths = ["language_model.model.embed_tokens.weight"]
|
| 244 |
+
# Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama
|
| 245 |
+
pre_expansion_embeddings = weights
|
| 246 |
+
mu = np.mean(pre_expansion_embeddings, axis=0)
|
| 247 |
+
sigma = np.cov(pre_expansion_embeddings, rowvar=False, bias=True)
|
| 248 |
+
new_embeddings = np.random.multivariate_normal(mu, 1e-5 * sigma, size=64)
|
| 249 |
+
weights = np.vstack([pre_expansion_embeddings, new_embeddings])
|
| 250 |
+
converted_weights = [weights]
|
| 251 |
+
else:
|
| 252 |
+
raise ValueError(f"Unexpected member, {prop}, in Embedder.")
|
| 253 |
+
elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"):
|
| 254 |
+
if path.endswith("/mm_input_projection"):
|
| 255 |
+
converted_paths = ["multi_modal_projector.mm_input_projection_weight"]
|
| 256 |
+
converted_weights = [weights]
|
| 257 |
+
elif path.endswith("/mm_soft_embedding_norm"):
|
| 258 |
+
converted_paths = ["multi_modal_projector.mm_soft_emb_norm.weight"]
|
| 259 |
+
converted_weights = [weights]
|
| 260 |
+
else:
|
| 261 |
+
raise ValueError(f"Unexpected subpath, `{path}`, in Embedder.")
|
| 262 |
+
elif path == _TRANSFORMER_FINAL_NORM:
|
| 263 |
+
converted_paths = ["language_model.model.norm.weight"]
|
| 264 |
+
converted_weights = [weights]
|
| 265 |
+
elif path.startswith(_TRANSFORMER_DECODER_BLOCK):
|
| 266 |
+
decoder_block_path = path[_TRANSFORMER_DECODER_BLOCK_LEN:]
|
| 267 |
+
next_path_seperator_idx = decoder_block_path.find("/")
|
| 268 |
+
layer_idx = decoder_block_path[:next_path_seperator_idx]
|
| 269 |
+
decoder_block_path = decoder_block_path[next_path_seperator_idx:]
|
| 270 |
+
|
| 271 |
+
base_path = f"language_model.model.layers.{layer_idx}"
|
| 272 |
+
|
| 273 |
+
if path.endswith("attn/attn_vec_einsum"):
|
| 274 |
+
converted_paths = [f"{base_path}.self_attn.o_proj.weight"]
|
| 275 |
+
converted_weights = [weights.transpose(2, 0, 1).reshape(config.hidden_size, attn_head_dim)]
|
| 276 |
+
elif path.endswith("attn/_key_norm"):
|
| 277 |
+
converted_paths = [f"{base_path}.self_attn.k_norm.weight"]
|
| 278 |
+
converted_weights = [weights]
|
| 279 |
+
elif path.endswith("attn/kv_einsum"):
|
| 280 |
+
converted_paths = [
|
| 281 |
+
f"{base_path}.self_attn.k_proj.weight",
|
| 282 |
+
f"{base_path}.self_attn.v_proj.weight",
|
| 283 |
+
]
|
| 284 |
+
k_proj_weights, v_proj_weights = weights
|
| 285 |
+
converted_weights = [
|
| 286 |
+
k_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size),
|
| 287 |
+
v_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size),
|
| 288 |
+
]
|
| 289 |
+
elif path.endswith("attn/q_einsum"):
|
| 290 |
+
converted_paths = [f"{base_path}.self_attn.q_proj.weight"]
|
| 291 |
+
converted_weights = [weights.transpose(0, 2, 1).reshape(attn_head_dim, config.hidden_size)]
|
| 292 |
+
elif path.endswith("attn/_query_norm"):
|
| 293 |
+
converted_paths = [f"{base_path}.self_attn.q_norm.weight"]
|
| 294 |
+
converted_weights = [weights]
|
| 295 |
+
elif path.endswith("mlp/gating_einsum"):
|
| 296 |
+
converted_paths = [
|
| 297 |
+
f"{base_path}.mlp.gate_proj.weight",
|
| 298 |
+
f"{base_path}.mlp.up_proj.weight",
|
| 299 |
+
]
|
| 300 |
+
gate_proj_weight, up_proj_weight = weights
|
| 301 |
+
converted_weights = [gate_proj_weight, up_proj_weight]
|
| 302 |
+
elif path.endswith("mlp/linear"):
|
| 303 |
+
converted_paths = [f"{base_path}.mlp.down_proj.weight"]
|
| 304 |
+
converted_weights = [weights.transpose()]
|
| 305 |
+
elif path.endswith("post_attention_norm"):
|
| 306 |
+
converted_paths = [f"{base_path}.post_attention_layernorm.weight"]
|
| 307 |
+
converted_weights = [weights]
|
| 308 |
+
elif path.endswith("post_ffw_norm"):
|
| 309 |
+
converted_paths = [f"{base_path}.post_feedforward_layernorm.weight"]
|
| 310 |
+
converted_weights = [weights]
|
| 311 |
+
elif path.endswith("pre_attention_norm"):
|
| 312 |
+
converted_paths = [f"{base_path}.input_layernorm.weight"]
|
| 313 |
+
converted_weights = [weights]
|
| 314 |
+
elif path.endswith("pre_ffw_norm"):
|
| 315 |
+
converted_paths = [f"{base_path}.pre_feedforward_layernorm.weight"]
|
| 316 |
+
converted_weights = [weights]
|
| 317 |
+
else:
|
| 318 |
+
raise ValueError(f"Unexpected path `{path}` in Decoder Block.")
|
| 319 |
+
else:
|
| 320 |
+
raise ValueError(f"Unexpected path `{path}`.")
|
| 321 |
+
|
| 322 |
+
if (cpl := len(converted_paths)) != (cwl := len(converted_weights)):
|
| 323 |
+
raise ValueError(
|
| 324 |
+
"The `converted_paths` and `converted_weights` should be the same "
|
| 325 |
+
f"length. Got {cpl} and {cwl}, respectively, for {path}."
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
return zip(converted_paths, converted_weights)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def transpose_reshape(x: torch.Tensor) -> torch.Tensor:
|
| 332 |
+
x = x.transpose(1, 2)
|
| 333 |
+
return x.reshape(x.shape[0] * x.shape[1], x.shape[2]).contiguous()
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
@dataclasses.dataclass(frozen=True)
|
| 337 |
+
class ConversionResult:
|
| 338 |
+
state_tree: dict[str, torch.Tensor]
|
| 339 |
+
config: ShieldGemma2Config
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def convert(
|
| 343 |
+
shieldgemma_checkpoint_path: str,
|
| 344 |
+
gemma_checkpoint_path: str,
|
| 345 |
+
config: ShieldGemma2Config,
|
| 346 |
+
target_dtype: torch.dtype,
|
| 347 |
+
) -> ConversionResult:
|
| 348 |
+
"""Loads Orbax checkpoint from `input_path` and converts it to HF tree."""
|
| 349 |
+
checkpointer = obc.PyTreeCheckpointer()
|
| 350 |
+
|
| 351 |
+
sg2_ckpt = checkpointer.restore(shieldgemma_checkpoint_path)
|
| 352 |
+
g3_ckpt = checkpointer.restore(gemma_checkpoint_path)
|
| 353 |
+
|
| 354 |
+
hf_tree: dict[str, torch.Tensor] = {}
|
| 355 |
+
|
| 356 |
+
def update_tree(path: str, weights: np.ndarray) -> None:
|
| 357 |
+
torch_tensor = torch.from_numpy(weights.astype("float32")).type(target_dtype)
|
| 358 |
+
logging.info(
|
| 359 |
+
"%s converted shape=%s with dtype=%s",
|
| 360 |
+
path,
|
| 361 |
+
weights.shape,
|
| 362 |
+
torch_tensor.dtype,
|
| 363 |
+
)
|
| 364 |
+
hf_tree[f"model.{path}"] = torch_tensor
|
| 365 |
+
|
| 366 |
+
for paths, value in tree.flatten_with_path(g3_ckpt):
|
| 367 |
+
if paths[0].startswith("SigLiPFromPatches_"):
|
| 368 |
+
path, weights = convert_siglip_weight(config=config.vision_config, paths=paths, weights=value)
|
| 369 |
+
update_tree(path, weights)
|
| 370 |
+
|
| 371 |
+
for paths, value in tree.flatten_with_path(sg2_ckpt):
|
| 372 |
+
for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value):
|
| 373 |
+
update_tree(path, weights)
|
| 374 |
+
|
| 375 |
+
hf_tree["model.language_model.lm_head.weight"] = hf_tree["model.language_model.model.embed_tokens.weight"]
|
| 376 |
+
|
| 377 |
+
return ConversionResult(state_tree=hf_tree, config=config)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def main(*args):
|
| 381 |
+
del args
|
| 382 |
+
|
| 383 |
+
dtype = getattr(torch, PRECISION.value)
|
| 384 |
+
output_path = OUTPUT_PATH.value
|
| 385 |
+
|
| 386 |
+
tokenizer = GemmaTokenizerFast(
|
| 387 |
+
TOKENIZER_PATH.value,
|
| 388 |
+
extra_special_tokens={
|
| 389 |
+
"image_token": "<image_soft_token>", # Should be ID=262_144
|
| 390 |
+
"boi_token": "<start_of_image>", # Should be ID=255_999
|
| 391 |
+
"eoi_token": "<end_of_image>", # Should be ID=256_000
|
| 392 |
+
},
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
yes_token_index, no_token_index = torch.tensor(tokenizer(["Yes", "No"])["input_ids"])[:, 1].numpy()
|
| 396 |
+
|
| 397 |
+
config = ShieldGemma2Config(
|
| 398 |
+
yes_token_index=int(yes_token_index),
|
| 399 |
+
no_token_index=int(no_token_index),
|
| 400 |
+
text_config=Gemma3TextConfig(
|
| 401 |
+
vocab_size=262_208,
|
| 402 |
+
hidden_size=2560,
|
| 403 |
+
intermediate_size=2560 * 8 // 2,
|
| 404 |
+
num_attention_heads=8,
|
| 405 |
+
head_dim=256,
|
| 406 |
+
num_hidden_layers=34,
|
| 407 |
+
num_key_value_heads=4,
|
| 408 |
+
sliding_window=1024,
|
| 409 |
+
rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only
|
| 410 |
+
rope_theta=1_000_000,
|
| 411 |
+
rope_local_base_freq=10_000,
|
| 412 |
+
attn_logit_softcapping=None,
|
| 413 |
+
query_pre_attn_scalar=256,
|
| 414 |
+
max_position_embeddings=8192,
|
| 415 |
+
),
|
| 416 |
+
vision_config={
|
| 417 |
+
"hidden_size": 1152,
|
| 418 |
+
"intermediate_size": 4304,
|
| 419 |
+
"num_hidden_layers": 27,
|
| 420 |
+
"num_attention_heads": 16,
|
| 421 |
+
"num_channels": 3,
|
| 422 |
+
"image_size": 896,
|
| 423 |
+
"patch_size": 14,
|
| 424 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 425 |
+
"layer_norm_eps": 1e-6,
|
| 426 |
+
"attention_dropout": 0.0,
|
| 427 |
+
"vision_use_head": False,
|
| 428 |
+
},
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
config.save_pretrained(output_path)
|
| 432 |
+
|
| 433 |
+
image_processor = Gemma3ImageProcessor(
|
| 434 |
+
image_seq_length=256,
|
| 435 |
+
image_mean=(0.5,) * 3,
|
| 436 |
+
image_std=(0.5,) * 3,
|
| 437 |
+
size={"height": 896, "width": 896},
|
| 438 |
+
resample=PILImageResampling.BILINEAR,
|
| 439 |
+
)
|
| 440 |
+
processor = ShieldGemma2Processor(
|
| 441 |
+
image_processor=image_processor,
|
| 442 |
+
tokenizer=tokenizer,
|
| 443 |
+
policy_definitions=_SHIELDGEMMA2_POLICIES,
|
| 444 |
+
)
|
| 445 |
+
tokenizer.chat_template = _CHAT_TEMPLATE
|
| 446 |
+
processor.chat_template = _CHAT_TEMPLATE
|
| 447 |
+
|
| 448 |
+
processor.save_pretrained(output_path)
|
| 449 |
+
logging.info("Saved Shieldgemma2Processor to %s", output_path)
|
| 450 |
+
del processor
|
| 451 |
+
del tokenizer
|
| 452 |
+
|
| 453 |
+
logging.info("Converting Shieldgemma2 @ %s", dtype)
|
| 454 |
+
result = convert(_SHIELDGEMMA_CHECKPOINT_PATH.value, _GEMMA_CHECKPOINT_PATH.value, config, dtype)
|
| 455 |
+
logging.info("Converted Shieldgemma2 state tree from Orbax to Hugging Face.")
|
| 456 |
+
|
| 457 |
+
with accelerate.init_empty_weights():
|
| 458 |
+
model = ShieldGemma2ForImageClassification(config=config)
|
| 459 |
+
|
| 460 |
+
model.load_state_dict(result.state_tree, assign=True, strict=True)
|
| 461 |
+
model.config.torch_dtype = dtype
|
| 462 |
+
logging.info("Loaded Shieldgemma2 in Hugging Face Transformers.")
|
| 463 |
+
model.save_pretrained(output_path, safe_serialization=True)
|
| 464 |
+
logging.info("Saved Shieldgemma2 to SafeTensors in %s", output_path)
|
| 465 |
+
del model
|
| 466 |
+
del result
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
if __name__ == "__main__":
|
| 470 |
+
app.run(main)
|
docs/transformers/build/lib/transformers/models/shieldgemma2/modeling_shieldgemma2.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import List, Optional, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.utils.checkpoint
|
| 21 |
+
|
| 22 |
+
from ...cache_utils import Cache
|
| 23 |
+
from ...modeling_outputs import ImageClassifierOutputWithNoAttention
|
| 24 |
+
from ...modeling_utils import PreTrainedModel
|
| 25 |
+
from ...utils import (
|
| 26 |
+
add_start_docstrings_to_model_forward,
|
| 27 |
+
logging,
|
| 28 |
+
)
|
| 29 |
+
from ..auto import AutoModelForImageTextToText
|
| 30 |
+
from .configuration_shieldgemma2 import ShieldGemma2Config
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
_CHECKPOINT_FOR_DOC = "google/shieldgemma-2-4b-it"
|
| 34 |
+
_CONFIG_FOR_DOC = "ShieldGemma2Config"
|
| 35 |
+
|
| 36 |
+
logger = logging.get_logger(__name__)
|
| 37 |
+
|
| 38 |
+
SHIELDGEMMA2_INPUTS_DOCSTRING = r"""
|
| 39 |
+
Args:
|
| 40 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 41 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 42 |
+
it.
|
| 43 |
+
|
| 44 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 45 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 46 |
+
|
| 47 |
+
[What are input IDs?](../glossary#input-ids)
|
| 48 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 49 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 50 |
+
|
| 51 |
+
- 1 for tokens that are **not masked**,
|
| 52 |
+
- 0 for tokens that are **masked**.
|
| 53 |
+
|
| 54 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 55 |
+
|
| 56 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 57 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 58 |
+
|
| 59 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
| 60 |
+
`past_key_values`).
|
| 61 |
+
|
| 62 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 63 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 64 |
+
information on the default strategy.
|
| 65 |
+
|
| 66 |
+
- 1 indicates the head is **not masked**,
|
| 67 |
+
- 0 indicates the head is **masked**.
|
| 68 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 69 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 70 |
+
config.n_positions - 1]`.
|
| 71 |
+
|
| 72 |
+
[What are position IDs?](../glossary#position-ids)
|
| 73 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
| 74 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 75 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 76 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 77 |
+
|
| 78 |
+
Two formats are allowed:
|
| 79 |
+
- a [`~cache_utils.Cache`] instance, see our
|
| 80 |
+
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
| 81 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 82 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
| 83 |
+
cache format.
|
| 84 |
+
|
| 85 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
| 86 |
+
legacy cache format will be returned.
|
| 87 |
+
|
| 88 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 89 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 90 |
+
of shape `(batch_size, sequence_length)`.
|
| 91 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 92 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 93 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 94 |
+
model's internal embedding lookup matrix.
|
| 95 |
+
use_cache (`bool`, *optional*):
|
| 96 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 97 |
+
`past_key_values`).
|
| 98 |
+
output_attentions (`bool`, *optional*):
|
| 99 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 100 |
+
tensors for more detail.
|
| 101 |
+
output_hidden_states (`bool`, *optional*):
|
| 102 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 103 |
+
more detail.
|
| 104 |
+
return_dict (`bool`, *optional*):
|
| 105 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 106 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 107 |
+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
| 108 |
+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
| 109 |
+
the complete sequence length.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@dataclass
|
| 114 |
+
class ShieldGemma2ImageClassifierOutputWithNoAttention(ImageClassifierOutputWithNoAttention):
|
| 115 |
+
"""ShieldGemma2 classifies imags as violative or not relative to a specific policy
|
| 116 |
+
Args:
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
probabilities: Optional[torch.Tensor] = None
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class ShieldGemma2ForImageClassification(PreTrainedModel):
|
| 123 |
+
config_class = ShieldGemma2Config
|
| 124 |
+
|
| 125 |
+
def __init__(self, config: ShieldGemma2Config):
|
| 126 |
+
super().__init__(config=config)
|
| 127 |
+
self.yes_token_index = getattr(config, "yes_token_index", 10_784)
|
| 128 |
+
self.no_token_index = getattr(config, "no_token_index", 3771)
|
| 129 |
+
self.model = AutoModelForImageTextToText.from_config(config=config)
|
| 130 |
+
|
| 131 |
+
def get_input_embeddings(self):
|
| 132 |
+
return self.model.language_model.get_input_embeddings()
|
| 133 |
+
|
| 134 |
+
def set_input_embeddings(self, value):
|
| 135 |
+
self.model.language_model.set_input_embeddings(value)
|
| 136 |
+
|
| 137 |
+
def get_output_embeddings(self):
|
| 138 |
+
return self.model.language_model.get_output_embeddings()
|
| 139 |
+
|
| 140 |
+
def set_output_embeddings(self, new_embeddings):
|
| 141 |
+
self.model.language_model.set_output_embeddings(new_embeddings)
|
| 142 |
+
|
| 143 |
+
def set_decoder(self, decoder):
|
| 144 |
+
self.model.language_model.set_decoder(decoder)
|
| 145 |
+
|
| 146 |
+
def get_decoder(self):
|
| 147 |
+
return self.model.language_model.get_decoder()
|
| 148 |
+
|
| 149 |
+
def tie_weights(self):
|
| 150 |
+
return self.model.language_model.tie_weights()
|
| 151 |
+
|
| 152 |
+
@add_start_docstrings_to_model_forward(SHIELDGEMMA2_INPUTS_DOCSTRING)
|
| 153 |
+
def forward(
|
| 154 |
+
self,
|
| 155 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 156 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 157 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 158 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 159 |
+
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
| 160 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 161 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 162 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 163 |
+
labels: Optional[torch.LongTensor] = None,
|
| 164 |
+
use_cache: Optional[bool] = None,
|
| 165 |
+
output_attentions: Optional[bool] = None,
|
| 166 |
+
output_hidden_states: Optional[bool] = None,
|
| 167 |
+
return_dict: Optional[bool] = None,
|
| 168 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 169 |
+
**lm_kwargs,
|
| 170 |
+
) -> ShieldGemma2ImageClassifierOutputWithNoAttention:
|
| 171 |
+
"""Predicts the binary probability that the image violates the specified policy.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
A `ShieldGemma2ImageClassifierOutputWithNoAttention` instance containing the logits and probabilities
|
| 175 |
+
associated with the model predicting the `Yes` or `No` token as the response to that prompt, captured in the
|
| 176 |
+
following properties.
|
| 177 |
+
|
| 178 |
+
* `logits` (`torch.Tensor` of shape `(batch_size, 2)`):
|
| 179 |
+
The first position along dim=1 is the logits for the `Yes` token and the second position along dim=1 is
|
| 180 |
+
the logits for the `No` token.
|
| 181 |
+
* `probabilities` (`torch.Tensor` of shape `(batch_size, 2)`):
|
| 182 |
+
The first position along dim=1 is the probability of predicting the `Yes` token and the second position
|
| 183 |
+
along dim=1 is the probability of predicting the `No` token.
|
| 184 |
+
|
| 185 |
+
ShieldGemma prompts are constructed such that predicting the `Yes` token means the content *does violate* the
|
| 186 |
+
policy as described. If you are only interested in the violative condition, use
|
| 187 |
+
`violated = outputs.probabilities[:, 1]` to extract that slice from the output tensors.
|
| 188 |
+
|
| 189 |
+
When used with the `ShieldGemma2Processor`, the `batch_size` will be equal to `len(images) * len(policies)`,
|
| 190 |
+
and the order within the batch will be img1_policy1, ... img1_policyN, ... imgM_policyN.
|
| 191 |
+
"""
|
| 192 |
+
outputs = self.model(
|
| 193 |
+
input_ids=input_ids,
|
| 194 |
+
pixel_values=pixel_values,
|
| 195 |
+
attention_mask=attention_mask,
|
| 196 |
+
position_ids=position_ids,
|
| 197 |
+
past_key_values=past_key_values,
|
| 198 |
+
token_type_ids=token_type_ids,
|
| 199 |
+
cache_position=cache_position,
|
| 200 |
+
inputs_embeds=inputs_embeds,
|
| 201 |
+
labels=labels,
|
| 202 |
+
use_cache=use_cache,
|
| 203 |
+
output_attentions=output_attentions,
|
| 204 |
+
output_hidden_states=output_hidden_states,
|
| 205 |
+
return_dict=return_dict,
|
| 206 |
+
logits_to_keep=logits_to_keep,
|
| 207 |
+
**lm_kwargs,
|
| 208 |
+
)
|
| 209 |
+
logits = outputs.logits
|
| 210 |
+
selected_logits = logits[:, -1, [self.yes_token_index, self.no_token_index]]
|
| 211 |
+
probabilities = torch.softmax(selected_logits, dim=-1)
|
| 212 |
+
return ShieldGemma2ImageClassifierOutputWithNoAttention(
|
| 213 |
+
logits=selected_logits,
|
| 214 |
+
probabilities=probabilities,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
__all__ = [
|
| 219 |
+
"ShieldGemma2ForImageClassification",
|
| 220 |
+
]
|
docs/transformers/build/lib/transformers/models/shieldgemma2/processing_shieldgemma2.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
from collections.abc import Mapping, Sequence
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
from ...feature_extraction_utils import BatchFeature
|
| 20 |
+
from ...image_utils import ImageInput
|
| 21 |
+
from ...processing_utils import Unpack
|
| 22 |
+
from ...utils import logging
|
| 23 |
+
from ..gemma3.processing_gemma3 import Gemma3Processor, Gemma3ProcessorKwargs
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
DEFAULT_SHIELDGEMMA2_POLICIES: Mapping[str, str] = {
|
| 29 |
+
"dangerous": (
|
| 30 |
+
"No Dangerous Content: The image shall not contain content that"
|
| 31 |
+
" facilitates or encourages activities that could cause real-world harm"
|
| 32 |
+
" (e.g., building firearms and explosive devices, promotion of"
|
| 33 |
+
" terrorism, instructions for suicide)."
|
| 34 |
+
),
|
| 35 |
+
"sexual": (
|
| 36 |
+
"No Sexually Explicit content: The image shall not contain content that"
|
| 37 |
+
" depicts explicit or graphic sexual acts (e.g., pornography, erotic"
|
| 38 |
+
" nudity, depictions of rape or sexual assault)."
|
| 39 |
+
),
|
| 40 |
+
"violence": (
|
| 41 |
+
"No Violence/Gore content: The image shall not contain content that"
|
| 42 |
+
" depicts shocking, sensational, or gratuitous violence (e.g.,"
|
| 43 |
+
" excessive blood and gore, gratuitous violence against animals,"
|
| 44 |
+
" extreme injury or moment of death)."
|
| 45 |
+
),
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ShieldGemma2ProcessorKwargs(Gemma3ProcessorKwargs, total=False):
|
| 50 |
+
policies: Optional[Sequence[str]]
|
| 51 |
+
custom_policies: Optional[Mapping[str, str]]
|
| 52 |
+
_defaults = {
|
| 53 |
+
"text_kwargs": {
|
| 54 |
+
"padding": True,
|
| 55 |
+
},
|
| 56 |
+
"images_kwargs": {
|
| 57 |
+
"do_pan_and_scan": False,
|
| 58 |
+
},
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ShieldGemma2Processor(Gemma3Processor):
|
| 63 |
+
def __init__(
|
| 64 |
+
self, image_processor, tokenizer, chat_template=None, image_seq_length=256, policy_definitions=None, **kwargs
|
| 65 |
+
):
|
| 66 |
+
"""A processor for the ShieldGemma 2 model.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
image_processor: The image processor to use, typically a `Gemma3ImageProcessorFast` instance.
|
| 70 |
+
tokenizer: The tokenizer to use, typically a `GemmaTokenizerFast` instance.
|
| 71 |
+
chat_template: The chat template to use with this processor. Typically, this is unset as the processor
|
| 72 |
+
configuration on Hugging Face Hub includes this value already.
|
| 73 |
+
image_seq_length: The number of soft tokens per image. Typically, this is unset as the processor
|
| 74 |
+
configuration on Hugging Face Hub includes this value already.
|
| 75 |
+
policy_definitions: A mapping from policy name to its description in text used as the default policies to
|
| 76 |
+
classify images against. The policy descriptions are included in the text of the prompts generated by
|
| 77 |
+
this processor. Typically, this is unset as the processor configuration on Hugging Face Hub includes
|
| 78 |
+
the base policies ShieldGemma was trained on.
|
| 79 |
+
"""
|
| 80 |
+
super().__init__(image_processor, tokenizer, chat_template, image_seq_length, **kwargs)
|
| 81 |
+
if policy_definitions is None:
|
| 82 |
+
self.policy_definitions = DEFAULT_SHIELDGEMMA2_POLICIES
|
| 83 |
+
else:
|
| 84 |
+
self.policy_definitions = policy_definitions
|
| 85 |
+
|
| 86 |
+
def __call__(
|
| 87 |
+
self,
|
| 88 |
+
images: ImageInput = None,
|
| 89 |
+
text=None,
|
| 90 |
+
videos=None,
|
| 91 |
+
audio=None,
|
| 92 |
+
**kwargs: Unpack[ShieldGemma2ProcessorKwargs],
|
| 93 |
+
) -> BatchFeature:
|
| 94 |
+
"""Generates a batch of inputs from the provided images.
|
| 95 |
+
|
| 96 |
+
ShieldGemma was trained to classify image content for policy compliance using a specific prompt construction.
|
| 97 |
+
This processor generates a batch of such prompts from the provided images by:
|
| 98 |
+
|
| 99 |
+
1. Creating a list of conversations, one for each `<image, policy>` pair;
|
| 100 |
+
2. Converting these conversations to text using `self.apply_chat_template()`; and
|
| 101 |
+
3. Encoding the conversations and images using the same techniques as `Gemma3Processor`.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
images: A single image or a list of images to include in the batch.
|
| 105 |
+
text: Not supported.
|
| 106 |
+
videos: Not supported.
|
| 107 |
+
audio: Not supported.
|
| 108 |
+
kwargs: An optional dictionary of keyword arguments to configre the
|
| 109 |
+
processor. Possible values include:
|
| 110 |
+
|
| 111 |
+
* `custom_policies`: Additional policy definitions that augment the `self.policy_definitions` passed
|
| 112 |
+
into the constructor. Note that `custom_policies` that share a key with `self.policy_definitions`
|
| 113 |
+
will override the policy description
|
| 114 |
+
* `policies`: (Optional) a list of keys in the joint `self.policy_definitions | custom_policies`
|
| 115 |
+
dictionary of specific interest for the provided images. If empty or None, prompts will be
|
| 116 |
+
generated for every key in the joint dictionary.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
A `BatchFeature` continaing `input_ids`, `pixel_values`, etc. where each Tensor is of shape
|
| 120 |
+
`(len(images) * len(policies), )`, and the order within the batch will be
|
| 121 |
+
img1_policy1, ... img1_policyN, ... imgM_policyN.
|
| 122 |
+
"""
|
| 123 |
+
del text, videos, audio
|
| 124 |
+
|
| 125 |
+
if not images:
|
| 126 |
+
raise ValueError("ShieldGemma 2 needs images to classify")
|
| 127 |
+
elif not isinstance(images, Sequence):
|
| 128 |
+
images = [images]
|
| 129 |
+
|
| 130 |
+
if not self.chat_template:
|
| 131 |
+
raise ValueError("ShieldGemma 2 requires the use of a specific chat template")
|
| 132 |
+
|
| 133 |
+
# Disable pan and scan
|
| 134 |
+
images_kwargs = kwargs.setdefault("images_kwargs", {})
|
| 135 |
+
if images_kwargs.get("do_pan_and_scan") is True:
|
| 136 |
+
logger.warning_once("ShieldGemma2 does not support pan and scan.")
|
| 137 |
+
images_kwargs["do_pan_and_scan"] = False
|
| 138 |
+
|
| 139 |
+
# Enable padding on the batch during tokenization
|
| 140 |
+
text_kwargs = kwargs.setdefault("text_kwargs", {})
|
| 141 |
+
if "padding" not in text_kwargs:
|
| 142 |
+
text_kwargs["padding"] = kwargs.pop("padding", True)
|
| 143 |
+
text_kwargs["padding_side"] = kwargs.pop("padding_side", "left")
|
| 144 |
+
|
| 145 |
+
policy_definitions: Mapping[str, str] = {
|
| 146 |
+
**self.policy_definitions,
|
| 147 |
+
**kwargs.get("custom_policies", {}),
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
if (policies := kwargs.get("policies")) is None:
|
| 151 |
+
policies = list(policy_definitions.keys())
|
| 152 |
+
|
| 153 |
+
# TODO(ryanmullins): Support images from PIL or URLs.
|
| 154 |
+
messages = []
|
| 155 |
+
expanded_images = []
|
| 156 |
+
for img in images:
|
| 157 |
+
for policy in policies:
|
| 158 |
+
messages.append(
|
| 159 |
+
[
|
| 160 |
+
{
|
| 161 |
+
"role": "user",
|
| 162 |
+
"content": [
|
| 163 |
+
{"type": "image"},
|
| 164 |
+
{"type": "text", "text": policy_definitions[policy]},
|
| 165 |
+
],
|
| 166 |
+
}
|
| 167 |
+
]
|
| 168 |
+
)
|
| 169 |
+
expanded_images.append([img])
|
| 170 |
+
|
| 171 |
+
text = self.apply_chat_template(messages, tokenize=False)
|
| 172 |
+
return super().__call__(images=expanded_images, text=text, **kwargs)
|
| 173 |
+
|
| 174 |
+
def batch_decode(self, *args, **kwargs):
|
| 175 |
+
"""
|
| 176 |
+
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 177 |
+
refer to the docstring of this method for more information.
|
| 178 |
+
"""
|
| 179 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 180 |
+
|
| 181 |
+
def decode(self, *args, **kwargs):
|
| 182 |
+
"""
|
| 183 |
+
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 184 |
+
the docstring of this method for more information.
|
| 185 |
+
"""
|
| 186 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 187 |
+
|
| 188 |
+
@property
|
| 189 |
+
def model_input_names(self):
|
| 190 |
+
tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
|
| 191 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 192 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
__all__ = ["ShieldGemma2Processor"]
|
docs/transformers/build/lib/transformers/models/siglip/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_siglip import *
|
| 22 |
+
from .image_processing_siglip import *
|
| 23 |
+
from .image_processing_siglip_fast import *
|
| 24 |
+
from .modeling_siglip import *
|
| 25 |
+
from .processing_siglip import *
|
| 26 |
+
from .tokenization_siglip import *
|
| 27 |
+
else:
|
| 28 |
+
import sys
|
| 29 |
+
|
| 30 |
+
_file = globals()["__file__"]
|
| 31 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/siglip/configuration_siglip.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Siglip model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SiglipTextConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
|
| 27 |
+
Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
|
| 28 |
+
configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
|
| 29 |
+
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
| 36 |
+
Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
|
| 37 |
+
the `inputs_ids` passed when calling [`SiglipModel`].
|
| 38 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 39 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 40 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 41 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 42 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 43 |
+
Number of hidden layers in the Transformer encoder.
|
| 44 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 45 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 46 |
+
max_position_embeddings (`int`, *optional*, defaults to 64):
|
| 47 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 48 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 49 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
| 50 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 51 |
+
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
| 52 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 53 |
+
The epsilon used by the layer normalization layers.
|
| 54 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 55 |
+
The dropout ratio for the attention probabilities.
|
| 56 |
+
pad_token_id (`int`, *optional*, defaults to 1):
|
| 57 |
+
The id of the padding token in the vocabulary.
|
| 58 |
+
bos_token_id (`int`, *optional*, defaults to 49406):
|
| 59 |
+
The id of the beginning-of-sequence token in the vocabulary.
|
| 60 |
+
eos_token_id (`int`, *optional*, defaults to 49407):
|
| 61 |
+
The id of the end-of-sequence token in the vocabulary.
|
| 62 |
+
projection_size (`int`, *optional*, defaults to `hidden_size`):
|
| 63 |
+
The size of the projection head.
|
| 64 |
+
|
| 65 |
+
Example:
|
| 66 |
+
|
| 67 |
+
```python
|
| 68 |
+
>>> from transformers import SiglipTextConfig, SiglipTextModel
|
| 69 |
+
|
| 70 |
+
>>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
|
| 71 |
+
>>> configuration = SiglipTextConfig()
|
| 72 |
+
|
| 73 |
+
>>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
|
| 74 |
+
>>> model = SiglipTextModel(configuration)
|
| 75 |
+
|
| 76 |
+
>>> # Accessing the model configuration
|
| 77 |
+
>>> configuration = model.config
|
| 78 |
+
```"""
|
| 79 |
+
|
| 80 |
+
model_type = "siglip_text_model"
|
| 81 |
+
base_config_key = "text_config"
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
vocab_size=32000,
|
| 86 |
+
hidden_size=768,
|
| 87 |
+
intermediate_size=3072,
|
| 88 |
+
num_hidden_layers=12,
|
| 89 |
+
num_attention_heads=12,
|
| 90 |
+
max_position_embeddings=64,
|
| 91 |
+
hidden_act="gelu_pytorch_tanh",
|
| 92 |
+
layer_norm_eps=1e-6,
|
| 93 |
+
attention_dropout=0.0,
|
| 94 |
+
# This differs from `CLIPTokenizer`'s default and from openai/siglip
|
| 95 |
+
# See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
|
| 96 |
+
pad_token_id=1,
|
| 97 |
+
bos_token_id=49406,
|
| 98 |
+
eos_token_id=49407,
|
| 99 |
+
projection_size=None,
|
| 100 |
+
**kwargs,
|
| 101 |
+
):
|
| 102 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 103 |
+
|
| 104 |
+
self.vocab_size = vocab_size
|
| 105 |
+
self.hidden_size = hidden_size
|
| 106 |
+
self.intermediate_size = intermediate_size
|
| 107 |
+
self.num_hidden_layers = num_hidden_layers
|
| 108 |
+
self.num_attention_heads = num_attention_heads
|
| 109 |
+
self.max_position_embeddings = max_position_embeddings
|
| 110 |
+
self.layer_norm_eps = layer_norm_eps
|
| 111 |
+
self.hidden_act = hidden_act
|
| 112 |
+
self.attention_dropout = attention_dropout
|
| 113 |
+
self.projection_size = projection_size if projection_size is not None else hidden_size
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class SiglipVisionConfig(PretrainedConfig):
|
| 117 |
+
r"""
|
| 118 |
+
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
|
| 119 |
+
Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
| 120 |
+
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
|
| 121 |
+
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
|
| 122 |
+
|
| 123 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 124 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 128 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 129 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 130 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 131 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 132 |
+
Number of hidden layers in the Transformer encoder.
|
| 133 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 134 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 135 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 136 |
+
Number of channels in the input images.
|
| 137 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 138 |
+
The size (resolution) of each image.
|
| 139 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 140 |
+
The size (resolution) of each patch.
|
| 141 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
| 142 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 143 |
+
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
| 144 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 145 |
+
The epsilon used by the layer normalization layers.
|
| 146 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 147 |
+
The dropout ratio for the attention probabilities.
|
| 148 |
+
|
| 149 |
+
Example:
|
| 150 |
+
|
| 151 |
+
```python
|
| 152 |
+
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
|
| 153 |
+
|
| 154 |
+
>>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
|
| 155 |
+
>>> configuration = SiglipVisionConfig()
|
| 156 |
+
|
| 157 |
+
>>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
|
| 158 |
+
>>> model = SiglipVisionModel(configuration)
|
| 159 |
+
|
| 160 |
+
>>> # Accessing the model configuration
|
| 161 |
+
>>> configuration = model.config
|
| 162 |
+
```"""
|
| 163 |
+
|
| 164 |
+
model_type = "siglip_vision_model"
|
| 165 |
+
base_config_key = "vision_config"
|
| 166 |
+
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
hidden_size=768,
|
| 170 |
+
intermediate_size=3072,
|
| 171 |
+
num_hidden_layers=12,
|
| 172 |
+
num_attention_heads=12,
|
| 173 |
+
num_channels=3,
|
| 174 |
+
image_size=224,
|
| 175 |
+
patch_size=16,
|
| 176 |
+
hidden_act="gelu_pytorch_tanh",
|
| 177 |
+
layer_norm_eps=1e-6,
|
| 178 |
+
attention_dropout=0.0,
|
| 179 |
+
**kwargs,
|
| 180 |
+
):
|
| 181 |
+
super().__init__(**kwargs)
|
| 182 |
+
|
| 183 |
+
self.hidden_size = hidden_size
|
| 184 |
+
self.intermediate_size = intermediate_size
|
| 185 |
+
self.num_hidden_layers = num_hidden_layers
|
| 186 |
+
self.num_attention_heads = num_attention_heads
|
| 187 |
+
self.num_channels = num_channels
|
| 188 |
+
self.patch_size = patch_size
|
| 189 |
+
self.image_size = image_size
|
| 190 |
+
self.attention_dropout = attention_dropout
|
| 191 |
+
self.layer_norm_eps = layer_norm_eps
|
| 192 |
+
self.hidden_act = hidden_act
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class SiglipConfig(PretrainedConfig):
|
| 196 |
+
r"""
|
| 197 |
+
[`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
|
| 198 |
+
instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
|
| 199 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
|
| 200 |
+
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
|
| 201 |
+
|
| 202 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 203 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
text_config (`dict`, *optional*):
|
| 207 |
+
Dictionary of configuration options used to initialize [`SiglipTextConfig`].
|
| 208 |
+
vision_config (`dict`, *optional*):
|
| 209 |
+
Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
|
| 210 |
+
kwargs (*optional*):
|
| 211 |
+
Dictionary of keyword arguments.
|
| 212 |
+
|
| 213 |
+
Example:
|
| 214 |
+
|
| 215 |
+
```python
|
| 216 |
+
>>> from transformers import SiglipConfig, SiglipModel
|
| 217 |
+
|
| 218 |
+
>>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
|
| 219 |
+
>>> configuration = SiglipConfig()
|
| 220 |
+
|
| 221 |
+
>>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
|
| 222 |
+
>>> model = SiglipModel(configuration)
|
| 223 |
+
|
| 224 |
+
>>> # Accessing the model configuration
|
| 225 |
+
>>> configuration = model.config
|
| 226 |
+
|
| 227 |
+
>>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
|
| 228 |
+
>>> from transformers import SiglipTextConfig, SiglipVisionConfig
|
| 229 |
+
|
| 230 |
+
>>> # Initializing a SiglipText and SiglipVision configuration
|
| 231 |
+
>>> config_text = SiglipTextConfig()
|
| 232 |
+
>>> config_vision = SiglipVisionConfig()
|
| 233 |
+
|
| 234 |
+
>>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
|
| 235 |
+
```"""
|
| 236 |
+
|
| 237 |
+
model_type = "siglip"
|
| 238 |
+
sub_configs = {"text_config": SiglipTextConfig, "vision_config": SiglipVisionConfig}
|
| 239 |
+
|
| 240 |
+
def __init__(self, text_config=None, vision_config=None, **kwargs):
|
| 241 |
+
super().__init__(**kwargs)
|
| 242 |
+
|
| 243 |
+
if text_config is None:
|
| 244 |
+
text_config = {}
|
| 245 |
+
logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
|
| 246 |
+
|
| 247 |
+
if vision_config is None:
|
| 248 |
+
vision_config = {}
|
| 249 |
+
logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
|
| 250 |
+
|
| 251 |
+
self.text_config = SiglipTextConfig(**text_config)
|
| 252 |
+
self.vision_config = SiglipVisionConfig(**vision_config)
|
| 253 |
+
|
| 254 |
+
self.initializer_factor = 1.0
|
| 255 |
+
|
| 256 |
+
@classmethod
|
| 257 |
+
def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs):
|
| 258 |
+
r"""
|
| 259 |
+
Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
|
| 260 |
+
model configuration.
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
[`SiglipConfig`]: An instance of a configuration object
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
__all__ = ["SiglipConfig", "SiglipTextConfig", "SiglipVisionConfig"]
|
docs/transformers/build/lib/transformers/models/siglip/convert_siglip_to_hf.py
ADDED
|
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert SigLIP checkpoints from the original repository.
|
| 16 |
+
|
| 17 |
+
URL: https://github.com/google-research/big_vision/tree/main
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
import collections
|
| 22 |
+
import os
|
| 23 |
+
from typing import Tuple
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import requests
|
| 27 |
+
import torch
|
| 28 |
+
from huggingface_hub import hf_hub_download
|
| 29 |
+
from numpy import load
|
| 30 |
+
from PIL import Image
|
| 31 |
+
|
| 32 |
+
from transformers import (
|
| 33 |
+
GemmaTokenizerFast,
|
| 34 |
+
SiglipConfig,
|
| 35 |
+
SiglipImageProcessor,
|
| 36 |
+
SiglipModel,
|
| 37 |
+
SiglipProcessor,
|
| 38 |
+
SiglipTokenizer,
|
| 39 |
+
)
|
| 40 |
+
from transformers.utils import logging
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
logging.set_verbosity_info()
|
| 44 |
+
logger = logging.get_logger(__name__)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
MODEL_CONFIGS = {
|
| 48 |
+
"base": {
|
| 49 |
+
"hidden_size": 768,
|
| 50 |
+
"intermediate_size": 3072,
|
| 51 |
+
"num_hidden_layers": 12,
|
| 52 |
+
"num_attention_heads": 12,
|
| 53 |
+
},
|
| 54 |
+
"large": {
|
| 55 |
+
"hidden_size": 1024,
|
| 56 |
+
"intermediate_size": 4096,
|
| 57 |
+
"num_hidden_layers": 24,
|
| 58 |
+
"num_attention_heads": 16,
|
| 59 |
+
},
|
| 60 |
+
"giant-opt": {
|
| 61 |
+
"hidden_size": 1536,
|
| 62 |
+
"intermediate_size": 6144,
|
| 63 |
+
"num_hidden_layers": 40,
|
| 64 |
+
"num_attention_heads": 16,
|
| 65 |
+
},
|
| 66 |
+
"so400m": {
|
| 67 |
+
"hidden_size": 1152,
|
| 68 |
+
"intermediate_size": 4304,
|
| 69 |
+
"num_hidden_layers": 27,
|
| 70 |
+
"num_attention_heads": 16,
|
| 71 |
+
},
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
model_name_to_checkpoint = {
|
| 75 |
+
# base checkpoints
|
| 76 |
+
"siglip-base-patch16-224": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_224_63724782.npz",
|
| 77 |
+
"siglip-base-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_256_60500360.npz",
|
| 78 |
+
"siglip-base-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_384_68578854.npz",
|
| 79 |
+
"siglip-base-patch16-512": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_512_68580893.npz",
|
| 80 |
+
# large checkpoints
|
| 81 |
+
"siglip-large-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_256_60552751.npz",
|
| 82 |
+
"siglip-large-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_384_63634585.npz",
|
| 83 |
+
# multilingual checkpoint
|
| 84 |
+
"siglip-base-patch16-256-i18n": "/Users/nielsrogge/Documents/SigLIP/webli_i18n_b16_256_66117334.npz",
|
| 85 |
+
# so400m checkpoints
|
| 86 |
+
"siglip-so400m-patch14-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_so400m_384_58765454.npz",
|
| 87 |
+
# ----------------- v2 -----------------
|
| 88 |
+
# base checkpoints
|
| 89 |
+
"siglip2-base-patch32-256": "gv-hf/siglip2/siglip2_b32_256.npz",
|
| 90 |
+
"siglip2-base-patch16-224": "gv-hf/siglip2/siglip2_b16_224.npz",
|
| 91 |
+
"siglip2-base-patch16-256": "gv-hf/siglip2/siglip2_b16_256.npz",
|
| 92 |
+
"siglip2-base-patch16-384": "gv-hf/siglip2/siglip2_b16_384.npz",
|
| 93 |
+
"siglip2-base-patch16-512": "gv-hf/siglip2/siglip2_b16_512.npz",
|
| 94 |
+
# large checkpoints
|
| 95 |
+
"siglip2-large-patch16-256": "gv-hf/siglip2/siglip2_l16_256.npz",
|
| 96 |
+
"siglip2-large-patch16-384": "gv-hf/siglip2/siglip2_l16_384.npz",
|
| 97 |
+
"siglip2-large-patch16-512": "gv-hf/siglip2/siglip2_l16_512.npz",
|
| 98 |
+
# giant opt checkpoints
|
| 99 |
+
"siglip2-giant-opt-patch16-256": "gv-hf/siglip2/siglip2_g-opt16_256.npz",
|
| 100 |
+
"siglip2-giant-opt-patch16-384": "gv-hf/siglip2/siglip2_g-opt16_384.npz",
|
| 101 |
+
# so400m checkpoints
|
| 102 |
+
"siglip2-so400m-patch14-224": "gv-hf/siglip2/siglip2_so400m14_224.npz",
|
| 103 |
+
"siglip2-so400m-patch14-384": "gv-hf/siglip2/siglip2_so400m14_384.npz",
|
| 104 |
+
"siglip2-so400m-patch16-256": "gv-hf/siglip2/siglip2_so400m16_256.npz",
|
| 105 |
+
"siglip2-so400m-patch16-384": "gv-hf/siglip2/siglip2_so400m16_384.npz",
|
| 106 |
+
"siglip2-so400m-patch16-512": "gv-hf/siglip2/siglip2_so400m16_512.npz",
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
# ------------------------------------------------------------------------------------------------------
|
| 110 |
+
# CONFIG
|
| 111 |
+
# ------------------------------------------------------------------------------------------------------
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def get_image_size_from_model_name(model_name: str) -> int:
|
| 115 |
+
if "-i18n" not in model_name:
|
| 116 |
+
size = model_name.split("-")[-1]
|
| 117 |
+
else:
|
| 118 |
+
size = model_name.split("-")[-2]
|
| 119 |
+
return int(size)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def get_patch_size_from_model_name(model_name: str) -> int:
|
| 123 |
+
patch_str = [x for x in model_name.split("-") if "patch" in x][0]
|
| 124 |
+
return int(patch_str[-2:])
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def get_vocab_size_from_model_name(model_name: str) -> int:
|
| 128 |
+
if "siglip2" in model_name:
|
| 129 |
+
vocab_size = 256000
|
| 130 |
+
elif "-i18n" in model_name:
|
| 131 |
+
vocab_size = 250000
|
| 132 |
+
else:
|
| 133 |
+
vocab_size = 32000
|
| 134 |
+
return vocab_size
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_vocab_file_from_model_name(model_name: str) -> str:
|
| 138 |
+
# get vocab file
|
| 139 |
+
if "i18n" in model_name:
|
| 140 |
+
vocab_file = "/Users/nielsrogge/Documents/SigLIP/multilingual_vocab/sentencepiece.model"
|
| 141 |
+
else:
|
| 142 |
+
vocab_file = "/Users/nielsrogge/Documents/SigLIP/english_vocab/sentencepiece.model"
|
| 143 |
+
return vocab_file
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def get_text_and_vision_vit_variants(model_name: str) -> Tuple[str, str]:
|
| 147 |
+
variant = model_name.split("-")[1] if "giant-opt" not in model_name else "giant-opt"
|
| 148 |
+
return {
|
| 149 |
+
"base": ("base", "base"),
|
| 150 |
+
"large": ("large", "large"),
|
| 151 |
+
"so400m": ("so400m", "so400m"),
|
| 152 |
+
# g-opt siglip2 is not symmetric
|
| 153 |
+
"giant-opt": ("so400m", "giant-opt"),
|
| 154 |
+
}[variant]
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def get_siglip_config(model_name):
|
| 158 |
+
text_variant, vision_variant = get_text_and_vision_vit_variants(model_name)
|
| 159 |
+
text_config = MODEL_CONFIGS[text_variant].copy()
|
| 160 |
+
vision_config = MODEL_CONFIGS[vision_variant].copy()
|
| 161 |
+
|
| 162 |
+
text_config["vocab_size"] = get_vocab_size_from_model_name(model_name)
|
| 163 |
+
vision_config["image_size"] = get_image_size_from_model_name(model_name)
|
| 164 |
+
vision_config["patch_size"] = get_patch_size_from_model_name(model_name)
|
| 165 |
+
|
| 166 |
+
if text_config["hidden_size"] != vision_config["hidden_size"]:
|
| 167 |
+
text_config["projection_size"] = vision_config["hidden_size"]
|
| 168 |
+
|
| 169 |
+
return SiglipConfig(text_config=text_config, vision_config=vision_config)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# ------------------------------------------------------------------------------------------------------
|
| 173 |
+
# PROCESSING
|
| 174 |
+
# ------------------------------------------------------------------------------------------------------
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def get_tokenizer(model_name: str) -> GemmaTokenizerFast:
|
| 178 |
+
if "siglip2" in model_name:
|
| 179 |
+
tokenizer = GemmaTokenizerFast.from_pretrained(
|
| 180 |
+
"google/gemma-2-9b-it",
|
| 181 |
+
add_bos_token=False,
|
| 182 |
+
add_eos_token=True,
|
| 183 |
+
padding_side="right",
|
| 184 |
+
do_lower_case=True,
|
| 185 |
+
# important: make tokenizer NOT return attention_mask since original one doesn't require it
|
| 186 |
+
model_input_names=["input_ids"],
|
| 187 |
+
)
|
| 188 |
+
else:
|
| 189 |
+
# for siglip v1
|
| 190 |
+
vocab_file = get_vocab_file_from_model_name(model_name)
|
| 191 |
+
# important: make tokenizer not return attention_mask since original one doesn't require it
|
| 192 |
+
tokenizer = SiglipTokenizer(vocab_file=vocab_file, model_input_names=["input_ids"])
|
| 193 |
+
return tokenizer
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def get_image_processor(model_name: str) -> SiglipImageProcessor:
|
| 197 |
+
image_size = get_image_size_from_model_name(model_name)
|
| 198 |
+
size = {"height": image_size, "width": image_size}
|
| 199 |
+
if "siglip2" in model_name:
|
| 200 |
+
image_processor = SiglipImageProcessor(size=size, resample=2) # bilinear resampling
|
| 201 |
+
else:
|
| 202 |
+
image_processor = SiglipImageProcessor(size=size)
|
| 203 |
+
return image_processor
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# ------------------------------------------------------------------------------------------------------
|
| 207 |
+
# CONVERT FUNCTIONS
|
| 208 |
+
# ------------------------------------------------------------------------------------------------------
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def split_encoderblock_layers(state_dict: dict) -> dict:
|
| 212 |
+
"""
|
| 213 |
+
Split the encoderblock weight into layers. In some cases they are concatenated in
|
| 214 |
+
the original checkpoints.
|
| 215 |
+
"""
|
| 216 |
+
# Make shallow copy
|
| 217 |
+
state_dict = state_dict.copy()
|
| 218 |
+
# Split encoderblock weight into layers
|
| 219 |
+
keys = list(state_dict.keys())
|
| 220 |
+
for key in keys:
|
| 221 |
+
if "/encoderblock/" in key:
|
| 222 |
+
weight = state_dict.pop(key)
|
| 223 |
+
for i, weight_i in enumerate(weight):
|
| 224 |
+
new_name = key.replace("encoderblock", f"encoderblock_{i}")
|
| 225 |
+
state_dict[new_name] = weight_i
|
| 226 |
+
return state_dict
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def create_rename_keys(config):
|
| 230 |
+
rename_keys = []
|
| 231 |
+
# fmt: off
|
| 232 |
+
|
| 233 |
+
# vision encoder
|
| 234 |
+
|
| 235 |
+
rename_keys.append(("params/img/embedding/kernel", "vision_model.embeddings.patch_embedding.weight"))
|
| 236 |
+
rename_keys.append(("params/img/embedding/bias", "vision_model.embeddings.patch_embedding.bias"))
|
| 237 |
+
rename_keys.append(("params/img/pos_embedding", "vision_model.embeddings.position_embedding.weight"))
|
| 238 |
+
|
| 239 |
+
for i in range(config.vision_config.num_hidden_layers):
|
| 240 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/scale", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
|
| 241 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
|
| 242 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/scale", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
|
| 243 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
|
| 244 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
|
| 245 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
|
| 246 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
|
| 247 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
|
| 248 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"vision_model.encoder.layers.{i}.self_attn.k_proj.weight"))
|
| 249 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"vision_model.encoder.layers.{i}.self_attn.k_proj.bias"))
|
| 250 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"vision_model.encoder.layers.{i}.self_attn.v_proj.weight"))
|
| 251 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"vision_model.encoder.layers.{i}.self_attn.v_proj.bias"))
|
| 252 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"vision_model.encoder.layers.{i}.self_attn.q_proj.weight"))
|
| 253 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"vision_model.encoder.layers.{i}.self_attn.q_proj.bias"))
|
| 254 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"vision_model.encoder.layers.{i}.self_attn.out_proj.weight"))
|
| 255 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"vision_model.encoder.layers.{i}.self_attn.out_proj.bias"))
|
| 256 |
+
|
| 257 |
+
rename_keys.append(("params/img/Transformer/encoder_norm/scale", "vision_model.post_layernorm.weight"))
|
| 258 |
+
rename_keys.append(("params/img/Transformer/encoder_norm/bias", "vision_model.post_layernorm.bias"))
|
| 259 |
+
|
| 260 |
+
rename_keys.append(("params/img/MAPHead_0/probe", "vision_model.head.probe"))
|
| 261 |
+
rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/scale", "vision_model.head.layernorm.weight"))
|
| 262 |
+
rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/bias", "vision_model.head.layernorm.bias"))
|
| 263 |
+
rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/kernel", "vision_model.head.mlp.fc1.weight"))
|
| 264 |
+
rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/bias", "vision_model.head.mlp.fc1.bias"))
|
| 265 |
+
rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/kernel", "vision_model.head.mlp.fc2.weight"))
|
| 266 |
+
rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/bias", "vision_model.head.mlp.fc2.bias"))
|
| 267 |
+
rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/kernel", "vision_model.head.attention.out_proj.weight"))
|
| 268 |
+
rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/bias", "vision_model.head.attention.out_proj.bias"))
|
| 269 |
+
|
| 270 |
+
# text encoder
|
| 271 |
+
|
| 272 |
+
rename_keys.append(("params/txt/Embed_0/embedding", "text_model.embeddings.token_embedding.weight"))
|
| 273 |
+
rename_keys.append(("params/txt/pos_embedding", "text_model.embeddings.position_embedding.weight"))
|
| 274 |
+
|
| 275 |
+
for i in range(config.text_config.num_hidden_layers):
|
| 276 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/scale", f"text_model.encoder.layers.{i}.layer_norm1.weight"))
|
| 277 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/bias", f"text_model.encoder.layers.{i}.layer_norm1.bias"))
|
| 278 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/scale", f"text_model.encoder.layers.{i}.layer_norm2.weight"))
|
| 279 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/bias", f"text_model.encoder.layers.{i}.layer_norm2.bias"))
|
| 280 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"text_model.encoder.layers.{i}.mlp.fc1.weight"))
|
| 281 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"text_model.encoder.layers.{i}.mlp.fc1.bias"))
|
| 282 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"text_model.encoder.layers.{i}.mlp.fc2.weight"))
|
| 283 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"text_model.encoder.layers.{i}.mlp.fc2.bias"))
|
| 284 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"text_model.encoder.layers.{i}.self_attn.k_proj.weight"))
|
| 285 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"text_model.encoder.layers.{i}.self_attn.k_proj.bias"))
|
| 286 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"text_model.encoder.layers.{i}.self_attn.v_proj.weight"))
|
| 287 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"text_model.encoder.layers.{i}.self_attn.v_proj.bias"))
|
| 288 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"text_model.encoder.layers.{i}.self_attn.q_proj.weight"))
|
| 289 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"text_model.encoder.layers.{i}.self_attn.q_proj.bias"))
|
| 290 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"text_model.encoder.layers.{i}.self_attn.out_proj.weight"))
|
| 291 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"text_model.encoder.layers.{i}.self_attn.out_proj.bias"))
|
| 292 |
+
|
| 293 |
+
rename_keys.append(("params/txt/Encoder_0/encoder_norm/scale", "text_model.final_layer_norm.weight"))
|
| 294 |
+
rename_keys.append(("params/txt/Encoder_0/encoder_norm/bias", "text_model.final_layer_norm.bias"))
|
| 295 |
+
rename_keys.append(("params/txt/head/kernel", "text_model.head.weight"))
|
| 296 |
+
rename_keys.append(("params/txt/head/bias", "text_model.head.bias"))
|
| 297 |
+
|
| 298 |
+
# learned temperature and bias
|
| 299 |
+
rename_keys.append(("params/t", "logit_scale"))
|
| 300 |
+
rename_keys.append(("params/b", "logit_bias"))
|
| 301 |
+
|
| 302 |
+
# fmt: on
|
| 303 |
+
return rename_keys
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def rename_key(dct, old, new, config):
|
| 307 |
+
val = dct.pop(old)
|
| 308 |
+
|
| 309 |
+
if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "vision" in new:
|
| 310 |
+
val = val.reshape(-1, config.vision_config.hidden_size)
|
| 311 |
+
if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "text" in new:
|
| 312 |
+
val = val.reshape(-1, config.text_config.hidden_size)
|
| 313 |
+
|
| 314 |
+
if "patch_embedding.weight" in new:
|
| 315 |
+
val = val.transpose(3, 2, 0, 1)
|
| 316 |
+
elif new.endswith("weight") and "position_embedding" not in new and "token_embedding" not in new:
|
| 317 |
+
val = val.T
|
| 318 |
+
|
| 319 |
+
if "position_embedding" in new and "vision" in new:
|
| 320 |
+
val = val.reshape(-1, config.vision_config.hidden_size)
|
| 321 |
+
if "position_embedding" in new and "text" in new:
|
| 322 |
+
val = val.reshape(-1, config.text_config.hidden_size)
|
| 323 |
+
|
| 324 |
+
if new.endswith("bias"):
|
| 325 |
+
val = val.reshape(-1)
|
| 326 |
+
|
| 327 |
+
dct[new] = torch.from_numpy(val)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def read_in_q_k_v_head(state_dict, config):
|
| 331 |
+
# read in individual input projection layers
|
| 332 |
+
key_proj_weight = (
|
| 333 |
+
state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/kernel")
|
| 334 |
+
.reshape(-1, config.vision_config.hidden_size)
|
| 335 |
+
.T
|
| 336 |
+
)
|
| 337 |
+
key_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/bias").reshape(-1)
|
| 338 |
+
value_proj_weight = (
|
| 339 |
+
state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/kernel")
|
| 340 |
+
.reshape(-1, config.vision_config.hidden_size)
|
| 341 |
+
.T
|
| 342 |
+
)
|
| 343 |
+
value_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/bias").reshape(-1)
|
| 344 |
+
query_proj_weight = (
|
| 345 |
+
state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/kernel")
|
| 346 |
+
.reshape(-1, config.vision_config.hidden_size)
|
| 347 |
+
.T
|
| 348 |
+
)
|
| 349 |
+
query_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/bias").reshape(-1)
|
| 350 |
+
|
| 351 |
+
# next, add them to the state dict as a single matrix + vector
|
| 352 |
+
state_dict["vision_model.head.attention.in_proj_weight"] = torch.from_numpy(
|
| 353 |
+
np.concatenate([query_proj_weight, key_proj_weight, value_proj_weight], axis=0)
|
| 354 |
+
)
|
| 355 |
+
state_dict["vision_model.head.attention.in_proj_bias"] = torch.from_numpy(
|
| 356 |
+
np.concatenate([query_proj_bias, key_proj_bias, value_proj_bias], axis=0)
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
# We will verify our results on an image of cute cats
|
| 361 |
+
def prepare_img():
|
| 362 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 363 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
| 364 |
+
return image
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def flatten_nested_dict(params, parent_key="", sep="/"):
|
| 368 |
+
items = []
|
| 369 |
+
|
| 370 |
+
for k, v in params.items():
|
| 371 |
+
new_key = parent_key + sep + k if parent_key else k
|
| 372 |
+
|
| 373 |
+
if isinstance(v, collections.abc.MutableMapping):
|
| 374 |
+
items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
|
| 375 |
+
else:
|
| 376 |
+
items.append((new_key, v))
|
| 377 |
+
return dict(items)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
@torch.no_grad()
|
| 381 |
+
def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logits=True, push_to_hub=False):
|
| 382 |
+
"""
|
| 383 |
+
Copy/paste/tweak model's weights to our SigLIP structure.
|
| 384 |
+
"""
|
| 385 |
+
|
| 386 |
+
# Define default SigLIP configuration
|
| 387 |
+
config = get_siglip_config(model_name)
|
| 388 |
+
|
| 389 |
+
# Get checkpoint
|
| 390 |
+
checkpoint = model_name_to_checkpoint[model_name]
|
| 391 |
+
if not os.path.exists(checkpoint):
|
| 392 |
+
org, repo_id, *filepath = checkpoint.split("/")
|
| 393 |
+
checkpoint = hf_hub_download(repo_id=f"{org}/{repo_id}", filename="/".join(filepath))
|
| 394 |
+
|
| 395 |
+
# Load original state dict
|
| 396 |
+
data = load(checkpoint)
|
| 397 |
+
state_dict = flatten_nested_dict(data)
|
| 398 |
+
state_dict = split_encoderblock_layers(state_dict)
|
| 399 |
+
|
| 400 |
+
# Remove and rename some keys
|
| 401 |
+
rename_keys = create_rename_keys(config)
|
| 402 |
+
for src, dest in rename_keys:
|
| 403 |
+
rename_key(state_dict, src, dest, config)
|
| 404 |
+
|
| 405 |
+
# qkv matrices of attention pooling head need special treatment
|
| 406 |
+
read_in_q_k_v_head(state_dict, config)
|
| 407 |
+
|
| 408 |
+
# Load HuggingFace model
|
| 409 |
+
model = SiglipModel(config).eval()
|
| 410 |
+
model.load_state_dict(state_dict)
|
| 411 |
+
|
| 412 |
+
# Create processor
|
| 413 |
+
image_processor = get_image_processor(model_name)
|
| 414 |
+
tokenizer = get_tokenizer(model_name)
|
| 415 |
+
processor = SiglipProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
| 416 |
+
|
| 417 |
+
# Verify forward pass on dummy images and texts
|
| 418 |
+
url_1 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-ipod.jpg"
|
| 419 |
+
image_1 = Image.open(requests.get(url_1, stream=True).raw).convert("RGB")
|
| 420 |
+
url_2 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-blank.jpg"
|
| 421 |
+
image_2 = Image.open(requests.get(url_2, stream=True).raw).convert("RGB")
|
| 422 |
+
texts = ["an apple", "a picture of an apple"]
|
| 423 |
+
|
| 424 |
+
inputs = processor(images=[image_1, image_2], text=texts, padding="max_length", max_length=64, return_tensors="pt")
|
| 425 |
+
with torch.no_grad():
|
| 426 |
+
outputs = model(**inputs)
|
| 427 |
+
|
| 428 |
+
if verify_logits:
|
| 429 |
+
image_size = config.vision_config.image_size
|
| 430 |
+
|
| 431 |
+
# verify input_ids against original ones
|
| 432 |
+
if image_size == 224:
|
| 433 |
+
filename = "siglip_pixel_values.pt"
|
| 434 |
+
elif image_size == 256:
|
| 435 |
+
filename = "siglip_pixel_values_256.pt"
|
| 436 |
+
elif image_size == 384:
|
| 437 |
+
filename = "siglip_pixel_values_384.pt"
|
| 438 |
+
elif image_size == 512:
|
| 439 |
+
filename = "siglip_pixel_values_512.pt"
|
| 440 |
+
else:
|
| 441 |
+
raise ValueError("Image size not supported")
|
| 442 |
+
|
| 443 |
+
filepath = hf_hub_download(repo_id="nielsr/test-image", filename=filename, repo_type="dataset")
|
| 444 |
+
original_pixel_values = torch.load(filepath, weights_only=True)
|
| 445 |
+
filepath = hf_hub_download(repo_id="nielsr/test-image", filename="siglip_input_ids.pt", repo_type="dataset")
|
| 446 |
+
original_input_ids = torch.load(filepath, weights_only=True)
|
| 447 |
+
|
| 448 |
+
if "i18n" not in model_name:
|
| 449 |
+
assert inputs.input_ids.tolist() == original_input_ids.tolist()
|
| 450 |
+
|
| 451 |
+
print("Mean of original pixel values:", original_pixel_values.mean())
|
| 452 |
+
print("Mean of new pixel values:", inputs.pixel_values.mean())
|
| 453 |
+
|
| 454 |
+
# note: we're testing with original pixel values here since we don't have exact pixel values
|
| 455 |
+
with torch.no_grad():
|
| 456 |
+
outputs = model(input_ids=original_input_ids, pixel_values=original_pixel_values)
|
| 457 |
+
print(outputs.logits_per_image[:3, :3])
|
| 458 |
+
|
| 459 |
+
probs = torch.sigmoid(outputs.logits_per_image) # these are the probabilities
|
| 460 |
+
print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
|
| 461 |
+
print(f"{probs[0][1]:.1%} that image 0 is '{texts[1]}'")
|
| 462 |
+
|
| 463 |
+
if model_name == "siglip-base-patch16-224":
|
| 464 |
+
expected_slice = torch.tensor(
|
| 465 |
+
[[-2.9621, -2.1672], [-0.2713, 0.2910]],
|
| 466 |
+
)
|
| 467 |
+
elif model_name == "siglip-base-patch16-256":
|
| 468 |
+
expected_slice = torch.tensor(
|
| 469 |
+
[[-3.1146, -1.9894], [-0.7312, 0.6387]],
|
| 470 |
+
)
|
| 471 |
+
elif model_name == "siglip-base-patch16-384":
|
| 472 |
+
expected_slice = torch.tensor(
|
| 473 |
+
[[-2.8098, -2.1891], [-0.4242, 0.4102]],
|
| 474 |
+
)
|
| 475 |
+
elif model_name == "siglip-base-patch16-512":
|
| 476 |
+
expected_slice = torch.tensor(
|
| 477 |
+
[[-2.7899, -2.2668], [-0.4295, -0.0735]],
|
| 478 |
+
)
|
| 479 |
+
elif model_name == "siglip-large-patch16-256":
|
| 480 |
+
expected_slice = torch.tensor(
|
| 481 |
+
[[-1.5827, -0.5801], [-0.9153, 0.1363]],
|
| 482 |
+
)
|
| 483 |
+
elif model_name == "siglip-large-patch16-384":
|
| 484 |
+
expected_slice = torch.tensor(
|
| 485 |
+
[[-2.1523, -0.2899], [-0.2959, 0.7884]],
|
| 486 |
+
)
|
| 487 |
+
elif model_name == "siglip-so400m-patch14-384":
|
| 488 |
+
expected_slice = torch.tensor([[-1.2441, -0.6649], [-0.7060, 0.7374]])
|
| 489 |
+
elif model_name == "siglip-base-patch16-256-i18n":
|
| 490 |
+
expected_slice = torch.tensor(
|
| 491 |
+
[[-0.9064, 0.1073], [-0.0299, 0.5304]],
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
assert torch.allclose(outputs.logits_per_image[:3, :3], expected_slice, atol=1e-4)
|
| 495 |
+
print("Looks ok!")
|
| 496 |
+
|
| 497 |
+
if pytorch_dump_folder_path is not None:
|
| 498 |
+
pytorch_dump_folder_path = os.path.join(pytorch_dump_folder_path, model_name)
|
| 499 |
+
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
| 500 |
+
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
|
| 501 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 502 |
+
print(f"Saving processor to {pytorch_dump_folder_path}")
|
| 503 |
+
processor.save_pretrained(pytorch_dump_folder_path)
|
| 504 |
+
|
| 505 |
+
if push_to_hub:
|
| 506 |
+
model.push_to_hub(f"s0225/{model_name}", private=True)
|
| 507 |
+
processor.push_to_hub(f"s0225/{model_name}", private=True)
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
if __name__ == "__main__":
|
| 511 |
+
parser = argparse.ArgumentParser()
|
| 512 |
+
# Required parameters
|
| 513 |
+
parser.add_argument(
|
| 514 |
+
"--model_name",
|
| 515 |
+
default="siglip-base-patch16-224",
|
| 516 |
+
type=str,
|
| 517 |
+
choices=model_name_to_checkpoint.keys(),
|
| 518 |
+
help="Name of the model you'd like to convert.",
|
| 519 |
+
)
|
| 520 |
+
parser.add_argument(
|
| 521 |
+
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
| 522 |
+
)
|
| 523 |
+
parser.add_argument(
|
| 524 |
+
"--verify_logits",
|
| 525 |
+
action="store_true",
|
| 526 |
+
help="Whether to verify logits against the original implementation.",
|
| 527 |
+
)
|
| 528 |
+
parser.add_argument(
|
| 529 |
+
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
args = parser.parse_args()
|
| 533 |
+
convert_siglip_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.verify_logits, args.push_to_hub)
|
docs/transformers/build/lib/transformers/models/siglip/image_processing_siglip.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for SigLIP."""
|
| 16 |
+
|
| 17 |
+
from typing import Dict, List, Optional, Union
|
| 18 |
+
|
| 19 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
| 20 |
+
from ...image_transforms import (
|
| 21 |
+
convert_to_rgb,
|
| 22 |
+
resize,
|
| 23 |
+
to_channel_dimension_format,
|
| 24 |
+
)
|
| 25 |
+
from ...image_utils import (
|
| 26 |
+
IMAGENET_STANDARD_MEAN,
|
| 27 |
+
IMAGENET_STANDARD_STD,
|
| 28 |
+
ChannelDimension,
|
| 29 |
+
ImageInput,
|
| 30 |
+
PILImageResampling,
|
| 31 |
+
infer_channel_dimension_format,
|
| 32 |
+
is_scaled_image,
|
| 33 |
+
make_flat_list_of_images,
|
| 34 |
+
to_numpy_array,
|
| 35 |
+
valid_images,
|
| 36 |
+
validate_preprocess_arguments,
|
| 37 |
+
)
|
| 38 |
+
from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if is_vision_available():
|
| 45 |
+
import PIL
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class SiglipImageProcessor(BaseImageProcessor):
|
| 49 |
+
r"""
|
| 50 |
+
Constructs a SigLIP image processor.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 54 |
+
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
| 55 |
+
`do_resize` in the `preprocess` method.
|
| 56 |
+
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
|
| 57 |
+
Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
|
| 58 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
| 59 |
+
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
| 60 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 61 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
| 62 |
+
the `preprocess` method.
|
| 63 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 64 |
+
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
| 65 |
+
method.
|
| 66 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 67 |
+
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
|
| 68 |
+
`do_normalize` in the `preprocess` method.
|
| 69 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 70 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 71 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 72 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 73 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
| 74 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 75 |
+
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 76 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 77 |
+
Whether to convert the image to RGB.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
model_input_names = ["pixel_values"]
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
do_resize: bool = True,
|
| 85 |
+
size: Dict[str, int] = None,
|
| 86 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 87 |
+
do_rescale: bool = True,
|
| 88 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 89 |
+
do_normalize: bool = True,
|
| 90 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 91 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 92 |
+
do_convert_rgb: Optional[bool] = None,
|
| 93 |
+
**kwargs,
|
| 94 |
+
) -> None:
|
| 95 |
+
super().__init__(**kwargs)
|
| 96 |
+
size = size if size is not None else {"height": 224, "width": 224}
|
| 97 |
+
image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
| 98 |
+
image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
| 99 |
+
|
| 100 |
+
self.do_resize = do_resize
|
| 101 |
+
self.size = size
|
| 102 |
+
self.resample = resample
|
| 103 |
+
self.do_rescale = do_rescale
|
| 104 |
+
self.rescale_factor = rescale_factor
|
| 105 |
+
self.do_normalize = do_normalize
|
| 106 |
+
self.image_mean = image_mean
|
| 107 |
+
self.image_std = image_std
|
| 108 |
+
self.do_convert_rgb = do_convert_rgb
|
| 109 |
+
|
| 110 |
+
@filter_out_non_signature_kwargs()
|
| 111 |
+
def preprocess(
|
| 112 |
+
self,
|
| 113 |
+
images: ImageInput,
|
| 114 |
+
do_resize: Optional[bool] = None,
|
| 115 |
+
size: Dict[str, int] = None,
|
| 116 |
+
resample: PILImageResampling = None,
|
| 117 |
+
do_rescale: Optional[bool] = None,
|
| 118 |
+
rescale_factor: Optional[float] = None,
|
| 119 |
+
do_normalize: Optional[bool] = None,
|
| 120 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 121 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 122 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 123 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
| 124 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 125 |
+
do_convert_rgb: Optional[bool] = None,
|
| 126 |
+
) -> PIL.Image.Image:
|
| 127 |
+
"""
|
| 128 |
+
Preprocess an image or batch of images.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
images (`ImageInput`):
|
| 132 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 133 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 134 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 135 |
+
Whether to resize the image.
|
| 136 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 137 |
+
Size of the image after resizing.
|
| 138 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
| 139 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
| 140 |
+
has an effect if `do_resize` is set to `True`.
|
| 141 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 142 |
+
Whether to rescale the image.
|
| 143 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 144 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 145 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 146 |
+
Whether to normalize the image.
|
| 147 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 148 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 149 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 150 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
| 151 |
+
`True`.
|
| 152 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 153 |
+
The type of tensors to return. Can be one of:
|
| 154 |
+
- Unset: Return a list of `np.ndarray`.
|
| 155 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 156 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 157 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 158 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 159 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 160 |
+
The channel dimension format for the output image. Can be one of:
|
| 161 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 162 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 163 |
+
- Unset: Use the channel dimension format of the input image.
|
| 164 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 165 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 166 |
+
from the input image. Can be one of:
|
| 167 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 168 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 169 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 170 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 171 |
+
Whether to convert the image to RGB.
|
| 172 |
+
"""
|
| 173 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 174 |
+
size = size if size is not None else self.size
|
| 175 |
+
size = get_size_dict(size, param_name="size", default_to_square=False)
|
| 176 |
+
resample = resample if resample is not None else self.resample
|
| 177 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 178 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 179 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 180 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 181 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 182 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 183 |
+
|
| 184 |
+
images = make_flat_list_of_images(images)
|
| 185 |
+
|
| 186 |
+
if not valid_images(images):
|
| 187 |
+
raise ValueError(
|
| 188 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 189 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 190 |
+
)
|
| 191 |
+
validate_preprocess_arguments(
|
| 192 |
+
do_rescale=do_rescale,
|
| 193 |
+
rescale_factor=rescale_factor,
|
| 194 |
+
do_normalize=do_normalize,
|
| 195 |
+
image_mean=image_mean,
|
| 196 |
+
image_std=image_std,
|
| 197 |
+
do_resize=do_resize,
|
| 198 |
+
size=size,
|
| 199 |
+
resample=resample,
|
| 200 |
+
)
|
| 201 |
+
if do_convert_rgb:
|
| 202 |
+
images = [convert_to_rgb(image) for image in images]
|
| 203 |
+
|
| 204 |
+
# All transformations expect numpy arrays.
|
| 205 |
+
images = [to_numpy_array(image) for image in images]
|
| 206 |
+
|
| 207 |
+
if do_rescale and is_scaled_image(images[0]):
|
| 208 |
+
logger.warning_once(
|
| 209 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 210 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
if input_data_format is None:
|
| 214 |
+
# We assume that all images have the same channel dimension format.
|
| 215 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 216 |
+
|
| 217 |
+
if do_resize:
|
| 218 |
+
height, width = size["height"], size["width"]
|
| 219 |
+
images = [
|
| 220 |
+
resize(image=image, size=(height, width), resample=resample, input_data_format=input_data_format)
|
| 221 |
+
for image in images
|
| 222 |
+
]
|
| 223 |
+
|
| 224 |
+
if do_rescale:
|
| 225 |
+
images = [
|
| 226 |
+
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 227 |
+
for image in images
|
| 228 |
+
]
|
| 229 |
+
|
| 230 |
+
if do_normalize:
|
| 231 |
+
images = [
|
| 232 |
+
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 233 |
+
for image in images
|
| 234 |
+
]
|
| 235 |
+
|
| 236 |
+
images = [
|
| 237 |
+
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
| 238 |
+
]
|
| 239 |
+
|
| 240 |
+
data = {"pixel_values": images}
|
| 241 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
__all__ = ["SiglipImageProcessor"]
|
docs/transformers/build/lib/transformers/models/siglip/image_processing_siglip_fast.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Fast Image processor class for SigLIP."""
|
| 16 |
+
|
| 17 |
+
from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast
|
| 18 |
+
from ...image_utils import (
|
| 19 |
+
IMAGENET_STANDARD_MEAN,
|
| 20 |
+
IMAGENET_STANDARD_STD,
|
| 21 |
+
PILImageResampling,
|
| 22 |
+
)
|
| 23 |
+
from ...utils import add_start_docstrings
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@add_start_docstrings(
|
| 27 |
+
"Constructs a fast SigLIP image processor.",
|
| 28 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
| 29 |
+
)
|
| 30 |
+
class SiglipImageProcessorFast(BaseImageProcessorFast):
|
| 31 |
+
resample = PILImageResampling.BICUBIC
|
| 32 |
+
image_mean = IMAGENET_STANDARD_MEAN
|
| 33 |
+
image_std = IMAGENET_STANDARD_STD
|
| 34 |
+
size = {"height": 224, "width": 224}
|
| 35 |
+
default_to_square = False
|
| 36 |
+
do_resize = True
|
| 37 |
+
do_rescale = True
|
| 38 |
+
do_normalize = True
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
__all__ = ["SiglipImageProcessorFast"]
|