"""backend/core/download.py — patched 2025-06-03""" from __future__ import annotations import asyncio import contextvars import datetime import hashlib import os import random import re from pathlib import Path from typing import Dict import yt_dlp from sqlalchemy import select, delete from sqlalchemy.exc import NoResultFound from core.db_utils import upsert from core.settings import ( DOWNLOAD_DIR, TMP_DIR, PER_IP_CONCURRENCY, DOWNLOAD_CACHE_TTL_SEC, ) from core.network import get_proxy, record_proxy from core.db_xp import ensure_user from core.db import SessionLocal, download_cache from core.progress_bus import update as set_progress from core.formats import _cached_metadata_fetch, _clean_proxy os.makedirs(DOWNLOAD_DIR, exist_ok=True) os.makedirs(TMP_DIR, exist_ok=True) EST_MB = contextvars.ContextVar("est_mb", default=0) MAX_GLOBAL_DOWNLOADS = PER_IP_CONCURRENCY * 4 _global_semaphore = asyncio.Semaphore(MAX_GLOBAL_DOWNLOADS) _ip_semaphores: Dict[str, asyncio.BoundedSemaphore] = {} _inflight: Dict[str, asyncio.Task[str]] = {} _ip_cache: Dict[str, set[str]] = {} def _get_ip_cache(ip: str) -> set[str]: return _ip_cache.setdefault(ip, set()) def _url_fmt_hash(url: str, fmt: str) -> str: return hashlib.blake2s(f"{url}::{fmt}".encode(), digest_size=16).hexdigest() _ansi_escape = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]") async def download(url: str, fmt_id: str, ip: str, sid: str) -> str: fmt = fmt_id or "bestaudio" key = _url_fmt_hash(url, fmt) dup_key = f"{ip}::{url}::{fmt}" cached = await asyncio.to_thread(_lookup_cache_sync, key) if cached: set_progress(sid, status="cached", pct=100, progress="Instant – served from cache") return cached if key in _inflight: return await _inflight[key] ip_cache_set = _get_ip_cache(ip) if dup_key in ip_cache_set: cached2 = await asyncio.to_thread(_lookup_cache_sync, key) if cached2: set_progress(sid, status="cached", pct=100, progress="Instant – served from cache") return cached2 ip_cache_set.add(dup_key) sem = _ip_semaphores.setdefault(ip, asyncio.BoundedSemaphore(PER_IP_CONCURRENCY)) async def _run() -> str: async with _global_semaphore, sem: ensure_user(ip) set_progress(sid, status="starting", pct=0, progress="Starting…") try: info = await asyncio.to_thread(_cached_metadata_fetch, url) except Exception as e: set_progress(sid, status="error", progress=f"Metadata fetch failed: {e}") raise attempt = 0 last_exc: Exception | None = None while attempt < 3: attempt += 1 proxies = [get_proxy() for _ in range(5)] random.shuffle(proxies) proxies.append("DIRECT") for proxy_url in proxies: try: final_path = await _single_download( url, fmt, key, sid, proxy_url, info, ) except asyncio.CancelledError: raise except Exception as exc: last_exc = exc record_proxy(proxy_url, False) clean_proxy = _clean_proxy(proxy_url) set_progress( sid, status="retrying", progress=f"Retry {attempt} failed (proxy {clean_proxy})", ) await asyncio.sleep(1 + random.random()) continue else: record_proxy(proxy_url, True) await asyncio.to_thread(_store_cache_sync, key, final_path) set_progress(sid, status="finished", pct=100, progress="Done") return final_path set_progress(sid, status="error", progress="Download failed") raise RuntimeError(f"All download attempts failed: {last_exc!r}") task = asyncio.create_task(_run()) _inflight[key] = task try: return await task finally: _inflight.pop(key, None) asyncio.create_task(_expire_ip_cache_entry(ip, dup_key)) async def _single_download( url: str, fmt: str, _unused_cache_key: str, sid: str, proxy_url: str, info: dict, ) -> str: title = info.get("title") or "unknown" artist = info.get("artist") or info.get("uploader") or "unknown" def _clean(s: str) -> str: return re.sub(r'[\\/*?:"<>|]', "", s) safe_title = _clean(title) safe_artist = _clean(artist) short_id = hashlib.blake2s(f"{url}::{fmt}".encode(), digest_size=8).hexdigest() base = f"{safe_title} - {safe_artist} - {short_id}" fmt_entry = next((f for f in info.get("formats", []) if f.get("format_id") == fmt), None) is_audio_only = bool(fmt_entry and fmt_entry.get("vcodec") == "none") if "soundcloud.com" in url.lower(): is_audio_only = True # force .mp3 for audio-only, .mp4 otherwise ext_guess = "mp3" if is_audio_only else "mp4" outtmpl_path = DOWNLOAD_DIR / f"{base}.%(ext)s" final_path_expected = DOWNLOAD_DIR / f"{base}.{ext_guess}" if final_path_expected.exists() and final_path_expected.stat().st_size > 0: return str(final_path_expected) cmd = ["yt-dlp", "-f", fmt, "-o", str(outtmpl_path), url] if is_audio_only: cmd = ["yt-dlp", "-x", "--audio-format", "mp3", "-o", str(outtmpl_path), url] else: cmd = ["yt-dlp", "-f", f"{fmt}+bestaudio", "-o", str(outtmpl_path), url] if proxy_url and proxy_url.upper() != "DIRECT": cmd.insert(1, f"--proxy={proxy_url}") proc = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) stderr_buffer: list[str] = [] try: while True: try: line = await asyncio.wait_for(proc.stderr.readline(), timeout=10) except asyncio.TimeoutError: break if not line: break decoded = line.decode(errors="ignore").strip() if decoded: stderr_buffer.append(decoded) if len(stderr_buffer) > 10: stderr_buffer.pop(0) set_progress(sid, status="running", pct=None, progress=decoded) try: rc = await asyncio.wait_for(proc.wait(), timeout=15) except asyncio.TimeoutError: proc.kill() await proc.wait() raise RuntimeError("yt-dlp stalled and was killed (timeout)") if rc != 0: raise RuntimeError(f"yt-dlp exited with code {rc}. Last lines: {' | '.join(stderr_buffer)}") candidates = [ p for p in DOWNLOAD_DIR.glob(f"{base}.*") if p.is_file() and p.stat().st_size > 0 ] if not candidates: raise RuntimeError("No output file produced") candidates.sort(key=lambda p: p.stat().st_mtime, reverse=True) return str(candidates[0]) except asyncio.CancelledError: proc.kill() await proc.wait() raise except Exception: for f in DOWNLOAD_DIR.glob(f"{base}.*"): f.unlink(missing_ok=True) raise def _lookup_cache_sync(key: str) -> str | None: now = datetime.datetime.now(datetime.timezone.utc) with SessionLocal() as session: try: row = session.execute( select(download_cache.c.path, download_cache.c.created_at) .where(download_cache.c.key == key) ).one() except NoResultFound: return None path_on_disk, created_at = row if created_at.tzinfo is None: created_at = created_at.replace(tzinfo=datetime.timezone.utc) age = (now - created_at).total_seconds() if age > DOWNLOAD_CACHE_TTL_SEC: session.execute(delete(download_cache).where(download_cache.c.key == key)) session.commit() try: os.remove(path_on_disk) except OSError: pass return None if not os.path.exists(path_on_disk): session.execute(delete(download_cache).where(download_cache.c.key == key)) session.commit() return None return path_on_disk def _store_cache_sync(key: str, path: str) -> None: now = datetime.datetime.now(datetime.timezone.utc) insert_values = { "key": key, "path": path, "ext": Path(path).suffix.lstrip("."), "created_at": now, } stmt = upsert( download_cache, insert_values=insert_values, conflict_cols=["key"], update_values={"path": path, "ext": Path(path).suffix.lstrip("."), "created_at": now}, ) with SessionLocal.begin() as session: session.execute(stmt) async def _expire_ip_cache_entry(ip: str, dup_key: str, delay: int = 300) -> None: await asyncio.sleep(delay) _get_ip_cache(ip).discard(dup_key)