# timesfm_backend.py import time import json import logging from typing import Any, Dict, Optional import numpy as np import torch from backends_base import ChatBackend, ImagesBackend from config import settings logger = logging.getLogger(__name__) try: from timesfm import TimesFm _TIMESFM_AVAILABLE = True except Exception as e: logger.warning("timesfm not available (%s)", e) TimesFm = None _TIMESFM_AVAILABLE = False # ---------- small helpers ---------- def _parse_series(series: Any) -> np.ndarray: if series is None: raise ValueError("series is required") if isinstance(series, dict): if "values" in series: series = series["values"] elif "y" in series: series = series["y"] vals = [] if isinstance(series, (list, tuple)): if series and isinstance(series[0], dict): for item in series: if "y" in item: vals.append(float(item["y"])) elif "value" in item: vals.append(float(item["value"])) else: vals = [float(x) for x in series] else: raise ValueError("series must be a list/tuple or dict with 'values'/'y'") if not vals: raise ValueError("series is empty") return np.asarray(vals, dtype=np.float32) def _fallback_forecast(y: np.ndarray, horizon: int) -> np.ndarray: if horizon <= 0: return np.zeros((0,), dtype=np.float32) k = 4 if y.shape[0] >= 4 else y.shape[0] base = float(np.mean(y[-k:])) return np.full((horizon,), base, dtype=np.float32) # ---------- backend ---------- class TimesFMBackend(ChatBackend): def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None): self.model_id = model_id or "google/timesfm-2.5-200m-pytorch" self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self._model = None def _ensure_model(self): if self._model is not None or not _TIMESFM_AVAILABLE: return try: # you may need to adjust context_len/horizon_len to match checkpoint self._model = TimesFm( context_len=512, horizon_len=128, input_patch_len=32, ) self._model.load_from_checkpoint(self.model_id) self._model.to(self.device) logger.info("TimesFM model loaded from %s on %s", self.model_id, self.device) except Exception as e: logger.exception("Failed to init TimesFM, fallback only. %s", e) self._model = None async def forecast(self, payload: Dict[str, Any]) -> Dict[str, Any]: if "data" in payload and isinstance(payload["data"], dict): payload = {**payload, **payload["data"]} if "timeseries" in payload and isinstance(payload["timeseries"], dict): payload = {**payload, **payload["timeseries"]} y = _parse_series(payload.get("series")) horizon = int(payload.get("horizon", 0)) freq = payload.get("freq") if horizon <= 0: raise ValueError("horizon must be positive") self._ensure_model() note = None if self._model is not None: try: x = torch.tensor(y, dtype=torch.float32, device=self.device)[None, :] preds = self._model.forecast_on_batch(x, horizon) fc = preds[0].detach().cpu().numpy().astype(float).tolist() except Exception as e: logger.exception("TimesFM forecast failed, using fallback. %s", e) fc = _fallback_forecast(y, horizon).tolist() note = "fallback_used_due_to_predict_error" else: fc = _fallback_forecast(y, horizon).tolist() note = "fallback_used_timesfm_missing" return { "model": self.model_id, "horizon": horizon, "freq": freq, "forecast": fc, "note": note, } async def stream(self, request: Dict[str, Any]): rid = f"chatcmpl-timesfm-{int(time.time())}" now = int(time.time()) payload = dict(request) if isinstance(request, dict) else {} try: result = await self.forecast(payload) except Exception as e: content = json.dumps({"error": str(e)}, separators=(",", ":"), ensure_ascii=False) yield { "id": rid, "object": "chat.completion.chunk", "created": now, "model": self.model_id, "choices": [{"index": 0, "delta": {"role": "assistant", "content": content}, "finish_reason": "stop"}], } return content = json.dumps( { "model": result["model"], "horizon": result["horizon"], "freq": result["freq"], "forecast": result["forecast"], "note": result.get("note"), "backend": "timesfm", }, separators=(",", ":"), ensure_ascii=False, ) yield { "id": rid, "object": "chat.completion.chunk", "created": now, "model": self.model_id, "choices": [{"index": 0, "delta": {"role": "assistant", "content": content}, "finish_reason": "stop"}], } class StubImagesBackend(ImagesBackend): async def generate_b64(self, request: Dict[str, Any]) -> str: logger.warning("Image generation not supported in TimesFM backend.") return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="