Odunolaoluwa Shadrack Jenrola

On Streaming and Causal Convolutions

Some sections of this article include Google Drive links to Manim animations to help explain some concepts more clearly. Please look out for them.

In customer-facing, speech- and audio-based systems, such as speech-to-text (Automatic Speech Recognition), text-to-speech, voice enhancement, and voice cloning, user interactions generally occur in two modes: offline (non-streaming) and streaming inference.

In offline mode, requests are processed without the need for real-time responses or the latency constraints typically associated with human conversation. For example, in a meeting transcription service, users upload a recording of a meeting. The system processes it to produce detailed minutes, including speaker turns, discussion points, and key takeaways. In this case, the model is typically optimized for high throughput rather than low latency. The goal is to process large audio segments efficiently, rather than responding in real-time.

Streaming inference involves real-time processing and the continuous generation of responses in response to user input. The system handles incoming audio as it arrives and produces incremental outputs conditioned on past context with minimal delay. This is important for applications such as live transcriptions, real-time translation, and conversational agents, where responses must align with the user’s speech rhythm. Models in these cases are optimized for low latency.

Streaming Inference requires that models for this capability cannot access future audio frames. Unlike OpenAI’s Whisper, which analyzes the entire waveform before producing an output, streaming models like Kyutai’s Moshi and Microsoft’s Vibevoice generate predictions based only on current and past context.

Before diving deeper into how models handle such constraints, it is useful to first examine an important building block of many modern speech and audio architectures: convolutions.

At their core, Convolutions are mathematical operations that aggregates information from a local neighborhood of input values. In one dimension, the convolution of an input signal x with a kernel w is given by:

y[t]=(K1)/2(K1)/2w[k]·x[tk]

Here, y[t] is the output at time t, K is the kernel size and each w[k] learns to detect a temporal pattern. A convolution kernel acts as a window: in one dimension, imagine it as a straight line of length K sliding across an input signal of length L and channels C, taking a weighted sum of the signal below it at each position t.

The rest of this blog would focus on one-dimensional convolutions.

We’d write some code in numpy to simulate the convolution operation given an audio signal with multiple channels.

import numpy as np

def conv1d(input:np.ndarray, kernel_size:int):
    num_channels, input_length = input.shape
    kernel = np.random.randn(num_channels, kernel_size)
    output_length = input_length - kernel_size + 1

    output = np.empty((num_channels, output_length), dtype = input.dtype)

    for channel in range(num_channels):
        for position in range(output_length):
            start = position
            end = start + kernel_size
            output[channel, position] = np.dot(kernel[channel,:], input[channel, start:end])

    return output

num_channels = 3
kernel_size = 3
input_length = 10
input = np.random.randn(num_channels, input_length) # one dimensional array of shape [num_channels, sequence_length)
output = conv1d(input, kernel_size)
print(output.shape) # (3,8)

Manim Animation for conv1d here

I’d like you to notice two things in this code, the first thing is the formula to compute the length of the output of our convolution operation.

Lout=LinK+1

It may be confusing why this computes the output length. At the start of the operation, you can only place the kernel when it fully fits inside the sequence. The first valid position is when the kernel’s left edge aligns with index 0 of the sequence. The last valid position is when the kernel’s right edge aligns with the index Lin1. As the kernel moves from left to right, its left knife edge can only move LinK+1 positions before it falls off the edge.

Using the same parameters as our previous convolution operations, our sliding kernel can occupy the following positions:

[0:3]
[1:4]
[2:5]
[3:6]
[4:7]
[5:8]
[6:9]
[7:10]

Secondly, notice that we effectively slide through the multi-channel sequence one position at a time. But what if we want to move S positions at a time, where S>1? Perhaps sliding one position at a time is too computationally intensive, or we simply do not need that much detail. This argument S is called the stride of a convolution operation, defined as the number of positions the kernel moves at a time as it traverses the signal.

Using the arguments in our previous code snippet, the positions our kernel can have if S=2 are:

[0:3]
[2:5]
[4:7]
[6:9]

The formula of the output length gets modified to

Lout=LinKS+1

We’d modify the previous code snippet to show how stride affects the convolution operation.

num_channels = 3
input_length = 10
kernel_size = 3
stride = 2

