Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Geo based access POC #1286

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions integrations/server/test_differentiated_access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import requests

# third party
import mysql.connector

# frirst party
from delphi.epidata.acquisition.covidcast.test_utils import (
CovidcastBase,
CovidcastTestRow,
)


class DifferentiatedAccessTests(CovidcastBase):
def localSetUp(self):
"""Perform per-test setup"""
self._db._cursor.execute(
'update covidcast_meta_cache set timestamp = 0, epidata = "[]"'
)

def setUp(self):
# connect to the `epidata` database

super().setUp()

self.maxDiff = None

cnx = mysql.connector.connect(
user="user",
password="pass",
host="delphi_database_epidata",
database="epidata",
)

cur = cnx.cursor()

cur.execute("DELETE FROM `api_user`")
cur.execute("TRUNCATE TABLE `user_role`")
cur.execute("TRUNCATE TABLE `user_role_link`")

cur.execute(
'INSERT INTO `api_user`(`api_key`, `email`) VALUES ("api_key", "[email protected]")'
)
cur.execute(
'INSERT INTO `api_user`(`api_key`, `email`) VALUES("ny_key", "[email protected]")'
)
cur.execute('INSERT INTO `user_role`(`name`) VALUES("state:ny")')
cur.execute(
'INSERT INTO `user_role_link`(`user_id`, `role_id`) SELECT `api_user`.`id`, 1 FROM `api_user` WHERE `api_key` = "ny_key"'
)

cnx.commit()
cur.close()
cnx.close()

def request_based_on_row(self, row: CovidcastTestRow, **kwargs):
params = self.params_from_row(row, endpoint="differentiated_access", **kwargs)
# use local instance of the Epidata API

response = requests.get(
"http://delphi_web_epidata/epidata/api.php", params=params
)
response.raise_for_status()
return response.json()

def _insert_placeholder_restricted_geo(self):
geo_values = ["36029", "36047", "36097", "36103", "36057", "36041", "36033"]
rows = [
CovidcastTestRow.make_default_row(
source="restricted-source",
geo_type="county",
geo_value=geo_values[i],
time_value=2000_01_01 + i,
value=i * 1.0,
stderr=i * 10.0,
sample_size=i * 100.0,
issue=2000_01_03,
lag=2 - i,
)
for i in [1, 2, 3]
] + [
# time value intended to overlap with the time values above, with disjoint geo values
CovidcastTestRow.make_default_row(
source="restricted-source",
geo_type="county",
geo_value=geo_values[i],
time_value=2000_01_01 + i - 3,
value=i * 1.0,
stderr=i * 10.0,
sample_size=i * 100.0,
issue=2000_01_03,
lag=5 - i,
)
for i in [4, 5, 6]
]
self._insert_rows(rows)
return rows

def test_restricted_geo_ny_role(self):
# insert placeholder data
rows = self._insert_placeholder_restricted_geo()

# make request
response = self.request_based_on_row(rows[0], token="ny_key")
expected = {
"result": 1,
"epidata": [rows[0].as_api_compatibility_row_dict()],
"message": "success",
}
self.assertEqual(response, expected)

def test_restricted_geo_default_role(self):
# insert placeholder data
rows = self._insert_placeholder_restricted_geo()

# make request
response = self.request_based_on_row(rows[0], token="api_key")
expected = {"result": -2, "message": "no results"}
self.assertEqual(response, expected)
2 changes: 2 additions & 0 deletions src/server/endpoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
wiki,
signal_dashboard_status,
signal_dashboard_coverage,
differentiated_access
)

endpoints = [
Expand Down Expand Up @@ -64,6 +65,7 @@
wiki,
signal_dashboard_status,
signal_dashboard_coverage,
differentiated_access
]

__all__ = ["endpoints"]
159 changes: 159 additions & 0 deletions src/server/endpoints/differentiated_access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from flask import Blueprint
from werkzeug.exceptions import Unauthorized

from .._common import is_compatibility_mode
from .._params import (
extract_date,
extract_dates,
extract_integer,
parse_geo_sets,
parse_source_signal_sets,
parse_time_set,
)
from .._query import QueryBuilder, execute_query
from .._security import current_user, sources_protected_by_roles
from .covidcast_utils.model import create_source_signal_alias_mapper
from delphi_utils import GeoMapper
from delphi.epidata.common.logger import get_structured_logger

