Skip to content

Commit

Permalink
Better ML (#46)
Browse files Browse the repository at this point in the history
added recurring retrain and additional marker (less than X messages) to the model
  • Loading branch information
Szer committed Aug 7, 2024
1 parent 2566ef3 commit 45e56e9
Show file tree
Hide file tree
Showing 9 changed files with 653 additions and 588 deletions.
4 changes: 3 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ CLEANUP_OLD_MESSAGES=true
CLEANUP_INTERVAL_SEC=86400
CLEANUP_OLD_LIMIT_SEC=259200
ML_ENABLED=false
ML_RETRAIN_INTERVAL_SEC=86400
ML_SEED=
ML_SPAM_DELETION_ENABLED=false
ML_TRAIN_BEFORE_DATE=2021-01-01
ML_TRAIN_INTERVAL_DAYS=30
ML_TRAIN_CRITICAL_MSG_COUNT=5
ML_TRAINING_SET_FRACTION=0.2
ML_SPAM_THRESHOLD=0.5
ML_WARNING_THRESHOLD=0.0
Expand Down
2 changes: 0 additions & 2 deletions src/VahterBanBot.Tests/ContainerTestBase.fs
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ type VahterTestContainers() =
.WithEnvironment("DATABASE_URL", internalConnectionString)
.WithEnvironment("CLEANUP_OLD_MESSAGES", "false")
.WithEnvironment("ML_ENABLED", "true")
// seed data uses 2021-01-01 as a date for all messages
.WithEnvironment("ML_TRAIN_BEFORE_DATE", "2021-01-02T00:00:00Z")
.WithEnvironment("ML_SEED", "42")
.WithEnvironment("ML_SPAM_DELETION_ENABLED", "true")
.WithEnvironment("ML_SPAM_THRESHOLD", "1.0")
Expand Down
1,026 changes: 513 additions & 513 deletions src/VahterBanBot.Tests/test_seed.sql

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion src/VahterBanBot/Bot.fs
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,9 @@ let justMessage
%mlActivity.SetTag("skipPrediction", shouldBeSkipped)

if not shouldBeSkipped then
match ml.Predict message.Text with
let! usrMsgCount = DB.countUniqueUserMsg message.From.Id

match ml.Predict(message.Text, usrMsgCount) with
| Some prediction ->
%mlActivity.SetTag("spamScoreMl", prediction.Score)

Expand Down
88 changes: 57 additions & 31 deletions src/VahterBanBot/DB.fs
Original file line number Diff line number Diff line change
Expand Up @@ -171,55 +171,70 @@ let getUserById (userId: int64): Task<DbUser option> =
return users |> Seq.tryHead
}

type SpamOrHam =
{ [<LoadColumn(0)>]
text: string
[<LoadColumn(1)>]
spam: bool }
type SpamOrHamDb =
{ text: string
spam: bool
less_than_n_messages: bool
created_at: DateTime }

