Usage Patterns
Common usage patterns and best practices for DREAM
Usage Patterns
This guide covers common usage patterns for DREAM in various scenarios.
Pattern 1: Simple Sequence Processing
Process a sequence and get outputs at all timesteps.
from dream import DREAM
model = DREAM(input_dim=64, hidden_dim=128)
x = torch.randn(32, 50, 64)
output, state = model(x)
# output: (32, 50, 128) - all timesteps
# state: DREAMState - final stateUse case: Standard sequence modeling where you need outputs at every timestep.
Pattern 2: Stateful Processing (Memory Retention)
Preserve state across multiple sequences for continuous memory.
# Initialize state once
state = model.init_state(batch_size=32)
# Process multiple sequences with memory
for seq in sequence_batch:
output, state = model(seq.unsqueeze(0), state)
# State persists between sequences!Use case:
- Streaming data processing
- Continuous speech recognition
- Online learning scenarios
Key benefit: The model adapts and remembers across sequence boundaries.
Pattern 3: Variable-Length Sequences
Handle padded sequences with proper masking.
# Padded input
x = torch.randn(32, 100, 64) # padded to max length
lengths = torch.tensor([45, 67, 89, ...]) # actual lengths
# Process with masking
output, state = model.forward_with_mask(x, lengths)
# output has zeros for padded positionsUse case:
- Batched training with variable-length sequences
- Text processing with different sentence lengths
- Audio with varying durations
Pattern 4: Encoder-Decoder
Use DREAM in encoder-decoder architectures.
# Encoder
encoder = DREAM(input_dim=64, hidden_dim=256)
# Decoder
decoder = DREAM(input_dim=64, hidden_dim=256)
# Encode
x_enc = torch.randn(32, 50, 64)
_, enc_state = encoder(x_enc, return_sequences=False)
# Use encoder state as decoder initial state
dec_state = enc_state # or transform if dimensions differ
# Decode
x_dec = torch.randn(32, 30, 64)
output, _ = decoder(x_dec, dec_state)Use case:
- Sequence-to-sequence tasks
- Machine translation
- Speech-to-text
- Video captioning
With State Transformation
If encoder and decoder have different hidden dimensions:
# State transformation layer
state_transform = nn.Linear(256, 128)
# Transform encoder state for decoder
dec_state = DREAMState(
h=state_transform(enc_state.h),
U=..., # Initialize or transform fast weights
# ... other state components
)Pattern 5: Classification
Use DREAM as an encoder with a classification head.
import torch.nn as nn
encoder = DREAM(input_dim=39, hidden_dim=256)
classifier = nn.Linear(256, 10)
# Process
x = torch.randn(32, 100, 39)
_, final_state = encoder(x, return_sequences=False)
# Classify using final hidden state
logits = classifier(final_state.h)
predictions = logits.argmax(dim=-1)Use case:
- Speech command recognition
- Time series classification
- Activity recognition
- Sentiment analysis
With Pooling
For better robustness, pool over all timesteps:
# Get all outputs
output, _ = encoder(x, return_sequences=True) # (32, 100, 256)
# Mean pooling
pooled = output.mean(dim=1) # (32, 256)
logits = classifier(pooled)Pattern 6: Multi-Layer Stack
Stack multiple DREAM layers for hierarchical processing.
from dream import DREAMStack
model = DREAMStack(
input_dim=64,
hidden_dims=[128, 128, 64], # 3 layers
rank=8,
dropout=0.1
)
x = torch.randn(32, 50, 64)
output, states = model(x)
# output: (32, 50, 64)
# states: list of 3 DREAMStateUse case:
- Deep sequence modeling
- Hierarchical feature extraction
- Complex temporal patterns
Manual Stack
For more control, build the stack manually:
layer1 = DREAM(input_dim=64, hidden_dim=128)
layer2 = DREAM(input_dim=128, hidden_dim=128)
layer3 = DREAM(input_dim=128, hidden_dim=64)
# Process through layers
out1, state1 = layer1(x)
out2, state2 = layer2(out1)
out3, state3 = layer3(out2)Pattern 7: Truncated BPTT
Process long sequences in chunks for memory efficiency.
# Long sequence: process in chunks
seq_len = 1000
chunk_size = 100
state = model.init_state(batch_size)
for start in range(0, seq_len, chunk_size):
chunk = x[:, start:start+chunk_size]
# Process chunk
output, state = model(chunk, state)
# Detach state for truncated BPTT
state = state.detach()Use case:
- Very long sequences (audio, video)
- Memory-constrained training
- Online/streaming processing
Why detach? Prevents backpropagation through entire history, reducing memory usage.
Pattern 8: Bidirectional Processing
Process sequences in both directions.
class BidirectionalDREAM(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.forward = DREAM(input_dim, hidden_dim)
self.backward = DREAM(input_dim, hidden_dim)
self.output_proj = nn.Linear(hidden_dim * 2, hidden_dim)
def forward(self, x):
# Forward pass
fwd_out, _ = self.forward(x) # (batch, time, hidden)
# Backward pass (reverse time)
x_rev = x.flip(dims=[1])
bwd_out, _ = self.backward(x_rev)
bwd_out = bwd_out.flip(dims=[1]) # restore time order
# Concatenate and project
combined = torch.cat([fwd_out, bwd_out], dim=-1)
return self.output_proj(combined)
model = BidirectionalDREAM(input_dim=64, hidden_dim=128)
output = model(x) # (32, 50, 128)Use case:
- Context-dependent predictions
- Sequence labeling
- Named entity recognition
Pattern 9: Attention Augmentation
Combine DREAM with attention mechanisms.
class DREAMWithAttention(nn.Module):
def __init__(self, input_dim, hidden_dim, num_heads=4):
super().__init__()
self.dream = DREAM(input_dim, hidden_dim)
self.attention = nn.MultiheadAttention(hidden_dim, num_heads)
self.norm = nn.LayerNorm(hidden_dim)
def forward(self, x):
# DREAM encoding
dream_out, _ = self.dream(x) # (batch, time, hidden)
# Self-attention (transpose for MHA format)
attn_out, _ = self.attention(
dream_out.transpose(0, 1),
dream_out.transpose(0, 1),
dream_out.transpose(0, 1)
)
attn_out = attn_out.transpose(0, 1)
# Residual + normalization
return self.norm(dream_out + attn_out)
model = DREAMWithAttention(input_dim=64, hidden_dim=128)
output = model(x)Use case:
- Long-range dependency modeling
- Transformer-RNN hybrids
- Enhanced context awareness
Pattern 10: Curriculum Learning with Surprise
Use the model's surprise signal for curriculum learning.
model = DREAM(input_dim=64, hidden_dim=128)
state = model.init_state(batch_size=32)
# Track surprise for each sample
surprise_scores = []
for batch in dataloader:
output, state = model(batch)
# Get average surprise
surprise = state.avg_surprise.mean().item()
surprise_scores.append(surprise)
# Use surprise to adjust learning rate
if surprise > 0.7:
# High surprise - slow down learning
set_learning_rate(1e-4)
elif surprise < 0.3:
# Low surprise - speed up learning
set_learning_rate(1e-3)Use case:
- Adaptive curriculum
- Difficulty-based sampling
- Self-paced learning
Pattern 11: State Inspection and Debugging
Monitor internal state during inference.
model = DREAM(input_dim=64, hidden_dim=128)
state = model.init_state(batch_size=1)
# Process and monitor
for timestep, x_t in enumerate(sequence):
output, state = model(x_t.unsqueeze(0), state)
# Log state statistics
print(f"Timestep {timestep}:")
print(f" Hidden norm: {state.h.norm().item():.4f}")
print(f" U norm: {state.U.norm().item():.4f}")
print(f" Surprise: {state.avg_surprise.mean().item():.4f}")
print(f" Adaptive tau: {state.adaptive_tau.mean().item():.4f}")Use case:
- Debugging model behavior
- Understanding adaptation dynamics
- Monitoring training stability
Pattern 12: Checkpointing State
Save and load state for resumption.
# Save state
torch.save({
'h': state.h,
'U': state.U,
'U_target': state.U_target,
'adaptive_tau': state.adaptive_tau,
'error_mean': state.error_mean,
'error_var': state.error_var,
'avg_surprise': state.avg_surprise,
}, 'checkpoint.pt')
# Load state
checkpoint = torch.load('checkpoint.pt')
state = DREAMState(**checkpoint)Use case:
- Long-running inference
- State persistence across sessions
- Distributed processing
Next Steps
- Configuration Guide - Tune parameters
- Training Best Practices - Optimize training
- Examples - Real-world examples