Tutorial 3: Modules

The nn module provides neural network layers, activations, and the Module trait that unifies them all. Modules compose naturally — a model is a Module that contains other Modules.

This tutorial builds on Tutorial 2: Automatic Differentiation.

The Module Trait

Every layer in floDl implements this trait:

pub trait Module {
    fn forward(&self, input: &Variable) -> Result<Variable>;

    fn parameters(&self) -> Vec<Parameter> { vec![] }
    fn name(&self) -> &str { "module" }
    fn sub_modules(&self) -> Vec<Rc<dyn Module>> { vec![] }
    fn move_to_device(&self, _device: Device) {}
    fn set_training(&self, _training: bool) {}
    fn as_named_input(&self) -> Option<&dyn NamedInputModule> { None }
}

forward takes an input variable and returns an output variable. parameters returns all learnable weights. Modules with no learnable parameters (like activations) return an empty vec.

Linear

Fully connected layer: y = x @ W^T + b.

let linear = Linear::new(784, 128)?;

Weights are Kaiming-initialized (suitable for ReLU). Input shape: [batch, in_features]. Output shape: [batch, out_features].

let output = linear.forward(&input)?;  // [batch, 784] -> [batch, 128]

Builder options

// Without bias
let linear = Linear::no_bias(784, 128)?;

// On a specific device
let linear = Linear::on_device(784, 128, Device::CUDA(0))?;

Convolutions

Conv1d

1D convolution over [N, C, L] inputs. Same builder pattern as Conv2d.

let conv = Conv1d::new(1, 16, 3)?;  // in=1, out=16, kernel=3

// Fluent builder for full control
let conv = Conv1d::configure(3, 16, 5)
    .with_stride(2)
    .with_padding(2)
    .on_device(Device::CUDA(0))
    .done()?;

Conv2d

2D convolution over [N, C, H, W] inputs.

let conv = Conv2d::new(3, 64, 3)?;  // in=3, out=64, kernel=3 (stride=1, padding=0)

// Fluent builder
let conv = Conv2d::configure(3, 64, 3)
    .with_padding(1)
    .with_stride(2)
    .done()?;

// Full control
let conv = Conv2d::build(3, 64, 3, true, [1,1], [1,1], [1,1], 1, Device::CPU)?;

Conv3d

3D convolution over [N, C, D, H, W] inputs. For volumetric data (video, 3D medical).

let conv = Conv3d::new(1, 32, [3, 3, 3])?;

let conv = Conv3d::configure(1, 32, [3, 3, 3])
    .with_padding([1, 1, 1])
    .done()?;

Transpose Convolutions

Transpose (deconvolution) variants for upsampling:

let deconv1d = ConvTranspose1d::new(16, 1, 3)?;
let deconv2d = ConvTranspose2d::new(64, 3, 4)?;
let deconv3d = ConvTranspose3d::new(32, 1, [3, 3, 3])?;

Pooling

MaxPool2d / AvgPool2d

2D pooling over [N, C, H, W] inputs. Stride defaults to kernel size.

let pool = MaxPool2d::new(2);                            // kernel=2, stride=2
let pool = MaxPool2d::with_stride(3, 2).padding(1);     // kernel=3, stride=2, padding=1
let output = pool.forward(&input)?;                      // [B, C, H, W] -> [B, C, H/2, W/2]

let pool = AvgPool2d::new(2);                            // average pooling
let pool = AvgPool2d::with_stride(3, 2).padding(1).count_include_pad(false);

No learnable parameters. Commonly paired with Conv2d + BatchNorm2d:

let model = FlowBuilder::from(Conv2d::new(3, 64, 3)?)
    .through(BatchNorm2d::new(64)?)
    .through(ReLU)
    .through(MaxPool2d::new(2))
    .build()?;

MaxPool1d / AvgPool1d

1D pooling over [N, C, L] inputs, for sequence and signal processing:

let pool = MaxPool1d::new(2);
let pool = AvgPool1d::with_stride(3, 2).padding(1);

Adaptive Pooling

Output a fixed spatial size regardless of input dimensions:

// As a free function
let pooled = adaptive_avg_pool2d(&input, [1, 1])?;  // [B, C, H, W] -> [B, C, 1, 1]

// As a module
let pool = AdaptiveMaxPool2d::new(7, 7);             // fixed 7x7 output
let pool = AdaptiveAvgPool2d::new([1, 1]);           // global avg (ResNet head)
let output = pool.forward(&input)?;

PixelShuffle / PixelUnshuffle

Rearrange elements for sub-pixel convolution (super-resolution):

let shuffle = PixelShuffle::new(2);    // [B, C*4, H, W] -> [B, C, H*2, W*2]
let unshuffle = PixelUnshuffle::new(2); // inverse

let model = FlowBuilder::from(Conv2d::new(3, 48, 3)?)  // 48 = 3 * 4 (upscale=2)
    .through(PixelShuffle::new(2))                       // -> [B, 3, H*2, W*2]
    .build()?;

Upsample

Resize spatial dimensions via interpolation:

