Skip to content

Commit

Permalink
Sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon committed Apr 12, 2024
1 parent f5fe2cb commit f337325
Showing 1 changed file with 45 additions and 8 deletions.
53 changes: 45 additions & 8 deletions include/epiworld/models/seirentitiesconnected.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#ifndef EPIWORLD_MODELS_SEIRENTITIESCONNECTED_HPP
#define EPIWORLD_MODELS_SEIRENTITIESCONNECTED_HPP

template<typename TSeq>
class GroupSampler;

/**
* @file seirentitiesconnected.hpp
* @brief Template for a Susceptible-Exposed-Infected-Removed (SEIR) model with entities
Expand All @@ -15,6 +18,8 @@ class ModelSEIREntitiesConn : public epiworld::Model<TSeq>
static const int INFECTED = 2;
static const int RECOVERED = 3;

GroupSampler<TSeq> group_sampler;


ModelSEIREntitiesConn() {};

Expand Down Expand Up @@ -43,7 +48,8 @@ class ModelSEIREntitiesConn : public epiworld::Model<TSeq>
epiworld_double avg_incubation_days,
epiworld_double recovery_rate,
std::vector< epiworld_double > entities,
std::vector< std::string > entities_names
std::vector< std::string > entities_names,
std::vector< double > contact_matrix
);

/**
Expand All @@ -67,7 +73,8 @@ class ModelSEIREntitiesConn : public epiworld::Model<TSeq>
epiworld_double avg_incubation_days,
epiworld_double recovery_rate,
std::vector< epiworld_double > entities,
std::vector< std::string > entities_names
std::vector< std::string > entities_names,
std::vector< double > contact_matrix
);

ModelSEIREntitiesConn<TSeq> & run(
Expand All @@ -91,14 +98,17 @@ class ModelSEIREntitiesConn : public epiworld::Model<TSeq>

};

/**
* @brief Weighted sampling of groups
*/
template<typename TSeq>
class GroupSampler {

private:

epiworld::Model<TSeq> & model;
const std::vector< double > & contact_matrix; ///< Contact matrix between groups
const std::vector< size_t > & group_sizes; ///< Sizes of the groups
const std::vector< double > contact_matrix; ///< Contact matrix between groups
const std::vector< size_t > group_sizes; ///< Sizes of the groups
std::vector< double > cumulate; ///< Cumulative sum of the contact matrix (row-major for faster access)

/**
Expand All @@ -125,7 +135,8 @@ class GroupSampler {
GroupSampler(
epiworld::Model<TSeq> & model,
const std::vector< double > & contact_matrix,
const std::vector< size_t > & group_sizes
const std::vector< size_t > & group_sizes,
bool normalize = true
): model(model), contact_matrix(contact_matrix), group_sizes(group_sizes) {

this->cumulate.resize(contact_matrix.size());
Expand All @@ -140,6 +151,18 @@ class GroupSampler {
contact_matrix[idx(i, j)];
}

if (normalize)
{
for (size_t i = 0; i < group_sizes.size(); ++i)
{
double sum = 0.0;
for (size_t j = 0; j < group_sizes.size(); ++j)
sum += contact_matrix[idx(i, j, true)];
for (size_t j = 0; j < group_sizes.size(); ++j)
contact_matrix[idx(i, j, true)] /= sum;
}
}

};

int sample_1(const int origin_group);
Expand Down Expand Up @@ -261,7 +284,8 @@ inline ModelSEIREntitiesConn<TSeq>::ModelSEIREntitiesConn(
epiworld_double avg_incubation_days,
epiworld_double recovery_rate,
std::vector< epiworld_double > entities,
std::vector< std::string > entities_names
std::vector< std::string > entities_names,
std::vector< double > contact_matrix
)
{

Expand All @@ -273,6 +297,17 @@ inline ModelSEIREntitiesConn<TSeq>::ModelSEIREntitiesConn(
e * n
);

std::vector< size_t > group_sizes(entities.size());
for (size_t i = 0; i < entities.size(); ++i)
group_sizes[i] = static_cast<size_t>(entities[i] * n);

// Setting up the group sampler
GroupSampler<TSeq> group_sampler(
model,
contact_matrix,
group_sizes
);

epiworld::UpdateFun<TSeq> update_susceptible = [](
epiworld::Agent<TSeq> * p, epiworld::Model<TSeq> * m
) -> void
Expand Down Expand Up @@ -454,7 +489,8 @@ inline ModelSEIREntitiesConn<TSeq>::ModelSEIREntitiesConn(
epiworld_double avg_incubation_days,
epiworld_double recovery_rate,
std::vector< epiworld_double > entities,
std::vector< std::string > entity_names
std::vector< std::string > entity_names,
std::vector< double > contact_matrix
)
{

Expand All @@ -468,7 +504,8 @@ inline ModelSEIREntitiesConn<TSeq>::ModelSEIREntitiesConn(
avg_incubation_days,
recovery_rate,
entities,
entity_names
entity_names,
contact_matrix
);

return;
Expand Down

0 comments on commit f337325

Please sign in to comment.