diff --git a/src/main/java/huzpsb/ll4j/model/Model.java b/src/main/java/huzpsb/ll4j/model/Model.java index 84d21d2..18dc164 100644 --- a/src/main/java/huzpsb/ll4j/model/Model.java +++ b/src/main/java/huzpsb/ll4j/model/Model.java @@ -155,6 +155,20 @@ public List testAndGetWA(DataSet dataSet) { return Map.of(((JudgeLayer) layers[layers.length - 1]).result, layers[layers.length - 1].input[((JudgeLayer) layers[layers.length - 1]).result]); } + public double @NotNull [] predictAllResult(double[] input) { + AbstractLayer[] layers = this.layers; + if (!(layers[layers.length - 1] instanceof JudgeLayer)) { + throw new RuntimeException("Last layer is not output layer"); + } + for (AbstractLayer layer : layers) { + layer.training = false; + layer.input = input; + layer.forward(); + input = layer.output; + } + return layers[layers.length - 1].input; + } + public int predict(DataEntry de) { double[] input = de.values; for (AbstractLayer layer : layers) { diff --git a/src/main/kotlin/ltd/guimc/lgzbot/listener/message/MessageFilter.kt b/src/main/kotlin/ltd/guimc/lgzbot/listener/message/MessageFilter.kt index afe5a22..a0cea1a 100644 --- a/src/main/kotlin/ltd/guimc/lgzbot/listener/message/MessageFilter.kt +++ b/src/main/kotlin/ltd/guimc/lgzbot/listener/message/MessageFilter.kt @@ -135,19 +135,16 @@ object MessageFilter { if (!muted && textMessage.length >= stringLength) { if (LL4JUtils.predict(textMessage)) { - if (RegexUtils.matchRegexPinyin(adPinyinRegex, textMessage)) { e.group.mute(e.sender, "非法发言内容 (模型预测, 强检查证实)") riskList.add(e.sender) setVl(e.sender.id, 99.0) muted = true - } else { - val botOwner = e.bot.getFriend(Config.BotOwner) - requireNotNull(botOwner) - // botOwner.sendMessage("发现一条模型认为违规的消息, 但正则匹配失败, 请检查.") - // val outputMessage = ForwardMessageBuilder(e.group) - // outputMessage.add(e) - // botOwner.sendMessage(outputMessage.build()) + } else if (LL4JUtils.predictMap(textMessage)[1]!! >= 0.9) { + e.sender.mute(120, "非法发言内容 (模型预测认为可能性较高)") + riskList.add(e.sender) + setVl(e.sender.id, 99.0) + muted = true } } }