let mlData(criticalDate: DateTime) : Task<SpamOrHam array> =
let mlData (criticalMsgCount: int) (criticalDate: DateTime) : Task<SpamOrHamDb array> =
task {
use conn = new NpgsqlConnection(connString)

//language=postgresql
let sql =
"""
WITH really_banned AS (SELECT *
WITH less_than_n_messages AS (SELECT u.id, COUNT(DISTINCT m.text) < @criticalMsgCount AS less_than_n_messages
FROM "user" u
LEFT JOIN message m ON u.id = m.user_id
GROUP BY u.id),
really_banned AS (SELECT *
FROM banned b
-- known false positive spam messages
WHERE NOT EXISTS(SELECT 1 FROM false_positive_users fpu WHERE fpu.user_id = b.banned_user_id)
AND NOT EXISTS(SELECT 1
FROM false_positive_messages fpm
WHERE fpm.text = b.message_text)
AND b.message_text IS NOT NULL
AND b.banned_at <= @criticalDate),
spam_or_ham AS (SELECT DISTINCT COALESCE(m.text, re_id.message_text) AS text,
CASE
-- known false negative spam messages
WHEN (EXISTS(SELECT 1
FROM false_negative_messages fnm
WHERE fnm.chat_id = m.chat_id
AND fnm.message_id = m.message_id)
-- known banned spam messages by bot, and not marked as false positive
OR EXISTS(SELECT 1
FROM banned_by_bot bbb
WHERE bbb.banned_in_chat_id = m.chat_id
AND bbb.message_id = m.message_id))
THEN TRUE
WHEN re_id.banned_user_id IS NULL AND re_text.banned_user_id IS NULL
THEN FALSE
ELSE TRUE
END AS spam
FROM (SELECT * FROM message WHERE text IS NOT NULL AND created_at <= @criticalDate) m
FULL OUTER JOIN really_banned re_id
ON m.message_id = re_id.message_id AND m.chat_id = re_id.banned_in_chat_id
LEFT JOIN really_banned re_text ON m.text = re_text.message_text)
AND b.banned_at >= @criticalDate),
spam_or_ham AS (SELECT x.text,
x.spam,
x.less_than_n_messages,
MAX(x.created_at) AS created_at
FROM (SELECT DISTINCT COALESCE(m.text, re_id.message_text) AS text,
CASE
-- known false negative spam messages
WHEN (EXISTS(SELECT 1
FROM false_negative_messages fnm
WHERE fnm.chat_id = m.chat_id
AND fnm.message_id = m.message_id)
-- known banned spam messages by bot, and not marked as false positive
OR EXISTS(SELECT 1
FROM banned_by_bot bbb
WHERE bbb.banned_in_chat_id = m.chat_id
AND bbb.message_id = m.message_id))
THEN TRUE
WHEN re_id.banned_user_id IS NULL AND re_text.banned_user_id IS NULL
THEN FALSE
ELSE TRUE
END AS spam,
COALESCE(l.less_than_n_messages, TRUE) AS less_than_n_messages,
COALESCE(re_id.banned_at, re_text.banned_at, m.created_at) AS created_at
FROM (SELECT *
FROM message
WHERE text IS NOT NULL
AND created_at >= @criticalDate) m
FULL OUTER JOIN really_banned re_id
ON m.message_id = re_id.message_id AND m.chat_id = re_id.banned_in_chat_id
LEFT JOIN really_banned re_text ON m.text = re_text.message_text
LEFT JOIN less_than_n_messages l ON m.user_id = l.id) x
GROUP BY text, spam, less_than_n_messages)
SELECT *
FROM spam_or_ham
ORDER BY RANDOM();
ORDER BY created_at;
"""

let! data = conn.QueryAsync<SpamOrHam>(sql, {| criticalDate = criticalDate |})
let! data = conn.QueryAsync<SpamOrHamDb>(sql, {| criticalDate = criticalDate; criticalMsgCount = criticalMsgCount |})
return Array.ofSeq data
}

Expand Down Expand Up @@ -289,3 +304,14 @@ let deleteCallback (id: Guid): Task =
let! _ = conn.QueryAsync<DbCallback>(sql, {| id = id |})
return ()
}

let countUniqueUserMsg (userId: int64): Task<int> =
task {
use conn = new NpgsqlConnection(connString)

//language=postgresql
let sql = "SELECT COUNT(DISTINCT text) FROM message WHERE user_id = @userId"

let! result = conn.QuerySingleAsync<int>(sql, {| userId = userId |})
return result
}
105 changes: 69 additions & 36 deletions src/VahterBanBot/ML.fs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module VahterBanBot.ML
open System
open System.Diagnostics
open System.Text
open System.Threading
open System.Threading.Tasks
open Microsoft.Extensions.Hosting
open Microsoft.Extensions.Logging
Expand All @@ -15,6 +16,13 @@ open VahterBanBot.DB
open VahterBanBot.Types
open VahterBanBot.Utils

[<CLIMutable>]
type SpamOrHam =
{ text: string
spam: bool
lessThanNMessagesF: single
createdAt: DateTime }

