Source code for libdse.data.librispeech

"""Streaming PyTorch dataset for the LibriSpeech ASR corpus.

.. _LibriSpeech ASR corpus: https://www.openslr.org/12

This module provides :class:`LibriSpeechDataset`, an
:class:`~torch.utils.data.IterableDataset` that streams ``(sample, label)``
tensor pairs directly from raw FLAC audio files.  Feature extraction is fully
delegated to a :class:`~dae.data.features.BaseExtractor` instance supplied at
construction, keeping the dataset class decoupled from the specific feature
representation.

Internally the dataset discovers all FLAC files under *entry_point*, shuffles
them once at construction to reduce temporal correlation between consecutive
batches, and then iterates through each file.  For every utterance the raw
waveform is loaded at 16 kHz, passed to the extractor, and the resulting
``(sample, label)`` pair is yielded directly.

Layout assumption
-----------------
The *entry_point* directory must contain exactly one sub-directory named
``LibriSpeech/``, matching the structure produced by the official LibriSpeech
tar archives::

    entry_point/
    └── LibriSpeech/
        └── <speaker>/<chapter>/<utterance>.flac

Classes
-------
- :class:`LibriSpeechDataset` — Iterable PyTorch dataset.

Exceptions
----------
- :exc:`~dae.data.err.EntryPointError` — Raised when *entry_point* is invalid.
"""

import librosa
import numpy as np

from pathlib import Path
from numpy import random
from typing import Generator
from numpy.typing import NDArray
from torch.utils.data import IterableDataset

from libdse.data.err import EntryPointError
from libdse.data.features import BaseExtractor


