Memorax is a library for efficient recurrent models. Using category theory, we utilize a simple interface that should work for nearly all recurrent models. Unlike most other recurrent modeling libraries, we provide a unified interface for fast recurrent state resets across the sequence, allowing you to avoid truncating BPTT.
We implement both linear and log-complexity recurrent models.
Name | Parallel Time Complexity | Paper | Code |
---|---|---|---|
Linear Recurrent Unit | [paper] | [code] | |
Selective State Space Model (S6) | [paper] | [code] | |
Diagonal Selective State Space Model (S6D) | [paper] | [code] | |
Linear Recurrent Neural Network | [paper] | [code] | |
Fast Autoregressive Transformer | [paper] | [code] | |
Fast and Forgetful Memory | [paper] | [code] | |
Rotational RNN (RotRNN) | [paper] | [code] | |
Fast Weight Programmer | [paper] | [code] | |
DeltaNet | [paper] | [code] | |
Gated DeltaNet | [paper] | [code] | |
DeltaProduct | [paper] | [code] | |
Dot Product Attention | [paper] | [code] | |
Elman Network | [paper] | [code] | |
Gated Recurrent Unit | [paper] | [code] | |
Minimal Gated Unit | [paper] | [code] | |
Long Short-Term Memory Unit | [paper] | [code] |
We provide datasets to test our recurrent models.
Sequential MNIST [HuggingFace] [Code]
The recurrent model receives an MNIST image pixel by pixel, and must predict the digit class.
Sequence Lengths:
[784]
MNIST Math [HuggingFace] [Code]
The recurrent model receives a sequence of MNIST images and operators, pixel by pixel, and must predict the percentile of the operators applied to the MNIST image classes.
Sequence Lengths:
[784 * 5, 784 * 100, 784 * 1_000, 784 * 10_000, 784 * 1_000_000]
Continuous Localization [HuggingFace] [Code]
The recurrent model receives a sequence of translation and rotation vectors in the local coordinate frame, and must predict the corresponding position and orientation in the global coordinate frame.
Sequence Lengths:
[20, 100, 1_000]
Install memorax
using pip and git for your specific framework
pip install "memorax[equinox]@git+https://github.com/smorad/memorax"
pip install "memorax[flax]@git+https://github.com/smorad/memorax"
from memorax.equinox.train_utils import get_residual_memory_models
import jax
import jax.numpy as jnp
from equinox import filter_jit, filter_vmap
from memorax.equinox.train_utils import add_batch_dim
T, F = 5, 6 # time and feature dim
model = get_residual_memory_models(
input=F, hidden=8, output=1, num_layers=2,
models=["LRU"], key=jax.random.key(0)
)["LRU"]
starts = jnp.array([True, False, False, True, False])
xs = jnp.zeros((T, F))
hs, ys = filter_jit(model)(model.initialize_carry(), (xs, starts))
last_h = filter_jit(model.latest_recurrent_state)(hs)
# with batch dim
B = 4
starts = jnp.zeros((B, T), dtype=bool)
xs = jnp.zeros((B, T, F))
hs_0 = add_batch_dim(model.initialize_carry(), B)
hs, ys = filter_jit(filter_vmap(model))(hs_0, (xs, starts))
You can compare various recurrent models on our datasets with a single command
python run_equinox_experiments.py # equinox framework
python run_linen_experiments.py # flax linen framework
Memorax uses the equinox
neural network library. See the semigroups directory for fast recurrent models that utilize an associative scan. We also provide a beta flax.linen
API. In this example, we focus on equinox
.
import equinox as eqx
import jax
import jax.numpy as jnp
from memorax.equinox.set_actions.gru import GRU
from memorax.equinox.models.residual import ResidualModel
from memorax.equinox.semigroups.lru import LRU, LRUSemigroup
from memorax.utils import debug_shape
# You can pack multiple subsequences into a single sequence using the start flag
sequence_starts = jnp.array([True, False, False, True, False])
x = jnp.zeros((5, 3))
inputs = (x, sequence_starts)
# Initialize a multi-layer recurrent model
key = jax.random.key(0)
make_layer_fn = lambda recurrent_size, key: LRU(
hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key
)
model = ResidualModel(
make_layer_fn=make_layer_fn,
input_size=3,
recurrent_size=16,
output_size=4,
num_layers=2,
key=key,
)
# Note: We also have layers if you want to build your own model
layer = LRU(hidden_size=16, recurrent_size=16, key=key)
# Or semigroups/set actions (scanned functions) if you want to build your own layer
sg = LRUSemigroup(recurrent_size=16)
# Run the model! All models are jit-capable, using equinox.filter_jit
h = eqx.filter_jit(model.initialize_carry)()
# Unlike most other libraries, we output ALL recurrent states h, not just the most recent
h, y = eqx.filter_jit(model)(h, inputs)
# Since we have two layers, we have a recurrent state of shape
print(debug_shape(h))
# ((5, 16), # Recurrent states of first layer
# (5,) # Start carries for first layer
# (5, 16) # Recurrent states of second layer
# (5,)) # Start carries for second layer
#
# Do your prediction
prediction = jax.nn.softmax(y)
# If you want to continue rolling out the RNN from h[-1]
# you should use the following helper function to extract
# h[-1] from the nested recurrent state
latest_h = eqx.filter_jit(model.latest_recurrent_state)(h)
# Continue rolling out as you please! You can use a single timestep
# or another sequence.
last_h, last_y = eqx.filter_jit(model)(latest_h, inputs)
# We can use a similar approach with RNNs
make_layer_fn = lambda recurrent_size, key: GRU(
recurrent_size=recurrent_size, key=key
)
model = ResidualModel(
make_layer_fn=make_layer_fn,
input_size=3,
recurrent_size=16,
output_size=4,
num_layers=2,
key=jax.random.key(0),
)
h = eqx.filter_jit(model.initialize_carry)()
h, y = eqx.filter_jit(model)(h, inputs)
prediction = jax.nn.softmax(y)
latest_h = eqx.filter_jit(model.latest_recurrent_state)(h)
h, y = eqx.filter_jit(model)(latest_h, inputs)
All recurrent cells should follow the GRAS
interface. A recurrent cell consists of an Algebra
. You can roughly think of the Algebra
as the function that updates the recurrent state, and the GRAS
as the Algebra
and all the associated MLPs/gates. You may reuse our Algebra
s in your custom GRAS
, or even write your custom Algebra
.
To implement your own Algebra
and GRAS
, we suggest copying one from our existing code, such as the LRNN for a Semigroup
or the Elman Network for a SetAction
.
Full documentation is available here.
Please cite the library as
@misc{morad_memorax_2025,
title = {Memorax},
url = {https://github.com/smorad/memorax},
author = {Morad, Steven and Toledo, Edan and Kortvelesy, Ryan and He, Zhe},
month = jun,
year = {2025},
}
If you use the recurrent state resets (sequence_starts
) with the log complexity memory models, please cite
@article{morad2024recurrent,
title={Recurrent reinforcement learning with memoroids},
author={Morad, Steven and Lu, Chris and Kortvelesy, Ryan and Liwicki, Stephan and Foerster, Jakob and Prorok, Amanda},
journal={Advances in Neural Information Processing Systems},
volume={37},
pages={14386--14416},
year={2024}
}