diff --git a/include/epiworld/model-bones.hpp b/include/epiworld/model-bones.hpp index 033c0d42..d8b11694 100644 --- a/include/epiworld/model-bones.hpp +++ b/include/epiworld/model-bones.hpp @@ -372,6 +372,12 @@ class Model { const std::vector & entities_ids ); + void load_agents_entities_ties( + const int * agents_id, + const int * entities_id, + size_t n + ); + /** * @name Accessing population of the model * diff --git a/include/epiworld/model-meat.hpp b/include/epiworld/model-meat.hpp index 4f1caea3..00f81c1a 100644 --- a/include/epiworld/model-meat.hpp +++ b/include/epiworld/model-meat.hpp @@ -1359,46 +1359,87 @@ inline void Model::load_agents_entities_ties( const std::vector< int > & entities_ids ) { + // Checking the size if (agents_ids.size() != entities_ids.size()) throw std::length_error( - std::string("agents_ids (") + + std::string("The size of agents_ids (") + std::to_string(agents_ids.size()) + std::string(") and entities_ids (") + std::to_string(entities_ids.size()) + - std::string(") should match.") + std::string(") must be the same.") ); + return this->load_agents_entities_ties( + agents_ids.data(), + entities_ids.data(), + agents_ids.size() + ); + +} - size_t n_entries = agents_ids.size(); - for (size_t i = 0u; i < n_entries; ++i) +template +inline void Model::load_agents_entities_ties( + const int * agents_ids, + const int * entities_ids, + size_t n +) { + + auto get_agent = [agents_ids](int i) -> int { + return *(agents_ids + i); + }; + + auto get_entity = [entities_ids](int i) -> int { + return *(entities_ids + i); + }; + + for (size_t i = 0u; i < n; ++i) { - if (agents_ids[i] >= this->population.size()) + if (get_agent(i) < 0) + throw std::length_error( + std::string("agents_ids[") + + std::to_string(i) + + std::string("] = ") + + std::to_string(get_agent(i)) + + std::string(" is negative.") + ); + + if (get_entity(i) < 0) + throw std::length_error( + std::string("entities_ids[") + + std::to_string(i) + + std::string("] = ") + + std::to_string(get_entity(i)) + + std::string(" is negative.") + ); + + int pop_size = static_cast(this->population.size()); + if (get_agent(i) >= pop_size) throw std::length_error( std::string("agents_ids[") + std::to_string(i) + std::string("] = ") + - std::to_string(agents_ids[i]) + + std::to_string(get_agent(i)) + std::string(" is out of range (population size: ") + - std::to_string(this->population.size()) + + std::to_string(pop_size) + std::string(").") ); - - if (entities_ids[i] >= this->entities.size()) + int ent_size = static_cast(this->entities.size()); + if (get_entity(i) >= ent_size) throw std::length_error( std::string("entities_ids[") + std::to_string(i) + std::string("] = ") + - std::to_string(entities_ids[i]) + + std::to_string(get_entity(i)) + std::string(" is out of range (entities size: ") + - std::to_string(this->entities.size()) + + std::to_string(ent_size) + std::string(").") ); // Adding the entity to the agent - this->population[agents_ids[i]].add_entity( - this->entities[entities_ids[i]], + this->population[get_agent(i)].add_entity( + this->entities[get_entity(i)], nullptr /* Immediately add it to the agent */ ); diff --git a/include/epiworld/models/seirmixing.hpp b/include/epiworld/models/seirmixing.hpp index 91c86a93..50304a0e 100644 --- a/include/epiworld/models/seirmixing.hpp +++ b/include/epiworld/models/seirmixing.hpp @@ -164,7 +164,10 @@ inline void ModelSEIRMixing::update_infected() { if (a.get_state() == ModelSEIRMixing::INFECTED) - infected[a.get_entity(0u).get_id()].push_back(&a); + { + if (a.get_n_entities() > 0u) + infected[a.get_entity(0u).get_id()].push_back(&a); + } } diff --git a/include/epiworld/models/sirmixing.hpp b/include/epiworld/models/sirmixing.hpp index 74c2169a..d87a7eae 100644 --- a/include/epiworld/models/sirmixing.hpp +++ b/include/epiworld/models/sirmixing.hpp @@ -159,7 +159,10 @@ inline void ModelSIRMixing::update_infected_list() { if (a.get_state() == ModelSIRMixing::INFECTED) - infected[a.get_entity(0u).get_id()].push_back(&a); + { + if (a.get_n_entities() > 0u) + infected[a.get_entity(0u).get_id()].push_back(&a); + } }