On Streaming and Causal Convolutions
Some sections of this article include Google Drive links to Manim animations to help explain some concepts more clearly. Please look out for them.
In customer-facing, speech- and audio-based systems, such as speech-to-text (Automatic Speech Recognition), text-to-speech, voice enhancement, and voice cloning, user interactions generally occur in two modes: offline (non-streaming) and streaming inference.
In offline mode, requests are processed without the need for real-time responses or the latency constraints typically associated with human conversation. For example, in a meeting transcription service, users upload a recording of a meeting. The system processes it to produce detailed minutes, including speaker turns, discussion points, and key takeaways. In this case, the model is typically optimized for high throughput rather than low latency. The goal is to process large audio segments efficiently, rather than responding in real-time.
Streaming inference involves real-time processing and the continuous generation of responses in response to user input. The system handles incoming audio as it arrives and produces incremental outputs conditioned on past context with minimal delay. This is important for applications such as live transcriptions, real-time translation, and conversational agents, where responses must align with the userās speech rhythm. Models in these cases are optimized for low latency.
Streaming Inference requires that models for this capability cannot access future audio frames. Unlike OpenAIās Whisper, which analyzes the entire waveform before producing an output, streaming models like Kyutaiās Moshi and Microsoftās Vibevoice generate predictions based only on current and past context.
Before diving deeper into how models handle such constraints, it is useful to first examine an important building block of many modern speech and audio architectures: convolutions.
At their core, Convolutions are mathematical operations that aggregates information from a local neighborhood of input values. In one dimension, the convolution of an input signal with a kernel is given by:
Here, is the output at time , is the kernel size and each learns to detect a temporal pattern. A convolution kernel acts as a window: in one dimension, imagine it as a straight line of length sliding across an input signal of length and channels , taking a weighted sum of the signal below it at each position .
The rest of this blog would focus on one-dimensional convolutions.
Weād write some code in numpy to simulate the convolution operation given an audio signal with multiple channels.
import numpy as np
def conv1d(input:np.ndarray, kernel_size:int):
num_channels, input_length = input.shape
kernel = np.random.randn(num_channels, kernel_size)
output_length = input_length - kernel_size + 1
output = np.empty((num_channels, output_length), dtype = input.dtype)
for channel in range(num_channels):
for position in range(output_length):
start = position
end = start + kernel_size
output[channel, position] = np.dot(kernel[channel,:], input[channel, start:end])
return output
num_channels = 3
kernel_size = 3
input_length = 10
input = np.random.randn(num_channels, input_length) # one dimensional array of shape [num_channels, sequence_length)
output = conv1d(input, kernel_size)
print(output.shape) # (3,8)
Manim Animation for conv1d here
Iād like you to notice two things in this code, the first thing is the formula to compute the length of the output of our convolution operation.
It may be confusing why this computes the output length. At the start of the operation, you can only place the kernel when it fully fits inside the sequence. The first valid position is when the kernelās left edge aligns with index 0 of the sequence. The last valid position is when the kernelās right edge aligns with the index . As the kernel moves from left to right, its left knife edge can only move positions before it falls off the edge.
Using the same parameters as our previous convolution operations, our sliding kernel can occupy the following positions:
[0:3]
[1:4]
[2:5]
[3:6]
[4:7]
[5:8]
[6:9]
[7:10]
Secondly, notice that we effectively slide through the multi-channel sequence one position at a time. But what if we want to move positions at a time, where ? Perhaps sliding one position at a time is too computationally intensive, or we simply do not need that much detail. This argument is called the stride of a convolution operation, defined as the number of positions the kernel moves at a time as it traverses the signal.
Using the arguments in our previous code snippet, the positions our kernel can have if are:
[0:3]
[2:5]
[4:7]
[6:9]
The formula of the output length gets modified to
Weād modify the previous code snippet to show how stride affects the convolution operation.
num_channels = 3
input_length = 10
kernel_size = 3
stride = 2
def conv1d_with_stride(input:np.ndarray, kernel_size:int, stride:int):
num_channels, input_length = input.shape
kernel = np.random.randn(num_channels, kernel_size)
output_length = ((input_length - kernel_size) // stride) + 1
output = np.empty((num_channels, output_length), dtype = input.dtype)
for channel in range(num_channels):
for position in range(output_length):
start = position * stride
end = start + kernel_size
output[channel, position] = np.dot(kernel[channel, :], input[channel, start:end])
return output
input = np.random.randn(num_channels, input_length)
output = conv1d_with_stride(input, kernel_size, stride)
print(output.shape) #(3,4)
Manim Animation for conv1d with stride here
Great, we have two more parameters to introduce. Up until now, the kernel window has been dense, meaning that every element in the kernel looks at consecutive input samples. But what if we want to expand the field of view of the kernel without increasing its size or adding more parameters? Thatās where dilation comes in.
Dilation introduces gaps between the sampled input elements covered by the kernel. The dilation rate controls the distance between input samples when computing each output. I should clarify that dilation does not add trainable parameters, it only changes the spacing of the sampled elements. When , we get the standard convolution, when , the kernel skips over every other input value, effectively doubling the receptive field without having to increase the size of the kernel itself.
The indices in the kernel appear as follows.
Standard Convolution ā [0, 1, 2]
Dilation ā [0, 2, 4]
The formula of the output length becomes
And the code snippet becomes
num_channels = 3
input_length = 10
kernel_size = 3
stride = 2
dilation = 2
def conv1d_with_stride_and_dilation(input:np.ndarray, kernel_size:int, stride:int, dilation:int):
num_channels, input_length = input.shape
kernel = np.random.randn(num_channels, kernel_size)
effective_kernel = (kernel_size-1) * dilation + 1
output_length = ((input_length - effective_kernel) // stride) + 1
output = np.empty((num_channels, output_length), dtype = input.dtype)
for channel in range(num_channels):
for position in range(output_length):
start = position * stride
end = start + effective_kernel
indices = np.arange(start, end, dilation)
output[channel, position] = np.dot(kernel[channel, :], input[channel, indices])
return output
input = np.random.randn(num_channels, input_length)
output = conv1d_with_stride_and_dilation(input, kernel_size, stride, dilation)
print(output.shape)
Manim Animation for conv1d with stride and dilation here
Finally, we may want to control the output length by deciding how to handle the boundaries of the input before performing convolution. This is where padding comes in. Padding adds extra values (typically zeros) to the ends of the input signal, allowing the kernel to slide over edge positions that would otherwise be inaccessible perhaps as a result of the length of the input signal.
This becomes especially important when we want the output length to match the input length a common requirement in deep convolutional networks that use residual or skip connections, where feature maps from different layers must align in shape for addition or concatenation.
The formula of our output length finally becomes
Where , denote the extra values added to the left and right of the signal respectively.
Our code now becomes
num_channels = 3
input_length = 10
kernel_size = 3
stride = 2
dilation = 2
padding = (2,2)
def conv1d_with_stride_dilation_padding(input:np.ndarray, kernel_size:int, stride:int, dilation:int, padding:tuple):
num_channels, input_length = input.shape
kernel = np.random.randn(num_channels, kernel_size)
effective_kernel = (kernel_size-1) * dilation + 1
output_length = ((input_length - effective_kernel + sum(padding)) // stride) + 1
output = np.empty((num_channels, output_length), dtype = input.dtype)
padded_input = np.pad(input, pad_width = ((0,0), padding), mode = 'constant', constant_values=0)
for channel in range(num_channels):
for position in range(output_length):
start = position * stride
end = start + effective_kernel
indices = np.arange(start, end, dilation)
output[channel, position] = np.dot(kernel[channel, :], padded_input[channel, indices])
return output
input = np.random.randn(num_channels, input_length)
output = conv1d_with_stride_dilation_padding(input, kernel_size, stride, dilation, padding)
print(output.shape)
(3, 5)
Manim Animation for conv1d with stride, dilation and padding Link
Weād add two more cases to our function to derive the final general conv1d operation.
We may sometimes wish not to limit the number of output channels to be equal to the number of input channels. Letās redefine our convolution kernel to have a shape of out_channels, in_channels, K so that each output channel is a combination of all input channels. We also add a bias term, similar to dense layers, to provide each output channel with an independent offset not tied to the convolution operation. The output length remains the same, but now we have a complete function of the convolution operation.
num_input_channels = 3
num_output_channels = 5
input_length = 10
kernel_size = 3
stride = 2
dilation = 2
padding = (2,2)
def general_conv1d(
input:np.ndarray,
num_output_channels,
kernel_size:int,
stride:int,
dilation:int,
padding:tuple,
bias = True
):
num_input_channels, input_length = input.shape
kernel = np.random.randn(num_output_channels, num_input_channels,kernel_size)
bias = np.random.randn(num_output_channels) if bias is True else np.zeros(num_input_channels)
effective_kernel = (kernel_size-1) * dilation + 1
output_length = ((input_length - effective_kernel + sum(padding)) // stride) + 1
output = np.empty((num_output_channels,output_length), dtype = input.dtype)
padded_input = np.pad(input, pad_width = ((0,0), padding), mode = 'constant', constant_values=0)
for channel in range(num_output_channels):
for position in range(output_length):
start = position * stride
end = start + effective_kernel
indices = np.arange(start, end, dilation)
value = 0
for in_channel in range(num_input_channels):
result = np.dot(kernel[channel, in_channel,:], padded_input[in_channel, indices])
value += result
output[channel, position] = value + bias[channel]
return output
input = np.random.randn(num_input_channels, input_length)
output = general_conv1d(input, num_output_channels,kernel_size, stride, dilation, padding, bias = True)
print(output.shape)
(5, 5)
Nice! Weāve so far derived the convolution operation.
We can compute a one-dimensional convolution in Pytorch straightaway, like this
import torch
from torch import nn
num_input_channels = 3
num_output_channels = 5
input_length = 10
kernel_size = 3
stride = 2
dilation = 2
padding = (2)
bias = True
conv = nn.Conv1d(
in_channels = num_input_channels,
out_channels = num_output_channels,
kernel_size = kernel_size,
stride = stride,
dilation = dilation,
padding = padding,
bias = bias
)
input = torch.randn(num_input_channels,input_length)
output = conv(input)
print(output.shape)
torch.Size([5, 5])
We would not discuss the groups argument in this blog. You can refer to the PyTorch documentation for nn.Conv1d to learn how it affects convolutions.
When applying convolutional operations in neural networks, the learnable parameters are the elements of the kernel ]. Each represents how strongly the model should emphasize or suppress the input sample .
Thereās one other thing we mentioned earlier that you should pay some attention to. the receptive field.
The Receptive field of a convolution refers to the range of input samples that influence a single output value.
When we had the kernel sliding over the signal, the receptive field is simply the range of parts of the input signal that the kernel hovers over at every time .
We referred to this as the effective kernel when deriving the general convolution operation earlier. We also saw this value change from just the kernel size when the dilation was 1 to
When dilation was greater than 1. This is because dilation introduces spaces between the kernel elements, making the area over which the kernel sums larger. We denote the Receptive field of a convolution as .
Recall that dilation increases the spacing between the kernel elements, so the convolution sees a wider portion of the input without increasing the size of the kernel and parameter count.
It is helpful to note that a single kernel (which is the only learnable parameter in a convolution asides the bias ) is reused across the entire sequence in a convolution. This weight sharing is what makes convolutional layers parameter-efficient. A kernel of shape contains far fewer parameters than a dense layer connecting all input positions.
Also, as the input signal length increases, the memory cost of convolutions increases quickly, too. This is because the convolutions must compute and store many overlapping multiplications as the kernel slides over long distances and channels. In practice, this trade-off between parameter efficiency and computational cost becomes more pronounced in audio and speech models that process long sequences.
To recap, weāve so far covered:
- An overview of offline inference
- An overview of streaming inference
- Derived functions for the convolution operation
- Discussed about the receptive field of a convolution
With these basics in place, next we introduce causal convolutions. But first letās understand what casual systems are. A system is causal if its output at time t depends only on present and past inputs, never on future inputs. Future samples are unavailable in real-time processing, so causality prevents access to them.
In decoder-only Transformers, such as GPT, causality is enforced through attention masking. This prevents each token from attending to future positions and ensures that predictions at time depend solely on tokens up to .
In standard convolutions, the receptive field of the kernel is centered around the current position, meaning that it extends into the future. This makes it unsuitable for real-time applications.
In other to make the convolution operation causal, we modify the convolution so that each output depends only on inputs from time steps to .
Letās review the convolution equation from a while back.
Here, indexes positions within the kernel, ranging from to . The lower limit corresponds to the leftmost part of the kernel (past inputs), while the upper limit corresponds to the rightmost part.
Consider a specific timestep on the input signal. When the kernel reaches this point, it is centered at it, this means that the middle of the kernel aligns exactly with . Because the kernel extends both backward and forward around this center point, the convolution operation uses samples from both the past and the future relative to .
This is exactly why standard convolution is not causal. The receptive field at time includes input values that have not yet occurred in real time. Effectively, the model is ālooking aheadā.
Now, we wish to make this operation causal. To do this, we should modify the limits of the summation so that only the correct and past inputs contribute to each output. This means shifting the receptive field entirely to the left or up to the current point in time.
We express causal convolution as:
Here, the summation index now starts from 0 and goes only backward in time, The receptive field size remains the same as in the standard convolution, but the key difference is that the kernel no longer overlaps future samples of the input signal.
This means that when the kernel reaches point , itās rightmost edge now aligns with while the remaining kernel elements cover only the past inputs .
This modification preserves the temporal order of the signal, making the convolution suitable for streaming and autoregressive tasks, where future inputs must never influence the current output.
Now how do we make a regular convolution causal?
To do this, we pad the signal on the left (the past) before performing the operation. This padding controls where the convolution kernel āstartsā and āendsā relative to the input sequence.
Recall that in a standard 1D convolution, the kernel is centered around the current position.
For example, with a kernel size and dilation . The receptive field spans three samples: . This means that it looks one step into the future.
However, for this convolution to be causal, we want each output to depend only on the current and past inputs .
To do this, we pad the left side of the signal with zeros. This effectively shifts the receptive field leftward so that it no longer extends into the future, ensuring that the model cannot access information from later time steps.
How about if dilation ?
In this case, the minimum left padding required to achieve causality is given by:
Letās take and :
Computing the above equation gives the left padding needed to be 4. Our Effective receptive field also becomes 5 (recall ).
So now, given the sequence
with indices
We first pad with the required number of zeros.
Our indices now becomes
At position 0,
The convolution kernel operates over the indices . This corresponds to the values , so the kernel never looks ahead.
At position 1,
Input indices: , values become
At position 2,
Input indices: , values become
At further outputs, youād see that the kernel never sees into the future. We achieved causality!
Letās update our convolution code.
num_input_channels = 3
num_output_channels = 5
input_length = 10
kernel_size = 3
stride = 2
dilation = 2
def causal_conv1d(
input:np.ndarray,
num_output_channels,
kernel_size:int,
stride:int,
dilation:int,
bias = True
):
num_input_channels, input_length = input.shape
kernel = np.random.randn(num_output_channels, num_input_channels,kernel_size)
bias = np.random.randn(num_output_channels) if bias is True else np.ones(num_input_channels)
effective_kernel = (kernel_size-1) * dilation + 1
left_padding = (kernel_size - 1) * dilation
padding = (left_padding, 0)
output_length = ((input_length - effective_kernel + sum(padding)) // stride) + 1
output = np.empty((num_output_channels,output_length), dtype = input.dtype)
padded_input = np.pad(input, pad_width = ((0,0), padding), mode = 'constant', constant_values=0)
for channel in range(num_output_channels):
for position in range(output_length):
start = position * stride
end = start + effective_kernel
indices = np.arange(start, end, dilation)
value = 0
for in_channel in range(num_input_channels):
result = np.dot(kernel[channel, in_channel,:], padded_input[in_channel, indices])
value += result
output[channel, position] = value + bias[channel]
return output
input = np.random.randn(num_input_channels, input_length)
output = causal_conv1d(input, num_output_channels,kernel_size, stride, dilation, bias = True)
print(output.shape) # (5,5)
We can also implement this in Pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
class CausalConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1, bias=True):
super().__init__()
self.kernel_size = kernel_size
self.dilation = dilation
self.stride = stride
self.left_padding = (kernel_size - 1) * dilation
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
stride=stride, dilation=dilation, bias=bias)
def forward(self, x):
x = F.pad(x, (self.left_padding, 0))
return self.conv(x)
input = torch.randn(1,3,10)
causal_conv = CausalConv1d(3,5, kernel_size = 3, dilation = 2)
output = causal_conv(input)
print(output.shape)
# Size([1,5,10])
Manim Animation for Causal Convolutions: Link
Now, what happens when we stack convolutions, like what youād see in state-of-the-art neural audio codecs? Why even have a stack of convolutions?
Recall that the receptive field is the range of input samples that influence an output. When you stack a couple of convolutional layers, each layer expands the range of past samples that can influence an output, effectively increasing the receptive field of the entire system
The receptive field of a single convolution as stated before is
Assume we have three convolutions stacked atop one another with the following arguments:
- Layer 1 :
- Layer 2 :
- Layer 3:
Let the input sample to the stacked convolution module be
passes through Layer 1 to return , The receptive field so far is given as
passes through Layer 2 to return . The receptive field relative to becomes
But how does stride get into the computation?
Let
recall that stride tells the number of temporal steps the convolution kernel jumps through as it moves atop the signal.
After Layer 1 with :
Notice the stride effect.
Now, when Layer 2 looks at the three values of , it is also looking at summaries of all six values of . Because Layer 1 has a stride of 2, moving 1 step in also equals moving steps in . So if Layer 2 has kernel size and dilation :
- Without considering stride from Layer 1, Layer 2 would span positions in
- But each position in represents positions in
- Therefore, relative to , Layer 2ās receptive field spans positions.
This is why we multiply by the cumulative product of strides from previous layers.
Now, for Layer 3, passes through Layer 2 to return The receptive field relative to becomes
A general equation to derive the effective Receptive field of a convolutional network is
- : Number of Layers
- : Stride
- : Kernel size
It looks daunting at first but it is really just the recursive computation we went through earlier for any number of layers .
So, by stacking convolutions, we have increased the receptive field from 3 to 15 to 31, a considerable expansion. This growing contextual window enables a powerful hierarchical feature extraction process.
Early layers with smaller receptive fields would capture fine-grained, local patterns in the input signal. In audio applications, these might correspond to basic acoustic features, such as specific frequencies, onsets, or phoneme-level characteristics.
As we progress deeperĀ and the receptive field expands, each subsequent layer can integrate information across these local patterns to form more complex, abstract representations.
- Mid-level layersĀ might detect combinations of phonemes, forming syllables or simple words.
- Deeper layers,Ā with the largest receptive fields, can understand prosody, speaker information and even semantics by observing long-range dependencies across the entire context window.
This layer-wise processing allows the network capture both fine and broad patterns, understand quick acoustic events and grasp long-term acoustic event.
in Pytorch, we can implement a simple stacked convolutional neural network like this.
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvModule(nn.Module):
def __init__(self, channels:list, kernel_sizes:list, dilations:list, strides:list, bias = False):
super().__init__()
num_layers = len(channels) - 1
assert len(kernel_sizes) == len(dilations) == len(strides) == num_layers
layers = [
nn.Conv1d(
in_channels = channels[i],
out_channels = channels[i+1],
kernel_size = kernel_sizes[i],
stride = strides[i],
dilation = dilations[i],
bias = bias,
)
for i in range(num_layers)]
self.layers = nn.Sequential(*layers)
self.num_layers = num_layers
self.kernel_sizes = kernel_sizes
self.dilations = dilations
self.strides = strides
def forward(self, tensor):
output = tensor
for layer in self.layers:
output = layer(output)
return output
def get_receptive_field(self):
stride_product = 1
receptive_field = 1
for kernel_size, dilation, stride in zip(self.kernel_sizes, self.dilations, self.strides):
receptive_field += (kernel_size - 1) * dilation * stride_product
stride_product *= stride
return receptive_field
channels = [1,3,5,7,11]
kernel_sizes = [3,3,3,3]
dilations = [1,2,1,2]
strides = [2,1,2,1]
conv_stack = ConvModule(
channels=channels,
kernel_sizes=kernel_sizes,
dilations = dilations,
strides = strides
)
input_signal = torch.randn(2,1,24000)
output = conv_stack(input_signal)
print(output.shape)
print(conv_stack.get_receptive_field()
torch.Size([2, 11, 5993])
30
For a causal neural network,
class CausalConvModule(nn.Module):
def __init__(self, channels:list, kernel_sizes:list, dilations:list, strides:list, bias = False):
super().__init__()
num_layers = len(channels) - 1
assert len(kernel_sizes) == len(dilations) == len(strides) == num_layers
layers = [
CausalConv1d(
in_channels = channels[i],
out_channels = channels[i+1],
kernel_size = kernel_sizes[i],
stride = strides[i],
dilation = dilations[i],
bias = bias,
)
for i in range(num_layers)]
self.layers = nn.Sequential(*layers)
self.num_layers = num_layers
self.kernel_sizes = kernel_sizes
self.dilations = dilations
self.strides = strides
def forward(self, tensor):
output = tensor
for layer in self.layers:
output = layer(output)
return output
def get_receptive_field(self):
stride_product = 1
receptive_field = 0
for kernel_size, dilation, stride in zip(self.kernel_sizes, self.dilations, self.strides):
receptive_field += 1 + (kernel_size - 1) * dilation * stride_product
stride_product *= stride
return receptive_field
channels = [1,3,5,7,11]
kernel_sizes = [3,3,3,3]
dilations = [1,2,1,2]
strides = [2,1,2,1]
causal_conv_stack = ConvModule(
channels=channels,
kernel_sizes=kernel_sizes,
dilations = dilations,
strides = strides
)
input_signal = torch.randn(2,1,24000)
output = causal_conv_stack(input_signal)
print(output.shape)
print(conv_stack.get_receptive_field())
torch.Size([2, 11, 5993])
31
Manim Animation for Stacked Convolutions: Link
Finally, letās bring all that weāve discussed so far back into the context of streaming inference.
When performing streaming inference with models that utilize causal convolutions, such as neural audio codecs, text-to-speech, and speech-to-text models, we run the model on an incoming audio stream (or token stream) chunk by chunk, without recomputing everything from scratch each time. This should be possible because, as we have established about causal convolutions, each output depends only on current and past inputs.
So, if you have already processed samples , when the next chunk arrives, you do not need to re-run the whole model. All you have to do is to reuse the last few samples. If you consider our previous pytorch causal convolution code, instead of padding with zeros during inference, you pad with a number of previous samples.
For a single one-dimensional convolution, the amount of past context you need to left-pad your sequence is the same as the left padding for causality.
In a neural network with multiple convolutional layers, operating in streaming mode, each layer must maintain its own cache, comprising the last few samples required to predict the next sample frame. This cache ensures that when a new chunk arrives,
- The layer reuses the relevant past context rather than zero-padding as in plain causal convolutions.
- Computation continues without recomputing old frames.
For example, letās take a three-layer causal stack with an incoming chunk :
- Layer 1 takes prepended with the last frames of its input during the previous timestemp. We would denote these frames as . This was previously stored in cache. After the forward pass, the layer updates with the last samples from its input.
- Layer 2 receives Layer 1ās output. It also prepends its own cache from the previous step, runs its convolution and updates accordingly.
- Layer 3 performs the same operation with its cache, .
Formally, if denotes the convolutional layer, then for each time step:
Where denotes the output frame of layer at time step .
This makes the modelās streaming output identical to what youād get if you processed the full sequence offline, except it does so incrementally.
Now letās write some code to see how all of this works. First, weād implement a cache that stores previous frames for each layer.
import torch
from torch import nn
from typing import Optional
class Cache:
def __init__(self):
self.cache = {}
def get(self, layer_id:str):
states = self.cache.get(layer_id, None)
return states
def set(self, layer_id:str, states:torch.Tensor):
self.cache[layer_id]= states.detach()
def clear(self):
self.cache = {}
The Cache is just a dictionary that uses a layer_id as a unique identification of all convolutional layers in the neural network.
Next, we write our Causal convolutional layer with context, adding an identifier property in the class so we can pull the specific layerās context from the cache. The convolution has two forward passes, one for streaming and one for non-streaming scenarios. When not streaming, we pad the input tensor in the same causal fashion.
Then we declare our Neural Network
class CausalConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1, bias=True):
super().__init__()
self.kernel_size = kernel_size
self.dilation = dilation
self.stride = stride
self.left_padding = (kernel_size - 1) * dilation
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
stride=stride, dilation=dilation, bias=bias)
context_size = (kernel_size - 1) * dilation - (stride - 1)
self.context_size = context_size if context_size > 0 else 0
self.padding = self.context_size
def forward(self, tensor, cache:Optional = None):
if cache is not None:
return self._forward_streaming(tensor, cache)
return self._forward_offline(tensor)
@property
def layer_id(self):
return str(id(self))
def _forward_streaming(self, tensor, cache:Cache):
B, C, T = tensor.shape
cached_states = cache.get(self.layer_id)
if cached_states is None:
cached_states = torch.zeros(B, C, self.context_size, device = tensor.device, dtype = tensor.dtype)
input_with_context = torch.cat([cached_states, tensor], dim = 2)
output = self.conv(input_with_context)
if self.context_size > 0:
total_input_length = input_with_context.shape[2]
if total_input_length >= self.context_size:
new_cache_start = total_input_length - self.context_size
new_cache = input_with_context[:, :, new_cache_start:]
else:
new_cache = input_with_context
cache.set(self.layer_id, new_cache)
return output
def _forward_offline(self, tensor):
tensor = nn.functional.pad(tensor,(self.padding, 0))
return self.conv(tensor)
class ConvModule(nn.Module):
def __init__(self, channels:list, kernel_sizes:list, dilations:list, strides:list, bias = False):
super().__init__()
num_layers = len(channels) - 1
assert len(kernel_sizes) == len(dilations) == len(strides) == num_layers
layers = [
CausalConv1d(
in_channels = channels[i],
out_channels = channels[i+1],
kernel_size = kernel_sizes[i],
stride = strides[i],
dilation = dilations[i],
bias = bias,
)
for i in range(num_layers)]
self.layers = nn.Sequential(*layers)
self.num_layers = num_layers
self.kernel_sizes = kernel_sizes
self.dilations = dilations
self.strides = strides
def forward(self, tensor:torch.Tensor, cache:Optional[Cache] = None):
output = tensor
for layer in self.layers:
output = layer(output, cache)
return output
It is essential that we ensure the same results are obtained from the neural network, whether it is streaming or not. To test this, we can simulate a real-time scenario where audio chunks are fed into the neural network as they are made available.
channels = [1,3,5,7,11]
kernel_sizes = [3,3,3,3]
dilations = [1,2,1,2]
strides = [2,1,2,1]
cache = Cache()
causal_conv_stack = ConvModule(
channels=channels,
kernel_sizes=kernel_sizes,
dilations = dilations,
strides = strides
)
test_audio = torch.randn(2,1,24000)
output_frames = []
chunk_size = 2000 # New chunk of audio stream that keeps coming in
for i in range(0, test_audio.shape[-1], chunk_size):
chunk = test_audio[...,i:i+chunk_size]
output_frame = causal_conv_stack(
chunk, cache = cache
)
output_frames.append(output_frame)
print('streaming inference')
stream_output = torch.concat(output_frames, dim = -1)
print(stream_output.shape)
print('offline inference')
offline_output = causal_conv_stack(test_audio)
print(offline_output.shape)
print(torch.abs(stream_output-offline_output).mean())
streaming inference
torch.Size([2, 11, 6000])
offline inference
torch.Size([2, 11, 6000])
tensor(0., grad_fn=<MeanBackward0>)
Great!
We were able to match sending in an entire audio signal to the model versus sending in chunks of audio in a streaming fashion.
One last thing before we call it a wrap, it would be nice to run our setup on an actual causal audio model. As a part of an upcoming project, Iāve been able to extract the weights of the acoustic variational autoencoders built by Microsoft as part of their VibeVoice project. VibeVoice is a speech synthesis model specifically trained for long-form, expressive audio such as podcasts. The Variational Autoencoder from Vibevoice differs from the traditional method of currently tokenizing audio into discrete tokens for language modeling. It instead leaves audio latents as continuous representations, and language modeling is done directly in this form. We would leave talking about this idea and my own upcoming related project for a future write-up.
The goal is to simulate streaming audio through a variational auto-encoder and compare audio results.
git clone https://github.com/odunola499/latent-lm-lessons.git
cd latent-lm-lessons
pip install -e .
Now, you can access the code for this simulation in blogs/streaming_convolutions/stream_simulation.py. For convenience, so I can explain how it works, Iāll add it here as well. However, you should first skim through the modelās streaming cache and architecture, also available in the repository.
from queue import Queue
from threading import Thread
from latentlm.vae.audio.model import AcousticTokenizerModel, AcousticTokenizerConfig
from latentlm.vae.audio.cache import StreamingCache
import torchaudio
from torchaudio.io import StreamWriter
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import torch
import math
import time
def pad_to_multiple(x: torch.Tensor, multiple: int = 3200) -> torch.Tensor:
L = x.shape[-1]
target_len = multiple * math.ceil(L / multiple)
if L < target_len:
pad_amount = target_len - L
x = torch.nn.functional.pad(x, (0, pad_amount))
return x
def compute_receptive_field(kernel_sizes, dilation, strides):
stride_product = 1
receptive_field = 1
for stride, kernel_size in zip(strides, kernel_sizes):
receptive_field += (kernel_size - 1) * dilation * stride_product
stride_product *= stride
return receptive_field
def load_model(device, repo_id='odunola/vibevoice_vae_weights'):
cache = StreamingCache()
config = AcousticTokenizerConfig()
model = AcousticTokenizerModel(config).to(device)
path = hf_hub_download(
repo_id=repo_id,
filename='acoustic.safetensors'
)
weights = load_file(path)
model.load_state_dict(weights)
return model, cache
def audio_stream(file_path, chunk_size):
waveform, sample_rate = torchaudio.load(file_path)
print("Starting streaming")
waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=24000)
waveform = waveform.mean(0, keepdim=True).unsqueeze(0)
for i in range(0, waveform.shape[-1], chunk_size):
yield waveform[..., i:i + chunk_size]
def producer(input_queue: Queue, file_path: str, chunk_size=6000):
for chunk in audio_stream(file_path, chunk_size):
chunk = pad_to_multiple(chunk, chunk_size)
input_queue.put(chunk)
time.sleep(0.05)
input_queue.put(None)
def consumer(
input_queue: Queue,
output_queue: Queue,
model: AcousticTokenizerModel,
cache: StreamingCache,
sample_indices: torch.Tensor,
device
):
while True:
chunk = input_queue.get()
if chunk is None:
output_queue.put(None)
break
print(f"Input to model: {chunk.shape}")
recon_chunk, _ = model(
chunk.to(device),
cache=cache,
sample_indices=sample_indices,
use_cache=True,
debug=False
)
output_queue.put(recon_chunk)
def save_to_disk(stream_output_path: str, output_queue: Queue):
writer = StreamWriter(stream_output_path)
writer.add_audio_stream(sample_rate=24000, num_channels=1)
writer.open()
while True:
chunk = output_queue.get()
if chunk is None:
print("Reached end")
writer.close()
break
chunk = chunk[0].detach().cpu()
chunk = chunk.squeeze(0).unsqueeze(-1)
writer.write_audio_chunk(0, chunk)
if __name__ == "__main__":
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_queue = Queue()
output_queue = Queue()
model, cache = load_model(device)
sample_indices = torch.tensor([0], device=device)
file_path = 'audio.mp3'
offline_output_path = 'offline_recon.wav'
stream_output_path = 'stream_recon.wav'
chunk_size = 3200
producer_thread = Thread(target=producer, args=(
input_queue, file_path, chunk_size)
)
consumer_thread = Thread(target=consumer, args=(
input_queue, output_queue, model, cache, sample_indices, device)
)
save_thread = Thread(target=save_to_disk, args=(stream_output_path, output_queue))
producer_thread.start()
consumer_thread.start()
save_thread.start()
producer_thread.join()
consumer_thread.join()
save_thread.join()
print('Finished Streaming inference')
audio, sample_rate = torchaudio.load(file_path)
audio = torchaudio.functional.resample(audio, orig_freq=sample_rate, new_freq=24000)
audio = audio.unsqueeze(0).to(device)
audio = pad_to_multiple(audio, chunk_size)
recon, _ = model(audio)
torchaudio.save(offline_output_path, recon[0].detach().cpu(), sample_rate=24000)
print('Finished Offline inference')
Letās talk about what is happening here.
We have two queues and three threads. In the first thread, the producer simulates a real-time audio source. In reality, this would be a microphone or socket stream. Our audio source is an audio file that our producer loads, resamples, and then slices into fixed-size chunks, pushing each chunk into the first queue, our input queue. Once all the chunks have been pushed, we push a None token to signal the end of the stream.
In the second thread, the consumer continuously retrieves items from the input queue and performs the actual streaming simulation. Every chunk is fed to the modelās forward method. For every pass, the cache of all causal convolutions in the model is updated. The model produces a latent reconstruction for that specific chunk and also computes the reconstructed audio. I should add that this is an unconventional streaming task, as a variational autoencoder like this would likely have the encoder providing latents in a streaming fashion for another task, or the decoder reconstructing audio from latents coming in from perhaps an autoregressive generative task.
The third thread retrieves the reconstructed chunks from the output queue and uses Torchaudioās StreamWriter to write to the file incrementally, so we don't have to wait for the entire audio to finish processing before saving it to the file. This aligns with how real text-to-speech models synthesize audio for playback.
We set the chunk size to 3200, as this is the length of audio input that would give exactly one latent frame for this model. You can confirm by multiplying the strides, as strides are responsible for downsampling in CNNs, as we saw earlier.
You can try running this to get results, but I have made some results available here, so you can listen straight away. Try out other settings too, like different chunk sizes larger and smaller than 3200. There are also a few examples on streaming in the repo and i would be adding a few more soon.
Throughout this writeup, we've covered how convolutions worked, how kernel sizes, dilations, and strides shape the behavior of the network, and why causal convolutions prevent future parts of a sequence from influencing the current timestep. We've also explored how streaming inference works in practice with a cache to manage state. Finally, we put everything to the test on a trained audio variational autoencoder.
Thank you for stopping by! I hope this deep dive into convolutions has been as enlightening for you as it was exciting for me.
If you enjoyed this post, please consider giving a thumbs up below!
Further Reading
- Oord, A. v. d., Dieleman, S., Zen, H., et al. (2016). WaveNet: A Generative Model for Raw Audio. arXiv:1609.03499.
- Younesi, A., Ansari, M., Fazli, M. A., Ejlali, A., Shafique, M., & Henkel, J. (2024). A Comprehensive Survey of Convolutions in Deep Learning: Applications, Challenges, and Future Trends. arXiv:2402.15490. https://arxiv.org/pdf/2402.15490
- How to Calculate Receptive Field in CNN: https://www.baeldung.com/cs/cnn-receptive-field-size
- Moshi by Kyutai Labs https://github.com/kyutai-labs/moshi
- VibeVoice by Microsoft https://github.com/vibevoice-community/VibeVoice
- Continuous Audio Language Models: https://arxiv.org/abs/2509.06926
- OpenAI Whisper https://arxiv.org/abs/2212.04356
- Conformer: Convolution-augmented Transformer for Speech Recognition: https://arxiv.org/abs/2005.08100
- Anatomy of Industrial Scale Multilingual ASR: https://arxiv.org/abs/2005.08100