Add multi-user auth, admin panel, and timezone support; rename to Yolkbook

- Rename app from Eggtracker to Yolkbook throughout
- Add JWT-based authentication (python-jose, passlib/bcrypt)
- Add users table; all data tables gain user_id FK for full data isolation
- Super admin credentials sourced from ADMIN_USERNAME/ADMIN_PASSWORD env vars,
  synced on every startup; orphaned rows auto-assigned to admin post-migration
- Login page with self-registration; JWT stored in localStorage (30-day expiry)
- Admin panel (/admin): list users, reset passwords, disable/enable, delete,
  and impersonate (Login As) with Return to Admin banner
- Settings modal (gear icon in nav): timezone selector and change password
- Timezone stored per-user; stats date windows computed in user's timezone;
  date input setToday() respects user timezone via Intl API
- migrate_v2.sql for existing single-user installs
- Auto-migration adds timezone column to users on startup
- Updated README with full setup, auth, admin, and migration docs

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-17 23:19:29 -07:00
parent 7d50af0054
commit aa12648228
31 changed files with 1572 additions and 140 deletions

72
backend/auth.py Normal file
View File

@@ -0,0 +1,72 @@
import os
from datetime import datetime, timedelta, timezone
from typing import Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
from database import get_db
from models import User
SECRET_KEY = os.environ["JWT_SECRET"]
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_DAYS = 30
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def hash_password(password: str) -> str:
return pwd_context.hash(password)
def create_access_token(user_id: int, username: str, is_admin: bool, user_timezone: str = "UTC") -> str:
expire = datetime.now(timezone.utc) + timedelta(days=ACCESS_TOKEN_EXPIRE_DAYS)
payload = {
"sub": str(user_id),
"username": username,
"is_admin": is_admin,
"timezone": user_timezone,
"exp": expire,
}
return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db),
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id_str: Optional[str] = payload.get("sub")
if user_id_str is None:
raise credentials_exception
user_id = int(user_id_str)
except (JWTError, ValueError):
raise credentials_exception
user = db.get(User, user_id)
if user is None or user.is_disabled:
raise credentials_exception
return user
async def get_current_admin(current_user: User = Depends(get_current_user)) -> User:
if not current_user.is_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin access required",
)
return current_user

View File

@@ -1,11 +1,81 @@
import os
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import select, update, text
from database import Base, engine, SessionLocal
from models import User, EggCollection, FlockHistory, FeedPurchase, OtherPurchase
from auth import hash_password
from routers import eggs, flock, feed, stats, other
from routers import auth_router, admin
app = FastAPI(title="Eggtracker API")
logger = logging.getLogger("yolkbook")
def _seed_admin():
"""Create or update the admin user from environment variables.
Also assigns any records with NULL user_id to the admin (post-migration).
"""
admin_username = os.environ["ADMIN_USERNAME"]
admin_password = os.environ["ADMIN_PASSWORD"]
with SessionLocal() as db:
admin_user = db.scalars(
select(User).where(User.username == admin_username)
).first()
if admin_user is None:
admin_user = User(
username=admin_username,
hashed_password=hash_password(admin_password),
is_admin=True,
)
db.add(admin_user)
db.commit()
db.refresh(admin_user)
logger.info("Admin user '%s' created.", admin_username)
else:
# Always sync password + admin flag from env vars
admin_user.hashed_password = hash_password(admin_password)
admin_user.is_admin = True
db.commit()
# Assign orphaned records (from pre-migration data) to admin
for model in [EggCollection, FlockHistory, FeedPurchase, OtherPurchase]:
db.execute(
update(model)
.where(model.user_id == None) # noqa: E711
.values(user_id=admin_user.id)
)
db.commit()
def _run_migrations():
"""Apply incremental schema changes that create_all won't handle on existing tables."""
with SessionLocal() as db:
# v2.1 — timezone column on users
try:
db.execute(text(
"ALTER TABLE users ADD COLUMN timezone VARCHAR(64) NOT NULL DEFAULT 'UTC'"
))
db.commit()
except Exception:
db.rollback() # column already exists — safe to ignore
@asynccontextmanager
async def lifespan(app: FastAPI):
Base.metadata.create_all(bind=engine)
_run_migrations()
_seed_admin()
yield
app = FastAPI(title="Yolkbook API", lifespan=lifespan)
# Allow requests from the Nginx frontend (same host, different port internally)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
@@ -13,6 +83,8 @@ app.add_middleware(
allow_headers=["*"],
)
app.include_router(auth_router.router)
app.include_router(admin.router)
app.include_router(eggs.router)
app.include_router(flock.router)
app.include_router(feed.router)

