Skip to content

Commit

Permalink
refactor: streamline assort
Browse files Browse the repository at this point in the history
  • Loading branch information
mcmcgrath13 committed Feb 18, 2021
1 parent 4cb4a7d commit bfcff90
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 25 deletions.
2 changes: 1 addition & 1 deletion scripts/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

outdir = "results"
save_pop = False
setting = "scott"
setting = "nyc-msm"
params_path = "tests/params/integration_base.yml"
sweepfile = None
rows = None
Expand Down
46 changes: 22 additions & 24 deletions titan/partnering.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,53 +53,51 @@ def select_partner(
return None

if params.features.assort_mix:
random_partner = None
assort_attrs = get_assort_attrs(params.assort_mix.values(), agent, rand_gen)
for partner in utils.safe_shuffle(eligible, rand_gen):
if is_assortable(partner, assort_attrs):
random_partner = partner
break
else:
random_partner = utils.safe_random_choice(eligible, rand_gen)
match_fns = get_match_fns(params.assort_mix.values(), agent, rand_gen)
# if no definitions match this agent, don't try to assort
if len(match_fns) > 0:
for partner in utils.safe_shuffle(eligible, rand_gen):
if is_assortable(partner, match_fns):
return partner
return None

return random_partner
return utils.safe_random_choice(eligible, rand_gen)


# does an agent match the criteria of the randomly chosen assort values?
def is_assortable(agent, assort_attrs):
for attr, match_fn in assort_attrs.items():
if not match_fn(agent, attr):
def is_assortable(agent, match_fns):
for match_fn in match_fns:
if not match_fn(agent):
return False

return True


# if this assort rule applies for this agent, get the randomly chosen value the partner must have
def get_assort_attrs(assort_defs, agent, rand_gen):
assort_attrs = {}
# if this assort rule applies for this agent, get the match function for a potential partner
def get_match_fns(assort_defs, agent, rand_gen):
match_fns = []
for assort_def in assort_defs:
if getattr(agent, assort_def.attribute) == assort_def.agent_value:
assort_attrs[assort_def.attribute] = get_assort_attr_value(
assort_def, rand_gen
)
match_fns.append(get_match_fn(assort_def, rand_gen))

return assort_attrs
return match_fns


# given an assort def, randomly select the type the partner must have given the weights
def get_assort_attr_value(assort_def, rand_gen):
# given an assort def, randomly select the type the partner must have given the
# weights and return a function to determin if a potential partner matches it
def get_match_fn(assort_def, rand_gen):
partner_types = list(assort_def.partner_values.keys())
partner_weights = [assort_def.partner_values[p] for p in partner_types]
partner_type = utils.safe_random_choice(
partner_types, rand_gen, weights=partner_weights
)

# python is a little weird about what gets captured in a lam
attr = assort_def.attribute
if partner_type == "__other__":
partner_types.remove("__other__")
return lambda ag, attr: str(getattr(ag, attr)) not in partner_types
return lambda ag: str(getattr(ag, attr)) not in partner_types
else:
return lambda ag, attr: str(getattr(ag, attr)) == partner_type
return lambda ag: str(getattr(ag, attr)) == partner_type


@utils.memo
Expand Down

0 comments on commit bfcff90

Please sign in to comment.