DREAM

Architecture

How DREAM works — 4 blocks explained

Architecture

DREAM cell combines four key mechanisms for adaptive sequence processing.

Overview

┌─────────────────────────────────────────────────────────────┐
│                    DREAM Cell                               │
│  Input: x_t (batch, input_dim)                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌─────────────────┐                                        │
│  │ 1. Predictive   │  x̂ = tanh(C^T @ h)                    │
│  │    Coding       │  e = x - x̂                            │
│  └─────────────────┘                                        │
│         │                                                   │
│         ▼                                                   │
│  ┌─────────────────┐                                        │
│  │ 2. Surprise     │  S = σ((r - τ) / (2γ))                │
│  │    Gate         │  r = ||e|| / ||μ_e||                  │
│  └─────────────────┘                                        │
│         │                                                   │
│         ▼                                                   │
│  ┌─────────────────┐                                        │
│  │ 3. Fast Weights │  dU = -λ(U-U_tgt) + η·S·(h⊗e)@V       │
│  │    (STDP)       │  U ← U + dU·dt                        │
│  └─────────────────┘                                        │
│         │                                                   │
│         ▼                                                   │
│  ┌─────────────────┐                                        │
│  │ 4. LTC          │  τ = τ_sys / (1 + S·scale)            │
│  │    Update       │  h_new = (1-α)h + α·tanh(u_eff)       │
│  └─────────────────┘                                        │
│                                                             │
│  Output: h_new (batch, hidden_dim)                          │
└─────────────────────────────────────────────────────────────┘

Block 1: Predictive Coding

Purpose: Generate predictions and compute errors.

Formulas

Prediction:  x̂_t = tanh(C^T @ h_{t-1})
Error:       e_t = x_t - x̂_t

Implementation

# Matrices
self.C = nn.Parameter(torch.randn(hidden_dim, input_dim) * 0.1)
self.W = nn.Parameter(torch.randn(input_dim, hidden_dim) * 0.1)
self.B = nn.Parameter(torch.randn(input_dim, hidden_dim) * 0.1)

# Forward
x_pred = torch.tanh(state.h @ self.C)
error = x - x_pred

Why This Design?

  • C — decodes hidden state to input space
  • W — projects error back to hidden space
  • B — processes new input
  • Small initialization (0.1) for stability

Block 2: Surprise Gate

Purpose: Detect novelty and modulate plasticity.

Formulas

Entropy:     H = 0.5 · log(2πe · var)
Threshold:   τ = 1.0 + α · H
Relative:    r = ||e|| / (||μ_e|| + ε)
Surprise:    S = σ((r - τ) / (2γ))

Implementation

def compute_surprise(self, error, state):
    # Entropy from variance
    variance = state.error_var.mean(dim=-1)
    entropy = 0.5 * torch.log(2 * torch.pi * torch.e * (variance + eps))
    
    # Adaptive threshold
    tau = 1.0 + self.alpha * entropy
    
    # Relative error
    baseline = state.error_mean.norm(dim=-1) + eps
    relative_error = error.norm(dim=-1) / baseline
    
    # Surprise
    surprise = torch.sigmoid((relative_error - tau) / (2 * self.gamma))
    return surprise

Why Relative Error?

  • Absolute error can be small but model still "surprised"
  • Relative error better detects anomalies
  • More stable across different data scales

Block 3: Fast Weights (STDP)

Purpose: Online learning during inference.

Formulas

Fast Weights:  W_fast = U @ V^T
STDP Update:   dU = -λ(U - U_target) + (η · S) · ((h ⊗ e) @ V)
Euler Step:    U ← U + dU · dt

Low-Rank Decomposition

Full Matrix:     hidden × input = 256 × 80 = 20,480 params
Low-Rank:        (hidden × rank) + (input × rank)
                 = 256×16 + 80×16 = 5,376 params
Savings:         ~4× reduction

Implementation

# Fixed orthogonal V
V_init = torch.randn(input_dim, rank)
Q, _ = torch.linalg.qr(V_init)
self.register_buffer('V', Q)

# Update
eV = error @ self.V  # (batch, rank)
hebbian = state.h.unsqueeze(2) * eV.unsqueeze(1)  # (batch, hidden, rank)

plasticity = self.eta * surprise.unsqueeze(1)
forgetting = -self.forgetting_rate * (state.U - state.U_target)

dU = forgetting + plasticity * hebbian
state.U = state.U + dU * self.dt

Why Low-Rank?

  • 4× parameter savings
  • Preserves expressiveness
  • V is fixed (orthogonal) for stability

Block 4: LTC Update

Purpose: Adaptive integration speeds.

Formulas

Dynamic τ:   τ = τ_sys / (1 + S · scale)
Target:      h_target = tanh(u_eff)
Integration: h_new = (1 - dt/τ) · h_prev + (dt/τ) · h_target

Implementation

def compute_ltc_update(self, h_prev, u_eff, surprise):
    # Dynamic tau
    tau = self.tau_sys / (1.0 + surprise * self.tau_surprise_scale)
    tau = torch.clamp(tau, 0.01, 50.0)
    
    # Euler integration
    h_target = torch.tanh(u_eff)
    dt_over_tau = self.dt / (tau.unsqueeze(1) + self.dt)
    dt_over_tau = torch.clamp(dt_over_tau, 0.01, 0.5)
    
    h_new = (1 - dt_over_tau) * h_prev + dt_over_tau * h_target
    return h_new

Why LTC?

  • High surprise → small τ → fast updates
  • Low surprise → large τ → smooth integration
  • Adaptive to event importance

Sleep Consolidation

Purpose: Stabilize fast changes into long-term memory.

Formulas

If S̄ > S_min:
    dU_target = ζ_sleep · S̄ · (U - U_target)
    U_target ← U_target + dU_target

Implementation

avg_surprise = state.avg_surprise.mean()

if avg_surprise > self.S_min:
    dU_target = self.sleep_rate * avg_surprise * (state.U - state.U_target)
    state.U_target = state.U_target + dU_target

Why Sleep?

  • Consolidates important patterns (S̄ > S_min)
  • Slow update (ζ_sleep = 0.01)
  • Prevents catastrophic forgetting

State Management

DREAMState Components

@dataclass
class DREAMState:
    h: torch.Tensor              # (batch, hidden_dim)
    U: torch.Tensor              # (batch, hidden_dim, rank)
    U_target: torch.Tensor       # (batch, hidden_dim, rank)
    adaptive_tau: torch.Tensor   # (batch,)
    error_mean: torch.Tensor     # (batch, input_dim)
    error_var: torch.Tensor      # (batch, input_dim)
    avg_surprise: torch.Tensor   # (batch,)

Detachment for BPTT

# Between segments
state = state.detach()  # Reset computation graph

Next Steps

On this page