let up = Upsample::new(&[64, 64], 1);  // output_size, mode (0=nearest, 1=bilinear, 2=bicubic)

Unfold / Fold

Extract and reconstruct sliding local blocks (im2col / col2im as modules):

let unfold = Unfold::new([3, 3], [1, 1], [0, 0], [1, 1]);  // kernel, dilation, padding, stride
let fold = Fold::new([28, 28], [3, 3], [1, 1], [0, 0], [1, 1]);  // output_size, kernel, ...

Normalization

LayerNorm

Normalizes the last dimension. Commonly used in transformers.

let ln = LayerNorm::new(512)?;
let output = ln.forward(&input)?;  // [batch, 512] -> [batch, 512]

RMSNorm

Root Mean Square normalization. Simpler and faster than LayerNorm — no mean subtraction, just RMS scaling. Used in LLaMA, Gemma, and other modern architectures:

let rn = RMSNorm::new(512)?;
let rn = RMSNorm::new(512)?.eps(1e-6);  // custom epsilon
let output = rn.forward(&input)?;

BatchNorm

Normalizes over the batch dimension. Uses running statistics at inference.

// For fully-connected layers: input [batch, features]
let bn = BatchNorm::new(128)?;
let output = bn.forward(&input)?;  // [batch, 128] -> [batch, 128]

// For conv layers: input [batch, channels, height, width]
let bn2d = BatchNorm2d::new(64)?;
let output = bn2d.forward(&input)?;  // [B, 64, H, W] -> [B, 64, H, W]

Use BatchNorm after Linear layers and BatchNorm2d after Conv2d layers. Both behave differently during training (batch statistics) vs. inference (running statistics). They track num_batches_tracked and will error in eval mode if no training has occurred — this catches a common silent bug.

See Train/Eval Mode below.

GroupNorm

Normalizes over groups of channels. Independent of batch size — works well with small batches where BatchNorm struggles:

let gn = GroupNorm::new(4, 16)?;   // 4 groups, 16 channels
let output = gn.forward(&input)?;  // [B, 16, H, W] -> [B, 16, H, W]

InstanceNorm

Normalizes each channel independently. Standard for style transfer:

let inn = InstanceNorm::new(64, true)?;   // 64 features, affine=true
let output = inn.forward(&input)?;

Dropout

Randomly zeroes elements during training. Uses inverted dropout so no scaling is needed at inference.

let drop = Dropout::new(0.1);    // 10% drop probability — zeroes individual elements
let drop2d = Dropout2d::new(0.1); // drops entire channels (for conv features)
let adrop = AlphaDropout::new(0.1); // maintains self-normalizing property (for SELU networks)
let output = drop.forward(&input)?;

During inference, all dropout variants become identity functions.

Padding

Padding modules for use in graph builder pipelines:

let pad = ZeroPad2d::new(1);                           // 1 pixel on all sides
let pad = ZeroPad2d::asymmetric(1, 1, 2, 2);           // left, right, top, bottom
let pad = ReflectionPad2d::new(1);                      // reflect at boundaries
let pad = ReflectionPad2d::asymmetric(1, 1, 2, 2);

Embedding

Lookup table mapping integer indices to dense vectors.

let emb = Embedding::new(10000, 64)?;  // vocab=10000, dim=64

Input is a Variable wrapping an Int64 tensor:

// [batch, seq_len] -> [batch, seq_len, 64]
let output = emb.forward(&indices)?;

EmbeddingBag

Computes bag-of-embeddings (sum, mean, or max of groups of indices). Useful when input sequences have variable length and you need a fixed-size output per bag:

let bag = EmbeddingBag::new(10000, 64)?;  // vocab=10000, dim=64

// indices: [total_indices], offsets: [num_bags] (start positions)
// mode: 0=sum, 1=mean, 2=max
let output = bag.forward_bag(&indices, &offsets, 1)?;  // [num_bags, 64]

Recurrent Layers

GRUCell / LSTMCell

Single-timestep cells. Backed by fused ATen kernels (~2 GPU kernels instead of ~25-40):

let gru = GRUCell::new(128, 256)?;
let h = gru.forward_step(&x, None)?;      // first step: h initialized to zeros
let h = gru.forward_step(&x2, Some(&h))?; // subsequent steps

let lstm = LSTMCell::new(128, 256)?;
let state = lstm.forward_step(&x, None)?;             // first step
let state = lstm.forward_step(&x2, Some(&state))?;    // subsequent steps

GRU / LSTM

Multi-layer sequence modules matching PyTorch’s nn.GRU / nn.LSTM. Process entire sequences and stack multiple layers. forward_seq uses fused at::lstm / at::gru kernels (cuDNN-accelerated on CUDA) — the full sequence is processed in a single kernel call, no per-timestep dispatch overhead:

let gru = GRU::new(128, 256, 2)?;  // input=128, hidden=256, 2 layers
// Input: [seq_len, batch, input_size] (default) or [batch, seq_len, input_size] (batch_first)
let (output, h_n) = gru.forward_seq(&x, None)?;
// output: [seq_len, batch, hidden_size], h_n: [num_layers, batch, hidden_size]

