API Reference
Complete API documentation for DREAM classes and methods
API Reference
Complete API documentation for DREAM (Dynamic Recall and Elastic Adaptive Memory).
DREAMConfig
Configuration dataclass for DREAM cell parameters.
from dataclasses import dataclass
@dataclass
class DREAMConfig:
"""Configuration for DREAM cell."""Model Dimensions
| Parameter | Type | Default | Description |
|---|---|---|---|
input_dim | int | 39 | Input feature dimension |
hidden_dim | int | 256 | Hidden state size |
rank | int | 16 | Fast weights rank |
Time Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
time_step | float | 0.1 | Integration step size (dt) |
Plasticity Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
forgetting_rate | float | 0.01 | Decay rate toward target (λ) |
base_plasticity | float | 0.1 | Base learning rate (η) |
Surprise Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
base_threshold | float | 0.5 | Base surprise threshold (τ₀) |
entropy_influence | float | 0.2 | Entropy weighting (α) |
surprise_temperature | float | 0.1 | Sigmoid smoothness (γ) |
Smoothing Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
error_smoothing | float | 0.01 | EMA smoothing for error (β) |
surprise_smoothing | float | 0.01 | EMA smoothing for surprise |
Homeostasis Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
target_norm | float | 2.0 | Target norm for fast weights |
kappa | float | 0.5 | Homeostasis strength |
Sleep Consolidation
| Parameter | Type | Default | Description |
|---|---|---|---|
sleep_rate | float | 0.005 | Consolidation rate (ζ) |
min_surprise_for_sleep | float | 0.2 | Threshold for consolidation |
Liquid Time-Constants
| Parameter | Type | Default | Description |
|---|---|---|---|
ltc_enabled | bool | True | Enable LTC dynamics |
ltc_tau_sys | float | 10.0 | System time constant (τ_sys) |
ltc_surprise_scale | float | 10.0 | Surprise influence scale |
Example
from dream import DREAMConfig
config = DREAMConfig(
input_dim=64,
hidden_dim=256,
rank=16,
ltc_enabled=True,
base_plasticity=0.15,
)DREAMCell
Low-level DREAM cell implementation.
class DREAMCell(nn.Module):
"""
DREAM cell: core building block.
Parameters
----------
config : DREAMConfig
Model configuration
"""Methods
__init__(config: DREAMConfig)
Initialize the DREAM cell with configuration.
from dream import DREAMCell, DREAMConfig
config = DREAMConfig(input_dim=64, hidden_dim=128)
cell = DREAMCell(config)init_state(batch_size: int, device: str | None = None, dtype: torch.dtype | None = None) -> DREAMState
Initialize cell state for a batch.
Parameters:
batch_size: Number of sequences in batchdevice: Optional device (e.g., 'cuda', 'cpu')dtype: Optional data type (e.g.,torch.float32)
Returns:
DREAMState: Initialized state object
state = cell.init_state(batch_size=32, device='cuda')forward(x: Tensor, state: DREAMState) -> Tuple[Tensor, DREAMState]
Single timestep forward pass.
Parameters:
x: Input tensor(batch, input_dim)state: Previous state
Returns:
output: Output tensor(batch, hidden_dim)new_state: Updated state
x = torch.randn(32, 64) # (batch, input)
output, new_state = cell(x, state)forward_sequence(x_seq: Tensor, state: DREAMState, return_all: bool = True) -> Tuple[Tensor, DREAMState]
Process full sequence.
Parameters:
x_seq: Input sequence(batch, time, input_dim)state: Initial statereturn_all: If True, return all timesteps; if False, only final output
Returns:
output: Output tensor(batch, time, hidden_dim)or(batch, hidden_dim)final_state: Final state
x_seq = torch.randn(32, 50, 64) # (batch, time, input)
output, final_state = cell.forward_sequence(x_seq, state)DREAM
High-level DREAM sequence model (nn.LSTM-like API).
class DREAM(nn.Module):
"""
High-level DREAM sequence model.
Parameters
----------
input_dim : int
Input feature dimension
hidden_dim : int
Hidden state size
rank : int
Fast weights rank
**kwargs
Passed to DREAMConfig
"""Constructor
from dream import DREAM
model = DREAM(
input_dim=64,
hidden_dim=128,
rank=8,
ltc_enabled=True
)Methods
init_state(batch_size: int, device: str | None = None, dtype: torch.dtype | None = None) -> DREAMState
Initialize model state.
state = model.init_state(batch_size=32, device='cuda')forward(x: Tensor, state: DREAMState | None = None, return_sequences: bool = True) -> Tuple[Tensor, DREAMState]
Process sequence.
Parameters:
x: Input sequence(batch, time, input_dim)state: Optional initial state (if None, initialized fresh)return_sequences: If True, return all timesteps; if False, only final
Returns:
output: Output tensorstate: Final state
# Return all timesteps
x = torch.randn(32, 50, 64)
output, state = model(x, return_sequences=True) # (32, 50, 128)
# Return only final output
output, state = model(x, return_sequences=False) # (32, 128)forward_with_mask(x: Tensor, lengths: Tensor, state: DREAMState | None = None) -> Tuple[Tensor, DREAMState]
Process padded sequence with masking.
Parameters:
x: Padded input(batch, max_time, input_dim)lengths: Actual sequence lengths(batch,)state: Optional initial state
Returns:
output: Output tensor (zeros for padded positions)state: Final state
x = torch.randn(32, 100, 64) # padded
lengths = torch.tensor([45, 67, 89, ...]) # actual lengths
output, state = model.forward_with_mask(x, lengths)DREAMStack
Stack of multiple DREAM layers.
class DREAMStack(nn.Module):
"""
Stack of multiple DREAM layers.
Parameters
----------
input_dim : int
Input dimension
hidden_dims : list[int]
Hidden dimensions for each layer
rank : int
Fast weights rank
dropout : float
Dropout between layers
**kwargs
Passed to DREAMConfig
"""Constructor
from dream import DREAMStack
model = DREAMStack(
input_dim=64,
hidden_dims=[128, 128, 64], # 3 layers
rank=8,
dropout=0.1
)Methods
forward(x: Tensor, state: list[DREAMState] | None = None) -> Tuple[Tensor, list[DREAMState]]
Process sequence through all layers.
Parameters:
x: Input sequence(batch, time, input_dim)state: Optional list of states for each layer
Returns:
output: Output tensor(batch, time, final_hidden_dim)states: List of final states for each layer
x = torch.randn(32, 50, 64)
output, states = model(x)
# output: (32, 50, 64) - final layer output
# states: list of 3 DREAMState objectsDREAMState
State container for DREAM cell.
@dataclass
class DREAMState:
"""State container for DREAM cell."""
h: Tensor # (batch, hidden_dim) - Hidden state
U: Tensor # (batch, hidden_dim, rank) - Fast weights
U_target: Tensor # (batch, hidden_dim, rank) - Target weights
adaptive_tau: Tensor # (batch,) - Habituated threshold
error_mean: Tensor # (batch, input_dim) - Error EMA mean
error_var: Tensor # (batch, input_dim) - Error EMA variance
avg_surprise: Tensor # (batch,) - Surprise EMAMethods
detach() -> DREAMState
Detach all tensors from computation graph.
Use case: Truncated BPTT (backpropagation through time)
# Process chunk
output, state = model(chunk, state)
# Detach for truncated BPTT
state = state.detach()to(device: str | torch.device) -> DREAMState
Move all tensors to a device.
state = state.to('cuda')shape() -> dict[str, tuple]
Get shapes of all state tensors.
shapes = state.shape()
# {'h': (32, 128), 'U': (32, 128, 8), ...}Accessing State Components
# Hidden state
hidden = state.h # (batch, hidden_dim)
# Fast weights
fast_weights = state.U # (batch, hidden_dim, rank)
# Average surprise (adaptation metric)
surprise = state.avg_surprise # (batch,)
# Error statistics
error_mean = state.error_mean # (batch, input_dim)
error_var = state.error_var # (batch, input_dim)Utilities
reset_parameters()
Reset model parameters with Xavier initialization.
model.apply(lambda m: m.reset_parameters() if hasattr(m, 'reset_parameters') else None)get_config() -> DREAMConfig
Get the configuration of a model.
config = model.get_config()Type Aliases
from typing import Tuple
import torch
Tensor = torch.Tensor
DREAMOutput = Tuple[Tensor, DREAMState]Exceptions
DREAMError
Base exception for DREAM-specific errors.
from dream import DREAMError
try:
model = DREAM(input_dim=-1) # Invalid dimension
except DREAMError as e:
print(f"DREAM error: {e}")Next Steps
- Usage Patterns - Common usage patterns
- Configuration Guide - Parameter tuning
- Examples - Real-world examples