Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import torch | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| from transformers import AutoProcessor, AutoModelForVision2Seq | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig | |
| import time | |
| import nltk | |
| import io | |
| import sys | |
| import pkg_resources | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| from parler_tts import ParlerTTSForConditionalGeneration | |
| from transformers import AutoFeatureExtractor, set_seed | |
| from transformers.models.speecht5.number_normalizer import EnglishNumberNormalizer | |
| from string import punctuation | |
| import re | |
| # Set environment variables | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| def initialize_vision_model(): | |
| model_id = "HuggingFaceTB/SmolVLM-500M-Instruct" | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = AutoModelForVision2Seq.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| return { | |
| "processor": processor, | |
| "model": model, | |
| "device": device | |
| } | |
| def analyze_image(image, vision_components, instruction="What do you see?"): | |
| processor = vision_components["processor"] | |
| model = vision_components["model"] | |
| device = vision_components["device"] | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| try: | |
| # Prepare chat template | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": instruction} | |
| ] | |
| } | |
| ] | |
| text = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| inputs = processor(text, [image], return_tensors="pt", do_image_splitting=False).to(device) | |
| with torch.no_grad(): | |
| generated_ids = model.generate(**inputs, max_new_tokens=100) | |
| output = processor.batch_decode(generated_ids[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True) | |
| return output[0].strip() if output else "" | |
| except Exception as e: | |
| print(f"Error in analyze_image: {str(e)}") | |
| return "" | |
| def initialize_llm(): | |
| model_id = "meta-llama/Llama-3.2-1B-Instruct" | |
| hf_token = os.environ.get("HF_TOKEN") | |
| # Load and patch config | |
| config = AutoConfig.from_pretrained(model_id, token=hf_token) | |
| if hasattr(config, "rope_scaling"): | |
| rope_scaling = config.rope_scaling | |
| if isinstance(rope_scaling, dict): | |
| config.rope_scaling = { | |
| "type": rope_scaling.get("type", "linear"), | |
| "factor": rope_scaling.get("factor", 1.0) | |
| } | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| config=config, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| token=hf_token | |
| ) | |
| return { | |
| "model": model, | |
| "tokenizer": tokenizer | |
| } | |
| def generate_roast(caption, llm_components): | |
| model = llm_components["model"] | |
| tokenizer = llm_components["tokenizer"] | |
| prompt = f"""[INST] You are AsianMOM, a stereotypical Asian mother who always has high expectations. \nYou just observed your child doing this: \"{caption}\"\n \nRespond with a short, humorous roast (maximum 2-3 sentences) in the style of a stereotypical Asian mother. \nInclude at least one of these elements:\n- Comparison to more successful relatives/cousins\n- High expectations about academic success\n- Mild threats about using slippers\n- Questioning life choices\n- Asking when they'll get married or have kids\n- Commenting on appearance\n- Saying \"back in my day\" and describing hardship\n\nBe funny but not hurtful. Keep it brief. [/INST]""" | |
| try: | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=300, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = response.split("[/INST]")[1].strip() | |
| return response if isinstance(response, str) else "" | |
| except Exception as e: | |
| print(f"Error in generate_roast: {str(e)}") | |
| return "" # Return empty string on error | |
| # Parler-TTS setup | |
| def setup_tts(): | |
| try: | |
| parler_device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| parler_repo_id = "parler-tts/parler-tts-mini-expresso" | |
| parler_model = ParlerTTSForConditionalGeneration.from_pretrained(parler_repo_id).to(parler_device) | |
| parler_tokenizer = AutoTokenizer.from_pretrained(parler_repo_id) | |
| parler_feature_extractor = AutoFeatureExtractor.from_pretrained(parler_repo_id) | |
| PARLER_SAMPLE_RATE = parler_feature_extractor.sampling_rate | |
| PARLER_SEED = 42 | |
| parler_number_normalizer = EnglishNumberNormalizer() | |
| return { | |
| "model": parler_model, | |
| "tokenizer": parler_tokenizer, | |
| "feature_extractor": parler_feature_extractor, | |
| "sample_rate": PARLER_SAMPLE_RATE, | |
| "seed": PARLER_SEED, | |
| "number_normalizer": parler_number_normalizer, | |
| "device": parler_device | |
| } | |
| except Exception as e: | |
| print(f"Error setting up TTS: {str(e)}") | |
| return None | |
| def parler_preprocess(text, number_normalizer): | |
| text = number_normalizer(text).strip() | |
| if text and text[-1] not in punctuation: | |
| text = f"{text}." | |
| abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b' | |
| def separate_abb(chunk): | |
| chunk = chunk.replace(".", "") | |
| return " ".join(chunk) | |
| abbreviations = re.findall(abbreviations_pattern, text) | |
| for abv in abbreviations: | |
| if abv in text: | |
| text = text.replace(abv, separate_abb(abv)) | |
| return text | |
| def text_to_speech(text, tts_components): | |
| if tts_components is None: | |
| return (16000, np.zeros(1)) # Default sample rate if components failed to load | |
| model = tts_components["model"] | |
| tokenizer = tts_components["tokenizer"] | |
| device = tts_components["device"] | |
| sample_rate = tts_components["sample_rate"] | |
| seed = tts_components["seed"] | |
| number_normalizer = tts_components["number_normalizer"] | |
| description = ("Elisabeth speaks in a mature, strict, nagging, and slightly disappointed tone, " | |
| "with a hint of love and high expectations, at a moderate pace with high quality audio. " | |
| "She sounds like a stereotypical Asian mother who compares you to your cousins, " | |
| "questions your life choices, and threatens you with a slipper, but ultimately wants the best for you.") | |
| if not text or not isinstance(text, str): | |
| return (sample_rate, np.zeros(1)) | |
| try: | |
| inputs = tokenizer(description, return_tensors="pt").to(device) | |
| prompt = tokenizer(parler_preprocess(text, number_normalizer), return_tensors="pt").to(device) | |
| set_seed(seed) | |
| generation = model.generate(input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids) | |
| audio_arr = generation.cpu().numpy().squeeze() | |
| return (sample_rate, audio_arr) | |
| except Exception as e: | |
| print(f"Error in text_to_speech: {str(e)}") | |
| return (sample_rate, np.zeros(1)) | |
| def process_frame(image, vision_components, llm_components, tts_components): | |
| try: | |
| caption = analyze_image(image, vision_components) | |
| roast = generate_roast(caption, llm_components) | |
| default_sample_rate = 16000 | |
| if tts_components is not None: | |
| default_sample_rate = tts_components["sample_rate"] | |
| if not roast or not isinstance(roast, str): | |
| audio = (default_sample_rate, np.zeros(1)) | |
| else: | |
| audio = text_to_speech(roast, tts_components) | |
| return caption, roast, audio | |
| except Exception as e: | |
| print(f"Error in process_frame: {str(e)}") | |
| return "", "", (default_sample_rate, np.zeros(1)) | |
| def create_app(): | |
| try: | |
| # Initialize components before creating the app | |
| vision_components = initialize_vision_model() | |
| tts_components = setup_tts() | |
| # Try to initialize LLM with Hugging Face token | |
| hf_token = os.environ.get("HF_TOKEN") | |
| llm_components = None | |
| if hf_token: | |
| try: | |
| llm_components = initialize_llm() | |
| except Exception as e: | |
| print(f"Error initializing LLM: {str(e)}. Will use fallback.") | |
| # Fallback if LLM initialization failed | |
| if llm_components is None: | |
| def fallback_generate_roast(caption, _): | |
| return f"I see you {caption}. Why you not doctor yet? Your cousin studying at Harvard!" | |
| llm_components = {"generate_roast": fallback_generate_roast} | |
| # Set initial values and processing parameters | |
| last_process_time = time.time() - 10 | |
| processing_interval = 5 | |
| with gr.Blocks(theme=gr.themes.Monochrome()) as app: | |
| gr.Markdown("# AsianMOM: Artificial Surveillance with Interactive Analysis with a Nagging Maternal Oversight Model") | |
| gr.Markdown("### Camera captures what you're doing and your Asian mom responds appropriately") | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_feed = gr.Image(sources=["webcam"], streaming=True, label="Camera Feed") | |
| with gr.Column(): | |
| analysis_output = gr.Textbox(label="What AsianMOM Sees", lines=2) | |
| roast_output = gr.Textbox(label="AsianMOM's Thoughts", lines=4) | |
| audio_output = gr.Audio(label="AsianMOM Says", autoplay=True) | |
| # Define processing function | |
| def process_webcam(image): | |
| nonlocal last_process_time | |
| current_time = time.time() | |
| default_caption = "" | |
| default_roast = "" | |
| default_sample_rate = 16000 | |
| if tts_components is not None: | |
| default_sample_rate = tts_components["sample_rate"] | |
| default_audio = (default_sample_rate, np.zeros(1)) | |
| if current_time - last_process_time >= processing_interval and image is not None: | |
| last_process_time = current_time | |
| try: | |
| caption, roast, audio = process_frame( | |
| image, | |
| vision_components, | |
| llm_components, | |
| tts_components | |
| ) | |
| final_caption = caption if isinstance(caption, str) else default_caption | |
| final_roast = roast if isinstance(roast, str) else default_roast | |
| final_audio = audio if isinstance(audio, tuple) and len(audio) == 2 and isinstance(audio[1], np.ndarray) else default_audio | |
| return image, final_caption, final_roast, final_audio | |
| except Exception as e: | |
| print(f"Error in process_webcam: {str(e)}") | |
| return image, default_caption, default_roast, default_audio | |
| # Setup the processing chain | |
| video_feed.change( | |
| process_webcam, | |
| inputs=[video_feed], | |
| outputs=[video_feed, analysis_output, roast_output, audio_output] | |
| ) | |
| return app | |
| except Exception as e: | |
| print(f"Error creating app: {str(e)}") | |
| # Create a fallback simple app that reports the error | |
| with gr.Blocks() as fallback_app: | |
| gr.Markdown("# AsianMOM: Error Initializing") | |
| gr.Markdown(f"Error: {str(e)}") | |
| gr.Markdown("Please check your environment setup and try again.") | |
| return fallback_app | |
| if __name__ == "__main__": | |
| try: | |
| # Check and report Gradio version | |
| gradio_version = pkg_resources.get_distribution("gradio").version | |
| print(f"Using Gradio version: {gradio_version}") | |
| # Try to download required resources | |
| try: | |
| os.system('python -m unidic download') | |
| nltk.download('averaged_perceptron_tagger_eng') | |
| except Exception as e: | |
| print(f"Warning: Could not download some resources: {str(e)}") | |
| # Create the app | |
| app = create_app() | |
| # Try multiple launch configurations if needed | |
| try: | |
| # First attempt with share=True | |
| print("Launching app with share=True and debug=True") | |
| app.launch(share=True, debug=True) | |
| except ValueError as e: | |
| if "localhost is not accessible" in str(e): | |
| # Second attempt with server_name to bind to all interfaces | |
| print("Retrying with server_name='0.0.0.0'") | |
| app.launch(share=True, debug=True, server_name="0.0.0.0") | |
| else: | |
| raise e | |
| except Exception as e: | |
| print(f"Fatal error: {str(e)}") | |
| # If all else fails, create a minimal app | |
| with gr.Blocks() as minimal_app: | |
| gr.Markdown("# AsianMOM: Fatal Error") | |
| gr.Markdown(f"Fatal error: {str(e)}") | |
| gr.Markdown("Try updating Gradio with: pip install --upgrade gradio") | |
| try: | |
| minimal_app.launch(share=True, server_name="0.0.0.0") | |
| except: | |
| minimal_app.launch(share=True) # Last attempt with minimal options |