API Reference
Complete API documentation for DREAM
API Reference
Complete API documentation for DREAM.
Core Classes
DREAMConfig
Configuration container for DREAM cell.
from dream import DREAMConfig
config = DREAMConfig(
input_dim=80,
hidden_dim=256,
rank=16,
forgetting_rate=0.005,
base_plasticity=0.5,
base_threshold=0.3,
ltc_tau_sys=5.0,
)Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
input_dim | int | 39 | Input feature dimension |
hidden_dim | int | 256 | Hidden state size |
rank | int | 16 | Fast weights rank |
time_step | float | 0.1 | Integration time step (dt) |
forgetting_rate | float | 0.005 | Fast weights decay (λ) |
base_plasticity | float | 0.5 | Hebbian learning rate (η) |
base_threshold | float | 0.3 | Surprise threshold (τ₀) |
entropy_influence | float | 0.1 | Entropy effect (α) |
surprise_temperature | float | 0.05 | Surprise scaling (γ) |
error_smoothing | float | 0.05 | Error EMA (β) |
surprise_smoothing | float | 0.05 | Surprise EMA (β_s) |
target_norm | float | 2.0 | Fast weights norm |
kappa | float | 0.5 | Gain modulation (κ) |
ltc_enabled | bool | True | Enable LTC |
ltc_tau_sys | float | 5.0 | Base time constant |
ltc_surprise_scale | float | 5.0 | Surprise modulation |
sleep_rate | float | 0.005 | Sleep rate (ζ_sleep) |
min_surprise_for_sleep | float | 0.2 | Sleep threshold (S_min) |
DREAMState
State container for DREAM cell.
from dream import DREAMState
# Access state components
state.h # Hidden state (batch, hidden_dim)
state.U # Fast weights (batch, hidden_dim, rank)
state.U_target # Target weights (batch, hidden_dim, rank)
state.adaptive_tau # Adaptive threshold (batch,)
state.error_mean # Error mean (batch, input_dim)
state.error_var # Error variance (batch, input_dim)
state.avg_surprise # Average surprise (batch,)
# Detach for BPTT
state = state.detach()Methods
detach()
Detach all tensors from computation graph.
state = state.detach()init_from_config()
Initialize state from config.
state = DREAMState.init_from_config(
config,
batch_size=4,
device='cuda',
dtype=torch.float32
)DREAMCell
Main DREAM cell implementation.
from dream import DREAMCell
cell = DREAMCell(config)Methods
__init__(config)
Initialize DREAM cell.
Parameters:
config(DREAMConfig): Model configuration
init_state(batch_size, device, dtype)
Initialize cell state.
Parameters:
batch_size(int): Batch sizedevice(torch.device, optional): Devicedtype(torch.dtype, optional): Data type
Returns:
DREAMState: Initialized state
Example:
state = cell.init_state(batch_size=4, device='cuda')forward(x, state)
Forward pass for single timestep.
Parameters:
x(torch.Tensor): Input (batch, input_dim)state(DREAMState): Current state
Returns:
h_new(torch.Tensor): New hidden state (batch, hidden_dim)state(DREAMState): Updated state
Example:
x = torch.randn(4, 80)
h_new, state = cell(x, state)forward_sequence(x_seq, state, return_all)
Process full sequence.
Parameters:
x_seq(torch.Tensor): Input (batch, time, input_dim)state(DREAMState, optional): Initial statereturn_all(bool): Return all timesteps
Returns:
output(torch.Tensor): Output (batch, time, hidden_dim) if return_allstate(DREAMState): Final state
Example:
x_seq = torch.randn(4, 100, 80)
output, final_state = cell.forward_sequence(x_seq, return_all=True)DREAM
High-level DREAM model (LSTM-like).
from dream import DREAM
model = DREAM(
input_dim=80,
hidden_dim=256,
rank=16,
)Methods
__init__(input_dim, hidden_dim, rank, **kwargs)
Initialize DREAM model.
Parameters:
input_dim(int): Input dimensionhidden_dim(int): Hidden dimensionrank(int): Fast weights rank**kwargs: Additional config options
forward(x, state, return_sequences)
Process sequence.
Parameters:
x(torch.Tensor): Input (batch, time, input_dim)state(DREAMState, optional): Initial statereturn_sequences(bool): Return all timesteps
Returns:
output(torch.Tensor): Outputstate(DREAMState): Final state
init_state(batch_size, device, dtype)
Initialize model state.
DREAMStack
Stack of multiple DREAM layers.
from dream import DREAMStack
model = DREAMStack(
input_dim=80,
hidden_dims=[256, 256, 128], # 3 layers
rank=16,
dropout=0.1,
)Methods
__init__(input_dim, hidden_dims, rank, dropout, **kwargs)
Initialize DREAM stack.
Parameters:
input_dim(int): Input dimensionhidden_dims(List[int]): Hidden dimensions per layerrank(int): Fast weights rankdropout(float): Dropout rate**kwargs: Additional config options
forward(x, states, return_sequences)
Process through all layers.
Returns:
output(torch.Tensor): Final outputstates(List[DREAMState]): States for each layer
Examples
Basic Sequence Processing
import torch
from dream import DREAMConfig, DREAMCell
config = DREAMConfig(input_dim=80, hidden_dim=256)
cell = DREAMCell(config)
state = cell.init_state(batch_size=4)
# Single timestep
x = torch.randn(4, 80)
h, state = cell(x, state)
# Full sequence
x_seq = torch.randn(4, 100, 80)
output, final_state = cell.forward_sequence(x_seq, return_all=True)Stateful Processing
# Initialize ONCE
state = cell.init_state(batch_size=4)
# Process multiple sequences (state preserved)
for seq in sequences:
output, state = cell.forward_sequence(seq, state)
# Model adapts and remembers!Sequence Classification
from dream import DREAMCell
import torch.nn as nn
cell = DREAMCell(config)
classifier = nn.Linear(256, 10) # 10 classes
state = cell.init_state(batch_size=32)
output, final_state = cell.forward_sequence(x)
# Classify using final hidden state
logits = classifier(final_state.h)
predictions = logits.argmax(dim=-1)Truncated BPTT
state = cell.init_state(batch_size=4)
for start in range(0, seq_len, segment_size):
segment = x[:, start:start+segment_size, :]
output, state = cell.forward_sequence(segment, state)
loss.backward()
state = state.detach() # Reset graph
optimizer.step()Next Steps
- Benchmarks — Performance comparison
- Guides — Tutorials and examples