From 7c9b7cb251858b84bbcb7d4abc476ffa2c42b0d9 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 27 Jun 2024 15:29:11 +0200 Subject: [PATCH] feat(candle): add FlashMistral (#308) --- Cargo.lock | 1 + backends/candle/src/flash_attn.rs | 18 +- backends/candle/src/layers/linear.rs | 1 + backends/candle/src/layers/mod.rs | 4 + backends/candle/src/layers/rms_norm.rs | 96 + backends/candle/src/lib.rs | 97 +- backends/candle/src/models/bert.rs | 6 + backends/candle/src/models/distilbert.rs | 9 +- backends/candle/src/models/flash_bert.rs | 60 +- .../candle/src/models/flash_distilbert.rs | 58 +- backends/candle/src/models/flash_jina.rs | 30 +- backends/candle/src/models/flash_jina_code.rs | 30 +- backends/candle/src/models/flash_mistral.rs | 442 + backends/candle/src/models/flash_nomic.rs | 30 +- backends/candle/src/models/jina.rs | 5 + backends/candle/src/models/jina_code.rs | 9 +- backends/candle/src/models/mistral.rs | 19 + backends/candle/src/models/mod.rs | 8 + backends/candle/src/models/nomic.rs | 5 + backends/candle/tests/common.rs | 97 +- .../test_flash_mistral__mistral_batch.snap | 12293 ++++++++++++++++ .../test_flash_mistral__mistral_single.snap | 4101 ++++++ backends/candle/tests/test_bert.rs | 24 +- backends/candle/tests/test_flash_bert.rs | 24 +- backends/candle/tests/test_flash_jina.rs | 10 +- backends/candle/tests/test_flash_jina_code.rs | 10 +- backends/candle/tests/test_flash_mistral.rs | 53 + backends/candle/tests/test_flash_nomic.rs | 10 +- backends/candle/tests/test_jina.rs | 10 +- backends/candle/tests/test_jina_code.rs | 10 +- backends/candle/tests/test_nomic.rs | 10 +- backends/core/src/lib.rs | 3 + backends/src/lib.rs | 57 + core/Cargo.toml | 1 + core/src/download.rs | 56 +- core/src/infer.rs | 14 +- load_tests/load.js | 2 +- router/src/grpc/server.rs | 4 + router/src/lib.rs | 37 +- 39 files changed, 17578 insertions(+), 176 deletions(-) create mode 100644 backends/candle/src/layers/rms_norm.rs create mode 100644 backends/candle/src/models/flash_mistral.rs create mode 100644 backends/candle/src/models/mistral.rs create mode 100644 backends/candle/tests/snapshots/test_flash_mistral__mistral_batch.snap create mode 100644 backends/candle/tests/snapshots/test_flash_mistral__mistral_single.snap create mode 100644 backends/candle/tests/test_flash_mistral.rs diff --git a/Cargo.lock b/Cargo.lock index cf8e1c0f..134036f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3981,6 +3981,7 @@ dependencies = [ "async-channel", "hf-hub", "metrics", + "serde_json", "text-embeddings-backend", "thiserror", "tokenizers", diff --git a/backends/candle/src/flash_attn.rs b/backends/candle/src/flash_attn.rs index 3afc6517..f2016928 100644 --- a/backends/candle/src/flash_attn.rs +++ b/backends/candle/src/flash_attn.rs @@ -31,6 +31,7 @@ pub(crate) fn flash_attn_varlen( max_seqlen_k: usize, softmax_scale: f32, causal: bool, + window_size_left: Option, ) -> Result { let runtime_compute_cap = get_runtime_compute_cap(); @@ -38,6 +39,9 @@ pub(crate) fn flash_attn_varlen( if alibi_slopes.is_some() { candle::bail!("Flash attention v1 does not support alibi"); } + if window_size_left.is_some() { + candle::bail!("Flash attention v1 does not support attention windowing"); + } #[cfg(feature = "flash-attn-v1")] { @@ -59,10 +63,12 @@ pub(crate) fn flash_attn_varlen( } else if (80..90).contains(&runtime_compute_cap) || runtime_compute_cap == 90 { #[cfg(feature = "flash-attn")] { - use candle_flash_attn::{flash_attn_varlen, flash_attn_varlen_alibi}; + use candle_flash_attn::{flash_attn_varlen_alibi_windowed, flash_attn_varlen_windowed}; + + let window_size_right = if causal { Some(0) } else { None }; let attention = if let Some(alibi_slopes) = alibi_slopes { - flash_attn_varlen_alibi( + flash_attn_varlen_alibi_windowed( q, k, v, @@ -72,10 +78,11 @@ pub(crate) fn flash_attn_varlen( max_seqlen_q, max_seqlen_k, softmax_scale, - causal, + window_size_left, + window_size_right, ) } else { - flash_attn_varlen( + flash_attn_varlen_windowed( q, k, v, @@ -84,7 +91,8 @@ pub(crate) fn flash_attn_varlen( max_seqlen_q, max_seqlen_k, softmax_scale, - causal, + window_size_left, + window_size_right, ) }; diff --git a/backends/candle/src/layers/linear.rs b/backends/candle/src/layers/linear.rs index 3fdd025b..fc8af1dd 100644 --- a/backends/candle/src/layers/linear.rs +++ b/backends/candle/src/layers/linear.rs @@ -7,6 +7,7 @@ use serde::Deserialize; pub enum HiddenAct { Gelu, Relu, + #[serde(alias = "silu")] Swiglu, } diff --git a/backends/candle/src/layers/mod.rs b/backends/candle/src/layers/mod.rs index 8e108fc2..81f63310 100644 --- a/backends/candle/src/layers/mod.rs +++ b/backends/candle/src/layers/mod.rs @@ -2,7 +2,11 @@ mod cublaslt; mod layer_norm; mod linear; +#[allow(dead_code, unused)] +mod rms_norm; pub use cublaslt::get_cublas_lt_wrapper; pub use layer_norm::LayerNorm; pub use linear::{HiddenAct, Linear}; +#[allow(unused_imports)] +pub use rms_norm::RMSNorm; diff --git a/backends/candle/src/layers/rms_norm.rs b/backends/candle/src/layers/rms_norm.rs new file mode 100644 index 00000000..e7dab642 --- /dev/null +++ b/backends/candle/src/layers/rms_norm.rs @@ -0,0 +1,96 @@ +use candle::{DType, Device, Result, Tensor, D}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +pub struct RMSNorm { + weight: Tensor, + epsilon: f32, + span: tracing::Span, +} + +impl RMSNorm { + pub fn load(vb: VarBuilder, hidden_size: usize, epsilon: f32) -> Result { + Ok(Self { + weight: vb + .get(hidden_size, "weight") + .or_else(|_| vb.get(hidden_size, "gamma"))?, + epsilon, + span: tracing::span!(tracing::Level::TRACE, "rms-norm"), + }) + } + + pub fn forward( + &self, + hidden_states: &Tensor, + residual: Option<&Tensor>, + ) -> Result<(Tensor, Tensor)> { + let _enter = self.span.enter(); + + match hidden_states.device() { + Device::Cpu | Device::Metal(_) => { + let mut hidden_states = hidden_states.clone(); + let residual_add = if let Some(residual) = residual { + let residual_add = hidden_states.add(residual)?; + hidden_states = residual_add.clone(); + residual_add + } else { + hidden_states.clone() + }; + + let hidden_states_dtype = hidden_states.dtype(); + let internal_dtype = match hidden_states_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = hidden_states.dim(D::Minus1)?; + let hidden_states = hidden_states.to_dtype(internal_dtype)?; + let norm_hidden_states = + (hidden_states.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let hidden_states_normed = hidden_states + .broadcast_div(&(norm_hidden_states + self.epsilon as f64)?.sqrt()?)?; + Ok(( + hidden_states_normed + .to_dtype(hidden_states_dtype)? + .broadcast_mul(&self.weight)?, + residual_add, + )) + } + Device::Cuda(_) => { + #[cfg(feature = "cuda")] + { + use candle_layer_norm::{fused_add_rms_norm, rms_norm}; + + let original_shape = hidden_states.shape(); + let hidden_states = hidden_states.flatten_to(D::Minus2)?; + + if let Some(residual) = residual { + let residual = residual.flatten_to(D::Minus2)?; + + let (result, residual_add) = fused_add_rms_norm( + &hidden_states, + &residual, + &self.weight, + None, + self.epsilon, + )?; + Ok(( + result.reshape(original_shape)?, + residual_add.reshape(original_shape)?, + )) + } else { + let residual_add = hidden_states.clone(); + + let result = rms_norm(&hidden_states, &self.weight, None, self.epsilon)?; + + Ok(( + result.reshape(original_shape)?, + residual_add.reshape(original_shape)?, + )) + } + } + #[cfg(not(feature = "cuda"))] + candle::bail!("`cuda` feature is not enabled") + } + } + } +} diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index a7a5dcf0..b9d750dc 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -12,12 +12,12 @@ use crate::compute_cap::{ }; use crate::models::{ BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, JinaCodeBertModel, - Model, NomicBertModel, NomicConfig, + MistralConfig, Model, NomicBertModel, NomicConfig, }; #[cfg(feature = "cuda")] use crate::models::{ FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel, - FlashNomicBertModel, + FlashMistralModel, FlashNomicBertModel, }; use anyhow::Context; use candle::{DType, Device}; @@ -56,6 +56,7 @@ enum Config { DistilBert(DistilBertConfig), #[serde(rename(deserialize = "nomic_bert"))] NomicBert(NomicConfig), + Mistral(MistralConfig), } pub struct CandleBackend { @@ -69,6 +70,54 @@ impl CandleBackend { dtype: String, model_type: ModelType, ) -> Result { + // Default files + let default_safetensors = model_path.join("model.safetensors"); + let default_pytorch = model_path.join("pytorch_model.bin"); + + // Single Files + let model_files = if default_safetensors.exists() { + vec![default_safetensors] + } else if default_pytorch.exists() { + vec![default_pytorch] + } + // Sharded weights + else { + // Get index file + let index_file = model_path.join("model.safetensors.index.json"); + + // Parse file + let index_file_string: String = std::fs::read_to_string(&index_file) + .map_err(|err| BackendError::Start(err.to_string()))?; + let json: serde_json::Value = serde_json::from_str(&index_file_string) + .map_err(|err| BackendError::Start(err.to_string()))?; + + let weight_map = match json.get("weight_map") { + None => { + return Err(BackendError::Start(format!( + "no weight map in {index_file:?}" + ))); + } + Some(serde_json::Value::Object(map)) => map, + Some(_) => { + return Err(BackendError::Start(format!( + "weight map in {index_file:?} is not a map" + ))); + } + }; + let mut safetensors_files = std::collections::HashSet::new(); + for value in weight_map.values() { + if let Some(file) = value.as_str() { + safetensors_files.insert(file.to_string()); + } + } + + // Collect paths + safetensors_files + .iter() + .map(|n| model_path.join(n)) + .collect() + }; + // Load config let config: String = std::fs::read_to_string(model_path.join("config.json")) .context("Unable to read config file") @@ -115,17 +164,10 @@ impl CandleBackend { ))) }?; - let safetensors_path = model_path.join("model.safetensors"); - let vb = if safetensors_path.exists() { - unsafe { - VarBuilder::from_mmaped_safetensors( - &[model_path.join("model.safetensors")], - dtype, - &device, - ) - } + let vb = if model_files.len() == 1 && model_files[0].extension().unwrap() == "bin" { + VarBuilder::from_pth(&model_files[0], dtype, &device) } else { - VarBuilder::from_pth(model_path.join("pytorch_model.bin"), dtype, &device) + unsafe { VarBuilder::from_mmaped_safetensors(&model_files, dtype, &device) } } .s()?; @@ -136,7 +178,7 @@ impl CandleBackend { )), (Config::Bert(config), Device::Cpu | Device::Metal(_)) => match config { BertConfigWrapper::JinaBert(config) => { - tracing::info!("Starting JinaBertModel model on {:?}", device); + tracing::info!("Starting JinaBert model on {:?}", device); Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?)) } BertConfigWrapper::JinaCodeBert(config) => { @@ -160,15 +202,19 @@ impl CandleBackend { )) } (Config::DistilBert(config), Device::Cpu | Device::Metal(_)) => { - tracing::info!("Starting DistilBertModel model on {:?}", device); + tracing::info!("Starting DistilBert model on {:?}", device); Ok(Box::new( DistilBertModel::load(vb, &config, model_type).s()?, )) } (Config::NomicBert(config), Device::Cpu | Device::Metal(_)) => { - tracing::info!("Starting NomicBertModel model on {:?}", device); + tracing::info!("Starting NomicBert model on {:?}", device); Ok(Box::new(NomicBertModel::load(vb, &config, model_type).s()?)) } + (Config::Mistral(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start( + "Mistral is only supported on Cuda devices in fp16 with flash attention enabled" + .to_string(), + )), #[cfg(feature = "cuda")] (Config::Bert(config), Device::Cuda(_)) => { if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) @@ -198,7 +244,7 @@ impl CandleBackend { } else { match config { BertConfigWrapper::JinaBert(config) => { - tracing::info!("Starting JinaBertModel model on {:?}", device); + tracing::info!("Starting JinaBert model on {:?}", device); Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?)) } BertConfigWrapper::JinaCodeBert(config) => { @@ -245,7 +291,7 @@ impl CandleBackend { .to_lowercase() == "true" { - tracing::info!("Starting FlashDistilBertModel model on {:?}", device); + tracing::info!("Starting FlashDistilBert model on {:?}", device); Ok(Box::new( FlashDistilBertModel::load(vb, &config, model_type).s()?, )) @@ -265,15 +311,28 @@ impl CandleBackend { .to_lowercase() == "true" { - tracing::info!("Starting FlashNomicBertModel model on {:?}", device); + tracing::info!("Starting FlashNomicBert model on {:?}", device); Ok(Box::new( FlashNomicBertModel::load(vb, &config, model_type).s()?, )) } else { - tracing::info!("Starting NomicBertModel model on {:?}", device); + tracing::info!("Starting NomicBert model on {:?}", device); Ok(Box::new(NomicBertModel::load(vb, &config, model_type).s()?)) } } + #[cfg(feature = "cuda")] + (Config::Mistral(config), Device::Cuda(_)) => { + if dtype != DType::F16 + || !cfg!(feature = "flash-attn") + || get_runtime_compute_cap().unwrap() < 80 + { + return Err(BackendError::Start("Mistral is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string())); + } + tracing::info!("Starting FlashMistral model on {:?}", device); + Ok(Box::new( + FlashMistralModel::load(vb, &config, model_type).s()?, + )) + } }; Ok(Self { diff --git a/backends/candle/src/models/bert.rs b/backends/candle/src/models/bert.rs index d12d90d2..5795fa27 100644 --- a/backends/candle/src/models/bert.rs +++ b/backends/candle/src/models/bert.rs @@ -638,6 +638,10 @@ impl BertModel { (pool, Some(classifier), None) } ModelType::Embedding(pool) => { + if pool == Pool::LastToken { + candle::bail!("`last_token` is not supported for Bert"); + } + let splade = if pool == Pool::Splade { Some(BertSpladeHead::load_roberta(vb.clone(), config)?) } else { @@ -832,6 +836,8 @@ impl BertModel { let pooled_embeddings = match self.pool { // CLS pooling Pool::Cls => outputs.i((.., 0))?, + // Last token pooling is not supported for this model + Pool::LastToken => unreachable!(), // Mean pooling Pool::Mean => { if let Some(ref attention_mask) = attention_mask { diff --git a/backends/candle/src/models/distilbert.rs b/backends/candle/src/models/distilbert.rs index e7145309..2cf62081 100644 --- a/backends/candle/src/models/distilbert.rs +++ b/backends/candle/src/models/distilbert.rs @@ -389,7 +389,12 @@ impl DistilBertModel { ModelType::Classifier => { candle::bail!("`classifier` model type is not supported for DistilBert") } - ModelType::Embedding(pool) => pool, + ModelType::Embedding(pool) => { + if pool == Pool::LastToken { + candle::bail!("`last_token` is not supported for DistilBert"); + } + pool + } }; let (embeddings, encoder) = match ( @@ -564,6 +569,8 @@ impl DistilBertModel { let pooled_embeddings = match self.pool { // CLS pooling Pool::Cls => outputs.i((.., 0))?, + // Last token pooling is not supported for this model + Pool::LastToken => unreachable!(), // Mean pooling Pool::Mean => { if let Some(ref attention_mask) = attention_mask { diff --git a/backends/candle/src/models/flash_bert.rs b/backends/candle/src/models/flash_bert.rs index d248c91f..8f20f027 100644 --- a/backends/candle/src/models/flash_bert.rs +++ b/backends/candle/src/models/flash_bert.rs @@ -5,7 +5,7 @@ use crate::models::bert::{ PositionEmbeddingType, RobertaClassificationHead, }; use crate::models::Model; -use candle::{DType, Device, Result, Tensor}; +use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::VarBuilder; use text_embeddings_backend_core::{Batch, ModelType, Pool}; @@ -103,6 +103,7 @@ impl BertAttention { max_s, self.softmax_scale, false, + None, )?; let attention = attention.flatten_from(candle::D::Minus2)?; @@ -393,27 +394,44 @@ impl FlashBertModel { let pooled_embeddings = if has_pooling_requests { match self.pool { - // CLS pooling - Pool::Cls => { - // Get the indices of the cls tokens from cu_seqlens - let mut cls_indices = cu_seqlens.narrow(0, 0, batch_size)?; - - // If raw_indices is empty, we don't need to do anything with - // the pooled_indices - if has_raw_requests { - // We need the pooled indices to select the correct cls indices - let pooled_indices = Tensor::from_vec( - batch.pooled_indices.clone(), - batch.pooled_indices.len(), - &self.device, - )?; - - // Only select indices that requires pooling - cls_indices = cls_indices.index_select(&pooled_indices, 0)? + // CLS and LastToken pooling + Pool::Cls | Pool::LastToken => { + if batch_size > 1 { + // Get token indices form cu_seqlens + let mut indices = match self.pool { + Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?, + Pool::LastToken => cu_seqlens.narrow(0, 1, batch_size)?, + _ => unreachable!(), + }; + + // If raw_indices is empty, we don't need to do anything with + // the pooled_indices + if has_raw_requests { + // We need the pooled indices to select the correct cls indices + let pooled_indices = Tensor::from_vec( + batch.pooled_indices.clone(), + batch.pooled_indices.len(), + &self.device, + )?; + + // Only select indices that requires pooling + indices = indices.index_select(&pooled_indices, 0)? + } + + // Select tokens + Some(outputs.index_select(&indices, 0)?) + } else { + Some( + match self.pool { + Pool::Cls => outputs.i(0)?, + Pool::LastToken => { + outputs.i(batch.cumulative_seq_lengths[1] as usize - 1)? + } + _ => unreachable!(), + } + .unsqueeze(0)?, + ) } - - // Select cls tokens - Some(outputs.index_select(&cls_indices, 0)?) } // Mean pooling Pool::Mean => { diff --git a/backends/candle/src/models/flash_distilbert.rs b/backends/candle/src/models/flash_distilbert.rs index 3d30c3a8..f8dd294b 100644 --- a/backends/candle/src/models/flash_distilbert.rs +++ b/backends/candle/src/models/flash_distilbert.rs @@ -4,7 +4,7 @@ use crate::models::distilbert::{ DistilBertConfig, DistilBertEmbeddings, DistilBertMLP, DistilBertSpladeHead, }; use crate::models::Model; -use candle::{DType, Device, Result, Tensor}; +use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::VarBuilder; use text_embeddings_backend_core::{Batch, ModelType, Pool}; @@ -84,6 +84,7 @@ impl DistilBertAttention { max_s, self.softmax_scale, false, + None, )?; let attention = attention.flatten_from(candle::D::Minus2)?; @@ -259,27 +260,42 @@ impl FlashDistilBertModel { let pooled_embeddings = if has_pooling_requests { let pooled_embeddings = match self.pool { - // CLS pooling - Pool::Cls => { - // Get the indices of the cls tokens from cu_seqlens - let mut cls_indices = cu_seqlens.narrow(0, 0, batch_size)?; - - // If raw_indices is empty, we don't need to do anything with - // the pooled_indices - if has_raw_requests { - // We need the pooled indices to select the correct cls indices - let pooled_indices = Tensor::from_vec( - batch.pooled_indices.clone(), - batch.pooled_indices.len(), - &self.device, - )?; - - // Only select indices that requires pooling - cls_indices = cls_indices.index_select(&pooled_indices, 0)? + // CLS and LastToken pooling + Pool::Cls | Pool::LastToken => { + if batch_size > 1 { + // Get token indices form cu_seqlens + let mut indices = match self.pool { + Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?, + Pool::LastToken => cu_seqlens.narrow(0, 1, batch_size)?, + _ => unreachable!(), + }; + + // If raw_indices is empty, we don't need to do anything with + // the pooled_indices + if has_raw_requests { + // We need the pooled indices to select the correct cls indices + let pooled_indices = Tensor::from_vec( + batch.pooled_indices.clone(), + batch.pooled_indices.len(), + &self.device, + )?; + + // Only select indices that requires pooling + indices = indices.index_select(&pooled_indices, 0)? + } + + // Select tokens + outputs.index_select(&indices, 0)? + } else { + match self.pool { + Pool::Cls => outputs.i(0)?, + Pool::LastToken => { + outputs.i(batch.cumulative_seq_lengths[1] as usize - 1)? + } + _ => unreachable!(), + } + .unsqueeze(0)? } - - // Select cls tokens - outputs.index_select(&cls_indices, 0)? } // Mean pooling Pool::Mean => { diff --git a/backends/candle/src/models/flash_jina.rs b/backends/candle/src/models/flash_jina.rs index b2d47e51..c8efee18 100644 --- a/backends/candle/src/models/flash_jina.rs +++ b/backends/candle/src/models/flash_jina.rs @@ -105,6 +105,7 @@ impl JinaAttention { max_s, self.softmax_scale, false, + None, )?; let attention = attention.flatten_from(candle::D::Minus2)?; @@ -319,11 +320,15 @@ impl FlashJinaBertModel { let pooled_embeddings = if has_pooling_requests { match self.pool { - // CLS pooling - Pool::Cls => { + // CLS and LastToken pooling + Pool::Cls | Pool::LastToken => { if batch_size > 1 { - // Get the indices of the cls tokens from cu_seqlens - let mut cls_indices = cu_seqlens.narrow(0, 0, batch_size)?; + // Get token indices form cu_seqlens + let mut indices = match self.pool { + Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?, + Pool::LastToken => cu_seqlens.narrow(0, 1, batch_size)?, + _ => unreachable!(), + }; // If raw_indices is empty, we don't need to do anything with // the pooled_indices @@ -336,13 +341,22 @@ impl FlashJinaBertModel { )?; // Only select indices that requires pooling - cls_indices = cls_indices.index_select(&pooled_indices, 0)? + indices = indices.index_select(&pooled_indices, 0)? } - // Select cls tokens - Some(outputs.index_select(&cls_indices, 0)?) + // Select tokens + Some(outputs.index_select(&indices, 0)?) } else { - Some(outputs.i(0)?) + Some( + match self.pool { + Pool::Cls => outputs.i(0)?, + Pool::LastToken => { + outputs.i(batch.cumulative_seq_lengths[1] as usize - 1)? + } + _ => unreachable!(), + } + .unsqueeze(0)?, + ) } } // Mean pooling diff --git a/backends/candle/src/models/flash_jina_code.rs b/backends/candle/src/models/flash_jina_code.rs index 5df80be2..06ade2d5 100644 --- a/backends/candle/src/models/flash_jina_code.rs +++ b/backends/candle/src/models/flash_jina_code.rs @@ -141,6 +141,7 @@ impl JinaCodeAttention { max_s, self.softmax_scale, false, + None, )?; let attention = attention.flatten_from(candle::D::Minus2)?; @@ -372,11 +373,15 @@ impl FlashJinaCodeBertModel { let pooled_embeddings = if has_pooling_requests { match self.pool { - // CLS pooling - Pool::Cls => { + // CLS and LastToken pooling + Pool::Cls | Pool::LastToken => { if batch_size > 1 { - // Get the indices of the cls tokens from cu_seqlens - let mut cls_indices = cu_seqlens.narrow(0, 0, batch_size)?; + // Get token indices form cu_seqlens + let mut indices = match self.pool { + Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?, + Pool::LastToken => cu_seqlens.narrow(0, 1, batch_size)?, + _ => unreachable!(), + }; // If raw_indices is empty, we don't need to do anything with // the pooled_indices @@ -389,13 +394,22 @@ impl FlashJinaCodeBertModel { )?; // Only select indices that requires pooling - cls_indices = cls_indices.index_select(&pooled_indices, 0)? + indices = indices.index_select(&pooled_indices, 0)? } - // Select cls tokens - Some(outputs.index_select(&cls_indices, 0)?) + // Select tokens + Some(outputs.index_select(&indices, 0)?) } else { - Some(outputs.i(0)?) + Some( + match self.pool { + Pool::Cls => outputs.i(0)?, + Pool::LastToken => { + outputs.i(batch.cumulative_seq_lengths[1] as usize - 1)? + } + _ => unreachable!(), + } + .unsqueeze(0)?, + ) } } // Mean pooling diff --git a/backends/candle/src/models/flash_mistral.rs b/backends/candle/src/models/flash_mistral.rs new file mode 100644 index 00000000..53e5d3c4 --- /dev/null +++ b/backends/candle/src/models/flash_mistral.rs @@ -0,0 +1,442 @@ +use crate::flash_attn::flash_attn_varlen; +use crate::layers::{HiddenAct, Linear, RMSNorm}; +use crate::models::{MistralConfig, Model}; +use candle::{DType, Device, IndexOp, Result, Tensor}; +use candle_nn::{Embedding, Module, VarBuilder}; +use text_embeddings_backend_core::{Batch, ModelType, Pool}; + +struct MistralAttention { + qkv_linear: Linear, + o_proj: Linear, + + window_size_left: Option, + + num_attention_heads: usize, + num_key_value_heads: usize, + attention_head_size: usize, + + softmax_scale: f32, + + span: tracing::Span, +} + +impl MistralAttention { + pub fn load(vb: VarBuilder, config: &MistralConfig) -> Result { + let window_size_left = config.sliding_window; + let num_attention_heads = config.num_attention_heads; + let attention_head_size = config.hidden_size / config.num_attention_heads; + let num_key_value_heads = config.num_key_value_heads; + let hidden_size = config.hidden_size; + + let query_weight = vb.pp("q_proj").get((hidden_size, hidden_size), "weight")?; + + let key_weight = vb.pp("k_proj").get( + (num_key_value_heads * attention_head_size, hidden_size), + "weight", + )?; + + let value_weight = vb.pp("v_proj").get( + (num_key_value_heads * attention_head_size, hidden_size), + "weight", + )?; + + let qkv_weight = Tensor::cat(&[&query_weight, &key_weight, &value_weight], 0)?; + let qkv_linear = Linear::new(qkv_weight, None, None); + + let o_proj_weight = vb.pp("o_proj").get((hidden_size, hidden_size), "weight")?; + + let o_proj = Linear::new(o_proj_weight, None, None); + + let softmax_scale = (1. / (attention_head_size as f64).sqrt()) as f32; + + Ok(Self { + qkv_linear, + o_proj, + window_size_left, + num_attention_heads, + num_key_value_heads, + attention_head_size, + softmax_scale, + span: tracing::span!(tracing::Level::TRACE, "attention"), + }) + } + + pub fn forward( + &self, + hidden_states: &Tensor, + cu_seqlens: &Tensor, + cos: &Tensor, + sin: &Tensor, + max_s: usize, + ) -> Result { + let _enter = self.span.enter(); + + let qkv = self.qkv_linear.forward(hidden_states)?; + + // Reshape to [tokens, heads, head_size] + let mut new_qkv_shape = qkv.dims().to_vec(); + new_qkv_shape.pop(); + new_qkv_shape.push(self.num_attention_heads + 2 * self.num_key_value_heads); + new_qkv_shape.push(self.attention_head_size); + + let qkv = qkv.reshape(new_qkv_shape)?; + + // Split qkv tensor + let q = qkv.narrow(1, 0, self.num_attention_heads)?; + let k = qkv.narrow(1, self.num_attention_heads, self.num_key_value_heads)?; + let v = qkv.narrow( + 1, + self.num_attention_heads + self.num_key_value_heads, + self.num_key_value_heads, + )?; + + candle_rotary::apply_rotary_inplace(&q, &k, &cos, &sin, true)?; + + let attention = flash_attn_varlen( + &q, + &k, + &v, + None, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + self.softmax_scale, + true, + self.window_size_left, + )?; + let attention = attention.flatten_from(candle::D::Minus2)?; + + self.o_proj.forward(&attention) + } +} + +struct MistralMLP { + gate_up_proj: Linear, + down_proj: Linear, + + act: HiddenAct, + intermediate_size: usize, + + span: tracing::Span, +} + +impl MistralMLP { + pub fn load(vb: VarBuilder, config: &MistralConfig) -> Result { + let intermediate_size = config.intermediate_size; + + let gate_proj_weight = vb + .pp("gate_proj") + .get((intermediate_size, config.hidden_size), "weight")?; + + let up_proj_weight = vb + .pp("up_proj") + .get((intermediate_size, config.hidden_size), "weight")?; + + let gate_up_proj_weight = Tensor::cat(&[&gate_proj_weight, &up_proj_weight], 0)?; + let gate_up_proj = Linear::new(gate_up_proj_weight, None, None); + + let down_proj_weight = vb + .pp("down_proj") + .get((config.hidden_size, intermediate_size), "weight")?; + let down_proj = Linear::new(down_proj_weight, None, None); + + Ok(Self { + gate_up_proj, + down_proj, + intermediate_size, + act: config.hidden_act.clone(), + span: tracing::span!(tracing::Level::TRACE, "mlp"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let gate_up_states = self.gate_up_proj.forward(hidden_states)?; + let gate_states = gate_up_states.narrow(1, 0, self.intermediate_size)?; + let up_states = gate_up_states.narrow(1, self.intermediate_size, self.intermediate_size)?; + + let gate_states = match self.act { + HiddenAct::Gelu => gate_states.gelu(), + HiddenAct::Relu => gate_states.relu(), + HiddenAct::Swiglu => gate_states.silu(), + }?; + let r = self.down_proj.forward(&(gate_states * up_states)?); + r + } +} + +struct MistralLayer { + attention: MistralAttention, + mlp: MistralMLP, + input_layer_norm: RMSNorm, + post_attention_layer_norm: RMSNorm, + + span: tracing::Span, +} + +impl MistralLayer { + pub fn load(vb: VarBuilder, config: &MistralConfig) -> Result { + let attention = MistralAttention::load(vb.pp("self_attn"), config)?; + let mlp = MistralMLP::load(vb.pp("mlp"), config)?; + + let input_layer_norm = RMSNorm::load( + vb.pp("input_layernorm"), + config.hidden_size, + config.rms_norm_eps, + )?; + let post_attention_layer_norm = RMSNorm::load( + vb.pp("post_attention_layernorm"), + config.hidden_size, + config.rms_norm_eps, + )?; + + Ok(Self { + attention, + mlp, + input_layer_norm, + post_attention_layer_norm, + span: tracing::span!(tracing::Level::TRACE, "layer"), + }) + } + + pub fn forward( + &self, + hidden_states: &Tensor, + residual: Option<&Tensor>, + cu_seqlens: &Tensor, + cos: &Tensor, + sin: &Tensor, + max_s: usize, + ) -> Result<(Tensor, Tensor)> { + let _enter = self.span.enter(); + + let (normed_hidden_states, res) = self.input_layer_norm.forward(hidden_states, residual)?; + let attn_output = + self.attention + .forward(&normed_hidden_states, cu_seqlens, cos, sin, max_s)?; + let (normed_attn_res_output, attn_res) = self + .post_attention_layer_norm + .forward(&attn_output, Some(&res))?; + let mlp_output = self.mlp.forward(&normed_attn_res_output)?; + + Ok((mlp_output, attn_res)) + } +} + +pub struct FlashMistralModel { + embeddings: Embedding, + layers: Vec, + norm: RMSNorm, + cos_cache: Tensor, + sin_cache: Tensor, + pool: Pool, + pub device: Device, + + span: tracing::Span, +} + +impl FlashMistralModel { + pub fn load(vb: VarBuilder, config: &MistralConfig, model_type: ModelType) -> Result { + match vb.device() { + Device::Cuda(_) => {} + _ => candle::bail!("FlashMistral requires Cuda"), + } + + if vb.dtype() != DType::F16 { + candle::bail!("FlashMistral requires DType::F16") + } + + let pool = match model_type { + ModelType::Classifier => { + candle::bail!("`classifier` model type is not supported for Mistral") + } + ModelType::Embedding(pool) => pool, + }; + + let embeddings = Embedding::new( + vb.pp("embed_tokens") + .get((config.vocab_size, config.hidden_size), "weight")?, + config.hidden_size, + ); + + let layers = (0..config.num_hidden_layers) + .map(|index| MistralLayer::load(vb.pp(format!("layers.{index}")), config)) + .collect::>>()?; + + let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?; + + let inv_freqs = candle_rotary::inv_freqs( + layers[0].attention.attention_head_size, + config.rope_theta, + vb.device(), + )?; + let (cos_cache, sin_cache) = + candle_rotary::cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype())?; + + Ok(Self { + embeddings, + layers, + norm, + cos_cache, + sin_cache, + pool, + device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + pub fn forward(&self, batch: Batch) -> Result<(Option, Option)> { + let _enter = self.span.enter(); + + let batch_size = batch.cumulative_seq_lengths.len() - 1; + let shape = batch.input_ids.len(); + + // Create Cuda tensors + let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?; + let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?; + let cu_seqlens = Tensor::from_vec( + batch.cumulative_seq_lengths.clone(), + batch_size + 1, + &self.device, + )?; + + let mut hidden_states = self.embeddings.forward(&input_ids)?; + + let cos = self.cos_cache.index_select(&position_ids, 0)?; + let sin = self.sin_cache.index_select(&position_ids, 0)?; + + let mut residual = None; + for layer in &self.layers { + let (h, r) = layer.forward( + &hidden_states, + residual.as_ref(), + &cu_seqlens, + &cos, + &sin, + batch.max_length as usize, + )?; + hidden_states = h; + residual = Some(r); + } + + let (outputs, _) = self.norm.forward(&hidden_states, residual.as_ref())?; + + let has_pooling_requests = !batch.pooled_indices.is_empty(); + let has_raw_requests = !batch.raw_indices.is_empty(); + + let pooled_embeddings = if has_pooling_requests { + match self.pool { + // CLS and LastToken pooling + Pool::Cls | Pool::LastToken => { + if batch_size > 1 { + // Get token indices form cu_seqlens + let mut indices = match self.pool { + Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?, + Pool::LastToken => cu_seqlens.narrow(0, 1, batch_size)?, + _ => unreachable!(), + }; + + // If raw_indices is empty, we don't need to do anything with + // the pooled_indices + if has_raw_requests { + // We need the pooled indices to select the correct cls indices + let pooled_indices = Tensor::from_vec( + batch.pooled_indices.clone(), + batch.pooled_indices.len(), + &self.device, + )?; + + // Only select indices that requires pooling + indices = indices.index_select(&pooled_indices, 0)? + } + + // Select tokens + Some(outputs.index_select(&indices, 0)?) + } else { + Some( + match self.pool { + Pool::Cls => outputs.i(0)?, + Pool::LastToken => { + outputs.i(batch.cumulative_seq_lengths[1] as usize - 1)? + } + _ => unreachable!(), + } + .unsqueeze(0)?, + ) + } + } + // Mean pooling + Pool::Mean => { + if batch_size > 1 { + // for each request that requires pooling + let results: Result> = batch + .pooled_indices + .into_iter() + .map(|i| { + let i = i as usize; + let start = batch.cumulative_seq_lengths[i]; + let len = batch.cumulative_seq_lengths[i + 1] - start; + + // Mean + let embeddings = outputs.narrow(0, start as usize, len as usize)?; + embeddings.sum_keepdim(0)? / (len as f64) + }) + .collect(); + + // Concatenate all results + Some(Tensor::cat(&results?, 0)?) + } else { + Some((outputs.sum_keepdim(0)? / (batch.max_length as f64))?) + } + } + Pool::Splade => { + unreachable!(); + } + } + } else { + None + }; + + let raw_embeddings = if has_raw_requests { + if batch_size > 1 && has_pooling_requests { + // Create indexing vector for the embeddings + let mut final_indices: Vec = Vec::with_capacity(shape); + for i in batch.raw_indices.into_iter() { + let i = i as usize; + // Get start/end token index of this specific member of the batch + let start = batch.cumulative_seq_lengths[i]; + let end = batch.cumulative_seq_lengths[i + 1]; + + for j in start..end { + // Add indices for the tokens of this specific member of the batch + final_indices.push(j); + } + } + + let final_indices_length = final_indices.len(); + let final_indices = + Tensor::from_vec(final_indices, final_indices_length, &self.device)?; + + // Select the tokens with final indices + Some(outputs.index_select(&final_indices, 0)?) + } else { + Some(outputs) + } + } else { + None + }; + + Ok((pooled_embeddings, raw_embeddings)) + } +} + +impl Model for FlashMistralModel { + fn is_padded(&self) -> bool { + false + } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { + self.forward(batch) + } +} diff --git a/backends/candle/src/models/flash_nomic.rs b/backends/candle/src/models/flash_nomic.rs index 12eff0f1..8ad1ab89 100644 --- a/backends/candle/src/models/flash_nomic.rs +++ b/backends/candle/src/models/flash_nomic.rs @@ -81,6 +81,7 @@ impl NomicAttention { max_s, self.softmax_scale, false, + None, )?; let attention = attention.flatten_from(D::Minus2)?; @@ -304,11 +305,15 @@ impl FlashNomicBertModel { let pooled_embeddings = if has_pooling_requests { match self.pool { - // CLS pooling - Pool::Cls => { + // CLS and LastToken pooling + Pool::Cls | Pool::LastToken => { if batch_size > 1 { - // Get the indices of the cls tokens from cu_seqlens - let mut cls_indices = cu_seqlens.narrow(0, 0, batch_size)?; + // Get token indices form cu_seqlens + let mut indices = match self.pool { + Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?, + Pool::LastToken => cu_seqlens.narrow(0, 1, batch_size)?, + _ => unreachable!(), + }; // If raw_indices is empty, we don't need to do anything with // the pooled_indices @@ -321,13 +326,22 @@ impl FlashNomicBertModel { )?; // Only select indices that requires pooling - cls_indices = cls_indices.index_select(&pooled_indices, 0)? + indices = indices.index_select(&pooled_indices, 0)? } - // Select cls tokens - Some(outputs.index_select(&cls_indices, 0)?) + // Select tokens + Some(outputs.index_select(&indices, 0)?) } else { - Some(outputs.i(0)?) + Some( + match self.pool { + Pool::Cls => outputs.i(0)?, + Pool::LastToken => { + outputs.i(batch.cumulative_seq_lengths[1] as usize - 1)? + } + _ => unreachable!(), + } + .unsqueeze(0)?, + ) } } // Mean pooling diff --git a/backends/candle/src/models/jina.rs b/backends/candle/src/models/jina.rs index 768e1a6f..6884fcae 100644 --- a/backends/candle/src/models/jina.rs +++ b/backends/candle/src/models/jina.rs @@ -373,6 +373,9 @@ impl JinaBertModel { if pool == Pool::Splade { candle::bail!("`splade` is not supported for Jina") } + if pool == Pool::LastToken { + candle::bail!("`last_token` is not supported for Jina"); + } pool } }; @@ -594,6 +597,8 @@ impl JinaBertModel { let pooled_embeddings = match self.pool { // CLS pooling Pool::Cls => outputs.i((.., 0))?, + // Last token pooling is not supported for this model + Pool::LastToken => unreachable!(), // Mean pooling Pool::Mean => { if let Some(ref attention_mask) = attention_mask { diff --git a/backends/candle/src/models/jina_code.rs b/backends/candle/src/models/jina_code.rs index 348f5892..fd004bc6 100644 --- a/backends/candle/src/models/jina_code.rs +++ b/backends/candle/src/models/jina_code.rs @@ -356,11 +356,14 @@ impl JinaCodeBertModel { let pool = match model_type { ModelType::Classifier => { - candle::bail!("`classifier` model type is not supported for Jina") + candle::bail!("`classifier` model type is not supported for JinaCode") } ModelType::Embedding(pool) => { if pool == Pool::Splade { - candle::bail!("`splade` is not supported for Jina") + candle::bail!("`splade` is not supported for JinaCode") + } + if pool == Pool::LastToken { + candle::bail!("`last_token` is not supported for JinaCode"); } pool } @@ -583,6 +586,8 @@ impl JinaCodeBertModel { let pooled_embeddings = match self.pool { // CLS pooling Pool::Cls => outputs.i((.., 0))?, + // Last token pooling is not supported for this model + Pool::LastToken => unreachable!(), // Mean pooling Pool::Mean => { if let Some(ref attention_mask) = attention_mask { diff --git a/backends/candle/src/models/mistral.rs b/backends/candle/src/models/mistral.rs new file mode 100644 index 00000000..33c5ab00 --- /dev/null +++ b/backends/candle/src/models/mistral.rs @@ -0,0 +1,19 @@ +use crate::layers::HiddenAct; +use serde::Deserialize; + +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct MistralConfig { + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub hidden_act: HiddenAct, + pub max_position_embeddings: usize, + pub initializer_range: f64, + pub rms_norm_eps: f32, + pub model_type: Option, + pub rope_theta: f32, + pub sliding_window: Option, +} diff --git a/backends/candle/src/models/mod.rs b/backends/candle/src/models/mod.rs index a7d6b267..3e4a5785 100644 --- a/backends/candle/src/models/mod.rs +++ b/backends/candle/src/models/mod.rs @@ -8,6 +8,7 @@ mod bert; mod distilbert; mod jina; mod jina_code; +mod mistral; mod nomic; #[cfg(feature = "cuda")] @@ -25,11 +26,15 @@ mod flash_nomic; #[cfg(feature = "cuda")] mod flash_distilbert; +#[cfg(feature = "cuda")] +mod flash_mistral; + pub use bert::{BertConfig, BertModel, PositionEmbeddingType}; use candle::{Result, Tensor}; pub use distilbert::{DistilBertConfig, DistilBertModel}; pub use jina::JinaBertModel; pub use jina_code::JinaCodeBertModel; +pub use mistral::MistralConfig; pub use nomic::{NomicBertModel, NomicConfig}; use text_embeddings_backend_core::Batch; @@ -48,6 +53,9 @@ pub use flash_nomic::FlashNomicBertModel; #[cfg(feature = "cuda")] pub use flash_distilbert::FlashDistilBertModel; +#[cfg(feature = "cuda")] +pub use flash_mistral::FlashMistralModel; + pub(crate) trait Model { fn is_padded(&self) -> bool; diff --git a/backends/candle/src/models/nomic.rs b/backends/candle/src/models/nomic.rs index 4f9e7551..cdaaea92 100644 --- a/backends/candle/src/models/nomic.rs +++ b/backends/candle/src/models/nomic.rs @@ -405,6 +405,9 @@ impl NomicBertModel { if pool == Pool::Splade { candle::bail!("`splade` is not supported for Nomic") } + if pool == Pool::LastToken { + candle::bail!("`last_token` is not supported for Nomic"); + } pool } }; @@ -610,6 +613,8 @@ impl NomicBertModel { let pooled_embeddings = match self.pool { // CLS pooling Pool::Cls => outputs.i((.., 0))?, + // Last token pooling is not supported for this model + Pool::LastToken => unreachable!(), // Mean pooling Pool::Mean => { if let Some(ref attention_mask) = attention_mask { diff --git a/backends/candle/tests/common.rs b/backends/candle/tests/common.rs index d7ebc67d..a3d74d16 100644 --- a/backends/candle/tests/common.rs +++ b/backends/candle/tests/common.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use hf_hub::api::sync::ApiBuilder; +use hf_hub::api::sync::{ApiBuilder, ApiError, ApiRepo}; use hf_hub::{Repo, RepoType}; use insta::internals::YamlMatcher; use serde::{Deserialize, Serialize}; @@ -25,7 +25,7 @@ impl Score { impl PartialEq for Score { fn eq(&self, other: &Self) -> bool { // Default tolerance for equality - self.is_close(other, 6e-3) + self.is_close(other, 5e-3) } } @@ -51,6 +51,44 @@ impl From>> for SnapshotScores { } } +#[derive(Serialize, Deserialize, Debug)] +pub struct SnapEmbedding(Vec); + +impl PartialEq for SnapEmbedding { + fn eq(&self, other: &Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + + let mut sumxx = 0.0; + let mut sumyy = 0.0; + let mut sumxy = 0.0; + + for (x, y) in self.0.iter().zip(other.0.iter()) { + sumxx += x * x; + sumyy += y * y; + sumxy += x * y; + } + + (sumxy / (sumxx * sumyy).sqrt()) > 0.999 + } +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub struct SnapshotEmbeddings(Vec); + +impl Deref for SnapshotEmbeddings { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From>> for SnapshotEmbeddings { + fn from(value: Vec>) -> Self { + Self(value.into_iter().map(|v| SnapEmbedding(v)).collect()) + } +} + pub fn sort_embeddings(embeddings: Embeddings) -> (Vec>, Vec>) { let mut pooled_embeddings = Vec::new(); let mut raw_embeddings = Vec::new(); @@ -85,23 +123,66 @@ pub fn download_artifacts( api_repo.get("config.json")?; api_repo.get("tokenizer.json")?; - let model_root = match api_repo.get("model.safetensors") { + let model_files = match download_safetensors(&api_repo) { Ok(p) => p, Err(_) => { + tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower."); + tracing::info!("Downloading `pytorch_model.bin`"); let p = api_repo.get("pytorch_model.bin")?; - tracing::warn!("`model.safetensors` not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower."); - p + vec![p] } - } - .parent().unwrap() - .to_path_buf(); + }; + let model_root = model_files[0].parent().unwrap().to_path_buf(); Ok(model_root) } +fn download_safetensors(api: &ApiRepo) -> Result, ApiError> { + // Single file + tracing::info!("Downloading `model.safetensors`"); + match api.get("model.safetensors") { + Ok(p) => return Ok(vec![p]), + Err(err) => tracing::warn!("Could not download `model.safetensors`: {}", err), + }; + + // Sharded weights + // Download and parse index file + tracing::info!("Downloading `model.safetensors.index.json`"); + let index_file = api.get("model.safetensors.index.json")?; + let index_file_string: String = + std::fs::read_to_string(index_file).expect("model.safetensors.index.json is corrupted"); + let json: serde_json::Value = serde_json::from_str(&index_file_string) + .expect("model.safetensors.index.json is corrupted"); + + let weight_map = match json.get("weight_map") { + Some(serde_json::Value::Object(map)) => map, + _ => panic!("model.safetensors.index.json is corrupted"), + }; + + let mut safetensors_filenames = std::collections::HashSet::new(); + for value in weight_map.values() { + if let Some(file) = value.as_str() { + safetensors_filenames.insert(file.to_string()); + } + } + + // Download weight files + let mut safetensors_files = Vec::new(); + for n in safetensors_filenames { + tracing::info!("Downloading `{}`", n); + safetensors_files.push(api.get(&n)?); + } + + Ok(safetensors_files) +} + pub fn relative_matcher() -> YamlMatcher { YamlMatcher::new() } +pub fn cosine_matcher() -> YamlMatcher { + YamlMatcher::new() +} + pub fn load_tokenizer(model_root: &Path) -> Result { // Load tokenizer let tokenizer_path = model_root.join("tokenizer.json"); diff --git a/backends/candle/tests/snapshots/test_flash_mistral__mistral_batch.snap b/backends/candle/tests/snapshots/test_flash_mistral__mistral_batch.snap new file mode 100644 index 00000000..acea099b --- /dev/null +++ b/backends/candle/tests/snapshots/test_flash_mistral__mistral_batch.snap @@ -0,0 +1,12293 @@ +--- +source: backends/candle/tests/test_flash_mistral.rs +assertion_line: 37 +expression: embeddings_batch +--- +- - 3.2363281 + - -1.1582031 + - 1.0810547 + - -2.0292969 + - 1.609375 + - -1.0048828 + - 0.43676758 + - -0.8769531 + - 0.79785156 + - -0.27612305 + - 0.4963379 + - -0.82128906 + - 0.16906738 + - -0.734375 + - -0.34936523 + - 0.03515625 + - 0.34375 + - 1.3769531 + - 1.5234375 + - -1.875 + - -1.4082031 + - 1.6289063 + - -1.1650391 + - 0.65234375 + - 1.796875 + - 1.984375 + - -0.4350586 + - 1.4003906 + - -0.34985352 + - -2.5253906 + - 2.5351563 + - 0.32348633 + - 2.3007813 + - 1.5195313 + - -0.28295898 + - 1.1650391 + - -3.4472656 + - 0.07421875 + - -5.28125 + - -0.8310547 + - 0.7524414 + - -2.4257813 + - -0.91845703 + - -0.9814453 + - -1.7285156 + - 2.0761719 + - 0.23657227 + - -3.9003906 + - -1.4052734 + - 0.8310547 + - 3.9140625 + - -0.43408203 + - -3.0429688 + - -100.5625 + - -3.0703125 + - -0.93652344 + - 2.71875 + - -1.0527344 + - -1.3789063 + - -7.3671875 + - -2.3789063 + - 0.58251953 + - 0.8388672 + - 0.13110352 + - 2.4003906 + - 0.07421875 + - -2.5488281 + - 0.5126953 + - 2.0644531 + - -1.5556641 + - -4.6679688 + - 0.055236816 + - -2.9921875 + - -0.9038086 + - -1.2294922 + - -0.3984375 + - 2.9863281 + - 3.1328125 + - -0.13867188 + - -0.36523438 + - -0.63916016 + - -0.6064453 + - -1.5869141 + - -0.3425293 + - -2.0234375 + - 0.5336914 + - -1.8027344 + - -0.15185547 + - 2.2578125 + - 0.86376953 + - -1.234375 + - 5.9453125 + - 2.7089844 + - -19.703125 + - -2.8125 + - -2.7832031 + - -4.4375 + - 0.35595703 + - 1.5751953 + - -4.09375 + - 1.6884766 + - -1.3564453 + - -3.8652344 + - -0.61035156 + - 0.0055770874 + - -2.7949219 + - 0.08062744 + - -1.3369141 + - -1.5839844 + - -0.056915283 + - 0.04058838 + - 0.4296875 + - 0.47753906 + - -1.5585938 + - -0.055511475 + - 3.03125 + - 2.8515625 + - 0.70947266 + - -0.18884277 + - 0.29467773 + - 2.2421875 + - 0.59472656 + - 0.15393066 + - -2.4863281 + - -2.1992188 + - -0.27172852 + - 2.40625 + - -0.73095703 + - 0.32299805 + - 1.59375 + - 2.3808594 + - 0.17297363 + - -3.2519531 + - 1.1630859 + - 1.234375 + - 2.40625 + - -0.3088379 + - 0.78564453 + - -1.2050781 + - -1.4824219 + - 1.5166016 + - -0.4206543 + - 1.3535156 + - -2.7734375 + - 1.1757813 + - -2.8027344 + - -1.7998047 + - -0.9379883 + - -2.5703125 + - 4.5820313 + - 0.78564453 + - -1.9257813 + - -1.0478516 + - 0.03515625 + - 0.5151367 + - -2.7832031 + - 0.90722656 + - -0.5102539 + - -3.0390625 + - -3.1289063 + - -1.2509766 + - -2.6191406 + - -0.5546875 + - -1.1376953 + - 0.51416016 + - 1.3994141 + - 3.3613281 + - -1.1591797 + - -0.7583008 + - -0.46289063 + - -2.6386719 + - -1.9306641 + - -0.43896484 + - -2.9863281 + - -0.09875488 + - 0.25195313 + - -1.3115234 + - 2.09375 + - -4.265625 + - -2.2519531 + - 1.7910156 + - 0.8022461 + - -1.8603516 + - -1.8544922 + - 0.13891602 + - 5.1054688 + - -3.4863281 + - -0.85253906 + - -1.1806641 + - 0.07336426 + - -1.9082031 + - -3.7753906 + - -0.5541992 + - 0.640625 + - -2.2460938 + - 1.4951172 + - 3.6328125 + - -2.1640625 + - -1.4921875 + - 0.13476563 + - 0.44189453 + - -2.359375 + - 1.9189453 + - 0.7114258 + - 7.9375 + - 3.2929688 + - 4.2617188 + - -2.8378906 + - -0.3474121 + - -2.2304688 + - -2.0644531 + - -0.7504883 + - -2.9101563 + - -0.859375 + - 0.8330078 + - 3.9570313 + - -0.0036258698 + - -2.5214844 + - 3.0898438 + - -0.70458984 + - -3.8535156 + - 0.6298828 + - -0.32739258 + - 3.1289063 + - -0.08618164 + - -1.21875 + - 0.09649658 + - 0.7675781 + - 0.39672852 + - -3.1464844 + - 0.7763672 + - -0.7680664 + - -1.0068359 + - -0.88671875 + - -0.2064209 + - 1.5820313 + - 0.7441406 + - 2.3671875 + - 2.8554688 + - 1.6601563 + - 6.0390625 + - -0.35351563 + - -3.4589844 + - 0.23046875 + - -2.2324219 + - -1.7626953 + - 3.2714844 + - 2.5566406 + - -0.61572266 + - 0.20751953 + - 1.2539063 + - 0.4423828 + - -2.1269531 + - 0.5131836 + - 0.62353516 + - -0.6958008 + - -0.33032227 + - -2.28125 + - 0.032348633 + - 0.3408203 + - 1.4726563 + - 1.8486328 + - 1.890625 + - 1.8886719 + - -0.37426758 + - 2.4140625 + - -2.3027344 + - 3.9121094 + - 0.85546875 + - -4.6953125 + - 0.32983398 + - 0.8154297 + - 3.2304688 + - 0.8305664 + - -0.42773438 + - -1.1630859 + - -3.9277344 + - 1.3681641 + - 0.18469238 + - 1.0292969 + - -2.1328125 + - -2.6738281 + - 1.3876953 + - 0.1361084 + - 0.99902344 + - -0.77783203 + - -0.064697266 + - 1.828125 + - 0.65771484 + - 0.03390503 + - 1.7265625 + - 1.2138672 + - 10.0703125 + - 0.064697266 + - 0.6723633 + - -0.4819336 + - 1.8457031 + - -1.4023438 + - 2.2148438 + - -0.5493164 + - -0.07574463 + - -0.20422363 + - 2.7597656 + - 3.3242188 + - -1.6425781 + - 1.5322266 + - 2.4785156 + - 1.4394531 + - -0.09094238 + - -1.203125 + - -1.6650391 + - -0.10546875 + - -0.8964844 + - 0.072509766 + - 1.1875 + - -2.4375 + - 0.08258057 + - -0.14453125 + - -3.1816406 + - 1.2851563 + - 1.8339844 + - 1.2412109 + - -3.8457031 + - 2.5703125 + - -1.4052734 + - -0.78564453 + - -1.3427734 + - -1.5039063 + - 2.3652344 + - -3.5820313 + - -4.078125 + - 1.7050781 + - 1.5644531 + - 0.7709961 + - 2.34375 + - -0.11657715 + - 2.7832031 + - -0.49926758 + - 0.08984375 + - 0.105285645 + - 2.7597656 + - -0.4482422 + - 2.1015625 + - 1.5488281 + - 1.9433594 + - 1.1533203 + - -0.21252441 + - 2.6777344 + - -5.0664063 + - -0.8847656 + - 2.1464844 + - -1.265625 + - 0.3330078 + - 0.5102539 + - -2.1738281 + - -0.7841797 + - -4.1015625 + - -1.609375 + - -1.6220703 + - -1.4111328 + - -1.4921875 + - 1.7324219 + - 4.359375 + - -1.3857422 + - 2.9726563 + - -2.90625 + - 6.1757813 + - 1.6982422 + - 1.4638672 + - -2.6894531 + - 0.7714844 + - -1.5244141 + - -2.125 + - 3.5058594 + - -0.3996582 + - 3.5996094 + - -1.4482422 + - 0.3935547 + - 0.7109375 + - 2.4746094 + - -1.3896484 + - -1.2880859 + - -1.9433594 + - -0.859375 + - -0.703125 + - 1.8554688 + - 1.8632813 + - -4.2226563 + - -8.125 + - -2.1074219 + - 0.453125 + - -0.09375 + - -2.6660156 + - -0.95751953 + - 0.047698975 + - -0.29663086 + - 2.6464844 + - 2.1074219 + - -2.1464844 + - 1.5498047 + - -2.3339844 + - 1.5898438 + - -0.5654297 + - -4.3476563 + - -0.1673584 + - 1.7988281 + - 2.0488281 + - -2.1660156 + - -14.390625 + - -0.12243652 + - -2.2089844 + - -1.6064453 + - 3.1171875 + - -1.1591797 + - 1.4433594 + - -0.19689941 + - -3.6835938 + - -1.4238281 + - -3.6152344 + - 5.109375 + - -0.5004883 + - -0.4736328 + - 2.7988281 + - -0.32592773 + - -0.75927734 + - 1.0458984 + - 0.1619873 + - -2.0371094 + - 2.2246094 + - -1.4375 + - -1.921875 + - -1.7138672 + - -3.8613281 + - 0.85009766 + - -0.37939453 + - -1.8525391 + - 0.5839844 + - -1.9013672 + - 0.7519531 + - 1.6748047 + - -1.3095703 + - -1.5087891 + - -0.6269531 + - -1.6445313 + - -2.2011719 + - -0.9091797 + - 0.06640625 + - 2.7050781 + - -2.1679688 + - -3.5800781 + - -0.009483337 + - 1.5244141 + - -0.58935547 + - -2.0390625 + - -0.47583008 + - 5.609375 + - 4.625 + - -0.033477783 + - 0.07110596 + - 3.2851563 + - -0.44482422 + - -2.8945313 + - -1.7675781 + - 2.7714844 + - -0.9301758 + - -0.84521484 + - -0.9785156 + - 0.27197266 + - 0.33666992 + - -2.3515625 + - 4.9375 + - 2.3125 + - 0.29882813 + - 1.015625 + - 0.35131836 + - 0.43896484 + - 0.8076172 + - -0.91064453 + - -0.6064453 + - 3.8203125 + - 0.5683594 + - 0.55908203 + - 0.9736328 + - -1.9970703 + - -0.3269043 + - 1.2158203 + - -6.0039063 + - 0.13977051 + - 3.71875 + - -0.5605469 + - 0.46313477 + - 1.5683594 + - -0.7011719 + - -0.46362305 + - -2.6328125 + - -1.3330078 + - 2.4570313 + - -2.0488281 + - -2.9238281 + - 5.375 + - 0.21679688 + - -5.9726563 + - 2.0390625 + - 0.055786133 + - 1.3359375 + - 3.8378906 + - -0.6225586 + - -0.6113281 + - -1.5830078 + - 2.8535156 + - 3.6679688 + - -2.5703125 + - -1.5019531 + - 0.69091797 + - -2.0332031 + - 1.6210938 + - -0.3408203 + - -0.5522461 + - -1.4355469 + - -0.5078125 + - 0.5957031 + - 1.5869141 + - 3.6757813 + - -0.018692017 + - 0.55566406 + - 1.4609375 + - 0.20336914 + - -1.3769531 + - 1.6767578 + - 2.1894531 + - 0.85253906 + - 0.4519043 + - -0.00390625 + - -1.8789063 + - 3.5800781 + - 0.16516113 + - -4.5117188 + - -0.12890625 + - -0.3557129 + - -1.6269531 + - -1.9589844 + - -1.0107422 + - 3.1054688 + - -0.8457031 + - -4.8476563 + - -2.3652344 + - -1.3818359 + - 0.20703125 + - 1.9863281 + - 1.4814453 + - 0.6333008 + - 1.9667969 + - -17.671875 + - -1.453125 + - -1.0478516 + - -2.0019531 + - -1.3818359 + - 0.61279297 + - 0.20227051 + - 0.0055770874 + - 2.3476563 + - -3.4804688 + - -1.0546875 + - -2.2363281 + - 1.2685547 + - -1.0302734 + - 0.87597656 + - -2.4453125 + - -1.4394531 + - -2.3496094 + - -2.2890625 + - -0.8925781 + - -1.9296875 + - 0.9921875 + - 0.2939453 + - -1.2851563 + - 1.1201172 + - 0.578125 + - 0.30908203 + - 0.7246094 + - -3.2089844 + - 0.65478516 + - 2.5683594 + - -3.2148438 + - -2.9394531 + - 1.6816406 + - 1.6416016 + - -2.3417969 + - -3.5 + - -1.1904297 + - 1.4462891 + - -3.1875 + - -1.890625 + - -0.1015625 + - -1.9082031 + - 1.4306641 + - 5.1757813 + - 3.9101563 + - 1.0263672 + - 3.2402344 + - -0.8222656 + - -0.68603516 + - 0.055786133 + - -2.2578125 + - -2.3261719 + - 0.15234375 + - -3.6972656 + - 0.5625 + - -4.3789063 + - 0.9506836 + - 2.5957031 + - -1.7587891 + - -1.9824219 + - 1.9609375 + - -0.60595703 + - -0.2524414 + - -1.5576172 + - 1.8701172 + - -2.1386719 + - 0.00390625 + - 1.4619141 + - 1.8613281 + - 0.00027894974 + - 0.44140625 + - -1.6054688 + - 3.4902344 + - 0.036834717 + - 1.4169922 + - 0.7788086 + - -0.12384033 + - 1.7070313 + - -0.52197266 + - -3.2265625 + - -2.6875 + - 0.61572266 + - 2.6113281 + - -2.8164063 + - -0.83251953 + - -0.25439453 + - 0.037384033 + - -2.2226563 + - -2.5703125 + - -0.08013916 + - 2.7851563 + - 4.390625 + - -1.0810547 + - 0.59375 + - -4.6757813 + - 7.9140625 + - -3.1503906 + - 0.73339844 + - 3.3554688 + - -1.6220703 + - -2.59375 + - 0.984375 + - -1.6298828 + - -0.5546875 + - 2.6933594 + - 3.8125 + - -0.45922852 + - 1.4638672 + - 1.0556641 + - 1.6621094 + - 3.1113281 + - -0.55126953 + - 2.4003906 + - 1.8222656 + - -2.0507813 + - 0.22314453 + - 0.98535156 + - -0.5253906 + - -1.0029297 + - 0.6152344 + - 0.6113281 + - -0.71191406 + - -2.9492188 + - -0.19580078 + - -0.98828125 + - -0.1899414 + - 0.044067383 + - 1.5214844 + - 1.734375 + - 1.0146484 + - -1.4179688 + - 7.7578125 + - 3.3652344 + - 7.0976563 + - 1.4726563 + - -5.7226563 + - -5.890625 + - -0.3828125 + - -1.3154297 + - -0.31958008 + - -1.5888672 + - 0.1907959 + - -0.23181152 + - -1.046875 + - 1.6132813 + - -1.9482422 + - 2.6699219 + - 3.2246094 + - 3.6679688 + - -0.9091797 + - -2.5136719 + - 0.5102539 + - 24.09375 + - 1.2988281 + - 0.88183594 + - 0.09313965 + - -3.0195313 + - 1.8251953 + - 0.71484375 + - 0.77197266 + - -2.15625 + - 1.1113281 + - 3 + - 2.96875 + - -0.28686523 + - -0.0496521 + - 0.5957031 + - 4.7929688 + - 1.4414063 + - 3.0625 + - -5.0664063 + - -0.17687988 + - -1.8623047 + - -1.8876953 + - -3.6367188 + - 0.9038086 + - -0.4519043 + - 1.453125 + - -0.27124023 + - -1.8652344 + - 2.1582031 + - 0.65771484 + - -3.4160156 + - -5.7304688 + - -0.22070313 + - -3.03125 + - -0.9975586 + - 1.8378906 + - -1.4101563 + - 1.4414063 + - 3.9804688 + - -1.9648438 + - -1.5292969 + - -1.8769531 + - 2.2949219 + - -0.23254395 + - -0.5600586 + - 1.2783203 + - 0.60791016 + - 1.453125 + - 0.8408203 + - -0.73535156 + - -0.99658203 + - -3.1132813 + - 2.9472656 + - -0.5136719 + - 0.32617188 + - -2.6640625 + - -1.5917969 + - 1.0527344 + - 0.119384766 + - -1.2695313 + - -1.6621094 + - 2.1621094 + - -1.7226563 + - -1.7275391 + - -0.45898438 + - -0.26733398 + - 2.6152344 + - 0.4230957 + - -1.1201172 + - -0.47021484 + - 4.1289063 + - 1.4775391 + - -0.26342773 + - 2.9726563 + - -2.859375 + - 2.3222656 + - 0.52197266 + - -1.1865234 + - -3.2050781 + - -1.1943359 + - 2.2285156 + - -2.5 + - 5.8789063 + - -0.001953125 + - 2.4101563 + - -0.78027344 + - -1.4560547 + - 0.8540039 + - 2.6914063 + - 0.49853516 + - -1.1474609 + - -0.55566406 + - 0.46972656 + - 1.1582031 + - -3.6191406 + - 2.3203125 + - -4.75 + - -4.75 + - -3.7871094 + - 1.0068359 + - 3.9179688 + - 1.4345703 + - -1.3925781 + - 0.171875 + - 2.4257813 + - 1.21875 + - -2.6074219 + - 1.1171875 + - -1.5332031 + - -4.0273438 + - -0.3540039 + - 5.6328125 + - 0.23010254 + - 2.109375 + - 1.9853516 + - -0.9951172 + - 2.140625 + - -0.2705078 + - -2.8164063 + - -0.19946289 + - 4.5820313 + - -2.5664063 + - -0.3581543 + - 2.8847656 + - -1.4316406 + - 0.06585693 + - 1.0810547 + - -1.1972656 + - -9.3359375 + - 1.4482422 + - -47.25 + - -1.2919922 + - -0.6015625 + - -2.0625 + - -3.9179688 + - -0.47729492 + - 0.296875 + - 1.0654297 + - 1.6640625 + - 1.0595703 + - 0.18188477 + - -1.796875 + - 4.6875 + - -0.5253906 + - -2.0019531 + - 1.5869141 + - 1.1044922 + - -0.7211914 + - 16.984375 + - 0.42285156 + - -0.9765625 + - -1.2626953 + - -0.9379883 + - -0.57958984 + - 0.4038086 + - 2.8007813 + - 0.87353516 + - -1.625 + - -0.4267578 + - -2.6699219 + - -0.9609375 + - -2.4199219 + - 0.1784668 + - 0.49438477 + - -0.88183594 + - 2.4472656 + - 1.0351563 + - 0.8046875 + - 1.4453125 + - 0.5073242 + - 3.921875 + - -0.3798828 + - 1.046875 + - 0.2524414 + - -3.1367188 + - 2.5292969 + - 0.12658691 + - -1.2939453 + - -0.52246094 + - -2.9902344 + - 0.3515625 + - -1.6132813 + - -0.08203125 + - -0.66015625 + - -0.059143066 + - 0.21252441 + - 1.9482422 + - -4.1484375 + - -2.4863281 + - 0.35864258 + - 0.18481445 + - -1.0009766 + - -2.59375 + - 1.2685547 + - 6.6015625 + - -0.65283203 + - -0.7451172 + - 4.7226563 + - -2.2519531 + - 2.3105469 + - -2.0625 + - -0.16796875 + - 0.17907715 + - -2.3144531 + - 2.8964844 + - -4.5703125 + - 3.5996094 + - -1.0625 + - 5.2304688 + - 0.46972656 + - 0.31811523 + - -3.0722656 + - 1.9150391 + - 0.18713379 + - 1.9267578 + - 2.9316406 + - -1.0644531 + - -0.28515625 + - 0.26489258 + - -0.71972656 + - 2.5703125 + - -1.4707031 + - -1.5351563 + - -2.7070313 + - 1.2441406 + - -0.47607422 + - -0.3474121 + - -0.8457031 + - -3.4179688 + - -1.0927734 + - -2.1328125 + - -5.7382813 + - -1.1689453 + - 0.2512207 + - 1.3505859 + - 3.4101563 + - 3.4472656 + - 0.40112305 + - 0.56689453 + - 0.064697266 + - 0.7753906 + - 0.9980469 + - -1.6445313 + - 2.921875 + - 0.97314453 + - 1.3320313 + - -2.6816406 + - 2.3125 + - -2.0449219 + - 2.2089844 + - 1.6376953 + - 0.4819336 + - -1.6738281 + - -1.7792969 + - 0.17663574 + - 0.31298828 + - 4.0273438 + - -0.7270508 + - 3.1933594 + - 2.3964844 + - 2.65625 + - 1.4794922 + - -0.0524292 + - 1.9814453 + - 0.39282227 + - 0.23828125 + - 2.7226563 + - -0.80126953 + - -2.8105469 + - 0.1665039 + - -2.1660156 + - -2.0292969 + - -2.4453125 + - -3.0078125 + - 1.9033203 + - 2.8339844 + - 2.7753906 + - -2.4765625 + - 0.8408203 + - -3.203125 + - 2.265625 + - -1.7246094 + - 4.75 + - 4.6875 + - 0.59472656 + - -0.53466797 + - 1.7792969 + - 0.2956543 + - 2.3515625 + - -4.1757813 + - 3.9179688 + - -1.46875 + - -4.9453125 + - -1.9033203 + - -1.0390625 + - -0.34399414 + - -2.9414063 + - -15.546875 + - 2.0390625 + - -1.2695313 + - 4.1445313 + - 1.2197266 + - 3.3535156 + - 1.3818359 + - 1.5996094 + - -0.45141602 + - -0.6635742 + - 1.65625 + - -2.0996094 + - 2.4941406 + - 1.4921875 + - 2.0800781 + - -3.2675781 + - 0.96191406 + - -0.0072517395 + - -0.21252441 + - 1.2314453 + - 2.2519531 + - -1.0253906 + - 0.35327148 + - -0.015625 + - 1.5966797 + - -4.4726563 + - 0.20471191 + - -1.7744141 + - -16.671875 + - 0.61865234 + - 0.1204834 + - 2.9863281 + - -4.984375 + - -1.5673828 + - 0.2685547 + - 1.1904297 + - -5.015625 + - -2.6191406 + - -2.6132813 + - 3.6992188 + - -0.53271484 + - -0.45141602 + - -2.3652344 + - 0.70166016 + - -6.203125 + - -1.1904297 + - -0.35180664 + - 0.74072266 + - 1.1875 + - -0.9941406 + - -0.24536133 + - -2.4628906 + - -0.63623047 + - 2.921875 + - -3.5 + - -0.0418396 + - -0.52783203 + - 1.5361328 + - 3.4628906 + - -1.8183594 + - 0.32592773 + - -1.4794922 + - -0.74853516 + - 2.2285156 + - -0.75097656 + - 0.43237305 + - -18.859375 + - -0.33251953 + - -1.9013672 + - 2.4355469 + - -4.1875 + - 2.4121094 + - 0.5698242 + - 1.2294922 + - 1.6337891 + - -0.6972656 + - 1.4189453 + - -1.1513672 + - 2.2636719 + - -1.9921875 + - 0.50927734 + - -0.11621094 + - 0.58740234 + - 0.045196533 + - 1.4101563 + - -4.8007813 + - -1.421875 + - 2.3144531 + - -2.7324219 + - -0.19055176 + - 2.9023438 + - -1.4501953 + - 3.1484375 + - -2.5957031 + - -1.5234375 + - 2.0722656 + - 1.359375 + - 3.15625 + - -2.1503906 + - -1.5009766 + - -1.6464844 + - -0.4116211 + - -0.60595703 + - -1.6875 + - 1.4931641 + - 1.8671875 + - 3.7695313 + - 1.6650391 + - 2.296875 + - 3.6601563 + - -2.0839844 + - 0.4116211 + - -2.2988281 + - -1.4267578 + - -6.0625 + - 1.0380859 + - 2.4628906 + - 0.46191406 + - 0.2548828 + - 0.19689941 + - -2.0976563 + - 0.6020508 + - 0.14929199 + - 8.09375 + - -0.37939453 + - -1.6357422 + - -1.1328125 + - 1.1572266 + - 1.5166016 + - 1.8105469 + - -1.7607422 + - -1.9306641 + - 0.43115234 + - 2.6933594 + - 0.68603516 + - 3.0800781 + - -3.4238281 + - -4.5898438 + - 0.8173828 + - 0.81689453 + - 1.5869141 + - 0.9785156 + - 0.3359375 + - -0.2454834 + - 4.140625 + - 0.45922852 + - 0.1227417 + - -2.3183594 + - 1.6416016 + - -0.86376953 + - 1.2724609 + - -3.3242188 + - -0.48486328 + - 1.7539063 + - -2.6875 + - 1.2851563 + - 3.9628906 + - 2.2578125 + - -0.9003906 + - -0.890625 + - 1.5214844 + - 1.3681641 + - 0.6738281 + - 2.875 + - 4.9257813 + - -0.41552734 + - 1.0478516 + - -0.67822266 + - 0.17907715 + - 0.7519531 + - 2.2324219 + - 1.2285156 + - 1.1103516 + - 0.13671875 + - -4.5898438 + - -0.58251953 + - 3.1289063 + - -2.9101563 + - -0.5 + - -3.109375 + - -0.7890625 + - 2.46875 + - 6.3671875 + - 1.0234375 + - -1.5839844 + - 1.7226563 + - 2.2578125 + - -0.53271484 + - -1.3720703 + - 1.2597656 + - -5.4179688 + - 1.2451172 + - 2.6855469 + - 5.4140625 + - -0.4560547 + - 0.5136719 + - -1.0898438 + - -0.8725586 + - -2.5917969 + - -3.6132813 + - 3.6015625 + - -0.8730469 + - 0.97802734 + - 5.375 + - -2.1015625 + - -1.2539063 + - -2.5039063 + - -0.38916016 + - -0.047546387 + - 0.2939453 + - -1.1806641 + - -0.13952637 + - 3.3027344 + - -0.9951172 + - 0.3881836 + - 1.9726563 + - 0.578125 + - -0.53564453 + - -0.30908203 + - 3.3164063 + - -0.27539063 + - 0.8676758 + - 1.8466797 + - 2.5957031 + - 0.625 + - -0.63427734 + - -3.7246094 + - -3.3027344 + - 0.061645508 + - 3.0683594 + - -0.9375 + - 2.4726563 + - -0.6616211 + - 1.5009766 + - -0.15673828 + - -3.625 + - 0.9790039 + - 0.10180664 + - -0.1430664 + - -1.1445313 + - -2.4355469 + - 6.703125 + - -2.4082031 + - 0.82666016 + - -1.2753906 + - 2.6503906 + - 0.7402344 + - -0.68408203 + - -2.0527344 + - 0.01701355 + - -3.9707031 + - 0.9741211 + - 0.3684082 + - 1.9746094 + - 1.2275391 + - 11.5703125 + - -1.9726563 + - -1.2568359 + - 1.5556641 + - 0.38720703 + - 6.0625 + - 4.03125 + - 0.3269043 + - -1.5058594 + - -0.7089844 + - 0.52783203 + - 8.3125 + - 0.38867188 + - -0.64453125 + - 0.23876953 + - -1.2001953 + - 0.69921875 + - -3.109375 + - -2.7402344 + - -2.3964844 + - -3.6738281 + - 1.8652344 + - -3.6816406 + - -1.0703125 + - 1.0126953 + - 0.83251953 + - -4.9414063 + - -0.2487793 + - 0.36669922 + - 1.9873047 + - -0.4453125 + - -1.421875 + - 1.3291016 + - -1.1318359 + - -1.125 + - 2.25 + - 0.49023438 + - 1.9892578 + - 4.171875 + - -1.8466797 + - 1.5117188 + - 0.41845703 + - -4.1914063 + - -1.8828125 + - -0.3010254 + - -1.7539063 + - 3.1015625 + - -1.0146484 + - 0.4970703 + - 3.1601563 + - 0.080078125 + - 3.5722656 + - -0.74072266 + - 3.1738281 + - -1.8457031 + - 3.15625 + - -0.88671875 + - -3.90625 + - -2.7324219 + - -3.7539063 + - 1.6591797 + - 1.1328125 + - -0.9873047 + - -0.70703125 + - -0.78564453 + - -0.30078125 + - -2.2480469 + - -1.0400391 + - 1.1386719 + - 1.0878906 + - -0.74658203 + - 2.7128906 + - -9.265625 + - 3.6757813 + - 3.4140625 + - -0.7910156 + - 0.8730469 + - -2.4628906 + - -0.8623047 + - 0.82128906 + - -0.09765625 + - 1.9785156 + - 0.9145508 + - -0.8256836 + - 3.8378906 + - 0.45043945 + - -1.5556641 + - -2.703125 + - -0.60546875 + - 1.1132813 + - -0.43652344 + - -2.0175781 + - -0.31958008 + - -0.07867432 + - -1.5126953 + - 3.2539063 + - 0.37036133 + - -6.2109375 + - 1.9072266 + - 4.3515625 + - -0.01171875 + - 0.04852295 + - 0.296875 + - 0.8154297 + - 1.7441406 + - 2.4199219 + - 3.375 + - 0.42578125 + - 0.5605469 + - -0.43188477 + - -0.09667969 + - 1.4482422 + - 2.7324219 + - -0.17468262 + - -3.9589844 + - 10.7734375 + - 2.2988281 + - -3.1738281 + - -71 + - 0.8598633 + - -1.671875 + - -0.8847656 + - 2.8320313 + - 4.7929688 + - 1.6953125 + - 0.8984375 + - -0.09063721 + - -2.2050781 + - -2.765625 + - 1.6904297 + - -0.7163086 + - 2.3457031 + - 0.35083008 + - -5.0625 + - -2.6972656 + - -3.0078125 + - -0.32592773 + - 1.7851563 + - 2.4550781 + - 0.5205078 + - 1.1357422 + - -0.9584961 + - -1.6064453 + - -2.7480469 + - -1.6689453 + - -3.2753906 + - 1.0966797 + - -1.7207031 + - 1.1298828 + - -4.6367188 + - 0.08984375 + - -1.109375 + - -3.8867188 + - 1.0859375 + - 1.0166016 + - -0.043792725 + - 1.3095703 + - -2.6269531 + - -0.30297852 + - -1.3212891 + - 4.2148438 + - 1.796875 + - 1.2851563 + - -2.6074219 + - 2.0527344 + - 1.4707031 + - 2.9453125 + - 0.33374023 + - 1.2978516 + - -0.5600586 + - 1.0791016 + - 9.7578125 + - -4.8945313 + - 1.8242188 + - 0.14147949 + - 0.9223633 + - 0.3815918 + - -2.0175781 + - 0.9194336 + - 2.046875 + - 0.3852539 + - -3.15625 + - -0.7392578 + - 0.11602783 + - -4.640625 + - 0.7426758 + - -0.93603516 + - 0.4621582 + - -2.9628906 + - 2.0625 + - 2.890625 + - 0.58935547 + - 1.4394531 + - 0.2878418 + - -2.2128906 + - -0.7866211 + - 0.54345703 + - 1.0351563 + - -0.11187744 + - 0.4152832 + - -1.7988281 + - -1.1962891 + - 0.7685547 + - -2.7597656 + - 2.4375 + - 3.6503906 + - -0.6088867 + - -1.0214844 + - -1.2431641 + - 2.0878906 + - -0.15905762 + - 2.8632813 + - 2.4941406 + - 7.8046875 + - 1.8417969 + - 3.0839844 + - -1.7001953 + - 0.81103516 + - 1.5585938 + - -0.31445313 + - 0.3947754 + - 1.9375 + - -0.9941406 + - 0.13220215 + - -0.83740234 + - -2.9550781 + - 0.67822266 + - -1.1914063 + - 5.3007813 + - 16.75 + - 1.0976563 + - -0.65185547 + - -3.8984375 + - 1.375 + - -0.75 + - 1.6728516 + - 2.3945313 + - -0.31225586 + - -0.9316406 + - 3.2753906 + - 0.94970703 + - 1.359375 + - -1.875 + - 2.1777344 + - 2.2441406 + - -4.0898438 + - 1.3691406 + - 0.30395508 + - 2.1152344 + - 0.1126709 + - -1.7089844 + - 1.3037109 + - -0.82666016 + - 3.9414063 + - 1.4775391 + - -1.4306641 + - 3.2910156 + - 1.3632813 + - -1.796875 + - -3.2226563 + - 1.6689453 + - -0.072509766 + - -2.9960938 + - 0.76416016 + - 0.1616211 + - -2.6503906 + - 0.085510254 + - 1.9941406 + - 0.55908203 + - 0.34423828 + - 3.0351563 + - 1.4033203 + - -0.54785156 + - 0.37817383 + - 3.5644531 + - -0.7607422 + - 2.7578125 + - 0.76660156 + - 3.2304688 + - 2.390625 + - -2.2675781 + - -1.4804688 + - 2.2480469 + - 6.3867188 + - -2.7519531 + - -0.3305664 + - 3.0195313 + - -4.2539063 + - 0.103515625 + - -0.5175781 + - -2.2578125 + - 0.27441406 + - 0.76660156 + - 2.3105469 + - 1.1015625 + - 0.081726074 + - -0.16015625 + - -0.0078125 + - -1.9619141 + - -0.63720703 + - -2.21875 + - 0.4033203 + - 1.1953125 + - 0.39013672 + - -2.21875 + - -1.65625 + - -2.0566406 + - -1.6669922 + - -10.375 + - 0.6894531 + - 0.6230469 + - -0.0446167 + - -0.6328125 + - -1.4785156 + - -3.3125 + - 1.4169922 + - -0.5205078 + - 1.609375 + - 3.4453125 + - 1.1767578 + - 2.6171875 + - 5.765625 + - -1.453125 + - 1.8847656 + - -3.3789063 + - -3.6875 + - -2.703125 + - 1.6894531 + - 0.23828125 + - -2.6445313 + - 2.9140625 + - -2.3457031 + - -0.65478516 + - 0.69970703 + - 1.2314453 + - 5.4804688 + - -0.18164063 + - 0.48754883 + - 3.3339844 + - 4.1132813 + - -3.0664063 + - -5.390625 + - -0.29589844 + - 0.8984375 + - 1.0292969 + - 2.5839844 + - -0.093444824 + - -1.4394531 + - 2.6972656 + - 2.3828125 + - -0.29467773 + - -1.8320313 + - -1.3818359 + - 2.1191406 + - 0.82128906 + - 3.8769531 + - 1.8378906 + - -0.46313477 + - 3.375 + - 1.1123047 + - 1.0087891 + - 2.1347656 + - -3.4277344 + - -2.8945313 + - -2.65625 + - 2.4277344 + - 2.7734375 + - -1.9775391 + - -3.71875 + - -3.6953125 + - -1.5332031 + - -4.8945313 + - 0.98828125 + - -1.0302734 + - 2.1640625 + - 0.5756836 + - -2.96875 + - -4.15625 + - -0.06274414 + - 0.03515625 + - 3.4160156 + - 0.92285156 + - -0.64697266 + - -1.0117188 + - 20.421875 + - 1.1201172 + - 0.58251953 + - 2.1933594 + - 8.015625 + - -0.35546875 + - -0.2253418 + - 0.3088379 + - 0.7392578 + - -3.4335938 + - -0.8833008 + - 4.125 + - -2.3203125 + - 4.7304688 + - 0.66845703 + - 0.73535156 + - -0.64697266 + - 0.68310547 + - -2.9316406 + - -2.5644531 + - 5.1523438 + - -0.84277344 + - 0.48046875 + - 3.7089844 + - 0.16040039 + - -3.9765625 + - 1.3769531 + - 2.2441406 + - 0.9951172 + - 0.20532227 + - 0.63134766 + - 0.3720703 + - 3.1738281 + - 0.61279297 + - -4.0507813 + - 0.96191406 + - -0.62353516 + - -0.9472656 + - -1.0126953 + - -4.5390625 + - 5.3164063 + - 2.5136719 + - -6.2109375 + - -1.0478516 + - 1.4082031 + - 2.2832031 + - -1.5019531 + - 1.1425781 + - 1.7949219 + - -2.5058594 + - 3.6738281 + - 0.515625 + - 2.3613281 + - 0.29858398 + - 6.1289063 + - 1.1318359 + - 0.29174805 + - 1.046875 + - -2.0136719 + - -3.8242188 + - 4.546875 + - 3.0429688 + - 2.7207031 + - 0.028457642 + - 0.33691406 + - 0.15515137 + - 2.9394531 + - -3.4550781 + - 0.39282227 + - 0.38305664 + - -4.5078125 + - -1.8945313 + - 1.9765625 + - 2.75 + - -4.6992188 + - -2.0136719 + - -1.1396484 + - -3.2890625 + - -1.2226563 + - -2.7890625 + - 1.3349609 + - 1.0654297 + - 0.18237305 + - -3.5683594 + - -0.7392578 + - 2.5644531 + - 1.5683594 + - -1.3681641 + - -2.8691406 + - 1.3779297 + - -1.5214844 + - -0.83691406 + - -4.0742188 + - -2.375 + - -4.5429688 + - 2.6953125 + - 0.6816406 + - -3.203125 + - -2.5175781 + - -2.1894531 + - 1.2763672 + - 0.5151367 + - -0.6088867 + - 4.1289063 + - -3.0625 + - 0.6694336 + - -0.07446289 + - -1.6347656 + - 4.0546875 + - -3.6660156 + - 1.1875 + - -2.1308594 + - 2.0566406 + - -0.37890625 + - -4.78125 + - -1.0332031 + - 3.9765625 + - 0.3557129 + - 1.2753906 + - -2.8867188 + - 2.3613281 + - -6.140625 + - 1.2578125 + - 0.69873047 + - -0.89160156 + - 3.6640625 + - 3.5039063 + - 1.4873047 + - 2.4082031 + - -0.64160156 + - 0.66015625 + - -2.4589844 + - -3.3144531 + - -2.1328125 + - 2.8867188 + - 0.7421875 + - -1.4570313 + - 1.7060547 + - 1.0664063 + - -0.52685547 + - 2.5371094 + - -1.890625 + - -1.6679688 + - 1.2255859 + - -0.51953125 + - -1.5722656 + - 1.5800781 + - 0.42919922 + - 0.4934082 + - 3.7558594 + - 2.6347656 + - 0.0892334 + - -1.2910156 + - -5.2148438 + - 3.09375 + - 1.4492188 + - -2.1113281 + - 2.4453125 + - 1.5205078 + - -3.7050781 + - 2.1386719 + - 1.9863281 + - -1.7480469 + - 2.6875 + - -2.9941406 + - -1.9804688 + - -1.8417969 + - 0.51708984 + - 1.8808594 + - 0.34106445 + - -1.5683594 + - -5.5898438 + - -0.23840332 + - -1.6435547 + - -0.86816406 + - -1.3125 + - -5.1445313 + - 3.1347656 + - 0.6113281 + - -2.2421875 + - 1.0253906 + - -1.7421875 + - 3.6621094 + - -2.1660156 + - 2.3730469 + - -1.4462891 + - 0.33862305 + - -0.83984375 + - -0.49267578 + - 1.8681641 + - -0.2175293 + - -0.25854492 + - -3.2089844 + - 0.10430908 + - -1.5869141 + - 1.0126953 + - 1.2773438 + - 3.75 + - -1.6982422 + - -2.1621094 + - -0.034454346 + - 3.90625 + - 2.0703125 + - -1.0029297 + - -3.7441406 + - -1.1357422 + - -2.8867188 + - 8.7734375 + - -1.75 + - -0.11102295 + - -1.7871094 + - 4.3984375 + - 1.2919922 + - 1.1982422 + - 0.79785156 + - -1.3037109 + - 0.2175293 + - -0.7133789 + - 2.1738281 + - -5.390625 + - -2.6777344 + - 5.7382813 + - -4.1210938 + - 3.6914063 + - -1.0966797 + - 0.49926758 + - 0.63720703 + - 3.8164063 + - 0.39770508 + - -1.3974609 + - -0.011154175 + - 0.9560547 + - 2.171875 + - -4.8320313 + - 1.7783203 + - 0.55126953 + - -3.1738281 + - -1.4326172 + - -0.23596191 + - -1.140625 + - -0.22290039 + - -1.1679688 + - 0.34204102 + - 1.5605469 + - -0.85595703 + - -2.0996094 + - -3.8925781 + - 0.55126953 + - -1.4453125 + - -1.6191406 + - 0.23510742 + - 2.6875 + - 0.5488281 + - 2.5390625 + - -0.30566406 + - -0.31054688 + - -1.75 + - 3.4765625 + - 2.8691406 + - -1.8105469 + - -0.67822266 + - -3.6894531 + - -2.2324219 + - 1.7548828 + - 0.15344238 + - -2.2128906 + - -2.3222656 + - -0.578125 + - 1.2382813 + - -0.4765625 + - 0.88134766 + - 2.4453125 + - -0.92285156 + - -3.0878906 + - -2.65625 + - 0.1439209 + - -2.96875 + - -1.8652344 + - -1.0390625 + - -2.1757813 + - -2.8847656 + - -0.6171875 + - -0.8310547 + - -1.3662109 + - 5.4140625 + - 4.6992188 + - -4.21875 + - -0.35668945 + - -1.2822266 + - 1.4794922 + - -2.3300781 + - -2.2949219 + - 3.5800781 + - -1.3066406 + - -2.5527344 + - 1.4326172 + - 2.2753906 + - -2.203125 + - -3.6445313 + - -0.66503906 + - -1.7519531 + - -1.0224609 + - 0.15905762 + - -0.32299805 + - -0.7036133 + - -1.9609375 + - -1.0732422 + - -1.2900391 + - -0.7626953 + - -2.0644531 + - -2.2519531 + - -0.75390625 + - -0.3725586 + - 3.9863281 + - -2.7480469 + - 3.9023438 + - -1.9814453 + - -0.93847656 + - 6.5117188 + - 0.60546875 + - -0.82666016 + - -1.3544922 + - 0.6323242 + - -2.96875 + - 3.3164063 + - 6.4257813 + - -2.3164063 + - -0.70703125 + - 5.7226563 + - 0.9033203 + - 1.3867188 + - 0.39868164 + - -1.9765625 + - 1.0751953 + - 0.51123047 + - -2.9804688 + - 1.3408203 + - -0.8623047 + - -0.3305664 + - 2.6601563 + - -7.1601563 + - 0.71728516 + - 4.21875 + - -2.4765625 + - -0.79003906 + - -2.1503906 + - 4.2460938 + - -5.1679688 + - -2.3320313 + - -0.23156738 + - 1.5947266 + - 2.4082031 + - -0.6894531 + - 1.6523438 + - -2.3300781 + - -2.6777344 + - 2.3339844 + - -0.69189453 + - 0.39379883 + - -2.3339844 + - 3.765625 + - 0.6713867 + - -1.71875 + - -2.4199219 + - -1.2382813 + - -0.22509766 + - 0.57373047 + - -0.34472656 + - 0.5488281 + - 2.0214844 + - -2.5917969 + - -0.09649658 + - -2.7949219 + - 0.71972656 + - 0.95751953 + - 1.1845703 + - -1.2763672 + - -2.2324219 + - -3.1464844 + - 1.2744141 + - 0.5834961 + - 1.15625 + - -0.36157227 + - -2.1542969 + - -2.1152344 + - 1.2978516 + - -3.0253906 + - -2.5078125 + - -1.9648438 + - 3.6992188 + - -3.4804688 + - -1.9482422 + - -0.6015625 + - 2.3535156 + - -1.609375 + - 0.017578125 + - -1.0625 + - -0.9248047 + - -0.30395508 + - -4.1132813 + - 0.8129883 + - 1.6357422 + - 4.8632813 + - -1.6777344 + - 1.4501953 + - -0.2841797 + - 6.375 + - 1.9326172 + - -0.73095703 + - 1.4150391 + - 1.7363281 + - -0.64941406 + - -1.9150391 + - -1.2910156 + - 1.2724609 + - 1.7753906 + - 3.4375 + - -1.9316406 + - 2.3691406 + - -0.04574585 + - -0.054595947 + - 2.40625 + - -0.54248047 + - -0.9785156 + - 1.7080078 + - -1.4541016 + - -2.8515625 + - 0.9140625 + - 0.92041016 + - -3.3164063 + - -0.5415039 + - 1.859375 + - -1.9082031 + - -1.2275391 + - -0.16516113 + - -0.29711914 + - 4.4257813 + - 6.828125 + - -1.8183594 + - -0.18664551 + - -3.7402344 + - -2.1445313 + - 0.515625 + - 1.0849609 + - -2.375 + - 1.8476563 + - -3.6679688 + - -2.8671875 + - -0.51171875 + - -2.3496094 + - -0.9980469 + - -2.3925781 + - -0.021759033 + - 1.8232422 + - 1.421875 + - -0.38916016 + - 1.7294922 + - 2.8515625 + - -0.71875 + - -2.0195313 + - 1.3427734 + - 2.3515625 + - 0.8647461 + - -1.6259766 + - -0.9580078 + - 0.50634766 + - 0.05996704 + - -0.2841797 + - -3.6992188 + - -1.28125 + - -1.3017578 + - 1.7587891 + - -0.9296875 + - 0.9707031 + - 0.14562988 + - 2.8203125 + - -0.19946289 + - -1.4619141 + - 8.03125 + - -2.1171875 + - 3.65625 + - -4.03125 + - 3.6367188 + - 4.2148438 + - -4.0703125 + - 1.1347656 + - 1.7832031 + - -0.21923828 + - -1.1455078 + - -0.35864258 + - -0.16906738 + - 1.8251953 + - -1.71875 + - -1.2568359 + - -1.7851563 + - 3.9589844 + - -0.72753906 + - 1.2275391 + - 0.44628906 + - -1.2568359 + - 0.9194336 + - -0.515625 + - -0.5131836 + - -1.1142578 + - 3.3339844 + - 0.8959961 + - -2.1777344 + - 1.6064453 + - -0.6953125 + - -2.7265625 + - 0.44482422 + - -2.1367188 + - -0.85253906 + - 2.6328125 + - 2.1464844 + - 2.1816406 + - -8.9609375 + - 4.40625 + - -0.578125 + - 0.32617188 + - 0.48632813 + - -3.5039063 + - 1.9033203 + - 0.44970703 + - -1.4980469 + - 1.4433594 + - -4.6289063 + - 0.4033203 + - -0.2097168 + - -0.4741211 + - 0.07739258 + - 0.23547363 + - 1.1494141 + - -0.3383789 + - -0.7475586 + - 0.73291016 + - 2.0761719 + - -2.421875 + - 1.4589844 + - -2.5488281 + - 1.5820313 + - 2.3574219 + - 0.77978516 + - 1.0751953 + - 1.9609375 + - -0.33642578 + - 0.08258057 + - -1.2607422 + - 4.4570313 + - 1.421875 + - 2.5390625 + - 1.0185547 + - -4.046875 + - 0.6635742 + - -0.4050293 + - -0.3876953 + - -0.26391602 + - 1.1337891 + - -0.93896484 + - 1.3505859 + - 6.3554688 + - 1.0771484 + - -8.7421875 + - 1.2646484 + - 1.3359375 + - -0.11853027 + - -0.98535156 + - 2.9433594 + - 6.1757813 + - -1.8076172 + - -0.09399414 + - -0.6176758 + - -1.4550781 + - 1.4707031 + - -0.77441406 + - 0.2220459 + - -0.23046875 + - -2.4199219 + - -0.43237305 + - -0.49902344 + - 4.078125 + - -1.9355469 + - -1.4414063 + - 0.12658691 + - 1.7949219 + - 3.6269531 + - 2.203125 + - 1.0576172 + - 0.4970703 + - 2.703125 + - 0.66748047 + - -24.875 + - 1.6738281 + - -4.6367188 + - -1.8183594 + - -15.671875 + - -1.2578125 + - -0.6875 + - 3.0644531 + - -3.7109375 + - 2.6074219 + - -7.5507813 + - -7.9296875 + - 0.8076172 + - -0.953125 + - 2.0195313 + - -1.1660156 + - 0.38110352 + - 4.4414063 + - -0.9458008 + - 1.5400391 + - 1.0097656 + - 2.0351563 + - 1.9921875 + - -2.9023438 + - -2.4785156 + - 3.6640625 + - -2.578125 + - 1.8388672 + - 1.6982422 + - -5.0117188 + - 1.9042969 + - -0.31152344 + - -0.0836792 + - 2.3574219 + - 0.6328125 + - -1.6601563 + - 1.71875 + - -1.8515625 + - 0.73095703 + - -0.04421997 + - 0.4597168 + - 0.034576416 + - 3.46875 + - 1.4013672 + - 0.056915283 + - 3.71875 + - 2.7539063 + - 1.515625 + - -1.0654297 + - -1.0966797 + - 1.7587891 + - -1.0693359 + - -2.015625 + - 2.0742188 + - 1.3916016 + - 3.1171875 + - -1.6464844 + - -4.7148438 + - 0.67529297 + - -2.6191406 + - 0.16125488 + - 2.4453125 + - -3.1289063 + - -0.6386719 + - -0.37548828 + - -0.41308594 + - -0.12719727 + - 4.5664063 + - 2.8710938 + - 1.4658203 + - -4.6757813 + - -0.140625 + - 3.0175781 + - 0.5756836 + - -0.4440918 + - 1.3955078 + - 0.27856445 + - -0.7294922 + - -1.0048828 + - 2.1171875 + - -3.4804688 + - -0.22387695 + - 1.3056641 + - -0.33764648 + - 0.57910156 + - 4.0429688 + - -0.57177734 + - 0.72314453 + - -1.4560547 + - -3.84375 + - 0.8569336 + - -1.7167969 + - 0.9316406 + - -1.5507813 + - -2.4707031 + - 0.9458008 + - -3.0820313 + - -8.6328125 + - 0.87353516 + - -3.7128906 + - 0.2854004 + - 2.3984375 + - 1.1992188 + - -3.4628906 + - 0.6176758 + - -3.5625 + - -1.8496094 + - -5.140625 + - -0.8227539 + - 0.005859375 + - -0.0052986145 + - 3.953125 + - -0.890625 + - 1.4560547 + - -3.1464844 + - -2.7402344 + - -1.1064453 + - 0.2019043 + - -0.8989258 + - -3.078125 + - 0.8232422 + - -2.5 + - -0.43896484 + - -0.1282959 + - 1.2353516 + - -0.3251953 + - 0.5102539 + - -3.4140625 + - -1.6064453 + - 0.57910156 + - -5.2148438 + - -2.2265625 + - 2.5878906 + - 5.3945313 + - 5.4765625 + - -0.2890625 + - 0.234375 + - 4.4335938 + - 3.2617188 + - -1.6669922 + - -0.90234375 + - -2.3027344 + - 0.3310547 + - 2.8554688 + - -1.0009766 + - -0.7446289 + - -0.61035156 + - -0.75390625 + - -2.0234375 + - -2.2988281 + - 2.4609375 + - -1.8125 + - 1.2353516 + - -0.21203613 + - -2.3457031 + - -0.0234375 + - 0.78027344 + - 1.3662109 + - -0.5136719 + - -0.7988281 + - 0.52685547 + - 2.2109375 + - -0.9453125 + - -1.5009766 + - -4.6523438 + - -0.0446167 + - 0.20629883 + - 3.40625 + - -0.46484375 + - 0.18688965 + - 2.3476563 + - 23.5 + - -0.89501953 + - -3.078125 + - 4.3554688 + - 0.5859375 + - 4.0507813 + - -2.0214844 + - -13.3359375 + - 1.4970703 + - -1.0517578 + - 4.7578125 + - 0.66796875 + - 0.11383057 + - 1.2236328 + - 0.84375 + - 2.2851563 + - 1.4814453 + - -4.9257813 + - 0.3095703 + - -4.7148438 + - 1.0253906 + - -3.7539063 + - 0.3647461 + - -0.20080566 + - -1.4785156 + - 3.5820313 + - -0.93603516 + - -2.2539063 + - 0.28979492 + - 3.0644531 + - -0.5317383 + - -0.69189453 + - 1.3955078 + - -1.6269531 + - -1.3457031 + - -2.0546875 + - -0.33032227 + - -0.26245117 + - -0.96191406 + - 0.11212158 + - -2.59375 + - 2.2695313 + - -1.0654297 + - -1.7246094 + - 1.9658203 + - -0.79833984 + - 0.2915039 + - 1.7851563 + - -3.4238281 + - 3.5742188 + - 1.0439453 + - -1.3769531 + - 5.90625 + - -2.6601563 + - -2.3691406 + - 0.82666016 + - 0.78759766 + - 2.9375 + - -2.3515625 + - 1.5 + - -2.4375 + - 3.8339844 + - 0.71240234 + - -1.1992188 + - -0.064697266 + - 6.109375 + - 3.3691406 + - -0.4128418 + - -1.7158203 + - -0.36547852 + - -1.1796875 + - -0.25268555 + - -0.30004883 + - -0.19189453 + - -2.7128906 + - -5.9140625 + - 6.5351563 + - 0.93652344 + - -2.375 + - -1.8955078 + - 1.6201172 + - 0.37719727 + - -0.3203125 + - -0.21618652 + - 0.5834961 + - 1.2314453 + - 0.7866211 + - 1.6142578 + - -3.2421875 + - 0.8457031 + - 1.3232422 + - -1.9501953 + - 0.4663086 + - 0.171875 + - 5.1757813 + - 2.1445313 + - -1.6201172 + - 4.75 + - -1.0703125 + - 2.4765625 + - 4.703125 + - -0.546875 + - -1.9902344 + - 5.75 + - 0.78759766 + - 0.38598633 + - -1.2539063 + - -0.17272949 + - 2.4550781 + - 1.6503906 + - -1.2587891 + - -1.6191406 + - -1.8496094 + - -0.71777344 + - -0.42578125 + - 0.38891602 + - 0.73339844 + - 0.124572754 + - 0.29614258 + - -2.078125 + - 2.2597656 + - 23.0625 + - -3.9101563 + - 2.9414063 + - -0.17468262 + - 0.92871094 + - 2.359375 + - 0.18408203 + - -2.0410156 + - 0.2841797 + - -0.84375 + - -1.4482422 + - 1.9472656 + - -2.3066406 + - -1.7001953 + - -0.2607422 + - 0.31054688 + - -5.1601563 + - 1.984375 + - 2.1582031 + - 14.546875 + - -2.6972656 + - 1.4003906 + - -0.11602783 + - -1.4023438 + - 0.2097168 + - -0.65283203 + - 0.63623047 + - 0.6635742 + - -0.21679688 + - -1.2744141 + - -26 + - -0.5024414 + - 0.55078125 + - 1.0732422 + - -2.9140625 + - -0.4934082 + - -0.6484375 + - 0.9169922 + - -2.46875 + - 0.9277344 + - 0.59472656 + - -3.8222656 + - -1.3505859 + - -0.8232422 + - -0.15454102 + - -1.0322266 + - -1.2919922 + - -2.9804688 + - 0.62353516 + - -0.2298584 + - -2.3261719 + - 0.8232422 + - 2.6308594 + - 0.26000977 + - 3.421875 + - -1.4072266 + - 3.1738281 + - -0.5625 + - 7.6953125 + - -1.9335938 + - 2.5839844 + - 4.0078125 + - -6.6484375 + - 2.421875 + - -2.1796875 + - 4.359375 + - -0.8208008 + - -0.51123047 + - -1.7314453 + - 0.5083008 + - 0.62841797 + - 0.9926758 + - -5.5351563 + - 2.9492188 + - -0.17919922 + - -2.4003906 + - 0.0287323 + - 2.7089844 + - 2.53125 + - 2.6328125 + - 2.5039063 + - -1.953125 + - -1.2744141 + - 1.8378906 + - 4.15625 + - 1.4326172 + - -1.4902344 + - -3.828125 + - -0.64501953 + - -4.1679688 + - -1.1298828 + - 2.1113281 + - 2.2246094 + - 3.640625 + - -1.1396484 + - 4.890625 + - 4.9960938 + - 2.046875 + - -0.7363281 + - -1.0830078 + - 0.77001953 + - -1.2724609 + - 1.3398438 + - -1.2626953 + - 1.3603516 + - -1.4814453 + - -2.6640625 + - 0.6230469 + - -3.5585938 + - -0.33764648 + - -3.3710938 + - -3.9375 + - -0.76416016 + - 0.515625 + - 3.0039063 + - -1.4169922 + - -0.14941406 + - 2.9160156 + - 0.7988281 + - 0.52783203 + - -2.7890625 + - 3.3554688 + - 2.0605469 + - -1.4150391 + - -3.3203125 + - 3.6054688 + - -0.5683594 + - 3.9394531 + - -2.7871094 + - -0.92089844 + - -1.0517578 + - 0.8227539 + - 3.4941406 + - 2.4726563 + - -0.17443848 + - 0.9404297 + - -3.7363281 + - -6.046875 + - -0.46191406 + - -1.4882813 + - 2.6621094 + - 2.6914063 + - 0.81933594 + - 1.0390625 + - 2.1582031 + - 0.5991211 + - -0.0715332 + - 2.3574219 + - -1.8457031 + - 2.953125 + - 1 + - -0.45532227 + - -0.33251953 + - -0.8066406 + - -0.6645508 + - 12.1953125 + - 0.5239258 + - 2.53125 + - 5.7851563 + - 7.796875 + - -1.2158203 + - 0.42822266 + - -1.0888672 + - 1.4638672 + - -2.6542969 + - -1.7939453 + - 1.3466797 + - 0.6689453 + - 0.30126953 + - -2.5625 + - -0.71875 + - 1.0185547 + - 1.890625 + - 1.9335938 + - 0.34350586 + - -0.17382813 + - -0.18469238 + - -0.78125 + - -1.9404297 + - -2.1035156 + - -1.4277344 + - 1.2451172 + - -0.46313477 + - -2.4238281 + - -3.4238281 + - 2.7890625 + - 2.1503906 + - 1.9921875 + - 1.015625 + - 0.2241211 + - -0.98291016 + - 1.9423828 + - -1.75 + - 0.74072266 + - 1.8212891 + - -1.4931641 + - 1.2539063 + - -1.7744141 + - -0.55615234 + - 3.9394531 + - -0.7192383 + - 1.7138672 + - -2.6484375 + - -1.0947266 + - -2.9023438 + - 3.21875 + - 1.0126953 + - -2.4042969 + - -1.1142578 + - 4.1015625 + - 1.8300781 + - 1.0361328 + - 1.5976563 + - 4.1875 + - 0.8457031 + - -1.8183594 + - -1.6669922 + - 1.4794922 + - 1.5244141 + - 1.203125 + - 4.1875 + - 2.5175781 + - 2.2617188 + - 1.9628906 + - -1.4160156 + - -0.6542969 + - -1.8525391 + - 1.2382813 + - 0.2019043 + - -0.050201416 + - -1.1044922 + - 0.3461914 + - 1.390625 + - 0.10290527 + - 3.0859375 + - -0.97753906 + - 0.08258057 + - 0.86376953 + - -0.26757813 + - 23.46875 + - -3.4707031 + - -1.1474609 + - -4.2460938 + - -0.22851563 + - 0.73583984 + - 2.34375 + - -0.092041016 + - -4.7851563 + - 1.6845703 + - 2.5976563 + - -1.359375 + - 3.3945313 + - 2.5351563 + - 1.9492188 + - 0.52001953 + - 1.6367188 + - -3.0742188 + - 1.7148438 + - 0.96191406 + - -2.2128906 + - 1.7011719 + - -3.6757813 + - 1.7763672 + - 0.0758667 + - 0.82177734 + - -2.2089844 + - 0.11645508 + - 2.3359375 + - -3.7753906 + - -0.76953125 + - 1.3154297 + - 2.078125 + - 2.1328125 + - 2.4160156 + - -1.5634766 + - 6.2851563 + - -0.03125 + - 0.32592773 + - -0.65625 + - -4.3359375 + - -3.5664063 + - 0.5019531 + - 4.9257813 + - 0.38012695 + - 0.20166016 + - -1.5683594 + - 1.7353516 + - 2.8164063 + - 3.9121094 + - -0.57470703 + - -1.8261719 + - 0.39379883 + - 8.6640625 + - -3.2226563 + - -1.2158203 + - 0.6328125 + - -1.2607422 + - 1.1367188 + - 0.51123047 + - 1.3037109 + - -0.11773682 + - -0.11462402 + - -4.2421875 + - -3.546875 + - -2.6640625 + - -3.1269531 + - -2.9941406 + - 0.49536133 + - -2.1972656 + - -1.2841797 + - 3.2851563 + - -0.7211914 + - -1.8222656 + - 0.68310547 + - -3.3378906 + - -4.3945313 + - -0.29614258 + - 2.0722656 + - -2.6777344 + - -0.19885254 + - 1.1748047 + - 2.1855469 + - 1.2265625 + - -1.1201172 + - -3.0878906 + - -1.4257813 + - -0.8696289 + - -2.9550781 + - 0.012275696 + - -0.5029297 + - -0.26831055 + - 4.1679688 + - -1.1015625 + - 2.6386719 + - -3.3066406 + - -2.3125 + - -1.2939453 + - -0.6850586 + - 1.2021484 + - -1.3095703 + - 1.4707031 + - 1.0224609 + - 0.8652344 + - 0.40429688 + - -1.2783203 + - -1.6054688 + - 1.5166016 + - -1.4238281 + - 1.6367188 + - 0.48046875 + - -0.32885742 + - 2.7402344 + - 0.9326172 + - 0.21398926 + - 1.2578125 + - -3.8359375 + - -2.6425781 + - -3.2421875 + - -1.3925781 + - 0.29956055 + - -0.22302246 + - 0.52734375 + - 1.0439453 + - 1.1669922 + - 1.2773438 + - -1.2041016 + - -2.421875 + - 1.2001953 + - 2.1035156 + - -2.71875 + - 2.1171875 + - 0.453125 + - 0.3317871 + - 1.2675781 + - 0.6713867 + - -5.578125 + - -3.3398438 + - -1.0908203 + - 1.5175781 + - 0.0262146 + - -2.25 + - -0.95703125 + - 4.9179688 + - -0.171875 + - 1.3681641 + - 6.5859375 + - 2.5625 + - -2.6875 + - 0.84033203 + - -0.055236816 + - 6.015625 + - -4.9648438 + - -2.1777344 + - 0.98876953 + - -2.1269531 + - -0.57470703 + - -2.3886719 + - 1.8857422 + - -3.3496094 + - 3.1972656 + - -1.1943359 + - 0.71972656 + - 0.15234375 + - -0.51708984 + - -1.1992188 + - 0.9658203 + - -0.23144531 + - -1.9414063 + - 5.9726563 + - 0.78759766 + - 2.4453125 + - -0.31518555 + - -4.4648438 + - 2.4316406 + - 0.24658203 + - 1.3349609 + - -0.71484375 + - -1.3564453 + - -0.7675781 + - 1.1240234 + - -2.0175781 + - -3.0800781 + - -0.032348633 + - 0.69873047 + - 1.7294922 + - 2.8203125 + - -2.3183594 + - 1.2373047 + - 0.30688477 + - -2.703125 + - 0.3466797 + - 3.5585938 + - 1.3242188 + - 5.7539063 + - 0.24804688 + - 0.0625 + - 16.203125 + - -0.41845703 + - 2.3027344 + - -3.5488281 + - -0.90771484 + - -0.89697266 + - 0.5410156 + - 1.4794922 + - 4.1484375 + - -0.92089844 + - -3.5253906 + - -1.8222656 + - 0.8720703 + - 1.9169922 + - 1.0517578 + - -1.1318359 + - 4.453125 + - -0.26391602 + - -0.66796875 + - 0.24523926 + - -1.6455078 + - 0.3034668 + - -1.5175781 + - -2.2949219 + - -1.6777344 + - 2.3652344 + - -0.2253418 + - -3.9960938 + - -3.1015625 + - 0.74316406 + - -0.99609375 + - -0.87890625 + - -1.8613281 + - -1.890625 + - 0.1751709 + - -0.083984375 + - 3.0117188 + - 0.75634766 + - 2.7890625 + - 0.2861328 + - 1.9648438 + - -4.5898438 + - 0.88720703 + - 0.65283203 + - -0.06890869 + - 4.2070313 + - -1.3691406 + - -1.3691406 + - -2.0625 + - -5.4882813 + - 2.1308594 + - 1.9013672 + - -0.30786133 + - 2.8808594 + - 4.703125 + - -1.6386719 + - -0.17785645 + - -3.8339844 + - -0.13439941 + - -1.8310547 + - -0.77441406 + - -1.1064453 + - 1.7431641 + - -2.7011719 + - -0.38720703 + - 1.0185547 + - 1.9091797 + - -4.953125 + - 3.3925781 + - 0.92626953 + - -0.5727539 + - -1.6923828 + - 4.6914063 + - 0.94384766 + - 1.1826172 + - 1.0126953 + - -1.9609375 + - -2.4472656 + - 1.6650391 + - 1.3632813 + - 2.3925781 + - 0.17211914 + - 4.7539063 + - -1.6230469 + - -1.1386719 + - 0.9663086 + - -1.5556641 + - -0.7675781 + - -1.5439453 + - 0.62353516 + - -4.34375 + - -0.8286133 + - 1.6669922 + - 1.9033203 + - -2.3789063 + - 2.5566406 + - -3.9316406 + - 2.6816406 + - 0.78759766 + - -0.73876953 + - 4.6054688 + - -0.89160156 + - -2.6074219 + - 1.9169922 + - 2.4316406 + - 3.3085938 + - 1.7695313 + - -1.0097656 + - -0.22338867 + - 0.45361328 + - 33.40625 + - 13.4765625 + - -9.1796875 + - 2.265625 + - -1.0507813 + - 1.4277344 + - -2.734375 + - -4.1757813 + - -0.36376953 + - -0.20703125 + - 1.9589844 + - 0.51464844 + - -0.34057617 + - 1.5166016 + - -2.7890625 + - 1.9707031 + - -1.0009766 + - 0.91259766 + - -2.6933594 + - 0.7138672 + - 1.8779297 + - 3.4140625 + - -1.3193359 + - -1.1445313 + - -0.2253418 + - -2.1523438 + - 0.08703613 + - -0.4038086 + - -4.6054688 + - 0.75097656 + - -0.119384766 + - -0.16101074 + - 1.4169922 + - 2.4785156 + - 1.6337891 + - -4.3789063 + - -1.8554688 + - 2.0644531 + - -2.1699219 + - 1.2451172 + - 2.2324219 + - 1.5371094 + - -0.27978516 + - 4.2304688 + - -1.2050781 + - 0.29345703 + - -3.4941406 + - 2.1425781 + - 1.3066406 + - 0.5107422 + - 2.2910156 + - 8.7265625 + - -0.5673828 + - -1.4306641 + - 1.7226563 + - -0.9453125 + - -0.84521484 + - 0.05606079 + - 1.4580078 + - 0.2175293 + - 2.9785156 + - 2.3984375 + - 1.2050781 + - -3.9238281 + - -1.7402344 + - -1.1376953 + - 1.9384766 + - -0.83203125 + - -2.6855469 + - 0.2565918 + - -2.9277344 + - -0.20385742 + - -1.5039063 + - -2.265625 + - 0.92822266 + - -2.6640625 + - -0.18579102 + - 1.3486328 + - 5.4453125 + - 0.41503906 + - -1.7626953 + - -1.4189453 + - 1.6337891 + - 1.8632813 + - 1.6875 + - 2.3808594 + - 1.1025391 + - 0.22314453 + - 1.9453125 + - -1.5341797 + - 1.3691406 + - 0.5053711 + - -0.8886719 + - -0.99902344 + - 3.6582031 + - 1.2080078 + - -1.3974609 + - 4.03125 + - -1.9023438 + - 0.5214844 + - -3.4609375 + - -1.0595703 + - 0.75097656 + - 1.15625 + - 0.11743164 + - 0.4892578 + - 0.32250977 + - -2.3222656 + - -0.081970215 + - 1.4853516 + - -3.2910156 + - 3.6777344 + - -0.69384766 + - 4.28125 + - 1.8076172 + - 2.8300781 + - -2.9140625 + - -1.3212891 + - 3.5175781 + - 0.42773438 + - -2.3886719 + - -1.8847656 + - 0.8803711 + - 1.109375 + - 3.6132813 + - 1.3603516 + - -3.2714844 + - 2.0566406 + - 2.4140625 + - 0.1307373 + - -0.87890625 + - -1.2529297 + - -1.1123047 + - 1.2490234 + - 0.28198242 + - 0.3125 + - -0.18469238 + - -3.4375 + - 1.5390625 + - -1.3007813 + - -0.4399414 + - 1.9648438 + - 1.7783203 + - -2.1347656 + - -0.296875 + - -0.17236328 + - 2.0097656 + - -1.2041016 + - -0.14453125 + - -4.1132813 + - 1.1660156 + - 1.3193359 + - -1.4667969 + - -1.4375 + - 0.4111328 + - -0.91552734 + - -1.1474609 + - 0.41748047 + - 0.4025879 + - 2.1621094 + - 0.09051514 + - -2.5625 + - 2.7890625 + - 1.7763672 + - -0.9404297 + - 0.4248047 + - 0.32739258 + - 2.3457031 + - -0.119506836 + - -2.5625 + - -0.5102539 + - -0.26660156 + - -2.6132813 + - -1.3476563 + - 0.5800781 + - 0.7158203 + - 1.4140625 + - 1.9658203 + - -1.1708984 + - -1.7529297 + - -0.59765625 + - 0.38500977 + - -0.5258789 + - 0.9008789 + - 1.5195313 + - -1.5722656 + - -0.06945801 + - 1.7695313 + - 1.7246094 + - -1.2783203 + - 2.3789063 + - 2.3203125 + - 1.78125 + - 0.7128906 + - -2.4902344 + - -1.8623047 + - 2.984375 + - 1.1738281 + - 0.92285156 + - -3.3925781 + - -2.7636719 + - -1.4267578 + - -2.8496094 + - -0.41601563 + - 0.39208984 + - -12.4453125 + - -0.31689453 + - -0.46142578 + - 0.21984863 + - -0.89160156 + - 0.5493164 + - -1.2490234 + - 1.6689453 + - 0.4597168 + - -1.7109375 + - 2.34375 + - -5.3710938 + - 0.48706055 + - 0.3251953 + - -1.1757813 + - 1.375 + - 1.5214844 + - -2.0566406 + - -0.022598267 + - 3.4277344 + - 0.61816406 + - 1.828125 + - -0.5341797 + - 9.390625 + - 1.4433594 + - -2.1386719 + - 0.72509766 + - -0.5239258 + - 0.89208984 + - -0.89160156 + - -0.083618164 + - -2.6601563 + - 6.7539063 + - 0.6816406 + - -1.7734375 + - 0.74072266 + - 1.0400391 + - -6.0976563 + - 0.71777344 + - 0.2915039 + - 1.3701172 + - 0.43798828 + - 6.2929688 + - -0.5932617 + - -2.7695313 + - 1.8964844 + - 2.2207031 + - 2.4609375 + - 2.1035156 + - 1.1425781 + - -2.8378906 + - 1.5439453 + - 1.7998047 + - -3.1582031 + - -1.0820313 + - -0.32714844 + - -0.43115234 + - -3.2050781 + - -1.8183594 + - -3.2753906 + - -0.1986084 + - -3.8652344 + - 2.4101563 + - -1.6953125 + - -1.7978516 + - 3.5683594 + - -2.4199219 + - 0.19494629 + - -1.6347656 + - -1.6376953 + - 2.0566406 + - -0.3552246 + - -1.3388672 + - 1.7587891 + - 1.6367188 + - -0.61572266 + - 0.6455078 + - 0.6113281 + - 2.1738281 + - 0.86376953 + - 3.7558594 + - 0.019104004 + - -0.2692871 + - -1.7851563 + - 2.6640625 + - 0.18725586 + - -2.0234375 + - -1.2880859 + - -1.5732422 + - -0.09063721 + - 5.2382813 + - 4.703125 + - -1.1416016 + - 1.9345703 + - 2.3378906 + - -0.7207031 + - -1.2539063 + - -0.4033203 + - 2.0351563 + - -1.9433594 + - 2.2792969 + - -3.4765625 + - 2.8359375 + - 0.7871094 + - -3.9589844 + - -0.11071777 + - -2.6660156 + - 3.2460938 + - 0.30151367 + - -5.5117188 + - -0.2685547 + - -1.7626953 + - 1.6542969 + - 0.42626953 + - 0.66503906 + - 3.4492188 + - 0.47387695 + - 1.28125 + - -0.3215332 + - -3.09375 + - -1.6669922 + - -0.59765625 + - -3.7890625 + - 8.9296875 + - 1.1962891 + - 1.4658203 + - -0.5292969 + - 0.5283203 + - -1.4980469 + - 0.4362793 + - 1.1601563 + - -1.2988281 + - -5.4726563 + - -3.3964844 + - 4.6328125 + - -4.1757813 + - 1.8066406 + - -1.8466797 + - -2.8164063 + - 1.296875 + - 0.8886719 + - -0.58203125 + - 0.27270508 + - 1.25 + - 1.1113281 + - -3.1777344 + - 0.07476807 + - -4.0429688 + - 1.7041016 + - -1.5908203 + - 1.2070313 + - -3.5976563 + - 0.81103516 + - -1.4306641 + - 0.9394531 + - -2.4980469 + - -1.0517578 + - 0.07281494 + - 2.2519531 + - 3.2441406 + - 0.49902344 + - 1.6640625 + - -1.6152344 + - 2.421875 + - 1.2851563 + - -0.71875 + - -1.1757813 + - -2.6894531 + - -0.24438477 + - 0.5205078 + - 2.5664063 + - -2.8769531 + - -0.093566895 + - -0.00390625 + - 4.234375 + - -0.012275696 + - -2.2246094 + - 0.36572266 + - 1.9814453 + - -2.2167969 + - -2.3164063 + - -0.9794922 + - 1.2119141 + - 1.9492188 + - -0.5366211 + - 0.7207031 + - -1.4638672 + - -0.29589844 + - 0.8256836 + - 3.0742188 + - -2.9179688 + - -2.7089844 + - 1.5957031 + - 1.8466797 + - 5.8125 + - 2.6308594 + - -1.5351563 + - 1.4619141 + - -0.5991211 + - 1.0800781 + - -1.6582031 + - -2.0136719 + - -0.91308594 + - 1.2207031 + - -1.9169922 + - 1.1708984 + - -1.0449219 + - 3.5253906 + - 4.34375 + - -0.51708984 + - 0.18188477 + - -0.23486328 + - -1.4326172 + - -3.3300781 + - -2.8691406 + - -0.890625 + - 1.3818359 + - -1.0712891 + - 0.85791016 + - 2.171875 + - 1.5488281 + - 1.4101563 + - -0.41503906 + - 0.8691406 + - -4.9179688 + - -0.90283203 + - -8.3046875 + - -1.7314453 + - -2.0175781 + - -2.2753906 + - -2.9023438 + - -0.96533203 + - 2.8378906 + - -6.7421875 + - -4.4335938 + - 24.671875 + - -1.7314453 + - -1.6464844 + - -0.65722656 + - -0.1796875 + - 0.51416016 + - 2.3203125 + - 3.0976563 + - -2.1542969 + - 1.1396484 + - 1.6914063 + - -0.0390625 + - 0.88378906 + - -1.4277344 + - 0.4267578 + - 0.08758545 + - -3.4179688 + - 0.72802734 + - 4.8867188 + - -0.75634766 + - -0.5488281 + - -1.4765625 + - -2.4765625 + - 0.65625 + - -0.3408203 + - 3.7578125 + - 0.36083984 + - -2.0878906 + - 2.2285156 + - -0.27612305 + - 1.5869141 + - -2.5488281 + - 0.7753906 + - 0.4025879 + - 1.2587891 + - -0.55908203 + - 1.6416016 + - 2.9863281 + - 4.1796875 + - 0.13830566 + - -0.85595703 + - -0.55566406 + - 2.0410156 + - -3.8964844 + - 0.77978516 + - -0.2824707 + - 3.2734375 + - 1.1845703 + - -2.0351563 + - 0.7270508 + - 2.3515625 + - 0.83691406 + - -3.1015625 + - -1.3193359 + - -2.0195313 + - -1.6425781 + - -2.9023438 + - -0.42871094 + - 2.3789063 + - -3.4550781 + - -2.8339844 + - 1.1816406 + - -0.5722656 + - 2.453125 + - -2.5 + - -0.10070801 + - -1.1962891 + - -0.010597229 + - -2.734375 + - 1.5898438 + - -4.609375 + - -4.359375 + - -0.1171875 + - -1.5556641 + - 1.4550781 + - 8.6328125 + - 0.89501953 + - 3.6816406 + - -4.7578125 + - 1.1894531 + - -0.67626953 + - 1.3095703 + - 0.9038086 + - 0.67626953 + - -0.16235352 + - -4.78125 + - 0.53125 + - 0.7607422 + - 2.5625 + - -0.83447266 + - -2.8378906 + - 0.44628906 + - -0.08538818 + - -0.5522461 + - -2.4765625 + - 1.4394531 + - 2.1074219 + - -2.5625 + - 5.3554688 + - 0.30908203 + - 0.36865234 + - 0.9243164 + - 0.52734375 + - 4.0117188 + - 0.27416992 + - 2.0800781 + - -1.8203125 + - -0.51904297 + - 0.5410156 + - 2.3886719 + - 7.1640625 + - 1.7148438 + - 1.0996094 + - -1.0556641 + - 3.5546875 + - 0.050476074 + - 1.7128906 + - 1.7871094 + - 2.2246094 + - -0.30566406 + - 3.09375 + - -0.69628906 + - 3.6015625 + - -4.4882813 + - -1.4697266 + - -2.0253906 + - 0.94189453 + - 0.001115799 + - 1.3408203 + - -0.42285156 + - 4.0742188 + - -1.9775391 + - -2.1054688 + - -0.84228516 + - 0.016174316 + - 2.9785156 + - 2.40625 + - 0.7363281 + - 1.1787109 + - 3.2851563 + - 4.1992188 + - 0.75634766 + - -0.5756836 + - 1.3769531 + - 2.0800781 + - -4.9882813 + - -4.578125 + - -0.9609375 + - 3.3125 + - -1.5917969 + - -0.75097656 + - -1.9638672 + - 2.8613281 + - 3.2753906 + - 3.265625 + - -0.8544922 + - -0.28344727 + - 1.3613281 + - -1.3515625 + - -0.44604492 + - 2.5839844 + - 2.6875 + - -0.9711914 + - -0.3581543 + - 0.4165039 + - 1.7861328 + - 0.39453125 + - -0.12207031 + - -0.35864258 + - 1.2529297 + - 2.140625 + - 0.9091797 + - -2.1191406 + - -0.3251953 + - -3.6425781 + - -4.8789063 + - -0.092163086 + - 2.5820313 + - -0.86035156 + - -0.36767578 + - 3.125 + - -2.1777344 + - 2.0097656 + - 0.5566406 + - -0.9897461 + - -2.9140625 + - 1.4013672 + - -0.5180664 + - 3.0625 + - 3.3476563 + - 1.2998047 + - -6.8359375 + - -0.47680664 + - -0.41845703 + - -5.390625 + - 2.1210938 + - -2.6621094 + - 2.4355469 + - 1.3867188 + - -6.4453125 + - 1.3076172 + - -0.65478516 + - -2.7988281 + - -2.4296875 + - 1.1220703 + - -0.37597656 + - 2.0761719 + - -0.4309082 + - -0.8129883 + - -33.875 + - -2.53125 + - -2.4140625 + - -0.3881836 + - -1.4277344 + - 2.09375 + - 2.4121094 + - -4.7539063 + - -4.6601563 + - -0.9038086 + - 1.1162109 + - -1.4375 + - -1.0976563 + - 6.7734375 + - 0.4885254 + - 4.7304688 + - -1.6601563 + - 4.3242188 + - -0.25097656 + - -1.4335938 + - 0.11437988 + - -0.45507813 + - 1.0791016 + - 1.8134766 + - -0.4350586 + - -4.0117188 + - -1.2519531 + - 0.053833008 + - 1.8681641 + - -0.36206055 + - 0.5722656 + - -1.265625 + - 0.3642578 + - -0.5629883 + - -3.4941406 + - 4.8632813 + - -3.3046875 + - -0.8071289 + - -2.328125 + - -3.4863281 + - 0.029571533 + - 1.9746094 + - 2.6328125 + - 0.01576233 + - 0.25268555 + - 1.7089844 + - 4.0039063 + - -0.63720703 + - 1.90625 + - -2.8339844 + - 2.6796875 + - -1.0927734 + - 0.26220703 + - -3.9238281 + - 3.0117188 + - 2.6074219 + - -2.9648438 + - 3.4550781 + - 2.6816406 + - 0.6645508 + - -1.0673828 + - -4.0117188 + - 3.0097656 + - 1.3544922 + - 1.5175781 + - -0.3876953 + - 0.039611816 + - -5.0078125 + - 0.8300781 + - 1.3789063 + - -2.2207031 + - 0.77441406 + - 2.6035156 + - 0.40454102 + - -0.56103516 + - 2.2070313 + - -1.4003906 + - -2.6953125 + - 0.8046875 + - 0.42114258 + - -1.2441406 + - 2.0878906 + - 0.47314453 + - 1.0439453 + - 3.0527344 + - 0.85058594 + - -1.2832031 + - 1.1123047 + - 2.0527344 + - 0.74658203 + - -2.3789063 + - 2.7949219 + - -1.0400391 + - 8.5703125 + - -1.4746094 + - 2.03125 + - -0.5991211 + - -0.8847656 + - -0.44628906 + - -0.66796875 + - 2.8222656 + - 0.049102783 + - 3.53125 + - 1.0810547 + - 2.125 + - -2.1464844 + - -2.4277344 + - 3.5800781 + - -0.17236328 + - 5.921875 + - -1.0566406 + - 5.921875 + - -2.0253906 + - -0.95410156 + - -1.4013672 + - 1.5019531 + - 0.3852539 + - 0.79003906 + - -1.5839844 + - 4.1132813 + - 2.96875 + - 2.4902344 + - 4.6875 + - -0.7216797 + - -2.0976563 + - 1.7167969 + - -1.4580078 + - -4.0742188 + - -3.1113281 + - 0.44921875 + - -4.3554688 + - -0.16064453 + - 1.7939453 + - 3.7304688 + - -1.1054688 + - -0.67529297 + - -30.3125 + - -0.85595703 + - -0.027618408 + - -0.6660156 + - 0.7626953 + - 3.5800781 + - 0.79296875 + - 1.8632813 + - 0.12609863 + - 2.0976563 + - 0.012275696 + - -0.1484375 + - -2.9160156 + - -2.2011719 + - 1.3662109 + - -2.3691406 + - 0.55859375 + - 0.073791504 + - -0.63134766 + - -1.5576172 + - 1.4433594 + - 10.890625 + - 3.125 + - -1.265625 + - 1.1884766 + - 0.94140625 + - -0.84814453 + - 2.3105469 + - 0.37841797 + - -2.6035156 + - 1.296875 + - 0.2529297 + - -2.203125 + - 0.34057617 + - 0.38110352 + - -2.0644531 + - -3.2285156 + - 0.17248535 + - -0.55126953 + - -1.90625 + - 5.6289063 + - 1.6572266 + - -1.2236328 + - 3.1679688 + - 1.0341797 + - 1.2763672 + - 0.0011701584 + - 3.1445313 + - 0.6489258 + - -1.7949219 + - 0.19189453 + - 3.5175781 + - -2.3945313 + - 2.4589844 + - -1.5351563 + - -2.0097656 + - -0.9692383 + - 4.3242188 + - 0.4519043 + - -4.0820313 + - 1.6386719 + - -0.49804688 + - -0.6801758 + - -1.8076172 + - -2.5019531 + - 0.077819824 + - -3.75 + - 0.7397461 + - 3.0078125 + - -6.9453125 + - 0.48876953 + - -1.3095703 + - -3.3691406 + - -3.0175781 + - 1.7734375 + - -0.8691406 + - -3.1191406 + - 0.06640625 + - 0.18615723 + - -0.3959961 + - -1.3349609 + - -0.6459961 + - 1.8984375 + - 1.75 + - 6.6757813 + - -1.4882813 + - -0.46704102 + - -1.2744141 + - -1.8183594 + - 2.0644531 + - -1.9638672 + - -0.7011719 + - 2.0664063 + - 0.15258789 + - 3.4492188 + - 0.890625 + - 0.921875 + - -1.0634766 + - 3.0039063 + - -0.6928711 + - 1.6298828 + - 0.5488281 + - -2.703125 + - -1.1425781 + - 0.41503906 + - -0.5839844 + - -0.2109375 + - 4.5625 + - 1.4433594 + - -0.11102295 + - -1.6738281 + - 4.5078125 + - -0.49682617 + - 2.0371094 + - -2.7558594 + - -1.8857422 + - 2.1015625 + - 2.515625 + - -0.82177734 + - 0.87597656 + - 1.6611328 + - -1.1982422 + - -1.96875 + - -1.2451172 + - 0.07476807 + - -0.46923828 + - -4.9023438 + - 0.047424316 + - -1.0195313 + - 3.3046875 + - 0.25048828 + - 0.66015625 + - -0.43066406 + - -0.13110352 + - 1.1132813 + - -0.35327148 + - -0.6738281 + - -0.47021484 + - -1.140625 + - -4.4179688 + - 0.7680664 + - 4.2070313 + - 0.112854004 + - 1.3613281 + - 1.8691406 + - 0.6191406 + - 3.9082031 + - -1.546875 + - 0.0418396 + - 2.265625 + - 2.2480469 + - 2.8027344 + - -1.9775391 + - 1.8564453 + - -1.6796875 + - 1.6044922 + - -2.3691406 + - 0.18969727 + - 1.0859375 + - 2.8300781 + - -0.6640625 + - 2.6914063 + - 2.7753906 + - 1.3164063 + - 2.5449219 + - -2.40625 + - 4.4960938 + - -2.4257813 + - -0.54003906 + - 1.7001953 + - -0.63427734 + - -2.5 + - 1.7324219 + - 0.1015625 + - -2.2871094 + - -1.5751953 + - -1.5019531 + - -1.6982422 + - -2.8789063 + - 3.1425781 + - 1.8701172 + - 1.7558594 + - -2.7441406 + - -0.32348633 + - -0.13171387 + - 2.4902344 + - 0.3330078 + - 2.4199219 + - -3.0214844 + - -0.18884277 + - 0.44799805 + - 1.0439453 + - 0.17492676 + - 4.0351563 + - -0.08843994 + - 1.4238281 + - -0.7919922 + - -1.9882813 + - -0.9272461 + - 1.3662109 + - 1.046875 + - 0.63427734 + - 1.2451172 + - -3.4550781 + - 0.17297363 + - 1.7441406 + - 0.62353516 + - -0.3647461 + - 1.515625 + - -1.1552734 + - -2.4160156 + - -5.5429688 + - -4.09375 + - 6.078125 + - -1.3701172 + - -0.91015625 + - 1.1992188 + - -1.7529297 + - 2.0800781 + - -1.6416016 + - -2.3925781 + - -3.8867188 + - -2.203125 + - -2.6425781 + - 0.7397461 + - 0.2734375 + - 1.4511719 + - -0.7939453 + - -1.1513672 + - 0.75683594 + - 0.1204834 + - -3.5039063 + - -1.7607422 + - -1.4775391 + - 3.1015625 + - 2.0839844 + - 6.2929688 + - -0.44384766 + - 2.5175781 + - -1.7080078 + - 1.8369141 +- - 1.3066406 + - -2.1523438 + - 0.703125 + - 0.2529297 + - 1.2626953 + - -1.46875 + - -0.19042969 + - -0.14892578 + - 3.3066406 + - -1.8222656 + - 1.0253906 + - -0.51953125 + - 0.8203125 + - 0.2109375 + - 1.1699219 + - 0.109680176 + - 1.5429688 + - 1.2597656 + - 2.3242188 + - -2.4765625 + - -1.4189453 + - -0.6923828 + - -0.0078125 + - 0.44189453 + - 2.7128906 + - 1.8183594 + - -0.043762207 + - 1.6103516 + - 0.77734375 + - 1.21875 + - 3.8847656 + - -0.7583008 + - 5.4765625 + - 1.6425781 + - -2.4707031 + - 1.5048828 + - -1.8222656 + - -1.1347656 + - -6.5820313 + - -0.45825195 + - 0.9609375 + - -1.4111328 + - 1.1171875 + - -1.0078125 + - -0.67578125 + - 1.3095703 + - 0.9667969 + - -3.625 + - 0.6777344 + - 2.6757813 + - 3.109375 + - -0.94970703 + - -3.96875 + - -79.125 + - -2.3476563 + - -1.6230469 + - 3.4257813 + - -1.3222656 + - -2.5878906 + - -10.5 + - -1.8828125 + - -0.7763672 + - -0.20166016 + - -0.38671875 + - 0.066223145 + - 0.24121094 + - -2.9160156 + - 2.1953125 + - -0.5649414 + - -0.8515625 + - -5.0117188 + - -1.8183594 + - -1.2324219 + - -2.1738281 + - -1.2753906 + - 0.38012695 + - 2.3984375 + - 1.7548828 + - 0.31445313 + - 0.1796875 + - 0.74609375 + - -1.5439453 + - -0.69970703 + - 1.3261719 + - -2.4179688 + - 3.9316406 + - -2.2070313 + - 0.7993164 + - 3.7070313 + - 2.0117188 + - -0.48486328 + - 2.3808594 + - 2.2070313 + - -26.71875 + - 0.13146973 + - -4.5546875 + - -5.8632813 + - -0.53515625 + - -0.08850098 + - -5.8359375 + - -1.0390625 + - -2.6054688 + - -6.5507813 + - -2.9179688 + - -1.4267578 + - -2.7207031 + - 1.1035156 + - -1.9316406 + - -1.3251953 + - 0.1217041 + - -0.5 + - 0.953125 + - 3.2734375 + - -1.8398438 + - -1.109375 + - 5.4570313 + - 2.2636719 + - 1.78125 + - -2.0039063 + - 0.7607422 + - 3.6132813 + - 1 + - -2.1503906 + - 0.3461914 + - -0.95410156 + - -0.73535156 + - 3.3984375 + - -1.7480469 + - 0.08428955 + - 2.4414063 + - 1.2148438 + - 1.2958984 + - -2.2597656 + - 1.1669922 + - 0.5546875 + - 0.6875 + - 1.953125 + - 0.578125 + - -2.1875 + - -1.5830078 + - 1.1005859 + - -0.66015625 + - 2.1269531 + - 0.39160156 + - 2.5273438 + - 0.61035156 + - -1.8222656 + - -1.2480469 + - -3.453125 + - 5.515625 + - 1.0234375 + - -1.2080078 + - -2.0703125 + - -0.7324219 + - -0.64697266 + - -4.796875 + - -1.21875 + - -0.30126953 + - -1.1337891 + - -2.234375 + - -0.036132813 + - -1.7109375 + - -3.625 + - -2.1074219 + - -0.7133789 + - 0.78759766 + - 1.7910156 + - -0.48364258 + - -0.57128906 + - -1.4111328 + - -1.8066406 + - -1.0322266 + - -0.9736328 + - -1.2832031 + - -1.4316406 + - -0.91503906 + - -2.0410156 + - 3.2207031 + - -1.1191406 + - -0.4609375 + - 3.4726563 + - 0.73046875 + - -1.2910156 + - -1.8994141 + - -0.70166016 + - 5.953125 + - -4.6757813 + - -0.33642578 + - -1.3808594 + - -1.0087891 + - -3.4550781 + - -2.0703125 + - -0.11456299 + - 1.4150391 + - -2.3164063 + - 4.3203125 + - 3.0625 + - -3.1289063 + - -1.7910156 + - 2.7265625 + - 0.49414063 + - -3.7148438 + - 1.8212891 + - 0.04296875 + - 1.7988281 + - 6.609375 + - 7.0976563 + - -2.7851563 + - -2.453125 + - -4.2226563 + - -2.7910156 + - -0.026031494 + - -2.6015625 + - -0.49658203 + - 0.26220703 + - 3.2597656 + - 1.1660156 + - -2.0742188 + - 6 + - 1.4511719 + - -2.2148438 + - 2.4785156 + - -3.1953125 + - 2.0566406 + - -0.5751953 + - -2.5722656 + - 1.0351563 + - 1.0371094 + - 0.7368164 + - -0.65478516 + - 2.015625 + - -0.5395508 + - -0.77197266 + - -1.8203125 + - -0.59814453 + - 0.77197266 + - 2.0957031 + - 2.0429688 + - 4.4296875 + - 0.26733398 + - 11.1640625 + - 0.024246216 + - -6.1328125 + - -0.7373047 + - -1.765625 + - -1.8984375 + - 5.2890625 + - 3.6191406 + - -0.52685547 + - 0.5571289 + - -0.6923828 + - -0.18676758 + - -2.1582031 + - -1.0644531 + - -1.4501953 + - -0.65527344 + - -3.2617188 + - -1.4257813 + - -2.375 + - 2.4433594 + - 0.8105469 + - -0.2290039 + - 3.6132813 + - 1.6386719 + - 0.17578125 + - -0.28222656 + - -3.4179688 + - 5.8007813 + - -0.8408203 + - -7.125 + - 0.4477539 + - 1.1816406 + - 2.8007813 + - -1.1210938 + - 1.6542969 + - 0.024734497 + - -3.390625 + - 2.2402344 + - 0.5571289 + - -0.67089844 + - -3.1210938 + - -0.091796875 + - 1.8320313 + - 2.421875 + - -0.43115234 + - -0.41845703 + - 1.9492188 + - -1.0253906 + - 1.8066406 + - -1.1699219 + - -0.04067993 + - -1.3125 + - 18.59375 + - -0.49267578 + - 2.1640625 + - -1.1904297 + - 2.046875 + - -2.9882813 + - 3.0351563 + - 0.070129395 + - -0.2932129 + - 0.14709473 + - 3.140625 + - 0.6411133 + - -1.734375 + - 1.0273438 + - 3.25 + - 0.66796875 + - -0.24633789 + - 1.0820313 + - -0.81152344 + - 2.8691406 + - -0.22851563 + - 0.8828125 + - -0.84765625 + - -3.078125 + - -0.53466797 + - -1.3183594 + - -2.9101563 + - 2.5097656 + - 0.9892578 + - -0.7841797 + - 1.0058594 + - 2.09375 + - -0.4638672 + - -0.27783203 + - -1.4726563 + - -0.58935547 + - 1.0644531 + - -3.0273438 + - -5.5820313 + - 2.59375 + - 0.8964844 + - 1.4658203 + - 2.8945313 + - -2.796875 + - 3.1347656 + - 0.73535156 + - -0.9921875 + - 0.6640625 + - 2.2148438 + - -0.47998047 + - 2.6660156 + - 0.028152466 + - 0.88671875 + - 1.6191406 + - 0.18554688 + - 1.1972656 + - -4.5234375 + - -0.7114258 + - 1.9296875 + - -0.3076172 + - 1.2744141 + - -0.19140625 + - -3.65625 + - -0.27856445 + - -5.1523438 + - -2.9882813 + - -1.6640625 + - -1.6660156 + - -1.7089844 + - 2.65625 + - 3.1875 + - -2.65625 + - 2.140625 + - -2.5976563 + - 5.9453125 + - 0.00032544136 + - 0.24072266 + - -2.453125 + - 0.00390625 + - -3.0390625 + - -2.8125 + - 2.1640625 + - 0.04296875 + - 3.2910156 + - -3.5351563 + - 1.5039063 + - -0.6879883 + - 2.1210938 + - -0.13867188 + - 1.2568359 + - -2.7675781 + - -1.9736328 + - -3.2578125 + - 2.8164063 + - 4.2734375 + - -2.6953125 + - -8.328125 + - -1.2773438 + - 0.95214844 + - -1.4785156 + - -2.8066406 + - -2.5625 + - 0.31762695 + - -0.07287598 + - 2.9238281 + - 1.5556641 + - -1.234375 + - 1.2900391 + - -3 + - 3.5097656 + - 1.1171875 + - -4.359375 + - 3.1347656 + - 2.8691406 + - 4.7421875 + - -2.5039063 + - -11.0078125 + - 0.47558594 + - -2.21875 + - -2.3964844 + - 2.8046875 + - -2.3085938 + - -0.24182129 + - 2.6953125 + - -3.296875 + - -2.3847656 + - -3.3535156 + - 4.9257813 + - -2.2988281 + - 0.1973877 + - -0.5859375 + - 0.66308594 + - 0.53564453 + - 0.9667969 + - 1.984375 + - 2.1015625 + - 2.3496094 + - -1.4863281 + - -1.3291016 + - -1.640625 + - -3.546875 + - -1.1943359 + - -0.7705078 + - -2.5976563 + - 3.5039063 + - -2.75 + - 0.234375 + - 3.1796875 + - -4.5703125 + - -1.8574219 + - -0.6586914 + - -3.6054688 + - -2.5800781 + - -0.04034424 + - 0.48876953 + - 1.9150391 + - -2.6191406 + - -4.1875 + - 1.2519531 + - 0.5439453 + - -0.16992188 + - -2.0195313 + - -0.70751953 + - 5.5 + - 6.0625 + - 1.9619141 + - 4.0234375 + - 2.5332031 + - -0.94384766 + - -3.8242188 + - -2.4726563 + - 2.765625 + - -2.5703125 + - 0.14868164 + - 2.1289063 + - -0.029937744 + - -0.19921875 + - -1.5585938 + - 6.5546875 + - 1.2070313 + - 2.3320313 + - 1.4941406 + - 0.030761719 + - 0.42529297 + - 0.30664063 + - -2.03125 + - -0.46142578 + - 3.5019531 + - -0.21740723 + - -0.52441406 + - -1.015625 + - -4.1601563 + - -1.5078125 + - 0.44873047 + - -8.125 + - 0.90625 + - 2.7226563 + - -0.7109375 + - 1.4423828 + - 2.125 + - -2.3691406 + - -1.2714844 + - -0.7314453 + - -0.96484375 + - 3.7441406 + - -3.65625 + - -2.484375 + - 2.5 + - 0.27734375 + - -4.84375 + - 2.875 + - 0.5957031 + - 0.23510742 + - 3.8359375 + - -3.4023438 + - -0.5209961 + - -3.359375 + - 3.0253906 + - 2.9003906 + - -2.6640625 + - -0.9140625 + - 2.1484375 + - -3.6914063 + - 0.123535156 + - -2.3554688 + - 0.50146484 + - -1.9921875 + - -0.22851563 + - 0.5620117 + - 1.7978516 + - 3.9921875 + - -0.01626587 + - -0.1796875 + - 2.0039063 + - 1.5117188 + - -2.890625 + - 0.7758789 + - 3.7070313 + - 0.9814453 + - 0.9794922 + - -0.5517578 + - -0.6455078 + - 2.3554688 + - -0.01953125 + - -2.6328125 + - 1.1054688 + - -2.5917969 + - -3.5273438 + - -1.4472656 + - -1.1289063 + - 2.1367188 + - -2.8125 + - -4.703125 + - -3.0390625 + - -0.091796875 + - 1.2519531 + - 2.8691406 + - 3.484375 + - 2.6757813 + - 0.5048828 + - -7.5664063 + - -2.5976563 + - -1.0341797 + - -2.0488281 + - -0.90234375 + - 1.21875 + - 0.26953125 + - 3.9453125 + - 2.328125 + - -4.9609375 + - -1.1132813 + - -2.7910156 + - 2.3945313 + - -1.1445313 + - 1.0087891 + - -0.83447266 + - -2.4648438 + - -0.38891602 + - -3.0117188 + - 0.21484375 + - -0.48168945 + - 2.1523438 + - 0.15002441 + - -2.8925781 + - 1.7236328 + - 0.44360352 + - 3.9707031 + - -1.6025391 + - -2.2929688 + - 0.46020508 + - 0.028640747 + - -2.1523438 + - -1.9892578 + - 1.4970703 + - 2.3457031 + - -0.55859375 + - -3.0625 + - -1.9150391 + - 0.8359375 + - -4.4101563 + - -0.057281494 + - -0.71777344 + - -0.5722656 + - 1.0957031 + - 2.4804688 + - 1.4980469 + - 3.0410156 + - 2.765625 + - -0.54296875 + - 0.7167969 + - -0.38964844 + - 0.04360962 + - -2.7753906 + - 0.73828125 + - -4.2109375 + - 0.7705078 + - -3.4160156 + - 1.1552734 + - 3.4472656 + - -4.21875 + - -1.2353516 + - 0.2746582 + - -1.8798828 + - -1.2822266 + - 0.84765625 + - 4.1015625 + - -0.5810547 + - -0.74316406 + - 3.453125 + - 3.3007813 + - 1.7714844 + - -0.7939453 + - -1.4003906 + - 1.6298828 + - 0.5395508 + - 1.3300781 + - 1.0800781 + - 0.8129883 + - 3.5078125 + - 0.4074707 + - -3.0820313 + - -2.296875 + - 1.3847656 + - 1.1904297 + - -1.0195313 + - -1.5390625 + - -0.69384766 + - 0.39990234 + - -3.1875 + - -2.2578125 + - -0.12902832 + - 0.36132813 + - 5.0039063 + - -0.61376953 + - -0.73291016 + - -1.8564453 + - 7.7382813 + - -3.71875 + - 2.96875 + - 1.3554688 + - -5.4609375 + - -3.0410156 + - 2.6503906 + - -0.4189453 + - -1.3085938 + - 0.0390625 + - 2.78125 + - -0.47607422 + - 1.9746094 + - 2.7519531 + - 2.3769531 + - 2.3945313 + - -1.4921875 + - 3.109375 + - 5.734375 + - -2.0976563 + - 1.2939453 + - 3.3359375 + - -3.3144531 + - -1.0683594 + - 0.3671875 + - -0.02017212 + - 0.77734375 + - -4.2382813 + - 0.35351563 + - -1.6689453 + - -0.40673828 + - 2.2109375 + - 1.5234375 + - 1.8798828 + - 1.8173828 + - -2.5605469 + - 6.0390625 + - 3.828125 + - 4.6328125 + - 2.7285156 + - -5.1875 + - -4.4101563 + - -1.4423828 + - -1.8642578 + - 0.46923828 + - -1.4111328 + - 0.05987549 + - -0.39941406 + - -1.3876953 + - 2.8222656 + - -3.46875 + - 1.0136719 + - 4.4101563 + - 6.9453125 + - 1.0126953 + - -2.71875 + - 0.9794922 + - 7.3203125 + - 2.2539063 + - 0.49658203 + - -0.67871094 + - -3.296875 + - 0.38500977 + - 1.3925781 + - -0.42626953 + - -3.1289063 + - 0.78515625 + - 4.7421875 + - 3.6015625 + - 0.7763672 + - 0.049621582 + - 1.4746094 + - 3.625 + - -0.47192383 + - 2.3632813 + - -5.40625 + - 0.7128906 + - -3.8945313 + - -1.4023438 + - -3.8359375 + - 1.1113281 + - 0.042297363 + - 3.78125 + - 1.6738281 + - -1.609375 + - 2.7207031 + - 1.1787109 + - -3.2285156 + - -3.4550781 + - 0.21582031 + - -3.3847656 + - -3.75 + - 3.0039063 + - -2.6367188 + - 2.1953125 + - 3.9414063 + - -1.2861328 + - -2.6171875 + - -2.7128906 + - 0.99658203 + - 1.4394531 + - -0.3371582 + - 1.3027344 + - -0.4399414 + - 2.7578125 + - 0.38012695 + - -0.80566406 + - -0.5805664 + - -2.9101563 + - 1.9453125 + - 0.02734375 + - -0.24279785 + - -2.90625 + - -2.3476563 + - 3.9804688 + - -1.3994141 + - 0.4699707 + - -1.8886719 + - 2.40625 + - -1.8144531 + - -2.8046875 + - -1.7939453 + - -0.06768799 + - 2.1445313 + - 0.60546875 + - -1.5830078 + - -0.48486328 + - 3.7910156 + - 0.011062622 + - 1.453125 + - 3.6347656 + - -2.609375 + - 2.3496094 + - -0.98828125 + - -4.1445313 + - -2.1210938 + - -1.0595703 + - 3.1601563 + - -2.0371094 + - 4.6328125 + - 1.4697266 + - 1.0527344 + - 0.29003906 + - -1.2949219 + - 0.875 + - 2.2636719 + - -0.86572266 + - -0.8051758 + - -0.8642578 + - -0.5673828 + - -1.8525391 + - -3.09375 + - 2.2988281 + - -5.9726563 + - -3.4921875 + - -4.34375 + - 1.7275391 + - 4.8203125 + - 1.8798828 + - -1.0244141 + - 0.47314453 + - 3.2109375 + - -0.9238281 + - -4.3125 + - -0.35668945 + - 0.37109375 + - -2.796875 + - -1.0546875 + - 5.34375 + - 2.2519531 + - -0.37158203 + - 0.5292969 + - -1.9462891 + - 1.5556641 + - 2.5175781 + - -1.3378906 + - 0.7993164 + - 3.6796875 + - -2.2441406 + - -1.6298828 + - 1.9345703 + - -0.6977539 + - -0.5083008 + - 1.5673828 + - -1.5605469 + - -9.109375 + - 1.8837891 + - -34.78125 + - 1.3105469 + - -0.103149414 + - -1.1875 + - -4.9765625 + - 1.0761719 + - 0.13500977 + - 0.5058594 + - 1.7402344 + - 0.8461914 + - 0.7192383 + - -1.0214844 + - 5.6796875 + - -0.13208008 + - -0.94921875 + - 2.671875 + - 0.30297852 + - -1.2099609 + - 12.359375 + - -3.2695313 + - -0.25585938 + - -0.054016113 + - 0.5961914 + - -0.43896484 + - 0.040039063 + - 6.9609375 + - -1.2011719 + - -1.4970703 + - 1.1767578 + - -2.3085938 + - -1.6259766 + - -3.5644531 + - 1.71875 + - 1.3642578 + - -1.265625 + - 2.4648438 + - 0.8828125 + - -0.21289063 + - 1.4453125 + - -1.09375 + - 3.5644531 + - -2.21875 + - -0.5566406 + - -0.55029297 + - -2.71875 + - 2.5644531 + - -0.98095703 + - 1.7158203 + - 1.4765625 + - -2.6171875 + - 0.5673828 + - -3.3632813 + - 0.09112549 + - -2.0703125 + - -1.0898438 + - 0.40039063 + - 4.875 + - -2.7441406 + - 0.22814941 + - 0.11846924 + - 0.8798828 + - -1.6914063 + - -2.1640625 + - 0.18225098 + - 7.140625 + - -0.023101807 + - 1.1025391 + - 4.8828125 + - -2.0175781 + - 2.109375 + - -1 + - -0.2578125 + - 1.65625 + - -2.5703125 + - 3.0019531 + - -2.7304688 + - 0.52197266 + - 0.45825195 + - 2.9921875 + - 0.4621582 + - -3.1210938 + - -3.3046875 + - 2.5996094 + - 0.71728516 + - 3.1191406 + - 2.5332031 + - -3.1132813 + - -0.6665039 + - 1.0673828 + - -1.2158203 + - 1.890625 + - -1.8837891 + - -0.33325195 + - -2.2519531 + - 0.7036133 + - -0.5732422 + - -3.0039063 + - 0.17382813 + - -1.0527344 + - -1.3515625 + - -2.8925781 + - -5.5546875 + - -1.2675781 + - -0.6269531 + - 0.14086914 + - 3.40625 + - 3.8125 + - 0.027496338 + - 2.4101563 + - 0.11578369 + - 1.0292969 + - 0.5839844 + - -3.0976563 + - 4.7382813 + - 0.32885742 + - 2.6835938 + - -0.51708984 + - 3.2363281 + - -1.53125 + - 3.2910156 + - 1.8261719 + - -0.6567383 + - -1.8789063 + - -1.4707031 + - 0.6298828 + - 3.1035156 + - 2.4707031 + - -0.15686035 + - 0.28808594 + - 2.7851563 + - 3.125 + - 1.9501953 + - -1.8330078 + - 1.6298828 + - 0.8754883 + - -0.6196289 + - 3.0664063 + - -1.8173828 + - -3.4101563 + - 0.859375 + - -0.61328125 + - -1.0517578 + - -2.4921875 + - -2.8378906 + - 1.5820313 + - -1.5546875 + - 3.2910156 + - -2.1308594 + - 0.8564453 + - -3.296875 + - 0.09240723 + - -1.2421875 + - 0.74072266 + - 4.7695313 + - -0.0982666 + - -0.59228516 + - 0.45825195 + - -2.6972656 + - 4.3203125 + - -2.3066406 + - 2.21875 + - -4.6015625 + - -5.1171875 + - -0.2705078 + - -2.2597656 + - -0.6220703 + - -4.3164063 + - -14.125 + - 0.76416016 + - -0.33007813 + - 6.03125 + - 2.125 + - 2.6347656 + - 0.8642578 + - 1.6621094 + - -0.38916016 + - 0.22521973 + - -1.3671875 + - -2.5566406 + - 1.9296875 + - 3.03125 + - 0.859375 + - -4.3398438 + - 1.1103516 + - -1.6923828 + - 0.54003906 + - 0.30200195 + - 2.8222656 + - 1.9316406 + - 1.0556641 + - 2.0976563 + - 2.4023438 + - -2.8769531 + - 0.9243164 + - -1.2138672 + - -20.15625 + - 1.4511719 + - -0.03125 + - 0.9589844 + - -2.6992188 + - -1.0195313 + - 1.3925781 + - 0.34179688 + - -3.6875 + - -3.0175781 + - -0.3359375 + - 1.4033203 + - -1.140625 + - -1.1269531 + - -1.1074219 + - 1.0742188 + - -6.1171875 + - -0.8149414 + - 0.15356445 + - -0.53222656 + - 1.6142578 + - 0.95166016 + - 3.1582031 + - -1.6103516 + - -0.7763672 + - 1.5488281 + - -5.1132813 + - -0.63720703 + - 1.2666016 + - -0.25048828 + - 4.2421875 + - -3.3457031 + - 0.8129883 + - 0.28076172 + - 0.28637695 + - 4.4453125 + - 0.453125 + - 1.8876953 + - -15.375 + - 0.6738281 + - -1.4277344 + - 1.5019531 + - -3.5664063 + - 2.2441406 + - -1.1171875 + - 1.8828125 + - 1.7548828 + - -0.8828125 + - 2.3339844 + - -2.0078125 + - 0.8935547 + - -0.69628906 + - 0.10107422 + - -1.4277344 + - 1.234375 + - 3.796875 + - 2.2988281 + - -5.8632813 + - -2.6738281 + - 2.9316406 + - -3.5800781 + - 0.058898926 + - 2.8007813 + - -0.007484436 + - 4.3828125 + - -2.140625 + - -3.0820313 + - 1.2695313 + - 1.8994141 + - 1.8564453 + - -0.27270508 + - -0.09033203 + - -2.21875 + - -0.3930664 + - -1.734375 + - -0.4819336 + - 0.97558594 + - 1.6064453 + - 5.0664063 + - 0.82910156 + - 1.2167969 + - 2.671875 + - -2.7382813 + - 2.1132813 + - -2.8320313 + - -1.8486328 + - -1.7109375 + - 1.9003906 + - 6.0820313 + - 1.2011719 + - 1.7392578 + - -2.7890625 + - -1.9960938 + - 2.4023438 + - 1.515625 + - 4.5390625 + - 1.1542969 + - -3.7695313 + - 2.203125 + - -0.9223633 + - 1.0097656 + - 1.5361328 + - -1.9609375 + - -3.4316406 + - 1.15625 + - 2.15625 + - -0.53125 + - 0.9609375 + - 0.53515625 + - -3.2910156 + - 2.3496094 + - -0.46484375 + - 2.5195313 + - 3.8847656 + - 0.37109375 + - -0.8173828 + - 3.7128906 + - 1.5595703 + - -2.5234375 + - -2.140625 + - 3.734375 + - -0.25878906 + - 2.7207031 + - -3.15625 + - 0.640625 + - 1.7597656 + - -2.0703125 + - 1.5878906 + - 4.65625 + - -2.2460938 + - -1.2089844 + - 0.4621582 + - 0.23046875 + - 0.65234375 + - 2.0859375 + - 1.1845703 + - 4.453125 + - 0.6455078 + - -1.2285156 + - -2.4882813 + - -2.3222656 + - 2.375 + - 0.95703125 + - 0.7109375 + - 0.83447266 + - -1.1503906 + - -4.890625 + - -0.58935547 + - 3.8535156 + - -3.0878906 + - -0.23120117 + - -2.2773438 + - -0.82421875 + - 3.7207031 + - 5.15625 + - -0.5644531 + - -3.6894531 + - 0.49169922 + - -1.1660156 + - -0.7832031 + - -1.6738281 + - 1.171875 + - -4.4453125 + - 1.03125 + - 2.7285156 + - 7.9257813 + - -1.6503906 + - 1.8007813 + - -0.10284424 + - 0.84765625 + - -1.7128906 + - -3.0039063 + - 5.2109375 + - -1.3691406 + - 3.3125 + - 3.4570313 + - -2.9375 + - -1.640625 + - -5.34375 + - 2.0117188 + - 1.3642578 + - -0.19213867 + - -2.0703125 + - -3.9003906 + - 3.3359375 + - -1.1699219 + - -1.5244141 + - 1.2226563 + - 0.6279297 + - 0.15734863 + - 2.0175781 + - 5.6484375 + - -0.7236328 + - -1.1660156 + - 0.6064453 + - 3.34375 + - 1.7587891 + - -0.8173828 + - -4.1953125 + - -2.0117188 + - -1.7128906 + - 0.82910156 + - 1.3769531 + - 4.546875 + - -1.8222656 + - 2.21875 + - 1.09375 + - -2.6308594 + - 4.1640625 + - 1.5439453 + - 0.26367188 + - -1.7441406 + - -3.578125 + - 3.9882813 + - -3.328125 + - 0.90722656 + - -2.671875 + - 2.7753906 + - 2.3183594 + - -1.0273438 + - -0.5024414 + - 1.0234375 + - -2.6289063 + - 2.1738281 + - -0.72265625 + - 3.3769531 + - -0.25805664 + - 6.3945313 + - -2.5878906 + - -2.703125 + - 1.6796875 + - -1.2431641 + - 7.5664063 + - 3.5898438 + - 0.035949707 + - 0.5727539 + - -0.50683594 + - -0.36083984 + - 4.1171875 + - -0.6035156 + - 0.020828247 + - -0.05987549 + - 0.39941406 + - 2.5273438 + - -1.7587891 + - -2.0585938 + - -1.0625 + - -4.734375 + - 2.828125 + - -3.1738281 + - -2.3417969 + - 0.9707031 + - 1.2626953 + - -5.4726563 + - -1.2929688 + - -0.06347656 + - 1.7470703 + - 0.00504303 + - -1.1835938 + - 1.6425781 + - 0.033233643 + - -2.4277344 + - 3.703125 + - -0.30297852 + - 2.53125 + - 2.7460938 + - -3.7070313 + - -0.54589844 + - 2.6015625 + - -5.0039063 + - -0.7246094 + - -0.12365723 + - 0.7236328 + - 1.2978516 + - -1.3496094 + - -1.1367188 + - 1.421875 + - -0.7368164 + - 4.34375 + - -0.6015625 + - 4.796875 + - -0.0065078735 + - 2.9765625 + - -3.8984375 + - -2.9101563 + - -1.9511719 + - -3.1132813 + - -0.38012695 + - 0.099609375 + - 2.0019531 + - -3.5 + - -0.16796875 + - -0.16796875 + - -4.3671875 + - 2.1914063 + - 1.4023438 + - 1.7861328 + - -1.3066406 + - -0.28515625 + - -12.4453125 + - 4.234375 + - 2.2773438 + - 0.4777832 + - 2.3027344 + - -1.7939453 + - -3.65625 + - -0.48291016 + - 0.83447266 + - 3.3320313 + - -1.3720703 + - -0.60253906 + - 3.6035156 + - 2.3222656 + - 0.12719727 + - -3.0273438 + - -3.0878906 + - -0.09765625 + - -1.046875 + - -3.7695313 + - -1.5283203 + - 0.57910156 + - -1.3457031 + - 2.0332031 + - -0.2524414 + - -4.9101563 + - 3.1757813 + - 5.2890625 + - 0.6801758 + - 2.0097656 + - 0.36767578 + - 1.0224609 + - 3.0175781 + - 1.7402344 + - 2.921875 + - -0.20898438 + - 0.8227539 + - -1.5205078 + - 1.2421875 + - 1.5644531 + - 2.0195313 + - 1.1933594 + - -2.1523438 + - 5.171875 + - 2.9472656 + - -4.359375 + - -74 + - 2.8378906 + - -2.5117188 + - -1.3486328 + - 2.9960938 + - 4.4765625 + - 1.4042969 + - 1.7890625 + - -0.31225586 + - -3.9003906 + - 0.15649414 + - 0.43408203 + - 1.59375 + - 1.7929688 + - 0.35351563 + - -4.7421875 + - -1.1943359 + - -4.5 + - 0.43603516 + - 1.1015625 + - 2.3300781 + - 0.76416016 + - 1.6015625 + - 0.009765625 + - -3.1367188 + - -3.609375 + - 0.69384766 + - -2.5351563 + - 2.0429688 + - -0.9970703 + - 0.6977539 + - -4.625 + - 1.1503906 + - -1.109375 + - -2.8691406 + - 0.057617188 + - 2.0605469 + - -0.8798828 + - -0.65625 + - -3.3476563 + - 1.0224609 + - 1.2070313 + - 2.9316406 + - 2.0273438 + - 0.46044922 + - -1.5625 + - 0.9404297 + - 0.9863281 + - 1.1357422 + - 0.92871094 + - 3.03125 + - -0.49072266 + - -0.23156738 + - 19.15625 + - -5.7578125 + - 2.671875 + - -0.35791016 + - 2.8398438 + - 1.6015625 + - -7.21875 + - 4.8789063 + - 3.3378906 + - -0.16320801 + - -1.0761719 + - 0.14282227 + - 1.4921875 + - -4.6132813 + - 1.3359375 + - -1.4375 + - -1.1367188 + - -2.4160156 + - 3.5351563 + - 3.3359375 + - 2.9257813 + - 1.546875 + - -1.859375 + - -2.5507813 + - -0.75439453 + - 0.39257813 + - 1.6806641 + - -0.29638672 + - 0.5517578 + - -2.9238281 + - -2.5488281 + - -0.09875488 + - -2.3613281 + - 0.80859375 + - 3.3671875 + - -0.37353516 + - -0.94189453 + - -2.9472656 + - -1.59375 + - 0.87353516 + - 3.4414063 + - -0.61572266 + - 6.9140625 + - 2.8085938 + - 4.1640625 + - -2.9472656 + - 0.04425049 + - -0.2512207 + - -0.36157227 + - -1.2441406 + - 2.7734375 + - 0.2548828 + - -1.2197266 + - -0.13867188 + - -0.88134766 + - 1.8203125 + - -0.86328125 + - 6.1328125 + - 23.078125 + - -0.640625 + - -1.1552734 + - -3.1484375 + - 1.96875 + - 0.2619629 + - 3.5 + - 2.5332031 + - -2.0078125 + - -2.0585938 + - 5.171875 + - 1.8515625 + - 0.49267578 + - -1.8642578 + - 1.5039063 + - 1.1074219 + - -3.0820313 + - 0.1126709 + - 0.020507813 + - 4.2539063 + - 0.5571289 + - 1.0449219 + - 1.2050781 + - -0.7529297 + - 1.6533203 + - 0.54345703 + - -1.1298828 + - 4.90625 + - 2.0253906 + - -1.7011719 + - -2.59375 + - 2.421875 + - 0.57714844 + - -2.9003906 + - 0.32543945 + - -2.1386719 + - -1.9335938 + - -2.2304688 + - 3.2929688 + - 1.0253906 + - 1.3085938 + - 3.3359375 + - 1.0859375 + - -2.28125 + - -0.46533203 + - 5.328125 + - -0.6411133 + - 0.8408203 + - 1.609375 + - 2.7539063 + - 1.0498047 + - -5.109375 + - -1.2265625 + - 2.2773438 + - 6.6171875 + - -0.80566406 + - -1.7910156 + - 0.9345703 + - -3.9726563 + - -0.38012695 + - 2.0957031 + - -2.3828125 + - -0.13085938 + - -0.83251953 + - 3.265625 + - 2.40625 + - 1.1796875 + - 1.0087891 + - 1.0927734 + - -1.578125 + - -1.7167969 + - -4.4414063 + - 1.3691406 + - 1.1953125 + - -0.39892578 + - -1.78125 + - 0.022125244 + - -1.5292969 + - -0.37841797 + - -12.890625 + - 0.45507813 + - 0.3371582 + - -2.0722656 + - 1.0898438 + - -2.3398438 + - -1.0986328 + - 0.5566406 + - -0.47998047 + - 0.8769531 + - 2.7753906 + - 1.2236328 + - 4.3203125 + - 0.9736328 + - -1.7363281 + - 1.8417969 + - -3.8476563 + - -4.1875 + - -3.3710938 + - 0.15356445 + - -0.93847656 + - -3.78125 + - 2.765625 + - 0.87597656 + - -0.59814453 + - 0.7939453 + - 2.0429688 + - 5.7382813 + - 1.1347656 + - 0.28833008 + - 1.3955078 + - 4.2421875 + - -3.3125 + - -3.8554688 + - -0.09729004 + - 0.62353516 + - 2.703125 + - 0.68652344 + - -0.009109497 + - -2.28125 + - 2.0820313 + - 1.9179688 + - 0.4663086 + - -1.8876953 + - -2.1523438 + - 1.4589844 + - -2.4394531 + - 2.921875 + - 1.8095703 + - 0.32348633 + - 2.796875 + - 2.1875 + - -0.23535156 + - 1.4736328 + - -5.1484375 + - -4.3945313 + - -2.734375 + - 1.6347656 + - 2.125 + - -2.2695313 + - -2.4472656 + - -2.3398438 + - -4.4101563 + - -5.8007813 + - 1.4511719 + - -0.27783203 + - 2.2617188 + - -0.6044922 + - -1.1474609 + - -4.0625 + - -0.54589844 + - 1.5429688 + - 0.8984375 + - -0.3857422 + - 0.41015625 + - 0.8071289 + - 18 + - 0.61035156 + - -0.3479004 + - 1.5517578 + - 15.2578125 + - 0.20629883 + - 0.33007813 + - -0.16113281 + - 1.203125 + - -3.4609375 + - -0.73876953 + - 5.375 + - -0.5419922 + - 6.1171875 + - 0.9165039 + - 2.5566406 + - 0.52783203 + - -0.033843994 + - -2.7851563 + - -3.9609375 + - 7.0195313 + - -0.013832092 + - 2.2988281 + - 5.2890625 + - 0.9433594 + - -4.2109375 + - 0.5439453 + - 3.828125 + - 1.3691406 + - 0.084472656 + - -0.51416016 + - 1.9941406 + - 1.6728516 + - -0.5073242 + - -5.7734375 + - 0.1652832 + - -0.6064453 + - -0.9238281 + - -1.2880859 + - -3.7226563 + - 2.8769531 + - -0.27929688 + - -5.875 + - -1.0927734 + - 2.8789063 + - 0.14172363 + - -0.5678711 + - 0.37646484 + - 0.35205078 + - -4.265625 + - 4.203125 + - -1.1142578 + - 4.21875 + - 2.7851563 + - 2.6621094 + - -2.9238281 + - -0.36621094 + - -0.20227051 + - -2.7597656 + - -3.7851563 + - 4.2851563 + - 2.3164063 + - 0.47387695 + - -1.5878906 + - 1.0175781 + - -2.8925781 + - 2.2695313 + - -3.6914063 + - -2.90625 + - 1.0556641 + - -2.7617188 + - -2.3828125 + - 1.1035156 + - 3.6796875 + - -4.1796875 + - -3.6328125 + - -1.0761719 + - -3.8164063 + - -1.3251953 + - -3.2695313 + - 0.6142578 + - 0.33642578 + - -0.60546875 + - -3.3632813 + - 0.27856445 + - 2.4804688 + - -0.005859375 + - -3.453125 + - -3.1875 + - -0.30273438 + - 0.27001953 + - -0.025390625 + - -5.6132813 + - -2.9941406 + - -5.875 + - 3.1484375 + - 0.44140625 + - -1.6796875 + - -1.0410156 + - -3.4160156 + - 3.5820313 + - -0.81347656 + - 3.03125 + - 2.9101563 + - -5.4765625 + - 0.8930664 + - 1.3232422 + - -0.7001953 + - 4.234375 + - -2.5605469 + - 1.375 + - -0.32641602 + - 0.43847656 + - -1.6894531 + - -3.4863281 + - -0.0013017654 + - 2.3457031 + - -1.5449219 + - 1.9824219 + - -2.0859375 + - 0.011390686 + - -6.4765625 + - -0.7265625 + - 1.3144531 + - 0.72265625 + - 1.9667969 + - 3.2285156 + - 2.4492188 + - 3.2753906 + - -0.6191406 + - -0.20715332 + - -0.6738281 + - -1.6425781 + - -2.0429688 + - 2.75 + - 0.39453125 + - -2.234375 + - 1.2246094 + - 1.4462891 + - -1.1611328 + - -0.14904785 + - -3.4726563 + - -3.0878906 + - -0.2697754 + - 0.72753906 + - -1.2978516 + - 1.9814453 + - 1.6972656 + - 2.2578125 + - 4.6132813 + - 2.875 + - -1.4121094 + - -1.1679688 + - -5.0742188 + - 3.8691406 + - 3.1660156 + - -0.63134766 + - 3.8515625 + - 3.4023438 + - -4.703125 + - 0.8173828 + - 1.71875 + - -3.1015625 + - 1.7080078 + - -2.8554688 + - -0.7597656 + - -0.9326172 + - -0.109191895 + - 2.6972656 + - -0.2130127 + - -1.6132813 + - -4.0234375 + - 0.5908203 + - 1.0527344 + - -0.95751953 + - 1.6660156 + - -5.7226563 + - 3.6679688 + - -0.9609375 + - 1.8105469 + - 1.2666016 + - -2.5253906 + - 4.5742188 + - -2.3535156 + - 1.1855469 + - -1.7353516 + - 0.3647461 + - 0.4621582 + - -0.17773438 + - 2.1914063 + - -0.123046875 + - -1.8798828 + - -2.0722656 + - 2.4160156 + - -0.6821289 + - 0.9145508 + - -0.6699219 + - 3.6347656 + - -0.4506836 + - -2.5234375 + - 0.36083984 + - 3.8867188 + - 1.0419922 + - 0.26171875 + - -3.0488281 + - -1.7773438 + - -3.5644531 + - 8.484375 + - -2.2363281 + - 0.8208008 + - -1.0859375 + - 3.5 + - -1.0898438 + - 0.34301758 + - 3.1035156 + - -2.7539063 + - -0.7392578 + - -0.31958008 + - 1.5429688 + - -4.7421875 + - -1.5078125 + - 4.9453125 + - -2.2304688 + - 4.4765625 + - -0.57910156 + - 0.50097656 + - 0.8066406 + - 3.640625 + - 0.65185547 + - -1.6796875 + - -1.2626953 + - 1.1816406 + - -0.93847656 + - -5.1484375 + - 2.796875 + - -1.8652344 + - -3.5488281 + - -0.9433594 + - 1.9453125 + - 0.96191406 + - 2.0449219 + - -3.4863281 + - -1.5751953 + - 0.7236328 + - -0.9736328 + - -2.9609375 + - -4.0078125 + - 0.32543945 + - -3.0625 + - -1.9082031 + - 0.2536621 + - 1.0478516 + - -0.12597656 + - 2 + - 2.5058594 + - -1.6220703 + - -1.5644531 + - 2.1894531 + - 0.51660156 + - -0.79296875 + - -0.96533203 + - -2.53125 + - -1.0117188 + - 0.8876953 + - -1.3691406 + - -1.8613281 + - -3.0410156 + - -1.7900391 + - 1.9658203 + - -0.9121094 + - -0.27783203 + - 1.84375 + - 0.3996582 + - -1.0654297 + - -2.6601563 + - 2.1464844 + - -1.9316406 + - -2.9375 + - -2.375 + - -3.4160156 + - -2.4570313 + - 0.39501953 + - -1.2490234 + - 0.6035156 + - 4.7578125 + - 6.3125 + - -5.4570313 + - -1.9628906 + - -2.4863281 + - 4.7382813 + - -4.0429688 + - -2.8476563 + - 1.1337891 + - -2.4941406 + - -2.9492188 + - 0.68847656 + - 3.1503906 + - 1.1210938 + - -3.1191406 + - -0.6035156 + - -1.3535156 + - -1.6064453 + - 1.7626953 + - -1.9804688 + - -2.0917969 + - -1.3691406 + - -0.78222656 + - -3.7675781 + - 1.9072266 + - -1.2460938 + - 1.421875 + - -3.3378906 + - -0.48364258 + - 2.4375 + - -3.7910156 + - 9.40625 + - -3.3320313 + - -2.0078125 + - 4.4375 + - -0.16247559 + - 1.2167969 + - -1.5859375 + - -0.02961731 + - -2.2871094 + - 2.2089844 + - 6.4101563 + - -3.5625 + - -2.1816406 + - 5.1523438 + - -1.3691406 + - 1.7929688 + - -0.002603531 + - -2.6015625 + - 2.2851563 + - 3.7988281 + - -3.9414063 + - 1.6425781 + - -4.6875 + - -0.8071289 + - 3.3984375 + - -9.109375 + - -0.5864258 + - 5.3945313 + - -1.7861328 + - -1.1875 + - -0.7871094 + - 4.5507813 + - -3.2207031 + - -4.96875 + - -2.0664063 + - 0.5048828 + - 5.0429688 + - 1.0175781 + - 2.0585938 + - -0.9560547 + - -2.4648438 + - 0.03286743 + - -4.28125 + - -1.1728516 + - -0.59814453 + - 3.75 + - 1.2246094 + - -2.6386719 + - -3.546875 + - 0.17114258 + - -0.09472656 + - 1.046875 + - 1.3876953 + - 0.7265625 + - -0.47998047 + - -4.1953125 + - -1.9609375 + - -1.9501953 + - 1.5605469 + - 0.39990234 + - -0.71533203 + - -0.57470703 + - -3.5820313 + - -4.4570313 + - 2.1445313 + - 0.7578125 + - 0.18676758 + - -4.5039063 + - -0.08135986 + - -0.09631348 + - -1.8847656 + - -1.8984375 + - -0.3400879 + - -0.47998047 + - 3.1347656 + - -4.1328125 + - -0.8232422 + - 0.71777344 + - 2.1777344 + - -0.30322266 + - -1.8798828 + - -2.1523438 + - -0.1282959 + - 0.35302734 + - -5.2109375 + - 1.0439453 + - 3.7890625 + - 4.3203125 + - 0.9946289 + - 1.1191406 + - 0.5551758 + - 4.265625 + - 2.5566406 + - -3.1757813 + - 1.3759766 + - 1.7705078 + - 1.8789063 + - -3.515625 + - -0.57177734 + - 2.5957031 + - 2.7441406 + - 1.4775391 + - -1.7666016 + - 1.953125 + - -1.8046875 + - -0.12524414 + - 3.5 + - 0.18225098 + - -0.95703125 + - 4.3671875 + - -1.4648438 + - -0.9501953 + - -1.2714844 + - -1.8515625 + - -3.8671875 + - 0.9248047 + - 3.5644531 + - -3.2851563 + - -1.8759766 + - 0.5234375 + - 0.77441406 + - 5.0390625 + - 8.03125 + - -3.0878906 + - 0.10675049 + - -1.6738281 + - -1.5683594 + - 0.5629883 + - 0.98876953 + - -0.9711914 + - 3.5039063 + - -3.0117188 + - -4.2851563 + - -0.75097656 + - -2.6523438 + - -1.5585938 + - -0.95214844 + - -1.8955078 + - 2.4238281 + - 4.09375 + - 1.0087891 + - 2.1328125 + - 3.6210938 + - -1.8876953 + - -1.6953125 + - -0.9736328 + - 0.97509766 + - 1.7695313 + - 0.19726563 + - -2.953125 + - 0.07519531 + - -1.6572266 + - -0.55078125 + - -3.4492188 + - 0.86572266 + - -0.40283203 + - 0.51953125 + - -1.6298828 + - 1.9462891 + - -3.2382813 + - -0.4543457 + - 0.08459473 + - -0.3725586 + - 8.359375 + - -1.9736328 + - 3.078125 + - -6.90625 + - 3.5019531 + - 3.078125 + - -2.7441406 + - 1.2988281 + - 1.2304688 + - -0.87109375 + - -2.9941406 + - 0.11242676 + - -1.0742188 + - -1.0800781 + - -2.8847656 + - -0.8496094 + - -1.4003906 + - 4.9375 + - 0.011062622 + - 0.7714844 + - 0.9321289 + - -1.015625 + - 2.1484375 + - -3.4726563 + - 1.3017578 + - -0.2043457 + - 5.09375 + - -3.7441406 + - -3.4375 + - 2.5917969 + - -1.7236328 + - -2.96875 + - 2.671875 + - 0.48486328 + - -0.53515625 + - 1.5644531 + - 3.8925781 + - 1.5 + - -7.15625 + - 4.25 + - 0.5839844 + - -0.67089844 + - 1.4267578 + - -4.046875 + - 0.06085205 + - 1.5019531 + - -1.2285156 + - 3.0351563 + - -8 + - 2.3476563 + - -1.1425781 + - -0.47070313 + - -1.9033203 + - -1.4580078 + - 1.0644531 + - -0.9482422 + - 0.2734375 + - 1.9316406 + - 2.546875 + - 0.7626953 + - 0.62109375 + - -5.0507813 + - 0.8696289 + - 1.1464844 + - -0.50390625 + - 1.9472656 + - 0.26611328 + - -1.3447266 + - -1.2792969 + - -1.2011719 + - 2.8242188 + - 0.17150879 + - 2.7617188 + - 1.6142578 + - -4.9765625 + - 1.6386719 + - -1.9648438 + - -0.50683594 + - 0.005207062 + - 2.0917969 + - -1.3164063 + - -0.09765625 + - 7.171875 + - -1.4003906 + - -0.5078125 + - 4.2070313 + - 0.94628906 + - 0.2685547 + - -1.9238281 + - 2.2226563 + - 7 + - -1.4765625 + - -0.40161133 + - -0.4025879 + - -2.8945313 + - 4.7265625 + - -1.5859375 + - -3.6289063 + - -0.18481445 + - -1.7050781 + - -1.0244141 + - 0.16992188 + - 1.2441406 + - -4.1796875 + - 0.11584473 + - 0.12695313 + - 2.6210938 + - 3.7070313 + - 3.140625 + - 1.1640625 + - 0.17138672 + - 3.2753906 + - 0.6040039 + - -5.0703125 + - 3.1875 + - -3.40625 + - -3.4101563 + - -18.828125 + - -3.3867188 + - 0.34033203 + - 4.5078125 + - -4.2578125 + - 1.8261719 + - -15.546875 + - -6.8320313 + - -0.25146484 + - -1.1142578 + - 1.4101563 + - -2.1464844 + - -0.06311035 + - 5.6132813 + - 0.609375 + - 2.4941406 + - -0.095703125 + - 1.9628906 + - 1.8984375 + - -5.0390625 + - -1.5390625 + - 2.4101563 + - -1.3535156 + - 0.25048828 + - 1.6494141 + - -1.015625 + - 1.8330078 + - 0.032226563 + - -0.41333008 + - 1.9814453 + - -1.1152344 + - -5.0820313 + - 1.7158203 + - -2.3613281 + - 1.0039063 + - 1.1445313 + - 1.1855469 + - -2.3222656 + - 2.7597656 + - -2.234375 + - 0.30615234 + - 5.46875 + - 1.4003906 + - 0.33520508 + - -3.1113281 + - -0.9633789 + - 4.3125 + - -1.6455078 + - -1.6640625 + - 2.0117188 + - 2.4179688 + - 2.7929688 + - -1.6152344 + - -3.4414063 + - 0.44848633 + - -4.8984375 + - 1.0996094 + - 2.5820313 + - -3.7226563 + - -0.3215332 + - 0.93066406 + - -0.83447266 + - -0.38891602 + - 4.9296875 + - 2.3300781 + - 1.1542969 + - -2.9375 + - 0.4338379 + - 4.8984375 + - -0.52441406 + - -1.5908203 + - 2.5117188 + - -2.2929688 + - -0.87890625 + - -0.81347656 + - 1.4589844 + - -2.6425781 + - 0.38598633 + - 1.1855469 + - -2.0429688 + - 1.2470703 + - 4.2695313 + - 0.028320313 + - 0.12109375 + - -1.5 + - -4.3828125 + - 0.85791016 + - -2.6386719 + - 1.0341797 + - 0.2199707 + - -1.5253906 + - 1.3408203 + - -3.5625 + - -6.9882813 + - 0.3203125 + - -5.7578125 + - -0.9042969 + - 1.4853516 + - 1.3212891 + - -2.4160156 + - 1.0097656 + - -4.296875 + - -2.8925781 + - -3.3789063 + - -0.86572266 + - -1.8447266 + - -0.057922363 + - 4.3164063 + - -0.41357422 + - 2.2128906 + - -1.5957031 + - -3.0332031 + - -0.6298828 + - -1.1777344 + - -1.6542969 + - -0.5727539 + - 1.0410156 + - -2.8066406 + - 1.28125 + - -0.25708008 + - 2.5664063 + - -0.70214844 + - 2.6191406 + - -4.8320313 + - -2.2207031 + - 1.0322266 + - -4.5859375 + - -3.4707031 + - -0.82421875 + - 5.265625 + - 7.21875 + - -0.6113281 + - 0.14709473 + - 4.6757813 + - 3.2539063 + - -1.8515625 + - -0.8154297 + - -3.2285156 + - 1.9921875 + - 2.2148438 + - 0.71191406 + - -1.3535156 + - -1.4267578 + - 0.17541504 + - -3.3007813 + - -3.7207031 + - 1.2480469 + - -0.7211914 + - -1.2402344 + - -1.46875 + - -2.3671875 + - 1.3730469 + - 1.1972656 + - -1.9931641 + - 0.008460999 + - -2.7753906 + - -0.9765625 + - 1.4550781 + - 0.67089844 + - -2.78125 + - -3.9765625 + - -0.6464844 + - 0.97314453 + - 5.7226563 + - -1.7617188 + - 0.43798828 + - 2.4648438 + - 14.609375 + - -0.89160156 + - -3.5488281 + - 1.40625 + - -0.9633789 + - 4.6914063 + - -1.7617188 + - -9.1640625 + - 2.9785156 + - -1.6113281 + - 6.59375 + - 1.1025391 + - 0.38330078 + - 0.045898438 + - 1.7861328 + - 3.0253906 + - 1.6845703 + - -4.0664063 + - -0.6582031 + - -3.8476563 + - 1.6376953 + - -0.35473633 + - 1.7167969 + - -2.7832031 + - -1.6972656 + - 2.6484375 + - 0.05532837 + - -3.84375 + - -1.9736328 + - -1.2441406 + - -0.29760742 + - -0.20874023 + - 2.203125 + - -1.1289063 + - -1.96875 + - -1.7617188 + - -0.79589844 + - 2.0644531 + - -0.5283203 + - 0.4560547 + - -3.6113281 + - 4.2109375 + - 0.63623047 + - -3.1875 + - 1.65625 + - -2.8632813 + - -0.3671875 + - 2.3632813 + - -3.359375 + - 4.921875 + - 3.6289063 + - -0.55371094 + - 4.6875 + - -0.86621094 + - -1.6542969 + - 2.203125 + - 2.4003906 + - 2.4804688 + - -3.4003906 + - 2.6289063 + - -3.3457031 + - 4.8164063 + - 1.4804688 + - -1.8515625 + - -1.4667969 + - 2.953125 + - 0.6767578 + - -1.7666016 + - -2.9804688 + - -2.3554688 + - -0.016921997 + - 0.037261963 + - 1.6191406 + - 0.22387695 + - -1.9355469 + - -5.296875 + - 4.078125 + - -0.28320313 + - -2.9921875 + - -0.9472656 + - -0.5205078 + - 0.09436035 + - -0.024734497 + - -2.2226563 + - 0.18859863 + - 1.2792969 + - -1.7587891 + - 3.96875 + - -1.8046875 + - 1.1855469 + - 1.0712891 + - -3.03125 + - 1.1933594 + - -0.15588379 + - 1.9921875 + - 0.24865723 + - -1.7714844 + - 4.5234375 + - -2.078125 + - 1.8681641 + - 0.98046875 + - -0.33520508 + - -3.5195313 + - 4.7617188 + - 1.1386719 + - 0.24902344 + - -0.84277344 + - 0.40625 + - 3.7910156 + - -1.0361328 + - -2.6679688 + - -2.609375 + - -3.3378906 + - -0.018875122 + - -2.0644531 + - 1.3408203 + - 0.50146484 + - -1.4648438 + - -0.016921997 + - -3.7734375 + - 4.8984375 + - 22.46875 + - -2.5898438 + - 3.4765625 + - -0.9609375 + - 0.2861328 + - 0.2746582 + - 1.0527344 + - -3.6113281 + - 2.140625 + - -0.7060547 + - -2.6640625 + - 0.59277344 + - -2.5039063 + - -2.0117188 + - 0.6923828 + - 2.953125 + - -7.484375 + - 0.6113281 + - -0.24182129 + - 10.84375 + - -4.5 + - 2.25 + - 1.5820313 + - -0.46020508 + - 1.2763672 + - -1.4472656 + - -0.5859375 + - -2.4902344 + - 0.30664063 + - -0.7626953 + - -23.875 + - 0.13769531 + - -1.609375 + - 1.1171875 + - -2.3710938 + - -1.8945313 + - 1.0908203 + - 2.4921875 + - -3.21875 + - 0.11340332 + - -0.6767578 + - -0.44580078 + - -2.875 + - -1.1230469 + - -1.984375 + - -0.86376953 + - -3.296875 + - -0.4555664 + - -0.8203125 + - -0.15454102 + - -3.6738281 + - 2.9804688 + - 0.12854004 + - -1.1914063 + - 4.0234375 + - -1.8564453 + - -0.83984375 + - -1.3613281 + - 4.6640625 + - -1.8671875 + - 0.28735352 + - 2.7363281 + - -6.1015625 + - 0.7290039 + - -1.3056641 + - 3.7695313 + - -1.1601563 + - 1.0625 + - -3.78125 + - 2.2773438 + - 0.2277832 + - -0.37817383 + - -5.8515625 + - 3.671875 + - -1.3378906 + - -2.1328125 + - 2.2304688 + - 1.4345703 + - 1.7617188 + - 3.8515625 + - 1.5800781 + - -0.875 + - -2.1171875 + - 2.2539063 + - 4.5703125 + - 1.1855469 + - -1.3242188 + - -2.28125 + - -1.7871094 + - -0.83691406 + - -1.1425781 + - 0.06542969 + - 2.5722656 + - 3.4179688 + - 0.3774414 + - 3.0566406 + - 4.8046875 + - 3.234375 + - -0.27661133 + - -0.28564453 + - 0.32421875 + - -2.1894531 + - -0.26611328 + - -1.7158203 + - -0.8017578 + - 0.16564941 + - -3.03125 + - 0.86035156 + - -4.609375 + - -0.38916016 + - -3.0253906 + - -3.7070313 + - 1.2519531 + - 0.6308594 + - 2.625 + - -1.171875 + - 1.8955078 + - 4.3671875 + - 0.7158203 + - 0.41308594 + - -3.3222656 + - 3.0195313 + - 2.3242188 + - -1.4941406 + - -1.5791016 + - 5.890625 + - 0.2578125 + - 3.5039063 + - -1.0683594 + - -0.35864258 + - 1.4765625 + - 0.49047852 + - 3.7050781 + - 0.25341797 + - 0.31298828 + - -0.7685547 + - -3.1914063 + - -8.0859375 + - 1.5517578 + - -0.95751953 + - 2.3789063 + - 2.1582031 + - 0.8828125 + - 0.17248535 + - 2.7675781 + - 0.2130127 + - 0.421875 + - 1.1416016 + - -0.037750244 + - 3.7109375 + - 2.0234375 + - -0.0234375 + - -0.38476563 + - 0.5810547 + - -3.2597656 + - 7.3515625 + - 1.3300781 + - 2.2382813 + - 8.9453125 + - 14.390625 + - -0.80566406 + - -2.8847656 + - -0.19458008 + - -1.0244141 + - -0.7836914 + - 2.0332031 + - -0.25024414 + - 1.1953125 + - 0.16796875 + - -2.890625 + - 0.45751953 + - 2.0722656 + - 1.1640625 + - 0.4345703 + - 1.5634766 + - -0.96972656 + - 2.1953125 + - -1.9414063 + - -2.859375 + - -4.1640625 + - -1.1455078 + - 1.7265625 + - -0.72753906 + - -1.5800781 + - -4.5078125 + - -0.3244629 + - 0.98828125 + - 0.46923828 + - -1.0166016 + - -0.921875 + - -1.7265625 + - 3.3476563 + - -1.6611328 + - 1.7001953 + - 3.6132813 + - -0.921875 + - 0.4807129 + - -1.1152344 + - -1.2421875 + - 5.4375 + - -0.59765625 + - 0.88134766 + - -2.1542969 + - -0.44482422 + - -2.8945313 + - 4.2382813 + - -0.16369629 + - -3.4921875 + - -0.6894531 + - 3.8164063 + - -0.084472656 + - -0.40820313 + - 2.1269531 + - 1.9228516 + - 0.33813477 + - -3.0234375 + - -1.9277344 + - 0.22521973 + - 1.9921875 + - -1.0722656 + - 4.4375 + - 1.8457031 + - 3.5722656 + - 2.5078125 + - -2.7578125 + - 1.578125 + - -2.203125 + - 1.3535156 + - -0.59228516 + - -2.2070313 + - -1.0908203 + - 0.69628906 + - -0.20605469 + - 1.6328125 + - 2.4882813 + - -0.27734375 + - 0.00894928 + - 1.8417969 + - 0.70947266 + - 19.9375 + - -5.421875 + - -0.47705078 + - -4.2617188 + - -0.38085938 + - -0.26123047 + - 3.9101563 + - -0.67578125 + - -5.5078125 + - 3.8789063 + - 2.0234375 + - -0.032958984 + - 3.9257813 + - 3.5195313 + - 1.5126953 + - -0.68847656 + - 1.3222656 + - -5.328125 + - 3.4375 + - 1.8378906 + - -2.4726563 + - -0.5859375 + - -5.9882813 + - 2.4960938 + - -1.7119141 + - 0.8515625 + - -2.0839844 + - -0.019195557 + - 3.28125 + - -4.8828125 + - -1.3984375 + - -0.5126953 + - 0.5415039 + - 1.3134766 + - 3.7304688 + - 3.6660156 + - 5.8046875 + - 2.1132813 + - 2.4023438 + - 1.6210938 + - -3.3398438 + - -3.9472656 + - -1.1796875 + - 3.84375 + - 0.10559082 + - -1.4814453 + - -0.6899414 + - 4.0078125 + - 3.6445313 + - 2.0800781 + - -1.0830078 + - -2.6660156 + - 0.17626953 + - 15.890625 + - -3.5195313 + - -0.6542969 + - 1.1113281 + - -1.2714844 + - 0.5058594 + - 0.9790039 + - 2.1953125 + - -0.5185547 + - -0.015296936 + - -4.8710938 + - 0.45214844 + - -2.0976563 + - -1.7587891 + - -2.125 + - 1.3242188 + - -4.9453125 + - 0.9404297 + - 4.203125 + - -0.4453125 + - -2.0117188 + - 0.36254883 + - -2.5371094 + - -2.7109375 + - -0.4736328 + - 2.5546875 + - -3.6171875 + - 0.15441895 + - 0.32421875 + - 2.2421875 + - -0.05859375 + - -3.4414063 + - -2.7285156 + - -1.0400391 + - -1.2080078 + - -3.3789063 + - 0.6201172 + - -1.7148438 + - 0.9399414 + - 1.8457031 + - -1.9355469 + - 3.90625 + - -4.4609375 + - -2.8554688 + - -2.0859375 + - 1.0449219 + - 1.4804688 + - -0.51953125 + - -1.2929688 + - -0.90527344 + - 0.8515625 + - 1.4267578 + - -3.1972656 + - -1.7519531 + - 2.1191406 + - -2.0507813 + - 3.3066406 + - -0.96972656 + - -1.0117188 + - 3.6445313 + - 1.2128906 + - 0.75097656 + - -1.8925781 + - -4.4296875 + - 0.74072266 + - -2.7304688 + - 0.15820313 + - 2.0390625 + - 1.1806641 + - 0.02407837 + - 0.5703125 + - 0.74072266 + - -0.08947754 + - -0.68359375 + - -1.7851563 + - 1.2402344 + - 3.0722656 + - -3.7363281 + - 3.4804688 + - 1.5947266 + - -0.026687622 + - 1.8457031 + - 0.85595703 + - -8.34375 + - -4.5585938 + - -3.8925781 + - 0.21704102 + - -1.9277344 + - -0.074035645 + - -3.1953125 + - 3.359375 + - -2.5019531 + - 2.3535156 + - 5.03125 + - 7.1015625 + - -1.28125 + - 2.2109375 + - -1.2626953 + - 4.90625 + - -2.3300781 + - -0.8876953 + - 3.109375 + - -2.2070313 + - 1.9785156 + - 0.4543457 + - 0.93359375 + - -4.9882813 + - 2.4570313 + - -3.2910156 + - 0.19396973 + - 1.6191406 + - 1.2207031 + - -0.625 + - 0.5185547 + - 0.04751587 + - -2.078125 + - 4.3671875 + - 0.640625 + - 1.1728516 + - -0.67871094 + - -2.7265625 + - 1.984375 + - 2.21875 + - 4.2890625 + - 0.5439453 + - 0.80371094 + - 0.15490723 + - 3.796875 + - -1.3603516 + - -1.71875 + - -1.0683594 + - 0.6743164 + - 2.1289063 + - 3.1289063 + - -0.6176758 + - 3.4003906 + - 0.15136719 + - -2.7050781 + - 0.34594727 + - 2.5898438 + - 0.5234375 + - 5.5625 + - 0.91015625 + - 1.4609375 + - 14.859375 + - -1.3017578 + - 1.3212891 + - -5.2421875 + - -0.5214844 + - -2.3046875 + - 1.4150391 + - -1.203125 + - 3.953125 + - -2.0097656 + - -2.6992188 + - -0.8046875 + - 0.28833008 + - 2.7597656 + - 0.049804688 + - 0.91308594 + - -5.5703125 + - 0.25390625 + - 0.22265625 + - 0.024734497 + - -0.67626953 + - 1.3320313 + - -1.0410156 + - -3.640625 + - -0.25341797 + - 1.3417969 + - -1.5166016 + - -4.3671875 + - -2.4472656 + - 0.5439453 + - -1.8212891 + - -2.5585938 + - 0.5361328 + - -1.5664063 + - -1.0214844 + - 0.5654297 + - 2.5019531 + - 0.17297363 + - 4.5625 + - 0.49658203 + - 3.0566406 + - -4.6679688 + - 3.8378906 + - 0.25195313 + - -0.8876953 + - 8.140625 + - -1.640625 + - 0.22387695 + - -0.65722656 + - -5.0351563 + - 1.9902344 + - 1.2021484 + - -1.2587891 + - 4.3320313 + - 4.015625 + - -2.5078125 + - 0.609375 + - -3.09375 + - 0.4572754 + - -0.23364258 + - 0.1171875 + - 0.32739258 + - -0.19165039 + - -0.090148926 + - -1.8798828 + - 1.4228516 + - 2.6015625 + - -6.5703125 + - 2.609375 + - 1.6796875 + - 0.5102539 + - -0.8652344 + - 5.8476563 + - 1.5175781 + - 2.625 + - -0.23364258 + - -3.2832031 + - -1.5703125 + - 2.1601563 + - 2.2910156 + - 1.8681641 + - -0.49804688 + - 3.2441406 + - -0.22753906 + - -1.3798828 + - -0.14465332 + - -2.7597656 + - 2.3730469 + - -1.8300781 + - 1.2392578 + - -3.3203125 + - -2.5234375 + - 1.4462891 + - 2.6601563 + - -4.4882813 + - 2.1523438 + - -3.96875 + - 4.1875 + - 1.296875 + - -0.87109375 + - 4.6132813 + - -2.2578125 + - -2.4394531 + - 4.5039063 + - 1.5625 + - 4.5234375 + - 1.3134766 + - 0.890625 + - 0.9296875 + - 3.125 + - 35.3125 + - 14.140625 + - -9.8046875 + - 0.80566406 + - 0.46679688 + - 0.2388916 + - -1.8359375 + - -3.5703125 + - 1.5048828 + - 1.1679688 + - 1.9238281 + - -1.9316406 + - 0.390625 + - 1.7314453 + - -5.75 + - 0.51953125 + - -0.0259552 + - 0.54003906 + - -3.21875 + - 2.3359375 + - 0.29492188 + - 1.3408203 + - -1.4785156 + - -0.18762207 + - -0.43286133 + - -0.8017578 + - 1.234375 + - -0.73095703 + - -7.3320313 + - 1.9111328 + - 0.08721924 + - -0.56152344 + - 0.66552734 + - 1.2216797 + - 1.6660156 + - -3.3242188 + - 0.15881348 + - 5.359375 + - -1.8066406 + - 0.46606445 + - 1.8408203 + - 1.3925781 + - -1.0996094 + - 6.0195313 + - -1.1767578 + - 0.33618164 + - -1.9609375 + - 0.6040039 + - 1.3525391 + - 0.8286133 + - 2.8378906 + - 4.71875 + - -0.98339844 + - 0.24768066 + - 2.6523438 + - 1.0644531 + - -0.2685547 + - 0.8671875 + - -0.013015747 + - -2.2851563 + - 2.7597656 + - 4.7695313 + - 1.984375 + - -1.7236328 + - -0.20532227 + - -1.1162109 + - 2.0976563 + - -0.56933594 + - -6.0820313 + - 0.03515625 + - -1.5283203 + - -0.24816895 + - -2.9453125 + - -1.2636719 + - -0.31640625 + - -0.9946289 + - -0.3227539 + - -0.3232422 + - 5.5195313 + - 1.3876953 + - -1.6103516 + - 1.1777344 + - 0.8798828 + - 3.0117188 + - 1.7539063 + - 3.1132813 + - 0.38916016 + - 1.0009766 + - -0.27954102 + - -0.52734375 + - -1.2441406 + - 1.7978516 + - -0.52734375 + - -1.4316406 + - 4.7734375 + - 1.0517578 + - -0.8417969 + - 0.37353516 + - -1.390625 + - 0.013504028 + - -1.3125 + - -1.3105469 + - 0.3564453 + - 2.1289063 + - -0.7817383 + - 2.1816406 + - 0.61816406 + - -1.8378906 + - 2.3085938 + - 2.7304688 + - -2.4121094 + - 3.546875 + - -1.6015625 + - 5.25 + - 1.9033203 + - 1.71875 + - -3.9765625 + - -1.1386719 + - 2.6113281 + - 0.66503906 + - -3.75 + - -1.7431641 + - -0.765625 + - 2.6972656 + - 3.9335938 + - 2.4726563 + - -4.3320313 + - 2.8984375 + - -1.078125 + - -0.80126953 + - -0.14318848 + - -2.6601563 + - -0.91064453 + - 1.7587891 + - 2.2011719 + - -0.89697266 + - -1.9863281 + - -1.7695313 + - 2.6445313 + - -0.3449707 + - -0.8852539 + - 2.5625 + - 3.5722656 + - -1.4150391 + - 0.81152344 + - -0.9423828 + - 1.53125 + - -3.6367188 + - -4.6640625 + - -4.0390625 + - -1.5390625 + - -0.7294922 + - -2.1933594 + - 1.3330078 + - -0.35986328 + - 0.27075195 + - -1.8251953 + - -1.9804688 + - -1.609375 + - 2.4960938 + - -0.062408447 + - -2.3222656 + - 2.921875 + - 2.0546875 + - -3.0273438 + - 0.9316406 + - 0.48950195 + - 1.6035156 + - -0.19384766 + - -4.3203125 + - 0.21594238 + - 0.65722656 + - -2.4511719 + - -2.4238281 + - 9.5078125 + - 0.79296875 + - 3.9570313 + - 1.9072266 + - -2.578125 + - -2.5 + - 2.2050781 + - -1.2763672 + - -0.19104004 + - 1.3164063 + - 1.421875 + - 1.8671875 + - 0.62402344 + - 1.4189453 + - 2.0761719 + - -4.0859375 + - 1.6621094 + - 4.0234375 + - 0.7451172 + - 1.3007813 + - -1.7988281 + - -2.0234375 + - 0.93603516 + - 1.6611328 + - 1.7460938 + - -3.5039063 + - -1.8339844 + - 0.15356445 + - 1.8222656 + - -0.3371582 + - -1.3486328 + - -6.3789063 + - 0.18481445 + - -0.3762207 + - -1.1855469 + - -2.1796875 + - 1.3945313 + - -1.2001953 + - 0.9951172 + - 0.8515625 + - -1.3046875 + - 1.8066406 + - -4.6328125 + - -3.4648438 + - 2.0019531 + - -0.92089844 + - 1.7695313 + - 0.84228516 + - -2.453125 + - 0.89746094 + - 3.015625 + - 2.4082031 + - 3.3359375 + - 2.0429688 + - 3.359375 + - 0.98828125 + - -0.5395508 + - 0.7734375 + - -0.69921875 + - -0.022125244 + - -1.6035156 + - -0.92089844 + - -3.9453125 + - 3.2265625 + - 2.0742188 + - -1.7558594 + - 1.2539063 + - -1.7109375 + - -8.46875 + - 1 + - -0.0859375 + - -0.49853516 + - -0.5776367 + - 5.2109375 + - 0.15356445 + - -1.2011719 + - 0.51464844 + - 2.9941406 + - 3.5019531 + - 2.7988281 + - 1.4394531 + - -3.5 + - 3.3242188 + - 2.9238281 + - -3.0585938 + - 0.61035156 + - -2.3632813 + - -0.014320374 + - -3.9335938 + - -0.12188721 + - -3.6894531 + - -3.5351563 + - -3.5097656 + - 0.7763672 + - -3.6132813 + - -0.8251953 + - 4.8164063 + - -2.1816406 + - 0.08496094 + - -1.7275391 + - -2.546875 + - 1.9179688 + - -0.07159424 + - 0.5600586 + - 0.26953125 + - 0.8930664 + - 1.5214844 + - -0.3852539 + - 0.3918457 + - 1.9765625 + - -1.1972656 + - 10.28125 + - 0.10675049 + - 1.5996094 + - -5.140625 + - 1.5917969 + - 1.3613281 + - -1.1572266 + - -2.6503906 + - -0.92041016 + - 0.5595703 + - 5.9570313 + - 2.8691406 + - -0.47265625 + - -0.8173828 + - 2.4121094 + - -0.7080078 + - -1.546875 + - -1.1708984 + - 3.6347656 + - -2.0546875 + - 2.40625 + - -1.9707031 + - 3.3222656 + - -0.9892578 + - -2.9726563 + - 0.78759766 + - -3.1445313 + - 2.5390625 + - -0.1640625 + - -4.8203125 + - -1.359375 + - -2.4140625 + - 2.3359375 + - 1.0957031 + - 1.09375 + - 3.4804688 + - 0.09698486 + - 2.84375 + - -1.0722656 + - -2.6835938 + - -2.4023438 + - 0.3305664 + - -4.7421875 + - 7.9765625 + - 1.1757813 + - 2.3632813 + - 0.9433594 + - 0.9375 + - -2.1933594 + - 0.8671875 + - -0.35546875 + - -1.7910156 + - -4.7421875 + - -1.6884766 + - 2.75 + - -1.2597656 + - 1.5048828 + - -0.9902344 + - -3.4726563 + - 2.3359375 + - 1.4394531 + - -3.65625 + - 0.037109375 + - 1.6533203 + - 1.5869141 + - -3.453125 + - 1.9628906 + - -3.1289063 + - 0.921875 + - 1.0673828 + - 0.7294922 + - -2.640625 + - 1.59375 + - -3.2617188 + - -0.11260986 + - -0.56640625 + - -2.5410156 + - -1.296875 + - 2.8691406 + - 1.8642578 + - 0.81640625 + - 2.1640625 + - 1.4238281 + - 1.0595703 + - 2.5351563 + - 1.7939453 + - -0.1899414 + - -0.11529541 + - -0.007648468 + - -1.6503906 + - 2.78125 + - -2.3671875 + - -2.2753906 + - 1.2001953 + - 6.1796875 + - -0.62939453 + - -0.3984375 + - 0.7734375 + - 1.0205078 + - -1.7363281 + - -1.2089844 + - -0.32470703 + - 1.5488281 + - 1.8359375 + - -0.09472656 + - 3.4472656 + - -1.796875 + - -0.21679688 + - -1.1171875 + - 4.1171875 + - -4.84375 + - -1.0039063 + - 1.3798828 + - 2.1679688 + - 5.96875 + - -0.007160187 + - -0.7524414 + - 1.1494141 + - -1.4423828 + - 1.3359375 + - -1.6533203 + - -1.8291016 + - -2.3164063 + - 2.4257813 + - -3.3164063 + - 0.5654297 + - -0.17504883 + - -3.1523438 + - 3.0625 + - -1.0107422 + - 0.78759766 + - -2.0078125 + - 0.5644531 + - -3.609375 + - -3.875 + - -0.26171875 + - 4.7421875 + - -2.390625 + - 0.8071289 + - 2.5175781 + - 2.7539063 + - -0.8203125 + - -0.49731445 + - 0.005531311 + - -2.4589844 + - -0.3383789 + - 0.4633789 + - -1.3798828 + - -2.9238281 + - -1.0244141 + - -0.5551758 + - 1.4628906 + - 1.6816406 + - -7.8007813 + - -4.0703125 + - 18.453125 + - -1.7832031 + - -2.6523438 + - 0.90478516 + - 1.1425781 + - 1.1152344 + - 2.2109375 + - 2.9648438 + - 0.5527344 + - -0.3161621 + - 1.8134766 + - 0.8046875 + - 0.30908203 + - -3.5117188 + - 0.13476563 + - -3.0253906 + - -2.6484375 + - -1.6220703 + - 4.3789063 + - -0.81689453 + - -2.5 + - -3.0664063 + - -4.0078125 + - 0.69140625 + - -0.9267578 + - 3.2617188 + - 1.1308594 + - 0.01399231 + - 1.7246094 + - 0.4326172 + - 2.9375 + - -0.28857422 + - 6.5820313 + - 3.9335938 + - 0.7285156 + - -0.53808594 + - 0.20117188 + - 2.3007813 + - 4.6914063 + - -1.5195313 + - -0.71484375 + - 0.98046875 + - 1.3720703 + - -2.484375 + - 2.3574219 + - -2.6015625 + - 1.2705078 + - 2.6816406 + - 0.4543457 + - 0.53222656 + - 0.7138672 + - 0.2709961 + - -1.7832031 + - -1.7080078 + - 0.13085938 + - -4.7617188 + - -0.35498047 + - -1.5439453 + - 2.1484375 + - -3.4003906 + - -3.0546875 + - 0.06964111 + - -0.3400879 + - 3.3007813 + - -2.4648438 + - 0.42578125 + - 2.4003906 + - 3.015625 + - -2.0273438 + - 3.03125 + - -4.875 + - -3.0742188 + - 0.037750244 + - -1.5507813 + - 2.0078125 + - 8.1796875 + - 0.9716797 + - 4.4101563 + - -6.3320313 + - 0.41015625 + - -1.1025391 + - 3.03125 + - 0.037109375 + - -0.12988281 + - 0.2265625 + - -5.8671875 + - -0.8408203 + - -0.2854004 + - 2.4570313 + - -1.3232422 + - -1.3886719 + - 1.3066406 + - -2.6894531 + - -2.40625 + - -3.0625 + - 2.8867188 + - 3.0722656 + - -1.0683594 + - 5.109375 + - -3.4726563 + - -1.0136719 + - 0.5371094 + - 0.97265625 + - 2.5136719 + - 1.265625 + - 0.55908203 + - -0.33984375 + - -0.796875 + - 0.83691406 + - 0.42236328 + - 5.5390625 + - 1.5234375 + - 1.609375 + - 1.1035156 + - 3.9726563 + - 0.56689453 + - 0.7675781 + - 0.8461914 + - 1.125 + - -0.07092285 + - 2.8300781 + - 0.44262695 + - 4.7226563 + - -1.2949219 + - -1.296875 + - -3.4394531 + - 0.82910156 + - -0.0390625 + - -0.35302734 + - -0.41064453 + - 2 + - -0.5859375 + - -5.6640625 + - -0.95166016 + - 0.6816406 + - 2.5839844 + - 2.2539063 + - 1.7753906 + - -0.2446289 + - 3.1757813 + - 2.1015625 + - 2.6113281 + - 0.2355957 + - 2.0449219 + - 2.7207031 + - -2.1035156 + - -7.1914063 + - -2.6035156 + - 1.9921875 + - -3.4628906 + - -1.4902344 + - -0.55566406 + - 2.8378906 + - 3.9375 + - 3.6445313 + - -0.11584473 + - 0.31054688 + - 1.0019531 + - -0.61328125 + - -0.2763672 + - 3.5175781 + - 3.4804688 + - -3.5957031 + - 0.012039185 + - -0.38110352 + - 3.6601563 + - 0.25683594 + - -0.40551758 + - -0.64160156 + - 0.25732422 + - 0.79003906 + - -0.89697266 + - -2.1835938 + - -1.0742188 + - -1.6757813 + - -1.7851563 + - -2.0585938 + - 2.5898438 + - -1.0957031 + - 0.6035156 + - 2.265625 + - -3.1445313 + - -0.5493164 + - 1.2929688 + - -1.7363281 + - -2.3945313 + - -1.0546875 + - -2.1835938 + - 3.8320313 + - 3.1191406 + - 2.3144531 + - -6.7578125 + - -1.0976563 + - 0.35668945 + - -5.2851563 + - 1.7636719 + - -2.6367188 + - 0.97509766 + - 0.6538086 + - -8.5703125 + - 1 + - 0.4169922 + - 0.2602539 + - -2.1210938 + - -1.5859375 + - -1.46875 + - 0.5834961 + - 0.28320313 + - -0.33911133 + - -25.625 + - -1.5126953 + - -5.8125 + - 2.9765625 + - 0.24145508 + - 2.3144531 + - 3.0878906 + - -3.5878906 + - -5.1640625 + - 0.020828247 + - 0.49243164 + - -1.0859375 + - -1.9501953 + - 5.7734375 + - -0.13366699 + - 5.1953125 + - 0.08850098 + - 4.9921875 + - -0.98339844 + - 1.5410156 + - -0.08721924 + - -0.72509766 + - 1.2910156 + - 3.8125 + - -1.3193359 + - -3.4960938 + - -0.44189453 + - -0.16748047 + - 3.4414063 + - -0.5678711 + - 0.37939453 + - -0.43286133 + - 2.3046875 + - 0.40356445 + - -6.2226563 + - 4.4296875 + - -2.4609375 + - -1.8955078 + - -4.2421875 + - -1.4931641 + - -0.85791016 + - 1.5517578 + - 2.6621094 + - -0.15686035 + - -3.5273438 + - 3.125 + - 2.96875 + - -1.0556641 + - 0.40283203 + - -5.1601563 + - 3.0507813 + - -0.55029297 + - -2.0722656 + - -3.703125 + - 1.7236328 + - 3.7421875 + - -1.4472656 + - 2.5976563 + - 1.6269531 + - 0.29492188 + - 0.12524414 + - -2.1269531 + - 1.8564453 + - -1.2783203 + - 0.90527344 + - 0.0960083 + - 0.92041016 + - -3.8691406 + - 2.2910156 + - -1.1074219 + - -2.6953125 + - 1.5048828 + - 1.3212891 + - 1.3105469 + - -0.69921875 + - 1.109375 + - -0.84765625 + - -2.1894531 + - 1.7773438 + - 0.46289063 + - 1.0683594 + - 1.5205078 + - 0.45947266 + - 0.6953125 + - 2.515625 + - 1.7626953 + - -0.017410278 + - 0.37109375 + - 2.5585938 + - 0.52685547 + - -0.4777832 + - 2.1425781 + - -0.33813477 + - 6.3203125 + - -0.26220703 + - 2.9394531 + - -0.91015625 + - 0.6923828 + - 1.2431641 + - -0.40356445 + - 1.5996094 + - -1.4775391 + - 3.7519531 + - -1.3554688 + - 1.9160156 + - -1.5986328 + - -3.5078125 + - 5.140625 + - -1.1523438 + - 5.8007813 + - -1.2900391 + - 7.9765625 + - -1.7509766 + - 0.32861328 + - -2.2382813 + - 0.93652344 + - -0.7392578 + - 2.2539063 + - -1.0869141 + - -0.8466797 + - 2.7597656 + - 1.7753906 + - 4.6328125 + - -0.29882813 + - -0.6533203 + - 0.57910156 + - -0.515625 + - -3.8828125 + - -1.7470703 + - 0.18225098 + - -3.9160156 + - -1.1816406 + - 2.4863281 + - 3.9570313 + - -0.8852539 + - -2.1601563 + - -28.90625 + - 0.2919922 + - 2.2734375 + - -0.76660156 + - 0.6015625 + - 3.8164063 + - -0.01789856 + - 1.8408203 + - 1.7519531 + - 2.5898438 + - -0.31176758 + - -0.4230957 + - -1.1376953 + - -0.7158203 + - -2.2890625 + - -0.734375 + - 0.34375 + - 1.4375 + - 1.8173828 + - -2.7070313 + - 2.7675781 + - 8.484375 + - 3.6015625 + - -2.203125 + - 0.8564453 + - 0.796875 + - -0.41333008 + - 0.9296875 + - 1.9941406 + - -3.4882813 + - 0.7446289 + - 0.020996094 + - -0.27001953 + - -0.5830078 + - 2.6367188 + - -3.4160156 + - -3.4082031 + - -0.86816406 + - -0.6953125 + - -1.28125 + - 5.4765625 + - 0.37890625 + - -0.4609375 + - 3.0097656 + - -2.2636719 + - 3.6289063 + - 0.0012741089 + - 3.1953125 + - 1.0205078 + - -2.0507813 + - -0.6533203 + - 4.0703125 + - -0.9589844 + - 2.4824219 + - 1.3720703 + - 0.19848633 + - -1.6601563 + - 2.2304688 + - -0.88378906 + - -3.2988281 + - 1.7441406 + - -1.4970703 + - -0.6040039 + - -2.4921875 + - -4.8007813 + - 0.8598633 + - -4.3984375 + - 1.9423828 + - 4.296875 + - -8.7109375 + - 1.3457031 + - -0.068359375 + - -1.9345703 + - -4.8671875 + - 3.1914063 + - 1.0673828 + - -1.9160156 + - 2.15625 + - -1.0566406 + - 2.7050781 + - 1.0644531 + - 0.044921875 + - 1.9609375 + - 0.53759766 + - 1.7822266 + - -1.8515625 + - -0.28759766 + - -1.8173828 + - -0.82910156 + - 0.34765625 + - -2.0625 + - -1.6425781 + - 0.8466797 + - 0.7636719 + - 3.6132813 + - 2.0976563 + - 0.7114258 + - -0.21289063 + - 2.4335938 + - 0.03579712 + - 3.9882813 + - -0.88671875 + - -4.7070313 + - -1.7285156 + - 1.5078125 + - 0.094055176 + - -0.012367249 + - 3.2851563 + - 4.1523438 + - 1.1933594 + - -1.453125 + - 3.5820313 + - 0.89697266 + - 2.5136719 + - -2.203125 + - 0.107421875 + - 1.2734375 + - 4.0390625 + - -0.3244629 + - -0.6425781 + - 3.8125 + - -3.5703125 + - -4.6640625 + - -0.10461426 + - 0.107055664 + - -3.3007813 + - -4.375 + - 1.1699219 + - -0.87597656 + - 2.4726563 + - 3.265625 + - 0.8408203 + - 2.9882813 + - -0.35546875 + - 1.6015625 + - -1.0097656 + - -2.5429688 + - 0.014320374 + - 0.7524414 + - -1.5390625 + - 1.0507813 + - 3.59375 + - -0.16918945 + - 0.25976563 + - 2.7695313 + - 0.48168945 + - 3.7539063 + - -2.1289063 + - -1.890625 + - 0.53808594 + - 3.9726563 + - 4.609375 + - -0.17053223 + - 3.21875 + - -1.828125 + - 1.2353516 + - -2.8007813 + - 0.14550781 + - 0.9404297 + - 0.9580078 + - -0.76660156 + - 2.3359375 + - 2.0234375 + - 3.0253906 + - 2.5703125 + - -3.4179688 + - 4.515625 + - -3.578125 + - 1.1767578 + - -0.31396484 + - -0.3088379 + - -1.5947266 + - 0.2421875 + - -0.4482422 + - -3.765625 + - 0.39746094 + - 2.2148438 + - -0.29541016 + - -0.5517578 + - 2.65625 + - 0.48657227 + - -0.29711914 + - -2.671875 + - 1.4628906 + - 1.0449219 + - 4.421875 + - -1.2089844 + - 2.4667969 + - -3.2109375 + - -0.8457031 + - -0.15783691 + - -0.5551758 + - -0.73291016 + - 5.1835938 + - 0.078125 + - -1.5830078 + - -0.33081055 + - 0.07092285 + - -1.8925781 + - 0.57128906 + - 0.515625 + - -1.1679688 + - -1.2685547 + - -3.4785156 + - -0.05206299 + - -0.38623047 + - -0.98828125 + - -0.30981445 + - 0.98535156 + - -2.3984375 + - -1.1425781 + - -7.2890625 + - -4.9921875 + - 3.40625 + - -2.6191406 + - -1.9726563 + - 1.875 + - -0.65283203 + - 3.1640625 + - -3.1445313 + - -2.3847656 + - -5.6875 + - -1.5703125 + - -4.234375 + - 1.3476563 + - -0.11682129 + - 1.765625 + - 1.984375 + - -2.078125 + - 1.0410156 + - -1.4189453 + - -2.9609375 + - 0.45947266 + - -0.41918945 + - -1.3798828 + - 2.2890625 + - -0.97265625 + - 0.35766602 + - 4.2890625 + - -1.2666016 + - 4.546875 +- - 3.2363281 + - -1.1582031 + - 1.0810547 + - -2.0292969 + - 1.609375 + - -1.0048828 + - 0.43676758 + - -0.8769531 + - 0.79785156 + - -0.27612305 + - 0.4963379 + - -0.82128906 + - 0.16906738 + - -0.734375 + - -0.34936523 + - 0.03515625 + - 0.34375 + - 1.3769531 + - 1.5234375 + - -1.875 + - -1.4082031 + - 1.6289063 + - -1.1650391 + - 0.65234375 + - 1.796875 + - 1.984375 + - -0.4350586 + - 1.4003906 + - -0.34985352 + - -2.5253906 + - 2.5351563 + - 0.32348633 + - 2.3007813 + - 1.5195313 + - -0.28295898 + - 1.1650391 + - -3.4472656 + - 0.07421875 + - -5.28125 + - -0.8310547 + - 0.7524414 + - -2.4257813 + - -0.91845703 + - -0.9814453 + - -1.7285156 + - 2.0761719 + - 0.23657227 + - -3.9003906 + - -1.4052734 + - 0.8310547 + - 3.9140625 + - -0.43408203 + - -3.0429688 + - -100.5625 + - -3.0703125 + - -0.93652344 + - 2.71875 + - -1.0527344 + - -1.3789063 + - -7.3671875 + - -2.3789063 + - 0.58251953 + - 0.8388672 + - 0.13110352 + - 2.4003906 + - 0.07421875 + - -2.5488281 + - 0.5126953 + - 2.0644531 + - -1.5556641 + - -4.6679688 + - 0.055236816 + - -2.9921875 + - -0.9038086 + - -1.2294922 + - -0.3984375 + - 2.9863281 + - 3.1328125 + - -0.13867188 + - -0.36523438 + - -0.63916016 + - -0.6064453 + - -1.5869141 + - -0.3425293 + - -2.0234375 + - 0.5336914 + - -1.8027344 + - -0.15185547 + - 2.2578125 + - 0.86376953 + - -1.234375 + - 5.9453125 + - 2.7089844 + - -19.703125 + - -2.8125 + - -2.7832031 + - -4.4375 + - 0.35595703 + - 1.5751953 + - -4.09375 + - 1.6884766 + - -1.3564453 + - -3.8652344 + - -0.61035156 + - 0.0055770874 + - -2.7949219 + - 0.08062744 + - -1.3369141 + - -1.5839844 + - -0.056915283 + - 0.04058838 + - 0.4296875 + - 0.47753906 + - -1.5585938 + - -0.055511475 + - 3.03125 + - 2.8515625 + - 0.70947266 + - -0.18884277 + - 0.29467773 + - 2.2421875 + - 0.59472656 + - 0.15393066 + - -2.4863281 + - -2.1992188 + - -0.27172852 + - 2.40625 + - -0.73095703 + - 0.32299805 + - 1.59375 + - 2.3808594 + - 0.17297363 + - -3.2519531 + - 1.1630859 + - 1.234375 + - 2.40625 + - -0.3088379 + - 0.78564453 + - -1.2050781 + - -1.4824219 + - 1.5166016 + - -0.4206543 + - 1.3535156 + - -2.7734375 + - 1.1757813 + - -2.8027344 + - -1.7998047 + - -0.9379883 + - -2.5703125 + - 4.5820313 + - 0.78564453 + - -1.9257813 + - -1.0478516 + - 0.03515625 + - 0.5151367 + - -2.7832031 + - 0.90722656 + - -0.5102539 + - -3.0390625 + - -3.1289063 + - -1.2509766 + - -2.6191406 + - -0.5546875 + - -1.1376953 + - 0.51416016 + - 1.3994141 + - 3.3613281 + - -1.1591797 + - -0.7583008 + - -0.46289063 + - -2.6386719 + - -1.9306641 + - -0.43896484 + - -2.9863281 + - -0.09875488 + - 0.25195313 + - -1.3115234 + - 2.09375 + - -4.265625 + - -2.2519531 + - 1.7910156 + - 0.8022461 + - -1.8603516 + - -1.8544922 + - 0.13891602 + - 5.1054688 + - -3.4863281 + - -0.85253906 + - -1.1806641 + - 0.07336426 + - -1.9082031 + - -3.7753906 + - -0.5541992 + - 0.640625 + - -2.2460938 + - 1.4951172 + - 3.6328125 + - -2.1640625 + - -1.4921875 + - 0.13476563 + - 0.44189453 + - -2.359375 + - 1.9189453 + - 0.7114258 + - 7.9375 + - 3.2929688 + - 4.2617188 + - -2.8378906 + - -0.3474121 + - -2.2304688 + - -2.0644531 + - -0.7504883 + - -2.9101563 + - -0.859375 + - 0.8330078 + - 3.9570313 + - -0.0036258698 + - -2.5214844 + - 3.0898438 + - -0.70458984 + - -3.8535156 + - 0.6298828 + - -0.32739258 + - 3.1289063 + - -0.08618164 + - -1.21875 + - 0.09649658 + - 0.7675781 + - 0.39672852 + - -3.1464844 + - 0.7763672 + - -0.7680664 + - -1.0068359 + - -0.88671875 + - -0.2064209 + - 1.5820313 + - 0.7441406 + - 2.3671875 + - 2.8554688 + - 1.6601563 + - 6.0390625 + - -0.35351563 + - -3.4589844 + - 0.23046875 + - -2.2324219 + - -1.7626953 + - 3.2714844 + - 2.5566406 + - -0.61572266 + - 0.20751953 + - 1.2539063 + - 0.4423828 + - -2.1269531 + - 0.5131836 + - 0.62353516 + - -0.6958008 + - -0.33032227 + - -2.28125 + - 0.032348633 + - 0.3408203 + - 1.4726563 + - 1.8486328 + - 1.890625 + - 1.8886719 + - -0.37426758 + - 2.4140625 + - -2.3027344 + - 3.9121094 + - 0.85546875 + - -4.6953125 + - 0.32983398 + - 0.8154297 + - 3.2304688 + - 0.8305664 + - -0.42773438 + - -1.1630859 + - -3.9277344 + - 1.3681641 + - 0.18469238 + - 1.0292969 + - -2.1328125 + - -2.6738281 + - 1.3876953 + - 0.1361084 + - 0.99902344 + - -0.77783203 + - -0.064697266 + - 1.828125 + - 0.65771484 + - 0.03390503 + - 1.7265625 + - 1.2138672 + - 10.0703125 + - 0.064697266 + - 0.6723633 + - -0.4819336 + - 1.8457031 + - -1.4023438 + - 2.2148438 + - -0.5493164 + - -0.07574463 + - -0.20422363 + - 2.7597656 + - 3.3242188 + - -1.6425781 + - 1.5322266 + - 2.4785156 + - 1.4394531 + - -0.09094238 + - -1.203125 + - -1.6650391 + - -0.10546875 + - -0.8964844 + - 0.072509766 + - 1.1875 + - -2.4375 + - 0.08258057 + - -0.14453125 + - -3.1816406 + - 1.2851563 + - 1.8339844 + - 1.2412109 + - -3.8457031 + - 2.5703125 + - -1.4052734 + - -0.78564453 + - -1.3427734 + - -1.5039063 + - 2.3652344 + - -3.5820313 + - -4.078125 + - 1.7050781 + - 1.5644531 + - 0.7709961 + - 2.34375 + - -0.11657715 + - 2.7832031 + - -0.49926758 + - 0.08984375 + - 0.105285645 + - 2.7597656 + - -0.4482422 + - 2.1015625 + - 1.5488281 + - 1.9433594 + - 1.1533203 + - -0.21252441 + - 2.6777344 + - -5.0664063 + - -0.8847656 + - 2.1464844 + - -1.265625 + - 0.3330078 + - 0.5102539 + - -2.1738281 + - -0.7841797 + - -4.1015625 + - -1.609375 + - -1.6220703 + - -1.4111328 + - -1.4921875 + - 1.7324219 + - 4.359375 + - -1.3857422 + - 2.9726563 + - -2.90625 + - 6.1757813 + - 1.6982422 + - 1.4638672 + - -2.6894531 + - 0.7714844 + - -1.5244141 + - -2.125 + - 3.5058594 + - -0.3996582 + - 3.5996094 + - -1.4482422 + - 0.3935547 + - 0.7109375 + - 2.4746094 + - -1.3896484 + - -1.2880859 + - -1.9433594 + - -0.859375 + - -0.703125 + - 1.8554688 + - 1.8632813 + - -4.2226563 + - -8.125 + - -2.1074219 + - 0.453125 + - -0.09375 + - -2.6660156 + - -0.95751953 + - 0.047698975 + - -0.29663086 + - 2.6464844 + - 2.1074219 + - -2.1464844 + - 1.5498047 + - -2.3339844 + - 1.5898438 + - -0.5654297 + - -4.3476563 + - -0.1673584 + - 1.7988281 + - 2.0488281 + - -2.1660156 + - -14.390625 + - -0.12243652 + - -2.2089844 + - -1.6064453 + - 3.1171875 + - -1.1591797 + - 1.4433594 + - -0.19689941 + - -3.6835938 + - -1.4238281 + - -3.6152344 + - 5.109375 + - -0.5004883 + - -0.4736328 + - 2.7988281 + - -0.32592773 + - -0.75927734 + - 1.0458984 + - 0.1619873 + - -2.0371094 + - 2.2246094 + - -1.4375 + - -1.921875 + - -1.7138672 + - -3.8613281 + - 0.85009766 + - -0.37939453 + - -1.8525391 + - 0.5839844 + - -1.9013672 + - 0.7519531 + - 1.6748047 + - -1.3095703 + - -1.5087891 + - -0.6269531 + - -1.6445313 + - -2.2011719 + - -0.9091797 + - 0.06640625 + - 2.7050781 + - -2.1679688 + - -3.5800781 + - -0.009483337 + - 1.5244141 + - -0.58935547 + - -2.0390625 + - -0.47583008 + - 5.609375 + - 4.625 + - -0.033477783 + - 0.07110596 + - 3.2851563 + - -0.44482422 + - -2.8945313 + - -1.7675781 + - 2.7714844 + - -0.9301758 + - -0.84521484 + - -0.9785156 + - 0.27197266 + - 0.33666992 + - -2.3515625 + - 4.9375 + - 2.3125 + - 0.29882813 + - 1.015625 + - 0.35131836 + - 0.43896484 + - 0.8076172 + - -0.91064453 + - -0.6064453 + - 3.8203125 + - 0.5683594 + - 0.55908203 + - 0.9736328 + - -1.9970703 + - -0.3269043 + - 1.2158203 + - -6.0039063 + - 0.13977051 + - 3.71875 + - -0.5605469 + - 0.46313477 + - 1.5683594 + - -0.7011719 + - -0.46362305 + - -2.6328125 + - -1.3330078 + - 2.4570313 + - -2.0488281 + - -2.9238281 + - 5.375 + - 0.21679688 + - -5.9726563 + - 2.0390625 + - 0.055786133 + - 1.3359375 + - 3.8378906 + - -0.6225586 + - -0.6113281 + - -1.5830078 + - 2.8535156 + - 3.6679688 + - -2.5703125 + - -1.5019531 + - 0.69091797 + - -2.0332031 + - 1.6210938 + - -0.3408203 + - -0.5522461 + - -1.4355469 + - -0.5078125 + - 0.5957031 + - 1.5869141 + - 3.6757813 + - -0.018692017 + - 0.55566406 + - 1.4609375 + - 0.20336914 + - -1.3769531 + - 1.6767578 + - 2.1894531 + - 0.85253906 + - 0.4519043 + - -0.00390625 + - -1.8789063 + - 3.5800781 + - 0.16516113 + - -4.5117188 + - -0.12890625 + - -0.3557129 + - -1.6269531 + - -1.9589844 + - -1.0107422 + - 3.1054688 + - -0.8457031 + - -4.8476563 + - -2.3652344 + - -1.3818359 + - 0.20703125 + - 1.9863281 + - 1.4814453 + - 0.6333008 + - 1.9667969 + - -17.671875 + - -1.453125 + - -1.0478516 + - -2.0019531 + - -1.3818359 + - 0.61279297 + - 0.20227051 + - 0.0055770874 + - 2.3476563 + - -3.4804688 + - -1.0546875 + - -2.2363281 + - 1.2685547 + - -1.0302734 + - 0.87597656 + - -2.4453125 + - -1.4394531 + - -2.3496094 + - -2.2890625 + - -0.8925781 + - -1.9296875 + - 0.9921875 + - 0.2939453 + - -1.2851563 + - 1.1201172 + - 0.578125 + - 0.30908203 + - 0.7246094 + - -3.2089844 + - 0.65478516 + - 2.5683594 + - -3.2148438 + - -2.9394531 + - 1.6816406 + - 1.6416016 + - -2.3417969 + - -3.5 + - -1.1904297 + - 1.4462891 + - -3.1875 + - -1.890625 + - -0.1015625 + - -1.9082031 + - 1.4306641 + - 5.1757813 + - 3.9101563 + - 1.0263672 + - 3.2402344 + - -0.8222656 + - -0.68603516 + - 0.055786133 + - -2.2578125 + - -2.3261719 + - 0.15234375 + - -3.6972656 + - 0.5625 + - -4.3789063 + - 0.9506836 + - 2.5957031 + - -1.7587891 + - -1.9824219 + - 1.9609375 + - -0.60595703 + - -0.2524414 + - -1.5576172 + - 1.8701172 + - -2.1386719 + - 0.00390625 + - 1.4619141 + - 1.8613281 + - 0.00027894974 + - 0.44140625 + - -1.6054688 + - 3.4902344 + - 0.036834717 + - 1.4169922 + - 0.7788086 + - -0.12384033 + - 1.7070313 + - -0.52197266 + - -3.2265625 + - -2.6875 + - 0.61572266 + - 2.6113281 + - -2.8164063 + - -0.83251953 + - -0.25439453 + - 0.037384033 + - -2.2226563 + - -2.5703125 + - -0.08013916 + - 2.7851563 + - 4.390625 + - -1.0810547 + - 0.59375 + - -4.6757813 + - 7.9140625 + - -3.1503906 + - 0.73339844 + - 3.3554688 + - -1.6220703 + - -2.59375 + - 0.984375 + - -1.6298828 + - -0.5546875 + - 2.6933594 + - 3.8125 + - -0.45922852 + - 1.4638672 + - 1.0556641 + - 1.6621094 + - 3.1113281 + - -0.55126953 + - 2.4003906 + - 1.8222656 + - -2.0507813 + - 0.22314453 + - 0.98535156 + - -0.5253906 + - -1.0029297 + - 0.6152344 + - 0.6113281 + - -0.71191406 + - -2.9492188 + - -0.19580078 + - -0.98828125 + - -0.1899414 + - 0.044067383 + - 1.5214844 + - 1.734375 + - 1.0146484 + - -1.4179688 + - 7.7578125 + - 3.3652344 + - 7.0976563 + - 1.4726563 + - -5.7226563 + - -5.890625 + - -0.3828125 + - -1.3154297 + - -0.31958008 + - -1.5888672 + - 0.1907959 + - -0.23181152 + - -1.046875 + - 1.6132813 + - -1.9482422 + - 2.6699219 + - 3.2246094 + - 3.6679688 + - -0.9091797 + - -2.5136719 + - 0.5102539 + - 24.09375 + - 1.2988281 + - 0.88183594 + - 0.09313965 + - -3.0195313 + - 1.8251953 + - 0.71484375 + - 0.77197266 + - -2.15625 + - 1.1113281 + - 3 + - 2.96875 + - -0.28686523 + - -0.0496521 + - 0.5957031 + - 4.7929688 + - 1.4414063 + - 3.0625 + - -5.0664063 + - -0.17687988 + - -1.8623047 + - -1.8876953 + - -3.6367188 + - 0.9038086 + - -0.4519043 + - 1.453125 + - -0.27124023 + - -1.8652344 + - 2.1582031 + - 0.65771484 + - -3.4160156 + - -5.7304688 + - -0.22070313 + - -3.03125 + - -0.9975586 + - 1.8378906 + - -1.4101563 + - 1.4414063 + - 3.9804688 + - -1.9648438 + - -1.5292969 + - -1.8769531 + - 2.2949219 + - -0.23254395 + - -0.5600586 + - 1.2783203 + - 0.60791016 + - 1.453125 + - 0.8408203 + - -0.73535156 + - -0.99658203 + - -3.1132813 + - 2.9472656 + - -0.5136719 + - 0.32617188 + - -2.6640625 + - -1.5917969 + - 1.0527344 + - 0.119384766 + - -1.2695313 + - -1.6621094 + - 2.1621094 + - -1.7226563 + - -1.7275391 + - -0.45898438 + - -0.26733398 + - 2.6152344 + - 0.4230957 + - -1.1201172 + - -0.47021484 + - 4.1289063 + - 1.4775391 + - -0.26342773 + - 2.9726563 + - -2.859375 + - 2.3222656 + - 0.52197266 + - -1.1865234 + - -3.2050781 + - -1.1943359 + - 2.2285156 + - -2.5 + - 5.8789063 + - -0.001953125 + - 2.4101563 + - -0.78027344 + - -1.4560547 + - 0.8540039 + - 2.6914063 + - 0.49853516 + - -1.1474609 + - -0.55566406 + - 0.46972656 + - 1.1582031 + - -3.6191406 + - 2.3203125 + - -4.75 + - -4.75 + - -3.7871094 + - 1.0068359 + - 3.9179688 + - 1.4345703 + - -1.3925781 + - 0.171875 + - 2.4257813 + - 1.21875 + - -2.6074219 + - 1.1171875 + - -1.5332031 + - -4.0273438 + - -0.3540039 + - 5.6328125 + - 0.23010254 + - 2.109375 + - 1.9853516 + - -0.9951172 + - 2.140625 + - -0.2705078 + - -2.8164063 + - -0.19946289 + - 4.5820313 + - -2.5664063 + - -0.3581543 + - 2.8847656 + - -1.4316406 + - 0.06585693 + - 1.0810547 + - -1.1972656 + - -9.3359375 + - 1.4482422 + - -47.25 + - -1.2919922 + - -0.6015625 + - -2.0625 + - -3.9179688 + - -0.47729492 + - 0.296875 + - 1.0654297 + - 1.6640625 + - 1.0595703 + - 0.18188477 + - -1.796875 + - 4.6875 + - -0.5253906 + - -2.0019531 + - 1.5869141 + - 1.1044922 + - -0.7211914 + - 16.984375 + - 0.42285156 + - -0.9765625 + - -1.2626953 + - -0.9379883 + - -0.57958984 + - 0.4038086 + - 2.8007813 + - 0.87353516 + - -1.625 + - -0.4267578 + - -2.6699219 + - -0.9609375 + - -2.4199219 + - 0.1784668 + - 0.49438477 + - -0.88183594 + - 2.4472656 + - 1.0351563 + - 0.8046875 + - 1.4453125 + - 0.5073242 + - 3.921875 + - -0.3798828 + - 1.046875 + - 0.2524414 + - -3.1367188 + - 2.5292969 + - 0.12658691 + - -1.2939453 + - -0.52246094 + - -2.9902344 + - 0.3515625 + - -1.6132813 + - -0.08203125 + - -0.66015625 + - -0.059143066 + - 0.21252441 + - 1.9482422 + - -4.1484375 + - -2.4863281 + - 0.35864258 + - 0.18481445 + - -1.0009766 + - -2.59375 + - 1.2685547 + - 6.6015625 + - -0.65283203 + - -0.7451172 + - 4.7226563 + - -2.2519531 + - 2.3105469 + - -2.0625 + - -0.16796875 + - 0.17907715 + - -2.3144531 + - 2.8964844 + - -4.5703125 + - 3.5996094 + - -1.0625 + - 5.2304688 + - 0.46972656 + - 0.31811523 + - -3.0722656 + - 1.9150391 + - 0.18713379 + - 1.9267578 + - 2.9316406 + - -1.0644531 + - -0.28515625 + - 0.26489258 + - -0.71972656 + - 2.5703125 + - -1.4707031 + - -1.5351563 + - -2.7070313 + - 1.2441406 + - -0.47607422 + - -0.3474121 + - -0.8457031 + - -3.4179688 + - -1.0927734 + - -2.1328125 + - -5.7382813 + - -1.1689453 + - 0.2512207 + - 1.3505859 + - 3.4101563 + - 3.4472656 + - 0.40112305 + - 0.56689453 + - 0.064697266 + - 0.7753906 + - 0.9980469 + - -1.6445313 + - 2.921875 + - 0.97314453 + - 1.3320313 + - -2.6816406 + - 2.3125 + - -2.0449219 + - 2.2089844 + - 1.6376953 + - 0.4819336 + - -1.6738281 + - -1.7792969 + - 0.17663574 + - 0.31298828 + - 4.0273438 + - -0.7270508 + - 3.1933594 + - 2.3964844 + - 2.65625 + - 1.4794922 + - -0.0524292 + - 1.9814453 + - 0.39282227 + - 0.23828125 + - 2.7226563 + - -0.80126953 + - -2.8105469 + - 0.1665039 + - -2.1660156 + - -2.0292969 + - -2.4453125 + - -3.0078125 + - 1.9033203 + - 2.8339844 + - 2.7753906 + - -2.4765625 + - 0.8408203 + - -3.203125 + - 2.265625 + - -1.7246094 + - 4.75 + - 4.6875 + - 0.59472656 + - -0.53466797 + - 1.7792969 + - 0.2956543 + - 2.3515625 + - -4.1757813 + - 3.9179688 + - -1.46875 + - -4.9453125 + - -1.9033203 + - -1.0390625 + - -0.34399414 + - -2.9414063 + - -15.546875 + - 2.0390625 + - -1.2695313 + - 4.1445313 + - 1.2197266 + - 3.3535156 + - 1.3818359 + - 1.5996094 + - -0.45141602 + - -0.6635742 + - 1.65625 + - -2.0996094 + - 2.4941406 + - 1.4921875 + - 2.0800781 + - -3.2675781 + - 0.96191406 + - -0.0072517395 + - -0.21252441 + - 1.2314453 + - 2.2519531 + - -1.0253906 + - 0.35327148 + - -0.015625 + - 1.5966797 + - -4.4726563 + - 0.20471191 + - -1.7744141 + - -16.671875 + - 0.61865234 + - 0.1204834 + - 2.9863281 + - -4.984375 + - -1.5673828 + - 0.2685547 + - 1.1904297 + - -5.015625 + - -2.6191406 + - -2.6132813 + - 3.6992188 + - -0.53271484 + - -0.45141602 + - -2.3652344 + - 0.70166016 + - -6.203125 + - -1.1904297 + - -0.35180664 + - 0.74072266 + - 1.1875 + - -0.9941406 + - -0.24536133 + - -2.4628906 + - -0.63623047 + - 2.921875 + - -3.5 + - -0.0418396 + - -0.52783203 + - 1.5361328 + - 3.4628906 + - -1.8183594 + - 0.32592773 + - -1.4794922 + - -0.74853516 + - 2.2285156 + - -0.75097656 + - 0.43237305 + - -18.859375 + - -0.33251953 + - -1.9013672 + - 2.4355469 + - -4.1875 + - 2.4121094 + - 0.5698242 + - 1.2294922 + - 1.6337891 + - -0.6972656 + - 1.4189453 + - -1.1513672 + - 2.2636719 + - -1.9921875 + - 0.50927734 + - -0.11621094 + - 0.58740234 + - 0.045196533 + - 1.4101563 + - -4.8007813 + - -1.421875 + - 2.3144531 + - -2.7324219 + - -0.19055176 + - 2.9023438 + - -1.4501953 + - 3.1484375 + - -2.5957031 + - -1.5234375 + - 2.0722656 + - 1.359375 + - 3.15625 + - -2.1503906 + - -1.5009766 + - -1.6464844 + - -0.4116211 + - -0.60595703 + - -1.6875 + - 1.4931641 + - 1.8671875 + - 3.7695313 + - 1.6650391 + - 2.296875 + - 3.6601563 + - -2.0839844 + - 0.4116211 + - -2.2988281 + - -1.4267578 + - -6.0625 + - 1.0380859 + - 2.4628906 + - 0.46191406 + - 0.2548828 + - 0.19689941 + - -2.0976563 + - 0.6020508 + - 0.14929199 + - 8.09375 + - -0.37939453 + - -1.6357422 + - -1.1328125 + - 1.1572266 + - 1.5166016 + - 1.8105469 + - -1.7607422 + - -1.9306641 + - 0.43115234 + - 2.6933594 + - 0.68603516 + - 3.0800781 + - -3.4238281 + - -4.5898438 + - 0.8173828 + - 0.81689453 + - 1.5869141 + - 0.9785156 + - 0.3359375 + - -0.2454834 + - 4.140625 + - 0.45922852 + - 0.1227417 + - -2.3183594 + - 1.6416016 + - -0.86376953 + - 1.2724609 + - -3.3242188 + - -0.48486328 + - 1.7539063 + - -2.6875 + - 1.2851563 + - 3.9628906 + - 2.2578125 + - -0.9003906 + - -0.890625 + - 1.5214844 + - 1.3681641 + - 0.6738281 + - 2.875 + - 4.9257813 + - -0.41552734 + - 1.0478516 + - -0.67822266 + - 0.17907715 + - 0.7519531 + - 2.2324219 + - 1.2285156 + - 1.1103516 + - 0.13671875 + - -4.5898438 + - -0.58251953 + - 3.1289063 + - -2.9101563 + - -0.5 + - -3.109375 + - -0.7890625 + - 2.46875 + - 6.3671875 + - 1.0234375 + - -1.5839844 + - 1.7226563 + - 2.2578125 + - -0.53271484 + - -1.3720703 + - 1.2597656 + - -5.4179688 + - 1.2451172 + - 2.6855469 + - 5.4140625 + - -0.4560547 + - 0.5136719 + - -1.0898438 + - -0.8725586 + - -2.5917969 + - -3.6132813 + - 3.6015625 + - -0.8730469 + - 0.97802734 + - 5.375 + - -2.1015625 + - -1.2539063 + - -2.5039063 + - -0.38916016 + - -0.047546387 + - 0.2939453 + - -1.1806641 + - -0.13952637 + - 3.3027344 + - -0.9951172 + - 0.3881836 + - 1.9726563 + - 0.578125 + - -0.53564453 + - -0.30908203 + - 3.3164063 + - -0.27539063 + - 0.8676758 + - 1.8466797 + - 2.5957031 + - 0.625 + - -0.63427734 + - -3.7246094 + - -3.3027344 + - 0.061645508 + - 3.0683594 + - -0.9375 + - 2.4726563 + - -0.6616211 + - 1.5009766 + - -0.15673828 + - -3.625 + - 0.9790039 + - 0.10180664 + - -0.1430664 + - -1.1445313 + - -2.4355469 + - 6.703125 + - -2.4082031 + - 0.82666016 + - -1.2753906 + - 2.6503906 + - 0.7402344 + - -0.68408203 + - -2.0527344 + - 0.01701355 + - -3.9707031 + - 0.9741211 + - 0.3684082 + - 1.9746094 + - 1.2275391 + - 11.5703125 + - -1.9726563 + - -1.2568359 + - 1.5556641 + - 0.38720703 + - 6.0625 + - 4.03125 + - 0.3269043 + - -1.5058594 + - -0.7089844 + - 0.52783203 + - 8.3125 + - 0.38867188 + - -0.64453125 + - 0.23876953 + - -1.2001953 + - 0.69921875 + - -3.109375 + - -2.7402344 + - -2.3964844 + - -3.6738281 + - 1.8652344 + - -3.6816406 + - -1.0703125 + - 1.0126953 + - 0.83251953 + - -4.9414063 + - -0.2487793 + - 0.36669922 + - 1.9873047 + - -0.4453125 + - -1.421875 + - 1.3291016 + - -1.1318359 + - -1.125 + - 2.25 + - 0.49023438 + - 1.9892578 + - 4.171875 + - -1.8466797 + - 1.5117188 + - 0.41845703 + - -4.1914063 + - -1.8828125 + - -0.3010254 + - -1.7539063 + - 3.1015625 + - -1.0146484 + - 0.4970703 + - 3.1601563 + - 0.080078125 + - 3.5722656 + - -0.74072266 + - 3.1738281 + - -1.8457031 + - 3.15625 + - -0.88671875 + - -3.90625 + - -2.7324219 + - -3.7539063 + - 1.6591797 + - 1.1328125 + - -0.9873047 + - -0.70703125 + - -0.78564453 + - -0.30078125 + - -2.2480469 + - -1.0400391 + - 1.1386719 + - 1.0878906 + - -0.74658203 + - 2.7128906 + - -9.265625 + - 3.6757813 + - 3.4140625 + - -0.7910156 + - 0.8730469 + - -2.4628906 + - -0.8623047 + - 0.82128906 + - -0.09765625 + - 1.9785156 + - 0.9145508 + - -0.8256836 + - 3.8378906 + - 0.45043945 + - -1.5556641 + - -2.703125 + - -0.60546875 + - 1.1132813 + - -0.43652344 + - -2.0175781 + - -0.31958008 + - -0.07867432 + - -1.5126953 + - 3.2539063 + - 0.37036133 + - -6.2109375 + - 1.9072266 + - 4.3515625 + - -0.01171875 + - 0.04852295 + - 0.296875 + - 0.8154297 + - 1.7441406 + - 2.4199219 + - 3.375 + - 0.42578125 + - 0.5605469 + - -0.43188477 + - -0.09667969 + - 1.4482422 + - 2.7324219 + - -0.17468262 + - -3.9589844 + - 10.7734375 + - 2.2988281 + - -3.1738281 + - -71 + - 0.8598633 + - -1.671875 + - -0.8847656 + - 2.8320313 + - 4.7929688 + - 1.6953125 + - 0.8984375 + - -0.09063721 + - -2.2050781 + - -2.765625 + - 1.6904297 + - -0.7163086 + - 2.3457031 + - 0.35083008 + - -5.0625 + - -2.6972656 + - -3.0078125 + - -0.32592773 + - 1.7851563 + - 2.4550781 + - 0.5205078 + - 1.1357422 + - -0.9584961 + - -1.6064453 + - -2.7480469 + - -1.6689453 + - -3.2753906 + - 1.0966797 + - -1.7207031 + - 1.1298828 + - -4.6367188 + - 0.08984375 + - -1.109375 + - -3.8867188 + - 1.0859375 + - 1.0166016 + - -0.043792725 + - 1.3095703 + - -2.6269531 + - -0.30297852 + - -1.3212891 + - 4.2148438 + - 1.796875 + - 1.2851563 + - -2.6074219 + - 2.0527344 + - 1.4707031 + - 2.9453125 + - 0.33374023 + - 1.2978516 + - -0.5600586 + - 1.0791016 + - 9.7578125 + - -4.8945313 + - 1.8242188 + - 0.14147949 + - 0.9223633 + - 0.3815918 + - -2.0175781 + - 0.9194336 + - 2.046875 + - 0.3852539 + - -3.15625 + - -0.7392578 + - 0.11602783 + - -4.640625 + - 0.7426758 + - -0.93603516 + - 0.4621582 + - -2.9628906 + - 2.0625 + - 2.890625 + - 0.58935547 + - 1.4394531 + - 0.2878418 + - -2.2128906 + - -0.7866211 + - 0.54345703 + - 1.0351563 + - -0.11187744 + - 0.4152832 + - -1.7988281 + - -1.1962891 + - 0.7685547 + - -2.7597656 + - 2.4375 + - 3.6503906 + - -0.6088867 + - -1.0214844 + - -1.2431641 + - 2.0878906 + - -0.15905762 + - 2.8632813 + - 2.4941406 + - 7.8046875 + - 1.8417969 + - 3.0839844 + - -1.7001953 + - 0.81103516 + - 1.5585938 + - -0.31445313 + - 0.3947754 + - 1.9375 + - -0.9941406 + - 0.13220215 + - -0.83740234 + - -2.9550781 + - 0.67822266 + - -1.1914063 + - 5.3007813 + - 16.75 + - 1.0976563 + - -0.65185547 + - -3.8984375 + - 1.375 + - -0.75 + - 1.6728516 + - 2.3945313 + - -0.31225586 + - -0.9316406 + - 3.2753906 + - 0.94970703 + - 1.359375 + - -1.875 + - 2.1777344 + - 2.2441406 + - -4.0898438 + - 1.3691406 + - 0.30395508 + - 2.1152344 + - 0.1126709 + - -1.7089844 + - 1.3037109 + - -0.82666016 + - 3.9414063 + - 1.4775391 + - -1.4306641 + - 3.2910156 + - 1.3632813 + - -1.796875 + - -3.2226563 + - 1.6689453 + - -0.072509766 + - -2.9960938 + - 0.76416016 + - 0.1616211 + - -2.6503906 + - 0.085510254 + - 1.9941406 + - 0.55908203 + - 0.34423828 + - 3.0351563 + - 1.4033203 + - -0.54785156 + - 0.37817383 + - 3.5644531 + - -0.7607422 + - 2.7578125 + - 0.76660156 + - 3.2304688 + - 2.390625 + - -2.2675781 + - -1.4804688 + - 2.2480469 + - 6.3867188 + - -2.7519531 + - -0.3305664 + - 3.0195313 + - -4.2539063 + - 0.103515625 + - -0.5175781 + - -2.2578125 + - 0.27441406 + - 0.76660156 + - 2.3105469 + - 1.1015625 + - 0.081726074 + - -0.16015625 + - -0.0078125 + - -1.9619141 + - -0.63720703 + - -2.21875 + - 0.4033203 + - 1.1953125 + - 0.39013672 + - -2.21875 + - -1.65625 + - -2.0566406 + - -1.6669922 + - -10.375 + - 0.6894531 + - 0.6230469 + - -0.0446167 + - -0.6328125 + - -1.4785156 + - -3.3125 + - 1.4169922 + - -0.5205078 + - 1.609375 + - 3.4453125 + - 1.1767578 + - 2.6171875 + - 5.765625 + - -1.453125 + - 1.8847656 + - -3.3789063 + - -3.6875 + - -2.703125 + - 1.6894531 + - 0.23828125 + - -2.6445313 + - 2.9140625 + - -2.3457031 + - -0.65478516 + - 0.69970703 + - 1.2314453 + - 5.4804688 + - -0.18164063 + - 0.48754883 + - 3.3339844 + - 4.1132813 + - -3.0664063 + - -5.390625 + - -0.29589844 + - 0.8984375 + - 1.0292969 + - 2.5839844 + - -0.093444824 + - -1.4394531 + - 2.6972656 + - 2.3828125 + - -0.29467773 + - -1.8320313 + - -1.3818359 + - 2.1191406 + - 0.82128906 + - 3.8769531 + - 1.8378906 + - -0.46313477 + - 3.375 + - 1.1123047 + - 1.0087891 + - 2.1347656 + - -3.4277344 + - -2.8945313 + - -2.65625 + - 2.4277344 + - 2.7734375 + - -1.9775391 + - -3.71875 + - -3.6953125 + - -1.5332031 + - -4.8945313 + - 0.98828125 + - -1.0302734 + - 2.1640625 + - 0.5756836 + - -2.96875 + - -4.15625 + - -0.06274414 + - 0.03515625 + - 3.4160156 + - 0.92285156 + - -0.64697266 + - -1.0117188 + - 20.421875 + - 1.1201172 + - 0.58251953 + - 2.1933594 + - 8.015625 + - -0.35546875 + - -0.2253418 + - 0.3088379 + - 0.7392578 + - -3.4335938 + - -0.8833008 + - 4.125 + - -2.3203125 + - 4.7304688 + - 0.66845703 + - 0.73535156 + - -0.64697266 + - 0.68310547 + - -2.9316406 + - -2.5644531 + - 5.1523438 + - -0.84277344 + - 0.48046875 + - 3.7089844 + - 0.16040039 + - -3.9765625 + - 1.3769531 + - 2.2441406 + - 0.9951172 + - 0.20532227 + - 0.63134766 + - 0.3720703 + - 3.1738281 + - 0.61279297 + - -4.0507813 + - 0.96191406 + - -0.62353516 + - -0.9472656 + - -1.0126953 + - -4.5390625 + - 5.3164063 + - 2.5136719 + - -6.2109375 + - -1.0478516 + - 1.4082031 + - 2.2832031 + - -1.5019531 + - 1.1425781 + - 1.7949219 + - -2.5058594 + - 3.6738281 + - 0.515625 + - 2.3613281 + - 0.29858398 + - 6.1289063 + - 1.1318359 + - 0.29174805 + - 1.046875 + - -2.0136719 + - -3.8242188 + - 4.546875 + - 3.0429688 + - 2.7207031 + - 0.028457642 + - 0.33691406 + - 0.15515137 + - 2.9394531 + - -3.4550781 + - 0.39282227 + - 0.38305664 + - -4.5078125 + - -1.8945313 + - 1.9765625 + - 2.75 + - -4.6992188 + - -2.0136719 + - -1.1396484 + - -3.2890625 + - -1.2226563 + - -2.7890625 + - 1.3349609 + - 1.0654297 + - 0.18237305 + - -3.5683594 + - -0.7392578 + - 2.5644531 + - 1.5683594 + - -1.3681641 + - -2.8691406 + - 1.3779297 + - -1.5214844 + - -0.83691406 + - -4.0742188 + - -2.375 + - -4.5429688 + - 2.6953125 + - 0.6816406 + - -3.203125 + - -2.5175781 + - -2.1894531 + - 1.2763672 + - 0.5151367 + - -0.6088867 + - 4.1289063 + - -3.0625 + - 0.6694336 + - -0.07446289 + - -1.6347656 + - 4.0546875 + - -3.6660156 + - 1.1875 + - -2.1308594 + - 2.0566406 + - -0.37890625 + - -4.78125 + - -1.0332031 + - 3.9765625 + - 0.3557129 + - 1.2753906 + - -2.8867188 + - 2.3613281 + - -6.140625 + - 1.2578125 + - 0.69873047 + - -0.89160156 + - 3.6640625 + - 3.5039063 + - 1.4873047 + - 2.4082031 + - -0.64160156 + - 0.66015625 + - -2.4589844 + - -3.3144531 + - -2.1328125 + - 2.8867188 + - 0.7421875 + - -1.4570313 + - 1.7060547 + - 1.0664063 + - -0.52685547 + - 2.5371094 + - -1.890625 + - -1.6679688 + - 1.2255859 + - -0.51953125 + - -1.5722656 + - 1.5800781 + - 0.42919922 + - 0.4934082 + - 3.7558594 + - 2.6347656 + - 0.0892334 + - -1.2910156 + - -5.2148438 + - 3.09375 + - 1.4492188 + - -2.1113281 + - 2.4453125 + - 1.5205078 + - -3.7050781 + - 2.1386719 + - 1.9863281 + - -1.7480469 + - 2.6875 + - -2.9941406 + - -1.9804688 + - -1.8417969 + - 0.51708984 + - 1.8808594 + - 0.34106445 + - -1.5683594 + - -5.5898438 + - -0.23840332 + - -1.6435547 + - -0.86816406 + - -1.3125 + - -5.1445313 + - 3.1347656 + - 0.6113281 + - -2.2421875 + - 1.0253906 + - -1.7421875 + - 3.6621094 + - -2.1660156 + - 2.3730469 + - -1.4462891 + - 0.33862305 + - -0.83984375 + - -0.49267578 + - 1.8681641 + - -0.2175293 + - -0.25854492 + - -3.2089844 + - 0.10430908 + - -1.5869141 + - 1.0126953 + - 1.2773438 + - 3.75 + - -1.6982422 + - -2.1621094 + - -0.034454346 + - 3.90625 + - 2.0703125 + - -1.0029297 + - -3.7441406 + - -1.1357422 + - -2.8867188 + - 8.7734375 + - -1.75 + - -0.11102295 + - -1.7871094 + - 4.3984375 + - 1.2919922 + - 1.1982422 + - 0.79785156 + - -1.3037109 + - 0.2175293 + - -0.7133789 + - 2.1738281 + - -5.390625 + - -2.6777344 + - 5.7382813 + - -4.1210938 + - 3.6914063 + - -1.0966797 + - 0.49926758 + - 0.63720703 + - 3.8164063 + - 0.39770508 + - -1.3974609 + - -0.011154175 + - 0.9560547 + - 2.171875 + - -4.8320313 + - 1.7783203 + - 0.55126953 + - -3.1738281 + - -1.4326172 + - -0.23596191 + - -1.140625 + - -0.22290039 + - -1.1679688 + - 0.34204102 + - 1.5605469 + - -0.85595703 + - -2.0996094 + - -3.8925781 + - 0.55126953 + - -1.4453125 + - -1.6191406 + - 0.23510742 + - 2.6875 + - 0.5488281 + - 2.5390625 + - -0.30566406 + - -0.31054688 + - -1.75 + - 3.4765625 + - 2.8691406 + - -1.8105469 + - -0.67822266 + - -3.6894531 + - -2.2324219 + - 1.7548828 + - 0.15344238 + - -2.2128906 + - -2.3222656 + - -0.578125 + - 1.2382813 + - -0.4765625 + - 0.88134766 + - 2.4453125 + - -0.92285156 + - -3.0878906 + - -2.65625 + - 0.1439209 + - -2.96875 + - -1.8652344 + - -1.0390625 + - -2.1757813 + - -2.8847656 + - -0.6171875 + - -0.8310547 + - -1.3662109 + - 5.4140625 + - 4.6992188 + - -4.21875 + - -0.35668945 + - -1.2822266 + - 1.4794922 + - -2.3300781 + - -2.2949219 + - 3.5800781 + - -1.3066406 + - -2.5527344 + - 1.4326172 + - 2.2753906 + - -2.203125 + - -3.6445313 + - -0.66503906 + - -1.7519531 + - -1.0224609 + - 0.15905762 + - -0.32299805 + - -0.7036133 + - -1.9609375 + - -1.0732422 + - -1.2900391 + - -0.7626953 + - -2.0644531 + - -2.2519531 + - -0.75390625 + - -0.3725586 + - 3.9863281 + - -2.7480469 + - 3.9023438 + - -1.9814453 + - -0.93847656 + - 6.5117188 + - 0.60546875 + - -0.82666016 + - -1.3544922 + - 0.6323242 + - -2.96875 + - 3.3164063 + - 6.4257813 + - -2.3164063 + - -0.70703125 + - 5.7226563 + - 0.9033203 + - 1.3867188 + - 0.39868164 + - -1.9765625 + - 1.0751953 + - 0.51123047 + - -2.9804688 + - 1.3408203 + - -0.8623047 + - -0.3305664 + - 2.6601563 + - -7.1601563 + - 0.71728516 + - 4.21875 + - -2.4765625 + - -0.79003906 + - -2.1503906 + - 4.2460938 + - -5.1679688 + - -2.3320313 + - -0.23156738 + - 1.5947266 + - 2.4082031 + - -0.6894531 + - 1.6523438 + - -2.3300781 + - -2.6777344 + - 2.3339844 + - -0.69189453 + - 0.39379883 + - -2.3339844 + - 3.765625 + - 0.6713867 + - -1.71875 + - -2.4199219 + - -1.2382813 + - -0.22509766 + - 0.57373047 + - -0.34472656 + - 0.5488281 + - 2.0214844 + - -2.5917969 + - -0.09649658 + - -2.7949219 + - 0.71972656 + - 0.95751953 + - 1.1845703 + - -1.2763672 + - -2.2324219 + - -3.1464844 + - 1.2744141 + - 0.5834961 + - 1.15625 + - -0.36157227 + - -2.1542969 + - -2.1152344 + - 1.2978516 + - -3.0253906 + - -2.5078125 + - -1.9648438 + - 3.6992188 + - -3.4804688 + - -1.9482422 + - -0.6015625 + - 2.3535156 + - -1.609375 + - 0.017578125 + - -1.0625 + - -0.9248047 + - -0.30395508 + - -4.1132813 + - 0.8129883 + - 1.6357422 + - 4.8632813 + - -1.6777344 + - 1.4501953 + - -0.2841797 + - 6.375 + - 1.9326172 + - -0.73095703 + - 1.4150391 + - 1.7363281 + - -0.64941406 + - -1.9150391 + - -1.2910156 + - 1.2724609 + - 1.7753906 + - 3.4375 + - -1.9316406 + - 2.3691406 + - -0.04574585 + - -0.054595947 + - 2.40625 + - -0.54248047 + - -0.9785156 + - 1.7080078 + - -1.4541016 + - -2.8515625 + - 0.9140625 + - 0.92041016 + - -3.3164063 + - -0.5415039 + - 1.859375 + - -1.9082031 + - -1.2275391 + - -0.16516113 + - -0.29711914 + - 4.4257813 + - 6.828125 + - -1.8183594 + - -0.18664551 + - -3.7402344 + - -2.1445313 + - 0.515625 + - 1.0849609 + - -2.375 + - 1.8476563 + - -3.6679688 + - -2.8671875 + - -0.51171875 + - -2.3496094 + - -0.9980469 + - -2.3925781 + - -0.021759033 + - 1.8232422 + - 1.421875 + - -0.38916016 + - 1.7294922 + - 2.8515625 + - -0.71875 + - -2.0195313 + - 1.3427734 + - 2.3515625 + - 0.8647461 + - -1.6259766 + - -0.9580078 + - 0.50634766 + - 0.05996704 + - -0.2841797 + - -3.6992188 + - -1.28125 + - -1.3017578 + - 1.7587891 + - -0.9296875 + - 0.9707031 + - 0.14562988 + - 2.8203125 + - -0.19946289 + - -1.4619141 + - 8.03125 + - -2.1171875 + - 3.65625 + - -4.03125 + - 3.6367188 + - 4.2148438 + - -4.0703125 + - 1.1347656 + - 1.7832031 + - -0.21923828 + - -1.1455078 + - -0.35864258 + - -0.16906738 + - 1.8251953 + - -1.71875 + - -1.2568359 + - -1.7851563 + - 3.9589844 + - -0.72753906 + - 1.2275391 + - 0.44628906 + - -1.2568359 + - 0.9194336 + - -0.515625 + - -0.5131836 + - -1.1142578 + - 3.3339844 + - 0.8959961 + - -2.1777344 + - 1.6064453 + - -0.6953125 + - -2.7265625 + - 0.44482422 + - -2.1367188 + - -0.85253906 + - 2.6328125 + - 2.1464844 + - 2.1816406 + - -8.9609375 + - 4.40625 + - -0.578125 + - 0.32617188 + - 0.48632813 + - -3.5039063 + - 1.9033203 + - 0.44970703 + - -1.4980469 + - 1.4433594 + - -4.6289063 + - 0.4033203 + - -0.2097168 + - -0.4741211 + - 0.07739258 + - 0.23547363 + - 1.1494141 + - -0.3383789 + - -0.7475586 + - 0.73291016 + - 2.0761719 + - -2.421875 + - 1.4589844 + - -2.5488281 + - 1.5820313 + - 2.3574219 + - 0.77978516 + - 1.0751953 + - 1.9609375 + - -0.33642578 + - 0.08258057 + - -1.2607422 + - 4.4570313 + - 1.421875 + - 2.5390625 + - 1.0185547 + - -4.046875 + - 0.6635742 + - -0.4050293 + - -0.3876953 + - -0.26391602 + - 1.1337891 + - -0.93896484 + - 1.3505859 + - 6.3554688 + - 1.0771484 + - -8.7421875 + - 1.2646484 + - 1.3359375 + - -0.11853027 + - -0.98535156 + - 2.9433594 + - 6.1757813 + - -1.8076172 + - -0.09399414 + - -0.6176758 + - -1.4550781 + - 1.4707031 + - -0.77441406 + - 0.2220459 + - -0.23046875 + - -2.4199219 + - -0.43237305 + - -0.49902344 + - 4.078125 + - -1.9355469 + - -1.4414063 + - 0.12658691 + - 1.7949219 + - 3.6269531 + - 2.203125 + - 1.0576172 + - 0.4970703 + - 2.703125 + - 0.66748047 + - -24.875 + - 1.6738281 + - -4.6367188 + - -1.8183594 + - -15.671875 + - -1.2578125 + - -0.6875 + - 3.0644531 + - -3.7109375 + - 2.6074219 + - -7.5507813 + - -7.9296875 + - 0.8076172 + - -0.953125 + - 2.0195313 + - -1.1660156 + - 0.38110352 + - 4.4414063 + - -0.9458008 + - 1.5400391 + - 1.0097656 + - 2.0351563 + - 1.9921875 + - -2.9023438 + - -2.4785156 + - 3.6640625 + - -2.578125 + - 1.8388672 + - 1.6982422 + - -5.0117188 + - 1.9042969 + - -0.31152344 + - -0.0836792 + - 2.3574219 + - 0.6328125 + - -1.6601563 + - 1.71875 + - -1.8515625 + - 0.73095703 + - -0.04421997 + - 0.4597168 + - 0.034576416 + - 3.46875 + - 1.4013672 + - 0.056915283 + - 3.71875 + - 2.7539063 + - 1.515625 + - -1.0654297 + - -1.0966797 + - 1.7587891 + - -1.0693359 + - -2.015625 + - 2.0742188 + - 1.3916016 + - 3.1171875 + - -1.6464844 + - -4.7148438 + - 0.67529297 + - -2.6191406 + - 0.16125488 + - 2.4453125 + - -3.1289063 + - -0.6386719 + - -0.37548828 + - -0.41308594 + - -0.12719727 + - 4.5664063 + - 2.8710938 + - 1.4658203 + - -4.6757813 + - -0.140625 + - 3.0175781 + - 0.5756836 + - -0.4440918 + - 1.3955078 + - 0.27856445 + - -0.7294922 + - -1.0048828 + - 2.1171875 + - -3.4804688 + - -0.22387695 + - 1.3056641 + - -0.33764648 + - 0.57910156 + - 4.0429688 + - -0.57177734 + - 0.72314453 + - -1.4560547 + - -3.84375 + - 0.8569336 + - -1.7167969 + - 0.9316406 + - -1.5507813 + - -2.4707031 + - 0.9458008 + - -3.0820313 + - -8.6328125 + - 0.87353516 + - -3.7128906 + - 0.2854004 + - 2.3984375 + - 1.1992188 + - -3.4628906 + - 0.6176758 + - -3.5625 + - -1.8496094 + - -5.140625 + - -0.8227539 + - 0.005859375 + - -0.0052986145 + - 3.953125 + - -0.890625 + - 1.4560547 + - -3.1464844 + - -2.7402344 + - -1.1064453 + - 0.2019043 + - -0.8989258 + - -3.078125 + - 0.8232422 + - -2.5 + - -0.43896484 + - -0.1282959 + - 1.2353516 + - -0.3251953 + - 0.5102539 + - -3.4140625 + - -1.6064453 + - 0.57910156 + - -5.2148438 + - -2.2265625 + - 2.5878906 + - 5.3945313 + - 5.4765625 + - -0.2890625 + - 0.234375 + - 4.4335938 + - 3.2617188 + - -1.6669922 + - -0.90234375 + - -2.3027344 + - 0.3310547 + - 2.8554688 + - -1.0009766 + - -0.7446289 + - -0.61035156 + - -0.75390625 + - -2.0234375 + - -2.2988281 + - 2.4609375 + - -1.8125 + - 1.2353516 + - -0.21203613 + - -2.3457031 + - -0.0234375 + - 0.78027344 + - 1.3662109 + - -0.5136719 + - -0.7988281 + - 0.52685547 + - 2.2109375 + - -0.9453125 + - -1.5009766 + - -4.6523438 + - -0.0446167 + - 0.20629883 + - 3.40625 + - -0.46484375 + - 0.18688965 + - 2.3476563 + - 23.5 + - -0.89501953 + - -3.078125 + - 4.3554688 + - 0.5859375 + - 4.0507813 + - -2.0214844 + - -13.3359375 + - 1.4970703 + - -1.0517578 + - 4.7578125 + - 0.66796875 + - 0.11383057 + - 1.2236328 + - 0.84375 + - 2.2851563 + - 1.4814453 + - -4.9257813 + - 0.3095703 + - -4.7148438 + - 1.0253906 + - -3.7539063 + - 0.3647461 + - -0.20080566 + - -1.4785156 + - 3.5820313 + - -0.93603516 + - -2.2539063 + - 0.28979492 + - 3.0644531 + - -0.5317383 + - -0.69189453 + - 1.3955078 + - -1.6269531 + - -1.3457031 + - -2.0546875 + - -0.33032227 + - -0.26245117 + - -0.96191406 + - 0.11212158 + - -2.59375 + - 2.2695313 + - -1.0654297 + - -1.7246094 + - 1.9658203 + - -0.79833984 + - 0.2915039 + - 1.7851563 + - -3.4238281 + - 3.5742188 + - 1.0439453 + - -1.3769531 + - 5.90625 + - -2.6601563 + - -2.3691406 + - 0.82666016 + - 0.78759766 + - 2.9375 + - -2.3515625 + - 1.5 + - -2.4375 + - 3.8339844 + - 0.71240234 + - -1.1992188 + - -0.064697266 + - 6.109375 + - 3.3691406 + - -0.4128418 + - -1.7158203 + - -0.36547852 + - -1.1796875 + - -0.25268555 + - -0.30004883 + - -0.19189453 + - -2.7128906 + - -5.9140625 + - 6.5351563 + - 0.93652344 + - -2.375 + - -1.8955078 + - 1.6201172 + - 0.37719727 + - -0.3203125 + - -0.21618652 + - 0.5834961 + - 1.2314453 + - 0.7866211 + - 1.6142578 + - -3.2421875 + - 0.8457031 + - 1.3232422 + - -1.9501953 + - 0.4663086 + - 0.171875 + - 5.1757813 + - 2.1445313 + - -1.6201172 + - 4.75 + - -1.0703125 + - 2.4765625 + - 4.703125 + - -0.546875 + - -1.9902344 + - 5.75 + - 0.78759766 + - 0.38598633 + - -1.2539063 + - -0.17272949 + - 2.4550781 + - 1.6503906 + - -1.2587891 + - -1.6191406 + - -1.8496094 + - -0.71777344 + - -0.42578125 + - 0.38891602 + - 0.73339844 + - 0.124572754 + - 0.29614258 + - -2.078125 + - 2.2597656 + - 23.0625 + - -3.9101563 + - 2.9414063 + - -0.17468262 + - 0.92871094 + - 2.359375 + - 0.18408203 + - -2.0410156 + - 0.2841797 + - -0.84375 + - -1.4482422 + - 1.9472656 + - -2.3066406 + - -1.7001953 + - -0.2607422 + - 0.31054688 + - -5.1601563 + - 1.984375 + - 2.1582031 + - 14.546875 + - -2.6972656 + - 1.4003906 + - -0.11602783 + - -1.4023438 + - 0.2097168 + - -0.65283203 + - 0.63623047 + - 0.6635742 + - -0.21679688 + - -1.2744141 + - -26 + - -0.5024414 + - 0.55078125 + - 1.0732422 + - -2.9140625 + - -0.4934082 + - -0.6484375 + - 0.9169922 + - -2.46875 + - 0.9277344 + - 0.59472656 + - -3.8222656 + - -1.3505859 + - -0.8232422 + - -0.15454102 + - -1.0322266 + - -1.2919922 + - -2.9804688 + - 0.62353516 + - -0.2298584 + - -2.3261719 + - 0.8232422 + - 2.6308594 + - 0.26000977 + - 3.421875 + - -1.4072266 + - 3.1738281 + - -0.5625 + - 7.6953125 + - -1.9335938 + - 2.5839844 + - 4.0078125 + - -6.6484375 + - 2.421875 + - -2.1796875 + - 4.359375 + - -0.8208008 + - -0.51123047 + - -1.7314453 + - 0.5083008 + - 0.62841797 + - 0.9926758 + - -5.5351563 + - 2.9492188 + - -0.17919922 + - -2.4003906 + - 0.0287323 + - 2.7089844 + - 2.53125 + - 2.6328125 + - 2.5039063 + - -1.953125 + - -1.2744141 + - 1.8378906 + - 4.15625 + - 1.4326172 + - -1.4902344 + - -3.828125 + - -0.64501953 + - -4.1679688 + - -1.1298828 + - 2.1113281 + - 2.2246094 + - 3.640625 + - -1.1396484 + - 4.890625 + - 4.9960938 + - 2.046875 + - -0.7363281 + - -1.0830078 + - 0.77001953 + - -1.2724609 + - 1.3398438 + - -1.2626953 + - 1.3603516 + - -1.4814453 + - -2.6640625 + - 0.6230469 + - -3.5585938 + - -0.33764648 + - -3.3710938 + - -3.9375 + - -0.76416016 + - 0.515625 + - 3.0039063 + - -1.4169922 + - -0.14941406 + - 2.9160156 + - 0.7988281 + - 0.52783203 + - -2.7890625 + - 3.3554688 + - 2.0605469 + - -1.4150391 + - -3.3203125 + - 3.6054688 + - -0.5683594 + - 3.9394531 + - -2.7871094 + - -0.92089844 + - -1.0517578 + - 0.8227539 + - 3.4941406 + - 2.4726563 + - -0.17443848 + - 0.9404297 + - -3.7363281 + - -6.046875 + - -0.46191406 + - -1.4882813 + - 2.6621094 + - 2.6914063 + - 0.81933594 + - 1.0390625 + - 2.1582031 + - 0.5991211 + - -0.0715332 + - 2.3574219 + - -1.8457031 + - 2.953125 + - 1 + - -0.45532227 + - -0.33251953 + - -0.8066406 + - -0.6645508 + - 12.1953125 + - 0.5239258 + - 2.53125 + - 5.7851563 + - 7.796875 + - -1.2158203 + - 0.42822266 + - -1.0888672 + - 1.4638672 + - -2.6542969 + - -1.7939453 + - 1.3466797 + - 0.6689453 + - 0.30126953 + - -2.5625 + - -0.71875 + - 1.0185547 + - 1.890625 + - 1.9335938 + - 0.34350586 + - -0.17382813 + - -0.18469238 + - -0.78125 + - -1.9404297 + - -2.1035156 + - -1.4277344 + - 1.2451172 + - -0.46313477 + - -2.4238281 + - -3.4238281 + - 2.7890625 + - 2.1503906 + - 1.9921875 + - 1.015625 + - 0.2241211 + - -0.98291016 + - 1.9423828 + - -1.75 + - 0.74072266 + - 1.8212891 + - -1.4931641 + - 1.2539063 + - -1.7744141 + - -0.55615234 + - 3.9394531 + - -0.7192383 + - 1.7138672 + - -2.6484375 + - -1.0947266 + - -2.9023438 + - 3.21875 + - 1.0126953 + - -2.4042969 + - -1.1142578 + - 4.1015625 + - 1.8300781 + - 1.0361328 + - 1.5976563 + - 4.1875 + - 0.8457031 + - -1.8183594 + - -1.6669922 + - 1.4794922 + - 1.5244141 + - 1.203125 + - 4.1875 + - 2.5175781 + - 2.2617188 + - 1.9628906 + - -1.4160156 + - -0.6542969 + - -1.8525391 + - 1.2382813 + - 0.2019043 + - -0.050201416 + - -1.1044922 + - 0.3461914 + - 1.390625 + - 0.10290527 + - 3.0859375 + - -0.97753906 + - 0.08258057 + - 0.86376953 + - -0.26757813 + - 23.46875 + - -3.4707031 + - -1.1474609 + - -4.2460938 + - -0.22851563 + - 0.73583984 + - 2.34375 + - -0.092041016 + - -4.7851563 + - 1.6845703 + - 2.5976563 + - -1.359375 + - 3.3945313 + - 2.5351563 + - 1.9492188 + - 0.52001953 + - 1.6367188 + - -3.0742188 + - 1.7148438 + - 0.96191406 + - -2.2128906 + - 1.7011719 + - -3.6757813 + - 1.7763672 + - 0.0758667 + - 0.82177734 + - -2.2089844 + - 0.11645508 + - 2.3359375 + - -3.7753906 + - -0.76953125 + - 1.3154297 + - 2.078125 + - 2.1328125 + - 2.4160156 + - -1.5634766 + - 6.2851563 + - -0.03125 + - 0.32592773 + - -0.65625 + - -4.3359375 + - -3.5664063 + - 0.5019531 + - 4.9257813 + - 0.38012695 + - 0.20166016 + - -1.5683594 + - 1.7353516 + - 2.8164063 + - 3.9121094 + - -0.57470703 + - -1.8261719 + - 0.39379883 + - 8.6640625 + - -3.2226563 + - -1.2158203 + - 0.6328125 + - -1.2607422 + - 1.1367188 + - 0.51123047 + - 1.3037109 + - -0.11773682 + - -0.11462402 + - -4.2421875 + - -3.546875 + - -2.6640625 + - -3.1269531 + - -2.9941406 + - 0.49536133 + - -2.1972656 + - -1.2841797 + - 3.2851563 + - -0.7211914 + - -1.8222656 + - 0.68310547 + - -3.3378906 + - -4.3945313 + - -0.29614258 + - 2.0722656 + - -2.6777344 + - -0.19885254 + - 1.1748047 + - 2.1855469 + - 1.2265625 + - -1.1201172 + - -3.0878906 + - -1.4257813 + - -0.8696289 + - -2.9550781 + - 0.012275696 + - -0.5029297 + - -0.26831055 + - 4.1679688 + - -1.1015625 + - 2.6386719 + - -3.3066406 + - -2.3125 + - -1.2939453 + - -0.6850586 + - 1.2021484 + - -1.3095703 + - 1.4707031 + - 1.0224609 + - 0.8652344 + - 0.40429688 + - -1.2783203 + - -1.6054688 + - 1.5166016 + - -1.4238281 + - 1.6367188 + - 0.48046875 + - -0.32885742 + - 2.7402344 + - 0.9326172 + - 0.21398926 + - 1.2578125 + - -3.8359375 + - -2.6425781 + - -3.2421875 + - -1.3925781 + - 0.29956055 + - -0.22302246 + - 0.52734375 + - 1.0439453 + - 1.1669922 + - 1.2773438 + - -1.2041016 + - -2.421875 + - 1.2001953 + - 2.1035156 + - -2.71875 + - 2.1171875 + - 0.453125 + - 0.3317871 + - 1.2675781 + - 0.6713867 + - -5.578125 + - -3.3398438 + - -1.0908203 + - 1.5175781 + - 0.0262146 + - -2.25 + - -0.95703125 + - 4.9179688 + - -0.171875 + - 1.3681641 + - 6.5859375 + - 2.5625 + - -2.6875 + - 0.84033203 + - -0.055236816 + - 6.015625 + - -4.9648438 + - -2.1777344 + - 0.98876953 + - -2.1269531 + - -0.57470703 + - -2.3886719 + - 1.8857422 + - -3.3496094 + - 3.1972656 + - -1.1943359 + - 0.71972656 + - 0.15234375 + - -0.51708984 + - -1.1992188 + - 0.9658203 + - -0.23144531 + - -1.9414063 + - 5.9726563 + - 0.78759766 + - 2.4453125 + - -0.31518555 + - -4.4648438 + - 2.4316406 + - 0.24658203 + - 1.3349609 + - -0.71484375 + - -1.3564453 + - -0.7675781 + - 1.1240234 + - -2.0175781 + - -3.0800781 + - -0.032348633 + - 0.69873047 + - 1.7294922 + - 2.8203125 + - -2.3183594 + - 1.2373047 + - 0.30688477 + - -2.703125 + - 0.3466797 + - 3.5585938 + - 1.3242188 + - 5.7539063 + - 0.24804688 + - 0.0625 + - 16.203125 + - -0.41845703 + - 2.3027344 + - -3.5488281 + - -0.90771484 + - -0.89697266 + - 0.5410156 + - 1.4794922 + - 4.1484375 + - -0.92089844 + - -3.5253906 + - -1.8222656 + - 0.8720703 + - 1.9169922 + - 1.0517578 + - -1.1318359 + - 4.453125 + - -0.26391602 + - -0.66796875 + - 0.24523926 + - -1.6455078 + - 0.3034668 + - -1.5175781 + - -2.2949219 + - -1.6777344 + - 2.3652344 + - -0.2253418 + - -3.9960938 + - -3.1015625 + - 0.74316406 + - -0.99609375 + - -0.87890625 + - -1.8613281 + - -1.890625 + - 0.1751709 + - -0.083984375 + - 3.0117188 + - 0.75634766 + - 2.7890625 + - 0.2861328 + - 1.9648438 + - -4.5898438 + - 0.88720703 + - 0.65283203 + - -0.06890869 + - 4.2070313 + - -1.3691406 + - -1.3691406 + - -2.0625 + - -5.4882813 + - 2.1308594 + - 1.9013672 + - -0.30786133 + - 2.8808594 + - 4.703125 + - -1.6386719 + - -0.17785645 + - -3.8339844 + - -0.13439941 + - -1.8310547 + - -0.77441406 + - -1.1064453 + - 1.7431641 + - -2.7011719 + - -0.38720703 + - 1.0185547 + - 1.9091797 + - -4.953125 + - 3.3925781 + - 0.92626953 + - -0.5727539 + - -1.6923828 + - 4.6914063 + - 0.94384766 + - 1.1826172 + - 1.0126953 + - -1.9609375 + - -2.4472656 + - 1.6650391 + - 1.3632813 + - 2.3925781 + - 0.17211914 + - 4.7539063 + - -1.6230469 + - -1.1386719 + - 0.9663086 + - -1.5556641 + - -0.7675781 + - -1.5439453 + - 0.62353516 + - -4.34375 + - -0.8286133 + - 1.6669922 + - 1.9033203 + - -2.3789063 + - 2.5566406 + - -3.9316406 + - 2.6816406 + - 0.78759766 + - -0.73876953 + - 4.6054688 + - -0.89160156 + - -2.6074219 + - 1.9169922 + - 2.4316406 + - 3.3085938 + - 1.7695313 + - -1.0097656 + - -0.22338867 + - 0.45361328 + - 33.40625 + - 13.4765625 + - -9.1796875 + - 2.265625 + - -1.0507813 + - 1.4277344 + - -2.734375 + - -4.1757813 + - -0.36376953 + - -0.20703125 + - 1.9589844 + - 0.51464844 + - -0.34057617 + - 1.5166016 + - -2.7890625 + - 1.9707031 + - -1.0009766 + - 0.91259766 + - -2.6933594 + - 0.7138672 + - 1.8779297 + - 3.4140625 + - -1.3193359 + - -1.1445313 + - -0.2253418 + - -2.1523438 + - 0.08703613 + - -0.4038086 + - -4.6054688 + - 0.75097656 + - -0.119384766 + - -0.16101074 + - 1.4169922 + - 2.4785156 + - 1.6337891 + - -4.3789063 + - -1.8554688 + - 2.0644531 + - -2.1699219 + - 1.2451172 + - 2.2324219 + - 1.5371094 + - -0.27978516 + - 4.2304688 + - -1.2050781 + - 0.29345703 + - -3.4941406 + - 2.1425781 + - 1.3066406 + - 0.5107422 + - 2.2910156 + - 8.7265625 + - -0.5673828 + - -1.4306641 + - 1.7226563 + - -0.9453125 + - -0.84521484 + - 0.05606079 + - 1.4580078 + - 0.2175293 + - 2.9785156 + - 2.3984375 + - 1.2050781 + - -3.9238281 + - -1.7402344 + - -1.1376953 + - 1.9384766 + - -0.83203125 + - -2.6855469 + - 0.2565918 + - -2.9277344 + - -0.20385742 + - -1.5039063 + - -2.265625 + - 0.92822266 + - -2.6640625 + - -0.18579102 + - 1.3486328 + - 5.4453125 + - 0.41503906 + - -1.7626953 + - -1.4189453 + - 1.6337891 + - 1.8632813 + - 1.6875 + - 2.3808594 + - 1.1025391 + - 0.22314453 + - 1.9453125 + - -1.5341797 + - 1.3691406 + - 0.5053711 + - -0.8886719 + - -0.99902344 + - 3.6582031 + - 1.2080078 + - -1.3974609 + - 4.03125 + - -1.9023438 + - 0.5214844 + - -3.4609375 + - -1.0595703 + - 0.75097656 + - 1.15625 + - 0.11743164 + - 0.4892578 + - 0.32250977 + - -2.3222656 + - -0.081970215 + - 1.4853516 + - -3.2910156 + - 3.6777344 + - -0.69384766 + - 4.28125 + - 1.8076172 + - 2.8300781 + - -2.9140625 + - -1.3212891 + - 3.5175781 + - 0.42773438 + - -2.3886719 + - -1.8847656 + - 0.8803711 + - 1.109375 + - 3.6132813 + - 1.3603516 + - -3.2714844 + - 2.0566406 + - 2.4140625 + - 0.1307373 + - -0.87890625 + - -1.2529297 + - -1.1123047 + - 1.2490234 + - 0.28198242 + - 0.3125 + - -0.18469238 + - -3.4375 + - 1.5390625 + - -1.3007813 + - -0.4399414 + - 1.9648438 + - 1.7783203 + - -2.1347656 + - -0.296875 + - -0.17236328 + - 2.0097656 + - -1.2041016 + - -0.14453125 + - -4.1132813 + - 1.1660156 + - 1.3193359 + - -1.4667969 + - -1.4375 + - 0.4111328 + - -0.91552734 + - -1.1474609 + - 0.41748047 + - 0.4025879 + - 2.1621094 + - 0.09051514 + - -2.5625 + - 2.7890625 + - 1.7763672 + - -0.9404297 + - 0.4248047 + - 0.32739258 + - 2.3457031 + - -0.119506836 + - -2.5625 + - -0.5102539 + - -0.26660156 + - -2.6132813 + - -1.3476563 + - 0.5800781 + - 0.7158203 + - 1.4140625 + - 1.9658203 + - -1.1708984 + - -1.7529297 + - -0.59765625 + - 0.38500977 + - -0.5258789 + - 0.9008789 + - 1.5195313 + - -1.5722656 + - -0.06945801 + - 1.7695313 + - 1.7246094 + - -1.2783203 + - 2.3789063 + - 2.3203125 + - 1.78125 + - 0.7128906 + - -2.4902344 + - -1.8623047 + - 2.984375 + - 1.1738281 + - 0.92285156 + - -3.3925781 + - -2.7636719 + - -1.4267578 + - -2.8496094 + - -0.41601563 + - 0.39208984 + - -12.4453125 + - -0.31689453 + - -0.46142578 + - 0.21984863 + - -0.89160156 + - 0.5493164 + - -1.2490234 + - 1.6689453 + - 0.4597168 + - -1.7109375 + - 2.34375 + - -5.3710938 + - 0.48706055 + - 0.3251953 + - -1.1757813 + - 1.375 + - 1.5214844 + - -2.0566406 + - -0.022598267 + - 3.4277344 + - 0.61816406 + - 1.828125 + - -0.5341797 + - 9.390625 + - 1.4433594 + - -2.1386719 + - 0.72509766 + - -0.5239258 + - 0.89208984 + - -0.89160156 + - -0.083618164 + - -2.6601563 + - 6.7539063 + - 0.6816406 + - -1.7734375 + - 0.74072266 + - 1.0400391 + - -6.0976563 + - 0.71777344 + - 0.2915039 + - 1.3701172 + - 0.43798828 + - 6.2929688 + - -0.5932617 + - -2.7695313 + - 1.8964844 + - 2.2207031 + - 2.4609375 + - 2.1035156 + - 1.1425781 + - -2.8378906 + - 1.5439453 + - 1.7998047 + - -3.1582031 + - -1.0820313 + - -0.32714844 + - -0.43115234 + - -3.2050781 + - -1.8183594 + - -3.2753906 + - -0.1986084 + - -3.8652344 + - 2.4101563 + - -1.6953125 + - -1.7978516 + - 3.5683594 + - -2.4199219 + - 0.19494629 + - -1.6347656 + - -1.6376953 + - 2.0566406 + - -0.3552246 + - -1.3388672 + - 1.7587891 + - 1.6367188 + - -0.61572266 + - 0.6455078 + - 0.6113281 + - 2.1738281 + - 0.86376953 + - 3.7558594 + - 0.019104004 + - -0.2692871 + - -1.7851563 + - 2.6640625 + - 0.18725586 + - -2.0234375 + - -1.2880859 + - -1.5732422 + - -0.09063721 + - 5.2382813 + - 4.703125 + - -1.1416016 + - 1.9345703 + - 2.3378906 + - -0.7207031 + - -1.2539063 + - -0.4033203 + - 2.0351563 + - -1.9433594 + - 2.2792969 + - -3.4765625 + - 2.8359375 + - 0.7871094 + - -3.9589844 + - -0.11071777 + - -2.6660156 + - 3.2460938 + - 0.30151367 + - -5.5117188 + - -0.2685547 + - -1.7626953 + - 1.6542969 + - 0.42626953 + - 0.66503906 + - 3.4492188 + - 0.47387695 + - 1.28125 + - -0.3215332 + - -3.09375 + - -1.6669922 + - -0.59765625 + - -3.7890625 + - 8.9296875 + - 1.1962891 + - 1.4658203 + - -0.5292969 + - 0.5283203 + - -1.4980469 + - 0.4362793 + - 1.1601563 + - -1.2988281 + - -5.4726563 + - -3.3964844 + - 4.6328125 + - -4.1757813 + - 1.8066406 + - -1.8466797 + - -2.8164063 + - 1.296875 + - 0.8886719 + - -0.58203125 + - 0.27270508 + - 1.25 + - 1.1113281 + - -3.1777344 + - 0.07476807 + - -4.0429688 + - 1.7041016 + - -1.5908203 + - 1.2070313 + - -3.5976563 + - 0.81103516 + - -1.4306641 + - 0.9394531 + - -2.4980469 + - -1.0517578 + - 0.07281494 + - 2.2519531 + - 3.2441406 + - 0.49902344 + - 1.6640625 + - -1.6152344 + - 2.421875 + - 1.2851563 + - -0.71875 + - -1.1757813 + - -2.6894531 + - -0.24438477 + - 0.5205078 + - 2.5664063 + - -2.8769531 + - -0.093566895 + - -0.00390625 + - 4.234375 + - -0.012275696 + - -2.2246094 + - 0.36572266 + - 1.9814453 + - -2.2167969 + - -2.3164063 + - -0.9794922 + - 1.2119141 + - 1.9492188 + - -0.5366211 + - 0.7207031 + - -1.4638672 + - -0.29589844 + - 0.8256836 + - 3.0742188 + - -2.9179688 + - -2.7089844 + - 1.5957031 + - 1.8466797 + - 5.8125 + - 2.6308594 + - -1.5351563 + - 1.4619141 + - -0.5991211 + - 1.0800781 + - -1.6582031 + - -2.0136719 + - -0.91308594 + - 1.2207031 + - -1.9169922 + - 1.1708984 + - -1.0449219 + - 3.5253906 + - 4.34375 + - -0.51708984 + - 0.18188477 + - -0.23486328 + - -1.4326172 + - -3.3300781 + - -2.8691406 + - -0.890625 + - 1.3818359 + - -1.0712891 + - 0.85791016 + - 2.171875 + - 1.5488281 + - 1.4101563 + - -0.41503906 + - 0.8691406 + - -4.9179688 + - -0.90283203 + - -8.3046875 + - -1.7314453 + - -2.0175781 + - -2.2753906 + - -2.9023438 + - -0.96533203 + - 2.8378906 + - -6.7421875 + - -4.4335938 + - 24.671875 + - -1.7314453 + - -1.6464844 + - -0.65722656 + - -0.1796875 + - 0.51416016 + - 2.3203125 + - 3.0976563 + - -2.1542969 + - 1.1396484 + - 1.6914063 + - -0.0390625 + - 0.88378906 + - -1.4277344 + - 0.4267578 + - 0.08758545 + - -3.4179688 + - 0.72802734 + - 4.8867188 + - -0.75634766 + - -0.5488281 + - -1.4765625 + - -2.4765625 + - 0.65625 + - -0.3408203 + - 3.7578125 + - 0.36083984 + - -2.0878906 + - 2.2285156 + - -0.27612305 + - 1.5869141 + - -2.5488281 + - 0.7753906 + - 0.4025879 + - 1.2587891 + - -0.55908203 + - 1.6416016 + - 2.9863281 + - 4.1796875 + - 0.13830566 + - -0.85595703 + - -0.55566406 + - 2.0410156 + - -3.8964844 + - 0.77978516 + - -0.2824707 + - 3.2734375 + - 1.1845703 + - -2.0351563 + - 0.7270508 + - 2.3515625 + - 0.83691406 + - -3.1015625 + - -1.3193359 + - -2.0195313 + - -1.6425781 + - -2.9023438 + - -0.42871094 + - 2.3789063 + - -3.4550781 + - -2.8339844 + - 1.1816406 + - -0.5722656 + - 2.453125 + - -2.5 + - -0.10070801 + - -1.1962891 + - -0.010597229 + - -2.734375 + - 1.5898438 + - -4.609375 + - -4.359375 + - -0.1171875 + - -1.5556641 + - 1.4550781 + - 8.6328125 + - 0.89501953 + - 3.6816406 + - -4.7578125 + - 1.1894531 + - -0.67626953 + - 1.3095703 + - 0.9038086 + - 0.67626953 + - -0.16235352 + - -4.78125 + - 0.53125 + - 0.7607422 + - 2.5625 + - -0.83447266 + - -2.8378906 + - 0.44628906 + - -0.08538818 + - -0.5522461 + - -2.4765625 + - 1.4394531 + - 2.1074219 + - -2.5625 + - 5.3554688 + - 0.30908203 + - 0.36865234 + - 0.9243164 + - 0.52734375 + - 4.0117188 + - 0.27416992 + - 2.0800781 + - -1.8203125 + - -0.51904297 + - 0.5410156 + - 2.3886719 + - 7.1640625 + - 1.7148438 + - 1.0996094 + - -1.0556641 + - 3.5546875 + - 0.050476074 + - 1.7128906 + - 1.7871094 + - 2.2246094 + - -0.30566406 + - 3.09375 + - -0.69628906 + - 3.6015625 + - -4.4882813 + - -1.4697266 + - -2.0253906 + - 0.94189453 + - 0.001115799 + - 1.3408203 + - -0.42285156 + - 4.0742188 + - -1.9775391 + - -2.1054688 + - -0.84228516 + - 0.016174316 + - 2.9785156 + - 2.40625 + - 0.7363281 + - 1.1787109 + - 3.2851563 + - 4.1992188 + - 0.75634766 + - -0.5756836 + - 1.3769531 + - 2.0800781 + - -4.9882813 + - -4.578125 + - -0.9609375 + - 3.3125 + - -1.5917969 + - -0.75097656 + - -1.9638672 + - 2.8613281 + - 3.2753906 + - 3.265625 + - -0.8544922 + - -0.28344727 + - 1.3613281 + - -1.3515625 + - -0.44604492 + - 2.5839844 + - 2.6875 + - -0.9711914 + - -0.3581543 + - 0.4165039 + - 1.7861328 + - 0.39453125 + - -0.12207031 + - -0.35864258 + - 1.2529297 + - 2.140625 + - 0.9091797 + - -2.1191406 + - -0.3251953 + - -3.6425781 + - -4.8789063 + - -0.092163086 + - 2.5820313 + - -0.86035156 + - -0.36767578 + - 3.125 + - -2.1777344 + - 2.0097656 + - 0.5566406 + - -0.9897461 + - -2.9140625 + - 1.4013672 + - -0.5180664 + - 3.0625 + - 3.3476563 + - 1.2998047 + - -6.8359375 + - -0.47680664 + - -0.41845703 + - -5.390625 + - 2.1210938 + - -2.6621094 + - 2.4355469 + - 1.3867188 + - -6.4453125 + - 1.3076172 + - -0.65478516 + - -2.7988281 + - -2.4296875 + - 1.1220703 + - -0.37597656 + - 2.0761719 + - -0.4309082 + - -0.8129883 + - -33.875 + - -2.53125 + - -2.4140625 + - -0.3881836 + - -1.4277344 + - 2.09375 + - 2.4121094 + - -4.7539063 + - -4.6601563 + - -0.9038086 + - 1.1162109 + - -1.4375 + - -1.0976563 + - 6.7734375 + - 0.4885254 + - 4.7304688 + - -1.6601563 + - 4.3242188 + - -0.25097656 + - -1.4335938 + - 0.11437988 + - -0.45507813 + - 1.0791016 + - 1.8134766 + - -0.4350586 + - -4.0117188 + - -1.2519531 + - 0.053833008 + - 1.8681641 + - -0.36206055 + - 0.5722656 + - -1.265625 + - 0.3642578 + - -0.5629883 + - -3.4941406 + - 4.8632813 + - -3.3046875 + - -0.8071289 + - -2.328125 + - -3.4863281 + - 0.029571533 + - 1.9746094 + - 2.6328125 + - 0.01576233 + - 0.25268555 + - 1.7089844 + - 4.0039063 + - -0.63720703 + - 1.90625 + - -2.8339844 + - 2.6796875 + - -1.0927734 + - 0.26220703 + - -3.9238281 + - 3.0117188 + - 2.6074219 + - -2.9648438 + - 3.4550781 + - 2.6816406 + - 0.6645508 + - -1.0673828 + - -4.0117188 + - 3.0097656 + - 1.3544922 + - 1.5175781 + - -0.3876953 + - 0.039611816 + - -5.0078125 + - 0.8300781 + - 1.3789063 + - -2.2207031 + - 0.77441406 + - 2.6035156 + - 0.40454102 + - -0.56103516 + - 2.2070313 + - -1.4003906 + - -2.6953125 + - 0.8046875 + - 0.42114258 + - -1.2441406 + - 2.0878906 + - 0.47314453 + - 1.0439453 + - 3.0527344 + - 0.85058594 + - -1.2832031 + - 1.1123047 + - 2.0527344 + - 0.74658203 + - -2.3789063 + - 2.7949219 + - -1.0400391 + - 8.5703125 + - -1.4746094 + - 2.03125 + - -0.5991211 + - -0.8847656 + - -0.44628906 + - -0.66796875 + - 2.8222656 + - 0.049102783 + - 3.53125 + - 1.0810547 + - 2.125 + - -2.1464844 + - -2.4277344 + - 3.5800781 + - -0.17236328 + - 5.921875 + - -1.0566406 + - 5.921875 + - -2.0253906 + - -0.95410156 + - -1.4013672 + - 1.5019531 + - 0.3852539 + - 0.79003906 + - -1.5839844 + - 4.1132813 + - 2.96875 + - 2.4902344 + - 4.6875 + - -0.7216797 + - -2.0976563 + - 1.7167969 + - -1.4580078 + - -4.0742188 + - -3.1113281 + - 0.44921875 + - -4.3554688 + - -0.16064453 + - 1.7939453 + - 3.7304688 + - -1.1054688 + - -0.67529297 + - -30.3125 + - -0.85595703 + - -0.027618408 + - -0.6660156 + - 0.7626953 + - 3.5800781 + - 0.79296875 + - 1.8632813 + - 0.12609863 + - 2.0976563 + - 0.012275696 + - -0.1484375 + - -2.9160156 + - -2.2011719 + - 1.3662109 + - -2.3691406 + - 0.55859375 + - 0.073791504 + - -0.63134766 + - -1.5576172 + - 1.4433594 + - 10.890625 + - 3.125 + - -1.265625 + - 1.1884766 + - 0.94140625 + - -0.84814453 + - 2.3105469 + - 0.37841797 + - -2.6035156 + - 1.296875 + - 0.2529297 + - -2.203125 + - 0.34057617 + - 0.38110352 + - -2.0644531 + - -3.2285156 + - 0.17248535 + - -0.55126953 + - -1.90625 + - 5.6289063 + - 1.6572266 + - -1.2236328 + - 3.1679688 + - 1.0341797 + - 1.2763672 + - 0.0011701584 + - 3.1445313 + - 0.6489258 + - -1.7949219 + - 0.19189453 + - 3.5175781 + - -2.3945313 + - 2.4589844 + - -1.5351563 + - -2.0097656 + - -0.9692383 + - 4.3242188 + - 0.4519043 + - -4.0820313 + - 1.6386719 + - -0.49804688 + - -0.6801758 + - -1.8076172 + - -2.5019531 + - 0.077819824 + - -3.75 + - 0.7397461 + - 3.0078125 + - -6.9453125 + - 0.48876953 + - -1.3095703 + - -3.3691406 + - -3.0175781 + - 1.7734375 + - -0.8691406 + - -3.1191406 + - 0.06640625 + - 0.18615723 + - -0.3959961 + - -1.3349609 + - -0.6459961 + - 1.8984375 + - 1.75 + - 6.6757813 + - -1.4882813 + - -0.46704102 + - -1.2744141 + - -1.8183594 + - 2.0644531 + - -1.9638672 + - -0.7011719 + - 2.0664063 + - 0.15258789 + - 3.4492188 + - 0.890625 + - 0.921875 + - -1.0634766 + - 3.0039063 + - -0.6928711 + - 1.6298828 + - 0.5488281 + - -2.703125 + - -1.1425781 + - 0.41503906 + - -0.5839844 + - -0.2109375 + - 4.5625 + - 1.4433594 + - -0.11102295 + - -1.6738281 + - 4.5078125 + - -0.49682617 + - 2.0371094 + - -2.7558594 + - -1.8857422 + - 2.1015625 + - 2.515625 + - -0.82177734 + - 0.87597656 + - 1.6611328 + - -1.1982422 + - -1.96875 + - -1.2451172 + - 0.07476807 + - -0.46923828 + - -4.9023438 + - 0.047424316 + - -1.0195313 + - 3.3046875 + - 0.25048828 + - 0.66015625 + - -0.43066406 + - -0.13110352 + - 1.1132813 + - -0.35327148 + - -0.6738281 + - -0.47021484 + - -1.140625 + - -4.4179688 + - 0.7680664 + - 4.2070313 + - 0.112854004 + - 1.3613281 + - 1.8691406 + - 0.6191406 + - 3.9082031 + - -1.546875 + - 0.0418396 + - 2.265625 + - 2.2480469 + - 2.8027344 + - -1.9775391 + - 1.8564453 + - -1.6796875 + - 1.6044922 + - -2.3691406 + - 0.18969727 + - 1.0859375 + - 2.8300781 + - -0.6640625 + - 2.6914063 + - 2.7753906 + - 1.3164063 + - 2.5449219 + - -2.40625 + - 4.4960938 + - -2.4257813 + - -0.54003906 + - 1.7001953 + - -0.63427734 + - -2.5 + - 1.7324219 + - 0.1015625 + - -2.2871094 + - -1.5751953 + - -1.5019531 + - -1.6982422 + - -2.8789063 + - 3.1425781 + - 1.8701172 + - 1.7558594 + - -2.7441406 + - -0.32348633 + - -0.13171387 + - 2.4902344 + - 0.3330078 + - 2.4199219 + - -3.0214844 + - -0.18884277 + - 0.44799805 + - 1.0439453 + - 0.17492676 + - 4.0351563 + - -0.08843994 + - 1.4238281 + - -0.7919922 + - -1.9882813 + - -0.9272461 + - 1.3662109 + - 1.046875 + - 0.63427734 + - 1.2451172 + - -3.4550781 + - 0.17297363 + - 1.7441406 + - 0.62353516 + - -0.3647461 + - 1.515625 + - -1.1552734 + - -2.4160156 + - -5.5429688 + - -4.09375 + - 6.078125 + - -1.3701172 + - -0.91015625 + - 1.1992188 + - -1.7529297 + - 2.0800781 + - -1.6416016 + - -2.3925781 + - -3.8867188 + - -2.203125 + - -2.6425781 + - 0.7397461 + - 0.2734375 + - 1.4511719 + - -0.7939453 + - -1.1513672 + - 0.75683594 + - 0.1204834 + - -3.5039063 + - -1.7607422 + - -1.4775391 + - 3.1015625 + - 2.0839844 + - 6.2929688 + - -0.44384766 + - 2.5175781 + - -1.7080078 + - 1.8369141 diff --git a/backends/candle/tests/snapshots/test_flash_mistral__mistral_single.snap b/backends/candle/tests/snapshots/test_flash_mistral__mistral_single.snap new file mode 100644 index 00000000..b238aa6a --- /dev/null +++ b/backends/candle/tests/snapshots/test_flash_mistral__mistral_single.snap @@ -0,0 +1,4101 @@ +--- +source: backends/candle/tests/test_flash_mistral.rs +assertion_line: 48 +expression: embeddings_single +--- +- - 3.2363281 + - -1.1582031 + - 1.0810547 + - -2.0234375 + - 1.6054688 + - -1.0048828 + - 0.4362793 + - -0.87646484 + - 0.7988281 + - -0.2722168 + - 0.49365234 + - -0.8203125 + - 0.17041016 + - -0.73291016 + - -0.34936523 + - 0.03543091 + - 0.34277344 + - 1.3779297 + - 1.5234375 + - -1.8720703 + - -1.4052734 + - 1.6289063 + - -1.1650391 + - 0.6503906 + - 1.7939453 + - 1.9814453 + - -0.43286133 + - 1.3994141 + - -0.3486328 + - -2.5253906 + - 2.5390625 + - 0.32348633 + - 2.2988281 + - 1.5175781 + - -0.28735352 + - 1.1669922 + - -3.4550781 + - 0.07141113 + - -5.2773438 + - -0.8330078 + - 0.75683594 + - -2.4296875 + - -0.9194336 + - -0.98095703 + - -1.7236328 + - 2.0722656 + - 0.234375 + - -3.9003906 + - -1.4003906 + - 0.8334961 + - 3.9121094 + - -0.4350586 + - -3.0488281 + - -100.5625 + - -3.0742188 + - -0.93408203 + - 2.7128906 + - -1.0556641 + - -1.3759766 + - -7.3671875 + - -2.3769531 + - 0.57910156 + - 0.83740234 + - 0.13171387 + - 2.4042969 + - 0.07281494 + - -2.5449219 + - 0.5151367 + - 2.0644531 + - -1.5566406 + - -4.6640625 + - 0.051605225 + - -2.9902344 + - -0.9008789 + - -1.2304688 + - -0.40454102 + - 2.9863281 + - 3.1367188 + - -0.13916016 + - -0.36206055 + - -0.640625 + - -0.6069336 + - -1.5878906 + - -0.34594727 + - -2.0214844 + - 0.5366211 + - -1.8007813 + - -0.15222168 + - 2.2597656 + - 0.86816406 + - -1.2304688 + - 5.9375 + - 2.7089844 + - -19.703125 + - -2.8144531 + - -2.7832031 + - -4.4414063 + - 0.36035156 + - 1.5751953 + - -4.09375 + - 1.6904297 + - -1.3564453 + - -3.8652344 + - -0.61035156 + - 0.006626129 + - -2.7910156 + - 0.07922363 + - -1.3349609 + - -1.5810547 + - -0.059143066 + - 0.03945923 + - 0.43066406 + - 0.47851563 + - -1.5595703 + - -0.055236816 + - 3.03125 + - 2.8515625 + - 0.70703125 + - -0.18713379 + - 0.296875 + - 2.2421875 + - 0.5942383 + - 0.15258789 + - -2.4863281 + - -2.2011719 + - -0.26879883 + - 2.4003906 + - -0.7294922 + - 0.32739258 + - 1.5878906 + - 2.3789063 + - 0.171875 + - -3.2539063 + - 1.1572266 + - 1.2333984 + - 2.4101563 + - -0.30664063 + - 0.7890625 + - -1.2041016 + - -1.484375 + - 1.5195313 + - -0.41796875 + - 1.3525391 + - -2.7753906 + - 1.1738281 + - -2.8027344 + - -1.7988281 + - -0.93603516 + - -2.5703125 + - 4.578125 + - 0.7866211 + - -1.9257813 + - -1.0458984 + - 0.037109375 + - 0.5161133 + - -2.7832031 + - 0.90527344 + - -0.5083008 + - -3.0410156 + - -3.1289063 + - -1.2539063 + - -2.6191406 + - -0.5517578 + - -1.140625 + - 0.5136719 + - 1.4003906 + - 3.3613281 + - -1.1591797 + - -0.7578125 + - -0.4633789 + - -2.6328125 + - -1.9306641 + - -0.4375 + - -2.9804688 + - -0.09539795 + - 0.25195313 + - -1.3125 + - 2.09375 + - -4.265625 + - -2.2539063 + - 1.7919922 + - 0.8027344 + - -1.8613281 + - -1.8544922 + - 0.13720703 + - 5.1015625 + - -3.4863281 + - -0.8515625 + - -1.1826172 + - 0.073913574 + - -1.9101563 + - -3.7773438 + - -0.5566406 + - 0.6411133 + - -2.2441406 + - 1.4951172 + - 3.6308594 + - -2.1640625 + - -1.4902344 + - 0.13244629 + - 0.4428711 + - -2.3515625 + - 1.9189453 + - 0.7084961 + - 7.9296875 + - 3.2929688 + - 4.2617188 + - -2.84375 + - -0.34692383 + - -2.2246094 + - -2.0625 + - -0.74853516 + - -2.90625 + - -0.8613281 + - 0.83447266 + - 3.9550781 + - -0.0033473969 + - -2.5214844 + - 3.0957031 + - -0.7055664 + - -3.8515625 + - 0.63378906 + - -0.32470703 + - 3.125 + - -0.085510254 + - -1.2158203 + - 0.09539795 + - 0.765625 + - 0.3972168 + - -3.1484375 + - 0.77734375 + - -0.76708984 + - -1.0068359 + - -0.88720703 + - -0.203125 + - 1.5800781 + - 0.74072266 + - 2.3691406 + - 2.8554688 + - 1.6591797 + - 6.0390625 + - -0.35083008 + - -3.4589844 + - 0.22875977 + - -2.2265625 + - -1.7607422 + - 3.2695313 + - 2.5605469 + - -0.6118164 + - 0.20898438 + - 1.2519531 + - 0.4440918 + - -2.1269531 + - 0.515625 + - 0.625 + - -0.69921875 + - -0.33081055 + - -2.28125 + - 0.03012085 + - 0.34375 + - 1.4726563 + - 1.8476563 + - 1.8925781 + - 1.890625 + - -0.3762207 + - 2.4140625 + - -2.2988281 + - 3.9140625 + - 0.85595703 + - -4.6953125 + - 0.32910156 + - 0.8154297 + - 3.2382813 + - 0.82910156 + - -0.42822266 + - -1.1640625 + - -3.9316406 + - 1.3710938 + - 0.18383789 + - 1.0302734 + - -2.1308594 + - -2.6738281 + - 1.3876953 + - 0.13671875 + - 1 + - -0.7792969 + - -0.064697266 + - 1.8291016 + - 0.65722656 + - 0.03186035 + - 1.7236328 + - 1.2119141 + - 10.078125 + - 0.06500244 + - 0.6723633 + - -0.4814453 + - 1.8417969 + - -1.4003906 + - 2.2128906 + - -0.5473633 + - -0.07757568 + - -0.20861816 + - 2.7636719 + - 3.3300781 + - -1.640625 + - 1.5292969 + - 2.4765625 + - 1.4394531 + - -0.09094238 + - -1.203125 + - -1.6669922 + - -0.10656738 + - -0.8984375 + - 0.07366943 + - 1.1894531 + - -2.4375 + - 0.08148193 + - -0.140625 + - -3.1875 + - 1.2861328 + - 1.8310547 + - 1.2421875 + - -3.8359375 + - 2.5703125 + - -1.4082031 + - -0.7836914 + - -1.3457031 + - -1.5019531 + - 2.3652344 + - -3.5800781 + - -4.078125 + - 1.7050781 + - 1.5644531 + - 0.7675781 + - 2.3378906 + - -0.11633301 + - 2.78125 + - -0.4987793 + - 0.0914917 + - 0.10571289 + - 2.7597656 + - -0.4482422 + - 2.1015625 + - 1.5498047 + - 1.9423828 + - 1.1533203 + - -0.21398926 + - 2.6796875 + - -5.0664063 + - -0.8828125 + - 2.1503906 + - -1.2607422 + - 0.3330078 + - 0.5073242 + - -2.1738281 + - -0.7817383 + - -4.09375 + - -1.6074219 + - -1.6220703 + - -1.4130859 + - -1.4902344 + - 1.7304688 + - 4.359375 + - -1.3847656 + - 2.96875 + - -2.9003906 + - 6.1679688 + - 1.703125 + - 1.4638672 + - -2.6914063 + - 0.77001953 + - -1.5253906 + - -2.1230469 + - 3.5039063 + - -0.40283203 + - 3.5976563 + - -1.4462891 + - 0.39208984 + - 0.70947266 + - 2.4726563 + - -1.3896484 + - -1.2861328 + - -1.9472656 + - -0.86035156 + - -0.7050781 + - 1.8564453 + - 1.8613281 + - -4.2226563 + - -8.125 + - -2.109375 + - 0.45532227 + - -0.09313965 + - -2.6660156 + - -0.9580078 + - 0.046875 + - -0.29736328 + - 2.6464844 + - 2.1054688 + - -2.1464844 + - 1.5488281 + - -2.3359375 + - 1.5898438 + - -0.5644531 + - -4.34375 + - -0.17236328 + - 1.7988281 + - 2.046875 + - -2.1660156 + - -14.390625 + - -0.1204834 + - -2.2128906 + - -1.6064453 + - 3.1152344 + - -1.1582031 + - 1.4433594 + - -0.19799805 + - -3.6875 + - -1.4189453 + - -3.6191406 + - 5.109375 + - -0.5004883 + - -0.4711914 + - 2.7988281 + - -0.33129883 + - -0.76171875 + - 1.0517578 + - 0.16320801 + - -2.0371094 + - 2.2246094 + - -1.4384766 + - -1.9189453 + - -1.7138672 + - -3.8613281 + - 0.84814453 + - -0.37939453 + - -1.8515625 + - 0.58203125 + - -1.9013672 + - 0.75097656 + - 1.6738281 + - -1.3115234 + - -1.5058594 + - -0.6225586 + - -1.6416016 + - -2.203125 + - -0.9116211 + - 0.06585693 + - 2.7050781 + - -2.1699219 + - -3.5800781 + - -0.0075302124 + - 1.5263672 + - -0.5859375 + - -2.0429688 + - -0.47314453 + - 5.609375 + - 4.625 + - -0.036254883 + - 0.06878662 + - 3.2851563 + - -0.44848633 + - -2.8945313 + - -1.7666016 + - 2.7695313 + - -0.9326172 + - -0.84472656 + - -0.9819336 + - 0.27319336 + - 0.33789063 + - -2.3496094 + - 4.9335938 + - 2.3125 + - 0.296875 + - 1.015625 + - 0.34985352 + - 0.4375 + - 0.8125 + - -0.91259766 + - -0.60546875 + - 3.8242188 + - 0.56884766 + - 0.5625 + - 0.9741211 + - -1.9951172 + - -0.32543945 + - 1.2128906 + - -6.0039063 + - 0.13793945 + - 3.71875 + - -0.5605469 + - 0.46289063 + - 1.5683594 + - -0.7011719 + - -0.4658203 + - -2.6289063 + - -1.3330078 + - 2.4589844 + - -2.0410156 + - -2.9179688 + - 5.3789063 + - 0.21728516 + - -5.9609375 + - 2.0371094 + - 0.051330566 + - 1.3349609 + - 3.8339844 + - -0.62158203 + - -0.61035156 + - -1.5869141 + - 2.8496094 + - 3.6738281 + - -2.5761719 + - -1.5 + - 0.6928711 + - -2.0371094 + - 1.6220703 + - -0.34204102 + - -0.5527344 + - -1.4384766 + - -0.5102539 + - 0.5991211 + - 1.5878906 + - 3.6777344 + - -0.01701355 + - 0.55566406 + - 1.4580078 + - 0.20336914 + - -1.375 + - 1.6777344 + - 2.1894531 + - 0.85302734 + - 0.45385742 + - -0.0055770874 + - -1.8759766 + - 3.5820313 + - 0.16687012 + - -4.5078125 + - -0.12371826 + - -0.3569336 + - -1.6259766 + - -1.9589844 + - -1.0117188 + - 3.1054688 + - -0.84765625 + - -4.8398438 + - -2.3632813 + - -1.3837891 + - 0.20227051 + - 1.984375 + - 1.4824219 + - 0.63720703 + - 1.9658203 + - -17.703125 + - -1.4570313 + - -1.0488281 + - -2 + - -1.3818359 + - 0.6147461 + - 0.203125 + - 0.0036258698 + - 2.34375 + - -3.4863281 + - -1.0546875 + - -2.2402344 + - 1.2724609 + - -1.0302734 + - 0.8774414 + - -2.4511719 + - -1.4433594 + - -2.3476563 + - -2.2890625 + - -0.8935547 + - -1.9257813 + - 0.9921875 + - 0.2890625 + - -1.2851563 + - 1.1181641 + - 0.57421875 + - 0.31811523 + - 0.72314453 + - -3.2070313 + - 0.65966797 + - 2.5644531 + - -3.21875 + - -2.9375 + - 1.6806641 + - 1.6425781 + - -2.3378906 + - -3.4960938 + - -1.1923828 + - 1.4433594 + - -3.1875 + - -1.8876953 + - -0.10430908 + - -1.9082031 + - 1.4277344 + - 5.1757813 + - 3.9101563 + - 1.0273438 + - 3.2441406 + - -0.8261719 + - -0.68408203 + - 0.056915283 + - -2.2558594 + - -2.3261719 + - 0.15344238 + - -3.6953125 + - 0.5649414 + - -4.3789063 + - 0.9536133 + - 2.5917969 + - -1.7558594 + - -1.9824219 + - 1.9570313 + - -0.6069336 + - -0.25170898 + - -1.5556641 + - 1.8720703 + - -2.140625 + - 0.001115799 + - 1.4619141 + - 1.8613281 + - 0.002231598 + - 0.44140625 + - -1.609375 + - 3.4902344 + - 0.036834717 + - 1.4189453 + - 0.78222656 + - -0.125 + - 1.7041016 + - -0.5253906 + - -3.2265625 + - -2.6875 + - 0.61328125 + - 2.6132813 + - -2.8164063 + - -0.8310547 + - -0.25170898 + - 0.034576416 + - -2.2246094 + - -2.5664063 + - -0.08154297 + - 2.7851563 + - 4.390625 + - -1.0859375 + - 0.5961914 + - -4.6757813 + - 7.9101563 + - -3.1484375 + - 0.7319336 + - 3.3535156 + - -1.6201172 + - -2.59375 + - 0.98291016 + - -1.6289063 + - -0.5541992 + - 2.6914063 + - 3.8085938 + - -0.45996094 + - 1.4609375 + - 1.0556641 + - 1.6582031 + - 3.1054688 + - -0.5498047 + - 2.4003906 + - 1.8154297 + - -2.0449219 + - 0.22497559 + - 0.9868164 + - -0.52490234 + - -1.0039063 + - 0.6166992 + - 0.609375 + - -0.7138672 + - -2.9492188 + - -0.19580078 + - -0.9863281 + - -0.18981934 + - 0.0446167 + - 1.5244141 + - 1.7304688 + - 1.015625 + - -1.4150391 + - 7.7539063 + - 3.3671875 + - 7.0976563 + - 1.4716797 + - -5.71875 + - -5.8828125 + - -0.3815918 + - -1.3154297 + - -0.3232422 + - -1.5888672 + - 0.18579102 + - -0.23291016 + - -1.0429688 + - 1.6132813 + - -1.9462891 + - 2.6738281 + - 3.2207031 + - 3.6679688 + - -0.9086914 + - -2.5136719 + - 0.5102539 + - 24.09375 + - 1.2988281 + - 0.88183594 + - 0.09259033 + - -3.0175781 + - 1.8251953 + - 0.71240234 + - 0.7685547 + - -2.15625 + - 1.1123047 + - 3.0058594 + - 2.9707031 + - -0.28710938 + - -0.04937744 + - 0.5996094 + - 4.7890625 + - 1.4404297 + - 3.0644531 + - -5.0585938 + - -0.171875 + - -1.8632813 + - -1.8867188 + - -3.6425781 + - 0.9008789 + - -0.4501953 + - 1.4492188 + - -0.27001953 + - -1.8603516 + - 2.15625 + - 0.66259766 + - -3.4140625 + - -5.734375 + - -0.2175293 + - -3.0253906 + - -0.99658203 + - 1.8369141 + - -1.4111328 + - 1.4414063 + - 3.9785156 + - -1.9648438 + - -1.5273438 + - -1.875 + - 2.2949219 + - -0.2331543 + - -0.55810547 + - 1.2763672 + - 0.61083984 + - 1.4492188 + - 0.84228516 + - -0.7363281 + - -0.9975586 + - -3.1113281 + - 2.9492188 + - -0.51416016 + - 0.32739258 + - -2.6601563 + - -1.5888672 + - 1.0517578 + - 0.116882324 + - -1.2705078 + - -1.6640625 + - 2.1640625 + - -1.7226563 + - -1.7275391 + - -0.45581055 + - -0.26733398 + - 2.6152344 + - 0.42016602 + - -1.1191406 + - -0.46948242 + - 4.125 + - 1.4794922 + - -0.26660156 + - 2.9726563 + - -2.859375 + - 2.3183594 + - 0.52001953 + - -1.1894531 + - -3.203125 + - -1.1923828 + - 2.2304688 + - -2.4980469 + - 5.8789063 + - -0.002231598 + - 2.4101563 + - -0.78125 + - -1.4570313 + - 0.85595703 + - 2.6875 + - 0.5 + - -1.1445313 + - -0.55908203 + - 0.46972656 + - 1.1552734 + - -3.6191406 + - 2.3222656 + - -4.75 + - -4.75 + - -3.7851563 + - 1.0068359 + - 3.9140625 + - 1.4355469 + - -1.3916016 + - 0.17407227 + - 2.4257813 + - 1.2197266 + - -2.609375 + - 1.1171875 + - -1.5351563 + - -4.0273438 + - -0.3540039 + - 5.6328125 + - 0.22961426 + - 2.1113281 + - 1.9863281 + - -0.9980469 + - 2.140625 + - -0.2734375 + - -2.8144531 + - -0.19921875 + - 4.5820313 + - -2.5644531 + - -0.36279297 + - 2.8847656 + - -1.4326172 + - 0.06750488 + - 1.0771484 + - -1.1982422 + - -9.3359375 + - 1.4482422 + - -47.28125 + - -1.2910156 + - -0.60595703 + - -2.0683594 + - -3.9179688 + - -0.47753906 + - 0.29614258 + - 1.0644531 + - 1.6621094 + - 1.0615234 + - 0.18664551 + - -1.7929688 + - 4.6835938 + - -0.5258789 + - -2.0019531 + - 1.5908203 + - 1.1064453 + - -0.72509766 + - 16.984375 + - 0.42407227 + - -0.97509766 + - -1.2607422 + - -0.94140625 + - -0.58251953 + - 0.40063477 + - 2.8007813 + - 0.87109375 + - -1.6220703 + - -0.42578125 + - -2.6699219 + - -0.9589844 + - -2.4199219 + - 0.1784668 + - 0.50146484 + - -0.8803711 + - 2.4511719 + - 1.0332031 + - 0.80566406 + - 1.4453125 + - 0.50878906 + - 3.9179688 + - -0.37817383 + - 1.0478516 + - 0.25683594 + - -3.1425781 + - 2.5253906 + - 0.12548828 + - -1.2929688 + - -0.5229492 + - -2.9902344 + - 0.3515625 + - -1.6113281 + - -0.08203125 + - -0.65966797 + - -0.06137085 + - 0.20996094 + - 1.9462891 + - -4.1523438 + - -2.4902344 + - 0.3618164 + - 0.18371582 + - -1.0068359 + - -2.59375 + - 1.2685547 + - 6.5976563 + - -0.65185547 + - -0.7446289 + - 4.7265625 + - -2.2558594 + - 2.3105469 + - -2.0644531 + - -0.16882324 + - 0.17822266 + - -2.3066406 + - 2.8925781 + - -4.5742188 + - 3.5976563 + - -1.0625 + - 5.234375 + - 0.47021484 + - 0.3149414 + - -3.0703125 + - 1.9140625 + - 0.18664551 + - 1.9296875 + - 2.9335938 + - -1.0634766 + - -0.28735352 + - 0.26293945 + - -0.7158203 + - 2.5664063 + - -1.4658203 + - -1.5371094 + - -2.7050781 + - 1.2421875 + - -0.47607422 + - -0.35009766 + - -0.84472656 + - -3.4140625 + - -1.09375 + - -2.1328125 + - -5.7382813 + - -1.1669922 + - 0.2524414 + - 1.3486328 + - 3.4140625 + - 3.4492188 + - 0.40039063 + - 0.56640625 + - 0.06439209 + - 0.7709961 + - 0.99365234 + - -1.6416016 + - 2.9238281 + - 0.9736328 + - 1.3349609 + - -2.6855469 + - 2.3144531 + - -2.046875 + - 2.2109375 + - 1.6347656 + - 0.484375 + - -1.6738281 + - -1.7783203 + - 0.17663574 + - 0.31176758 + - 4.0273438 + - -0.72509766 + - 3.1933594 + - 2.3925781 + - 2.6542969 + - 1.484375 + - -0.05355835 + - 1.9794922 + - 0.39257813 + - 0.24121094 + - 2.7246094 + - -0.80126953 + - -2.8066406 + - 0.16589355 + - -2.1699219 + - -2.03125 + - -2.4511719 + - -3.0097656 + - 1.8994141 + - 2.8339844 + - 2.7753906 + - -2.4824219 + - 0.84228516 + - -3.1992188 + - 2.2734375 + - -1.7246094 + - 4.734375 + - 4.6914063 + - 0.59472656 + - -0.5366211 + - 1.7763672 + - 0.2956543 + - 2.3574219 + - -4.1796875 + - 3.9277344 + - -1.46875 + - -4.9414063 + - -1.9033203 + - -1.0361328 + - -0.3449707 + - -2.9414063 + - -15.5703125 + - 2.0390625 + - -1.2744141 + - 4.1445313 + - 1.2207031 + - 3.3535156 + - 1.3818359 + - 1.5976563 + - -0.45166016 + - -0.6635742 + - 1.65625 + - -2.0996094 + - 2.4941406 + - 1.4931641 + - 2.0800781 + - -3.2714844 + - 0.96191406 + - -0.0055770874 + - -0.21203613 + - 1.2304688 + - 2.2519531 + - -1.0205078 + - 0.35668945 + - -0.019805908 + - 1.59375 + - -4.4726563 + - 0.2109375 + - -1.7705078 + - -16.6875 + - 0.61816406 + - 0.119384766 + - 2.9882813 + - -4.9882813 + - -1.5654297 + - 0.2705078 + - 1.1875 + - -5.0273438 + - -2.6191406 + - -2.6113281 + - 3.7070313 + - -0.53222656 + - -0.44799805 + - -2.3652344 + - 0.7050781 + - -6.203125 + - -1.1806641 + - -0.3515625 + - 0.73828125 + - 1.1845703 + - -1 + - -0.24719238 + - -2.4667969 + - -0.6381836 + - 2.9179688 + - -3.5039063 + - -0.040161133 + - -0.52783203 + - 1.5332031 + - 3.4589844 + - -1.8183594 + - 0.32641602 + - -1.4794922 + - -0.75 + - 2.2285156 + - -0.75390625 + - 0.43066406 + - -18.859375 + - -0.33496094 + - -1.8964844 + - 2.4355469 + - -4.1835938 + - 2.4101563 + - 0.5703125 + - 1.2275391 + - 1.6376953 + - -0.6977539 + - 1.4189453 + - -1.1503906 + - 2.2636719 + - -1.9921875 + - 0.5078125 + - -0.11853027 + - 0.58691406 + - 0.04714966 + - 1.4111328 + - -4.8007813 + - -1.421875 + - 2.3105469 + - -2.7324219 + - -0.19165039 + - 2.9023438 + - -1.453125 + - 3.1464844 + - -2.5957031 + - -1.5205078 + - 2.0761719 + - 1.3583984 + - 3.15625 + - -2.1542969 + - -1.4980469 + - -1.6445313 + - -0.41552734 + - -0.60791016 + - -1.6884766 + - 1.4931641 + - 1.8642578 + - 3.7695313 + - 1.6601563 + - 2.2988281 + - 3.6582031 + - -2.0839844 + - 0.41430664 + - -2.2949219 + - -1.4238281 + - -6.0546875 + - 1.0351563 + - 2.46875 + - 0.46142578 + - 0.2512207 + - 0.19921875 + - -2.0976563 + - 0.60302734 + - 0.1508789 + - 8.0703125 + - -0.37890625 + - -1.6367188 + - -1.1289063 + - 1.1582031 + - 1.5166016 + - 1.8085938 + - -1.7597656 + - -1.9277344 + - 0.43237305 + - 2.6953125 + - 0.68310547 + - 3.0742188 + - -3.4238281 + - -4.5898438 + - 0.8183594 + - 0.8173828 + - 1.5820313 + - 0.97314453 + - 0.3359375 + - -0.24768066 + - 4.140625 + - 0.4609375 + - 0.12164307 + - -2.3164063 + - 1.6376953 + - -0.86328125 + - 1.2705078 + - -3.3242188 + - -0.4831543 + - 1.75 + - -2.6875 + - 1.2890625 + - 3.96875 + - 2.2597656 + - -0.89990234 + - -0.88964844 + - 1.5273438 + - 1.3662109 + - 0.67626953 + - 2.8710938 + - 4.9335938 + - -0.4152832 + - 1.0458984 + - -0.6816406 + - 0.17663574 + - 0.75 + - 2.2324219 + - 1.2294922 + - 1.1123047 + - 0.13781738 + - -4.578125 + - -0.58251953 + - 3.1289063 + - -2.9101563 + - -0.50390625 + - -3.1054688 + - -0.7910156 + - 2.46875 + - 6.375 + - 1.0224609 + - -1.5839844 + - 1.7207031 + - 2.2578125 + - -0.5307617 + - -1.3740234 + - 1.2626953 + - -5.4179688 + - 1.2460938 + - 2.6777344 + - 5.4140625 + - -0.45336914 + - 0.5151367 + - -1.0908203 + - -0.8769531 + - -2.59375 + - -3.6132813 + - 3.6015625 + - -0.8696289 + - 0.9765625 + - 5.375 + - -2.1015625 + - -1.2519531 + - -2.5078125 + - -0.39208984 + - -0.044769287 + - 0.2902832 + - -1.1806641 + - -0.1352539 + - 3.3046875 + - -0.9975586 + - 0.38891602 + - 1.9707031 + - 0.58154297 + - -0.54052734 + - -0.30859375 + - 3.3164063 + - -0.28027344 + - 0.87158203 + - 1.84375 + - 2.5957031 + - 0.625 + - -0.63720703 + - -3.7226563 + - -3.2988281 + - 0.060546875 + - 3.0703125 + - -0.93847656 + - 2.4707031 + - -0.65722656 + - 1.5 + - -0.15563965 + - -3.625 + - 0.98095703 + - 0.1015625 + - -0.14416504 + - -1.1445313 + - -2.4316406 + - 6.703125 + - -2.4082031 + - 0.82910156 + - -1.2744141 + - 2.6484375 + - 0.7402344 + - -0.6870117 + - -2.0546875 + - 0.016738892 + - -3.9648438 + - 0.97753906 + - 0.3684082 + - 1.9726563 + - 1.2236328 + - 11.5703125 + - -1.9707031 + - -1.2548828 + - 1.5488281 + - 0.38598633 + - 6.0546875 + - 4.0273438 + - 0.3269043 + - -1.5107422 + - -0.71191406 + - 0.52734375 + - 8.3046875 + - 0.3881836 + - -0.64404297 + - 0.2421875 + - -1.1992188 + - 0.69873047 + - -3.1113281 + - -2.7441406 + - -2.3984375 + - -3.6738281 + - 1.8623047 + - -3.6796875 + - -1.0703125 + - 1.0117188 + - 0.83203125 + - -4.9375 + - -0.24768066 + - 0.37231445 + - 1.9902344 + - -0.44458008 + - -1.4228516 + - 1.3271484 + - -1.1367188 + - -1.125 + - 2.2480469 + - 0.48657227 + - 1.9863281 + - 4.1679688 + - -1.84375 + - 1.5097656 + - 0.41918945 + - -4.1914063 + - -1.8837891 + - -0.30249023 + - -1.7529297 + - 3.1015625 + - -1.015625 + - 0.49438477 + - 3.1601563 + - 0.076171875 + - 3.5742188 + - -0.7426758 + - 3.171875 + - -1.8476563 + - 3.15625 + - -0.8876953 + - -3.9023438 + - -2.7324219 + - -3.7519531 + - 1.6601563 + - 1.1337891 + - -0.98876953 + - -0.70947266 + - -0.7890625 + - -0.30151367 + - -2.2441406 + - -1.0410156 + - 1.1416016 + - 1.0859375 + - -0.74365234 + - 2.7128906 + - -9.2578125 + - 3.6777344 + - 3.4101563 + - -0.7944336 + - 0.8720703 + - -2.4628906 + - -0.8623047 + - 0.82177734 + - -0.097351074 + - 1.9794922 + - 0.9145508 + - -0.82421875 + - 3.8378906 + - 0.4519043 + - -1.5556641 + - -2.7050781 + - -0.60253906 + - 1.1113281 + - -0.43481445 + - -2.0175781 + - -0.31811523 + - -0.0758667 + - -1.5087891 + - 3.2519531 + - 0.3737793 + - -6.2070313 + - 1.9091797 + - 4.3554688 + - -0.013671875 + - 0.04714966 + - 0.29467773 + - 0.8154297 + - 1.7441406 + - 2.4199219 + - 3.375 + - 0.42578125 + - 0.55810547 + - -0.4350586 + - -0.10180664 + - 1.4433594 + - 2.7324219 + - -0.17236328 + - -3.9609375 + - 10.78125 + - 2.2988281 + - -3.1757813 + - -71.0625 + - 0.85791016 + - -1.6738281 + - -0.8847656 + - 2.8320313 + - 4.7890625 + - 1.6933594 + - 0.89697266 + - -0.09313965 + - -2.2050781 + - -2.7636719 + - 1.6953125 + - -0.71533203 + - 2.3476563 + - 0.35327148 + - -5.0625 + - -2.6953125 + - -3.0058594 + - -0.32592773 + - 1.7832031 + - 2.4550781 + - 0.5229492 + - 1.1347656 + - -0.9584961 + - -1.6064453 + - -2.7519531 + - -1.6699219 + - -3.28125 + - 1.0976563 + - -1.7207031 + - 1.1289063 + - -4.6367188 + - 0.08868408 + - -1.1123047 + - -3.8847656 + - 1.0830078 + - 1.0185547 + - -0.043792725 + - 1.3076172 + - -2.6289063 + - -0.30395508 + - -1.3193359 + - 4.21875 + - 1.7939453 + - 1.2841797 + - -2.6074219 + - 2.0527344 + - 1.4726563 + - 2.9414063 + - 0.3347168 + - 1.2998047 + - -0.56591797 + - 1.0771484 + - 9.7265625 + - -4.9023438 + - 1.8222656 + - 0.13598633 + - 0.9267578 + - 0.3774414 + - -2.0136719 + - 0.92089844 + - 2.0449219 + - 0.38598633 + - -3.1523438 + - -0.7363281 + - 0.11602783 + - -4.6367188 + - 0.7373047 + - -0.9375 + - 0.46191406 + - -2.9609375 + - 2.0625 + - 2.8964844 + - 0.58447266 + - 1.4394531 + - 0.29077148 + - -2.2109375 + - -0.7861328 + - 0.54296875 + - 1.0341797 + - -0.111328125 + - 0.41235352 + - -1.7998047 + - -1.1992188 + - 0.7680664 + - -2.7578125 + - 2.4277344 + - 3.6503906 + - -0.6069336 + - -1.0185547 + - -1.2431641 + - 2.0898438 + - -0.15917969 + - 2.8671875 + - 2.4902344 + - 7.8007813 + - 1.8486328 + - 3.0820313 + - -1.703125 + - 0.8125 + - 1.5527344 + - -0.3125 + - 0.39379883 + - 1.9355469 + - -0.99658203 + - 0.13000488 + - -0.84033203 + - -2.9570313 + - 0.6801758 + - -1.1962891 + - 5.3007813 + - 16.75 + - 1.0966797 + - -0.65185547 + - -3.8945313 + - 1.375 + - -0.7519531 + - 1.6757813 + - 2.3925781 + - -0.3112793 + - -0.93359375 + - 3.2714844 + - 0.94921875 + - 1.359375 + - -1.8720703 + - 2.1757813 + - 2.2402344 + - -4.09375 + - 1.3691406 + - 0.3017578 + - 2.1171875 + - 0.10992432 + - -1.7070313 + - 1.2988281 + - -0.8232422 + - 3.9394531 + - 1.4765625 + - -1.4296875 + - 3.2890625 + - 1.3623047 + - -1.7988281 + - -3.2207031 + - 1.6689453 + - -0.06915283 + - -3 + - 0.7626953 + - 0.15979004 + - -2.6484375 + - 0.08618164 + - 1.9960938 + - 0.55322266 + - 0.3449707 + - 3.0351563 + - 1.4033203 + - -0.54345703 + - 0.3737793 + - 3.5664063 + - -0.76220703 + - 2.7558594 + - 0.7607422 + - 3.2363281 + - 2.3925781 + - -2.2617188 + - -1.4804688 + - 2.25 + - 6.3828125 + - -2.75 + - -0.32836914 + - 3.0234375 + - -4.2539063 + - 0.107666016 + - -0.51660156 + - -2.2578125 + - 0.2763672 + - 0.7685547 + - 2.3105469 + - 1.0986328 + - 0.08648682 + - -0.15844727 + - -0.0027885437 + - -1.9550781 + - -0.63671875 + - -2.2246094 + - 0.40283203 + - 1.1972656 + - 0.39086914 + - -2.2207031 + - -1.6533203 + - -2.0566406 + - -1.6660156 + - -10.375 + - 0.69091797 + - 0.6245117 + - -0.04574585 + - -0.63378906 + - -1.4775391 + - -3.3144531 + - 1.4140625 + - -0.5234375 + - 1.6064453 + - 3.4453125 + - 1.1767578 + - 2.6191406 + - 5.765625 + - -1.4560547 + - 1.8808594 + - -3.375 + - -3.6914063 + - -2.7050781 + - 1.6914063 + - 0.24243164 + - -2.6425781 + - 2.9160156 + - -2.34375 + - -0.6567383 + - 0.69628906 + - 1.2294922 + - 5.4804688 + - -0.18408203 + - 0.48876953 + - 3.3378906 + - 4.1132813 + - -3.0703125 + - -5.390625 + - -0.29760742 + - 0.8984375 + - 1.0292969 + - 2.5839844 + - -0.08984375 + - -1.4404297 + - 2.7011719 + - 2.3789063 + - -0.2915039 + - -1.8369141 + - -1.3837891 + - 2.1191406 + - 0.8208008 + - 3.875 + - 1.8369141 + - -0.4584961 + - 3.375 + - 1.1132813 + - 1.0107422 + - 2.1347656 + - -3.4238281 + - -2.9003906 + - -2.6542969 + - 2.4277344 + - 2.7695313 + - -1.9716797 + - -3.71875 + - -3.6953125 + - -1.53125 + - -4.890625 + - 0.98535156 + - -1.0332031 + - 2.1660156 + - 0.57177734 + - -2.96875 + - -4.15625 + - -0.06359863 + - 0.03375244 + - 3.421875 + - 0.9238281 + - -0.6503906 + - -1.0087891 + - 20.421875 + - 1.1191406 + - 0.57958984 + - 2.1933594 + - 8.015625 + - -0.359375 + - -0.22424316 + - 0.3095703 + - 0.73583984 + - -3.4316406 + - -0.8833008 + - 4.125 + - -2.3203125 + - 4.7304688 + - 0.6694336 + - 0.73828125 + - -0.64697266 + - 0.6850586 + - -2.9277344 + - -2.5664063 + - 5.1523438 + - -0.84033203 + - 0.48242188 + - 3.7050781 + - 0.15368652 + - -3.9765625 + - 1.375 + - 2.2460938 + - 0.9941406 + - 0.20471191 + - 0.63378906 + - 0.37158203 + - 3.1679688 + - 0.61279297 + - -4.0507813 + - 0.9628906 + - -0.625 + - -0.94433594 + - -1.0126953 + - -4.5390625 + - 5.3125 + - 2.5136719 + - -6.203125 + - -1.0429688 + - 1.4091797 + - 2.28125 + - -1.4980469 + - 1.140625 + - 1.7939453 + - -2.5078125 + - 3.671875 + - 0.52001953 + - 2.359375 + - 0.30126953 + - 6.125 + - 1.1328125 + - 0.2890625 + - 1.0439453 + - -2.0097656 + - -3.8300781 + - 4.5507813 + - 3.0390625 + - 2.7226563 + - 0.027053833 + - 0.33325195 + - 0.15283203 + - 2.9375 + - -3.4550781 + - 0.39501953 + - 0.38476563 + - -4.5078125 + - -1.8955078 + - 1.9746094 + - 2.75 + - -4.6992188 + - -2.0097656 + - -1.140625 + - -3.2929688 + - -1.2207031 + - -2.7890625 + - 1.3349609 + - 1.0644531 + - 0.18103027 + - -3.5664063 + - -0.7441406 + - 2.5605469 + - 1.5654297 + - -1.3662109 + - -2.8671875 + - 1.3818359 + - -1.5234375 + - -0.8388672 + - -4.0742188 + - -2.3789063 + - -4.5390625 + - 2.6972656 + - 0.6796875 + - -3.2050781 + - -2.5175781 + - -2.1894531 + - 1.2724609 + - 0.51416016 + - -0.60595703 + - 4.125 + - -3.0625 + - 0.67041016 + - -0.07757568 + - -1.6328125 + - 4.0585938 + - -3.6660156 + - 1.1875 + - -2.1308594 + - 2.0605469 + - -0.37939453 + - -4.78125 + - -1.0390625 + - 3.9726563 + - 0.35839844 + - 1.2685547 + - -2.8925781 + - 2.3574219 + - -6.140625 + - 1.2578125 + - 0.69873047 + - -0.88964844 + - 3.6660156 + - 3.4941406 + - 1.4863281 + - 2.40625 + - -0.640625 + - 0.66015625 + - -2.4589844 + - -3.3125 + - -2.1347656 + - 2.8867188 + - 0.7397461 + - -1.4589844 + - 1.7070313 + - 1.0664063 + - -0.52783203 + - 2.5449219 + - -1.8867188 + - -1.6669922 + - 1.2216797 + - -0.51660156 + - -1.5722656 + - 1.5830078 + - 0.42919922 + - 0.49487305 + - 3.7519531 + - 2.6386719 + - 0.0892334 + - -1.2861328 + - -5.2070313 + - 3.09375 + - 1.4482422 + - -2.1132813 + - 2.4472656 + - 1.5185547 + - -3.7050781 + - 2.1367188 + - 1.9863281 + - -1.7519531 + - 2.6875 + - -3 + - -1.9804688 + - -1.8457031 + - 0.51708984 + - 1.8808594 + - 0.33813477 + - -1.5712891 + - -5.5898438 + - -0.23986816 + - -1.6425781 + - -0.8676758 + - -1.3125 + - -5.1445313 + - 3.1328125 + - 0.61816406 + - -2.2441406 + - 1.0234375 + - -1.7402344 + - 3.6640625 + - -2.1699219 + - 2.3691406 + - -1.4482422 + - 0.34106445 + - -0.8408203 + - -0.49316406 + - 1.8691406 + - -0.21594238 + - -0.25708008 + - -3.2109375 + - 0.10406494 + - -1.5878906 + - 1.0107422 + - 1.2763672 + - 3.7441406 + - -1.6972656 + - -2.15625 + - -0.032348633 + - 3.90625 + - 2.0722656 + - -1.0029297 + - -3.7441406 + - -1.1396484 + - -2.8867188 + - 8.7734375 + - -1.75 + - -0.109375 + - -1.7861328 + - 4.3945313 + - 1.2861328 + - 1.1962891 + - 0.7944336 + - -1.3017578 + - 0.21643066 + - -0.7138672 + - 2.1738281 + - -5.390625 + - -2.6757813 + - 5.7382813 + - -4.125 + - 3.6875 + - -1.0947266 + - 0.5 + - 0.6381836 + - 3.8164063 + - 0.3984375 + - -1.3984375 + - -0.0078125 + - 0.95410156 + - 2.171875 + - -4.828125 + - 1.7792969 + - 0.54833984 + - -3.1738281 + - -1.4355469 + - -0.23962402 + - -1.1396484 + - -0.22302246 + - -1.1669922 + - 0.3425293 + - 1.5595703 + - -0.8535156 + - -2.1015625 + - -3.8867188 + - 0.54833984 + - -1.4433594 + - -1.6181641 + - 0.23596191 + - 2.6875 + - 0.5493164 + - 2.5390625 + - -0.3046875 + - -0.31103516 + - -1.7480469 + - 3.4765625 + - 2.8671875 + - -1.8125 + - -0.6796875 + - -3.6894531 + - -2.2324219 + - 1.75 + - 0.15234375 + - -2.2128906 + - -2.3203125 + - -0.578125 + - 1.2363281 + - -0.47875977 + - 0.8803711 + - 2.4414063 + - -0.9194336 + - -3.0878906 + - -2.6503906 + - 0.14672852 + - -2.9726563 + - -1.8681641 + - -1.0400391 + - -2.1738281 + - -2.8847656 + - -0.61816406 + - -0.8330078 + - -1.3642578 + - 5.4140625 + - 4.6953125 + - -4.2148438 + - -0.3569336 + - -1.28125 + - 1.4785156 + - -2.328125 + - -2.2949219 + - 3.5800781 + - -1.3017578 + - -2.5488281 + - 1.4306641 + - 2.2753906 + - -2.2050781 + - -3.6425781 + - -0.66845703 + - -1.7558594 + - -1.0195313 + - 0.15844727 + - -0.32080078 + - -0.70654297 + - -1.9628906 + - -1.0722656 + - -1.2929688 + - -0.76416016 + - -2.0664063 + - -2.2539063 + - -0.7558594 + - -0.37158203 + - 3.9863281 + - -2.7519531 + - 3.9023438 + - -1.9804688 + - -0.9316406 + - 6.5078125 + - 0.60253906 + - -0.82910156 + - -1.3535156 + - 0.6323242 + - -2.9726563 + - 3.3203125 + - 6.421875 + - -2.3164063 + - -0.7084961 + - 5.7226563 + - 0.90283203 + - 1.3837891 + - 0.3955078 + - -1.9765625 + - 1.0742188 + - 0.50878906 + - -2.9804688 + - 1.3427734 + - -0.8613281 + - -0.33447266 + - 2.6582031 + - -7.1601563 + - 0.71777344 + - 4.2148438 + - -2.4765625 + - -0.7910156 + - -2.1523438 + - 4.2460938 + - -5.1679688 + - -2.3320313 + - -0.23095703 + - 1.5947266 + - 2.4082031 + - -0.68847656 + - 1.6523438 + - -2.328125 + - -2.6777344 + - 2.3359375 + - -0.6948242 + - 0.39648438 + - -2.3339844 + - 3.7714844 + - 0.66845703 + - -1.71875 + - -2.4238281 + - -1.2421875 + - -0.2253418 + - 0.5722656 + - -0.34692383 + - 0.54541016 + - 2.0175781 + - -2.5878906 + - -0.09539795 + - -2.7949219 + - 0.7241211 + - 0.953125 + - 1.1865234 + - -1.2783203 + - -2.234375 + - -3.1484375 + - 1.2773438 + - 0.5834961 + - 1.1572266 + - -0.35473633 + - -2.15625 + - -2.1152344 + - 1.2978516 + - -3.0273438 + - -2.5136719 + - -1.9619141 + - 3.6992188 + - -3.4785156 + - -1.9482422 + - -0.60253906 + - 2.3535156 + - -1.6074219 + - 0.014503479 + - -1.0634766 + - -0.9248047 + - -0.30688477 + - -4.1210938 + - 0.8144531 + - 1.6376953 + - 4.859375 + - -1.6796875 + - 1.4482422 + - -0.28686523 + - 6.375 + - 1.9296875 + - -0.7294922 + - 1.4150391 + - 1.7324219 + - -0.64990234 + - -1.9150391 + - -1.2890625 + - 1.2744141 + - 1.7753906 + - 3.4375 + - -1.9316406 + - 2.3730469 + - -0.04574585 + - -0.055236816 + - 2.40625 + - -0.5361328 + - -0.97753906 + - 1.7050781 + - -1.4550781 + - -2.8496094 + - 0.9140625 + - 0.92285156 + - -3.3085938 + - -0.5410156 + - 1.8603516 + - -1.9072266 + - -1.2226563 + - -0.16955566 + - -0.29467773 + - 4.4257813 + - 6.8242188 + - -1.8144531 + - -0.18603516 + - -3.7402344 + - -2.1425781 + - 0.51416016 + - 1.0888672 + - -2.375 + - 1.8486328 + - -3.671875 + - -2.8691406 + - -0.50878906 + - -2.3476563 + - -0.9975586 + - -2.390625 + - -0.022872925 + - 1.8251953 + - 1.421875 + - -0.38720703 + - 1.7363281 + - 2.8496094 + - -0.7216797 + - -2.0195313 + - 1.3427734 + - 2.3515625 + - 0.8642578 + - -1.6220703 + - -0.9550781 + - 0.5053711 + - 0.060821533 + - -0.28515625 + - -3.6992188 + - -1.28125 + - -1.2978516 + - 1.7617188 + - -0.9326172 + - 0.96533203 + - 0.1439209 + - 2.8222656 + - -0.20129395 + - -1.4619141 + - 8.03125 + - -2.1132813 + - 3.6503906 + - -4.0273438 + - 3.6367188 + - 4.21875 + - -4.0664063 + - 1.1337891 + - 1.7832031 + - -0.22033691 + - -1.1425781 + - -0.35546875 + - -0.17297363 + - 1.8232422 + - -1.7207031 + - -1.2578125 + - -1.7851563 + - 3.9609375 + - -0.72802734 + - 1.2285156 + - 0.44677734 + - -1.2597656 + - 0.921875 + - -0.5136719 + - -0.51171875 + - -1.1142578 + - 3.3339844 + - 0.89208984 + - -2.1738281 + - 1.609375 + - -0.69873047 + - -2.7265625 + - 0.4440918 + - -2.1386719 + - -0.85253906 + - 2.6328125 + - 2.1425781 + - 2.1855469 + - -8.9609375 + - 4.40625 + - -0.5805664 + - 0.3293457 + - 0.48657227 + - -3.5019531 + - 1.9033203 + - 0.44970703 + - -1.5009766 + - 1.4414063 + - -4.625 + - 0.40112305 + - -0.21362305 + - -0.4753418 + - 0.07678223 + - 0.234375 + - 1.1494141 + - -0.34545898 + - -0.74853516 + - 0.7314453 + - 2.0800781 + - -2.4199219 + - 1.4638672 + - -2.5507813 + - 1.5810547 + - 2.359375 + - 0.77978516 + - 1.078125 + - 1.9570313 + - -0.3322754 + - 0.08258057 + - -1.2578125 + - 4.4570313 + - 1.421875 + - 2.5390625 + - 1.0166016 + - -4.0390625 + - 0.66503906 + - -0.40161133 + - -0.38891602 + - -0.26391602 + - 1.1357422 + - -0.9375 + - 1.3476563 + - 6.3554688 + - 1.0732422 + - -8.7421875 + - 1.2675781 + - 1.3388672 + - -0.11828613 + - -0.9863281 + - 2.9414063 + - 6.1757813 + - -1.8085938 + - -0.09820557 + - -0.61816406 + - -1.453125 + - 1.4726563 + - -0.7734375 + - 0.21923828 + - -0.22814941 + - -2.4238281 + - -0.43408203 + - -0.5 + - 4.0820313 + - -1.9326172 + - -1.4404297 + - 0.12634277 + - 1.7939453 + - 3.6191406 + - 2.1953125 + - 1.0546875 + - 0.49658203 + - 2.7050781 + - 0.66796875 + - -24.84375 + - 1.6748047 + - -4.6367188 + - -1.8183594 + - -15.671875 + - -1.2568359 + - -0.6870117 + - 3.0644531 + - -3.7128906 + - 2.609375 + - -7.5625 + - -7.9375 + - 0.80908203 + - -0.95410156 + - 2.0214844 + - -1.1650391 + - 0.3779297 + - 4.4375 + - -0.9453125 + - 1.5361328 + - 1.0087891 + - 2.0332031 + - 1.9931641 + - -2.9023438 + - -2.4765625 + - 3.6621094 + - -2.5761719 + - 1.8408203 + - 1.6982422 + - -5.0117188 + - 1.9042969 + - -0.31225586 + - -0.08258057 + - 2.3535156 + - 0.6352539 + - -1.6601563 + - 1.7197266 + - -1.8496094 + - 0.73046875 + - -0.04547119 + - 0.45996094 + - 0.036834717 + - 3.46875 + - 1.4023438 + - 0.061920166 + - 3.7128906 + - 2.75 + - 1.5185547 + - -1.0664063 + - -1.0947266 + - 1.7597656 + - -1.0664063 + - -2.015625 + - 2.078125 + - 1.390625 + - 3.1171875 + - -1.6494141 + - -4.7148438 + - 0.67285156 + - -2.6191406 + - 0.16210938 + - 2.4414063 + - -3.1289063 + - -0.6411133 + - -0.37329102 + - -0.4140625 + - -0.13000488 + - 4.5664063 + - 2.875 + - 1.4648438 + - -4.6757813 + - -0.13916016 + - 3.0117188 + - 0.57666016 + - -0.4453125 + - 1.3945313 + - 0.28149414 + - -0.7294922 + - -1.0039063 + - 2.1191406 + - -3.484375 + - -0.22729492 + - 1.3056641 + - -0.33862305 + - 0.5800781 + - 4.0390625 + - -0.5722656 + - 0.7241211 + - -1.4550781 + - -3.84375 + - 0.85791016 + - -1.71875 + - 0.92822266 + - -1.546875 + - -2.46875 + - 0.94970703 + - -3.0800781 + - -8.6328125 + - 0.8774414 + - -3.7089844 + - 0.2854004 + - 2.4003906 + - 1.1992188 + - -3.4628906 + - 0.6152344 + - -3.5566406 + - -1.8525391 + - -5.1367188 + - -0.82128906 + - 0.005718231 + - -0.0025100708 + - 3.9492188 + - -0.89208984 + - 1.4550781 + - -3.1503906 + - -2.7421875 + - -1.1074219 + - 0.19470215 + - -0.9003906 + - -3.0742188 + - 0.81884766 + - -2.4941406 + - -0.4404297 + - -0.12817383 + - 1.2353516 + - -0.32226563 + - 0.5078125 + - -3.4140625 + - -1.6044922 + - 0.5761719 + - -5.2070313 + - -2.2285156 + - 2.5839844 + - 5.3945313 + - 5.4726563 + - -0.2890625 + - 0.23120117 + - 4.4335938 + - 3.2597656 + - -1.6689453 + - -0.9008789 + - -2.3066406 + - 0.3330078 + - 2.8515625 + - -1.0039063 + - -0.74609375 + - -0.6118164 + - -0.7519531 + - -2.0234375 + - -2.296875 + - 2.4609375 + - -1.8095703 + - 1.2333984 + - -0.20812988 + - -2.3496094 + - -0.021194458 + - 0.78271484 + - 1.359375 + - -0.5175781 + - -0.7998047 + - 0.5258789 + - 2.2089844 + - -0.94970703 + - -1.5 + - -4.6523438 + - -0.04547119 + - 0.20422363 + - 3.4082031 + - -0.46362305 + - 0.18469238 + - 2.3476563 + - 23.5 + - -0.8959961 + - -3.0800781 + - 4.359375 + - 0.5830078 + - 4.0507813 + - -2.0234375 + - -13.3203125 + - 1.4960938 + - -1.0517578 + - 4.7539063 + - 0.66845703 + - 0.11383057 + - 1.2207031 + - 0.8408203 + - 2.2832031 + - 1.4814453 + - -4.9179688 + - 0.30908203 + - -4.7148438 + - 1.0234375 + - -3.7539063 + - 0.36450195 + - -0.19970703 + - -1.4775391 + - 3.5820313 + - -0.9350586 + - -2.2519531 + - 0.29345703 + - 3.0703125 + - -0.5292969 + - -0.6928711 + - 1.3974609 + - -1.6289063 + - -1.3476563 + - -2.0527344 + - -0.32861328 + - -0.2668457 + - -0.95947266 + - 0.1149292 + - -2.5957031 + - 2.2675781 + - -1.0664063 + - -1.7275391 + - 1.9658203 + - -0.79833984 + - 0.29541016 + - 1.7871094 + - -3.4179688 + - 3.5722656 + - 1.0419922 + - -1.3701172 + - 5.9101563 + - -2.6601563 + - -2.3671875 + - 0.8227539 + - 0.7866211 + - 2.9375 + - -2.3496094 + - 1.5 + - -2.4375 + - 3.8300781 + - 0.7109375 + - -1.203125 + - -0.06329346 + - 6.1054688 + - 3.3710938 + - -0.41015625 + - -1.71875 + - -0.3671875 + - -1.1767578 + - -0.25268555 + - -0.30078125 + - -0.1940918 + - -2.7109375 + - -5.9179688 + - 6.5351563 + - 0.9375 + - -2.3789063 + - -1.8955078 + - 1.6210938 + - 0.37548828 + - -0.31518555 + - -0.21875 + - 0.5830078 + - 1.2382813 + - 0.7890625 + - 1.6132813 + - -3.2402344 + - 0.8442383 + - 1.3203125 + - -1.9482422 + - 0.46557617 + - 0.17077637 + - 5.1757813 + - 2.1425781 + - -1.6201172 + - 4.75 + - -1.0703125 + - 2.4785156 + - 4.703125 + - -0.54296875 + - -1.9921875 + - 5.75 + - 0.78759766 + - 0.38354492 + - -1.2578125 + - -0.17211914 + - 2.4511719 + - 1.6533203 + - -1.2587891 + - -1.6181641 + - -1.8476563 + - -0.71875 + - -0.42626953 + - 0.3869629 + - 0.7348633 + - 0.12426758 + - 0.29516602 + - -2.078125 + - 2.2558594 + - 23.0625 + - -3.9101563 + - 2.9472656 + - -0.171875 + - 0.9301758 + - 2.3613281 + - 0.18798828 + - -2.0449219 + - 0.28344727 + - -0.8486328 + - -1.4492188 + - 1.9501953 + - -2.3046875 + - -1.6992188 + - -0.25854492 + - 0.31225586 + - -5.1601563 + - 1.9814453 + - 2.15625 + - 14.546875 + - -2.7011719 + - 1.4033203 + - -0.11602783 + - -1.4033203 + - 0.2109375 + - -0.6464844 + - 0.63916016 + - 0.6640625 + - -0.21984863 + - -1.2744141 + - -26 + - -0.5029297 + - 0.55078125 + - 1.0742188 + - -2.9101563 + - -0.4951172 + - -0.6484375 + - 0.9194336 + - -2.46875 + - 0.9267578 + - 0.5957031 + - -3.828125 + - -1.3505859 + - -0.8256836 + - -0.15515137 + - -1.0332031 + - -1.2939453 + - -2.9804688 + - 0.6225586 + - -0.23510742 + - -2.3261719 + - 0.8261719 + - 2.6347656 + - 0.2565918 + - 3.4257813 + - -1.4033203 + - 3.1738281 + - -0.5678711 + - 7.6953125 + - -1.9326172 + - 2.5859375 + - 4.0039063 + - -6.6484375 + - 2.4199219 + - -2.1757813 + - 4.3632813 + - -0.8208008 + - -0.5097656 + - -1.734375 + - 0.50439453 + - 0.62841797 + - 0.9951172 + - -5.5351563 + - 2.953125 + - -0.18005371 + - -2.4003906 + - 0.027893066 + - 2.7128906 + - 2.5332031 + - 2.6386719 + - 2.5058594 + - -1.9511719 + - -1.2734375 + - 1.8320313 + - 4.15625 + - 1.4335938 + - -1.4951172 + - -3.8300781 + - -0.64501953 + - -4.1640625 + - -1.1318359 + - 2.1132813 + - 2.2207031 + - 3.6367188 + - -1.140625 + - 4.890625 + - 4.9960938 + - 2.046875 + - -0.734375 + - -1.0810547 + - 0.76953125 + - -1.2734375 + - 1.3349609 + - -1.2626953 + - 1.3642578 + - -1.4804688 + - -2.6601563 + - 0.62158203 + - -3.5585938 + - -0.33520508 + - -3.3691406 + - -3.9375 + - -0.76464844 + - 0.5126953 + - 3.0058594 + - -1.4169922 + - -0.14758301 + - 2.9179688 + - 0.7988281 + - 0.52978516 + - -2.7910156 + - 3.359375 + - 2.0585938 + - -1.4140625 + - -3.3203125 + - 3.6015625 + - -0.56884766 + - 3.9375 + - -2.7890625 + - -0.921875 + - -1.0517578 + - 0.8203125 + - 3.4902344 + - 2.4726563 + - -0.17346191 + - 0.94189453 + - -3.7363281 + - -6.0507813 + - -0.46191406 + - -1.4873047 + - 2.65625 + - 2.6914063 + - 0.81689453 + - 1.0429688 + - 2.1601563 + - 0.59814453 + - -0.07366943 + - 2.3574219 + - -1.8486328 + - 2.9550781 + - 0.99902344 + - -0.4560547 + - -0.3359375 + - -0.8046875 + - -0.6621094 + - 12.1953125 + - 0.52441406 + - 2.53125 + - 5.7734375 + - 7.8046875 + - -1.21875 + - 0.42993164 + - -1.0869141 + - 1.4628906 + - -2.6542969 + - -1.7949219 + - 1.34375 + - 0.66845703 + - 0.29956055 + - -2.5566406 + - -0.7207031 + - 1.0195313 + - 1.8886719 + - 1.9316406 + - 0.34399414 + - -0.17321777 + - -0.1821289 + - -0.7832031 + - -1.9394531 + - -2.1015625 + - -1.4257813 + - 1.2460938 + - -0.46191406 + - -2.4238281 + - -3.4238281 + - 2.7890625 + - 2.1503906 + - 1.9941406 + - 1.0136719 + - 0.22485352 + - -0.98291016 + - 1.9404297 + - -1.7470703 + - 0.74072266 + - 1.8251953 + - -1.4882813 + - 1.2548828 + - -1.7763672 + - -0.55859375 + - 3.9375 + - -0.7192383 + - 1.7089844 + - -2.6484375 + - -1.0927734 + - -2.9003906 + - 3.2207031 + - 1.0126953 + - -2.4003906 + - -1.1132813 + - 4.1015625 + - 1.8291016 + - 1.0341797 + - 1.5966797 + - 4.1914063 + - 0.8461914 + - -1.8164063 + - -1.6669922 + - 1.4746094 + - 1.5244141 + - 1.2060547 + - 4.1875 + - 2.5195313 + - 2.265625 + - 1.9580078 + - -1.4179688 + - -0.6538086 + - -1.8564453 + - 1.2441406 + - 0.19885254 + - -0.050201416 + - -1.1044922 + - 0.34765625 + - 1.390625 + - 0.10595703 + - 3.0839844 + - -0.97753906 + - 0.080322266 + - 0.86376953 + - -0.27001953 + - 23.46875 + - -3.4648438 + - -1.1455078 + - -4.2460938 + - -0.22766113 + - 0.7368164 + - 2.34375 + - -0.09429932 + - -4.7851563 + - 1.6826172 + - 2.5976563 + - -1.3603516 + - 3.3925781 + - 2.5390625 + - 1.9511719 + - 0.51953125 + - 1.6357422 + - -3.0820313 + - 1.7158203 + - 0.9614258 + - -2.2148438 + - 1.7001953 + - -3.6777344 + - 1.7763672 + - 0.0758667 + - 0.8208008 + - -2.2089844 + - 0.12011719 + - 2.3339844 + - -3.7714844 + - -0.77197266 + - 1.3144531 + - 2.078125 + - 2.1347656 + - 2.4082031 + - -1.5664063 + - 6.2851563 + - -0.035705566 + - 0.3269043 + - -0.6582031 + - -4.3398438 + - -3.5703125 + - 0.5024414 + - 4.9257813 + - 0.38110352 + - 0.20275879 + - -1.5664063 + - 1.7324219 + - 2.8144531 + - 3.9101563 + - -0.5703125 + - -1.8300781 + - 0.39135742 + - 8.6640625 + - -3.2226563 + - -1.21875 + - 0.6303711 + - -1.2597656 + - 1.1396484 + - 0.5097656 + - 1.3017578 + - -0.11853027 + - -0.11633301 + - -4.2382813 + - -3.5429688 + - -2.6660156 + - -3.125 + - -2.9941406 + - 0.49731445 + - -2.203125 + - -1.2890625 + - 3.2851563 + - -0.7158203 + - -1.8212891 + - 0.6801758 + - -3.3378906 + - -4.4023438 + - -0.29785156 + - 2.0722656 + - -2.6738281 + - -0.19897461 + - 1.1738281 + - 2.1875 + - 1.2285156 + - -1.1191406 + - -3.0839844 + - -1.4257813 + - -0.87158203 + - -2.9550781 + - 0.016738892 + - -0.5004883 + - -0.26733398 + - 4.171875 + - -1.1015625 + - 2.6386719 + - -3.3027344 + - -2.3066406 + - -1.2890625 + - -0.68310547 + - 1.1992188 + - -1.3095703 + - 1.4726563 + - 1.0214844 + - 0.8647461 + - 0.40307617 + - -1.2763672 + - -1.6074219 + - 1.5175781 + - -1.4238281 + - 1.6337891 + - 0.4814453 + - -0.33032227 + - 2.7382813 + - 0.9296875 + - 0.21643066 + - 1.2539063 + - -3.8339844 + - -2.6425781 + - -3.2421875 + - -1.3925781 + - 0.30249023 + - -0.22033691 + - 0.5292969 + - 1.0478516 + - 1.1650391 + - 1.2773438 + - -1.2050781 + - -2.421875 + - 1.1992188 + - 2.1015625 + - -2.7226563 + - 2.1171875 + - 0.45581055 + - 0.33129883 + - 1.2685547 + - 0.67285156 + - -5.5898438 + - -3.34375 + - -1.0898438 + - 1.5175781 + - 0.026779175 + - -2.2480469 + - -0.9560547 + - 4.9257813 + - -0.17370605 + - 1.3681641 + - 6.5820313 + - 2.5605469 + - -2.6855469 + - 0.83984375 + - -0.056915283 + - 6.015625 + - -4.9570313 + - -2.1777344 + - 0.9863281 + - -2.1269531 + - -0.57910156 + - -2.3925781 + - 1.8867188 + - -3.3476563 + - 3.1953125 + - -1.1894531 + - 0.7207031 + - 0.15515137 + - -0.5161133 + - -1.1982422 + - 0.96875 + - -0.23339844 + - -1.9394531 + - 5.9726563 + - 0.79003906 + - 2.4414063 + - -0.31469727 + - -4.46875 + - 2.4296875 + - 0.24865723 + - 1.3359375 + - -0.7138672 + - -1.3564453 + - -0.7661133 + - 1.1220703 + - -2.015625 + - -3.0722656 + - -0.030685425 + - 0.69677734 + - 1.7275391 + - 2.8183594 + - -2.3203125 + - 1.234375 + - 0.3095703 + - -2.7070313 + - 0.34692383 + - 3.5566406 + - 1.3251953 + - 5.75 + - 0.24768066 + - 0.06359863 + - 16.1875 + - -0.41845703 + - 2.3007813 + - -3.5507813 + - -0.90722656 + - -0.89746094 + - 0.5439453 + - 1.4785156 + - 4.1484375 + - -0.9238281 + - -3.5253906 + - -1.8232422 + - 0.87402344 + - 1.9189453 + - 1.0517578 + - -1.1347656 + - 4.4570313 + - -0.26879883 + - -0.66796875 + - 0.24414063 + - -1.6445313 + - 0.30395508 + - -1.5214844 + - -2.2949219 + - -1.6738281 + - 2.3652344 + - -0.22375488 + - -4 + - -3.1015625 + - 0.7397461 + - -0.9951172 + - -0.88134766 + - -1.8613281 + - -1.8925781 + - 0.17687988 + - -0.08227539 + - 3.0117188 + - 0.75683594 + - 2.7890625 + - 0.28637695 + - 1.9667969 + - -4.5898438 + - 0.88378906 + - 0.64941406 + - -0.06854248 + - 4.2070313 + - -1.3662109 + - -1.3671875 + - -2.0664063 + - -5.4882813 + - 2.1308594 + - 1.8994141 + - -0.31152344 + - 2.8789063 + - 4.703125 + - -1.640625 + - -0.17565918 + - -3.8339844 + - -0.13244629 + - -1.8339844 + - -0.77197266 + - -1.1074219 + - 1.7451172 + - -2.703125 + - -0.38671875 + - 1.0224609 + - 1.9111328 + - -4.953125 + - 3.3925781 + - 0.9248047 + - -0.57373047 + - -1.6894531 + - 4.6914063 + - 0.9428711 + - 1.1796875 + - 1.0107422 + - -1.9638672 + - -2.4433594 + - 1.6601563 + - 1.3613281 + - 2.390625 + - 0.17053223 + - 4.7617188 + - -1.6230469 + - -1.1416016 + - 0.96484375 + - -1.5556641 + - -0.76660156 + - -1.5439453 + - 0.62353516 + - -4.3476563 + - -0.82666016 + - 1.6621094 + - 1.9033203 + - -2.375 + - 2.5566406 + - -3.9316406 + - 2.6777344 + - 0.7910156 + - -0.7397461 + - 4.5976563 + - -0.8935547 + - -2.609375 + - 1.921875 + - 2.4296875 + - 3.3144531 + - 1.7685547 + - -1.0107422 + - -0.22399902 + - 0.45361328 + - 33.40625 + - 13.4609375 + - -9.1796875 + - 2.265625 + - -1.0498047 + - 1.4277344 + - -2.7285156 + - -4.171875 + - -0.36083984 + - -0.20532227 + - 1.9619141 + - 0.51708984 + - -0.3388672 + - 1.5126953 + - -2.7910156 + - 1.9707031 + - -1.0048828 + - 0.9091797 + - -2.6953125 + - 0.71533203 + - 1.8789063 + - 3.4160156 + - -1.3212891 + - -1.1416016 + - -0.22705078 + - -2.1503906 + - 0.08703613 + - -0.40356445 + - -4.6054688 + - 0.75439453 + - -0.12780762 + - -0.15905762 + - 1.421875 + - 2.4765625 + - 1.6376953 + - -4.375 + - -1.8544922 + - 2.0644531 + - -2.1660156 + - 1.2460938 + - 2.2285156 + - 1.5400391 + - -0.2800293 + - 4.2265625 + - -1.2050781 + - 0.29296875 + - -3.4941406 + - 2.1425781 + - 1.3056641 + - 0.51171875 + - 2.2910156 + - 8.734375 + - -0.5722656 + - -1.4316406 + - 1.7226563 + - -0.9472656 + - -0.84472656 + - 0.054107666 + - 1.4589844 + - 0.21362305 + - 2.9804688 + - 2.3964844 + - 1.203125 + - -3.9238281 + - -1.7451172 + - -1.1357422 + - 1.9345703 + - -0.8339844 + - -2.6875 + - 0.25439453 + - -2.9238281 + - -0.20739746 + - -1.5019531 + - -2.2675781 + - 0.92626953 + - -2.6699219 + - -0.18823242 + - 1.3486328 + - 5.4453125 + - 0.4140625 + - -1.7626953 + - -1.4208984 + - 1.6337891 + - 1.8632813 + - 1.6884766 + - 2.3789063 + - 1.1064453 + - 0.22314453 + - 1.9423828 + - -1.53125 + - 1.3662109 + - 0.50439453 + - -0.8911133 + - -1.0019531 + - 3.65625 + - 1.2099609 + - -1.3984375 + - 4.0351563 + - -1.9003906 + - 0.5229492 + - -3.4648438 + - -1.0595703 + - 0.75097656 + - 1.15625 + - 0.12231445 + - 0.48754883 + - 0.32348633 + - -2.3203125 + - -0.081970215 + - 1.484375 + - -3.2929688 + - 3.6777344 + - -0.6933594 + - 4.28125 + - 1.8056641 + - 2.8339844 + - -2.9140625 + - -1.3173828 + - 3.515625 + - 0.4248047 + - -2.3886719 + - -1.8857422 + - 0.875 + - 1.1064453 + - 3.609375 + - 1.3613281 + - -3.2714844 + - 2.0546875 + - 2.4140625 + - 0.1270752 + - -0.8769531 + - -1.2519531 + - -1.1103516 + - 1.2451172 + - 0.2758789 + - 0.30737305 + - -0.18188477 + - -3.4394531 + - 1.5400391 + - -1.2939453 + - -0.4375 + - 1.9580078 + - 1.7792969 + - -2.1367188 + - -0.2956543 + - -0.17468262 + - 2.0078125 + - -1.203125 + - -0.140625 + - -4.109375 + - 1.1669922 + - 1.3193359 + - -1.4697266 + - -1.4335938 + - 0.4091797 + - -0.91503906 + - -1.1445313 + - 0.41333008 + - 0.4038086 + - 2.1660156 + - 0.09411621 + - -2.5546875 + - 2.7890625 + - 1.7773438 + - -0.9394531 + - 0.4284668 + - 0.328125 + - 2.3417969 + - -0.12164307 + - -2.5566406 + - -0.50927734 + - -0.265625 + - -2.6074219 + - -1.3457031 + - 0.58691406 + - 0.71728516 + - 1.4130859 + - 1.96875 + - -1.1738281 + - -1.75 + - -0.6010742 + - 0.38598633 + - -0.52441406 + - 0.90283203 + - 1.5185547 + - -1.5732422 + - -0.068359375 + - 1.7675781 + - 1.7275391 + - -1.2802734 + - 2.3789063 + - 2.3203125 + - 1.7792969 + - 0.7207031 + - -2.4882813 + - -1.8632813 + - 2.9804688 + - 1.1787109 + - 0.92089844 + - -3.390625 + - -2.7675781 + - -1.4277344 + - -2.8476563 + - -0.42285156 + - 0.39453125 + - -12.4453125 + - -0.31469727 + - -0.46240234 + - 0.21875 + - -0.88916016 + - 0.5488281 + - -1.2509766 + - 1.6689453 + - 0.45922852 + - -1.7119141 + - 2.3417969 + - -5.375 + - 0.4868164 + - 0.32421875 + - -1.1748047 + - 1.3769531 + - 1.5244141 + - -2.0566406 + - -0.025665283 + - 3.4238281 + - 0.61816406 + - 1.8251953 + - -0.53515625 + - 9.390625 + - 1.4433594 + - -2.1425781 + - 0.7246094 + - -0.52197266 + - 0.8935547 + - -0.88916016 + - -0.08459473 + - -2.6640625 + - 6.75 + - 0.68066406 + - -1.7714844 + - 0.7470703 + - 1.0390625 + - -6.09375 + - 0.71484375 + - 0.29418945 + - 1.3671875 + - 0.44189453 + - 6.2929688 + - -0.5942383 + - -2.7695313 + - 1.8964844 + - 2.2207031 + - 2.4628906 + - 2.109375 + - 1.1445313 + - -2.8378906 + - 1.5419922 + - 1.8007813 + - -3.15625 + - -1.0839844 + - -0.3232422 + - -0.43164063 + - -3.1992188 + - -1.8183594 + - -3.2753906 + - -0.1986084 + - -3.8652344 + - 2.4101563 + - -1.6914063 + - -1.796875 + - 3.5683594 + - -2.4199219 + - 0.18859863 + - -1.6337891 + - -1.6347656 + - 2.0566406 + - -0.3544922 + - -1.3388672 + - 1.7558594 + - 1.6328125 + - -0.6225586 + - 0.6425781 + - 0.61083984 + - 2.1738281 + - 0.8647461 + - 3.7578125 + - 0.01953125 + - -0.26611328 + - -1.7851563 + - 2.6621094 + - 0.1842041 + - -2.0214844 + - -1.2861328 + - -1.5732422 + - -0.09051514 + - 5.2382813 + - 4.703125 + - -1.1425781 + - 1.9355469 + - 2.3378906 + - -0.7207031 + - -1.25 + - -0.4050293 + - 2.0273438 + - -1.9423828 + - 2.2753906 + - -3.4765625 + - 2.8359375 + - 0.7866211 + - -3.9609375 + - -0.10961914 + - -2.6640625 + - 3.25 + - 0.3005371 + - -5.5078125 + - -0.27075195 + - -1.765625 + - 1.6582031 + - 0.4284668 + - 0.68310547 + - 3.4550781 + - 0.47021484 + - 1.2822266 + - -0.31884766 + - -3.0898438 + - -1.6689453 + - -0.5917969 + - -3.7890625 + - 8.9140625 + - 1.1953125 + - 1.4628906 + - -0.5317383 + - 0.52783203 + - -1.5 + - 0.43896484 + - 1.1591797 + - -1.2998047 + - -5.4804688 + - -3.4003906 + - 4.6367188 + - -4.171875 + - 1.8056641 + - -1.84375 + - -2.8164063 + - 1.2988281 + - 0.89208984 + - -0.5800781 + - 0.27661133 + - 1.2519531 + - 1.1083984 + - -3.1777344 + - 0.07696533 + - -4.0429688 + - 1.703125 + - -1.59375 + - 1.2041016 + - -3.5976563 + - 0.8105469 + - -1.4296875 + - 0.93847656 + - -2.5 + - -1.0498047 + - 0.07159424 + - 2.2539063 + - 3.2402344 + - 0.5004883 + - 1.6611328 + - -1.6152344 + - 2.4199219 + - 1.2880859 + - -0.7167969 + - -1.1738281 + - -2.6914063 + - -0.23876953 + - 0.51708984 + - 2.5664063 + - -2.8828125 + - -0.09454346 + - -0.0020923615 + - 4.2304688 + - -0.010597229 + - -2.2207031 + - 0.36743164 + - 1.984375 + - -2.21875 + - -2.3183594 + - -0.9819336 + - 1.2138672 + - 1.9511719 + - -0.53466797 + - 0.7192383 + - -1.4638672 + - -0.29736328 + - 0.82910156 + - 3.0742188 + - -2.9179688 + - -2.7089844 + - 1.5957031 + - 1.8515625 + - 5.8125 + - 2.6269531 + - -1.5332031 + - 1.4589844 + - -0.59716797 + - 1.0800781 + - -1.6582031 + - -2.015625 + - -0.9116211 + - 1.2197266 + - -1.9160156 + - 1.1708984 + - -1.0478516 + - 3.5195313 + - 4.3398438 + - -0.51708984 + - 0.17626953 + - -0.23376465 + - -1.4296875 + - -3.3242188 + - -2.8652344 + - -0.8925781 + - 1.3798828 + - -1.0742188 + - 0.85595703 + - 2.1699219 + - 1.5449219 + - 1.4101563 + - -0.4128418 + - 0.86865234 + - -4.921875 + - -0.9008789 + - -8.3046875 + - -1.734375 + - -2.0214844 + - -2.2714844 + - -2.90625 + - -0.96777344 + - 2.8417969 + - -6.7421875 + - -4.4335938 + - 24.671875 + - -1.7294922 + - -1.6435547 + - -0.6557617 + - -0.17883301 + - 0.50634766 + - 2.3261719 + - 3.0898438 + - -2.15625 + - 1.1416016 + - 1.6894531 + - -0.03488159 + - 0.88378906 + - -1.4248047 + - 0.42895508 + - 0.09020996 + - -3.4160156 + - 0.7285156 + - 4.890625 + - -0.75 + - -0.55126953 + - -1.4794922 + - -2.4765625 + - 0.6567383 + - -0.34155273 + - 3.7578125 + - 0.36376953 + - -2.0878906 + - 2.2304688 + - -0.27441406 + - 1.5878906 + - -2.5488281 + - 0.77246094 + - 0.4033203 + - 1.2587891 + - -0.55615234 + - 1.6416016 + - 2.984375 + - 4.1796875 + - 0.13500977 + - -0.85595703 + - -0.55322266 + - 2.0449219 + - -3.890625 + - 0.7788086 + - -0.2800293 + - 3.2695313 + - 1.1845703 + - -2.0371094 + - 0.7270508 + - 2.3496094 + - 0.83691406 + - -3.1035156 + - -1.3164063 + - -2.0175781 + - -1.6425781 + - -2.9003906 + - -0.42822266 + - 2.3769531 + - -3.4570313 + - -2.8359375 + - 1.1767578 + - -0.5722656 + - 2.4550781 + - -2.5039063 + - -0.0993042 + - -1.1953125 + - -0.012275696 + - -2.7324219 + - 1.5888672 + - -4.6132813 + - -4.3554688 + - -0.115478516 + - -1.5566406 + - 1.4550781 + - 8.6328125 + - 0.89697266 + - 3.6796875 + - -4.7578125 + - 1.1884766 + - -0.67285156 + - 1.3085938 + - 0.9038086 + - 0.6767578 + - -0.16455078 + - -4.7695313 + - 0.5332031 + - 0.76171875 + - 2.5664063 + - -0.84033203 + - -2.8378906 + - 0.4453125 + - -0.084106445 + - -0.55078125 + - -2.4765625 + - 1.4394531 + - 2.109375 + - -2.5664063 + - 5.3554688 + - 0.3088379 + - 0.37426758 + - 0.9243164 + - 0.53271484 + - 4.0078125 + - 0.27270508 + - 2.0820313 + - -1.8183594 + - -0.5209961 + - 0.54345703 + - 2.3847656 + - 7.1640625 + - 1.7158203 + - 1.0996094 + - -1.0556641 + - 3.5527344 + - 0.05078125 + - 1.7119141 + - 1.7900391 + - 2.2285156 + - -0.30566406 + - 3.09375 + - -0.6933594 + - 3.5976563 + - -4.484375 + - -1.4716797 + - -2.0273438 + - 0.9428711 + - 0.004463196 + - 1.3388672 + - -0.42236328 + - 4.0742188 + - -1.9814453 + - -2.109375 + - -0.8417969 + - 0.016311646 + - 2.9804688 + - 2.4042969 + - 0.7421875 + - 1.1767578 + - 3.2851563 + - 4.1992188 + - 0.7553711 + - -0.578125 + - 1.3769531 + - 2.078125 + - -4.9882813 + - -4.578125 + - -0.96484375 + - 3.3046875 + - -1.5917969 + - -0.75097656 + - -1.9638672 + - 2.8613281 + - 3.2753906 + - 3.2617188 + - -0.8564453 + - -0.28076172 + - 1.3603516 + - -1.3505859 + - -0.44799805 + - 2.5859375 + - 2.6894531 + - -0.9707031 + - -0.359375 + - 0.41503906 + - 1.7861328 + - 0.39282227 + - -0.1227417 + - -0.35986328 + - 1.2529297 + - 2.1425781 + - 0.90625 + - -2.1171875 + - -0.32250977 + - -3.6425781 + - -4.8789063 + - -0.09008789 + - 2.5820313 + - -0.8569336 + - -0.3659668 + - 3.1269531 + - -2.1777344 + - 2.0078125 + - 0.55859375 + - -0.9863281 + - -2.9140625 + - 1.4023438 + - -0.52001953 + - 3.0664063 + - 3.3515625 + - 1.2978516 + - -6.8359375 + - -0.47705078 + - -0.4194336 + - -5.390625 + - 2.1230469 + - -2.6640625 + - 2.4316406 + - 1.3896484 + - -6.4453125 + - 1.3085938 + - -0.65478516 + - -2.8007813 + - -2.4277344 + - 1.1220703 + - -0.37695313 + - 2.0820313 + - -0.42700195 + - -0.81347656 + - -33.90625 + - -2.5253906 + - -2.4140625 + - -0.39160156 + - -1.4277344 + - 2.0917969 + - 2.4101563 + - -4.7539063 + - -4.6601563 + - -0.90478516 + - 1.1181641 + - -1.4375 + - -1.0966797 + - 6.78125 + - 0.48706055 + - 4.7304688 + - -1.6582031 + - 4.3242188 + - -0.24768066 + - -1.4345703 + - 0.11437988 + - -0.453125 + - 1.0810547 + - 1.8134766 + - -0.4345703 + - -4.015625 + - -1.2519531 + - 0.05355835 + - 1.8691406 + - -0.36376953 + - 0.57177734 + - -1.2675781 + - 0.36206055 + - -0.5605469 + - -3.4941406 + - 4.8632813 + - -3.3027344 + - -0.8066406 + - -2.328125 + - -3.4863281 + - 0.029846191 + - 1.9746094 + - 2.6289063 + - 0.015411377 + - 0.25048828 + - 1.7070313 + - 4 + - -0.63671875 + - 1.9033203 + - -2.8378906 + - 2.6796875 + - -1.0927734 + - 0.2626953 + - -3.921875 + - 3.0117188 + - 2.6113281 + - -2.96875 + - 3.4550781 + - 2.6816406 + - 0.6640625 + - -1.0654297 + - -4.015625 + - 3.0058594 + - 1.3544922 + - 1.5175781 + - -0.38891602 + - 0.040161133 + - -5.0078125 + - 0.82666016 + - 1.3818359 + - -2.2207031 + - 0.7763672 + - 2.6074219 + - 0.4038086 + - -0.56103516 + - 2.2050781 + - -1.3994141 + - -2.6972656 + - 0.80566406 + - 0.42236328 + - -1.2441406 + - 2.0898438 + - 0.46972656 + - 1.0478516 + - 3.0527344 + - 0.8486328 + - -1.28125 + - 1.1132813 + - 2.0488281 + - 0.74658203 + - -2.3789063 + - 2.7949219 + - -1.0380859 + - 8.5703125 + - -1.4736328 + - 2.0292969 + - -0.59472656 + - -0.88183594 + - -0.4428711 + - -0.6660156 + - 2.8222656 + - 0.04714966 + - 3.53125 + - 1.0810547 + - 2.1230469 + - -2.1484375 + - -2.4238281 + - 3.5800781 + - -0.16760254 + - 5.9179688 + - -1.0576172 + - 5.9179688 + - -2.0292969 + - -0.9536133 + - -1.4013672 + - 1.5 + - 0.38745117 + - 0.7910156 + - -1.5820313 + - 4.1210938 + - 2.96875 + - 2.4902344 + - 4.6875 + - -0.7207031 + - -2.0996094 + - 1.7158203 + - -1.4609375 + - -4.0703125 + - -3.109375 + - 0.45117188 + - -4.3554688 + - -0.16455078 + - 1.7939453 + - 3.7363281 + - -1.1025391 + - -0.6791992 + - -30.3125 + - -0.8564453 + - -0.026504517 + - -0.66748047 + - 0.76416016 + - 3.5742188 + - 0.79296875 + - 1.8681641 + - 0.12719727 + - 2.0957031 + - 0.010040283 + - -0.14733887 + - -2.9140625 + - -2.2050781 + - 1.3681641 + - -2.3769531 + - 0.5546875 + - 0.07476807 + - -0.63378906 + - -1.5576172 + - 1.4462891 + - 10.890625 + - 3.125 + - -1.2587891 + - 1.1845703 + - 0.9394531 + - -0.8461914 + - 2.3105469 + - 0.3803711 + - -2.6035156 + - 1.2958984 + - 0.2529297 + - -2.2011719 + - 0.34106445 + - 0.37817383 + - -2.0605469 + - -3.2304688 + - 0.1685791 + - -0.5493164 + - -1.9033203 + - 5.6289063 + - 1.6601563 + - -1.2236328 + - 3.1679688 + - 1.0351563 + - 1.2753906 + - 0.0011701584 + - 3.140625 + - 0.6459961 + - -1.7978516 + - 0.19299316 + - 3.5117188 + - -2.3925781 + - 2.4589844 + - -1.5361328 + - -2.0097656 + - -0.9711914 + - 4.3320313 + - 0.4501953 + - -4.078125 + - 1.640625 + - -0.49487305 + - -0.68310547 + - -1.8125 + - -2.5019531 + - 0.07867432 + - -3.75 + - 0.7373047 + - 3.0117188 + - -6.9453125 + - 0.48876953 + - -1.3125 + - -3.3691406 + - -3.015625 + - 1.7744141 + - -0.86816406 + - -3.1210938 + - 0.06555176 + - 0.18383789 + - -0.3972168 + - -1.3349609 + - -0.6455078 + - 1.8955078 + - 1.7519531 + - 6.6796875 + - -1.4863281 + - -0.46948242 + - -1.2734375 + - -1.8232422 + - 2.0605469 + - -1.9619141 + - -0.69970703 + - 2.0683594 + - 0.15258789 + - 3.4492188 + - 0.89160156 + - 0.92285156 + - -1.0654297 + - 3.0019531 + - -0.6899414 + - 1.6308594 + - 0.5473633 + - -2.7011719 + - -1.1396484 + - 0.41479492 + - -0.5834961 + - -0.2142334 + - 4.5625 + - 1.4414063 + - -0.11456299 + - -1.6738281 + - 4.5039063 + - -0.5004883 + - 2.0371094 + - -2.7578125 + - -1.890625 + - 2.1015625 + - 2.5175781 + - -0.82128906 + - 0.8779297 + - 1.6621094 + - -1.1992188 + - -1.9658203 + - -1.2460938 + - 0.078125 + - -0.46875 + - -4.9023438 + - 0.04547119 + - -1.0234375 + - 3.3046875 + - 0.24829102 + - 0.66259766 + - -0.42407227 + - -0.1274414 + - 1.1132813 + - -0.35083008 + - -0.6723633 + - -0.47094727 + - -1.1416016 + - -4.4179688 + - 0.76953125 + - 4.2070313 + - 0.11364746 + - 1.3613281 + - 1.8681641 + - 0.6166992 + - 3.90625 + - -1.5507813 + - 0.046295166 + - 2.2636719 + - 2.2480469 + - 2.8027344 + - -1.9775391 + - 1.8564453 + - -1.6806641 + - 1.6044922 + - -2.3652344 + - 0.18908691 + - 1.0859375 + - 2.8300781 + - -0.6635742 + - 2.6914063 + - 2.7792969 + - 1.3203125 + - 2.5488281 + - -2.40625 + - 4.4882813 + - -2.4199219 + - -0.5385742 + - 1.7001953 + - -0.63720703 + - -2.5058594 + - 1.7324219 + - 0.103759766 + - -2.2871094 + - -1.5810547 + - -1.5009766 + - -1.6982422 + - -2.875 + - 3.1425781 + - 1.8691406 + - 1.7539063 + - -2.7480469 + - -0.32080078 + - -0.13049316 + - 2.4902344 + - 0.33203125 + - 2.4160156 + - -3.0175781 + - -0.18688965 + - 0.44848633 + - 1.0439453 + - 0.171875 + - 4.0351563 + - -0.09259033 + - 1.421875 + - -0.7915039 + - -1.9824219 + - -0.921875 + - 1.3632813 + - 1.0478516 + - 0.6333008 + - 1.2431641 + - -3.453125 + - 0.17626953 + - 1.7451172 + - 0.6254883 + - -0.36523438 + - 1.5126953 + - -1.1552734 + - -2.4199219 + - -5.5390625 + - -4.0976563 + - 6.078125 + - -1.3671875 + - -0.9116211 + - 1.2001953 + - -1.7539063 + - 2.0761719 + - -1.6425781 + - -2.3925781 + - -3.8867188 + - -2.203125 + - -2.640625 + - 0.74072266 + - 0.27661133 + - 1.4482422 + - -0.7949219 + - -1.1552734 + - 0.75683594 + - 0.123291016 + - -3.5039063 + - -1.7607422 + - -1.4736328 + - 3.1015625 + - 2.0839844 + - 6.2890625 + - -0.44213867 + - 2.5195313 + - -1.7119141 + - 1.8369141 diff --git a/backends/candle/tests/test_bert.rs b/backends/candle/tests/test_bert.rs index 45d02577..1bd5017f 100644 --- a/backends/candle/tests/test_bert.rs +++ b/backends/candle/tests/test_bert.rs @@ -1,8 +1,8 @@ mod common; -use crate::common::{sort_embeddings, SnapshotScores}; +use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores}; use anyhow::Result; -use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; +use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher}; use text_embeddings_backend_candle::CandleBackend; use text_embeddings_backend_core::{Backend, ModelType, Pool}; @@ -28,10 +28,10 @@ fn test_mini() -> Result<()> { vec![], ); - let matcher = relative_matcher(); + let matcher = cosine_matcher(); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); - let embeddings_batch = SnapshotScores::from(pooled_embeddings); + let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("mini_batch", embeddings_batch, &matcher); let input_single = batch( @@ -41,7 +41,7 @@ fn test_mini() -> Result<()> { ); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); - let embeddings_single = SnapshotScores::from(pooled_embeddings); + let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("mini_single", embeddings_single, &matcher); assert_eq!(embeddings_batch[0], embeddings_single[0]); @@ -57,8 +57,8 @@ fn test_mini() -> Result<()> { ); let (pooled_embeddings, raw_embeddings) = sort_embeddings(backend.embed(input_batch)?); - let pooled_embeddings = SnapshotScores::from(pooled_embeddings); - let raw_embeddings = SnapshotScores::from(raw_embeddings); + let pooled_embeddings = SnapshotEmbeddings::from(pooled_embeddings); + let raw_embeddings = SnapshotEmbeddings::from(raw_embeddings); assert_eq!(embeddings_batch[0], pooled_embeddings[0]); assert_eq!(raw_embeddings.len(), 8); @@ -91,13 +91,13 @@ fn test_mini_pooled_raw() -> Result<()> { [1, 4, 5].to_vec(), ); - let matcher = relative_matcher(); + let matcher = cosine_matcher(); let (pooled_embeddings, raw_embeddings) = sort_embeddings(backend.embed(input_batch)?); - let pooled_embeddings_batch = SnapshotScores::from(pooled_embeddings); + let pooled_embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("mini_batch_pooled", pooled_embeddings_batch, &matcher); - let raw_embeddings_batch = SnapshotScores::from(raw_embeddings); + let raw_embeddings_batch = SnapshotEmbeddings::from(raw_embeddings); insta::assert_yaml_snapshot!("mini_batch_raw", raw_embeddings_batch, &matcher); // Check that the first token of each raw embeddings member is the same as the cls pooling ones @@ -113,7 +113,7 @@ fn test_mini_pooled_raw() -> Result<()> { ); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); - let embeddings_single = SnapshotScores::from(pooled_embeddings); + let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("mini_single_pooled", embeddings_single, &matcher); assert_eq!(pooled_embeddings_batch[0], embeddings_single[0]); @@ -126,7 +126,7 @@ fn test_mini_pooled_raw() -> Result<()> { ); let (_, raw_embeddings) = sort_embeddings(backend.embed(input_single)?); - let embeddings_single = SnapshotScores::from(raw_embeddings); + let embeddings_single = SnapshotEmbeddings::from(raw_embeddings); insta::assert_yaml_snapshot!("mini_single_raw", embeddings_single, &matcher); assert_eq!(raw_embeddings_batch[0], embeddings_single[0]); diff --git a/backends/candle/tests/test_flash_bert.rs b/backends/candle/tests/test_flash_bert.rs index 1888a32b..ea150e7f 100644 --- a/backends/candle/tests/test_flash_bert.rs +++ b/backends/candle/tests/test_flash_bert.rs @@ -2,9 +2,9 @@ mod common; -use crate::common::{sort_embeddings, SnapshotScores}; +use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores}; use anyhow::Result; -use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; +use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher}; use text_embeddings_backend_candle::CandleBackend; use text_embeddings_backend_core::{Backend, ModelType, Pool}; @@ -34,10 +34,10 @@ fn test_flash_mini() -> Result<()> { vec![], ); - let matcher = relative_matcher(); + let matcher = cosine_matcher(); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); - let embeddings_batch = SnapshotScores::from(pooled_embeddings); + let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("mini_batch", embeddings_batch, &matcher); let input_single = batch( @@ -47,7 +47,7 @@ fn test_flash_mini() -> Result<()> { ); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); - let embeddings_single = SnapshotScores::from(pooled_embeddings); + let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("mini_single", embeddings_single, &matcher); assert_eq!(embeddings_batch[0], embeddings_single[0]); @@ -63,8 +63,8 @@ fn test_flash_mini() -> Result<()> { ); let (pooled_embeddings, raw_embeddings) = sort_embeddings(backend.embed(input_batch)?); - let pooled_embeddings = SnapshotScores::from(pooled_embeddings); - let raw_embeddings = SnapshotScores::from(raw_embeddings); + let pooled_embeddings = SnapshotEmbeddings::from(pooled_embeddings); + let raw_embeddings = SnapshotEmbeddings::from(raw_embeddings); assert_eq!(embeddings_batch[0], pooled_embeddings[0]); assert_eq!(raw_embeddings.len(), 8); @@ -101,13 +101,13 @@ fn test_flash_mini_pooled_raw() -> Result<()> { [1, 4, 5].to_vec(), ); - let matcher = relative_matcher(); + let matcher = cosine_matcher(); let (pooled_embeddings, raw_embeddings) = sort_embeddings(backend.embed(input_batch)?); - let pooled_embeddings_batch = SnapshotScores::from(pooled_embeddings); + let pooled_embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("mini_batch_pooled", pooled_embeddings_batch, &matcher); - let raw_embeddings_batch = SnapshotScores::from(raw_embeddings); + let raw_embeddings_batch = SnapshotEmbeddings::from(raw_embeddings); insta::assert_yaml_snapshot!("mini_batch_raw", raw_embeddings_batch, &matcher); // Check that the first token of each raw embeddings member is the same as the cls pooling ones @@ -123,7 +123,7 @@ fn test_flash_mini_pooled_raw() -> Result<()> { ); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); - let embeddings_single = SnapshotScores::from(pooled_embeddings); + let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("mini_single_pooled", embeddings_single, &matcher); assert_eq!(pooled_embeddings_batch[0], embeddings_single[0]); @@ -136,7 +136,7 @@ fn test_flash_mini_pooled_raw() -> Result<()> { ); let (_, raw_embeddings) = sort_embeddings(backend.embed(input_single)?); - let embeddings_single = SnapshotScores::from(raw_embeddings); + let embeddings_single = SnapshotEmbeddings::from(raw_embeddings); insta::assert_yaml_snapshot!("mini_single_raw", embeddings_single, &matcher); assert_eq!(raw_embeddings_batch[0], embeddings_single[0]); diff --git a/backends/candle/tests/test_flash_jina.rs b/backends/candle/tests/test_flash_jina.rs index 4a5f8276..255b82a2 100644 --- a/backends/candle/tests/test_flash_jina.rs +++ b/backends/candle/tests/test_flash_jina.rs @@ -1,9 +1,9 @@ #![allow(dead_code, unused_imports)] mod common; -use crate::common::{sort_embeddings, SnapshotScores}; +use crate::common::{sort_embeddings, SnapshotEmbeddings}; use anyhow::Result; -use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; +use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; use text_embeddings_backend_candle::CandleBackend; use text_embeddings_backend_core::{Backend, ModelType, Pool}; @@ -30,10 +30,10 @@ fn test_flash_jina_small() -> Result<()> { vec![], ); - let matcher = relative_matcher(); + let matcher = cosine_matcher(); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); - let embeddings_batch = SnapshotScores::from(pooled_embeddings); + let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("jina_batch", embeddings_batch, &matcher); let input_single = batch( @@ -43,7 +43,7 @@ fn test_flash_jina_small() -> Result<()> { ); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); - let embeddings_single = SnapshotScores::from(pooled_embeddings); + let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("jina_single", embeddings_single, &matcher); assert_eq!(embeddings_batch[0], embeddings_single[0]); diff --git a/backends/candle/tests/test_flash_jina_code.rs b/backends/candle/tests/test_flash_jina_code.rs index 508bf722..d84848dc 100644 --- a/backends/candle/tests/test_flash_jina_code.rs +++ b/backends/candle/tests/test_flash_jina_code.rs @@ -1,9 +1,9 @@ #![allow(dead_code, unused_imports)] mod common; -use crate::common::{sort_embeddings, SnapshotScores}; +use crate::common::{sort_embeddings, SnapshotEmbeddings}; use anyhow::Result; -use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; +use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; use text_embeddings_backend_candle::CandleBackend; use text_embeddings_backend_core::{Backend, ModelType, Pool}; @@ -30,10 +30,10 @@ fn test_flash_jina_code_base() -> Result<()> { vec![], ); - let matcher = relative_matcher(); + let matcher = cosine_matcher(); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); - let embeddings_batch = SnapshotScores::from(pooled_embeddings); + let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("jina_code_batch", embeddings_batch, &matcher); let input_single = batch( @@ -43,7 +43,7 @@ fn test_flash_jina_code_base() -> Result<()> { ); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); - let embeddings_single = SnapshotScores::from(pooled_embeddings); + let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("jina_code_single", embeddings_single, &matcher); assert_eq!(embeddings_batch[0], embeddings_single[0]); diff --git a/backends/candle/tests/test_flash_mistral.rs b/backends/candle/tests/test_flash_mistral.rs new file mode 100644 index 00000000..71749c8b --- /dev/null +++ b/backends/candle/tests/test_flash_mistral.rs @@ -0,0 +1,53 @@ +#![allow(dead_code, unused_imports)] +mod common; + +use crate::common::{sort_embeddings, SnapshotEmbeddings}; +use anyhow::Result; +use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; +use text_embeddings_backend_candle::CandleBackend; +use text_embeddings_backend_core::{Backend, ModelType, Pool}; + +#[test] +#[serial_test::serial] +#[cfg(all(feature = "cuda", feature = "flash-attn"))] +fn test_flash_mistral() -> Result<()> { + let model_root = download_artifacts("Salesforce/SFR-Embedding-2_R", None)?; + let tokenizer = load_tokenizer(&model_root)?; + + let backend = CandleBackend::new( + model_root, + "float16".to_string(), + ModelType::Embedding(Pool::Mean), + )?; + + let input_batch = batch( + vec![ + tokenizer.encode("What is Deep Learning?", true).unwrap(), + tokenizer.encode("Deep Learning is...", true).unwrap(), + tokenizer.encode("What is Deep Learning?", true).unwrap(), + ], + [0, 1, 2].to_vec(), + vec![], + ); + + let matcher = cosine_matcher(); + + let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); + let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); + insta::assert_yaml_snapshot!("mistral_batch", embeddings_batch, &matcher); + + let input_single = batch( + vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], + [0].to_vec(), + vec![], + ); + + let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); + let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); + + insta::assert_yaml_snapshot!("mistral_single", embeddings_single, &matcher); + assert_eq!(embeddings_batch[0], embeddings_single[0]); + assert_eq!(embeddings_batch[2], embeddings_single[0]); + + Ok(()) +} diff --git a/backends/candle/tests/test_flash_nomic.rs b/backends/candle/tests/test_flash_nomic.rs index 3e9b6e1d..263bbe43 100644 --- a/backends/candle/tests/test_flash_nomic.rs +++ b/backends/candle/tests/test_flash_nomic.rs @@ -1,9 +1,9 @@ #![allow(dead_code, unused_imports)] mod common; -use crate::common::{sort_embeddings, SnapshotScores}; +use crate::common::{sort_embeddings, SnapshotEmbeddings}; use anyhow::Result; -use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; +use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; use text_embeddings_backend_candle::CandleBackend; use text_embeddings_backend_core::{Backend, ModelType, Pool}; @@ -30,10 +30,10 @@ fn test_flash_nomic_small() -> Result<()> { vec![], ); - let matcher = relative_matcher(); + let matcher = cosine_matcher(); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); - let embeddings_batch = SnapshotScores::from(pooled_embeddings); + let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("nomic_batch", embeddings_batch, &matcher); let input_single = batch( @@ -43,7 +43,7 @@ fn test_flash_nomic_small() -> Result<()> { ); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); - let embeddings_single = SnapshotScores::from(pooled_embeddings); + let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("nomic_single", embeddings_single, &matcher); assert_eq!(embeddings_batch[0], embeddings_single[0]); diff --git a/backends/candle/tests/test_jina.rs b/backends/candle/tests/test_jina.rs index 4cd7bba6..4aa30d03 100644 --- a/backends/candle/tests/test_jina.rs +++ b/backends/candle/tests/test_jina.rs @@ -1,8 +1,8 @@ mod common; -use crate::common::{sort_embeddings, SnapshotScores}; +use crate::common::{sort_embeddings, SnapshotEmbeddings}; use anyhow::Result; -use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; +use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; use text_embeddings_backend_candle::CandleBackend; use text_embeddings_backend_core::{Backend, ModelType, Pool}; @@ -27,10 +27,10 @@ fn test_jina_small() -> Result<()> { vec![], ); - let matcher = relative_matcher(); + let matcher = cosine_matcher(); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); - let embeddings_batch = SnapshotScores::from(pooled_embeddings); + let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("jina_batch", embeddings_batch, &matcher); let input_single = batch( @@ -40,7 +40,7 @@ fn test_jina_small() -> Result<()> { ); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); - let embeddings_single = SnapshotScores::from(pooled_embeddings); + let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("jina_single", embeddings_single, &matcher); assert_eq!(embeddings_batch[0], embeddings_single[0]); diff --git a/backends/candle/tests/test_jina_code.rs b/backends/candle/tests/test_jina_code.rs index 70248e1a..6c3b3f20 100644 --- a/backends/candle/tests/test_jina_code.rs +++ b/backends/candle/tests/test_jina_code.rs @@ -1,8 +1,8 @@ mod common; -use crate::common::{sort_embeddings, SnapshotScores}; +use crate::common::{sort_embeddings, SnapshotEmbeddings}; use anyhow::Result; -use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; +use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; use text_embeddings_backend_candle::CandleBackend; use text_embeddings_backend_core::{Backend, ModelType, Pool}; @@ -27,10 +27,10 @@ fn test_jina_code_base() -> Result<()> { vec![], ); - let matcher = relative_matcher(); + let matcher = cosine_matcher(); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); - let embeddings_batch = SnapshotScores::from(pooled_embeddings); + let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("jina_code_batch", embeddings_batch, &matcher); let input_single = batch( @@ -40,7 +40,7 @@ fn test_jina_code_base() -> Result<()> { ); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); - let embeddings_single = SnapshotScores::from(pooled_embeddings); + let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("jina_code_single", embeddings_single, &matcher); assert_eq!(embeddings_batch[0], embeddings_single[0]); diff --git a/backends/candle/tests/test_nomic.rs b/backends/candle/tests/test_nomic.rs index 914be7ea..ce0a4559 100644 --- a/backends/candle/tests/test_nomic.rs +++ b/backends/candle/tests/test_nomic.rs @@ -1,8 +1,8 @@ mod common; -use crate::common::{sort_embeddings, SnapshotScores}; +use crate::common::{sort_embeddings, SnapshotEmbeddings}; use anyhow::Result; -use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; +use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; use text_embeddings_backend_candle::CandleBackend; use text_embeddings_backend_core::{Backend, ModelType, Pool}; @@ -27,10 +27,10 @@ fn test_nomic_small() -> Result<()> { vec![], ); - let matcher = relative_matcher(); + let matcher = cosine_matcher(); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); - let embeddings_batch = SnapshotScores::from(pooled_embeddings); + let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("nomic_batch", embeddings_batch, &matcher); let input_single = batch( @@ -40,7 +40,7 @@ fn test_nomic_small() -> Result<()> { ); let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); - let embeddings_single = SnapshotScores::from(pooled_embeddings); + let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); insta::assert_yaml_snapshot!("nomic_single", embeddings_single, &matcher); assert_eq!(embeddings_batch[0], embeddings_single[0]); diff --git a/backends/core/src/lib.rs b/backends/core/src/lib.rs index 06cef3ed..932c0083 100644 --- a/backends/core/src/lib.rs +++ b/backends/core/src/lib.rs @@ -63,6 +63,8 @@ pub enum Pool { /// This option is only available if the loaded model is a `ForMaskedLM` Transformer /// model. Splade, + /// Select the last token as embedding + LastToken, } impl fmt::Display for Pool { @@ -71,6 +73,7 @@ impl fmt::Display for Pool { Pool::Cls => write!(f, "cls"), Pool::Mean => write!(f, "mean"), Pool::Splade => write!(f, "splade"), + Pool::LastToken => write!(f, "last_token"), } } } diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 220e05cb..9b5d1762 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -1,5 +1,6 @@ mod dtype; +use std::cmp::{max, min}; use std::path::PathBuf; use std::sync::Arc; use std::thread::JoinHandle; @@ -67,6 +68,62 @@ impl Backend { }) } + #[instrument(skip(self))] + pub async fn warmup( + &self, + max_input_length: usize, + max_batch_tokens: usize, + max_batch_requests: Option, + ) -> Result<(), BackendError> { + let mut input_ids = Vec::with_capacity(max_batch_tokens); + let mut token_type_ids = Vec::with_capacity(max_batch_tokens); + let mut position_ids = Vec::with_capacity(max_batch_tokens); + + let mut cumulative_seq_lengths = vec![0]; + let mut pooled_indices = Vec::new(); + + let mut i = 0_u32; + let mut remaining = max_batch_tokens; + let mut cumulative_length = 0; + let mut max_length = 0; + + while remaining > 0 { + let request_length = min(remaining, max_input_length); + cumulative_length += request_length; + max_length = max(max_length, request_length as u32); + + input_ids.extend(vec![0; request_length]); + token_type_ids.extend(vec![0; request_length]); + position_ids.extend((0..request_length as u32).collect::>()); + + cumulative_seq_lengths.push(cumulative_length as u32); + pooled_indices.push(i); + + i += 1; + remaining = remaining.saturating_sub(max_input_length); + if let Some(max_batch_requests) = &max_batch_requests { + if i as usize == *max_batch_requests { + break; + } + } + } + + let batch = Batch { + input_ids, + token_type_ids, + position_ids, + cumulative_seq_lengths, + max_length, + pooled_indices, + raw_indices: vec![], + }; + + match &self.model_type { + ModelType::Classifier => self.predict(batch).await.map(|_| ()), + ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()), + } + } + #[instrument(skip(self))] pub async fn health(&self) -> Result<(), BackendError> { if *self.health_receiver.borrow() { diff --git a/core/Cargo.toml b/core/Cargo.toml index d2f23eb0..d69871dc 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -9,6 +9,7 @@ homepage.workspace = true async-channel = "^2.3" hf-hub = { workspace = true } metrics = { workspace = true } +serde_json = { workspace = true } text-embeddings-backend = { path = "../backends" } thiserror = { workspace = true } tokenizers = { workspace = true } diff --git a/core/src/download.rs b/core/src/download.rs index 6cc60472..24dc041f 100644 --- a/core/src/download.rs +++ b/core/src/download.rs @@ -19,20 +19,22 @@ pub async fn download_artifacts(api: &ApiRepo) -> Result { tracing::info!("Starting download"); + tracing::info!("Downloading `config.json`"); api.get("config.json").await?; + + tracing::info!("Downloading `tokenizer.json`"); api.get("tokenizer.json").await?; - let model_root = match api.get("model.safetensors").await { + let model_files = match download_safetensors(api).await { Ok(p) => p, Err(_) => { + tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower."); + tracing::info!("Downloading `pytorch_model.bin`"); let p = api.get("pytorch_model.bin").await?; - tracing::warn!("`model.safetensors` not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower."); - p + vec![p] } - } - .parent() - .unwrap() - .to_path_buf(); + }; + let model_root = model_files[0].parent().unwrap().to_path_buf(); tracing::info!("Model artifacts downloaded in {:?}", start.elapsed()); Ok(model_root) @@ -40,10 +42,50 @@ pub async fn download_artifacts(api: &ApiRepo) -> Result { #[instrument(skip_all)] pub async fn download_pool_config(api: &ApiRepo) -> Result { + tracing::info!("Downloading `1_Pooling/config.json`"); let pool_config_path = api.get("1_Pooling/config.json").await?; Ok(pool_config_path) } +async fn download_safetensors(api: &ApiRepo) -> Result, ApiError> { + // Single file + tracing::info!("Downloading `model.safetensors`"); + match api.get("model.safetensors").await { + Ok(p) => return Ok(vec![p]), + Err(err) => tracing::warn!("Could not download `model.safetensors`: {}", err), + }; + + // Sharded weights + // Download and parse index file + tracing::info!("Downloading `model.safetensors.index.json`"); + let index_file = api.get("model.safetensors.index.json").await?; + let index_file_string: String = + std::fs::read_to_string(index_file).expect("model.safetensors.index.json is corrupted"); + let json: serde_json::Value = serde_json::from_str(&index_file_string) + .expect("model.safetensors.index.json is corrupted"); + + let weight_map = match json.get("weight_map") { + Some(serde_json::Value::Object(map)) => map, + _ => panic!("model.safetensors.index.json is corrupted"), + }; + + let mut safetensors_filenames = std::collections::HashSet::new(); + for value in weight_map.values() { + if let Some(file) = value.as_str() { + safetensors_filenames.insert(file.to_string()); + } + } + + // Download weight files + let mut safetensors_files = Vec::new(); + for n in safetensors_filenames { + tracing::info!("Downloading `{}`", n); + safetensors_files.push(api.get(&n).await?); + } + + Ok(safetensors_files) +} + #[instrument(skip_all)] pub async fn download_st_config(api: &ApiRepo) -> Result { // Try default path diff --git a/core/src/infer.rs b/core/src/infer.rs index 66d04e19..7e6a4629 100644 --- a/core/src/infer.rs +++ b/core/src/infer.rs @@ -494,11 +494,19 @@ async fn batching_task(queue: Queue, notify: Arc, embed_sender: mpsc::Se loop { notify.notified().await; - while let Some(next_batch) = queue.next_batch().await { - embed_sender - .send(next_batch) + { + let mut permit = embed_sender + .reserve() .await .expect("embed receiver was dropped. This is a bug."); + + while let Some(next_batch) = queue.next_batch().await { + permit.send(next_batch); + permit = embed_sender + .reserve() + .await + .expect("embed receiver was dropped. This is a bug."); + } } } } diff --git a/load_tests/load.js b/load_tests/load.js index 86719b25..b3705476 100644 --- a/load_tests/load.js +++ b/load_tests/load.js @@ -27,7 +27,7 @@ export const options = { executor: 'constant-arrival-rate', duration: '30s', preAllocatedVUs: 5000, - rate: 1000, + rate: 10, timeUnit: '1s', gracefulStop: '1s', }, diff --git a/router/src/grpc/server.rs b/router/src/grpc/server.rs index 389e3848..c428e065 100644 --- a/router/src/grpc/server.rs +++ b/router/src/grpc/server.rs @@ -904,6 +904,7 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { match &self.info.model_type { ModelType::Classifier(_) => { let counter = metrics::counter!("te_request_failure", "err" => "model_type"); + counter.increment(1); let message = "model is not a re-ranker model".to_string(); tracing::error!("{message}"); Err(Status::new(Code::FailedPrecondition, message)) @@ -911,6 +912,7 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { ModelType::Reranker(_) => Ok(()), ModelType::Embedding(_) => { let counter = metrics::counter!("te_request_failure", "err" => "model_type"); + counter.increment(1); let message = "model is not a classifier model".to_string(); tracing::error!("{message}"); Err(Status::new(Code::FailedPrecondition, message)) @@ -1080,6 +1082,7 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { match &self.info.model_type { ModelType::Classifier(_) => { let counter = metrics::counter!("te_request_failure", "err" => "model_type"); + counter.increment(1); let message = "model is not a re-ranker model".to_string(); tracing::error!("{message}"); Err(Status::new(Code::FailedPrecondition, message)) @@ -1087,6 +1090,7 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { ModelType::Reranker(_) => Ok(()), ModelType::Embedding(_) => { let counter = metrics::counter!("te_request_failure", "err" => "model_type"); + counter.increment(1); let message = "model is not a classifier model".to_string(); tracing::error!("{message}"); Err(Status::new(Code::FailedPrecondition, message)) diff --git a/router/src/lib.rs b/router/src/lib.rs index eca9c61f..5c7899ec 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -26,7 +26,7 @@ use std::fs; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::Path; use std::time::{Duration, Instant}; -use text_embeddings_backend::DType; +use text_embeddings_backend::{DType, Pool}; use text_embeddings_core::download::{ download_artifacts, download_pool_config, download_st_config, ST_CONFIG_NAMES, }; @@ -207,6 +207,12 @@ pub async fn run( .await .context("Model backend is not healthy")?; + tracing::info!("Warming up model"); + backend + .warmup(max_input_length, max_batch_tokens, max_batch_requests) + .await + .context("Model backend is not healthy")?; + let max_batch_requests = backend .max_batch_size .map(|s| { @@ -336,13 +342,7 @@ fn get_backend_model_type( let config = fs::read_to_string(config_path).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.")?; let config: PoolConfig = serde_json::from_str(&config).context("Failed to parse `1_Pooling/config.json`")?; - if config.pooling_mode_cls_token { - text_embeddings_backend::Pool::Cls - } else if config.pooling_mode_mean_tokens { - text_embeddings_backend::Pool::Mean - } else { - return Err(anyhow!("Pooling config {config:?} is not supported")); - } + Pool::try_from(config)? } }; Ok(text_embeddings_backend::ModelType::Embedding(pool)) @@ -364,8 +364,25 @@ pub struct ModelConfig { pub struct PoolConfig { pooling_mode_cls_token: bool, pooling_mode_mean_tokens: bool, - pooling_mode_max_tokens: bool, - pooling_mode_mean_sqrt_len_tokens: bool, + #[serde(default)] + pooling_mode_lasttoken: bool, +} + +impl TryFrom for Pool { + type Error = anyhow::Error; + + fn try_from(config: PoolConfig) -> std::result::Result { + if config.pooling_mode_cls_token { + return Ok(Pool::Cls); + } + if config.pooling_mode_mean_tokens { + return Ok(Pool::Mean); + } + if config.pooling_mode_lasttoken { + return Ok(Pool::LastToken); + } + Err(anyhow!("Pooling config {config:?} is not supported")) + } } #[derive(Debug, Deserialize)]