diff --git a/epiworld.hpp b/epiworld.hpp index 83bf73ae..4f28fb54 100644 --- a/epiworld.hpp +++ b/epiworld.hpp @@ -6116,11 +6116,7 @@ class Model { bool directed = false; std::vector< VirusPtr > viruses = {}; - std::vector< ToolPtr > tools = {}; - std::vector< epiworld_double > prevalence_tool = {}; - std::vector< bool > prevalence_tool_as_proportion = {}; - std::vector< ToolToAgentFun > tools_dist_funs = {}; std::vector< Entity > entities = {}; std::vector< Entity > entities_backup = {}; @@ -6317,9 +6313,7 @@ class Model { */ ///@{ void add_virus(Virus & v); - void add_tool(Tool & t, epiworld_double preval); - void add_tool_n(Tool & t, epiworld_fast_uint preval); - void add_tool_fun(Tool & t, ToolToAgentFun fun); + void add_tool(Tool & t); void add_entity(Entity e); void rm_virus(size_t virus_pos); void rm_tool(size_t tool_pos); @@ -7122,9 +7116,6 @@ inline Model::Model(const Model & model) : directed(model.directed), viruses(model.viruses), tools(model.tools), - prevalence_tool(model.prevalence_tool), - prevalence_tool_as_proportion(model.prevalence_tool_as_proportion), - tools_dist_funs(model.tools_dist_funs), entities(model.entities), entities_backup(model.entities_backup), rewire_fun(model.rewire_fun), @@ -7192,9 +7183,6 @@ inline Model::Model(Model && model) : viruses(std::move(model.viruses)), // Tools tools(std::move(model.tools)), - prevalence_tool(std::move(model.prevalence_tool)), - prevalence_tool_as_proportion(std::move(model.prevalence_tool_as_proportion)), - tools_dist_funs(std::move(model.tools_dist_funs)), // Entities entities(std::move(model.entities)), entities_backup(std::move(model.entities_backup)), @@ -7265,9 +7253,6 @@ inline Model & Model::operator=(const Model & m) viruses = m.viruses; tools = m.tools; - prevalence_tool = m.prevalence_tool; - prevalence_tool_as_proportion = m.prevalence_tool_as_proportion; - tools_dist_funs = m.tools_dist_funs; entities = m.entities; entities_backup = m.entities_backup; @@ -7507,55 +7492,10 @@ template inline void Model::dist_tools() { - // Starting first infection - int n = size(); - std::vector< size_t > idx(n); - for (epiworld_fast_uint t = 0; t < tools.size(); ++t) + for (auto & tool: tools) { - if (tools_dist_funs[t]) - { - - tools_dist_funs[t](*tools[t], this); - - } else { - - // Picking how many - int nsampled; - if (prevalence_tool_as_proportion[t]) - { - nsampled = static_cast(std::floor(prevalence_tool[t] * size())); - } - else - { - nsampled = static_cast(prevalence_tool[t]); - } - - if (nsampled > static_cast(size())) - throw std::range_error("There are only " + std::to_string(size()) + - " individuals in the population. Cannot add the tool to " + std::to_string(nsampled)); - - ToolPtr tool = tools[t]; - - int n_left = n; - std::iota(idx.begin(), idx.end(), 0); - while (nsampled > 0) - { - int loc = static_cast(floor(runif() * n_left--)); - - population[idx[loc]].add_tool( - tool, - const_cast< Model * >(this), - tool->state_init, tool->queue_init - ); - - nsampled--; - - std::swap(idx[loc], idx[n_left]); - - } - - } + tool->distribute(this); // Apply the events events_run(); @@ -7740,51 +7680,17 @@ inline void Model::add_virus( } template -inline void Model::add_tool(Tool & t, epiworld_double preval) +inline void Model::add_tool(Tool & t) { - if (preval > 1.0) - throw std::range_error("Prevalence of tool cannot be above 1.0"); - - if (preval < 0.0) - throw std::range_error("Prevalence of tool cannot be negative"); - + db.record_tool(t); // Adding the tool to the model (and database.) tools.push_back(std::make_shared< Tool >(t)); - prevalence_tool.push_back(preval); - prevalence_tool_as_proportion.push_back(true); - tools_dist_funs.push_back(nullptr); } -template -inline void Model::add_tool_n(Tool & t, epiworld_fast_uint preval) -{ - - db.record_tool(t); - - tools.push_back(std::make_shared >(t)); - prevalence_tool.push_back(preval); - prevalence_tool_as_proportion.push_back(false); - tools_dist_funs.push_back(nullptr); - -} - -template -inline void Model::add_tool_fun(Tool & t, ToolToAgentFun fun) -{ - - db.record_tool(t); - - tools.push_back(std::make_shared >(t)); - prevalence_tool.push_back(0.0); - prevalence_tool_as_proportion.push_back(false); - tools_dist_funs.push_back(fun); -} - - template inline void Model::add_entity(Entity e) { @@ -7850,8 +7756,6 @@ inline void Model::rm_tool(size_t tool_pos) // Flipping with the last one std::swap(tools[tool_pos], tools[tools.size() - 1]); - std::swap(tools_dist_funs[tool_pos], tools_dist_funs[tools.size() - 1]); - std::swap(prevalence_tool[tool_pos], prevalence_tool[tools.size() - 1]); /* There's an error on windows: https://github.com/UofUEpiBio/epiworldR/actions/runs/4801482395/jobs/8543744180#step:6:84 @@ -7859,20 +7763,8 @@ inline void Model::rm_tool(size_t tool_pos) More clear here: https://stackoverflow.com/questions/58660207/why-doesnt-stdswap-work-on-vectorbool-elements-under-clang-win */ - std::vector::swap( - prevalence_tool_as_proportion[tool_pos], - prevalence_tool_as_proportion[tools.size() - 1] - ); - - // auto old = prevalence_tool_as_proportion[tool_pos]; - // prevalence_tool_as_proportion[tool_pos] = prevalence_tool_as_proportion[tools.size() - 1]; - // prevalence_tool_as_proportion[tools.size() - 1] = old; - tools.pop_back(); - tools_dist_funs.pop_back(); - prevalence_tool.pop_back(); - prevalence_tool_as_proportion.pop_back(); return; @@ -8987,6 +8879,7 @@ inline const Model & Model::print(bool lite) const size_t n_tools_model = tools.size(); for (size_t i = 0u; i < tools.size(); ++i) { + const auto & tool = tools[i]; if ((n_tools_model > 10) && (i >= 10)) { @@ -8999,13 +8892,13 @@ inline const Model & Model::print(bool lite) const if (i < n_tools_model) { - if (prevalence_tool_as_proportion[i]) + if (tool->get_prevalence_as_proportion()) { printf_epiworld( " - %s (baseline prevalence: %.2f%%)\n", - tools[i]->get_name().c_str(), - prevalence_tool[i] * 100.0 + tool->get_name().c_str(), + tool->get_prevalence() * 100.0 ); } @@ -9014,8 +8907,8 @@ inline const Model & Model::print(bool lite) const printf_epiworld( " - %s (baseline prevalence: %i seeds)\n", - tools[i]->get_name().c_str(), - static_cast(prevalence_tool[i]) + tool->get_name().c_str(), + static_cast(tool->get_prevalence()) ); } @@ -9024,7 +8917,7 @@ inline const Model & Model::print(bool lite) const printf_epiworld( " - %s (originated in the model...)\n", - tools[i]->get_name().c_str() + tool->get_name().c_str() ); } @@ -9719,18 +9612,6 @@ inline bool Model::operator==(const Model & other) const ) } - - VECT_MATCH( - prevalence_tool, - other.prevalence_tool, - "tools prevalence don't match" - ) - - VECT_MATCH( - prevalence_tool_as_proportion, - other.prevalence_tool_as_proportion, - "tools as prop don't match" - ) VECT_MATCH( entities, @@ -10254,7 +10135,7 @@ class Virus { void set_prevalence(epiworld_double prevalence, bool as_proportion); bool get_prevalence_as_proportion() const; void distribute(Model * model); - void set_distribution(VirusToAgentFun fun); + void set_dist_fun(VirusToAgentFun fun); ///@} @@ -10370,7 +10251,7 @@ inline Virus::Virus( set_name(name); set_prevalence(prevalence, prevalence_as_proportion); - set_distribution(dist_fun); + set_dist_fun(dist_fun); } template @@ -11080,7 +10961,7 @@ inline void Virus::distribute(Model * model) } template -inline void Virus::set_distribution(VirusToAgentFun fun) +inline void Virus::set_dist_fun(VirusToAgentFun fun) { dist_fun = fun; } @@ -11386,9 +11267,17 @@ class Tool { void set_agent(Agent * p, size_t idx); + epiworld_double prevalence = 0.0; + bool prevalence_as_proportion = false; + ToolToAgentFun dist_fun = nullptr; + public: - Tool(std::string name = "unknown tool"); - // Tool(TSeq d, std::string name = "unknown tool"); + Tool( + std::string name = "unknown tool", + epiworld_double prevalence = 0.0, + bool prevalence_as_proportion = false, + ToolToAgentFun dist_fun = nullptr + ); void set_sequence(TSeq d); void set_sequence(std::shared_ptr d); @@ -11443,6 +11332,13 @@ class Tool { void print() const; + void distribute(Model * model); + + void set_prevalence(epiworld_double p, bool as_proportion = false); + epiworld_double get_prevalence() const; + bool get_prevalence_as_proportion() const; + void set_dist_fun(ToolToAgentFun fun); + }; #endif @@ -11547,9 +11443,16 @@ inline ToolFun tool_fun_logit( } template -inline Tool::Tool(std::string name) +inline Tool::Tool( + std::string name, + epiworld_double prevalence, + bool as_proportion, + ToolToAgentFun dist_fun + ) { set_name(name); + set_prevalence(prevalence, as_proportion); + set_dist_fun(dist_fun); } // template @@ -11963,6 +11866,89 @@ inline void Tool::print() const } +template +inline void Tool::distribute(Model * model) +{ + + if (dist_fun) + { + + dist_fun(*this, model); + + } else { + + // Picking how many + int n_to_distribute; + int n = model->size(); + if (prevalence_as_proportion) + { + n_to_distribute = static_cast(std::floor(prevalence * n)); + + if (n_to_distribute == n) + n_to_distribute--; + } + else + { + n_to_distribute = static_cast(prevalence); + } + + if (n_to_distribute > n) + throw std::range_error("There are only " + std::to_string(n) + + " individuals in the population. Cannot add the tool to " + std::to_string(n_to_distribute)); + + std::vector< int > idx(n); + std::iota(idx.begin(), idx.end(), 0); + auto & population = model->get_agents(); + for (int i = 0u; i < n_to_distribute; ++i) + { + int loc = static_cast( + floor(model->runif() * n--) + ); + + if ((loc > 0) && (loc == n)) + loc--; + + population[idx[loc]].add_tool( + *this, + const_cast< Model * >(model), + state_init, queue_init + ); + + std::swap(idx[loc], idx[n]); + + } + + } + +} + +template +inline void Tool::set_dist_fun(ToolToAgentFun fun) +{ + dist_fun = fun; +} + +template +inline epiworld_double Tool::get_prevalence() const +{ + return prevalence; +} + +template +inline void Tool::set_prevalence( + epiworld_double prevalence, + bool as_proportion +) +{ + this->prevalence = prevalence; + this->prevalence_as_proportion = as_proportion; +} + +template +inline bool Tool::get_prevalence_as_proportion() const +{ + return prevalence_as_proportion; +} #endif /*////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// @@ -16815,11 +16801,11 @@ inline ModelSURV::ModelSURV( model.add_globalevent(surveillance_program, "Surveilance program", -1); // Vaccine tool ----------------------------------------------------------- - epiworld::Tool vax("Vaccine"); + epiworld::Tool vax("Vaccine", prop_vaccinated, true); vax.set_susceptibility_reduction(&model("Vax efficacy")); vax.set_transmission_reduction(&model("Vax redux transmission")); - model.add_tool(vax, prop_vaccinated); + model.add_tool(vax); model.set_name("Surveillance"); diff --git a/examples/00-hello-world/main.cpp b/examples/00-hello-world/main.cpp index fb6a2ea0..c2336807 100644 --- a/examples/00-hello-world/main.cpp +++ b/examples/00-hello-world/main.cpp @@ -23,8 +23,8 @@ int main() virus.set_prob_death(.01); model.add_virus(virus); - epiworld::Tool tool("vaccine"); - model.add_tool(tool, .5); + epiworld::Tool tool("vaccine", .5, true); + model.add_tool(tool); // Generating a random pop model.agents_smallworld(10000, 20, false, .01); diff --git a/examples/03-simple-sir/main.cpp b/examples/03-simple-sir/main.cpp index aed018fa..89253b7a 100644 --- a/examples/03-simple-sir/main.cpp +++ b/examples/03-simple-sir/main.cpp @@ -10,7 +10,7 @@ int main() covid19.set_state(1,2,2); // Creating a tool - epiworld::Tool<> vax("vaccine"); + epiworld::Tool<> vax("vaccine", .5, true); vax.set_susceptibility_reduction(.95); // Creating a model @@ -23,7 +23,7 @@ int main() // Adding the tool and virus model.add_virus(covid19); - model.add_tool(vax, .5); + model.add_tool(vax); // Generating a random pop model.agents_from_adjlist( diff --git a/examples/04-advanced-usage/main.cpp b/examples/04-advanced-usage/main.cpp index 1339a738..a036b5fd 100644 --- a/examples/04-advanced-usage/main.cpp +++ b/examples/04-advanced-usage/main.cpp @@ -72,17 +72,17 @@ int main() { covid19.set_state(1,2,3); // Creating tools --------------------------------------------------------- - epiworld::Tool vaccine("Vaccine"); + epiworld::Tool vaccine("Vaccine", 0.5, true); vaccine.set_susceptibility_reduction(&model("vax efficacy")); vaccine.set_recovery_enhancer(0.4); vaccine.set_death_reduction(&model("vax death")); vaccine.set_transmission_reduction(0.5); - epiworld::Tool mask("Face masks"); + epiworld::Tool mask("Face masks", 0.5, true); mask.set_susceptibility_reduction(0.8); mask.set_transmission_reduction(0.05); - epiworld::Tool immune("Immune system"); + epiworld::Tool immune("Immune system", 1.0, true); immune.set_susceptibility_reduction(&model("imm efficacy")); immune.set_recovery_enhancer(&model("imm recovery")); immune.set_death_reduction(&model("imm death")); @@ -90,16 +90,16 @@ int main() { DAT seq0(base_seq.size(), false); immune.set_sequence(seq0); - epiworld::Tool post_immunity("Post Immune"); + epiworld::Tool post_immunity("Post Immune", 0, true); post_immunity.set_susceptibility_reduction(1.0); // Adding the virus and the tools to the model ---------------------------- model.add_virus(covid19); - model.add_tool(immune, 1.0); - model.add_tool(vaccine, 0.5); - model.add_tool(mask, 0.5); - model.add_tool_n(post_immunity, 0); + model.add_tool(immune); + model.add_tool(vaccine); + model.add_tool(mask); + model.add_tool(post_immunity); // Initializing and printing information about the model ------------------ model.queuing_off(); // Not working with rewiring just yet. diff --git a/examples/05-user-data/main.cpp b/examples/05-user-data/main.cpp index f2be2ac8..2592ebda 100644 --- a/examples/05-user-data/main.cpp +++ b/examples/05-user-data/main.cpp @@ -41,16 +41,16 @@ int main() v.set_prob_infecting(&model("infectiousness")); // Setting up tool --------------------------------------------------------- - epiworld::Tool<> is("immune system"); + epiworld::Tool<> is("immune system", 1.0, true); is.set_susceptibility_reduction(.3); is.set_death_reduction(.9); is.set_recovery_enhancer(&model("recovery")); - epiworld::Tool<> postImm("post immunity"); + epiworld::Tool<> postImm("post immunity", 0, false); postImm.set_susceptibility_reduction(1.0); - model.add_tool(is, 1.0); - model.add_tool_n(postImm, 0u); + model.add_tool(is); + model.add_tool(postImm); model.add_virus(v); model.run(112, 30); model.print(); diff --git a/examples/07-surveillance/07-surveillance.md b/examples/07-surveillance/07-surveillance.md index 722b0d61..46804f51 100644 --- a/examples/07-surveillance/07-surveillance.md +++ b/examples/07-surveillance/07-surveillance.md @@ -61,7 +61,7 @@ int main() covid19.set_infectiousness(.8); // Creating a tool - epiworld::Tool<> vax("vaccine"); + epiworld::Tool<> vax("vaccine", .5, true); vax.set_contagion_reduction(.95); // Creating a model @@ -69,7 +69,7 @@ int main() // Adding the tool and virus model.add_virus(covid19); - model.add_tool(vax, .5); + model.add_tool(vax); // Generating a random pop model.population_from_adjlist( diff --git a/include/epiworld/model-bones.hpp b/include/epiworld/model-bones.hpp index 3c7f7ca2..f57db90d 100644 --- a/include/epiworld/model-bones.hpp +++ b/include/epiworld/model-bones.hpp @@ -135,11 +135,7 @@ class Model { bool directed = false; std::vector< VirusPtr > viruses = {}; - std::vector< ToolPtr > tools = {}; - std::vector< epiworld_double > prevalence_tool = {}; - std::vector< bool > prevalence_tool_as_proportion = {}; - std::vector< ToolToAgentFun > tools_dist_funs = {}; std::vector< Entity > entities = {}; std::vector< Entity > entities_backup = {}; @@ -336,9 +332,7 @@ class Model { */ ///@{ void add_virus(Virus & v); - void add_tool(Tool & t, epiworld_double preval); - void add_tool_n(Tool & t, epiworld_fast_uint preval); - void add_tool_fun(Tool & t, ToolToAgentFun fun); + void add_tool(Tool & t); void add_entity(Entity e); void rm_virus(size_t virus_pos); void rm_tool(size_t tool_pos); diff --git a/include/epiworld/model-meat-print.hpp b/include/epiworld/model-meat-print.hpp index 0de4c86f..d429dfea 100644 --- a/include/epiworld/model-meat-print.hpp +++ b/include/epiworld/model-meat-print.hpp @@ -206,6 +206,7 @@ inline const Model & Model::print(bool lite) const size_t n_tools_model = tools.size(); for (size_t i = 0u; i < tools.size(); ++i) { + const auto & tool = tools[i]; if ((n_tools_model > 10) && (i >= 10)) { @@ -218,13 +219,13 @@ inline const Model & Model::print(bool lite) const if (i < n_tools_model) { - if (prevalence_tool_as_proportion[i]) + if (tool->get_prevalence_as_proportion()) { printf_epiworld( " - %s (baseline prevalence: %.2f%%)\n", - tools[i]->get_name().c_str(), - prevalence_tool[i] * 100.0 + tool->get_name().c_str(), + tool->get_prevalence() * 100.0 ); } @@ -233,8 +234,8 @@ inline const Model & Model::print(bool lite) const printf_epiworld( " - %s (baseline prevalence: %i seeds)\n", - tools[i]->get_name().c_str(), - static_cast(prevalence_tool[i]) + tool->get_name().c_str(), + static_cast(tool->get_prevalence()) ); } @@ -243,7 +244,7 @@ inline const Model & Model::print(bool lite) const printf_epiworld( " - %s (originated in the model...)\n", - tools[i]->get_name().c_str() + tool->get_name().c_str() ); } diff --git a/include/epiworld/model-meat.hpp b/include/epiworld/model-meat.hpp index c52b64b0..75331639 100644 --- a/include/epiworld/model-meat.hpp +++ b/include/epiworld/model-meat.hpp @@ -381,9 +381,6 @@ inline Model::Model(const Model & model) : directed(model.directed), viruses(model.viruses), tools(model.tools), - prevalence_tool(model.prevalence_tool), - prevalence_tool_as_proportion(model.prevalence_tool_as_proportion), - tools_dist_funs(model.tools_dist_funs), entities(model.entities), entities_backup(model.entities_backup), rewire_fun(model.rewire_fun), @@ -451,9 +448,6 @@ inline Model::Model(Model && model) : viruses(std::move(model.viruses)), // Tools tools(std::move(model.tools)), - prevalence_tool(std::move(model.prevalence_tool)), - prevalence_tool_as_proportion(std::move(model.prevalence_tool_as_proportion)), - tools_dist_funs(std::move(model.tools_dist_funs)), // Entities entities(std::move(model.entities)), entities_backup(std::move(model.entities_backup)), @@ -524,9 +518,6 @@ inline Model & Model::operator=(const Model & m) viruses = m.viruses; tools = m.tools; - prevalence_tool = m.prevalence_tool; - prevalence_tool_as_proportion = m.prevalence_tool_as_proportion; - tools_dist_funs = m.tools_dist_funs; entities = m.entities; entities_backup = m.entities_backup; @@ -766,55 +757,10 @@ template inline void Model::dist_tools() { - // Starting first infection - int n = size(); - std::vector< size_t > idx(n); - for (epiworld_fast_uint t = 0; t < tools.size(); ++t) + for (auto & tool: tools) { - if (tools_dist_funs[t]) - { - - tools_dist_funs[t](*tools[t], this); - - } else { - - // Picking how many - int nsampled; - if (prevalence_tool_as_proportion[t]) - { - nsampled = static_cast(std::floor(prevalence_tool[t] * size())); - } - else - { - nsampled = static_cast(prevalence_tool[t]); - } - - if (nsampled > static_cast(size())) - throw std::range_error("There are only " + std::to_string(size()) + - " individuals in the population. Cannot add the tool to " + std::to_string(nsampled)); - - ToolPtr tool = tools[t]; - - int n_left = n; - std::iota(idx.begin(), idx.end(), 0); - while (nsampled > 0) - { - int loc = static_cast(floor(runif() * n_left--)); - - population[idx[loc]].add_tool( - tool, - const_cast< Model * >(this), - tool->state_init, tool->queue_init - ); - - nsampled--; - - std::swap(idx[loc], idx[n_left]); - - } - - } + tool->distribute(this); // Apply the events events_run(); @@ -999,51 +945,17 @@ inline void Model::add_virus( } template -inline void Model::add_tool(Tool & t, epiworld_double preval) +inline void Model::add_tool(Tool & t) { - if (preval > 1.0) - throw std::range_error("Prevalence of tool cannot be above 1.0"); - - if (preval < 0.0) - throw std::range_error("Prevalence of tool cannot be negative"); - + db.record_tool(t); // Adding the tool to the model (and database.) tools.push_back(std::make_shared< Tool >(t)); - prevalence_tool.push_back(preval); - prevalence_tool_as_proportion.push_back(true); - tools_dist_funs.push_back(nullptr); - -} - -template -inline void Model::add_tool_n(Tool & t, epiworld_fast_uint preval) -{ - - db.record_tool(t); - - tools.push_back(std::make_shared >(t)); - prevalence_tool.push_back(preval); - prevalence_tool_as_proportion.push_back(false); - tools_dist_funs.push_back(nullptr); } -template -inline void Model::add_tool_fun(Tool & t, ToolToAgentFun fun) -{ - - db.record_tool(t); - - tools.push_back(std::make_shared >(t)); - prevalence_tool.push_back(0.0); - prevalence_tool_as_proportion.push_back(false); - tools_dist_funs.push_back(fun); -} - - template inline void Model::add_entity(Entity e) { @@ -1109,8 +1021,6 @@ inline void Model::rm_tool(size_t tool_pos) // Flipping with the last one std::swap(tools[tool_pos], tools[tools.size() - 1]); - std::swap(tools_dist_funs[tool_pos], tools_dist_funs[tools.size() - 1]); - std::swap(prevalence_tool[tool_pos], prevalence_tool[tools.size() - 1]); /* There's an error on windows: https://github.com/UofUEpiBio/epiworldR/actions/runs/4801482395/jobs/8543744180#step:6:84 @@ -1118,20 +1028,8 @@ inline void Model::rm_tool(size_t tool_pos) More clear here: https://stackoverflow.com/questions/58660207/why-doesnt-stdswap-work-on-vectorbool-elements-under-clang-win */ - std::vector::swap( - prevalence_tool_as_proportion[tool_pos], - prevalence_tool_as_proportion[tools.size() - 1] - ); - - // auto old = prevalence_tool_as_proportion[tool_pos]; - // prevalence_tool_as_proportion[tool_pos] = prevalence_tool_as_proportion[tools.size() - 1]; - // prevalence_tool_as_proportion[tools.size() - 1] = old; - tools.pop_back(); - tools_dist_funs.pop_back(); - prevalence_tool.pop_back(); - prevalence_tool_as_proportion.pop_back(); return; @@ -2614,18 +2512,6 @@ inline bool Model::operator==(const Model & other) const ) } - - VECT_MATCH( - prevalence_tool, - other.prevalence_tool, - "tools prevalence don't match" - ) - - VECT_MATCH( - prevalence_tool_as_proportion, - other.prevalence_tool_as_proportion, - "tools as prop don't match" - ) VECT_MATCH( entities, diff --git a/include/epiworld/models/surveillance.hpp b/include/epiworld/models/surveillance.hpp index c51eccce..7a6cee23 100644 --- a/include/epiworld/models/surveillance.hpp +++ b/include/epiworld/models/surveillance.hpp @@ -336,11 +336,11 @@ inline ModelSURV::ModelSURV( model.add_globalevent(surveillance_program, "Surveilance program", -1); // Vaccine tool ----------------------------------------------------------- - epiworld::Tool vax("Vaccine"); + epiworld::Tool vax("Vaccine", prop_vaccinated, true); vax.set_susceptibility_reduction(&model("Vax efficacy")); vax.set_transmission_reduction(&model("Vax redux transmission")); - model.add_tool(vax, prop_vaccinated); + model.add_tool(vax); model.set_name("Surveillance"); diff --git a/include/epiworld/tool-bones.hpp b/include/epiworld/tool-bones.hpp index 8e877eab..98d3c319 100644 --- a/include/epiworld/tool-bones.hpp +++ b/include/epiworld/tool-bones.hpp @@ -50,9 +50,17 @@ class Tool { void set_agent(Agent * p, size_t idx); + epiworld_double prevalence = 0.0; + bool prevalence_as_proportion = false; + ToolToAgentFun dist_fun = nullptr; + public: - Tool(std::string name = "unknown tool"); - // Tool(TSeq d, std::string name = "unknown tool"); + Tool( + std::string name = "unknown tool", + epiworld_double prevalence = 0.0, + bool prevalence_as_proportion = false, + ToolToAgentFun dist_fun = nullptr + ); void set_sequence(TSeq d); void set_sequence(std::shared_ptr d); @@ -107,6 +115,13 @@ class Tool { void print() const; + void distribute(Model * model); + + void set_prevalence(epiworld_double p, bool as_proportion = false); + epiworld_double get_prevalence() const; + bool get_prevalence_as_proportion() const; + void set_dist_fun(ToolToAgentFun fun); + }; #endif \ No newline at end of file diff --git a/include/epiworld/tool-meat.hpp b/include/epiworld/tool-meat.hpp index 604ca2ec..b4ae3906 100644 --- a/include/epiworld/tool-meat.hpp +++ b/include/epiworld/tool-meat.hpp @@ -81,9 +81,16 @@ inline ToolFun tool_fun_logit( } template -inline Tool::Tool(std::string name) +inline Tool::Tool( + std::string name, + epiworld_double prevalence, + bool as_proportion, + ToolToAgentFun dist_fun + ) { set_name(name); + set_prevalence(prevalence, as_proportion); + set_dist_fun(dist_fun); } // template @@ -497,4 +504,87 @@ inline void Tool::print() const } +template +inline void Tool::distribute(Model * model) +{ + + if (dist_fun) + { + + dist_fun(*this, model); + + } else { + + // Picking how many + int n_to_distribute; + int n = model->size(); + if (prevalence_as_proportion) + { + n_to_distribute = static_cast(std::floor(prevalence * n)); + + if (n_to_distribute == n) + n_to_distribute--; + } + else + { + n_to_distribute = static_cast(prevalence); + } + + if (n_to_distribute > n) + throw std::range_error("There are only " + std::to_string(n) + + " individuals in the population. Cannot add the tool to " + std::to_string(n_to_distribute)); + + std::vector< int > idx(n); + std::iota(idx.begin(), idx.end(), 0); + auto & population = model->get_agents(); + for (int i = 0u; i < n_to_distribute; ++i) + { + int loc = static_cast( + floor(model->runif() * n--) + ); + + if ((loc > 0) && (loc == n)) + loc--; + + population[idx[loc]].add_tool( + *this, + const_cast< Model * >(model), + state_init, queue_init + ); + + std::swap(idx[loc], idx[n]); + + } + + } + +} + +template +inline void Tool::set_dist_fun(ToolToAgentFun fun) +{ + dist_fun = fun; +} + +template +inline epiworld_double Tool::get_prevalence() const +{ + return prevalence; +} + +template +inline void Tool::set_prevalence( + epiworld_double prevalence, + bool as_proportion +) +{ + this->prevalence = prevalence; + this->prevalence_as_proportion = as_proportion; +} + +template +inline bool Tool::get_prevalence_as_proportion() const +{ + return prevalence_as_proportion; +} #endif \ No newline at end of file diff --git a/include/epiworld/virus-bones.hpp b/include/epiworld/virus-bones.hpp index 8820f939..d3014c5e 100644 --- a/include/epiworld/virus-bones.hpp +++ b/include/epiworld/virus-bones.hpp @@ -174,7 +174,7 @@ class Virus { void set_prevalence(epiworld_double prevalence, bool as_proportion); bool get_prevalence_as_proportion() const; void distribute(Model * model); - void set_distribution(VirusToAgentFun fun); + void set_dist_fun(VirusToAgentFun fun); ///@} diff --git a/include/epiworld/virus-meat.hpp b/include/epiworld/virus-meat.hpp index 8c571d58..1270a6f0 100644 --- a/include/epiworld/virus-meat.hpp +++ b/include/epiworld/virus-meat.hpp @@ -89,7 +89,7 @@ inline Virus::Virus( set_name(name); set_prevalence(prevalence, prevalence_as_proportion); - set_distribution(dist_fun); + set_dist_fun(dist_fun); } template @@ -799,7 +799,7 @@ inline void Virus::distribute(Model * model) } template -inline void Virus::set_distribution(VirusToAgentFun fun) +inline void Virus::set_dist_fun(VirusToAgentFun fun) { dist_fun = fun; } diff --git a/tests/00-cloning-model.cpp b/tests/00-cloning-model.cpp index ccdad29f..5338b94b 100644 --- a/tests/00-cloning-model.cpp +++ b/tests/00-cloning-model.cpp @@ -10,14 +10,14 @@ EPIWORLD_TEST_CASE("Cloning", "[clone]") { m.add_state("Recovered"); epiworld::Virus v("covid 19", 0.5, true); - epiworld::Tool t; + epiworld::Tool t("vax", .5, true); v.set_state(0, 1); m.seed(1333); m.agents_smallworld(1000); m.add_virus(v); - m.add_tool(t, .5); + m.add_tool(t); // Cloning epiworld::Model m2 = m; diff --git a/tests/01-sample.cpp b/tests/01-sample.cpp index 4e9c379f..0ec0b1a4 100644 --- a/tests/01-sample.cpp +++ b/tests/01-sample.cpp @@ -23,8 +23,8 @@ int main() virus.set_prob_death(.01); model.add_virus(virus); - epiworld::Tool tool("vaccine"); - model.add_tool(tool, .5); + epiworld::Tool tool("vaccine", .5, true); + model.add_tool(tool); // Generating a random pop model.agents_smallworld(10000); diff --git a/tests/05-mixing.cpp b/tests/05-mixing.cpp index a04eb1b8..a7fc57b7 100644 --- a/tests/05-mixing.cpp +++ b/tests/05-mixing.cpp @@ -28,7 +28,7 @@ EPIWORLD_TEST_CASE("SEIRMixing", "[SEIR-mixing]") { // Copy the original virus Virus<> v1 = model.get_virus(0); model.rm_virus(0); - v1.set_distribution(dist_virus<>(0)); + v1.set_dist_fun(dist_virus<>(0)); model.add_virus(v1); diff --git a/tests/06-mixing.cpp b/tests/06-mixing.cpp index c21e06d0..1afa0e1f 100644 --- a/tests/06-mixing.cpp +++ b/tests/06-mixing.cpp @@ -29,7 +29,7 @@ EPIWORLD_TEST_CASE("SIRMixing", "[SIR-mixing]") { // Copy the original virus Virus<> v1 = model.get_virus(0); model.rm_virus(0); - v1.set_distribution(dist_virus<>(0)); + v1.set_dist_fun(dist_virus<>(0)); model.add_virus(v1);