Add multi-user auth, admin panel, and timezone support; rename to Yolkbook
- Rename app from Eggtracker to Yolkbook throughout - Add JWT-based authentication (python-jose, passlib/bcrypt) - Add users table; all data tables gain user_id FK for full data isolation - Super admin credentials sourced from ADMIN_USERNAME/ADMIN_PASSWORD env vars, synced on every startup; orphaned rows auto-assigned to admin post-migration - Login page with self-registration; JWT stored in localStorage (30-day expiry) - Admin panel (/admin): list users, reset passwords, disable/enable, delete, and impersonate (Login As) with Return to Admin banner - Settings modal (gear icon in nav): timezone selector and change password - Timezone stored per-user; stats date windows computed in user's timezone; date input setToday() respects user timezone via Intl API - migrate_v2.sql for existing single-user installs - Auto-migration adds timezone column to users on startup - Updated README with full setup, auth, admin, and migration docs Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
72
backend/auth.py
Normal file
72
backend/auth.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import User
|
||||
|
||||
SECRET_KEY = os.environ["JWT_SECRET"]
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_DAYS = 30
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(user_id: int, username: str, is_admin: bool, user_timezone: str = "UTC") -> str:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=ACCESS_TOKEN_EXPIRE_DAYS)
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"username": username,
|
||||
"is_admin": is_admin,
|
||||
"timezone": user_timezone,
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: Session = Depends(get_db),
|
||||
) -> User:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id_str: Optional[str] = payload.get("sub")
|
||||
if user_id_str is None:
|
||||
raise credentials_exception
|
||||
user_id = int(user_id_str)
|
||||
except (JWTError, ValueError):
|
||||
raise credentials_exception
|
||||
|
||||
user = db.get(User, user_id)
|
||||
if user is None or user.is_disabled:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_admin(current_user: User = Depends(get_current_user)) -> User:
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin access required",
|
||||
)
|
||||
return current_user
|
||||
@@ -1,11 +1,81 @@
|
||||
import os
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqlalchemy import select, update, text
|
||||
|
||||
from database import Base, engine, SessionLocal
|
||||
from models import User, EggCollection, FlockHistory, FeedPurchase, OtherPurchase
|
||||
from auth import hash_password
|
||||
from routers import eggs, flock, feed, stats, other
|
||||
from routers import auth_router, admin
|
||||
|
||||
app = FastAPI(title="Eggtracker API")
|
||||
logger = logging.getLogger("yolkbook")
|
||||
|
||||
|
||||
def _seed_admin():
|
||||
"""Create or update the admin user from environment variables.
|
||||
Also assigns any records with NULL user_id to the admin (post-migration).
|
||||
"""
|
||||
admin_username = os.environ["ADMIN_USERNAME"]
|
||||
admin_password = os.environ["ADMIN_PASSWORD"]
|
||||
|
||||
with SessionLocal() as db:
|
||||
admin_user = db.scalars(
|
||||
select(User).where(User.username == admin_username)
|
||||
).first()
|
||||
|
||||
if admin_user is None:
|
||||
admin_user = User(
|
||||
username=admin_username,
|
||||
hashed_password=hash_password(admin_password),
|
||||
is_admin=True,
|
||||
)
|
||||
db.add(admin_user)
|
||||
db.commit()
|
||||
db.refresh(admin_user)
|
||||
logger.info("Admin user '%s' created.", admin_username)
|
||||
else:
|
||||
# Always sync password + admin flag from env vars
|
||||
admin_user.hashed_password = hash_password(admin_password)
|
||||
admin_user.is_admin = True
|
||||
db.commit()
|
||||
|
||||
# Assign orphaned records (from pre-migration data) to admin
|
||||
for model in [EggCollection, FlockHistory, FeedPurchase, OtherPurchase]:
|
||||
db.execute(
|
||||
update(model)
|
||||
.where(model.user_id == None) # noqa: E711
|
||||
.values(user_id=admin_user.id)
|
||||
)
|
||||
db.commit()
|
||||
|
||||
|
||||
def _run_migrations():
|
||||
"""Apply incremental schema changes that create_all won't handle on existing tables."""
|
||||
with SessionLocal() as db:
|
||||
# v2.1 — timezone column on users
|
||||
try:
|
||||
db.execute(text(
|
||||
"ALTER TABLE users ADD COLUMN timezone VARCHAR(64) NOT NULL DEFAULT 'UTC'"
|
||||
))
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback() # column already exists — safe to ignore
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
Base.metadata.create_all(bind=engine)
|
||||
_run_migrations()
|
||||
_seed_admin()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="Yolkbook API", lifespan=lifespan)
|
||||
|
||||
# Allow requests from the Nginx frontend (same host, different port internally)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
@@ -13,6 +83,8 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(auth_router.router)
|
||||
app.include_router(admin.router)
|
||||
app.include_router(eggs.router)
|
||||
app.include_router(flock.router)
|
||||
app.include_router(feed.router)
|
||||
|
||||
@@ -1,13 +1,27 @@
|
||||
from datetime import date, datetime
|
||||
from sqlalchemy import Integer, Date, DateTime, Text, Numeric, func
|
||||
from sqlalchemy import Boolean, Integer, Date, DateTime, Text, Numeric, String, ForeignKey, UniqueConstraint, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from database import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
is_admin: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
is_disabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
timezone: Mapped[str] = mapped_column(String(64), nullable=False, default='UTC')
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
|
||||
|
||||
class EggCollection(Base):
|
||||
__tablename__ = "egg_collections"
|
||||
__table_args__ = (UniqueConstraint("user_id", "date", name="uq_user_date"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
date: Mapped[date] = mapped_column(Date, nullable=False, index=True)
|
||||
eggs: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
notes: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
@@ -18,6 +32,7 @@ class FlockHistory(Base):
|
||||
__tablename__ = "flock_history"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
date: Mapped[date] = mapped_column(Date, nullable=False, index=True)
|
||||
chicken_count: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
notes: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
@@ -28,6 +43,7 @@ class FeedPurchase(Base):
|
||||
__tablename__ = "feed_purchases"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
date: Mapped[date] = mapped_column(Date, nullable=False, index=True)
|
||||
bags: Mapped[float] = mapped_column(Numeric(5, 2), nullable=False)
|
||||
price_per_bag: Mapped[float] = mapped_column(Numeric(10, 2), nullable=False)
|
||||
@@ -39,6 +55,7 @@ class OtherPurchase(Base):
|
||||
__tablename__ = "other_purchases"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
date: Mapped[date] = mapped_column(Date, nullable=False, index=True)
|
||||
total: Mapped[float] = mapped_column(Numeric(10, 2), nullable=False)
|
||||
notes: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
|
||||
@@ -4,3 +4,6 @@ sqlalchemy==2.0.36
|
||||
pymysql==1.1.1
|
||||
cryptography==43.0.3
|
||||
pydantic==2.9.2
|
||||
python-jose[cryptography]==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
bcrypt==4.0.1
|
||||
|
||||
111
backend/routers/admin.py
Normal file
111
backend/routers/admin.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import User
|
||||
from schemas import UserCreate, UserOut, ResetPasswordRequest, TokenResponse
|
||||
from auth import hash_password, create_access_token, get_current_admin, get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/admin", tags=["admin"])
|
||||
|
||||
|
||||
@router.get("/users", response_model=list[UserOut])
|
||||
def list_users(
|
||||
_: User = Depends(get_current_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return db.scalars(select(User).order_by(User.created_at)).all()
|
||||
|
||||
|
||||
@router.post("/users", response_model=UserOut, status_code=201)
|
||||
def create_user(
|
||||
body: UserCreate,
|
||||
_: User = Depends(get_current_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
existing = db.scalars(select(User).where(User.username == body.username)).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="Username already taken")
|
||||
user = User(
|
||||
username=body.username,
|
||||
hashed_password=hash_password(body.password),
|
||||
is_admin=False,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/reset-password")
|
||||
def reset_password(
|
||||
user_id: int,
|
||||
body: ResetPasswordRequest,
|
||||
current_admin: User = Depends(get_current_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
user = db.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
user.hashed_password = hash_password(body.new_password)
|
||||
db.commit()
|
||||
return {"detail": f"Password reset for {user.username}"}
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/disable")
|
||||
def disable_user(
|
||||
user_id: int,
|
||||
current_admin: User = Depends(get_current_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
user = db.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if user.id == current_admin.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot disable your own account")
|
||||
user.is_disabled = True
|
||||
db.commit()
|
||||
return {"detail": f"User {user.username} disabled"}
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/enable")
|
||||
def enable_user(
|
||||
user_id: int,
|
||||
_: User = Depends(get_current_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
user = db.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
user.is_disabled = False
|
||||
db.commit()
|
||||
return {"detail": f"User {user.username} enabled"}
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}", status_code=204)
|
||||
def delete_user(
|
||||
user_id: int,
|
||||
current_admin: User = Depends(get_current_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
user = db.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if user.id == current_admin.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete your own account")
|
||||
db.delete(user)
|
||||
db.commit()
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/impersonate", response_model=TokenResponse)
|
||||
def impersonate_user(
|
||||
user_id: int,
|
||||
_: User = Depends(get_current_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
user = db.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
token = create_access_token(user.id, user.username, user.is_admin, user.timezone)
|
||||
return TokenResponse(access_token=token)
|
||||
85
backend/routers/auth_router.py
Normal file
85
backend/routers/auth_router.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import User
|
||||
from schemas import LoginRequest, TokenResponse, UserOut, UserCreate, ChangePasswordRequest, TimezoneUpdate
|
||||
from auth import verify_password, hash_password, create_access_token, get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
def _make_token(user: User) -> str:
|
||||
return create_access_token(user.id, user.username, user.is_admin, user.timezone)
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
def login(body: LoginRequest, db: Session = Depends(get_db)):
|
||||
user = db.scalars(select(User).where(User.username == body.username)).first()
|
||||
if not user or not verify_password(body.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
if user.is_disabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Account is disabled. Contact your administrator.",
|
||||
)
|
||||
return TokenResponse(access_token=_make_token(user))
|
||||
|
||||
|
||||
@router.post("/register", response_model=TokenResponse, status_code=201)
|
||||
def register(body: UserCreate, db: Session = Depends(get_db)):
|
||||
existing = db.scalars(select(User).where(User.username == body.username)).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="Username already taken")
|
||||
# Default timezone to UTC; user can change it in settings
|
||||
user = User(
|
||||
username=body.username,
|
||||
hashed_password=hash_password(body.password),
|
||||
is_admin=False,
|
||||
timezone="UTC",
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return TokenResponse(access_token=_make_token(user))
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
def me(current_user: User = Depends(get_current_user)):
|
||||
return current_user
|
||||
|
||||
|
||||
@router.post("/change-password")
|
||||
def change_password(
|
||||
body: ChangePasswordRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
if not verify_password(body.current_password, current_user.hashed_password):
|
||||
raise HTTPException(status_code=400, detail="Current password is incorrect")
|
||||
current_user.hashed_password = hash_password(body.new_password)
|
||||
db.commit()
|
||||
return {"detail": "Password updated"}
|
||||
|
||||
|
||||
@router.put("/timezone", response_model=TokenResponse)
|
||||
def update_timezone(
|
||||
body: TimezoneUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
ZoneInfo(body.timezone) # validate it's a real IANA timezone
|
||||
except ZoneInfoNotFoundError:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown timezone: {body.timezone}")
|
||||
current_user.timezone = body.timezone
|
||||
db.commit()
|
||||
db.refresh(current_user)
|
||||
# Return a fresh token with the updated timezone embedded
|
||||
return TokenResponse(access_token=_make_token(current_user))
|
||||
@@ -6,8 +6,9 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import EggCollection
|
||||
from models import EggCollection, User
|
||||
from schemas import EggCollectionCreate, EggCollectionUpdate, EggCollectionOut
|
||||
from auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/eggs", tags=["eggs"])
|
||||
|
||||
@@ -17,8 +18,13 @@ def list_eggs(
|
||||
start: Optional[date] = None,
|
||||
end: Optional[date] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
q = select(EggCollection).order_by(EggCollection.date.desc())
|
||||
q = (
|
||||
select(EggCollection)
|
||||
.where(EggCollection.user_id == current_user.id)
|
||||
.order_by(EggCollection.date.desc())
|
||||
)
|
||||
if start:
|
||||
q = q.where(EggCollection.date >= start)
|
||||
if end:
|
||||
@@ -27,8 +33,12 @@ def list_eggs(
|
||||
|
||||
|
||||
@router.post("", response_model=EggCollectionOut, status_code=201)
|
||||
def create_egg_collection(body: EggCollectionCreate, db: Session = Depends(get_db)):
|
||||
record = EggCollection(**body.model_dump())
|
||||
def create_egg_collection(
|
||||
body: EggCollectionCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = EggCollection(**body.model_dump(), user_id=current_user.id)
|
||||
db.add(record)
|
||||
try:
|
||||
db.commit()
|
||||
@@ -44,8 +54,12 @@ def update_egg_collection(
|
||||
record_id: int,
|
||||
body: EggCollectionUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = db.get(EggCollection, record_id)
|
||||
record = db.scalars(
|
||||
select(EggCollection)
|
||||
.where(EggCollection.id == record_id, EggCollection.user_id == current_user.id)
|
||||
).first()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
for field, value in body.model_dump(exclude_none=True).items():
|
||||
@@ -54,14 +68,21 @@ def update_egg_collection(
|
||||
db.commit()
|
||||
except IntegrityError:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=409, detail=f"An entry for that date already exists.")
|
||||
raise HTTPException(status_code=409, detail="An entry for that date already exists.")
|
||||
db.refresh(record)
|
||||
return record
|
||||
|
||||
|
||||
@router.delete("/{record_id}", status_code=204)
|
||||
def delete_egg_collection(record_id: int, db: Session = Depends(get_db)):
|
||||
record = db.get(EggCollection, record_id)
|
||||
def delete_egg_collection(
|
||||
record_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = db.scalars(
|
||||
select(EggCollection)
|
||||
.where(EggCollection.id == record_id, EggCollection.user_id == current_user.id)
|
||||
).first()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
db.delete(record)
|
||||
|
||||
@@ -5,8 +5,9 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import FeedPurchase
|
||||
from models import FeedPurchase, User
|
||||
from schemas import FeedPurchaseCreate, FeedPurchaseUpdate, FeedPurchaseOut
|
||||
from auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/feed", tags=["feed"])
|
||||
|
||||
@@ -16,8 +17,13 @@ def list_feed_purchases(
|
||||
start: Optional[date] = None,
|
||||
end: Optional[date] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
q = select(FeedPurchase).order_by(FeedPurchase.date.desc())
|
||||
q = (
|
||||
select(FeedPurchase)
|
||||
.where(FeedPurchase.user_id == current_user.id)
|
||||
.order_by(FeedPurchase.date.desc())
|
||||
)
|
||||
if start:
|
||||
q = q.where(FeedPurchase.date >= start)
|
||||
if end:
|
||||
@@ -26,8 +32,12 @@ def list_feed_purchases(
|
||||
|
||||
|
||||
@router.post("", response_model=FeedPurchaseOut, status_code=201)
|
||||
def create_feed_purchase(body: FeedPurchaseCreate, db: Session = Depends(get_db)):
|
||||
record = FeedPurchase(**body.model_dump())
|
||||
def create_feed_purchase(
|
||||
body: FeedPurchaseCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = FeedPurchase(**body.model_dump(), user_id=current_user.id)
|
||||
db.add(record)
|
||||
db.commit()
|
||||
db.refresh(record)
|
||||
@@ -39,8 +49,12 @@ def update_feed_purchase(
|
||||
record_id: int,
|
||||
body: FeedPurchaseUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = db.get(FeedPurchase, record_id)
|
||||
record = db.scalars(
|
||||
select(FeedPurchase)
|
||||
.where(FeedPurchase.id == record_id, FeedPurchase.user_id == current_user.id)
|
||||
).first()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
for field, value in body.model_dump(exclude_none=True).items():
|
||||
@@ -51,8 +65,15 @@ def update_feed_purchase(
|
||||
|
||||
|
||||
@router.delete("/{record_id}", status_code=204)
|
||||
def delete_feed_purchase(record_id: int, db: Session = Depends(get_db)):
|
||||
record = db.get(FeedPurchase, record_id)
|
||||
def delete_feed_purchase(
|
||||
record_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = db.scalars(
|
||||
select(FeedPurchase)
|
||||
.where(FeedPurchase.id == record_id, FeedPurchase.user_id == current_user.id)
|
||||
).first()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
db.delete(record)
|
||||
|
||||
@@ -5,30 +5,49 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import FlockHistory
|
||||
from models import FlockHistory, User
|
||||
from schemas import FlockHistoryCreate, FlockHistoryUpdate, FlockHistoryOut
|
||||
from auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/flock", tags=["flock"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[FlockHistoryOut])
|
||||
def list_flock_history(db: Session = Depends(get_db)):
|
||||
q = select(FlockHistory).order_by(FlockHistory.date.desc())
|
||||
def list_flock_history(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
q = (
|
||||
select(FlockHistory)
|
||||
.where(FlockHistory.user_id == current_user.id)
|
||||
.order_by(FlockHistory.date.desc())
|
||||
)
|
||||
return db.scalars(q).all()
|
||||
|
||||
|
||||
@router.get("/current", response_model=Optional[FlockHistoryOut])
|
||||
def get_current_flock(db: Session = Depends(get_db)):
|
||||
"""Returns the most recent flock entry — the current flock size."""
|
||||
q = select(FlockHistory).order_by(FlockHistory.date.desc()).limit(1)
|
||||
def get_current_flock(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
q = (
|
||||
select(FlockHistory)
|
||||
.where(FlockHistory.user_id == current_user.id)
|
||||
.order_by(FlockHistory.date.desc())
|
||||
.limit(1)
|
||||
)
|
||||
return db.scalars(q).first()
|
||||
|
||||
|
||||
@router.get("/at/{target_date}", response_model=Optional[FlockHistoryOut])
|
||||
def get_flock_at_date(target_date: date, db: Session = Depends(get_db)):
|
||||
"""Returns the flock size that was in effect on a given date."""
|
||||
def get_flock_at_date(
|
||||
target_date: date,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
q = (
|
||||
select(FlockHistory)
|
||||
.where(FlockHistory.user_id == current_user.id)
|
||||
.where(FlockHistory.date <= target_date)
|
||||
.order_by(FlockHistory.date.desc())
|
||||
.limit(1)
|
||||
@@ -37,8 +56,12 @@ def get_flock_at_date(target_date: date, db: Session = Depends(get_db)):
|
||||
|
||||
|
||||
@router.post("", response_model=FlockHistoryOut, status_code=201)
|
||||
def create_flock_entry(body: FlockHistoryCreate, db: Session = Depends(get_db)):
|
||||
record = FlockHistory(**body.model_dump())
|
||||
def create_flock_entry(
|
||||
body: FlockHistoryCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = FlockHistory(**body.model_dump(), user_id=current_user.id)
|
||||
db.add(record)
|
||||
db.commit()
|
||||
db.refresh(record)
|
||||
@@ -50,8 +73,12 @@ def update_flock_entry(
|
||||
record_id: int,
|
||||
body: FlockHistoryUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = db.get(FlockHistory, record_id)
|
||||
record = db.scalars(
|
||||
select(FlockHistory)
|
||||
.where(FlockHistory.id == record_id, FlockHistory.user_id == current_user.id)
|
||||
).first()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
for field, value in body.model_dump(exclude_none=True).items():
|
||||
@@ -62,8 +89,15 @@ def update_flock_entry(
|
||||
|
||||
|
||||
@router.delete("/{record_id}", status_code=204)
|
||||
def delete_flock_entry(record_id: int, db: Session = Depends(get_db)):
|
||||
record = db.get(FlockHistory, record_id)
|
||||
def delete_flock_entry(
|
||||
record_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = db.scalars(
|
||||
select(FlockHistory)
|
||||
.where(FlockHistory.id == record_id, FlockHistory.user_id == current_user.id)
|
||||
).first()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
db.delete(record)
|
||||
|
||||
@@ -5,8 +5,9 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import OtherPurchase
|
||||
from models import OtherPurchase, User
|
||||
from schemas import OtherPurchaseCreate, OtherPurchaseUpdate, OtherPurchaseOut
|
||||
from auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/other", tags=["other"])
|
||||
|
||||
@@ -16,8 +17,13 @@ def list_other_purchases(
|
||||
start: Optional[date] = None,
|
||||
end: Optional[date] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
q = select(OtherPurchase).order_by(OtherPurchase.date.desc())
|
||||
q = (
|
||||
select(OtherPurchase)
|
||||
.where(OtherPurchase.user_id == current_user.id)
|
||||
.order_by(OtherPurchase.date.desc())
|
||||
)
|
||||
if start:
|
||||
q = q.where(OtherPurchase.date >= start)
|
||||
if end:
|
||||
@@ -26,8 +32,12 @@ def list_other_purchases(
|
||||
|
||||
|
||||
@router.post("", response_model=OtherPurchaseOut, status_code=201)
|
||||
def create_other_purchase(body: OtherPurchaseCreate, db: Session = Depends(get_db)):
|
||||
record = OtherPurchase(**body.model_dump())
|
||||
def create_other_purchase(
|
||||
body: OtherPurchaseCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = OtherPurchase(**body.model_dump(), user_id=current_user.id)
|
||||
db.add(record)
|
||||
db.commit()
|
||||
db.refresh(record)
|
||||
@@ -39,8 +49,12 @@ def update_other_purchase(
|
||||
record_id: int,
|
||||
body: OtherPurchaseUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = db.get(OtherPurchase, record_id)
|
||||
record = db.scalars(
|
||||
select(OtherPurchase)
|
||||
.where(OtherPurchase.id == record_id, OtherPurchase.user_id == current_user.id)
|
||||
).first()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
for field, value in body.model_dump(exclude_none=True).items():
|
||||
@@ -51,8 +65,15 @@ def update_other_purchase(
|
||||
|
||||
|
||||
@router.delete("/{record_id}", status_code=204)
|
||||
def delete_other_purchase(record_id: int, db: Session = Depends(get_db)):
|
||||
record = db.get(OtherPurchase, record_id)
|
||||
def delete_other_purchase(
|
||||
record_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
record = db.scalars(
|
||||
select(OtherPurchase)
|
||||
.where(OtherPurchase.id == record_id, OtherPurchase.user_id == current_user.id)
|
||||
).first()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
db.delete(record)
|
||||
|
||||
@@ -1,25 +1,30 @@
|
||||
import calendar
|
||||
from datetime import date, timedelta
|
||||
from datetime import date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import EggCollection, FlockHistory, FeedPurchase, OtherPurchase
|
||||
from models import EggCollection, FlockHistory, FeedPurchase, OtherPurchase, User
|
||||
from schemas import DashboardStats, BudgetStats, MonthlySummary
|
||||
from auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/stats", tags=["stats"])
|
||||
|
||||
|
||||
def _avg_per_hen_30d(db: Session, start_30d: date) -> float | None:
|
||||
"""
|
||||
For each collection in the last 30 days, look up the flock size that was
|
||||
in effect on that date using a correlated subquery, then average eggs/hen
|
||||
across those days. This gives an accurate result even when flock size changed.
|
||||
"""
|
||||
def _today(user_timezone: str) -> date:
|
||||
try:
|
||||
return datetime.now(ZoneInfo(user_timezone)).date()
|
||||
except ZoneInfoNotFoundError:
|
||||
return date.today()
|
||||
|
||||
|
||||
def _avg_per_hen_30d(db: Session, user_id: int, start_30d: date) -> float | None:
|
||||
flock_at_date = (
|
||||
select(FlockHistory.chicken_count)
|
||||
.where(FlockHistory.user_id == user_id)
|
||||
.where(FlockHistory.date <= EggCollection.date)
|
||||
.order_by(FlockHistory.date.desc())
|
||||
.limit(1)
|
||||
@@ -29,6 +34,7 @@ def _avg_per_hen_30d(db: Session, start_30d: date) -> float | None:
|
||||
|
||||
rows = db.execute(
|
||||
select(EggCollection.eggs, flock_at_date.label('flock_count'))
|
||||
.where(EggCollection.user_id == user_id)
|
||||
.where(EggCollection.date >= start_30d)
|
||||
).all()
|
||||
|
||||
@@ -38,15 +44,18 @@ def _avg_per_hen_30d(db: Session, start_30d: date) -> float | None:
|
||||
return round(sum(e / f for e, f in valid) / len(valid), 3)
|
||||
|
||||
|
||||
def _current_flock(db: Session) -> int | None:
|
||||
def _current_flock(db: Session, user_id: int) -> int | None:
|
||||
row = db.scalars(
|
||||
select(FlockHistory).order_by(FlockHistory.date.desc()).limit(1)
|
||||
select(FlockHistory)
|
||||
.where(FlockHistory.user_id == user_id)
|
||||
.order_by(FlockHistory.date.desc())
|
||||
.limit(1)
|
||||
).first()
|
||||
return row.chicken_count if row else None
|
||||
|
||||
|
||||
def _total_eggs(db: Session, start: date | None = None, end: date | None = None) -> int:
|
||||
q = select(func.coalesce(func.sum(EggCollection.eggs), 0))
|
||||
def _total_eggs(db: Session, user_id: int, start: date | None = None, end: date | None = None) -> int:
|
||||
q = select(func.coalesce(func.sum(EggCollection.eggs), 0)).where(EggCollection.user_id == user_id)
|
||||
if start:
|
||||
q = q.where(EggCollection.date >= start)
|
||||
if end:
|
||||
@@ -54,10 +63,10 @@ def _total_eggs(db: Session, start: date | None = None, end: date | None = None)
|
||||
return db.scalar(q)
|
||||
|
||||
|
||||
def _total_feed_cost(db: Session, start: date | None = None, end: date | None = None):
|
||||
def _total_feed_cost(db: Session, user_id: int, start: date | None = None, end: date | None = None):
|
||||
q = select(
|
||||
func.coalesce(func.sum(FeedPurchase.bags * FeedPurchase.price_per_bag), 0)
|
||||
)
|
||||
).where(FeedPurchase.user_id == user_id)
|
||||
if start:
|
||||
q = q.where(FeedPurchase.date >= start)
|
||||
if end:
|
||||
@@ -65,8 +74,8 @@ def _total_feed_cost(db: Session, start: date | None = None, end: date | None =
|
||||
return db.scalar(q)
|
||||
|
||||
|
||||
def _total_other_cost(db: Session, start: date | None = None, end: date | None = None):
|
||||
q = select(func.coalesce(func.sum(OtherPurchase.total), 0))
|
||||
def _total_other_cost(db: Session, user_id: int, start: date | None = None, end: date | None = None):
|
||||
q = select(func.coalesce(func.sum(OtherPurchase.total), 0)).where(OtherPurchase.user_id == user_id)
|
||||
if start:
|
||||
q = q.where(OtherPurchase.date >= start)
|
||||
if end:
|
||||
@@ -75,29 +84,33 @@ def _total_other_cost(db: Session, start: date | None = None, end: date | None =
|
||||
|
||||
|
||||
@router.get("/dashboard", response_model=DashboardStats)
|
||||
def dashboard_stats(db: Session = Depends(get_db)):
|
||||
today = date.today()
|
||||
start_30d = today - timedelta(days=30)
|
||||
start_7d = today - timedelta(days=7)
|
||||
def dashboard_stats(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
uid = current_user.id
|
||||
today = _today(current_user.timezone)
|
||||
start_30d = today - timedelta(days=30)
|
||||
start_7d = today - timedelta(days=7)
|
||||
|
||||
total_alltime = _total_eggs(db)
|
||||
total_30d = _total_eggs(db, start=start_30d)
|
||||
total_7d = _total_eggs(db, start=start_7d)
|
||||
flock = _current_flock(db)
|
||||
total_alltime = _total_eggs(db, uid)
|
||||
total_30d = _total_eggs(db, uid, start=start_30d)
|
||||
total_7d = _total_eggs(db, uid, start=start_7d)
|
||||
flock = _current_flock(db, uid)
|
||||
|
||||
# Count how many distinct days have a collection logged
|
||||
days_tracked = db.scalar(
|
||||
select(func.count(func.distinct(EggCollection.date)))
|
||||
.where(EggCollection.user_id == uid)
|
||||
)
|
||||
|
||||
# Average eggs per day over the last 30 days (only counting days with data)
|
||||
days_with_data_30d = db.scalar(
|
||||
select(func.count(func.distinct(EggCollection.date)))
|
||||
.where(EggCollection.user_id == uid)
|
||||
.where(EggCollection.date >= start_30d)
|
||||
)
|
||||
|
||||
avg_per_day = round(total_30d / days_with_data_30d, 2) if days_with_data_30d else None
|
||||
avg_per_hen = _avg_per_hen_30d(db, start_30d)
|
||||
avg_per_hen = _avg_per_hen_30d(db, uid, start_30d)
|
||||
|
||||
return DashboardStats(
|
||||
current_flock=flock,
|
||||
@@ -111,10 +124,13 @@ def dashboard_stats(db: Session = Depends(get_db)):
|
||||
|
||||
|
||||
@router.get("/monthly", response_model=list[MonthlySummary])
|
||||
def monthly_stats(db: Session = Depends(get_db)):
|
||||
def monthly_stats(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
uid = current_user.id
|
||||
MONTH_NAMES = ['Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep','Oct','Nov','Dec']
|
||||
|
||||
# Monthly egg totals
|
||||
egg_rows = db.execute(
|
||||
select(
|
||||
func.year(EggCollection.date).label('year'),
|
||||
@@ -122,6 +138,7 @@ def monthly_stats(db: Session = Depends(get_db)):
|
||||
func.sum(EggCollection.eggs).label('total_eggs'),
|
||||
func.count(EggCollection.date).label('days_logged'),
|
||||
)
|
||||
.where(EggCollection.user_id == uid)
|
||||
.group_by(func.year(EggCollection.date), func.month(EggCollection.date))
|
||||
.order_by(func.year(EggCollection.date).desc(), func.month(EggCollection.date).desc())
|
||||
).all()
|
||||
@@ -129,25 +146,25 @@ def monthly_stats(db: Session = Depends(get_db)):
|
||||
if not egg_rows:
|
||||
return []
|
||||
|
||||
# Monthly feed costs
|
||||
feed_rows = db.execute(
|
||||
select(
|
||||
func.year(FeedPurchase.date).label('year'),
|
||||
func.month(FeedPurchase.date).label('month'),
|
||||
func.sum(FeedPurchase.bags * FeedPurchase.price_per_bag).label('feed_cost'),
|
||||
)
|
||||
.where(FeedPurchase.user_id == uid)
|
||||
.group_by(func.year(FeedPurchase.date), func.month(FeedPurchase.date))
|
||||
).all()
|
||||
|
||||
feed_map = {(r.year, r.month): r.feed_cost for r in feed_rows}
|
||||
|
||||
# Monthly other costs
|
||||
other_rows = db.execute(
|
||||
select(
|
||||
func.year(OtherPurchase.date).label('year'),
|
||||
func.month(OtherPurchase.date).label('month'),
|
||||
func.sum(OtherPurchase.total).label('other_cost'),
|
||||
)
|
||||
.where(OtherPurchase.user_id == uid)
|
||||
.group_by(func.year(OtherPurchase.date), func.month(OtherPurchase.date))
|
||||
).all()
|
||||
|
||||
@@ -159,9 +176,9 @@ def monthly_stats(db: Session = Depends(get_db)):
|
||||
last_day = calendar.monthrange(y, m)[1]
|
||||
month_end = date(y, m, last_day)
|
||||
|
||||
# Flock size in effect at the end of this month
|
||||
flock_row = db.scalars(
|
||||
select(FlockHistory)
|
||||
.where(FlockHistory.user_id == uid)
|
||||
.where(FlockHistory.date <= month_end)
|
||||
.order_by(FlockHistory.date.desc())
|
||||
.limit(1)
|
||||
@@ -201,16 +218,20 @@ def monthly_stats(db: Session = Depends(get_db)):
|
||||
|
||||
|
||||
@router.get("/budget", response_model=BudgetStats)
|
||||
def budget_stats(db: Session = Depends(get_db)):
|
||||
today = date.today()
|
||||
def budget_stats(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
uid = current_user.id
|
||||
today = _today(current_user.timezone)
|
||||
start_30d = today - timedelta(days=30)
|
||||
|
||||
total_feed_cost = _total_feed_cost(db)
|
||||
total_feed_cost_30d = _total_feed_cost(db, start=start_30d)
|
||||
total_other_cost = _total_other_cost(db)
|
||||
total_other_cost_30d = _total_other_cost(db, start=start_30d)
|
||||
total_eggs = _total_eggs(db)
|
||||
total_eggs_30d = _total_eggs(db, start=start_30d)
|
||||
total_feed_cost = _total_feed_cost(db, uid)
|
||||
total_feed_cost_30d = _total_feed_cost(db, uid, start=start_30d)
|
||||
total_other_cost = _total_other_cost(db, uid)
|
||||
total_other_cost_30d = _total_other_cost(db, uid, start=start_30d)
|
||||
total_eggs = _total_eggs(db, uid)
|
||||
total_eggs_30d = _total_eggs(db, uid, start=start_30d)
|
||||
|
||||
def cost_per_egg(cost, eggs):
|
||||
if not eggs or not cost:
|
||||
|
||||
@@ -4,6 +4,44 @@ from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ── Auth ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
current_password: str
|
||||
new_password: str = Field(min_length=6)
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
new_password: str = Field(min_length=6)
|
||||
|
||||
class TimezoneUpdate(BaseModel):
|
||||
timezone: str = Field(min_length=1, max_length=64)
|
||||
|
||||
|
||||
# ── Users ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
username: str = Field(min_length=2, max_length=64)
|
||||
password: str = Field(min_length=6)
|
||||
|
||||
class UserOut(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
is_admin: bool
|
||||
is_disabled: bool
|
||||
timezone: str
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ── Egg Collections ───────────────────────────────────────────────────────────
|
||||
|
||||
class EggCollectionCreate(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user