Skip to content

Commit

Permalink
Training is now done on bot startup
Browse files Browse the repository at this point in the history
  • Loading branch information
Exenifix committed May 1, 2022
1 parent cd12513 commit 3921333
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 35 deletions.
8 changes: 0 additions & 8 deletions .github/workflows/runner.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,6 @@ jobs:

steps:
- uses: actions/checkout@v3
- name: Train Model
run: |
python3 -m virtualenv train-venv
. train-venv/bin/activate
pip install -U -r train-requirements.txt
python3 ai/train.py
deactivate
- name: Build Docker Image
run: docker build -t aias .

Expand Down
34 changes: 7 additions & 27 deletions ai/train.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,20 @@
import asyncio
import os
import sys
from pickle import dump

import asyncpg
from dotenv import load_dotenv
from exencolorlogs import Logger
from numpy import array
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from utils.datamodels import Database
from utils.enums import FetchMode

log = Logger("TRAIN")

load_dotenv()
DATABASE = os.getenv("DATABASE")
HOST = os.getenv("HOST")
USER = os.getenv("USER")
PASSWORD = os.getenv("PASSWORD")

if DATABASE is None:
log.critical(".env file not filled up properly")
sys.exit(1)


async def main():
log.info("Establishing connection to the database...")
con: asyncpg.Connection = await asyncpg.connect(
database=DATABASE, host=HOST, user=USER, password=PASSWORD
async def train(db: Database):
log = Logger("TRAIN")
records = await db.execute(
"SELECT total_chars, unique_chars, total_words, unique_words, is_spam FROM data WHERE is_spam IS NOT NULL",
fetch_mode=FetchMode.ALL,
)
records = await con.fetch(
"SELECT total_chars, unique_chars, total_words, unique_words, is_spam FROM data WHERE is_spam IS NOT NULL"
)
await con.close()
log.info("Preparing data... Amount of records: %s", len(records))
data = array([list(r.values()) for r in records])
train_x = data[:, 0:4]
Expand All @@ -52,6 +35,3 @@ async def main():
with open("./ai/models/model.ai", "wb") as f:
dump(model, f)
log.info("File saved successfully to ./ai/models/model.ai")


asyncio.run(main())
4 changes: 4 additions & 0 deletions utils/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from exencolorlogs import Logger

from utils import embeds
from ai.train import train as train_ai
from utils.constants import EMOJIS, LOG_CHANNEL_ID, TRAIN_GUILD_IDS
from utils.datamodels import Database
from utils.views import AntispamView
Expand Down Expand Up @@ -44,6 +45,9 @@ async def start(self, *args, **kwargs):
await self.db.connect()
await self.db.setup()

self.log.info("Training AI model...")
await train_ai(self.db)

self.log.info("Loading extensions...")
self.load_extensions("./ext")

Expand Down

0 comments on commit 3921333

Please sign in to comment.