Manifestro Docs

Usage Patterns

Common usage patterns and best practices for DREAM

Usage Patterns

This guide covers common usage patterns for DREAM in various scenarios.

Pattern 1: Simple Sequence Processing

Process a sequence and get outputs at all timesteps.

from dream import DREAM

model = DREAM(input_dim=64, hidden_dim=128)
x = torch.randn(32, 50, 64)

output, state = model(x)
# output: (32, 50, 128) - all timesteps
# state: DREAMState - final state

Use case: Standard sequence modeling where you need outputs at every timestep.


Pattern 2: Stateful Processing (Memory Retention)

Preserve state across multiple sequences for continuous memory.

# Initialize state once
state = model.init_state(batch_size=32)

# Process multiple sequences with memory
for seq in sequence_batch:
    output, state = model(seq.unsqueeze(0), state)
    # State persists between sequences!

Use case:

  • Streaming data processing
  • Continuous speech recognition
  • Online learning scenarios

Key benefit: The model adapts and remembers across sequence boundaries.


Pattern 3: Variable-Length Sequences

Handle padded sequences with proper masking.

# Padded input
x = torch.randn(32, 100, 64)  # padded to max length
lengths = torch.tensor([45, 67, 89, ...])  # actual lengths

# Process with masking
output, state = model.forward_with_mask(x, lengths)

# output has zeros for padded positions

Use case:

  • Batched training with variable-length sequences
  • Text processing with different sentence lengths
  • Audio with varying durations

Pattern 4: Encoder-Decoder

Use DREAM in encoder-decoder architectures.

# Encoder
encoder = DREAM(input_dim=64, hidden_dim=256)

# Decoder
decoder = DREAM(input_dim=64, hidden_dim=256)

# Encode
x_enc = torch.randn(32, 50, 64)
_, enc_state = encoder(x_enc, return_sequences=False)

# Use encoder state as decoder initial state
dec_state = enc_state  # or transform if dimensions differ

# Decode
x_dec = torch.randn(32, 30, 64)
output, _ = decoder(x_dec, dec_state)

Use case:

  • Sequence-to-sequence tasks
  • Machine translation
  • Speech-to-text
  • Video captioning

With State Transformation

If encoder and decoder have different hidden dimensions:

# State transformation layer
state_transform = nn.Linear(256, 128)

# Transform encoder state for decoder
dec_state = DREAMState(
    h=state_transform(enc_state.h),
    U=...,  # Initialize or transform fast weights
    # ... other state components
)

Pattern 5: Classification

Use DREAM as an encoder with a classification head.

import torch.nn as nn

encoder = DREAM(input_dim=39, hidden_dim=256)
classifier = nn.Linear(256, 10)

# Process
x = torch.randn(32, 100, 39)
_, final_state = encoder(x, return_sequences=False)

# Classify using final hidden state
logits = classifier(final_state.h)
predictions = logits.argmax(dim=-1)

Use case:

  • Speech command recognition
  • Time series classification
  • Activity recognition
  • Sentiment analysis

With Pooling

For better robustness, pool over all timesteps:

# Get all outputs
output, _ = encoder(x, return_sequences=True)  # (32, 100, 256)

# Mean pooling
pooled = output.mean(dim=1)  # (32, 256)
logits = classifier(pooled)

Pattern 6: Multi-Layer Stack

Stack multiple DREAM layers for hierarchical processing.

from dream import DREAMStack

model = DREAMStack(
    input_dim=64,
    hidden_dims=[128, 128, 64],  # 3 layers
    rank=8,
    dropout=0.1
)

x = torch.randn(32, 50, 64)
output, states = model(x)
# output: (32, 50, 64)
# states: list of 3 DREAMState

Use case:

  • Deep sequence modeling
  • Hierarchical feature extraction
  • Complex temporal patterns

Manual Stack

For more control, build the stack manually:

layer1 = DREAM(input_dim=64, hidden_dim=128)
layer2 = DREAM(input_dim=128, hidden_dim=128)
layer3 = DREAM(input_dim=128, hidden_dim=64)

# Process through layers
out1, state1 = layer1(x)
out2, state2 = layer2(out1)
out3, state3 = layer3(out2)

Pattern 7: Truncated BPTT

Process long sequences in chunks for memory efficiency.

# Long sequence: process in chunks
seq_len = 1000
chunk_size = 100

state = model.init_state(batch_size)

