init
This commit is contained in:
0
backend/__init__.py
Normal file
0
backend/__init__.py
Normal file
1
backend/alembic/README
Normal file
1
backend/alembic/README
Normal file
@@ -0,0 +1 @@
|
||||
Generic single-database configuration.
|
||||
97
backend/alembic/env.py
Normal file
97
backend/alembic/env.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from alembic import context
|
||||
from backend.core.settings import SQLALCHEMY_DATABASE_URI
|
||||
from backend.core.db import metadata
|
||||
# Load .env manually (if not loaded by your settings.py or entry point)
|
||||
load_dotenv()
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
# Get environment (default to development if not set)
|
||||
APP_ENV = os.getenv("APP_ENV")
|
||||
config.set_main_option("sqlalchemy.url", SQLALCHEMY_DATABASE_URI)
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
# Determine DB URL
|
||||
if APP_ENV == "production":
|
||||
db_url = os.getenv("DATABASE_URL")
|
||||
if not db_url:
|
||||
raise RuntimeError("DATABASE_URL is required in production.")
|
||||
else:
|
||||
# fallback to local SQLite db
|
||||
db_url = os.getenv("SQLALCHEMY_DATABASE_URI", "sqlite:///data/local.db")
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
url = config.get_main_option("sqlalchemy.url", db_url)
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
print(f"→ Alembic using DB URL: {db_url}")
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connect_args = {"sslmode": "require"}
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
connect_args=connect_args,
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
28
backend/alembic/script.py.mako
Normal file
28
backend/alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,28 @@
|
||||
"""merge user-branch fixes
|
||||
|
||||
Revision ID: 03568bb37289
|
||||
Revises: user_counters_defaults, user_counters_defaults_old
|
||||
Create Date: 2025-05-01 16:01:28.514674
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '03568bb37289'
|
||||
down_revision: Union[str, None] = ('user_counters_defaults', 'user_counters_defaults_old')
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
pass
|
||||
67
backend/alembic/versions/05d6342e2105_tier_text_not_int.py
Normal file
67
backend/alembic/versions/05d6342e2105_tier_text_not_int.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""tier text not int
|
||||
|
||||
Revision ID: 05d6342e2105
|
||||
Revises: 20250506abcd
|
||||
Create Date: 2025-05-06 16:42:38.378374
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '05d6342e2105'
|
||||
down_revision: Union[str, None] = '20250506abcd'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1) make sure every numeric tier is inside the allowed range
|
||||
# (optional safety; no type change yet)
|
||||
op.execute("""
|
||||
UPDATE users
|
||||
SET tier = 0
|
||||
WHERE pg_typeof(tier)::text = 'integer'
|
||||
AND tier NOT IN (0,1,2,3);
|
||||
""")
|
||||
|
||||
# 2) ALTER COLUMN to TEXT first – numeric values become '0','1',...
|
||||
op.alter_column(
|
||||
"users", "tier",
|
||||
existing_type=sa.Integer(),
|
||||
type_=sa.Text(),
|
||||
postgresql_using="tier::text",
|
||||
nullable=False,
|
||||
server_default=sa.text("'Online'")
|
||||
)
|
||||
|
||||
# 3) now map the stringified numbers to names
|
||||
op.execute("""
|
||||
UPDATE users SET tier =
|
||||
CASE tier
|
||||
WHEN '0' THEN 'Online'
|
||||
WHEN '1' THEN 'Rank 1'
|
||||
WHEN '2' THEN 'Rank 2'
|
||||
WHEN '3' THEN 'Rank 3'
|
||||
ELSE tier
|
||||
END;
|
||||
""")
|
||||
|
||||
def downgrade() -> None:
|
||||
# revert to integer, mapping back Online→0 etc. if you really need it
|
||||
op.alter_column(
|
||||
"users", "tier",
|
||||
existing_type=sa.Text(),
|
||||
type_=sa.Integer(),
|
||||
postgresql_using="""
|
||||
CASE tier
|
||||
WHEN 'Rank 1' THEN 1
|
||||
WHEN 'Rank 2' THEN 2
|
||||
WHEN 'Rank 3' THEN 3
|
||||
ELSE 0
|
||||
END::integer
|
||||
""",
|
||||
server_default=sa.text("0"),
|
||||
nullable=False,
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
"""add bonus_active_until to users
|
||||
|
||||
Revision ID: 175f03f1c9f7
|
||||
Revises: ff38ddad43af
|
||||
Create Date: 2025-05-01 16:58:15.855501
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '175f03f1c9f7'
|
||||
down_revision: Union[str, None] = 'ff38ddad43af'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
def upgrade():
|
||||
if not _has_column("users", "bonus_active_until"):
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("bonus_active_until", sa.DateTime(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("users", "bonus_active_until")
|
||||
|
||||
|
||||
# helper
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
def _has_column(table, column):
|
||||
bind = op.get_bind()
|
||||
insp = Inspector.from_engine(bind)
|
||||
return column in [c["name"] for c in insp.get_columns(table)]
|
||||
55
backend/alembic/versions/20250521_remove_xp_system.py
Normal file
55
backend/alembic/versions/20250521_remove_xp_system.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Drop XP / tier related columns (+data) if they exist
|
||||
|
||||
Revision ID: 20250521_remove_xp_system
|
||||
Revises: 05d6342e2105
|
||||
Create Date: 2025-05-21
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "20250521_remove_xp_system"
|
||||
down_revision = "05d6342e2105"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# added "data" here ↓↓↓
|
||||
XP_COLUMNS = (
|
||||
"videos_downloaded",
|
||||
"mb_usage",
|
||||
"level",
|
||||
"xp",
|
||||
"tier",
|
||||
"admin",
|
||||
"vip_badge",
|
||||
"bonus_active_until",
|
||||
"score",
|
||||
"data", # ← drop the NOT-NULL JSON/profile blob
|
||||
)
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing = {c["name"] for c in sa.inspect(conn).get_columns("users")}
|
||||
with op.batch_alter_table("users") as batch:
|
||||
for col in XP_COLUMNS:
|
||||
if col in existing:
|
||||
batch.drop_column(col)
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing = {c["name"] for c in sa.inspect(conn).get_columns("users")}
|
||||
with op.batch_alter_table("users") as batch:
|
||||
# Only re-add if missing; "data" comes back as JSON/Text and NULL-able
|
||||
add = lambda name, *args, **kw: (
|
||||
batch.add_column(sa.Column(name, *args, **kw))
|
||||
if name not in existing else None
|
||||
)
|
||||
add("videos_downloaded", sa.Integer(), server_default="0", nullable=False)
|
||||
add("mb_usage", sa.Float(), server_default="0", nullable=False)
|
||||
add("level", sa.Integer(), server_default="1", nullable=False)
|
||||
add("xp", sa.Integer(), server_default="0", nullable=False)
|
||||
add("tier", sa.Text(), server_default="Online")
|
||||
add("admin", sa.Boolean(), server_default="false", nullable=False)
|
||||
add("vip_badge", sa.Text())
|
||||
add("bonus_active_until", sa.DateTime(timezone=True))
|
||||
add("score", sa.Float(), server_default="0")
|
||||
add("data", sa.JSON(), server_default="{}") # <<<
|
||||
@@ -0,0 +1,36 @@
|
||||
"""default zeros for new user counters"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "user_counters_defaults"
|
||||
down_revision = "add_ok_to_dl_stats"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _add(column: str, coltype, default_sql: str):
|
||||
# Add column nullable=True with default, then make NOT NULL & drop default
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column(column, coltype, nullable=True, server_default=sa.text(default_sql)),
|
||||
)
|
||||
op.alter_column("users", column, server_default=None, nullable=False)
|
||||
|
||||
|
||||
def upgrade():
|
||||
_add("videos_downloaded", sa.Integer(), "0")
|
||||
_add("mb_usage", sa.Float(), "0")
|
||||
_add("level", sa.Integer(), "1")
|
||||
_add("xp", sa.Integer(), "0")
|
||||
_add("tier", sa.Integer(), "0")
|
||||
_add("ban_status", sa.Boolean(), "false")
|
||||
_add("soft_banned", sa.Boolean(), "false")
|
||||
|
||||
|
||||
def downgrade():
|
||||
for col in (
|
||||
"soft_banned", "ban_status", "tier",
|
||||
"xp", "level", "mb_usage", "videos_downloaded",
|
||||
):
|
||||
op.drop_column("users", col)
|
||||
@@ -0,0 +1,64 @@
|
||||
"""drop url column, make ip the primary key on users
|
||||
|
||||
Revision ID: 55327cbf08df
|
||||
Revises: 70e118917866
|
||||
Create Date: 2025-05-01 07:26:50.279482
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '55327cbf08df'
|
||||
down_revision: Union[str, None] = '70e118917866'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
|
||||
def _has_column(conn, table, column):
|
||||
return column in [
|
||||
c["name"] for c in Inspector.from_engine(conn).get_columns(table)
|
||||
]
|
||||
|
||||
|
||||
def upgrade():
|
||||
conn = op.get_bind()
|
||||
insp = Inspector.from_engine(conn)
|
||||
|
||||
# 1) Drop any UNIQUE constraints on url
|
||||
for uc in insp.get_unique_constraints("users"):
|
||||
if "url" in uc["column_names"]:
|
||||
op.drop_constraint(uc["name"], "users", type_="unique")
|
||||
|
||||
# 2) Drop whatever PRIMARY KEY exists today
|
||||
pk = insp.get_pk_constraint("users")["name"]
|
||||
if pk:
|
||||
op.drop_constraint(pk, "users", type_="primary")
|
||||
|
||||
# 3) Remove the old url column
|
||||
if "url" in {c["name"] for c in insp.get_columns("users")}:
|
||||
op.drop_column("users", "url")
|
||||
|
||||
# 4) Create a new PK on ip
|
||||
op.create_primary_key("users_pkey", "users", ["ip"])
|
||||
|
||||
|
||||
def downgrade():
|
||||
conn = op.get_bind()
|
||||
insp = Inspector.from_engine(conn)
|
||||
|
||||
# 1) Drop the ip PK
|
||||
pk = insp.get_pk_constraint("users")["name"]
|
||||
if pk:
|
||||
op.drop_constraint(pk, "users", type_="primary")
|
||||
|
||||
# 2) Re-create url column
|
||||
op.add_column("users", sa.Column("url", sa.Text(), nullable=False))
|
||||
|
||||
# 3) Restore the PK on url
|
||||
op.create_primary_key("users_pkey", "users", ["url"])
|
||||
@@ -0,0 +1,22 @@
|
||||
"""add ok column to dl_stats"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "add_ok_to_dl_stats"
|
||||
down_revision = "55327cbf08df" # or latest hash
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column(
|
||||
"dl_stats",
|
||||
sa.Column("ok", sa.Boolean(), nullable=False, server_default=sa.text("false"))
|
||||
)
|
||||
# drop default if you like
|
||||
op.alter_column("dl_stats", "ok", server_default=None)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("dl_stats", "ok")
|
||||
@@ -0,0 +1,43 @@
|
||||
"""add default to users.videos_downloaded
|
||||
|
||||
Revision ID: 20250506abcd
|
||||
Revises: previous_revision_id
|
||||
Create Date: 2025-05-06 22:10:00.000000
|
||||
"""
|
||||
from typing import Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "20250506abcd"
|
||||
down_revision = 'abe00f7f8f72'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# 1) back‑fill existing NULLs
|
||||
op.execute(
|
||||
"UPDATE users SET videos_downloaded = 0 "
|
||||
"WHERE videos_downloaded IS NULL;"
|
||||
)
|
||||
|
||||
# 2) give the column a server‑side default
|
||||
op.alter_column(
|
||||
"users",
|
||||
"videos_downloaded",
|
||||
existing_type=sa.Integer(),
|
||||
server_default=sa.text("0"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.alter_column(
|
||||
"users",
|
||||
"videos_downloaded",
|
||||
existing_type=sa.Integer(),
|
||||
server_default=None,
|
||||
nullable=True,
|
||||
)
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add ok column to dl_stats
|
||||
|
||||
Revision ID: 7064708f684e
|
||||
Revises: 86141e89fea3
|
||||
Create Date: 2025-05-01 16:29:31.009976
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '7064708f684e'
|
||||
down_revision: Union[str, None] = '86141e89fea3'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
|
||||
def upgrade():
|
||||
# if “success” exists, rename it; otherwise just add ok
|
||||
op.add_column("dl_stats", sa.Column("ok", sa.Boolean(), server_default="false", nullable=False))
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("dl_stats", "ok")
|
||||
@@ -0,0 +1,48 @@
|
||||
"""make ip the primary key on users, drop obsolete url
|
||||
|
||||
Revision ID: 70e118917866
|
||||
Revises: e03269ce4058
|
||||
Create Date: 2025-05-01 07:03:06.119500
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine import reflection
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '70e118917866'
|
||||
down_revision: Union[str, None] = 'e03269ce4058'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
def _column_exists(conn, table, column):
|
||||
insp = reflection.Inspector.from_engine(conn)
|
||||
return column in [c["name"] for c in insp.get_columns(table)]
|
||||
|
||||
|
||||
def upgrade():
|
||||
conn = op.get_bind()
|
||||
|
||||
# 1) add ip column if the old table never had it
|
||||
if not _column_exists(conn, "users", "ip"):
|
||||
op.add_column("users", sa.Column("ip", sa.Text(), nullable=False))
|
||||
|
||||
# 2) drop the old PK that was on url (or any other column mix)
|
||||
op.drop_constraint("users_pkey", "users", type_="primary")
|
||||
|
||||
# 3) create new PK on ip
|
||||
op.create_primary_key("users_pkey", "users", ["ip"])
|
||||
|
||||
# 4) drop obsolete url column if it’s still there
|
||||
if _column_exists(conn, "users", "url"):
|
||||
op.drop_column("users", "url")
|
||||
|
||||
|
||||
def downgrade():
|
||||
# reverse: recreate url, restore old PK
|
||||
op.add_column("users", sa.Column("url", sa.Text(), nullable=False))
|
||||
op.drop_constraint("users_pkey", "users", type_="primary")
|
||||
op.create_primary_key("users_pkey", "users", ["url"])
|
||||
op.drop_column("users", "ip")
|
||||
28
backend/alembic/versions/86141e89fea3_merge_heads.py
Normal file
28
backend/alembic/versions/86141e89fea3_merge_heads.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""merge heads
|
||||
|
||||
Revision ID: 86141e89fea3
|
||||
Revises: 03568bb37289, ffae4495003d
|
||||
Create Date: 2025-05-01 16:08:14.791524
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '86141e89fea3'
|
||||
down_revision: Union[str, None] = ('03568bb37289', 'ffae4495003d')
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
pass
|
||||
@@ -0,0 +1,40 @@
|
||||
"""add banned and updated_at columns to proxies
|
||||
|
||||
Revision ID: 957c893a8a67
|
||||
Revises:
|
||||
Create Date: 2025-05-01 06:45:41.546150
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '957c893a8a67'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── Add the new 'banned' column with default FALSE ──────────
|
||||
op.add_column(
|
||||
'proxies',
|
||||
sa.Column('banned', sa.Boolean(), nullable=False, server_default=sa.text('FALSE'))
|
||||
)
|
||||
# ── (Optional) Add an 'updated_at' timestamp column ─────────
|
||||
op.add_column(
|
||||
'proxies',
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True)
|
||||
)
|
||||
# ── (Optional) Remove server_default if you don’t need it going forward ─
|
||||
op.alter_column('proxies', 'banned', server_default=None)
|
||||
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ── Drop in reverse order ───────────────────────────────
|
||||
op.drop_column('proxies', 'updated_at')
|
||||
op.drop_column('proxies', 'banned')
|
||||
@@ -0,0 +1,51 @@
|
||||
"""proxy leasing columns
|
||||
|
||||
Revision ID: abe00f7f8f72
|
||||
Revises: 175f03f1c9f7
|
||||
Create Date: 2025‑05‑05 18:12:44
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "abe00f7f8f72"
|
||||
down_revision = "175f03f1c9f7"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _has_column(bind, table: str, column: str) -> bool:
|
||||
insp = inspect(bind)
|
||||
return column in {c["name"] for c in insp.get_columns(table)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
|
||||
# add only the columns that are missing
|
||||
add_in_use = not _has_column(bind, "proxies", "in_use")
|
||||
add_last_fail = not _has_column(bind, "proxies", "last_fail")
|
||||
|
||||
if add_in_use or add_last_fail:
|
||||
with op.batch_alter_table("proxies") as batch:
|
||||
if add_in_use:
|
||||
batch.add_column(sa.Column("in_use", sa.Integer(), server_default="0"))
|
||||
if add_last_fail:
|
||||
batch.add_column(sa.Column("last_fail", sa.DateTime()))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# downgrade assumes the columns exist, so drop them only if present
|
||||
bind = op.get_bind()
|
||||
|
||||
drop_in_use = _has_column(bind, "proxies", "in_use")
|
||||
drop_last_fail = _has_column(bind, "proxies", "last_fail")
|
||||
|
||||
if drop_in_use or drop_last_fail:
|
||||
with op.batch_alter_table("proxies") as batch:
|
||||
if drop_last_fail:
|
||||
batch.drop_column("last_fail")
|
||||
if drop_in_use:
|
||||
batch.drop_column("in_use")
|
||||
@@ -0,0 +1,37 @@
|
||||
"""add first_visit column to users
|
||||
|
||||
Revision ID: e03269ce4058
|
||||
Revises: 957c893a8a67
|
||||
Create Date: 2025-05-01 06:58:05.118501
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'e03269ce4058'
|
||||
down_revision: Union[str, None] = '957c893a8a67'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# 1) add nullable column with UTC default
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column(
|
||||
"first_visit",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
nullable=True,
|
||||
server_default=sa.text("timezone('utc', now())")
|
||||
),
|
||||
)
|
||||
# 2) once it’s there, make it NOT NULL and drop the default
|
||||
op.alter_column("users", "first_visit", nullable=False, server_default=None)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("users", "first_visit")
|
||||
@@ -0,0 +1,41 @@
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
revision = "ff38ddad43af"
|
||||
down_revision = "7064708f684e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
conn = op.get_bind()
|
||||
insp = Inspector.from_engine(conn)
|
||||
cols = {c["name"]: c for c in insp.get_columns("dl_stats")}
|
||||
|
||||
# ── 1. guarantee “id” exists and is the PK ──────────────────────
|
||||
if "id" not in cols:
|
||||
op.add_column("dl_stats", sa.Column("id", sa.Integer(), primary_key=True))
|
||||
# rely on PostgreSQL's implicit sequence; no ALTER ... ADD GENERATED
|
||||
op.create_primary_key("dl_stats_pkey", "dl_stats", ["id"])
|
||||
else:
|
||||
pk_cols = insp.get_pk_constraint("dl_stats")["constrained_columns"]
|
||||
if "id" not in pk_cols:
|
||||
op.drop_constraint("dl_stats_pkey", "dl_stats", type_="primary")
|
||||
op.create_primary_key("dl_stats_pkey", "dl_stats", ["id"])
|
||||
# do **not** attempt to alter the column’s default/identity
|
||||
|
||||
# ── 2. add “ok” boolean if missing, back-fill from “success” ───
|
||||
if "ok" not in cols:
|
||||
op.add_column(
|
||||
"dl_stats",
|
||||
sa.Column("ok", sa.Boolean(), nullable=False,
|
||||
server_default=sa.text("false")),
|
||||
)
|
||||
if "success" in cols:
|
||||
op.execute("UPDATE dl_stats SET ok = success")
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("dl_stats", "ok")
|
||||
op.drop_constraint("dl_stats_pkey", "dl_stats", type_="primary")
|
||||
op.drop_column("dl_stats", "id")
|
||||
@@ -0,0 +1,10 @@
|
||||
revision = 'ffae4495003d'
|
||||
down_revision = '55327cbf08df' # or whatever its real parent was
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
def upgrade():
|
||||
pass
|
||||
|
||||
def downgrade():
|
||||
pass
|
||||
@@ -0,0 +1,45 @@
|
||||
"""default zeros for new user counters"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
revision = "user_counters_defaults_old"
|
||||
down_revision = "add_ok_to_dl_stats"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _has_column(table: str, column: str, conn) -> bool:
|
||||
insp = Inspector.from_engine(conn)
|
||||
return column in [c["name"] for c in insp.get_columns(table)]
|
||||
|
||||
|
||||
def _add(column: str, coltype, default_sql: str, conn):
|
||||
if not _has_column("users", column, conn):
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column(column, coltype, nullable=True, server_default=sa.text(default_sql)),
|
||||
)
|
||||
# Whether it was just added or already existed, be sure it is NOT NULL and no default remains
|
||||
op.alter_column("users", column, nullable=False, server_default=None)
|
||||
|
||||
|
||||
def upgrade():
|
||||
conn = op.get_bind()
|
||||
|
||||
_add("videos_downloaded", sa.Integer(), "0", conn)
|
||||
_add("mb_usage", sa.Float(), "0", conn)
|
||||
_add("level", sa.Integer(), "1", conn)
|
||||
_add("xp", sa.Integer(), "0", conn)
|
||||
_add("tier", sa.Integer(), "0", conn)
|
||||
_add("ban_status", sa.Boolean(), "false", conn)
|
||||
_add("soft_banned", sa.Boolean(), "false", conn)
|
||||
|
||||
|
||||
def downgrade():
|
||||
for col in (
|
||||
"soft_banned", "ban_status", "tier",
|
||||
"xp", "level", "mb_usage", "videos_downloaded",
|
||||
):
|
||||
op.drop_column("users", col)
|
||||
0
backend/core/__init__.py
Normal file
0
backend/core/__init__.py
Normal file
73
backend/core/db.py
Normal file
73
backend/core/db.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Shared SQLAlchemy engine / session + schema bootstrap
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import importlib, datetime
|
||||
import os
|
||||
|
||||
from sqlalchemy import (
|
||||
create_engine, event, Table, Column, Text, DateTime, Index, make_url
|
||||
)
|
||||
from sqlalchemy.engine import url
|
||||
from sqlalchemy.orm import sessionmaker, registry
|
||||
from backend.core.settings import (
|
||||
SQLALCHEMY_DATABASE_URI, DB_POOL_SIZE, DB_ECHO, SKIP_SCHEMA_BOOTSTRAP
|
||||
)
|
||||
|
||||
IS_PG = SQLALCHEMY_DATABASE_URI.startswith("postgresql")
|
||||
|
||||
parsed_url = make_url(SQLALCHEMY_DATABASE_URI) # string into URL object
|
||||
|
||||
# engine & session
|
||||
connect_args = {"sslmode": "require"} if parsed_url.drivername.startswith("postgresql") else {}
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URI,
|
||||
pool_size = DB_POOL_SIZE,
|
||||
max_overflow = 20,
|
||||
pool_timeout = 30,
|
||||
echo = DB_ECHO,
|
||||
future = True,
|
||||
pool_pre_ping = True,
|
||||
pool_recycle=3600,
|
||||
connect_args = connect_args
|
||||
)
|
||||
|
||||
# SQLite -> WAL for concurrency
|
||||
if SQLALCHEMY_DATABASE_URI.startswith("sqlite:///"):
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def _set_wal(dbapi_conn, _):
|
||||
dbapi_conn.execute("PRAGMA journal_mode=WAL;")
|
||||
|
||||
SessionLocal = sessionmaker(bind=engine, autoflush=False,
|
||||
expire_on_commit=False, future=True)
|
||||
|
||||
# metadata (tables from every module)
|
||||
mapper_registry = registry()
|
||||
metadata = mapper_registry.metadata
|
||||
|
||||
# download-cache table
|
||||
download_cache = Table(
|
||||
"download_cache", metadata,
|
||||
Column("key", Text, primary_key=True),
|
||||
Column("path", Text, nullable=False),
|
||||
Column("ext", Text, nullable=False),
|
||||
Column("created_at", DateTime, default=datetime.datetime.utcnow,
|
||||
nullable=False, index=True),
|
||||
)
|
||||
Index("ix_download_cache_created", download_cache.c.created_at)
|
||||
|
||||
# auto-bootstrap all
|
||||
def _bootstrap_schema() -> None:
|
||||
"""Import modules then create."""
|
||||
table_modules = (
|
||||
"backend.core.db_xp",
|
||||
"backend.web.db_extra",
|
||||
"backend.core.formats",
|
||||
)
|
||||
for mod in table_modules:
|
||||
importlib.import_module(mod)
|
||||
|
||||
metadata.create_all(engine)
|
||||
if SKIP_SCHEMA_BOOTSTRAP != "1":
|
||||
_bootstrap_schema()
|
||||
14
backend/core/db_cache.py
Normal file
14
backend/core/db_cache.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
compat layer - exposes getconn used by older code
|
||||
"""
|
||||
from contextlib import contextmanager
|
||||
from backend.core.db import engine
|
||||
|
||||
@contextmanager
|
||||
def getconn():
|
||||
conn = engine.raw_connection()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
31
backend/core/db_utils.py
Normal file
31
backend/core/db_utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# backend/core/db_utils.py
|
||||
|
||||
from sqlalchemy import insert as sa_insert
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from core.db import engine
|
||||
|
||||
_IS_PG = engine.url.get_backend_name().startswith("postgres")
|
||||
|
||||
def upsert(
|
||||
tbl,
|
||||
insert_values: dict,
|
||||
conflict_cols: list[str],
|
||||
update_values: dict | None = None,
|
||||
):
|
||||
|
||||
if _IS_PG:
|
||||
stmt = (
|
||||
pg_insert(tbl)
|
||||
.values(**insert_values)
|
||||
.on_conflict_do_update(
|
||||
index_elements=conflict_cols,
|
||||
set_=update_values or insert_values,
|
||||
)
|
||||
)
|
||||
else:
|
||||
stmt = (
|
||||
sa_insert(tbl)
|
||||
.values(**insert_values)
|
||||
.prefix_with("OR REPLACE") # SQLite
|
||||
)
|
||||
return stmt
|
||||
70
backend/core/db_xp.py
Normal file
70
backend/core/db_xp.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""db_xp.py – minimal user helper"""
|
||||
|
||||
from __future__ import annotations
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any
|
||||
|
||||
from sqlalchemy import (
|
||||
Table, MetaData, select, func, insert as sa_insert, text, inspect
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from backend.core.db import SessionLocal, engine
|
||||
|
||||
_NOW = lambda: datetime.now(timezone.utc)
|
||||
metadata = MetaData()
|
||||
|
||||
_IS_PG = engine.url.get_backend_name().startswith("postgres")
|
||||
|
||||
def _get_users_table() -> Table:
|
||||
return Table("users", metadata, autoload_with=engine)
|
||||
|
||||
def _get_column_info():
|
||||
try:
|
||||
insp = inspect(engine)
|
||||
cols = {c["name"]: c for c in insp.get_columns("users")}
|
||||
return cols
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def _insert_ignore(**vals):
|
||||
users = _get_users_table()
|
||||
if _IS_PG:
|
||||
return (
|
||||
pg_insert(users)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=["ip"])
|
||||
)
|
||||
return sa_insert(users).values(**vals).prefix_with("OR IGNORE")
|
||||
|
||||
def ensure_user(ip: str) -> None:
|
||||
cols = _get_column_info()
|
||||
has_data = (
|
||||
"data" in cols and
|
||||
not cols["data"].get("nullable", True) # NOT NULL
|
||||
)
|
||||
|
||||
vals = dict(
|
||||
ip=ip,
|
||||
first_visit=_NOW(),
|
||||
ban_status=False,
|
||||
soft_banned=False,
|
||||
)
|
||||
if has_data:
|
||||
vals["data"] = {}
|
||||
|
||||
stmt = _insert_ignore(**vals)
|
||||
with SessionLocal.begin() as s:
|
||||
s.execute(stmt)
|
||||
|
||||
def is_ip_banned(ip: str) -> bool:
|
||||
users = _get_users_table()
|
||||
with SessionLocal() as s:
|
||||
return bool(s.scalar(select(users.c.ban_status).where(users.c.ip == ip)))
|
||||
|
||||
def get_status(ip: str) -> Dict[str, Any]:
|
||||
ensure_user(ip)
|
||||
users = _get_users_table()
|
||||
with SessionLocal() as s:
|
||||
soft = s.scalar(select(users.c.soft_banned).where(users.c.ip == ip))
|
||||
return {"soft_banned": bool(soft)}
|
||||
287
backend/core/download.py
Normal file
287
backend/core/download.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""backend/core/download.py — patched 2025-06-03"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import datetime
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import yt_dlp
|
||||
from sqlalchemy import select, delete
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
|
||||
from core.db_utils import upsert
|
||||
from core.settings import (
|
||||
DOWNLOAD_DIR,
|
||||
TMP_DIR,
|
||||
PER_IP_CONCURRENCY,
|
||||
DOWNLOAD_CACHE_TTL_SEC,
|
||||
)
|
||||
from core.network import get_proxy, record_proxy
|
||||
from core.db_xp import ensure_user
|
||||
from core.db import SessionLocal, download_cache
|
||||
from core.progress_bus import update as set_progress
|
||||
from core.formats import _cached_metadata_fetch, _clean_proxy
|
||||
|
||||
os.makedirs(DOWNLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(TMP_DIR, exist_ok=True)
|
||||
|
||||
EST_MB = contextvars.ContextVar("est_mb", default=0)
|
||||
|
||||
MAX_GLOBAL_DOWNLOADS = PER_IP_CONCURRENCY * 4
|
||||
_global_semaphore = asyncio.Semaphore(MAX_GLOBAL_DOWNLOADS)
|
||||
|
||||
_ip_semaphores: Dict[str, asyncio.BoundedSemaphore] = {}
|
||||
_inflight: Dict[str, asyncio.Task[str]] = {}
|
||||
|
||||
_ip_cache: Dict[str, set[str]] = {}
|
||||
|
||||
|
||||
def _get_ip_cache(ip: str) -> set[str]:
|
||||
return _ip_cache.setdefault(ip, set())
|
||||
|
||||
|
||||
def _url_fmt_hash(url: str, fmt: str) -> str:
|
||||
return hashlib.blake2s(f"{url}::{fmt}".encode(), digest_size=16).hexdigest()
|
||||
|
||||
|
||||
_ansi_escape = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]")
|
||||
|
||||
|
||||
async def download(url: str, fmt_id: str, ip: str, sid: str) -> str:
|
||||
fmt = fmt_id or "bestaudio"
|
||||
key = _url_fmt_hash(url, fmt)
|
||||
dup_key = f"{ip}::{url}::{fmt}"
|
||||
|
||||
cached = await asyncio.to_thread(_lookup_cache_sync, key)
|
||||
if cached:
|
||||
set_progress(sid, status="cached", pct=100, progress="Instant – served from cache")
|
||||
return cached
|
||||
|
||||
if key in _inflight:
|
||||
return await _inflight[key]
|
||||
|
||||
ip_cache_set = _get_ip_cache(ip)
|
||||
if dup_key in ip_cache_set:
|
||||
cached2 = await asyncio.to_thread(_lookup_cache_sync, key)
|
||||
if cached2:
|
||||
set_progress(sid, status="cached", pct=100, progress="Instant – served from cache")
|
||||
return cached2
|
||||
ip_cache_set.add(dup_key)
|
||||
|
||||
sem = _ip_semaphores.setdefault(ip, asyncio.BoundedSemaphore(PER_IP_CONCURRENCY))
|
||||
|
||||
async def _run() -> str:
|
||||
async with _global_semaphore, sem:
|
||||
ensure_user(ip)
|
||||
set_progress(sid, status="starting", pct=0, progress="Starting…")
|
||||
|
||||
try:
|
||||
info = await asyncio.to_thread(_cached_metadata_fetch, url)
|
||||
except Exception as e:
|
||||
set_progress(sid, status="error", progress=f"Metadata fetch failed: {e}")
|
||||
raise
|
||||
|
||||
attempt = 0
|
||||
last_exc: Exception | None = None
|
||||
|
||||
while attempt < 3:
|
||||
attempt += 1
|
||||
proxies = [get_proxy() for _ in range(5)]
|
||||
random.shuffle(proxies)
|
||||
proxies.append("DIRECT")
|
||||
|
||||
for proxy_url in proxies:
|
||||
try:
|
||||
final_path = await _single_download(
|
||||
url,
|
||||
fmt,
|
||||
key,
|
||||
sid,
|
||||
proxy_url,
|
||||
info,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
record_proxy(proxy_url, False)
|
||||
clean_proxy = _clean_proxy(proxy_url)
|
||||
set_progress(
|
||||
sid,
|
||||
status="retrying",
|
||||
progress=f"Retry {attempt} failed (proxy {clean_proxy})",
|
||||
)
|
||||
await asyncio.sleep(1 + random.random())
|
||||
continue
|
||||
else:
|
||||
record_proxy(proxy_url, True)
|
||||
await asyncio.to_thread(_store_cache_sync, key, final_path)
|
||||
set_progress(sid, status="finished", pct=100, progress="Done")
|
||||
return final_path
|
||||
|
||||
set_progress(sid, status="error", progress="Download failed")
|
||||
raise RuntimeError(f"All download attempts failed: {last_exc!r}")
|
||||
|
||||
task = asyncio.create_task(_run())
|
||||
_inflight[key] = task
|
||||
try:
|
||||
return await task
|
||||
finally:
|
||||
_inflight.pop(key, None)
|
||||
asyncio.create_task(_expire_ip_cache_entry(ip, dup_key))
|
||||
|
||||
|
||||
async def _single_download(
|
||||
url: str,
|
||||
fmt: str,
|
||||
_unused_cache_key: str,
|
||||
sid: str,
|
||||
proxy_url: str,
|
||||
info: dict,
|
||||
) -> str:
|
||||
title = info.get("title") or "unknown"
|
||||
artist = info.get("artist") or info.get("uploader") or "unknown"
|
||||
|
||||
def _clean(s: str) -> str:
|
||||
return re.sub(r'[\\/*?:"<>|]', "", s)
|
||||
|
||||
safe_title = _clean(title)
|
||||
safe_artist = _clean(artist)
|
||||
short_id = hashlib.blake2s(f"{url}::{fmt}".encode(), digest_size=8).hexdigest()
|
||||
base = f"{safe_title} - {safe_artist} - {short_id}"
|
||||
|
||||
fmt_entry = next((f for f in info.get("formats", []) if f.get("format_id") == fmt), None)
|
||||
is_audio_only = bool(fmt_entry and fmt_entry.get("vcodec") == "none")
|
||||
if "soundcloud.com" in url.lower():
|
||||
is_audio_only = True
|
||||
|
||||
# force .mp3 for audio-only, .mp4 otherwise
|
||||
ext_guess = "mp3" if is_audio_only else "mp4"
|
||||
outtmpl_path = DOWNLOAD_DIR / f"{base}.%(ext)s"
|
||||
final_path_expected = DOWNLOAD_DIR / f"{base}.{ext_guess}"
|
||||
|
||||
if final_path_expected.exists() and final_path_expected.stat().st_size > 0:
|
||||
return str(final_path_expected)
|
||||
|
||||
cmd = ["yt-dlp", "-f", fmt, "-o", str(outtmpl_path), url]
|
||||
if is_audio_only:
|
||||
cmd = ["yt-dlp", "-x", "--audio-format", "mp3", "-o", str(outtmpl_path), url]
|
||||
else:
|
||||
cmd = ["yt-dlp", "-f", f"{fmt}+bestaudio", "-o", str(outtmpl_path), url]
|
||||
if proxy_url and proxy_url.upper() != "DIRECT":
|
||||
cmd.insert(1, f"--proxy={proxy_url}")
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
stderr_buffer: list[str] = []
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
line = await asyncio.wait_for(proc.stderr.readline(), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
if not line:
|
||||
break
|
||||
decoded = line.decode(errors="ignore").strip()
|
||||
if decoded:
|
||||
stderr_buffer.append(decoded)
|
||||
if len(stderr_buffer) > 10:
|
||||
stderr_buffer.pop(0)
|
||||
set_progress(sid, status="running", pct=None, progress=decoded)
|
||||
|
||||
try:
|
||||
rc = await asyncio.wait_for(proc.wait(), timeout=15)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
raise RuntimeError("yt-dlp stalled and was killed (timeout)")
|
||||
|
||||
if rc != 0:
|
||||
raise RuntimeError(f"yt-dlp exited with code {rc}. Last lines: {' | '.join(stderr_buffer)}")
|
||||
|
||||
candidates = [
|
||||
p for p in DOWNLOAD_DIR.glob(f"{base}.*")
|
||||
if p.is_file() and p.stat().st_size > 0
|
||||
]
|
||||
if not candidates:
|
||||
raise RuntimeError("No output file produced")
|
||||
|
||||
candidates.sort(key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
return str(candidates[0])
|
||||
|
||||
except asyncio.CancelledError:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
for f in DOWNLOAD_DIR.glob(f"{base}.*"):
|
||||
f.unlink(missing_ok=True)
|
||||
raise
|
||||
|
||||
|
||||
def _lookup_cache_sync(key: str) -> str | None:
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
with SessionLocal() as session:
|
||||
try:
|
||||
row = session.execute(
|
||||
select(download_cache.c.path, download_cache.c.created_at)
|
||||
.where(download_cache.c.key == key)
|
||||
).one()
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
path_on_disk, created_at = row
|
||||
if created_at.tzinfo is None:
|
||||
created_at = created_at.replace(tzinfo=datetime.timezone.utc)
|
||||
|
||||
age = (now - created_at).total_seconds()
|
||||
if age > DOWNLOAD_CACHE_TTL_SEC:
|
||||
session.execute(delete(download_cache).where(download_cache.c.key == key))
|
||||
session.commit()
|
||||
try:
|
||||
os.remove(path_on_disk)
|
||||
except OSError:
|
||||
pass
|
||||
return None
|
||||
|
||||
if not os.path.exists(path_on_disk):
|
||||
session.execute(delete(download_cache).where(download_cache.c.key == key))
|
||||
session.commit()
|
||||
return None
|
||||
|
||||
return path_on_disk
|
||||
|
||||
|
||||
def _store_cache_sync(key: str, path: str) -> None:
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
insert_values = {
|
||||
"key": key,
|
||||
"path": path,
|
||||
"ext": Path(path).suffix.lstrip("."),
|
||||
"created_at": now,
|
||||
}
|
||||
stmt = upsert(
|
||||
download_cache,
|
||||
insert_values=insert_values,
|
||||
conflict_cols=["key"],
|
||||
update_values={"path": path, "ext": Path(path).suffix.lstrip("."), "created_at": now},
|
||||
)
|
||||
with SessionLocal.begin() as session:
|
||||
session.execute(stmt)
|
||||
|
||||
|
||||
async def _expire_ip_cache_entry(ip: str, dup_key: str, delay: int = 300) -> None:
|
||||
await asyncio.sleep(delay)
|
||||
_get_ip_cache(ip).discard(dup_key)
|
||||
265
backend/core/formats.py
Normal file
265
backend/core/formats.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""backend/core/formats.py — patched 2025-06-03"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import urllib.parse as _url
|
||||
from datetime import datetime, timezone
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import yt_dlp
|
||||
import structlog
|
||||
from sqlalchemy import select, delete, Table, Column, Text, DateTime, JSON
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
|
||||
from core.db import SessionLocal, metadata
|
||||
from core.network import get_proxy, record_proxy, stealth_headers
|
||||
from core.settings import FORMAT_CACHE_TTL_SEC
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
format_cache = Table(
|
||||
"format_cache",
|
||||
metadata,
|
||||
Column("url", Text, primary_key=True),
|
||||
Column("cached_at", DateTime, nullable=False),
|
||||
Column("info", JSON, nullable=False),
|
||||
)
|
||||
|
||||
_YT_PAT = re.compile(r"(youtu\.be/|youtube\.com/(?:watch|shorts))", re.I)
|
||||
_BC_PAT = re.compile(r"\.bandcamp\.com", re.I)
|
||||
_SC_PAT = re.compile(r"(?:soundcloud\.com|on\.soundcloud\.com|m\.soundcloud\.com)", re.I)
|
||||
_TW_PAT = re.compile(r"(?:twitter\.com|x\.com|mobile\.twitter\.com)", re.I)
|
||||
_ansi_escape = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]")
|
||||
|
||||
# resolve cookie file path from env or fallback to root-relative path
|
||||
COOKIE_FILE = Path(os.getenv("YT_COOKIE_FILE", Path(__file__).resolve().parents[2] / "playwright_cookies.txt"))
|
||||
log.info("cookie_file_resolved", path=str(COOKIE_FILE), exists=COOKIE_FILE.exists())
|
||||
|
||||
def _canonical_url(u: str) -> str:
|
||||
u = u.strip()
|
||||
if not u.lower().startswith(("http://", "https://")):
|
||||
return u
|
||||
|
||||
if _YT_PAT.search(u):
|
||||
parsed = _url.urlparse(u)
|
||||
if "youtu.be" in parsed.netloc:
|
||||
vid = parsed.path.lstrip("/")
|
||||
else:
|
||||
q = _url.parse_qs(parsed.query)
|
||||
vid = (q.get("v") or [None])[0]
|
||||
if not vid and parsed.path.startswith("/shorts/"):
|
||||
vid = parsed.path.split("/")[2]
|
||||
return f"https://www.youtube.com/watch?v={vid}" if vid else u
|
||||
|
||||
if _BC_PAT.search(u):
|
||||
parsed = _url.urlparse(u)
|
||||
clean = parsed._replace(query="", fragment="")
|
||||
return _url.urlunparse(clean)
|
||||
|
||||
if _SC_PAT.search(u):
|
||||
u2 = (
|
||||
u.replace("m.soundcloud.com", "soundcloud.com")
|
||||
.replace("on.soundcloud.com", "soundcloud.com")
|
||||
)
|
||||
return u2.split("?")[0].split("#")[0]
|
||||
|
||||
if _TW_PAT.search(u):
|
||||
parsed = _url.urlparse(
|
||||
u.replace("mobile.twitter.com", "x.com").replace("twitter.com", "x.com")
|
||||
)
|
||||
clean = parsed._replace(query="", fragment="")
|
||||
return _url.urlunparse(clean)
|
||||
|
||||
parsed = _url.urlparse(u)
|
||||
clean = parsed._replace(query="", fragment="")
|
||||
return _url.urlunparse(clean)
|
||||
|
||||
|
||||
def _clean_proxy(proxy: str) -> str:
|
||||
if not proxy or proxy.upper() == "DIRECT":
|
||||
return "DIRECT"
|
||||
parsed = urlparse(proxy)
|
||||
return (
|
||||
f"{parsed.scheme}://{parsed.hostname}{f':{parsed.port}' if parsed.port else ''}"
|
||||
if parsed.hostname
|
||||
else proxy
|
||||
)
|
||||
|
||||
|
||||
def platform_badge(u: str) -> str:
|
||||
l = u.lower()
|
||||
if "youtu" in l:
|
||||
return "youtube"
|
||||
if "soundcloud" in l:
|
||||
return "soundcloud"
|
||||
if "twitter" in l or "x.com" in l:
|
||||
return "twitterx"
|
||||
if "bandcamp" in l:
|
||||
return "bandcamp"
|
||||
return "other"
|
||||
|
||||
|
||||
def user_facing_formats(fmts: list[dict]) -> list[dict]:
|
||||
desired_heights = [1440, 1080, 720, 480, 360]
|
||||
out: list[dict] = []
|
||||
|
||||
audio_only = [
|
||||
f for f in fmts if f.get("vcodec") == "none" and f.get("acodec") != "none"
|
||||
]
|
||||
if audio_only:
|
||||
best = max(audio_only, key=lambda x: x.get("tbr") or 0)
|
||||
out.append(
|
||||
{
|
||||
"format_id": best["format_id"],
|
||||
"ext": best.get("ext", "mp3"),
|
||||
"label": "Audio (.mp3)",
|
||||
}
|
||||
)
|
||||
|
||||
for h in desired_heights:
|
||||
candidates = [f for f in fmts if f.get("height") == h and f.get("vcodec") != "none"]
|
||||
if candidates:
|
||||
best = max(candidates, key=lambda x: x.get("tbr") or 0)
|
||||
out.append(
|
||||
{
|
||||
"format_id": best["format_id"],
|
||||
"ext": best.get("ext", "mp4"),
|
||||
"label": f"{h}p",
|
||||
}
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@lru_cache(maxsize=1024)
|
||||
def _cached_metadata_fetch(url: str) -> dict:
|
||||
opts = {"quiet": True, "skip_download": True}
|
||||
try:
|
||||
with yt_dlp.YoutubeDL(opts) as ydl:
|
||||
return ydl.extract_info(url, download=False)
|
||||
except Exception as e:
|
||||
msg = _ansi_escape.sub("", str(e)).strip()
|
||||
log.warning("metadata_fail_direct", url=url, err=msg)
|
||||
raise
|
||||
|
||||
|
||||
def _fetch_metadata_sync(url: str, proxy_url: str = "DIRECT") -> dict:
|
||||
opts = {
|
||||
"quiet": True,
|
||||
"skip_download": True,
|
||||
"proxy": None if proxy_url == "DIRECT" else proxy_url,
|
||||
"http_headers": stealth_headers(),
|
||||
"cookiefile": str(COOKIE_FILE),
|
||||
}
|
||||
|
||||
if not COOKIE_FILE.exists():
|
||||
log.warning("cookie_file_missing", path=str(COOKIE_FILE))
|
||||
|
||||
try:
|
||||
with yt_dlp.YoutubeDL(opts) as ydl:
|
||||
return ydl.extract_info(url, download=False)
|
||||
except Exception as e:
|
||||
clean_proxy = _clean_proxy(proxy_url)
|
||||
msg = _ansi_escape.sub("", str(e)).strip()
|
||||
log.warning("metadata_fail_proxy", url=url, proxy=clean_proxy, err=msg)
|
||||
raise
|
||||
|
||||
|
||||
async def _fetch_metadata(url: str) -> dict:
|
||||
if any(x in url.lower() for x in ("youtube.com", "youtu.be", "bandcamp.com")):
|
||||
return await asyncio.to_thread(_cached_metadata_fetch, url)
|
||||
|
||||
for attempt in range(1, 4):
|
||||
proxy = get_proxy()
|
||||
try:
|
||||
info = await asyncio.to_thread(_fetch_metadata_sync, url, proxy)
|
||||
if not info.get("formats"):
|
||||
raise ValueError("No formats found")
|
||||
record_proxy(proxy, True)
|
||||
return info
|
||||
except Exception as e:
|
||||
record_proxy(proxy, False)
|
||||
err_msg = _ansi_escape.sub("", str(e)).strip()
|
||||
log.warning(
|
||||
"metadata_retry_fail",
|
||||
attempt=attempt,
|
||||
proxy=_clean_proxy(proxy),
|
||||
err=err_msg,
|
||||
)
|
||||
|
||||
raise RuntimeError("Format fetch failed after 3 attempts")
|
||||
|
||||
|
||||
async def choose_format(url: str) -> dict:
|
||||
url = _canonical_url(url)
|
||||
if not re.match(r"^https?://", url, re.I):
|
||||
return {"error": "Invalid URL"}
|
||||
|
||||
if any(x in url.lower() for x in ("soundcloud.com", "x.com")):
|
||||
return {"auto_download": True, "fmt_id": "bestaudio", "url": url}
|
||||
|
||||
info = await asyncio.to_thread(_lookup_cache_sync, url)
|
||||
if info:
|
||||
return {
|
||||
"formats": user_facing_formats(info["formats"]),
|
||||
"title": info.get("title", "Unknown"),
|
||||
"platform": info.get("platform", ""),
|
||||
"url": url,
|
||||
}
|
||||
|
||||
info_raw = await _fetch_metadata(url)
|
||||
|
||||
cache_doc = {
|
||||
"title": info_raw.get("title", "Unknown"),
|
||||
"formats": info_raw.get("formats", []),
|
||||
"platform": platform_badge(url),
|
||||
}
|
||||
|
||||
await asyncio.to_thread(_store_cache_sync, url, cache_doc)
|
||||
|
||||
return {
|
||||
"formats": user_facing_formats(info_raw.get("formats", [])),
|
||||
"title": cache_doc["title"],
|
||||
"platform": cache_doc["platform"],
|
||||
"url": url,
|
||||
}
|
||||
|
||||
|
||||
def _lookup_cache_sync(url: str) -> dict | None:
|
||||
now = datetime.now(timezone.utc)
|
||||
with SessionLocal() as session:
|
||||
try:
|
||||
row = session.execute(
|
||||
select(format_cache.c.info, format_cache.c.cached_at).where(
|
||||
format_cache.c.url == url
|
||||
)
|
||||
).one()
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
info, cached_at = row
|
||||
if cached_at.tzinfo is None:
|
||||
cached_at = cached_at.replace(tzinfo=timezone.utc)
|
||||
|
||||
if (now - cached_at).total_seconds() > FORMAT_CACHE_TTL_SEC:
|
||||
session.execute(delete(format_cache).where(format_cache.c.url == url))
|
||||
session.commit()
|
||||
return None
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def _store_cache_sync(url: str, info: dict) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
stmt = (
|
||||
pg_insert(format_cache)
|
||||
.values(url=url, cached_at=now, info=info)
|
||||
.on_conflict_do_update(index_elements=["url"], set_={"cached_at": now, "info": info})
|
||||
)
|
||||
with SessionLocal.begin() as session:
|
||||
session.execute(stmt)
|
||||
40
backend/core/logging.py
Normal file
40
backend/core/logging.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Logging - 16 May 2025
|
||||
Dev - colored console
|
||||
Prod - structured JSON
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging, os, structlog
|
||||
from core.settings import LOG_LEVEL, ENV
|
||||
|
||||
def init_logging() -> None:
|
||||
log_level = getattr(logging, LOG_LEVEL.upper(), logging.INFO)
|
||||
|
||||
if ENV == "production":
|
||||
processors = [
|
||||
structlog.processors.TimeStamper(fmt="%Y-%m-%dT%H:%M:%S", utc=True),
|
||||
structlog.processors.add_log_level,
|
||||
_add_path,
|
||||
structlog.processors.JSONRenderer(),
|
||||
]
|
||||
else:
|
||||
|
||||
processors = [
|
||||
structlog.processors.TimeStamper(fmt="%H:%M:%S"),
|
||||
structlog.processors.add_log_level,
|
||||
structlog.dev.ConsoleRenderer(colors=True),
|
||||
]
|
||||
|
||||
structlog.configure(
|
||||
wrapper_class=structlog.make_filtering_bound_logger(log_level),
|
||||
processors=processors,
|
||||
)
|
||||
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||
|
||||
def _add_path(_, __, event_dict):
|
||||
from quart import request
|
||||
if request:
|
||||
event_dict["path"] = request.path
|
||||
return event_dict
|
||||
|
||||
49
backend/core/network.py
Normal file
49
backend/core/network.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
network.py - 16 May 2025
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import random, structlog
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from fake_useragent import UserAgent
|
||||
from tls_client import Session as TLSSession
|
||||
from backend.web.db_extra import acquire_proxy, release_proxy, queue_proxy_result
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def stealth_headers(rotate: bool = False) -> dict[str, str]:
|
||||
if rotate:
|
||||
stealth_headers.cache_clear()
|
||||
|
||||
browsers = ["chrome", "firefox", "edge"]
|
||||
browser = random.choice(browsers)
|
||||
client_id_map = {
|
||||
"chrome": ["chrome_122", "chrome_121", "chrome_120"],
|
||||
"firefox": ["firefox_123"],
|
||||
"edge": ["edge_121"],
|
||||
}
|
||||
client_id = random.choice(client_id_map[browser])
|
||||
TLSSession(client_identifier=client_id)
|
||||
|
||||
headers = {
|
||||
"User-Agent": UserAgent()[browser],
|
||||
"Accept-Language": random.choice(
|
||||
["en-US,en;q=0.9", "en-GB,en;q=0.9", "en;q=0.8"]
|
||||
),
|
||||
}
|
||||
return headers
|
||||
|
||||
|
||||
def get_proxy() -> str:
|
||||
px = acquire_proxy()
|
||||
if px:
|
||||
return px
|
||||
log.debug("proxy.none", msg="DIRECT fallback")
|
||||
return "DIRECT"
|
||||
|
||||
def record_proxy(px: str, ok: bool) -> None:
|
||||
if not px or px == "DIRECT":
|
||||
return
|
||||
queue_proxy_result(px, ok)
|
||||
release_proxy(px, ok)
|
||||
62
backend/core/progress_bus.py
Normal file
62
backend/core/progress_bus.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
progress_bus.py - 07 May 2025
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import asyncio, json, time
|
||||
from typing import Dict, Any
|
||||
|
||||
# in-mem state
|
||||
_progress: Dict[str, Dict[str, Any]] = {}
|
||||
_watchers: Dict[str, list[asyncio.Queue[str]]] = {}
|
||||
_TTL = 60 * 60 # keep finished/error records 1 h
|
||||
|
||||
def _now() -> float: return time.time()
|
||||
|
||||
def register(sid: str) -> None:
|
||||
_progress[sid] = dict(pct=0, progress="", status="running", ts=_now())
|
||||
_broadcast(sid)
|
||||
|
||||
def update(sid: str, *, pct: float | None = None,
|
||||
progress: str | None = None, status: str | None = None) -> None:
|
||||
if sid not in _progress:
|
||||
register(sid)
|
||||
p = _progress[sid]
|
||||
if pct is not None: p["pct"] = pct
|
||||
if progress is not None: p["progress"] = progress
|
||||
if status is not None: p["status"] = status
|
||||
p["ts"] = _now()
|
||||
_broadcast(sid)
|
||||
|
||||
def get(sid: str) -> Dict[str, Any]:
|
||||
_gc()
|
||||
return _progress.get(sid, {"status": "idle"})
|
||||
|
||||
def clear(sid: str) -> None:
|
||||
_progress.pop(sid, None)
|
||||
_watchers.pop(sid, None)
|
||||
|
||||
# SSE integration
|
||||
def subscribe(sid: str) -> asyncio.Queue[str]:
|
||||
q: asyncio.Queue[str] = asyncio.Queue(maxsize=16)
|
||||
_watchers.setdefault(sid, []).append(q)
|
||||
# immediately push current state
|
||||
q.put_nowait(json.dumps({"sid": sid, **get(sid)}))
|
||||
return q
|
||||
|
||||
def _broadcast(sid: str) -> None:
|
||||
if sid not in _watchers:
|
||||
return
|
||||
payload = json.dumps({"sid": sid, **_progress[sid]})
|
||||
for q in list(_watchers[sid]):
|
||||
try:
|
||||
q.put_nowait(payload)
|
||||
except asyncio.QueueFull:
|
||||
pass # drop frame
|
||||
|
||||
# garbage collector
|
||||
def _gc() -> None:
|
||||
now = _now()
|
||||
stale = [k for k, v in _progress.items()
|
||||
if v["status"] in ("finished", "error") and now - v["ts"] > _TTL]
|
||||
for k in stale:
|
||||
clear(k)
|
||||
48
backend/core/settings.py
Normal file
48
backend/core/settings.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from functools import lru_cache
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# ─── Paths ───────────────────────────────────────────────
|
||||
ROOT_DIR = Path(__file__).resolve().parent.parent
|
||||
DATA_DIR = ROOT_DIR / "data"
|
||||
TMP_DIR = Path("/tmp")
|
||||
DOWNLOAD_DIR = DATA_DIR / "downloads"
|
||||
USERS_DIR = DATA_DIR / "users"
|
||||
ENV = os.getenv("APP_ENV", "development")
|
||||
PROXY_LIST_FILE = Path(os.getenv("PROXY_LIST_FILE", ".50.txt"))
|
||||
# ─── Additional ───────────────────────────────────────────────
|
||||
PROXY_USERNAME = os.getenv("PROXY_USERNAME")
|
||||
PROXY_PASSWORD = os.getenv("PROXY_PASSWORD")
|
||||
# ─── DB and SQLAlchemy ───────────────────────────────────────────
|
||||
SQLALCHEMY_DATABASE_URI= os.getenv("DATABASE_URL") or f"sqlite:///{DATA_DIR / 'local.db'}"
|
||||
DB_POOL_SIZE = int(os.getenv("DB_POOL_SIZE", 20))
|
||||
DB_ECHO = bool(os.getenv("DB_ECHO", False))
|
||||
# ─── Concurrency ───────────────────────────────────────────────
|
||||
CPU_COUNT = os.cpu_count() or 2
|
||||
THREADS_MAX = min(32, CPU_COUNT * 4)
|
||||
PROCS_MAX = min(CPU_COUNT, 4)
|
||||
PER_IP_CONCURRENCY = int(os.getenv("PER_IP_CONCURRENCY", 2))
|
||||
# ─── Cache and Tuning knobs ──────────────────────────────────────
|
||||
FORMAT_CACHE_TTL_SEC = int(os.getenv("FORMAT_CACHE_TTL_SEC", 8_000))
|
||||
DOWNLOAD_CACHE_TTL_SEC = int(os.getenv("DOWNLOAD_CACHE_TTL_SEC", 86_400)) # 24h
|
||||
PARALLEL_CHUNK_MB = int(os.getenv("PARALLEL_CHUNK_MB", 2))
|
||||
MAX_CONCURRENT_FRAG = int(os.getenv("MAX_CONCURRENT_FRAG", 4))
|
||||
ARIA2C_THRESHOLD_MB = int(os.getenv("ARIA2C_THRESHOLD_MB", 512))
|
||||
MIN_SCORE = float(os.getenv("PROXY_MIN_SCORE", "0.05"))
|
||||
MAX_IN_USE = int(os.getenv("PROXY_CONCURRENCY_LIMIT", "4"))
|
||||
FAIL_COOLDOWN_SEC = int(os.getenv("PROXY_FAIL_COOLDOWN", "600"))
|
||||
_MAX_LOGIN_FAILS = int(os.getenv("MAX_LOGIN_FAILS", "12"))
|
||||
_MAX_INVALID_URLS = int(os.getenv("MAX_INVALID_URLS", "20"))
|
||||
_WINDOW_MINUTES = int(os.getenv("WINDOW_MINUTES", "60"))
|
||||
# ─── Logging ───────────────────────────────────────────────────
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
|
||||
# ─── db ───────────────────────────────────────────────────
|
||||
SKIP_SCHEMA_BOOTSTRAP = int(os.getenv("SKIP_SCHEMA_BOOTSTRAP","0"))
|
||||
@lru_cache
|
||||
def ensure_dirs() -> None:
|
||||
for p in (DATA_DIR, USERS_DIR, DOWNLOAD_DIR):
|
||||
p.mkdir(parents=True, exist_ok=True)
|
||||
ensure_dirs()
|
||||
0
backend/web/__init__.py
Normal file
0
backend/web/__init__.py
Normal file
303
backend/web/app.py
Normal file
303
backend/web/app.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""app.py – Jul 15 2025"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import os
|
||||
import secrets
|
||||
import shutil
|
||||
import signal
|
||||
import sysconfig
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import aiofiles
|
||||
import structlog
|
||||
from quart import (
|
||||
Quart,
|
||||
Response,
|
||||
jsonify,
|
||||
redirect,
|
||||
render_template,
|
||||
request,
|
||||
session,
|
||||
url_for, render_template_string,
|
||||
)
|
||||
|
||||
from backend.core.logging import init_logging
|
||||
from core.settings import TMP_DIR, DOWNLOAD_DIR
|
||||
from core.formats import choose_format, _lookup_cache_sync, _cached_metadata_fetch
|
||||
from core.download import download, EST_MB
|
||||
from core.formats import choose_format as choose_format_logic
|
||||
from core.db_xp import is_ip_banned, ensure_user, get_status
|
||||
from core.web.db_extra import invalid_over_limit, init_proxy_seed, start_background_tasks
|
||||
from core.db import metadata, engine
|
||||
from core import progress_bus
|
||||
|
||||
init_logging()
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
def _frontend_root() -> Path:
|
||||
here = Path(__file__).resolve().parent
|
||||
dev = here.parent.parent / "frontend"
|
||||
return dev if dev.exists() else Path(sysconfig.get_path("data")) / "share" / "s1ne" / "frontend"
|
||||
|
||||
|
||||
FRONTEND_ROOT = _frontend_root()
|
||||
app = Quart(
|
||||
__name__,
|
||||
template_folder=str(FRONTEND_ROOT / "templates"),
|
||||
static_folder=str(FRONTEND_ROOT / "static"),
|
||||
)
|
||||
app.secret_key = os.getenv("SECRET_KEY_WORD")
|
||||
|
||||
_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
async def _cleanup_temp(interval: int = 900) -> None:
|
||||
while True:
|
||||
cutoff = asyncio.get_event_loop().time() - 60 * 60 * 12
|
||||
for p in (TMP_DIR / "ytlocks").glob("*.lock"):
|
||||
if p.stat().st_mtime < cutoff:
|
||||
p.unlink(missing_ok=True)
|
||||
for pattern in ("yt_*", "tmp*"):
|
||||
for p in TMP_DIR.glob(pattern):
|
||||
if p.is_dir() and p.stat().st_mtime < cutoff:
|
||||
shutil.rmtree(p, ignore_errors=True)
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
||||
async def _file_iter(path: Path, chunk: int = 1 << 15):
|
||||
async with aiofiles.open(path, "rb") as f:
|
||||
while (blk := await f.read(chunk)):
|
||||
yield blk
|
||||
|
||||
async def _shutdown_waiter():
|
||||
await asyncio.sleep(0.1)
|
||||
log.info("shutdown.tasks_cancelled")
|
||||
|
||||
def _graceful_exit() -> None:
|
||||
log.info("shutdown.initiated")
|
||||
for t in list(_tasks.values()):
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
asyncio.create_task(_shutdown_waiter())
|
||||
|
||||
def force_exit():
|
||||
import time
|
||||
time.sleep(5)
|
||||
os._exit(1)
|
||||
|
||||
threading.Thread(target=force_exit, daemon=True).start()
|
||||
|
||||
|
||||
@app.before_serving
|
||||
async def _launch_tasks() -> None:
|
||||
metadata.create_all(engine)
|
||||
await init_proxy_seed()
|
||||
start_background_tasks(asyncio.get_running_loop())
|
||||
asyncio.create_task(_cleanup_temp())
|
||||
loop = asyncio.get_running_loop()
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, _graceful_exit)
|
||||
|
||||
|
||||
@app.route("/")
|
||||
async def home():
|
||||
if not request.cookies.get("auth"):
|
||||
return await render_template("login.html")
|
||||
ip = request.remote_addr or "0.0.0.0"
|
||||
ensure_user(ip)
|
||||
soft_banned = get_status(ip)["soft_banned"]
|
||||
return await render_template("index.html", soft_banned=soft_banned)
|
||||
|
||||
|
||||
@app.route("/login", methods=["GET", "POST"])
|
||||
async def login():
|
||||
if request.method == "GET":
|
||||
return await render_template("login.html")
|
||||
form = await request.form
|
||||
if form.get("password") == os.getenv("MASTER_PASSWORD"):
|
||||
resp = redirect(url_for("home"))
|
||||
resp.set_cookie("auth", "1", httponly=True, secure=False)
|
||||
return resp
|
||||
return await render_template("login.html", error_badge="Incorrect password")
|
||||
|
||||
|
||||
@app.route("/logout")
|
||||
async def logout():
|
||||
session.clear()
|
||||
resp = redirect(url_for("login"))
|
||||
resp.delete_cookie("auth")
|
||||
return resp
|
||||
|
||||
|
||||
@app.route("/choose_format", methods=["POST"])
|
||||
async def handle_choose_format() -> Response:
|
||||
url: str
|
||||
try:
|
||||
|
||||
if request.content_type and "application/json" in request.content_type:
|
||||
data = await request.get_json(silent=True) or {}
|
||||
url = (data.get("url") or "").strip()
|
||||
else:
|
||||
form = await request.form
|
||||
url = (form.get("url") or "").strip()
|
||||
|
||||
if not url:
|
||||
#log.warning("choose_format.missing_url")
|
||||
return jsonify({"error": "url field required"}), 422
|
||||
|
||||
|
||||
run_id: str = session.get("run_id") or secrets.token_urlsafe(10)
|
||||
session["run_id"] = run_id
|
||||
#log.info("choose_format.run_id_set", run_id=run_id, url=url)
|
||||
|
||||
|
||||
res: dict = await choose_format_logic(url)
|
||||
res["sid"] = run_id
|
||||
|
||||
if "error" in res:
|
||||
#log.warning("choose_format.logic_error", error=res["error"], url=url)
|
||||
return jsonify(res), 400
|
||||
|
||||
log.info("choose_format.success", url=url, title=res.get("title"), platform=res.get("platform"))
|
||||
return jsonify(res)
|
||||
|
||||
except Exception as e:
|
||||
log.exception("choose_format.exception", err=str(e))
|
||||
return jsonify({"error": "Internal error during format selection"}), 500
|
||||
|
||||
|
||||
@app.route("/download_file")
|
||||
async def dl():
|
||||
ip = request.remote_addr or "0.0.0.0"
|
||||
|
||||
|
||||
if ip == "127.0.0.1":
|
||||
log.info("dev_mode.skip_ban_check", ip=ip)
|
||||
else:
|
||||
if is_ip_banned(ip):
|
||||
log.warning("download.reject.banned", ip=ip)
|
||||
return jsonify({"error": "Banned"}), 403
|
||||
|
||||
|
||||
url = request.args.get("url", "").strip()
|
||||
fmt = request.args.get("format_id", "").strip()
|
||||
sid = request.args.get("sid", "").strip()
|
||||
run_id = session.get("run_id")
|
||||
|
||||
if run_id is None:
|
||||
pass
|
||||
|
||||
if is_ip_banned(ip):
|
||||
log.warning("download.reject.banned", ip=ip)
|
||||
return jsonify({"error": "Banned"}), 403
|
||||
|
||||
if not url or not fmt:
|
||||
log.warning("download.reject.missing_params", url=url, fmt=fmt)
|
||||
return jsonify({"error": "Missing URL or format"}), 400
|
||||
|
||||
if sid in _tasks and not _tasks[sid].done():
|
||||
log.warning("download.reject.already_running", sid=sid)
|
||||
return jsonify({"error": "download already running"}), 409
|
||||
|
||||
if sid != run_id:
|
||||
log.warning("download.reject.sid_mismatch", sid=sid, session_run_id=run_id)
|
||||
return jsonify({
|
||||
"error": "Session mismatch – please refresh the page and select a format again."
|
||||
}), 403
|
||||
|
||||
progress_bus.register(sid)
|
||||
|
||||
|
||||
async def _run_download() -> Path:
|
||||
try:
|
||||
meta = await asyncio.to_thread(_lookup_cache_sync, url)
|
||||
if meta:
|
||||
chosen = next((f for f in meta["formats"] if f["format_id"] == fmt), None)
|
||||
est = (
|
||||
chosen.get("filesize")
|
||||
or chosen.get("filesize_approx")
|
||||
or 0
|
||||
) if chosen else 0
|
||||
EST_MB.set(int(est / 1_048_576))
|
||||
|
||||
log.info("download.starting", sid=sid, url=url, fmt=fmt)
|
||||
path_str = await download(url, fmt, ip, sid)
|
||||
return Path(path_str)
|
||||
|
||||
finally:
|
||||
_tasks.pop(sid, None)
|
||||
|
||||
task = asyncio.create_task(_run_download())
|
||||
_tasks[sid] = task
|
||||
|
||||
try:
|
||||
tmp_path = await task
|
||||
mime = mimetypes.guess_type(tmp_path.name)[0] or "application/octet-stream"
|
||||
log.info("download.success", file=str(tmp_path), sid=sid)
|
||||
|
||||
resp = Response(
|
||||
_file_iter(tmp_path),
|
||||
headers={
|
||||
"Content-Type": mime,
|
||||
"Content-Disposition": f'attachment; filename="{tmp_path.name}"',
|
||||
},
|
||||
)
|
||||
|
||||
if hasattr(resp, "call_after_response"):
|
||||
def _after():
|
||||
progress_bus.update(sid, status="finished", pct=100, progress="Done")
|
||||
progress_bus.clear(sid)
|
||||
if str(tmp_path.parent).startswith(str(TMP_DIR)):
|
||||
shutil.rmtree(tmp_path.parent, ignore_errors=True)
|
||||
|
||||
resp.call_after_response(_after)
|
||||
|
||||
return resp
|
||||
|
||||
except asyncio.CancelledError:
|
||||
log.warning("download.cancelled", sid=sid)
|
||||
progress_bus.update(sid, status="cancelled", progress="Cancelled")
|
||||
return jsonify({"error": "Download cancelled"}), 499
|
||||
|
||||
except Exception as e:
|
||||
log.exception("download.failed", sid=sid, err=str(e))
|
||||
progress_bus.update(sid, status="error", progress="Error")
|
||||
return jsonify({"error": "Download failed"}), 500
|
||||
|
||||
|
||||
@app.route("/cancel_download", methods=["POST"])
|
||||
async def cancel_dl():
|
||||
sid = request.args.get("sid", "").strip()
|
||||
if sid:
|
||||
task = _tasks.get(sid)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
progress_bus.update(sid, status="cancelled", progress="Cancelled")
|
||||
return jsonify({"status": "cancelled"})
|
||||
|
||||
|
||||
@app.route("/api/progress/<sid>")
|
||||
async def progress_stream(sid: str):
|
||||
q = progress_bus.subscribe(sid)
|
||||
|
||||
async def gen():
|
||||
while True:
|
||||
msg = await q.get()
|
||||
yield f"data: {msg}\n\n"
|
||||
|
||||
return Response(
|
||||
gen(),
|
||||
content_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-store"},
|
||||
)
|
||||
|
||||
|
||||
@app.before_serving
|
||||
async def _on_startup():
|
||||
pass
|
||||
320
backend/web/db_extra.py
Normal file
320
backend/web/db_extra.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
backend/web/db_extra.py - 16 May 2025
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
from datetime import timezone
|
||||
import structlog
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import (
|
||||
Table, Column, Text, Float, Integer, Boolean, DateTime, func,
|
||||
select, insert, update, delete, or_, inspect, text
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from backend.core.db import SessionLocal, metadata, engine
|
||||
from backend.core.settings import (
|
||||
MAX_IN_USE, FAIL_COOLDOWN_SEC, MIN_SCORE, PROXY_LIST_FILE,
|
||||
_WINDOW_MINUTES, PROXY_USERNAME, PROXY_PASSWORD,
|
||||
_MAX_LOGIN_FAILS, _MAX_INVALID_URLS
|
||||
)
|
||||
|
||||
log = structlog.get_logger()
|
||||
_IS_PG = engine.url.get_backend_name().startswith("postgres")
|
||||
|
||||
|
||||
def _insert_ignore(tbl: Table, **vals):
|
||||
if _IS_PG:
|
||||
return pg_insert(tbl).values(**vals).on_conflict_do_nothing()
|
||||
return insert(tbl).prefix_with("OR IGNORE").values(**vals)
|
||||
|
||||
|
||||
def _clamp_zero(expr):
|
||||
"""SQL‑portable max(expr, 0)."""
|
||||
return func.greatest(expr, 0) if _IS_PG else func.max(expr, 0)
|
||||
|
||||
|
||||
|
||||
proxy_tbl = Table(
|
||||
"proxies", metadata,
|
||||
Column("url", Text, primary_key=True),
|
||||
Column("score", Float, nullable=False, server_default="1.0"),
|
||||
Column("fails", Integer, nullable=False, server_default="0"),
|
||||
Column("banned", Boolean, nullable=False, server_default="false"),
|
||||
Column("in_use", Integer, nullable=False, server_default="0"),
|
||||
Column("last_fail", DateTime),
|
||||
Column("updated_at", DateTime, server_default=func.now(), index=True),
|
||||
)
|
||||
|
||||
login_tbl = Table(
|
||||
"login_attempts", metadata,
|
||||
Column("ip", Text, primary_key=True),
|
||||
Column("count", Integer, nullable=False, server_default="0"),
|
||||
Column("updated_at", DateTime, nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
invalid_tbl = Table(
|
||||
"invalid_urls", metadata,
|
||||
Column("ip", Text, primary_key=True),
|
||||
Column("count", Integer, nullable=False, server_default="0"),
|
||||
Column("updated_at", DateTime, nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
dl_stats = Table(
|
||||
"dl_stats", metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=True),
|
||||
Column("ok", Boolean, nullable=False),
|
||||
Column("ts", DateTime, nullable=False, server_default=func.now(), index=True),
|
||||
)
|
||||
|
||||
|
||||
def _ensure_proxy_columns() -> None:
|
||||
insp = inspect(engine)
|
||||
if "proxies" not in insp.get_table_names():
|
||||
return
|
||||
|
||||
existing = {c["name"] for c in insp.get_columns("proxies")}
|
||||
add: list[tuple[str, str]] = []
|
||||
if "in_use" not in existing: add.append(("in_use", "INTEGER DEFAULT 0"))
|
||||
if "last_fail" not in existing: add.append(("last_fail", "TIMESTAMP"))
|
||||
if not add:
|
||||
return
|
||||
|
||||
with engine.begin() as conn:
|
||||
for col, ddl in add:
|
||||
if _IS_PG:
|
||||
conn.execute(text(f"ALTER TABLE proxies ADD COLUMN IF NOT EXISTS {col} {ddl};"))
|
||||
else:
|
||||
conn.execute(text(f"ALTER TABLE proxies ADD COLUMN {col} {ddl};"))
|
||||
log.info("proxy.schema.auto_migrated", added=[c for c, _ in add])
|
||||
|
||||
|
||||
#metadata.create_all(engine)
|
||||
_ensure_proxy_columns()
|
||||
|
||||
|
||||
|
||||
def _seed() -> None:
|
||||
if not PROXY_LIST_FILE.exists():
|
||||
return
|
||||
|
||||
with SessionLocal.begin() as s:
|
||||
for ln in PROXY_LIST_FILE.read_text().splitlines():
|
||||
ln = ln.strip()
|
||||
if not ln:
|
||||
continue
|
||||
ip, port = ln.split(":", 1)
|
||||
px = (
|
||||
f"http://{PROXY_USERNAME}:{PROXY_PASSWORD}@{ip}:{port}"
|
||||
if PROXY_USERNAME else f"http://{ip}:{port}"
|
||||
)
|
||||
s.execute(_insert_ignore(proxy_tbl, url=px))
|
||||
|
||||
|
||||
|
||||
def _candidate_stmt(now: dt.datetime):
|
||||
cool_ts = now - dt.timedelta(seconds=FAIL_COOLDOWN_SEC)
|
||||
jitter = func.random() * 0.01
|
||||
return (
|
||||
select(proxy_tbl.c.url)
|
||||
.where(
|
||||
proxy_tbl.c.banned.is_(False),
|
||||
proxy_tbl.c.score > MIN_SCORE,
|
||||
proxy_tbl.c.in_use < MAX_IN_USE,
|
||||
or_(proxy_tbl.c.last_fail.is_(None), proxy_tbl.c.last_fail < cool_ts),
|
||||
)
|
||||
.order_by((proxy_tbl.c.score + jitter).desc())
|
||||
.limit(1)
|
||||
.with_for_update(nowait=False)
|
||||
)
|
||||
|
||||
|
||||
def acquire_proxy() -> str | None:
|
||||
now = dt.datetime.now(timezone.utc)
|
||||
with SessionLocal.begin() as s:
|
||||
row = s.execute(_candidate_stmt(now)).first()
|
||||
if not row:
|
||||
return None
|
||||
px = row[0]
|
||||
s.execute(
|
||||
update(proxy_tbl)
|
||||
.where(proxy_tbl.c.url == px)
|
||||
.values(in_use=proxy_tbl.c.in_use + 1, updated_at=now)
|
||||
)
|
||||
return px
|
||||
|
||||
|
||||
def release_proxy(px: str, ok: bool) -> None:
|
||||
if not px or px == "DIRECT":
|
||||
return
|
||||
|
||||
now = dt.datetime.now(timezone.utc)
|
||||
with SessionLocal.begin() as s:
|
||||
new_in_use = proxy_tbl.c.in_use - 1
|
||||
s.execute(
|
||||
update(proxy_tbl)
|
||||
.where(proxy_tbl.c.url == px)
|
||||
.values(
|
||||
in_use=_clamp_zero(new_in_use),
|
||||
updated_at=now,
|
||||
last_fail=None if ok else now,
|
||||
)
|
||||
)
|
||||
|
||||
_buffer: asyncio.Queue[tuple[str, bool]] = asyncio.Queue(maxsize=2048)
|
||||
|
||||
|
||||
def queue_proxy_result(px: str, ok: bool) -> None:
|
||||
try:
|
||||
_buffer.put_nowait((px, ok))
|
||||
except asyncio.QueueFull:
|
||||
try:
|
||||
_buffer.get_nowait()
|
||||
_buffer.put_nowait((px, ok))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _flusher() -> None:
|
||||
while True:
|
||||
await asyncio.sleep(0.4)
|
||||
if _buffer.empty():
|
||||
continue
|
||||
|
||||
|
||||
batch: dict[str, tuple[int, int]] = {}
|
||||
while not _buffer.empty():
|
||||
px, ok = _buffer.get_nowait()
|
||||
succ, fail = batch.get(px, (0, 0))
|
||||
if ok:
|
||||
succ += 1
|
||||
else:
|
||||
fail += 1
|
||||
batch[px] = (succ, fail)
|
||||
|
||||
now = dt.datetime.now(timezone.utc)
|
||||
with SessionLocal.begin() as s:
|
||||
for px, (succ, fail) in batch.items():
|
||||
delta = 0.1 * succ - 0.2 * fail
|
||||
stmt = (
|
||||
update(proxy_tbl)
|
||||
.where(proxy_tbl.c.url == px)
|
||||
.values(
|
||||
score=_clamp_zero(proxy_tbl.c.score + delta),
|
||||
fails=_clamp_zero(proxy_tbl.c.fails + fail - succ),
|
||||
banned=(proxy_tbl.c.fails + fail) > 5,
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
s.execute(stmt)
|
||||
|
||||
|
||||
def start_background_tasks(loop: asyncio.AbstractEventLoop) -> None:
|
||||
loop.create_task(_flusher())
|
||||
loop.create_task(asyncio.to_thread(_seed))
|
||||
|
||||
|
||||
_WINDOW_N = 50
|
||||
|
||||
|
||||
def add_dl_stat(ok: bool) -> None:
|
||||
now = dt.datetime.now(timezone.utc)
|
||||
with SessionLocal.begin() as s:
|
||||
s.execute(insert(dl_stats).values(ok=ok, ts=now))
|
||||
|
||||
# -------- FIX ③ --------
|
||||
oldest_keep = select(dl_stats.c.id).order_by(
|
||||
dl_stats.c.id.desc()
|
||||
).limit(500)
|
||||
s.execute(
|
||||
delete(dl_stats).where(~dl_stats.c.id.in_(oldest_keep))
|
||||
)
|
||||
|
||||
|
||||
def recent_success_rate(n: int = _WINDOW_N) -> float:
|
||||
with SessionLocal() as s:
|
||||
vals = (
|
||||
s.execute(select(dl_stats.c.ok).order_by(dl_stats.c.id.desc()).limit(n))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
return 0.5 if not vals else sum(vals) / len(vals)
|
||||
|
||||
|
||||
|
||||
def _inc(table: Table, ip: str) -> None:
|
||||
now = dt.datetime.now(timezone.utc)
|
||||
with SessionLocal.begin() as s:
|
||||
row = s.execute(select(table).where(table.c.ip == ip)).first()
|
||||
if not row:
|
||||
s.execute(insert(table).values(ip=ip, count=1, updated_at=now))
|
||||
else:
|
||||
s.execute(
|
||||
update(table)
|
||||
.where(table.c.ip == ip)
|
||||
.values(count=row.count + 1, updated_at=now)
|
||||
)
|
||||
|
||||
|
||||
def record_login(ip: str, success: bool) -> None:
|
||||
if success:
|
||||
with SessionLocal.begin() as s:
|
||||
s.execute(update(login_tbl).where(login_tbl.c.ip == ip).values(count=0))
|
||||
else:
|
||||
_inc(login_tbl, ip)
|
||||
|
||||
|
||||
def inc_invalid(ip: str) -> None:
|
||||
_inc(invalid_tbl, ip)
|
||||
|
||||
|
||||
def _over_limit(table: Table, ip: str, cap: int) -> bool:
|
||||
with SessionLocal() as s:
|
||||
row = s.execute(
|
||||
select(table.c.count, table.c.updated_at).where(table.c.ip == ip)
|
||||
).first()
|
||||
if not row:
|
||||
return False
|
||||
|
||||
count, ts = row
|
||||
now = dt.datetime.now(timezone.utc)
|
||||
|
||||
if ts.tzinfo is None:
|
||||
ts = ts.replace(tzinfo=timezone.utc)
|
||||
|
||||
if (now - ts).total_seconds() > _WINDOW_MINUTES * 60:
|
||||
with SessionLocal.begin() as sx:
|
||||
sx.execute(update(table).where(table.c.ip == ip).values(count=0))
|
||||
return False
|
||||
|
||||
return count >= cap
|
||||
|
||||
|
||||
def too_many_attempts(ip: str) -> bool:
|
||||
return _over_limit(login_tbl, ip, _MAX_LOGIN_FAILS)
|
||||
|
||||
|
||||
def invalid_over_limit(ip: str) -> bool:
|
||||
return _over_limit(invalid_tbl, ip, _MAX_INVALID_URLS)
|
||||
|
||||
|
||||
|
||||
def pick_proxy() -> str | None:
|
||||
return acquire_proxy()
|
||||
|
||||
|
||||
def ensure_proxy(px: str) -> None:
|
||||
with SessionLocal.begin() as s:
|
||||
s.execute(_insert_ignore(proxy_tbl, url=px))
|
||||
|
||||
|
||||
def update_proxy(px: str, ok: bool) -> None:
|
||||
queue_proxy_result(px, ok)
|
||||
|
||||
|
||||
async def init_proxy_seed() -> None:
|
||||
await asyncio.to_thread(_seed)
|
||||
Reference in New Issue
Block a user