init
This commit is contained in:
287
backend/core/download.py
Normal file
287
backend/core/download.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""backend/core/download.py — patched 2025-06-03"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import datetime
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import yt_dlp
|
||||
from sqlalchemy import select, delete
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
|
||||
from core.db_utils import upsert
|
||||
from core.settings import (
|
||||
DOWNLOAD_DIR,
|
||||
TMP_DIR,
|
||||
PER_IP_CONCURRENCY,
|
||||
DOWNLOAD_CACHE_TTL_SEC,
|
||||
)
|
||||
from core.network import get_proxy, record_proxy
|
||||
from core.db_xp import ensure_user
|
||||
from core.db import SessionLocal, download_cache
|
||||
from core.progress_bus import update as set_progress
|
||||
from core.formats import _cached_metadata_fetch, _clean_proxy
|
||||
|
||||
os.makedirs(DOWNLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(TMP_DIR, exist_ok=True)
|
||||
|
||||
EST_MB = contextvars.ContextVar("est_mb", default=0)
|
||||
|
||||
MAX_GLOBAL_DOWNLOADS = PER_IP_CONCURRENCY * 4
|
||||
_global_semaphore = asyncio.Semaphore(MAX_GLOBAL_DOWNLOADS)
|
||||
|
||||
_ip_semaphores: Dict[str, asyncio.BoundedSemaphore] = {}
|
||||
_inflight: Dict[str, asyncio.Task[str]] = {}
|
||||
|
||||
_ip_cache: Dict[str, set[str]] = {}
|
||||
|
||||
|
||||
def _get_ip_cache(ip: str) -> set[str]:
|
||||
return _ip_cache.setdefault(ip, set())
|
||||
|
||||
|
||||
def _url_fmt_hash(url: str, fmt: str) -> str:
|
||||
return hashlib.blake2s(f"{url}::{fmt}".encode(), digest_size=16).hexdigest()
|
||||
|
||||
|
||||
_ansi_escape = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]")
|
||||
|
||||
|
||||
async def download(url: str, fmt_id: str, ip: str, sid: str) -> str:
|
||||
fmt = fmt_id or "bestaudio"
|
||||
key = _url_fmt_hash(url, fmt)
|
||||
dup_key = f"{ip}::{url}::{fmt}"
|
||||
|
||||
cached = await asyncio.to_thread(_lookup_cache_sync, key)
|
||||
if cached:
|
||||
set_progress(sid, status="cached", pct=100, progress="Instant – served from cache")
|
||||
return cached
|
||||
|
||||
if key in _inflight:
|
||||
return await _inflight[key]
|
||||
|
||||
ip_cache_set = _get_ip_cache(ip)
|
||||
if dup_key in ip_cache_set:
|
||||
cached2 = await asyncio.to_thread(_lookup_cache_sync, key)
|
||||
if cached2:
|
||||
set_progress(sid, status="cached", pct=100, progress="Instant – served from cache")
|
||||
return cached2
|
||||
ip_cache_set.add(dup_key)
|
||||
|
||||
sem = _ip_semaphores.setdefault(ip, asyncio.BoundedSemaphore(PER_IP_CONCURRENCY))
|
||||
|
||||
async def _run() -> str:
|
||||
async with _global_semaphore, sem:
|
||||
ensure_user(ip)
|
||||
set_progress(sid, status="starting", pct=0, progress="Starting…")
|
||||
|
||||
try:
|
||||
info = await asyncio.to_thread(_cached_metadata_fetch, url)
|
||||
except Exception as e:
|
||||
set_progress(sid, status="error", progress=f"Metadata fetch failed: {e}")
|
||||
raise
|
||||
|
||||
attempt = 0
|
||||
last_exc: Exception | None = None
|
||||
|
||||
while attempt < 3:
|
||||
attempt += 1
|
||||
proxies = [get_proxy() for _ in range(5)]
|
||||
random.shuffle(proxies)
|
||||
proxies.append("DIRECT")
|
||||
|
||||
for proxy_url in proxies:
|
||||
try:
|
||||
final_path = await _single_download(
|
||||
url,
|
||||
fmt,
|
||||
key,
|
||||
sid,
|
||||
proxy_url,
|
||||
info,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
record_proxy(proxy_url, False)
|
||||
clean_proxy = _clean_proxy(proxy_url)
|
||||
set_progress(
|
||||
sid,
|
||||
status="retrying",
|
||||
progress=f"Retry {attempt} failed (proxy {clean_proxy})",
|
||||
)
|
||||
await asyncio.sleep(1 + random.random())
|
||||
continue
|
||||
else:
|
||||
record_proxy(proxy_url, True)
|
||||
await asyncio.to_thread(_store_cache_sync, key, final_path)
|
||||
set_progress(sid, status="finished", pct=100, progress="Done")
|
||||
return final_path
|
||||
|
||||
set_progress(sid, status="error", progress="Download failed")
|
||||
raise RuntimeError(f"All download attempts failed: {last_exc!r}")
|
||||
|
||||
task = asyncio.create_task(_run())
|
||||
_inflight[key] = task
|
||||
try:
|
||||
return await task
|
||||
finally:
|
||||
_inflight.pop(key, None)
|
||||
asyncio.create_task(_expire_ip_cache_entry(ip, dup_key))
|
||||
|
||||
|
||||
async def _single_download(
|
||||
url: str,
|
||||
fmt: str,
|
||||
_unused_cache_key: str,
|
||||
sid: str,
|
||||
proxy_url: str,
|
||||
info: dict,
|
||||
) -> str:
|
||||
title = info.get("title") or "unknown"
|
||||
artist = info.get("artist") or info.get("uploader") or "unknown"
|
||||
|
||||
def _clean(s: str) -> str:
|
||||
return re.sub(r'[\\/*?:"<>|]', "", s)
|
||||
|
||||
safe_title = _clean(title)
|
||||
safe_artist = _clean(artist)
|
||||
short_id = hashlib.blake2s(f"{url}::{fmt}".encode(), digest_size=8).hexdigest()
|
||||
base = f"{safe_title} - {safe_artist} - {short_id}"
|
||||
|
||||
fmt_entry = next((f for f in info.get("formats", []) if f.get("format_id") == fmt), None)
|
||||
is_audio_only = bool(fmt_entry and fmt_entry.get("vcodec") == "none")
|
||||
if "soundcloud.com" in url.lower():
|
||||
is_audio_only = True
|
||||
|
||||
# force .mp3 for audio-only, .mp4 otherwise
|
||||
ext_guess = "mp3" if is_audio_only else "mp4"
|
||||
outtmpl_path = DOWNLOAD_DIR / f"{base}.%(ext)s"
|
||||
final_path_expected = DOWNLOAD_DIR / f"{base}.{ext_guess}"
|
||||
|
||||
if final_path_expected.exists() and final_path_expected.stat().st_size > 0:
|
||||
return str(final_path_expected)
|
||||
|
||||
cmd = ["yt-dlp", "-f", fmt, "-o", str(outtmpl_path), url]
|
||||
if is_audio_only:
|
||||
cmd = ["yt-dlp", "-x", "--audio-format", "mp3", "-o", str(outtmpl_path), url]
|
||||
else:
|
||||
cmd = ["yt-dlp", "-f", f"{fmt}+bestaudio", "-o", str(outtmpl_path), url]
|
||||
if proxy_url and proxy_url.upper() != "DIRECT":
|
||||
cmd.insert(1, f"--proxy={proxy_url}")
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
stderr_buffer: list[str] = []
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
line = await asyncio.wait_for(proc.stderr.readline(), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
if not line:
|
||||
break
|
||||
decoded = line.decode(errors="ignore").strip()
|
||||
if decoded:
|
||||
stderr_buffer.append(decoded)
|
||||
if len(stderr_buffer) > 10:
|
||||
stderr_buffer.pop(0)
|
||||
set_progress(sid, status="running", pct=None, progress=decoded)
|
||||
|
||||
try:
|
||||
rc = await asyncio.wait_for(proc.wait(), timeout=15)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
raise RuntimeError("yt-dlp stalled and was killed (timeout)")
|
||||
|
||||
if rc != 0:
|
||||
raise RuntimeError(f"yt-dlp exited with code {rc}. Last lines: {' | '.join(stderr_buffer)}")
|
||||
|
||||
candidates = [
|
||||
p for p in DOWNLOAD_DIR.glob(f"{base}.*")
|
||||
if p.is_file() and p.stat().st_size > 0
|
||||
]
|
||||
if not candidates:
|
||||
raise RuntimeError("No output file produced")
|
||||
|
||||
candidates.sort(key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
return str(candidates[0])
|
||||
|
||||
except asyncio.CancelledError:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
for f in DOWNLOAD_DIR.glob(f"{base}.*"):
|
||||
f.unlink(missing_ok=True)
|
||||
raise
|
||||
|
||||
|
||||
def _lookup_cache_sync(key: str) -> str | None:
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
with SessionLocal() as session:
|
||||
try:
|
||||
row = session.execute(
|
||||
select(download_cache.c.path, download_cache.c.created_at)
|
||||
.where(download_cache.c.key == key)
|
||||
).one()
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
path_on_disk, created_at = row
|
||||
if created_at.tzinfo is None:
|
||||
created_at = created_at.replace(tzinfo=datetime.timezone.utc)
|
||||
|
||||
age = (now - created_at).total_seconds()
|
||||
if age > DOWNLOAD_CACHE_TTL_SEC:
|
||||
session.execute(delete(download_cache).where(download_cache.c.key == key))
|
||||
session.commit()
|
||||
try:
|
||||
os.remove(path_on_disk)
|
||||
except OSError:
|
||||
pass
|
||||
return None
|
||||
|
||||
if not os.path.exists(path_on_disk):
|
||||
session.execute(delete(download_cache).where(download_cache.c.key == key))
|
||||
session.commit()
|
||||
return None
|
||||
|
||||
return path_on_disk
|
||||
|
||||
|
||||
def _store_cache_sync(key: str, path: str) -> None:
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
insert_values = {
|
||||
"key": key,
|
||||
"path": path,
|
||||
"ext": Path(path).suffix.lstrip("."),
|
||||
"created_at": now,
|
||||
}
|
||||
stmt = upsert(
|
||||
download_cache,
|
||||
insert_values=insert_values,
|
||||
conflict_cols=["key"],
|
||||
update_values={"path": path, "ext": Path(path).suffix.lstrip("."), "created_at": now},
|
||||
)
|
||||
with SessionLocal.begin() as session:
|
||||
session.execute(stmt)
|
||||
|
||||
|
||||
async def _expire_ip_cache_entry(ip: str, dup_key: str, delay: int = 300) -> None:
|
||||
await asyncio.sleep(delay)
|
||||
_get_ip_cache(ip).discard(dup_key)
|
||||
Reference in New Issue
Block a user