8.0 KiB
Edit Flows
Reference 📄 Paper: Edit Flows: Flow Matching with Edit Operations
This directory provides an educational reference for training EditFlow models. It demonstrates how to adapt open-weight DLLMs—such as LLaDA and Dream—to support insertion, deletion, beyond the standard substitution(mask->tokens) operations. It also includes examples for training (pretraining and finetuning) EditFlow models from scratch.
Note
- Examples are available for both LLaDA and Dream, but this README focuses on adapting open-weight LLaDA for edit operations (
adapt_llada.py) and reusing its architecture for training from scratch (pt_llada.py->sft_llada.py).- While
EditFlowCollatorsupports customx0, this README uses a fixed-length (128) masks asx0. The trained model generates text by replacing masks, deleting redundant ones, and inserting tokens as needed. To change the defaultx0distribution (e.g., empty sequences for OneFlow-like insertion-only generation), pass--x0_sampler "empty".
Table of Contents
Setup
Important
Slurm users: Update
scripts/train.slurm.shandmkdir logps: see (optional) Slurm setup for details.
Files overview
dllm/pipelines/editflow
├── __init__.py # Package initialization
├── models
│ ├── dream
│ │ └── modelling_dream.py # EditFlowDream: architecture based on Dream
│ └── llada
│ └── modelling_llada.py # EditFlowLLaDA: architecture based on LLaDA
├── trainer.py
└── utils.py
# example entry point for training / sampling
examples/editflow
├── adapt_dream.py # Example of adapting Dream for EditFlow directly
├── adapt_llada.py # Example of adapting LLaDA for EditFlow directly
├── generate.py # Generation example
├── pt_dream.py # EditFlowDream pretraining example
├── pt_llada.py # EditFlowLLaDA pretraining example
├── pt.py # Pretraining function
├── README.md # Documentation (you are here)
├── sft_dream.py # EditFlowDream SFT example
├── sft_llada.py # EditFlowLLaDA SFT example
└── sft.py # Supervised finetuning function
Training
Adapting LLaDA-8B-Instruct to support insertion and deletion
The original LLaDA model generated text by iteratively substituting the given <mask> tokens to real tokens.
Figure: Example Gradio demo for LLaDA.
However, LLaDA supports only substitution. This example shows how to adapt it so that, during decoding, the model can not only replace fixed-length masks (e.g., 128 tokens) with real text but also insert new tokens and delete unnecessary masks adaptively:
accelerate launch \
--config_file scripts/accelerate_configs/zero2.yaml \
examples/editflow/adapt_llada.py \
--model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" \
--lm_head_key "model.transformer.ff_out" \
--init_editflow_from_src True \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture" \
--x0_sampler "masks[length:128]" \
--max_length 1024 \
--num_train_epochs 4 \
--learning_rate 5e-5
If you are using slurm and want to train across, for example, four nodes (32 GPUs total), run:
sbatch --nodes=4 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/editflow/adapt_llada.py" \
--model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" \
--lm_head_key "model.transformer.ff_out" \
--init_editflow_from_src True \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture" \
--x0_sampler "masks[length:128]" \
--max_length 1024 \
--num_train_epochs 4 \
--learning_rate 5e-5
After training, you can use the generate.py scripts to provide a visualized decoding trace to see how the model performs insertion and deletion beyond regular mask substitutions. See Sampling for details.
Pretraining & Finetuning from scratch
You can also train an EditFlow model from scratch (pretrain → SFT) without adapting an existing DLLM.
Pretrain on a subset of mlfoundations/dclm-baseline-1.0 using 192 GPUs (24x8) and FSDP:
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/editflow/pt_llada.py" \
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
--dataset_args "mlfoundations/dclm-baseline-1.0" \
--output_dir "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0" \
--x0_sampler "masks[length:128]" \
--max_length 1024 \
--max_steps 2000 \
--learning_rate 3e-4
Finetune on a subset of allenai/tulu-3-sft-mixture using 8 GPUS and FSDP for better instruction following:
# you can also run locally with `accelerate ...`
sbatch --nodes=1 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/editflow/sft_llada.py" \
--model_name_or_path "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0/checkpoint-final" \
--dataset_args "allenai/tulu-3-sft-mixture[train:10000,test:1000]" \
--output_dir "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0" \
--x0_sampler "masks[length:128]" \
--max_length 1024 \
--num_train_epochs 4 \
--learning_rate 5e-5
Sampling
After training, you can visualize how the model performs mask substitution, insertion, and deletion during generation with generate.py. Inserted tokens appear blue, and tokens substituted from <mask> appear black, and deleted tokens are shown with a strikethrough before they disappear.
# Generate a long sequence to visualize insertions after 128 <mask> tokens
python examples/editflow/generate.py \
--model_name_or_path "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture/checkpoint-final" \
--tau 0.02 --mask_length 128 --seed 7070 \
--prompt "write a romantic story" --make_gif
# Generate a short sequence to visualize deletions after 128 <mask> tokens
python examples/editflow/generate.py \
--model_name_or_path "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture/checkpoint-final" \
--tau 0.02 --mask_length 128 --seed 7070 \
--prompt "write a single-sentence romantic story" --make_gif
Figure: Deletion & Substitution trace
Figure: Inserction & Substitution trace
Acknowledgement
This Edit Flows implementation is inspired by https://github.com/TheMatrixMaster/edit-flows-demo.


