Skip to content

Commit 18fc2bb

Browse files
committed
ENH: dwt_max_level: also allow the name of a wavelet to be used as the 2nd argument.
update docstring to indicate that a wavelet can be passed in place of the filter length raise ValueError if filter_len < 2 formerly filter_len = 0 or 1 returned 0 while a negative filter_len raised an OverflowError raise informative ValueError on unrecognized string or non-integer input
1 parent 3806a79 commit 18fc2bb

File tree

2 files changed

+46
-4
lines changed

2 files changed

+46
-4
lines changed

pywt/_dwt.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
1+
import sys
12
import numpy as np
3+
from numbers import Number
24

3-
from ._extensions._pywt import Wavelet, Modes, _check_dtype
5+
6+
from ._extensions._pywt import (Wavelet, Modes, _check_dtype, wavelist)
47
from ._extensions._dwt import (dwt_single, dwt_axis, idwt_single, idwt_axis,
58
upcoef as _upcoef, downcoef as _downcoef,
69
dwt_max_level as _dwt_max_level,
710
dwt_coeff_len as _dwt_coeff_len)
811

9-
__all__ = ["dwt", "idwt", "downcoef", "upcoef", "dwt_max_level", "dwt_coeff_len"]
12+
__all__ = ["dwt", "idwt", "downcoef", "upcoef", "dwt_max_level",
13+
"dwt_coeff_len"]
14+
15+
# define string_types as in six for Python 2/3 compatibility
16+
if sys.version_info[0] == 3:
17+
string_types = str,
18+
else:
19+
string_types = basestring,
1020

1121

1222
def dwt_max_level(data_len, filter_len):
@@ -19,8 +29,9 @@ def dwt_max_level(data_len, filter_len):
1929
----------
2030
data_len : int
2131
Input data length.
22-
filter_len : int
23-
Wavelet filter length.
32+
filter_len : int, str or Wavelet
33+
The wavelet filter length. Alternatively, the name of a discrete
34+
wavelet or a Wavelet object can be specified.
2435
2536
Returns
2637
-------
@@ -35,9 +46,26 @@ def dwt_max_level(data_len, filter_len):
3546
6
3647
>>> pywt.dwt_max_level(1000, w)
3748
6
49+
>>> pywt.dwt_max_level(1000, 'sym5')
50+
6
3851
"""
3952
if isinstance(filter_len, Wavelet):
4053
filter_len = filter_len.dec_len
54+
elif isinstance(filter_len, string_types):
55+
if filter_len in wavelist(kind='discrete'):
56+
filter_len = Wavelet(filter_len).dec_len
57+
else:
58+
raise ValueError(
59+
("'{}', is not a recognized discrete wavelet. A list of "
60+
"supported wavelet names can be obtained via "
61+
"pywt.wavelist(kind='discrete')").format(filter_len))
62+
elif not (isinstance(filter_len, Number) and filter_len % 1 == 0):
63+
raise ValueError(
64+
"filter_len must be an integer, discrete Wavelet object, or the "
65+
"name of a discrete wavelet.")
66+
67+
if filter_len < 2:
68+
raise ValueError("invalid wavelet filter length")
4169

4270
return _dwt_max_level(data_len, filter_len)
4371

pywt/tests/test__pywt.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,22 @@ def test_dwt_max_level():
106106
assert_(pywt.dwt_max_level(16, 8) == 1)
107107
assert_(pywt.dwt_max_level(16, 9) == 1)
108108
assert_(pywt.dwt_max_level(16, 10) == 0)
109+
assert_(pywt.dwt_max_level(16, np.int8(10)) == 0)
110+
assert_(pywt.dwt_max_level(16, 10.) == 0)
109111
assert_(pywt.dwt_max_level(16, 18) == 0)
110112

113+
# accepts discrete Wavelet object or string as well
114+
assert_(pywt.dwt_max_level(32, pywt.Wavelet('sym5')) == 1)
115+
assert_(pywt.dwt_max_level(32, 'sym5') == 1)
116+
117+
# string input that is not a discrete wavelet
118+
assert_raises(ValueError, pywt.dwt_max_level, 16, 'mexh')
119+
120+
# filter_len must be an integer >= 2
121+
assert_raises(ValueError, pywt.dwt_max_level, 16, 1)
122+
assert_raises(ValueError, pywt.dwt_max_level, 16, -1)
123+
assert_raises(ValueError, pywt.dwt_max_level, 16, 3.3)
124+
111125

112126
def test_ContinuousWavelet_errs():
113127
assert_raises(ValueError, pywt.ContinuousWavelet, 'qwertz')

0 commit comments

Comments
 (0)