let lstm = LSTM::new(128, 256, 2)?;
let (output, (h_n, c_n)) = lstm.forward_seq(&x, None)?;
// output: [seq_len, batch, hidden_size]
// h_n, c_n: [num_layers, batch, hidden_size]

// Batch-first ordering:
let gru = GRU::on_device(128, 256, 2, true, Device::CUDA(0))?;
let lstm = LSTM::on_device(128, 256, 2, true, Device::CUDA(0))?;

Attention

MultiheadAttention

Standard multi-head attention matching PyTorch’s nn.MultiheadAttention. Supports self-attention and cross-attention with optional masking:

let mha = MultiheadAttention::new(512, 8)?;  // embed_dim=512, 8 heads

// Self-attention (query = key = value)
let y = mha.forward(&x)?;                           // [B, seq, 512] -> [B, seq, 512]

// Cross-attention or masked attention
let y = mha.forward_ext(&query, &key, &value, Some(&mask))?;

Bilinear

Bilinear transformation: y = x1^T A x2 + b. Useful for modeling interactions between two feature sets:

let bi = Bilinear::new(128, 64, 32, true)?;  // in1=128, in2=64, out=32, bias=true
let y = bi.forward_bilinear(&x1, &x2)?;      // [B, 128] x [B, 64] -> [B, 32]

Activations

Activation functions are also modules, making them composable in the graph builder:

// Zero-sized types — no parameters, no allocation
ReLU              // max(0, x)
Sigmoid           // 1 / (1 + exp(-x))
Tanh              // hyperbolic tangent
GELU              // Gaussian Error Linear Unit
SiLU              // x * sigmoid(x), also called Swish
SELU              // scaled ELU — self-normalizing (pair with AlphaDropout)
Mish              // x * tanh(softplus(x))
Hardswish         // efficient Swish approximation
Hardsigmoid       // piecewise-linear sigmoid approximation
Identity          // pass-through

// Parameterized — take a config value at construction
LeakyReLU::new(0.01)         // max(x, slope * x)
ELU::new(1.0)                // alpha * (exp(x) - 1) for x < 0
Softplus::new(1.0, 20.0)     // smooth approximation of ReLU (beta, threshold)
Softmax::new(-1)              // softmax along dim
LogSoftmax::new(-1)           // log(softmax(x)) — numerically stable
Flatten::new(1, -1)           // flatten spatial dims (start_dim, end_dim)

// Learnable
let prelu = PReLU::new(1, Device::CPU)?;  // parametric ReLU (num_parameters, device)

All zero-sized activations compile to direct tensor calls with no overhead.

Train/Eval Mode

Some modules (Dropout, BatchNorm) behave differently during training vs. inference. The set_training method on Module controls this, and convenience aliases train() / eval() make it concise:

model.eval();    // eval mode  — same as set_training(false)
model.train();   // training mode — same as set_training(true)

train() and eval() are convenience methods for set_training(true) and set_training(false). When using the graph builder, Graph::set_training(bool) (and its aliases) propagates to all nodes recursively.

Optional Module Traits

NamedInputModule

For modules that receive using references as a named map instead of positional arguments:

pub trait NamedInputModule: Module {
    fn forward_named(
        &self,
        input: &Variable,
        refs: &HashMap<String, Variable>,
    ) -> Result<Variable>;
}

Stateful Module Methods

For modules with per-forward mutable state (attention location, counter), override reset() on Module. Loops auto-call it before iterating:

impl Module for AttentionStep {
    fn reset(&self) {
        self.location.set(None); // clear stale state
    }
    // ...
}

For modules holding Variables across forward calls (recurrent state), override detach_state() on Module. graph.detach_state() propagates to all modules:

impl Module for RecurrentModule {
    fn detach_state(&self) {
        // break gradient chain on retained hidden state
    }
    // ...
}

Composing Modules Manually

Without the graph builder, you compose modules in plain Rust. Implement sub_modules() to declare children — the framework then handles device placement, training mode, and parameter collection:

struct MLP {
    fc1: Linear,
    fc2: Linear,
}

impl MLP {
    fn new() -> Result<Self> {
        Ok(MLP {
            fc1: Linear::new(784, 128)?,
            fc2: Linear::new(128, 10)?,
        })
    }
}

impl Module for MLP {
    fn forward(&self, input: &Variable) -> Result<Variable> {
        let x = self.fc1.forward(input)?;
        let x = x.relu()?;
        self.fc2.forward(&x)
    }

    fn parameters(&self) -> Vec<Parameter> {
        let mut params = self.fc1.parameters();
        params.extend(self.fc2.parameters());
        params
    }

    fn sub_modules(&self) -> Vec<Rc<dyn Module>> {
        vec![Rc::new(self.fc1.clone()), Rc::new(self.fc2.clone())]
    }
}

This is the same pattern as PyTorch’s nn.Module — declare children, let the framework walk the tree. For anything involving residual connections, parallel branches, loops, or conditional execution, the graph builder API is more expressive and handles the wiring automatically.