Skip to content

Commit 1ae00c6

Browse files
authored
PERF: use bisect_right_i8 in vectorized (#46341)
1 parent 7143c44 commit 1ae00c6

File tree

1 file changed

+62
-73
lines changed

1 file changed

+62
-73
lines changed

pandas/_libs/tslibs/vectorized.pyx

Lines changed: 62 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@ from cpython.datetime cimport (
99

1010
import numpy as np
1111

12+
cimport numpy as cnp
1213
from numpy cimport (
1314
int64_t,
1415
intp_t,
1516
ndarray,
1617
)
1718

19+
cnp.import_array()
20+
1821
from .conversion cimport normalize_i8_stamp
1922

2023
from .dtypes import Resolution
@@ -35,52 +38,13 @@ from .timezones cimport (
3538
is_tzlocal,
3639
is_utc,
3740
)
38-
from .tzconversion cimport tz_convert_utc_to_tzlocal
41+
from .tzconversion cimport (
42+
bisect_right_i8,
43+
tz_convert_utc_to_tzlocal,
44+
)
3945

4046
# -------------------------------------------------------------------------
4147

42-
cdef inline object create_datetime_from_ts(
43-
int64_t value,
44-
npy_datetimestruct dts,
45-
tzinfo tz,
46-
object freq,
47-
bint fold,
48-
):
49-
"""
50-
Convenience routine to construct a datetime.datetime from its parts.
51-
"""
52-
return datetime(
53-
dts.year, dts.month, dts.day, dts.hour, dts.min, dts.sec, dts.us,
54-
tz, fold=fold,
55-
)
56-
57-
58-
cdef inline object create_date_from_ts(
59-
int64_t value,
60-
npy_datetimestruct dts,
61-
tzinfo tz,
62-
object freq,
63-
bint fold
64-
):
65-
"""
66-
Convenience routine to construct a datetime.date from its parts.
67-
"""
68-
# GH#25057 add fold argument to match other func_create signatures
69-
return date(dts.year, dts.month, dts.day)
70-
71-
72-
cdef inline object create_time_from_ts(
73-
int64_t value,
74-
npy_datetimestruct dts,
75-
tzinfo tz,
76-
object freq,
77-
bint fold
78-
):
79-
"""
80-
Convenience routine to construct a datetime.time from its parts.
81-
"""
82-
return time(dts.hour, dts.min, dts.sec, dts.us, tz, fold=fold)
83-
8448

8549
@cython.wraparound(False)
8650
@cython.boundscheck(False)
@@ -119,29 +83,29 @@ def ints_to_pydatetime(
11983
ndarray[object] of type specified by box
12084
"""
12185
cdef:
122-
Py_ssize_t i, n = len(stamps)
86+
Py_ssize_t i, ntrans =- 1, n = len(stamps)
12387
ndarray[int64_t] trans
12488
int64_t[::1] deltas
125-
intp_t[:] pos
89+
int64_t* tdata = NULL
90+
intp_t pos
12691
npy_datetimestruct dts
12792
object dt, new_tz
12893
str typ
12994
int64_t value, local_val, delta = NPY_NAT # dummy for delta
13095
ndarray[object] result = np.empty(n, dtype=object)
131-
object (*func_create)(int64_t, npy_datetimestruct, tzinfo, object, bint)
13296
bint use_utc = False, use_tzlocal = False, use_fixed = False
13397
bint use_pytz = False
98+
bint use_date = False, use_time = False, use_ts = False, use_pydt = False
13499

135100
if box == "date":
136101
assert (tz is None), "tz should be None when converting to date"
137-
138-
func_create = create_date_from_ts
102+
use_date = True
139103
elif box == "timestamp":
140-
func_create = create_timestamp_from_ts
104+
use_ts = True
141105
elif box == "time":
142-
func_create = create_time_from_ts
106+
use_time = True
143107
elif box == "datetime":
144-
func_create = create_datetime_from_ts
108+
use_pydt = True
145109
else:
146110
raise ValueError(
147111
"box must be one of 'datetime', 'date', 'time' or 'timestamp'"
@@ -153,12 +117,13 @@ def ints_to_pydatetime(
153117
use_tzlocal = True
154118
else:
155119
trans, deltas, typ = get_dst_info(tz)
120+
ntrans = trans.shape[0]
156121
if typ not in ["pytz", "dateutil"]:
157122
# static/fixed; in this case we know that len(delta) == 1
158123
use_fixed = True
159124
delta = deltas[0]
160125
else:
161-
pos = trans.searchsorted(stamps, side="right") - 1
126+
tdata = <int64_t*>cnp.PyArray_DATA(trans)
162127
use_pytz = typ == "pytz"
163128

164129
for i in range(n):
@@ -176,14 +141,26 @@ def ints_to_pydatetime(
176141
elif use_fixed:
177142
local_val = value + delta
178143
else:
179-
local_val = value + deltas[pos[i]]
144+
pos = bisect_right_i8(tdata, value, ntrans) - 1
145+
local_val = value + deltas[pos]
180146

181-
if use_pytz:
182-
# find right representation of dst etc in pytz timezone
183-
new_tz = tz._tzinfos[tz._transition_info[pos[i]]]
147+
if use_pytz:
148+
# find right representation of dst etc in pytz timezone
149+
new_tz = tz._tzinfos[tz._transition_info[pos]]
184150

185151
dt64_to_dtstruct(local_val, &dts)
186-
result[i] = func_create(value, dts, new_tz, freq, fold)
152+
153+
if use_ts:
154+
result[i] = create_timestamp_from_ts(value, dts, new_tz, freq, fold)
155+
elif use_pydt:
156+
result[i] = datetime(
157+
dts.year, dts.month, dts.day, dts.hour, dts.min, dts.sec, dts.us,
158+
new_tz, fold=fold,
159+
)
160+
elif use_date:
161+
result[i] = date(dts.year, dts.month, dts.day)
162+
else:
163+
result[i] = time(dts.hour, dts.min, dts.sec, dts.us, new_tz, fold=fold)
187164

188165
return result
189166

@@ -219,12 +196,13 @@ cdef inline int _reso_stamp(npy_datetimestruct *dts):
219196

220197
def get_resolution(const int64_t[:] stamps, tzinfo tz=None) -> Resolution:
221198
cdef:
222-
Py_ssize_t i, n = len(stamps)
199+
Py_ssize_t i, ntrans=-1, n = len(stamps)
223200
npy_datetimestruct dts
224201
int reso = RESO_DAY, curr_reso
225202
ndarray[int64_t] trans
226203
int64_t[::1] deltas
227-
intp_t[:] pos
204+
int64_t* tdata = NULL
205+
intp_t pos
228206
int64_t local_val, delta = NPY_NAT
229207
bint use_utc = False, use_tzlocal = False, use_fixed = False
230208

@@ -234,12 +212,13 @@ def get_resolution(const int64_t[:] stamps, tzinfo tz=None) -> Resolution:
234212
use_tzlocal = True
235213
else:
236214
trans, deltas, typ = get_dst_info(tz)
215+
ntrans = trans.shape[0]
237216
if typ not in ["pytz", "dateutil"]:
238217
# static/fixed; in this case we know that len(delta) == 1
239218
use_fixed = True
240219
delta = deltas[0]
241220
else:
242-
pos = trans.searchsorted(stamps, side="right") - 1
221+
tdata = <int64_t*>cnp.PyArray_DATA(trans)
243222

244223
for i in range(n):
245224
if stamps[i] == NPY_NAT:
@@ -252,7 +231,8 @@ def get_resolution(const int64_t[:] stamps, tzinfo tz=None) -> Resolution:
252231
elif use_fixed:
253232
local_val = stamps[i] + delta
254233
else:
255-
local_val = stamps[i] + deltas[pos[i]]
234+
pos = bisect_right_i8(tdata, stamps[i], ntrans) - 1
235+
local_val = stamps[i] + deltas[pos]
256236

257237
dt64_to_dtstruct(local_val, &dts)
258238
curr_reso = _reso_stamp(&dts)
@@ -282,12 +262,13 @@ cpdef ndarray[int64_t] normalize_i8_timestamps(const int64_t[:] stamps, tzinfo t
282262
result : int64 ndarray of converted of normalized nanosecond timestamps
283263
"""
284264
cdef:
285-
Py_ssize_t i, n = len(stamps)
265+
Py_ssize_t i, ntrans =- 1, n = len(stamps)
286266
int64_t[:] result = np.empty(n, dtype=np.int64)
287267
ndarray[int64_t] trans
288268
int64_t[::1] deltas
269+
int64_t* tdata = NULL
289270
str typ
290-
Py_ssize_t[:] pos
271+
Py_ssize_t pos
291272
int64_t local_val, delta = NPY_NAT
292273
bint use_utc = False, use_tzlocal = False, use_fixed = False
293274

@@ -297,12 +278,13 @@ cpdef ndarray[int64_t] normalize_i8_timestamps(const int64_t[:] stamps, tzinfo t
297278
use_tzlocal = True
298279
else:
299280
trans, deltas, typ = get_dst_info(tz)
281+
ntrans = trans.shape[0]
300282
if typ not in ["pytz", "dateutil"]:
301283
# static/fixed; in this case we know that len(delta) == 1
302284
use_fixed = True
303285
delta = deltas[0]
304286
else:
305-
pos = trans.searchsorted(stamps, side="right") - 1
287+
tdata = <int64_t*>cnp.PyArray_DATA(trans)
306288

307289
for i in range(n):
308290
if stamps[i] == NPY_NAT:
@@ -316,7 +298,8 @@ cpdef ndarray[int64_t] normalize_i8_timestamps(const int64_t[:] stamps, tzinfo t
316298
elif use_fixed:
317299
local_val = stamps[i] + delta
318300
else:
319-
local_val = stamps[i] + deltas[pos[i]]
301+
pos = bisect_right_i8(tdata, stamps[i], ntrans) - 1
302+
local_val = stamps[i] + deltas[pos]
320303

321304
result[i] = normalize_i8_stamp(local_val)
322305

@@ -341,10 +324,11 @@ def is_date_array_normalized(const int64_t[:] stamps, tzinfo tz=None) -> bool:
341324
is_normalized : bool True if all stamps are normalized
342325
"""
343326
cdef:
344-
Py_ssize_t i, n = len(stamps)
327+
Py_ssize_t i, ntrans =- 1, n = len(stamps)
345328
ndarray[int64_t] trans
346329
int64_t[::1] deltas
347-
intp_t[:] pos
330+
int64_t* tdata = NULL
331+
intp_t pos
348332
int64_t local_val, delta = NPY_NAT
349333
str typ
350334
int64_t day_nanos = 24 * 3600 * 1_000_000_000
@@ -356,12 +340,13 @@ def is_date_array_normalized(const int64_t[:] stamps, tzinfo tz=None) -> bool:
356340
use_tzlocal = True
357341
else:
358342
trans, deltas, typ = get_dst_info(tz)
343+
ntrans = trans.shape[0]
359344
if typ not in ["pytz", "dateutil"]:
360345
# static/fixed; in this case we know that len(delta) == 1
361346
use_fixed = True
362347
delta = deltas[0]
363348
else:
364-
pos = trans.searchsorted(stamps, side="right") - 1
349+
tdata = <int64_t*>cnp.PyArray_DATA(trans)
365350

366351
for i in range(n):
367352
if use_utc:
@@ -371,7 +356,8 @@ def is_date_array_normalized(const int64_t[:] stamps, tzinfo tz=None) -> bool:
371356
elif use_fixed:
372357
local_val = stamps[i] + delta
373358
else:
374-
local_val = stamps[i] + deltas[pos[i]]
359+
pos = bisect_right_i8(tdata, stamps[i], ntrans) - 1
360+
local_val = stamps[i] + deltas[pos]
375361

376362
if local_val % day_nanos != 0:
377363
return False
@@ -386,11 +372,12 @@ def is_date_array_normalized(const int64_t[:] stamps, tzinfo tz=None) -> bool:
386372
@cython.boundscheck(False)
387373
def dt64arr_to_periodarr(const int64_t[:] stamps, int freq, tzinfo tz):
388374
cdef:
389-
Py_ssize_t i, n = len(stamps)
375+
Py_ssize_t i, ntrans =- 1, n = len(stamps)
390376
int64_t[:] result = np.empty(n, dtype=np.int64)
391377
ndarray[int64_t] trans
392378
int64_t[::1] deltas
393-
Py_ssize_t[:] pos
379+
int64_t* tdata = NULL
380+
intp_t pos
394381
npy_datetimestruct dts
395382
int64_t local_val, delta = NPY_NAT
396383
bint use_utc = False, use_tzlocal = False, use_fixed = False
@@ -401,12 +388,13 @@ def dt64arr_to_periodarr(const int64_t[:] stamps, int freq, tzinfo tz):
401388
use_tzlocal = True
402389
else:
403390
trans, deltas, typ = get_dst_info(tz)
391+
ntrans = trans.shape[0]
404392
if typ not in ["pytz", "dateutil"]:
405393
# static/fixed; in this case we know that len(delta) == 1
406394
use_fixed = True
407395
delta = deltas[0]
408396
else:
409-
pos = trans.searchsorted(stamps, side="right") - 1
397+
tdata = <int64_t*>cnp.PyArray_DATA(trans)
410398

411399
for i in range(n):
412400
if stamps[i] == NPY_NAT:
@@ -420,7 +408,8 @@ def dt64arr_to_periodarr(const int64_t[:] stamps, int freq, tzinfo tz):
420408
elif use_fixed:
421409
local_val = stamps[i] + delta
422410
else:
423-
local_val = stamps[i] + deltas[pos[i]]
411+
pos = bisect_right_i8(tdata, stamps[i], ntrans) - 1
412+
local_val = stamps[i] + deltas[pos]
424413

425414
dt64_to_dtstruct(local_val, &dts)
426415
result[i] = get_period_ordinal(&dts, freq)

0 commit comments

Comments
 (0)