Spaces:
Paused
Paused
| import os | |
| import torch | |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
| from transformers import LlamaForCausalLM, LlamaTokenizer | |
| from datasets import load_dataset | |
| from openvoice import se_extractor | |
| from openvoice.api import BaseSpeakerTTS, ToneColorConverter | |
| import gradio as gr | |
| import spaces | |
| # Device setup | |
| torch_dtype = torch.float16 | |
| # Whisper setup | |
| whisper_model_id = "openai/whisper-large-v3" | |
| whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| whisper_model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True, | |
| ) | |
| whisper_processor = AutoProcessor.from_pretrained(whisper_model_id) | |
| whisper_pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=whisper_model, | |
| tokenizer=whisper_processor.tokenizer, | |
| feature_extractor=whisper_processor.feature_extractor, | |
| max_new_tokens=128, | |
| chunk_length_s=30, | |
| batch_size=16, | |
| return_timestamps=True, | |
| torch_dtype=torch_dtype, | |
| ) | |
| # LLaMa3-8B setup | |
| llama_model_id = "meta-llama/Meta-Llama-3-8B" | |
| llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_id) | |
| llama_model = LlamaForCausalLM.from_pretrained(llama_model_id, torch_dtype=torch_dtype) | |
| # OpenVoiceV2 setup | |
| ckpt_base = 'checkpoints/base_speakers/EN' | |
| ckpt_converter = 'checkpoints/converter' | |
| output_dir = 'outputs' | |
| base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base}/config.json',) | |
| base_speaker_tts.load_ckpt(f'{ckpt_base}/checkpoint.pth') | |
| tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json',) | |
| tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth') | |
| os.makedirs(output_dir, exist_ok=True) | |
| source_se = torch.load(f'{ckpt_base}/en_default_se.pth').to(device) | |
| def process_audio(input_audio): | |
| # ASR with Whisper | |
| whisper_result = whisper_pipe(input_audio)["text"] | |
| # Text generation with LLaMa | |
| inputs = llama_tokenizer(whisper_result, return_tensors="pt").to(device) | |
| outputs = llama_model.generate(**inputs, max_new_tokens=50) | |
| generated_text = llama_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # TTS with OpenVoiceV2 | |
| reference_speaker = 'resources/example_reference.mp3' | |
| target_se, _ = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True) | |
| save_path = f'{output_dir}/output_en_default.wav' | |
| src_path = f'{output_dir}/tmp.wav' | |
| base_speaker_tts.tts(generated_text, src_path, speaker='default', language='English', speed=1.0) | |
| tone_color_converter.convert( | |
| audio_src_path=src_path, | |
| src_se=source_se, | |
| tgt_se=target_se, | |
| output_path=save_path, | |
| message="@MyShell" | |
| ) | |
| return save_path | |
| def real_time_processing(input_audio): | |
| return process_audio(input_audio) | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=real_time_processing, | |
| inputs=gr.Audio(source="microphone", type="filepath"), | |
| outputs=gr.Audio(type="file"), | |
| live=True, | |
| title="ASR + Text-to-Text + TTS with Whisper, LLaMa3-8B, and OpenVoiceV2", | |
| description="Real-time processing using Whisper for ASR, LLaMa3-8B for text generation, and OpenVoiceV2 for TTS." | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |