128 lines
4.1 KiB
Python
128 lines
4.1 KiB
Python
"""Server-side session storage in Postgres.
|
|
|
|
Cookie carries only an opaque token. Server resolves token → user_email +
|
|
decrypted IMAP password. Sessions can be listed and revoked from /admin/.
|
|
|
|
Encryption: Fernet, key = base64(sha256(SECRET_KEY)). If SECRET_KEY changes,
|
|
existing session passwords become unreadable and users have to log in again.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import hashlib
|
|
import secrets
|
|
from datetime import datetime, timedelta
|
|
|
|
from cryptography.fernet import Fernet, InvalidToken
|
|
from flask import current_app
|
|
from sqlalchemy import delete, select, update
|
|
|
|
from ..db import db_session
|
|
from ..models import UserSession, AuditLog
|
|
|
|
|
|
_TOKEN_BYTES = 32 # 64 hex chars
|
|
SESSION_TTL = timedelta(days=14)
|
|
|
|
|
|
def _fernet() -> Fernet:
|
|
secret = (current_app.config.get("SECRET_KEY") or "dev-secret").encode()
|
|
return Fernet(base64.urlsafe_b64encode(hashlib.sha256(secret).digest()))
|
|
|
|
|
|
def encrypt(value: str) -> str:
|
|
return _fernet().encrypt(value.encode()).decode()
|
|
|
|
|
|
def decrypt(token: str) -> str:
|
|
try:
|
|
return _fernet().decrypt(token.encode()).decode()
|
|
except InvalidToken:
|
|
return ""
|
|
|
|
|
|
def create(user_email: str, password: str, ip: str = "", user_agent: str = "") -> str:
|
|
s = db_session()
|
|
token = secrets.token_hex(_TOKEN_BYTES)
|
|
s.add(UserSession(
|
|
token=token,
|
|
user_email=user_email,
|
|
password_enc=encrypt(password),
|
|
ip=ip[:64],
|
|
user_agent=user_agent[:400],
|
|
))
|
|
s.commit()
|
|
return token
|
|
|
|
|
|
def load(token: str | None) -> dict | None:
|
|
if not token:
|
|
return None
|
|
s = db_session()
|
|
row = s.get(UserSession, token)
|
|
if not row:
|
|
return None
|
|
if datetime.utcnow() - row.last_seen_at > SESSION_TTL:
|
|
revoke(token)
|
|
return None
|
|
pw = decrypt(row.password_enc)
|
|
if not pw:
|
|
revoke(token)
|
|
return None
|
|
# touch last_seen_at, but at most once per minute to avoid hot writes
|
|
if (datetime.utcnow() - row.last_seen_at).total_seconds() > 60:
|
|
s.execute(update(UserSession).where(UserSession.token == token).values(last_seen_at=datetime.utcnow()))
|
|
s.commit()
|
|
return {"token": token, "user_email": row.user_email, "password": pw,
|
|
"created_at": row.created_at, "last_seen_at": row.last_seen_at,
|
|
"ip": row.ip, "user_agent": row.user_agent}
|
|
|
|
|
|
def revoke(token: str) -> None:
|
|
s = db_session()
|
|
s.execute(delete(UserSession).where(UserSession.token == token))
|
|
s.commit()
|
|
|
|
|
|
def list_for(user_email: str) -> list[dict]:
|
|
s = db_session()
|
|
rows = s.execute(
|
|
select(UserSession).where(UserSession.user_email == user_email).order_by(UserSession.last_seen_at.desc())
|
|
).scalars().all()
|
|
return [{
|
|
"token": r.token, "ip": r.ip, "user_agent": r.user_agent,
|
|
"created_at": r.created_at, "last_seen_at": r.last_seen_at,
|
|
} for r in rows]
|
|
|
|
|
|
def revoke_all_for(user_email: str, except_token: str | None = None) -> int:
|
|
s = db_session()
|
|
q = delete(UserSession).where(UserSession.user_email == user_email)
|
|
if except_token:
|
|
q = q.where(UserSession.token != except_token)
|
|
res = s.execute(q)
|
|
s.commit()
|
|
return res.rowcount or 0
|
|
|
|
|
|
# ── Audit ───────────────────────────────────────────────────────────────
|
|
|
|
def audit(event: str, user_email: str = "", ip: str = "", user_agent: str = "", extra: str = "") -> None:
|
|
s = db_session()
|
|
s.add(AuditLog(
|
|
event=event, user_email=user_email,
|
|
ip=ip[:64], user_agent=user_agent[:400], extra=extra[:5000],
|
|
))
|
|
s.commit()
|
|
|
|
|
|
def recent_audit(limit: int = 50, user_email: str | None = None) -> list[dict]:
|
|
s = db_session()
|
|
q = select(AuditLog).order_by(AuditLog.at.desc()).limit(limit)
|
|
if user_email:
|
|
q = select(AuditLog).where(AuditLog.user_email == user_email).order_by(AuditLog.at.desc()).limit(limit)
|
|
return [{
|
|
"at": r.at, "user_email": r.user_email, "event": r.event,
|
|
"ip": r.ip, "user_agent": r.user_agent, "extra": r.extra,
|
|
} for r in s.execute(q).scalars().all()]
|