Architecture Deep Dive
Understand the internal structure and components of the DREAM cell
Architecture Deep Dive
This guide provides a comprehensive look inside the DREAM cell architecture.
Cell Structure Overview
The DREAM cell processes input through four main computational blocks:
┌─────────────────────────────────────────────────────────────────┐
│ DREAM Cell │
│ │
│ Input: x (batch, input_dim) │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ BLOCK 1: Predictive Coding │ │
│ │ ───────────────────────── │ │
│ │ x_pred = tanh((C + U × V^T) × h_prev) │ │
│ │ error = x - x_pred │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ BLOCK 2: Surprise Gate │ │
│ │ ───────────────────── │ │
│ │ error_norm = ||error|| │ │
│ │ entropy = 0.5 × log(2πe × error_var) │ │
│ │ effective_tau = 0.3×classical + 0.7×adaptive │ │
│ │ surprise = sigmoid((error_norm - effective_tau) / γ) │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ BLOCK 3: Fast Weights Update (Hebbian) │ │
│ │ ───────────────────────────────── │ │
│ │ dU = -λ(U - U_target) + η × surprise × (h ⊗ error × V) │ │
│ │ U ← U + dU × dt │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ BLOCK 4: Liquid Time-Constant Update │ │
│ │ ───────────────────────────────── │ │
│ │ τ = τ_sys / (1 + surprise × scale) │ │
│ │ h_target = tanh(input_effect) │ │
│ │ h_new = (1 - dt/τ) × h_prev + (dt/τ) × h_target │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ Output: h_new (batch, hidden_dim) │
│ │
└─────────────────────────────────────────────────────────────────┘Block 1: Predictive Coding
Purpose
Generate predictions of the next input based on the current state. This enables the model to compute prediction errors, which drive learning.
Parameters
| Parameter | Shape | Purpose |
|---|---|---|
C | (input_dim, hidden_dim) | Base prediction matrix (stable) |
W | (hidden_dim, input_dim) | Error projection |
B | (hidden_dim, input_dim) | Input projection |
U | (batch, hidden_dim, rank) | Fast weights (dynamic, per-batch) |
V | (input_dim, rank) | Fast weights basis (fixed) |
Computation
# Dynamic modulation via fast weights
dynamic = U @ V.transpose(-2, -1) # (batch, hidden, input)
# Effective prediction matrix (base + dynamic modulation)
C_effective = C.transpose(0, 1) + dynamic * 0.1
# Generate prediction
x_pred = tanh(C_effective @ h_prev) * ||x||Design Rationale
- Base matrix C: Provides stable, learned predictions
- Fast weights U: Add input-specific adaptation
- Multiplicative modulation: Preserves structure while allowing flexibility
- Norm scaling: Maintains appropriate prediction magnitude
Block 2: Surprise Gate
Purpose
Compute "surprise" from prediction error. Not all errors are equal—expected errors produce low surprise, while unexpected errors trigger strong updates.
Parameters
| Parameter | Default | Purpose |
|---|---|---|
τ₀ (base_threshold) | 0.5 | Base surprise threshold |
α (entropy_influence) | 0.2 | Entropy weighting |
γ (surprise_temperature) | 0.1 | Sigmoid smoothness |
Computation
# Error statistics (exponential moving average)
error_mean = (1 - β) * error_mean + β * error
error_var = (1 - β) * error_var + β * (error - error_mean)²
# Classical entropy-based threshold
entropy = 0.5 * log(2πe * error_var)
classical_tau = τ₀ * (1 + α * entropy)
# Adaptive habituation threshold
adaptive_tau = (1 - 0.001) * adaptive_tau + 0.001 * error_norm
adaptive_tau = clamp(adaptive_tau, max=0.8)
# Effective threshold (weighted combination)
effective_tau = 0.3 * classical_tau + 0.7 * adaptive_tau
# Surprise (sigmoid activation)
surprise = sigmoid((error_norm - effective_tau) / γ)Design Rationale
- Entropy accounting: Higher uncertainty → higher threshold
- Habituation: Prevents over-reaction to constant errors
- Sigmoid gating: Smooth transition from 0 (expected) to 1 (surprising)
- Adaptive threshold: Learns what "normal" error looks like
Block 3: Fast Weights Update
Purpose
Update fast weights via Hebbian learning: "neurons that fire together, wire together."
Parameters
| Parameter | Default | Purpose |
|---|---|---|
λ (forgetting_rate) | 0.01 | Decay toward target |
η (base_plasticity) | 0.1 | Learning rate scale |
Computation
# Hebbian term (outer product projected onto V)
hebbian = (h_prev ⊗ error) @ V # (batch, hidden, rank)
# Update with decay toward target
dU = -λ * (U - U_target) + η * surprise * hebbian
# Euler integration
U_new = U + dU * dt
# Normalize to target norm (stability)
U_norm = ||U_new||
U_new = U_new * (target_norm / U_norm)Design Rationale
- Decay term: Prevents unbounded weight growth
- Surprise modulation: High surprise → faster learning
- Normalization: Maintains numerical stability
- Target U_target: Provides consolidation anchor
Block 4: Liquid Time-Constant Update
Purpose
Update hidden state with adaptive integration speed. Time constants change based on input novelty.
Parameters
| Parameter | Default | Purpose |
|---|---|---|
τ_sys (ltc_tau_sys) | 10.0 | System time constant |
scale (ltc_surprise_scale) | 10.0 | Surprise influence |
Computation
if LTC disabled:
h_new = tanh(input_effect)
else:
# Dynamic time constant
tau_dynamic = τ_sys / (1 + surprise × scale)
# Clamp for stability
tau_effective = clamp(tau_dynamic, min=0.01, max=50.0)
# Target state
h_target = tanh(input_effect)
# Time-step normalization
dt_over_tau = dt / (tau_effective + dt)
dt_over_tau = clamp(dt_over_tau, min=0.01, max=0.5)
# LTC update (Euler integration)
h_new = (1 - dt_over_tau) * h_prev + dt_over_tau * h_targetDesign Rationale
- High surprise → small τ: Fast response to novelty
- Low surprise → large τ: Slow integration (maintain memory)
- Clamping: Prevents numerical instability
- Continuous-time dynamics: Interpretable temporal behavior
State Management
The DREAMState contains all persistent tensors:
| Tensor | Shape | Purpose |
|---|---|---|
h | (batch, hidden_dim) | Hidden state |
U | (batch, hidden_dim, rank) | Fast weights (left factor) |
U_target | (batch, hidden_dim, rank) | Target for sleep consolidation |
adaptive_tau | (batch,) | Habituated surprise threshold |
error_mean | (batch, input_dim) | EMA of prediction error |
error_var | (batch, input_dim) | EMA of error variance |
avg_surprise | (batch,) | EMA of surprise |
Per-Batch Design
Each batch element has independent state for:
- Independent adaptation: Each sequence adapts differently
- Variable-length handling: Proper masking support
- Memory retention: State persists across sequences
State Operations
# Initialize state
state = model.init_state(batch_size=32, device='cuda')
# Detach from computation graph (for truncated BPTT)
state = state.detach()
# Access components
hidden = state.h
fast_weights = state.U
surprise = state.avg_surpriseInformation Flow
Here's how information flows through the cell:
- Input arrives → Block 1 generates prediction
- Prediction error computed → Block 2 calculates surprise
- Surprise gates learning → Block 3 updates fast weights
- Adaptive integration → Block 4 updates hidden state
- Output produced → New state passed to next timestep
This creates a feedback loop where surprise drives adaptation, which improves predictions, which reduces surprise.
Next Steps
- Mathematical Foundations - Deep dive into the equations
- Configuration Guide - Tune architectural parameters
- API Reference - Complete class documentation