[<CLIMutable>]
type Prediction =
{ Score: single
Expand Down Expand Up @@ -44,45 +52,66 @@ type MachineLearning(
sb.ToString()

let mutable predictionEngine: PredictionEngine<SpamOrHam, Prediction> option = None
let mutable timer: Timer = null

let trainModel() = task {
// switch to thread pool
do! Task.Yield()

let sw = Stopwatch.StartNew()
let trainModel _ = task {
try
// switch to thread pool
do! Task.Yield()
logger.LogInformation "Training model..."

let sw = Stopwatch.StartNew()

let mlContext = MLContext(botConf.MlSeed)
let mlContext = MLContext(botConf.MlSeed)

let trainDate = DateTime.UtcNow - botConf.MlTrainInterval
let! rawData = DB.mlData botConf.MlTrainCriticalMsgCount trainDate

let data =
rawData
|> Array.map (fun x ->
{ text = x.text
spam = x.spam
createdAt = x.created_at
lessThanNMessagesF = if x.less_than_n_messages then 1.0f else 0.0f }
)

let dataView = mlContext.Data.LoadFromEnumerable data
let trainTestSplit = mlContext.Data.TrainTestSplit(dataView, testFraction = botConf.MlTrainingSetFraction)
let trainingData = trainTestSplit.TrainSet
let testData = trainTestSplit.TestSet

let dataProcessPipeline =
mlContext.Transforms.Text
.FeaturizeText(outputColumnName = "TextFeaturized", inputColumnName = "text")
.Append(mlContext.Transforms.Concatenate(outputColumnName = "Features", inputColumnNames = [|"TextFeaturized"; "lessThanNMessagesF";|]))
.Append(mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(labelColumnName = "spam", featureColumnName = "Features"))

let! data = DB.mlData botConf.MlTrainBeforeDate

let dataView = mlContext.Data.LoadFromEnumerable data
let trainTestSplit = mlContext.Data.TrainTestSplit(dataView, testFraction = botConf.MlTrainingSetFraction)
let trainingData = trainTestSplit.TrainSet
let testData = trainTestSplit.TestSet

let dataProcessPipeline = mlContext.Transforms.Text.FeaturizeText(outputColumnName = "Features", inputColumnName = "text")
let trainer = mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(labelColumnName = "spam", featureColumnName = "Features")
let trainingPipeline = dataProcessPipeline.Append(trainer)

let trainedModel = trainingPipeline.Fit(trainingData)
predictionEngine <- Some(mlContext.Model.CreatePredictionEngine<SpamOrHam, Prediction>(trainedModel))

let predictions = trainedModel.Transform(testData)
let metrics = mlContext.BinaryClassification.Evaluate(data = predictions, labelColumnName = "spam", scoreColumnName = "Score")

sw.Stop()

let metricsStr = metricsToString metrics sw.Elapsed
logger.LogInformation metricsStr
do! telegramClient.SendTextMessageAsync(ChatId(botConf.LogsChannelId), metricsStr, parseMode = ParseMode.Markdown)
|> taskIgnore
let trainedModel = dataProcessPipeline.Fit(trainingData)
predictionEngine <- Some(mlContext.Model.CreatePredictionEngine<SpamOrHam, Prediction>(trainedModel))

let predictions = trainedModel.Transform(testData)
let metrics = mlContext.BinaryClassification.Evaluate(data = predictions, labelColumnName = "spam", scoreColumnName = "Score")

sw.Stop()

let metricsStr = metricsToString metrics sw.Elapsed
logger.LogInformation metricsStr
do! telegramClient.SendTextMessageAsync(ChatId(botConf.LogsChannelId), metricsStr, parseMode = ParseMode.Markdown)
|> taskIgnore
with ex ->
logger.LogError(ex, "Error training model")
}

member _.Predict(text: string) =
member _.Predict(text: string, userMsgCount: int) =
try
match predictionEngine with
| Some predictionEngine ->
predictionEngine.Predict({ text = text; spam = false })
predictionEngine.Predict
{ text = text
spam = false
lessThanNMessagesF = if userMsgCount < botConf.MlTrainCriticalMsgCount then 1.0f else 0.0f
createdAt = DateTime.UtcNow }
|> Some
| None ->
logger.LogInformation "Model not trained yet"
Expand All @@ -94,11 +123,15 @@ type MachineLearning(
interface IHostedService with
member this.StartAsync _ = task {
if botConf.MlEnabled then
try
logger.LogInformation "Training model..."
if botConf.MlRetrainInterval.IsSome then
// recurring
timer <- new Timer(TimerCallback(trainModel >> ignore), null, TimeSpan.Zero, botConf.MlRetrainInterval.Value)
else
// once
do! trainModel()
with ex ->
logger.LogError(ex, "Error training model")
}

member this.StopAsync _ = Task.CompletedTask
member this.StopAsync _ =
match timer with
| null -> Task.CompletedTask
| timer -> timer.DisposeAsync().AsTask()
4 changes: 3 additions & 1 deletion src/VahterBanBot/Program.fs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ let botConf =
UpdateChatAdminsInterval = getEnvOrWith "UPDATE_CHAT_ADMINS_INTERVAL_SEC" None (int >> TimeSpan.FromSeconds >> Some)
UpdateChatAdmins = getEnvOr "UPDATE_CHAT_ADMINS" "false" |> bool.Parse
MlEnabled = getEnvOr "ML_ENABLED" "false" |> bool.Parse
MlRetrainInterval = getEnvOrWith "ML_RETRAIN_INTERVAL_SEC" None (int >> TimeSpan.FromSeconds >> Some)
MlSeed = getEnvOrWith "ML_SEED" (Nullable<int>()) (int >> Nullable)
MlSpamDeletionEnabled = getEnvOr "ML_SPAM_DELETION_ENABLED" "false" |> bool.Parse
MlTrainBeforeDate = getEnvOrWith "ML_TRAIN_BEFORE_DATE" DateTime.UtcNow (DateTimeOffset.Parse >> _.UtcDateTime)
MlTrainInterval = getEnvOr "ML_TRAIN_INTERVAL_DAYS" "30" |> int |> TimeSpan.FromDays
MlTrainCriticalMsgCount = getEnvOr "ML_TRAIN_CRITICAL_MSG_COUNT" "5" |> int
MlTrainingSetFraction = getEnvOr "ML_TRAINING_SET_FRACTION" "0.2" |> float
MlSpamThreshold = getEnvOr "ML_SPAM_THRESHOLD" "0.5" |> single
MlWarningThreshold = getEnvOr "ML_WARNING_THRESHOLD" "0.0" |> single
Expand Down
4 changes: 3 additions & 1 deletion src/VahterBanBot/Types.fs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ type BotConfiguration =
UpdateChatAdminsInterval: TimeSpan option
UpdateChatAdmins: bool
MlEnabled: bool
MlRetrainInterval: TimeSpan option
MlSeed: Nullable<int>
MlSpamDeletionEnabled: bool
MlTrainBeforeDate: DateTime
MlTrainInterval: TimeSpan
MlTrainCriticalMsgCount: int
MlTrainingSetFraction: float
MlSpamThreshold: single
MlWarningThreshold: single
Expand Down
4 changes: 2 additions & 2 deletions src/VahterBanBot/UpdateChatAdmins.fs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ type UpdateChatAdmins(
do! Task.Delay 100

for admin in admins do
%result.Add admin.User.Id
%sb.AppendJoin(",", $"{prependUsername admin.User.Username} ({admin.User.Id})")
if result.Add admin.User.Id then
%sb.AppendJoin(",", $"{prependUsername admin.User.Username} ({admin.User.Id})")
localAdmins <- result
logger.LogInformation (sb.ToString())
}
Expand Down

0 comments on commit 45e56e9

Please sign in to comment.