from datetime import date from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from sqlalchemy.orm import selectinload from app.dependencies import get_db, get_current_user from app.models.child import Child from app.models.morning_routine import MorningRoutineItem from app.models.break_activity import BreakActivityItem from app.models.schedule import ScheduleBlock from app.models.subject import Subject # noqa: F401 — needed for selectinload chain from app.models.session import DailySession, TimerEvent from app.models.user import User from app.schemas.session import DailySessionOut, SessionStart, TimerAction from app.utils.timer import compute_block_elapsed, compute_break_elapsed from app.websocket.manager import manager router = APIRouter(prefix="/api/sessions", tags=["sessions"]) async def _broadcast_session(db: AsyncSession, session: DailySession) -> None: """Build a snapshot dict and broadcast it to all connected TVs for this child.""" blocks = [] if session.template_id: blocks_result = await db.execute( select(ScheduleBlock) .where(ScheduleBlock.template_id == session.template_id) .options(selectinload(ScheduleBlock.subject).selectinload(Subject.options)) .order_by(ScheduleBlock.time_start) ) blocks = [ { "id": b.id, "subject_id": b.subject_id, "subject": { "id": b.subject.id, "name": b.subject.name, "color": b.subject.color, "icon": b.subject.icon, "options": [{"id": o.id, "text": o.text, "order_index": o.order_index} for o in b.subject.options], } if b.subject else None, "time_start": str(b.time_start), "time_end": str(b.time_end), "duration_minutes": b.duration_minutes, "label": b.label, "order_index": b.order_index, "break_time_enabled": b.break_time_enabled, "break_time_minutes": b.break_time_minutes, } for b in blocks_result.scalars().all() ] # Gather completed block IDs from timer events events_result = await db.execute( select(TimerEvent).where( TimerEvent.session_id == session.id, TimerEvent.event_type == "complete", ) ) completed_ids = [e.block_id for e in events_result.scalars().all() if e.block_id] # Fetch morning routine items via child → user_id child_result = await db.execute(select(Child).where(Child.id == session.child_id)) child = child_result.scalar_one_or_none() morning_routine: list[str] = [] if child: routine_result = await db.execute( select(MorningRoutineItem) .where(MorningRoutineItem.user_id == child.user_id) .order_by(MorningRoutineItem.order_index, MorningRoutineItem.id) ) morning_routine = [item.text for item in routine_result.scalars().all()] break_activities: list[str] = [] if child: break_result = await db.execute( select(BreakActivityItem) .where(BreakActivityItem.user_id == child.user_id) .order_by(BreakActivityItem.order_index, BreakActivityItem.id) ) break_activities = [item.text for item in break_result.scalars().all()] payload = { "event": "session_update", "session": { "id": session.id, "child_id": session.child_id, "session_date": str(session.session_date), "is_active": session.is_active, "current_block_id": session.current_block_id, }, "blocks": blocks, "completed_block_ids": completed_ids, "morning_routine": morning_routine, "break_activities": break_activities, } await manager.broadcast(session.child_id, payload) @router.post("", response_model=DailySessionOut, status_code=status.HTTP_201_CREATED) async def start_session( body: SessionStart, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): # Verify child belongs to user child_result = await db.execute( select(Child).where(Child.id == body.child_id, Child.user_id == current_user.id) ) child = child_result.scalar_one_or_none() if not child: raise HTTPException(status_code=404, detail="Child not found") # Reset strikes at the start of each new day if child.strikes != 0: child.strikes = 0 await manager.broadcast(body.child_id, {"event": "strikes_update", "strikes": 0}) session_date = body.session_date or date.today() # Deactivate any existing active session for this child today existing = await db.execute( select(DailySession).where( DailySession.child_id == body.child_id, DailySession.session_date == session_date, DailySession.is_active == True, ) ) for old in existing.scalars().all(): old.is_active = False session = DailySession( child_id=body.child_id, template_id=body.template_id, session_date=session_date, is_active=True, ) db.add(session) await db.commit() await db.refresh(session) # Record session start as a timer event so it appears in the activity log db.add(TimerEvent(session_id=session.id, block_id=None, event_type="session_start")) await db.commit() await _broadcast_session(db, session) return session @router.get("/{session_id}", response_model=DailySessionOut) async def get_session( session_id: int, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): result = await db.execute( select(DailySession) .join(Child) .where(DailySession.id == session_id, Child.user_id == current_user.id) .options(selectinload(DailySession.current_block).selectinload(ScheduleBlock.subject)) ) session = result.scalar_one_or_none() if not session: raise HTTPException(status_code=404, detail="Session not found") return session @router.post("/{session_id}/timer", response_model=DailySessionOut) async def timer_action( session_id: int, body: TimerAction, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): result = await db.execute( select(DailySession) .join(Child) .where(DailySession.id == session_id, Child.user_id == current_user.id) .options(selectinload(DailySession.current_block).selectinload(ScheduleBlock.subject)) ) session = result.scalar_one_or_none() if not session: raise HTTPException(status_code=404, detail="Session not found") # Break-time events are handled separately — they don't switch blocks or # trigger implicit pauses. Just record the event and broadcast. BREAK_EVENTS = {"break_start", "break_pause", "break_resume", "break_reset"} if body.event_type in BREAK_EVENTS: block_id = body.block_id or session.current_block_id # When break starts, implicitly pause the main block timer so elapsed # time is captured accurately in the activity log and on page reload. if body.event_type == "break_start" and block_id: db.add(TimerEvent( session_id=session.id, block_id=block_id, event_type="pause", )) db.add(TimerEvent( session_id=session.id, block_id=block_id, event_type=body.event_type, )) await db.commit() await db.refresh(session) break_elapsed_seconds = 0 block_elapsed_seconds = 0 if body.event_type in ("break_start", "break_reset") and block_id: break_elapsed_seconds, _ = await compute_break_elapsed(db, session.id, block_id) if body.event_type == "break_start" and block_id: block_elapsed_seconds, _ = await compute_block_elapsed(db, session.id, block_id) ws_payload = { "event": body.event_type, "session_id": session.id, "block_id": block_id, "current_block_id": session.current_block_id, "is_active": session.is_active, "break_elapsed_seconds": break_elapsed_seconds, "block_elapsed_seconds": block_elapsed_seconds, } await manager.broadcast(session.child_id, ws_payload) return session # When switching to a different block (start / select / reset), implicitly # pause the previous block so the activity log stays accurate. prev_block_id = None prev_block_elapsed_seconds = 0 if body.event_type in ("start", "select", "reset") and body.block_id is not None: prev_block_id = session.current_block_id if prev_block_id and prev_block_id != body.block_id: db.add(TimerEvent( session_id=session.id, block_id=prev_block_id, event_type="pause", )) # Autoflush means the implicit pause above is visible to the helper. prev_block_elapsed_seconds, _ = await compute_block_elapsed( db, session.id, prev_block_id ) # Update current block if provided if body.block_id is not None: session.current_block_id = body.block_id # Record the timer event event = TimerEvent( session_id=session.id, block_id=body.block_id or session.current_block_id, event_type=body.event_type, ) db.add(event) # Mark session complete if event is session-level complete if body.event_type == "complete" and body.block_id is None: session.is_active = False await db.commit() await db.refresh(session) # For start / select / reset, compute elapsed for the new block so every # client can restore the correct offset without a local cache. block_elapsed_seconds = 0 if body.event_type in ("start", "select", "reset") and event.block_id: block_elapsed_seconds, _ = await compute_block_elapsed( db, session.id, event.block_id ) # Broadcast the timer event to all TV clients ws_payload = { "event": body.event_type, "session_id": session.id, "block_id": event.block_id, "current_block_id": session.current_block_id, "is_active": session.is_active, "is_paused": body.event_type == "select", "block_elapsed_seconds": block_elapsed_seconds, "prev_block_id": prev_block_id, "prev_block_elapsed_seconds": prev_block_elapsed_seconds, } await manager.broadcast(session.child_id, ws_payload) return session