diff --git a/.bazelrc b/.bazelrc index de238aa..2f1baa6 100644 --- a/.bazelrc +++ b/.bazelrc @@ -21,11 +21,26 @@ build --enable_platform_specific_config build --cxxopt=-std=c++17 build --host_cxxopt=-std=c++17 +build:avx --copt=-mavx +build:avx --host_copt=-mavx +build:avx --copt=-DCHECK_AVX +build:avx --host_copt=-DCHECK_AVX + +# Binary safety flags +build --copt=-fPIC +build --copt=-fstack-protector-strong +build:linux --copt=-Wl,-z,noexecstack +build:macos --copt=-Wa,--noexecstack + test --keep_going test --test_output=errors test --test_timeout=1800 +# static link runtime libraries on Linux +build:linux-release --action_env=BAZEL_LINKOPTS=-static-libstdc++:-static-libgcc +build:linux-release --action_env=BAZEL_LINKLIBS=-l%:libstdc++.a:-l%:libgcc.a + # platform specific config # Bazel will automatic pick platform config since we have enable_platform_specific_config set build:macos --copt="-Xpreprocessor -fopenmp" @@ -36,10 +51,8 @@ build:macos --cxxopt -Wno-deprecated-anon-enum-enum-conversion build:macos --macos_minimum_os=11.0 build:macos --host_macos_minimum_os=11.0 -# static link libstdc++ & libgcc on Linux build:linux --copt=-fopenmp -build:linux --action_env=BAZEL_LINKOPTS=-static-libstdc++:-static-libgcc -build:linux --action_env=BAZEL_LINKLIBS=-l%:libstdc++.a:-l%:libgcc.a +build:linux --linkopt=-fopenmp build:asan --strip=never build:asan --copt -fno-sanitize-recover=all diff --git a/.ci/accuracy_test.py b/.ci/accuracy_test.py new file mode 100644 index 0000000..b4ead7b --- /dev/null +++ b/.ci/accuracy_test.py @@ -0,0 +1,317 @@ +#! python3 + +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import getpass +import json +import os +import subprocess +import sys +import tempfile +import time +from dataclasses import dataclass +from typing import Any, Dict, List + +from pyarrow import csv as pa_csv +from google.protobuf.json_format import MessageToJson + +from secretflow_serving_lib.feature_pb2 import FeatureParam + + +TEST_STORAGE_ROOT = os.path.join(tempfile.gettempdir(), getpass.getuser()) + +# set up global env +g_script_name = os.path.abspath(sys.argv[0]) +g_script_dir = os.path.dirname(g_script_name) +g_repo_dir = os.path.dirname(g_script_dir) + +g_clean_up_service = True +g_clean_up_files = True + + +def global_ip_config(index): + cluster_ip = ["127.0.0.1:9810", "127.0.0.1:9811"] + metrics_port = [8318, 8319] + brpc_builtin_port = [8328, 8329] + assert index < len(cluster_ip) + return { + "cluster_ip": cluster_ip[index], + "metrics_port": metrics_port[index], + "brpc_builtin_service_port": brpc_builtin_port[index], + } + + +def dump_json(obj, filename, indent=2): + with open(filename, "w") as ofile: + json.dump(obj, ofile, indent=indent) + + +def is_approximately_equal(a, b, epsilon) -> bool: + return abs(a - b) < epsilon + + +class CurlWrapper: + def __init__(self, url: str, data: str): + self.url = url + self.data = data + + def cmd(self): + return f'curl --location "{self.url}" --header "Content-Type: application/json" --data \'{self.data}\'' + + def exe(self): + return os.popen(self.cmd()) + + +@dataclass +class PartyConfig: + id: str + cluster_ip: str + metrics_port: int + brpc_builtin_service_port: int + channel_protocol: str + model_id: str + model_package_path: str + csv_path: str + query_datas: List[str] = None + query_context: str = None + + +class ConfigBuilder: + def __init__( + self, + party_configs: Dict[str, PartyConfig], + service_id: str, + serving_config_filename: str = "serving.config", + log_config_filename: str = "logging.config", + ): + self.service_id = service_id + self.party_configs = party_configs + self.parties = [] + self.log_config = log_config_filename + self.serving_config = serving_config_filename + for id, config in self.party_configs.items(): + self.parties.append({"id": id, "address": config.cluster_ip}) + + self.logging_config_paths = {} + self.serving_config_paths = {} + + def _dump_logging_config(self, path: str, log_path: str): + logging_config_path = os.path.join(path, self.log_config) + with open(logging_config_path, "w") as ofile: + json.dump({"systemLogPath": os.path.abspath(log_path)}, ofile, indent=2) + return logging_config_path + + def _dump_serving_config(self, path: str, config: PartyConfig): + config_dict = { + "id": self.service_id, + "serverConf": { + "metricsExposerPort": config.metrics_port, + "brpcBuiltinServicePort": config.brpc_builtin_service_port, + }, + "modelConf": { + "modelId": config.model_id, + "basePath": os.path.abspath(path), + "sourcePath": os.path.abspath(config.model_package_path), + "sourceType": "ST_FILE", + }, + "clusterConf": { + "selfId": config.id, + "parties": self.parties, + "channelDesc": {"protocol": config.channel_protocol}, + }, + "featureSourceConf": { + "csvOpts": { + "filePath": os.path.abspath(config.csv_path), + "id_name": "id", + } + }, + } + config_path = os.path.join(path, self.serving_config) + dump_json(config_dict, config_path) + return config_path + + def finish(self, path="."): + for id, config in self.party_configs.items(): + config_path = os.path.join(path, id) + if not os.path.exists(config_path): + os.makedirs(config_path, exist_ok=True) + self.logging_config_paths[id] = self._dump_logging_config( + config_path, os.path.join(config_path, "log") + ) + self.serving_config_paths[id] = self._dump_serving_config( + config_path, config + ) + + def get_logging_config_paths(self) -> Dict[str, str]: + return self.logging_config_paths + + def get_serving_config_paths(self) -> Dict[str, str]: + return self.serving_config_paths + + +# simple test +class AccuracyTestCase: + def __init__( + self, + service_id: str, + parties: List[str], + case_dir: str, + package_name: str, + input_csv_names: Dict[str, str], + expect_csv_name: str, + query_ids: List[str], + score_col_name: str, + ): + self.service_id = service_id + self.case_dir = case_dir + self.expect_csv_path = os.path.join(case_dir, expect_csv_name) + self.query_ids = query_ids + self.score_col_name = score_col_name + self.party_configs = {} + for index, party in enumerate(parties): + base_dir = os.path.join(case_dir, party) + self.party_configs[party] = PartyConfig( + id=party, + **global_ip_config(index), + channel_protocol="baidu_std", + model_id="accuracy_test", + model_package_path=os.path.join(base_dir, package_name), + csv_path=os.path.join(base_dir, input_csv_names[party]), + query_datas=query_ids, + ) + self.background_proc = [] + self.data_path = os.path.join(TEST_STORAGE_ROOT, self.service_id) + + def _exe_cmd(self, cmd, background=False): + print("Execute: ", cmd) + if not background: + ret = subprocess.run(cmd, shell=True, check=True, capture_output=True) + ret.check_returncode() + return ret + else: + proc = subprocess.Popen(cmd.split(), shell=False) + self.background_proc.append(proc) + return proc + + def start_server(self, start_interval_s=0): + config_builder = ConfigBuilder( + party_configs=self.party_configs, service_id=self.service_id + ) + config_builder.finish(self.data_path) + + logging_config_paths = config_builder.get_logging_config_paths() + serving_config_paths = config_builder.get_serving_config_paths() + + for id in self.party_configs.keys(): + self._exe_cmd( + f"./bazel-bin/secretflow_serving/server/secretflow_serving --serving_config_file={serving_config_paths[id]} --logging_config_file={logging_config_paths[id]}", + True, + ) + if start_interval_s: + time.sleep(start_interval_s) + + # wait 10s for servers be ready + time.sleep(10) + + def stop_server(self): + if g_clean_up_service: + for proc in self.background_proc: + proc.kill() + proc.wait() + if g_clean_up_files: + os.system(f"rm -rf {self.data_path}") + + def _make_request_body(self): + fs_param = {} + for id, config in self.party_configs.items(): + fs_param[id] = FeatureParam(query_datas=self.query_ids) + + body_dict = { + "service_spec": { + "id": self.service_id, + }, + "fs_params": { + k: json.loads(MessageToJson(v, preserving_proto_field_name=True)) + for k, v in fs_param.items() + }, + } + + return json.dumps(body_dict) + + def make_predict_curl_cmd(self, party: str): + url = None + for id, config in self.party_configs.items(): + if id == party: + url = f"http://{config.cluster_ip}/PredictionService/Predict" + break + if not url: + raise Exception(f"{party} is not in config({self.party_configs.keys()})") + curl_wrapper = CurlWrapper( + url=url, + data=self._make_request_body(), + ) + return curl_wrapper.cmd() + + def exec(self): + try: + self.start_server() + + # read expect csv + expect_table = pa_csv.read_csv(self.expect_csv_path) + expect_df = expect_table.to_pandas() + + print(f"expect_df: {expect_df}") + + for party in self.party_configs.keys(): + res = self._exe_cmd(self.make_predict_curl_cmd(party)) + out = res.stdout.decode() + print("Predict Result: ", out) + res = json.loads(out) + assert ( + res["status"]["code"] == 1 + ), f'return status code({res["status"]["code"]}) should be OK(1)' + assert len(res['results']) == len( + self.query_ids + ), f'return results size not match the query size {len(res["results"])} vs {len(self.query_ids)}' + + # check result accuracy + for index, r in enumerate(res['results']): + # TODO: support multi-col score + s = float(r["scores"][0]["value"]) + id = self.query_ids[index] + expect_score = float( + expect_df[expect_df['id'] == int(id)][self.score_col_name].iloc[ + 0 + ] + ) + assert is_approximately_equal( + expect_score, s, 0.0001 + ), f'result not match, {s} vs {expect_score}' + finally: + self.stop_server() + + +if __name__ == "__main__": + AccuracyTestCase( + service_id="bin_onehot_glm", + parties=['alice', 'bob'], + case_dir='.ci/test_data/bin_onehot_glm', + package_name='s_model.tar.gz', + input_csv_names={'alice': 'alice.csv', 'bob': 'bob.csv'}, + expect_csv_name='predict.csv', + query_ids=['1', '2', '3', '4', '5', '6', '7', '8', '9', '15'], + score_col_name='pred', + ).exec() diff --git a/.ci/integration_test.py b/.ci/integration_test.py index bbe71b2..beb0295 100644 --- a/.ci/integration_test.py +++ b/.ci/integration_test.py @@ -22,166 +22,58 @@ import subprocess import sys import tarfile +import time from dataclasses import dataclass -from enum import Enum from typing import Any, Dict, List +import pyarrow as pa +from google.protobuf.json_format import MessageToJson + +from secretflow_serving_lib import get_op +from secretflow_serving_lib.attr_pb2 import AttrValue, DoubleList, StringList +from secretflow_serving_lib.bundle_pb2 import FileFormatType, ModelBundle, ModelManifest +from secretflow_serving_lib.feature_pb2 import ( + Feature, + FeatureField, + FeatureParam, + FeatureValue, + FieldType, +) +from secretflow_serving_lib.graph_pb2 import ( + DispatchType, + ExecutionDef, + GraphDef, + NodeDef, + RuntimeConfig, +) +from secretflow_serving_lib.link_function_pb2 import LinkFunctionType + # set up global env g_script_name = os.path.abspath(sys.argv[0]) g_script_dir = os.path.dirname(g_script_name) g_repo_dir = os.path.dirname(g_script_dir) +g_clean_up_service = True +g_clean_up_files = True -class AttrValue: - def __init__(self, value, type=None): - assert value, "value cannot be null" - self.value = value - self.type = None - - def to_json(self): - def make_dict(type, value): - if isinstance(self.value, list): - return {self.type: {"data": self.value}} - else: - return {self.type: self.value} - - if self.type: - return make_dict(self.type, self.value) - - def get_attr_prefix(value): - if isinstance(value, int): - return "i32" - elif isinstance(value, float): - return "f" - elif isinstance(value, bool): - return "b" - else: - return "s" - - if isinstance(self.value, list): - assert len(self.value) != 0, "list is None, cannont deduce type" - type = get_attr_prefix(self.value[0]) - return make_dict(type + "s", self.value) - else: - return {get_attr_prefix(self.value): self.value} - - -class MapAttrValue: - def __init__(self, attr_values: Dict[str, AttrValue]): - self.data = attr_values - - def to_json(self): - return {k: v.to_json() for k, v in self.data.items()} - - -class DispatchType(Enum): - UNKNOWN_DP_TYPE = 0 - DP_ALL = 1 - DP_ANYONE = 2 - - -class RuntimeConfig: - def __init__(self, dispatch_type: Enum, session_run: bool): - self.dispatch_type = dispatch_type - self.session_run = session_run - - def to_json(self): - return { - "dispatch_type": self.dispatch_type.name, - "session_run": self.session_run, - } - -class NodeDef: - def __init__(self, name: str, op: str, parents: List[str], data_dict): - self.name = name - self.op = op - self.parents = parents - self.attr_values = trans_normal_dict_to_attr_dict(data_dict) - - def to_json(self): - return { - "name": self.name, - "op": self.op, - "parents": self.parents, - "attr_values": self.attr_values.to_json(), - } - - -class ExecutionDef: - def __init__(self, nodes: List[str], config: RuntimeConfig): - self.nodes = nodes - self.config = config - - def to_json(self): - return {"nodes": self.nodes, "config": self.config.to_json()} - - -class GraphDef: - def __init__( - self, - version: str, - node_list: List[NodeDef] = None, - execution_list: List[ExecutionDef] = None, - ): - self.node_list = node_list if node_list else [] - self.execution_list = execution_list if execution_list else [] - self.version = version - - def to_json(self): - return { - "version": self.version, - "node_list": [node.to_json() for node in self.node_list], - "execution_list": [exe.to_json() for exe in self.execution_list], - } - - -class AttrValue: - def __init__(self, value, type=None): - assert value, "value cannot be null" - self.value = value - self.type = None - - def to_json(self): - def make_dict(type, value): - if isinstance(value, list): - return {type: {"data": value}} - else: - return {type: value} - - if self.type: - return make_dict(self.type, self.value) - - def get_attr_prefix(value): - if isinstance(value, int): - return "i32" - elif isinstance(value, float): - return "d" - elif isinstance(value, bool): - return "b" - else: - return "s" - - if isinstance(self.value, list): - assert len(self.value) != 0, "list is None, cannont deduce type" - type = get_attr_prefix(self.value[0]) - return make_dict(type + "s", self.value) - else: - return {get_attr_prefix(self.value): self.value} - - -def trans_normal_dict_to_attr_dict(data_dict): - ret = {} - for k, v in data_dict.items(): - ret[k] = AttrValue(v) - return MapAttrValue(ret) +def global_ip_config(index): + cluster_ip = ["127.0.0.1:9910", "127.0.0.1:9911"] + metrics_port = [10318, 10319] + brpc_builtin_port = [10328, 10329] + assert index < len(cluster_ip) + return { + "cluster_ip": cluster_ip[index], + "metrics_port": metrics_port[index], + "brpc_builtin_service_port": brpc_builtin_port[index], + } -class JsonModel: +class ModelBuilder: def __init__(self, name, desc, graph_def: GraphDef): self.name = name self.desc = desc - self.graph_json = {"name": name, "desc": desc, "graph": graph_def.to_json()} + self.bundle = ModelBundle(name=name, desc=desc, graph=graph_def) def dump_tar_gz(self, path=".", filename=None): if filename is None: @@ -193,12 +85,15 @@ def dump_tar_gz(self, path=".", filename=None): model_graph_filename = "model_graph.json" - dump_json( - {"bundle_path": model_graph_filename, "bundle_format": "FF_JSON"}, + # dump manifest + dump_pb_json_file( + ModelManifest( + bundle_path=model_graph_filename, bundle_format=FileFormatType.FF_JSON + ), os.path.join(path, "MANIFEST"), ) - - dump_json(self.graph_json, os.path.join(path, model_graph_filename)) + # dump model file + dump_pb_json_file(self.bundle, os.path.join(path, model_graph_filename)) with tarfile.open(filename, "w:gz") as model_tar: model_tar.add(os.path.join(path, "MANIFEST"), arcname="MANIFEST") @@ -211,107 +106,86 @@ def dump_tar_gz(self, path=".", filename=None): os.remove(os.path.join(path, "MANIFEST")) os.remove(os.path.join(path, model_graph_filename)) with open(filename, "rb") as ifile: - return filename, hashlib.md5(ifile.read()).hexdigest() - - -class DotProcutAttr: - def __init__(self, weight_dict: Dict[str, float], output_col_name, intercept=None): - self.name = "DOT_PRODUCT" - self.version = "0.0.1" - self.desc = "" - self.weight_dict = weight_dict - self.feature_names = list(weight_dict.keys()) - self.feature_weights = list(weight_dict.values()) - self.output_col_name = output_col_name - self.intercept = intercept - - def get_node_attr_map(self): - assert self.feature_names - assert self.feature_weights - assert self.output_col_name - ret = { - "feature_names": self.feature_names, - "feature_weights": self.feature_weights, - "output_col_name": self.output_col_name, - } - if self.intercept: - ret["intercept"] = self.intercept - return ret - - -class LinkFunction(Enum): - LF_LOG = 1 - LF_LOGIT = 2 - LF_INVERSE = 3 - LF_LOGIT_V2 = 4 - LF_RECIPROCAL = 5 - LF_INDENTITY = 6 - LF_SIGMOID_RAW = 11 - LF_SIGMOID_MM1 = 12 - LF_SIGMOID_MM3 = 13 - LF_SIGMOID_GA = 14 - LF_SIGMOID_T1 = 15 - LF_SIGMOID_T3 = 16 - LF_SIGMOID_T5 = 17 - LF_SIGMOID_T7 = 18 - LF_SIGMOID_T9 = 19 - LF_SIGMOID_LS7 = 20 - LF_SIGMOID_SEG3 = 21 - LF_SIGMOID_SEG5 = 22 - LF_SIGMOID_DF = 23 - LF_SIGMOID_SR = 24 - LF_SIGMOID_SEGLS = 25 - - -class MergeYAttr: - def __init__( - self, - link_function: LinkFunction, - input_col_name: str, - output_col_name: str, - yhat_scale: float = None, - ): - self.name = "MERGE_Y" - self.version = "0.0.1" - self.desc = "" - self.link_function = link_function - self.yhat_scale = yhat_scale - self.input_col_name = input_col_name - self.output_col_name = output_col_name - - def get_node_attr_map(self): - assert self.link_function - assert self.input_col_name - assert self.output_col_name - ret = { - "link_function": self.link_function.name, - "input_col_name": self.input_col_name, - "output_col_name": self.output_col_name, - } - if self.yhat_scale: - ret["yhat_scale"] = self.yhat_scale - return ret + return filename, hashlib.sha256(ifile.read()).hexdigest() + + +def make_processing_node_def( + name, + parents, + input_schema: pa.Schema, + output_schema: pa.Schema, + trace_content=None, +): + op_def = get_op("ARROW_PROCESSING") + attrs = { + "input_schema_bytes": AttrValue(by=input_schema.serialize().to_pybytes()), + "output_schema_bytes": AttrValue(by=output_schema.serialize().to_pybytes()), + "content_json_flag": AttrValue(b=True), + } + if trace_content: + attrs["trace_content"] = AttrValue(by=trace_content) + + return NodeDef( + name=name, + parents=parents, + op=op_def.name, + attr_values=attrs, + op_version=op_def.version, + ) def make_dot_product_node_def( - name, parents, weight_dict, output_col_name, intercept=None + name, parents, weight_dict, output_col_name, input_types, intercept=None ): - dot_op = DotProcutAttr(weight_dict, output_col_name, intercept) + op_def = get_op("DOT_PRODUCT") + attrs = { + "feature_names": AttrValue(ss=StringList(data=list(weight_dict.keys()))), + "feature_weights": AttrValue(ds=DoubleList(data=list(weight_dict.values()))), + "output_col_name": AttrValue(s=output_col_name), + "input_types": AttrValue(ss=StringList(data=input_types)), + } + if intercept: + attrs["intercept"] = AttrValue(d=intercept) + return NodeDef( - name, dot_op.name, parents=parents, data_dict=dot_op.get_node_attr_map() + name=name, + parents=parents, + op=op_def.name, + attr_values=attrs, + op_version=op_def.version, ) def make_merge_y_node_def( name, parents, - link_function: LinkFunction, + link_function: LinkFunctionType, input_col_name: str, output_col_name: str, yhat_scale: float = None, ): - op = MergeYAttr(link_function, input_col_name, output_col_name, yhat_scale) - return NodeDef(name, op.name, parents=parents, data_dict=op.get_node_attr_map()) + op_def = get_op("MERGE_Y") + attrs = { + "link_function": AttrValue(s=LinkFunctionType.Name(link_function)), + "input_col_name": AttrValue(s=input_col_name), + "output_col_name": AttrValue(s=output_col_name), + } + if yhat_scale: + attrs["yhat_scale"] = AttrValue(d=yhat_scale) + + return NodeDef( + name=name, + parents=parents, + op=op_def.name, + attr_values=attrs, + op_version=op_def.version, + ) + + +def dump_pb_json_file(pb_obj, file_name, indent=2): + json_str = MessageToJson(pb_obj) + with open(file_name, "w") as file: + file.write(json_str) def dump_json(obj, filename, indent=2): @@ -357,7 +231,9 @@ def _dump_logging_config(self, path: str, logging_path: str): json.dump({"systemLogPath": os.path.abspath(logging_path)}, ofile, indent=2) def _dump_model_tar_gz(self, path: str, graph_def: GraphDef): - return JsonModel("test_model", "just for test", graph_def).dump_tar_gz( + graph_def_str = MessageToJson(graph_def, preserving_proto_field_name=True) + print(f"graph_def: \n {graph_def_str}") + return ModelBuilder("test_model", "just for test", graph_def).dump_tar_gz( path, self.tar_name ) @@ -382,7 +258,7 @@ def make_csv_config(self, data_dict: Dict[str, List[Any]], path: str): return {"csv_opts": {"file_path": file_path, "id_name": "id"}} def _dump_serving_config( - self, path: str, config: PartyConfig, model_name: str, model_md5: str + self, path: str, config: PartyConfig, model_name: str, model_sha256: str ): config_dict = { "id": self.service_id, @@ -395,7 +271,7 @@ def _dump_serving_config( "modelId": config.model_id, "basePath": os.path.abspath(path), "sourcePath": os.path.abspath(model_name), - "sourceMd5": model_md5, + "sourceSha256": model_sha256, "sourceType": "ST_FILE", }, "clusterConf": { @@ -415,10 +291,10 @@ def dump(self, path="."): if not os.path.exists(config_path): os.makedirs(config_path, exist_ok=True) self._dump_logging_config(config_path, os.path.join(config_path, "log")) - model_name, model_md5 = self._dump_model_tar_gz( + model_name, model_sha256 = self._dump_model_tar_gz( config_path, config.graph_def ) - self._dump_serving_config(config_path, config, model_name, model_md5) + self._dump_serving_config(config_path, config, model_name, model_sha256) # for every testcase, there should be a TestConfig instance @@ -430,6 +306,7 @@ def __init__( header_dict: Dict[str, str] = None, service_spec_id: str = None, predefined_features: Dict[str, List[Any]] = None, + predefined_types: Dict[str, str] = None, log_config_name=None, serving_config_name=None, tar_name=None, @@ -437,6 +314,7 @@ def __init__( self.header_dict = header_dict self.service_spec_id = service_spec_id self.predefined_features = predefined_features + self.predefined_types = predefined_types self.model_path = os.path.join(g_script_dir, model_path) self.party_config = party_config self.log_config_name = ( @@ -448,15 +326,6 @@ def __init__( self.tar_name = tar_name if tar_name is not None else "model.tar.gz" self.background_proc = [] - def get_module_path(self): - return self.model_path - - def get_request_loactions(self): - return [ - f"http://{config.cluster_ip}/PredictionService/Predict" - for config in self.party_config - ] - def dump_config(self): ConfigDumper( self.party_config, @@ -483,7 +352,9 @@ def make_request(self): if self.predefined_features: pre_features = [] for name, data_list in self.predefined_features.items(): - pre_features.append(make_feature(name, data_list)) + pre_features.append( + make_feature(name, data_list, self.predefined_types[name]) + ) else: pre_features = None @@ -491,7 +362,7 @@ def make_request(self): fs_param = {} for config in self.party_config: fs_param[config.id] = FeatureParam( - config.query_datas, config.query_context + query_datas=config.query_datas, query_context=config.query_context ) else: fs_param = None @@ -500,14 +371,16 @@ def make_request(self): self.header_dict, self.service_spec_id, fs_param, pre_features ) - def make_curl_cmd(self, party: str): + def make_predict_curl_cmd(self, party: str): url = None for p_cfg in self.party_config: if p_cfg.id == party: url = f"http://{p_cfg.cluster_ip}/PredictionService/Predict" break if not url: - raise Exception(f"{party} is not in TestConfig({config.get_party_ids()})") + raise Exception( + f"{party} is not in TestConfig({self.config.get_party_ids()})" + ) curl_wrapper = CurlWrapper( url=url, header="Content-Type: application/json", @@ -515,9 +388,24 @@ def make_curl_cmd(self, party: str): ) return curl_wrapper.cmd() - def _exe_cmd(self, cmd, backgroud=False): + def make_get_model_info_curl_cmd(self, party: str): + url = None + for p_cfg in self.party_config: + if p_cfg.id == party: + url = f"http://{p_cfg.cluster_ip}/ModelService/GetModelInfo" + break + if not url: + raise Exception(f"{party} is not in TestConfig({config.get_party_ids()})") + curl_wrapper = CurlWrapper( + url=url, + header="Content-Type: application/json", + data='{}', + ) + return curl_wrapper.cmd() + + def _exe_cmd(self, cmd, background=False): print("Execute: ", cmd) - if not backgroud: + if not background: ret = subprocess.run(cmd, shell=True, check=True, capture_output=True) ret.check_returncode() return ret @@ -527,112 +415,50 @@ def _exe_cmd(self, cmd, backgroud=False): return proc def finish(self): - for proc in self.background_proc: - proc.kill() - proc.wait() - os.system(f"rm -rf {self.model_path}") - - def exe_start_server_scripts(self): - main_process_name = "//secretflow_serving/server:secretflow_serving" - self._exe_cmd(f"bazel build {main_process_name}") - - [ + if g_clean_up_service: + for proc in self.background_proc: + proc.kill() + proc.wait() + if g_clean_up_files: + os.system(f"rm -rf {self.model_path}") + + def exe_start_server_scripts(self, start_interval_s=0): + for arg in self.get_server_start_args(): self._exe_cmd( f"./bazel-bin/secretflow_serving/server/secretflow_serving {arg}", True ) - for arg in self.get_server_start_args() - ] - - def exe_curl_request_scripts(self, party: str): - return self._exe_cmd(self.make_curl_cmd(party)) - + if start_interval_s: + time.sleep(start_interval_s) -class FeatureParam: - def __init__(self, query_datas: List[str], query_context: str = None): - self.query_datas = query_datas - self.query_context = query_context + # wait 10s for servers be ready + time.sleep(10) - def to_json(self): - ret = {"query_datas": self.query_datas} - if self.query_context: - ret["query_context"] = self.query_context - return ret + def exe_curl_request_scripts(self, party: str): + return self._exe_cmd(self.make_predict_curl_cmd(party)) + def exe_get_model_info_request_scripts(self, party: str): + return self._exe_cmd(self.make_get_model_info_curl_cmd(party)) -class FieldType(Enum): - FIELD_BOOL = 1 - FIELD_INT32 = 2 - FIELD_INT64 = 3 - FIELD_FLOAT = 4 - FIELD_DOUBLE = 5 - FIELD_STRING = 6 +def make_feature(name: str, value: List[Any], f_type: str): + assert len(value) != 0 -class FeatureValue: - def __init__(self, data_list, types): - self.data_list = data_list - assert types in ["i32s", "i64s", "fs", "ds", "ss", "bs"] - self.types = types + field_type = FieldType.Value(f_type) - def to_json(self): - if self.types: - return {self.types: self.data_list} - assert len(self.data_list) != 0 - if isinstance(self.data_list[0], int): - types = "i32s" - elif isinstance(self.data_list[0], float): - types = "ds" - elif isinstance(self.data_list[0], bool): - types = "bs" - else: - types = "ss" - return {types: self.data_list} - - -def make_feature(name: str, value: List[Any], f_type: FieldType = None): - type_dict = { - FieldType.FIELD_BOOL: "bs", - FieldType.FIELD_DOUBLE: "ds", - FieldType.FIELD_FLOAT: "fs", - FieldType.FIELD_INT32: "i32s", - FieldType.FIELD_INT64: "i64s", - FieldType.FIELD_STRING: "ss", - } - if f_type: - if f_type == FieldType.FIELD_BOOL: - value = [bool(v) for v in value] - elif f_type in (FieldType.FIELD_DOUBLE, FieldType.FIELD_FLOAT): - value = [float(v) for v in value] - elif f_type in (FieldType.FIELD_INT32, FieldType.FIELD_INT64): - value = [int(v) for v in value] - else: - value = [str(v) for v in value] - return Feature(name, f_type, FeatureValue(value, type_dict[f_type])) + if field_type == FieldType.FIELD_BOOL: + f_value = FeatureValue(bs=[bool(v) for v in value]) + elif field_type == FieldType.FIELD_FLOAT: + f_value = FeatureValue(fs=[float(v) for v in value]) + elif field_type == FieldType.FIELD_DOUBLE: + f_value = FeatureValue(ds=[float(v) for v in value]) + elif field_type == FieldType.FIELD_INT32: + f_value = FeatureValue(i32s=[int(v) for v in value]) + elif field_type == FieldType.FIELD_INT64: + f_value = FeatureValue(i64s=[int(v) for v in value]) else: - assert len(value) != 0 - if isinstance(value[0], int): - f_type = FieldType.FIELD_INT64 - elif isinstance(value[0], float): - f_type = FieldType.FIELD_DOUBLE - elif isinstance(value[0], bool): - f_type = FieldType.FIELD_BOOL - else: - f_type = FieldType.FIELD_STRING - value = [str(v) for v in value] - return Feature(name, f_type, FeatureValue(value, type_dict[f_type])) + f_value = FeatureValue(ss=[str(v) for v in value]) - -class Feature: - def __init__(self, field_name: str, field_type: FieldType, value: FeatureValue): - self.name = field_name - self.type = field_type - self.value = value - - def to_json(self): - return { - "field": {"name": self.name, "type": self.type.name}, - "value": self.value.to_json(), - } + return Feature(field=FeatureField(name=name, type=field_type), value=f_value) class PredictRequest: @@ -656,10 +482,14 @@ def to_json(self): ret["service_spec"] = {"id": self.service_spec_id} if self.party_param_dict: ret["fs_params"] = { - k: v.to_json() for k, v in self.party_param_dict.items() + k: json.loads(MessageToJson(v, preserving_proto_field_name=True)) + for k, v in self.party_param_dict.items() } if self.predefined_feature: - ret["predefined_features"] = [i.to_json() for i in self.predefined_feature] + ret["predefined_features"] = [ + json.loads(MessageToJson(i, preserving_proto_field_name=True)) + for i in self.predefined_feature + ] return json.dumps(ret) @@ -677,89 +507,28 @@ def exe(self): # simple test -def get_example_config(path: str): - dot_node_1 = make_dot_product_node_def( - name="node_dot_product", - parents=[], - weight_dict={"x21": -0.3, "x22": 0.95, "x23": 1.01, "x24": 1.35, "x25": -0.97}, - output_col_name="y", - intercept=1.313, - ) - dot_node_2 = make_dot_product_node_def( - name="node_dot_product", - parents=[], - weight_dict={"x6": -0.53, "x7": 0.92, "x8": -0.72, "x9": 0.146, "x10": -0.07}, - output_col_name="y", - ) - merge_y_node = make_merge_y_node_def( - "node_merge_y", - ["node_dot_product"], - LinkFunction.LF_LOGIT_V2, - input_col_name="y", - output_col_name="score", - yhat_scale=1.2, - ) - execution_1 = ExecutionDef( - ["node_dot_product"], config=RuntimeConfig(DispatchType.DP_ALL, False) - ) - execution_2 = ExecutionDef( - ["node_merge_y"], config=RuntimeConfig(DispatchType.DP_ANYONE, False) - ) +class TestCase: + def __init__(self, path: str): + self.path = path - alice_graph = GraphDef( - "0.0.1", - node_list=[dot_node_1, merge_y_node], - execution_list=[execution_1, execution_2], - ) - bob_graph = GraphDef( - "0.0.1", - node_list=[dot_node_2, merge_y_node], - execution_list=[execution_1, execution_2], - ) + def exec(self): + config = self.get_config(self.path) + try: + self.test(config) + finally: + config.finish() - alice_config = PartyConfig( - id="alice", - feature_mapping={ - "v24": "x24", - "v22": "x22", - "v21": "x21", - "v25": "x25", - "v23": "x23", - }, - cluster_ip="127.0.0.1:9010", - metrics_port=10306, - brpc_builtin_service_port=10307, - channel_protocol="baidu_std", - model_id="integration_model", - graph_def=alice_graph, - query_datas=["a"], - ) - bob_config = PartyConfig( - id="bob", - feature_mapping={"v6": "x6", "v7": "x7", "v8": "x8", "v9": "x9", "v10": "x10"}, - cluster_ip="127.0.0.1:9011", - metrics_port=10308, - brpc_builtin_service_port=10309, - channel_protocol="baidu_std", - model_id="integration_model", - graph_def=bob_graph, - query_datas=["b"], - ) - return TestConfig( - path, - service_spec_id="integration_test", - party_config=[alice_config, bob_config], - ) + def test(config: TestConfig): + raise NotImplementedError + def get_config(self, path: str) -> TestConfig: + raise NotImplementedError -def simple_mock_test(): - config = get_example_config("model_path") - try: + +class SimpleTest(TestCase): + def test(self, config: TestConfig): config.dump_config() config.exe_start_server_scripts() - import time - - time.sleep(1) for party in config.get_party_ids(): res = config.exe_curl_request_scripts(party) out = res.stdout.decode() @@ -769,93 +538,282 @@ def simple_mock_test(): assert len(res["results"]) == len( config.party_config[0].query_datas ), f"result rows({len(res['results'])}) not equal to query_data({len(config.party_config[0].query_datas)})" - finally: - config.finish() - - -# predefine_test -def get_predefine_config(path: str): - dot_node_1 = make_dot_product_node_def( - name="node_dot_product", - parents=[], - weight_dict={"x1": 1.0, "x2": 2.0}, - output_col_name="y", - intercept=0, - ) - dot_node_2 = make_dot_product_node_def( - name="node_dot_product", - parents=[], - weight_dict={"x1": -1.0, "x2": -2.0}, - output_col_name="y", - intercept=0, - ) - merge_y_node = make_merge_y_node_def( - "node_merge_y", - ["node_dot_product"], - LinkFunction.LF_INDENTITY, - input_col_name="y", - output_col_name="score", - yhat_scale=1.0, - ) - execution_1 = ExecutionDef( - ["node_dot_product"], config=RuntimeConfig(DispatchType.DP_ALL, False) - ) - execution_2 = ExecutionDef( - ["node_merge_y"], config=RuntimeConfig(DispatchType.DP_ANYONE, False) - ) + model_info = config.exe_get_model_info_request_scripts(party) + out = model_info.stdout.decode() + print("Model info: ", out) + + def get_config(self, path: str): + with open(".ci/simple_test/node_processing_alice.json", "rb") as f: + alice_trace_content = f.read() + + processing_node_alice = make_processing_node_def( + name="node_processing", + parents=[], + input_schema=pa.schema( + [ + ('a', pa.int32()), + ('b', pa.float32()), + ('c', pa.utf8()), + ('x21', pa.float64()), + ('x22', pa.float32()), + ('x23', pa.int8()), + ('x24', pa.int16()), + ('x25', pa.int32()), + ] + ), + output_schema=pa.schema( + [ + ('a_0', pa.int64()), + ('a_1', pa.int64()), + ('c_0', pa.int64()), + ('b_0', pa.int64()), + ('x21', pa.float64()), + ('x22', pa.float32()), + ('x23', pa.int8()), + ('x24', pa.int16()), + ('x25', pa.int32()), + ] + ), + trace_content=alice_trace_content, + ) + # bob run dummy node (no trace) + processing_node_bob = make_processing_node_def( + name="node_processing", + parents=[], + input_schema=pa.schema( + [ + ('x6', pa.int64()), + ('x7', pa.uint8()), + ('x8', pa.uint16()), + ('x9', pa.uint32()), + ('x10', pa.uint64()), + ] + ), + output_schema=pa.schema( + [ + ('x6', pa.int64()), + ('x7', pa.uint8()), + ('x8', pa.uint16()), + ('x9', pa.uint32()), + ('x10', pa.uint64()), + ] + ), + ) - alice_graph = GraphDef( - "0.0.1", - node_list=[dot_node_1, merge_y_node], - execution_list=[execution_1, execution_2], - ) - bob_graph = GraphDef( - "0.0.1", - node_list=[dot_node_2, merge_y_node], - execution_list=[execution_1, execution_2], - ) + dot_node_alice = make_dot_product_node_def( + name="node_dot_product", + parents=['node_processing'], + weight_dict={ + "x21": -0.3, + "x22": 0.95, + "x23": 1.01, + "x24": 1.35, + "x25": -0.97, + "a_0": 1.0, + "c_0": 1.0, + "b_0": 1.0, + }, + output_col_name="y", + input_types=[ + "DT_DOUBLE", + "DT_FLOAT", + "DT_INT8", + "DT_INT16", + "DT_INT32", + "DT_INT64", + "DT_INT64", + "DT_INT64", + ], + intercept=1.313, + ) + dot_node_bob = make_dot_product_node_def( + name="node_dot_product", + parents=['node_processing'], + weight_dict={ + "x6": -0.53, + "x7": 0.92, + "x8": -0.72, + "x9": 0.146, + "x10": -0.07, + }, + input_types=["DT_INT64", "DT_UINT8", "DT_UINT16", "DT_UINT32", "DT_UINT64"], + output_col_name="y", + ) + merge_y_node = make_merge_y_node_def( + "node_merge_y", + ["node_dot_product"], + LinkFunctionType.LF_LOGIT, + input_col_name="y", + output_col_name="score", + yhat_scale=1.2, + ) + execution_1 = ExecutionDef( + nodes=["node_processing", "node_dot_product"], + config=RuntimeConfig(dispatch_type=DispatchType.DP_ALL, session_run=False), + ) + execution_2 = ExecutionDef( + nodes=["node_merge_y"], + config=RuntimeConfig( + dispatch_type=DispatchType.DP_ANYONE, session_run=False + ), + ) - alice_config = PartyConfig( - id="alice", - feature_mapping={}, - cluster_ip="127.0.0.1:9010", - metrics_port=10306, - brpc_builtin_service_port=10307, - channel_protocol="baidu_std", - model_id="integration_model", - graph_def=alice_graph, - query_datas=["a", "a", "a"], # only length matters - ) - bob_config = PartyConfig( - id="bob", - feature_mapping={}, - cluster_ip="127.0.0.1:9011", - metrics_port=10308, - brpc_builtin_service_port=10309, - channel_protocol="baidu_std", - model_id="integration_model", - graph_def=bob_graph, - query_datas=["a", "a", "a"], # only length matters - ) - return TestConfig( - path, - service_spec_id="integration_test", - party_config=[alice_config, bob_config], - predefined_features={ - "x1": [1.0, 2.0, 3.4], - "x2": [6.0, 7.0, 8.0], - }, - ) + alice_graph = GraphDef( + version="0.0.1", + node_list=[processing_node_alice, dot_node_alice, merge_y_node], + execution_list=[execution_1, execution_2], + ) + bob_graph = GraphDef( + version="0.0.1", + node_list=[processing_node_bob, dot_node_bob, merge_y_node], + execution_list=[execution_1, execution_2], + ) + + alice_config = PartyConfig( + id="alice", + feature_mapping={ + "a": "a", + "b": "b", + "c": "c", + "v24": "x24", + "v22": "x22", + "v21": "x21", + "v25": "x25", + "v23": "x23", + }, + **global_ip_config(0), + channel_protocol="baidu_std", + model_id="integration_model", + graph_def=alice_graph, + query_datas=["a"], + ) + bob_config = PartyConfig( + id="bob", + feature_mapping={ + "v6": "x6", + "v7": "x7", + "v8": "x8", + "v9": "x9", + "v10": "x10", + }, + **global_ip_config(1), + channel_protocol="baidu_std", + model_id="integration_model", + graph_def=bob_graph, + query_datas=["b"], + ) + return TestConfig( + path, + service_spec_id="integration_test", + party_config=[alice_config, bob_config], + ) + + +class PredefinedErrorTest(TestCase): + def get_config(self, path: str) -> TestConfig: + dot_node_alice = make_dot_product_node_def( + name="node_dot_product", + parents=[], + weight_dict={"x1": 1.0, "x2": 2.0, "x3": 3.0}, + input_types=["DT_FLOAT", "DT_DOUBLE", "DT_INT32"], + output_col_name="y", + intercept=0, + ) + dot_node_bob = make_dot_product_node_def( + name="node_dot_product", + parents=[], + weight_dict={"x1": -1.0, "x2": -2.0, "x3": -3.0}, + input_types=["DT_FLOAT", "DT_DOUBLE", "DT_INT32"], + output_col_name="y", + intercept=0, + ) + merge_y_node = make_merge_y_node_def( + "node_merge_y", + ["node_dot_product"], + LinkFunctionType.LF_IDENTITY, + input_col_name="y", + output_col_name="score", + yhat_scale=1.0, + ) + execution_1 = ExecutionDef( + nodes=["node_dot_product"], + config=RuntimeConfig(dispatch_type=DispatchType.DP_ALL, session_run=False), + ) + execution_2 = ExecutionDef( + nodes=["node_merge_y"], + config=RuntimeConfig( + dispatch_type=DispatchType.DP_ANYONE, session_run=False + ), + ) + + alice_graph = GraphDef( + version="0.0.1", + node_list=[dot_node_alice, merge_y_node], + execution_list=[execution_1, execution_2], + ) + bob_graph = GraphDef( + version="0.0.1", + node_list=[dot_node_bob, merge_y_node], + execution_list=[execution_1, execution_2], + ) + alice_config = PartyConfig( + id="alice", + feature_mapping={}, + **global_ip_config(0), + channel_protocol="baidu_std", + model_id="integration_model", + graph_def=alice_graph, + query_datas=["a", "a", "a"], # only length matters + ) + bob_config = PartyConfig( + id="bob", + feature_mapping={}, + **global_ip_config(1), + channel_protocol="baidu_std", + model_id="integration_model", + graph_def=bob_graph, + query_datas=["a", "a", "a"], # only length matters + ) + return TestConfig( + path, + service_spec_id="integration_test", + party_config=[alice_config, bob_config], + predefined_features={ + "x1": [1.0, 2.0, 3.4], + "x2": [6.0, 7.0, 8.0], + "x3": [-9, -10, -11], + }, + predefined_types={ + "x1": "FIELD_FLOAT", + "x2": "FIELD_DOUBLE", + "x3": "FIELD_INT32", + }, + ) -def predefine_test(): - config = get_predefine_config("module_path") - try: + def test(self, config): + new_config = {} + for k, v in config.predefined_features.items(): + v.append(9.9) + new_config[k] = v + config.predefined_features = new_config config.dump_config() config.exe_start_server_scripts() - import time - time.sleep(1) + for party in config.get_party_ids(): + res = config.exe_curl_request_scripts(party) + out = res.stdout.decode() + print("Result: ", out) + res = json.loads(out) + assert ( + res["status"]["code"] != 1 + ), f'return status code({res["status"]["code"]}) should not be OK(1)' + + +class PredefineTest(PredefinedErrorTest): + def test(self, config): + config.dump_config() + config.exe_start_server_scripts() results = [] for party in config.get_party_ids(): res = config.exe_curl_request_scripts(party) @@ -869,103 +827,98 @@ def predefine_test(): config.party_config[0].query_datas ), f"result rows({len(res['results'])}) not equal to query_data({len(config.party_config[0].query_datas)})" results.append(res) - # std::rand in MockAdapater start with same seed at both sides + # std::rand in MockAdapter start with same seed at both sides for a_score, b_score in zip(results[0]["results"], results[1]["results"]): assert a_score["scores"][0]["value"] + b_score["scores"][0]["value"] == 0 - finally: - config.finish() - - -# csv test -def get_csv_config(path: str): - dot_node_1 = make_dot_product_node_def( - name="node_dot_product", - parents=[], - weight_dict={"x1": 1.0, "x2": 2.0}, - output_col_name="y", - intercept=0, - ) - dot_node_2 = make_dot_product_node_def( - name="node_dot_product", - parents=[], - weight_dict={"x1": -1.0, "x2": -2.0}, - output_col_name="y", - intercept=0, - ) - merge_y_node = make_merge_y_node_def( - "node_merge_y", - ["node_dot_product"], - LinkFunction.LF_INDENTITY, - input_col_name="y", - output_col_name="score", - yhat_scale=1.0, - ) - execution_1 = ExecutionDef( - ["node_dot_product"], config=RuntimeConfig(DispatchType.DP_ALL, False) - ) - execution_2 = ExecutionDef( - ["node_merge_y"], config=RuntimeConfig(DispatchType.DP_ANYONE, False) - ) - alice_graph = GraphDef( - "0.0.1", - node_list=[dot_node_1, merge_y_node], - execution_list=[execution_1, execution_2], - ) - bob_graph = GraphDef( - "0.0.1", - node_list=[dot_node_2, merge_y_node], - execution_list=[execution_1, execution_2], - ) - alice_config = PartyConfig( - id="alice", - feature_mapping={"v1": "x1", "v2": "x2"}, - cluster_ip="127.0.0.1:9010", - metrics_port=10306, - brpc_builtin_service_port=10307, - channel_protocol="baidu_std", - model_id="integration_model", - graph_def=alice_graph, - query_datas=["a", "b", "c"], # Corresponds to the id column in csv - csv_dict={ - "id": ["a", "b", "c", "d"], - "v1": [1.0, 2.0, 3.0, 4.0], - "v2": [5.0, 6.0, 7.0, 8.0], - }, - ) - bob_config = PartyConfig( - id="bob", - feature_mapping={"vv2": "x2", "vv3": "x1"}, - cluster_ip="127.0.0.1:9011", - metrics_port=10308, - brpc_builtin_service_port=10309, - channel_protocol="baidu_std", - model_id="integration_model", - graph_def=bob_graph, - query_datas=["a", "b", "c"], # Corresponds to the id column in csv - csv_dict={ - "id": ["a", "b", "c"], - "vv3": [1.0, 2.0, 3.0], - "vv2": [5.0, 6.0, 7.0], - }, - ) - return TestConfig( - path, - service_spec_id="integration_test", - party_config=[alice_config, bob_config], - ) +class CsvTest(TestCase): + def get_config(self, path: str): + dot_node_alice = make_dot_product_node_def( + name="node_dot_product", + parents=[], + weight_dict={"x1": 1.0, "x2": 2.0}, + input_types=["DT_INT8", "DT_UINT8"], + output_col_name="y", + intercept=0, + ) + dot_node_bob = make_dot_product_node_def( + name="node_dot_product", + parents=[], + weight_dict={"x1": -1.0, "x2": -2.0}, + input_types=["DT_INT16", "DT_UINT16"], + output_col_name="y", + intercept=0, + ) + merge_y_node = make_merge_y_node_def( + "node_merge_y", + ["node_dot_product"], + LinkFunctionType.LF_IDENTITY, + input_col_name="y", + output_col_name="score", + yhat_scale=1.0, + ) + + execution_1 = ExecutionDef( + nodes=["node_dot_product"], + config=RuntimeConfig(dispatch_type=DispatchType.DP_ALL, session_run=False), + ) + execution_2 = ExecutionDef( + nodes=["node_merge_y"], + config=RuntimeConfig( + dispatch_type=DispatchType.DP_ANYONE, session_run=False + ), + ) + + alice_graph = GraphDef( + version="0.0.1", + node_list=[dot_node_alice, merge_y_node], + execution_list=[execution_1, execution_2], + ) + bob_graph = GraphDef( + version="0.0.1", + node_list=[dot_node_bob, merge_y_node], + execution_list=[execution_1, execution_2], + ) + alice_config = PartyConfig( + id="alice", + feature_mapping={"v1": "x1", "v2": "x2"}, + **global_ip_config(0), + channel_protocol="baidu_std", + model_id="integration_model", + graph_def=alice_graph, + query_datas=["a", "b", "c"], # Corresponds to the id column in csv + csv_dict={ + "id": ["a", "b", "c", "d"], + "v1": [1, 2, 3, 4], + "v2": [5, 6, 7, 8], + }, + ) + bob_config = PartyConfig( + id="bob", + feature_mapping={"vv2": "x2", "vv3": "x1"}, + **global_ip_config(1), + channel_protocol="baidu_std", + model_id="integration_model", + graph_def=bob_graph, + query_datas=["a", "b", "c"], # Corresponds to the id column in csv + csv_dict={ + "id": ["a", "b", "c"], + "vv3": [1, 2, 3], + "vv2": [5, 6, 7], + }, + ) + return TestConfig( + path, + service_spec_id="integration_test", + party_config=[alice_config, bob_config], + ) -def csv_test(): - config = get_csv_config("module_path") - try: + def test(self, config): config.dump_config() config.exe_start_server_scripts() - import time - time.sleep(1) - results = [] for party in config.get_party_ids(): res = config.exe_curl_request_scripts(party) out = res.stdout.decode() @@ -979,11 +932,160 @@ def csv_test(): ), f"result rows({len(res['results'])}) not equal to query_data({len(config.party_config[0].query_datas)})" for score in res["results"]: assert score["scores"][0]["value"] == 0, "result should be 0" - finally: - config.finish() + + +class SpecificTest(TestCase): + def get_config(self, path: str): + dot_node_alice = make_dot_product_node_def( + name="node_dot_product", + parents=[], + weight_dict={"x1": 1.0, "x2": 2.0}, + input_types=["DT_DOUBLE", "DT_DOUBLE"], + output_col_name="y", + intercept=0, + ) + dot_node_bob = make_dot_product_node_def( + name="node_dot_product", + parents=[], + weight_dict={"x1": -1.0, "x2": -2.0}, + input_types=["DT_DOUBLE", "DT_DOUBLE"], + output_col_name="y", + intercept=0, + ) + merge_y_node = make_merge_y_node_def( + "node_merge_y", + ["node_dot_product"], + LinkFunctionType.LF_IDENTITY, + input_col_name="y", + output_col_name="score", + yhat_scale=1.0, + ) + dot_node_specific_1 = make_dot_product_node_def( + name="node_dot_product_spec", + parents=["node_merge_y"], + weight_dict={"score": 1.0}, + input_types=["DT_DOUBLE"], + output_col_name="y", + intercept=1234.0, + ) + dot_node_specific_2 = make_dot_product_node_def( + name="node_dot_product_spec", + parents=["node_merge_y"], + input_types=["DT_DOUBLE"], + weight_dict={"score": 1.0}, + output_col_name="y", + intercept=2468.0, + ) + merge_y_node_res = make_merge_y_node_def( + "node_merge_y_res", + ["node_dot_product_spec"], + LinkFunctionType.LF_IDENTITY, + input_col_name="y", + output_col_name="score", + yhat_scale=1.0, + ) + execution_1 = ExecutionDef( + nodes=["node_dot_product"], + config=RuntimeConfig(dispatch_type=DispatchType.DP_ALL, session_run=False), + ) + execution_2_alice = ExecutionDef( + nodes=["node_merge_y", "node_dot_product_spec", "node_merge_y_res"], + config=RuntimeConfig( + dispatch_type=DispatchType.DP_SPECIFIED, + session_run=False, + specific_flag=True, + ), + ) + + execution_2_bob = ExecutionDef( + nodes=["node_merge_y", "node_dot_product_spec", "node_merge_y_res"], + config=RuntimeConfig( + dispatch_type=DispatchType.DP_SPECIFIED, session_run=False + ), + ) + + alice_graph = GraphDef( + version="0.0.1", + node_list=[ + dot_node_alice, + merge_y_node, + dot_node_specific_1, + merge_y_node_res, + ], + execution_list=[ + execution_1, + execution_2_alice, + ], + ) + bob_graph = GraphDef( + version="0.0.1", + node_list=[ + dot_node_bob, + merge_y_node, + dot_node_specific_2, + merge_y_node_res, + ], + execution_list=[execution_1, execution_2_bob], + ) + + alice_config = PartyConfig( + id="alice", + feature_mapping={"v1": "x1", "v2": "x2"}, + **global_ip_config(0), + channel_protocol="baidu_std", + model_id="integration_model", + graph_def=alice_graph, + query_datas=["a", "b", "c"], # Corresponds to the id column in csv + csv_dict={ + "id": ["a", "b", "c", "d"], + "v1": [1.0, 2.0, 3.0, 4.0], + "v2": [5.0, 6.0, 7.0, 8.0], + }, + ) + bob_config = PartyConfig( + id="bob", + feature_mapping={"vv2": "x2", "vv3": "x1"}, + **global_ip_config(1), + channel_protocol="baidu_std", + model_id="integration_model", + graph_def=bob_graph, + query_datas=["a", "b", "c"], # Corresponds to the id column in csv + csv_dict={ + "id": ["a", "b", "c"], + "vv3": [1.0, 2.0, 3.0], + "vv2": [5.0, 6.0, 7.0], + }, + ) + return TestConfig( + path, + service_spec_id="integration_test", + party_config=[alice_config, bob_config], + ) + + def test(self, config): + config.dump_config() + config.exe_start_server_scripts(2) + + for party in config.get_party_ids(): + res = config.exe_curl_request_scripts(party) + out = res.stdout.decode() + print("Result: ", out) + res = json.loads(out) + assert ( + res["status"]["code"] == 1 + ), f'return status code({res["status"]["code"]}) is not OK(1)' + assert len(res["results"]) == len( + config.party_config[0].query_datas + ), f"result rows({len(res['results'])}) not equal to query_data({len(config.party_config[0].query_datas)})" + for score in res["results"]: + assert ( + score["scores"][0]["value"] == 1234.0 + ), f'result should be 0, got: {score["scores"][0]["value"]}' if __name__ == "__main__": - simple_mock_test() - predefine_test() - csv_test() + SimpleTest('model_path').exec() + PredefinedErrorTest('model_path').exec() + PredefineTest('model_path').exec() + CsvTest('model_path').exec() + SpecificTest('model_path').exec() diff --git a/.ci/requirements-ci.txt b/.ci/requirements-ci.txt new file mode 100644 index 0000000..453abc6 --- /dev/null +++ b/.ci/requirements-ci.txt @@ -0,0 +1 @@ +pandas==1.5.3 diff --git a/.ci/simple_test/node_processing_alice.json b/.ci/simple_test/node_processing_alice.json new file mode 100644 index 0000000..9fef115 --- /dev/null +++ b/.ci/simple_test/node_processing_alice.json @@ -0,0 +1,432 @@ +{ + "name": "onehot", + "funcTraces": [ + { + "name": "EFN_TB_COLUMN", + "inputs": [ + { + "dataId": 0 + }, + { + "customScalar": { + "i64": "0" + } + } + ], + "output": { + "dataId": 1 + } + }, + { + "name": "equal", + "inputs": [ + { + "dataId": 1 + }, + { + "customScalar": { + "i64": "3" + } + } + ], + "output": { + "dataId": 2 + } + }, + { + "name": "equal", + "inputs": [ + { + "dataId": 1 + }, + { + "customScalar": { + "i64": "2" + } + } + ], + "output": { + "dataId": 3 + } + }, + { + "name": "equal", + "inputs": [ + { + "dataId": 1 + }, + { + "customScalar": { + "i64": "1" + } + } + ], + "output": { + "dataId": 4 + } + }, + { + "name": "or", + "inputs": [ + { + "dataId": 3 + }, + { + "dataId": 2 + } + ], + "output": { + "dataId": 5 + } + }, + { + "name": "if_else", + "inputs": [ + { + "dataId": 4 + }, + { + "customScalar": { + "i64": "1" + } + }, + { + "customScalar": { + "i64": "0" + } + } + ], + "output": { + "dataId": 6 + } + }, + { + "name": "EFN_TB_REMOVE_COLUMN", + "inputs": [ + { + "dataId": 0 + }, + { + "customScalar": { + "i64": "0" + } + } + ], + "output": { + "dataId": 7 + } + }, + { + "name": "if_else", + "inputs": [ + { + "dataId": 5 + }, + { + "customScalar": { + "i64": "1" + } + }, + { + "customScalar": { + "i64": "0" + } + } + ], + "output": { + "dataId": 8 + } + }, + { + "name": "EFN_TB_ADD_COLUMN", + "inputs": [ + { + "dataId": 7 + }, + { + "customScalar": { + "i64": "2" + } + }, + { + "customScalar": { + "s": "a_0" + } + }, + { + "dataId": 6 + } + ], + "output": { + "dataId": 9 + } + }, + { + "name": "EFN_TB_ADD_COLUMN", + "inputs": [ + { + "dataId": 9 + }, + { + "customScalar": { + "i64": "3" + } + }, + { + "customScalar": { + "s": "a_1" + } + }, + { + "dataId": 8 + } + ], + "output": { + "dataId": 10 + } + }, + { + "name": "EFN_TB_COLUMN", + "inputs": [ + { + "dataId": 10 + }, + { + "customScalar": { + "i64": "1" + } + } + ], + "output": { + "dataId": 11 + } + }, + { + "name": "equal", + "inputs": [ + { + "dataId": 11 + }, + { + "customScalar": { + "s": "m" + } + } + ], + "output": { + "dataId": 12 + } + }, + { + "name": "equal", + "inputs": [ + { + "dataId": 11 + }, + { + "customScalar": { + "s": "k" + } + } + ], + "output": { + "dataId": 13 + } + }, + { + "name": "or", + "inputs": [ + { + "dataId": 13 + }, + { + "dataId": 12 + } + ], + "output": { + "dataId": 14 + } + }, + { + "name": "if_else", + "inputs": [ + { + "dataId": 14 + }, + { + "customScalar": { + "i64": "1" + } + }, + { + "customScalar": { + "i64": "0" + } + } + ], + "output": { + "dataId": 15 + } + }, + { + "name": "EFN_TB_REMOVE_COLUMN", + "inputs": [ + { + "dataId": 10 + }, + { + "customScalar": { + "i64": "1" + } + } + ], + "output": { + "dataId": 16 + } + }, + { + "name": "EFN_TB_ADD_COLUMN", + "inputs": [ + { + "dataId": 16 + }, + { + "customScalar": { + "i64": "3" + } + }, + { + "customScalar": { + "s": "c_0" + } + }, + { + "dataId": 15 + } + ], + "output": { + "dataId": 17 + } + }, + { + "name": "EFN_TB_COLUMN", + "inputs": [ + { + "dataId": 17 + }, + { + "customScalar": { + "i64": "0" + } + } + ], + "output": { + "dataId": 18 + } + }, + { + "name": "subtract", + "inputs": [ + { + "dataId": 18 + }, + { + "customScalar": { + "d": 1.11 + } + } + ], + "output": { + "dataId": 19 + } + }, + { + "name": "abs", + "inputs": [ + { + "dataId": 19 + } + ], + "output": { + "dataId": 20 + } + }, + { + "name": "less", + "inputs": [ + { + "dataId": 20 + }, + { + "customScalar": { + "d": 1e-07 + } + } + ], + "output": { + "dataId": 21 + } + }, + { + "name": "if_else", + "inputs": [ + { + "dataId": 21 + }, + { + "customScalar": { + "i64": "1" + } + }, + { + "customScalar": { + "i64": "0" + } + } + ], + "output": { + "dataId": 22 + } + }, + { + "name": "EFN_TB_REMOVE_COLUMN", + "inputs": [ + { + "dataId": 17 + }, + { + "customScalar": { + "i64": "0" + } + } + ], + "output": { + "dataId": 23 + } + }, + { + "name": "EFN_TB_ADD_COLUMN", + "inputs": [ + { + "dataId": 23 + }, + { + "customScalar": { + "i64": "3" + } + }, + { + "customScalar": { + "s": "b_0" + } + }, + { + "dataId": 22 + } + ], + "output": { + "dataId": 24 + } + } + ] +} diff --git a/.ci/test_data/bin_onehot_glm/alice/alice.csv b/.ci/test_data/bin_onehot_glm/alice/alice.csv new file mode 100644 index 0000000..d83612e --- /dev/null +++ b/.ci/test_data/bin_onehot_glm/alice/alice.csv @@ -0,0 +1,21 @@ +id,f1,f2,f3,f4,f5,f6,f7,f8,b1,o1,y +1,-0.5591187559186066,-0.2814808888968485,-0.4310124259272412,0.42698761472064484,0.6996194412496004,-0.8971369921102821,-0.1297223511775294,-0.6458798434413631,0.1396383060036957,D,0.0 +2,0.17853136775181744,-0.4768767452292948,0.6595296048985997,0.16383724799230648,0.4191750631343425,-0.0553416809559204,-0.835222222053031,0.5381071271563711,0.08999384285997897,C,1.0 +3,0.6188609133556533,0.025633080047547052,0.8300179478963079,-0.45405441885735187,-0.08493499972471619,0.04889614165329914,-0.9523316995614857,-0.05353206715441927,0.2981044873444963,B,1.0 +4,-0.987002480643878,-0.005445413798033316,-0.3038095035035151,0.5789854645875725,-0.6726110640591769,-0.11796517197878131,-0.7913629960788118,0.5252666701315694,0.32159023992596675,B,1.0 +5,0.6116385036656158,-0.8027835368690399,0.9265753301309776,0.2876623562458911,-0.724878081055228,-0.4616345585134092,0.12250598496354348,-0.8694099398481183,0.16167450163600877,C,0.0 +6,0.3962787899764537,0.9626369136931534,-0.4469843026057556,-0.19952894216630912,0.3919110519057767,0.6720178809053261,0.2193264721156456,-0.09114188927646016,0.20676337842801856,B,1.0 +7,-0.31949896696401625,-0.061019155954883164,0.21272681598450438,-0.04834743340794745,-0.41783686135384324,-0.41386938227609527,-0.3882048353908272,-0.39360862010344855,0.018901318801212597,C,1.0 +8,-0.6890410003764369,0.6794623631273697,-0.6109496138280779,0.2063308648785771,-0.7103542356227883,-0.5119420092003868,0.9845747170799963,-0.5513685278692384,0.1172886051599123,D,0.0 +9,0.9144261444135624,0.8286609977949781,0.8594276816242405,-0.1911128016185284,-0.563727053934955,-0.5138533145337412,0.541481543886571,-0.5981519864733384,0.0376931235673269,C,0.0 +10,-0.32681090977474647,-0.2585901571374831,0.1498134434916294,-0.030934281477115855,0.49676503532610927,0.5352868285398893,0.5416854721911608,0.7276956182064944,0.16392625797797383,A,1.0 +11,-0.8145083132397042,-0.17213965884275084,-0.47778498865612185,-0.4422187336180494,0.4706377931123673,-0.27607008470757766,0.39780330740087533,0.6789359372382642,0.31034068728585906,D,0.0 +12,-0.806567246333072,0.12504945492870867,-0.18573074778559673,0.5359991111551932,-0.9817349707748626,-0.08835765250300809,0.05510025965060583,0.9270163534670308,0.3898825715991485,B,1.0 +13,0.6949887326949196,-0.5574518033675495,-0.7870700872148468,-0.7839266639621547,0.45191502356739766,0.49637655872697906,-0.9167349161493943,-0.46072101353916106,0.3119186144921719,B,0.0 +14,0.20745206273378214,-0.7081545737949029,-0.8552280589465922,-0.4986202792921457,0.15136419388574107,-0.7691661612231564,0.8935597105203195,-0.03929942517629792,0.36079060086816067,A,0.0 +15,0.6142565465487604,-0.47845178018046086,-0.4114963510637599,-0.00545243574944454,0.3566356120332963,-0.5179070466010649,-0.38576033872228077,-0.12374141954173834,0.4730381159140901,D,1.0 +16,0.45946357338763577,0.869516500592707,0.8940309303553251,0.25674205932621663,0.9769683757190519,0.4655966603399595,0.1749397180654313,-0.7330265340634321,0.057493900732114966,A,1.0 +17,0.07245618290940148,0.15828585207480095,0.6058844851447147,-0.7307832385640813,-0.7047890815652045,0.7604535581160456,0.800552689601107,0.5620156263473772,0.18038921596661117,D,1.0 +18,0.9462315279587412,-0.16484385288965275,0.9134186991705504,-0.8007932195395477,0.9007963088557205,-0.2762376803205897,0.08557406326172878,0.2904407156393314,0.39477993406748413,D,0.0 +19,-0.24293124558329304,-0.6951771796317134,0.7522260924861339,-0.894295158530586,-0.14577787102095408,-0.2960307812622618,0.28186485638698366,-0.2308367011118997,0.014778052261847086,C,0.0 +20,0.104081262546454,-0.3402694280442802,0.6262946992127485,-0.8635426536480353,0.6833201820883588,0.011286885685380721,0.4785141057961586,0.9729862354474565,0.08367496429611809,D,0.0 diff --git a/.ci/test_data/bin_onehot_glm/alice/s_model.tar.gz b/.ci/test_data/bin_onehot_glm/alice/s_model.tar.gz new file mode 100644 index 0000000..6330550 Binary files /dev/null and b/.ci/test_data/bin_onehot_glm/alice/s_model.tar.gz differ diff --git a/.ci/test_data/bin_onehot_glm/bob/bob.csv b/.ci/test_data/bin_onehot_glm/bob/bob.csv new file mode 100644 index 0000000..e70da7c --- /dev/null +++ b/.ci/test_data/bin_onehot_glm/bob/bob.csv @@ -0,0 +1,20 @@ +id,f9,f10,f11,f12,f13,f14,f15,f16,b2,o2 +1,0.5209858850305862,0.9053351714559352,-0.48894315289504453,0.9829900835655276,0.15858969894616792,-0.6509650422128417,-0.1925996324660748,0.7337426610542273,0.2069654682246781,C +2,0.958857708228386,-0.03544303040194019,0.1405452156194371,0.26566930360269203,-0.07037219203097589,-0.06238026085591453,-0.9576483200594723,0.10816205842994098,0.1731894617440778,A +3,-0.9699719830103539,-0.22871450938326254,0.9948664607223907,-0.22279916938471844,0.25927550014205414,0.3878264311318409,0.27177589619065845,0.6383352653093646,0.25407196544940047,D +4,-0.050829342174382175,0.3012852377213351,-0.6689844077984184,0.40506254578438017,0.8293328989949791,-0.33125222538898,-0.10029643042382363,-0.5102837414762278,0.4975597939138244,A +5,0.03254075321523131,-0.1550161278915756,0.748292942247121,-0.3110546506844072,-0.9357791012167433,-0.6037439666968589,-0.6685984505490399,0.3264405397456147,0.47940341433092076,C +6,-0.20783808110355806,0.25829933171498265,0.5454217424402565,-0.18908102504737956,0.9705429784843074,0.6544709622583322,0.5338931897402897,-0.22007933551686087,0.35156914944455486,B +7,0.42193410537442877,-0.2595318581831192,-0.2605537512100149,0.8343221815816269,0.22255791135656344,0.6519908236107419,-0.8582444787325789,0.9720562883664492,0.402478217150546,A +8,0.9832739317242707,0.16400234650147172,-0.5492110046967817,-0.6174901634775414,0.732487308619052,-0.802784012378206,-0.36245973518487573,-0.508591015050454,0.19336452375546404,D +9,-0.7442956984656484,-0.9874521101295344,0.6841333919922283,-0.7822304746507185,0.40523640971796415,-0.5573698692579074,-0.23805678835787902,0.8964462739632588,0.4198405689280362,D +10,0.7492376130689999,0.7977079143275674,-0.7263690468963742,-0.7226107456080437,-0.07819549018956495,-0.9606123576424457,0.21630587326725403,-0.14395471673536475,0.32775602239291934,A +11,-0.6542216900312965,0.07645253125796714,0.11882381987303448,0.8614954588574071,0.27210232282280633,-0.0371962155037604,0.3201004955320794,0.48841165779914575,0.3282744447473197,D +12,0.1910487842929145,-0.7990976891472847,-0.9170751231778567,0.7211491585750245,0.7840706287053596,-0.4455316016020341,-0.26381822731077764,-0.3487559445781614,0.35421728688837884,A +13,0.8835371488089101,0.30837838073400636,0.3764288197101724,0.17726234541250996,-0.3427559189371696,-0.5028957779843914,0.3936751770454463,0.24834986779951707,0.281776736179031,A +14,-0.11182332142106888,-0.46719792116638614,-0.6540867954811456,0.7985565707173228,0.8914160990779798,0.7962924217918101,0.16273208957591456,0.657259185120165,0.4073024579275322,D +15,0.5711869582126345,0.5406490575531178,0.031435391811630575,-0.9773036805505819,0.8514585657857008,0.3584283025293262,0.7959478839608298,0.8963306585117048,0.44468516803316027,A +16,-0.9727600525172746,0.24233982156866274,0.4678844169690879,-0.20917678146571816,-0.16292590796261908,-0.8990527513505742,-0.2930935478104002,0.9876944462530497,0.47849543068945494,C +17,0.6038280170142221,-0.06582223464997417,-0.4675591727611532,0.14531487617290728,0.01644454992371669,0.9715205465257579,-0.010725165298184791,0.7775149649626913,0.28136158103093967,B +18,-0.10776926376059648,-0.1108106381219256,0.30567391234767327,0.011254218913324898,-0.8669148850836634,-0.15379881396552686,0.13669193302481908,0.10165178123069141,0.2809494691187582,D +19,0.9884320455533016,-0.1804968846324253,0.9585057043623681,0.3298875552404499,-0.5876084980995562,-0.7970582716892087,0.22321082010928328,0.3179780628164999,0.1317701920136854,C diff --git a/.ci/test_data/bin_onehot_glm/bob/s_model.tar.gz b/.ci/test_data/bin_onehot_glm/bob/s_model.tar.gz new file mode 100644 index 0000000..aaf545a Binary files /dev/null and b/.ci/test_data/bin_onehot_glm/bob/s_model.tar.gz differ diff --git a/.ci/test_data/bin_onehot_glm/predict.csv b/.ci/test_data/bin_onehot_glm/predict.csv new file mode 100644 index 0000000..b6d9de0 --- /dev/null +++ b/.ci/test_data/bin_onehot_glm/predict.csv @@ -0,0 +1,20 @@ +id,pred,y +1, 0.39350164,0.0 +2, 0.8956253,1.0 +3, 0.8731979,1.0 +4, 0.96265507,1.0 +5, 0.16241527,0.0 +6, 0.9365009,1.0 +7, 0.8754473,1.0 +8, 0.48002118,0.0 +9, 0.42156637,0.0 +10, 0.714669,1.0 +11, 0.35087457,0.0 +12, 0.92386246,1.0 +13, 0.38156974,0.0 +14, 0.22711706,0.0 +15, 0.8340757,1.0 +16, 0.6390995,1.0 +17, 0.64777297,1.0 +18, 0.10322996,0.0 +19, 0.14100587,0.0 diff --git a/.circleci/config.yml b/.circleci/config.yml index 2805d63..3b907a4 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -71,13 +71,21 @@ jobs: command: git clone https://github.com/secretflow/devtools.git ../devtools - run: name: "build" - command: bazel build //... -c opt --ui_event_filters=-info,-debug,-warning --jobs 16 + command: | + source ~/miniconda3/bin/activate py3.8 + + python3 -m pip install -r requirements.txt + python3 -m pip install -r .ci/requirements-ci.txt + + bazel build //... -c opt --ui_event_filters=-info,-debug,-warning --jobs 16 + + sh ./build_wheel_entrypoint.sh - run: name: "test" command: | set +e declare -i test_status - bazel test //... -c opt --ui_event_filters=-info,-debug,-warning --test_output=errors | tee test_result.log; test_status=${PIPESTATUS[0]} + bazel test //secretflow_serving/... -c opt --ui_event_filters=-info,-debug,-warning --test_output=errors | tee test_result.log; test_status=${PIPESTATUS[0]} sh ../devtools/rename-junit-xml.sh find bazel-bin/ -executable -type f -name "*_test" -print0 | xargs -0 tar -cvzf test_binary.tar.gz @@ -88,8 +96,22 @@ jobs: name: "integration test" command: | set +e + source ~/miniconda3/bin/activate py3.8 + + declare -i test_status + + python3 .ci/integration_test.py 2>&1 | tee integration_test.log; test_status=${PIPESTATUS[0]} + + exit ${test_status} + - run: + name: "accuracy test" + command: | + set +e + source ~/miniconda3/bin/activate py3.8 + declare -i test_status - python .ci/integration_test.py 2>&1 | tee integration_test.log; test_status=${PIPESTATUS[0]} + + python3 .ci/accuracy_test.py 2>&1 | tee accuracy_test.py.log; test_status=${PIPESTATUS[0]} exit ${test_status} - store_test_results: diff --git a/.clang-tidy b/.clang-tidy index 5af8a18..4ad05fc 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -8,6 +8,8 @@ Checks: "abseil-cleanup-ctad, abseil-upgrade-duration-conversions bugprone-*, -bugprone-easily-swappable-parameters, + -bugprone-implicit-widening-of-multiplication-result, + -bugprone-narrowing-conversions, # too many false positives around `std::size_t` vs. `*::difference_type`. google-build-using-namespace, google-explicit-constructor, google-global-names-in-headers, @@ -17,11 +19,14 @@ Checks: "abseil-cleanup-ctad, misc-unused-using-decls, modernize-*, -modernize-use-trailing-return-type, + -modernize-avoid-c-arrays, + -modernize-return-braced-init-list, # can hurt readability -modernize-use-nodiscard, performance-*, readability-*, -readability-else-after-return, -readability-identifier-length, + -readability-function-cognitive-complexity, -readability-magic-numbers, -readability-named-parameter" @@ -68,3 +73,6 @@ CheckOptions: - key: readability-identifier-naming.FunctionCase value: "CamelBack" + + - key: performance-unnecessary-value-param.AllowedTypes + value: PtBufferView diff --git a/.gitignore b/.gitignore index 465b3be..3443bee 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,10 @@ rpc_data #python *.pyc +*egg-info +build/ +dist/ + +#docs +_build/ +*.mo diff --git a/.vscode/settings.json b/.vscode/settings.json index 52c78ac..4739d5a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -35,6 +35,10 @@ }, "editor.formatOnSave": false, }, - "python.formatting.provider": "black", - "esbonio.sphinx.confDir": "" -} \ No newline at end of file + "python.formatting.provider": "none", + "esbonio.sphinx.confDir": "", + "git.ignoreLimitWarning": true, + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter" + } +} diff --git a/CHANGELOG.md b/CHANGELOG.md index 4779f54..6d140a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,11 +5,18 @@ > - Add `[Feature]` prefix for new features > - Add `[Bugfix]` prefix for bug fixes > - Add `[API]` prefix for API changes +> - Add `[DOC]` prefix for Doc changes ## staging > please add your unreleased change here. +- [Feature] Add `ARROW_PROCESSING` operator +- [Feature] Add Python binding libs for serving +- [Feature] Add thread pool for ops executing +- [Feature] Node could have multiple out edges +- [Feature] Add support for Execution with specific party + ## 20210928 - [serving] Init release. diff --git a/README.md b/README.md index 7dffd4f..9561c45 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![CircleCI](https://dl.circleci.com/status-badge/img/gh/secretflow/serving/tree/main.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/secretflow/serving/tree/main) -SecretFlow-Serving is a serving system for privacy-preserving machine learning models. +SecretFlow-Serving is a serving system for privacy-preserving machine learning models. ## Serve a model @@ -11,17 +11,9 @@ SecretFlow-Serving is a serving system for privacy-preserving machine learning docker pull secretflow/serving-anolis8:latest # Start Secretflow Serving container and open the REST API port -# run alice -docker run -t --rm --name serving-example-alice --network=host \ - --entrypoint "/root/sf_serving/secretflow_serving" \ - secretflow/serving-anolis8:latest \ - "--serving_config_file=/root/sf_serving/examples/alice/serving.config" & - -# run bob -docker run -t --rm --name serving-example-bob --network=host \ - --entrypoint "/root/sf_serving/secretflow_serving" \ - secretflow/serving-anolis8:latest \ - "--serving_config_file=/root/sf_serving/examples/bob/serving.config" & +cd examples + +docker-compose up -d # Query the model using the predict API curl --location 'http://127.0.0.1:9010/PredictionService/Predict' \ diff --git a/WORKSPACE b/WORKSPACE index ac6183e..2f69af5 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -43,3 +43,16 @@ boost_deps() load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies") rules_pkg_dependencies() + +load("@pybind11_bazel//:python_configure.bzl", "python_configure") + +python_configure( + name = "local_config_python", + python_version = "3", +) + +load("@rules_proto_grpc//:repositories.bzl", "rules_proto_grpc_repos", "rules_proto_grpc_toolchains") + +rules_proto_grpc_toolchains() + +rules_proto_grpc_repos() diff --git a/bazel/arrow.BUILD b/bazel/arrow.BUILD index 7b6d515..3773b19 100644 --- a/bazel/arrow.BUILD +++ b/bazel/arrow.BUILD @@ -84,6 +84,8 @@ cc_library( "cpp/src/arrow/vendored/string_view.hpp", "cpp/src/arrow/vendored/variant.hpp", "cpp/src/arrow/vendored/base64.cpp", + "cpp/src/arrow/vendored/double-conversion/*.cc", + "cpp/src/arrow/vendored/double-conversion/*.h", "cpp/src/arrow/**/*.h", "cpp/src/parquet/**/*.h", "cpp/src/parquet/**/*.cc", diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index bf79731..e86248e 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -18,13 +18,14 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") SECRETFLOW_GIT = "https://github.com/secretflow" -YACL_COMMIT_ID = "ebcc0a27e5cd511bc5f87e97f5695b0b8d07fc74" +YACL_COMMIT_ID = "5feaa30e6a2ab3be5a01a7a4ee3c1613d11386d9" -KUSCIA_COMMIT_ID = "fbc6e69433e6320e896f103f2360b17a3863784c" +KUSCIA_COMMIT_ID = "75d37fa346830eb4798ff56fcf919de14a9ef657" def sf_serving_deps(): _bazel_platform() _bazel_rules_pkg() + _rules_proto_grpc() _com_github_nelhage_rules_boost() _com_github_facebook_zstd() @@ -39,6 +40,8 @@ def sf_serving_deps(): _com_github_jupp0r_prometheus_cpp() _org_apache_thrift() _org_apache_arrow() + _com_github_pybind11_bazel() + _com_github_pybind11() # aws s3 _com_aws_c_common() @@ -291,10 +294,43 @@ def _org_apache_arrow(): maybe( http_archive, name = "org_apache_arrow", - sha256 = "f01b76a42ceb30409e7b1953ef64379297dd0c08502547cae6aaafd2c4a4d92e", - strip_prefix = "arrow-apache-arrow-12.0.1", + sha256 = "07cdb4da6795487c800526b2865c150ab7d80b8512a31793e6a7147c8ccd270f", + strip_prefix = "arrow-apache-arrow-14.0.2", build_file = "@sf_serving//bazel:arrow.BUILD", urls = [ - "https://github.com/apache/arrow/archive/refs/tags/apache-arrow-12.0.1.tar.gz", + "https://github.com/apache/arrow/archive/refs/tags/apache-arrow-14.0.2.tar.gz", + ], + ) + +def _com_github_pybind11_bazel(): + maybe( + http_archive, + name = "pybind11_bazel", + sha256 = "2d3316d89b581966fc11eab9aa9320276baee95c8233c7a8efc7158623a48de0", + strip_prefix = "pybind11_bazel-ff261d2e9190955d0830040b20ea59ab9dbe66c8", + urls = [ + "https://github.com/pybind/pybind11_bazel/archive/ff261d2e9190955d0830040b20ea59ab9dbe66c8.zip", + ], + ) + +def _com_github_pybind11(): + maybe( + http_archive, + name = "pybind11", + build_file = "@pybind11_bazel//:pybind11.BUILD", + sha256 = "d475978da0cdc2d43b73f30910786759d593a9d8ee05b1b6846d1eb16c6d2e0c", + strip_prefix = "pybind11-2.11.1", + urls = [ + "https://github.com/pybind/pybind11/archive/refs/tags/v2.11.1.tar.gz", + ], + ) + +def _rules_proto_grpc(): + http_archive( + name = "rules_proto_grpc", + sha256 = "928e4205f701b7798ce32f3d2171c1918b363e9a600390a25c876f075f1efc0a", + strip_prefix = "rules_proto_grpc-4.4.0", + urls = [ + "https://github.com/rules-proto-grpc/rules_proto_grpc/releases/download/4.4.0/rules_proto_grpc-4.4.0.tar.gz", ], ) diff --git a/build_wheel_entrypoint.sh b/build_wheel_entrypoint.sh new file mode 100644 index 0000000..c0fca59 --- /dev/null +++ b/build_wheel_entrypoint.sh @@ -0,0 +1,21 @@ +#! /bin/bash +# +# Copyright 2022 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + + +rm -rf dist +python setup.py bdist_wheel +python3 -m pip install dist/*.whl --force-reinstall diff --git a/docker/Dockerfile b/docker/Dockerfile index 5d973ae..550e1bd 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -27,6 +27,10 @@ LABEL maintainer="secretflow-contact@service.alipay.com" COPY --from=python /root/miniconda3/envs/secretflow/bin/ /usr/local/bin/ COPY --from=python /root/miniconda3/envs/secretflow/lib/ /usr/local/lib/ +RUN yum install -y protobuf libnl3 libgomp && yum clean all + +RUN grep -rl '#!/root/miniconda3/envs/secretflow/bin' /usr/local/bin/ | xargs sed -i -e 's/#!\/root\/miniconda3\/envs\/secretflow/#!\/usr\/local/g' + COPY sf_serving.tar.gz /root/sf_serving.tgz RUN tar -C /root -xzf /root/sf_serving.tgz diff --git a/docker/version.txt b/docker/version.txt index 0e2f711..36a5d7e 100644 --- a/docker/version.txt +++ b/docker/version.txt @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -version = "0.1.0b" +version = "0.2.0dev" diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..92dd33a --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/build.sh b/docs/build.sh new file mode 100755 index 0000000..1eaa6f1 --- /dev/null +++ b/docs/build.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +usage() { echo "Usage: $0 [-l ]" 1>&2; exit 1; } + + +while getopts ":l:" o; do + case "${o}" in + l) + l=${OPTARG} + ;; + *) + usage + ;; + esac +done +shift $((OPTIND-1)) + +if [ -z "${l}" ]; then + usage +fi + +echo "selected language is: ${l}" + +if [[ "$l" != "en" && "$l" != "zh_CN" ]]; then + usage +fi + + +SPHINX_APIDOC_OPTIONS=members,autosummary +make clean +env PYTHONPATH=$PYTHONPATH:$PWD/.. make SPHINXOPTS="-D language='${l}'" html diff --git a/docs/locales/zh_CN/LC_MESSAGES/index.po b/docs/locales/zh_CN/LC_MESSAGES/index.po new file mode 100644 index 0000000..e975cb2 --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/index.po @@ -0,0 +1,85 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023 Ant Group Co., Ltd. +# This file is distributed under the same license as the SecretFlow-Serving +# package. +# FIRST AUTHOR , 2023. +# +msgid "" +msgstr "" +"Project-Id-Version: SecretFlow-Serving \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-01-04 18:17+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.13.0\n" + +#: ../../source/index.rst:7 +msgid "Welcome to SecretFlow-Serving's documentation!" +msgstr "欢迎来到SecretFlow-Serving文档页面!" + +#: ../../source/index.rst:9 +msgid "" +"SecretFlow-Serving is a serving system for privacy-preserving machine " +"learning models." +msgstr "SecretFlow-Serving 是一个加载隐私保护机器学习模型的在线服务系统。" + +#: ../../source/index.rst:13 +msgid "Getting started" +msgstr "开始" + +#: ../../source/index.rst:15 +msgid "" +"Follow the :doc:`tutorial ` and try out SecretFlow-" +"Serving on your machine!" +msgstr "按照 :doc:`快速教程 ` 在本地尝试SecretFlow-Serving" + +#: ../../source/index.rst:19 +msgid "SecretFlow-Serving Systems" +msgstr "SecretFlow-Serving 系统" + +#: ../../source/index.rst:21 +msgid "" +"**Overview**: :doc:`System overview and architecture " +"`" +msgstr "**概述**::doc:`系统简介及架构设计 `" + +#: ../../source/index.rst:26 +msgid "Deployment" +msgstr "部署" + +#: ../../source/index.rst:28 +msgid "" +"**Guides**: :doc:`How to deploy an SecretFlow-Serving " +"cluster`" +msgstr "**指南**::doc:`如何部署 SecretFlow-Serving 集群`" + +#: ../../source/index.rst:31 +msgid "" +"**Reference**: :doc:`SecretFlow-Serving service API ` | " +":doc:`SecretFlow-Serving system config ` | :doc" +":`SecretFlow-Serving feature service spi `" +msgstr "" +"**参考**::doc:`SecretFlow-Serving 服务 API ` | :doc" +":`SecretFlow-Serving 系统配置 ` | :doc:`SecretFlow-Serving" +" 特征服务 SPI `" + +#: ../../source/index.rst:38 +msgid "Graph" +msgstr "模型图" + +#: ../../source/index.rst:40 +msgid "" +"**Overview**: :doc:`Introduction to graphs " +"` | :doc:`Operators " +"`" +msgstr "" +"**概述**::doc:`模型图的介绍 ` | :doc:`算子 " +"`" + +#: ../../source/index.rst:44 +msgid "**Reference**: :doc:`SecretFlow-Serving model `" +msgstr "**参考**::doc:`SecretFlow-Serving 模型 `" diff --git a/docs/locales/zh_CN/LC_MESSAGES/intro/index.po b/docs/locales/zh_CN/LC_MESSAGES/intro/index.po new file mode 100644 index 0000000..49213d8 --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/intro/index.po @@ -0,0 +1,23 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023 Ant Group Co., Ltd. +# This file is distributed under the same license as the SecretFlow-Serving +# package. +# FIRST AUTHOR , 2023. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: SecretFlow-Serving \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2023-12-25 11:37+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.13.0\n" + +#: ../../source/intro/index.rst:2 +msgid "Introduction" +msgstr "介绍" diff --git a/docs/locales/zh_CN/LC_MESSAGES/intro/tutorial.po b/docs/locales/zh_CN/LC_MESSAGES/intro/tutorial.po new file mode 100644 index 0000000..bbac4d2 --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/intro/tutorial.po @@ -0,0 +1,109 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023 Ant Group Co., Ltd. +# This file is distributed under the same license as the SecretFlow-Serving +# package. +# FIRST AUTHOR , 2023. +# +msgid "" +msgstr "" +"Project-Id-Version: SecretFlow-Serving \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-01-04 16:56+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.13.0\n" + +#: ../../source/intro/tutorial.rst:2 +msgid "Quickstart" +msgstr "快速开始" + +#: ../../source/intro/tutorial.rst:5 +msgid "TL;DR" +msgstr "简略" + +#: ../../source/intro/tutorial.rst:7 +msgid "" +"Use ``docker-compose`` to deploy a SecretFlow-Serving cluster, the query " +"the model using the predict API." +msgstr "" +"通过使用 ``docker-compose`` 来部署一个SecretFlow-Serving集群,然后通过访问predict " +"API来请求serving加载的模型。" + +#: ../../source/intro/tutorial.rst:11 +msgid "Start SecretFlow-Serving Service" +msgstr "启动SecretFlow-Serving" + +#: ../../source/intro/tutorial.rst:13 +msgid "" +"You could start SecretFlow-Serving service via `docker-compose " +"`_, it would " +"deploy and start services as shown in the following figure, it contains " +"two SecretFlow-Serving from party ``Alice``, ``Bob``." +msgstr "" +"你可以通过使用 `docker-compose " +"`_ " +"来启动和部署服务,如下图所示,其包含两个分别来自 ``Alice`` 和 ``Bob`` 的SecretFlow-Serving" + +#: ../../source/intro/tutorial.rst:-1 +msgid "docker-compose deployment for quickstart example" +msgstr "快速入门示例的 docker-compose 部署" + +#: ../../source/intro/tutorial.rst:20 +msgid "" +"To demonstrate SecretFlow-Serving, we conducted the following simplified " +"operations:" +msgstr "为了演示 SecretFlow-Serving,我们进行了以下简化操作:" + +#: ../../source/intro/tutorial.rst:22 +msgid "" +"Both parties of Secretflow-Serving use mock feature source to produce " +"random feature values." +msgstr "每方的 Secretflow-Serving 均使用仿真特征数据源产生随机的特征值。" + +#: ../../source/intro/tutorial.rst:23 +msgid "" +"The model files in the examples directory are loaded by ``Alice`` and " +"``Bob``'s Secretflow-Serving respectively。" +msgstr "examples目录下的模型文件会分别被 ``Alice`` 和 ``Bob`` 的Secretflow-Serving加载使用。" + +#: ../../source/intro/tutorial.rst:24 +msgid "" +"The SecretFlow-Serving is served through the HTTP protocol. However, for " +"production environments, it is recommended to use HTTPS instead. Please " +"check :ref:`TLS Configuration ` for details." +msgstr "" +"本示例的SecretFlow-Serving 通过 HTTP 协议提供服务。然而,对于生产环境,建议使用 HTTPS 协议来代替。请查看 " +":doc:`TLS 配置 ` 获取详细信息。" + +#: ../../source/intro/tutorial.rst:33 +msgid "" +"Now, the ``Alice``'s SecretFlow-Serving is listening on " +"``http://localhost:9010``, the ``Bob``'s SecretFlow-Serving is listening " +"on ``http://localhost:9011``, you could send predict request to it via " +"curl or other http tools." +msgstr "" +"现在,``Alice`` 的 SecretFlow-Serving 监听 ``http://localhost:9010``, ``Bob`` 的" +" SecretFlow-Serving 监听 " +"``http://localhost:9011``,你可以通过使用curl或其他http工具向其发送预测请求。" + +#: ../../source/intro/tutorial.rst:37 +msgid "Do Predict" +msgstr "执行预测请求" + +#: ../../source/intro/tutorial.rst:39 +msgid "send predict request to ``Alice``" +msgstr "向 ``Alice`` 发送预测请求" + +#: ../../source/intro/tutorial.rst:64 +msgid "send predict request to ``Bob``" +msgstr "向 ``Bob`` 发送预测请求" + +#: ../../source/intro/tutorial.rst:89 +msgid "" +"Please checkout :ref:`SecretFlow-Serving API ` for the" +" Predict API details." +msgstr "请查询 :doc:`SecretFlow-Serving API ` 以获得更多关于预测API的信息。" diff --git a/docs/locales/zh_CN/LC_MESSAGES/reference/api.po b/docs/locales/zh_CN/LC_MESSAGES/reference/api.po new file mode 100644 index 0000000..21262c6 --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/reference/api.po @@ -0,0 +1,1388 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023 Ant Group Co., Ltd. +# This file is distributed under the same license as the SecretFlow-Serving +# package. +# FIRST AUTHOR , 2023. +# +msgid "" +msgstr "" +"Project-Id-Version: SecretFlow-Serving \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-01-05 11:07+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.13.0\n" + +#: ../../source/reference/api.md:1 +msgid "SecretFlow-Serving API" +msgstr "" + +#: ../../source/reference/api.md:3 +msgid "Table of Contents" +msgstr "" + +#: ../../source/reference/api.md:4 +msgid "Services" +msgstr "" + +#: ../../source/reference/api.md:14 +msgid "[ExecutionService](#executionservice)" +msgstr "" + +#: ../../source/reference/api.md:20 +msgid "[metrics](#metrics)" +msgstr "" + +#: ../../source/reference/api.md:26 +msgid "[ModelService](#modelservice)" +msgstr "" + +#: ../../source/reference/api.md:32 +msgid "[PredictionService](#predictionservice)" +msgstr "" + +#: ../../source/reference/api.md:43 ../../source/reference/api.md:207 +msgid "Messages" +msgstr "" + +#: ../../source/reference/api.md:47 +msgid "[Header](#header)" +msgstr "" + +#: ../../source/reference/api.md:48 +msgid "[Header.DataEntry](#header-dataentry)" +msgstr "" + +#: ../../source/reference/api.md:49 +msgid "[ServiceSpec](#servicespec)" +msgstr "" + +#: ../../source/reference/api.md:58 +msgid "[ExecuteRequest](#executerequest)" +msgstr "" + +#: ../../source/reference/api.md:59 +msgid "[ExecuteResponse](#executeresponse)" +msgstr "" + +#: ../../source/reference/api.md:60 +msgid "[ExecuteResult](#executeresult)" +msgstr "" + +#: ../../source/reference/api.md:61 +msgid "[ExecutionTask](#executiontask)" +msgstr "" + +#: ../../source/reference/api.md:62 +msgid "[FeatureSource](#featuresource)" +msgstr "" + +#: ../../source/reference/api.md:63 +msgid "[IoData](#iodata)" +msgstr "" + +#: ../../source/reference/api.md:64 +msgid "[NodeIo](#nodeio)" +msgstr "" + +#: ../../source/reference/api.md:70 +msgid "[MetricsRequest](#metricsrequest)" +msgstr "" + +#: ../../source/reference/api.md:71 +msgid "[MetricsResponse](#metricsresponse)" +msgstr "" + +#: ../../source/reference/api.md:77 +msgid "[GetModelInfoRequest](#getmodelinforequest)" +msgstr "" + +#: ../../source/reference/api.md:78 +msgid "[GetModelInfoResponse](#getmodelinforesponse)" +msgstr "" + +#: ../../source/reference/api.md:84 +msgid "[PredictRequest](#predictrequest)" +msgstr "" + +#: ../../source/reference/api.md:85 +msgid "[PredictRequest.FsParamsEntry](#predictrequest-fsparamsentry)" +msgstr "" + +#: ../../source/reference/api.md:86 +msgid "[PredictResponse](#predictresponse)" +msgstr "" + +#: ../../source/reference/api.md:87 +msgid "[PredictResult](#predictresult)" +msgstr "" + +#: ../../source/reference/api.md:88 +msgid "[Score](#score)" +msgstr "" + +#: ../../source/reference/api.md:94 +msgid "[Status](#status)" +msgstr "" + +#: ../../source/reference/api.md:100 +msgid "[Feature](#feature)" +msgstr "" + +#: ../../source/reference/api.md:101 +msgid "[FeatureField](#featurefield)" +msgstr "" + +#: ../../source/reference/api.md:102 +msgid "[FeatureParam](#featureparam)" +msgstr "" + +#: ../../source/reference/api.md:103 +msgid "[FeatureValue](#featurevalue)" +msgstr "" + +#: ../../source/reference/api.md:108 ../../source/reference/api.md:598 +msgid "Enums" +msgstr "" + +#: ../../source/reference/api.md:115 +msgid "[ErrorCode](#errorcode)" +msgstr "" + +#: ../../source/reference/api.md:121 +msgid "[FeatureSourceType](#featuresourcetype)" +msgstr "" + +#: ../../source/reference/api.md:139 +msgid "[FieldType](#fieldtype)" +msgstr "" + +#: ../../source/reference/api.md:143 +msgid "[Scalar Value Types](#scalar-value-types)" +msgstr "" + +#: ../../source/reference/api.md:150 +msgid "{#ExecutionService}" +msgstr "" + +#: ../../source/reference/api.md:151 +msgid "ExecutionService" +msgstr "" + +#: ../../source/reference/api.md:152 +msgid "ExecutionService provides access to run execution defined in the GraphDef." +msgstr "" + +#: ../../source/reference/api.md:154 +msgid "Execute" +msgstr "" + +#: ../../source/reference/api.md:156 +msgid "" +"**rpc** Execute([ExecuteRequest](#executerequest)) " +"[ExecuteResponse](#executeresponse)" +msgstr "" + +#: ../../source/reference/api.md:163 +msgid "{#metrics}" +msgstr "" + +#: ../../source/reference/api.md:164 +msgid "metrics" +msgstr "" + +#: ../../source/reference/api.md:167 +msgid "default_method" +msgstr "" + +#: ../../source/reference/api.md:169 +msgid "" +"**rpc** default_method([MetricsRequest](#metricsrequest)) " +"[MetricsResponse](#metricsresponse)" +msgstr "" + +#: ../../source/reference/api.md:176 +msgid "{#ModelService}" +msgstr "" + +#: ../../source/reference/api.md:177 +msgid "ModelService" +msgstr "" + +#: ../../source/reference/api.md:178 +msgid "ModelService provides operation ralated to models." +msgstr "" + +#: ../../source/reference/api.md:180 +msgid "GetModelInfo" +msgstr "" + +#: ../../source/reference/api.md:182 +msgid "" +"**rpc** GetModelInfo([GetModelInfoRequest](#getmodelinforequest)) " +"[GetModelInfoResponse](#getmodelinforesponse)" +msgstr "" + +#: ../../source/reference/api.md:189 +msgid "{#PredictionService}" +msgstr "" + +#: ../../source/reference/api.md:190 +msgid "PredictionService" +msgstr "" + +#: ../../source/reference/api.md:191 +msgid "PredictionService provides access to the serving model." +msgstr "" + +#: ../../source/reference/api.md:193 +msgid "Predict" +msgstr "" + +#: ../../source/reference/api.md:195 +msgid "" +"**rpc** Predict([PredictRequest](#predictrequest)) " +"[PredictResponse](#predictresponse)" +msgstr "" + +#: ../../source/reference/api.md:198 +msgid "Predict." +msgstr "" + +#: ../../source/reference/api.md:211 +msgid "{#Header}" +msgstr "" + +#: ../../source/reference/api.md:212 +msgid "Header" +msgstr "" + +#: ../../source/reference/api.md:213 +msgid "Header containing custom data" +msgstr "" + +#: ../../source/reference/api.md +msgid "Field" +msgstr "" + +#: ../../source/reference/api.md +msgid "Type" +msgstr "" + +#: ../../source/reference/api.md +msgid "Description" +msgstr "" + +#: ../../source/reference/api.md +msgid "data" +msgstr "" + +#: ../../source/reference/api.md +msgid "[map Header.DataEntry](#header-dataentry )" +msgstr "" + +#: ../../source/reference/api.md +msgid "none" +msgstr "" + +#: ../../source/reference/api.md:223 +msgid "{#Header.DataEntry}" +msgstr "" + +#: ../../source/reference/api.md:224 +msgid "Header.DataEntry" +msgstr "" + +#: ../../source/reference/api.md +msgid "key" +msgstr "" + +#: ../../source/reference/api.md +msgid "[ string](#string )" +msgstr "" + +#: ../../source/reference/api.md +msgid "value" +msgstr "" + +#: ../../source/reference/api.md:236 +msgid "{#ServiceSpec}" +msgstr "" + +#: ../../source/reference/api.md:237 +msgid "ServiceSpec" +msgstr "" + +#: ../../source/reference/api.md:238 +msgid "Metadata for an predict or execute request." +msgstr "" + +#: ../../source/reference/api.md +msgid "id" +msgstr "" + +#: ../../source/reference/api.md +msgid "The id of the model service." +msgstr "" + +#: ../../source/reference/api.md:252 +msgid "{#ExecuteRequest}" +msgstr "" + +#: ../../source/reference/api.md:253 +msgid "ExecuteRequest" +msgstr "" + +#: ../../source/reference/api.md:254 +msgid "Execute request containing one or more requests." +msgstr "" + +#: ../../source/reference/api.md +msgid "header" +msgstr "" + +#: ../../source/reference/api.md +msgid "[ Header](#header )" +msgstr "" + +#: ../../source/reference/api.md +msgid "" +"Custom data. The header will be passed to the downstream system which " +"implement the feature service spi." +msgstr "" + +#: ../../source/reference/api.md +msgid "requester_id" +msgstr "" + +#: ../../source/reference/api.md +msgid "Represents the id of the requesting party" +msgstr "" + +#: ../../source/reference/api.md +msgid "service_spec" +msgstr "" + +#: ../../source/reference/api.md +msgid "[ ServiceSpec](#servicespec )" +msgstr "" + +#: ../../source/reference/api.md +msgid "Model service specification." +msgstr "" + +#: ../../source/reference/api.md +msgid "session_id" +msgstr "" + +#: ../../source/reference/api.md +msgid "Represents the session of this execute." +msgstr "" + +#: ../../source/reference/api.md +msgid "feature_source" +msgstr "" + +#: ../../source/reference/api.md +msgid "[ FeatureSource](#featuresource )" +msgstr "" + +#: ../../source/reference/api.md +msgid "task" +msgstr "" + +#: ../../source/reference/api.md +msgid "[ ExecutionTask](#executiontask )" +msgstr "" + +#: ../../source/reference/api.md:269 +msgid "{#ExecuteResponse}" +msgstr "" + +#: ../../source/reference/api.md:270 +msgid "ExecuteResponse" +msgstr "" + +#: ../../source/reference/api.md:271 +msgid "Execute response containing one or more responses." +msgstr "" + +#: ../../source/reference/api.md +msgid "" +"Custom data. Passed by the downstream system which implement the feature " +"service spi." +msgstr "" + +#: ../../source/reference/api.md +msgid "status" +msgstr "" + +#: ../../source/reference/api.md +msgid "[ Status](#status )" +msgstr "" + +#: ../../source/reference/api.md +msgid "Staus of this response." +msgstr "" + +#: ../../source/reference/api.md +msgid "result" +msgstr "" + +#: ../../source/reference/api.md +msgid "[ ExecuteResult](#executeresult )" +msgstr "" + +#: ../../source/reference/api.md:285 +msgid "{#ExecuteResult}" +msgstr "" + +#: ../../source/reference/api.md:286 +msgid "ExecuteResult" +msgstr "" + +#: ../../source/reference/api.md:287 +msgid "Execute result of the request task." +msgstr "" + +#: ../../source/reference/api.md +msgid "execution_id" +msgstr "" + +#: ../../source/reference/api.md +msgid "[ int32](#int32 )" +msgstr "" + +#: ../../source/reference/api.md +msgid "Specified the execution id." +msgstr "" + +#: ../../source/reference/api.md +msgid "nodes" +msgstr "" + +#: ../../source/reference/api.md +msgid "[repeated NodeIo](#nodeio )" +msgstr "" + +#: ../../source/reference/api.md:298 +msgid "{#ExecutionTask}" +msgstr "" + +#: ../../source/reference/api.md:299 +msgid "ExecutionTask" +msgstr "" + +#: ../../source/reference/api.md:300 +msgid "Execute request task." +msgstr "" + +#: ../../source/reference/api.md:311 +msgid "{#FeatureSource}" +msgstr "" + +#: ../../source/reference/api.md:312 +msgid "FeatureSource" +msgstr "" + +#: ../../source/reference/api.md:313 +msgid "Descriptive feature source" +msgstr "" + +#: ../../source/reference/api.md +msgid "type" +msgstr "" + +#: ../../source/reference/api.md +msgid "[ FeatureSourceType](#featuresourcetype )" +msgstr "" + +#: ../../source/reference/api.md +msgid "Identifies the source type of the features" +msgstr "" + +#: ../../source/reference/api.md +msgid "fs_param" +msgstr "" + +#: ../../source/reference/api.md +msgid "[ secretflow.serving.FeatureParam](#featureparam )" +msgstr "" + +#: ../../source/reference/api.md +msgid "" +"Custom parameter for fetch features from feature service or other " +"systems. Valid when `type==FeatureSourceType::FS_SERVICE`" +msgstr "" + +#: ../../source/reference/api.md +msgid "predefineds" +msgstr "" + +#: ../../source/reference/api.md +msgid "[repeated secretflow.serving.Feature](#feature )" +msgstr "" + +#: ../../source/reference/api.md +msgid "Defined features. Valid when `type==FeatureSourceType::FS_PREDEFINED`" +msgstr "" + +#: ../../source/reference/api.md:325 +msgid "{#IoData}" +msgstr "" + +#: ../../source/reference/api.md:326 +msgid "IoData" +msgstr "" + +#: ../../source/reference/api.md:327 +msgid "The serialized data of the node input/output." +msgstr "" + +#: ../../source/reference/api.md +msgid "datas" +msgstr "" + +#: ../../source/reference/api.md +msgid "[repeated bytes](#bytes )" +msgstr "" + +#: ../../source/reference/api.md:337 +msgid "{#NodeIo}" +msgstr "" + +#: ../../source/reference/api.md:338 +msgid "NodeIo" +msgstr "" + +#: ../../source/reference/api.md:339 +msgid "Represents the node input/output data." +msgstr "" + +#: ../../source/reference/api.md +msgid "name" +msgstr "" + +#: ../../source/reference/api.md +msgid "Node name." +msgstr "" + +#: ../../source/reference/api.md +msgid "ios" +msgstr "" + +#: ../../source/reference/api.md +msgid "[repeated IoData](#iodata )" +msgstr "" + +#: ../../source/reference/api.md:352 +msgid "{#MetricsRequest}" +msgstr "" + +#: ../../source/reference/api.md:353 +msgid "MetricsRequest" +msgstr "" + +#: ../../source/reference/api.md:359 +msgid "{#MetricsResponse}" +msgstr "" + +#: ../../source/reference/api.md:360 +msgid "MetricsResponse" +msgstr "" + +#: ../../source/reference/api.md:368 +msgid "{#GetModelInfoRequest}" +msgstr "" + +#: ../../source/reference/api.md:369 +msgid "GetModelInfoRequest" +msgstr "" + +#: ../../source/reference/api.md +msgid "Custom data." +msgstr "" + +#: ../../source/reference/api.md:381 +msgid "{#GetModelInfoResponse}" +msgstr "" + +#: ../../source/reference/api.md:382 +msgid "GetModelInfoResponse" +msgstr "" + +#: ../../source/reference/api.md +msgid "model_info" +msgstr "" + +#: ../../source/reference/api.md +msgid "[ secretflow.serving.ModelInfo](#modelinfo )" +msgstr "" + +#: ../../source/reference/api.md:398 +msgid "{#PredictRequest}" +msgstr "" + +#: ../../source/reference/api.md:399 +msgid "PredictRequest" +msgstr "" + +#: ../../source/reference/api.md:400 +msgid "Predict request containing one or more requests. examples:" +msgstr "" + +#: ../../source/reference/api.md +msgid "fs_params" +msgstr "" + +#: ../../source/reference/api.md +msgid "[map PredictRequest.FsParamsEntry](#predictrequest-fsparamsentry )" +msgstr "" + +#: ../../source/reference/api.md +msgid "" +"The params for fetch features. Note that this should include all the " +"parties involved in the prediction. Key: party's id. Value: params for " +"fetch features." +msgstr "" + +#: ../../source/reference/api.md +msgid "predefined_features" +msgstr "" + +#: ../../source/reference/api.md +msgid "" +"Optional. If defined, the request party will no longer query for the " +"feature but will use defined fetures in `predefined_features` for the " +"prediction." +msgstr "" + +#: ../../source/reference/api.md:442 +msgid "{#PredictRequest.FsParamsEntry}" +msgstr "" + +#: ../../source/reference/api.md:443 +msgid "PredictRequest.FsParamsEntry" +msgstr "" + +#: ../../source/reference/api.md:455 +msgid "{#PredictResponse}" +msgstr "" + +#: ../../source/reference/api.md:456 +msgid "PredictResponse" +msgstr "" + +#: ../../source/reference/api.md:457 +msgid "Predict response containing one or more responses. examples:" +msgstr "" + +#: ../../source/reference/api.md +msgid "results" +msgstr "" + +#: ../../source/reference/api.md +msgid "[repeated PredictResult](#predictresult )" +msgstr "" + +#: ../../source/reference/api.md +msgid "" +"List of the predict result. Returned in the same order as the request's " +"feature query data." +msgstr "" + +#: ../../source/reference/api.md:499 +msgid "{#PredictResult}" +msgstr "" + +#: ../../source/reference/api.md:500 +msgid "PredictResult" +msgstr "" + +#: ../../source/reference/api.md:501 +msgid "Result of single predict request." +msgstr "" + +#: ../../source/reference/api.md +msgid "scores" +msgstr "" + +#: ../../source/reference/api.md +msgid "[repeated Score](#score )" +msgstr "" + +#: ../../source/reference/api.md +msgid "According to the model, there may be one or multi scores." +msgstr "" + +#: ../../source/reference/api.md:511 +msgid "{#Score}" +msgstr "" + +#: ../../source/reference/api.md:512 +msgid "Score" +msgstr "" + +#: ../../source/reference/api.md:513 +msgid "Result of regression or one class of Classifications" +msgstr "" + +#: ../../source/reference/api.md +msgid "" +"The name of the score, it depends on the attribute configuration of the " +"model." +msgstr "" + +#: ../../source/reference/api.md +msgid "[ double](#double )" +msgstr "" + +#: ../../source/reference/api.md +msgid "The value of the score." +msgstr "" + +#: ../../source/reference/api.md:526 +msgid "{#Status}" +msgstr "" + +#: ../../source/reference/api.md:527 +msgid "Status" +msgstr "" + +#: ../../source/reference/api.md:528 +msgid "Represents the status of a request" +msgstr "" + +#: ../../source/reference/api.md +msgid "code" +msgstr "" + +#: ../../source/reference/api.md +msgid "The code of this status. Must be one of ErrorCode in error_code.proto" +msgstr "" + +#: ../../source/reference/api.md +msgid "msg" +msgstr "" + +#: ../../source/reference/api.md +msgid "The msg of this status." +msgstr "" + +#: ../../source/reference/api.md:541 +msgid "{#Feature}" +msgstr "" + +#: ../../source/reference/api.md:542 +msgid "Feature" +msgstr "" + +#: ../../source/reference/api.md:543 +msgid "The definition of a feature" +msgstr "" + +#: ../../source/reference/api.md +msgid "field" +msgstr "" + +#: ../../source/reference/api.md +msgid "[ FeatureField](#featurefield )" +msgstr "" + +#: ../../source/reference/api.md +msgid "[ FeatureValue](#featurevalue )" +msgstr "" + +#: ../../source/reference/api.md:554 +msgid "{#FeatureField}" +msgstr "" + +#: ../../source/reference/api.md:555 +msgid "FeatureField" +msgstr "" + +#: ../../source/reference/api.md:556 +msgid "The definition of a feature field." +msgstr "" + +#: ../../source/reference/api.md +msgid "Unique name of the feature" +msgstr "" + +#: ../../source/reference/api.md +msgid "[ FieldType](#fieldtype )" +msgstr "" + +#: ../../source/reference/api.md +msgid "Field type of the feature" +msgstr "" + +#: ../../source/reference/api.md:567 +msgid "{#FeatureParam}" +msgstr "" + +#: ../../source/reference/api.md:568 +msgid "FeatureParam" +msgstr "" + +#: ../../source/reference/api.md:569 +msgid "The param for fetch features" +msgstr "" + +#: ../../source/reference/api.md +msgid "query_datas" +msgstr "" + +#: ../../source/reference/api.md +msgid "[repeated string](#string )" +msgstr "" + +#: ../../source/reference/api.md +msgid "" +"The serialized datas for query features. Each one for query one row of " +"features." +msgstr "" + +#: ../../source/reference/api.md +msgid "query_context" +msgstr "" + +#: ../../source/reference/api.md +msgid "Optional. Represents the common part of the query datas." +msgstr "" + +#: ../../source/reference/api.md:580 +msgid "{#FeatureValue}" +msgstr "" + +#: ../../source/reference/api.md:581 +msgid "FeatureValue" +msgstr "" + +#: ../../source/reference/api.md:582 +msgid "The value of a feature" +msgstr "" + +#: ../../source/reference/api.md +msgid "i32s" +msgstr "" + +#: ../../source/reference/api.md +msgid "[repeated int32](#int32 )" +msgstr "" + +#: ../../source/reference/api.md +msgid "int list" +msgstr "" + +#: ../../source/reference/api.md +msgid "i64s" +msgstr "" + +#: ../../source/reference/api.md +msgid "[repeated int64](#int64 )" +msgstr "" + +#: ../../source/reference/api.md +msgid "fs" +msgstr "" + +#: ../../source/reference/api.md +msgid "[repeated float](#float )" +msgstr "" + +#: ../../source/reference/api.md +msgid "float list" +msgstr "" + +#: ../../source/reference/api.md +msgid "ds" +msgstr "" + +#: ../../source/reference/api.md +msgid "[repeated double](#double )" +msgstr "" + +#: ../../source/reference/api.md +msgid "ss" +msgstr "" + +#: ../../source/reference/api.md +msgid "string list" +msgstr "" + +#: ../../source/reference/api.md +msgid "bs" +msgstr "" + +#: ../../source/reference/api.md +msgid "[repeated bool](#bool )" +msgstr "" + +#: ../../source/reference/api.md +msgid "bool list" +msgstr "" + +#: ../../source/reference/api.md:604 +msgid "ErrorCode" +msgstr "" + +#: ../../source/reference/api.md +msgid "Name" +msgstr "" + +#: ../../source/reference/api.md +msgid "Number" +msgstr "" + +#: ../../source/reference/api.md +msgid "UNKNOWN" +msgstr "" + +#: ../../source/reference/api.md +msgid "0" +msgstr "" + +#: ../../source/reference/api.md +msgid "Placeholder for proto3 default value, do not use it" +msgstr "" + +#: ../../source/reference/api.md +msgid "OK" +msgstr "" + +#: ../../source/reference/api.md +msgid "1" +msgstr "" + +#: ../../source/reference/api.md +msgid "UNEXPECTED_ERROR" +msgstr "" + +#: ../../source/reference/api.md +msgid "2" +msgstr "" + +#: ../../source/reference/api.md +msgid "INVALID_ARGUMENT" +msgstr "" + +#: ../../source/reference/api.md +msgid "3" +msgstr "" + +#: ../../source/reference/api.md +msgid "NETWORK_ERROR" +msgstr "" + +#: ../../source/reference/api.md +msgid "4" +msgstr "" + +#: ../../source/reference/api.md +msgid "NOT_FOUND" +msgstr "" + +#: ../../source/reference/api.md +msgid "5" +msgstr "" + +#: ../../source/reference/api.md +msgid "Some requested entity (e.g., file or directory) was not found." +msgstr "" + +#: ../../source/reference/api.md +msgid "NOT_IMPLEMENTED" +msgstr "" + +#: ../../source/reference/api.md +msgid "6" +msgstr "" + +#: ../../source/reference/api.md +msgid "LOGIC_ERROR" +msgstr "" + +#: ../../source/reference/api.md +msgid "7" +msgstr "" + +#: ../../source/reference/api.md +msgid "SERIALIZE_FAILED" +msgstr "" + +#: ../../source/reference/api.md +msgid "8" +msgstr "" + +#: ../../source/reference/api.md +msgid "DESERIALIZE_FAILED" +msgstr "" + +#: ../../source/reference/api.md +msgid "9" +msgstr "" + +#: ../../source/reference/api.md +msgid "IO_ERROR" +msgstr "" + +#: ../../source/reference/api.md +msgid "10" +msgstr "" + +#: ../../source/reference/api.md +msgid "NOT_READY" +msgstr "" + +#: ../../source/reference/api.md +msgid "11" +msgstr "" + +#: ../../source/reference/api.md +msgid "FS_UNAUTHENTICATED" +msgstr "" + +#: ../../source/reference/api.md +msgid "100" +msgstr "" + +#: ../../source/reference/api.md +msgid "FS_INVALID_ARGUMENT" +msgstr "" + +#: ../../source/reference/api.md +msgid "101" +msgstr "" + +#: ../../source/reference/api.md +msgid "FS_DEADLINE_EXCEEDED" +msgstr "" + +#: ../../source/reference/api.md +msgid "102" +msgstr "" + +#: ../../source/reference/api.md +msgid "FS_NOT_FOUND" +msgstr "" + +#: ../../source/reference/api.md +msgid "103" +msgstr "" + +#: ../../source/reference/api.md +msgid "FS_INTERNAL_ERROR" +msgstr "" + +#: ../../source/reference/api.md +msgid "104" +msgstr "" + +#: ../../source/reference/api.md:632 +msgid "FeatureSourceType" +msgstr "" + +#: ../../source/reference/api.md:633 +msgid "Support feature source type" +msgstr "" + +#: ../../source/reference/api.md +msgid "UNKNOWN_FS_TYPE" +msgstr "" + +#: ../../source/reference/api.md +msgid "FS_NONE" +msgstr "" + +#: ../../source/reference/api.md +msgid "No need features." +msgstr "" + +#: ../../source/reference/api.md +msgid "FS_SERVICE" +msgstr "" + +#: ../../source/reference/api.md +msgid "Fetch features from feature service." +msgstr "" + +#: ../../source/reference/api.md +msgid "FS_PREDEFINED" +msgstr "" + +#: ../../source/reference/api.md +msgid "The feature is defined in the request." +msgstr "" + +#: ../../source/reference/api.md:655 +msgid "FieldType" +msgstr "" + +#: ../../source/reference/api.md:656 +msgid "Supported feature field type." +msgstr "" + +#: ../../source/reference/api.md +msgid "UNKNOWN_FIELD_TYPE" +msgstr "" + +#: ../../source/reference/api.md +msgid "Placeholder for proto3 default value, do not use it." +msgstr "" + +#: ../../source/reference/api.md +msgid "FIELD_BOOL" +msgstr "" + +#: ../../source/reference/api.md +msgid "BOOL" +msgstr "" + +#: ../../source/reference/api.md +msgid "FIELD_INT32" +msgstr "" + +#: ../../source/reference/api.md +msgid "INT32" +msgstr "" + +#: ../../source/reference/api.md +msgid "FIELD_INT64" +msgstr "" + +#: ../../source/reference/api.md +msgid "INT64" +msgstr "" + +#: ../../source/reference/api.md +msgid "FIELD_FLOAT" +msgstr "" + +#: ../../source/reference/api.md +msgid "FLOAT" +msgstr "" + +#: ../../source/reference/api.md +msgid "FIELD_DOUBLE" +msgstr "" + +#: ../../source/reference/api.md +msgid "DOUBLE" +msgstr "" + +#: ../../source/reference/api.md +msgid "FIELD_STRING" +msgstr "" + +#: ../../source/reference/api.md +msgid "STRING" +msgstr "" + +#: ../../source/reference/api.md:672 +msgid "Scalar Value Types" +msgstr "" + +#: ../../source/reference/api.md +msgid ".proto Type" +msgstr "" + +#: ../../source/reference/api.md +msgid "Notes" +msgstr "" + +#: ../../source/reference/api.md +msgid "C++ Type" +msgstr "" + +#: ../../source/reference/api.md +msgid "Java Type" +msgstr "" + +#: ../../source/reference/api.md +msgid "Python Type" +msgstr "" + +#: ../../source/reference/api.md +msgid "

double" +msgstr "" + +#: ../../source/reference/api.md +msgid "double" +msgstr "" + +#: ../../source/reference/api.md +msgid "float" +msgstr "" + +#: ../../source/reference/api.md +msgid "

float" +msgstr "" + +#: ../../source/reference/api.md +msgid "

int32" +msgstr "" + +#: ../../source/reference/api.md +msgid "" +"Uses variable-length encoding. Inefficient for encoding negative numbers " +"– if your field is likely to have negative values, use sint32 instead." +msgstr "" + +#: ../../source/reference/api.md +msgid "int32" +msgstr "" + +#: ../../source/reference/api.md +msgid "int" +msgstr "" + +#: ../../source/reference/api.md +msgid "

int64" +msgstr "" + +#: ../../source/reference/api.md +msgid "" +"Uses variable-length encoding. Inefficient for encoding negative numbers " +"– if your field is likely to have negative values, use sint64 instead." +msgstr "" + +#: ../../source/reference/api.md +msgid "int64" +msgstr "" + +#: ../../source/reference/api.md +msgid "long" +msgstr "" + +#: ../../source/reference/api.md +msgid "int/long" +msgstr "" + +#: ../../source/reference/api.md +msgid "

uint32" +msgstr "" + +#: ../../source/reference/api.md +msgid "Uses variable-length encoding." +msgstr "" + +#: ../../source/reference/api.md +msgid "uint32" +msgstr "" + +#: ../../source/reference/api.md +msgid "

uint64" +msgstr "" + +#: ../../source/reference/api.md +msgid "uint64" +msgstr "" + +#: ../../source/reference/api.md +msgid "

sint32" +msgstr "" + +#: ../../source/reference/api.md +msgid "" +"Uses variable-length encoding. Signed int value. These more efficiently " +"encode negative numbers than regular int32s." +msgstr "" + +#: ../../source/reference/api.md +msgid "

sint64" +msgstr "" + +#: ../../source/reference/api.md +msgid "" +"Uses variable-length encoding. Signed int value. These more efficiently " +"encode negative numbers than regular int64s." +msgstr "" + +#: ../../source/reference/api.md +msgid "

fixed32" +msgstr "" + +#: ../../source/reference/api.md +msgid "" +"Always four bytes. More efficient than uint32 if values are often greater" +" than 2^28." +msgstr "" + +#: ../../source/reference/api.md +msgid "

fixed64" +msgstr "" + +#: ../../source/reference/api.md +msgid "" +"Always eight bytes. More efficient than uint64 if values are often " +"greater than 2^56." +msgstr "" + +#: ../../source/reference/api.md +msgid "

sfixed32" +msgstr "" + +#: ../../source/reference/api.md +msgid "Always four bytes." +msgstr "" + +#: ../../source/reference/api.md +msgid "

sfixed64" +msgstr "" + +#: ../../source/reference/api.md +msgid "Always eight bytes." +msgstr "" + +#: ../../source/reference/api.md +msgid "

bool" +msgstr "" + +#: ../../source/reference/api.md +msgid "bool" +msgstr "" + +#: ../../source/reference/api.md +msgid "boolean" +msgstr "" + +#: ../../source/reference/api.md +msgid "

string" +msgstr "" + +#: ../../source/reference/api.md +msgid "A string must always contain UTF-8 encoded or 7-bit ASCII text." +msgstr "" + +#: ../../source/reference/api.md +msgid "string" +msgstr "" + +#: ../../source/reference/api.md +msgid "String" +msgstr "" + +#: ../../source/reference/api.md +msgid "str/unicode" +msgstr "" + +#: ../../source/reference/api.md +msgid "

bytes" +msgstr "" + +#: ../../source/reference/api.md +msgid "May contain any arbitrary sequence of bytes." +msgstr "" + +#: ../../source/reference/api.md +msgid "ByteString" +msgstr "" + +#: ../../source/reference/api.md +msgid "str" +msgstr "" diff --git a/docs/locales/zh_CN/LC_MESSAGES/reference/config.po b/docs/locales/zh_CN/LC_MESSAGES/reference/config.po new file mode 100644 index 0000000..50ca06f --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/reference/config.po @@ -0,0 +1,1084 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023 Ant Group Co., Ltd. +# This file is distributed under the same license as the SecretFlow-Serving +# package. +# FIRST AUTHOR , 2023. +# +msgid "" +msgstr "" +"Project-Id-Version: SecretFlow-Serving \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-01-04 16:56+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.13.0\n" + +#: ../../source/reference/config.md:1 +msgid "SecretFlow-Serving Config" +msgstr "" + +#: ../../source/reference/config.md:3 +msgid "Table of Contents" +msgstr "" + +#: ../../source/reference/config.md:4 +msgid "Services" +msgstr "" + +#: ../../source/reference/config.md:28 ../../source/reference/config.md:131 +msgid "Messages" +msgstr "" + +#: ../../source/reference/config.md:32 +msgid "[ChannelDesc](#channeldesc)" +msgstr "" + +#: ../../source/reference/config.md:33 +msgid "[ClusterConfig](#clusterconfig)" +msgstr "" + +#: ../../source/reference/config.md:34 +msgid "[PartyDesc](#partydesc)" +msgstr "" + +#: ../../source/reference/config.md:40 +msgid "[CsvOptions](#csvoptions)" +msgstr "" + +#: ../../source/reference/config.md:41 +msgid "[FeatureSourceConfig](#featuresourceconfig)" +msgstr "" + +#: ../../source/reference/config.md:42 +msgid "[HttpOptions](#httpoptions)" +msgstr "" + +#: ../../source/reference/config.md:43 +msgid "[MockOptions](#mockoptions)" +msgstr "" + +#: ../../source/reference/config.md:49 +msgid "[LoggingConfig](#loggingconfig)" +msgstr "" + +#: ../../source/reference/config.md:55 +msgid "[FileSourceMeta](#filesourcemeta)" +msgstr "" + +#: ../../source/reference/config.md:56 +msgid "[ModelConfig](#modelconfig)" +msgstr "" + +#: ../../source/reference/config.md:57 +msgid "[OSSSourceMeta](#osssourcemeta)" +msgstr "" + +#: ../../source/reference/config.md:63 +msgid "[ServerConfig](#serverconfig)" +msgstr "" + +#: ../../source/reference/config.md:64 +msgid "[ServerConfig.FeatureMappingEntry](#serverconfig-featuremappingentry)" +msgstr "" + +#: ../../source/reference/config.md:70 +msgid "[ServingConfig](#servingconfig)" +msgstr "" + +#: ../../source/reference/config.md:76 +msgid "[TlsConfig](#tlsconfig)" +msgstr "" + +#: ../../source/reference/config.md:81 ../../source/reference/config.md:369 +msgid "Enums" +msgstr "" + +#: ../../source/reference/config.md:88 +msgid "[MockDataType](#mockdatatype)" +msgstr "" + +#: ../../source/reference/config.md:94 +msgid "[LogLevel](#loglevel)" +msgstr "" + +#: ../../source/reference/config.md:100 +msgid "[SourceType](#sourcetype)" +msgstr "" + +#: ../../source/reference/config.md:113 +msgid "[Scalar Value Types](#scalar-value-types)" +msgstr "" + +#: ../../source/reference/config.md:135 +msgid "{#ChannelDesc}" +msgstr "" + +#: ../../source/reference/config.md:136 +msgid "ChannelDesc" +msgstr "" + +#: ../../source/reference/config.md:137 +msgid "Description for channels between joined parties" +msgstr "" + +#: ../../source/reference/config.md +msgid "Field" +msgstr "" + +#: ../../source/reference/config.md +msgid "Type" +msgstr "" + +#: ../../source/reference/config.md +msgid "Description" +msgstr "" + +#: ../../source/reference/config.md +msgid "protocol" +msgstr "" + +#: ../../source/reference/config.md +msgid "[ string](#string )" +msgstr "" + +#: ../../source/reference/config.md +msgid "https://github.com/apache/brpc/blob/master/docs/en/client.md#protocols" +msgstr "" + +#: ../../source/reference/config.md +msgid "rpc_timeout_ms" +msgstr "" + +#: ../../source/reference/config.md +msgid "[ int32](#int32 )" +msgstr "" + +#: ../../source/reference/config.md +msgid "Max duration of RPC. -1 means wait indefinitely. Default: 2000 (ms)" +msgstr "" + +#: ../../source/reference/config.md +msgid "connect_timeout_ms" +msgstr "" + +#: ../../source/reference/config.md +msgid "Max duration for a connect. -1 means wait indefinitely. Default: 500 (ms)" +msgstr "" + +#: ../../source/reference/config.md +msgid "tls_config" +msgstr "" + +#: ../../source/reference/config.md +msgid "[ TlsConfig](#tlsconfig )" +msgstr "" + +#: ../../source/reference/config.md +msgid "TLS related config." +msgstr "" + +#: ../../source/reference/config.md +msgid "handshake_max_retry_cnt" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"When the server starts, model information from all parties will be " +"collected. At this time, the remote servers may not have started yet, and" +" we need to retry. And if we connect gateway,the max waiting time for " +"each operation will be rpc_timeout_ms + handshake_retry_interval_ms. " +"Maximum number of retries, default: 60" +msgstr "" + +#: ../../source/reference/config.md +msgid "handshake_retry_interval_ms" +msgstr "" + +#: ../../source/reference/config.md +msgid "time between retries, default: 5000ms" +msgstr "" + +#: ../../source/reference/config.md:152 +msgid "{#ClusterConfig}" +msgstr "" + +#: ../../source/reference/config.md:153 +msgid "ClusterConfig" +msgstr "" + +#: ../../source/reference/config.md:154 +msgid "Runtime config for a serving cluster" +msgstr "" + +#: ../../source/reference/config.md +msgid "parties" +msgstr "" + +#: ../../source/reference/config.md +msgid "[repeated PartyDesc](#partydesc )" +msgstr "" + +#: ../../source/reference/config.md +msgid "none" +msgstr "" + +#: ../../source/reference/config.md +msgid "self_id" +msgstr "" + +#: ../../source/reference/config.md +msgid "channel_desc" +msgstr "" + +#: ../../source/reference/config.md +msgid "[ ChannelDesc](#channeldesc )" +msgstr "" + +#: ../../source/reference/config.md:166 +msgid "{#PartyDesc}" +msgstr "" + +#: ../../source/reference/config.md:167 +msgid "PartyDesc" +msgstr "" + +#: ../../source/reference/config.md:168 +msgid "Description for a joined party" +msgstr "" + +#: ../../source/reference/config.md +msgid "id" +msgstr "" + +#: ../../source/reference/config.md +msgid "Unique id of the party" +msgstr "" + +#: ../../source/reference/config.md +msgid "address" +msgstr "" + +#: ../../source/reference/config.md +msgid "e.g. 127.0.0.1:9001" +msgstr "" + +#: ../../source/reference/config.md +msgid "listen_address" +msgstr "" + +#: ../../source/reference/config.md +msgid "Optional. Address will be used if listen_address is empty." +msgstr "" + +#: ../../source/reference/config.md:182 +msgid "{#CsvOptions}" +msgstr "" + +#: ../../source/reference/config.md:183 +msgid "CsvOptions" +msgstr "" + +#: ../../source/reference/config.md:184 +msgid "Options of a csv feature source." +msgstr "" + +#: ../../source/reference/config.md +msgid "file_path" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"Input file path, specifies where to load data Note that this will load " +"all of the data into memory at once" +msgstr "" + +#: ../../source/reference/config.md +msgid "id_name" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"Id column name, associated with FeatureParam::query_datas Query datas is " +"a subset of id column" +msgstr "" + +#: ../../source/reference/config.md:195 +msgid "{#FeatureSourceConfig}" +msgstr "" + +#: ../../source/reference/config.md:196 +msgid "FeatureSourceConfig" +msgstr "" + +#: ../../source/reference/config.md:197 +msgid "Config for a feature source" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) options.mock_opts" +msgstr "" + +#: ../../source/reference/config.md +msgid "[ MockOptions](#mockoptions )" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) options.http_opts" +msgstr "" + +#: ../../source/reference/config.md +msgid "[ HttpOptions](#httpoptions )" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) options.csv_opts" +msgstr "" + +#: ../../source/reference/config.md +msgid "[ CsvOptions](#csvoptions )" +msgstr "" + +#: ../../source/reference/config.md:209 +msgid "{#HttpOptions}" +msgstr "" + +#: ../../source/reference/config.md:210 +msgid "HttpOptions" +msgstr "" + +#: ../../source/reference/config.md:211 +msgid "" +"Options for a http feature source which should implement the feature " +"service spi. The defined of spi can be found in " +"secretflow_serving/spis/batch_feature_service.proto" +msgstr "" + +#: ../../source/reference/config.md +msgid "endpoint" +msgstr "" + +#: ../../source/reference/config.md +msgid "enable_lb" +msgstr "" + +#: ../../source/reference/config.md +msgid "[ bool](#bool )" +msgstr "" + +#: ../../source/reference/config.md +msgid "Whether to enable round robin load balancer." +msgstr "" + +#: ../../source/reference/config.md +msgid "timeout_ms" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"Max duration of http request. -1 means wait indefinitely. Default: 1000 " +"(ms)" +msgstr "" + +#: ../../source/reference/config.md:227 +msgid "{#MockOptions}" +msgstr "" + +#: ../../source/reference/config.md:228 +msgid "MockOptions" +msgstr "" + +#: ../../source/reference/config.md:229 +msgid "" +"Options for a mock feature source. Mock feature source will generates " +"values(random or fixed, according to type) for the desired features." +msgstr "" + +#: ../../source/reference/config.md +msgid "type" +msgstr "" + +#: ../../source/reference/config.md +msgid "[ MockDataType](#mockdatatype )" +msgstr "" + +#: ../../source/reference/config.md +msgid "default MDT_RANDOM" +msgstr "" + +#: ../../source/reference/config.md:243 +msgid "{#LoggingConfig}" +msgstr "" + +#: ../../source/reference/config.md:244 +msgid "LoggingConfig" +msgstr "" + +#: ../../source/reference/config.md +msgid "system_log_path" +msgstr "" + +#: ../../source/reference/config.md +msgid "system log default value: \"serving.log\"" +msgstr "" + +#: ../../source/reference/config.md +msgid "log_level" +msgstr "" + +#: ../../source/reference/config.md +msgid "[ LogLevel](#loglevel )" +msgstr "" + +#: ../../source/reference/config.md +msgid "default value: LogLevel.INFO_LOG_LEVEL" +msgstr "" + +#: ../../source/reference/config.md +msgid "max_log_file_size" +msgstr "" + +#: ../../source/reference/config.md +msgid "Byte. default value: 500 * 1024 * 1024 (500MB)" +msgstr "" + +#: ../../source/reference/config.md +msgid "max_log_file_count" +msgstr "" + +#: ../../source/reference/config.md +msgid "default value: 10" +msgstr "" + +#: ../../source/reference/config.md:260 +msgid "{#FileSourceMeta}" +msgstr "" + +#: ../../source/reference/config.md:261 +msgid "FileSourceMeta" +msgstr "" + +#: ../../source/reference/config.md:262 +msgid "empty by design" +msgstr "" + +#: ../../source/reference/config.md:267 +msgid "{#ModelConfig}" +msgstr "" + +#: ../../source/reference/config.md:268 +msgid "ModelConfig" +msgstr "" + +#: ../../source/reference/config.md:269 +msgid "Config for serving model" +msgstr "" + +#: ../../source/reference/config.md +msgid "model_id" +msgstr "" + +#: ../../source/reference/config.md +msgid "Unique id of the model package" +msgstr "" + +#: ../../source/reference/config.md +msgid "base_path" +msgstr "" + +#: ../../source/reference/config.md +msgid "Path used to cache and load model package" +msgstr "" + +#: ../../source/reference/config.md +msgid "source_path" +msgstr "" + +#: ../../source/reference/config.md +msgid "Represent the path of the model package in the model source" +msgstr "" + +#: ../../source/reference/config.md +msgid "source_sha256" +msgstr "" + +#: ../../source/reference/config.md +msgid "Optional. The expect sha256 of the model package" +msgstr "" + +#: ../../source/reference/config.md +msgid "source_type" +msgstr "" + +#: ../../source/reference/config.md +msgid "[ SourceType](#sourcetype )" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) kind.file_source_meta" +msgstr "" + +#: ../../source/reference/config.md +msgid "[ FileSourceMeta](#filesourcemeta )" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) kind.oss_source_meta" +msgstr "" + +#: ../../source/reference/config.md +msgid "[ OSSSourceMeta](#osssourcemeta )" +msgstr "" + +#: ../../source/reference/config.md:285 +msgid "{#OSSSourceMeta}" +msgstr "" + +#: ../../source/reference/config.md:286 +msgid "OSSSourceMeta" +msgstr "" + +#: ../../source/reference/config.md:287 +msgid "Options for a S3 Oss model source" +msgstr "" + +#: ../../source/reference/config.md +msgid "access_key" +msgstr "" + +#: ../../source/reference/config.md +msgid "Bucket access key" +msgstr "" + +#: ../../source/reference/config.md +msgid "secret_key" +msgstr "" + +#: ../../source/reference/config.md +msgid "Bucket secret key" +msgstr "" + +#: ../../source/reference/config.md +msgid "virtual_hosted" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"Whether to use virtual host mode, " +"https://docs.aws.amazon.com/AmazonS3/latest/userguide/VirtualHosting.html" +msgstr "" + +#: ../../source/reference/config.md +msgid "bucket" +msgstr "" + +#: ../../source/reference/config.md:303 +msgid "{#ServerConfig}" +msgstr "" + +#: ../../source/reference/config.md:304 +msgid "ServerConfig" +msgstr "" + +#: ../../source/reference/config.md +msgid "feature_mapping" +msgstr "" + +#: ../../source/reference/config.md +msgid "[map ServerConfig.FeatureMappingEntry](#serverconfig-featuremappingentry )" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"Optional. Feature name mapping rules. Key: source or predefined feature " +"name Value: model feature name" +msgstr "" + +#: ../../source/reference/config.md +msgid "Whether to enable tls for server" +msgstr "" + +#: ../../source/reference/config.md +msgid "brpc_builtin_service_port" +msgstr "" + +#: ../../source/reference/config.md +msgid "Brpc builtin service listen port Default: disable service" +msgstr "" + +#: ../../source/reference/config.md +msgid "metrics_exposer_port" +msgstr "" + +#: ../../source/reference/config.md +msgid "Whether `/metrics` service is enable/disable." +msgstr "" + +#: ../../source/reference/config.md +msgid "worker_num" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"Number of pthreads that server runs on. If this option <= 0, use default " +"value. Default: #cpu-cores" +msgstr "" + +#: ../../source/reference/config.md +msgid "max_concurrency" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"Server-level max number of requests processed in parallel Default: 0 " +"(unlimited)" +msgstr "" + +#: ../../source/reference/config.md +msgid "op_exec_worker_num" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"Number of pthreads that server runs to execute ops. If this option <= 0, " +"use default value. Default: #cpu-cores" +msgstr "" + +#: ../../source/reference/config.md:321 +msgid "{#ServerConfig.FeatureMappingEntry}" +msgstr "" + +#: ../../source/reference/config.md:322 +msgid "ServerConfig.FeatureMappingEntry" +msgstr "" + +#: ../../source/reference/config.md +msgid "key" +msgstr "" + +#: ../../source/reference/config.md +msgid "value" +msgstr "" + +#: ../../source/reference/config.md:336 +msgid "{#ServingConfig}" +msgstr "" + +#: ../../source/reference/config.md:337 +msgid "ServingConfig" +msgstr "" + +#: ../../source/reference/config.md:338 +msgid "Related config of serving" +msgstr "" + +#: ../../source/reference/config.md +msgid "Unique id of the serving service" +msgstr "" + +#: ../../source/reference/config.md +msgid "server_conf" +msgstr "" + +#: ../../source/reference/config.md +msgid "[ ServerConfig](#serverconfig )" +msgstr "" + +#: ../../source/reference/config.md +msgid "model_conf" +msgstr "" + +#: ../../source/reference/config.md +msgid "[ ModelConfig](#modelconfig )" +msgstr "" + +#: ../../source/reference/config.md +msgid "cluster_conf" +msgstr "" + +#: ../../source/reference/config.md +msgid "[ ClusterConfig](#clusterconfig )" +msgstr "" + +#: ../../source/reference/config.md +msgid "feature_source_conf" +msgstr "" + +#: ../../source/reference/config.md +msgid "[ FeatureSourceConfig](#featuresourceconfig )" +msgstr "" + +#: ../../source/reference/config.md:354 +msgid "{#TlsConfig}" +msgstr "" + +#: ../../source/reference/config.md:355 +msgid "TlsConfig" +msgstr "" + +#: ../../source/reference/config.md +msgid "certificate_path" +msgstr "" + +#: ../../source/reference/config.md +msgid "Certificate file path" +msgstr "" + +#: ../../source/reference/config.md +msgid "private_key_path" +msgstr "" + +#: ../../source/reference/config.md +msgid "Private key file path" +msgstr "" + +#: ../../source/reference/config.md +msgid "ca_file_path" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"The trusted CA file to verify the peer's certificate If empty, use the " +"system default CA files" +msgstr "" + +#: ../../source/reference/config.md:375 +msgid "MockDataType" +msgstr "" + +#: ../../source/reference/config.md +msgid "Name" +msgstr "" + +#: ../../source/reference/config.md +msgid "Number" +msgstr "" + +#: ../../source/reference/config.md +msgid "INVALID_MOCK_DATA_TYPE" +msgstr "" + +#: ../../source/reference/config.md +msgid "0" +msgstr "" + +#: ../../source/reference/config.md +msgid "Placeholder for proto3 default value, do not use it." +msgstr "" + +#: ../../source/reference/config.md +msgid "MDT_RANDOM" +msgstr "" + +#: ../../source/reference/config.md +msgid "1" +msgstr "" + +#: ../../source/reference/config.md +msgid "random value for each feature" +msgstr "" + +#: ../../source/reference/config.md +msgid "MDT_FIXED" +msgstr "" + +#: ../../source/reference/config.md +msgid "2" +msgstr "" + +#: ../../source/reference/config.md +msgid "fixed value for each feature" +msgstr "" + +#: ../../source/reference/config.md:389 +msgid "LogLevel" +msgstr "" + +#: ../../source/reference/config.md +msgid "INVALID_LOG_LEVEL" +msgstr "" + +#: ../../source/reference/config.md +msgid "DEBUG_LOG_LEVEL" +msgstr "" + +#: ../../source/reference/config.md +msgid "debug" +msgstr "" + +#: ../../source/reference/config.md +msgid "INFO_LOG_LEVEL" +msgstr "" + +#: ../../source/reference/config.md +msgid "info" +msgstr "" + +#: ../../source/reference/config.md +msgid "WARN_LOG_LEVEL" +msgstr "" + +#: ../../source/reference/config.md +msgid "3" +msgstr "" + +#: ../../source/reference/config.md +msgid "warn" +msgstr "" + +#: ../../source/reference/config.md +msgid "ERROR_LOG_LEVEL" +msgstr "" + +#: ../../source/reference/config.md +msgid "4" +msgstr "" + +#: ../../source/reference/config.md +msgid "error" +msgstr "" + +#: ../../source/reference/config.md:405 +msgid "SourceType" +msgstr "" + +#: ../../source/reference/config.md:406 +msgid "Supported model source type" +msgstr "" + +#: ../../source/reference/config.md +msgid "INVALID_SOURCE_TYPE" +msgstr "" + +#: ../../source/reference/config.md +msgid "ST_FILE" +msgstr "" + +#: ../../source/reference/config.md +msgid "Local filesystem" +msgstr "" + +#: ../../source/reference/config.md +msgid "ST_OSS" +msgstr "" + +#: ../../source/reference/config.md +msgid "S3 OSS" +msgstr "" + +#: ../../source/reference/config.md:424 +msgid "Scalar Value Types" +msgstr "" + +#: ../../source/reference/config.md +msgid ".proto Type" +msgstr "" + +#: ../../source/reference/config.md +msgid "Notes" +msgstr "" + +#: ../../source/reference/config.md +msgid "C++ Type" +msgstr "" + +#: ../../source/reference/config.md +msgid "Java Type" +msgstr "" + +#: ../../source/reference/config.md +msgid "Python Type" +msgstr "" + +#: ../../source/reference/config.md +msgid "

double" +msgstr "" + +#: ../../source/reference/config.md +msgid "double" +msgstr "" + +#: ../../source/reference/config.md +msgid "float" +msgstr "" + +#: ../../source/reference/config.md +msgid "

float" +msgstr "" + +#: ../../source/reference/config.md +msgid "

int32" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"Uses variable-length encoding. Inefficient for encoding negative numbers " +"– if your field is likely to have negative values, use sint32 instead." +msgstr "" + +#: ../../source/reference/config.md +msgid "int32" +msgstr "" + +#: ../../source/reference/config.md +msgid "int" +msgstr "" + +#: ../../source/reference/config.md +msgid "

int64" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"Uses variable-length encoding. Inefficient for encoding negative numbers " +"– if your field is likely to have negative values, use sint64 instead." +msgstr "" + +#: ../../source/reference/config.md +msgid "int64" +msgstr "" + +#: ../../source/reference/config.md +msgid "long" +msgstr "" + +#: ../../source/reference/config.md +msgid "int/long" +msgstr "" + +#: ../../source/reference/config.md +msgid "

uint32" +msgstr "" + +#: ../../source/reference/config.md +msgid "Uses variable-length encoding." +msgstr "" + +#: ../../source/reference/config.md +msgid "uint32" +msgstr "" + +#: ../../source/reference/config.md +msgid "

uint64" +msgstr "" + +#: ../../source/reference/config.md +msgid "uint64" +msgstr "" + +#: ../../source/reference/config.md +msgid "

sint32" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"Uses variable-length encoding. Signed int value. These more efficiently " +"encode negative numbers than regular int32s." +msgstr "" + +#: ../../source/reference/config.md +msgid "

sint64" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"Uses variable-length encoding. Signed int value. These more efficiently " +"encode negative numbers than regular int64s." +msgstr "" + +#: ../../source/reference/config.md +msgid "

fixed32" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"Always four bytes. More efficient than uint32 if values are often greater" +" than 2^28." +msgstr "" + +#: ../../source/reference/config.md +msgid "

fixed64" +msgstr "" + +#: ../../source/reference/config.md +msgid "" +"Always eight bytes. More efficient than uint64 if values are often " +"greater than 2^56." +msgstr "" + +#: ../../source/reference/config.md +msgid "

sfixed32" +msgstr "" + +#: ../../source/reference/config.md +msgid "Always four bytes." +msgstr "" + +#: ../../source/reference/config.md +msgid "

sfixed64" +msgstr "" + +#: ../../source/reference/config.md +msgid "Always eight bytes." +msgstr "" + +#: ../../source/reference/config.md +msgid "

bool" +msgstr "" + +#: ../../source/reference/config.md +msgid "bool" +msgstr "" + +#: ../../source/reference/config.md +msgid "boolean" +msgstr "" + +#: ../../source/reference/config.md +msgid "

string" +msgstr "" + +#: ../../source/reference/config.md +msgid "A string must always contain UTF-8 encoded or 7-bit ASCII text." +msgstr "" + +#: ../../source/reference/config.md +msgid "string" +msgstr "" + +#: ../../source/reference/config.md +msgid "String" +msgstr "" + +#: ../../source/reference/config.md +msgid "str/unicode" +msgstr "" + +#: ../../source/reference/config.md +msgid "

bytes" +msgstr "" + +#: ../../source/reference/config.md +msgid "May contain any arbitrary sequence of bytes." +msgstr "" + +#: ../../source/reference/config.md +msgid "ByteString" +msgstr "" + +#: ../../source/reference/config.md +msgid "str" +msgstr "" diff --git a/docs/locales/zh_CN/LC_MESSAGES/reference/index.po b/docs/locales/zh_CN/LC_MESSAGES/reference/index.po new file mode 100644 index 0000000..6009c0b --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/reference/index.po @@ -0,0 +1,42 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023 Ant Group Co., Ltd. +# This file is distributed under the same license as the SecretFlow-Serving +# package. +# FIRST AUTHOR , 2023. +# +msgid "" +msgstr "" +"Project-Id-Version: SecretFlow-Serving \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-01-04 15:21+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.13.0\n" + +#: ../../source/reference/index.rst:9 +msgid "API" +msgstr "" + +#: ../../source/reference/index.rst:15 +msgid "Config" +msgstr "" + +#: ../../source/reference/index.rst:21 +msgid "SPI" +msgstr "" + +#: ../../source/reference/index.rst:27 +msgid "Model" +msgstr "" + +#: ../../source/reference/index.rst:4 +msgid "Reference" +msgstr "参考" + +#: ../../source/reference/index.rst:6 +msgid "This part contains detailed explanation of Model, Configs, SPIs and APIs." +msgstr "本主题包含算子、模型、配置、服务提供接口以及服务接口的细节内容。" diff --git a/docs/locales/zh_CN/LC_MESSAGES/reference/model.po b/docs/locales/zh_CN/LC_MESSAGES/reference/model.po new file mode 100644 index 0000000..99754ad --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/reference/model.po @@ -0,0 +1,1755 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023 Ant Group Co., Ltd. +# This file is distributed under the same license as the SecretFlow-Serving +# package. +# FIRST AUTHOR , 2023. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: SecretFlow-Serving \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-01-04 16:56+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.13.0\n" + +#: ../../source/reference/model.md:1 +msgid "SecretFlow-Serving Model" +msgstr "" + +#: ../../source/reference/model.md:3 +msgid "Table of Contents" +msgstr "" + +#: ../../source/reference/model.md:4 +msgid "Services" +msgstr "" + +#: ../../source/reference/model.md:28 ../../source/reference/model.md:155 +msgid "Messages" +msgstr "" + +#: ../../source/reference/model.md:32 +msgid "[AttrDef](#attrdef)" +msgstr "" + +#: ../../source/reference/model.md:33 +msgid "[AttrValue](#attrvalue)" +msgstr "" + +#: ../../source/reference/model.md:34 +msgid "[BoolList](#boollist)" +msgstr "" + +#: ../../source/reference/model.md:35 +msgid "[BytesList](#byteslist)" +msgstr "" + +#: ../../source/reference/model.md:36 +msgid "[DoubleList](#doublelist)" +msgstr "" + +#: ../../source/reference/model.md:37 +msgid "[FloatList](#floatlist)" +msgstr "" + +#: ../../source/reference/model.md:38 +msgid "[Int32List](#int32list)" +msgstr "" + +#: ../../source/reference/model.md:39 +msgid "[Int64List](#int64list)" +msgstr "" + +#: ../../source/reference/model.md:40 +msgid "[StringList](#stringlist)" +msgstr "" + +#: ../../source/reference/model.md:46 +msgid "[IoDef](#iodef)" +msgstr "" + +#: ../../source/reference/model.md:47 +msgid "[OpDef](#opdef)" +msgstr "" + +#: ../../source/reference/model.md:48 +msgid "[OpTag](#optag)" +msgstr "" + +#: ../../source/reference/model.md:54 +msgid "[ExecutionDef](#executiondef)" +msgstr "" + +#: ../../source/reference/model.md:55 +msgid "[GraphDef](#graphdef)" +msgstr "" + +#: ../../source/reference/model.md:56 +msgid "[GraphView](#graphview)" +msgstr "" + +#: ../../source/reference/model.md:57 +msgid "[NodeDef](#nodedef)" +msgstr "" + +#: ../../source/reference/model.md:58 +msgid "[NodeDef.AttrValuesEntry](#nodedef-attrvaluesentry)" +msgstr "" + +#: ../../source/reference/model.md:59 +msgid "[NodeView](#nodeview)" +msgstr "" + +#: ../../source/reference/model.md:60 +msgid "[RuntimeConfig](#runtimeconfig)" +msgstr "" + +#: ../../source/reference/model.md:66 +msgid "[ModelBundle](#modelbundle)" +msgstr "" + +#: ../../source/reference/model.md:67 +msgid "[ModelInfo](#modelinfo)" +msgstr "" + +#: ../../source/reference/model.md:68 +msgid "[ModelManifest](#modelmanifest)" +msgstr "" + +#: ../../source/reference/model.md:77 ../../source/reference/model.md:87 +msgid "[ComputeTrace](#computetrace)" +msgstr "" + +#: ../../source/reference/model.md:78 ../../source/reference/model.md:88 +msgid "[FunctionInput](#functioninput)" +msgstr "" + +#: ../../source/reference/model.md:79 ../../source/reference/model.md:89 +msgid "[FunctionOutput](#functionoutput)" +msgstr "" + +#: ../../source/reference/model.md:80 ../../source/reference/model.md:90 +msgid "[FunctionTrace](#functiontrace)" +msgstr "" + +#: ../../source/reference/model.md:81 ../../source/reference/model.md:91 +msgid "[Scalar](#scalar)" +msgstr "" + +#: ../../source/reference/model.md:96 ../../source/reference/model.md:643 +msgid "Enums" +msgstr "" + +#: ../../source/reference/model.md:100 +msgid "[AttrType](#attrtype)" +msgstr "" + +#: ../../source/reference/model.md:109 +msgid "[DispatchType](#dispatchtype)" +msgstr "" + +#: ../../source/reference/model.md:115 +msgid "[FileFormatType](#fileformattype)" +msgstr "" + +#: ../../source/reference/model.md:121 +msgid "[DataType](#datatype)" +msgstr "" + +#: ../../source/reference/model.md:127 ../../source/reference/model.md:133 +msgid "[ExtendFunctionName](#extendfunctionname)" +msgstr "" + +#: ../../source/reference/model.md:137 +msgid "[Scalar Value Types](#scalar-value-types)" +msgstr "" + +#: ../../source/reference/model.md:159 +msgid "{#AttrDef}" +msgstr "" + +#: ../../source/reference/model.md:160 +msgid "AttrDef" +msgstr "" + +#: ../../source/reference/model.md:161 +msgid "The definition of an attribute." +msgstr "" + +#: ../../source/reference/model.md +msgid "Field" +msgstr "" + +#: ../../source/reference/model.md +msgid "Type" +msgstr "" + +#: ../../source/reference/model.md +msgid "Description" +msgstr "" + +#: ../../source/reference/model.md +msgid "name" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ string](#string )" +msgstr "" + +#: ../../source/reference/model.md +msgid "Must be unique among all attr of the operator." +msgstr "" + +#: ../../source/reference/model.md +msgid "desc" +msgstr "" + +#: ../../source/reference/model.md +msgid "Description of the attribute" +msgstr "" + +#: ../../source/reference/model.md +msgid "type" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ AttrType](#attrtype )" +msgstr "" + +#: ../../source/reference/model.md +msgid "none" +msgstr "" + +#: ../../source/reference/model.md +msgid "is_optional" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ bool](#bool )" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"If True, when AttrValue is not provided or is_na, default_value would be " +"used. Else, AttrValue must be provided." +msgstr "" + +#: ../../source/reference/model.md +msgid "default_value" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ AttrValue](#attrvalue )" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"A reasonable default for this attribute if it's optional and the user " +"does not supply a value. If not, the user must supply a value." +msgstr "" + +#: ../../source/reference/model.md:175 +msgid "{#AttrValue}" +msgstr "" + +#: ../../source/reference/model.md:176 +msgid "AttrValue" +msgstr "" + +#: ../../source/reference/model.md:177 +msgid "The value of an attribute" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.i32" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ int32](#int32 )" +msgstr "" + +#: ../../source/reference/model.md +msgid "INT" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.i64" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ int64](#int64 )" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.f" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ float](#float )" +msgstr "" + +#: ../../source/reference/model.md +msgid "FLOAT" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.d" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ double](#double )" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.s" +msgstr "" + +#: ../../source/reference/model.md +msgid "STRING" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.b" +msgstr "" + +#: ../../source/reference/model.md +msgid "BOOL" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.by" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ bytes](#bytes )" +msgstr "" + +#: ../../source/reference/model.md +msgid "BYTES" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.i32s" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ Int32List](#int32list )" +msgstr "" + +#: ../../source/reference/model.md +msgid "INTS" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.i64s" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ Int64List](#int64list )" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.fs" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ FloatList](#floatlist )" +msgstr "" + +#: ../../source/reference/model.md +msgid "FLOATS" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.ds" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ DoubleList](#doublelist )" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.ss" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ StringList](#stringlist )" +msgstr "" + +#: ../../source/reference/model.md +msgid "STRINGS" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.bs" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ BoolList](#boollist )" +msgstr "" + +#: ../../source/reference/model.md +msgid "BOOLS" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.bys" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ BytesList](#byteslist )" +msgstr "" + +#: ../../source/reference/model.md +msgid "BYTESS" +msgstr "" + +#: ../../source/reference/model.md:200 +msgid "{#BoolList}" +msgstr "" + +#: ../../source/reference/model.md:201 +msgid "BoolList" +msgstr "" + +#: ../../source/reference/model.md +msgid "data" +msgstr "" + +#: ../../source/reference/model.md +msgid "[repeated bool](#bool )" +msgstr "" + +#: ../../source/reference/model.md:212 +msgid "{#BytesList}" +msgstr "" + +#: ../../source/reference/model.md:213 +msgid "BytesList" +msgstr "" + +#: ../../source/reference/model.md +msgid "[repeated bytes](#bytes )" +msgstr "" + +#: ../../source/reference/model.md:224 +msgid "{#DoubleList}" +msgstr "" + +#: ../../source/reference/model.md:225 +msgid "DoubleList" +msgstr "" + +#: ../../source/reference/model.md +msgid "[repeated double](#double )" +msgstr "" + +#: ../../source/reference/model.md:236 +msgid "{#FloatList}" +msgstr "" + +#: ../../source/reference/model.md:237 +msgid "FloatList" +msgstr "" + +#: ../../source/reference/model.md +msgid "[repeated float](#float )" +msgstr "" + +#: ../../source/reference/model.md:248 +msgid "{#Int32List}" +msgstr "" + +#: ../../source/reference/model.md:249 +msgid "Int32List" +msgstr "" + +#: ../../source/reference/model.md +msgid "[repeated int32](#int32 )" +msgstr "" + +#: ../../source/reference/model.md:260 +msgid "{#Int64List}" +msgstr "" + +#: ../../source/reference/model.md:261 +msgid "Int64List" +msgstr "" + +#: ../../source/reference/model.md +msgid "[repeated int64](#int64 )" +msgstr "" + +#: ../../source/reference/model.md:272 +msgid "{#StringList}" +msgstr "" + +#: ../../source/reference/model.md:273 +msgid "StringList" +msgstr "" + +#: ../../source/reference/model.md +msgid "[repeated string](#string )" +msgstr "" + +#: ../../source/reference/model.md:286 +msgid "{#IoDef}" +msgstr "" + +#: ../../source/reference/model.md:287 +msgid "IoDef" +msgstr "" + +#: ../../source/reference/model.md:288 +msgid "Define an input/output for operator." +msgstr "" + +#: ../../source/reference/model.md +msgid "Must be unique among all IOs of the operator." +msgstr "" + +#: ../../source/reference/model.md +msgid "Description of the IO" +msgstr "" + +#: ../../source/reference/model.md:299 +msgid "{#OpDef}" +msgstr "" + +#: ../../source/reference/model.md:300 +msgid "OpDef" +msgstr "" + +#: ../../source/reference/model.md:301 +msgid "The definition of a operator." +msgstr "" + +#: ../../source/reference/model.md +msgid "Unique name of the op" +msgstr "" + +#: ../../source/reference/model.md +msgid "Description of the op" +msgstr "" + +#: ../../source/reference/model.md +msgid "version" +msgstr "" + +#: ../../source/reference/model.md +msgid "Version of the op" +msgstr "" + +#: ../../source/reference/model.md +msgid "tag" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ OpTag](#optag )" +msgstr "" + +#: ../../source/reference/model.md +msgid "inputs" +msgstr "" + +#: ../../source/reference/model.md +msgid "[repeated IoDef](#iodef )" +msgstr "" + +#: ../../source/reference/model.md +msgid "output" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ IoDef](#iodef )" +msgstr "" + +#: ../../source/reference/model.md +msgid "attrs" +msgstr "" + +#: ../../source/reference/model.md +msgid "[repeated AttrDef](#attrdef )" +msgstr "" + +#: ../../source/reference/model.md:317 +msgid "{#OpTag}" +msgstr "" + +#: ../../source/reference/model.md:318 +msgid "OpTag" +msgstr "" + +#: ../../source/reference/model.md:319 +msgid "Representation operator property" +msgstr "" + +#: ../../source/reference/model.md +msgid "returnable" +msgstr "" + +#: ../../source/reference/model.md +msgid "The operator's output can be the final result" +msgstr "" + +#: ../../source/reference/model.md +msgid "mergeable" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"The operator accept the output of operators with different participants " +"and will somehow merge them." +msgstr "" + +#: ../../source/reference/model.md +msgid "session_run" +msgstr "" + +#: ../../source/reference/model.md +msgid "The operator needs to be executed in session." +msgstr "" + +#: ../../source/reference/model.md:333 +msgid "{#ExecutionDef}" +msgstr "" + +#: ../../source/reference/model.md:334 +msgid "ExecutionDef" +msgstr "" + +#: ../../source/reference/model.md:335 +msgid "" +"The definition of a execution. A execution represents a subgraph within a" +" graph that can be scheduled for execution in a specified pattern." +msgstr "" + +#: ../../source/reference/model.md +msgid "nodes" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"Represents the nodes contained in this execution. Note that these node " +"names should be findable and unique within the node definitions. One node" +" can only exist in one execution and must exist in one." +msgstr "" + +#: ../../source/reference/model.md +msgid "config" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ RuntimeConfig](#runtimeconfig )" +msgstr "" + +#: ../../source/reference/model.md ../../source/reference/model.md:423 +msgid "The runtime config of the execution." +msgstr "" + +#: ../../source/reference/model.md:347 +msgid "{#GraphDef}" +msgstr "" + +#: ../../source/reference/model.md:348 +msgid "GraphDef" +msgstr "" + +#: ../../source/reference/model.md:349 +msgid "" +"The definition of a Graph. A graph consists of a set of nodes carrying " +"data and a set of executions that describes the scheduling of the graph." +msgstr "" + +#: ../../source/reference/model.md +msgid "Version of the graph" +msgstr "" + +#: ../../source/reference/model.md +msgid "node_list" +msgstr "" + +#: ../../source/reference/model.md +msgid "[repeated NodeDef](#nodedef )" +msgstr "" + +#: ../../source/reference/model.md +msgid "execution_list" +msgstr "" + +#: ../../source/reference/model.md +msgid "[repeated ExecutionDef](#executiondef )" +msgstr "" + +#: ../../source/reference/model.md:362 +msgid "{#GraphView}" +msgstr "" + +#: ../../source/reference/model.md:363 +msgid "GraphView" +msgstr "" + +#: ../../source/reference/model.md:364 +msgid "" +"The view of a graph is used to display the structure of the graph, " +"containing only structural information and excluding the data components." +msgstr "" + +#: ../../source/reference/model.md +msgid "[repeated NodeView](#nodeview )" +msgstr "" + +#: ../../source/reference/model.md:377 +msgid "{#NodeDef}" +msgstr "" + +#: ../../source/reference/model.md:378 +msgid "NodeDef" +msgstr "" + +#: ../../source/reference/model.md:379 +msgid "The definition of a node." +msgstr "" + +#: ../../source/reference/model.md +msgid "Must be unique among all nodes of the graph." +msgstr "" + +#: ../../source/reference/model.md +msgid "op" +msgstr "" + +#: ../../source/reference/model.md +msgid "The operator name." +msgstr "" + +#: ../../source/reference/model.md +msgid "parents" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"The parent node names of the node. The order of the parent nodes should " +"match the order of the inputs of the node." +msgstr "" + +#: ../../source/reference/model.md +msgid "attr_values" +msgstr "" + +#: ../../source/reference/model.md +msgid "[map NodeDef.AttrValuesEntry](#nodedef-attrvaluesentry )" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"The attribute values configed in the node. Note that this should include " +"all attrs defined in the corresponding OpDef." +msgstr "" + +#: ../../source/reference/model.md +msgid "op_version" +msgstr "" + +#: ../../source/reference/model.md +msgid "The operator version." +msgstr "" + +#: ../../source/reference/model.md:393 +msgid "{#NodeDef.AttrValuesEntry}" +msgstr "" + +#: ../../source/reference/model.md:394 +msgid "NodeDef.AttrValuesEntry" +msgstr "" + +#: ../../source/reference/model.md +msgid "key" +msgstr "" + +#: ../../source/reference/model.md +msgid "value" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ op.AttrValue](#attrvalue )" +msgstr "" + +#: ../../source/reference/model.md:406 +msgid "{#NodeView}" +msgstr "" + +#: ../../source/reference/model.md:407 +msgid "NodeView" +msgstr "" + +#: ../../source/reference/model.md:408 +msgid "The view of a node, which could be public to other parties" +msgstr "" + +#: ../../source/reference/model.md:421 +msgid "{#RuntimeConfig}" +msgstr "" + +#: ../../source/reference/model.md:422 +msgid "RuntimeConfig" +msgstr "" + +#: ../../source/reference/model.md +msgid "dispatch_type" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ DispatchType](#dispatchtype )" +msgstr "" + +#: ../../source/reference/model.md +msgid "The dispatch type of the execution." +msgstr "" + +#: ../../source/reference/model.md +msgid "The execution need run in session(stateful) TODO: not support yet." +msgstr "" + +#: ../../source/reference/model.md +msgid "specific_flag" +msgstr "" + +#: ../../source/reference/model.md +msgid "if dispatch_type is DP_SPECIFIED, only one party should be true" +msgstr "" + +#: ../../source/reference/model.md:437 +msgid "{#ModelBundle}" +msgstr "" + +#: ../../source/reference/model.md:438 +msgid "ModelBundle" +msgstr "" + +#: ../../source/reference/model.md:439 +msgid "" +"Represents an exported secertflow model. It consists of a GraphDef and " +"extra metadata required for serving." +msgstr "" + +#: ../../source/reference/model.md +msgid "graph" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ GraphDef](#graphdef )" +msgstr "" + +#: ../../source/reference/model.md:452 +msgid "{#ModelInfo}" +msgstr "" + +#: ../../source/reference/model.md:453 +msgid "ModelInfo" +msgstr "" + +#: ../../source/reference/model.md:454 +msgid "Represents a secertflow model without private data." +msgstr "" + +#: ../../source/reference/model.md +msgid "graph_view" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ GraphView](#graphview )" +msgstr "" + +#: ../../source/reference/model.md:466 +msgid "{#ModelManifest}" +msgstr "" + +#: ../../source/reference/model.md:467 +msgid "ModelManifest" +msgstr "" + +#: ../../source/reference/model.md:468 +msgid "" +"The manifest of the model package. Package format is as follows: " +"model.tar.gz ├ MANIFIEST ├ model_file └ some op meta files MANIFIEST " +"should be json format" +msgstr "" + +#: ../../source/reference/model.md +msgid "bundle_path" +msgstr "" + +#: ../../source/reference/model.md +msgid "Model bundle file path." +msgstr "" + +#: ../../source/reference/model.md +msgid "bundle_format" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ FileFormatType](#fileformattype )" +msgstr "" + +#: ../../source/reference/model.md +msgid "The format type of the model bundle file." +msgstr "" + +#: ../../source/reference/model.md:488 ../../source/reference/model.md:566 +msgid "{#ComputeTrace}" +msgstr "" + +#: ../../source/reference/model.md:489 ../../source/reference/model.md:567 +msgid "ComputeTrace" +msgstr "" + +#: ../../source/reference/model.md +msgid "The name of this Compute." +msgstr "" + +#: ../../source/reference/model.md +msgid "func_traces" +msgstr "" + +#: ../../source/reference/model.md +msgid "[repeated FunctionTrace](#functiontrace )" +msgstr "" + +#: ../../source/reference/model.md:501 ../../source/reference/model.md:579 +msgid "{#FunctionInput}" +msgstr "" + +#: ../../source/reference/model.md:502 ../../source/reference/model.md:580 +msgid "FunctionInput" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.data_id" +msgstr "" + +#: ../../source/reference/model.md +msgid "'0' means root input data" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.custom_scalar" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ Scalar](#scalar )" +msgstr "" + +#: ../../source/reference/model.md:514 ../../source/reference/model.md:592 +msgid "{#FunctionOutput}" +msgstr "" + +#: ../../source/reference/model.md:515 ../../source/reference/model.md:593 +msgid "FunctionOutput" +msgstr "" + +#: ../../source/reference/model.md +msgid "data_id" +msgstr "" + +#: ../../source/reference/model.md:526 ../../source/reference/model.md:604 +msgid "{#FunctionTrace}" +msgstr "" + +#: ../../source/reference/model.md:527 ../../source/reference/model.md:605 +msgid "FunctionTrace" +msgstr "" + +#: ../../source/reference/model.md +msgid "The Function name." +msgstr "" + +#: ../../source/reference/model.md +msgid "option_bytes" +msgstr "" + +#: ../../source/reference/model.md +msgid "The serialized function options." +msgstr "" + +#: ../../source/reference/model.md +msgid "[repeated FunctionInput](#functioninput )" +msgstr "" + +#: ../../source/reference/model.md +msgid "Inputs of this function." +msgstr "" + +#: ../../source/reference/model.md +msgid "[ FunctionOutput](#functionoutput )" +msgstr "" + +#: ../../source/reference/model.md +msgid "Output of this function." +msgstr "" + +#: ../../source/reference/model.md:541 ../../source/reference/model.md:619 +msgid "{#Scalar}" +msgstr "" + +#: ../../source/reference/model.md:542 ../../source/reference/model.md:620 +msgid "Scalar" +msgstr "" + +#: ../../source/reference/model.md:543 ../../source/reference/model.md:621 +msgid "Represents a single value with a specific data type." +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.i8" +msgstr "" + +#: ../../source/reference/model.md +msgid "INT8." +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.ui8" +msgstr "" + +#: ../../source/reference/model.md +msgid "UINT8" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.i16" +msgstr "" + +#: ../../source/reference/model.md +msgid "INT16" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.ui16" +msgstr "" + +#: ../../source/reference/model.md +msgid "UINT16" +msgstr "" + +#: ../../source/reference/model.md +msgid "INT32" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.ui32" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ uint32](#uint32 )" +msgstr "" + +#: ../../source/reference/model.md +msgid "UINT32" +msgstr "" + +#: ../../source/reference/model.md +msgid "INT64" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"[**oneof**](https://developers.google.com/protocol-" +"buffers/docs/proto3#oneof) value.ui64" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ uint64](#uint64 )" +msgstr "" + +#: ../../source/reference/model.md +msgid "UINT64" +msgstr "" + +#: ../../source/reference/model.md +msgid "DOUBLE" +msgstr "" + +#: ../../source/reference/model.md:647 +msgid "{#AttrType}" +msgstr "" + +#: ../../source/reference/model.md:648 +msgid "AttrType" +msgstr "" + +#: ../../source/reference/model.md:649 +msgid "Supported attribute types." +msgstr "" + +#: ../../source/reference/model.md +msgid "Name" +msgstr "" + +#: ../../source/reference/model.md +msgid "Number" +msgstr "" + +#: ../../source/reference/model.md +msgid "UNKNOWN_AT_TYPE" +msgstr "" + +#: ../../source/reference/model.md +msgid "0" +msgstr "" + +#: ../../source/reference/model.md +msgid "Placeholder for proto3 default value, do not use it." +msgstr "" + +#: ../../source/reference/model.md +msgid "AT_INT32" +msgstr "" + +#: ../../source/reference/model.md +msgid "1" +msgstr "" + +#: ../../source/reference/model.md +msgid "AT_INT64" +msgstr "" + +#: ../../source/reference/model.md +msgid "2" +msgstr "" + +#: ../../source/reference/model.md +msgid "AT_FLOAT" +msgstr "" + +#: ../../source/reference/model.md +msgid "3" +msgstr "" + +#: ../../source/reference/model.md +msgid "AT_DOUBLE" +msgstr "" + +#: ../../source/reference/model.md +msgid "4" +msgstr "" + +#: ../../source/reference/model.md +msgid "AT_STRING" +msgstr "" + +#: ../../source/reference/model.md +msgid "5" +msgstr "" + +#: ../../source/reference/model.md +msgid "AT_BOOL" +msgstr "" + +#: ../../source/reference/model.md +msgid "6" +msgstr "" + +#: ../../source/reference/model.md +msgid "AT_BYTES" +msgstr "" + +#: ../../source/reference/model.md +msgid "7" +msgstr "" + +#: ../../source/reference/model.md +msgid "AT_INT32_LIST" +msgstr "" + +#: ../../source/reference/model.md +msgid "11" +msgstr "" + +#: ../../source/reference/model.md +msgid "INT32 LIST" +msgstr "" + +#: ../../source/reference/model.md +msgid "AT_INT64_LIST" +msgstr "" + +#: ../../source/reference/model.md +msgid "12" +msgstr "" + +#: ../../source/reference/model.md +msgid "INT64 LIST" +msgstr "" + +#: ../../source/reference/model.md +msgid "AT_FLOAT_LIST" +msgstr "" + +#: ../../source/reference/model.md +msgid "13" +msgstr "" + +#: ../../source/reference/model.md +msgid "FLOAT LIST" +msgstr "" + +#: ../../source/reference/model.md +msgid "AT_DOUBLE_LIST" +msgstr "" + +#: ../../source/reference/model.md +msgid "14" +msgstr "" + +#: ../../source/reference/model.md +msgid "DOUBLE LIST" +msgstr "" + +#: ../../source/reference/model.md +msgid "AT_STRING_LIST" +msgstr "" + +#: ../../source/reference/model.md +msgid "15" +msgstr "" + +#: ../../source/reference/model.md +msgid "STRING LIST" +msgstr "" + +#: ../../source/reference/model.md +msgid "AT_BOOL_LIST" +msgstr "" + +#: ../../source/reference/model.md +msgid "16" +msgstr "" + +#: ../../source/reference/model.md +msgid "BOOL LIST" +msgstr "" + +#: ../../source/reference/model.md +msgid "AT_BYTES_LIST" +msgstr "" + +#: ../../source/reference/model.md +msgid "17" +msgstr "" + +#: ../../source/reference/model.md +msgid "BYTES LIST" +msgstr "" + +#: ../../source/reference/model.md:676 +msgid "{#DispatchType}" +msgstr "" + +#: ../../source/reference/model.md:677 +msgid "DispatchType" +msgstr "" + +#: ../../source/reference/model.md:678 +msgid "Supported dispatch type" +msgstr "" + +#: ../../source/reference/model.md +msgid "UNKNOWN_DP_TYPE" +msgstr "" + +#: ../../source/reference/model.md +msgid "DP_ALL" +msgstr "" + +#: ../../source/reference/model.md +msgid "Dispatch all participants." +msgstr "" + +#: ../../source/reference/model.md +msgid "DP_ANYONE" +msgstr "" + +#: ../../source/reference/model.md +msgid "Dispatch any participant." +msgstr "" + +#: ../../source/reference/model.md +msgid "DP_SPECIFIED" +msgstr "" + +#: ../../source/reference/model.md +msgid "Dispatch specified participant." +msgstr "" + +#: ../../source/reference/model.md:692 +msgid "{#FileFormatType}" +msgstr "" + +#: ../../source/reference/model.md:693 +msgid "FileFormatType" +msgstr "" + +#: ../../source/reference/model.md:694 +msgid "Support model file format" +msgstr "" + +#: ../../source/reference/model.md +msgid "UNKNOWN_FF_TYPE" +msgstr "" + +#: ../../source/reference/model.md +msgid "FF_PB" +msgstr "" + +#: ../../source/reference/model.md +msgid "Protobuf" +msgstr "" + +#: ../../source/reference/model.md +msgid "FF_JSON" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"Json It is recommended to use protobuf's official json serialization " +"method to ensure compatibility" +msgstr "" + +#: ../../source/reference/model.md:707 +msgid "{#DataType}" +msgstr "" + +#: ../../source/reference/model.md:708 +msgid "DataType" +msgstr "" + +#: ../../source/reference/model.md:709 +msgid "" +"Mapping arrow::DataType " +"`https://arrow.apache.org/docs/cpp/api/datatype.html`." +msgstr "" + +#: ../../source/reference/model.md +msgid "UNKNOWN_DT_TYPE" +msgstr "" + +#: ../../source/reference/model.md +msgid "DT_BOOL" +msgstr "" + +#: ../../source/reference/model.md +msgid "Boolean as 1 bit, LSB bit-packed ordering." +msgstr "" + +#: ../../source/reference/model.md +msgid "DT_UINT8" +msgstr "" + +#: ../../source/reference/model.md +msgid "Unsigned 8-bit little-endian integer." +msgstr "" + +#: ../../source/reference/model.md +msgid "DT_INT8" +msgstr "" + +#: ../../source/reference/model.md +msgid "Signed 8-bit little-endian integer." +msgstr "" + +#: ../../source/reference/model.md +msgid "DT_UINT16" +msgstr "" + +#: ../../source/reference/model.md +msgid "Unsigned 16-bit little-endian integer." +msgstr "" + +#: ../../source/reference/model.md +msgid "DT_INT16" +msgstr "" + +#: ../../source/reference/model.md +msgid "Signed 16-bit little-endian integer." +msgstr "" + +#: ../../source/reference/model.md +msgid "DT_UINT32" +msgstr "" + +#: ../../source/reference/model.md +msgid "Unsigned 32-bit little-endian integer." +msgstr "" + +#: ../../source/reference/model.md +msgid "DT_INT32" +msgstr "" + +#: ../../source/reference/model.md +msgid "Signed 32-bit little-endian integer." +msgstr "" + +#: ../../source/reference/model.md +msgid "DT_UINT64" +msgstr "" + +#: ../../source/reference/model.md +msgid "8" +msgstr "" + +#: ../../source/reference/model.md +msgid "Unsigned 64-bit little-endian integer." +msgstr "" + +#: ../../source/reference/model.md +msgid "DT_INT64" +msgstr "" + +#: ../../source/reference/model.md +msgid "9" +msgstr "" + +#: ../../source/reference/model.md +msgid "Signed 64-bit little-endian integer." +msgstr "" + +#: ../../source/reference/model.md +msgid "DT_FLOAT" +msgstr "" + +#: ../../source/reference/model.md +msgid "4-byte floating point value" +msgstr "" + +#: ../../source/reference/model.md +msgid "DT_DOUBLE" +msgstr "" + +#: ../../source/reference/model.md +msgid "8-byte floating point value" +msgstr "" + +#: ../../source/reference/model.md +msgid "DT_STRING" +msgstr "" + +#: ../../source/reference/model.md +msgid "UTF8 variable-length string as List" +msgstr "" + +#: ../../source/reference/model.md +msgid "DT_BINARY" +msgstr "" + +#: ../../source/reference/model.md +msgid "Variable-length bytes (no guarantee of UTF8-ness)" +msgstr "" + +#: ../../source/reference/model.md:734 ../../source/reference/model.md:751 +msgid "{#ExtendFunctionName}" +msgstr "" + +#: ../../source/reference/model.md:735 ../../source/reference/model.md:752 +msgid "ExtendFunctionName" +msgstr "" + +#: ../../source/reference/model.md +msgid "UNKOWN_EX_FUNCTION_NAME" +msgstr "" + +#: ../../source/reference/model.md +msgid "Placeholder for proto3 default value, do not use it" +msgstr "" + +#: ../../source/reference/model.md +msgid "EFN_TB_COLUMN" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"Get colunm from table(record_batch). see " +"https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch6columnEi" +msgstr "" + +#: ../../source/reference/model.md +msgid "EFN_TB_ADD_COLUMN" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"Add colum to table(record_batch). see " +"https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch9AddColumnEiNSt6stringERKNSt10shared_ptrI5ArrayEE" +msgstr "" + +#: ../../source/reference/model.md +msgid "EFN_TB_REMOVE_COLUMN" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"Remove colunm from table(record_batch). see " +"https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch12RemoveColumnEi" +msgstr "" + +#: ../../source/reference/model.md +msgid "EFN_TB_SET_COLUMN" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"Set colunm to table(record_batch). see " +"https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch9SetColumnEiRKNSt10shared_ptrI5FieldEERKNSt10shared_ptrI5ArrayEE" +msgstr "" + +#: ../../source/reference/model.md:767 +msgid "Scalar Value Types" +msgstr "" + +#: ../../source/reference/model.md +msgid ".proto Type" +msgstr "" + +#: ../../source/reference/model.md +msgid "Notes" +msgstr "" + +#: ../../source/reference/model.md +msgid "C++ Type" +msgstr "" + +#: ../../source/reference/model.md +msgid "Java Type" +msgstr "" + +#: ../../source/reference/model.md +msgid "Python Type" +msgstr "" + +#: ../../source/reference/model.md +msgid "

double" +msgstr "" + +#: ../../source/reference/model.md +msgid "double" +msgstr "" + +#: ../../source/reference/model.md +msgid "float" +msgstr "" + +#: ../../source/reference/model.md +msgid "

float" +msgstr "" + +#: ../../source/reference/model.md +msgid "

int32" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"Uses variable-length encoding. Inefficient for encoding negative numbers " +"– if your field is likely to have negative values, use sint32 instead." +msgstr "" + +#: ../../source/reference/model.md +msgid "int32" +msgstr "" + +#: ../../source/reference/model.md +msgid "int" +msgstr "" + +#: ../../source/reference/model.md +msgid "

int64" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"Uses variable-length encoding. Inefficient for encoding negative numbers " +"– if your field is likely to have negative values, use sint64 instead." +msgstr "" + +#: ../../source/reference/model.md +msgid "int64" +msgstr "" + +#: ../../source/reference/model.md +msgid "long" +msgstr "" + +#: ../../source/reference/model.md +msgid "int/long" +msgstr "" + +#: ../../source/reference/model.md +msgid "

uint32" +msgstr "" + +#: ../../source/reference/model.md +msgid "Uses variable-length encoding." +msgstr "" + +#: ../../source/reference/model.md +msgid "uint32" +msgstr "" + +#: ../../source/reference/model.md +msgid "

uint64" +msgstr "" + +#: ../../source/reference/model.md +msgid "uint64" +msgstr "" + +#: ../../source/reference/model.md +msgid "

sint32" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"Uses variable-length encoding. Signed int value. These more efficiently " +"encode negative numbers than regular int32s." +msgstr "" + +#: ../../source/reference/model.md +msgid "

sint64" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"Uses variable-length encoding. Signed int value. These more efficiently " +"encode negative numbers than regular int64s." +msgstr "" + +#: ../../source/reference/model.md +msgid "

fixed32" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"Always four bytes. More efficient than uint32 if values are often greater" +" than 2^28." +msgstr "" + +#: ../../source/reference/model.md +msgid "

fixed64" +msgstr "" + +#: ../../source/reference/model.md +msgid "" +"Always eight bytes. More efficient than uint64 if values are often " +"greater than 2^56." +msgstr "" + +#: ../../source/reference/model.md +msgid "

sfixed32" +msgstr "" + +#: ../../source/reference/model.md +msgid "Always four bytes." +msgstr "" + +#: ../../source/reference/model.md +msgid "

sfixed64" +msgstr "" + +#: ../../source/reference/model.md +msgid "Always eight bytes." +msgstr "" + +#: ../../source/reference/model.md +msgid "

bool" +msgstr "" + +#: ../../source/reference/model.md +msgid "bool" +msgstr "" + +#: ../../source/reference/model.md +msgid "boolean" +msgstr "" + +#: ../../source/reference/model.md +msgid "

string" +msgstr "" + +#: ../../source/reference/model.md +msgid "A string must always contain UTF-8 encoded or 7-bit ASCII text." +msgstr "" + +#: ../../source/reference/model.md +msgid "string" +msgstr "" + +#: ../../source/reference/model.md +msgid "String" +msgstr "" + +#: ../../source/reference/model.md +msgid "str/unicode" +msgstr "" + +#: ../../source/reference/model.md +msgid "

bytes" +msgstr "" + +#: ../../source/reference/model.md +msgid "May contain any arbitrary sequence of bytes." +msgstr "" + +#: ../../source/reference/model.md +msgid "ByteString" +msgstr "" + +#: ../../source/reference/model.md +msgid "str" +msgstr "" diff --git a/docs/locales/zh_CN/LC_MESSAGES/reference/spi.po b/docs/locales/zh_CN/LC_MESSAGES/reference/spi.po new file mode 100644 index 0000000..ed566a5 --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/reference/spi.po @@ -0,0 +1,834 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023 Ant Group Co., Ltd. +# This file is distributed under the same license as the SecretFlow-Serving +# package. +# FIRST AUTHOR , 2023. +# +msgid "" +msgstr "" +"Project-Id-Version: SecretFlow-Serving \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-01-05 11:07+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.13.0\n" + +#: ../../source/reference/spi.md:1 +msgid "SecretFlow-Serving SPI" +msgstr "" + +#: ../../source/reference/spi.md:3 +msgid "Table of Contents" +msgstr "" + +#: ../../source/reference/spi.md:4 +msgid "Services" +msgstr "" + +#: ../../source/reference/spi.md:8 +msgid "[BatchFeatureService](#batchfeatureservice)" +msgstr "" + +#: ../../source/reference/spi.md:22 ../../source/reference/spi.md:95 +msgid "Messages" +msgstr "" + +#: ../../source/reference/spi.md:26 +msgid "[BatchFetchFeatureRequest](#batchfetchfeaturerequest)" +msgstr "" + +#: ../../source/reference/spi.md:27 +msgid "[BatchFetchFeatureResponse](#batchfetchfeatureresponse)" +msgstr "" + +#: ../../source/reference/spi.md:33 +msgid "[Header](#header)" +msgstr "" + +#: ../../source/reference/spi.md:34 +msgid "[Header.DataEntry](#header-dataentry)" +msgstr "" + +#: ../../source/reference/spi.md:35 +msgid "[Status](#status)" +msgstr "" + +#: ../../source/reference/spi.md:44 +msgid "[Feature](#feature)" +msgstr "" + +#: ../../source/reference/spi.md:45 +msgid "[FeatureField](#featurefield)" +msgstr "" + +#: ../../source/reference/spi.md:46 +msgid "[FeatureParam](#featureparam)" +msgstr "" + +#: ../../source/reference/spi.md:47 +msgid "[FeatureValue](#featurevalue)" +msgstr "" + +#: ../../source/reference/spi.md:52 ../../source/reference/spi.md:299 +msgid "Enums" +msgstr "" + +#: ../../source/reference/spi.md:62 +msgid "[ErrorCode](#errorcode)" +msgstr "" + +#: ../../source/reference/spi.md:68 +msgid "[FieldType](#fieldtype)" +msgstr "" + +#: ../../source/reference/spi.md:72 +msgid "[Scalar Value Types](#scalar-value-types)" +msgstr "" + +#: ../../source/reference/spi.md:75 +msgid "{#BatchFeatureService}" +msgstr "" + +#: ../../source/reference/spi.md:76 +msgid "BatchFeatureService" +msgstr "" + +#: ../../source/reference/spi.md:77 +msgid "BatchFeatureService provides access to fetch features." +msgstr "" + +#: ../../source/reference/spi.md:79 +msgid "BatchFetchFeature" +msgstr "" + +#: ../../source/reference/spi.md:81 +msgid "" +"**rpc** " +"BatchFetchFeature([BatchFetchFeatureRequest](#batchfetchfeaturerequest))" +" [BatchFetchFeatureResponse](#batchfetchfeatureresponse)" +msgstr "" + +#: ../../source/reference/spi.md:99 +msgid "{#BatchFetchFeatureRequest}" +msgstr "" + +#: ../../source/reference/spi.md:100 +msgid "BatchFetchFeatureRequest" +msgstr "" + +#: ../../source/reference/spi.md:101 +msgid "BatchFetchFeature request containing one or more requests. examples:" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Field" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Type" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Description" +msgstr "" + +#: ../../source/reference/spi.md +msgid "header" +msgstr "" + +#: ../../source/reference/spi.md +msgid "[ Header](#header )" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Custom data passed by the Predict request's header." +msgstr "" + +#: ../../source/reference/spi.md +msgid "model_service_id" +msgstr "" + +#: ../../source/reference/spi.md +msgid "[ string](#string )" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Model service specification." +msgstr "" + +#: ../../source/reference/spi.md +msgid "party_id" +msgstr "" + +#: ../../source/reference/spi.md +msgid "The request party id." +msgstr "" + +#: ../../source/reference/spi.md +msgid "feature_fields" +msgstr "" + +#: ../../source/reference/spi.md +msgid "[repeated secretflow.serving.FeatureField](#featurefield )" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Request feature field list" +msgstr "" + +#: ../../source/reference/spi.md +msgid "param" +msgstr "" + +#: ../../source/reference/spi.md +msgid "[ secretflow.serving.FeatureParam](#featureparam )" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Custom query paramters for fetch features" +msgstr "" + +#: ../../source/reference/spi.md:144 +msgid "{#BatchFetchFeatureResponse}" +msgstr "" + +#: ../../source/reference/spi.md:145 +msgid "BatchFetchFeatureResponse" +msgstr "" + +#: ../../source/reference/spi.md:146 +msgid "" +"BatchFetchFeatureResponse response containing one or more responses. " +"examples:" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Custom data." +msgstr "" + +#: ../../source/reference/spi.md +msgid "status" +msgstr "" + +#: ../../source/reference/spi.md +msgid "[ Status](#status )" +msgstr "" + +#: ../../source/reference/spi.md +msgid "none" +msgstr "" + +#: ../../source/reference/spi.md +msgid "features" +msgstr "" + +#: ../../source/reference/spi.md +msgid "[repeated secretflow.serving.Feature](#feature )" +msgstr "" + +#: ../../source/reference/spi.md +msgid "" +"Should include all the features mentioned in the " +"BatchFetchFeatureRequest.feature_fields" +msgstr "" + +#: ../../source/reference/spi.md:200 +msgid "{#Header}" +msgstr "" + +#: ../../source/reference/spi.md:201 +msgid "Header" +msgstr "" + +#: ../../source/reference/spi.md:202 +msgid "Header containing custom data" +msgstr "" + +#: ../../source/reference/spi.md +msgid "data" +msgstr "" + +#: ../../source/reference/spi.md +msgid "[map Header.DataEntry](#header-dataentry )" +msgstr "" + +#: ../../source/reference/spi.md:212 +msgid "{#Header.DataEntry}" +msgstr "" + +#: ../../source/reference/spi.md:213 +msgid "Header.DataEntry" +msgstr "" + +#: ../../source/reference/spi.md +msgid "key" +msgstr "" + +#: ../../source/reference/spi.md +msgid "value" +msgstr "" + +#: ../../source/reference/spi.md:225 +msgid "{#Status}" +msgstr "" + +#: ../../source/reference/spi.md:226 +msgid "Status" +msgstr "" + +#: ../../source/reference/spi.md:227 +msgid "Represents the status of spi request" +msgstr "" + +#: ../../source/reference/spi.md +msgid "code" +msgstr "" + +#: ../../source/reference/spi.md +msgid "[ int32](#int32 )" +msgstr "" + +#: ../../source/reference/spi.md +msgid "" +"code value reference `ErrorCode` in " +"secretflow_serving/spis/error_code.proto" +msgstr "" + +#: ../../source/reference/spi.md +msgid "msg" +msgstr "" + +#: ../../source/reference/spi.md:242 +msgid "{#Feature}" +msgstr "" + +#: ../../source/reference/spi.md:243 +msgid "Feature" +msgstr "" + +#: ../../source/reference/spi.md:244 +msgid "The definition of a feature" +msgstr "" + +#: ../../source/reference/spi.md +msgid "field" +msgstr "" + +#: ../../source/reference/spi.md +msgid "[ FeatureField](#featurefield )" +msgstr "" + +#: ../../source/reference/spi.md +msgid "[ FeatureValue](#featurevalue )" +msgstr "" + +#: ../../source/reference/spi.md:255 +msgid "{#FeatureField}" +msgstr "" + +#: ../../source/reference/spi.md:256 +msgid "FeatureField" +msgstr "" + +#: ../../source/reference/spi.md:257 +msgid "The definition of a feature field." +msgstr "" + +#: ../../source/reference/spi.md +msgid "name" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Unique name of the feature" +msgstr "" + +#: ../../source/reference/spi.md +msgid "type" +msgstr "" + +#: ../../source/reference/spi.md +msgid "[ FieldType](#fieldtype )" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Field type of the feature" +msgstr "" + +#: ../../source/reference/spi.md:268 +msgid "{#FeatureParam}" +msgstr "" + +#: ../../source/reference/spi.md:269 +msgid "FeatureParam" +msgstr "" + +#: ../../source/reference/spi.md:270 +msgid "The param for fetch features" +msgstr "" + +#: ../../source/reference/spi.md +msgid "query_datas" +msgstr "" + +#: ../../source/reference/spi.md +msgid "[repeated string](#string )" +msgstr "" + +#: ../../source/reference/spi.md +msgid "" +"The serialized datas for query features. Each one for query one row of " +"features." +msgstr "" + +#: ../../source/reference/spi.md +msgid "query_context" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Optional. Represents the common part of the query datas." +msgstr "" + +#: ../../source/reference/spi.md:281 +msgid "{#FeatureValue}" +msgstr "" + +#: ../../source/reference/spi.md:282 +msgid "FeatureValue" +msgstr "" + +#: ../../source/reference/spi.md:283 +msgid "The value of a feature" +msgstr "" + +#: ../../source/reference/spi.md +msgid "i32s" +msgstr "" + +#: ../../source/reference/spi.md +msgid "[repeated int32](#int32 )" +msgstr "" + +#: ../../source/reference/spi.md +msgid "int list" +msgstr "" + +#: ../../source/reference/spi.md +msgid "i64s" +msgstr "" + +#: ../../source/reference/spi.md +msgid "[repeated int64](#int64 )" +msgstr "" + +#: ../../source/reference/spi.md +msgid "fs" +msgstr "" + +#: ../../source/reference/spi.md +msgid "[repeated float](#float )" +msgstr "" + +#: ../../source/reference/spi.md +msgid "float list" +msgstr "" + +#: ../../source/reference/spi.md +msgid "ds" +msgstr "" + +#: ../../source/reference/spi.md +msgid "[repeated double](#double )" +msgstr "" + +#: ../../source/reference/spi.md +msgid "ss" +msgstr "" + +#: ../../source/reference/spi.md +msgid "string list" +msgstr "" + +#: ../../source/reference/spi.md +msgid "bs" +msgstr "" + +#: ../../source/reference/spi.md +msgid "[repeated bool](#bool )" +msgstr "" + +#: ../../source/reference/spi.md +msgid "bool list" +msgstr "" + +#: ../../source/reference/spi.md:307 +msgid "ErrorCode" +msgstr "" + +#: ../../source/reference/spi.md:308 +msgid "ONLY for Reference by ResponseHeader It's subset of google.rpc.Code" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Name" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Number" +msgstr "" + +#: ../../source/reference/spi.md +msgid "OK" +msgstr "" + +#: ../../source/reference/spi.md +msgid "0" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Not an error; returned on success" +msgstr "" + +#: ../../source/reference/spi.md:315 +msgid "" +"HTTP Mapping: 200 OK | | INVALID_ARGUMENT | 3 | The client specified an " +"invalid argument. Note that this differs from `FAILED_PRECONDITION`. " +"`INVALID_ARGUMENT` indicates arguments that are problematic regardless of" +" the state of the system (e.g., a malformed file name)." +msgstr "" + +#: ../../source/reference/spi.md:318 +msgid "" +"HTTP Mapping: 400 Bad Request | | DEADLINE_EXCEEDED | 4 | The deadline " +"expired before the operation could complete. For operations that change " +"the state of the system, this error may be returned even if the operation" +" has completed successfully. For example, a successful response from a " +"server could have been delayed long enough for the deadline to expire." +msgstr "" + +#: ../../source/reference/spi.md:321 +msgid "" +"HTTP Mapping: 504 Gateway Timeout | | NOT_FOUND | 5 | Some requested " +"entity (e.g., file or directory) was not found." +msgstr "" + +#: ../../source/reference/spi.md:324 +msgid "" +"Note to server developers: if a request is denied for an entire class of " +"users, such as gradual feature rollout or undocumented whitelist, " +"`NOT_FOUND` may be used. If a request is denied for some users within a " +"class of users, such as user-based access control, `PERMISSION_DENIED` " +"must be used." +msgstr "" + +#: ../../source/reference/spi.md:326 +msgid "" +"HTTP Mapping: 404 Not Found | | INTERNAL_ERROR | 13 | Internal errors. " +"This means that some invariants expected by the underlying system have " +"been broken. This error code is reserved for serious errors." +msgstr "" + +#: ../../source/reference/spi.md:329 +msgid "" +"HTTP Mapping: 500 Internal Server Error | | UNAUTHENTICATED | 16 | The " +"request does not have valid authentication credentials for the operation." +msgstr "" + +#: ../../source/reference/spi.md:332 +msgid "HTTP Mapping: 401 Unauthorized |" +msgstr "" + +#: ../../source/reference/spi.md:339 +msgid "FieldType" +msgstr "" + +#: ../../source/reference/spi.md:340 +msgid "Supported feature field type." +msgstr "" + +#: ../../source/reference/spi.md +msgid "UNKNOWN_FIELD_TYPE" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Placeholder for proto3 default value, do not use it." +msgstr "" + +#: ../../source/reference/spi.md +msgid "FIELD_BOOL" +msgstr "" + +#: ../../source/reference/spi.md +msgid "1" +msgstr "" + +#: ../../source/reference/spi.md +msgid "BOOL" +msgstr "" + +#: ../../source/reference/spi.md +msgid "FIELD_INT32" +msgstr "" + +#: ../../source/reference/spi.md +msgid "2" +msgstr "" + +#: ../../source/reference/spi.md +msgid "INT32" +msgstr "" + +#: ../../source/reference/spi.md +msgid "FIELD_INT64" +msgstr "" + +#: ../../source/reference/spi.md +msgid "3" +msgstr "" + +#: ../../source/reference/spi.md +msgid "INT64" +msgstr "" + +#: ../../source/reference/spi.md +msgid "FIELD_FLOAT" +msgstr "" + +#: ../../source/reference/spi.md +msgid "4" +msgstr "" + +#: ../../source/reference/spi.md +msgid "FLOAT" +msgstr "" + +#: ../../source/reference/spi.md +msgid "FIELD_DOUBLE" +msgstr "" + +#: ../../source/reference/spi.md +msgid "5" +msgstr "" + +#: ../../source/reference/spi.md +msgid "DOUBLE" +msgstr "" + +#: ../../source/reference/spi.md +msgid "FIELD_STRING" +msgstr "" + +#: ../../source/reference/spi.md +msgid "6" +msgstr "" + +#: ../../source/reference/spi.md +msgid "STRING" +msgstr "" + +#: ../../source/reference/spi.md:356 +msgid "Scalar Value Types" +msgstr "" + +#: ../../source/reference/spi.md +msgid ".proto Type" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Notes" +msgstr "" + +#: ../../source/reference/spi.md +msgid "C++ Type" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Java Type" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Python Type" +msgstr "" + +#: ../../source/reference/spi.md +msgid "

double" +msgstr "" + +#: ../../source/reference/spi.md +msgid "double" +msgstr "" + +#: ../../source/reference/spi.md +msgid "float" +msgstr "" + +#: ../../source/reference/spi.md +msgid "

float" +msgstr "" + +#: ../../source/reference/spi.md +msgid "

int32" +msgstr "" + +#: ../../source/reference/spi.md +msgid "" +"Uses variable-length encoding. Inefficient for encoding negative numbers " +"– if your field is likely to have negative values, use sint32 instead." +msgstr "" + +#: ../../source/reference/spi.md +msgid "int32" +msgstr "" + +#: ../../source/reference/spi.md +msgid "int" +msgstr "" + +#: ../../source/reference/spi.md +msgid "

int64" +msgstr "" + +#: ../../source/reference/spi.md +msgid "" +"Uses variable-length encoding. Inefficient for encoding negative numbers " +"– if your field is likely to have negative values, use sint64 instead." +msgstr "" + +#: ../../source/reference/spi.md +msgid "int64" +msgstr "" + +#: ../../source/reference/spi.md +msgid "long" +msgstr "" + +#: ../../source/reference/spi.md +msgid "int/long" +msgstr "" + +#: ../../source/reference/spi.md +msgid "

uint32" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Uses variable-length encoding." +msgstr "" + +#: ../../source/reference/spi.md +msgid "uint32" +msgstr "" + +#: ../../source/reference/spi.md +msgid "

uint64" +msgstr "" + +#: ../../source/reference/spi.md +msgid "uint64" +msgstr "" + +#: ../../source/reference/spi.md +msgid "

sint32" +msgstr "" + +#: ../../source/reference/spi.md +msgid "" +"Uses variable-length encoding. Signed int value. These more efficiently " +"encode negative numbers than regular int32s." +msgstr "" + +#: ../../source/reference/spi.md +msgid "

sint64" +msgstr "" + +#: ../../source/reference/spi.md +msgid "" +"Uses variable-length encoding. Signed int value. These more efficiently " +"encode negative numbers than regular int64s." +msgstr "" + +#: ../../source/reference/spi.md +msgid "

fixed32" +msgstr "" + +#: ../../source/reference/spi.md +msgid "" +"Always four bytes. More efficient than uint32 if values are often greater" +" than 2^28." +msgstr "" + +#: ../../source/reference/spi.md +msgid "

fixed64" +msgstr "" + +#: ../../source/reference/spi.md +msgid "" +"Always eight bytes. More efficient than uint64 if values are often " +"greater than 2^56." +msgstr "" + +#: ../../source/reference/spi.md +msgid "

sfixed32" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Always four bytes." +msgstr "" + +#: ../../source/reference/spi.md +msgid "

sfixed64" +msgstr "" + +#: ../../source/reference/spi.md +msgid "Always eight bytes." +msgstr "" + +#: ../../source/reference/spi.md +msgid "

bool" +msgstr "" + +#: ../../source/reference/spi.md +msgid "bool" +msgstr "" + +#: ../../source/reference/spi.md +msgid "boolean" +msgstr "" + +#: ../../source/reference/spi.md +msgid "

string" +msgstr "" + +#: ../../source/reference/spi.md +msgid "A string must always contain UTF-8 encoded or 7-bit ASCII text." +msgstr "" + +#: ../../source/reference/spi.md +msgid "string" +msgstr "" + +#: ../../source/reference/spi.md +msgid "String" +msgstr "" + +#: ../../source/reference/spi.md +msgid "str/unicode" +msgstr "" + +#: ../../source/reference/spi.md +msgid "

bytes" +msgstr "" + +#: ../../source/reference/spi.md +msgid "May contain any arbitrary sequence of bytes." +msgstr "" + +#: ../../source/reference/spi.md +msgid "ByteString" +msgstr "" + +#: ../../source/reference/spi.md +msgid "str" +msgstr "" diff --git a/docs/locales/zh_CN/LC_MESSAGES/topics/deployment/deployment.po b/docs/locales/zh_CN/LC_MESSAGES/topics/deployment/deployment.po new file mode 100644 index 0000000..56992a6 --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/topics/deployment/deployment.po @@ -0,0 +1,232 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023 Ant Group Co., Ltd. +# This file is distributed under the same license as the SecretFlow-Serving +# package. +# FIRST AUTHOR , 2023. +# +msgid "" +msgstr "" +"Project-Id-Version: SecretFlow-Serving \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-01-04 16:56+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.13.0\n" + +#: ../../source/topics/deployment/deployment.rst:3 +msgid "How to deployment SecretFlow-Serving" +msgstr "如何部署 SecretFlow-Serving" + +#: ../../source/topics/deployment/deployment.rst:5 +msgid "" +"This document describes how to deploy SecretFlow-Serving with docker, " +"it's basically same with :doc:`/intro/tutorial`, but deployed in multi-" +"machine." +msgstr "" +"此文档说明了如何通过 docker 部署 SecretFlow-Serving,本文档的内容类似 :doc:`/intro/tutorial` " +"的内容但 SecretFlow-Serving 将会使用多机进行部署。" + +#: ../../source/topics/deployment/deployment.rst:7 +msgid "" +"Before start this doc, we assume that the reader has some experience " +"using the docker-compose utility. If you are new to Docker Compose, " +"please consider reviewing the `official Docker Compose overview " +"`_, or checking out the `Getting " +"Started guide `_." +msgstr "" +"在开始之前,我们假设读者具有使用 docker-compose 的相关经验。如果您没有相关经验,可以考虑参考 `official Docker " +"Compose overview `_,或参考 `Getting " +"Started guide `_。" + +#: ../../source/topics/deployment/deployment.rst:10 +msgid "Deployment Diagram" +msgstr "部署图" + +#: ../../source/topics/deployment/deployment.rst:12 +msgid "" +"The deployment diagram of the SecretFlow-Serving system that we plan to " +"deploy is shown as the following figure, it involves a total of two " +"party, including two parties named ``Alice`` and ``Bob``. We use two " +"machines to simulate different parties." +msgstr "" +"SecretFlow-Serving系统的部署图如下所示,涉及到总共两个参与方,包括名为 ``Alice`` 和 ``Bob`` " +"的两个参与方。这里使用两台计算机来模拟不同的参与方。" + +#: ../../source/topics/deployment/deployment.rst:17 +msgid "" +"The SecretFlow-Serving is served through the HTTP protocol. It is " +"recommended to use HTTPS instead in production environments. Please check" +" :ref:`TLS Configuration ` for details." +msgstr "" +"本示例的SecretFlow-Serving 通过 HTTP 协议提供服务。然而对于生产环境,建议使用 HTTPS 协议来代替。请查看 " +":ref:`TLS 配置 ` 获取详细信息。" + +#: ../../source/topics/deployment/deployment.rst:20 +msgid "Step 1: Deploy SecretFlow-Serving" +msgstr "步骤1:部署 SecretFlow-Serving" + +#: ../../source/topics/deployment/deployment.rst:22 +msgid "" +"Here we present how to deploy serving in party ``Alice``, it's same with " +"party ``Bob``." +msgstr "这里我们会展示 ``Alice`` 如何部署 serving ,对于 ``Bob`` 流程基本类似。" + +#: ../../source/topics/deployment/deployment.rst:25 +msgid "1.1 Create a Workspace" +msgstr "1.1 创建工作空间" + +#: ../../source/topics/deployment/deployment.rst:32 +msgid "" +"Here, we use the model file from the \"examples\" directory as a " +"demonstration and place it in the \"serving\" directory. Please replace " +"the following path with the actual path according to your situation." +msgstr "" +"我们这里使用 \"examples\" 目录下的模型包作为示例文件然后将其放置在 \"serving\" " +"目录下,请根据您的实际情况替换下面命令中的路径地址。" + +#: ../../source/topics/deployment/deployment.rst:41 +msgid "For ``Bob`` should use model file `serving/examples/bob/glm-test.tar.gz`." +msgstr "对 ``Bob`` 来说可以使用 `serving/examples/bob/glm-test.tar.gz` 模型包。" + +#: ../../source/topics/deployment/deployment.rst:45 +msgid "1.2 Create Serving config file" +msgstr "1.2 创建 Serving 配置文件" + +#: ../../source/topics/deployment/deployment.rst:47 +msgid "" +"Create a file called ``serving.config`` in your workspace and paste the " +"following code in:" +msgstr "在您的工作空间目录下创建一个名为 ``serving.config`` 的文件,并将下面的内容添加到其的内容中:" + +#: ../../source/topics/deployment/deployment.rst:92 +msgid "See :ref:`Serving Config ` for more config information" +msgstr "请参考 :ref:`Serving Config ` 来获取更多信息。" + +#: ../../source/topics/deployment/deployment.rst:96 +msgid "" +"The above configuration is referenced from `alice-serving-config " +"`_." +msgstr "" +"以上配置内容参考自 `alice-serving-config " +"`_。" + +#: ../../source/topics/deployment/deployment.rst:98 +msgid "" +"For ``Bob``, you should refer to `bob-serving-config " +"`_" +" ." +msgstr "" +"对于 ``Bob`` ,可以参考 `bob-serving-config " +"`_。" + +#: ../../source/topics/deployment/deployment.rst:102 +msgid "1.3 Create logging config file" +msgstr "1.3 创建 logging 配置文件" + +#: ../../source/topics/deployment/deployment.rst:104 +msgid "" +"Create a file called ``logging.config`` in your workspace and paste the " +"following code in:" +msgstr "在您的工作空间目录下创建一个名为 ``logging.config`` 的文件,并将下面的内容添加到其的内容中:" + +#: ../../source/topics/deployment/deployment.rst:115 +msgid "" +"See :ref:`Logging Config ` for more logging config " +"information." +msgstr "" +"请参考 :ref:`Logging Config ` 获得更多关于 logging " +"配置的信息。" + +#: ../../source/topics/deployment/deployment.rst:119 +msgid "" +"The above configuration is referenced from `alice-logging-config " +"`_." +msgstr "" +"以上配置内容参考自 `alice-logging-config " +"`_。" + +#: ../../source/topics/deployment/deployment.rst:121 +msgid "" +"For ``Bob``, you should refer to `bob-logging-config " +"`_" +" ." +msgstr "" +"对于 ``Bob`` ,可以参考 `bob-logging-config " +"`_。" + +#: ../../source/topics/deployment/deployment.rst:125 +msgid "1.4 Create docker-compose file" +msgstr "1.4 创建 docker-compose 文件" + +#: ../../source/topics/deployment/deployment.rst:127 +msgid "" +"Create a file called ``docker-compose.yaml`` in your workspace and paste " +"the following code in:" +msgstr "在您的工作空间目录下创建一个名为 ``docker-compose.yaml`` 文件并将下列内容添加到其中:" + +#: ../../source/topics/deployment/deployment.rst:149 +msgid "" +"``__ALICE_PORT__`` is the published port on the host machine which is " +"used for SecretFlow-Serving service to listen on, you need to replace it " +"with an accessible port number. In this case, we have designated it as " +"``9010`` for ``Alice``, ``9011`` for ``Bob``." +msgstr "" +"``__ALICE_PORT__`` 是 Alice 的 SecretFlow-Serving " +"服务在宿主机上的监听端口,您需要用一个可访问的端口号替换它,这里为 ``Alice`` 设置 ``9010``,为 ``Bob`` 设置 " +"``9011``。" + +#: ../../source/topics/deployment/deployment.rst:153 +msgid "Step 2: Start Serving Service" +msgstr "步骤 2:启动 Serving 服务" + +#: ../../source/topics/deployment/deployment.rst:155 +msgid "The file your workspace should be as follows:" +msgstr "您工作区的文件应如下所示:" + +#: ../../source/topics/deployment/deployment.rst:164 +msgid "Then you can start serving service by running docker compose up" +msgstr "然后您可以运行 docker compose up 来启动 serving 服务" + +#: ../../source/topics/deployment/deployment.rst:171 +msgid "You can use docker logs to check whether serving works well" +msgstr "您可以使用 docker logs 检查 serving 是否正常工作" + +#: ../../source/topics/deployment/deployment.rst:177 +msgid "" +"Now, ``Alice`` serving is listening on ``9010``, you can confirm if the " +"service is ready by accessing the ``/health`` endpoint." +msgstr "" +"现在 ``Alice`` 的 serving 服务监听了 ``9010`` 端口,您可以通过访问 ``/health`` " +"接口来检测服务是否已经准备完成。" + +#: ../../source/topics/deployment/deployment.rst:183 +msgid "" +"When the endpoint returns a status code of ``200``, it means that the " +"service is ready." +msgstr "若接口返回状态码 ``200`` 则标志服务已经准备完成。" + +#: ../../source/topics/deployment/deployment.rst:186 +msgid "Step 3: Predict Test" +msgstr "步骤 3:预测测试" + +#: ../../source/topics/deployment/deployment.rst:188 +msgid "" +"Based on the capabilities of `Brpc " +"`_, serving supports accessing " +"through various protocols. Here, we are using an HTTP request to test the" +" predict interface of serving." +msgstr "" +"得益于 `Brpc `_ 的能力,serving " +"能够支持通过多种协议进行访问。这里我们通过发送一个 HTTP 的预测请求来测试服务。" + +#: ../../source/topics/deployment/deployment.rst:190 +msgid "" +"You can read :ref:`SecretFlow-Serving API ` for more " +"information about serving APIs." +msgstr "" +"您可以参考 :ref:`SecretFlow-Serving API ` 获得更多关于 serving API " +"的信息。" diff --git a/docs/locales/zh_CN/LC_MESSAGES/topics/graph/intro_to_graph.po b/docs/locales/zh_CN/LC_MESSAGES/topics/graph/intro_to_graph.po new file mode 100644 index 0000000..f1165ec --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/topics/graph/intro_to_graph.po @@ -0,0 +1,220 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023 Ant Group Co., Ltd. +# This file is distributed under the same license as the SecretFlow-Serving +# package. +# FIRST AUTHOR , 2023. +# +msgid "" +msgstr "" +"Project-Id-Version: SecretFlow-Serving \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-01-04 16:56+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.13.0\n" + +#: ../../source/topics/graph/intro_to_graph.rst:4 +msgid "Introduction to Graph" +msgstr "模型图的介绍" + +#: ../../source/topics/graph/intro_to_graph.rst:6 +msgid "" +"Secretflow-Serving has defined a protocol for describing prediction " +"computations, which mainly includes descriptions of operators, " +"attributes, nodes, graphs, and executions." +msgstr "Secretflow-Serving 定义了一个协议来描述模型预测,其主要包含了关于算子、属性、节点、模型图以及执行体的描述。" + +#: ../../source/topics/graph/intro_to_graph.rst:-1 +msgid "graph structure" +msgstr "模型图结构" + +#: ../../source/topics/graph/intro_to_graph.rst:12 +msgid "Operators" +msgstr "算子" + +#: ../../source/topics/graph/intro_to_graph.rst:13 +msgid "" +"Operators describe specific computations. By combining operators, " +"different model computations can be achieved. Plain operators perform " +"computations using local data only, while secure computation operators " +"collaborate with peer operators from other participants for secure " +"computations." +msgstr "算子描述了特定的计算函数。通过组合多个算子,可以实现对不同模型计算的支持。其中,明文计算算子在计算过程中只会使用本方的特征数据,密文算子则会利用安全协议在参与方之间进行密态计算。" + +#: ../../source/topics/graph/intro_to_graph.rst:16 +msgid "OpDef" +msgstr "OpDef" + +#: ../../source/topics/graph/intro_to_graph.rst:18 +msgid "name: Unique name of the operator." +msgstr "name: 算子的唯一名称。" + +#: ../../source/topics/graph/intro_to_graph.rst:19 +msgid "desc: Description of the operator." +msgstr "desc: 算子的描述信息。" + +#: ../../source/topics/graph/intro_to_graph.rst:20 +msgid "version: The version of the operator." +msgstr "version: 算子的版本号信息。" + +#: ../../source/topics/graph/intro_to_graph.rst:21 +msgid "tag: Some properties of the operator." +msgstr "tag: 算子的特定属性描述。" + +#: ../../source/topics/graph/intro_to_graph.rst:22 +msgid "attributes: Please check `Attributes` part below." +msgstr "attributes: 请参考下面 `属性` 的内容" + +#: ../../source/topics/graph/intro_to_graph.rst:23 +msgid "inputs and output: The info of the inputs or output of the operator." +msgstr "inputs and output: 算子的输入以及输出的描述信息。" + +#: ../../source/topics/graph/intro_to_graph.rst:26 +msgid "Attributes" +msgstr "属性" + +#: ../../source/topics/graph/intro_to_graph.rst:27 +msgid "" +"Operators have various attributes determined by their definitions. These " +"attributes and their data support the operators in completing " +"computations." +msgstr "算子的定义中包含多种不同的属性。属性以及属性的值会被算子在计算过程中使用。" + +#: ../../source/topics/graph/intro_to_graph.rst:30 +msgid "AttrDef" +msgstr "AttrDef" + +#: ../../source/topics/graph/intro_to_graph.rst:32 +msgid "name: Must be unique among all attrs of the operator." +msgstr "name: 属性的名称需要在算子定义下唯一。" + +#: ../../source/topics/graph/intro_to_graph.rst:33 +msgid "desc: Description of the attribute." +msgstr "desc: 属性的描述信息。" + +#: ../../source/topics/graph/intro_to_graph.rst:34 +msgid "type: Please check :ref:`AttrType `." +msgstr "type: 请参考 :doc:`AttrType ` 获取详细信息。" + +#: ../../source/topics/graph/intro_to_graph.rst:35 +msgid "" +"is_optional: If True, when AttrValue is not provided, `default_value` " +"would be used. Else, AttrValue must be provided." +msgstr "" +"is_optional: 为 True 时,如果对应的 AttrValue 没有设置,`default_value` 的值将会被使用。否则,对应的" +" AttrValue 必需提供。" + +#: ../../source/topics/graph/intro_to_graph.rst:36 +msgid "default_value: Please check :ref:`AttrValue `." +msgstr "default_value: 请参考 :doc:`AttrValue ` 获取详细信息。" + +#: ../../source/topics/graph/intro_to_graph.rst:39 +msgid "Nodes" +msgstr "节点" + +#: ../../source/topics/graph/intro_to_graph.rst:40 +msgid "" +"Nodes are instances of operators. They store the attribute values " +"(`AttrValue`) of the operators." +msgstr "节点是算子的实例。节点内包含算子的属性对应的属性值(`AttrValue`)" + +#: ../../source/topics/graph/intro_to_graph.rst:43 +msgid "NodeDef" +msgstr "NodeDef" + +#: ../../source/topics/graph/intro_to_graph.rst:45 +msgid "name: Must be unique among all nodes of the graph." +msgstr "name: 节点的名称需要在模型图内唯一。" + +#: ../../source/topics/graph/intro_to_graph.rst:46 +msgid "op: The operator name." +msgstr "op: 节点对应的算子名称。" + +#: ../../source/topics/graph/intro_to_graph.rst:47 +msgid "" +"parents: The parent node names of the node. The order of the parent nodes" +" should match the order of the inputs of the node." +msgstr "parents: 本节点的父母节点名称列表。其顺序应保持和该节点对应算子的输入元素顺序一致。" + +#: ../../source/topics/graph/intro_to_graph.rst:48 +msgid "" +"attr_values: The attribute values config in the node. Note that this " +"should include all attrs defined in the corresponding OpDef" +msgstr "attr_values: 本节点拥有的属性值。注意节点的属性值应该包含其对应的算子的所有属性。" + +#: ../../source/topics/graph/intro_to_graph.rst:49 +msgid "op_version: The operator version." +msgstr "op_version: 节点对应的算子的版本信息。" + +#: ../../source/topics/graph/intro_to_graph.rst:52 +msgid "Graphs" +msgstr "模型图" + +#: ../../source/topics/graph/intro_to_graph.rst:53 +msgid "" +"Graphs can consist of one or multiple nodes. They form a directed acyclic" +" graph, where the direction represents the flow of data computation. A " +"graph can represent a complete prediction computation process, including " +"preprocessing, model prediction, and post-processing." +msgstr "模型图可以包含一个或多个节点。这些节点组成了一个有向无环图,图的方向表示了节点间数据的计算传播方向。一个图能够表示一个完整的,包含预处理、模型预测、结果后处理的预测计算过程。" + +#: ../../source/topics/graph/intro_to_graph.rst:55 +msgid "" +"Each participant will have a graph with the same structure but different " +"data。" +msgstr "每个参与方都拥有着一个有着相同结构但是不同数据的模型图。" + +#: ../../source/topics/graph/intro_to_graph.rst:58 +msgid "GraphDef" +msgstr "GraphDef" + +#: ../../source/topics/graph/intro_to_graph.rst:60 +msgid "version: Version of the graph." +msgstr "version: 图的版本信息。" + +#: ../../source/topics/graph/intro_to_graph.rst:61 +msgid "node_list: The node list of the graph." +msgstr "node_list: 图拥有的节点列表。" + +#: ../../source/topics/graph/intro_to_graph.rst:62 +msgid "execution_list: Please check `Executions` part below." +msgstr "execution_list: 请参考下面 `执行体` 的部分。" + +#: ../../source/topics/graph/intro_to_graph.rst:65 +msgid "Executions" +msgstr "执行体" + +#: ../../source/topics/graph/intro_to_graph.rst:66 +msgid "" +"Execution contain a subset of nodes from the main graph and form a " +"subgraph. They represent the model computation scheduling patterns. A " +"graph can have multiple executions." +msgstr "执行体包含了一组节点,它们是图的节点的子集并且能够组成一个子图。执行体描述了模型计算的调度模式,一个图中会包含复数个执行体" + +#: ../../source/topics/graph/intro_to_graph.rst:-1 +msgid "execution" +msgstr "execution" + +#: ../../source/topics/graph/intro_to_graph.rst:72 +msgid "ExecutionDef" +msgstr "ExecutionDef" + +#: ../../source/topics/graph/intro_to_graph.rst:74 +msgid "" +"nodes: Represents the nodes contained in this execution. Note that these " +"node names should be findable and unique within the node definitions. One" +" node can only exist in one execution and must exist in one." +msgstr "nodes: 执行体中包含的节点列表。注意,这些节点应该是包含在模型图中的节点,且每个节点只能属于某一个执行体。" + +#: ../../source/topics/graph/intro_to_graph.rst:75 +msgid "" +"config: The runtime config of the execution. It describes the scheduling " +"logic and session-related states of this execution unit. for more " +"details, please check :ref:`RuntimeConfig `." +msgstr "" +"config: 执行体的运行配置。其描述执行体的调度逻辑以及会话状态。请查看 :ref:`RuntimeConfig ` " +"获取更多信息。" diff --git a/docs/locales/zh_CN/LC_MESSAGES/topics/graph/operator_list.po b/docs/locales/zh_CN/LC_MESSAGES/topics/graph/operator_list.po new file mode 100644 index 0000000..5da662e --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/topics/graph/operator_list.po @@ -0,0 +1,326 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023 Ant Group Co., Ltd. +# This file is distributed under the same license as the SecretFlow-Serving +# package. +# FIRST AUTHOR , 2023. +# +msgid "" +msgstr "" +"Project-Id-Version: SecretFlow-Serving \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2023-12-28 14:32+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.13.0\n" + +#: ../../source/topics/graph/operator_list.md:5 +msgid "SecretFlow-Serving Operator List" +msgstr "SecretFlow-Serving 算子列表" + +#: ../../source/topics/graph/operator_list.md:9 +msgid "Last update: Thu Dec 28 14:28:43 2023" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:10 +msgid "MERGE_Y" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:13 +#: ../../source/topics/graph/operator_list.md:51 +msgid "Operator version: 0.0.2" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:15 +msgid "Merge all partial y(score) and apply link function" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:16 +#: ../../source/topics/graph/operator_list.md:54 +#: ../../source/topics/graph/operator_list.md:85 +msgid "Attrs" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Name" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Description" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Type" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Required" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Notes" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "output_col_name" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "The column name of merged score" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "String" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Y" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "link_function" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "" +"Type of link function, defined in " +"`secretflow_serving/protos/link_function.proto`. Optional value: LF_LOG, " +"LF_LOGIT, LF_INVERSE, LF_RECIPROCAL, LF_IDENTITY, LF_SIGMOID_RAW, " +"LF_SIGMOID_MM1, LF_SIGMOID_MM3, LF_SIGMOID_GA, LF_SIGMOID_T1, " +"LF_SIGMOID_T3, LF_SIGMOID_T5, LF_SIGMOID_T7, LF_SIGMOID_T9, " +"LF_SIGMOID_LS7, LF_SIGMOID_SEG3, LF_SIGMOID_SEG5, LF_SIGMOID_DF, " +"LF_SIGMOID_SR, LF_SIGMOID_SEGLS" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "input_col_name" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "The column name of partial_y" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "yhat_scale" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "" +"In order to prevent value overflow, GLM training is performed on the " +"scaled y label. So in the prediction process, you need to enlarge yhat " +"back to get the real predicted value, `yhat = yhat_scale * link(X * W)`" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Double" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "N" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Default: 1.0." +msgstr "" + +#: ../../source/topics/graph/operator_list.md:26 +#: ../../source/topics/graph/operator_list.md:95 +msgid "Tags" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "returnable" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "The operator's output can be the final result" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "mergeable" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "" +"The operator accept the output of operators with different participants " +"and will somehow merge them." +msgstr "" + +#: ../../source/topics/graph/operator_list.md:34 +#: ../../source/topics/graph/operator_list.md:65 +#: ../../source/topics/graph/operator_list.md:102 +msgid "Inputs" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "partial_ys" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "The list of partial y, data type: `double`" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:41 +#: ../../source/topics/graph/operator_list.md:72 +#: ../../source/topics/graph/operator_list.md:109 +msgid "Output" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "scores" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "The merge result of `partial_ys`, data type: `double`" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:48 +msgid "DOT_PRODUCT" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:53 +msgid "Calculate the dot product of feature weights and values" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "intercept" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Value of model intercept" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Default: 0.0." +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Column name of partial y" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "feature_weights" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "List of feature weights" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Double List" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "input_types" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "" +"List of input feature data types, Note that there is a loss of precision " +"when using `DT_FLOAT` type. Optional value: DT_UINT8, DT_INT8, DT_UINT16," +" DT_INT16, DT_UINT32, DT_INT32, DT_UINT64, DT_INT64, DT_FLOAT, DT_DOUBLE" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "String List" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "feature_names" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "List of feature names" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "features" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Input feature table" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "The calculation results, they have a data type of `double`." +msgstr "" + +#: ../../source/topics/graph/operator_list.md:79 +msgid "ARROW_PROCESSING" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:82 +msgid "Operator version: 0.0.1" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:84 +msgid "Replay secretflow compute functions" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "content_json_flag" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Whether `trace_content` is serialized json" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Boolean" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Default: False." +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "trace_content" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Serialized data of secretflow compute trace" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Bytes" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "output_schema_bytes" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Serialized data of output schema(arrow::Schema)" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "input_schema_bytes" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Serialized data of input schema(arrow::Schema)" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "input" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "output" +msgstr "" + +#~ msgid "Last update: Fri Dec 22 14:29:43 2023" +#~ msgstr "" + +#~ msgid "" +#~ "Optional value: LF_LOG, LF_LOGIT, LF_INVERSE," +#~ " LF_LOGIT_V2, LF_RECIPROCAL, LF_IDENTITY, " +#~ "LF_SIGMOID_RAW, LF_SIGMOID_MM1, LF_SIGMOID_MM3, " +#~ "LF_SIGMOID_GA, LF_SIGMOID_T1, LF_SIGMOID_T3, " +#~ "LF_SIGMOID_T5, LF_SIGMOID_T7, LF_SIGMOID_T9, " +#~ "LF_SIGMOID_LS7, LF_SIGMOID_SEG3, LF_SIGMOID_SEG5, " +#~ "LF_SIGMOID_DF, LF_SIGMOID_SR, LF_SIGMOID_SEGLS" +#~ msgstr "" diff --git a/docs/locales/zh_CN/LC_MESSAGES/topics/index.po b/docs/locales/zh_CN/LC_MESSAGES/topics/index.po new file mode 100644 index 0000000..da4af89 --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/topics/index.po @@ -0,0 +1,34 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023 Ant Group Co., Ltd. +# This file is distributed under the same license as the SecretFlow-Serving +# package. +# FIRST AUTHOR , 2023. +# +msgid "" +msgstr "" +"Project-Id-Version: SecretFlow-Serving \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2023-12-25 11:37+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.13.0\n" + +#: ../../source/topics/index.rst:7 +msgid "system" +msgstr "系统" + +#: ../../source/topics/index.rst:14 +msgid "deployment" +msgstr "部署" + +#: ../../source/topics/index.rst:20 +msgid "graph" +msgstr "模型图" + +#: ../../source/topics/index.rst:4 +msgid "Topics" +msgstr "主题" diff --git a/docs/locales/zh_CN/LC_MESSAGES/topics/system/intro.po b/docs/locales/zh_CN/LC_MESSAGES/topics/system/intro.po new file mode 100644 index 0000000..313e437 --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/topics/system/intro.po @@ -0,0 +1,180 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023 Ant Group Co., Ltd. +# This file is distributed under the same license as the SecretFlow-Serving +# package. +# FIRST AUTHOR , 2023. +# +msgid "" +msgstr "" +"Project-Id-Version: SecretFlow-Serving \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2023-12-26 11:01+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.13.0\n" + +#: ../../source/topics/system/intro.rst:2 +msgid "SecretFlow-Serving System Introduction" +msgstr "SecretFlow-Serving 系统介绍" + +#: ../../source/topics/system/intro.rst:4 +msgid "" +"SecretFlow-Serving is a serving system for privacy-preserving machine " +"learning models." +msgstr "SecretFlow-Serving 是一个加载隐私保护机器学习模型的在线服务系统。" + +#: ../../source/topics/system/intro.rst:7 +msgid "Key Features" +msgstr "关键特性" + +#: ../../source/topics/system/intro.rst:9 +msgid "Support multiple parties (N >= 2)." +msgstr "支持多个参与方(N >= 2)。" + +#: ../../source/topics/system/intro.rst:10 +msgid "Parallel compute between parties." +msgstr "多方并行计算。" + +#: ../../source/topics/system/intro.rst:11 +msgid "Batch Predict API Supported." +msgstr "支持批量预测请求API。" + +#: ../../source/topics/system/intro.rst:12 +msgid "" +"Multi-protocol support. Secretflow-Serving is built on brpc, a high-" +"performance rpc framework, and is capable of using multiple communication" +" protocols." +msgstr "多种协议支持。Secretflow-Serving基于brpc构建,其可支持多种通信协议。" + +#: ../../source/topics/system/intro.rst:13 +msgid "Support multiple types feature sources, e.g. SPI, CSV file, Mock data." +msgstr "多种特征数据源的支持:SPI、CSV文件、仿真数据等。" + +#: ../../source/topics/system/intro.rst:14 +msgid "Specific model graph definition." +msgstr "标准模型图定义。" + +#: ../../source/topics/system/intro.rst:15 +msgid "Federated learning model predict." +msgstr "联邦模型预测。" + +#: ../../source/topics/system/intro.rst:16 +msgid "One process one model/version." +msgstr "单进程单模型模式。" + +#: ../../source/topics/system/intro.rst:20 +msgid "Architecture" +msgstr "架构" + +#: ../../source/topics/system/intro.rst:22 +msgid "" +"Secretflow-Serving leverages the model package trained with Secretflow to" +" provide model prediction capabilities at different security levels. It " +"achieves this by utilizing the online feature data provided by each " +"participant without compromising the integrity of the original data " +"domain." +msgstr "Secretflow-Serving 使用由 Secretflow 训练产生的模型包,在预测参与方各方的在线特征数据不出域的前提下,提供不同的安全级别的模型预测能力。" + +#: ../../source/topics/system/intro.rst:-1 +msgid "Secretflow-Serving Deployment Architecture" +msgstr "Secretflow-Serving 部署架构" + +#: ../../source/topics/system/intro.rst:29 +msgid "Key Concepts" +msgstr "关键概念" + +#: ../../source/topics/system/intro.rst:31 +msgid "" +"To understand the architecture of Secretflow-Serving, you need to " +"understand the following key concepts:" +msgstr "为了理解 Secretflow-Serving 的系统架构,你需要理解下列的一些关键概念:" + +#: ../../source/topics/system/intro.rst:35 +msgid "Model Package" +msgstr "模型包" + +#: ../../source/topics/system/intro.rst:37 +msgid "" +"A Secretflow-Serving model package is a compressed package comprising a " +"model file, a manifest file, and other metadata files." +msgstr "Secretflow-Serving 的模型包是一个包含模型文件、描述文件以及一些元数据文件的压缩包。" + +#: ../../source/topics/system/intro.rst:39 +msgid "" +"The manifest file provides meta-information about the model file and " +"follows the defined structure outlined :ref:`here `." +msgstr "描述文件包含了关于模型文件的元数据信息,它的内容结构可参考 :ref:`这里 `。" + +#: ../../source/topics/system/intro.rst:41 +msgid "" +"The model file contains the graph that represents the model inference " +"process, encompassing pre-processing, post-processing, and the specific " +"inference algorithm. For graph details, please check :ref:`Introduction " +"to Graph `." +msgstr "模型文件包含了一个描述了模型预测过程、特征预处理、结果后处理以及特定预测算法的图。想了解更多关于图的细节,可以查看" +" :ref:`模型图的介绍 ` " + +#: ../../source/topics/system/intro.rst:43 +msgid "" +"The metadata files, while optional, stores additional data information " +"required during the model inference process." +msgstr "元数据文件存储了一些在模型预测中可能被用到的额外数据信息,它们只会在必要场景时存在。" + +#: ../../source/topics/system/intro.rst:47 +msgid "Model Source" +msgstr "模型数据源" + +#: ../../source/topics/system/intro.rst:49 +msgid "" +"Secretflow-Serving supports retrieving model packages from different " +"storage sources. Currently, the following data sources are supported:" +msgstr "Secretflow-Serving 支持从不同的存储数据源中获取模型包。目前已支持下列的数据源:" + +#: ../../source/topics/system/intro.rst:51 +msgid "" +"Local Filesystem Data Source: Secretflow-Serving loads the model package " +"from a specified local path." +msgstr "本地文件系统数据源:Secretflow-Serving 可从一个可访问的本地路径读取模型包。" + +#: ../../source/topics/system/intro.rst:52 +msgid "" +"OSS/S3 Data Source: Secretflow-Serving attempts to download the model " +"package from the OSS/S3 storage based on the provided configuration " +"before loading it locally." +msgstr "OSS/S3 数据源:Secretflow-Serving 支持根据提供的配置从 OSS/S3 存储中下载模型包到本地进行加载。" + +#: ../../source/topics/system/intro.rst:56 +msgid "Feature Source" +msgstr "特征数据源" + +#: ../../source/topics/system/intro.rst:58 +msgid "" +"Secretflow-Serving obtains the necessary features for the online " +"inference process through the Feature Source. Currently, the platform " +"supports the following feature data sources:" +msgstr "Secretflow-Serving 在预测过程中会向特征数据源请求获取必要的特征数据。当前已支持的特征数据源如下所示:" + +#: ../../source/topics/system/intro.rst:60 +msgid "" +"HTTP Source: Secretflow-Serving defines a Service Provider Interface " +"(:doc:`SPI `) for retrieving feature data. Feature " +"providers can implement this SPI to supply features to Secretflow-" +"Serving." +msgstr "HTTP数据源:Secretflow-Serving 定义一组服务提供接口(:doc:`SPI `)来获取特征数据。特征提供方可以通过实现" + +#: ../../source/topics/system/intro.rst:61 +msgid "" +"CSV Source: Secretflow-Serving supports direct loading of CSV file as a " +"feature source. For performance reasons, the CSV file is fully loaded " +"into memory and features are filtered based on the ID column." +msgstr "CSV数据源:Secretflow-Serving 支持直接加载CSV文件作为特征特征数据源。性能原因考虑,CSV文件会被整个加载到内存中以支持通过ID列查询到特定特征。" + +#: ../../source/topics/system/intro.rst:62 +msgid "" +"Mock Source: In this scenario, Secretflow-Serving uses randomly generated" +" values as feature data." +msgstr "仿真数据源:此模式下,Secretflow-Serving 会使用随机值作为特征数据参与计算。" diff --git a/docs/locales/zh_CN/LC_MESSAGES/topics/system/observability.po b/docs/locales/zh_CN/LC_MESSAGES/topics/system/observability.po new file mode 100644 index 0000000..e2810d3 --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/topics/system/observability.po @@ -0,0 +1,22 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023 Ant Group Co., Ltd. +# This file is distributed under the same license as the SecretFlow-Serving +# package. +# FIRST AUTHOR , 2023. +# +msgid "" +msgstr "" +"Project-Id-Version: SecretFlow-Serving \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2023-12-25 11:37+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.13.0\n" + +#: ../../source/topics/system/observability.rst:2 +msgid "SecretFlow-Serving System Observability" +msgstr "SecretFlow-Serving 系统可观测性" diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..b344c6a --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..29df61a --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,16 @@ +myst-parser==0.18.1 +rstcheck==6.1.1 +sphinx==5.3.0 +nbsphinx==0.8.9 +sphinx-autobuild==2021.3.14 +pydata-sphinx-theme==0.10.1 +sphinx-markdown-parser==0.2.4 +sphinxcontrib-actdiag==3.0.0 +sphinxcontrib-blockdiag==3.0.0 +sphinxcontrib-nwdiag==2.0.0 +sphinxcontrib-seqdiag==3.0.0 +pytablewriter==0.64.2 +linkify-it-py==2.0.0 +sphinx_design==0.3.0 +sphinx-intl==2.0.1 +mdutils==1.6.0 diff --git a/docs/source/_static/css/custom.css b/docs/source/_static/css/custom.css new file mode 100644 index 0000000..275c70b --- /dev/null +++ b/docs/source/_static/css/custom.css @@ -0,0 +1,12 @@ +@import "../basic.css"; + +html[data-theme="light"] { + --pst-color-primary: rgb(22 119 255); + --pst-color-secondary: rgb(22 255 201); +} + +html[data-theme="dark"] { + --pst-color-primary: rgb(22 119 255); + --pst-color-secondary: rgb(22 255 201); + --pst-color-background: rgb(56, 56, 56); +} \ No newline at end of file diff --git a/docs/source/_static/favicon.ico b/docs/source/_static/favicon.ico new file mode 100644 index 0000000..58c2592 Binary files /dev/null and b/docs/source/_static/favicon.ico differ diff --git a/docs/source/_static/js/custom.js b/docs/source/_static/js/custom.js new file mode 100644 index 0000000..184b5ae --- /dev/null +++ b/docs/source/_static/js/custom.js @@ -0,0 +1,3 @@ +$(document).ready(function () { + $('a.external').attr('target', '_blank'); +}); diff --git a/docs/source/_static/logo-dark.png b/docs/source/_static/logo-dark.png new file mode 100644 index 0000000..e77d487 Binary files /dev/null and b/docs/source/_static/logo-dark.png differ diff --git a/docs/source/_static/logo-light.png b/docs/source/_static/logo-light.png new file mode 100644 index 0000000..2117fe6 Binary files /dev/null and b/docs/source/_static/logo-light.png differ diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..dfe650b --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,118 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = "SecretFlow-Serving" +copyright = "2023 Ant Group Co., Ltd." +author = "SecretFlow-Serving authors" +release = "0.1" + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [ + "sphinx.ext.napoleon", + "sphinx.ext.autodoc", + "sphinx.ext.todo", + "sphinx.ext.viewcode", + "sphinx.ext.extlinks", + "sphinx.ext.autosectionlabel", + "myst_parser", + "nbsphinx", + "sphinxcontrib.actdiag", + "sphinxcontrib.blockdiag", + "sphinxcontrib.nwdiag", + "sphinxcontrib.packetdiag", + "sphinxcontrib.rackdiag", + "sphinxcontrib.seqdiag", + "sphinx_design", +] + +templates_path = ["_templates"] +exclude_patterns = [] + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +# multi-language docs + +language = "en" +locale_dirs = ["../locales/"] # path is example but recommended. +gettext_compact = False # optional. +gettext_uuid = False # optional. + +# Enable TODO +todo_include_todos = True + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "pydata_sphinx_theme" + +html_static_path = ["_static"] + +html_favicon = "_static/favicon.ico" + +html_css_files = [ + "css/custom.css", +] + +html_js_files = ["js/custom.js"] + +html_theme_options = { + "icon_links": [ + { + "name": "GitHub", + "url": "https://github.com/secretflow/serving", + "icon": "fab fa-github-square", + "type": "fontawesome", + }, + ], + "logo": { + "text": "SecretFlow-Serving", + "image_light": "logo-light.png", + "image_dark": "logo-dark.png", + }, +} + +myst_enable_extensions = [ + "amsmath", + "colon_fence", + "deflist", + "dollarmath", + "fieldlist", + "html_admonition", + "html_image", + "linkify", + "replacements", + "smartquotes", + "strikethrough", + "substitution", + "tasklist", +] + +suppress_warnings = ["myst.header"] + +myst_gfm_only = True +myst_heading_anchors = 1 +myst_title_to_header = True + + +# app setup hook +def setup(app): + app.add_config_value( + "recommonmark_config", + { + "auto_toc_tree_section": "Contents", + }, + True, + ) diff --git a/docs/source/imgs/architecture.png b/docs/source/imgs/architecture.png new file mode 100644 index 0000000..7cc730f Binary files /dev/null and b/docs/source/imgs/architecture.png differ diff --git a/docs/source/imgs/execution.png b/docs/source/imgs/execution.png new file mode 100644 index 0000000..19fdce2 Binary files /dev/null and b/docs/source/imgs/execution.png differ diff --git a/docs/source/imgs/graph.png b/docs/source/imgs/graph.png new file mode 100644 index 0000000..79a5234 Binary files /dev/null and b/docs/source/imgs/graph.png differ diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..da35d9d --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,53 @@ +.. SecretFlow-Serving documentation master file, created by + sphinx-quickstart on Sat Oct 7 17:42:31 2023. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to SecretFlow-Serving's documentation! +============================================== + +SecretFlow-Serving is a serving system for privacy-preserving machine learning models. + + +Getting started +--------------- + +Follow the :doc:`tutorial ` and try out SecretFlow-Serving on your machine! + + +SecretFlow-Serving Systems +-------------------------- + +- **Overview**: + :doc:`System overview and architecture ` + + +Deployment +---------- + +- **Guides**: + :doc:`How to deploy an SecretFlow-Serving cluster` + +- **Reference**: + :doc:`SecretFlow-Serving service API ` | + :doc:`SecretFlow-Serving system config ` | + :doc:`SecretFlow-Serving feature service spi ` + + +Graph +----------------- + +- **Overview**: + :doc:`Introduction to graphs ` | + :doc:`Operators ` + +- **Reference**: + :doc:`SecretFlow-Serving model ` + + +.. toctree:: + :hidden: + + intro/index + topics/index + reference/index diff --git a/docs/source/intro/index.rst b/docs/source/intro/index.rst new file mode 100644 index 0000000..7792c30 --- /dev/null +++ b/docs/source/intro/index.rst @@ -0,0 +1,6 @@ +Introduction +============ + +.. toctree:: + + tutorial diff --git a/docs/source/intro/tutorial.rst b/docs/source/intro/tutorial.rst new file mode 100644 index 0000000..e04237f --- /dev/null +++ b/docs/source/intro/tutorial.rst @@ -0,0 +1,89 @@ +Quickstart +========== + +TL;DR +----- + +Use ``docker-compose`` to deploy a SecretFlow-Serving cluster, the query the model using the predict API. + + +Start SecretFlow-Serving Service +-------------------------------- + +You could start SecretFlow-Serving service via `docker-compose `_, it would deploy and start services as shown in the following figure, it contains two SecretFlow-Serving from party ``Alice``, ``Bob``. + +.. image:: /imgs/architecture.png + :alt: docker-compose deployment for quickstart example + + +.. note:: + To demonstrate SecretFlow-Serving, we conducted the following simplified operations: + + 1. Both parties of Secretflow-Serving use mock feature source to produce random feature values. + 2. The model files in the examples directory are loaded by ``Alice`` and ``Bob``'s Secretflow-Serving respectively。 + 3. The SecretFlow-Serving is served through the HTTP protocol. However, for production environments, it is recommended to use HTTPS instead. Please check :ref:`TLS Configuration ` for details. + + +.. code-block:: bash + + # startup docker-compose + # If you install docker with Compose V1, pleas use `docker-compose` instead of `docker compose` + (cd examples && docker compose up -d) + +Now, the ``Alice``'s SecretFlow-Serving is listening on ``http://localhost:9010``, the ``Bob``'s SecretFlow-Serving is listening on ``http://localhost:9011``, you could send predict request to it via curl or other http tools. + + +Do Predict +---------- + +send predict request to ``Alice`` + + +.. code-block:: bash + + curl --location 'http://127.0.0.1:9010/PredictionService/Predict' \ + --header 'Content-Type: application/json' \ + --data '{ + "service_spec": { + "id": "test_service_id" + }, + "fs_params": { + "alice": { + "query_datas": [ + "a" + ] + }, + "bob": { + "query_datas": [ + "a" + ] + } + } + }' + +send predict request to ``Bob`` + +.. code-block:: bash + + curl --location 'http://127.0.0.1:9011/PredictionService/Predict' \ + --header 'Content-Type: application/json' \ + --data '{ + "service_spec": { + "id": "test_service_id" + }, + "fs_params": { + "alice": { + "query_datas": [ + "a" + ] + }, + "bob": { + "query_datas": [ + "a" + ] + } + } + }' + +.. note:: + Please checkout :ref:`SecretFlow-Serving API ` for the Predict API details. diff --git a/docs/source/reference/api.md b/docs/source/reference/api.md new file mode 100644 index 0000000..a226a4f --- /dev/null +++ b/docs/source/reference/api.md @@ -0,0 +1,690 @@ +# SecretFlow-Serving API + +## Table of Contents +- Services + + + + + + + + + + - [ExecutionService](#executionservice) + + + + + + - [metrics](#metrics) + + + + + + - [ModelService](#modelservice) + + + + + + - [PredictionService](#predictionservice) + + + + + + + + + + +- Messages + + + + - [Header](#header) + - [Header.DataEntry](#header-dataentry) + - [ServiceSpec](#servicespec) + + + + + + + + + - [ExecuteRequest](#executerequest) + - [ExecuteResponse](#executeresponse) + - [ExecuteResult](#executeresult) + - [ExecutionTask](#executiontask) + - [FeatureSource](#featuresource) + - [IoData](#iodata) + - [NodeIo](#nodeio) + + + + + + - [MetricsRequest](#metricsrequest) + - [MetricsResponse](#metricsresponse) + + + + + + - [GetModelInfoRequest](#getmodelinforequest) + - [GetModelInfoResponse](#getmodelinforesponse) + + + + + + - [PredictRequest](#predictrequest) + - [PredictRequest.FsParamsEntry](#predictrequest-fsparamsentry) + - [PredictResponse](#predictresponse) + - [PredictResult](#predictresult) + - [Score](#score) + + + + + + - [Status](#status) + + + + + + - [Feature](#feature) + - [FeatureField](#featurefield) + - [FeatureParam](#featureparam) + - [FeatureValue](#featurevalue) + + + + +- Enums + + + + + + + - [ErrorCode](#errorcode) + + + + + + - [FeatureSourceType](#featuresourcetype) + + + + + + + + + + + + + + + + + + - [FieldType](#fieldtype) + + + +- [Scalar Value Types](#scalar-value-types) + + + + + + +{#ExecutionService} +## ExecutionService +ExecutionService provides access to run execution defined in the GraphDef. + +### Execute + +> **rpc** Execute([ExecuteRequest](#executerequest)) + [ExecuteResponse](#executeresponse) + + + + + +{#metrics} +## metrics + + +### default_method + +> **rpc** default_method([MetricsRequest](#metricsrequest)) + [MetricsResponse](#metricsresponse) + + + + + +{#ModelService} +## ModelService +ModelService provides operation ralated to models. + +### GetModelInfo + +> **rpc** GetModelInfo([GetModelInfoRequest](#getmodelinforequest)) + [GetModelInfoResponse](#getmodelinforesponse) + + + + + +{#PredictionService} +## PredictionService +PredictionService provides access to the serving model. + +### Predict + +> **rpc** Predict([PredictRequest](#predictrequest)) + [PredictResponse](#predictresponse) + +Predict. + + + + + + + + +## Messages + + + +{#Header} +### Header +Header containing custom data + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| data | [map Header.DataEntry](#header-dataentry ) | none | + + + + +{#Header.DataEntry} +### Header.DataEntry + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| key | [ string](#string ) | none | +| value | [ string](#string ) | none | + + + + +{#ServiceSpec} +### ServiceSpec +Metadata for an predict or execute request. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| id | [ string](#string ) | The id of the model service. | + + + + + + + + +{#ExecuteRequest} +### ExecuteRequest +Execute request containing one or more requests. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| header | [ Header](#header ) | Custom data. The header will be passed to the downstream system which implement the feature service spi. | +| requester_id | [ string](#string ) | Represents the id of the requesting party | +| service_spec | [ ServiceSpec](#servicespec ) | Model service specification. | +| session_id | [ string](#string ) | Represents the session of this execute. | +| feature_source | [ FeatureSource](#featuresource ) | none | +| task | [ ExecutionTask](#executiontask ) | none | + + + + +{#ExecuteResponse} +### ExecuteResponse +Execute response containing one or more responses. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| header | [ Header](#header ) | Custom data. Passed by the downstream system which implement the feature service spi. | +| status | [ Status](#status ) | Staus of this response. | +| service_spec | [ ServiceSpec](#servicespec ) | Model service specification. | +| session_id | [ string](#string ) | Represents the session of this execute. | +| result | [ ExecuteResult](#executeresult ) | none | + + + + +{#ExecuteResult} +### ExecuteResult +Execute result of the request task. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| execution_id | [ int32](#int32 ) | Specified the execution id. | +| nodes | [repeated NodeIo](#nodeio ) | none | + + + + +{#ExecutionTask} +### ExecutionTask +Execute request task. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| execution_id | [ int32](#int32 ) | Specified the execution id. | +| nodes | [repeated NodeIo](#nodeio ) | none | + + + + +{#FeatureSource} +### FeatureSource +Descriptive feature source + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| type | [ FeatureSourceType](#featuresourcetype ) | Identifies the source type of the features | +| fs_param | [ secretflow.serving.FeatureParam](#featureparam ) | Custom parameter for fetch features from feature service or other systems. Valid when `type==FeatureSourceType::FS_SERVICE` | +| predefineds | [repeated secretflow.serving.Feature](#feature ) | Defined features. Valid when `type==FeatureSourceType::FS_PREDEFINED` | + + + + +{#IoData} +### IoData +The serialized data of the node input/output. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| datas | [repeated bytes](#bytes ) | none | + + + + +{#NodeIo} +### NodeIo +Represents the node input/output data. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| name | [ string](#string ) | Node name. | +| ios | [repeated IoData](#iodata ) | none | + + + + + + +{#MetricsRequest} +### MetricsRequest + + + + + +{#MetricsResponse} +### MetricsResponse + + + + + + + +{#GetModelInfoRequest} +### GetModelInfoRequest + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| header | [ Header](#header ) | Custom data. | +| service_spec | [ ServiceSpec](#servicespec ) | Model service specification. | + + + + +{#GetModelInfoResponse} +### GetModelInfoResponse + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| header | [ Header](#header ) | Custom data. | +| status | [ Status](#status ) | Staus of this response. | +| service_spec | [ ServiceSpec](#servicespec ) | Model service specification. | +| model_info | [ secretflow.serving.ModelInfo](#modelinfo ) | none | + + + + + + +{#PredictRequest} +### PredictRequest +Predict request containing one or more requests. +examples: +```json + { + "header": { + "data": { + "custom_str": "id_12345" + }, + }, + "service_spec": { + "id": "test_service_id" + }, + "fs_params": { + "alice": { + "query_datas": [ + "x1", + "x2" + ], + "query_context": "context_x" + }, + "bob": { + "query_datas": [ + "y1", + "y2" + ], + "query_context": "context_y" + } + } + } +``` + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| header | [ Header](#header ) | Custom data. The header will be passed to the downstream system which implement the feature service spi. | +| service_spec | [ ServiceSpec](#servicespec ) | Model service specification. | +| fs_params | [map PredictRequest.FsParamsEntry](#predictrequest-fsparamsentry ) | The params for fetch features. Note that this should include all the parties involved in the prediction. Key: party's id. Value: params for fetch features. | +| predefined_features | [repeated secretflow.serving.Feature](#feature ) | Optional. If defined, the request party will no longer query for the feature but will use defined fetures in `predefined_features` for the prediction. | + + + + +{#PredictRequest.FsParamsEntry} +### PredictRequest.FsParamsEntry + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| key | [ string](#string ) | none | +| value | [ secretflow.serving.FeatureParam](#featureparam ) | none | + + + + +{#PredictResponse} +### PredictResponse +Predict response containing one or more responses. +examples: +```json + { + "header": { + "data": { + "custom_value": "asdfvb" + }, + }, + "status": { + "code": 0, + "msg": "success." + }, + "service_spec": { + "id": "test_service_id" + }, + "results": { + "scores": [ + { + "name": "pred_y", + "value": 0.32456 + }, + { + "name": "pred_y", + "value": 0.02456 + } + ] + } + } +``` + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| header | [ Header](#header ) | Custom data. Passed by the downstream system which implement the feature service spi. | +| status | [ Status](#status ) | Staus of this response. | +| service_spec | [ ServiceSpec](#servicespec ) | Model service specification. | +| results | [repeated PredictResult](#predictresult ) | List of the predict result. Returned in the same order as the request's feature query data. | + + + + +{#PredictResult} +### PredictResult +Result of single predict request. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| scores | [repeated Score](#score ) | According to the model, there may be one or multi scores. | + + + + +{#Score} +### Score +Result of regression or one class of Classifications + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| name | [ string](#string ) | The name of the score, it depends on the attribute configuration of the model. | +| value | [ double](#double ) | The value of the score. | + + + + + + +{#Status} +### Status +Represents the status of a request + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| code | [ int32](#int32 ) | The code of this status. Must be one of ErrorCode in error_code.proto | +| msg | [ string](#string ) | The msg of this status. | + + + + + + +{#Feature} +### Feature +The definition of a feature + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| field | [ FeatureField](#featurefield ) | none | +| value | [ FeatureValue](#featurevalue ) | none | + + + + +{#FeatureField} +### FeatureField +The definition of a feature field. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| name | [ string](#string ) | Unique name of the feature | +| type | [ FieldType](#fieldtype ) | Field type of the feature | + + + + +{#FeatureParam} +### FeatureParam +The param for fetch features + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| query_datas | [repeated string](#string ) | The serialized datas for query features. Each one for query one row of features. | +| query_context | [ string](#string ) | Optional. Represents the common part of the query datas. | + + + + +{#FeatureValue} +### FeatureValue +The value of a feature + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| i32s | [repeated int32](#int32 ) | int list | +| i64s | [repeated int64](#int64 ) | none | +| fs | [repeated float](#float ) | float list | +| ds | [repeated double](#double ) | none | +| ss | [repeated string](#string ) | string list | +| bs | [repeated bool](#bool ) | bool list | + + + + + +## Enums + + + + + +### ErrorCode + + +| Name | Number | Description | +| ---- | ------ | ----------- | +| UNKNOWN | 0 | Placeholder for proto3 default value, do not use it | +| OK | 1 | none | +| UNEXPECTED_ERROR | 2 | none | +| INVALID_ARGUMENT | 3 | none | +| NETWORK_ERROR | 4 | none | +| NOT_FOUND | 5 | Some requested entity (e.g., file or directory) was not found. | +| NOT_IMPLEMENTED | 6 | none | +| LOGIC_ERROR | 7 | none | +| SERIALIZE_FAILED | 8 | none | +| DESERIALIZE_FAILED | 9 | none | +| IO_ERROR | 10 | none | +| NOT_READY | 11 | none | +| FS_UNAUTHENTICATED | 100 | none | +| FS_INVALID_ARGUMENT | 101 | none | +| FS_DEADLINE_EXCEEDED | 102 | none | +| FS_NOT_FOUND | 103 | none | +| FS_INTERNAL_ERROR | 104 | none | + + + + + + +### FeatureSourceType +Support feature source type + +| Name | Number | Description | +| ---- | ------ | ----------- | +| UNKNOWN_FS_TYPE | 0 | none | +| FS_NONE | 1 | No need features. | +| FS_SERVICE | 2 | Fetch features from feature service. | +| FS_PREDEFINED | 3 | The feature is defined in the request. | + + + + + + + + + + + + + + +### FieldType +Supported feature field type. + +| Name | Number | Description | +| ---- | ------ | ----------- | +| UNKNOWN_FIELD_TYPE | 0 | Placeholder for proto3 default value, do not use it. | +| FIELD_BOOL | 1 | BOOL | +| FIELD_INT32 | 2 | INT32 | +| FIELD_INT64 | 3 | INT64 | +| FIELD_FLOAT | 4 | FLOAT | +| FIELD_DOUBLE | 5 | DOUBLE | +| FIELD_STRING | 6 | STRING | + + + + + +## Scalar Value Types + +| .proto Type | Notes | C++ Type | Java Type | Python Type | +| ----------- | ----- | -------- | --------- | ----------- | +|

double | | double | double | float | +|

float | | float | float | float | +|

int32 | Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint32 instead. | int32 | int | int | +|

int64 | Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint64 instead. | int64 | long | int/long | +|

uint32 | Uses variable-length encoding. | uint32 | int | int/long | +|

uint64 | Uses variable-length encoding. | uint64 | long | int/long | +|

sint32 | Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int32s. | int32 | int | int | +|

sint64 | Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int64s. | int64 | long | int/long | +|

fixed32 | Always four bytes. More efficient than uint32 if values are often greater than 2^28. | uint32 | int | int | +|

fixed64 | Always eight bytes. More efficient than uint64 if values are often greater than 2^56. | uint64 | long | int/long | +|

sfixed32 | Always four bytes. | int32 | int | int | +|

sfixed64 | Always eight bytes. | int64 | long | int/long | +|

bool | | bool | boolean | boolean | +|

string | A string must always contain UTF-8 encoded or 7-bit ASCII text. | string | String | str/unicode | +|

bytes | May contain any arbitrary sequence of bytes. | string | ByteString | str | diff --git a/docs/source/reference/api_md.tmpl b/docs/source/reference/api_md.tmpl new file mode 100644 index 0000000..2b22081 --- /dev/null +++ b/docs/source/reference/api_md.tmpl @@ -0,0 +1,89 @@ +# SecretFlow-Serving API + +## Table of Contents +- Services +{{range .Files}} + +{{if .HasServices}} + {{range .Services}} - [{{.Name}}](#{{.Name | lower | replace "." ""}}) + {{end}} +{{end}} +{{end}} + +- Messages +{{range .Files}} + +{{if .HasMessages}} + {{range .Messages}} - [{{.LongName}}](#{{.LongName | lower | replace "." "-"}}) + {{end}} +{{end}} +{{end}} + +- Enums +{{range .Files}} + +{{if .HasEnums}} + {{range .Enums}} - [{{.LongName}}](#{{.LongName | lower | replace "." ""}}) + {{end}} +{{end}} +{{end}} +- [Scalar Value Types](#scalar-value-types) + +{{range .Files}} +{{range .Services -}} +{#{{.LongName}}} +## {{.Name}} +{{.Description}} + +{{range .Methods -}} +### {{.Name}} + +> **rpc** {{.Name}}([{{.RequestLongType}}](#{{.RequestLongType | lower | replace "." ""}})) + [{{.ResponseLongType}}](#{{.ResponseLongType | lower | replace "." ""}}) + +{{ .Description}} +{{end}} +{{end}} +{{end}} + +## Messages +{{range .Files}} +{{range .Messages}} + +{#{{.LongName}}} +### {{.LongName}} +{{.Description}} + +{{if .HasFields}} +| Field | Type | Description | +| ----- | ---- | ----------- | +{{range .Fields -}} + | {{if .IsOneof}}[**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) {{.OneofDecl}}.{{end}}{{.Name}} | [{{if .IsMap}}map {{else}}{{.Label}} {{end}}{{.LongType}}](#{{if .IsMap}}{{.LongType | lower | replace "." "-"}} {{else}}{{.Type | lower | replace "." "-"}} {{end}}) | {{if .Description}}{{nobr .Description}}{{if .DefaultValue}} Default: {{.DefaultValue}}{{end}}{{else}}none{{end}} | +{{end}} +{{end}} +{{end}} +{{end}} + +## Enums +{{range .Files}} +{{range .Enums}} + +### {{.LongName}} +{{.Description}} + +| Name | Number | Description | +| ---- | ------ | ----------- | +{{range .Values -}} + | {{.Name}} | {{.Number}} | {{if .Description}}{{nobr .Description}}{{else}}none{{end}} | +{{end}} + +{{end}} +{{end}} + +## Scalar Value Types + +| .proto Type | Notes | C++ Type | Java Type | Python Type | +| ----------- | ----- | -------- | --------- | ----------- | +{{range .Scalars -}} + |

{{.ProtoType}} | {{.Notes}} | {{.CppType}} | {{.JavaType}} | {{.PythonType}} | +{{end}} diff --git a/docs/source/reference/config.md b/docs/source/reference/config.md new file mode 100644 index 0000000..ea4f966 --- /dev/null +++ b/docs/source/reference/config.md @@ -0,0 +1,442 @@ +# SecretFlow-Serving Config + +## Table of Contents +- Services + + + + + + + + + + + + + + + + + + + + + + + +- Messages + + + + - [ChannelDesc](#channeldesc) + - [ClusterConfig](#clusterconfig) + - [PartyDesc](#partydesc) + + + + + + - [CsvOptions](#csvoptions) + - [FeatureSourceConfig](#featuresourceconfig) + - [HttpOptions](#httpoptions) + - [MockOptions](#mockoptions) + + + + + + - [LoggingConfig](#loggingconfig) + + + + + + - [FileSourceMeta](#filesourcemeta) + - [ModelConfig](#modelconfig) + - [OSSSourceMeta](#osssourcemeta) + + + + + + - [ServerConfig](#serverconfig) + - [ServerConfig.FeatureMappingEntry](#serverconfig-featuremappingentry) + + + + + + - [ServingConfig](#servingconfig) + + + + + + - [TlsConfig](#tlsconfig) + + + + +- Enums + + + + + + + - [MockDataType](#mockdatatype) + + + + + + - [LogLevel](#loglevel) + + + + + + - [SourceType](#sourcetype) + + + + + + + + + + + + +- [Scalar Value Types](#scalar-value-types) + + + + + + + + + + + + + + + + + +## Messages + + + +{#ChannelDesc} +### ChannelDesc +Description for channels between joined parties + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| protocol | [ string](#string ) | https://github.com/apache/brpc/blob/master/docs/en/client.md#protocols | +| rpc_timeout_ms | [ int32](#int32 ) | Max duration of RPC. -1 means wait indefinitely. Default: 2000 (ms) | +| connect_timeout_ms | [ int32](#int32 ) | Max duration for a connect. -1 means wait indefinitely. Default: 500 (ms) | +| tls_config | [ TlsConfig](#tlsconfig ) | TLS related config. | +| handshake_max_retry_cnt | [ int32](#int32 ) | When the server starts, model information from all parties will be collected. At this time, the remote servers may not have started yet, and we need to retry. And if we connect gateway,the max waiting time for each operation will be rpc_timeout_ms + handshake_retry_interval_ms. Maximum number of retries, default: 60 | +| handshake_retry_interval_ms | [ int32](#int32 ) | time between retries, default: 5000ms | + + + + +{#ClusterConfig} +### ClusterConfig +Runtime config for a serving cluster + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| parties | [repeated PartyDesc](#partydesc ) | none | +| self_id | [ string](#string ) | none | +| channel_desc | [ ChannelDesc](#channeldesc ) | none | + + + + +{#PartyDesc} +### PartyDesc +Description for a joined party + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| id | [ string](#string ) | Unique id of the party | +| address | [ string](#string ) | e.g. 127.0.0.1:9001 | +| listen_address | [ string](#string ) | Optional. Address will be used if listen_address is empty. | + + + + + + +{#CsvOptions} +### CsvOptions +Options of a csv feature source. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| file_path | [ string](#string ) | Input file path, specifies where to load data Note that this will load all of the data into memory at once | +| id_name | [ string](#string ) | Id column name, associated with FeatureParam::query_datas Query datas is a subset of id column | + + + + +{#FeatureSourceConfig} +### FeatureSourceConfig +Config for a feature source + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) options.mock_opts | [ MockOptions](#mockoptions ) | none | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) options.http_opts | [ HttpOptions](#httpoptions ) | none | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) options.csv_opts | [ CsvOptions](#csvoptions ) | none | + + + + +{#HttpOptions} +### HttpOptions +Options for a http feature source which should implement the feature service +spi. The defined of spi can be found in +secretflow_serving/spis/batch_feature_service.proto + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| endpoint | [ string](#string ) | none | +| enable_lb | [ bool](#bool ) | Whether to enable round robin load balancer. | +| connect_timeout_ms | [ int32](#int32 ) | Max duration for a connect. -1 means wait indefinitely. Default: 500 (ms) | +| timeout_ms | [ int32](#int32 ) | Max duration of http request. -1 means wait indefinitely. Default: 1000 (ms) | +| tls_config | [ TlsConfig](#tlsconfig ) | TLS related config. | + + + + +{#MockOptions} +### MockOptions +Options for a mock feature source. +Mock feature source will generates values(random or fixed, according to type) +for the desired features. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| type | [ MockDataType](#mockdatatype ) | default MDT_RANDOM | + + + + + + +{#LoggingConfig} +### LoggingConfig + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| system_log_path | [ string](#string ) | system log default value: "serving.log" | +| log_level | [ LogLevel](#loglevel ) | default value: LogLevel.INFO_LOG_LEVEL | +| max_log_file_size | [ int32](#int32 ) | Byte. default value: 500 * 1024 * 1024 (500MB) | +| max_log_file_count | [ int32](#int32 ) | default value: 10 | + + + + + + +{#FileSourceMeta} +### FileSourceMeta +empty by design + + + + +{#ModelConfig} +### ModelConfig +Config for serving model + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| model_id | [ string](#string ) | Unique id of the model package | +| base_path | [ string](#string ) | Path used to cache and load model package | +| source_path | [ string](#string ) | Represent the path of the model package in the model source | +| source_sha256 | [ string](#string ) | Optional. The expect sha256 of the model package | +| source_type | [ SourceType](#sourcetype ) | none | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) kind.file_source_meta | [ FileSourceMeta](#filesourcemeta ) | none | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) kind.oss_source_meta | [ OSSSourceMeta](#osssourcemeta ) | none | + + + + +{#OSSSourceMeta} +### OSSSourceMeta +Options for a S3 Oss model source + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| access_key | [ string](#string ) | Bucket access key | +| secret_key | [ string](#string ) | Bucket secret key | +| virtual_hosted | [ bool](#bool ) | Whether to use virtual host mode, https://docs.aws.amazon.com/AmazonS3/latest/userguide/VirtualHosting.html | +| endpoint | [ string](#string ) | none | +| bucket | [ string](#string ) | none | + + + + + + +{#ServerConfig} +### ServerConfig + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| feature_mapping | [map ServerConfig.FeatureMappingEntry](#serverconfig-featuremappingentry ) | Optional. Feature name mapping rules. Key: source or predefined feature name Value: model feature name | +| tls_config | [ TlsConfig](#tlsconfig ) | Whether to enable tls for server | +| brpc_builtin_service_port | [ int32](#int32 ) | Brpc builtin service listen port Default: disable service | +| metrics_exposer_port | [ int32](#int32 ) | Whether `/metrics` service is enable/disable. | +| worker_num | [ int32](#int32 ) | Number of pthreads that server runs on. If this option <= 0, use default value. Default: #cpu-cores | +| max_concurrency | [ int32](#int32 ) | Server-level max number of requests processed in parallel Default: 0 (unlimited) | +| op_exec_worker_num | [ int32](#int32 ) | Number of pthreads that server runs to execute ops. If this option <= 0, use default value. Default: #cpu-cores | + + + + +{#ServerConfig.FeatureMappingEntry} +### ServerConfig.FeatureMappingEntry + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| key | [ string](#string ) | none | +| value | [ string](#string ) | none | + + + + + + +{#ServingConfig} +### ServingConfig +Related config of serving + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| id | [ string](#string ) | Unique id of the serving service | +| server_conf | [ ServerConfig](#serverconfig ) | none | +| model_conf | [ ModelConfig](#modelconfig ) | none | +| cluster_conf | [ ClusterConfig](#clusterconfig ) | none | +| feature_source_conf | [ FeatureSourceConfig](#featuresourceconfig ) | none | + + + + + + +{#TlsConfig} +### TlsConfig + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| certificate_path | [ string](#string ) | Certificate file path | +| private_key_path | [ string](#string ) | Private key file path | +| ca_file_path | [ string](#string ) | The trusted CA file to verify the peer's certificate If empty, use the system default CA files | + + + + + +## Enums + + + + + +### MockDataType + + +| Name | Number | Description | +| ---- | ------ | ----------- | +| INVALID_MOCK_DATA_TYPE | 0 | Placeholder for proto3 default value, do not use it. | +| MDT_RANDOM | 1 | random value for each feature | +| MDT_FIXED | 2 | fixed value for each feature | + + + + + + +### LogLevel + + +| Name | Number | Description | +| ---- | ------ | ----------- | +| INVALID_LOG_LEVEL | 0 | Placeholder for proto3 default value, do not use it. | +| DEBUG_LOG_LEVEL | 1 | debug | +| INFO_LOG_LEVEL | 2 | info | +| WARN_LOG_LEVEL | 3 | warn | +| ERROR_LOG_LEVEL | 4 | error | + + + + + + +### SourceType +Supported model source type + +| Name | Number | Description | +| ---- | ------ | ----------- | +| INVALID_SOURCE_TYPE | 0 | Placeholder for proto3 default value, do not use it. | +| ST_FILE | 1 | Local filesystem | +| ST_OSS | 2 | S3 OSS | + + + + + + + + + + + +## Scalar Value Types + +| .proto Type | Notes | C++ Type | Java Type | Python Type | +| ----------- | ----- | -------- | --------- | ----------- | +|

double | | double | double | float | +|

float | | float | float | float | +|

int32 | Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint32 instead. | int32 | int | int | +|

int64 | Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint64 instead. | int64 | long | int/long | +|

uint32 | Uses variable-length encoding. | uint32 | int | int/long | +|

uint64 | Uses variable-length encoding. | uint64 | long | int/long | +|

sint32 | Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int32s. | int32 | int | int | +|

sint64 | Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int64s. | int64 | long | int/long | +|

fixed32 | Always four bytes. More efficient than uint32 if values are often greater than 2^28. | uint32 | int | int | +|

fixed64 | Always eight bytes. More efficient than uint64 if values are often greater than 2^56. | uint64 | long | int/long | +|

sfixed32 | Always four bytes. | int32 | int | int | +|

sfixed64 | Always eight bytes. | int64 | long | int/long | +|

bool | | bool | boolean | boolean | +|

string | A string must always contain UTF-8 encoded or 7-bit ASCII text. | string | String | str/unicode | +|

bytes | May contain any arbitrary sequence of bytes. | string | ByteString | str | diff --git a/docs/source/reference/config_md.tmpl b/docs/source/reference/config_md.tmpl new file mode 100644 index 0000000..8b871ad --- /dev/null +++ b/docs/source/reference/config_md.tmpl @@ -0,0 +1,89 @@ +# SecretFlow-Serving Config + +## Table of Contents +- Services +{{range .Files}} + +{{if .HasServices}} + {{range .Services}} - [{{.Name}}](#{{.Name | lower | replace "." ""}}) + {{end}} +{{end}} +{{end}} + +- Messages +{{range .Files}} + +{{if .HasMessages}} + {{range .Messages}} - [{{.LongName}}](#{{.LongName | lower | replace "." "-"}}) + {{end}} +{{end}} +{{end}} + +- Enums +{{range .Files}} + +{{if .HasEnums}} + {{range .Enums}} - [{{.LongName}}](#{{.LongName | lower | replace "." ""}}) + {{end}} +{{end}} +{{end}} +- [Scalar Value Types](#scalar-value-types) + +{{range .Files}} +{{range .Services -}} +{#{{.LongName}}} +## {{.Name}} +{{.Description}} + +{{range .Methods -}} +### {{.Name}} + +> **rpc** {{.Name}}([{{.RequestLongType}}](#{{.RequestLongType | lower | replace "." ""}})) + [{{.ResponseLongType}}](#{{.ResponseLongType | lower | replace "." ""}}) + +{{ .Description}} +{{end}} +{{end}} +{{end}} + +## Messages +{{range .Files}} +{{range .Messages}} + +{#{{.LongName}}} +### {{.LongName}} +{{.Description}} + +{{if .HasFields}} +| Field | Type | Description | +| ----- | ---- | ----------- | +{{range .Fields -}} + | {{if .IsOneof}}[**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) {{.OneofDecl}}.{{end}}{{.Name}} | [{{if .IsMap}}map {{else}}{{.Label}} {{end}}{{.LongType}}](#{{if .IsMap}}{{.LongType | lower | replace "." "-"}} {{else}}{{.Type | lower | replace "." "-"}} {{end}}) | {{if .Description}}{{nobr .Description}}{{if .DefaultValue}} Default: {{.DefaultValue}}{{end}}{{else}}none{{end}} | +{{end}} +{{end}} +{{end}} +{{end}} + +## Enums +{{range .Files}} +{{range .Enums}} + +### {{.LongName}} +{{.Description}} + +| Name | Number | Description | +| ---- | ------ | ----------- | +{{range .Values -}} + | {{.Name}} | {{.Number}} | {{if .Description}}{{nobr .Description}}{{else}}none{{end}} | +{{end}} + +{{end}} +{{end}} + +## Scalar Value Types + +| .proto Type | Notes | C++ Type | Java Type | Python Type | +| ----------- | ----- | -------- | --------- | ----------- | +{{range .Scalars -}} + |

{{.ProtoType}} | {{.Notes}} | {{.CppType}} | {{.JavaType}} | {{.PythonType}} | +{{end}} diff --git a/docs/source/reference/index.rst b/docs/source/reference/index.rst new file mode 100644 index 0000000..a715e7d --- /dev/null +++ b/docs/source/reference/index.rst @@ -0,0 +1,31 @@ +.. _reference: + +Reference +========= + +This part contains detailed explanation of Model, Configs, SPIs and APIs. + + +.. toctree:: + :maxdepth: 1 + :caption: API + + api + +.. toctree:: + :maxdepth: 1 + :caption: Config + + config + +.. toctree:: + :maxdepth: 1 + :caption: SPI + + spi + +.. toctree:: + :maxdepth: 1 + :caption: Model + + model diff --git a/docs/source/reference/model.md b/docs/source/reference/model.md new file mode 100644 index 0000000..1977c35 --- /dev/null +++ b/docs/source/reference/model.md @@ -0,0 +1,785 @@ +# SecretFlow-Serving Model + +## Table of Contents +- Services + + + + + + + + + + + + + + + + + + + + + + + +- Messages + + + + - [AttrDef](#attrdef) + - [AttrValue](#attrvalue) + - [BoolList](#boollist) + - [BytesList](#byteslist) + - [DoubleList](#doublelist) + - [FloatList](#floatlist) + - [Int32List](#int32list) + - [Int64List](#int64list) + - [StringList](#stringlist) + + + + + + - [IoDef](#iodef) + - [OpDef](#opdef) + - [OpTag](#optag) + + + + + + - [ExecutionDef](#executiondef) + - [GraphDef](#graphdef) + - [GraphView](#graphview) + - [NodeDef](#nodedef) + - [NodeDef.AttrValuesEntry](#nodedef-attrvaluesentry) + - [NodeView](#nodeview) + - [RuntimeConfig](#runtimeconfig) + + + + + + - [ModelBundle](#modelbundle) + - [ModelInfo](#modelinfo) + - [ModelManifest](#modelmanifest) + + + + + + + + + - [ComputeTrace](#computetrace) + - [FunctionInput](#functioninput) + - [FunctionOutput](#functionoutput) + - [FunctionTrace](#functiontrace) + - [Scalar](#scalar) + + + + + + - [ComputeTrace](#computetrace) + - [FunctionInput](#functioninput) + - [FunctionOutput](#functionoutput) + - [FunctionTrace](#functiontrace) + - [Scalar](#scalar) + + + + +- Enums + + + + - [AttrType](#attrtype) + + + + + + + + + - [DispatchType](#dispatchtype) + + + + + + - [FileFormatType](#fileformattype) + + + + + + - [DataType](#datatype) + + + + + + - [ExtendFunctionName](#extendfunctionname) + + + + + + - [ExtendFunctionName](#extendfunctionname) + + + +- [Scalar Value Types](#scalar-value-types) + + + + + + + + + + + + + + + + + +## Messages + + + +{#AttrDef} +### AttrDef +The definition of an attribute. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| name | [ string](#string ) | Must be unique among all attr of the operator. | +| desc | [ string](#string ) | Description of the attribute | +| type | [ AttrType](#attrtype ) | none | +| is_optional | [ bool](#bool ) | If True, when AttrValue is not provided or is_na, default_value would be used. Else, AttrValue must be provided. | +| default_value | [ AttrValue](#attrvalue ) | A reasonable default for this attribute if it's optional and the user does not supply a value. If not, the user must supply a value. | + + + + +{#AttrValue} +### AttrValue +The value of an attribute + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.i32 | [ int32](#int32 ) | INT | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.i64 | [ int64](#int64 ) | none | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.f | [ float](#float ) | FLOAT | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.d | [ double](#double ) | none | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.s | [ string](#string ) | STRING | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.b | [ bool](#bool ) | BOOL | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.by | [ bytes](#bytes ) | BYTES | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.i32s | [ Int32List](#int32list ) | INTS | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.i64s | [ Int64List](#int64list ) | none | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.fs | [ FloatList](#floatlist ) | FLOATS | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.ds | [ DoubleList](#doublelist ) | none | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.ss | [ StringList](#stringlist ) | STRINGS | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.bs | [ BoolList](#boollist ) | BOOLS | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.bys | [ BytesList](#byteslist ) | BYTESS | + + + + +{#BoolList} +### BoolList + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| data | [repeated bool](#bool ) | none | + + + + +{#BytesList} +### BytesList + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| data | [repeated bytes](#bytes ) | none | + + + + +{#DoubleList} +### DoubleList + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| data | [repeated double](#double ) | none | + + + + +{#FloatList} +### FloatList + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| data | [repeated float](#float ) | none | + + + + +{#Int32List} +### Int32List + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| data | [repeated int32](#int32 ) | none | + + + + +{#Int64List} +### Int64List + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| data | [repeated int64](#int64 ) | none | + + + + +{#StringList} +### StringList + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| data | [repeated string](#string ) | none | + + + + + + +{#IoDef} +### IoDef +Define an input/output for operator. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| name | [ string](#string ) | Must be unique among all IOs of the operator. | +| desc | [ string](#string ) | Description of the IO | + + + + +{#OpDef} +### OpDef +The definition of a operator. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| name | [ string](#string ) | Unique name of the op | +| desc | [ string](#string ) | Description of the op | +| version | [ string](#string ) | Version of the op | +| tag | [ OpTag](#optag ) | none | +| inputs | [repeated IoDef](#iodef ) | none | +| output | [ IoDef](#iodef ) | none | +| attrs | [repeated AttrDef](#attrdef ) | none | + + + + +{#OpTag} +### OpTag +Representation operator property + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| returnable | [ bool](#bool ) | The operator's output can be the final result | +| mergeable | [ bool](#bool ) | The operator accept the output of operators with different participants and will somehow merge them. | +| session_run | [ bool](#bool ) | The operator needs to be executed in session. | + + + + + + +{#ExecutionDef} +### ExecutionDef +The definition of a execution. A execution represents a subgraph within a +graph that can be scheduled for execution in a specified pattern. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| nodes | [repeated string](#string ) | Represents the nodes contained in this execution. Note that these node names should be findable and unique within the node definitions. One node can only exist in one execution and must exist in one. | +| config | [ RuntimeConfig](#runtimeconfig ) | The runtime config of the execution. | + + + + +{#GraphDef} +### GraphDef +The definition of a Graph. A graph consists of a set of nodes carrying data +and a set of executions that describes the scheduling of the graph. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| version | [ string](#string ) | Version of the graph | +| node_list | [repeated NodeDef](#nodedef ) | none | +| execution_list | [repeated ExecutionDef](#executiondef ) | none | + + + + +{#GraphView} +### GraphView +The view of a graph is used to display the structure of the graph, containing +only structural information and excluding the data components. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| version | [ string](#string ) | Version of the graph | +| node_list | [repeated NodeView](#nodeview ) | none | +| execution_list | [repeated ExecutionDef](#executiondef ) | none | + + + + +{#NodeDef} +### NodeDef +The definition of a node. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| name | [ string](#string ) | Must be unique among all nodes of the graph. | +| op | [ string](#string ) | The operator name. | +| parents | [repeated string](#string ) | The parent node names of the node. The order of the parent nodes should match the order of the inputs of the node. | +| attr_values | [map NodeDef.AttrValuesEntry](#nodedef-attrvaluesentry ) | The attribute values configed in the node. Note that this should include all attrs defined in the corresponding OpDef. | +| op_version | [ string](#string ) | The operator version. | + + + + +{#NodeDef.AttrValuesEntry} +### NodeDef.AttrValuesEntry + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| key | [ string](#string ) | none | +| value | [ op.AttrValue](#attrvalue ) | none | + + + + +{#NodeView} +### NodeView +The view of a node, which could be public to other parties + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| name | [ string](#string ) | Must be unique among all nodes of the graph. | +| op | [ string](#string ) | The operator name. | +| parents | [repeated string](#string ) | The parent node names of the node. The order of the parent nodes should match the order of the inputs of the node. | +| op_version | [ string](#string ) | The operator version. | + + + + +{#RuntimeConfig} +### RuntimeConfig +The runtime config of the execution. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| dispatch_type | [ DispatchType](#dispatchtype ) | The dispatch type of the execution. | +| session_run | [ bool](#bool ) | The execution need run in session(stateful) TODO: not support yet. | +| specific_flag | [ bool](#bool ) | if dispatch_type is DP_SPECIFIED, only one party should be true | + + + + + + +{#ModelBundle} +### ModelBundle +Represents an exported secertflow model. It consists of a GraphDef and extra +metadata required for serving. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| name | [ string](#string ) | none | +| desc | [ string](#string ) | none | +| graph | [ GraphDef](#graphdef ) | none | + + + + +{#ModelInfo} +### ModelInfo +Represents a secertflow model without private data. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| name | [ string](#string ) | none | +| desc | [ string](#string ) | none | +| graph_view | [ GraphView](#graphview ) | none | + + + + +{#ModelManifest} +### ModelManifest +The manifest of the model package. Package format is as follows: +model.tar.gz + ├ MANIFIEST + ├ model_file + └ some op meta files +MANIFIEST should be json format + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| bundle_path | [ string](#string ) | Model bundle file path. | +| bundle_format | [ FileFormatType](#fileformattype ) | The format type of the model bundle file. | + + + + + + + + +{#ComputeTrace} +### ComputeTrace + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| name | [ string](#string ) | The name of this Compute. | +| func_traces | [repeated FunctionTrace](#functiontrace ) | none | + + + + +{#FunctionInput} +### FunctionInput + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.data_id | [ int32](#int32 ) | '0' means root input data | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.custom_scalar | [ Scalar](#scalar ) | none | + + + + +{#FunctionOutput} +### FunctionOutput + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| data_id | [ int32](#int32 ) | none | + + + + +{#FunctionTrace} +### FunctionTrace + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| name | [ string](#string ) | The Function name. | +| option_bytes | [ bytes](#bytes ) | The serialized function options. | +| inputs | [repeated FunctionInput](#functioninput ) | Inputs of this function. | +| output | [ FunctionOutput](#functionoutput ) | Output of this function. | + + + + +{#Scalar} +### Scalar +Represents a single value with a specific data type. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.i8 | [ int32](#int32 ) | INT8. | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.ui8 | [ int32](#int32 ) | UINT8 | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.i16 | [ int32](#int32 ) | INT16 | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.ui16 | [ int32](#int32 ) | UINT16 | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.i32 | [ int32](#int32 ) | INT32 | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.ui32 | [ uint32](#uint32 ) | UINT32 | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.i64 | [ int64](#int64 ) | INT64 | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.ui64 | [ uint64](#uint64 ) | UINT64 | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.f | [ float](#float ) | FLOAT | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.d | [ double](#double ) | DOUBLE | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.s | [ string](#string ) | STRING | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.b | [ bool](#bool ) | BOOL | + + + + + + +{#ComputeTrace} +### ComputeTrace + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| name | [ string](#string ) | The name of this Compute. | +| func_traces | [repeated FunctionTrace](#functiontrace ) | none | + + + + +{#FunctionInput} +### FunctionInput + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.data_id | [ int32](#int32 ) | '0' means root input data | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.custom_scalar | [ Scalar](#scalar ) | none | + + + + +{#FunctionOutput} +### FunctionOutput + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| data_id | [ int32](#int32 ) | none | + + + + +{#FunctionTrace} +### FunctionTrace + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| name | [ string](#string ) | The Function name. | +| option_bytes | [ bytes](#bytes ) | The serialized function options. | +| inputs | [repeated FunctionInput](#functioninput ) | Inputs of this function. | +| output | [ FunctionOutput](#functionoutput ) | Output of this function. | + + + + +{#Scalar} +### Scalar +Represents a single value with a specific data type. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.i8 | [ int32](#int32 ) | INT8. | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.ui8 | [ int32](#int32 ) | UINT8 | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.i16 | [ int32](#int32 ) | INT16 | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.ui16 | [ int32](#int32 ) | UINT16 | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.i32 | [ int32](#int32 ) | INT32 | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.ui32 | [ uint32](#uint32 ) | UINT32 | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.i64 | [ int64](#int64 ) | INT64 | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.ui64 | [ uint64](#uint64 ) | UINT64 | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.f | [ float](#float ) | FLOAT | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.d | [ double](#double ) | DOUBLE | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.s | [ string](#string ) | STRING | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) value.b | [ bool](#bool ) | BOOL | + + + + + +## Enums + + + +{#AttrType} +### AttrType +Supported attribute types. + +| Name | Number | Description | +| ---- | ------ | ----------- | +| UNKNOWN_AT_TYPE | 0 | Placeholder for proto3 default value, do not use it. | +| AT_INT32 | 1 | INT32 | +| AT_INT64 | 2 | INT64 | +| AT_FLOAT | 3 | FLOAT | +| AT_DOUBLE | 4 | DOUBLE | +| AT_STRING | 5 | STRING | +| AT_BOOL | 6 | BOOL | +| AT_BYTES | 7 | BYTES | +| AT_INT32_LIST | 11 | INT32 LIST | +| AT_INT64_LIST | 12 | INT64 LIST | +| AT_FLOAT_LIST | 13 | FLOAT LIST | +| AT_DOUBLE_LIST | 14 | DOUBLE LIST | +| AT_STRING_LIST | 15 | STRING LIST | +| AT_BOOL_LIST | 16 | BOOL LIST | +| AT_BYTES_LIST | 17 | BYTES LIST | + + + + + + + + +{#DispatchType} +### DispatchType +Supported dispatch type + +| Name | Number | Description | +| ---- | ------ | ----------- | +| UNKNOWN_DP_TYPE | 0 | Placeholder for proto3 default value, do not use it. | +| DP_ALL | 1 | Dispatch all participants. | +| DP_ANYONE | 2 | Dispatch any participant. | +| DP_SPECIFIED | 3 | Dispatch specified participant. | + + + + + + +{#FileFormatType} +### FileFormatType +Support model file format + +| Name | Number | Description | +| ---- | ------ | ----------- | +| UNKNOWN_FF_TYPE | 0 | none | +| FF_PB | 1 | Protobuf | +| FF_JSON | 2 | Json It is recommended to use protobuf's official json serialization method to ensure compatibility | + + + + + + +{#DataType} +### DataType +Mapping arrow::DataType +`https://arrow.apache.org/docs/cpp/api/datatype.html`. + +| Name | Number | Description | +| ---- | ------ | ----------- | +| UNKNOWN_DT_TYPE | 0 | Placeholder for proto3 default value, do not use it. | +| DT_BOOL | 1 | Boolean as 1 bit, LSB bit-packed ordering. | +| DT_UINT8 | 2 | Unsigned 8-bit little-endian integer. | +| DT_INT8 | 3 | Signed 8-bit little-endian integer. | +| DT_UINT16 | 4 | Unsigned 16-bit little-endian integer. | +| DT_INT16 | 5 | Signed 16-bit little-endian integer. | +| DT_UINT32 | 6 | Unsigned 32-bit little-endian integer. | +| DT_INT32 | 7 | Signed 32-bit little-endian integer. | +| DT_UINT64 | 8 | Unsigned 64-bit little-endian integer. | +| DT_INT64 | 9 | Signed 64-bit little-endian integer. | +| DT_FLOAT | 11 | 4-byte floating point value | +| DT_DOUBLE | 12 | 8-byte floating point value | +| DT_STRING | 13 | UTF8 variable-length string as List | +| DT_BINARY | 14 | Variable-length bytes (no guarantee of UTF8-ness) | + + + + + + +{#ExtendFunctionName} +### ExtendFunctionName + + +| Name | Number | Description | +| ---- | ------ | ----------- | +| UNKOWN_EX_FUNCTION_NAME | 0 | Placeholder for proto3 default value, do not use it | +| EFN_TB_COLUMN | 1 | Get colunm from table(record_batch). see https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch6columnEi | +| EFN_TB_ADD_COLUMN | 2 | Add colum to table(record_batch). see https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch9AddColumnEiNSt6stringERKNSt10shared_ptrI5ArrayEE | +| EFN_TB_REMOVE_COLUMN | 3 | Remove colunm from table(record_batch). see https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch12RemoveColumnEi | +| EFN_TB_SET_COLUMN | 4 | Set colunm to table(record_batch). see https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch9SetColumnEiRKNSt10shared_ptrI5FieldEERKNSt10shared_ptrI5ArrayEE | + + + + + + +{#ExtendFunctionName} +### ExtendFunctionName + + +| Name | Number | Description | +| ---- | ------ | ----------- | +| UNKOWN_EX_FUNCTION_NAME | 0 | Placeholder for proto3 default value, do not use it | +| EFN_TB_COLUMN | 1 | Get colunm from table(record_batch). see https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch6columnEi | +| EFN_TB_ADD_COLUMN | 2 | Add colum to table(record_batch). see https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch9AddColumnEiNSt6stringERKNSt10shared_ptrI5ArrayEE | +| EFN_TB_REMOVE_COLUMN | 3 | Remove colunm from table(record_batch). see https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch12RemoveColumnEi | +| EFN_TB_SET_COLUMN | 4 | Set colunm to table(record_batch). see https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch9SetColumnEiRKNSt10shared_ptrI5FieldEERKNSt10shared_ptrI5ArrayEE | + + + + + +## Scalar Value Types + +| .proto Type | Notes | C++ Type | Java Type | Python Type | +| ----------- | ----- | -------- | --------- | ----------- | +|

double | | double | double | float | +|

float | | float | float | float | +|

int32 | Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint32 instead. | int32 | int | int | +|

int64 | Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint64 instead. | int64 | long | int/long | +|

uint32 | Uses variable-length encoding. | uint32 | int | int/long | +|

uint64 | Uses variable-length encoding. | uint64 | long | int/long | +|

sint32 | Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int32s. | int32 | int | int | +|

sint64 | Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int64s. | int64 | long | int/long | +|

fixed32 | Always four bytes. More efficient than uint32 if values are often greater than 2^28. | uint32 | int | int | +|

fixed64 | Always eight bytes. More efficient than uint64 if values are often greater than 2^56. | uint64 | long | int/long | +|

sfixed32 | Always four bytes. | int32 | int | int | +|

sfixed64 | Always eight bytes. | int64 | long | int/long | +|

bool | | bool | boolean | boolean | +|

string | A string must always contain UTF-8 encoded or 7-bit ASCII text. | string | String | str/unicode | +|

bytes | May contain any arbitrary sequence of bytes. | string | ByteString | str | diff --git a/docs/source/reference/model_md.tmpl b/docs/source/reference/model_md.tmpl new file mode 100644 index 0000000..60bcd48 --- /dev/null +++ b/docs/source/reference/model_md.tmpl @@ -0,0 +1,90 @@ +# SecretFlow-Serving Model + +## Table of Contents +- Services +{{range .Files}} + +{{if .HasServices}} + {{range .Services}} - [{{.Name}}](#{{.Name | lower | replace "." ""}}) + {{end}} +{{end}} +{{end}} + +- Messages +{{range .Files}} + +{{if .HasMessages}} + {{range .Messages}} - [{{.LongName}}](#{{.LongName | lower | replace "." "-"}}) + {{end}} +{{end}} +{{end}} + +- Enums +{{range .Files}} + +{{if .HasEnums}} + {{range .Enums}} - [{{.LongName}}](#{{.LongName | lower | replace "." ""}}) + {{end}} +{{end}} +{{end}} +- [Scalar Value Types](#scalar-value-types) + +{{range .Files}} +{{range .Services -}} +{#{{.LongName}}} +## {{.Name}} +{{.Description}} + +{{range .Methods -}} +### {{.Name}} + +> **rpc** {{.Name}}([{{.RequestLongType}}](#{{.RequestLongType | lower | replace "." ""}})) + [{{.ResponseLongType}}](#{{.ResponseLongType | lower | replace "." ""}}) + +{{ .Description}} +{{end}} +{{end}} +{{end}} + +## Messages +{{range .Files}} +{{range .Messages}} + +{#{{.LongName}}} +### {{.LongName}} +{{.Description}} + +{{if .HasFields}} +| Field | Type | Description | +| ----- | ---- | ----------- | +{{range .Fields -}} + | {{if .IsOneof}}[**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) {{.OneofDecl}}.{{end}}{{.Name}} | [{{if .IsMap}}map {{else}}{{.Label}} {{end}}{{.LongType}}](#{{if .IsMap}}{{.LongType | lower | replace "." "-"}} {{else}}{{.Type | lower | replace "." "-"}} {{end}}) | {{if .Description}}{{nobr .Description}}{{if .DefaultValue}} Default: {{.DefaultValue}}{{end}}{{else}}none{{end}} | +{{end}} +{{end}} +{{end}} +{{end}} + +## Enums +{{range .Files}} +{{range .Enums}} + +{#{{.LongName}}} +### {{.LongName}} +{{.Description}} + +| Name | Number | Description | +| ---- | ------ | ----------- | +{{range .Values -}} + | {{.Name}} | {{.Number}} | {{if .Description}}{{nobr .Description}}{{else}}none{{end}} | +{{end}} + +{{end}} +{{end}} + +## Scalar Value Types + +| .proto Type | Notes | C++ Type | Java Type | Python Type | +| ----------- | ----- | -------- | --------- | ----------- | +{{range .Scalars -}} + |

{{.ProtoType}} | {{.Notes}} | {{.CppType}} | {{.JavaType}} | {{.PythonType}} | +{{end}} diff --git a/docs/source/reference/spi.md b/docs/source/reference/spi.md new file mode 100644 index 0000000..3386294 --- /dev/null +++ b/docs/source/reference/spi.md @@ -0,0 +1,374 @@ +# SecretFlow-Serving SPI + +## Table of Contents +- Services + + + + - [BatchFeatureService](#batchfeatureservice) + + + + + + + + + + + + + +- Messages + + + + - [BatchFetchFeatureRequest](#batchfetchfeaturerequest) + - [BatchFetchFeatureResponse](#batchfetchfeatureresponse) + + + + + + - [Header](#header) + - [Header.DataEntry](#header-dataentry) + - [Status](#status) + + + + + + + + + - [Feature](#feature) + - [FeatureField](#featurefield) + - [FeatureParam](#featureparam) + - [FeatureValue](#featurevalue) + + + + +- Enums + + + + + + + + + + - [ErrorCode](#errorcode) + + + + + + - [FieldType](#fieldtype) + + + +- [Scalar Value Types](#scalar-value-types) + + +{#BatchFeatureService} +## BatchFeatureService +BatchFeatureService provides access to fetch features. + +### BatchFetchFeature + +> **rpc** BatchFetchFeature([BatchFetchFeatureRequest](#batchfetchfeaturerequest)) + [BatchFetchFeatureResponse](#batchfetchfeatureresponse) + + + + + + + + + + + + +## Messages + + + +{#BatchFetchFeatureRequest} +### BatchFetchFeatureRequest +BatchFetchFeature request containing one or more requests. +examples: +```json + { + "header": { + "data": { + "custom_str": "id_12345" + }, + }, + "model_service_id": "test_service_id", + "party_id": "alice", + "feature_fields": [ + { + "name": "f1", + "type": 2 + }, + { + "name": "f2", + "type": 4 + } + ] + "param": { + "query_datas": [ + "x1", + "x2" + ], + "query_context": "context_x" + } + } +``` + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| header | [ Header](#header ) | Custom data passed by the Predict request's header. | +| model_service_id | [ string](#string ) | Model service specification. | +| party_id | [ string](#string ) | The request party id. | +| feature_fields | [repeated secretflow.serving.FeatureField](#featurefield ) | Request feature field list | +| param | [ secretflow.serving.FeatureParam](#featureparam ) | Custom query paramters for fetch features | + + + + +{#BatchFetchFeatureResponse} +### BatchFetchFeatureResponse +BatchFetchFeatureResponse response containing one or more responses. +examples: +```json + { + "header": { + "data": { + "custom_value": "asdfvb" + } + }, + "status": { + "code": 0, + "msg": "success." + }, + "features": [ + { + "field": { + "name": "f1", + "type": 2 + }, + "value": { + "i32s": [ + 123, + 234 + ] + } + }, + { + "field": { + "name": "f2", + "type": 4 + }, + "value": { + "fs": [ + 0.123, + 1.234 + ] + } + } + ] + } +``` + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| header | [ Header](#header ) | Custom data. | +| status | [ Status](#status ) | none | +| features | [repeated secretflow.serving.Feature](#feature ) | Should include all the features mentioned in the BatchFetchFeatureRequest.feature_fields | + + + + + + +{#Header} +### Header +Header containing custom data + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| data | [map Header.DataEntry](#header-dataentry ) | none | + + + + +{#Header.DataEntry} +### Header.DataEntry + + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| key | [ string](#string ) | none | +| value | [ string](#string ) | none | + + + + +{#Status} +### Status +Represents the status of spi request + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| code | [ int32](#int32 ) | code value reference `ErrorCode` in secretflow_serving/spis/error_code.proto | +| msg | [ string](#string ) | none | + + + + + + + + +{#Feature} +### Feature +The definition of a feature + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| field | [ FeatureField](#featurefield ) | none | +| value | [ FeatureValue](#featurevalue ) | none | + + + + +{#FeatureField} +### FeatureField +The definition of a feature field. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| name | [ string](#string ) | Unique name of the feature | +| type | [ FieldType](#fieldtype ) | Field type of the feature | + + + + +{#FeatureParam} +### FeatureParam +The param for fetch features + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| query_datas | [repeated string](#string ) | The serialized datas for query features. Each one for query one row of features. | +| query_context | [ string](#string ) | Optional. Represents the common part of the query datas. | + + + + +{#FeatureValue} +### FeatureValue +The value of a feature + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| i32s | [repeated int32](#int32 ) | int list | +| i64s | [repeated int64](#int64 ) | none | +| fs | [repeated float](#float ) | float list | +| ds | [repeated double](#double ) | none | +| ss | [repeated string](#string ) | string list | +| bs | [repeated bool](#bool ) | bool list | + + + + + +## Enums + + + + + + + +### ErrorCode +ONLY for Reference by ResponseHeader +It's subset of google.rpc.Code + +| Name | Number | Description | +| ---- | ------ | ----------- | +| OK | 0 | Not an error; returned on success + +HTTP Mapping: 200 OK | +| INVALID_ARGUMENT | 3 | The client specified an invalid argument. Note that this differs from `FAILED_PRECONDITION`. `INVALID_ARGUMENT` indicates arguments that are problematic regardless of the state of the system (e.g., a malformed file name). + +HTTP Mapping: 400 Bad Request | +| DEADLINE_EXCEEDED | 4 | The deadline expired before the operation could complete. For operations that change the state of the system, this error may be returned even if the operation has completed successfully. For example, a successful response from a server could have been delayed long enough for the deadline to expire. + +HTTP Mapping: 504 Gateway Timeout | +| NOT_FOUND | 5 | Some requested entity (e.g., file or directory) was not found. + +Note to server developers: if a request is denied for an entire class of users, such as gradual feature rollout or undocumented whitelist, `NOT_FOUND` may be used. If a request is denied for some users within a class of users, such as user-based access control, `PERMISSION_DENIED` must be used. + +HTTP Mapping: 404 Not Found | +| INTERNAL_ERROR | 13 | Internal errors. This means that some invariants expected by the underlying system have been broken. This error code is reserved for serious errors. + +HTTP Mapping: 500 Internal Server Error | +| UNAUTHENTICATED | 16 | The request does not have valid authentication credentials for the operation. + +HTTP Mapping: 401 Unauthorized | + + + + + + +### FieldType +Supported feature field type. + +| Name | Number | Description | +| ---- | ------ | ----------- | +| UNKNOWN_FIELD_TYPE | 0 | Placeholder for proto3 default value, do not use it. | +| FIELD_BOOL | 1 | BOOL | +| FIELD_INT32 | 2 | INT32 | +| FIELD_INT64 | 3 | INT64 | +| FIELD_FLOAT | 4 | FLOAT | +| FIELD_DOUBLE | 5 | DOUBLE | +| FIELD_STRING | 6 | STRING | + + + + + +## Scalar Value Types + +| .proto Type | Notes | C++ Type | Java Type | Python Type | +| ----------- | ----- | -------- | --------- | ----------- | +|

double | | double | double | float | +|

float | | float | float | float | +|

int32 | Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint32 instead. | int32 | int | int | +|

int64 | Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint64 instead. | int64 | long | int/long | +|

uint32 | Uses variable-length encoding. | uint32 | int | int/long | +|

uint64 | Uses variable-length encoding. | uint64 | long | int/long | +|

sint32 | Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int32s. | int32 | int | int | +|

sint64 | Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int64s. | int64 | long | int/long | +|

fixed32 | Always four bytes. More efficient than uint32 if values are often greater than 2^28. | uint32 | int | int | +|

fixed64 | Always eight bytes. More efficient than uint64 if values are often greater than 2^56. | uint64 | long | int/long | +|

sfixed32 | Always four bytes. | int32 | int | int | +|

sfixed64 | Always eight bytes. | int64 | long | int/long | +|

bool | | bool | boolean | boolean | +|

string | A string must always contain UTF-8 encoded or 7-bit ASCII text. | string | String | str/unicode | +|

bytes | May contain any arbitrary sequence of bytes. | string | ByteString | str | diff --git a/docs/source/reference/spi_md.tmpl b/docs/source/reference/spi_md.tmpl new file mode 100644 index 0000000..ee2715e --- /dev/null +++ b/docs/source/reference/spi_md.tmpl @@ -0,0 +1,89 @@ +# SecretFlow-Serving SPI + +## Table of Contents +- Services +{{range .Files}} + +{{if .HasServices}} + {{range .Services}} - [{{.Name}}](#{{.Name | lower | replace "." ""}}) + {{end}} +{{end}} +{{end}} + +- Messages +{{range .Files}} + +{{if .HasMessages}} + {{range .Messages}} - [{{.LongName}}](#{{.LongName | lower | replace "." "-"}}) + {{end}} +{{end}} +{{end}} + +- Enums +{{range .Files}} + +{{if .HasEnums}} + {{range .Enums}} - [{{.LongName}}](#{{.LongName | lower | replace "." ""}}) + {{end}} +{{end}} +{{end}} +- [Scalar Value Types](#scalar-value-types) + +{{range .Files}} +{{range .Services -}} +{#{{.LongName}}} +## {{.Name}} +{{.Description}} + +{{range .Methods -}} +### {{.Name}} + +> **rpc** {{.Name}}([{{.RequestLongType}}](#{{.RequestLongType | lower | replace "." ""}})) + [{{.ResponseLongType}}](#{{.ResponseLongType | lower | replace "." ""}}) + +{{ .Description}} +{{end}} +{{end}} +{{end}} + +## Messages +{{range .Files}} +{{range .Messages}} + +{#{{.LongName}}} +### {{.LongName}} +{{.Description}} + +{{if .HasFields}} +| Field | Type | Description | +| ----- | ---- | ----------- | +{{range .Fields -}} + | {{if .IsOneof}}[**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) {{.OneofDecl}}.{{end}}{{.Name}} | [{{if .IsMap}}map {{else}}{{.Label}} {{end}}{{.LongType}}](#{{if .IsMap}}{{.LongType | lower | replace "." "-"}} {{else}}{{.Type | lower | replace "." "-"}} {{end}}) | {{if .Description}}{{nobr .Description}}{{if .DefaultValue}} Default: {{.DefaultValue}}{{end}}{{else}}none{{end}} | +{{end}} +{{end}} +{{end}} +{{end}} + +## Enums +{{range .Files}} +{{range .Enums}} + +### {{.LongName}} +{{.Description}} + +| Name | Number | Description | +| ---- | ------ | ----------- | +{{range .Values -}} + | {{.Name}} | {{.Number}} | {{if .Description}}{{nobr .Description}}{{else}}none{{end}} | +{{end}} + +{{end}} +{{end}} + +## Scalar Value Types + +| .proto Type | Notes | C++ Type | Java Type | Python Type | +| ----------- | ----- | -------- | --------- | ----------- | +{{range .Scalars -}} + |

{{.ProtoType}} | {{.Notes}} | {{.CppType}} | {{.JavaType}} | {{.PythonType}} | +{{end}} diff --git a/docs/source/topics/deployment/deployment.rst b/docs/source/topics/deployment/deployment.rst new file mode 100644 index 0000000..71e82fe --- /dev/null +++ b/docs/source/topics/deployment/deployment.rst @@ -0,0 +1,212 @@ +==================================== +How to deployment SecretFlow-Serving +==================================== + +This document describes how to deploy SecretFlow-Serving with docker, it's basically same with :doc:`/intro/tutorial`, but deployed in multi-machine. + +Before start this doc, we assume that the reader has some experience using the docker-compose utility. If you are new to Docker Compose, please consider reviewing the `official Docker Compose overview `_, or checking out the `Getting Started guide `_. + +Deployment Diagram +================== + +The deployment diagram of the SecretFlow-Serving system that we plan to deploy is shown as the following figure, it involves a total of two party, including two parties named ``Alice`` and ``Bob``. We use two machines to simulate different parties. + +.. image:: /imgs/architecture.png + +.. note:: + 1. The SecretFlow-Serving is served through the HTTP protocol. It is recommended to use HTTPS instead in production environments. Please check :ref:`TLS Configuration ` for details. + +Step 1: Deploy SecretFlow-Serving +================================= + +Here we present how to deploy serving in party ``Alice``, it's same with party ``Bob``. + +1.1 Create a Workspace +----------------------- + +.. code-block:: bash + + mkdir serving + cd serving + +Here, we use the model file from the "examples" directory as a demonstration and place it in the "serving" directory. Please replace the following path with the actual path according to your situation. + +.. code-block:: bash + + cp serving/examples/alice/glm-test.tar.gz . + + +.. note:: + + For ``Bob`` should use model file `serving/examples/bob/glm-test.tar.gz`. + + +1.2 Create Serving config file +------------------------------ + +Create a file called ``serving.config`` in your workspace and paste the following code in: + +.. code-block:: json + + { + "id": "test_service_id", + "serverConf": { + "featureMapping": { + "v24": "x24", + "v22": "x22", + "v21": "x21", + "v25": "x25", + "v23": "x23" + }, + "metricsExposerPort": 10306, + "brpcBuiltinServicePort": 10307 + }, + "modelConf": { + "modelId": "glm-test", + "basePath": "./data", + "sourcePath": "./glm-test.tar.gz", + "sourceSha256": "3b6a3b76a8d5bbf0e45b83f2d44772a0a6aa9a15bf382cee22cbdc8f59d55522", + "sourceType": "ST_FILE" + }, + "clusterConf": { + "selfId": "alice", + "parties": [ + { + "id": "alice", + "address": "0.0.0.0:9010" + }, + { + "id": "bob", + "address": "0.0.0.0:9011" + } + ], + "channel_desc": { + "protocol": "baidu_std" + } + }, + "featureSourceConf": { + "mockOpts": {} + } + } + +See :ref:`Serving Config ` for more config information + +.. note:: + + The above configuration is referenced from `alice-serving-config `_. + + For ``Bob``, you should refer to `bob-serving-config `_ . + + +1.3 Create logging config file +------------------------------ + +Create a file called ``logging.config`` in your workspace and paste the following code in: + +.. code-block:: json + + { + "systemLogPath": "./serving.log", + "logLevel": 2, + "maxLogFileSize": 104857600, + "maxLogFileCount": 2 + } + +See :ref:`Logging Config ` for more logging config information. + +.. note:: + + The above configuration is referenced from `alice-logging-config `_. + + For ``Bob``, you should refer to `bob-logging-config `_ . + + +1.4 Create docker-compose file +------------------------------ + +Create a file called ``docker-compose.yaml`` in your workspace and paste the following code in: + +.. code-block:: yaml + + version: "3.8" + services: + serving: + cap_add: + - NET_ADMIN + command: + - /root/sf_serving/secretflow_serving + - --serving_config_file=/root/sf_serving/conf/serving.config + - --logging_config_file=/root/sf_serving/conf/logging.config + restart: always + image: secretflow/serving-anolis8:latest + ports: + - __ALICE_PORT__:9010 + volumes: + - ./serving.conf:/root/sf_serving/conf/serving.config + +.. note:: + + ``__ALICE_PORT__`` is the published port on the host machine which is used for SecretFlow-Serving service to listen on, you need to replace it with an accessible port number. In this case, we have designated it as ``9010`` for ``Alice``, ``9011`` for ``Bob``. + + +Step 2: Start Serving Service +============================= + +The file your workspace should be as follows: + +.. code-block:: bash + + └── serving + ├── serving.config + ├── logging.config + └── docker-compose.yaml + +Then you can start serving service by running docker compose up + +.. code-block:: bash + + # If you install docker with Compose V1, pleas use `docker-compose` instead of `docker compose` + docker compose -f docker-compose.yaml up -d + +You can use docker logs to check whether serving works well + +.. code-block:: bash + + docker logs -f serving_serving_1 + +Now, ``Alice`` serving is listening on ``9010``, you can confirm if the service is ready by accessing the ``/health`` endpoint. + +.. code-block:: bash + + curl --location 'http://127.0.0.1:9010/health' + +When the endpoint returns a status code of ``200``, it means that the service is ready. + +Step 3: Predict Test +==================== + +Based on the capabilities of `Brpc `_, serving supports accessing through various protocols. Here, we are using an HTTP request to test the predict interface of serving. + +You can read :ref:`SecretFlow-Serving API ` for more information about serving APIs. + +.. code-block:: bash + + curl --location 'http://127.0.0.1:9010/PredictionService/Predict' \ + --header 'Content-Type: application/json' \ + --data '{ + "service_spec": { + "id": "test_service_id" + }, + "fs_params": { + "alice": { + "query_datas": [ + "a" + ] + }, + "bob": { + "query_datas": [ + "a" + ] + } + } + }' diff --git a/docs/source/topics/graph/intro_to_graph.rst b/docs/source/topics/graph/intro_to_graph.rst new file mode 100644 index 0000000..463c781 --- /dev/null +++ b/docs/source/topics/graph/intro_to_graph.rst @@ -0,0 +1,75 @@ +.. _intro-graph: + +Introduction to Graph +===================== + +Secretflow-Serving has defined a protocol for describing prediction computations, which mainly includes descriptions of operators, attributes, nodes, graphs, and executions. + +.. image:: /imgs/graph.png + :alt: graph structure + +Operators +--------- +Operators describe specific computations. By combining operators, different model computations can be achieved. Plain operators perform computations using local data only, while secure computation operators collaborate with peer operators from other participants for secure computations. + +OpDef +^^^^^ + +* name: Unique name of the operator. +* desc: Description of the operator. +* version: The version of the operator. +* tag: Some properties of the operator. +* attributes: Please check `Attributes` part below. +* inputs and output: The info of the inputs or output of the operator. + +Attributes +---------- +Operators have various attributes determined by their definitions. These attributes and their data support the operators in completing computations. + +AttrDef +^^^^^^^ + +* name: Must be unique among all attrs of the operator. +* desc: Description of the attribute. +* type: Please check :ref:`AttrType `. +* is_optional: If True, when AttrValue is not provided, `default_value` would be used. Else, AttrValue must be provided. +* default_value: Please check :ref:`AttrValue `. + +Nodes +----- +Nodes are instances of operators. They store the attribute values (`AttrValue`) of the operators. + +NodeDef +^^^^^^^ + +* name: Must be unique among all nodes of the graph. +* op: The operator name. +* parents: The parent node names of the node. The order of the parent nodes should match the order of the inputs of the node. +* attr_values: The attribute values config in the node. Note that this should include all attrs defined in the corresponding OpDef +* op_version: The operator version. + +Graphs +------ +Graphs can consist of one or multiple nodes. They form a directed acyclic graph, where the direction represents the flow of data computation. A graph can represent a complete prediction computation process, including preprocessing, model prediction, and post-processing. + +Each participant will have a graph with the same structure but different data。 + +GraphDef +^^^^^^^^ + +* version: Version of the graph. +* node_list: The node list of the graph. +* execution_list: Please check `Executions` part below. + +Executions +---------- +Execution contain a subset of nodes from the main graph and form a subgraph. They represent the model computation scheduling patterns. A graph can have multiple executions. + +.. image:: /imgs/execution.png + :alt: execution + +ExecutionDef +^^^^^^^^^^^^ + +* nodes: Represents the nodes contained in this execution. Note that these node names should be findable and unique within the node definitions. One node can only exist in one execution and must exist in one. +* config: The runtime config of the execution. It describes the scheduling logic and session-related states of this execution unit. for more details, please check :ref:`RuntimeConfig `. diff --git a/docs/source/topics/graph/operator_list.md b/docs/source/topics/graph/operator_list.md new file mode 100644 index 0000000..63ca162 --- /dev/null +++ b/docs/source/topics/graph/operator_list.md @@ -0,0 +1,114 @@ + + + + +SecretFlow-Serving Operator List +================================ + + +Last update: Thu Dec 28 14:28:43 2023 +## MERGE_Y + + +Operator version: 0.0.2 + +Merge all partial y(score) and apply link function +### Attrs + + +|Name|Description|Type|Required|Notes| +| :--- | :--- | :--- | :--- | :--- | +|output_col_name|The column name of merged score|String|Y|| +|link_function|Type of link function, defined in `secretflow_serving/protos/link_function.proto`. Optional value: LF_LOG, LF_LOGIT, LF_INVERSE, LF_RECIPROCAL, LF_IDENTITY, LF_SIGMOID_RAW, LF_SIGMOID_MM1, LF_SIGMOID_MM3, LF_SIGMOID_GA, LF_SIGMOID_T1, LF_SIGMOID_T3, LF_SIGMOID_T5, LF_SIGMOID_T7, LF_SIGMOID_T9, LF_SIGMOID_LS7, LF_SIGMOID_SEG3, LF_SIGMOID_SEG5, LF_SIGMOID_DF, LF_SIGMOID_SR, LF_SIGMOID_SEGLS|String|Y|| +|input_col_name|The column name of partial_y|String|Y|| +|yhat_scale|In order to prevent value overflow, GLM training is performed on the scaled y label. So in the prediction process, you need to enlarge yhat back to get the real predicted value, `yhat = yhat_scale * link(X * W)`|Double|N|Default: 1.0.| + +### Tags + + +|Name|Description| +| :--- | :--- | +|returnable|The operator's output can be the final result| +|mergeable|The operator accept the output of operators with different participants and will somehow merge them.| + +### Inputs + + +|Name|Description| +| :--- | :--- | +|partial_ys|The list of partial y, data type: `double`| + +### Output + + +|Name|Description| +| :--- | :--- | +|scores|The merge result of `partial_ys`, data type: `double`| + +## DOT_PRODUCT + + +Operator version: 0.0.2 + +Calculate the dot product of feature weights and values +### Attrs + + +|Name|Description|Type|Required|Notes| +| :--- | :--- | :--- | :--- | :--- | +|intercept|Value of model intercept|Double|N|Default: 0.0.| +|output_col_name|Column name of partial y|String|Y|| +|feature_weights|List of feature weights|Double List|Y|| +|input_types|List of input feature data types, Note that there is a loss of precision when using `DT_FLOAT` type. Optional value: DT_UINT8, DT_INT8, DT_UINT16, DT_INT16, DT_UINT32, DT_INT32, DT_UINT64, DT_INT64, DT_FLOAT, DT_DOUBLE|String List|Y|| +|feature_names|List of feature names|String List|Y|| + +### Inputs + + +|Name|Description| +| :--- | :--- | +|features|Input feature table| + +### Output + + +|Name|Description| +| :--- | :--- | +|partial_ys|The calculation results, they have a data type of `double`.| + +## ARROW_PROCESSING + + +Operator version: 0.0.1 + +Replay secretflow compute functions +### Attrs + + +|Name|Description|Type|Required|Notes| +| :--- | :--- | :--- | :--- | :--- | +|content_json_flag|Whether `trace_content` is serialized json|Boolean|N|Default: False.| +|trace_content|Serialized data of secretflow compute trace|Bytes|N|| +|output_schema_bytes|Serialized data of output schema(arrow::Schema)|Bytes|Y|| +|input_schema_bytes|Serialized data of input schema(arrow::Schema)|Bytes|Y|| + +### Tags + + +|Name|Description| +| :--- | :--- | +|returnable|The operator's output can be the final result| + +### Inputs + + +|Name|Description| +| :--- | :--- | +|input|| + +### Output + + +|Name|Description| +| :--- | :--- | +|output|| diff --git a/docs/source/topics/graph/update_operator_list.py b/docs/source/topics/graph/update_operator_list.py new file mode 100644 index 0000000..6a6bb92 --- /dev/null +++ b/docs/source/topics/graph/update_operator_list.py @@ -0,0 +1,192 @@ +# +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from secretflow_serving_lib import get_all_ops +from secretflow_serving_lib.attr_pb2 import AttrValue, AttrType +from mdutils.mdutils import MdUtils + +import datetime + +this_directory = os.path.abspath(os.path.dirname(__file__)) + +mdFile = MdUtils( + file_name=os.path.join(this_directory, 'operator_list.md'), +) + +mdFile.new_header(level=1, title='SecretFlow-Serving Operator List', style='setext') + +mdFile.new_paragraph(f'Last update: {datetime.datetime.now().strftime("%c")}') + +AttrTypeStrMap = { + AttrType.UNKNOWN_AT_TYPE: 'Undefined', + AttrType.AT_INT32: 'Integer32', + AttrType.AT_INT64: 'Integer64', + AttrType.AT_FLOAT: 'Float', + AttrType.AT_DOUBLE: 'Double', + AttrType.AT_STRING: 'String', + AttrType.AT_BOOL: 'Boolean', + AttrType.AT_BYTES: 'Bytes', + AttrType.AT_INT32_LIST: 'Integer32 List', + AttrType.AT_INT64_LIST: 'Integer64 List', + AttrType.AT_FLOAT_LIST: 'Float List', + AttrType.AT_DOUBLE_LIST: 'Double List', + AttrType.AT_STRING_LIST: 'String List', + AttrType.AT_BOOL_LIST: 'Boolean List', + AttrType.AT_BYTES_LIST: 'Bytes List', +} + + +def get_atomic_attr_value(at: AttrType, attr: AttrValue): + if at == AttrType.AT_FLOAT: + return round(attr.f, 5) + elif at == AttrType.AT_DOUBLE: + return round(attr.d, 8) + elif at == AttrType.AT_INT32: + return attr.i32 + elif at == AttrType.AT_INT64: + return attr.i64 + elif at == AttrType.AT_STRING: + return attr.s + elif at == AttrType.AT_BOOL: + return attr.b + elif at == AttrType.AT_FLOAT_LIST: + return [round(f, 5) for f in attr.fs.data] + elif at == AttrType.AT_DOUBLE_LIST: + return [round(f, 8) for f in attr.ds.data] + elif at == AttrType.AT_INT32_LIST: + return list(attr.i32s.data) + elif at == AttrType.AT_INT64_LIST: + return list(attr.i64s.data) + elif at == AttrType.AT_STRING_LIST: + return list(attr.ss.data) + elif at == AttrType.AT_BOOL_LIST: + return list(attr.bs.data) + elif at == AttrType.AT_BYTES_LIST: + return list(attr.bs.data) + else: + return None + + +def parse_comp_io(md, io_defs): + io_table_text = ['Name', 'Description'] + for io_def in io_defs: + io_table_text.extend([io_def.name, io_def.desc]) + + md.new_line() + md.new_table( + columns=2, + rows=len(io_defs) + 1, + text=io_table_text, + text_align='left', + ) + + +op_list = get_all_ops() + + +for op in op_list: + mdFile.new_header( + level=2, + title=op.name, + ) + mdFile.new_paragraph(f'Operator version: {op.version}') + mdFile.new_paragraph(op.desc) + + # build attrs + if len(op.attrs): + mdFile.new_header( + level=3, + title='Attrs', + ) + attr_table_text = ["Name", "Description", "Type", "Required", "Notes"] + for attr in op.attrs: + name_str = attr.name + type_str = AttrTypeStrMap[attr.type] + notes_str = '' + required_str = 'N/A' + + default_value = None + if attr.is_optional: + default_value = get_atomic_attr_value(attr.type, attr.default_value) + + if default_value is not None: + notes_str += f'Default: {default_value}. ' + + required_str = 'N' if attr.is_optional else 'Y' + + attr_table_text.extend( + [name_str, attr.desc, type_str, required_str, notes_str.rstrip()] + ) + + mdFile.new_line() + mdFile.new_table( + columns=5, + rows=len(op.attrs) + 1, + text=attr_table_text, + text_align='left', + ) + + # build tag + if op.tag.returnable or op.tag.mergeable or op.tag.session_run: + tag_rows = 0 + mdFile.new_header( + level=3, + title='Tags', + ) + tag_table_text = ["Name", "Description"] + if op.tag.returnable: + tag_rows += 1 + tag_table_text.extend( + ["returnable", "The operator's output can be the final result"] + ) + if op.tag.mergeable: + tag_rows += 1 + tag_table_text.extend( + [ + "mergeable", + "The operator accept the output of operators with different participants and will somehow merge them.", + ] + ) + if op.tag.session_run: + tag_rows += 1 + tag_table_text.extend( + ["session_run", "The operator needs to be executed in session."] + ) + + mdFile.new_line() + mdFile.new_table( + columns=2, + rows=tag_rows + 1, + text=tag_table_text, + text_align='left', + ) + + # build inputs/output + if len(op.inputs): + mdFile.new_header( + level=3, + title='Inputs', + ) + parse_comp_io(mdFile, op.inputs) + + mdFile.new_header( + level=3, + title='Output', + ) + parse_comp_io(mdFile, [op.output]) + + +mdFile.create_md_file() diff --git a/docs/source/topics/index.rst b/docs/source/topics/index.rst new file mode 100644 index 0000000..6c402ea --- /dev/null +++ b/docs/source/topics/index.rst @@ -0,0 +1,25 @@ +.. _topics: + +Topics +====== + + +.. toctree:: + :maxdepth: 2 + :caption: system + + system/intro + system/observability + +.. toctree:: + :maxdepth: 2 + :caption: deployment + + deployment/deployment + +.. toctree:: + :maxdepth: 2 + :caption: graph + + graph/intro_to_graph + graph/operator_list diff --git a/docs/source/topics/system/intro.rst b/docs/source/topics/system/intro.rst new file mode 100644 index 0000000..5f5917b --- /dev/null +++ b/docs/source/topics/system/intro.rst @@ -0,0 +1,62 @@ +SecretFlow-Serving System Introduction +====================================== + +SecretFlow-Serving is a serving system for privacy-preserving machine learning models. + +Key Features +------------ + +* Support multiple parties (N >= 2). +* Parallel compute between parties. +* Batch Predict API Supported. +* Multi-protocol support. Secretflow-Serving is built on brpc, a high-performance rpc framework, and is capable of using multiple communication protocols. +* Support multiple types feature sources, e.g. SPI, CSV file, Mock data. +* Specific model graph definition. +* Federated learning model predict. +* One process one model/version. + + +Architecture +------------ + +Secretflow-Serving leverages the model package trained with Secretflow to provide model prediction capabilities at different security levels. It achieves this by utilizing the online feature data provided by each participant without compromising the integrity of the original data domain. + +.. image:: /imgs/architecture.png + :alt: Secretflow-Serving Deployment Architecture + + +Key Concepts +^^^^^^^^^^^^ + +To understand the architecture of Secretflow-Serving, you need to understand the following key concepts: + + +Model Package ++++++++++++++ + +A Secretflow-Serving model package is a compressed package comprising a model file, a manifest file, and other metadata files. + +The manifest file provides meta-information about the model file and follows the defined structure outlined :ref:`here `. + +The model file contains the graph that represents the model inference process, encompassing pre-processing, post-processing, and the specific inference algorithm. For graph details, please check :ref:`Introduction to Graph `. + +The metadata files, while optional, stores additional data information required during the model inference process. + + +Model Source ++++++++++++++ + +Secretflow-Serving supports retrieving model packages from different storage sources. Currently, the following data sources are supported: + +* Local Filesystem Data Source: Secretflow-Serving loads the model package from a specified local path. +* OSS/S3 Data Source: Secretflow-Serving attempts to download the model package from the OSS/S3 storage based on the provided configuration before loading it locally. + + +Feature Source ++++++++++++++++ + +Secretflow-Serving obtains the necessary features for the online inference process through the Feature Source. Currently, the platform supports the following feature data sources: + +* HTTP Source: Secretflow-Serving defines a Service Provider Interface (:doc:`SPI `) for retrieving feature data. Feature providers can implement this SPI to supply features to Secretflow-Serving. +* CSV Source: Secretflow-Serving supports direct loading of CSV file as a feature source. For performance reasons, the CSV file is fully loaded into memory and features are filtered based on the ID column. +* Mock Source: In this scenario, Secretflow-Serving uses randomly generated values as feature data. diff --git a/docs/source/topics/system/observability.rst b/docs/source/topics/system/observability.rst new file mode 100644 index 0000000..edbf05d --- /dev/null +++ b/docs/source/topics/system/observability.rst @@ -0,0 +1,2 @@ +SecretFlow-Serving System Observability +======================================= diff --git a/docs/update_po.sh b/docs/update_po.sh new file mode 100755 index 0000000..904f424 --- /dev/null +++ b/docs/update_po.sh @@ -0,0 +1,5 @@ +#!/bin/bash +mkdir -p _build/gettext && +make gettext && +sphinx-intl update -p _build/gettext -l zh_CN && +echo "po files has been updated. Please update po files in locales folder." diff --git a/docs/update_reference.sh b/docs/update_reference.sh new file mode 100755 index 0000000..63ef007 --- /dev/null +++ b/docs/update_reference.sh @@ -0,0 +1,43 @@ +#! /bin/bash +# +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# LIST=$(find ../secretflow_serving/apis -name "*.proto" | cut -c 4- | sort -t '\0' -n) +# echo ${LIST} + +echo "1. Update spi doc." +docker run --rm -v $(pwd)/source/reference/:/out \ + -v $(pwd)/..:/protos \ + pseudomuto/protoc-gen-doc \ + --doc_opt=/out/spi_md.tmpl,spi.md $(find ../secretflow_serving/spis -name "*.proto" | cut -c 4- | sort -t '\0' -n) secretflow_serving/protos/feature.proto + +echo "2. Update api doc." +docker run --rm -v $(pwd)/source/reference/:/out \ + -v $(pwd)/..:/protos \ + pseudomuto/protoc-gen-doc \ + --doc_opt=/out/api_md.tmpl,api.md $(find ../secretflow_serving/apis -name "*.proto" | cut -c 4- | sort -t '\0' -n) secretflow_serving/protos/feature.proto + +echo "3. Update config doc." +docker run --rm -v $(pwd)/source/reference/:/out \ + -v $(pwd)/..:/protos \ + pseudomuto/protoc-gen-doc \ + --doc_opt=/out/config_md.tmpl,config.md $(find ../secretflow_serving/config -name "*.proto" | cut -c 4- | sort -t '\0' -n) + +echo "4. Update model graph doc." +docker run --rm -v $(pwd)/source/reference/:/out \ + -v $(pwd)/..:/protos \ + pseudomuto/protoc-gen-doc \ + --doc_opt=/out/model_md.tmpl,model.md secretflow_serving/protos/attr.proto secretflow_serving/protos/op.proto secretflow_serving/protos/graph.proto secretflow_serving/protos/bundle.proto secretflow_serving/protos/data_type.proto secretflow_serving/protos/compute_trace.proto secretflow_serving/protos/compute_trace.proto diff --git a/examples/alice/glm-test.tar.gz b/examples/alice/glm-test.tar.gz index 05bc723..a3a3eb9 100644 Binary files a/examples/alice/glm-test.tar.gz and b/examples/alice/glm-test.tar.gz differ diff --git a/examples/alice/serving.config b/examples/alice/serving.config index 03a1ab4..92d64af 100644 --- a/examples/alice/serving.config +++ b/examples/alice/serving.config @@ -15,7 +15,7 @@ "modelId": "glm-test-1", "basePath": "/tmp/alice", "sourcePath": "examples/alice/glm-test.tar.gz", - "sourceMd5": "4216c62acba4a630d5039f917612780b", + "sourceSha256": "3b6a3b76a8d5bbf0e45b83f2d44772a0a6aa9a15bf382cee22cbdc8f59d55522", "sourceType": "ST_FILE" }, "clusterConf": { diff --git a/examples/bob/glm-test.tar.gz b/examples/bob/glm-test.tar.gz index a1e8e96..793e8b5 100644 Binary files a/examples/bob/glm-test.tar.gz and b/examples/bob/glm-test.tar.gz differ diff --git a/examples/bob/serving.config b/examples/bob/serving.config index 5a8a8d9..397dd2f 100644 --- a/examples/bob/serving.config +++ b/examples/bob/serving.config @@ -15,7 +15,7 @@ "modelId": "glm-test-1", "basePath": "/tmp/bob", "sourcePath": "examples/bob/glm-test.tar.gz", - "sourceMd5": "1ded1513dab8734e23152ef906c180fc", + "sourceSha256": "330192f3a51f9498dd882478bfe08a06501e2ed4aa2543a0fb586180925eb309", "sourceType": "ST_FILE" }, "clusterConf": { diff --git a/examples/docker-compose.yml b/examples/docker-compose.yml new file mode 100644 index 0000000..7d70161 --- /dev/null +++ b/examples/docker-compose.yml @@ -0,0 +1,18 @@ +version: '3.8' +services: + serving_alice: + image: ${SERVING_IMAGE:-secretflow/serving-anolis8:latest} + command: + - /root/sf_serving/secretflow_serving + - --serving_config_file=/root/sf_serving/examples/alice/serving.config + - --logging_config_file=/root/sf_serving/examples/alice/logging.config + restart: always + network_mode: host + serving_bob: + image: ${SERVING_IMAGE:-secretflow/serving-anolis8:latest} + command: + - /root/sf_serving/secretflow_serving + - --serving_config_file=/root/sf_serving/examples/bob/serving.config + - --logging_config_file=/root/sf_serving/examples/bob/logging.config + restart: always + network_mode: host diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..d4796ce --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,28 @@ +[tool.black] +skip-string-normalization = true + +[tool.isort] +profile = "black" + +[tool.rstcheck] +report_level = "ERROR" +ignore_directives = [ + "include", + "mermaid", + "autoclass", + "autofunction", +] +ignore_languages = [ + "cpp" +] + +[tool.pyright] +include = [ + "secretflow_serving_lib", +] + +reportMissingImports = true +reportMissingTypeStubs = false + +pythonVersion = "3.8" +pythonPlatform = "Linux" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5c3abaa --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +grpcio>=1.42.0,!=1.48.0 +protobuf>=3.19, <4 +pyarrow==14.0.2 diff --git a/secretflow_serving/apis/BUILD.bazel b/secretflow_serving/apis/BUILD.bazel index 37a4dff..1f115b8 100644 --- a/secretflow_serving/apis/BUILD.bazel +++ b/secretflow_serving/apis/BUILD.bazel @@ -57,6 +57,7 @@ proto_library( ":common_proto", ":status_proto", "//secretflow_serving/protos:feature_proto", + "//secretflow_serving/protos:graph_proto", ], ) @@ -65,6 +66,23 @@ cc_proto_library( deps = [":prediction_service_proto"], ) +proto_library( + name = "model_service_proto", + srcs = [ + "model_service.proto", + ], + deps = [ + ":common_proto", + ":status_proto", + "//secretflow_serving/protos:bundle_proto", + ], +) + +cc_proto_library( + name = "model_service_cc_proto", + deps = [":model_service_proto"], +) + proto_library( name = "execution_service_proto", srcs = [ diff --git a/secretflow_serving/apis/error_code.proto b/secretflow_serving/apis/error_code.proto index 5aab3d1..08b3ac9 100644 --- a/secretflow_serving/apis/error_code.proto +++ b/secretflow_serving/apis/error_code.proto @@ -18,7 +18,7 @@ syntax = "proto3"; package secretflow.serving.errors; enum ErrorCode { - // placeholder for proto3 default value, do not use it + // Placeholder for proto3 default value, do not use it UNKNOWN = 0; // 001-099 for general code @@ -31,9 +31,10 @@ enum ErrorCode { NOT_FOUND = 5; NOT_IMPLEMENTED = 6; LOGIC_ERROR = 7; - SERIALIZE_FAILD = 8; - DESERIALIZE_FAILD = 9; + SERIALIZE_FAILED = 8; + DESERIALIZE_FAILED = 9; IO_ERROR = 10; + NOT_READY = 11; // 100-199 for mapping feature service code, see // `secretflow_serving/spis/error_code.proto` diff --git a/secretflow_serving/apis/model_service.proto b/secretflow_serving/apis/model_service.proto new file mode 100644 index 0000000..9cd8c92 --- /dev/null +++ b/secretflow_serving/apis/model_service.proto @@ -0,0 +1,50 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +syntax = "proto3"; + +import "secretflow_serving/apis/common.proto"; +import "secretflow_serving/apis/status.proto"; +import "secretflow_serving/protos/bundle.proto"; + +package secretflow.serving.apis; + +option cc_generic_services = true; + +// ModelService provides operation ralated to models. +service ModelService { + rpc GetModelInfo(GetModelInfoRequest) returns (GetModelInfoResponse); +} + +message GetModelInfoRequest { + // Custom data. + Header header = 1; + + // Model service specification. + ServiceSpec service_spec = 2; +} + +message GetModelInfoResponse { + // Custom data. + Header header = 1; + + // Staus of this response. + Status status = 2; + + // Model service specification. + ServiceSpec service_spec = 3; + + ModelInfo model_info = 4; +} diff --git a/secretflow_serving/apis/prediction_service.proto b/secretflow_serving/apis/prediction_service.proto index 76cc1d4..15d106d 100644 --- a/secretflow_serving/apis/prediction_service.proto +++ b/secretflow_serving/apis/prediction_service.proto @@ -31,9 +31,10 @@ service PredictionService { // Result of regression or one class of Classifications message Score { - // name of the score. + // The name of the score, it depends on the attribute configuration of the + // model. string name = 1; - // value of the score. + // The value of the score. double value = 2; } @@ -44,6 +45,35 @@ message PredictResult { } // Predict request containing one or more requests. +// examples: +// ```json +// { +// "header": { +// "data": { +// "custom_str": "id_12345" +// }, +// }, +// "service_spec": { +// "id": "test_service_id" +// }, +// "fs_params": { +// "alice": { +// "query_datas": [ +// "x1", +// "x2" +// ], +// "query_context": "context_x" +// }, +// "bob": { +// "query_datas": [ +// "y1", +// "y2" +// ], +// "query_context": "context_y" +// } +// } +// } +// ``` message PredictRequest { // Custom data. The header will be passed to the downstream system which // implement the feature service spi. @@ -65,6 +95,35 @@ message PredictRequest { } // Predict response containing one or more responses. +// examples: +// ```json +// { +// "header": { +// "data": { +// "custom_value": "asdfvb" +// }, +// }, +// "status": { +// "code": 0, +// "msg": "success." +// }, +// "service_spec": { +// "id": "test_service_id" +// }, +// "results": { +// "scores": [ +// { +// "name": "pred_y", +// "value": 0.32456 +// }, +// { +// "name": "pred_y", +// "value": 0.02456 +// } +// ] +// } +// } +// ``` message PredictResponse { // Custom data. Passed by the downstream system which implement the feature // service spi. diff --git a/secretflow_serving/config/cluster_config.proto b/secretflow_serving/config/cluster_config.proto index 32ab210..ab9b5a5 100644 --- a/secretflow_serving/config/cluster_config.proto +++ b/secretflow_serving/config/cluster_config.proto @@ -34,6 +34,16 @@ message ChannelDesc { // TLS related config. TlsConfig tls_config = 4; + + // When the server starts, model information from all parties will be + // collected. At this time, the remote servers may not have started yet, and + // we need to retry. And if we connect gateway,the max waiting time for each + // operation will be rpc_timeout_ms + handshake_retry_interval_ms. + // Maximum number of retries, default: 60 + int32 handshake_max_retry_cnt = 5; + + // time between retries, default: 5000ms + int32 handshake_retry_interval_ms = 6; } // Description for a joined party diff --git a/secretflow_serving/config/feature_config.proto b/secretflow_serving/config/feature_config.proto index 93e5e07..e78cfe2 100644 --- a/secretflow_serving/config/feature_config.proto +++ b/secretflow_serving/config/feature_config.proto @@ -28,10 +28,22 @@ message FeatureSourceConfig { } } +enum MockDataType { + // Placeholder for proto3 default value, do not use it. + INVALID_MOCK_DATA_TYPE = 0; + + // random value for each feature + MDT_RANDOM = 1; + // fixed value for each feature + MDT_FIXED = 2; +} + // Options for a mock feature source. -// Mock feature source will generates random values for the desired features。 +// Mock feature source will generates values(random or fixed, according to type) +// for the desired features. message MockOptions { - // Empty by designed + // default MDT_RANDOM + MockDataType type = 1; } // Options for a http feature source which should implement the feature service diff --git a/secretflow_serving/config/logging_config.proto b/secretflow_serving/config/logging_config.proto index 8394e8e..61b04e8 100644 --- a/secretflow_serving/config/logging_config.proto +++ b/secretflow_serving/config/logging_config.proto @@ -18,11 +18,16 @@ syntax = "proto3"; package secretflow.serving; enum LogLevel { + // Placeholder for proto3 default value, do not use it. INVALID_LOG_LEVEL = 0; + // debug DEBUG_LOG_LEVEL = 1; + // info INFO_LOG_LEVEL = 2; + // warn WARN_LOG_LEVEL = 3; + // error ERROR_LOG_LEVEL = 4; } diff --git a/secretflow_serving/config/model_config.proto b/secretflow_serving/config/model_config.proto index f9b2ddf..3b099f4 100644 --- a/secretflow_serving/config/model_config.proto +++ b/secretflow_serving/config/model_config.proto @@ -15,17 +15,16 @@ syntax = "proto3"; -import "secretflow_serving/config/tls_config.proto"; - package secretflow.serving; // Supported model source type enum SourceType { + // Placeholder for proto3 default value, do not use it. INVALID_SOURCE_TYPE = 0; // Local filesystem ST_FILE = 1; - // Oss + // S3 OSS ST_OSS = 2; } @@ -62,8 +61,8 @@ message ModelConfig { string source_path = 3; // Optional. - // The expect md5 of the model package - string source_md5 = 4; + // The expect sha256 of the model package + string source_sha256 = 4; SourceType source_type = 5; oneof kind { diff --git a/secretflow_serving/config/server_config.proto b/secretflow_serving/config/server_config.proto index b121196..8445ab4 100644 --- a/secretflow_serving/config/server_config.proto +++ b/secretflow_serving/config/server_config.proto @@ -44,4 +44,9 @@ message ServerConfig { // Server-level max number of requests processed in parallel // Default: 0 (unlimited) int32 max_concurrency = 14; + + // Number of pthreads that server runs to execute ops. + // If this option <= 0, use default value. + // Default: #cpu-cores + int32 op_exec_worker_num = 15; } diff --git a/secretflow_serving/core/link_func.cc b/secretflow_serving/core/link_func.cc index af36b5f..a74d5e7 100644 --- a/secretflow_serving/core/link_func.cc +++ b/secretflow_serving/core/link_func.cc @@ -16,24 +16,18 @@ #include "secretflow_serving/core/exception.h" -#include "secretflow_serving/protos/link_function.pb.h" - namespace secretflow::serving { -void ValidateLinkFuncType(const std::string& type) { - LinkFucntionType lf_type; - SERVING_ENFORCE(LinkFucntionType_Parse(type, &lf_type), +LinkFunctionType ParseLinkFuncType(const std::string& type) { + LinkFunctionType lf_type; + SERVING_ENFORCE(LinkFunctionType_Parse(type, &lf_type), errors::ErrorCode::UNEXPECTED_ERROR, - "unsupport link func type:{}", type); + "unsupported link func type:{}", type); + return lf_type; } template -T ApplyLinkFunc(T x, const std::string& type) { - LinkFucntionType lf_type; - SERVING_ENFORCE(LinkFucntionType_Parse(type, &lf_type), - errors::ErrorCode::UNEXPECTED_ERROR, - "unsupport link func type:{}", type); - +T ApplyLinkFunc(T x, LinkFunctionType lf_type) { auto ls7 = [](T x) -> T { return 5.00052959e-01 + 2.35176260e-01 * x - 3.97212202e-05 * std::pow(x, 2) - 1.23407424e-02 * std::pow(x, 3) + @@ -42,68 +36,65 @@ T ApplyLinkFunc(T x, const std::string& type) { }; switch (lf_type) { - case LinkFucntionType::LF_LOG: { + case LinkFunctionType::LF_LOG: { return std::exp(x); } - case LinkFucntionType::LF_LOGIT: { - return 1.0f / (1.0f + std::exp(-x)); + case LinkFunctionType::LF_LOGIT: { + return 1.0F / (1.0F + std::exp(-x)); } - case LinkFucntionType::LF_INVERSE: { + case LinkFunctionType::LF_INVERSE: { return std::exp(-x); } - case LinkFucntionType::LF_LOGIT_V2: { - return 0.5f * (x / std::sqrt(1 + std::pow(x, 2))) + 0.5f; - } - case LinkFucntionType::LF_RECIPROCAL: { - return 1.0f / x; + case LinkFunctionType::LF_RECIPROCAL: { + return 1.0F / x; } - case LinkFucntionType::LF_INDENTITY: { + case LinkFunctionType::LF_IDENTITY: { return x; } - case LinkFucntionType::LF_SIGMOID_RAW: { - return 1.0f / (1.0f + exp(-x)); + case LinkFunctionType::LF_SIGMOID_RAW: { + return 1.0F / (1.0F + exp(-x)); } - case LinkFucntionType::LF_SIGMOID_MM1: { - return 0.5f + 0.125f * x; + case LinkFunctionType::LF_SIGMOID_MM1: { + return 0.5F + 0.125F * x; } - case LinkFucntionType::LF_SIGMOID_MM3: { - return 0.5f + 0.197f * x - 0.004f * std::pow(x, 3); + case LinkFunctionType::LF_SIGMOID_MM3: { + return 0.5F + 0.197F * x - 0.004F * std::pow(x, 3); } - case LinkFucntionType::LF_SIGMOID_GA: { - return 0.5f + 0.15012f * x + 0.001593f * std::pow(x, 3); + case LinkFunctionType::LF_SIGMOID_GA: { + return 0.5F + 0.15012F * x + 0.001593F * std::pow(x, 3); } - case LinkFucntionType::LF_SIGMOID_T1: { - return 0.5f + 0.25f * x; + case LinkFunctionType::LF_SIGMOID_T1: { + return 0.5F + 0.25F * x; } - case LinkFucntionType::LF_SIGMOID_T3: { - return 0.5f + 0.25f * x - (1.0f / 48) * std::pow(x, 3); + case LinkFunctionType::LF_SIGMOID_T3: { + return 0.5F + 0.25F * x - (1.0F / 48) * std::pow(x, 3); } - case LinkFucntionType::LF_SIGMOID_T5: { - return 0.5f + 0.25f * x - (1.0f / 48) * std::pow(x, 3) + - (1.0f / 480) * std::pow(x, 5); + case LinkFunctionType::LF_SIGMOID_T5: { + return 0.5F + 0.25F * x - (1.0F / 48) * std::pow(x, 3) + + (1.0F / 480) * std::pow(x, 5); } - case LinkFucntionType::LF_SIGMOID_T7: { - return 0.5f + 0.25f * x - (1.0f / 48) * std::pow(x, 3) + - (1.0f / 480) * std::pow(x, 5) - (17.0f / 80640) * std::pow(x, 7); + case LinkFunctionType::LF_SIGMOID_T7: { + return 0.5F + 0.25F * x - (1.0F / 48) * std::pow(x, 3) + + (1.0F / 480) * std::pow(x, 5) - (17.0F / 80640) * std::pow(x, 7); } - case LinkFucntionType::LF_SIGMOID_T9: { - return 0.5f + 0.25f * x - (1.0f / 48) * std::pow(x, 3) + - (1.0f / 480) * std::pow(x, 5) - (17.0f / 80640) * std::pow(x, 7) + - (31.0f / 1451520) * std::pow(x, 9); + case LinkFunctionType::LF_SIGMOID_T9: { + return 0.5F + 0.25F * x - (1.0F / 48) * std::pow(x, 3) + + (1.0F / 480) * std::pow(x, 5) - (17.0F / 80640) * std::pow(x, 7) + + (31.0F / 1451520) * std::pow(x, 9); } - case LinkFucntionType::LF_SIGMOID_LS7: { + case LinkFunctionType::LF_SIGMOID_LS7: { return ls7(x); } - case LinkFucntionType::LF_SIGMOID_SEG3: { + case LinkFunctionType::LF_SIGMOID_SEG3: { if (x > 4) { return 1; } else if (x < -4) { return 0; } else { - return 0.5f + 0.125f * x; + return 0.5F + 0.125F * x; } } - case LinkFucntionType::LF_SIGMOID_SEG5: { + case LinkFunctionType::LF_SIGMOID_SEG5: { if (x > 35.75) { return 1; } else if (x > 3.75) { @@ -116,27 +107,27 @@ T ApplyLinkFunc(T x, const std::string& type) { return 0; } } - case LinkFucntionType::LF_SIGMOID_DF: { - return 0.5f * (x / (1.f + std::abs(x))) + 0.5f; + case LinkFunctionType::LF_SIGMOID_DF: { + return 0.5F * (x / (1.0F + std::abs(x))) + 0.5F; } - case LinkFucntionType::LF_SIGMOID_SR: { - return 0.5f * (x / std::sqrt(1.f + std::pow(x, 2))) + 0.5f; + case LinkFunctionType::LF_SIGMOID_SR: { + return 0.5F * (x / std::sqrt(1.0F + std::pow(x, 2))) + 0.5F; } - case LinkFucntionType::LF_SIGMOID_SEGLS: { + case LinkFunctionType::LF_SIGMOID_SEGLS: { if (std::abs(x) <= 5.87) { return ls7(x); } else { - return 0.5f * (x / std::sqrt(1.f + std::pow(x, 2))) + 0.5f; + return 0.5F * (x / std::sqrt(1.0F + std::pow(x, 2))) + 0.5F; } } default: { SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, - "not support link func type {}", static_cast(lf_type)); + "unsupported link func type {}", static_cast(lf_type)); } } } -template float ApplyLinkFunc(float, const std::string&); -template double ApplyLinkFunc(double, const std::string&); +template float ApplyLinkFunc(float, LinkFunctionType); +template double ApplyLinkFunc(double, LinkFunctionType); } // namespace secretflow::serving diff --git a/secretflow_serving/core/link_func.h b/secretflow_serving/core/link_func.h index 6660b01..976146b 100644 --- a/secretflow_serving/core/link_func.h +++ b/secretflow_serving/core/link_func.h @@ -16,11 +16,13 @@ #include +#include "secretflow_serving/protos/link_function.pb.h" + namespace secretflow::serving { -void ValidateLinkFuncType(const std::string& type); +LinkFunctionType ParseLinkFuncType(const std::string& type); template -T ApplyLinkFunc(T x, const std::string& type); +T ApplyLinkFunc(T x, LinkFunctionType type); } // namespace secretflow::serving diff --git a/secretflow_serving/feature_adapter/feature_adapter.cc b/secretflow_serving/feature_adapter/feature_adapter.cc index 7378829..a60016d 100644 --- a/secretflow_serving/feature_adapter/feature_adapter.cc +++ b/secretflow_serving/feature_adapter/feature_adapter.cc @@ -24,7 +24,7 @@ namespace secretflow::serving::feature { FeatureAdapter::FeatureAdapter( const FeatureSourceConfig& spec, const std::string& service_id, const std::string& party_id, - const std::shared_ptr& feature_schema) + const std::shared_ptr& feature_schema) : spec_(spec), service_id_(service_id), party_id_(party_id), @@ -33,15 +33,21 @@ FeatureAdapter::FeatureAdapter( void FeatureAdapter::FetchFeature(const Request& request, Response* response) { OnFetchFeature(request, response); - CheckFeatureValid(response->features); + CheckFeatureValid(request, response->features); } void FeatureAdapter::CheckFeatureValid( + const Request& request, const std::shared_ptr& features) { const auto& schema = features->schema(); SERVING_ENFORCE(schema->Equals(*feature_schema_), errors::ErrorCode::NOT_FOUND, "result schema does not match the request expect."); + SERVING_ENFORCE( + request.fs_param->query_datas().size() == features->num_rows(), + errors::ErrorCode::LOGIC_ERROR, + "query row_num {} should be equal to fetched row_num {}", + request.fs_param->query_datas().size(), features->num_rows()); } } // namespace secretflow::serving::feature diff --git a/secretflow_serving/feature_adapter/feature_adapter.h b/secretflow_serving/feature_adapter/feature_adapter.h index 1a6327b..9b86fcc 100644 --- a/secretflow_serving/feature_adapter/feature_adapter.h +++ b/secretflow_serving/feature_adapter/feature_adapter.h @@ -43,7 +43,7 @@ class FeatureAdapter { public: FeatureAdapter(const FeatureSourceConfig& spec, const std::string& service_id, const std::string& party_id, - const std::shared_ptr& feature_schema); + const std::shared_ptr& feature_schema); virtual ~FeatureAdapter() = default; virtual void FetchFeature(const Request& request, Response* response); @@ -51,7 +51,8 @@ class FeatureAdapter { protected: virtual void OnFetchFeature(const Request& request, Response* response) = 0; - void CheckFeatureValid(const std::shared_ptr& features); + void CheckFeatureValid(const Request& request, + const std::shared_ptr& features); protected: FeatureSourceConfig spec_; @@ -59,7 +60,7 @@ class FeatureAdapter { const std::string service_id_; const std::string party_id_; - const std::shared_ptr feature_schema_; + const std::shared_ptr feature_schema_; }; } // namespace secretflow::serving::feature diff --git a/secretflow_serving/feature_adapter/feature_adapter_factory.h b/secretflow_serving/feature_adapter/feature_adapter_factory.h index a5eabf2..cc306b8 100644 --- a/secretflow_serving/feature_adapter/feature_adapter_factory.h +++ b/secretflow_serving/feature_adapter/feature_adapter_factory.h @@ -26,7 +26,7 @@ class FeatureAdapterFactory : public Singleton { public: using CreateAdapterFunc = std::function( const FeatureSourceConfig&, const std::string&, const std::string&, - const std::shared_ptr&)>; + const std::shared_ptr&)>; template void Register(FeatureSourceConfig::OptionsCase opts_case) { @@ -38,7 +38,7 @@ class FeatureAdapterFactory : public Singleton { opts_case, [](const FeatureSourceConfig& spec, const std::string& service_id, const std::string& party_id, - const std::shared_ptr& feature_schema) { + const std::shared_ptr& feature_schema) { return std::make_unique(spec, service_id, party_id, feature_schema); }); @@ -47,7 +47,7 @@ class FeatureAdapterFactory : public Singleton { std::unique_ptr Create( const FeatureSourceConfig& spec, const std::string& service_id, const std::string& party_id, - const std::shared_ptr& feature_schema) { + const std::shared_ptr& feature_schema) { auto creator = creators_[spec.options_case()]; YACL_ENFORCE(creator, "no creator registered for operator type: {}", static_cast(spec.options_case())); diff --git a/secretflow_serving/feature_adapter/file_adapter.cc b/secretflow_serving/feature_adapter/file_adapter.cc index 897867a..ad3ac2f 100644 --- a/secretflow_serving/feature_adapter/file_adapter.cc +++ b/secretflow_serving/feature_adapter/file_adapter.cc @@ -23,10 +23,10 @@ namespace secretflow::serving::feature { -FileAdapater::FileAdapater(const FeatureSourceConfig& spec, - const std::string& service_id, - const std::string& party_id, - const std::shared_ptr& feature_schema) +FileAdapter::FileAdapter( + const FeatureSourceConfig& spec, const std::string& service_id, + const std::string& party_id, + const std::shared_ptr& feature_schema) : FeatureAdapter(spec, service_id, party_id, feature_schema) { SERVING_ENFORCE(spec_.has_csv_opts(), errors::ErrorCode::INVALID_ARGUMENT, "invalid mock options"); @@ -60,7 +60,7 @@ FileAdapater::FileAdapater(const FeatureSourceConfig& spec, SERVING_GET_ARROW_RESULT(csv_reader->Read(), csv_table_); } -void FileAdapater::OnFetchFeature(const Request& request, Response* response) { +void FileAdapter::OnFetchFeature(const Request& request, Response* response) { // query data is unique const auto& query_datas = request.fs_param->query_datas(); std::set query_data_set(query_datas.begin(), query_datas.end()); @@ -106,6 +106,6 @@ void FileAdapater::OnFetchFeature(const Request& request, Response* response) { response->features); } -REGISTER_ADAPTER(FeatureSourceConfig::OptionsCase::kCsvOpts, FileAdapater); +REGISTER_ADAPTER(FeatureSourceConfig::OptionsCase::kCsvOpts, FileAdapter); } // namespace secretflow::serving::feature diff --git a/secretflow_serving/feature_adapter/file_adapter.h b/secretflow_serving/feature_adapter/file_adapter.h index 3c08d9f..d4b470e 100644 --- a/secretflow_serving/feature_adapter/file_adapter.h +++ b/secretflow_serving/feature_adapter/file_adapter.h @@ -21,13 +21,13 @@ namespace secretflow::serving::feature { -class FileAdapater : public FeatureAdapter { +class FileAdapter : public FeatureAdapter { public: - FileAdapater(const FeatureSourceConfig& spec, const std::string& service_id, - const std::string& party_id, - const std::shared_ptr& feature_schema); + FileAdapter(const FeatureSourceConfig& spec, const std::string& service_id, + const std::string& party_id, + const std::shared_ptr& feature_schema); - ~FileAdapater() = default; + ~FileAdapter() override = default; protected: void OnFetchFeature(const Request& request, Response* response) override; diff --git a/secretflow_serving/feature_adapter/http_adapter.cc b/secretflow_serving/feature_adapter/http_adapter.cc index fbc2370..3923e26 100644 --- a/secretflow_serving/feature_adapter/http_adapter.cc +++ b/secretflow_serving/feature_adapter/http_adapter.cc @@ -29,7 +29,7 @@ namespace secretflow::serving::feature { namespace { -const size_t kConnectTimoutMs = 500; +const size_t kConnectTimeoutMs = 500; const size_t kTimeoutMs = 1000; @@ -59,7 +59,7 @@ errors::ErrorCode MappingErrorCode(int fs_code) { HttpFeatureAdapter::HttpFeatureAdapter( const FeatureSourceConfig& spec, const std::string& service_id, const std::string& party_id, - const std::shared_ptr& feature_schema) + const std::shared_ptr& feature_schema) : FeatureAdapter(spec, service_id, party_id, feature_schema) { SERVING_ENFORCE(spec_.has_http_opts(), errors::ErrorCode::INVALID_ARGUMENT, "invalid http options"); @@ -81,7 +81,7 @@ HttpFeatureAdapter::HttpFeatureAdapter( http_opts.endpoint(), "http", http_opts.enable_lb(), http_opts.timeout_ms() > 0 ? http_opts.timeout_ms() : kTimeoutMs, http_opts.connect_timeout_ms() > 0 ? http_opts.connect_timeout_ms() - : kConnectTimoutMs, + : kConnectTimeoutMs, http_opts.has_tls_config() ? &http_opts.tls_config() : nullptr); } @@ -121,7 +121,7 @@ std::string HttpFeatureAdapter::SerializeRequest(const Request& request) { auto status = google::protobuf::util::MessageToJsonString(batch_request, &json_str, options); if (!status.ok()) { - SERVING_THROW(errors::ErrorCode::SERIALIZE_FAILD, + SERVING_THROW(errors::ErrorCode::SERIALIZE_FAILED, "serialize fetch feature request failed: {}", status.ToString()); } @@ -134,7 +134,7 @@ void HttpFeatureAdapter::DeserializeResponse(const std::string& res_context, spis::BatchFetchFeatureResponse batch_response; auto status = ::google::protobuf::util::JsonStringToMessage(res_context, &batch_response); - SERVING_ENFORCE(status.ok(), errors::ErrorCode::DESERIALIZE_FAILD, + SERVING_ENFORCE(status.ok(), errors::ErrorCode::DESERIALIZE_FAILED, "deserialize response context({}) failed: {}", res_context, status.ToString()); SERVING_ENFORCE(batch_response.status().code() == spis::ErrorCode::OK, @@ -147,7 +147,8 @@ void HttpFeatureAdapter::DeserializeResponse(const std::string& res_context, response->header->mutable_data()->swap( *batch_response.mutable_header()->mutable_data()); - response->features = FeaturesToTable(batch_response.features()); + response->features = + FeaturesToTable(batch_response.features(), feature_schema_); } REGISTER_ADAPTER(FeatureSourceConfig::OptionsCase::kHttpOpts, diff --git a/secretflow_serving/feature_adapter/http_adapter.h b/secretflow_serving/feature_adapter/http_adapter.h index 768af92..08e6732 100644 --- a/secretflow_serving/feature_adapter/http_adapter.h +++ b/secretflow_serving/feature_adapter/http_adapter.h @@ -24,10 +24,11 @@ namespace secretflow::serving::feature { class HttpFeatureAdapter : public FeatureAdapter { public: - HttpFeatureAdapter(const FeatureSourceConfig& spec, - const std::string& service_id, const std::string& party_id, - const std::shared_ptr& feature_schema); - ~HttpFeatureAdapter() = default; + HttpFeatureAdapter( + const FeatureSourceConfig& spec, const std::string& service_id, + const std::string& party_id, + const std::shared_ptr& feature_schema); + ~HttpFeatureAdapter() override = default; protected: void OnFetchFeature(const Request& request, Response* response) override; diff --git a/secretflow_serving/feature_adapter/mock_adapter.cc b/secretflow_serving/feature_adapter/mock_adapter.cc index 3aaa3ae..f77af7a 100644 --- a/secretflow_serving/feature_adapter/mock_adapter.cc +++ b/secretflow_serving/feature_adapter/mock_adapter.cc @@ -34,16 +34,19 @@ std::shared_ptr CreateArray(size_t rows, Fn generator) { } // namespace -MockAdapater::MockAdapater(const FeatureSourceConfig& spec, - const std::string& service_id, - const std::string& party_id, - const std::shared_ptr& feature_schema) +MockAdapter::MockAdapter( + const FeatureSourceConfig& spec, const std::string& service_id, + const std::string& party_id, + const std::shared_ptr& feature_schema) : FeatureAdapter(spec, service_id, party_id, feature_schema) { SERVING_ENFORCE(spec_.has_mock_opts(), errors::ErrorCode::INVALID_ARGUMENT, "invalid mock options"); + mock_type_ = spec_.mock_opts().type() != MockDataType::INVALID_MOCK_DATA_TYPE + ? spec_.mock_opts().type() + : MockDataType::MDT_RANDOM; } -void MockAdapater::OnFetchFeature(const Request& request, Response* response) { +void MockAdapter::OnFetchFeature(const Request& request, Response* response) { SERVING_ENFORCE(!request.fs_param->query_datas().empty(), errors::ErrorCode::INVALID_ARGUMENT, "get empty feature service query datas."); @@ -55,23 +58,86 @@ void MockAdapater::OnFetchFeature(const Request& request, Response* response) { std::shared_ptr array; const auto& f = feature_schema_->field(c); if (f->type()->id() == arrow::Type::type::BOOL) { - const auto generator = [] { return std::rand() % 2; }; + const auto generator = [this] { + return mock_type_ == MockDataType::MDT_FIXED ? 1 : std::rand() % 2; + }; array = CreateArray(rows, generator); + } else if (f->type()->id() == arrow::Type::type::INT8) { + const auto generator = [this] { + return mock_type_ == MockDataType::MDT_FIXED + ? 1 + : (std::rand() % std::numeric_limits::max()); + }; + array = CreateArray(rows, generator); + } else if (f->type()->id() == arrow::Type::type::UINT8) { + const auto generator = [this] { + return mock_type_ == MockDataType::MDT_FIXED + ? 1 + : (std::numeric_limits::max()); + }; + array = CreateArray(rows, generator); + } else if (f->type()->id() == arrow::Type::type::INT16) { + const auto generator = [this] { + return mock_type_ == MockDataType::MDT_FIXED + ? 1 + : (std::numeric_limits::max()); + }; + array = CreateArray(rows, generator); + } else if (f->type()->id() == arrow::Type::type::UINT16) { + const auto generator = [this] { + return mock_type_ == MockDataType::MDT_FIXED + ? 1 + : (std::numeric_limits::max()); + }; + array = CreateArray(rows, generator); } else if (f->type()->id() == arrow::Type::type::INT32) { - const auto generator = [] { return std::rand(); }; + const auto generator = [this] { + return mock_type_ == MockDataType::MDT_FIXED ? 1 : std::rand(); + }; array = CreateArray(rows, generator); + } else if (f->type()->id() == arrow::Type::type::UINT32) { + const auto generator = [this] { + return mock_type_ == MockDataType::MDT_FIXED ? 1 : std::rand(); + }; + array = CreateArray(rows, generator); } else if (f->type()->id() == arrow::Type::type::INT64) { - const auto generator = [] { return std::rand() * std::rand(); }; + const auto generator = [this] { + return mock_type_ == MockDataType::MDT_FIXED ? 1 : std::rand(); + }; array = CreateArray(rows, generator); + } else if (f->type()->id() == arrow::Type::type::UINT64) { + const auto generator = [this] { + return mock_type_ == MockDataType::MDT_FIXED ? 1 : std::rand(); + }; + array = CreateArray(rows, generator); } else if (f->type()->id() == arrow::Type::type::FLOAT) { - const auto generator = [] { return std::rand() / float(RAND_MAX); }; + const auto generator = [this] { + return mock_type_ == MockDataType::MDT_FIXED + ? 1 + : ((float)std::rand() / float(RAND_MAX)); + }; array = CreateArray(rows, generator); } else if (f->type()->id() == arrow::Type::type::DOUBLE) { - const auto generator = [] { return std::rand() / double(RAND_MAX); }; + const auto generator = [this] { + return mock_type_ == MockDataType::MDT_FIXED + ? 1 + : ((double)std::rand() / (RAND_MAX)); + }; array = CreateArray(rows, generator); } else if (f->type()->id() == arrow::Type::type::STRING) { - const auto generator = [] { return std::to_string(std::rand()); }; + const auto generator = [this] { + return mock_type_ == MockDataType::MDT_FIXED + ? "1" + : std::to_string(std::rand()); + }; array = CreateArray(rows, generator); + } else if (f->type()->id() == arrow::Type::type::BINARY) { + const auto generator = [this] { + return mock_type_ == MockDataType::MDT_FIXED + ? "1" + : std::to_string(std::rand()); + }; + array = CreateArray(rows, generator); } else { SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, "unkown field type {}", f->type()->ToString()); @@ -82,6 +148,6 @@ void MockAdapater::OnFetchFeature(const Request& request, Response* response) { MakeRecordBatch(feature_schema_, rows, std::move(arrays)); } -REGISTER_ADAPTER(FeatureSourceConfig::OptionsCase::kMockOpts, MockAdapater); +REGISTER_ADAPTER(FeatureSourceConfig::OptionsCase::kMockOpts, MockAdapter); } // namespace secretflow::serving::feature diff --git a/secretflow_serving/feature_adapter/mock_adapter.h b/secretflow_serving/feature_adapter/mock_adapter.h index c53e0fc..dcf8365 100644 --- a/secretflow_serving/feature_adapter/mock_adapter.h +++ b/secretflow_serving/feature_adapter/mock_adapter.h @@ -18,16 +18,18 @@ namespace secretflow::serving::feature { -class MockAdapater : public FeatureAdapter { +class MockAdapter : public FeatureAdapter { public: - MockAdapater(const FeatureSourceConfig& spec, const std::string& service_id, - const std::string& party_id, - const std::shared_ptr& feature_schema); + MockAdapter(const FeatureSourceConfig& spec, const std::string& service_id, + const std::string& party_id, + const std::shared_ptr& feature_schema); - ~MockAdapater() = default; + ~MockAdapter() override = default; protected: void OnFetchFeature(const Request& request, Response* response) override; + + MockDataType mock_type_; }; } // namespace secretflow::serving::feature diff --git a/secretflow_serving/feature_adapter/mock_adapter_test.cc b/secretflow_serving/feature_adapter/mock_adapter_test.cc index f89f582..f6251c1 100644 --- a/secretflow_serving/feature_adapter/mock_adapter_test.cc +++ b/secretflow_serving/feature_adapter/mock_adapter_test.cc @@ -32,11 +32,21 @@ TEST_F(MockAdapterTest, Work) { FeatureSourceConfig config; (void)config.mutable_mock_opts(); - auto model_schema = arrow::schema( - {arrow::field("x1", arrow::int32()), arrow::field("x2", arrow::int64()), - arrow::field("x3", arrow::float32()), - arrow::field("x4", arrow::float64()), arrow::field("x5", arrow::utf8()), - arrow::field("x6", arrow::boolean())}); + auto model_schema = arrow::schema({ + arrow::field("x1", arrow::int8()), + arrow::field("x2", arrow::uint8()), + arrow::field("x3", arrow::int16()), + arrow::field("x4", arrow::uint16()), + arrow::field("x5", arrow::int32()), + arrow::field("x6", arrow::uint32()), + arrow::field("x7", arrow::int64()), + arrow::field("x8", arrow::uint64()), + arrow::field("x9", arrow::float32()), + arrow::field("x10", arrow::float64()), + arrow::field("x11", arrow::boolean()), + arrow::field("x12", arrow::utf8()), + arrow::field("x13", arrow::binary()), + }); auto adapter = FeatureAdapterFactory::GetInstance()->Create( config, kTestModelServiceId, kTestPartyId, model_schema); diff --git a/secretflow_serving/framework/BUILD.bazel b/secretflow_serving/framework/BUILD.bazel index bf4b715..c52dbdb 100644 --- a/secretflow_serving/framework/BUILD.bazel +++ b/secretflow_serving/framework/BUILD.bazel @@ -34,6 +34,8 @@ serving_cc_library( deps = [ ":propagator", "//secretflow_serving/ops", + "//secretflow_serving/util:thread_pool", + "//secretflow_serving/util:thread_safe_queue", ], ) @@ -49,62 +51,76 @@ serving_cc_test( serving_cc_library( name = "executable", + srcs = ["executable.cc"], hdrs = ["executable.h"], deps = [ ":executor", - "@org_apache_arrow//:arrow", ], ) serving_cc_library( - name = "executable_impl", - srcs = ["executable_impl.cc"], - hdrs = ["executable_impl.h"], + name = "model_info_collector", + srcs = ["model_info_collector.cc"], + hdrs = ["model_info_collector.h"], deps = [ - ":executable", - ":executor", + "//secretflow_serving/apis:model_service_cc_proto", + "//secretflow_serving/core:exception", + "//secretflow_serving/util:utils", + "@com_github_brpc_brpc//:brpc", ], ) serving_cc_library( - name = "predictor", - hdrs = ["predictor.h"], + name = "execute_context", + srcs = ["execute_context.cc"], + hdrs = ["execute_context.h"], deps = [ - "//secretflow_serving/apis:common_cc_proto", "//secretflow_serving/apis:execution_service_cc_proto", "//secretflow_serving/apis:prediction_service_cc_proto", + "//secretflow_serving/core:exception", + "//secretflow_serving/ops:graph", "//secretflow_serving/server:execution_core", + "//secretflow_serving/util:utils", + "@com_github_brpc_brpc//:brpc", ], ) serving_cc_library( - name = "predictor_impl", - srcs = ["predictor_impl.cc"], - hdrs = ["predictor_impl.h"], + name = "predictor", + srcs = ["predictor.cc"], + hdrs = ["predictor.h"], deps = [ - ":predictor", - "//secretflow_serving/core:exception", + ":execute_context", + "//secretflow_serving/apis:common_cc_proto", + "//secretflow_serving/apis:prediction_service_cc_proto", + "//secretflow_serving/server:execution_core", "//secretflow_serving/util:arrow_helper", "//secretflow_serving/util:utils", "@com_github_brpc_brpc//:brpc", - "@yacl//yacl/utils:elapsed_timer", ], ) serving_cc_test( - name = "predictor_impl_test", - srcs = ["predictor_impl_test.cc"], + name = "predictor_test", + srcs = ["predictor_test.cc"], deps = [ - ":predictor_impl", + ":predictor", + ], +) + +serving_cc_test( + name = "execute_context_test", + srcs = ["execute_context_test.cc"], + deps = [ + ":execute_context", ], ) serving_cc_library( - name = "interface", + name = "loader", hdrs = ["loader.h"], deps = [ - ":executable", - "//secretflow_serving/core:exception", + "//secretflow_serving/protos:bundle_cc_proto", ], ) @@ -113,10 +129,7 @@ serving_cc_library( srcs = ["model_loader.cc"], hdrs = ["model_loader.h"], deps = [ - ":executable_impl", - ":interface", - ":predictor_impl", - "//secretflow_serving/protos:bundle_cc_proto", + ":loader", "//secretflow_serving/util:sys_util", "//secretflow_serving/util:utils", ], diff --git a/secretflow_serving/framework/executable_impl.cc b/secretflow_serving/framework/executable.cc similarity index 73% rename from secretflow_serving/framework/executable_impl.cc rename to secretflow_serving/framework/executable.cc index 7855807..06dbb29 100644 --- a/secretflow_serving/framework/executable_impl.cc +++ b/secretflow_serving/framework/executable.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "secretflow_serving/framework/executable_impl.h" +#include "secretflow_serving/framework/executable.h" #include "spdlog/spdlog.h" @@ -20,28 +20,27 @@ namespace secretflow::serving { -ExecutableImpl::ExecutableImpl(std::vector> executors) - : executors_(std::move(executors)) { - SERVING_ENFORCE(!executors_.empty(), errors::ErrorCode::LOGIC_ERROR); -} +Executable::Executable(std::vector> executors) + : executors_(std::move(executors)) {} -void ExecutableImpl::Run(Task& task) { +void Executable::Run(Task& task) { SERVING_ENFORCE(task.id < executors_.size(), errors::ErrorCode::LOGIC_ERROR); auto executor = executors_[task.id]; if (task.features) { task.outputs = executor->Run(task.features); } else { SERVING_ENFORCE(!task.node_inputs->empty(), errors::ErrorCode::LOGIC_ERROR); - task.outputs = executor->Run(task.node_inputs); + task.outputs = executor->Run(*(task.node_inputs)); } - SPDLOG_DEBUG("ExecutableImpl::Run end, task.outputs.size:{}", + SPDLOG_DEBUG("Executable::Run end, task.outputs.size:{}", task.outputs->size()); } const std::shared_ptr& -ExecutableImpl::GetInputFeatureSchema() { +Executable::GetInputFeatureSchema() { const auto& schema = executors_.front()->GetInputFeatureSchema(); + SERVING_ENFORCE(schema, errors::ErrorCode::LOGIC_ERROR); return schema; } diff --git a/secretflow_serving/framework/executable.h b/secretflow_serving/framework/executable.h index a860553..397ac78 100644 --- a/secretflow_serving/framework/executable.h +++ b/secretflow_serving/framework/executable.h @@ -14,11 +14,6 @@ #pragma once -#include -#include - -#include "arrow/api.h" - #include "secretflow_serving/framework/executor.h" namespace secretflow::serving { @@ -30,7 +25,8 @@ class Executable { // input std::shared_ptr features; - std::shared_ptr>> + std::shared_ptr< + std::unordered_map>> node_inputs; // output @@ -38,13 +34,15 @@ class Executable { }; public: - explicit Executable() = default; + explicit Executable(std::vector> executors); virtual ~Executable() = default; - virtual const std::shared_ptr& - GetInputFeatureSchema() = 0; + virtual void Run(Task& task); + + virtual const std::shared_ptr& GetInputFeatureSchema(); - virtual void Run(Task& task) = 0; + private: + std::vector> executors_; }; } // namespace secretflow::serving diff --git a/secretflow_serving/framework/executable_impl.h b/secretflow_serving/framework/executable_impl.h deleted file mode 100644 index 6a0cdd9..0000000 --- a/secretflow_serving/framework/executable_impl.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "secretflow_serving/framework/executable.h" -#include "secretflow_serving/framework/executor.h" - -namespace secretflow::serving { - -class ExecutableImpl : public Executable { - public: - explicit ExecutableImpl(std::vector> executors); - ~ExecutableImpl() override = default; - - void Run(Task& task) override; - - const std::shared_ptr& GetInputFeatureSchema() override; - - private: - std::vector> executors_; -}; - -} // namespace secretflow::serving diff --git a/secretflow_serving/framework/execute_context.cc b/secretflow_serving/framework/execute_context.cc new file mode 100644 index 0000000..b5f0c6d --- /dev/null +++ b/secretflow_serving/framework/execute_context.cc @@ -0,0 +1,129 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/framework/execute_context.h" + +#include "secretflow_serving/core/exception.h" +#include "secretflow_serving/util/utils.h" + +namespace secretflow::serving { + +void ExecuteContext::CheckAndUpdateResponse() { + CheckAndUpdateResponse(exec_res_); +} +void ExecuteContext::CheckAndUpdateResponse( + const apis::ExecuteResponse& exec_res) { + if (!CheckStatusOk(exec_res.status())) { + SERVING_THROW( + exec_res.status().code(), + fmt::format("{} exec failed: code({}), {}", target_id_, + exec_res.status().code(), exec_res.status().msg())); + } + response_->mutable_header()->mutable_data()->insert( + exec_res.header().data().begin(), exec_res.header().data().end()); +} + +void ExeResponseToIoMap( + apis::ExecuteResponse& exec_res, + std::unordered_map>* + node_io_map) { + auto result = exec_res.mutable_result(); + for (int i = 0; i < result->nodes_size(); ++i) { + auto result_node_io = result->mutable_nodes(i); + auto prev_insert_iter = node_io_map->find(result_node_io->name()); + if (prev_insert_iter != node_io_map->end()) { + // found node, merge ios + auto& target_node_io = prev_insert_iter->second; + SERVING_ENFORCE(target_node_io->ios_size() == result_node_io->ios_size(), + errors::ErrorCode::LOGIC_ERROR); + for (int io_index = 0; io_index < target_node_io->ios_size(); + ++io_index) { + auto target_io = target_node_io->mutable_ios(io_index); + auto io = result_node_io->mutable_ios(io_index); + for (int data_index = 0; data_index < io->datas_size(); ++data_index) { + target_io->add_datas(std::move(*(io->mutable_datas(data_index)))); + } + } + } else { + auto node_name = result_node_io->name(); + node_io_map->emplace(node_name, std::make_shared( + std::move(*result_node_io))); + } + } +} + +void ExecuteContext::GetResultNodeIo( + std::unordered_map>* + node_io_map) { + ExeResponseToIoMap(exec_res_, node_io_map); +} + +void ExecuteContext::SetFeatureSource() { + auto feature_source = exec_req_.mutable_feature_source(); + if (execution_->IsEntry()) { + // entry execution need features + // get target_id's feature param + if (target_id_ == local_id_ && request_->predefined_features_size() != 0) { + // only loacl execute will use `predefined_features` + feature_source->set_type(apis::FeatureSourceType::FS_PREDEFINED); + feature_source->mutable_predefineds()->CopyFrom( + request_->predefined_features()); + } else { + feature_source->set_type(apis::FeatureSourceType::FS_SERVICE); + auto iter = request_->fs_params().find(target_id_); + SERVING_ENFORCE(iter != request_->fs_params().end(), + serving::errors::LOGIC_ERROR, + "missing {}'s feature params", target_id_); + feature_source->mutable_fs_param()->CopyFrom(iter->second); + } + } else { + feature_source->set_type(apis::FeatureSourceType::FS_NONE); + } +} + +ExecuteContext::ExecuteContext(const apis::PredictRequest* request, + apis::PredictResponse* response, + const std::shared_ptr& execution, + std::string target_id, std::string local_id) + : request_(request), + response_(response), + local_id_(std::move(local_id)), + target_id_(std::move(target_id)), + execution_(execution) { + exec_req_.mutable_header()->CopyFrom(request_->header()); + exec_req_.set_requester_id(local_id_); + exec_req_.mutable_service_spec()->CopyFrom(request_->service_spec()); + + SetFeatureSource(); +} + +void ExecuteContext::Execute( + std::shared_ptr<::google::protobuf::RpcChannel> channel, + brpc::Controller* cntl) { + apis::ExecutionService_Stub stub(channel.get()); + stub.Execute(cntl, &exec_req_, &exec_res_, brpc::DoNothing()); +} + +void ExecuteContext::Execute(std::shared_ptr execution_core) { + execution_core->Execute(&exec_req_, &exec_res_); +} + +void RemoteExecute::Run() { + // semisynchronous call + exe_ctx_.Execute(channel_, &cntl_); + + executing_ = true; +} + +} // namespace secretflow::serving diff --git a/secretflow_serving/framework/execute_context.h b/secretflow_serving/framework/execute_context.h new file mode 100644 index 0000000..2b63912 --- /dev/null +++ b/secretflow_serving/framework/execute_context.h @@ -0,0 +1,192 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "brpc/controller.h" + +#include "secretflow_serving/ops/graph.h" +#include "secretflow_serving/server/execution_core.h" + +#include "secretflow_serving/apis/execution_service.pb.h" +#include "secretflow_serving/apis/prediction_service.pb.h" + +namespace secretflow::serving { + +void ExeResponseToIoMap( + apis::ExecuteResponse& exec_res, + std::unordered_map>* + node_io_map); + +class ExecuteContext { + public: + ExecuteContext(const apis::PredictRequest* request, + apis::PredictResponse* response, + const std::shared_ptr& execution, + std::string target_id, std::string local_id); + + template < + typename T, + typename = std::enable_if_t, + std::unordered_map>>>> + void SetEntryNodesInputs(T&& node_io_map) { + if (node_io_map.empty()) { + return; + } + auto task = exec_req_.mutable_task(); + task->set_execution_id(execution_->id()); + auto entry_nodes = execution_->GetEntryNodes(); + for (const auto& n : entry_nodes) { + auto entry_node_io = task->add_nodes(); + entry_node_io->set_name(n->GetName()); + for (const auto& e : n->in_edges()) { + auto iter = node_io_map.find(e->src_node()); + SERVING_ENFORCE(iter != node_io_map.end(), + errors::ErrorCode::LOGIC_ERROR, + "Input of {} cannot be found in ctx(size:{})", + e->src_node(), node_io_map.size()); + for (auto& io : *(iter->second->mutable_ios())) { + if constexpr (std::is_lvalue_reference_v) { + *(entry_node_io->mutable_ios()->Add()) = io; + } else { + entry_node_io->mutable_ios()->Add(std::move(io)); + } + } + } + } + } + + void Execute(std::shared_ptr<::google::protobuf::RpcChannel> channel, + brpc::Controller* cntl); + void Execute(std::shared_ptr execution_core); + + void GetResultNodeIo( + std::unordered_map>* + node_io_map); + + void CheckAndUpdateResponse(const apis::ExecuteResponse& exec_res); + void CheckAndUpdateResponse(); + + const std::string& LocalId() const { return local_id_; } + const std::string& TargetId() const { return target_id_; } + + private: + void SetFeatureSource(); + + protected: + const apis::PredictRequest* request_; + apis::PredictResponse* response_; + + std::string local_id_; + std::string target_id_; + std::shared_ptr execution_; + + std::string session_id_; + + apis::ExecuteRequest exec_req_; + apis::ExecuteResponse exec_res_; +}; + +class ExecuteBase { + public: + ExecuteBase(const apis::PredictRequest* request, + apis::PredictResponse* response, + const std::shared_ptr& execution, + std::string target_id, std::string local_id) + : exe_ctx_{request, response, execution, std::move(target_id), + std::move(local_id)} {} + virtual ~ExecuteBase() = default; + + void SetInputs(std::unordered_map>& + node_io_map) { + exe_ctx_.SetEntryNodesInputs(node_io_map); + } + void SetInputs( + std::unordered_map>&& + node_io_map) { + exe_ctx_.SetEntryNodesInputs(std::move(node_io_map)); + } + virtual void GetOutputs( + std::unordered_map>* + node_io_map) { + exe_ctx_.GetResultNodeIo(node_io_map); + } + + virtual void Run() = 0; + + protected: + ExecuteContext exe_ctx_; +}; + +class RemoteExecute : public ExecuteBase, + public std::enable_shared_from_this { + public: + RemoteExecute(const apis::PredictRequest* request, + apis::PredictResponse* response, + const std::shared_ptr& execution, + std::string target_id, std::string local_id, + std::shared_ptr<::google::protobuf::RpcChannel> channel) + : ExecuteBase{request, response, execution, std::move(target_id), + std::move(local_id)}, + channel_(std::move(channel)) {} + + virtual ~RemoteExecute() { + if (executing_) { + Cancel(); + } + } + + virtual void Run() override; + virtual void Cancel() { + if (executing_) { + brpc::StartCancel(cntl_.call_id()); + } + } + virtual void WaitToFinish() { + brpc::Join(cntl_.call_id()); + SERVING_ENFORCE(!cntl_.Failed(), errors::ErrorCode::NETWORK_ERROR, + "call ({}) from ({}) execute failed, msg:{}", + exe_ctx_.TargetId(), exe_ctx_.LocalId(), cntl_.ErrorText()); + executing_ = false; + exe_ctx_.CheckAndUpdateResponse(); + } + + protected: + std::shared_ptr<::google::protobuf::RpcChannel> channel_; + brpc::Controller cntl_; + bool executing_{false}; +}; + +class LocalExecute : public ExecuteBase { + public: + LocalExecute(const apis::PredictRequest* request, + apis::PredictResponse* response, + const std::shared_ptr& execution, + std::string target_id, std::string local_id, + std::shared_ptr execution_core) + : ExecuteBase{request, response, execution, std::move(target_id), + std::move(local_id)}, + execution_core_(std::move(execution_core)) {} + + void Run() override { + exe_ctx_.Execute(execution_core_); + exe_ctx_.CheckAndUpdateResponse(); + } + + protected: + std::shared_ptr execution_core_; +}; + +} // namespace secretflow::serving diff --git a/secretflow_serving/framework/execute_context_test.cc b/secretflow_serving/framework/execute_context_test.cc new file mode 100644 index 0000000..2d578df --- /dev/null +++ b/secretflow_serving/framework/execute_context_test.cc @@ -0,0 +1,284 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/framework/execute_context.h" + +#include "brpc/channel.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "secretflow_serving/ops/op_factory.h" +#include "secretflow_serving/ops/op_kernel_factory.h" +#include "secretflow_serving/util/utils.h" + +namespace secretflow::serving { + +namespace op { + +class MockOpKernel0 : public OpKernel { + public: + explicit MockOpKernel0(OpKernelOptions opts) : OpKernel(std::move(opts)) {} + + void DoCompute(ComputeContext* ctx) override {} + void BuildInputSchema() override {} + void BuildOutputSchema() override {} +}; + +class MockOpKernel1 : public OpKernel { + public: + explicit MockOpKernel1(OpKernelOptions opts) : OpKernel(std::move(opts)) {} + + void DoCompute(ComputeContext* ctx) override {} + void BuildInputSchema() override {} + void BuildOutputSchema() override {} +}; + +REGISTER_OP_KERNEL(TEST_OP_0, MockOpKernel0); +REGISTER_OP_KERNEL(TEST_OP_1, MockOpKernel1); +REGISTER_OP(TEST_OP_0, "0.0.1", "test_desc") + .StringAttr("attr_s", "attr_s_desc", false, false) + .Input("input", "input_desc") + .Output("output", "output_desc"); +REGISTER_OP(TEST_OP_1, "0.0.1", "test_desc") + .Mergeable() + .Returnable() + .StringAttr("attr_s", "attr_s_desc", false, false) + .Input("input", "input_desc") + .Output("output", "output_desc"); + +} // namespace op + +class MockExecuteContext : public ExecuteContext { + public: + using ExecuteContext::ExecuteContext; + const apis::ExecuteRequest& ExeRequest() const { return exec_req_; } + apis::ExecuteResponse& ExeResponse() { return exec_res_; } + const apis::ExecuteResponse& ExeResponse() const { return exec_res_; } +}; + +class ExecuteContextTest : public ::testing::Test { + protected: + void SetUp() override { + // mock execution + std::vector node_def_jsons = { + R"JSON( +{ + "name": "mock_node_1", + "op": "TEST_OP_0", +} +)JSON", + R"JSON( +{ + "name": "mock_node_2", + "op": "TEST_OP_1", + "parents": [ "mock_node_1" ], +} +)JSON"}; + + std::vector execution_def_jsons = { + R"JSON( +{ + "nodes": [ + "mock_node_1" + ], + "config": { + "dispatch_type": "DP_ALL" + } +} +)JSON", + R"JSON( +{ + "nodes": [ + "mock_node_2" + ], + "config": { + "dispatch_type": "DP_ANYONE" + } +} +)JSON"}; + + // build node + std::unordered_map> nodes; + for (const auto& j : node_def_jsons) { + NodeDef node_def; + JsonToPb(j, &node_def); + auto node = std::make_shared(std::move(node_def)); + nodes.emplace(node->GetName(), node); + } + // build edge + for (const auto& pair : nodes) { + const auto& input_nodes = pair.second->GetInputNodeNames(); + for (size_t i = 0; i < input_nodes.size(); ++i) { + auto n_iter = nodes.find(input_nodes[i]); + SERVING_ENFORCE(n_iter != nodes.end(), errors::ErrorCode::LOGIC_ERROR); + auto edge = std::make_shared(n_iter->first, pair.first, i); + n_iter->second->AddOutEdge(edge); + pair.second->AddInEdge(edge); + } + } + std::vector> executions; + for (size_t i = 0; i < execution_def_jsons.size(); ++i) { + ExecutionDef executino_def; + JsonToPb(execution_def_jsons[i], &executino_def); + + std::unordered_map> e_nodes; + for (const auto& n : executino_def.nodes()) { + e_nodes.emplace(n, nodes.find(n)->second); + } + + executions.emplace_back(std::make_shared( + i, std::move(executino_def), std::move(e_nodes))); + } + + local_id_ = "alice"; + remote_id_ = "bob"; + executions_ = std::move(executions); + } + + void TearDown() override {} + + protected: + std::string local_id_; + std::string remote_id_; + std::vector> executions_; +}; + +TEST_F(ExecuteContextTest, BuildExecCtx) { + // mock predict request + apis::PredictRequest request; + apis::PredictResponse response; + + request.mutable_header()->mutable_data()->insert({"test-k", "test-v"}); + request.mutable_service_spec()->set_id("test_service_id"); + request.mutable_fs_params()->insert({"bob", {}}); + request.mutable_fs_params()->at("bob").set_query_context("bob_test_context"); + int params_num = 3; + for (int i = 0; i < params_num; ++i) { + request.mutable_fs_params()->at("bob").add_query_datas("bob_test_params"); + } + auto feature_1 = request.add_predefined_features(); + feature_1->mutable_field()->set_name("feature_1"); + feature_1->mutable_field()->set_type(FieldType::FIELD_STRING); + std::vector ss = {"true", "false", "true"}; + feature_1->mutable_value()->mutable_ss()->Assign(ss.begin(), ss.end()); + auto feature_2 = request.add_predefined_features(); + feature_2->mutable_field()->set_name("feature_2"); + feature_2->mutable_field()->set_type(FieldType::FIELD_DOUBLE); + std::vector ds = {1.1, 2.2, 3.3}; + feature_2->mutable_value()->mutable_ds()->Assign(ds.begin(), ds.end()); + + // build bob ctx + auto ctx_bob = std::make_shared( + &request, &response, executions_[0], remote_id_, local_id_); + ASSERT_EQ(request.header().data().at("test-k"), + ctx_bob->ExeRequest().header().data().at("test-k")); + ASSERT_EQ(ctx_bob->ExeRequest().service_spec().id(), + request.service_spec().id()); + ASSERT_EQ(ctx_bob->ExeRequest().requester_id(), local_id_); + ASSERT_TRUE(ctx_bob->ExeRequest().feature_source().type() == + apis::FeatureSourceType::FS_SERVICE); + ASSERT_TRUE(std::equal( + ctx_bob->ExeRequest().feature_source().fs_param().query_datas().begin(), + ctx_bob->ExeRequest().feature_source().fs_param().query_datas().end(), + request.fs_params().at(ctx_bob->TargetId()).query_datas().begin())); + ASSERT_EQ(ctx_bob->ExeRequest().feature_source().fs_param().query_context(), + request.fs_params().at(ctx_bob->TargetId()).query_context()); + ASSERT_TRUE(ctx_bob->ExeRequest().feature_source().predefineds().empty()); + ASSERT_EQ(ctx_bob->ExeRequest().task().execution_id(), 0); + ASSERT_TRUE(ctx_bob->ExeRequest().task().nodes().empty()); + + // build alice ctx + + auto ctx_alice = std::make_shared( + &request, &response, executions_[0], local_id_, local_id_); + ASSERT_EQ(request.header().data().at("test-k"), + ctx_alice->ExeRequest().header().data().at("test-k")); + ASSERT_EQ(ctx_alice->ExeRequest().service_spec().id(), + request.service_spec().id()); + ASSERT_EQ(ctx_alice->ExeRequest().requester_id(), local_id_); + ASSERT_TRUE(ctx_alice->ExeRequest().feature_source().type() == + apis::FeatureSourceType::FS_PREDEFINED); + ASSERT_TRUE(ctx_alice->ExeRequest() + .feature_source() + .fs_param() + .query_datas() + .empty()); + ASSERT_TRUE(ctx_alice->ExeRequest() + .feature_source() + .fs_param() + .query_context() + .empty()); + ASSERT_EQ(ctx_alice->ExeRequest().feature_source().predefineds_size(), + request.predefined_features_size()); + auto f1 = ctx_alice->ExeRequest().feature_source().predefineds(0); + ASSERT_FALSE(f1.field().name().empty()); + ASSERT_EQ(f1.field().name(), feature_1->field().name()); + ASSERT_EQ(f1.field().type(), feature_1->field().type()); + ASSERT_FALSE(f1.value().ss().empty()); + ASSERT_TRUE(std::equal(f1.value().ss().begin(), f1.value().ss().end(), + feature_1->value().ss().begin())); + auto f2 = ctx_alice->ExeRequest().feature_source().predefineds(1); + ASSERT_FALSE(f2.field().name().empty()); + ASSERT_EQ(f2.field().name(), feature_2->field().name()); + ASSERT_EQ(f2.field().type(), feature_2->field().type()); + ASSERT_FALSE(f2.value().ds().empty()); + ASSERT_TRUE(std::equal(f2.value().ds().begin(), f2.value().ds().end(), + feature_2->value().ds().begin())); + ASSERT_EQ(ctx_alice->ExeRequest().task().execution_id(), 0); + ASSERT_TRUE(ctx_alice->ExeRequest().task().nodes().empty()); + + // mock alice & bob response + { + auto& exec_response = ctx_bob->ExeResponse(); + exec_response.mutable_result()->set_execution_id(0); + auto node = exec_response.mutable_result()->add_nodes(); + node->set_name("mock_node_1"); + auto io = node->add_ios(); + io->add_datas("mock_bob_data"); + } + { + auto& exec_response = ctx_alice->ExeResponse(); + exec_response.mutable_result()->set_execution_id(0); + auto node_1 = exec_response.mutable_result()->add_nodes(); + node_1->set_name("mock_node_1"); + node_1->add_ios()->add_datas("mock_alice_data"); + } + + std::unordered_map> node_io_map; + ctx_bob->GetResultNodeIo(&node_io_map); + ctx_alice->GetResultNodeIo(&node_io_map); + + // build ctx + auto ctx_final = std::make_shared( + &request, &response, executions_[1], local_id_, local_id_); + ctx_final->SetEntryNodesInputs(node_io_map); + + EXPECT_EQ(request.header().data().at("test-k"), + ctx_final->ExeRequest().header().data().at("test-k")); + EXPECT_EQ(ctx_final->ExeRequest().service_spec().id(), + request.service_spec().id()); + EXPECT_EQ(ctx_final->ExeRequest().requester_id(), local_id_); + EXPECT_TRUE(ctx_final->ExeRequest().feature_source().type() == + apis::FeatureSourceType::FS_NONE); + EXPECT_EQ(ctx_final->ExeRequest().task().execution_id(), 1); + EXPECT_EQ(ctx_final->ExeRequest().task().nodes_size(), 1); + auto node1 = ctx_final->ExeRequest().task().nodes(0); + EXPECT_EQ(node1.name(), "mock_node_2"); + EXPECT_EQ(node1.ios_size(), 1); + EXPECT_EQ(node1.ios(0).datas_size(), 2); + EXPECT_EQ(node1.ios(0).datas(0), "mock_bob_data"); + EXPECT_EQ(node1.ios(0).datas(1), "mock_alice_data"); +} + +} // namespace secretflow::serving diff --git a/secretflow_serving/framework/executor.cc b/secretflow_serving/framework/executor.cc index c768812..27c8c65 100644 --- a/secretflow_serving/framework/executor.cc +++ b/secretflow_serving/framework/executor.cc @@ -14,8 +14,12 @@ #include "secretflow_serving/framework/executor.h" +#include + #include "secretflow_serving/ops/op_factory.h" #include "secretflow_serving/ops/op_kernel_factory.h" +#include "secretflow_serving/util/thread_pool.h" +#include "secretflow_serving/util/thread_safe_queue.h" namespace secretflow::serving { @@ -23,22 +27,25 @@ Executor::Executor(const std::shared_ptr& execution) : execution_(execution) { // create op_kernel auto nodes = execution_->nodes(); + node_items_ = std::make_shared< + std::unordered_map>>(); + for (const auto& [node_name, node] : nodes) { - op::OpKernelOptions ctx{node}; + op::OpKernelOptions ctx{node->node_def(), node->GetOpDef()}; auto item = std::make_shared(); item->node = node; item->op_kernel = op::OpKernelFactory::GetInstance()->Create(std::move(ctx)); - node_items_.emplace(node_name, item); + node_items_->emplace(node_name, item); } // get input schema const auto& entry_nodes = execution_->GetEntryNodes(); for (const auto& node : entry_nodes) { const auto& node_name = node->node_def().name(); - auto iter = node_items_.find(node_name); - SERVING_ENFORCE(iter != node_items_.end(), errors::ErrorCode::LOGIC_ERROR); + auto iter = node_items_->find(node_name); + SERVING_ENFORCE(iter != node_items_->end(), errors::ErrorCode::LOGIC_ERROR); const auto& input_schema = iter->second->op_kernel->GetAllInputSchema(); entry_node_names_.emplace_back(node_name); input_schema_map_.emplace(node_name, input_schema); @@ -47,19 +54,25 @@ Executor::Executor(const std::shared_ptr& execution) if (execution_->IsEntry()) { // build feature schema from entry execution auto iter = input_schema_map_.begin(); - const auto& fisrt_input_schema_list = iter->second; - SERVING_ENFORCE(fisrt_input_schema_list.size() == 1, + const auto& first_input_schema_list = iter->second; + SERVING_ENFORCE(first_input_schema_list.size() == 1, errors::ErrorCode::LOGIC_ERROR); - const auto& target_schema = fisrt_input_schema_list.front(); + const auto& target_schema = first_input_schema_list.front(); ++iter; for (; iter != input_schema_map_.end(); ++iter) { - SERVING_ENFORCE(iter->second.size() == 1, errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(iter->second.size() == 1, errors::ErrorCode::LOGIC_ERROR, + "entry nodes should have one input table({})", + iter->second.size()); const auto& schema = iter->second.front(); - SERVING_ENFORCE(schema->Equals(target_schema), - errors::ErrorCode::LOGIC_ERROR, - "found entry nodes input schema not equals"); - // TODO: consider support entry nodes schema have same fields but - // different ordered + + SERVING_ENFORCE_EQ( + target_schema->num_fields(), schema->num_fields(), + "entry nodes should have same shape inputs, expect: {}, found: {}", + target_schema->num_fields(), schema->num_fields()); + CheckReferenceFields(schema, target_schema, + fmt::format("entry nodes should have same input " + "schema, found node:{} mismatch", + iter->first)); } input_feature_schema_ = target_schema; } @@ -68,69 +81,219 @@ Executor::Executor(const std::shared_ptr& execution) std::shared_ptr> Executor::Run( std::shared_ptr& features) { SERVING_ENFORCE(execution_->IsEntry(), errors::ErrorCode::LOGIC_ERROR); - auto inputs = std::make_shared< - std::map>>(); + auto inputs = + std::unordered_map>(); for (size_t i = 0; i < entry_node_names_.size(); ++i) { auto op_inputs = std::make_shared(); std::vector> record_list = {features}; op_inputs->emplace_back(std::move(record_list)); - inputs->emplace(entry_node_names_[i], std::move(op_inputs)); + inputs.emplace(entry_node_names_[i], std::move(op_inputs)); } return Run(inputs); } -std::shared_ptr> Executor::Run( - std::shared_ptr< - std::map>>& inputs) { - auto results = std::make_shared>(); - Propagator propagator; +class RunContext { + public: + explicit RunContext(size_t expect_result_count, + std::deque> buffer = {}) + : ready_nodes_(std::move(buffer)), + expect_result_count_(expect_result_count) {} - // get entry nodes - std::deque> ready_nodes; - const auto& entry_nodes = execution_->GetEntryNodes(); - for (const auto& node : entry_nodes) { - const auto& node_name = node->node_def().name(); - ready_nodes.emplace_back(node_items_.find(node_name)->second); - auto frame = propagator.CreateFrame(node); - auto in_iter = inputs->find(node_name); - SERVING_ENFORCE(in_iter != inputs->end(), - errors::ErrorCode::INVALID_ARGUMENT, - "can not found inputs for node:{}", node_name); - frame->compute_ctx.inputs = in_iter->second; - } - - // schedule ready - // TODO: support multi-thread run - size_t scheduled_count = 0; - while (!ready_nodes.empty()) { - auto n = ready_nodes.front(); - ready_nodes.pop_front(); - - auto frame = propagator.GetFrame(n->node->node_def().name()); - - n->op_kernel->Compute(&(frame->compute_ctx)); - ++scheduled_count; - - if (execution_->IsExitNode(n->node->node_def().name())) { - NodeOutput node_output{n->node->node_def().name(), - frame->compute_ctx.output}; - results->emplace_back(std::move(node_output)); - } else { - auto dst_node_name = n->node->out_edge()->dst_node(); - auto dst_node = execution_->GetNode(dst_node_name); - auto child_frame = propagator.FindOrCreateChildFrame(frame, dst_node); - child_frame->compute_ctx.inputs->at( - n->node->out_edge()->dst_input_id()) = {frame->compute_ctx.output}; - child_frame->pending_count--; - if (child_frame->pending_count == 0) { - ready_nodes.push_back( - node_items_.find(dst_node->node_def().name())->second); + void AddResult(std::string node_name, + std::shared_ptr table) { + { + std::lock_guard lock(results_mtx_); + + results_.emplace_back(std::move(node_name), std::move(table)); + results_cv_.notify_all(); + } + if (IsFinish()) { + Stop(); + } + } + + void AddReadyNode(std::shared_ptr node_item) { + ready_nodes_.Push(std::move(node_item)); + } + + bool GetReadyNode(std::shared_ptr& node_item) { + return ready_nodes_.WaitPop(node_item); + } + + bool IsFinish() const { return expect_result_count_ == results_.size(); } + + std::vector GetResults() { + std::lock_guard lock(results_mtx_); + return results_; + } + + void Stop() { ready_nodes_.StopPush(); } + + private: + ThreadSafeQueue> ready_nodes_; + + size_t expect_result_count_; + + mutable std::mutex results_mtx_; + std::condition_variable results_cv_; + std::vector results_; +}; + +class ExecuteScheduler : public std::enable_shared_from_this { + public: + class ExecuteOpTask : public ThreadPool::Task { + public: + const char* Name() override { return "ExecuteOpTask"; } + + ExecuteOpTask(std::shared_ptr node_item, + std::shared_ptr sched) + : node_item_(std::move(node_item)), sched_(std::move(sched)) {} + + void Exec() override { sched_->ExecuteOp(node_item_); } + + void OnException(std::exception_ptr e) noexcept override { + sched_->SetTaskException(e); + } + + private: + std::shared_ptr node_item_; + std::shared_ptr sched_; + }; + + ExecuteScheduler( + std::shared_ptr< + std::unordered_map>> + node_items, + uint64_t res_cnt, const std::shared_ptr& thread_pool, + std::shared_ptr execution) + : node_items_(std::move(node_items)), + context_(res_cnt), + thread_pool_(thread_pool), + execution_(std::move(execution)), + propagator_(execution_->nodes()), + sched_count_(0) {} + + void AddEntryNode(const std::shared_ptr& node, + std::shared_ptr& inputs) { + auto* frame = propagator_.GetFrame(node->GetName()); + frame->compute_ctx.inputs = std::move(*inputs); + context_.AddReadyNode(node_items_->find(node->node_def().name())->second); + } + + void ExecuteOp(const std::shared_ptr& node_item) { + if (stop_flag_.load()) { + return; + } + + auto* frame = propagator_.GetFrame(node_item->node->node_def().name()); + + node_item->op_kernel->Compute(&(frame->compute_ctx)); + sched_count_++; + + if (execution_->IsExitNode(node_item->node->node_def().name())) { + context_.AddResult(node_item->node->node_def().name(), + frame->compute_ctx.output); + } + + const auto& edges = node_item->node->out_edges(); + for (const auto& edge : edges) { + CompleteOutEdge(edge, frame->compute_ctx.output); + } + } + + void CompleteOutEdge(const std::shared_ptr& edge, + std::shared_ptr output) { + std::shared_ptr dst_node; + if (!execution_->TryGetNode(edge->dst_node(), &dst_node)) { + return; + } + + auto* child_frame = propagator_.GetFrame(dst_node->GetName()); + child_frame->compute_ctx.inputs[edge->dst_input_id()].emplace_back( + std::move(output)); + + if (child_frame->pending_count.fetch_sub(1) == 1) { + context_.AddReadyNode( + node_items_->find(dst_node->node_def().name())->second); + } + } + + void SubmitExecuteOpTask(std::shared_ptr& node_item) { + if (stop_flag_.load()) { + return; + } + thread_pool_->SubmitTask( + std::make_unique(node_item, shared_from_this())); + } + + void Schedule() { + while (!stop_flag_.load() && !context_.IsFinish()) { + // TODO: consider use bthread::Mutex and bthread::ConditionVariable + // to make this worker can switch to others + std::shared_ptr node_item; + if (!context_.GetReadyNode(node_item)) { + continue; } + SubmitExecuteOpTask(node_item); + } + } + + void SetTaskException(std::exception_ptr& e) noexcept { + bool expect_flag = false; + // store the first exception + if (stop_flag_.compare_exchange_strong(expect_flag, true)) { + task_exception_ = e; + context_.Stop(); } } - SERVING_ENFORCE_EQ(scheduled_count, execution_->nodes().size()); - return results; + std::exception_ptr GetTaskException() { return task_exception_; } + + uint64_t GetSchedCount() { return sched_count_.load(); } + std::vector GetResults() { return context_.GetResults(); } + + private: + std::shared_ptr>> + node_items_; + RunContext context_; + std::shared_ptr thread_pool_; + std::shared_ptr execution_; + Propagator propagator_; + std::atomic sched_count_{0}; + static constexpr uint64_t READY_NODE_WAIT_MS = 10; + std::atomic stop_flag_{false}; + std::exception_ptr task_exception_; +}; + +std::shared_ptr> Executor::Run( + std::unordered_map>& + inputs) { + std::vector> entry_node_inputs; + for (const auto& node : execution_->GetEntryNodes()) { + auto iter = inputs.find(node->node_def().name()); + SERVING_ENFORCE(iter != inputs.end(), errors::ErrorCode::INVALID_ARGUMENT, + "can not found inputs for node:{}", + node->node_def().name()); + entry_node_inputs.emplace_back(iter->second); + } + + auto sched = std::make_shared( + node_items_, execution_->GetExitNodeNum(), ThreadPool::GetInstance(), + execution_); + const auto& entry_nodes = execution_->GetEntryNodes(); + for (size_t i = 0; i != execution_->GetEntryNodeNum(); ++i) { + sched->AddEntryNode(entry_nodes[i], entry_node_inputs[i]); + } + + sched->Schedule(); + + auto task_exception = sched->GetTaskException(); + if (task_exception) { + SPDLOG_ERROR("Execution {} run with exception.", execution_->id()); + std::rethrow_exception(task_exception); + } + SERVING_ENFORCE_EQ(sched->GetSchedCount(), execution_->nodes().size()); + return std::make_shared>(sched->GetResults()); } } // namespace secretflow::serving diff --git a/secretflow_serving/framework/executor.h b/secretflow_serving/framework/executor.h index c08c471..7850fc8 100644 --- a/secretflow_serving/framework/executor.h +++ b/secretflow_serving/framework/executor.h @@ -23,6 +23,8 @@ namespace secretflow::serving { struct NodeOutput { std::string node_name; std::shared_ptr table; + NodeOutput(std::string name, std::shared_ptr table) + : node_name(std::move(name)), table(std::move(table)) {} }; struct NodeItem { @@ -36,8 +38,8 @@ class Executor { ~Executor() = default; std::shared_ptr> Run( - std::shared_ptr< - std::map>>& inputs); + std::unordered_map>& + inputs); // for entry executor std::shared_ptr> Run( @@ -52,9 +54,11 @@ class Executor { std::vector entry_node_names_; - std::map> node_items_; + std::shared_ptr>> + node_items_; - std::map>> + std::unordered_map>> input_schema_map_; std::shared_ptr input_feature_schema_; diff --git a/secretflow_serving/framework/executor_test.cc b/secretflow_serving/framework/executor_test.cc index da68864..250281e 100644 --- a/secretflow_serving/framework/executor_test.cc +++ b/secretflow_serving/framework/executor_test.cc @@ -14,62 +14,117 @@ #include "secretflow_serving/framework/executor.h" +#include +#include + #include "arrow/compute/api.h" #include "gtest/gtest.h" #include "secretflow_serving/ops/op_factory.h" #include "secretflow_serving/ops/op_kernel_factory.h" #include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/thread_pool.h" #include "secretflow_serving/util/utils.h" namespace secretflow::serving { +constexpr size_t MASSIVE_NODE_CNT = 1ULL << 14; + class ExecutorTest : public ::testing::Test { protected: - void SetUp() override {} - void TearDown() override {} + void SetUp() override { + ThreadPool::GetInstance()->Start(std::thread::hardware_concurrency()); + } + void TearDown() override { ThreadPool::GetInstance()->Stop(); } }; namespace op { -class MockOpKernel0 : public OpKernel { +#define REGISTER_TEMPLATE_OP_KERNEL(op_name, class_name, unique_class_id) \ + static KernelRegister regist_kernel_##unique_class_id(#op_name); + +template +class AddEleOpKernel : public OpKernel { public: - explicit MockOpKernel0(OpKernelOptions opts) : OpKernel(std::move(opts)) { + explicit AddEleOpKernel(OpKernelOptions opts) : OpKernel(std::move(opts)) { auto schema = arrow::schema({arrow::field("test_field_0", arrow::float64())}); - input_schema_list_ = {schema}; + input_schema_list_ = std::vector(INPUT_EDGE_COUNT, schema); output_schema_ = schema; } // array += 1; - void Compute(ComputeContext* ctx) override { - for (size_t i = 0; i < ctx->inputs->size(); ++i) { - for (size_t j = 0; j < ctx->inputs->at(i).size(); ++j) { - for (int col_index = 0; - col_index < ctx->inputs->at(i)[j]->num_columns(); ++col_index) { + void DoCompute(ComputeContext* ctx) override { + for (size_t i = 0; i < ctx->inputs.size(); ++i) { + for (size_t j = 0; j < ctx->inputs[i].size(); ++j) { + for (int col_index = 0; col_index < ctx->inputs[i][j]->num_columns(); + ++col_index) { // add index num to every item. - auto field = ctx->inputs->at(i)[j]->schema()->field(col_index); - auto array = ctx->inputs->at(i)[j]->column(col_index); - arrow::Datum incremented_datum(1); + auto field = ctx->inputs[i][j]->schema()->field(col_index); + auto array = ctx->inputs[i][j]->column(col_index); + SERVING_ENFORCE(ADD_DATUM != std::numeric_limits::max(), + serving::errors::INVALID_ARGUMENT, + "add datum: {} is too large", ADD_DATUM); + arrow::Datum incremented_datum(ADD_DATUM); SERVING_GET_ARROW_RESULT( arrow::compute::Add(incremented_datum, array), incremented_datum); SERVING_GET_ARROW_RESULT( - ctx->inputs->at(i)[j]->SetColumn( + ctx->inputs[i][j]->SetColumn( col_index, field, std::move(incremented_datum).make_array()), - ctx->inputs->at(i)[j]); + ctx->inputs[i][j]); + } + } + } + ctx->output = ctx->inputs.front()[0]; + } + + void BuildInputSchema() override {} + void BuildOutputSchema() override {} +}; + +template +class EdgeReduceOpKernel : public OpKernel { + public: + explicit EdgeReduceOpKernel(OpKernelOptions opts) + : OpKernel(std::move(opts)) { + auto schema = + arrow::schema({arrow::field("test_field_0", arrow::float64())}); + input_schema_list_ = + std::vector>(INPUT_EDGE_COUNT, schema); + output_schema_ = schema; + } + + void DoCompute(ComputeContext* ctx) override { + auto res = ctx->inputs.front(); + for (size_t i = 1; i < ctx->inputs.size(); ++i) { + for (size_t j = 0; j < ctx->inputs[i].size(); ++j) { + for (int col_index = 0; col_index < ctx->inputs[i][j]->num_columns(); + ++col_index) { + // add index num to every item. + auto field = ctx->inputs[i][j]->schema()->field(col_index); + auto array = ctx->inputs[i][j]->column(col_index); + arrow::Datum incremented_datum; + SERVING_GET_ARROW_RESULT( + arrow::compute::Add(arrow::Datum(res[j]->column(col_index)), + array), + incremented_datum); + SERVING_GET_ARROW_RESULT( + res[j]->SetColumn(col_index, field, + std::move(incremented_datum).make_array()), + res[j]); } } } - ctx->output = ctx->inputs->front()[0]; + ctx->output = res[0]; } void BuildInputSchema() override {} void BuildOutputSchema() override {} }; -class MockOpKernel1 : public OpKernel { +class AddToStrOpKernel : public OpKernel { public: - explicit MockOpKernel1(OpKernelOptions opts) : OpKernel(std::move(opts)) { + explicit AddToStrOpKernel(OpKernelOptions opts) : OpKernel(std::move(opts)) { auto schema = arrow::schema({arrow::field("test_field_0", arrow::float64())}); input_schema_list_ = {schema, schema}; @@ -78,15 +133,14 @@ class MockOpKernel1 : public OpKernel { } // input0-array0 + input1-arry1 then cast to string array - void Compute(ComputeContext* ctx) override { - SERVING_ENFORCE(ctx->inputs->size() == 2, errors::ErrorCode::LOGIC_ERROR); - SERVING_ENFORCE(ctx->inputs->front().size() == 1, - errors::ErrorCode::LOGIC_ERROR); - SERVING_ENFORCE(ctx->inputs->at(1).size() == 1, + void DoCompute(ComputeContext* ctx) override { + SERVING_ENFORCE(ctx->inputs.size() == 2, errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(ctx->inputs.front().size() == 1, errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(ctx->inputs[1].size() == 1, errors::ErrorCode::LOGIC_ERROR); - auto array_0 = ctx->inputs->front()[0]->column(0); - auto array_1 = ctx->inputs->at(1)[0]->column(0); + auto array_0 = ctx->inputs.front()[0]->column(0); + auto array_1 = ctx->inputs[1][0]->column(0); arrow::Datum incremented_datum; SERVING_GET_ARROW_RESULT(arrow::compute::Add(array_0, array_1), @@ -111,47 +165,298 @@ class MockOpKernel1 : public OpKernel { void BuildInputSchema() override {} void BuildOutputSchema() override {} }; - -REGISTER_OP_KERNEL(TEST_OP_0, MockOpKernel0); -REGISTER_OP_KERNEL(TEST_OP_1, MockOpKernel1); -REGISTER_OP(TEST_OP_0, "0.0.1", "test_desc") +using AddEleOpKernelOneEdgeAddMax = + AddEleOpKernel<1, std::numeric_limits::max()>; + +REGISTER_TEMPLATE_OP_KERNEL(TEST_OP_ONE_EDGE_ADD_ONE, AddEleOpKernel<1>, + AddEleOpKernel_1_1); +REGISTER_TEMPLATE_OP_KERNEL(TEST_OP_ONE_EDGE_ADD_MAX, + AddEleOpKernelOneEdgeAddMax, + AddEleOpKernel_1_MAX_Exception); +REGISTER_OP_KERNEL(TEST_OP_TWO_EDGE_ADD_TO_STR, AddToStrOpKernel); +REGISTER_TEMPLATE_OP_KERNEL(TEST_OP_REDUCE_MASSIVE_CNT, + EdgeReduceOpKernel, + EdgeReduceOpKernel_MASSIVE_NODE_CNT); +REGISTER_TEMPLATE_OP_KERNEL(TEST_OP_REDUCE_2, EdgeReduceOpKernel<2>, + EdgeReduceOpKernel_2); +REGISTER_TEMPLATE_OP_KERNEL(TEST_OP_REDUCE_COMPLEX_MASSIVE_CNT, + EdgeReduceOpKernel, + EdgeReduceOpKernel_COMPLEX_MASSIVE_CNT); +REGISTER_OP(TEST_OP_ONE_EDGE_ADD_ONE, "0.0.1", "test_desc") .StringAttr("attr_s", "attr_s_desc", false, false) .Input("input", "input_desc") .Output("output", "output_desc"); -REGISTER_OP(TEST_OP_1, "0.0.1", "test_desc") +REGISTER_OP(TEST_OP_ONE_EDGE_ADD_MAX, "0.0.1", "test_desc") + .StringAttr("attr_s", "attr_s_desc", false, false) + .Input("input", "input_desc") + .Output("output", "output_desc"); +REGISTER_OP(TEST_OP_TWO_EDGE_ADD_TO_STR, "0.0.1", "test_desc") .Returnable() .StringAttr("attr_s", "attr_s_desc", false, false) .Input("input_0", "input_desc") .Input("input_1", "input_desc") .Output("output", "output_desc"); +REGISTER_OP(TEST_OP_REDUCE_MASSIVE_CNT, "0.0.1", "test_desc") + .Returnable() + .StringAttr("attr_s", "attr_s_desc", false, false) + .InputList("input", MASSIVE_NODE_CNT, "input_desc") + .Output("output", "output_desc"); + +REGISTER_OP(TEST_OP_REDUCE_2, "0.0.1", "test_desc") + .Returnable() + .StringAttr("attr_s", "attr_s_desc", false, false) + .InputList("input", 2, "input_desc") + .Output("output", "output_desc"); +REGISTER_OP(TEST_OP_REDUCE_COMPLEX_MASSIVE_CNT, "0.0.1", "test_desc") + .Returnable() + .StringAttr("attr_s", "attr_s_desc", false, false) + .InputList("input", MASSIVE_NODE_CNT * 3 / 2, "input_desc") + .Output("output", "output_desc"); + } // namespace op -TEST_F(ExecutorTest, Works) { +std::string MakeNodeDefJson(const std::string& name, const std::string& op_name, + const std::vector& parents = {}) { + std::string ret = + R"( { "name": ")" + name + R"(", "op": ")" + op_name + R"(", )"; + if (!parents.empty()) { + ret += R"("parents": [ )"; + for (const auto& p : parents) { + ret += '"' + p + '"' + ','; + } + ret.back() = ']'; + } + ret += "}"; + return ret; +} + +TEST_F(ExecutorTest, MassiveWorks) { + std::vector node_def_jsons; + std::string node_list_str; + node_def_jsons.emplace_back( + MakeNodeDefJson("node_a", "TEST_OP_ONE_EDGE_ADD_ONE")); + node_list_str += R"("node_a",)"; + + std::vector last_level_parents; + for (auto i = 0; i != MASSIVE_NODE_CNT; ++i) { + std::string node_name = "node_b_" + std::to_string(i); + node_list_str += '"' + node_name + '"' + ','; + + node_def_jsons.emplace_back( + MakeNodeDefJson(node_name, "TEST_OP_ONE_EDGE_ADD_ONE", {"node_a"})); + last_level_parents.emplace_back(std::move(node_name)); + } + node_def_jsons.emplace_back(MakeNodeDefJson( + "node_c", "TEST_OP_REDUCE_MASSIVE_CNT", last_level_parents)); + node_list_str += R"("node_c")"; + + std::string execution_def_json = + R"({"nodes": [)" + node_list_str + + R"(],"config": {"dispatch_type": "DP_ALL"} })"; + + // build node + std::unordered_map> nodes; + for (const auto& j : node_def_jsons) { + NodeDef node_def; + JsonToPb(j, &node_def); + auto node = std::make_shared(std::move(node_def)); + nodes.emplace(node->GetName(), node); + } + // build edge + for (const auto& pair : nodes) { + const auto& input_nodes = pair.second->GetInputNodeNames(); + for (size_t i = 0; i < input_nodes.size(); ++i) { + auto n_iter = nodes.find(input_nodes[i]); + SERVING_ENFORCE(n_iter != nodes.end(), errors::ErrorCode::LOGIC_ERROR); + auto edge = std::make_shared(n_iter->first, pair.first, i); + n_iter->second->AddOutEdge(edge); + pair.second->AddInEdge(edge); + } + } + + ExecutionDef executino_def; + JsonToPb(execution_def_json, &executino_def); + + auto execution = std::make_shared(0, std::move(executino_def), + std::move(nodes)); + auto executor = std::make_shared(execution); + + auto inputs = + std::unordered_map>(); + { + // mock input + auto input_schema = + arrow::schema({arrow::field("test_field_0", arrow::float64())}); + std::shared_ptr array_0; + arrow::DoubleBuilder double_builder; + SERVING_CHECK_ARROW_STATUS(double_builder.AppendValues({1, 2, 3, 4})); + SERVING_CHECK_ARROW_STATUS(double_builder.Finish(&array_0)); + double_builder.Reset(); + auto input_0 = MakeRecordBatch(input_schema, 4, {array_0}); + + auto op_inputs_0 = std::make_shared(); + std::vector> r_list_0 = {input_0}; + op_inputs_0->emplace_back(r_list_0); + + inputs.emplace("node_a", op_inputs_0); + } + // run + auto output = executor->Run(inputs); + + // build expect + auto expect_output_schema = + arrow::schema({arrow::field("test_field_0", arrow::float64())}); + std::shared_ptr expect_array; + arrow::DoubleBuilder array_builder; + SERVING_CHECK_ARROW_STATUS( + array_builder.AppendValues({3 * MASSIVE_NODE_CNT, 4 * MASSIVE_NODE_CNT, + 5 * MASSIVE_NODE_CNT, 6 * MASSIVE_NODE_CNT})); + SERVING_CHECK_ARROW_STATUS(array_builder.Finish(&expect_array)); + + EXPECT_EQ(output->size(), 1); + EXPECT_EQ(output->at(0).node_name, "node_c"); + EXPECT_EQ(output->at(0).table->num_columns(), 1); + EXPECT_TRUE(output->at(0).table->schema()->Equals(expect_output_schema)); + + std::cout << output->at(0).table->column(0)->ToString() << std::endl; + std::cout << expect_array->ToString() << std::endl; + + EXPECT_TRUE(output->at(0).table->column(0)->Equals(expect_array)); +} + +TEST_F(ExecutorTest, ComplexMassiveWorks) { + std::vector node_def_jsons; + std::string node_list_str; + node_def_jsons.emplace_back( + MakeNodeDefJson("node_a", "TEST_OP_ONE_EDGE_ADD_ONE")); + node_list_str += R"("node_a",)"; + + std::vector last_level_parents; + for (auto i = 0; i != MASSIVE_NODE_CNT; ++i) { + std::string node_name = "node_b_" + std::to_string(i); + node_list_str += '"' + node_name + '"' + ','; + + node_def_jsons.emplace_back( + MakeNodeDefJson(node_name, "TEST_OP_ONE_EDGE_ADD_ONE", {"node_a"})); + last_level_parents.emplace_back(std::move(node_name)); + } + + unsigned node_c_count = MASSIVE_NODE_CNT / 2; + for (unsigned i = 0; i != node_c_count; ++i) { + std::string node_name = "node_c_" + std::to_string(i); + node_list_str += '"' + node_name + '"' + ','; + + node_def_jsons.emplace_back( + MakeNodeDefJson(node_name, "TEST_OP_REDUCE_2", + {"node_b_" + std::to_string(i * 2), + "node_b_" + std::to_string(i * 2 + 1)})); + last_level_parents.emplace_back(std::move(node_name)); + } + + node_def_jsons.emplace_back(MakeNodeDefJson( + "node_d", "TEST_OP_REDUCE_COMPLEX_MASSIVE_CNT", last_level_parents)); + node_list_str += R"("node_d")"; + + std::string execution_def_json = + R"({"nodes": [)" + node_list_str + + R"(],"config": {"dispatch_type": "DP_ALL"} })"; + + // build node + std::unordered_map> nodes; + for (const auto& j : node_def_jsons) { + NodeDef node_def; + JsonToPb(j, &node_def); + auto node = std::make_shared(std::move(node_def)); + nodes.emplace(node->GetName(), node); + } + // build edge + for (const auto& pair : nodes) { + const auto& input_nodes = pair.second->GetInputNodeNames(); + for (size_t i = 0; i < input_nodes.size(); ++i) { + auto n_iter = nodes.find(input_nodes[i]); + SERVING_ENFORCE(n_iter != nodes.end(), errors::ErrorCode::LOGIC_ERROR); + auto edge = std::make_shared(n_iter->first, pair.first, i); + n_iter->second->AddOutEdge(edge); + pair.second->AddInEdge(edge); + } + } + + ExecutionDef executino_def; + JsonToPb(execution_def_json, &executino_def); + + auto execution = std::make_shared(0, std::move(executino_def), + std::move(nodes)); + auto executor = std::make_shared(execution); + + auto inputs = + std::unordered_map>(); + { + // mock input + auto input_schema = + arrow::schema({arrow::field("test_field_0", arrow::float64())}); + std::shared_ptr array_0; + arrow::DoubleBuilder double_builder; + SERVING_CHECK_ARROW_STATUS(double_builder.AppendValues({1, 2, 3, 4})); + SERVING_CHECK_ARROW_STATUS(double_builder.Finish(&array_0)); + double_builder.Reset(); + auto input_0 = MakeRecordBatch(input_schema, 4, {array_0}); + + auto op_inputs_0 = std::make_shared(); + std::vector> r_list_0 = {input_0}; + op_inputs_0->emplace_back(r_list_0); + + inputs.emplace("node_a", op_inputs_0); + } + // run + auto output = executor->Run(inputs); + + // build expect + auto expect_output_schema = + arrow::schema({arrow::field("test_field_0", arrow::float64())}); + std::shared_ptr expect_array; + arrow::DoubleBuilder array_builder; + SERVING_CHECK_ARROW_STATUS(array_builder.AppendValues( + {3 * MASSIVE_NODE_CNT * 2, 4 * MASSIVE_NODE_CNT * 2, + 5 * MASSIVE_NODE_CNT * 2, 6 * MASSIVE_NODE_CNT * 2})); + SERVING_CHECK_ARROW_STATUS(array_builder.Finish(&expect_array)); + + EXPECT_EQ(output->size(), 1); + EXPECT_EQ(output->at(0).node_name, "node_d"); + EXPECT_EQ(output->at(0).table->num_columns(), 1); + EXPECT_TRUE(output->at(0).table->schema()->Equals(expect_output_schema)); + + std::cout << output->at(0).table->column(0)->ToString() << std::endl; + std::cout << expect_array->ToString() << std::endl; + + EXPECT_TRUE(output->at(0).table->column(0)->Equals(expect_array)); +} + +TEST_F(ExecutorTest, BasicWorks) { std::vector node_def_jsons = { R"JSON( { "name": "node_a", - "op": "TEST_OP_0", + "op": "TEST_OP_ONE_EDGE_ADD_ONE", } )JSON", R"JSON( { "name": "node_b", - "op": "TEST_OP_0", + "op": "TEST_OP_ONE_EDGE_ADD_ONE", } )JSON", R"JSON( { "name": "node_c", - "op": "TEST_OP_0", + "op": "TEST_OP_ONE_EDGE_ADD_ONE", "parents": [ "node_a" ], } )JSON", R"JSON( { "name": "node_d", - "op": "TEST_OP_1", + "op": "TEST_OP_TWO_EDGE_ADD_TO_STR", "parents": [ "node_b", "node_c" ], } )JSON"}; @@ -168,7 +473,7 @@ TEST_F(ExecutorTest, Works) { )JSON"; // build node - std::map> nodes; + std::unordered_map> nodes; for (const auto& j : node_def_jsons) { NodeDef node_def; JsonToPb(j, &node_def); @@ -182,7 +487,7 @@ TEST_F(ExecutorTest, Works) { auto n_iter = nodes.find(input_nodes[i]); SERVING_ENFORCE(n_iter != nodes.end(), errors::ErrorCode::LOGIC_ERROR); auto edge = std::make_shared(n_iter->first, pair.first, i); - n_iter->second->SetOutEdge(edge); + n_iter->second->AddOutEdge(edge); pair.second->AddInEdge(edge); } } @@ -216,10 +521,10 @@ TEST_F(ExecutorTest, Works) { std::vector> r_list_1 = {input_1}; op_inputs_1->emplace_back(r_list_1); - auto inputs = std::make_shared< - std::map>>(); - inputs->emplace("node_a", op_inputs_0); - inputs->emplace("node_b", op_inputs_1); + auto inputs = + std::unordered_map>(); + inputs.emplace("node_a", op_inputs_0); + inputs.emplace("node_b", op_inputs_1); // run auto output = executor->Run(inputs); @@ -244,6 +549,207 @@ TEST_F(ExecutorTest, Works) { EXPECT_TRUE(output->at(0).table->column(0)->Equals(expect_array)); } -// TODO: exception case +TEST_F(ExecutorTest, ExceptionWorks) { + std::vector node_def_jsons = { + R"JSON( +{ + "name": "node_a", + "op": "TEST_OP_ONE_EDGE_ADD_ONE", +} +)JSON", + R"JSON( +{ + "name": "node_b", + "op": "TEST_OP_ONE_EDGE_ADD_ONE", +} +)JSON", + R"JSON( +{ + "name": "node_c", + "op": "TEST_OP_ONE_EDGE_ADD_MAX", + "parents": [ "node_a" ], +} +)JSON", + R"JSON( +{ + "name": "node_d", + "op": "TEST_OP_TWO_EDGE_ADD_TO_STR", + "parents": [ "node_b", "node_c" ], +} +)JSON"}; + + std::string execution_def_json = R"JSON( +{ + "nodes": [ + "node_a", "node_b", "node_c", "node_d" + ], + "config": { + "dispatch_type": "DP_ALL" + } +} +)JSON"; + + // build node + std::unordered_map> nodes; + for (const auto& j : node_def_jsons) { + NodeDef node_def; + JsonToPb(j, &node_def); + auto node = std::make_shared(std::move(node_def)); + nodes.emplace(node->GetName(), node); + } + // build edge + for (const auto& pair : nodes) { + const auto& input_nodes = pair.second->GetInputNodeNames(); + for (size_t i = 0; i < input_nodes.size(); ++i) { + auto n_iter = nodes.find(input_nodes[i]); + SERVING_ENFORCE(n_iter != nodes.end(), errors::ErrorCode::LOGIC_ERROR); + auto edge = std::make_shared(n_iter->first, pair.first, i); + n_iter->second->AddOutEdge(edge); + pair.second->AddInEdge(edge); + } + } + + ExecutionDef executino_def; + JsonToPb(execution_def_json, &executino_def); + + auto execution = std::make_shared(0, std::move(executino_def), + std::move(nodes)); + auto executor = std::make_shared(execution); + + // mock input + auto input_schema = + arrow::schema({arrow::field("test_field_0", arrow::float64())}); + std::shared_ptr array_0; + std::shared_ptr array_1; + arrow::DoubleBuilder double_builder; + SERVING_CHECK_ARROW_STATUS(double_builder.AppendValues({1, 2, 3, 4})); + SERVING_CHECK_ARROW_STATUS(double_builder.Finish(&array_0)); + double_builder.Reset(); + SERVING_CHECK_ARROW_STATUS(double_builder.AppendValues({11, 22, 33, 44})); + SERVING_CHECK_ARROW_STATUS(double_builder.Finish(&array_1)); + double_builder.Reset(); + auto input_0 = MakeRecordBatch(input_schema, 4, {array_0}); + auto input_1 = MakeRecordBatch(input_schema, 4, {array_1}); + + auto op_inputs_0 = std::make_shared(); + std::vector> r_list_0 = {input_0}; + op_inputs_0->emplace_back(r_list_0); + auto op_inputs_1 = std::make_shared(); + std::vector> r_list_1 = {input_1}; + op_inputs_1->emplace_back(r_list_1); + + auto inputs = + std::unordered_map>(); + inputs.emplace("node_a", op_inputs_0); + inputs.emplace("node_b", op_inputs_1); + + // run + EXPECT_THROW(executor->Run(inputs), ::secretflow::serving::Exception); + + // expect + EXPECT_EQ(ThreadPool::GetInstance()->GetTaskSize(), 0); +} + +TEST_F(ExecutorTest, ExceptionComplexMassiveWorks) { + std::vector node_def_jsons; + std::string node_list_str; + node_def_jsons.emplace_back( + MakeNodeDefJson("node_a", "TEST_OP_ONE_EDGE_ADD_ONE")); + node_list_str += R"("node_a",)"; + + std::vector last_level_parents; + for (auto i = 0; i != MASSIVE_NODE_CNT - 1; ++i) { + std::string node_name = "node_b_" + std::to_string(i); + node_list_str += '"' + node_name + '"' + ','; + + node_def_jsons.emplace_back( + MakeNodeDefJson(node_name, "TEST_OP_ONE_EDGE_ADD_ONE", {"node_a"})); + last_level_parents.emplace_back(std::move(node_name)); + } + + // exception node + std::string node_name = "node_b_" + std::to_string(MASSIVE_NODE_CNT - 1); + node_list_str += '"' + node_name + '"' + ','; + node_def_jsons.emplace_back( + MakeNodeDefJson(node_name, "TEST_OP_ONE_EDGE_ADD_MAX", {"node_a"})); + last_level_parents.emplace_back(std::move(node_name)); + + unsigned node_c_count = MASSIVE_NODE_CNT / 2; + for (unsigned i = 0; i != node_c_count; ++i) { + std::string node_name = "node_c_" + std::to_string(i); + node_list_str += '"' + node_name + '"' + ','; + + node_def_jsons.emplace_back( + MakeNodeDefJson(node_name, "TEST_OP_REDUCE_2", + {"node_b_" + std::to_string(i * 2), + "node_b_" + std::to_string(i * 2 + 1)})); + last_level_parents.emplace_back(std::move(node_name)); + } + + node_def_jsons.emplace_back(MakeNodeDefJson( + "node_d", "TEST_OP_REDUCE_COMPLEX_MASSIVE_CNT", last_level_parents)); + node_list_str += R"("node_d")"; + + std::string execution_def_json = + R"({"nodes": [)" + node_list_str + + R"(],"config": {"dispatch_type": "DP_ALL"} })"; + + // build node + std::unordered_map> nodes; + for (const auto& j : node_def_jsons) { + NodeDef node_def; + JsonToPb(j, &node_def); + auto node = std::make_shared(std::move(node_def)); + nodes.emplace(node->GetName(), node); + } + // build edge + for (const auto& pair : nodes) { + const auto& input_nodes = pair.second->GetInputNodeNames(); + for (size_t i = 0; i < input_nodes.size(); ++i) { + auto n_iter = nodes.find(input_nodes[i]); + SERVING_ENFORCE(n_iter != nodes.end(), errors::ErrorCode::LOGIC_ERROR); + auto edge = std::make_shared(n_iter->first, pair.first, i); + n_iter->second->AddOutEdge(edge); + pair.second->AddInEdge(edge); + } + } + + ExecutionDef executino_def; + JsonToPb(execution_def_json, &executino_def); + + auto execution = std::make_shared(0, std::move(executino_def), + std::move(nodes)); + auto executor = std::make_shared(execution); + + auto inputs = + std::unordered_map>(); + { + // mock input + auto input_schema = + arrow::schema({arrow::field("test_field_0", arrow::float64())}); + std::shared_ptr array_0; + arrow::DoubleBuilder double_builder; + SERVING_CHECK_ARROW_STATUS(double_builder.AppendValues({1, 2, 3, 4})); + SERVING_CHECK_ARROW_STATUS(double_builder.Finish(&array_0)); + double_builder.Reset(); + auto input_0 = MakeRecordBatch(input_schema, 4, {array_0}); + + auto op_inputs_0 = std::make_shared(); + std::vector> r_list_0 = {input_0}; + op_inputs_0->emplace_back(r_list_0); + + inputs.emplace("node_a", op_inputs_0); + } + + // run + EXPECT_THROW(executor->Run(inputs), ::secretflow::serving::Exception); + + // wait for thread pool to pop remain tasks + executor.reset(); + std::this_thread::sleep_for(std::chrono::seconds(1)); + + // expect + EXPECT_EQ(ThreadPool::GetInstance()->GetTaskSize(), 0); +} } // namespace secretflow::serving diff --git a/secretflow_serving/framework/loader.h b/secretflow_serving/framework/loader.h index 011695c..223844e 100644 --- a/secretflow_serving/framework/loader.h +++ b/secretflow_serving/framework/loader.h @@ -17,32 +17,21 @@ #include #include -#include "secretflow_serving/core/exception.h" -#include "secretflow_serving/framework/executable.h" -#include "secretflow_serving/framework/predictor.h" +#include "secretflow_serving/protos/bundle.pb.h" namespace secretflow::serving { class Loader { public: - struct Options { - std::string party_id; - }; - - public: - Loader(const Options& opts) : opts_(opts) { - SERVING_ENFORCE(!opts_.party_id.empty(), errors::ErrorCode::LOGIC_ERROR); - } + explicit Loader() = default; virtual ~Loader() = default; virtual void Load(const std::string& file_path) = 0; - virtual std::shared_ptr GetExecutable() = 0; - - virtual std::shared_ptr GetPredictor() = 0; + std::shared_ptr GetModelBundle() { return model_bundle_; } protected: - Options opts_; + std::shared_ptr model_bundle_; }; } // namespace secretflow::serving diff --git a/secretflow_serving/framework/model_info_collector.cc b/secretflow_serving/framework/model_info_collector.cc new file mode 100644 index 0000000..0def48b --- /dev/null +++ b/secretflow_serving/framework/model_info_collector.cc @@ -0,0 +1,244 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/framework/model_info_collector.h" + +#include + +#include "spdlog/spdlog.h" + +#include "secretflow_serving/util/utils.h" + +#include "secretflow_serving/apis/model_service.pb.h" + +namespace secretflow::serving { + +namespace { + +using std::invoke_result_t; + +class RetryRunner { + public: + RetryRunner(uint32_t retry_counts, uint32_t retry_interval_ms) + : retry_counts_(retry_counts), retry_interval_ms_(retry_interval_ms) {} + + template >>> + bool Run(Func&& f, Args&&... args) const { + auto runner_func = [&] { + return std::invoke(std::forward(f), std::forward(args)...); + }; + for (uint32_t i = 0; i != retry_counts_; ++i) { + if (!runner_func()) { + std::this_thread::sleep_for( + std::chrono::milliseconds(retry_interval_ms_)); + } else { + return true; + } + } + return false; + } + + private: + uint32_t retry_counts_; + uint32_t retry_interval_ms_; +}; + +} // namespace + +ModelInfoCollector::ModelInfoCollector(Options opts) : opts_(std::move(opts)) { + // build model_info_ + model_info_.set_name(opts_.model_bundle->name()); + model_info_.set_desc(opts_.model_bundle->desc()); + auto* graph_view = model_info_.mutable_graph_view(); + graph_view->set_version(opts_.model_bundle->graph().version()); + for (const auto& node : opts_.model_bundle->graph().node_list()) { + NodeView view; + view.set_name(node.name()); + view.set_op(node.op()); + view.set_op_version(node.op_version()); + *(view.mutable_parents()) = node.parents(); + + graph_view->mutable_node_list()->Add(std::move(view)); + } + *(graph_view->mutable_execution_list()) = + opts_.model_bundle->graph().execution_list(); + + // build specific_party_map_ + auto execution_list = graph_view->execution_list(); + auto execution_list_size = execution_list.size(); + for (int i = 0; i != execution_list_size; ++i) { + auto& execution = execution_list[i]; + if (execution.config().dispatch_type() == DispatchType::DP_SPECIFIED && + !execution.config().specific_flag()) { + specific_party_map_[i] = std::string(); + } + } + + // build local_node_views_ + const auto& node_list = graph_view->node_list(); + for (const auto& node_view : node_list) { + local_node_views_[node_view.name()] = std::make_pair( + node_view, std::set(node_view.parents().begin(), + node_view.parents().end())); + } +} + +void ModelInfoCollector::DoCollect() { + RetryRunner runner(max_retry_cnt_, retry_interval_ms_); + for (auto& [remote_party_id, channel] : *(opts_.remote_channel_map)) { + SERVING_ENFORCE(runner.Run(&ModelInfoCollector::TryCollect, *this, + remote_party_id, channel), + serving::errors::LOGIC_ERROR, + "GetModelInfo from {} failed.", remote_party_id); + } + + CheckAndSetSpecificMap(); +} + +bool ModelInfoCollector::TryCollect( + const std::string& remote_party_id, + const std::shared_ptr<::google::protobuf::RpcChannel>& channel) { + brpc::Controller cntl; + apis::GetModelInfoResponse response; + apis::GetModelInfoRequest request; + request.mutable_service_spec()->set_id(opts_.service_id); + + apis::ModelService_Stub stub(channel.get()); + stub.GetModelInfo(&cntl, &request, &response, nullptr); + + if (cntl.Failed()) { + SPDLOG_WARN( + "call ({}) from ({}) GetModelInfo failed, msg:{}, may need retry", + remote_party_id, opts_.self_party_id, cntl.ErrorText()); + return false; + } + if (!CheckStatusOk(response.status())) { + SPDLOG_WARN( + "call ({}) from ({}) GetModelInfo failed, msg:{}, may need retry", + remote_party_id, opts_.self_party_id, response.status().msg()); + return false; + } + model_info_map_[remote_party_id] = response.model_info(); + return true; +} + +void ModelInfoCollector::CheckNodeViewList( + const ::google::protobuf::RepeatedPtrField<::secretflow::serving::NodeView>& + remote_node_views, + const std::string& remote_party_id) { + SERVING_ENFORCE_EQ( + local_node_views_.size(), static_cast(remote_node_views.size()), + "node views size is not equal, {} : {}, {} : {}", opts_.self_party_id, + local_node_views_.size(), remote_party_id, remote_node_views.size()); + for (const auto& remote_node_view : remote_node_views) { + auto iter = local_node_views_.find(remote_node_view.name()); + SERVING_ENFORCE(iter != local_node_views_.end(), + serving::errors::LOGIC_ERROR, + "can't find node view {} from {}", remote_node_view.name(), + remote_party_id); + auto& [local_node_view, local_node_parents] = iter->second; + SERVING_ENFORCE_EQ(local_node_view.op(), remote_node_view.op(), + "node view {} op name is not equal, {} : {}, {} : {}", + remote_node_view.name(), opts_.self_party_id, + local_node_view.op(), remote_party_id, + remote_node_view.op()); + SERVING_ENFORCE_EQ(local_node_view.op_version(), + remote_node_view.op_version(), + "node view {} op version is not equal, {} : {}, {} : {}", + remote_node_view.name(), opts_.self_party_id, + local_node_view.op_version(), remote_party_id, + remote_node_view.op_version()); + auto remote_node_parents = std::set( + remote_node_view.parents().begin(), remote_node_view.parents().end()); + SERVING_ENFORCE( + local_node_parents == remote_node_parents, serving::errors::LOGIC_ERROR, + "node view {} op parents is not equal, {} : {}, {} : {}", + remote_node_view.name(), opts_.self_party_id, + fmt::join(local_node_parents.begin(), local_node_parents.end(), ","), + remote_party_id, + fmt::join(remote_node_parents.begin(), remote_node_parents.end(), ",")); + } +} + +void ModelInfoCollector::CheckAndSetSpecificMap() { + const auto& local_graph_view = model_info_.graph_view(); + + for (auto& [remote_party_id, model_info] : model_info_map_) { + SERVING_ENFORCE_EQ(model_info.name(), model_info_.name(), + "model name mismatch with {}: {}, local: {}: {}", + remote_party_id, model_info.name(), opts_.self_party_id, + model_info.name()); + + const auto& graph_view = model_info.graph_view(); + SERVING_ENFORCE_EQ(graph_view.version(), local_graph_view.version(), + "version mismatch with {}: {}, local: {}: {}", + remote_party_id, graph_view.version(), + opts_.self_party_id, local_graph_view.version()); + SERVING_ENFORCE_EQ( + local_graph_view.execution_list_size(), + graph_view.execution_list_size(), + "execution list size mismatch with {}: {}, local: {}: {}", + remote_party_id, graph_view.execution_list_size(), opts_.self_party_id, + local_graph_view.execution_list_size()); + + CheckNodeViewList(graph_view.node_list(), remote_party_id); + + for (int i = 0; i != local_graph_view.execution_list_size(); ++i) { + const auto& local_execution = local_graph_view.execution_list(i); + const auto& remote_execution = graph_view.execution_list(i); + + SERVING_ENFORCE_EQ(remote_execution.nodes_size(), + local_execution.nodes_size(), + "node count mismatch: {}: {}, local: {}: {}", + remote_party_id, remote_execution.nodes_size(), + opts_.self_party_id, local_execution.nodes_size()); + + SERVING_ENFORCE( + remote_execution.config().dispatch_type() == + local_execution.config().dispatch_type(), + serving::errors::LOGIC_ERROR, + "node count mismatch: {}: {}, local: {}: {}", remote_party_id, + DispatchType_Name(remote_execution.config().dispatch_type()), + opts_.self_party_id, + DispatchType_Name(local_execution.config().dispatch_type())); + + for (int j = 0; j != local_execution.nodes_size(); ++j) { + SERVING_ENFORCE(remote_execution.nodes(j) == local_execution.nodes(j), + serving::errors::LOGIC_ERROR, + "node name mismatch: {}: {}, local: {}: {}", + remote_party_id, remote_execution.nodes(j), + opts_.self_party_id, local_execution.nodes(j)); + } + + if (remote_execution.config().dispatch_type() == + DispatchType::DP_SPECIFIED) { + if (remote_execution.config().specific_flag()) { + SERVING_ENFORCE(specific_party_map_[i].empty(), + serving::errors::LOGIC_ERROR, + "{} execution specific to multiple parties", i); + specific_party_map_[i] = remote_party_id; + } + } + } + } + + for (auto& [id, party_id] : specific_party_map_) { + SERVING_ENFORCE(!party_id.empty(), serving::errors::LOGIC_ERROR, + "{} execution specific to no party", id); + } +} + +} // namespace secretflow::serving diff --git a/secretflow_serving/framework/model_info_collector.h b/secretflow_serving/framework/model_info_collector.h new file mode 100644 index 0000000..70f8af3 --- /dev/null +++ b/secretflow_serving/framework/model_info_collector.h @@ -0,0 +1,88 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "brpc/controller.h" + +#include "secretflow_serving/core/exception.h" + +#include "secretflow_serving/protos/bundle.pb.h" + +namespace secretflow::serving { + +class ModelInfoCollector { + public: + struct Options { + std::string self_party_id; + std::string service_id; + + std::shared_ptr model_bundle; + + std::shared_ptr< + std::map>> + remote_channel_map; + }; + + public: + explicit ModelInfoCollector(Options opts); + + void DoCollect(); + + const ModelInfo& GetSelfModelInfo() { return model_info_; } + + void SetRetryCounts(uint32_t max_retry_cnt) { + max_retry_cnt_ = max_retry_cnt; + } + void SetRetryIntervalMs(uint32_t retry_interval_ms) { + retry_interval_ms_ = retry_interval_ms; + } + + std::unordered_map GetSpecificMap() const { + return specific_party_map_; + } + + private: + bool TryCollect( + const std::string& remote_party_id, + const std::shared_ptr<::google::protobuf::RpcChannel>& channel); + + void CheckAndSetSpecificMap(); + + void CheckNodeViewList( + const ::google::protobuf::RepeatedPtrField< + ::secretflow::serving::NodeView>& remote_node_views, + const std::string& remote_party_id); + + private: + Options opts_; + + ModelInfo model_info_; + + // key: execution_id(index), value: party_id + std::unordered_map specific_party_map_; + + // key: node name, value: + std::unordered_map>> + local_node_views_; + + std::unordered_map model_info_map_; + + uint32_t max_retry_cnt_{60}; + uint32_t retry_interval_ms_{5000}; +}; + +} // namespace secretflow::serving diff --git a/secretflow_serving/framework/model_loader.cc b/secretflow_serving/framework/model_loader.cc index 9421e62..44f91fd 100644 --- a/secretflow_serving/framework/model_loader.cc +++ b/secretflow_serving/framework/model_loader.cc @@ -14,29 +14,24 @@ #include "secretflow_serving/framework/model_loader.h" +#include #include #include +#include #include "spdlog/spdlog.h" -#include "secretflow_serving/framework/executable_impl.h" -#include "secretflow_serving/framework/executor.h" -#include "secretflow_serving/framework/predictor_impl.h" -#include "secretflow_serving/ops/graph.h" +#include "secretflow_serving/core/exception.h" #include "secretflow_serving/util/sys_util.h" #include "secretflow_serving/util/utils.h" -#include "secretflow_serving/protos/bundle.pb.h" - namespace secretflow::serving { namespace { + const std::string kManifestFileName = "MANIFEST"; -} -ModelLoader::ModelLoader(const Options& opts, - std::shared_ptr channels) - : Loader(opts), channels_(channels) {} +} // namespace void ModelLoader::Load(const std::string& file_path) { SPDLOG_INFO("begin load file: {}", file_path); @@ -72,34 +67,21 @@ void ModelLoader::Load(const std::string& file_path) { auto model_file_path = model_dir.append(manifest.bundle_path()); - ModelBundle model_pb; + auto model_bundle = std::make_shared(); if (manifest.bundle_format() == FileFormatType::FF_PB) { - LoadPbFromBinaryFile(model_file_path.string(), &model_pb); + LoadPbFromBinaryFile(model_file_path.string(), model_bundle.get()); } else if (manifest.bundle_format() == FileFormatType::FF_JSON) { - LoadPbFromJsonFile(model_file_path.string(), &model_pb); + LoadPbFromJsonFile(model_file_path.string(), model_bundle.get()); } else { SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, - "found unkonwn bundle_format:{}", + "found unknown bundle_format:{}", FileFormatType_Name(manifest.bundle_format())); } + model_bundle_ = std::move(model_bundle); - SPDLOG_INFO("load model bundle:{} desc:{} graph version:{}", model_pb.name(), - model_pb.desc(), model_pb.graph().version()); - - auto graph = std::make_unique(model_pb.graph()); - const auto& executions = graph->GetExecutions(); - - std::vector> executors; - for (const auto& e : executions) { - executors.emplace_back(std::make_shared(e)); - } - executable_ = std::make_shared(std::move(executors)); - - Predictor::Options predictor_opts; - predictor_opts.party_id = opts_.party_id; - predictor_opts.channels = channels_; - predictor_opts.executions = executions; - predictor_ = std::make_shared(std::move(predictor_opts)); + SPDLOG_INFO("end load model bundle, name: {}, desc: {}, graph version: {}", + model_bundle_->name(), model_bundle_->desc(), + model_bundle_->graph().version()); } } // namespace secretflow::serving diff --git a/secretflow_serving/framework/model_loader.h b/secretflow_serving/framework/model_loader.h index fc434c8..8a46e10 100644 --- a/secretflow_serving/framework/model_loader.h +++ b/secretflow_serving/framework/model_loader.h @@ -20,20 +20,10 @@ namespace secretflow::serving { class ModelLoader : public Loader { public: - ModelLoader(const Options& opts, std::shared_ptr channels); - virtual ~ModelLoader() = default; + ModelLoader() = default; + ~ModelLoader() override = default; void Load(const std::string& file_path) override; - - std::shared_ptr GetExecutable() override { return executable_; } - - std::shared_ptr GetPredictor() override { return predictor_; } - - private: - std::shared_ptr channels_; - - std::shared_ptr executable_; - std::shared_ptr predictor_; }; } // namespace secretflow::serving diff --git a/secretflow_serving/framework/predictor.cc b/secretflow_serving/framework/predictor.cc new file mode 100644 index 0000000..9d2807d --- /dev/null +++ b/secretflow_serving/framework/predictor.cc @@ -0,0 +1,165 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/framework/predictor.h" + +#include + +#include "arrow/compute/api.h" + +#include "secretflow_serving/core/exception.h" +#include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/utils.h" + +namespace secretflow::serving { + +Predictor::Predictor(Options opts) : opts_(std::move(opts)) {} + +void Predictor::Predict(const apis::PredictRequest* request, + apis::PredictResponse* response) { + std::unordered_map> + prev_node_io_map; + std::vector> async_running_execs; + async_running_execs.reserve(opts_.channels->size()); + + auto execute_locally = + [&](const std::shared_ptr& execution, + std::unordered_map>& + prev_io_map, + std::unordered_map>& + cur_io_map) { + // exec locally + auto local_exec = BuildLocalExecute(request, response, execution); + local_exec->SetInputs(std::move(prev_io_map)); + local_exec->Run(); + local_exec->GetOutputs(&cur_io_map); + }; + + for (const auto& e : opts_.executions) { + async_running_execs.clear(); + std::unordered_map> + new_node_io_map; + if (e->GetDispatchType() == DispatchType::DP_ALL) { + for (const auto& [party_id, channel] : *opts_.channels) { + auto ctx = BuildRemoteExecute(request, response, e, party_id, channel); + ctx->SetInputs(prev_node_io_map); + ctx->Run(); + async_running_execs.emplace_back(ctx); + } + + // exec locally + if (execution_core_) { + execute_locally(e, prev_node_io_map, new_node_io_map); + for (auto& exec : async_running_execs) { + exec->WaitToFinish(); + exec->GetOutputs(&new_node_io_map); + } + } else { + // TODO: support no execution core scene + SERVING_THROW(errors::ErrorCode::NOT_IMPLEMENTED, "not implemented"); + } + + } else if (e->GetDispatchType() == DispatchType::DP_ANYONE) { + // exec locally + if (execution_core_) { + execute_locally(e, prev_node_io_map, new_node_io_map); + } else { + // TODO: support no execution core scene + SERVING_THROW(errors::ErrorCode::NOT_IMPLEMENTED, "not implemented"); + } + } else if (e->GetDispatchType() == DispatchType::DP_SPECIFIED) { + if (e->SpecificToThis()) { + SERVING_ENFORCE(execution_core_, errors::ErrorCode::UNEXPECTED_ERROR); + execute_locally(e, prev_node_io_map, new_node_io_map); + } else { + auto iter = opts_.specific_party_map.find(e->id()); + SERVING_ENFORCE(iter != opts_.specific_party_map.end(), + serving::errors::LOGIC_ERROR, + "{} execution assign to no party", e->id()); + auto ctx = BuildRemoteExecute(request, response, e, iter->second, + opts_.channels->at(iter->second)); + ctx->SetInputs(prev_node_io_map); + ctx->Run(); + ctx->WaitToFinish(); + ctx->GetOutputs(&new_node_io_map); + } + } else { + SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, + "unsupported dispatch type: {}", + DispatchType_Name(e->GetDispatchType())); + } + prev_node_io_map.swap(new_node_io_map); + } + + DealFinalResult(prev_node_io_map, response); +} + +std::shared_ptr Predictor::BuildRemoteExecute( + const apis::PredictRequest* request, apis::PredictResponse* response, + const std::shared_ptr& execution, std::string target_id, + std::shared_ptr<::google::protobuf::RpcChannel> channel) { + return std::make_shared(request, response, execution, + target_id, opts_.party_id, channel); +} + +std::shared_ptr Predictor::BuildLocalExecute( + const apis::PredictRequest* request, apis::PredictResponse* response, + const std::shared_ptr& execution) { + return std::make_shared(request, response, execution, + opts_.party_id, opts_.party_id, + execution_core_); +} + +void Predictor::DealFinalResult( + std::unordered_map>& node_io_map, + apis::PredictResponse* response) { + SERVING_ENFORCE(node_io_map.size() == 1, errors::ErrorCode::LOGIC_ERROR); + auto& node_io = node_io_map.begin()->second; + SERVING_ENFORCE(node_io->ios_size() == 1, errors::ErrorCode::LOGIC_ERROR); + auto& ios = node_io->ios(0); + SERVING_ENFORCE(ios.datas_size() == 1, errors::ErrorCode::LOGIC_ERROR); + std::shared_ptr record_batch = + DeserializeRecordBatch(ios.datas(0)); + + std::vector results(record_batch->num_rows()); + for (int64_t i = 0; i != record_batch->num_rows(); ++i) { + results[i] = response->add_results(); + } + + for (int j = 0; j < record_batch->num_columns(); ++j) { + auto col = record_batch->column(j); + if (col->type_id() != arrow::Type::DOUBLE) { + arrow::Datum tmp; + SERVING_GET_ARROW_RESULT( + arrow::compute::Cast( + col, arrow::compute::CastOptions::Safe(arrow::float64())), + tmp); + col = std::move(tmp).make_array(); + } + + // index 0 is validity bitmap, real data start with 1 + const auto* data = col->data()->GetValues(1); + SERVING_ENFORCE(data, errors::ErrorCode::LOGIC_ERROR, + "found unsupported field type"); + + auto field_name = record_batch->schema()->field(j)->name(); + for (int64_t i = 0; i < record_batch->num_rows(); ++i) { + auto* score = results[i]->add_scores(); + score->set_name(field_name); + score->set_value(data[i]); + } + } +} + +} // namespace secretflow::serving diff --git a/secretflow_serving/framework/predictor.h b/secretflow_serving/framework/predictor.h index 789ea0f..64f7581 100644 --- a/secretflow_serving/framework/predictor.h +++ b/secretflow_serving/framework/predictor.h @@ -16,9 +16,12 @@ #include #include +#include +#include #include "google/protobuf/service.h" +#include "secretflow_serving/framework/execute_context.h" #include "secretflow_serving/server/execution_core.h" #include "secretflow_serving/apis/prediction_service.pb.h" @@ -38,19 +41,38 @@ class Predictor { std::shared_ptr channels; std::vector> executions; + + std::unordered_map specific_party_map; }; public: - explicit Predictor(Options opts) : opts_(std::move(opts)) {} + explicit Predictor(Options opts); virtual ~Predictor() = default; virtual void Predict(const apis::PredictRequest* request, - apis::PredictResponse* response) = 0; + apis::PredictResponse* response); void SetExecutionCore(std::shared_ptr& execution_core) { execution_core_ = execution_core; } + protected: + virtual std::shared_ptr BuildRemoteExecute( + const apis::PredictRequest* request, apis::PredictResponse* response, + const std::shared_ptr& execution, std::string target_id, + std::shared_ptr<::google::protobuf::RpcChannel> channel); + + virtual std::shared_ptr BuildLocalExecute( + const apis::PredictRequest* request, apis::PredictResponse* response, + const std::shared_ptr& execution); + + void BuildExecCtx(); + + void DealFinalResult( + std::unordered_map>& + node_io_map, + apis::PredictResponse* response); + protected: Options opts_; diff --git a/secretflow_serving/framework/predictor_impl.cc b/secretflow_serving/framework/predictor_impl.cc deleted file mode 100644 index cf70511..0000000 --- a/secretflow_serving/framework/predictor_impl.cc +++ /dev/null @@ -1,308 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "secretflow_serving/framework/predictor_impl.h" - -#include "secretflow_serving/core/exception.h" -#include "secretflow_serving/util/arrow_helper.h" -#include "secretflow_serving/util/utils.h" - -namespace secretflow::serving { - -namespace { -class OnRPCDone : public google::protobuf::Closure { - public: - OnRPCDone(const std::shared_ptr& cntl) : cntl_(cntl) {} - - void Run() override { - std::unique_ptr self_guard(this); - // reduce cntl_ reference count - cntl_ = nullptr; - } - - private: - std::shared_ptr cntl_; -}; -} // namespace - -PredictorImpl::PredictorImpl(Options opts) : Predictor(std::move(opts)) {} - -void PredictorImpl::Predict(const apis::PredictRequest* request, - apis::PredictResponse* response) { - std::shared_ptr>> last_exec_ctxs; - for (const auto& e : opts_.executions) { - if (e->GetDispatchType() == DispatchType::DP_ALL) { - // async exec peers - auto peer_ctx_list = - std::make_shared>>(); - for (const auto& [party_id, _] : *opts_.channels) { - auto ctx = BuildExecCtx(request, response, party_id, e, last_exec_ctxs); - peer_ctx_list->emplace_back(ctx); - } - AsyncPeersExecute(peer_ctx_list); - - // exec self - auto local_ctx = - BuildExecCtx(request, response, opts_.party_id, e, last_exec_ctxs); - LocalExecute(local_ctx, peer_ctx_list); - - // join peers - JoinPeersExecute(peer_ctx_list); - - last_exec_ctxs = peer_ctx_list; - last_exec_ctxs->emplace_back(local_ctx); - } else if (e->GetDispatchType() == DispatchType::DP_ANYONE) { - auto local_ctx = - BuildExecCtx(request, response, opts_.party_id, e, last_exec_ctxs); - LocalExecute(local_ctx); - last_exec_ctxs = - std::make_shared>>(); - last_exec_ctxs->emplace_back(local_ctx); - } else { - SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, - "unsupport dispatch type: {}", - DispatchType_Name(e->GetDispatchType())); - } - } - - SERVING_ENFORCE(last_exec_ctxs->size() == 1, errors::ErrorCode::LOGIC_ERROR); - DealFinalResult(last_exec_ctxs->front(), response); -} - -std::shared_ptr PredictorImpl::BuildExecCtx( - const apis::PredictRequest* request, apis::PredictResponse* response, - const std::string& target_id, const std::shared_ptr& execution, - std::shared_ptr>>& - last_exec_ctxs) { - auto ctx = std::make_shared(); - ctx->response = response; - ctx->request = request; - - ctx->target_id = target_id; - ctx->execution = execution; - - ctx->exec_res = std::make_shared(); - ctx->cntl = std::make_shared(); - - ctx->exec_req = std::make_shared(); - ctx->exec_req->mutable_header()->CopyFrom(request->header()); - ctx->exec_req->set_requester_id(opts_.party_id); - ctx->exec_req->mutable_service_spec()->CopyFrom(request->service_spec()); - auto feature_source = ctx->exec_req->mutable_feature_source(); - if (execution->IsEntry()) { - // entry execution need features - // get target_id's feature param - if (target_id == opts_.party_id && - request->predefined_features_size() != 0) { - // only loacl execute will use `predefined_features` - feature_source->set_type(apis::FeatureSourceType::FS_PREDEFINED); - feature_source->mutable_predefineds()->CopyFrom( - request->predefined_features()); - } else { - feature_source->set_type(apis::FeatureSourceType::FS_SERVICE); - auto iter = request->fs_params().find(target_id); - SERVING_ENFORCE(iter != request->fs_params().end(), - errors::ErrorCode::INVALID_ARGUMENT, - "missing {}'s feature params", target_id); - feature_source->mutable_fs_param()->CopyFrom(iter->second); - } - } else { - feature_source->set_type(apis::FeatureSourceType::FS_NONE); - } - // build task - auto task = ctx->exec_req->mutable_task(); - task->set_execution_id(execution->id()); - if (last_exec_ctxs) { - // merge output node_io - std::map> node_io_map; - for (auto& res_ctx : *last_exec_ctxs) { - auto result = res_ctx->exec_res->mutable_result(); - for (int i = 0; i < result->nodes_size(); ++i) { - auto node_io = result->mutable_nodes(i); - auto iter = node_io_map.find(node_io->name()); - if (iter != node_io_map.end()) { - // found node, merge ios - auto& target_node_io = iter->second; - SERVING_ENFORCE(target_node_io->ios_size() == node_io->ios_size(), - errors::ErrorCode::LOGIC_ERROR); - for (int io_index = 0; io_index < target_node_io->ios_size(); - ++io_index) { - auto target_io = target_node_io->mutable_ios(io_index); - auto io = node_io->mutable_ios(io_index); - for (int data_index = 0; data_index < io->datas_size(); - ++data_index) - target_io->add_datas(std::move(*(io->mutable_datas(data_index)))); - } - } else { - auto node_name = node_io->name(); - node_io_map.emplace( - node_name, std::make_shared(std::move(*node_io))); - } - } - } - // build intput from last output - auto entry_nodes = execution->GetEntryNodes(); - for (const auto& n : entry_nodes) { - auto entry_node_io = task->add_nodes(); - entry_node_io->set_name(n->GetName()); - for (const auto& e : n->in_edges()) { - auto iter = node_io_map.find(e->src_node()); - SERVING_ENFORCE(iter != node_io_map.end(), - errors::ErrorCode::LOGIC_ERROR); - for (auto& io : *(iter->second->mutable_ios())) { - entry_node_io->mutable_ios()->Add(std::move(io)); - } - } - } - } - return ctx; -} - -void PredictorImpl::AsyncPeersExecute( - std::shared_ptr>>& - context_list) { - for (size_t i = 0; i < context_list->size(); ++i) { - auto ctx = context_list->at(i); - - AsyncCallRpc(ctx->target_id, ctx->cntl, ctx->exec_req.get(), - ctx->exec_res.get()); - } -} - -void PredictorImpl::JoinPeersExecute( - std::shared_ptr>>& - context_list) { - for (auto& context : *context_list) { - JoinAsyncCall(context->target_id, context->cntl); - } - for (auto& context : *context_list) { - MergeHeader(context->response, context->exec_res); - CheckExecResponse(context->target_id, context->exec_res); - } -} - -void PredictorImpl::SyncPeersExecute( - std::shared_ptr>>& - context_list) { - AsyncPeersExecute(context_list); - JoinPeersExecute(context_list); -} - -void PredictorImpl::AsyncCallRpc(const std::string& target_id, - std::shared_ptr& cntl, - const apis::ExecuteRequest* request, - apis::ExecuteResponse* response) { - OnRPCDone* done = new OnRPCDone(cntl); - // semisynchronous call - apis::ExecutionService_Stub stub(opts_.channels->at(target_id).get()); - stub.Execute(cntl.get(), request, response, done); -} - -void PredictorImpl::JoinAsyncCall( - const std::string& target_id, - const std::shared_ptr& cntl) { - brpc::Join(cntl->call_id()); - SERVING_ENFORCE(!cntl->Failed(), errors::ErrorCode::NETWORK_ERROR, - "call ({}) execute failed, msg:{}", target_id, - cntl->ErrorText()); -} - -void PredictorImpl::CancelAsyncCall( - const std::shared_ptr& cntl) { - brpc::StartCancel(cntl->call_id()); -} - -void PredictorImpl::LocalExecute( - std::shared_ptr& context, - std::shared_ptr>> - exception_cancel_cxts) { - try { - context->exec_res = std::make_shared(); - execution_core_->Execute(context->exec_req.get(), context->exec_res.get()); - MergeHeader(context->response, context->exec_res); - CheckExecResponse(context->target_id, context->exec_res); - } catch (Exception& e) { - if (exception_cancel_cxts) { - for (const auto& cxt : *exception_cancel_cxts) { - CancelAsyncCall(cxt->cntl); - } - } - throw e; - } catch (std::exception& e) { - if (exception_cancel_cxts) { - for (const auto& cxt : *exception_cancel_cxts) { - CancelAsyncCall(cxt->cntl); - } - } - throw e; - } -} - -void PredictorImpl::MergeHeader( - apis::PredictResponse* response, - const std::shared_ptr& exec_response) { - response->mutable_header()->mutable_data()->insert( - exec_response->header().data().begin(), - exec_response->header().data().end()); -} - -void PredictorImpl::CheckExecResponse( - const std::string& party_id, - const std::shared_ptr& response) { - if (!CheckStatusOk(response->status())) { - SERVING_THROW( - response->status().code(), - fmt::format("{} exec failed: {}", party_id, response->status().msg())); - } -} - -void PredictorImpl::DealFinalResult(std::shared_ptr& ctx, - apis::PredictResponse* response) { - std::shared_ptr record_batch; - auto exec_result = ctx->exec_res->result(); - SERVING_ENFORCE(exec_result.nodes_size() == 1, - errors::ErrorCode::LOGIC_ERROR); - for (const auto& n : exec_result.nodes()) { - SERVING_ENFORCE(n.ios_size() == 1, errors::ErrorCode::LOGIC_ERROR); - for (const auto& io : n.ios()) { - SERVING_ENFORCE(io.datas_size() == 1, errors::ErrorCode::LOGIC_ERROR); - record_batch = DeserializeRecordBatch(io.datas(0)); - break; - } - break; - } - - for (int64_t i = 0; i < record_batch->num_rows(); ++i) { - auto result = response->add_results(); - for (int j = 0; j < record_batch->num_columns(); ++j) { - auto field = record_batch->schema()->field(j); - auto array = record_batch->column(j); - std::shared_ptr raw_scalar; - SERVING_GET_ARROW_RESULT(record_batch->column(j)->GetScalar(i), - raw_scalar); - std::shared_ptr scalar = raw_scalar; - if (raw_scalar->type->id() != arrow::Type::type::DOUBLE) { - // cast type - SERVING_GET_ARROW_RESULT(raw_scalar->CastTo(arrow::float64()), scalar); - } - auto score_value = - std::static_pointer_cast(scalar)->value; - auto score = result->add_scores(); - score->set_name(field->name()); - score->set_value(score_value); - } - } -} - -} // namespace secretflow::serving diff --git a/secretflow_serving/framework/predictor_impl.h b/secretflow_serving/framework/predictor_impl.h deleted file mode 100644 index fc36e96..0000000 --- a/secretflow_serving/framework/predictor_impl.h +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "brpc/controller.h" - -#include "secretflow_serving/framework/predictor.h" -#include "secretflow_serving/ops/graph.h" - -#include "secretflow_serving/apis/execution_service.pb.h" - -namespace secretflow::serving { - -class PredictorImpl : public Predictor { - public: - struct ExecuteContext { - const apis::PredictRequest* request; - apis::PredictResponse* response; - - std::string session_id; - std::string target_id; - std::shared_ptr execution; - - std::shared_ptr exec_req; - std::shared_ptr exec_res; - std::shared_ptr cntl; - }; - - public: - explicit PredictorImpl(Options opts); - ~PredictorImpl() = default; - - void Predict(const apis::PredictRequest* request, - apis::PredictResponse* response) override; - - protected: - std::shared_ptr BuildExecCtx( - const apis::PredictRequest* request, apis::PredictResponse* response, - const std::string& target_id, const std::shared_ptr& execution, - std::shared_ptr>>& - last_exec_ctxs); - - void AsyncPeersExecute( - std::shared_ptr>>& - context_list); - void JoinPeersExecute( - std::shared_ptr>>& - context_list); - void SyncPeersExecute( - std::shared_ptr>>& - context_list); - - void LocalExecute( - std::shared_ptr& context, - std::shared_ptr>> - exception_cancel_cxts = nullptr); - - virtual void AsyncCallRpc(const std::string& target_id, - std::shared_ptr& cntl, - const apis::ExecuteRequest* request, - apis::ExecuteResponse* response); - - virtual void JoinAsyncCall(const std::string& target_id, - const std::shared_ptr& cntl); - - virtual void CancelAsyncCall(const std::shared_ptr& cntl); - - void MergeHeader(apis::PredictResponse* response, - const std::shared_ptr& exec_response); - - void CheckExecResponse( - const std::string& party_id, - const std::shared_ptr& response); - - void DealFinalResult(std::shared_ptr& ctx, - apis::PredictResponse* response); -}; - -} // namespace secretflow::serving diff --git a/secretflow_serving/framework/predictor_impl_test.cc b/secretflow_serving/framework/predictor_impl_test.cc deleted file mode 100644 index 8a1a3cb..0000000 --- a/secretflow_serving/framework/predictor_impl_test.cc +++ /dev/null @@ -1,531 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "secretflow_serving/framework/predictor_impl.h" - -#include "brpc/channel.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -#include "secretflow_serving/ops/op_factory.h" -#include "secretflow_serving/ops/op_kernel_factory.h" -#include "secretflow_serving/util/arrow_helper.h" -#include "secretflow_serving/util/utils.h" - -#include "secretflow_serving/protos/field.pb.h" - -namespace secretflow::serving { - -namespace op { - -class MockOpKernel0 : public OpKernel { - public: - explicit MockOpKernel0(OpKernelOptions opts) : OpKernel(std::move(opts)) {} - - void Compute(ComputeContext* ctx) override {} - void BuildInputSchema() override {} - void BuildOutputSchema() override {} -}; - -class MockOpKernel1 : public OpKernel { - public: - explicit MockOpKernel1(OpKernelOptions opts) : OpKernel(std::move(opts)) {} - - void Compute(ComputeContext* ctx) override {} - void BuildInputSchema() override {} - void BuildOutputSchema() override {} -}; - -REGISTER_OP_KERNEL(TEST_OP_0, MockOpKernel0); -REGISTER_OP_KERNEL(TEST_OP_1, MockOpKernel1); -REGISTER_OP(TEST_OP_0, "0.0.1", "test_desc") - .StringAttr("attr_s", "attr_s_desc", false, false) - .Input("input", "input_desc") - .Output("output", "output_desc"); -REGISTER_OP(TEST_OP_1, "0.0.1", "test_desc") - .Mergeable() - .Returnable() - .StringAttr("attr_s", "attr_s_desc", false, false) - .Input("input", "input_desc") - .Output("output", "output_desc"); - -} // namespace op - -class MockExecutable : public Executable { - public: - MockExecutable() - : Executable(), - mock_schema_( - arrow::schema({arrow::field("test_name", arrow::utf8())})) {} - - const std::shared_ptr& GetInputFeatureSchema() { - return mock_schema_; - } - - MOCK_METHOD1(Run, void(Task&)); - - private: - std::shared_ptr mock_schema_; -}; - -class MockExecutionCore : public ExecutionCore { - public: - MockExecutionCore(Options opts) : ExecutionCore(std::move(opts)) {} - - MOCK_METHOD2(Execute, - void(const apis::ExecuteRequest*, apis::ExecuteResponse*)); -}; - -class MockPredictor : public PredictorImpl { - public: - MockPredictor(const Options& options) : PredictorImpl(options) {} - - std::shared_ptr TestBuildExecCtx( - const apis::PredictRequest* request, apis::PredictResponse* response, - const std::string& target_id, const std::shared_ptr& execution, - std::shared_ptr>>& - last_exec_ctxs) { - return BuildExecCtx(request, response, target_id, execution, - last_exec_ctxs); - } - - MOCK_METHOD4(AsyncCallRpc, - void(const std::string&, std::shared_ptr&, - const apis::ExecuteRequest*, apis::ExecuteResponse*)); - - MOCK_METHOD2(JoinAsyncCall, void(const std::string&, - const std::shared_ptr&)); - - MOCK_METHOD1(CancelAsyncCall, void(const std::shared_ptr&)); -}; - -class PredictorImplTest : public ::testing::Test { - protected: - void SetUp() override { - MockExecutionCore::Options exec_opts{"test_id", "alice", std::nullopt, - std::nullopt, - std::make_shared()}; - mock_exec_core_ = new MockExecutionCore(std::move(exec_opts)); - exec_core_ = std::shared_ptr(mock_exec_core_); - - // mock execution - std::vector node_def_jsons = { - R"JSON( -{ - "name": "mock_node_1", - "op": "TEST_OP_0", -} -)JSON", - R"JSON( -{ - "name": "mock_node_2", - "op": "TEST_OP_1", - "parents": [ "mock_node_1" ], -} -)JSON"}; - - std::vector execution_def_jsons = { - R"JSON( -{ - "nodes": [ - "mock_node_1" - ], - "config": { - "dispatch_type": "DP_ALL" - } -} -)JSON", - R"JSON( -{ - "nodes": [ - "mock_node_2" - ], - "config": { - "dispatch_type": "DP_ANYONE" - } -} -)JSON"}; - - // build node - std::map> nodes; - for (const auto& j : node_def_jsons) { - NodeDef node_def; - JsonToPb(j, &node_def); - auto node = std::make_shared(std::move(node_def)); - nodes.emplace(node->GetName(), node); - } - // build edge - for (const auto& pair : nodes) { - const auto& input_nodes = pair.second->GetInputNodeNames(); - for (size_t i = 0; i < input_nodes.size(); ++i) { - auto n_iter = nodes.find(input_nodes[i]); - SERVING_ENFORCE(n_iter != nodes.end(), errors::ErrorCode::LOGIC_ERROR); - auto edge = std::make_shared(n_iter->first, pair.first, i); - n_iter->second->SetOutEdge(edge); - pair.second->AddInEdge(edge); - } - } - std::vector> executions; - for (size_t i = 0; i < execution_def_jsons.size(); ++i) { - ExecutionDef executino_def; - JsonToPb(execution_def_jsons[i], &executino_def); - - std::map> e_nodes; - for (const auto& n : executino_def.nodes()) { - e_nodes.emplace(n, nodes.find(n)->second); - } - - executions.emplace_back(std::make_shared( - i, std::move(executino_def), std::move(e_nodes))); - } - - // mock channel - channel_map_ = std::make_shared(); - auto channel = std::make_unique(); - channel_map_->emplace(std::make_pair("bob", std::move(channel))); - - p_opts_.party_id = "alice"; - p_opts_.channels = channel_map_; - p_opts_.executions = std::move(executions); - - mock_predictor_ = std::make_shared(p_opts_); - mock_predictor_->SetExecutionCore(exec_core_); - } - - void TearDown() override { exec_core_ = nullptr; } - - protected: - Predictor::Options p_opts_; - - MockExecutionCore* mock_exec_core_; - - std::shared_ptr mock_predictor_; - std::shared_ptr exec_core_; - std::shared_ptr channel_map_; -}; - -MATCHER_P(ExecuteRequestEquel, expect, "") { - return arg->service_spec().id() == expect->service_spec().id() && - arg->requester_id() == expect->requester_id() && - arg->feature_source().type() == expect->feature_source().type() && - std::equal( - arg->feature_source().fs_param().query_datas().begin(), - arg->feature_source().fs_param().query_datas().end(), - expect->feature_source().fs_param().query_datas().begin()) && - arg->feature_source().fs_param().query_context() == - expect->feature_source().fs_param().query_context() && - std::equal(arg->feature_source().predefineds().begin(), - arg->feature_source().predefineds().end(), - expect->feature_source().predefineds().begin(), - [](const Feature& f1, const Feature& f2) { - return f1.field().name() == f2.field().name(); - }) && - arg->task().execution_id() == expect->task().execution_id() && - std::equal( - arg->task().nodes().begin(), arg->task().nodes().end(), - expect->task().nodes().begin(), - [](const apis::NodeIo& n1, const apis::NodeIo& n2) { - return n1.name() == n2.name() && - std::equal( - n1.ios().begin(), n1.ios().end(), n2.ios().begin(), - [](const apis::IoData& io1, const apis::IoData& io2) { - return std::equal(io1.datas().begin(), - io1.datas().end(), - io2.datas().begin()); - }); - }); -} - -TEST_F(PredictorImplTest, BuildExecCtx) { - // mock predict request - apis::PredictRequest request; - apis::PredictResponse response; - - request.mutable_header()->mutable_data()->insert({"test-k", "test-v"}); - request.mutable_service_spec()->set_id("test_service_id"); - request.mutable_fs_params()->insert({"bob", {}}); - request.mutable_fs_params()->at("bob").set_query_context("bob_test_context"); - int params_num = 3; - for (int i = 0; i < params_num; ++i) { - request.mutable_fs_params()->at("bob").add_query_datas("bob_test_params"); - } - auto feature_1 = request.add_predefined_features(); - feature_1->mutable_field()->set_name("feature_1"); - feature_1->mutable_field()->set_type(FieldType::FIELD_STRING); - std::vector ss = {"true", "false", "true"}; - feature_1->mutable_value()->mutable_ss()->Assign(ss.begin(), ss.end()); - auto feature_2 = request.add_predefined_features(); - feature_2->mutable_field()->set_name("feature_2"); - feature_2->mutable_field()->set_type(FieldType::FIELD_DOUBLE); - std::vector ds = {1.1, 2.2, 3.3}; - feature_2->mutable_value()->mutable_ds()->Assign(ds.begin(), ds.end()); - - // build bob ctx - std::shared_ptr>> - last_ctxs; - auto ctx_bob = mock_predictor_->TestBuildExecCtx( - &request, &response, "bob", p_opts_.executions[0], last_ctxs); - ASSERT_EQ(request.header().data().at("test-k"), - ctx_bob->exec_req->header().data().at("test-k")); - ASSERT_EQ(ctx_bob->exec_req->service_spec().id(), - request.service_spec().id()); - ASSERT_EQ(ctx_bob->exec_req->requester_id(), p_opts_.party_id); - ASSERT_TRUE(ctx_bob->exec_req->feature_source().type() == - apis::FeatureSourceType::FS_SERVICE); - ASSERT_TRUE(std::equal( - ctx_bob->exec_req->feature_source().fs_param().query_datas().begin(), - ctx_bob->exec_req->feature_source().fs_param().query_datas().end(), - request.fs_params().at(ctx_bob->target_id).query_datas().begin())); - ASSERT_EQ(ctx_bob->exec_req->feature_source().fs_param().query_context(), - request.fs_params().at(ctx_bob->target_id).query_context()); - ASSERT_TRUE(ctx_bob->exec_req->feature_source().predefineds().empty()); - ASSERT_EQ(ctx_bob->exec_req->task().execution_id(), 0); - ASSERT_TRUE(ctx_bob->exec_req->task().nodes().empty()); - - // build alice ctx - auto ctx_alice = mock_predictor_->TestBuildExecCtx( - &request, &response, "alice", p_opts_.executions[0], last_ctxs); - ASSERT_EQ(request.header().data().at("test-k"), - ctx_alice->exec_req->header().data().at("test-k")); - ASSERT_EQ(ctx_alice->exec_req->service_spec().id(), - request.service_spec().id()); - ASSERT_EQ(ctx_alice->exec_req->requester_id(), p_opts_.party_id); - ASSERT_TRUE(ctx_alice->exec_req->feature_source().type() == - apis::FeatureSourceType::FS_PREDEFINED); - ASSERT_TRUE( - ctx_alice->exec_req->feature_source().fs_param().query_datas().empty()); - ASSERT_TRUE( - ctx_alice->exec_req->feature_source().fs_param().query_context().empty()); - ASSERT_EQ(ctx_alice->exec_req->feature_source().predefineds_size(), - request.predefined_features_size()); - auto f1 = ctx_alice->exec_req->feature_source().predefineds(0); - ASSERT_FALSE(f1.field().name().empty()); - ASSERT_EQ(f1.field().name(), feature_1->field().name()); - ASSERT_EQ(f1.field().type(), feature_1->field().type()); - ASSERT_FALSE(f1.value().ss().empty()); - ASSERT_TRUE(std::equal(f1.value().ss().begin(), f1.value().ss().end(), - feature_1->value().ss().begin())); - auto f2 = ctx_alice->exec_req->feature_source().predefineds(1); - ASSERT_FALSE(f2.field().name().empty()); - ASSERT_EQ(f2.field().name(), feature_2->field().name()); - ASSERT_EQ(f2.field().type(), feature_2->field().type()); - ASSERT_FALSE(f2.value().ds().empty()); - ASSERT_TRUE(std::equal(f2.value().ds().begin(), f2.value().ds().end(), - feature_2->value().ds().begin())); - ASSERT_EQ(ctx_alice->exec_req->task().execution_id(), 0); - ASSERT_TRUE(ctx_alice->exec_req->task().nodes().empty()); - - // mock alice & bob response - { - auto exec_response = std::make_shared(); - exec_response->mutable_result()->set_execution_id(0); - auto node = exec_response->mutable_result()->add_nodes(); - node->set_name("mock_node_1"); - auto io = node->add_ios(); - io->add_datas("mock_bob_data"); - ctx_bob->exec_res = exec_response; - } - { - auto exec_response = std::make_shared(); - exec_response->mutable_result()->set_execution_id(0); - auto node_1 = exec_response->mutable_result()->add_nodes(); - node_1->set_name("mock_node_1"); - node_1->add_ios()->add_datas("mock_alice_data"); - ctx_alice->exec_res = exec_response; - } - auto ctx_list = std::make_shared< - std::vector>>(); - ctx_list->emplace_back(ctx_bob); - ctx_list->emplace_back(ctx_alice); - - // build ctx - auto ctx_final = mock_predictor_->TestBuildExecCtx( - &request, &response, "alice", p_opts_.executions[1], ctx_list); - EXPECT_EQ(request.header().data().at("test-k"), - ctx_final->exec_req->header().data().at("test-k")); - EXPECT_EQ(ctx_final->exec_req->service_spec().id(), - request.service_spec().id()); - EXPECT_EQ(ctx_final->exec_req->requester_id(), p_opts_.party_id); - EXPECT_TRUE(ctx_final->exec_req->feature_source().type() == - apis::FeatureSourceType::FS_NONE); - EXPECT_EQ(ctx_final->exec_req->task().execution_id(), 1); - EXPECT_EQ(ctx_final->exec_req->task().nodes_size(), 1); - auto node1 = ctx_final->exec_req->task().nodes(0); - EXPECT_EQ(node1.name(), "mock_node_2"); - EXPECT_EQ(node1.ios_size(), 1); - EXPECT_EQ(node1.ios(0).datas_size(), 2); - EXPECT_EQ(node1.ios(0).datas(0), "mock_bob_data"); - EXPECT_EQ(node1.ios(0).datas(1), "mock_alice_data"); -} - -TEST_F(PredictorImplTest, Predict) { - apis::PredictRequest request; - apis::PredictResponse response; - - // mock predict request - request.mutable_header()->mutable_data()->insert({"test-k", "test-v"}); - request.mutable_service_spec()->set_id("test_service_id"); - request.mutable_fs_params()->insert({"bob", {}}); - request.mutable_fs_params()->at("bob").set_query_context("bob_test_context"); - int params_num = 3; - for (int i = 0; i < params_num; ++i) { - request.mutable_fs_params()->at("bob").add_query_datas("bob_test_params"); - } - auto feature_1 = request.add_predefined_features(); - feature_1->mutable_field()->set_name("feature_1"); - feature_1->mutable_field()->set_type(FieldType::FIELD_STRING); - std::vector ss = {"true", "false", "true"}; - feature_1->mutable_value()->mutable_ss()->Assign(ss.begin(), ss.end()); - auto feature_2 = request.add_predefined_features(); - feature_2->mutable_field()->set_name("feature_2"); - feature_2->mutable_field()->set_type(FieldType::FIELD_DOUBLE); - std::vector ds = {1.1, 2.2, 3.3}; - feature_2->mutable_value()->mutable_ds()->Assign(ds.begin(), ds.end()); - - // mock bob's req & res - apis::ExecuteResponse bob_exec0_res; - apis::ExecuteRequest bob_exec0_req; - { - // build execute reponse - bob_exec0_res.mutable_header()->mutable_data()->insert( - {"bob-res-k", "bob-res-v"}); - bob_exec0_res.mutable_status()->set_code(1); - bob_exec0_res.mutable_service_spec()->set_id("test_service_id"); - bob_exec0_res.mutable_result()->set_execution_id(0); - auto node = bob_exec0_res.mutable_result()->add_nodes(); - node->set_name("mock_node_1"); - auto io = node->add_ios(); - io->add_datas("mock_bob_data"); - - // build execute request - bob_exec0_req.mutable_header()->mutable_data()->insert( - {"test-k", "test-v"}); - *bob_exec0_req.mutable_service_spec() = request.service_spec(); - bob_exec0_req.set_requester_id(p_opts_.party_id); - bob_exec0_req.mutable_feature_source()->set_type( - apis::FeatureSourceType::FS_SERVICE); - bob_exec0_req.mutable_feature_source()->mutable_fs_param()->CopyFrom( - request.fs_params().at("bob")); - bob_exec0_req.mutable_task()->set_execution_id(0); - } - // mock alice's req & res - apis::ExecuteResponse alice_exec0_res; - apis::ExecuteRequest alice_exec0_req; - { - // build execute reponse - alice_exec0_res.mutable_header()->mutable_data()->insert( - {"alice-res-k", "alice-res-v"}); - alice_exec0_res.mutable_status()->set_code(1); - alice_exec0_res.mutable_service_spec()->set_id("test_service_id"); - alice_exec0_res.mutable_result()->set_execution_id(0); - auto node = alice_exec0_res.mutable_result()->add_nodes(); - node->set_name("mock_node_1"); - auto io = node->add_ios(); - io->add_datas("mock_alice_data"); - - // build execute request - alice_exec0_req.mutable_header()->mutable_data()->insert( - {"test-k", "test-v"}); - *alice_exec0_req.mutable_service_spec() = request.service_spec(); - alice_exec0_req.set_requester_id(p_opts_.party_id); - alice_exec0_req.mutable_feature_source()->set_type( - apis::FeatureSourceType::FS_PREDEFINED); - alice_exec0_req.mutable_feature_source()->mutable_predefineds()->CopyFrom( - request.predefined_features()); - alice_exec0_req.mutable_task()->set_execution_id(0); - } - - // mock alice any one req & res - apis::ExecuteResponse alice_exec1_res; - apis::ExecuteRequest alice_exec1_req; - { - // build execute reponse - alice_exec1_res.mutable_header()->mutable_data()->insert( - {"alice-res-k", "alice-res-v"}); - alice_exec1_res.mutable_header()->mutable_data()->insert( - {"alice-res-k1", "alice-res-v1"}); - alice_exec1_res.mutable_status()->set_code(1); - alice_exec1_res.mutable_service_spec()->set_id("test_service_id"); - alice_exec1_res.mutable_result()->set_execution_id(1); - auto node = alice_exec1_res.mutable_result()->add_nodes(); - node->set_name("mock_node_2"); - auto io = node->add_ios(); - auto schema = arrow::schema({arrow::field("score_0", arrow::utf8()), - arrow::field("score_1", arrow::utf8())}); - arrow::StringBuilder builder; - std::shared_ptr array; - SERVING_CHECK_ARROW_STATUS(builder.AppendValues({"1", "2", "3"})); - SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); - auto record_batch = MakeRecordBatch(schema, 3, {array, array}); - io->add_datas(SerializeRecordBatch(record_batch)); - } - { - // build execute request - alice_exec1_req.mutable_header()->mutable_data()->insert( - {"test-k", "test-v"}); - *alice_exec1_req.mutable_service_spec() = request.service_spec(); - alice_exec1_req.set_requester_id(p_opts_.party_id); - alice_exec1_req.mutable_feature_source()->set_type( - apis::FeatureSourceType::FS_NONE); - alice_exec1_req.mutable_task()->set_execution_id(1); - auto node = alice_exec1_req.mutable_task()->add_nodes(); - node->set_name("mock_node_2"); - auto io = node->add_ios(); - io->add_datas("mock_bob_data"); - io->add_datas("mock_alice_data"); - } - - EXPECT_CALL(*mock_predictor_, - AsyncCallRpc(::testing::_, ::testing::_, - testing::Matcher( - ExecuteRequestEquel(&bob_exec0_req)), - ::testing::_)) - .Times(1) - .WillOnce(::testing::DoAll(::testing::SetArgPointee<3>(bob_exec0_res))); - - EXPECT_CALL(*mock_exec_core_, - Execute(testing::Matcher( - ExecuteRequestEquel(&alice_exec0_req)), - ::testing::_)) - .Times(1) - .WillOnce(::testing::SetArgPointee<1>(alice_exec0_res)); - - EXPECT_CALL(*mock_predictor_, JoinAsyncCall(::testing::_, ::testing::_)) - .Times(1); - - EXPECT_CALL(*mock_exec_core_, - Execute(testing::Matcher( - ExecuteRequestEquel(&alice_exec1_req)), - ::testing::_)) - .Times(1) - .WillOnce(::testing::SetArgPointee<1>(alice_exec1_res)); - - ASSERT_NO_THROW(mock_predictor_->Predict(&request, &response)); - ASSERT_EQ(response.header().data_size(), 3); - ASSERT_EQ(response.header().data().at("bob-res-k"), "bob-res-v"); - ASSERT_EQ(response.header().data().at("alice-res-k"), "alice-res-v"); - ASSERT_EQ(response.header().data().at("alice-res-k1"), "alice-res-v1"); - ASSERT_EQ(response.header().data_size(), 3); - ASSERT_EQ(response.results_size(), params_num); - for (int i = 0; i < params_num; ++i) { - auto result = response.results(i); - ASSERT_EQ(result.scores_size(), 2); - for (int j = 0; j < result.scores_size(); ++j) { - ASSERT_EQ(result.scores(j).name(), "score_" + std::to_string(j)); - ASSERT_EQ(result.scores(j).value(), i + 1); - } - } -} - -} // namespace secretflow::serving diff --git a/secretflow_serving/framework/predictor_test.cc b/secretflow_serving/framework/predictor_test.cc new file mode 100644 index 0000000..4aa716b --- /dev/null +++ b/secretflow_serving/framework/predictor_test.cc @@ -0,0 +1,461 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/framework/predictor.h" + +#include "brpc/channel.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "secretflow_serving/ops/op_factory.h" +#include "secretflow_serving/ops/op_kernel_factory.h" +#include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/utils.h" + +namespace secretflow::serving { + +namespace op { + +class MockOpKernel0 : public OpKernel { + public: + explicit MockOpKernel0(OpKernelOptions opts) : OpKernel(std::move(opts)) {} + + void DoCompute(ComputeContext* ctx) override {} + void BuildInputSchema() override {} + void BuildOutputSchema() override {} +}; + +class MockOpKernel1 : public OpKernel { + public: + explicit MockOpKernel1(OpKernelOptions opts) : OpKernel(std::move(opts)) {} + + void DoCompute(ComputeContext* ctx) override {} + void BuildInputSchema() override {} + void BuildOutputSchema() override {} +}; + +REGISTER_OP_KERNEL(TEST_OP_0, MockOpKernel0); +REGISTER_OP_KERNEL(TEST_OP_1, MockOpKernel1); +REGISTER_OP(TEST_OP_0, "0.0.1", "test_desc") + .StringAttr("attr_s", "attr_s_desc", false, false) + .Input("input", "input_desc") + .Output("output", "output_desc"); +REGISTER_OP(TEST_OP_1, "0.0.1", "test_desc") + .Mergeable() + .Returnable() + .StringAttr("attr_s", "attr_s_desc", false, false) + .Input("input", "input_desc") + .Output("output", "output_desc"); + +} // namespace op + +class MockExecutable : public Executable { + public: + MockExecutable() + : Executable({}), + mock_schema_( + arrow::schema({arrow::field("test_name", arrow::utf8())})) {} + + const std::shared_ptr& GetInputFeatureSchema() { + return mock_schema_; + } + + MOCK_METHOD1(Run, void(Task&)); + + private: + std::shared_ptr mock_schema_; +}; + +class MockExecutionCore : public ExecutionCore { + public: + MockExecutionCore(Options opts) : ExecutionCore(std::move(opts)) {} + + MOCK_METHOD2(Execute, + void(const apis::ExecuteRequest*, apis::ExecuteResponse*)); +}; + +class MockRemoteExecute : public RemoteExecute { + public: + using RemoteExecute::RemoteExecute; + void Run() override {} + void Cancel() override {} + void WaitToFinish() override { + exe_ctx_.CheckAndUpdateResponse(mock_exec_res); + } + void GetOutputs( + std::unordered_map>* + node_io_map) override { + ExeResponseToIoMap(mock_exec_res, node_io_map); + } + apis::ExecuteResponse mock_exec_res; +}; + +class MockPredictor : public Predictor { + public: + MockPredictor(const Options& options) : Predictor(options) {} + std::shared_ptr BuildRemoteExecute( + const apis::PredictRequest* request, apis::PredictResponse* response, + const std::shared_ptr& execution, std::string target_id, + std::shared_ptr<::google::protobuf::RpcChannel> channel) override { + auto exec = std::make_shared( + request, response, execution, target_id, opts_.party_id, channel); + exec->mock_exec_res = remote_exec_res_; + return exec; + } + + apis::ExecuteResponse remote_exec_res_; +}; + +bool ExeRequestEqual(const apis::ExecuteRequest* arg, + const apis::ExecuteRequest* expect) { + if (arg->service_spec().id() != expect->service_spec().id()) { + return false; + } + if (arg->requester_id() != expect->requester_id()) { + return false; + } + if (arg->feature_source().type() != expect->feature_source().type()) { + return false; + } + if (!std::equal(arg->feature_source().fs_param().query_datas().begin(), + arg->feature_source().fs_param().query_datas().end(), + expect->feature_source().fs_param().query_datas().begin())) { + return false; + } + if (arg->feature_source().fs_param().query_context() != + expect->feature_source().fs_param().query_context()) { + return false; + } + if (arg->feature_source().predefineds().size() != + expect->feature_source().predefineds().size()) { + return false; + } + + if (!std::equal(arg->feature_source().predefineds().begin(), + arg->feature_source().predefineds().end(), + expect->feature_source().predefineds().begin(), + [](const Feature& f1, const Feature& f2) { + return f1.field().name() == f2.field().name(); + })) { + return false; + } + if (arg->task().execution_id() != expect->task().execution_id()) { + return false; + } + if (!std::equal(arg->task().nodes().begin(), arg->task().nodes().end(), + expect->task().nodes().begin(), + [](const apis::NodeIo& n1, const apis::NodeIo& n2) { + return n1.name() == n2.name() && + std::equal(n1.ios().begin(), n1.ios().end(), + n2.ios().begin(), + [](const apis::IoData& io1, + const apis::IoData& io2) { + return std::equal(io1.datas().begin(), + io1.datas().end(), + io2.datas().begin()); + }); + })) { + return false; + } + return true; +} + +MATCHER_P(ExecuteRequestEqual, expect, "") { + return ExeRequestEqual(arg, expect); +} + +class PredictorTest : public ::testing::Test { + protected: + void SetUpEnv(std::vector& node_def_jsons, + std::vector& execution_def_jsons) { + MockExecutionCore::Options exec_opts{"test_id", "alice", std::nullopt, + std::nullopt, + std::make_shared()}; + mock_exec_core_ = new MockExecutionCore(std::move(exec_opts)); + exec_core_ = std::shared_ptr(mock_exec_core_); + + // build node + std::map> nodes; + for (const auto& j : node_def_jsons) { + NodeDef node_def; + JsonToPb(j, &node_def); + auto node = std::make_shared(std::move(node_def)); + nodes.emplace(node->GetName(), node); + } + // build edge + for (const auto& pair : nodes) { + const auto& input_nodes = pair.second->GetInputNodeNames(); + for (size_t i = 0; i < input_nodes.size(); ++i) { + auto n_iter = nodes.find(input_nodes[i]); + SERVING_ENFORCE(n_iter != nodes.end(), errors::ErrorCode::LOGIC_ERROR); + auto edge = std::make_shared(n_iter->first, pair.first, i); + n_iter->second->AddOutEdge(edge); + pair.second->AddInEdge(edge); + } + } + std::vector> executions; + for (size_t i = 0; i < execution_def_jsons.size(); ++i) { + ExecutionDef executino_def; + JsonToPb(execution_def_jsons[i], &executino_def); + + std::unordered_map> e_nodes; + for (const auto& n : executino_def.nodes()) { + e_nodes.emplace(n, nodes.find(n)->second); + } + + executions.emplace_back(std::make_shared( + i, std::move(executino_def), std::move(e_nodes))); + } + + // mock channel + channel_map_ = std::make_shared(); + auto channel = std::make_unique(); + channel_map_->emplace(std::make_pair("bob", std::move(channel))); + + p_opts_.party_id = "alice"; + p_opts_.channels = channel_map_; + p_opts_.executions = std::move(executions); + + mock_predictor_ = std::make_shared(p_opts_); + mock_predictor_->SetExecutionCore(exec_core_); + } + void SetUp() override {} + void TearDown() override { exec_core_ = nullptr; } + + protected: + Predictor::Options p_opts_; + + MockExecutionCore* mock_exec_core_; + + std::shared_ptr mock_predictor_; + std::shared_ptr exec_core_; + std::shared_ptr channel_map_; +}; + +TEST_F(PredictorTest, Predict) { + // mock execution + std::vector node_def_jsons = { + R"JSON( +{ + "name": "mock_node_1", + "op": "TEST_OP_0", +} +)JSON", + R"JSON( +{ + "name": "mock_node_2", + "op": "TEST_OP_1", + "parents": [ "mock_node_1" ], +} +)JSON"}; + + std::vector> execution_def_jsons_list = {{ + R"JSON( +{ + "nodes": [ + "mock_node_1" + ], + "config": { + "dispatch_type": "DP_ALL" + } +} +)JSON", + R"JSON( +{ + "nodes": [ + "mock_node_2" + ], + "config": { + "dispatch_type": "DP_ANYONE" + } +} +)JSON"}, + { + R"JSON( +{ + "nodes": [ + "mock_node_1" + ], + "config": { + "dispatch_type": "DP_ALL" + } +} +)JSON", + R"JSON( +{ + "nodes": [ + "mock_node_2" + ], + "config": { + "dispatch_type": "DP_SPECIFIED", + "specific_flag": true + } +} +)JSON"}}; + for (auto& execution_def_jsons : execution_def_jsons_list) { + SetUpEnv(node_def_jsons, execution_def_jsons); + apis::PredictRequest request; + apis::PredictResponse response; + + // mock predict request + request.mutable_header()->mutable_data()->insert({"test-k", "test-v"}); + request.mutable_service_spec()->set_id("test_service_id"); + request.mutable_fs_params()->insert({"bob", {}}); + request.mutable_fs_params()->at("bob").set_query_context( + "bob_test_context"); + int params_num = 3; + for (int i = 0; i < params_num; ++i) { + request.mutable_fs_params()->at("bob").add_query_datas("bob_test_params"); + } + auto feature_1 = request.add_predefined_features(); + feature_1->mutable_field()->set_name("feature_1"); + feature_1->mutable_field()->set_type(FieldType::FIELD_STRING); + std::vector ss = {"true", "false", "true"}; + feature_1->mutable_value()->mutable_ss()->Assign(ss.begin(), ss.end()); + auto feature_2 = request.add_predefined_features(); + feature_2->mutable_field()->set_name("feature_2"); + feature_2->mutable_field()->set_type(FieldType::FIELD_DOUBLE); + std::vector ds = {1.1, 2.2, 3.3}; + feature_2->mutable_value()->mutable_ds()->Assign(ds.begin(), ds.end()); + + // mock bob's req & res + apis::ExecuteResponse bob_exec0_res; + apis::ExecuteRequest bob_exec0_req; + { + // build execute reponse + bob_exec0_res.mutable_header()->mutable_data()->insert( + {"bob-res-k", "bob-res-v"}); + bob_exec0_res.mutable_status()->set_code(1); + bob_exec0_res.mutable_service_spec()->set_id("test_service_id"); + bob_exec0_res.mutable_result()->set_execution_id(0); + auto node = bob_exec0_res.mutable_result()->add_nodes(); + node->set_name("mock_node_1"); + auto io = node->add_ios(); + io->add_datas("mock_bob_data"); + + // build execute request + bob_exec0_req.mutable_header()->mutable_data()->insert( + {"test-k", "test-v"}); + *bob_exec0_req.mutable_service_spec() = request.service_spec(); + bob_exec0_req.set_requester_id(p_opts_.party_id); + bob_exec0_req.mutable_feature_source()->set_type( + apis::FeatureSourceType::FS_SERVICE); + bob_exec0_req.mutable_feature_source()->mutable_fs_param()->CopyFrom( + request.fs_params().at("bob")); + bob_exec0_req.mutable_task()->set_execution_id(0); + } + // mock alice's req & res + apis::ExecuteResponse alice_exec0_res; + apis::ExecuteRequest alice_exec0_req; + { + // build execute reponse + alice_exec0_res.mutable_header()->mutable_data()->insert( + {"alice-res-k", "alice-res-v"}); + alice_exec0_res.mutable_status()->set_code(1); + alice_exec0_res.mutable_service_spec()->set_id("test_service_id"); + alice_exec0_res.mutable_result()->set_execution_id(0); + auto node = alice_exec0_res.mutable_result()->add_nodes(); + node->set_name("mock_node_1"); + auto io = node->add_ios(); + io->add_datas("mock_alice_data"); + + // build execute request + alice_exec0_req.mutable_header()->mutable_data()->insert( + {"test-k", "test-v"}); + *alice_exec0_req.mutable_service_spec() = request.service_spec(); + alice_exec0_req.set_requester_id(p_opts_.party_id); + alice_exec0_req.mutable_feature_source()->set_type( + apis::FeatureSourceType::FS_PREDEFINED); + alice_exec0_req.mutable_feature_source()->mutable_predefineds()->CopyFrom( + request.predefined_features()); + alice_exec0_req.mutable_task()->set_execution_id(0); + } + + // mock alice any one req & res + apis::ExecuteResponse alice_exec1_res; + apis::ExecuteRequest alice_exec1_req; + { + // build execute reponse + alice_exec1_res.mutable_header()->mutable_data()->insert( + {"alice-res-k", "alice-res-v"}); + alice_exec1_res.mutable_header()->mutable_data()->insert( + {"alice-res-k1", "alice-res-v1"}); + alice_exec1_res.mutable_status()->set_code(1); + alice_exec1_res.mutable_service_spec()->set_id("test_service_id"); + alice_exec1_res.mutable_result()->set_execution_id(1); + auto node = alice_exec1_res.mutable_result()->add_nodes(); + node->set_name("mock_node_2"); + auto io = node->add_ios(); + auto schema = arrow::schema({arrow::field("score_0", arrow::utf8()), + arrow::field("score_1", arrow::utf8())}); + arrow::StringBuilder builder; + std::shared_ptr array; + SERVING_CHECK_ARROW_STATUS(builder.AppendValues({"1", "2", "3"})); + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + auto record_batch = MakeRecordBatch(schema, 3, {array, array}); + io->add_datas(SerializeRecordBatch(record_batch)); + } + { + // build execute request + alice_exec1_req.mutable_header()->mutable_data()->insert( + {"test-k", "test-v"}); + *alice_exec1_req.mutable_service_spec() = request.service_spec(); + alice_exec1_req.set_requester_id(p_opts_.party_id); + alice_exec1_req.mutable_feature_source()->set_type( + apis::FeatureSourceType::FS_NONE); + alice_exec1_req.mutable_task()->set_execution_id(1); + auto node = alice_exec1_req.mutable_task()->add_nodes(); + node->set_name("mock_node_2"); + auto io = node->add_ios(); + // local first + io->add_datas("mock_alice_data"); + io->add_datas("mock_bob_data"); + } + + EXPECT_CALL(*mock_exec_core_, + Execute(testing::Matcher( + ExecuteRequestEqual(&alice_exec0_req)), + ::testing::_)) + .Times(1) + .WillOnce(::testing::SetArgPointee<1>(alice_exec0_res)); + + EXPECT_CALL(*mock_exec_core_, + Execute(testing::Matcher( + ExecuteRequestEqual(&alice_exec1_req)), + ::testing::_)) + .Times(1) + .WillOnce(::testing::SetArgPointee<1>(alice_exec1_res)); + + mock_predictor_->remote_exec_res_ = bob_exec0_res; + ASSERT_NO_THROW(mock_predictor_->Predict(&request, &response)); + for (const auto& [k, v] : response.header().data()) { + std::cout << k << " : " << v << std::endl; + } + ASSERT_EQ(response.header().data_size(), 3); + ASSERT_EQ(response.header().data().at("bob-res-k"), "bob-res-v"); + ASSERT_EQ(response.header().data().at("alice-res-k"), "alice-res-v"); + ASSERT_EQ(response.header().data().at("alice-res-k1"), "alice-res-v1"); + ASSERT_EQ(response.results_size(), params_num); + for (int i = 0; i < params_num; ++i) { + auto result = response.results(i); + ASSERT_EQ(result.scores_size(), 2); + for (int j = 0; j < result.scores_size(); ++j) { + ASSERT_EQ(result.scores(j).name(), "score_" + std::to_string(j)); + ASSERT_EQ(result.scores(j).value(), i + 1); + } + } + } +} + +} // namespace secretflow::serving diff --git a/secretflow_serving/framework/propagator.cc b/secretflow_serving/framework/propagator.cc index 54a805f..fb7d48a 100644 --- a/secretflow_serving/framework/propagator.cc +++ b/secretflow_serving/framework/propagator.cc @@ -16,51 +16,25 @@ namespace secretflow::serving { -Propagator::Propagator() {} - -FrameState* Propagator::CreateFrame(const std::shared_ptr& node) { - auto frame = std::make_unique(); - frame->node_name = node->node_def().name(); - frame->pending_count = node->GetInputNum(); - frame->compute_ctx.inputs = - std::make_shared(frame->pending_count); - - auto result = frame.get(); - std::lock_guard lock(mutex_); - SERVING_ENFORCE( - node_frame_map_.emplace(node->node_def().name(), std::move(frame)).second, - errors::ErrorCode::LOGIC_ERROR); - - return result; -} - -FrameState* Propagator::FindOrCreateChildFrame( - FrameState* frame, const std::shared_ptr& child_node) { - std::lock_guard lock(mutex_); - auto iter = node_frame_map_.find(child_node->node_def().name()); - if (iter != node_frame_map_.end()) { - return iter->second.get(); - } else { - auto child_frame = std::make_unique(); - child_frame->node_name = child_node->node_def().name(); - child_frame->parent_name = frame->node_name; - child_frame->pending_count = child_node->GetInputNum(); - child_frame->compute_ctx.inputs = - std::make_shared(child_node->GetInputNum()); - - auto result = child_frame.get(); - node_frame_map_.emplace(child_node->node_def().name(), - std::move(child_frame)); - return result; +Propagator::Propagator( + const std::unordered_map>& nodes) { + frame_pool_ = std::vector(nodes.size()); + size_t idx = 0; + for (auto& [node_name, node] : nodes) { + auto frame = &frame_pool_[idx++]; + frame->pending_count = node->GetInputNum(); + frame->compute_ctx.inputs.resize(frame->pending_count); + + SERVING_ENFORCE(node_frame_map_.emplace(node_name, std::move(frame)).second, + errors::ErrorCode::LOGIC_ERROR); } } FrameState* Propagator::GetFrame(const std::string& node_name) { - std::lock_guard lock(mutex_); auto iter = node_frame_map_.find(node_name); SERVING_ENFORCE(iter != node_frame_map_.end(), errors::ErrorCode::LOGIC_ERROR, "can not found frame for node: {}", node_name); - return iter->second.get(); + return iter->second; } } // namespace secretflow::serving diff --git a/secretflow_serving/framework/propagator.h b/secretflow_serving/framework/propagator.h index fb56307..94065f5 100644 --- a/secretflow_serving/framework/propagator.h +++ b/secretflow_serving/framework/propagator.h @@ -23,28 +23,20 @@ namespace secretflow::serving { struct FrameState { - std::string node_name; - - int pending_count; - - std::string parent_name; + std::atomic pending_count; op::ComputeContext compute_ctx; }; class Propagator { public: - explicit Propagator(); - - FrameState* CreateFrame(const std::shared_ptr& node); - - FrameState* FindOrCreateChildFrame(FrameState* frame, - const std::shared_ptr& child_node); + explicit Propagator( + const std::unordered_map>& nodes); FrameState* GetFrame(const std::string& node_name); private: - std::mutex mutex_; - std::map> node_frame_map_; + std::unordered_map node_frame_map_; + std::vector frame_pool_; }; } // namespace secretflow::serving diff --git a/secretflow_serving/ops/BUILD.bazel b/secretflow_serving/ops/BUILD.bazel index 90dec39..8b73eb3 100644 --- a/secretflow_serving/ops/BUILD.bazel +++ b/secretflow_serving/ops/BUILD.bazel @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//bazel:serving.bzl", "serving_cc_library", "serving_cc_test") +load("//bazel:serving.bzl", "serving_cc_binary", "serving_cc_library", "serving_cc_test") package(default_visibility = ["//visibility:public"]) serving_cc_library( name = "ops", deps = [ + ":arrow_processing", ":dot_product", ":merge_y", ], @@ -84,7 +85,7 @@ serving_cc_library( ":node", "//secretflow_serving/core:exception", "//secretflow_serving/protos:op_cc_proto", - "@org_apache_arrow//:arrow", + "//secretflow_serving/util:arrow_helper", ], ) @@ -134,7 +135,6 @@ serving_cc_library( ":op_factory", ":op_kernel_factory", "//secretflow_serving/core:types", - "//secretflow_serving/util:arrow_helper", "//secretflow_serving/util:utils", ], alwayslink = True, @@ -145,6 +145,16 @@ serving_cc_test( srcs = ["dot_product_test.cc"], deps = [ ":dot_product", + "//secretflow_serving/util:test_utils", + ], +) + +serving_cc_binary( + name = "dot_product_benchmark", + srcs = ["dot_product_benchmark.cc"], + deps = [ + ":dot_product", + "@com_github_google_benchmark//:benchmark_main", ], ) @@ -157,7 +167,6 @@ serving_cc_library( ":op_factory", ":op_kernel_factory", "//secretflow_serving/core:link_func", - "//secretflow_serving/util:arrow_helper", ], alwayslink = True, ) @@ -170,3 +179,41 @@ serving_cc_test( "//secretflow_serving/util:utils", ], ) + +serving_cc_binary( + name = "merge_y_benchmark", + srcs = ["merge_y_benchmark.cc"], + deps = [ + ":merge_y", + "//secretflow_serving/util:utils", + "@com_github_google_benchmark//:benchmark_main", + ], +) + +serving_cc_library( + name = "arrow_processing", + srcs = ["arrow_processing.cc"], + hdrs = ["arrow_processing.h"], + deps = [ + ":node_def_util", + ":op_factory", + ":op_kernel_factory", + "//secretflow_serving/protos:compute_trace_cc_proto", + ], + alwayslink = True, +) + +serving_cc_test( + name = "arrow_processing_test", + srcs = ["arrow_processing_test.cc"], + deps = [ + ":arrow_processing", + "//secretflow_serving/util:test_utils", + "//secretflow_serving/util:utils", + ], +) + +serving_cc_library( + name = "graph_version", + hdrs = ["graph_version.h"], +) diff --git a/secretflow_serving/ops/arrow_processing.cc b/secretflow_serving/ops/arrow_processing.cc new file mode 100644 index 0000000..76399d3 --- /dev/null +++ b/secretflow_serving/ops/arrow_processing.cc @@ -0,0 +1,482 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/ops/arrow_processing.h" + +#include +#include + +#include "arrow/compute/api.h" +#include "spdlog/spdlog.h" + +#include "secretflow_serving/ops/node_def_util.h" +#include "secretflow_serving/ops/op_factory.h" +#include "secretflow_serving/ops/op_kernel_factory.h" +#include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/utils.h" + +namespace secretflow::serving::op { + +namespace { + +const static std::set kReturnableFuncNames = { + compute::ExtendFunctionName_Name( + compute::ExtendFunctionName::EFN_TB_ADD_COLUMN), + compute::ExtendFunctionName_Name( + compute::ExtendFunctionName::EFN_TB_REMOVE_COLUMN), + compute::ExtendFunctionName_Name( + compute::ExtendFunctionName::EFN_TB_SET_COLUMN)}; + +arrow::Datum PbScalar2Datum(const compute::Scalar& scalar) { + switch (scalar.value_case()) { + case compute::Scalar::ValueCase::kI8: { + return arrow::Datum(static_cast(scalar.i8())); + } + case compute::Scalar::ValueCase::kUi8: { + return arrow::Datum(static_cast(scalar.ui8())); + } + case compute::Scalar::ValueCase::kI16: { + return arrow::Datum(static_cast(scalar.i16())); + } + case compute::Scalar::ValueCase::kUi16: { + return arrow::Datum(static_cast(scalar.ui16())); + } + case compute::Scalar::ValueCase::kI32: { + return arrow::Datum(scalar.i32()); + } + case compute::Scalar::ValueCase::kUi32: { + return arrow::Datum(scalar.ui32()); + } + case compute::Scalar::ValueCase::kI64: { + return arrow::Datum(scalar.i64()); + } + case compute::Scalar::ValueCase::kUi64: { + return arrow::Datum(scalar.ui64()); + } + case compute::Scalar::ValueCase::kF: { + return arrow::Datum(scalar.f()); + } + case compute::Scalar::ValueCase::kD: { + return arrow::Datum(scalar.d()); + } + case compute::Scalar::ValueCase::kS: { + return arrow::Datum(scalar.s()); + } + case compute::Scalar::ValueCase::kB: { + return arrow::Datum(scalar.b()); + } + default: + SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, + "invalid pb scalar type: {}", + static_cast(scalar.value_case())); + } +} + +std::vector BuildInputDatumKind( + const ::google::protobuf::RepeatedPtrField& + func_inputs, + const std::map& data_id_map) { + std::vector results; + for (const auto& in : func_inputs) { + if (in.value_case() == compute::FunctionInput::ValueCase::kDataId) { + auto iter = data_id_map.find(in.data_id()); + SERVING_ENFORCE(iter != data_id_map.end(), errors::ErrorCode::LOGIC_ERROR, + "can not found input data_id({})", in.data_id()); + results.emplace_back(iter->second); + } else if (in.value_case() == + compute::FunctionInput::ValueCase::kCustomScalar) { + results.emplace_back(arrow::Datum::Kind::SCALAR); + } else { + SERVING_THROW(errors::ErrorCode::INVALID_ARGUMENT, + "invalid function input type:{}", + static_cast(in.value_case())); + } + } + + return results; +} + +std::vector BuildInputDatums( + const ::google::protobuf::RepeatedPtrField& + func_inputs, + const std::map& datas) { + std::vector results; + for (const auto& in : func_inputs) { + if (in.value_case() == compute::FunctionInput::ValueCase::kDataId) { + auto iter = datas.find(in.data_id()); + SERVING_ENFORCE(iter != datas.end(), errors::ErrorCode::LOGIC_ERROR, + "can not found input data_id({})", in.data_id()); + results.emplace_back(iter->second); + } else if (in.value_case() == + compute::FunctionInput::ValueCase::kCustomScalar) { + results.emplace_back(PbScalar2Datum(in.custom_scalar())); + } else { + SERVING_THROW(errors::ErrorCode::INVALID_ARGUMENT, + "invalid function input type:{}", + static_cast(in.value_case())); + } + } + + return results; +} + +} // namespace + +ArrowProcessing::ArrowProcessing(OpKernelOptions opts) + : OpKernel(std::move(opts)) { + BuildInputSchema(); + BuildOutputSchema(); + + // optional attr + std::string trace_content; + GetNodeBytesAttr(opts_.node_def, "trace_content", &trace_content); + if (trace_content.empty()) { + dummy_flag_ = true; + return; + } + + bool content_json_flag = false; + GetNodeAttr(opts_.node_def, "content_json_flag", &content_json_flag); + + if (content_json_flag) { + JsonToPb(trace_content, &compute_trace_); + } else { + SERVING_ENFORCE(compute_trace_.ParseFromString(trace_content), + errors::ErrorCode::DESERIALIZE_FAILED, + "parse trace pb bytes failed"); + } + + if (compute_trace_.func_traces().empty()) { + dummy_flag_ = true; + return; + } + + // sanity check + // check the last compute func is returnable(output record_batch) + const auto& end_func = *(compute_trace_.func_traces().rbegin()); + SERVING_ENFORCE( + kReturnableFuncNames.find(end_func.name()) != kReturnableFuncNames.end(), + errors::ErrorCode::LOGIC_ERROR, + "the last compute function({}) is not returnable", end_func.name()); + result_id_ = end_func.output().data_id(); + + int num_fields = input_schema_list_.front()->num_fields(); + std::map data_id_map = { + {0, arrow::Datum::Kind::RECORD_BATCH}}; + + func_list_.reserve(compute_trace_.func_traces_size()); + for (int i = 0; i < compute_trace_.func_traces_size(); ++i) { + const auto& func = compute_trace_.func_traces(i); + compute::ExtendFunctionName ex_func_name; + auto input_kinds = BuildInputDatumKind(func.inputs(), data_id_map); + SERVING_ENFORCE(!input_kinds.empty(), errors::ErrorCode::LOGIC_ERROR); + if (compute::ExtendFunctionName_Parse(func.name(), &ex_func_name)) { + // check ext func inputs type valid + SERVING_ENFORCE(input_kinds[0] == arrow::Datum::Kind::RECORD_BATCH, + errors::ErrorCode::LOGIC_ERROR); + if (ex_func_name == compute::ExtendFunctionName::EFN_TB_COLUMN || + ex_func_name == compute::ExtendFunctionName::EFN_TB_REMOVE_COLUMN) { + // std::shared_ptr column(int) const + // + // Result> RemoveColumn(int) const + // + SERVING_ENFORCE(func.inputs_size() == 2, + errors::ErrorCode::INVALID_ARGUMENT); + // check index valid + SERVING_ENFORCE(input_kinds[1] == arrow::Datum::Kind::SCALAR, + errors::ErrorCode::LOGIC_ERROR); + const auto& index_scalar = func.inputs(1).custom_scalar(); + SERVING_ENFORCE(index_scalar.has_i64(), errors::ErrorCode::LOGIC_ERROR); + // 0 <= index < num_fields + SERVING_ENFORCE_GE(index_scalar.i64(), 0); + SERVING_ENFORCE_LT(index_scalar.i64(), num_fields); + + } else if (ex_func_name == + compute::ExtendFunctionName::EFN_TB_ADD_COLUMN || + ex_func_name == + compute::ExtendFunctionName::EFN_TB_SET_COLUMN) { + // Result> AddColumn(int, std::string + // field_name, const std::shared_ptr&) const + // + // Result> SetColumn(int, const + // std::shared_ptr&, const std::shared_ptr&) const + // + SERVING_ENFORCE(func.inputs_size() == 4, + errors::ErrorCode::INVALID_ARGUMENT); + SERVING_ENFORCE(input_kinds[1] == arrow::Datum::Kind::SCALAR && + input_kinds[2] == arrow::Datum::Kind::SCALAR, + errors::ErrorCode::LOGIC_ERROR); + + SERVING_ENFORCE(input_kinds[1] == arrow::Datum::Kind::SCALAR && + input_kinds[2] == arrow::Datum::Kind::SCALAR, + errors::ErrorCode::LOGIC_ERROR); + // check index valid + const auto& index_scalar = func.inputs(1).custom_scalar(); + SERVING_ENFORCE(index_scalar.has_i64(), errors::ErrorCode::LOGIC_ERROR); + // 0 <= index + SERVING_ENFORCE_GE(index_scalar.i64(), 0); + if (ex_func_name == compute::ExtendFunctionName::EFN_TB_ADD_COLUMN) { + // index <= num_fields + SERVING_ENFORCE_LE(index_scalar.i64(), num_fields); + } + if (ex_func_name == compute::ExtendFunctionName::EFN_TB_SET_COLUMN) { + // index < num_fields + SERVING_ENFORCE_LT(index_scalar.i64(), num_fields); + } + + // check field name valid + SERVING_ENFORCE(func.inputs(2).custom_scalar().value_case() == + compute::Scalar::ValueCase::kS, + errors::ErrorCode::INVALID_ARGUMENT, + "{}th input must be str for func:{}", 2, func.name()); + // check data valid + SERVING_ENFORCE(input_kinds[3] == arrow::Datum::Kind::ARRAY, + errors::ErrorCode::LOGIC_ERROR); + } else { + SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, + "invalid ext func name: {}", func.name()); + } + + auto output_kind = + ex_func_name == compute::ExtendFunctionName::EFN_TB_COLUMN + ? arrow::Datum::Kind::ARRAY + : arrow::Datum::Kind::RECORD_BATCH; + SERVING_ENFORCE( + data_id_map.emplace(func.output().data_id(), output_kind).second, + errors::ErrorCode::LOGIC_ERROR, "found duplicate data_id: {}", + func.output().data_id()); + + switch (ex_func_name) { + case compute::ExtendFunctionName::EFN_TB_COLUMN: { + func_list_.emplace_back([](arrow::Datum& result_datum, + std::vector& func_inputs) { + result_datum = func_inputs[0].record_batch()->column( + std::static_pointer_cast( + func_inputs[1].scalar()) + ->value); + }); + break; + } + case compute::ExtendFunctionName::EFN_TB_ADD_COLUMN: { + func_list_.emplace_back([](arrow::Datum& result_datum, + std::vector& func_inputs) { + int64_t index = std::static_pointer_cast( + func_inputs[1].scalar()) + ->value; + std::string field_name( + std::static_pointer_cast( + func_inputs[2].scalar()) + ->view()); + std::shared_ptr new_batch; + SERVING_GET_ARROW_RESULT( + func_inputs[0].record_batch()->AddColumn( + index, std::move(field_name), func_inputs[3].make_array()), + new_batch); + result_datum = new_batch; + }); + break; + } + case compute::ExtendFunctionName::EFN_TB_REMOVE_COLUMN: { + func_list_.emplace_back([](arrow::Datum& result_datum, + std::vector& func_inputs) { + std::shared_ptr new_batch; + SERVING_GET_ARROW_RESULT( + func_inputs[0].record_batch()->RemoveColumn( + std::static_pointer_cast( + func_inputs[1].scalar()) + ->value), + new_batch); + result_datum = new_batch; + }); + break; + } + case compute::ExtendFunctionName::EFN_TB_SET_COLUMN: { + func_list_.emplace_back([](arrow::Datum& result_datum, + std::vector& func_inputs) { + int64_t index = std::static_pointer_cast( + func_inputs[1].scalar()) + ->value; + std::string field_name( + std::static_pointer_cast( + func_inputs[2].scalar()) + ->view()); + std::shared_ptr array = func_inputs[3].make_array(); + std::shared_ptr new_batch; + SERVING_GET_ARROW_RESULT( + func_inputs[0].record_batch()->SetColumn( + index, arrow::field(std::move(field_name), array->type()), + array), + new_batch); + result_datum = new_batch; + }); + break; + } + default: + SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, + "invalid ext func name enum: {}", + static_cast(ex_func_name)); + } + } else { + // check arrow compute func valid + std::shared_ptr arrow_func; + SERVING_GET_ARROW_RESULT( + arrow::compute::GetFunctionRegistry()->GetFunction(func.name()), + arrow_func); + // Noticed, we only allowed scalar type arrow compute function + SERVING_ENFORCE( + arrow_func->kind() == arrow::compute::Function::Kind::SCALAR, + errors::ErrorCode::LOGIC_ERROR, "unsupported arrow compute func:{}", + func.name()); + + // check func options valid + if (arrow_func->doc().options_required) { + SERVING_ENFORCE(!func.option_bytes().empty(), + errors::ErrorCode::LOGIC_ERROR, + "arrow compute func:{} cannot be called without " + "options(empty `option_bytes`).", + func.name()); + } + if (func.option_bytes().empty()) { + func_list_.emplace_back( + [func_name = func.name()]( + arrow::Datum& result_datum, + std::vector& + func_inputs) { // call arrow compute func + std::for_each(func_inputs.begin(), func_inputs.end(), + [](const arrow::Datum& d) { + SERVING_ENFORCE(d.is_value(), + errors::ErrorCode::LOGIC_ERROR); + }); + SERVING_GET_ARROW_RESULT( + arrow::compute::CallFunction(func_name, func_inputs), + result_datum); + + }); + } else { + arrow::Buffer option_buf(func.option_bytes()); + std::unique_ptr func_opts; + SERVING_GET_ARROW_RESULT( + arrow::compute::FunctionOptions::Deserialize( + arrow_func->doc().options_class, option_buf), + func_opts); + + func_list_.emplace_back( + [func_name = func.name(), opt_ptr = func_opts.get(), arrow_func]( + arrow::Datum& result_datum, + std::vector& + func_inputs) { // call arrow compute func + std::for_each(func_inputs.begin(), func_inputs.end(), + [](const arrow::Datum& d) { + SERVING_ENFORCE(d.is_value(), + errors::ErrorCode::LOGIC_ERROR); + }); + + SERVING_GET_ARROW_RESULT( + arrow_func->Execute(func_inputs, opt_ptr, + arrow::compute::default_exec_context()), + result_datum); + }); + func_opt_map_.emplace(i, std::move(func_opts)); + } + + // check inputs invalid + for (size_t j = 0; j < input_kinds.size(); ++j) { + SERVING_ENFORCE(input_kinds[j] == arrow::Datum::Kind::ARRAY || + input_kinds[j] == arrow::Datum::Kind::SCALAR, + errors::ErrorCode::LOGIC_ERROR, + "invalid input type for func({}) {}th arg", func.name(), + j); + } + + SERVING_ENFORCE( + data_id_map + .emplace(func.output().data_id(), arrow::Datum::Kind::ARRAY) + .second, + errors::ErrorCode::LOGIC_ERROR, "found duplicate data_id: {}", + func.output().data_id()); + } + } +} + +void ArrowProcessing::DoCompute(ComputeContext* ctx) { + // sanity check + SERVING_ENFORCE(ctx->inputs.size() == 1, errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(ctx->inputs.front().size() == 1, + errors::ErrorCode::LOGIC_ERROR); + + if (dummy_flag_) { + ctx->output = ctx->inputs.front().front(); + return; + } + + SPDLOG_INFO("replay compute: {}", compute_trace_.name()); + + ctx->output = ReplayCompute(ctx->inputs.front().front()); +} + +std::shared_ptr ArrowProcessing::ReplayCompute( + const std::shared_ptr& input) { + std::map datas = {{0, input}}; + + arrow::Datum result_datum; + for (int i = 0; i < compute_trace_.func_traces_size(); ++i) { + const auto& func = compute_trace_.func_traces(i); + SPDLOG_DEBUG("replay func: {}", func.ShortDebugString()); + auto func_inputs = BuildInputDatums(func.inputs(), datas); + func_list_[i](result_datum, func_inputs); + + SERVING_ENFORCE( + datas.emplace(func.output().data_id(), std::move(result_datum)).second, + errors::ErrorCode::LOGIC_ERROR); + } + + return datas[result_id_].record_batch(); +} + +void ArrowProcessing::BuildInputSchema() { + input_schema_bytes_ = GetNodeBytesAttr(opts_.node_def, "input_schema_bytes"); + SERVING_ENFORCE(!input_schema_bytes_.empty(), + errors::ErrorCode::INVALID_ARGUMENT, + "get empty `input_schema_bytes`"); + auto input_schema = DeserializeSchema(input_schema_bytes_); + for (const auto& f : input_schema->fields()) { + CheckArrowDataTypeValid(f->type()); + } + input_schema_list_.emplace_back(std::move(input_schema)); +} + +void ArrowProcessing::BuildOutputSchema() { + output_schema_bytes_ = + GetNodeBytesAttr(opts_.node_def, "output_schema_bytes"); + SERVING_ENFORCE(!output_schema_bytes_.empty(), + errors::ErrorCode::INVALID_ARGUMENT, + "get empty `output_schema_bytes`"); + output_schema_ = DeserializeSchema(output_schema_bytes_); +} + +REGISTER_OP_KERNEL(ARROW_PROCESSING, ArrowProcessing) +REGISTER_OP(ARROW_PROCESSING, "0.0.1", "Replay secretflow compute functions") + .Returnable() + .BytesAttr("input_schema_bytes", + "Serialized data of input schema(arrow::Schema)", false, false) + .BytesAttr("output_schema_bytes", + "Serialized data of output schema(arrow::Schema)", false, false) + .BytesAttr("trace_content", "Serialized data of secretflow compute trace", + false, true, "") + .BoolAttr("content_json_flag", "Whether `trace_content` is serialized json", + false, true, false) + .Input("input", "") + .Output("output", ""); + +} // namespace secretflow::serving::op diff --git a/secretflow_serving/ops/arrow_processing.h b/secretflow_serving/ops/arrow_processing.h new file mode 100644 index 0000000..b6de763 --- /dev/null +++ b/secretflow_serving/ops/arrow_processing.h @@ -0,0 +1,55 @@ + +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "secretflow_serving/ops/op_kernel.h" + +#include "secretflow_serving/protos/compute_trace.pb.h" + +namespace secretflow::serving::op { + +class ArrowProcessing : public OpKernel { + public: + explicit ArrowProcessing(OpKernelOptions opts); + + void DoCompute(ComputeContext* ctx) override; + + protected: + void BuildInputSchema() override; + + void BuildOutputSchema() override; + + std::shared_ptr ReplayCompute( + const std::shared_ptr& input); + + private: + compute::ComputeTrace compute_trace_; + + std::string input_schema_bytes_; + std::string output_schema_bytes_; + + int32_t result_id_; + + std::map> func_opt_map_; + std::vector&)>> + func_list_; + + bool dummy_flag_ = false; +}; + +} // namespace secretflow::serving::op diff --git a/secretflow_serving/ops/arrow_processing_test.cc b/secretflow_serving/ops/arrow_processing_test.cc new file mode 100644 index 0000000..8fb664e --- /dev/null +++ b/secretflow_serving/ops/arrow_processing_test.cc @@ -0,0 +1,971 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/ops/arrow_processing.h" + +#include "arrow/compute/api.h" +#include "arrow/ipc/api.h" +#include "gtest/gtest.h" + +#include "secretflow_serving/ops/op_factory.h" +#include "secretflow_serving/ops/op_kernel_factory.h" +#include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/test_utils.h" +#include "secretflow_serving/util/utils.h" + +namespace secretflow::serving::op { + +namespace { +std::shared_ptr BuildRecordBatch( + const std::vector& array_jsons, + std::vector> fields) { + std::vector> input_arrays; + std::shared_ptr tmp_array; + for (size_t i = 0; i < array_jsons.size(); ++i) { + SERVING_GET_ARROW_RESULT(arrow::ipc::internal::json::ArrayFromJSON( + fields[i]->type(), array_jsons[i]), + tmp_array); + input_arrays.emplace_back(tmp_array); + } + return MakeRecordBatch(arrow::schema(fields), tmp_array->length(), + std::move(input_arrays)); +} +} // namespace + +struct Param { + bool content_json_flag; + std::vector func_trace_contents; + std::map> + func_trace_opts; + + std::vector input_array_jsons; + std::vector> input_fields; + + std::vector output_array_jsons; + std::vector> output_fields; +}; + +class ArrowProcessingParamTest : public ::testing::TestWithParam { + protected: + void SetUp() override {} + void TearDown() override {} +}; + +TEST_P(ArrowProcessingParamTest, Works) { + auto param = GetParam(); + + // build input & expect_output + auto input = BuildRecordBatch(param.input_array_jsons, param.input_fields); + + auto expect_output = + BuildRecordBatch(param.output_array_jsons, param.output_fields); + std::cout << "expect_output: " << expect_output->ToString() << std::endl; + + // build node + compute::ComputeTrace compute_trace; + compute_trace.set_name("test_compute"); + for (size_t i = 0; i < param.func_trace_contents.size(); ++i) { + auto* func_trace = compute_trace.add_func_traces(); + JsonToPb(param.func_trace_contents[i], func_trace); + + auto it = param.func_trace_opts.find(i); + if (it != param.func_trace_opts.end()) { + std::shared_ptr buf; + SERVING_GET_ARROW_RESULT(it->second->Serialize(), buf); + func_trace->set_option_bytes(reinterpret_cast(buf->data()), + buf->size()); + } + } + + NodeDef node_def; + node_def.set_name("test_node"); + node_def.set_op("ARROW_PROCESSING"); + + AttrValue trace_content; + AttrValue content_json_flag; + if (param.content_json_flag) { + trace_content.set_by(PbToJson(&compute_trace)); + content_json_flag.set_b(true); + node_def.mutable_attr_values()->insert( + {"content_json_flag", content_json_flag}); + } else { + trace_content.set_by(compute_trace.SerializeAsString()); + } + if (!param.func_trace_contents.empty()) { + node_def.mutable_attr_values()->insert({"trace_content", trace_content}); + } + + { + AttrValue input_schema_bytes; + std::shared_ptr buf; + SERVING_GET_ARROW_RESULT(arrow::ipc::SerializeSchema(*input->schema()), + buf); + input_schema_bytes.set_by(reinterpret_cast(buf->data()), + buf->size()); + node_def.mutable_attr_values()->insert( + {"input_schema_bytes", std::move(input_schema_bytes)}); + } + { + AttrValue output_schema_bytes; + std::shared_ptr buf; + SERVING_GET_ARROW_RESULT( + arrow::ipc::SerializeSchema(*expect_output->schema()), buf); + output_schema_bytes.set_by(reinterpret_cast(buf->data()), + buf->size()); + node_def.mutable_attr_values()->insert( + {"output_schema_bytes", std::move(output_schema_bytes)}); + } + + auto mock_node = std::make_shared(std::move(node_def)); + ASSERT_EQ(mock_node->GetOpDef()->inputs_size(), 1); + ASSERT_TRUE(mock_node->GetOpDef()->tag().returnable()); + + OpKernelOptions opts{mock_node->node_def(), mock_node->GetOpDef()}; + auto kernel = OpKernelFactory::GetInstance()->Create(std::move(opts)); + + // check input schema + ASSERT_EQ(kernel->GetInputsNum(), mock_node->GetOpDef()->inputs_size()); + const auto& input_schema_list = kernel->GetAllInputSchema(); + ASSERT_EQ(input_schema_list.size(), kernel->GetInputsNum()); + for (const auto& input_schema : input_schema_list) { + ASSERT_TRUE(input_schema->Equals(input->schema())); + } + + // check output schema + auto output_schema = kernel->GetOutputSchema(); + ASSERT_TRUE(output_schema->Equals(expect_output->schema())); + + // compute + ComputeContext compute_ctx; + compute_ctx.inputs.emplace_back( + std::vector>{input}); + + kernel->Compute(&compute_ctx); + + // check output + ASSERT_TRUE(compute_ctx.output); + + std::cout << "output: " << compute_ctx.output->ToString() << std::endl; + + double epsilon = 1E-13; + ASSERT_TRUE(compute_ctx.output->ApproxEquals( + *expect_output, arrow::EqualOptions::Defaults().atol(epsilon))); +} + +INSTANTIATE_TEST_SUITE_P( + ArrowProcessingParamTestSuite, ArrowProcessingParamTest, + ::testing::Values( + /*ext funcs*/ + Param{true, + {R"JSON({ + "name": "EFN_TB_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 0 + } + } + ], + "output": { + "data_id": 1 + } + })JSON", + R"JSON({ + "name": "EFN_TB_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 1 + } + } + ], + "output": { + "data_id": 2 + } + })JSON", + R"JSON({ + "name": "add", + "inputs": [ + { + "data_id": 1 + }, + { + "data_id": 2 + } + ], + "output": { + "data_id": 3 + } + })JSON", + R"JSON({ + "name": "EFN_TB_SET_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 1 + } + }, + { + "custom_scalar": { + "s": "s2" + } + }, + { + "data_id": 3 + } + ], + "output": { + "data_id": 4 + } + })JSON", + R"JSON({ + "name": "EFN_TB_REMOVE_COLUMN", + "inputs": [ + { + "data_id": 4 + }, + { + "custom_scalar": { + "i64": 2 + } + } + ], + "output": { + "data_id": 5 + } + })JSON", + R"JSON({ + "name": "EFN_TB_ADD_COLUMN", + "inputs": [ + { + "data_id": 5 + }, + { + "custom_scalar": { + "i64": 2 + } + }, + { + "custom_scalar": { + "s": "a4" + } + }, + { + "data_id": 3 + } + ], + "output": { + "data_id": 6 + } + })JSON"}, + {}, + {R"JSON([1, 2, 3])JSON", R"JSON([4, 5, 6])JSON", + R"JSON(["null", "null", "null"])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("x2", arrow::int32()), + arrow::field("x3", arrow::utf8())}, + {R"JSON([1, 2, 3])JSON", R"JSON([5, 7, 9])JSON", + R"JSON([5, 7, 9])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("s2", arrow::int32()), + arrow::field("a4", arrow::int32())}}, + /*run func with opts*/ + Param{false, + {R"JSON({ + "name": "EFN_TB_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 0 + } + } + ], + "output": { + "data_id": 1 + } + })JSON", + R"JSON({ + "name": "round", + "inputs": [ + { + "data_id": 1 + } + ], + "output": { + "data_id": 2 + } + })JSON", + R"JSON({ + "name": "EFN_TB_ADD_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 1 + } + }, + { + "custom_scalar": { + "s": "round_c" + } + }, + { + "data_id": 2 + } + ], + "output": { + "data_id": 3 + } + })JSON"}, + {{1, std::make_shared( + 2, arrow::compute::RoundMode::DOWN)}}, + {R"JSON([1.234, 2.35864, 3.1415926])JSON"}, + {arrow::field("x1", arrow::float64())}, + {R"JSON([1.234, 2.35864, 3.1415926])JSON", + R"JSON([1.23, 2.35, 3.14])JSON"}, + {arrow::field("x1", arrow::float64()), + arrow::field("round_c", arrow::float64())}}, + /*run func with default opts*/ + Param{ + false, + {R"JSON({ + "name": "EFN_TB_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 0 + } + } + ], + "output": { + "data_id": 1 + } + })JSON", + R"JSON({ + "name": "round", + "inputs": [ + { + "data_id": 1 + } + ], + "output": { + "data_id": 2 + } + })JSON", + R"JSON({ + "name": "EFN_TB_ADD_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 1 + } + }, + { + "custom_scalar": { + "s": "round_c" + } + }, + { + "data_id": 2 + } + ], + "output": { + "data_id": 3 + } + })JSON"}, + {}, + {R"JSON([1.234, 2.78864, 3.1415926])JSON"}, + {arrow::field("x1", arrow::float64())}, + {R"JSON([1.234, 2.78864, 3.1415926])JSON", R"JSON([1, 3, 3])JSON"}, + {arrow::field("x1", arrow::float64()), + arrow::field("round_c", arrow::float64())}}, + /*dummy run*/ + Param{false, + {}, + {}, + {R"JSON([1.234, 2.78864, 3.1415926])JSON"}, + {arrow::field("x1", arrow::float64())}, + {R"JSON([1.234, 2.78864, 3.1415926])JSON"}, + {arrow::field("x1", arrow::float64())}})); + +class ArrowProcessingExceptionTest : public ::testing::TestWithParam { + protected: + void SetUp() override {} + void TearDown() override {} +}; + +TEST_P(ArrowProcessingExceptionTest, Constructor) { + auto param = GetParam(); + + // build input & expect_output + auto input = BuildRecordBatch(param.input_array_jsons, param.input_fields); + + auto expect_output = + BuildRecordBatch(param.output_array_jsons, param.output_fields); + + // build node + compute::ComputeTrace compute_trace; + compute_trace.set_name("test_compute"); + for (size_t i = 0; i < param.func_trace_contents.size(); ++i) { + auto* func_trace = compute_trace.add_func_traces(); + JsonToPb(param.func_trace_contents[i], func_trace); + + auto it = param.func_trace_opts.find(i); + if (it != param.func_trace_opts.end()) { + std::shared_ptr buf; + SERVING_GET_ARROW_RESULT(it->second->Serialize(), buf); + func_trace->set_option_bytes(reinterpret_cast(buf->data()), + buf->size()); + } + } + + NodeDef node_def; + node_def.set_name("test_node"); + node_def.set_op("ARROW_PROCESSING"); + + // always use pb serialize + AttrValue content_json_flag; + content_json_flag.set_b(param.content_json_flag); + node_def.mutable_attr_values()->insert( + {"content_json_flag", content_json_flag}); + + AttrValue trace_content; + trace_content.set_by(compute_trace.SerializeAsString()); + node_def.mutable_attr_values()->insert({"trace_content", trace_content}); + + { + AttrValue input_schema_bytes; + std::shared_ptr buf; + SERVING_GET_ARROW_RESULT(arrow::ipc::SerializeSchema(*input->schema()), + buf); + input_schema_bytes.set_by(reinterpret_cast(buf->data()), + buf->size()); + node_def.mutable_attr_values()->insert( + {"input_schema_bytes", std::move(input_schema_bytes)}); + } + { + AttrValue output_schema_bytes; + std::shared_ptr buf; + SERVING_GET_ARROW_RESULT( + arrow::ipc::SerializeSchema(*expect_output->schema()), buf); + output_schema_bytes.set_by(reinterpret_cast(buf->data()), + buf->size()); + node_def.mutable_attr_values()->insert( + {"output_schema_bytes", std::move(output_schema_bytes)}); + } + + // create kernel + auto mock_node = std::make_shared(std::move(node_def)); + ASSERT_EQ(mock_node->GetOpDef()->inputs_size(), 1); + ASSERT_TRUE(mock_node->GetOpDef()->tag().returnable()); + + OpKernelOptions opts{mock_node->node_def(), mock_node->GetOpDef()}; + try { + OpKernelFactory::GetInstance()->Create(std::move(opts)); + } catch (const std::exception& e) { + std::cout << e.what() << std::endl; + } + + EXPECT_THROW(OpKernelFactory::GetInstance()->Create(std::move(opts)), + Exception); +} + +INSTANTIATE_TEST_SUITE_P( + ArrowProcessingExceptionTest, ArrowProcessingExceptionTest, + ::testing::Values( + /*wrong content format*/ Param{true, + {R"JSON({ + "name": "EFN_TB_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 0 + } + } + ], + "output": { + "data_id": 1 + } + })JSON", + R"JSON({ + "name": "EFN_TB_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 1 + } + } + ], + "output": { + "data_id": 2 + } + })JSON", + R"JSON({ + "name": "add", + "inputs": [ + { + "data_id": 1 + }, + { + "data_id": 2 + } + ], + "output": { + "data_id": 3 + } + })JSON", + R"JSON({ + "name": "EFN_TB_SET_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 1 + } + }, + { + "custom_scalar": { + "s": "s2" + } + }, + { + "data_id": 3 + } + ], + "output": { + "data_id": 4 + } + })JSON"}, + {}, + {R"JSON([1, 2, 3])JSON", + R"JSON([4, 5, 6])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("x2", arrow::int32())}, + {R"JSON([1, 2, 3])JSON", + R"JSON([5, 7, 9])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("s2", arrow::int32())}}, + /*invalid data_id*/ + Param{false, + {R"JSON({ + "name": "EFN_TB_REMOVE_COLUMN", + "inputs": [ + { + "data_id": 1 + }, + { + "custom_scalar": { + "i64": 2 + } + } + ], + "output": { + "data_id": 2 + } + })JSON"}, + {}, + {R"JSON([1, 2, 3])JSON", R"JSON([4, 5, 6])JSON", + R"JSON(["null", "null", "null"])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("x2", arrow::int32()), + arrow::field("x3", arrow::utf8())}, + {R"JSON([1, 2, 3])JSON", R"JSON([4, 5, 6])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("s2", arrow::int32())}}, + /*wrong data type*/ + Param{false, + {R"JSON({ + "name": "EFN_TB_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 0 + } + } + ], + "output": { + "data_id": 1 + } + })JSON", + R"JSON({ + "name": "EFN_TB_REMOVE_COLUMN", + "inputs": [ + { + "data_id": 1 + }, + { + "custom_scalar": { + "i64": 2 + } + } + ], + "output": { + "data_id": 2 + } + })JSON"}, + {}, + {R"JSON([1, 2, 3])JSON", R"JSON([4, 5, 6])JSON", + R"JSON(["null", "null", "null"])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("x2", arrow::int32()), + arrow::field("x3", arrow::utf8())}, + {R"JSON([1, 2, 3])JSON", R"JSON([4, 5, 6])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("s2", arrow::int32())}}, + /*wrong index*/ + Param{false, + {R"JSON({ + "name": "EFN_TB_REMOVE_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 5 + } + } + ], + "output": { + "data_id": 1 + } + })JSON"}, + {}, + {R"JSON([1, 2, 3])JSON", R"JSON([4, 5, 6])JSON", + R"JSON(["null", "null", "null"])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("x2", arrow::int32()), + arrow::field("x3", arrow::utf8())}, + {R"JSON([1, 2, 3])JSON", R"JSON([4, 5, 6])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("s2", arrow::int32())}}, + /*wrong index*/ + Param{false, + {R"JSON({ + "name": "EFN_TB_REMOVE_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": -2 + } + } + ], + "output": { + "data_id": 1 + } + })JSON"}, + {}, + {R"JSON([1, 2, 3])JSON", R"JSON([4, 5, 6])JSON", + R"JSON(["null", "null", "null"])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("x2", arrow::int32()), + arrow::field("x3", arrow::utf8())}, + {R"JSON([1, 2, 3])JSON", R"JSON([4, 5, 6])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("s2", arrow::int32())}}, + /*wrong index type*/ + Param{false, + {R"JSON({ + "name": "EFN_TB_REMOVE_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "data_id": 0 + } + ], + "output": { + "data_id": 1 + } + })JSON"}, + {}, + {R"JSON([1, 2, 3])JSON", R"JSON([4, 5, 6])JSON", + R"JSON(["null", "null", "null"])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("x2", arrow::int32()), + arrow::field("x3", arrow::utf8())}, + {R"JSON([1, 2, 3])JSON", R"JSON([4, 5, 6])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("s2", arrow::int32())}}, + /*wrong input type*/ + Param{false, + {R"JSON({ + "name": "EFN_TB_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 0 + } + } + ], + "output": { + "data_id": 1 + } + })JSON", + R"JSON({ + "name": "EFN_TB_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 1 + } + } + ], + "output": { + "data_id": 2 + } + })JSON", + R"JSON({ + "name": "add", + "inputs": [ + { + "data_id": 0 + }, + { + "data_id": 0 + } + ], + "output": { + "data_id": 3 + } + })JSON", + R"JSON({ + "name": "EFN_TB_SET_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 1 + } + }, + { + "custom_scalar": { + "s": "s2" + } + }, + { + "data_id": 3 + } + ], + "output": { + "data_id": 4 + } + })JSON"}, + {}, + {R"JSON([1, 2, 3])JSON", R"JSON([4, 5, 6])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("x2", arrow::int32())}, + {R"JSON([1, 2, 3])JSON", R"JSON([5, 7, 9])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("s2", arrow::int32())}}, + /*duplicate output_id*/ + Param{false, + {R"JSON({ + "name": "EFN_TB_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 0 + } + } + ], + "output": { + "data_id": 1 + } + })JSON", + R"JSON({ + "name": "EFN_TB_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 1 + } + } + ], + "output": { + "data_id": 1 + } + })JSON", + R"JSON({ + "name": "add", + "inputs": [ + { + "data_id": 1 + }, + { + "data_id": 1 + } + ], + "output": { + "data_id": 3 + } + })JSON", + R"JSON({ + "name": "EFN_TB_SET_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 1 + } + }, + { + "custom_scalar": { + "s": "s2" + } + }, + { + "data_id": 3 + } + ], + "output": { + "data_id": 4 + } + })JSON"}, + {}, + {R"JSON([1, 2, 3])JSON", R"JSON([4, 5, 6])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("x2", arrow::int32())}, + {R"JSON([1, 2, 3])JSON", R"JSON([5, 7, 9])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("s2", arrow::int32())}}, + /*not returnable*/ + Param{false, + {R"JSON({ + "name": "EFN_TB_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 0 + } + } + ], + "output": { + "data_id": 1 + } + })JSON"}, + {}, + {R"JSON([1, 2, 3])JSON", R"JSON([4, 5, 6])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("x2", arrow::int32())}, + {R"JSON([1, 2, 3])JSON", R"JSON([4, 5, 6])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("x2", arrow::int32())}}, + /*require opts*/ + Param{false, + {R"JSON({ + "name": "EFN_TB_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 0 + } + } + ], + "output": { + "data_id": 1 + } + })JSON", + R"JSON({ + "name": "is_in", + "inputs": [ + { + "data_id": 1 + } + ], + "output": { + "data_id": 2 + } + })JSON", + R"JSON({ + "name": "EFN_TB_REMOVE_COLUMN", + "inputs": [ + { + "data_id": 0 + }, + { + "custom_scalar": { + "i64": 1 + } + } + ], + "output": { + "data_id": 3 + } + })JSON"}, + {}, + {R"JSON([1, 2, 3])JSON", R"JSON([4, 5, 6])JSON"}, + {arrow::field("x1", arrow::int32()), + arrow::field("x2", arrow::int32())}, + {R"JSON([1, 2, 3])JSON"}, + {arrow::field("x1", arrow::int32())}})); + +} // namespace secretflow::serving::op diff --git a/secretflow_serving/ops/dot_product.cc b/secretflow_serving/ops/dot_product.cc index e8a3d50..e84d4a3 100644 --- a/secretflow_serving/ops/dot_product.cc +++ b/secretflow_serving/ops/dot_product.cc @@ -16,11 +16,15 @@ #include +#include "arrow/compute/api.h" + #include "secretflow_serving/ops/node_def_util.h" #include "secretflow_serving/ops/op_factory.h" #include "secretflow_serving/ops/op_kernel_factory.h" #include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/protos/data_type.pb.h" + namespace secretflow::serving::op { namespace { @@ -32,17 +36,26 @@ Double::Matrix TableToMatrix(const std::shared_ptr& table) { Double::Matrix matrix; matrix.resize(rows, cols); - Eigen::Map map(matrix.data(), rows, cols); // 遍历table的每一列,将数据映射到Eigen::Matrix中 for (int i = 0; i < cols; ++i) { - auto col = table->column(i); + const auto& col = table->column(i); + std::shared_ptr double_array = col; + if (col->type_id() != arrow::Type::DOUBLE) { + arrow::Datum double_array_datum; + SERVING_GET_ARROW_RESULT( + arrow::compute::Cast( + col, arrow::compute::CastOptions::Safe(arrow::float64())), + double_array_datum); + double_array = std::move(double_array_datum).make_array(); + } + // index 0 is validity bitmap, real data start with 1 - auto data = col->data()->GetMutableValues(1); + auto data = double_array->data()->GetMutableValues(1); SERVING_ENFORCE(data, errors::ErrorCode::LOGIC_ERROR, - "found unsupport field type"); + "found unsupported field type"); Eigen::Map vec(data, rows); - map.col(i) = vec; + matrix.col(i) = vec; } return matrix; @@ -51,48 +64,55 @@ Double::Matrix TableToMatrix(const std::shared_ptr& table) { } // namespace DotProduct::DotProduct(OpKernelOptions opts) : OpKernel(std::move(opts)) { - output_col_name_ = - GetNodeAttr(opts_.node->node_def(), "output_col_name"); - // optional attr - GetNodeAttr(opts_.node->node_def(), "intercept", &intercept_); + SERVING_ENFORCE_EQ(opts_.node_def.op_version(), "0.0.2"); - // feature + // feature name + feature_name_list_ = + GetNodeAttr>(opts_.node_def, "feature_names"); std::set f_name_set; - feature_name_list_ = GetNodeAttr>( - opts_.node->node_def(), "feature_names"); for (auto& feature_name : feature_name_list_) { SERVING_ENFORCE(f_name_set.emplace(feature_name).second, - errors::ErrorCode::LOGIC_ERROR); + errors::ErrorCode::LOGIC_ERROR, + "found duplicate feature name:{}", feature_name); } - auto feature_weight_list = GetNodeAttr>( - opts_.node->node_def(), "feature_weights"); - - SERVING_ENFORCE(feature_name_list_.size() == feature_weight_list.size(), - errors::ErrorCode::UNEXPECTED_ERROR, - "attr:feature_names size={} does not match " - "attr:feature_weights size={}, node:{}, op:{}", - feature_name_list_.size(), feature_weight_list.size(), - opts_.node->node_def().name(), opts_.node->node_def().op()); + // feature types + feature_type_list_ = + GetNodeAttr>(opts_.node_def, "input_types"); + SERVING_ENFORCE_EQ(feature_name_list_.size(), feature_type_list_.size(), + "attr:feature_names size={} does not match " + "attr:input_types size={}, node:{}, op:{}", + feature_name_list_.size(), feature_type_list_.size(), + opts_.node_def.name(), opts_.node_def.op()); + + auto feature_weight_list = + GetNodeAttr>(opts_.node_def, "feature_weights"); + SERVING_ENFORCE_EQ(feature_name_list_.size(), feature_weight_list.size(), + "attr:feature_names size={} does not match " + "attr:feature_weights size={}, node:{}, op:{}", + feature_name_list_.size(), feature_weight_list.size(), + opts_.node_def.name(), opts_.node_def.op()); weights_ = Double::ColVec::Zero(feature_weight_list.size()); for (size_t i = 0; i < feature_weight_list.size(); i++) { weights_[i] = feature_weight_list[i]; } + output_col_name_ = + GetNodeAttr(opts_.node_def, "output_col_name"); + + // optional attr + GetNodeAttr(opts_.node_def, "intercept", &intercept_); + BuildInputSchema(); BuildOutputSchema(); } -void DotProduct::Compute(ComputeContext* ctx) { - SERVING_ENFORCE(ctx->inputs->size() == 1, errors::ErrorCode::LOGIC_ERROR); - SERVING_ENFORCE(ctx->inputs->front().size() == 1, +void DotProduct::DoCompute(ComputeContext* ctx) { + SERVING_ENFORCE(ctx->inputs.size() == 1, errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(ctx->inputs.front().size() == 1, errors::ErrorCode::LOGIC_ERROR); - auto input_table = ctx->inputs->front()[0]; - SERVING_ENFORCE(input_table->schema()->Equals(input_schema_list_.front()), - errors::ErrorCode::LOGIC_ERROR); - - auto features = TableToMatrix(input_table); + auto features = TableToMatrix(ctx->inputs.front().front()); Double::ColVec score_vec = features * weights_; score_vec.array() += intercept_; @@ -109,16 +129,15 @@ void DotProduct::Compute(ComputeContext* ctx) { void DotProduct::BuildInputSchema() { // build input schema - int inputs_size = opts_.node->GetOpDef()->inputs_size(); - for (int i = 0; i < inputs_size; ++i) { - std::vector> f_list; - for (const auto& f : feature_name_list_) { - f_list.emplace_back(arrow::field(f, arrow::float64())); - } - input_schema_list_.emplace_back(arrow::schema(std::move(f_list))); - // should only have 1 input - break; + std::vector> fields; + for (size_t i = 0; i < feature_name_list_.size(); ++i) { + auto data_type = DataTypeToArrowDataType(feature_type_list_[i]); + SERVING_ENFORCE( + arrow::is_numeric(data_type->id()), errors::INVALID_ARGUMENT, + "feature type must be numeric, get:{}", feature_type_list_[i]); + fields.emplace_back(arrow::field(feature_name_list_[i], data_type)); } + input_schema_list_.emplace_back(arrow::schema(std::move(fields))); } void DotProduct::BuildOutputSchema() { @@ -128,13 +147,21 @@ void DotProduct::BuildOutputSchema() { } REGISTER_OP_KERNEL(DOT_PRODUCT, DotProduct) -REGISTER_OP(DOT_PRODUCT, "0.0.1", +REGISTER_OP(DOT_PRODUCT, "0.0.2", "Calculate the dot product of feature weights and values") - .StringAttr("feature_names", "", true, false) - .DoubleAttr("feature_weights", "", true, false) - .StringAttr("output_col_name", "", false, false) - .DoubleAttr("intercept", "", false, true, 0.0d) - .Input("features", "") - .Output("ys", ""); + .StringAttr("feature_names", "List of feature names", true, false) + .DoubleAttr("feature_weights", "List of feature weights", true, false) + .StringAttr("input_types", + "List of input feature data types, Note that there is a loss " + "of precision when using `DT_FLOAT` type. Optional " + "value: DT_UINT8, " + "DT_INT8, DT_UINT16, DT_INT16, DT_UINT32, DT_INT32, DT_UINT64, " + "DT_INT64, DT_FLOAT, DT_DOUBLE", + true, false) + .StringAttr("output_col_name", "Column name of partial y", false, false) + .DoubleAttr("intercept", "Value of model intercept", false, true, 0.0d) + .Input("features", "Input feature table") + .Output("partial_ys", + "The calculation results, they have a data type of `double`."); } // namespace secretflow::serving::op diff --git a/secretflow_serving/ops/dot_product.h b/secretflow_serving/ops/dot_product.h index a01ceeb..e19d948 100644 --- a/secretflow_serving/ops/dot_product.h +++ b/secretflow_serving/ops/dot_product.h @@ -23,7 +23,7 @@ class DotProduct : public OpKernel { public: explicit DotProduct(OpKernelOptions opts); - void Compute(ComputeContext* ctx) override; + void DoCompute(ComputeContext* ctx) override; protected: void BuildInputSchema() override; @@ -32,6 +32,7 @@ class DotProduct : public OpKernel { private: std::vector feature_name_list_; + std::vector feature_type_list_; std::string output_col_name_; diff --git a/secretflow_serving/ops/dot_product_benchmark.cc b/secretflow_serving/ops/dot_product_benchmark.cc new file mode 100644 index 0000000..dda2cd5 --- /dev/null +++ b/secretflow_serving/ops/dot_product_benchmark.cc @@ -0,0 +1,147 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "benchmark/benchmark.h" + +#include "secretflow_serving/ops/dot_product.h" +#include "secretflow_serving/ops/op_factory.h" +#include "secretflow_serving/ops/op_kernel_factory.h" +#include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/utils.h" + +namespace secretflow::serving::op { + +template +struct ArrowBuilderTrait { + using type = void; + constexpr static const char* const data_type = nullptr; +}; + +template <> +struct ArrowBuilderTrait { + using type = arrow::DoubleBuilder; + constexpr static const char* const data_type = "DT_DOUBLE"; +}; + +template <> +struct ArrowBuilderTrait { + using type = arrow::Int64Builder; + constexpr static const char* const data_type = "DT_INT64"; +}; + +template +void BMDotProductBench(benchmark::State& state) { + auto feature_nums = state.range(1); + auto row_nums = state.range(0); + + state.counters["log2(row_nums*feature_nums)"] = log2(row_nums * feature_nums); + state.counters["feature_nums"] = feature_nums; + state.counters["row_nums"] = row_nums; + state.SetLabel(ArrowBuilderTrait::data_type); + + // mock attr + AttrValue feature_name_value; + { + std::vector names; + for (int64_t i = 1; i <= feature_nums; ++i) { + names.push_back(fmt::format("x{}", i)); + } + feature_name_value.mutable_ss()->mutable_data()->Assign(names.begin(), + names.end()); + } + + AttrValue input_types; + { + std::vector types(feature_nums, + ArrowBuilderTrait::data_type); + input_types.mutable_ss()->mutable_data()->Assign(types.begin(), + types.end()); + } + + AttrValue feature_weight_value; + { + std::vector weights; + std::generate_n(std::back_inserter(weights), feature_nums, + []() { return rand(); }); + + feature_weight_value.mutable_ds()->mutable_data()->Assign(weights.begin(), + weights.end()); + } + AttrValue output_col_name_value; + output_col_name_value.set_s("score"); + AttrValue intercept_value; + intercept_value.set_d(1.313201881559211); + + // mock feature values + std::vector> feature_value_list; + std::generate_n(std::back_inserter(feature_value_list), feature_nums, + [row_nums]() { + std::vector tmp; + std::generate_n(std::back_inserter(tmp), row_nums, + [] { return rand(); }); + return tmp; + }); + + NodeDef node_def; + node_def.set_op_version("0.0.2"); + node_def.set_name("test_node"); + node_def.set_op("DOT_PRODUCT"); + node_def.mutable_attr_values()->insert({"feature_names", feature_name_value}); + node_def.mutable_attr_values()->insert( + {"feature_weights", feature_weight_value}); + node_def.mutable_attr_values()->insert({"input_types", input_types}); + node_def.mutable_attr_values()->insert( + {"output_col_name", output_col_name_value}); + node_def.mutable_attr_values()->insert({"intercept", intercept_value}); + auto mock_node = std::make_shared(std::move(node_def)); + + OpKernelOptions opts{mock_node->node_def(), mock_node->GetOpDef()}; + auto kernel = OpKernelFactory::GetInstance()->Create(std::move(opts)); + const auto& input_schema_list = kernel->GetAllInputSchema(); + + // compute + ComputeContext compute_ctx; + { + std::vector> arrays; + for (size_t i = 0; i < feature_value_list.size(); ++i) { + typename ArrowBuilderTrait::type builder; + // arrow::DoubleBuilder builder; + SERVING_CHECK_ARROW_STATUS(builder.AppendValues(feature_value_list[i])); + std::shared_ptr array; + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + arrays.emplace_back(array); + } + auto features = + MakeRecordBatch(input_schema_list.front(), row_nums, std::move(arrays)); + compute_ctx.inputs.emplace_back( + std::vector>{features}); + } + for (auto _ : state) { + kernel->Compute(&compute_ctx); + } +} + +BENCHMARK_TEMPLATE(BMDotProductBench, double) + ->ArgsProduct({ + benchmark::CreateRange(2, 1u << 17, /*multi=*/2), + benchmark::CreateRange(64, 512, /*multi=*/4), + }); + +BENCHMARK_TEMPLATE(BMDotProductBench, int64_t) + ->ArgsProduct({ + benchmark::CreateRange(2, 1u << 17, /*multi=*/2), + benchmark::CreateRange(64, 512, /*multi=*/4), + }); + +} // namespace secretflow::serving::op diff --git a/secretflow_serving/ops/dot_product_test.cc b/secretflow_serving/ops/dot_product_test.cc index 959e60d..fcdde18 100644 --- a/secretflow_serving/ops/dot_product_test.cc +++ b/secretflow_serving/ops/dot_product_test.cc @@ -14,10 +14,13 @@ #include "secretflow_serving/ops/dot_product.h" +#include "arrow/ipc/api.h" #include "gtest/gtest.h" +#include "secretflow_serving/ops/op_factory.h" #include "secretflow_serving/ops/op_kernel_factory.h" #include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/test_utils.h" #include "secretflow_serving/util/utils.h" namespace secretflow::serving::op { @@ -29,62 +32,71 @@ class DotProductTest : public ::testing::Test { }; TEST_F(DotProductTest, Works) { - // mock attr - AttrValue feature_name_value; - { - std::vector names = {"x1", "x2", "x3", "x4", "x5", - "x6", "x7", "x8", "x9", "x10"}; - feature_name_value.mutable_ss()->mutable_data()->Assign(names.begin(), - names.end()); - } - AttrValue feature_weight_value; - { - std::vector weights = {-0.32705051172041194, 0.95102386482712309, - 1.01145375640758, 1.3493328346102449, - -0.97103250283196174, -0.53749125086392879, - 0.92053884353604121, -0.72217737944554916, - 0.14693041881992241, -0.0707939985283586}; - feature_weight_value.mutable_ds()->mutable_data()->Assign(weights.begin(), - weights.end()); - } - AttrValue output_col_name_value; - output_col_name_value.set_s("score"); - AttrValue intercept_value; - intercept_value.set_d(1.313201881559211); - - // mock feature values - std::vector> feature_value_list = { - {93, 0.1}, {18, 0.1}, {17, 0.1}, {20, 0.1}, {76, 0.1}, - {74, 0.1}, {25, 0.1}, {2, 0.1}, {31, 0.1}, {37, 0.1}}; - - // expect result - double expect_score_0 = 1.313201881559211 + 1.01145375640758 * 17 + - -0.97103250283196174 * 76 + - -0.32705051172041194 * 93 + 1.3493328346102449 * 20 + - 0.95102386482712309 * 18 + -0.53749125086392879 * 74 + - 0.92053884353604121 * 25 + -0.72217737944554916 * 2 + - 0.14693041881992241 * 31 + -0.0707939985283586 * 37; - double expect_score_1 = - 1.313201881559211 + 1.01145375640758 * 0.1 + -0.97103250283196174 * 0.1 + - -0.32705051172041194 * 0.1 + 1.3493328346102449 * 0.1 + - 0.95102386482712309 * 0.1 + -0.53749125086392879 * 0.1 + - 0.92053884353604121 * 0.1 + -0.72217737944554916 * 0.1 + - 0.14693041881992241 * 0.1 + -0.0707939985283586 * 0.1; - double epsilon = 1E-13; - + std::string json_content = R"JSON( +{ + "name": "test_node", + "op": "DOT_PRODUCT", + "attr_values": { + "feature_names": { + "ss": { + "data": [ + "x1", "x2", "x3", "x4", "x5", + "x6", "x7", "x8", "x9", "x10" + ] + } + }, + "input_types": { + "ss": { + "data": [ + "DT_DOUBLE", "DT_FLOAT", "DT_INT8", "DT_UINT8", "DT_INT16", + "DT_UINT16", "DT_INT32", "DT_UINT32", "DT_INT64", "DT_UINT64" + ] + } + }, + "feature_weights": { + "ds": { + "data": [ + -0.32705051172041194, 0.95102386482712309, + 1.01145375640758, 1.3493328346102449, + -0.97103250283196174, -0.53749125086392879, + 0.92053884353604121, -0.72217737944554916, + 0.14693041881992241, -0.0707939985283586 + ] + } + }, + "output_col_name": { + "s": "score", + }, + "intercept": { + "d": 1.313201881559211 + } + }, + "op_version": "0.0.2", +} +)JSON"; NodeDef node_def; - node_def.set_name("test_node"); - node_def.set_op("DOT_PRODUCT"); - node_def.mutable_attr_values()->insert({"feature_names", feature_name_value}); - node_def.mutable_attr_values()->insert( - {"feature_weights", feature_weight_value}); - node_def.mutable_attr_values()->insert( - {"output_col_name", output_col_name_value}); - node_def.mutable_attr_values()->insert({"intercept", intercept_value}); + JsonToPb(json_content, &node_def); + + std::vector> input_fields = { + arrow::field("x1", arrow::float64()), + arrow::field("x2", arrow::float32()), + arrow::field("x3", arrow::int8()), + arrow::field("x4", arrow::uint8()), + arrow::field("x5", arrow::int16()), + arrow::field("x6", arrow::uint16()), + arrow::field("x7", arrow::int32()), + arrow::field("x8", arrow::uint32()), + arrow::field("x9", arrow::int64()), + arrow::field("x10", arrow::uint64())}; + + auto expect_input_schema = arrow::schema(input_fields); + auto expect_output_schema = + arrow::schema({arrow::field("score", arrow::float64())}); + auto mock_node = std::make_shared(std::move(node_def)); ASSERT_EQ(mock_node->GetOpDef()->inputs_size(), 1); - OpKernelOptions opts{mock_node}; + OpKernelOptions opts{mock_node->node_def(), mock_node->GetOpDef()}; auto kernel = OpKernelFactory::GetInstance()->Create(std::move(opts)); // check input schema @@ -93,41 +105,55 @@ TEST_F(DotProductTest, Works) { ASSERT_EQ(input_schema_list.size(), kernel->GetInputsNum()); for (size_t i = 0; i < input_schema_list.size(); ++i) { const auto& input_schema = input_schema_list[i]; - ASSERT_EQ(input_schema, kernel->GetInputSchema(i)); - ASSERT_EQ(input_schema->num_fields(), feature_name_value.ss().data_size()); - for (int j = 0; j < input_schema->num_fields(); ++j) { - auto field = input_schema->field(j); - ASSERT_EQ(field->name(), feature_name_value.ss().data(j)); - ASSERT_EQ(field->type()->id(), arrow::Type::type::DOUBLE); - } + ASSERT_TRUE(input_schema->Equals(expect_input_schema)); } - // check output schema auto output_schema = kernel->GetOutputSchema(); - ASSERT_EQ(output_schema->num_fields(), 1); - for (int j = 0; j < output_schema->num_fields(); ++j) { - auto field = output_schema->field(j); - ASSERT_EQ(field->name(), output_col_name_value.s()); - ASSERT_EQ(field->type()->id(), arrow::Type::type::DOUBLE); - } + ASSERT_TRUE(output_schema->Equals(expect_output_schema)); - // compute + // build input ComputeContext compute_ctx; - compute_ctx.inputs = std::make_shared(); { - std::vector> arrays; - for (size_t i = 0; i < feature_value_list.size(); ++i) { - arrow::DoubleBuilder builder; - SERVING_CHECK_ARROW_STATUS(builder.AppendValues(feature_value_list[i])); - std::shared_ptr array; - SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); - arrays.emplace_back(array); - } + std::shared_ptr x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11; + using arrow::ipc::internal::json::ArrayFromJSON; + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::float64(), "[93, -0.1]"), x1); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::float32(), "[18, 0.1]"), x2); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::int8(), "[17, -1]"), x3); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::uint8(), "[20, 1]"), x4); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::int16(), "[76, -1]"), x5); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::uint16(), "[74, 1]"), x6); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::int32(), "[25, -1]"), x7); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::uint32(), "[2, 1]"), x8); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::int64(), "[31, -1]"), x9); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::uint64(), "[37, 1]"), x10); + // redundant column + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::uint64(), "[23, 15]"), x11); + input_fields.emplace_back(arrow::field("x11", arrow::uint64())); + auto features = - MakeRecordBatch(input_schema_list.front(), 2, std::move(arrays)); - compute_ctx.inputs->emplace_back( - std::vector>{features}); + MakeRecordBatch(arrow::schema(input_fields), 2, + {x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11}); + auto shuffle_fs = test::ShuffleRecordBatch(features); + std::cout << shuffle_fs->ToString() << std::endl; + + compute_ctx.inputs.emplace_back( + std::vector>{shuffle_fs}); } + + // expect result + double expect_score_0 = 1.313201881559211 + -0.32705051172041194 * 93 + + 0.95102386482712309 * 18 + 1.01145375640758 * 17 + + 1.3493328346102449 * 20 + -0.97103250283196174 * 76 + + -0.53749125086392879 * 74 + 0.92053884353604121 * 25 + + -0.72217737944554916 * 2 + 0.14693041881992241 * 31 + + -0.0707939985283586 * 37; + double expect_score_1 = 1.313201881559211 + -0.32705051172041194 * -0.1 + + 0.95102386482712309 * 0.1 + 1.01145375640758 * -1 + + 1.3493328346102449 * 1 + -0.97103250283196174 * -1 + + -0.53749125086392879 * 1 + 0.92053884353604121 * -1 + + -0.72217737944554916 * 1 + 0.14693041881992241 * -1 + + -0.0707939985283586 * 1; + kernel->Compute(&compute_ctx); // check output @@ -143,14 +169,16 @@ TEST_F(DotProductTest, Works) { std::cout << "result: " << compute_ctx.output->column(0)->ToString() << std::endl; + // converting float to double causes the result to lose precision + // double epsilon = 1E-13; + double epsilon = 1E-8; ASSERT_TRUE(compute_ctx.output->column(0)->ApproxEquals( expect_score_array, arrow::EqualOptions::Defaults().atol(epsilon))); } TEST_F(DotProductTest, Constructor) { // default intercept - { - std::string json_content = R"JSON( + std::string json_content = R"JSON( { "name": "test_node", "op": "DOT_PRODUCT", @@ -163,6 +191,14 @@ TEST_F(DotProductTest, Constructor) { ] } }, + "input_types": { + "ss": { + "data": [ + "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", + "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE" + ] + } + }, "feature_weights": { "ds": { "data": [ @@ -177,20 +213,44 @@ TEST_F(DotProductTest, Constructor) { "output_col_name": { "s": "y", } - } + }, + "op_version": "0.0.2", } )JSON"; - NodeDef node_def; - JsonToPb(json_content, &node_def); + NodeDef node_def; + JsonToPb(json_content, &node_def); - OpKernelOptions opts{std::make_shared(std::move(node_def))}; - EXPECT_NO_THROW(OpKernelFactory::GetInstance()->Create(std::move(opts))); - } + auto op_def = OpFactory::GetInstance()->Get("DOT_PRODUCT"); + OpKernelOptions opts{std::move(node_def), op_def}; + EXPECT_NO_THROW(OpKernelFactory::GetInstance()->Create(std::move(opts))); +} - // name and weight num mismatch - { - std::string json_content = R"JSON( +struct Param { + std::string node_content; +}; + +class DotProductExceptionTest : public ::testing::TestWithParam { + protected: + void SetUp() override {} + void TearDown() override {} +}; + +TEST_P(DotProductExceptionTest, Constructor) { + auto param = GetParam(); + + NodeDef node_def; + JsonToPb(param.node_content, &node_def); + + auto op_def = OpFactory::GetInstance()->Get(node_def.op()); + OpKernelOptions opts{std::move(node_def), op_def}; + EXPECT_THROW(OpKernelFactory::GetInstance()->Create(std::move(opts)), + Exception); +} + +INSTANTIATE_TEST_SUITE_P( + DotProductExceptionTestSuite, DotProductExceptionTest, + ::testing::Values(/*name and weight num mismatch*/ Param{R"JSON( { "name": "test_node", "op": "DOT_PRODUCT", @@ -203,6 +263,14 @@ TEST_F(DotProductTest, Constructor) { ] } }, + "input_types": { + "ss": { + "data": [ + "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", + "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE" + ] + } + }, "feature_weights": { "ds": { "data": [ @@ -217,21 +285,11 @@ TEST_F(DotProductTest, Constructor) { "output_col_name": { "s": "y", } - } + }, + "op_version": "0.0.2", } -)JSON"; - - NodeDef node_def; - JsonToPb(json_content, &node_def); - - OpKernelOptions opts{std::make_shared(std::move(node_def))}; - EXPECT_THROW(OpKernelFactory::GetInstance()->Create(std::move(opts)), - Exception); - } - - // missing feature_weights - { - std::string json_content = R"JSON( +)JSON"}, + /*missing feature_weights*/ Param{R"JSON( { "name": "test_node", "op": "DOT_PRODUCT", @@ -244,28 +302,34 @@ TEST_F(DotProductTest, Constructor) { ] } }, + "input_types": { + "ss": { + "data": [ + "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", + "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE" + ] + } + }, "output_col_name": { "s": "y", } - } + }, + "op_version": "0.0.2", } -)JSON"; - - NodeDef node_def; - JsonToPb(json_content, &node_def); - - OpKernelOptions opts{std::make_shared(std::move(node_def))}; - EXPECT_THROW(OpKernelFactory::GetInstance()->Create(std::move(opts)), - Exception); - } - - // missing feature_names - { - std::string json_content = R"JSON( +)JSON"}, + /*missing feature_names*/ Param{R"JSON( { "name": "test_node", "op": "DOT_PRODUCT", "attr_values": { + "input_types": { + "ss": { + "data": [ + "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", + "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE" + ] + } + }, "feature_weights": { "ds": { "data": [ @@ -280,21 +344,11 @@ TEST_F(DotProductTest, Constructor) { "output_col_name": { "s": "y", } - } + }, + "op_version": "0.0.2", } -)JSON"; - - NodeDef node_def; - JsonToPb(json_content, &node_def); - - OpKernelOptions opts{std::make_shared(std::move(node_def))}; - EXPECT_THROW(OpKernelFactory::GetInstance()->Create(std::move(opts)), - Exception); - } - - // missing output_col_name - { - std::string json_content = R"JSON( +)JSON"}, + /*missing output_col_name*/ Param{R"JSON( { "name": "test_node", "op": "DOT_PRODUCT", @@ -307,6 +361,14 @@ TEST_F(DotProductTest, Constructor) { ] } }, + "input_types": { + "ss": { + "data": [ + "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", + "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE" + ] + } + }, "feature_weights": { "ds": { "data": [ @@ -318,17 +380,79 @@ TEST_F(DotProductTest, Constructor) { ] } } - } + }, + "op_version": "0.0.2", } -)JSON"; - - NodeDef node_def; - JsonToPb(json_content, &node_def); - - OpKernelOptions opts{std::make_shared(std::move(node_def))}; - EXPECT_THROW(OpKernelFactory::GetInstance()->Create(std::move(opts)), - Exception); - } +)JSON"}, + /*missing feature types*/ Param{R"JSON( +{ + "name": "test_node", + "op": "DOT_PRODUCT", + "attr_values": { + "feature_names": { + "ss": { + "data": [ + "x1", "x2", "x3", "x4", "x5", + "x6", "x7", "x8", "x9" + ] + } + }, + "feature_weights": { + "ds": { + "data": [ + -0.32705051172041194, 0.95102386482712309, + 1.01145375640758, 1.3493328346102449, + -0.97103250283196174, -0.53749125086392879, + 0.92053884353604121, -0.72217737944554916, + 0.14693041881992241, -0.0707939985283586 + ] + } + }, + "output_col_name": { + "s": "y", + } + }, + "op_version": "0.0.2", +} +)JSON"}, + /*mismatch op version*/ Param{R"JSON( +{ + "name": "test_node", + "op": "DOT_PRODUCT", + "attr_values": { + "feature_names": { + "ss": { + "data": [ + "x1", "x2", "x3", "x4", "x5", + "x6", "x7", "x8", "x9", "x10" + ] + } + }, + "input_types": { + "ss": { + "data": [ + "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", + "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE", "DT_DOUBLE" + ] + } + }, + "feature_weights": { + "ds": { + "data": [ + -0.32705051172041194, 0.95102386482712309, + 1.01145375640758, 1.3493328346102449, + -0.97103250283196174, -0.53749125086392879, + 0.92053884353604121, -0.72217737944554916, + 0.14693041881992241, -0.0707939985283586 + ] + } + }, + "output_col_name": { + "s": "y", + } + }, + "op_version": "0.0.1", } +)JSON"})); } // namespace secretflow::serving::op diff --git a/secretflow_serving/ops/graph.cc b/secretflow_serving/ops/graph.cc index 2fbf9f1..4f96101 100644 --- a/secretflow_serving/ops/graph.cc +++ b/secretflow_serving/ops/graph.cc @@ -28,7 +28,7 @@ namespace { // BFS, out_node ---> in_node void NodeTraversal( std::unordered_map>* visited, - const std::map>& nodes) { + const std::unordered_map>& nodes) { std::deque> queue; std::unordered_set> visited_edges; for (const auto& pair : *visited) { @@ -36,7 +36,7 @@ void NodeTraversal( } while (!queue.empty()) { - const auto& n = queue.front(); + auto n = queue.front(); queue.pop_front(); const auto& in_edges = n->in_edges(); for (const auto& e : in_edges) { @@ -56,8 +56,9 @@ void NodeTraversal( } // namespace -Execution::Execution(size_t id, ExecutionDef execution_def, - std::map> nodes) +Execution::Execution( + size_t id, ExecutionDef execution_def, + std::unordered_map> nodes) : id_(id), execution_def_(std::move(execution_def)), nodes_(std::move(nodes)), @@ -65,14 +66,17 @@ Execution::Execution(size_t id, ExecutionDef execution_def, is_exit_(false) { // get execution exit nodes & entry nodes for (const auto& [node_name, node] : nodes_) { - const auto& dst_edge = node->out_edge(); + const auto& dst_edges = node->out_edges(); const auto& in_edges = node->in_edges(); // find exit nodes - if (dst_edge == nullptr) { + if (dst_edges.empty()) { exit_node_names_.emplace(node_name); is_exit_ = true; } else { - if (nodes_.find(dst_edge->dst_node()) == nodes_.end()) { + if (std::any_of(dst_edges.begin(), dst_edges.end(), + [&](const auto& edge) { + return nodes_.find(edge->dst_node()) == nodes_.end(); + })) { exit_node_names_.emplace(node_name); } } @@ -106,10 +110,21 @@ bool Execution::IsExitNode(const std::string& node_name) const { const std::shared_ptr& Execution::GetNode(const std::string& name) const { auto iter = nodes_.find(name); - SERVING_ENFORCE(iter != nodes_.end(), errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(iter != nodes_.end(), errors::ErrorCode::LOGIC_ERROR, + "can not find {} in execution {}", name, id_); return iter->second; } +bool Execution::TryGetNode(const std::string& name, + std::shared_ptr* node) const { + auto iter = nodes_.find(name); + if (iter == nodes_.end()) { + return false; + } + *node = iter->second; + return true; +} + void Execution::CheckNodesReachability() { std::unordered_map> reachable_nodes; for (const auto& n : exit_node_names_) { @@ -137,6 +152,17 @@ Graph::Graph(GraphDef graph_def) : def_(std::move(graph_def)) { // TODO: consider not storing def_ to avoiding multiple copies of node_defs // and execution_defs + graph_view_.set_version(def_.version()); + for (auto& node : def_.node_list()) { + NodeView view; + *(view.mutable_name()) = node.name(); + *(view.mutable_op()) = node.op(); + *(view.mutable_op_version()) = node.op_version(); + *(view.mutable_parents()) = node.parents(); + graph_view_.mutable_node_list()->Add(std::move(view)); + } + *(graph_view_.mutable_execution_list()) = def_.execution_list(); + // create nodes for (int i = 0; i < def_.node_list_size(); ++i) { const auto node_name = def_.node_list(i).name(); @@ -163,7 +189,7 @@ Graph::Graph(GraphDef graph_def) : def_(std::move(graph_def)) { "can not found input node:{} for node:{}", input_nodes[i], name); auto edge = std::make_shared(n_iter->first, name, i); - n_iter->second->SetOutEdge(edge); + n_iter->second->AddOutEdge(edge); node->AddInEdge(edge); edges_.emplace_back(edge); } @@ -172,14 +198,13 @@ Graph::Graph(GraphDef graph_def) : def_(std::move(graph_def)) { // find exit node size_t exit_node_count = 0; for (const auto& pair : nodes_) { - if (pair.second->out_edge() == nullptr) { + if (pair.second->out_edges().empty()) { exit_node_ = pair.second; ++exit_node_count; } } SERVING_ENFORCE(!entry_nodes_.empty(), errors::ErrorCode::LOGIC_ERROR, - "can not found any entry node, please check graph def.", - exit_node_count); + "can not found any entry node, please check graph def."); SERVING_ENFORCE(exit_node_count == 1, errors::ErrorCode::LOGIC_ERROR, "found {} exit nodes, expect only 1 in graph", exit_node_count); @@ -217,12 +242,12 @@ void Graph::CheckNodesReachability() { } void Graph::CheckEdgeValidate() { - std::map> kernel_map; + std::unordered_map> kernel_map; const auto get_kernel_func = [&](const std::shared_ptr& n) -> std::shared_ptr { auto iter = kernel_map.find(n->GetName()); if (iter == kernel_map.end()) { - op::OpKernelOptions ctx{n}; + op::OpKernelOptions ctx{n->node_def(), n->GetOpDef()}; auto kernel = op::OpKernelFactory::GetInstance()->Create(std::move(ctx)); kernel_map.emplace(n->GetName(), kernel); return kernel; @@ -238,33 +263,20 @@ void Graph::CheckEdgeValidate() { const auto& src_schema = src_kernel->GetOutputSchema(); const auto& dst_schema = dst_kernel->GetInputSchema(e->dst_input_id()); - SERVING_ENFORCE(src_schema->num_fields() == dst_schema->num_fields(), - errors::ErrorCode::LOGIC_ERROR, - "node({}) output schema does not fit node({}) input " - "schema, size: {}-{}", - e->src_node(), e->dst_node(), src_schema->num_fields(), - dst_schema->num_fields()); - for (int i = 0; i < src_schema->num_fields(); ++i) { - const auto& src_f = src_schema->field(i); - auto dst_f = dst_schema->GetFieldByName(src_f->name()); - SERVING_ENFORCE(dst_f, errors::ErrorCode::LOGIC_ERROR, - "node({}) output schema does not fit node({}) input " - "schema, missed field:{}", - e->src_node(), e->dst_node(), src_f->name()); - SERVING_ENFORCE(src_f->Equals(dst_f), errors::ErrorCode::LOGIC_ERROR, - "node({}) output schema does not fit node({}) input " - "schema, field:{} not equal"); - } + // Check the dst_schema is a subset of the src_schema + CheckReferenceFields( + src_schema, dst_schema, + fmt::format("edge schema check failed, src: {}, dst: {}", e->src_node(), + e->dst_node())); } } void Graph::BuildExecution() { std::unordered_set node_name_set; const auto& execution_def_list = def_.execution_list(); - SERVING_ENFORCE(execution_def_list.size() > 0, - errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(!execution_def_list.empty(), errors::ErrorCode::LOGIC_ERROR); for (int i = 0; i < execution_def_list.size(); ++i) { - std::map> nodes; + std::unordered_map> nodes; for (const auto& n_name : execution_def_list[i].nodes()) { auto n_iter = nodes_.find(n_name); SERVING_ENFORCE(n_iter != nodes_.end(), errors::ErrorCode::LOGIC_ERROR, @@ -284,18 +296,26 @@ void Graph::BuildExecution() { } void Graph::CheckExecutionValidate() { - // TODO: remove limit - SERVING_ENFORCE(executions_.size() == 2, errors::ErrorCode::LOGIC_ERROR, - "graph must contain 2 executions."); + SERVING_ENFORCE(executions_.size() >= 2, errors::ErrorCode::LOGIC_ERROR, + "graph must contain 2 executions at least."); + + auto prev_dispatch_type = + DispatchType::DispatchType_INT_MAX_SENTINEL_DO_NOT_USE_; for (const auto& e : executions_) { + SERVING_ENFORCE(e->GetDispatchType() != prev_dispatch_type, + errors::ErrorCode::LOGIC_ERROR, + "The dispatch types of two adjacent executions cannot be " + "the same, cur exeution id {}, type: {}", + e->id(), DispatchType_Name(e->GetDispatchType())); + prev_dispatch_type = e->GetDispatchType(); + if (e->IsEntry()) { SERVING_ENFORCE(e->GetDispatchType() == DispatchType::DP_ALL, errors::ErrorCode::LOGIC_ERROR); } if (e->IsExit()) { - // TODO: allow DP_SPECIFIED - SERVING_ENFORCE(e->GetDispatchType() == DispatchType::DP_ANYONE, + SERVING_ENFORCE(e->GetDispatchType() != DispatchType::DP_ALL, errors::ErrorCode::LOGIC_ERROR); SERVING_ENFORCE(e->GetExitNodeNum() == 1, errors::ErrorCode::LOGIC_ERROR); SERVING_ENFORCE(e->IsExitNode(exit_node_->GetName()), @@ -304,4 +324,11 @@ void Graph::CheckExecutionValidate() { } } +const std::shared_ptr& Graph::GetNode(const std::string& name) const { + auto iter = nodes_.find(name); + SERVING_ENFORCE(iter != nodes_.end(), errors::ErrorCode::LOGIC_ERROR, + "can not find node({}) in graph", name); + return iter->second; +} + } // namespace secretflow::serving diff --git a/secretflow_serving/ops/graph.h b/secretflow_serving/ops/graph.h index 6118291..003872d 100644 --- a/secretflow_serving/ops/graph.h +++ b/secretflow_serving/ops/graph.h @@ -14,21 +14,24 @@ #pragma once -#include #include -#include +#include +#include #include #include "arrow/api.h" #include "secretflow_serving/ops/node.h" +#include "secretflow_serving/protos/graph.pb.h" + namespace secretflow::serving { class Execution final { public: - explicit Execution(size_t id, ExecutionDef execution_def, - std::map> nodes); + explicit Execution( + size_t id, ExecutionDef execution_def, + std::unordered_map> nodes); ~Execution() = default; size_t id() const { return id_; } @@ -41,6 +44,10 @@ class Execution final { DispatchType GetDispatchType() const; + bool SpecificToThis() const { + return execution_def_.config().specific_flag(); + } + size_t GetEntryNodeNum() const; size_t GetExitNodeNum() const; @@ -51,25 +58,27 @@ class Execution final { return entry_nodes_; } - const std::map>& nodes() const { + const std::unordered_map>& nodes() const { return nodes_; } const std::shared_ptr& GetNode(const std::string& name) const; + bool TryGetNode(const std::string& name, std::shared_ptr* node) const; + protected: void CheckNodesReachability(); private: const size_t id_; const ExecutionDef execution_def_; - const std::map> nodes_; + const std::unordered_map> nodes_; bool is_entry_; bool is_exit_; std::vector> entry_nodes_; - std::set exit_node_names_; + std::unordered_set exit_node_names_; }; class Graph final { @@ -79,10 +88,14 @@ class Graph final { const GraphDef& def() { return def_; } + GraphView GetView() { return graph_view_; } + const std::vector>& GetExecutions() const { return executions_; } + const std::shared_ptr& GetNode(const std::string& name) const; + protected: void CheckNodesReachability(); @@ -95,7 +108,9 @@ class Graph final { private: const GraphDef def_; - std::map> nodes_; + GraphView graph_view_; + + std::unordered_map> nodes_; std::vector> edges_; std::vector> executions_; diff --git a/secretflow_serving/ops/graph_test.cc b/secretflow_serving/ops/graph_test.cc index 14fabad..829c917 100644 --- a/secretflow_serving/ops/graph_test.cc +++ b/secretflow_serving/ops/graph_test.cc @@ -27,12 +27,13 @@ class MockOpKernel0 : public OpKernel { explicit MockOpKernel0(OpKernelOptions opts) : OpKernel(std::move(opts)) { auto schema = arrow::schema({arrow::field("test_field_0", arrow::float64()), - arrow::field("test_field_1", arrow::float64())}); + arrow::field("test_field_1", arrow::float64()), + arrow::field("test_field_2", arrow::float64())}); input_schema_list_ = {schema}; output_schema_ = schema; } - void Compute(ComputeContext* ctx) override {} + void DoCompute(ComputeContext* ctx) override {} void BuildInputSchema() override {} void BuildOutputSchema() override {} }; @@ -41,14 +42,14 @@ class MockOpKernel1 : public OpKernel { public: explicit MockOpKernel1(OpKernelOptions opts) : OpKernel(std::move(opts)) { auto schema = - arrow::schema({arrow::field("test_field_0", arrow::float64()), - arrow::field("test_field_1", arrow::float64())}); + arrow::schema({arrow::field("test_field_1", arrow::float64()), + arrow::field("test_field_0", arrow::float64())}); input_schema_list_ = {schema}; output_schema_ = arrow::schema({arrow::field("test_field_a", arrow::float64())}); } - void Compute(ComputeContext* ctx) override {} + void DoCompute(ComputeContext* ctx) override {} void BuildInputSchema() override {} void BuildOutputSchema() override {} }; @@ -62,7 +63,24 @@ class MockOpKernel2 : public OpKernel { output_schema_ = schema; } - void Compute(ComputeContext* ctx) override {} + void DoCompute(ComputeContext* ctx) override {} + void BuildInputSchema() override {} + void BuildOutputSchema() override {} +}; + +class MockOpKernel3 : public OpKernel { + public: + explicit MockOpKernel3(OpKernelOptions opts) : OpKernel(std::move(opts)) { + auto schema = + arrow::schema({arrow::field("test_field_2", arrow::float64()), + arrow::field("test_field_0", arrow::float64())}); + auto schema1 = + arrow::schema({arrow::field("test_field_a", arrow::float64())}); + input_schema_list_ = {schema, schema1}; + output_schema_ = schema1; + } + + void DoCompute(ComputeContext* ctx) override {} void BuildInputSchema() override {} void BuildOutputSchema() override {} }; @@ -70,6 +88,7 @@ class MockOpKernel2 : public OpKernel { REGISTER_OP_KERNEL(TEST_OP_0, MockOpKernel0); REGISTER_OP_KERNEL(TEST_OP_1, MockOpKernel1); REGISTER_OP_KERNEL(TEST_OP_2, MockOpKernel2); +REGISTER_OP_KERNEL(TEST_OP_3, MockOpKernel3); REGISTER_OP(TEST_OP_0, "0.0.1", "test_desc") .StringAttr("attr_s", "attr_s_desc", false, false) .Input("input", "input_desc") @@ -84,6 +103,12 @@ REGISTER_OP(TEST_OP_2, "0.0.1", "test_desc") .StringAttr("attr_s", "attr_s_desc", false, false) .Input("input", "input_desc") .Output("output", "output_desc"); +REGISTER_OP(TEST_OP_3, "0.0.1", "test_desc") + .Returnable() + .StringAttr("attr_s", "attr_s_desc", false, false) + .Input("input", "input_desc") + .Input("input2", "input_desc") + .Output("output", "output_desc"); class GraphTest : public ::testing::Test { protected: @@ -118,7 +143,7 @@ TEST_F(GraphTest, Works) { { "name": "node_c", "op": "TEST_OP_1", - "parents": [ "node_b" ], + "parents": [ "node_a" ], "attr_values": { "attr_s": { "s": "b" @@ -127,8 +152,18 @@ TEST_F(GraphTest, Works) { }, { "name": "node_d", + "op": "TEST_OP_3", + "parents": [ "node_b", "node_c" ], + "attr_values": { + "attr_s": { + "s": "b" + }, + }, + }, + { + "name": "node_e", "op": "TEST_OP_2", - "parents": [ "node_c" ], + "parents": [ "node_d" ], "attr_values": { "attr_s": { "s": "b" @@ -139,7 +174,7 @@ TEST_F(GraphTest, Works) { "execution_list": [ { "nodes": [ - "node_a", "node_b" + "node_a", "node_b", "node_c", "node_d", ], "config": { "dispatch_type": "DP_ALL" @@ -147,7 +182,7 @@ TEST_F(GraphTest, Works) { }, { "nodes": [ - "node_c", "node_d" + "node_e" ], "config": { "dispatch_type": "DP_ANYONE" @@ -170,8 +205,10 @@ TEST_F(GraphTest, Works) { EXPECT_FALSE(execution_list[0]->IsExit()); EXPECT_EQ(execution_list[0]->GetEntryNodeNum(), 1); EXPECT_EQ(execution_list[0]->GetExitNodeNum(), 1); - EXPECT_TRUE(execution_list[0]->IsExitNode("node_b")); + EXPECT_FALSE(execution_list[0]->IsExitNode("node_b")); EXPECT_FALSE(execution_list[0]->IsExitNode("node_a")); + EXPECT_FALSE(execution_list[0]->IsExitNode("node_c")); + EXPECT_TRUE(execution_list[0]->IsExitNode("node_d")); EXPECT_TRUE(execution_list[1]); EXPECT_EQ(execution_list[1]->id(), 1); @@ -179,8 +216,7 @@ TEST_F(GraphTest, Works) { EXPECT_TRUE(execution_list[1]->IsExit()); EXPECT_EQ(execution_list[1]->GetEntryNodeNum(), 1); EXPECT_EQ(execution_list[1]->GetExitNodeNum(), 1); - EXPECT_TRUE(execution_list[1]->IsExitNode("node_d")); - EXPECT_FALSE(execution_list[1]->IsExitNode("node_c")); + EXPECT_TRUE(execution_list[1]->IsExitNode("node_e")); } struct ErrorParam { @@ -452,26 +488,6 @@ INSTANTIATE_TEST_SUITE_P( "s": "a" }, }, - }, - { - "name": "node_b", - "op": "TEST_OP_1", - "parents": [ "node_a" ], - "attr_values": { - "attr_s": { - "s": "b" - }, - }, - }, - { - "name": "node_c", - "op": "TEST_OP_2", - "parents": [ "node_b" ], - "attr_values": { - "attr_s": { - "s": "b" - }, - }, } ], "execution_list": [ @@ -482,22 +498,6 @@ INSTANTIATE_TEST_SUITE_P( "config": { "dispatch_type": "DP_ALL" } - }, - { - "nodes": [ - "node_b" - ], - "config": { - "dispatch_type": "DP_ANYONE" - } - }, - { - "nodes": [ - "node_c" - ], - "config": { - "dispatch_type": "DP_ANYONE" - } } ] } diff --git a/secretflow_serving/ops/graph_version.h b/secretflow_serving/ops/graph_version.h new file mode 100644 index 0000000..2b4252d --- /dev/null +++ b/secretflow_serving/ops/graph_version.h @@ -0,0 +1,14 @@ +#pragma once + +// Version upgrade when `GraphDef` changed. +#define SERVING_GRAPH_MAJOR_VERSION 0 +#define SERVING_GRAPH_MINOR_VERSION 1 +#define SERVING_GRAPH_PATCH_VERSION 0 + +#define SERVING_STR_HELPER(x) #x +#define SERVING_STR(x) SERVING_STR_HELPER(x) + +#define SERVING_GRAPH_VERSION_STRING \ + SERVING_STR(SERVING_GRAPH_MAJOR_VERSION) \ + "." SERVING_STR(SERVING_GRAPH_MINOR_VERSION) "." SERVING_STR( \ + SERVING_GRAPH_PATCH_VERSION) diff --git a/secretflow_serving/ops/merge_y.cc b/secretflow_serving/ops/merge_y.cc index 31345f6..b86d7b4 100644 --- a/secretflow_serving/ops/merge_y.cc +++ b/secretflow_serving/ops/merge_y.cc @@ -27,37 +27,31 @@ namespace secretflow::serving::op { MergeY::MergeY(OpKernelOptions opts) : OpKernel(std::move(opts)) { - link_function_ = - GetNodeAttr(opts_.node->node_def(), "link_function"); - ValidateLinkFuncType(link_function_); + auto link_function_name = + GetNodeAttr(opts_.node_def, "link_function"); + link_function_ = ParseLinkFuncType(link_function_name); // optional attr - GetNodeAttr(opts_.node->node_def(), "yhat_scale", &yhat_scale_); + GetNodeAttr(opts_.node_def, "yhat_scale", &yhat_scale_); - input_col_name_ = - GetNodeAttr(opts_.node->node_def(), "input_col_name"); + input_col_name_ = GetNodeAttr(opts_.node_def, "input_col_name"); output_col_name_ = - GetNodeAttr(opts_.node->node_def(), "output_col_name"); + GetNodeAttr(opts_.node_def, "output_col_name"); BuildInputSchema(); BuildOutputSchema(); } -void MergeY::Compute(ComputeContext* ctx) { +void MergeY::DoCompute(ComputeContext* ctx) { // santiy check - SERVING_ENFORCE(ctx->inputs->size() == 1, errors::ErrorCode::LOGIC_ERROR); - SERVING_ENFORCE(ctx->inputs->front().size() >= 1, + SERVING_ENFORCE(ctx->inputs.size() == 1, errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(ctx->inputs.front().size() >= 1, errors::ErrorCode::LOGIC_ERROR); - auto num_rows = ctx->inputs->front()[0]->num_rows(); - for (size_t i = 1; i < ctx->inputs->front().size(); ++i) { - auto cur_num_rows = ctx->inputs->front()[i]->num_rows(); - SERVING_ENFORCE_EQ(num_rows, cur_num_rows); - } // merge partial_y - arrow::Datum incremented_datum(ctx->inputs->front()[0]->column(0)); - for (size_t i = 1; i < ctx->inputs->front().size(); ++i) { - auto cur_array = ctx->inputs->front()[i]->column(0); + arrow::Datum incremented_datum(ctx->inputs.front()[0]->column(0)); + for (size_t i = 1; i < ctx->inputs.front().size(); ++i) { + auto cur_array = ctx->inputs.front()[i]->column(0); SERVING_GET_ARROW_RESULT(arrow::compute::Add(incremented_datum, cur_array), incremented_datum); } @@ -74,7 +68,8 @@ void MergeY::Compute(ComputeContext* ctx) { } std::shared_ptr res_array; SERVING_CHECK_ARROW_STATUS(builder.Finish(&res_array)); - ctx->output = MakeRecordBatch(output_schema_, num_rows, {res_array}); + ctx->output = + MakeRecordBatch(output_schema_, res_array->length(), {res_array}); } void MergeY::BuildInputSchema() { @@ -91,25 +86,34 @@ void MergeY::BuildOutputSchema() { } REGISTER_OP_KERNEL(MERGE_Y, MergeY) -REGISTER_OP(MERGE_Y, "0.0.1", +REGISTER_OP(MERGE_Y, "0.0.2", "Merge all partial y(score) and apply link function") .Returnable() .Mergeable() - .DoubleAttr("yhat_scale", "", false, true, 1.0d) + .DoubleAttr( + "yhat_scale", + "In order to prevent value overflow, GLM training is performed on the " + "scaled y label. So in the prediction process, you need to enlarge " + "yhat back to get the real predicted value, `yhat = yhat_scale * " + "link(X * W)`", + false, true, 1.0) .StringAttr( "link_function", - "optinal value: LF_LOG, LF_LOGIT, LF_INVERSE, LF_LOGIT_V2, " + "Type of link function, defined in " + "`secretflow_serving/protos/link_function.proto`. Optional value: " + "LF_LOG, LF_LOGIT, LF_INVERSE, " "LF_RECIPROCAL, " - "LF_INDENTITY, LF_SIGMOID_RAW, LF_SIGMOID_MM1, LF_SIGMOID_MM3, " + "LF_IDENTITY, LF_SIGMOID_RAW, LF_SIGMOID_MM1, LF_SIGMOID_MM3, " "LF_SIGMOID_GA, " "LF_SIGMOID_T1, LF_SIGMOID_T3, " "LF_SIGMOID_T5, LF_SIGMOID_T7, LF_SIGMOID_T9, LF_SIGMOID_LS7, " "LF_SIGMOID_SEG3, " "LF_SIGMOID_SEG5, LF_SIGMOID_DF, LF_SIGMOID_SR, LF_SIGMOID_SEGLS", false, false) - .StringAttr("input_col_name", "", false, false) - .StringAttr("output_col_name", "", false, false) - .Input("partial_ys", "") - .Output("scores", ""); + .StringAttr("input_col_name", "The column name of partial_y", false, false) + .StringAttr("output_col_name", "The column name of merged score", false, + false) + .Input("partial_ys", "The list of partial y, data type: `double`") + .Output("scores", "The merge result of `partial_ys`, data type: `double`"); } // namespace secretflow::serving::op diff --git a/secretflow_serving/ops/merge_y.h b/secretflow_serving/ops/merge_y.h index 892b019..51f772f 100644 --- a/secretflow_serving/ops/merge_y.h +++ b/secretflow_serving/ops/merge_y.h @@ -16,13 +16,15 @@ #include "secretflow_serving/ops/op_kernel.h" +#include "secretflow_serving/protos/link_function.pb.h" + namespace secretflow::serving::op { class MergeY : public OpKernel { public: explicit MergeY(OpKernelOptions opts); - void Compute(ComputeContext* ctx) override; + void DoCompute(ComputeContext* ctx) override; protected: void BuildInputSchema() override; @@ -32,7 +34,7 @@ class MergeY : public OpKernel { private: double yhat_scale_ = 1.0; - std::string link_function_; + LinkFunctionType link_function_; std::string input_col_name_; std::string output_col_name_; }; diff --git a/secretflow_serving/ops/merge_y_benchmark.cc b/secretflow_serving/ops/merge_y_benchmark.cc new file mode 100644 index 0000000..66cfac3 --- /dev/null +++ b/secretflow_serving/ops/merge_y_benchmark.cc @@ -0,0 +1,122 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "benchmark/benchmark.h" + +#include "secretflow_serving/core/link_func.h" +#include "secretflow_serving/ops/merge_y.h" +#include "secretflow_serving/ops/op_factory.h" +#include "secretflow_serving/ops/op_kernel_factory.h" +#include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/utils.h" + +namespace secretflow::serving::op { + +static constexpr const char* const kLinkFuncsArray[] = { + "LF_LOG", "LF_LOGIT", "LF_INVERSE", "LF_RECIPROCAL", + "LF_IDENTITY", "LF_SIGMOID_RAW", "LF_SIGMOID_MM1", "LF_SIGMOID_MM3", + "LF_SIGMOID_GA", "LF_SIGMOID_T1", "LF_SIGMOID_T3", "LF_SIGMOID_T5", + "LF_SIGMOID_T7", "LF_SIGMOID_T9", "LF_SIGMOID_LS7", "LF_SIGMOID_SEG3", + "LF_SIGMOID_SEG5", "LF_SIGMOID_DF", "LF_SIGMOID_SR", "LF_SIGMOID_SEGLS"}; + +void BMMergeYOPBench(benchmark::State& state) { + std::string json_content = R"JSON( +{ + "name": "test_node", + "op": "MERGE_Y", + "attr_values": { + "input_col_name": { + "s": "y" + }, + "output_col_name": { + "s": "score" + } + } +} +)JSON"; + + auto link_func_index = state.range(2); + auto party_nums = state.range(1); + auto row_nums = state.range(0); + + state.counters["log2(row_nums*party_nums)"] = log2(row_nums * party_nums); + state.counters["row_nums"] = row_nums; + state.counters["party_nums"] = party_nums; + state.counters["link_func_index"] = link_func_index; + state.counters["log2(row_nums*party_nums)"] = log2(row_nums * party_nums); + + state.SetLabel(kLinkFuncsArray[link_func_index]); + + NodeDef node_def; + JsonToPb(json_content, &node_def); + { + AttrValue link_func_value; + link_func_value.set_s(kLinkFuncsArray[link_func_index]); + node_def.mutable_attr_values()->insert( + {"link_function", std::move(link_func_value)}); + } + { + AttrValue scale_value; + scale_value.set_d(1.14); + node_def.mutable_attr_values()->insert( + {"yhat_scale", std::move(scale_value)}); + } + + // build node + auto mock_node = std::make_shared(std::move(node_def)); + + OpKernelOptions opts{mock_node->node_def(), mock_node->GetOpDef()}; + auto kernel = OpKernelFactory::GetInstance()->Create(std::move(opts)); + + // check input schema + const auto& input_schema_list = kernel->GetAllInputSchema(); + + // compute + ComputeContext compute_ctx; + std::vector> input_list; + + // mock input values + std::vector> feature_value_list; + std::generate_n(std::back_inserter(feature_value_list), party_nums, + [row_nums]() { + std::vector tmp; + std::generate_n(std::back_inserter(tmp), row_nums, + [] { return rand(); }); + return tmp; + }); + for (size_t i = 0; i < feature_value_list.size(); ++i) { + arrow::DoubleBuilder builder; + SERVING_CHECK_ARROW_STATUS(builder.AppendValues(feature_value_list[i])); + std::shared_ptr array; + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + input_list.emplace_back( + MakeRecordBatch(input_schema_list.front(), row_nums, {array})); + } + compute_ctx.inputs.emplace_back(std::move(input_list)); + + for (auto _ : state) { + kernel->Compute(&compute_ctx); + } +} + +BENCHMARK(BMMergeYOPBench) + ->ArgsProduct({ + benchmark::CreateRange(8, 1u << 25, /*multi=*/2), + {1, 2, 3}, + benchmark::CreateDenseRange( + 0, 1, // sizeof(kLinkFuncsArray) / sizeof(kLinkFuncsArray[0]) - 1, + /*step=*/1), + }); + +} // namespace secretflow::serving::op diff --git a/secretflow_serving/ops/merge_y_test.cc b/secretflow_serving/ops/merge_y_test.cc index 265ce53..1b59b94 100644 --- a/secretflow_serving/ops/merge_y_test.cc +++ b/secretflow_serving/ops/merge_y_test.cc @@ -17,6 +17,7 @@ #include "gtest/gtest.h" #include "secretflow_serving/core/link_func.h" +#include "secretflow_serving/ops/op_factory.h" #include "secretflow_serving/ops/op_kernel_factory.h" #include "secretflow_serving/util/arrow_helper.h" #include "secretflow_serving/util/utils.h" @@ -73,9 +74,11 @@ TEST_P(MergeYParamTest, Works) { // expect result double expect_score_0 = - ApplyLinkFunc(0.1 + 0.1 + 0.1, param.link_func) * param.yhat_scale; + ApplyLinkFunc(0.1 + 0.1 + 0.1, ParseLinkFuncType(param.link_func)) * + param.yhat_scale; double expect_score_1 = - ApplyLinkFunc(0.11 + 0.12 + 0.13, param.link_func) * param.yhat_scale; + ApplyLinkFunc(0.11 + 0.12 + 0.13, ParseLinkFuncType(param.link_func)) * + param.yhat_scale; double epsilon = 1E-13; // build node @@ -83,7 +86,7 @@ TEST_P(MergeYParamTest, Works) { ASSERT_EQ(mock_node->GetOpDef()->inputs_size(), 1); ASSERT_TRUE(mock_node->GetOpDef()->tag().returnable()); - OpKernelOptions opts{mock_node}; + OpKernelOptions opts{mock_node->node_def(), mock_node->GetOpDef()}; auto kernel = OpKernelFactory::GetInstance()->Create(std::move(opts)); // check input schema @@ -110,7 +113,6 @@ TEST_P(MergeYParamTest, Works) { // compute ComputeContext compute_ctx; - compute_ctx.inputs = std::make_shared(); std::vector> input_list; for (size_t i = 0; i < feature_value_list.size(); ++i) { arrow::DoubleBuilder builder; @@ -120,7 +122,7 @@ TEST_P(MergeYParamTest, Works) { input_list.emplace_back( MakeRecordBatch(input_schema_list.front(), 2, {array})); } - compute_ctx.inputs->emplace_back(std::move(input_list)); + compute_ctx.inputs.emplace_back(std::move(input_list)); kernel->Compute(&compute_ctx); @@ -145,15 +147,15 @@ INSTANTIATE_TEST_SUITE_P( MergeYParamTestSuite, MergeYParamTest, ::testing::Values( Param{"LF_LOG", 1.0}, Param{"LF_LOGIT", 1.0}, Param{"LF_INVERSE", 1.0}, - Param{"LF_LOGIT_V2", 1.0}, Param{"LF_RECIPROCAL", 1.1}, - Param{"LF_INDENTITY", 1.2}, Param{"LF_SIGMOID_RAW", 1.3}, - Param{"LF_SIGMOID_MM1", 1.4}, Param{"LF_SIGMOID_MM3", 1.5}, - Param{"LF_SIGMOID_GA", 1.6}, Param{"LF_SIGMOID_T1", 1.7}, - Param{"LF_SIGMOID_T3", 1.8}, Param{"LF_SIGMOID_T5", 1.9}, - Param{"LF_SIGMOID_T7", 1.01}, Param{"LF_SIGMOID_T9", 1.02}, - Param{"LF_SIGMOID_LS7", 1.03}, Param{"LF_SIGMOID_SEG3", 1.04}, - Param{"LF_SIGMOID_SEG5", 1.05}, Param{"LF_SIGMOID_DF", 1.06}, - Param{"LF_SIGMOID_SR", 1.07}, Param{"LF_SIGMOID_SEGLS", 1.08})); + Param{"LF_RECIPROCAL", 1.1}, Param{"LF_IDENTITY", 1.2}, + Param{"LF_SIGMOID_RAW", 1.3}, Param{"LF_SIGMOID_MM1", 1.4}, + Param{"LF_SIGMOID_MM3", 1.5}, Param{"LF_SIGMOID_GA", 1.6}, + Param{"LF_SIGMOID_T1", 1.7}, Param{"LF_SIGMOID_T3", 1.8}, + Param{"LF_SIGMOID_T5", 1.9}, Param{"LF_SIGMOID_T7", 1.01}, + Param{"LF_SIGMOID_T9", 1.02}, Param{"LF_SIGMOID_LS7", 1.03}, + Param{"LF_SIGMOID_SEG3", 1.04}, Param{"LF_SIGMOID_SEG5", 1.05}, + Param{"LF_SIGMOID_DF", 1.06}, Param{"LF_SIGMOID_SR", 1.07}, + Param{"LF_SIGMOID_SEGLS", 1.08})); // TODO: exception case @@ -187,7 +189,8 @@ TEST_F(MergeYTest, Constructor) { NodeDef node_def; JsonToPb(json_content, &node_def); - OpKernelOptions opts{std::make_shared(std::move(node_def))}; + auto op_def = OpFactory::GetInstance()->Get("MERGE_Y"); + OpKernelOptions opts{std::move(node_def), op_def}; EXPECT_NO_THROW(OpKernelFactory::GetInstance()->Create(std::move(opts))); } @@ -214,7 +217,8 @@ TEST_F(MergeYTest, Constructor) { NodeDef node_def; JsonToPb(json_content, &node_def); - OpKernelOptions opts{std::make_shared(std::move(node_def))}; + auto op_def = OpFactory::GetInstance()->Get("MERGE_Y"); + OpKernelOptions opts{std::move(node_def), op_def}; EXPECT_THROW(OpKernelFactory::GetInstance()->Create(std::move(opts)), Exception); } @@ -239,7 +243,8 @@ TEST_F(MergeYTest, Constructor) { NodeDef node_def; JsonToPb(json_content, &node_def); - OpKernelOptions opts{std::make_shared(std::move(node_def))}; + auto op_def = OpFactory::GetInstance()->Get("MERGE_Y"); + OpKernelOptions opts{std::move(node_def), op_def}; EXPECT_THROW(OpKernelFactory::GetInstance()->Create(std::move(opts)), Exception); } @@ -264,7 +269,8 @@ TEST_F(MergeYTest, Constructor) { NodeDef node_def; JsonToPb(json_content, &node_def); - OpKernelOptions opts{std::make_shared(std::move(node_def))}; + auto op_def = OpFactory::GetInstance()->Get("MERGE_Y"); + OpKernelOptions opts{std::move(node_def), op_def}; EXPECT_THROW(OpKernelFactory::GetInstance()->Create(std::move(opts)), Exception); } @@ -289,7 +295,8 @@ TEST_F(MergeYTest, Constructor) { NodeDef node_def; JsonToPb(json_content, &node_def); - OpKernelOptions opts{std::make_shared(std::move(node_def))}; + auto op_def = OpFactory::GetInstance()->Get("MERGE_Y"); + OpKernelOptions opts{std::move(node_def), op_def}; EXPECT_THROW(OpKernelFactory::GetInstance()->Create(std::move(opts)), Exception); } diff --git a/secretflow_serving/ops/node.cc b/secretflow_serving/ops/node.cc index 472186a..cb168a3 100644 --- a/secretflow_serving/ops/node.cc +++ b/secretflow_serving/ops/node.cc @@ -40,8 +40,8 @@ void Node::AddInEdge(const std::shared_ptr& in_edge) { in_edges_.emplace_back(in_edge); } -void Node::SetOutEdge(const std::shared_ptr& out_edge) { - out_edge_ = out_edge; +void Node::AddOutEdge(const std::shared_ptr& out_edge) { + out_edges_.emplace_back(out_edge); } } // namespace secretflow::serving diff --git a/secretflow_serving/ops/node.h b/secretflow_serving/ops/node.h index 1b21477..fc697ef 100644 --- a/secretflow_serving/ops/node.h +++ b/secretflow_serving/ops/node.h @@ -14,7 +14,6 @@ #pragma once -#include #include #include @@ -46,11 +45,13 @@ class Node final { return in_edges_; } - const std::shared_ptr& out_edge() const { return out_edge_; } + const std::vector>& out_edges() const { + return out_edges_; + } void AddInEdge(const std::shared_ptr& in_edge); - void SetOutEdge(const std::shared_ptr& out_edge); + void AddOutEdge(const std::shared_ptr& out_edge); private: const NodeDef node_def_; @@ -58,7 +59,7 @@ class Node final { std::vector input_nodes_; std::vector> in_edges_; - std::shared_ptr out_edge_; + std::vector> out_edges_; }; class Edge final { diff --git a/secretflow_serving/ops/node_def_util.cc b/secretflow_serving/ops/node_def_util.cc index 7c04f59..df1189f 100644 --- a/secretflow_serving/ops/node_def_util.cc +++ b/secretflow_serving/ops/node_def_util.cc @@ -30,7 +30,7 @@ bool GetAttrValue(const NodeDef& node_def, const std::string& attr_name, } // namespace -#define DEFINE_GETT_LIST_ATTR(TYPE, FIELD_LIST, CAST) \ +#define DEFINE_GET_LIST_ATTR(TYPE, FIELD_LIST, CAST) \ bool GetNodeAttr(const NodeDef& node_def, const std::string& attr_name, \ std::vector* value) { \ AttrValue attr_value; \ @@ -67,7 +67,7 @@ bool GetAttrValue(const NodeDef& node_def, const std::string& attr_name, *value = CAST; \ return true; \ } \ - DEFINE_GETT_LIST_ATTR(TYPE, FIELD##s, CAST) + DEFINE_GET_LIST_ATTR(TYPE, FIELD##s, CAST) DEFINE_GET_ATTR(std::string, s, v) DEFINE_GET_ATTR(int64_t, i64, v) @@ -77,4 +77,39 @@ DEFINE_GET_ATTR(double, d, v) DEFINE_GET_ATTR(bool, b, v) #undef DEFINE_GET_ATTR +bool GetNodeBytesAttr(const NodeDef& node_def, const std::string& attr_name, + std::string* value) { + AttrValue attr_value; + if (!GetAttrValue(node_def, attr_name, &attr_value)) { + return false; + } + SERVING_ENFORCE( + attr_value.has_by(), errors::ErrorCode::LOGIC_ERROR, + "attr_value({}) does not have expected type(bytes) value, node: {}", + attr_name, node_def.name()); + *value = attr_value.by(); + return true; +} + +bool GetNodeBytesAttr(const NodeDef& node_def, const std::string& attr_name, + std::vector* value) { + AttrValue attr_value; + if (!GetAttrValue(node_def, attr_name, &attr_value)) { + return false; + } + SERVING_ENFORCE( + attr_value.has_by(), errors::ErrorCode::LOGIC_ERROR, + "attr_value({}) does not have expected type(bytes) value, node: {}", + attr_name, node_def.name()); + SERVING_ENFORCE(!attr_value.bys().data().empty(), + errors::ErrorCode::INVALID_ARGUMENT, + "attr_value({}) type(BytesList) has empty value, node: {}", + attr_name, node_def.name()); + value->reserve(attr_value.bys().data().size()); + for (const auto& v : attr_value.bys().data()) { + value->emplace_back(v); + } + return true; +} + } // namespace secretflow::serving::op diff --git a/secretflow_serving/ops/node_def_util.h b/secretflow_serving/ops/node_def_util.h index 4a7a7d0..e6f6430 100644 --- a/secretflow_serving/ops/node_def_util.h +++ b/secretflow_serving/ops/node_def_util.h @@ -49,4 +49,20 @@ T GetNodeAttr(const NodeDef& node_def, const std::string& attr_name) { return value; } +bool GetNodeBytesAttr(const NodeDef& node_def, const std::string& attr_name, + std::string* value); +bool GetNodeBytesAttr(const NodeDef& node_def, const std::string& attr_name, + std::vector* value); + +inline std::string GetNodeBytesAttr(const NodeDef& node_def, + const std::string& attr_name) { + std::string value; + if (!GetNodeBytesAttr(node_def, attr_name, &value)) { + SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, + "can not get attr:{} from node:{}, op:{}", attr_name, + node_def.name(), node_def.op()); + } + return value; +} + } // namespace secretflow::serving::op diff --git a/secretflow_serving/ops/op_def_builder.cc b/secretflow_serving/ops/op_def_builder.cc index 72273ad..3b5a79d 100644 --- a/secretflow_serving/ops/op_def_builder.cc +++ b/secretflow_serving/ops/op_def_builder.cc @@ -38,11 +38,10 @@ OpDefBuilder& OpDefBuilder::Int32Attr( attr_def.name()); if (!is_list) { - int32_t v = std::get(default_value.value()); + auto& v = std::get(default_value.value()); attr_def.mutable_default_value()->set_i32(v); } else { - std::vector& v_list = - std::get>(default_value.value()); + auto& v_list = std::get>(default_value.value()); *(attr_def.mutable_default_value()->mutable_i32s()->mutable_data()) = { v_list.begin(), v_list.end()}; } @@ -69,11 +68,10 @@ OpDefBuilder& OpDefBuilder::Int64Attr( "attr {}: default_value must be provided if optional", attr_def.name()); if (!is_list) { - int64_t v = std::get(default_value.value()); + auto& v = std::get(default_value.value()); attr_def.mutable_default_value()->set_i64(v); } else { - std::vector& v_list = - std::get>(default_value.value()); + auto& v_list = std::get>(default_value.value()); *(attr_def.mutable_default_value()->mutable_i64s()->mutable_data()) = { v_list.begin(), v_list.end()}; } @@ -100,11 +98,10 @@ OpDefBuilder& OpDefBuilder::FloatAttr( "attr {}: default_value must be provided if optional", attr_def.name()); if (!is_list) { - float v = std::get(default_value.value()); + auto& v = std::get(default_value.value()); attr_def.mutable_default_value()->set_f(v); } else { - std::vector& v_list = - std::get>(default_value.value()); + auto& v_list = std::get>(default_value.value()); *(attr_def.mutable_default_value()->mutable_fs()->mutable_data()) = { v_list.begin(), v_list.end()}; } @@ -131,11 +128,10 @@ OpDefBuilder& OpDefBuilder::DoubleAttr( "attr {}: default_value must be provided if optional", attr_def.name()); if (!is_list) { - double v = std::get(default_value.value()); + auto& v = std::get(default_value.value()); attr_def.mutable_default_value()->set_d(v); } else { - std::vector& v_list = - std::get>(default_value.value()); + auto& v_list = std::get>(default_value.value()); *(attr_def.mutable_default_value()->mutable_ds()->mutable_data()) = { v_list.begin(), v_list.end()}; } @@ -162,11 +158,10 @@ OpDefBuilder& OpDefBuilder::StringAttr( "attr {}: default_value must be provided if optional", attr_def.name()); if (!is_list) { - std::string v = std::get(default_value.value()); + auto& v = std::get(default_value.value()); attr_def.mutable_default_value()->set_s(v); } else { - std::vector& v_list = - std::get>(default_value.value()); + auto& v_list = std::get>(default_value.value()); *(attr_def.mutable_default_value()->mutable_ss()->mutable_data()) = { v_list.begin(), v_list.end()}; } @@ -193,11 +188,10 @@ OpDefBuilder& OpDefBuilder::BoolAttr( "attr {}: default_value must be provided if optional", attr_def.name()); if (!is_list) { - bool v = std::get(default_value.value()); + auto& v = std::get(default_value.value()); attr_def.mutable_default_value()->set_b(v); } else { - std::vector& v_list = - std::get>(default_value.value()); + auto& v_list = std::get>(default_value.value()); *(attr_def.mutable_default_value()->mutable_bs()->mutable_data()) = { v_list.begin(), v_list.end()}; } @@ -211,6 +205,36 @@ OpDefBuilder& OpDefBuilder::BoolAttr( return *this; } +OpDefBuilder& OpDefBuilder::BytesAttr( + std::string name, std::string desc, bool is_list, bool is_optional, + std::optional> default_value) { + AttrDef attr_def; + attr_def.set_name(std::move(name)); + attr_def.set_desc(std::move(desc)); + attr_def.set_type(is_list ? AttrType::AT_BYTES_LIST : AttrType::AT_BYTES); + attr_def.set_is_optional(is_optional); + if (is_optional) { + SERVING_ENFORCE(default_value.has_value(), errors::ErrorCode::LOGIC_ERROR, + "attr {}: default_value must be provided if optional", + attr_def.name()); + if (!is_list) { + auto& v = std::get(default_value.value()); + attr_def.mutable_default_value()->set_by(v); + } else { + auto& v_list = std::get>(default_value.value()); + *(attr_def.mutable_default_value()->mutable_bys()->mutable_data()) = { + v_list.begin(), v_list.end()}; + } + } + + SERVING_ENFORCE( + attr_defs_.emplace(attr_def.name(), std::move(attr_def)).second, + errors::ErrorCode::LOGIC_ERROR, "found duplicate attr:{}", + attr_def.name()); + + return *this; +} + OpDefBuilder& OpDefBuilder::Returnable() { returnable_ = true; return *this; diff --git a/secretflow_serving/ops/op_def_builder.h b/secretflow_serving/ops/op_def_builder.h index 886e068..55acab0 100644 --- a/secretflow_serving/ops/op_def_builder.h +++ b/secretflow_serving/ops/op_def_builder.h @@ -55,6 +55,9 @@ class OpDefBuilder final { OpDefBuilder& StringAttr( std::string name, std::string desc, bool is_list, bool is_optional, std::optional> default_value = std::nullopt); + OpDefBuilder& BytesAttr( + std::string name, std::string desc, bool is_list, bool is_optional, + std::optional> default_value = std::nullopt); // tag OpDefBuilder& Returnable(); diff --git a/secretflow_serving/ops/op_factory.h b/secretflow_serving/ops/op_factory.h index 73145f2..8fee0c2 100644 --- a/secretflow_serving/ops/op_factory.h +++ b/secretflow_serving/ops/op_factory.h @@ -45,8 +45,18 @@ class OpFactory final : public Singleton { return iter->second; } + std::vector> GetAllOps() { + std::vector> result; + + std::lock_guard lock(mutex_); + for (const auto& pair : op_defs_) { + result.emplace_back(pair.second); + } + return result; + } + private: - std::map> op_defs_; + std::unordered_map> op_defs_; std::mutex mutex_; }; @@ -100,6 +110,13 @@ class OpDefBuilderWrapper { std::move(default_value)); return *this; } + OpDefBuilderWrapper& BytesAttr( + std::string name, std::string desc, bool is_list, bool is_optional, + std::optional> default_value = std::nullopt) { + builder_.BytesAttr(std::move(name), std::move(desc), is_list, is_optional, + std::move(default_value)); + return *this; + } OpDefBuilderWrapper& Returnable() { builder_.Returnable(); return *this; @@ -112,6 +129,13 @@ class OpDefBuilderWrapper { builder_.Input(std::move(name), std::move(desc)); return *this; } + OpDefBuilderWrapper& InputList(const std::string& prefix, size_t num, + std::string desc) { + for (size_t i = 0; i != num; ++i) { + builder_.Input(prefix + std::to_string(i), desc); + } + return *this; + } OpDefBuilderWrapper& Output(std::string name, std::string desc) { builder_.Output(std::move(name), std::move(desc)); return *this; diff --git a/secretflow_serving/ops/op_factory_test.cc b/secretflow_serving/ops/op_factory_test.cc index 0302f47..0369cc8 100644 --- a/secretflow_serving/ops/op_factory_test.cc +++ b/secretflow_serving/ops/op_factory_test.cc @@ -174,7 +174,7 @@ TEST_F(OpFactoryTest, Works) { JsonToPb(attr_json_map[actual_attr_def.name()], &expect_attr_def); EXPECT_FALSE(expect_attr_def.name().empty()); EXPECT_FALSE(expect_attr_def.desc().empty()); - EXPECT_FALSE(expect_attr_def.type() == AttrType::UNKNOWN_AT_TYEP); + EXPECT_FALSE(expect_attr_def.type() == AttrType::UNKNOWN_AT_TYPE); EXPECT_EQ(expect_attr_def.name(), actual_attr_def.name()); EXPECT_EQ(expect_attr_def.desc(), actual_attr_def.desc()); @@ -385,7 +385,7 @@ TEST_F(OpFactoryTest, WorksDefaultValue) { JsonToPb(attr_json_map[actual_attr_def.name()], &expect_attr_def); EXPECT_FALSE(expect_attr_def.name().empty()); EXPECT_FALSE(expect_attr_def.desc().empty()); - EXPECT_FALSE(expect_attr_def.type() == AttrType::UNKNOWN_AT_TYEP); + EXPECT_FALSE(expect_attr_def.type() == AttrType::UNKNOWN_AT_TYPE); std::cout << "expect " << expect_attr_def.ShortDebugString() << std::endl; diff --git a/secretflow_serving/ops/op_kernel.h b/secretflow_serving/ops/op_kernel.h index 1487d87..ae99162 100644 --- a/secretflow_serving/ops/op_kernel.h +++ b/secretflow_serving/ops/op_kernel.h @@ -14,7 +14,6 @@ #pragma once -#include #include #include @@ -24,31 +23,32 @@ #include "secretflow_serving/core/exception.h" #include "secretflow_serving/ops/node.h" +#include "secretflow_serving/util/arrow_helper.h" #include "secretflow_serving/protos/op.pb.h" namespace secretflow::serving::op { +// two level index: +// first for input edges of this node +// second for multiple parties to this op using OpComputeInputs = std::vector>>; struct OpKernelOptions { - const std::shared_ptr node; + const NodeDef node_def; + const std::shared_ptr op_def; }; struct ComputeContext { // TODO: Session - std::shared_ptr inputs; + OpComputeInputs inputs; std::shared_ptr output; }; class OpKernel { public: - explicit OpKernel(OpKernelOptions opts) : opts_(std::move(opts)) { - SPDLOG_INFO("op kernel: {}, version: {}, node: {}", - opts_.node->GetOpDef()->name(), - opts_.node->GetOpDef()->version(), opts_.node->GetName()); - } + explicit OpKernel(OpKernelOptions opts) : opts_(std::move(opts)) {} virtual ~OpKernel() = default; size_t GetInputsNum() const { return input_schema_list_.size(); } @@ -65,7 +65,48 @@ class OpKernel { return output_schema_; } - virtual void Compute(ComputeContext* ctx) = 0; + void Compute(ComputeContext* ctx) { + int64_t rows = ctx->inputs.front().front()->num_rows(); + SERVING_ENFORCE_EQ(ctx->inputs.size(), input_schema_list_.size(), + "schema size be equal to input edges"); + + for (size_t edge_index = 0; edge_index != ctx->inputs.size(); + ++edge_index) { + auto& edge_inputs = ctx->inputs[edge_index]; + for (size_t party_index = 0; party_index != edge_inputs.size(); + ++party_index) { + auto& input_table = edge_inputs[party_index]; + SERVING_ENFORCE_EQ(rows, input_table->num_rows(), + "rows of all inputs tables should be equal"); + + if (!input_table->schema()->Equals(input_schema_list_[edge_index])) { + // reshape real input base on kernel input_schema + std::vector> sorted_arrays; + for (int i = 0; i < input_schema_list_[edge_index]->num_fields(); + ++i) { + auto array_index = input_table->schema()->GetFieldIndex( + input_schema_list_[edge_index]->field(i)->name()); + SERVING_ENFORCE_GE(array_index, 0); + sorted_arrays.emplace_back(input_table->column(array_index)); + } + edge_inputs[party_index] = MakeRecordBatch( + input_schema_list_[edge_index], rows, sorted_arrays); + } + } + } + + DoCompute(ctx); + + SERVING_ENFORCE_EQ(rows, ctx->output->num_rows(), + "rows of input and output be equal"); + SERVING_ENFORCE(ctx->output->schema()->Equals(output_schema_), + errors::ErrorCode::LOGIC_ERROR, + "schema of output ({}) should match output_schema ({})", + ctx->output->schema()->ToString(), + output_schema_->ToString()); + } + + virtual void DoCompute(ComputeContext* ctx) = 0; protected: virtual void BuildInputSchema() = 0; diff --git a/secretflow_serving/ops/op_kernel_factory.h b/secretflow_serving/ops/op_kernel_factory.h index 1fd9d81..950e95f 100644 --- a/secretflow_serving/ops/op_kernel_factory.h +++ b/secretflow_serving/ops/op_kernel_factory.h @@ -37,15 +37,14 @@ class OpKernelFactory final : public Singleton { std::shared_ptr Create(OpKernelOptions opts) { std::lock_guard lock(mutex_); - auto creator = creators_[opts.node->node_def().op()]; + auto creator = creators_[opts.op_def->name()]; SERVING_ENFORCE(creator, errors::ErrorCode::UNEXPECTED_ERROR, - "no op kernel registered for {}", - opts.node->node_def().op()); + "no op kernel registered for {}", opts.op_def->name()); return creator(std::move(opts)); } private: - std::map creators_; + std::unordered_map creators_; std::mutex mutex_; }; diff --git a/secretflow_serving/protos/BUILD.bazel b/secretflow_serving/protos/BUILD.bazel index 13a5ed8..d95a366 100644 --- a/secretflow_serving/protos/BUILD.bazel +++ b/secretflow_serving/protos/BUILD.bazel @@ -13,24 +13,23 @@ # limitations under the License. load("@rules_cc//cc:defs.bzl", "cc_proto_library") -load("@rules_python//python:defs.bzl", "py_library") +load("@rules_proto_grpc//python:defs.bzl", "python_proto_compile") package(default_visibility = ["//visibility:public"]) proto_library( - name = "field_proto", - srcs = ["field.proto"], + name = "data_type_proto", + srcs = ["data_type.proto"], ) cc_proto_library( - name = "field_cc_proto", - deps = [":field_proto"], + name = "data_type_cc_proto", + deps = [":data_type_proto"], ) proto_library( name = "feature_proto", srcs = ["feature.proto"], - deps = [":field_proto"], ) cc_proto_library( @@ -92,3 +91,69 @@ cc_proto_library( name = "link_function_cc_proto", deps = [":link_function_proto"], ) + +proto_library( + name = "compute_trace_proto", + srcs = ["compute_trace.proto"], +) + +cc_proto_library( + name = "compute_trace_cc_proto", + deps = [":compute_trace_proto"], +) + +python_proto_compile( + name = "graph_py_proto", + output_mode = "NO_PREFIX", + prefix_path = "../..", + protos = [":graph_proto"], +) + +python_proto_compile( + name = "bundle_py_proto", + output_mode = "NO_PREFIX", + prefix_path = "../..", + protos = [":bundle_proto"], +) + +python_proto_compile( + name = "op_py_proto", + output_mode = "NO_PREFIX", + prefix_path = "../..", + protos = [":op_proto"], +) + +python_proto_compile( + name = "attr_py_proto", + output_mode = "NO_PREFIX", + prefix_path = "../..", + protos = [":attr_proto"], +) + +python_proto_compile( + name = "data_type_py_proto", + output_mode = "NO_PREFIX", + prefix_path = "../..", + protos = [":data_type_proto"], +) + +python_proto_compile( + name = "compute_trace_py_proto", + output_mode = "NO_PREFIX", + prefix_path = "../..", + protos = [":compute_trace_proto"], +) + +python_proto_compile( + name = "link_function_py_proto", + output_mode = "NO_PREFIX", + prefix_path = "../..", + protos = [":link_function_proto"], +) + +python_proto_compile( + name = "feature_py_proto", + output_mode = "NO_PREFIX", + prefix_path = "../..", + protos = [":feature_proto"], +) diff --git a/secretflow_serving/protos/attr.proto b/secretflow_serving/protos/attr.proto index 7a08dde..293627e 100644 --- a/secretflow_serving/protos/attr.proto +++ b/secretflow_serving/protos/attr.proto @@ -19,25 +19,42 @@ package secretflow.serving.op; // Supported attribute types. enum AttrType { - UNKNOWN_AT_TYEP = 0; + // Placeholder for proto3 default value, do not use it. + UNKNOWN_AT_TYPE = 0; // Atomic types + // INT32 AT_INT32 = 1; + // INT64 AT_INT64 = 2; + // FLOAT AT_FLOAT = 3; + // DOUBLE AT_DOUBLE = 4; + // STRING AT_STRING = 5; + // BOOL AT_BOOL = 6; + // BYTES + AT_BYTES = 7; // List types + // INT32 LIST AT_INT32_LIST = 11; + // INT64 LIST AT_INT64_LIST = 12; + // FLOAT LIST AT_FLOAT_LIST = 13; + // DOUBLE LIST AT_DOUBLE_LIST = 14; + // STRING LIST AT_STRING_LIST = 15; + // BOOL LIST AT_BOOL_LIST = 16; + // BYTES LIST + AT_BYTES_LIST = 17; } message Int32List { @@ -64,6 +81,10 @@ message BoolList { repeated bool data = 11; } +message BytesList { + repeated bytes data = 11; +} + // The value of an attribute message AttrValue { oneof value { @@ -77,6 +98,8 @@ message AttrValue { string s = 5; // BOOL bool b = 6; + // BYTES + bytes by = 7; // Lists @@ -90,6 +113,8 @@ message AttrValue { StringList ss = 15; // BOOLS BoolList bs = 16; + // BYTESS + BytesList bys = 17; } } diff --git a/secretflow_serving/protos/bundle.proto b/secretflow_serving/protos/bundle.proto index 4413e2e..8870d6c 100644 --- a/secretflow_serving/protos/bundle.proto +++ b/secretflow_serving/protos/bundle.proto @@ -55,3 +55,12 @@ message ModelManifest { // The format type of the model bundle file. FileFormatType bundle_format = 2; } + +// Represents a secertflow model without private data. +message ModelInfo { + string name = 1; + + string desc = 2; + + GraphView graph_view = 3; +} diff --git a/secretflow_serving/protos/compute_trace.proto b/secretflow_serving/protos/compute_trace.proto new file mode 100644 index 0000000..0f93628 --- /dev/null +++ b/secretflow_serving/protos/compute_trace.proto @@ -0,0 +1,111 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +syntax = "proto3"; + +package secretflow.serving.compute; + +enum ExtendFunctionName { + // Placeholder for proto3 default value, do not use it + UNKOWN_EX_FUNCTION_NAME = 0; + + // Get colunm from table(record_batch). + // see + // https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch6columnEi + EFN_TB_COLUMN = 1; + // Add colum to table(record_batch). + // see + // https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch9AddColumnEiNSt6stringERKNSt10shared_ptrI5ArrayEE + EFN_TB_ADD_COLUMN = 2; + // Remove colunm from table(record_batch). + // see + // https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch12RemoveColumnEi + EFN_TB_REMOVE_COLUMN = 3; + // Set colunm to table(record_batch). + // see + // https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch9SetColumnEiRKNSt10shared_ptrI5FieldEERKNSt10shared_ptrI5ArrayEE + EFN_TB_SET_COLUMN = 4; +} + +// Represents a single value with a specific data type. +message Scalar { + oneof value { + // For integer types with a length below 32 bits, due to the limitations of + // protobuf data types, we will extract the data and then convert it to the + // desired type. + + // INT8. + int32 i8 = 1; + // UINT8 + int32 ui8 = 2; + // INT16 + int32 i16 = 3; + // UINT16 + int32 ui16 = 4; + + // INT32 + int32 i32 = 5; + // UINT32 + uint32 ui32 = 6; + // INT64 + int64 i64 = 7; + // UINT64 + uint64 ui64 = 8; + + // FLOAT + float f = 9; + // DOUBLE + double d = 10; + + // STRING + string s = 11; + + // BOOL + bool b = 12; + } +} + +message FunctionInput { + oneof value { + // '0' means root input data + int32 data_id = 1; + Scalar custom_scalar = 2; + } +} + +message FunctionOutput { + int32 data_id = 1; +} + +message FunctionTrace { + // The Function name. + string name = 1; + + // The serialized function options. + bytes option_bytes = 2; + + // Inputs of this function. + repeated FunctionInput inputs = 3; + + // Output of this function. + FunctionOutput output = 4; +} + +message ComputeTrace { + // The name of this Compute. + string name = 1; + + repeated FunctionTrace func_traces = 2; +} diff --git a/secretflow_serving/protos/data_type.proto b/secretflow_serving/protos/data_type.proto new file mode 100644 index 0000000..1a23f92 --- /dev/null +++ b/secretflow_serving/protos/data_type.proto @@ -0,0 +1,52 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +syntax = "proto3"; + +package secretflow.serving; + +// Mapping arrow::DataType +// `https://arrow.apache.org/docs/cpp/api/datatype.html`. +enum DataType { + // Placeholder for proto3 default value, do not use it. + UNKNOWN_DT_TYPE = 0; + + // Boolean as 1 bit, LSB bit-packed ordering. + DT_BOOL = 1; + // Unsigned 8-bit little-endian integer. + DT_UINT8 = 2; + // Signed 8-bit little-endian integer. + DT_INT8 = 3; + // Unsigned 16-bit little-endian integer. + DT_UINT16 = 4; + // Signed 16-bit little-endian integer. + DT_INT16 = 5; + // Unsigned 32-bit little-endian integer. + DT_UINT32 = 6; + // Signed 32-bit little-endian integer. + DT_INT32 = 7; + // Unsigned 64-bit little-endian integer. + DT_UINT64 = 8; + // Signed 64-bit little-endian integer. + DT_INT64 = 9; + // 4-byte floating point value + DT_FLOAT = 11; + // 8-byte floating point value + DT_DOUBLE = 12; + // UTF8 variable-length string as List + DT_STRING = 13; + // Variable-length bytes (no guarantee of UTF8-ness) + DT_BINARY = 14; +} diff --git a/secretflow_serving/protos/feature.proto b/secretflow_serving/protos/feature.proto index ffdb173..85f7629 100644 --- a/secretflow_serving/protos/feature.proto +++ b/secretflow_serving/protos/feature.proto @@ -15,10 +15,27 @@ syntax = "proto3"; -import "secretflow_serving/protos/field.proto"; - package secretflow.serving; +// Supported feature field type. +enum FieldType { + // Placeholder for proto3 default value, do not use it. + UNKNOWN_FIELD_TYPE = 0; + + // BOOL + FIELD_BOOL = 1; + // INT32 + FIELD_INT32 = 2; + // INT64 + FIELD_INT64 = 3; + // FLOAT + FIELD_FLOAT = 4; + // DOUBLE + FIELD_DOUBLE = 5; + // STRING + FIELD_STRING = 6; +} + // The value of a feature message FeatureValue { // int list diff --git a/secretflow_serving/protos/field.proto b/secretflow_serving/protos/field.proto deleted file mode 100644 index b7adea3..0000000 --- a/secretflow_serving/protos/field.proto +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -syntax = "proto3"; - -package secretflow.serving; - -// Supported field type. -enum FieldType { - UNKNOWN_FIELD_TYPE = 0; - - FIELD_BOOL = 1; - FIELD_INT32 = 2; - FIELD_INT64 = 3; - FIELD_FLOAT = 4; - FIELD_DOUBLE = 5; - FIELD_STRING = 6; -} diff --git a/secretflow_serving/protos/graph.proto b/secretflow_serving/protos/graph.proto index a84a8fe..d924ca2 100644 --- a/secretflow_serving/protos/graph.proto +++ b/secretflow_serving/protos/graph.proto @@ -21,13 +21,15 @@ package secretflow.serving; // Supported dispatch type enum DispatchType { + // Placeholder for proto3 default value, do not use it. UNKNOWN_DP_TYPE = 0; - // Dispath all participants. + // Dispatch all participants. DP_ALL = 1; - // Dispath any participant. + // Dispatch any participant. DP_ANYONE = 2; - // DP_SPECIFIED = 3; + // Dispatch specified participant. + DP_SPECIFIED = 3; } // The runtime config of the execution. @@ -38,6 +40,9 @@ message RuntimeConfig { // The execution need run in session(stateful) // TODO: not support yet. bool session_run = 2; + + // if dispatch_type is DP_SPECIFIED, only one party should be true + bool specific_flag = 3; } // The definition of a execution. A execution represents a subgraph within a @@ -58,7 +63,7 @@ message NodeDef { // Must be unique among all nodes of the graph. string name = 1; - // The operation name. + // The operator name. string op = 2; // The parent node names of the node. The order of the parent nodes should @@ -68,9 +73,29 @@ message NodeDef { // The attribute values configed in the node. Note that this should include // all attrs defined in the corresponding OpDef. map attr_values = 4; + + // The operator version. + string op_version = 5; +} + +// The view of a node, which could be public to other parties +message NodeView { + // Must be unique among all nodes of the graph. + string name = 1; + + // The operator name. + string op = 2; + + // The parent node names of the node. The order of the parent nodes should + // match the order of the inputs of the node. + repeated string parents = 3; + + // The operator version. + string op_version = 5; } -// Represents the graph of operations +// The definition of a Graph. A graph consists of a set of nodes carrying data +// and a set of executions that describes the scheduling of the graph. message GraphDef { // Version of the graph string version = 1; @@ -79,3 +104,14 @@ message GraphDef { repeated ExecutionDef execution_list = 3; } + +// The view of a graph is used to display the structure of the graph, containing +// only structural information and excluding the data components. +message GraphView { + // Version of the graph + string version = 1; + + repeated NodeView node_list = 2; + + repeated ExecutionDef execution_list = 3; +} diff --git a/secretflow_serving/protos/link_function.proto b/secretflow_serving/protos/link_function.proto index 79fda47..c256659 100644 --- a/secretflow_serving/protos/link_function.proto +++ b/secretflow_serving/protos/link_function.proto @@ -17,18 +17,25 @@ syntax = "proto3"; package secretflow.serving; -enum LinkFucntionType { +// Type of link function. +enum LinkFunctionType { + // Placeholder for proto3 default value, do not use it. UNKNOWN_LF_TYPE = 0; + // LOG LF_LOG = 1; + // LOGIT LF_LOGIT = 2; + // INVERSE LF_INVERSE = 3; - LF_LOGIT_V2 = 4; + // RECIPROCAL LF_RECIPROCAL = 5; - LF_INDENTITY = 6; + // IDENTITY + LF_IDENTITY = 6; + // Sigmoid LF_SIGMOID_RAW = 11; - // Taylor Maclaurin 1 order. + // MinMax approximation 1 order. LF_SIGMOID_MM1 = 12; // MinMax approximation 3 order. LF_SIGMOID_MM3 = 13; diff --git a/secretflow_serving/server/BUILD.bazel b/secretflow_serving/server/BUILD.bazel index 2d41a6c..e92726d 100644 --- a/secretflow_serving/server/BUILD.bazel +++ b/secretflow_serving/server/BUILD.bazel @@ -45,6 +45,7 @@ serving_cc_library( deps = [ "//secretflow_serving/apis:prediction_service_cc_proto", "//secretflow_serving/framework:predictor", + "//secretflow_serving/util:utils", ], ) @@ -70,6 +71,19 @@ serving_cc_library( ], ) +serving_cc_library( + name = "model_service_impl", + srcs = ["model_service_impl.cc"], + hdrs = ["model_service_impl.h"], + deps = [ + "//secretflow_serving/apis:model_service_cc_proto", + "//secretflow_serving/core:exception", + "//secretflow_serving/server/metrics:default_metrics_registry", + "@com_github_brpc_brpc//:brpc", + "@yacl//yacl/utils:elapsed_timer", + ], +) + serving_cc_library( name = "server", srcs = ["server.cc"], @@ -82,7 +96,9 @@ serving_cc_library( ":health", ":prediction_service_impl", "//secretflow_serving/config:serving_config_cc_proto", + "//secretflow_serving/framework:model_info_collector", "//secretflow_serving/framework:model_loader", + "//secretflow_serving/server:model_service_impl", "//secretflow_serving/server/kuscia:config_parser", "//secretflow_serving/server/metrics:default_metrics_registry", "//secretflow_serving/server/metrics:metrics_service", diff --git a/secretflow_serving/server/execution_core.cc b/secretflow_serving/server/execution_core.cc index 2d8a23f..ddd1cf7 100644 --- a/secretflow_serving/server/execution_core.cc +++ b/secretflow_serving/server/execution_core.cc @@ -19,21 +19,24 @@ #include "secretflow_serving/feature_adapter/feature_adapter_factory.h" #include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/thread_pool.h" namespace secretflow::serving { ExecutionCore::ExecutionCore(Options opts) : opts_(std::move(opts)), - stats_({{"handler", "ExecutionCore"}, - {"service_id", opts_.id}, - {"party_id", opts_.party_id}}) { + stats_({{"handler", "ExecutionCore"}, {"party_id", opts_.party_id}}) { SERVING_ENFORCE(!opts_.id.empty(), errors::ErrorCode::INVALID_ARGUMENT); SERVING_ENFORCE(!opts_.party_id.empty(), errors::ErrorCode::INVALID_ARGUMENT); SERVING_ENFORCE(opts_.executable, errors::ErrorCode::INVALID_ARGUMENT); + SERVING_ENFORCE(opts_.op_exec_workers_num > 0, + errors::ErrorCode::INVALID_ARGUMENT); + + ThreadPool::GetInstance()->Start(opts_.op_exec_workers_num); // key: model input feature name // value: source or predefined feature name - std::map model_feature_mapping; + std::unordered_map model_feature_mapping; valid_feature_mapping_flag_ = false; if (opts_.feature_mapping.has_value()) { for (const auto& pair : opts_.feature_mapping.value()) { @@ -47,10 +50,9 @@ ExecutionCore::ExecutionCore(Options opts) } } - std::shared_ptr source_schema; const auto& model_input_schema = opts_.executable->GetInputFeatureSchema(); if (model_feature_mapping.empty()) { - source_schema = arrow::schema(model_input_schema->fields()); + source_schema_ = model_input_schema; } else { arrow::SchemaBuilder builder; int num_fields = model_input_schema->num_fields(); @@ -63,14 +65,14 @@ ExecutionCore::ExecutionCore(Options opts) SERVING_CHECK_ARROW_STATUS( builder.AddField(arrow::field(iter->second, f->type()))); } - SERVING_GET_ARROW_RESULT(builder.Finish(), source_schema); + SERVING_GET_ARROW_RESULT(builder.Finish(), source_schema_); } if (opts_.feature_source_config.has_value()) { - SPDLOG_INFO("create feature adpater, type:{}", + SPDLOG_INFO("create feature adapter, type:{}", static_cast(opts_.feature_source_config->options_case())); - feature_adapater_ = feature::FeatureAdapterFactory::GetInstance()->Create( - *opts_.feature_source_config, opts_.id, opts_.party_id, source_schema); + feature_adapter_ = feature::FeatureAdapterFactory::GetInstance()->Create( + *opts_.feature_source_config, opts_.id, opts_.party_id, source_schema_); } } @@ -81,7 +83,7 @@ void ExecutionCore::Execute(const apis::ExecuteRequest* request, try { SERVING_ENFORCE(request->service_spec().id() == opts_.id, errors::ErrorCode::INVALID_ARGUMENT, - "invalid service sepc id: {}", + "invalid service spec id: {}", request->service_spec().id()); response->mutable_service_spec()->CopyFrom(request->service_spec()); @@ -102,7 +104,8 @@ void ExecutionCore::Execute(const apis::ExecuteRequest* request, "get empty predefined features."); SERVING_ENFORCE(request->task().nodes().empty(), errors::ErrorCode::LOGIC_ERROR); - features = FeaturesToTable(request->feature_source().predefineds()); + features = FeaturesToTable(request->feature_source().predefineds(), + source_schema_); } features = ApplyFeatureMappingRule(features); @@ -110,8 +113,8 @@ void ExecutionCore::Execute(const apis::ExecuteRequest* request, Executable::Task task; task.id = request->task().execution_id(); task.features = features; - task.node_inputs = std::make_shared< - std::map>>(); + task.node_inputs = std::make_shared>>(); for (const auto& n : request->task().nodes()) { auto compute_inputs = std::make_shared(); for (const auto& io : n.ios()) { @@ -125,11 +128,10 @@ void ExecutionCore::Execute(const apis::ExecuteRequest* request, } opts_.executable->Run(task); - for (size_t i = 0; i < task.outputs->size(); ++i) { - auto& output = task.outputs->at(i); - auto node_io = response->mutable_result()->add_nodes(); + for (auto& output : *task.outputs) { + auto* node_io = response->mutable_result()->add_nodes(); node_io->set_name(std::move(output.node_name)); - auto io_data = node_io->add_ios(); + auto* io_data = node_io->add_ios(); io_data->add_datas(SerializeRecordBatch(output.table)); } response->mutable_status()->set_code(errors::ErrorCode::OK); @@ -145,30 +147,36 @@ void ExecutionCore::Execute(const apis::ExecuteRequest* request, } timer.Pause(); - RecordMetrics(*request, *response, timer.CountMs()); + RecordMetrics(*request, *response, timer.CountMs(), "Execute"); SPDLOG_DEBUG("execute end, response: {}", response->ShortDebugString()); } void ExecutionCore::RecordMetrics(const apis::ExecuteRequest& request, const apis::ExecuteResponse& response, - double duration_ms) { - std::map labels = { - {"code", std::to_string(response.status().code())}, - {"requester_id", request.requester_id()}, - {"feature_source_type", - FeatureSourceType_Name(request.feature_source().type())}}; - stats_.execute_request_counter_family.Add(::prometheus::Labels(labels)) + double duration_ms, + const std::string& action) { + stats_.execute_request_counter_family + .Add(::prometheus::Labels( + {{"service_id", request.service_spec().id()}, + {"action", action}, + {"code", std::to_string(response.status().code())}, + {"requester_id", request.requester_id()}, + {"feature_source_type", + FeatureSourceType_Name(request.feature_source().type())}})) .Increment(); - stats_.execute_request_totol_duration_family.Add(::prometheus::Labels(labels)) - .Increment(duration_ms); - stats_.execute_request_duration_summary.Observe(duration_ms); + stats_.execute_request_duration_summary_family + .Add(::prometheus::Labels({{"service_id", request.service_spec().id()}, + {"action", action}}), + ::prometheus::Summary::Quantiles( + {{0.5, 0.05}, {0.9, 0.01}, {0.99, 0.001}})) + .Observe(duration_ms); } std::shared_ptr ExecutionCore::BatchFetchFeatures( const apis::ExecuteRequest* request, apis::ExecuteResponse* response) const { - SERVING_ENFORCE(feature_adapater_, errors::ErrorCode::INVALID_ARGUMENT, + SERVING_ENFORCE(feature_adapter_, errors::ErrorCode::INVALID_ARGUMENT, "feature source is not set, please check config."); yacl::ElapsedTimer timer; @@ -178,11 +186,15 @@ std::shared_ptr ExecutionCore::BatchFetchFeatures( fa_request.fs_param = &request->feature_source().fs_param(); feature::FeatureAdapter::Response fa_response; fa_response.header = response->mutable_header(); - feature_adapater_->FetchFeature(fa_request, &fa_response); + feature_adapter_->FetchFeature(fa_request, &fa_response); + RecordBatchFeatureMetrics(request->service_spec().id(), + request->requester_id(), errors::ErrorCode::OK, + timer.CountMs()); return fa_response.features; } catch (Exception& e) { - RecordBatchFeatureMetrics(request->requester_id(), e.code(), + RecordBatchFeatureMetrics(request->service_spec().id(), + request->requester_id(), e.code(), timer.CountMs()); throw e; } @@ -194,7 +206,7 @@ std::shared_ptr ExecutionCore::ApplyFeatureMappingRule( // no need mapping return features; } - auto& feature_mapping = opts_.feature_mapping.value(); + const auto& feature_mapping = opts_.feature_mapping.value(); int num_cols = features->num_columns(); const auto& old_schema = features->schema(); @@ -214,16 +226,22 @@ std::shared_ptr ExecutionCore::ApplyFeatureMappingRule( return MakeRecordBatch(schema, features->num_rows(), features->columns()); } -void ExecutionCore::RecordBatchFeatureMetrics(const std::string& requester_id, +void ExecutionCore::RecordBatchFeatureMetrics(const std::string& service_id, + const std::string& requester_id, int code, double duration_ms) const { - std::map labels = {{"requester_id", requester_id}, - {"code", std::to_string(code)}}; - stats_.fetch_feature_counter_family.Add(::prometheus::Labels(labels)) + stats_.fetch_feature_counter_family + .Add(::prometheus::Labels({{"service_id", service_id}, + {"action", "FetchFeature"}, + {"code", std::to_string(code)}, + {"requester_id", requester_id}})) .Increment(); - stats_.fetch_feature_total_duration_family.Add(::prometheus::Labels(labels)) - .Increment(duration_ms); - stats_.fetch_feature_duration_summary.Observe(duration_ms); + stats_.fetch_feature_duration_summary_family + .Add(::prometheus::Labels( + {{"service_id", service_id}, {"action", "FetchFeature"}}), + ::prometheus::Summary::Quantiles( + {{0.5, 0.05}, {0.9, 0.01}, {0.99, 0.001}})) + .Observe(duration_ms); } ExecutionCore::Stats::Stats( @@ -236,45 +254,24 @@ ExecutionCore::Stats::Stats( "this ExecutionCore.") .Labels(labels) .Register(*registry)), - execute_request_totol_duration_family( - ::prometheus::BuildCounter() - .Name("execution_core_request_total_duration_family") - .Help("total time to process the request in milliseconds") - .Labels(labels) - .Register(*registry)), execute_request_duration_summary_family( ::prometheus::BuildSummary() .Name("execution_core_request_duration_family") .Help("prediction service api request duration in milliseconds") .Labels(labels) .Register(*registry)), - execute_request_duration_summary( - execute_request_duration_summary_family.Add( - ::prometheus::Labels(), - ::prometheus::Summary::Quantiles( - {{0.5, 0.05}, {0.9, 0.01}, {0.99, 0.001}}))), fetch_feature_counter_family( ::prometheus::BuildCounter() .Name("fetch_feature_counter_family") - .Help("How many times to fetch remote features sevice by " + .Help("How many times to fetch remote features service by " "this ExecutionCore.") .Labels(labels) .Register(*registry)), - fetch_feature_total_duration_family( - ::prometheus::BuildCounter() - .Name("fetch_feature_total_duration_family") - .Help("total time of fetching remote features in milliseconds") - .Labels(labels) - .Register(*registry)), fetch_feature_duration_summary_family( ::prometheus::BuildSummary() .Name("fetch_feature_duration_family") .Help("durations of fetching remote features in milliseconds") .Labels(labels) - .Register(*registry)), - fetch_feature_duration_summary(fetch_feature_duration_summary_family.Add( - ::prometheus::Labels(), - ::prometheus::Summary::Quantiles( - {{0.5, 0.05}, {0.9, 0.01}, {0.99, 0.001}}))) {} + .Register(*registry)) {} } // namespace secretflow::serving diff --git a/secretflow_serving/server/execution_core.h b/secretflow_serving/server/execution_core.h index f2a105a..f8044a2 100644 --- a/secretflow_serving/server/execution_core.h +++ b/secretflow_serving/server/execution_core.h @@ -37,6 +37,8 @@ class ExecutionCore { std::optional feature_source_config; std::shared_ptr executable; + + uint32_t op_exec_workers_num{std::thread::hardware_concurrency()}; }; public: @@ -53,36 +55,36 @@ class ExecutionCore { std::shared_ptr BatchFetchFeatures( const apis::ExecuteRequest* request, apis::ExecuteResponse* response) const; - void RecordMetrics(const apis::ExecuteRequest& request, - const apis::ExecuteResponse& response, double duration_ms); - void RecordBatchFeatureMetrics(const std::string& requester_id, int code, - double duration_ms) const; std::shared_ptr ApplyFeatureMappingRule( const std::shared_ptr& features); + void RecordMetrics(const apis::ExecuteRequest& request, + const apis::ExecuteResponse& response, double duration_ms, + const std::string& action); + + void RecordBatchFeatureMetrics(const std::string& service_id, + const std::string& requester_id, int code, + double duration_ms) const; + private: const Options opts_; bool valid_feature_mapping_flag_; - std::unique_ptr feature_adapater_; + std::shared_ptr source_schema_; + + std::unique_ptr feature_adapter_; struct Stats { // for service interface ::prometheus::Family<::prometheus::Counter>& execute_request_counter_family; - ::prometheus::Family<::prometheus::Counter>& - execute_request_totol_duration_family; ::prometheus::Family<::prometheus::Summary>& execute_request_duration_summary_family; - ::prometheus::Summary& execute_request_duration_summary; ::prometheus::Family<::prometheus::Counter>& fetch_feature_counter_family; - ::prometheus::Family<::prometheus::Counter>& - fetch_feature_total_duration_family; ::prometheus::Family<::prometheus::Summary>& fetch_feature_duration_summary_family; - ::prometheus::Summary& fetch_feature_duration_summary; Stats(std::map labels, const std::shared_ptr<::prometheus::Registry>& registry = diff --git a/secretflow_serving/server/execution_service_impl.cc b/secretflow_serving/server/execution_service_impl.cc index 5d9f798..f8e8422 100644 --- a/secretflow_serving/server/execution_service_impl.cc +++ b/secretflow_serving/server/execution_service_impl.cc @@ -24,7 +24,6 @@ ExecutionServiceImpl::ExecutionServiceImpl( const std::shared_ptr& execution_core) : execution_core_(execution_core), stats_({{"handler", "ExecutionService"}, - {"service_id", execution_core->GetServiceID()}, {"party_id", execution_core->GetPartyID()}}) {} void ExecutionServiceImpl::Execute( @@ -41,21 +40,25 @@ void ExecutionServiceImpl::Execute( timer.Pause(); SPDLOG_DEBUG("execute end, response: {}", response->ShortDebugString()); - RecordMetrics(*request, *response, timer.CountMs()); + RecordMetrics(*request, *response, timer.CountMs(), "Execute"); } void ExecutionServiceImpl::RecordMetrics(const apis::ExecuteRequest& request, const apis::ExecuteResponse& response, - double duration_ms) { + double duration_ms, + const std::string& action) { stats_.api_request_counter_family .Add(::prometheus::Labels( - {{"code", std::to_string(response.status().code())}})) + {{"action", action}, + {"service_id", request.service_spec().id()}, + {"code", std::to_string(response.status().code())}})) .Increment(); - stats_.api_request_totol_duration_family - .Add(::prometheus::Labels( - {{"code", std::to_string(response.status().code())}})) - .Increment(duration_ms); - stats_.api_request_duration_summary.Observe(duration_ms); + stats_.api_request_duration_summary_family + .Add(::prometheus::Labels({{"action", action}, + {"service_id", request.service_spec().id()}}), + ::prometheus::Summary::Quantiles( + {{0.5, 0.05}, {0.9, 0.01}, {0.99, 0.001}})) + .Observe(duration_ms); } ExecutionServiceImpl::Stats::Stats( @@ -68,21 +71,11 @@ ExecutionServiceImpl::Stats::Stats( "this server.") .Labels(labels) .Register(*registry)), - api_request_totol_duration_family( - ::prometheus::BuildCounter() - .Name("execution_request_total_duration_family") - .Help("total time to process the request in milliseconds") - .Labels(labels) - .Register(*registry)), api_request_duration_summary_family( ::prometheus::BuildSummary() .Name("execution_request_duration_family") .Help("prediction service api request duration in milliseconds") .Labels(labels) - .Register(*registry)), - api_request_duration_summary(api_request_duration_summary_family.Add( - ::prometheus::Labels(), - ::prometheus::Summary::Quantiles( - {{0.5, 0.05}, {0.9, 0.01}, {0.99, 0.001}}))) {} + .Register(*registry)) {} } // namespace secretflow::serving diff --git a/secretflow_serving/server/execution_service_impl.h b/secretflow_serving/server/execution_service_impl.h index 6aaebb5..34bd1f8 100644 --- a/secretflow_serving/server/execution_service_impl.h +++ b/secretflow_serving/server/execution_service_impl.h @@ -38,15 +38,13 @@ class ExecutionServiceImpl : public apis::ExecutionService { private: void RecordMetrics(const apis::ExecuteRequest& request, - const apis::ExecuteResponse& response, double duration_ms); + const apis::ExecuteResponse& response, double duration_ms, + const std::string& action); struct Stats { // for service interface ::prometheus::Family<::prometheus::Counter>& api_request_counter_family; - ::prometheus::Family<::prometheus::Counter>& - api_request_totol_duration_family; ::prometheus::Family<::prometheus::Summary>& api_request_duration_summary_family; - ::prometheus::Summary& api_request_duration_summary; Stats(std::map labels, const std::shared_ptr<::prometheus::Registry>& registry = diff --git a/secretflow_serving/server/kuscia/config_parser_test.cc b/secretflow_serving/server/kuscia/config_parser_test.cc index 878a1f5..b4fd925 100644 --- a/secretflow_serving/server/kuscia/config_parser_test.cc +++ b/secretflow_serving/server/kuscia/config_parser_test.cc @@ -30,7 +30,7 @@ TEST_F(KusciaConfigParserTest, Works) { tmpfile.save(1 + R"JSON( { "serving_id": "kd-1", - "input_config": "{\"partyConfigs\":{\"alice\":{\"serverConfig\":{\"featureMapping\":{\"v24\":\"x24\",\"v22\":\"x22\",\"v21\":\"x21\",\"v25\":\"x25\",\"v23\":\"x23\"}},\"modelConfig\":{\"modelId\":\"glm-test-1\",\"basePath\":\"/tmp/alice\",\"sourceMd5\":\"ba9b4a121e139902e320a18c0610aa99\",\"sourcePath\":\"examples/alice/glm-test.tar.gz\",\"sourceType\":\"ST_FILE\"},\"featureSourceConfig\":{\"mockOpts\":{}},\"channel_desc\":{\"protocol\":\"http\"}},\"bob\":{\"serverConfig\":{\"featureMapping\":{\"v6\":\"x6\",\"v7\":\"x7\",\"v8\":\"x8\",\"v9\":\"x9\",\"v10\":\"x10\"}},\"modelConfig\":{\"modelId\":\"glm-test-1\",\"basePath\":\"/tmp/bob\",\"sourceMd5\":\"bdd57c453e64bcf4a9b5fa247a7cffa1\",\"sourcePath\":\"examples/bob/glm-test.tar.gz\",\"sourceType\":\"ST_FILE\"},\"featureSourceConfig\":{\"mockOpts\":{}},\"channel_desc\":{\"protocol\":\"http\"}}}}", + "input_config": "{\"partyConfigs\":{\"alice\":{\"serverConfig\":{\"featureMapping\":{\"v24\":\"x24\",\"v22\":\"x22\",\"v21\":\"x21\",\"v25\":\"x25\",\"v23\":\"x23\"}},\"modelConfig\":{\"modelId\":\"glm-test-1\",\"basePath\":\"/tmp/alice\",\"sourceSha256\":\"3b6a3b76a8d5bbf0e45b83f2d44772a0a6aa9a15bf382cee22cbdc8f59d55522\",\"sourcePath\":\"examples/alice/glm-test.tar.gz\",\"sourceType\":\"ST_FILE\"},\"featureSourceConfig\":{\"mockOpts\":{}},\"channel_desc\":{\"protocol\":\"http\"}},\"bob\":{\"serverConfig\":{\"featureMapping\":{\"v6\":\"x6\",\"v7\":\"x7\",\"v8\":\"x8\",\"v9\":\"x9\",\"v10\":\"x10\"}},\"modelConfig\":{\"modelId\":\"glm-test-1\",\"basePath\":\"/tmp/bob\",\"sourceSha256\":\"330192f3a51f9498dd882478bfe08a06501e2ed4aa2543a0fb586180925eb309\",\"sourcePath\":\"examples/bob/glm-test.tar.gz\",\"sourceType\":\"ST_FILE\"},\"featureSourceConfig\":{\"mockOpts\":{}},\"channel_desc\":{\"protocol\":\"http\"}}}}", "cluster_def": "{\"parties\":[{\"name\":\"alice\", \"role\":\"\", \"services\":[{\"portName\":\"service\", \"endpoints\":[\"kd-1-service.alice.svc\"]}, {\"portName\":\"internal\", \"endpoints\":[\"kd-1-internal.alice.svc:53510\"]}, {\"portName\":\"brpc-builtin\", \"endpoints\":[\"kd-1-brpc-builtin.alice.svc:53511\"]}]}, {\"name\":\"bob\", \"role\":\"\", \"services\":[{\"portName\":\"brpc-builtin\", \"endpoints\":[\"kd-1-brpc-builtin.bob.svc:53511\"]}, {\"portName\":\"service\", \"endpoints\":[\"kd-1-service.bob.svc\"]}, {\"portName\":\"internal\", \"endpoints\":[\"kd-1-internal.bob.svc:53510\"]}]}], \"selfPartyIdx\":0, \"selfEndpointIdx\":0}", "allocated_ports": "{\"ports\":[{\"name\":\"service\", \"port\":53509, \"scope\":\"Cluster\", \"protocol\":\"HTTP\"}, {\"name\":\"internal\", \"port\":53510, \"scope\":\"Domain\", \"protocol\":\"HTTP\"}, {\"name\":\"brpc-builtin\", \"port\":53511, \"scope\":\"Domain\", \"protocol\":\"HTTP\"}]}", "oss_meta": "" @@ -45,7 +45,8 @@ TEST_F(KusciaConfigParserTest, Works) { EXPECT_EQ("glm-test-1", model_config.model_id()); EXPECT_EQ("/tmp/alice", model_config.base_path()); EXPECT_EQ(SourceType::ST_FILE, model_config.source_type()); - EXPECT_EQ("ba9b4a121e139902e320a18c0610aa99", model_config.source_md5()); + EXPECT_EQ("3b6a3b76a8d5bbf0e45b83f2d44772a0a6aa9a15bf382cee22cbdc8f59d55522", + model_config.source_sha256()); auto cluster_config = config_parser.cluster_config(); EXPECT_EQ(2, cluster_config.parties_size()); diff --git a/secretflow_serving/server/main.cc b/secretflow_serving/server/main.cc index 17bc1d5..6771884 100644 --- a/secretflow_serving/server/main.cc +++ b/secretflow_serving/server/main.cc @@ -18,6 +18,7 @@ #include "secretflow_serving/core/exception.h" #include "secretflow_serving/core/logging.h" +#include "secretflow_serving/ops/op_factory.h" #include "secretflow_serving/server/kuscia/config_parser.h" #include "secretflow_serving/server/server.h" #include "secretflow_serving/server/version.h" @@ -27,7 +28,7 @@ DEFINE_string(config_mode, "", "config mode for serving, default value will use the raw config " - "defined. optinal value: kuscia"); + "defined. optional value: kuscia"); DEFINE_string(serving_config_file, "", "read an ascii config protobuf from the supplied file name."); @@ -51,8 +52,6 @@ int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, true); try { - SPDLOG_INFO("version: {}", SERVING_VERSION_STRING); - // init logger secretflow::serving::LoggingConfig log_config; if (!FLAGS_logging_config_file.empty()) { @@ -61,6 +60,22 @@ int main(int argc, char* argv[]) { } secretflow::serving::SetupLogging(log_config); + SPDLOG_INFO("version: {}", SERVING_VERSION_STRING); + + { + auto op_def_list = + secretflow::serving::op::OpFactory::GetInstance()->GetAllOps(); + std::vector op_names; + std::for_each( + op_def_list.begin(), op_def_list.end(), + [&](const std::shared_ptr& o) { + op_names.emplace_back(o->name()); + }); + + SPDLOG_INFO("op list: {}", + fmt::join(op_names.begin(), op_names.end(), ", ")); + } + STRING_EMPTY_VALIDATOR(FLAGS_serving_config_file); // init server options @@ -92,7 +107,8 @@ int main(int argc, char* argv[]) { server.WaitForEnd(); } catch (const secretflow::serving::Exception& e) { // TODO: custom status sink - SPDLOG_ERROR("server startup failed, code:{}, msg:{}", e.code(), e.what()); + SPDLOG_ERROR("server startup failed, code: {}, msg: {}, stack: {}", + e.code(), e.what(), e.stack_trace()); return -1; } catch (const std::exception& e) { // TODO: custom status sink diff --git a/secretflow_serving/server/model_service_impl.cc b/secretflow_serving/server/model_service_impl.cc new file mode 100644 index 0000000..27abdfe --- /dev/null +++ b/secretflow_serving/server/model_service_impl.cc @@ -0,0 +1,94 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/server/model_service_impl.h" + +#include "brpc/closure_guard.h" +#include "brpc/controller.h" +#include "spdlog/spdlog.h" +#include "yacl/utils/elapsed_timer.h" + +#include "secretflow_serving/core/exception.h" + +namespace secretflow::serving { + +ModelServiceImpl::ModelServiceImpl(std::map model_infos, + const std::string& self_party_id) + : model_infos_(std::move(model_infos)), + self_party_id_(self_party_id), + stats_({{"handler", "ModelService"}, {"party_id", self_party_id_}}) {} + +void ModelServiceImpl::GetModelInfo( + ::google::protobuf::RpcController* controller, + const apis::GetModelInfoRequest* request, + apis::GetModelInfoResponse* response, ::google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + auto* cntl = static_cast(controller); + cntl->set_always_print_primitive_fields(true); + + yacl::ElapsedTimer timer; + + response->mutable_service_spec()->CopyFrom(request->service_spec()); + + auto it = model_infos_.find(request->service_spec().id()); + if (it == model_infos_.end()) { + response->mutable_status()->set_code(errors::ErrorCode::NOT_FOUND); + response->mutable_status()->set_msg(fmt::format( + "invalid service spec id: {}", request->service_spec().id())); + } else { + *(response->mutable_model_info()) = it->second; + response->mutable_status()->set_code(errors::ErrorCode::OK); + } + + timer.Pause(); + RecordMetrics(*request, *response, timer.CountMs(), "GetModelInfo"); +} + +void ModelServiceImpl::RecordMetrics(const apis::GetModelInfoRequest& request, + const apis::GetModelInfoResponse& response, + double duration_ms, + const std::string& action) { + stats_.api_request_duration_summary_family + .Add(::prometheus::Labels({{"service_id", request.service_spec().id()}, + {"action", action}}), + ::prometheus::Summary::Quantiles( + {{0.5, 0.05}, {0.9, 0.01}, {0.99, 0.001}})) + .Observe(duration_ms); + stats_.api_request_counter_family + .Add(::prometheus::Labels( + {{"service_id", request.service_spec().id()}, + {"code", std::to_string(response.status().code())}, + {"action", action}})) + .Increment(); +} + +ModelServiceImpl::Stats::Stats( + std::map labels, + const std::shared_ptr<::prometheus::Registry>& registry) + : api_request_counter_family( + ::prometheus::BuildCounter() + .Name("model_service_request_count") + .Help("How many model service api requests are handled by " + "this server.") + .Labels(labels) + .Register(*registry)), + api_request_duration_summary_family( + ::prometheus::BuildSummary() + .Name("model_service_request_duration_seconds") + .Help("model service api request duration in " + "milliseconds.") + .Labels(labels) + .Register(*registry)) {} + +} // namespace secretflow::serving diff --git a/secretflow_serving/server/model_service_impl.h b/secretflow_serving/server/model_service_impl.h new file mode 100644 index 0000000..675859f --- /dev/null +++ b/secretflow_serving/server/model_service_impl.h @@ -0,0 +1,62 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "prometheus/counter.h" +#include "prometheus/family.h" +#include "prometheus/registry.h" +#include "prometheus/summary.h" + +#include "secretflow_serving/server/metrics/default_metrics_registry.h" + +#include "secretflow_serving/apis/model_service.pb.h" + +namespace secretflow::serving { + +// 模型 - 服务入口 +class ModelServiceImpl : public apis::ModelService { + public: + explicit ModelServiceImpl(std::map model_infos, + const std::string& self_party_id); + + void GetModelInfo(::google::protobuf::RpcController* controller, + const apis::GetModelInfoRequest* request, + apis::GetModelInfoResponse* response, + ::google::protobuf::Closure* done) override; + + private: + struct Stats { + // for request api + ::prometheus::Family<::prometheus::Counter>& api_request_counter_family; + ::prometheus::Family<::prometheus::Summary>& + api_request_duration_summary_family; + + explicit Stats(std::map labels, + const std::shared_ptr<::prometheus::Registry>& registry = + metrics::GetDefaultRegistry()); + }; + + void RecordMetrics(const apis::GetModelInfoRequest& request, + const apis::GetModelInfoResponse& response, + double duration_ms, const std::string& action); + + private: + std::map model_infos_; + + std::string self_party_id_; + Stats stats_; +}; + +} // namespace secretflow::serving diff --git a/secretflow_serving/server/prediction_core.cc b/secretflow_serving/server/prediction_core.cc index d3f4c16..96903ea 100644 --- a/secretflow_serving/server/prediction_core.cc +++ b/secretflow_serving/server/prediction_core.cc @@ -16,65 +16,108 @@ #include "spdlog/spdlog.h" +#include "secretflow_serving/util/utils.h" + namespace secretflow::serving { +namespace { + +struct FeatureLengthVisitor { + template + void operator()(const FeatureField& field, const Vec& values) { + len = values.size(); + field_name = field.name(); + } + int len = 0; + std::string field_name; +}; + +} // namespace + PredictionCore::PredictionCore(Options opts) : opts_(std::move(opts)) { SERVING_ENFORCE(!opts_.service_id.empty(), errors::ErrorCode::INVALID_ARGUMENT); SERVING_ENFORCE(!opts_.party_id.empty(), errors::ErrorCode::INVALID_ARGUMENT); SERVING_ENFORCE(!opts_.cluster_ids.empty(), errors::ErrorCode::INVALID_ARGUMENT); - SERVING_ENFORCE(opts_.predictor, errors::ErrorCode::INVALID_ARGUMENT); } void PredictionCore::Predict(const apis::PredictRequest* request, apis::PredictResponse* response) noexcept { try { - SERVING_ENFORCE(request->service_spec().id() == opts_.service_id, - errors::ErrorCode::INVALID_ARGUMENT, - "invalid service sepc id: {}", - request->service_spec().id()); response->mutable_service_spec()->CopyFrom(request->service_spec()); - auto status = response->mutable_status(); + auto* status = response->mutable_status(); CheckArgument(request); opts_.predictor->Predict(request, response); status->set_code(errors::ErrorCode::OK); } catch (const Exception& e) { - SPDLOG_ERROR("predict failed, code:{}, msg:{}, stack:{}", e.code(), + SPDLOG_ERROR("Predict failed, code:{}, msg:{}, stack:{}", e.code(), e.what(), e.stack_trace()); response->mutable_status()->set_code(e.code()); response->mutable_status()->set_msg(e.what()); } catch (const std::exception& e) { - SPDLOG_ERROR("predict failed, msg:{}", e.what()); + SPDLOG_ERROR("Predict failed, msg:{}", e.what()); response->mutable_status()->set_code(errors::ErrorCode::UNEXPECTED_ERROR); response->mutable_status()->set_msg(e.what()); } } void PredictionCore::CheckArgument(const apis::PredictRequest* request) { - SERVING_ENFORCE(request->service_spec().id() == opts_.service_id, - errors::ErrorCode::INVALID_ARGUMENT, "invalid service id: {}", - request->service_spec().id()); - + SERVING_ENFORCE_EQ(request->service_spec().id(), opts_.service_id, + "invalid service spec id: {}", + request->service_spec().id()); std::vector missing_params_party; + std::unordered_map party_row_num; for (const auto& party_id : opts_.cluster_ids) { + if (party_id == opts_.party_id && !request->predefined_features().empty()) { + int predefined_row_num = -1; + for (const auto& feature : request->predefined_features()) { + FeatureLengthVisitor len_visitor; + FeatureVisit(len_visitor, feature); + + if (predefined_row_num == -1) { + predefined_row_num = len_visitor.len; + } else { + SERVING_ENFORCE(predefined_row_num == len_visitor.len, + errors::ErrorCode::INVALID_ARGUMENT, + "predifined_features should have same length, {} : " + "{}, previous is {}", + len_visitor.field_name, len_visitor.len, + predefined_row_num); + } + } + party_row_num[party_id] = predefined_row_num; + continue; + } + auto it = request->fs_params().find(party_id); + if (it == request->fs_params().end()) { - if (party_id == opts_.party_id && - !request->predefined_features().empty()) { - // check whether set predefined features - continue; - } missing_params_party.emplace_back(party_id); + } else { + if (it->second.query_datas().empty()) { + missing_params_party.emplace_back(party_id); + } else { + party_row_num[party_id] = it->second.query_datas().size(); + } } } if (!missing_params_party.empty()) { SERVING_THROW(errors::ErrorCode::INVALID_ARGUMENT, - "{} missing feature params", + "{} missing feature params or got empty query datas", fmt::join(missing_params_party.begin(), missing_params_party.end(), ",")); } + auto row_num_iter = party_row_num.begin(); + auto row_num = row_num_iter->second; + for (++row_num_iter; row_num_iter != party_row_num.end(); ++row_num_iter) { + SERVING_ENFORCE(row_num == row_num_iter->second, + errors::ErrorCode::INVALID_ARGUMENT, + "predict row nums should be same, expect:{}, " + "party({}) : {}", + row_num, row_num_iter->first, row_num_iter->second); + } } } // namespace secretflow::serving diff --git a/secretflow_serving/server/prediction_core.h b/secretflow_serving/server/prediction_core.h index 64325da..45264d7 100644 --- a/secretflow_serving/server/prediction_core.h +++ b/secretflow_serving/server/prediction_core.h @@ -26,7 +26,9 @@ class PredictionCore { struct Options { std::string service_id; std::string party_id; + std::vector cluster_ids; + std::shared_ptr predictor; }; @@ -38,6 +40,7 @@ class PredictionCore { apis::PredictResponse* response) noexcept; const std::string& GetServiceID() const { return opts_.service_id; } + const std::string& GetPartyID() const { return opts_.party_id; } protected: diff --git a/secretflow_serving/server/prediction_service_impl.cc b/secretflow_serving/server/prediction_service_impl.cc index c435b60..e72d439 100644 --- a/secretflow_serving/server/prediction_service_impl.cc +++ b/secretflow_serving/server/prediction_service_impl.cc @@ -21,12 +21,19 @@ namespace secretflow::serving { -PredictionServiceImpl::PredictionServiceImpl( - const std::shared_ptr& prediction_core) - : prediction_core_(prediction_core), - stats_({{"handler", "PredictionService"}, - {"service_id", prediction_core->GetServiceID()}, - {"party_id", prediction_core->GetPartyID()}}) {} +PredictionServiceImpl::PredictionServiceImpl(const std::string& party_id) + : party_id_(party_id), + stats_({{"handler", "PredictionService"}, {"party_id", party_id_}}), + init_flag_(false) {} + +void PredictionServiceImpl::Init( + const std::shared_ptr& prediction_core) { + SERVING_ENFORCE(prediction_core, errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(!init_flag_, errors::ErrorCode::LOGIC_ERROR); + + prediction_core_ = prediction_core; + init_flag_ = true; +} void PredictionServiceImpl::Predict( ::google::protobuf::RpcController* controller, @@ -38,25 +45,38 @@ void PredictionServiceImpl::Predict( SPDLOG_DEBUG("predict begin, request: {}", request->ShortDebugString()); yacl::ElapsedTimer timer; - prediction_core_->Predict(request, response); + + if (!init_flag_) { + response->mutable_service_spec()->CopyFrom(request->service_spec()); + response->mutable_status()->set_code(errors::ErrorCode::NOT_READY); + response->mutable_status()->set_msg( + "prediction service is not ready to serve, please retry later."); + } else { + prediction_core_->Predict(request, response); + } + timer.Pause(); SPDLOG_DEBUG("predict end, time: {}", timer.CountMs()); - RecordMetrics(*request, *response, timer.CountMs()); + RecordMetrics(*request, *response, timer.CountMs(), "Predict"); } void PredictionServiceImpl::RecordMetrics(const apis::PredictRequest& request, const apis::PredictResponse& response, - const double duration_ms) { - stats_.api_request_duration_summary.Observe(duration_ms); + double duration_ms, + const std::string& action) { + stats_.api_request_duration_summary_family + .Add(::prometheus::Labels({{"action", action}, + {"service_id", request.service_spec().id()}}), + ::prometheus::Summary::Quantiles( + {{0.5, 0.05}, {0.9, 0.01}, {0.99, 0.001}})) + .Observe(duration_ms); stats_.api_request_counter_family .Add(::prometheus::Labels( - {{"code", std::to_string(response.status().code())}})) + {{"action", action}, + {"service_id", request.service_spec().id()}, + {"code", std::to_string(response.status().code())}})) .Increment(); - stats_.api_request_total_duration_family - .Add(::prometheus::Labels( - {{"code", std::to_string(response.status().code())}})) - .Increment(duration_ms); stats_.predict_counter.Increment(response.results().size()); } @@ -70,22 +90,12 @@ PredictionServiceImpl::Stats::Stats( "this server.") .Labels(labels) .Register(*registry)), - api_request_total_duration_family( - ::prometheus::BuildCounter() - .Name("prediction_request_total_duration") - .Help("total time to process the request in milliseconds") - .Labels(labels) - .Register(*registry)), api_request_duration_summary_family( ::prometheus::BuildSummary() .Name("prediction_request_duration_seconds") .Help("prediction service api request duration in milliseconds.") .Labels(labels) .Register(*registry)), - api_request_duration_summary(api_request_duration_summary_family.Add( - ::prometheus::Labels{}, - ::prometheus::Summary::Quantiles( - {{0.5, 0.05}, {0.9, 0.01}, {0.99, 0.001}}))), predict_counter_family( ::prometheus::BuildCounter() .Name("prediction_count") diff --git a/secretflow_serving/server/prediction_service_impl.h b/secretflow_serving/server/prediction_service_impl.h index 4ad6503..86cc6ab 100644 --- a/secretflow_serving/server/prediction_service_impl.h +++ b/secretflow_serving/server/prediction_service_impl.h @@ -29,8 +29,9 @@ namespace secretflow::serving { // 预测 - 服务入口 class PredictionServiceImpl : public apis::PredictionService { public: - explicit PredictionServiceImpl( - const std::shared_ptr& prediction_core); + explicit PredictionServiceImpl(const std::string& party_id); + + void Init(const std::shared_ptr& prediction_core); void Predict(::google::protobuf::RpcController* controller, const apis::PredictRequest* request, @@ -41,11 +42,8 @@ class PredictionServiceImpl : public apis::PredictionService { struct Stats { // for request api ::prometheus::Family<::prometheus::Counter>& api_request_counter_family; - ::prometheus::Family<::prometheus::Counter>& - api_request_total_duration_family; ::prometheus::Family<::prometheus::Summary>& api_request_duration_summary_family; - ::prometheus::Summary& api_request_duration_summary; // for predict sample ::prometheus::Family<::prometheus::Counter>& predict_counter_family; ::prometheus::Counter& predict_counter; @@ -56,12 +54,17 @@ class PredictionServiceImpl : public apis::PredictionService { }; void RecordMetrics(const apis::PredictRequest& request, - const apis::PredictResponse& response, - const double duration_ms); + const apis::PredictResponse& response, double duration_ms, + const std::string& action); private: - std::shared_ptr prediction_core_; + const std::string& party_id_; + Stats stats_; + + std::shared_ptr prediction_core_; + + std::atomic init_flag_; }; } // namespace secretflow::serving diff --git a/secretflow_serving/server/server.cc b/secretflow_serving/server/server.cc index 1d783af..392ce62 100644 --- a/secretflow_serving/server/server.cc +++ b/secretflow_serving/server/server.cc @@ -17,11 +17,14 @@ #include "absl/strings/str_split.h" #include "spdlog/spdlog.h" +#include "secretflow_serving/framework/model_info_collector.h" #include "secretflow_serving/framework/model_loader.h" +#include "secretflow_serving/ops/graph.h" #include "secretflow_serving/server/execution_service_impl.h" #include "secretflow_serving/server/health.h" #include "secretflow_serving/server/metrics/default_metrics_registry.h" #include "secretflow_serving/server/metrics/metrics_service.h" +#include "secretflow_serving/server/model_service_impl.h" #include "secretflow_serving/server/prediction_service_impl.h" #include "secretflow_serving/server/version.h" #include "secretflow_serving/source/factory.h" @@ -55,14 +58,13 @@ Server::~Server() { void Server::Start() { const auto& self_party_id = opts_.cluster_config.self_id(); - // 1. get model package + // get model package auto source = SourceFactory::GetInstance()->Create(opts_.model_config, opts_.service_id); auto package_path = source->PullModel(); + // build channels std::string self_address; - - // 2. load model package std::vector cluster_ids; auto channels = std::make_shared(); for (const auto& party : opts_.cluster_config.parties()) { @@ -87,17 +89,26 @@ void Server::Start() { ? &opts_.cluster_config.channel_desc().tls_config() : nullptr)); } - Loader::Options loader_opts; - loader_opts.party_id = self_party_id; - auto loader = std::make_unique(loader_opts, channels); + + // load model package + auto loader = std::make_unique(); loader->Load(package_path); - auto executable = loader->GetExecutable(); + const auto& model_bundle = loader->GetModelBundle(); + Graph graph(model_bundle->graph()); - // 3. create execution_service + // build execution core + std::vector> executors; + for (const auto& execution : graph.GetExecutions()) { + executors.emplace_back(std::make_shared(execution)); + } ExecutionCore::Options exec_opts; exec_opts.id = opts_.service_id; exec_opts.party_id = self_party_id; - exec_opts.executable = executable; + exec_opts.executable = std::make_shared(std::move(executors)); + if (opts_.server_config.op_exec_worker_num() > 0) { + exec_opts.op_exec_workers_num = opts_.server_config.op_exec_worker_num(); + } + if (!opts_.server_config.feature_mapping().empty()) { exec_opts.feature_mapping = {opts_.server_config.feature_mapping().begin(), opts_.server_config.feature_mapping().end()}; @@ -105,21 +116,9 @@ void Server::Start() { exec_opts.feature_source_config = opts_.feature_source_config; auto execution_core = std::make_shared(std::move(exec_opts)); - // 4. prediction core - auto predictor = loader->GetPredictor(); - predictor->SetExecutionCore(execution_core); - - PredictionCore::Options prediction_core_opts; - prediction_core_opts.service_id = opts_.service_id; - prediction_core_opts.party_id = self_party_id; - prediction_core_opts.cluster_ids = std::move(cluster_ids); - prediction_core_opts.predictor = predictor; - auto prediction_core = - std::make_shared(std::move(prediction_core_opts)); - - // mertrics server + // start mertrics server if (opts_.server_config.metrics_exposer_port() > 0) { - std::vector strs = absl::StrSplit(self_address, ":"); + std::vector strs = absl::StrSplit(self_address, ':'); SERVING_ENFORCE(strs.size() == 2, errors::ErrorCode::LOGIC_ERROR, "invalid self address."); auto metrics_listen_address = fmt::format( @@ -127,7 +126,7 @@ void Server::Start() { brpc::ServerOptions metrics_server_options; if (opts_.server_config.has_tls_config()) { - auto ssl_opts = metrics_server_options.mutable_ssl_options(); + auto* ssl_opts = metrics_server_options.mutable_ssl_options(); ssl_opts->default_cert.certificate = opts_.server_config.tls_config().certificate_path(); ssl_opts->default_cert.private_key = @@ -136,10 +135,11 @@ void Server::Start() { ssl_opts->verify.ca_file_path = opts_.server_config.tls_config().ca_file_path(); } - metrics_server_.set_version(SERVING_VERSION_STRING); - auto metrics_service = new metrics::MetricsService(); + auto* metrics_service = new metrics::MetricsService(); metrics_service->RegisterCollectable(metrics::GetDefaultRegistry()); + + metrics_server_.set_version(SERVING_VERSION_STRING); if (metrics_server_.AddService(metrics_service, brpc::SERVER_OWNS_SERVICE) != 0) { SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, @@ -155,7 +155,49 @@ void Server::Start() { SPDLOG_INFO("begin metrics service listen at {}, ", metrics_listen_address); } - // start server + // build model_info_collector + ModelInfoCollector::Options m_c_opts; + m_c_opts.model_bundle = model_bundle; + m_c_opts.service_id = opts_.service_id; + m_c_opts.self_party_id = self_party_id; + m_c_opts.remote_channel_map = channels; + ModelInfoCollector model_info_collector(std::move(m_c_opts)); + { + auto max_retry_cnt = + opts_.cluster_config.channel_desc().handshake_max_retry_cnt(); + if (max_retry_cnt != 0) { + model_info_collector.SetRetryCounts(max_retry_cnt); + } + auto retry_interval_ms = + opts_.cluster_config.channel_desc().handshake_retry_interval_ms(); + if (retry_interval_ms != 0) { + model_info_collector.SetRetryIntervalMs(retry_interval_ms); + } + } + + // add services + auto* model_service = new ModelServiceImpl( + {{opts_.service_id, model_info_collector.GetSelfModelInfo()}}, + self_party_id); + if (service_server_.AddService(model_service, brpc::SERVER_OWNS_SERVICE) != + 0) { + SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, + "fail to add model service into brpc server."); + } + auto* execution_service = new ExecutionServiceImpl(execution_core); + if (service_server_.AddService(execution_service, + brpc::SERVER_OWNS_SERVICE) != 0) { + SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, + "fail to add execution service into brpc server."); + } + auto* prediction_service = new PredictionServiceImpl(self_party_id); + if (service_server_.AddService(prediction_service, + brpc::SERVER_OWNS_SERVICE) != 0) { + SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, + "fail to add prediction service into brpc server."); + } + + // build services server opts brpc::ServerOptions server_options; server_options.max_concurrency = opts_.server_config.max_concurrency(); if (opts_.server_config.worker_num() > 0) { @@ -168,7 +210,7 @@ void Server::Start() { SPDLOG_INFO("internal port: {}", server_options.internal_port); } if (opts_.server_config.has_tls_config()) { - auto ssl_opts = server_options.mutable_ssl_options(); + auto* ssl_opts = server_options.mutable_ssl_options(); ssl_opts->default_cert.certificate = opts_.server_config.tls_config().certificate_path(); ssl_opts->default_cert.private_key = @@ -177,31 +219,44 @@ void Server::Start() { ssl_opts->verify.ca_file_path = opts_.server_config.tls_config().ca_file_path(); } - - // health reporter health::ServingHealthReporter hr; - hr.SetStatusCode(200); server_options.health_reporter = &hr; + // start services server service_server_.set_version(SERVING_VERSION_STRING); - auto execution_service = new ExecutionServiceImpl(execution_core); - if (service_server_.AddService(execution_service, - brpc::SERVER_OWNS_SERVICE) != 0) { - SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, - "fail to add execution service into brpc server."); - } - auto prediction_service = new PredictionServiceImpl(prediction_core); - if (service_server_.AddService(prediction_service, - brpc::SERVER_OWNS_SERVICE) != 0) { - SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, - "fail to add prediction service into brpc server."); - } if (service_server_.Start(self_address.c_str(), &server_options) != 0) { SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, "fail to start brpc server at {}", self_address); } - SPDLOG_INFO("begin listen at {}, ", self_address); + // exchange model info + SPDLOG_INFO("start exchange model_info"); + + model_info_collector.DoCollect(); + auto specific_map = model_info_collector.GetSpecificMap(); + + SPDLOG_INFO("end exchange model_info"); + + // build prediction core, let prediction service begin to serve + Predictor::Options predictor_opts; + predictor_opts.party_id = self_party_id; + predictor_opts.channels = channels; + predictor_opts.executions = graph.GetExecutions(); + predictor_opts.specific_party_map = std::move(specific_map); + auto predictor = std::make_shared(predictor_opts); + predictor->SetExecutionCore(execution_core); + + PredictionCore::Options prediction_core_opts; + prediction_core_opts.service_id = opts_.service_id; + prediction_core_opts.party_id = self_party_id; + prediction_core_opts.cluster_ids = std::move(cluster_ids); + prediction_core_opts.predictor = predictor; + auto prediction_core = + std::make_shared(std::move(prediction_core_opts)); + prediction_service->Init(prediction_core); + + // set server ready code + hr.SetStatusCode(200); } void Server::WaitForEnd() { diff --git a/secretflow_serving/source/source.cc b/secretflow_serving/source/source.cc index 106c08f..5375831 100644 --- a/secretflow_serving/source/source.cc +++ b/secretflow_serving/source/source.cc @@ -42,10 +42,10 @@ std::string Source::PullModel() { } auto dst_file_path = dst_dir.append(kModelFileName); - const auto& source_md5 = config_.source_md5(); + const auto& source_sha256 = config_.source_sha256(); if (std::filesystem::exists(dst_file_path)) { - if (!source_md5.empty()) { - if (SysUtil::CheckMD5(dst_file_path.string(), source_md5)) { + if (!source_sha256.empty()) { + if (SysUtil::CheckSHA256(dst_file_path.string(), source_sha256)) { return dst_file_path; } } @@ -54,10 +54,10 @@ std::string Source::PullModel() { } OnPullModel(dst_file_path); - if (!source_md5.empty()) { - SERVING_ENFORCE(SysUtil::CheckMD5(dst_file_path.string(), source_md5), - errors::ErrorCode::IO_ERROR, "model({}) md5 check failed", - config_.source_path()); + if (!source_sha256.empty()) { + SERVING_ENFORCE(SysUtil::CheckSHA256(dst_file_path.string(), source_sha256), + errors::ErrorCode::IO_ERROR, + "model({}) sha256 check failed", config_.source_path()); } return dst_file_path; diff --git a/secretflow_serving/spis/batch_feature_service.proto b/secretflow_serving/spis/batch_feature_service.proto index 4275731..4b28fb6 100644 --- a/secretflow_serving/spis/batch_feature_service.proto +++ b/secretflow_serving/spis/batch_feature_service.proto @@ -29,6 +29,35 @@ service BatchFeatureService { } // BatchFetchFeature request containing one or more requests. +// examples: +// ```json +// { +// "header": { +// "data": { +// "custom_str": "id_12345" +// }, +// }, +// "model_service_id": "test_service_id", +// "party_id": "alice", +// "feature_fields": [ +// { +// "name": "f1", +// "type": 2 +// }, +// { +// "name": "f2", +// "type": 4 +// } +// ] +// "param": { +// "query_datas": [ +// "x1", +// "x2" +// ], +// "query_context": "context_x" +// } +// } +// ``` message BatchFetchFeatureRequest { // Custom data passed by the Predict request's header. Header header = 1; @@ -47,6 +76,46 @@ message BatchFetchFeatureRequest { } // BatchFetchFeatureResponse response containing one or more responses. +// examples: +// ```json +// { +// "header": { +// "data": { +// "custom_value": "asdfvb" +// } +// }, +// "status": { +// "code": 0, +// "msg": "success." +// }, +// "features": [ +// { +// "field": { +// "name": "f1", +// "type": 2 +// }, +// "value": { +// "i32s": [ +// 123, +// 234 +// ] +// } +// }, +// { +// "field": { +// "name": "f2", +// "type": 4 +// }, +// "value": { +// "fs": [ +// 0.123, +// 1.234 +// ] +// } +// } +// ] +// } +// ``` message BatchFetchFeatureResponse { // Custom data. Header header = 1; diff --git a/secretflow_serving/tools/model_view/BUILD.bazel b/secretflow_serving/tools/model_view/BUILD.bazel new file mode 100644 index 0000000..5eee0bd --- /dev/null +++ b/secretflow_serving/tools/model_view/BUILD.bazel @@ -0,0 +1,14 @@ +load("//bazel:serving.bzl", "serving_cc_binary") + +package(default_visibility = ["//visibility:public"]) + +serving_cc_binary( + name = "model_view", + srcs = ["main.cc"], + deps = [ + "//secretflow_serving/ops", + "//secretflow_serving/ops:graph", + "//secretflow_serving/protos:bundle_cc_proto", + "//secretflow_serving/util:utils", + ], +) diff --git a/secretflow_serving/tools/model_view/main.cc b/secretflow_serving/tools/model_view/main.cc new file mode 100644 index 0000000..f6b8a91 --- /dev/null +++ b/secretflow_serving/tools/model_view/main.cc @@ -0,0 +1,77 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "secretflow_serving/ops/graph.h" +#include "secretflow_serving/ops/op_kernel_factory.h" +#include "secretflow_serving/util/utils.h" + +#include "secretflow_serving/protos/bundle.pb.h" + +namespace secretflow::serving { + +static void ShowModel(const std::string& file_path) { + ModelBundle model_pb; + LoadPbFromBinaryFile(file_path, &model_pb); + + auto model_json_content = PbToJson(&model_pb); + std::cout << "Model content: " << std::endl; + std::cout << model_json_content << std::endl; + + // get input schema & output schema + std::cout << std::endl; + std::cout << "Io schema: " << std::endl; + Graph graph(model_pb.graph()); + for (const auto& node_def : model_pb.graph().node_list()) { + const auto& node = graph.GetNode(node_def.name()); + + op::OpKernelOptions ctx{node->node_def(), node->GetOpDef()}; + auto op_kernel = op::OpKernelFactory::GetInstance()->Create(std::move(ctx)); + + std::cout << "==========================" << std::endl; + std::cout << "node: " << node_def.name() << std::endl; + auto inputs_num = op_kernel->GetInputsNum(); + for (size_t i = 0; i < inputs_num; ++i) { + std::cout << "--------------------------" << std::endl; + std::cout << i << " input:" << std::endl; + std::cout << op_kernel->GetInputSchema(i)->ToString() << std::endl; + } + std::cout << "--------------------------" << std::endl; + std::cout << "output:" << std::endl; + std::cout << op_kernel->GetOutputSchema()->ToString() << std::endl; + } +} + +} // namespace secretflow::serving + +const char* help_msg = R"MSG( +Usage: model_view +View file content of secretflow binary format model file. +)MSG"; +int main(int argc, char** argv) { + if (argc != 2) { + std::cerr << help_msg << std::endl; + return 1; + } + + try { + secretflow::serving::ShowModel(argv[1]); + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + return -1; + } + + return 0; +} diff --git a/secretflow_serving/util/BUILD.bazel b/secretflow_serving/util/BUILD.bazel index 41708bd..0dd1623 100644 --- a/secretflow_serving/util/BUILD.bazel +++ b/secretflow_serving/util/BUILD.bazel @@ -43,6 +43,7 @@ serving_cc_library( deps = [ "//secretflow_serving/apis:status_cc_proto", "//secretflow_serving/core:exception", + "//secretflow_serving/protos:feature_cc_proto", ], ) @@ -51,12 +52,29 @@ serving_cc_library( srcs = ["arrow_helper.cc"], hdrs = ["arrow_helper.h"], deps = [ + ":utils", "//secretflow_serving/core:exception", - "//secretflow_serving/protos:feature_cc_proto", + "//secretflow_serving/protos:data_type_cc_proto", "@org_apache_arrow//:arrow", ], ) +serving_cc_library( + name = "thread_safe_queue", + hdrs = ["thread_safe_queue.h"], + deps = [ + "//secretflow_serving/core:exception", + ], +) + +serving_cc_library( + name = "thread_pool", + hdrs = ["thread_pool.h"], + deps = [ + ":thread_safe_queue", + ], +) + serving_cc_test( name = "arrow_helper_test", srcs = ["arrow_helper_test.cc"], @@ -75,3 +93,20 @@ serving_cc_library( "@com_github_brpc_brpc//:brpc", ], ) + +serving_cc_library( + name = "test_utils", + srcs = ["test_utils.cc"], + hdrs = ["test_utils.h"], + deps = [ + ":arrow_helper", + ], +) + +serving_cc_test( + name = "thread_safe_queue_test", + srcs = ["thread_safe_queue_test.cc"], + deps = [ + ":thread_safe_queue", + ], +) diff --git a/secretflow_serving/util/arrow_helper.cc b/secretflow_serving/util/arrow_helper.cc index 15516bf..539fc3f 100644 --- a/secretflow_serving/util/arrow_helper.cc +++ b/secretflow_serving/util/arrow_helper.cc @@ -17,140 +17,235 @@ #include #include -namespace secretflow::serving { +#include +#include -std::shared_ptr FieldTypeToDataType(FieldType field_type) { - const static std::map> - kDataTypeMap = { - {FieldType::FIELD_BOOL, arrow::boolean()}, - {FieldType::FIELD_INT32, arrow::int32()}, - {FieldType::FIELD_INT64, arrow::int64()}, - {FieldType::FIELD_FLOAT, arrow::float32()}, - {FieldType::FIELD_DOUBLE, arrow::float64()}, - {FieldType::FIELD_STRING, arrow::utf8()}, - }; +#include "secretflow_serving/core/exception.h" +#include "secretflow_serving/util/utils.h" - auto it = kDataTypeMap.find(field_type); - SERVING_ENFORCE(it != kDataTypeMap.end(), errors::ErrorCode::LOGIC_ERROR, - "unknow field type: {}", FieldType_Name(field_type)); - return it->second; -} +namespace secretflow::serving { -FieldType DataTypeToFieldType( - const std::shared_ptr& data_type) { - const static std::map kFieldTypeMap = { - {arrow::Type::type::BOOL, FieldType::FIELD_BOOL}, - {arrow::Type::type::INT32, FieldType::FIELD_INT32}, - {arrow::Type::type::INT64, FieldType::FIELD_INT64}, - {arrow::Type::type::FLOAT, FieldType::FIELD_FLOAT}, - {arrow::Type::type::DOUBLE, FieldType::FIELD_DOUBLE}, - {arrow::Type::type::STRING, FieldType::FIELD_STRING}, - }; +namespace { - auto it = kFieldTypeMap.find(data_type->id()); - SERVING_ENFORCE(it != kFieldTypeMap.end(), errors::ErrorCode::LOGIC_ERROR, - "unsupport arrow data type: {}", - arrow::internal::ToString(data_type->id())); - return it->second; -} - -std::shared_ptr FeaturesToTable( - const ::google::protobuf::RepeatedPtrField& features) { - arrow::SchemaBuilder schema_builder; - std::vector> arrays; - int num_rows = -1; - for (const auto& f : features) { - std::shared_ptr array; - int cur_num_rows = -1; - switch (f.field().type()) { - case FieldType::FIELD_BOOL: { - SERVING_CHECK_ARROW_STATUS(schema_builder.AddField( - arrow::field(f.field().name(), arrow::boolean()))); - arrow::BooleanBuilder array_builder; - SERVING_CHECK_ARROW_STATUS(array_builder.AppendValues( - f.value().bs().begin(), f.value().bs().end())); +struct FeatureToArrayVisitor { + void operator()(const FeatureField& field, + const ::google::protobuf::RepeatedField& values) { + arrow::BooleanBuilder array_builder; + SERVING_CHECK_ARROW_STATUS( + array_builder.AppendValues(values.begin(), values.end())); + SERVING_CHECK_ARROW_STATUS(array_builder.Finish(&array)); + } + void operator()(const FeatureField& field, + const ::google::protobuf::RepeatedField& values) { + switch (target_field->type()->id()) { + case arrow::Type::INT8: { + arrow::Int8Builder array_builder; + SERVING_CHECK_ARROW_STATUS(array_builder.Resize(values.size())); + for (const auto& v : values) { + SERVING_ENFORCE_GE(v, std::numeric_limits::min()); + SERVING_ENFORCE_LE(v, std::numeric_limits::max()); + SERVING_CHECK_ARROW_STATUS(array_builder.Append(v)); + } + SERVING_CHECK_ARROW_STATUS(array_builder.Finish(&array)); + break; + } + case arrow::Type::UINT8: { + arrow::UInt8Builder array_builder; + SERVING_CHECK_ARROW_STATUS(array_builder.Resize(values.size())); + for (const auto& v : values) { + SERVING_ENFORCE_GE(v, 0); + SERVING_ENFORCE_LE(v, std::numeric_limits::max()); + SERVING_CHECK_ARROW_STATUS(array_builder.Append(v)); + } SERVING_CHECK_ARROW_STATUS(array_builder.Finish(&array)); - cur_num_rows = f.value().bs_size(); break; } - case FieldType::FIELD_INT32: { - SERVING_CHECK_ARROW_STATUS(schema_builder.AddField( - arrow::field(f.field().name(), arrow::int32()))); + case arrow::Type::INT16: { + arrow::Int16Builder array_builder; + SERVING_CHECK_ARROW_STATUS(array_builder.Resize(values.size())); + for (const auto& v : values) { + SERVING_ENFORCE_GE(v, std::numeric_limits::min()); + SERVING_ENFORCE_LE(v, std::numeric_limits::max()); + SERVING_CHECK_ARROW_STATUS(array_builder.Append(v)); + } + SERVING_CHECK_ARROW_STATUS(array_builder.Finish(&array)); + break; + } + case arrow::Type::UINT16: { + arrow::UInt16Builder array_builder; + SERVING_CHECK_ARROW_STATUS(array_builder.Resize(values.size())); + for (const auto& v : values) { + SERVING_ENFORCE_GE(v, 0); + SERVING_ENFORCE_LE(v, std::numeric_limits::max()); + SERVING_CHECK_ARROW_STATUS(array_builder.Append(v)); + } + SERVING_CHECK_ARROW_STATUS(array_builder.Finish(&array)); + break; + } + case arrow::Type::INT32: { arrow::Int32Builder array_builder; - SERVING_CHECK_ARROW_STATUS(array_builder.AppendValues( - f.value().i32s().begin(), f.value().i32s().end())); + SERVING_CHECK_ARROW_STATUS( + array_builder.AppendValues(values.begin(), values.end())); + SERVING_CHECK_ARROW_STATUS(array_builder.Finish(&array)); + break; + } + default: + SERVING_THROW(errors::ErrorCode::INVALID_ARGUMENT, + "{} mismatch types, expect:{}, actual:{}", field.name(), + FieldType_Name(DataTypeToFieldType(target_field->type())), + FieldType_Name(field.type())); + } + } + void operator()(const FeatureField& field, + const ::google::protobuf::RepeatedField& values) { + switch (target_field->type()->id()) { + case arrow::Type::UINT32: { + arrow::UInt32Builder array_builder; + SERVING_CHECK_ARROW_STATUS(array_builder.Resize(values.size())); + for (const auto& v : values) { + SERVING_ENFORCE_GE(v, 0); + SERVING_ENFORCE_LE(v, std::numeric_limits::max()); + SERVING_CHECK_ARROW_STATUS(array_builder.Append(v)); + } SERVING_CHECK_ARROW_STATUS(array_builder.Finish(&array)); - cur_num_rows = f.value().i32s_size(); break; } - case FieldType::FIELD_INT64: { - SERVING_CHECK_ARROW_STATUS(schema_builder.AddField( - arrow::field(f.field().name(), arrow::int64()))); + case arrow::Type::INT64: { arrow::Int64Builder array_builder; - SERVING_CHECK_ARROW_STATUS(array_builder.AppendValues( - f.value().i64s().begin(), f.value().i64s().end())); + SERVING_CHECK_ARROW_STATUS( + array_builder.AppendValues(values.begin(), values.end())); SERVING_CHECK_ARROW_STATUS(array_builder.Finish(&array)); - cur_num_rows = f.value().i64s_size(); break; } - case FieldType::FIELD_FLOAT: { - SERVING_CHECK_ARROW_STATUS(schema_builder.AddField( - arrow::field(f.field().name(), arrow::float32()))); - arrow::FloatBuilder array_builder; - SERVING_CHECK_ARROW_STATUS(array_builder.AppendValues( - f.value().fs().begin(), f.value().fs().end())); + case arrow::Type::UINT64: { + arrow::UInt64Builder array_builder; + SERVING_CHECK_ARROW_STATUS(array_builder.Resize(values.size())); + for (const auto& v : values) { + SERVING_ENFORCE_GE(v, 0); + SERVING_CHECK_ARROW_STATUS(array_builder.Append(v)); + } SERVING_CHECK_ARROW_STATUS(array_builder.Finish(&array)); - cur_num_rows = f.value().fs_size(); break; } - case FieldType::FIELD_DOUBLE: { - SERVING_CHECK_ARROW_STATUS(schema_builder.AddField( - arrow::field(f.field().name(), arrow::float64()))); - arrow::DoubleBuilder array_builder; - SERVING_CHECK_ARROW_STATUS(array_builder.AppendValues( - f.value().ds().begin(), f.value().ds().end())); + default: + SERVING_THROW(errors::ErrorCode::INVALID_ARGUMENT, + "{} mismatch types, expect:{}, actual:{}", field.name(), + FieldType_Name(DataTypeToFieldType(target_field->type())), + FieldType_Name(field.type())); + } + } + void operator()(const FeatureField& field, + const ::google::protobuf::RepeatedField& values) { + switch (target_field->type()->id()) { + case arrow::Type::HALF_FLOAT: { + // currently `half_float` is not completely supported. + // see `https://arrow.apache.org/docs/12.0/status.html` + SERVING_THROW(errors::ErrorCode::INVALID_ARGUMENT, + "float16(halffloat) is unsupported."); + break; + } + case arrow::Type::FLOAT: { + arrow::FloatBuilder array_builder; + SERVING_CHECK_ARROW_STATUS( + array_builder.AppendValues(values.begin(), values.end())); SERVING_CHECK_ARROW_STATUS(array_builder.Finish(&array)); - cur_num_rows = f.value().ds_size(); break; } - case FieldType::FIELD_STRING: { - SERVING_CHECK_ARROW_STATUS(schema_builder.AddField( - arrow::field(f.field().name(), arrow::utf8()))); + default: + SERVING_THROW(errors::ErrorCode::INVALID_ARGUMENT, + "{} mismatch types, expect:{}, actual:{}", field.name(), + FieldType_Name(DataTypeToFieldType(target_field->type())), + FieldType_Name(field.type())); + } + } + void operator()(const FeatureField& field, + const ::google::protobuf::RepeatedField& values) { + SERVING_ENFORCE(target_field->type()->id() == arrow::Type::DOUBLE, + errors::INVALID_ARGUMENT, + "{} mismatch types, expect:{}, actual:{}", field.name(), + FieldType_Name(DataTypeToFieldType(target_field->type())), + FieldType_Name(field.type())); + arrow::DoubleBuilder array_builder; + SERVING_CHECK_ARROW_STATUS( + array_builder.AppendValues(values.begin(), values.end())); + SERVING_CHECK_ARROW_STATUS(array_builder.Finish(&array)); + } + void operator()( + const FeatureField& field, + const ::google::protobuf::RepeatedPtrField& values) { + switch (target_field->type()->id()) { + case arrow::Type::STRING: { arrow::StringBuilder array_builder; - auto ss_list = f.value().ss(); - for (const auto& s : ss_list) { + for (const auto& s : values) { + SERVING_CHECK_ARROW_STATUS(array_builder.Append(s)); + } + SERVING_CHECK_ARROW_STATUS(array_builder.Finish(&array)); + break; + } + case arrow::Type::BINARY: { + arrow::BinaryBuilder array_builder; + for (const auto& s : values) { SERVING_CHECK_ARROW_STATUS(array_builder.Append(s)); } SERVING_CHECK_ARROW_STATUS(array_builder.Finish(&array)); - cur_num_rows = f.value().ss_size(); break; } default: - SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, "unkown field type", - FieldType_Name(f.field().type())); + SERVING_THROW(errors::ErrorCode::INVALID_ARGUMENT, + "{} mismatch types, expect:{}, actual:{}", field.name(), + FieldType_Name(DataTypeToFieldType(target_field->type())), + FieldType_Name(field.type())); } - if (num_rows >= 0) { - SERVING_ENFORCE_EQ(num_rows, cur_num_rows, - "features must have same length value."); + } + + std::shared_ptr target_field; + std::shared_ptr array; +}; + +} // namespace + +std::shared_ptr FeaturesToTable( + const ::google::protobuf::RepeatedPtrField& features, + const std::shared_ptr& target_schema) { + arrow::SchemaBuilder schema_builder; + std::vector> arrays; + int num_rows = -1; + + for (const auto& field : target_schema->fields()) { + bool found = false; + for (const auto& f : features) { + if (f.field().name() == field->name()) { + FeatureToArrayVisitor visitor{.target_field = field, .array = {}}; + FeatureVisit(visitor, f); + + if (num_rows >= 0) { + SERVING_ENFORCE_EQ( + num_rows, visitor.array->length(), + "features must have same length value. {}:{}, others:{}", + f.field().name(), visitor.array->length(), num_rows); + } + num_rows = visitor.array->length(); + arrays.emplace_back(visitor.array); + found = true; + break; + } } - num_rows = cur_num_rows; - arrays.emplace_back(array); + SERVING_ENFORCE(found, errors::ErrorCode::UNEXPECTED_ERROR, + "can not found feature:{} in response", field->name()); } - std::shared_ptr schema; - SERVING_GET_ARROW_RESULT(schema_builder.Finish(), schema); - return MakeRecordBatch(schema, num_rows, std::move(arrays)); + return MakeRecordBatch(target_schema, num_rows, std::move(arrays)); } std::string SerializeRecordBatch( - std::shared_ptr& recordBatch) { + std::shared_ptr& record_batch) { std::shared_ptr out_stream; SERVING_GET_ARROW_RESULT(arrow::io::BufferOutputStream::Create(), out_stream); std::shared_ptr writer; SERVING_GET_ARROW_RESULT( - arrow::ipc::MakeStreamWriter(out_stream, recordBatch->schema()), writer); + arrow::ipc::MakeStreamWriter(out_stream, record_batch->schema()), writer); - SERVING_CHECK_ARROW_STATUS(writer->WriteRecordBatch(*recordBatch)); + SERVING_CHECK_ARROW_STATUS(writer->WriteRecordBatch(*record_batch)); SERVING_CHECK_ARROW_STATUS(writer->Close()); std::shared_ptr buffer; @@ -175,4 +270,101 @@ std::shared_ptr DeserializeRecordBatch( return record_batch; } +std::shared_ptr DeserializeSchema(const std::string& buf) { + std::shared_ptr result; + + std::shared_ptr buffer_reader = + std::make_shared(buf); + + arrow::ipc::DictionaryMemo tmp_memo; + SERVING_GET_ARROW_RESULT( + arrow::ipc::ReadSchema( + std::static_pointer_cast(buffer_reader).get(), + &tmp_memo), + result); + + return result; +} + +FieldType DataTypeToFieldType( + const std::shared_ptr& data_type) { + const static std::unordered_map kFieldTypeMap = + { + // supported data_type list: + // `secretflow_serving/protos/data_type.proto` + {arrow::Type::type::BOOL, FieldType::FIELD_BOOL}, + {arrow::Type::type::UINT8, FieldType::FIELD_INT32}, + {arrow::Type::type::INT8, FieldType::FIELD_INT32}, + {arrow::Type::type::UINT16, FieldType::FIELD_INT32}, + {arrow::Type::type::INT16, FieldType::FIELD_INT32}, + {arrow::Type::type::INT32, FieldType::FIELD_INT32}, + {arrow::Type::type::UINT32, FieldType::FIELD_INT64}, + {arrow::Type::type::UINT64, FieldType::FIELD_INT64}, + {arrow::Type::type::INT64, FieldType::FIELD_INT64}, + // currently `half_float` is not completely supported. + // see `https://arrow.apache.org/docs/12.0/status.html` + // {arrow::Type::type::HALF_FLOAT, FieldType::FIELD_FLOAT}, + {arrow::Type::type::FLOAT, FieldType::FIELD_FLOAT}, + {arrow::Type::type::DOUBLE, FieldType::FIELD_DOUBLE}, + {arrow::Type::type::STRING, FieldType::FIELD_STRING}, + {arrow::Type::type::BINARY, FieldType::FIELD_STRING}, + }; + + auto it = kFieldTypeMap.find(data_type->id()); + SERVING_ENFORCE(it != kFieldTypeMap.end(), errors::ErrorCode::LOGIC_ERROR, + "unsupported arrow data type: {}", + arrow::internal::ToString(data_type->id())); + return it->second; +} + +std::shared_ptr DataTypeToArrowDataType(DataType data_type) { + const static std::unordered_map> + kDataTypeMap = { + {DT_BOOL, arrow::boolean()}, + {DT_UINT8, arrow::uint8()}, + {DT_INT8, arrow::int8()}, + {DT_UINT16, arrow::uint16()}, + {DT_INT16, arrow::int16()}, + {DT_INT32, arrow::int32()}, + {DT_UINT32, arrow::uint32()}, + {DT_UINT64, arrow::uint64()}, + {DT_INT64, arrow::int64()}, + // currently `half_float` is not completely supported. + // see `https://arrow.apache.org/docs/12.0/status.html` + // {DT_FLOAT16, arrow::float16()}, + {DT_FLOAT, arrow::float32()}, + {DT_DOUBLE, arrow::float64()}, + {DT_STRING, arrow::utf8()}, + {DT_BINARY, arrow::binary()}, + }; + + auto it = kDataTypeMap.find(data_type); + SERVING_ENFORCE(it != kDataTypeMap.end(), errors::ErrorCode::LOGIC_ERROR, + "unsupported data type: {}", DataType_Name(data_type)); + return it->second; +} + +std::shared_ptr DataTypeToArrowDataType( + const std::string& data_type) { + DataType d_type; + SERVING_ENFORCE(DataType_Parse(data_type, &d_type), + errors::ErrorCode::UNEXPECTED_ERROR, "unknown data type: {}", + data_type); + return DataTypeToArrowDataType(d_type); +} + +void CheckReferenceFields(const std::shared_ptr& src, + const std::shared_ptr& dst, + const std::string& additional_msg) { + SERVING_CHECK_ARROW_STATUS( + src->CanReferenceFieldsByNames(dst->field_names())); + for (const auto& dst_f : dst->fields()) { + auto src_f = src->GetFieldByName(dst_f->name()); + SERVING_ENFORCE( + src_f->type()->id() == dst_f->type()->id(), errors::LOGIC_ERROR, + "{}. field: {} type not match, expect: {}, get: {}", additional_msg, + dst_f->name(), dst_f->type()->ToString(), src_f->type()->ToString()); + } +} + } // namespace secretflow::serving diff --git a/secretflow_serving/util/arrow_helper.h b/secretflow_serving/util/arrow_helper.h index 0438381..ef0f118 100644 --- a/secretflow_serving/util/arrow_helper.h +++ b/secretflow_serving/util/arrow_helper.h @@ -14,11 +14,14 @@ #pragma once +#include + #include "arrow/api.h" #include "google/protobuf/repeated_field.h" #include "secretflow_serving/core/exception.h" +#include "secretflow_serving/protos/data_type.pb.h" #include "secretflow_serving/protos/feature.pb.h" namespace secretflow::serving { @@ -38,12 +41,12 @@ namespace secretflow::serving { SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, \ __r__.status().message()); \ } else { \ - value = __r__.ValueOrDie(); \ + value = std::move(__r__.ValueOrDie()); \ } \ } while (false) inline std::shared_ptr MakeRecordBatch( - std::shared_ptr schema, int64_t num_rows, + const std::shared_ptr& schema, int64_t num_rows, std::vector> columns) { auto record_batch = arrow::RecordBatch::Make(schema, num_rows, std::move(columns)); @@ -51,18 +54,49 @@ inline std::shared_ptr MakeRecordBatch( return record_batch; } +inline std::shared_ptr MakeRecordBatch( + const std::shared_ptr& schema, int64_t num_rows, + std::vector> columns) { + return MakeRecordBatch(std::const_pointer_cast(schema), + num_rows, std::move(columns)); +} + std::string SerializeRecordBatch( - std::shared_ptr& recordBatch); + std::shared_ptr& record_batch); std::shared_ptr DeserializeRecordBatch( const std::string& buf); -std::shared_ptr FieldTypeToDataType(FieldType field_type); +std::shared_ptr DeserializeSchema(const std::string& buf); FieldType DataTypeToFieldType( const std::shared_ptr& data_type); std::shared_ptr FeaturesToTable( - const ::google::protobuf::RepeatedPtrField& features); + const ::google::protobuf::RepeatedPtrField& features, + const std::shared_ptr& target_schema); + +inline void CheckArrowDataTypeValid( + const std::shared_ptr& data_type) { + SERVING_ENFORCE( + arrow::is_numeric(data_type->id()) || arrow::is_string(data_type->id()) || + arrow::is_binary(data_type->id()), + errors::ErrorCode::LOGIC_ERROR, "unsupported arrow data type: {}", + arrow::internal::ToString(data_type->id())); + SERVING_ENFORCE(data_type->id() != arrow::Type::HALF_FLOAT, + errors::ErrorCode::LOGIC_ERROR, + "float16(halffloat) is unsupported."); +} + +std::shared_ptr DataTypeToArrowDataType(DataType data_type); + +std::shared_ptr DataTypeToArrowDataType( + const std::string& data_type); + +// Check that all fields in 'dst' can be found in 'src' and that the data type +// of each field is also consistent. +void CheckReferenceFields(const std::shared_ptr& src, + const std::shared_ptr& dst, + const std::string& additional_msg = ""); } // namespace secretflow::serving diff --git a/secretflow_serving/util/arrow_helper_test.cc b/secretflow_serving/util/arrow_helper_test.cc index 75fda07..fe2550c 100644 --- a/secretflow_serving/util/arrow_helper_test.cc +++ b/secretflow_serving/util/arrow_helper_test.cc @@ -26,87 +26,199 @@ class ArrowHelperTest : public ::testing::Test { TEST_F(ArrowHelperTest, FeaturesToTable) { ::google::protobuf::RepeatedPtrField features; - // bool - std::vector bool_list = {true, false, false}; - auto bool_f = features.Add(); - bool_f->mutable_field()->set_name("bool"); - bool_f->mutable_field()->set_type(FieldType::FIELD_BOOL); - bool_f->mutable_value()->mutable_bs()->Assign(bool_list.begin(), - bool_list.end()); - std::shared_ptr bool_array; - arrow::BooleanBuilder bool_builder; - SERVING_CHECK_ARROW_STATUS(bool_builder.AppendValues(bool_list)); - SERVING_CHECK_ARROW_STATUS(bool_builder.Finish(&bool_array)); - // int32 - std::vector int32_list = {1, 2, 3}; - auto i32_f = features.Add(); - i32_f->mutable_field()->set_name("int32"); - i32_f->mutable_field()->set_type(FieldType::FIELD_INT32); - i32_f->mutable_value()->mutable_i32s()->Assign(int32_list.begin(), - int32_list.end()); - std::shared_ptr i32_array; - arrow::Int32Builder i32_builder; - SERVING_CHECK_ARROW_STATUS(i32_builder.AppendValues(int32_list)); - SERVING_CHECK_ARROW_STATUS(i32_builder.Finish(&i32_array)); - // int64 - std::vector int64_list = {4294967296, 4294967297, 4294967298}; - auto i64_f = features.Add(); - i64_f->mutable_field()->set_name("int64"); - i64_f->mutable_field()->set_type(FieldType::FIELD_INT64); - i64_f->mutable_value()->mutable_i64s()->Assign(int64_list.begin(), - int64_list.end()); - std::shared_ptr i64_array; - arrow::Int64Builder i64_builder; - SERVING_CHECK_ARROW_STATUS(i64_builder.AppendValues(int64_list)); - SERVING_CHECK_ARROW_STATUS(i64_builder.Finish(&i64_array)); - // float - std::vector float_list = {1.1f, 2.2f, 3.3f}; - auto f_f = features.Add(); - f_f->mutable_field()->set_name("float"); - f_f->mutable_field()->set_type(FieldType::FIELD_FLOAT); - f_f->mutable_value()->mutable_fs()->Assign(float_list.begin(), - float_list.end()); - std::shared_ptr f_array; - arrow::FloatBuilder f_builder; - SERVING_CHECK_ARROW_STATUS(f_builder.AppendValues(float_list)); - SERVING_CHECK_ARROW_STATUS(f_builder.Finish(&f_array)); - // double - std::vector double_list = {16777217.1d, 16777218.2d, 16777218.3d}; - auto d_f = features.Add(); - d_f->mutable_field()->set_name("double"); - d_f->mutable_field()->set_type(FieldType::FIELD_DOUBLE); - d_f->mutable_value()->mutable_ds()->Assign(double_list.begin(), - double_list.end()); - std::shared_ptr d_array; - arrow::DoubleBuilder d_builder; - SERVING_CHECK_ARROW_STATUS(d_builder.AppendValues(double_list)); - SERVING_CHECK_ARROW_STATUS(d_builder.Finish(&d_array)); - // string + + std::vector int8_list = {1, -2, -3}; + std::vector int16_list = {1, -2, -3}; + std::vector int32_list = {1, -2, -3}; + std::vector uint8_list = {11, 22, 33}; + std::vector uint16_list = {1, 22, 33}; + std::vector uint32_list = {2294967296, 2294967297, 2294967298}; + std::vector int64_list = {4294967296, -4294967297, 4294967298}; + std::vector uint64_list = {4294967296, 4294967297, 4294967298}; + std::vector float_list = {1.1F, 2.2F, 3.3F}; + std::vector double_list = {16777217.1, 16777218.2, 16777218.3}; std::vector str_list = {"test_0", "test_1", "test_2"}; - auto s_f = features.Add(); - s_f->mutable_field()->set_name("string"); - s_f->mutable_field()->set_type(FieldType::FIELD_STRING); - s_f->mutable_value()->mutable_ss()->Assign(str_list.begin(), str_list.end()); - std::shared_ptr str_array; - arrow::StringBuilder str_builder; - SERVING_CHECK_ARROW_STATUS(str_builder.AppendValues(str_list)); - SERVING_CHECK_ARROW_STATUS(str_builder.Finish(&str_array)); + std::vector bool_list = {true, false, false}; - auto record_batch = FeaturesToTable(features); + std::vector> fields = { + arrow::field("int8", arrow::int8()), + arrow::field("uint8", arrow::uint8()), + arrow::field("int16", arrow::int16()), + arrow::field("uint16", arrow::uint16()), + arrow::field("int32", arrow::int32()), + arrow::field("uint32", arrow::uint32()), + arrow::field("int64", arrow::int64()), + arrow::field("uint64", arrow::uint64()), + arrow::field("float", arrow::float32()), + arrow::field("double", arrow::float64()), + arrow::field("string", arrow::utf8()), + arrow::field("binary", arrow::binary()), + arrow::field("bool", arrow::boolean())}; + + std::vector> arrays; + for (const auto& f : fields) { + std::shared_ptr array; + switch (f->type()->id()) { + case arrow::Type::INT8: { + auto* i32_f = features.Add(); + i32_f->mutable_field()->set_name("int8"); + i32_f->mutable_field()->set_type(FieldType::FIELD_INT32); + i32_f->mutable_value()->mutable_i32s()->Assign(int8_list.begin(), + int8_list.end()); + arrow::Int8Builder builder; + SERVING_CHECK_ARROW_STATUS(builder.AppendValues(int8_list)); + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + break; + } + case arrow::Type::UINT8: { + auto* i32_f = features.Add(); + i32_f->mutable_field()->set_name("uint8"); + i32_f->mutable_field()->set_type(FieldType::FIELD_INT32); + i32_f->mutable_value()->mutable_i32s()->Assign(uint8_list.begin(), + uint8_list.end()); + arrow::UInt8Builder builder; + SERVING_CHECK_ARROW_STATUS(builder.AppendValues(uint8_list)); + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + break; + } + case arrow::Type::INT16: { + auto* i32_f = features.Add(); + i32_f->mutable_field()->set_name("int16"); + i32_f->mutable_field()->set_type(FieldType::FIELD_INT32); + i32_f->mutable_value()->mutable_i32s()->Assign(int16_list.begin(), + int16_list.end()); + arrow::Int16Builder builder; + SERVING_CHECK_ARROW_STATUS(builder.AppendValues(int16_list)); + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + break; + } + case arrow::Type::UINT16: { + auto* i32_f = features.Add(); + i32_f->mutable_field()->set_name("uint16"); + i32_f->mutable_field()->set_type(FieldType::FIELD_INT32); + i32_f->mutable_value()->mutable_i32s()->Assign(uint16_list.begin(), + uint16_list.end()); + arrow::UInt16Builder builder; + SERVING_CHECK_ARROW_STATUS(builder.AppendValues(uint16_list)); + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + break; + } + case arrow::Type::INT32: { + auto* i32_f = features.Add(); + i32_f->mutable_field()->set_name("int32"); + i32_f->mutable_field()->set_type(FieldType::FIELD_INT32); + i32_f->mutable_value()->mutable_i32s()->Assign(int32_list.begin(), + int32_list.end()); + arrow::Int32Builder builder; + SERVING_CHECK_ARROW_STATUS(builder.AppendValues(int32_list)); + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + break; + } + case arrow::Type::UINT32: { + auto* i64_f = features.Add(); + i64_f->mutable_field()->set_name("uint32"); + i64_f->mutable_field()->set_type(FieldType::FIELD_INT64); + i64_f->mutable_value()->mutable_i64s()->Assign(uint32_list.begin(), + uint32_list.end()); + arrow::UInt32Builder builder; + SERVING_CHECK_ARROW_STATUS(builder.AppendValues(uint32_list)); + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + break; + } + case arrow::Type::INT64: { + auto* i64_f = features.Add(); + i64_f->mutable_field()->set_name("int64"); + i64_f->mutable_field()->set_type(FieldType::FIELD_INT64); + i64_f->mutable_value()->mutable_i64s()->Assign(int64_list.begin(), + int64_list.end()); + arrow::Int64Builder builder; + SERVING_CHECK_ARROW_STATUS(builder.AppendValues(int64_list)); + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + break; + } + case arrow::Type::UINT64: { + auto* i64_f = features.Add(); + i64_f->mutable_field()->set_name("uint64"); + i64_f->mutable_field()->set_type(FieldType::FIELD_INT64); + i64_f->mutable_value()->mutable_i64s()->Assign(uint64_list.begin(), + uint64_list.end()); + arrow::UInt64Builder builder; + SERVING_CHECK_ARROW_STATUS(builder.AppendValues(uint64_list)); + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + break; + } + case arrow::Type::FLOAT: { + auto* f_f = features.Add(); + f_f->mutable_field()->set_name("float"); + f_f->mutable_field()->set_type(FieldType::FIELD_FLOAT); + f_f->mutable_value()->mutable_fs()->Assign(float_list.begin(), + float_list.end()); + arrow::FloatBuilder builder; + SERVING_CHECK_ARROW_STATUS(builder.AppendValues(float_list)); + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + break; + } + case arrow::Type::DOUBLE: { + auto* d_f = features.Add(); + d_f->mutable_field()->set_name("double"); + d_f->mutable_field()->set_type(FieldType::FIELD_DOUBLE); + d_f->mutable_value()->mutable_ds()->Assign(double_list.begin(), + double_list.end()); + arrow::DoubleBuilder builder; + SERVING_CHECK_ARROW_STATUS(builder.AppendValues(double_list)); + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + break; + } + case arrow::Type::STRING: { + auto* s_f = features.Add(); + s_f->mutable_field()->set_name("string"); + s_f->mutable_field()->set_type(FieldType::FIELD_STRING); + s_f->mutable_value()->mutable_ss()->Assign(str_list.begin(), + str_list.end()); + arrow::StringBuilder builder; + SERVING_CHECK_ARROW_STATUS(builder.AppendValues(str_list)); + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + break; + } + case arrow::Type::BINARY: { + auto* s_f = features.Add(); + s_f->mutable_field()->set_name("binary"); + s_f->mutable_field()->set_type(FieldType::FIELD_STRING); + s_f->mutable_value()->mutable_ss()->Assign(str_list.begin(), + str_list.end()); + arrow::BinaryBuilder builder; + SERVING_CHECK_ARROW_STATUS(builder.AppendValues(str_list)); + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + break; + } + case arrow::Type::BOOL: { + auto* bool_f = features.Add(); + bool_f->mutable_field()->set_name("bool"); + bool_f->mutable_field()->set_type(FieldType::FIELD_BOOL); + bool_f->mutable_value()->mutable_bs()->Assign(bool_list.begin(), + bool_list.end()); + arrow::BooleanBuilder builder; + SERVING_CHECK_ARROW_STATUS(builder.AppendValues(bool_list)); + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + break; + } + default: + SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, "logistic error"); + } + arrays.emplace_back(std::move(array)); + } // expect - auto expect_schema = arrow::schema({arrow::field("bool", arrow::boolean()), - arrow::field("int32", arrow::int32()), - arrow::field("int64", arrow::int64()), - arrow::field("float", arrow::float32()), - arrow::field("double", arrow::float64()), - arrow::field("string", arrow::utf8())}); - auto expect_record = MakeRecordBatch( - expect_schema, 3, - {bool_array, i32_array, i64_array, f_array, d_array, str_array}); + auto expect_schema = arrow::schema(fields); + auto expect_record = MakeRecordBatch(expect_schema, 3, arrays); + + std::cout << "expect_record: " << expect_record->ToString() << std::endl; + + auto record_batch = FeaturesToTable(features, expect_schema); - std::cout << record_batch->ToString() << std::endl; + std::cout << "record_batch: " << record_batch->ToString() << std::endl; + EXPECT_TRUE(record_batch->schema()->Equals(expect_schema)); EXPECT_TRUE(record_batch->Equals(*expect_record)); } diff --git a/secretflow_serving/util/network.cc b/secretflow_serving/util/network.cc index 780565b..065e968 100644 --- a/secretflow_serving/util/network.cc +++ b/secretflow_serving/util/network.cc @@ -67,7 +67,7 @@ std::shared_ptr CreateBrpcChannel( opts.timeout_ms = rpc_timeout_ms; } if (connect_timeout_ms > 0) { - opts.timeout_ms = connect_timeout_ms; + opts.connect_timeout_ms = connect_timeout_ms; } if (tls_config != nullptr) { opts.mutable_ssl_options()->client_cert.certificate = diff --git a/secretflow_serving/util/sys_util.cc b/secretflow_serving/util/sys_util.cc index a17d960..3a28f54 100644 --- a/secretflow_serving/util/sys_util.cc +++ b/secretflow_serving/util/sys_util.cc @@ -22,7 +22,8 @@ #include #include "absl/strings/escaping.h" -#include "openssl/md5.h" +#include "openssl/evp.h" +#include "openssl/sha.h" #include "spdlog/spdlog.h" #include "secretflow_serving/core/exception.h" @@ -70,14 +71,12 @@ int cmd_through_popen(std::ostream& os, const char* cmd) { namespace { -std::string MD5String(const std::string& str) { - unsigned char results[MD5_DIGEST_LENGTH]; - MD5_CTX ctx; - MD5_Init(&ctx); - MD5_Update(&ctx, str.data(), str.length()); - MD5_Final(results, &ctx); +std::string SHA256String(const std::string& str) { + unsigned char results[SHA256_DIGEST_LENGTH]; + EVP_Digest(str.data(), str.length(), results, nullptr, EVP_sha256(), nullptr); + return absl::BytesToHexString(absl::string_view( - reinterpret_cast(results), MD5_DIGEST_LENGTH)); + reinterpret_cast(results), SHA256_DIGEST_LENGTH)); } } // namespace @@ -92,9 +91,10 @@ void SysUtil::System(const std::string& cmd, std::string* command_output) { if (content.length() > 2048) { content.resize(2048); } - YACL_THROW("execute cmd={} return error code={}: {}", cmd, ret, content); + SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, + "execute cmd={} return error code={}: {}", cmd, ret, content); } - if (command_output) { + if (command_output != nullptr) { *command_output = cmd_output.str(); } } @@ -103,25 +103,26 @@ void SysUtil::ExtractGzippedArchive(const std::string& package_path, const std::string& target_dir) { if (!std::filesystem::exists(package_path) || std::filesystem::file_size(package_path) == 0) { - YACL_THROW("file {} not exist or file size == 0. extract fail", - package_path); + SERVING_THROW(errors::ErrorCode::IO_ERROR, + "file {} not exist or file size == 0. extract fail", + package_path); } std::filesystem::create_directories(target_dir); - auto cmd = - fmt::format("tar zxf \"{0}\" -C \"{1}\"", package_path, target_dir); + auto cmd = fmt::format(R"(tar zxf "{0}" -C "{1}")", package_path, target_dir); SysUtil::System(cmd); } -bool SysUtil::CheckMD5(const std::string& fname, const std::string& md5sum) { +bool SysUtil::CheckSHA256(const std::string& fname, + const std::string& expect_sha256) { std::ifstream file_is(fname); std::string content((std::istreambuf_iterator(file_is)), std::istreambuf_iterator()); - std::string md5_str = MD5String(content); - if (md5_str.compare(md5sum) != 0) { - SPDLOG_WARN("file({}) md5 check failed, expect:{}, get:{}", fname, md5sum, - md5_str); + std::string sha256_str = SHA256String(content); + if (sha256_str != expect_sha256) { + SPDLOG_WARN("file({}) sha256 check failed, expect:{}, get:{}", fname, + expect_sha256, sha256_str); return false; } return true; diff --git a/secretflow_serving/util/sys_util.h b/secretflow_serving/util/sys_util.h index ded38fb..ae914d1 100644 --- a/secretflow_serving/util/sys_util.h +++ b/secretflow_serving/util/sys_util.h @@ -26,7 +26,8 @@ class SysUtil { static void ExtractGzippedArchive(const std::string& package_path, const std::string& target_dir); - static bool CheckMD5(const std::string& fname, const std::string& md5sum); + static bool CheckSHA256(const std::string& fname, + const std::string& expect_sha256); }; } // namespace secretflow::serving diff --git a/secretflow_serving/util/test_utils.cc b/secretflow_serving/util/test_utils.cc new file mode 100644 index 0000000..650bf5a --- /dev/null +++ b/secretflow_serving/util/test_utils.cc @@ -0,0 +1,26 @@ +#include "secretflow_serving/util/test_utils.h" + +#include + +namespace secretflow::serving::test { + +std::shared_ptr ShuffleRecordBatch( + std::shared_ptr input_batch) { + auto fields = input_batch->schema()->fields(); + + std::random_device rd; + std::mt19937 g(rd()); + std::shuffle(fields.begin(), fields.end(), g); + + std::vector> new_columns; + new_columns.reserve(fields.size()); + for (const auto& f : fields) { + new_columns.emplace_back( + input_batch->column(input_batch->schema()->GetFieldIndex(f->name()))); + } + + return arrow::RecordBatch::Make(arrow::schema(fields), + input_batch->num_rows(), new_columns); +} + +} // namespace secretflow::serving::test diff --git a/secretflow_serving/util/test_utils.h b/secretflow_serving/util/test_utils.h new file mode 100644 index 0000000..689fc51 --- /dev/null +++ b/secretflow_serving/util/test_utils.h @@ -0,0 +1,10 @@ +#pragma once + +#include "arrow/api.h" + +namespace secretflow::serving::test { + +std::shared_ptr ShuffleRecordBatch( + std::shared_ptr input_batch); + +} // namespace secretflow::serving::test diff --git a/secretflow_serving/util/thread_pool.h b/secretflow_serving/util/thread_pool.h new file mode 100644 index 0000000..d320b2c --- /dev/null +++ b/secretflow_serving/util/thread_pool.h @@ -0,0 +1,154 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include + +#include "spdlog/spdlog.h" + +#include "secretflow_serving/core/exception.h" +#include "secretflow_serving/util/thread_safe_queue.h" + +namespace secretflow::serving { + +class ThreadPool { + public: + static std::shared_ptr GetInstance() { + static std::shared_ptr thread_pool{new ThreadPool()}; + return thread_pool; + } + + class Task { + public: + virtual void OnException(std::exception_ptr exception) noexcept = 0; + + virtual void Exec() = 0; + virtual const char* Name() { return "Task"; } + virtual ~Task() = default; + }; + + ~ThreadPool() { Stop(); } + + void Stop() { + bool started = true; + if (!started_.compare_exchange_strong(started, false)) { + return; + } + + if (tasks_num_ != 0) { + SPDLOG_ERROR("thread pool stoped with {} tasks not executed.", + tasks_num_.load()); + } + + for (auto& task_queue : task_queues_) { + std::unique_ptr task; + task_queue.StopPush(); + } + + for (auto& thread : threads_) { + if (thread.joinable()) { + thread.join(); + } + } + } + + bool IsRunning() const { return started_; } + + void SubmitTask(std::unique_ptr task) { + if (!started_) { + SPDLOG_WARN("submit task: {} while threadpool is not started", + task->Name()); + } else { + SPDLOG_DEBUG("submit task: {}", task->Name()); + } + + task_queues_[insert_queue_index_.fetch_add(1) % task_queues_.size()].Push( + std::move(task)); + tasks_num_++; + } + + void Start(int32_t thread_num) { + bool started = false; + if (!started_.compare_exchange_strong(started, true)) { + SPDLOG_WARN( + "Thread pool cannot be started multiple times, already have {} " + "threads running.", + threads_.size()); + + return; + } + + SPDLOG_INFO("Create and start thread pool with {} threads", thread_num); + task_queues_ = + std::vector>>(thread_num); + + auto exec_task = [this](auto& task) { + --tasks_num_; + SPDLOG_DEBUG("start execute: {}", task->Name()); + try { + task->Exec(); + } catch (std::exception& e) { + SPDLOG_ERROR("execute task {} with exception: {}", task->Name(), + e.what()); + task->OnException(std::current_exception()); + } + SPDLOG_DEBUG("end execute: {}", task->Name()); + }; + + for (int32_t i = 0; i != thread_num; ++i) { + threads_.emplace_back( + [this, exec_task](size_t i) { + size_t spurious_cnt = 0; + while (started_) { + std::unique_ptr task; + if (task_queues_[i].BlockPop(task) && started_) { + exec_task(task); + continue; + } + if (started_ && tasks_num_ > 0) { + // TODO: steal tasks + } + spurious_cnt++; + } + SPDLOG_DEBUG("spurious BlockPop times: {}", spurious_cnt); + }, + i); + } + } + + size_t GetTaskSize() const { + return std::accumulate( + task_queues_.begin(), task_queues_.end(), 0, + [](int size, auto& queue) { return size + queue.size(); }); + } + + private: + explicit ThreadPool() {} + + std::vector threads_; + + std::vector>> task_queues_; + // Hint for submit tasks + std::atomic insert_queue_index_{0}; + + std::atomic tasks_num_{0}; + std::atomic started_{false}; +}; + +} // namespace secretflow::serving diff --git a/secretflow_serving/util/thread_safe_queue.h b/secretflow_serving/util/thread_safe_queue.h new file mode 100644 index 0000000..e1c2c59 --- /dev/null +++ b/secretflow_serving/util/thread_safe_queue.h @@ -0,0 +1,155 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "secretflow_serving/core/exception.h" + +namespace secretflow::serving { + +constexpr uint32_t kMaxQueueSize = 4096; + +template +class ThreadSafeQueue { + public: + using ValueType = T; + using SizeType = size_t; + + explicit ThreadSafeQueue(uint32_t pop_wait_ms = DEFAULT_POP_WAIT_MS) + : ThreadSafeQueue({}, pop_wait_ms) {} + + explicit ThreadSafeQueue(std::deque buffer, + uint32_t pop_wait_ms = DEFAULT_POP_WAIT_MS) + : wait_ms_(pop_wait_ms) { + SERVING_ENFORCE_LE(buffer.size(), kMaxQueueSize, + "init buffer size is too big: {} > {}", buffer.size(), + kMaxQueueSize); + for (size_t i = 0; i != buffer.size(); ++i) { + buffer_[i] = std::move(buffer[i]); + } + SERVING_ENFORCE(wait_ms_ > 0, errors::ErrorCode::INVALID_ARGUMENT, + "wait_ms should not be zero", wait_ms_); + } + + ThreadSafeQueue(const ThreadSafeQueue&) = delete; + ThreadSafeQueue(ThreadSafeQueue&&) = delete; + ThreadSafeQueue& operator=(const ThreadSafeQueue&) = delete; + ThreadSafeQueue& operator=(ThreadSafeQueue&&) = delete; + + void Push(T t) { + // wait queue not full + { + std::unique_lock lock(mtx_); + if (length_ >= kMaxQueueSize) { + full_cv_.wait(lock, + [this] { return length_ < kMaxQueueSize || stop_flag_; }); + } + if (stop_flag_) { + empty_cv_.notify_all(); + return; + } + length_++; + buffer_[tail_index_] = std::move(t); + ++tail_index_; + tail_index_ = tail_index_ % kMaxQueueSize; + } + empty_cv_.notify_one(); + } + + void StopPush() { + stop_flag_ = true; + full_cv_.notify_all(); + empty_cv_.notify_all(); + } + + bool BlockPop(T& t) { + { + std::unique_lock lock(mtx_); + if (length_ <= 0) { + empty_cv_.wait(lock, [this] { return stop_flag_ || (length_ > 0); }); + if (length_ <= 0) { + return false; + } + } + length_--; + t = std::move(buffer_[head_index_]); + + ++head_index_; + head_index_ = head_index_ % kMaxQueueSize; + } + full_cv_.notify_one(); + return true; + } + + bool TryPop(T& t) { + { + std::unique_lock lock(mtx_); + if (length_ <= 0) { + return false; + } + length_--; + t = std::move(buffer_[head_index_]); + ++head_index_; + head_index_ = head_index_ % kMaxQueueSize; + } + full_cv_.notify_one(); + return true; + } + + bool WaitPop(T& t) { + { + std::unique_lock lock(mtx_); + if (length_ <= 0) { + empty_cv_.wait_for(lock, std::chrono::milliseconds(wait_ms_), + [this] { return stop_flag_ || (length_ > 0); }); + if (length_ <= 0) { + return false; + } + } + length_--; + t = std::move(buffer_[head_index_]); + ++head_index_; + head_index_ = head_index_ % kMaxQueueSize; + } + full_cv_.notify_one(); + + return true; + } + + int size() const { + std::lock_guard guard_(mtx_); + return length_; + } + + ~ThreadSafeQueue() { StopPush(); } + + private: + std::array buffer_; + + mutable std::mutex mtx_; + std::condition_variable full_cv_; + std::condition_variable empty_cv_; + uint32_t head_index_{0}; + uint32_t tail_index_{0}; + uint32_t length_{0}; + std::atomic stop_flag_{false}; + + uint32_t wait_ms_; + + static constexpr uint32_t DEFAULT_POP_WAIT_MS = 10; +}; +} // namespace secretflow::serving diff --git a/secretflow_serving/util/thread_safe_queue_test.cc b/secretflow_serving/util/thread_safe_queue_test.cc new file mode 100644 index 0000000..8f68760 --- /dev/null +++ b/secretflow_serving/util/thread_safe_queue_test.cc @@ -0,0 +1,111 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/util/thread_safe_queue.h" + +#include + +#include +#include +#include + +namespace secretflow::serving { + +using namespace std::chrono_literals; + +TEST(ThreadSafeQueueTest, Basic) { + ThreadSafeQueue q; + + EXPECT_TRUE((std::is_same_v::ValueType>)); + + EXPECT_EQ(q.size(), 0); + q.Push(1); + q.Push(2); + EXPECT_EQ(q.size(), 2); + int val = 0; + q.TryPop(val); + EXPECT_EQ(q.size(), 1); + EXPECT_EQ(val, 1); + q.TryPop(val); + EXPECT_EQ(val, 2); + EXPECT_EQ(q.size(), 0); +} + +TEST(ThreadSafeQueueTest, TryPop) { + ThreadSafeQueue q; + + int val = 0; + EXPECT_FALSE(q.TryPop(val)); + EXPECT_EQ(val, 0); + + q.Push(999); + EXPECT_TRUE(q.TryPop(val)); + EXPECT_EQ(val, 999); +} + +TEST(ThreadSafeQueueTest, WaitPop) { + ThreadSafeQueue q(50); + int val = 0; + auto start_time = std::chrono::steady_clock::now(); + EXPECT_FALSE(q.WaitPop(val)); + auto end_time = std::chrono::steady_clock::now(); + auto duration = std::chrono::duration_cast( + end_time - start_time); + EXPECT_GE(duration.count(), 50); + EXPECT_LE(duration.count(), 60); + q.Push(999); + start_time = std::chrono::steady_clock::now(); + EXPECT_TRUE(q.WaitPop(val)); + end_time = std::chrono::steady_clock::now(); + duration = std::chrono::duration_cast(end_time - + start_time); + EXPECT_LE(duration.count(), 1); +} + +TEST(ThreadSafeQueueTest, Push) { + ThreadSafeQueue q; + uint32_t val = 0; + + auto done = std::async(std::launch::async, [&]() { + for (uint32_t i = 0; i < kMaxQueueSize + 1; ++i) { + q.Push(i); + } + }); + EXPECT_EQ(done.wait_for(std::chrono::milliseconds(1)), + std::future_status::timeout); + EXPECT_EQ(q.size(), kMaxQueueSize); + q.TryPop(val); + EXPECT_EQ(done.wait_for(std::chrono::milliseconds(1)), + std::future_status::ready); + EXPECT_EQ(q.size(), kMaxQueueSize); +} + +TEST(ThreadSafeQueueTest, StopPush) { + ThreadSafeQueue q; + + auto done = std::async(std::launch::async, [&]() { + for (uint32_t i = 0; i < kMaxQueueSize + 10; ++i) { + q.Push(i); + } + }); + EXPECT_EQ(done.wait_for(std::chrono::milliseconds(1)), + std::future_status::timeout); + EXPECT_EQ(q.size(), kMaxQueueSize); + q.StopPush(); + EXPECT_EQ(done.wait_for(std::chrono::milliseconds(1)), + std::future_status::ready); + EXPECT_EQ(q.size(), kMaxQueueSize); +} + +} // namespace secretflow::serving diff --git a/secretflow_serving/util/utils.cc b/secretflow_serving/util/utils.cc index aa32fe6..3391afc 100644 --- a/secretflow_serving/util/utils.cc +++ b/secretflow_serving/util/utils.cc @@ -14,6 +14,7 @@ #include "secretflow_serving/util/utils.h" +#include #include #include @@ -22,34 +23,50 @@ namespace secretflow::serving { -void LoadPbFromJsonFile(const std::string& file, - ::google::protobuf::Message* message) { +namespace { +std::string ReadFileContent(const std::string& file) { + if (!std::filesystem::exists(file)) { + SERVING_THROW(errors::ErrorCode::IO_ERROR, "can not find file: {}", file); + } std::ifstream file_is(file); - SERVING_ENFORCE(file_is.good(), errors::ErrorCode::FS_INVALID_ARGUMENT, + SERVING_ENFORCE(file_is.good(), errors::ErrorCode::IO_ERROR, "open failed, file: {}", file); - std::string content((std::istreambuf_iterator(file_is)), - std::istreambuf_iterator()); - JsonToPb(content, message); + return std::string((std::istreambuf_iterator(file_is)), + std::istreambuf_iterator()); +} +} // namespace + +void LoadPbFromJsonFile(const std::string& file, + ::google::protobuf::Message* message) { + JsonToPb(ReadFileContent(file), message); } void LoadPbFromBinaryFile(const std::string& file, ::google::protobuf::Message* message) { - std::ifstream file_is(file); - std::string content((std::istreambuf_iterator(file_is)), - std::istreambuf_iterator()); - - SERVING_ENFORCE(message->ParseFromString(content), - errors::ErrorCode::DESERIALIZE_FAILD, + SERVING_ENFORCE(message->ParseFromString(ReadFileContent(file)), + errors::ErrorCode::DESERIALIZE_FAILED, "parse pb failed, file: {}", file); } void JsonToPb(const std::string& json, ::google::protobuf::Message* message) { auto status = ::google::protobuf::util::JsonStringToMessage(json, message); if (!status.ok()) { - SPDLOG_ERROR("json to pb faied, msg:{}, json:{}", status.ToString(), json); - SERVING_THROW(errors::ErrorCode::DESERIALIZE_FAILD, + SPDLOG_ERROR("json to pb failed, msg:{}, json:{}", status.ToString(), json); + SERVING_THROW(errors::ErrorCode::DESERIALIZE_FAILED, "json to pb failed, msg:{}", status.ToString()); } } +std::string PbToJson(const ::google::protobuf::Message* message) { + std::string json; + auto status = ::google::protobuf::util::MessageToJsonString(*message, &json); + if (!status.ok()) { + SPDLOG_ERROR("pb to json failed, msg:{}, message:{}", status.ToString(), + message->ShortDebugString()); + SERVING_THROW(errors::ErrorCode::SERIALIZE_FAILED, + "pb to json failed, msg:{}", status.ToString()); + } + return json; +} + } // namespace secretflow::serving diff --git a/secretflow_serving/util/utils.h b/secretflow_serving/util/utils.h index 629c82c..a5ccaec 100644 --- a/secretflow_serving/util/utils.h +++ b/secretflow_serving/util/utils.h @@ -18,6 +18,7 @@ #include "secretflow_serving/apis/error_code.pb.h" #include "secretflow_serving/apis/status.pb.h" +#include "secretflow_serving/protos/feature.pb.h" namespace secretflow::serving { @@ -37,4 +38,39 @@ void LoadPbFromBinaryFile(const std::string& file, void JsonToPb(const std::string& json, ::google::protobuf::Message* message); +std::string PbToJson(const ::google::protobuf::Message* message); + +template +void FeatureVisit(Func&& visitor, const Feature& f) { + switch (f.field().type()) { + case FieldType::FIELD_BOOL: { + visitor(f.field(), f.value().bs()); + break; + } + case FieldType::FIELD_INT32: { + visitor(f.field(), f.value().i32s()); + break; + } + case FieldType::FIELD_INT64: { + visitor(f.field(), f.value().i64s()); + break; + } + case FieldType::FIELD_FLOAT: { + visitor(f.field(), f.value().fs()); + break; + } + case FieldType::FIELD_DOUBLE: { + visitor(f.field(), f.value().ds()); + break; + } + case FieldType::FIELD_STRING: { + visitor(f.field(), f.value().ss()); + break; + } + default: + SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, "unkown field type", + FieldType_Name(f.field().type())); + } +} + } // namespace secretflow::serving diff --git a/secretflow_serving_lib/BUILD.bazel b/secretflow_serving_lib/BUILD.bazel new file mode 100644 index 0000000..c504e52 --- /dev/null +++ b/secretflow_serving_lib/BUILD.bazel @@ -0,0 +1,97 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_library", "py_test") +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package(default_visibility = ["//visibility:public"]) + +exports_files( + [ + "exported_symbols.lds", + "version_script.lds", + ], + visibility = ["//visibility:private"], +) + +pybind_extension( + name = "libserving", + srcs = ["libserving.cc"], + linkopts = select({ + "@bazel_tools//src/conditions:darwin": [ + "-Wl,-exported_symbols_list,$(location //secretflow_serving_lib:exported_symbols.lds)", + ], + "//conditions:default": [ + "-Wl,--version-script,$(location //secretflow_serving_lib:version_script.lds)", + ], + }), + deps = [ + ":exported_symbols.lds", + ":version_script.lds", + "//secretflow_serving/ops", + "//secretflow_serving/ops:graph_version", + "@yacl//yacl/base:exception", + ], +) + +py_library( + name = "protos", + srcs = [ + "attr_pb2.py", + "bundle_pb2.py", + "compute_trace_pb2.py", + "data_type_pb2.py", + "feature_pb2.py", + "graph_pb2.py", + "link_function_pb2.py", + "op_pb2.py", + "//secretflow_serving/protos:attr_py_proto", + "//secretflow_serving/protos:bundle_py_proto", + "//secretflow_serving/protos:compute_trace_py_proto", + "//secretflow_serving/protos:data_type_py_proto", + "//secretflow_serving/protos:feature_py_proto", + "//secretflow_serving/protos:graph_py_proto", + "//secretflow_serving/protos:link_function_py_proto", + "//secretflow_serving/protos:op_py_proto", + ], +) + +py_library( + name = "api", + srcs = [ + "api.py", + ":protos", + ], + data = [ + ":libserving.so", + ], +) + +py_library( + name = "init", + srcs = [ + "__init__.py", + "version.py", + ":api", + ":protos", + ], +) + +py_test( + name = "api_test", + srcs = ["api_test.py"], + deps = [ + ":init", + ], +) diff --git a/secretflow_serving_lib/__init__.py b/secretflow_serving_lib/__init__.py new file mode 100644 index 0000000..34edc9e --- /dev/null +++ b/secretflow_serving_lib/__init__.py @@ -0,0 +1,41 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .version import __version__ + +from . import op_pb2 +from . import attr_pb2 +from . import compute_trace_pb2 +from . import graph_pb2 +from . import feature_pb2 +from . import bundle_pb2 +from . import data_type_pb2 +from . import link_function_pb2 + +from .api import get_all_ops, get_op, get_graph_version + +__all__ = [ + # api + "get_all_ops", + "get_op", + "get_graph_version", + "op_pb2", + "attr_pb2", + "compute_trace_pb2", + "graph_pb2", + "feature_pb2", + "bundle_pb2", + "data_type_pb2", + "link_function_pb2", +] diff --git a/secretflow_serving_lib/api.py b/secretflow_serving_lib/api.py new file mode 100644 index 0000000..a014b4c --- /dev/null +++ b/secretflow_serving_lib/api.py @@ -0,0 +1,42 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import List + +from . import op_pb2 +from . import libserving # type: ignore + + +def get_all_ops() -> List[op_pb2.OpDef]: + ret = [] + contents = libserving.get_all_op_defs_impl() + for c in contents: + o = op_pb2.OpDef() + o.ParseFromString(c) + ret.append(o) + + return ret + + +def get_op(name: str) -> op_pb2.OpDef: + content = libserving.get_op_def_impl(name) + o = op_pb2.OpDef() + o.ParseFromString(content) + return o + + +def get_graph_version() -> str: + return libserving.get_graph_def_version_impl() diff --git a/secretflow_serving_lib/api_test.py b/secretflow_serving_lib/api_test.py new file mode 100644 index 0000000..82edf45 --- /dev/null +++ b/secretflow_serving_lib/api_test.py @@ -0,0 +1,25 @@ +import os +import sys + +import secretflow_serving_lib as serving + +print("about to import", file=sys.stderr) +print("python is", sys.version_info) +print("pid is", os.getpid()) + + +print("imported, about to call", file=sys.stderr) + +# get ops +ops = serving.get_all_ops() +assert len(ops) == 3 + +# get ops +op = serving.get_op("MERGE_Y") +print(op) + +# get graph version +g_v = serving.get_graph_version() +print(g_v) + +print("done!", file=sys.stderr) diff --git a/secretflow_serving_lib/attr_pb2.py b/secretflow_serving_lib/attr_pb2.py new file mode 100755 index 0000000..790051b --- /dev/null +++ b/secretflow_serving_lib/attr_pb2.py @@ -0,0 +1 @@ +from secretflow_serving.protos.attr_pb2 import * diff --git a/secretflow_serving_lib/bundle_pb2.py b/secretflow_serving_lib/bundle_pb2.py new file mode 100755 index 0000000..ced30c9 --- /dev/null +++ b/secretflow_serving_lib/bundle_pb2.py @@ -0,0 +1 @@ +from secretflow_serving.protos.bundle_pb2 import * diff --git a/secretflow_serving_lib/compute_trace_pb2.py b/secretflow_serving_lib/compute_trace_pb2.py new file mode 100644 index 0000000..a6f3ef6 --- /dev/null +++ b/secretflow_serving_lib/compute_trace_pb2.py @@ -0,0 +1 @@ +from secretflow_serving.protos.compute_trace_pb2 import * diff --git a/secretflow_serving_lib/data_type_pb2.py b/secretflow_serving_lib/data_type_pb2.py new file mode 100755 index 0000000..d52ba41 --- /dev/null +++ b/secretflow_serving_lib/data_type_pb2.py @@ -0,0 +1 @@ +from secretflow_serving.protos.data_type_pb2 import * diff --git a/secretflow_serving_lib/exported_symbols.lds b/secretflow_serving_lib/exported_symbols.lds new file mode 100644 index 0000000..2637585 --- /dev/null +++ b/secretflow_serving_lib/exported_symbols.lds @@ -0,0 +1 @@ +_PyInit_* \ No newline at end of file diff --git a/secretflow_serving_lib/feature_pb2.py b/secretflow_serving_lib/feature_pb2.py new file mode 100755 index 0000000..bb973cd --- /dev/null +++ b/secretflow_serving_lib/feature_pb2.py @@ -0,0 +1 @@ +from secretflow_serving.protos.feature_pb2 import * diff --git a/secretflow_serving_lib/graph_pb2.py b/secretflow_serving_lib/graph_pb2.py new file mode 100755 index 0000000..9289116 --- /dev/null +++ b/secretflow_serving_lib/graph_pb2.py @@ -0,0 +1 @@ +from secretflow_serving.protos.graph_pb2 import * diff --git a/secretflow_serving_lib/libserving.cc b/secretflow_serving_lib/libserving.cc new file mode 100644 index 0000000..c50a561 --- /dev/null +++ b/secretflow_serving_lib/libserving.cc @@ -0,0 +1,83 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "fmt/format.h" +#include "pybind11/complex.h" +#include "pybind11/functional.h" +#include "pybind11/iostream.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "yacl/base/exception.h" + +#include "secretflow_serving/ops/graph_version.h" +#include "secretflow_serving/ops/op_factory.h" +#include "secretflow_serving/ops/op_kernel_factory.h" +#include "secretflow_serving/util/arrow_helper.h" + +namespace py = pybind11; + +namespace secretflow::serving::op { + +#define NO_GIL py::call_guard() + +PYBIND11_MODULE(libserving, m) { + m.doc() = R"pbdoc( + Secretflow-Serving Python Library + )pbdoc"; + + py::register_exception_translator( + [](std::exception_ptr p) { // NOLINT: pybind11 + try { + if (p) { + std::rethrow_exception(p); + } + } catch (const yacl::Exception& e) { + // Translate this exception to a standard RuntimeError + PyErr_SetString(PyExc_RuntimeError, + fmt::format("what: \n\t{}\nstacktrace: \n{}\n", + e.what(), e.stack_trace()) + .c_str()); + } + }); + + m.def("get_all_op_defs_impl", []() -> std::vector { + std::vector result; + auto op_defs = OpFactory::GetInstance()->GetAllOps(); + std::for_each(op_defs.begin(), op_defs.end(), + [&](const std::shared_ptr& op) { + std::string content; + YACL_ENFORCE(op->SerializeToString(&content)); + result.emplace_back(std::move(content)); + }); + return result; + }); + + m.def( + "get_op_def_impl", + [](const std::string& name) -> py::bytes { + std::string result; + const auto def = OpFactory::GetInstance()->Get(name); + YACL_ENFORCE(def->SerializeToString(&result)); + return result; + }, + py::arg("name")); + + m.def("get_graph_def_version_impl", + []() -> std::string { return SERVING_GRAPH_VERSION_STRING; }); +} + +} // namespace secretflow::serving::op diff --git a/secretflow_serving_lib/link_function_pb2.py b/secretflow_serving_lib/link_function_pb2.py new file mode 100644 index 0000000..d28c66b --- /dev/null +++ b/secretflow_serving_lib/link_function_pb2.py @@ -0,0 +1 @@ +from secretflow_serving.protos.link_function_pb2 import * diff --git a/secretflow_serving_lib/op_pb2.py b/secretflow_serving_lib/op_pb2.py new file mode 100644 index 0000000..4123451 --- /dev/null +++ b/secretflow_serving_lib/op_pb2.py @@ -0,0 +1 @@ +from secretflow_serving.protos.op_pb2 import * diff --git a/secretflow_serving_lib/version.py b/secretflow_serving_lib/version.py new file mode 100644 index 0000000..e22a684 --- /dev/null +++ b/secretflow_serving_lib/version.py @@ -0,0 +1,16 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +__version__ = "0.2.0dev$$DATE$$" diff --git a/secretflow_serving_lib/version_script.lds b/secretflow_serving_lib/version_script.lds new file mode 100644 index 0000000..a7e3bc0 --- /dev/null +++ b/secretflow_serving_lib/version_script.lds @@ -0,0 +1,9 @@ +VERS_1.0 { + # Export symbols in pybind. + global: + PyInit_*; + + # Hide everything else. + local: + *; +}; diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..133bf5a --- /dev/null +++ b/setup.py @@ -0,0 +1,257 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Ideas borrowed from: https://github.com/ray-project/ray/blob/master/python/setup.py + +import io +import logging +import os +import platform +import re +import shutil +import subprocess +import sys +import setuptools +import setuptools.command.build_ext + +from datetime import datetime, timedelta + +logger = logging.getLogger(__name__) +# 3.8 is the minimum python version we can support +SUPPORTED_PYTHONS = [(3, 8), (3, 9), (3, 10), (3, 11)] +BAZEL_MAX_JOBS = os.getenv("BAZEL_MAX_JOBS") +ROOT_DIR = os.path.dirname(__file__) +SKIP_BAZEL_CLEAN = os.getenv("SKIP_BAZEL_CLEAN") +BAZEL_CACHE_DIR = os.getenv("BAZEL_CACHE_DIR") + +pyd_suffix = ".so" + + +def add_date_to_version(*filepath): + local_time = datetime.utcnow() + chn_time = local_time + timedelta(hours=8) + dstr = chn_time.strftime("%Y%m%d") + with open(os.path.join(ROOT_DIR, *filepath), "r") as fp: + content = fp.read() + + content = content.replace("$$DATE$$", dstr) + + with open(os.path.join(ROOT_DIR, *filepath), "w+") as fp: + fp.write(content) + + +def find_version(*filepath): + add_date_to_version(*filepath) + # Extract version information from filepath + with open(os.path.join(ROOT_DIR, *filepath)) as fp: + version_match = re.search( + r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M + ) + if version_match: + return version_match.group(1) + raise RuntimeError("Unable to find version string.") + + +def read_requirements(*filepath): + requirements = [] + with open(os.path.join(ROOT_DIR, *filepath)) as file: + requirements = file.read().splitlines() + return requirements + + +class SetupSpec: + def __init__(self, name: str, description: str): + self.name: str = name + self.version = find_version("secretflow_serving_lib", "version.py") + self.description: str = description + self.files_to_include: list = [] + self.install_requires: list = [] + self.extras: dict = {} + + def get_packages(self): + return setuptools.find_packages("./secretflow_serving_lib") + + +setup_spec = SetupSpec( + "secretflow-serving-lib", + "Serving is a subproject of Secretflow that implements model serving capabilities.", +) + +# These are the directories where automatically generated Python protobuf +# bindings are created. +generated_python_directories = [ + "bazel-bin/secretflow_serving_lib", + "bazel-bin/secretflow_serving/protos", +] +setup_spec.install_requires = read_requirements('requirements.txt') +files_to_remove = [] + + +# NOTE: The lists below must be kept in sync with spu/BUILD.bazel. +serving_ops_lib_files = [ + "bazel-bin/secretflow_serving_lib/libserving" + pyd_suffix, +] + + +# Calls Bazel in PATH +def bazel_invoke(invoker, cmdline, *args, **kwargs): + try: + result = invoker(['bazel'] + cmdline, *args, **kwargs) + return result + except IOError: + raise + + +def build(build_python, build_cpp): + if tuple(sys.version_info[:2]) not in SUPPORTED_PYTHONS: + msg = ( + "Detected Python version {}, which is not supported. " + "Only Python {} are supported." + ).format( + ".".join(map(str, sys.version_info[:2])), + ", ".join(".".join(map(str, v)) for v in SUPPORTED_PYTHONS), + ) + raise RuntimeError(msg) + + bazel_env = dict(os.environ, PYTHON3_BIN_PATH=sys.executable) + + bazel_flags = ["--verbose_failures"] + if BAZEL_MAX_JOBS: + n = int(BAZEL_MAX_JOBS) # the value must be an int + bazel_flags.append("--jobs") + bazel_flags.append(f"{n}") + if BAZEL_CACHE_DIR: + bazel_flags.append(f"--repository_cache={BAZEL_CACHE_DIR}") + + bazel_precmd_flags = [] + + bazel_targets = [] + bazel_targets += ["//secretflow_serving_lib:init"] if build_python else [] + bazel_targets += ["//secretflow_serving_lib:api"] if build_cpp else [] + + bazel_flags.extend(["-c", "opt"]) + + if sys.platform == "linux": + bazel_flags.extend(["--config=linux-release"]) + + if platform.machine() == "x86_64": + bazel_flags.extend(["--config=avx"]) + + return bazel_invoke( + subprocess.check_call, + bazel_precmd_flags + ["build"] + bazel_flags + ["--"] + bazel_targets, + env=bazel_env, + ) + + +def remove_prefix(text, prefix): + return text[text.startswith(prefix) and len(prefix) :] + + +def copy_file(target_dir, filename, rootdir): + source = os.path.relpath(filename, rootdir) + destination = os.path.join(target_dir, remove_prefix(source, 'bazel-bin/')) + + # Create the target directory if it doesn't already exist. + os.makedirs(os.path.dirname(destination), exist_ok=True) + if not os.path.exists(destination): + print(f"Copy file from {source} to {destination}") + shutil.copy(source, destination, follow_symlinks=True) + return 1 + return 0 + + +def remove_file(target_dir, filename): + file = os.path.join(target_dir, filename) + if os.path.exists(file): + print(f"delete {file}") + os.remove(file) + return 1 + return 0 + + +def pip_run(build_ext): + build(True, True) + + setup_spec.files_to_include += serving_ops_lib_files + + # Copy over the autogenerated protobuf Python bindings. + for directory in generated_python_directories: + for filename in os.listdir(directory): + if filename[-3:] == ".py": + setup_spec.files_to_include.append(os.path.join(directory, filename)) + + copied_files = 0 + for filename in setup_spec.files_to_include: + copied_files += copy_file(build_ext.build_lib, filename, ROOT_DIR) + print("# of files copied to {}: {}".format(build_ext.build_lib, copied_files)) + + deleted_files = 0 + for filename in files_to_remove: + deleted_files += remove_file(build_ext.build_lib, filename) + print("# of files deleted in {}: {}".format(build_ext.build_lib, deleted_files)) + + +class build_ext(setuptools.command.build_ext.build_ext): + def run(self): + return pip_run(self) + + +class BinaryDistribution(setuptools.Distribution): + def has_ext_modules(self): + return True + + +# Ensure no remaining lib files. +build_dir = os.path.join(ROOT_DIR, "build") +if os.path.isdir(build_dir): + shutil.rmtree(build_dir) + +if not SKIP_BAZEL_CLEAN: + bazel_invoke(subprocess.check_call, ['clean']) + +# Default Linux platform tag +plat_name = "manylinux2014_x86_64" +folder_name = "secretflow_serving_lib" + + +def add_folder_name(l): + return [folder_name + "/" + element for element in l] + + +setuptools.setup( + name=setup_spec.name, + version=setup_spec.version, + author="SecretFlow Team", + author_email='secretflow-contact@service.alipay.com', + description=(setup_spec.description), + long_description=io.open( + os.path.join(ROOT_DIR, "README.md"), "r", encoding="utf-8" + ).read(), + long_description_content_type='text/markdown', + classifiers=[ + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + ], + packages=[folder_name] + add_folder_name(setuptools.find_packages(folder_name)), + cmdclass={"build_ext": build_ext}, + # The BinaryDistribution argument triggers build_ext. + distclass=BinaryDistribution, + install_requires=setup_spec.install_requires, + setup_requires=["wheel"], + extras_require=setup_spec.extras, + license="Apache 2.0", + options={'bdist_wheel': {'plat_name': plat_name}}, +)