# first argument is the endpoint name
bp = Blueprint("differentiated_access", __name__)
alias = None

latest_table = "epimetric_latest_v"
history_table = "epimetric_full_v"


def restrict_by_roles(source_signal_sets):
# takes a list of SourceSignalSet objects
# and returns only those from the list
# that the current user is permitted to access.
user = current_user
allowed_source_signal_sets = []
for src_sig_set in source_signal_sets:
src = src_sig_set.source
if src in sources_protected_by_roles:
role = sources_protected_by_roles[src]
if user and user.has_role(role):
allowed_source_signal_sets.append(src_sig_set)
else:
# protected src and user does not have permission => leave it out of the srcsig sets
get_structured_logger("covcast_endpt").warning(
"non-authZd request for restricted 'source'",
api_key=(user and user.api_key),
src=src,
)
else:
allowed_source_signal_sets.append(src_sig_set)
return allowed_source_signal_sets


def serve_geo_restricted(geo_sets: set):
geomapper = GeoMapper()
if not current_user:
raise Unauthorized("User is not authenticated.")
allowed_counties = set()
# Getting allowed counties set from user's roles.
# Example: role 'state:ny' will give user access to all counties in ny state.
for role in current_user.roles:
if role.name.startswith("state:"):
state = role.name.split(":", 1)[1]
counties_in_state = geomapper.get_geos_within(state, "county", "state")
allowed_counties.update(counties_in_state)

for geo_set in geo_sets:
# Reject if `geo_type` is not county.
if geo_set.geo_type != "county":
raise Unauthorized("Only `county` geo_type is allowed")
# If `geo_value` = '*' then we want to query only that counties that user has access to.
if geo_set.geo_values is True:
geo_set.geo_values = list(allowed_counties)
# Actually we don't need to check whether `geo_set.geo_values` (user requested counties) is a superset of `allowed_counties`
# We do want to return set of counties that are in both `geo_set.geo_values` and `allowed_counties`
# Because if user requested less -> we will get only requested list of counties, in other case (user requested more
# than he can get -> he will get only that counties that he is allowed to).

# elif set(geo_set.geo_values).issuperset(allowed_counties):
# geo_set.geo_values = list(set(geo_set.geo_values).intersection(allowed_counties))

# If user provided more counties that he is able to query, then we want to show him only
# that counties that he is allowed to.
else:
geo_set.geo_values = list(
set(geo_set.geo_values).intersection(allowed_counties)
)
return geo_sets


@bp.route("/", methods=("GET", "POST"))
def handle():
source_signal_sets = parse_source_signal_sets()
source_signal_sets = restrict_by_roles(source_signal_sets)
source_signal_sets, alias_mapper = create_source_signal_alias_mapper(
source_signal_sets
)
time_set = parse_time_set()
geo_sets = serve_geo_restricted(parse_geo_sets())

as_of = extract_date("as_of")
issues = extract_dates("issues")
lag = extract_integer("lag")

# build query
q = QueryBuilder(latest_table, "t")

fields_string = ["geo_value", "signal"]
fields_int = [
"time_value",
"direction",
"issue",
"lag",
"missing_value",
"missing_stderr",
"missing_sample_size",
]
fields_float = ["value", "stderr", "sample_size"]
is_compatibility = is_compatibility_mode()
if is_compatibility:
q.set_sort_order("signal", "time_value", "geo_value", "issue")
else:
# transfer also the new detail columns
fields_string.extend(["source", "geo_type", "time_type"])
q.set_sort_order(
"source",
"signal",
"time_type",
"time_value",
"geo_type",
"geo_value",
"issue",
)
q.set_fields(fields_string, fields_int, fields_float)

# basic query info
# data type of each field
# build the source, signal, time, and location (type and id) filters

q.apply_source_signal_filters("source", "signal", source_signal_sets)
q.apply_geo_filters("geo_type", "geo_value", geo_sets)
q.apply_time_filter("time_type", "time_value", time_set)

q.apply_issues_filter(history_table, issues)
q.apply_lag_filter(history_table, lag)
q.apply_as_of_filter(history_table, as_of)

def transform_row(row, proxy):
if is_compatibility or not alias_mapper or "source" not in row:
return row
row["source"] = alias_mapper(row["source"], proxy["signal"])
return row

# send query
return execute_query(
str(q),
q.params,
fields_string,
fields_int,
fields_float,
transform=transform_row,
)
Loading