Skip to content

Commit

Permalink
chore: reformatted using ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
anonymous-org-za committed Jul 8, 2024
1 parent b5c15cd commit 9c83fb0
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 77 deletions.
2 changes: 0 additions & 2 deletions functions/authFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def checkCorrectCredentials(credentials: HTTPBasicCredentials = Depends(security
- boolean: the boolean determines if the credentials match.
"""


current_username_bytes = credentials.username.encode("utf8")
correct_username_bytes = AUTH_USERNAME.encode("utf-8")
is_correct_username = secrets.compare_digest(
Expand All @@ -31,4 +30,3 @@ def checkCorrectCredentials(credentials: HTTPBasicCredentials = Depends(security
if not (is_correct_username and is_correct_password):
return False
return True

20 changes: 14 additions & 6 deletions functions/serverFunctions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from library.tinfoil import errorMessage, SWITCH_UID


def checkAllowed(authenticated: bool, switch_uid: str):
"""
Checks if a user should be allowed to finish the request, otherwise returns an error message.
Requires:
- authenticated: a boolean which tells the server if the user is authenticated or not.
- switch_uid: a string which either contains a uid or not. If no UID then a user is not using a switch. Also has the ability to check if the switch UID matches the required UID.
Expand All @@ -12,13 +13,20 @@ def checkAllowed(authenticated: bool, switch_uid: str):
- boolean, dict: the boolean determines if the user is allowed past, the dict gives the errorMessage.
"""


if not authenticated:
return False, errorMessage("Your given credentials are incorrect. Please try again.", error_code="BAD_TOKEN")
return False, errorMessage(
"Your given credentials are incorrect. Please try again.",
error_code="BAD_TOKEN",
)
if not switch_uid:
return False, errorMessage("Please use your Nintendo Switch using Tinfoil to access this server.", error_code="INVALID_DEVICE")
return False, errorMessage(
"Please use your Nintendo Switch using Tinfoil to access this server.",
error_code="INVALID_DEVICE",
)
if SWITCH_UID and switch_uid != SWITCH_UID:
return False, errorMessage("This switch is not authorized to use this server. Please use the correct switch to access.", error_code="INVALID_DEVICE")
return False, errorMessage(
"This switch is not authorized to use this server. Please use the correct switch to access.",
error_code="INVALID_DEVICE",
)

return True, None

55 changes: 32 additions & 23 deletions functions/tinfoilFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ async def generateIndex(base_url: str):

try:
# runs all requests in parallel for faster responses
torrents, usenet_downloads, web_downloads = await asyncio.gather(getDownloads("torrents"), getDownloads("usenet"), getDownloads("webdl"))

torrents, usenet_downloads, web_downloads = await asyncio.gather(
getDownloads("torrents"), getDownloads("usenet"), getDownloads("webdl")
)

file_list = torrents + usenet_downloads + web_downloads

for file in file_list:
Expand All @@ -37,52 +39,63 @@ async def generateIndex(base_url: str):
download_id = file.get("id")
file_id = file.get("file_id", 0)


for acceptable_file_type in ACCEPTABLE_SWITCH_FILES:
if fnmatch.fnmatch(file_name, f"*{acceptable_file_type}"):
# create a url for the download
files.append({
# example base_url = http://192.168.0.1/
# example complete_url = http://192.168.0.1/torrents/1/1#Game_Name
"url": f"{base_url}{download_type}/{download_id}/{file_id}#{file_name}",
"size": file.get("size", 0)
})
files.append(
{
# example base_url = http://192.168.0.1/
# example complete_url = http://192.168.0.1/torrents/1/1#Game_Name
"url": f"{base_url}{download_type}/{download_id}/{file_id}#{file_name}",
"size": file.get("size", 0),
}
)

success_message += f"Total Files: {len(files)}"
success_message += f"\nTotal Size: {human_readable.file_size(sum([file.get('size', 0) for file in files]))}"
return JSONResponse(
status_code=200,
content={
"success": success_message,
"files": files
}
status_code=200, content={"success": success_message, "files": files}
)
except Exception as e:
logging.error(f"There was an error generating the index. Error: {str(e)}")
return JSONResponse(
status_code=500,
content=errorMessage(f"There was an error generating the index. Error: {str(e)}", error_code="UNKOWN_ERROR")
content=errorMessage(
f"There was an error generating the index. Error: {str(e)}",
error_code="UNKOWN_ERROR",
),
)

async def serveFile(background_task: BackgroundTasks, download_type: str, download_id: int, file_id: int = 0):

async def serveFile(
background_task: BackgroundTasks,
download_type: str,
download_id: int,
file_id: int = 0,
):
"""
Retrieves the TorBox download link and starts proxying the download through the server. This is necessary as generating a bunch of links through the index generation process can take some time, and is wasteful.
Requires:
- download_type: the download type of the file. Must be either 'torrents', 'usenet' or 'webdl'.
- download_id: an integer which represents the id of the download in the TorBox database.
- file_id: an integer which represents the id of the file which is inside of the download.
Returns:
- Streaming Response: containing the download of the file to be served on the fly.
"""

download_link = await getDownloadLink(download_type=download_type, download_id=download_id, file_id=file_id)
download_link = await getDownloadLink(
download_type=download_type, download_id=download_id, file_id=file_id
)

if not download_link:
return JSONResponse(
status_code=500,
content=errorMessage("There was an error retrieving the download link from TorBox. Please try again.", error_code="DATABASE_ERROR")
content=errorMessage(
"There was an error retrieving the download link from TorBox. Please try again.",
error_code="DATABASE_ERROR",
),
)

# now stream link and stream out
Expand All @@ -93,7 +106,3 @@ async def serveFile(background_task: BackgroundTasks, download_type: str, downlo
cleanup = background_task.add_task(response.aclose)

return StreamingResponse(content=response.aiter_bytes(), background=cleanup)




73 changes: 43 additions & 30 deletions functions/torboxFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,22 @@
import logging
import traceback


async def getDownloads(type: str):
"""
Gets a download type list from TorBox.
Gets a download type list from TorBox.
Requires:
Requires:
- type: the download type to be retrieved, either "torrents", "usenet", or "webdl"
Returns:
- file_list: a list containing all of the files retrieved from the TorBox API.
"""

if type not in ["torrents", "usenet", "webdl"]:
logging.error("Please provide a type of either 'torrents', 'usenet' or 'webdl'.")
logging.error(
"Please provide a type of either 'torrents', 'usenet' or 'webdl'."
)
return []

try:
Expand All @@ -24,11 +27,13 @@ async def getDownloads(type: str):
url=f"{TORBOX_API_URL}/{type}/mylist",
headers={
"Authorization": f"Bearer {TORBOX_API_KEY}",
"User-Agent": "TorBox SelfHosted Tinfoil Server/1.0.0"
}
"User-Agent": "TorBox SelfHosted Tinfoil Server/1.0.0",
},
)
if response.status_code != httpx.codes.OK:
logging.error(f"Unable to retrieve TorBox {type} downloads. Response Code: {response.status_code}. Response: {response.json()}")
logging.error(
f"Unable to retrieve TorBox {type} downloads. Response Code: {response.status_code}. Response: {response.json()}"
)
return []
files = []
json = response.json()
Expand All @@ -41,27 +46,38 @@ async def getDownloads(type: str):
if not file.get("s3_path", None):
continue
try:
files.append({
"type": type,
"id": id,
"file_id": file.get("id", 0),
"name": file.get("s3_path", None).split("/")[-1], # gets only the filename of the file
"size": file.get("size", 0)
})
files.append(
{
"type": type,
"id": id,
"file_id": file.get("id", 0),
"name": file.get("s3_path", None).split("/")[
-1
], # gets only the filename of the file
"size": file.get("size", 0),
}
)
except Exception as e:
logging.error(f"There was an error trying to add {type} download file to file list. Error: {str(e)}")
logging.error(
f"There was an error trying to add {type} download file to file list. Error: {str(e)}"
)
continue
return files
except Exception as e:
traceback.print_exc()
logging.error(f"There was an error getting {type} downloads from TorBox. Error: {str(e)}")
logging.error(
f"There was an error getting {type} downloads from TorBox. Error: {str(e)}"
)
return []



async def getDownloadLink(download_type: str, download_id: int, file_id: int = 0):
if download_type not in ["torrents", "usenet", "webdl"]:
logging.error("Please provide a type of either 'torrents', 'usenet' or 'webdl'.")
logging.error(
"Please provide a type of either 'torrents', 'usenet' or 'webdl'."
)
return None

try:
async with httpx.AsyncClient() as client:
if download_type == "torrents":
Expand All @@ -70,25 +86,22 @@ async def getDownloadLink(download_type: str, download_id: int, file_id: int = 0
id_type = "usenet_id"
elif download_type == "webdl":
id_type = "web_id"
params = {
"token": TORBOX_API_KEY,
id_type: download_id,
file_id: file_id
}
params = {"token": TORBOX_API_KEY, id_type: download_id, file_id: file_id}
response = await client.get(
url=f"{TORBOX_API_URL}/{download_type}/requestdl",
params=params,
headers={
"User-Agent": "TorBox SelfHosted Tinfoil Server/1.0.0"
}
headers={"User-Agent": "TorBox SelfHosted Tinfoil Server/1.0.0"},
)
if response.status_code != httpx.codes.OK:
logging.error(f"Unable to retrieve TorBox {download_type} download link. Response Code: {response.status_code}. Response: {response.json()}")
logging.error(
f"Unable to retrieve TorBox {download_type} download link. Response Code: {response.status_code}. Response: {response.json()}"
)
return None

json = response.json()
link = json.get("data", None)
return link
except Exception as e:
logging.error(f"There was an error retrieving the download url from TorBox. Error: {str(e)}")

logging.error(
f"There was an error retrieving the download url from TorBox. Error: {str(e)}"
)
1 change: 1 addition & 0 deletions library/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from dotenv import load_dotenv

load_dotenv()

AUTH_USERNAME = os.getenv("AUTH_USERNAME", "admin")
Expand Down
7 changes: 3 additions & 4 deletions library/tinfoil.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from dotenv import load_dotenv

load_dotenv()

SWITCH_UID = os.getenv("SWITCH_UID", None)
Expand All @@ -16,7 +17,5 @@ def errorMessage(message: str, error_code: str):
Returns:
- error: a dict with the 'error' key which is read by Tinfoil. Documentation here: https://blawar.github.io/tinfoil/custom_index/
"""

return {
"error": f"TorBox\n\n{message}\n\nError: {error_code}"
}

return {"error": f"TorBox\n\n{message}\n\nError: {error_code}"}
3 changes: 2 additions & 1 deletion library/torbox.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from dotenv import load_dotenv

load_dotenv()

TORBOX_API_KEY = os.getenv("TORBOX_API_KEY")
TORBOX_API_URL = "https://api.torbox.app/v1/api"
TORBOX_API_URL = "https://api.torbox.app/v1/api"
39 changes: 28 additions & 11 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import uvicorn
from fastapi import FastAPI, HTTPException, status, Depends, Header, Request, BackgroundTasks
from fastapi import (
FastAPI,
HTTPException,
status,
Depends,
Header,
Request,
BackgroundTasks,
)
from fastapi.responses import JSONResponse
from typing_extensions import Annotated, Union
from library.tinfoil import errorMessage
Expand All @@ -12,39 +20,48 @@
app = FastAPI()
logging.basicConfig(level=logging.INFO)


# Custom exemption handler to be well-formatted with Tinfoil so the user knows what has happened if no authentication is sent, as it is required.
@app.exception_handler(HTTPException)
async def custom_http_exception_handler(request, exc):
if exc.status_code == status.HTTP_401_UNAUTHORIZED:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=errorMessage(message="Please authenticate using your username and password set when the server was created.", error_code="NO_AUTH"),
content=errorMessage(
message="Please authenticate using your username and password set when the server was created.",
error_code="NO_AUTH",
),
)
return await request.app.exception_handler(exc)


@app.get("/")
async def get_user_files(
request: Request,
authenticated: bool = Depends(checkCorrectCredentials),
uid: Annotated[Union[str, None], Header()] = None
uid: Annotated[Union[str, None], Header()] = None,
):
logging.info(f"Request from Switch with UID: {uid}")
allowed, response = checkAllowed(authenticated=authenticated, switch_uid=uid)
if not allowed:
return JSONResponse(
content=response,
status_code=401
)
return JSONResponse(content=response, status_code=401)
return await generateIndex(base_url=request.base_url)


@app.get("/{download_type}/{download_id}/{file_id}")
async def get_file(
background_task: BackgroundTasks, # background_task is used to clean up the httpx response afterwards to prevent leakage
background_task: BackgroundTasks, # background_task is used to clean up the httpx response afterwards to prevent leakage
download_type: str,
download_id: int,
file_id: int = 0
file_id: int = 0,
):
return await serveFile(background_task=background_task, download_type=download_type, download_id=download_id, file_id=file_id)
return await serveFile(
background_task=background_task,
download_type=download_type,
download_id=download_id,
file_id=file_id,
)


if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=PORT)
uvicorn.run(app, host="0.0.0.0", port=PORT)

0 comments on commit 9c83fb0

Please sign in to comment.