def conv1d_with_stride(input:np.ndarray, kernel_size:int, stride:int):
    num_channels, input_length = input.shape
    kernel = np.random.randn(num_channels, kernel_size)
    output_length = ((input_length - kernel_size) // stride) + 1

    output = np.empty((num_channels, output_length), dtype = input.dtype)
    for channel in range(num_channels):
        for position in range(output_length):
            start = position * stride
            end = start + kernel_size
            output[channel, position] = np.dot(kernel[channel, :], input[channel, start:end])
    return output

input = np.random.randn(num_channels, input_length)
output = conv1d_with_stride(input, kernel_size, stride)
print(output.shape) #(3,4)

Manim Animation for conv1d with stride here

Great, we have two more parameters to introduce. Up until now, the kernel window has been dense, meaning that every element in the kernel looks at consecutive input samples. But what if we want to expand the field of view of the kernel without increasing its size or adding more parameters? That’s where dilation comes in.

Dilation introduces gaps between the sampled input elements covered by the kernel. The dilation rate D controls the distance between input samples when computing each output. I should clarify that dilation does not add trainable parameters, it only changes the spacing of the sampled elements. When D=1, we get the standard convolution, when D=2, the kernel skips over every other input value, effectively doubling the receptive field without having to increase the size of the kernel itself.

The indices in the kernel appear as follows.

Standard Convolution → [0, 1, 2]

Dilation → [0, 2, 4]

The formula of the output length becomes

Lout=LinD(K1)1S+1

And the code snippet becomes

num_channels = 3
input_length = 10
kernel_size = 3
stride = 2
dilation = 2

def conv1d_with_stride_and_dilation(input:np.ndarray, kernel_size:int, stride:int, dilation:int):
    num_channels, input_length = input.shape
    kernel = np.random.randn(num_channels, kernel_size)

    effective_kernel = (kernel_size-1) * dilation + 1
    output_length = ((input_length - effective_kernel) // stride) + 1

    output = np.empty((num_channels, output_length), dtype = input.dtype)
    for channel in range(num_channels):
        for position in range(output_length):
            start = position * stride
            end = start + effective_kernel
            indices = np.arange(start, end, dilation)
            output[channel, position] = np.dot(kernel[channel, :], input[channel, indices])
    return output

input = np.random.randn(num_channels, input_length)
output = conv1d_with_stride_and_dilation(input, kernel_size, stride, dilation)
print(output.shape)

Manim Animation for conv1d with stride and dilation here

Finally, we may want to control the output length by deciding how to handle the boundaries of the input before performing convolution. This is where padding comes in. Padding adds extra values (typically zeros) to the ends of the input signal, allowing the kernel to slide over edge positions that would otherwise be inaccessible perhaps as a result of the length of the input signal.

This becomes especially important when we want the output length to match the input length Lout=Lin a common requirement in deep convolutional networks that use residual or skip connections, where feature maps from different layers must align in shape for addition or concatenation.

The formula of our output length finally becomes

Lout=Lin+Pleft+PrightD(K1)1S+1

Where Pleft, Pright denote the extra values added to the left and right of the signal respectively.

Our code now becomes

num_channels = 3
input_length = 10
kernel_size = 3
stride = 2
dilation = 2
padding = (2,2)

def conv1d_with_stride_dilation_padding(input:np.ndarray, kernel_size:int, stride:int, dilation:int, padding:tuple):
    num_channels, input_length = input.shape
    kernel = np.random.randn(num_channels, kernel_size)

    effective_kernel = (kernel_size-1) * dilation + 1
    output_length = ((input_length - effective_kernel + sum(padding)) // stride) + 1

    output = np.empty((num_channels, output_length), dtype = input.dtype)
    padded_input = np.pad(input, pad_width = ((0,0), padding), mode = 'constant', constant_values=0)

    for channel in range(num_channels):
        for position in range(output_length):
            start = position * stride
            end = start + effective_kernel
            indices = np.arange(start, end, dilation)
            output[channel, position] = np.dot(kernel[channel, :], padded_input[channel, indices])
    return output

input = np.random.randn(num_channels, input_length)
output = conv1d_with_stride_dilation_padding(input, kernel_size, stride, dilation, padding)
print(output.shape)
(3, 5)

Manim Animation for conv1d with stride, dilation and padding Link

We’d add two more cases to our function to derive the final general conv1d operation.

We may sometimes wish not to limit the number of output channels to be equal to the number of input channels. Let’s redefine our convolution kernel to have a shape of out_channels, in_channels, K so that each output channel is a combination of all input channels. We also add a bias term, similar to dense layers, to provide each output channel with an independent offset not tied to the convolution operation. The output length remains the same, but now we have a complete function of the convolution operation.

num_input_channels = 3
num_output_channels = 5
input_length = 10
kernel_size = 3
stride = 2
dilation = 2
padding = (2,2)

def general_conv1d(
        input:np.ndarray,
        num_output_channels,
        kernel_size:int,
        stride:int,
        dilation:int,
        padding:tuple,
        bias = True
):
    num_input_channels, input_length = input.shape
    kernel = np.random.randn(num_output_channels, num_input_channels,kernel_size)
    bias = np.random.randn(num_output_channels) if bias is True else np.zeros(num_input_channels)

    effective_kernel = (kernel_size-1) * dilation + 1
    output_length = ((input_length - effective_kernel + sum(padding)) // stride) + 1

    output = np.empty((num_output_channels,output_length), dtype = input.dtype)
    padded_input = np.pad(input, pad_width = ((0,0), padding), mode = 'constant', constant_values=0)

    for channel in range(num_output_channels):
        for position in range(output_length):
            start = position * stride
            end = start + effective_kernel
            indices = np.arange(start, end, dilation)
            value = 0
            for in_channel in range(num_input_channels):
                result = np.dot(kernel[channel, in_channel,:], padded_input[in_channel, indices])
                value += result
            output[channel, position] = value + bias[channel]
    return output

input = np.random.randn(num_input_channels, input_length)
output = general_conv1d(input, num_output_channels,kernel_size, stride, dilation, padding, bias = True)
print(output.shape)
(5, 5)

Nice! We’ve so far derived the convolution operation.

We can compute a one-dimensional convolution in Pytorch straightaway, like this

import torch
from torch import nn

num_input_channels = 3
num_output_channels = 5
input_length = 10
kernel_size = 3
stride = 2
dilation = 2
padding = (2)
bias = True

conv = nn.Conv1d(
    in_channels = num_input_channels,
    out_channels = num_output_channels,
    kernel_size = kernel_size,
    stride = stride,
    dilation = dilation,
    padding = padding,
    bias = bias
)

input = torch.randn(num_input_channels,input_length)
output = conv(input)
print(output.shape)
torch.Size([5, 5])

We would not discuss the groups argument in this blog. You can refer to the PyTorch documentation for nn.Conv1d to learn how it affects convolutions.

When applying convolutional operations in neural networks, the learnable parameters are the elements of the kernel w[k]. Each w[k] represents how strongly the model should emphasize or suppress the input sample x[tk].

There’s one other thing we mentioned earlier that you should pay some attention to. the receptive field.

The Receptive field of a convolution refers to the range of input samples that influence a single output value.

When we had the kernel sliding over the signal, the receptive field is simply the range of parts of the input signal that the kernel hovers over at every time t.

We referred to this as the effective kernel when deriving the general convolution operation earlier. We also saw this value change from just the kernel size K when the dilation was 1 to

R=D·(K1)+1

When dilation was greater than 1. This is because dilation introduces spaces between the kernel elements, making the area over which the kernel sums larger. We denote the Receptive field of a convolution as R.

Recall that dilation increases the spacing between the kernel elements, so the convolution sees a wider portion of the input without increasing the size of the kernel and parameter count.

It is helpful to note that a single kernel (which is the only learnable parameter in a convolution asides the bias ) is reused across the entire sequence in a convolution. This weight sharing is what makes convolutional layers parameter-efficient. A kernel of shape (C,K) contains far fewer parameters than a dense layer connecting all input positions.

Also, as the input signal length increases, the memory cost of convolutions increases quickly, too. This is because the convolutions must compute and store many overlapping multiplications as the kernel slides over long distances and channels. In practice, this trade-off between parameter efficiency and computational cost becomes more pronounced in audio and speech models that process long sequences.

To recap, we’ve so far covered:

With these basics in place, next we introduce causal convolutions. But first let’s understand what casual systems are. A system is causal if its output at time t depends only on present and past inputs, never on future inputs. Future samples are unavailable in real-time processing, so causality prevents access to them.

In decoder-only Transformers, such as GPT, causality is enforced through attention masking. This prevents each token from attending to future positions and ensures that predictions at time t depend solely on tokens up to t.

In standard convolutions, the receptive field of the kernel is centered around the current position, meaning that it extends into the future. This makes it unsuitable for real-time applications.

In other to make the convolution operation causal, we modify the convolution so that each output y[t] depends only on inputs from time steps 0 to t.

Let’s review the convolution equation from a while back.

y[t]=(K1)/2(K1)/2w[k]·x[tk]

Here, k indexes positions within the kernel, ranging from K12 to +K12. The lower limit corresponds to the leftmost part of the kernel (past inputs), while the upper limit corresponds to the rightmost part.

Consider a specific timestep t on the input signal. When the kernel reaches this point, it is centered at it, this means that the middle of the kernel aligns exactly with x[t]. Because the kernel extends both backward and forward around this center point, the convolution operation uses samples from both the past and the future relative to t.

This is exactly why standard convolution is not causal. The receptive field at time t includes input values that have not yet occurred in real time. Effectively, the model is ā€œlooking aheadā€.

Now, we wish to make this operation causal. To do this, we should modify the limits of the summation so that only the correct and past inputs contribute to each output. This means shifting the receptive field entirely to the left or up to the current point in time.

We express causal convolution as:

y[t]=k=0K1w[k]·x[tk]

Here, the summation index k now starts from 0 and goes only backward in time, The receptive field size remains the same as in the standard convolution, but the key difference is that the kernel no longer overlaps future samples of the input signal.

This means that when the kernel reaches point t, it’s rightmost edge now aligns with x[t] while the remaining kernel elements cover only the past inputs (x[t1],x[t2],,x[t(K1)]).

This modification preserves the temporal order of the signal, making the convolution suitable for streaming and autoregressive tasks, where future inputs must never influence the current output.

Now how do we make a regular convolution causal?

To do this, we pad the signal on the left (the past) before performing the operation. This padding controls where the convolution kernel ā€œstartsā€ and ā€œendsā€ relative to the input sequence.

Recall that in a standard 1D convolution, the kernel is centered around the current position.

For example, with a kernel size K=3 and dilation D=1. The receptive field spans three samples: x[t1],x[t],x[t+1]]. This means that it looks one step into the future.

However, for this convolution to be causal, we want each output y[t] to depend only on the current and past inputs x[t],x[t1],x[t2]].

To do this, we pad the left side of the signal with (K1) zeros. This effectively shifts the receptive field leftward so that it no longer extends into the future, ensuring that the model cannot access information from later time steps.

How about if dilation D>1?

In this case, the minimum left padding required to achieve causality is given by:

Pleft=(K1)×D

Let’s take K=3 and D=2:

Computing the above equation gives the left padding needed to be 4. Our Effective receptive field also becomes 5 (recall R=(K1)×D+1).

So now, given the sequence

x[t]=[x0, x1, x2, x3, x4]

with indices

indices=[0, 1, 2, 3, 4]

We first pad with the required number of zeros.

x[t]=[0, 0, 0, 0, x0, x1,x2, x3, x4]

Our indices now becomes

indices=[4, 3, 2, 1, 0, 1, 2, 3, 4]

At position 0,

The convolution kernel operates over the indices [4,2,0]. This corresponds to the values [0,0,x0], so the kernel never looks ahead.

At position 1,

Input indices: [3,1,1], values become [0,0,x1]

At position 2,

Input indices: [2,0,2], values become [0,x0,x2]

At further outputs, you’d see that the kernel never sees into the future. We achieved causality!

Let’s update our convolution code.

num_input_channels = 3
num_output_channels = 5
input_length = 10
kernel_size = 3
stride = 2
dilation = 2

def causal_conv1d(
        input:np.ndarray,
        num_output_channels,
        kernel_size:int,
        stride:int,
        dilation:int,
        bias = True
):
    num_input_channels, input_length = input.shape
    kernel = np.random.randn(num_output_channels, num_input_channels,kernel_size)
    bias = np.random.randn(num_output_channels) if bias is True else np.ones(num_input_channels)

    effective_kernel = (kernel_size-1) * dilation + 1
    left_padding = (kernel_size - 1) * dilation
    padding = (left_padding, 0)

    output_length = ((input_length - effective_kernel + sum(padding)) // stride) + 1
    output = np.empty((num_output_channels,output_length), dtype = input.dtype)

    padded_input = np.pad(input, pad_width = ((0,0), padding), mode = 'constant', constant_values=0)

    for channel in range(num_output_channels):
        for position in range(output_length):
            start = position * stride
            end = start + effective_kernel
            indices = np.arange(start, end, dilation)
            value = 0
            for in_channel in range(num_input_channels):
                result = np.dot(kernel[channel, in_channel,:], padded_input[in_channel, indices])
                value += result
            output[channel, position] = value + bias[channel]
    return output

input = np.random.randn(num_input_channels, input_length)
output = causal_conv1d(input, num_output_channels,kernel_size, stride, dilation, bias = True)
print(output.shape) #  (5,5)

We can also implement this in Pytorch

import torch
import torch.nn as nn
import torch.nn.functional as F

class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1, bias=True):
        super().__init__()
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.stride = stride
        self.left_padding = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
                              stride=stride, dilation=dilation, bias=bias)

    def forward(self, x):
        x = F.pad(x, (self.left_padding, 0))
        return self.conv(x)

input = torch.randn(1,3,10)
causal_conv = CausalConv1d(3,5, kernel_size = 3, dilation = 2)
output = causal_conv(input)
print(output.shape)
# Size([1,5,10])

Manim Animation for Causal Convolutions: Link

Now, what happens when we stack convolutions, like what you’d see in state-of-the-art neural audio codecs? Why even have a stack of convolutions?

Recall that the receptive field is the range of input samples that influence an output. When you stack a couple of convolutional layers, each layer expands the range of past samples that can influence an output, effectively increasing the receptive field of the entire system

The receptive field of a single convolution as stated before is

R=D×(K1)+1

Assume we have three convolutions stacked atop one another with the following arguments:

Let the input sample to the stacked convolution module be X0

X0 passes through Layer 1 to return X1, The receptive field so far is given as

R1=D1×(K11)+1R1=1×(31)+1R1=3

X1 passes through Layer 2 to return X2. The receptive field relative to X0 becomes

R2=R1+(K21)×D2×S1R2=3+(41)×2×2R2=15

But how does stride get into the computation?

Let X0=[0,1,2,3,4,5,6]

recall that stride tells the number of temporal steps the convolution kernel jumps through as it moves atop the signal.

After Layer 1 with S=2:

Notice the stride effect.

Now, when Layer 2 looks at the three values of X1, it is also looking at summaries of all six values of X0. Because Layer 1 has a stride of 2, moving 1 step in X1 also equals moving S=2 steps in X0. So if Layer 2 has kernel size K2=4 and dilation D2=2:

This is why we multiply by the cumulative product of strides from previous layers.

Now, for Layer 3, X2 passes through Layer 2 to return X3. The receptive field relative to X0 becomes

R3=R2+(K31)×D3×S2×S1R3=31

A general equation to derive the effective Receptive field of a convolutional network is

RL=1+i=1L[(Ki1)×Di×j1i1Sj]

It looks daunting at first but it is really just the recursive computation we went through earlier for any number of layers L.

So, by stacking convolutions, we have increased the receptive field from 3 to 15 to 31, a considerable expansion. This growing contextual window enables a powerful hierarchical feature extraction process.

Early layers with smaller receptive fields would capture fine-grained, local patterns in the input signal. In audio applications, these might correspond to basic acoustic features, such as specific frequencies, onsets, or phoneme-level characteristics.

As we progress deeperĀ and the receptive field expands, each subsequent layer can integrate information across these local patterns to form more complex, abstract representations.

This layer-wise processing allows the network capture both fine and broad patterns, understand quick acoustic events and grasp long-term acoustic event.

in Pytorch, we can implement a simple stacked convolutional neural network like this.

import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvModule(nn.Module):
    def __init__(self, channels:list, kernel_sizes:list, dilations:list, strides:list, bias = False):
        super().__init__()
        num_layers = len(channels) - 1
        assert len(kernel_sizes) == len(dilations) == len(strides) == num_layers
        layers = [
            nn.Conv1d(
                in_channels = channels[i],
                out_channels = channels[i+1],
                kernel_size = kernel_sizes[i],
                stride = strides[i],
                dilation = dilations[i],
                bias = bias,
            )
        for i in range(num_layers)]
        self.layers = nn.Sequential(*layers)

        self.num_layers = num_layers
        self.kernel_sizes = kernel_sizes
        self.dilations = dilations
        self.strides = strides

    def forward(self, tensor):
        output = tensor
        for layer in self.layers:
            output = layer(output)
        return output

    def get_receptive_field(self):
        stride_product = 1
        receptive_field = 1
        for kernel_size, dilation, stride in zip(self.kernel_sizes, self.dilations, self.strides):
            receptive_field += (kernel_size - 1) * dilation * stride_product
            stride_product *= stride
        return receptive_field

channels = [1,3,5,7,11]
kernel_sizes = [3,3,3,3]
dilations = [1,2,1,2]
strides = [2,1,2,1]

conv_stack = ConvModule(
    channels=channels,
    kernel_sizes=kernel_sizes,
    dilations = dilations,
    strides = strides
)
input_signal = torch.randn(2,1,24000)
output = conv_stack(input_signal)
print(output.shape)
print(conv_stack.get_receptive_field()
torch.Size([2, 11, 5993])
30

For a causal neural network,

class CausalConvModule(nn.Module):
    def __init__(self, channels:list, kernel_sizes:list, dilations:list, strides:list, bias = False):
        super().__init__()
        num_layers = len(channels) - 1
        assert len(kernel_sizes) == len(dilations) == len(strides) == num_layers
        layers = [
            CausalConv1d(
                in_channels = channels[i],
                out_channels = channels[i+1],
                kernel_size = kernel_sizes[i],
                stride = strides[i],
                dilation = dilations[i],
                bias = bias,
            )
        for i in range(num_layers)]
        self.layers = nn.Sequential(*layers)

        self.num_layers = num_layers
        self.kernel_sizes = kernel_sizes
        self.dilations = dilations
        self.strides = strides

    def forward(self, tensor):
        output = tensor
        for layer in self.layers:
            output = layer(output)
        return output

    def get_receptive_field(self):
        stride_product = 1
        receptive_field = 0
        for kernel_size, dilation, stride in zip(self.kernel_sizes, self.dilations, self.strides):
            receptive_field += 1 + (kernel_size - 1) * dilation * stride_product
            stride_product *= stride
        return receptive_field

channels = [1,3,5,7,11]
kernel_sizes = [3,3,3,3]
dilations = [1,2,1,2]
strides = [2,1,2,1]

causal_conv_stack = ConvModule(
    channels=channels,
    kernel_sizes=kernel_sizes,
    dilations = dilations,
    strides = strides
)
input_signal = torch.randn(2,1,24000)
output = causal_conv_stack(input_signal)
print(output.shape)
print(conv_stack.get_receptive_field())
torch.Size([2, 11, 5993])
31

Manim Animation for Stacked Convolutions: Link

Finally, let’s bring all that we’ve discussed so far back into the context of streaming inference.

When performing streaming inference with models that utilize causal convolutions, such as neural audio codecs, text-to-speech, and speech-to-text models, we run the model on an incoming audio stream (or token stream) chunk by chunk, without recomputing everything from scratch each time. This should be possible because, as we have established about causal convolutions, each output depends only on current and past inputs.

So, if you have already processed samples [0,,t1], when the next chunk [t,,t+k] arrives, you do not need to re-run the whole model. All you have to do is to reuse the last few samples. If you consider our previous pytorch causal convolution code, instead of padding with zeros during inference, you pad with a number of previous samples.

For a single one-dimensional convolution, the amount of past context you need to left-pad your sequence is the same as the left padding for causality.

context=D×(K1)

In a neural network with multiple convolutional layers, operating in streaming mode, each layer must maintain its own cache, comprising the last few samples required to predict the next sample frame. This cache ensures that when a new chunk arrives,

  1. The layer reuses the relevant past context rather than zero-padding as in plain causal convolutions.
  2. Computation continues without recomputing old frames.

For example, let’s take a three-layer causal stack with an incoming chunk xt:

  1. Layer 1 takes xt prepended with the last context frames of its input during the previous timestemp. We would denote these frames as c1. This was previously stored in cache. After the forward pass, the layer updates c1 with the last context1 samples from its input.
  2. Layer 2 receives Layer 1’s output. It also prepends its own cache c2 from the previous step, runs its convolution and updates c2 accordingly.
  3. Layer 3 performs the same operation with its cache, c3.

Formally, if fi denotes the ith convolutional layer, then for each time step:

yi(t)=fi([ci,yi1(t)])

Where yi(t) denotes the output frame of layer i at time step t.

This makes the model’s streaming output identical to what you’d get if you processed the full sequence offline, except it does so incrementally.

Now let’s write some code to see how all of this works. First, we’d implement a cache that stores previous frames for each layer.

import torch
from torch import nn
from typing import Optional

class Cache:
    def __init__(self):
        self.cache = {}

    def get(self, layer_id:str):
        states = self.cache.get(layer_id, None)
        return states

    def set(self, layer_id:str, states:torch.Tensor):
        self.cache[layer_id]= states.detach()

    def clear(self):
        self.cache = {}

The Cache is just a dictionary that uses a layer_id as a unique identification of all convolutional layers in the neural network.

Next, we write our Causal convolutional layer with context, adding an identifier property in the class so we can pull the specific layer’s context from the cache. The convolution has two forward passes, one for streaming and one for non-streaming scenarios. When not streaming, we pad the input tensor in the same causal fashion.

Then we declare our Neural Network

class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1, bias=True):
        super().__init__()
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.stride = stride
        self.left_padding = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
                              stride=stride, dilation=dilation, bias=bias)

        context_size = (kernel_size - 1) * dilation - (stride - 1)
        self.context_size = context_size if context_size > 0 else 0
        self.padding = self.context_size

    def forward(self, tensor, cache:Optional = None):
        if cache is not None:
            return self._forward_streaming(tensor, cache)
        return self._forward_offline(tensor)

    @property
    def layer_id(self):
        return str(id(self))

    def _forward_streaming(self, tensor, cache:Cache):
        B, C, T = tensor.shape
        cached_states = cache.get(self.layer_id)
        if cached_states is None:
            cached_states = torch.zeros(B, C, self.context_size, device = tensor.device, dtype = tensor.dtype)

        input_with_context = torch.cat([cached_states, tensor], dim = 2)
        output = self.conv(input_with_context)

        if self.context_size > 0:
            total_input_length = input_with_context.shape[2]
            if total_input_length >= self.context_size:
                new_cache_start = total_input_length - self.context_size
                new_cache = input_with_context[:, :, new_cache_start:]
            else:
                new_cache = input_with_context

            cache.set(self.layer_id, new_cache)
        return output

    def _forward_offline(self, tensor):
        tensor = nn.functional.pad(tensor,(self.padding, 0))
        return self.conv(tensor)

