Skip to content

Commit

Permalink
Update to register listeners.
Browse files Browse the repository at this point in the history
  • Loading branch information
Buhake Sindi committed Dec 14, 2024
1 parent 2711775 commit 084a1be
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fraud.memory.max.messages=20
# Location of documents to RAG
app.docs-for-rag.dir=docs-for-rag

# Micrprofile Telemetry
# Microprofile Telemetry
otel.service.name=liberty-car-booking
otel.sdk.disabled=false
otel.logs.exporter=otlp,console
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import java.lang.annotation.Annotation;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -15,6 +17,12 @@
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.literal.NamedLiteral;
import jakarta.enterprise.inject.spi.DeploymentException;
import jakarta.enterprise.util.TypeLiteral;

import org.jboss.logging.Logger;

import dev.langchain4j.data.segment.TextSegment;
Expand All @@ -23,11 +31,6 @@
import dev.langchain4j.store.embedding.EmbeddingStore;
import io.smallrye.llm.core.langchain4j.core.config.spi.LLMConfig;
import io.smallrye.llm.core.langchain4j.core.config.spi.LLMConfigProvider;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.literal.NamedLiteral;
import jakarta.enterprise.inject.spi.CDI;
import jakarta.enterprise.util.TypeLiteral;

/*
smallrye.llm.plugin.content-retriever.class=dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever
Expand Down Expand Up @@ -76,11 +79,18 @@ public static void createAllLLMBeans(LLMConfig llmConfig, Consumer<BeanData> bea
}
beanBuilder.accept(
new BeanData(targetClass, builderCLass, scopeClass, beanName,
(Instance<Object> creationalContext) -> CommonLLMPluginCreator.create(
creationalContext,
beanName,
targetClass,
builderCLass)));
(Instance<Object> creationalContext) -> {
try {
return CommonLLMPluginCreator.create(
creationalContext,
beanName,
targetClass,
builderCLass);
} catch (ClassNotFoundException e) {
// TODO Auto-generated catch block
throw new DeploymentException(e);
}
}));
}
}
}
Expand Down Expand Up @@ -123,7 +133,8 @@ public Function<Instance<Object>, Object> getCallback() {
}
}

public static Object create(Instance<Object> lookup, String beanName, Class<?> targetClass, Class<?> builderClass) {
@SuppressWarnings("unchecked")
public static Object create(Instance<Object> lookup, String beanName, Class<?> targetClass, Class<?> builderClass) throws ClassNotFoundException {
LLMConfig llmConfig = LLMConfigProvider.getLlmConfig();
LOGGER.info(
"Create instance config:" + beanName + ", target class : " + targetClass + ", builderClass : " + builderClass);
Expand All @@ -144,30 +155,37 @@ public static Object create(Instance<Object> lookup, String beanName, Class<?> t
} else {
for (Method methodToCall : methodsToCall) {
Class<?> parameterType = methodToCall.getParameterTypes()[0];
if ("listeners".equals(property) && "@all".equals(stringValue)) {
Class<?> typeParameterClass = ChatLanguageModel.class.isAssignableFrom(targetClass)
if ("listeners".equals(property)) {
Class<?> typeParameterClass = ChatLanguageModel.class.isAssignableFrom(targetClass)
? ChatModelListener.class
: parameterType.getTypeParameters()[0].getGenericDeclaration();
Instance<?> inst = getInstance(lookup, typeParameterClass);
if (inst != null) {
List<?> listeners = StreamSupport.stream(inst.spliterator(), false)
List<Object> listeners = (List<Object>) Collections.checkedList(new ArrayList<>(), typeParameterClass);
if ("@all".equals(stringValue.trim())) {
Instance<Object> inst = (Instance<Object>) getInstance(lookup, typeParameterClass);
if (inst != null) {
inst.forEach(listeners::add);
listeners = StreamSupport.stream(inst.spliterator(), false)
.collect(Collectors.toList());
if (listeners != null && !listeners.isEmpty()) {
listeners.stream().forEach(l -> LOGGER.info("Adding listener: " + l.getClass().getName()));
methodToCall.invoke(builder, listeners);
}
}
} else {
for (String className : stringValue.split(",")) {
Instance<?> inst = getInstance(lookup, loadClass(className.trim()));
listeners.add(inst.get());
}
}

if (listeners != null && !listeners.isEmpty()) {
listeners.stream().forEach(l -> LOGGER.info("Adding listener: " + l.getClass().getName()));
methodToCall.invoke(builder, listeners);
}
} else if (stringValue.startsWith("lookup:")) {
String lookupableBean = stringValue.substring("lookup:".length());
LOGGER.info("Lookup " + lookupableBean + " " + parameterType);
Instance<?> inst;
if ("default".equals(lookupableBean)) {
inst = lookup.select(parameterType);
if (!inst.isResolvable()) {
inst = CDI.current().select(parameterType);
}
inst = getInstance(lookup, parameterType);
} else {
inst = lookup.select(parameterType, NamedLiteral.of(lookupableBean));
inst = getInstance(lookup, parameterType, lookupableBean);
}
methodToCall.invoke(builder, inst.get());
break;
Expand All @@ -189,8 +207,8 @@ public static Object create(Instance<Object> lookup, String beanName, Class<?> t
}
}

private static Class<?> loadClass(String scopeClassName) throws ClassNotFoundException {
return Thread.currentThread().getContextClassLoader().loadClass(scopeClassName);
private static Class<?> loadClass(String className) throws ClassNotFoundException {
return Thread.currentThread().getContextClassLoader().loadClass(className);
}

@SuppressWarnings("unchecked")
Expand Down

0 comments on commit 084a1be

Please sign in to comment.