Skip to content

allow string input in dwt_max_level #269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 15, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions pywt/_dwt.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
import sys
import numpy as np
from numbers import Number

from ._extensions._pywt import Wavelet, Modes, _check_dtype

from ._extensions._pywt import (Wavelet, Modes, _check_dtype, wavelist)
from ._extensions._dwt import (dwt_single, dwt_axis, idwt_single, idwt_axis,
upcoef as _upcoef, downcoef as _downcoef,
dwt_max_level as _dwt_max_level,
dwt_coeff_len as _dwt_coeff_len)

__all__ = ["dwt", "idwt", "downcoef", "upcoef", "dwt_max_level", "dwt_coeff_len"]
__all__ = ["dwt", "idwt", "downcoef", "upcoef", "dwt_max_level",
"dwt_coeff_len"]

# define string_types as in six for Python 2/3 compatibility
if sys.version_info[0] == 3:
string_types = str,
else:
string_types = basestring,


def dwt_max_level(data_len, filter_len):
Expand All @@ -19,8 +29,9 @@ def dwt_max_level(data_len, filter_len):
----------
data_len : int
Input data length.
filter_len : int
Wavelet filter length.
filter_len : int, str or Wavelet
The wavelet filter length. Alternatively, the name of a discrete
wavelet or a Wavelet object can be specified.

Returns
-------
Expand All @@ -35,9 +46,26 @@ def dwt_max_level(data_len, filter_len):
6
>>> pywt.dwt_max_level(1000, w)
6
>>> pywt.dwt_max_level(1000, 'sym5')
6
"""
if isinstance(filter_len, Wavelet):
filter_len = filter_len.dec_len
elif isinstance(filter_len, string_types):
if filter_len in wavelist(kind='discrete'):
filter_len = Wavelet(filter_len).dec_len
else:
raise ValueError(
("'{}', is not a recognized discrete wavelet. A list of "
"supported wavelet names can be obtained via "
"pywt.wavelist(kind='discrete')").format(filter_len))
elif not (isinstance(filter_len, Number) and filter_len % 1 == 0):
raise ValueError(
"filter_len must be an integer, discrete Wavelet object, or the "
"name of a discrete wavelet.")

if filter_len < 2:
raise ValueError("invalid wavelet filter length")

return _dwt_max_level(data_len, filter_len)

Expand Down
14 changes: 14 additions & 0 deletions pywt/tests/test__pywt.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,22 @@ def test_dwt_max_level():
assert_(pywt.dwt_max_level(16, 8) == 1)
assert_(pywt.dwt_max_level(16, 9) == 1)
assert_(pywt.dwt_max_level(16, 10) == 0)
assert_(pywt.dwt_max_level(16, np.int8(10)) == 0)
assert_(pywt.dwt_max_level(16, 10.) == 0)
assert_(pywt.dwt_max_level(16, 18) == 0)

# accepts discrete Wavelet object or string as well
assert_(pywt.dwt_max_level(32, pywt.Wavelet('sym5')) == 1)
assert_(pywt.dwt_max_level(32, 'sym5') == 1)

# string input that is not a discrete wavelet
assert_raises(ValueError, pywt.dwt_max_level, 16, 'mexh')

# filter_len must be an integer >= 2
assert_raises(ValueError, pywt.dwt_max_level, 16, 1)
assert_raises(ValueError, pywt.dwt_max_level, 16, -1)
assert_raises(ValueError, pywt.dwt_max_level, 16, 3.3)


def test_ContinuousWavelet_errs():
assert_raises(ValueError, pywt.ContinuousWavelet, 'qwertz')
Expand Down