Skip to content

Commit

Permalink
covidcast
Browse files Browse the repository at this point in the history
  • Loading branch information
rzats committed Jan 17, 2024
1 parent 7f240fe commit cf459bd
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 16 deletions.
2 changes: 1 addition & 1 deletion integrations/acquisition/covidcast/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _find_matches_for_row(self, row):
cur.execute(q)
res = cur.fetchone()
if res:
results[table] = dict(zip(cur.column_names, res))
results[table] = dict(zip([desc[0] for desc in cur.description], res))
else:
results[table] = None
return results
Expand Down
2 changes: 0 additions & 2 deletions integrations/server/test_covidcast_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ def test_backfill(self):

def test_meta(self):
"""Request a signal from the /meta endpoint."""

num_rows = 10
rows = [CovidcastTestRow.make_default_row(time_value=2020_04_01 + i, value=i, source="fb-survey", signal="smoothed_cli") for i in range(num_rows)]
self._insert_rows(rows)
Expand Down Expand Up @@ -343,7 +342,6 @@ def test_meta_restricted(self):
# and fed by src/server/endpoints/covidcast_utils/db_sources.csv, but also surreptitiously augmened
# by _load_data_signals() which attaches a list of signals to each source,
# in turn fed by src/server/endpoints/covidcast_utils/db_signals.csv)

# insert data from two different sources, one restricted/protected (quidel), one not
self._insert_rows([
CovidcastTestRow.make_default_row(source="quidel", signal="raw_pct_negative"),
Expand Down
33 changes: 20 additions & 13 deletions src/acquisition/covidcast/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# third party
import json
import mysql.connector
import MySQLdb

# first party
import delphi.operations.secrets as secrets
Expand Down Expand Up @@ -38,16 +38,23 @@ class Database:
# TODO: also consider that for composite key tuples, like short_comp_key and long_comp_key as used in delete_batch()


def connect(self, connector_impl=mysql.connector):
def connect(self, connector_impl=None):
"""Establish a connection to the database."""

u, p = secrets.db.epi
self._connector_impl = connector_impl
self._connection = self._connector_impl.connect(
host=secrets.db.host,
user=u,
password=p,
database=Database.DATABASE_NAME)
if connector_impl is None:
self._connection = MySQLdb.connect(
host=secrets.db.host,
user=u,
password=p,
database=Database.DATABASE_NAME)
else:
self._connection = self._connector_impl.connect(
host=secrets.db.host,
user=u,
password=p,
database=Database.DATABASE_NAME)
self._cursor = self._connection.cursor()

def commit(self):
Expand All @@ -71,8 +78,7 @@ def disconnect(self, commit):

def count_all_load_rows(self):
self._cursor.execute(f'SELECT count(1) FROM `{self.load_table}`')
for (num,) in self._cursor:
return num
return self._cursor.fetchone()[0]

def _reset_load_table_ai_counter(self):
"""Corrects the AUTO_INCREMENT counter in the load table.
Expand Down Expand Up @@ -101,7 +107,7 @@ def do_analyze(self):
f'''ANALYZE TABLE
signal_dim, geo_dim,
{self.load_table}, {self.history_table}, {self.latest_table}''')
output = [self._cursor.column_names] + self._cursor.fetchall()
output = [desc[0] for desc in self._cursor.description] + list(self._cursor.fetchall())
get_structured_logger('do_analyze').info("ANALYZE results", results=str(output))

def insert_or_update_bulk(self, cc_rows):
Expand Down Expand Up @@ -456,8 +462,8 @@ def compute_covidcast_meta(self, table_name=None, n_threads=None):
srcsigs = Queue() # multi-consumer threadsafe!
sql = f'SELECT `source`, `signal` FROM `{table_name}` GROUP BY `source`, `signal` ORDER BY `source` ASC, `signal` ASC;'
self._cursor.execute(sql)
for source, signal in self._cursor:
srcsigs.put((source, signal))
for res in self._cursor.fetchall():
srcsigs.put((res[0], res[1])) # source, signal

inner_sql = f'''
SELECT
Expand Down Expand Up @@ -505,8 +511,9 @@ def worker():
logger.info("starting pair", thread=name, pair=f"({source}, {signal})")
w_cursor.execute(inner_sql, (source, signal))
with meta_lock:
# Create a dictionary of column names (from cursor.description) & values
meta.extend(list(
dict(zip(w_cursor.column_names, x)) for x in w_cursor
dict(zip([desc[0] for desc in w_cursor.description], x)) for x in w_cursor
))
srcsigs.task_done()
except Empty:
Expand Down

0 comments on commit cf459bd

Please sign in to comment.