Skip to content

Commit

Permalink
receive new server address from external indexing router, check proto…
Browse files Browse the repository at this point in the history
…col version on first message, handle error messages correctly
  • Loading branch information
var77 committed Sep 16, 2024
1 parent 9db16cc commit 1dfe10f
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 59 deletions.
2 changes: 1 addition & 1 deletion src/hnsw/build.c
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ static void BuildIndex(Relation heap, Relation index, IndexInfo *indexInfo, ldb_
buildstate->external_socket, &num_added_vectors, &buildstate->index_buffer_size, buildstate->status);
CheckBuildIndexError(buildstate);

uint32 bytes_read = external_index_receive_index_part(
uint32 bytes_read = external_index_receive_all(
buildstate->external_socket, buildstate->index_buffer, USEARCH_HEADER_SIZE, buildstate->status);
CheckBuildIndexError(buildstate);

Expand Down
8 changes: 4 additions & 4 deletions src/hnsw/external_index.c
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,10 @@ void StoreExternalIndex(Relation index,
while(tuples_indexed < num_added_vectors) {
local_progress = 0;

bytes_read = external_index_receive_index_part(external_index_socket,
external_index_data + buffer_position,
EXTERNAL_INDEX_FILE_BUFFER_SIZE - buffer_position,
status);
bytes_read = external_index_receive_all(external_index_socket,
external_index_data + buffer_position,
EXTERNAL_INDEX_FILE_BUFFER_SIZE - buffer_position,
status);
total_bytes_read += bytes_read;

if(status->code != BUILD_INDEX_OK) {
Expand Down
243 changes: 201 additions & 42 deletions src/hnsw/external_index_socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <hnsw/build.h>
#include <miscadmin.h>
#include <netdb.h>
#include <string.h>
#include <unistd.h>

static bool is_little_endian()
Expand Down Expand Up @@ -70,11 +71,9 @@ static int connect_with_timeout(int sockfd, const struct sockaddr *addr, socklen
// Set the socket to non-blocking mode
int flags = fcntl(sockfd, F_GETFL, 0);
if(flags == -1) {
perror("fcntl F_GETFL");
return -1;
}
if(fcntl(sockfd, F_SETFL, flags | O_NONBLOCK) == -1) {
perror("fcntl F_SETFL");
return -1;
}

Expand Down Expand Up @@ -127,6 +126,60 @@ static int connect_with_timeout(int sockfd, const struct sockaddr *addr, socklen
return 0;
}

static void wait_for_data(external_index_socket_t *socket_con, BuildIndexStatus *status)
{
struct timeval timeout;
fd_set read_fds;

int interval = 5;

// Set the socket to non-blocking mode
int flags = fcntl(socket_con->fd, F_GETFL, 0);
if(flags == -1) {
status->code = BUILD_INDEX_FAILED;
strncpy(status->error, "error getting socket flags", BUILD_INDEX_MAX_ERROR_SIZE);
return;
}

if(fcntl(socket_con->fd, F_SETFL, flags | O_NONBLOCK) == -1) {
status->code = BUILD_INDEX_FAILED;
strncpy(status->error, "error setting socket to non-blocking mode", BUILD_INDEX_MAX_ERROR_SIZE);
return;
}

while(1) {
FD_ZERO(&read_fds);
FD_SET(socket_con->fd, &read_fds);

timeout.tv_sec = 5;
timeout.tv_usec = 0;

int activity = select(socket_con->fd + 1, &read_fds, NULL, NULL, &timeout);

if(activity < 0) {
status->code = BUILD_INDEX_FAILED;
strncpy(status->error, "select syscall error", BUILD_INDEX_MAX_ERROR_SIZE);
return;
}

// If socket has data to read
if(FD_ISSET(socket_con->fd, &read_fds)) {
// Restore the socket to blocking mode
if(fcntl(socket_con->fd, F_SETFL, flags) == -1) {
status->code = BUILD_INDEX_FAILED;
strncpy(status->error, "error setting socket to blocking mode", BUILD_INDEX_MAX_ERROR_SIZE);
}
return;
}

// Check for interrupts on each iteration
if(INTERRUPTS_PENDING_CONDITION()) {
status->code = BUILD_INDEX_INTERRUPT;
return;
}
}
}

/**
* Check for error received from socket response
* This function will return void setting the corresponding error code and error message
Expand All @@ -142,6 +195,10 @@ static void set_external_index_response_status(external_index_socket_t *socket_c
BuildIndexStatus *status)
{
uint32 hdr;
uint32 err_msg_size = 0;
uint32 bytes_read = 0;
uint32 total_bytes_read = 0;
char recv_error[ BUILD_INDEX_MAX_ERROR_SIZE ];

if(size < 0) {
status->code = BUILD_INDEX_FAILED;
Expand All @@ -160,11 +217,34 @@ static void set_external_index_response_status(external_index_socket_t *socket_c
return;
};

buffer[ size - 1 ] = '\0';
// if we receive EXTERNAL_INDEX_ERR_MSG header
// the server should send err_msg_bytes (uint32) followed by the actual error message
// we will read and check errors here manually to not get stuck into recursion

bytes_read = socket_con->read(socket_con, (char *)&recv_error, sizeof(uint32));

if(bytes_read != sizeof(uint32)) {
status->code = BUILD_INDEX_FAILED;
strncpy(status->error, "external index socket read failed", BUILD_INDEX_MAX_ERROR_SIZE);
return;
}
memcpy(&err_msg_size, recv_error, sizeof(uint32));

while(total_bytes_read < err_msg_size) {
bytes_read
= socket_con->read(socket_con, (char *)&recv_error + total_bytes_read, err_msg_size - total_bytes_read);

if(bytes_read < 0) {
status->code = BUILD_INDEX_FAILED;
strncpy(status->error, "external index socket read failed", BUILD_INDEX_MAX_ERROR_SIZE);
return;
}

total_bytes_read += bytes_read;
}

status->code = BUILD_INDEX_FAILED;
snprintf(
status->error, BUILD_INDEX_MAX_ERROR_SIZE, "external index error: %s", buffer + EXTERNAL_INDEX_MAGIC_MSG_SIZE);
snprintf(status->error, BUILD_INDEX_MAX_ERROR_SIZE, "external index error: %s", (char *)&recv_error);
}

static void set_external_index_request_status(external_index_socket_t *socket_con,
Expand Down Expand Up @@ -207,6 +287,40 @@ static void write_all(
status->code = BUILD_INDEX_OK;
}

uint64 external_index_receive_all(external_index_socket_t *socket_con,
char *result_buf,
uint64 size,
BuildIndexStatus *status)
{
int64 bytes_read;
uint64 index_size = 0, total_received = 0;

// start reading index into buffer
while(total_received < size) {
bytes_read = socket_con->read(socket_con, result_buf + total_received, size - total_received);

// Check for CTRL-C interrupts
if(INTERRUPTS_PENDING_CONDITION()) {
status->code = BUILD_INDEX_INTERRUPT;
return total_received;
}

set_external_index_response_status(socket_con, result_buf, bytes_read, status);

if(status->code != BUILD_INDEX_OK) {
return total_received;
}

if(bytes_read == 0) {
break;
}

total_received += (uint32)bytes_read;
}

return total_received;
}

static void external_index_send_codebook(external_index_socket_t *socket_con,
float *codebook,
uint32 dimensions,
Expand Down Expand Up @@ -243,6 +357,7 @@ external_index_socket_t *create_external_index_session(const char
char port_str[ 5 ];
struct addrinfo *serv_addr, hints = {0};
char init_response[ EXTERNAL_INDEX_INIT_BUFFER_SIZE ] = {0};
int64 bytes_read = 0;

if(!is_little_endian()) {
buildstate->status->code = BUILD_INDEX_FAILED;
Expand Down Expand Up @@ -286,7 +401,7 @@ external_index_socket_t *create_external_index_session(const char

socket_con->fd = client_fd;
hints.ai_socktype = SOCK_STREAM; // TCP socket
snprintf(port_str, 5, "%u", port);
snprintf(port_str, 6, "%u", port);
status = getaddrinfo(host, port_str, &hints, &serv_addr);

if(status != 0) {
Expand Down Expand Up @@ -321,6 +436,84 @@ external_index_socket_t *create_external_index_session(const char
elog(INFO, "successfully connected to external indexing server");
socket_con->init(socket_con);

// receive and check protocol version
bytes_read = socket_con->read(socket_con, (char *)&init_response, EXTERNAL_INDEX_MAGIC_MSG_SIZE);
set_external_index_response_status(socket_con, (char *)init_response, bytes_read, buildstate->status);
if(buildstate->status->code != BUILD_INDEX_OK) {
return socket_con;
}
uint32 protocol_version = 0;
memcpy(&protocol_version, init_response, sizeof(uint32));

if(protocol_version != EXTERNAL_INDEX_PROTOCOL_VERSION) {
buildstate->status->code = BUILD_INDEX_FAILED;
snprintf(buildstate->status->error,
BUILD_INDEX_MAX_ERROR_SIZE,
"external index protocol version mismatch - client version: %u, server version: %u",
EXTERNAL_INDEX_PROTOCOL_VERSION,
protocol_version);
return socket_con;
}
// check server type
bytes_read = socket_con->read(socket_con, (char *)&init_response, EXTERNAL_INDEX_MAGIC_MSG_SIZE);
set_external_index_response_status(socket_con, (char *)init_response, bytes_read, buildstate->status);
if(buildstate->status->code != BUILD_INDEX_OK) {
return socket_con;
}
uint32 server_type = 0;
memcpy(&server_type, init_response, sizeof(uint32));

if(server_type == EXTERNAL_INDEX_ROUTER_SERVER_TYPE) {
uint32 is_secure = 0;
uint32 address_length = 0;
uint32 port_number = 0;
char address[ 1024 ] = {0};
uint32 get_server_msg = 0x3;

elog(INFO, "receiving new server address from router... (this may take up to 10m)");
memcpy(init_buf, &get_server_msg, sizeof(uint32));
write_all(socket_con, init_buf, sizeof(uint32), 0, buildstate->status);

// wait for data to be available for read and also check for interrupts each 5s
wait_for_data(socket_con, buildstate->status);

if(buildstate->status->code != BUILD_INDEX_OK) {
return socket_con;
}

bytes_read = socket_con->read(socket_con, (char *)&init_response, sizeof(uint32));
set_external_index_response_status(socket_con, (char *)init_response, bytes_read, buildstate->status);
if(buildstate->status->code != BUILD_INDEX_OK) {
return socket_con;
}
memcpy(&is_secure, init_response, sizeof(uint32));

bytes_read = socket_con->read(socket_con, (char *)&init_response, sizeof(uint32));
set_external_index_response_status(socket_con, (char *)init_response, bytes_read, buildstate->status);
if(buildstate->status->code != BUILD_INDEX_OK) {
return socket_con;
}
memcpy(&address_length, init_response, sizeof(uint32));

external_index_receive_all(socket_con, (char *)&address, address_length, buildstate->status);
if(buildstate->status->code != BUILD_INDEX_OK) {
return socket_con;
}

bytes_read = socket_con->read(socket_con, (char *)&init_response, sizeof(uint32));
set_external_index_response_status(socket_con, (char *)init_response, bytes_read, buildstate->status);
if(buildstate->status->code != BUILD_INDEX_OK) {
return socket_con;
}
memcpy(&port_number, init_response, sizeof(uint32));

socket_con->close(socket_con);

// connect to new address
return create_external_index_session(
address, port_number, (bool)is_secure, params, buildstate, estimated_row_count);
}

external_index_params_t index_params = {
.pq = params->pq,
.metric_kind = params->metric_kind,
Expand Down Expand Up @@ -358,9 +551,9 @@ external_index_socket_t *create_external_index_session(const char
}
}

int64 buf_size = socket_con->read(socket_con, (char *)&init_response, EXTERNAL_INDEX_INIT_BUFFER_SIZE);
bytes_read = socket_con->read(socket_con, (char *)&init_response, EXTERNAL_INDEX_INIT_BUFFER_SIZE);

set_external_index_response_status(socket_con, (char *)init_response, buf_size, buildstate->status);
set_external_index_response_status(socket_con, (char *)init_response, bytes_read, buildstate->status);

return socket_con;
}
Expand Down Expand Up @@ -407,40 +600,6 @@ void external_index_receive_metadata(external_index_socket_t *socket_con,
memcpy(index_size, buffer, sizeof(uint64));
}

uint64 external_index_receive_index_part(external_index_socket_t *socket_con,
char *result_buf,
uint64 size,
BuildIndexStatus *status)
{
int64 bytes_read;
uint64 index_size = 0, total_received = 0;

// start reading index into buffer
while(total_received < size) {
bytes_read = socket_con->read(socket_con, result_buf + total_received, size - total_received);

// Check for CTRL-C interrupts
if(INTERRUPTS_PENDING_CONDITION()) {
status->code = BUILD_INDEX_INTERRUPT;
return total_received;
}

set_external_index_response_status(socket_con, result_buf, bytes_read, status);

if(status->code != BUILD_INDEX_OK) {
return total_received;
}

if(bytes_read == 0) {
break;
}

total_received += (uint32)bytes_read;
}

return total_received;
}

void external_index_send_tuple(external_index_socket_t *socket_con,
usearch_label_t *label,
void *vector,
Expand Down
27 changes: 15 additions & 12 deletions src/hnsw/external_index_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
#include "external_index_socket_ssl.h"
#include "usearch.h"

#define EXTERNAL_INDEX_MAGIC_MSG_SIZE 4
#define EXTERNAL_INDEX_INIT_MSG 0x13333337
#define EXTERNAL_INDEX_END_MSG 0x31333337
#define EXTERNAL_INDEX_ERR_MSG 0x37333337
#define EXTERNAL_INDEX_INIT_BUFFER_SIZE 1024
#define EXTERNAL_INDEX_FILE_BUFFER_SIZE 1024 * 1024 * 10 // 10MB
#define EXTERNAL_INDEX_SOCKET_TIMEOUT 10 // 10 seconds
#define EXTERNAL_INDEX_MAGIC_MSG_SIZE 4
#define EXTERNAL_INDEX_INIT_MSG 0x13333337
#define EXTERNAL_INDEX_END_MSG 0x31333337
#define EXTERNAL_INDEX_ERR_MSG 0x37333337
#define EXTERNAL_INDEX_INIT_BUFFER_SIZE 1024
#define EXTERNAL_INDEX_FILE_BUFFER_SIZE 1024 * 1024 * 10 // 10MB
#define EXTERNAL_INDEX_SOCKET_TIMEOUT 10 // 10 seconds
#define EXTERNAL_INDEX_ROUTER_SOCKET_TIMEOUT 600 // 10 minutes
// maximum tuple size can be 8kb (8192 byte) + 8 byte label
#define EXTERNAL_INDEX_MAX_TUPLE_SIZE 8200
#define EXTERNAL_INDEX_MAX_TUPLE_SIZE 8200
#define EXTERNAL_INDEX_PROTOCOL_VERSION 1
#define EXTERNAL_INDEX_ROUTER_SERVER_TYPE 0x2

typedef struct external_index_params_t
{
Expand Down Expand Up @@ -66,10 +69,10 @@ void external_index_receive_metadata(external_index_socket_t
uint64 *num_added_vectors,
uint64 *index_size,
BuildIndexStatus *status);
uint64 external_index_receive_index_part(external_index_socket_t *socket_con,
char *result_buf,
uint64 size,
BuildIndexStatus *status);
uint64 external_index_receive_all(external_index_socket_t *socket_con,
char *result_buf,
uint64 size,
BuildIndexStatus *status);
void external_index_send_tuple(external_index_socket_t *socket_con,
usearch_label_t *label,
void *vector,
Expand Down

0 comments on commit 1dfe10f

Please sign in to comment.