class ConvModule(nn.Module):
    def __init__(self, channels:list, kernel_sizes:list, dilations:list, strides:list, bias = False):
        super().__init__()
        num_layers = len(channels) - 1
        assert len(kernel_sizes) == len(dilations) == len(strides) == num_layers
        layers = [
            CausalConv1d(
                in_channels = channels[i],
                out_channels = channels[i+1],
                kernel_size = kernel_sizes[i],
                stride = strides[i],
                dilation = dilations[i],
                bias = bias,
            )
        for i in range(num_layers)]
        self.layers = nn.Sequential(*layers)

        self.num_layers = num_layers
        self.kernel_sizes = kernel_sizes
        self.dilations = dilations
        self.strides = strides

    def forward(self, tensor:torch.Tensor, cache:Optional[Cache] = None):
        output = tensor
        for layer in self.layers:
            output = layer(output, cache)
        return output

It is essential that we ensure the same results are obtained from the neural network, whether it is streaming or not. To test this, we can simulate a real-time scenario where audio chunks are fed into the neural network as they are made available.

channels = [1,3,5,7,11]
kernel_sizes = [3,3,3,3]
dilations = [1,2,1,2]
strides = [2,1,2,1]

cache = Cache()
causal_conv_stack = ConvModule(
    channels=channels,
    kernel_sizes=kernel_sizes,
    dilations = dilations,
    strides = strides
)
test_audio = torch.randn(2,1,24000)
output_frames = []

