Skip to content

Commit

Permalink
Added geo based restricted endpoint. Resulting data is based on user'…
Browse files Browse the repository at this point in the history
…s role, so user can see only role allowed data (county based)
  • Loading branch information
dmytrotsko committed Sep 13, 2023
1 parent f7da659 commit 2bf7e00
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 0 deletions.
118 changes: 118 additions & 0 deletions integrations/server/test_differentiated_access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import requests

# third party
import mysql.connector

# frirst party
from delphi.epidata.acquisition.covidcast.test_utils import (
CovidcastBase,
CovidcastTestRow,
)


class DifferentiatedAccessTests(CovidcastBase):
def localSetUp(self):
"""Perform per-test setup"""
self._db._cursor.execute(
'update covidcast_meta_cache set timestamp = 0, epidata = "[]"'
)

def setUp(self):
# connect to the `epidata` database

super().setUp()

self.maxDiff = None

cnx = mysql.connector.connect(
user="user",
password="pass",
host="delphi_database_epidata",
database="epidata",
)

cur = cnx.cursor()

cur.execute("DELETE FROM `api_user`")
cur.execute("TRUNCATE TABLE `user_role`")
cur.execute("TRUNCATE TABLE `user_role_link`")

cur.execute(
'INSERT INTO `api_user`(`api_key`, `email`) VALUES ("api_key", "[email protected]")'
)
cur.execute(
'INSERT INTO `api_user`(`api_key`, `email`) VALUES("ny_key", "[email protected]")'
)
cur.execute('INSERT INTO `user_role`(`name`) VALUES("state:ny")')
cur.execute(
'INSERT INTO `user_role_link`(`user_id`, `role_id`) SELECT `api_user`.`id`, 1 FROM `api_user` WHERE `api_key` = "ny_key"'
)

cnx.commit()
cur.close()
cnx.close()

def request_based_on_row(self, row: CovidcastTestRow, **kwargs):
params = self.params_from_row(row, endpoint="differentiated_access", **kwargs)
# use local instance of the Epidata API

response = requests.get(
"http://delphi_web_epidata/epidata/api.php", params=params
)
response.raise_for_status()
return response.json()

def _insert_placeholder_restricted_geo(self):
geo_values = ["36029", "36047", "36097", "36103", "36057", "36041", "36033"]
rows = [
CovidcastTestRow.make_default_row(
source="restricted-source",
geo_type="county",
geo_value=geo_values[i],
time_value=2000_01_01 + i,
value=i * 1.0,
stderr=i * 10.0,
sample_size=i * 100.0,
issue=2000_01_03,
lag=2 - i,
)
for i in [1, 2, 3]
] + [
# time value intended to overlap with the time values above, with disjoint geo values
CovidcastTestRow.make_default_row(
source="restricted-source",
geo_type="county",
geo_value=geo_values[i],
time_value=2000_01_01 + i - 3,
value=i * 1.0,
stderr=i * 10.0,
sample_size=i * 100.0,
issue=2000_01_03,
lag=5 - i,
)
for i in [4, 5, 6]
]
self._insert_rows(rows)
return rows

def test_restricted_geo_ny_role(self):
# insert placeholder data
rows = self._insert_placeholder_restricted_geo()

# make request
response = self.request_based_on_row(rows[0], token="ny_key")
expected = {
"result": 1,
"epidata": [rows[0].as_api_compatibility_row_dict()],
"message": "success",
}
self.assertEqual(response, expected)

def test_restricted_geo_default_role(self):
# insert placeholder data
rows = self._insert_placeholder_restricted_geo()

# make request
response = self.request_based_on_row(rows[0], token="api_key")
expected = {"result": -2, "message": "no results"}
self.assertEqual(response, expected)
2 changes: 2 additions & 0 deletions src/server/endpoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
wiki,
signal_dashboard_status,
signal_dashboard_coverage,
differentiated_access
)

endpoints = [
Expand Down Expand Up @@ -66,6 +67,7 @@
wiki,
signal_dashboard_status,
signal_dashboard_coverage,
differentiated_access
]

__all__ = ["endpoints"]
159 changes: 159 additions & 0 deletions src/server/endpoints/differentiated_access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from flask import Blueprint
from werkzeug.exceptions import Unauthorized

from .._common import is_compatibility_mode
from .._params import (
extract_date,
extract_dates,
extract_integer,
parse_geo_sets,
parse_source_signal_sets,
parse_time_set,
)
from .._query import QueryBuilder, execute_query
from .._security import current_user, sources_protected_by_roles
from .covidcast_utils.model import create_source_signal_alias_mapper
from delphi_utils import GeoMapper
from delphi.epidata.common.logger import get_structured_logger

