DREAM

API Reference

Complete API documentation for DREAM

API Reference

Complete API documentation for DREAM.


Core Classes

DREAMConfig

Configuration container for DREAM cell.

from dream import DREAMConfig

config = DREAMConfig(
    input_dim=80,
    hidden_dim=256,
    rank=16,
    forgetting_rate=0.005,
    base_plasticity=0.5,
    base_threshold=0.3,
    ltc_tau_sys=5.0,
)

Parameters

ParameterTypeDefaultDescription
input_dimint39Input feature dimension
hidden_dimint256Hidden state size
rankint16Fast weights rank
time_stepfloat0.1Integration time step (dt)
forgetting_ratefloat0.005Fast weights decay (λ)
base_plasticityfloat0.5Hebbian learning rate (η)
base_thresholdfloat0.3Surprise threshold (τ₀)
entropy_influencefloat0.1Entropy effect (α)
surprise_temperaturefloat0.05Surprise scaling (γ)
error_smoothingfloat0.05Error EMA (β)
surprise_smoothingfloat0.05Surprise EMA (β_s)
target_normfloat2.0Fast weights norm
kappafloat0.5Gain modulation (κ)
ltc_enabledboolTrueEnable LTC
ltc_tau_sysfloat5.0Base time constant
ltc_surprise_scalefloat5.0Surprise modulation
sleep_ratefloat0.005Sleep rate (ζ_sleep)
min_surprise_for_sleepfloat0.2Sleep threshold (S_min)

DREAMState

State container for DREAM cell.

from dream import DREAMState

# Access state components
state.h              # Hidden state (batch, hidden_dim)
state.U              # Fast weights (batch, hidden_dim, rank)
state.U_target       # Target weights (batch, hidden_dim, rank)
state.adaptive_tau   # Adaptive threshold (batch,)
state.error_mean     # Error mean (batch, input_dim)
state.error_var      # Error variance (batch, input_dim)
state.avg_surprise   # Average surprise (batch,)

# Detach for BPTT
state = state.detach()

Methods

detach()

Detach all tensors from computation graph.

state = state.detach()
init_from_config()

Initialize state from config.

state = DREAMState.init_from_config(
    config,
    batch_size=4,
    device='cuda',
    dtype=torch.float32
)

DREAMCell

Main DREAM cell implementation.

from dream import DREAMCell

cell = DREAMCell(config)

Methods

__init__(config)

Initialize DREAM cell.

Parameters:

  • config (DREAMConfig): Model configuration
init_state(batch_size, device, dtype)

Initialize cell state.

Parameters:

  • batch_size (int): Batch size
  • device (torch.device, optional): Device
  • dtype (torch.dtype, optional): Data type

Returns:

  • DREAMState: Initialized state

Example:

state = cell.init_state(batch_size=4, device='cuda')
forward(x, state)

Forward pass for single timestep.

Parameters:

  • x (torch.Tensor): Input (batch, input_dim)
  • state (DREAMState): Current state

Returns:

  • h_new (torch.Tensor): New hidden state (batch, hidden_dim)
  • state (DREAMState): Updated state

Example:

x = torch.randn(4, 80)
h_new, state = cell(x, state)
forward_sequence(x_seq, state, return_all)

Process full sequence.

Parameters:

  • x_seq (torch.Tensor): Input (batch, time, input_dim)
  • state (DREAMState, optional): Initial state
  • return_all (bool): Return all timesteps

Returns:

  • output (torch.Tensor): Output (batch, time, hidden_dim) if return_all
  • state (DREAMState): Final state

Example:

x_seq = torch.randn(4, 100, 80)
output, final_state = cell.forward_sequence(x_seq, return_all=True)

DREAM

High-level DREAM model (LSTM-like).

from dream import DREAM

model = DREAM(
    input_dim=80,
    hidden_dim=256,
    rank=16,
)

Methods

__init__(input_dim, hidden_dim, rank, **kwargs)

Initialize DREAM model.

Parameters:

  • input_dim (int): Input dimension
  • hidden_dim (int): Hidden dimension
  • rank (int): Fast weights rank
  • **kwargs: Additional config options
forward(x, state, return_sequences)

Process sequence.

Parameters:

  • x (torch.Tensor): Input (batch, time, input_dim)
  • state (DREAMState, optional): Initial state
  • return_sequences (bool): Return all timesteps

Returns:

  • output (torch.Tensor): Output
  • state (DREAMState): Final state
init_state(batch_size, device, dtype)

Initialize model state.


DREAMStack

Stack of multiple DREAM layers.

from dream import DREAMStack

model = DREAMStack(
    input_dim=80,
    hidden_dims=[256, 256, 128],  # 3 layers
    rank=16,
    dropout=0.1,
)

Methods

__init__(input_dim, hidden_dims, rank, dropout, **kwargs)

Initialize DREAM stack.

Parameters:

  • input_dim (int): Input dimension
  • hidden_dims (List[int]): Hidden dimensions per layer
  • rank (int): Fast weights rank
  • dropout (float): Dropout rate
  • **kwargs: Additional config options
forward(x, states, return_sequences)

Process through all layers.

Returns:

  • output (torch.Tensor): Final output
  • states (List[DREAMState]): States for each layer

Examples

Basic Sequence Processing

import torch
from dream import DREAMConfig, DREAMCell

config = DREAMConfig(input_dim=80, hidden_dim=256)
cell = DREAMCell(config)
state = cell.init_state(batch_size=4)

# Single timestep
x = torch.randn(4, 80)
h, state = cell(x, state)

# Full sequence
x_seq = torch.randn(4, 100, 80)
output, final_state = cell.forward_sequence(x_seq, return_all=True)

Stateful Processing

# Initialize ONCE
state = cell.init_state(batch_size=4)

# Process multiple sequences (state preserved)
for seq in sequences:
    output, state = cell.forward_sequence(seq, state)
    # Model adapts and remembers!

Sequence Classification

from dream import DREAMCell
import torch.nn as nn

cell = DREAMCell(config)
classifier = nn.Linear(256, 10)  # 10 classes

state = cell.init_state(batch_size=32)
output, final_state = cell.forward_sequence(x)

# Classify using final hidden state
logits = classifier(final_state.h)
predictions = logits.argmax(dim=-1)

Truncated BPTT

state = cell.init_state(batch_size=4)

for start in range(0, seq_len, segment_size):
    segment = x[:, start:start+segment_size, :]
    output, state = cell.forward_sequence(segment, state)
    
    loss.backward()
    state = state.detach()  # Reset graph
    
optimizer.step()

Next Steps

On this page