Files
s1ne/backend/web/db_extra.py
2026-03-29 23:50:49 -05:00

321 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
backend/web/db_extra.py - 16 May 2025
"""
from __future__ import annotations
import asyncio
import datetime as dt
from datetime import timezone
import structlog
from typing import List
from sqlalchemy import (
Table, Column, Text, Float, Integer, Boolean, DateTime, func,
select, insert, update, delete, or_, inspect, text
)
from sqlalchemy.dialects.postgresql import insert as pg_insert
from backend.core.db import SessionLocal, metadata, engine
from backend.core.settings import (
MAX_IN_USE, FAIL_COOLDOWN_SEC, MIN_SCORE, PROXY_LIST_FILE,
_WINDOW_MINUTES, PROXY_USERNAME, PROXY_PASSWORD,
_MAX_LOGIN_FAILS, _MAX_INVALID_URLS
)
log = structlog.get_logger()
_IS_PG = engine.url.get_backend_name().startswith("postgres")
def _insert_ignore(tbl: Table, **vals):
if _IS_PG:
return pg_insert(tbl).values(**vals).on_conflict_do_nothing()
return insert(tbl).prefix_with("OR IGNORE").values(**vals)
def _clamp_zero(expr):
"""SQLportable max(expr, 0)."""
return func.greatest(expr, 0) if _IS_PG else func.max(expr, 0)
proxy_tbl = Table(
"proxies", metadata,
Column("url", Text, primary_key=True),
Column("score", Float, nullable=False, server_default="1.0"),
Column("fails", Integer, nullable=False, server_default="0"),
Column("banned", Boolean, nullable=False, server_default="false"),
Column("in_use", Integer, nullable=False, server_default="0"),
Column("last_fail", DateTime),
Column("updated_at", DateTime, server_default=func.now(), index=True),
)
login_tbl = Table(
"login_attempts", metadata,
Column("ip", Text, primary_key=True),
Column("count", Integer, nullable=False, server_default="0"),
Column("updated_at", DateTime, nullable=False, server_default=func.now()),
)
invalid_tbl = Table(
"invalid_urls", metadata,
Column("ip", Text, primary_key=True),
Column("count", Integer, nullable=False, server_default="0"),
Column("updated_at", DateTime, nullable=False, server_default=func.now()),
)
dl_stats = Table(
"dl_stats", metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("ok", Boolean, nullable=False),
Column("ts", DateTime, nullable=False, server_default=func.now(), index=True),
)
def _ensure_proxy_columns() -> None:
insp = inspect(engine)
if "proxies" not in insp.get_table_names():
return
existing = {c["name"] for c in insp.get_columns("proxies")}
add: list[tuple[str, str]] = []
if "in_use" not in existing: add.append(("in_use", "INTEGER DEFAULT 0"))
if "last_fail" not in existing: add.append(("last_fail", "TIMESTAMP"))
if not add:
return
with engine.begin() as conn:
for col, ddl in add:
if _IS_PG:
conn.execute(text(f"ALTER TABLE proxies ADD COLUMN IF NOT EXISTS {col} {ddl};"))
else:
conn.execute(text(f"ALTER TABLE proxies ADD COLUMN {col} {ddl};"))
log.info("proxy.schema.auto_migrated", added=[c for c, _ in add])
#metadata.create_all(engine)
_ensure_proxy_columns()
def _seed() -> None:
if not PROXY_LIST_FILE.exists():
return
with SessionLocal.begin() as s:
for ln in PROXY_LIST_FILE.read_text().splitlines():
ln = ln.strip()
if not ln:
continue
ip, port = ln.split(":", 1)
px = (
f"http://{PROXY_USERNAME}:{PROXY_PASSWORD}@{ip}:{port}"
if PROXY_USERNAME else f"http://{ip}:{port}"
)
s.execute(_insert_ignore(proxy_tbl, url=px))
def _candidate_stmt(now: dt.datetime):
cool_ts = now - dt.timedelta(seconds=FAIL_COOLDOWN_SEC)
jitter = func.random() * 0.01
return (
select(proxy_tbl.c.url)
.where(
proxy_tbl.c.banned.is_(False),
proxy_tbl.c.score > MIN_SCORE,
proxy_tbl.c.in_use < MAX_IN_USE,
or_(proxy_tbl.c.last_fail.is_(None), proxy_tbl.c.last_fail < cool_ts),
)
.order_by((proxy_tbl.c.score + jitter).desc())
.limit(1)
.with_for_update(nowait=False)
)
def acquire_proxy() -> str | None:
now = dt.datetime.now(timezone.utc)
with SessionLocal.begin() as s:
row = s.execute(_candidate_stmt(now)).first()
if not row:
return None
px = row[0]
s.execute(
update(proxy_tbl)
.where(proxy_tbl.c.url == px)
.values(in_use=proxy_tbl.c.in_use + 1, updated_at=now)
)
return px
def release_proxy(px: str, ok: bool) -> None:
if not px or px == "DIRECT":
return
now = dt.datetime.now(timezone.utc)
with SessionLocal.begin() as s:
new_in_use = proxy_tbl.c.in_use - 1
s.execute(
update(proxy_tbl)
.where(proxy_tbl.c.url == px)
.values(
in_use=_clamp_zero(new_in_use),
updated_at=now,
last_fail=None if ok else now,
)
)
_buffer: asyncio.Queue[tuple[str, bool]] = asyncio.Queue(maxsize=2048)
def queue_proxy_result(px: str, ok: bool) -> None:
try:
_buffer.put_nowait((px, ok))
except asyncio.QueueFull:
try:
_buffer.get_nowait()
_buffer.put_nowait((px, ok))
except Exception:
pass
async def _flusher() -> None:
while True:
await asyncio.sleep(0.4)
if _buffer.empty():
continue
batch: dict[str, tuple[int, int]] = {}
while not _buffer.empty():
px, ok = _buffer.get_nowait()
succ, fail = batch.get(px, (0, 0))
if ok:
succ += 1
else:
fail += 1
batch[px] = (succ, fail)
now = dt.datetime.now(timezone.utc)
with SessionLocal.begin() as s:
for px, (succ, fail) in batch.items():
delta = 0.1 * succ - 0.2 * fail
stmt = (
update(proxy_tbl)
.where(proxy_tbl.c.url == px)
.values(
score=_clamp_zero(proxy_tbl.c.score + delta),
fails=_clamp_zero(proxy_tbl.c.fails + fail - succ),
banned=(proxy_tbl.c.fails + fail) > 5,
updated_at=now,
)
)
s.execute(stmt)
def start_background_tasks(loop: asyncio.AbstractEventLoop) -> None:
loop.create_task(_flusher())
loop.create_task(asyncio.to_thread(_seed))
_WINDOW_N = 50
def add_dl_stat(ok: bool) -> None:
now = dt.datetime.now(timezone.utc)
with SessionLocal.begin() as s:
s.execute(insert(dl_stats).values(ok=ok, ts=now))
# -------- FIX ③ --------
oldest_keep = select(dl_stats.c.id).order_by(
dl_stats.c.id.desc()
).limit(500)
s.execute(
delete(dl_stats).where(~dl_stats.c.id.in_(oldest_keep))
)
def recent_success_rate(n: int = _WINDOW_N) -> float:
with SessionLocal() as s:
vals = (
s.execute(select(dl_stats.c.ok).order_by(dl_stats.c.id.desc()).limit(n))
.scalars()
.all()
)
return 0.5 if not vals else sum(vals) / len(vals)
def _inc(table: Table, ip: str) -> None:
now = dt.datetime.now(timezone.utc)
with SessionLocal.begin() as s:
row = s.execute(select(table).where(table.c.ip == ip)).first()
if not row:
s.execute(insert(table).values(ip=ip, count=1, updated_at=now))
else:
s.execute(
update(table)
.where(table.c.ip == ip)
.values(count=row.count + 1, updated_at=now)
)
def record_login(ip: str, success: bool) -> None:
if success:
with SessionLocal.begin() as s:
s.execute(update(login_tbl).where(login_tbl.c.ip == ip).values(count=0))
else:
_inc(login_tbl, ip)
def inc_invalid(ip: str) -> None:
_inc(invalid_tbl, ip)
def _over_limit(table: Table, ip: str, cap: int) -> bool:
with SessionLocal() as s:
row = s.execute(
select(table.c.count, table.c.updated_at).where(table.c.ip == ip)
).first()
if not row:
return False
count, ts = row
now = dt.datetime.now(timezone.utc)
if ts.tzinfo is None:
ts = ts.replace(tzinfo=timezone.utc)
if (now - ts).total_seconds() > _WINDOW_MINUTES * 60:
with SessionLocal.begin() as sx:
sx.execute(update(table).where(table.c.ip == ip).values(count=0))
return False
return count >= cap
def too_many_attempts(ip: str) -> bool:
return _over_limit(login_tbl, ip, _MAX_LOGIN_FAILS)
def invalid_over_limit(ip: str) -> bool:
return _over_limit(invalid_tbl, ip, _MAX_INVALID_URLS)
def pick_proxy() -> str | None:
return acquire_proxy()
def ensure_proxy(px: str) -> None:
with SessionLocal.begin() as s:
s.execute(_insert_ignore(proxy_tbl, url=px))
def update_proxy(px: str, ok: bool) -> None:
queue_proxy_result(px, ok)
async def init_proxy_seed() -> None:
await asyncio.to_thread(_seed)