Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(optimizer): support external partition in virtual circuits #1187

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,11 @@ class KeysetCache {
};

Message<concreteprotocol::KeysetInfo> keysetInfoFromVirtualCircuit(
std::vector<concrete_optimizer::utils::PartitionDefinition> partitions,
bool generate_fks, std::optional<concrete_optimizer::Options> options);
std::vector<concrete_optimizer::utils::PartitionDefinition>
internalPartitions,
std::vector<concrete_optimizer::utils::ExternalPartition>
externalPartitions,
std::optional<concrete_optimizer::Options> options);

} // namespace keysets
} // namespace concretelang
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,6 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
// ------------------------------------------------------------------------------//
// PARTITION DEFINITION //
// ------------------------------------------------------------------------------//
//
pybind11::class_<concrete_optimizer::utils::PartitionDefinition>(
m, "PartitionDefinition")
.def(init([](uint8_t precision, double norm2)
Expand All @@ -1081,6 +1080,28 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.doc() = "Definition of a partition (in terms of precision in bits and "
"norm2 in value).";

// ------------------------------------------------------------------------------//
// EXTERNAL PARTITION //
// ------------------------------------------------------------------------------//
pybind11::class_<concrete_optimizer::utils::ExternalPartition>(
m, "ExternalPartition")
.def(init([](std::string name, uint64_t macroLog2PolynomialSize,
uint64_t macroGlweDimension, uint64_t macroInternalDim,
double maxVariance, double variance)
-> concrete_optimizer::utils::ExternalPartition {
return concrete_optimizer::utils::ExternalPartition{
name,
macroLog2PolynomialSize,
macroGlweDimension,
macroInternalDim,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if macroInternalDim have any use now or in the future !?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how this is used, but it is needed to construct the external partition structure in the optimizer ...

maxVariance,
variance};
}),
arg("name"), arg("macro_log2_polynomial_size"),
arg("macro_glwe_dimension"), arg("macro_internal_dim"),
arg("max_variance"), arg("variance"))
.doc() = "Definition of an external partition.";

// ------------------------------------------------------------------------------//
// KEYSET INFO //
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about keygen using KeySetInfo instead of ProgramInfo? If it requires a lot of changes in the frontend, then we can think about having both.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentionned at the end of our discussion, I thought it may incorrectly nudge the user into generating keysets from the virtual keyset info, which would yield a big keygen, while just a small subset would be needed for the circuit.

And thinking about it a bit more, I don't think it would work to use a superset of the necessary keyset on a circuit (nothing fancy, just that we index keys by number, and as such, we may incorrectly lookup keys).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense yeah. I'm just thinking how can we showcase the feature then... We just need a way to construct a keyset from a program info using pregenerated keys. If we had a big keyset (with more keys than what the circuit needs), we could have a method that takes the program info and the big keyset, and extract necessary keys with their appropriate index.
We can also simulate all this with a keygen for now, so it's not really necessary, but it will make a better demo with something in this direction.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// ------------------------------------------------------------------------------//
Expand All @@ -1089,15 +1110,16 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def_static(
"generate_virtual",
[](std::vector<concrete_optimizer::utils::PartitionDefinition>
partitions,
bool generateFks,
internalPartitions,
std::vector<concrete_optimizer::utils::ExternalPartition>
externalPartitions,
std::optional<concrete_optimizer::Options> options) -> KeysetInfo {
if (partitions.size() < 2) {
if (internalPartitions.size() + externalPartitions.size() < 2) {
throw std::runtime_error("Need at least two partition defs to "
"generate a virtual keyset info.");
}
return ::concretelang::keysets::keysetInfoFromVirtualCircuit(
partitions, generateFks, options);
internalPartitions, externalPartitions, options);
},
arg("partition_defs"), arg("generate_fks"),
arg("options") = std::nullopt,
Expand Down
12 changes: 10 additions & 2 deletions compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,13 +560,21 @@ generateKeysetInfoFromParameters(CircuitKeys parameters,

Message<concreteprotocol::KeysetInfo> keysetInfoFromVirtualCircuit(
std::vector<concrete_optimizer::utils::PartitionDefinition> partitionDefs,
bool generateFks, std::optional<concrete_optimizer::Options> options) {
std::vector<concrete_optimizer::utils::ExternalPartition>
externalPartitions,
std::optional<concrete_optimizer::Options> options) {

rust::Vec<concrete_optimizer::utils::PartitionDefinition> rustPartitionDefs{};
for (auto def : partitionDefs) {
rustPartitionDefs.push_back(def);
}

rust::Vec<concrete_optimizer::utils::ExternalPartition>
rustExternalPartitions{};
for (auto def : externalPartitions) {
rustExternalPartitions.push_back(def);
}

auto defaultOptions = concrete_optimizer::Options{};
defaultOptions.security_level = 128;
defaultOptions.maximum_acceptable_error_probability = 0.000063342483999973;
Expand All @@ -577,7 +585,7 @@ Message<concreteprotocol::KeysetInfo> keysetInfoFromVirtualCircuit(
auto opts = options.value_or(defaultOptions);

auto parameters = concrete_optimizer::utils::generate_virtual_keyset_info(
rustPartitionDefs, generateFks, opts);
rustPartitionDefs, rustExternalPartitions, opts);

return generateKeysetInfoFromParameters(parameters, opts);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ struct FunctionToDag {
options, logPolySize, glweDim, lweDim, pbsLevel, pbsLogBase);
auto name = partitionAttr.getName().getValue().str();
// TODO: max_variance vs variance
return concrete_optimizer::utils::get_external_partition(
name, logPolySize, glweDim, lweDim, max_variance, max_variance);
return concrete_optimizer::utils::ExternalPartition{
name, logPolySize, glweDim, lweDim, max_variance, max_variance};
};

if (auto srcPartitionAttr =
Expand All @@ -287,14 +287,14 @@ struct FunctionToDag {
auto partition = partitionBuilder(srcPartitionAttr);

index[val] = dagBuilder.add_change_partition_with_src(
encrypted_input, *partition, *loc_to_location(val.getLoc()));
encrypted_input, partition, *loc_to_location(val.getLoc()));
} else if (auto destPartitionAttr =
op.getAttrOfType<mlir::concretelang::FHE::PartitionAttr>(
"dest")) {
auto partition = partitionBuilder(destPartitionAttr);

index[val] = dagBuilder.add_change_partition_with_dst(
encrypted_input, *partition, *loc_to_location(val.getLoc()));
encrypted_input, partition, *loc_to_location(val.getLoc()));
} else {
assert(false &&
"ChangePartition: one of src or dest partitions need to be set");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ use concrete_optimizer::optimization::dag::multi_parameters::optimize::{
KeysetRestriction, MacroParameters, NoSearchSpaceRestriction, RangeRestriction,
SearchSpaceRestriction,
};
use concrete_optimizer::optimization::dag::multi_parameters::partition_cut::PartitionCut;
use concrete_optimizer::optimization::dag::multi_parameters::partition_cut::{
ExternalPartition, PartitionCut,
};
use concrete_optimizer::optimization::dag::multi_parameters::virtual_circuit::generate_virtual_parameters;
use concrete_optimizer::optimization::dag::multi_parameters::{keys_spec, PartitionIndex};
use concrete_optimizer::optimization::dag::solo_key::optimize_generic::{
Expand Down Expand Up @@ -62,35 +64,6 @@ fn caches_from(options: &ffi::Options) -> decomposition::PersistDecompCaches {
)
}

#[derive(Clone)]
pub struct ExternalPartition(
concrete_optimizer::optimization::dag::multi_parameters::partition_cut::ExternalPartition,
);

pub fn get_external_partition(
name: String,
log2_polynomial_size: u64,
glwe_dimension: u64,
internal_dim: u64,
max_variance: f64,
variance: f64,
) -> Box<ExternalPartition> {
Box::new(ExternalPartition(
concrete_optimizer::optimization::dag::multi_parameters::partition_cut::ExternalPartition {
name,
macro_params: MacroParameters {
glwe_params: GlweParameters {
log2_polynomial_size,
glwe_dimension,
},
internal_dim,
},
max_variance,
variance,
},
))
}

pub fn get_noise_br(
options: &ffi::Options,
log2_polynomial_size: u64,
Expand Down Expand Up @@ -852,13 +825,13 @@ impl DagBuilder<'_> {
fn add_change_partition_with_src(
&mut self,
input: ffi::OperatorIndex,
src_partition: &ExternalPartition,
src_partition: &ffi::ExternalPartition,
location: &Location,
) -> ffi::OperatorIndex {
self.0
.add_change_partition(
input.into(),
Some(src_partition.0.clone()),
Some(src_partition.clone().into()),
None,
location.0.clone(),
)
Expand All @@ -868,14 +841,14 @@ impl DagBuilder<'_> {
fn add_change_partition_with_dst(
&mut self,
input: ffi::OperatorIndex,
dst_partition: &ExternalPartition,
dst_partition: &ffi::ExternalPartition,
location: &Location,
) -> ffi::OperatorIndex {
self.0
.add_change_partition(
input.into(),
None,
Some(dst_partition.0.clone()),
Some(dst_partition.clone().into()),
location.0.clone(),
)
.into()
Expand Down Expand Up @@ -915,8 +888,8 @@ fn location_from_string(string: &str) -> Box<Location> {
}

fn generate_virtual_keyset_info(
inputs: Vec<ffi::PartitionDefinition>,
generate_fks: bool,
internal_partitions: Vec<ffi::PartitionDefinition>,
external_partitions: Vec<ffi::ExternalPartition>,
options: &ffi::Options,
) -> ffi::CircuitKeys {
let config = Config {
Expand All @@ -928,13 +901,13 @@ fn generate_virtual_keyset_info(
complexity_model: &CpuComplexity::default(),
};
generate_virtual_parameters(
inputs
internal_partitions
.into_iter()
.map(
|ffi::PartitionDefinition { precision, norm2 }| concrete_optimizer::optimization::dag::multi_parameters::virtual_circuit::PartitionDefinition { precision, norm2 },
|ffi::PartitionDefinition { precision, norm2 }| concrete_optimizer::optimization::dag::multi_parameters::virtual_circuit::InternalPartition { precision, norm2 },
)
.collect(),
generate_fks,
external_partitions.into_iter().map(|part| part.into()).collect(),
config
)
.into()
Expand Down Expand Up @@ -975,6 +948,24 @@ impl Into<Encoding> for ffi::Encoding {
}
}

#[allow(clippy::from_over_into)]
impl Into<ExternalPartition> for ffi::ExternalPartition {
fn into(self) -> ExternalPartition {
ExternalPartition {
name: self.name,
macro_params: MacroParameters {
glwe_params: GlweParameters {
log2_polynomial_size: self.macro_log2_polynomial_size,
glwe_dimension: self.macro_glwe_dimension,
},
internal_dim: self.macro_internal_dim,
},
max_variance: self.max_variance,
variance: self.variance,
}
}
}

#[allow(
unused_must_use,
clippy::needless_lifetimes,
Expand All @@ -1000,8 +991,6 @@ mod ffi {

type Location;

type ExternalPartition;

#[namespace = "concrete_optimizer::utils"]
fn location_unknown() -> Box<Location>;

Expand All @@ -1010,21 +999,11 @@ mod ffi {

#[namespace = "concrete_optimizer::utils"]
fn generate_virtual_keyset_info(
partitions: Vec<PartitionDefinition>,
generate_fks: bool,
internal_partitions: Vec<PartitionDefinition>,
external_partitions: Vec<ExternalPartition>,
options: &Options,
) -> CircuitKeys;

#[namespace = "concrete_optimizer::utils"]
fn get_external_partition(
name: String,
log2_polynomial_size: u64,
glwe_dimension: u64,
internal_dim: u64,
max_variance: f64,
variance: f64,
) -> Box<ExternalPartition>;

#[namespace = "concrete_optimizer::utils"]
fn get_noise_br(
options: &Options,
Expand Down Expand Up @@ -1400,6 +1379,17 @@ mod ffi {
pub precision: u8,
pub norm2: f64,
}

#[namespace = "concrete_optimizer::utils"]
#[derive(Debug, Clone)]
pub struct ExternalPartition {
pub name: String,
pub macro_log2_polynomial_size: u64,
pub macro_glwe_dimension: u64,
pub macro_internal_dim: u64,
pub max_variance: f64,
pub variance: f64,
}
}

fn processing_unit(options: &ffi::Options) -> ProcessingUnit {
Expand Down
Loading
Loading