import json import logging from collections import defaultdict from fastapi import WebSocket logger = logging.getLogger(__name__) class ConnectionManager: def __init__(self): # child_id → list of active WebSocket connections self.active: dict[int, list[WebSocket]] = defaultdict(list) async def connect(self, websocket: WebSocket, child_id: int) -> None: await websocket.accept() self.active[child_id].append(websocket) logger.info("WS connected for child %d — %d total", child_id, len(self.active[child_id])) def disconnect(self, websocket: WebSocket, child_id: int) -> None: self.active[child_id].discard(websocket) if hasattr(self.active[child_id], "discard") else None try: self.active[child_id].remove(websocket) except ValueError: pass logger.info("WS disconnected for child %d — %d remaining", child_id, len(self.active[child_id])) async def broadcast(self, child_id: int, message: dict) -> None: """Send a JSON message to all TVs watching a given child.""" dead = [] for ws in list(self.active.get(child_id, [])): try: await ws.send_text(json.dumps(message)) except Exception: dead.append(ws) for ws in dead: self.disconnect(ws, child_id) # Singleton — imported by routers and the WS endpoint manager = ConnectionManager()