Quick Start
Get started with DREAM in 5 minutes
Quick Start
This guide will help you get started with DREAM (Dynamic Recall and Elastic Adaptive Memory) in just a few minutes.
Basic Usage
Here's the simplest way to use DREAM:
import torch
from dream import DREAM
# Create model
model = DREAM(
input_dim=64, # Input feature dimension
hidden_dim=128, # Hidden state size
rank=8, # Fast weights rank
ltc_enabled=True # Enable liquid time-constants
)
# Process sequence
batch_size = 32
seq_len = 50
x = torch.randn(batch_size, seq_len, 64) # (batch, time, features)
output, state = model(x)
print(f"Output shape: {output.shape}") # (32, 50, 128)Key Concepts
Input/Output Shapes
DREAM follows the standard RNN convention:
- Input:
(batch, time, features) - Output:
(batch, time, hidden_dim)(whenreturn_sequences=True) - State:
DREAMStateobject containing hidden state and fast weights
State Management
DREAM maintains state across sequences, enabling memory retention:
# Initialize state
state = model.init_state(batch_size=32)
# Process multiple sequences with memory
for sequence in sequences:
output, state = model(sequence, state)
# State persists between sequences!Common Patterns
Classification
Use DREAM as an encoder with a classification head:
import torch.nn as nn
# DREAM encoder
encoder = DREAM(input_dim=39, hidden_dim=256, rank=16)
# Classifier head
classifier = nn.Linear(256, 10) # 10 classes
# Process sequence
x = torch.randn(32, 100, 39) # (batch, time, features)
output, final_state = encoder(x, return_sequences=False) # (batch, hidden)
# Classify
logits = classifier(final_state.h) # (batch, 10)
predictions = logits.argmax(dim=-1)Sequence-to-Sequence
Use DREAM for encoder-decoder architectures:
# Encoder
encoder = DREAM(input_dim=64, hidden_dim=256)
# Decoder
decoder = DREAM(input_dim=64, hidden_dim=256)
# Encode
x_enc = torch.randn(32, 50, 64)
_, enc_state = encoder(x_enc, return_sequences=False)
# Use encoder state as decoder initial state
dec_state = enc_state
# Decode
x_dec = torch.randn(32, 30, 64)
output, _ = decoder(x_dec, dec_state)Variable-Length Sequences
Handle padded sequences with masking:
# Padded input
x = torch.randn(32, 100, 64) # padded to max length
lengths = torch.tensor([45, 67, 89, ...]) # actual lengths
# Process with masking
output, state = model.forward_with_mask(x, lengths)
# output has zeros for padded positionsTraining Example
Here's a complete training loop:
import torch
import torch.nn as nn
from dream import DREAM
# Model
model = DREAM(input_dim=39, hidden_dim=256, rank=16)
classifier = nn.Linear(256, 10)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
# Training loop
for epoch in range(10):
for batch_x, batch_y in dataloader:
# Forward pass
_, final_state = model(batch_x, return_sequences=False)
logits = classifier(final_state.h)
# Compute loss
loss = criterion(logits, batch_y)
# Backward pass
optimizer.zero_grad()
loss.backward()
# Gradient clipping (recommended)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
print(f"Epoch {epoch}: Loss = {loss.item():.4f}")Next Steps
Now that you know the basics, explore these topics:
- Architecture Deep Dive - Understand how DREAM works internally
- Configuration Guide - Tune parameters for your use case
- Usage Patterns - Learn advanced patterns
- Training Best Practices - Optimize your training process