Files
s1ne/backend/core/download.py
2026-03-29 23:50:49 -05:00

288 lines
9.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""backend/core/download.py — patched 2025-06-03"""
from __future__ import annotations
import asyncio
import contextvars
import datetime
import hashlib
import os
import random
import re
from pathlib import Path
from typing import Dict
import yt_dlp
from sqlalchemy import select, delete
from sqlalchemy.exc import NoResultFound
from core.db_utils import upsert
from core.settings import (
DOWNLOAD_DIR,
TMP_DIR,
PER_IP_CONCURRENCY,
DOWNLOAD_CACHE_TTL_SEC,
)
from core.network import get_proxy, record_proxy
from core.db_xp import ensure_user
from core.db import SessionLocal, download_cache
from core.progress_bus import update as set_progress
from core.formats import _cached_metadata_fetch, _clean_proxy
os.makedirs(DOWNLOAD_DIR, exist_ok=True)
os.makedirs(TMP_DIR, exist_ok=True)
EST_MB = contextvars.ContextVar("est_mb", default=0)
MAX_GLOBAL_DOWNLOADS = PER_IP_CONCURRENCY * 4
_global_semaphore = asyncio.Semaphore(MAX_GLOBAL_DOWNLOADS)
_ip_semaphores: Dict[str, asyncio.BoundedSemaphore] = {}
_inflight: Dict[str, asyncio.Task[str]] = {}
_ip_cache: Dict[str, set[str]] = {}
def _get_ip_cache(ip: str) -> set[str]:
return _ip_cache.setdefault(ip, set())
def _url_fmt_hash(url: str, fmt: str) -> str:
return hashlib.blake2s(f"{url}::{fmt}".encode(), digest_size=16).hexdigest()
_ansi_escape = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]")
async def download(url: str, fmt_id: str, ip: str, sid: str) -> str:
fmt = fmt_id or "bestaudio"
key = _url_fmt_hash(url, fmt)
dup_key = f"{ip}::{url}::{fmt}"
cached = await asyncio.to_thread(_lookup_cache_sync, key)
if cached:
set_progress(sid, status="cached", pct=100, progress="Instant served from cache")
return cached
if key in _inflight:
return await _inflight[key]
ip_cache_set = _get_ip_cache(ip)
if dup_key in ip_cache_set:
cached2 = await asyncio.to_thread(_lookup_cache_sync, key)
if cached2:
set_progress(sid, status="cached", pct=100, progress="Instant served from cache")
return cached2
ip_cache_set.add(dup_key)
sem = _ip_semaphores.setdefault(ip, asyncio.BoundedSemaphore(PER_IP_CONCURRENCY))
async def _run() -> str:
async with _global_semaphore, sem:
ensure_user(ip)
set_progress(sid, status="starting", pct=0, progress="Starting…")
try:
info = await asyncio.to_thread(_cached_metadata_fetch, url)
except Exception as e:
set_progress(sid, status="error", progress=f"Metadata fetch failed: {e}")
raise
attempt = 0
last_exc: Exception | None = None
while attempt < 3:
attempt += 1
proxies = [get_proxy() for _ in range(5)]
random.shuffle(proxies)
proxies.append("DIRECT")
for proxy_url in proxies:
try:
final_path = await _single_download(
url,
fmt,
key,
sid,
proxy_url,
info,
)
except asyncio.CancelledError:
raise
except Exception as exc:
last_exc = exc
record_proxy(proxy_url, False)
clean_proxy = _clean_proxy(proxy_url)
set_progress(
sid,
status="retrying",
progress=f"Retry {attempt} failed (proxy {clean_proxy})",
)
await asyncio.sleep(1 + random.random())
continue
else:
record_proxy(proxy_url, True)
await asyncio.to_thread(_store_cache_sync, key, final_path)
set_progress(sid, status="finished", pct=100, progress="Done")
return final_path
set_progress(sid, status="error", progress="Download failed")
raise RuntimeError(f"All download attempts failed: {last_exc!r}")
task = asyncio.create_task(_run())
_inflight[key] = task
try:
return await task
finally:
_inflight.pop(key, None)
asyncio.create_task(_expire_ip_cache_entry(ip, dup_key))
async def _single_download(
url: str,
fmt: str,
_unused_cache_key: str,
sid: str,
proxy_url: str,
info: dict,
) -> str:
title = info.get("title") or "unknown"
artist = info.get("artist") or info.get("uploader") or "unknown"
def _clean(s: str) -> str:
return re.sub(r'[\\/*?:"<>|]', "", s)
safe_title = _clean(title)
safe_artist = _clean(artist)
short_id = hashlib.blake2s(f"{url}::{fmt}".encode(), digest_size=8).hexdigest()
base = f"{safe_title} - {safe_artist} - {short_id}"
fmt_entry = next((f for f in info.get("formats", []) if f.get("format_id") == fmt), None)
is_audio_only = bool(fmt_entry and fmt_entry.get("vcodec") == "none")
if "soundcloud.com" in url.lower():
is_audio_only = True
# force .mp3 for audio-only, .mp4 otherwise
ext_guess = "mp3" if is_audio_only else "mp4"
outtmpl_path = DOWNLOAD_DIR / f"{base}.%(ext)s"
final_path_expected = DOWNLOAD_DIR / f"{base}.{ext_guess}"
if final_path_expected.exists() and final_path_expected.stat().st_size > 0:
return str(final_path_expected)
cmd = ["yt-dlp", "-f", fmt, "-o", str(outtmpl_path), url]
if is_audio_only:
cmd = ["yt-dlp", "-x", "--audio-format", "mp3", "-o", str(outtmpl_path), url]
else:
cmd = ["yt-dlp", "-f", f"{fmt}+bestaudio", "-o", str(outtmpl_path), url]
if proxy_url and proxy_url.upper() != "DIRECT":
cmd.insert(1, f"--proxy={proxy_url}")
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stderr_buffer: list[str] = []
try:
while True:
try:
line = await asyncio.wait_for(proc.stderr.readline(), timeout=10)
except asyncio.TimeoutError:
break
if not line:
break
decoded = line.decode(errors="ignore").strip()
if decoded:
stderr_buffer.append(decoded)
if len(stderr_buffer) > 10:
stderr_buffer.pop(0)
set_progress(sid, status="running", pct=None, progress=decoded)
try:
rc = await asyncio.wait_for(proc.wait(), timeout=15)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise RuntimeError("yt-dlp stalled and was killed (timeout)")
if rc != 0:
raise RuntimeError(f"yt-dlp exited with code {rc}. Last lines: {' | '.join(stderr_buffer)}")
candidates = [
p for p in DOWNLOAD_DIR.glob(f"{base}.*")
if p.is_file() and p.stat().st_size > 0
]
if not candidates:
raise RuntimeError("No output file produced")
candidates.sort(key=lambda p: p.stat().st_mtime, reverse=True)
return str(candidates[0])
except asyncio.CancelledError:
proc.kill()
await proc.wait()
raise
except Exception:
for f in DOWNLOAD_DIR.glob(f"{base}.*"):
f.unlink(missing_ok=True)
raise
def _lookup_cache_sync(key: str) -> str | None:
now = datetime.datetime.now(datetime.timezone.utc)
with SessionLocal() as session:
try:
row = session.execute(
select(download_cache.c.path, download_cache.c.created_at)
.where(download_cache.c.key == key)
).one()
except NoResultFound:
return None
path_on_disk, created_at = row
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=datetime.timezone.utc)
age = (now - created_at).total_seconds()
if age > DOWNLOAD_CACHE_TTL_SEC:
session.execute(delete(download_cache).where(download_cache.c.key == key))
session.commit()
try:
os.remove(path_on_disk)
except OSError:
pass
return None
if not os.path.exists(path_on_disk):
session.execute(delete(download_cache).where(download_cache.c.key == key))
session.commit()
return None
return path_on_disk
def _store_cache_sync(key: str, path: str) -> None:
now = datetime.datetime.now(datetime.timezone.utc)
insert_values = {
"key": key,
"path": path,
"ext": Path(path).suffix.lstrip("."),
"created_at": now,
}
stmt = upsert(
download_cache,
insert_values=insert_values,
conflict_cols=["key"],
update_values={"path": path, "ext": Path(path).suffix.lstrip("."), "created_at": now},
)
with SessionLocal.begin() as session:
session.execute(stmt)
async def _expire_ip_cache_entry(ip: str, dup_key: str, delay: int = 300) -> None:
await asyncio.sleep(delay)
_get_ip_cache(ip).discard(dup_key)