This repository contains PyTorch/GPU and TorchXLA/TPU implementations of our paper: Diffusion Transformers with Representation Autoencoders. For JAX/TPU implementation, please refer to diffuse_nnx
Diffusion Transformers with Representation Autoencoders
Boyang Zheng, Nanye Ma, Shengbang Tong, Saining Xie
New York University
We present Representation Autoencoders (RAE), a class of autoencoders that utilize pretrained, frozen representation encoders such as DINOv2 and SigLIP2 as encoders with trained ViT decoders. RAE can be used in a two-stage training pipeline for high-fidelity image synthesis, where a Stage 2 diffusion model is trained on the latent space of a pretrained RAE to generate images.
This repository contains:
PyTorch/GPU:
- A PyTorch implementation of RAE and pretrained weights.
- A PyTorch implementation of LightningDiT, DiTDH and pretrained weights.
- Training and sampling scripts for the two-stage RAE+DiT pipeline.
TorchXLA/TPU:
- A TPU implementation of RAE and pretrained weights.
- Sampling of RAE and DiTDH on TPU.
- Create environment and install via
uv
:conda create -n rae python=3.10 -y conda activate rae pip install uv # Install PyTorch 2.2.0 with CUDA 12.1 uv pip install torch==2.2.0 torchvision==0.17.0 torchaudio --index-url https://download.pytorch.org/whl/cu121 # Install other dependencies uv pip install timm==0.9.16 accelerate==0.23.0 torchdiffeq==0.2.5 wandb uv pip install "numpy<2" transformers einops omegaconf
We release three kind of models: RAE decoders, DiTDH diffusion transformers and stats for latent normalization. To download all models at once:
cd RAE
pip install huggingface_hub
hf download nyu-visionx/RAE-collections \
--local-dir models
To download specific models, run:
hf download nyu-visionx/RAE-collections \
<remote_model_path> \
--local-dir models
- Download ImageNet-1k.
- Point Stage 1 and Stage 2 scripts to the training split via
--data-path
.
All training and sampling entrypoints are driven by OmegaConf YAML files. A single config describes the Stage 1 autoencoder, the Stage 2 diffusion model, and the solver used during training or inference. A minimal example looks like:
stage_1:
target: stage1.RAE
params: { ... }
ckpt: <path_to_ckpt>
stage_2:
target: stage2.models.DDT.DiTwDDTHead
params: { ... }
ckpt: <path_to_ckpt>
transport:
params:
path_type: Linear
prediction: velocity
...
sampler:
mode: ODE
params:
num_steps: 50
...
guidance:
method: cfg/autoguidance
scale: 1.0
...
misc:
latent_size: [768, 16, 16]
num_classes: 1000
training:
...
stage_1
instantiates the frozen encoder and trainable decoder. For Stage 1 training you can point to an existing checkpoint viastage_1.ckpt
or start frompretrained_decoder_path
.stage_2
defines the diffusion transformer. During sampling you must provideckpt
; during training you typically omit it so weights initialise randomly.transport
,sampler
, andguidance
select the forward/backward SDE/ODE integrator and optional classifier-free or autoguidance schedule.misc
collects shapes, class counts, and scaling constants used by both stages.training
contains defaults that the training scripts consume (epochs, learning rate, EMA decay, gradient accumulation, etc.).
Stage 1 training configs additionally include a top-level gan
block that
configures the discriminator architecture and the LPIPS/GAN loss schedule.
We release decoders for DINOv2-B, SigLIP-B, MAE-B, at configs/stage1/pretrained/
.
There is also a training script for training a ViT-XL decoder on DINOv2-B: configs/stage1/training/DINOv2-B_decXL.yaml
We release our best model, DiTDH-XL and it's guidance model on both configs/stage2/sampling/
.
We also provide training configs for DiTDH at configs/stage2/training/
.
src/train_stage1.py
fine-tunes the ViT decoder while keeping the
representation encoder frozen. Launch it with PyTorch DDP (single or multi-GPU):
torchrun --standalone --nproc_per_node=N \
src/train_stage1.py \
--config <config> \
--data-path <imagenet_train_split> \
--results-dir results/stage1 \
--image-size 256 --precision bf16/fp32 \
--ckpt <optional_ckpt> \
where N
refers to the number of GPU cards available, and --ckpt
resumes from an existing checkpoint.
Logging. To enable wandb
, firstly set WANDB_KEY
, ENTITY
, and PROJECT
as environment variables:
export WANDB_KEY="key"
export ENTITY="entity name"
export PROJECT="project name"
Then in training command add the --wandb
flag
Use src/stage1_sample.py
to encode/decode a single image:
python src/stage1_sample.py \
--config <config> \
--image assets/pixabay_cat.png \
For batched reconstructions and .npz
export, run the DDP variant:
torchrun --standalone --nproc_per_node=N \
src/stage1_sample_ddp.py \
--config <config> \
--data-path <imagenet_val_split> \
--sample-dir recon_samples \
--image-size 256
The script writes per-image PNGs as well as a packed .npz
suitable for FID.
src/train.py
trains the Stage 2 diffusion transformer using PyTorch DDP. Edit
one of the configs under configs/training/
and launch:
torchrun --standalone --nnodes=1 --nproc_per_node=N \
src/train.py \
--config <training_config> \
--data-path <imagenet_train_split> \
--results-dir results/stage2 \
--precision bf16
src/sample.py
uses the same config schema to draw a small batch of images on a
single device and saves them to sample.png
:
python src/sample.py \
--config <sample_config> \
--seed 42
src/sample_ddp.py
parallelises sampling across GPUs, producing PNGs and an
FID-ready .npz
:
torchrun --standalone --nnodes=1 --nproc_per_node=N \
src/sample_ddp.py \
--config <sample_config> \
--sample-dir samples \
--precision bf16 \
--label-sampling equal
--label-sampling {equal,random}
: equal
uses exactly 50 images per class for FID-50k; random
uniformly samples labels. Using equal
brings consistently lower FID than random
by around 0.1. We use equal
by default.
Autoguidance and classifier-free guidance are controlled via the config’s
guidance
block.
Use the ADM evaluation suite to score generated samples:
-
Clone the repo:
git clone https://github.com/openai/guided-diffusion.git cd guided-diffusion/evaluation
-
Create an environment and install dependencies:
conda create -n adm-fid python=3.10 conda activate adm-fid pip install 'tensorflow[and-cuda]'==2.19 scipy requests tqdm
-
Download ImageNet statistics (256×256 shown here):
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
-
Evaluate:
python evaluator.py VIRTUAL_imagenet256_labeled.npz /path/to/samples.npz
See XLA
branch for TPU support.
This code is built upon the following repositories:
- SiT - for diffusion implementation and training codebase.
- DDT - for some of the DiTDH implementation.
- LightningDiT - for the PyTorch Lightning based DiT implementation.
- MAE - for the ViT decoder architecture.