Source code for libdse.nets

"""Neural network architectures for denoising autoencoders.

Currently provides two classes:

- :class:`VanillaAutoEncoder` - fully-connected DAE used in production.
- :class:`VariationalAutoEncoder` - placeholder for a future VAE variant.
"""

import torch
from torch import nn
from torch import Tensor
from torch.nn import functional as F
from itertools import pairwise


[docs] class VanillaAutoEncoder(nn.Module): """Symmetric fully-connected denoising autoencoder. The encoder compresses the input through a user-defined stack of linear → ReLU → LayerNorm (→ Dropout) layers down to a bottleneck of size *latent_dim*. The decoder mirrors this structure and maps the bottleneck back to the original input dimension. A :class:`~torch.nn.LayerNorm` is prepended to the encoder to normalise the raw input features to zero mean and unit variance. :param input_dim: Dimensionality of the input feature vector. :type input_dim: int :param latent_dim: Bottleneck (latent) dimensionality. :type latent_dim: int :param hidden_layer_struct: Ordered list of hidden-layer widths between the input and the bottleneck. *latent_dim* is appended automatically. Defaults to ``[1024, 512, 256, 128]``. :type hidden_layer_struct: list[int] or None :param dropout: Dropout probability applied after the *first* hidden layer of the encoder (and the corresponding decoder layer). ``None`` or ``0.0`` disables dropout. :type dropout: float or None """ def __init__( self, input_dim: int, latent_dim: int, hidden_layer_struct: list[int] | None = None, dropout: list[float] | None = None, ) -> None: """Build encoder and decoder :class:`~torch.nn.Sequential` stacks. Hidden layer widths are taken from *hidden_layer_struct* (with *latent_dim* appended); dropout is only applied after the first hidden layer. """ super().__init__() if hidden_layer_struct is None: hidden_layer_struct = [1024, 512, 256, 128, latent_dim] else: hidden_layer_struct.append(latent_dim) if dropout is None: dropout_struct = [0.0] * len(hidden_layer_struct) else: drst = [0.0] * (len(hidden_layer_struct) - 1) drst.insert(0, dropout) dropout_struct = drst # Normalise input to zero mean and unit variance, then feed through encoder_modules = [nn.LayerNorm(input_dim)] i_dim = input_dim for h_dim, dropout in zip(hidden_layer_struct, dropout_struct): encoder_modules.append( nn.Sequential( nn.Linear(in_features=i_dim, out_features=h_dim), nn.ReLU(), nn.LayerNorm(h_dim), nn.Dropout(dropout), ) ) i_dim = h_dim self.encoder = nn.Sequential(*encoder_modules) decoder_modules = [] i_dim = latent_dim for h_dim, dropout in zip( reversed(hidden_layer_struct[:-1]), reversed(dropout_struct[:-1]) ): decoder_modules.append( nn.Sequential( nn.Linear(in_features=i_dim, out_features=h_dim), nn.ReLU(), nn.LayerNorm(h_dim), nn.Dropout(dropout), ) ) i_dim = h_dim decoder_modules.append( nn.Sequential( nn.Linear( in_features=hidden_layer_struct[0], out_features=input_dim ), ) ) self.decoder = nn.Sequential(*decoder_modules)
[docs] def forward(self, input: Tensor) -> Tensor: """Encode *input* to the bottleneck, then decode back to input space. :param input: Feature batch, shape ``(B, input_dim)``. :type input: :class:`torch.Tensor` :return: Reconstructed batch, shape ``(B, input_dim)``. :rtype: :class:`torch.Tensor` """ return self.decoder(self.encoder(input))
[docs] class WaveUNet(nn.Module): """Wave-U-Net for end-to-end audio source separation (Stoller et al., 2018). **Conceptual overview** Wave-U-Net operates *directly on the raw audio waveform* - no STFT, no spectrogram. The architecture is a 1-D analogue of the image-segmentation U-Net: a contracting encoder path progressively halves the time resolution while doubling the number of feature channels, a bottleneck captures the most abstract representation, and a symmetric expanding decoder path recovers the original resolution step by step. The key insight that makes this work for separation is the **skip connections**: every encoder layer's output is concatenated (channel-wise) to the corresponding decoder layer's input. This gives the decoder access to fine-grained temporal detail that would otherwise be lost during downsampling, letting the network combine high-level context (what is happening globally) with low-level detail (exactly how the waveform looks locally) at every scale simultaneously. **Signal flow**:: raw audio → [DS 1] → decimate → [DS 2] → decimate → … → bottleneck ↓ ↓ saved saved (skip connections) ↓ ↓ output ← [US 1] ← upsample ← [US 2] ← upsample ← … **Channel schedule** (following Table 1 of the paper) Let ``F_c`` be the channel-growth factor. The encoder layer ``k`` (1-indexed) produces ``k * F_c`` channels. The bottleneck produces ``(n_layers + 1) * F_c`` channels. During decoding the skip connection from the *mirror* encoder layer is concatenated before the convolution, so the number of input channels to each decoder convolution equals the sum of the upsampled decoder channels and the corresponding encoder channels. **Output** The network predicts the *foreground* source (e.g. vocals / speech) as a residual mask on the original waveform. The *background* (accompaniment / noise residual) is obtained for free as ``original - foreground``, which enforces the implicit mixture constraint that both outputs must sum back to the input. :param n_layers: Number of encoder (= decoder) layers. More layers mean a larger receptive field and more levels of temporal abstraction. :type n_layers: int :param f_u: Kernel size of every upsampling (decoder) convolution. :type f_u: int :param f_d: Kernel size of every downsampling (encoder) and bottleneck convolution. :type f_d: int :param F_c: Base channel-growth factor. Encoder layer *k* will have ``k * F_c`` output channels. :type F_c: int Reference: Stoller, D., Ewert, S., & Dixon, S. (2018). *Wave-U-Net: A Multi-Scale Neural Network for End-to-End Audio Source Separation.* arXiv:1806.03185. """ def __init__(self, n_layers: int, f_u: int, f_d: int, F_c: int) -> None: """Build the encoder stack, bottleneck, decoder stack, and output layer. :param n_layers: Number of encoder/decoder layer pairs. :param f_u: Decoder convolution kernel size. :param f_d: Encoder / bottleneck convolution kernel size. :param F_c: Base channel multiplier (see class docstring). """ super().__init__() self.n_layers = n_layers # ------------------------------------------------------------------ # Encoder (downsampling path) # ------------------------------------------------------------------ # Each encoder layer is a single Conv1d that *learns* to summarise the # local neighbourhood. No pooling is used here; temporal downsampling # is done explicitly by decimation (stride-2 slicing) *after* the # activation in forward(). Separating the convolution from the # downsampling step means the convolution can still see the full # pre-decimation context. # # Channel schedule: [1, F_c, 2*F_c, ..., n_layers*F_c] # (the leading 1 is the single raw-audio input channel) self.encoder = nn.ModuleList() channels_ds = [1] + [F_c * i for i in range(1, n_layers + 1)] for ch_in, ch_out in pairwise(channels_ds): self.encoder.append( nn.Conv1d( in_channels=ch_in, out_channels=ch_out, kernel_size=f_d, ) ) # ------------------------------------------------------------------ # Bottleneck # ------------------------------------------------------------------ # The bottleneck sits between the encoder and decoder. It sees the # most heavily decimated (shortest) feature sequence and must capture # the global structure of the mixture. It outputs (n_layers+1)*F_c # channels - one step wider than the deepest encoder layer - giving # the decoder a richer starting point. self.bottleneck = nn.Sequential( nn.Conv1d( in_channels=channels_ds[-1], out_channels=F_c * (n_layers + 1), kernel_size=f_d, ), nn.LeakyReLU(0.2), ) # ------------------------------------------------------------------ # Decoder (upsampling path) # ------------------------------------------------------------------ # The decoder mirrors the encoder. Before each Conv1d the upsampled # feature map is *concatenated* with the corresponding encoder output # (skip connection), so the input channel count is the sum of both. # # Decoder-side channels *before* concatenation (bottleneck down to 1): # [(n_layers+1)*F_c, n_layers*F_c, ..., 1*F_c] # After concatenation with encoder skip (channels_ds in reverse): # channels_us[i] = decoder_ch[i] + encoder_ch[i] self.decoder = nn.ModuleList() channels_us = [F_c * i for i in range(n_layers + 1, 0, -1)] channels_us = [ channels_us[i] + channels_ds[i] for i in range(n_layers, 0, -1) ] for ch_in, ch_out in pairwise(channels_us): self.decoder.append( nn.Conv1d( in_channels=ch_in, out_channels=ch_out, kernel_size=f_u ) ) # ------------------------------------------------------------------ # Output layer # ------------------------------------------------------------------ # The very last step concatenates the final decoder output (F_c # channels) with the *original raw waveform* (1 channel), giving # F_c + 1 input channels. A pointwise (kernel_size=1) Conv1d then # collapses these to a single-channel waveform, and Tanh clamps the # predicted sample amplitudes to [-1, 1]. self.output_layer = nn.Sequential( nn.Conv1d(in_channels=F_c + 1, out_channels=1, kernel_size=1), nn.Tanh(), )
[docs] def center_crop(self, input: Tensor, target_shape: int) -> Tensor: """Crop the time axis of *input* symmetrically to *target_shape*. Because ``Conv1d`` without padding shortens the time axis by ``kernel_size - 1``, encoder and decoder tensors at the same depth will have slightly different lengths. Before concatenating a skip connection we therefore crop the *longer* tensor to the length of the *shorter* one, always removing an equal number of samples from both ends to keep the remaining samples centred in time. :param input: Tensor of shape ``(B, C, T_in)`` to be cropped. :type input: :class:`torch.Tensor` :param target_shape: Desired length *T_out* along the time axis. Must satisfy ``T_out <= T_in``. :type target_shape: int :return: Tensor of shape ``(B, C, T_out)``. :rtype: :class:`torch.Tensor` """ input_shape = input.shape[-1] shape_difference = target_shape - input_shape # Integer floor division ensures we always remove a whole number of # samples; the off-by-one remainder (for odd differences) is # silently absorbed by Python's negative-index slicing. lr_offset = shape_difference // 2 # samples to remove from each end return input[:, :, lr_offset:-lr_offset]
[docs] def stack_channels(self, input1: Tensor, input2: Tensor) -> Tensor: """Concatenate two feature maps along the channel dimension. This is the skip-connection merge operation. Because encoder and decoder tensors differ in length (due to unpadded convolutions), *input2* is centre-cropped to match the time length of *input1* before concatenation. :param input1: Primary tensor, shape ``(B, C1, T)``. Its time length determines the output length. :type input1: :class:`torch.Tensor` :param input2: Skip-connection tensor, shape ``(B, C2, T')``. Will be cropped to ``T`` along the time axis. :type input2: :class:`torch.Tensor` :return: Merged tensor of shape ``(B, C1 + C2, T)``. :rtype: :class:`torch.Tensor` """ input2 = self.center_crop(input2, input1.shape[-1]) return torch.cat([input1, input2], dim=1)
[docs] def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: """Run a full separation forward pass. The pass has four conceptual phases: 1. **Encoder**: ``n_layers`` rounds of Conv1d + LeakyReLU followed by hard decimation (keep every other sample). Each round halves the temporal resolution and increases the channel count by ``F_c``. The *pre-decimation* activations are stashed as skip connections. 2. **Bottleneck**: one Conv1d + LeakyReLU on the most compressed representation. 3. **Decoder**: ``n_layers`` rounds of linear interpolation back to the previous resolution, skip-connection concatenation, Conv1d, and LeakyReLU. The skip connections are consumed in reverse order (deepest encoder layer first). 4. **Output**: the decoder output is concatenated with the original raw waveform, collapsed to one channel by a pointwise convolution, and passed through Tanh. The complementary output (accompaniment / noise residual) is derived as ``original - foreground``, enforcing the mixture constraint. :param x: Raw waveform batch, shape ``(B, 1, T)``. :type x: :class:`torch.Tensor` :return: Tuple ``(foreground, background)`` where both tensors have shape ``(B, 1, T')``. ``T'`` is slightly shorter than ``T`` due to unpadded convolutions reducing the time axis at each layer. :rtype: tuple[:class:`torch.Tensor`, :class:`torch.Tensor`] """ # Keep the encoder activations so we can wire them as skip connections # to their mirror decoder layers later. enc_intermediates = list() # The raw input is saved so the network can reference the unmodified # mixture waveform in the final output layer. orig = x.clone() # ------------------------------------------------------------------ # Phase 1 - Encoder (downsampling) # ------------------------------------------------------------------ for layer_num in range(self.n_layers): # Learn local features at the current temporal resolution. x = self.encoder[layer_num](x) x = F.leaky_relu(x, 0.2) # Save the full-resolution activation for the skip connection # before decimation so the decoder can access it. enc_intermediates.append(x) # Decimate: discard every odd-indexed time step, effectively # halving the sequence length. This is equivalent to stride-2 # pooling but keeps the decimation logic separate from learning. x = x[:, :, ::2] # (B, C, T) → (B, C, T//2) # ------------------------------------------------------------------ # Phase 2 - Bottleneck # ------------------------------------------------------------------ # The most compressed representation passes through one final # convolution. From here on the network must reconstruct fine detail # solely from what it learned. x = self.bottleneck(x) # ------------------------------------------------------------------ # Phase 3 - Decoder (upsampling) # ------------------------------------------------------------------ for layer_num in range(self.n_layers): # Linear interpolation restores the time axis to roughly twice its # current length. The target size (2T - 1) ensures that the # upsampled grid aligns with the original sample positions when # align_corners=True: the first and last samples are pinned, and # new samples are inserted exactly between existing ones. x = F.interpolate( x, size=x.shape[-1] * 2 - 1, # every second sample is unchanged mode="linear", align_corners=True, ) # Retrieve the mirror encoder activation (deepest first) and # concatenate it channel-wise. This is the skip connection that # gives the decoder access to fine-grained temporal structure. x = self.stack_channels( x, enc_intermediates[self.n_layers - layer_num - 1] ) # Refine the merged representation and reduce the channel count # back towards F_c. x = self.decoder[layer_num](x) x = F.leaky_relu(x, 0.2) # ------------------------------------------------------------------ # Phase 4 - Output # ------------------------------------------------------------------ # Append the original waveform as an extra channel so the network can # learn a residual correction rather than generating the output from # scratch. This biases the model towards the right answer and # generally speeds up convergence. second_to_last = self.stack_channels(x, orig) # Pointwise conv collapses all channels to one; Tanh keeps amplitudes # in [-1, 1], matching the range of normalised audio. output = self.output_layer(second_to_last) # The accompaniment is obtained for free: because output + accompaniment # must equal the original mixture, accompaniment = original - output. # We crop 'orig' to the (slightly shorter) output length first. output_accompaniment = self.center_crop(orig, output.shape[-1]) - output return output, output_accompaniment