[docs] class LibriSpeechDataset(IterableDataset): """Iterable PyTorch dataset for the `LibriSpeech <https://www.openslr.org/12>`_ ASR corpus. Streams ``(sample, label)`` tensor pairs directly from raw FLAC audio files. Feature extraction — STFT, mel projection, windowing, and optional noise mixing — is entirely delegated to the *extractor* argument, making this class agnostic about the feature representation. FLAC files are discovered recursively under *entry_point* at construction and shuffled once to reduce temporal correlation between consecutive training batches. Thereafter one ``(sample, label)`` pair is yielded per utterance by calling ``extractor(waveform)``. :param entry_point: LibriSpeech root directory. Must contain a single child directory named ``LibriSpeech/``. :type entry_point: :class:`pathlib.Path` :param extractor: Feature extractor instance. Called once per utterance with the raw mono waveform (float32, 16 kHz) as its sole argument and must return a ``(sample, label)`` tensor pair. :type extractor: :class:`~dae.data.features.BaseExtractor` :raises EntryPointError: If *entry_point* is not a directory or does not contain a ``LibriSpeech/`` sub-directory. .. note:: Because the number of feature chunks per utterance is not known without reading every file, :meth:`__len__` is not supported. Use the :class:`~torch.utils.data.DataLoader` and iterate until :exc:`StopIteration`. .. seealso:: :class:`~dae.data.features.MelPowerSpectrumExtractor` Default extractor implementation. :class:`~dae.data.noise.DEMANDNoiseDataset` Noise dataset injected into the extractor for on-the-fly mixing. .. rubric:: Typical usage .. code-block:: python from pathlib import Path from torch.utils.data import DataLoader from dae.data.features import MelPowerSpectrumExtractor from dae.data.librispeech import LibriSpeechDataset from dae.data.noise import DEMANDNoiseDataset, DEMANDNoiseType noise_ds = DEMANDNoiseDataset( entry_point=Path("data/noise/DEMAND"), noise_types=DEMANDNoiseType.ALL, ) extractor = MelPowerSpectrumExtractor( sampling_rate=16_000, window_length=512, hop_length=128, n_mels=40, chunks_per_feature=7, noise=noise_ds, ) ds = LibriSpeechDataset( entry_point=Path("data/train-clean-100"), extractor=extractor, ) loader = DataLoader(ds, batch_size=32) for noisy, clean in loader: loss = criterion(model(noisy), clean) """ def __init__( self, entry_point: Path, extractor: BaseExtractor, sample_rate: int = 16_000, ) -> None: """Validate *entry_point*, collect FLAC paths, and store the extractor. :param entry_point: LibriSpeech root directory. :type entry_point: :class:`pathlib.Path` :param extractor: Feature extractor called once per utterance. :type extractor: :class:`~dae.data.features.BaseExtractor` :raises EntryPointError: If *entry_point* is not a valid LibriSpeech root. """ super().__init__() # Verify that the directory has the expected LibriSpeech sub-directory. if not entry_point.is_dir() or "LibriSpeech" not in { p.name for p in entry_point.iterdir() }: raise EntryPointError( f"`{entry_point}` is not an entry point to a LibriSpeech " "Dataset. In case you manually disassembled the dataset prior " "to loading, please download an unaltered copy from " "https://www.openslr.org/12 and set the extraction target " "to this Dataset's `entry_point`." ) #: Sampling rate for the entire LibriSpeech corpus. Original cropus #: is sampled at 16 kHz, and all files are resampled to this rate #: at load time. self.fs = sample_rate # Materialise the glob eagerly so the list can be reused across # epochs without rescanning the file system on every iteration. self._source_flac_paths: list[Path] = list(entry_point.rglob("*.flac")) # Shuffle once at construction to reduce temporal correlation between # consecutive batches during training. random.shuffle(self._source_flac_paths) # Feature extractor. self.extractor = extractor #: Shape of a single feature vector, as reported by the extractor. self.sample_shape = self.extractor.sample_shape
[docs] def __repr__(self) -> str: """Return a concise string representation of the dataset. :return: ``LibriSpeechDataset(n_files=M, sample_shape=S)`` :rtype: str """ return ( f"LibriSpeechDataset(" f"n_files={len(self._source_flac_paths)}, " f"sample_shape={self.sample_shape})" )
[docs] def __len__(self) -> None: """Not implemented — the dataset length cannot be determined cheaply. The exact number of ``(sample, label)`` pairs depends on the duration of every audio file in the corpus. Scanning all files upfront would be prohibitively slow, so ``len()`` is intentionally unsupported. Use the :class:`~torch.utils.data.DataLoader` and iterate until :exc:`StopIteration`. :raises NotImplementedError: Always. """ raise NotImplementedError( "len() is not supported for LibriSpeechDataset. " "Iterate until StopIteration instead." )
[docs] def __iter__(self) -> Generator[tuple[NDArray, NDArray], None, None]: """Yield ``(sample, label)`` tensor pairs by streaming each FLAC file. For every utterance the raw waveform is loaded at 16 kHz and passed to :attr:`extractor` via ``yield from``. The extractor is itself a generator that yields one ``(sample, label)`` pair per non-overlapping spectrogram window, so the total number of pairs emitted by this iterator is roughly proportional to the total audio duration. :return: Generator of ``(sample, label)`` tensor pairs. :rtype: Generator[tuple[:class:`torch.Tensor`, :class:`torch.Tensor`], None, None] """ for file in self._source_flac_paths: sample, _ = librosa.load(file, sr=self.fs, mono=True) yield from self.extractor(sample)
if __name__ == "__main__": from pathlib import Path from torch.utils.data import DataLoader from libdse.data.features import LogMelPowerSpectrumExtractor from libdse.data.noise import DEMANDNoiseDataset, DEMANDNoiseType import time # Build a noise dataset covering all DEMAND environments. noise_ds = DEMANDNoiseDataset( entry_point=Path("data/noise/DEMAND"), noise_types=DEMANDNoiseType.ALL, ) # Construct the extractor with desired STFT and mel parameters. extractor = LogMelPowerSpectrumExtractor( sampling_rate=16_000, window_length=512, hop_length=128, n_mels=40, chunks_per_feature=7, noise=noise_ds, ) dataset = LibriSpeechDataset( entry_point=Path("data/train-clean-100"), extractor=extractor, ) loader = DataLoader(dataset, batch_size=512) loader_iter = iter(loader) global_start = time.perf_counter() for i in range(500): batch = next(loader_iter) avg_time = (time.perf_counter() - global_start) / (i + 1) print(f"Batch {i} - {avg_time:.3f} s/batch")