Skip to content
This repository has been archived by the owner on Nov 24, 2024. It is now read-only.

Commit

Permalink
fix(ll4j): Thread safety model implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
siuank committed Oct 2, 2024
1 parent 1b408d5 commit 672a828
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 10 deletions.
60 changes: 60 additions & 0 deletions src/main/java/huzpsb/ll4j/minrt/JFuncModelScript.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package huzpsb.ll4j.minrt;

import java.util.function.Function;

public class JFuncModelScript {
public static Function<double[], double[]> compile(String[] script) {
Function<double[], double[]> current = (input) -> {
double[] copied = new double[input.length];
System.arraycopy(input, 0, copied, 0, input.length);
return copied;
};

for (String str : script) {
if (str.length() < 2) {
continue;
}
String[] tokens = str.split(" ");
switch (tokens[0]) {
case "D":
int ic = Integer.parseInt(tokens[1]);
int oc = Integer.parseInt(tokens[2]);
double[] weights = new double[ic * oc];
for (int i = 0; i < oc; i++) {
for (int j = 0; j < ic; j++) {
weights[i + j * oc] = Double.parseDouble(tokens[3 + i + j * oc]);
}
}
current = current.andThen((input) -> {
if (input.length != ic) {
throw new RuntimeException("Wrong input size for Dense layer (expected " + ic + ", got " + input.length + ")");
}
double[] tmp = new double[oc];
for (int i = 0; i < oc; i++) {
double sum = 0;
for (int j = 0; j < ic; j++) {
sum += input[j] * weights[i + j * oc];
}
tmp[i] = sum;
}
return tmp;
});
break;
case "L":
int n = Integer.parseInt(tokens[1]);
current = current.andThen((input) -> {
if (input.length != n) {
throw new RuntimeException("Wrong input size for LeakyRelu layer (expected " + n + ", got " + input.length + ")");
}
double[] tmp = new double[n];
for (int i = 0; i < n; i++) {
tmp[i] = input[i] > 0 ? input[i] : input[i] * 0.01;
}
return tmp;
});
break;
}
}
return current;
}
}
2 changes: 1 addition & 1 deletion src/main/kotlin/ltd/guimc/lgzbot/command/LGZBotCommand.kt
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ object LGZBotCommand : CompositeCommand(
@Description("让模型学习一段文本")
suspend fun CommandSender.iI1I1i1iIi1I(type: Int, string: String) {
LL4JUtils.learn(type, string)
sendMessage("Done.")
sendMessage("Deprecated")
}

@SubCommand("downloadModel")
Expand Down
22 changes: 13 additions & 9 deletions src/main/kotlin/ltd/guimc/lgzbot/utils/LL4JUtils.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package ltd.guimc.lgzbot.utils

import huzpsb.ll4j.minrt.JFuncModelScript
import huzpsb.ll4j.model.Model
import huzpsb.ll4j.nlp.token.Tokenizer
import huzpsb.ll4j.utils.data.DataSet
Expand All @@ -8,7 +9,7 @@ import ltd.guimc.lgzbot.utils.TextUtils.removeInterference
import ltd.guimc.lgzbot.utils.TextUtils.removeNonVisible

object LL4JUtils {
lateinit var model: Model
lateinit var model: (DoubleArray) -> DoubleArray
lateinit var tokenizer: Tokenizer
var version = "FEB25"

Expand All @@ -17,23 +18,25 @@ object LL4JUtils {
val modelFile = LL4JUtils.javaClass.getResourceAsStream("/anti-ad.model")!!
tokenizer = Tokenizer.load(tokenizerFile.bufferedReader(Charsets.UTF_8))
modelFile.bufferedReader(Charsets.UTF_8).use {
model = Model.read(it)
val compiled = JFuncModelScript.compile(it.readLines().toTypedArray())
model = { arr -> compiled.apply(arr) } // to kotlin style lambda (wtf)
}
}

fun predict(string: String): Boolean =
model.predictDebug(
model(
tokenizer.tokenize(
0,
string.replace("\n", "").replace("live.bilibili.com", "")
).values
).first == 1
).let { it[1] > it[0] }

fun predictDebug(string: String): Pair<Int, Double> =
model.predictDebug(tokenizer.tokenize(0, string.replace("\n", "").replace("live.bilibili.com", "")).values)
// [unused]
// fun predictDebug(string: String): Pair<Int, Double> =
// model.predictDebug(tokenizer.tokenize(0, string.replace("\n", "").replace("live.bilibili.com", "")).values)

fun predictAllResult(string: String): DoubleArray =
model.predictAllResult(
model(
tokenizer.tokenize(
0, sbc2dbcCase(string.replace("\n", "").replace("live.bilibili.com", ""))
.lowercase()
Expand All @@ -42,15 +45,16 @@ object LL4JUtils {
).values
)

// [deprecated]
fun learn(type: Int, string: String) {
val dataSet = DataSet()
dataSet.split.add(tokenizer.tokenize(type, string.replace("\n", "")))
model.trainOn(dataSet)
// model.trainOn(dataSet)
}

fun downloadModel() {
try {
model = Model.read(HttpUtils.getResponse("https://raw.githubusercontent.com/siuank/ADDetector4J/main/anti-ad.model"))
// model = Model.read(HttpUtils.getResponse("https://raw.githubusercontent.com/siuank/ADDetector4J/main/anti-ad.model"))
tokenizer = Tokenizer.load(HttpUtils.getResponse("https://raw.githubusercontent.com/siuank/ADDetector4J/main/t1.tokenized.txt").reader())
val time = GithubUtils.getLastCommit("siuank/ADDetector4J").commitTime
version = "${time.month.name.substring(0..3)}${time.dayOfMonth}"
Expand Down

0 comments on commit 672a828

Please sign in to comment.