Add multi-user auth, admin panel, and timezone support; rename to Yolkbook
- Rename app from Eggtracker to Yolkbook throughout - Add JWT-based authentication (python-jose, passlib/bcrypt) - Add users table; all data tables gain user_id FK for full data isolation - Super admin credentials sourced from ADMIN_USERNAME/ADMIN_PASSWORD env vars, synced on every startup; orphaned rows auto-assigned to admin post-migration - Login page with self-registration; JWT stored in localStorage (30-day expiry) - Admin panel (/admin): list users, reset passwords, disable/enable, delete, and impersonate (Login As) with Return to Admin banner - Settings modal (gear icon in nav): timezone selector and change password - Timezone stored per-user; stats date windows computed in user's timezone; date input setToday() respects user timezone via Intl API - migrate_v2.sql for existing single-user installs - Auto-migration adds timezone column to users on startup - Updated README with full setup, auth, admin, and migration docs Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
111
backend/routers/admin.py
Normal file
111
backend/routers/admin.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import User
|
||||
from schemas import UserCreate, UserOut, ResetPasswordRequest, TokenResponse
|
||||
from auth import hash_password, create_access_token, get_current_admin, get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/admin", tags=["admin"])
|
||||
|
||||
|
||||
@router.get("/users", response_model=list[UserOut])
|
||||
def list_users(
|
||||
_: User = Depends(get_current_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return db.scalars(select(User).order_by(User.created_at)).all()
|
||||
|
||||
|
||||
@router.post("/users", response_model=UserOut, status_code=201)
|
||||
def create_user(
|
||||
body: UserCreate,
|
||||
_: User = Depends(get_current_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
existing = db.scalars(select(User).where(User.username == body.username)).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="Username already taken")
|
||||
user = User(
|
||||
username=body.username,
|
||||
hashed_password=hash_password(body.password),
|
||||
is_admin=False,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/reset-password")
|
||||
def reset_password(
|
||||
user_id: int,
|
||||
body: ResetPasswordRequest,
|
||||
current_admin: User = Depends(get_current_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
user = db.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
user.hashed_password = hash_password(body.new_password)
|
||||
db.commit()
|
||||
return {"detail": f"Password reset for {user.username}"}
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/disable")
|
||||
def disable_user(
|
||||
user_id: int,
|
||||
current_admin: User = Depends(get_current_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
user = db.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if user.id == current_admin.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot disable your own account")
|
||||
user.is_disabled = True
|
||||
db.commit()
|
||||
return {"detail": f"User {user.username} disabled"}
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/enable")
|
||||
def enable_user(
|
||||
user_id: int,
|
||||
_: User = Depends(get_current_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
user = db.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
user.is_disabled = False
|
||||
db.commit()
|
||||
return {"detail": f"User {user.username} enabled"}
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}", status_code=204)
|
||||
def delete_user(
|
||||
user_id: int,
|
||||
current_admin: User = Depends(get_current_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
user = db.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if user.id == current_admin.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete your own account")
|
||||
db.delete(user)
|
||||
db.commit()
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/impersonate", response_model=TokenResponse)
|
||||
def impersonate_user(
|
||||
user_id: int,
|
||||
_: User = Depends(get_current_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
user = db.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
token = create_access_token(user.id, user.username, user.is_admin, user.timezone)
|
||||
return TokenResponse(access_token=token)
|
||||
85
backend/routers/auth_router.py
Normal file
85
backend/routers/auth_router.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import User
|
||||
from schemas import LoginRequest, TokenResponse, UserOut, UserCreate, ChangePasswordRequest, TimezoneUpdate
|
||||
from auth import verify_password, hash_password, create_access_token, get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
def _make_token(user: User) -> str:
|
||||
return create_access_token(user.id, user.username, user.is_admin, user.timezone)
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
def login(body: LoginRequest, db: Session = Depends(get_db)):
|
||||
user = db.scalars(select(User).where(User.username == body.username)).first()
|
||||
if not user or not verify_password(body.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
if user.is_disabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Account is disabled. Contact your administrator.",
|
||||
)
|
||||
return TokenResponse(access_token=_make_token(user))
|
||||
|
||||
|
||||
@router.post("/register", response_model=TokenResponse, status_code=201)
|
||||
def register(body: UserCreate, db: Session = Depends(get_db)):
|
||||
existing = db.scalars(select(User).where(User.username == body.username)).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="Username already taken")
|
||||
# Default timezone to UTC; user can change it in settings
|
||||
user = User(
|
||||
username=body.username,
|
||||
hashed_password=hash_password(body.password),
|
||||
is_admin=False,
|
||||
timezone="UTC",
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return TokenResponse(access_token=_make_token(user))
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
def me(current_user: User = Depends(get_current_user)):
|
||||
return current_user
|
||||
|
||||
|
||||
@router.post("/change-password")
|
||||
def change_password(
|
||||
body: ChangePasswordRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
if not verify_password(body.current_password, current_user.hashed_password):
|
||||
raise HTTPException(status_code=400, detail="Current password is incorrect")
|
||||
current_user.hashed_password = hash_password(body.new_password)
|
||||
db.commit()
|
||||
return {"detail": "Password updated"}
|
||||
|
||||
|
||||
@router.put("/timezone", response_model=TokenResponse)
|
||||
def update_timezone(
|
||||
body: TimezoneUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
ZoneInfo(body.timezone) # validate it's a real IANA timezone
|
||||
except ZoneInfoNotFoundError:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown timezone: {body.timezone}")
|
||||
current_user.timezone = body.timezone
|
||||
db.commit()
|
||||
db.refresh(current_user)
|
||||
# Return a fresh token with the updated timezone embedded
|
||||
return TokenResponse(access_token=_make_token(current_user))
|
||||
@@ -6,8 +6,9 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import EggCollection
|
||||
from models import EggCollection, User
|
||||
from schemas import EggCollectionCreate, EggCollectionUpdate, EggCollectionOut
|
||||
from auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/eggs", tags=["eggs"])
|
||||
|
||||
@@ -17,8 +18,13 @@ def list_eggs(
|
||||
start: Optional[date] = None,
|
||||
end: Optional[date] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
q = select(EggCollection).order_by(EggCollection.date.desc())
|
||||
q = (
|
||||
select(EggCollection)
|
||||
.where(EggCollection.user_id == current_user.id)
|
||||
.order_by(EggCollection.date.desc())
|
||||
)
|
||||
if start:
|
||||
q = q.where(EggCollection.date >= start)
|
||||
if end:
|
||||
@@ -27,8 +33,12 @@ def list_eggs(
|
||||
|
||||
|
||||
@router.post("", response_model=EggCollectionOut, status_code=201)
|
||||
def create_egg_collection(body: EggCollectionCreate, db: Session = Depends(get_db)):
|
||||
record = EggCollection(**body.model_dump())
|
||||
def create_egg_collection(
|
||||
body: EggCollectionCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = EggCollection(**body.model_dump(), user_id=current_user.id)
|
||||
db.add(record)
|
||||
try:
|
||||
db.commit()
|
||||
@@ -44,8 +54,12 @@ def update_egg_collection(
|
||||
record_id: int,
|
||||
body: EggCollectionUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = db.get(EggCollection, record_id)
|
||||
record = db.scalars(
|
||||
select(EggCollection)
|
||||
.where(EggCollection.id == record_id, EggCollection.user_id == current_user.id)
|
||||
).first()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
for field, value in body.model_dump(exclude_none=True).items():
|
||||
@@ -54,14 +68,21 @@ def update_egg_collection(
|
||||
db.commit()
|
||||
except IntegrityError:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=409, detail=f"An entry for that date already exists.")
|
||||
raise HTTPException(status_code=409, detail="An entry for that date already exists.")
|
||||
db.refresh(record)
|
||||
return record
|
||||
|
||||
|
||||
@router.delete("/{record_id}", status_code=204)
|
||||
def delete_egg_collection(record_id: int, db: Session = Depends(get_db)):
|
||||
record = db.get(EggCollection, record_id)
|
||||
def delete_egg_collection(
|
||||
record_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = db.scalars(
|
||||
select(EggCollection)
|
||||
.where(EggCollection.id == record_id, EggCollection.user_id == current_user.id)
|
||||
).first()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
db.delete(record)
|
||||
|
||||
@@ -5,8 +5,9 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import FeedPurchase
|
||||
from models import FeedPurchase, User
|
||||
from schemas import FeedPurchaseCreate, FeedPurchaseUpdate, FeedPurchaseOut
|
||||
from auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/feed", tags=["feed"])
|
||||
|
||||
@@ -16,8 +17,13 @@ def list_feed_purchases(
|
||||
start: Optional[date] = None,
|
||||
end: Optional[date] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
q = select(FeedPurchase).order_by(FeedPurchase.date.desc())
|
||||
q = (
|
||||
select(FeedPurchase)
|
||||
.where(FeedPurchase.user_id == current_user.id)
|
||||
.order_by(FeedPurchase.date.desc())
|
||||
)
|
||||
if start:
|
||||
q = q.where(FeedPurchase.date >= start)
|
||||
if end:
|
||||
@@ -26,8 +32,12 @@ def list_feed_purchases(
|
||||
|
||||
|
||||
@router.post("", response_model=FeedPurchaseOut, status_code=201)
|
||||
def create_feed_purchase(body: FeedPurchaseCreate, db: Session = Depends(get_db)):
|
||||
record = FeedPurchase(**body.model_dump())
|
||||
def create_feed_purchase(
|
||||
body: FeedPurchaseCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = FeedPurchase(**body.model_dump(), user_id=current_user.id)
|
||||
db.add(record)
|
||||
db.commit()
|
||||
db.refresh(record)
|
||||
@@ -39,8 +49,12 @@ def update_feed_purchase(
|
||||
record_id: int,
|
||||
body: FeedPurchaseUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = db.get(FeedPurchase, record_id)
|
||||
record = db.scalars(
|
||||
select(FeedPurchase)
|
||||
.where(FeedPurchase.id == record_id, FeedPurchase.user_id == current_user.id)
|
||||
).first()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
for field, value in body.model_dump(exclude_none=True).items():
|
||||
@@ -51,8 +65,15 @@ def update_feed_purchase(
|
||||
|
||||
|
||||
@router.delete("/{record_id}", status_code=204)
|
||||
def delete_feed_purchase(record_id: int, db: Session = Depends(get_db)):
|
||||
record = db.get(FeedPurchase, record_id)
|
||||
def delete_feed_purchase(
|
||||
record_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = db.scalars(
|
||||
select(FeedPurchase)
|
||||
.where(FeedPurchase.id == record_id, FeedPurchase.user_id == current_user.id)
|
||||
).first()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
db.delete(record)
|
||||
|
||||
@@ -5,30 +5,49 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import FlockHistory
|
||||
from models import FlockHistory, User
|
||||
from schemas import FlockHistoryCreate, FlockHistoryUpdate, FlockHistoryOut
|
||||
from auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/flock", tags=["flock"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[FlockHistoryOut])
|
||||
def list_flock_history(db: Session = Depends(get_db)):
|
||||
q = select(FlockHistory).order_by(FlockHistory.date.desc())
|
||||
def list_flock_history(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
q = (
|
||||
select(FlockHistory)
|
||||
.where(FlockHistory.user_id == current_user.id)
|
||||
.order_by(FlockHistory.date.desc())
|
||||
)
|
||||
return db.scalars(q).all()
|
||||
|
||||
|
||||
@router.get("/current", response_model=Optional[FlockHistoryOut])
|
||||
def get_current_flock(db: Session = Depends(get_db)):
|
||||
"""Returns the most recent flock entry — the current flock size."""
|
||||
q = select(FlockHistory).order_by(FlockHistory.date.desc()).limit(1)
|
||||
def get_current_flock(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
q = (
|
||||
select(FlockHistory)
|
||||
.where(FlockHistory.user_id == current_user.id)
|
||||
.order_by(FlockHistory.date.desc())
|
||||
.limit(1)
|
||||
)
|
||||
return db.scalars(q).first()
|
||||
|
||||
|
||||
@router.get("/at/{target_date}", response_model=Optional[FlockHistoryOut])
|
||||
def get_flock_at_date(target_date: date, db: Session = Depends(get_db)):
|
||||
"""Returns the flock size that was in effect on a given date."""
|
||||
def get_flock_at_date(
|
||||
target_date: date,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
q = (
|
||||
select(FlockHistory)
|
||||
.where(FlockHistory.user_id == current_user.id)
|
||||
.where(FlockHistory.date <= target_date)
|
||||
.order_by(FlockHistory.date.desc())
|
||||
.limit(1)
|
||||
@@ -37,8 +56,12 @@ def get_flock_at_date(target_date: date, db: Session = Depends(get_db)):
|
||||
|
||||
|
||||
@router.post("", response_model=FlockHistoryOut, status_code=201)
|
||||
def create_flock_entry(body: FlockHistoryCreate, db: Session = Depends(get_db)):
|
||||
record = FlockHistory(**body.model_dump())
|
||||
def create_flock_entry(
|
||||
body: FlockHistoryCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = FlockHistory(**body.model_dump(), user_id=current_user.id)
|
||||
db.add(record)
|
||||
db.commit()
|
||||
db.refresh(record)
|
||||
@@ -50,8 +73,12 @@ def update_flock_entry(
|
||||
record_id: int,
|
||||
body: FlockHistoryUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = db.get(FlockHistory, record_id)
|
||||
record = db.scalars(
|
||||
select(FlockHistory)
|
||||
.where(FlockHistory.id == record_id, FlockHistory.user_id == current_user.id)
|
||||
).first()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
for field, value in body.model_dump(exclude_none=True).items():
|
||||
@@ -62,8 +89,15 @@ def update_flock_entry(
|
||||
|
||||
|
||||
@router.delete("/{record_id}", status_code=204)
|
||||
def delete_flock_entry(record_id: int, db: Session = Depends(get_db)):
|
||||
record = db.get(FlockHistory, record_id)
|
||||
def delete_flock_entry(
|
||||
record_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = db.scalars(
|
||||
select(FlockHistory)
|
||||
.where(FlockHistory.id == record_id, FlockHistory.user_id == current_user.id)
|
||||
).first()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
db.delete(record)
|
||||
|
||||
@@ -5,8 +5,9 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import OtherPurchase
|
||||
from models import OtherPurchase, User
|
||||
from schemas import OtherPurchaseCreate, OtherPurchaseUpdate, OtherPurchaseOut
|
||||
from auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/other", tags=["other"])
|
||||
|
||||
@@ -16,8 +17,13 @@ def list_other_purchases(
|
||||
start: Optional[date] = None,
|
||||
end: Optional[date] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
q = select(OtherPurchase).order_by(OtherPurchase.date.desc())
|
||||
q = (
|
||||
select(OtherPurchase)
|
||||
.where(OtherPurchase.user_id == current_user.id)
|
||||
.order_by(OtherPurchase.date.desc())
|
||||
)
|
||||
if start:
|
||||
q = q.where(OtherPurchase.date >= start)
|
||||
if end:
|
||||
@@ -26,8 +32,12 @@ def list_other_purchases(
|
||||
|
||||
|
||||
@router.post("", response_model=OtherPurchaseOut, status_code=201)
|
||||
def create_other_purchase(body: OtherPurchaseCreate, db: Session = Depends(get_db)):
|
||||
record = OtherPurchase(**body.model_dump())
|
||||
def create_other_purchase(
|
||||
body: OtherPurchaseCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = OtherPurchase(**body.model_dump(), user_id=current_user.id)
|
||||
db.add(record)
|
||||
db.commit()
|
||||
db.refresh(record)
|
||||
@@ -39,8 +49,12 @@ def update_other_purchase(
|
||||
record_id: int,
|
||||
body: OtherPurchaseUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = db.get(OtherPurchase, record_id)
|
||||
record = db.scalars(
|
||||
select(OtherPurchase)
|
||||
.where(OtherPurchase.id == record_id, OtherPurchase.user_id == current_user.id)
|
||||
).first()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
for field, value in body.model_dump(exclude_none=True).items():
|
||||
@@ -51,8 +65,15 @@ def update_other_purchase(
|
||||
|
||||
|
||||
@router.delete("/{record_id}", status_code=204)
|
||||
def delete_other_purchase(record_id: int, db: Session = Depends(get_db)):
|
||||
record = db.get(OtherPurchase, record_id)
|
||||
def delete_other_purchase(
|
||||
record_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = db.scalars(
|
||||
select(OtherPurchase)
|
||||
.where(OtherPurchase.id == record_id, OtherPurchase.user_id == current_user.id)
|
||||
).first()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
db.delete(record)
|
||||
|
||||
@@ -1,25 +1,30 @@
|
||||
import calendar
|
||||
from datetime import date, timedelta
|
||||
from datetime import date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import EggCollection, FlockHistory, FeedPurchase, OtherPurchase
|
||||
from models import EggCollection, FlockHistory, FeedPurchase, OtherPurchase, User
|
||||
from schemas import DashboardStats, BudgetStats, MonthlySummary
|
||||
from auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/stats", tags=["stats"])
|
||||
|
||||
|
||||
def _avg_per_hen_30d(db: Session, start_30d: date) -> float | None:
|
||||
"""
|
||||
For each collection in the last 30 days, look up the flock size that was
|
||||
in effect on that date using a correlated subquery, then average eggs/hen
|
||||
across those days. This gives an accurate result even when flock size changed.
|
||||
"""
|
||||
def _today(user_timezone: str) -> date:
|
||||
try:
|
||||
return datetime.now(ZoneInfo(user_timezone)).date()
|
||||
except ZoneInfoNotFoundError:
|
||||
return date.today()
|
||||
|
||||
|
||||
def _avg_per_hen_30d(db: Session, user_id: int, start_30d: date) -> float | None:
|
||||
flock_at_date = (
|
||||
select(FlockHistory.chicken_count)
|
||||
.where(FlockHistory.user_id == user_id)
|
||||
.where(FlockHistory.date <= EggCollection.date)
|
||||
.order_by(FlockHistory.date.desc())
|
||||
.limit(1)
|
||||
@@ -29,6 +34,7 @@ def _avg_per_hen_30d(db: Session, start_30d: date) -> float | None:
|
||||
|
||||
rows = db.execute(
|
||||
select(EggCollection.eggs, flock_at_date.label('flock_count'))
|
||||
.where(EggCollection.user_id == user_id)
|
||||
.where(EggCollection.date >= start_30d)
|
||||
).all()
|
||||
|
||||
@@ -38,15 +44,18 @@ def _avg_per_hen_30d(db: Session, start_30d: date) -> float | None:
|
||||
return round(sum(e / f for e, f in valid) / len(valid), 3)
|
||||
|
||||
|
||||
def _current_flock(db: Session) -> int | None:
|
||||
def _current_flock(db: Session, user_id: int) -> int | None:
|
||||
row = db.scalars(
|
||||
select(FlockHistory).order_by(FlockHistory.date.desc()).limit(1)
|
||||
select(FlockHistory)
|
||||
.where(FlockHistory.user_id == user_id)
|
||||
.order_by(FlockHistory.date.desc())
|
||||
.limit(1)
|
||||
).first()
|
||||
return row.chicken_count if row else None
|
||||
|
||||
|
||||
def _total_eggs(db: Session, start: date | None = None, end: date | None = None) -> int:
|
||||
q = select(func.coalesce(func.sum(EggCollection.eggs), 0))
|
||||
def _total_eggs(db: Session, user_id: int, start: date | None = None, end: date | None = None) -> int:
|
||||
q = select(func.coalesce(func.sum(EggCollection.eggs), 0)).where(EggCollection.user_id == user_id)
|
||||
if start:
|
||||
q = q.where(EggCollection.date >= start)
|
||||
if end:
|
||||
@@ -54,10 +63,10 @@ def _total_eggs(db: Session, start: date | None = None, end: date | None = None)
|
||||
return db.scalar(q)
|
||||
|
||||
|
||||
def _total_feed_cost(db: Session, start: date | None = None, end: date | None = None):
|
||||
def _total_feed_cost(db: Session, user_id: int, start: date | None = None, end: date | None = None):
|
||||
q = select(
|
||||
func.coalesce(func.sum(FeedPurchase.bags * FeedPurchase.price_per_bag), 0)
|
||||
)
|
||||
).where(FeedPurchase.user_id == user_id)
|
||||
if start:
|
||||
q = q.where(FeedPurchase.date >= start)
|
||||
if end:
|
||||
@@ -65,8 +74,8 @@ def _total_feed_cost(db: Session, start: date | None = None, end: date | None =
|
||||
return db.scalar(q)
|
||||
|
||||
|
||||
def _total_other_cost(db: Session, start: date | None = None, end: date | None = None):
|
||||
q = select(func.coalesce(func.sum(OtherPurchase.total), 0))
|
||||
def _total_other_cost(db: Session, user_id: int, start: date | None = None, end: date | None = None):
|
||||
q = select(func.coalesce(func.sum(OtherPurchase.total), 0)).where(OtherPurchase.user_id == user_id)
|
||||
if start:
|
||||
q = q.where(OtherPurchase.date >= start)
|
||||
if end:
|
||||
@@ -75,29 +84,33 @@ def _total_other_cost(db: Session, start: date | None = None, end: date | None =
|
||||
|
||||
|
||||
@router.get("/dashboard", response_model=DashboardStats)
|
||||
def dashboard_stats(db: Session = Depends(get_db)):
|
||||
today = date.today()
|
||||
start_30d = today - timedelta(days=30)
|
||||
start_7d = today - timedelta(days=7)
|
||||
def dashboard_stats(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
uid = current_user.id
|
||||
today = _today(current_user.timezone)
|
||||
start_30d = today - timedelta(days=30)
|
||||
start_7d = today - timedelta(days=7)
|
||||
|
||||
total_alltime = _total_eggs(db)
|
||||
total_30d = _total_eggs(db, start=start_30d)
|
||||
total_7d = _total_eggs(db, start=start_7d)
|
||||
flock = _current_flock(db)
|
||||
total_alltime = _total_eggs(db, uid)
|
||||
total_30d = _total_eggs(db, uid, start=start_30d)
|
||||
total_7d = _total_eggs(db, uid, start=start_7d)
|
||||
flock = _current_flock(db, uid)
|
||||
|
||||
# Count how many distinct days have a collection logged
|
||||
days_tracked = db.scalar(
|
||||
select(func.count(func.distinct(EggCollection.date)))
|
||||
.where(EggCollection.user_id == uid)
|
||||
)
|
||||
|
||||
# Average eggs per day over the last 30 days (only counting days with data)
|
||||
days_with_data_30d = db.scalar(
|
||||
select(func.count(func.distinct(EggCollection.date)))
|
||||
.where(EggCollection.user_id == uid)
|
||||
.where(EggCollection.date >= start_30d)
|
||||
)
|
||||
|
||||
avg_per_day = round(total_30d / days_with_data_30d, 2) if days_with_data_30d else None
|
||||
avg_per_hen = _avg_per_hen_30d(db, start_30d)
|
||||
avg_per_hen = _avg_per_hen_30d(db, uid, start_30d)
|
||||
|
||||
return DashboardStats(
|
||||
current_flock=flock,
|
||||
@@ -111,10 +124,13 @@ def dashboard_stats(db: Session = Depends(get_db)):
|
||||
|
||||
|
||||
@router.get("/monthly", response_model=list[MonthlySummary])
|
||||
def monthly_stats(db: Session = Depends(get_db)):
|
||||
def monthly_stats(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
uid = current_user.id
|
||||
MONTH_NAMES = ['Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep','Oct','Nov','Dec']
|
||||
|
||||
# Monthly egg totals
|
||||
egg_rows = db.execute(
|
||||
select(
|
||||
func.year(EggCollection.date).label('year'),
|
||||
@@ -122,6 +138,7 @@ def monthly_stats(db: Session = Depends(get_db)):
|
||||
func.sum(EggCollection.eggs).label('total_eggs'),
|
||||
func.count(EggCollection.date).label('days_logged'),
|
||||
)
|
||||
.where(EggCollection.user_id == uid)
|
||||
.group_by(func.year(EggCollection.date), func.month(EggCollection.date))
|
||||
.order_by(func.year(EggCollection.date).desc(), func.month(EggCollection.date).desc())
|
||||
).all()
|
||||
@@ -129,25 +146,25 @@ def monthly_stats(db: Session = Depends(get_db)):
|
||||
if not egg_rows:
|
||||
return []
|
||||
|
||||
# Monthly feed costs
|
||||
feed_rows = db.execute(
|
||||
select(
|
||||
func.year(FeedPurchase.date).label('year'),
|
||||
func.month(FeedPurchase.date).label('month'),
|
||||
func.sum(FeedPurchase.bags * FeedPurchase.price_per_bag).label('feed_cost'),
|
||||
)
|
||||
.where(FeedPurchase.user_id == uid)
|
||||
.group_by(func.year(FeedPurchase.date), func.month(FeedPurchase.date))
|
||||
).all()
|
||||
|
||||
feed_map = {(r.year, r.month): r.feed_cost for r in feed_rows}
|
||||
|
||||
# Monthly other costs
|
||||
other_rows = db.execute(
|
||||
select(
|
||||
func.year(OtherPurchase.date).label('year'),
|
||||
func.month(OtherPurchase.date).label('month'),
|
||||
func.sum(OtherPurchase.total).label('other_cost'),
|
||||
)
|
||||
.where(OtherPurchase.user_id == uid)
|
||||
.group_by(func.year(OtherPurchase.date), func.month(OtherPurchase.date))
|
||||
).all()
|
||||
|
||||
@@ -159,9 +176,9 @@ def monthly_stats(db: Session = Depends(get_db)):
|
||||
last_day = calendar.monthrange(y, m)[1]
|
||||
month_end = date(y, m, last_day)
|
||||
|
||||
# Flock size in effect at the end of this month
|
||||
flock_row = db.scalars(
|
||||
select(FlockHistory)
|
||||
.where(FlockHistory.user_id == uid)
|
||||
.where(FlockHistory.date <= month_end)
|
||||
.order_by(FlockHistory.date.desc())
|
||||
.limit(1)
|
||||
@@ -201,16 +218,20 @@ def monthly_stats(db: Session = Depends(get_db)):
|
||||
|
||||
|
||||
@router.get("/budget", response_model=BudgetStats)
|
||||
def budget_stats(db: Session = Depends(get_db)):
|
||||
today = date.today()
|
||||
def budget_stats(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
uid = current_user.id
|
||||
today = _today(current_user.timezone)
|
||||
start_30d = today - timedelta(days=30)
|
||||
|
||||
total_feed_cost = _total_feed_cost(db)
|
||||
total_feed_cost_30d = _total_feed_cost(db, start=start_30d)
|
||||
total_other_cost = _total_other_cost(db)
|
||||
total_other_cost_30d = _total_other_cost(db, start=start_30d)
|
||||
total_eggs = _total_eggs(db)
|
||||
total_eggs_30d = _total_eggs(db, start=start_30d)
|
||||
total_feed_cost = _total_feed_cost(db, uid)
|
||||
total_feed_cost_30d = _total_feed_cost(db, uid, start=start_30d)
|
||||
total_other_cost = _total_other_cost(db, uid)
|
||||
total_other_cost_30d = _total_other_cost(db, uid, start=start_30d)
|
||||
total_eggs = _total_eggs(db, uid)
|
||||
total_eggs_30d = _total_eggs(db, uid, start=start_30d)
|
||||
|
||||
def cost_per_egg(cost, eggs):
|
||||
if not eggs or not cost:
|
||||
|
||||
Reference in New Issue
Block a user