for start in range(0, seq_len, chunk_size):
    chunk = x[:, start:start+chunk_size]
    
    # Process chunk
    output, state = model(chunk, state)
    
    # Detach state for truncated BPTT
    state = state.detach()

Use case:

  • Very long sequences (audio, video)
  • Memory-constrained training
  • Online/streaming processing

Why detach? Prevents backpropagation through entire history, reducing memory usage.


Pattern 8: Bidirectional Processing

Process sequences in both directions.

class BidirectionalDREAM(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.forward = DREAM(input_dim, hidden_dim)
        self.backward = DREAM(input_dim, hidden_dim)
        self.output_proj = nn.Linear(hidden_dim * 2, hidden_dim)
    
    def forward(self, x):
        # Forward pass
        fwd_out, _ = self.forward(x)  # (batch, time, hidden)
        
        # Backward pass (reverse time)
        x_rev = x.flip(dims=[1])
        bwd_out, _ = self.backward(x_rev)
        bwd_out = bwd_out.flip(dims=[1])  # restore time order
        
        # Concatenate and project
        combined = torch.cat([fwd_out, bwd_out], dim=-1)
        return self.output_proj(combined)

model = BidirectionalDREAM(input_dim=64, hidden_dim=128)
output = model(x)  # (32, 50, 128)

Use case:

  • Context-dependent predictions
  • Sequence labeling
  • Named entity recognition

Pattern 9: Attention Augmentation

Combine DREAM with attention mechanisms.

class DREAMWithAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads=4):
        super().__init__()
        self.dream = DREAM(input_dim, hidden_dim)
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads)
        self.norm = nn.LayerNorm(hidden_dim)
    
    def forward(self, x):
        # DREAM encoding
        dream_out, _ = self.dream(x)  # (batch, time, hidden)
        
        # Self-attention (transpose for MHA format)
        attn_out, _ = self.attention(
            dream_out.transpose(0, 1),
            dream_out.transpose(0, 1),
            dream_out.transpose(0, 1)
        )
        attn_out = attn_out.transpose(0, 1)
        
        # Residual + normalization
        return self.norm(dream_out + attn_out)

model = DREAMWithAttention(input_dim=64, hidden_dim=128)
output = model(x)

Use case:

  • Long-range dependency modeling
  • Transformer-RNN hybrids
  • Enhanced context awareness

Pattern 10: Curriculum Learning with Surprise

Use the model's surprise signal for curriculum learning.

model = DREAM(input_dim=64, hidden_dim=128)
state = model.init_state(batch_size=32)

# Track surprise for each sample
surprise_scores = []

for batch in dataloader:
    output, state = model(batch)
    
    # Get average surprise
    surprise = state.avg_surprise.mean().item()
    surprise_scores.append(surprise)
    
    # Use surprise to adjust learning rate
    if surprise > 0.7:
        # High surprise - slow down learning
        set_learning_rate(1e-4)
    elif surprise < 0.3:
        # Low surprise - speed up learning
        set_learning_rate(1e-3)

Use case:

  • Adaptive curriculum
  • Difficulty-based sampling
  • Self-paced learning

Pattern 11: State Inspection and Debugging

Monitor internal state during inference.

model = DREAM(input_dim=64, hidden_dim=128)
state = model.init_state(batch_size=1)

# Process and monitor
for timestep, x_t in enumerate(sequence):
    output, state = model(x_t.unsqueeze(0), state)
    
    # Log state statistics
    print(f"Timestep {timestep}:")
    print(f"  Hidden norm: {state.h.norm().item():.4f}")
    print(f"  U norm: {state.U.norm().item():.4f}")
    print(f"  Surprise: {state.avg_surprise.mean().item():.4f}")
    print(f"  Adaptive tau: {state.adaptive_tau.mean().item():.4f}")

Use case:

  • Debugging model behavior
  • Understanding adaptation dynamics
  • Monitoring training stability

Pattern 12: Checkpointing State

Save and load state for resumption.

# Save state
torch.save({
    'h': state.h,
    'U': state.U,
    'U_target': state.U_target,
    'adaptive_tau': state.adaptive_tau,
    'error_mean': state.error_mean,
    'error_var': state.error_var,
    'avg_surprise': state.avg_surprise,
}, 'checkpoint.pt')

# Load state
checkpoint = torch.load('checkpoint.pt')
state = DREAMState(**checkpoint)

Use case:

  • Long-running inference
  • State persistence across sessions
  • Distributed processing

Next Steps

On this page