chunk_size = 2000 # New chunk of audio stream that keeps coming in
for i in range(0, test_audio.shape[-1], chunk_size):
    chunk = test_audio[...,i:i+chunk_size]
    output_frame = causal_conv_stack(
        chunk, cache = cache
    )
    output_frames.append(output_frame)

print('streaming inference')
stream_output = torch.concat(output_frames, dim = -1)
print(stream_output.shape)

print('offline inference')
offline_output = causal_conv_stack(test_audio)
print(offline_output.shape)

print(torch.abs(stream_output-offline_output).mean())
streaming inference
torch.Size([2, 11, 6000])
offline inference
torch.Size([2, 11, 6000])
tensor(0., grad_fn=<MeanBackward0>)

Great!

We were able to match sending in an entire audio signal to the model versus sending in chunks of audio in a streaming fashion.

One last thing before we call it a wrap, it would be nice to run our setup on an actual causal audio model. As a part of an upcoming project, I’ve been able to extract the weights of the acoustic variational autoencoders built by Microsoft as part of their VibeVoice project. VibeVoice is a speech synthesis model specifically trained for long-form, expressive audio such as podcasts. The Variational Autoencoder from Vibevoice differs from the traditional method of currently tokenizing audio into discrete tokens for language modeling. It instead leaves audio latents as continuous representations, and language modeling is done directly in this form. We would leave talking about this idea and my own upcoming related project for a future write-up.

