Skip to content

Commit b88b3d9

Browse files
authored
Fix order of start/end values in audinterface.Segment (#136)
* Add failing tests * Fix segment for file given * Add more test ideas * Update tests * Fix process_file() * First part of fixing process_files() * Fix process_index() * Fix process_signal with start argument * Add test for process_signal_from_index * Extend tests * Add test for process_folder() * Fix empty line * Base file length on given end values
1 parent e615890 commit b88b3d9

File tree

2 files changed

+117
-65
lines changed

2 files changed

+117
-65
lines changed

audinterface/core/segment.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,8 @@ def process_file(
272272
).values[0]
273273
return audformat.segmented_index(
274274
files=[file] * len(index),
275-
starts=index.levels[0] + start,
276-
ends=index.levels[1] + start,
275+
starts=index.get_level_values('start') + start,
276+
ends=index.get_level_values('end') + start,
277277
)
278278

279279
def process_files(
@@ -324,8 +324,8 @@ def process_files(
324324
ends = []
325325
for (file, start, _), index in y.items():
326326
files.extend([file] * len(index))
327-
starts.extend(index.levels[0] + start)
328-
ends.extend(index.levels[1] + start)
327+
starts.extend(index.get_level_values('start') + start)
328+
ends.extend(index.get_level_values('end') + start)
329329

330330
return audformat.segmented_index(files, starts, ends)
331331

@@ -416,8 +416,8 @@ def process_index(
416416
ends = []
417417
for (file, start, _), index in y.items():
418418
files.extend([file] * len(index))
419-
starts.extend(index.levels[0] + start)
420-
ends.extend(index.levels[1] + start)
419+
starts.extend(index.get_level_values('start') + start)
420+
ends.extend(index.get_level_values('end') + start)
421421

422422
return audformat.segmented_index(files, starts, ends)
423423

@@ -466,6 +466,12 @@ def process_signal(
466466
).values[0]
467467
utils.assert_index(index)
468468
if start is not None:
469+
start = utils.to_timedelta(start)
470+
# Here we change directly the levels,
471+
# so we need to use
472+
# `index.levels[0]`
473+
# instead of
474+
# `index.get_level_values('start')`
469475
index = index.set_levels(
470476
[
471477
index.levels[0] + start,
@@ -476,9 +482,10 @@ def process_signal(
476482
if file is not None:
477483
index = audformat.segmented_index(
478484
files=[file] * len(index),
479-
starts=index.levels[0],
480-
ends=index.levels[1],
485+
starts=index.get_level_values('start'),
486+
ends=index.get_level_values('end'),
481487
)
488+
482489
return index
483490

484491
def process_signal_from_index(

tests/test_process.py

Lines changed: 102 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import audformat
99
import audiofile
1010
import audiofile as af
11+
import audmath
1112
import audobject
1213

1314
import audinterface
@@ -1521,102 +1522,146 @@ def process_func(signal, sampling_rate, idx, file, root):
15211522

15221523

15231524
@pytest.mark.parametrize(
1524-
'segment',
1525+
# `starts` and `ends`
1526+
# are used to create a segment object
1527+
# using audinterface.utils.signal_index()
1528+
'starts, ends',
15251529
[
1526-
audinterface.Segment(
1527-
process_func=lambda x, sr: audinterface.utils.signal_index()
1528-
),
1529-
audinterface.Segment(
1530-
process_func=lambda x, sr:
1531-
audinterface.utils.signal_index(
1532-
pd.to_timedelta(0),
1533-
pd.to_timedelta(x.shape[1] / sr, unit='s') / 2,
1534-
)
1535-
),
1536-
audinterface.Segment(
1537-
process_func=lambda x, sr:
1538-
audinterface.utils.signal_index(
1539-
pd.to_timedelta(x.shape[1] / sr, unit='s') / 2,
1540-
pd.to_timedelta(x.shape[1] / sr, unit='s'),
1541-
)
1542-
),
1543-
audinterface.Segment(
1544-
process_func=lambda x, sr:
1545-
audinterface.utils.signal_index(
1546-
[
1547-
pd.to_timedelta(0),
1548-
pd.to_timedelta(x.shape[1] / sr, unit='s') / 2,
1549-
],
1550-
[
1551-
pd.to_timedelta(x.shape[1] / sr, unit='s') / 2,
1552-
pd.to_timedelta(x.shape[1] / sr),
1553-
],
1554-
)
1555-
)
1530+
(None, None),
1531+
(0, 1.5),
1532+
(1.5, 3),
1533+
([0, 1.5], [1.5, 3]),
1534+
# Blocked by https://github.com/audeering/audinterface/issues/134
1535+
# or a similar issue
1536+
# ([0, 1.5], [1, 2.000000003]),
1537+
([0, 2], [1, 3]),
1538+
([0, 1], [2, 2]),
1539+
# https://github.com/audeering/audinterface/issues/135
1540+
([0, 1], [3, 2]),
15561541
]
15571542
)
1558-
def test_process_with_segment(tmpdir, segment):
1543+
def test_process_with_segment(tmpdir, starts, ends):
15591544

1560-
process = audinterface.Process()
1561-
process_with_segment = audinterface.Process(
1562-
segment=segment,
1545+
# Segment and process objects
1546+
segment = audinterface.Segment(
1547+
process_func=lambda x, sr:
1548+
audinterface.utils.signal_index(starts, ends)
15631549
)
1550+
process = audinterface.Process()
1551+
process_with_segment = audinterface.Process(segment=segment)
15641552

1565-
# create signal and file
1553+
# Create signal and file
15661554
sampling_rate = 8000
1567-
signal = np.zeros((1, sampling_rate))
1555+
if ends is None:
1556+
duration = 1
1557+
else:
1558+
duration = audmath.duration_in_seconds(
1559+
max(audeer.to_list(ends))
1560+
)
1561+
signal = np.zeros((1, audmath.samples(duration, sampling_rate)))
15681562
root = tmpdir
15691563
file = 'file.wav'
15701564
path = os.path.join(root, file)
15711565
audiofile.write(path, signal, sampling_rate)
15721566

1567+
# Expected index
1568+
if starts is None:
1569+
files = None
1570+
files_abs = None
1571+
else:
1572+
files = [file] * len(audeer.to_list(starts))
1573+
files_abs = [audeer.path(root, file) for file in files]
1574+
expected = audformat.segmented_index(files, starts, ends)
1575+
expected_folder_index = audformat.segmented_index(files_abs, starts, ends)
1576+
expected_signal_index = audinterface.utils.signal_index(starts, ends)
1577+
15731578
# process signal
1574-
index = segment.process_signal(
1579+
index = segment.process_signal(signal, sampling_rate)
1580+
pd.testing.assert_index_equal(index, expected_signal_index)
1581+
1582+
# process signal with start argument
1583+
index = segment.process_signal(signal, sampling_rate, start=0)
1584+
pd.testing.assert_index_equal(index, expected_signal_index)
1585+
1586+
# process signal with file argument
1587+
index = segment.process_signal(signal, sampling_rate, file=file)
1588+
pd.testing.assert_index_equal(index, expected)
1589+
1590+
pd.testing.assert_series_equal(
1591+
process.process_index(index, root=root, preserve_index=True),
1592+
process_with_segment.process_signal(signal, sampling_rate, file=file)
1593+
)
1594+
1595+
# process signal from index
1596+
index = segment.process_signal_from_index(
15751597
signal,
15761598
sampling_rate,
1577-
file=file,
1599+
audinterface.utils.signal_index(0, duration),
15781600
)
1579-
pd.testing.assert_series_equal(
1580-
process.process_index(index, root=root),
1581-
process_with_segment.process_signal(
1582-
signal,
1583-
sampling_rate,
1584-
file=file,
1585-
)
1601+
pd.testing.assert_index_equal(index, expected_signal_index)
1602+
index = segment.process_signal_from_index(
1603+
signal,
1604+
sampling_rate,
1605+
audformat.segmented_index(file, 0, duration),
15861606
)
1607+
pd.testing.assert_index_equal(index, expected)
15871608
index = segment.process_signal_from_index(
15881609
signal,
15891610
sampling_rate,
15901611
audformat.filewise_index(file),
15911612
)
1613+
pd.testing.assert_index_equal(index, expected)
1614+
15921615
pd.testing.assert_series_equal(
1593-
process.process_index(index, root=root),
1616+
process.process_index(index, root=root, preserve_index=True),
15941617
process_with_segment.process_signal_from_index(
15951618
signal,
15961619
sampling_rate,
15971620
audformat.filewise_index(file),
1598-
)
1621+
),
15991622
)
16001623

16011624
# process file
16021625
index = segment.process_file(file, root=root)
1626+
pd.testing.assert_index_equal(index, expected)
1627+
16031628
pd.testing.assert_series_equal(
1604-
process.process_index(index, root=root),
1605-
process_with_segment.process_file(file, root=root)
1606-
)
1607-
index = segment.process_index(
1608-
audformat.filewise_index(file),
1609-
root=root,
1629+
process.process_index(index, root=root, preserve_index=True),
1630+
process_with_segment.process_file(file, root=root),
16101631
)
1632+
1633+
# process files
1634+
index = segment.process_files([file], root=root)
1635+
pd.testing.assert_index_equal(index, expected)
1636+
1637+
# https://github.com/audeering/audinterface/issues/138
1638+
# pd.testing.assert_series_equal(
1639+
# process.process_index(index, root=root, preserve_index=True),
1640+
# process_with_segment.process_files([file], root=root)
1641+
# )
1642+
1643+
# process folder
1644+
index = segment.process_folder(root)
1645+
pd.testing.assert_index_equal(index, expected_folder_index)
1646+
1647+
# https://github.com/audeering/audinterface/issues/139
1648+
# pd.testing.assert_series_equal(
1649+
# process.process_index(index, root=root, preserve_index=True),
1650+
# process_with_segment.process_folder(root),
1651+
# )
1652+
1653+
# process index
1654+
index = segment.process_index(audformat.filewise_index(file), root=root)
1655+
pd.testing.assert_index_equal(index, expected)
1656+
16111657
pd.testing.assert_series_equal(
1612-
process.process_index(index, root=root),
1658+
process.process_index(index, root=root, preserve_index=True),
16131659
process_with_segment.process_index(
16141660
audformat.filewise_index(file),
16151661
root=root,
1616-
)
1662+
),
16171663
)
16181664

1619-
16201665
def test_read_audio(tmpdir):
16211666
sampling_rate = 8000
16221667
signal = np.ones((1, 8000))

0 commit comments

Comments
 (0)