Skip to content

Commit 4f39a82

Browse files
dshemetovkrivardmelange396
committed
CovidcastRow: address code review #1044
Co-authored-by: Katie Mazaitis <[email protected]> Co-authored-by: melange396 <[email protected]>
1 parent 8568e77 commit 4f39a82

File tree

8 files changed

+110
-78
lines changed

8 files changed

+110
-78
lines changed

integrations/acquisition/covidcast/test_db.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def _find_matches_for_row(self, row):
3131

3232
def test_insert_or_update_with_nonempty_load_table(self):
3333
# make rows
34-
a_row = CovidcastRow.make_default_row(time_value=20200202)
35-
another_row = CovidcastRow.make_default_row(time_value=20200203, issue=20200203)
34+
a_row = CovidcastRow.make_default_row(time_value=2020_02_02)
35+
another_row = CovidcastRow.make_default_row(time_value=2020_02_03, issue=2020_02_03)
3636
# insert one
3737
self._db.insert_or_update_bulk([a_row])
3838
# put something into the load table

integrations/client/test_delphi_epidata.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def test_covidcast(self):
6868
)
6969

7070
expected = [
71-
row_latest_issue.as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields),
72-
rows[-1].as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields)
71+
row_latest_issue.as_api_compatibility_row_dict(),
72+
rows[-1].as_api_compatibility_row_dict()
7373
]
7474

7575
self.assertEqual(response['epidata'], expected)
@@ -88,10 +88,10 @@ def test_covidcast(self):
8888

8989
expected = [{
9090
rows[0].signal: [
91-
row_latest_issue.as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields + ['signal']),
91+
row_latest_issue.as_api_compatibility_row_dict(ignore_fields=['signal']),
9292
],
9393
rows[-1].signal: [
94-
rows[-1].as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields + ['signal']),
94+
rows[-1].as_api_compatibility_row_dict(ignore_fields=['signal']),
9595
],
9696
}]
9797

@@ -108,7 +108,7 @@ def test_covidcast(self):
108108
**self.params_from_row(rows[0])
109109
)
110110

111-
expected = [row_latest_issue.as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields)]
111+
expected = [row_latest_issue.as_api_compatibility_row_dict()]
112112

113113
# check result
114114
self.assertEqual(response_1, {
@@ -123,7 +123,7 @@ def test_covidcast(self):
123123
**self.params_from_row(rows[0], as_of=rows[1].issue)
124124
)
125125

126-
expected = [rows[1].as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields)]
126+
expected = [rows[1].as_api_compatibility_row_dict()]
127127

128128
# check result
129129
self.maxDiff=None
@@ -140,8 +140,8 @@ def test_covidcast(self):
140140
)
141141

142142
expected = [
143-
rows[0].as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields),
144-
rows[1].as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields)
143+
rows[0].as_api_compatibility_row_dict(),
144+
rows[1].as_api_compatibility_row_dict()
145145
]
146146

147147
# check result
@@ -157,7 +157,7 @@ def test_covidcast(self):
157157
**self.params_from_row(rows[0], lag=2)
158158
)
159159

160-
expected = [row_latest_issue.as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields)]
160+
expected = [row_latest_issue.as_api_compatibility_row_dict()]
161161

162162
# check result
163163
self.assertDictEqual(response_3, {
@@ -231,7 +231,7 @@ def test_geo_value(self):
231231
self._insert_rows(rows)
232232

233233
counties = [
234-
rows[i].as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields) for i in range(N)
234+
rows[i].as_api_compatibility_row_dict() for i in range(N)
235235
]
236236

237237
def fetch(geo):

integrations/server/test_covidcast.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_round_trip(self):
9393
# make the request
9494
response = self.request_based_on_row(row)
9595

96-
expected = [row.as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields)]
96+
expected = [row.as_api_compatibility_row_dict()]
9797

