Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt bot to handle groups independently using chat_id #232

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 94 additions & 97 deletions bot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
db = database.Database()
logger = logging.getLogger(__name__)

user_semaphores = {}
user_tasks = {}
chat_semaphores = {}
chat_tasks = {}

HELP_MESSAGE = """Commands:
⚪ /retry – Regenerate last bot answer
Expand All @@ -56,48 +56,45 @@ def split_text_into_chunks(text, chunk_size):
yield text[i:i + chunk_size]


async def register_user_if_not_exists(update: Update, context: CallbackContext, user: User):
if not db.check_if_user_exists(user.id):
db.add_new_user(
user.id,
update.message.chat_id,
username=user.username,
first_name=user.first_name,
last_name= user.last_name
)
db.start_new_dialog(user.id)
async def register_chat_if_not_exists(update: Update):

chat_id = update.message.chat_id

if not db.chat_exists(chat_id):
db.create_chat(chat_id)
db.start_new_dialog(chat_id)

if db.get_user_attribute(user.id, "current_dialog_id") is None:
db.start_new_dialog(user.id)
if db.get_chat_attribute(chat_id, "current_dialog_id") is None:
db.start_new_dialog(chat_id)

if user.id not in user_semaphores:
user_semaphores[user.id] = asyncio.Semaphore(1)
if chat_id not in chat_semaphores:
chat_semaphores[chat_id] = asyncio.Semaphore(1)

if db.get_user_attribute(user.id, "current_model") is None:
db.set_user_attribute(user.id, "current_model", config.models["available_text_models"][0])
if db.get_chat_attribute(chat_id, "current_model") is None:
db.set_chat_attribute(chat_id, "current_model", config.models["available_text_models"][0])

# back compatibility for n_used_tokens field
n_used_tokens = db.get_user_attribute(user.id, "n_used_tokens")
n_used_tokens = db.get_chat_attribute(chat_id, "n_used_tokens")
if isinstance(n_used_tokens, int): # old format
new_n_used_tokens = {
"gpt-3.5-turbo": {
"n_input_tokens": 0,
"n_output_tokens": n_used_tokens
}
}
db.set_user_attribute(user.id, "n_used_tokens", new_n_used_tokens)
db.set_chat_attribute(chat_id, "n_used_tokens", new_n_used_tokens)

# voice message transcription
if db.get_user_attribute(user.id, "n_transcribed_seconds") is None:
db.set_user_attribute(user.id, "n_transcribed_seconds", 0.0)
if db.get_chat_attribute(chat_id, "n_transcribed_seconds") is None:
db.set_chat_attribute(chat_id, "n_transcribed_seconds", 0.0)


async def start_handle(update: Update, context: CallbackContext):
await register_user_if_not_exists(update, context, update.message.from_user)
user_id = update.message.from_user.id
await register_chat_if_not_exists(update)
chat_id = update.message.chat_id

db.set_user_attribute(user_id, "last_interaction", datetime.now())
db.start_new_dialog(user_id)
db.set_chat_attribute(chat_id, "last_interaction", datetime.now())
db.start_new_dialog(chat_id)

reply_text = "Hi! I'm <b>ChatGPT</b> bot implemented with GPT-3.5 OpenAI API 🤖\n\n"
reply_text += HELP_MESSAGE
Expand All @@ -108,26 +105,26 @@ async def start_handle(update: Update, context: CallbackContext):


async def help_handle(update: Update, context: CallbackContext):
await register_user_if_not_exists(update, context, update.message.from_user)
user_id = update.message.from_user.id
db.set_user_attribute(user_id, "last_interaction", datetime.now())
await register_chat_if_not_exists(update)
chat_id = update.message.chat_id
db.set_chat_attribute(chat_id, "last_interaction", datetime.now())
await update.message.reply_text(HELP_MESSAGE, parse_mode=ParseMode.HTML)


async def retry_handle(update: Update, context: CallbackContext):
await register_user_if_not_exists(update, context, update.message.from_user)
if await is_previous_message_not_answered_yet(update, context): return
await register_chat_if_not_exists(update)
if await is_previous_message_not_answered_yet(update): return

user_id = update.message.from_user.id
db.set_user_attribute(user_id, "last_interaction", datetime.now())
chat_id = update.message.chat_id
db.set_chat_attribute(chat_id, "last_interaction", datetime.now())

dialog_messages = db.get_dialog_messages(user_id, dialog_id=None)
dialog_messages = db.get_dialog_messages(chat_id, dialog_id=None)
if len(dialog_messages) == 0:
await update.message.reply_text("No message to retry 🤷‍♂️")
return

last_dialog_message = dialog_messages.pop()
db.set_dialog_messages(user_id, dialog_messages, dialog_id=None) # last message was removed from the context
db.set_dialog_messages(chat_id, dialog_messages, dialog_id=None) # last message was removed from the context

await message_handle(update, context, message=last_dialog_message["user"], use_new_dialog_timeout=False)

Expand All @@ -138,23 +135,23 @@ async def message_handle(update: Update, context: CallbackContext, message=None,
await edited_message_handle(update, context)
return

await register_user_if_not_exists(update, context, update.message.from_user)
if await is_previous_message_not_answered_yet(update, context): return
await register_chat_if_not_exists(update)
if await is_previous_message_not_answered_yet(update): return

user_id = update.message.from_user.id
chat_id = update.message.chat_id
async def message_handle_fn():
chat_mode = db.get_user_attribute(user_id, "current_chat_mode")
chat_mode = db.get_chat_attribute(chat_id, "current_chat_mode")

# new dialog timeout
if use_new_dialog_timeout:
if (datetime.now() - db.get_user_attribute(user_id, "last_interaction")).seconds > config.new_dialog_timeout and len(db.get_dialog_messages(user_id)) > 0:
db.start_new_dialog(user_id)
if (datetime.now() - db.get_chat_attribute(chat_id, "last_interaction")).seconds > config.new_dialog_timeout and len(db.get_dialog_messages(chat_id)) > 0:
db.start_new_dialog(chat_id)
await update.message.reply_text(f"Starting new dialog due to timeout (<b>{openai_utils.CHAT_MODES[chat_mode]['name']}</b> mode) ✅", parse_mode=ParseMode.HTML)
db.set_user_attribute(user_id, "last_interaction", datetime.now())
db.set_chat_attribute(chat_id, "last_interaction", datetime.now())

# in case of CancelledError
n_input_tokens, n_output_tokens = 0, 0
current_model = db.get_user_attribute(user_id, "current_model")
current_model = db.get_chat_attribute(chat_id, "current_model")

try:
# send placeholder message to user
Expand All @@ -165,7 +162,7 @@ async def message_handle_fn():

_message = message or update.message.text

dialog_messages = db.get_dialog_messages(user_id, dialog_id=None)
dialog_messages = db.get_dialog_messages(chat_id, dialog_id=None)
parse_mode = {
"html": ParseMode.HTML,
"markdown": ParseMode.MARKDOWN
Expand Down Expand Up @@ -211,16 +208,16 @@ async def fake_gen():
# update user data
new_dialog_message = {"user": _message, "bot": answer, "date": datetime.now()}
db.set_dialog_messages(
user_id,
db.get_dialog_messages(user_id, dialog_id=None) + [new_dialog_message],
chat_id,
db.get_dialog_messages(chat_id, dialog_id=None) + [new_dialog_message],
dialog_id=None
)

db.update_n_used_tokens(user_id, current_model, n_input_tokens, n_output_tokens)
db.update_n_used_tokens(chat_id, current_model, n_input_tokens, n_output_tokens)

except asyncio.CancelledError:
# note: intermediate token updates only work when enable_message_streaming=True (config.yml)
db.update_n_used_tokens(user_id, current_model, n_input_tokens, n_output_tokens)
db.update_n_used_tokens(chat_id, current_model, n_input_tokens, n_output_tokens)
raise

except Exception as e:
Expand All @@ -237,9 +234,9 @@ async def fake_gen():
text = f"✍️ <i>Note:</i> Your current dialog is too long, so <b>{n_first_dialog_messages_removed} first messages</b> were removed from the context.\n Send /new command to start new dialog"
await update.message.reply_text(text, parse_mode=ParseMode.HTML)

async with user_semaphores[user_id]:
async with chat_semaphores[chat_id]:
task = asyncio.create_task(message_handle_fn())
user_tasks[user_id] = task
chat_tasks[chat_id] = task

try:
await task
Expand All @@ -248,15 +245,15 @@ async def fake_gen():
else:
pass
finally:
if user_id in user_tasks:
del user_tasks[user_id]
if chat_id in chat_tasks:
del chat_tasks[chat_id]


async def is_previous_message_not_answered_yet(update: Update, context: CallbackContext):
await register_user_if_not_exists(update, context, update.message.from_user)
async def is_previous_message_not_answered_yet(update: Update):
await register_chat_if_not_exists(update)

user_id = update.message.from_user.id
if user_semaphores[user_id].locked():
chat_id = update.message.chat_id
if chat_semaphores[chat_id].locked():
text = "⏳ Please <b>wait</b> for a reply to the previous message\n"
text += "Or you can /cancel it"
await update.message.reply_text(text, reply_to_message_id=update.message.id, parse_mode=ParseMode.HTML)
Expand All @@ -266,11 +263,11 @@ async def is_previous_message_not_answered_yet(update: Update, context: Callback


async def voice_message_handle(update: Update, context: CallbackContext):
await register_user_if_not_exists(update, context, update.message.from_user)
if await is_previous_message_not_answered_yet(update, context): return
await register_chat_if_not_exists(update)
if await is_previous_message_not_answered_yet(update): return

user_id = update.message.from_user.id
db.set_user_attribute(user_id, "last_interaction", datetime.now())
chat_id = update.message.chat_id
db.set_chat_attribute(chat_id, "last_interaction", datetime.now())

voice = update.message.voice
with tempfile.TemporaryDirectory() as tmp_dir:
Expand All @@ -293,44 +290,44 @@ async def voice_message_handle(update: Update, context: CallbackContext):
await update.message.reply_text(text, parse_mode=ParseMode.HTML)

# update n_transcribed_seconds
db.set_user_attribute(user_id, "n_transcribed_seconds", voice.duration + db.get_user_attribute(user_id, "n_transcribed_seconds"))
db.set_chat_attribute(chat_id, "n_transcribed_seconds", voice.duration + db.get_chat_attribute(chat_id, "n_transcribed_seconds"))

await message_handle(update, context, message=transcribed_text)


async def new_dialog_handle(update: Update, context: CallbackContext):
await register_user_if_not_exists(update, context, update.message.from_user)
if await is_previous_message_not_answered_yet(update, context): return
await register_chat_if_not_exists(update)
if await is_previous_message_not_answered_yet(update): return

user_id = update.message.from_user.id
db.set_user_attribute(user_id, "last_interaction", datetime.now())
chat_id = update.message.chat_id
db.set_chat_attribute(chat_id, "last_interaction", datetime.now())

db.start_new_dialog(user_id)
db.start_new_dialog(chat_id)
await update.message.reply_text("Starting new dialog ✅")

chat_mode = db.get_user_attribute(user_id, "current_chat_mode")
chat_mode = db.get_chat_attribute(chat_id, "current_chat_mode")
await update.message.reply_text(f"{openai_utils.CHAT_MODES[chat_mode]['welcome_message']}", parse_mode=ParseMode.HTML)


async def cancel_handle(update: Update, context: CallbackContext):
await register_user_if_not_exists(update, context, update.message.from_user)
await register_chat_if_not_exists(update)

user_id = update.message.from_user.id
db.set_user_attribute(user_id, "last_interaction", datetime.now())
chat_id = update.message.chat_id
db.set_chat_attribute(chat_id, "last_interaction", datetime.now())

if user_id in user_tasks:
task = user_tasks[user_id]
if chat_id in chat_tasks:
task = chat_tasks[chat_id]
task.cancel()
else:
await update.message.reply_text("<i>Nothing to cancel...</i>", parse_mode=ParseMode.HTML)


async def show_chat_modes_handle(update: Update, context: CallbackContext):
await register_user_if_not_exists(update, context, update.message.from_user)
if await is_previous_message_not_answered_yet(update, context): return
await register_chat_if_not_exists(update)
if await is_previous_message_not_answered_yet(update): return

user_id = update.message.from_user.id
db.set_user_attribute(user_id, "last_interaction", datetime.now())
chat_id = update.message.chat_id
db.set_chat_attribute(chat_id, "last_interaction", datetime.now())

keyboard = []
for chat_mode, chat_mode_dict in openai_utils.CHAT_MODES.items():
Expand All @@ -341,22 +338,22 @@ async def show_chat_modes_handle(update: Update, context: CallbackContext):


async def set_chat_mode_handle(update: Update, context: CallbackContext):
await register_user_if_not_exists(update.callback_query, context, update.callback_query.from_user)
user_id = update.callback_query.from_user.id
await register_chat_if_not_exists(update.callback_query)
chat_id = update.callback_query.message.chat_id

query = update.callback_query
await query.answer()

chat_mode = query.data.split("|")[1]

db.set_user_attribute(user_id, "current_chat_mode", chat_mode)
db.start_new_dialog(user_id)
db.set_chat_attribute(chat_id, "current_chat_mode", chat_mode)
db.start_new_dialog(chat_id)

await query.edit_message_text(f"{openai_utils.CHAT_MODES[chat_mode]['welcome_message']}", parse_mode=ParseMode.HTML)


def get_settings_menu(user_id: int):
current_model = db.get_user_attribute(user_id, "current_model")
def get_settings_menu(chat_id: int):
current_model = db.get_chat_attribute(chat_id, "current_model")
text = config.models["info"][current_model]["description"]

text += "\n\n"
Expand All @@ -382,28 +379,28 @@ def get_settings_menu(user_id: int):


async def settings_handle(update: Update, context: CallbackContext):
await register_user_if_not_exists(update, context, update.message.from_user)
if await is_previous_message_not_answered_yet(update, context): return
await register_chat_if_not_exists(update)
if await is_previous_message_not_answered_yet(update): return

user_id = update.message.from_user.id
db.set_user_attribute(user_id, "last_interaction", datetime.now())
chat_id = update.message.chat_id
db.set_chat_attribute(chat_id, "last_interaction", datetime.now())

text, reply_markup = get_settings_menu(user_id)
text, reply_markup = get_settings_menu(chat_id)
await update.message.reply_text(text, reply_markup=reply_markup, parse_mode=ParseMode.HTML)


async def set_settings_handle(update: Update, context: CallbackContext):
await register_user_if_not_exists(update.callback_query, context, update.callback_query.from_user)
user_id = update.callback_query.from_user.id
await register_chat_if_not_exists(update.callback_query)
chat_id = update.callback_query.message.chat_id

query = update.callback_query
await query.answer()

_, model_key = query.data.split("|")
db.set_user_attribute(user_id, "current_model", model_key)
db.start_new_dialog(user_id)
db.set_chat_attribute(chat_id, "current_model", model_key)
db.start_new_dialog(chat_id)

text, reply_markup = get_settings_menu(user_id)
text, reply_markup = get_settings_menu(chat_id)
try:
await query.edit_message_text(text, reply_markup=reply_markup, parse_mode=ParseMode.HTML)
except telegram.error.BadRequest as e:
Expand All @@ -412,17 +409,17 @@ async def set_settings_handle(update: Update, context: CallbackContext):


async def show_balance_handle(update: Update, context: CallbackContext):
await register_user_if_not_exists(update, context, update.message.from_user)
await register_chat_if_not_exists(update)

user_id = update.message.from_user.id
db.set_user_attribute(user_id, "last_interaction", datetime.now())
chat_id = update.message.chat_id
db.set_chat_attribute(chat_id, "last_interaction", datetime.now())

# count total usage statistics
total_n_spent_dollars = 0
total_n_used_tokens = 0

n_used_tokens_dict = db.get_user_attribute(user_id, "n_used_tokens")
n_transcribed_seconds = db.get_user_attribute(user_id, "n_transcribed_seconds")
n_used_tokens_dict = db.get_chat_attribute(chat_id, "n_used_tokens")
n_transcribed_seconds = db.get_chat_attribute(chat_id, "n_transcribed_seconds")

details_text = "🏷️ Details:\n"
for model_key in sorted(n_used_tokens_dict.keys()):
Expand Down
Loading