Skip to content

Commit

Permalink
feat(core): Add support for passing temperature as a parameter (#1229)
Browse files Browse the repository at this point in the history
* Draft changes

* First compiling implementation

* Current changes

* Unfinished changes

* Temperature now works

* Only call srand() once to prevent accidental determinism

* Seed RNG with model parameters instead of current time

* Add seed parameter

* Use better rng

* Formatting, make weighted_random take a const array

* Pass seed instead of rng

* Fix weighted_random

* Make golden tests use a fixed seed
  • Loading branch information
boxbeam authored Jan 24, 2024
1 parent e9f7e90 commit 4fd4c2e
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 17 deletions.
4 changes: 3 additions & 1 deletion crates/llama-cpp-bindings/include/engine.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "rust/cxx.h"
#include <cmath>
#include <memory>

namespace llama {
Expand All @@ -10,7 +11,8 @@ class TextInferenceEngine {
public:
virtual ~TextInferenceEngine();

virtual void add_request(uint32_t request_id, rust::Str text, size_t max_input_length) = 0;
virtual void add_request(uint32_t request_id, rust::Str text, size_t max_input_length,
float temperature, uint64_t seed) = 0;
virtual void stop_request(uint32_t request_id) = 0;
virtual rust::Vec<StepOutput> step() = 0;
};
Expand Down
56 changes: 49 additions & 7 deletions crates/llama-cpp-bindings/src/engine.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include <algorithm>
#include <functional>
#include <random>
#include <vector>
#include <cmath>
#include <deque>
#include <unordered_set>
#include <mutex>
Expand All @@ -17,13 +20,17 @@ namespace {
constexpr size_t N_BATCH = 512; // # per batch inference.
constexpr size_t N_CTX = 4096; // # max kv history.
struct Request {
Request(size_t request_id, std::vector<llama_token> input_token_ids) :
Request(size_t request_id, std::vector<llama_token> input_token_ids, float temperature, uint64_t seed) :
id(request_id),
tokens(input_token_ids.begin(), input_token_ids.end()) {
tokens(input_token_ids.begin(), input_token_ids.end()),
temperature(temperature),
seed(seed) {
}

uint32_t id = -1;
llama_seq_id seq_id = -1;
float temperature = 0;
uint64_t seed = 0;

std::vector<llama_token> tokens;
size_t i_batch = -1;
Expand Down Expand Up @@ -77,6 +84,39 @@ std::string string_format(const std::string& format, Args ... args)
return std::string(buf.get(), buf.get() + size - 1); // We don't want the '\0' inside
}

float compute_softmax_inplace(float* nums, size_t len, float temperature) {
float sum = 0;
float max = *std::max_element(nums, nums + len);
for (size_t i = 0; i < len; i++) {
nums[i] -= max;
nums[i] = std::exp(nums[i] / temperature);
sum += nums[i];
}
for (size_t i = 0; i < len; i++) {
nums[i] /= sum;
}
return sum;
}

size_t weighted_random(const float* nums, size_t len, uint64_t seed) {
std::mt19937 rng(seed);
float sum = 0;
for (size_t i = 0; i < len; i++) {
sum += nums[i];
}

float random = std::uniform_real_distribution<float>(0, sum)(rng);
sum = 0;
size_t i;
for (i = 0; i < len; i++) {
sum += nums[i];
if (sum >= random) {
return i;
}
}
return i;
}

template<class T>
using owned = std::unique_ptr<T, std::function<void(T*)>>;

Expand Down Expand Up @@ -110,13 +150,15 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
llama_batch_free(batch_);
}

virtual void add_request(uint32_t request_id, rust::Str text, size_t max_input_length) override {
virtual void add_request(uint32_t request_id, rust::Str text, size_t max_input_length,
float temperature, uint64_t seed) override {

auto tokens = llama_tokenize(llama_get_model(ctx_.get()), text, false, true);
if (tokens.size() > max_input_length) {
int start = tokens.size() - max_input_length;
tokens = std::vector<llama_token>(tokens.begin() + start, tokens.end());
}
pending_requests_.push_back(Request(request_id, tokens));
pending_requests_.push_back(Request(request_id, tokens, temperature, seed));
}

void stop_request(uint32_t request_id) override {
Expand Down Expand Up @@ -213,9 +255,9 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
}

int32_t i_batch = request.i_batch - i;
auto logits = llama_get_logits_ith(ctx, i_batch);
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));

float* logits = llama_get_logits_ith(ctx, i_batch);
compute_softmax_inplace(logits, n_vocab, request.temperature);
auto next_token = weighted_random(logits, n_vocab, request.seed);
request.n_past += request.tokens.size();

request.tokens.clear();
Expand Down
10 changes: 9 additions & 1 deletion crates/llama-cpp-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ mod ffi {
request_id: u32,
prompt: &str,
max_input_length: usize,
temperature: f32,
seed: u64,
);
fn stop_request(self: Pin<&mut TextInferenceEngine>, request_id: u32);
fn step(self: Pin<&mut TextInferenceEngine>) -> Result<Vec<StepOutput>>;
Expand Down Expand Up @@ -101,7 +103,13 @@ impl TextGeneration for LlamaTextGeneration {

let mut rx = self
.service
.add_request(prompt, options.max_input_length, stop_condition)
.add_request(
prompt,
options.max_input_length,
options.sampling_temperature,
options.seed,
stop_condition,
)
.await;

let s = stream! {
Expand Down
19 changes: 15 additions & 4 deletions crates/llama-cpp-bindings/src/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use crate::ffi;
struct LlamaInitRequest {
prompt: String,
max_input_length: usize,
temperature: f32,
seed: u64,

tx: Sender<String>,
stop_condition: StopCondition,
Expand Down Expand Up @@ -56,6 +58,8 @@ impl LlamaServiceImpl {
async fn background_job(&mut self) {
while let Some(LlamaInitRequest {
prompt,
temperature,
seed,
tx,
max_input_length,
stop_condition,
Expand All @@ -69,10 +73,13 @@ impl LlamaServiceImpl {
let request_id = self.alloc_request_id();
self.requests
.insert(request_id, LlamaRunningRequest { tx, stop_condition });
self.engine
.as_mut()
.unwrap()
.add_request(request_id, &prompt, max_input_length);
self.engine.as_mut().unwrap().add_request(
request_id,
&prompt,
max_input_length,
temperature,
seed,
);
}

let result = match self.engine.as_mut().unwrap().step() {
Expand Down Expand Up @@ -144,12 +151,16 @@ impl LlamaService {
&self,
prompt: &str,
max_input_length: usize,
temperature: f32,
seed: u64,
stop_condition: StopCondition,
) -> Receiver<String> {
let (tx, rx) = channel(8);
self.tx
.send(LlamaInitRequest {
prompt: prompt.to_owned(),
temperature,
seed,
tx,
max_input_length,
stop_condition,
Expand Down
5 changes: 4 additions & 1 deletion crates/tabby-inference/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ pub struct TextGenerationOptions {
#[builder(default = "256")]
pub max_decoding_length: usize,

#[builder(default = "1.0")]
#[builder(default = "0.1")]
pub sampling_temperature: f32,

#[builder(default = "0")]
pub seed: u64,

#[builder(default = "None")]
pub language: Option<&'static Language>,
}
Expand Down
26 changes: 23 additions & 3 deletions crates/tabby/src/services/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ pub struct CompletionRequest {
user: Option<String>,

debug_options: Option<DebugOptions>,

/// The temperature parameter for the model, used to tune variance and "creativity" of the model output
temperature: Option<f32>,

/// The seed used for randomly selecting tokens
seed: Option<u64>,
}

impl CompletionRequest {
Expand Down Expand Up @@ -200,11 +206,16 @@ impl CompletionService {
}
}

fn text_generation_options(language: &str) -> TextGenerationOptions {
fn text_generation_options(
language: &str,
temperature: f32,
seed: u64,
) -> TextGenerationOptions {
TextGenerationOptionsBuilder::default()
.max_input_length(1024 + 512)
.max_decoding_length(128)
.sampling_temperature(0.1)
.sampling_temperature(temperature)
.seed(seed)
.language(Some(get_language(language)))
.build()
.unwrap()
Expand All @@ -216,7 +227,16 @@ impl CompletionService {
) -> Result<CompletionResponse, CompletionError> {
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
let language = request.language_or_unknown();
let options = Self::text_generation_options(language.as_str());
let options = Self::text_generation_options(
language.as_str(),
request.temperature.unwrap_or(0.1),
request.seed.unwrap_or_else(|| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64
}),
);

let (prompt, segments, snippets) = if let Some(prompt) = request.raw_prompt() {
(prompt, None, vec![])
Expand Down
4 changes: 4 additions & 0 deletions crates/tabby/tests/goldentests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ async fn run_golden_tests() {

assert_golden!(json!({
"language": "python",
"seed": 0,
"segments": {
"prefix": "def fib(n):\n ",
"suffix": "\n return fib(n - 1) + fib(n - 2)"
Expand All @@ -111,6 +112,7 @@ async fn run_golden_tests() {

assert_golden!(json!({
"language": "python",
"seed": 0,
"segments": {
"prefix": "import datetime\n\ndef parse_expenses(expenses_string):\n \"\"\"Parse the list of expenses and return the list of triples (date, value, currency).\n Ignore lines starting with #.\n Parse the date using datetime.\n Example expenses_string:\n 2016-01-02 -34.01 USD\n 2016-01-03 2.59 DKK\n 2016-01-03 -2.72 EUR\n \"\"\"\n for line in expenses_string.split('\\n'):\n "
}
Expand All @@ -124,13 +126,15 @@ async fn run_golden_tests_cpu() {

assert_golden!(json!({
"language": "python",
"seed": 0,
"segments": {
"prefix": "def is_prime(n):\n",
}
}));

assert_golden!(json!({
"language": "python",
"seed": 0,
"segments": {
"prefix": "def char_frequencies(str):\n freqs = {}\n ",
}
Expand Down

0 comments on commit 4fd4c2e

Please sign in to comment.