Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, WebSocket | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from typing import List, Tuple | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import base64 | |
| import io | |
| import asyncio | |
| app = FastAPI() | |
| # Mount the static directory to serve HTML, JavaScript, and CSS files | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # Serve the index.html file at the root URL | |
| async def get(): | |
| return HTMLResponse(open("static/index.html").read()) | |
| def generate_random_image(width: int, height: int) -> np.ndarray: | |
| return np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) | |
| def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray: | |
| pil_image = Image.fromarray(image) | |
| draw = ImageDraw.Draw(pil_image) | |
| for i, (action_type, position) in enumerate(previous_actions): | |
| color = (255, 0, 0) if action_type == "move" else (0, 255, 0) | |
| x, y = position | |
| draw.ellipse([x-2, y-2, x+2, y+2], fill=color) | |
| if i > 0: | |
| prev_x, prev_y = previous_actions[i-1][1] | |
| draw.line([prev_x, prev_y, x, y], fill=color, width=1) | |
| return np.array(pil_image) | |
| def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray: | |
| width, height = 800, 600 | |
| if not previous_frames or previous_actions[-1][0] == "move": | |
| # Generate a new random image when there's no previous frame or the mouse moves | |
| new_frame = generate_random_image(width, height) | |
| else: | |
| # Use the last frame if it exists and the action is not a mouse move | |
| new_frame = previous_frames[-1].copy() | |
| # Draw the trace of previous actions | |
| new_frame_with_trace = draw_trace(new_frame, previous_actions) | |
| return new_frame_with_trace | |
| # WebSocket endpoint for continuous user interaction | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| previous_frames = [] | |
| previous_actions = [] | |
| try: | |
| while True: | |
| try: | |
| # Receive user input with a timeout | |
| data = await asyncio.wait_for(websocket.receive_json(), timeout=30.0) | |
| action_type = data.get("action_type") | |
| mouse_position = data.get("mouse_position") | |
| # Store the actions | |
| previous_actions.append((action_type, mouse_position)) | |
| # Predict the next frame based on the previous frames and actions | |
| next_frame = predict_next_frame(previous_frames, previous_actions) | |
| previous_frames.append(next_frame) | |
| # Convert the numpy array to a base64 encoded image | |
| img = Image.fromarray(next_frame) | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| # Send the generated frame back to the client | |
| await websocket.send_json({"image": img_str}) | |
| except asyncio.TimeoutError: | |
| print("WebSocket connection timed out") | |
| await websocket.close(code=1000) | |
| break | |
| except WebSocketDisconnect: | |
| print("WebSocket disconnected") | |
| break | |
| except Exception as e: | |
| print(f"Error in WebSocket connection: {e}") | |
| finally: | |
| print("WebSocket connection closed") | |
| await websocket.close() | |