View File

@@ -1,13 +1,27 @@
from datetime import date, datetime
from sqlalchemy import Integer, Date, DateTime, Text, Numeric, func
from sqlalchemy import Boolean, Integer, Date, DateTime, Text, Numeric, String, ForeignKey, UniqueConstraint, func
from sqlalchemy.orm import Mapped, mapped_column
from database import Base
class User(Base):
__tablename__ = "users"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
is_admin: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
is_disabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
timezone: Mapped[str] = mapped_column(String(64), nullable=False, default='UTC')
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
class EggCollection(Base):
__tablename__ = "egg_collections"
__table_args__ = (UniqueConstraint("user_id", "date", name="uq_user_date"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
date: Mapped[date] = mapped_column(Date, nullable=False, index=True)
eggs: Mapped[int] = mapped_column(Integer, nullable=False)
notes: Mapped[str] = mapped_column(Text, nullable=True)
@@ -18,6 +32,7 @@ class FlockHistory(Base):
__tablename__ = "flock_history"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
date: Mapped[date] = mapped_column(Date, nullable=False, index=True)
chicken_count: Mapped[int] = mapped_column(Integer, nullable=False)
notes: Mapped[str] = mapped_column(Text, nullable=True)
@@ -28,6 +43,7 @@ class FeedPurchase(Base):
__tablename__ = "feed_purchases"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
date: Mapped[date] = mapped_column(Date, nullable=False, index=True)
bags: Mapped[float] = mapped_column(Numeric(5, 2), nullable=False)
price_per_bag: Mapped[float] = mapped_column(Numeric(10, 2), nullable=False)
@@ -39,6 +55,7 @@ class OtherPurchase(Base):
__tablename__ = "other_purchases"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
date: Mapped[date] = mapped_column(Date, nullable=False, index=True)
total: Mapped[float] = mapped_column(Numeric(10, 2), nullable=False)
notes: Mapped[str] = mapped_column(Text, nullable=True)

View File

@@ -4,3 +4,6 @@ sqlalchemy==2.0.36
pymysql==1.1.1
cryptography==43.0.3
pydantic==2.9.2
python-jose[cryptography]==3.3.0
passlib[bcrypt]==1.7.4
bcrypt==4.0.1

111
backend/routers/admin.py Normal file
View File

@@ -0,0 +1,111 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.orm import Session
from database import get_db
from models import User
from schemas import UserCreate, UserOut, ResetPasswordRequest, TokenResponse
from auth import hash_password, create_access_token, get_current_admin, get_current_user
router = APIRouter(prefix="/api/admin", tags=["admin"])
@router.get("/users", response_model=list[UserOut])
def list_users(
_: User = Depends(get_current_admin),
db: Session = Depends(get_db),
):
return db.scalars(select(User).order_by(User.created_at)).all()
@router.post("/users", response_model=UserOut, status_code=201)
def create_user(
body: UserCreate,
_: User = Depends(get_current_admin),
db: Session = Depends(get_db),
):
existing = db.scalars(select(User).where(User.username == body.username)).first()
if existing:
raise HTTPException(status_code=409, detail="Username already taken")
user = User(
username=body.username,
hashed_password=hash_password(body.password),
is_admin=False,
)
db.add(user)
db.commit()
db.refresh(user)
return user
@router.post("/users/{user_id}/reset-password")
def reset_password(
user_id: int,
body: ResetPasswordRequest,
current_admin: User = Depends(get_current_admin),
db: Session = Depends(get_db),
):
user = db.get(User, user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
user.hashed_password = hash_password(body.new_password)
db.commit()
return {"detail": f"Password reset for {user.username}"}
@router.post("/users/{user_id}/disable")
def disable_user(
user_id: int,
current_admin: User = Depends(get_current_admin),
db: Session = Depends(get_db),
):
user = db.get(User, user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
if user.id == current_admin.id:
raise HTTPException(status_code=400, detail="Cannot disable your own account")
user.is_disabled = True
db.commit()
return {"detail": f"User {user.username} disabled"}
@router.post("/users/{user_id}/enable")
def enable_user(
user_id: int,
_: User = Depends(get_current_admin),
db: Session = Depends(get_db),
):
user = db.get(User, user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
user.is_disabled = False
db.commit()
return {"detail": f"User {user.username} enabled"}
@router.delete("/users/{user_id}", status_code=204)
def delete_user(
user_id: int,
current_admin: User = Depends(get_current_admin),
db: Session = Depends(get_db),
):
user = db.get(User, user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
if user.id == current_admin.id:
raise HTTPException(status_code=400, detail="Cannot delete your own account")
db.delete(user)
db.commit()
@router.post("/users/{user_id}/impersonate", response_model=TokenResponse)
def impersonate_user(
user_id: int,
_: User = Depends(get_current_admin),
db: Session = Depends(get_db),
):
user = db.get(User, user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
token = create_access_token(user.id, user.username, user.is_admin, user.timezone)
return TokenResponse(access_token=token)

View File

@@ -0,0 +1,85 @@
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.orm import Session
from database import get_db
from models import User
from schemas import LoginRequest, TokenResponse, UserOut, UserCreate, ChangePasswordRequest, TimezoneUpdate
from auth import verify_password, hash_password, create_access_token, get_current_user
router = APIRouter(prefix="/api/auth", tags=["auth"])
def _make_token(user: User) -> str:
return create_access_token(user.id, user.username, user.is_admin, user.timezone)
@router.post("/login", response_model=TokenResponse)
def login(body: LoginRequest, db: Session = Depends(get_db)):
user = db.scalars(select(User).where(User.username == body.username)).first()
if not user or not verify_password(body.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password",
)
if user.is_disabled:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account is disabled. Contact your administrator.",
)
return TokenResponse(access_token=_make_token(user))
@router.post("/register", response_model=TokenResponse, status_code=201)
def register(body: UserCreate, db: Session = Depends(get_db)):
existing = db.scalars(select(User).where(User.username == body.username)).first()
if existing:
raise HTTPException(status_code=409, detail="Username already taken")
# Default timezone to UTC; user can change it in settings
user = User(
username=body.username,
hashed_password=hash_password(body.password),
is_admin=False,
timezone="UTC",
)
db.add(user)
db.commit()
db.refresh(user)
return TokenResponse(access_token=_make_token(user))
@router.get("/me", response_model=UserOut)
def me(current_user: User = Depends(get_current_user)):
return current_user
@router.post("/change-password")
def change_password(
body: ChangePasswordRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
if not verify_password(body.current_password, current_user.hashed_password):
raise HTTPException(status_code=400, detail="Current password is incorrect")
current_user.hashed_password = hash_password(body.new_password)
db.commit()
return {"detail": "Password updated"}
@router.put("/timezone", response_model=TokenResponse)
def update_timezone(
body: TimezoneUpdate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
try:
ZoneInfo(body.timezone) # validate it's a real IANA timezone
except ZoneInfoNotFoundError:
raise HTTPException(status_code=400, detail=f"Unknown timezone: {body.timezone}")
current_user.timezone = body.timezone
db.commit()
db.refresh(current_user)
# Return a fresh token with the updated timezone embedded
return TokenResponse(access_token=_make_token(current_user))

View File

@@ -6,8 +6,9 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from database import get_db
from models import EggCollection
from models import EggCollection, User
from schemas import EggCollectionCreate, EggCollectionUpdate, EggCollectionOut
from auth import get_current_user
router = APIRouter(prefix="/api/eggs", tags=["eggs"])
@@ -17,8 +18,13 @@ def list_eggs(
start: Optional[date] = None,
end: Optional[date] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
q = select(EggCollection).order_by(EggCollection.date.desc())
q = (
select(EggCollection)
.where(EggCollection.user_id == current_user.id)
.order_by(EggCollection.date.desc())
)
if start:
q = q.where(EggCollection.date >= start)
if end:
@@ -27,8 +33,12 @@ def list_eggs(
@router.post("", response_model=EggCollectionOut, status_code=201)
def create_egg_collection(body: EggCollectionCreate, db: Session = Depends(get_db)):
record = EggCollection(**body.model_dump())
def create_egg_collection(
body: EggCollectionCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
record = EggCollection(**body.model_dump(), user_id=current_user.id)
db.add(record)
try:
db.commit()
@@ -44,8 +54,12 @@ def update_egg_collection(
record_id: int,
body: EggCollectionUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
record = db.get(EggCollection, record_id)
record = db.scalars(
select(EggCollection)
.where(EggCollection.id == record_id, EggCollection.user_id == current_user.id)
).first()
if not record:
raise HTTPException(status_code=404, detail="Record not found")
for field, value in body.model_dump(exclude_none=True).items():
@@ -54,14 +68,21 @@ def update_egg_collection(
db.commit()
except IntegrityError:
db.rollback()
raise HTTPException(status_code=409, detail=f"An entry for that date already exists.")
raise HTTPException(status_code=409, detail="An entry for that date already exists.")
db.refresh(record)
return record
@router.delete("/{record_id}", status_code=204)
def delete_egg_collection(record_id: int, db: Session = Depends(get_db)):
record = db.get(EggCollection, record_id)
def delete_egg_collection(
record_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
record = db.scalars(
select(EggCollection)
.where(EggCollection.id == record_id, EggCollection.user_id == current_user.id)
).first()
if not record:
raise HTTPException(status_code=404, detail="Record not found")
db.delete(record)

View File

@@ -5,8 +5,9 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from database import get_db
from models import FeedPurchase
from models import FeedPurchase, User
from schemas import FeedPurchaseCreate, FeedPurchaseUpdate, FeedPurchaseOut
from auth import get_current_user
router = APIRouter(prefix="/api/feed", tags=["feed"])
@@ -16,8 +17,13 @@ def list_feed_purchases(
start: Optional[date] = None,
end: Optional[date] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
q = select(FeedPurchase).order_by(FeedPurchase.date.desc())
q = (
select(FeedPurchase)
.where(FeedPurchase.user_id == current_user.id)
.order_by(FeedPurchase.date.desc())
)
if start:
q = q.where(FeedPurchase.date >= start)
if end:
@@ -26,8 +32,12 @@ def list_feed_purchases(
@router.post("", response_model=FeedPurchaseOut, status_code=201)
def create_feed_purchase(body: FeedPurchaseCreate, db: Session = Depends(get_db)):
record = FeedPurchase(**body.model_dump())
def create_feed_purchase(
body: FeedPurchaseCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
record = FeedPurchase(**body.model_dump(), user_id=current_user.id)
db.add(record)
db.commit()
db.refresh(record)
@@ -39,8 +49,12 @@ def update_feed_purchase(
record_id: int,
body: FeedPurchaseUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
record = db.get(FeedPurchase, record_id)
record = db.scalars(
select(FeedPurchase)
.where(FeedPurchase.id == record_id, FeedPurchase.user_id == current_user.id)
).first()
if not record:
raise HTTPException(status_code=404, detail="Record not found")
for field, value in body.model_dump(exclude_none=True).items():
@@ -51,8 +65,15 @@ def update_feed_purchase(
@router.delete("/{record_id}", status_code=204)
def delete_feed_purchase(record_id: int, db: Session = Depends(get_db)):
record = db.get(FeedPurchase, record_id)
def delete_feed_purchase(
record_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
record = db.scalars(
select(FeedPurchase)
.where(FeedPurchase.id == record_id, FeedPurchase.user_id == current_user.id)
).first()
if not record:
raise HTTPException(status_code=404, detail="Record not found")
db.delete(record)

View File

@@ -5,30 +5,49 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from database import get_db
from models import FlockHistory
from models import FlockHistory, User
from schemas import FlockHistoryCreate, FlockHistoryUpdate, FlockHistoryOut
from auth import get_current_user
router = APIRouter(prefix="/api/flock", tags=["flock"])
@router.get("", response_model=list[FlockHistoryOut])
def list_flock_history(db: Session = Depends(get_db)):
q = select(FlockHistory).order_by(FlockHistory.date.desc())
def list_flock_history(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
q = (
select(FlockHistory)
.where(FlockHistory.user_id == current_user.id)
.order_by(FlockHistory.date.desc())
)
return db.scalars(q).all()
@router.get("/current", response_model=Optional[FlockHistoryOut])
def get_current_flock(db: Session = Depends(get_db)):
"""Returns the most recent flock entry — the current flock size."""
q = select(FlockHistory).order_by(FlockHistory.date.desc()).limit(1)
def get_current_flock(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
q = (
select(FlockHistory)
.where(FlockHistory.user_id == current_user.id)
.order_by(FlockHistory.date.desc())
.limit(1)
)
return db.scalars(q).first()
@router.get("/at/{target_date}", response_model=Optional[FlockHistoryOut])
def get_flock_at_date(target_date: date, db: Session = Depends(get_db)):
"""Returns the flock size that was in effect on a given date."""
def get_flock_at_date(
target_date: date,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
q = (
select(FlockHistory)
.where(FlockHistory.user_id == current_user.id)
.where(FlockHistory.date <= target_date)
.order_by(FlockHistory.date.desc())
.limit(1)
@@ -37,8 +56,12 @@ def get_flock_at_date(target_date: date, db: Session = Depends(get_db)):
@router.post("", response_model=FlockHistoryOut, status_code=201)
def create_flock_entry(body: FlockHistoryCreate, db: Session = Depends(get_db)):
record = FlockHistory(**body.model_dump())
def create_flock_entry(
body: FlockHistoryCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
record = FlockHistory(**body.model_dump(), user_id=current_user.id)
db.add(record)
db.commit()
db.refresh(record)
@@ -50,8 +73,12 @@ def update_flock_entry(
record_id: int,
body: FlockHistoryUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
record = db.get(FlockHistory, record_id)
record = db.scalars(
select(FlockHistory)
.where(FlockHistory.id == record_id, FlockHistory.user_id == current_user.id)
).first()
if not record:
raise HTTPException(status_code=404, detail="Record not found")
for field, value in body.model_dump(exclude_none=True).items():
@@ -62,8 +89,15 @@ def update_flock_entry(
@router.delete("/{record_id}", status_code=204)
def delete_flock_entry(record_id: int, db: Session = Depends(get_db)):
record = db.get(FlockHistory, record_id)
def delete_flock_entry(
record_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
record = db.scalars(
select(FlockHistory)
.where(FlockHistory.id == record_id, FlockHistory.user_id == current_user.id)
).first()
if not record:
raise HTTPException(status_code=404, detail="Record not found")
db.delete(record)

View File

@@ -5,8 +5,9 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from database import get_db
from models import OtherPurchase
from models import OtherPurchase, User
from schemas import OtherPurchaseCreate, OtherPurchaseUpdate, OtherPurchaseOut
from auth import get_current_user
router = APIRouter(prefix="/api/other", tags=["other"])
@@ -16,8 +17,13 @@ def list_other_purchases(
start: Optional[date] = None,
end: Optional[date] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
q = select(OtherPurchase).order_by(OtherPurchase.date.desc())
q = (
select(OtherPurchase)
.where(OtherPurchase.user_id == current_user.id)
.order_by(OtherPurchase.date.desc())
)
if start:
q = q.where(OtherPurchase.date >= start)
if end:
@@ -26,8 +32,12 @@ def list_other_purchases(
@router.post("", response_model=OtherPurchaseOut, status_code=201)
def create_other_purchase(body: OtherPurchaseCreate, db: Session = Depends(get_db)):
record = OtherPurchase(**body.model_dump())
def create_other_purchase(
body: OtherPurchaseCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
record = OtherPurchase(**body.model_dump(), user_id=current_user.id)
db.add(record)
db.commit()
db.refresh(record)
@@ -39,8 +49,12 @@ def update_other_purchase(
record_id: int,
body: OtherPurchaseUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
record = db.get(OtherPurchase, record_id)
record = db.scalars(
select(OtherPurchase)
.where(OtherPurchase.id == record_id, OtherPurchase.user_id == current_user.id)
).first()
if not record:
raise HTTPException(status_code=404, detail="Record not found")
for field, value in body.model_dump(exclude_none=True).items():
@@ -51,8 +65,15 @@ def update_other_purchase(
@router.delete("/{record_id}", status_code=204)
def delete_other_purchase(record_id: int, db: Session = Depends(get_db)):
record = db.get(OtherPurchase, record_id)
def delete_other_purchase(
record_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
record = db.scalars(
select(OtherPurchase)
.where(OtherPurchase.id == record_id, OtherPurchase.user_id == current_user.id)
).first()
if not record:
raise HTTPException(status_code=404, detail="Record not found")
db.delete(record)

View File

@@ -1,25 +1,30 @@
import calendar
from datetime import date, timedelta
from datetime import date, datetime, timedelta
from decimal import Decimal
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from fastapi import APIRouter, Depends
from sqlalchemy import select, func
from sqlalchemy.orm import Session
from database import get_db
from models import EggCollection, FlockHistory, FeedPurchase, OtherPurchase
from models import EggCollection, FlockHistory, FeedPurchase, OtherPurchase, User
from schemas import DashboardStats, BudgetStats, MonthlySummary
from auth import get_current_user
router = APIRouter(prefix="/api/stats", tags=["stats"])
def _avg_per_hen_30d(db: Session, start_30d: date) -> float | None:
"""
For each collection in the last 30 days, look up the flock size that was
in effect on that date using a correlated subquery, then average eggs/hen
across those days. This gives an accurate result even when flock size changed.
"""
def _today(user_timezone: str) -> date:
try:
return datetime.now(ZoneInfo(user_timezone)).date()
except ZoneInfoNotFoundError:
return date.today()
def _avg_per_hen_30d(db: Session, user_id: int, start_30d: date) -> float | None:
flock_at_date = (
select(FlockHistory.chicken_count)
.where(FlockHistory.user_id == user_id)
.where(FlockHistory.date <= EggCollection.date)
.order_by(FlockHistory.date.desc())
.limit(1)
@@ -29,6 +34,7 @@ def _avg_per_hen_30d(db: Session, start_30d: date) -> float | None:
rows = db.execute(
select(EggCollection.eggs, flock_at_date.label('flock_count'))
.where(EggCollection.user_id == user_id)
.where(EggCollection.date >= start_30d)
).all()
@@ -38,15 +44,18 @@ def _avg_per_hen_30d(db: Session, start_30d: date) -> float | None:
return round(sum(e / f for e, f in valid) / len(valid), 3)
def _current_flock(db: Session) -> int | None:
def _current_flock(db: Session, user_id: int) -> int | None:
row = db.scalars(
select(FlockHistory).order_by(FlockHistory.date.desc()).limit(1)
select(FlockHistory)
.where(FlockHistory.user_id == user_id)
.order_by(FlockHistory.date.desc())
.limit(1)
).first()
return row.chicken_count if row else None
def _total_eggs(db: Session, start: date | None = None, end: date | None = None) -> int:
q = select(func.coalesce(func.sum(EggCollection.eggs), 0))
def _total_eggs(db: Session, user_id: int, start: date | None = None, end: date | None = None) -> int:
q = select(func.coalesce(func.sum(EggCollection.eggs), 0)).where(EggCollection.user_id == user_id)
if start:
q = q.where(EggCollection.date >= start)
if end:
@@ -54,10 +63,10 @@ def _total_eggs(db: Session, start: date | None = None, end: date | None = None)
return db.scalar(q)
def _total_feed_cost(db: Session, start: date | None = None, end: date | None = None):
def _total_feed_cost(db: Session, user_id: int, start: date | None = None, end: date | None = None):
q = select(
func.coalesce(func.sum(FeedPurchase.bags * FeedPurchase.price_per_bag), 0)
)
).where(FeedPurchase.user_id == user_id)
if start:
q = q.where(FeedPurchase.date >= start)
if end:
@@ -65,8 +74,8 @@ def _total_feed_cost(db: Session, start: date | None = None, end: date | None =
return db.scalar(q)
def _total_other_cost(db: Session, start: date | None = None, end: date | None = None):
q = select(func.coalesce(func.sum(OtherPurchase.total), 0))
def _total_other_cost(db: Session, user_id: int, start: date | None = None, end: date | None = None):
q = select(func.coalesce(func.sum(OtherPurchase.total), 0)).where(OtherPurchase.user_id == user_id)
if start:
q = q.where(OtherPurchase.date >= start)
if end:
@@ -75,29 +84,33 @@ def _total_other_cost(db: Session, start: date | None = None, end: date | None =
@router.get("/dashboard", response_model=DashboardStats)
def dashboard_stats(db: Session = Depends(get_db)):
today = date.today()
start_30d = today - timedelta(days=30)
start_7d = today - timedelta(days=7)
def dashboard_stats(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
uid = current_user.id
today = _today(current_user.timezone)
start_30d = today - timedelta(days=30)
start_7d = today - timedelta(days=7)
total_alltime = _total_eggs(db)
total_30d = _total_eggs(db, start=start_30d)
total_7d = _total_eggs(db, start=start_7d)
flock = _current_flock(db)
total_alltime = _total_eggs(db, uid)
total_30d = _total_eggs(db, uid, start=start_30d)
total_7d = _total_eggs(db, uid, start=start_7d)
flock = _current_flock(db, uid)
# Count how many distinct days have a collection logged
days_tracked = db.scalar(
select(func.count(func.distinct(EggCollection.date)))
.where(EggCollection.user_id == uid)
)
# Average eggs per day over the last 30 days (only counting days with data)
days_with_data_30d = db.scalar(
select(func.count(func.distinct(EggCollection.date)))
.where(EggCollection.user_id == uid)
.where(EggCollection.date >= start_30d)
)
avg_per_day = round(total_30d / days_with_data_30d, 2) if days_with_data_30d else None
avg_per_hen = _avg_per_hen_30d(db, start_30d)
avg_per_hen = _avg_per_hen_30d(db, uid, start_30d)
return DashboardStats(
current_flock=flock,
@@ -111,10 +124,13 @@ def dashboard_stats(db: Session = Depends(get_db)):
@router.get("/monthly", response_model=list[MonthlySummary])
def monthly_stats(db: Session = Depends(get_db)):
def monthly_stats(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
uid = current_user.id
MONTH_NAMES = ['Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep','Oct','Nov','Dec']
# Monthly egg totals
egg_rows = db.execute(
select(
func.year(EggCollection.date).label('year'),
@@ -122,6 +138,7 @@ def monthly_stats(db: Session = Depends(get_db)):
func.sum(EggCollection.eggs).label('total_eggs'),
func.count(EggCollection.date).label('days_logged'),
)
.where(EggCollection.user_id == uid)
.group_by(func.year(EggCollection.date), func.month(EggCollection.date))
.order_by(func.year(EggCollection.date).desc(), func.month(EggCollection.date).desc())
).all()
@@ -129,25 +146,25 @@ def monthly_stats(db: Session = Depends(get_db)):
if not egg_rows:
return []
# Monthly feed costs
feed_rows = db.execute(
select(
func.year(FeedPurchase.date).label('year'),
func.month(FeedPurchase.date).label('month'),
func.sum(FeedPurchase.bags * FeedPurchase.price_per_bag).label('feed_cost'),
)
.where(FeedPurchase.user_id == uid)
.group_by(func.year(FeedPurchase.date), func.month(FeedPurchase.date))
).all()
feed_map = {(r.year, r.month): r.feed_cost for r in feed_rows}
# Monthly other costs
other_rows = db.execute(
select(
func.year(OtherPurchase.date).label('year'),
func.month(OtherPurchase.date).label('month'),
func.sum(OtherPurchase.total).label('other_cost'),
)
.where(OtherPurchase.user_id == uid)
.group_by(func.year(OtherPurchase.date), func.month(OtherPurchase.date))
).all()
@@ -159,9 +176,9 @@ def monthly_stats(db: Session = Depends(get_db)):
last_day = calendar.monthrange(y, m)[1]
month_end = date(y, m, last_day)
# Flock size in effect at the end of this month
flock_row = db.scalars(
select(FlockHistory)
.where(FlockHistory.user_id == uid)
.where(FlockHistory.date <= month_end)
.order_by(FlockHistory.date.desc())
.limit(1)
@@ -201,16 +218,20 @@ def monthly_stats(db: Session = Depends(get_db)):
@router.get("/budget", response_model=BudgetStats)
def budget_stats(db: Session = Depends(get_db)):
today = date.today()
def budget_stats(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
uid = current_user.id
today = _today(current_user.timezone)
start_30d = today - timedelta(days=30)
total_feed_cost = _total_feed_cost(db)
total_feed_cost_30d = _total_feed_cost(db, start=start_30d)
total_other_cost = _total_other_cost(db)
total_other_cost_30d = _total_other_cost(db, start=start_30d)
total_eggs = _total_eggs(db)
total_eggs_30d = _total_eggs(db, start=start_30d)
total_feed_cost = _total_feed_cost(db, uid)
total_feed_cost_30d = _total_feed_cost(db, uid, start=start_30d)
total_other_cost = _total_other_cost(db, uid)
total_other_cost_30d = _total_other_cost(db, uid, start=start_30d)
total_eggs = _total_eggs(db, uid)
total_eggs_30d = _total_eggs(db, uid, start=start_30d)
def cost_per_egg(cost, eggs):
if not eggs or not cost:

View File

@@ -4,6 +4,44 @@ from typing import Optional
from pydantic import BaseModel, Field
# ── Auth ──────────────────────────────────────────────────────────────────────
class LoginRequest(BaseModel):
username: str
password: str
class TokenResponse(BaseModel):
access_token: str
token_type: str = "bearer"
class ChangePasswordRequest(BaseModel):
current_password: str
new_password: str = Field(min_length=6)
class ResetPasswordRequest(BaseModel):
new_password: str = Field(min_length=6)
class TimezoneUpdate(BaseModel):
timezone: str = Field(min_length=1, max_length=64)
# ── Users ─────────────────────────────────────────────────────────────────────
class UserCreate(BaseModel):
username: str = Field(min_length=2, max_length=64)
password: str = Field(min_length=6)
class UserOut(BaseModel):
id: int
username: str
is_admin: bool
is_disabled: bool
timezone: str
created_at: datetime
model_config = {"from_attributes": True}
# ── Egg Collections ───────────────────────────────────────────────────────────
class EggCollectionCreate(BaseModel):