Guides
Tutorials and examples for DREAM
Guides
Tutorials and examples for using DREAM.
Table of Contents
Beginner
- Basic Sequence Processing — Your first DREAM model
- Stateful Processing — Preserve memory across sequences
- Using DREAMStack — Multi-layer models
Intermediate
- Truncated BPTT — Training on long sequences
- Custom Configuration — Tuning hyperparameters
- Saving and Loading — Checkpoint management
Advanced
- Custom Cell Extension — Extending DREAMCell
- Integration with Lightning — PyTorch Lightning support
Basic Sequence Processing
Step 1: Import and Configure
import torch
from dream import DREAMConfig, DREAMCell
config = DREAMConfig(
input_dim=80, # Mel spectrogram bins
hidden_dim=256, # Hidden state size
rank=16, # Fast weights rank
)Step 2: Create Model
cell = DREAMCell(config)Step 3: Initialize State
batch_size = 4
state = cell.init_state(batch_size)Step 4: Process Sequence
# Single timestep
x = torch.randn(batch_size, 80)
h_new, state = cell(x, state)
# Or full sequence
x_seq = torch.randn(batch_size, 100, 80)
output, final_state = cell.forward_sequence(x_seq, return_all=True)
print(f"Output shape: {output.shape}") # (4, 100, 256)Stateful Processing
Preserve memory across multiple sequences:
# Initialize ONCE
state = cell.init_state(batch_size=4)
# Process multiple sequences
for seq in sequences:
# State (U, h, adaptive_tau) is preserved
output, state = cell.forward_sequence(seq, state)
# Model adapts and remembers!
# Surprise should decrease as model learnsExample: Memory Retention Test
state = cell.init_state(batch_size=1)
for pass_idx in range(5):
output, state = cell.forward_sequence(same_sequence, state)
# Surprise should decrease as model adapts
print(f"Pass {pass_idx}: Surprise = {state.avg_surprise.mean().item():.4f}")Using DREAMStack
Multi-layer DREAM models:
from dream import DREAMStack
model = DREAMStack(
input_dim=80,
hidden_dims=[256, 256, 128], # 3 layers
rank=16,
dropout=0.1,
)
# Process
x = torch.randn(4, 100, 80)
output, states = model(x, return_sequences=True)
print(f"Output shape: {output.shape}") # (4, 100, 128)Accessing Layer States
# Initialize states for all layers
states = model.init_state(batch_size=4)
# Process with explicit states
for layer_idx, layer in enumerate(model.layers):
output, states[layer_idx] = layer(output, states[layer_idx])Truncated BPTT
Training on long sequences:
from dream import DREAMCell, DREAMState
import torch.nn as nn
model = DREAMCell(config)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Training loop
state = model.init_state(batch_size=4)
for epoch in range(100):
optimizer.zero_grad()
# Process in segments
segment_size = 100
for start in range(0, seq_len, segment_size):
segment = x[:, start:start+segment_size, :]
output, state = model.forward_sequence(segment, state)
loss = criterion(output, target[:, start:start+segment_size, :])
loss.backward()
# Detach state between segments
state = state.detach()
optimizer.step()Custom Configuration
For ASR (MFCC 39D)
config = DREAMConfig(
input_dim=39, # 13 MFCC + 13Δ + 13ΔΔ
hidden_dim=512, # Larger capacity
rank=16,
forgetting_rate=0.005,
base_plasticity=0.5,
ltc_tau_sys=5.0,
)For Time Series
config = DREAMConfig(
input_dim=features_dim,
hidden_dim=128,
rank=8,
ltc_enabled=True,
ltc_tau_sys=5.0, # Faster response
ltc_surprise_scale=5.0,
)For Fast Adaptation
config = DREAMConfig(
input_dim=80,
hidden_dim=256,
rank=16,
base_plasticity=1.0, # Higher learning rate
base_threshold=0.2, # More sensitive
surprise_temperature=0.1, # Smoother surprise
)Saving and Loading
Save Checkpoint
# Save model state
torch.save({
'model_state_dict': cell.state_dict(),
'config': config,
}, 'dream_checkpoint.pt')Load Checkpoint
checkpoint = torch.load('dream_checkpoint.pt')
# Recreate config and model
config = checkpoint['config']
cell = DREAMCell(config)
cell.load_state_dict(checkpoint['model_state_dict'])Save Full State (for inference)
# Save model + fast weights state
torch.save({
'model_state_dict': cell.state_dict(),
'state_dict': {
'U': state.U,
'U_target': state.U_target,
'h': state.h,
},
}, 'dream_full_checkpoint.pt')Custom Cell Extension
Extend DREAMCell for custom tasks:
from dream import DREAMCell, DREAMConfig
class CustomDREAMCell(DREAMCell):
def __init__(self, config, output_dim):
super().__init__(config)
self.output_layer = nn.Linear(config.hidden_dim, output_dim)
def forward(self, x, state):
h_new, state = super().forward(x, state)
output = self.output_layer(h_new)
return output, stateIntegration with Lightning
PyTorch Lightning module:
import pytorch_lightning as pl
from dream import DREAMCell, DREAMConfig
class DREAMModule(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.cell = DREAMCell(config)
self.criterion = nn.MSELoss()
def training_step(self, batch, batch_idx):
x, target = batch
output, state = self.cell.forward_sequence(x, return_all=True)
loss = self.criterion(output, target)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)Next Steps
- API Reference — Complete API documentation
- Benchmarks — Performance results
- GitHub Issues — Ask questions