The goal is to simulate streaming audio through a variational auto-encoder and compare audio results.

git clone https://github.com/odunola499/latent-lm-lessons.git
cd latent-lm-lessons
pip install -e .

Now, you can access the code for this simulation in blogs/streaming_convolutions/stream_simulation.py. For convenience, so I can explain how it works, I’ll add it here as well. However, you should first skim through the model’s streaming cache and architecture, also available in the repository.

from queue import Queue
from threading import Thread
from latentlm.vae.audio.model import AcousticTokenizerModel, AcousticTokenizerConfig
from latentlm.vae.audio.cache import StreamingCache
import torchaudio
from torchaudio.io import StreamWriter
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import torch
import math
import time

def pad_to_multiple(x: torch.Tensor, multiple: int = 3200) -> torch.Tensor:
    L = x.shape[-1]
    target_len = multiple * math.ceil(L / multiple)
    
    if L < target_len:
        pad_amount = target_len - L
        x = torch.nn.functional.pad(x, (0, pad_amount))
    
    return x

def compute_receptive_field(kernel_sizes, dilation, strides):
    stride_product = 1
    receptive_field = 1
    for stride, kernel_size in zip(strides, kernel_sizes):
        receptive_field += (kernel_size - 1) * dilation * stride_product
        stride_product *= stride
    return receptive_field

