Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add More Features and Pretty Page #1

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,8 @@ cython_debug/
**/*-local.*
**/*-local
**/*_local.*
**/*_local
**/*_local

# static and media files
static/**
media/**
32 changes: 22 additions & 10 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
import os.path
from contextlib import asynccontextmanager

import uvicorn
from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError, ResponseValidationError
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware

from util.logger import init_logger

init_logger()
from starlette.staticfiles import StaticFiles


def register_router(_app: FastAPI):
from core import api_router, page_router

_app.include_router(page_router.router)
_app.include_router(api_router.router, prefix="/api")


def add_custom_exception_handlers(_app: FastAPI):
from util.exception_util import request_validation_exception_handler, response_validation_exception_handler
def register_custom_exception_handlers(_app: FastAPI):
from util.exception_util import (
request_validation_exception_handler,
response_validation_exception_handler,
)

_app.add_exception_handler(RequestValidationError, request_validation_exception_handler)
_app.add_exception_handler(ResponseValidationError, response_validation_exception_handler)

Expand All @@ -33,19 +35,29 @@ def register_middlewares(_app: FastAPI):
)


def register_mounter(_app: FastAPI):
os.makedirs("static", exist_ok=True)
os.makedirs("media", exist_ok=True)
_app.mount("/static", StaticFiles(directory="static"), name="static")
_app.mount("/media", StaticFiles(directory="media"), name="media")


def create_app(span) -> FastAPI:
_app = FastAPI(lifespan=span)
register_router(_app)
register_middlewares(_app)
add_custom_exception_handlers(_app)
register_custom_exception_handlers(_app)
register_mounter(_app)
return _app


@asynccontextmanager
async def lifespan(application: FastAPI): # noqa
async def lifespan(application: FastAPI):
from base.connector import database_connector

"""
Use context manager to manage the lifespan of the application instead of using the startup and shutdown events.
Use context manager to manage the lifespan of the application instead of
using the startup and shutdown events.
"""
yield
await database_connector.engine.dispose()
Expand Down
105 changes: 80 additions & 25 deletions core/chat.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
import hashlib
import json
import os
import random
import shutil
import string
from datetime import timedelta
from typing import List, Dict
from datetime import datetime, timedelta

import starlette.datastructures
from aioredis import Redis
from fastapi import APIRouter, Depends
from starlette.websockets import WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Depends, File, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.exceptions import RequestValidationError
from loguru import logger

from .deps import get_redis_session
from datetime import datetime

clients: List[WebSocket] = []
client_map: Dict[str, Dict[str, str]] = {} # {websocket_id: {"username": username, "ip": ip}}
# 存储 WebSocket 连接和映射关系
clients: list[WebSocket] = []
client_map: dict[WebSocket, dict[str, str]] = {}

router = APIRouter()


# 生成随机用户名
def generate_random_username() -> str:
return ''.join(random.choices(string.ascii_letters + string.digits, k=8))
return "".join(random.choices(string.ascii_letters + string.digits, k=8))


# 消息存储到 Redis
async def store_message_in_redis(message: str, username: str, ip: str, session: Redis):
async def store_message_in_redis(message: str | dict, username: str, ip: str, session: Redis):
message_data = {
"username": username,
"ip": ip,
"timestamp": datetime.now().isoformat(),
"message": message
"message": message,
}
# 将消息数据以 JSON 字符串的形式存储到 Redis 列表
await session.lpush("chat_history", json.dumps(message_data))
Expand All @@ -40,48 +46,97 @@ async def get_chat_history(session: Redis):
return [json.loads(msg) for msg in reversed(messages)]


# 使用 SHA-256 对文件名进行加密
def encrypt_filename(filename: str) -> str:
# 使用 SHA-256 对文件名进行哈希加密
sha256_hash = hashlib.sha256()
sha256_hash.update(filename.encode("utf-8")) # 对文件名进行编码
return sha256_hash.hexdigest() # 返回加密后的文件名


# 上传文件处理,文件大小小于50MB时通过WebSocket直接传输,超过50MB时使用POST上传
async def handle_file_upload(file: UploadFile | bytes, file_name: str = None) -> str:
# 获取文件的扩展名
file_name, ext = os.path.splitext(file_name or file.filename)
file_name = encrypt_filename(file_name) or encrypt_filename(file.filename)
filename = f"{file_name}{ext}"
file_location = f"/media/uploads/{filename}"

# 保存文件到指定目录
os.makedirs(os.path.dirname(file_location.lstrip("/")), exist_ok=True)
with open(file_location.lstrip("/"), "wb") as f:
shutil.copyfileobj(file.file, f) if isinstance(file, starlette.datastructures.UploadFile) else f.write(file)

return file_location


# WebSocket 连接处理
@router.websocket("/ws/chat")
async def websocket_endpoint(
websocket: WebSocket,
rds_session: Redis = Depends(get_redis_session)
):
async def websocket_endpoint(websocket: WebSocket, rds_session: Redis = Depends(get_redis_session)):
await websocket.accept()
clients.append(websocket)

# 获取客户端 IP 和生成随机用户名
client_ip = str(websocket.client)
client_real_ip = str(websocket.client)
username = generate_random_username()

# 保存映射关系
client_map[websocket] = {"username": username, "ip": client_ip}
client_map[websocket] = {"username": username, "ip": client_real_ip}

# 向客户端发送历史消息
chat_history = await get_chat_history(rds_session)
for message in chat_history:
await websocket.send_text(f"{message['timestamp']} - {message['username']} ({message['ip']}): {message['message']}")
await websocket.send_json(message)

try:
while True:
# 接收客户端发送的消息
message = await websocket.receive_text()
origin_message = await websocket.receive_json()

if origin_message.get("type") == "file":
if origin_message.get("filename") and origin_message.get("fileSize") < 50 * 1024 * 1024:
# 接收文件二进制数据
file_data = await websocket.receive_bytes()
if file_data:
file_location = await handle_file_upload(file_data, file_name=origin_message.get("filename"))
origin_message["url"] = file_location
logger.info(f"File saved: {file_location}")
logger.warning("File size exceeds 50MB, please use POST to upload.")
raise RequestValidationError("File size exceeds 50MB, please use POST to upload.")

# 获取客户端的用户名和 IP 地址
client_username = client_map[websocket]["username"]
client_ip = client_map[websocket]["ip"]
client_real_ip = client_map[websocket]["ip"]

# 将消息存储到 Redis
await store_message_in_redis(message, client_username, client_ip, rds_session)
await store_message_in_redis(origin_message, client_username, client_real_ip, rds_session)

# 向所有连接的客户端广播消息
for client in clients:
# 获取发送者的用户名
sender_username = client_map[client]["username"]
sender_ip = client_map[client]["ip"]
# 发送时,附加用户名、IP 和时间戳
await client.send_text(f"{datetime.now().isoformat()} - {sender_username} ({sender_ip}): {message}")
await client.send_json(
{
"username": client_map[client]["username"],
"ip": client_map[client]["ip"],
"timestamp": datetime.now().isoformat(),
"message": origin_message,
}
)

except WebSocketDisconnect:
# 断开连接时,清除客户端映射关系
del client_map[websocket]
clients.remove(websocket)


# 处理文件上传的HTTP POST请求
@router.post("/upload")
async def upload_file(file: UploadFile = File(...)):
try:
# 处理文件上传,并返回文件存储位置
file_location = await handle_file_upload(file)

# 返回文件的 URL
return {"fileUrl": f"{file_location}"}

except ValueError as e:
return {"error": str(e)}
1 change: 1 addition & 0 deletions core/page_router.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi import APIRouter

from .page import router as page_router

router = APIRouter()
Expand Down
6 changes: 3 additions & 3 deletions exc/database_exc.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import Any


class DatabaseErr(Exception):
class DatabaseError(Exception):
def __init__(self, message: str = ""):
super().__init__(f"{message}")


class NotFoundRecordsErr(DatabaseErr):
class NotFoundRecordsError(DatabaseError):
def __init__(self, reason: Any = None):
super().__init__(f"{f'{reason}' if reason else '.'}")


class IntegrityErr(DatabaseErr):
class IntegrityError(DatabaseError):
def __init__(self, reason: Any = None):
super().__init__(f"Record(s) Integrity Error{f': `{reason}`' if reason else '.'}")
6 changes: 3 additions & 3 deletions exc/service_exc.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from builtins import Exception


class ServiceErr(Exception):
class ServiceError(Exception):
def __init__(self, message: str | dict = ""):
super().__init__(message)


class BadRequestErr(ServiceErr):
class BadRequestError(ServiceError):
def __init__(self, message: str | dict = ""):
super().__init__(message)


class NotFoundErr(ServiceErr):
class NotFoundError(ServiceError):
def __init__(self, message: str | dict = ""):
super().__init__(message)
Loading