Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Next #160

Merged
merged 11 commits into from
Apr 22, 2024
Merged
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ venv
env
.env
__pycache__
certificate.*
private.*
2 changes: 1 addition & 1 deletion _version.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from packaging import version

__version__ = "5.2.10"
__version__ = "5.4.2"


def version_major() -> int:
Expand Down
212 changes: 205 additions & 7 deletions bot_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
BOT_COMMAND_CHAT = "chat"
BOT_COMMAND_MODULE = "module"
BOT_COMMAND_STYLE = "style"
BOT_COMMAND_MODEL = "model"
BOT_COMMAND_CLEAR = "clear"
BOT_COMMAND_LANG = "lang"
BOT_COMMAND_CHAT_ID = "chatid"
Expand Down Expand Up @@ -121,6 +122,8 @@ def __init__(
logging_queue: multiprocessing.Queue,
queue_handler_: queue_handler.QueueHandler,
modules: Dict,
web_cooldown_timer: multiprocessing.Value,
web_request_lock: multiprocessing.Lock,
):
self.config = config
self.config_file = config_file
Expand All @@ -130,6 +133,10 @@ def __init__(
self.queue_handler = queue_handler_
self.modules = modules

# LMAO
self.web_cooldown_timer = web_cooldown_timer
self.web_request_lock = web_request_lock

self.prevent_shutdown_flag = multiprocessing.Value(c_bool, False)

self._application = None
Expand Down Expand Up @@ -159,7 +166,17 @@ def start_bot(self):

# Build bot
telegram_config = self.config.get("telegram")
builder = ApplicationBuilder().token(telegram_config.get("api_key"))
proxy = telegram_config.get("proxy")
if proxy:
logging.info(f"Using proxy {proxy} for Telegram bot")
builder = (
ApplicationBuilder()
.token(telegram_config.get("api_key"))
.proxy(proxy)
.get_updates_proxy(proxy)
)
else:
builder = ApplicationBuilder().token(telegram_config.get("api_key"))
self._application = builder.build()

# Set commands
Expand All @@ -172,6 +189,7 @@ def start_bot(self):
self._application.add_handler(CaptionCommandHandler(BOT_COMMAND_CHAT, self.bot_module_request))
self._application.add_handler(CaptionCommandHandler(BOT_COMMAND_MODULE, self.bot_command_module))
self._application.add_handler(CaptionCommandHandler(BOT_COMMAND_STYLE, self.bot_command_style))
self._application.add_handler(CaptionCommandHandler(BOT_COMMAND_MODEL, self.bot_command_model))
self._application.add_handler(CaptionCommandHandler(BOT_COMMAND_CLEAR, self.bot_command_clear))
self._application.add_handler(CaptionCommandHandler(BOT_COMMAND_LANG, self.bot_command_lang))
self._application.add_handler(CaptionCommandHandler(BOT_COMMAND_CHAT_ID, self.bot_command_chatid))
Expand Down Expand Up @@ -302,15 +320,13 @@ async def query_callback(self, update: Update, context: ContextTypes.DEFAULT_TYP
return

# Parse data from markup
action, data_, reply_message_id = data_.split("|")
action, data_, argument_ = data_.split("|")
if not action:
raise Exception("No action in callback data")
if not data_:
data_ = None
if not reply_message_id:
reply_message_id = None
else:
reply_message_id = int(reply_message_id.strip())
if not argument_:
argument_ = None

# Get user
banned, user = await self._user_get_check(update, context, prompt_language_selection=False)
Expand All @@ -329,6 +345,12 @@ async def query_callback(self, update: Update, context: ContextTypes.DEFAULT_TYP

# Regenerate request
if action == "regenerate":
# Parse message ID
if not argument_:
reply_message_id = None
else:
reply_message_id = int(argument_.strip())

# Get last message ID
reply_message_id_last = self.users_handler.get_key(0, "reply_message_id_last", user=user)
if reply_message_id_last is None or reply_message_id_last != reply_message_id:
Expand Down Expand Up @@ -364,6 +386,12 @@ async def query_callback(self, update: Update, context: ContextTypes.DEFAULT_TYP

# Continue generating
elif action == "continue":
# Parse message ID
if not argument_:
reply_message_id = None
else:
reply_message_id = int(argument_.strip())

# Get last message ID
reply_message_id_last = self.users_handler.get_key(0, "reply_message_id_last", user=user)
if reply_message_id_last is None or reply_message_id_last != reply_message_id:
Expand All @@ -385,6 +413,12 @@ async def query_callback(self, update: Update, context: ContextTypes.DEFAULT_TYP

# Send suggestion
elif action == "suggestion":
# Parse message ID
if not argument_:
reply_message_id = None
else:
reply_message_id = int(argument_.strip())

# Get last message ID
reply_message_id_last = self.users_handler.get_key(0, "reply_message_id_last", user=user)
if reply_message_id_last is None or reply_message_id_last != reply_message_id:
Expand Down Expand Up @@ -421,6 +455,12 @@ async def query_callback(self, update: Update, context: ContextTypes.DEFAULT_TYP

# Stop generating
elif action == "stop":
# Parse message ID
if not argument_:
reply_message_id = None
else:
reply_message_id = int(argument_.strip())

# Get last message ID
reply_message_id_last = self.users_handler.get_key(0, "reply_message_id_last", user=user)
if reply_message_id_last is None or reply_message_id_last != reply_message_id:
Expand Down Expand Up @@ -466,6 +506,10 @@ async def query_callback(self, update: Update, context: ContextTypes.DEFAULT_TYP
elif action == "style":
await self._bot_command_style_raw(data_, user, context)

# Change model
elif action == "model":
await self._bot_command_model_raw(data_, argument_, user, context)

# Change language
elif action == "lang":
await self._bot_command_lang_raw(data_, user, context)
Expand Down Expand Up @@ -752,8 +796,20 @@ async def bot_command_restart(self, update: Update, context: ContextTypes.DEFAUL
continue
logging.info(f"Trying to load and initialize {module_name} module")
try:
use_web = (
module_name.startswith("lmao_")
and module_name in self.config.get("modules").get("lmao_web_for_modules", [])
and "lmao_web_api_url" in self.config.get("modules")
)
module = module_wrapper_global.ModuleWrapperGlobal(
module_name, self.config, self.messages, self.users_handler, self.logging_queue
module_name,
self.config,
self.messages,
self.users_handler,
self.logging_queue,
use_web=use_web,
web_cooldown_timer=self.web_cooldown_timer,
web_request_lock=self.web_request_lock,
)
self.modules[module_name] = module
reload_logs += f"Intialized and loaded {module_name} module\n"
Expand Down Expand Up @@ -1042,6 +1098,148 @@ async def _bot_command_style_raw(self, style: str or None, user: Dict, context:
context,
)

async def bot_command_model(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""/model commands callback

Args:
update (Update): update object from bot's callback
context (ContextTypes.DEFAULT_TYPE): context object from bot's callback
"""
# Get user
banned, user = await self._user_get_check(update, context)
if user is None:
return
user_id = user.get("user_id")
user_name = self.users_handler.get_key(0, "user_name", "", user=user)
lang_id = self.users_handler.get_key(0, "lang_id", user=user)

# Log command
logging.info(f"/model command from {user_name} ({user_id})")

# Exit if banned
if banned:
return

module_id = self.users_handler.get_key(0, "module", self.config.get("modules").get("default"), user=user)

model = None

# User specified model
if context.args and len(context.args) >= 1:
try:
model = context.args[0].strip().lower()

# Get available models
current_module_id = self.users_handler.get_key(
0, "module", self.config.get("modules").get("default"), user=user
)
available_models = self.config.get(current_module_id).get("models", [])

# Get current model
model_current = self.config.get(module_id).get("model_default")
model_current = self.users_handler.get_key(0, f"{module_id}_model", model_current, user=user)

# Check
if not model_current or len(available_models) == 0:
await _send_safe(
user_id,
self.messages.get_message("model_no_models", lang_id=lang_id),
context,
)
return

# Check
if model not in available_models:
raise Exception(f"No model {model} in {' '.join(available_models)}")
except Exception as e:
logging.error("Error retrieving requested model", exc_info=e)
await _send_safe(
user["user_id"],
self.messages.get_message("model_change_error", lang_id=lang_id).format(error_text=str(e)),
context,
)
return

# Change model or ask the user
await self._bot_command_model_raw(module_id, model, user, context)

async def _bot_command_model_raw(
self, module_id: str or None, model: str or None, user: Dict, context: ContextTypes.DEFAULT_TYPE
) -> None:
"""Changes model of module

Args:
module_id (str or None): id of module to change model of
model (str or None): model name or None to ask user
user (Dict): user's data as dictionary
context (ContextTypes.DEFAULT_TYPE): context object from bot's callback
"""
user_id = user.get("user_id")
lang_id = self.users_handler.get_key(0, "lang_id", user=user)

# Extract current user's module and model
module_icon_names = self.messages.get_message("modules", lang_id=lang_id)
if not module_id:
module_id = self.users_handler.get_key(0, "module", self.config.get("modules").get("default"), user=user)
current_module_name = module_icon_names.get(module_id).get("name")
current_module_icon = module_icon_names.get(module_id).get("icon")
current_module_name = f"{current_module_icon} {current_module_name}"

# Get available models
available_models = self.config.get(module_id).get("models", [])

# Get current model
model_current = self.config.get(module_id).get("model_default")
model_current = self.users_handler.get_key(0, f"{module_id}_model", model_current, user=user)

# Check
if not model_current or len(available_models) == 0:
await _send_safe(
user_id,
self.messages.get_message("model_no_models", lang_id=lang_id),
context,
)
return

# Ask user
if not model:
buttons = []
for model_ in available_models:
buttons.append(InlineKeyboardButton(model_, callback_data=f"model|{module_id}|{model_}"))

await _send_safe(
user_id,
self.messages.get_message("model_select", lang_id=lang_id).format(
module_name=current_module_name, current_model=model_current
),
context,
reply_markup=InlineKeyboardMarkup(bot_sender.build_menu(buttons)),
)
return

# Change model
try:
# Change model of user
self.users_handler.set_key(user_id, f"{module_id}_model", model)

# Send confirmation
await _send_safe(
user_id,
self.messages.get_message("model_changed", lang_id=lang_id).format(
module_name=current_module_name, changed_model=model
),
context,
)

# Error changing model
except Exception as e:
logging.error("Error changing model", exc_info=e)
await _send_safe(
user_id,
self.messages.get_message("model_change_error", lang_id=lang_id).format(error_text=str(e)),
context,
)

########################################
# General (non-modules) commands below #
########################################
Expand Down
8 changes: 8 additions & 0 deletions bot_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,14 @@ def build_markup(
)
buttons.append(button_style)

# Add change model button
if request_response.module_name in module_wrapper_global.MODULES_WITH_MODELS:
button_model = InlineKeyboardButton(
messages_.get_message("button_model_change", user_id=user_id),
callback_data=f"model|{request_response.module_name}|",
)
buttons.append(button_model)

# Add change module button for all modules
button_module = InlineKeyboardButton(
messages_.get_message("button_module", user_id=user_id),
Expand Down
Loading
Loading