# first argument is the endpoint name
bp = Blueprint("differentiated_access", __name__)
alias = None

latest_table = "epimetric_latest_v"
history_table = "epimetric_full_v"


def restrict_by_roles(source_signal_sets):
# takes a list of SourceSignalSet objects
# and returns only those from the list
# that the current user is permitted to access.
user = current_user
allowed_source_signal_sets = []
for src_sig_set in source_signal_sets:
src = src_sig_set.source
if src in sources_protected_by_roles:
role = sources_protected_by_roles[src]
if user and user.has_role(role):
allowed_source_signal_sets.append(src_sig_set)
else:
# protected src and user does not have permission => leave it out of the srcsig sets
get_structured_logger("covcast_endpt").warning(
"non-authZd request for restricted 'source'",
api_key=(user and user.api_key),
src=src,
)
else:
allowed_source_signal_sets.append(src_sig_set)
return allowed_source_signal_sets


def serve_geo_restricted(geo_sets: set):
geomapper = GeoMapper()
if not current_user:
raise Unauthorized("User is not authenticated.")
allowed_counties = set()
# Getting allowed counties set from user's roles.
# Example: role 'state:ny' will give user access to all counties in ny state.
for role in current_user.roles:
if role.name.startswith("state:"):
state = role.name.split(":", 1)[1]
counties_in_state = geomapper.get_geos_within(state, "county", "state")
allowed_counties.update(counties_in_state)

for geo_set in geo_sets:
# Reject if `geo_type` is not county.
if geo_set.geo_type != "county":
raise Unauthorized("Only `county` geo_type is allowed")
# If `geo_value` = '*' then we want to query only that counties that user has access to.
if geo_set.geo_values is True:
geo_set.geo_values = list(allowed_counties)
# Actually we don't need to check whether `geo_set.geo_values` (user requested counties) is a superset of `allowed_counties`
# We do want to return set of counties that are in both `geo_set.geo_values` and `allowed_counties`
# Because if user requested less -> we will get only requested list of counties, in other case (user requested more
# than he can get -> he will get only that counties that he is allowed to).

# elif set(geo_set.geo_values).issuperset(allowed_counties):
# geo_set.geo_values = list(set(geo_set.geo_values).intersection(allowed_counties))

# If user provided more counties that he is able to query, then we want to show him only
# that counties that he is allowed to.
else:
geo_set.geo_values = list(
set(geo_set.geo_values).intersection(allowed_counties)
)
return geo_sets


@bp.route("/", methods=("GET", "POST"))
def handle():
source_signal_sets = parse_source_signal_sets()
source_signal_sets = restrict_by_roles(source_signal_sets)
source_signal_sets, alias_mapper = create_source_signal_alias_mapper(
source_signal_sets
)
time_set = parse_time_set()
geo_sets = serve_geo_restricted(parse_geo_sets())

as_of = extract_date("as_of")
issues = extract_dates("issues")
lag = extract_integer("lag")

# build query
q = QueryBuilder(latest_table, "t")

fields_string = ["geo_value", "signal"]
fields_int = [
"time_value",
"direction",
"issue",
"lag",
"missing_value",
"missing_stderr",
"missing_sample_size",
]
fields_float = ["value", "stderr", "sample_size"]
is_compatibility = is_compatibility_mode()
if is_compatibility:
q.set_sort_order("signal", "time_value", "geo_value", "issue")
else:
# transfer also the new detail columns
fields_string.extend(["source", "geo_type", "time_type"])
q.set_sort_order(
"source",
"signal",
"time_type",
"time_value",
"geo_type",
"geo_value",
"issue",
)
q.set_fields(fields_string, fields_int, fields_float)

# basic query info
# data type of each field
# build the source, signal, time, and location (type and id) filters

q.apply_source_signal_filters("source", "signal", source_signal_sets)
q.apply_geo_filters("geo_type", "geo_value", geo_sets)
q.apply_time_filter("time_type", "time_value", time_set)

q.apply_issues_filter(history_table, issues)
q.apply_lag_filter(history_table, lag)
q.apply_as_of_filter(history_table, as_of)

def transform_row(row, proxy):
if is_compatibility or not alias_mapper or "source" not in row:
return row
row["source"] = alias_mapper(row["source"], proxy["signal"])
return row

# send query
return execute_query(
str(q),
q.params,
fields_string,
fields_int,
fields_float,
transform=transform_row,
)

0 comments on commit 2bf7e00

Please sign in to comment.