Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from dataclasses import dataclass | |
| import os | |
| from supabase import create_client, Client | |
| from supabase.client import ClientOptions | |
| from enum import Enum | |
| from datasets import get_dataset_infos | |
| from transformers import AutoConfig | |
| class GenerationStatus(Enum): | |
| PENDING = "PENDING" | |
| RUNNING = "RUNNING" | |
| COMPLETED = "COMPLETED" | |
| FAILED = "FAILED" | |
| MAX_SAMPLES = 10000 # max number of samples in the input dataset | |
| MAX_TOKENS = 32768 | |
| MAX_MODEL_PARAMS = 20_000_000_000 # 20 billion parameters (for now) | |
| class GenerationRequest: | |
| id: str | |
| status: GenerationStatus | |
| input_dataset_name: str | |
| input_dataset_config: str | |
| input_dataset_split: str | |
| prompt_column: str | |
| model_name_or_path: str | |
| model_revision: str | |
| model_token: str | None | |
| system_prompt: str | None | |
| max_tokens: int | |
| temperature: float | |
| top_k: int | |
| top_p: float | |
| input_dataset_token: str | None | |
| output_dataset_token: str | |
| username: str | |
| email: str | |
| def validate_request(request: GenerationRequest): | |
| # checks that the request is valid | |
| # - input dataset exists and can be accessed with the provided token | |
| try: | |
| input_dataset_info = get_dataset_infos(request.input_dataset_name, token=request.input_dataset_token)[request.input_dataset_config] | |
| except Exception as e: | |
| raise Exception(f"Dataset {request.input_dataset_name} does not exist or cannot be accessed with the provided token.") | |
| # check that the input dataset split exists | |
| if request.input_dataset_split not in input_dataset_info.splits: | |
| raise Exception(f"Dataset split {request.input_dataset_split} does not exist in dataset {request.input_dataset_name}. Available splits: {list(input_dataset_info.splits.keys())}") | |
| # check that the number of samples is less than MAX_SAMPLES | |
| if input_dataset_info.splits[request.input_dataset_split].num_samples > MAX_SAMPLES: | |
| raise Exception(f"Dataset split {request.input_dataset_split} in dataset {request.input_dataset_name} exceeds max sample limit of {MAX_SAMPLES}.") | |
| # check the prompt column exists in the dataset | |
| if request.prompt_column not in input_dataset_info.features: | |
| raise Exception(f"Prompt column {request.prompt_column} does not exist in dataset {request.input_dataset_name}. Available columns: {list(input_dataset_info.features.keys())}") | |
| # check the models exists | |
| try: | |
| model_config = AutoConfig.from_pretrained(request.model_name_or_path, revision=request.model_revision, token=request.model_token) | |
| except Exception as e: | |
| raise Exception(f"Model {request.model_name_or_path} revision {request.model_revision} does not exist or cannot be accessed with the provided token.") | |
| # check the model max position embeddings is greater than the requested max tokens and less than MAX_TOKENS | |
| if model_config.max_position_embeddings < request.max_tokens: | |
| raise Exception(f"Model {request.model_name_or_path} max position embeddings {model_config.max_position_embeddings} is less than the requested max tokens {request.max_tokens}.") | |
| if request.max_tokens > MAX_TOKENS: | |
| raise Exception(f"Requested max tokens {request.max_tokens} exceeds the limit of {MAX_TOKENS}.") | |
| # check sampling parameters are valid | |
| if request.temperature < 0.0 or request.temperature > 2.0: | |
| raise Exception("Temperature must be between 0.0 and 2.0") | |
| if request.top_k < 1 or request.top_k > 100: | |
| raise Exception("Top K must be between 1 and 100") | |
| if request.top_p < 0.0 or request.top_p > 1.0: | |
| raise Exception("Top P must be between 0.0 and 1.0") | |
| # check valid email address TODO: use py3-validate-email https://stackoverflow.com/questions/8022530/how-to-check-for-valid-email-address | |
| if "@" not in request.email or "." not in request.email.split("@")[-1]: | |
| raise Exception("Invalid email address") | |
| def add_request_to_db(request: GenerationRequest): | |
| url: str = os.getenv("SUPABASE_URL") | |
| key: str = os.getenv("SUPABASE_KEY") | |
| options: ClientOptions = { | |
| "schema": "public" | |
| } | |
| supabase: Client = create_client(url, key, options) | |
| data = { | |
| "status": request.status.value, | |
| "input_dataset_name": request.input_dataset_name, | |
| "input_dataset_config": request.input_dataset_config, | |
| "input_dataset_split": request.input_dataset_split, | |
| "prompt_column": request.prompt_column, | |
| "model_name_or_path": request.model_name_or_path, | |
| "model_revision": request.model_revision, | |
| "model_token": request.model_token, | |
| "system_prompt": request.system_prompt, | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_k": request.top_k, | |
| "top_p": request.top_p, | |
| "input_dataset_token": request.input_dataset_token, | |
| "output_dataset_token": request.output_dataset_token, | |
| "username": request.username, | |
| "email": request.email | |
| } | |
| response = supabase.table("generation-requests").insert(data).execute() | |
| if response.status_code != 201: | |
| raise Exception(f"Failed to add request to database: {response.data}") | |
| return response.data | |
| def create_gradio_interface(): | |
| with gr.Blocks(title="Synthetic Data Generation") as interface: | |
| with gr.Group(): | |
| with gr.Row(): | |
| gr.Markdown("# Synthetic Data Generation Request") | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| Welcome to the Synthetic Data Generation service! This tool allows you to generate synthetic data using large language models. | |
| Generation is FREE for Hugging Face PRO users and uses idle GPUs on the HF science cluster.\n | |
| """) | |
| with gr.Group(): | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| **How it works:** | |
| 1. Provide an input dataset with prompts | |
| 2. Select a language model for generation | |
| 3. Configure generation parameters | |
| 4. Submit your request and receive generated data | |
| """) | |
| gr.Markdown(""" | |
| **Requirements:** | |
| - Input dataset must be publicly accessible or you must provide a valid HuggingFace token | |
| - Output dataset repository must exist and you must have write access | |
| - Model must be accessible (public or with valid token) | |
| - Maximum 10,000 samples per dataset | |
| - Maximum of 32k generation tokens | |
| **Note:** Generation requests are processed asynchronously. You will be notified via email when your request is complete. | |
| """) | |
| with gr.Row(): | |
| with gr.Group(): | |
| gr.Markdown("## Dataset information") | |
| with gr.Column(): | |
| with gr.Row(): | |
| input_dataset_name = gr.Textbox(label="Input Dataset Name", placeholder="e.g., simplescaling/s1K-1.1") | |
| input_dataset_split = gr.Textbox(label="Input Dataset Split", value="train", placeholder="e.g., train, test, validation") | |
| input_dataset_config = gr.Textbox(label="Input Dataset Config", value="default", placeholder="e.g., default, custom") | |
| prompt_column = gr.Textbox(label="Prompt Column", placeholder="e.g., text, prompt, question") | |
| with gr.Column(): | |
| output_dataset_name = gr.Textbox(label="Output Dataset Name", placeholder="e.g., MyOrg/my-generated-dataset") | |
| with gr.Group(): | |
| gr.Markdown("## Model information") | |
| with gr.Column(): | |
| with gr.Row(): | |
| model_name_or_path = gr.Textbox(label="Model Name or Path", placeholder="e.g., Qwen/Qwen3-4B-Instruct-2507") | |
| model_revision = gr.Textbox(label="Model Revision", value="main", placeholder="e.g., main, v1.0") | |
| model_token = gr.Textbox(label="Model Token (Optional)", type="password", placeholder="Your HF token with read/write access to the model...") | |
| with gr.Group(): | |
| gr.Markdown("## Generation Parameters") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| max_tokens = gr.Slider(label="Max Tokens", value=512, minimum=256, maximum=MAX_TOKENS, step=256) | |
| temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, value=0.7, step=0.1) | |
| with gr.Row(): | |
| top_k = gr.Slider(label="Top K", value=50, minimum=5, maximum=100, step=5) | |
| top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.1) | |
| with gr.Column(): | |
| system_prompt = gr.Textbox(label="System Prompt (Optional)", lines=3, placeholder="Optional system prompt... e.g., You are a helpful assistant.") | |
| with gr.Group(): | |
| gr.Markdown("## User Information, for tokens refer to guide [here](https://huggingface.co/docs/hub/en/security-tokens#user-access-tokens)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| username = gr.Textbox(label="Hugging Face Username", placeholder="Your HF username") | |
| email = gr.Textbox(label="Email", placeholder="[email protected]") | |
| with gr.Row(): | |
| input_dataset_token = gr.Textbox(label="Input dataset token", type="password", placeholder="Your HF token with read access to the input dataset, leave blank if public dataset") | |
| output_dataset_token = gr.Textbox(label="Output dataset token", type="password", placeholder="Your HF token with write access to the output dataset") | |
| submit_btn = gr.Button("Submit Generation Request", variant="primary") | |
| output_status = gr.Textbox(label="Status", interactive=False) | |
| def submit_request(input_ds, input_split, prompt_col, model_name, model_rev, model_token, sys_prompt, | |
| max_tok, temp, top_k_val, top_p_val, output_ds, user, email_addr, input_dataset_token, output_dataset_token): | |
| try: | |
| request = GenerationRequest( | |
| id="", # Will be generated when adding to the database | |
| status=GenerationStatus.PENDING, | |
| input_dataset_name=input_ds, | |
| input_dataset_split=input_split, | |
| input_dataset_config=input_dataset_config, | |
| prompt_column=prompt_col, | |
| model_name_or_path=model_name, | |
| model_revision=model_rev, | |
| model_token=model_token if model_token else None, | |
| system_prompt=sys_prompt if sys_prompt else None, | |
| max_tokens=int(max_tok), | |
| temperature=temp, | |
| top_k=int(top_k_val), | |
| top_p=top_p_val, | |
| output_dataset_name=output_ds, | |
| input_dataset_token=input_dataset_token if input_dataset_token else None, | |
| output_dataset_token=output_dataset_token, | |
| username=user, | |
| email=email_addr | |
| ) | |
| # check the input dataset exists and can be accessed with the provided token | |
| validate_request(request) | |
| add_request_to_db(request) | |
| return "Request submitted successfully!" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| submit_btn.click( | |
| submit_request, | |
| inputs=[input_dataset_name, input_dataset_split, prompt_column, model_name_or_path, | |
| model_revision, model_token, system_prompt, max_tokens, temperature, top_k, top_p, | |
| output_dataset_name, username, email, input_dataset_token, output_dataset_token], | |
| outputs=output_status | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| app = create_gradio_interface() | |
| app.launch() |