Skip to content

Commit dd3fae3

Browse files
authored
Merge pull request #1005 from vanandrew/zstd_support
NF: Add zstd compression support
2 parents a0d2534 + 258d0cd commit dd3fae3

13 files changed

+129
-34
lines changed

nibabel/analyze.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,7 @@ class AnalyzeImage(SpatialImage):
906906
_meta_sniff_len = header_class.sizeof_hdr
907907
files_types = (('image', '.img'), ('header', '.hdr'))
908908
valid_exts = ('.img', '.hdr')
909-
_compressed_suffixes = ('.gz', '.bz2')
909+
_compressed_suffixes = ('.gz', '.bz2', '.zst')
910910

911911
makeable = True
912912
rw = True

nibabel/benchmarks/bench_fileslice.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
from ..fileslice import fileslice
1919
from ..rstutils import rst_table
2020
from ..tmpdirs import InTemporaryDirectory
21+
from ..optpkg import optional_package
2122

2223
SHAPE = (64, 64, 32, 100)
2324
ROW_NAMES = [f'axis {i}, len {dim}' for i, dim in enumerate(SHAPE)]
2425
COL_NAMES = ['mid int',
2526
'step 1',
2627
'half step 1',
2728
'step mid int']
29+
HAVE_ZSTD = optional_package("pyzstd")[1]
2830

2931

3032
def _slices_for_len(L):
@@ -70,7 +72,8 @@ def g():
7072
def bench_fileslice(bytes=True,
7173
file_=True,
7274
gz=True,
73-
bz2=False):
75+
bz2=False,
76+
zst=True):
7477
sys.stdout.flush()
7578
repeat = 2
7679

@@ -103,4 +106,10 @@ def my_table(title, times, base):
103106
my_table('bz2 slice - raw (ratio)',
104107
np.dstack((bz2_times, bz2_times / bz2_base)),
105108
bz2_base)
109+
if zst and HAVE_ZSTD:
110+
with InTemporaryDirectory():
111+
zst_times, zst_base = run_slices('data.zst', repeat)
112+
my_table('zst slice - raw (ratio)',
113+
np.dstack((zst_times, zst_times / zst_base)),
114+
zst_base)
106115
sys.stdout.flush()

nibabel/brikhead.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ class AFNIImage(SpatialImage):
490490
header_class = AFNIHeader
491491
valid_exts = ('.brik', '.head')
492492
files_types = (('image', '.brik'), ('header', '.head'))
493-
_compressed_suffixes = ('.gz', '.bz2', '.Z')
493+
_compressed_suffixes = ('.gz', '.bz2', '.Z', '.zst')
494494
makeable = False
495495
rw = False
496496
ImageArrayProxy = AFNIArrayProxy

nibabel/loadsave.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from .arrayproxy import is_proxy
2020
from .deprecated import deprecate_with_version
2121

22+
_compressed_suffixes = ('.gz', '.bz2', '.zst')
23+
2224