9898
self.assertEqual(response, {
9999
'result': 1,
@@ -160,13 +160,13 @@ def test_csv_format(self):
160160
**{'format':'csv'}
161161
)
162162

163-
# TODO: This is a mess because of api.php.
163+
# This is a hardcoded mess because of api.php.
164164
column_order = [
165165
"geo_value", "signal", "time_value", "direction", "issue", "lag", "missing_value",
166166
"missing_stderr", "missing_sample_size", "value", "stderr", "sample_size"
167167
]
168168
expected = (
169-
row.api_compatibility_row_df
169+
row.as_api_compatibility_row_df()
170170
.assign(direction = None)
171171
.to_csv(columns=column_order, index=False)
172172
)
@@ -183,7 +183,7 @@ def test_raw_json_format(self):
183183
# make the request
184184
response = self.request_based_on_row(row, **{'format':'json'})
185185

186-
expected = [row.as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields)]
186+
expected = [row.as_api_compatibility_row_dict()]
187187

188188
# assert that the right data came back
189189
self.assertEqual(response, expected)
@@ -197,7 +197,7 @@ def test_fields(self):
197197
# limit fields
198198
response = self.request_based_on_row(row, fields='time_value,geo_value')
199199

200-
expected = row.as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields)
200+
expected = row.as_api_compatibility_row_dict()
201201
expected_all = {
202202
'result': 1,
203203
'epidata': [{
@@ -230,7 +230,7 @@ def test_location_wildcard(self):
230230

231231
# insert placeholder data
232232
rows = self._insert_placeholder_set_two()
233-
expected = [row.as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields) for row in rows[:3]]
233+
expected = [row.as_api_compatibility_row_dict() for row in rows[:3]]
234234
# make the request
235235
response = self.request_based_on_row(rows[0], geo_value="*")
236236

@@ -247,7 +247,7 @@ def test_time_values_wildcard(self):
247247

248248
# insert placeholder data
249249
rows = self._insert_placeholder_set_three()
250-
expected = [row.as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields) for row in rows[:3]]
250+
expected = [row.as_api_compatibility_row_dict() for row in rows[:3]]
251251

252252
# make the request
253253
response = self.request_based_on_row(rows[0], time_values="*")
@@ -265,7 +265,7 @@ def test_issues_wildcard(self):
265265

266266
# insert placeholder data
267267
rows = self._insert_placeholder_set_five()
268-
expected = [row.as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields) for row in rows[:3]]
268+
expected = [row.as_api_compatibility_row_dict() for row in rows[:3]]
269269

270270
# make the request
271271
response = self.request_based_on_row(rows[0], issues="*")
@@ -283,7 +283,7 @@ def test_signal_wildcard(self):
283283

284284
# insert placeholder data
285285
rows = self._insert_placeholder_set_four()
286-
expected_signals = [row.as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields) for row in rows[:3]]
286+
expected_signals = [row.as_api_compatibility_row_dict() for row in rows[:3]]
287287

288288
# make the request
289289
response = self.request_based_on_row(rows[0], signals="*")
@@ -301,7 +301,7 @@ def test_geo_value(self):
301301

302302
# insert placeholder data
303303
rows = self._insert_placeholder_set_two()
304-
expected = [row.as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields) for row in rows[:3]]
304+
expected = [row.as_api_compatibility_row_dict() for row in rows[:3]]
305305

306306
def fetch(geo_value):
307307
# make the request
@@ -337,7 +337,7 @@ def test_location_timeline(self):
337337

338338
# insert placeholder data
339339
rows = self._insert_placeholder_set_three()
340-
expected_timeseries = [row.as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields) for row in rows[:3]]
340+
expected_timeseries = [row.as_api_compatibility_row_dict() for row in rows[:3]]
341341

342342
# make the request
343343
response = self.request_based_on_row(rows[0], time_values='20000101-20000105')
@@ -374,8 +374,7 @@ def test_nullable_columns(self):
374374

