From 672a8287d4a41f0e6d7fce6d5503f890d947537e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B1=91=E8=90=BD=E5=8F=B6?= <100185989+siuank@users.noreply.github.com> Date: Wed, 2 Oct 2024 19:50:44 +0800 Subject: [PATCH] fix(ll4j): Thread safety model implementation --- .../huzpsb/ll4j/minrt/JFuncModelScript.java | 60 +++++++++++++++++++ .../ltd/guimc/lgzbot/command/LGZBotCommand.kt | 2 +- .../ltd/guimc/lgzbot/utils/LL4JUtils.kt | 22 ++++--- 3 files changed, 74 insertions(+), 10 deletions(-) create mode 100644 src/main/java/huzpsb/ll4j/minrt/JFuncModelScript.java diff --git a/src/main/java/huzpsb/ll4j/minrt/JFuncModelScript.java b/src/main/java/huzpsb/ll4j/minrt/JFuncModelScript.java new file mode 100644 index 0000000..3f9261f --- /dev/null +++ b/src/main/java/huzpsb/ll4j/minrt/JFuncModelScript.java @@ -0,0 +1,60 @@ +package huzpsb.ll4j.minrt; + +import java.util.function.Function; + +public class JFuncModelScript { + public static Function compile(String[] script) { + Function current = (input) -> { + double[] copied = new double[input.length]; + System.arraycopy(input, 0, copied, 0, input.length); + return copied; + }; + + for (String str : script) { + if (str.length() < 2) { + continue; + } + String[] tokens = str.split(" "); + switch (tokens[0]) { + case "D": + int ic = Integer.parseInt(tokens[1]); + int oc = Integer.parseInt(tokens[2]); + double[] weights = new double[ic * oc]; + for (int i = 0; i < oc; i++) { + for (int j = 0; j < ic; j++) { + weights[i + j * oc] = Double.parseDouble(tokens[3 + i + j * oc]); + } + } + current = current.andThen((input) -> { + if (input.length != ic) { + throw new RuntimeException("Wrong input size for Dense layer (expected " + ic + ", got " + input.length + ")"); + } + double[] tmp = new double[oc]; + for (int i = 0; i < oc; i++) { + double sum = 0; + for (int j = 0; j < ic; j++) { + sum += input[j] * weights[i + j * oc]; + } + tmp[i] = sum; + } + return tmp; + }); + break; + case "L": + int n = Integer.parseInt(tokens[1]); + current = current.andThen((input) -> { + if (input.length != n) { + throw new RuntimeException("Wrong input size for LeakyRelu layer (expected " + n + ", got " + input.length + ")"); + } + double[] tmp = new double[n]; + for (int i = 0; i < n; i++) { + tmp[i] = input[i] > 0 ? input[i] : input[i] * 0.01; + } + return tmp; + }); + break; + } + } + return current; + } +} diff --git a/src/main/kotlin/ltd/guimc/lgzbot/command/LGZBotCommand.kt b/src/main/kotlin/ltd/guimc/lgzbot/command/LGZBotCommand.kt index 80cb0aa..60fe702 100644 --- a/src/main/kotlin/ltd/guimc/lgzbot/command/LGZBotCommand.kt +++ b/src/main/kotlin/ltd/guimc/lgzbot/command/LGZBotCommand.kt @@ -129,7 +129,7 @@ object LGZBotCommand : CompositeCommand( @Description("让模型学习一段文本") suspend fun CommandSender.iI1I1i1iIi1I(type: Int, string: String) { LL4JUtils.learn(type, string) - sendMessage("Done.") + sendMessage("Deprecated") } @SubCommand("downloadModel") diff --git a/src/main/kotlin/ltd/guimc/lgzbot/utils/LL4JUtils.kt b/src/main/kotlin/ltd/guimc/lgzbot/utils/LL4JUtils.kt index 235c293..a7a028a 100644 --- a/src/main/kotlin/ltd/guimc/lgzbot/utils/LL4JUtils.kt +++ b/src/main/kotlin/ltd/guimc/lgzbot/utils/LL4JUtils.kt @@ -1,5 +1,6 @@ package ltd.guimc.lgzbot.utils +import huzpsb.ll4j.minrt.JFuncModelScript import huzpsb.ll4j.model.Model import huzpsb.ll4j.nlp.token.Tokenizer import huzpsb.ll4j.utils.data.DataSet @@ -8,7 +9,7 @@ import ltd.guimc.lgzbot.utils.TextUtils.removeInterference import ltd.guimc.lgzbot.utils.TextUtils.removeNonVisible object LL4JUtils { - lateinit var model: Model + lateinit var model: (DoubleArray) -> DoubleArray lateinit var tokenizer: Tokenizer var version = "FEB25" @@ -17,23 +18,25 @@ object LL4JUtils { val modelFile = LL4JUtils.javaClass.getResourceAsStream("/anti-ad.model")!! tokenizer = Tokenizer.load(tokenizerFile.bufferedReader(Charsets.UTF_8)) modelFile.bufferedReader(Charsets.UTF_8).use { - model = Model.read(it) + val compiled = JFuncModelScript.compile(it.readLines().toTypedArray()) + model = { arr -> compiled.apply(arr) } // to kotlin style lambda (wtf) } } fun predict(string: String): Boolean = - model.predictDebug( + model( tokenizer.tokenize( 0, string.replace("\n", "").replace("live.bilibili.com", "") ).values - ).first == 1 + ).let { it[1] > it[0] } - fun predictDebug(string: String): Pair = - model.predictDebug(tokenizer.tokenize(0, string.replace("\n", "").replace("live.bilibili.com", "")).values) + // [unused] +// fun predictDebug(string: String): Pair = +// model.predictDebug(tokenizer.tokenize(0, string.replace("\n", "").replace("live.bilibili.com", "")).values) fun predictAllResult(string: String): DoubleArray = - model.predictAllResult( + model( tokenizer.tokenize( 0, sbc2dbcCase(string.replace("\n", "").replace("live.bilibili.com", "")) .lowercase() @@ -42,15 +45,16 @@ object LL4JUtils { ).values ) + // [deprecated] fun learn(type: Int, string: String) { val dataSet = DataSet() dataSet.split.add(tokenizer.tokenize(type, string.replace("\n", ""))) - model.trainOn(dataSet) +// model.trainOn(dataSet) } fun downloadModel() { try { - model = Model.read(HttpUtils.getResponse("https://raw.githubusercontent.com/siuank/ADDetector4J/main/anti-ad.model")) +// model = Model.read(HttpUtils.getResponse("https://raw.githubusercontent.com/siuank/ADDetector4J/main/anti-ad.model")) tokenizer = Tokenizer.load(HttpUtils.getResponse("https://raw.githubusercontent.com/siuank/ADDetector4J/main/t1.tokenized.txt").reader()) val time = GithubUtils.getLastCommit("siuank/ADDetector4J").commitTime version = "${time.month.name.substring(0..3)}${time.dayOfMonth}"