Skip to content

Commit

Permalink
Adding mixing matrix! Still need to set dist_ent without replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon committed Apr 15, 2024
1 parent f337325 commit ed0f80d
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 29 deletions.
2 changes: 1 addition & 1 deletion examples/11-entities/Makefile
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
main.o: main.cpp
g++ -std=c++17 -g main.cpp -o main.o
g++ -std=c++17 -O3 -g main.cpp -o main.o
11 changes: 9 additions & 2 deletions examples/11-entities/main.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
#define EPI_DEBUG
// #define EPI_DEBUG
#include "../../include/epiworld/epiworld.hpp"

using namespace epiworld;

int main() {

std::vector< double > contact_matrix = {
0.9, 0.1, 0.1,
0.05, 0.8, .2,
0.05, 0.1, 0.7
};

epimodels::ModelSEIREntitiesConn model(
"Flu", // std::string vname,
100000, // epiworld_fast_uint n,
Expand All @@ -14,7 +20,8 @@ int main() {
4.0,// epiworld_double avg_incubation_days,
1.0/7.0,// epiworld_double recovery_rate,
{.1, .1, .8},// std::vector< epiworld_double > entities,
{"A", "B", "C"}// std::vector< std::string > entities_names
{"A", "B", "C"},// std::vector< std::string > entities_names
contact_matrix
);

// Running and checking the results
Expand Down
5 changes: 4 additions & 1 deletion include/epiworld/entity-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ template<typename TSeq>
inline Agent<TSeq> * Entity<TSeq>::operator[](size_t i)
{
if (n_agents <= i)
throw std::logic_error("There are not that many agents in this entity.");
throw std::logic_error(
"There are not that many agents in this entity. " +
std::to_string(n_agents) + " <= " + std::to_string(i)
);

return &model->get_agents()[i];
}
Expand Down
64 changes: 39 additions & 25 deletions include/epiworld/models/seirentitiesconnected.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ class ModelSEIREntitiesConn : public epiworld::Model<TSeq>
static const int INFECTED = 2;
static const int RECOVERED = 3;

GroupSampler<TSeq> group_sampler;


ModelSEIREntitiesConn() {};


Expand Down Expand Up @@ -106,9 +103,10 @@ class GroupSampler {

private:

epiworld::Model<TSeq> & model;
const std::vector< double > contact_matrix; ///< Contact matrix between groups
const std::vector< size_t > group_sizes; ///< Sizes of the groups
epiworld::Model<TSeq> * model;
std::vector< Entity<TSeq> > * entities;
std::vector< double > contact_matrix; ///< Contact matrix between groups
std::vector< size_t > group_sizes; ///< Sizes of the groups
std::vector< double > cumulate; ///< Cumulative sum of the contact matrix (row-major for faster access)

/**
Expand All @@ -132,18 +130,22 @@ class GroupSampler {

public:

GroupSampler() {};

GroupSampler(
epiworld::Model<TSeq> & model,
const std::vector< double > & contact_matrix,
const std::vector< size_t > & group_sizes,
epiworld::Model<TSeq> * model_,
const std::vector< double > & contact_matrix_,
const std::vector< size_t > & group_sizes_,
bool normalize = true
): model(model), contact_matrix(contact_matrix), group_sizes(group_sizes) {
): model(model_), contact_matrix(contact_matrix_), group_sizes(group_sizes_) {

entities = &model->get_entities();

this->cumulate.resize(contact_matrix.size());
std::fill(cumulate.begin(), cumulate.end(), 0.0);

// Cumulative sum
for (size_t j = 1; j < group_sizes.size(); ++j)
for (size_t j = 0; j < group_sizes.size(); ++j)
{
for (size_t i = 0; i < group_sizes.size(); ++i)
cumulate[idx(i, j, true)] +=
Expand Down Expand Up @@ -180,7 +182,7 @@ int GroupSampler<TSeq>::sample_1(const int origin_group)
{

// Random number
double r = model.runif();
double r = model->runif();

// Finding the group
size_t j = 0;
Expand All @@ -194,11 +196,11 @@ int GroupSampler<TSeq>::sample_1(const int origin_group)
std::floor(r * group_sizes[j])
);

// Making sure we are not picling outside of the group
// Making sure we are not picking outside of the group
if (res >= static_cast<int>(group_sizes[j]))
res = static_cast<int>(group_sizes[j]) - 1;

return res;
return entities->at(j)[res]->get_id();

}

Expand Down Expand Up @@ -302,40 +304,52 @@ inline ModelSEIREntitiesConn<TSeq>::ModelSEIREntitiesConn(
group_sizes[i] = static_cast<size_t>(entities[i] * n);

// Setting up the group sampler
GroupSampler<TSeq> group_sampler(
model,
contact_matrix,
group_sizes
std::shared_ptr<GroupSampler<TSeq>> group_sampler =
std::make_shared<GroupSampler<TSeq>>(
dynamic_cast<Model<TSeq>*>(&model),
contact_matrix,
group_sizes
);

epiworld::UpdateFun<TSeq> update_susceptible = [](
epiworld::UpdateFun<TSeq> update_susceptible = [group_sampler](
epiworld::Agent<TSeq> * p, epiworld::Model<TSeq> * m
) -> void
{

if (p->get_n_entities() == 0)
return;

// Sampling how many individuals
int ndraw = m->rbinom();

if (ndraw == 0)
return;

// Sampling from the agent's entities
epiworld::AgentsSample<TSeq> sample(m, *p, ndraw, {}, true);
std::vector< size_t > sample(ndraw);
group_sampler->sample_n(
sample,
p->get_entity(0u).get_id(),
ndraw
);

// Drawing from the set
int nviruses_tmp = 0;
for (const auto & neighbor: sample)
auto & agents = m->get_agents();
for (const auto & i: sample)
{

auto neighbor = agents[i];

// Can't sample itself
if (neighbor->get_id() == static_cast<int>(p->get_id()))
if (neighbor.get_id() == static_cast<int>(p->get_id()))
continue;

// If the neighbor is infected, then proceed
if (neighbor->get_state() == ModelSEIREntitiesConn<TSeq>::INFECTED)
if (neighbor.get_state() == ModelSEIREntitiesConn<TSeq>::INFECTED)
{

auto & v = neighbor->get_virus();
auto & v = neighbor.get_virus();

#ifdef EPI_DEBUG
if (nviruses_tmp >= static_cast<int>(m->array_virus_tmp.size()))
Expand All @@ -346,7 +360,7 @@ inline ModelSEIREntitiesConn<TSeq>::ModelSEIREntitiesConn(
m->array_double_tmp[nviruses_tmp] =
(1.0 - p->get_susceptibility_reduction(v, m)) *
v->get_prob_infecting(m) *
(1.0 - neighbor->get_transmission_reduction(v, m))
(1.0 - neighbor.get_transmission_reduction(v, m))
;

m->array_virus_tmp[nviruses_tmp++] = &(*v);
Expand Down

0 comments on commit ed0f80d

Please sign in to comment.