def load_model(device, repo_id='odunola/vibevoice_vae_weights'):
    cache = StreamingCache()
    config = AcousticTokenizerConfig()
    model = AcousticTokenizerModel(config).to(device)
    path = hf_hub_download(
        repo_id=repo_id,
        filename='acoustic.safetensors'
    )
    weights = load_file(path)
    model.load_state_dict(weights)
    return model, cache

def audio_stream(file_path, chunk_size):
    waveform, sample_rate = torchaudio.load(file_path)
    print("Starting streaming")
    waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=24000)
    waveform = waveform.mean(0, keepdim=True).unsqueeze(0)
    for i in range(0, waveform.shape[-1], chunk_size):
        yield waveform[..., i:i + chunk_size]

def producer(input_queue: Queue, file_path: str, chunk_size=6000):
    for chunk in audio_stream(file_path, chunk_size):
        chunk = pad_to_multiple(chunk, chunk_size)
        input_queue.put(chunk)
        time.sleep(0.05)
    input_queue.put(None)

def consumer(
        input_queue: Queue,
        output_queue: Queue,
        model: AcousticTokenizerModel,
        cache: StreamingCache,
        sample_indices: torch.Tensor,
        device
):
    while True:
        chunk = input_queue.get()
        if chunk is None:
            output_queue.put(None)
            break
        print(f"Input to model: {chunk.shape}")
        recon_chunk, _ = model(
            chunk.to(device),
            cache=cache,
            sample_indices=sample_indices,
            use_cache=True,
            debug=False
        )
        output_queue.put(recon_chunk)

