diff --git a/bot/bot.py b/bot/bot.py index cefb14b46..74dc5b256 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -38,8 +38,8 @@ db = database.Database() logger = logging.getLogger(__name__) -user_semaphores = {} -user_tasks = {} +chat_semaphores = {} +chat_tasks = {} HELP_MESSAGE = """Commands: ⚪ /retry – Regenerate last bot answer @@ -56,28 +56,25 @@ def split_text_into_chunks(text, chunk_size): yield text[i:i + chunk_size] -async def register_user_if_not_exists(update: Update, context: CallbackContext, user: User): - if not db.check_if_user_exists(user.id): - db.add_new_user( - user.id, - update.message.chat_id, - username=user.username, - first_name=user.first_name, - last_name= user.last_name - ) - db.start_new_dialog(user.id) +async def register_chat_if_not_exists(update: Update): + + chat_id = update.message.chat_id + + if not db.chat_exists(chat_id): + db.create_chat(chat_id) + db.start_new_dialog(chat_id) - if db.get_user_attribute(user.id, "current_dialog_id") is None: - db.start_new_dialog(user.id) + if db.get_chat_attribute(chat_id, "current_dialog_id") is None: + db.start_new_dialog(chat_id) - if user.id not in user_semaphores: - user_semaphores[user.id] = asyncio.Semaphore(1) + if chat_id not in chat_semaphores: + chat_semaphores[chat_id] = asyncio.Semaphore(1) - if db.get_user_attribute(user.id, "current_model") is None: - db.set_user_attribute(user.id, "current_model", config.models["available_text_models"][0]) + if db.get_chat_attribute(chat_id, "current_model") is None: + db.set_chat_attribute(chat_id, "current_model", config.models["available_text_models"][0]) # back compatibility for n_used_tokens field - n_used_tokens = db.get_user_attribute(user.id, "n_used_tokens") + n_used_tokens = db.get_chat_attribute(chat_id, "n_used_tokens") if isinstance(n_used_tokens, int): # old format new_n_used_tokens = { "gpt-3.5-turbo": { @@ -85,19 +82,19 @@ async def register_user_if_not_exists(update: Update, context: CallbackContext, "n_output_tokens": n_used_tokens } } - db.set_user_attribute(user.id, "n_used_tokens", new_n_used_tokens) + db.set_chat_attribute(chat_id, "n_used_tokens", new_n_used_tokens) # voice message transcription - if db.get_user_attribute(user.id, "n_transcribed_seconds") is None: - db.set_user_attribute(user.id, "n_transcribed_seconds", 0.0) + if db.get_chat_attribute(chat_id, "n_transcribed_seconds") is None: + db.set_chat_attribute(chat_id, "n_transcribed_seconds", 0.0) async def start_handle(update: Update, context: CallbackContext): - await register_user_if_not_exists(update, context, update.message.from_user) - user_id = update.message.from_user.id + await register_chat_if_not_exists(update) + chat_id = update.message.chat_id - db.set_user_attribute(user_id, "last_interaction", datetime.now()) - db.start_new_dialog(user_id) + db.set_chat_attribute(chat_id, "last_interaction", datetime.now()) + db.start_new_dialog(chat_id) reply_text = "Hi! I'm ChatGPT bot implemented with GPT-3.5 OpenAI API 🤖\n\n" reply_text += HELP_MESSAGE @@ -108,26 +105,26 @@ async def start_handle(update: Update, context: CallbackContext): async def help_handle(update: Update, context: CallbackContext): - await register_user_if_not_exists(update, context, update.message.from_user) - user_id = update.message.from_user.id - db.set_user_attribute(user_id, "last_interaction", datetime.now()) + await register_chat_if_not_exists(update) + chat_id = update.message.chat_id + db.set_chat_attribute(chat_id, "last_interaction", datetime.now()) await update.message.reply_text(HELP_MESSAGE, parse_mode=ParseMode.HTML) async def retry_handle(update: Update, context: CallbackContext): - await register_user_if_not_exists(update, context, update.message.from_user) - if await is_previous_message_not_answered_yet(update, context): return + await register_chat_if_not_exists(update) + if await is_previous_message_not_answered_yet(update): return - user_id = update.message.from_user.id - db.set_user_attribute(user_id, "last_interaction", datetime.now()) + chat_id = update.message.chat_id + db.set_chat_attribute(chat_id, "last_interaction", datetime.now()) - dialog_messages = db.get_dialog_messages(user_id, dialog_id=None) + dialog_messages = db.get_dialog_messages(chat_id, dialog_id=None) if len(dialog_messages) == 0: await update.message.reply_text("No message to retry 🤷‍♂️") return last_dialog_message = dialog_messages.pop() - db.set_dialog_messages(user_id, dialog_messages, dialog_id=None) # last message was removed from the context + db.set_dialog_messages(chat_id, dialog_messages, dialog_id=None) # last message was removed from the context await message_handle(update, context, message=last_dialog_message["user"], use_new_dialog_timeout=False) @@ -138,23 +135,23 @@ async def message_handle(update: Update, context: CallbackContext, message=None, await edited_message_handle(update, context) return - await register_user_if_not_exists(update, context, update.message.from_user) - if await is_previous_message_not_answered_yet(update, context): return + await register_chat_if_not_exists(update) + if await is_previous_message_not_answered_yet(update): return - user_id = update.message.from_user.id + chat_id = update.message.chat_id async def message_handle_fn(): - chat_mode = db.get_user_attribute(user_id, "current_chat_mode") + chat_mode = db.get_chat_attribute(chat_id, "current_chat_mode") # new dialog timeout if use_new_dialog_timeout: - if (datetime.now() - db.get_user_attribute(user_id, "last_interaction")).seconds > config.new_dialog_timeout and len(db.get_dialog_messages(user_id)) > 0: - db.start_new_dialog(user_id) + if (datetime.now() - db.get_chat_attribute(chat_id, "last_interaction")).seconds > config.new_dialog_timeout and len(db.get_dialog_messages(chat_id)) > 0: + db.start_new_dialog(chat_id) await update.message.reply_text(f"Starting new dialog due to timeout ({openai_utils.CHAT_MODES[chat_mode]['name']} mode) ✅", parse_mode=ParseMode.HTML) - db.set_user_attribute(user_id, "last_interaction", datetime.now()) + db.set_chat_attribute(chat_id, "last_interaction", datetime.now()) # in case of CancelledError n_input_tokens, n_output_tokens = 0, 0 - current_model = db.get_user_attribute(user_id, "current_model") + current_model = db.get_chat_attribute(chat_id, "current_model") try: # send placeholder message to user @@ -165,7 +162,7 @@ async def message_handle_fn(): _message = message or update.message.text - dialog_messages = db.get_dialog_messages(user_id, dialog_id=None) + dialog_messages = db.get_dialog_messages(chat_id, dialog_id=None) parse_mode = { "html": ParseMode.HTML, "markdown": ParseMode.MARKDOWN @@ -211,16 +208,16 @@ async def fake_gen(): # update user data new_dialog_message = {"user": _message, "bot": answer, "date": datetime.now()} db.set_dialog_messages( - user_id, - db.get_dialog_messages(user_id, dialog_id=None) + [new_dialog_message], + chat_id, + db.get_dialog_messages(chat_id, dialog_id=None) + [new_dialog_message], dialog_id=None ) - db.update_n_used_tokens(user_id, current_model, n_input_tokens, n_output_tokens) + db.update_n_used_tokens(chat_id, current_model, n_input_tokens, n_output_tokens) except asyncio.CancelledError: # note: intermediate token updates only work when enable_message_streaming=True (config.yml) - db.update_n_used_tokens(user_id, current_model, n_input_tokens, n_output_tokens) + db.update_n_used_tokens(chat_id, current_model, n_input_tokens, n_output_tokens) raise except Exception as e: @@ -237,9 +234,9 @@ async def fake_gen(): text = f"✍️ Note: Your current dialog is too long, so {n_first_dialog_messages_removed} first messages were removed from the context.\n Send /new command to start new dialog" await update.message.reply_text(text, parse_mode=ParseMode.HTML) - async with user_semaphores[user_id]: + async with chat_semaphores[chat_id]: task = asyncio.create_task(message_handle_fn()) - user_tasks[user_id] = task + chat_tasks[chat_id] = task try: await task @@ -248,15 +245,15 @@ async def fake_gen(): else: pass finally: - if user_id in user_tasks: - del user_tasks[user_id] + if chat_id in chat_tasks: + del chat_tasks[chat_id] -async def is_previous_message_not_answered_yet(update: Update, context: CallbackContext): - await register_user_if_not_exists(update, context, update.message.from_user) +async def is_previous_message_not_answered_yet(update: Update): + await register_chat_if_not_exists(update) - user_id = update.message.from_user.id - if user_semaphores[user_id].locked(): + chat_id = update.message.chat_id + if chat_semaphores[chat_id].locked(): text = "⏳ Please wait for a reply to the previous message\n" text += "Or you can /cancel it" await update.message.reply_text(text, reply_to_message_id=update.message.id, parse_mode=ParseMode.HTML) @@ -266,11 +263,11 @@ async def is_previous_message_not_answered_yet(update: Update, context: Callback async def voice_message_handle(update: Update, context: CallbackContext): - await register_user_if_not_exists(update, context, update.message.from_user) - if await is_previous_message_not_answered_yet(update, context): return + await register_chat_if_not_exists(update) + if await is_previous_message_not_answered_yet(update): return - user_id = update.message.from_user.id - db.set_user_attribute(user_id, "last_interaction", datetime.now()) + chat_id = update.message.chat_id + db.set_chat_attribute(chat_id, "last_interaction", datetime.now()) voice = update.message.voice with tempfile.TemporaryDirectory() as tmp_dir: @@ -293,44 +290,44 @@ async def voice_message_handle(update: Update, context: CallbackContext): await update.message.reply_text(text, parse_mode=ParseMode.HTML) # update n_transcribed_seconds - db.set_user_attribute(user_id, "n_transcribed_seconds", voice.duration + db.get_user_attribute(user_id, "n_transcribed_seconds")) + db.set_chat_attribute(chat_id, "n_transcribed_seconds", voice.duration + db.get_chat_attribute(chat_id, "n_transcribed_seconds")) await message_handle(update, context, message=transcribed_text) async def new_dialog_handle(update: Update, context: CallbackContext): - await register_user_if_not_exists(update, context, update.message.from_user) - if await is_previous_message_not_answered_yet(update, context): return + await register_chat_if_not_exists(update) + if await is_previous_message_not_answered_yet(update): return - user_id = update.message.from_user.id - db.set_user_attribute(user_id, "last_interaction", datetime.now()) + chat_id = update.message.chat_id + db.set_chat_attribute(chat_id, "last_interaction", datetime.now()) - db.start_new_dialog(user_id) + db.start_new_dialog(chat_id) await update.message.reply_text("Starting new dialog ✅") - chat_mode = db.get_user_attribute(user_id, "current_chat_mode") + chat_mode = db.get_chat_attribute(chat_id, "current_chat_mode") await update.message.reply_text(f"{openai_utils.CHAT_MODES[chat_mode]['welcome_message']}", parse_mode=ParseMode.HTML) async def cancel_handle(update: Update, context: CallbackContext): - await register_user_if_not_exists(update, context, update.message.from_user) + await register_chat_if_not_exists(update) - user_id = update.message.from_user.id - db.set_user_attribute(user_id, "last_interaction", datetime.now()) + chat_id = update.message.chat_id + db.set_chat_attribute(chat_id, "last_interaction", datetime.now()) - if user_id in user_tasks: - task = user_tasks[user_id] + if chat_id in chat_tasks: + task = chat_tasks[chat_id] task.cancel() else: await update.message.reply_text("Nothing to cancel...", parse_mode=ParseMode.HTML) async def show_chat_modes_handle(update: Update, context: CallbackContext): - await register_user_if_not_exists(update, context, update.message.from_user) - if await is_previous_message_not_answered_yet(update, context): return + await register_chat_if_not_exists(update) + if await is_previous_message_not_answered_yet(update): return - user_id = update.message.from_user.id - db.set_user_attribute(user_id, "last_interaction", datetime.now()) + chat_id = update.message.chat_id + db.set_chat_attribute(chat_id, "last_interaction", datetime.now()) keyboard = [] for chat_mode, chat_mode_dict in openai_utils.CHAT_MODES.items(): @@ -341,22 +338,22 @@ async def show_chat_modes_handle(update: Update, context: CallbackContext): async def set_chat_mode_handle(update: Update, context: CallbackContext): - await register_user_if_not_exists(update.callback_query, context, update.callback_query.from_user) - user_id = update.callback_query.from_user.id + await register_chat_if_not_exists(update.callback_query) + chat_id = update.callback_query.message.chat_id query = update.callback_query await query.answer() chat_mode = query.data.split("|")[1] - db.set_user_attribute(user_id, "current_chat_mode", chat_mode) - db.start_new_dialog(user_id) + db.set_chat_attribute(chat_id, "current_chat_mode", chat_mode) + db.start_new_dialog(chat_id) await query.edit_message_text(f"{openai_utils.CHAT_MODES[chat_mode]['welcome_message']}", parse_mode=ParseMode.HTML) -def get_settings_menu(user_id: int): - current_model = db.get_user_attribute(user_id, "current_model") +def get_settings_menu(chat_id: int): + current_model = db.get_chat_attribute(chat_id, "current_model") text = config.models["info"][current_model]["description"] text += "\n\n" @@ -382,28 +379,28 @@ def get_settings_menu(user_id: int): async def settings_handle(update: Update, context: CallbackContext): - await register_user_if_not_exists(update, context, update.message.from_user) - if await is_previous_message_not_answered_yet(update, context): return + await register_chat_if_not_exists(update) + if await is_previous_message_not_answered_yet(update): return - user_id = update.message.from_user.id - db.set_user_attribute(user_id, "last_interaction", datetime.now()) + chat_id = update.message.chat_id + db.set_chat_attribute(chat_id, "last_interaction", datetime.now()) - text, reply_markup = get_settings_menu(user_id) + text, reply_markup = get_settings_menu(chat_id) await update.message.reply_text(text, reply_markup=reply_markup, parse_mode=ParseMode.HTML) async def set_settings_handle(update: Update, context: CallbackContext): - await register_user_if_not_exists(update.callback_query, context, update.callback_query.from_user) - user_id = update.callback_query.from_user.id + await register_chat_if_not_exists(update.callback_query) + chat_id = update.callback_query.message.chat_id query = update.callback_query await query.answer() _, model_key = query.data.split("|") - db.set_user_attribute(user_id, "current_model", model_key) - db.start_new_dialog(user_id) + db.set_chat_attribute(chat_id, "current_model", model_key) + db.start_new_dialog(chat_id) - text, reply_markup = get_settings_menu(user_id) + text, reply_markup = get_settings_menu(chat_id) try: await query.edit_message_text(text, reply_markup=reply_markup, parse_mode=ParseMode.HTML) except telegram.error.BadRequest as e: @@ -412,17 +409,17 @@ async def set_settings_handle(update: Update, context: CallbackContext): async def show_balance_handle(update: Update, context: CallbackContext): - await register_user_if_not_exists(update, context, update.message.from_user) + await register_chat_if_not_exists(update) - user_id = update.message.from_user.id - db.set_user_attribute(user_id, "last_interaction", datetime.now()) + chat_id = update.message.chat_id + db.set_chat_attribute(chat_id, "last_interaction", datetime.now()) # count total usage statistics total_n_spent_dollars = 0 total_n_used_tokens = 0 - n_used_tokens_dict = db.get_user_attribute(user_id, "n_used_tokens") - n_transcribed_seconds = db.get_user_attribute(user_id, "n_transcribed_seconds") + n_used_tokens_dict = db.get_chat_attribute(chat_id, "n_used_tokens") + n_transcribed_seconds = db.get_chat_attribute(chat_id, "n_transcribed_seconds") details_text = "🏷️ Details:\n" for model_key in sorted(n_used_tokens_dict.keys()): diff --git a/bot/database.py b/bot/database.py index 92cb56694..05b320969 100644 --- a/bot/database.py +++ b/bot/database.py @@ -12,88 +12,71 @@ def __init__(self): self.client = pymongo.MongoClient(config.mongodb_uri) self.db = self.client["chatgpt_telegram_bot"] - self.user_collection = self.db["user"] + self.chat_collection = self.db["chat"] self.dialog_collection = self.db["dialog"] - def check_if_user_exists(self, user_id: int, raise_exception: bool = False): - if self.user_collection.count_documents({"_id": user_id}) > 0: - return True - else: - if raise_exception: - raise ValueError(f"User {user_id} does not exist") - else: - return False - - def add_new_user( - self, - user_id: int, - chat_id: int, - username: str = "", - first_name: str = "", - last_name: str = "", - ): - user_dict = { - "_id": user_id, - "chat_id": chat_id, - - "username": username, - "first_name": first_name, - "last_name": last_name, + def chat_exists(self, chat_id: int): + return self.chat_collection.count_documents({"_id": chat_id}) > 0 + def create_chat(self, chat_id: int): + chat_dict = { + "_id": chat_id, "last_interaction": datetime.now(), "first_seen": datetime.now(), - "current_dialog_id": None, "current_chat_mode": "assistant", "current_model": config.models["available_text_models"][0], - - "n_used_tokens": {}, - + "n_used_tokens": {}, # {"model_name": {"n_input_tokens": 0, "n_output_tokens": 0}} "n_transcribed_seconds": 0.0 # voice message transcription } - if not self.check_if_user_exists(user_id): - self.user_collection.insert_one(user_dict) + if not self.chat_exists(chat_id): + self.chat_collection.insert_one(chat_dict) - def start_new_dialog(self, user_id: int): - self.check_if_user_exists(user_id, raise_exception=True) + def start_new_dialog(self, chat_id: int): + if not self.chat_exists(chat_id): + raise ValueError(f"Chat {chat_id} does not exist") dialog_id = str(uuid.uuid4()) dialog_dict = { "_id": dialog_id, - "user_id": user_id, - "chat_mode": self.get_user_attribute(user_id, "current_chat_mode"), + "chat_id": chat_id, + "chat_mode": self.get_chat_attribute(chat_id, "current_chat_mode"), "start_time": datetime.now(), - "model": self.get_user_attribute(user_id, "current_model"), + "model": self.get_chat_attribute(chat_id, "current_model"), "messages": [] } # add new dialog self.dialog_collection.insert_one(dialog_dict) - # update user's current dialog - self.user_collection.update_one( - {"_id": user_id}, + # update chat's current dialog + self.chat_collection.update_one( + {"_id": chat_id}, {"$set": {"current_dialog_id": dialog_id}} ) return dialog_id - def get_user_attribute(self, user_id: int, key: str): - self.check_if_user_exists(user_id, raise_exception=True) - user_dict = self.user_collection.find_one({"_id": user_id}) + def get_chat_attribute(self, chat_id: int, key: str): + if not self.chat_exists(chat_id): + raise ValueError(f"Chat {chat_id} does not exist") - if key not in user_dict: + chat_dict = self.chat_collection.find_one({"_id": chat_id}) + + if key not in chat_dict: return None - return user_dict[key] + return chat_dict[key] + + def set_chat_attribute(self, chat_id: int, key: str, value: Any): + if not self.chat_exists(chat_id): + raise ValueError(f"Chat {chat_id} does not exist") - def set_user_attribute(self, user_id: int, key: str, value: Any): - self.check_if_user_exists(user_id, raise_exception=True) - self.user_collection.update_one({"_id": user_id}, {"$set": {key: value}}) + self.chat_collection.update_one({"_id": chat_id}, {"$set": {key: value}}) - def update_n_used_tokens(self, user_id: int, model: str, n_input_tokens: int, n_output_tokens: int): - n_used_tokens_dict = self.get_user_attribute(user_id, "n_used_tokens") + def update_n_used_tokens(self, chat_id: int, model: str, n_input_tokens: int, n_output_tokens: int): + n_used_tokens_dict = self.get_chat_attribute(chat_id, "n_used_tokens") if model in n_used_tokens_dict: n_used_tokens_dict[model]["n_input_tokens"] += n_input_tokens @@ -104,24 +87,26 @@ def update_n_used_tokens(self, user_id: int, model: str, n_input_tokens: int, n_ "n_output_tokens": n_output_tokens } - self.set_user_attribute(user_id, "n_used_tokens", n_used_tokens_dict) + self.set_chat_attribute(chat_id, "n_used_tokens", n_used_tokens_dict) - def get_dialog_messages(self, user_id: int, dialog_id: Optional[str] = None): - self.check_if_user_exists(user_id, raise_exception=True) + def get_dialog_messages(self, chat_id: int, dialog_id: Optional[str] = None): + if not self.chat_exists(chat_id): + raise ValueError(f"Chat {chat_id} does not exist") if dialog_id is None: - dialog_id = self.get_user_attribute(user_id, "current_dialog_id") + dialog_id = self.get_chat_attribute(chat_id, "current_dialog_id") - dialog_dict = self.dialog_collection.find_one({"_id": dialog_id, "user_id": user_id}) + dialog_dict = self.dialog_collection.find_one({"_id": dialog_id, "chat_id": chat_id}) return dialog_dict["messages"] - def set_dialog_messages(self, user_id: int, dialog_messages: list, dialog_id: Optional[str] = None): - self.check_if_user_exists(user_id, raise_exception=True) + def set_dialog_messages(self, chat_id: int, dialog_messages: list, dialog_id: Optional[str] = None): + if not self.chat_exists(chat_id): + raise ValueError(f"Chat {chat_id} does not exist") if dialog_id is None: - dialog_id = self.get_user_attribute(user_id, "current_dialog_id") + dialog_id = self.get_chat_attribute(chat_id, "current_dialog_id") self.dialog_collection.update_one( - {"_id": dialog_id, "user_id": user_id}, + {"_id": dialog_id, "chat_id": chat_id}, {"$set": {"messages": dialog_messages}} )