Architecture
How DREAM works — 4 blocks explained
Architecture
DREAM cell combines four key mechanisms for adaptive sequence processing.
Overview
┌─────────────────────────────────────────────────────────────┐
│ DREAM Cell │
│ Input: x_t (batch, input_dim) │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────┐ │
│ │ 1. Predictive │ x̂ = tanh(C^T @ h) │
│ │ Coding │ e = x - x̂ │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ 2. Surprise │ S = σ((r - τ) / (2γ)) │
│ │ Gate │ r = ||e|| / ||μ_e|| │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ 3. Fast Weights │ dU = -λ(U-U_tgt) + η·S·(h⊗e)@V │
│ │ (STDP) │ U ← U + dU·dt │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ 4. LTC │ τ = τ_sys / (1 + S·scale) │
│ │ Update │ h_new = (1-α)h + α·tanh(u_eff) │
│ └─────────────────┘ │
│ │
│ Output: h_new (batch, hidden_dim) │
└─────────────────────────────────────────────────────────────┘Block 1: Predictive Coding
Purpose: Generate predictions and compute errors.
Formulas
Prediction: x̂_t = tanh(C^T @ h_{t-1})
Error: e_t = x_t - x̂_tImplementation
# Matrices
self.C = nn.Parameter(torch.randn(hidden_dim, input_dim) * 0.1)
self.W = nn.Parameter(torch.randn(input_dim, hidden_dim) * 0.1)
self.B = nn.Parameter(torch.randn(input_dim, hidden_dim) * 0.1)
# Forward
x_pred = torch.tanh(state.h @ self.C)
error = x - x_predWhy This Design?
C— decodes hidden state to input spaceW— projects error back to hidden spaceB— processes new input- Small initialization (0.1) for stability
Block 2: Surprise Gate
Purpose: Detect novelty and modulate plasticity.
Formulas
Entropy: H = 0.5 · log(2πe · var)
Threshold: τ = 1.0 + α · H
Relative: r = ||e|| / (||μ_e|| + ε)
Surprise: S = σ((r - τ) / (2γ))Implementation
def compute_surprise(self, error, state):
# Entropy from variance
variance = state.error_var.mean(dim=-1)
entropy = 0.5 * torch.log(2 * torch.pi * torch.e * (variance + eps))
# Adaptive threshold
tau = 1.0 + self.alpha * entropy
# Relative error
baseline = state.error_mean.norm(dim=-1) + eps
relative_error = error.norm(dim=-1) / baseline
# Surprise
surprise = torch.sigmoid((relative_error - tau) / (2 * self.gamma))
return surpriseWhy Relative Error?
- Absolute error can be small but model still "surprised"
- Relative error better detects anomalies
- More stable across different data scales
Block 3: Fast Weights (STDP)
Purpose: Online learning during inference.
Formulas
Fast Weights: W_fast = U @ V^T
STDP Update: dU = -λ(U - U_target) + (η · S) · ((h ⊗ e) @ V)
Euler Step: U ← U + dU · dtLow-Rank Decomposition
Full Matrix: hidden × input = 256 × 80 = 20,480 params
Low-Rank: (hidden × rank) + (input × rank)
= 256×16 + 80×16 = 5,376 params
Savings: ~4× reductionImplementation
# Fixed orthogonal V
V_init = torch.randn(input_dim, rank)
Q, _ = torch.linalg.qr(V_init)
self.register_buffer('V', Q)
# Update
eV = error @ self.V # (batch, rank)
hebbian = state.h.unsqueeze(2) * eV.unsqueeze(1) # (batch, hidden, rank)
plasticity = self.eta * surprise.unsqueeze(1)
forgetting = -self.forgetting_rate * (state.U - state.U_target)
dU = forgetting + plasticity * hebbian
state.U = state.U + dU * self.dtWhy Low-Rank?
- 4× parameter savings
- Preserves expressiveness
- V is fixed (orthogonal) for stability
Block 4: LTC Update
Purpose: Adaptive integration speeds.
Formulas
Dynamic τ: τ = τ_sys / (1 + S · scale)
Target: h_target = tanh(u_eff)
Integration: h_new = (1 - dt/τ) · h_prev + (dt/τ) · h_targetImplementation
def compute_ltc_update(self, h_prev, u_eff, surprise):
# Dynamic tau
tau = self.tau_sys / (1.0 + surprise * self.tau_surprise_scale)
tau = torch.clamp(tau, 0.01, 50.0)
# Euler integration
h_target = torch.tanh(u_eff)
dt_over_tau = self.dt / (tau.unsqueeze(1) + self.dt)
dt_over_tau = torch.clamp(dt_over_tau, 0.01, 0.5)
h_new = (1 - dt_over_tau) * h_prev + dt_over_tau * h_target
return h_newWhy LTC?
- High surprise → small τ → fast updates
- Low surprise → large τ → smooth integration
- Adaptive to event importance
Sleep Consolidation
Purpose: Stabilize fast changes into long-term memory.
Formulas
If S̄ > S_min:
dU_target = ζ_sleep · S̄ · (U - U_target)
U_target ← U_target + dU_targetImplementation
avg_surprise = state.avg_surprise.mean()
if avg_surprise > self.S_min:
dU_target = self.sleep_rate * avg_surprise * (state.U - state.U_target)
state.U_target = state.U_target + dU_targetWhy Sleep?
- Consolidates important patterns (S̄ > S_min)
- Slow update (ζ_sleep = 0.01)
- Prevents catastrophic forgetting
State Management
DREAMState Components
@dataclass
class DREAMState:
h: torch.Tensor # (batch, hidden_dim)
U: torch.Tensor # (batch, hidden_dim, rank)
U_target: torch.Tensor # (batch, hidden_dim, rank)
adaptive_tau: torch.Tensor # (batch,)
error_mean: torch.Tensor # (batch, input_dim)
error_var: torch.Tensor # (batch, input_dim)
avg_surprise: torch.Tensor # (batch,)Detachment for BPTT
# Between segments
state = state.detach() # Reset computation graphNext Steps
- API Reference — Complete API documentation
- Benchmarks — Performance results
- Guides — Tutorials and examples