diff --git a/include/epiworld/models/seirentitiesconnected.hpp b/include/epiworld/models/seirentitiesconnected.hpp index 3dcd20f9..3a96a236 100644 --- a/include/epiworld/models/seirentitiesconnected.hpp +++ b/include/epiworld/models/seirentitiesconnected.hpp @@ -1,6 +1,9 @@ #ifndef EPIWORLD_MODELS_SEIRENTITIESCONNECTED_HPP #define EPIWORLD_MODELS_SEIRENTITIESCONNECTED_HPP +template +class GroupSampler; + /** * @file seirentitiesconnected.hpp * @brief Template for a Susceptible-Exposed-Infected-Removed (SEIR) model with entities @@ -15,6 +18,8 @@ class ModelSEIREntitiesConn : public epiworld::Model static const int INFECTED = 2; static const int RECOVERED = 3; + GroupSampler group_sampler; + ModelSEIREntitiesConn() {}; @@ -43,7 +48,8 @@ class ModelSEIREntitiesConn : public epiworld::Model epiworld_double avg_incubation_days, epiworld_double recovery_rate, std::vector< epiworld_double > entities, - std::vector< std::string > entities_names + std::vector< std::string > entities_names, + std::vector< double > contact_matrix ); /** @@ -67,7 +73,8 @@ class ModelSEIREntitiesConn : public epiworld::Model epiworld_double avg_incubation_days, epiworld_double recovery_rate, std::vector< epiworld_double > entities, - std::vector< std::string > entities_names + std::vector< std::string > entities_names, + std::vector< double > contact_matrix ); ModelSEIREntitiesConn & run( @@ -91,14 +98,17 @@ class ModelSEIREntitiesConn : public epiworld::Model }; +/** + * @brief Weighted sampling of groups + */ template class GroupSampler { private: epiworld::Model & model; - const std::vector< double > & contact_matrix; ///< Contact matrix between groups - const std::vector< size_t > & group_sizes; ///< Sizes of the groups + const std::vector< double > contact_matrix; ///< Contact matrix between groups + const std::vector< size_t > group_sizes; ///< Sizes of the groups std::vector< double > cumulate; ///< Cumulative sum of the contact matrix (row-major for faster access) /** @@ -125,7 +135,8 @@ class GroupSampler { GroupSampler( epiworld::Model & model, const std::vector< double > & contact_matrix, - const std::vector< size_t > & group_sizes + const std::vector< size_t > & group_sizes, + bool normalize = true ): model(model), contact_matrix(contact_matrix), group_sizes(group_sizes) { this->cumulate.resize(contact_matrix.size()); @@ -140,6 +151,18 @@ class GroupSampler { contact_matrix[idx(i, j)]; } + if (normalize) + { + for (size_t i = 0; i < group_sizes.size(); ++i) + { + double sum = 0.0; + for (size_t j = 0; j < group_sizes.size(); ++j) + sum += contact_matrix[idx(i, j, true)]; + for (size_t j = 0; j < group_sizes.size(); ++j) + contact_matrix[idx(i, j, true)] /= sum; + } + } + }; int sample_1(const int origin_group); @@ -261,7 +284,8 @@ inline ModelSEIREntitiesConn::ModelSEIREntitiesConn( epiworld_double avg_incubation_days, epiworld_double recovery_rate, std::vector< epiworld_double > entities, - std::vector< std::string > entities_names + std::vector< std::string > entities_names, + std::vector< double > contact_matrix ) { @@ -273,6 +297,17 @@ inline ModelSEIREntitiesConn::ModelSEIREntitiesConn( e * n ); + std::vector< size_t > group_sizes(entities.size()); + for (size_t i = 0; i < entities.size(); ++i) + group_sizes[i] = static_cast(entities[i] * n); + + // Setting up the group sampler + GroupSampler group_sampler( + model, + contact_matrix, + group_sizes + ); + epiworld::UpdateFun update_susceptible = []( epiworld::Agent * p, epiworld::Model * m ) -> void @@ -454,7 +489,8 @@ inline ModelSEIREntitiesConn::ModelSEIREntitiesConn( epiworld_double avg_incubation_days, epiworld_double recovery_rate, std::vector< epiworld_double > entities, - std::vector< std::string > entity_names + std::vector< std::string > entity_names, + std::vector< double > contact_matrix ) { @@ -468,7 +504,8 @@ inline ModelSEIREntitiesConn::ModelSEIREntitiesConn( avg_incubation_days, recovery_rate, entities, - entity_names + entity_names, + contact_matrix ); return;