@@ -39,6 +39,8 @@ def _fetch(self, endpoint="/", is_compatibility=False, **params):
39
39
# make the request
40
40
if is_compatibility :
41
41
url = BASE_URL_OLD
42
+ # only set endpoint if it's not already set
43
+ # only set endpoint if it's not already set
42
44
params .setdefault ("endpoint" , "covidcast" )
43
45
if params .get ("source" ):
44
46
params .setdefault ("data_source" , params .get ("source" ))
@@ -49,7 +51,10 @@ def _fetch(self, endpoint="/", is_compatibility=False, **params):
49
51
return response .json ()
50
52
51
53
def _diff_rows (self , rows : Sequence [float ]):
52
- return [float (x - y ) if x is not None and y is not None else None for x , y in zip (rows [1 :], rows [:- 1 ])]
54
+ return [
55
+ float (x - y ) if x is not None and y is not None else None
56
+ for x , y in zip (rows [1 :], rows [:- 1 ])
57
+ ]
53
58
54
59
def _smooth_rows (self , rows : Sequence [float ]):
55
60
return [
@@ -59,7 +64,7 @@ def _smooth_rows(self, rows: Sequence[float]):
59
64
60
65
def test_basic (self ):
61
66
"""Request a signal from the / endpoint."""
62
- rows = [CovidcastRow .make_default_row (time_value = 20200401 + i , value = i ) for i in range (10 )]
67
+ rows = [CovidcastRow .make_default_row (time_value = 2020_04_01 + i , value = i ) for i in range (10 )]
63
68
first = rows [0 ]
64
69
self ._insert_rows (rows )
65
70
@@ -68,12 +73,12 @@ def test_basic(self):
68
73
self .assertEqual (out ["result" ], - 1 )
69
74
70
75
with self .subTest ("simple" ):
71
- out = self ._fetch ("/" , signal = first .signal_pair , geo = first .geo_pair , time = "day:*" )
76
+ out = self ._fetch ("/" , signal = first .signal_pair () , geo = first .geo_pair () , time = "day:*" )
72
77
self .assertEqual (len (out ["epidata" ]), len (rows ))
73
78
74
79
def test_compatibility (self ):
75
80
"""Request at the /api.php endpoint."""
76
- rows = [CovidcastRow .make_default_row (source = "src" , signal = "sig" , time_value = 20200401 + i , value = i ) for i in range (10 )]
81
+ rows = [CovidcastRow .make_default_row (source = "src" , signal = "sig" , time_value = 2020_04_01 + i , value = i ) for i in range (10 )]
77
82
first = rows [0 ]
78
83
self ._insert_rows (rows )
79
84
@@ -82,20 +87,20 @@ def test_compatibility(self):
82
87
self .assertEqual (out ["result" ], - 1 )
83
88
84
89
with self .subTest ("simple" ):
85
- out = self ._fetch ("/" , signal = first .signal_pair , geo = first .geo_pair , time = "day:*" , is_compatibility = True )
90
+ out = self ._fetch ("/" , signal = first .signal_pair () , geo = first .geo_pair () , time = "day:*" , is_compatibility = True )
86
91
self .assertEqual (len (out ["epidata" ]), len (rows ))
87
92
88
93
def test_trend (self ):
89
94
"""Request a signal from the /trend endpoint."""
90
95
91
96
num_rows = 30
92
- rows = [CovidcastRow .make_default_row (time_value = 20200401 + i , value = i ) for i in range (num_rows )]
97
+ rows = [CovidcastRow .make_default_row (time_value = 2020_04_01 + i , value = i ) for i in range (num_rows )]
93
98
first = rows [0 ]
94
99
last = rows [- 1 ]
95
100
ref = rows [num_rows // 2 ]
96
101
self ._insert_rows (rows )
97
102
98
- out = self ._fetch ("/trend" , signal = first .signal_pair , geo = first .geo_pair , date = last .time_value , window = "20200401-20201212" , basis = ref .time_value )
103
+ out = self ._fetch ("/trend" , signal = first .signal_pair () , geo = first .geo_pair () , date = last .time_value , window = "20200401-20201212" , basis = ref .time_value )
99
104
100
105
101
106
self .assertEqual (out ["result" ], 1 )
@@ -125,12 +130,12 @@ def test_trendseries(self):
125
130
"""Request a signal from the /trendseries endpoint."""
126
131
127
132
num_rows = 3
128
- rows = [CovidcastRow .make_default_row (time_value = 20200401 + i , value = num_rows - i ) for i in range (num_rows )]
133
+ rows = [CovidcastRow .make_default_row (time_value = 2020_04_01 + i , value = num_rows - i ) for i in range (num_rows )]
129
134
first = rows [0 ]
130
135
last = rows [- 1 ]
131
136
self ._insert_rows (rows )
132
137
133
- out = self ._fetch ("/trendseries" , signal = first .signal_pair , geo = first .geo_pair , date = last .time_value , window = "20200401-20200410" , basis = 1 )
138
+ out = self ._fetch ("/trendseries" , signal = first .signal_pair () , geo = first .geo_pair () , date = last .time_value , window = "20200401-20200410" , basis = 1 )
134
139
135
140
self .assertEqual (out ["result" ], 1 )
136
141
self .assertEqual (len (out ["epidata" ]), 3 )
@@ -199,7 +204,7 @@ def test_correlation(self):
199
204
self ._insert_rows (other_rows )
200
205
max_lag = 3
201
206
202
- out = self ._fetch ("/correlation" , reference = first .signal_pair , others = other .signal_pair , geo = first .geo_pair , window = "20200401-20201212" , lag = max_lag )
207
+ out = self ._fetch ("/correlation" , reference = first .signal_pair () , others = other .signal_pair () , geo = first .geo_pair () , window = "20200401-20201212" , lag = max_lag )
203
208
self .assertEqual (out ["result" ], 1 )
204
209
df = pd .DataFrame (out ["epidata" ])
205
210
self .assertEqual (len (df ), max_lag * 2 + 1 ) # -...0...+
@@ -217,26 +222,27 @@ def test_correlation(self):
217
222
def test_csv (self ):
218
223
"""Request a signal from the /csv endpoint."""
219
224
220
- rows = [CovidcastRow .make_default_row (time_value = 20200401 + i , value = i ) for i in range (10 )]
225
+ rows = [CovidcastRow .make_default_row (time_value = 2020_04_01 + i , value = i ) for i in range (10 )]
221
226
first = rows [0 ]
222
227
self ._insert_rows (rows )
223
228
224
229
response = requests .get (
225
230
f"{ BASE_URL } /csv" ,
226
- params = dict (signal = first .signal_pair , start_day = "2020-04-01" , end_day = "2020-12-12" , geo_type = first .geo_type ),
231
+ params = dict (signal = first .signal_pair () , start_day = "2020-04-01" , end_day = "2020-12-12" , geo_type = first .geo_type ),
227
232
)
228
233
229
234
def test_backfill (self ):
230
235
"""Request a signal from the /backfill endpoint."""
231
236
237
+ TEST_DATE_VALUE = 2020_04_01
232
238
num_rows = 10
233
- issue_0 = [CovidcastRow .make_default_row (time_value = 20200401 + i , value = i , sample_size = 1 , lag = 0 , issue = 20200401 + i ) for i in range (num_rows )]
234
- issue_1 = [CovidcastRow .make_default_row (time_value = 20200401 + i , value = i + 1 , sample_size = 2 , lag = 1 , issue = 20200401 + i + 1 ) for i in range (num_rows )]
235
- last_issue = [CovidcastRow .make_default_row (time_value = 20200401 + i , value = i + 2 , sample_size = 3 , lag = 2 , issue = 20200401 + i + 2 ) for i in range (num_rows )] # <-- the latest issues
239
+ issue_0 = [CovidcastRow .make_default_row (time_value = TEST_DATE_VALUE + i , value = i , sample_size = 1 , lag = 0 , issue = TEST_DATE_VALUE + i ) for i in range (num_rows )]
240
+ issue_1 = [CovidcastRow .make_default_row (time_value = TEST_DATE_VALUE + i , value = i + 1 , sample_size = 2 , lag = 1 , issue = TEST_DATE_VALUE + i + 1 ) for i in range (num_rows )]
241
+ last_issue = [CovidcastRow .make_default_row (time_value = TEST_DATE_VALUE + i , value = i + 2 , sample_size = 3 , lag = 2 , issue = TEST_DATE_VALUE + i + 2 ) for i in range (num_rows )] # <-- the latest issues
236
242
self ._insert_rows ([* issue_0 , * issue_1 , * last_issue ])
237
243
first = issue_0 [0 ]
238
244
239
- out = self ._fetch ("/backfill" , signal = first .signal_pair , geo = first .geo_pair , time = "day:20200401-20201212" , anchor_lag = 3 )
245
+ out = self ._fetch ("/backfill" , signal = first .signal_pair () , geo = first .geo_pair () , time = "day:20200401-20201212" , anchor_lag = 3 )
240
246
self .assertEqual (out ["result" ], 1 )
241
247
df = pd .DataFrame (out ["epidata" ])
242
248
self .assertEqual (len (df ), 3 * num_rows ) # num issues
@@ -258,7 +264,7 @@ def test_meta(self):
258
264
"""Request a signal from the /meta endpoint."""
259
265
260
266
num_rows = 10
261
- rows = [CovidcastRow .make_default_row (time_value = 20200401 + i , value = i , source = "fb-survey" , signal = "smoothed_cli" ) for i in range (num_rows )]
267
+ rows = [CovidcastRow .make_default_row (time_value = 2020_04_01 + i , value = i , source = "fb-survey" , signal = "smoothed_cli" ) for i in range (num_rows )]
262
268
self ._insert_rows (rows )
263
269
first = rows [0 ]
264
270
last = rows [- 1 ]
@@ -298,23 +304,23 @@ def test_coverage(self):
298
304
"""Request a signal from the /coverage endpoint."""
299
305
300
306
num_geos_per_date = [10 , 20 , 30 , 40 , 44 ]
301
- dates = [20200401 + i for i in range (len (num_geos_per_date ))]
307
+ dates = [2020_04_01 + i for i in range (len (num_geos_per_date ))]
302
308
rows = [CovidcastRow .make_default_row (time_value = dates [i ], value = i , geo_value = str (geo_value )) for i , num_geo in enumerate (num_geos_per_date ) for geo_value in range (num_geo )]
303
309
self ._insert_rows (rows )
304
310
first = rows [0 ]
305
311
306
312
with self .subTest ("default" ):
307
- out = self ._fetch ("/coverage" , signal = first .signal_pair , geo_type = first .geo_type , latest = dates [- 1 ], format = "json" )
313
+ out = self ._fetch ("/coverage" , signal = first .signal_pair () , geo_type = first .geo_type , latest = dates [- 1 ], format = "json" )
308
314
self .assertEqual (len (out ), len (num_geos_per_date ))
309
315
self .assertEqual ([o ["time_value" ] for o in out ], dates )
310
316
self .assertEqual ([o ["count" ] for o in out ], num_geos_per_date )
311
317
312
318
with self .subTest ("specify window" ):
313
- out = self ._fetch ("/coverage" , signal = first .signal_pair , geo_type = first .geo_type , window = f"{ dates [0 ]} -{ dates [1 ]} " , format = "json" )
319
+ out = self ._fetch ("/coverage" , signal = first .signal_pair () , geo_type = first .geo_type , window = f"{ dates [0 ]} -{ dates [1 ]} " , format = "json" )
314
320
self .assertEqual (len (out ), 2 )
315
321
self .assertEqual ([o ["time_value" ] for o in out ], dates [:2 ])
316
322
self .assertEqual ([o ["count" ] for o in out ], num_geos_per_date [:2 ])
317
323
318
324
with self .subTest ("invalid geo_type" ):
319
- out = self ._fetch ("/coverage" , signal = first .signal_pair , geo_type = "doesnt_exist" , format = "json" )
325
+ out = self ._fetch ("/coverage" , signal = first .signal_pair () , geo_type = "doesnt_exist" , format = "json" )
320
326
self .assertEqual (len (out ), 0 )
0 commit comments