-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added geo based restricted endpoint. Resulting data is based on user'…
…s role, so user can see only role allowed data (county based)
- Loading branch information
1 parent
f7da659
commit 2bf7e00
Showing
3 changed files
with
279 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import requests | ||
|
||
# third party | ||
import mysql.connector | ||
|
||
# frirst party | ||
from delphi.epidata.acquisition.covidcast.test_utils import ( | ||
CovidcastBase, | ||
CovidcastTestRow, | ||
) | ||
|
||
|
||
class DifferentiatedAccessTests(CovidcastBase): | ||
def localSetUp(self): | ||
"""Perform per-test setup""" | ||
self._db._cursor.execute( | ||
'update covidcast_meta_cache set timestamp = 0, epidata = "[]"' | ||
) | ||
|
||
def setUp(self): | ||
# connect to the `epidata` database | ||
|
||
super().setUp() | ||
|
||
self.maxDiff = None | ||
|
||
cnx = mysql.connector.connect( | ||
user="user", | ||
password="pass", | ||
host="delphi_database_epidata", | ||
database="epidata", | ||
) | ||
|
||
cur = cnx.cursor() | ||
|
||
cur.execute("DELETE FROM `api_user`") | ||
cur.execute("TRUNCATE TABLE `user_role`") | ||
cur.execute("TRUNCATE TABLE `user_role_link`") | ||
|
||
cur.execute( | ||
'INSERT INTO `api_user`(`api_key`, `email`) VALUES ("api_key", "[email protected]")' | ||
) | ||
cur.execute( | ||
'INSERT INTO `api_user`(`api_key`, `email`) VALUES("ny_key", "[email protected]")' | ||
) | ||
cur.execute('INSERT INTO `user_role`(`name`) VALUES("state:ny")') | ||
cur.execute( | ||
'INSERT INTO `user_role_link`(`user_id`, `role_id`) SELECT `api_user`.`id`, 1 FROM `api_user` WHERE `api_key` = "ny_key"' | ||
) | ||
|
||
cnx.commit() | ||
cur.close() | ||
cnx.close() | ||
|
||
def request_based_on_row(self, row: CovidcastTestRow, **kwargs): | ||
params = self.params_from_row(row, endpoint="differentiated_access", **kwargs) | ||
# use local instance of the Epidata API | ||
|
||
response = requests.get( | ||
"http://delphi_web_epidata/epidata/api.php", params=params | ||
) | ||
response.raise_for_status() | ||
return response.json() | ||
|
||
def _insert_placeholder_restricted_geo(self): | ||
geo_values = ["36029", "36047", "36097", "36103", "36057", "36041", "36033"] | ||
rows = [ | ||
CovidcastTestRow.make_default_row( | ||
source="restricted-source", | ||
geo_type="county", | ||
geo_value=geo_values[i], | ||
time_value=2000_01_01 + i, | ||
value=i * 1.0, | ||
stderr=i * 10.0, | ||
sample_size=i * 100.0, | ||
issue=2000_01_03, | ||
lag=2 - i, | ||
) | ||
for i in [1, 2, 3] | ||
] + [ | ||
# time value intended to overlap with the time values above, with disjoint geo values | ||
CovidcastTestRow.make_default_row( | ||
source="restricted-source", | ||
geo_type="county", | ||
geo_value=geo_values[i], | ||
time_value=2000_01_01 + i - 3, | ||
value=i * 1.0, | ||
stderr=i * 10.0, | ||
sample_size=i * 100.0, | ||
issue=2000_01_03, | ||
lag=5 - i, | ||
) | ||
for i in [4, 5, 6] | ||
] | ||
self._insert_rows(rows) | ||
return rows | ||
|
||
def test_restricted_geo_ny_role(self): | ||
# insert placeholder data | ||
rows = self._insert_placeholder_restricted_geo() | ||
|
||
# make request | ||
response = self.request_based_on_row(rows[0], token="ny_key") | ||
expected = { | ||
"result": 1, | ||
"epidata": [rows[0].as_api_compatibility_row_dict()], | ||
"message": "success", | ||
} | ||
self.assertEqual(response, expected) | ||
|
||
def test_restricted_geo_default_role(self): | ||
# insert placeholder data | ||
rows = self._insert_placeholder_restricted_geo() | ||
|
||
# make request | ||
response = self.request_based_on_row(rows[0], token="api_key") | ||
expected = {"result": -2, "message": "no results"} | ||
self.assertEqual(response, expected) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
from flask import Blueprint | ||
from werkzeug.exceptions import Unauthorized | ||
|
||
from .._common import is_compatibility_mode | ||
from .._params import ( | ||
extract_date, | ||
extract_dates, | ||
extract_integer, | ||
parse_geo_sets, | ||
parse_source_signal_sets, | ||
parse_time_set, | ||
) | ||
from .._query import QueryBuilder, execute_query | ||
from .._security import current_user, sources_protected_by_roles | ||
from .covidcast_utils.model import create_source_signal_alias_mapper | ||
from delphi_utils import GeoMapper | ||
from delphi.epidata.common.logger import get_structured_logger | ||
|
||
# first argument is the endpoint name | ||
bp = Blueprint("differentiated_access", __name__) | ||
alias = None | ||
|
||
latest_table = "epimetric_latest_v" | ||
history_table = "epimetric_full_v" | ||
|
||
|
||
def restrict_by_roles(source_signal_sets): | ||
# takes a list of SourceSignalSet objects | ||
# and returns only those from the list | ||
# that the current user is permitted to access. | ||
user = current_user | ||
allowed_source_signal_sets = [] | ||
for src_sig_set in source_signal_sets: | ||
src = src_sig_set.source | ||
if src in sources_protected_by_roles: | ||
role = sources_protected_by_roles[src] | ||
if user and user.has_role(role): | ||
allowed_source_signal_sets.append(src_sig_set) | ||
else: | ||
# protected src and user does not have permission => leave it out of the srcsig sets | ||
get_structured_logger("covcast_endpt").warning( | ||
"non-authZd request for restricted 'source'", | ||
api_key=(user and user.api_key), | ||
src=src, | ||
) | ||
else: | ||
allowed_source_signal_sets.append(src_sig_set) | ||
return allowed_source_signal_sets | ||
|
||
|
||
def serve_geo_restricted(geo_sets: set): | ||
geomapper = GeoMapper() | ||
if not current_user: | ||
raise Unauthorized("User is not authenticated.") | ||
allowed_counties = set() | ||
# Getting allowed counties set from user's roles. | ||
# Example: role 'state:ny' will give user access to all counties in ny state. | ||
for role in current_user.roles: | ||
if role.name.startswith("state:"): | ||
state = role.name.split(":", 1)[1] | ||
counties_in_state = geomapper.get_geos_within(state, "county", "state") | ||
allowed_counties.update(counties_in_state) | ||
|
||
for geo_set in geo_sets: | ||
# Reject if `geo_type` is not county. | ||
if geo_set.geo_type != "county": | ||
raise Unauthorized("Only `county` geo_type is allowed") | ||
# If `geo_value` = '*' then we want to query only that counties that user has access to. | ||
if geo_set.geo_values is True: | ||
geo_set.geo_values = list(allowed_counties) | ||
# Actually we don't need to check whether `geo_set.geo_values` (user requested counties) is a superset of `allowed_counties` | ||
# We do want to return set of counties that are in both `geo_set.geo_values` and `allowed_counties` | ||
# Because if user requested less -> we will get only requested list of counties, in other case (user requested more | ||
# than he can get -> he will get only that counties that he is allowed to). | ||
|
||
# elif set(geo_set.geo_values).issuperset(allowed_counties): | ||
# geo_set.geo_values = list(set(geo_set.geo_values).intersection(allowed_counties)) | ||
|
||
# If user provided more counties that he is able to query, then we want to show him only | ||
# that counties that he is allowed to. | ||
else: | ||
geo_set.geo_values = list( | ||
set(geo_set.geo_values).intersection(allowed_counties) | ||
) | ||
return geo_sets | ||
|
||
|
||
@bp.route("/", methods=("GET", "POST")) | ||
def handle(): | ||
source_signal_sets = parse_source_signal_sets() | ||
source_signal_sets = restrict_by_roles(source_signal_sets) | ||
source_signal_sets, alias_mapper = create_source_signal_alias_mapper( | ||
source_signal_sets | ||
) | ||
time_set = parse_time_set() | ||
geo_sets = serve_geo_restricted(parse_geo_sets()) | ||
|
||
as_of = extract_date("as_of") | ||
issues = extract_dates("issues") | ||
lag = extract_integer("lag") | ||
|
||
# build query | ||
q = QueryBuilder(latest_table, "t") | ||
|
||
fields_string = ["geo_value", "signal"] | ||
fields_int = [ | ||
"time_value", | ||
"direction", | ||
"issue", | ||
"lag", | ||
"missing_value", | ||
"missing_stderr", | ||
"missing_sample_size", | ||
] | ||
fields_float = ["value", "stderr", "sample_size"] | ||
is_compatibility = is_compatibility_mode() | ||
if is_compatibility: | ||
q.set_sort_order("signal", "time_value", "geo_value", "issue") | ||
else: | ||
# transfer also the new detail columns | ||
fields_string.extend(["source", "geo_type", "time_type"]) | ||
q.set_sort_order( | ||
"source", | ||
"signal", | ||
"time_type", | ||
"time_value", | ||
"geo_type", | ||
"geo_value", | ||
"issue", | ||
) | ||
q.set_fields(fields_string, fields_int, fields_float) | ||
|
||
# basic query info | ||
# data type of each field | ||
# build the source, signal, time, and location (type and id) filters | ||
|
||
q.apply_source_signal_filters("source", "signal", source_signal_sets) | ||
q.apply_geo_filters("geo_type", "geo_value", geo_sets) | ||
q.apply_time_filter("time_type", "time_value", time_set) | ||
|
||
q.apply_issues_filter(history_table, issues) | ||
q.apply_lag_filter(history_table, lag) | ||
q.apply_as_of_filter(history_table, as_of) | ||
|
||
def transform_row(row, proxy): | ||
if is_compatibility or not alias_mapper or "source" not in row: | ||
return row | ||
row["source"] = alias_mapper(row["source"], proxy["signal"]) | ||
return row | ||
|
||
# send query | ||
return execute_query( | ||
str(q), | ||
q.params, | ||
fields_string, | ||
fields_int, | ||
fields_float, | ||
transform=transform_row, | ||
) |