2325
def _signature_matches_extension(filename, sniff):
2426
"""Check if signature aka magic number matches filename extension.
@@ -153,7 +155,7 @@ def save(img, filename):
153155
return
154156

155157
# Be nice to users by making common implicit conversions
156-
froot, ext, trailing = splitext_addext(filename, ('.gz', '.bz2'))
158+
froot, ext, trailing = splitext_addext(filename, _compressed_suffixes)
157159
lext = ext.lower()
158160

159161
# Special-case Nifti singles and Pairs

nibabel/minc1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ class Minc1Image(SpatialImage):
316316
_meta_sniff_len = 4
317317
valid_exts = ('.mnc',)
318318
files_types = (('image', '.mnc'),)
319-
_compressed_suffixes = ('.gz', '.bz2')
319+
_compressed_suffixes = ('.gz', '.bz2', '.zst')
320320

321321
makeable = True
322322
rw = False

nibabel/openers.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from os.path import splitext
1616
from distutils.version import StrictVersion
1717

18+
from nibabel.optpkg import optional_package
19+
1820
# is indexed_gzip present and modern?
1921
try:
2022
import indexed_gzip as igzip
@@ -55,6 +57,12 @@ def _gzip_open(filename, mode='rb', compresslevel=9, keep_open=False):
5557
return gzip_file
5658

5759

60+
def _zstd_open(filename, mode="r", *, level_or_option=None, zstd_dict=None):
61+
pyzstd = optional_package("pyzstd")[0]
62+
return pyzstd.ZstdFile(filename, mode,
63+
level_or_option=level_or_option, zstd_dict=zstd_dict)
64+
65+
5866
class Opener(object):
5967
r""" Class to accept, maybe open, and context-manage file-likes / filenames
6068
@@ -77,13 +85,20 @@ class Opener(object):
7785
"""
7886
gz_def = (_gzip_open, ('mode', 'compresslevel', 'keep_open'))
7987
bz2_def = (BZ2File, ('mode', 'buffering', 'compresslevel'))
88+
zstd_def = (_zstd_open, ('mode', 'level_or_option', 'zstd_dict'))
8089
compress_ext_map = {
8190
'.gz': gz_def,
8291
'.bz2': bz2_def,
92+
'.zst': zstd_def,
8393
None: (open, ('mode', 'buffering')) # default
8494
}
8595
#: default compression level when writing gz and bz2 files
8696
default_compresslevel = 1
97+
#: default option for zst files
98+
default_zst_compresslevel = 3
99+
default_level_or_option = {"rb": None, "r": None,
100+
"wb": default_zst_compresslevel,
101+
"w": default_zst_compresslevel}
87102
#: whether to ignore case looking for compression extensions
88103
compress_ext_icase = True
89104

@@ -100,10 +115,15 @@ def __init__(self, fileish, *args, **kwargs):
100115
full_kwargs.update(dict(zip(arg_names[:n_args], args)))
101116
# Set default mode
102117
if 'mode' not in full_kwargs:
103-
kwargs['mode'] = 'rb'
118+
mode = 'rb'
119+
kwargs['mode'] = mode
120+
else:
121+
mode = full_kwargs['mode']
104122
# Default compression level
105123
if 'compresslevel' in arg_names and 'compresslevel' not in kwargs:
106124
kwargs['compresslevel'] = self.default_compresslevel
125+
if 'level_or_option' in arg_names and 'level_or_option' not in kwargs:
126+
kwargs['level_or_option'] = self.default_level_or_option[mode]
107127
# Default keep_open hint
108128
if 'keep_open' in arg_names:
109129
kwargs.setdefault('keep_open', False)

nibabel/tests/test_analyze.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ..casting import as_int
3131
from ..tmpdirs import InTemporaryDirectory
3232
from ..arraywriters import WriterError
33+
from ..optpkg import optional_package
3334

3435
import pytest
3536
from numpy.testing import (assert_array_equal, assert_array_almost_equal)
@@ -40,6 +41,8 @@
4041
from .test_wrapstruct import _TestLabeledWrapStruct
4142
from . import test_spatialimages as tsi
4243

44+
HAVE_ZSTD = optional_package("pyzstd")[1]
45+
4346
header_file = os.path.join(data_path, 'analyze.hdr')
4447

4548
PIXDIM0_MSG = 'pixdim[1,2,3] should be non-zero; setting 0 dims to 1'
@@ -788,6 +791,8 @@ def test_big_offset_exts(self):
788791
aff = np.eye(4)
789792
img_ext = img_klass.files_types[0][1]
790793
compressed_exts = ['', '.gz', '.bz2']
794+
if HAVE_ZSTD:
795+
compressed_exts += ['.zst']
791796
with InTemporaryDirectory():
792797
for offset in (0, 2048):
793798
# Set offset in in-memory image

nibabel/tests/test_minc1.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ..deprecated import ModuleProxy
2323
from .. import minc1
2424
from ..minc1 import Minc1File, Minc1Image, MincHeader
25+
from ..optpkg import optional_package
2526

2627
from ..tmpdirs import InTemporaryDirectory
2728
from ..deprecator import ExpiredDeprecationError
@@ -32,6 +33,8 @@
3233
from . import test_spatialimages as tsi
3334
from .test_fileslice import slicer_samples
3435

36+
pyzstd, HAVE_ZSTD, _ = optional_package("pyzstd")
37+
3538
EG_FNAME = pjoin(data_path, 'tiny.mnc')
3639

3740
# Example images in format expected for ``test_image_api``, adding ``zooms``
@@ -170,7 +173,10 @@ def test_compressed(self):
170173
# Not so for MINC2; hence this small sub-class
171174
for tp in self.test_files:
172175
content = open(tp['fname'], 'rb').read()
173-
openers_exts = ((gzip.open, '.gz'), (bz2.BZ2File, '.bz2'))
176+
openers_exts = [(gzip.open, '.gz'),
177+
(bz2.BZ2File, '.bz2')]
178+
if HAVE_ZSTD: # add .zst to test if installed
179+
openers_exts += [(pyzstd.ZstdFile, '.zst')]
174180
with InTemporaryDirectory():
175181
for opener, ext in openers_exts:
176182
fname = 'test.mnc' + ext

nibabel/tests/test_openers.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,22 @@
1414
from distutils.version import StrictVersion
1515

1616
from numpy.compat.py3k import asstr, asbytes
17-
from ..openers import Opener, ImageOpener, HAVE_INDEXED_GZIP, BZ2File
17+
from ..openers import (Opener,
18+
ImageOpener,
19+
HAVE_INDEXED_GZIP,
20+
BZ2File,
21+
)
1822
from ..tmpdirs import InTemporaryDirectory
1923
from ..volumeutils import BinOpener
24+
from ..optpkg import optional_package
2025

2126
import unittest
2227
from unittest import mock
2328
import pytest
2429
from ..testing import error_warnings
2530

31+
pyzstd, HAVE_ZSTD, _ = optional_package("pyzstd")
32+
2633

2734
class Lunk(object):
2835
# bare file-like for testing
@@ -71,10 +78,13 @@ def test_Opener_various():
7178
import indexed_gzip as igzip
7279
with InTemporaryDirectory():
7380
sobj = BytesIO()
74-
for input in ('test.txt',
75-
'test.txt.gz',
76-
'test.txt.bz2',
77-
sobj):
81+
files_to_test = ['test.txt',
82+
'test.txt.gz',
83+
'test.txt.bz2',
84+
sobj]
85+
if HAVE_ZSTD:
86+
files_to_test += ['test.txt.zst']
87+
for input in files_to_test:
7888
with Opener(input, 'wb') as fobj:
7989
fobj.write(message)
8090
assert fobj.tell() == len(message)
@@ -240,6 +250,8 @@ def test_compressed_ext_case():
240250
class StrictOpener(Opener):
241251
compress_ext_icase = False
242252
exts = ('gz', 'bz2', 'GZ', 'gZ', 'BZ2', 'Bz2')
253+
if HAVE_ZSTD:
254+
exts += ('zst', 'ZST', 'Zst')
243255
with InTemporaryDirectory():
244256
# Make a basic file to check type later
245257
with open(__file__, 'rb') as a_file:
@@ -264,6 +276,8 @@ class StrictOpener(Opener):
264276
except ImportError:
265277
IndexedGzipFile = GzipFile
266278
assert isinstance(fobj.fobj, (GzipFile, IndexedGzipFile))
279+
elif lext == 'zst':
280+
assert isinstance(fobj.fobj, pyzstd.ZstdFile)
267281
else:
268282
assert isinstance(fobj.fobj, BZ2File)
269283

@@ -273,11 +287,14 @@ def test_name():
273287
sobj = BytesIO()
274288
lunk = Lunk('in ART')
275289
with InTemporaryDirectory():
276-
for input in ('test.txt',
277-
'test.txt.gz',
278-
'test.txt.bz2',
279-
sobj,
280-
lunk):
290+
files_to_test = ['test.txt',
291+
'test.txt.gz',
292+
'test.txt.bz2',
293+
sobj,
294+
lunk]
295+
if HAVE_ZSTD:
296+
files_to_test += ['test.txt.zst']
297+
for input in files_to_test:
281298
exp_name = input if type(input) == type('') else None
282299
with Opener(input, 'wb') as fobj:
283300
assert fobj.name == exp_name
@@ -329,10 +346,13 @@ def test_iter():
329346
""".split('\n')
330347
with InTemporaryDirectory():
331348
sobj = BytesIO()
332-
for input, does_t in (('test.txt', True),
333-
('test.txt.gz', False),
334-
('test.txt.bz2', False),
335-
(sobj, True)):
349+
files_to_test = [('test.txt', True),
350+
('test.txt.gz', False),
351+
('test.txt.bz2', False),
352+
(sobj, True)]
353+
if HAVE_ZSTD:
354+
files_to_test += [('test.txt.zst', False)]
355+
for input, does_t in files_to_test:
336356
with Opener(input, 'wb') as fobj:
337357
for line in lines:
338358
fobj.write(asbytes(line + os.linesep))

0 commit comments

Comments
 (0)