AI Atlas
Intermediate· ~2 min read#batch-normalization#layer-norm#training-stability

Batch Normalization

Stabilizing layer activations

Normalizes a layer's activations using mini-batch statistics — speeds up training and stabilizes deep networks.

BATCH NORMALIZATIONWithout BNshifting activationsBNAfter BNcentered, stableStabilizes per-layer input distributions for faster, steadier training.
Definition

In a deep network, each layer's output feeds the next. If activation distributions keep shifting layer to layer (internal covariate shift), every layer chases a moving target — training slows and grows unstable. Batch normalization fixes this by normalizing each layer's input by the mini-batch's mean and standard deviation.

Mechanism: per batch, compute mean and variance per feature dimension, normalize, then rescale and shift with two learnable parameters (γ and β). γ and β let the model decide its preferred distribution. At inference time, running averages tracked during training are used in place of batch statistics.

Published in 2015, BatchNorm transformed CNN training. It enabled higher learning rates, made initialization less fragile, and often replaced dropout. In transformers, layer normalization dominates — sequential data and small inference batches don't suit BatchNorm. Modern large models use LayerNorm or its faster variant RMSNorm.

Analogy

An orchestra. Each section starts at a different volume — one too loud, another lost. The conductor calls out a standard reference level for each section to begin from; each section then expresses around that level. The concert stays coherent. BatchNorm provides that standard reference at every layer.

Real-world example

Training an 18-layer CNN on CIFAR-10:

- No BatchNorm: lr=0.1 → loss NaN in epoch 1. lr=0.001 → 78% accuracy after 60 epochs. - With BatchNorm: lr=0.1 stable; 85% accuracy in 30 epochs. 2× faster and better.

Dropout becomes unnecessary; BatchNorm's regularizing effect is enough. Hence post-2015 architectures (ResNet, EfficientNet) bake BatchNorm into every block.

Transformers go a different way. Sequential processing is sensitive to batch size; inference often runs single examples where batch statistics are meaningless. LayerNorm (per-sample, per-feature normalization) became the default.

Code examples
PyTorch · BatchNorm in a CNNPython
import torch.nn as nn

# Classic ResNet block layout
class BasicBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        return self.relu(out)

# Training: model.train() — batch stats
# Inference: model.eval() — running averages
LayerNorm in a transformer blockPython
import torch.nn as nn

# Pre-LN transformer block
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, ff_dim):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, ff_dim), nn.GELU(), nn.Linear(ff_dim, dim),
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.ff(self.norm2(x))
        return x
When to use
  • Deep CNN training — BatchNorm is the default
  • Loss is unstable and you want higher LR
  • Diagnosed internal covariate shift
  • Transformer / sequential models: LayerNorm or RMSNorm
When not to use
  • Very small batch (<8) — statistics get noisy
  • Sequential models (RNN / Transformer) — prefer LayerNorm
  • Inference without a batch — switch to eval mode
Common pitfalls

Forgetting model.eval()

BatchNorm uses batch stats in training and running averages at inference. Without eval(), inference uses 'one example's stats' — broken.

Tiny batch size

Batch <8 → unstable statistics → unstable training. Group normalization or layer normalization fits better.

Leaving the bias on

Conv/Linear before BatchNorm doesn't need a bias — BatchNorm's β covers it. bias=False is standard.