from datetime import date from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from sqlalchemy.orm import selectinload from app.dependencies import get_db, get_current_user from app.models.child import Child from app.models.morning_routine import MorningRoutineItem from app.models.break_activity import BreakActivityItem from app.models.schedule import ScheduleBlock from app.models.subject import Subject # noqa: F401 — needed for selectinload chain from app.models.session import DailySession, TimerEvent from app.models.session_block_agenda import SessionBlockAgenda from app.models.user import User from app.schemas.session import DailySessionOut, SessionStart, TimerAction from sqlalchemy import delete as sql_delete from app.utils.timer import compute_block_elapsed, compute_break_elapsed from app.websocket.manager import manager router = APIRouter(prefix="/api/sessions", tags=["sessions"]) async def _broadcast_session(db: AsyncSession, session: DailySession) -> None: """Build a snapshot dict and broadcast it to all connected TVs for this child.""" blocks = [] if session.template_id: blocks_result = await db.execute( select(ScheduleBlock) .where(ScheduleBlock.template_id == session.template_id) .options(selectinload(ScheduleBlock.subject).selectinload(Subject.options)) .order_by(ScheduleBlock.time_start) ) blocks = [ { "id": b.id, "subject_id": b.subject_id, "subject": { "id": b.subject.id, "name": b.subject.name, "color": b.subject.color, "icon": b.subject.icon, "options": [{"id": o.id, "text": o.text, "order_index": o.order_index} for o in b.subject.options], } if b.subject else None, "time_start": str(b.time_start), "time_end": str(b.time_end), "duration_minutes": b.duration_minutes, "label": b.label, "order_index": b.order_index, "break_time_enabled": b.break_time_enabled, "break_time_minutes": b.break_time_minutes, } for b in blocks_result.scalars().all() ] # Gather completed block IDs from timer events events_result = await db.execute( select(TimerEvent).where( TimerEvent.session_id == session.id, TimerEvent.event_type == "complete", ) ) completed_ids = [e.block_id for e in events_result.scalars().all() if e.block_id] # Fetch morning routine items via child → user_id child_result = await db.execute(select(Child).where(Child.id == session.child_id)) child = child_result.scalar_one_or_none() morning_routine: list[str] = [] if child: routine_result = await db.execute( select(MorningRoutineItem) .where(MorningRoutineItem.user_id == child.user_id) .order_by(MorningRoutineItem.order_index, MorningRoutineItem.id) ) morning_routine = [item.text for item in routine_result.scalars().all()] break_activities: list[str] = [] if child: break_result = await db.execute( select(BreakActivityItem) .where(BreakActivityItem.user_id == child.user_id) .order_by(BreakActivityItem.order_index, BreakActivityItem.id) ) break_activities = [item.text for item in break_result.scalars().all()] agendas_result = await db.execute( select(SessionBlockAgenda).where(SessionBlockAgenda.session_id == session.id) ) block_agendas = { str(item.block_id): item.text for item in agendas_result.scalars().all() } payload = { "event": "session_update", "session": { "id": session.id, "child_id": session.child_id, "session_date": str(session.session_date), "is_active": session.is_active, "current_block_id": session.current_block_id, }, "blocks": blocks, "completed_block_ids": completed_ids, "morning_routine": morning_routine, "break_activities": break_activities, "block_agendas": block_agendas, } await manager.broadcast(session.child_id, payload) @router.post("", response_model=DailySessionOut, status_code=status.HTTP_201_CREATED) async def start_session( body: SessionStart, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): # Verify child belongs to user child_result = await db.execute( select(Child).where(Child.id == body.child_id, Child.user_id == current_user.id) ) child = child_result.scalar_one_or_none() if not child: raise HTTPException(status_code=404, detail="Child not found") # Reset strikes at the start of each new day if child.strikes != 0: child.strikes = 0 await manager.broadcast(body.child_id, {"event": "strikes_update", "strikes": 0}) session_date = body.session_date or date.today() # Deactivate any existing active session for this child today existing = await db.execute( select(DailySession).where( DailySession.child_id == body.child_id, DailySession.session_date == session_date, DailySession.is_active == True, ) ) for old in existing.scalars().all(): old.is_active = False session = DailySession( child_id=body.child_id, template_id=body.template_id, session_date=session_date, is_active=True, ) db.add(session) await db.commit() await db.refresh(session) # Record session start as a timer event so it appears in the activity log db.add(TimerEvent(session_id=session.id, block_id=None, event_type="session_start")) await db.commit() await _broadcast_session(db, session) return session @router.get("/{session_id}", response_model=DailySessionOut) async def get_session( session_id: int, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): result = await db.execute( select(DailySession) .join(Child) .where(DailySession.id == session_id, Child.user_id == current_user.id) .options(selectinload(DailySession.current_block).selectinload(ScheduleBlock.subject)) ) session = result.scalar_one_or_none() if not session: raise HTTPException(status_code=404, detail="Session not found") return session @router.post("/{session_id}/timer", response_model=DailySessionOut) async def timer_action( session_id: int, body: TimerAction, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): result = await db.execute( select(DailySession) .join(Child) .where(DailySession.id == session_id, Child.user_id == current_user.id) .options(selectinload(DailySession.current_block).selectinload(ScheduleBlock.subject)) ) session = result.scalar_one_or_none() if not session: raise HTTPException(status_code=404, detail="Session not found") # Break-time events are handled separately — they don't switch blocks or # trigger implicit pauses. Just record the event and broadcast. BREAK_EVENTS = {"break_start", "break_pause", "break_resume", "break_reset"} if body.event_type in BREAK_EVENTS: block_id = body.block_id or session.current_block_id # When break starts, implicitly pause the main block timer so elapsed # time is captured accurately in the activity log and on page reload. # Only write the pause if the block is actually running. if body.event_type == "break_start" and block_id: _, already_paused = await compute_block_elapsed(db, session.id, block_id) if not already_paused: db.add(TimerEvent( session_id=session.id, block_id=block_id, event_type="pause", )) db.add(TimerEvent( session_id=session.id, block_id=block_id, event_type=body.event_type, )) await db.commit() await db.refresh(session) break_elapsed_seconds = 0 block_elapsed_seconds = 0 if body.event_type in ("break_start", "break_reset") and block_id: break_elapsed_seconds, _ = await compute_break_elapsed(db, session.id, block_id) if body.event_type == "break_start" and block_id: block_elapsed_seconds, _ = await compute_block_elapsed(db, session.id, block_id) ws_payload = { "event": body.event_type, "session_id": session.id, "block_id": block_id, "current_block_id": session.current_block_id, "is_active": session.is_active, "break_elapsed_seconds": break_elapsed_seconds, "block_elapsed_seconds": block_elapsed_seconds, } await manager.broadcast(session.child_id, ws_payload) return session # When switching to a different block (start / select / reset), implicitly # pause the previous block so the activity log stays accurate. prev_block_id = None prev_block_elapsed_seconds = 0 if body.event_type in ("start", "select", "reset") and body.block_id is not None: prev_block_id = session.current_block_id if prev_block_id and prev_block_id != body.block_id: prev_block_elapsed_seconds, prev_already_paused = await compute_block_elapsed( db, session.id, prev_block_id ) # Only write an implicit pause if the previous block was actually running. if not prev_already_paused: db.add(TimerEvent( session_id=session.id, block_id=prev_block_id, event_type="pause", )) # Recompute elapsed now that the pause event is included. prev_block_elapsed_seconds, _ = await compute_block_elapsed( db, session.id, prev_block_id ) # Update current block if provided if body.block_id is not None: session.current_block_id = body.block_id # Record the timer event (select events are not persisted — they only drive WS broadcasts) event = TimerEvent( session_id=session.id, block_id=body.block_id or session.current_block_id, event_type=body.event_type, ) if body.event_type != "select": db.add(event) # Mark session complete if event is session-level complete if body.event_type == "complete" and body.block_id is None: session.is_active = False # Reset removes completed status — delete any complete events for this block if body.event_type == "reset" and event.block_id: await db.execute( sql_delete(TimerEvent).where( TimerEvent.session_id == session.id, TimerEvent.block_id == event.block_id, TimerEvent.event_type == "complete", ) ) await db.commit() await db.refresh(session) # For start / select / reset, compute elapsed for the new block so every # client can restore the correct offset without a local cache. block_elapsed_seconds = 0 if body.event_type in ("start", "select", "reset") and event.block_id: block_elapsed_seconds, _ = await compute_block_elapsed( db, session.id, event.block_id ) # Broadcast the timer event to all TV clients ws_payload = { "event": body.event_type, "session_id": session.id, "block_id": event.block_id, "current_block_id": session.current_block_id, "is_active": session.is_active, "is_paused": body.event_type == "select", "block_elapsed_seconds": block_elapsed_seconds, "prev_block_id": prev_block_id, "prev_block_elapsed_seconds": prev_block_elapsed_seconds, "uncomplete_block_id": event.block_id if body.event_type == "reset" else None, } await manager.broadcast(session.child_id, ws_payload) return session class AgendaUpdate(BaseModel): text: str @router.put("/{session_id}/blocks/{block_id}/agenda", status_code=status.HTTP_204_NO_CONTENT) async def set_block_agenda( session_id: int, block_id: int, body: AgendaUpdate, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): result = await db.execute( select(DailySession) .join(Child) .where(DailySession.id == session_id, Child.user_id == current_user.id) ) session = result.scalar_one_or_none() if not session: raise HTTPException(status_code=404, detail="Session not found") existing = await db.execute( select(SessionBlockAgenda).where( SessionBlockAgenda.session_id == session_id, SessionBlockAgenda.block_id == block_id, ) ) agenda = existing.scalar_one_or_none() clean_text = body.text.strip() if clean_text: if agenda: agenda.text = clean_text else: db.add(SessionBlockAgenda(session_id=session_id, block_id=block_id, text=clean_text)) elif agenda: await db.delete(agenda) await db.commit() await manager.broadcast(session.child_id, { "event": "agenda_update", "block_id": block_id, "text": clean_text, })