Skip to content

Commit

Permalink
Merge pull request #19 from firstbatchxyz/erhant/workflow-fixes
Browse files Browse the repository at this point in the history
fix: workflow & start fixes
  • Loading branch information
erhant authored Oct 31, 2024
2 parents 944d411 + ebac87d commit 04ef3d1
Show file tree
Hide file tree
Showing 11 changed files with 521 additions and 338 deletions.
734 changes: 442 additions & 292 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "dkn-oracle"
description = "Dria Knowledge Network: Oracle Node"
version = "0.1.6"
version = "0.1.7"
edition = "2021"
license = "Apache-2.0"
readme = "README.md"
Expand Down
4 changes: 4 additions & 0 deletions src/commands/coordinator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ pub async fn run_oracle(
kinds.push(kind);
}
}

if kinds.is_empty() {
return Err(eyre!("You are not registered as any type of oracle."))?;
}
} else {
// otherwise, make sure we are registered to required kinds
for kind in &kinds {
Expand Down
8 changes: 6 additions & 2 deletions src/compute/handlers/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,18 @@ pub async fn handle_generation(
log::debug!("Executing the workflow");
let protocol_string = bytes32_to_string(&protocol)?;
let executor = Executor::new(model);
let (output, metadata) = executor
let (output, metadata, use_storage) = executor
.execute_raw(&request.input, &protocol_string)
.await?;

// do the Arweave trick for large inputs
log::debug!("Uploading to Arweave if required");
let arweave = Arweave::new_from_env().wrap_err("could not create Arweave instance")?;
let output = arweave.put_if_large(output).await?;
let output = if use_storage {
arweave.put_if_large(output).await?
} else {
output
};
let metadata = arweave.put_if_large(metadata).await?;

// mine nonce
Expand Down
83 changes: 51 additions & 32 deletions src/compute/workflows/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@ use crate::data::{Arweave, OracleExternalData};

#[async_trait(?Send)]
pub trait WorkflowsExt {
async fn prepare_input(&self, input_bytes: &Bytes) -> Result<(Option<Entry>, Workflow)>;
async fn execute_raw(&self, input_bytes: &Bytes, protocol: &str) -> Result<(Bytes, Bytes)>;
async fn parse_input_bytes(&self, input_bytes: &Bytes) -> Result<(Option<Entry>, Workflow)>;
async fn parse_input_string(&self, input_str: String) -> Result<(Option<Entry>, Workflow)>;
async fn execute_raw(
&self,
input_bytes: &Bytes,
protocol: &str,
) -> Result<(Bytes, Bytes, bool)>;

/// Returns a generation workflow for the executor.
#[inline]
Expand All @@ -22,37 +27,40 @@ pub trait WorkflowsExt {

#[async_trait(?Send)]
impl WorkflowsExt for Executor {
/// Given an input, prepares it for the executer by providing the entry and workflow.
/// Given an input of byte-slice, parses it to an entry and workflow.
///
/// - If input is a JSON txid string (64-char hex), the entry is fetched from Arweave, and then we recurse.
/// - If input is a JSON string, it is converted to an entry and a generation workflow is returned.
/// - If input is a JSON workflow, entry is `None` and input is casted to a workflow
/// - Otherwise, error is returned.
async fn prepare_input(&self, input_bytes: &Bytes) -> Result<(Option<Entry>, Workflow)> {
/// - If input is a byteslice of JSON string, it is passed to `parse_input_string`.
/// - If input is a byteslice of JSON workflow, entry is `None` and input is casted to a workflow.
async fn parse_input_bytes(&self, input_bytes: &Bytes) -> Result<(Option<Entry>, Workflow)> {
if let Ok(input_str) = serde_json::from_slice::<String>(input_bytes) {
// this is a string, lets see if its a txid
if Arweave::is_key(input_str.clone()) {
// if its a txid, we download the data and parse it again
// we dont expect to recurse here too much, because there would have to txid within txid
// but still it is possible
let input_bytes = Arweave::default()
.get(input_str)
.await
.wrap_err("could not download from Arweave")?;
self.prepare_input(&input_bytes).await
} else {
// it is not a key, so we treat it as a generation request with plaintext input
let entry = Some(Entry::String(input_str));
let workflow = self.get_generation_workflow()?;
Ok((entry, workflow))
}
self.parse_input_string(input_str).await
} else if let Ok(workflow) = serde_json::from_slice::<Workflow>(input_bytes) {
// it is a workflow, so we can directly use it with no entry
Ok((None, workflow))
} else {
// it is unparsable, return as lossy-converted string
let input_string = String::from_utf8_lossy(input_bytes);
let entry = Some(Entry::String(input_string.into()));
let input_str = String::from_utf8_lossy(input_bytes);
self.parse_input_string(input_str.into()).await
}
}

/// Given an input of string, parses it to an entry and workflow.
///
/// - If input is a txid (64-char hex, without 0x), the entry is fetched from Arweave, and then we recurse back to `parse_input_bytes`.
/// - Otherwise, it is treated as a plaintext input and a generation workflow is returned.
async fn parse_input_string(&self, input_string: String) -> Result<(Option<Entry>, Workflow)> {
if Arweave::is_key(input_string.clone()) {
// if its a txid, we download the data and parse it again
let input_bytes = Arweave::default()
.get(input_string)
.await
.wrap_err("could not download from Arweave")?;

// we dont expect to recurse here again too much, because there would have to txid within txid
self.parse_input_bytes(&input_bytes).await
} else {
// it is not a key, so we treat it as a generation request with plaintext input
let entry = Some(Entry::String(input_string));
let workflow = self.get_generation_workflow()?;
Ok((entry, workflow))
}
Expand All @@ -61,10 +69,14 @@ impl WorkflowsExt for Executor {
/// Executes a generation task for the given input.
/// The workflow & entry is derived from the input.
///
/// Returns output and metadata.
async fn execute_raw(&self, input_bytes: &Bytes, protocol: &str) -> Result<(Bytes, Bytes)> {
/// Returns output, metadata, and a boolean indicating whether we shall upload the `output` to storage if large enough.
async fn execute_raw(
&self,
input_bytes: &Bytes,
protocol: &str,
) -> Result<(Bytes, Bytes, bool)> {
// parse & prepare input
let (entry, workflow) = self.prepare_input(input_bytes).await?;
let (entry, workflow) = self.parse_input_bytes(input_bytes).await?;

// obtain raw output
let mut memory = ProgramMemory::new();
Expand Down Expand Up @@ -92,7 +104,7 @@ mod tests {
let input_str = "foobar";

let (entry, _) = executor
.prepare_input(&input_str.as_bytes().into())
.parse_input_bytes(&input_str.as_bytes().into())
.await
.unwrap();
assert_eq!(entry.unwrap(), Entry::String(input_str.into()));
Expand All @@ -108,7 +120,14 @@ mod tests {
let expected_str = "Hello, Arweave!";

let (entry, _) = executor
.prepare_input(&arweave_key.as_bytes().into())
.parse_input_bytes(&arweave_key.as_bytes().into())
.await
.unwrap();
assert_eq!(entry.unwrap(), Entry::String(expected_str.into()));

// without `"`s
let (entry, _) = executor
.parse_input_bytes(&arweave_key.trim_matches('"').as_bytes().into())
.await
.unwrap();
assert_eq!(entry.unwrap(), Entry::String(expected_str.into()));
Expand All @@ -120,7 +139,7 @@ mod tests {

let workflow_str = include_str!("presets/generation.json");
let (entry, _) = executor
.prepare_input(&workflow_str.as_bytes().into())
.parse_input_bytes(&workflow_str.as_bytes().into())
.await
.unwrap();

Expand Down
4 changes: 2 additions & 2 deletions src/compute/workflows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ mod tests {
async fn test_ollama_generation() {
dotenvy::dotenv().unwrap();
let executor = Executor::new(Model::Llama3_1_8B);
let (output, _) = executor
let (output, _, _) = executor
.execute_raw(&Bytes::from_static(b"What is the result of 2 + 2?"), "")
.await
.unwrap();
Expand All @@ -29,7 +29,7 @@ mod tests {
async fn test_openai_generation() {
dotenvy::dotenv().unwrap();
let executor = Executor::new(Model::Llama3_1_8B);
let (output, _) = executor
let (output, _, _) = executor
.execute_raw(&Bytes::from_static(b"What is the result of 2 + 2?"), "")
.await
.unwrap();
Expand Down
6 changes: 3 additions & 3 deletions src/compute/workflows/postprocess/identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ pub struct IdentityPostProcessor;
impl PostProcess for IdentityPostProcessor {
const PROTOCOL: &'static str = "";

fn post_process(&self, input: String) -> Result<(Bytes, Bytes)> {
Ok((input.into(), Default::default()))
fn post_process(&self, input: String) -> Result<(Bytes, Bytes, bool)> {
Ok((input.into(), Default::default(), true))
}
}

Expand All @@ -23,7 +23,7 @@ mod tests {
#[test]
fn test_identity_post_processor() {
let input = "hello".to_string();
let (output, metadata) = IdentityPostProcessor.post_process(input).unwrap();
let (output, metadata, _) = IdentityPostProcessor.post_process(input).unwrap();
assert_eq!(output, Bytes::from("hello"));
assert_eq!(metadata, Bytes::default());
}
Expand Down
7 changes: 6 additions & 1 deletion src/compute/workflows/postprocess/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,10 @@ pub trait PostProcess {
const PROTOCOL: &'static str;

/// A post-processing step that takes the raw output from the LLM and splits it into an output and metadata.
fn post_process(&self, input: String) -> eyre::Result<(Bytes, Bytes)>;
///
/// Returns:
/// - The output that is used within the contract.
/// - The metadata that is externally checked.
/// - A boolean indicating if the output should be uploaded to a storage if large enough.
fn post_process(&self, input: String) -> eyre::Result<(Bytes, Bytes, bool)>;
}
8 changes: 4 additions & 4 deletions src/compute/workflows/postprocess/swan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl SwanPurchasePostProcessor {
impl PostProcess for SwanPurchasePostProcessor {
const PROTOCOL: &'static str = "swan-buyer-purchase";

fn post_process(&self, input: String) -> Result<(Bytes, Bytes)> {
fn post_process(&self, input: String) -> Result<(Bytes, Bytes, bool)> {
// we will cast strings to Address here
use alloy::primitives::Address;

Expand Down Expand Up @@ -58,7 +58,7 @@ impl PostProcess for SwanPurchasePostProcessor {
// `abi.encode` the list of addresses to be decodable by contract
let addresses_encoded = addresses.abi_encode();

Ok((Bytes::from(addresses_encoded), Bytes::from(input)))
Ok((Bytes::from(addresses_encoded), Bytes::from(input), false))
}
}

Expand All @@ -85,7 +85,7 @@ some more blabla here

let post_processor = SwanPurchasePostProcessor::new("<buy_list>", "</buy_list>");

let (output, metadata) = post_processor.post_process(INPUT.to_string()).unwrap();
let (output, metadata, _) = post_processor.post_process(INPUT.to_string()).unwrap();
assert_eq!(
metadata,
Bytes::from(INPUT),
Expand Down Expand Up @@ -121,7 +121,7 @@ some more blabla here

let post_processor = SwanPurchasePostProcessor::new("<shop_list>", "</shop_list>");

let (output, _) = post_processor.post_process(INPUT.to_string()).unwrap();
let (output, _, _) = post_processor.post_process(INPUT.to_string()).unwrap();
println!("{}", output);

let addresses = <Vec<Address>>::abi_decode(&output, true).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion src/compute/workflows/presets/generation.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": "LLM generation",
"description": "Directly generate text with input",
"config": {
"max_steps": 1,
"max_steps": 10,
"max_time": 50,
"tools": [""]
},
Expand Down
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ async fn main() -> Result<()> {

dkn_oracle::cli().await?;

log::info!("Bye!");
Ok(())
}

0 comments on commit 04ef3d1

Please sign in to comment.