DREAM

Guides

Tutorials and examples for DREAM

Guides

Tutorials and examples for using DREAM.


Table of Contents

Beginner

Intermediate

Advanced


Basic Sequence Processing

Step 1: Import and Configure

import torch
from dream import DREAMConfig, DREAMCell

config = DREAMConfig(
    input_dim=80,      # Mel spectrogram bins
    hidden_dim=256,    # Hidden state size
    rank=16,           # Fast weights rank
)

Step 2: Create Model

cell = DREAMCell(config)

Step 3: Initialize State

batch_size = 4
state = cell.init_state(batch_size)

Step 4: Process Sequence

# Single timestep
x = torch.randn(batch_size, 80)
h_new, state = cell(x, state)

# Or full sequence
x_seq = torch.randn(batch_size, 100, 80)
output, final_state = cell.forward_sequence(x_seq, return_all=True)

print(f"Output shape: {output.shape}")  # (4, 100, 256)

Stateful Processing

Preserve memory across multiple sequences:

# Initialize ONCE
state = cell.init_state(batch_size=4)

# Process multiple sequences
for seq in sequences:
    # State (U, h, adaptive_tau) is preserved
    output, state = cell.forward_sequence(seq, state)
    
    # Model adapts and remembers!
    # Surprise should decrease as model learns

Example: Memory Retention Test

state = cell.init_state(batch_size=1)

for pass_idx in range(5):
    output, state = cell.forward_sequence(same_sequence, state)
    
    # Surprise should decrease as model adapts
    print(f"Pass {pass_idx}: Surprise = {state.avg_surprise.mean().item():.4f}")

Using DREAMStack

Multi-layer DREAM models:

from dream import DREAMStack

model = DREAMStack(
    input_dim=80,
    hidden_dims=[256, 256, 128],  # 3 layers
    rank=16,
    dropout=0.1,
)

# Process
x = torch.randn(4, 100, 80)
output, states = model(x, return_sequences=True)

print(f"Output shape: {output.shape}")  # (4, 100, 128)

Accessing Layer States

# Initialize states for all layers
states = model.init_state(batch_size=4)

# Process with explicit states
for layer_idx, layer in enumerate(model.layers):
    output, states[layer_idx] = layer(output, states[layer_idx])

Truncated BPTT

Training on long sequences:

from dream import DREAMCell, DREAMState
import torch.nn as nn

model = DREAMCell(config)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training loop
state = model.init_state(batch_size=4)

for epoch in range(100):
    optimizer.zero_grad()
    
    # Process in segments
    segment_size = 100
    for start in range(0, seq_len, segment_size):
        segment = x[:, start:start+segment_size, :]
        
        output, state = model.forward_sequence(segment, state)
        loss = criterion(output, target[:, start:start+segment_size, :])
        
        loss.backward()
        
        # Detach state between segments
        state = state.detach()
    
    optimizer.step()

Custom Configuration

For ASR (MFCC 39D)

config = DREAMConfig(
    input_dim=39,       # 13 MFCC + 13Δ + 13ΔΔ
    hidden_dim=512,     # Larger capacity
    rank=16,
    forgetting_rate=0.005,
    base_plasticity=0.5,
    ltc_tau_sys=5.0,
)

For Time Series

config = DREAMConfig(
    input_dim=features_dim,
    hidden_dim=128,
    rank=8,
    ltc_enabled=True,
    ltc_tau_sys=5.0,  # Faster response
    ltc_surprise_scale=5.0,
)

For Fast Adaptation

config = DREAMConfig(
    input_dim=80,
    hidden_dim=256,
    rank=16,
    base_plasticity=1.0,      # Higher learning rate
    base_threshold=0.2,        # More sensitive
    surprise_temperature=0.1,  # Smoother surprise
)

Saving and Loading

Save Checkpoint

# Save model state
torch.save({
    'model_state_dict': cell.state_dict(),
    'config': config,
}, 'dream_checkpoint.pt')

Load Checkpoint

checkpoint = torch.load('dream_checkpoint.pt')

# Recreate config and model
config = checkpoint['config']
cell = DREAMCell(config)
cell.load_state_dict(checkpoint['model_state_dict'])

Save Full State (for inference)

# Save model + fast weights state
torch.save({
    'model_state_dict': cell.state_dict(),
    'state_dict': {
        'U': state.U,
        'U_target': state.U_target,
        'h': state.h,
    },
}, 'dream_full_checkpoint.pt')

Custom Cell Extension

Extend DREAMCell for custom tasks:

from dream import DREAMCell, DREAMConfig

class CustomDREAMCell(DREAMCell):
    def __init__(self, config, output_dim):
        super().__init__(config)
        self.output_layer = nn.Linear(config.hidden_dim, output_dim)
    
    def forward(self, x, state):
        h_new, state = super().forward(x, state)
        output = self.output_layer(h_new)
        return output, state

Integration with Lightning

PyTorch Lightning module:

import pytorch_lightning as pl
from dream import DREAMCell, DREAMConfig

class DREAMModule(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.cell = DREAMCell(config)
        self.criterion = nn.MSELoss()
        
    def training_step(self, batch, batch_idx):
        x, target = batch
        output, state = self.cell.forward_sequence(x, return_all=True)
        loss = self.criterion(output, target)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

Next Steps

On this page