def save_to_disk(stream_output_path: str, output_queue: Queue):
    writer = StreamWriter(stream_output_path)
    writer.add_audio_stream(sample_rate=24000, num_channels=1)
    writer.open()

    while True:
        chunk = output_queue.get()
        if chunk is None:
            print("Reached end")
            writer.close()
            break
        chunk = chunk[0].detach().cpu()
        chunk = chunk.squeeze(0).unsqueeze(-1)
        writer.write_audio_chunk(0, chunk)

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    input_queue = Queue()
    output_queue = Queue()
    model, cache = load_model(device)
    sample_indices = torch.tensor([0], device=device)

    file_path = 'audio.mp3'
    offline_output_path = 'offline_recon.wav'
    stream_output_path = 'stream_recon.wav'

    chunk_size = 3200

    producer_thread = Thread(target=producer, args=(
        input_queue, file_path, chunk_size)
                             )
    consumer_thread = Thread(target=consumer, args=(
        input_queue, output_queue, model, cache, sample_indices, device)
                             )
    save_thread = Thread(target=save_to_disk, args=(stream_output_path, output_queue))

    producer_thread.start()
    consumer_thread.start()
    save_thread.start()

    producer_thread.join()
    consumer_thread.join()
    save_thread.join()

    print('Finished Streaming inference')

    audio, sample_rate = torchaudio.load(file_path)
    audio = torchaudio.functional.resample(audio, orig_freq=sample_rate, new_freq=24000)
    audio = audio.unsqueeze(0).to(device)
    audio = pad_to_multiple(audio, chunk_size)
 
    recon, _ = model(audio)
    torchaudio.save(offline_output_path, recon[0].detach().cpu(), sample_rate=24000)

    print('Finished Offline inference')

