Files
2025-11-27 15:44:17 +08:00

8.0 KiB

Edit Flows

Reference 📄 Paper: Edit Flows: Flow Matching with Edit Operations

This directory provides an educational reference for training EditFlow models. It demonstrates how to adapt open-weight DLLMs—such as LLaDA and Dream—to support insertion, deletion, beyond the standard substitution(mask->tokens) operations. It also includes examples for training (pretraining and finetuning) EditFlow models from scratch.

Note

  • Examples are available for both LLaDA and Dream, but this README focuses on adapting open-weight LLaDA for edit operations (adapt_llada.py) and reusing its architecture for training from scratch (pt_llada.py -> sft_llada.py).
  • While EditFlowCollator supports custom x0, this README uses a fixed-length (128) masks as x0. The trained model generates text by replacing masks, deleting redundant ones, and inserting tokens as needed. To change the default x0 distribution (e.g., empty sequences for OneFlow-like insertion-only generation), pass --x0_sampler "empty".

Table of Contents

Setup

Important

Slurm users: Update scripts/train.slurm.sh and mkdir logps: see (optional) Slurm setup for details.

Files overview

dllm/pipelines/editflow
├── __init__.py                 # Package initialization
├── models
│   ├── dream
│   │   └── modelling_dream.py  # EditFlowDream: architecture based on Dream
│   └── llada
│       └── modelling_llada.py  # EditFlowLLaDA: architecture based on LLaDA
├── trainer.py
└── utils.py

# example entry point for training / sampling
examples/editflow
├── adapt_dream.py              # Example of adapting Dream for EditFlow directly
├── adapt_llada.py              # Example of adapting LLaDA for EditFlow directly
├── generate.py                 # Generation example
├── pt_dream.py                 # EditFlowDream pretraining example
├── pt_llada.py                 # EditFlowLLaDA pretraining example
├── pt.py                       # Pretraining function
├── README.md                   # Documentation (you are here)
├── sft_dream.py                # EditFlowDream SFT example
├── sft_llada.py                # EditFlowLLaDA SFT example
└── sft.py                      # Supervised finetuning function

Training

Adapting LLaDA-8B-Instruct to support insertion and deletion

The original LLaDA model generated text by iteratively substituting the given <mask> tokens to real tokens.

LLaDA demo

Figure: Example Gradio demo for LLaDA.

However, LLaDA supports only substitution. This example shows how to adapt it so that, during decoding, the model can not only replace fixed-length masks (e.g., 128 tokens) with real text but also insert new tokens and delete unnecessary masks adaptively:

accelerate launch \
    --config_file scripts/accelerate_configs/zero2.yaml \
    examples/editflow/adapt_llada.py \
    --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" \
    --lm_head_key "model.transformer.ff_out" \
    --init_editflow_from_src True \
    --dataset_args "allenai/tulu-3-sft-mixture" \
    --output_dir "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture" \
    --x0_sampler "masks[length:128]" \
    --max_length 1024 \
    --num_train_epochs 4 \
    --learning_rate 5e-5

If you are using slurm and want to train across, for example, four nodes (32 GPUs total), run:

sbatch --nodes=4 --gres=gpu:8 scripts/train.slurm.sh \
    --accelerate_config "fsdp" \
    --script_path "examples/editflow/adapt_llada.py" \
    --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" \
    --lm_head_key "model.transformer.ff_out" \
    --init_editflow_from_src True \
    --dataset_args "allenai/tulu-3-sft-mixture" \
    --output_dir "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture" \
    --x0_sampler "masks[length:128]" \
    --max_length 1024 \
    --num_train_epochs 4 \
    --learning_rate 5e-5

After training, you can use the generate.py scripts to provide a visualized decoding trace to see how the model performs insertion and deletion beyond regular mask substitutions. See Sampling for details.

Pretraining & Finetuning from scratch

You can also train an EditFlow model from scratch (pretrain → SFT) without adapting an existing DLLM.

Pretrain on a subset of mlfoundations/dclm-baseline-1.0 using 192 GPUs (24x8) and FSDP:

sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
    --accelerate_config "fsdp" \
    --script_path "examples/editflow/pt_llada.py" \
    --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
    --dataset_args "mlfoundations/dclm-baseline-1.0" \
    --output_dir "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0" \
    --x0_sampler "masks[length:128]" \
    --max_length 1024 \
    --max_steps 2000 \
    --learning_rate 3e-4

Finetune on a subset of allenai/tulu-3-sft-mixture using 8 GPUS and FSDP for better instruction following:

# you can also run locally with `accelerate ...`
sbatch --nodes=1 --gres=gpu:8 scripts/train.slurm.sh \
    --accelerate_config "fsdp" \
    --script_path "examples/editflow/sft_llada.py" \
    --model_name_or_path "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0/checkpoint-final" \
    --dataset_args "allenai/tulu-3-sft-mixture[train:10000,test:1000]" \
    --output_dir "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0" \
    --x0_sampler "masks[length:128]" \
    --max_length 1024 \
    --num_train_epochs 4 \
    --learning_rate 5e-5

Sampling

After training, you can visualize how the model performs mask substitution, insertion, and deletion during generation with generate.py. Inserted tokens appear blue, and tokens substituted from <mask> appear black, and deleted tokens are shown with a strikethrough before they disappear.

# Generate a long sequence to visualize insertions after 128 <mask> tokens
python examples/editflow/generate.py \
  --model_name_or_path "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture/checkpoint-final" \
  --tau 0.02 --mask_length 128 --seed 7070 \
  --prompt "write a romantic story" --make_gif

# Generate a short sequence to visualize deletions after 128 <mask> tokens
python examples/editflow/generate.py \
  --model_name_or_path "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture/checkpoint-final" \
  --tau 0.02 --mask_length 128 --seed 7070 \
  --prompt "write a single-sentence romantic story" --make_gif

EditFlow deletion demo

Figure: Deletion & Substitution trace

LLaDA demo

Figure: Inserction & Substitution trace

Acknowledgement

This Edit Flows implementation is inspired by https://github.com/TheMatrixMaster/edit-flows-demo.