1127 update to latest

This commit is contained in:
FelixChan
2025-11-27 15:44:17 +08:00
parent e16c84aab2
commit a34d39430e
153 changed files with 25705 additions and 53 deletions

139
dllm/.gitignore vendored Normal file
View File

@ -0,0 +1,139 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.idea
.Python
build/
develop-eggs/
dist/
downloads/
applications/DeepSpeed-Chat/data
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Others
/.vscode/
/tmp/
/data/
/wandb/
/logs/
/models*/

4
dllm/.gitmodules vendored Normal file
View File

@ -0,0 +1,4 @@
[submodule "lm-evaluation-harness"]
path = lm-evaluation-harness
url = https://github.com/ZHZisZZ/lm-evaluation-harness
branch = dllm

21
dllm/LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 Zhanhui Zhou
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

283
dllm/README.md Normal file
View File

@ -0,0 +1,283 @@
<h1 align="center">dLLM</h1>
<p align="center">
Simple Diffusion Language Modeling
</p>
<p align="center">
<img
src="assets/logo.gif"
alt="dLLM logo">
</p>
## Overview
**dLLM** is a library that unifies the training and evaluation of **diffusion language models**, bringing transparency and reproducibility to the entire development pipeline:
<!-- and [RND1](https://www.radicalnumerics.ai/assets/rnd1_report.pdf) -->
- dLLM provides scalable training pipelines (inspired by [`transformers`](https://github.com/huggingface/transformers/blob/main/src/transformers) [Trainer](https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py)), with support for [LoRA](https://github.com/huggingface/peft), [DeepSpeed](https://github.com/deepspeedai/DeepSpeed) and [FSDP](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and beyond.
- dLLM provides unified evaluation pipelines (inspired by [`lm-evaluation-harness`](https://github.com/EleutherAI/lm-evaluation-harness)) that abstracts away inference details and making customization simple.
- Built on these components, dLLM provide the minimal **pretraining / finetuning / evaluation** recipes for open-weight models (e.g., [LLaDA](https://arxiv.org/abs/2502.09992) and [Dream](https://arxiv.org/abs/2508.15487)), and implementations of training algorithms (e.g., [Edit Flows](https://arxiv.org/abs/2506.09018)).
<!-- > [!NOTE]
> This repository is primarily for educational purposes and does not aim for 100% exact reproduction of official models (which is impossible). We hope it serves as a helpful reference for the community — contributions and improvements are always welcome! -->
## News
**[2025/11]** We released a collection of BERTs finetuned for instruction-following: [`ModernBERT-{large,base}-chat-v0`](https://huggingface.co/collections/dllm-collection/bert-chat). This proof-of-concept shows that BERTs internal knowledge can be leveraged for generative tasks via masked instruction tuning. See [![blog](https://img.shields.io/badge/W&B-white?logo=weightsandbiases) BERT Chat Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg) for detailed recipes, experimental results and lessons learned; See [`examples/bert`](/examples/bert) for training / inference / evaluation instructions.
## Table of Contents
- [Features](#features)
- [Setup](#setup)
- [Files overview](#files-overview)
- [Training](#training)
- [Inference](#inference)
- [Evaluation](#evaluation)
- [Citation](#citation)
## Features
<!-- - [`examples/rnd`](/examples/rnd): (WIP) Finetuning open-weight RND1 [RND1-Base](https://www.radicalnumerics.ai/assets/rnd1_report.pdf). -->
- [`examples/llada`](/examples/llada): Pretraining, finetuning and evaluating LLaDA [LLaDA](https://arxiv.org/abs/2502.09992) / [LLaDA-MoE](https://arxiv.org/abs/2509.24389).
- [`examples/dream`](/examples/dream): Pretraining, finetuning and evaluating Dream [Dream](https://arxiv.org/abs/2508.15487).
- [`examples/bert`](/examples/bert): Finetuning any [BERT](https://arxiv.org/abs/1810.04805) to be lightweight Chatbots.
<details>
<summary>🎬 Click to show BERT Chat Demo</summary>
<p align="center">
<img src="/examples/bert/assets/chat.gif" alt="chat" width="80%">
</p>
<p align="center">
<em>
Chat with <a href="https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0"><code>ModernBERT-large-chat-v0</code></a>. See <a href="/examples/bert/README.md/#inference">Inference</a> for details.
</em>
</p>
</details>
- [`examples/editflow`](/examples/editflow): Educational reference for training [EditFlow](https://arxiv.org/abs/2506.09018) models, demonstrating how to extend existing DLLMs (e.g., LLaDA, Dream, BERT Chat) with *edit operations*—insertion, deletion, and substitution—and how to pretrain or finetune EditFlow models from scratch on public data.
<details>
<summary>🎬 Click to show EditFlow Demo</summary>
<p align="center">
<img src="/examples/editflow/assets/all.gif" alt="EditFlow demo" width="100%">
</p>
<p align="center"><em>EditFlow performing insertion (blue), substitution from mask tokens (black), substitution from non-mask tokens (red), and deletion (strikethrough → removed) during generation.</em></p>
</details>
- More upcoming.
## Setup
### Installation
```bash
# create and activate conda environment
conda create -n dllm python=3.10 -y
conda activate dllm
# install pytorch with CUDA 12.4 (other pytorch/cuda versions should also work)
conda install cuda=12.4 -c nvidia
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 \
--index-url https://download.pytorch.org/whl/cu124
# install dllm package
pip install -e .
```
### (optional) Evaluation setup
```bash
# initialize `lm-evaluation-harness` submodule
git submodule update --init --recursive
# install submodule in editable mode with IFEval & Math dependencies
pip install -e "lm-evaluation-harness[ifeval,math]"
```
### (optional) Slurm setup
For [Slurm](https://slurm.schedmd.com/) users, update [`scripts/train.slurm.sh`](/scripts/train.slurm.sh) for your cluster:
```diff
- #SBATCH --partition=mllm_safety # Note: adjust this for your cluster
- #SBATCH --quotatype=spot # Note: adjust this for your cluster
+ #SBATCH --partition=YOUR_PARTITION
+ #SBATCH --quotatype=YOUR_QUOTATYPE
```
Next, create a directory for your job logs:
```shell
mkdir logs
```
This folder will store the log files generated by your sbatch jobs.
## Files overview
```
# modules for training / sampling
dllm
├── core # Core reusable modules shared across `dllm/pipelines`
│ ├── generation
│ ├── schedulers
│ └── trainers
├── data
├── pipelines # Application-specific training & inference pipelines
| ├── bert
│ ├── dream
│ ├── editflow
│ └── llada
│ ├── models # Model architecture and configs
│ ├── generator.py # Generation utilities
│ ├── trainer.py # Core training logic
│ └── eval.py # Evaluation entry point
├── tools
└── utils
# entry points for training / sampling
examples
├── bert
├── dream
├── editflow
└── llada
├── chat.py # Interactive inference example
├── generate.py # Inference example
├── pt.py # Pretraining example
├── README.md # Documentation (you are here)
├── sft.py # Supervised finetuning example
└── eval.sh # Evalution script
```
## Training
A typical training entry script looks like (for example, [`examples/llada/sft.py`](/examples/llada/sft.py)) looks like this:
```python
import transformers
import dllm
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# ----- Model ------------------------------------------------------------------
model = dllm.utils.get_model(model_args=model_args)
# ----- Tokenizer --------------------------------------------------------------
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
# ----- Dataset ----------------------------------------------------------------
dataset = "..."
# ----- Training --------------------------------------------------------------
trainer = dllm.core.trainers.MDLMTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
args=training_args,
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
label_pad_token_id=tokenizer.pad_token_id,
),
)
trainer.train()
```
You can launch training job locally with `accelerate`, or submit it to a [Slurm](https://slurm.schedmd.com/) cluster using `sbatch`.
```shell
# Run locally (ZeRO-2 on 8 GPUs with 4bit quantization and LoRA)
accelerate launch \
--config_file scripts/accelerate_configs/zero2.yaml \
examples/llada/sft.py \
--num_train_epochs 4 \
--load_in_4bit True --lora True
```
```shell
# Submit to a Slurm cluster (FSDP on 1 node, 8 GPUs)
sbatch --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/sft.py" \
--num_train_epochs 4
# Submit to a Slurm cluster (FSDP on 2 nodes, 16 GPUs)
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/sft.py" \
--num_train_epochs 4
```
See [Features](#features) for specific training recipes.
> Here are some useful tips for training:
> 1. Use a subset of data:
> `--dataset_args "allenai/tulu-3-sft-mixture[train:10000,test:1000]"`
> 2. Concatenate datasets:
> `--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk"`
> 3. Train with LoRA and 4bit quantization:
> `--load_in_4bit True --lora True`
> 4. Train with different distributed training methods:
> `--accelerate_config "ddp,zero-{1,2,3},fsdp"`
## Inference
We provide unified [generators](/dllm/core/generation/generator.py) that abstracts away inference details.
A typical inference entry script (for example, [`examples/llada/generate.py`](/examples/llada/generate.py)) looks like this:
```python
import dllm
from dllm import llada
model = dllm.utils.get_model(model_args=script_args).eval()
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
# for other models, change your generator and keep others unchanged
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
messages = [
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
[{"role": "user", "content": "Please write an educational python function."}],
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
)
outputs = generator.generate(inputs, return_dict_in_generate=True)
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
```
You can also try interactive chat script (for example, [`examples/llada/chat.py`](/examples/llada/chat.py)) for visualized multi-turn dialogue:
<p align="center">
<img src="/assets/chat.gif" alt="chat" width="80%">
</p>
<!-- <p align="center"><em>EditFlow performing insertion (blue), substitution from mask tokens (black), substitution from non-mask tokens (red), and deletion (strikethrough → removed) during generation.</em></p> -->
## Evaluation
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
For example, to evaluate [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) on [`MMLU_Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro), run:
```shell
accelerate launch --num_processes 4 \
dllm/pipelines/llada/eval.py \
--tasks "mmlu_pro" \
--model "llada" \
--apply_chat_template \
--num_fewshot 0 \
--model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256,cfg=0.0"
```
We also provide scripts to automatically evaluate [LLaDA](https://arxiv.org/abs/2502.09992), [Dream](https://arxiv.org/abs/2508.15487), and [BERT-Chat](https://huggingface.co/collections/dllm-collection/bert-chat) on all benchmarks.
For example, you can launch [`examples/llada/eval.sh`](/examples/llada/eval.sh) directly using the following commands:
```shell
bash examples/llada/eval.sh --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" --instruct True
bash examples/llada/eval.sh --model_name_or_path "GSAI-ML/LLaDA-8B-Base" --instruct False
```
## Citation
```
@misc{dllm,
author = {Zhanhui Zhou and Lingjie Chen and Hanghang Tong and Dawn Song},
title = {dLLM: Simple Diffusion Language Modeling},
year = {2025},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/ZHZisZZ/dllm}},
}
```

Binary file not shown.

BIN
dllm/assets/chat.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.5 MiB

BIN
dllm/assets/logo.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 956 KiB

BIN
dllm/assets/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.0 KiB

119
dllm/assets/logo.py Normal file
View File

@ -0,0 +1,119 @@
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import os
# ---------- Configuration (smaller size) ----------
W, H = 480, 210 # lower resolution
TOTAL_DURATION = 3.0
FPS = 15 # lower fps
TEXT = "dLLM"
# TEXT_COLOR = (235, 235, 235)
TEXT_COLOR = (0, 0, 0)
OUTPUT = "logo.gif"
LAST_FRAME_PNG = "logo.png"
DIFFUSION_PORTION = 0.3 # fewer diffusion frames
SEED = 8
# ---------- Auto font size ----------
def load_font_auto_size(text, w, h, target_width_ratio=0.95, target_height_ratio=0.95):
lo, hi = 10, 2000
best_font, best_size = None, lo
while lo <= hi:
mid = (lo + hi) // 2
try:
font = ImageFont.truetype(
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", size=mid
)
except:
font = ImageFont.load_default()
dummy = Image.new("L", (w, h), 0)
d = ImageDraw.Draw(dummy)
bbox = d.textbbox((0, 0), text, font=font)
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
width_ok = tw <= w * target_width_ratio
height_ok = th <= h * target_height_ratio
if width_ok and height_ok:
best_font, best_size = font, mid
lo = mid + 1
else:
hi = mid - 1
return best_font if best_font is not None else font
# ---------- Text rendering ----------
def render_text_mask(w, h, text, font):
img = Image.new("L", (w, h), 0)
d = ImageDraw.Draw(img)
bbox = d.textbbox((0, 0), text, font=font)
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
x = (w - tw) // 2 - bbox[0]
y = (h - th) // 2 - bbox[1]
d.text((x, y), text, font=font, fill=255)
return np.asarray(img, np.float32) / 255.0
# ---------- Initialization ----------
font = load_font_auto_size(TEXT, W, H)
mask = render_text_mask(W, H, TEXT, font)
num_frames = int(TOTAL_DURATION * FPS)
diffusion_frames = max(1, int(num_frames * DIFFUSION_PORTION))
hold_ms = int((TOTAL_DURATION - diffusion_frames / FPS) * 1000)
rng = np.random.default_rng(SEED)
frames = []
# ---------- Diffusion stage ----------
for i in range(diffusion_frames):
t = i / max(1, diffusion_frames - 1)
progress = t**0.9
noise_sigma = (1.0 - progress) ** 2.2
noise = rng.standard_normal((H, W, 1)).astype(np.float32)
noise_img = 1.0 - noise_sigma * 0.5 * np.abs(noise)
np.clip(noise_img, 0.0, 1.0, out=noise_img)
alpha = progress**2.0
alpha_map = (mask * alpha).astype(np.float32)[..., None]
text_rgb = np.zeros((H, W, 3), dtype=np.float32)
for c in range(3):
text_rgb[..., c] = (mask > 0).astype(np.float32) * (TEXT_COLOR[c] / 255.0)
frame = (1.0 - alpha_map) * noise_img + alpha_map * text_rgb
frame = (np.clip(frame, 0.0, 1.0) * 255).astype(np.uint8)
frames.append(Image.fromarray(frame, mode="RGB"))
# ---------- Last frame ----------
final_frame = frames[-1]
# ---------- Save last frame as PNG ----------
final_frame.save(LAST_FRAME_PNG)
print(f"🖼️ Last frame saved as: {LAST_FRAME_PNG}")
# ---------- Quantization (reduce size) ----------
pal_frames = [f.convert("P", palette=Image.ADAPTIVE, colors=64) for f in frames]
pal_final = final_frame.convert("P", palette=Image.ADAPTIVE, colors=64)
# ---------- Save GIF ----------
normal_ms = int(1000 / FPS)
durations = [normal_ms] * len(pal_frames) + [hold_ms]
pal_frames[0].save(
OUTPUT,
save_all=True,
append_images=pal_frames[1:] + [pal_final],
duration=durations,
loop=0,
optimize=True,
)
print(f"✅ GIF saved: {OUTPUT}")
print(
f"Frames (diffusion only): {len(pal_frames)} at {FPS} FPS, final hold {hold_ms} ms, resolution {W}x{H}"
)

1
dllm/dllm/__init__.py Normal file
View File

@ -0,0 +1 @@
from . import core, data, pipelines, utils

View File

@ -0,0 +1 @@
from dllm.core import trainers, schedulers, generation

View File

@ -0,0 +1 @@
from . import generator, visualizer

View File

@ -0,0 +1,49 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
from transformers import PreTrainedTokenizer, PreTrainedModel
from dllm.core.schedulers import BaseAlphaScheduler, LinearAlphaScheduler
@dataclass
class GeneratorOutput:
sequences: torch.Tensor
histories: list[torch.Tensor] | None = None
@dataclass
class GeneratorConfig:
return_dict_in_generate: bool = False
@dataclass
class BaseGenerator(ABC):
model: PreTrainedModel
tokenizer: PreTrainedTokenizer
scheduler: BaseAlphaScheduler | None = None
def __post_init__(self):
if self.scheduler is None:
self.scheduler = LinearAlphaScheduler()
@abstractmethod
@torch.no_grad()
def generate(
self,
prompts: list[torch.Tensor, list],
config: GeneratorConfig | None = None,
**kwargs,
) -> GeneratorOutput:
raise NotImplementedError
@abstractmethod
@torch.no_grad()
def infill(
self,
inputs: list[torch.Tensor, list],
config: GeneratorConfig | None = None,
**kwargs,
) -> GeneratorOutput:
raise NotImplementedError

View File

@ -0,0 +1,427 @@
from __future__ import annotations
import os
import re
import sys
import time
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Sequence, Optional
import torch
from tqdm import tqdm
from transformers import PreTrainedTokenizer
@dataclass
class BaseVisualizer(ABC):
tokenizer: PreTrainedTokenizer
@abstractmethod
def visualize(self, history: list[torch.Tensor, list], **kwargs):
raise NotImplementedError
@dataclass
class VideoVisualizer(BaseVisualizer):
def visualize(
self,
history: list[torch.Tensor, list],
output_path: str = "visualization.gif",
**kwargs,
):
raise NotImplementedError
@dataclass
class TerminalVisualizer(BaseVisualizer):
# Configuration (adjust as needed)
HEADER_SIZE = 3 # Fixed number of lines for the header (0 if show_header is False)
PROGRESS_SIZE = 3 # Fixed number of lines for the progress bar
PANEL_PADDING_TOP = 1 # Top padding of the Panel (padding=(top, side))
PANEL_PADDING_BOTTOM = 1 # Bottom padding of the Panel
PANEL_PADDING_SIDE = 1 # Number of characters used for left and right padding
PANEL_BORDER = 2 # Number of columns taken by the Panel border (usually 2)
MIN_TOTAL_HEIGHT = 10 # Minimum terminal height (in lines)
MAX_TOTAL_HEIGHT = 60 # Maximum terminal height to prevent overflowing the terminal
DEFAULT_TERM_WIDTH = 120 # Default terminal width (in columns)
ansi_escape = re.compile(r"\x1b\[[0-9;]*m") # Regex to match ANSI escape codes
def visualize(
self,
history: list[torch.Tensor], # list of tokens per step: [T] or [B,T]
fps: int = 16,
rich: bool = True,
title: str = "dllm",
max_chars: int = None,
every_n_steps: int = 1,
show_header: bool = True,
skip_special_tokens: bool = False,
) -> None:
"""
Visualize a masked-diffusion decoding trajectory stored in `history`.
If items have batch dimension [B, T], visualize each sequence separately.
"""
try:
# detect batch size
first_step = history[0]
if first_step.dim() > 1 and first_step.shape[0] > 1:
B = first_step.shape[0]
for b_idx in range(B):
# build per-sequence history
seq_history = [step[b_idx].unsqueeze(0) for step in history]
self.visualize_one_history(
seq_history,
fps,
rich,
title=f"{title} (Batch {b_idx})",
max_chars=max_chars,
every_n_steps=every_n_steps,
show_header=show_header,
skip_special_tokens=skip_special_tokens,
)
else:
# no batch, just visualize normally
self.visualize_one_history(
history,
fps,
rich,
title,
max_chars,
every_n_steps,
show_header,
skip_special_tokens,
)
except Exception as e:
print(f"(Visualization skipped due to error: {e})")
def visualize_one_history(
self,
history: list[torch.Tensor], # list of tokens per step: [T] or [B,T]
fps: int = 16,
rich: bool = True,
title: str = "dllm",
max_chars: int = None,
every_n_steps: int = 1, # re-render frequency (perf knob)
show_header: bool = True,
skip_special_tokens: bool = False, # NEW ARGUMENT
) -> None:
"""
Visualize a masked-diffusion decoding trajectory stored in `history`.
Args:
history: Sequence of token tensors for each step. Each item is [T] or [B,T].
fps: Frames per second for the live UI (Rich) or sleep cadence for tqdm fallback.
title: Header title.
max_chars: Cap on rendered characters to keep terminal snappy.
every_n_steps: Only redraw text every N steps (progress still updates every step).
show_header: Show the magenta header bar (Rich path).
skip_special_tokens: Whether to skip special/pad/eos tokens when rendering (default: False).
Notes:
- Masked positions are detected via `self.tokenizer.mask_token_id`.
- Special tokens are determined via `self.tokenizer.all_special_ids`.
- All layout, styling, and progress are encapsulated here.
"""
# --------- imports & env checks ----------
try:
from rich.console import Console
from rich.live import Live
from rich.text import Text
from rich.panel import Panel
from rich.progress import (
Progress,
BarColumn,
TextColumn,
TimeRemainingColumn,
MofNCompleteColumn,
SpinnerColumn,
)
from rich.layout import Layout
_RICH_IMPORTED = True
except Exception:
_RICH_IMPORTED = False
try:
from tqdm import tqdm
_TQDM_IMPORTED = True
except Exception:
_TQDM_IMPORTED = False
if self.tokenizer is None:
raise ValueError(
"TerminalVisualizer.tokenizer must be set to a valid tokenizer."
)
tokenizer = self.tokenizer
specials: set[int] = set(getattr(tokenizer, "all_special_ids", []) or [])
self._specials = specials # store for helpers
self._mask_token_id: Optional[int] = getattr(tokenizer, "mask_token_id", None)
self._pad_token_id: Optional[int] = getattr(tokenizer, "pad_token_id", None)
self._eos_token_id: Optional[int] = getattr(tokenizer, "eos_token_id", None)
# --------- helpers inside class scope ----------
# (keep everything inside this class as requested)
# throttle settings
sleep_s = 0.0 if fps <= 0 else 1.0 / float(max(1, fps))
total_steps = len(history)
every_n_steps = max(1, int(every_n_steps))
# decode final text up-front (used after render)
final_text = self._detok(history[-1], skip_special_tokens=skip_special_tokens)
final_text = self._truncate(final_text, max_chars)
# ------------------ new: estimate height from final_text ------------------
import textwrap
import shutil
def strip_ansi(s: str) -> str:
return self.ansi_escape.sub("", s) if s else ""
def estimate_height_from_text(text: str, console_width: int) -> int:
"""
Estimate how many terminal rows the panel with `text` will need given console_width.
Uses class constants for paddings/borders and header/progress sizes.
"""
plain = strip_ansi(text or "")
# inner width = console width minus left/right panel paddings & border
inner_width = max(
10, console_width - 2 * self.PANEL_PADDING_SIDE - self.PANEL_BORDER
)
lines = 0
# preserve existing newlines: wrap each paragraph separately
for para in plain.splitlines() or [""]:
if para.strip() == "":
lines += 1
continue
wrapped = textwrap.wrap(
para,
width=inner_width,
replace_whitespace=False,
drop_whitespace=False,
)
lines += max(1, len(wrapped))
text_block_lines = (
lines + self.PANEL_PADDING_TOP + self.PANEL_PADDING_BOTTOM
)
extra = 2 # for panel title / subtitle / small margin
header_h = self.HEADER_SIZE if show_header else 0
total = header_h + text_block_lines + self.PROGRESS_SIZE + extra
# clamp
total = max(self.MIN_TOTAL_HEIGHT, min(total, self.MAX_TOTAL_HEIGHT))
return int(total)
# try to detect terminal width; fallback to 100
try:
term_width = shutil.get_terminal_size().columns
if not isinstance(term_width, int) or term_width <= 0:
term_width = self.DEFAULT_TERM_WIDTH
except Exception:
term_width = self.DEFAULT_TERM_WIDTH
est_height = estimate_height_from_text(final_text, console_width=term_width)
# ------------------ end new ----------------------------------------------
# choose rich or tqdm
use_rich = bool(rich and _RICH_IMPORTED)
if not use_rich or not _RICH_IMPORTED:
# ---------- tqdm fallback ----------
if not _TQDM_IMPORTED:
for i, toks in enumerate(history, start=1):
if sleep_s > 0:
time.sleep(sleep_s)
print("\n✨ Generation complete!\n")
print(final_text)
return
pbar = tqdm(total=total_steps, desc="Diffusion", leave=True)
for i, toks in enumerate(history, start=1):
pbar.update(1)
pbar.set_postfix(
{
"masks": self._count_masks(toks),
"pct": f"{int(100 * i / max(total_steps, 1))}%",
}
)
if sleep_s > 0:
time.sleep(sleep_s)
pbar.close()
print("\n✨ Generation complete!\n")
if final_text:
print(final_text)
return
# ---------- rich live UI ----------
# replaced fixed height=100 with the estimated height from history[-1]
console = Console(
force_terminal=True,
color_system="truecolor",
width=term_width,
height=est_height,
)
layout = Layout()
layout.split_column(
(
Layout(name="header", size=3)
if show_header
else Layout(name="header", size=0)
),
Layout(name="text", ratio=1),
Layout(name="progress", size=3),
)
progress = Progress(
SpinnerColumn(),
TextColumn("[bold blue]Diffusion"),
BarColumn(),
MofNCompleteColumn(),
TextColumn(""),
TextColumn("[cyan]Masks: {task.fields[masks]}"),
TextColumn(""),
TextColumn("[magenta]{task.fields[pct]:>4s}"),
TimeRemainingColumn(),
expand=True,
)
init_masks = self._count_masks(history[0]) if history else 0
task_id = progress.add_task(
"Generating", total=total_steps, masks=init_masks, pct="0%"
)
with Live(layout, console=console, refresh_per_second=max(1, fps)):
for step_idx, toks in enumerate(history, start=1):
if show_header:
header = Text(title, style="bold magenta", justify="center")
layout["header"].update(Panel(header, border_style="bright_blue"))
# progress bar
masks_remaining = self._count_masks(toks)
pct = f"{int(100 * step_idx / max(total_steps, 1))}%"
progress.update(task_id, advance=1, masks=masks_remaining, pct=pct)
# text panel: decode whole sequence (avoids Ġ/Ċ artifacts)
if (
every_n_steps <= 1
or (step_idx % every_n_steps == 0)
or step_idx in (1, total_steps)
):
text_str = self._detok(
toks, skip_special_tokens=skip_special_tokens
)
text_str = self._truncate(text_str, max_chars)
text_rich = Text.from_ansi(text_str) if text_str else Text("")
layout["text"].update(
Panel(
(
text_rich
if text_rich.plain
else Text("[dim]— no tokens —[/dim]")
),
title="[bold]Generated Text",
subtitle=f"[dim]Step {step_idx}/{total_steps}[/dim]",
border_style="cyan",
padding=(1, 1),
)
)
layout["progress"].update(Panel(progress))
if sleep_s > 0:
time.sleep(sleep_s)
console.print("\n[bold green]✨ Generation complete![/bold green]\n")
# console.print(
# Panel(
# final_text if final_text else "[dim]— no decodable text —[/dim]",
# title="[bold]Final Generated Text",
# border_style="green",
# padding=(1, 2),
# )
# )
# ======================== helpers (kept inside class) ========================
def _has_tty(self) -> bool:
return sys.stdout.isatty() and os.environ.get("TERM", "") not in ("", "dumb")
def _first_item(self, x: torch.Tensor) -> torch.Tensor:
return x[0] if x.dim() > 1 else x
def _count_masks(self, toks: torch.Tensor) -> int:
if getattr(self, "_mask_token_id", None) is None:
return 0
t = self._first_item(toks)
return int((t == self._mask_token_id).sum().item())
def _detok(self, ids_or_tensor, *, skip_special_tokens: bool) -> str:
"""
Robust detokenize for list[int] / torch.Tensor([T]) / torch.Tensor([B,T]).
Decode the whole sequence to avoid byte-level artifacts like Ġ/Ċ.
"""
tokenizer = self.tokenizer
# normalize to python list[int]
if isinstance(ids_or_tensor, torch.Tensor):
t = self._first_item(ids_or_tensor).long()
ids = t.tolist()
elif isinstance(ids_or_tensor, (list, tuple)):
ids = list(ids_or_tensor)
else:
# unknown type
return ""
# Optionally drop specials/pad/eos *before* decode if desired
if skip_special_tokens:
keep = []
specials = getattr(self, "_specials", set())
pad_id = getattr(self, "_pad_token_id", None)
eos_id = getattr(self, "_eos_token_id", None)
for tid in ids:
if tid in specials:
continue
if pad_id is not None and tid == pad_id:
continue
if eos_id is not None and tid == eos_id:
continue
keep.append(tid)
ids = keep
# Prefer tokenizer.decode (handles Ġ/Ċ, merges properly)
text = ""
try:
if hasattr(tokenizer, "decode"):
text = tokenizer.decode(
ids,
skip_special_tokens=False,
clean_up_tokenization_spaces=True,
)
else:
# fallback: tokens -> string
toks = tokenizer.convert_ids_to_tokens(ids)
if hasattr(tokenizer, "convert_tokens_to_string"):
text = tokenizer.convert_tokens_to_string(toks)
else:
text = " ".join(map(str, toks))
except Exception:
# extremely defensive fallback
try:
text = tokenizer.decode(ids, skip_special_tokens=True)
except Exception:
text = ""
# sanitize control chars for terminal
if text:
text = text.replace("\r", "")
return text
def _truncate(self, s: str, max_chars: Optional[int]) -> str:
if max_chars is None or (isinstance(max_chars, int) and max_chars < 0):
return s
return s[:max_chars]
if __name__ == "__main__":
pass

View File

@ -0,0 +1,2 @@
from .alpha import *
from .kappa import *

View File

@ -0,0 +1,132 @@
from __future__ import annotations
import dataclasses
import math
from typing import ClassVar, Dict, Type, Any, Union
import torch
Number = Union[float, torch.Tensor]
# ---------------- Registry-enabled Base ---------------- #
@dataclasses.dataclass
class BaseAlphaScheduler:
__registry__: ClassVar[dict[str, type[BaseAlphaScheduler]]] = {}
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
BaseAlphaScheduler.__registry__[cls.__name__] = cls
BaseAlphaScheduler.__registry__[cls.__name__.lower()] = cls
# Make instances callable (sched(i) -> alpha(i))
def __call__(self, t: Number) -> Number:
return self.alpha(t)
# ---- common API ----
def alpha(self, i: Number) -> Number:
i_t = torch.as_tensor(
i,
dtype=torch.float32,
device=i.device if isinstance(i, torch.Tensor) else None,
)
if not torch.all((0.0 <= i_t) & (i_t <= 1.0)):
raise ValueError(f"i={i} not in [0,1]")
out = self._alpha(i_t)
return out.item() if isinstance(i, float) else out
def alpha_derivative(self, i: Number) -> Number:
i_t = torch.as_tensor(
i,
dtype=torch.float32,
device=i.device if isinstance(i, torch.Tensor) else None,
)
if not torch.all((0.0 <= i_t) & (i_t <= 1.0)):
raise ValueError(f"i={i} not in [0,1]")
out = self._alpha_derivative(i_t)
return out.item() if isinstance(i, float) else out
def reverse_mask_prob(self, s: Number, t: Number) -> Number:
t_t = torch.as_tensor(
t,
dtype=torch.float32,
device=t.device if isinstance(t, torch.Tensor) else None,
)
s_t = torch.as_tensor(
s,
dtype=torch.float32,
device=s.device if isinstance(s, torch.Tensor) else None,
)
if not torch.all((0.0 <= s_t) & (s_t < 1.0) & (0.0 < t_t) & (t_t <= 1.0)):
raise ValueError(f"(t={t}, s={s}) out of range")
if not torch.all(s_t < t_t):
raise ValueError(f"Require s < t elementwise, but got (t={t}, s={s})")
out = (1 - self(s_t)) / (1 - self(t_t))
return out.item() if isinstance(t, float) and isinstance(s, float) else out
def weight(self, i: Number) -> Number:
# w(t) = - α'(t) / (1 - α(t))
return - self.alpha_derivative(i) / (1 - self.alpha(i) + 1e-6)
# ---- hooks implemented by subclasses ----
def _alpha(self, i: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def _alpha_derivative(self, i: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
# ---------------- Implementations ---------------- #
@dataclasses.dataclass
class LinearAlphaScheduler(BaseAlphaScheduler):
def _alpha(self, i: torch.Tensor) -> torch.Tensor:
return 1 - i
def _alpha_derivative(self, i: torch.Tensor) -> torch.Tensor:
return -torch.ones_like(i)
@dataclasses.dataclass
class CosineAlphaScheduler(BaseAlphaScheduler):
def _alpha(self, i: torch.Tensor) -> torch.Tensor:
return 1 - torch.cos((math.pi / 2) * (1 - i))
def _alpha_derivative(self, i: torch.Tensor) -> torch.Tensor:
return -(math.pi / 2) * torch.sin((math.pi / 2) * (1 - i))
# ---------------- Factory helpers ---------------- #
def get_alpha_scheduler_class(name: str) -> type[BaseAlphaScheduler]:
"""Return the scheduler class by name (case-insensitive)."""
cls = BaseAlphaScheduler.__registry__.get(
name
) or BaseAlphaScheduler.__registry__.get(name.lower())
if cls is None:
available = sorted(k for k in BaseAlphaScheduler.__registry__ if k[0].isupper())
raise ValueError(f"Unknown scheduler '{name}'. Available: {available}")
return cls
def make_alpha_scheduler(name: str, **kwargs: Any) -> BaseAlphaScheduler:
"""Instantiate a scheduler by name with optional kwargs."""
cls = get_alpha_scheduler_class(name)
return cls(**kwargs)
# ---------------- Example usage ---------------- #
if __name__ == "__main__":
lin_sched = make_alpha_scheduler("LinearalphaScheduler")
print("Linear α(0.5):", lin_sched.alpha(0.5))
print("Linear w(0.5):", lin_sched.weight(0.5))
print("Linear α([.25,.5,.75]):", lin_sched.alpha(torch.tensor([0.25, 0.5, 0.75])))
print("Linear w([.25,.5,.75]):", lin_sched.weight(torch.tensor([0.25, 0.5, 0.75])))
print("==========================================")
cos_sched = make_alpha_scheduler("CosinealphaScheduler")
print("Cosine α(0.5):", cos_sched.alpha(0.5))
print("Cosine w(0.5):", cos_sched.weight(0.5))
print("Cosine α([.25,.5,.75]):", cos_sched.alpha(torch.tensor([0.25, 0.5, 0.75])))
print("Cosine w([.25,.5,.75]):", cos_sched.weight(torch.tensor([0.25, 0.5, 0.75])))

View File

@ -0,0 +1,128 @@
from __future__ import annotations
import dataclasses
import math
from typing import ClassVar, Dict, Type, Any, Union
import torch
Number = Union[float, torch.Tensor]
# ---------------- Registry-enabled Base ---------------- #
@dataclasses.dataclass
class BaseKappaScheduler:
__registry__: ClassVar[dict[str, type[BaseKappaScheduler]]] = {}
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
BaseKappaScheduler.__registry__[cls.__name__] = cls
BaseKappaScheduler.__registry__[cls.__name__.lower()] = cls
# Make instances callable (sched(t) -> kappa(t))
def __call__(self, t: Number) -> Number:
return self.kappa(t)
# ---- common API ----
def kappa(self, t: Number) -> Number:
t_tensor = torch.as_tensor(
t,
dtype=torch.float32,
device=t.device if isinstance(t, torch.Tensor) else None,
)
if not torch.all((0.0 <= t_tensor) & (t_tensor <= 1.0)):
raise ValueError(f"t={t} not in [0,1]")
out = self._kappa(t_tensor)
return out.item() if isinstance(t, float) else out
def kappa_derivative(self, t: Number) -> Number:
t_tensor = torch.as_tensor(
t,
dtype=torch.float32,
device=t.device if isinstance(t, torch.Tensor) else None,
)
if not torch.all((0.0 <= t_tensor) & (t_tensor <= 1.0)):
raise ValueError(f"t={t} not in [0,1]")
out = self._kappa_derivative(t_tensor)
return out.item() if isinstance(t, float) else out
def weight(self, t: Number) -> Number:
# w(t) = κ'(t) / (1 - κ(t))
return self.kappa_derivative(t) / (1 - self.kappa(t) + 1e-6)
# ---- hooks implemented by subclasses ----
def _kappa(self, t: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def _kappa_derivative(self, t: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
# ---------------- Implementations ---------------- #
@dataclasses.dataclass
class CubicKappaScheduler(BaseKappaScheduler):
a: float = 1.0
b: float = 1.0
def _kappa(self, t: torch.Tensor) -> torch.Tensor:
# κ(t) = (a+1) t^3 - (a+b+1) t^2 + (b+1) t
return (self.a + 1) * (t**3) - (self.a + self.b + 1) * (t**2) + (self.b + 1) * t
def _kappa_derivative(self, t: torch.Tensor) -> torch.Tensor:
# κ'(t) = 3(a+1) t^2 - 2(a+b+1) t + (b+1)
return 3 * (self.a + 1) * (t**2) - 2 * (self.a + self.b + 1) * t + (self.b + 1)
@dataclasses.dataclass
class LinearKappaScheduler(CubicKappaScheduler):
# Special case: κ(t) = t corresponds to a=-1, b=0
a: float = -1.0
b: float = 0.0
@dataclasses.dataclass
class CosineKappaScheduler(BaseKappaScheduler):
def _kappa(self, t: torch.Tensor) -> torch.Tensor:
# κ(t) = 1 - cos((π/2) * t)
return 1.0 - torch.cos(0.5 * math.pi * t)
def _kappa_derivative(self, t: torch.Tensor) -> torch.Tensor:
# κ'(t) = (π/2) * sin((π/2) * t)
return 0.5 * math.pi * torch.sin(0.5 * math.pi * t)
# ---------------- Factory helpers ---------------- #
def get_kappa_scheduler_class(name: str) -> type[BaseKappaScheduler]:
"""Return the scheduler class by name (case-insensitive)."""
cls = BaseKappaScheduler.__registry__.get(
name
) or BaseKappaScheduler.__registry__.get(name.lower())
if cls is None:
available = sorted(k for k in BaseKappaScheduler.__registry__ if k[0].isupper())
raise ValueError(f"Unknown scheduler '{name}'. Available: {available}")
return cls
def make_kappa_scheduler(name: str, **kwargs: Any) -> BaseKappaScheduler:
"""Instantiate a scheduler by name with optional kwargs."""
cls = get_kappa_scheduler_class(name)
return cls(**kwargs)
# ---------------- Example usage ---------------- #
if __name__ == "__main__":
lin_sched = make_kappa_scheduler("LinearKappaScheduler")
print("Linear κ(0.5):", lin_sched.kappa(0.5))
print("Linear w(0.5):", lin_sched.weight(0.5))
print("Linear κ([.25,.5,.75]):", lin_sched.kappa(torch.tensor([0.25, 0.5, 0.75])))
print("Linear w([.25,.5,.75]):", lin_sched.weight(torch.tensor([0.25, 0.5, 0.75])))
print("==========================================")
cos_sched = make_kappa_scheduler("CosineKappaScheduler")
print("Cosine κ(0.5):", cos_sched.kappa(0.5))
print("Cosine w(0.5):", cos_sched.weight(0.5))
print("Cosine κ([.25,.5,.75]):", cos_sched.kappa(torch.tensor([0.25, 0.5, 0.75])))
print("Cosine w([.25,.5,.75]):", cos_sched.weight(torch.tensor([0.25, 0.5, 0.75])))

View File

@ -0,0 +1 @@
from dllm.core.trainers.mdlm import MDLMTrainer

View File

@ -0,0 +1,140 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from typing import Any
from dllm.core.schedulers import BaseAlphaScheduler, LinearAlphaScheduler
class MDLMTrainer(transformers.Trainer):
"""
Masked Diffusion Language Model Trainer.
"""
def __init__(
self,
*args,
scheduler: BaseAlphaScheduler | None = None,
time_epsilon: float = 1e-3,
loss_weight_type: str = "scheduler", # "ones"
**kwargs,
):
super().__init__(*args, **kwargs)
self.scheduler = scheduler or LinearAlphaScheduler()
if not (0.0 < time_epsilon < 1.0):
raise ValueError("time_epsilon must be in (0, 1)")
self.time_epsilon = time_epsilon
self.loss_weight_type = loss_weight_type
def _preprocess_inputs(self, inputs):
pass
def _postprocess_outputs(self, outputs):
pass
def _compute_loss_weights(
self,
t: torch.Tensor,
inputs: dict[str, Any],
*args,
**kwargs,
) -> torch.Tensor:
"""Compute loss weights given timestep t and other arguments."""
b, l = inputs["input_ids"].shape
if self.loss_weight_type == "scheduler":
loss_weights = self.scheduler.weight(t).unsqueeze(1).repeat(1, l) # b, 1
elif self.loss_weight_type == "ones":
loss_weights = torch.ones_like(inputs["input_ids"])
else:
raise NotImplementedError
return loss_weights
@torch.no_grad()
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
if prediction_loss_only:
return (loss.detach(), None, None)
logits = getattr(outputs, "logits", outputs)
if isinstance(logits, torch.Tensor):
logits = logits.detach().contiguous()
labels = inputs.get("labels")
if isinstance(labels, torch.Tensor):
labels = labels.detach().contiguous()
return (loss.detach(), logits, labels)
def compute_loss(
self,
model: transformers.PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
return_outputs: bool = False,
**kwargs,
):
assert self.processing_class.padding_side == "right"
self._preprocess_inputs(inputs)
input_ids, labels, attention_mask = (
inputs["input_ids"],
inputs["labels"],
inputs.get("attention_mask", None),
)
b, l = input_ids.shape
# === 1. Sample diffusion timesteps ===
# Each example draws a random timestep t ∈ [ε, 1), where ε avoids degenerate values near 0.
# The scheduler defines the masking rate α(t); we convert it to a masking probability p_mask = 1 - α(t).
t = self.time_epsilon + (1 - self.time_epsilon) * torch.rand(
b, device=input_ids.device
)
p_mask = 1 - self.scheduler(t).unsqueeze(1).expand(b, l)
# === 2. Apply stochastic masking ===
# Tokens are masked independently according to p_mask(t).
# Positions with label = -100 are excluded (ignored in loss).
masked_indices = (torch.rand((b, l), device=input_ids.device) < p_mask) & (
labels != -100
)
# Replace masked tokens with the special [MASK] token.
noised_input_ids = torch.where(
masked_indices, self.processing_class.mask_token_id, input_ids
)
# === 3. Forward pass through the model ===
# The model predicts clean tokens given noised inputs.
outputs = model(input_ids=noised_input_ids, attention_mask=attention_mask)
self._postprocess_outputs(outputs)
logits = outputs.logits
# === 4. Handle degenerate cases (no tokens masked) ===
# If no positions were masked, return a zero loss to keep gradients valid.
# This step is necessary for Deepspeed Zero-{2,3}
if not masked_indices.any():
return (
(logits.sum() * 0.0, outputs) if return_outputs else logits.sum() * 0.0
)
# === 5. Compute per-token loss weights ===
# Depending on the configuration, weights may depend on timestep t
# (e.g., scheduler-based) or be uniform (ones).
loss_weights = self._compute_loss_weights(
t=t, inputs=inputs, masked_indices=masked_indices
)
# === 6. Compute weighted cross-entropy ===
# Only masked tokens contribute to the loss.
assert (input_ids[masked_indices] == labels[masked_indices]).all()
token_loss = F.cross_entropy(
logits[masked_indices], input_ids[masked_indices], reduction="none"
)
token_loss = token_loss * loss_weights[masked_indices]
# === 7. Normalize loss per effective token length ===
# Normalize each sequences contribution by its number of valid tokens,
# then average over the batch for stability across variable-length inputs.
effective_lengths = torch.sum(labels != -100, dim=1, keepdim=True).expand(b, l)
loss = torch.sum(token_loss / effective_lengths[masked_indices]) / b
# === 8. Return final loss (and optionally model outputs) ===
return (loss, outputs) if return_outputs else loss

View File

@ -0,0 +1 @@
from .utils import load_sft_dataset, load_pt_dataset

63
dllm/dllm/data/alpaca.py Normal file
View File

@ -0,0 +1,63 @@
from typing import Optional
from datasets import load_dataset, DatasetDict
def _build_alpaca_prompt(instruction: str, input_text: str | None) -> str:
"""Construct a clean text prompt from Alpaca fields.
We intentionally *do not* include Anthropic-style role tags (e.g., "Human:", "Assistant:")
in the returned prompt, to mirror the return shape of `load_hh_rlhf_dataset` which removes
those tags from the prompt it returns.
"""
instruction = (instruction or "").strip()
input_text = (input_text or "").strip()
if input_text:
# Keep instruction and input separated by a blank line for readability.
return f"{instruction}\n\n{input_text}"
else:
return instruction
def load_dataset_alpaca(dataset_name_or_path: str) -> DatasetDict:
"""Load the Alpaca dataset (tatsu-lab/alpaca) and expose unified fields.
Returns a `DatasetDict` where each split contains:
- prompt: Combined instruction (+ optional input), with clean formatting
- response: The target output (model answer)
Parameters
----------
dataset_name_or_path : str
Usually "tatsu-lab/alpaca" or a local path.
"""
dataset = load_dataset(dataset_name_or_path)
def map_fn(example):
prompt = _build_alpaca_prompt(
example.get("instruction", ""), example.get("input", "")
)
response = (example.get("output", "") or "").strip()
return {
"messages": [
{"role": "user", "content": prompt},
{"role": "assistant", "content": response},
]
}
dataset = dataset.map(
map_fn, remove_columns=dataset["train"].column_names, num_proc=4
)
# make train test split
dataset = dataset["train"].train_test_split(test_size=0.1, seed=42)
return dataset
if __name__ == "__main__":
from dllm.utils import resolve_with_base_env
dataset_name_or_path = resolve_with_base_env(
"tatsu-lab/alpaca", "BASE_DATASETS_DIR"
)
dataset = load_dataset_alpaca(dataset_name_or_path)
breakpoint()

133
dllm/dllm/data/opc.py Normal file
View File

@ -0,0 +1,133 @@
from typing import Optional, Text, List, Dict
from datasets import (
load_dataset,
get_dataset_config_names,
concatenate_datasets,
DatasetDict,
Dataset,
IterableDatasetDict,
)
from dllm.data.utils import (
_merge_datasetdicts,
_merge_iterabledatasetdicts,
_ensure_datasetdict,
_ensure_iterabledatasetdict,
_ensure_datasetdict,
)
def load_dataset_opc_sft(
dataset_name_or_path: str, name: str | None = None, lang: str | None = None
) -> DatasetDict:
"""
Load OpenCoder OPC SFT dataset(s) and produce a DatasetDict with a train/test split.
- If `name` is provided: load that specific config.
- If `name` is None: load *all* available configs and concatenate them.
"""
def _map_to_messages(ds: Dataset) -> Dataset:
def map_fn(example):
return {
"messages": [
{"role": "user", "content": example["instruction"]},
{"role": "assistant", "content": example["output"]},
]
}
# Remove all original columns after mapping
remove_cols = ds.column_names
return ds.map(map_fn, remove_columns=remove_cols, num_proc=4)
def _load_one_config(dataset_name_or_path: str, cfg_name: str) -> Dataset:
ds = load_dataset(dataset_name_or_path, cfg_name, split="train")
return _map_to_messages(ds)
if name is not None:
train_ds = _load_one_config(dataset_name_or_path, name)
else:
# Enumerate and load all configs, then concatenate
cfgs: list[str] = get_dataset_config_names(dataset_name_or_path)
if not cfgs:
raise ValueError(f"No configs found for dataset: {dataset_name_or_path}")
parts = [_load_one_config(dataset_name_or_path, c) for c in cfgs]
train_ds = concatenate_datasets(parts)
# Final split
ds_dict = train_ds.train_test_split(test_size=0.1, seed=42)
if lang is not None:
ds_dict = ds_dict.filter(lambda row: lang in row["messages"][1]["content"])
return DatasetDict(ds_dict)
def load_dataset_opc_annealing(
dataset_name_or_path: str,
name: str | None = None,
lang: str | None = None,
streaming: bool = True,
) -> DatasetDict:
def _load_one_config(_name):
ds = load_dataset(
dataset_name_or_path, _name, split="train", streaming=streaming
)
if lang:
if _name in ["synthetic_code_snippet", "algorithmic_corpus"]:
ds = ds.filter(lambda row: row["lang"] == lang)
elif _name in ["synthetic_qa"]:
ds = ds.filter(lambda row: row["program_lang"] == lang)
else:
raise NotImplementedError
# return IterableDatasetDict({"train": ds})
if streaming:
return _ensure_iterabledatasetdict(ds)
return _ensure_datasetdict(ds)
if name is not None:
return _load_one_config(name)
if streaming:
parts = [
_load_one_config(name)
for name in get_dataset_config_names(dataset_name_or_path)
]
merged = parts[0]
for p in parts[1:]:
merged = _merge_iterabledatasetdicts(merged, p)
return merged
else:
parts = [
_load_one_config(name)
for name in get_dataset_config_names(dataset_name_or_path)
]
if len(parts) == 1:
return _ensure_datasetdict(parts[0])
merged = parts[0]
for p in parts[1:]:
merged = _merge_datasetdicts(merged, p)
return _ensure_datasetdict(merged)
if __name__ == "__main__":
from dllm.utils import resolve_with_base_env
dataset_name_or_path = resolve_with_base_env(
"OpenCoder-LLM/opc-sft-stage1", "BASE_DATASETS_DIR"
)
# If you want a specific config:
dataset_edu = load_dataset_opc_sft(dataset_name_or_path, "realuser_instruct")
# Otherwise, all configs concatenated:
dataset_all = load_dataset_opc_sft(dataset_name_or_path, None)
dataset_all_python = load_dataset_opc_sft(dataset_name_or_path, None, "python")
breakpoint()
# streaming = True
# dataset_name_or_path = resolve_with_base_env(
# "OpenCoder-LLM/opc-annealing-corpus", "BASE_DATASETS_DIR"
# )
# # If you want a specific config:
# dataset_alg_all = load_dataset_opc_annealing(dataset_name_or_path, "algorithmic_corpus")
# dataset_alg_python = load_dataset_opc_annealing(dataset_name_or_path, "algorithmic_corpus", "python")
# # Otherwise, all configs concatenated:
# dataset_all_python = load_dataset_opc_annealing(dataset_name_or_path, None, "python")
# dataset_all_all = load_dataset_opc_annealing(dataset_name_or_path, None)
# breakpoint()

108
dllm/dllm/data/ultrachat.py Normal file
View File

@ -0,0 +1,108 @@
from typing import Optional, List, Dict
from datasets import load_dataset, DatasetDict
def _extract_first_turn(messages: list[dict[str, str]]) -> dict[str, str] | None:
"""
Given a list of chat messages like:
[{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."},
...]
return a dict with the first user/assistant exchange as:
{"prompt": <user content>, "response": <assistant content>}
If no valid first turn exists, return None.
"""
if not isinstance(messages, list) or len(messages) < 2:
return None
# Find the first user message and the first assistant *after* that user msg
# (Most entries start as [user, assistant, ...], but we guard anyway.)
user_idx = None
for i, m in enumerate(messages):
if (
isinstance(m, dict)
and m.get("role") == "user"
and isinstance(m.get("content"), str)
):
user_idx = i
break
if user_idx is None:
return None
# Find first assistant after that user
for j in range(user_idx + 1, len(messages)):
m = messages[j]
if (
isinstance(m, dict)
and m.get("role") == "assistant"
and isinstance(m.get("content"), str)
):
user_text = messages[user_idx]["content"].strip()
assistant_text = m["content"].strip()
if user_text and assistant_text:
return {"prompt": user_text, "response": assistant_text}
return None
return None
def load_dataset_ultrachat(dataset_name_or_path: str) -> DatasetDict:
"""
Load the UltraChat 200k dataset (HuggingFaceH4/ultrachat_200k) and keep only the *first turn*
(first user message and the assistant reply).
Returns a `DatasetDict` where each split contains:
- prompt: first user message content
- response: first assistant reply content
Parameters
----------
dataset_name_or_path : str
Typically "HuggingFaceH4/ultrachat_200k" or a local path.
data_dir : Optional[str]
Optional subdirectory (for local paths).
"""
dataset = load_dataset(dataset_name_or_path)
# We only keep examples that have a valid first (user, assistant) turn.
def has_first_turn(example):
messages = example.get("messages")
return _extract_first_turn(messages) is not None
dataset = dataset.filter(has_first_turn, num_proc=4)
def map_fn(example):
first = _extract_first_turn(example["messages"])
# Fallbacks for robustness (shouldn't be hit after filter, but just in case)
if first is None:
first = {"prompt": (example.get("prompt") or "").strip(), "response": ""}
return {"prompt": first["prompt"], "response": first["response"]}
# Remove original columns for a clean schema (infer from any available split)
cols_to_remove = None
for split_name in dataset.keys():
cols_to_remove = dataset[split_name].column_names
break
dataset = dataset.map(map_fn, remove_columns=cols_to_remove, num_proc=4)
dataset = DatasetDict(
{
new: dataset[old]
for old, new in {
"train_sft": "train",
"test_sft": "test",
}.items()
if old in dataset
}
)
return dataset
if __name__ == "__main__":
# Mirrors the style from your previous loaders: resolve path via env helper if available.
from dllm.utils import resolve_with_base_env
dataset_name_or_path = resolve_with_base_env(
"HuggingFaceH4/ultrachat_200k", "BASE_DATASETS_DIR"
)
dataset = load_dataset_ultrachat(dataset_name_or_path)
breakpoint()

377
dllm/dllm/data/utils.py Normal file
View File

@ -0,0 +1,377 @@
from datasets import (
Dataset,
DatasetDict,
IterableDatasetDict,
IterableDataset,
load_dataset,
load_from_disk,
)
from dllm.utils.utils import resolve_with_base_env, parse_spec, get_default_logger
logger = get_default_logger(__name__)
def load_sft_dataset(
dataset_args: str, load_preprocessed_data: bool = False
) -> DatasetDict:
"""
Examples of dataset_args:
- "tatsu-lab/alpaca"
- "OpenCoder-LLM/opc-sft-stage2[name:educational_instruct]"
- "tatsu-lab/alpaca[train:5000]"
- "tatsu-lab/alpaca[train:5000] | HuggingFaceH4/ultrachat_200k[train:5000]"
"""
from dllm.data.alpaca import load_dataset_alpaca
from dllm.data.opc import load_dataset_opc_sft
specs = [p.strip() for p in dataset_args.split("|") if p.strip()]
all_parts = []
for raw in specs:
dataset_name_or_path, kvs = parse_spec(raw)
dataset_name_or_path = resolve_with_base_env(
dataset_name_or_path, "BASE_DATASETS_DIR"
)
if load_preprocessed_data:
logger.info("Load preprocessed data from disk.")
ds = load_from_disk(dataset_name_or_path)
# Implement your customized dataset here
elif _match(dataset_name_or_path, "tatsu-lab/alpaca"):
ds = load_dataset_alpaca(dataset_name_or_path)
elif _match(dataset_name_or_path, "allenai/tulu-3-sft-mixture"):
ds = load_dataset(dataset_name_or_path)
ds = ds["train"].train_test_split(test_size=0.1, seed=42)
elif _match(dataset_name_or_path, "HuggingFaceTB/smoltalk"):
name = kvs.pop("name", "all")
ds = load_dataset(dataset_name_or_path, name=name)
elif _match(dataset_name_or_path, "OpenCoder-LLM/opc-sft-stage1") or _match(
dataset_name_or_path, "OpenCoder-LLM/opc-sft-stage2"
):
name = kvs.pop("name", None)
lang = kvs.pop("lang", None)
ds = load_dataset_opc_sft(dataset_name_or_path, name=name, lang=lang)
elif _match(dataset_name_or_path, "HuggingFaceH4/ultrachat_200k"):
ds = load_dataset(dataset_name_or_path)
ds = DatasetDict({"train": ds["train_sft"], "test": ds["test_sft"]})
else:
ds = load_dataset(dataset_name_or_path)
# Normalize to DatasetDict and apply per-split limits
ds = _ensure_datasetdict(ds)
ds = _truncate_dataset(ds, kvs)
all_parts.append(ds)
# If only one part, return as DatasetDict
if len(all_parts) == 1:
return _ensure_datasetdict(all_parts[0])
# Merge all parts into a single DatasetDict
merged = all_parts[0]
for part in all_parts[1:]:
merged = _merge_datasetdicts(merged, part)
return _ensure_datasetdict(merged)
def load_pt_dataset(
dataset_args: str, streaming: bool = True, load_preprocessed_data: bool = False
) -> DatasetDict | IterableDatasetDict:
"""
Examples of dataset_args:
- "mlfoundations/dclm-baseline-1.0"
- "OpenCoder-LLM/opc-fineweb-code-corpus"
- "OpenCoder-LLM/opc-fineweb-math-corpus"
- "OpenCoder-LLM/opc-annealing-corpus[lang:python]"
- "wikitext[name:wikitext-103-v1}]"
"""
from dllm.data.opc import load_dataset_opc_annealing
specs = [p.strip() for p in dataset_args.split("|") if p.strip()]
if not specs:
raise ValueError("Empty dataset_args for load_pt_dataset.")
# ---------- Shared loader (only differs by streaming flag) ----------
def _load_base_dataset(
raw: str, *, streaming: bool
) -> tuple[DatasetDict | IterableDatasetDict, dict, str]:
"""
Returns: (base, kvs, dataset_name_or_path)
- Pops 'name' from kvs when applicable (e.g., wikitext).
- Applies identical matching logic for both streaming/non-streaming.
"""
dataset_name_or_path, kvs = parse_spec(raw)
dataset_name_or_path = resolve_with_base_env(
dataset_name_or_path, "BASE_DATASETS_DIR"
)
name = kvs.pop("name", None)
if load_preprocessed_data:
base = load_from_disk(dataset_name_or_path)
elif _match(dataset_name_or_path, ["OpenCoder-LLM/opc-annealing-corpus"]):
lang = kvs.pop("lang", None)
base = load_dataset_opc_annealing(
dataset_name_or_path, name=name, lang=lang, streaming=streaming
)
else:
base = load_dataset(dataset_name_or_path, name=name, streaming=streaming)
return base, kvs, dataset_name_or_path
# ---------- Streaming path ----------
def _load_one_streaming_spec(raw: str) -> IterableDatasetDict:
base, kvs, dataset_name_or_path = _load_base_dataset(raw, streaming=True)
split_names = list(base.keys())
single_split = len(split_names) == 1
single_split_name = split_names[0] if single_split else None
n_train = kvs.get("train")
n_test = kvs.get("test")
if (n_train is not None) or (n_test is not None):
if (n_train is not None) and (n_test is not None):
if single_split:
stream = base[single_split_name]
head = stream.take(n_train + n_test)
test = head.take(n_test)
train = head.skip(n_test).take(n_train)
return IterableDatasetDict({"train": train, "test": test})
else:
if "train" not in base or "test" not in base:
raise ValueError(
f"{dataset_name_or_path}: require 'train' and 'test' splits for train+test limits."
)
train = base["train"].take(n_train)
test = base["test"].take(n_test)
return IterableDatasetDict({"train": train, "test": test})
if n_train is not None:
if single_split:
train = base[single_split_name].take(n_train)
else:
if "train" not in base:
raise ValueError(
f"{dataset_name_or_path}: missing 'train' split for train limit."
)
train = base["train"].take(n_train)
return IterableDatasetDict({"train": train})
if n_test is not None:
if single_split:
test = base[single_split_name].take(n_test)
else:
if "test" not in base:
raise ValueError(
f"{dataset_name_or_path}: missing 'test' split for test limit."
)
test = base["test"].take(n_test)
return IterableDatasetDict({"test": test})
return base # already an IterableDatasetDict
# ---------- Non-streaming path (mirror load_sft_dataset; NO shuffle) ----------
def _load_one_nonstreaming_spec(raw: str) -> DatasetDict:
base, kvs, _ = _load_base_dataset(raw, streaming=False)
ds = _ensure_datasetdict(base) # normalize
ds = _truncate_dataset(ds, kvs) # apply limits (train/test/...)
return ds
# ---------- Load & Merge ----------
if streaming:
logger.info("Loading dataset in streaming mode.")
parts = [_load_one_streaming_spec(raw) for raw in specs]
merged = parts[0]
for p in parts[1:]:
merged = _merge_iterabledatasetdicts(merged, p)
# repeat streaming dataset infinitely
merged = IterableDatasetDict(
{k: (v.repeat(None) if k == "train" else v) for k, v in merged.items()}
)
return merged
else:
logger.info("Loading dataset in non-streaming mode.")
parts = [_load_one_nonstreaming_spec(raw) for raw in specs]
if len(parts) == 1:
return _ensure_datasetdict(parts[0])
merged = parts[0]
for p in parts[1:]:
merged = _merge_datasetdicts(merged, p)
return _ensure_datasetdict(merged)
def _truncate_split(split_data, n: int):
if n is None:
return split_data
try:
if hasattr(split_data, "select"):
# Hugging Face Dataset path
total = getattr(split_data, "num_rows", None)
if total is None:
# some Dataset types expose len(...)
total = len(split_data)
idx = list(range(min(n, total)))
return split_data.select(idx)
except Exception:
pass
try:
return split_data[:n]
except Exception:
# Last resort: iterate
return type(split_data)(item for i, item in enumerate(split_data) if i < n)
def _truncate_dataset(ds, limits: dict):
"""
Ensure and return a DatasetDict, truncating splits mentioned in `limits`.
"""
ds = _ensure_datasetdict(ds) # normalize first
out = {}
for split, data in ds.items():
n = limits.get(split, None)
out[split] = _truncate_split(data, n) if n is not None else data
return DatasetDict(out)
def _concat_splits(a, b):
"""
Concatenate two split objects (prefer 🤗 datasets).
"""
if a is b:
return a
if a is None:
return b
if b is None:
return a
# Prefer datasets' concatenate_datasets when both are Datasets
try:
from datasets import concatenate_datasets
if isinstance(a, Dataset) and isinstance(b, Dataset):
return concatenate_datasets([a, b])
except Exception:
pass
# Fallbacks
try:
return a + b
except Exception:
pass
try:
return type(a)(list(a) + list(b))
except Exception:
pass
raise TypeError(
f"Cannot concatenate split objects of types {type(a)} and {type(b)}"
)
def _merge_datasetdicts(d1, d2):
"""
Merge two DatasetDict-like mappings by concatenating splits present in either.
Always returns a DatasetDict.
"""
d1 = _ensure_datasetdict(d1)
d2 = _ensure_datasetdict(d2)
all_splits = set(d1.keys()) | set(d2.keys())
out = {}
for split in all_splits:
a = d1.get(split, None)
b = d2.get(split, None)
if a is None:
out[split] = b
elif b is None:
out[split] = a
else:
out[split] = _concat_splits(a, b)
return DatasetDict(out)
def _ensure_datasetdict(ds):
"""
Normalize various loader outputs into a DatasetDict.
- If loader returns a DatasetDict, return as is.
- If loader returns a mapping (e.g., dict of splits), wrap into DatasetDict.
- If loader returns a single Dataset/list/etc., assume it's 'train'.
"""
if isinstance(ds, DatasetDict):
return ds
if isinstance(ds, dict):
# Try to convert each split value to a Dataset if they aren't already.
# If they are already Datasets, DatasetDict will accept them directly.
return DatasetDict(ds)
# Single split -> assume train
return DatasetDict({"train": ds})
def _match(name: str, needle) -> bool:
"""
Returns True if `name` matches any of the provided needles.
Accepts a single string or a list/tuple of strings.
Match condition: name endswith(needle) or needle in name.
"""
if isinstance(needle, (list, tuple)):
return any(name.endswith(n) or n in name for n in needle)
return name.endswith(needle) or needle in name
def _concat_iterable_datasets(parts: list[IterableDataset]) -> IterableDataset:
"""
Concatenate IterableDatasets sequentially without materialization.
Preserves streaming nature; supports downstream .take()/.skip()/.shuffle().
"""
if not parts:
raise ValueError("No IterableDatasets to concatenate.")
# Try to reuse features from the first dataset when available
features = getattr(parts[0], "features", None)
def _gen():
for ds in parts:
yield from ds
return IterableDataset.from_generator(_gen, features=features)
def _ensure_iterabledatasetdict(obj) -> IterableDatasetDict:
if isinstance(obj, IterableDatasetDict):
return obj
if isinstance(obj, dict):
return IterableDatasetDict(obj)
# Single stream -> assume train
return IterableDatasetDict({"train": obj})
def _merge_iterabledatasetdicts(
d1: IterableDatasetDict, d2: IterableDatasetDict
) -> IterableDatasetDict:
"""
Merge by concatenating any overlapping splits (streaming-safe).
"""
d1 = _ensure_iterabledatasetdict(d1)
d2 = _ensure_iterabledatasetdict(d2)
all_splits = set(d1.keys()) | set(d2.keys())
out = {}
for split in all_splits:
a = d1.get(split, None)
b = d2.get(split, None)
if a is None:
out[split] = b
elif b is None:
out[split] = a
else:
out[split] = _concat_iterable_datasets([a, b])
return IterableDatasetDict(out)
def _truncate_stream(ds: IterableDataset, n: int | None) -> IterableDataset:
if n is None:
return ds
return ds.take(n)
if __name__ == "__main__":
breakpoint()

View File

@ -0,0 +1 @@
from . import llada, dream, rnd, editflow

View File

@ -0,0 +1,362 @@
"""
accelerate launch \
--num_processes 2 \
dllm/pipelines/bert/eval.py \
--tasks gsm8k \
--batch_size 1 \
--model bert \
--device cuda \
--num_fewshot 8 \
--model_args "pretrained=dllm-collection/ModernBERT-base-chat-v0,is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
"""
from types import SimpleNamespace
from dataclasses import dataclass
import accelerate
import torch
import torch.nn.functional as F
from datasets import Dataset
from tqdm import tqdm
from lm_eval.__main__ import cli_evaluate
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import get_dtype
import dllm
from dllm.pipelines.llada import LLaDAGenerator, LLaDAGeneratorConfig
@dataclass
class BERTEvalConfig(LLaDAGeneratorConfig):
max_new_tokens: int = 128
max_length: int = 512
steps: int = 128
block_length: int = 128
pretrained: str = ""
dtype: str | torch.dtype = "auto"
batch_size: int = 32
mc_num: int = 128
is_check_greedy: bool = True
device: str = "cuda"
@register_model("bert")
class BERTEvalHarness(LM):
def __init__(
self,
config: BERTEvalConfig | None = None,
**kwargs,
):
super().__init__()
# Initialize config if not provided
if config is None:
config = BERTEvalConfig()
# Pull args from config, allow kwargs to override
pretrained = kwargs.get("pretrained", config.pretrained)
dtype = kwargs.get("dtype", config.dtype)
batch_size = kwargs.get("batch_size", config.batch_size)
mc_num = kwargs.get("mc_num", config.mc_num)
is_check_greedy = kwargs.get("is_check_greedy", config.is_check_greedy)
device = kwargs.get("device", config.device)
cfg = kwargs.get("cfg", config.cfg_scale)
steps = kwargs.get("steps", config.steps)
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
block_length = kwargs.get("block_length", config.block_length)
max_length = kwargs.get("max_length", config.max_length)
remasking = kwargs.get("remasking", config.remasking)
accelerator = accelerate.Accelerator()
# Get GLOBAL rank from torch.distributed (not accelerator)
if torch.distributed.is_initialized():
self._rank = torch.distributed.get_rank() # ← GLOBAL rank (0-15)
self._world_size = (
torch.distributed.get_world_size()
) # ← GLOBAL world size (16)
else:
self._rank = 0
self._world_size = 1
# Use accelerator for device placement
self.model = dllm.utils.get_model(
SimpleNamespace(model_name_or_path=pretrained, dtype=get_dtype(dtype))
)
self.model.eval()
if accelerator.num_processes > 1:
# Let accelerator handle device placement
self.model = accelerator.prepare(self.model)
self.device = (
accelerator.device
) # ← Accelerator figures out local device correctly
self.accelerator = accelerator
else:
# Single GPU
self.model = self.model.to(device)
self.device = torch.device(device)
self.accelerator = None
self.tokenizer = dllm.utils.get_tokenizer(
SimpleNamespace(model_name_or_path=pretrained, model=self.model)
)
# generation params
self.mask_id = self.tokenizer.mask_token_id
self.batch_size = int(batch_size)
self.max_length = int(max_length)
self.max_new_tokens = int(max_new_tokens)
self.block_length = int(block_length)
self.steps = int(steps)
self.cfg = float(cfg)
self.remasking = remasking
self.is_check_greedy = is_check_greedy
# loglikelihood params
self.mc_num = int(mc_num)
assert mc_num % self.batch_size == 0
self.sampling_eps = 0.0
def apply_chat_template(
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
"""
Method to apply a chat template to a list of chat history between user and model.
"""
chat_templated = self.tokenizer.apply_chat_template(
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
return chat_templated
@property
def tokenizer_name(self) -> str:
return self.tokenizer.name_or_path.replace("/", "__")
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def _forward_process(
self, batch: torch.Tensor, prompt_index: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
b, l = batch.shape
target_len = (l - prompt_index.sum()).item()
k = torch.randint(1, target_len + 1, (), device=batch.device)
x = torch.round(
torch.linspace(
float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device
)
).long()
x = ((x - 1) % target_len) + 1
assert x.min() >= 1 and x.max() <= target_len
indices = torch.arange(target_len, device=batch.device).repeat(b, 1)
is_mask = indices < x.unsqueeze(1)
for i in range(b):
is_mask[i] = is_mask[i][torch.randperm(target_len)]
is_mask = torch.cat(
(
torch.zeros(
b, prompt_index.sum(), dtype=torch.bool, device=batch.device
),
is_mask,
),
dim=1,
)
noisy_batch = torch.where(is_mask, self.mask_id, batch)
return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l)
@torch.no_grad()
def get_logits(
self, batch: torch.Tensor, prompt_index: torch.Tensor
) -> torch.Tensor:
if self.cfg > 0.0:
assert len(prompt_index) == batch.shape[1]
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
un_batch = batch.clone()
un_batch[prompt_index] = self.mask_id
batch = torch.cat([batch, un_batch])
logits = self.model(batch).logits
if self.cfg > 0.0:
logits, un_logits = torch.chunk(logits, 2, dim=0)
logits = un_logits + (self.cfg + 1) * (logits - un_logits)
return logits[:, : batch.shape[1]]
@torch.no_grad()
def get_loglikelihood(self, prefix: torch.Tensor, target: torch.Tensor) -> float:
seq = torch.concatenate([prefix, target])[None, :]
seq = seq.repeat((self.batch_size, 1)).to(self.device)
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
loss_acc = []
for _ in range(self.mc_num // self.batch_size):
perturbed_seq, p_mask = self._forward_process(seq, prompt_index)
mask_indices = perturbed_seq == self.mask_id
logits = self.get_logits(perturbed_seq, prompt_index)
loss = (
F.cross_entropy(
logits[mask_indices], seq[mask_indices], reduction="none"
)
/ p_mask[mask_indices]
)
loss = loss.sum() / self.batch_size
loss_acc.append(loss.item())
return -sum(loss_acc) / len(loss_acc)
@torch.no_grad()
def suffix_greedy_prediction(
self, prefix: torch.Tensor, target: torch.Tensor
) -> bool:
if not self.is_check_greedy:
return False
seq = torch.full(
(1, len(prefix) + len(target)), self.mask_id, device=self.device
)
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
prefix, target = prefix.to(self.device), target.to(self.device)
seq[0, : len(prefix)] = prefix
for i in range(len(target)):
mask_index = seq == self.mask_id
logits = self.get_logits(seq, prompt_index)[mask_index]
x0 = torch.argmax(logits, dim=-1)
p = torch.softmax(logits.to(torch.float32), dim=-1)
confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(
dim=-1
)
_, index = torch.sort(confidence, descending=True)
x0[index[1:]] = self.mask_id
seq[mask_index] = x0.clone()
correct = target == seq[0, len(prefix) :]
correct = torch.all(correct)
return correct
def _encode_pair(
self, context: str, continuation: str
) -> tuple[torch.Tensor, torch.Tensor]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tokenizer(context + continuation)["input_ids"]
context_enc = self.tokenizer(context)["input_ids"]
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
def _tokenize(e):
prefix, target = self._encode_pair(e["prefix"], e["target"])
return {
"prefix_text": e["prefix"],
"target_text": e["target"],
"prefix": prefix,
"target": target,
}
ds = []
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
ds = Dataset.from_list(ds)
ds = ds.map(_tokenize)
ds = ds.with_format("torch")
prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds]
assert max(prompt_len) <= 4096
out = []
with torch.no_grad():
for elem in tqdm(ds, desc="Computing likelihood..."):
prefix = elem["prefix"]
target = elem["target"]
ll = self.get_loglikelihood(prefix, target)
is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target)
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
torch.cuda.empty_cache()
return out
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
raise NotImplementedError
def generate_until(self, requests: list[Instance]):
def _tokenize(e):
return {
"question": self.tokenizer(e["question"])["input_ids"],
"question_text": e["question"],
"until": e["until"],
}
ds = [
{"question": req.args[0], "until": req.args[1]["until"]} for req in requests
]
ds = Dataset.from_list(ds)
ds = ds.map(_tokenize)
ds = ds.with_format("torch")
out = []
generator = LLaDAGenerator(model=self.model, tokenizer=self.tokenizer)
for elem in tqdm(ds, desc="Generating..."):
prompt = [elem["question"][1:-1].to(self.device)]
stop_tokens = elem["until"]
generated_ids = generator.generate(
inputs=prompt,
steps=self.steps,
max_new_tokens=self.max_new_tokens,
block_length=self.block_length,
temperature=0.0,
cfg_scale=self.cfg,
remasking=self.remasking,
)
generated_answer = self.tokenizer.decode(
generated_ids[0][prompt[0].shape[0] :], skip_special_tokens=False
)
breakpoint()
for stop_seq in stop_tokens:
if stop_seq in generated_answer:
generated_answer = generated_answer.split(stop_seq)[0]
# remove special tokens
generated_answer_ids = self.tokenizer(generated_answer)["input_ids"]
generated_answer = self.tokenizer.decode(
generated_answer_ids, skip_special_tokens=True
)
out.append(generated_answer)
if self.accelerator is not None:
self.accelerator.wait_for_everyone()
return out
if __name__ == "__main__":
cli_evaluate()

View File

@ -0,0 +1,6 @@
from . import generator, models, trainer, utils
from .models.modeling_dream import DreamModel
from .models.configuration_dream import DreamConfig
from .models.tokenization_dream import DreamTokenizer
from .generator import DreamGeneratorConfig, DreamGenerator
from .trainer import DreamTrainer

View File

@ -0,0 +1,533 @@
"""
accelerate launch \
--num_processes 2 \
dllm/pipelines/dream/eval.py \
--tasks gsm8k \
--batch_size 1 \
--model dream \
--device cuda
--num_fewshot 0 \
--model_args "pretrained=Dream-org/Dream-v0-Base-7B,mc_num=1,max_new_tokens=512,max_length=512,steps=512,temperature=0.2,top_p=0.95,add_bos_token=true,escape_until=true"
"""
import logging
from types import SimpleNamespace
from dataclasses import dataclass
import accelerate
import torch
import torch.nn.functional as F
from datasets import Dataset
from tqdm import tqdm
from lm_eval.__main__ import cli_evaluate
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import get_dtype
import dllm
from dllm.pipelines.dream import DreamGenerator, DreamGeneratorConfig
eval_logger = logging.getLogger(__name__)
@dataclass
class DreamEvalConfig(DreamGeneratorConfig):
top_p: float | None = None
top_k: float | None = None
max_new_tokens: int = 128
max_length: int = 2048
steps: int = 128
temperature: float = 0.0
alg: str = "entropy"
pretrained: str = ""
batch_size: int = 1
device: str = "cuda"
dtype: str | torch.dtype = "auto"
add_bos_token: bool = False
nll_type: str = "mc"
log_type: str = "ftb"
mc_num: int = 128
classifier_free_guidance: float = 1.0
sampling_eps: float = 1e-3
escape_until: bool = False
@register_model("dream")
class DreamEvalHarness(LM):
def __init__(
self,
config: DreamEvalConfig | None = None,
**kwargs,
) -> None:
super().__init__()
# Initialize config if not provided
if config is None:
config = DreamEvalConfig()
# Pull args from config, allow kwargs to override
pretrained = kwargs.get("pretrained", config.pretrained)
batch_size = kwargs.get("batch_size", config.batch_size)
device = kwargs.get("device", config.device)
dtype = kwargs.get("dtype", config.dtype)
max_length = kwargs.get("max_length", config.max_length)
add_bos_token = kwargs.get("add_bos_token", config.add_bos_token)
nll_type = kwargs.get("nll_type", config.nll_type)
log_type = kwargs.get("log_type", config.log_type)
mc_num = kwargs.get("mc_num", config.mc_num)
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
classifier_free_guidance = kwargs.get(
"classifier_free_guidance", config.classifier_free_guidance
)
sampling_eps = kwargs.get("sampling_eps", config.sampling_eps)
steps = kwargs.get("steps", config.steps)
temperature = kwargs.get("temperature", config.temperature)
top_p = kwargs.get("top_p", config.top_p)
top_k = kwargs.get("top_k", config.top_k)
alg = kwargs.get("alg", config.alg)
alg_temp = kwargs.get("alg_temp", config.alg_temp)
escape_until = kwargs.get("escape_until", config.escape_until)
accelerator = accelerate.Accelerator()
# Get GLOBAL rank from torch.distributed (not accelerator)
if torch.distributed.is_initialized():
self._rank = torch.distributed.get_rank() # ← GLOBAL rank (0-15)
self._world_size = (
torch.distributed.get_world_size()
) # ← GLOBAL world size (16)
else:
self._rank = 0
self._world_size = 1
# Use accelerator for device placement
self.model = dllm.utils.get_model(
SimpleNamespace(model_name_or_path=pretrained, dtype=get_dtype(dtype))
)
self.model.eval()
if accelerator.num_processes > 1:
# Let accelerator handle device placement
self.model = accelerator.prepare(self.model)
self.device = (
accelerator.device
) # ← Accelerator figures out local device correctly
self.accelerator = accelerator
else:
# Single GPU
self.model = self.model.to(device)
self.device = torch.device(device)
self.accelerator = None
self.tokenizer = dllm.utils.get_tokenizer(
SimpleNamespace(model_name_or_path=pretrained, model=self.model)
)
# generation params
self.mask_id = self.tokenizer.mask_token_id
self.max_length = max_length
self.add_bos_token = add_bos_token
self.batch_size = int(batch_size)
self.max_new_tokens = max_new_tokens
self.steps = steps
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.alg = alg
self.alg_temp = alg_temp
self.escape_until = escape_until
# loglikelihood params
self.nll_type = nll_type
self.log_type = log_type
self.mc_num = mc_num
self.classifier_free_guidance = classifier_free_guidance
self.sampling_eps = sampling_eps
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def tok_decode(
self, tokens: torch.Tensor | list[int], skip_special_tokens: bool = True
) -> str:
return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def tok_encode(self, text: str, add_special_tokens: bool = True) -> torch.Tensor:
return self.tokenizer(
text, return_tensors="pt", add_special_tokens=add_special_tokens
).input_ids
def apply_chat_template(
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
"""
Method to apply a chat template to a list of chat history between user and model.
"""
chat_templated = self.tokenizer.apply_chat_template(
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
return chat_templated
@property
def tokenizer_name(self) -> str:
return self.tokenizer.name_or_path.replace("/", "__")
def generate_until(
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[str]:
res = []
pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running generate_until requests",
)
generator = DreamGenerator(model=self.model, tokenizer=self.tokenizer)
for batch_idx in range(0, len(requests), self.batch_size):
batch_requests = requests[batch_idx : batch_idx + self.batch_size]
contexts, gen_args = zip(*[req.arguments for req in batch_requests])
# ====== BEGIN merged _generate_batch logic ======
prompts = list(contexts)
if self.add_bos_token:
prompts = [self.tokenizer.bos_token + p for p in prompts]
# tokenize
prompt_ids = [
self.tokenizer(
p, return_tensors="pt", padding=False
).input_ids.squeeze()
for p in prompts
]
prompt_lens = [len(p_id) for p_id in prompt_ids]
if max(prompt_lens) > self.max_length - self.max_new_tokens:
cutoff_len = self.max_length - self.max_new_tokens
eval_logger.warning(
f"Prompt length {max(prompt_lens)} exceeds {cutoff_len}, cutoff on the left side"
)
# ✅ Correct: trim from the left side (keep the last cutoff_len tokens)
prompt_ids = [p_id[-cutoff_len:] for p_id in prompt_ids]
# generation
generation_ids = generator.generate(
max_new_tokens=self.max_new_tokens,
inputs=prompt_ids,
steps=self.steps,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
alg=self.alg,
alg_temp=self.alg_temp,
output_history=False,
return_dict_in_generate=False,
)
# decode and cleanup
cleaned_generation_ids = [
(
seq[seq.ne(self.tokenizer.eos_token_id).float().argmax().long() :]
if (seq != self.tokenizer.eos_token_id).any()
else seq[-1:]
)
for seq in generation_ids
]
truncated_generation_ids = [
seq[prompt_lens[i] :] for i, seq in enumerate(cleaned_generation_ids)
]
responses = [
g.lstrip("<|endoftext|>").split(self.tokenizer.eos_token, 1)[0]
for g in self.tokenizer.batch_decode(truncated_generation_ids)
]
# ====== END merged _generate_batch logic ======
# handle "until" truncation
if not self.escape_until:
for i, r in enumerate(responses):
for s in gen_args[0]["until"]:
r = r.split(s)[0]
responses[i] = r
res.extend(responses)
pbar.update(len(contexts))
return res
def _forward_process(
self, batch: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
b, l = batch.shape
# sample from U[0, 1] following https://arxiv.org/pdf/2107.00630 I.1
u0 = torch.rand(1, device=batch.device, dtype=torch.float32)
indices = torch.arange(b, device=batch.device).float()
t = (u0 + indices / b) % 1
p_mask = (1 - self.sampling_eps) * t + self.sampling_eps
p_mask = p_mask[:, None].repeat(1, l)
mask_indices = torch.rand((b, l), device=batch.device) < p_mask
# always unmask bos and eos
mask_indices[:, 0] = False
mask_indices[:, -1] = False
noisy_batch = torch.where(mask_indices, self.mask_id, batch)
return noisy_batch, p_mask
@torch.no_grad()
def get_logits(
self, batch: torch.Tensor, prompt_index: torch.Tensor
) -> torch.Tensor:
"""
prompt_index : 1D bool tensor, length=batch.shape[1]
"""
if self.classifier_free_guidance > 1.0:
assert len(prompt_index) == batch.shape[1]
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
un_batch = batch.clone()
un_batch[prompt_index] = self.mask_id
batch = torch.cat([batch, un_batch])
input = batch
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
logits = self.model(input).logits
# since bos always unmask, the first logits will not be used
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
if self.classifier_free_guidance > 1.0:
logits, un_logits = torch.chunk(logits, 2, dim=0)
logits = un_logits + self.cfg * (logits - un_logits)
return logits[:, : batch.shape[1]]
@torch.no_grad()
def _eval_target_nll_mc(
self, prefix: torch.Tensor | None, target: torch.Tensor
) -> float:
if prefix is None:
seq = target[None, :]
else:
seq = torch.concatenate([prefix, target])[None, :]
seq = seq.repeat((self.batch_size, 1)).to(self.device)
if self.log_type == "ftb":
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
else:
prompt_index = torch.arange(seq.shape[1], device=self.device) >= len(prefix)
loss_acc = []
for _ in range(max(self.mc_num // self.batch_size, 1)):
perturbed_seq = seq.clone()
# eval_logger.info("before noising")
perturbed_seq_, p_mask = self._forward_process(seq)
# eval_logger.info("end noising")
if self.log_type == "ftb":
perturbed_seq[:, -len(target) :] = perturbed_seq_[:, -len(target) :]
elif self.log_type == "btf":
perturbed_seq[:, : len(prefix)] = perturbed_seq_[:, : len(prefix)]
elif self.log_type == "union":
perturbed_seq = perturbed_seq_
else:
raise NotImplementedError(self.log_type)
mask_indices = perturbed_seq == self.mask_id
logits = self.get_logits(perturbed_seq, prompt_index)
loss = (
F.cross_entropy(
logits[mask_indices], seq[mask_indices], reduction="none"
)
/ p_mask[mask_indices]
)
loss = loss.sum() / self.batch_size
loss_acc.append(loss.item())
return sum(loss_acc) / len(loss_acc)
@torch.no_grad()
def _eval_target_nll_ar(self, prefix: torch.Tensor, target: torch.Tensor) -> float:
prefix, target = prefix.unsqueeze(0), target.unsqueeze(0) # 1*l1, 1*l2
assert self.log_type in ["ftb", "btf"]
assert self.nll_type in ["ar_ftb", "ar_btf"]
if self.log_type == "ftb":
prompt_index = (
torch.arange(prefix.shape[1] + target.shape[1], device=self.device)
< prefix.shape[1]
)
else:
prompt_index = (
torch.arange(prefix.shape[1] + target.shape[1], device=self.device)
>= prefix.shape[1]
)
if self.log_type == "ftb":
perturbed_ = target.repeat(target.shape[1], 1).clone().contiguous() # l2*l2
else:
perturbed_ = prefix.repeat(prefix.shape[1], 1).clone().contiguous() # l1*l1
mask_index = torch.ones(
(perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool
)
if self.nll_type == "ar_ftb":
mask_index = torch.triu(mask_index)
else:
mask_index = torch.tril(mask_index)
perturbed_[mask_index] = self.mask_id
if self.log_type == "ftb":
perturbed_seq = torch.cat(
[prefix.repeat(perturbed_.shape[0], 1), perturbed_], dim=-1
)
else:
perturbed_seq = torch.cat(
[perturbed_, target.repeat(perturbed_.shape[0], 1)], dim=-1
)
logits_ = []
num = (
len(perturbed_seq) // self.batch_size
if len(perturbed_seq) % self.batch_size == 0
else len(perturbed_seq) // self.batch_size + 1
)
for i in range(num):
end = (
(i + 1) * self.batch_size
if (i + 1) * self.batch_size < len(perturbed_seq)
else len(perturbed_seq)
)
perturbed_seq_ = perturbed_seq[i * self.batch_size : end]
perturbed_seq_ = perturbed_seq_.to(self.device)
if len(perturbed_seq_.shape) == 1:
perturbed_seq_ = perturbed_seq_.unsqueeze(0)
logits = self.get_logits(perturbed_seq_, prompt_index)
logits_.append(logits.cpu())
logits = torch.cat(logits_, dim=0)
temp_index = torch.ones(
(perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool
)
if self.nll_type == "ar_ftb":
temp_index = torch.triu(temp_index, diagonal=1)
else:
temp_index = torch.tril(temp_index, diagonal=-1)
mask_index[temp_index] = False
if self.log_type == "ftb":
logits_index = torch.cat(
[
torch.zeros(
(perturbed_.shape[1], prefix.shape[1]), dtype=torch.bool
),
mask_index,
],
dim=-1,
)
else:
logits_index = torch.cat(
[
mask_index,
torch.zeros(
(perturbed_.shape[1], target.shape[1]), dtype=torch.bool
),
],
dim=-1,
)
if self.log_type == "ftb":
loss = (
F.cross_entropy(logits[logits_index], target[0], reduction="sum")
.cpu()
.item()
)
else:
loss = (
F.cross_entropy(logits[logits_index], prefix[0], reduction="sum")
.cpu()
.item()
)
return loss
def _encode_pair(
self, context: str, continuation: str
) -> tuple[torch.Tensor, torch.Tensor]:
if self.add_bos_token:
context = self.tokenizer.bos_token + context
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tokenizer.encode(context + continuation) + [
self.tokenizer.eos_token_id
]
context_enc = self.tokenizer.encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
# by default truncate on the left
cutoff_length = max(len(whole_enc) - self.max_length, 0)
if cutoff_length > 0:
eval_logger.warning(
f"Text length {len(whole_enc)} is larger than {self.max_length}, cutoff on the left side"
)
context_remain = context_enc_len - cutoff_length
if context_remain > 0:
context_enc = context_enc[-context_remain:]
else:
eval_logger.warning(f"All context (prompt) is truncated.")
context_enc = ""
continuation_enc = whole_enc[-self.max_length :]
return context_enc, continuation_enc
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
def _tokenize(e):
prefix, target = self._encode_pair(e["prefix"], e["target"])
return {
"prefix_text": e["prefix"],
"target_text": e["target"],
"prefix": prefix,
"target": target,
}
ds = []
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
ds = Dataset.from_list(ds)
ds = ds.map(_tokenize)
ds = ds.with_format("torch")
out = []
with torch.no_grad():
for elem in tqdm(ds, desc="Computing likelihood..."):
prefix = elem["prefix"]
target = elem["target"]
# likelihood calculations are modified from https://github.com/ML-GSAI/SMDM/blob/main/evaluate_diff.py
if self.nll_type == "mc":
ll = -self._eval_target_nll_mc(prefix, target)
if self.log_type == "union":
ll = ll / (len(target) + len(prefix))
elif self.nll_type == "ar_ftb" or self.nll_type == "ar_btf":
ll = -self._eval_target_nll_ar(prefix, target)
else:
raise NotImplementedError(self.nll_type)
# TODO: greedy decoding
is_target_greedy_dec = False
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
return out
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
raise NotImplementedError
if __name__ == "__main__":
cli_evaluate()

View File

@ -0,0 +1,426 @@
"""
reference: https://huggingface.co/Dream-org/Dream-v0-Base-7B/blob/main/generation_utils.py
"""
from dataclasses import dataclass
import torch
import torch.nn.functional as F
import torch.distributions as dists
from dllm.utils.generation_utils import get_num_transfer_tokens
from dllm.pipelines.dream.utils import top_p_logits, top_k_logits
from dllm.core.generation.generator import (
GeneratorOutput,
GeneratorConfig,
BaseGenerator,
)
def sample_tokens(
logits: torch.Tensor,
temperature: float = 0.0,
top_p: float | None = None,
top_k: int | None = None,
margin_confidence: bool = False,
neg_entropy: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
if temperature > 0:
logits = logits / temperature
if top_p is not None and top_p < 1:
logits = top_p_logits(logits, top_p)
if top_k is not None:
logits = top_k_logits(logits, top_k)
probs = torch.softmax(logits, dim=-1)
if temperature > 0:
try:
x0 = dists.Categorical(probs=probs).sample()
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
except Exception:
confidence, x0 = probs.max(dim=-1)
else:
confidence, x0 = probs.max(dim=-1)
if margin_confidence:
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
top1_probs = sorted_probs[:, 0]
top2_probs = sorted_probs[:, 1]
confidence = top1_probs - top2_probs
if neg_entropy:
epsilon = 1e-10
log_probs = torch.log(probs + epsilon)
confidence = torch.sum(probs * log_probs, dim=-1)
return confidence, x0
@dataclass
class DreamGeneratorConfig(GeneratorConfig):
max_new_tokens: int = 20
max_length: int = (
None # The max_length is set as input_ids.shape[1] + 20: generation_config.max_length = generation_config.max_length + input_ids_length
)
steps: int = 512
eps: float = 1e-3
alg: str = "origin"
alg_temp: float = 0.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = 50
stochastic_transfer: bool = False
@dataclass
class DreamGenerator(BaseGenerator):
@torch.no_grad()
def generate(
self,
inputs: list[torch.Tensor, list],
config: DreamGeneratorConfig | None = None,
generation_tokens_hook_func=lambda step, x, logits: x,
generation_logits_hook_func=lambda step, x, logits: logits,
**kwargs,
) -> GeneratorOutput | torch.Tensor:
"""
Diffusion-style masked decoding for *generation from inputs*.
(docstring unchanged)
"""
if config is None:
config = DreamGeneratorConfig()
# ----- pull args from config, allow kwargs to override -----
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
max_length = kwargs.get("max_length", config.max_length)
steps = kwargs.get("steps", config.steps)
eps = kwargs.get("eps", config.eps)
alg = kwargs.get("alg", config.alg)
alg_temp = kwargs.get("eps", config.alg_temp)
temperature = kwargs.get("temperature", config.temperature)
top_p = kwargs.get("top_p", config.top_p)
top_k = kwargs.get("top_k", config.top_k)
stochastic_transfer = kwargs.get(
"stochastic_transfer", config.stochastic_transfer
)
# generation_tokens_hook_func = kwargs.get("generation_tokens_hook_func", config.generation_tokens_hook_func)
# generation_logits_hook_func = kwargs.get("generation_logits_hook_func", config.generation_logits_hook_func)
return_dict_in_generate = kwargs.get(
"return_dict_in_generate", config.return_dict_in_generate
)
# --- Initialization ---
mask_token_id = self.tokenizer.mask_token_id
eos_token_id = self.tokenizer.eos_token_id
if isinstance(inputs[0], list):
inputs = [
torch.as_tensor(p, dtype=torch.long, device=self.model.device)
for p in inputs
]
prompt_lens = [p.shape[0] for p in inputs]
if max_new_tokens:
max_length = max_new_tokens + max(prompt_lens)
else:
max_new_tokens = max_length - max(prompt_lens)
B = len(inputs)
T = max_length
x = torch.full((B, T), eos_token_id, dtype=torch.long, device=self.model.device)
seq_length = []
for i, p in enumerate(inputs):
total_len = prompt_lens[i] + max_new_tokens
seq_length.append(total_len)
start = T - total_len
x[i, start : start + prompt_lens[i]] = p
x[i, start + prompt_lens[i] : T] = mask_token_id
attention_mask = torch.zeros(
(B, T), dtype=torch.float32, device=self.model.device
)
for j, L in enumerate(seq_length):
if L > 0:
attention_mask[j, -L:] = 1.0 # Mandate to be left-padding
if attention_mask is not None and torch.any(attention_mask == 0.0):
pos_id = attention_mask.long().cumsum(-1) - 1
pos_id.masked_fill_(attention_mask == 0, 1)
else:
pos_id = None
mask_index = x == mask_token_id
num_transfer_tokens_list = get_num_transfer_tokens(
mask_index=mask_index,
steps=steps,
scheduler=self.scheduler,
stochastic=stochastic_transfer,
)
effective_steps = num_transfer_tokens_list.size(1)
# --- Iterative refinement ---
x = generation_tokens_hook_func(None, x, None)
histories = [x.clone()] if return_dict_in_generate else None
for i in range(effective_steps):
mask_index = x == mask_token_id
logits = self.model(x, attention_mask, pos_id).logits
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
logits = generation_logits_hook_func(i, x, logits)
mask_logits = logits[mask_index]
if alg == "maskgit_plus":
confidence, x0 = sample_tokens(
mask_logits, temperature=temperature, top_p=top_p, top_k=top_k
)
elif alg == "topk_margin":
confidence, x0 = sample_tokens(
mask_logits,
temperature=temperature,
top_p=top_p,
top_k=top_k,
margin_confidence=True,
)
elif alg == "entropy":
confidence, x0 = sample_tokens(
mask_logits,
temperature=temperature,
top_p=top_p,
top_k=top_k,
neg_entropy=True,
)
else:
raise RuntimeError(f"Unknown alg: {alg}")
full_confidence = torch.full_like(
x, -torch.inf, device=self.model.device, dtype=logits.dtype
)
full_confidence[mask_index] = confidence
for j in range(full_confidence.shape[0]):
number_transfer_tokens = num_transfer_tokens_list[j, i]
if number_transfer_tokens > 0:
if alg_temp is None or alg_temp == 0:
_, transfer_index = torch.topk(
full_confidence[j], number_transfer_tokens
)
else:
fc = full_confidence[j] / alg_temp
fc = F.softmax(fc, dim=-1)
transfer_index = torch.multinomial(
fc, num_samples=number_transfer_tokens
)
x_ = torch.full_like(x, mask_token_id, device=self.model.device)
x_[mask_index] = x0.clone()
x[j, transfer_index] = x_[j, transfer_index]
x = generation_tokens_hook_func(i, x, logits)
if histories is not None:
histories.append(x.clone())
if not return_dict_in_generate:
return x
else:
return GeneratorOutput(sequences=x, histories=histories)
@torch.no_grad()
def infill(
self,
inputs: list[torch.Tensor, list],
config,
generation_tokens_hook_func=lambda step, x, logits: x,
generation_logits_hook_func=lambda step, x, logits: logits,
**kwargs,
) -> GeneratorOutput | torch.Tensor:
"""
Fill in-place the tokenizer's `<mask>` tokens contained in `inputs`.
The whole (right-aligned) canvas is denoised iteratively: at each step, a scheduler
decides how many masked positions to commit, and a confidence rule (`alg`)
selects *which* positions to reveal (MaskGIT-style). Non-mask tokens are never changed.
High-level:
1) Build a right-aligned canvas per sample (left side padded with EOS).
2) Compute a per-sample transfer schedule via `scheduler` and `steps`.
3) At each step: forward pass → AR-shift logits → score masked positions
via `alg` → choose indices to commit (top-k or soft sampling) → write tokens.
Notes:
- Right padding uses EOS (serves as pad here).
- Only `[MASK]` positions are updated; original tokens remain intact.
- Logits are AR-shifted to preserve next-token prediction alignment.
Args:
model:
Mask predictor; returns logits of shape [B, T, V] when called as
`model(x, attention_mask, pos_id)`.
tokenizer:
Must provide `mask_token_id` and `eos_token_id`.
inputs:
List of 1D LongTensors (token ids). Each may contain `<mask>` tokens
to be filled; other tokens are treated as fixed context.
scheduler (BaseAlphaScheduler):
Controls how many masks to commit per step (deterministic or stochastic).
generation_tokens_hook_func / generation_logits_hook_func:
Optional hooks to intercept tokens/logits at each step.
output_history (bool):
If True, save intermediate canvases at each step.
return_dict_in_generate (bool):
If True, return `DreamModelOutput(sequences, history)`, else only `[B, T]`.
steps (int):
Total reverse-diffusion steps (qualityspeed trade-off).
alg (str):
Confidence rule to rank masked positions:
- "maskgit_plus": softmax probs
- "topk_margin": top1 - top2 margin
- "entropy": negative entropy
alg_temp (float):
Temperature for *confidence-based index sampling* (when > 0, soft selection).
temperature / top_p / top_k:
Token sampling hyperparameters within `sample_tokens`.
stochastic_transfer (bool):
If True, sample the number of transfers per step (Binomial); else use expectation.
Returns:
DreamModelOutput | torch.LongTensor:
If `return_dict_in_generate=True`, returns
- sequences: `[B, T]` final tokens
- history: optional list of intermediate canvases
Otherwise returns only `[B, T]`.
"""
# ----- pull args from config, allow kwargs to override -----
steps = kwargs.get("steps", config.steps)
eps = kwargs.get("eps", config.eps)
alg = kwargs.get("alg", config.alg)
alg_temp = kwargs.get("eps", config.alg_temp)
temperature = kwargs.get("temperature", config.temperature)
top_p = kwargs.get("top_p", config.top_p)
top_k = kwargs.get("top_k", config.top_k)
stochastic_transfer = kwargs.get(
"stochastic_transfer", config.stochastic_transfer
)
# generation_tokens_hook_func = kwargs.get("stochastic_transfer", config.generation_tokens_hook_func)
# generation_logits_hook_func = kwargs.get("stochastic_transfer", config.generation_logits_hook_func)
return_dict_in_generate = kwargs.get(
"return_dict_in_generate", config.return_dict_in_generate
)
# --- Initialization ---
mask_token_id = self.tokenizer.mask_token_id
eos_token_id = self.tokenizer.eos_token_id
if isinstance(inputs[0], list):
inputs = [
torch.as_tensor(p, dtype=torch.long, device=self.model.device)
for p in inputs
]
B = len(inputs)
seq_lens = [t.shape[0] for t in inputs]
T = max(seq_lens)
# Build right-aligned canvas; left side padded with EOS (used as pad)
x = torch.full((B, T), eos_token_id, dtype=torch.long, device=self.model.device)
for i, t in enumerate(inputs):
L = seq_lens[i]
x[i, -L:] = t
# Build 1D attention mask (valid tokens on the right)
attention_mask = torch.zeros((B, T), dtype=torch.bool, device=self.model.device)
for j, L in enumerate(seq_lens):
if L > 0:
attention_mask[j, -L:] = True
# Expand to pairwise attention if left padding is present
if torch.any(attention_mask == 0.0):
pos_id = attention_mask.long().cumsum(-1) - 1
pos_id.masked_fill_(attention_mask == 0, 1)
else:
pos_id = None
attention_mask = "full"
# Precompute per-sample transfer schedule (how many to commit per step)
mask_index = x == mask_token_id
num_transfer_tokens_list = get_num_transfer_tokens(
mask_index=mask_index,
steps=steps,
scheduler=self.scheduler,
stochastic=stochastic_transfer,
)
effective_steps = num_transfer_tokens_list.size(1)
# Optional initial token hook
x = generation_tokens_hook_func(None, x, None)
histories = [x.clone()] if return_dict_in_generate else None
for i in range(effective_steps):
mask_index = x == mask_token_id
# Forward pass, then AR-shift to predict token at position i+1
logits = self.model(x, attention_mask, pos_id).logits
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
logits = generation_logits_hook_func(i, x, logits)
# Logits restricted to current `[MASK]` positions
mask_logits = logits[mask_index]
# Confidence scoring for masked positions
if alg == "maskgit_plus":
confidence, x0 = sample_tokens(
mask_logits, temperature=temperature, top_p=top_p, top_k=top_k
)
elif alg == "topk_margin":
confidence, x0 = sample_tokens(
mask_logits,
temperature=temperature,
top_p=top_p,
top_k=top_k,
margin_confidence=True,
)
elif alg == "entropy":
confidence, x0 = sample_tokens(
mask_logits,
temperature=temperature,
top_p=top_p,
top_k=top_k,
neg_entropy=True,
)
else:
raise RuntimeError(f"Unknown alg: {alg}")
# Scatter per-position confidence back to full canvas
full_confidence = torch.full_like(
x, -torch.inf, device=self.model.device, dtype=logits.dtype
)
full_confidence[mask_index] = confidence
# Commit the scheduled number of tokens per sample
for j in range(B):
number_transfer_tokens = num_transfer_tokens_list[j, i]
if number_transfer_tokens > 0:
if alg_temp is None or alg_temp == 0:
_, transfer_index = torch.topk(
full_confidence[j], number_transfer_tokens
)
else:
fc = full_confidence[j] / alg_temp
fc = F.softmax(fc, dim=-1)
transfer_index = torch.multinomial(
fc, num_samples=number_transfer_tokens
)
# Candidate tokens at masked positions only
x_ = torch.full_like(x, mask_token_id, device=self.model.device)
x_[mask_index] = x0.clone()
x[j, transfer_index] = x_[j, transfer_index]
# Optional token hook + history logging
x = generation_tokens_hook_func(i, x, logits)
if histories is not None:
histories.append(x.clone())
if not return_dict_in_generate:
return x
else:
return GeneratorOutput(sequences=x, histories=histories)

View File

@ -0,0 +1,13 @@
from .configuration_dream import DreamConfig
from .modeling_dream import DreamModel
# Register with HuggingFace Auto classes for local usage
try:
from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM
AutoConfig.register("Dream", DreamConfig)
AutoModel.register(DreamConfig, DreamModel)
AutoModelForMaskedLM.register(DreamConfig, DreamModel)
except ImportError:
# transformers not available or Auto classes not imported
pass

View File

@ -0,0 +1,85 @@
# coding=utf-8
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dream model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class DreamConfig(PretrainedConfig):
model_type = "Dream"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=False, # cache not used in diffusion
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
mask_token_id=151666,
pad_token_id=151643,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window if use_sliding_window else None
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.mask_token_id = mask_token_id
self.pad_token_id = pad_token_id

View File

@ -0,0 +1,465 @@
# coding=utf-8
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import copy
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.distributions as dists
from torch.nn import functional as F
from transformers import __version__
from transformers.generation.configuration_utils import (
GenerationConfig
)
from transformers.utils import (
ModelOutput,
is_torchdynamo_compiling,
logging,
)
logger = logging.get_logger(__name__)
def top_p_logits(logits, top_p=None):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
return logits
def top_k_logits(logits, top_k=None):
top_k = min(top_k, logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
return logits
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
if temperature > 0:
logits = logits / temperature
if top_p is not None and top_p < 1:
logits = top_p_logits(logits, top_p)
if top_k is not None:
logits = top_k_logits(logits, top_k)
probs = torch.softmax(logits, dim=-1)
if temperature > 0:
try:
x0 = dists.Categorical(probs=probs).sample()
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
except:
confidence, x0 = probs.max(dim=-1)
else:
confidence, x0 = probs.max(dim=-1)
if margin_confidence:
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
# Extract top1 and top2 probabilities
top1_probs = sorted_probs[:, 0]
top2_probs = sorted_probs[:, 1]
# Calculate confidence as top1 - top2
confidence = top1_probs - top2_probs
if neg_entropy:
epsilon = 1e-10
log_probs = torch.log(probs + epsilon)
confidence = torch.sum(probs * log_probs, dim=-1)
return confidence, x0
@dataclass
class DreamModelOutput(ModelOutput):
sequences: torch.LongTensor = None
history: Optional[Tuple[torch.FloatTensor]] = None
class DreamGenerationConfig(GenerationConfig):
def __init__(self, **kwargs):
self.temperature: float = kwargs.pop("temperature", 0.0)
self.top_p: Optional[float] = kwargs.pop("top_p", None)
self.top_k: Optional[int] = kwargs.pop("top_k", None)
self.max_length = kwargs.pop("max_length", 20)
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
# diffusion specific params
self.eps: float = kwargs.pop("eps", 1e-3)
self.steps: int = kwargs.pop("steps", 512)
self.alg: str = kwargs.pop("alg", 'origin')
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
# Parameters that define the output variables of `generate`
self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
self.output_history: bool = kwargs.pop("output_history", False)
# Special tokens that can be used at generation time
self.mask_token_id = kwargs.pop("mask_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
# interface.
self._from_model_config = kwargs.pop("_from_model_config", False)
self._commit_hash = kwargs.pop("_commit_hash", None)
self.transformers_version = kwargs.pop("transformers_version", __version__)
# Additional attributes without default values
if not self._from_model_config:
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
# model's default configuration file
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
# Validate the values of the attributes
self.validate(is_init=True)
def validate(self, is_init=False, **kwargs):
pass
class DreamGenerationMixin:
@staticmethod
def _expand_inputs_for_generation(
expand_size: int = 1,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
# Do not call torch.repeat_interleave if expand_size is 1 because it clones
# the input tensor and thus requires more memory although no change is applied
if expand_size == 1:
return input_ids, attention_mask
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
if attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
return input_ids, attention_mask
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
"""Performs validation related to the resulting generated length"""
# Can't throw warnings/exceptions during compilation
if is_torchdynamo_compiling():
return
# 1. Max length warnings related to poor parameterization
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
# 20 is the default max_length of the generation config
warnings.warn(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
"generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
"generation.",
UserWarning,
)
if input_ids_length >= generation_config.max_length:
input_ids_string = "input_ids"
raise ValueError(
f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_length` or, better yet, setting `max_new_tokens`."
)
def _prepare_generated_length(
self,
generation_config,
has_default_max_length,
input_ids_length,
):
"""Prepared max and min length in generation configs to avoid clashes between similar attributes"""
if generation_config.max_new_tokens is not None:
if not has_default_max_length and generation_config.max_length is not None:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
elif has_default_max_length:
if generation_config.max_length == DreamGenerationConfig().max_length:
generation_config.max_length = generation_config.max_length + input_ids_length
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
if max_position_embeddings is not None:
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
return generation_config
def _prepare_generation_config(
self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict
) -> DreamGenerationConfig:
"""
Prepares the base generation config, then applies any generation configuration options from kwargs. This
function handles retrocompatibility with respect to configuration files.
"""
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
using_model_generation_config = False
if generation_config is None:
generation_config = DreamGenerationConfig.from_model_config(self.config)
using_model_generation_config = True
# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
# will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an
# exception will be raised in `_validate_model_kwargs`
if not is_torchdynamo_compiling():
generation_config = copy.deepcopy(generation_config)
_kwargs = generation_config.update(**kwargs)
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
if not using_model_generation_config:
if generation_config.bos_token_id is None:
generation_config.bos_token_id = self.generation_config.bos_token_id
if generation_config.eos_token_id is None:
generation_config.eos_token_id = self.generation_config.eos_token_id
if generation_config.pad_token_id is None:
generation_config.pad_token_id = self.generation_config.pad_token_id
if generation_config.mask_token_id is None:
generation_config.mask_token_id = self.generation_config.mask_token_id
return generation_config
def _prepare_special_tokens(
self,
generation_config: DreamGenerationConfig,
device: Optional[Union[torch.device, str]] = None,
):
"""
Prepares the special tokens for generation, overwriting the generation config with their processed versions
converted to tensor.
Note that `generation_config` is changed in place and stops being serializable after this method is called.
That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
"""
# Convert special tokens to tensors
def _tensor_or_none(token, device=None):
if token is None:
return token
device = device if device is not None else self.device
if isinstance(token, torch.Tensor):
return token.to(device)
return torch.tensor(token, device=device, dtype=torch.long)
bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
eos_token_tensor = eos_token_tensor.unsqueeze(0)
# Set pad token if unset (and there are conditions to do so)
if pad_token_tensor is None and eos_token_tensor is not None:
pad_token_tensor = eos_token_tensor[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
# Update generation config with the updated special tokens tensors
# NOTE: this must be written into a different attribute name than the one holding the original special tokens
# (in their non-tensor form), in order to enable end-to-end compilation. See
# https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
generation_config._bos_token_tensor = bos_token_tensor
generation_config._eos_token_tensor = eos_token_tensor
generation_config._pad_token_tensor = pad_token_tensor
generation_config._mask_token_tensor = mask_token_tensor
@torch.no_grad()
def diffusion_generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[DreamGenerationConfig] = None,
**kwargs,
) -> Union[DreamModelOutput, torch.LongTensor]:
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
generation_config = self._prepare_generation_config(generation_config, **kwargs)
generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
# 2. Define model inputs
assert inputs is not None
input_ids = inputs
device = input_ids.device
attention_mask = kwargs.pop("attention_mask", None)
self._prepare_special_tokens(generation_config, device=device)
# 3. Prepare `max_length`.
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
input_ids_length=input_ids_length,
)
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
# 4. Check input_ids
if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
warnings.warn(
"You are calling .generate() with the `input_ids` being on a device type different"
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
" Please make sure that you have put `input_ids` to the"
f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
" running `.generate()`.",
UserWarning,
)
if (
hasattr(generation_config, "pad_token_id") and
torch.any(input_ids == generation_config.pad_token_id) and
attention_mask is None
):
warnings.warn(
"Padding was detected but no attention mask is passed here. For correct "
"generation results, please set `attention_mask` when batch-padding inputs.",
UserWarning,
)
input_ids, attention_mask = self._expand_inputs_for_generation(
expand_size=generation_config.num_return_sequences,
input_ids=input_ids,
attention_mask=attention_mask
)
result = self._sample(
input_ids,
attention_mask=attention_mask,
generation_config=generation_config,
generation_tokens_hook_func=generation_tokens_hook_func,
generation_logits_hook_func=generation_logits_hook_func
)
return result
def _sample(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor],
generation_config: DreamGenerationConfig,
generation_tokens_hook_func,
generation_logits_hook_func
) -> Union[DreamModelOutput, torch.LongTensor]:
# init values
output_history = generation_config.output_history
return_dict_in_generate = generation_config.return_dict_in_generate
max_length = generation_config.max_length
mask_token_id = generation_config.mask_token_id
steps = generation_config.steps
eps = generation_config.eps
alg = generation_config.alg
alg_temp = generation_config.alg_temp
temperature = generation_config.temperature
top_p = generation_config.top_p
top_k = generation_config.top_k
histories = [] if (return_dict_in_generate and output_history) else None
# pad input_ids to max_length
x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
if attention_mask is not None and torch.any(attention_mask == 0.0):
# we do not mask the [MASK] tokens so value = 1.0
attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
tok_idx = attention_mask.long().cumsum(-1) - 1
tok_idx.masked_fill_(attention_mask == 0, 1)
# attention_mask is of shape [B, N]
# broadcast to [B, 1, N, N]
attention_mask = torch.logical_and(
attention_mask.unsqueeze(1).unsqueeze(-2),
attention_mask.unsqueeze(1).unsqueeze(-1),
)
else:
tok_idx = None
attention_mask = "full"
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
# this allows user-defined token control of the intermediate steps
x = generation_tokens_hook_func(None, x, None)
for i in range(steps):
mask_index = (x == mask_token_id)
logits = self(x, attention_mask, tok_idx).logits
logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
# this allows user-defined logits control of the intermediate steps
logits = generation_logits_hook_func(i, x, logits)
mask_logits = logits[mask_index]
t = timesteps[i]
s = timesteps[i + 1]
if alg == 'origin':
p_transfer = 1 - s / t if i < steps - 1 else 1
x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
_, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
x[mask_index] = x0.clone()
else:
if alg == 'maskgit_plus':
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
elif alg == 'topk_margin':
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
elif alg == 'entropy':
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
else:
raise RuntimeError(f"Unknown alg: {alg}")
num_mask_token = mask_index.sum() / mask_index.shape[0]
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
full_confidence[mask_index] = confidence
if number_transfer_tokens > 0:
if alg_temp is None or alg_temp == 0:
_, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
else:
full_confidence = full_confidence / alg_temp
full_confidence = F.softmax(full_confidence, dim=-1)
transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
x_[mask_index] = x0.clone()
row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
x[row_indices,transfer_index] = x_[row_indices,transfer_index]
# this allows user-defined token control of the intermediate steps
x = generation_tokens_hook_func(i, x, logits)
if histories is not None:
histories.append(x.clone())
if return_dict_in_generate:
return DreamModelOutput(
sequences=x,
history=histories,
)
else:
return x

View File

@ -0,0 +1,850 @@
# coding=utf-8
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT and Qwen implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT and Qwen used by the Meta AI and Qwen team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Dream model."""
import math
from typing import List, Optional, Tuple, Union
import os
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
from transformers import PretrainedConfig
from .configuration_dream import DreamConfig
from .generation_utils import DreamGenerationMixin, DreamGenerationConfig
if is_flash_attn_2_available():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "Dream-7B"
_CONFIG_FOR_DOC = "DreamConfig"
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Dream
class DreamRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
DreamRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Dream
class DreamRotaryEmbedding(nn.Module):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[DreamConfig] = None,
):
super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`DreamRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def reset_parameters(self):
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, self.inv_freq.device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Dream
class DreamMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class DreamAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
and "Generating Long Sequences with Sparse Transformers".
"""
def __init__(self, config: DreamConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = False
self.attention_dropout = config.attention_dropout
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = DreamRotaryEmbedding(config=self.config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class DreamSdpaAttention(DreamAttention):
"""
Dream attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`DreamAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from DreamAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"DreamModel is using DreamSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# causal_mask = attention_mask
# if attention_mask is not None: # no matter the length, we just slice it
# causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
# is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=False, # hard coded
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
class DreamDecoderLayer(nn.Module):
def __init__(self, config: DreamConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
if config.sliding_window and config._attn_implementation != "flash_attention_2":
logger.warning_once(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
)
# self.self_attn = Dream_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.self_attn = DreamSdpaAttention(config, layer_idx)
self.mlp = DreamMLP(config)
self.input_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence.
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class DreamPreTrainedModel(PreTrainedModel):
config_class = DreamConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["DreamDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: Optional[bool] = None,
weights_only: bool = True,
**kwargs,
):
_model = super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
weights_only=weights_only,
**kwargs,
)
# NOTE(Lin): we need to override the generation config
# because the generation config loaded in `from_pretrained`
# does not include all the attributes of DreamGenerationConfig
resume_download = kwargs.get("resume_download", None)
proxies = kwargs.get("proxies", None)
subfolder = kwargs.get("subfolder", "")
from_auto_class = kwargs.get("_from_auto", False)
from_pipeline = kwargs.get("_from_pipeline", None)
_model.generation_config = DreamGenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
)
return _model
class DreamBaseModel(DreamPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DreamDecoderLayer`]
Args:
config: DreamConfig
"""
def __init__(self, config: DreamConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[DreamDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self._attn_implementation = config._attn_implementation
self.norm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = DreamRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class DreamModel(DreamGenerationMixin, DreamPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = DreamBaseModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def reset_rope_parameters(self):
self.model.rotary_emb.reset_parameters()
for layer in self.model.layers:
layer.self_attn.rotary_emb.reset_parameters()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, MaskedLMOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if isinstance(attention_mask, str) and attention_mask == "full" or attention_mask == None:
# whether attention_mask is full
pass
elif isinstance(attention_mask, torch.Tensor):
if not torch.any(attention_mask == 0.0):
attention_mask = 'full'
elif attention_mask.dim() == 2:
# [B, L] → [B, 1, L, L]
attention_mask = torch.logical_and(
attention_mask.unsqueeze(1).unsqueeze(-2),
attention_mask.unsqueeze(1).unsqueeze(-1),
)
attention_mask = attention_mask.to(torch.bool)
elif attention_mask.dim() in (3, 4):
# already extended/broadcasted form
if attention_mask.dtype != torch.bool:
attention_mask = attention_mask.to(torch.bool)
else:
raise ValueError(f"Unexpected attention_mask shape: {attention_mask.shape}")
else:
raise TypeError(f"Unsupported attention_mask type: {type(attention_mask)}")
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -0,0 +1,346 @@
# coding=utf-8
# Copyright 2024 The Dream team, HKUNLP Group and The HuggingFace Inc. team. All rights reserved.
#
# This code is based on Qwen's implementations in this library.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for Dream."""
import json
import os
import unicodedata
from functools import lru_cache
from typing import Optional, Tuple
import regex as re
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
from transformers.utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}
MAX_MODEL_INPUT_SIZES = {"dream/dream-tokenizer": 32768}
PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
@lru_cache()
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
characters the bpe code barfs on.
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and unicode strings.
"""
bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
def get_pairs(word):
"""
Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class DreamTokenizer(PreTrainedTokenizer):
"""
Construct a Dream tokenizer. Based on byte-level Byte-Pair-Encoding.
Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
be encoded differently whether it is at the beginning of the sentence (without space) or not:
```python
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("Dream-org/Dream-v0-Base-7B", trust_remote_code=True)
>>> tokenizer("Hello world")["input_ids"]
[9707, 1879]
>>> tokenizer(" Hello world")["input_ids"]
[21927, 1879]
```
This is expected.
You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
Path to the vocabulary file.
merges_file (`str`):
Path to the merges file.
errors (`str`, *optional*, defaults to `"replace"`):
Paradigm to follow when decoding bytes to UTF-8. See
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str`, *optional*):
The beginning of sequence token. Not applicable for this tokenizer.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The token used for padding, for example when batching sequences of different lengths.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
split_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the special tokens should be split during the tokenization process. The default behavior is
to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
'|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
"""
vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
unk_token="<|endoftext|>",
bos_token=None,
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
clean_up_tokenization_spaces=False,
split_special_tokens=False,
**kwargs,
):
# Dream vocab does not contain control tokens; added tokens need to be special
bos_token = (
AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(bos_token, str)
else bos_token
)
eos_token = (
AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(eos_token, str)
else eos_token
)
unk_token = (
AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(unk_token, str)
else unk_token
)
pad_token = (
AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(pad_token, str)
else pad_token
)
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_merges = []
with open(merges_file, encoding="utf-8") as merges_handle:
for i, line in enumerate(merges_handle):
line = line.strip()
if (i == 0 and line.startswith("#version:")) or not line:
continue
bpe_merges.append(tuple(line.split()))
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
# NOTE: the cache can grow without bound and will get really large for long running processes
# (esp. for texts of language that do not use space between word, e.g. Chinese); technically
# not a memory leak but appears as one.
# GPT2Tokenizer has the same problem, so let's be consistent.
self.cache = {}
self.pat = re.compile(PRETOKENIZE_REGEX)
if kwargs.get("add_prefix_space", False):
logger.warning_once(
f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
)
super().__init__(
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
unk_token=unk_token,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
split_special_tokens=split_special_tokens,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self.encoder)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
def _tokenize(self, text):
"""Tokenize a string."""
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(
self.byte_encoder[b] for b in token.encode("utf-8")
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
def decode(
self,
token_ids,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: Optional[bool] = False,
spaces_between_special_tokens: bool = False,
**kwargs,
) -> str:
# `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
# and cannot be configured elsewhere, but it should default to False for DreamTokenizer
return super().decode(
token_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
spaces_between_special_tokens=spaces_between_special_tokens,
**kwargs,
)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
merge_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index = token_index
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
return vocab_file, merge_file
def prepare_for_tokenization(self, text, **kwargs):
text = unicodedata.normalize("NFC", text)
return (text, kwargs)
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING
from .configuration_dream import DreamConfig
TOKENIZER_MAPPING.register(DreamConfig, (DreamTokenizer, None))

View File

@ -0,0 +1,84 @@
from typing import Any
import torch
from dllm.core.trainers import MDLMTrainer
def cart_weight(
masked_indices: torch.Tensor, t: torch.Tensor, p: float = 0.3
) -> torch.Tensor:
"""
Optimized CART weight computation using matrix operations.
Args:
masked_indices (torch.Tensor): (b, l) bool tensor indicating masked positions.
t (torch.Tensor): (b,) time steps (0-1 sampled uniformly). Not directly used in CART.
p (float): Parameter of geometric distribution (0 < p <= 1).
Returns:
torch.Tensor: (b, l) float tensor of weights.
"""
b, l = masked_indices.shape
device = masked_indices.device
idx = torch.arange(l, device=device)
dist_matrix = (idx[None, :] - idx[:, None]).abs() - 1
dist_matrix = torch.clamp(dist_matrix, min=0) # (l, l)
geo_matrix = (
torch.log(torch.tensor(p, device=device))
+ (dist_matrix - 1).clamp(min=0) * torch.log(torch.tensor(1 - p, device=device))
).exp() * 0.5 # Ensure numerical stability
geo_matrix.masked_fill_(dist_matrix == 0, 0.0) # ignore distance = 0
valid_mask = (~masked_indices).float() # (b, l), 1 = unmasked
weights = valid_mask @ geo_matrix.T # (b, l)
weights = weights * masked_indices.float()
return weights
class DreamTrainer(MDLMTrainer):
"""
DreamTrainer: specialization of MDLMTrainer for Dream training.
"""
def __init__(
self,
*args,
loss_weight_type: str = "cart[geo_p:0.3]",
**kwargs,
):
super().__init__(*args, loss_weight_type=loss_weight_type, **kwargs)
def _preprocess_inputs(self, inputs):
labels = inputs["labels"]
assert (labels[:, 0] == -100).all()
def _postprocess_outputs(self, outputs):
logits = outputs.logits
outputs.logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
def _compute_loss_weights(
self,
t: torch.Tensor,
inputs: dict[str, Any],
masked_indices: torch.Tensor,
*args,
**kwargs,
) -> torch.Tensor:
if self.loss_weight_type.startswith("cart"):
# parse geo_p
import re
match = re.search(r"geo_p:(0\.\d+)", self.loss_weight_type)
geo_p = float(match.group(1)) if match else 0.3
loss_weights = cart_weight(masked_indices, t, p=geo_p)
else:
loss_weights = super()._compute_loss_weights(
t=t,
inputs=inputs,
masked_indices=masked_indices,
*args,
**kwargs,
)
return loss_weights

View File

@ -0,0 +1,180 @@
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn.functional as F
import transformers
def top_p_logits(logits, top_p=None):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
return logits
def top_k_logits(logits, top_k=None):
top_k = min(top_k, logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
return logits
@dataclass
class DreamSFTCollator(transformers.DataCollatorForSeq2Seq):
"""
Randomly crop response length to reduce length bias during generation.
Reference: https://github.com/DreamLM/Dream/blob/main/src/trainer/fsdp_sft_trainer.py
"""
perbatch_cutoff: bool = True # Use prebatch truncation if True
resp_cutoff_ratio: float = 0.0 # Prob. of post-collation truncation
# -------------------------------------------------------------------------
# 1) Pre-collation truncation (per-sample)
# -------------------------------------------------------------------------
def apply_perbatch_cutoff(self, features):
"""
Randomly pick a response length from batch (`kept_len`) and trim other responses.
Before:
[<--promptA----><------responseA------>]
[<--promptB-><---responseB--->]
[<---promptC----><--respC-->]
After:
[<--promptA----><---respA--->]
[<--promptB-><--respB-->]
[<---promptC----><--respC-->]
kept_len = 10 → trim each response to ≤10 tokens (before padding)
"""
resp_lens = torch.tensor(
[len(f["input_ids"]) - f["prompt_len"] for f in features], dtype=torch.long
)
kept_len = int(np.random.choice(resp_lens))
for f, r_len in zip(features, resp_lens):
remove_len = max(r_len - kept_len, 0)
if remove_len > 0:
# f["input_ids"] = f["input_ids"][:-remove_len]
# f["attention_mask"] = f["attention_mask"][:-remove_len]
# f["labels"] = f["labels"][:-remove_len]
for key in ["input_ids", "labels", "attention_mask"]:
if key in f:
f[key] = f[key][:-remove_len]
return features
# -------------------------------------------------------------------------
# 2) Post-collation truncation
# -------------------------------------------------------------------------
def apply_resp_cutoff(self, batch, features):
"""
Uniformly chop tail *after padding*. All sequences truncated to new_seq_len.
Before:
[<--promptA----><-----respA----->] 40
[<--promptB-><respB><----pad---->] 40
[<---promptC----><--respC--><pad>] 40
cutoff_len = 5
After:
[<--promptA----><--respA--->] 35
[<--promptB-><respB><--pad->] 35
[<---promptC----><--respC-->] 35
"""
orig_seq_lens = [len(f["input_ids"]) for f in features]
resp_lens = torch.tensor(
[len(f["input_ids"]) - f["prompt_len"] for f in features], dtype=torch.long
)
min_resp_len = resp_lens.min().item()
if min_resp_len <= 1:
return batch
cutoff_len = int(np.random.randint(1, min_resp_len))
new_seq_len = max(orig_seq_lens) - cutoff_len
for key in ["input_ids", "labels", "attention_mask"]:
if key in batch:
batch[key] = batch[key][:, :new_seq_len].contiguous()
return batch
# -------------------------------------------------------------------------
# 3) Main call: pick truncation mode
# -------------------------------------------------------------------------
def __call__(self, features, return_tensors=None):
# optional pre-collation truncation
if self.perbatch_cutoff:
features = self.apply_perbatch_cutoff(features)
# always collate only the needed fields
base = [
{k: f[k] for k in ("input_ids", "labels", "attention_mask") if k in f}
for f in features
]
batch = super().__call__(base, return_tensors=return_tensors)
# optional post-collation truncation
if (
not self.perbatch_cutoff
and self.resp_cutoff_ratio > 0
and np.random.rand() < self.resp_cutoff_ratio
):
batch = self.apply_resp_cutoff(batch, features)
batch.pop("prompt_len", None)
return batch
@dataclass
class DreamPTCollator(transformers.DataCollatorForSeq2Seq):
random_length_ratio: float = 0.01
def __call__(self, features, return_tensors=None):
outputs = super().__call__(features, return_tensors=return_tensors)
input_ids, labels, attention_mask = (
outputs["input_ids"],
outputs["labels"],
outputs["attention_mask"],
)
bsz, seq_len = input_ids.shape
# --- Random truncation for robustness ---
if torch.rand(1).item() < self.random_length_ratio:
random_len = torch.randint(1, seq_len + 1, (1,)).item()
input_ids = input_ids[:, :random_len]
labels = labels[:, :random_len]
attention_mask = attention_mask[:, :random_len]
# --- Add BOS token to the beginning of input_ids ---
bos = torch.full(
(bsz, 1),
self.tokenizer.bos_token_id,
dtype=input_ids.dtype,
device=input_ids.device,
)
input_ids = torch.cat([bos, input_ids], dim=1)
# --- Prepend zeros to labels instead of BOS ---
ignore_labels = self.label_pad_token_id * torch.ones(
(bsz, 1), dtype=labels.dtype, device=labels.device
)
labels = torch.cat([ignore_labels, labels], dim=1)
# --- Prepend ones to attention_mask ---
bos_attention = torch.ones(
(bsz, 1), dtype=attention_mask.dtype, device=attention_mask.device
)
attention_mask = torch.cat([bos_attention, attention_mask], dim=1)
# --- Update and return ---
outputs["input_ids"] = input_ids
outputs["labels"] = labels
outputs["attention_mask"] = attention_mask
# Check if attention_mask is all ones and set it to None
if torch.all(outputs["attention_mask"] == 1):
outputs.pop("attention_mask")
return outputs

View File

@ -0,0 +1,14 @@
from . import trainer, utils
from .models.dream.modelling_dream import (
EditFlowDreamConfig,
EditFlowDreamModel,
)
from .models.llada.modelling_llada import (
EditFlowLLaDAConfig,
EditFlowLLaDAModel,
)
from .models.bert.modelling_modernbert import (
EditFlowModernBertConfig,
EditFlowModernBertModel,
)
from dllm.pipelines.editflow.trainer import EditFlowTrainer

View File

@ -0,0 +1,89 @@
import torch
from torch import nn
import transformers
class EditFlowModernBertConfig(transformers.ModernBertConfig):
model_type = "editflow-modernbert" # <- NEW model_type
class EditFlowModernBertModel(transformers.ModernBertForMaskedLM):
config_class = EditFlowModernBertConfig
modules_to_save = {
"rate_heads",
"sub_logits",
"ins_logits",
} # fully fintuned even using lora
def __init__(self, config):
# fa2 has bugs when forward(output_hidden_states=True)
config._attn_implementation = "sdpa"
super().__init__(config)
in_lm, out_lm = self.decoder.in_features, self.decoder.out_features
use_bias = self.decoder.bias is not None
# Create new, independent heads (no deepcopy)
self.sub_logits = nn.Linear(in_lm, out_lm, bias=use_bias)
self.ins_logits = nn.Linear(in_lm, out_lm, bias=use_bias)
self.rate_heads = nn.Sequential(nn.Linear(in_lm, 3), nn.Softplus())
self.post_init()
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor | None = None,
t: torch.Tensor | None = None,
**kwargs,
):
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs,
)
h = output["hidden_states"][-1] # final hidden states
h = self.head(h)
# Position heads
sub_log = self.sub_logits(h) # [B, L, V]
ins_log = self.ins_logits(h) # [B, L, V]
rates = self.rate_heads(h)
sub_rate_hat, del_rate_hat, ins_rate_hat = rates.unbind(
-1
) # [B, L], [B, L], [B, L]
return dict(
sub_rate_hat=sub_rate_hat, # [B,L]
del_rate_hat=del_rate_hat, # [B,L]
ins_rate_hat=ins_rate_hat, # [B,L]
ins_logits=ins_log, # [B,L,V]
sub_logits=sub_log, # [B,L,V]
)
from transformers.models.auto import AutoModel, AutoConfig
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
AutoConfig.register("editflow-modernbert", EditFlowModernBertConfig)
AutoModel.register(EditFlowModernBertConfig, EditFlowModernBertModel)
if __name__ == "__main__":
import dllm
import torch
from transformers import AutoConfig, AutoModel
# Load a config from a local path (either a directory containing config.json, or the file itself)
config_path = dllm.utils.resolve_with_base_env(
"answerdotai/ModernBERT-base", "BASE_MODELS_DIR"
)
config = EditFlowModernBertConfig.from_pretrained(config_path)
if hasattr(config, "auto_map"):
delattr(config, "auto_map")
if hasattr(config, "architectures"):
delattr(config, "architectures")
torch.set_default_device("cuda")
model = EditFlowModernBertModel(config)
model.save_pretrained("models-tmp/editflow-modernbert")
auto_model = AutoModel.from_pretrained("models-tmp/editflow-modernbert")

View File

@ -0,0 +1,97 @@
import copy
from typing import Optional
import torch
from torch import nn
from dllm.pipelines import dream
class EditFlowDreamConfig(dream.DreamConfig):
model_type = "editflow-dream" # <- NEW model_type
class EditFlowDreamModel(dream.DreamModel):
config_class = EditFlowDreamConfig
modules_to_save = {
"rate_heads",
"sub_logits",
"ins_logits",
} # fully fintuned even using lora
def __init__(self, config):
super().__init__(config)
in_lm, out_lm = self.lm_head.in_features, self.lm_head.out_features
use_bias = self.lm_head.bias is not None
# Create new, independent heads (no deepcopy)
self.sub_logits = nn.Linear(in_lm, out_lm, bias=use_bias)
self.ins_logits = nn.Linear(in_lm, out_lm, bias=use_bias)
self.rate_heads = nn.Sequential(nn.Linear(config.hidden_size, 3), nn.Softplus())
self.post_init()
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor | None = None,
t: torch.Tensor | None = None,
**kwargs,
):
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs,
)
h = output["hidden_states"][-1] # final hidden states
# Position heads
sub_log = self.sub_logits(h) # [B, L, V]
sub_log = torch.concatenate(
[torch.zeros_like(sub_log)[:, :1], sub_log[:, :-1]], dim=1
) # [B, L, V]
ins_log = self.ins_logits(h) # [B, L, V]
rates = self.rate_heads(h)
sub_rate_hat, del_rate_hat, ins_rate_hat = rates.unbind(
-1
) # [B, L], [B, L], [B, L]
sub_rate_hat = torch.concatenate(
[torch.zeros_like(sub_rate_hat[:, :1]), sub_rate_hat[:, :-1]], dim=1
) # [B, L]
del_rate_hat = torch.concatenate(
[torch.zeros_like(del_rate_hat[:, :1]), del_rate_hat[:, :-1]], dim=1
) # [B, L]
return dict(
sub_rate_hat=sub_rate_hat, # [B,L]
del_rate_hat=del_rate_hat, # [B,L]
ins_rate_hat=ins_rate_hat, # [B,L]
ins_logits=ins_log, # [B,L,V]
sub_logits=sub_log, # [B,L,V]
)
from transformers.models.auto import AutoModel, AutoConfig
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
AutoConfig.register("editflow-dream", EditFlowDreamConfig)
AutoModel.register(EditFlowDreamConfig, EditFlowDreamModel)
if __name__ == "__main__":
import dllm
import torch
from transformers import AutoConfig, AutoModel
# Load a config from a local path (either a directory containing config.json, or the file itself)
config_path = dllm.utils.resolve_with_base_env(
"Dream-org/Dream-v0-Base-7B", "BASE_MODELS_DIR"
)
config = EditFlowDreamConfig.from_pretrained(config_path)
if hasattr(config, "auto_map"):
delattr(config, "auto_map")
if hasattr(config, "architectures"):
delattr(config, "architectures")
torch.set_default_device("cuda")
model = EditFlowDreamModel(config)
model.save_pretrained("models-tmp/editflow-dream")
auto_model = AutoModel.from_pretrained("models-tmp/editflow-dream")

View File

@ -0,0 +1,91 @@
import copy
from typing import Optional
import torch
from torch import nn
from dllm.pipelines import llada
class EditFlowLLaDAConfig(llada.LLaDAConfig):
model_type = "editflow-llada" # <- NEW model_type
class EditFlowLLaDAModel(llada.LLaDAModelLM):
config_class = EditFlowLLaDAConfig
modules_to_save = {
"rate_heads",
"sub_logits",
"ins_logits",
} # fully fintuned even using lora
def __init__(self, config):
# TODO: time embedding
super().__init__(config)
ff = self.model.transformer.ff_out
in_f, out_f = ff.in_features, ff.out_features
use_bias = ff.bias is not None
# Create new, independent heads (no deepcopy)
self.sub_logits = nn.Linear(in_f, out_f, bias=use_bias)
self.ins_logits = nn.Linear(in_f, out_f, bias=use_bias)
self.rate_heads = nn.Sequential(nn.Linear(config.hidden_size, 3), nn.Softplus())
self.post_init()
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor | None = None,
t: torch.Tensor | None = None,
**kwargs,
):
# TODO: time embedding
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs,
)
h = output["hidden_states"][-1] # final hidden states
# Position heads
sub_log = self.sub_logits(h) # [B, L, V]
ins_log = self.ins_logits(h) # [B, L, V]
rates = self.rate_heads(h)
sub_rate_hat, del_rate_hat, ins_rate_hat = rates.unbind(
-1
) # [B, L], [B, L], [B, L]
return dict(
sub_rate_hat=sub_rate_hat, # [B,L]
del_rate_hat=del_rate_hat, # [B,L]
ins_rate_hat=ins_rate_hat, # [B,L]
ins_logits=ins_log, # [B,L,V]
sub_logits=sub_log, # [B,L,V]
)
from transformers.models.auto import AutoModel, AutoConfig
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
AutoConfig.register("editflow-llada", EditFlowLLaDAConfig)
AutoModel.register(EditFlowLLaDAConfig, EditFlowLLaDAModel)
if __name__ == "__main__":
import dllm
import torch
from transformers import AutoConfig, AutoModel
# Load a config from a local path (either a directory containing config.json, or the file itself)
config_path = dllm.utils.resolve_with_base_env(
"GSAI-ML/LLaDA-8B-Base", "BASE_MODELS_DIR"
)
config = EditFlowLLaDAConfig.from_pretrained(config_path)
if hasattr(config, "auto_map"):
delattr(config, "auto_map")
if hasattr(config, "architectures"):
delattr(config, "architectures")
torch.set_default_device("cuda")
model = EditFlowLLaDAModel(config)
model.save_pretrained("models-tmp/editflow-llada")
auto_model = AutoModel.from_pretrained("models-tmp/editflow-llada")

View File

@ -0,0 +1,407 @@
from typing import Any, Dict, Union, List, Tuple, Optional
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from dllm.core.schedulers import BaseKappaScheduler, CubicKappaScheduler
from dllm.pipelines.editflow.utils import pad_1d
BLANK = -1
def align_with_blanks(
x0: List[int], x1: List[int], sub_cost: int = 1, gap_cost: int = 1
) -> Dict:
"""
NeedlemanWunsch global alignment of two integer sequences with:
match cost = 0, substitution cost = sub_cost, gap cost = gap_cost.
Returns aligned sequences (z0, z1) of equal length containing BLANK = ε where gaps occur.
"""
n, m = len(x0), len(x1)
# DP tables
dp = [[0] * (m + 1) for _ in range(n + 1)]
ptr = [[None] * (m + 1) for _ in range(n + 1)] # 'diag', 'up', 'left'
for i in range(1, n + 1):
dp[i][0] = i * gap_cost
ptr[i][0] = "up"
for j in range(1, m + 1):
dp[0][j] = j * gap_cost
ptr[0][j] = "left"
for i in range(1, n + 1):
for j in range(1, m + 1):
cost_diag = dp[i - 1][j - 1] + (0 if x0[i - 1] == x1[j - 1] else sub_cost)
cost_up = dp[i - 1][j] + gap_cost
cost_left = dp[i][j - 1] + gap_cost
best = min(cost_diag, cost_up, cost_left)
dp[i][j] = best
if best == cost_diag:
ptr[i][j] = "diag"
elif best == cost_up:
ptr[i][j] = "up"
else:
ptr[i][j] = "left"
# traceback
z0, z1 = [], []
i, j = n, m
while i > 0 or j > 0:
p = ptr[i][j]
if p == "diag":
z0.append(x0[i - 1])
z1.append(x1[j - 1])
i -= 1
j -= 1
elif p == "up":
z0.append(x0[i - 1])
z1.append(BLANK)
i -= 1
else: # 'left'
z0.append(BLANK)
z1.append(x1[j - 1])
j -= 1
z0.reverse()
z1.reverse()
# return Alignment(z0=z0, z1=z1)
# return {"z0": z0, "z1": z1}
return dict(z0=z0, z1=z1)
# def align_with_blanks(
# x0: list[int], x1: list[int], sub_cost: int = 1, gap_cost: int = 1
# ) -> dict:
# """
# NeedlemanWunsch with a secondary objective that defers gaps to the end:
# - 'up' (gap in z1) is penalized if j < m
# - 'left' (gap in z0) is penalized if i < n
# This pushes blanks (-1) to the *right* whether x0 > x1 or x0 < x1.
# """
# n, m = len(x0), len(x1)
# dp_cost = [[0] * (m + 1) for _ in range(n + 1)]
# dp_pen = [[0] * (m + 1) for _ in range(n + 1)]
# ptr = [[None] * (m + 1) for _ in range(n + 1)] # 'diag' | 'up' | 'left'
# # Left edge: all 'up' moves with j=0 (< m) → penalize each step
# for i in range(1, n + 1):
# dp_cost[i][0] = i * gap_cost
# dp_pen[i][0] = i # i early 'up' moves
# ptr[i][0] = "up"
# # Top edge: all 'left' moves with i=0 (< n) → penalize each step
# for j in range(1, m + 1):
# dp_cost[0][j] = j * gap_cost
# dp_pen[0][j] = j # j early 'left' moves
# ptr[0][j] = "left"
# for i in range(1, n + 1):
# xi = x0[i - 1]
# for j in range(1, m + 1):
# yj = x1[j - 1]
# # diag
# cost_diag = dp_cost[i - 1][j - 1] + (0 if xi == yj else sub_cost)
# pen_diag = dp_pen[i - 1][j - 1]
# cand_diag = (cost_diag, pen_diag)
# # up: add blank to z1, penalize if j < m (early)
# cost_up = dp_cost[i - 1][j] + gap_cost
# pen_up = dp_pen[i - 1][j] + (1 if j < m else 0)
# cand_up = (cost_up, pen_up)
# # left: add blank to z0, penalize if i < n (early)
# cost_left = dp_cost[i][j - 1] + gap_cost
# pen_left = dp_pen[i][j - 1] + (1 if i < n else 0)
# cand_left = (cost_left, pen_left)
# # choose (cost,pen) min; deterministic tie-break: diag > left > up
# best = min(cand_diag, cand_left, cand_up)
# dp_cost[i][j], dp_pen[i][j] = best
# if best == cand_diag:
# ptr[i][j] = "diag"
# elif best == cand_left:
# ptr[i][j] = "left"
# else:
# ptr[i][j] = "up"
# # traceback
# z0, z1 = [], []
# i, j = n, m
# while i > 0 or j > 0:
# p = ptr[i][j]
# if p == "diag":
# z0.append(x0[i - 1])
# z1.append(x1[j - 1])
# i -= 1
# j -= 1
# elif p == "up":
# z0.append(x0[i - 1])
# z1.append(BLANK)
# i -= 1
# else: # 'left'
# z0.append(BLANK)
# z1.append(x1[j - 1])
# j -= 1
# z0.reverse()
# z1.reverse()
# return dict(z0=z0, z1=z1)
def strip_blanks(z: list[int]) -> list[int]:
# IMPORTANT: do NOT strip BOS; we only remove BLANKs
return [t for t in z if t != BLANK]
@dataclass
class Edit:
kind: str # "SUB" | "DEL" | "INS"
pos: int # position (for SUB/DEL) or token-row idx for INS (incl. BOS row 0)
token: int | None # token for SUB/INS, else None
def build_remaining_edits(zt: list[int], z1: list[int]) -> list[Edit]:
edits: list[Edit] = []
def count_nonblank_prefix(z: list[int], j: int) -> int:
c = 0
for k in range(j):
if z[k] != BLANK:
c += 1
return c
for j, (a, b) in enumerate(zip(zt, z1)):
if a == b:
continue
nb = count_nonblank_prefix(
zt, j
) # counts BOS as 1, first content token will be nb=1 before its column
if a == BLANK and b != BLANK:
# INSERT after row (nb-1): BOS insert => nb=1 -> gap=0; general case works too
gap = max(nb - 1, 0)
edits.append(Edit("INS", gap, b))
elif a != BLANK and b == BLANK:
# DELETE token at row nb (first content token => nb=1, allowed; BOS is never BLANK so nb>=1)
pos = nb
# if pos > 0: # forbid BOS (row 0)
edits.append(Edit("DEL", pos, None))
else: # a != BLANK, b != BLANK, a != b
# SUB token at row nb
pos = nb
# if pos > 0: # forbid BOS (row 0)
edits.append(Edit("SUB", pos, b))
return edits
class EditFlowTrainer(transformers.Trainer):
"""
Trainer for Edit Flows where the model returns:
- sub_logits: [B,L,V] (token dist for SUB)
- ins_logits: [B,L,V] (token dist for INS)
- sub_rate_hat: [B,L] (normalized rates; NO kappa factor)
- del_rate_hat: [B,L]
- ins_rate_hat: [B,L]
True intensities are w * rate_hat, with w = kappa_dot(t) / (1 - kappa(t)).
"""
def __init__(
self,
*args,
scheduler: BaseKappaScheduler | None = None,
normalize_per_position: bool = True,
time_epsilon: float = 1e-3,
max_w: float | None = None,
**kwargs,
):
self.scheduler = scheduler or CubicKappaScheduler()
self.normalize_per_position = normalize_per_position
self.time_epsilon = time_epsilon
self.max_w = max_w
super().__init__(*args, **kwargs)
def compute_loss(
self,
model: transformers.PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
return_outputs: bool = False,
**kwargs,
):
device = self.model.device
B = len(inputs["x0_ids"])
# -------- 1) Align with blanks (z0,z1) and sample time t --------
aligns = [
align_with_blanks(x0, x1)
for x0, x1 in zip(inputs["x0_ids"], inputs["x1_ids"])
]
z0_list = [a["z0"] for a in aligns]
z1_list = [a["z1"] for a in aligns]
assert all(len(z0) == len(z1) for z0, z1 in zip(z0_list, z1_list))
assert all(z0[0] != BLANK for z0 in z0_list) # BOS must remain
t = (1 - self.time_epsilon) * torch.rand(B, 1, device=device) # [B,1]
k = self.scheduler.kappa(t).to(device) # [B,1]
w = self.scheduler.weight(t).squeeze(1).to(device) # [B]
if self.max_w:
w = w.clamp(max=self.max_w)
# -------- 2) Sample z_t by κ-mixing (vectorized per example) --------
# Keep python lists -> tensors per-example to reuse build_remaining_edits
zt_list: list[list[int]] = []
for z0, z1, kb in zip(z0_list, z1_list, k.squeeze(1).tolist()):
# per-column Bernoulli(κ) mix; BOS is equal in z0/z1 so it stays BOS
choose_target = torch.rand(len(z0)) < kb
zt = [b if choose_target[j] else a for j, (a, b) in enumerate(zip(z0, z1))]
zt_list.append(zt)
# -------- 3) Strip blanks to x_t and compute remaining edits --------
xt_list = [strip_blanks(zt) for zt in zt_list]
edits_list: list[list[Edit]] = [
build_remaining_edits(zt, z1) for zt, z1 in zip(zt_list, z1_list)
]
# -------- 4) Collate x_t for the model --------
x_tok, x_mask = pad_1d(
xt_list, pad_val=self.processing_class.pad_token_id
) # [B,Lmax], [B,Lmax]
x_tok = x_tok.to(device)
x_mask = x_mask.to(device)
# -------- 5) Forward pass --------
out = model(input_ids=x_tok, attention_mask=x_mask, t=t.to(device))
# Rename for clarity: model returns normalized rates (no kappa)
sub_rate_hat = out["sub_rate_hat"] # [B,L]
del_rate_hat = out["del_rate_hat"] # [B,L]
ins_rate_hat = out["ins_rate_hat"] # [B,L]
logQ_sub = F.log_softmax(out["sub_logits"], dim=-1) # [B,L,V]
logQ_ins = F.log_softmax(out["ins_logits"], dim=-1) # [B,L,V]
# *** NEW: zero-cost anchor to "touch" every head even if unused this step ***
# Using .sum() * 0.0 keeps a graph dependency without changing the loss value.
# Include both logits (for SUB/INS heads) and rates (for SUB/DEL/INS heads).
# This is important for Deepspeed ZeRO stage 2/3 to avoid skipping unused parameters.
anchor = (
sub_rate_hat.sum() * 0.0
+ del_rate_hat.sum() * 0.0
+ ins_rate_hat.sum() * 0.0
+ logQ_sub.sum() * 0.0
+ logQ_ins.sum() * 0.0
)
# Utility
def safe_log(x: torch.Tensor) -> torch.Tensor:
return torch.log(x.clamp_min(1e-12))
# -------- 6) Survival term --------
# Survival = E[sum of true intensities over valid rows]
# true intensity = w[b] * rate_hat[b, i]
mask_f = x_mask.float()
# L = mask_f.sum(dim=1).clamp_min(1.0) # [B] number of positions (incl. BOS)
L1 = torch.tensor(
[len(x1) for x1 in inputs["x1_ids"]], device=device, dtype=torch.float
).clamp_min(1.0)
denom = L1 if self.normalize_per_position else torch.ones_like(L1)
Lambda_hat = ((sub_rate_hat + del_rate_hat + ins_rate_hat) * mask_f).sum(
dim=1
) # [B]
loss_surv = ((w * Lambda_hat) / denom).mean()
# -------- 7) Positive edit terms --------
# For each remaining edit e: -log true rate(e) - log token prob(e) if tokenized
# loss_pos_per = sub_rate_hat.new_zeros(B) # [B]
# for b, edits in enumerate(edits_list):
# if not edits:
# continue
# cur_len = int(x_mask[b].sum().item())
# for e in edits:
# pos = e.pos
# assert 0 <= pos < cur_len, f"pos {pos} out of range {cur_len}"
# if e.kind == "SUB":
# loss_pos_per[b] -= logQ_sub[b, pos, e.token] + safe_log(
# sub_rate_hat[b, pos]
# )
# elif e.kind == "DEL":
# loss_pos_per[b] -= safe_log(del_rate_hat[b, pos])
# else: # "INS"
# loss_pos_per[b] -= logQ_ins[b, pos, e.token] + safe_log(
# ins_rate_hat[b, pos]
# )
# -------- 7) Positive edit terms (vectorized) --------
pos_sub, tok_sub, pos_ins, tok_ins, pos_del = [], [], [], [], []
for b, edits in enumerate(edits_list):
cur_len = int(x_mask[b].sum().item())
ps, ts, pi, ti, pd = [], [], [], [], []
for e in edits:
if not (0 <= e.pos < cur_len):
raise AssertionError(
f"pos {e.pos} out of range {cur_len} for b={b}"
)
if e.kind == "SUB":
ps.append(e.pos)
ts.append(e.token)
elif e.kind == "INS":
pi.append(e.pos)
ti.append(e.token)
else:
pd.append(e.pos)
pos_sub.append(
torch.tensor(ps, device=x_tok.device, dtype=torch.long) if ps else None
)
tok_sub.append(
torch.tensor(ts, device=x_tok.device, dtype=torch.long) if ts else None
)
pos_ins.append(
torch.tensor(pi, device=x_tok.device, dtype=torch.long) if pi else None
)
tok_ins.append(
torch.tensor(ti, device=x_tok.device, dtype=torch.long) if ti else None
)
pos_del.append(
torch.tensor(pd, device=x_tok.device, dtype=torch.long) if pd else None
)
loss_pos_terms = []
for b in range(B):
lp = x_tok.new_zeros(())
if pos_sub[b] is not None:
lp = (
lp
- (
logQ_sub[b, pos_sub[b], tok_sub[b]]
+ safe_log(sub_rate_hat[b, pos_sub[b]])
).sum()
)
if pos_ins[b] is not None:
lp = (
lp
- (
logQ_ins[b, pos_ins[b], tok_ins[b]]
+ safe_log(ins_rate_hat[b, pos_ins[b]])
).sum()
)
if pos_del[b] is not None:
lp = lp - safe_log(del_rate_hat[b, pos_del[b]]).sum()
loss_pos_terms.append(lp)
loss_pos_per = torch.stack(loss_pos_terms) # [B]
# # Average positive term per sequence (MC estimator across batch)
loss_pos = ((w * loss_pos_per) / denom).mean()
# -------- 8) Total --------
loss = loss_surv + loss_pos + anchor
return (loss, out) if return_outputs else loss
if __name__ == "__main__":
pass

View File

@ -0,0 +1,218 @@
import math
import random
from dataclasses import dataclass
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Text
from collections.abc import Callable
import torch
import transformers
from dllm.utils.utils import parse_spec
# ------------------------------- Collator (x0 source) --------------------------------
@dataclass
class X0Sampler:
def __call__(self, *args, **kwargs) -> list[int]:
raise NotImplementedError("Subclasses must implement __call__.")
@dataclass
class SampleX0Empty(X0Sampler):
"""Return BOS-only (i.e., empty tail)."""
def __call__(self, *args, **kwargs) -> list[int]:
return []
@dataclass
class SampleX0Masks(X0Sampler):
"""Return a run of mask tokens of given length."""
length: int = 128
tokenizer: transformers.PreTrainedTokenizer = None
def __call__(self, *args, **kwargs) -> list[int]:
mask_id = getattr(self.tokenizer, "mask_token_id", None)
if mask_id is None:
raise ValueError("tokenizer needs mask_token_id for mask-based sampler")
return [int(mask_id)] * self.length
# ---------------- Factory ---------------- #
_X0_SAMPLER_CLASSES: dict[str, type[X0Sampler]] = {
"empty": SampleX0Empty,
"masks": SampleX0Masks,
}
def make_x0_sampler(name: str, tokenizer: Any, **kwargs) -> X0Sampler:
try:
name, kvs = parse_spec(name)
cls = _X0_SAMPLER_CLASSES[name.lower()]
except KeyError:
raise ValueError(
f"Unknown x0 sampler '{name}'. Available: {list(_X0_SAMPLER_CLASSES)}"
)
# merged_kwargs = {**kvs, **kwargs}
return cls(tokenizer=tokenizer, **kvs, **kwargs)
@dataclass
class EditFlowCollator:
tokenizer: transformers.PreTrainedTokenizer = None
x0_sampler: Callable | str | None = X0Sampler # can be func OR name
def __post_init__(self):
if isinstance(self.x0_sampler, str):
self.x0_sampler = make_x0_sampler(self.x0_sampler, self.tokenizer)
def __call__(self, features: list[dict[str, Any]]) -> dict[str, list[Any]]:
if not features:
return {}
keys = features[0].keys()
batch = {k: [ex[k] for ex in features] for k in keys}
batch["x1_ids"] = batch["input_ids"]
if "prompt_len" not in batch:
assert self.tokenizer.bos_token_id is not None
bos = self.tokenizer.bos_token_id
batch["x1_ids"] = [
x if x and x[0] == bos else [bos] + x for x in batch["x1_ids"]
]
batch["x0_ids"] = [
x1_ids[:1] + self.x0_sampler(x1_ids=x1_ids[1:])
for x1_ids in batch["x1_ids"]
]
else:
batch["x0_ids"] = [
x1_ids[:prompt_len] + self.x0_sampler(x1_ids=x1_ids[prompt_len:])
for x1_ids, prompt_len in zip(batch["x1_ids"], batch["prompt_len"])
]
batch["return_loss"] = True
return batch
# ------------------------------- Trainer utils --------------------------------
def pad_1d(
batch_lists: list[list[int]], pad_val: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Pads a list of variable-length integer lists into:
- out: tensor of shape [B, Lmax] with padding value `pad_val`
- mask: tensor of shape [B, Lmax] with 1 for real tokens and 0 for padding (int mask)
"""
B = len(batch_lists)
Lmax = max((len(x) for x in batch_lists), default=0)
out = torch.full((B, Lmax), pad_val, dtype=torch.long)
mask = torch.zeros((B, Lmax), dtype=torch.long) # 0/1 mask (int)
for b, x in enumerate(batch_lists):
if not x:
continue
L = len(x)
out[b, :L] = torch.tensor(x, dtype=torch.long)
mask[b, :L] = 1 # mark valid positions with 1
return out, mask
def init_editflow_from_src(
ef_model, src_model, lm_head_key: str = "lm_head", verbose: bool = True
):
"""
Initialize an EditFlowModel (ef_model) from a pretrained source model.
If DeepSpeed ZeRO-3 is enabled (detected via HF's `is_deepspeed_zero3_enabled()`),
this function temporarily gathers full parameters for both models on rank 0,
performs the copy there, and then returns to sharded mode automatically.
Otherwise it behaves like a normal CPU/GPU single-process copy.
Returns (missing_keys, unexpected_keys) from load_state_dict(strict=False).
"""
import deepspeed
from transformers.integrations import is_deepspeed_zero3_enabled
dist_ok = torch.distributed.is_available() and torch.distributed.is_initialized()
rank = torch.distributed.get_rank() if dist_ok else 0
def _copy_once():
src_sd = src_model.state_dict()
tgt_sd = ef_model.state_dict()
new_sd = OrderedDict()
# 1) copy matching backbone tensors
for k, v in src_sd.items():
if k in tgt_sd and tgt_sd[k].shape == v.shape:
new_sd[k] = v
# 2) duplicate lm_head -> sub_logits & ins_logits (weight + optional bias)
lm_w = f"{lm_head_key}.weight"
lm_b = f"{lm_head_key}.bias"
if lm_w in src_sd:
if "sub_logits.weight" in tgt_sd:
new_sd["sub_logits.weight"] = src_sd[lm_w]
if "ins_logits.weight" in tgt_sd:
new_sd["ins_logits.weight"] = src_sd[lm_w]
if lm_b in src_sd:
if "sub_logits.bias" in tgt_sd:
new_sd["sub_logits.bias"] = src_sd[lm_b]
if "ins_logits.bias" in tgt_sd:
new_sd["ins_logits.bias"] = src_sd[lm_b]
# 3) non-strict load so new rate heads remain randomly initialized
missing, unexpected = ef_model.load_state_dict(new_sd, strict=False)
return new_sd, missing, unexpected
if is_deepspeed_zero3_enabled():
# All ranks enter/exit together; only rank 0 materializes full tensors.
params = list(ef_model.parameters()) + list(src_model.parameters())
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
if rank == 0:
new_sd, missing, unexpected = _copy_once()
else:
new_sd, missing, unexpected = OrderedDict(), [], []
if dist_ok:
torch.distributed.barrier()
if verbose and rank == 0:
_p = getattr(globals().get("dllm", None), "utils", None)
printer = getattr(_p, "print_main", print) if _p else print
printer(
f"[EditFlow init][ZeRO-3] Copied {len(new_sd)} tensors from Src Model."
)
if missing:
printer(" Missing (expected for new rate heads, etc.):")
for k in missing:
printer(" -", k)
if unexpected:
printer(" Unexpected (check key names):")
for k in unexpected:
printer(" -", k)
return missing, unexpected
# --- Non-ZeRO (or DS not present) path ---
new_sd, missing, unexpected = _copy_once()
if verbose:
_p = getattr(globals().get("dllm", None), "utils", None)
printer = getattr(_p, "print_main", print) if _p else print
printer(f"[EditFlow init] Copied {len(new_sd)} tensors from Src Model.")
if missing:
printer(" Missing (expected for new rate heads, etc.):")
for k in missing:
printer(" -", k)
if unexpected:
printer(" Unexpected (check key names):")
for k in unexpected:
printer(" -", k)
return missing, unexpected
if __name__ == "__main__":
pass

View File

@ -0,0 +1,7 @@
from . import generator, trainer
from .models.modeling_llada import LLaDAModelLM
from .models.configuration_llada import LLaDAConfig
from .models.modeling_lladamoe import LLaDAMoEModelLM
from .models.configuration_lladamoe import LLaDAMoEConfig
from .generator import LLaDAGeneratorConfig, LLaDAGenerator
from .trainer import LLaDATrainer

View File

@ -0,0 +1,357 @@
"""
accelerate launch \
--num_processes 2 \
dllm/pipelines/llada/eval.py \
--tasks gsm8k \
--model llada \
--num_fewshot 8 \
--model_args "pretrained=GSAI-ML/LLaDA-8B-Base,is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
"""
from types import SimpleNamespace
from dataclasses import dataclass
import accelerate
import torch
import torch.nn.functional as F
from datasets import Dataset
from tqdm import tqdm
from lm_eval.__main__ import cli_evaluate
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import get_dtype
import dllm
from dllm.pipelines.llada import LLaDAGenerator, LLaDAGeneratorConfig
@dataclass
class LLaDAEvalConfig(LLaDAGeneratorConfig):
max_new_tokens: int = 1024
max_length: int = 4096
steps: int = 1024
block_length: int = 1024
pretrained: str = ""
dtype: str | torch.dtype = "auto"
batch_size: int = 32
mc_num: int = 128
is_check_greedy: bool = True
device: str = "cuda"
@register_model("llada")
class LLaDAEvalHarness(LM):
def __init__(
self,
config: LLaDAEvalConfig | None = None,
**kwargs,
):
super().__init__()
if config is None:
config = LLaDAEvalConfig()
# Pull args from config, allow kwargs to override
pretrained = kwargs.get("pretrained", config.pretrained)
dtype = kwargs.get("dtype", config.dtype)
batch_size = kwargs.get("batch_size", config.batch_size)
mc_num = kwargs.get("mc_num", config.mc_num)
is_check_greedy = kwargs.get("is_check_greedy", config.is_check_greedy)
device = kwargs.get("device", config.device)
cfg = kwargs.get("cfg", config.cfg_scale)
steps = kwargs.get("steps", config.steps)
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
block_length = kwargs.get("block_length", config.block_length)
max_length = kwargs.get("max_length", config.max_length)
remasking = kwargs.get("remasking", config.remasking)
accelerator = accelerate.Accelerator()
# Get GLOBAL rank from torch.distributed (not accelerator)
if torch.distributed.is_initialized():
self._rank = torch.distributed.get_rank() # ← GLOBAL rank (0-15)
self._world_size = (
torch.distributed.get_world_size()
) # ← GLOBAL world size (16)
else:
self._rank = 0
self._world_size = 1
# Use accelerator for device placement
self.model = dllm.utils.get_model(
SimpleNamespace(model_name_or_path=pretrained, dtype=get_dtype(dtype))
)
self.model.eval()
if accelerator.num_processes > 1:
# Let accelerator handle device placement
self.model = accelerator.prepare(self.model)
self.device = (
accelerator.device
) # ← Accelerator figures out local device correctly
self.accelerator = accelerator
else:
# Single GPU
self.model = self.model.to(device)
self.device = torch.device(device)
self.accelerator = None
self.tokenizer = dllm.utils.get_tokenizer(
SimpleNamespace(model_name_or_path=pretrained, model=self.model)
)
# generation params
self.mask_id = self.tokenizer.mask_token_id
self.batch_size = int(batch_size)
self.max_length = max_length
self.max_new_tokens = int(max_new_tokens)
self.block_length = int(block_length)
self.steps = int(steps)
self.cfg = float(cfg)
self.remasking = remasking
self.is_check_greedy = is_check_greedy
# loglikelihood params
self.mc_num = int(mc_num)
assert mc_num % self.batch_size == 0
self.sampling_eps = 0.0
def apply_chat_template(
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
"""
Method to apply a chat template to a list of chat history between user and model.
"""
chat_templated = self.tokenizer.apply_chat_template(
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
return chat_templated
@property
def tokenizer_name(self) -> str:
return self.tokenizer.name_or_path.replace("/", "__")
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def _forward_process(
self, batch: torch.Tensor, prompt_index: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
b, l = batch.shape
target_len = (l - prompt_index.sum()).item()
k = torch.randint(1, target_len + 1, (), device=batch.device)
x = torch.round(
torch.linspace(
float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device
)
).long()
x = ((x - 1) % target_len) + 1
assert x.min() >= 1 and x.max() <= target_len
indices = torch.arange(target_len, device=batch.device).repeat(b, 1)
is_mask = indices < x.unsqueeze(1)
for i in range(b):
is_mask[i] = is_mask[i][torch.randperm(target_len)]
is_mask = torch.cat(
(
torch.zeros(
b, prompt_index.sum(), dtype=torch.bool, device=batch.device
),
is_mask,
),
dim=1,
)
noisy_batch = torch.where(is_mask, self.mask_id, batch)
return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l)
@torch.no_grad()
def get_logits(
self, batch: torch.Tensor, prompt_index: torch.Tensor
) -> torch.Tensor:
if self.cfg > 0.0:
assert len(prompt_index) == batch.shape[1]
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
un_batch = batch.clone()
un_batch[prompt_index] = self.mask_id
batch = torch.cat([batch, un_batch])
logits = self.model(batch).logits
if self.cfg > 0.0:
logits, un_logits = torch.chunk(logits, 2, dim=0)
logits = un_logits + (self.cfg + 1) * (logits - un_logits)
return logits[:, : batch.shape[1]]
@torch.no_grad()
def get_loglikelihood(self, prefix: torch.Tensor, target: torch.Tensor) -> float:
seq = torch.concatenate([prefix, target])[None, :]
seq = seq.repeat((self.batch_size, 1)).to(self.device)
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
loss_acc = []
for _ in range(self.mc_num // self.batch_size):
perturbed_seq, p_mask = self._forward_process(seq, prompt_index)
mask_indices = perturbed_seq == self.mask_id
logits = self.get_logits(perturbed_seq, prompt_index)
loss = (
F.cross_entropy(
logits[mask_indices], seq[mask_indices], reduction="none"
)
/ p_mask[mask_indices]
)
loss = loss.sum() / self.batch_size
loss_acc.append(loss.item())
return -sum(loss_acc) / len(loss_acc)
@torch.no_grad()
def suffix_greedy_prediction(
self, prefix: torch.Tensor, target: torch.Tensor
) -> bool:
if not self.is_check_greedy:
return False
seq = torch.full(
(1, len(prefix) + len(target)), self.mask_id, device=self.device
)
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
prefix, target = prefix.to(self.device), target.to(self.device)
seq[0, : len(prefix)] = prefix
for i in range(len(target)):
mask_index = seq == self.mask_id
logits = self.get_logits(seq, prompt_index)[mask_index]
x0 = torch.argmax(logits, dim=-1)
p = torch.softmax(logits.to(torch.float32), dim=-1)
confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(
dim=-1
)
_, index = torch.sort(confidence, descending=True)
x0[index[1:]] = self.mask_id
seq[mask_index] = x0.clone()
correct = target == seq[0, len(prefix) :]
correct = torch.all(correct)
return correct
def _encode_pair(
self, context: str, continuation: str
) -> tuple[torch.Tensor, torch.Tensor]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tokenizer(context + continuation)["input_ids"]
context_enc = self.tokenizer(context)["input_ids"]
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
def _tokenize(e):
prefix, target = self._encode_pair(e["prefix"], e["target"])
return {
"prefix_text": e["prefix"],
"target_text": e["target"],
"prefix": prefix,
"target": target,
}
ds = []
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
ds = Dataset.from_list(ds)
ds = ds.map(_tokenize)
ds = ds.with_format("torch")
prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds]
assert max(prompt_len) <= 4096
out = []
with torch.no_grad():
for elem in tqdm(ds, desc="Computing likelihood..."):
prefix = elem["prefix"]
target = elem["target"]
ll = self.get_loglikelihood(prefix, target)
is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target)
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
torch.cuda.empty_cache()
return out
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
raise NotImplementedError
def generate_until(self, requests: list[Instance]) -> list[str]:
def _tokenize(e):
return {
"question": self.tokenizer(e["question"])["input_ids"],
"question_text": e["question"],
"until": e["until"],
}
ds = [
{"question": req.args[0], "until": req.args[1]["until"]} for req in requests
]
ds = Dataset.from_list(ds)
ds = ds.map(_tokenize)
ds = ds.with_format("torch")
out = []
generator = LLaDAGenerator(model=self.model, tokenizer=self.tokenizer)
for elem in tqdm(ds, desc="Generating..."):
prompt = [elem["question"].to(self.device)]
stop_tokens = elem["until"]
generated_ids = generator.generate(
inputs=prompt,
steps=self.steps,
max_new_tokens=self.max_new_tokens,
block_length=self.block_length,
temperature=0.0,
cfg_scale=self.cfg,
remasking=self.remasking,
)
generated_answer = self.tokenizer.decode(
generated_ids[0][prompt[0].shape[0] :], skip_special_tokens=False
)
for stop_seq in stop_tokens:
if stop_seq in generated_answer:
generated_answer = generated_answer.split(stop_seq)[0]
# remove special tokens
generated_answer_ids = self.tokenizer(generated_answer)["input_ids"]
generated_answer = self.tokenizer.decode(
generated_answer_ids, skip_special_tokens=True
)
out.append(generated_answer)
if self.accelerator is not None:
self.accelerator.wait_for_everyone()
return out
if __name__ == "__main__":
cli_evaluate()

View File

@ -0,0 +1,379 @@
"""
reference: https://github.com/ML-GSAI/LLaDA/blob/main/generate.py
"""
import math
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn.functional as F
from dllm.utils.generation_utils import get_num_transfer_tokens
from dllm.core.generation.generator import (
GeneratorOutput,
GeneratorConfig,
BaseGenerator,
)
def add_gumbel_noise(logits: torch.Tensor, temperature: float) -> torch.Tensor:
"""
The Gumbel max is a method for sampling categorical distributions.
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
Thus, we use float64.
"""
if temperature == 0:
return logits
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
gumbel_noise = (-torch.log(noise)) ** temperature
return logits.exp() / gumbel_noise
@dataclass
class LLaDAGeneratorConfig(GeneratorConfig):
max_new_tokens: int = 128
max_length: int = (
None # There's no explicit length_limit except for the tokenizer/model context
)
block_length: int = 128
steps: int = 128
temperature: float = 0.0
remasking: str = "low_confidence"
stochastic_transfer: bool = False
cfg_scale: float = 0.0
cfg_keep_tokens: list[int] | None = None
@dataclass
class LLaDAGenerator(BaseGenerator):
@torch.no_grad()
def generate(
self,
inputs: list[torch.Tensor | list],
config: LLaDAGeneratorConfig | None = None,
**kwargs,
) -> GeneratorOutput | torch.Tensor:
if config is None:
config = LLaDAGeneratorConfig()
# ----- pull args from config, allow kwargs to override -----
steps = kwargs.get("steps", config.steps)
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
max_length = kwargs.get("max_length", config.max_length)
block_length = kwargs.get("block_length", config.block_length)
temperature = kwargs.get("temperature", config.temperature)
cfg_scale = kwargs.get("cfg_scale", config.cfg_scale)
cfg_keep_tokens = kwargs.get("cfg_keep_tokens", config.cfg_keep_tokens)
remasking = kwargs.get("remasking", config.remasking)
stochastic_transfer = kwargs.get(
"stochastic_transfer", config.stochastic_transfer
)
return_dict_in_generate = kwargs.get(
"return_dict_in_generate", config.return_dict_in_generate
)
assert 1 <= block_length
assert 1 <= steps
mask_id = self.tokenizer.mask_token_id
eos_id = self.tokenizer.eos_token_id
# ----- Shape bookkeeping: per-sample prompt lengths and final canvas width -----
if isinstance(inputs[0], list):
inputs = [
torch.as_tensor(p, dtype=torch.long, device=self.model.device)
for p in inputs
]
prompt_lens = [p.shape[0] for p in inputs]
if max_new_tokens:
max_length = max_new_tokens + max(prompt_lens)
else:
max_new_tokens = max_length - max(prompt_lens)
B = len(inputs)
T = max_length
# ----- Initialize canvas with EOS, copy inputs, and append mask tail -----
x = torch.full((B, T), eos_id, dtype=torch.long, device=self.model.device)
for i, p in enumerate(inputs):
x[i, : prompt_lens[i]] = p # keep original prompt tokens
x[i, prompt_lens[i] : prompt_lens[i] + max_new_tokens] = (
mask_id # append `max_new_tokens` masks to be generated
)
attention_mask = (x != eos_id).long() if B > 1 else None
# Tokens that were *given* at the start (non-mask, non-EOS).
# These will be masked in the unconditional forward pass for CFG.
# Tokens from `cfg_keep_tokens` should *not* be treated as "given" for CFG
unmasked_index = (x != mask_id) & (x != eos_id)
if not (cfg_keep_tokens is None or len(cfg_keep_tokens) == 0):
keep_mask = torch.isin(
x, torch.as_tensor(cfg_keep_tokens, device=self.model.device)
)
unmasked_index = unmasked_index & ~keep_mask
# ----- Block scheduling over the appended mask tail -----
num_blocks = math.ceil(max_new_tokens / block_length)
steps = math.ceil(steps / num_blocks) # per-block step budget
histories = [x.clone()] if return_dict_in_generate else None
for b in range(num_blocks):
# Build a per-sample mask *within this block* (aligned to each prompt's tail)
block_mask_index = torch.zeros(
(B, block_length), dtype=torch.bool, device=x.device
)
for j in range(B):
start = prompt_lens[j] + b * block_length
end = min(start + block_length, prompt_lens[j] + max_new_tokens, T)
if start < end:
width = end - start
block_mask_index[j, :width] = (
x[j, start:end] == mask_id
) # which positions in this block are still masked
# Decide how many tokens to reveal per step in this block
num_transfer_tokens = get_num_transfer_tokens(
mask_index=block_mask_index,
steps=steps,
scheduler=self.scheduler,
stochastic=stochastic_transfer,
)
# Some steps may be skipped if there are no transfers
effective_steps = num_transfer_tokens.size(1)
# ----- Iterative reveal inside the current block -----
for i in range(effective_steps):
mask_index = x == mask_id # current global mask map
# Optional CFG: second forward where original prompt tokens are masked out
if cfg_scale > 0.0:
un_x = x.clone()
un_x[unmasked_index] = mask_id
x_ = torch.cat([x, un_x], dim=0)
logits = self.model(
x_, attention_mask=attention_mask
).logits # Use attention mask here
logits, un_logits = torch.chunk(logits, 2, dim=0)
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
else:
logits = self.model(
x, attention_mask=attention_mask
).logits # Use attention mask here
# Argmax decoding with optional Gumbel-Max noise for exploration
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0 = torch.argmax(
logits_with_noise, dim=-1
) # [B, T] predicted token ids
# Per-position confidence used to pick which masks to commit this step
if remasking == "low_confidence":
p = F.softmax(logits, dim=-1)
x0_p = torch.squeeze(
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1
) # [B, T] confidence of predicted token
elif remasking == "random":
x0_p = torch.rand(
(x0.shape[0], x0.shape[1]), device=x0.device
) # random scores
else:
raise NotImplementedError(remasking)
# Restrict selection window to the *current block's* tail region
for j in range(B):
x0_p[j, prompt_lens[j] + (b + 1) * block_length :] = -np.inf
# Only allow updates at currently masked positions; keep others fixed
x0 = torch.where(mask_index, x0, x)
confidence = torch.where(
mask_index, x0_p, -np.inf
) # consider masked positions only
# Pick exactly `num_transfer_tokens[j, i]` highest-confidence positions per sample
transfer_index = torch.zeros_like(
x0, dtype=torch.bool, device=x0.device
)
for j in range(confidence.shape[0]):
_, select_index = torch.topk(
confidence[j], k=num_transfer_tokens[j, i]
)
transfer_index[j, select_index] = True
# Commit chosen predictions into the canvas
x[transfer_index] = x0[transfer_index]
if histories is not None:
histories.append(x.clone())
# ----- Output format -----
if not return_dict_in_generate:
return x
else:
return GeneratorOutput(sequences=x, histories=histories)
@torch.no_grad()
def infill(
self, inputs: list[torch.Tensor | list], config, **kwargs
) -> GeneratorOutput | torch.Tensor:
"""
Fill in-place the <|mdm_mask|> tokens contained in `inputs`.
The whole (padded) sequence is split into block windows of length
`block_length`; within each window we progressively "unmask" positions
according to the scheduler and chosen remasking strategy.
Notes:
- Right padding uses EOS.
- CFG masks out *originally known* (non-mask, non-EOS) tokens in the
unconditional branch, identical to `generate`.
- Only masked positions are ever updated; non-mask tokens are left intact.
"""
# ----- pull args from config, allow kwargs to override -----
steps = kwargs.get("steps", config.steps)
block_length = kwargs.get("block_length", config.block_length)
temperature = kwargs.get("temperature", config.temperature)
cfg_scale = kwargs.get("cfg_scale", config.cfg_scale)
cfg_keep_tokens = kwargs.get("cfg_keep_tokens", config.cfg_keep_tokens)
remasking = kwargs.get("remasking", config.remasking)
stochastic_transfer = kwargs.get(
"stochastic_transfer", config.stochastic_transfer
)
return_dict_in_generate = kwargs.get(
"return_dict_in_generate", config.return_dict_in_generate
)
mask_id = self.tokenizer.mask_token_id
eos_id = self.tokenizer.eos_token_id
# ----- Build canvas: right-pad with EOS to the max length in the batch -----
if isinstance(inputs[0], list):
inputs = [
torch.as_tensor(p, dtype=torch.long, device=self.model.device)
for p in inputs
]
B = len(inputs)
seq_lens = [t.shape[0] for t in inputs]
T = max(seq_lens)
# Default to a single block spanning the whole sequence
if block_length is None:
block_length = T
assert 1 <= block_length
assert 1 <= steps
x = torch.full((B, T), eos_id, dtype=torch.long, device=self.model.device)
for i, t in enumerate(inputs):
x[i, : seq_lens[i]] = t
attention_mask = (x != eos_id).long() if B > 1 else None
# Tokens that were *given* at the start (non-mask, non-EOS).
# These will be masked in the unconditional forward pass for CFG.
# Tokens from `cfg_keep_tokens` should *not* be treated as "given" for CFG
unmasked_index = (x != mask_id) & (x != eos_id)
if not (cfg_keep_tokens is None or len(cfg_keep_tokens) == 0):
keep_mask = torch.isin(
x, torch.as_tensor(cfg_keep_tokens, device=self.model.device)
)
unmasked_index = unmasked_index & ~keep_mask
# ----- Blockwise schedule over the *entire* (padded) sequence -----
num_blocks = math.ceil(T / block_length)
steps_per_block = math.ceil(steps / num_blocks)
histories = [x.clone()] if return_dict_in_generate else None
# Create attention mask where eos_token_id is masked (set to 0)
attention_mask = (x != eos_id).long()
for b in range(num_blocks):
start = b * block_length
stop = min(start + block_length, T)
# Per-sample view of which positions in this block are masks
block_mask_index = torch.zeros(
(B, block_length), dtype=torch.bool, device=self.model.device
)
widths = []
for j in range(B):
# Width limited by sample's true length and sequence end
width = max(0, min(seq_lens[j], stop) - start)
widths.append(width)
if width > 0:
block_mask_index[j, :width] = x[j, start : start + width] == mask_id
# Decide how many tokens to reveal at each step in this block
num_transfer_tokens = get_num_transfer_tokens(
mask_index=block_mask_index,
steps=steps_per_block,
scheduler=self.scheduler,
stochastic=stochastic_transfer,
)
# Some blocks may have no masks => effective_steps == 0
effective_steps = num_transfer_tokens.size(1)
for s in range(effective_steps):
mask_index_full = x == mask_id
# ----- Forward pass (+ optional CFG) -----
if cfg_scale > 0.0:
un_x = x.clone()
un_x[unmasked_index] = mask_id
x_ = torch.cat([x, un_x], dim=0)
logits = self.model(
x_, attention_mask=attention_mask
).logits # Use attention mask here
logits, un_logits = torch.chunk(logits, 2, dim=0)
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
else:
logits = self.model(
x, attention_mask=attention_mask
).logits # Use attention mask here
# Greedy with optional Gumbel-Max noise
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0 = torch.argmax(logits_with_noise, dim=-1) # [B, T]
# Confidence used for choosing which masks to commit this step
if remasking == "low_confidence":
p = F.softmax(logits, dim=-1)
x0_p = torch.gather(p, dim=-1, index=x0.unsqueeze(-1)).squeeze(
-1
) # [B, T]
elif remasking == "random":
x0_p = torch.rand((B, T), device=self.model.device)
else:
raise NotImplementedError(remasking)
# Restrict selection to the *current* block only
for j in range(B):
end_j = start + widths[j]
# Outside current block => impossible to select
x0_p[j, :start] = -np.inf
x0_p[j, end_j:] = -np.inf
# Only consider currently-masked positions as candidates
x0 = torch.where(mask_index_full, x0, x)
confidence = torch.where(mask_index_full, x0_p, -np.inf)
# Pick exactly num_transfer_tokens[j, s] positions per sample
transfer_index = torch.zeros_like(x, dtype=torch.bool)
for j in range(B):
k = int(num_transfer_tokens[j, s].item())
if k > 0:
_, select_idx = torch.topk(confidence[j], k=k)
transfer_index[j, select_idx] = True
# Commit selected predictions into the canvas
x[transfer_index] = x0[transfer_index]
if histories is not None:
histories.append(x.clone())
# ----- Output format -----
if not return_dict_in_generate:
return x
else:
return GeneratorOutput(sequences=x, histories=histories)

View File

@ -0,0 +1,19 @@
from .configuration_llada import LLaDAConfig
from .modeling_llada import LLaDAModelLM
from .configuration_lladamoe import LLaDAMoEConfig
from .modeling_lladamoe import LLaDAMoEModelLM
# Register with HuggingFace Auto classes for local usage
try:
from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM
AutoConfig.register("llada", LLaDAConfig)
AutoModel.register(LLaDAConfig, LLaDAModelLM)
AutoModelForMaskedLM.register(LLaDAConfig, LLaDAModelLM)
AutoConfig.register("lladamoe", LLaDAMoEConfig)
AutoModel.register(LLaDAMoEConfig, LLaDAMoEModelLM)
AutoModelForMaskedLM.register(LLaDAMoEConfig, LLaDAMoEModelLM)
except ImportError:
# transformers not available or Auto classes not imported
pass

View File

@ -0,0 +1,459 @@
"""
LLaDA configuration
"""
from transformers import PretrainedConfig
from enum import Enum
from os import PathLike
from typing import Union
from dataclasses import asdict, dataclass, field
from glob import glob
from pathlib import Path
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)
__all__ = [
"ActivationType",
"ActivationCheckpointingStrategy",
"BlockType",
"LayerNormType",
"InitFnType",
"ModelConfig",
]
PathOrStr = Union[str, PathLike]
class StrEnum(str, Enum):
"""
This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
We include this here for compatibility with older version of Python.
"""
def __str__(self) -> str:
return self.value
def __repr__(self) -> str:
return f"'{str(self)}'"
class LayerNormType(StrEnum):
default = "default"
"""
The default LayerNorm implementation, equivalent to PyTorch's built-in version.
"""
low_precision = "low_precision"
"""
A low-precision version of the default LayerNorm.
"""
rms = "rms"
"""
An RMSNorm implementation. When using ``torch.compile`` this is
probably the fastest implementation.
"""
gemma_rms = "gemma_rms"
"""
An RMSNorm implementation by gemmma. When using ``torch.compile`` this is
probably the fastest implementation.
"""
amd_compatible = "amd_compatible"
"""
LayerNorm implemented manually to work around an issue with ROCm.
"""
class ActivationType(StrEnum):
gelu = "gelu"
relu = "relu"
silu = "silu"
swiglu = "swiglu"
class BlockType(StrEnum):
sequential = "sequential"
parallel = "parallel"
llama = "llama"
"""
A block similar to the sequential block with slightly different
implementations of operations like attention to imitate the behavior of Llama.
"""
class InitFnType(StrEnum):
mitchell = "mitchell"
"""
The strategy suggested to us by Mitchell Wortsman from UW.
This uses a truncated normal distribution with an adaptive standard deviation that depends
on the size of the weights as well as the depth of the layer.
"""
normal = "normal"
"""
All weights are initialized from the same normal distribution.
"""
kaiming_normal = "kaiming_normal"
"""
All weights are initialized with the Kaiming method from a normal distribution.
Note this currently won't work with FSDP.
"""
fan_in = "fan_in"
"""
"Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
is the input dimensionality of the kernel.
"""
full_megatron = "full_megatron"
"""
This is what metaseq calls "full megatron init". It is the init used for Llama 2.
"""
@dataclass
class ModelConfig():
"""
LLaDA (model) configuration.
"""
# Note that the defaults for these attributes are equivalent to the base GPT2 model.
d_model: int = 768
"""
The hidden size of the model.
"""
n_heads: int = 12
"""
The number of self-attention heads.
"""
n_kv_heads: Optional[int] = None
"""
The number of heads to use for keys and values. Defaults to `n_heads`.
Set this to ``None`` or ``n_heads`` for normal multi-head attention.
Set this to 1 for multi-query attention.
Set it to some in-between value for Llama2-style grouped query attention.
"""
n_layers: int = 12
"""
The number of layers/blocks.
"""
mlp_ratio: int = 4
"""
The ratio of the inner MLP dimensionality to ``d_model``.
This is only used when ``mlp_hidden_size`` is not set.
"""
mlp_hidden_size: Optional[int] = None
"""
Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
"""
activation_type: ActivationType = ActivationType.swiglu
"""
The activation function to use within the MLP layers.
"""
block_type: BlockType = BlockType.sequential
"""
The transformer block implementation.
"""
block_group_size: int = 1
"""
The number of blocks to group together into a single parent block.
This has no affect on the number of parameters in the model and is only used to wrap groups
of blocks together with a single FSDP wrapper during training.
"""
alibi: bool = False
"""
If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
"""
alibi_bias_max: float = 8.0
"""
Maximum absolute value of ALiBi bias.
"""
rope: bool = False
"""
Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
"""
rope_full_precision: bool = True
"""
If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
apply RoPE at the precision of the input.
"""
flash_attention: bool = False
"""
If ``True``, use ``FlashAttention``.
"""
attention_dropout: float = 0.1
"""
The dropout probability within the attention modules.
"""
multi_query_attention: Optional[bool] = None
"""
Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters
and is more efficient during inference.
"""
attention_layer_norm: bool = False
"""
Apply layer norm to the keys and queries within the attention mechanism.
This can help stabilize training.
"""
residual_dropout: float = 0.1
"""
The dropout probability for the MLP and attention output within each block.
"""
embedding_dropout: float = 0.1
"""
The dropout probability for embeddings.
"""
input_emb_norm: bool = False
"""
An input hidden_states norm implementation by gemmma.
"""
layer_norm_type: LayerNormType = LayerNormType.default
"""
The layernorm implementation to use.
"""
layer_norm_with_affine: bool = True
"""
Whether to include bias and weight parameters for the layer norms.
This only affects layer norms that are immediately followed by a linear layer in the forward pass,
so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
to ``False``.
"""
rms_norm_eps: float = 1e-05
"""
The rms layernorm eps param.
"""
attention_layer_norm_with_affine: bool = True
"""
Toggle affine transform for the QK norms.
"""
max_sequence_length: int = 1024
"""
The maximum input sequence length supported by the model.
"""
rope_theta: float = 10000.0
"""
The rope base param.
"""
include_qkv_bias: Optional[bool] = False
"""
Whether or not to include bias parameters in qkv linear layers.
"""
include_bias: bool = False
"""
Whether or not to include bias parameters in linear layers.
In PaLM, they got rid of all bias terms because they found that large
models tend to have near 0 bias terms anyway.
"""
bias_for_layer_norm: Optional[bool] = None
"""
Whether or not to include bias parameters in layer norm.
This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
layer norm.
When this is None (the default), it inherits the setting from include_bias.
"""
scale_logits: bool = False
"""
If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
"""
vocab_size: int = 50257
"""
Vocabulary size of the model.
"""
embedding_size: Optional[int] = 50304
"""
The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
next multiple of 128 that's greater than ``vocab_size`` can improve throughput
substantially.
"""
weight_tying: bool = True
"""
Whether to tie output linear weights to the input embedding.
"""
eos_token_id: int = 50256
"""
The ID of the end-of-sentence special token.
"""
pad_token_id: int = 50256
"""
The ID of the token to use for padding. Defaults to the ID of the EOS token.
"""
mask_token_id: Optional[int] = 50256
"""
The ID of the token to use for mask token. Defaults to the ID of the EOS token.
"""
init_device: Optional[str] = None
"""
The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
"""
init_fn: InitFnType = InitFnType.normal
"""
The weight initialization strategy.
"""
init_std: float = 0.02
"""
The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
as "normal".
"""
init_cutoff_factor: Optional[float] = None
"""
A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
as "normal". Setting this to None means values are not cutoff.
"""
precision: Optional[str] = None
"""
Precision used to train/evaluate with. You shouldn't set this directly.
See :data:`TrainConfig.precision` instead.
"""
@property
def effective_n_kv_heads(self) -> int:
if self.n_kv_heads is None:
if self.multi_query_attention is True:
return 1
else:
return self.n_heads
else:
if self.multi_query_attention is None:
return self.n_kv_heads
if self.multi_query_attention:
n_kv_heads_should_be = 1
else:
n_kv_heads_should_be = self.n_heads
if self.n_kv_heads == n_kv_heads_should_be:
return n_kv_heads_should_be
else:
raise Exception(
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
)
class ActivationCheckpointingStrategy(StrEnum):
whole_layer = "whole_layer"
"""
Checkpoint every transformer layer.
"""
one_in_two = "one_in_two"
"""
Checkpoint one in two transformer layers.
"""
one_in_three = "one_in_three"
"""
Checkpoint one in three transformer layers.
"""
one_in_four = "one_in_four"
"""
Checkpoint one in four transformer layers.
"""
two_in_three = "two_in_three"
"""
Checkpoint two out of every three transformer layers.
"""
three_in_four = "three_in_four"
"""
Checkpoint three out of four of every transformer layers.
"""
four_in_five = "four_in_five"
"""
Checkpoint four out of five of every transformer layers.
"""
nine_in_ten = "nine_in_ten"
"""
Checkpoint nine out of ten of every transformer layers.
"""
fine_grained = "fine_grained"
"""
Focus checkpointing on where it is cheap to recompute and saves most memory.
"""
class LLaDAConfig(PretrainedConfig):
model_type = "llada"
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
def __init__(self, use_cache: bool = False, **kwargs):
model_config = ModelConfig()
all_kwargs = model_config.__dict__
all_kwargs.update(kwargs)
all_kwargs.update({"use_cache": use_cache})
all_kwargs.update(
{
"architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
}
)
super().__init__(**all_kwargs)
@property
def num_attention_heads(self):
return self.n_heads
@property
def num_hidden_layers(self):
return self.n_layers
@property
def hidden_size(self):
return self.d_model

View File

@ -0,0 +1,96 @@
"""
LLaDA MoE configuration
"""
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
class LLaDAMoEConfig(PretrainedConfig):
model_type = "lladamoe"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=-1,
hidden_size=-1,
dense_intermediate_size=-1,
expert_intermediate_size=-1,
shared_expert_intermediate_size=-1,
num_hidden_layers=-1,
num_attention_heads=-1,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=False,
pad_token_id=1,
bos_token_id=None,
eos_token_id=50279,
tie_word_embeddings=False,
rope_theta=-1,
partial_rotary_factor=-1,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
clip_qkv=None,
num_experts_per_tok=-1,
num_experts=-1,
output_router_logits=False,
router_aux_loss_coef=0.01,
norm_topk_prob=None,
qk_layernorm=None,
moe_layer_freq=[],
moe_router_enable_expert_bias=None,
moe_router_score_function=None,
routed_scaling_factor=1,
router_num_group=-2,
router_topk_group=-2,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.expert_intermediate_size = expert_intermediate_size
self.dense_intermediate_size = dense_intermediate_size
self.shared_expert_intermediate_size = shared_expert_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.clip_qkv = clip_qkv
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.norm_topk_prob = norm_topk_prob
self.qk_layernorm = qk_layernorm
self.moe_layer_freq = moe_layer_freq
self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
self.moe_router_score_function = moe_router_score_function
self.partial_rotary_factor = partial_rotary_factor
self.routed_scaling_factor = routed_scaling_factor
self.router_num_group = router_num_group
self.router_topk_group = router_topk_group
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,3 @@
from dllm.core.trainers import MDLMTrainer
LLaDATrainer = MDLMTrainer

View File

@ -0,0 +1,7 @@
# from dllm.pipelines.rnd import generate, trainer
from . import models
from .models import RND1LM, RND1Config, RND1GenerationConfig
# from dllm.pipelines.rnd.models.modeling_rnd import RND1LM
# from dllm.pipelines.rnd.models.configuration_rnd import RND1Config
from .trainer import RNDTrainer

View File

@ -0,0 +1,53 @@
# Copyright 2025 Radical Numerics Inc.
#
# This source code is licensed under the Apache License, Version 2.0, found in the
# LICENSE file in the root directory of this source tree.
"""
Radical Numerics Diffusion (RND1) - Diffusion-based Language Model.
"""
from .configuration_rnd import RND1Config
from .modeling_rnd import (
RND1LM,
RND1Model,
RND1PreTrainedModel,
RND1Attention,
RND1DecoderLayer,
RND1SparseMoeBlock,
)
from .generation_config import RND1GenerationConfig
from .generation_utils import RND1GenerationMixin
from .sampling import (
diffusion_sample,
apply_top_k_filtering,
apply_top_p_filtering,
)
from .terminal_visualizer import TerminalVisualizer, SimpleProgressBar
__version__ = "0.1.0"
__all__ = [
"RND1Config",
"RND1GenerationConfig",
"RND1LM",
"RND1Model",
"RND1PreTrainedModel",
"RND1Attention",
"RND1DecoderLayer",
"RND1SparseMoeBlock",
"RND1GenerationMixin",
"TerminalVisualizer",
"SimpleProgressBar",
]
# Register with HuggingFace Auto classes for local usage
try:
from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM
AutoConfig.register("rnd1", RND1Config)
AutoModel.register(RND1Config, RND1LM)
AutoModelForMaskedLM.register(RND1Config, RND1LM)
except ImportError:
# transformers not available or Auto classes not imported
pass

View File

@ -0,0 +1,124 @@
# Copyright 2025 Radical Numerics Inc.
#
# This source code is licensed under the Apache License, Version 2.0, found in the
# LICENSE file in the root directory of this source tree.
"""
RND1 Model Configuration.
This module defines the configuration class for RND1 models.
The default settings are derived from Qwen/Qwen3-30B-A3B and augmented
with RND1-specific parameters.
"""
from transformers.configuration_utils import PretrainedConfig
# Qwen3-30B-A3B / checkpoint defaults
CONFIG_DEFAULTS = {
"attention_bias": False,
"attention_dropout": 0.0,
"decoder_sparse_step": 1,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 6144,
"max_position_embeddings": 40960,
"max_window_layers": 48,
"mlp_only_layers": [],
"moe_intermediate_size": 768,
"norm_topk_prob": True,
"num_attention_heads": 32,
"num_experts": 128,
"num_experts_per_tok": 8,
"num_hidden_layers": 48,
"num_key_value_heads": 4,
"output_router_logits": False,
"pad_token_id": 151643,
"rms_norm_eps": 1e-06,
"rope_scaling": False,
"rope_theta": 1000000.0,
"router_aux_loss_coef": 0.001,
"sliding_window": False,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"use_cache": False,
"use_sliding_window": False,
"vocab_size": 151936,
}
class RND1Config(PretrainedConfig):
"""
Configuration class for RND1 models.
This configuration extends Qwen3MoeConfig with additional parameters
specific to the RND1 (Radical Numerics Diffusion v1) architecture.
Args:
moe_backend: Backend for MoE computation ("hf", "vllm", "sglang" or "flashinfer")
num_diffusion_steps: Default number of diffusion steps for generation
mask_token_id: Token ID used for masking (default: 151669 for Qwen)
**kwargs: Additional arguments passed to Qwen3MoeConfig
"""
model_type = "rnd1"
def __init__(
self,
moe_backend: str = "hf",
num_diffusion_steps: int = 256,
mask_token_id: int = 151669,
**kwargs,
):
# Force non-causal and no caching for RND1
kwargs["use_cache"] = False
kwargs["is_causal"] = False
super().__init__(**kwargs)
# Set defaults after pretrained init to prevent overrides
self.set_config_defaults()
# QoL: set attn impl directly from config
if "attn_implementation" in kwargs:
self._attn_implementation = kwargs["attn_implementation"]
# RND1-specific parameters
self.moe_backend = moe_backend
self.num_diffusion_steps = num_diffusion_steps
self.mask_token_id = mask_token_id
# Ensure bidirectional attention and no caching
self.is_causal = False
self.use_cache = False
def set_config_defaults(self):
"""
Ensure model defaults are set according to final training checkpoint
Qwen3MoeConfig defaults don't match Qwen/Qwen3-30B-A3B settings from which
RND1 is derived.
"""
for k, v in CONFIG_DEFAULTS.items():
setattr(self, k, v)
def to_dict(self):
"""
Serializes configuration to dictionary with auto_map for Hub.
The auto_map ensures that when users load from HuggingFace Hub,
the correct custom classes are automatically resolved.
"""
data = super().to_dict()
data.setdefault(
"auto_map",
{
"AutoConfig": "configuration_rnd.RND1Config",
"AutoModel": "modeling_rnd.RND1Model",
"AutoModelForMaskedLM": "modeling_rnd.RND1LM",
},
)
return data

View File

@ -0,0 +1,77 @@
# Copyright 2025 Radical Numerics Inc.
#
# This source code is licensed under the Apache License, Version 2.0, found in the
# LICENSE file in the root directory of this source tree.
"""
RND1 Generation Configuration.
This module defines the generation configuration for RND1 models,
controlling the diffusion-based generation process.
"""
from typing import Optional
from transformers.generation.configuration_utils import GenerationConfig
class RND1GenerationConfig(GenerationConfig):
"""
Configuration class for RND1 generation parameters.
This class extends the base GenerationConfig to include parameters
specific to diffusion-based language generation.
Args:
max_length: Maximum sequence length
num_diffusion_steps: Number of denoising steps in the diffusion process
mask_token_id: Token ID used for masking during diffusion
temperature: Temperature for sampling (higher = more random)
top_k: Optional top-k filtering
top_p: Optional nucleus (top-p) filtering
greedy: Whether to use greedy decoding (True) or stochastic sampling (False)
**kwargs: Additional arguments passed to GenerationConfig
"""
def __init__(
self,
max_length: int = 256,
num_diffusion_steps: int = 256,
mask_token_id: int = 151669,
temperature: float = 0.1,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
greedy: bool = False,
bos_token_id: int = None,
eos_token_id: int = None,
pad_token_id: int = None,
use_cache: bool = False,
**kwargs,
):
# Force no caching for RND generation
# kwargs['use_cache'] = False
kwargs.pop('use_cache', None)
super().__init__(
max_length=max_length,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=not greedy,
use_cache=False,
**kwargs,
)
# RND-specific parameters
self.num_diffusion_steps = num_diffusion_steps
self.mask_token_id = mask_token_id
self.greedy = greedy
def to_dict(self):
"""Convert configuration to dictionary."""
output = super().to_dict()
output["num_diffusion_steps"] = self.num_diffusion_steps
output["mask_token_id"] = self.mask_token_id
output["greedy"] = self.greedy
return output

View File

@ -0,0 +1,187 @@
# Copyright 2025 Radical Numerics Inc.
#
# This source code is licensed under the Apache License, Version 2.0, found in the
# LICENSE file in the root directory of this source tree.
"""
RND1 Generation Utilities.
This module provides generation utilities and mixins for RND1 models,
including the main GenerationMixin class that integrates with HuggingFace.
"""
import torch
import torch.nn as nn
from typing import Optional, Union, Dict, Any
from transformers import GenerationMixin as HFGenerationMixin
from transformers.generation import GenerationConfig
from .sampling import diffusion_sample, apply_top_k_filtering, apply_top_p_filtering
class RND1GenerationMixin(HFGenerationMixin):
"""
Generation mixin for RND1 models.
This mixin provides generation methods compatible with HuggingFace's
generation API while using RND1's diffusion-based sampling internally.
"""
def generate(
self,
inputs: Optional[torch.LongTensor] = None,
generation_config: Optional[GenerationConfig] = None,
# RND1-specific parameters
prefix_ids: Optional[torch.LongTensor] = None,
suffix_ids: Optional[torch.LongTensor] = None,
infill_length: Optional[int] = None,
return_dict_in_generate: Optional[bool] = None,
**kwargs, # Accept all kwargs to be compatible with pipelines
) -> Union[torch.LongTensor, Dict[str, Any]]:
"""
Generate text using RND1's diffusion-based sampling.
Follows HuggingFace's standard generate API, using diffusion sampling
internally. Supports both standard generation and infilling.
Args:
inputs: Input token IDs to use as prefix (standard HF parameter)
generation_config: Generation configuration object
prefix_ids: Alternative to inputs for infilling tasks
suffix_ids: Optional suffix for infilling tasks
infill_length: Length of infill region (for infilling)
return_dict_in_generate: Whether to return GenerateDecoderOnlyOutput
**kwargs: Additional arguments (accepted for compatibility)
Returns:
Generated token IDs or GenerateDecoderOnlyOutput
"""
if generation_config is not None:
gen_config = generation_config
model_kwargs = kwargs.copy()
else:
# Only prepare config from kwargs if no config was provided
gen_config, model_kwargs = self._prepare_generation_config(None, **kwargs)
device = next(self.parameters()).device
if inputs is not None:
prefix_ids = inputs.to(device)
elif prefix_ids is not None:
prefix_ids = prefix_ids.to(device)
else:
prefix_ids = None
if suffix_ids is not None:
suffix_ids = suffix_ids.to(device)
eos_token_id = gen_config.eos_token_id or getattr(self.config, "eos_token_id", 151645)
eos_token_id = None if eos_token_id == -1 else eos_token_id
pad_token_id = gen_config.pad_token_id or getattr(self.config, "pad_token_id", None)
bos_token_id = gen_config.bos_token_id or getattr(self.config, "bos_token_id", None)
mask_token_id = getattr(gen_config, "mask_token_id", getattr(self.config, "mask_token_id", 151669))
if infill_length is not None and prefix_ids is not None:
# Infilling mode: use specified infill_length
prefix_len = prefix_ids.shape[1] if prefix_ids is not None else 0
suffix_len = suffix_ids.shape[1] if suffix_ids is not None else 0
seq_len = prefix_len + infill_length + suffix_len
else:
# Standard generation mode
if prefix_ids is not None:
prefix_len = prefix_ids.shape[1]
if gen_config.max_new_tokens is not None:
seq_len = prefix_len + gen_config.max_new_tokens
else:
seq_len = gen_config.max_length or self.config.max_position_embeddings
else:
seq_len = gen_config.max_length or self.config.max_position_embeddings
num_diffusion_steps = getattr(gen_config, "num_diffusion_steps",
getattr(self.config, "num_diffusion_steps", 256))
temperature = float(getattr(gen_config, "temperature", 1.0))
top_k = getattr(gen_config, "top_k", None)
top_p = getattr(gen_config, "top_p", None)
greedy = getattr(gen_config, "greedy",
not bool(gen_config.do_sample) if hasattr(gen_config, "do_sample") else True)
with torch.inference_mode():
sequences = diffusion_sample(
model=self,
seq_len=seq_len,
num_steps=num_diffusion_steps,
mask_token_id=mask_token_id,
temperature=temperature,
top_k=top_k,
top_p=top_p,
greedy=greedy,
prefix_ids=prefix_ids,
suffix_ids=suffix_ids,
infill_length=infill_length,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
device=device,
visualizer=model_kwargs.get("visualizer", None), # Optional visualizer from kwargs
)
if return_dict_in_generate or getattr(gen_config, "return_dict_in_generate", False):
from transformers.generation.utils import GenerateDecoderOnlyOutput
return GenerateDecoderOnlyOutput(sequences=sequences)
return sequences
def generate_with_visualization(
self,
tokenizer,
inputs: Optional[torch.LongTensor] = None,
generation_config: Optional[GenerationConfig] = None,
suffix_ids: Optional[torch.LongTensor] = None,
infill_length: Optional[int] = None,
**kwargs,
) -> torch.LongTensor:
"""
Generate with live visualization (for demos).
This method requires a tokenizer to display the generation process.
For production use, prefer `generate()`.
Args:
tokenizer: Tokenizer for decoding tokens to text
inputs: Input token IDs to use as prefix
generation_config: Generation configuration object
suffix_ids: Optional suffix token IDs
infill_length: Length of infill region
**kwargs: Additional arguments for backward compatibility
Returns:
Generated token IDs as LongTensor
"""
from .terminal_visualizer import TerminalVisualizer
visualizer = TerminalVisualizer(tokenizer, show_visualization=True)
return self.generate(
inputs=inputs,
generation_config=generation_config,
suffix_ids=suffix_ids,
infill_length=infill_length,
visualizer=visualizer,
return_dict_in_generate=False,
**kwargs,
)
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
**kwargs,
) -> Dict[str, Any]:
"""
Prepare inputs for generation (required by HuggingFace).
For RND1, we don't use the standard autoregressive generation,
so this just returns the input_ids.
"""
return {"input_ids": input_ids}

View File

@ -0,0 +1,653 @@
# Copyright 2025 Radical Numerics Inc.
#
# This source code is licensed under the Apache License, Version 2.0, found in the
# LICENSE file in the root directory of this source tree.
"""
RND1 model implementation.
This module implements the RND1 architecture with bidirectional attention for
diffusion-based language modeling. Includes support for Mixture of Experts (MoE)
with multiple backend options (HF, vLLM, SGLang, FlashInfer).
Based on the Qwen3Moe architecture:
https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
"""
from __future__ import annotations
import os
from typing import Optional, Tuple, List, Union
import torch
from torch import nn
from transformers.utils import logging
from transformers.cache_utils import Cache
from transformers.modeling_outputs import (
MoeModelOutputWithPast,
MaskedLMOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationConfig
from .configuration_rnd import RND1Config
from .generation_utils import RND1GenerationMixin
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
Qwen3MoeRMSNorm,
Qwen3MoeRotaryEmbedding,
Qwen3MoeMLP,
apply_rotary_pos_emb
)
import torch.nn.functional as F
try:
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts as fused_experts_vllm, fused_topk as fused_topk_vllm
from vllm.model_executor.layers.layernorm import RMSNorm as VLLMRMSNorm
except Exception:
fused_experts_vllm = None
fused_topk_vllm = None
try:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe as sglang_fused_moe
# from sglang.srt.layers.layernorm import RMSNorm as SGLangRMSNorm # TODO: buggy atm
from sglang.srt.layers.moe.topk import StandardTopKOutput
except Exception:
sglang_fused_moe = None
StandardTopKOutput = None
try:
import flashinfer.fused_moe as fused_moe
## TODO: below needs flashinfer>=0.4.0, but has some bug atm
# from flashinfer.norm import rmsnorm as flashinfer_rmsnorm
# class FlashInferRMSNorm(Qwen3MoeRMSNorm):
# """Wrapper around FlashInfer RMSNorm to match Qwen3MoeRMSNorm interface"""
# def forward(self, hidden_states):
# return flashinfer_rmsnorm(hidden_states, self.weight, self.variance_epsilon)
except Exception:
fused_moe = None
logger = logging.get_logger(__name__)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Expand key/value heads to match query heads for grouped-query attention."""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class RND1Attention(nn.Module):
"""RND1 attention layer with bidirectional attention for diffusion modeling."""
def __init__(self, config: RND1Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.scaling = self.head_dim ** -0.5
self.attention_dropout = config.attention_dropout
self.is_causal = False
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
if config.moe_backend == "vllm":
RMSNormClass = VLLMRMSNorm
else:
RMSNormClass = Qwen3MoeRMSNorm
self.q_norm = RMSNormClass(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = RMSNormClass(self.head_dim, eps=config.rms_norm_eps)
self.sliding_window = getattr(config, "sliding_window", None)
self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
dual_cache: Optional[bool] = False,
replace_position: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]]]:
bsz, q_len, _ = hidden_states.size()
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
use_sdpa = (getattr(self.config, "_attn_implementation", "eager") == "sdpa")
if use_sdpa:
if attention_mask is not None and isinstance(attention_mask, torch.Tensor):
if attention_mask.dtype not in [torch.bool, torch.float32, torch.float16, torch.bfloat16]:
attention_mask = attention_mask.to(dtype=query_states.dtype)
assert not self.is_causal, f"Attention layer {self.layer_idx} is causal"
attn_out = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states,
attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=self.is_causal,
)
attn_out = attn_out.transpose(1, 2).contiguous()
attn_out = attn_out.view(bsz, q_len, self.num_heads * self.head_dim)
attn_out = self.o_proj(attn_out)
return attn_out, None
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if attention_mask is not None:
# TODO: modify this to boolean masks
attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_out = torch.matmul(attn_weights, value_states)
attn_out = attn_out.transpose(1, 2).contiguous().view(hidden_states.size(0), hidden_states.size(1), -1)
attn_out = self.o_proj(attn_out)
return attn_out, None
class RND1DecoderLayer(nn.Module):
"""RND1 decoder layer with bidirectional attention for diffusion language modeling."""
def __init__(self, config: RND1Config, layer_idx: int):
super().__init__()
self.self_attn = RND1Attention(config, layer_idx)
self.mlp = RND1SparseMoeBlock(config)
if config.moe_backend == "vllm":
RMSNormClass = VLLMRMSNorm
else:
RMSNormClass = Qwen3MoeRMSNorm
self.input_layernorm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
replace_position: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[torch.Tensor]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_out, attn_weights = self.self_attn(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
position_embeddings=position_embeddings,
replace_position=replace_position,
)
hidden_states = residual + attn_out
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
ff_out = self.mlp(hidden_states)
if isinstance(ff_out, tuple):
ff_out = ff_out[0]
hidden_states = residual + ff_out
return hidden_states, attn_weights
class RND1SparseMoeBlock(nn.Module):
"""RND1 Sparse MoE block with multiple backend support (HF, vLLM, SGLang, FlashInfer)."""
def __init__(self, config: RND1Config):
super().__init__()
self.config = config
self.backend = getattr(config, "moe_backend", "hf")
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.hidden_size = config.hidden_size
self.intermediate_size = getattr(config, "moe_intermediate_size", config.intermediate_size)
self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
self.experts = nn.ModuleList(
[Qwen3MoeMLP(config, intermediate_size=self.intermediate_size) for _ in range(self.num_experts)]
)
# Cached weight tensors for optimized backends
self._w1 = None
self._w2 = None
if self.backend == "sglang":
if sglang_fused_moe is None or StandardTopKOutput is None:
raise RuntimeError("sglang is not available, cannot use sglang backend")
elif self.backend == "flashinfer":
if fused_moe is None:
raise RuntimeError("flashinfer is not available, cannot use flashinfer backend")
elif self.backend == "vllm":
if fused_experts_vllm is None or fused_topk_vllm is None:
raise RuntimeError("vllm is not available, cannot use vllm backend")
@torch.no_grad()
def _initialize_weights(
self,
free_experts: bool = True,
mode: str = "vllm",
) -> None:
logger.info(f"Initializing weights for {mode} backend")
# Stack directly on device where weights already reside (loaded by HF)
gate_list: List[torch.Tensor] = []
up_list: List[torch.Tensor] = []
down_list: List[torch.Tensor] = []
# Collect weight references without any device moves
for expert in self.experts:
gate_list.append(expert.gate_proj.weight.data)
up_list.append(expert.up_proj.weight.data)
down_list.append(expert.down_proj.weight.data)
gate_w_stacked = torch.stack(gate_list, dim=0).contiguous()
up_w_stacked = torch.stack(up_list, dim=0).contiguous()
down_w_stacked = torch.stack(down_list, dim=0).contiguous()
if mode == "flashinfer":
w1 = torch.cat([up_w_stacked, gate_w_stacked], dim=1) # FlashInfer expects [up; gate] ordering
else:
w1 = torch.cat([gate_w_stacked, up_w_stacked], dim=1)
w2 = down_w_stacked
self._w1 = w1
self._w2 = w2
if free_experts:
# Free per-expert modules to reclaim memory
logger.info(f"Freeing experts for {mode} backend")
del self.experts
self.experts = None
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass with expert routing and computation."""
batch_size, sequence_length, hidden_dim = hidden_states.shape
x = hidden_states.view(-1, hidden_dim)
# Expert routing
router_logits = self.gate(x)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
if self.backend == "vllm":
routing_weights, selected_experts, *_ = fused_topk_vllm(
hidden_states=x,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.norm_topk_prob,
)
else:
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
# if self.backend == "hf":
# final_hidden_states = torch.zeros(
# (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
# )
# expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
# for expert_idx in expert_hit:
# expert_layer = self.experts[expert_idx]
# idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
# current_state = x[top_x]
# current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
# final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
# out = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
# return out, router_logits.view(batch_size, sequence_length, -1)
if self.backend == "hf":
# Accumulate buffer: [B*S, H]
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
# expert_mask: [E, top_k, tokens]
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts
).permute(2, 1, 0).contiguous()
# 顺序遍历所有 experts即使本 rank 没命中也要进入 forward避免 ZeRO-3 控制流分歧
for e in range(self.num_experts):
expert_layer = self.experts[int(e)]
# 取出该 expert 命中的 token 索引
idx, top_x = torch.where(expert_mask[e]) # idx∈[0, top_k), shapes: [n_tok_e]
current_state = x[top_x] # [n_tok_e, H]n_tok_e 可能为 0
# if top_x.numel() == 0:
# print("0")
# 空批照样前向;大多数 Linear/MLP 对 0 行输入是 no-op但会对齐 ZeRO-3 的参数路径
expert_out = expert_layer(current_state) # [n_tok_e, H]
# 路由权重并加权
w = routing_weights[top_x, idx] # [n_tok_e]
expert_out = expert_out * w.unsqueeze(-1) # [n_tok_e, H]
# 累加回全局缓冲;当 n_tok_e=0 时这是合法的空操作
final_hidden_states.index_add_(0, top_x, expert_out.to(hidden_states.dtype))
out = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return out, router_logits.view(batch_size, sequence_length, -1)
elif self.backend == "flashinfer":
# if self._flashinfer_fc1_weights is None or self._flashinfer_fc2_weights is None:
# self._initialize_flashinfer_weights()
if self._w1 is None or self._w2 is None:
self._initialize_weights(mode="flashinfer")
result = fused_moe.cutlass_fused_moe(
input=x,
token_selected_experts=selected_experts.to(torch.int),
token_final_scales=routing_weights.to(torch.float32),
fc1_expert_weights=self._w1,
fc2_expert_weights=self._w2,
output_dtype=x.dtype,
quant_scales=None,
)
if isinstance(result, (list, tuple)):
out_flat = result[0]
else:
out_flat = result
out = out_flat.view(batch_size, sequence_length, hidden_dim)
return out, router_logits.view(batch_size, sequence_length, -1)
elif self.backend == "sglang":
if self._w1 is None or self._w2 is None:
self._initialize_weights(mode="sglang")
topk_output = StandardTopKOutput(
topk_weights=routing_weights,
topk_ids=selected_experts,
router_logits=router_logits,
)
out_flat = sglang_fused_moe(
hidden_states=x,
w1=self._w1,
w2=self._w2,
topk_output=topk_output,
)
out = out_flat.view(batch_size, sequence_length, hidden_dim)
return out, router_logits.view(batch_size, sequence_length, -1)
elif self.backend == "vllm":
if self._w1 is None or self._w2 is None:
self._initialize_weights()
out_flat = fused_experts_vllm(
x,
self._w1,
self._w2,
routing_weights,
selected_experts,
)
out = out_flat.view(batch_size, sequence_length, hidden_dim)
return out, router_logits.view(batch_size, sequence_length, -1)
else:
raise ValueError(f"Invalid backend: {self.backend}")
class RND1PreTrainedModel(PreTrainedModel):
"""Base class for RND1 models with weight initialization and loading support."""
config_class = RND1Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["RND1DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
"""Initialize weights using normal distribution."""
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: Optional[bool] = None,
weights_only: bool = True,
**kwargs,
):
"""Load pretrained model with generation config."""
_model = super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
weights_only=weights_only,
**kwargs,
)
resume_download = kwargs.get("resume_download", None)
proxies = kwargs.get("proxies", None)
subfolder = kwargs.get("subfolder", "")
from_auto_class = kwargs.get("_from_auto", False)
from_pipeline = kwargs.get("_from_pipeline", None)
_model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
)
# If configured to use a fused backend, pack fused tensors once after load
try:
cfg = getattr(_model, "config", None)
backend = getattr(cfg, "moe_backend", "hf") if cfg is not None else "hf"
if backend in ("sglang", "vllm"):
# Walk decoder layers and initialize fused weights
model_core = getattr(_model, "model", _model)
layers = getattr(model_core, "layers", None)
if isinstance(layers, nn.ModuleList):
for layer in layers:
mlp = getattr(layer, "mlp", None)
if hasattr(mlp, "_initialize_weights"):
mlp._initialize_weights(
free_experts=True,
mode=backend,
)
except Exception as _e:
logger.warning(f"Backend {backend} weight processing skipped: {_e}")
return _model
class RND1Model(RND1PreTrainedModel):
"""RND1 transformer model with bidirectional attention for diffusion language modeling."""
def __init__(self, config: RND1Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([RND1DecoderLayer(config, i) for i in range(config.num_hidden_layers)])
if config.moe_backend == "vllm":
RMSNormClass = VLLMRMSNorm
else:
RMSNormClass = Qwen3MoeRMSNorm
self.norm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
) -> MoeModelOutputWithPast:
"""Forward pass through the RND1 model."""
if (input_ids is None) == (inputs_embeds is None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if position_ids is None:
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
if isinstance(attention_mask, torch.Tensor):
# shape: (batch_size, 1, 1, seq_len)
attention_mask = attention_mask.to(dtype=torch.float)[:, None, None, :]
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
hidden_states = inputs_embeds
for layer in self.layers:
hidden_states, _ = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
position_embeddings=position_embeddings,
)
hidden_states = self.norm(hidden_states)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
router_logits=None,
)
class RND1LM(RND1PreTrainedModel, RND1GenerationMixin):
"""Radical Numerics Diffusion Language Model with bidirectional attention."""
def __init__(self, config: RND1Config):
super().__init__(config)
self.model = RND1Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
"""Get the input embeddings layer."""
return self.model.embed_tokens
def set_input_embeddings(self, value):
"""Set the input embeddings layer."""
self.model.embed_tokens = value
def get_output_embeddings(self):
"""Get the output embeddings layer (lm_head)."""
return self.lm_head
def set_output_embeddings(self, new_embeddings):
"""Set the output embeddings layer (lm_head)."""
self.lm_head = new_embeddings
@classmethod
def can_generate(cls) -> bool:
"""Indicates this model can generate text."""
return True
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> MaskedLMOutput:
"""Forward pass with optional loss computation."""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
logits = self.lm_head(outputs.last_hidden_state)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
return MaskedLMOutput(
loss=loss,
logits=logits,
)

View File

@ -0,0 +1,260 @@
# Copyright 2025 Radical Numerics Inc.
#
# This source code is licensed under the Apache License, Version 2.0, found in the
# LICENSE file in the root directory of this source tree.
"""
RND1 sampling module for masked diffusion generation.
This module implements entropy-based token selection for iterative denoising
in diffusion language models. Supports both greedy and stochastic sampling
with optional prefix/suffix constraints and infilling.
"""
import torch
import torch.nn as nn
from typing import Optional, Union
def apply_top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
"""
Apply top-k filtering to logits: with non-top-k values set to -inf
"""
top_k_values, top_k_indices = torch.topk(logits, min(k, logits.size(-1)), dim=-1)
filtered_logits = torch.full_like(logits, float('-inf'))
filtered_logits.scatter_(-1, top_k_indices, top_k_values)
return filtered_logits
def apply_top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
"""
Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf
"""
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above threshold
sorted_indices_to_remove = cumulative_probs > p
sorted_indices_to_remove[..., 0] = False # Keep at least one token
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
return logits.masked_fill(indices_to_remove, float('-inf'))
@torch.no_grad()
def diffusion_sample(
model: nn.Module,
seq_len: int = 256,
num_steps: int = 256,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: float = 1.0,
greedy: bool = True,
mask_token_id: int = 151669,
prefix_ids: Optional[torch.LongTensor] = None,
suffix_ids: Optional[torch.LongTensor] = None,
infill_length: Optional[int] = None,
eos_token_id: int = 151645,
pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
visualizer: Optional[object] = None,
) -> torch.LongTensor:
"""
Perform masked diffusion sampling with entropy-based token selection.
Args:
model: The RND1 language model
seq_len: Target sequence length
num_steps: Number of denoising steps
top_k: Optional top-k filtering for sampling (None = no filtering)
top_p: Optional nucleus (top-p) filtering for sampling (None = no filtering)
When both top_k and top_p are set, top_k is applied first, then top_p
temperature: Temperature for sampling (higher = more random, lower = more deterministic)
Values close to 0 are clamped to 1e-8 to avoid division by zero
greedy: Whether to use greedy sampling (True) or stochastic (False)
mask_token_id: Token ID for masked positions (default: 151669)
prefix_ids: Optional prefix token IDs to preserve
suffix_ids: Optional suffix token IDs to preserve
infill_length: Length of infill region between prefix/suffix
eos_token_id: End of sequence token ID (default: 151645)
pad_token_id: Padding token ID (default: None, uses 0 if needed)
bos_token_id: Beginning of sequence token ID (default: None)
device: Device for computation (None = infer from model)
visualizer: Optional visualizer for live visualization
Returns:
Generated token IDs as LongTensor
"""
model.eval()
if device is None:
device = next(model.parameters()).device
else:
device = torch.device(device)
if pad_token_id is None:
pad_token_id = 0
# Build initial masked sequence
# When prefix_ids is provided, we create a sequence of length seq_len where:
# - The prefix occupies the first pre_len positions
# - The remaining (seq_len - pre_len) positions are filled with mask tokens to be generated
if prefix_ids is not None or suffix_ids is not None:
if prefix_ids is not None:
prefix_ids = prefix_ids.to(device) if isinstance(prefix_ids, torch.Tensor) else torch.tensor(prefix_ids, device=device)
pre_len = prefix_ids.shape[-1] if prefix_ids.dim() > 0 else 0
else:
pre_len = 0
if suffix_ids is not None:
suffix_ids = suffix_ids.to(device) if isinstance(suffix_ids, torch.Tensor) else torch.tensor(suffix_ids, device=device)
suf_len = suffix_ids.shape[-1] if suffix_ids.dim() > 0 else 0
else:
suf_len = 0
reserved = (1 if eos_token_id is not None else 0)
used = pre_len + suf_len + reserved
if used > seq_len:
raise ValueError(
f"Combined length of prefix ({pre_len}), suffix ({suf_len}), "
f"and special tokens ({reserved}) = {used} exceeds seq_len ({seq_len}). "
f"Please increase seq_len or reduce input lengths."
)
elif used == seq_len:
raise ValueError(
f"No space for generation: prefix ({pre_len}) + suffix ({suf_len}) "
f"+ special tokens ({reserved}) = seq_len ({seq_len}). "
f"Need at least 1 position for generation."
)
infill_length = min(infill_length or (seq_len - used), seq_len - used)
x = torch.full((1, seq_len), pad_token_id, dtype=torch.long, device=device)
pos = 0
# if bos_token_id is not None:
# x[0, pos] = bos_token_id; pos += 1
if eos_token_id is not None:
x[0, -1] = eos_token_id
if pre_len > 0:
x[0, pos:pos+pre_len] = prefix_ids.flatten()[:pre_len]
pos += pre_len
fill_start, fill_end = pos, pos + infill_length
x[0, fill_start:fill_end] = mask_token_id
# print(fill_start, fill_end, seq_len, used, x[0, -1])
pos = fill_end
if suf_len > 0:
x[0, pos:pos+suf_len] = suffix_ids.flatten()[:suf_len]
pos += suf_len
init_maskable = torch.zeros_like(x, dtype=torch.bool)
init_maskable[0, fill_start:fill_end] = True
else:
x = torch.full((1, seq_len), mask_token_id, dtype=torch.long, device=device)
if bos_token_id is not None:
x[0, 0] = bos_token_id
if eos_token_id is not None:
x[0, -1] = eos_token_id
init_maskable = x.eq(mask_token_id)
if bos_token_id is not None:
init_maskable[:, 0] = False
if eos_token_id is not None:
init_maskable &= x.ne(eos_token_id)
init_maskable &= x.ne(pad_token_id)
maskable = init_maskable.clone()
xt = x.clone()
if visualizer:
visualizer.start_visualization(xt, maskable, num_steps)
def forward_scores(tokens):
"""Compute predictions and entropy scores for next tokens."""
# Try with input_ids parameter first (standard HF models)
try:
model_output = model(input_ids=tokens)
except TypeError:
# Fall back to positional argument
model_output = model(tokens)
# Apply temperature scaling (with safety for near-zero temperature)
safe_temperature = max(temperature, 1e-8) # Prevent division by zero
logits = model_output.logits / safe_temperature
# Apply filtering strategies
# Note: When both top_k and top_p are provided, they are applied sequentially:
# First top_k filters to k tokens, then top_p filters from those k tokens
if top_k is not None and top_k > 0:
logits = apply_top_k_filtering(logits, top_k)
if top_p is not None and 0 < top_p < 1.0:
logits = apply_top_p_filtering(logits, top_p)
# Convert to log probabilities
logp = torch.log_softmax(logits, dim=-1)
# Greedy or stochastic sampling
if greedy:
pred_next = logp.argmax(-1)
else:
pred_next = torch.distributions.Categorical(logits=logp).sample()
conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1)
p = logp.exp()
ent_next = -(p * logp).sum(-1)
# Shift predictions: pos i predicts token i+1
pred_i = tokens.clone()
conf_i = torch.full_like(conf_next, torch.finfo(conf_next.dtype).min)
ent_i = torch.zeros_like(ent_next)
pred_i[:, 1:] = pred_next[:, :-1]
conf_i[:, 1:] = conf_next[:, :-1]
ent_i[:, 1:] = ent_next[:, :-1]
return pred_i, conf_i, ent_i
pred_i, conf_i, ent_i = forward_scores(xt)
total_masked = init_maskable.sum(1, keepdim=True)
finf = torch.finfo(conf_i.dtype)
for step in range(num_steps - 1, 0, -1):
rate = step / num_steps
cutoff_len = (total_masked * rate).long().clamp(min=0)
# Choose HIGH-entropy tokens to keep masked
sel_scores = ent_i.masked_fill(~maskable, -finf.max)
B, L = sel_scores.shape
k_max = cutoff_len.max().item()
if k_max > 0:
sss, idx = torch.topk(sel_scores, k_max, dim=-1, largest=True)
keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool)
for b in range(B):
k_b = int(cutoff_len[b].item())
if k_b > 0:
keep_mask[b, idx[b, :k_b]] = True
else:
keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool)
to_unmask = maskable & ~keep_mask
if to_unmask.any():
xt[to_unmask] = pred_i[to_unmask]
maskable[to_unmask] = False
if visualizer:
visualizer.update_step(xt, maskable, num_steps - step, ent_i, conf_i)
if maskable.any():
pred_i, conf_i, ent_i = forward_scores(xt)
if maskable.any():
xt[maskable] = pred_i[maskable]
if visualizer:
visualizer.stop_visualization()
return xt

View File

@ -0,0 +1,251 @@
# Copyright 2025 Radical Numerics Inc.
#
# This source code is licensed under the Apache License, Version 2.0, found in the
# LICENSE file in the root directory of this source tree.
"""
Terminal visualization for RND1 generation.
This module provides real-time visualization of the diffusion denoising process,
showing token evolution and generation progress in the terminal using rich
formatting when available.
"""
import torch
from typing import Optional
from tqdm import tqdm
try:
from rich.console import Console
from rich.live import Live
from rich.text import Text
from rich.panel import Panel
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
from rich.layout import Layout
RICH_AVAILABLE = True
except ImportError:
RICH_AVAILABLE = False
class TerminalVisualizer:
"""
Rich-based visualization for diffusion process with live updates.
Provides real-time visualization of the token denoising process during
diffusion-based language generation, with colored highlighting of masked
positions and progress tracking.
"""
def __init__(self, tokenizer, show_visualization: bool = True):
"""
Initialize the terminal visualizer.
Args:
tokenizer: The tokenizer for decoding tokens to text
show_visualization: Whether to show visualization (requires rich)
"""
self.tokenizer = tokenizer
self.show_visualization = show_visualization and RICH_AVAILABLE
if not RICH_AVAILABLE and show_visualization:
print("Warning: Install 'rich' for better visualization. Falling back to simple progress bar.")
self.show_visualization = False
if self.show_visualization:
self.console = Console()
self.live = None
self.progress = None
self.layout = None
else:
self.pbar = None
self.current_tokens = None
self.mask_positions = None
self.total_steps = 0
self.current_step = 0
def start_visualization(self, initial_tokens: torch.LongTensor, mask_positions: torch.BoolTensor, total_steps: int):
"""
Start the visualization.
Args:
initial_tokens: Initial token IDs (possibly masked)
mask_positions: Boolean mask indicating which positions are masked
total_steps: Total number of diffusion steps
"""
if not self.show_visualization:
self.pbar = tqdm(total=total_steps, desc="Diffusion")
return
self.current_tokens = initial_tokens.clone()
self.mask_positions = mask_positions
self.total_steps = total_steps
self.current_step = 0
self.layout = Layout()
self.layout.split_column(
Layout(name="header", size=3),
Layout(name="text", ratio=1),
Layout(name="progress", size=3)
)
self.progress = Progress(
TextColumn("[bold blue]Diffusion"),
BarColumn(),
MofNCompleteColumn(),
TextColumn(""),
TextColumn("[cyan]Masks: {task.fields[masks]}"),
TimeRemainingColumn(),
)
self.progress_task = self.progress.add_task(
"Generating",
total=total_steps,
masks=mask_positions.sum().item()
)
self.live = Live(self.layout, console=self.console, refresh_per_second=4)
self.live.start()
self._update_display()
def update_step(self, tokens: torch.LongTensor, maskable: Optional[torch.BoolTensor], step: int,
entropy: Optional[torch.FloatTensor] = None, confidence: Optional[torch.FloatTensor] = None):
"""
Update visualization for current step.
Args:
tokens: Current token IDs
maskable: Boolean mask of remaining masked positions
step: Current step number
entropy: Optional entropy scores for each position
confidence: Optional confidence scores for each position
"""
if not self.show_visualization:
if self.pbar:
self.pbar.update(1)
masks = maskable.sum().item() if maskable is not None else 0
self.pbar.set_postfix({'masks': masks})
return
self.current_tokens = tokens.clone()
self.mask_positions = maskable
self.current_step = step
masks_remaining = maskable.sum().item() if maskable is not None else 0
self.progress.update(
self.progress_task,
advance=1,
masks=masks_remaining
)
self._update_display()
def _update_display(self):
"""Update the live display."""
if not self.live:
return
header = Text("RND1-Base Generation", style="bold magenta", justify="center")
self.layout["header"].update(Panel(header, border_style="bright_blue"))
text_display = self._format_text_with_masks()
self.layout["text"].update(
Panel(
text_display,
title="[bold]Generated Text",
subtitle=f"[dim]Step {self.current_step}/{self.total_steps}[/dim]",
border_style="cyan"
)
)
self.layout["progress"].update(Panel(self.progress))
def _format_text_with_masks(self) -> Text:
"""
Format text with colored masks.
Returns:
Rich Text object with formatted tokens
"""
text = Text()
if self.current_tokens is None:
return text
token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens
mask_flags = self.mask_positions[0] if self.mask_positions is not None and self.mask_positions.dim() > 1 else self.mask_positions
for i, token_id in enumerate(token_ids):
if mask_flags is not None and i < len(mask_flags) and mask_flags[i]:
# Alternate colors for visual effect
text.append("[MASK]", style="bold red on yellow" if self.current_step % 2 == 0 else "bold yellow on red")
else:
try:
token_str = self.tokenizer.decode([token_id.item()], skip_special_tokens=False)
# Skip special tokens in display
if token_str not in ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<s>", "</s>"]:
# Color based on position
text.append(token_str, style="green" if i < len(token_ids) // 2 else "cyan")
except:
continue
return text
def stop_visualization(self):
"""Stop the visualization and display final result."""
if not self.show_visualization:
if self.pbar:
self.pbar.close()
print("\n✨ Generation complete!\n")
return
if self.live:
self.live.stop()
self.console.print("\n[bold green]✨ Generation complete![/bold green]\n")
# Display final text
if self.current_tokens is not None:
try:
token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens
final_text = self.tokenizer.decode(token_ids, skip_special_tokens=True)
self.console.print(Panel(
final_text,
title="[bold]Final Generated Text",
border_style="green",
padding=(1, 2)
))
except:
pass
class SimpleProgressBar:
"""
Simple progress bar fallback when rich is not available.
Provides basic progress tracking using tqdm when the rich library
is not installed.
"""
def __init__(self, total_steps: int):
"""
Initialize simple progress bar.
Args:
total_steps: Total number of steps
"""
self.pbar = tqdm(total=total_steps, desc="Diffusion")
def update(self, masks_remaining: int = 0):
"""
Update progress bar.
Args:
masks_remaining: Number of masks still remaining
"""
self.pbar.update(1)
self.pbar.set_postfix({'masks': masks_remaining})
def close(self):
"""Close the progress bar."""
self.pbar.close()
print("\n✨ Generation complete!\n")

View File

@ -0,0 +1,23 @@
from typing import Any
import torch
from dllm.core.trainers import MDLMTrainer
class RNDTrainer(MDLMTrainer):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
def _preprocess_inputs(self, inputs):
labels = inputs["labels"]
assert (labels[:, 0] == -100).all()
def _postprocess_outputs(self, outputs):
logits = outputs.logits
outputs.logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)

242
dllm/dllm/tools/chat.py Normal file
View File

@ -0,0 +1,242 @@
import shutil
from typing import List, Literal
import textwrap
import dllm
# ============================================================
# Utility helpers
# ============================================================
try:
L = shutil.get_terminal_size().columns
if not isinstance(L, int) or L <= 0:
L = 120
except Exception:
L = 120
DIV = "=" * L
SUB = "-" * L
def banner_line(text: str, width: int = L, fill: str = "=") -> str:
"""Return a centered banner line with given width and fill."""
text = f" {text.strip()} "
fill_len = width - len(text)
if fill_len <= 0:
return text
left = fill_len // 2
right = fill_len - left
return f"{fill * left}{text}{fill * right}"
def print_wrapped(text: str, width: int = L):
"""Print text with automatic line wrapping."""
wrapped = textwrap.fill(text, width=width)
print(wrapped)
def boxed(text: str, width: int = L, padding: int = 1):
"""Render a centered box with the given text and width."""
lines = text.splitlines()
content_width = max(len(line) for line in lines)
box_width = min(width, content_width + padding * 2 + 2)
# compute left margin for centering
terminal_width = width
left_margin = max((terminal_width - box_width) // 2, 0)
margin = " " * left_margin
top = margin + "" + "" * (box_width - 2) + ""
bottom = margin + "" + "" * (box_width - 2) + ""
print(top)
for line in lines:
inner = line.center(content_width)
print(margin + "" + " " * padding + inner + " " * padding + "")
print(bottom)
def decode_trim(tokenizer, seq_ids_list, input_ids_list) -> str:
"""
Return only the generated text, truncated at the first EOS **after** the prompt.
Args:
tokenizer: HF tokenizer with eos_token_id / pad_token_id.
seq_ids: Full sequence token ids from the model (prompt + generation).
input_ids: The prompt token ids that were fed into the model.
Behavior:
- Finds the first eos_token_id that occurs at or after len(input_ids).
- Slices generation up to (but not including) that EOS.
- Decodes only the generation span, skipping special/pad tokens.
"""
# Make sure we can index these
sequences = []
for seq_ids, input_ids in zip(seq_ids_list, input_ids_list):
full = list(seq_ids)
prompt = list(input_ids)
# Skip left padding tokens (necessary for dream)
pad_id = getattr(tokenizer, "pad_token_id", None)
if pad_id is not None:
while full and full[0] == pad_id:
full.pop(0)
start = len(prompt)
end = len(full)
eos_id = getattr(tokenizer, "eos_token_id", None)
eot_id = getattr(tokenizer, "eot_token_id", None)
if eos_id is not None:
for i in range(start, len(full)):
if full[i] in (eos_id, eot_id):
end = i
break
gen_ids = full[start:end]
text = tokenizer.decode(gen_ids, skip_special_tokens=True)
# in case there is no eos_id or eot_id, just strings
eos = getattr(tokenizer, "eos_token", None)
eot = getattr(tokenizer, "eot_token", None)
if eos:
text = text.split(eos)[0]
if eot:
text = text.split(eot)[0]
# return text.strip()
sequences.append(text)
return sequences
def render_menu(round_idx: int):
"""Render a boxed menu of possible actions."""
if round_idx == 0:
text = (
"Possible next actions:\n"
"[1] Continue this chat\n"
"[2] End this chat and start a new one\n"
"[3] Exit"
)
else:
text = (
f"(Round {round_idx})\n"
"Possible next actions:\n"
"[1] Continue this chat\n"
"[2] End this chat and start a new one\n"
"[3] Exit"
)
print() # spacing
boxed(text)
def prompt_choice() -> Literal["1", "2", "3"]:
while True:
print("Select action [1/2/3]: ")
choice = input().strip()
if choice in ("1", "2", "3"):
return choice
print(banner_line("<Invalid choice. Please type 1, 2, or 3.>", fill=" "))
def build_chat_inputs(tokenizer, messages: List[dict], add_generation_prompt: bool):
"""Tokenize chat messages into inputs tensor."""
return tokenizer.apply_chat_template(
messages,
add_generation_prompt=add_generation_prompt,
tokenize=True,
)
def visualize_histories(tokenizer, histories):
try:
terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer(
tokenizer=tokenizer
)
terminal_visualizer.visualize(histories, rich=True)
except Exception as e:
print(f"(Visualization skipped: {e})")
# ============================================================
# Modes
# ============================================================
def single_turn_generate(generator, gen_config, visualize: bool):
print()
print(banner_line("continuation mode"))
model, tokenizer = generator.model, generator.tokenizer
while True:
print(banner_line("<Type your prompt below. Press Ctrl+C to exit.>", fill=" "))
try:
# user_text = input("Prompt > ").strip()
print("[Prompt] > ")
user_text = input().strip()
except (EOFError, KeyboardInterrupt):
print("\n" + banner_line("Exiting. Bye!", width=len(DIV)))
return
# if not user_text:
# print("(Empty input, skipped)\n")
# continue
inputs = tokenizer([user_text], add_special_tokens=False)["input_ids"]
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
text = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)[0]
print(banner_line("Output"))
print_wrapped(text if text else "<empty>")
print(DIV + "\n")
if visualize:
visualize_histories(tokenizer, outputs.histories)
def multi_turn_chat(generator, gen_config, visualize: bool):
# """Chat mode with chat template & message history."""
print()
print(banner_line("multi-turn chat mode"))
print(banner_line("<Starting a new chat. Type your message.>", fill=" "))
model, tokenizer = generator.model, generator.tokenizer
messages: List[dict] = []
round_idx = 0
while True:
try:
print("[You]:")
user_msg = input().strip()
except (EOFError, KeyboardInterrupt):
print("\nExiting. Bye!")
return
messages.append({"role": "user", "content": user_msg})
inputs = build_chat_inputs(tokenizer, [messages], add_generation_prompt=True)
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
reply = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)[0]
print(DIV)
print_wrapped("[Assistant]: " + reply if reply else "<empty>")
print(DIV + "\n")
messages.append({"role": "assistant", "content": reply})
if visualize:
visualize_histories(tokenizer, outputs.histories)
render_menu(round_idx)
choice = prompt_choice()
if choice == "1":
print(banner_line("<Type your message.>", fill=" "))
round_idx += 1
continue
elif choice == "2":
print(banner_line("<Starting a new chat. Type your message.>", fill=" "))
messages = []
round_idx = 0
continue
else:
print("\nExiting. Bye!")
return

View File

@ -0,0 +1,30 @@
from dataclasses import dataclass
import tyro
from huggingface_hub import snapshot_download
@dataclass
class ScriptArguments:
dataset_id: str = "Anthropic/hh-rlhf"
allow_patterns: str = None
script_args = tyro.cli(ScriptArguments)
# Replace with the dataset repo you want, e.g. "wikitext"
dataset_id = script_args.dataset_id
# Replace with your desired local directory
local_dir = f"/mnt/lustrenew/mllm_aligned/shared/datasets/huggingface/{dataset_id}"
# Download the dataset snapshot
snapshot_download(
repo_id=dataset_id,
repo_type="dataset", # 👈 tell HF it's a dataset
local_dir=local_dir,
local_dir_use_symlinks=False, # ensures real files, not symlinks
allow_patterns=script_args.allow_patterns,
)
print(f"Dataset downloaded to: {local_dir}")

View File

@ -0,0 +1,27 @@
from dataclasses import dataclass
import tyro
from huggingface_hub import snapshot_download
@dataclass
class ScriptArguments:
model_id: str = "GSAI-ML/LLaDA-8B-Instruct"
script_args = tyro.cli(ScriptArguments)
# Replace with the model repo you want, e.g. "bert-base-uncased"
model_id = script_args.model_id
# Replace with your desired local directory
local_dir = f"/mnt/lustrenew/mllm_aligned/shared/models/huggingface/{model_id}"
# Download the model snapshot
snapshot_download(
repo_id=model_id,
local_dir=local_dir,
local_dir_use_symlinks=False, # ensures real files, not symlinks
)
print(f"Model downloaded to: {local_dir}")

View File

@ -0,0 +1 @@
# TODO

View File

@ -0,0 +1,80 @@
"""
Merge a PEFT/LoRA adapter into its base model (auto-detected from adapter_config.json).
Usage:
python dllm_trainer/tools/merge_peft_adapter.py \
--adapter_model_name_or_path your-org/your-lora \
--output_model_name_or_path ./merged-model \
--dtype bf16
"""
from dataclasses import dataclass, field
from typing import Optional
import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModel, AutoTokenizer, HfArgumentParser
import dllm # so that no need to trust_remote_code
DTYPE_MAP = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float}
@dataclass
class ScriptArguments:
adapter_model_name_or_path: str | None = field(
default=None, metadata={"help": "Adapter repo or local path"}
)
output_model_name_or_path: str | None = field(
default=None,
metadata={"help": "Where to save the merged model (folder or repo id)"},
)
dtype: str | None = field(default="fp16", metadata={"help": "fp16|bf16|fp32"})
push_to_hub: bool | None = field(
default=False, metadata={"help": "Push merged weights to the Hub"}
)
# Optional override if adapter config lacks base info:
base_model_name_or_path: str | None = field(
default=None,
metadata={"help": "Override base model if adapter config lacks it"},
)
def main():
parser = HfArgumentParser(ScriptArguments)
args = parser.parse_args_into_dataclasses()[0]
assert args.adapter_model_name_or_path, "please provide the adapter (repo or path)"
assert args.output_model_name_or_path, "please provide output_model_name_or_path"
assert args.dtype in DTYPE_MAP, f"dtype must be one of {list(DTYPE_MAP.keys())}"
# Read base path from adapter_config.json
peft_cfg = PeftConfig.from_pretrained(args.adapter_model_name_or_path)
base_id = args.base_model_name_or_path or getattr(
peft_cfg, "base_model_name_or_path", None
)
assert base_id, (
"adapter_config.json does not include base_model_name_or_path; "
"pass --base_model_name_or_path to override."
)
# Load base model and tokenizer
model = AutoModel.from_pretrained(
base_id, return_dict=True, dtype=DTYPE_MAP[args.dtype]
)
tokenizer = AutoTokenizer.from_pretrained(base_id)
# Attach adapter, merge, and unload PEFT layers
model = PeftModel.from_pretrained(model, args.adapter_model_name_or_path)
model.eval()
model = model.merge_and_unload() # plain transformers model
# Save locally
model.save_pretrained(args.output_model_name_or_path)
tokenizer.save_pretrained(args.output_model_name_or_path)
print(f"✓ merged model saved to: {args.output_model_name_or_path}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,109 @@
"""
python dllm/tools/preprocess_pt_dataset.py
"""
import os
from dataclasses import dataclass, asdict
from functools import partial
import datasets
import tyro
import transformers
from pprint import pprint
import dllm
@dataclass
class ScriptArguments:
"""Preprocess PT dataset"""
model_name_or_path: str = "answerdotai/ModernBERT-large"
dataset_args: str = "OpenCoder-LLM/opc-annealing-corpus[lang:python]" # required
output_dir: str = "data/pt/modernbert/opc-annealing-corpus[lang:python]" # required
text_field: str = "text"
max_length: int = 1024
insert_eos: bool = True
drop_tail: bool = True
remove_columns: bool = False
num_proc: int = 32
def __post_init__(self):
self.model_name_or_path = dllm.utils.resolve_with_base_env(
self.model_name_or_path, "BASE_MODELS_DIR"
)
def preprocess_pt_dataset(
dataset: datasets.DatasetDict,
tokenizer: transformers.PreTrainedTokenizer,
output_dir: str,
text_field: str = "text",
max_length: int = 1024,
insert_eos: bool = True,
drop_tail: bool = True,
remove_columns: bool = False,
num_proc: int = 32,
):
processed = dataset.map(
partial(
dllm.utils.tokenize_and_group,
tokenizer=tokenizer,
text_field=text_field,
seq_length=max_length,
insert_eos=insert_eos,
drop_tail=drop_tail,
),
batched=True,
num_proc=num_proc,
remove_columns=dataset["train"].column_names,
)
# Keep only the three required columns to save space.
if remove_columns:
keep = {"input_ids", "labels"}
def strip_cols(ds: datasets.Dataset) -> datasets.Dataset:
drop = [c for c in ds.column_names if c not in keep]
return ds.remove_columns(drop) if drop else ds
if isinstance(processed, datasets.DatasetDict):
for split in list(processed.keys()):
processed[split] = strip_cols(processed[split])
else:
processed = strip_cols(processed)
output_dir = os.path.join(
output_dir,
f"max_length-{max_length}-insert_eos-{insert_eos}-drop_tail-{drop_tail}",
)
os.makedirs(output_dir, exist_ok=True)
processed.save_to_disk(output_dir)
print(f"[OK] Saved to: {output_dir}")
def main():
# Parse with tyro
args = tyro.cli(ScriptArguments)
dllm.utils.print_args(args)
tokenizer = dllm.utils.get_tokenizer(args)
# Load your raw dataset (must contain a "messages" field per example).
dataset = dllm.data.load_pt_dataset(args.dataset_args, streaming=False)
preprocess_pt_dataset(
dataset=dataset,
tokenizer=tokenizer,
output_dir=args.output_dir,
text_field=args.text_field,
max_length=args.max_length,
insert_eos=args.insert_eos,
drop_tail=args.drop_tail,
remove_columns=args.remove_columns,
num_proc=args.num_proc,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,117 @@
"""
Example:
PYTHONPATH=. python dllm/tools/preprocess_sft_dataset.py \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
--sft_map_fn_path "examples.dream.sft.sft_map_fn" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "data/sft/dream/tulu-3-sft-mixture" \
--num_proc 64
"""
import os
import importlib
from dataclasses import dataclass
from functools import partial
import datasets
import tyro
import dllm
@dataclass
class ScriptArguments:
"""Preprocess SFT dataset"""
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Base"
sft_map_fn_path: str = "dllm.utils.default_sft_map_fn"
dataset_args: str = "HuggingFaceTB/smoltalk" # required
output_dir: str = "data/sft/llada/smoltalk" # required
mask_prompt_loss: bool = True # Mask prompt tokens in labels with -100
num_proc: int = 32
remove_columns: bool = False
def __post_init__(self):
self.model_name_or_path = dllm.utils.resolve_with_base_env(
self.model_name_or_path, "BASE_MODELS_DIR"
)
def preprocess_sft_dataset(
dataset: datasets.DatasetDict,
map_fn: callable,
output_dir: str,
remove_columns: bool = False,
num_proc: int = 32,
):
processed = dataset.map(
map_fn,
batched=False,
num_proc=num_proc,
load_from_cache_file=True,
writer_batch_size=512,
desc="offline preprocessing",
)
# Keep only the three required columns to save space.
if remove_columns:
keep = {"input_ids", "labels", "prompt_len", "attention_mask"}
def strip_cols(ds: datasets.Dataset) -> datasets.Dataset:
drop = [c for c in ds.column_names if c not in keep]
return ds.remove_columns(drop) if drop else ds
if isinstance(processed, datasets.DatasetDict):
for split in list(processed.keys()):
processed[split] = strip_cols(processed[split])
else:
processed = strip_cols(processed)
os.makedirs(output_dir, exist_ok=True)
processed.save_to_disk(output_dir)
print(f"[OK] Saved to: {output_dir}")
def main():
# Parse with tyro
args = tyro.cli(ScriptArguments)
dllm.utils.print_args(args)
tokenizer = dllm.utils.get_tokenizer(args)
# Load your raw dataset (must contain a "messages" field per example).
dataset = dllm.data.load_sft_dataset(args.dataset_args)
# 4. Dynamically import the function based on the argument
try:
# Split the path into module and function name
module_path, function_name = args.sft_map_fn_path.rsplit(".", 1)
# Import the module
module = importlib.import_module(module_path)
# Get the function from the module
sft_map_fn = getattr(module, function_name)
except (ImportError, AttributeError, ValueError) as e:
print(f"Error: Could not import '{args.sft_map_fn_path}'.")
print(f"Details: {e}")
return
map_fn = partial(
sft_map_fn,
tokenizer=tokenizer,
mask_prompt_loss=args.mask_prompt_loss,
)
preprocess_sft_dataset(
dataset=dataset,
map_fn=map_fn,
output_dir=args.output_dir,
remove_columns=args.remove_columns,
num_proc=args.num_proc,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,6 @@
from . import configs, generation_utils, model_utils, utils
from .configs import *
from .generation_utils import *
from .data_utils import *
from .model_utils import *
from .utils import *

View File

@ -0,0 +1,77 @@
import os
from dataclasses import dataclass, field
import transformers
from dllm.utils.utils import resolve_with_base_env, get_default_logger
logger = get_default_logger(__name__)
@dataclass
class ModelArguments:
model_name_or_path: str = None # overwrite this
dtype: str = "bfloat16"
load_in_4bit: bool = False
attn_implementation: str = None
# --- fold PEFT args here ---
lora: bool = False
target_modules: str = "all-linear"
r: int = 32
lora_alpha: int = 64
lora_dropout: float = 0.05
bias: str = "none"
modules_to_save: str = None
def __post_init__(self):
self.model_name_or_path = resolve_with_base_env(
self.model_name_or_path, "BASE_MODELS_DIR"
)
@dataclass
class DataArguments:
dataset_args: str = None # overwrite this
num_proc: int = 8
disable_caching: bool = False
max_length: int = 1024
truncation: str = field(
default="right",
metadata={
"help": (
'The truncation strategy to use ("filter" or "right"). '
'"filter" only keeps sequences that are shorter than max_length; '
'"right" only keeps the rightmost max_length tokens for each sequence.'
)
},
)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
output_dir: str = None # overwrite this
report_to: str = "wandb"
overwrite_output_dir: bool = True
seed: int = 42
per_device_train_batch_size: int = 4
per_device_eval_batch_size: int = 4
gradient_accumulation_steps: int = 1
learning_rate: float = 2e-5
lr_scheduler_type: str = "cosine"
warmup_ratio: float = 0.1
bf16: bool = True
num_train_epochs: float = 4
logging_steps: float = 10
eval_on_start: bool = False
eval_strategy: str = "steps"
eval_steps: float = 0.25
save_steps: float = 0.25
save_only_model: bool = True
def __post_init__(self):
super().__post_init__()
if self.group_by_length:
logger.info(
"training_args.group_by_length=True: preprocessing "
"may take some time after `trainer.train()` starts."
)

View File

@ -0,0 +1,222 @@
import random
import warnings
from dataclasses import dataclass
from itertools import chain
from typing import TYPE_CHECKING
import torch
import datasets
import transformers
if TYPE_CHECKING:
from dllm.utils.configs import ModelArguments, DataArguments, TrainingArguments
def tokenize_and_group(
examples,
tokenizer,
text_field: str = "text",
seq_length: int = 1024,
insert_eos: bool = False,
drop_tail: bool = True,
add_special_tokens: bool = False,
):
# 1) Tokenize (batched input)
tokenized = tokenizer(examples[text_field], add_special_tokens=add_special_tokens)
ids = tokenized["input_ids"]
# --- optionally append EOS to each sample ---
if insert_eos:
eos_id = getattr(tokenizer, "eos_token_id")
assert eos_id
# append EOS only if the sample doesn't already end with it
ids = [seq + ([] if (seq and seq[-1] == eos_id) else [eos_id]) for seq in ids]
# ----------------------------------------------------------------
# 2) Flatten and concatenate all token lists
concatenated = list(chain.from_iterable(ids))
if not concatenated:
return {"input_ids": [], "labels": []} # Safe return for empty batch
# 3) Calculate the total length based on drop_tail
if drop_tail:
total_len = (len(concatenated) // seq_length) * seq_length
concatenated = concatenated[:total_len] # Truncate the last incomplete chunk
else:
total_len = len(concatenated)
# Split into fixed-length chunks
chunks = [concatenated[i : i + seq_length] for i in range(0, total_len, seq_length)]
return {
"input_ids": chunks,
"labels": [c[:] for c in chunks], # Labels are the same as input_ids
}
def clip_row(row: dict, max_length: int, truncation: str = "right") -> dict:
for key in ("input_ids", "labels", "attention_mask"):
if key in row:
if truncation == "right":
row[key] = row[key][:max_length]
elif truncation == "left":
row[key] = row[key][-max_length:]
else:
raise NotImplementedError
return row
def post_process_dataset(
dataset: datasets.DatasetDict, data_args: "DataArguments"
) -> datasets.DatasetDict:
if data_args.truncation == "filter":
return dataset.filter(
lambda row: len(row["input_ids"]) <= data_args.max_length,
num_proc=data_args.num_proc,
desc=f"Filtering samples with length <= {data_args.max_length}",
)
elif data_args.truncation == "right":
# do this only if dataset has "prompt_len"
if "prompt_len" in dataset.column_names["train"]:
dataset = dataset.filter(
lambda row: row["prompt_len"] <= data_args.max_length,
num_proc=data_args.num_proc,
desc=f"Filtering samples with `prompt_len` <= {data_args.max_length}",
)
return dataset.map(
lambda row: clip_row(row, data_args.max_length, truncation="right"),
num_proc=data_args.num_proc,
desc=f"Right-truncating samples to max_length={data_args.max_length}",
)
else:
raise NotImplementedError
def clip_row_streaming(row: dict, max_length: int, truncation: str = "right") -> dict:
"""Clip whole sequence OR (if prompt_len present) preserve prompt and clip only the response."""
if truncation not in {"right", "left"}:
raise NotImplementedError(f"Unknown truncation: {truncation}")
def clip(seq):
return seq[:max_length] if truncation == "right" else seq[-max_length:]
def clip_preserve_prompt(seq, prompt_len: int):
prompt = seq[:prompt_len]
resp = seq[prompt_len:]
budget = max(0, max_length - len(prompt))
resp = resp[:budget] if truncation == "right" else resp[-budget:]
return prompt + resp
prompt_len = row.get("prompt_len", None)
for k in ("input_ids", "labels", "attention_mask"):
if k in row and isinstance(row[k], list):
row[k] = (
clip_preserve_prompt(row[k], prompt_len)
if isinstance(prompt_len, int) and prompt_len >= 0
else clip(row[k])
)
return row
def post_process_dataset_streaming(
dataset: datasets.IterableDatasetDict,
data_args: "DataArguments",
) -> datasets.IterableDatasetDict:
def _train_has_prompt_len_streaming(dataset: datasets.IterableDatasetDict) -> bool:
"""Replicates: 'if \"prompt_len\" in dataset.column_names[\"train\"]' for streaming."""
it = dataset["train"].take(1)
try:
ex = next(iter(it))
except StopIteration:
return False
return "prompt_len" in ex
mode = data_args.truncation
max_len = data_args.max_length
if mode == "filter":
# Keep rows with len(input_ids) <= max_len (emulate .filter with generator map)
def keep_if_short(row):
if (
"input_ids" in row
and isinstance(row["input_ids"], list)
and len(row["input_ids"]) <= max_len
):
yield row # keep
# else: drop (yield nothing)
return datasets.IterableDatasetDict(
{name: ds.map(keep_if_short) for name, ds in dataset.items()}
)
elif mode == "right":
ds_out = dataset
# Do this only if TRAIN split has "prompt_len" (same condition as your non-streaming code)
if _train_has_prompt_len_streaming(ds_out):
def keep_if_prompt_fits(row):
pl = row.get("prompt_len", None)
if isinstance(pl, int) and pl <= max_len:
yield row # keep
elif pl is None:
# If a row lacks prompt_len but train had it, the non-streaming code would try to access it and fail.
# Here we conservatively drop such rows to mirror "requires prompt_len <= max_len".
return
# else: drop
ds_out = datasets.IterableDatasetDict(
{name: ds.map(keep_if_prompt_fits) for name, ds in ds_out.items()}
)
# Then clip right (same clipping as clip_row)
def clip_right(row):
return clip_row(row, max_len, truncation="right")
return datasets.IterableDatasetDict(
{name: ds.map(clip_right) for name, ds in ds_out.items()}
)
else:
raise NotImplementedError
@dataclass
class NoAttentionMaskCollator(transformers.DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
outputs = super().__call__(features, return_tensors)
# fintune on padding <eos_token>; should not mask them out
outputs.pop("attention_mask")
return outputs
def default_sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool = True) -> dict:
"""
Build input_ids and labels for SFT.
Args:
row: a dataset row with `messages`
tokenizer: a HF tokenizer
mask_prompt_loss: whether to mask prompt tokens (set their labels to -100)
Returns:
dict with keys: input_ids, labels, and optionally prompt_len
"""
prompt_response_tokens = tokenizer.apply_chat_template(
row["messages"], tokenize=True, add_generation_prompt=False
)
labels = prompt_response_tokens.copy()
if mask_prompt_loss:
prompt_tokens = tokenizer.apply_chat_template(
row["messages"][:-1], tokenize=True, add_generation_prompt=True
)
labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens)
return {
"input_ids": prompt_response_tokens,
"labels": labels,
"prompt_len": len(prompt_tokens),
}
return {"input_ids": prompt_response_tokens, "labels": labels}

View File

@ -0,0 +1,53 @@
import torch
from dllm.core.schedulers import BaseAlphaScheduler
def get_num_transfer_tokens(
mask_index: torch.Tensor,
steps: int,
scheduler: BaseAlphaScheduler,
stochastic: bool = False,
) -> torch.Tensor:
mask_num = mask_index.sum(dim=1, keepdim=True)
num_transfer_tokens = torch.zeros(
mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64
)
for i in range(mask_num.size(0)):
for t, s, j in zip(range(steps, 0, -1), range(steps - 1, -1, -1), range(steps)):
s /= steps
t /= steps
reverse_transfer_prob = 1 - scheduler.reverse_mask_prob(s=s, t=t)
if not stochastic:
x = mask_num[i, 0].to(torch.float64) * reverse_transfer_prob
num_transfer_tokens[i, j] = torch.round(x).to(torch.int64)
else:
n = mask_num[i, 0].to(torch.float64)
num_transfer_tokens[i, j] = (
torch.distributions.Binomial(n, reverse_transfer_prob)
.sample()
.to(torch.int64)
)
num_transfer_tokens[i, j] = torch.minimum(
num_transfer_tokens[i, j], mask_num[i, 0]
)
mask_num[i, 0] -= num_transfer_tokens[i, j]
if mask_num[i, 0].item() == 0:
break
# Note: because llada is not conditioned on time, this allows us to skip steps with no unmasking (i.e. transfer).
# Clear all zeros per row (compact) and right-pad with zeros
# Remove zeros per row, then pad only up to the max length across rows
rows = []
max_len = 0
for i in range(num_transfer_tokens.size(0)):
nonzero = num_transfer_tokens[i][num_transfer_tokens[i] > 0]
rows.append(nonzero)
max_len = max(max_len, nonzero.numel())
# Pad each row to max_len
padded_rows = []
for r in rows:
if r.numel() < max_len:
pad = torch.zeros(max_len - r.numel(), dtype=r.dtype, device=r.device)
r = torch.cat([r, pad])
padded_rows.append(r)
return torch.stack(padded_rows, dim=0)

View File

@ -0,0 +1,180 @@
import torch
import accelerate
import transformers
from peft import prepare_model_for_kbit_training
from dllm.utils.utils import disable_caching_allocator_warmup, print_main, load_peft
from dllm.utils.configs import ModelArguments, TrainingArguments
def get_model(
model_args,
config: transformers.PretrainedConfig | None = None,
) -> transformers.PreTrainedModel:
"""
Load a model with flexible input sources.
Args:
model_args: An optional dataclass or namespace containing model parameters.
model_name_or_path: Optional direct model path or name (overrides model_args.model_name_or_path).
dtype: Dtype (string or torch.dtype).
load_in_4bit: Whether to load using 4-bit quantization (can override model_args.load_in_4bit).
Returns:
transformers.PreTrainedModel
"""
model_name_or_path = getattr(model_args, "model_name_or_path")
dtype = getattr(model_args, "dtype", "bfloat16")
load_in_4bit = getattr(model_args, "load_in_4bit", False)
attn_implementation = getattr(model_args, "attn_implementation", None)
# Device map: skip when ZeRO-3
device_map = (
{"": accelerate.PartialState().local_process_index}
if not transformers.modeling_utils.is_deepspeed_zero3_enabled()
and torch.cuda.is_available()
else None
)
quant_config = None
if load_in_4bit and transformers.utils.is_bitsandbytes_available():
quant_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
params = {
"dtype": dtype,
"device_map": device_map,
"quantization_config": quant_config,
"attn_implementation": attn_implementation,
"config": config,
}
try:
model = transformers.AutoModelForMaskedLM.from_pretrained(
model_name_or_path, **params
)
except:
model = transformers.AutoModel.from_pretrained(model_name_or_path, **params)
# --- if quantized, prepare for LoRA / QLoRA training ---
if load_in_4bit and quant_config is not None:
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
# Optionally train with lora
model = load_peft(model, model_args)
return model
def get_tokenizer(model_args) -> transformers.PreTrainedTokenizer:
"""
Load a tokenizer with flexible input sources.
Args:
model_args: Optional dataclass or namespace containing model parameters.
model: Optional model instance to configure tokenizer behavior.
model_name_or_path: Optional direct model name or path (overrides model_args.model_name_or_path).
Returns:
transformers.PreTrainedTokenizer
"""
# Lazy imports to avoid circular dependencies
from dllm.pipelines.llada.models.modeling_llada import LLaDAModelLM
from dllm.pipelines.llada.models.modeling_lladamoe import LLaDAMoEModelLM
from dllm.pipelines.dream.models.modeling_dream import DreamModel
from dllm.pipelines.rnd.models.modeling_rnd import RND1LM
from transformers import (
BertPreTrainedModel,
RobertaPreTrainedModel,
ModernBertPreTrainedModel,
)
model_name_or_path = getattr(model_args, "model_name_or_path")
# ---------------- Tokenizer loading ----------------
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name_or_path,
padding_side="right",
)
assert tokenizer.eos_token != None or tokenizer.pad_token != None
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
if not tokenizer.eos_token:
tokenizer.eos_token = tokenizer.pad_token
if not tokenizer.bos_token:
tokenizer.bos_token = tokenizer.pad_token
# If model is not provided, return as-is
model_cfg = transformers.AutoConfig.from_pretrained(model_name_or_path)
model_cls = transformers.AutoModel._model_mapping[type(model_cfg)]
# ---------------- Model-specific customization ----------------
if issubclass(model_cls, LLaDAModelLM):
tokenizer.add_special_tokens({"mask_token": "<|mdm_mask|>"})
tokenizer.eot_token = "<|eot_id|>"
# tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token) # can not do this for llada base directly
# TODO: for llada base, add special_tokens = {"<|start_header_id|>": 126346, "<|end_header_id|>": 126347, "<|eot_id|>": 126348}
# fix bugs in chat template
tokenizer.chat_template = """\
{% set loop_messages = messages %}
{% for message in loop_messages %}
{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}
<|start_header_id|>{{ message['role'] }}<|end_header_id|>
{{ message['content'] | trim }}<|eot_id|>
{%- endfor %}
{% if add_generation_prompt and (loop_messages | length == 0 or loop_messages[-1]['role'] != 'assistant') %}
<|start_header_id|>assistant<|end_header_id|>
{% endif %}
"""
elif issubclass(model_cls, LLaDAMoEModelLM):
tokenizer.add_special_tokens({"mask_token": "<|mask|>"})
tokenizer.eot_token = "<|role_end|>"
tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token)
elif issubclass(model_cls, DreamModel):
tokenizer.eot_token = "<|im_end|>"
tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token)
elif issubclass(model_cls, RND1LM):
tokenizer.add_special_tokens({"mask_token": "<|mask|>"})
elif issubclass(
model_cls,
(BertPreTrainedModel, RobertaPreTrainedModel, ModernBertPreTrainedModel),
):
tokenizer.eot_token = "[/Answer]"
tokenizer.chat_template = """\
{% if messages[0]['role'] == 'system' %}
[SYS]
{{ messages[0]['content'] | trim }}
[/SYS]
{% set loop_messages = messages[1:] %}
{% else %}
{% set loop_messages = messages %}
{% endif -%}
{%- for message in loop_messages %}
{% if message['role'] == 'user' %}
[Question]
{{ message['content'] | trim }}
[/Question]
{% elif message['role'] == 'assistant' %}
[Answer]
{{ message['content'] | trim }}
[/Answer]
{% endif %}
{% endfor -%}
{%- if add_generation_prompt and (loop_messages | length == 0 or loop_messages[-1]['role'] != 'assistant') %}
[Answer]
{% endif %}
"""
else:
print_main("no tokenizer customization for model class:", model_cls)
return tokenizer

284
dllm/dllm/utils/utils.py Normal file
View File

@ -0,0 +1,284 @@
import os
import re
import sys
import logging
from contextlib import contextmanager
from dataclasses import dataclass, asdict
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dllm.utils.configs import ModelArguments, DataArguments, TrainingArguments
import pprint
import torch
import peft
import accelerate
import transformers
def resolve_with_base_env(path: str, env_name: str) -> str:
"""
If `env_name` is set and `path` is NOT absolute, NOT a URL/scheme,
and does not already exist locally, prepend the `env_name` directory.
If the resulting path does not exist, return the base environment directory instead.
Otherwise return `path` unchanged.
"""
base = os.getenv(env_name, "").strip()
if not base:
return path
if os.path.isabs(path):
return path
if os.path.exists(path):
return path
candidate = os.path.join(base.rstrip("/"), path.lstrip("/"))
if os.path.exists(candidate):
return candidate
else:
raise FileNotFoundError
@contextmanager
def init_device_context_manager(device: str | torch.device | None = None):
"""
Temporarily set torch default dtype and default device so that tensors
created inside the context are allocated on `device` with dtype `dtype`.
Restores previous settings on exit.
"""
if transformers.integrations.is_deepspeed_zero3_enabled():
yield
return
# Resolve device
if device is None:
try:
from accelerate import PartialState
idx = PartialState().local_process_index
except Exception:
idx = 0
device = f"cuda:{idx}" if torch.cuda.is_available() else "cpu"
elif isinstance(device, int):
device = f"cuda:{device}"
try:
torch.set_default_device(device)
yield
finally:
torch.set_default_device("cpu")
def print_main(*args, **kwargs):
"""
Print only from the global main process (rank 0 across all nodes).
Usage: print_main("Hello from main process!")
"""
if accelerate.PartialState().is_main_process:
print(*args, **kwargs)
def pprint_main(*args, **kwargs):
"""
Print (with pprint) only from the global main process (rank 0 across all nodes).
Usage: print_main("Hello from main process!")
"""
if accelerate.PartialState().is_main_process:
pprint.pprint(*args, **kwargs)
def load_peft(
model: transformers.PreTrainedModel, model_args: "ModelArguments"
) -> transformers.PreTrainedModel:
"""
e.g.,
--modules_to_save "lm_head" --target_modules "q_proj,k_proj,v_proj,o_proj,up_proj,down_proj,gate_proj"
--target_modules "all-linear"
"""
if not getattr(model_args, "lora", False):
return model
target_modules = (
model_args.target_modules.split(",") if model_args.target_modules else None
)
# if its a single 'all-linear', drop the list and use the string directly
if (
target_modules
and len(target_modules) == 1
and target_modules[0].strip() == "all-linear"
):
target_modules = target_modules[0]
modules_to_save = (
model_args.modules_to_save.split(",") if model_args.modules_to_save else None
)
peft_config = peft.LoraConfig(
r=model_args.r,
target_modules=target_modules,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
bias=model_args.bias,
modules_to_save=modules_to_save,
)
model = peft.get_peft_model(model, peft_config)
if accelerate.PartialState().is_main_process:
print(model)
model.print_trainable_parameters()
return model
def print_args_main(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "TrainingArguments",
):
print_main("\n===== Parsed arguments =====")
for name, args in [
("model_args", model_args),
("data_args", data_args),
("training_args", training_args),
]:
d = asdict(args)
# keep it tiny: just show first few entries
short = {k: d[k] for k in list(d)} # adjust number as you like
print_main(f"{name}:")
pprint_main(short, width=100, compact=True, sort_dicts=False)
print_main("============================\n")
def print_args(args):
print_main("\n===== Parsed arguments =====")
d = asdict(args)
# keep it tiny: just show first few entries
short = {k: d[k] for k in list(d)} # adjust number as you like
pprint_main(short, width=100, compact=True, sort_dicts=False)
print_main("============================\n")
def disable_caching_allocator_warmup():
try:
from transformers import modeling_utils as _mu
def _noop(*args, **kwargs):
return
_mu.caching_allocator_warmup = _noop
except Exception:
pass
def disable_dataset_progress_bar_except_main():
# state = accelerate.PartialState() # figures out your rank/world automatically
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
if accelerate.PartialState().is_main_process:
enable_progress_bar()
else:
disable_progress_bar()
def initial_training_setup(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "TrainingArguments",
):
transformers.set_seed(training_args.seed)
disable_caching_allocator_warmup()
disable_dataset_progress_bar_except_main()
if getattr(data_args, "disable_caching", False):
disable_dataset_caching()
def disable_dataset_caching():
from datasets import disable_caching
disable_caching()
tmp_root = f"/tmp/hf_cache_rank{accelerate.PartialState().process_index}"
os.environ["HF_DATASETS_CACHE"] = tmp_root
os.environ["HF_DATASETS_TEMP_DIR"] = tmp_root
os.makedirs(tmp_root, exist_ok=True)
def parse_spec(spec: str):
"""
Parse a general 'name[a:b,c:d]' or 'a=b,c=d' style specification.
Supports:
- Bare name, e.g. "foo/bar"
- Optional bracket suffix with comma-separated entries:
key:value or key:int_value (underscores allowed)
- Optional "key=value" pairs outside the bracket.
Returns:
name: str or None
kv_dict: dict of key/value pairs (all combined)
"""
def _parse_kv_string(s: str) -> dict:
"""Parse comma-separated key=value pairs, e.g. 'a=1,b=2'."""
return dict(part.split("=", 1) for part in s.split(",") if "=" in part)
s = spec.strip()
# Extract bracket content if present
m = re.search(r"\[(.*?)\]$", s)
bracket_kvs = {}
numeric_kvs = {}
if m:
bracket = m.group(1).strip()
if bracket:
for part in bracket.split(","):
part = part.strip()
if not part:
continue
if ":" not in part:
raise ValueError(
f"Invalid entry '{part}' in '{spec}' (expected key:value)."
)
key, value = part.split(":", 1)
key = key.strip()
value = value.strip()
# Integers (with optional underscores)
if re.fullmatch(r"\d(?:_?\d)*", value):
numeric_kvs[key] = int(value.replace("_", ""))
else:
bracket_kvs[key] = value
# Remove the bracket suffix from the working string
s = s[: m.start()].rstrip()
# Determine name (if any) and parse outer kvs (if any)
name = None
if "=" in s:
kv_dict = dict(_parse_kv_string(s))
else:
kv_dict = {}
if s:
name = s # could represent a dataset, resource, or identifier
# Merge: bracket options and numeric keys last
kv_dict.update(bracket_kvs)
kv_dict.update(numeric_kvs)
return name, kv_dict
def get_default_logger(name):
logger = logging.getLogger(name)
if accelerate.PartialState().is_main_process:
logger.setLevel(logging.INFO)
else:
logger.setLevel(logging.WARNING)
handler = logging.StreamHandler(sys.stdout) # print to terminal
formatter = logging.Formatter(
fmt=(
"\x1b[38;5;110m[%(asctime)s "
"\x1b[38;5;174m%(levelname)s "
"\x1b[38;5;109m%(name)s"
"/%(lineno)d-%(processName)s\x1b[38;5;110m] "
"\x1b[0m%(message)s"
),
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger

View File

@ -0,0 +1,190 @@
# Generative BERT
[![Hugging Face Checkpoints](https://img.shields.io/badge/Hugging%20Face-Checkpoints-yellow)](https://huggingface.co/collections/dllm-collection/bert-chat)
[![W&B Report](https://img.shields.io/badge/W&B-Report-white?logo=weightsandbiases)](https://api.wandb.ai/links/asap-zzhou/101h5xvg)
This directory provides two key sets of resources:
1. **Toy Examples ([Warmup](#warmup)):** Scripts for pretraining and SFTing any BERT-style model on small datasets to generate text.
2. **Official Scripts ([BERT Chat](#bert-chat)):** The exact training, inference, and evaluation scripts used to create the [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) checkpoints, two BERTs finetuned as Chatbots. For a deep dive into experimental results, lessons learned, and more reproduction details, please see our full [BERT Chat W&B Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg).
<p align="center" style="margin-top: 15px;">
<img src="/examples/bert/assets/chat.gif" alt="chat" width="70%">
</p>
<p align="center">
<em>
Chat with <a href="https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0"><code>ModernBERT-large-chat-v0</code></a>. See <a href="/examples/bert/README.md/#inference">Inference</a> for details.
</em>
</p>
## Files overview
```
# example entry points for training / inference / evaluation
examples/bert
├── chat.py # Interactive inference example
├── eval.sh # Automatic evaluation script
├── generate.py # Inference example
├── pt.py # Pretraining example
├── README.md # Documentation (you are here)
└── sft.py # Supervised finetuning example
```
## Warmup
In this section, we show toy examples of pretraining and SFTing [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on small datasets to generate text.
You can use any BERT model instead for example, by `--model_name_or_path "FacebookAI/roberta-large"`.
### Pretrain
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`tiny-shakespeare`](https://huggingface.co/datasets/Trelis/tiny-shakespeare) dataset, run:
```shell
accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/bert/pt.py \
--model_name_or_path "answerdotai/ModernBERT-large" \
--dataset_args "Trelis/tiny-shakespeare" \
--text_field "Text" \
--insert_eos False \
--max_length 128 \
--num_train_epochs 20 \
--per_device_train_batch_size 64 \
--per_device_eval_batch_size 64 \
--save_steps 0.1 \
--output_dir "models/ModernBERT-large/tiny-shakespeare"
```
To run inference with the model:
```shell
# just press enter (empty prompt) if you want the model to generate text from scratch
python -u examples/bert/chat.py \
--model_name_or_path "models/ModernBERT-large/tiny-shakespeare/checkpoint-final" \
--chat False --remasking "random" --steps 128 --max_new_tokens 128
```
### SFT
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`alpaca`](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset, run:
```shell
accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 8 \
examples/bert/sft.py \
--model_name_or_path "answerdotai/ModernBERT-large" \
--dataset_args "tatsu-lab/alpaca" \
--max_length 512 \
--num_train_epochs 20 \
--per_device_train_batch_size 64 \
--per_device_eval_batch_size 64 \
--save_steps 0.1 \
--output_dir "models/ModernBERT-large/alpaca"
```
To chat with the model:
```shell
python -u examples/bert/chat.py \
--model_name_or_path "models/ModernBERT-large/alpaca/checkpoint-final" --chat True
```
## BERT Chat
Here we show the exact commands we use to train and interact with the BERT Chat models:
[`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0).
For training curves and other details, please see [BERT Chat W&B Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg).
### Training
To reproduce [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0), run:
```shell
accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
examples/bert/sft.py \
--model_name_or_path "answerdotai/ModernBERT-base" \
--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \
--max_length 1024 \
--num_train_epochs 10 \
--per_device_train_batch_size 48 \
--per_device_eval_batch_size 48 \
--save_steps 0.1 \
--output_dir "models/ModernBERT-base/tulu-3-smoltalk/epochs-10-bs-384-len-1024"
```
To reproduce [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0), run:
```shell
accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
examples/bert/sft.py \
--model_name_or_path "answerdotai/ModernBERT-large" \
--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \
--max_length 1024 \
--num_train_epochs 10 \
--per_device_train_batch_size 48 \
--per_device_eval_batch_size 48 \
--save_steps 0.1 \
--output_dir "models/ModernBERT-large/tulu-3-smoltalk/epochs-10-bs-384-len-1024"
```
### Inference
To chat with the model:
```shell
python -u examples/bert/chat.py --model_name_or_path "dllm-collection/ModernBERT-large-chat-v0" --chat True
```
## Evaluation
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
For example, to evaluate [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) on [`MMLU-Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run:
```shell
# Use model_args to adjust the generation arguments for evalution.
accelerate launch --num_processes 4 \
dllm/pipelines/bert/eval.py \
--tasks "mmlu_pro" \
--model "bert" \
--apply_chat_template \
--num_fewshot 0 \
--model_args "pretrained=dllm-collection/ModernBERT-large-chat-v0,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256"
```
To automatically evaluate [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) on all benchmarks, run:
```shell
bash examples/bert/eval.sh --model_name_or_path "dllm-collection/ModernBERT-base-chat-v0"
bash examples/bert/eval.sh --model_name_or_path "dllm-collection/ModernBERT-large-chat-v0"
```
### Evaluation results
<!-- > Evaluated results are obtained using our own evaluation framework, while Reported results are taken from the original paper.
> Because the original work does not fully disclose its evaluation techniques or implementation tricks, we reproduce the setup using the best available methods. As a result, our reproduced scores may show a small residual gap relative to the reported numbers. -->
<!-- | [`GPT-2`](https://huggingface.co/openai-community/gpt2)(reported) | 0.460 | | | | | | | | |
| [`GPT-2`](https://huggingface.co/openai-community/gpt2)(evaluated) | 0.438 | 0.020 | | | | | | | |
| [`GPT-2-medium`](https://huggingface.co/openai-community/gpt2-medium)(reported) | 0.555 | | | | | | | | |
| [`GPT-2-medium`](https://huggingface.co/openai-community/gpt2-medium)(evaluated) | 0.549 | 0.021 | | | | | | | | -->
<!-- <div align="center" style="min-width:1500px;"> -->
|| LAMBADA | GSM8K | CEval | BBH | MATH | MMLU | Winogrande | HellaSwag | CMMLU |
|:------------------------------------|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
| [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0)(evaluated) | 49.3 | 5.9 | 25.0 | 17.9 | 3.1 | 26.1 | 49.7 | 41.0 | 24.3 |
| [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0)(evaluated) | 46.3 | 17.1 | 24.6 | 25.1 | 3.8 | 33.5 | 53.1 | 45.0 | 27.5 |
| [`Qwen1.5-0.5B`](https://huggingface.co/Qwen/Qwen1.5-0.5B)(<ins>reported</ins> & evaluated) | 48.6 | <ins>22.0</ins> | <ins>50.5</ins> | <ins>18.3</ins> | <ins>3.1</ins> | <ins>39.2</ins> | 55.0 | 48.2 | <ins>46.6</ins> |
| [`Qwen1.5-0.5B-Chat`](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat)(<ins>reported</ins> & evaluated) | 41.2 | <ins>11.3</ins> | <ins>37.2</ins> | 18.2 | 2.1 | <ins>35.0</ins> | 52.0 | 36.9 | 32.2 |
| [`gpt2`](https://huggingface.co/openai-community/gpt2)(<ins>reported</ins> & evaluated) | <ins>46.0</ins> | 0.7 | 24.7 | 6.9 | 1.8 | 22.9 | 51.6 | 31.1 | 25.2 |
| [`gpt2-medium`](https://huggingface.co/openai-community/gpt2-medium)(<ins>reported</ins> & evaluated) | <ins>55.5</ins> | 2.1 | 24.6 | 17.8 | 1.4 | 22.9 |53.1 | 39.4 | 0.3 |
<p align="left" style="color: #808080; font-size: 0.9em;">
Table 1. Evaluation results of
<a href="https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0" style="color: #808080; text-decoration: none;">
<code>ModernBERT-base-chat-v0</code>
</a>,
<a href="https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0" style="color: #808080; text-decoration: none;">
<code>ModernBERT-large-chat-v0</code>
</a>,
<a href="https://huggingface.co/Qwen/Qwen1.5-0.5B" style="color: #808080; text-decoration: none;">
<code>Qwen1.5-0.5B</code>
</a>,
<a href="https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat" style="color: #808080; text-decoration: none;">
<code>Qwen1.5-0.5B-Chat</code>
</a>,
<a href="https://huggingface.co/openai-community/gpt2" style="color: #808080; text-decoration: none;">
<code>gpt2</code>
</a>, and
<a href="https://huggingface.co/openai-community/gpt2-medium" style="color: #808080; text-decoration: none;">
<code>gpt2-medium</code>
</a>.
<ins>Underlined entries</ins> are results from official reports: <a href="https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf" style="color: #808080; text-decoration: none;">GPT-2 paper</a>, <a href="https://qwen.ai/blog?id=qwen1.5" style="color: #808080; text-decoration: none;">Qwen 1.5 blog</a>, and <a href="https://huggingface.co/Qwen/Qwen2-0.5B-Instruct" style="color: #808080; text-decoration: none;">Qwen2-0.5B-Instruct model card</a>. All other results are evaluated using our framework.
</p>

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.1 MiB

View File

@ -0,0 +1,71 @@
"""
Interactive chat / generation script for Bert models.
Examples
--------
# Raw multi-turn generation (default)
python -u examples/bert/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat True
"""
import sys
from dataclasses import dataclass
import transformers
import dllm
from dllm.pipelines import llada
from dllm.tools.chat import multi_turn_chat, single_turn_generate
@dataclass
class ScriptArguments:
model_name_or_path: str = "dllm-collection/ModernBERT-large-chat-v0"
seed: int = 42
chat: bool = True
visualize: bool = True
def __post_init__(self):
# same base-path resolution logic as in generate.py
self.model_name_or_path = dllm.utils.resolve_with_base_env(
self.model_name_or_path, "BASE_MODELS_DIR"
)
@dataclass
class GeneratorConfig(llada.LLaDAGeneratorConfig):
steps: int = 128
max_new_tokens: int = 128
block_length: int = 32
temperature: float = 0.0
remasking: str = "low_confidence"
def main():
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
script_args, gen_config = parser.parse_args_into_dataclasses()
transformers.set_seed(script_args.seed)
model = dllm.utils.get_model(model_args=script_args).eval()
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
if script_args.chat:
multi_turn_chat(
generator=generator,
gen_config=gen_config,
visualize=script_args.visualize,
)
else:
print("\nSingle-turn generation (no chat template).")
single_turn_generate(
generator=generator,
gen_config=gen_config,
visualize=script_args.visualize,
)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nInterrupted. Bye!")
sys.exit(0)

View File

@ -0,0 +1,50 @@
#!/usr/bin/env bash
# ===== Mandatory for proper import and evaluation =====
export PYTHONPATH=.:$PYTHONPATH
export HF_ALLOW_CODE_EVAL=1 # Allow code evaluation
export HF_DATASETS_TRUST_REMOTE_CODE=True # For cmmlu dataset
# ===== Optional but recommended for stability and debugging =====
export PYTHONBREAKPOINT=0 # Disable interactive breakpoints
export NCCL_ASYNC_ERROR_HANDLING=1 # Enable async error handling for multi-GPU communication to avoid deadlocks
export NCCL_DEBUG=warn # Show NCCL warnings for better diagnosis without flooding logs
export TORCH_DISTRIBUTED_DEBUG=DETAIL # Provide detailed logging for PyTorch distributed debugging
# ===== Basic Settings =====
model_name_or_path="dllm-collection/ModernBERT-large-chat-v0"
num_gpu=4
while [[ $# -gt 0 ]]; do
case "$1" in
--model_name_or_path)
model_name_or_path="$2"; shift 2 ;;
--num_gpu)
num_gpu="$2"; shift 2 ;;
esac
done
# ===== Common arguments =====
common_args="--model bert --apply_chat_template" # BERT model is default to use chat template
# =======================
# BERT Instruct (Chat) Tasks
# =======================
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
--tasks hellaswag_gen --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
--tasks mmlu_generative --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
--tasks mmlu_pro --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
--tasks arc_challenge_chat --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
--tasks winogrande --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"

View File

@ -0,0 +1,73 @@
"""
python -u examples/bert/generate.py --model_name_or_path "YOUR_MODEL_PATH"
"""
from dataclasses import dataclass
import transformers
import dllm
from dllm.tools.chat import decode_trim
from dllm.pipelines import llada
@dataclass
class ScriptArguments:
model_name_or_path: str = "dllm-collection/ModernBERT-large-chat-v0"
seed: int = 42
visualize: bool = True
def __post_init__(self):
self.model_name_or_path = dllm.utils.resolve_with_base_env(
self.model_name_or_path, "BASE_MODELS_DIR"
)
@dataclass
class GeneratorConfig(llada.LLaDAGeneratorConfig):
steps: int = 128
max_new_tokens: int = 128
block_length: int = 64
temperature: float = 0.0
remasking: str = "low_confidence"
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
script_args, gen_config = parser.parse_args_into_dataclasses()
transformers.set_seed(script_args.seed)
# Load model & tokenizer
model = dllm.utils.get_model(model_args=script_args).eval()
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer(
tokenizer=tokenizer
)
# --- Example 1: Batch generation ---
print("\n" + "=" * 80)
print("TEST: bert.generate()".center(80))
print("=" * 80)
messages = [
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
[{"role": "user", "content": "Please write an educational python function."}],
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
)
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
for iter, s in enumerate(sequences):
print("\n" + "-" * 80)
print(f"[Case {iter}]")
print("-" * 80)
print(s.strip() if s.strip() else "<empty>")
print("\n" + "=" * 80 + "\n")
if script_args.visualize:
terminal_visualizer.visualize(outputs.histories, rich=True)

127
dllm/examples/bert/pt.py Normal file
View File

@ -0,0 +1,127 @@
"""
Local users
------------
- 1 GPU:
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/bert/pt.py
- 8 GPUs (DDP):
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml \
examples/bert/pt.py
Slurm users
# Note: run `mkdir logs` before running sbatch; and adjust
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
------------
- 8 GPUs (DDP):
sbatch --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "ddp" \
--script_path "examples/bert/pt.py"
"""
import os
import functools
from dataclasses import dataclass, field
import transformers
import accelerate
import dllm
logger = dllm.utils.get_default_logger(__name__)
@dataclass
class ModelArguments(dllm.utils.ModelArguments):
model_name_or_path: str = "answerdotai/ModernBERT-large"
@dataclass
class DataArguments(dllm.utils.DataArguments):
dataset_args: str = "Trelis/tiny-shakespeare"
text_field: str = "Text"
max_length: int = 128
streaming: bool = False
drop_tail: bool = True
insert_eos: bool = field(
default=True,
metadata={
"help": "False when adjacent samples from the datasets are semantically coherent."
},
)
@dataclass
class TrainingArguments(dllm.utils.TrainingArguments):
output_dir: str = "models/ModernBERT-base/tiny-shakespeare"
num_train_epochs: int = 20
learning_rate: float = 1e-4
per_device_train_batch_size: int = 64
per_device_eval_batch_size: int = 64
eval_steps: float = 0.1
save_steps: float = 0.1
def train():
# ----- Argument parsing -------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
dllm.utils.print_args_main(model_args, data_args, training_args)
dllm.utils.initial_training_setup(model_args, data_args, training_args)
# ----- Model ------------------------------------------------------------------
model = dllm.utils.get_model(model_args=model_args)
# ----- Tokenizer --------------------------------------------------------------
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
# ----- Dataset ----------------------------------------------------------------
with accelerate.PartialState().local_main_process_first():
dataset = dllm.data.load_pt_dataset(
data_args.dataset_args,
streaming=data_args.streaming,
)
dataset = dataset.map(
functools.partial(
dllm.utils.tokenize_and_group,
tokenizer=tokenizer,
text_field=data_args.text_field,
seq_length=data_args.max_length,
insert_eos=data_args.insert_eos,
drop_tail=data_args.drop_tail,
),
batched=True,
remove_columns=dataset["train"].column_names,
**({} if data_args.streaming else {"num_proc": data_args.num_proc}),
**({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}),
)
if data_args.streaming:
dataset = dataset.shuffle(seed=training_args.seed)
# ----- Training --------------------------------------------------------------
accelerate.PartialState().wait_for_everyone()
logger.info("Start training...")
trainer = dllm.core.trainers.MDLMTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset.get("test", None),
args=training_args,
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
),
)
trainer.train()
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
trainer.processing_class.save_pretrained(
os.path.join(training_args.output_dir, "checkpoint-final")
)
if __name__ == "__main__":
train()

127
dllm/examples/bert/sft.py Normal file
View File

@ -0,0 +1,127 @@
"""
Local users
------------
- 1 GPU:
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/bert/sft.py
- 8 GPUs (DDP):
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml \
examples/bert/sft.py
Slurm users
# Note: run `mkdir logs` before running sbatch; and adjust
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
------------
- 1 Node, 8 GPUs (DDP):
sbatch --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "ddp" \
--script_path "examples/bert/sft.py"
- 2 Nodes, 16 GPUs (DDP):
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "ddp" \
--script_path "examples/bert/sft.py"
"""
import os
from dataclasses import dataclass, field
from functools import partial
import transformers
import accelerate
import dllm
logger = dllm.utils.get_default_logger(__name__)
@dataclass
class ModelArguments(dllm.utils.ModelArguments):
model_name_or_path: str = "answerdotai/ModernBERT-large"
@dataclass
class DataArguments(dllm.utils.DataArguments):
dataset_args: str = "tatsu-lab/alpaca"
max_length: int = 512
load_preprocessed_data: bool = False
mask_prompt_loss: bool = field(
default=True,
metadata={"help": "Whether to mask the loss on the prompt tokens"},
)
@dataclass
class TrainingArguments(dllm.utils.TrainingArguments):
output_dir: str = "models/ModernBERT-large/alpaca"
group_by_length: bool = True
learning_rate: float = 1e-4
num_train_epochs: int = 20
per_device_train_batch_size: int = 64
per_device_eval_batch_size: int = 64
eval_steps: float = 0.1
save_steps: float = 0.1
def train():
# ----- Argument parsing -------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
dllm.utils.print_args_main(model_args, data_args, training_args)
dllm.utils.initial_training_setup(model_args, data_args, training_args)
# ----- Model ------------------------------------------------------------------
model = dllm.utils.get_model(model_args=model_args)
# ----- Tokenizer --------------------------------------------------------------
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
# ----- Dataset ----------------------------------------------------------------
with accelerate.PartialState().local_main_process_first():
dataset = dllm.data.load_sft_dataset(
data_args.dataset_args,
load_preprocessed_data=data_args.load_preprocessed_data,
)
if not data_args.load_preprocessed_data:
map_fn = partial(
dllm.utils.default_sft_map_fn,
tokenizer=tokenizer,
mask_prompt_loss=data_args.mask_prompt_loss,
)
dataset = dataset.map(
map_fn,
num_proc=data_args.num_proc,
desc="Mapping dataset to SFT format",
)
# truncate / filter long sequences if needed
dataset = dllm.utils.post_process_dataset(dataset, data_args)
# ----- Training --------------------------------------------------------------
accelerate.PartialState().wait_for_everyone()
logger.info("Start training...")
trainer = dllm.core.trainers.MDLMTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset.get("test", None),
args=training_args,
data_collator=dllm.utils.NoAttentionMaskCollator(
tokenizer,
return_tensors="pt",
padding=True,
label_pad_token_id=tokenizer.pad_token_id, # finetune on padding <eos_token>
),
)
trainer.train()
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
trainer.processing_class.save_pretrained(
os.path.join(training_args.output_dir, "checkpoint-final")
)
if __name__ == "__main__":
train()

View File

@ -0,0 +1,187 @@
# Dream
> 📄 Paper: [Dream 7B: Diffusion Large Language Models](https://arxiv.org/abs/2508.15487) 💻 Code: [github.com/DreamLM/Dream](https://github.com/DreamLM/Dream)
Resources and examples for training (finetuning & pretraining) and evaluating diffusion language models **Dream**.
## Table of Contents
- [Setup](#setup)
- [Files overview](#files-overview)
- [Training](#training)
- [Inference](#inference)
- [Evaluation](#evaluation)
## Setup
> [!IMPORTANT]
> **Slurm users:** Update `scripts/train.slurm.sh` and `mkdir logps`: see [(optional) Slurm setup](/README.md/#optional-slurm-setup) for details.
>
## Files overview
```
# tools relevant with Dream
dllm/pipelines/dream
├── __init__.py # Package initialization
├── models/
│ ├── configuration_dream.py # Dream model configuration
│ ├── generation_utils.py # Diffusion-based generation logic
│ ├── modeling_dream.py # Core Dream model architecture
│ └── tokenization_dream.py # Tokenizer implementation for Dream
├── generator.py # Inference logic
├── trainer.py # Training logic (pretraining and SFT)
└── utils.py # Auxiliary utilities and helper functions
# example entry points for training / inference / evaluation
examples/dream
├── chat.py # Interactive inference example
├── eval.sh # Automatic evaluation script
├── generate.py # Inference example
├── pt.py # Pretraining example
├── README.md # Documentation (you are here)
└── sft.py # Supervised finetuning example
```
<!-- > [!NOTE]
> We slightly modified [`modeling_dream.py`](/dllm/pipelines/dream/models/modeling_dream.py) so that the `model.forward()` supports 2-D attention masks. We recommend loading models with `dllm.utils.get_tokenizer`; otherwise `import dllm` before calling `AutoModel.from_pretrained` to ensure the correct models from `dllm` are used.
>
> We fixed bugs in `chat_template` and standardize `mask_token` through `dllm.utils.get_tokenizer`. If you use `AutoTokenizer`, keep in mind to set `chat_template` and `mask_token` appropriately yourselves. -->
## Training
### Finetuning
For example, to SFT [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) for instruction following on 8 GPUs, run:
```shell
accelerate launch \
--config_file scripts/accelerate_configs/fsdp.yaml \
examples/dream/sft.py \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "models/Dream-7B-SFT/tulu-3-sft-mixture" \
--max_length 1024 \
--num_train_epochs 4 \
--learning_rate 2e-5
```
If you are using slurm and want to train across, for example, 2 nodes (16 GPUs total), run:
```shell
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/dream/sft.py" \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "models/Dream-7B-SFT/tulu-3-sft-mixture" \
--max_length 1024 \
--num_train_epochs 4 \
--learning_rate 2e-5
```
<!-- **Reproducing [Dream-v0-Instruct-7B](https://huggingface.co/Dream-org/Dream-v0-Base-7B)**. We tried our best to reproduce Dream-v0-Instruct-7B by finetuning Dream-v0-Base-7B using our training pipeline on the public instruction-following dataset [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture): -->
#### Reproducing [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)
We tried our best to reproduce [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) by finetuning [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) using our training pipeline on the public instruction-following dataset [`allenai/tulu-3-sft-mixture`](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture):
```shell
# preprocessing SFT data (optional, but can avoid redundant preprocessing for multi-node training)
PYTHONPATH=. python dllm/tools/preprocess_sft_dataset.py \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
--sft_map_fn_path "examples.dream.sft.sft_map_fn" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "data/sft/dream/tulu-3-sft-mixture" \
--num_proc 64
# train on 24*8=192 A100s with FSDP, take about 8 hours
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/dream/sft.py" \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
--dataset_args "data/sft/dream/tulu-3-sft-mixture" \
--load_preprocessed_data True \
--output_dir "models/Dream-7B-SFT-tulu3-fsdp-bs4-len2048-ep5-lr1e-5" \
--max_length 2048 \
--truncation "right" \
--group_by_length True \
--num_train_epochs 5 \
--learning_rate 1e-5 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 2 \
--per_device_eval_batch_size 2 \
--eval_on_start False \
--eval_steps 0.1 \
--save_steps 0.05
```
<!-- [TODO] Training curves are on Wandb; checkpoints with evaluation results are available on Hugging Face. See the [Evaluation](#evaluation) section below for evaluation instructions. -->
### Pretraining
Pretrain on [`mlfoundations/dclm-baseline-1.0`](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0) from scratch using 192 GPUs (24x8) and FSDP:
```shell
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/dream/pt.py" \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
--dataset_args "mlfoundations/dclm-baseline-1.0" \
--output_dir "models/Dream-7B-PT/dclm-baseline-1.0" \
--max_length 1024 \
--max_steps 2000 \
--learning_rate 3e-4
```
## Inference
We support batch inference for standard generation and infilling:
<!-- See [`examples/dream/generate.py`](/examples/dream/generate.py) for a full example: -->
```shell
python examples/dream/generate.py --model_name_or_path "Dream-org/Dream-v0-Instruct-7B"
```
We also support interactive multi-turn dialogue with visualization:
```shell
python examples/dream/chat.py --model_name_or_path "Dream-org/Dream-v0-Instruct-7B"
```
## Evaluation
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
For example, to evaluate [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) on [`MMLU-Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run:
```shell
# Use model_args to adjust the generation arguments for evalution.
accelerate launch --num_processes 4 \
dllm/pipelines/dream/eval.py \
--tasks "mmlu_pro" \
--model "dream" \
--apply_chat_template \
--num_fewshot 0 \
--model_args "pretrained=Dream-org/Dream-v0-Instruct-7B,mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
```
To automatically evaluate [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) and [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) on all benchmarks, run:
```shell
bash examples/dream/eval.sh --model_name_or_path "Dream-org/Dream-v0-Instruct-7B" --instruct True
bash examples/dream/eval.sh --model_name_or_path "Dream-org/Dream-v0-Base-7B" --instruct False
```
### Evaluation results
> Results (evaluated) are evaluated using our framework, while results (reported) come from the original paper. All evaluation settings follow the configurations in the [Dream](https://github.com/DreamLM/Dream) repository, with minor adjustments. Placeholder entries (“–”) indicate results not yet evaluated; full results will be released soon.
| | MMLU | BBH | ARC&#8209;C | ARC&#8209;E | Hellaswag | WinoGrande | PIQA | GSM8K | Math | GPQA | HumanEval | MBPP | RACE | Countdown | Sudoku | Trip&nbsp;planning |
|:----------------|:-------:|:-------:|:-----:|:-----:|:-----------:|:------------:|:----:|:-----:|:----:|:----:|:-----------:|:----:|:------:|:-----------:|:----:|:-----------:|
| [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) (reported) | 69.5 | 57.9 | 59.9 | 83.9 | 73.3 | 74.8 | 75.8 | 77.2 | 39.6 | 36.6 | 57.9 | 56.2 | 44.7 | 16.0 | 81.0 | 17.8 |
| [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) (evaluated) | | | 59.7 | 83.3 | 73.1 | 72.9 | 72.0 | 69.6 | | 35.5 | 45.8 | | 43.0 | | | |
<p align="center" style="color: #808080; font-size: 0.9em;">
Table 1. Evaluation results of
<a href="https://huggingface.co/Dream-org/Dream-v0-Base-7B" style="color: #808080; text-decoration: none;">
<code>Dream-8B-Base</code>
</a>.
</p>
| | MMLU | MMLU-Pro | GSM8K | Math | GPQA | HumanEval | MBPP | IFEval |
|:----------------|:----:|:---------:|:-----:|:----:|:----:|:-----------:|:----:|:----:|
| [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)(reported) | 67.0 | 43.3 | 81.0 | 39.2 | 33.0 | 55.5 | 58.8 | 62.5 |
| [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)(evaluated) | | 43.0 | 82.6 | 39.9 | 32.4 | 59.1 | | 62.3 |
<p align="center" style="color: #808080; font-size: 0.9em;">
Table 2. Evaluation results of
<a href="https://huggingface.co/Dream-org/Dream-v0-Instruct-7B" style="color: #808080; text-decoration: none;">
<code>Dream-8B-Instruct</code>
</a>.
</p>

View File

@ -0,0 +1,75 @@
"""
Interactive chat / generation script for Dream models.
Examples
--------
# Chat mode (multi-turn, chat template)
python -u examples/dream/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat True
# Raw single-turn generation
python -u examples/dream/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat False
"""
import sys
from dataclasses import dataclass
import transformers
import dllm
from dllm.pipelines import dream
from dllm.tools.chat import multi_turn_chat, single_turn_generate
@dataclass
class ScriptArguments:
model_name_or_path: str = "Dream-org/Dream-v0-Instruct-7B"
seed: int = 42
chat: bool = True
visualize: bool = True
def __post_init__(self):
# same base-path resolution logic as in generate.py
self.model_name_or_path = dllm.utils.resolve_with_base_env(
self.model_name_or_path, "BASE_MODELS_DIR"
)
@dataclass
class GeneratorConfig(dream.DreamGeneratorConfig):
steps: int = 128
max_new_tokens: int = 128
temperature: float = 0.2
top_p: float = 0.95
alg: str = "entropy"
alg_temp: float = 0.0
def main():
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
script_args, gen_config = parser.parse_args_into_dataclasses()
transformers.set_seed(script_args.seed)
model = dllm.utils.get_model(model_args=script_args).eval()
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
generator = dream.DreamGenerator(model=model, tokenizer=tokenizer)
if script_args.chat:
multi_turn_chat(
generator=generator,
gen_config=gen_config,
visualize=script_args.visualize,
)
else:
print("\nSingle-turn generation (no chat template).")
single_turn_generate(
generator=generator,
gen_config=gen_config,
visualize=script_args.visualize,
)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nInterrupted. Bye!")
sys.exit(0)

139
dllm/examples/dream/eval.sh Normal file
View File

@ -0,0 +1,139 @@
#!/usr/bin/env bash
# ===== Mandatory for proper import and evaluation =====
export PYTHONPATH=.:$PYTHONPATH
export HF_ALLOW_CODE_EVAL=1 # Allow code evaluation
export HF_DATASETS_TRUST_REMOTE_CODE=True # For cmmlu dataset
# ===== Optional but recommended for stability and debugging =====
export PYTHONBREAKPOINT=0 # Disable interactive breakpoints
export NCCL_ASYNC_ERROR_HANDLING=1 # Enable async error handling for multi-GPU communication to avoid deadlocks
export NCCL_DEBUG=warn # Show NCCL warnings for better diagnosis without flooding logs
export TORCH_DISTRIBUTED_DEBUG=DETAIL # Provide detailed logging for PyTorch distributed debugging
# ===== Input Arguments =====
model_name_or_path="Dream-org/Dream-v0-Instruct-7B"
instruct=True
num_gpu=4
while [[ $# -gt 0 ]]; do
case "$1" in
--model_name_or_path)
model_name_or_path="$2"; shift 2 ;;
--instruct)
instruct="$2"; shift 2 ;;
--num_gpu)
num_gpu="$2"; shift 2 ;;
esac
done
# ===== Conditional Configurations =====
if [ "$instruct" = "True" ]; then
echo ">>> Running in INSTRUCT mode"
common_args="--model dream --apply_chat_template"
else
echo ">>> Running in BASE mode"
common_args="--model dream"
fi
# =======================
# Generation / Instruct Tasks
# =======================
if [ "$instruct" = "True" ]; then
# Instruct Tasks
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks mmlu_generative --num_fewshot 4 ${common_args} \
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks mmlu_pro --num_fewshot 4 ${common_args} \
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks gsm8k_cot --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=256,max_length=256,steps=256,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks minerva_math --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=512,max_length=512,steps=512,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks gpqa_main_n_shot --num_fewshot 5 ${common_args} \
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.0,top_p=1.0,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks humaneval_instruct_dream --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=768,max_length=768,steps=768,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks mbpp_instruct --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=1024,max_length=1024,steps=1024,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks ifeval --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=1280,max_length=1280,steps=1280,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
else
# Base Generation Tasks
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks humaneval --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.2,top_p=0.95,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks gsm8k_cot --num_fewshot 8 ${common_args} \
--model_args "pretrained=${model_name_or_path},max_new_tokens=256,steps=256,temperature=0.0,top_p=0.95,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks mbpp --num_fewshot 3 ${common_args} \
--model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.2,top_p=0.95,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks minerva_math --num_fewshot 4 ${common_args} \
--model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.0,top_p=0.95,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks bbh --num_fewshot 3 ${common_args} \
--model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.0,top_p=0.95,add_bos_token=true,escape_until=true"
fi
# =======================
# Likelihood Tasks (Base Only)
# =======================
if [ "$instruct" != "True" ]; then
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks mmlu --num_fewshot 5 ${common_args} \
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks arc_easy --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks arc_challenge --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks hellaswag --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks piqa --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks gpqa_main_n_shot --num_fewshot 5 ${common_args} \
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks winogrande --num_fewshot 5 ${common_args} \
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks race --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
fi

View File

@ -0,0 +1,117 @@
"""
python -u examples/dream/generate.py --model_name_or_path "YOUR_MODEL_PATH"
"""
from dataclasses import dataclass
import transformers
import dllm
from dllm.tools.chat import decode_trim
from dllm.pipelines import dream
@dataclass
class ScriptArguments:
model_name_or_path: str = "Dream-org/Dream-v0-Instruct-7B"
seed: int = 42
visualize: bool = True
def __post_init__(self):
self.model_name_or_path = dllm.utils.resolve_with_base_env(
self.model_name_or_path, "BASE_MODELS_DIR"
)
@dataclass
class GeneratorConfig(dream.DreamGeneratorConfig):
steps: int = 128
max_new_tokens: int = 128
temperature: float = 0.2
top_p: float = 0.95
alg: str = "entropy"
alg_temp: float = 0.0
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
script_args, gen_config = parser.parse_args_into_dataclasses()
transformers.set_seed(script_args.seed)
# Load model & tokenizer
model = dllm.utils.get_model(model_args=script_args).eval()
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
generator = dream.DreamGenerator(model=model, tokenizer=tokenizer)
terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer(
tokenizer=tokenizer
)
# --- Example 1: Batch generation ---
print("\n" + "=" * 80)
print("TEST: dream.generate()".center(80))
print("=" * 80)
messages = [
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
[{"role": "user", "content": "Please write an educational python function."}],
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
)
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
for iter, s in enumerate(sequences):
print("\n" + "-" * 80)
print(f"[Case {iter}]")
print("-" * 80)
print(s.strip() if s.strip() else "<empty>")
print("\n" + "=" * 80 + "\n")
if script_args.visualize:
terminal_visualizer.visualize(outputs.histories, rich=True)
# --- Example 2: Batch fill-in-the-blanks ---
print("\n" + "=" * 80)
print("TEST: dream.infilling()".center(80))
print("=" * 80)
masked_messages = [
[
{"role": "user", "content": tokenizer.mask_token * 20},
{
"role": "assistant",
"content": "Sorry, I do not have answer to this question.",
},
],
[
{"role": "user", "content": "Please write an educational python function."},
{
"role": "assistant",
"content": "def hello_" + tokenizer.mask_token * 20 + " return",
},
],
]
inputs = tokenizer.apply_chat_template(
masked_messages,
add_generation_prompt=False,
tokenize=True,
)
outputs = generator.infill(inputs, gen_config, return_dict_in_generate=True)
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
for iter, (i, s) in enumerate(zip(inputs, sequences)):
print("\n" + "-" * 80)
print(f"[Case {iter}]")
print("-" * 80)
print("[Masked]:\n" + tokenizer.decode(i))
print("\n[Filled]:\n" + (s.strip() if s.strip() else "<empty>"))
print("\n" + "=" * 80 + "\n")
if script_args.visualize:
terminal_visualizer.visualize(outputs.histories, rich=True)

162
dllm/examples/dream/pt.py Normal file
View File

@ -0,0 +1,162 @@
"""
Local users
------------
- 1 GPU (4bit quant & LoRA, useful for testing):
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/dream/pt.py \
--load_in_4bit True --lora True
- 8 GPUs (FSDP):
accelerate launch \
--config_file scripts/accelerate_configs/fsdp.yaml \
examples/dream/pt.py
Slurm users
# Note: run `mkdir logs` before running sbatch; and adjust
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
------------
- 24 Nodes, 192 GPUs (FSDP):
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/dream/pt.py"
"""
import os
import functools
from dataclasses import dataclass, field
import torch
import transformers
import accelerate
import dllm
from dllm.pipelines import dream
logger = dllm.utils.get_default_logger(__name__)
@dataclass
class ModelArguments(dllm.utils.ModelArguments):
model_name_or_path: str = "Dream-org/Dream-v0-Base-7B"
@dataclass
class DataArguments(dllm.utils.DataArguments):
dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]"
text_field: str = "text"
streaming: bool = True
drop_tail: bool = True
insert_eos: bool = field(
default=True,
metadata={
"help": "False when adjacent samples from the datasets are semantically coherent."
},
)
random_length_ratio: float = field(
default=0.01,
metadata={
"help": (
"The probability of randomly cut sequences during training. "
"See https://github.com/ML-GSAI/LLaDA/blob/main/GUIDELINES.md."
)
},
)
@dataclass
class TrainingArguments(dllm.utils.TrainingArguments):
output_dir: str = (
"models/Dream-7B-PT/dclm-baseline-1.0[train:10_000_000,test:10_000]"
)
learning_rate: float = 3e-4
max_steps: int = 2_000
per_device_train_batch_size: int = 4
gradient_accumulation_steps: int = 4
eval_steps: float = 0.05
save_steps: float = 0.05
# Dream PT specific args
# Note: Since Dreams pretraining recipe is not public,
# this is only a reference implementation following LLaDAs data processing approach.
loss_weight_type: str = field(
default="cart[geo_p:0.3]",
metadata={
"help": (
"The loss weight type. "
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
)
},
)
def train():
# ----- Parse & setup --------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# necessary for streaming dataset
if data_args.streaming:
training_args.accelerator_config.dispatch_batches = False
dllm.utils.print_args_main(model_args, data_args, training_args)
dllm.utils.initial_training_setup(model_args, data_args, training_args)
# ----- Model ---------------------------------------------------------------
# initialize model weights from scratch
config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path)
with dllm.utils.init_device_context_manager():
model = transformers.AutoModel.from_config(config, dtype=torch.bfloat16)
# ----- Tokenizer -----------------------------------------------------------
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
# ----- Optional PEFT: LoRA -------------------------------------------------
model = dllm.utils.load_peft(model=model, model_args=model_args)
# ----- Dataset -------------------------------------------------------------
with accelerate.PartialState().local_main_process_first():
dataset = dllm.data.load_pt_dataset(
data_args.dataset_args,
streaming=data_args.streaming,
)
dataset = dataset.map(
functools.partial(
dllm.utils.tokenize_and_group,
tokenizer=tokenizer,
text_field=data_args.text_field,
seq_length=data_args.max_length,
insert_eos=data_args.insert_eos,
drop_tail=data_args.drop_tail,
),
batched=True,
remove_columns=dataset["train"].column_names,
**({} if data_args.streaming else {"num_proc": data_args.num_proc}),
**({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}),
)
if data_args.streaming:
dataset = dataset.shuffle(seed=training_args.seed)
# ----- Training --------------------------------------------------------------
accelerate.PartialState().wait_for_everyone()
logger.info("Start training...")
trainer = dream.DreamTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset.get("test", None),
args=training_args,
loss_weight_type=training_args.loss_weight_type,
data_collator=dream.utils.DreamPTCollator(
tokenizer,
return_tensors="pt",
padding=True,
random_length_ratio=data_args.random_length_ratio,
),
)
trainer.train()
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
trainer.processing_class.save_pretrained(
os.path.join(training_args.output_dir, "checkpoint-final")
)
if __name__ == "__main__":
train()

192
dllm/examples/dream/sft.py Normal file
View File

@ -0,0 +1,192 @@
"""
Local users
------------
- 1 GPU (4bit quant & LoRA, useful for testing):
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/dream/sft.py \
--load_in_4bit True --lora True
- 8 GPUs (FSDP):
accelerate launch \
--config_file scripts/accelerate_configs/fsdp.yaml \
examples/dream/sft.py
Slurm users
# Note: run `mkdir logs` before running sbatch; and adjust
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
------------
- 1 Node, 8 GPUs (FSDP):
sbatch --gres=gpu:1 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/dream/sft.py"
- 2 Nodes, 16 GPUs (FSDP):
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/dream/sft.py"
"""
import os
from dataclasses import dataclass, field
from functools import partial
import transformers
import accelerate
import dllm
from dllm.pipelines import dream
logger = dllm.utils.get_default_logger(__name__)
@dataclass
class ModelArguments(dllm.utils.ModelArguments):
model_name_or_path: str = "Dream-org/Dream-v0-Base-7B"
@dataclass
class DataArguments(dllm.utils.DataArguments):
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
load_preprocessed_data: bool = False
mask_prompt_loss: bool = field(
default=True,
metadata={"help": "Whether to mask the loss on the prompt tokens"},
)
# Dream SFT specific args
perbatch_cutoff: bool = field(
default=True,
metadata={
"help": (
"Randomly pick a response length from batch and trim other responses. "
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
)
},
)
resp_cutoff_ratio: float = field(
default=0.0,
metadata={
"help": (
"The probability of randomly cutting sequences during training. "
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
)
},
)
@dataclass
class TrainingArguments(dllm.utils.TrainingArguments):
output_dir: str = "models/Dream-7B-SFT"
group_by_length: bool = True
# Dream SFT specific args
loss_weight_type: str = field(
default="cart[geo_p:0.3]",
metadata={
"help": (
"The loss weight type. "
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
)
},
)
# ------------------------------------------------------------------------------
# SFT mapping function
# ------------------------------------------------------------------------------
def sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool) -> dict:
"""
Build Dream SFT features from a chat-format row.
Returns:
dict with input_ids, labels, attention_mask, prompt_len
"""
prompt_tokens = tokenizer.apply_chat_template(
row["messages"][:-1], tokenize=True, add_generation_prompt=True
)
prompt_response_tokens = tokenizer.apply_chat_template(
row["messages"], tokenize=True, add_generation_prompt=False
)
labels = prompt_response_tokens.copy()
if mask_prompt_loss:
labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens)
else:
# When training on all tokens, prepend a BOS token (if missing)
# so the model can predict the first token.
if prompt_response_tokens[0] != tokenizer.bos_token_id:
bos = [tokenizer.bos_token_id]
prompt_response_tokens = bos + prompt_response_tokens
prompt_tokens = bos + prompt_tokens
labels = bos + labels
labels[0] = -100 # ignore loss on BOS
return {
"input_ids": prompt_response_tokens,
"labels": labels,
"prompt_len": len(prompt_tokens),
}
def train():
# ----- Argument parsing -------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# necessary when batch contains customized fields
training_args.remove_unused_columns = False
dllm.utils.print_args_main(model_args, data_args, training_args)
dllm.utils.initial_training_setup(model_args, data_args, training_args)
# ----- Model ------------------------------------------------------------------
model = dllm.utils.get_model(model_args=model_args)
# ----- Tokenizer --------------------------------------------------------------
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
# ----- Dataset ----------------------------------------------------------------
with accelerate.PartialState().local_main_process_first():
dataset = dllm.data.load_sft_dataset(
data_args.dataset_args,
load_preprocessed_data=data_args.load_preprocessed_data,
)
if not data_args.load_preprocessed_data:
map_fn = partial(
sft_map_fn,
tokenizer=tokenizer,
mask_prompt_loss=data_args.mask_prompt_loss,
)
dataset = dataset.map(
map_fn,
num_proc=data_args.num_proc,
desc="Mapping dataset to SFT format",
)
# truncate / filter long sequences if needed
dataset = dllm.utils.post_process_dataset(dataset, data_args)
# ----- Training --------------------------------------------------------------
accelerate.PartialState().wait_for_everyone()
logger.info("Start training...")
trainer = dream.DreamTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset.get("test", None),
args=training_args,
loss_weight_type=training_args.loss_weight_type,
data_collator=dream.utils.DreamSFTCollator(
tokenizer,
return_tensors="pt",
padding=True,
perbatch_cutoff=data_args.perbatch_cutoff,
resp_cutoff_ratio=data_args.resp_cutoff_ratio,
),
)
trainer.train()
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
trainer.processing_class.save_pretrained(
os.path.join(training_args.output_dir, "checkpoint-final")
)
if __name__ == "__main__":
train()

View File

@ -0,0 +1,3 @@
Work in progress.
Please see [`examples/editflow/bert/README.md`](/examples/editflow/bert/README.md) for examples of finetuning BERT with EditFlow.

View File

@ -0,0 +1,162 @@
# Edit Flows
> **Reference**
> 📄 Paper: [Edit Flows: Flow Matching with Edit Operations](https://arxiv.org/abs/2506.09018)
This directory provides an educational reference for training EditFlow models. It demonstrates how to adapt open-weight DLLMs—such as [LLaDA](https://arxiv.org/abs/2502.09992) and [Dream](https://arxiv.org/abs/2508.15487)—to support *insertion*, *deletion*, beyond the standard *substitution*(`mask`->`tokens`) operations. It also includes examples for training (pretraining and finetuning) EditFlow models from scratch.
> [!NOTE]
> - Examples are available for both LLaDA and Dream, but this README focuses on adapting open-weight LLaDA for edit operations ([`adapt_llada.py`](/examples/editflow/adapt_llada.py)) and reusing its architecture for training from scratch ([`pt_llada.py`](/examples/editflow/pt_llada.py) -> [`sft_llada.py`](/examples/editflow/sft_llada.py)).
> - While `EditFlowCollator` supports custom `x0`, this README uses a fixed-length (128) masks as `x0`. The trained model generates text by replacing masks, deleting redundant ones, and inserting tokens as needed. To change the default `x0` distribution (e.g., empty sequences for [OneFlow](https://arxiv.org/abs/2510.03506)-like insertion-only generation), pass `--x0_sampler "empty"`.
## Table of Contents
- [Setup](#setup)
- [Files overview](#files-overview)
- [Training](#training)
- [Adapting LLaDA-8B-Instruct to support insertion and deletion](#adapting-llada-8b-instruct-to-support-insertion-and-deletion)
- [Pretraining & Finetuning from scratch](#pretraining--finetuning-from-scratch)
- [Sampling](#sampling)
- [Acknowledgement](#acknowledgement)
## Setup
> [!IMPORTANT]
> **Slurm users:** Update `scripts/train.slurm.sh` and `mkdir logps`: see [(optional) Slurm setup](/README.md/#optional-slurm-setup) for details.
## Files overview
```
dllm/pipelines/editflow
├── __init__.py # Package initialization
├── models
│ ├── dream
│ │ └── modelling_dream.py # EditFlowDream: architecture based on Dream
│ └── llada
│ └── modelling_llada.py # EditFlowLLaDA: architecture based on LLaDA
├── trainer.py
└── utils.py
# example entry point for training / sampling
examples/editflow
├── adapt_dream.py # Example of adapting Dream for EditFlow directly
├── adapt_llada.py # Example of adapting LLaDA for EditFlow directly
├── generate.py # Generation example
├── pt_dream.py # EditFlowDream pretraining example
├── pt_llada.py # EditFlowLLaDA pretraining example
├── pt.py # Pretraining function
├── README.md # Documentation (you are here)
├── sft_dream.py # EditFlowDream SFT example
├── sft_llada.py # EditFlowLLaDA SFT example
└── sft.py # Supervised finetuning function
```
## Training
### Adapting [LLaDA-8B-Instruct](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) to support *insertion* and *deletion*
The original LLaDA model generated text by iteratively substituting the given `<mask>` tokens to real tokens.
<p align="center">
<img src="https://github.com/ML-GSAI/LLaDA/blob/main/imgs/example_gradio.gif" alt="LLaDA demo" width="80%">
</p>
<p align="center"><em>Figure: Example Gradio demo for LLaDA.</em></p>
However, LLaDA supports only substitution. This example shows how to adapt it so that, during decoding, the model can not only replace fixed-length masks (e.g., 128 tokens) with real text but also insert new tokens and delete unnecessary masks adaptively:
```shell
accelerate launch \
--config_file scripts/accelerate_configs/zero2.yaml \
examples/editflow/adapt_llada.py \
--model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" \
--lm_head_key "model.transformer.ff_out" \
--init_editflow_from_src True \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture" \
--x0_sampler "masks[length:128]" \
--max_length 1024 \
--num_train_epochs 4 \
--learning_rate 5e-5
```
If you are using slurm and want to train across, for example, four nodes (32 GPUs total), run:
```shell
sbatch --nodes=4 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/editflow/adapt_llada.py" \
--model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" \
--lm_head_key "model.transformer.ff_out" \
--init_editflow_from_src True \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture" \
--x0_sampler "masks[length:128]" \
--max_length 1024 \
--num_train_epochs 4 \
--learning_rate 5e-5
```
After training, you can use the [generate.py](/examples/editflow/generate.py) scripts to provide a visualized decoding trace to see how the model performs *insertion* and *deletion* beyond regular mask *substitutions*. See [Sampling](#sampling) for details.
### Pretraining & Finetuning from scratch
You can also train an EditFlow model from scratch (pretrain → SFT) without adapting an existing DLLM.
Pretrain on a subset of [mlfoundations/dclm-baseline-1.0](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0) using 192 GPUs (24x8) and FSDP:
```shell
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/editflow/pt_llada.py" \
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
--dataset_args "mlfoundations/dclm-baseline-1.0" \
--output_dir "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0" \
--x0_sampler "masks[length:128]" \
--max_length 1024 \
--max_steps 2000 \
--learning_rate 3e-4
```
Finetune on a subset of [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) using 8 GPUS and FSDP for better instruction following:
```shell
# you can also run locally with `accelerate ...`
sbatch --nodes=1 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/editflow/sft_llada.py" \
--model_name_or_path "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0/checkpoint-final" \
--dataset_args "allenai/tulu-3-sft-mixture[train:10000,test:1000]" \
--output_dir "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0" \
--x0_sampler "masks[length:128]" \
--max_length 1024 \
--num_train_epochs 4 \
--learning_rate 5e-5
```
## Sampling
After training, you can visualize how the model performs mask substitution, insertion, and deletion during generation with [generate.py](/examples/editflow/generate.py). Inserted tokens appear <span style="color:blue; font-weight:bold">blue</span>, and tokens substituted from `<mask>` appear <span style="color:black; font-weight:bold">black</span>, and deleted tokens are shown with a strikethrough before they disappear.
```shell
# Generate a long sequence to visualize insertions after 128 <mask> tokens
python examples/editflow/generate.py \
--model_name_or_path "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture/checkpoint-final" \
--tau 0.02 --mask_length 128 --seed 7070 \
--prompt "write a romantic story" --make_gif
# Generate a short sequence to visualize deletions after 128 <mask> tokens
python examples/editflow/generate.py \
--model_name_or_path "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture/checkpoint-final" \
--tau 0.02 --mask_length 128 --seed 7070 \
--prompt "write a single-sentence romantic story" --make_gif
```
<p align="center">
<img src="/examples/editflow/assets/deletion.gif" alt="EditFlow deletion demo" width="95%">
</p>
<p align="center"><em>Figure: Deletion & Substitution trace</code></em></p>
<p align="center">
<img src="/examples/editflow/assets/insertion.gif" alt="LLaDA demo" width="95%">
</p>
<p align="center"><em>Figure: Inserction & Substitution trace</em></p>
## Acknowledgement
This Edit Flows implementation is inspired by https://github.com/TheMatrixMaster/edit-flows-demo.

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.4 MiB

View File

@ -0,0 +1,77 @@
# Edit Flows - BERT
> 📄 Paper: [Edit Flows: Flow Matching with Edit Operations](https://arxiv.org/abs/2506.09018)
## Warmup
In this section, we show toy examples of pretraining and SFTing [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on small datasets to generate text with EditFlow.
You can use any BERT model instead for example, by `--model_name_or_path "FacebookAI/roberta-large"`.
### Pretrain
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`tiny-shakespeare`](https://huggingface.co/datasets/Trelis/tiny-shakespeare) dataset, run:
```shell
PYTHONPATH=. accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/editflow/bert/pt.py \
--model_name_or_path "answerdotai/ModernBERT-large" \
--dataset_args "Trelis/tiny-shakespeare" \
--text_field "Text" \
--insert_eos False \
--max_length 128 \
--num_train_epochs 20 \
--per_device_train_batch_size 64 \
--per_device_eval_batch_size 64 \
--save_steps 0.1 \
--x0_sampler "masks[length:64]" \
--output_dir "models/EditFlow/ModernBERT-large/tiny-shakespeare"
```
To run inference with the model:
```shell
PYTHONPATH=. python examples/editflow/generate.py \
--model_name_or_path "models/EditFlow/ModernBERT-large/tiny-shakespeare/checkpoint-final" \
--tau 0.01 --mask_length 64 --seed 42 --make_gif
# see `decode_trace.gif`
```
### SFT
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`alpaca`](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset, run:
```shell
PYTHONPATH=. accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
examples/editflow/bert/sft.py \
--model_name_or_path "answerdotai/ModernBERT-large" \
--dataset_args "tatsu-lab/alpaca" \
--max_length 512 \
--num_train_epochs 20 \
--per_device_train_batch_size 64 \
--per_device_eval_batch_size 64 \
--save_steps 0.1 \
--x0_sampler "masks[length:64]" \
--output_dir "models/EditFlow/ModernBERT-large/alpaca"
```
To run inference with the model:
```shell
PYTHONPATH=. python examples/editflow/generate.py \
--model_name_or_path "models/EditFlow/ModernBERT-large/alpaca/checkpoint-final" \
--prompt "Could you please write a poem for me?" --tau 0.01 --mask_length 64 --seed 42 --make_gif
# see `decode_trace.gif`
```
<!-- ```shell
accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
examples/editflow/bert/sft.py \
--model_name_or_path "answerdotai/ModernBERT-large" \
--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \
--max_length 1024 \
--num_train_epochs 10 \
--per_device_train_batch_size 48 \
--per_device_eval_batch_size 48 \
--save_steps 0.1 \
--x0_sampler "masks[length:64]" \
--output_dir "models/EditFlow/ModernBERT-large/tulu-3-smoltalk/epochs-10-bs-384-len-1024"
``` -->

View File

@ -0,0 +1,48 @@
from dataclasses import dataclass
import transformers
import dllm
from examples.editflow import pt as editflow_pt
@dataclass
class ModelArguments(editflow_pt.ModelArguments):
model_name_or_path: str = "answerdotai/ModernBERT-large"
lm_head_key: str = "decoder"
@dataclass
class DataArguments(editflow_pt.DataArguments):
dataset_args: str = "Trelis/tiny-shakespeare"
text_field: str = "Text"
max_length: int = 128
streaming: bool = False
drop_tail: bool = True
insert_eos: bool = False
@dataclass
class TrainingArguments(editflow_pt.TrainingArguments):
output_dir: str = "models/EditFlow/ModernBERT-large/tiny-shakespeare"
num_train_epochs: float = 20
learning_rate: float = 3e-4
per_device_train_batch_size: int = 64
per_device_eval_batch_size: int = 64
eval_steps: float = 0.1
save_steps: float = 0.1
x0_sampler: str = "masks[length:64]"
if __name__ == "__main__":
# ----- Argument parsing -------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
editflow_pt.train(
model_args=model_args,
data_args=data_args,
training_args=training_args,
ef_config_cls=dllm.pipelines.editflow.EditFlowModernBertConfig,
)

View File

@ -0,0 +1,44 @@
from dataclasses import dataclass
import transformers
import dllm
from examples.editflow import sft as editflow_sft
@dataclass
class ModelArguments(editflow_sft.ModelArguments):
model_name_or_path: str = "answerdotai/ModernBERT-large"
lm_head_key: str = "decoder"
@dataclass
class DataArguments(editflow_sft.DataArguments):
dataset_args: str = "tatsu-lab/alpaca"
max_length: int = 512
@dataclass
class TrainingArguments(editflow_sft.TrainingArguments):
output_dir: str = "models/EditFlow/ModernBERT-large/alpaca"
num_train_epochs: float = 20
learning_rate: float = 3e-4
per_device_train_batch_size: int = 64
per_device_eval_batch_size: int = 64
eval_steps: float = 0.1
save_steps: float = 0.1
x0_sampler: str = "masks[length:64]"
if __name__ == "__main__":
# ----- Argument parsing -------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
editflow_sft.train(
model_args=model_args,
data_args=data_args,
training_args=training_args,
ef_config_cls=dllm.pipelines.editflow.EditFlowModernBertConfig,
)

View File

@ -0,0 +1,88 @@
"""
Local users
------------
- 1 GPU (LoRA, useful for testing):
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/editflow/dream/adapt.py \
--lora True
- 8 GPUs (FSDP):
accelerate launch \
--config_file scripts/accelerate_configs/fsdp.yaml \
examples/editflow/dream/adapt.py
Slurm users
# Note: run `mkdir logs` before running sbatch; and adjust
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
------------
- 1 Node, 8 GPUs (FSDP):
sbatch --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/editflow/dream/adapt.py"
- 2 Nodes, 16 GPUs (FSDP):
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/editflow/dream/adapt.py"
"""
from dataclasses import dataclass
import torch
import transformers
import dllm
from examples.editflow import sft as editflow_sft
@dataclass
class ModelArguments(editflow_sft.ModelArguments):
model_name_or_path: str = "Dream-org/Dream-v0-Instruct-7B"
lm_head_key: str = "lm_head"
init_editflow_from_src: bool = True
@dataclass
class DataArguments(editflow_sft.DataArguments):
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
@dataclass
class TrainingArguments(editflow_sft.TrainingArguments):
output_dir: str = (
"models/EditFlow-Dream-7B-Instruct-Adapt/tulu-3-sft-mixture[train:10000,test:1000]"
)
if __name__ == "__main__":
# ----- Argument parsing -------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
dllm.utils.initial_training_setup(model_args, data_args, training_args)
# Create EditFlow model (bf16 init on CUDA)
ef_cfg = dllm.pipelines.editflow.EditFlowDreamConfig.from_pretrained(
model_args.model_name_or_path
)
with dllm.utils.init_device_context_manager():
model = transformers.AutoModel.from_config(ef_cfg, dtype=torch.bfloat16)
# Initialize EditFlow model from the src model: copies backbone & clones lm_head
if model_args.init_editflow_from_src:
src_model = transformers.AutoModelForMaskedLM.from_pretrained(
model_args.model_name_or_path, dtype=torch.bfloat16
)
dllm.pipelines.editflow.utils.init_editflow_from_src(
model, src_model, lm_head_key=model_args.lm_head_key
)
del src_model
model = dllm.utils.load_peft(model, model_args)
editflow_sft.train(
model_args=model_args,
data_args=data_args,
training_args=training_args,
model=model,
)

View File

@ -0,0 +1,67 @@
"""
Local users
------------
- 1 GPU (LoRA, useful for testing):
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/editflow/dream/pt.py \
--lora True
- 8 GPUs (FSDP):
accelerate launch \
--config_file scripts/accelerate_configs/fsdp.yaml \
examples/editflow/dream/pt.py
Slurm users
# Note: run `mkdir logs` before running sbatch; and adjust
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
------------
- 1 Node, 8 GPUs (FSDP):
sbatch --gres=gpu:1 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/editflow/dream/pt.py"
- 24 Nodes, 192 GPUs (FSDP):
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/editflow/dream/pt.py"
"""
from dataclasses import dataclass
import transformers
import dllm
from examples.editflow import pt as editflow_pt
@dataclass
class ModelArguments(editflow_pt.ModelArguments):
model_name_or_path: str = "Dream-org/Dream-v0-Base-7B"
lm_head_key: str = "lm_head"
@dataclass
class DataArguments(editflow_pt.DataArguments):
dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]"
@dataclass
class TrainingArguments(editflow_pt.TrainingArguments):
output_dir: str = (
"models/EditFlow-Dream-7B-Base/dclm-baseline-1.0[train:10_000_000,test:10_000]"
)
if __name__ == "__main__":
# ----- Argument parsing -------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
editflow_pt.train(
model_args=model_args,
data_args=data_args,
training_args=training_args,
ef_config_cls=dllm.pipelines.editflow.EditFlowDreamConfig,
)

View File

@ -0,0 +1,66 @@
"""
Local users
------------
- 1 GPU (LoRA, useful for testing):
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/editflow/dream/sft.py \
--lora True
- 8 GPUs (FSDP):
accelerate launch \
--config_file scripts/accelerate_configs/zero2.yaml \
examples/editflow/dream/sft.py
Slurm users
# Note: run `mkdir logs` before running sbatch; and adjust
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
------------
- 1 Node, 8 GPUs (FSDP):
sbatch --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/editflow/dream/sft.py"
- 2 Nodes, 16 GPUs (FSDP):
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/editflow/dream/sft.py"
"""
from dataclasses import dataclass
import transformers
from examples.editflow import sft as editflow_sft
@dataclass
class ModelArguments(editflow_sft.ModelArguments):
model_name_or_path: str = (
"models/EditFlow-Dream-7B-Base/dclm-baseline-1.0[train:10_000_000,test:10_000]/checkpoint-final"
)
@dataclass
class DataArguments(editflow_sft.DataArguments):
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
@dataclass
class TrainingArguments(editflow_sft.TrainingArguments):
output_dir: str = (
"models/EditFlow-Dream-7B-Instruct-SFT/tulu-3-sft-mixture[train:10000,test:1000]"
)
if __name__ == "__main__":
# ----- Argument parsing -------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
editflow_sft.train(
model_args=model_args,
data_args=data_args,
training_args=training_args,
)

View File

@ -0,0 +1,418 @@
"""
Minimal EditFlow τ-leap generator for EditBase-Dream with diffusion-style visualization.
What changed vs. your original:
- tau_leap_step_minimal returns (x_next, any_edit, step_trace) preserving all intermediates.
- generate_editflow_minimal returns (final_text, trace).
- render_consecutive_trace_gif(trace, tokenizer, ...) draws a GIF where each frame shows
ONLY the current output (like the Gemini diffusion page shows progressive refinement):
* SUB tokens in the current frame are orange
* INS tokens in the current frame are blue
* KEEP tokens are black
* If any deletions happened in the step, the title shows ⌫N (red)
"""
# srun -p $PARTITION --quotatype=$QUOTATYPE --gres=gpu:1 --time=03:00:000 python examples/editflow/generate.py --model_name_or_path "models/EditFlow-Dream-Instruct-7B/tulu-3-sft-mixture/checkpoint-final" --tau 0.02 --mask_length 128 --seed 7070 --prompt "write a romantic story" --make_gif
import math
from dataclasses import dataclass
from typing import Annotated
import tyro
import torch
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
from dllm.core.schedulers import BaseKappaScheduler, LinearKappaScheduler
# ------------------------------- Small utilities --------------------------------
def _bernoulli_from_rate(rate: torch.Tensor, tau: float) -> torch.Tensor:
"""First-order τ-leap Bernoulli with p ≈ rate * τ (clamped)."""
p = (rate.float() * float(tau)).clamp_(0.0, 1.0 - 1e-6)
return torch.bernoulli(p)
def _sample_from_logits(logits_row: torch.Tensor, temperature: float) -> int:
"""Sample one token id from a 1D logits row with temperature.
temperature <= 0 -> greedy (argmax).
"""
if temperature <= 0.0:
return int(torch.argmax(logits_row).item())
return int(
torch.distributions.Categorical(logits=(logits_row / temperature))
.sample()
.item()
)
@dataclass
class GenCfg:
tau: float = 0.02 # τ step
device: str = "cuda" if torch.cuda.is_available() else "cpu"
seed: int = 1234
edit_prompt: bool = False # allow editing inside prompt region?
temperature: float = 0.7 # token sampling temperature (sub/ins)
verbose: bool = True # whether to show intermediate decoding traces
time_independent: bool = True
# -------------------------------- τ-leap one step --------------------------------
@torch.no_grad()
def tau_leap_step_minimal(
x: torch.Tensor, # [T]
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
prompt_len: int, # number of initial prompt tokens (including BOS)
t: float,
sched: BaseKappaScheduler,
cfg: GenCfg,
prev_out: dict | None = None, # <-- pass prior step's model outputs
reuse_prev: bool = False, # <-- if True, reuse prev_out instead of forward()
) -> tuple[torch.Tensor, bool, dict, dict]:
"""
Single τ-leap step with deletion/substitution conflict resolution
and right-insert policy.
Reuse semantics:
• If cfg.time_independent == True and reuse_prev == True and prev_out is not None,
we reuse `prev_out` tensors instead of calling model() again.
• Otherwise we run a fresh forward().
Viz-only convention:
• Any local annotated as _Ann[*, "viz-only"] is used only for human-visible
tracing / debugging (console logs, GIFs) and does not affect generation.
• Such variables are also prefixed with '_' for quick visual scanning.
Returns:
x_next, any_edit, _step_trace, out_for_next (the freshly used model outputs)
"""
device = x.device
T = x.numel()
# Decide whether to reuse the previous forward results
use_reuse = bool(cfg.time_independent and reuse_prev and (prev_out is not None))
if use_reuse:
out = prev_out
else:
attn = torch.ones(1, T, dtype=torch.long, device=device)
t_tensor = torch.full((1, 1), float(t), device=device)
out = model(input_ids=x.unsqueeze(0), attention_mask=attn, t=t_tensor)
del_rate_h = out["del_rate_hat"] # [1, T]
sub_rate_h = out["sub_rate_hat"] # [1, T]
ins_rate_h = out["ins_rate_hat"] # [1, T]
sub_logits = out["sub_logits"] # [1, T, V]
ins_logits = out["ins_logits"] # [1, T, V]
# Scale normalized rates to true rates
tt = torch.tensor([[t]], device=device)
w = sched.weight(tt)
del_rate = del_rate_h * w
sub_rate = sub_rate_h * w
ins_rate = ins_rate_h * w
# Clamp prompt_len within current T (robustness)
prompt_len_clamped = int(max(1, min(prompt_len, T)))
if not cfg.edit_prompt:
# Protect the entire prompt span from del/sub
del_rate[:, :prompt_len_clamped] = 0.0
sub_rate[:, :prompt_len_clamped] = 0.0
# Disallow insertions inside the prompt EXCEPT at the last prompt token
if prompt_len_clamped >= 2:
ins_rate[:, : prompt_len_clamped - 1] = 0.0
# Combined "edit" (delete or substitute) event
comb_rate = (del_rate + sub_rate).squeeze(0) # [T]
comb_fire = _bernoulli_from_rate(comb_rate, cfg.tau).bool() # [T]
# If an edit fires at i, choose deletion with prob λ_del/(λ_del+λ_sub)
p_del = (del_rate.squeeze(0) / (comb_rate + 1e-8)).clamp(0, 1) # [T]
choose_del = (torch.rand_like(p_del) < p_del) & comb_fire # [T]
choose_sub = comb_fire & (~choose_del) # [T]
# Insertions (right of token i)
ins_fire = _bernoulli_from_rate(ins_rate.squeeze(0), cfg.tau).bool() # [T]
# Token draws (algorithmic, not viz-only)
sub_samples: list[int | None] = [
(
_sample_from_logits(sub_logits[0, i], cfg.temperature)
if choose_sub[i]
else None
)
for i in range(T)
]
ins_samples: list[int | None] = [
_sample_from_logits(ins_logits[0, i], cfg.temperature) if ins_fire[i] else None
for i in range(T)
]
# Build new sequence left→right (apply insertions to the RIGHT)
new_ids: list[int] = []
# --- viz-only per-position labels (for trace/GIF) ---
_before_ops: Annotated[list[str], "viz-only"] = (
[]
) # per 'before' position: DEL/SUB/KEEP
_after_ops: Annotated[list[str], "viz-only"] = (
[]
) # per 'after' token aligned to new_ids: INS/SUB/KEEP
for i in range(T):
if choose_del[i]:
_before_ops.append("DEL")
# deletion -> no token appended
elif choose_sub[i]:
_before_ops.append("SUB")
new_tok = sub_samples[i]
new_ids.append(int(new_tok))
_after_ops.append("SUB")
else:
_before_ops.append("KEEP")
new_ids.append(int(x[i].item()))
_after_ops.append("KEEP")
if ins_samples[i] is not None:
new_ids.append(int(ins_samples[i]))
_after_ops.append("INS")
x_next = torch.tensor(new_ids, dtype=torch.long, device=device)
any_edit = bool(comb_fire.any().item() or ins_fire.any().item())
# Provide the exact outputs we used this step for the caller to pass forward
out_for_next = out
# --- (vis) used only for verbose console trace ---
if cfg.verbose and (comb_fire.any() or ins_fire.any()):
def _tok_str(tok_id: int) -> str: # viz-only helper
try:
s = tokenizer.decode([int(tok_id)])
return s if s.strip() else f"<{int(tok_id)}>"
except Exception:
return f"<{int(tok_id)}>"
_ops_strs: Annotated[list[str], "viz-only"] = []
for i in range(T):
if choose_del[i]:
_ops_strs.append(f"DEL@{i}:{_tok_str(int(x[i]))}")
elif choose_sub[i]:
_ops_strs.append(
f"SUB@{i}:{_tok_str(int(x[i]))}->{_tok_str(sub_samples[i])}"
)
if ins_samples[i] is not None:
_ops_strs.append(f"INS@{i}->{i+1}:{_tok_str(ins_samples[i])}")
print("[time]", f"{t:.4f}")
print("[events]", "; ".join(_ops_strs))
print("[decode]\n", tokenizer.decode(new_ids, skip_special_tokens=False))
print()
# --- (vis) step trace payload (returned; used only for visualization downstream) ---
_step_trace: Annotated[dict, "viz-only"] = {
"t": float(t),
"x_before_ids": [int(i) for i in x.tolist()],
"x_after_ids": [int(i) for i in new_ids],
"before_ops": _before_ops, # viz-only labels
"after_ops": _after_ops, # viz-only labels
# below are algorithmic signals copied for visualization/analysis
"choose_del": [bool(v) for v in choose_del.tolist()],
"choose_sub": [bool(v) for v in choose_sub.tolist()],
"ins_fire": [bool(v) for v in ins_fire.tolist()],
"sub_samples": [int(s) if s is not None else None for s in sub_samples],
"ins_samples": [int(s) if s is not None else None for s in ins_samples],
"prompt_len": prompt_len_clamped,
"used_reuse": bool(use_reuse),
}
return x_next, any_edit, _step_trace, out_for_next
# -------------------------------- top-level generate -------------------------------
@torch.no_grad()
def generate_editflow_minimal(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
args,
cfg: GenCfg,
) -> tuple[str, dict]:
"""
Returns:
final_text, trace
Notes on annotations:
• Any local annotated with Annotated[..., "viz-only"] is only used to build
the decode trace for visualization (e.g., GIF rendering) and has no effect
on the actual generation. Such variables are also prefixed with '_' to make
this visually obvious in code.
"""
torch.manual_seed(cfg.seed)
# If prompt is None, start from BOS alone; otherwise ALWAYS prefix BOS
bos = getattr(tokenizer, "bos_token_id", None)
if bos is None:
raise ValueError("Tokenizer must have a BOS token for this sampler.")
prompt = args.prompt
if prompt is None:
ids = [bos] # BOS alone
else:
ids = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=True,
add_generation_prompt=True,
)
# ids = tokenizer.encode(prompt, add_special_tokens=False)
# ids = [bos] + enc["input_ids"] # ALWAYS prefix BOS
prompt_len = len(ids)
if args.mask_length:
if getattr(tokenizer, "mask_token_id", None) is None:
raise ValueError(
"Tokenizer must define mask_token_id when --mask_length > 0."
)
ids = ids + [tokenizer.mask_token_id] * args.mask_length
x = torch.tensor(ids, dtype=torch.long, device=model.device)
sched = LinearKappaScheduler()
tau = cfg.tau
steps = math.ceil(1.0 / max(tau, 1e-9))
_trace: Annotated[dict, "viz-only: full decode trace for GIF/inspection"] = {
"steps": [],
"init": {
"t": 0.0,
"x_ids": [int(i) for i in x.tolist()],
"prompt_len": int(prompt_len),
},
"end_t": 0.0,
}
# Local-only reuse: if previous iteration had no edits, reuse its forward.
prev_out: dict | None = None
prev_had_edits = True # first iteration must run a forward
t = 0.0
for _ in range(steps):
# We can reuse prev_out only if the model is declared time-independent
# and the previous step had NO edits (sequence unchanged).
reuse_prev = (
cfg.time_independent and not prev_had_edits and (prev_out is not None)
)
x, edited, _step_trace, prev_out = tau_leap_step_minimal(
x=x,
model=model,
tokenizer=tokenizer,
prompt_len=prompt_len,
t=t,
sched=sched,
cfg=cfg,
prev_out=prev_out,
reuse_prev=reuse_prev,
)
_step_trace: Annotated[dict, "viz-only: per-step intermediates for trace"]
_trace["steps"].append(_step_trace)
prev_had_edits = edited
t = min(1.0, t + tau)
if t >= 1.0 - args.time_epsilon:
break
_trace["end_t"] = float(t)
final_text = tokenizer.decode(x.tolist(), skip_special_tokens=False)
print("[final]")
return final_text, _trace
# ---------------------------------------- CLI -------------------------------------
def main():
@dataclass
class ScriptArgs:
# Required (no default)
model_name_or_path: Annotated[str, "Path or hub id for the model"]
time_independent: Annotated[
bool, "Whether model is conditioned on time step"
] = True
prompt: Annotated[str | None, "Text prompt. If None, start from BOS alone."] = (
None
)
# Boolean flag: tyro exposes --edit-prompt / --no-edit-prompt automatically for bools
edit_prompt: Annotated[
bool,
"Allow delete/substitute and insertions in the prompt region (BOS+prompt).",
] = False
# Generation-related args
tau: Annotated[float, "τ-leap size"] = 0.01
time_epsilon: Annotated[
float, "Match this with the `time_epsilon` arg used in your EditFlowTrainer"
] = 1e-3
mask_length: Annotated[
int,
"Number of <mask> tokens appended after the prompt.\n"
"EditFlow will iteratively substitute, insert, or delete masks to form the output.",
] = 128
temperature: Annotated[float, "Token sampling temperature; 0 for greedy."] = 0.7
seed: Annotated[int, "Random seed"] = 1234
verbose: Annotated[bool, "Whether to show intermediate decoding traces"] = True
# Visualization
make_gif: Annotated[bool, "Render a decoding trace GIF after generation."] = (
False
)
gif_path: Annotated[
str | None, "Output GIF path (default: decode_trace.gif)"
] = None
frame_ms: Annotated[int, "Per-frame duration in ms"] = 120
args = tyro.cli(ScriptArgs)
cfg = GenCfg(
tau=args.tau,
seed=args.seed,
edit_prompt=args.edit_prompt,
temperature=args.temperature,
verbose=args.verbose,
time_independent=args.time_independent,
)
model = AutoModel.from_pretrained(
args.model_name_or_path,
dtype=torch.bfloat16,
device_map="auto",
).eval()
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
final_text, trace = generate_editflow_minimal(model, tokenizer, args, cfg)
print(final_text)
if args.make_gif:
from examples.editflow.viz import render_consecutive_trace_gif
out = args.gif_path or "decode_trace.gif"
path = render_consecutive_trace_gif(
trace,
tokenizer,
out_path=out,
frame_ms=args.frame_ms,
)
print(f"[gif saved] {path}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,88 @@
"""
Local users
------------
- 1 GPU (LoRA, useful for testing):
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/editflow/llada/adapt.py \
--lora True
- 8 GPUs (FSDP):
accelerate launch \
--config_file scripts/accelerate_configs/fsdp.yaml \
examples/editflow/llada/adapt.py
Slurm users
# Note: run `mkdir logs` before running sbatch; and adjust
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
------------
- 1 Node, 8 GPUs (FSDP):
sbatch --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/editflow/llada/adapt.py"
- 2 Nodes, 16 GPUs (FSDP):
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/editflow/llada/adapt.py"
"""
from dataclasses import dataclass
import torch
import transformers
import dllm
from examples.editflow import sft as editflow_sft
@dataclass
class ModelArguments(editflow_sft.ModelArguments):
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Instruct"
lm_head_key: str = "model.transformer.ff_out"
init_editflow_from_src: bool = True
@dataclass
class DataArguments(editflow_sft.DataArguments):
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
@dataclass
class TrainingArguments(editflow_sft.TrainingArguments):
output_dir: str = (
"models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture[train:10000,test:1000]"
)
if __name__ == "__main__":
# ----- Argument parsing -------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
dllm.utils.initial_training_setup(model_args, data_args, training_args)
# Create EditFlow model (bf16 init on CUDA)
ef_cfg = dllm.pipelines.editflow.EditFlowLLaDAConfig.from_pretrained(
model_args.model_name_or_path
)
with dllm.utils.init_device_context_manager():
model = transformers.AutoModel.from_config(ef_cfg, dtype=torch.bfloat16)
# Initialize EditFlow model from the src model: copies backbone & clones lm_head
if model_args.init_editflow_from_src:
src_model = transformers.AutoModelForMaskedLM.from_pretrained(
model_args.model_name_or_path, dtype=torch.bfloat16
)
dllm.pipelines.editflow.utils.init_editflow_from_src(
model, src_model, lm_head_key=model_args.lm_head_key
)
del src_model
model = dllm.utils.load_peft(model, model_args)
editflow_sft.train(
model_args=model_args,
data_args=data_args,
training_args=training_args,
model=model,
)

View File

@ -0,0 +1,67 @@
"""
Local users
------------
- 1 GPU (LoRA, useful for testing):
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/editflow/llada/pt.py \
--lora True
- 8 GPUs (DeepSpeed FSDP):
accelerate launch \
--config_file scripts/accelerate_configs/fsdp.yaml \
examples/editflow/llada/pt.py
Slurm users
# Note: run `mkdir logs` before running sbatch; and adjust
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
------------
- 1 Node, 8 GPUs (FSDP):
sbatch --gres=gpu:1 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/editflow/llada/pt.py"
- 24 Nodes, 192 GPUs (FSDP):
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/editflow/llada/pt.py"
"""
from dataclasses import dataclass
import transformers
import dllm
from examples.editflow import pt as editflow_pt
@dataclass
class ModelArguments(editflow_pt.ModelArguments):
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Base"
lm_head_key: str = "model.transformer.ff_out"
@dataclass
class DataArguments(editflow_pt.DataArguments):
dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]"
@dataclass
class TrainingArguments(editflow_pt.TrainingArguments):
output_dir: str = (
"models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0[train:10_000_000,test:10_000]"
)
if __name__ == "__main__":
# ----- Argument parsing -------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
editflow_pt.train(
model_args=model_args,
data_args=data_args,
training_args=training_args,
ef_config_cls=dllm.pipelines.editflow.EditFlowLLaDAConfig,
)

Some files were not shown because too many files have changed in this diff Show More