diff --git a/llm_perplexity.py b/llm_perplexity.py index 73e7265..93b7e21 100644 --- a/llm_perplexity.py +++ b/llm_perplexity.py @@ -6,7 +6,7 @@ @llm.hookimpl def register_models(register): - # https://docs.perplexity.ai/docs/model-cards + # https://docs.perplexity.ai/guides/model-cards register(Perplexity("llama-3.1-sonar-small-128k-online"), aliases=("sonar-small",)) register(Perplexity("llama-3.1-sonar-large-128k-online"), aliases=("sonar-large",)) register(Perplexity("llama-3.1-sonar-huge-128k-online"), aliases=("sonar-huge",)) @@ -52,6 +52,11 @@ class PerplexityOptions(llm.Options): default=None, ) + return_citations: Optional[bool] = Field( + description="Determines whether or not a request to an online model should return citations", + default=False, + ) + @field_validator("temperature") @classmethod def validate_temperature(cls, temperature): @@ -115,7 +120,7 @@ def execute(self, prompt, stream, response, conversation): kwargs = { "model": self.model_id, "messages": self.build_messages(prompt, conversation), - "stream": prompt.options.stream, + "stream": stream, "max_tokens": prompt.options.max_tokens or None, } @@ -126,14 +131,28 @@ def execute(self, prompt, stream, response, conversation): if prompt.options.top_k: kwargs["top_k"] = prompt.options.top_k + + if prompt.options.return_citations: + kwargs["return_citations"] = prompt.options.return_citations if stream: with client.chat.completions.create(**kwargs) as stream: for text in stream: yield text.choices[0].delta.content + + if hasattr(text, 'citations') and text.citations: + yield "\n\nCitations:\n" + for i, citation in enumerate(text.citations, 1): + yield f"[{i}] {citation}\n" + else: completion = client.chat.completions.create(**kwargs) yield completion.choices[0].message.content + if hasattr(completion, 'citations') and completion.citations: + yield "\n\nCitations:\n" + for i, citation in enumerate(completion.citations, 1): + yield f"[{i}] {citation}\n" + def __str__(self): return f"Perplexity: {self.model_id}" diff --git a/pyproject.toml b/pyproject.toml index ef8ae62..7c51be3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "llm-perplexity" -version = "0.9" +version = "2024.11.0" description = "LLM access to pplx-api 3 by Perplexity Labs" readme = "README.md" authors = [{name = "hex"}]