Let’s talk about what is happening here.

We have two queues and three threads. In the first thread, the producer simulates a real-time audio source. In reality, this would be a microphone or socket stream. Our audio source is an audio file that our producer loads, resamples, and then slices into fixed-size chunks, pushing each chunk into the first queue, our input queue. Once all the chunks have been pushed, we push a None token to signal the end of the stream.

In the second thread, the consumer continuously retrieves items from the input queue and performs the actual streaming simulation. Every chunk is fed to the model’s forward method. For every pass, the cache of all causal convolutions in the model is updated. The model produces a latent reconstruction for that specific chunk and also computes the reconstructed audio. I should add that this is an unconventional streaming task, as a variational autoencoder like this would likely have the encoder providing latents in a streaming fashion for another task, or the decoder reconstructing audio from latents coming in from perhaps an autoregressive generative task.

The third thread retrieves the reconstructed chunks from the output queue and uses Torchaudio’s StreamWriter to write to the file incrementally, so we don't have to wait for the entire audio to finish processing before saving it to the file. This aligns with how real text-to-speech models synthesize audio for playback.

We set the chunk size to 3200, as this is the length of audio input that would give exactly one latent frame for this model. You can confirm by multiplying the strides, as strides are responsible for downsampling in CNNs, as we saw earlier.

You can try running this to get results, but I have made some results available here, so you can listen straight away. Try out other settings too, like different chunk sizes larger and smaller than 3200. There are also a few examples on streaming in the repo and i would be adding a few more soon.

Throughout this writeup, we've covered how convolutions worked, how kernel sizes, dilations, and strides shape the behavior of the network, and why causal convolutions prevent future parts of a sequence from influencing the current timestep. We've also explored how streaming inference works in practice with a cache to manage state. Finally, we put everything to the test on a trained audio variational autoencoder.

Thank you for stopping by! I hope this deep dive into convolutions has been as enlightening for you as it was exciting for me.

If you enjoyed this post, please consider giving a thumbs up below!

Further Reading