diff --git a/.github/workflows/runner.yml b/.github/workflows/runner.yml index 38a60e3..946e787 100644 --- a/.github/workflows/runner.yml +++ b/.github/workflows/runner.yml @@ -18,14 +18,6 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Train Model - run: | - python3 -m virtualenv train-venv - . train-venv/bin/activate - pip install -U -r train-requirements.txt - python3 ai/train.py - deactivate - - name: Build Docker Image run: docker build -t aias . diff --git a/ai/train.py b/ai/train.py index 7b1bff2..046e99c 100644 --- a/ai/train.py +++ b/ai/train.py @@ -1,37 +1,20 @@ -import asyncio import os -import sys from pickle import dump -import asyncpg -from dotenv import load_dotenv from exencolorlogs import Logger from numpy import array from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import cross_val_score +from utils.datamodels import Database +from utils.enums import FetchMode -log = Logger("TRAIN") -load_dotenv() -DATABASE = os.getenv("DATABASE") -HOST = os.getenv("HOST") -USER = os.getenv("USER") -PASSWORD = os.getenv("PASSWORD") - -if DATABASE is None: - log.critical(".env file not filled up properly") - sys.exit(1) - - -async def main(): - log.info("Establishing connection to the database...") - con: asyncpg.Connection = await asyncpg.connect( - database=DATABASE, host=HOST, user=USER, password=PASSWORD +async def train(db: Database): + log = Logger("TRAIN") + records = await db.execute( + "SELECT total_chars, unique_chars, total_words, unique_words, is_spam FROM data WHERE is_spam IS NOT NULL", + fetch_mode=FetchMode.ALL, ) - records = await con.fetch( - "SELECT total_chars, unique_chars, total_words, unique_words, is_spam FROM data WHERE is_spam IS NOT NULL" - ) - await con.close() log.info("Preparing data... Amount of records: %s", len(records)) data = array([list(r.values()) for r in records]) train_x = data[:, 0:4] @@ -52,6 +35,3 @@ async def main(): with open("./ai/models/model.ai", "wb") as f: dump(model, f) log.info("File saved successfully to ./ai/models/model.ai") - - -asyncio.run(main()) diff --git a/utils/bot.py b/utils/bot.py index 9f4de69..9e5a9d5 100644 --- a/utils/bot.py +++ b/utils/bot.py @@ -11,6 +11,7 @@ from exencolorlogs import Logger from utils import embeds +from ai.train import train as train_ai from utils.constants import EMOJIS, LOG_CHANNEL_ID, TRAIN_GUILD_IDS from utils.datamodels import Database from utils.views import AntispamView @@ -44,6 +45,9 @@ async def start(self, *args, **kwargs): await self.db.connect() await self.db.setup() + self.log.info("Training AI model...") + await train_ai(self.db) + self.log.info("Loading extensions...") self.load_extensions("./ext")