375375
# make the request
376376
response = self.request_based_on_row(row)
377-
expected = row.as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields)
378-
# expected.update(stderr=None, sample_size=None)
377+
expected = row.as_api_compatibility_row_dict()
379378

380379
# assert that the right data came back
381380
self.assertEqual(response, {
@@ -395,8 +394,8 @@ def test_temporal_partitioning(self):
395394
self._insert_rows(rows)
396395

397396
# make the request
398-
response = self.request_based_on_row(rows[1], time_values="20000101-30010201")
399-
expected = [rows[1].as_dict(ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields)]
397+
response = self.request_based_on_row(rows[1], time_values="*")
398+
expected = [rows[1].as_api_compatibility_row_dict()]
400399

401400
# assert that the right data came back
402401
self.assertEqual(response, {

integrations/server/test_covidcast_endpoints.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def _fetch(self, endpoint="/", is_compatibility=False, **params):
3939
# make the request
4040
if is_compatibility:
4141
url = BASE_URL_OLD
42+
# only set endpoint if it's not already set
43+
# only set endpoint if it's not already set
4244
params.setdefault("endpoint", "covidcast")
4345
if params.get("source"):
4446
params.setdefault("data_source", params.get("source"))
@@ -49,7 +51,10 @@ def _fetch(self, endpoint="/", is_compatibility=False, **params):
4951
return response.json()
5052

5153
def _diff_rows(self, rows: Sequence[float]):
52-
return [float(x - y) if x is not None and y is not None else None for x, y in zip(rows[1:], rows[:-1])]
54+
return [
55+
float(x - y) if x is not None and y is not None else None
56+
for x, y in zip(rows[1:], rows[:-1])
57+
]
5358

5459
def _smooth_rows(self, rows: Sequence[float]):
5560
return [
@@ -59,7 +64,7 @@ def _smooth_rows(self, rows: Sequence[float]):
5964

6065
def test_basic(self):
6166
"""Request a signal from the / endpoint."""
62-
rows = [CovidcastRow.make_default_row(time_value=20200401 + i, value=i) for i in range(10)]
67+
rows = [CovidcastRow.make_default_row(time_value=2020_04_01 + i, value=i) for i in range(10)]
6368
first = rows[0]
6469
self._insert_rows(rows)
6570

@@ -68,12 +73,12 @@ def test_basic(self):
6873
self.assertEqual(out["result"], -1)
6974

7075
with self.subTest("simple"):
71-
out = self._fetch("/", signal=first.signal_pair, geo=first.geo_pair, time="day:*")
76+
out = self._fetch("/", signal=first.signal_pair(), geo=first.geo_pair(), time="day:*")
7277
self.assertEqual(len(out["epidata"]), len(rows))
7378

7479
def test_compatibility(self):
7580
"""Request at the /api.php endpoint."""
76-
rows = [CovidcastRow.make_default_row(source="src", signal="sig", time_value=20200401 + i, value=i) for i in range(10)]
81+
rows = [CovidcastRow.make_default_row(source="src", signal="sig", time_value=2020_04_01 + i, value=i) for i in range(10)]
7782
first = rows[0]
7883
self._insert_rows(rows)
7984

@@ -82,20 +87,20 @@ def test_compatibility(self):
8287
self.assertEqual(out["result"], -1)
8388

8489
with self.subTest("simple"):
85-
out = self._fetch("/", signal=first.signal_pair, geo=first.geo_pair, time="day:*", is_compatibility=True)
90+
out = self._fetch("/", signal=first.signal_pair(), geo=first.geo_pair(), time="day:*", is_compatibility=True)
8691
self.assertEqual(len(out["epidata"]), len(rows))
8792

8893
def test_trend(self):
8994
"""Request a signal from the /trend endpoint."""
9095

9196
num_rows = 30
92-
rows = [CovidcastRow.make_default_row(time_value=20200401 + i, value=i) for i in range(num_rows)]
97+
rows = [CovidcastRow.make_default_row(time_value=2020_04_01 + i, value=i) for i in range(num_rows)]
9398
first = rows[0]
9499
last = rows[-1]
95100
ref = rows[num_rows // 2]
96101
self._insert_rows(rows)
97102

98-
out = self._fetch("/trend", signal=first.signal_pair, geo=first.geo_pair, date=last.time_value, window="20200401-20201212", basis=ref.time_value)
103+
out = self._fetch("/trend", signal=first.signal_pair(), geo=first.geo_pair(), date=last.time_value, window="20200401-20201212", basis=ref.time_value)
99104

100105

101106
self.assertEqual(out["result"], 1)
@@ -125,12 +130,12 @@ def test_trendseries(self):
125130
"""Request a signal from the /trendseries endpoint."""
126131

127132
num_rows = 3
128-
rows = [CovidcastRow.make_default_row(time_value=20200401 + i, value=num_rows - i) for i in range(num_rows)]
133+
rows = [CovidcastRow.make_default_row(time_value=2020_04_01 + i, value=num_rows - i) for i in range(num_rows)]
129134
first = rows[0]
130135
last = rows[-1]
131136
self._insert_rows(rows)
132137

133-
out = self._fetch("/trendseries", signal=first.signal_pair, geo=first.geo_pair, date=last.time_value, window="20200401-20200410", basis=1)
138+
out = self._fetch("/trendseries", signal=first.signal_pair(), geo=first.geo_pair(), date=last.time_value, window="20200401-20200410", basis=1)
134139

135140
self.assertEqual(out["result"], 1)
136141
self.assertEqual(len(out["epidata"]), 3)
@@ -199,7 +204,7 @@ def test_correlation(self):
199204
self._insert_rows(other_rows)
200205
max_lag = 3
201206

202-
out = self._fetch("/correlation", reference=first.signal_pair, others=other.signal_pair, geo=first.geo_pair, window="20200401-20201212", lag=max_lag)
207+
out = self._fetch("/correlation", reference=first.signal_pair(), others=other.signal_pair(), geo=first.geo_pair(), window="20200401-20201212", lag=max_lag)
203208
self.assertEqual(out["result"], 1)
204209
df = pd.DataFrame(out["epidata"])
205210
self.assertEqual(len(df), max_lag * 2 + 1) # -...0...+
@@ -217,26 +222,27 @@ def test_correlation(self):
217222
def test_csv(self):
218223
"""Request a signal from the /csv endpoint."""
219224

220-
rows = [CovidcastRow.make_default_row(time_value=20200401 + i, value=i) for i in range(10)]
225+
rows = [CovidcastRow.make_default_row(time_value=2020_04_01 + i, value=i) for i in range(10)]
221226
first = rows[0]
222227
self._insert_rows(rows)
223228

224229
response = requests.get(
225230
f"{BASE_URL}/csv",
226-
params=dict(signal=first.signal_pair, start_day="2020-04-01", end_day="2020-12-12", geo_type=first.geo_type),
231+
params=dict(signal=first.signal_pair(), start_day="2020-04-01", end_day="2020-12-12", geo_type=first.geo_type),
227232
)
228233

229234
def test_backfill(self):
230235
"""Request a signal from the /backfill endpoint."""
231236

237+
TEST_DATE_VALUE = 2020_04_01
232238
num_rows = 10
233-
issue_0 = [CovidcastRow.make_default_row(time_value=20200401 + i, value=i, sample_size=1, lag=0, issue=20200401 + i) for i in range(num_rows)]
234-
issue_1 = [CovidcastRow.make_default_row(time_value=20200401 + i, value=i + 1, sample_size=2, lag=1, issue=20200401 + i + 1) for i in range(num_rows)]
235-
last_issue = [CovidcastRow.make_default_row(time_value=20200401 + i, value=i + 2, sample_size=3, lag=2, issue=20200401 + i + 2) for i in range(num_rows)] # <-- the latest issues
239+
issue_0 = [CovidcastRow.make_default_row(time_value=TEST_DATE_VALUE + i, value=i, sample_size=1, lag=0, issue=TEST_DATE_VALUE + i) for i in range(num_rows)]
240+
issue_1 = [CovidcastRow.make_default_row(time_value=TEST_DATE_VALUE + i, value=i + 1, sample_size=2, lag=1, issue=TEST_DATE_VALUE + i + 1) for i in range(num_rows)]
241+
last_issue = [CovidcastRow.make_default_row(time_value=TEST_DATE_VALUE + i, value=i + 2, sample_size=3, lag=2, issue=TEST_DATE_VALUE + i + 2) for i in range(num_rows)] # <-- the latest issues
236242
self._insert_rows([*issue_0, *issue_1, *last_issue])
237243
first = issue_0[0]
238244

239-
out = self._fetch("/backfill", signal=first.signal_pair, geo=first.geo_pair, time="day:20200401-20201212", anchor_lag=3)
245+
out = self._fetch("/backfill", signal=first.signal_pair(), geo=first.geo_pair(), time="day:20200401-20201212", anchor_lag=3)
240246
self.assertEqual(out["result"], 1)
241247
df = pd.DataFrame(out["epidata"])
242248
self.assertEqual(len(df), 3 * num_rows) # num issues
@@ -258,7 +264,7 @@ def test_meta(self):
258264
"""Request a signal from the /meta endpoint."""
259265

260266
num_rows = 10
261-
rows = [CovidcastRow.make_default_row(time_value=20200401 + i, value=i, source="fb-survey", signal="smoothed_cli") for i in range(num_rows)]
267+
rows = [CovidcastRow.make_default_row(time_value=2020_04_01 + i, value=i, source="fb-survey", signal="smoothed_cli") for i in range(num_rows)]
262268
self._insert_rows(rows)
263269
first = rows[0]
264270
last = rows[-1]
@@ -298,23 +304,23 @@ def test_coverage(self):
298304
"""Request a signal from the /coverage endpoint."""
299305

300306
num_geos_per_date = [10, 20, 30, 40, 44]
301-
dates = [20200401 + i for i in range(len(num_geos_per_date))]
307+
dates = [2020_04_01 + i for i in range(len(num_geos_per_date))]
302308
rows = [CovidcastRow.make_default_row(time_value=dates[i], value=i, geo_value=str(geo_value)) for i, num_geo in enumerate(num_geos_per_date) for geo_value in range(num_geo)]
303309
self._insert_rows(rows)
304310
first = rows[0]
305311

306312
with self.subTest("default"):
307-
out = self._fetch("/coverage", signal=first.signal_pair, geo_type=first.geo_type, latest=dates[-1], format="json")
313+
out = self._fetch("/coverage", signal=first.signal_pair(), geo_type=first.geo_type, latest=dates[-1], format="json")
308314
self.assertEqual(len(out), len(num_geos_per_date))
309315
self.assertEqual([o["time_value"] for o in out], dates)
310316
self.assertEqual([o["count"] for o in out], num_geos_per_date)
311317

312318
with self.subTest("specify window"):
313-
out = self._fetch("/coverage", signal=first.signal_pair, geo_type=first.geo_type, window=f"{dates[0]}-{dates[1]}", format="json")
319+
out = self._fetch("/coverage", signal=first.signal_pair(), geo_type=first.geo_type, window=f"{dates[0]}-{dates[1]}", format="json")
314320
self.assertEqual(len(out), 2)
315321
self.assertEqual([o["time_value"] for o in out], dates[:2])
316322
self.assertEqual([o["count"] for o in out], num_geos_per_date[:2])
317323

318324
with self.subTest("invalid geo_type"):
319-
out = self._fetch("/coverage", signal=first.signal_pair, geo_type="doesnt_exist", format="json")
325+
out = self._fetch("/coverage", signal=first.signal_pair(), geo_type="doesnt_exist", format="json")
320326
self.assertEqual(len(out), 0)

0 commit comments

Comments
 (0)