From ed0f80da3b597b10361e6c44acdc632a6d188b71 Mon Sep 17 00:00:00 2001 From: "George G. Vega Yon" Date: Mon, 15 Apr 2024 00:06:11 -0600 Subject: [PATCH] Adding mixing matrix! Still need to set dist_ent without replacement --- examples/11-entities/Makefile | 2 +- examples/11-entities/main.cpp | 11 +++- include/epiworld/entity-meat.hpp | 5 +- .../epiworld/models/seirentitiesconnected.hpp | 64 +++++++++++-------- 4 files changed, 53 insertions(+), 29 deletions(-) diff --git a/examples/11-entities/Makefile b/examples/11-entities/Makefile index 86dbde3c..0c4d9f21 100644 --- a/examples/11-entities/Makefile +++ b/examples/11-entities/Makefile @@ -1,2 +1,2 @@ main.o: main.cpp - g++ -std=c++17 -g main.cpp -o main.o \ No newline at end of file + g++ -std=c++17 -O3 -g main.cpp -o main.o \ No newline at end of file diff --git a/examples/11-entities/main.cpp b/examples/11-entities/main.cpp index da3a298f..b8b8d6e7 100644 --- a/examples/11-entities/main.cpp +++ b/examples/11-entities/main.cpp @@ -1,10 +1,16 @@ -#define EPI_DEBUG +// #define EPI_DEBUG #include "../../include/epiworld/epiworld.hpp" using namespace epiworld; int main() { + std::vector< double > contact_matrix = { + 0.9, 0.1, 0.1, + 0.05, 0.8, .2, + 0.05, 0.1, 0.7 + }; + epimodels::ModelSEIREntitiesConn model( "Flu", // std::string vname, 100000, // epiworld_fast_uint n, @@ -14,7 +20,8 @@ int main() { 4.0,// epiworld_double avg_incubation_days, 1.0/7.0,// epiworld_double recovery_rate, {.1, .1, .8},// std::vector< epiworld_double > entities, - {"A", "B", "C"}// std::vector< std::string > entities_names + {"A", "B", "C"},// std::vector< std::string > entities_names + contact_matrix ); // Running and checking the results diff --git a/include/epiworld/entity-meat.hpp b/include/epiworld/entity-meat.hpp index 4c5a30d4..ef01b507 100644 --- a/include/epiworld/entity-meat.hpp +++ b/include/epiworld/entity-meat.hpp @@ -92,7 +92,10 @@ template inline Agent * Entity::operator[](size_t i) { if (n_agents <= i) - throw std::logic_error("There are not that many agents in this entity."); + throw std::logic_error( + "There are not that many agents in this entity. " + + std::to_string(n_agents) + " <= " + std::to_string(i) + ); return &model->get_agents()[i]; } diff --git a/include/epiworld/models/seirentitiesconnected.hpp b/include/epiworld/models/seirentitiesconnected.hpp index 3a96a236..bc65cd61 100644 --- a/include/epiworld/models/seirentitiesconnected.hpp +++ b/include/epiworld/models/seirentitiesconnected.hpp @@ -18,9 +18,6 @@ class ModelSEIREntitiesConn : public epiworld::Model static const int INFECTED = 2; static const int RECOVERED = 3; - GroupSampler group_sampler; - - ModelSEIREntitiesConn() {}; @@ -106,9 +103,10 @@ class GroupSampler { private: - epiworld::Model & model; - const std::vector< double > contact_matrix; ///< Contact matrix between groups - const std::vector< size_t > group_sizes; ///< Sizes of the groups + epiworld::Model * model; + std::vector< Entity > * entities; + std::vector< double > contact_matrix; ///< Contact matrix between groups + std::vector< size_t > group_sizes; ///< Sizes of the groups std::vector< double > cumulate; ///< Cumulative sum of the contact matrix (row-major for faster access) /** @@ -132,18 +130,22 @@ class GroupSampler { public: + GroupSampler() {}; + GroupSampler( - epiworld::Model & model, - const std::vector< double > & contact_matrix, - const std::vector< size_t > & group_sizes, + epiworld::Model * model_, + const std::vector< double > & contact_matrix_, + const std::vector< size_t > & group_sizes_, bool normalize = true - ): model(model), contact_matrix(contact_matrix), group_sizes(group_sizes) { + ): model(model_), contact_matrix(contact_matrix_), group_sizes(group_sizes_) { + + entities = &model->get_entities(); this->cumulate.resize(contact_matrix.size()); std::fill(cumulate.begin(), cumulate.end(), 0.0); // Cumulative sum - for (size_t j = 1; j < group_sizes.size(); ++j) + for (size_t j = 0; j < group_sizes.size(); ++j) { for (size_t i = 0; i < group_sizes.size(); ++i) cumulate[idx(i, j, true)] += @@ -180,7 +182,7 @@ int GroupSampler::sample_1(const int origin_group) { // Random number - double r = model.runif(); + double r = model->runif(); // Finding the group size_t j = 0; @@ -194,11 +196,11 @@ int GroupSampler::sample_1(const int origin_group) std::floor(r * group_sizes[j]) ); - // Making sure we are not picling outside of the group + // Making sure we are not picking outside of the group if (res >= static_cast(group_sizes[j])) res = static_cast(group_sizes[j]) - 1; - return res; + return entities->at(j)[res]->get_id(); } @@ -302,17 +304,21 @@ inline ModelSEIREntitiesConn::ModelSEIREntitiesConn( group_sizes[i] = static_cast(entities[i] * n); // Setting up the group sampler - GroupSampler group_sampler( - model, - contact_matrix, - group_sizes + std::shared_ptr> group_sampler = + std::make_shared>( + dynamic_cast*>(&model), + contact_matrix, + group_sizes ); - epiworld::UpdateFun update_susceptible = []( + epiworld::UpdateFun update_susceptible = [group_sampler]( epiworld::Agent * p, epiworld::Model * m ) -> void { + if (p->get_n_entities() == 0) + return; + // Sampling how many individuals int ndraw = m->rbinom(); @@ -320,22 +326,30 @@ inline ModelSEIREntitiesConn::ModelSEIREntitiesConn( return; // Sampling from the agent's entities - epiworld::AgentsSample sample(m, *p, ndraw, {}, true); + std::vector< size_t > sample(ndraw); + group_sampler->sample_n( + sample, + p->get_entity(0u).get_id(), + ndraw + ); // Drawing from the set int nviruses_tmp = 0; - for (const auto & neighbor: sample) + auto & agents = m->get_agents(); + for (const auto & i: sample) { + auto neighbor = agents[i]; + // Can't sample itself - if (neighbor->get_id() == static_cast(p->get_id())) + if (neighbor.get_id() == static_cast(p->get_id())) continue; // If the neighbor is infected, then proceed - if (neighbor->get_state() == ModelSEIREntitiesConn::INFECTED) + if (neighbor.get_state() == ModelSEIREntitiesConn::INFECTED) { - auto & v = neighbor->get_virus(); + auto & v = neighbor.get_virus(); #ifdef EPI_DEBUG if (nviruses_tmp >= static_cast(m->array_virus_tmp.size())) @@ -346,7 +360,7 @@ inline ModelSEIREntitiesConn::ModelSEIREntitiesConn( m->array_double_tmp[nviruses_tmp] = (1.0 - p->get_susceptibility_reduction(v, m)) * v->get_prob_infecting(m) * - (1.0 - neighbor->get_transmission_reduction(v, m)) + (1.0 - neighbor.get_transmission_reduction(v, m)) ; m->array_virus_tmp[nviruses_tmp++] = &(*v);