From 9b10e525f0d7d76f90b97f9fa714af6a3f763ff0 Mon Sep 17 00:00:00 2001 From: boyska Date: Fri, 17 Sep 2021 10:45:36 +0200 Subject: [PATCH] basic auth support --- techrec/default_config.py | 1 + techrec/forge.py | 4 ++-- techrec/http_retriever.py | 29 +++++++++++++++++++++++------ 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/techrec/default_config.py b/techrec/default_config.py index b03f1b5..4c21ac3 100644 --- a/techrec/default_config.py +++ b/techrec/default_config.py @@ -9,6 +9,7 @@ DEBUG = True DB_URI = "sqlite:///techrec.db" AUDIO_OUTPUT = "output/" AUDIO_INPUT = "rec/" +AUDIO_INPUT_BASICAUTH = None # Could be a ("user", "pass") tuple instead AUDIO_INPUT_FORMAT = "%Y-%m/%d/rec-%Y-%m-%d-%H-%M-%S.mp3" AUDIO_OUTPUT_FORMAT = "techrec-%(startdt)s-%(endtime)s-%(name)s.mp3" FORGE_TIMEOUT = 20 diff --git a/techrec/forge.py b/techrec/forge.py index a492dc3..10a8c00 100644 --- a/techrec/forge.py +++ b/techrec/forge.py @@ -1,4 +1,3 @@ -import aiohttp import asyncio import logging import tempfile @@ -26,7 +25,8 @@ async def get_timefile_exact(time) -> str: ) if path.startswith("http://") or path.startswith("https://"): logger.info(f"downloading: {path}") - local = await download(path) + local = await download(path, + basic_auth=get_config()['AUDIO_INPUT_BASICAUTH']) return local return path diff --git a/techrec/http_retriever.py b/techrec/http_retriever.py index 68d8934..d67308c 100644 --- a/techrec/http_retriever.py +++ b/techrec/http_retriever.py @@ -1,17 +1,21 @@ # -*- encoding: utf-8 -*- -import asyncio import os -from typing import Optional +from typing import Optional, Tuple from tempfile import mkdtemp +from logging import getLogger import aiohttp # type: ignore -from .config_manager import get_config - CHUNK_SIZE = 2 ** 12 +log = getLogger("http") -async def download(remote: str, staging: Optional[str] = None) -> str: + +async def download( + remote: str, + staging: Optional[str] = None, + basic_auth: Optional[Tuple[str, str]] = None, +) -> str: """ This will download to AUDIO_STAGING the remote file and return the local path of the downloaded file @@ -24,12 +28,25 @@ async def download(remote: str, staging: Optional[str] = None) -> str: # used by techrec: rm -rf /tmp/techrec* base = mkdtemp(prefix="techrec-", dir="/tmp") local = os.path.join(base, filename) - async with aiohttp.ClientSession() as session: + + session_args = {} + if basic_auth is not None: + session_args["auth"] = aiohttp.BasicAuth( + login=basic_auth[0], password=basic_auth[1], encoding="utf-8" + ) + + log.debug("Downloading %s with %s options", remote, ",".join(session_args.keys())) + async with aiohttp.ClientSession(**session_args) as session: async with session.get(remote) as resp: + if resp.status != 200: + raise ValueError( + "Could not download %s: error %d" % (remote, resp.status) + ) with open(local, "wb") as f: while True: chunk = await resp.content.read(CHUNK_SIZE) if not chunk: break f.write(chunk) + log.debug("Downloading %s complete", remote) return local